summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/actors
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/actors')
-rw-r--r--toolkit/components/ml/actors/MLEngineChild.sys.mjs325
-rw-r--r--toolkit/components/ml/actors/MLEngineParent.sys.mjs403
-rw-r--r--toolkit/components/ml/actors/moz.build8
3 files changed, 736 insertions, 0 deletions
diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
new file mode 100644
index 0000000000..925ce59266
--- /dev/null
+++ b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
@@ -0,0 +1,325 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
+
+/**
+ * @typedef {import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
+ */
+
+/**
+ * @typedef {object} Lazy
+ * @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
+ * @property {typeof setTimeout} setTimeout
+ * @property {typeof clearTimeout} clearTimeout
+ */
+
+/** @type {Lazy} */
+const lazy = {};
+ChromeUtils.defineESModuleGetters(lazy, {
+ BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs",
+ setTimeout: "resource://gre/modules/Timer.sys.mjs",
+ clearTimeout: "resource://gre/modules/Timer.sys.mjs",
+});
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+XPCOMUtils.defineLazyPreferenceGetter(
+ lazy,
+ "loggingLevel",
+ "browser.ml.logLevel"
+);
+
+/**
+ * The engine child is responsible for the life cycle and instantiation of the local
+ * machine learning inference engine.
+ */
+export class MLEngineChild extends JSWindowActorChild {
+ /**
+ * The cached engines.
+ *
+ * @type {Map<string, EngineDispatcher>}
+ */
+ #engineDispatchers = new Map();
+
+ // eslint-disable-next-line consistent-return
+ async receiveMessage({ name, data }) {
+ switch (name) {
+ case "MLEngine:NewPort": {
+ const { engineName, port, timeoutMS } = data;
+ this.#engineDispatchers.set(
+ engineName,
+ new EngineDispatcher(this, port, engineName, timeoutMS)
+ );
+ break;
+ }
+ case "MLEngine:ForceShutdown": {
+ for (const engineDispatcher of this.#engineDispatchers.values()) {
+ return engineDispatcher.terminate();
+ }
+ this.#engineDispatchers = null;
+ break;
+ }
+ }
+ }
+
+ handleEvent(event) {
+ switch (event.type) {
+ case "DOMContentLoaded":
+ this.sendAsyncMessage("MLEngine:Ready");
+ break;
+ }
+ }
+
+ /**
+ * @returns {ArrayBuffer}
+ */
+ getWasmArrayBuffer() {
+ return this.sendQuery("MLEngine:GetWasmArrayBuffer");
+ }
+
+ /**
+ * @param {string} engineName
+ */
+ removeEngine(engineName) {
+ this.#engineDispatchers.delete(engineName);
+ if (this.#engineDispatchers.size === 0) {
+ this.sendQuery("MLEngine:DestroyEngineProcess");
+ }
+ }
+}
+
+/**
+ * This classes manages the lifecycle of an ML Engine, and handles dispatching messages
+ * to it.
+ */
+class EngineDispatcher {
+ /** @type {Set<MessagePort>} */
+ #ports = new Set();
+
+ /** @type {TimeoutID | null} */
+ #keepAliveTimeout = null;
+
+ /** @type {PromiseWithResolvers} */
+ #modelRequest;
+
+ /** @type {Promise<Engine> | null} */
+ #engine = null;
+
+ /** @type {string} */
+ #engineName;
+
+ /**
+ * @param {MLEngineChild} mlEngineChild
+ * @param {MessagePort} port
+ * @param {string} engineName
+ * @param {number} timeoutMS
+ */
+ constructor(mlEngineChild, port, engineName, timeoutMS) {
+ /** @type {MLEngineChild} */
+ this.mlEngineChild = mlEngineChild;
+
+ /** @type {number} */
+ this.timeoutMS = timeoutMS;
+
+ this.#engineName = engineName;
+
+ this.#engine = Promise.all([
+ this.mlEngineChild.getWasmArrayBuffer(),
+ this.getModel(port),
+ ]).then(([wasm, model]) => FakeEngine.create(wasm, model));
+
+ this.#engine
+ .then(() => void this.keepAlive())
+ .catch(error => {
+ if (
+ // Ignore errors from tests intentionally causing errors.
+ !error?.message?.startsWith("Intentionally")
+ ) {
+ lazy.console.error("Could not initalize the engine", error);
+ }
+ });
+
+ this.setupMessageHandler(port);
+ }
+
+ /**
+ * The worker needs to be shutdown after some amount of time of not being used.
+ */
+ keepAlive() {
+ if (this.#keepAliveTimeout) {
+ // Clear any previous timeout.
+ lazy.clearTimeout(this.#keepAliveTimeout);
+ }
+ // In automated tests, the engine is manually destroyed.
+ if (!Cu.isInAutomation) {
+ this.#keepAliveTimeout = lazy.setTimeout(this.terminate, this.timeoutMS);
+ }
+ }
+
+ /**
+ * @param {MessagePort} port
+ */
+ getModel(port) {
+ if (this.#modelRequest) {
+ // There could be a race to get a model, use the first request.
+ return this.#modelRequest.promise;
+ }
+ this.#modelRequest = Promise.withResolvers();
+ port.postMessage({ type: "EnginePort:ModelRequest" });
+ return this.#modelRequest.promise;
+ }
+
+ /**
+ * @param {MessagePort} port
+ */
+ setupMessageHandler(port) {
+ port.onmessage = async ({ data }) => {
+ switch (data.type) {
+ case "EnginePort:Discard": {
+ port.close();
+ this.#ports.delete(port);
+ break;
+ }
+ case "EnginePort:Terminate": {
+ this.terminate();
+ break;
+ }
+ case "EnginePort:ModelResponse": {
+ if (this.#modelRequest) {
+ const { model, error } = data;
+ if (model) {
+ this.#modelRequest.resolve(model);
+ } else {
+ this.#modelRequest.reject(error);
+ }
+ this.#modelRequest = null;
+ } else {
+ lazy.console.error(
+ "Got a EnginePort:ModelResponse but no model resolvers"
+ );
+ }
+ break;
+ }
+ case "EnginePort:Run": {
+ const { requestId, request } = data;
+ let engine;
+ try {
+ engine = await this.#engine;
+ } catch (error) {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: null,
+ error,
+ });
+ // The engine failed to load. Terminate the entire dispatcher.
+ this.terminate();
+ return;
+ }
+
+ // Do not run the keepAlive timer until we are certain that the engine loaded,
+ // as the engine shouldn't be killed while it is initializing.
+ this.keepAlive();
+
+ try {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: await engine.run(request),
+ error: null,
+ });
+ } catch (error) {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: null,
+ error,
+ });
+ }
+ break;
+ }
+ default:
+ lazy.console.error("Unknown port message to engine: ", data);
+ break;
+ }
+ };
+ }
+
+ /**
+ * Terminates the engine and its worker after a timeout.
+ */
+ async terminate() {
+ if (this.#keepAliveTimeout) {
+ lazy.clearTimeout(this.#keepAliveTimeout);
+ this.#keepAliveTimeout = null;
+ }
+ for (const port of this.#ports) {
+ port.postMessage({ type: "EnginePort:EngineTerminated" });
+ port.close();
+ }
+ this.#ports = new Set();
+ this.mlEngineChild.removeEngine(this.#engineName);
+ try {
+ const engine = await this.#engine;
+ engine.terminate();
+ } catch (error) {
+ console.error("Failed to get the engine", error);
+ }
+ }
+}
+
+/**
+ * Fake the engine by slicing the text in half.
+ */
+class FakeEngine {
+ /** @type {BasePromiseWorker} */
+ #worker;
+
+ /**
+ * Initialize the worker.
+ *
+ * @param {ArrayBuffer} wasm
+ * @param {ArrayBuffer} model
+ * @returns {FakeEngine}
+ */
+ static async create(wasm, model) {
+ /** @type {BasePromiseWorker} */
+ const worker = new lazy.BasePromiseWorker(
+ "chrome://global/content/ml/MLEngine.worker.mjs",
+ { type: "module" }
+ );
+
+ const args = [wasm, model, lazy.loggingLevel];
+ const closure = {};
+ const transferables = [wasm, model];
+ await worker.post("initializeEngine", args, closure, transferables);
+ return new FakeEngine(worker);
+ }
+
+ /**
+ * @param {BasePromiseWorker} worker
+ */
+ constructor(worker) {
+ this.#worker = worker;
+ }
+
+ /**
+ * @param {string} request
+ * @returns {Promise<string>}
+ */
+ run(request) {
+ return this.#worker.post("run", [request]);
+ }
+
+ terminate() {
+ this.#worker.terminate();
+ this.#worker = null;
+ }
+}
diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
new file mode 100644
index 0000000000..10b4eed4fa
--- /dev/null
+++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
@@ -0,0 +1,403 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/**
+ * @typedef {object} Lazy
+ * @property {typeof console} console
+ * @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess
+ * @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings
+ * @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent
+ */
+
+/** @type {Lazy} */
+const lazy = {};
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+ChromeUtils.defineESModuleGetters(lazy, {
+ EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs",
+ RemoteSettings: "resource://services-settings/remote-settings.sys.mjs",
+ TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs",
+});
+
+/**
+ * @typedef {import("../../translations/translations").WasmRecord} WasmRecord
+ */
+
+const DEFAULT_CACHE_TIMEOUT_MS = 15_000;
+
+/**
+ * The ML engine is in its own content process. This actor handles the
+ * marshalling of the data such as the engine payload.
+ */
+export class MLEngineParent extends JSWindowActorParent {
+ /**
+ * The RemoteSettingsClient that downloads the wasm binaries.
+ *
+ * @type {RemoteSettingsClient | null}
+ */
+ static #remoteClient = null;
+
+ /** @type {Promise<WasmRecord> | null} */
+ static #wasmRecord = null;
+
+ /**
+ * The following constant controls the major version for wasm downloaded from
+ * Remote Settings. When a breaking change is introduced, Nightly will have these
+ * numbers incremented by one, but Beta and Release will still be on the previous
+ * version. Remote Settings will ship both versions of the records, and the latest
+ * asset released in that version will be used. For instance, with a major version
+ * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked
+ * as "2.0", "2.1", etc will not be downloaded.
+ */
+ static WASM_MAJOR_VERSION = 1;
+
+ /**
+ * Remote settings isn't available in tests, so provide mocked responses.
+ *
+ * @param {RemoteSettingsClient} remoteClient
+ */
+ static mockRemoteSettings(remoteClient) {
+ lazy.console.log("Mocking remote settings in MLEngineParent.");
+ MLEngineParent.#remoteClient = remoteClient;
+ MLEngineParent.#wasmRecord = null;
+ }
+
+ /**
+ * Remove anything that could have been mocked.
+ */
+ static removeMocks() {
+ lazy.console.log("Removing mocked remote client in MLEngineParent.");
+ MLEngineParent.#remoteClient = null;
+ MLEngineParent.#wasmRecord = null;
+ }
+
+ /**
+ * @param {string} engineName
+ * @param {() => Promise<ArrayBuffer>} getModel
+ * @param {number} cacheTimeoutMS - How long the engine cache remains alive between
+ * uses, in milliseconds. In automation the engine is manually created and destroyed
+ * to avoid timing issues.
+ * @returns {MLEngine}
+ */
+ getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) {
+ return new MLEngine(this, engineName, getModel, cacheTimeoutMS);
+ }
+
+ // eslint-disable-next-line consistent-return
+ async receiveMessage({ name, data }) {
+ switch (name) {
+ case "MLEngine:Ready":
+ if (lazy.EngineProcess.resolveMLEngineParent) {
+ lazy.EngineProcess.resolveMLEngineParent(this);
+ } else {
+ lazy.console.error(
+ "Expected #resolveMLEngineParent to exist when then ML Engine is ready."
+ );
+ }
+ break;
+ case "MLEngine:GetWasmArrayBuffer":
+ return MLEngineParent.getWasmArrayBuffer();
+ case "MLEngine:DestroyEngineProcess":
+ lazy.EngineProcess.destroyMLEngine().catch(error =>
+ console.error(error)
+ );
+ break;
+ }
+ }
+
+ /**
+ * @param {RemoteSettingsClient} client
+ */
+ static async #getWasmArrayRecord(client) {
+ // Load the wasm binary from remote settings, if it hasn't been already.
+ lazy.console.log(`Getting remote wasm records.`);
+
+ /** @type {WasmRecord[]} */
+ const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords(
+ client,
+ {
+ // TODO - This record needs to be created with the engine wasm payload.
+ filters: { name: "inference-engine" },
+ majorVersion: MLEngineParent.WASM_MAJOR_VERSION,
+ }
+ );
+
+ if (wasmRecords.length === 0) {
+ // The remote settings client provides an empty list of records when there is
+ // an error.
+ throw new Error("Unable to get the ML engine from Remote Settings.");
+ }
+
+ if (wasmRecords.length > 1) {
+ MLEngineParent.reportError(
+ new Error("Expected the ml engine to only have 1 record."),
+ wasmRecords
+ );
+ }
+ const [record] = wasmRecords;
+ lazy.console.log(
+ `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`,
+ record
+ );
+ return record;
+ }
+
+ /**
+ * Download the wasm for the ML inference engine.
+ *
+ * @returns {Promise<ArrayBuffer>}
+ */
+ static async getWasmArrayBuffer() {
+ const client = MLEngineParent.#getRemoteClient();
+
+ if (!MLEngineParent.#wasmRecord) {
+ // Place the records into a promise to prevent any races.
+ MLEngineParent.#wasmRecord = MLEngineParent.#getWasmArrayRecord(client);
+ }
+
+ let wasmRecord;
+ try {
+ wasmRecord = await MLEngineParent.#wasmRecord;
+ if (!wasmRecord) {
+ return Promise.reject(
+ "Error: Unable to get the ML engine from Remote Settings."
+ );
+ }
+ } catch (error) {
+ MLEngineParent.#wasmRecord = null;
+ throw error;
+ }
+
+ /** @type {{buffer: ArrayBuffer}} */
+ const { buffer } = await client.attachments.download(wasmRecord);
+
+ return buffer;
+ }
+
+ /**
+ * Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data.
+ *
+ * @returns {RemoteSettingsClient}
+ */
+ static #getRemoteClient() {
+ if (MLEngineParent.#remoteClient) {
+ return MLEngineParent.#remoteClient;
+ }
+
+ /** @type {RemoteSettingsClient} */
+ const client = lazy.RemoteSettings("ml-wasm");
+
+ MLEngineParent.#remoteClient = client;
+
+ client.on("sync", async ({ data: { created, updated, deleted } }) => {
+ lazy.console.log(`"sync" event for ml-wasm`, {
+ created,
+ updated,
+ deleted,
+ });
+
+ // Remove all the deleted records.
+ for (const record of deleted) {
+ await client.attachments.deleteDownloaded(record);
+ }
+
+ // Remove any updated records, and download the new ones.
+ for (const { old: oldRecord } of updated) {
+ await client.attachments.deleteDownloaded(oldRecord);
+ }
+
+ // Do nothing for the created records.
+ });
+
+ return client;
+ }
+
+ /**
+ * Send a message to gracefully shutdown all of the ML engines in the engine process.
+ * This mostly exists for testing the shutdown paths of the code.
+ */
+ forceShutdown() {
+ return this.sendQuery("MLEngine:ForceShutdown");
+ }
+}
+
+/**
+ * This contains all of the information needed to perform a translation request.
+ *
+ * @typedef {object} TranslationRequest
+ * @property {Node} node
+ * @property {string} sourceText
+ * @property {boolean} isHTML
+ * @property {Function} resolve
+ * @property {Function} reject
+ */
+
+/**
+ * The interface to communicate to an MLEngine in the parent process. The engine manages
+ * its own lifetime, and is kept alive with a timeout. A reference to this engine can
+ * be retained, but once idle, the engine will be destroyed. If a new request to run
+ * is sent, the engine will be recreated on demand. This balances the cost of retaining
+ * potentially large amounts of memory to run models, with the speed and ease of running
+ * the engine.
+ *
+ * @template Request
+ * @template Response
+ */
+class MLEngine {
+ /**
+ * @type {MessagePort | null}
+ */
+ #port = null;
+
+ #nextRequestId = 0;
+
+ /**
+ * Tie together a message id to a resolved response.
+ *
+ * @type {Map<number, PromiseWithResolvers<Request>>}
+ */
+ #requests = new Map();
+
+ /**
+ * @type {"uninitialized" | "ready" | "error" | "closed"}
+ */
+ engineStatus = "uninitialized";
+
+ /**
+ * @param {MLEngineParent} mlEngineParent
+ * @param {string} engineName
+ * @param {() => Promise<ArrayBuffer>} getModel
+ * @param {number} timeoutMS
+ */
+ constructor(mlEngineParent, engineName, getModel, timeoutMS) {
+ /** @type {MLEngineParent} */
+ this.mlEngineParent = mlEngineParent;
+ /** @type {string} */
+ this.engineName = engineName;
+ /** @type {() => Promise<ArrayBuffer>} */
+ this.getModel = getModel;
+ /** @type {number} */
+ this.timeoutMS = timeoutMS;
+
+ this.#setupPortCommunication();
+ }
+
+ /**
+ * Create a MessageChannel to communicate with the engine directly.
+ */
+ #setupPortCommunication() {
+ const { port1: childPort, port2: parentPort } = new MessageChannel();
+ const transferables = [childPort];
+ this.#port = parentPort;
+
+ this.#port.onmessage = this.handlePortMessage;
+ this.mlEngineParent.sendAsyncMessage(
+ "MLEngine:NewPort",
+ {
+ port: childPort,
+ engineName: this.engineName,
+ timeoutMS: this.timeoutMS,
+ },
+ transferables
+ );
+ }
+
+ handlePortMessage = ({ data }) => {
+ switch (data.type) {
+ case "EnginePort:ModelRequest": {
+ if (this.#port) {
+ this.getModel().then(
+ model => {
+ this.#port.postMessage({
+ type: "EnginePort:ModelResponse",
+ model,
+ error: null,
+ });
+ },
+ error => {
+ this.#port.postMessage({
+ type: "EnginePort:ModelResponse",
+ model: null,
+ error,
+ });
+ if (
+ // Ignore intentional errors in tests.
+ !error?.message.startsWith("Intentionally")
+ ) {
+ lazy.console.error("Failed to get the model", error);
+ }
+ }
+ );
+ } else {
+ lazy.console.error(
+ "Expected a port to exist during the EnginePort:GetModel event"
+ );
+ }
+ break;
+ }
+ case "EnginePort:RunResponse": {
+ const { response, error, requestId } = data;
+ const request = this.#requests.get(requestId);
+ if (request) {
+ if (response) {
+ request.resolve(response);
+ } else {
+ request.reject(error);
+ }
+ } else {
+ lazy.console.error(
+ "Could not resolve response in the MLEngineParent",
+ data
+ );
+ }
+ this.#requests.delete(requestId);
+ break;
+ }
+ case "EnginePort:EngineTerminated": {
+ // The engine was terminated, and if a new run is needed a new port
+ // will need to be requested.
+ this.engineStatus = "closed";
+ this.discardPort();
+ break;
+ }
+ default:
+ lazy.console.error("Unknown port message from engine", data);
+ break;
+ }
+ };
+
+ discardPort() {
+ if (this.#port) {
+ this.#port.postMessage({ type: "EnginePort:Discard" });
+ this.#port.close();
+ this.#port = null;
+ }
+ }
+
+ terminate() {
+ this.#port.postMessage({ type: "EnginePort:Terminate" });
+ }
+
+ /**
+ * @param {Request} request
+ * @returns {Promise<Response>}
+ */
+ run(request) {
+ const resolvers = Promise.withResolvers();
+ const requestId = this.#nextRequestId++;
+ this.#requests.set(requestId, resolvers);
+ this.#port.postMessage({
+ type: "EnginePort:Run",
+ requestId,
+ request,
+ });
+ return resolvers.promise;
+ }
+}
diff --git a/toolkit/components/ml/actors/moz.build b/toolkit/components/ml/actors/moz.build
new file mode 100644
index 0000000000..de3e27ae2a
--- /dev/null
+++ b/toolkit/components/ml/actors/moz.build
@@ -0,0 +1,8 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+FINAL_TARGET_FILES.actors += [
+ "MLEngineChild.sys.mjs",
+ "MLEngineParent.sys.mjs",
+]