diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-06-12 05:43:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-06-12 05:43:14 +0000 |
commit | 8dd16259287f58f9273002717ec4d27e97127719 (patch) | |
tree | 3863e62a53829a84037444beab3abd4ed9dfc7d0 /toolkit/components/ml/actors | |
parent | Releasing progress-linux version 126.0.1-1~progress7.99u1. (diff) | |
download | firefox-8dd16259287f58f9273002717ec4d27e97127719.tar.xz firefox-8dd16259287f58f9273002717ec4d27e97127719.zip |
Merging upstream version 127.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/components/ml/actors')
-rw-r--r-- | toolkit/components/ml/actors/MLEngineChild.sys.mjs | 156 | ||||
-rw-r--r-- | toolkit/components/ml/actors/MLEngineParent.sys.mjs | 175 |
2 files changed, 233 insertions, 98 deletions
diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs index 925ce59266..17a8b3511a 100644 --- a/toolkit/components/ml/actors/MLEngineChild.sys.mjs +++ b/toolkit/components/ml/actors/MLEngineChild.sys.mjs @@ -21,6 +21,8 @@ ChromeUtils.defineESModuleGetters(lazy, { BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs", setTimeout: "resource://gre/modules/Timer.sys.mjs", clearTimeout: "resource://gre/modules/Timer.sys.mjs", + ModelHub: "chrome://global/content/ml/ModelHub.sys.mjs", + PipelineOptions: "chrome://global/content/ml/EngineProcess.sys.mjs", }); ChromeUtils.defineLazyGetter(lazy, "console", () => { @@ -32,9 +34,20 @@ ChromeUtils.defineLazyGetter(lazy, "console", () => { XPCOMUtils.defineLazyPreferenceGetter( lazy, - "loggingLevel", - "browser.ml.logLevel" + "CACHE_TIMEOUT_MS", + "browser.ml.modelCacheTimeout" ); +XPCOMUtils.defineLazyPreferenceGetter( + lazy, + "MODEL_HUB_ROOT_URL", + "browser.ml.modelHubRootUrl" +); +XPCOMUtils.defineLazyPreferenceGetter( + lazy, + "MODEL_HUB_URL_TEMPLATE", + "browser.ml.modelHubUrlTemplate" +); +XPCOMUtils.defineLazyPreferenceGetter(lazy, "LOG_LEVEL", "browser.ml.logLevel"); /** * The engine child is responsible for the life cycle and instantiation of the local @@ -52,10 +65,21 @@ export class MLEngineChild extends JSWindowActorChild { async receiveMessage({ name, data }) { switch (name) { case "MLEngine:NewPort": { - const { engineName, port, timeoutMS } = data; + const { port, pipelineOptions } = data; + + // Override some options using prefs + let options = new lazy.PipelineOptions(pipelineOptions); + + options.updateOptions({ + modelHubRootUrl: lazy.MODEL_HUB_ROOT_URL, + modelHubUrlTemplate: lazy.MODEL_HUB_URL_TEMPLATE, + timeoutMS: lazy.CACHE_TIMEOUT_MS, + logLevel: lazy.LOG_LEVEL, + }); + this.#engineDispatchers.set( - engineName, - new EngineDispatcher(this, port, engineName, timeoutMS) + options.taskName, + new EngineDispatcher(this, port, options) ); break; } @@ -78,13 +102,24 @@ export class MLEngineChild extends JSWindowActorChild { } /** - * @returns {ArrayBuffer} + * Gets the wasm array buffer from RemoteSettings. + * + * @returns {Promise<ArrayBuffer>} */ getWasmArrayBuffer() { return this.sendQuery("MLEngine:GetWasmArrayBuffer"); } /** + * Gets the inference options from RemoteSettings. + * + * @returns {Promise<object>} + */ + getInferenceOptions(taskName) { + return this.sendQuery(`MLEngine:GetInferenceOptions:${taskName}`); + } + + /** * @param {string} engineName */ removeEngine(engineName) { @@ -113,28 +148,45 @@ class EngineDispatcher { #engine = null; /** @type {string} */ - #engineName; + #taskName; + + /** Creates the inference engine given the wasm runtime and the run options. + * + * The initialization is done in three steps: + * 1. The wasm runtime is fetched from RS + * 2. The inference options are fetched from RS and augmented with the pipeline options. + * 3. The inference engine is created with the wasm runtime and the options. + * + * Any exception here will be bubbled up for the constructor to log. + * + * @param {PipelineOptions} pipelineOptions + * @returns {Promise<Engine>} + */ + async initializeInferenceEngine(pipelineOptions) { + // Create the inference engine given the wasm runtime and the options. + const wasm = await this.mlEngineChild.getWasmArrayBuffer(); + const inferenceOptions = await this.mlEngineChild.getInferenceOptions( + this.#taskName + ); + lazy.console.debug("Inference engine options:", inferenceOptions); + pipelineOptions.updateOptions(inferenceOptions); + + return InferenceEngine.create(wasm, pipelineOptions); + } /** * @param {MLEngineChild} mlEngineChild * @param {MessagePort} port - * @param {string} engineName - * @param {number} timeoutMS + * @param {PipelineOptions} pipelineOptions */ - constructor(mlEngineChild, port, engineName, timeoutMS) { - /** @type {MLEngineChild} */ + constructor(mlEngineChild, port, pipelineOptions) { this.mlEngineChild = mlEngineChild; + this.#taskName = pipelineOptions.taskName; + this.timeoutMS = pipelineOptions.timeoutMS; - /** @type {number} */ - this.timeoutMS = timeoutMS; - - this.#engineName = engineName; - - this.#engine = Promise.all([ - this.mlEngineChild.getWasmArrayBuffer(), - this.getModel(port), - ]).then(([wasm, model]) => FakeEngine.create(wasm, model)); + this.#engine = this.initializeInferenceEngine(pipelineOptions); + // Trigger the keep alive timer. this.#engine .then(() => void this.keepAlive()) .catch(error => { @@ -265,20 +317,61 @@ class EngineDispatcher { port.close(); } this.#ports = new Set(); - this.mlEngineChild.removeEngine(this.#engineName); + this.mlEngineChild.removeEngine(this.#taskName); try { const engine = await this.#engine; engine.terminate(); } catch (error) { - console.error("Failed to get the engine", error); + lazy.console.error("Failed to get the engine", error); } } } +let modelHub = null; // This will hold the ModelHub instance to reuse it. + +/** + * Retrieves a model file as an ArrayBuffer from the specified URL. + * This function normalizes the URL, extracts the organization, model name, and file path, + * then fetches the model file using the ModelHub API. The `modelHub` instance is created + * only once and reused for subsequent calls to optimize performance. + * + * @param {string} url - The URL of the model file to fetch. Can be a path relative to + * the model hub root or an absolute URL. + * @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers, + * and data as an ArrayBuffer. The data is marked for transfer to avoid cloning. + */ +async function getModelFile(url) { + // Create the model hub instance if needed + if (!modelHub) { + lazy.console.debug("Creating model hub instance"); + modelHub = new lazy.ModelHub({ + rootUrl: lazy.MODEL_HUB_ROOT_URL, + urlTemplate: lazy.MODEL_HUB_URL_TEMPLATE, + }); + } + + if (url.startsWith(lazy.MODEL_HUB_ROOT_URL)) { + url = url.slice(lazy.MODEL_HUB_ROOT_URL.length); + // Make sure we get a front slash + if (!url.startsWith("/")) { + url = `/${url}`; + } + } + + // Parsing url to get model name, and file path. + // if this errors out, it will be caught in the worker + const parsedUrl = modelHub.parseUrl(url); + + let [data, headers] = await modelHub.getModelFileAsArrayBuffer(parsedUrl); + return new lazy.BasePromiseWorker.Meta([url, headers, data], { + transfers: [data], + }); +} + /** - * Fake the engine by slicing the text in half. + * Wrapper around the ChromeWorker that runs the inference. */ -class FakeEngine { +class InferenceEngine { /** @type {BasePromiseWorker} */ #worker; @@ -286,21 +379,22 @@ class FakeEngine { * Initialize the worker. * * @param {ArrayBuffer} wasm - * @param {ArrayBuffer} model - * @returns {FakeEngine} + * @param {PipelineOptions} pipelineOptions + * @returns {InferenceEngine} */ - static async create(wasm, model) { + static async create(wasm, pipelineOptions) { /** @type {BasePromiseWorker} */ const worker = new lazy.BasePromiseWorker( "chrome://global/content/ml/MLEngine.worker.mjs", - { type: "module" } + { type: "module" }, + { getModelFile } ); - const args = [wasm, model, lazy.loggingLevel]; + const args = [wasm, pipelineOptions]; const closure = {}; - const transferables = [wasm, model]; + const transferables = [wasm]; await worker.post("initializeEngine", args, closure, transferables); - return new FakeEngine(worker); + return new InferenceEngine(worker); } /** diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs index 05203e5f69..65941d9d4e 100644 --- a/toolkit/components/ml/actors/MLEngineParent.sys.mjs +++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs @@ -5,6 +5,7 @@ /** * @typedef {object} Lazy * @property {typeof console} console + * @property {typeof import("../content/Utils.sys.mjs").getRuntimeWasmFilename} getRuntimeWasmFilename * @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess * @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings * @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent @@ -21,16 +22,14 @@ ChromeUtils.defineLazyGetter(lazy, "console", () => { }); ChromeUtils.defineESModuleGetters(lazy, { + getRuntimeWasmFilename: "chrome://global/content/ml/Utils.sys.mjs", EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs", RemoteSettings: "resource://services-settings/remote-settings.sys.mjs", TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs", }); -/** - * @typedef {import("../../translations/translations").WasmRecord} WasmRecord - */ - -const DEFAULT_CACHE_TIMEOUT_MS = 15_000; +const RS_RUNTIME_COLLECTION = "ml-onnx-runtime"; +const RS_INFERENCE_OPTIONS_COLLECTION = "ml-inference-options"; /** * The ML engine is in its own content process. This actor handles the @@ -40,9 +39,9 @@ export class MLEngineParent extends JSWindowActorParent { /** * The RemoteSettingsClient that downloads the wasm binaries. * - * @type {RemoteSettingsClient | null} + * @type {Record<string, RemoteSettingsClient>} */ - static #remoteClient = null; + static #remoteClients = {}; /** @type {Promise<WasmRecord> | null} */ static #wasmRecord = null; @@ -61,11 +60,11 @@ export class MLEngineParent extends JSWindowActorParent { /** * Remote settings isn't available in tests, so provide mocked responses. * - * @param {RemoteSettingsClient} remoteClient + * @param {RemoteSettingsClient} remoteClients */ - static mockRemoteSettings(remoteClient) { + static mockRemoteSettings(remoteClients) { lazy.console.log("Mocking remote settings in MLEngineParent."); - MLEngineParent.#remoteClient = remoteClient; + MLEngineParent.#remoteClients = remoteClients; MLEngineParent.#wasmRecord = null; } @@ -74,24 +73,49 @@ export class MLEngineParent extends JSWindowActorParent { */ static removeMocks() { lazy.console.log("Removing mocked remote client in MLEngineParent."); - MLEngineParent.#remoteClient = null; + MLEngineParent.#remoteClients = {}; MLEngineParent.#wasmRecord = null; } - /** - * @param {string} engineName - * @param {() => Promise<ArrayBuffer>} getModel - * @param {number} cacheTimeoutMS - How long the engine cache remains alive between - * uses, in milliseconds. In automation the engine is manually created and destroyed - * to avoid timing issues. + /** Creates a new MLEngine. + * + * @param {PipelineOptions} pipelineOptions * @returns {MLEngine} */ - getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) { - return new MLEngine(this, engineName, getModel, cacheTimeoutMS); + getEngine(pipelineOptions) { + return new MLEngine({ mlEngineParent: this, pipelineOptions }); + } + + /** Extracts the task name from the name and validates it. + * + * Throws an exception if the task name is invalid. + * + * @param {string} name + * @returns {string} + */ + nameToTaskName(name) { + // Extract taskName after the specific prefix + const taskName = name.split("MLEngine:GetInferenceOptions:")[1]; + + // Define a regular expression to verify taskName pattern (alphanumeric and underscores/dashes) + const validTaskNamePattern = /^[a-zA-Z0-9_\-]+$/; + + // Check if taskName matches the pattern + if (!validTaskNamePattern.test(taskName)) { + // Handle invalid taskName, e.g., throw an error or return null + throw new Error( + "Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes." + ); + } + return taskName; } // eslint-disable-next-line consistent-return async receiveMessage({ name }) { + if (name.startsWith("MLEngine:GetInferenceOptions")) { + return MLEngineParent.getInferenceOptions(this.nameToTaskName(name)); + } + switch (name) { case "MLEngine:Ready": if (lazy.EngineProcess.resolveMLEngineParent) { @@ -112,19 +136,18 @@ export class MLEngineParent extends JSWindowActorParent { } } - /** + /** Gets the wasm file from remote settings. + * * @param {RemoteSettingsClient} client */ static async #getWasmArrayRecord(client) { - // Load the wasm binary from remote settings, if it hasn't been already. - lazy.console.log(`Getting remote wasm records.`); + const wasmFilename = lazy.getRuntimeWasmFilename(this.browsingContext); /** @type {WasmRecord[]} */ const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords( client, { - // TODO - This record needs to be created with the engine wasm payload. - filters: { name: "inference-engine" }, + filters: { name: wasmFilename }, majorVersion: MLEngineParent.WASM_MAJOR_VERSION, } ); @@ -142,20 +165,47 @@ export class MLEngineParent extends JSWindowActorParent { ); } const [record] = wasmRecords; - lazy.console.log( - `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`, - record - ); + lazy.console.log(`Using runtime ${record.name}@${record.version}`, record); return record; } + /** Gets the inference options from remote settings given a task name. + * + * @type {string} taskName - name of the inference :wtask + * @returns {Promise<ModelRevisionRecord>} + */ + static async getInferenceOptions(taskName) { + const client = MLEngineParent.#getRemoteClient( + RS_INFERENCE_OPTIONS_COLLECTION + ); + const records = await client.get({ + filters: { + taskName, + }, + }); + + if (records.length === 0) { + throw new Error(`No inference options found for task ${taskName}`); + } + const options = records[0]; + return { + modelRevision: options.modelRevision, + modelId: options.modelId, + tokenizerRevision: options.tokenizerRevision, + tokenizerId: options.tokenizerId, + processorRevision: options.processorRevision, + processorId: options.processorId, + runtimeFilename: lazy.getRuntimeWasmFilename(this.browsingContext), + }; + } + /** * Download the wasm for the ML inference engine. * * @returns {Promise<ArrayBuffer>} */ static async getWasmArrayBuffer() { - const client = MLEngineParent.#getRemoteClient(); + const client = MLEngineParent.#getRemoteClient(RS_RUNTIME_COLLECTION); if (!MLEngineParent.#wasmRecord) { // Place the records into a promise to prevent any races. @@ -184,20 +234,23 @@ export class MLEngineParent extends JSWindowActorParent { /** * Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data. * + * @param {string} collectionName - The name of the collection to use. * @returns {RemoteSettingsClient} */ - static #getRemoteClient() { - if (MLEngineParent.#remoteClient) { - return MLEngineParent.#remoteClient; + static #getRemoteClient(collectionName) { + if (MLEngineParent.#remoteClients[collectionName]) { + return MLEngineParent.#remoteClients[collectionName]; } /** @type {RemoteSettingsClient} */ - const client = lazy.RemoteSettings("ml-wasm"); + const client = lazy.RemoteSettings(collectionName, { + bucketName: "main", + }); - MLEngineParent.#remoteClient = client; + MLEngineParent.#remoteClients[collectionName] = client; client.on("sync", async ({ data: { created, updated, deleted } }) => { - lazy.console.log(`"sync" event for ml-wasm`, { + lazy.console.log(`"sync" event for ${collectionName}`, { created, updated, deleted, @@ -229,17 +282,6 @@ export class MLEngineParent extends JSWindowActorParent { } /** - * This contains all of the information needed to perform a translation request. - * - * @typedef {object} TranslationRequest - * @property {Node} node - * @property {string} sourceText - * @property {boolean} isHTML - * @property {Function} resolve - * @property {Function} reject - */ - -/** * The interface to communicate to an MLEngine in the parent process. The engine manages * its own lifetime, and is kept alive with a timeout. A reference to this engine can * be retained, but once idle, the engine will be destroyed. If a new request to run @@ -271,21 +313,13 @@ class MLEngine { engineStatus = "uninitialized"; /** - * @param {MLEngineParent} mlEngineParent - * @param {string} engineName - * @param {() => Promise<ArrayBuffer>} getModel - * @param {number} timeoutMS + * @param {object} config - The configuration object for the instance. + * @param {object} config.mlEngineParent - The parent machine learning engine associated with this instance. + * @param {object} config.pipelineOptions - The options for configuring the pipeline associated with this instance. */ - constructor(mlEngineParent, engineName, getModel, timeoutMS) { - /** @type {MLEngineParent} */ + constructor({ mlEngineParent, pipelineOptions }) { this.mlEngineParent = mlEngineParent; - /** @type {string} */ - this.engineName = engineName; - /** @type {() => Promise<ArrayBuffer>} */ - this.getModel = getModel; - /** @type {number} */ - this.timeoutMS = timeoutMS; - + this.pipelineOptions = pipelineOptions; this.#setupPortCommunication(); } @@ -296,14 +330,12 @@ class MLEngine { const { port1: childPort, port2: parentPort } = new MessageChannel(); const transferables = [childPort]; this.#port = parentPort; - this.#port.onmessage = this.handlePortMessage; this.mlEngineParent.sendAsyncMessage( "MLEngine:NewPort", { port: childPort, - engineName: this.engineName, - timeoutMS: this.timeoutMS, + pipelineOptions: this.pipelineOptions.getOptions(), }, transferables ); @@ -393,11 +425,20 @@ class MLEngine { const resolvers = Promise.withResolvers(); const requestId = this.#nextRequestId++; this.#requests.set(requestId, resolvers); - this.#port.postMessage({ - type: "EnginePort:Run", - requestId, - request, - }); + + let transferables = []; + if (request.data instanceof ArrayBuffer) { + transferables.push(request.data); + } + + this.#port.postMessage( + { + type: "EnginePort:Run", + requestId, + request, + }, + transferables + ); return resolvers.promise; } } |