summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/tests
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/tests')
-rw-r--r--toolkit/components/ml/tests/browser/browser.toml6
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_cache.js233
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_engine.js155
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_utils.js26
-rw-r--r--toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/config.json21
-rw-r--r--toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/onnx/config.json21
-rw-r--r--toolkit/components/ml/tests/browser/head.js129
7 files changed, 386 insertions, 205 deletions
diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml
index 57637c8bda..ed3c04f066 100644
--- a/toolkit/components/ml/tests/browser/browser.toml
+++ b/toolkit/components/ml/tests/browser/browser.toml
@@ -1,9 +1,15 @@
[DEFAULT]
+run-if = ["nightly_build"] # Bug 1890946 - enable the inference engine in release
support-files = [
"head.js",
"data/**/*.*"
]
["browser_ml_cache.js"]
+lineno = "7"
["browser_ml_engine.js"]
+lineno = "10"
+
+["browser_ml_utils.js"]
+lineno = "13"
diff --git a/toolkit/components/ml/tests/browser/browser_ml_cache.js b/toolkit/components/ml/tests/browser/browser_ml_cache.js
index d8725368bd..8d879fc74d 100644
--- a/toolkit/components/ml/tests/browser/browser_ml_cache.js
+++ b/toolkit/components/ml/tests/browser/browser_ml_cache.js
@@ -11,16 +11,20 @@ const FAKE_HUB =
"chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data";
const FAKE_MODEL_ARGS = {
- organization: "acme",
- modelName: "bert",
- modelVersion: "main",
+ model: "acme/bert",
+ revision: "main",
+ file: "config.json",
+};
+
+const FAKE_RELEASED_MODEL_ARGS = {
+ model: "acme/bert",
+ revision: "v0.1",
file: "config.json",
};
const FAKE_ONNX_MODEL_ARGS = {
- organization: "acme",
- modelName: "bert",
- modelVersion: "main",
+ model: "acme/bert",
+ revision: "main",
file: "onnx/config.json",
};
@@ -35,6 +39,9 @@ const badHubs = [
"https://model-hub.mozilla.org.hack", // Domain that contains allowed domain
];
+/**
+ * Make sure we reject bad model hub URLs.
+ */
add_task(async function test_bad_hubs() {
for (const badHub of badHubs) {
Assert.throws(
@@ -60,60 +67,57 @@ add_task(async function test_allowed_hub() {
const badInputs = [
[
{
- organization: "ac me",
- modelName: "bert",
- modelVersion: "main",
+ model: "ac me/bert",
+ revision: "main",
file: "config.json",
},
"Org can only contain letters, numbers, and hyphens",
],
[
{
- organization: "1111",
- modelName: "bert",
- modelVersion: "main",
+ model: "1111/bert",
+ revision: "main",
file: "config.json",
},
"Org cannot contain only numbers",
],
[
{
- organization: "-acme",
- modelName: "bert",
- modelVersion: "main",
+ model: "-acme/bert",
+ revision: "main",
file: "config.json",
},
"Org start or end with a hyphen, or use consecutive hyphens",
],
[
{
- organization: "a-c-m-e",
- modelName: "#bert",
- modelVersion: "main",
+ model: "a-c-m-e/#bert",
+ revision: "main",
file: "config.json",
},
"Models can only contain letters, numbers, and hyphens, underscord, periods",
],
[
{
- organization: "a-c-m-e",
- modelName: "b$ert",
- modelVersion: "main",
+ model: "a-c-m-e/b$ert",
+ revision: "main",
file: "config.json",
},
"Models cannot contain spaces or control characters",
],
[
{
- organization: "a-c-m-e",
- modelName: "b$ert",
- modelVersion: "main",
+ model: "a-c-m-e/b$ert",
+ revision: "main",
file: ".filename",
},
"File",
],
];
+/**
+ * Make sure we reject bad inputs.
+ */
add_task(async function test_bad_inputs() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -129,6 +133,9 @@ add_task(async function test_bad_inputs() {
}
});
+/**
+ * Test that we can retrieve a file as an ArrayBuffer.
+ */
add_task(async function test_getting_file() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -144,6 +151,35 @@ add_task(async function test_getting_file() {
Assert.equal(jsonData.hidden_size, 768);
});
+/**
+ * Test that we can retrieve a file from a released model and skip head calls
+ */
+add_task(async function test_getting_released_file() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ console.log(hub);
+
+ let spy = sinon.spy(hub, "getETag");
+ let [array, headers] = await hub.getModelFileAsArrayBuffer(
+ FAKE_RELEASED_MODEL_ARGS
+ );
+
+ Assert.equal(headers["Content-Type"], "application/json");
+
+ // check the content of the file.
+ let jsonData = JSON.parse(
+ String.fromCharCode.apply(null, new Uint8Array(array))
+ );
+
+ Assert.equal(jsonData.hidden_size, 768);
+
+ // check that head calls were not made
+ Assert.ok(!spy.called, "getETag should have never been called.");
+ spy.restore();
+});
+
+/**
+ * Make sure files can be located in sub directories
+ */
add_task(async function test_getting_file_in_subdir() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -161,10 +197,13 @@ add_task(async function test_getting_file_in_subdir() {
Assert.equal(jsonData.hidden_size, 768);
});
+/**
+ * Test that we can use a custom URL template.
+ */
add_task(async function test_getting_file_custom_path() {
const hub = new ModelHub({
rootUrl: FAKE_HUB,
- urlTemplate: "${organization}/${modelName}/resolve/${modelVersion}/${file}",
+ urlTemplate: "{model}/resolve/{revision}",
});
let res = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
@@ -172,9 +211,11 @@ add_task(async function test_getting_file_custom_path() {
Assert.equal(res[1]["Content-Type"], "application/json");
});
+/**
+ * Test that we can't use an URL with a query for the template
+ */
add_task(async function test_getting_file_custom_path_rogue() {
- const urlTemplate =
- "${organization}/${modelName}/resolve/${modelVersion}/${file}?some_id=bedqwdw";
+ const urlTemplate = "{model}/resolve/{revision}/?some_id=bedqwdw";
Assert.throws(
() => new ModelHub({ rootUrl: FAKE_HUB, urlTemplate }),
/Invalid URL template/,
@@ -182,6 +223,9 @@ add_task(async function test_getting_file_custom_path_rogue() {
);
});
+/**
+ * Test that the file can be returned as a response and its content correct.
+ */
add_task(async function test_getting_file_as_response() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -192,6 +236,10 @@ add_task(async function test_getting_file_as_response() {
Assert.equal(jsonData.hidden_size, 768);
});
+/**
+ * Test that the cache is used when the data is retrieved from the server
+ * and that the cache is updated with the new data.
+ */
add_task(async function test_getting_file_from_cache() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
let array = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
@@ -213,6 +261,75 @@ add_task(async function test_getting_file_from_cache() {
Assert.deepEqual(array, array2);
});
+/**
+ * Test parsing of a well-formed full URL, including protocol and path.
+ */
+add_task(async function testWellFormedFullUrl() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ const url = `${FAKE_HUB}/org1/model1/v1/file/path`;
+ const result = hub.parseUrl(url);
+
+ Assert.equal(
+ result.model,
+ "org1/model1",
+ "Model should be parsed correctly."
+ );
+ Assert.equal(result.revision, "v1", "Revision should be parsed correctly.");
+ Assert.equal(
+ result.file,
+ "file/path",
+ "File path should be parsed correctly."
+ );
+});
+
+/**
+ * Test parsing of a well-formed relative URL, starting with a slash.
+ */
+add_task(async function testWellFormedRelativeUrl() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+
+ const url = "/org1/model1/v1/file/path";
+ const result = hub.parseUrl(url);
+
+ Assert.equal(
+ result.model,
+ "org1/model1",
+ "Model should be parsed correctly."
+ );
+ Assert.equal(result.revision, "v1", "Revision should be parsed correctly.");
+ Assert.equal(
+ result.file,
+ "file/path",
+ "File path should be parsed correctly."
+ );
+});
+
+/**
+ * Ensures an error is thrown when the URL does not start with the expected root URL or a slash.
+ */
+add_task(async function testInvalidDomain() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ const url = "https://example.com/org1/model1/v1/file/path";
+ Assert.throws(
+ () => hub.parseUrl(url),
+ new RegExp(`Error: Invalid domain for model URL: ${url}`),
+ `Should throw with ${url}`
+ );
+});
+
+/** Tests the method's error handling when the URL format does not include the required segments.
+ *
+ */
+add_task(async function testTooFewParts() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ const url = "/org1/model1";
+ Assert.throws(
+ () => hub.parseUrl(url),
+ new RegExp(`Error: Invalid model URL: ${url}`),
+ `Should throw with ${url}`
+ );
+});
+
// IndexedDB tests
/**
@@ -248,18 +365,41 @@ add_task(async function test_Init() {
});
/**
+ * Test checking existence of data in the cache.
+ */
+add_task(async function test_PutAndCheckExists() {
+ const cache = await initializeCache();
+ const testData = new ArrayBuffer(8); // Example data
+ const key = "file.txt";
+ await cache.put("org/model", "v1", "file.txt", testData, {
+ ETag: "ETAG123",
+ });
+
+ // Checking if the file exists
+ let exists = await cache.fileExists("org/model", "v1", key);
+ Assert.ok(exists, "The file should exist in the cache.");
+
+ // Removing all files from the model
+ await cache.deleteModel("org/model", "v1");
+
+ exists = await cache.fileExists("org/model", "v1", key);
+ Assert.ok(!exists, "The file should be gone from the cache.");
+
+ await deleteCache(cache);
+});
+
+/**
* Test adding data to the cache and retrieving it.
*/
add_task(async function test_PutAndGet() {
const cache = await initializeCache();
const testData = new ArrayBuffer(8); // Example data
- await cache.put("org", "model", "v1", "file.txt", testData, {
+ await cache.put("org/model", "v1", "file.txt", testData, {
ETag: "ETAG123",
});
const [retrievedData, headers] = await cache.getFile(
- "org",
- "model",
+ "org/model",
"v1",
"file.txt"
);
@@ -289,14 +429,9 @@ add_task(async function test_GetHeaders() {
extra: "extra",
};
- await cache.put("org", "model", "v1", "file.txt", testData, headers);
+ await cache.put("org/model", "v1", "file.txt", testData, headers);
- const storedHeaders = await cache.getHeaders(
- "org",
- "model",
- "v1",
- "file.txt"
- );
+ const storedHeaders = await cache.getHeaders("org/model", "v1", "file.txt");
// The `extra` field should be removed from the stored headers because
// it's not part of the allowed keys.
@@ -318,22 +453,8 @@ add_task(async function test_GetHeaders() {
*/
add_task(async function test_ListModels() {
const cache = await initializeCache();
- await cache.put(
- "org1",
- "modelA",
- "v1",
- "file1.txt",
- new ArrayBuffer(8),
- null
- );
- await cache.put(
- "org2",
- "modelB",
- "v1",
- "file2.txt",
- new ArrayBuffer(8),
- null
- );
+ await cache.put("org1/modelA", "v1", "file1.txt", new ArrayBuffer(8), null);
+ await cache.put("org2/modelB", "v1", "file2.txt", new ArrayBuffer(8), null);
const models = await cache.listModels();
Assert.ok(
@@ -348,10 +469,10 @@ add_task(async function test_ListModels() {
*/
add_task(async function test_DeleteModel() {
const cache = await initializeCache();
- await cache.put("org", "model", "v1", "file.txt", new ArrayBuffer(8), null);
- await cache.deleteModel("org", "model", "v1");
+ await cache.put("org/model", "v1", "file.txt", new ArrayBuffer(8), null);
+ await cache.deleteModel("org/model", "v1");
- const dataAfterDelete = await cache.getFile("org", "model", "v1", "file.txt");
+ const dataAfterDelete = await cache.getFile("org/model", "v1", "file.txt");
Assert.equal(
dataAfterDelete,
null,
diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine.js b/toolkit/components/ml/tests/browser/browser_ml_engine.js
index 6942809d6d..3a93b78182 100644
--- a/toolkit/components/ml/tests/browser/browser_ml_engine.js
+++ b/toolkit/components/ml/tests/browser/browser_ml_engine.js
@@ -5,6 +5,8 @@
/// <reference path="head.js" />
+requestLongerTimeout(2);
+
async function setup({ disabled = false, prefs = [] } = {}) {
const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({
autoDownloadFromRemoteSettings: false,
@@ -15,6 +17,7 @@ async function setup({ disabled = false, prefs = [] } = {}) {
// Enabled by default.
["browser.ml.enable", !disabled],
["browser.ml.logLevel", "All"],
+ ["browser.ml.modelCacheTimeout", 1000],
...prefs,
],
});
@@ -33,6 +36,8 @@ async function setup({ disabled = false, prefs = [] } = {}) {
};
}
+const PIPELINE_OPTIONS = new PipelineOptions({ taskName: "echo" });
+
add_task(async function test_ml_engine_basics() {
const { cleanup, remoteClients } = await setup();
@@ -40,22 +45,18 @@ add_task(async function test_ml_engine_basics() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
+ const inferencePromise = summarizer.run({ data: "This gets echoed." });
info("Wait for the pending downloads.");
- await remoteClients.models.resolvePendingDownloads(1);
- await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
- is(
- await summarizePromise,
- "This gets c",
- "The text gets cut in half simulating summarizing"
+ Assert.equal(
+ (await inferencePromise).output,
+ "This gets echoed.",
+ "The text get echoed exercising the whole flow."
);
ok(
@@ -68,40 +69,6 @@ add_task(async function test_ml_engine_basics() {
await cleanup();
});
-add_task(async function test_ml_engine_model_rejection() {
- const { cleanup, remoteClients } = await setup();
-
- info("Get the engine process");
- const mlEngineParent = await EngineProcess.getMLEngineParent();
-
- info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
-
- info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
-
- info("Wait for the pending downloads.");
- await remoteClients.wasm.resolvePendingDownloads(1);
- await remoteClients.models.rejectPendingDownloads(1);
-
- let error;
- try {
- await summarizePromise;
- } catch (e) {
- error = e;
- }
- is(
- error?.message,
- "Intentionally rejecting downloads.",
- "The error is correctly surfaced."
- );
-
- await cleanup();
-});
-
add_task(async function test_ml_engine_wasm_rejection() {
const { cleanup, remoteClients } = await setup();
@@ -109,21 +76,18 @@ add_task(async function test_ml_engine_wasm_rejection() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
+ const inferencePromise = summarizer.run({ data: "This gets echoed." });
info("Wait for the pending downloads.");
- await remoteClients.wasm.rejectPendingDownloads(1);
- await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].rejectPendingDownloads(1);
+ //await remoteClients.models.resolvePendingDownloads(1);
let error;
try {
- await summarizePromise;
+ await inferencePromise;
} catch (e) {
error = e;
}
@@ -146,21 +110,18 @@ add_task(async function test_ml_engine_model_error() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer with a throwing example.");
- const summarizePromise = summarizer.run("throw");
+ const inferencePromise = summarizer.run("throw");
info("Wait for the pending downloads.");
- await remoteClients.wasm.resolvePendingDownloads(1);
- await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
+ //await remoteClients.models.resolvePendingDownloads(1);
let error;
try {
- await summarizePromise;
+ await inferencePromise;
} catch (e) {
error = e;
}
@@ -186,22 +147,18 @@ add_task(async function test_ml_engine_destruction() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
+ const inferencePromise = summarizer.run({ data: "This gets echoed." });
info("Wait for the pending downloads.");
- await remoteClients.models.resolvePendingDownloads(1);
- await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
- is(
- await summarizePromise,
- "This gets c",
- "The text gets cut in half simulating summarizing"
+ Assert.equal(
+ (await inferencePromise).output,
+ "This gets echoed.",
+ "The text get echoed exercising the whole flow."
);
ok(
@@ -217,3 +174,57 @@ add_task(async function test_ml_engine_destruction() {
await cleanup();
});
+
+/**
+ * Tests that we display a nice error message when the pref is off
+ */
+add_task(async function test_pref_is_off() {
+ await SpecialPowers.pushPrefEnv({
+ set: [["browser.ml.enable", false]],
+ });
+
+ info("Get the engine process");
+ let error;
+
+ try {
+ await EngineProcess.getMLEngineParent();
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "MLEngine is disabled. Check the browser.ml prefs.",
+ "The error is correctly surfaced."
+ );
+
+ await SpecialPowers.pushPrefEnv({
+ set: [["browser.ml.enable", true]],
+ });
+});
+
+/**
+ * Tests that we verify the task name is valid
+ */
+add_task(async function test_invalid_task_name() {
+ const { cleanup, remoteClients } = await setup();
+
+ const options = new PipelineOptions({ taskName: "inv#alid" });
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+ const summarizer = mlEngineParent.getEngine(options);
+
+ let error;
+
+ try {
+ const res = summarizer.run({ data: "This gets echoed." });
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
+ await res;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes.",
+ "The error is correctly surfaced."
+ );
+ await cleanup();
+});
diff --git a/toolkit/components/ml/tests/browser/browser_ml_utils.js b/toolkit/components/ml/tests/browser/browser_ml_utils.js
new file mode 100644
index 0000000000..c215349af4
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/browser_ml_utils.js
@@ -0,0 +1,26 @@
+/* Any copyright is dedicated to the Public Domain.
+http://creativecommons.org/publicdomain/zero/1.0/ */
+"use strict";
+
+const { arrayBufferToBlobURL } = ChromeUtils.importESModule(
+ "chrome://global/content/ml/Utils.sys.mjs"
+);
+
+/**
+ * Test arrayBufferToBlobURL function.
+ */
+add_task(async function test_ml_utils_array_buffer_to_blob_url() {
+ const buffer = new ArrayBuffer(8);
+ const blobURL = arrayBufferToBlobURL(buffer);
+
+ Assert.equal(
+ typeof blobURL,
+ "string",
+ "arrayBufferToBlobURL should return a string"
+ );
+ Assert.equal(
+ blobURL.startsWith("blob:"),
+ true,
+ "The returned string should be a Blob URL"
+ );
+});
diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/config.json
new file mode 100644
index 0000000000..50dbb760bb
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": ["BertForMaskedLM"],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/onnx/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/onnx/config.json
new file mode 100644
index 0000000000..50dbb760bb
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/onnx/config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": ["BertForMaskedLM"],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js
index 9fc20c0e84..25e611a753 100644
--- a/toolkit/components/ml/tests/browser/head.js
+++ b/toolkit/components/ml/tests/browser/head.js
@@ -6,13 +6,6 @@
"use strict";
/**
- * @type {import("../../content/SummarizerModel.sys.mjs")}
- */
-const { SummarizerModel } = ChromeUtils.importESModule(
- "chrome://global/content/ml/SummarizerModel.sys.mjs"
-);
-
-/**
* @type {import("../../actors/MLEngineParent.sys.mjs")}
*/
const { MLEngineParent } = ChromeUtils.importESModule(
@@ -23,6 +16,14 @@ const { ModelHub, IndexedDBCache } = ChromeUtils.importESModule(
"chrome://global/content/ml/ModelHub.sys.mjs"
);
+const { getRuntimeWasmFilename } = ChromeUtils.importESModule(
+ "chrome://global/content/ml/Utils.sys.mjs"
+);
+
+const { PipelineOptions } = ChromeUtils.importESModule(
+ "chrome://global/content/ml/EngineProcess.sys.mjs"
+);
+
// This test suite shares some utility functions with translations as they work in a very
// similar fashion. Eventually, the plan is to unify these two components.
Services.scriptloader.loadSubScript(
@@ -30,94 +31,39 @@ Services.scriptloader.loadSubScript(
this
);
-function getDefaultModelRecords() {
- return [
- {
- name: "summarizer-model",
- version: SummarizerModel.MODEL_MAJOR_VERSION + ".0",
- },
- ];
-}
-
function getDefaultWasmRecords() {
return [
{
- name: "inference-engine",
+ name: getRuntimeWasmFilename(),
version: MLEngineParent.WASM_MAJOR_VERSION + ".0",
},
];
}
-/**
- * Creates a local RemoteSettingsClient for use within tests.
- *
- * @param {boolean} autoDownloadFromRemoteSettings
- * @param {object[]} models
- * @returns {AttachmentMock}
- */
-async function createMLModelsRemoteClient(
- autoDownloadFromRemoteSettings,
- models = getDefaultModelRecords()
-) {
- const { RemoteSettings } = ChromeUtils.importESModule(
- "resource://services-settings/remote-settings.sys.mjs"
- );
- const mockedCollectionName = "test-ml-models";
- const client = RemoteSettings(
- `${mockedCollectionName}-${_remoteSettingsMockId++}`
- );
- const metadata = {};
- await client.db.clear();
- await client.db.importChanges(
- metadata,
- Date.now(),
- models.map(({ name, version }) => ({
- id: crypto.randomUUID(),
- name,
- version,
- last_modified: Date.now(),
- schema: Date.now(),
- attachment: {
- hash: `${crypto.randomUUID()}`,
- size: `123`,
- filename: name,
- location: `main-workspace/ml-models/${crypto.randomUUID()}.bin`,
- mimetype: "application/octet-stream",
- },
- }))
- );
-
- return createAttachmentMock(
- client,
- mockedCollectionName,
- autoDownloadFromRemoteSettings
- );
-}
-
async function createAndMockMLRemoteSettings({
- models = getDefaultModelRecords(),
autoDownloadFromRemoteSettings = false,
} = {}) {
+ const runtime = await createMLWasmRemoteClient(
+ autoDownloadFromRemoteSettings
+ );
+ const options = await createOptionsRemoteClient();
+
const remoteClients = {
- models: await createMLModelsRemoteClient(
- autoDownloadFromRemoteSettings,
- models
- ),
- wasm: await createMLWasmRemoteClient(autoDownloadFromRemoteSettings),
+ "ml-onnx-runtime": runtime,
+ "ml-inference-options": options,
};
- MLEngineParent.mockRemoteSettings(remoteClients.wasm.client);
- SummarizerModel.mockRemoteSettings(remoteClients.models.client);
+ MLEngineParent.mockRemoteSettings({
+ "ml-onnx-runtime": runtime.client,
+ "ml-inference-options": options,
+ });
return {
async removeMocks() {
- await remoteClients.models.client.attachments.deleteAll();
- await remoteClients.models.client.db.clear();
- await remoteClients.wasm.client.attachments.deleteAll();
- await remoteClients.wasm.client.db.clear();
-
+ await runtime.client.attachments.deleteAll();
+ await runtime.client.db.clear();
+ await options.db.clear();
MLEngineParent.removeMocks();
- SummarizerModel.removeMocks();
},
remoteClients,
};
@@ -157,3 +103,32 @@ async function createMLWasmRemoteClient(autoDownloadFromRemoteSettings) {
autoDownloadFromRemoteSettings
);
}
+
+/**
+ * Creates a local RemoteSettingsClient for use within tests.
+ *
+ * @returns {RemoteSettings}
+ */
+async function createOptionsRemoteClient() {
+ const { RemoteSettings } = ChromeUtils.importESModule(
+ "resource://services-settings/remote-settings.sys.mjs"
+ );
+ const mockedCollectionName = "test-ml-inference-options";
+ const client = RemoteSettings(
+ `${mockedCollectionName}-${_remoteSettingsMockId++}`
+ );
+
+ const record = {
+ taskName: "echo",
+ modelId: "mozilla/distilvit",
+ processorId: "mozilla/distilvit",
+ tokenizerId: "mozilla/distilvit",
+ modelRevision: "main",
+ processorRevision: "main",
+ tokenizerRevision: "main",
+ id: "74a71cfd-1734-44e6-85c0-69cf3e874138",
+ };
+ await client.db.clear();
+ await client.db.importChanges({}, Date.now(), [record]);
+ return client;
+}