/* Any copyright is dedicated to the Public Domain. http://creativecommons.org/publicdomain/zero/1.0/ */ "use strict"; const { sinon } = ChromeUtils.importESModule( "resource://testing-common/Sinon.sys.mjs" ); const { ProgressStatusText, ProgressType } = ChromeUtils.importESModule( "chrome://global/content/ml/Utils.sys.mjs" ); const { OPFS } = ChromeUtils.importESModule( "chrome://global/content/ml/OPFS.sys.mjs" ); const { URLChecker } = ChromeUtils.importESModule( "chrome://global/content/ml/Utils.sys.mjs" ); // Root URL of the fake hub, see the `data` dir in the tests. const FAKE_HUB = "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data"; const FAKE_URL_TEMPLATE = "{model}/resolve/{revision}"; const FAKE_MODEL_ARGS = { model: "acme/bert", revision: "main", file: "config.json", taskName: "task_model", }; const FAKE_RELEASED_MODEL_ARGS = { model: "acme/bert", revision: "v0.1", file: "config.json", taskName: "task_released", }; const FAKE_ONNX_MODEL_ARGS = { model: "acme/bert", revision: "main", file: "onnx/config.json", taskName: "task_onnx", }; function createRandomBlob(blockSize = 8, count = 1) { const blocks = Array.from({ length: count }, () => Uint32Array.from( { length: blockSize / 4 }, () => Math.random() * 4294967296 ) ); return new Blob(blocks, { type: "application/octet-stream" }); } function createBlob(size = 8) { return createRandomBlob(size); } function stripLastUsed(data) { return data.map(({ lastUsed: _unusedLastUsed, ...rest }) => { return rest; }); } /** * Test the MOZ_ALLOW_EXTERNAL_ML_HUB environment variable */ add_task(async function test_allow_external_ml_hub() { Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", "1"); new ModelHub({ rootUrl: "https://huggingface.co" }); Services.env.set("MOZ_ALLOW_EXTERNAL_ML_HUB", ""); }); const badInputs = [ [ { model: "ac me/bert", revision: "main", file: "config.json", }, "Org can only contain letters, numbers, and hyphens", ], [ { model: "1111/bert", revision: "main", file: "config.json", }, "Org cannot contain only numbers", ], [ { model: "-acme/bert", revision: "main", file: "config.json", }, "Org start or end with a hyphen, or use consecutive hyphens", ], [ { model: "a-c-m-e/#bert", revision: "main", file: "config.json", }, "Models can only contain letters, numbers, and hyphens, underscord, periods", ], [ { model: "a-c-m-e/b$ert", revision: "main", file: "config.json", }, "Models cannot contain spaces or control characters", ], [ { model: "a-c-m-e/b$ert", revision: "main", file: ".filename", }, "File", ], ]; /** * Make sure we reject bad inputs. */ add_task(async function test_bad_inputs() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); for (const badInput of badInputs) { const params = badInput[0]; const errorMsg = badInput[1]; try { await hub.getModelFileAsArrayBuffer(params); } catch (error) { continue; } throw new Error(errorMsg); } }); /** * Test that we can retrieve a file as an ArrayBuffer. */ add_task(async function test_getting_file() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); let [array, headers] = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); Assert.equal(headers["Content-Type"], "application/json"); // check the content of the file. let jsonData = JSON.parse( String.fromCharCode.apply(null, new Uint8Array(array)) ); Assert.equal(jsonData.hidden_size, 768); }); /** * Test that we can retrieve a file as an ArrayBuffer even if we don't have headers */ add_task(async function test_getting_file_no_headers() { await SpecialPowers.pushPrefEnv({ set: [ // Enabled by default. ["browser.ml.logLevel", "All"], ], }); const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, reset: true, }); // Return empty headers sinon.stub(hub, "extractHeaders").callsFake(function () { return {}; }); let [array, headers] = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); Assert.equal(headers["Content-Type"], "application/octet-stream"); // default content type // check the content of the file. let jsonData = JSON.parse( String.fromCharCode.apply(null, new Uint8Array(array)) ); Assert.equal(jsonData.hidden_size, 768); hub.extractHeaders.restore(); await deleteCache(hub.cache); }); /** * Test that we can retrieve a file from a released model and skip head calls */ add_task(async function test_getting_released_file() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); let spy = sinon.spy(hub, "getETag"); let [array, headers] = await hub.getModelFileAsArrayBuffer( FAKE_RELEASED_MODEL_ARGS ); Assert.equal(headers["Content-Type"], "application/json"); // check the content of the file. let jsonData = JSON.parse( String.fromCharCode.apply(null, new Uint8Array(array)) ); Assert.equal(jsonData.hidden_size, 768); // check that head calls were not made Assert.ok(!spy.called, "getETag should have never been called."); spy.restore(); }); /** * Make sure files can be located in sub directories */ add_task(async function test_getting_file_in_subdir() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); let [array, metadata] = await hub.getModelFileAsArrayBuffer(FAKE_ONNX_MODEL_ARGS); Assert.equal(metadata["Content-Type"], "application/json"); // check the content of the file. let jsonData = JSON.parse( String.fromCharCode.apply(null, new Uint8Array(array)) ); Assert.equal(jsonData.hidden_size, 768); }); /** * Test that we can use a custom URL template. */ add_task(async function test_getting_file_custom_path() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: "{model}/resolve/{revision}", }); let res = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); Assert.equal(res[1]["Content-Type"], "application/json"); }); /** * Test that we can't use an URL with a query for the template */ add_task(async function test_getting_file_custom_path_rogue() { const urlTemplate = "{model}/resolve/{revision}/?some_id=bedqwdw"; Assert.throws( () => new ModelHub({ rootUrl: FAKE_HUB, urlTemplate }), /Invalid URL template/, `Should throw with ${urlTemplate}` ); }); /** * Test that the file can be returned as a response and its content correct. */ add_task(async function test_getting_file_as_response() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); let response = await hub.getModelFileAsResponse(FAKE_MODEL_ARGS); // check the content of the file. let jsonData = await response.json(); Assert.equal(jsonData.hidden_size, 768); }); /** * Test that the cache is used when the data is retrieved from the server * and that the cache is updated with the new data. */ add_task(async function test_getting_file_from_cache() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); let array = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); var lastUsed = array[1].lastUsed; // stub to verify that the data was retrieved from IndexDB let matchMethod = hub.cache._testGetData; sinon.stub(hub.cache, "_testGetData").callsFake(function () { return matchMethod.apply(this, arguments).then(result => { Assert.notEqual(result, null); return result; }); }); // exercises the cache let array2 = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); hub.cache._testGetData.restore(); let newLastUsed = array2[1].lastUsed; // make sure the last used field was updated Assert.greater(newLastUsed, lastUsed); // we don't compare the lastUsed fiel because it changes for each read Assert.deepEqual(stripLastUsed(array), stripLastUsed(array2)); }); /** * Test that the callback is appropriately called when the data is retrieved from the server * or from the cache. */ add_task(async function test_getting_file_from_url_cache_with_callback() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); hub.cache = await initializeCache(); let numCalls = 0; let currentData = null; let array = await hub.getModelFileAsArrayBuffer({ ...FAKE_MODEL_ARGS, progressCallback: data => { // expecting initiate status and download currentData = data; if (numCalls == 0) { Assert.deepEqual( { type: data.type, statusText: data.statusText, ok: data.ok, model: currentData?.metadata?.model, file: currentData?.metadata?.file, revision: currentData?.metadata?.revision, taskName: currentData?.metadata?.taskName, }, { type: ProgressType.DOWNLOAD, statusText: ProgressStatusText.INITIATE, ok: true, ...FAKE_MODEL_ARGS, }, "Initiate Data from server should be correct" ); } if (numCalls == 1) { Assert.deepEqual( { type: data.type, statusText: data.statusText, ok: data.ok, model: currentData?.metadata?.model, file: currentData?.metadata?.file, revision: currentData?.metadata?.revision, taskName: currentData?.metadata?.taskName, }, { type: ProgressType.DOWNLOAD, statusText: ProgressStatusText.SIZE_ESTIMATE, ok: true, ...FAKE_MODEL_ARGS, }, "size estimate Data from server should be correct" ); } numCalls += 1; }, }); var lastUsed = array[1].lastUsed; Assert.greaterOrEqual(numCalls, 3); // last received message is DONE Assert.deepEqual( { type: currentData?.type, statusText: currentData?.statusText, ok: currentData?.ok, model: currentData?.metadata?.model, file: currentData?.metadata?.file, revision: currentData?.metadata?.revision, taskName: currentData?.metadata?.taskName, }, { type: ProgressType.DOWNLOAD, statusText: ProgressStatusText.DONE, ok: true, ...FAKE_MODEL_ARGS, }, "Done Data from server should be correct" ); // stub to verify that the data was retrieved from IndexDB let matchMethod = hub.cache._testGetData; sinon.stub(hub.cache, "_testGetData").callsFake(function () { return matchMethod.apply(this, arguments).then(result => { Assert.notEqual(result, null); return result; }); }); numCalls = 0; currentData = null; // Now we expect the callback to indicate cache usage. let array2 = await hub.getModelFileAsArrayBuffer({ ...FAKE_MODEL_ARGS, progressCallback: data => { // expecting initiate status and download currentData = data; if (numCalls == 0) { Assert.deepEqual( { type: data.type, statusText: data.statusText, ok: data.ok, model: currentData?.metadata?.model, file: currentData?.metadata?.file, revision: currentData?.metadata?.revision, taskName: currentData?.metadata?.taskName, }, { type: ProgressType.LOAD_FROM_CACHE, statusText: ProgressStatusText.INITIATE, ok: true, ...FAKE_MODEL_ARGS, }, "Initiate Data from cache should be correct" ); } numCalls += 1; }, }); hub.cache._testGetData.restore(); let newLastUsed = array2[1].lastUsed; // make sure the last used field was updated Assert.greater(newLastUsed, lastUsed); Assert.deepEqual(stripLastUsed(array), stripLastUsed(array2)); // last received message is DONE Assert.deepEqual( { type: currentData?.type, statusText: currentData?.statusText, ok: currentData?.ok, model: currentData?.metadata?.model, file: currentData?.metadata?.file, revision: currentData?.metadata?.revision, taskName: currentData?.metadata?.taskName, }, { type: ProgressType.LOAD_FROM_CACHE, statusText: ProgressStatusText.DONE, ok: true, ...FAKE_MODEL_ARGS, }, "Done Data from cache should be correct" ); await deleteCache(hub.cache); }); /** * Test parsing of a well-formed full URL, including protocol and path. */ add_task(async function testWellFormedFullUrl() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: "{model}/{revision}", }); const url = `${FAKE_HUB}/org1/model1/v1/file/path`; const result = hub.parseUrl(url); Assert.equal( result.model, "org1/model1", "Model should be parsed correctly." ); Assert.equal(result.revision, "v1", "Revision should be parsed correctly."); Assert.equal( result.file, "file/path", "File path should be parsed correctly." ); }); /** * Test parsing of well-formed URLs, starting with a slash. */ const URLS_AND_RESULT = [ { url: "/Xenova/bert-base-NER/resolve/main/onnx/model.onnx", model: "Xenova/bert-base-NER", revision: "main", file: "onnx/model.onnx", urlTemplate: "{model}/resolve/{revision}", }, { url: "/org1/model1/v1/file/path", model: "org1/model1", revision: "v1", file: "file/path", urlTemplate: "{model}/{revision}", }, ]; add_task(async function testWellFormedRelativeUrl() { for (const example of URLS_AND_RESULT) { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: example.urlTemplate, }); const result = hub.parseUrl(example.url); Assert.equal( result.model, example.model, "Model should be parsed correctly." ); Assert.equal( result.revision, example.revision, "Revision should be parsed correctly." ); Assert.equal( result.file, example.file, "File path should be parsed correctly." ); } }); /** * Ensures an error is thrown when the URL does not start with the expected root URL or a slash. */ add_task(async function testInvalidDomain() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); const url = "https://example.com/org1/model1/resolve/v1/file/path"; Assert.throws( () => hub.parseUrl(url), new RegExp(`Error: Invalid domain for model URL: ${url}`), `Should throw with ${url}` ); }); /** Tests the method's error handling when the URL format does not include the required segments. * */ add_task(async function testTooFewParts() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); const url = "/org1/model1/resolve"; Assert.throws( () => hub.parseUrl(url), new RegExp(`Error: Invalid model URL format: ${url}`), `Should throw with ${url}` ); }); // IndexedDB tests /** * Helper function to initialize the cache */ async function initializeCache() { const randomSuffix = Math.floor(Math.random() * 10000); const dbName = `modelFiles-${randomSuffix}`; await OPFS.getDirectoryHandle(dbName, { create: true }); return await IndexedDBCache.init({ dbName }); } /** * Helper function to delete the cache database */ async function deleteCache(cache) { await cache.dispose(); indexedDB.deleteDatabase(cache.dbName); try { await OPFS.remove(cache.dbName, { recursive: true }); } catch (e) { // can be empty } } /** * Test the initialization and creation of the IndexedDBCache instance. */ add_task(async function test_Init() { const cache = await initializeCache(); Assert.ok( cache instanceof IndexedDBCache, "The cache instance should be created successfully." ); Assert.ok( IDBDatabase.isInstance(cache.db), `The cache should have an IDBDatabase instance. Found ${cache.db}` ); await deleteCache(cache); }); /** * Test checking existence of data in the cache. */ add_task(async function test_PutAndCheckExists() { const cache = await initializeCache(); const testData = createBlob(); const key = "file.txt"; await cache.put({ taskName: "task", model: "org/model", revision: "v1", file: "file.txt", data: testData, headers: { ETag: "ETAG123", }, }); // Checking if the file exists let exists = await cache.fileExists({ model: "org/model", revision: "v1", file: key, }); Assert.ok(exists, "The file should exist in the cache."); // Removing all files from the model await cache.deleteModels({ model: "org/model", revision: "v1" }); exists = await cache.fileExists({ taskName: "task", model: "org/model", revision: "v1", file: key, }); Assert.ok(!exists, "The file should be gone from the cache."); await deleteCache(cache); }); /** * Test adding data to the cache and retrieving it. */ add_task(async function test_PutAndGet() { const cache = await initializeCache(); const testData = createBlob(); await cache.put({ taskName: "task", model: "org/model", revision: "v1", file: "file.txt", data: testData, headers: { ETag: "ETAG123", }, }); const [retrievedData, headers] = await cache.getFile({ model: "org/model", revision: "v1", file: "file.txt", }); Assert.deepEqual( retrievedData, testData, "The retrieved data should match the stored data." ); Assert.equal( headers.ETag, "ETAG123", "The retrieved ETag should match the stored ETag." ); await deleteCache(cache); }); /** * Test retrieving the headers for a cache entry. */ add_task(async function test_GetHeaders() { const cache = await initializeCache(); const testData = createBlob(); const headers = { ETag: "ETAG123", status: 200, extra: "extra", }; const when = await cache.put({ taskName: "task", model: "org/model", revision: "v1", file: "file.txt", data: testData, headers, }); const storedHeaders = await cache.getHeaders({ model: "org/model", revision: "v1", file: "file.txt", }); // The `extra` field should be removed from the stored headers because // it's not part of the allowed keys. // The content-type one is added when not present Assert.deepEqual( { ETag: "ETAG123", status: 200, "Content-Type": "application/octet-stream", fileSize: 8, lastUsed: when, lastUpdated: when, }, storedHeaders, "The retrieved headers should match the stored headers." ); await deleteCache(cache); }); /** * Test listing all models stored in the cache. */ add_task(async function test_ListModels() { const cache = await initializeCache(); await Promise.all([ cache.put({ taskName: "task1", model: "org1/modelA", revision: "v1", file: "file1.txt", data: createBlob(), headers: null, }), cache.put({ taskName: "task2", model: "org2/modelB", revision: "v2", file: "file2.txt", data: createBlob(), headers: null, }), ]); const models = await cache.listModels(); const expected = [ { name: "org1/modelA", revision: "v1", taskName: "task1" }, { name: "org2/modelB", revision: "v2", taskName: "task2" }, ]; Assert.deepEqual(models, expected, "All models should be listed"); await deleteCache(cache); }); /** * Test deleting a model and its data from the cache. */ add_task(async function test_DeleteModels() { const cache = await initializeCache(); await cache.put({ taskName: "task", model: "org/model", revision: "v1", file: "file.txt", data: createBlob(), headers: null, }); await cache.deleteModels({ model: "org/model", revision: "v1" }); const dataAfterDelete = await cache.getFile({ model: "org/model", revision: "v1", file: "file.txt", }); Assert.equal( dataAfterDelete, null, "The data for the deleted model should not exist." ); await deleteCache(cache); }); /** * Test that after deleting a model from the cache, the remaing models are still there. */ add_task(async function test_nonDeletedModels() { const cache = await initializeCache(); const testData = createRandomBlob(); await Promise.all([ cache.put({ taskName: "task1", model: "org/model", revision: "v1", file: "file.txt", data: testData, headers: { ETag: "ETAG123", }, }), cache.put({ taskName: "task2", model: "org/model2", revision: "v1", file: "file.txt", data: createRandomBlob(), headers: { ETag: "ETAG1234", }, }), cache.put({ taskName: "task3", model: "org/model2", revision: "v1", file: "file2.txt", data: createRandomBlob(), headers: { ETag: "ETAG1234", }, }), ]); await cache.deleteModels({ model: "org/model2", revision: "v1" }); const [retrievedData, headers] = await cache.getFile({ model: "org/model", revision: "v1", file: "file.txt", }); Assert.deepEqual( retrievedData, testData, "The retrieved data should match the stored data." ); Assert.equal( headers.ETag, "ETAG123", "The retrieved ETag should match the stored ETag." ); const dataAfterDelete = await cache.getFile({ model: "org/model2", revision: "v1", file: "file.txt", }); Assert.equal( dataAfterDelete, null, "The data for the deleted model should not exist." ); const dataAfterDelete2 = await cache.getFile({ model: "org/model2", revision: "v1", file: "file2.txt", }); Assert.equal( dataAfterDelete2, null, "The data for the deleted model should not exist." ); await deleteCache(cache); }); /** * Test deleting a model and its data from the cache using a task name. */ add_task(async function test_DeleteModelsUsingTaskName() { const cache = await initializeCache(); const model = "mozilla/distilvit"; const revision = "main"; const taskName = "echo"; await cache.put({ taskName, model, revision, file: "file.txt", data: createBlob(), headers: null, }); await cache.deleteModels({ taskName }); // Model should be gone. const models = await cache.listModels(); const expected = []; Assert.deepEqual(models, expected, "All models should be deleted."); const dataAfterDelete = await cache.getFile({ model, revision, file: "file.txt", }); Assert.equal( dataAfterDelete, null, "The data for the deleted model should not exist." ); await deleteCache(cache); }); /** * Test deleting a model and its data from the cache using a non-existing task name. */ add_task(async function test_DeleteModelsUsingNonExistingTaskName() { const cache = await initializeCache(); const model = "mozilla/distilvit"; const revision = "main"; const taskName = "echo"; await cache.put({ taskName, model, revision, file: "file.txt", data: createBlob(), headers: null, }); await cache.deleteModels({ taskName: "non-existing-task" }); // Model should still be there. const models = await cache.listModels(); const expected = [{ name: model, revision, taskName }]; Assert.deepEqual(models, expected, "All models should be listed"); await deleteCache(cache); }); /** * Test that after deleting a model from the cache, the remaing models are still there. */ add_task(async function test_deleteNonMatchingModelRevisions() { const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); const cache = await initializeCache(); hub.cache = cache; const testData = createRandomBlob(); const testData2 = createRandomBlob(); const taskName = "task"; const file = "file.txt"; const hostname = new URL(FAKE_HUB).hostname; await Promise.all([ cache.put({ taskName, model: `${hostname}/org/model`, revision: "v1", file, data: testData, headers: { ETag: "ETAG123", }, }), cache.put({ taskName, model: `${hostname}/org/model2`, revision: "v1", file, data: createRandomBlob(), headers: { ETag: "ETAG1234", }, }), cache.put({ taskName, model: `${hostname}/org/model2`, revision: "v2", file, data: createRandomBlob(), headers: { ETag: "ETAG1234", }, }), cache.put({ taskName, model: `${hostname}/org/model2`, revision: "v3", file, data: testData2, headers: { ETag: "ETAG1234", }, }), ]); await hub.deleteNonMatchingModelRevisions({ taskName, modelWithHostname: `${hostname}/org/model2`, targetRevision: "v3", }); const [retrievedData, headers] = await cache.getFile({ model: `${hostname}/org/model`, revision: "v1", file, }); Assert.deepEqual( retrievedData, testData, "The retrieved data should match the stored data." ); Assert.equal( headers.ETag, "ETAG123", "The retrieved ETag should match the stored ETag." ); const dataAfterDelete = await cache.getFile({ model: `${hostname}/org/model2`, revision: "v1", file, }); Assert.equal(dataAfterDelete, null, "The data for v1 should not exist."); const dataAfterDelete2 = await cache.getFile({ model: `${hostname}/org/model2`, revision: "v2", file, }); Assert.equal(dataAfterDelete2, null, "The data for v2 should not exist."); const [retrievedData2, headers2] = await cache.getFile({ model: `${hostname}/org/model2`, revision: "v3", file, }); Assert.deepEqual( retrievedData2, testData2, "The retrieved data for v3 should match the stored data." ); Assert.equal( headers2.ETag, "ETAG1234", "The retrieved ETag for v3 should match the stored ETag." ); await deleteCache(cache); }); /** * Test listing files */ add_task(async function test_listFiles() { const cache = await initializeCache(); const headers = { "Content-Length": "12345", ETag: "XYZ" }; const blob = createBlob(); const when1 = await cache.put({ taskName: "task1", model: "org/model", revision: "v1", file: "file.txt", data: blob, headers: null, }); const when2 = await cache.put({ taskName: "task1", model: "org/model", revision: "v1", file: "file2.txt", data: blob, headers: null, }); const when3 = await cache.put({ taskName: "task2", model: "org/model", revision: "v1", file: "sub/file3.txt", data: createBlob(32), headers, }); const { files } = await cache.listFiles({ model: "org/model", revision: "v1", }); const expected = [ { path: "file.txt", headers: { "Content-Type": "application/octet-stream", fileSize: 8, ETag: "NO_ETAG", lastUsed: when1, lastUpdated: when1, }, engineIds: [], }, { path: "file2.txt", headers: { "Content-Type": "application/octet-stream", fileSize: 8, ETag: "NO_ETAG", lastUsed: when2, lastUpdated: when2, }, engineIds: [], }, { path: "sub/file3.txt", headers: { "Content-Length": "12345", "Content-Type": "application/octet-stream", fileSize: 32, ETag: "XYZ", lastUsed: when3, lastUpdated: when3, }, engineIds: [], }, ]; Assert.deepEqual(files, expected); await deleteCache(cache); }); /** * Test listing files using a task name */ add_task(async function test_listFilesUsingTaskName() { const cache = await initializeCache(); const model = "mozilla/distilvit"; const revision = "main"; const taskName = "echo"; const headers = { "Content-Length": "12345", ETag: "XYZ" }; const blob = createBlob(); const when1 = await cache.put({ taskName, model, revision, file: "file.txt", data: blob, headers: null, }); const when2 = await cache.put({ taskName, model, revision, file: "file2.txt", data: blob, headers: null, }); const when3 = await cache.put({ taskName, model, revision, file: "sub/file3.txt", data: createBlob(32), headers, }); const { files } = await cache.listFiles({ taskName, model, revision }); const expected = [ { path: "file.txt", headers: { "Content-Type": "application/octet-stream", fileSize: 8, ETag: "NO_ETAG", lastUsed: when1, lastUpdated: when1, }, engineIds: [], }, { path: "file2.txt", headers: { "Content-Type": "application/octet-stream", fileSize: 8, ETag: "NO_ETAG", lastUsed: when2, lastUpdated: when2, }, engineIds: [], }, { path: "sub/file3.txt", headers: { "Content-Length": "12345", "Content-Type": "application/octet-stream", fileSize: 32, ETag: "XYZ", lastUsed: when3, lastUpdated: when3, }, engineIds: [], }, ]; Assert.deepEqual(files, expected); await deleteCache(cache); }); /** * Test listing files using a non existing task name */ add_task(async function test_listFilesUsingNonExistingTaskName() { const cache = await initializeCache(); const model = "mozilla/distilvit"; const revision = "main"; const taskName = "echo"; const headers = { "Content-Length": "12345", ETag: "XYZ" }; const blob = createBlob(); await Promise.all([ cache.put({ taskName, model, revision, file: "file.txt", data: blob, headers: null, }), cache.put({ taskName, model, revision, file: "file2.txt", data: blob, headers: null, }), cache.put({ taskName, model, revision, file: "sub/file3.txt", data: createBlob(32), headers, }), ]); const { files } = await cache.listFiles({ taskName: "non-existing-task" }); Assert.deepEqual(files, []); await deleteCache(cache); }); /** * Test the ability to add a database from a non-existing database. */ add_task(async function test_initDbFromNonExisting() { const cache = await initializeCache(); Assert.notEqual(cache, null); await deleteCache(cache); }); /** * Test that we can upgrade even if the existing database is missing some stores or indices. */ add_task(async function test_initDbFromExistingEmpty() { const randomSuffix = Math.floor(Math.random() * 10000); const dbName = `modelFiles-${randomSuffix}`; const dbVersion = 1; const newVersion = dbVersion + 1; async function openDB() { return new Promise((resolve, reject) => { const request = indexedDB.open(dbName, dbVersion); request.onerror = event => reject(event.target.error); request.onsuccess = event => resolve(event.target.result); }); } const db = await openDB(); db.close(); const cache = await IndexedDBCache.init({ dbName, version: newVersion }); Assert.notEqual(cache, null); Assert.equal(cache.db.version, newVersion); const model = "mozilla/distilvit"; const revision = "main"; const taskName = "echo"; const blob = createBlob(); const when = await cache.put({ taskName, model, revision, file: "file.txt", data: blob, headers: null, }); const expected = [ { path: "file.txt", headers: { "Content-Type": "application/octet-stream", fileSize: 8, ETag: "NO_ETAG", lastUsed: when, lastUpdated: when, }, engineIds: [], }, ]; // Ensure every table & indices is on so that we can list files const { files } = await cache.listFiles({ taskName, model, revision }); Assert.deepEqual(files, expected); await deleteCache(cache); }); /** * Test that upgrading from version 1 to version 2 results in existing data being deleted. */ add_task(async function test_initDbFromExistingNoChange() { const randomSuffix = Math.floor(Math.random() * 10000); const dbName = `modelFiles-${randomSuffix}`; // Create version 1 let cache = await IndexedDBCache.init({ dbName, version: 1 }); Assert.notEqual(cache, null); Assert.equal(cache.db.version, 1); const model = "mozilla/distilvit"; const revision = "main"; const taskName = "echo"; const blob = createBlob(); await cache.put({ taskName, model, revision, file: "file.txt", data: blob, headers: null, }); cache.db.close(); // Create version 2 cache = await IndexedDBCache.init({ dbName, version: 2 }); Assert.notEqual(cache, null); Assert.equal(cache.db.version, 2); // Ensure tables are all empty. const { files } = await cache.listFiles({ taskName }); Assert.deepEqual(files, []); await deleteCache(cache); }); /** * Test that upgrading an existing cache from another source is possible. */ add_task(async function test_initDbFromExistingElseWhereStoreChanges() { const randomSuffix = Math.floor(Math.random() * 10000); const dbName = `modelFiles-${randomSuffix}`; const dbVersion = 2; const model = "mozilla/distilvit"; const revision = "main"; const taskName = "echo"; const blob = createBlob(); // Create version 2 const cache1 = await IndexedDBCache.init({ dbName, version: dbVersion }); Assert.notEqual(cache1, null); Assert.equal(cache1.db.version, 2); // Cache1 is not closed by design of this test // Create version 3 const cache2 = await IndexedDBCache.init({ dbName, version: dbVersion + 1 }); Assert.notEqual(cache2, null); Assert.equal(cache2.db.version, 3); const when = await cache2.put({ taskName, model, revision, file: "file.txt", data: blob, headers: null, }); const expected = [ { path: "file.txt", headers: { "Content-Type": "application/octet-stream", fileSize: 8, ETag: "NO_ETAG", lastUpdated: when, lastUsed: when, }, engineIds: [], }, ]; // Ensure every table & indices is on so that we can list files const { files } = await cache2.listFiles({ taskName, model, revision }); Assert.deepEqual(files, expected); await deleteCache(cache2); }); /** * Test that we can use a custom hub on every API call to get files. */ add_task(async function test_getting_file_custom_hub() { // The hub is configured to use localhost const hub = new ModelHub({ rootUrl: "https://localhost", urlTemplate: "{model}/boo/revision", }); // but we can use APIs against another hub const args = { model: "acme/bert", revision: "main", file: "config.json", taskName: "task_model", modelHubRootUrl: FAKE_HUB, modelHubUrlTemplate: "{model}/resolve/{revision}", }; let [array, headers] = await hub.getModelFileAsArrayBuffer(args); Assert.equal(headers["Content-Type"], "application/json"); // check the content of the file. let jsonData = JSON.parse( String.fromCharCode.apply(null, new Uint8Array(array)) ); Assert.equal(jsonData.hidden_size, 768); let res = await hub.getModelFileAsBlob(args); Assert.equal(res[0].size, 562); let response = await hub.getModelFileAsResponse(args); Assert.equal((await response.blob()).size, 562); }); /** * Make sure that we can't pass a rootUrl that is not allowed when using the API calls */ add_task(async function test_getting_file_disallowed_custom_hub() { // The hub is configured to use localhost const hub = new ModelHub({ rootUrl: "https://localhost", urlTemplate: "{model}/boo/revision", allowDenyList: [{ filter: "ALLOW", urlPrefix: "https://example.com" }], }); // and we can't use APIs against another hub if it's not allowed const args = { model: "acme/bert", revision: "main", file: "config.json", taskName: "task_model", modelHubRootUrl: "https://forbidden.com", modelHubUrlTemplate: "{model}/{revision}", }; // This catch the error returned by getEtag when checking if file is in cache try { await hub.getModelFileAsArrayBuffer(args); throw new Error("Expected method to reject."); } catch (error) { Assert.throws( () => { throw error; }, new RegExp(`ForbiddenURLError`), `Should throw with https://forbidden.com` ); } // This catch the error returned when useCached is false try { await hub.getModelFileAsArrayBuffer({ ...args, revision: "v1" }); throw new Error("Expected method to reject."); } catch (error) { Assert.throws( () => { throw error; }, new RegExp(`ForbiddenURLError`), `Should throw with https://forbidden.com` ); } try { await hub.getModelFileAsBlob(args); throw new Error("Expected method to reject."); } catch (error) { Assert.throws( () => { throw error; }, new RegExp(`ForbiddenURLError`), `Should throw with https://forbidden.com` ); } try { await hub.getModelFileAsResponse(args); throw new Error("Expected method to reject."); } catch (error) { Assert.throws( () => { throw error; }, new RegExp(`ForbiddenURLError`), `Should throw with https://forbidden.com` ); } // This catch the error when http error codes are returned, useCached is false try { await hub.getModelFileAsArrayBuffer({ ...args, revision: "v1", modelHubRootUrl: "https://example.com", }); throw new Error("Expected method to reject."); } catch (error) { Assert.throws( () => { throw error; }, new RegExp(`HTTP error! Status: 404 Not Found`), `Should throw with 404` ); } // This catch the error returned when useCached is true with no checks for etags try { // store a file in the hub await hub.cache.put({ ...args, model: "forbidden.com/acme/bert", engineId: "engineOne", revision: "v1", data: createBlob(), headers: null, }); await hub.getModelFileAsArrayBuffer({ ...args, revision: "v1" }); throw new Error("Expected method to reject."); } catch (error) { Assert.throws( () => { throw error; }, new RegExp(`ForbiddenURLError`), `Should throw with https://forbidden.com` ); } }); /** * Test deleting files used by several engines */ add_task(async function test_DeleteFileByEngines() { const cache = await initializeCache(); const testData = createBlob(); const engineOne = "engine-1"; const engineTwo = "engine-2"; // a file is stored by engineOne await cache.put({ engineId: engineOne, taskName: "task", model: "org/model", revision: "v1", file: "file.txt", data: createBlob(), headers: null, }); // The file is read by engineTwo let retrievedData = await cache.getFile({ engineId: engineTwo, model: "org/model", revision: "v1", file: "file.txt", }); Assert.deepEqual( retrievedData[0], testData, "The retrieved data should match the stored data." ); // if we delete the model by engineOne, it will still be around for engineTwo await cache.deleteFilesByEngine({ engineId: engineOne }); retrievedData = await cache.getFile({ engineId: engineTwo, model: "org/model", revision: "v1", file: "file.txt", }); Assert.deepEqual( retrievedData[0], testData, "The retrieved data should match the stored data." ); // now deleting via engineTwo await cache.deleteFilesByEngine({ engineId: engineTwo }); // at this point we should not have anymore files const dataAfterDelete = await cache.getFile({ engineId: engineOne, model: "org/model", revision: "v1", file: "file.txt", }); Assert.equal( dataAfterDelete, null, "The data for the deleted model should not exist." ); await deleteCache(cache); }); /** * Test deleting files used by an engine via the model hub. This is similar to * `test_DeleteFileByEngines`, except we are calling the model hub method. */ add_task(async function test_ModelHub_DeleteFileByEngines() { const cache = await initializeCache(); const engineOne = "engine-1"; const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, allowDenyList: [], }); hub.cache = cache; // a file is stored by engineOne await cache.put({ engineId: engineOne, taskName: "task", model: "org/model", revision: "v1", file: "file.txt", data: createBlob(), headers: null, }); await hub.deleteFilesByEngine({ engineId: engineOne }); // at this point we should not have anymore files const dataAfterDelete = await cache.getFile({ engineId: engineOne, model: "org/model", revision: "v1", file: "file.txt", }); Assert.equal( dataAfterDelete, null, "The data for the deleted model should not exist." ); await deleteCache(cache); }); // tests allow deny list updating after model is cached add_task(async function test_update_allow_deny_after_model_cache() { const cache = await initializeCache(); const file = "config.json"; const taskName = FAKE_MODEL_ARGS.taskName; const model = FAKE_MODEL_ARGS.model; const revision = "v0.1"; await cache.put({ taskName, model, revision, file, data: createBlob(), headers: null, }); let exists = await cache.fileExists({ model, revision, file, }); Assert.ok(exists, "The file should exist in the cache."); let list = [ { filter: "ALLOW", urlPrefix: "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data/acme", }, ]; let hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, allowDenyList: list, }); hub.cache = cache; // should go through since model is allowed await hub.getModelFileAsArrayBuffer({ ...FAKE_MODEL_ARGS, file, revision }); // put model in deny list list = [ { filter: "DENY", urlPrefix: "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data/acme", }, ]; hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, allowDenyList: list, }); hub.cache = cache; // now ensure the model cannot be called after being put in the deny list try { await hub.getModelFileAsArrayBuffer({ ...FAKE_MODEL_ARGS, file, revision }); } catch (e) { Assert.ok(e.name === "ForbiddenURLError"); } // make sure that the model is deleted after const dataAfterForbidden = await cache.getFile({ model, revision, file, }); Assert.equal( dataAfterForbidden, null, "The data for the deleted model should not exist." ); }); /** * Test that data from OPFS is wiped */ add_task(async function test_migrateStore_modelsDeleted() { const randomSuffix = Math.floor(Math.random() * 10000); const dbName = `modelFiles-${randomSuffix}`; // Initialize version 4 of the database let cache = await IndexedDBCache.init({ dbName, version: 4 }); // Add some test data for unknown models await Promise.all([ cache.put({ taskName: "task", model: "random/model", revision: "v1", file: "random.txt", data: createBlob(), headers: null, }), cache.put({ taskName: "task", model: "unknown/model", revision: "v2", file: "unknown.txt", data: createBlob(), headers: null, }), ]); // Close version 4 and upgrade to version 5 cache.db.close(); cache = await IndexedDBCache.init({ dbName, version: 5 }); // Verify all unknown model data is deleted const { files: random } = await cache.listFiles({ model: "random/model", revision: "v1", }); Assert.deepEqual(random, [], "All unknown model files should be deleted."); const { files: unknown } = await cache.listFiles({ model: "unknown/model", revision: "v2", }); Assert.deepEqual(unknown, [], "All unknown model files should be deleted."); await deleteCache(cache); }); /** * Test migration when database starts empty. */ add_task(async function test_migrateStore_emptyDatabase() { const randomSuffix = Math.floor(Math.random() * 10000); const dbName = `modelFiles-${randomSuffix}`; // Initialize an empty version 4 database let cache = await IndexedDBCache.init({ dbName, version: 4 }); cache.db.close(); // Upgrade to version 5 cache = await IndexedDBCache.init({ dbName, version: 5 }); // Verify database is still empty const models = await cache.listModels(); Assert.deepEqual( models, [], "The database should remain empty after migration." ); await deleteCache(cache); }); add_task(async function test_getOwnerIcon() { await SpecialPowers.pushPrefEnv({ set: [["browser.ml.logLevel", "All"]], }); const hub = new ModelHub({ rootUrl: FAKE_HUB, urlTemplate: FAKE_URL_TEMPLATE, }); const fullyQualifiedModelName = "mochitests/mozilla/distilvit"; // first call will get the icon from the web const icon = await hub.getOwnerIcon(fullyQualifiedModelName); Assert.notEqual(icon, null); // second call will get it from the cache let spy = sinon.spy(OPFS.File.prototype, "getBlobFromOPFS"); const icon2 = await hub.getOwnerIcon(fullyQualifiedModelName); Assert.notEqual(icon2, null); // check that it cames from OPFS Assert.notEqual(await spy.lastCall.returnValue, null); sinon.restore(); });