/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; /** * @typedef {import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker */ /** * @typedef {object} Lazy * @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams * @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker * @property {typeof setTimeout} setTimeout * @property {typeof clearTimeout} clearTimeout */ /** @type {Lazy} */ const lazy = {}; ChromeUtils.defineESModuleGetters(lazy, { BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs", setTimeout: "resource://gre/modules/Timer.sys.mjs", clearTimeout: "resource://gre/modules/Timer.sys.mjs", PipelineOptions: "chrome://global/content/ml/EngineProcess.sys.mjs", DEFAULT_ENGINE_ID: "chrome://global/content/ml/EngineProcess.sys.mjs", DEFAULT_MODELS: "chrome://global/content/ml/EngineProcess.sys.mjs", WASM_BACKENDS: "chrome://global/content/ml/EngineProcess.sys.mjs", }); ChromeUtils.defineLazyGetter(lazy, "console", () => { return console.createInstance({ maxLogLevelPref: "browser.ml.logLevel", prefix: "ML:EngineChild", }); }); XPCOMUtils.defineLazyPreferenceGetter( lazy, "CACHE_TIMEOUT_MS", "browser.ml.modelCacheTimeout" ); XPCOMUtils.defineLazyPreferenceGetter( lazy, "MODEL_HUB_ROOT_URL", "browser.ml.modelHubRootUrl" ); XPCOMUtils.defineLazyPreferenceGetter( lazy, "MODEL_HUB_URL_TEMPLATE", "browser.ml.modelHubUrlTemplate" ); XPCOMUtils.defineLazyPreferenceGetter(lazy, "LOG_LEVEL", "browser.ml.logLevel"); XPCOMUtils.defineLazyServiceGetter( lazy, "mlUtils", "@mozilla.org/ml-utils;1", "nsIMLUtils" ); XPCOMUtils.defineLazyPreferenceGetter( lazy, "PIPELINE_OVERRIDE_OPTIONS", "browser.ml.overridePipelineOptions", "{}" ); const SAFE_OVERRIDE_OPTIONS = [ "dtype", "logLevel", "modelRevision", "numThreads", "processorRevision", "timeoutMS", "tokenizerRevision", ]; /** * The engine child is responsible for the life cycle and instantiation of the local * machine learning inference engine. */ export class MLEngineChild extends JSProcessActorChild { /** * The cached engines. * * @type {Map} */ #engineDispatchers = new Map(); /** * Engine statuses * * @type {Map} */ #engineStatuses = new Map(); // eslint-disable-next-line consistent-return async receiveMessage({ name, data }) { switch (name) { case "MLEngine:NewPort": { await this.#onNewPortCreated(data); break; } case "MLEngine:GetStatus": { return this.getStatus(); } case "MLEngine:ForceShutdown": { for (const engineDispatcher of this.#engineDispatchers.values()) { await engineDispatcher.terminate( /* shutDownIfEmpty */ true, /* replacement */ false ); } break; } } } /** * Handles the actions to be performed after a new port has been created. * Specifically, it ensures that the engine dispatcher is created if not already present, * and notifies the parent through the port once the engine dispatcher is ready. * * @param {object} config - Configuration object. * @param {MessagePort} config.port - The port of the channel. * @param {PipelineOptions} config.pipelineOptions - The options for the pipeline. * @returns {Promise} - A promise that resolves once the necessary actions are complete. */ async #onNewPortCreated({ port, pipelineOptions }) { try { // We get some default options from the prefs let options = new lazy.PipelineOptions({ modelHubRootUrl: lazy.MODEL_HUB_ROOT_URL, modelHubUrlTemplate: lazy.MODEL_HUB_URL_TEMPLATE, timeoutMS: lazy.CACHE_TIMEOUT_MS, logLevel: lazy.LOG_LEVEL, }); const updatedPipelineOptions = this.getUpdatedPipelineOptions(pipelineOptions); options.updateOptions(updatedPipelineOptions); const engineId = options.engineId; this.#engineStatuses.set(engineId, "INITIALIZING"); // Check if we already have an engine under this id. if (this.#engineDispatchers.has(engineId)) { let currentEngineDispatcher = this.#engineDispatchers.get(engineId); // The option matches, let's reuse the engine if (currentEngineDispatcher.pipelineOptions.equals(options)) { port.postMessage({ type: "EnginePort:EngineReady", error: null, }); this.#engineStatuses.set(engineId, "READY"); return; } // The options do not match, terminate the old one so we have a single engine per id. await currentEngineDispatcher.terminate( /* shutDownIfEmpty */ false, /* replacement */ true ); this.#engineDispatchers.delete(engineId); } this.#engineStatuses.set(engineId, "CREATING"); const dispatcher = new EngineDispatcher(this, port, options); this.#engineDispatchers.set(engineId, dispatcher); // When the pipeline is mocked typically in unit tests, the WASM files are // mocked. In these cases, the pipeline is not resolved during // initialization to allow the test to work. // // NOTE: This is done after adding to #engineDispatchers to ensure other // async calls see the new dispatcher. if (!lazy.PipelineOptions.isMocked(pipelineOptions)) { await dispatcher.ensureInferenceEngineIsReady(); } this.#engineStatuses.set(engineId, "READY"); port.postMessage({ type: "EnginePort:EngineReady", error: null, }); } catch (error) { port.postMessage({ type: "EnginePort:EngineReady", error, }); } } /** * Gets the wasm array buffer from RemoteSettings. * * @param {string} backend - The ML engine for which the WASM buffer is requested. * @returns {Promise} */ getWasmArrayBuffer(backend) { return this.sendQuery("MLEngine:GetWasmArrayBuffer", backend); } /** * Gets the configuration of the worker * * @returns {Promise} */ getWorkerConfig() { return this.sendQuery("MLEngine:GetWorkerConfig"); } /** * Gets the inference options from RemoteSettings. * * @returns {Promise} */ getInferenceOptions(featureId, taskName, modelId) { return this.sendQuery("MLEngine:GetInferenceOptions", { featureId, taskName, modelId, }); } /** * Retrieves a model file and headers by communicating with the parent actor. * * @param {object} config - The configuration accepted by the parent function. * @returns {Promise<[string, object]>} The file local path and headers */ getModelFile(config) { return this.sendQuery("MLEngine:GetModelFile", config); } /** * Notify that the model download is completed by communicating with the parent actor. * * @param {object} config - The configuration accepted by the parent function. */ async notifyModelDownloadComplete(config) { this.sendQuery("MLEngine:NotifyModelDownloadComplete", config); } /** * Removes an engine by its ID. Optionally shuts down if no engines remain. * * @param {string} engineId - The ID of the engine to remove. * @param {boolean} [shutDownIfEmpty] - If true, shuts down the engine process if no engines remain. * @param {boolean} replacement - Flag indicating whether the engine is being replaced. */ removeEngine(engineId, shutDownIfEmpty, replacement) { this.#engineDispatchers.delete(engineId); this.#engineStatuses.delete(engineId); this.sendAsyncMessage("MLEngine:Removed", { engineId, shutdown: shutDownIfEmpty, replacement, }); if (this.#engineDispatchers.size === 0 && shutDownIfEmpty) { this.sendAsyncMessage("MLEngine:DestroyEngineProcess"); } } /** * Collects information about the current status. */ async getStatus() { const statusMap = new Map(); for (const [key, value] of this.#engineStatuses) { if (this.#engineDispatchers.has(key)) { statusMap.set(key, this.#engineDispatchers.get(key).getStatus()); } else { // The engine is probably being created statusMap.set(key, { status: value }); } } return statusMap; } /** * @param {PipelineOptions} pipelineOptions - options that we want to safely override * @returns {object} - updated pipeline options */ getUpdatedPipelineOptions(pipelineOptions) { const overrideOptionsByFeature = JSON.parse(lazy.PIPELINE_OVERRIDE_OPTIONS); const overrideOptions = {}; if (overrideOptionsByFeature.hasOwnProperty(pipelineOptions.featureId)) { for (let key of Object.keys( overrideOptionsByFeature[pipelineOptions.featureId] )) { if (SAFE_OVERRIDE_OPTIONS.includes(key)) { overrideOptions[key] = overrideOptionsByFeature[pipelineOptions.featureId][key]; } } } return { ...pipelineOptions, ...overrideOptions }; } } /** * This classes manages the lifecycle of an ML Engine, and handles dispatching messages * to it. */ class EngineDispatcher { /** @type {MessagePort | null} */ #port = null; /** @type {TimeoutID | null} */ #keepAliveTimeout = null; /** @type {PromiseWithResolvers} */ #modelRequest; /** @type {Promise | null} */ #engine = null; /** @type {string} */ #taskName; /** @type {string} */ #featureId; /** @type {string} */ #engineId; /** @type {PipelineOptions | null} */ pipelineOptions = null; /** @type {string} */ #status; /** * Creates the inference engine given the wasm runtime and the run options. * * The initialization is done in three steps: * 1. The wasm runtime is fetched from RS * 2. The inference options are fetched from RS and augmented with the pipeline options. * 3. The inference engine is created with the wasm runtime and the options. * * Any exception here will be bubbled up for the constructor to log. * * @param {PipelineOptions} pipelineOptions * @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback The callback to call for updating about notifications such as dowload progress status. * @returns {Promise} */ async initializeInferenceEngine(pipelineOptions, notificationsCallback) { let remoteSettingsOptions = await this.mlEngineChild.getInferenceOptions( this.#featureId, this.#taskName, pipelineOptions.modelId ?? null ); // Merge the RemoteSettings inference options with the pipeline options provided. let mergedOptions = new lazy.PipelineOptions(remoteSettingsOptions); mergedOptions.updateOptions(pipelineOptions); // If the merged options don't have a modelId and we have a default modelId, we set it if (!mergedOptions.modelId) { const defaultModelEntry = lazy.DEFAULT_MODELS[this.#taskName]; if (defaultModelEntry) { lazy.console.debug( `Using default model ${defaultModelEntry.modelId} for task ${this.#taskName}` ); mergedOptions.updateOptions(defaultModelEntry); } else { throw new Error(`No default model found for task ${this.#taskName}`); } } lazy.console.debug("Inference engine options:", mergedOptions); this.pipelineOptions = mergedOptions; // load the wasm if required. let wasm = null; if (lazy.WASM_BACKENDS.includes(pipelineOptions.backend || "onnx")) { wasm = await this.mlEngineChild.getWasmArrayBuffer( pipelineOptions.backend ); } const workerConfig = await this.mlEngineChild.getWorkerConfig(); return InferenceEngine.create({ workerUrl: workerConfig.url, workerOptions: workerConfig.options, wasm, pipelineOptions: mergedOptions, notificationsCallback, getModelFileFn: this.mlEngineChild.getModelFile.bind(this.mlEngineChild), notifyModelDownloadCompleteFn: this.mlEngineChild.notifyModelDownloadComplete.bind(this.mlEngineChild), }); } /** * Private Constructor for an Engine Dispatcher. * * @param {MLEngineChild} mlEngineChild * @param {MessagePort} port * @param {PipelineOptions} pipelineOptions */ constructor(mlEngineChild, port, pipelineOptions) { this.#status = "CREATED"; this.mlEngineChild = mlEngineChild; this.#featureId = pipelineOptions.featureId; this.#taskName = pipelineOptions.taskName; this.timeoutMS = pipelineOptions.timeoutMS; this.#engineId = pipelineOptions.engineId; this.#engine = this.initializeInferenceEngine( pipelineOptions, notificationsData => { this.handleInitProgressStatus(port, notificationsData); } ); // Trigger the keep alive timer. this.#engine .then(() => void this.keepAlive()) .catch(error => { if ( // Ignore errors from tests intentionally causing errors. !error?.message?.startsWith("Intentionally") ) { lazy.console.error("Could not initalize the engine", error); } }); this.#setupMessageHandler(port); } /** * Returns the status of the engine */ getStatus() { return { status: this.#status, options: this.pipelineOptions, engineId: this.#engineId, }; } /** * Resolves the engine to fully initialize it. */ async ensureInferenceEngineIsReady() { this.#engine = await this.#engine; this.#status = "READY"; } handleInitProgressStatus(port, notificationsData) { port.postMessage({ type: "EnginePort:InitProgress", statusResponse: notificationsData, }); } /** * The worker will be shutdown automatically after some amount of time of not being used, unless: * * - timeoutMS is set to -1 */ keepAlive() { if (this.#keepAliveTimeout) { // Clear any previous timeout. lazy.clearTimeout(this.#keepAliveTimeout); } if (this.timeoutMS >= 0) { this.#keepAliveTimeout = lazy.setTimeout( this.terminate.bind( this, /* shutDownIfEmpty */ true, /* replacement */ false ), this.timeoutMS ); } else { this.#keepAliveTimeout = null; } } /** * @param {MessagePort} port */ getModel(port) { if (this.#modelRequest) { // There could be a race to get a model, use the first request. return this.#modelRequest.promise; } this.#modelRequest = Promise.withResolvers(); port.postMessage({ type: "EnginePort:ModelRequest" }); return this.#modelRequest.promise; } /** * @param {MessagePort} port */ #setupMessageHandler(port) { this.#port = port; port.onmessage = async ({ data }) => { switch (data.type) { case "EnginePort:Discard": { port.close(); this.#port = null; break; } case "EnginePort:Terminate": { await this.terminate(data.shutdown, data.replacement); break; } case "EnginePort:ModelResponse": { if (this.#modelRequest) { const { model, error } = data; if (model) { this.#modelRequest.resolve(model); } else { this.#modelRequest.reject(error); } this.#modelRequest = null; } else { lazy.console.error( "Got a EnginePort:ModelResponse but no model resolvers" ); } break; } case "EnginePort:Run": { const { requestId, request, engineRunOptions } = data; try { await this.ensureInferenceEngineIsReady(); } catch (error) { port.postMessage({ type: "EnginePort:RunResponse", requestId, response: null, error, }); // The engine failed to load. Terminate the entire dispatcher. await this.terminate( /* shutDownIfEmpty */ true, /* replacement */ false ); return; } // Do not run the keepAlive timer until we are certain that the engine loaded, // as the engine shouldn't be killed while it is initializing. this.keepAlive(); this.#status = "RUNNING"; try { port.postMessage({ type: "EnginePort:RunResponse", requestId, response: await this.#engine.run( request, requestId, engineRunOptions ), error: null, }); } catch (error) { port.postMessage({ type: "EnginePort:RunResponse", requestId, response: null, error, }); } this.#status = "IDLING"; break; } default: lazy.console.error("Unknown port message to engine: ", data); break; } }; } /** * Terminates the engine and its worker after a timeout. * * @param {boolean} shutDownIfEmpty - If true, shuts down the engine process if no engines remain. * @param {boolean} replacement - Flag indicating whether the engine is being replaced. */ async terminate(shutDownIfEmpty, replacement) { if (this.#keepAliveTimeout) { lazy.clearTimeout(this.#keepAliveTimeout); this.#keepAliveTimeout = null; } if (this.#port) { // This call will trigger back an EnginePort:Discard that will close the port this.#port.postMessage({ type: "EnginePort:EngineTerminated" }); } this.#status = "TERMINATING"; try { const engine = await this.#engine; engine.terminate(); } catch (error) { lazy.console.error("Failed to get the engine", error); } this.#status = "TERMINATED"; this.mlEngineChild.removeEngine( this.#engineId, shutDownIfEmpty, replacement ); } } /** * Wrapper for a function that fetches a model file from a specified URL and task name. * * @param {object} config * @param {string} config.engineId - The engine id - defaults to "default-engine". * @param {string} config.taskName - name of the inference task. * @param {string} config.url - The URL of the model file to fetch. Can be a path relative to * the model hub root or an absolute URL. * @param {string} config.modelHubRootUrl - root url of the model hub. When not provided, uses the default from prefs. * @param {string} config.modelHubUrlTemplate - url template of the model hub. When not provided, uses the default from prefs. * @param {?function(object):Promise<[string, object]>} config.getModelFileFn - A function that actually retrieves the model and headers. * @param {string} config.featureId - The feature id * @param {string} config.sessionId - Shared across the same session. * @param {object} config.telemetryData - Additional telemetry data. * @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers, * and model path. */ async function getModelFile({ engineId, taskName, url, getModelFileFn, modelHubRootUrl, modelHubUrlTemplate, featureId, sessionId, telemetryData, }) { const [data, headers] = await getModelFileFn({ engineId: engineId || lazy.DEFAULT_ENGINE_ID, taskName, url, rootUrl: modelHubRootUrl || lazy.MODEL_HUB_ROOT_URL, urlTemplate: modelHubUrlTemplate || lazy.MODEL_HUB_URL_TEMPLATE, featureId, sessionId, telemetryData, }); return new lazy.BasePromiseWorker.Meta([url, headers, data], {}); } /** * Wrapper around the ChromeWorker that runs the inference. */ class InferenceEngine { /** @type {BasePromiseWorker} */ #worker; /** * Initialize the worker. * * @param {object} config * @param {string} config.workerUrl The url of the worker * @param {object} config.workerOptions the options to pass to BasePromiseWorker * @param {ArrayBuffer} config.wasm * @param {PipelineOptions} config.pipelineOptions * @param {?function(ProgressAndStatusCallbackParams):void} config.notificationsCallback The callback to call for updating about notifications such as dowload progress status. * @param {?function(object):Promise<[string, object]>} config.getModelFileFn - A function that actually retrieves the model and headers. * @param {?function(object):Promise} config.notifyModelDownloadCompleteFn - A function to notify that all files needing downloads are completed. * @returns {InferenceEngine} */ static async create({ workerUrl, workerOptions, wasm, pipelineOptions, notificationsCallback, // eslint-disable-line no-unused-vars getModelFileFn, notifyModelDownloadCompleteFn, }) { // Check for the numThreads value. If it's not set, use the best value for the platform, which is the number of physical cores pipelineOptions.numThreads = pipelineOptions.numThreads || lazy.mlUtils.getOptimalCPUConcurrency(); /** @type {BasePromiseWorker} */ const worker = new lazy.BasePromiseWorker(workerUrl, workerOptions, { getModelFile: async (url, sessionId = "") => getModelFile({ engineId: pipelineOptions.engineId, url, taskName: pipelineOptions.taskName, getModelFileFn, modelHubRootUrl: pipelineOptions.modelHubRootUrl, modelHubUrlTemplate: pipelineOptions.modelHubUrlTemplate, featureId: pipelineOptions.featureId, sessionId, // We have model, revision that are parsed for the url. // However, we want to save in telemetry the ones that are configured // for the pipeline. This allows consistent reporting regarding of how // the backend constructs the url. telemetryData: { modelId: pipelineOptions.modelId, modelRevision: pipelineOptions.modelRevision, }, }), onInferenceProgress: notificationsCallback, notifyModelDownloadComplete: async (sessionId = "") => notifyModelDownloadCompleteFn({ sessionId, featureId: pipelineOptions.featureId, engineId: pipelineOptions.engineId, modelId: pipelineOptions.modelId, modelRevision: pipelineOptions.modelRevision, }), }); const args = [wasm, pipelineOptions]; const closure = {}; const transferables = wasm instanceof ArrayBuffer ? [wasm] : []; await worker.post("initializeEngine", args, closure, transferables); return new InferenceEngine(worker); } /** * @param {BasePromiseWorker} worker */ constructor(worker) { this.#worker = worker; } /** * @param {string} request * @param {string} requestId - The identifier used to internally track this request. * @param {object} engineRunOptions - Additional run options for the engine. * @param {boolean} engineRunOptions.enableInferenceProgress - Whether to enable inference progress. * @returns {Promise} */ run(request, requestId, engineRunOptions) { return this.#worker.post("run", [request, requestId, engineRunOptions]); } terminate() { if (this.#worker) { this.#worker.terminate(); this.#worker = null; } } }