diff options
Diffstat (limited to 'toolkit/components/ml/content')
-rw-r--r-- | toolkit/components/ml/content/EngineProcess.sys.mjs | 188 | ||||
-rw-r--r-- | toolkit/components/ml/content/MLEngine.worker.mjs | 98 | ||||
-rw-r--r-- | toolkit/components/ml/content/ModelHub.sys.mjs | 317 | ||||
-rw-r--r-- | toolkit/components/ml/content/ONNXPipeline.mjs | 297 | ||||
-rw-r--r-- | toolkit/components/ml/content/SummarizerModel.sys.mjs | 160 | ||||
-rw-r--r-- | toolkit/components/ml/content/Utils.sys.mjs | 77 |
6 files changed, 798 insertions, 339 deletions
diff --git a/toolkit/components/ml/content/EngineProcess.sys.mjs b/toolkit/components/ml/content/EngineProcess.sys.mjs index 36a9381192..0fe6403cc8 100644 --- a/toolkit/components/ml/content/EngineProcess.sys.mjs +++ b/toolkit/components/ml/content/EngineProcess.sys.mjs @@ -2,10 +2,17 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +// known to be loaded early in the startup process, and should be loaded eagerly +import { AppConstants } from "resource://gre/modules/AppConstants.sys.mjs"; + const lazy = {}; -ChromeUtils.defineESModuleGetters(lazy, { - HiddenFrame: "resource://gre/modules/HiddenFrame.sys.mjs", -}); +ChromeUtils.defineESModuleGetters( + lazy, + { + HiddenFrame: "resource://gre/modules/HiddenFrame.sys.mjs", + }, + { global: "current" } +); /** * @typedef {import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent @@ -16,6 +23,172 @@ ChromeUtils.defineESModuleGetters(lazy, { */ /** + * This class encapsulates the options for a pipeline process. + */ +export class PipelineOptions { + /** + * The name of the task the pipeline is configured for. + * + * @type {?string} + */ + taskName = null; + + /** + * The maximum amount of time in milliseconds the pipeline should wait for a response. + * + * @type {?number} + */ + timeoutMS = null; + + /** + * The root URL of the model hub where models are hosted. + * + * @type {?string} + */ + modelHubRootUrl = null; + + /** + * A template URL for building the full URL for the model. + * + * @type {?string} + */ + modelHubUrlTemplate = null; + + /** + * The identifier for the specific model to be used by the pipeline. + * + * @type {?string} + */ + modelId = null; + + /** + * The revision for the specific model to be used by the pipeline. + * + * @type {?string} + */ + modelRevision = null; + + /** + * The identifier for the tokenizer associated with the model, used for pre-processing inputs. + * + * @type {?string} + */ + tokenizerId = null; + + /** + * The revision for the tokenizer associated with the model, used for pre-processing inputs. + * + * @type {?string} + */ + tokenizerRevision = null; + + /** + * The identifier for any processor required by the model, used for additional input processing. + * + * @type {?string} + */ + processorId = null; + + /** + * The revision for any processor required by the model, used for additional input processing. + * + * @type {?string} + */ + + processorRevision = null; + + /** + * The log level used in the worker + * + * @type {?string} + */ + logLevel = null; + + /** + * Name of the runtime wasm file + * + * @type {?string} + */ + runtimeFilename = null; + + /** + * Create a PipelineOptions instance. + * + * @param {object} options - The options for the pipeline. Must include mandatory fields. + */ + constructor(options) { + this.updateOptions(options); + } + + /** + * Updates multiple options at once. + * + * @param {object} options - An object containing the options to update. + * @throws {Error} Throws an error if an invalid option is provided. + */ + updateOptions(options) { + const allowedKeys = [ + "taskName", + "modelHubRootUrl", + "modelHubUrlTemplate", + "timeoutMS", + "modelId", + "modelRevision", + "tokenizerId", + "tokenizerRevision", + "processorId", + "processorRevision", + "logLevel", + "runtimeFilename", + ]; + + Object.keys(options).forEach(key => { + if (allowedKeys.includes(key)) { + this[key] = options[key]; // Use bracket notation to access setter + } else { + throw new Error(`Invalid option: ${key}`); + } + }); + } + + /** + * Returns an object containing all current options. + + * @returns {object} An object with the current options. + */ + getOptions() { + return { + taskName: this.taskName, + modelHubRootUrl: this.modelHubRootUrl, + modelHubUrlTemplate: this.modelHubUrlTemplate, + timeoutMS: this.timeoutMS, + modelId: this.modelId, + modelRevision: this.modelRevision, + tokenizerId: this.tokenizerId, + tokenizerRevision: this.tokenizerRevision, + processorId: this.processorId, + processorRevision: this.processorRevision, + logLevel: this.logLevel, + runtimeFilename: this.runtimeFilename, + }; + } + + /** + * Updates the given configuration object with the options. + * + * @param {object} config - The configuration object to be updated. + */ + applyToConfig(config) { + const options = this.getOptions(); + Object.keys(options).forEach(key => { + if (options[key] !== null) { + config[key] = options[key]; + } + }); + } +} + +/** * This class controls the life cycle of the engine process used both in the * Translations engine and the MLEngine component. */ @@ -68,6 +241,15 @@ export class EngineProcess { * @returns {Promise<MLEngineParent>} */ static async getMLEngineParent() { + // Bug 1890946 - enable the inference engine in release + if (!AppConstants.NIGHTLY_BUILD) { + throw new Error("MLEngine is only available in Nightly builds."); + } + // the pref is off by default + if (!Services.prefs.getBoolPref("browser.ml.enable")) { + throw new Error("MLEngine is disabled. Check the browser.ml prefs."); + } + if (!this.mlEngineParent) { this.mlEngineParent = this.#attachBrowser({ id: "ml-engine-browser", diff --git a/toolkit/components/ml/content/MLEngine.worker.mjs b/toolkit/components/ml/content/MLEngine.worker.mjs index 1013977e07..585ac4ab04 100644 --- a/toolkit/components/ml/content/MLEngine.worker.mjs +++ b/toolkit/components/ml/content/MLEngine.worker.mjs @@ -2,74 +2,99 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -import { PromiseWorker } from "resource://gre/modules/workers/PromiseWorker.mjs"; +const lazy = {}; -// Respect the preference "browser.ml.logLevel". -let _loggingLevel = "Error"; -function log(...args) { - if (_loggingLevel !== "Error" && _loggingLevel !== "Warn") { - console.log("ML:", ...args); - } -} -function trace(...args) { - if (_loggingLevel === "Trace" || _loggingLevel === "All") { - console.log("ML:", ...args); - } -} +ChromeUtils.defineESModuleGetters( + lazy, + { + PromiseWorker: "resource://gre/modules/workers/PromiseWorker.mjs", + Pipeline: "chrome://global/content/ml/ONNXPipeline.mjs", + PipelineOptions: "chrome://global/content/ml/EngineProcess.sys.mjs", + }, + { global: "current" } +); /** * The actual MLEngine lives here in a worker. */ class MLEngineWorker { - /** @type {ArrayBuffer} */ - #wasm; - /** @type {ArrayBuffer} */ - #model; + #pipeline; constructor() { // Connect the provider to the worker. this.#connectToPromiseWorker(); } + /** Implements the `match` function from the Cache API for Transformers.js custom cache. + * + * See https://developer.mozilla.org/en-US/docs/Web/API/Cache + * + * Attempts to match and retrieve a model file based on a provided key. + * Fetches a model file by delegating the call to the worker's main thread. + * Then wraps the fetched model file into a response object compatible with Transformers.js expectations. + * + * @param {string} key The unique identifier for the model to fetch. + * @returns {Promise<Response|null>} A promise that resolves with a Response object containing the model file or null if not found. + */ + async match(key) { + let res = await this.getModelFile(key); + if (res.fail) { + return null; + } + let headers = res.ok[1]; + let modelFile = res.ok[2]; + // Transformers.js expects a response object, so we wrap the array buffer + const response = new Response(modelFile, { + status: 200, + headers, + }); + return response; + } + + async getModelFile(...args) { + let result = await self.callMainThread("getModelFile", args); + return result; + } + /** - * @param {ArrayBuffer} wasm - * @param {ArrayBuffer} model - * @param {string} loggingLevel + * Placeholder for the `put` method from the Cache API for Transformers.js custom cache. + * + * @throws {Error} Always thrown to indicate the method is not implemented. */ - initializeEngine(wasm, model, loggingLevel) { - this.#wasm = wasm; - this.#model = model; - _loggingLevel = loggingLevel; - // TODO - Initialize the engine for real here. - log("MLEngineWorker is initalized"); + put() { + throw new Error("Method not implemented."); } /** + * @param {ArrayBuffer} wasm + * @param {object} options received as an object, converted to a PipelineOptions instance + */ + async initializeEngine(wasm, options) { + this.#pipeline = await lazy.Pipeline.initialize( + this, + wasm, + new lazy.PipelineOptions(options) + ); + } + /** * Run the worker. * * @param {string} request */ - run(request) { - if (!this.#wasm) { - throw new Error("Expected the wasm to exist."); - } - if (!this.#model) { - throw new Error("Expected the model to exist"); - } + async run(request) { if (request === "throw") { throw new Error( 'Received the message "throw", so intentionally throwing an error.' ); } - trace("inference run requested with:", request); - return request.slice(0, Math.floor(request.length / 2)); + return await this.#pipeline.run(request); } /** * Glue code to connect the `MLEngineWorker` to the PromiseWorker interface. */ #connectToPromiseWorker() { - const worker = new PromiseWorker.AbstractWorker(); + const worker = new lazy.PromiseWorker.AbstractWorker(); worker.dispatch = (method, args = []) => { if (!this[method]) { throw new Error("Method does not exist: " + method); @@ -81,6 +106,7 @@ class MLEngineWorker { self.postMessage(message, ...transfers); }; + self.callMainThread = worker.callMainThread.bind(worker); self.addEventListener("message", msg => worker.handleMessage(msg)); self.addEventListener("unhandledrejection", function (error) { throw error.reason; diff --git a/toolkit/components/ml/content/ModelHub.sys.mjs b/toolkit/components/ml/content/ModelHub.sys.mjs index 4c2181ff14..10f83c3000 100644 --- a/toolkit/components/ml/content/ModelHub.sys.mjs +++ b/toolkit/components/ml/content/ModelHub.sys.mjs @@ -24,8 +24,7 @@ const ALLOWED_HUBS = [ ]; const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"]; -const DEFAULT_URL_TEMPLATE = - "${organization}/${modelName}/resolve/${modelVersion}/${file}"; +const DEFAULT_URL_TEMPLATE = "{model}/resolve/{revision}"; /** * Checks if a given URL string corresponds to an allowed hub. @@ -239,17 +238,36 @@ export class IndexedDBCache { } /** + * Checks if a specified model file exists in storage. + * + * @param {string} model - The model name (organization/name) + * @param {string} revision - The model revision. + * @param {string} file - The file name. + * @returns {Promise<boolean>} A promise that resolves with `true` if the key exists, otherwise `false`. + */ + async fileExists(model, revision, file) { + const storeName = this.fileStoreName; + const cacheKey = `${model}/${revision}/${file}`; + return new Promise((resolve, reject) => { + const transaction = this.db.transaction([storeName], "readonly"); + const store = transaction.objectStore(storeName); + const request = store.getKey(cacheKey); + request.onerror = event => reject(event.target.error); + request.onsuccess = event => resolve(event.target.result !== undefined); + }); + } + + /** * Retrieves the headers for a specific cache entry. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name) + * @param {string} revision - The model revision. * @param {string} file - The file name. * @returns {Promise<object|null>} The headers or null if not found. */ - async getHeaders(organization, modelName, modelVersion, file) { - const headersKey = `${organization}/${modelName}/${modelVersion}`; - const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + async getHeaders(model, revision, file) { + const headersKey = `${model}/${revision}`; + const cacheKey = `${model}/${revision}/${file}`; const headers = await this.#getData(this.headersStoreName, headersKey); if (headers && headers.files[cacheKey]) { return headers.files[cacheKey]; @@ -260,22 +278,16 @@ export class IndexedDBCache { /** * Retrieves the file for a specific cache entry. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name). + * @param {string} revision - The model version. * @param {string} file - The file name. * @returns {Promise<[ArrayBuffer, object]|null>} The file ArrayBuffer and its headers or null if not found. */ - async getFile(organization, modelName, modelVersion, file) { - const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + async getFile(model, revision, file) { + const cacheKey = `${model}/${revision}/${file}`; const stored = await this.#getData(this.fileStoreName, cacheKey); if (stored) { - const headers = await this.getHeaders( - organization, - modelName, - modelVersion, - file - ); + const headers = await this.getHeaders(model, revision, file); return [stored.data, headers]; } return null; // Return null if no file is found @@ -284,29 +296,21 @@ export class IndexedDBCache { /** * Adds or updates a cache entry. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name). + * @param {string} revision - The model version. * @param {string} file - The file name. * @param {ArrayBuffer} arrayBuffer - The data to cache. * @param {object} [headers] - The headers for the file. * @returns {Promise<void>} */ - async put( - organization, - modelName, - modelVersion, - file, - arrayBuffer, - headers = {} - ) { - const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + async put(model, revision, file, arrayBuffer, headers = {}) { + const cacheKey = `${model}/${revision}/${file}`; const newSize = this.totalSize + arrayBuffer.byteLength; if (newSize > this.#maxSize) { throw new Error("Exceeding total cache size limit of 1GB"); } - const headersKey = `${organization}/${modelName}/${modelVersion}`; + const headersKey = `${model}/${revision}`; const data = { id: cacheKey, data: arrayBuffer }; // Store the file data @@ -356,13 +360,12 @@ export class IndexedDBCache { /** * Deletes all data related to a specific model. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name). + * @param {string} revision - The model version. * @returns {Promise<void>} */ - async deleteModel(organization, modelName, modelVersion) { - const headersKey = `${organization}/${modelName}/${modelVersion}`; + async deleteModel(model, revision) { + const headersKey = `${model}/${revision}`; const headers = await this.#getData(this.headersStoreName, headersKey); if (headers) { for (const fileKey in headers.files) { @@ -409,9 +412,9 @@ export class ModelHub { this.cache = null; // Ensures the URL template is well-formed and does not contain any invalid characters. - const pattern = /^(?:\$\{\w+\}|\w+)(?:\/(?:\$\{\w+\}|\w+))*$/; + const pattern = /^(?:\{\w+\}|\w+)(?:\/(?:\{\w+\}|\w+))*$/; // ^ $ Start and end of string - // (?:\$\{\w+\}|\w+) Match a ${placeholder} or alphanumeric characters + // (?:\{\w+\}|\w+) Match a {placeholder} or alphanumeric characters // (?:\/(?:\$\{\w+\}|\w+))* Zero or more groups of a forward slash followed by a ${placeholder} or alphanumeric characters if (!pattern.test(urlTemplate)) { throw new Error(`Invalid URL template: ${urlTemplate}`); @@ -426,33 +429,94 @@ export class ModelHub { this.cache = await IndexedDBCache.init(); } + /** + * This method takes a model URL and parses it to extract the + * model name, optional model version, and file path. + * + * The expected URL format are : + * + * `/organization/model/revision/filePath` + * `https://hub/organization/model/revision/filePath` + * + * @param {string} url - The full URL to the model, including protocol and domain - or the relative path. + * @returns {object} An object containing the parsed components of the URL. The + * object has properties `model`, and `file`, + * and optionally `revision` if the URL includes a version. + * @throws {Error} Throws an error if the URL does not start with `this.rootUrl` or + * if the URL format does not match the expected structure. + * + * @example + * // For a URL + * parseModelUrl("https://example.com/org1/model1/v1/file/path"); + * // returns { model: "org1/model1", revision: "v1", file: "file/path" } + * + * @example + * // For a relative URL + * parseModelUrl("/org1/model1/revision/file/path"); + * // returns { model: "org1/model1", revision: "v1", file: "file/path" } + */ + parseUrl(url) { + let parts; + if (url.startsWith("/")) { + // relative URL + parts = url.slice(1).split("/"); + } else { + // absolute URL + if (!url.startsWith(this.rootUrl)) { + throw new Error(`Invalid domain for model URL: ${url}`); + } + const urlObject = new URL(url); + const rootUrlObject = new URL(this.rootUrl); + + // Remove the root URL's pathname from the full URL's pathname + const relativePath = urlObject.pathname.substring( + rootUrlObject.pathname.length + ); + + parts = relativePath.slice(1).split("/"); + } + + if (parts.length < 3) { + throw new Error(`Invalid model URL: ${url}`); + } + + const file = parts.slice(3).join("/"); + if (file == null || !file.length) { + throw new Error(`Invalid model URL: ${url}`); + } + + return { + model: `${parts[0]}/${parts[1]}`, + revision: parts[2], + file, + }; + } + /** Creates the file URL from the organization, model, and version. * - * @param {string} organization - * @param {string} modelName - * @param {string} modelVersion + * @param {string} model + * @param {string} revision * @param {string} file * @returns {string} The full URL */ - #fileUrl(organization, modelName, modelVersion, file) { + #fileUrl(model, revision, file) { const baseUrl = new URL(this.rootUrl); if (!baseUrl.pathname.endsWith("/")) { baseUrl.pathname += "/"; } - // Replace placeholders in the URL template with the provided data. // If some keys are missing in the data object, the placeholder is left as is. // If the placeholder is not found in the data object, it is left as is. const data = { - organization, - modelName, - modelVersion, - file, + model, + revision, }; - const path = this.urlTemplate.replace( - /\$\{(\w+)\}/g, + let path = this.urlTemplate.replace( + /\{(\w+)\}/g, (match, key) => data[key] || match ); + path = `${path}/${file}`; + const fullPath = `${baseUrl.pathname}${ path.startsWith("/") ? path.slice(1) : path }`; @@ -462,28 +526,29 @@ export class ModelHub { return urlObject.toString(); } - /** Checks the organization, model, and version inputs. + /** Checks the model and revision inputs. * - * @param { string } organization - * @param { string } modelName - * @param { string } modelVersion + * @param { string } model + * @param { string } revision * @param { string } file * @returns { Error } The error instance(can be null) */ - #checkInput(organization, modelName, modelVersion, file) { - // Ensures string consists only of letters, digits, and hyphens without starting/ending - // with a hyphen or containing consecutive hyphens. + #checkInput(model, revision, file) { + // Matches a string with the format 'organization/model' where: + // - 'organization' consists only of letters, digits, and hyphens, cannot start or end with a hyphen, + // and cannot contain consecutive hyphens. + // - 'model' can contain letters, digits, hyphens, underscores, or periods. // - // ^ $ Start and end of string - // (?!-) (?<!-) Negative lookahead/behind for not starting or ending with hyphen - // (?!.*--) Negative lookahead for not containing consecutive hyphens - // [A-Za-z0-9-]+ Alphanum characters or hyphens, one or more - const orgRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)$/; - - // Matches strings containing letters, digits, hyphens, underscores, or periods. - // ^ $ Start and end of string - // [A-Za-z0-9-_.]+ Alphanum characters, hyphens, underscores, or periods, one or more times - const modelRegex = /^[A-Za-z0-9-_.]+$/; + // Pattern breakdown: + // ^ Start of string + // (?!-) Negative lookahead for 'organization' not starting with hyphen + // (?!.*--) Negative lookahead for 'organization' not containing consecutive hyphens + // [A-Za-z0-9-]+ 'organization' part: Alphanumeric characters or hyphens + // (?<!-) Negative lookbehind for 'organization' not ending with a hyphen + // \/ Literal '/' character separating 'organization' and 'model' + // [A-Za-z0-9-_.]+ 'model' part: Alphanumeric characters, hyphens, underscores, or periods + // $ End of string + const modelRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)\/[A-Za-z0-9-_.]+$/; // Matches strings consisting of alphanumeric characters, hyphens, or periods. // @@ -502,22 +567,18 @@ export class ModelHub { // \/ Directory separator // [A-Za-z0-9-_]+ Directory or file name // )* Zero or more times - // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension + // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension const fileRegex = /^(?:\/)?(?!\/)[A-Za-z0-9-_]+(?:\/[A-Za-z0-9-_]+)*(?:[.][A-Za-z]{2,4})?$/; - if (!orgRegex.test(organization) || !isNaN(parseInt(organization))) { - return new Error(`Invalid organization name ${organization}`); - } - - if (!modelRegex.test(modelName)) { + if (!modelRegex.test(model)) { return new Error("Invalid model name."); } if ( - !versionRegex.test(modelVersion) || - modelVersion.includes(" ") || - /[\^$]/.test(modelVersion) + !versionRegex.test(revision) || + revision.includes(" ") || + /[\^$]/.test(revision) ) { return new Error("Invalid version identifier."); } @@ -536,7 +597,7 @@ export class ModelHub { * @param {number} timeout in ms. Default is 1000 * @returns {Promise<string>} ETag (can be null) */ - async #getETag(url, timeout = 1000) { + async getETag(url, timeout = 1000) { const controller = new AbortController(); const id = lazy.setTimeout(() => controller.abort(), timeout); @@ -559,24 +620,18 @@ export class ModelHub { * Given an organization, model, and version, fetch a model file in the hub as a Response. * * @param {object} config - * @param {string} config.organization - * @param {string} config.modelName - * @param {string} config.modelVersion + * @param {string} config.model + * @param {string} config.revision * @param {string} config.file * @returns {Promise<Response>} The file content */ - async getModelFileAsResponse({ - organization, - modelName, - modelVersion, - file, - }) { + async getModelFileAsResponse({ model, revision, file }) { const [blob, headers] = await this.getModelFileAsBlob({ - organization, - modelName, - modelVersion, + model, + revision, file, }); + return new Response(blob, { headers }); } @@ -584,22 +639,15 @@ export class ModelHub { * Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer. * * @param {object} config - * @param {string} config.organization - * @param {string} config.modelName - * @param {string} config.modelVersion + * @param {string} config.model + * @param {string} config.revision * @param {string} config.file * @returns {Promise<[ArrayBuffer, headers]>} The file content */ - async getModelFileAsArrayBuffer({ - organization, - modelName, - modelVersion, - file, - }) { + async getModelFileAsArrayBuffer({ model, revision, file }) { const [blob, headers] = await this.getModelFileAsBlob({ - organization, - modelName, - modelVersion, + model, + revision, file, }); return [await blob.arrayBuffer(), headers]; @@ -609,54 +657,44 @@ export class ModelHub { * Given an organization, model, and version, fetch a model file in the hub as blob. * * @param {object} config - * @param {string} config.organization - * @param {string} config.modelName - * @param {string} config.modelVersion + * @param {string} config.model + * @param {string} config.revision * @param {string} config.file * @returns {Promise<[Blob, object]>} The file content */ - async getModelFileAsBlob({ organization, modelName, modelVersion, file }) { + async getModelFileAsBlob({ model, revision, file }) { // Make sure inputs are clean. We don't sanitize them but throw an exception - let checkError = this.#checkInput( - organization, - modelName, - modelVersion, - file - ); + let checkError = this.#checkInput(model, revision, file); if (checkError) { throw checkError; } - - const url = this.#fileUrl(organization, modelName, modelVersion, file); + const url = this.#fileUrl(model, revision, file); lazy.console.debug(`Getting model file from ${url}`); await this.#initCache(); - // this can be null if no ETag was found or there were a network error - const hubETag = await this.#getETag(url); + let useCached; - lazy.console.debug( - `Checking the cache for ${organization}/${modelName}/${modelVersion}/${file}` - ); + // If the revision is `main` we want to check the ETag in the hub + if (revision === "main") { + // this can be null if no ETag was found or there were a network error + const hubETag = await this.getETag(url); - // storage lookup - const cachedHeaders = await this.cache.getHeaders( - organization, - modelName, - modelVersion, - file - ); - const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null; - - // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response - if (cachedEtag !== null && (hubETag === null || cachedEtag === hubETag)) { - lazy.console.debug(`Cache Hit`); - return await this.cache.getFile( - organization, - modelName, - modelVersion, - file - ); + // Storage ETag lookup + const cachedHeaders = await this.cache.getHeaders(model, revision, file); + const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null; + + // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response + useCached = + cachedEtag !== null && (hubETag === null || cachedEtag === hubETag); + } else { + // If we are dealing with a pinned revision, we ignore the ETag, to spare HEAD hits on every call + useCached = await this.cache.fileExists(model, revision, file); + } + + if (useCached) { + lazy.console.debug(`Cache Hit for ${url}`); + return await this.cache.getFile(model, revision, file); } lazy.console.debug(`Fetching ${url}`); @@ -668,13 +706,12 @@ export class ModelHub { // We don't store the boundary or the charset, just the content type, // so we drop what's after the semicolon. "Content-Type": response.headers.get("Content-Type").split(";")[0], - ETag: hubETag, + ETag: response.headers.get("ETag"), }; await this.cache.put( - organization, - modelName, - modelVersion, + model, + revision, file, await clone.blob(), headers diff --git a/toolkit/components/ml/content/ONNXPipeline.mjs b/toolkit/components/ml/content/ONNXPipeline.mjs new file mode 100644 index 0000000000..fcc1a0eb77 --- /dev/null +++ b/toolkit/components/ml/content/ONNXPipeline.mjs @@ -0,0 +1,297 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// This import does not use Chromutils because the next version of the library +// will require an async import, which is not supported by importESModule, +// so we'll just add await here. +import { + env, + RawImage, + AutoProcessor, + AutoTokenizer, + AutoModelForVision2Seq, +} from "chrome://global/content/ml/transformers-dev.js"; + +/** + * Lazy initialization container. + * + * @type {object} + */ + +const lazy = {}; + +ChromeUtils.defineESModuleGetters( + lazy, + { + arrayBufferToBlobURL: "chrome://global/content/ml/Utils.sys.mjs", + }, + { global: "current" } +); + +// Using a custom console, see https://bugzilla.mozilla.org/show_bug.cgi?id=1891789 +let _logLevel = "Error"; + +function debug(...args) { + if (["Debug", "Trace", "All"].includes(_logLevel)) { + console.log("ML:", ...args); // eslint-disable-line no-console + } +} + +/** + * Echo inference for testing purposes. + * + * @async + * @param {object} request - The request object containing image data. + * @param {object} _model - The model used for inference. + * @param {object} _tokenizer - The tokenizer used for decoding. + * @param {object} _processor - The processor used for preparing image data. + * @returns {Promise<object>} The result object containing the processed text. + */ +async function echo(request, _model, _tokenizer, _processor) { + return { + metrics: { + tokenizingTime: 0, + }, + output: request.data, + }; +} + +/** + * Converts an image to text using a machine learning model. + * + * @async + * @param {object} request - The request object containing image data. + * @param {string} [request.imageUrl] - The URL of the image to process. Either `imageUrl` or `data` must be provided, but not both. + * @param {ArrayBuffer} [request.data] - The raw image data to process. Either `data` or `imageUrl` must be provided, but not both. + * @param {string} request.mimeType - The MIME type of the image data. + * @param {object} model - The model used for inference. + * @param {object} tokenizer - The tokenizer used for decoding. + * @param {object} processor - The processor used for preparing image data. + * @returns {Promise<object>} The result object containing the processed text. + */ +async function imageToText(request, model, tokenizer, processor) { + let result = { + metrics: { + inferenceTime: 0, + tokenizingTime: 0, + }, + }; + let start = Date.now(); + let rawImage; + + if ("imageUrl" in request) { + rawImage = await RawImage.fromUrl(request.imageUrl); + } else { + const blob = new Blob([request.data], { type: request.mimeType }); + rawImage = await RawImage.fromBlob(blob); + } + + debug("Image loaded in ", Date.now() - start); + + const { pixel_values } = await processor(rawImage); + result.metrics.tokenizingTime += Date.now() - start; + const toReturn = []; + for (const batch of pixel_values) { + batch.dims = [1, ...batch.dims]; + start = Date.now(); + const output = await model.generate(batch); + result.metrics.inferenceTime += Date.now() - start; + start = Date.now(); + const decoded = tokenizer + .batch_decode(output, { + skip_special_tokens: true, + }) + .map(x => ({ generated_text: x.trim() })); + result.metrics.tokenizingTime += Date.now() - start; + toReturn.push(decoded); + } + debug("Inference done in ", Date.now() - start); + result.output = toReturn[0][0].generated_text; + return result; +} + +/** + * Configuration for engine. Each task has a configuration object that + * gets merged at runtime with the options from PipelineOptions. + * + * When a key exists in both the default configuration and the options, + * the value from the options is used. + * + * The configuration keys that are not exposed as options are all the + * callables that are used in the pipeline: + * + * - modelClass + * - tokenizerClass + * - processorClass + * - pipelineFunction + * + * @type {object} + */ +const ENGINE_CONFIGURATION = { + "image-to-text": { + modelId: "mozilla/distilvit", + modelClass: AutoModelForVision2Seq, + tokenizerId: "mozilla/distilvit", + tokenizerClass: AutoTokenizer, + processorId: "mozilla/distilvit", + processorClass: AutoProcessor, + pipelineFunction: imageToText, + }, + echo: { + modelId: null, + modelClass: null, + tokenizerId: null, + tokenizerClass: null, + processorId: null, + processorClass: null, + pipelineFunction: echo, + }, +}; + +/** + * Represents a pipeline for processing machine learning tasks. + */ +export class Pipeline { + #modelCache = null; + #model = null; + #tokenizer = null; + #processor = null; + #pipelineFunction = null; + #taskName = null; + #initTime = 0; + #isReady = false; + + /** + * Creates an instance of a Pipeline. + * + * @param {object} modelCache - Implements the Cache interface and used to get models + * @param {object} config - The configuration options + */ + constructor(modelCache, config) { + let start = Date.now(); + this.#modelCache = modelCache; + + _logLevel = config.logLevel || "Error"; + // Setting up the Transformers.js environment + // See https://huggingface.co/docs/transformers.js/api/env + + // Caching strategy. + // Here we make sure that everytime transformers.js requires a file, it uses + // modelCache, which transfers the request to the main thread and uses the + // ModelHub that caches files into IndexDB. + env.useBrowserCache = false; + env.allowLocalModels = false; + env.remoteHost = config.modelHubRootUrl; + env.remotePathTemplate = config.modelHubUrlTemplate; + env.useCustomCache = true; + env.customCache = this.#modelCache; + env.localModelPath = "/"; + + // ONNX runtime - we set up the wasm runtime we got from RS for the ONNX backend to pick + debug("Setting up ONNX backend"); + env.backends.onnx.wasm.wasmPaths = {}; + env.backends.onnx.wasm.wasmPaths[config.runtimeFilename] = + lazy.arrayBufferToBlobURL(config.runtime); + + if (config.modelClass && config.modelId) { + debug(`Loading model ${config.modelId} with class ${config.modelClass}`); + this.#model = config.modelClass.from_pretrained(config.modelId); + } + if (config.tokenizerClass && config.tokenizerId) { + debug( + `Loading tokenizer ${config.tokenizerId} with class ${config.tokenizerClass}` + ); + this.#tokenizer = config.tokenizerClass.from_pretrained( + config.tokenizerId + ); + } + if (config.processorClass && config.processorId) { + debug( + `Loading processor ${config.processorId} with class ${config.processorClass}` + ); + this.#processor = config.processorClass.from_pretrained( + config.processorId + ); + } + this.#taskName = config.taskName; + this.#pipelineFunction = config.pipelineFunction.bind(this); + this.#initTime = Date.now() - start; + debug("Pipeline initialized, took ", this.#initTime); + } + + /** + * Initializes the pipeline with given options. + * + * @static + * @async + * @param {object} modelCache - Implements the Cache interface and used to get models + * @param {ArrayBuffer} runtime - The runtime wasm file. + * @param {PipelineOptions} options - The options for initialization. + * @returns {Promise<Pipeline>} The initialized pipeline instance. + */ + static async initialize(modelCache, runtime, options) { + const taskName = options.taskName; + debug(`Initializing Pipeline for task ${taskName}`); + + if (!ENGINE_CONFIGURATION[taskName]) { + throw new Error(`Task ${taskName} is not supported`); + } + + // Loading the config defaults for the task + let config = { ...ENGINE_CONFIGURATION[taskName] }; + config.runtime = runtime; + + // Overriding the defaults with the options + options.applyToConfig(config); + + if (!config.pipelineFunction) { + throw new Error("pipelineFunction is required for the pipeline"); + } + return new Pipeline(modelCache, config); + } + + /** + * Runs the pipeline with the given request. + * + * @async + * @param {T} request - The request object to be processed. The fields it may contain + * depends on the task. See each pipeline function for more details. + * @returns {Promise<object>} The result object from the pipeline execution. + */ + async run(request) { + debug("Running task: ", this.#taskName); + // Calling all promises to ensure they are resolved before running the first pipeline + if (!this.#isReady) { + let start = Date.now(); + debug("Initializing model, tokenizer and processor"); + + // deactive console.warn, see https://bugzilla.mozilla.org/show_bug.cgi?id=1891003 + const originalWarn = console.warn; + console.warn = () => {}; + try { + this.#model = await this.#model; + this.#tokenizer = await this.#tokenizer; + this.#processor = await this.#processor; + this.#isReady = true; + } catch (error) { + debug("Error initializing pipeline", error); + throw error; + } finally { + console.warn = originalWarn; + } + + this.#initTime += Date.now() - start; + debug("Pipeline is fully initialized, took ", this.#initTime); + } + + let result = await this.#pipelineFunction( + request, + this.#model, + this.#tokenizer, + this.#processor + ); + result.metrics.initTime = this.#initTime; + return result; + } +} diff --git a/toolkit/components/ml/content/SummarizerModel.sys.mjs b/toolkit/components/ml/content/SummarizerModel.sys.mjs deleted file mode 100644 index 7cac55d92f..0000000000 --- a/toolkit/components/ml/content/SummarizerModel.sys.mjs +++ /dev/null @@ -1,160 +0,0 @@ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this - * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ - -/** - * @typedef {object} LazyImports - * @property {typeof import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent - */ - -/** @type {LazyImports} */ -const lazy = {}; - -ChromeUtils.defineESModuleGetters(lazy, { - RemoteSettings: "resource://services-settings/remote-settings.sys.mjs", - TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs", -}); - -ChromeUtils.defineLazyGetter(lazy, "console", () => { - return console.createInstance({ - maxLogLevelPref: "browser.ml.logLevel", - prefix: "ML", - }); -}); - -export class SummarizerModel { - /** - * The RemoteSettingsClient that downloads the summarizer model. - * - * @type {RemoteSettingsClient | null} - */ - static #remoteClient = null; - - /** @type {Promise<WasmRecord> | null} */ - static #modelRecord = null; - - /** - * The following constant controls the major version for wasm downloaded from - * Remote Settings. When a breaking change is introduced, Nightly will have these - * numbers incremented by one, but Beta and Release will still be on the previous - * version. Remote Settings will ship both versions of the records, and the latest - * asset released in that version will be used. For instance, with a major version - * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked - * as "2.0", "2.1", etc will not be downloaded. - */ - static MODEL_MAJOR_VERSION = 1; - - /** - * Remote settings isn't available in tests, so provide mocked responses. - */ - static mockRemoteSettings(remoteClient) { - lazy.console.log("Mocking remote client in SummarizerModel."); - SummarizerModel.#remoteClient = remoteClient; - SummarizerModel.#modelRecord = null; - } - - /** - * Remove anything that could have been mocked. - */ - static removeMocks() { - lazy.console.log("Removing mocked remote client in SummarizerModel."); - SummarizerModel.#remoteClient = null; - SummarizerModel.#modelRecord = null; - } - /** - * Download or load the model from remote settings. - * - * @returns {Promise<ArrayBuffer>} - */ - static async getModel() { - const client = SummarizerModel.#getRemoteClient(); - - if (!SummarizerModel.#modelRecord) { - // Place the records into a promise to prevent any races. - SummarizerModel.#modelRecord = (async () => { - // Load the wasm binary from remote settings, if it hasn't been already. - lazy.console.log(`Getting the summarizer model record.`); - - // TODO - The getMaxVersionRecords should eventually migrated to some kind of - // shared utility. - const { getMaxVersionRecords } = lazy.TranslationsParent; - - /** @type {WasmRecord[]} */ - const wasmRecords = await getMaxVersionRecords(client, { - // TODO - This record needs to be created with the engine wasm payload. - filters: { name: "summarizer-model" }, - majorVersion: SummarizerModel.MODEL_MAJOR_VERSION, - }); - - if (wasmRecords.length === 0) { - // The remote settings client provides an empty list of records when there is - // an error. - throw new Error("Unable to get the models from Remote Settings."); - } - - if (wasmRecords.length > 1) { - SummarizerModel.reportError( - new Error("Expected the ml engine to only have 1 record."), - wasmRecords - ); - } - const [record] = wasmRecords; - lazy.console.log( - `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`, - record - ); - return record; - })(); - } - - try { - /** @type {{buffer: ArrayBuffer}} */ - const { buffer } = await client.attachments.download( - await SummarizerModel.#modelRecord - ); - - return buffer; - } catch (error) { - SummarizerModel.#modelRecord = null; - throw error; - } - } - - /** - * Lazily initializes the RemoteSettingsClient. - * - * @returns {RemoteSettingsClient} - */ - static #getRemoteClient() { - if (SummarizerModel.#remoteClient) { - return SummarizerModel.#remoteClient; - } - - /** @type {RemoteSettingsClient} */ - const client = lazy.RemoteSettings("ml-model"); - - SummarizerModel.#remoteClient = client; - - client.on("sync", async ({ data: { created, updated, deleted } }) => { - lazy.console.log(`"sync" event for ml-model`, { - created, - updated, - deleted, - }); - - // Remove all the deleted records. - for (const record of deleted) { - await client.attachments.deleteDownloaded(record); - } - - // Remove any updated records, and download the new ones. - for (const { old: oldRecord } of updated) { - await client.attachments.deleteDownloaded(oldRecord); - } - - // Do nothing for the created records. - }); - - return client; - } -} diff --git a/toolkit/components/ml/content/Utils.sys.mjs b/toolkit/components/ml/content/Utils.sys.mjs new file mode 100644 index 0000000000..b3a25e84d7 --- /dev/null +++ b/toolkit/components/ml/content/Utils.sys.mjs @@ -0,0 +1,77 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +/** + * Converts an ArrayBuffer to a Blob URL. + * + * @param {ArrayBuffer} buffer - The ArrayBuffer to convert. + * @returns {string} The Blob URL. + */ +export function arrayBufferToBlobURL(buffer) { + let blob = new Blob([buffer], { type: "application/wasm" }); + return URL.createObjectURL(blob); +} + +/** + * Validate some simple Wasm that uses a SIMD operation. + */ +function detectSimdSupport() { + return WebAssembly.validate( + new Uint8Array( + // ``` + // ;; Detect SIMD support. + // ;; Compile by running: wat2wasm --enable-all simd-detect.wat + // + // (module + // (func (result v128) + // i32.const 0 + // i8x16.splat + // i8x16.popcnt + // ) + // ) + // ``` + + // prettier-ignore + [ + 0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, 0x01, 0x05, 0x01, 0x60, 0x00, + 0x01, 0x7b, 0x03, 0x02, 0x01, 0x00, 0x0a, 0x0a, 0x01, 0x08, 0x00, 0x41, 0x00, + 0xfd, 0x0f, 0xfd, 0x62, 0x0b + ] + ) + ); +} + +let cachedRuntimeWasmFilename = null; + +/** + * Determines the appropriate WebAssembly (Wasm) filename based on the runtime capabilities of the browser. + * This function considers both SIMD and multi-threading support. + * It returns a filename that matches the browser's capabilities, ensuring the most optimized version of the Wasm file is used. + * + * The result is cached to avoid re-computation. + * + * @param {Window|null} browsingContext - The browsing context to use for feature detection. + * @returns {string} The filename of the Wasm file best suited for the current browser's capabilities. + */ +export function getRuntimeWasmFilename(browsingContext = null) { + if (cachedRuntimeWasmFilename != null) { + return cachedRuntimeWasmFilename; + } + + // The cross-origin isolation flag is used to determine if we have multi-threading support. + const hasMultiThreadSupport = browsingContext + ? browsingContext.crossOriginIsolated + : false; + + let res; + if (detectSimdSupport()) { + res = hasMultiThreadSupport + ? "ort-wasm-simd-threaded.wasm" + : "ort-wasm-simd.wasm"; + } else { + res = hasMultiThreadSupport ? "ort-wasm-threaded.wasm" : "ort-wasm.wasm"; + } + cachedRuntimeWasmFilename = res; + return res; +} |