diff options
Diffstat (limited to 'toolkit/components/ml/tests')
7 files changed, 386 insertions, 205 deletions
diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml index 57637c8bda..ed3c04f066 100644 --- a/toolkit/components/ml/tests/browser/browser.toml +++ b/toolkit/components/ml/tests/browser/browser.toml @@ -1,9 +1,15 @@ [DEFAULT] +run-if = ["nightly_build"] # Bug 1890946 - enable the inference engine in release support-files = [ "head.js", "data/**/*.*" ] ["browser_ml_cache.js"] +lineno = "7" ["browser_ml_engine.js"] +lineno = "10" + +["browser_ml_utils.js"] +lineno = "13" diff --git a/toolkit/components/ml/tests/browser/browser_ml_cache.js b/toolkit/components/ml/tests/browser/browser_ml_cache.js index d8725368bd..8d879fc74d 100644 --- a/toolkit/components/ml/tests/browser/browser_ml_cache.js +++ b/toolkit/components/ml/tests/browser/browser_ml_cache.js @@ -11,16 +11,20 @@ const FAKE_HUB = "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data"; const FAKE_MODEL_ARGS = { - organization: "acme", - modelName: "bert", - modelVersion: "main", + model: "acme/bert", + revision: "main", + file: "config.json", +}; + +const FAKE_RELEASED_MODEL_ARGS = { + model: "acme/bert", + revision: "v0.1", file: "config.json", }; const FAKE_ONNX_MODEL_ARGS = { - organization: "acme", - modelName: "bert", - modelVersion: "main", + model: "acme/bert", + revision: "main", file: "onnx/config.json", }; @@ -35,6 +39,9 @@ const badHubs = [ "https://model-hub.mozilla.org.hack", // Domain that contains allowed domain ]; +/** + * Make sure we reject bad model hub URLs. + */ add_task(async function test_bad_hubs() { for (const badHub of badHubs) { Assert.throws( @@ -60,60 +67,57 @@ add_task(async function test_allowed_hub() { const badInputs = [ [ { - organization: "ac me", - modelName: "bert", - modelVersion: "main", + model: "ac me/bert", + revision: "main", file: "config.json", }, "Org can only contain letters, numbers, and hyphens", ], [ { - organization: "1111", - modelName: "bert", - modelVersion: "main", + model: "1111/bert", + revision: "main", file: "config.json", }, "Org cannot contain only numbers", ], [ { - organization: "-acme", - modelName: "bert", - modelVersion: "main", + model: "-acme/bert", + revision: "main", file: "config.json", }, "Org start or end with a hyphen, or use consecutive hyphens", ], [ { - organization: "a-c-m-e", - modelName: "#bert", - modelVersion: "main", + model: "a-c-m-e/#bert", + revision: "main", file: "config.json", }, "Models can only contain letters, numbers, and hyphens, underscord, periods", ], [ { - organization: "a-c-m-e", - modelName: "b$ert", - modelVersion: "main", + model: "a-c-m-e/b$ert", + revision: "main", file: "config.json", }, "Models cannot contain spaces or control characters", ], [ { - organization: "a-c-m-e", - modelName: "b$ert", - modelVersion: "main", + model: "a-c-m-e/b$ert", + revision: "main", file: ".filename", }, "File", ], ]; +/** + * Make sure we reject bad inputs. + */ add_task(async function test_bad_inputs() { const hub = new ModelHub({ rootUrl: FAKE_HUB }); @@ -129,6 +133,9 @@ add_task(async function test_bad_inputs() { } }); +/** + * Test that we can retrieve a file as an ArrayBuffer. + */ add_task(async function test_getting_file() { const hub = new ModelHub({ rootUrl: FAKE_HUB }); @@ -144,6 +151,35 @@ add_task(async function test_getting_file() { Assert.equal(jsonData.hidden_size, 768); }); +/** + * Test that we can retrieve a file from a released model and skip head calls + */ +add_task(async function test_getting_released_file() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + console.log(hub); + + let spy = sinon.spy(hub, "getETag"); + let [array, headers] = await hub.getModelFileAsArrayBuffer( + FAKE_RELEASED_MODEL_ARGS + ); + + Assert.equal(headers["Content-Type"], "application/json"); + + // check the content of the file. + let jsonData = JSON.parse( + String.fromCharCode.apply(null, new Uint8Array(array)) + ); + + Assert.equal(jsonData.hidden_size, 768); + + // check that head calls were not made + Assert.ok(!spy.called, "getETag should have never been called."); + spy.restore(); +}); + +/** + * Make sure files can be located in sub directories + */ add_task(async function test_getting_file_in_subdir() { const hub = new ModelHub({ rootUrl: FAKE_HUB }); @@ -161,10 +197,13 @@ add_task(async function test_getting_file_in_subdir() { Assert.equal(jsonData.hidden_size, 768); }); +/** + * Test that we can use a custom URL template. + */ add_task(async function test_getting_file_custom_path() { const hub = new ModelHub({ rootUrl: FAKE_HUB, - urlTemplate: "${organization}/${modelName}/resolve/${modelVersion}/${file}", + urlTemplate: "{model}/resolve/{revision}", }); let res = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); @@ -172,9 +211,11 @@ add_task(async function test_getting_file_custom_path() { Assert.equal(res[1]["Content-Type"], "application/json"); }); +/** + * Test that we can't use an URL with a query for the template + */ add_task(async function test_getting_file_custom_path_rogue() { - const urlTemplate = - "${organization}/${modelName}/resolve/${modelVersion}/${file}?some_id=bedqwdw"; + const urlTemplate = "{model}/resolve/{revision}/?some_id=bedqwdw"; Assert.throws( () => new ModelHub({ rootUrl: FAKE_HUB, urlTemplate }), /Invalid URL template/, @@ -182,6 +223,9 @@ add_task(async function test_getting_file_custom_path_rogue() { ); }); +/** + * Test that the file can be returned as a response and its content correct. + */ add_task(async function test_getting_file_as_response() { const hub = new ModelHub({ rootUrl: FAKE_HUB }); @@ -192,6 +236,10 @@ add_task(async function test_getting_file_as_response() { Assert.equal(jsonData.hidden_size, 768); }); +/** + * Test that the cache is used when the data is retrieved from the server + * and that the cache is updated with the new data. + */ add_task(async function test_getting_file_from_cache() { const hub = new ModelHub({ rootUrl: FAKE_HUB }); let array = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); @@ -213,6 +261,75 @@ add_task(async function test_getting_file_from_cache() { Assert.deepEqual(array, array2); }); +/** + * Test parsing of a well-formed full URL, including protocol and path. + */ +add_task(async function testWellFormedFullUrl() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + const url = `${FAKE_HUB}/org1/model1/v1/file/path`; + const result = hub.parseUrl(url); + + Assert.equal( + result.model, + "org1/model1", + "Model should be parsed correctly." + ); + Assert.equal(result.revision, "v1", "Revision should be parsed correctly."); + Assert.equal( + result.file, + "file/path", + "File path should be parsed correctly." + ); +}); + +/** + * Test parsing of a well-formed relative URL, starting with a slash. + */ +add_task(async function testWellFormedRelativeUrl() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + const url = "/org1/model1/v1/file/path"; + const result = hub.parseUrl(url); + + Assert.equal( + result.model, + "org1/model1", + "Model should be parsed correctly." + ); + Assert.equal(result.revision, "v1", "Revision should be parsed correctly."); + Assert.equal( + result.file, + "file/path", + "File path should be parsed correctly." + ); +}); + +/** + * Ensures an error is thrown when the URL does not start with the expected root URL or a slash. + */ +add_task(async function testInvalidDomain() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + const url = "https://example.com/org1/model1/v1/file/path"; + Assert.throws( + () => hub.parseUrl(url), + new RegExp(`Error: Invalid domain for model URL: ${url}`), + `Should throw with ${url}` + ); +}); + +/** Tests the method's error handling when the URL format does not include the required segments. + * + */ +add_task(async function testTooFewParts() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + const url = "/org1/model1"; + Assert.throws( + () => hub.parseUrl(url), + new RegExp(`Error: Invalid model URL: ${url}`), + `Should throw with ${url}` + ); +}); + // IndexedDB tests /** @@ -248,18 +365,41 @@ add_task(async function test_Init() { }); /** + * Test checking existence of data in the cache. + */ +add_task(async function test_PutAndCheckExists() { + const cache = await initializeCache(); + const testData = new ArrayBuffer(8); // Example data + const key = "file.txt"; + await cache.put("org/model", "v1", "file.txt", testData, { + ETag: "ETAG123", + }); + + // Checking if the file exists + let exists = await cache.fileExists("org/model", "v1", key); + Assert.ok(exists, "The file should exist in the cache."); + + // Removing all files from the model + await cache.deleteModel("org/model", "v1"); + + exists = await cache.fileExists("org/model", "v1", key); + Assert.ok(!exists, "The file should be gone from the cache."); + + await deleteCache(cache); +}); + +/** * Test adding data to the cache and retrieving it. */ add_task(async function test_PutAndGet() { const cache = await initializeCache(); const testData = new ArrayBuffer(8); // Example data - await cache.put("org", "model", "v1", "file.txt", testData, { + await cache.put("org/model", "v1", "file.txt", testData, { ETag: "ETAG123", }); const [retrievedData, headers] = await cache.getFile( - "org", - "model", + "org/model", "v1", "file.txt" ); @@ -289,14 +429,9 @@ add_task(async function test_GetHeaders() { extra: "extra", }; - await cache.put("org", "model", "v1", "file.txt", testData, headers); + await cache.put("org/model", "v1", "file.txt", testData, headers); - const storedHeaders = await cache.getHeaders( - "org", - "model", - "v1", - "file.txt" - ); + const storedHeaders = await cache.getHeaders("org/model", "v1", "file.txt"); // The `extra` field should be removed from the stored headers because // it's not part of the allowed keys. @@ -318,22 +453,8 @@ add_task(async function test_GetHeaders() { */ add_task(async function test_ListModels() { const cache = await initializeCache(); - await cache.put( - "org1", - "modelA", - "v1", - "file1.txt", - new ArrayBuffer(8), - null - ); - await cache.put( - "org2", - "modelB", - "v1", - "file2.txt", - new ArrayBuffer(8), - null - ); + await cache.put("org1/modelA", "v1", "file1.txt", new ArrayBuffer(8), null); + await cache.put("org2/modelB", "v1", "file2.txt", new ArrayBuffer(8), null); const models = await cache.listModels(); Assert.ok( @@ -348,10 +469,10 @@ add_task(async function test_ListModels() { */ add_task(async function test_DeleteModel() { const cache = await initializeCache(); - await cache.put("org", "model", "v1", "file.txt", new ArrayBuffer(8), null); - await cache.deleteModel("org", "model", "v1"); + await cache.put("org/model", "v1", "file.txt", new ArrayBuffer(8), null); + await cache.deleteModel("org/model", "v1"); - const dataAfterDelete = await cache.getFile("org", "model", "v1", "file.txt"); + const dataAfterDelete = await cache.getFile("org/model", "v1", "file.txt"); Assert.equal( dataAfterDelete, null, diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine.js b/toolkit/components/ml/tests/browser/browser_ml_engine.js index 6942809d6d..3a93b78182 100644 --- a/toolkit/components/ml/tests/browser/browser_ml_engine.js +++ b/toolkit/components/ml/tests/browser/browser_ml_engine.js @@ -5,6 +5,8 @@ /// <reference path="head.js" /> +requestLongerTimeout(2); + async function setup({ disabled = false, prefs = [] } = {}) { const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({ autoDownloadFromRemoteSettings: false, @@ -15,6 +17,7 @@ async function setup({ disabled = false, prefs = [] } = {}) { // Enabled by default. ["browser.ml.enable", !disabled], ["browser.ml.logLevel", "All"], + ["browser.ml.modelCacheTimeout", 1000], ...prefs, ], }); @@ -33,6 +36,8 @@ async function setup({ disabled = false, prefs = [] } = {}) { }; } +const PIPELINE_OPTIONS = new PipelineOptions({ taskName: "echo" }); + add_task(async function test_ml_engine_basics() { const { cleanup, remoteClients } = await setup(); @@ -40,22 +45,18 @@ add_task(async function test_ml_engine_basics() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); + const inferencePromise = summarizer.run({ data: "This gets echoed." }); info("Wait for the pending downloads."); - await remoteClients.models.resolvePendingDownloads(1); - await remoteClients.wasm.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - is( - await summarizePromise, - "This gets c", - "The text gets cut in half simulating summarizing" + Assert.equal( + (await inferencePromise).output, + "This gets echoed.", + "The text get echoed exercising the whole flow." ); ok( @@ -68,40 +69,6 @@ add_task(async function test_ml_engine_basics() { await cleanup(); }); -add_task(async function test_ml_engine_model_rejection() { - const { cleanup, remoteClients } = await setup(); - - info("Get the engine process"); - const mlEngineParent = await EngineProcess.getMLEngineParent(); - - info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); - - info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); - - info("Wait for the pending downloads."); - await remoteClients.wasm.resolvePendingDownloads(1); - await remoteClients.models.rejectPendingDownloads(1); - - let error; - try { - await summarizePromise; - } catch (e) { - error = e; - } - is( - error?.message, - "Intentionally rejecting downloads.", - "The error is correctly surfaced." - ); - - await cleanup(); -}); - add_task(async function test_ml_engine_wasm_rejection() { const { cleanup, remoteClients } = await setup(); @@ -109,21 +76,18 @@ add_task(async function test_ml_engine_wasm_rejection() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); + const inferencePromise = summarizer.run({ data: "This gets echoed." }); info("Wait for the pending downloads."); - await remoteClients.wasm.rejectPendingDownloads(1); - await remoteClients.models.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].rejectPendingDownloads(1); + //await remoteClients.models.resolvePendingDownloads(1); let error; try { - await summarizePromise; + await inferencePromise; } catch (e) { error = e; } @@ -146,21 +110,18 @@ add_task(async function test_ml_engine_model_error() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer with a throwing example."); - const summarizePromise = summarizer.run("throw"); + const inferencePromise = summarizer.run("throw"); info("Wait for the pending downloads."); - await remoteClients.wasm.resolvePendingDownloads(1); - await remoteClients.models.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + //await remoteClients.models.resolvePendingDownloads(1); let error; try { - await summarizePromise; + await inferencePromise; } catch (e) { error = e; } @@ -186,22 +147,18 @@ add_task(async function test_ml_engine_destruction() { const mlEngineParent = await EngineProcess.getMLEngineParent(); info("Get summarizer"); - const summarizer = mlEngineParent.getEngine( - "summarizer", - SummarizerModel.getModel - ); + const summarizer = mlEngineParent.getEngine(PIPELINE_OPTIONS); info("Run the summarizer"); - const summarizePromise = summarizer.run("This gets cut in half."); + const inferencePromise = summarizer.run({ data: "This gets echoed." }); info("Wait for the pending downloads."); - await remoteClients.models.resolvePendingDownloads(1); - await remoteClients.wasm.resolvePendingDownloads(1); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); - is( - await summarizePromise, - "This gets c", - "The text gets cut in half simulating summarizing" + Assert.equal( + (await inferencePromise).output, + "This gets echoed.", + "The text get echoed exercising the whole flow." ); ok( @@ -217,3 +174,57 @@ add_task(async function test_ml_engine_destruction() { await cleanup(); }); + +/** + * Tests that we display a nice error message when the pref is off + */ +add_task(async function test_pref_is_off() { + await SpecialPowers.pushPrefEnv({ + set: [["browser.ml.enable", false]], + }); + + info("Get the engine process"); + let error; + + try { + await EngineProcess.getMLEngineParent(); + } catch (e) { + error = e; + } + is( + error?.message, + "MLEngine is disabled. Check the browser.ml prefs.", + "The error is correctly surfaced." + ); + + await SpecialPowers.pushPrefEnv({ + set: [["browser.ml.enable", true]], + }); +}); + +/** + * Tests that we verify the task name is valid + */ +add_task(async function test_invalid_task_name() { + const { cleanup, remoteClients } = await setup(); + + const options = new PipelineOptions({ taskName: "inv#alid" }); + const mlEngineParent = await EngineProcess.getMLEngineParent(); + const summarizer = mlEngineParent.getEngine(options); + + let error; + + try { + const res = summarizer.run({ data: "This gets echoed." }); + await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1); + await res; + } catch (e) { + error = e; + } + is( + error?.message, + "Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes.", + "The error is correctly surfaced." + ); + await cleanup(); +}); diff --git a/toolkit/components/ml/tests/browser/browser_ml_utils.js b/toolkit/components/ml/tests/browser/browser_ml_utils.js new file mode 100644 index 0000000000..c215349af4 --- /dev/null +++ b/toolkit/components/ml/tests/browser/browser_ml_utils.js @@ -0,0 +1,26 @@ +/* Any copyright is dedicated to the Public Domain. +http://creativecommons.org/publicdomain/zero/1.0/ */ +"use strict"; + +const { arrayBufferToBlobURL } = ChromeUtils.importESModule( + "chrome://global/content/ml/Utils.sys.mjs" +); + +/** + * Test arrayBufferToBlobURL function. + */ +add_task(async function test_ml_utils_array_buffer_to_blob_url() { + const buffer = new ArrayBuffer(8); + const blobURL = arrayBufferToBlobURL(buffer); + + Assert.equal( + typeof blobURL, + "string", + "arrayBufferToBlobURL should return a string" + ); + Assert.equal( + blobURL.startsWith("blob:"), + true, + "The returned string should be a Blob URL" + ); +}); diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/config.json new file mode 100644 index 0000000000..50dbb760bb --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/config.json @@ -0,0 +1,21 @@ +{ + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/onnx/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/onnx/config.json new file mode 100644 index 0000000000..50dbb760bb --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/v0.1/onnx/config.json @@ -0,0 +1,21 @@ +{ + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js index 9fc20c0e84..25e611a753 100644 --- a/toolkit/components/ml/tests/browser/head.js +++ b/toolkit/components/ml/tests/browser/head.js @@ -6,13 +6,6 @@ "use strict"; /** - * @type {import("../../content/SummarizerModel.sys.mjs")} - */ -const { SummarizerModel } = ChromeUtils.importESModule( - "chrome://global/content/ml/SummarizerModel.sys.mjs" -); - -/** * @type {import("../../actors/MLEngineParent.sys.mjs")} */ const { MLEngineParent } = ChromeUtils.importESModule( @@ -23,6 +16,14 @@ const { ModelHub, IndexedDBCache } = ChromeUtils.importESModule( "chrome://global/content/ml/ModelHub.sys.mjs" ); +const { getRuntimeWasmFilename } = ChromeUtils.importESModule( + "chrome://global/content/ml/Utils.sys.mjs" +); + +const { PipelineOptions } = ChromeUtils.importESModule( + "chrome://global/content/ml/EngineProcess.sys.mjs" +); + // This test suite shares some utility functions with translations as they work in a very // similar fashion. Eventually, the plan is to unify these two components. Services.scriptloader.loadSubScript( @@ -30,94 +31,39 @@ Services.scriptloader.loadSubScript( this ); -function getDefaultModelRecords() { - return [ - { - name: "summarizer-model", - version: SummarizerModel.MODEL_MAJOR_VERSION + ".0", - }, - ]; -} - function getDefaultWasmRecords() { return [ { - name: "inference-engine", + name: getRuntimeWasmFilename(), version: MLEngineParent.WASM_MAJOR_VERSION + ".0", }, ]; } -/** - * Creates a local RemoteSettingsClient for use within tests. - * - * @param {boolean} autoDownloadFromRemoteSettings - * @param {object[]} models - * @returns {AttachmentMock} - */ -async function createMLModelsRemoteClient( - autoDownloadFromRemoteSettings, - models = getDefaultModelRecords() -) { - const { RemoteSettings } = ChromeUtils.importESModule( - "resource://services-settings/remote-settings.sys.mjs" - ); - const mockedCollectionName = "test-ml-models"; - const client = RemoteSettings( - `${mockedCollectionName}-${_remoteSettingsMockId++}` - ); - const metadata = {}; - await client.db.clear(); - await client.db.importChanges( - metadata, - Date.now(), - models.map(({ name, version }) => ({ - id: crypto.randomUUID(), - name, - version, - last_modified: Date.now(), - schema: Date.now(), - attachment: { - hash: `${crypto.randomUUID()}`, - size: `123`, - filename: name, - location: `main-workspace/ml-models/${crypto.randomUUID()}.bin`, - mimetype: "application/octet-stream", - }, - })) - ); - - return createAttachmentMock( - client, - mockedCollectionName, - autoDownloadFromRemoteSettings - ); -} - async function createAndMockMLRemoteSettings({ - models = getDefaultModelRecords(), autoDownloadFromRemoteSettings = false, } = {}) { + const runtime = await createMLWasmRemoteClient( + autoDownloadFromRemoteSettings + ); + const options = await createOptionsRemoteClient(); + const remoteClients = { - models: await createMLModelsRemoteClient( - autoDownloadFromRemoteSettings, - models - ), - wasm: await createMLWasmRemoteClient(autoDownloadFromRemoteSettings), + "ml-onnx-runtime": runtime, + "ml-inference-options": options, }; - MLEngineParent.mockRemoteSettings(remoteClients.wasm.client); - SummarizerModel.mockRemoteSettings(remoteClients.models.client); + MLEngineParent.mockRemoteSettings({ + "ml-onnx-runtime": runtime.client, + "ml-inference-options": options, + }); return { async removeMocks() { - await remoteClients.models.client.attachments.deleteAll(); - await remoteClients.models.client.db.clear(); - await remoteClients.wasm.client.attachments.deleteAll(); - await remoteClients.wasm.client.db.clear(); - + await runtime.client.attachments.deleteAll(); + await runtime.client.db.clear(); + await options.db.clear(); MLEngineParent.removeMocks(); - SummarizerModel.removeMocks(); }, remoteClients, }; @@ -157,3 +103,32 @@ async function createMLWasmRemoteClient(autoDownloadFromRemoteSettings) { autoDownloadFromRemoteSettings ); } + +/** + * Creates a local RemoteSettingsClient for use within tests. + * + * @returns {RemoteSettings} + */ +async function createOptionsRemoteClient() { + const { RemoteSettings } = ChromeUtils.importESModule( + "resource://services-settings/remote-settings.sys.mjs" + ); + const mockedCollectionName = "test-ml-inference-options"; + const client = RemoteSettings( + `${mockedCollectionName}-${_remoteSettingsMockId++}` + ); + + const record = { + taskName: "echo", + modelId: "mozilla/distilvit", + processorId: "mozilla/distilvit", + tokenizerId: "mozilla/distilvit", + modelRevision: "main", + processorRevision: "main", + tokenizerRevision: "main", + id: "74a71cfd-1734-44e6-85c0-69cf3e874138", + }; + await client.db.clear(); + await client.db.importChanges({}, Date.now(), [record]); + return client; +} |