diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-15 03:34:42 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-15 03:34:42 +0000 |
commit | da4c7e7ed675c3bf405668739c3012d140856109 (patch) | |
tree | cdd868dba063fecba609a1d819de271f0d51b23e /toolkit/components/ml | |
parent | Adding upstream version 125.0.3. (diff) | |
download | firefox-da4c7e7ed675c3bf405668739c3012d140856109.tar.xz firefox-da4c7e7ed675c3bf405668739c3012d140856109.zip |
Adding upstream version 126.0.upstream/126.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/components/ml')
9 files changed, 1108 insertions, 1 deletions
diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs index 10b4eed4fa..05203e5f69 100644 --- a/toolkit/components/ml/actors/MLEngineParent.sys.mjs +++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs @@ -91,7 +91,7 @@ export class MLEngineParent extends JSWindowActorParent { } // eslint-disable-next-line consistent-return - async receiveMessage({ name, data }) { + async receiveMessage({ name }) { switch (name) { case "MLEngine:Ready": if (lazy.EngineProcess.resolveMLEngineParent) { diff --git a/toolkit/components/ml/content/ModelHub.sys.mjs b/toolkit/components/ml/content/ModelHub.sys.mjs new file mode 100644 index 0000000000..4c2181ff14 --- /dev/null +++ b/toolkit/components/ml/content/ModelHub.sys.mjs @@ -0,0 +1,690 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +const lazy = {}; + +ChromeUtils.defineESModuleGetters(lazy, { + clearTimeout: "resource://gre/modules/Timer.sys.mjs", + setTimeout: "resource://gre/modules/Timer.sys.mjs", +}); + +ChromeUtils.defineLazyGetter(lazy, "console", () => { + return console.createInstance({ + maxLogLevelPref: "browser.ml.logLevel", + prefix: "ML", + }); +}); + +const ALLOWED_HUBS = [ + "chrome://*", + "resource://*", + "http://localhost", + "https://localhost", + "https://model-hub.mozilla.org", +]; + +const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"]; +const DEFAULT_URL_TEMPLATE = + "${organization}/${modelName}/resolve/${modelVersion}/${file}"; + +/** + * Checks if a given URL string corresponds to an allowed hub. + * + * This function validates a URL against a list of allowed hubs, ensuring that it: + * - Is well-formed according to the URL standard. + * - Does not include a username or password. + * - Matches the allowed scheme and hostname. + * + * @param {string} urlString The URL string to validate. + * @returns {boolean} True if the URL is allowed; false otherwise. + */ +function allowedHub(urlString) { + try { + const url = new URL(urlString); + // Check for username or password in the URL + if (url.username !== "" || url.password !== "") { + return false; // Reject URLs with username or password + } + const scheme = url.protocol; + const host = url.hostname; + const fullPrefix = `${scheme}//${host}`; + + return ALLOWED_HUBS.some(allowedHub => { + const [allowedScheme, allowedHost] = allowedHub.split("://"); + if (allowedHost === "*") { + return `${allowedScheme}:` === scheme; + } + const allowedPrefix = `${allowedScheme}://${allowedHost}`; + return fullPrefix === allowedPrefix; + }); + } catch (error) { + lazy.console.error("Error parsing URL:", error); + return false; + } +} + +const NO_ETAG = "NO_ETAG"; + +/** + * Class for managing a cache stored in IndexedDB. + */ +export class IndexedDBCache { + /** + * Reference to the IndexedDB database. + * + * @type {IDBDatabase|null} + */ + db = null; + + /** + * Version of the database. Null if not set. + * + * @type {number|null} + */ + dbVersion = null; + + /** + * Total size of the files stored in the cache. + * + * @type {number} + */ + totalSize = 0; + + /** + * Name of the database used by IndexedDB. + * + * @type {string} + */ + dbName; + + /** + * Name of the object store for storing files. + * + * @type {string} + */ + fileStoreName; + + /** + * Name of the object store for storing headers. + * + * @type {string} + */ + headersStoreName; + /** + * Maximum size of the cache in bytes. Defaults to 1GB. + * + * @type {number} + */ + #maxSize = 1_073_741_824; // 1GB in bytes + + /** + * Private constructor to prevent direct instantiation. + * Use IndexedDBCache.init to create an instance. + * + * @param {string} dbName - The name of the database file. + * @param {number} version - The version number of the database. + */ + constructor(dbName = "modelFiles", version = 1) { + this.dbName = dbName; + this.dbVersion = version; + this.fileStoreName = "files"; + this.headersStoreName = "headers"; + } + + /** + * Static method to create and initialize an instance of IndexedDBCache. + * + * @param {string} [dbName="modelFiles"] - The name of the database. + * @param {number} [version=1] - The version number of the database. + * @returns {Promise<IndexedDBCache>} An initialized instance of IndexedDBCache. + */ + static async init(dbName = "modelFiles", version = 1) { + const cacheInstance = new IndexedDBCache(dbName, version); + cacheInstance.db = await cacheInstance.#openDB(); + const storedSize = await cacheInstance.#getData( + cacheInstance.headersStoreName, + "totalSize" + ); + cacheInstance.totalSize = storedSize ? storedSize.size : 0; + return cacheInstance; + } + + /** + * Called to close the DB connection and dispose the instance + * + */ + async dispose() { + if (this.db) { + this.db.close(); + this.db = null; + } + } + + /** + * Opens or creates the IndexedDB database. + * + * @returns {Promise<IDBDatabase>} + */ + async #openDB() { + return new Promise((resolve, reject) => { + const request = indexedDB.open(this.dbName, this.dbVersion); + request.onerror = event => reject(event.target.error); + request.onsuccess = event => resolve(event.target.result); + request.onupgradeneeded = event => { + const db = event.target.result; + if (!db.objectStoreNames.contains(this.fileStoreName)) { + db.createObjectStore(this.fileStoreName, { keyPath: "id" }); + } + if (!db.objectStoreNames.contains(this.headersStoreName)) { + db.createObjectStore(this.headersStoreName, { keyPath: "id" }); + } + }; + }); + } + + /** + * Generic method to get the data from a specified object store. + * + * @param {string} storeName - The name of the object store. + * @param {string} key - The key within the object store to retrieve the data from. + * @returns {Promise<any>} + */ + async #getData(storeName, key) { + return new Promise((resolve, reject) => { + const transaction = this.db.transaction([storeName], "readonly"); + const store = transaction.objectStore(storeName); + const request = store.get(key); + request.onerror = event => reject(event.target.error); + request.onsuccess = event => resolve(event.target.result); + }); + } + + // Used in tests + async _testGetData(storeName, key) { + return this.#getData(storeName, key); + } + + /** + * Generic method to update data in a specified object store. + * + * @param {string} storeName - The name of the object store. + * @param {object} data - The data to store. + * @returns {Promise<void>} + */ + async #updateData(storeName, data) { + return new Promise((resolve, reject) => { + const transaction = this.db.transaction([storeName], "readwrite"); + const store = transaction.objectStore(storeName); + const request = store.put(data); + request.onerror = event => reject(event.target.error); + request.onsuccess = () => resolve(); + }); + } + + /** + * Deletes a specific cache entry. + * + * @param {string} storeName - The name of the object store. + * @param {string} key - The key of the entry to delete. + * @returns {Promise<void>} + */ + async #deleteData(storeName, key) { + return new Promise((resolve, reject) => { + const transaction = this.db.transaction([storeName], "readwrite"); + const store = transaction.objectStore(storeName); + const request = store.delete(key); + request.onerror = event => reject(event.target.error); + request.onsuccess = () => resolve(); + }); + } + + /** + * Retrieves the headers for a specific cache entry. + * + * @param {string} organization - The organization name. + * @param {string} modelName - The model name. + * @param {string} modelVersion - The model version. + * @param {string} file - The file name. + * @returns {Promise<object|null>} The headers or null if not found. + */ + async getHeaders(organization, modelName, modelVersion, file) { + const headersKey = `${organization}/${modelName}/${modelVersion}`; + const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + const headers = await this.#getData(this.headersStoreName, headersKey); + if (headers && headers.files[cacheKey]) { + return headers.files[cacheKey]; + } + return null; // Return null if no headers is found + } + + /** + * Retrieves the file for a specific cache entry. + * + * @param {string} organization - The organization name. + * @param {string} modelName - The model name. + * @param {string} modelVersion - The model version. + * @param {string} file - The file name. + * @returns {Promise<[ArrayBuffer, object]|null>} The file ArrayBuffer and its headers or null if not found. + */ + async getFile(organization, modelName, modelVersion, file) { + const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + const stored = await this.#getData(this.fileStoreName, cacheKey); + if (stored) { + const headers = await this.getHeaders( + organization, + modelName, + modelVersion, + file + ); + return [stored.data, headers]; + } + return null; // Return null if no file is found + } + + /** + * Adds or updates a cache entry. + * + * @param {string} organization - The organization name. + * @param {string} modelName - The model name. + * @param {string} modelVersion - The model version. + * @param {string} file - The file name. + * @param {ArrayBuffer} arrayBuffer - The data to cache. + * @param {object} [headers] - The headers for the file. + * @returns {Promise<void>} + */ + async put( + organization, + modelName, + modelVersion, + file, + arrayBuffer, + headers = {} + ) { + const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + const newSize = this.totalSize + arrayBuffer.byteLength; + if (newSize > this.#maxSize) { + throw new Error("Exceeding total cache size limit of 1GB"); + } + + const headersKey = `${organization}/${modelName}/${modelVersion}`; + const data = { id: cacheKey, data: arrayBuffer }; + + // Store the file data + await this.#updateData(this.fileStoreName, data); + + // Update headers store - whith defaults for ETag and Content-Type + headers = headers || {}; + headers["Content-Type"] = + headers["Content-Type"] ?? "application/octet-stream"; + headers.ETag = headers.ETag ?? NO_ETAG; + + // filter out any keys that are not allowed + headers = Object.keys(headers) + .filter(key => ALLOWED_HEADERS_KEYS.includes(key)) + .reduce((obj, key) => { + obj[key] = headers[key]; + return obj; + }, {}); + + const headersStore = (await this.#getData( + this.headersStoreName, + headersKey + )) || { + id: headersKey, + files: {}, + }; + headersStore.files[cacheKey] = headers; + await this.#updateData(this.headersStoreName, headersStore); + + // Update size + await this.#updateTotalSize(arrayBuffer.byteLength); + } + + /** + * Updates the total size of the cache. + * + * @param {number} sizeToAdd - The size to add to the total. + * @returns {Promise<void>} + */ + async #updateTotalSize(sizeToAdd) { + this.totalSize += sizeToAdd; + await this.#updateData(this.headersStoreName, { + id: "totalSize", + size: this.totalSize, + }); + } + /** + * Deletes all data related to a specific model. + * + * @param {string} organization - The organization name. + * @param {string} modelName - The model name. + * @param {string} modelVersion - The model version. + * @returns {Promise<void>} + */ + async deleteModel(organization, modelName, modelVersion) { + const headersKey = `${organization}/${modelName}/${modelVersion}`; + const headers = await this.#getData(this.headersStoreName, headersKey); + if (headers) { + for (const fileKey in headers.files) { + await this.#deleteData(this.fileStoreName, fileKey); + } + await this.#deleteData(this.headersStoreName, headersKey); // Remove headers entry after files are deleted + } + } + + /** + * Lists all models stored in the cache. + * + * @returns {Promise<Array<string>>} An array of model identifiers. + */ + async listModels() { + const models = []; + return new Promise((resolve, reject) => { + const transaction = this.db.transaction( + [this.headersStoreName], + "readonly" + ); + const store = transaction.objectStore(this.headersStoreName); + const request = store.openCursor(); + request.onerror = event => reject(event.target.error); + request.onsuccess = event => { + const cursor = event.target.result; + if (cursor) { + models.push(cursor.value.id); // Assuming id is the organization/modelName + cursor.continue(); + } else { + resolve(models); + } + }; + }); + } +} + +export class ModelHub { + constructor({ rootUrl, urlTemplate = DEFAULT_URL_TEMPLATE }) { + if (!allowedHub(rootUrl)) { + throw new Error(`Invalid model hub root url: ${rootUrl}`); + } + this.rootUrl = rootUrl; + this.cache = null; + + // Ensures the URL template is well-formed and does not contain any invalid characters. + const pattern = /^(?:\$\{\w+\}|\w+)(?:\/(?:\$\{\w+\}|\w+))*$/; + // ^ $ Start and end of string + // (?:\$\{\w+\}|\w+) Match a ${placeholder} or alphanumeric characters + // (?:\/(?:\$\{\w+\}|\w+))* Zero or more groups of a forward slash followed by a ${placeholder} or alphanumeric characters + if (!pattern.test(urlTemplate)) { + throw new Error(`Invalid URL template: ${urlTemplate}`); + } + this.urlTemplate = urlTemplate; + } + + async #initCache() { + if (this.cache) { + return; + } + this.cache = await IndexedDBCache.init(); + } + + /** Creates the file URL from the organization, model, and version. + * + * @param {string} organization + * @param {string} modelName + * @param {string} modelVersion + * @param {string} file + * @returns {string} The full URL + */ + #fileUrl(organization, modelName, modelVersion, file) { + const baseUrl = new URL(this.rootUrl); + if (!baseUrl.pathname.endsWith("/")) { + baseUrl.pathname += "/"; + } + + // Replace placeholders in the URL template with the provided data. + // If some keys are missing in the data object, the placeholder is left as is. + // If the placeholder is not found in the data object, it is left as is. + const data = { + organization, + modelName, + modelVersion, + file, + }; + const path = this.urlTemplate.replace( + /\$\{(\w+)\}/g, + (match, key) => data[key] || match + ); + const fullPath = `${baseUrl.pathname}${ + path.startsWith("/") ? path.slice(1) : path + }`; + + const urlObject = new URL(fullPath, baseUrl.origin); + urlObject.searchParams.append("download", "true"); + return urlObject.toString(); + } + + /** Checks the organization, model, and version inputs. + * + * @param { string } organization + * @param { string } modelName + * @param { string } modelVersion + * @param { string } file + * @returns { Error } The error instance(can be null) + */ + #checkInput(organization, modelName, modelVersion, file) { + // Ensures string consists only of letters, digits, and hyphens without starting/ending + // with a hyphen or containing consecutive hyphens. + // + // ^ $ Start and end of string + // (?!-) (?<!-) Negative lookahead/behind for not starting or ending with hyphen + // (?!.*--) Negative lookahead for not containing consecutive hyphens + // [A-Za-z0-9-]+ Alphanum characters or hyphens, one or more + const orgRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)$/; + + // Matches strings containing letters, digits, hyphens, underscores, or periods. + // ^ $ Start and end of string + // [A-Za-z0-9-_.]+ Alphanum characters, hyphens, underscores, or periods, one or more times + const modelRegex = /^[A-Za-z0-9-_.]+$/; + + // Matches strings consisting of alphanumeric characters, hyphens, or periods. + // + // ^ $ Start and end of string + // [A-Za-z0-9-.]+ Alphanum characters, hyphens, or periods, one or more times + const versionRegex = /^[A-Za-z0-9-.]+$/; + + // Matches filenames with subdirectories, starting with alphanumeric or underscore, + // and optionally ending with a dot followed by a 2-4 letter extension. + // + // ^ $ Start and end of string + // (?:\/)? Optional leading slash (for absolute paths or root directory) + // (?!\/) Negative lookahead for not starting with a slash + // [A-Za-z0-9-_]+ First directory or filename + // (?: Begin non-capturing group for additional directories or file + // \/ Directory separator + // [A-Za-z0-9-_]+ Directory or file name + // )* Zero or more times + // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension + const fileRegex = + /^(?:\/)?(?!\/)[A-Za-z0-9-_]+(?:\/[A-Za-z0-9-_]+)*(?:[.][A-Za-z]{2,4})?$/; + + if (!orgRegex.test(organization) || !isNaN(parseInt(organization))) { + return new Error(`Invalid organization name ${organization}`); + } + + if (!modelRegex.test(modelName)) { + return new Error("Invalid model name."); + } + + if ( + !versionRegex.test(modelVersion) || + modelVersion.includes(" ") || + /[\^$]/.test(modelVersion) + ) { + return new Error("Invalid version identifier."); + } + + if (!fileRegex.test(file)) { + return new Error("Invalid file name"); + } + + return null; + } + + /** + * Returns the ETag value given an URL + * + * @param {string} url + * @param {number} timeout in ms. Default is 1000 + * @returns {Promise<string>} ETag (can be null) + */ + async #getETag(url, timeout = 1000) { + const controller = new AbortController(); + const id = lazy.setTimeout(() => controller.abort(), timeout); + + try { + const headResponse = await fetch(url, { + method: "HEAD", + signal: controller.signal, + }); + const currentEtag = headResponse.headers.get("ETag"); + return currentEtag; + } catch (error) { + lazy.console.warn("An error occurred when calling HEAD:", error); + return null; + } finally { + lazy.clearTimeout(id); + } + } + + /** + * Given an organization, model, and version, fetch a model file in the hub as a Response. + * + * @param {object} config + * @param {string} config.organization + * @param {string} config.modelName + * @param {string} config.modelVersion + * @param {string} config.file + * @returns {Promise<Response>} The file content + */ + async getModelFileAsResponse({ + organization, + modelName, + modelVersion, + file, + }) { + const [blob, headers] = await this.getModelFileAsBlob({ + organization, + modelName, + modelVersion, + file, + }); + return new Response(blob, { headers }); + } + + /** + * Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer. + * + * @param {object} config + * @param {string} config.organization + * @param {string} config.modelName + * @param {string} config.modelVersion + * @param {string} config.file + * @returns {Promise<[ArrayBuffer, headers]>} The file content + */ + async getModelFileAsArrayBuffer({ + organization, + modelName, + modelVersion, + file, + }) { + const [blob, headers] = await this.getModelFileAsBlob({ + organization, + modelName, + modelVersion, + file, + }); + return [await blob.arrayBuffer(), headers]; + } + + /** + * Given an organization, model, and version, fetch a model file in the hub as blob. + * + * @param {object} config + * @param {string} config.organization + * @param {string} config.modelName + * @param {string} config.modelVersion + * @param {string} config.file + * @returns {Promise<[Blob, object]>} The file content + */ + async getModelFileAsBlob({ organization, modelName, modelVersion, file }) { + // Make sure inputs are clean. We don't sanitize them but throw an exception + let checkError = this.#checkInput( + organization, + modelName, + modelVersion, + file + ); + if (checkError) { + throw checkError; + } + + const url = this.#fileUrl(organization, modelName, modelVersion, file); + lazy.console.debug(`Getting model file from ${url}`); + + await this.#initCache(); + + // this can be null if no ETag was found or there were a network error + const hubETag = await this.#getETag(url); + + lazy.console.debug( + `Checking the cache for ${organization}/${modelName}/${modelVersion}/${file}` + ); + + // storage lookup + const cachedHeaders = await this.cache.getHeaders( + organization, + modelName, + modelVersion, + file + ); + const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null; + + // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response + if (cachedEtag !== null && (hubETag === null || cachedEtag === hubETag)) { + lazy.console.debug(`Cache Hit`); + return await this.cache.getFile( + organization, + modelName, + modelVersion, + file + ); + } + + lazy.console.debug(`Fetching ${url}`); + try { + const response = await fetch(url); + if (response.ok) { + const clone = response.clone(); + const headers = { + // We don't store the boundary or the charset, just the content type, + // so we drop what's after the semicolon. + "Content-Type": response.headers.get("Content-Type").split(";")[0], + ETag: hubETag, + }; + + await this.cache.put( + organization, + modelName, + modelVersion, + file, + await clone.blob(), + headers + ); + return [await response.blob(), headers]; + } + } catch (error) { + lazy.console.error(`Failed to fetch ${url}:`, error); + } + + throw new Error(`Failed to fetch the model file: ${url}`); + } +} diff --git a/toolkit/components/ml/jar.mn b/toolkit/components/ml/jar.mn index 56bfb0d469..c9c82e0b26 100644 --- a/toolkit/components/ml/jar.mn +++ b/toolkit/components/ml/jar.mn @@ -6,4 +6,5 @@ toolkit.jar: content/global/ml/EngineProcess.sys.mjs (content/EngineProcess.sys.mjs) content/global/ml/MLEngine.worker.mjs (content/MLEngine.worker.mjs) content/global/ml/MLEngine.html (content/MLEngine.html) + content/global/ml/ModelHub.sys.mjs (content/ModelHub.sys.mjs) content/global/ml/SummarizerModel.sys.mjs (content/SummarizerModel.sys.mjs) diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml index 9ccda0beaa..57637c8bda 100644 --- a/toolkit/components/ml/tests/browser/browser.toml +++ b/toolkit/components/ml/tests/browser/browser.toml @@ -1,5 +1,9 @@ [DEFAULT] support-files = [ "head.js", + "data/**/*.*" ] + +["browser_ml_cache.js"] + ["browser_ml_engine.js"] diff --git a/toolkit/components/ml/tests/browser/browser_ml_cache.js b/toolkit/components/ml/tests/browser/browser_ml_cache.js new file mode 100644 index 0000000000..d8725368bd --- /dev/null +++ b/toolkit/components/ml/tests/browser/browser_ml_cache.js @@ -0,0 +1,361 @@ +/* Any copyright is dedicated to the Public Domain. +http://creativecommons.org/publicdomain/zero/1.0/ */ +"use strict"; + +const { sinon } = ChromeUtils.importESModule( + "resource://testing-common/Sinon.sys.mjs" +); + +// Root URL of the fake hub, see the `data` dir in the tests. +const FAKE_HUB = + "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data"; + +const FAKE_MODEL_ARGS = { + organization: "acme", + modelName: "bert", + modelVersion: "main", + file: "config.json", +}; + +const FAKE_ONNX_MODEL_ARGS = { + organization: "acme", + modelName: "bert", + modelVersion: "main", + file: "onnx/config.json", +}; + +const badHubs = [ + "https://my.cool.hub", + "https://sub.localhost/myhub", // Subdomain of allowed domain + "https://model-hub.mozilla.org.evil.com", // Manipulating path to mimic domain + "httpsz://localhost/myhub", // Similar-looking scheme + "https://localhost.", // Trailing dot in domain + "resource://user@localhost", // User info in URL + "ftp://localhost/myhub", // Disallowed scheme with allowed host + "https://model-hub.mozilla.org.hack", // Domain that contains allowed domain +]; + +add_task(async function test_bad_hubs() { + for (const badHub of badHubs) { + Assert.throws( + () => new ModelHub({ rootUrl: badHub }), + new RegExp(`Error: Invalid model hub root url: ${badHub}`), + `Should throw with ${badHub}` + ); + } +}); + +let goodHubs = [ + "https:///localhost/myhub", // Triple slashes, see https://stackoverflow.com/a/22775589 + "https://localhost:8080/myhub", + "http://localhost/myhub", + "https://model-hub.mozilla.org", + "chrome://gre/somewhere/in/the/code/base", +]; + +add_task(async function test_allowed_hub() { + goodHubs.forEach(url => new ModelHub({ rootUrl: url })); +}); + +const badInputs = [ + [ + { + organization: "ac me", + modelName: "bert", + modelVersion: "main", + file: "config.json", + }, + "Org can only contain letters, numbers, and hyphens", + ], + [ + { + organization: "1111", + modelName: "bert", + modelVersion: "main", + file: "config.json", + }, + "Org cannot contain only numbers", + ], + [ + { + organization: "-acme", + modelName: "bert", + modelVersion: "main", + file: "config.json", + }, + "Org start or end with a hyphen, or use consecutive hyphens", + ], + [ + { + organization: "a-c-m-e", + modelName: "#bert", + modelVersion: "main", + file: "config.json", + }, + "Models can only contain letters, numbers, and hyphens, underscord, periods", + ], + [ + { + organization: "a-c-m-e", + modelName: "b$ert", + modelVersion: "main", + file: "config.json", + }, + "Models cannot contain spaces or control characters", + ], + [ + { + organization: "a-c-m-e", + modelName: "b$ert", + modelVersion: "main", + file: ".filename", + }, + "File", + ], +]; + +add_task(async function test_bad_inputs() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + for (const badInput of badInputs) { + const params = badInput[0]; + const errorMsg = badInput[1]; + try { + await hub.getModelFileAsArrayBuffer(params); + } catch (error) { + continue; + } + throw new Error(errorMsg); + } +}); + +add_task(async function test_getting_file() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + let [array, headers] = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + + Assert.equal(headers["Content-Type"], "application/json"); + + // check the content of the file. + let jsonData = JSON.parse( + String.fromCharCode.apply(null, new Uint8Array(array)) + ); + + Assert.equal(jsonData.hidden_size, 768); +}); + +add_task(async function test_getting_file_in_subdir() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + let [array, metadata] = await hub.getModelFileAsArrayBuffer( + FAKE_ONNX_MODEL_ARGS + ); + + Assert.equal(metadata["Content-Type"], "application/json"); + + // check the content of the file. + let jsonData = JSON.parse( + String.fromCharCode.apply(null, new Uint8Array(array)) + ); + + Assert.equal(jsonData.hidden_size, 768); +}); + +add_task(async function test_getting_file_custom_path() { + const hub = new ModelHub({ + rootUrl: FAKE_HUB, + urlTemplate: "${organization}/${modelName}/resolve/${modelVersion}/${file}", + }); + + let res = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + + Assert.equal(res[1]["Content-Type"], "application/json"); +}); + +add_task(async function test_getting_file_custom_path_rogue() { + const urlTemplate = + "${organization}/${modelName}/resolve/${modelVersion}/${file}?some_id=bedqwdw"; + Assert.throws( + () => new ModelHub({ rootUrl: FAKE_HUB, urlTemplate }), + /Invalid URL template/, + `Should throw with ${urlTemplate}` + ); +}); + +add_task(async function test_getting_file_as_response() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + + let response = await hub.getModelFileAsResponse(FAKE_MODEL_ARGS); + + // check the content of the file. + let jsonData = await response.json(); + Assert.equal(jsonData.hidden_size, 768); +}); + +add_task(async function test_getting_file_from_cache() { + const hub = new ModelHub({ rootUrl: FAKE_HUB }); + let array = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + + // stub to verify that the data was retrieved from IndexDB + let matchMethod = hub.cache._testGetData; + + sinon.stub(hub.cache, "_testGetData").callsFake(function () { + return matchMethod.apply(this, arguments).then(result => { + Assert.notEqual(result, null); + return result; + }); + }); + + // exercises the cache + let array2 = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS); + hub.cache._testGetData.restore(); + + Assert.deepEqual(array, array2); +}); + +// IndexedDB tests + +/** + * Helper function to initialize the cache + */ +async function initializeCache() { + const randomSuffix = Math.floor(Math.random() * 10000); + return await IndexedDBCache.init(`modelFiles-${randomSuffix}`); +} + +/** + * Helper function to delete the cache database + */ +async function deleteCache(cache) { + await cache.dispose(); + indexedDB.deleteDatabase(cache.dbName); +} + +/** + * Test the initialization and creation of the IndexedDBCache instance. + */ +add_task(async function test_Init() { + const cache = await initializeCache(); + Assert.ok( + cache instanceof IndexedDBCache, + "The cache instance should be created successfully." + ); + Assert.ok( + IDBDatabase.isInstance(cache.db), + `The cache should have an IDBDatabase instance. Found ${cache.db}` + ); + await deleteCache(cache); +}); + +/** + * Test adding data to the cache and retrieving it. + */ +add_task(async function test_PutAndGet() { + const cache = await initializeCache(); + const testData = new ArrayBuffer(8); // Example data + await cache.put("org", "model", "v1", "file.txt", testData, { + ETag: "ETAG123", + }); + + const [retrievedData, headers] = await cache.getFile( + "org", + "model", + "v1", + "file.txt" + ); + Assert.deepEqual( + retrievedData, + testData, + "The retrieved data should match the stored data." + ); + Assert.equal( + headers.ETag, + "ETAG123", + "The retrieved ETag should match the stored ETag." + ); + + await deleteCache(cache); +}); + +/** + * Test retrieving the headers for a cache entry. + */ +add_task(async function test_GetHeaders() { + const cache = await initializeCache(); + const testData = new ArrayBuffer(8); + const headers = { + ETag: "ETAG123", + status: 200, + extra: "extra", + }; + + await cache.put("org", "model", "v1", "file.txt", testData, headers); + + const storedHeaders = await cache.getHeaders( + "org", + "model", + "v1", + "file.txt" + ); + + // The `extra` field should be removed from the stored headers because + // it's not part of the allowed keys. + // The content-type one is added when not present + Assert.deepEqual( + { + ETag: "ETAG123", + status: 200, + "Content-Type": "application/octet-stream", + }, + storedHeaders, + "The retrieved headers should match the stored headers." + ); + await deleteCache(cache); +}); + +/** + * Test listing all models stored in the cache. + */ +add_task(async function test_ListModels() { + const cache = await initializeCache(); + await cache.put( + "org1", + "modelA", + "v1", + "file1.txt", + new ArrayBuffer(8), + null + ); + await cache.put( + "org2", + "modelB", + "v1", + "file2.txt", + new ArrayBuffer(8), + null + ); + + const models = await cache.listModels(); + Assert.ok( + models.includes("org1/modelA/v1") && models.includes("org2/modelB/v1"), + "All models should be listed." + ); + await deleteCache(cache); +}); + +/** + * Test deleting a model and its data from the cache. + */ +add_task(async function test_DeleteModel() { + const cache = await initializeCache(); + await cache.put("org", "model", "v1", "file.txt", new ArrayBuffer(8), null); + await cache.deleteModel("org", "model", "v1"); + + const dataAfterDelete = await cache.getFile("org", "model", "v1", "file.txt"); + Assert.equal( + dataAfterDelete, + null, + "The data for the deleted model should not exist." + ); + await deleteCache(cache); +}); diff --git a/toolkit/components/ml/tests/browser/data/README.md b/toolkit/components/ml/tests/browser/data/README.md new file mode 100644 index 0000000000..d826cf7ee6 --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/README.md @@ -0,0 +1,5 @@ +# fake hub + +This directory is a fake hub that is served via chrome://global/content/ml/tests. + +All files in this directory are included with a wildcard in the component `jar.md`. diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json new file mode 100644 index 0000000000..50dbb760bb --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json @@ -0,0 +1,21 @@ +{ + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json new file mode 100644 index 0000000000..50dbb760bb --- /dev/null +++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json @@ -0,0 +1,21 @@ +{ + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": 0.1, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0.dev0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 +} diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js index 99d27ce18a..9fc20c0e84 100644 --- a/toolkit/components/ml/tests/browser/head.js +++ b/toolkit/components/ml/tests/browser/head.js @@ -19,6 +19,10 @@ const { MLEngineParent } = ChromeUtils.importESModule( "resource://gre/actors/MLEngineParent.sys.mjs" ); +const { ModelHub, IndexedDBCache } = ChromeUtils.importESModule( + "chrome://global/content/ml/ModelHub.sys.mjs" +); + // This test suite shares some utility functions with translations as they work in a very // similar fashion. Eventually, the plan is to unify these two components. Services.scriptloader.loadSubScript( |