diff options
Diffstat (limited to 'toolkit/components/ml/actors/MLEngineParent.sys.mjs')
-rw-r--r-- | toolkit/components/ml/actors/MLEngineParent.sys.mjs | 403 |
1 files changed, 403 insertions, 0 deletions
diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs new file mode 100644 index 0000000000..10b4eed4fa --- /dev/null +++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs @@ -0,0 +1,403 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/** + * @typedef {object} Lazy + * @property {typeof console} console + * @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess + * @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings + * @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent + */ + +/** @type {Lazy} */ +const lazy = {}; + +ChromeUtils.defineLazyGetter(lazy, "console", () => { + return console.createInstance({ + maxLogLevelPref: "browser.ml.logLevel", + prefix: "ML", + }); +}); + +ChromeUtils.defineESModuleGetters(lazy, { + EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs", + RemoteSettings: "resource://services-settings/remote-settings.sys.mjs", + TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs", +}); + +/** + * @typedef {import("../../translations/translations").WasmRecord} WasmRecord + */ + +const DEFAULT_CACHE_TIMEOUT_MS = 15_000; + +/** + * The ML engine is in its own content process. This actor handles the + * marshalling of the data such as the engine payload. + */ +export class MLEngineParent extends JSWindowActorParent { + /** + * The RemoteSettingsClient that downloads the wasm binaries. + * + * @type {RemoteSettingsClient | null} + */ + static #remoteClient = null; + + /** @type {Promise<WasmRecord> | null} */ + static #wasmRecord = null; + + /** + * The following constant controls the major version for wasm downloaded from + * Remote Settings. When a breaking change is introduced, Nightly will have these + * numbers incremented by one, but Beta and Release will still be on the previous + * version. Remote Settings will ship both versions of the records, and the latest + * asset released in that version will be used. For instance, with a major version + * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked + * as "2.0", "2.1", etc will not be downloaded. + */ + static WASM_MAJOR_VERSION = 1; + + /** + * Remote settings isn't available in tests, so provide mocked responses. + * + * @param {RemoteSettingsClient} remoteClient + */ + static mockRemoteSettings(remoteClient) { + lazy.console.log("Mocking remote settings in MLEngineParent."); + MLEngineParent.#remoteClient = remoteClient; + MLEngineParent.#wasmRecord = null; + } + + /** + * Remove anything that could have been mocked. + */ + static removeMocks() { + lazy.console.log("Removing mocked remote client in MLEngineParent."); + MLEngineParent.#remoteClient = null; + MLEngineParent.#wasmRecord = null; + } + + /** + * @param {string} engineName + * @param {() => Promise<ArrayBuffer>} getModel + * @param {number} cacheTimeoutMS - How long the engine cache remains alive between + * uses, in milliseconds. In automation the engine is manually created and destroyed + * to avoid timing issues. + * @returns {MLEngine} + */ + getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) { + return new MLEngine(this, engineName, getModel, cacheTimeoutMS); + } + + // eslint-disable-next-line consistent-return + async receiveMessage({ name, data }) { + switch (name) { + case "MLEngine:Ready": + if (lazy.EngineProcess.resolveMLEngineParent) { + lazy.EngineProcess.resolveMLEngineParent(this); + } else { + lazy.console.error( + "Expected #resolveMLEngineParent to exist when then ML Engine is ready." + ); + } + break; + case "MLEngine:GetWasmArrayBuffer": + return MLEngineParent.getWasmArrayBuffer(); + case "MLEngine:DestroyEngineProcess": + lazy.EngineProcess.destroyMLEngine().catch(error => + console.error(error) + ); + break; + } + } + + /** + * @param {RemoteSettingsClient} client + */ + static async #getWasmArrayRecord(client) { + // Load the wasm binary from remote settings, if it hasn't been already. + lazy.console.log(`Getting remote wasm records.`); + + /** @type {WasmRecord[]} */ + const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords( + client, + { + // TODO - This record needs to be created with the engine wasm payload. + filters: { name: "inference-engine" }, + majorVersion: MLEngineParent.WASM_MAJOR_VERSION, + } + ); + + if (wasmRecords.length === 0) { + // The remote settings client provides an empty list of records when there is + // an error. + throw new Error("Unable to get the ML engine from Remote Settings."); + } + + if (wasmRecords.length > 1) { + MLEngineParent.reportError( + new Error("Expected the ml engine to only have 1 record."), + wasmRecords + ); + } + const [record] = wasmRecords; + lazy.console.log( + `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`, + record + ); + return record; + } + + /** + * Download the wasm for the ML inference engine. + * + * @returns {Promise<ArrayBuffer>} + */ + static async getWasmArrayBuffer() { + const client = MLEngineParent.#getRemoteClient(); + + if (!MLEngineParent.#wasmRecord) { + // Place the records into a promise to prevent any races. + MLEngineParent.#wasmRecord = MLEngineParent.#getWasmArrayRecord(client); + } + + let wasmRecord; + try { + wasmRecord = await MLEngineParent.#wasmRecord; + if (!wasmRecord) { + return Promise.reject( + "Error: Unable to get the ML engine from Remote Settings." + ); + } + } catch (error) { + MLEngineParent.#wasmRecord = null; + throw error; + } + + /** @type {{buffer: ArrayBuffer}} */ + const { buffer } = await client.attachments.download(wasmRecord); + + return buffer; + } + + /** + * Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data. + * + * @returns {RemoteSettingsClient} + */ + static #getRemoteClient() { + if (MLEngineParent.#remoteClient) { + return MLEngineParent.#remoteClient; + } + + /** @type {RemoteSettingsClient} */ + const client = lazy.RemoteSettings("ml-wasm"); + + MLEngineParent.#remoteClient = client; + + client.on("sync", async ({ data: { created, updated, deleted } }) => { + lazy.console.log(`"sync" event for ml-wasm`, { + created, + updated, + deleted, + }); + + // Remove all the deleted records. + for (const record of deleted) { + await client.attachments.deleteDownloaded(record); + } + + // Remove any updated records, and download the new ones. + for (const { old: oldRecord } of updated) { + await client.attachments.deleteDownloaded(oldRecord); + } + + // Do nothing for the created records. + }); + + return client; + } + + /** + * Send a message to gracefully shutdown all of the ML engines in the engine process. + * This mostly exists for testing the shutdown paths of the code. + */ + forceShutdown() { + return this.sendQuery("MLEngine:ForceShutdown"); + } +} + +/** + * This contains all of the information needed to perform a translation request. + * + * @typedef {object} TranslationRequest + * @property {Node} node + * @property {string} sourceText + * @property {boolean} isHTML + * @property {Function} resolve + * @property {Function} reject + */ + +/** + * The interface to communicate to an MLEngine in the parent process. The engine manages + * its own lifetime, and is kept alive with a timeout. A reference to this engine can + * be retained, but once idle, the engine will be destroyed. If a new request to run + * is sent, the engine will be recreated on demand. This balances the cost of retaining + * potentially large amounts of memory to run models, with the speed and ease of running + * the engine. + * + * @template Request + * @template Response + */ +class MLEngine { + /** + * @type {MessagePort | null} + */ + #port = null; + + #nextRequestId = 0; + + /** + * Tie together a message id to a resolved response. + * + * @type {Map<number, PromiseWithResolvers<Request>>} + */ + #requests = new Map(); + + /** + * @type {"uninitialized" | "ready" | "error" | "closed"} + */ + engineStatus = "uninitialized"; + + /** + * @param {MLEngineParent} mlEngineParent + * @param {string} engineName + * @param {() => Promise<ArrayBuffer>} getModel + * @param {number} timeoutMS + */ + constructor(mlEngineParent, engineName, getModel, timeoutMS) { + /** @type {MLEngineParent} */ + this.mlEngineParent = mlEngineParent; + /** @type {string} */ + this.engineName = engineName; + /** @type {() => Promise<ArrayBuffer>} */ + this.getModel = getModel; + /** @type {number} */ + this.timeoutMS = timeoutMS; + + this.#setupPortCommunication(); + } + + /** + * Create a MessageChannel to communicate with the engine directly. + */ + #setupPortCommunication() { + const { port1: childPort, port2: parentPort } = new MessageChannel(); + const transferables = [childPort]; + this.#port = parentPort; + + this.#port.onmessage = this.handlePortMessage; + this.mlEngineParent.sendAsyncMessage( + "MLEngine:NewPort", + { + port: childPort, + engineName: this.engineName, + timeoutMS: this.timeoutMS, + }, + transferables + ); + } + + handlePortMessage = ({ data }) => { + switch (data.type) { + case "EnginePort:ModelRequest": { + if (this.#port) { + this.getModel().then( + model => { + this.#port.postMessage({ + type: "EnginePort:ModelResponse", + model, + error: null, + }); + }, + error => { + this.#port.postMessage({ + type: "EnginePort:ModelResponse", + model: null, + error, + }); + if ( + // Ignore intentional errors in tests. + !error?.message.startsWith("Intentionally") + ) { + lazy.console.error("Failed to get the model", error); + } + } + ); + } else { + lazy.console.error( + "Expected a port to exist during the EnginePort:GetModel event" + ); + } + break; + } + case "EnginePort:RunResponse": { + const { response, error, requestId } = data; + const request = this.#requests.get(requestId); + if (request) { + if (response) { + request.resolve(response); + } else { + request.reject(error); + } + } else { + lazy.console.error( + "Could not resolve response in the MLEngineParent", + data + ); + } + this.#requests.delete(requestId); + break; + } + case "EnginePort:EngineTerminated": { + // The engine was terminated, and if a new run is needed a new port + // will need to be requested. + this.engineStatus = "closed"; + this.discardPort(); + break; + } + default: + lazy.console.error("Unknown port message from engine", data); + break; + } + }; + + discardPort() { + if (this.#port) { + this.#port.postMessage({ type: "EnginePort:Discard" }); + this.#port.close(); + this.#port = null; + } + } + + terminate() { + this.#port.postMessage({ type: "EnginePort:Terminate" }); + } + + /** + * @param {Request} request + * @returns {Promise<Response>} + */ + run(request) { + const resolvers = Promise.withResolvers(); + const requestId = this.#nextRequestId++; + this.#requests.set(requestId, resolvers); + this.#port.postMessage({ + type: "EnginePort:Run", + requestId, + request, + }); + return resolvers.promise; + } +} |