summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/tests/browser
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/tests/browser')
-rw-r--r--toolkit/components/ml/tests/browser/browser.toml5
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_engine.js219
-rw-r--r--toolkit/components/ml/tests/browser/head.js155
3 files changed, 379 insertions, 0 deletions
diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml
new file mode 100644
index 0000000000..9ccda0beaa
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/browser.toml
@@ -0,0 +1,5 @@
+[DEFAULT]
+support-files = [
+ "head.js",
+]
+["browser_ml_engine.js"]
diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine.js b/toolkit/components/ml/tests/browser/browser_ml_engine.js
new file mode 100644
index 0000000000..6942809d6d
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/browser_ml_engine.js
@@ -0,0 +1,219 @@
+/* Any copyright is dedicated to the Public Domain.
+ http://creativecommons.org/publicdomain/zero/1.0/ */
+
+"use strict";
+
+/// <reference path="head.js" />
+
+async function setup({ disabled = false, prefs = [] } = {}) {
+ const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({
+ autoDownloadFromRemoteSettings: false,
+ });
+
+ await SpecialPowers.pushPrefEnv({
+ set: [
+ // Enabled by default.
+ ["browser.ml.enable", !disabled],
+ ["browser.ml.logLevel", "All"],
+ ...prefs,
+ ],
+ });
+
+ return {
+ remoteClients,
+ async cleanup() {
+ await removeMocks();
+ await waitForCondition(
+ () => EngineProcess.areAllEnginesTerminated(),
+ "Waiting for all of the engines to be terminated.",
+ 100,
+ 200
+ );
+ },
+ };
+}
+
+add_task(async function test_ml_engine_basics() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients.wasm.resolvePendingDownloads(1);
+
+ is(
+ await summarizePromise,
+ "This gets c",
+ "The text gets cut in half simulating summarizing"
+ );
+
+ ok(
+ !EngineProcess.areAllEnginesTerminated(),
+ "The engine process is still active."
+ );
+
+ await EngineProcess.destroyMLEngine();
+
+ await cleanup();
+});
+
+add_task(async function test_ml_engine_model_rejection() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients.models.rejectPendingDownloads(1);
+
+ let error;
+ try {
+ await summarizePromise;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "Intentionally rejecting downloads.",
+ "The error is correctly surfaced."
+ );
+
+ await cleanup();
+});
+
+add_task(async function test_ml_engine_wasm_rejection() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.wasm.rejectPendingDownloads(1);
+ await remoteClients.models.resolvePendingDownloads(1);
+
+ let error;
+ try {
+ await summarizePromise;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "Intentionally rejecting downloads.",
+ "The error is correctly surfaced."
+ );
+
+ await cleanup();
+});
+
+/**
+ * Tests that the SummarizerModel's internal errors are correctly surfaced.
+ */
+add_task(async function test_ml_engine_model_error() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer with a throwing example.");
+ const summarizePromise = summarizer.run("throw");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients.models.resolvePendingDownloads(1);
+
+ let error;
+ try {
+ await summarizePromise;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ 'Error: Received the message "throw", so intentionally throwing an error.',
+ "The error is correctly surfaced."
+ );
+
+ summarizer.terminate();
+
+ await cleanup();
+});
+
+/**
+ * This test is really similar to the "basic" test, but tests manually destroying
+ * the summarizer.
+ */
+add_task(async function test_ml_engine_destruction() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients.wasm.resolvePendingDownloads(1);
+
+ is(
+ await summarizePromise,
+ "This gets c",
+ "The text gets cut in half simulating summarizing"
+ );
+
+ ok(
+ !EngineProcess.areAllEnginesTerminated(),
+ "The engine process is still active."
+ );
+
+ summarizer.terminate();
+
+ info(
+ "The summarizer is manually destroyed. The cleanup function should wait for the engine process to be destroyed."
+ );
+
+ await cleanup();
+});
diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js
new file mode 100644
index 0000000000..99d27ce18a
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/head.js
@@ -0,0 +1,155 @@
+/* Any copyright is dedicated to the Public Domain.
+ http://creativecommons.org/publicdomain/zero/1.0/ */
+
+/// <reference path="../../../../../toolkit/components/translations/tests/browser/shared-head.js" />
+
+"use strict";
+
+/**
+ * @type {import("../../content/SummarizerModel.sys.mjs")}
+ */
+const { SummarizerModel } = ChromeUtils.importESModule(
+ "chrome://global/content/ml/SummarizerModel.sys.mjs"
+);
+
+/**
+ * @type {import("../../actors/MLEngineParent.sys.mjs")}
+ */
+const { MLEngineParent } = ChromeUtils.importESModule(
+ "resource://gre/actors/MLEngineParent.sys.mjs"
+);
+
+// This test suite shares some utility functions with translations as they work in a very
+// similar fashion. Eventually, the plan is to unify these two components.
+Services.scriptloader.loadSubScript(
+ "chrome://mochitests/content/browser/toolkit/components/translations/tests/browser/shared-head.js",
+ this
+);
+
+function getDefaultModelRecords() {
+ return [
+ {
+ name: "summarizer-model",
+ version: SummarizerModel.MODEL_MAJOR_VERSION + ".0",
+ },
+ ];
+}
+
+function getDefaultWasmRecords() {
+ return [
+ {
+ name: "inference-engine",
+ version: MLEngineParent.WASM_MAJOR_VERSION + ".0",
+ },
+ ];
+}
+
+/**
+ * Creates a local RemoteSettingsClient for use within tests.
+ *
+ * @param {boolean} autoDownloadFromRemoteSettings
+ * @param {object[]} models
+ * @returns {AttachmentMock}
+ */
+async function createMLModelsRemoteClient(
+ autoDownloadFromRemoteSettings,
+ models = getDefaultModelRecords()
+) {
+ const { RemoteSettings } = ChromeUtils.importESModule(
+ "resource://services-settings/remote-settings.sys.mjs"
+ );
+ const mockedCollectionName = "test-ml-models";
+ const client = RemoteSettings(
+ `${mockedCollectionName}-${_remoteSettingsMockId++}`
+ );
+ const metadata = {};
+ await client.db.clear();
+ await client.db.importChanges(
+ metadata,
+ Date.now(),
+ models.map(({ name, version }) => ({
+ id: crypto.randomUUID(),
+ name,
+ version,
+ last_modified: Date.now(),
+ schema: Date.now(),
+ attachment: {
+ hash: `${crypto.randomUUID()}`,
+ size: `123`,
+ filename: name,
+ location: `main-workspace/ml-models/${crypto.randomUUID()}.bin`,
+ mimetype: "application/octet-stream",
+ },
+ }))
+ );
+
+ return createAttachmentMock(
+ client,
+ mockedCollectionName,
+ autoDownloadFromRemoteSettings
+ );
+}
+
+async function createAndMockMLRemoteSettings({
+ models = getDefaultModelRecords(),
+ autoDownloadFromRemoteSettings = false,
+} = {}) {
+ const remoteClients = {
+ models: await createMLModelsRemoteClient(
+ autoDownloadFromRemoteSettings,
+ models
+ ),
+ wasm: await createMLWasmRemoteClient(autoDownloadFromRemoteSettings),
+ };
+
+ MLEngineParent.mockRemoteSettings(remoteClients.wasm.client);
+ SummarizerModel.mockRemoteSettings(remoteClients.models.client);
+
+ return {
+ async removeMocks() {
+ await remoteClients.models.client.attachments.deleteAll();
+ await remoteClients.models.client.db.clear();
+ await remoteClients.wasm.client.attachments.deleteAll();
+ await remoteClients.wasm.client.db.clear();
+
+ MLEngineParent.removeMocks();
+ SummarizerModel.removeMocks();
+ },
+ remoteClients,
+ };
+}
+
+/**
+ * Creates a local RemoteSettingsClient for use within tests.
+ *
+ * @param {boolean} autoDownloadFromRemoteSettings
+ * @returns {AttachmentMock}
+ */
+async function createMLWasmRemoteClient(autoDownloadFromRemoteSettings) {
+ const { RemoteSettings } = ChromeUtils.importESModule(
+ "resource://services-settings/remote-settings.sys.mjs"
+ );
+ const mockedCollectionName = "test-translation-wasm";
+ const client = RemoteSettings(
+ `${mockedCollectionName}-${_remoteSettingsMockId++}`
+ );
+ const metadata = {};
+ await client.db.clear();
+ await client.db.importChanges(
+ metadata,
+ Date.now(),
+ getDefaultWasmRecords().map(({ name, version }) => ({
+ id: crypto.randomUUID(),
+ name,
+ version,
+ last_modified: Date.now(),
+ schema: Date.now(),
+ }))
+ );
+
+ return createAttachmentMock(
+ client,
+ mockedCollectionName,
+ autoDownloadFromRemoteSettings
+ );
+}