summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-15 03:34:42 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-15 03:34:42 +0000
commitda4c7e7ed675c3bf405668739c3012d140856109 (patch)
treecdd868dba063fecba609a1d819de271f0d51b23e /toolkit/components/ml
parentAdding upstream version 125.0.3. (diff)
downloadfirefox-da4c7e7ed675c3bf405668739c3012d140856109.tar.xz
firefox-da4c7e7ed675c3bf405668739c3012d140856109.zip
Adding upstream version 126.0.upstream/126.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/components/ml')
-rw-r--r--toolkit/components/ml/actors/MLEngineParent.sys.mjs2
-rw-r--r--toolkit/components/ml/content/ModelHub.sys.mjs690
-rw-r--r--toolkit/components/ml/jar.mn1
-rw-r--r--toolkit/components/ml/tests/browser/browser.toml4
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_cache.js361
-rw-r--r--toolkit/components/ml/tests/browser/data/README.md5
-rw-r--r--toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json21
-rw-r--r--toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json21
-rw-r--r--toolkit/components/ml/tests/browser/head.js4
9 files changed, 1108 insertions, 1 deletions
diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
index 10b4eed4fa..05203e5f69 100644
--- a/toolkit/components/ml/actors/MLEngineParent.sys.mjs
+++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
@@ -91,7 +91,7 @@ export class MLEngineParent extends JSWindowActorParent {
}
// eslint-disable-next-line consistent-return
- async receiveMessage({ name, data }) {
+ async receiveMessage({ name }) {
switch (name) {
case "MLEngine:Ready":
if (lazy.EngineProcess.resolveMLEngineParent) {
diff --git a/toolkit/components/ml/content/ModelHub.sys.mjs b/toolkit/components/ml/content/ModelHub.sys.mjs
new file mode 100644
index 0000000000..4c2181ff14
--- /dev/null
+++ b/toolkit/components/ml/content/ModelHub.sys.mjs
@@ -0,0 +1,690 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+const lazy = {};
+
+ChromeUtils.defineESModuleGetters(lazy, {
+ clearTimeout: "resource://gre/modules/Timer.sys.mjs",
+ setTimeout: "resource://gre/modules/Timer.sys.mjs",
+});
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+const ALLOWED_HUBS = [
+ "chrome://*",
+ "resource://*",
+ "http://localhost",
+ "https://localhost",
+ "https://model-hub.mozilla.org",
+];
+
+const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"];
+const DEFAULT_URL_TEMPLATE =
+ "${organization}/${modelName}/resolve/${modelVersion}/${file}";
+
+/**
+ * Checks if a given URL string corresponds to an allowed hub.
+ *
+ * This function validates a URL against a list of allowed hubs, ensuring that it:
+ * - Is well-formed according to the URL standard.
+ * - Does not include a username or password.
+ * - Matches the allowed scheme and hostname.
+ *
+ * @param {string} urlString The URL string to validate.
+ * @returns {boolean} True if the URL is allowed; false otherwise.
+ */
+function allowedHub(urlString) {
+ try {
+ const url = new URL(urlString);
+ // Check for username or password in the URL
+ if (url.username !== "" || url.password !== "") {
+ return false; // Reject URLs with username or password
+ }
+ const scheme = url.protocol;
+ const host = url.hostname;
+ const fullPrefix = `${scheme}//${host}`;
+
+ return ALLOWED_HUBS.some(allowedHub => {
+ const [allowedScheme, allowedHost] = allowedHub.split("://");
+ if (allowedHost === "*") {
+ return `${allowedScheme}:` === scheme;
+ }
+ const allowedPrefix = `${allowedScheme}://${allowedHost}`;
+ return fullPrefix === allowedPrefix;
+ });
+ } catch (error) {
+ lazy.console.error("Error parsing URL:", error);
+ return false;
+ }
+}
+
+const NO_ETAG = "NO_ETAG";
+
+/**
+ * Class for managing a cache stored in IndexedDB.
+ */
+export class IndexedDBCache {
+ /**
+ * Reference to the IndexedDB database.
+ *
+ * @type {IDBDatabase|null}
+ */
+ db = null;
+
+ /**
+ * Version of the database. Null if not set.
+ *
+ * @type {number|null}
+ */
+ dbVersion = null;
+
+ /**
+ * Total size of the files stored in the cache.
+ *
+ * @type {number}
+ */
+ totalSize = 0;
+
+ /**
+ * Name of the database used by IndexedDB.
+ *
+ * @type {string}
+ */
+ dbName;
+
+ /**
+ * Name of the object store for storing files.
+ *
+ * @type {string}
+ */
+ fileStoreName;
+
+ /**
+ * Name of the object store for storing headers.
+ *
+ * @type {string}
+ */
+ headersStoreName;
+ /**
+ * Maximum size of the cache in bytes. Defaults to 1GB.
+ *
+ * @type {number}
+ */
+ #maxSize = 1_073_741_824; // 1GB in bytes
+
+ /**
+ * Private constructor to prevent direct instantiation.
+ * Use IndexedDBCache.init to create an instance.
+ *
+ * @param {string} dbName - The name of the database file.
+ * @param {number} version - The version number of the database.
+ */
+ constructor(dbName = "modelFiles", version = 1) {
+ this.dbName = dbName;
+ this.dbVersion = version;
+ this.fileStoreName = "files";
+ this.headersStoreName = "headers";
+ }
+
+ /**
+ * Static method to create and initialize an instance of IndexedDBCache.
+ *
+ * @param {string} [dbName="modelFiles"] - The name of the database.
+ * @param {number} [version=1] - The version number of the database.
+ * @returns {Promise<IndexedDBCache>} An initialized instance of IndexedDBCache.
+ */
+ static async init(dbName = "modelFiles", version = 1) {
+ const cacheInstance = new IndexedDBCache(dbName, version);
+ cacheInstance.db = await cacheInstance.#openDB();
+ const storedSize = await cacheInstance.#getData(
+ cacheInstance.headersStoreName,
+ "totalSize"
+ );
+ cacheInstance.totalSize = storedSize ? storedSize.size : 0;
+ return cacheInstance;
+ }
+
+ /**
+ * Called to close the DB connection and dispose the instance
+ *
+ */
+ async dispose() {
+ if (this.db) {
+ this.db.close();
+ this.db = null;
+ }
+ }
+
+ /**
+ * Opens or creates the IndexedDB database.
+ *
+ * @returns {Promise<IDBDatabase>}
+ */
+ async #openDB() {
+ return new Promise((resolve, reject) => {
+ const request = indexedDB.open(this.dbName, this.dbVersion);
+ request.onerror = event => reject(event.target.error);
+ request.onsuccess = event => resolve(event.target.result);
+ request.onupgradeneeded = event => {
+ const db = event.target.result;
+ if (!db.objectStoreNames.contains(this.fileStoreName)) {
+ db.createObjectStore(this.fileStoreName, { keyPath: "id" });
+ }
+ if (!db.objectStoreNames.contains(this.headersStoreName)) {
+ db.createObjectStore(this.headersStoreName, { keyPath: "id" });
+ }
+ };
+ });
+ }
+
+ /**
+ * Generic method to get the data from a specified object store.
+ *
+ * @param {string} storeName - The name of the object store.
+ * @param {string} key - The key within the object store to retrieve the data from.
+ * @returns {Promise<any>}
+ */
+ async #getData(storeName, key) {
+ return new Promise((resolve, reject) => {
+ const transaction = this.db.transaction([storeName], "readonly");
+ const store = transaction.objectStore(storeName);
+ const request = store.get(key);
+ request.onerror = event => reject(event.target.error);
+ request.onsuccess = event => resolve(event.target.result);
+ });
+ }
+
+ // Used in tests
+ async _testGetData(storeName, key) {
+ return this.#getData(storeName, key);
+ }
+
+ /**
+ * Generic method to update data in a specified object store.
+ *
+ * @param {string} storeName - The name of the object store.
+ * @param {object} data - The data to store.
+ * @returns {Promise<void>}
+ */
+ async #updateData(storeName, data) {
+ return new Promise((resolve, reject) => {
+ const transaction = this.db.transaction([storeName], "readwrite");
+ const store = transaction.objectStore(storeName);
+ const request = store.put(data);
+ request.onerror = event => reject(event.target.error);
+ request.onsuccess = () => resolve();
+ });
+ }
+
+ /**
+ * Deletes a specific cache entry.
+ *
+ * @param {string} storeName - The name of the object store.
+ * @param {string} key - The key of the entry to delete.
+ * @returns {Promise<void>}
+ */
+ async #deleteData(storeName, key) {
+ return new Promise((resolve, reject) => {
+ const transaction = this.db.transaction([storeName], "readwrite");
+ const store = transaction.objectStore(storeName);
+ const request = store.delete(key);
+ request.onerror = event => reject(event.target.error);
+ request.onsuccess = () => resolve();
+ });
+ }
+
+ /**
+ * Retrieves the headers for a specific cache entry.
+ *
+ * @param {string} organization - The organization name.
+ * @param {string} modelName - The model name.
+ * @param {string} modelVersion - The model version.
+ * @param {string} file - The file name.
+ * @returns {Promise<object|null>} The headers or null if not found.
+ */
+ async getHeaders(organization, modelName, modelVersion, file) {
+ const headersKey = `${organization}/${modelName}/${modelVersion}`;
+ const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ const headers = await this.#getData(this.headersStoreName, headersKey);
+ if (headers && headers.files[cacheKey]) {
+ return headers.files[cacheKey];
+ }
+ return null; // Return null if no headers is found
+ }
+
+ /**
+ * Retrieves the file for a specific cache entry.
+ *
+ * @param {string} organization - The organization name.
+ * @param {string} modelName - The model name.
+ * @param {string} modelVersion - The model version.
+ * @param {string} file - The file name.
+ * @returns {Promise<[ArrayBuffer, object]|null>} The file ArrayBuffer and its headers or null if not found.
+ */
+ async getFile(organization, modelName, modelVersion, file) {
+ const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ const stored = await this.#getData(this.fileStoreName, cacheKey);
+ if (stored) {
+ const headers = await this.getHeaders(
+ organization,
+ modelName,
+ modelVersion,
+ file
+ );
+ return [stored.data, headers];
+ }
+ return null; // Return null if no file is found
+ }
+
+ /**
+ * Adds or updates a cache entry.
+ *
+ * @param {string} organization - The organization name.
+ * @param {string} modelName - The model name.
+ * @param {string} modelVersion - The model version.
+ * @param {string} file - The file name.
+ * @param {ArrayBuffer} arrayBuffer - The data to cache.
+ * @param {object} [headers] - The headers for the file.
+ * @returns {Promise<void>}
+ */
+ async put(
+ organization,
+ modelName,
+ modelVersion,
+ file,
+ arrayBuffer,
+ headers = {}
+ ) {
+ const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ const newSize = this.totalSize + arrayBuffer.byteLength;
+ if (newSize > this.#maxSize) {
+ throw new Error("Exceeding total cache size limit of 1GB");
+ }
+
+ const headersKey = `${organization}/${modelName}/${modelVersion}`;
+ const data = { id: cacheKey, data: arrayBuffer };
+
+ // Store the file data
+ await this.#updateData(this.fileStoreName, data);
+
+ // Update headers store - whith defaults for ETag and Content-Type
+ headers = headers || {};
+ headers["Content-Type"] =
+ headers["Content-Type"] ?? "application/octet-stream";
+ headers.ETag = headers.ETag ?? NO_ETAG;
+
+ // filter out any keys that are not allowed
+ headers = Object.keys(headers)
+ .filter(key => ALLOWED_HEADERS_KEYS.includes(key))
+ .reduce((obj, key) => {
+ obj[key] = headers[key];
+ return obj;
+ }, {});
+
+ const headersStore = (await this.#getData(
+ this.headersStoreName,
+ headersKey
+ )) || {
+ id: headersKey,
+ files: {},
+ };
+ headersStore.files[cacheKey] = headers;
+ await this.#updateData(this.headersStoreName, headersStore);
+
+ // Update size
+ await this.#updateTotalSize(arrayBuffer.byteLength);
+ }
+
+ /**
+ * Updates the total size of the cache.
+ *
+ * @param {number} sizeToAdd - The size to add to the total.
+ * @returns {Promise<void>}
+ */
+ async #updateTotalSize(sizeToAdd) {
+ this.totalSize += sizeToAdd;
+ await this.#updateData(this.headersStoreName, {
+ id: "totalSize",
+ size: this.totalSize,
+ });
+ }
+ /**
+ * Deletes all data related to a specific model.
+ *
+ * @param {string} organization - The organization name.
+ * @param {string} modelName - The model name.
+ * @param {string} modelVersion - The model version.
+ * @returns {Promise<void>}
+ */
+ async deleteModel(organization, modelName, modelVersion) {
+ const headersKey = `${organization}/${modelName}/${modelVersion}`;
+ const headers = await this.#getData(this.headersStoreName, headersKey);
+ if (headers) {
+ for (const fileKey in headers.files) {
+ await this.#deleteData(this.fileStoreName, fileKey);
+ }
+ await this.#deleteData(this.headersStoreName, headersKey); // Remove headers entry after files are deleted
+ }
+ }
+
+ /**
+ * Lists all models stored in the cache.
+ *
+ * @returns {Promise<Array<string>>} An array of model identifiers.
+ */
+ async listModels() {
+ const models = [];
+ return new Promise((resolve, reject) => {
+ const transaction = this.db.transaction(
+ [this.headersStoreName],
+ "readonly"
+ );
+ const store = transaction.objectStore(this.headersStoreName);
+ const request = store.openCursor();
+ request.onerror = event => reject(event.target.error);
+ request.onsuccess = event => {
+ const cursor = event.target.result;
+ if (cursor) {
+ models.push(cursor.value.id); // Assuming id is the organization/modelName
+ cursor.continue();
+ } else {
+ resolve(models);
+ }
+ };
+ });
+ }
+}
+
+export class ModelHub {
+ constructor({ rootUrl, urlTemplate = DEFAULT_URL_TEMPLATE }) {
+ if (!allowedHub(rootUrl)) {
+ throw new Error(`Invalid model hub root url: ${rootUrl}`);
+ }
+ this.rootUrl = rootUrl;
+ this.cache = null;
+
+ // Ensures the URL template is well-formed and does not contain any invalid characters.
+ const pattern = /^(?:\$\{\w+\}|\w+)(?:\/(?:\$\{\w+\}|\w+))*$/;
+ // ^ $ Start and end of string
+ // (?:\$\{\w+\}|\w+) Match a ${placeholder} or alphanumeric characters
+ // (?:\/(?:\$\{\w+\}|\w+))* Zero or more groups of a forward slash followed by a ${placeholder} or alphanumeric characters
+ if (!pattern.test(urlTemplate)) {
+ throw new Error(`Invalid URL template: ${urlTemplate}`);
+ }
+ this.urlTemplate = urlTemplate;
+ }
+
+ async #initCache() {
+ if (this.cache) {
+ return;
+ }
+ this.cache = await IndexedDBCache.init();
+ }
+
+ /** Creates the file URL from the organization, model, and version.
+ *
+ * @param {string} organization
+ * @param {string} modelName
+ * @param {string} modelVersion
+ * @param {string} file
+ * @returns {string} The full URL
+ */
+ #fileUrl(organization, modelName, modelVersion, file) {
+ const baseUrl = new URL(this.rootUrl);
+ if (!baseUrl.pathname.endsWith("/")) {
+ baseUrl.pathname += "/";
+ }
+
+ // Replace placeholders in the URL template with the provided data.
+ // If some keys are missing in the data object, the placeholder is left as is.
+ // If the placeholder is not found in the data object, it is left as is.
+ const data = {
+ organization,
+ modelName,
+ modelVersion,
+ file,
+ };
+ const path = this.urlTemplate.replace(
+ /\$\{(\w+)\}/g,
+ (match, key) => data[key] || match
+ );
+ const fullPath = `${baseUrl.pathname}${
+ path.startsWith("/") ? path.slice(1) : path
+ }`;
+
+ const urlObject = new URL(fullPath, baseUrl.origin);
+ urlObject.searchParams.append("download", "true");
+ return urlObject.toString();
+ }
+
+ /** Checks the organization, model, and version inputs.
+ *
+ * @param { string } organization
+ * @param { string } modelName
+ * @param { string } modelVersion
+ * @param { string } file
+ * @returns { Error } The error instance(can be null)
+ */
+ #checkInput(organization, modelName, modelVersion, file) {
+ // Ensures string consists only of letters, digits, and hyphens without starting/ending
+ // with a hyphen or containing consecutive hyphens.
+ //
+ // ^ $ Start and end of string
+ // (?!-) (?<!-) Negative lookahead/behind for not starting or ending with hyphen
+ // (?!.*--) Negative lookahead for not containing consecutive hyphens
+ // [A-Za-z0-9-]+ Alphanum characters or hyphens, one or more
+ const orgRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)$/;
+
+ // Matches strings containing letters, digits, hyphens, underscores, or periods.
+ // ^ $ Start and end of string
+ // [A-Za-z0-9-_.]+ Alphanum characters, hyphens, underscores, or periods, one or more times
+ const modelRegex = /^[A-Za-z0-9-_.]+$/;
+
+ // Matches strings consisting of alphanumeric characters, hyphens, or periods.
+ //
+ // ^ $ Start and end of string
+ // [A-Za-z0-9-.]+ Alphanum characters, hyphens, or periods, one or more times
+ const versionRegex = /^[A-Za-z0-9-.]+$/;
+
+ // Matches filenames with subdirectories, starting with alphanumeric or underscore,
+ // and optionally ending with a dot followed by a 2-4 letter extension.
+ //
+ // ^ $ Start and end of string
+ // (?:\/)? Optional leading slash (for absolute paths or root directory)
+ // (?!\/) Negative lookahead for not starting with a slash
+ // [A-Za-z0-9-_]+ First directory or filename
+ // (?: Begin non-capturing group for additional directories or file
+ // \/ Directory separator
+ // [A-Za-z0-9-_]+ Directory or file name
+ // )* Zero or more times
+ // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension
+ const fileRegex =
+ /^(?:\/)?(?!\/)[A-Za-z0-9-_]+(?:\/[A-Za-z0-9-_]+)*(?:[.][A-Za-z]{2,4})?$/;
+
+ if (!orgRegex.test(organization) || !isNaN(parseInt(organization))) {
+ return new Error(`Invalid organization name ${organization}`);
+ }
+
+ if (!modelRegex.test(modelName)) {
+ return new Error("Invalid model name.");
+ }
+
+ if (
+ !versionRegex.test(modelVersion) ||
+ modelVersion.includes(" ") ||
+ /[\^$]/.test(modelVersion)
+ ) {
+ return new Error("Invalid version identifier.");
+ }
+
+ if (!fileRegex.test(file)) {
+ return new Error("Invalid file name");
+ }
+
+ return null;
+ }
+
+ /**
+ * Returns the ETag value given an URL
+ *
+ * @param {string} url
+ * @param {number} timeout in ms. Default is 1000
+ * @returns {Promise<string>} ETag (can be null)
+ */
+ async #getETag(url, timeout = 1000) {
+ const controller = new AbortController();
+ const id = lazy.setTimeout(() => controller.abort(), timeout);
+
+ try {
+ const headResponse = await fetch(url, {
+ method: "HEAD",
+ signal: controller.signal,
+ });
+ const currentEtag = headResponse.headers.get("ETag");
+ return currentEtag;
+ } catch (error) {
+ lazy.console.warn("An error occurred when calling HEAD:", error);
+ return null;
+ } finally {
+ lazy.clearTimeout(id);
+ }
+ }
+
+ /**
+ * Given an organization, model, and version, fetch a model file in the hub as a Response.
+ *
+ * @param {object} config
+ * @param {string} config.organization
+ * @param {string} config.modelName
+ * @param {string} config.modelVersion
+ * @param {string} config.file
+ * @returns {Promise<Response>} The file content
+ */
+ async getModelFileAsResponse({
+ organization,
+ modelName,
+ modelVersion,
+ file,
+ }) {
+ const [blob, headers] = await this.getModelFileAsBlob({
+ organization,
+ modelName,
+ modelVersion,
+ file,
+ });
+ return new Response(blob, { headers });
+ }
+
+ /**
+ * Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer.
+ *
+ * @param {object} config
+ * @param {string} config.organization
+ * @param {string} config.modelName
+ * @param {string} config.modelVersion
+ * @param {string} config.file
+ * @returns {Promise<[ArrayBuffer, headers]>} The file content
+ */
+ async getModelFileAsArrayBuffer({
+ organization,
+ modelName,
+ modelVersion,
+ file,
+ }) {
+ const [blob, headers] = await this.getModelFileAsBlob({
+ organization,
+ modelName,
+ modelVersion,
+ file,
+ });
+ return [await blob.arrayBuffer(), headers];
+ }
+
+ /**
+ * Given an organization, model, and version, fetch a model file in the hub as blob.
+ *
+ * @param {object} config
+ * @param {string} config.organization
+ * @param {string} config.modelName
+ * @param {string} config.modelVersion
+ * @param {string} config.file
+ * @returns {Promise<[Blob, object]>} The file content
+ */
+ async getModelFileAsBlob({ organization, modelName, modelVersion, file }) {
+ // Make sure inputs are clean. We don't sanitize them but throw an exception
+ let checkError = this.#checkInput(
+ organization,
+ modelName,
+ modelVersion,
+ file
+ );
+ if (checkError) {
+ throw checkError;
+ }
+
+ const url = this.#fileUrl(organization, modelName, modelVersion, file);
+ lazy.console.debug(`Getting model file from ${url}`);
+
+ await this.#initCache();
+
+ // this can be null if no ETag was found or there were a network error
+ const hubETag = await this.#getETag(url);
+
+ lazy.console.debug(
+ `Checking the cache for ${organization}/${modelName}/${modelVersion}/${file}`
+ );
+
+ // storage lookup
+ const cachedHeaders = await this.cache.getHeaders(
+ organization,
+ modelName,
+ modelVersion,
+ file
+ );
+ const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null;
+
+ // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response
+ if (cachedEtag !== null && (hubETag === null || cachedEtag === hubETag)) {
+ lazy.console.debug(`Cache Hit`);
+ return await this.cache.getFile(
+ organization,
+ modelName,
+ modelVersion,
+ file
+ );
+ }
+
+ lazy.console.debug(`Fetching ${url}`);
+ try {
+ const response = await fetch(url);
+ if (response.ok) {
+ const clone = response.clone();
+ const headers = {
+ // We don't store the boundary or the charset, just the content type,
+ // so we drop what's after the semicolon.
+ "Content-Type": response.headers.get("Content-Type").split(";")[0],
+ ETag: hubETag,
+ };
+
+ await this.cache.put(
+ organization,
+ modelName,
+ modelVersion,
+ file,
+ await clone.blob(),
+ headers
+ );
+ return [await response.blob(), headers];
+ }
+ } catch (error) {
+ lazy.console.error(`Failed to fetch ${url}:`, error);
+ }
+
+ throw new Error(`Failed to fetch the model file: ${url}`);
+ }
+}
diff --git a/toolkit/components/ml/jar.mn b/toolkit/components/ml/jar.mn
index 56bfb0d469..c9c82e0b26 100644
--- a/toolkit/components/ml/jar.mn
+++ b/toolkit/components/ml/jar.mn
@@ -6,4 +6,5 @@ toolkit.jar:
content/global/ml/EngineProcess.sys.mjs (content/EngineProcess.sys.mjs)
content/global/ml/MLEngine.worker.mjs (content/MLEngine.worker.mjs)
content/global/ml/MLEngine.html (content/MLEngine.html)
+ content/global/ml/ModelHub.sys.mjs (content/ModelHub.sys.mjs)
content/global/ml/SummarizerModel.sys.mjs (content/SummarizerModel.sys.mjs)
diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml
index 9ccda0beaa..57637c8bda 100644
--- a/toolkit/components/ml/tests/browser/browser.toml
+++ b/toolkit/components/ml/tests/browser/browser.toml
@@ -1,5 +1,9 @@
[DEFAULT]
support-files = [
"head.js",
+ "data/**/*.*"
]
+
+["browser_ml_cache.js"]
+
["browser_ml_engine.js"]
diff --git a/toolkit/components/ml/tests/browser/browser_ml_cache.js b/toolkit/components/ml/tests/browser/browser_ml_cache.js
new file mode 100644
index 0000000000..d8725368bd
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/browser_ml_cache.js
@@ -0,0 +1,361 @@
+/* Any copyright is dedicated to the Public Domain.
+http://creativecommons.org/publicdomain/zero/1.0/ */
+"use strict";
+
+const { sinon } = ChromeUtils.importESModule(
+ "resource://testing-common/Sinon.sys.mjs"
+);
+
+// Root URL of the fake hub, see the `data` dir in the tests.
+const FAKE_HUB =
+ "chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data";
+
+const FAKE_MODEL_ARGS = {
+ organization: "acme",
+ modelName: "bert",
+ modelVersion: "main",
+ file: "config.json",
+};
+
+const FAKE_ONNX_MODEL_ARGS = {
+ organization: "acme",
+ modelName: "bert",
+ modelVersion: "main",
+ file: "onnx/config.json",
+};
+
+const badHubs = [
+ "https://my.cool.hub",
+ "https://sub.localhost/myhub", // Subdomain of allowed domain
+ "https://model-hub.mozilla.org.evil.com", // Manipulating path to mimic domain
+ "httpsz://localhost/myhub", // Similar-looking scheme
+ "https://localhost.", // Trailing dot in domain
+ "resource://user@localhost", // User info in URL
+ "ftp://localhost/myhub", // Disallowed scheme with allowed host
+ "https://model-hub.mozilla.org.hack", // Domain that contains allowed domain
+];
+
+add_task(async function test_bad_hubs() {
+ for (const badHub of badHubs) {
+ Assert.throws(
+ () => new ModelHub({ rootUrl: badHub }),
+ new RegExp(`Error: Invalid model hub root url: ${badHub}`),
+ `Should throw with ${badHub}`
+ );
+ }
+});
+
+let goodHubs = [
+ "https:///localhost/myhub", // Triple slashes, see https://stackoverflow.com/a/22775589
+ "https://localhost:8080/myhub",
+ "http://localhost/myhub",
+ "https://model-hub.mozilla.org",
+ "chrome://gre/somewhere/in/the/code/base",
+];
+
+add_task(async function test_allowed_hub() {
+ goodHubs.forEach(url => new ModelHub({ rootUrl: url }));
+});
+
+const badInputs = [
+ [
+ {
+ organization: "ac me",
+ modelName: "bert",
+ modelVersion: "main",
+ file: "config.json",
+ },
+ "Org can only contain letters, numbers, and hyphens",
+ ],
+ [
+ {
+ organization: "1111",
+ modelName: "bert",
+ modelVersion: "main",
+ file: "config.json",
+ },
+ "Org cannot contain only numbers",
+ ],
+ [
+ {
+ organization: "-acme",
+ modelName: "bert",
+ modelVersion: "main",
+ file: "config.json",
+ },
+ "Org start or end with a hyphen, or use consecutive hyphens",
+ ],
+ [
+ {
+ organization: "a-c-m-e",
+ modelName: "#bert",
+ modelVersion: "main",
+ file: "config.json",
+ },
+ "Models can only contain letters, numbers, and hyphens, underscord, periods",
+ ],
+ [
+ {
+ organization: "a-c-m-e",
+ modelName: "b$ert",
+ modelVersion: "main",
+ file: "config.json",
+ },
+ "Models cannot contain spaces or control characters",
+ ],
+ [
+ {
+ organization: "a-c-m-e",
+ modelName: "b$ert",
+ modelVersion: "main",
+ file: ".filename",
+ },
+ "File",
+ ],
+];
+
+add_task(async function test_bad_inputs() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+
+ for (const badInput of badInputs) {
+ const params = badInput[0];
+ const errorMsg = badInput[1];
+ try {
+ await hub.getModelFileAsArrayBuffer(params);
+ } catch (error) {
+ continue;
+ }
+ throw new Error(errorMsg);
+ }
+});
+
+add_task(async function test_getting_file() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+
+ let [array, headers] = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
+
+ Assert.equal(headers["Content-Type"], "application/json");
+
+ // check the content of the file.
+ let jsonData = JSON.parse(
+ String.fromCharCode.apply(null, new Uint8Array(array))
+ );
+
+ Assert.equal(jsonData.hidden_size, 768);
+});
+
+add_task(async function test_getting_file_in_subdir() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+
+ let [array, metadata] = await hub.getModelFileAsArrayBuffer(
+ FAKE_ONNX_MODEL_ARGS
+ );
+
+ Assert.equal(metadata["Content-Type"], "application/json");
+
+ // check the content of the file.
+ let jsonData = JSON.parse(
+ String.fromCharCode.apply(null, new Uint8Array(array))
+ );
+
+ Assert.equal(jsonData.hidden_size, 768);
+});
+
+add_task(async function test_getting_file_custom_path() {
+ const hub = new ModelHub({
+ rootUrl: FAKE_HUB,
+ urlTemplate: "${organization}/${modelName}/resolve/${modelVersion}/${file}",
+ });
+
+ let res = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
+
+ Assert.equal(res[1]["Content-Type"], "application/json");
+});
+
+add_task(async function test_getting_file_custom_path_rogue() {
+ const urlTemplate =
+ "${organization}/${modelName}/resolve/${modelVersion}/${file}?some_id=bedqwdw";
+ Assert.throws(
+ () => new ModelHub({ rootUrl: FAKE_HUB, urlTemplate }),
+ /Invalid URL template/,
+ `Should throw with ${urlTemplate}`
+ );
+});
+
+add_task(async function test_getting_file_as_response() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+
+ let response = await hub.getModelFileAsResponse(FAKE_MODEL_ARGS);
+
+ // check the content of the file.
+ let jsonData = await response.json();
+ Assert.equal(jsonData.hidden_size, 768);
+});
+
+add_task(async function test_getting_file_from_cache() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ let array = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
+
+ // stub to verify that the data was retrieved from IndexDB
+ let matchMethod = hub.cache._testGetData;
+
+ sinon.stub(hub.cache, "_testGetData").callsFake(function () {
+ return matchMethod.apply(this, arguments).then(result => {
+ Assert.notEqual(result, null);
+ return result;
+ });
+ });
+
+ // exercises the cache
+ let array2 = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
+ hub.cache._testGetData.restore();
+
+ Assert.deepEqual(array, array2);
+});
+
+// IndexedDB tests
+
+/**
+ * Helper function to initialize the cache
+ */
+async function initializeCache() {
+ const randomSuffix = Math.floor(Math.random() * 10000);
+ return await IndexedDBCache.init(`modelFiles-${randomSuffix}`);
+}
+
+/**
+ * Helper function to delete the cache database
+ */
+async function deleteCache(cache) {
+ await cache.dispose();
+ indexedDB.deleteDatabase(cache.dbName);
+}
+
+/**
+ * Test the initialization and creation of the IndexedDBCache instance.
+ */
+add_task(async function test_Init() {
+ const cache = await initializeCache();
+ Assert.ok(
+ cache instanceof IndexedDBCache,
+ "The cache instance should be created successfully."
+ );
+ Assert.ok(
+ IDBDatabase.isInstance(cache.db),
+ `The cache should have an IDBDatabase instance. Found ${cache.db}`
+ );
+ await deleteCache(cache);
+});
+
+/**
+ * Test adding data to the cache and retrieving it.
+ */
+add_task(async function test_PutAndGet() {
+ const cache = await initializeCache();
+ const testData = new ArrayBuffer(8); // Example data
+ await cache.put("org", "model", "v1", "file.txt", testData, {
+ ETag: "ETAG123",
+ });
+
+ const [retrievedData, headers] = await cache.getFile(
+ "org",
+ "model",
+ "v1",
+ "file.txt"
+ );
+ Assert.deepEqual(
+ retrievedData,
+ testData,
+ "The retrieved data should match the stored data."
+ );
+ Assert.equal(
+ headers.ETag,
+ "ETAG123",
+ "The retrieved ETag should match the stored ETag."
+ );
+
+ await deleteCache(cache);
+});
+
+/**
+ * Test retrieving the headers for a cache entry.
+ */
+add_task(async function test_GetHeaders() {
+ const cache = await initializeCache();
+ const testData = new ArrayBuffer(8);
+ const headers = {
+ ETag: "ETAG123",
+ status: 200,
+ extra: "extra",
+ };
+
+ await cache.put("org", "model", "v1", "file.txt", testData, headers);
+
+ const storedHeaders = await cache.getHeaders(
+ "org",
+ "model",
+ "v1",
+ "file.txt"
+ );
+
+ // The `extra` field should be removed from the stored headers because
+ // it's not part of the allowed keys.
+ // The content-type one is added when not present
+ Assert.deepEqual(
+ {
+ ETag: "ETAG123",
+ status: 200,
+ "Content-Type": "application/octet-stream",
+ },
+ storedHeaders,
+ "The retrieved headers should match the stored headers."
+ );
+ await deleteCache(cache);
+});
+
+/**
+ * Test listing all models stored in the cache.
+ */
+add_task(async function test_ListModels() {
+ const cache = await initializeCache();
+ await cache.put(
+ "org1",
+ "modelA",
+ "v1",
+ "file1.txt",
+ new ArrayBuffer(8),
+ null
+ );
+ await cache.put(
+ "org2",
+ "modelB",
+ "v1",
+ "file2.txt",
+ new ArrayBuffer(8),
+ null
+ );
+
+ const models = await cache.listModels();
+ Assert.ok(
+ models.includes("org1/modelA/v1") && models.includes("org2/modelB/v1"),
+ "All models should be listed."
+ );
+ await deleteCache(cache);
+});
+
+/**
+ * Test deleting a model and its data from the cache.
+ */
+add_task(async function test_DeleteModel() {
+ const cache = await initializeCache();
+ await cache.put("org", "model", "v1", "file.txt", new ArrayBuffer(8), null);
+ await cache.deleteModel("org", "model", "v1");
+
+ const dataAfterDelete = await cache.getFile("org", "model", "v1", "file.txt");
+ Assert.equal(
+ dataAfterDelete,
+ null,
+ "The data for the deleted model should not exist."
+ );
+ await deleteCache(cache);
+});
diff --git a/toolkit/components/ml/tests/browser/data/README.md b/toolkit/components/ml/tests/browser/data/README.md
new file mode 100644
index 0000000000..d826cf7ee6
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/data/README.md
@@ -0,0 +1,5 @@
+# fake hub
+
+This directory is a fake hub that is served via chrome://global/content/ml/tests.
+
+All files in this directory are included with a wildcard in the component `jar.md`.
diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json
new file mode 100644
index 0000000000..50dbb760bb
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": ["BertForMaskedLM"],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
diff --git a/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json
new file mode 100644
index 0000000000..50dbb760bb
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/data/acme/bert/resolve/main/onnx/config.json
@@ -0,0 +1,21 @@
+{
+ "architectures": ["BertForMaskedLM"],
+ "attention_probs_dropout_prob": 0.1,
+ "gradient_checkpointing": false,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "intermediate_size": 3072,
+ "layer_norm_eps": 1e-12,
+ "max_position_embeddings": 512,
+ "model_type": "bert",
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "pad_token_id": 0,
+ "position_embedding_type": "absolute",
+ "transformers_version": "4.6.0.dev0",
+ "type_vocab_size": 2,
+ "use_cache": true,
+ "vocab_size": 30522
+}
diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js
index 99d27ce18a..9fc20c0e84 100644
--- a/toolkit/components/ml/tests/browser/head.js
+++ b/toolkit/components/ml/tests/browser/head.js
@@ -19,6 +19,10 @@ const { MLEngineParent } = ChromeUtils.importESModule(
"resource://gre/actors/MLEngineParent.sys.mjs"
);
+const { ModelHub, IndexedDBCache } = ChromeUtils.importESModule(
+ "chrome://global/content/ml/ModelHub.sys.mjs"
+);
+
// This test suite shares some utility functions with translations as they work in a very
// similar fashion. Eventually, the plan is to unify these two components.
Services.scriptloader.loadSubScript(