summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/actors/MLEngineChild.sys.mjs
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/actors/MLEngineChild.sys.mjs')
-rw-r--r--toolkit/components/ml/actors/MLEngineChild.sys.mjs325
1 files changed, 325 insertions, 0 deletions
diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
new file mode 100644
index 0000000000..925ce59266
--- /dev/null
+++ b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
@@ -0,0 +1,325 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
+
+/**
+ * @typedef {import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
+ */
+
+/**
+ * @typedef {object} Lazy
+ * @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
+ * @property {typeof setTimeout} setTimeout
+ * @property {typeof clearTimeout} clearTimeout
+ */
+
+/** @type {Lazy} */
+const lazy = {};
+ChromeUtils.defineESModuleGetters(lazy, {
+ BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs",
+ setTimeout: "resource://gre/modules/Timer.sys.mjs",
+ clearTimeout: "resource://gre/modules/Timer.sys.mjs",
+});
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+XPCOMUtils.defineLazyPreferenceGetter(
+ lazy,
+ "loggingLevel",
+ "browser.ml.logLevel"
+);
+
+/**
+ * The engine child is responsible for the life cycle and instantiation of the local
+ * machine learning inference engine.
+ */
+export class MLEngineChild extends JSWindowActorChild {
+ /**
+ * The cached engines.
+ *
+ * @type {Map<string, EngineDispatcher>}
+ */
+ #engineDispatchers = new Map();
+
+ // eslint-disable-next-line consistent-return
+ async receiveMessage({ name, data }) {
+ switch (name) {
+ case "MLEngine:NewPort": {
+ const { engineName, port, timeoutMS } = data;
+ this.#engineDispatchers.set(
+ engineName,
+ new EngineDispatcher(this, port, engineName, timeoutMS)
+ );
+ break;
+ }
+ case "MLEngine:ForceShutdown": {
+ for (const engineDispatcher of this.#engineDispatchers.values()) {
+ return engineDispatcher.terminate();
+ }
+ this.#engineDispatchers = null;
+ break;
+ }
+ }
+ }
+
+ handleEvent(event) {
+ switch (event.type) {
+ case "DOMContentLoaded":
+ this.sendAsyncMessage("MLEngine:Ready");
+ break;
+ }
+ }
+
+ /**
+ * @returns {ArrayBuffer}
+ */
+ getWasmArrayBuffer() {
+ return this.sendQuery("MLEngine:GetWasmArrayBuffer");
+ }
+
+ /**
+ * @param {string} engineName
+ */
+ removeEngine(engineName) {
+ this.#engineDispatchers.delete(engineName);
+ if (this.#engineDispatchers.size === 0) {
+ this.sendQuery("MLEngine:DestroyEngineProcess");
+ }
+ }
+}
+
+/**
+ * This classes manages the lifecycle of an ML Engine, and handles dispatching messages
+ * to it.
+ */
+class EngineDispatcher {
+ /** @type {Set<MessagePort>} */
+ #ports = new Set();
+
+ /** @type {TimeoutID | null} */
+ #keepAliveTimeout = null;
+
+ /** @type {PromiseWithResolvers} */
+ #modelRequest;
+
+ /** @type {Promise<Engine> | null} */
+ #engine = null;
+
+ /** @type {string} */
+ #engineName;
+
+ /**
+ * @param {MLEngineChild} mlEngineChild
+ * @param {MessagePort} port
+ * @param {string} engineName
+ * @param {number} timeoutMS
+ */
+ constructor(mlEngineChild, port, engineName, timeoutMS) {
+ /** @type {MLEngineChild} */
+ this.mlEngineChild = mlEngineChild;
+
+ /** @type {number} */
+ this.timeoutMS = timeoutMS;
+
+ this.#engineName = engineName;
+
+ this.#engine = Promise.all([
+ this.mlEngineChild.getWasmArrayBuffer(),
+ this.getModel(port),
+ ]).then(([wasm, model]) => FakeEngine.create(wasm, model));
+
+ this.#engine
+ .then(() => void this.keepAlive())
+ .catch(error => {
+ if (
+ // Ignore errors from tests intentionally causing errors.
+ !error?.message?.startsWith("Intentionally")
+ ) {
+ lazy.console.error("Could not initalize the engine", error);
+ }
+ });
+
+ this.setupMessageHandler(port);
+ }
+
+ /**
+ * The worker needs to be shutdown after some amount of time of not being used.
+ */
+ keepAlive() {
+ if (this.#keepAliveTimeout) {
+ // Clear any previous timeout.
+ lazy.clearTimeout(this.#keepAliveTimeout);
+ }
+ // In automated tests, the engine is manually destroyed.
+ if (!Cu.isInAutomation) {
+ this.#keepAliveTimeout = lazy.setTimeout(this.terminate, this.timeoutMS);
+ }
+ }
+
+ /**
+ * @param {MessagePort} port
+ */
+ getModel(port) {
+ if (this.#modelRequest) {
+ // There could be a race to get a model, use the first request.
+ return this.#modelRequest.promise;
+ }
+ this.#modelRequest = Promise.withResolvers();
+ port.postMessage({ type: "EnginePort:ModelRequest" });
+ return this.#modelRequest.promise;
+ }
+
+ /**
+ * @param {MessagePort} port
+ */
+ setupMessageHandler(port) {
+ port.onmessage = async ({ data }) => {
+ switch (data.type) {
+ case "EnginePort:Discard": {
+ port.close();
+ this.#ports.delete(port);
+ break;
+ }
+ case "EnginePort:Terminate": {
+ this.terminate();
+ break;
+ }
+ case "EnginePort:ModelResponse": {
+ if (this.#modelRequest) {
+ const { model, error } = data;
+ if (model) {
+ this.#modelRequest.resolve(model);
+ } else {
+ this.#modelRequest.reject(error);
+ }
+ this.#modelRequest = null;
+ } else {
+ lazy.console.error(
+ "Got a EnginePort:ModelResponse but no model resolvers"
+ );
+ }
+ break;
+ }
+ case "EnginePort:Run": {
+ const { requestId, request } = data;
+ let engine;
+ try {
+ engine = await this.#engine;
+ } catch (error) {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: null,
+ error,
+ });
+ // The engine failed to load. Terminate the entire dispatcher.
+ this.terminate();
+ return;
+ }
+
+ // Do not run the keepAlive timer until we are certain that the engine loaded,
+ // as the engine shouldn't be killed while it is initializing.
+ this.keepAlive();
+
+ try {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: await engine.run(request),
+ error: null,
+ });
+ } catch (error) {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: null,
+ error,
+ });
+ }
+ break;
+ }
+ default:
+ lazy.console.error("Unknown port message to engine: ", data);
+ break;
+ }
+ };
+ }
+
+ /**
+ * Terminates the engine and its worker after a timeout.
+ */
+ async terminate() {
+ if (this.#keepAliveTimeout) {
+ lazy.clearTimeout(this.#keepAliveTimeout);
+ this.#keepAliveTimeout = null;
+ }
+ for (const port of this.#ports) {
+ port.postMessage({ type: "EnginePort:EngineTerminated" });
+ port.close();
+ }
+ this.#ports = new Set();
+ this.mlEngineChild.removeEngine(this.#engineName);
+ try {
+ const engine = await this.#engine;
+ engine.terminate();
+ } catch (error) {
+ console.error("Failed to get the engine", error);
+ }
+ }
+}
+
+/**
+ * Fake the engine by slicing the text in half.
+ */
+class FakeEngine {
+ /** @type {BasePromiseWorker} */
+ #worker;
+
+ /**
+ * Initialize the worker.
+ *
+ * @param {ArrayBuffer} wasm
+ * @param {ArrayBuffer} model
+ * @returns {FakeEngine}
+ */
+ static async create(wasm, model) {
+ /** @type {BasePromiseWorker} */
+ const worker = new lazy.BasePromiseWorker(
+ "chrome://global/content/ml/MLEngine.worker.mjs",
+ { type: "module" }
+ );
+
+ const args = [wasm, model, lazy.loggingLevel];
+ const closure = {};
+ const transferables = [wasm, model];
+ await worker.post("initializeEngine", args, closure, transferables);
+ return new FakeEngine(worker);
+ }
+
+ /**
+ * @param {BasePromiseWorker} worker
+ */
+ constructor(worker) {
+ this.#worker = worker;
+ }
+
+ /**
+ * @param {string} request
+ * @returns {Promise<string>}
+ */
+ run(request) {
+ return this.#worker.post("run", [request]);
+ }
+
+ terminate() {
+ this.#worker.terminate();
+ this.#worker = null;
+ }
+}