diff options
Diffstat (limited to 'toolkit/components/ml/tests/browser/browser_ml_engine.js')
-rw-r--r-- | toolkit/components/ml/tests/browser/browser_ml_engine.js | 155 |
1 files changed, 83 insertions, 72 deletions
diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine.js b/toolkit/components/ml/tests/browser/browser_ml_engine.js index 6942809d6d..3a93b78182 100644 --- a/toolkit/components/ml/tests/browser/browser_ml_engine.js +++ b/toolkit/components/ml/tests/browser/browser_ml_engine.js @@ -5,6 +5,8 @@ /// <reference path="head.js" /> +requestLongerTimeout(2); + async function setup({ disabled = false, prefs = [] } = {}) { const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({ autoDownloadFromRemoteSettings: false, @@ -15,6 +17,7 @@ async function setup({ disabled = false, prefs = [] } = {}) { // Enabled by default. ["browser.ml.enable", !disabled], ["browser.ml.logLevel", "All"], + ["browser.ml.modelCacheTimeout", 1000], ...prefs, ], }); @@ -33,6 +36,8 @@ async function setup({ disabled = false, prefs = [] } = {}) { }; } +const PIPELINE_OPTIONS = new PipelineOptions({ taskName: "echo" }); + add_task(async function test_ml_engine_basics() { const { cleanup, remoteClients } = await setup(); @@ -40,22 +45,18 @@ add_task(async function test_ml_engine_basics() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); + const inferencePromise = summarizer.run({ data: "This gets echoed." }); info("Wait for the pending downloads."); - await remoteClients.models.resolvePendingDownloads(1); - await remoteClients.wasm.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - is( - await summarizePromise, - "This gets c", - "The text gets cut in half simulating summarizing" + Assert.equal( + (await inferencePromise).output, + "This gets echoed.", + "The text get echoed exercising the whole flow." ); ok( @@ -68,40 +69,6 @@ add_task(async function test_ml_engine_basics() { await cleanup(); }); -add_task(async function test_ml_engine_model_rejection() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine process"); - const mlEngineParent = await EngineProcess.getMLEngineParent(); - - info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); - - info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); - - info("Wait for the pending downloads."); - await remoteClients.wasm.resolvePendingDownloads(1); - await remoteClients.models.rejectPendingDownloads(1); - - let error; - try { - await summarizePromise; - } catch (e) { - error = e; - } - is( - error?.message, - "Intentionally rejecting downloads.", - "The error is correctly surfaced." - ); - - await cleanup(); -}); - add_task(async function test_ml_engine_wasm_rejection() { const { cleanup, remoteClients } = await setup(); @@ -109,21 +76,18 @@ add_task(async function test_ml_engine_wasm_rejection() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); + const inferencePromise = summarizer.run({ data: "This gets echoed." }); info("Wait for the pending downloads."); - await remoteClients.wasm.rejectPendingDownloads(1); - await remoteClients.models.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].rejectPendingDownloads(1); + //await remoteClients.models.resolvePendingDownloads(1); let error; try { - await summarizePromise; + await inferencePromise; } catch (e) { error = e; } @@ -146,21 +110,18 @@ add_task(async function test_ml_engine_model_error() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer with a throwing example."); - const summarizePromise = summarizer.run("throw"); + const inferencePromise = summarizer.run("throw"); info("Wait for the pending downloads."); - await remoteClients.wasm.resolvePendingDownloads(1); - await remoteClients.models.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + //await remoteClients.models.resolvePendingDownloads(1); let error; try { - await summarizePromise; + await inferencePromise; } catch (e) { error = e; } @@ -186,22 +147,18 @@ add_task(async function test_ml_engine_destruction() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); + const inferencePromise = summarizer.run({ data: "This gets echoed." }); info("Wait for the pending downloads."); - await remoteClients.models.resolvePendingDownloads(1); - await remoteClients.wasm.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - is( - await summarizePromise, - "This gets c", - "The text gets cut in half simulating summarizing" + Assert.equal( + (await inferencePromise).output, + "This gets echoed.", + "The text get echoed exercising the whole flow." ); ok( @@ -217,3 +174,57 @@ add_task(async function test_ml_engine_destruction() { await cleanup(); }); + +/** + * Tests that we display a nice error message when the pref is off + */ +add_task(async function test_pref_is_off() { + await SpecialPowers.pushPrefEnv({ + set: [["browser.ml.enable", false]], + }); + + info("Get the engine process"); + let error; + + try { + await EngineProcess.getMLEngineParent(); + } catch (e) { + error = e; + } + is( + error?.message, + "MLEngine is disabled. Check the browser.ml prefs.", + "The error is correctly surfaced." + ); + + await SpecialPowers.pushPrefEnv({ + set: [["browser.ml.enable", true]], + }); +}); + +/** + * Tests that we verify the task name is valid + */ +add_task(async function test_invalid_task_name() { + const { cleanup, remoteClients } = await setup(); + + const options = new PipelineOptions({ taskName: "inv#alid" }); + const mlEngineParent = await EngineProcess.getMLEngineParent(); + const summarizer = mlEngineParent.getEngine(options); + + let error; + + try { + const res = summarizer.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + await res; + } catch (e) { + error = e; + } + is( + error?.message, + "Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes.", + "The error is correctly surfaced." + ); + await cleanup(); +}); |