summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/tests/browser/browser_ml_engine.js
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/tests/browser/browser_ml_engine.js')
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_engine.js155
1 files changed, 83 insertions, 72 deletions
diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine.js b/toolkit/components/ml/tests/browser/browser_ml_engine.js
index 6942809d6d..3a93b78182 100644
--- a/toolkit/components/ml/tests/browser/browser_ml_engine.js
+++ b/toolkit/components/ml/tests/browser/browser_ml_engine.js
@@ -5,6 +5,8 @@
/// <reference path="head.js" />
+requestLongerTimeout(2);
+
async function setup({ disabled = false, prefs = [] } = {}) {
const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({
autoDownloadFromRemoteSettings: false,
@@ -15,6 +17,7 @@ async function setup({ disabled = false, prefs = [] } = {}) {
// Enabled by default.
["browser.ml.enable", !disabled],
["browser.ml.logLevel", "All"],
+ ["browser.ml.modelCacheTimeout", 1000],
...prefs,
],
});
@@ -33,6 +36,8 @@ async function setup({ disabled = false, prefs = [] } = {}) {
};
}
+const PIPELINE_OPTIONS = new PipelineOptions({ taskName: "echo" });
+
add_task(async function test_ml_engine_basics() {
const { cleanup, remoteClients } = await setup();
@@ -40,22 +45,18 @@ add_task(async function test_ml_engine_basics() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
+ const inferencePromise = summarizer.run({ data: "This gets echoed." });
info("Wait for the pending downloads.");
- await remoteClients.models.resolvePendingDownloads(1);
- await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
- is(
- await summarizePromise,
- "This gets c",
- "The text gets cut in half simulating summarizing"
+ Assert.equal(
+ (await inferencePromise).output,
+ "This gets echoed.",
+ "The text get echoed exercising the whole flow."
);
ok(
@@ -68,40 +69,6 @@ add_task(async function test_ml_engine_basics() {
await cleanup();
});
-add_task(async function test_ml_engine_model_rejection() {
- const { cleanup, remoteClients } = await setup();
-
- info("Get the engine process");
- const mlEngineParent = await EngineProcess.getMLEngineParent();
-
- info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
-
- info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
-
- info("Wait for the pending downloads.");
- await remoteClients.wasm.resolvePendingDownloads(1);
- await remoteClients.models.rejectPendingDownloads(1);
-
- let error;
- try {
- await summarizePromise;
- } catch (e) {
- error = e;
- }
- is(
- error?.message,
- "Intentionally rejecting downloads.",
- "The error is correctly surfaced."
- );
-
- await cleanup();
-});
-
add_task(async function test_ml_engine_wasm_rejection() {
const { cleanup, remoteClients } = await setup();
@@ -109,21 +76,18 @@ add_task(async function test_ml_engine_wasm_rejection() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
+ const inferencePromise = summarizer.run({ data: "This gets echoed." });
info("Wait for the pending downloads.");
- await remoteClients.wasm.rejectPendingDownloads(1);
- await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].rejectPendingDownloads(1);
+ //await remoteClients.models.resolvePendingDownloads(1);
let error;
try {
- await summarizePromise;
+ await inferencePromise;
} catch (e) {
error = e;
}
@@ -146,21 +110,18 @@ add_task(async function test_ml_engine_model_error() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer with a throwing example.");
- const summarizePromise = summarizer.run("throw");
+ const inferencePromise = summarizer.run("throw");
info("Wait for the pending downloads.");
- await remoteClients.wasm.resolvePendingDownloads(1);
- await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
+ //await remoteClients.models.resolvePendingDownloads(1);
let error;
try {
- await summarizePromise;
+ await inferencePromise;
} catch (e) {
error = e;
}
@@ -186,22 +147,18 @@ add_task(async function test_ml_engine_destruction() {
const mlEngineParent = await EngineProcess.getMLEngineParent();
info("Get summarizer");
- const summarizer = mlEngineParent.getEngine(
- "summarizer",
- SummarizerModel.getModel
- );
+ const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS);
info("Run the summarizer");
- const summarizePromise = summarizer.run("This gets cut in half.");
+ const inferencePromise = summarizer.run({ data: "This gets echoed." });
info("Wait for the pending downloads.");
- await remoteClients.models.resolvePendingDownloads(1);
- await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
- is(
- await summarizePromise,
- "This gets c",
- "The text gets cut in half simulating summarizing"
+ Assert.equal(
+ (await inferencePromise).output,
+ "This gets echoed.",
+ "The text get echoed exercising the whole flow."
);
ok(
@@ -217,3 +174,57 @@ add_task(async function test_ml_engine_destruction() {
await cleanup();
});
+
+/**
+ * Tests that we display a nice error message when the pref is off
+ */
+add_task(async function test_pref_is_off() {
+ await SpecialPowers.pushPrefEnv({
+ set: [["browser.ml.enable", false]],
+ });
+
+ info("Get the engine process");
+ let error;
+
+ try {
+ await EngineProcess.getMLEngineParent();
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "MLEngine is disabled. Check the browser.ml prefs.",
+ "The error is correctly surfaced."
+ );
+
+ await SpecialPowers.pushPrefEnv({
+ set: [["browser.ml.enable", true]],
+ });
+});
+
+/**
+ * Tests that we verify the task name is valid
+ */
+add_task(async function test_invalid_task_name() {
+ const { cleanup, remoteClients } = await setup();
+
+ const options = new PipelineOptions({ taskName: "inv#alid" });
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+ const summarizer = mlEngineParent.getEngine(options);
+
+ let error;
+
+ try {
+ const res = summarizer.run({ data: "This gets echoed." });
+ await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
+ await res;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes.",
+ "The error is correctly surfaced."
+ );
+ await cleanup();
+});