diff options
Diffstat (limited to 'toolkit/components/ml/actors')
-rw-r--r-- | toolkit/components/ml/actors/MLEngineChild.sys.mjs | 325 | ||||
-rw-r--r-- | toolkit/components/ml/actors/MLEngineParent.sys.mjs | 403 | ||||
-rw-r--r-- | toolkit/components/ml/actors/moz.build | 8 |
3 files changed, 736 insertions, 0 deletions
diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs new file mode 100644 index 0000000000..925ce59266 --- /dev/null +++ b/toolkit/components/ml/actors/MLEngineChild.sys.mjs @@ -0,0 +1,325 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs"; + +/** + * @typedef {import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker + */ + +/** + * @typedef {object} Lazy + * @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker + * @property {typeof setTimeout} setTimeout + * @property {typeof clearTimeout} clearTimeout + */ + +/** @type {Lazy} */ +const lazy = {}; +ChromeUtils.defineESModuleGetters(lazy, { + BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs", + setTimeout: "resource://gre/modules/Timer.sys.mjs", + clearTimeout: "resource://gre/modules/Timer.sys.mjs", +}); + +ChromeUtils.defineLazyGetter(lazy, "console", () => { + return console.createInstance({ + maxLogLevelPref: "browser.ml.logLevel", + prefix: "ML", + }); +}); + +XPCOMUtils.defineLazyPreferenceGetter( + lazy, + "loggingLevel", + "browser.ml.logLevel" +); + +/** + * The engine child is responsible for the life cycle and instantiation of the local + * machine learning inference engine. + */ +export class MLEngineChild extends JSWindowActorChild { + /** + * The cached engines. + * + * @type {Map<string, EngineDispatcher>} + */ + #engineDispatchers = new Map(); + + // eslint-disable-next-line consistent-return + async receiveMessage({ name, data }) { + switch (name) { + case "MLEngine:NewPort": { + const { engineName, port, timeoutMS } = data; + this.#engineDispatchers.set( + engineName, + new EngineDispatcher(this, port, engineName, timeoutMS) + ); + break; + } + case "MLEngine:ForceShutdown": { + for (const engineDispatcher of this.#engineDispatchers.values()) { + return engineDispatcher.terminate(); + } + this.#engineDispatchers = null; + break; + } + } + } + + handleEvent(event) { + switch (event.type) { + case "DOMContentLoaded": + this.sendAsyncMessage("MLEngine:Ready"); + break; + } + } + + /** + * @returns {ArrayBuffer} + */ + getWasmArrayBuffer() { + return this.sendQuery("MLEngine:GetWasmArrayBuffer"); + } + + /** + * @param {string} engineName + */ + removeEngine(engineName) { + this.#engineDispatchers.delete(engineName); + if (this.#engineDispatchers.size === 0) { + this.sendQuery("MLEngine:DestroyEngineProcess"); + } + } +} + +/** + * This classes manages the lifecycle of an ML Engine, and handles dispatching messages + * to it. + */ +class EngineDispatcher { + /** @type {Set<MessagePort>} */ + #ports = new Set(); + + /** @type {TimeoutID | null} */ + #keepAliveTimeout = null; + + /** @type {PromiseWithResolvers} */ + #modelRequest; + + /** @type {Promise<Engine> | null} */ + #engine = null; + + /** @type {string} */ + #engineName; + + /** + * @param {MLEngineChild} mlEngineChild + * @param {MessagePort} port + * @param {string} engineName + * @param {number} timeoutMS + */ + constructor(mlEngineChild, port, engineName, timeoutMS) { + /** @type {MLEngineChild} */ + this.mlEngineChild = mlEngineChild; + + /** @type {number} */ + this.timeoutMS = timeoutMS; + + this.#engineName = engineName; + + this.#engine = Promise.all([ + this.mlEngineChild.getWasmArrayBuffer(), + this.getModel(port), + ]).then(([wasm, model]) => FakeEngine.create(wasm, model)); + + this.#engine + .then(() => void this.keepAlive()) + .catch(error => { + if ( + // Ignore errors from tests intentionally causing errors. + !error?.message?.startsWith("Intentionally") + ) { + lazy.console.error("Could not initalize the engine", error); + } + }); + + this.setupMessageHandler(port); + } + + /** + * The worker needs to be shutdown after some amount of time of not being used. + */ + keepAlive() { + if (this.#keepAliveTimeout) { + // Clear any previous timeout. + lazy.clearTimeout(this.#keepAliveTimeout); + } + // In automated tests, the engine is manually destroyed. + if (!Cu.isInAutomation) { + this.#keepAliveTimeout = lazy.setTimeout(this.terminate, this.timeoutMS); + } + } + + /** + * @param {MessagePort} port + */ + getModel(port) { + if (this.#modelRequest) { + // There could be a race to get a model, use the first request. + return this.#modelRequest.promise; + } + this.#modelRequest = Promise.withResolvers(); + port.postMessage({ type: "EnginePort:ModelRequest" }); + return this.#modelRequest.promise; + } + + /** + * @param {MessagePort} port + */ + setupMessageHandler(port) { + port.onmessage = async ({ data }) => { + switch (data.type) { + case "EnginePort:Discard": { + port.close(); + this.#ports.delete(port); + break; + } + case "EnginePort:Terminate": { + this.terminate(); + break; + } + case "EnginePort:ModelResponse": { + if (this.#modelRequest) { + const { model, error } = data; + if (model) { + this.#modelRequest.resolve(model); + } else { + this.#modelRequest.reject(error); + } + this.#modelRequest = null; + } else { + lazy.console.error( + "Got a EnginePort:ModelResponse but no model resolvers" + ); + } + break; + } + case "EnginePort:Run": { + const { requestId, request } = data; + let engine; + try { + engine = await this.#engine; + } catch (error) { + port.postMessage({ + type: "EnginePort:RunResponse", + requestId, + response: null, + error, + }); + // The engine failed to load. Terminate the entire dispatcher. + this.terminate(); + return; + } + + // Do not run the keepAlive timer until we are certain that the engine loaded, + // as the engine shouldn't be killed while it is initializing. + this.keepAlive(); + + try { + port.postMessage({ + type: "EnginePort:RunResponse", + requestId, + response: await engine.run(request), + error: null, + }); + } catch (error) { + port.postMessage({ + type: "EnginePort:RunResponse", + requestId, + response: null, + error, + }); + } + break; + } + default: + lazy.console.error("Unknown port message to engine: ", data); + break; + } + }; + } + + /** + * Terminates the engine and its worker after a timeout. + */ + async terminate() { + if (this.#keepAliveTimeout) { + lazy.clearTimeout(this.#keepAliveTimeout); + this.#keepAliveTimeout = null; + } + for (const port of this.#ports) { + port.postMessage({ type: "EnginePort:EngineTerminated" }); + port.close(); + } + this.#ports = new Set(); + this.mlEngineChild.removeEngine(this.#engineName); + try { + const engine = await this.#engine; + engine.terminate(); + } catch (error) { + console.error("Failed to get the engine", error); + } + } +} + +/** + * Fake the engine by slicing the text in half. + */ +class FakeEngine { + /** @type {BasePromiseWorker} */ + #worker; + + /** + * Initialize the worker. + * + * @param {ArrayBuffer} wasm + * @param {ArrayBuffer} model + * @returns {FakeEngine} + */ + static async create(wasm, model) { + /** @type {BasePromiseWorker} */ + const worker = new lazy.BasePromiseWorker( + "chrome://global/content/ml/MLEngine.worker.mjs", + { type: "module" } + ); + + const args = [wasm, model, lazy.loggingLevel]; + const closure = {}; + const transferables = [wasm, model]; + await worker.post("initializeEngine", args, closure, transferables); + return new FakeEngine(worker); + } + + /** + * @param {BasePromiseWorker} worker + */ + constructor(worker) { + this.#worker = worker; + } + + /** + * @param {string} request + * @returns {Promise<string>} + */ + run(request) { + return this.#worker.post("run", [request]); + } + + terminate() { + this.#worker.terminate(); + this.#worker = null; + } +} diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs new file mode 100644 index 0000000000..10b4eed4fa --- /dev/null +++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs @@ -0,0 +1,403 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/** + * @typedef {object} Lazy + * @property {typeof console} console + * @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess + * @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings + * @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent + */ + +/** @type {Lazy} */ +const lazy = {}; + +ChromeUtils.defineLazyGetter(lazy, "console", () => { + return console.createInstance({ + maxLogLevelPref: "browser.ml.logLevel", + prefix: "ML", + }); +}); + +ChromeUtils.defineESModuleGetters(lazy, { + EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs", + RemoteSettings: "resource://services-settings/remote-settings.sys.mjs", + TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs", +}); + +/** + * @typedef {import("../../translations/translations").WasmRecord} WasmRecord + */ + +const DEFAULT_CACHE_TIMEOUT_MS = 15_000; + +/** + * The ML engine is in its own content process. This actor handles the + * marshalling of the data such as the engine payload. + */ +export class MLEngineParent extends JSWindowActorParent { + /** + * The RemoteSettingsClient that downloads the wasm binaries. + * + * @type {RemoteSettingsClient | null} + */ + static #remoteClient = null; + + /** @type {Promise<WasmRecord> | null} */ + static #wasmRecord = null; + + /** + * The following constant controls the major version for wasm downloaded from + * Remote Settings. When a breaking change is introduced, Nightly will have these + * numbers incremented by one, but Beta and Release will still be on the previous + * version. Remote Settings will ship both versions of the records, and the latest + * asset released in that version will be used. For instance, with a major version + * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked + * as "2.0", "2.1", etc will not be downloaded. + */ + static WASM_MAJOR_VERSION = 1; + + /** + * Remote settings isn't available in tests, so provide mocked responses. + * + * @param {RemoteSettingsClient} remoteClient + */ + static mockRemoteSettings(remoteClient) { + lazy.console.log("Mocking remote settings in MLEngineParent."); + MLEngineParent.#remoteClient = remoteClient; + MLEngineParent.#wasmRecord = null; + } + + /** + * Remove anything that could have been mocked. + */ + static removeMocks() { + lazy.console.log("Removing mocked remote client in MLEngineParent."); + MLEngineParent.#remoteClient = null; + MLEngineParent.#wasmRecord = null; + } + + /** + * @param {string} engineName + * @param {() => Promise<ArrayBuffer>} getModel + * @param {number} cacheTimeoutMS - How long the engine cache remains alive between + * uses, in milliseconds. In automation the engine is manually created and destroyed + * to avoid timing issues. + * @returns {MLEngine} + */ + getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) { + return new MLEngine(this, engineName, getModel, cacheTimeoutMS); + } + + // eslint-disable-next-line consistent-return + async receiveMessage({ name, data }) { + switch (name) { + case "MLEngine:Ready": + if (lazy.EngineProcess.resolveMLEngineParent) { + lazy.EngineProcess.resolveMLEngineParent(this); + } else { + lazy.console.error( + "Expected #resolveMLEngineParent to exist when then ML Engine is ready." + ); + } + break; + case "MLEngine:GetWasmArrayBuffer": + return MLEngineParent.getWasmArrayBuffer(); + case "MLEngine:DestroyEngineProcess": + lazy.EngineProcess.destroyMLEngine().catch(error => + console.error(error) + ); + break; + } + } + + /** + * @param {RemoteSettingsClient} client + */ + static async #getWasmArrayRecord(client) { + // Load the wasm binary from remote settings, if it hasn't been already. + lazy.console.log(`Getting remote wasm records.`); + + /** @type {WasmRecord[]} */ + const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords( + client, + { + // TODO - This record needs to be created with the engine wasm payload. + filters: { name: "inference-engine" }, + majorVersion: MLEngineParent.WASM_MAJOR_VERSION, + } + ); + + if (wasmRecords.length === 0) { + // The remote settings client provides an empty list of records when there is + // an error. + throw new Error("Unable to get the ML engine from Remote Settings."); + } + + if (wasmRecords.length > 1) { + MLEngineParent.reportError( + new Error("Expected the ml engine to only have 1 record."), + wasmRecords + ); + } + const [record] = wasmRecords; + lazy.console.log( + `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`, + record + ); + return record; + } + + /** + * Download the wasm for the ML inference engine. + * + * @returns {Promise<ArrayBuffer>} + */ + static async getWasmArrayBuffer() { + const client = MLEngineParent.#getRemoteClient(); + + if (!MLEngineParent.#wasmRecord) { + // Place the records into a promise to prevent any races. + MLEngineParent.#wasmRecord = MLEngineParent.#getWasmArrayRecord(client); + } + + let wasmRecord; + try { + wasmRecord = await MLEngineParent.#wasmRecord; + if (!wasmRecord) { + return Promise.reject( + "Error: Unable to get the ML engine from Remote Settings." + ); + } + } catch (error) { + MLEngineParent.#wasmRecord = null; + throw error; + } + + /** @type {{buffer: ArrayBuffer}} */ + const { buffer } = await client.attachments.download(wasmRecord); + + return buffer; + } + + /** + * Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data. + * + * @returns {RemoteSettingsClient} + */ + static #getRemoteClient() { + if (MLEngineParent.#remoteClient) { + return MLEngineParent.#remoteClient; + } + + /** @type {RemoteSettingsClient} */ + const client = lazy.RemoteSettings("ml-wasm"); + + MLEngineParent.#remoteClient = client; + + client.on("sync", async ({ data: { created, updated, deleted } }) => { + lazy.console.log(`"sync" event for ml-wasm`, { + created, + updated, + deleted, + }); + + // Remove all the deleted records. + for (const record of deleted) { + await client.attachments.deleteDownloaded(record); + } + + // Remove any updated records, and download the new ones. + for (const { old: oldRecord } of updated) { + await client.attachments.deleteDownloaded(oldRecord); + } + + // Do nothing for the created records. + }); + + return client; + } + + /** + * Send a message to gracefully shutdown all of the ML engines in the engine process. + * This mostly exists for testing the shutdown paths of the code. + */ + forceShutdown() { + return this.sendQuery("MLEngine:ForceShutdown"); + } +} + +/** + * This contains all of the information needed to perform a translation request. + * + * @typedef {object} TranslationRequest + * @property {Node} node + * @property {string} sourceText + * @property {boolean} isHTML + * @property {Function} resolve + * @property {Function} reject + */ + +/** + * The interface to communicate to an MLEngine in the parent process. The engine manages + * its own lifetime, and is kept alive with a timeout. A reference to this engine can + * be retained, but once idle, the engine will be destroyed. If a new request to run + * is sent, the engine will be recreated on demand. This balances the cost of retaining + * potentially large amounts of memory to run models, with the speed and ease of running + * the engine. + * + * @template Request + * @template Response + */ +class MLEngine { + /** + * @type {MessagePort | null} + */ + #port = null; + + #nextRequestId = 0; + + /** + * Tie together a message id to a resolved response. + * + * @type {Map<number, PromiseWithResolvers<Request>>} + */ + #requests = new Map(); + + /** + * @type {"uninitialized" | "ready" | "error" | "closed"} + */ + engineStatus = "uninitialized"; + + /** + * @param {MLEngineParent} mlEngineParent + * @param {string} engineName + * @param {() => Promise<ArrayBuffer>} getModel + * @param {number} timeoutMS + */ + constructor(mlEngineParent, engineName, getModel, timeoutMS) { + /** @type {MLEngineParent} */ + this.mlEngineParent = mlEngineParent; + /** @type {string} */ + this.engineName = engineName; + /** @type {() => Promise<ArrayBuffer>} */ + this.getModel = getModel; + /** @type {number} */ + this.timeoutMS = timeoutMS; + + this.#setupPortCommunication(); + } + + /** + * Create a MessageChannel to communicate with the engine directly. + */ + #setupPortCommunication() { + const { port1: childPort, port2: parentPort } = new MessageChannel(); + const transferables = [childPort]; + this.#port = parentPort; + + this.#port.onmessage = this.handlePortMessage; + this.mlEngineParent.sendAsyncMessage( + "MLEngine:NewPort", + { + port: childPort, + engineName: this.engineName, + timeoutMS: this.timeoutMS, + }, + transferables + ); + } + + handlePortMessage = ({ data }) => { + switch (data.type) { + case "EnginePort:ModelRequest": { + if (this.#port) { + this.getModel().then( + model => { + this.#port.postMessage({ + type: "EnginePort:ModelResponse", + model, + error: null, + }); + }, + error => { + this.#port.postMessage({ + type: "EnginePort:ModelResponse", + model: null, + error, + }); + if ( + // Ignore intentional errors in tests. + !error?.message.startsWith("Intentionally") + ) { + lazy.console.error("Failed to get the model", error); + } + } + ); + } else { + lazy.console.error( + "Expected a port to exist during the EnginePort:GetModel event" + ); + } + break; + } + case "EnginePort:RunResponse": { + const { response, error, requestId } = data; + const request = this.#requests.get(requestId); + if (request) { + if (response) { + request.resolve(response); + } else { + request.reject(error); + } + } else { + lazy.console.error( + "Could not resolve response in the MLEngineParent", + data + ); + } + this.#requests.delete(requestId); + break; + } + case "EnginePort:EngineTerminated": { + // The engine was terminated, and if a new run is needed a new port + // will need to be requested. + this.engineStatus = "closed"; + this.discardPort(); + break; + } + default: + lazy.console.error("Unknown port message from engine", data); + break; + } + }; + + discardPort() { + if (this.#port) { + this.#port.postMessage({ type: "EnginePort:Discard" }); + this.#port.close(); + this.#port = null; + } + } + + terminate() { + this.#port.postMessage({ type: "EnginePort:Terminate" }); + } + + /** + * @param {Request} request + * @returns {Promise<Response>} + */ + run(request) { + const resolvers = Promise.withResolvers(); + const requestId = this.#nextRequestId++; + this.#requests.set(requestId, resolvers); + this.#port.postMessage({ + type: "EnginePort:Run", + requestId, + request, + }); + return resolvers.promise; + } +} diff --git a/toolkit/components/ml/actors/moz.build b/toolkit/components/ml/actors/moz.build new file mode 100644 index 0000000000..de3e27ae2a --- /dev/null +++ b/toolkit/components/ml/actors/moz.build @@ -0,0 +1,8 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +FINAL_TARGET_FILES.actors += [ + "MLEngineChild.sys.mjs", + "MLEngineParent.sys.mjs", +] |