summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/actors/MLEngineParent.sys.mjs
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/actors/MLEngineParent.sys.mjs')
-rw-r--r--toolkit/components/ml/actors/MLEngineParent.sys.mjs403
1 files changed, 403 insertions, 0 deletions
diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
new file mode 100644
index 0000000000..10b4eed4fa
--- /dev/null
+++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
@@ -0,0 +1,403 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/**
+ * @typedef {object} Lazy
+ * @property {typeof console} console
+ * @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess
+ * @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings
+ * @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent
+ */
+
+/** @type {Lazy} */
+const lazy = {};
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+ChromeUtils.defineESModuleGetters(lazy, {
+ EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs",
+ RemoteSettings: "resource://services-settings/remote-settings.sys.mjs",
+ TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs",
+});
+
+/**
+ * @typedef {import("../../translations/translations").WasmRecord} WasmRecord
+ */
+
+const DEFAULT_CACHE_TIMEOUT_MS = 15_000;
+
+/**
+ * The ML engine is in its own content process. This actor handles the
+ * marshalling of the data such as the engine payload.
+ */
+export class MLEngineParent extends JSWindowActorParent {
+ /**
+ * The RemoteSettingsClient that downloads the wasm binaries.
+ *
+ * @type {RemoteSettingsClient | null}
+ */
+ static #remoteClient = null;
+
+ /** @type {Promise<WasmRecord> | null} */
+ static #wasmRecord = null;
+
+ /**
+ * The following constant controls the major version for wasm downloaded from
+ * Remote Settings. When a breaking change is introduced, Nightly will have these
+ * numbers incremented by one, but Beta and Release will still be on the previous
+ * version. Remote Settings will ship both versions of the records, and the latest
+ * asset released in that version will be used. For instance, with a major version
+ * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked
+ * as "2.0", "2.1", etc will not be downloaded.
+ */
+ static WASM_MAJOR_VERSION = 1;
+
+ /**
+ * Remote settings isn't available in tests, so provide mocked responses.
+ *
+ * @param {RemoteSettingsClient} remoteClient
+ */
+ static mockRemoteSettings(remoteClient) {
+ lazy.console.log("Mocking remote settings in MLEngineParent.");
+ MLEngineParent.#remoteClient = remoteClient;
+ MLEngineParent.#wasmRecord = null;
+ }
+
+ /**
+ * Remove anything that could have been mocked.
+ */
+ static removeMocks() {
+ lazy.console.log("Removing mocked remote client in MLEngineParent.");
+ MLEngineParent.#remoteClient = null;
+ MLEngineParent.#wasmRecord = null;
+ }
+
+ /**
+ * @param {string} engineName
+ * @param {() => Promise<ArrayBuffer>} getModel
+ * @param {number} cacheTimeoutMS - How long the engine cache remains alive between
+ * uses, in milliseconds. In automation the engine is manually created and destroyed
+ * to avoid timing issues.
+ * @returns {MLEngine}
+ */
+ getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) {
+ return new MLEngine(this, engineName, getModel, cacheTimeoutMS);
+ }
+
+ // eslint-disable-next-line consistent-return
+ async receiveMessage({ name, data }) {
+ switch (name) {
+ case "MLEngine:Ready":
+ if (lazy.EngineProcess.resolveMLEngineParent) {
+ lazy.EngineProcess.resolveMLEngineParent(this);
+ } else {
+ lazy.console.error(
+ "Expected #resolveMLEngineParent to exist when then ML Engine is ready."
+ );
+ }
+ break;
+ case "MLEngine:GetWasmArrayBuffer":
+ return MLEngineParent.getWasmArrayBuffer();
+ case "MLEngine:DestroyEngineProcess":
+ lazy.EngineProcess.destroyMLEngine().catch(error =>
+ console.error(error)
+ );
+ break;
+ }
+ }
+
+ /**
+ * @param {RemoteSettingsClient} client
+ */
+ static async #getWasmArrayRecord(client) {
+ // Load the wasm binary from remote settings, if it hasn't been already.
+ lazy.console.log(`Getting remote wasm records.`);
+
+ /** @type {WasmRecord[]} */
+ const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords(
+ client,
+ {
+ // TODO - This record needs to be created with the engine wasm payload.
+ filters: { name: "inference-engine" },
+ majorVersion: MLEngineParent.WASM_MAJOR_VERSION,
+ }
+ );
+
+ if (wasmRecords.length === 0) {
+ // The remote settings client provides an empty list of records when there is
+ // an error.
+ throw new Error("Unable to get the ML engine from Remote Settings.");
+ }
+
+ if (wasmRecords.length > 1) {
+ MLEngineParent.reportError(
+ new Error("Expected the ml engine to only have 1 record."),
+ wasmRecords
+ );
+ }
+ const [record] = wasmRecords;
+ lazy.console.log(
+ `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`,
+ record
+ );
+ return record;
+ }
+
+ /**
+ * Download the wasm for the ML inference engine.
+ *
+ * @returns {Promise<ArrayBuffer>}
+ */
+ static async getWasmArrayBuffer() {
+ const client = MLEngineParent.#getRemoteClient();
+
+ if (!MLEngineParent.#wasmRecord) {
+ // Place the records into a promise to prevent any races.
+ MLEngineParent.#wasmRecord = MLEngineParent.#getWasmArrayRecord(client);
+ }
+
+ let wasmRecord;
+ try {
+ wasmRecord = await MLEngineParent.#wasmRecord;
+ if (!wasmRecord) {
+ return Promise.reject(
+ "Error: Unable to get the ML engine from Remote Settings."
+ );
+ }
+ } catch (error) {
+ MLEngineParent.#wasmRecord = null;
+ throw error;
+ }
+
+ /** @type {{buffer: ArrayBuffer}} */
+ const { buffer } = await client.attachments.download(wasmRecord);
+
+ return buffer;
+ }
+
+ /**
+ * Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data.
+ *
+ * @returns {RemoteSettingsClient}
+ */
+ static #getRemoteClient() {
+ if (MLEngineParent.#remoteClient) {
+ return MLEngineParent.#remoteClient;
+ }
+
+ /** @type {RemoteSettingsClient} */
+ const client = lazy.RemoteSettings("ml-wasm");
+
+ MLEngineParent.#remoteClient = client;
+
+ client.on("sync", async ({ data: { created, updated, deleted } }) => {
+ lazy.console.log(`"sync" event for ml-wasm`, {
+ created,
+ updated,
+ deleted,
+ });
+
+ // Remove all the deleted records.
+ for (const record of deleted) {
+ await client.attachments.deleteDownloaded(record);
+ }
+
+ // Remove any updated records, and download the new ones.
+ for (const { old: oldRecord } of updated) {
+ await client.attachments.deleteDownloaded(oldRecord);
+ }
+
+ // Do nothing for the created records.
+ });
+
+ return client;
+ }
+
+ /**
+ * Send a message to gracefully shutdown all of the ML engines in the engine process.
+ * This mostly exists for testing the shutdown paths of the code.
+ */
+ forceShutdown() {
+ return this.sendQuery("MLEngine:ForceShutdown");
+ }
+}
+
+/**
+ * This contains all of the information needed to perform a translation request.
+ *
+ * @typedef {object} TranslationRequest
+ * @property {Node} node
+ * @property {string} sourceText
+ * @property {boolean} isHTML
+ * @property {Function} resolve
+ * @property {Function} reject
+ */
+
+/**
+ * The interface to communicate to an MLEngine in the parent process. The engine manages
+ * its own lifetime, and is kept alive with a timeout. A reference to this engine can
+ * be retained, but once idle, the engine will be destroyed. If a new request to run
+ * is sent, the engine will be recreated on demand. This balances the cost of retaining
+ * potentially large amounts of memory to run models, with the speed and ease of running
+ * the engine.
+ *
+ * @template Request
+ * @template Response
+ */
+class MLEngine {
+ /**
+ * @type {MessagePort | null}
+ */
+ #port = null;
+
+ #nextRequestId = 0;
+
+ /**
+ * Tie together a message id to a resolved response.
+ *
+ * @type {Map<number, PromiseWithResolvers<Request>>}
+ */
+ #requests = new Map();
+
+ /**
+ * @type {"uninitialized" | "ready" | "error" | "closed"}
+ */
+ engineStatus = "uninitialized";
+
+ /**
+ * @param {MLEngineParent} mlEngineParent
+ * @param {string} engineName
+ * @param {() => Promise<ArrayBuffer>} getModel
+ * @param {number} timeoutMS
+ */
+ constructor(mlEngineParent, engineName, getModel, timeoutMS) {
+ /** @type {MLEngineParent} */
+ this.mlEngineParent = mlEngineParent;
+ /** @type {string} */
+ this.engineName = engineName;
+ /** @type {() => Promise<ArrayBuffer>} */
+ this.getModel = getModel;
+ /** @type {number} */
+ this.timeoutMS = timeoutMS;
+
+ this.#setupPortCommunication();
+ }
+
+ /**
+ * Create a MessageChannel to communicate with the engine directly.
+ */
+ #setupPortCommunication() {
+ const { port1: childPort, port2: parentPort } = new MessageChannel();
+ const transferables = [childPort];
+ this.#port = parentPort;
+
+ this.#port.onmessage = this.handlePortMessage;
+ this.mlEngineParent.sendAsyncMessage(
+ "MLEngine:NewPort",
+ {
+ port: childPort,
+ engineName: this.engineName,
+ timeoutMS: this.timeoutMS,
+ },
+ transferables
+ );
+ }
+
+ handlePortMessage = ({ data }) => {
+ switch (data.type) {
+ case "EnginePort:ModelRequest": {
+ if (this.#port) {
+ this.getModel().then(
+ model => {
+ this.#port.postMessage({
+ type: "EnginePort:ModelResponse",
+ model,
+ error: null,
+ });
+ },
+ error => {
+ this.#port.postMessage({
+ type: "EnginePort:ModelResponse",
+ model: null,
+ error,
+ });
+ if (
+ // Ignore intentional errors in tests.
+ !error?.message.startsWith("Intentionally")
+ ) {
+ lazy.console.error("Failed to get the model", error);
+ }
+ }
+ );
+ } else {
+ lazy.console.error(
+ "Expected a port to exist during the EnginePort:GetModel event"
+ );
+ }
+ break;
+ }
+ case "EnginePort:RunResponse": {
+ const { response, error, requestId } = data;
+ const request = this.#requests.get(requestId);
+ if (request) {
+ if (response) {
+ request.resolve(response);
+ } else {
+ request.reject(error);
+ }
+ } else {
+ lazy.console.error(
+ "Could not resolve response in the MLEngineParent",
+ data
+ );
+ }
+ this.#requests.delete(requestId);
+ break;
+ }
+ case "EnginePort:EngineTerminated": {
+ // The engine was terminated, and if a new run is needed a new port
+ // will need to be requested.
+ this.engineStatus = "closed";
+ this.discardPort();
+ break;
+ }
+ default:
+ lazy.console.error("Unknown port message from engine", data);
+ break;
+ }
+ };
+
+ discardPort() {
+ if (this.#port) {
+ this.#port.postMessage({ type: "EnginePort:Discard" });
+ this.#port.close();
+ this.#port = null;
+ }
+ }
+
+ terminate() {
+ this.#port.postMessage({ type: "EnginePort:Terminate" });
+ }
+
+ /**
+ * @param {Request} request
+ * @returns {Promise<Response>}
+ */
+ run(request) {
+ const resolvers = Promise.withResolvers();
+ const requestId = this.#nextRequestId++;
+ this.#requests.set(requestId, resolvers);
+ this.#port.postMessage({
+ type: "EnginePort:Run",
+ requestId,
+ request,
+ });
+ return resolvers.promise;
+ }
+}