summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/content
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/content')
-rw-r--r--toolkit/components/ml/content/EngineProcess.sys.mjs188
-rw-r--r--toolkit/components/ml/content/MLEngine.worker.mjs98
-rw-r--r--toolkit/components/ml/content/ModelHub.sys.mjs317
-rw-r--r--toolkit/components/ml/content/ONNXPipeline.mjs297
-rw-r--r--toolkit/components/ml/content/SummarizerModel.sys.mjs160
-rw-r--r--toolkit/components/ml/content/Utils.sys.mjs77
6 files changed, 798 insertions, 339 deletions
diff --git a/toolkit/components/ml/content/EngineProcess.sys.mjs b/toolkit/components/ml/content/EngineProcess.sys.mjs
index 36a9381192..0fe6403cc8 100644
--- a/toolkit/components/ml/content/EngineProcess.sys.mjs
+++ b/toolkit/components/ml/content/EngineProcess.sys.mjs
@@ -2,10 +2,17 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+// known to be loaded early in the startup process, and should be loaded eagerly
+import { AppConstants } from "resource://gre/modules/AppConstants.sys.mjs";
+
const lazy = {};
-ChromeUtils.defineESModuleGetters(lazy, {
- HiddenFrame: "resource://gre/modules/HiddenFrame.sys.mjs",
-});
+ChromeUtils.defineESModuleGetters(
+ lazy,
+ {
+ HiddenFrame: "resource://gre/modules/HiddenFrame.sys.mjs",
+ },
+ { global: "current" }
+);
/**
* @typedef {import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent
@@ -16,6 +23,172 @@ ChromeUtils.defineESModuleGetters(lazy, {
*/
/**
+ * This class encapsulates the options for a pipeline process.
+ */
+export class PipelineOptions {
+ /**
+ * The name of the task the pipeline is configured for.
+ *
+ * @type {?string}
+ */
+ taskName = null;
+
+ /**
+ * The maximum amount of time in milliseconds the pipeline should wait for a response.
+ *
+ * @type {?number}
+ */
+ timeoutMS = null;
+
+ /**
+ * The root URL of the model hub where models are hosted.
+ *
+ * @type {?string}
+ */
+ modelHubRootUrl = null;
+
+ /**
+ * A template URL for building the full URL for the model.
+ *
+ * @type {?string}
+ */
+ modelHubUrlTemplate = null;
+
+ /**
+ * The identifier for the specific model to be used by the pipeline.
+ *
+ * @type {?string}
+ */
+ modelId = null;
+
+ /**
+ * The revision for the specific model to be used by the pipeline.
+ *
+ * @type {?string}
+ */
+ modelRevision = null;
+
+ /**
+ * The identifier for the tokenizer associated with the model, used for pre-processing inputs.
+ *
+ * @type {?string}
+ */
+ tokenizerId = null;
+
+ /**
+ * The revision for the tokenizer associated with the model, used for pre-processing inputs.
+ *
+ * @type {?string}
+ */
+ tokenizerRevision = null;
+
+ /**
+ * The identifier for any processor required by the model, used for additional input processing.
+ *
+ * @type {?string}
+ */
+ processorId = null;
+
+ /**
+ * The revision for any processor required by the model, used for additional input processing.
+ *
+ * @type {?string}
+ */
+
+ processorRevision = null;
+
+ /**
+ * The log level used in the worker
+ *
+ * @type {?string}
+ */
+ logLevel = null;
+
+ /**
+ * Name of the runtime wasm file
+ *
+ * @type {?string}
+ */
+ runtimeFilename = null;
+
+ /**
+ * Create a PipelineOptions instance.
+ *
+ * @param {object} options - The options for the pipeline. Must include mandatory fields.
+ */
+ constructor(options) {
+ this.updateOptions(options);
+ }
+
+ /**
+ * Updates multiple options at once.
+ *
+ * @param {object} options - An object containing the options to update.
+ * @throws {Error} Throws an error if an invalid option is provided.
+ */
+ updateOptions(options) {
+ const allowedKeys = [
+ "taskName",
+ "modelHubRootUrl",
+ "modelHubUrlTemplate",
+ "timeoutMS",
+ "modelId",
+ "modelRevision",
+ "tokenizerId",
+ "tokenizerRevision",
+ "processorId",
+ "processorRevision",
+ "logLevel",
+ "runtimeFilename",
+ ];
+
+ Object.keys(options).forEach(key => {
+ if (allowedKeys.includes(key)) {
+ this[key] = options[key]; // Use bracket notation to access setter
+ } else {
+ throw new Error(`Invalid option: ${key}`);
+ }
+ });
+ }
+
+ /**
+ * Returns an object containing all current options.
+
+ * @returns {object} An object with the current options.
+ */
+ getOptions() {
+ return {
+ taskName: this.taskName,
+ modelHubRootUrl: this.modelHubRootUrl,
+ modelHubUrlTemplate: this.modelHubUrlTemplate,
+ timeoutMS: this.timeoutMS,
+ modelId: this.modelId,
+ modelRevision: this.modelRevision,
+ tokenizerId: this.tokenizerId,
+ tokenizerRevision: this.tokenizerRevision,
+ processorId: this.processorId,
+ processorRevision: this.processorRevision,
+ logLevel: this.logLevel,
+ runtimeFilename: this.runtimeFilename,
+ };
+ }
+
+ /**
+ * Updates the given configuration object with the options.
+ *
+ * @param {object} config - The configuration object to be updated.
+ */
+ applyToConfig(config) {
+ const options = this.getOptions();
+ Object.keys(options).forEach(key => {
+ if (options[key] !== null) {
+ config[key] = options[key];
+ }
+ });
+ }
+}
+
+/**
* This class controls the life cycle of the engine process used both in the
* Translations engine and the MLEngine component.
*/
@@ -68,6 +241,15 @@ export class EngineProcess {
* @returns {Promise<MLEngineParent>}
*/
static async getMLEngineParent() {
+ // Bug 1890946 - enable the inference engine in release
+ if (!AppConstants.NIGHTLY_BUILD) {
+ throw new Error("MLEngine is only available in Nightly builds.");
+ }
+ // the pref is off by default
+ if (!Services.prefs.getBoolPref("browser.ml.enable")) {
+ throw new Error("MLEngine is disabled. Check the browser.ml prefs.");
+ }
+
if (!this.mlEngineParent) {
this.mlEngineParent = this.#attachBrowser({
id: "ml-engine-browser",
diff --git a/toolkit/components/ml/content/MLEngine.worker.mjs b/toolkit/components/ml/content/MLEngine.worker.mjs
index 1013977e07..585ac4ab04 100644
--- a/toolkit/components/ml/content/MLEngine.worker.mjs
+++ b/toolkit/components/ml/content/MLEngine.worker.mjs
@@ -2,74 +2,99 @@
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
-import { PromiseWorker } from "resource://gre/modules/workers/PromiseWorker.mjs";
+const lazy = {};
-// Respect the preference "browser.ml.logLevel".
-let _loggingLevel = "Error";
-function log(...args) {
- if (_loggingLevel !== "Error" && _loggingLevel !== "Warn") {
- console.log("ML:", ...args);
- }
-}
-function trace(...args) {
- if (_loggingLevel === "Trace" || _loggingLevel === "All") {
- console.log("ML:", ...args);
- }
-}
+ChromeUtils.defineESModuleGetters(
+ lazy,
+ {
+ PromiseWorker: "resource://gre/modules/workers/PromiseWorker.mjs",
+ Pipeline: "chrome://global/content/ml/ONNXPipeline.mjs",
+ PipelineOptions: "chrome://global/content/ml/EngineProcess.sys.mjs",
+ },
+ { global: "current" }
+);
/**
* The actual MLEngine lives here in a worker.
*/
class MLEngineWorker {
- /** @type {ArrayBuffer} */
- #wasm;
- /** @type {ArrayBuffer} */
- #model;
+ #pipeline;
constructor() {
// Connect the provider to the worker.
this.#connectToPromiseWorker();
}
+ /** Implements the `match` function from the Cache API for Transformers.js custom cache.
+ *
+ * See https://developer.mozilla.org/en-US/docs/Web/API/Cache
+ *
+ * Attempts to match and retrieve a model file based on a provided key.
+ * Fetches a model file by delegating the call to the worker's main thread.
+ * Then wraps the fetched model file into a response object compatible with Transformers.js expectations.
+ *
+ * @param {string} key The unique identifier for the model to fetch.
+ * @returns {Promise<Response|null>} A promise that resolves with a Response object containing the model file or null if not found.
+ */
+ async match(key) {
+ let res = await this.getModelFile(key);
+ if (res.fail) {
+ return null;
+ }
+ let headers = res.ok[1];
+ let modelFile = res.ok[2];
+ // Transformers.js expects a response object, so we wrap the array buffer
+ const response = new Response(modelFile, {
+ status: 200,
+ headers,
+ });
+ return response;
+ }
+
+ async getModelFile(...args) {
+ let result = await self.callMainThread("getModelFile", args);
+ return result;
+ }
+
/**
- * @param {ArrayBuffer} wasm
- * @param {ArrayBuffer} model
- * @param {string} loggingLevel
+ * Placeholder for the `put` method from the Cache API for Transformers.js custom cache.
+ *
+ * @throws {Error} Always thrown to indicate the method is not implemented.
*/
- initializeEngine(wasm, model, loggingLevel) {
- this.#wasm = wasm;
- this.#model = model;
- _loggingLevel = loggingLevel;
- // TODO - Initialize the engine for real here.
- log("MLEngineWorker is initalized");
+ put() {
+ throw new Error("Method not implemented.");
}
/**
+ * @param {ArrayBuffer} wasm
+ * @param {object} options received as an object, converted to a PipelineOptions instance
+ */
+ async initializeEngine(wasm, options) {
+ this.#pipeline = await lazy.Pipeline.initialize(
+ this,
+ wasm,
+ new lazy.PipelineOptions(options)
+ );
+ }
+ /**
* Run the worker.
*
* @param {string} request
*/
- run(request) {
- if (!this.#wasm) {
- throw new Error("Expected the wasm to exist.");
- }
- if (!this.#model) {
- throw new Error("Expected the model to exist");
- }
+ async run(request) {
if (request === "throw") {
throw new Error(
'Received the message "throw", so intentionally throwing an error.'
);
}
- trace("inference run requested with:", request);
- return request.slice(0, Math.floor(request.length / 2));
+ return await this.#pipeline.run(request);
}
/**
* Glue code to connect the `MLEngineWorker` to the PromiseWorker interface.
*/
#connectToPromiseWorker() {
- const worker = new PromiseWorker.AbstractWorker();
+ const worker = new lazy.PromiseWorker.AbstractWorker();
worker.dispatch = (method, args = []) => {
if (!this[method]) {
throw new Error("Method does not exist: " + method);
@@ -81,6 +106,7 @@ class MLEngineWorker {
self.postMessage(message, ...transfers);
};
+ self.callMainThread = worker.callMainThread.bind(worker);
self.addEventListener("message", msg => worker.handleMessage(msg));
self.addEventListener("unhandledrejection", function (error) {
throw error.reason;
diff --git a/toolkit/components/ml/content/ModelHub.sys.mjs b/toolkit/components/ml/content/ModelHub.sys.mjs
index 4c2181ff14..10f83c3000 100644
--- a/toolkit/components/ml/content/ModelHub.sys.mjs
+++ b/toolkit/components/ml/content/ModelHub.sys.mjs
@@ -24,8 +24,7 @@ const ALLOWED_HUBS = [
];
const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"];
-const DEFAULT_URL_TEMPLATE =
- "${organization}/${modelName}/resolve/${modelVersion}/${file}";
+const DEFAULT_URL_TEMPLATE = "{model}/resolve/{revision}";
/**
* Checks if a given URL string corresponds to an allowed hub.
@@ -239,17 +238,36 @@ export class IndexedDBCache {
}
/**
+ * Checks if a specified model file exists in storage.
+ *
+ * @param {string} model - The model name (organization/name)
+ * @param {string} revision - The model revision.
+ * @param {string} file - The file name.
+ * @returns {Promise<boolean>} A promise that resolves with `true` if the key exists, otherwise `false`.
+ */
+ async fileExists(model, revision, file) {
+ const storeName = this.fileStoreName;
+ const cacheKey = `${model}/${revision}/${file}`;
+ return new Promise((resolve, reject) => {
+ const transaction = this.db.transaction([storeName], "readonly");
+ const store = transaction.objectStore(storeName);
+ const request = store.getKey(cacheKey);
+ request.onerror = event => reject(event.target.error);
+ request.onsuccess = event => resolve(event.target.result !== undefined);
+ });
+ }
+
+ /**
* Retrieves the headers for a specific cache entry.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name)
+ * @param {string} revision - The model revision.
* @param {string} file - The file name.
* @returns {Promise<object|null>} The headers or null if not found.
*/
- async getHeaders(organization, modelName, modelVersion, file) {
- const headersKey = `${organization}/${modelName}/${modelVersion}`;
- const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ async getHeaders(model, revision, file) {
+ const headersKey = `${model}/${revision}`;
+ const cacheKey = `${model}/${revision}/${file}`;
const headers = await this.#getData(this.headersStoreName, headersKey);
if (headers && headers.files[cacheKey]) {
return headers.files[cacheKey];
@@ -260,22 +278,16 @@ export class IndexedDBCache {
/**
* Retrieves the file for a specific cache entry.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name).
+ * @param {string} revision - The model version.
* @param {string} file - The file name.
* @returns {Promise<[ArrayBuffer, object]|null>} The file ArrayBuffer and its headers or null if not found.
*/
- async getFile(organization, modelName, modelVersion, file) {
- const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ async getFile(model, revision, file) {
+ const cacheKey = `${model}/${revision}/${file}`;
const stored = await this.#getData(this.fileStoreName, cacheKey);
if (stored) {
- const headers = await this.getHeaders(
- organization,
- modelName,
- modelVersion,
- file
- );
+ const headers = await this.getHeaders(model, revision, file);
return [stored.data, headers];
}
return null; // Return null if no file is found
@@ -284,29 +296,21 @@ export class IndexedDBCache {
/**
* Adds or updates a cache entry.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name).
+ * @param {string} revision - The model version.
* @param {string} file - The file name.
* @param {ArrayBuffer} arrayBuffer - The data to cache.
* @param {object} [headers] - The headers for the file.
* @returns {Promise<void>}
*/
- async put(
- organization,
- modelName,
- modelVersion,
- file,
- arrayBuffer,
- headers = {}
- ) {
- const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ async put(model, revision, file, arrayBuffer, headers = {}) {
+ const cacheKey = `${model}/${revision}/${file}`;
const newSize = this.totalSize + arrayBuffer.byteLength;
if (newSize > this.#maxSize) {
throw new Error("Exceeding total cache size limit of 1GB");
}
- const headersKey = `${organization}/${modelName}/${modelVersion}`;
+ const headersKey = `${model}/${revision}`;
const data = { id: cacheKey, data: arrayBuffer };
// Store the file data
@@ -356,13 +360,12 @@ export class IndexedDBCache {
/**
* Deletes all data related to a specific model.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name).
+ * @param {string} revision - The model version.
* @returns {Promise<void>}
*/
- async deleteModel(organization, modelName, modelVersion) {
- const headersKey = `${organization}/${modelName}/${modelVersion}`;
+ async deleteModel(model, revision) {
+ const headersKey = `${model}/${revision}`;
const headers = await this.#getData(this.headersStoreName, headersKey);
if (headers) {
for (const fileKey in headers.files) {
@@ -409,9 +412,9 @@ export class ModelHub {
this.cache = null;
// Ensures the URL template is well-formed and does not contain any invalid characters.
- const pattern = /^(?:\$\{\w+\}|\w+)(?:\/(?:\$\{\w+\}|\w+))*$/;
+ const pattern = /^(?:\{\w+\}|\w+)(?:\/(?:\{\w+\}|\w+))*$/;
// ^ $ Start and end of string
- // (?:\$\{\w+\}|\w+) Match a ${placeholder} or alphanumeric characters
+ // (?:\{\w+\}|\w+) Match a {placeholder} or alphanumeric characters
// (?:\/(?:\$\{\w+\}|\w+))* Zero or more groups of a forward slash followed by a ${placeholder} or alphanumeric characters
if (!pattern.test(urlTemplate)) {
throw new Error(`Invalid URL template: ${urlTemplate}`);
@@ -426,33 +429,94 @@ export class ModelHub {
this.cache = await IndexedDBCache.init();
}
+ /**
+ * This method takes a model URL and parses it to extract the
+ * model name, optional model version, and file path.
+ *
+ * The expected URL format are :
+ *
+ * `/organization/model/revision/filePath`
+ * `https://hub/organization/model/revision/filePath`
+ *
+ * @param {string} url - The full URL to the model, including protocol and domain - or the relative path.
+ * @returns {object} An object containing the parsed components of the URL. The
+ * object has properties `model`, and `file`,
+ * and optionally `revision` if the URL includes a version.
+ * @throws {Error} Throws an error if the URL does not start with `this.rootUrl` or
+ * if the URL format does not match the expected structure.
+ *
+ * @example
+ * // For a URL
+ * parseModelUrl("https://example.com/org1/model1/v1/file/path");
+ * // returns { model: "org1/model1", revision: "v1", file: "file/path" }
+ *
+ * @example
+ * // For a relative URL
+ * parseModelUrl("/org1/model1/revision/file/path");
+ * // returns { model: "org1/model1", revision: "v1", file: "file/path" }
+ */
+ parseUrl(url) {
+ let parts;
+ if (url.startsWith("/")) {
+ // relative URL
+ parts = url.slice(1).split("/");
+ } else {
+ // absolute URL
+ if (!url.startsWith(this.rootUrl)) {
+ throw new Error(`Invalid domain for model URL: ${url}`);
+ }
+ const urlObject = new URL(url);
+ const rootUrlObject = new URL(this.rootUrl);
+
+ // Remove the root URL's pathname from the full URL's pathname
+ const relativePath = urlObject.pathname.substring(
+ rootUrlObject.pathname.length
+ );
+
+ parts = relativePath.slice(1).split("/");
+ }
+
+ if (parts.length < 3) {
+ throw new Error(`Invalid model URL: ${url}`);
+ }
+
+ const file = parts.slice(3).join("/");
+ if (file == null || !file.length) {
+ throw new Error(`Invalid model URL: ${url}`);
+ }
+
+ return {
+ model: `${parts[0]}/${parts[1]}`,
+ revision: parts[2],
+ file,
+ };
+ }
+
/** Creates the file URL from the organization, model, and version.
*
- * @param {string} organization
- * @param {string} modelName
- * @param {string} modelVersion
+ * @param {string} model
+ * @param {string} revision
* @param {string} file
* @returns {string} The full URL
*/
- #fileUrl(organization, modelName, modelVersion, file) {
+ #fileUrl(model, revision, file) {
const baseUrl = new URL(this.rootUrl);
if (!baseUrl.pathname.endsWith("/")) {
baseUrl.pathname += "/";
}
-
// Replace placeholders in the URL template with the provided data.
// If some keys are missing in the data object, the placeholder is left as is.
// If the placeholder is not found in the data object, it is left as is.
const data = {
- organization,
- modelName,
- modelVersion,
- file,
+ model,
+ revision,
};
- const path = this.urlTemplate.replace(
- /\$\{(\w+)\}/g,
+ let path = this.urlTemplate.replace(
+ /\{(\w+)\}/g,
(match, key) => data[key] || match
);
+ path = `${path}/${file}`;
+
const fullPath = `${baseUrl.pathname}${
path.startsWith("/") ? path.slice(1) : path
}`;
@@ -462,28 +526,29 @@ export class ModelHub {
return urlObject.toString();
}
- /** Checks the organization, model, and version inputs.
+ /** Checks the model and revision inputs.
*
- * @param { string } organization
- * @param { string } modelName
- * @param { string } modelVersion
+ * @param { string } model
+ * @param { string } revision
* @param { string } file
* @returns { Error } The error instance(can be null)
*/
- #checkInput(organization, modelName, modelVersion, file) {
- // Ensures string consists only of letters, digits, and hyphens without starting/ending
- // with a hyphen or containing consecutive hyphens.
+ #checkInput(model, revision, file) {
+ // Matches a string with the format 'organization/model' where:
+ // - 'organization' consists only of letters, digits, and hyphens, cannot start or end with a hyphen,
+ // and cannot contain consecutive hyphens.
+ // - 'model' can contain letters, digits, hyphens, underscores, or periods.
//
- // ^ $ Start and end of string
- // (?!-) (?<!-) Negative lookahead/behind for not starting or ending with hyphen
- // (?!.*--) Negative lookahead for not containing consecutive hyphens
- // [A-Za-z0-9-]+ Alphanum characters or hyphens, one or more
- const orgRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)$/;
-
- // Matches strings containing letters, digits, hyphens, underscores, or periods.
- // ^ $ Start and end of string
- // [A-Za-z0-9-_.]+ Alphanum characters, hyphens, underscores, or periods, one or more times
- const modelRegex = /^[A-Za-z0-9-_.]+$/;
+ // Pattern breakdown:
+ // ^ Start of string
+ // (?!-) Negative lookahead for 'organization' not starting with hyphen
+ // (?!.*--) Negative lookahead for 'organization' not containing consecutive hyphens
+ // [A-Za-z0-9-]+ 'organization' part: Alphanumeric characters or hyphens
+ // (?<!-) Negative lookbehind for 'organization' not ending with a hyphen
+ // \/ Literal '/' character separating 'organization' and 'model'
+ // [A-Za-z0-9-_.]+ 'model' part: Alphanumeric characters, hyphens, underscores, or periods
+ // $ End of string
+ const modelRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)\/[A-Za-z0-9-_.]+$/;
// Matches strings consisting of alphanumeric characters, hyphens, or periods.
//
@@ -502,22 +567,18 @@ export class ModelHub {
// \/ Directory separator
// [A-Za-z0-9-_]+ Directory or file name
// )* Zero or more times
- // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension
+ // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension
const fileRegex =
/^(?:\/)?(?!\/)[A-Za-z0-9-_]+(?:\/[A-Za-z0-9-_]+)*(?:[.][A-Za-z]{2,4})?$/;
- if (!orgRegex.test(organization) || !isNaN(parseInt(organization))) {
- return new Error(`Invalid organization name ${organization}`);
- }
-
- if (!modelRegex.test(modelName)) {
+ if (!modelRegex.test(model)) {
return new Error("Invalid model name.");
}
if (
- !versionRegex.test(modelVersion) ||
- modelVersion.includes(" ") ||
- /[\^$]/.test(modelVersion)
+ !versionRegex.test(revision) ||
+ revision.includes(" ") ||
+ /[\^$]/.test(revision)
) {
return new Error("Invalid version identifier.");
}
@@ -536,7 +597,7 @@ export class ModelHub {
* @param {number} timeout in ms. Default is 1000
* @returns {Promise<string>} ETag (can be null)
*/
- async #getETag(url, timeout = 1000) {
+ async getETag(url, timeout = 1000) {
const controller = new AbortController();
const id = lazy.setTimeout(() => controller.abort(), timeout);
@@ -559,24 +620,18 @@ export class ModelHub {
* Given an organization, model, and version, fetch a model file in the hub as a Response.
*
* @param {object} config
- * @param {string} config.organization
- * @param {string} config.modelName
- * @param {string} config.modelVersion
+ * @param {string} config.model
+ * @param {string} config.revision
* @param {string} config.file
* @returns {Promise<Response>} The file content
*/
- async getModelFileAsResponse({
- organization,
- modelName,
- modelVersion,
- file,
- }) {
+ async getModelFileAsResponse({ model, revision, file }) {
const [blob, headers] = await this.getModelFileAsBlob({
- organization,
- modelName,
- modelVersion,
+ model,
+ revision,
file,
});
+
return new Response(blob, { headers });
}
@@ -584,22 +639,15 @@ export class ModelHub {
* Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer.
*
* @param {object} config
- * @param {string} config.organization
- * @param {string} config.modelName
- * @param {string} config.modelVersion
+ * @param {string} config.model
+ * @param {string} config.revision
* @param {string} config.file
* @returns {Promise<[ArrayBuffer, headers]>} The file content
*/
- async getModelFileAsArrayBuffer({
- organization,
- modelName,
- modelVersion,
- file,
- }) {
+ async getModelFileAsArrayBuffer({ model, revision, file }) {
const [blob, headers] = await this.getModelFileAsBlob({
- organization,
- modelName,
- modelVersion,
+ model,
+ revision,
file,
});
return [await blob.arrayBuffer(), headers];
@@ -609,54 +657,44 @@ export class ModelHub {
* Given an organization, model, and version, fetch a model file in the hub as blob.
*
* @param {object} config
- * @param {string} config.organization
- * @param {string} config.modelName
- * @param {string} config.modelVersion
+ * @param {string} config.model
+ * @param {string} config.revision
* @param {string} config.file
* @returns {Promise<[Blob, object]>} The file content
*/
- async getModelFileAsBlob({ organization, modelName, modelVersion, file }) {
+ async getModelFileAsBlob({ model, revision, file }) {
// Make sure inputs are clean. We don't sanitize them but throw an exception
- let checkError = this.#checkInput(
- organization,
- modelName,
- modelVersion,
- file
- );
+ let checkError = this.#checkInput(model, revision, file);
if (checkError) {
throw checkError;
}
-
- const url = this.#fileUrl(organization, modelName, modelVersion, file);
+ const url = this.#fileUrl(model, revision, file);
lazy.console.debug(`Getting model file from ${url}`);
await this.#initCache();
- // this can be null if no ETag was found or there were a network error
- const hubETag = await this.#getETag(url);
+ let useCached;
- lazy.console.debug(
- `Checking the cache for ${organization}/${modelName}/${modelVersion}/${file}`
- );
+ // If the revision is `main` we want to check the ETag in the hub
+ if (revision === "main") {
+ // this can be null if no ETag was found or there were a network error
+ const hubETag = await this.getETag(url);
- // storage lookup
- const cachedHeaders = await this.cache.getHeaders(
- organization,
- modelName,
- modelVersion,
- file
- );
- const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null;
-
- // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response
- if (cachedEtag !== null && (hubETag === null || cachedEtag === hubETag)) {
- lazy.console.debug(`Cache Hit`);
- return await this.cache.getFile(
- organization,
- modelName,
- modelVersion,
- file
- );
+ // Storage ETag lookup
+ const cachedHeaders = await this.cache.getHeaders(model, revision, file);
+ const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null;
+
+ // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response
+ useCached =
+ cachedEtag !== null && (hubETag === null || cachedEtag === hubETag);
+ } else {
+ // If we are dealing with a pinned revision, we ignore the ETag, to spare HEAD hits on every call
+ useCached = await this.cache.fileExists(model, revision, file);
+ }
+
+ if (useCached) {
+ lazy.console.debug(`Cache Hit for ${url}`);
+ return await this.cache.getFile(model, revision, file);
}
lazy.console.debug(`Fetching ${url}`);
@@ -668,13 +706,12 @@ export class ModelHub {
// We don't store the boundary or the charset, just the content type,
// so we drop what's after the semicolon.
"Content-Type": response.headers.get("Content-Type").split(";")[0],
- ETag: hubETag,
+ ETag: response.headers.get("ETag"),
};
await this.cache.put(
- organization,
- modelName,
- modelVersion,
+ model,
+ revision,
file,
await clone.blob(),
headers
diff --git a/toolkit/components/ml/content/ONNXPipeline.mjs b/toolkit/components/ml/content/ONNXPipeline.mjs
new file mode 100644
index 0000000000..fcc1a0eb77
--- /dev/null
+++ b/toolkit/components/ml/content/ONNXPipeline.mjs
@@ -0,0 +1,297 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+// This import does not use Chromutils because the next version of the library
+// will require an async import, which is not supported by importESModule,
+// so we'll just add await here.
+import {
+ env,
+ RawImage,
+ AutoProcessor,
+ AutoTokenizer,
+ AutoModelForVision2Seq,
+} from "chrome://global/content/ml/transformers-dev.js";
+
+/**
+ * Lazy initialization container.
+ *
+ * @type {object}
+ */
+
+const lazy = {};
+
+ChromeUtils.defineESModuleGetters(
+ lazy,
+ {
+ arrayBufferToBlobURL: "chrome://global/content/ml/Utils.sys.mjs",
+ },
+ { global: "current" }
+);
+
+// Using a custom console, see https://bugzilla.mozilla.org/show_bug.cgi?id=1891789
+let _logLevel = "Error";
+
+function debug(...args) {
+ if (["Debug", "Trace", "All"].includes(_logLevel)) {
+ console.log("ML:", ...args); // eslint-disable-line no-console
+ }
+}
+
+/**
+ * Echo inference for testing purposes.
+ *
+ * @async
+ * @param {object} request - The request object containing image data.
+ * @param {object} _model - The model used for inference.
+ * @param {object} _tokenizer - The tokenizer used for decoding.
+ * @param {object} _processor - The processor used for preparing image data.
+ * @returns {Promise<object>} The result object containing the processed text.
+ */
+async function echo(request, _model, _tokenizer, _processor) {
+ return {
+ metrics: {
+ tokenizingTime: 0,
+ },
+ output: request.data,
+ };
+}
+
+/**
+ * Converts an image to text using a machine learning model.
+ *
+ * @async
+ * @param {object} request - The request object containing image data.
+ * @param {string} [request.imageUrl] - The URL of the image to process. Either `imageUrl` or `data` must be provided, but not both.
+ * @param {ArrayBuffer} [request.data] - The raw image data to process. Either `data` or `imageUrl` must be provided, but not both.
+ * @param {string} request.mimeType - The MIME type of the image data.
+ * @param {object} model - The model used for inference.
+ * @param {object} tokenizer - The tokenizer used for decoding.
+ * @param {object} processor - The processor used for preparing image data.
+ * @returns {Promise<object>} The result object containing the processed text.
+ */
+async function imageToText(request, model, tokenizer, processor) {
+ let result = {
+ metrics: {
+ inferenceTime: 0,
+ tokenizingTime: 0,
+ },
+ };
+ let start = Date.now();
+ let rawImage;
+
+ if ("imageUrl" in request) {
+ rawImage = await RawImage.fromUrl(request.imageUrl);
+ } else {
+ const blob = new Blob([request.data], { type: request.mimeType });
+ rawImage = await RawImage.fromBlob(blob);
+ }
+
+ debug("Image loaded in ", Date.now() - start);
+
+ const { pixel_values } = await processor(rawImage);
+ result.metrics.tokenizingTime += Date.now() - start;
+ const toReturn = [];
+ for (const batch of pixel_values) {
+ batch.dims = [1, ...batch.dims];
+ start = Date.now();
+ const output = await model.generate(batch);
+ result.metrics.inferenceTime += Date.now() - start;
+ start = Date.now();
+ const decoded = tokenizer
+ .batch_decode(output, {
+ skip_special_tokens: true,
+ })
+ .map(x => ({ generated_text: x.trim() }));
+ result.metrics.tokenizingTime += Date.now() - start;
+ toReturn.push(decoded);
+ }
+ debug("Inference done in ", Date.now() - start);
+ result.output = toReturn[0][0].generated_text;
+ return result;
+}
+
+/**
+ * Configuration for engine. Each task has a configuration object that
+ * gets merged at runtime with the options from PipelineOptions.
+ *
+ * When a key exists in both the default configuration and the options,
+ * the value from the options is used.
+ *
+ * The configuration keys that are not exposed as options are all the
+ * callables that are used in the pipeline:
+ *
+ * - modelClass
+ * - tokenizerClass
+ * - processorClass
+ * - pipelineFunction
+ *
+ * @type {object}
+ */
+const ENGINE_CONFIGURATION = {
+ "image-to-text": {
+ modelId: "mozilla/distilvit",
+ modelClass: AutoModelForVision2Seq,
+ tokenizerId: "mozilla/distilvit",
+ tokenizerClass: AutoTokenizer,
+ processorId: "mozilla/distilvit",
+ processorClass: AutoProcessor,
+ pipelineFunction: imageToText,
+ },
+ echo: {
+ modelId: null,
+ modelClass: null,
+ tokenizerId: null,
+ tokenizerClass: null,
+ processorId: null,
+ processorClass: null,
+ pipelineFunction: echo,
+ },
+};
+
+/**
+ * Represents a pipeline for processing machine learning tasks.
+ */
+export class Pipeline {
+ #modelCache = null;
+ #model = null;
+ #tokenizer = null;
+ #processor = null;
+ #pipelineFunction = null;
+ #taskName = null;
+ #initTime = 0;
+ #isReady = false;
+
+ /**
+ * Creates an instance of a Pipeline.
+ *
+ * @param {object} modelCache - Implements the Cache interface and used to get models
+ * @param {object} config - The configuration options
+ */
+ constructor(modelCache, config) {
+ let start = Date.now();
+ this.#modelCache = modelCache;
+
+ _logLevel = config.logLevel || "Error";
+ // Setting up the Transformers.js environment
+ // See https://huggingface.co/docs/transformers.js/api/env
+
+ // Caching strategy.
+ // Here we make sure that everytime transformers.js requires a file, it uses
+ // modelCache, which transfers the request to the main thread and uses the
+ // ModelHub that caches files into IndexDB.
+ env.useBrowserCache = false;
+ env.allowLocalModels = false;
+ env.remoteHost = config.modelHubRootUrl;
+ env.remotePathTemplate = config.modelHubUrlTemplate;
+ env.useCustomCache = true;
+ env.customCache = this.#modelCache;
+ env.localModelPath = "/";
+
+ // ONNX runtime - we set up the wasm runtime we got from RS for the ONNX backend to pick
+ debug("Setting up ONNX backend");
+ env.backends.onnx.wasm.wasmPaths = {};
+ env.backends.onnx.wasm.wasmPaths[config.runtimeFilename] =
+ lazy.arrayBufferToBlobURL(config.runtime);
+
+ if (config.modelClass && config.modelId) {
+ debug(`Loading model ${config.modelId} with class ${config.modelClass}`);
+ this.#model = config.modelClass.from_pretrained(config.modelId);
+ }
+ if (config.tokenizerClass && config.tokenizerId) {
+ debug(
+ `Loading tokenizer ${config.tokenizerId} with class ${config.tokenizerClass}`
+ );
+ this.#tokenizer = config.tokenizerClass.from_pretrained(
+ config.tokenizerId
+ );
+ }
+ if (config.processorClass && config.processorId) {
+ debug(
+ `Loading processor ${config.processorId} with class ${config.processorClass}`
+ );
+ this.#processor = config.processorClass.from_pretrained(
+ config.processorId
+ );
+ }
+ this.#taskName = config.taskName;
+ this.#pipelineFunction = config.pipelineFunction.bind(this);
+ this.#initTime = Date.now() - start;
+ debug("Pipeline initialized, took ", this.#initTime);
+ }
+
+ /**
+ * Initializes the pipeline with given options.
+ *
+ * @static
+ * @async
+ * @param {object} modelCache - Implements the Cache interface and used to get models
+ * @param {ArrayBuffer} runtime - The runtime wasm file.
+ * @param {PipelineOptions} options - The options for initialization.
+ * @returns {Promise<Pipeline>} The initialized pipeline instance.
+ */
+ static async initialize(modelCache, runtime, options) {
+ const taskName = options.taskName;
+ debug(`Initializing Pipeline for task ${taskName}`);
+
+ if (!ENGINE_CONFIGURATION[taskName]) {
+ throw new Error(`Task ${taskName} is not supported`);
+ }
+
+ // Loading the config defaults for the task
+ let config = { ...ENGINE_CONFIGURATION[taskName] };
+ config.runtime = runtime;
+
+ // Overriding the defaults with the options
+ options.applyToConfig(config);
+
+ if (!config.pipelineFunction) {
+ throw new Error("pipelineFunction is required for the pipeline");
+ }
+ return new Pipeline(modelCache, config);
+ }
+
+ /**
+ * Runs the pipeline with the given request.
+ *
+ * @async
+ * @param {T} request - The request object to be processed. The fields it may contain
+ * depends on the task. See each pipeline function for more details.
+ * @returns {Promise<object>} The result object from the pipeline execution.
+ */
+ async run(request) {
+ debug("Running task: ", this.#taskName);
+ // Calling all promises to ensure they are resolved before running the first pipeline
+ if (!this.#isReady) {
+ let start = Date.now();
+ debug("Initializing model, tokenizer and processor");
+
+ // deactive console.warn, see https://bugzilla.mozilla.org/show_bug.cgi?id=1891003
+ const originalWarn = console.warn;
+ console.warn = () => {};
+ try {
+ this.#model = await this.#model;
+ this.#tokenizer = await this.#tokenizer;
+ this.#processor = await this.#processor;
+ this.#isReady = true;
+ } catch (error) {
+ debug("Error initializing pipeline", error);
+ throw error;
+ } finally {
+ console.warn = originalWarn;
+ }
+
+ this.#initTime += Date.now() - start;
+ debug("Pipeline is fully initialized, took ", this.#initTime);
+ }
+
+ let result = await this.#pipelineFunction(
+ request,
+ this.#model,
+ this.#tokenizer,
+ this.#processor
+ );
+ result.metrics.initTime = this.#initTime;
+ return result;
+ }
+}
diff --git a/toolkit/components/ml/content/SummarizerModel.sys.mjs b/toolkit/components/ml/content/SummarizerModel.sys.mjs
deleted file mode 100644
index 7cac55d92f..0000000000
--- a/toolkit/components/ml/content/SummarizerModel.sys.mjs
+++ /dev/null
@@ -1,160 +0,0 @@
-/* This Source Code Form is subject to the terms of the Mozilla Public
- * License, v. 2.0. If a copy of the MPL was not distributed with this
- * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
-
-/**
- * @typedef {object} LazyImports
- * @property {typeof import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent
- */
-
-/** @type {LazyImports} */
-const lazy = {};
-
-ChromeUtils.defineESModuleGetters(lazy, {
- RemoteSettings: "resource://services-settings/remote-settings.sys.mjs",
- TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs",
-});
-
-ChromeUtils.defineLazyGetter(lazy, "console", () => {
- return console.createInstance({
- maxLogLevelPref: "browser.ml.logLevel",
- prefix: "ML",
- });
-});
-
-export class SummarizerModel {
- /**
- * The RemoteSettingsClient that downloads the summarizer model.
- *
- * @type {RemoteSettingsClient | null}
- */
- static #remoteClient = null;
-
- /** @type {Promise<WasmRecord> | null} */
- static #modelRecord = null;
-
- /**
- * The following constant controls the major version for wasm downloaded from
- * Remote Settings. When a breaking change is introduced, Nightly will have these
- * numbers incremented by one, but Beta and Release will still be on the previous
- * version. Remote Settings will ship both versions of the records, and the latest
- * asset released in that version will be used. For instance, with a major version
- * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked
- * as "2.0", "2.1", etc will not be downloaded.
- */
- static MODEL_MAJOR_VERSION = 1;
-
- /**
- * Remote settings isn't available in tests, so provide mocked responses.
- */
- static mockRemoteSettings(remoteClient) {
- lazy.console.log("Mocking remote client in SummarizerModel.");
- SummarizerModel.#remoteClient = remoteClient;
- SummarizerModel.#modelRecord = null;
- }
-
- /**
- * Remove anything that could have been mocked.
- */
- static removeMocks() {
- lazy.console.log("Removing mocked remote client in SummarizerModel.");
- SummarizerModel.#remoteClient = null;
- SummarizerModel.#modelRecord = null;
- }
- /**
- * Download or load the model from remote settings.
- *
- * @returns {Promise<ArrayBuffer>}
- */
- static async getModel() {
- const client = SummarizerModel.#getRemoteClient();
-
- if (!SummarizerModel.#modelRecord) {
- // Place the records into a promise to prevent any races.
- SummarizerModel.#modelRecord = (async () => {
- // Load the wasm binary from remote settings, if it hasn't been already.
- lazy.console.log(`Getting the summarizer model record.`);
-
- // TODO - The getMaxVersionRecords should eventually migrated to some kind of
- // shared utility.
- const { getMaxVersionRecords } = lazy.TranslationsParent;
-
- /** @type {WasmRecord[]} */
- const wasmRecords = await getMaxVersionRecords(client, {
- // TODO - This record needs to be created with the engine wasm payload.
- filters: { name: "summarizer-model" },
- majorVersion: SummarizerModel.MODEL_MAJOR_VERSION,
- });
-
- if (wasmRecords.length === 0) {
- // The remote settings client provides an empty list of records when there is
- // an error.
- throw new Error("Unable to get the models from Remote Settings.");
- }
-
- if (wasmRecords.length > 1) {
- SummarizerModel.reportError(
- new Error("Expected the ml engine to only have 1 record."),
- wasmRecords
- );
- }
- const [record] = wasmRecords;
- lazy.console.log(
- `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`,
- record
- );
- return record;
- })();
- }
-
- try {
- /** @type {{buffer: ArrayBuffer}} */
- const { buffer } = await client.attachments.download(
- await SummarizerModel.#modelRecord
- );
-
- return buffer;
- } catch (error) {
- SummarizerModel.#modelRecord = null;
- throw error;
- }
- }
-
- /**
- * Lazily initializes the RemoteSettingsClient.
- *
- * @returns {RemoteSettingsClient}
- */
- static #getRemoteClient() {
- if (SummarizerModel.#remoteClient) {
- return SummarizerModel.#remoteClient;
- }
-
- /** @type {RemoteSettingsClient} */
- const client = lazy.RemoteSettings("ml-model");
-
- SummarizerModel.#remoteClient = client;
-
- client.on("sync", async ({ data: { created, updated, deleted } }) => {
- lazy.console.log(`"sync" event for ml-model`, {
- created,
- updated,
- deleted,
- });
-
- // Remove all the deleted records.
- for (const record of deleted) {
- await client.attachments.deleteDownloaded(record);
- }
-
- // Remove any updated records, and download the new ones.
- for (const { old: oldRecord } of updated) {
- await client.attachments.deleteDownloaded(oldRecord);
- }
-
- // Do nothing for the created records.
- });
-
- return client;
- }
-}
diff --git a/toolkit/components/ml/content/Utils.sys.mjs b/toolkit/components/ml/content/Utils.sys.mjs
new file mode 100644
index 0000000000..b3a25e84d7
--- /dev/null
+++ b/toolkit/components/ml/content/Utils.sys.mjs
@@ -0,0 +1,77 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
+
+/**
+ * Converts an ArrayBuffer to a Blob URL.
+ *
+ * @param {ArrayBuffer} buffer - The ArrayBuffer to convert.
+ * @returns {string} The Blob URL.
+ */
+export function arrayBufferToBlobURL(buffer) {
+ let blob = new Blob([buffer], { type: "application/wasm" });
+ return URL.createObjectURL(blob);
+}
+
+/**
+ * Validate some simple Wasm that uses a SIMD operation.
+ */
+function detectSimdSupport() {
+ return WebAssembly.validate(
+ new Uint8Array(
+ // ```
+ // ;; Detect SIMD support.
+ // ;; Compile by running: wat2wasm --enable-all simd-detect.wat
+ //
+ // (module
+ // (func (result v128)
+ // i32.const 0
+ // i8x16.splat
+ // i8x16.popcnt
+ // )
+ // )
+ // ```
+
+ // prettier-ignore
+ [
+ 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, 0x00,
+ 0x01, 0x7b, 0x03, 0x02, 0x01, 0x00, 0x0a, 0x0a, 0x01, 0x08, 0x00, 0x41, 0x00,
+ 0xfd, 0x0f, 0xfd, 0x62, 0x0b
+ ]
+ )
+ );
+}
+
+let cachedRuntimeWasmFilename = null;
+
+/**
+ * Determines the appropriate WebAssembly (Wasm) filename based on the runtime capabilities of the browser.
+ * This function considers both SIMD and multi-threading support.
+ * It returns a filename that matches the browser's capabilities, ensuring the most optimized version of the Wasm file is used.
+ *
+ * The result is cached to avoid re-computation.
+ *
+ * @param {Window|null} browsingContext - The browsing context to use for feature detection.
+ * @returns {string} The filename of the Wasm file best suited for the current browser's capabilities.
+ */
+export function getRuntimeWasmFilename(browsingContext = null) {
+ if (cachedRuntimeWasmFilename != null) {
+ return cachedRuntimeWasmFilename;
+ }
+
+ // The cross-origin isolation flag is used to determine if we have multi-threading support.
+ const hasMultiThreadSupport = browsingContext
+ ? browsingContext.crossOriginIsolated
+ : false;
+
+ let res;
+ if (detectSimdSupport()) {
+ res = hasMultiThreadSupport
+ ? "ort-wasm-simd-threaded.wasm"
+ : "ort-wasm-simd.wasm";
+ } else {
+ res = hasMultiThreadSupport ? "ort-wasm-threaded.wasm" : "ort-wasm.wasm";
+ }
+ cachedRuntimeWasmFilename = res;
+ return res;
+}