summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml')
-rw-r--r--toolkit/components/ml/actors/MLEngineChild.sys.mjs325
-rw-r--r--toolkit/components/ml/actors/MLEngineParent.sys.mjs403
-rw-r--r--toolkit/components/ml/actors/moz.build8
-rw-r--r--toolkit/components/ml/content/EngineProcess.sys.mjs241
-rw-r--r--toolkit/components/ml/content/MLEngine.html16
-rw-r--r--toolkit/components/ml/content/MLEngine.worker.mjs91
-rw-r--r--toolkit/components/ml/content/SummarizerModel.sys.mjs160
-rw-r--r--toolkit/components/ml/docs/index.md44
-rw-r--r--toolkit/components/ml/jar.mn9
-rw-r--r--toolkit/components/ml/moz.build17
-rw-r--r--toolkit/components/ml/tests/browser/browser.toml5
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_engine.js219
-rw-r--r--toolkit/components/ml/tests/browser/head.js155
13 files changed, 1693 insertions, 0 deletions
diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
new file mode 100644
index 0000000000..925ce59266
--- /dev/null
+++ b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
@@ -0,0 +1,325 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
+
+/**
+ * @typedef {import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
+ */
+
+/**
+ * @typedef {object} Lazy
+ * @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
+ * @property {typeof setTimeout} setTimeout
+ * @property {typeof clearTimeout} clearTimeout
+ */
+
+/** @type {Lazy} */
+const lazy = {};
+ChromeUtils.defineESModuleGetters(lazy, {
+ BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs",
+ setTimeout: "resource://gre/modules/Timer.sys.mjs",
+ clearTimeout: "resource://gre/modules/Timer.sys.mjs",
+});
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+XPCOMUtils.defineLazyPreferenceGetter(
+ lazy,
+ "loggingLevel",
+ "browser.ml.logLevel"
+);
+
+/**
+ * The engine child is responsible for the life cycle and instantiation of the local
+ * machine learning inference engine.
+ */
+export class MLEngineChild extends JSWindowActorChild {
+ /**
+ * The cached engines.
+ *
+ * @type {Map<string, EngineDispatcher>}
+ */
+ #engineDispatchers = new Map();
+
+ // eslint-disable-next-line consistent-return
+ async receiveMessage({ name, data }) {
+ switch (name) {
+ case "MLEngine:NewPort": {
+ const { engineName, port, timeoutMS } = data;
+ this.#engineDispatchers.set(
+ engineName,
+ new EngineDispatcher(this, port, engineName, timeoutMS)
+ );
+ break;
+ }
+ case "MLEngine:ForceShutdown": {
+ for (const engineDispatcher of this.#engineDispatchers.values()) {
+ return engineDispatcher.terminate();
+ }
+ this.#engineDispatchers = null;
+ break;
+ }
+ }
+ }
+
+ handleEvent(event) {
+ switch (event.type) {
+ case "DOMContentLoaded":
+ this.sendAsyncMessage("MLEngine:Ready");
+ break;
+ }
+ }
+
+ /**
+ * @returns {ArrayBuffer}
+ */
+ getWasmArrayBuffer() {
+ return this.sendQuery("MLEngine:GetWasmArrayBuffer");
+ }
+
+ /**
+ * @param {string} engineName
+ */
+ removeEngine(engineName) {
+ this.#engineDispatchers.delete(engineName);
+ if (this.#engineDispatchers.size === 0) {
+ this.sendQuery("MLEngine:DestroyEngineProcess");
+ }
+ }
+}
+
+/**
+ * This classes manages the lifecycle of an ML Engine, and handles dispatching messages
+ * to it.
+ */
+class EngineDispatcher {
+ /** @type {Set<MessagePort>} */
+ #ports = new Set();
+
+ /** @type {TimeoutID | null} */
+ #keepAliveTimeout = null;
+
+ /** @type {PromiseWithResolvers} */
+ #modelRequest;
+
+ /** @type {Promise<Engine> | null} */
+ #engine = null;
+
+ /** @type {string} */
+ #engineName;
+
+ /**
+ * @param {MLEngineChild} mlEngineChild
+ * @param {MessagePort} port
+ * @param {string} engineName
+ * @param {number} timeoutMS
+ */
+ constructor(mlEngineChild, port, engineName, timeoutMS) {
+ /** @type {MLEngineChild} */
+ this.mlEngineChild = mlEngineChild;
+
+ /** @type {number} */
+ this.timeoutMS = timeoutMS;
+
+ this.#engineName = engineName;
+
+ this.#engine = Promise.all([
+ this.mlEngineChild.getWasmArrayBuffer(),
+ this.getModel(port),
+ ]).then(([wasm, model]) => FakeEngine.create(wasm, model));
+
+ this.#engine
+ .then(() => void this.keepAlive())
+ .catch(error => {
+ if (
+ // Ignore errors from tests intentionally causing errors.
+ !error?.message?.startsWith("Intentionally")
+ ) {
+ lazy.console.error("Could not initalize the engine", error);
+ }
+ });
+
+ this.setupMessageHandler(port);
+ }
+
+ /**
+ * The worker needs to be shutdown after some amount of time of not being used.
+ */
+ keepAlive() {
+ if (this.#keepAliveTimeout) {
+ // Clear any previous timeout.
+ lazy.clearTimeout(this.#keepAliveTimeout);
+ }
+ // In automated tests, the engine is manually destroyed.
+ if (!Cu.isInAutomation) {
+ this.#keepAliveTimeout = lazy.setTimeout(this.terminate, this.timeoutMS);
+ }
+ }
+
+ /**
+ * @param {MessagePort} port
+ */
+ getModel(port) {
+ if (this.#modelRequest) {
+ // There could be a race to get a model, use the first request.
+ return this.#modelRequest.promise;
+ }
+ this.#modelRequest = Promise.withResolvers();
+ port.postMessage({ type: "EnginePort:ModelRequest" });
+ return this.#modelRequest.promise;
+ }
+
+ /**
+ * @param {MessagePort} port
+ */
+ setupMessageHandler(port) {
+ port.onmessage = async ({ data }) => {
+ switch (data.type) {
+ case "EnginePort:Discard": {
+ port.close();
+ this.#ports.delete(port);
+ break;
+ }
+ case "EnginePort:Terminate": {
+ this.terminate();
+ break;
+ }
+ case "EnginePort:ModelResponse": {
+ if (this.#modelRequest) {
+ const { model, error } = data;
+ if (model) {
+ this.#modelRequest.resolve(model);
+ } else {
+ this.#modelRequest.reject(error);
+ }
+ this.#modelRequest = null;
+ } else {
+ lazy.console.error(
+ "Got a EnginePort:ModelResponse but no model resolvers"
+ );
+ }
+ break;
+ }
+ case "EnginePort:Run": {
+ const { requestId, request } = data;
+ let engine;
+ try {
+ engine = await this.#engine;
+ } catch (error) {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: null,
+ error,
+ });
+ // The engine failed to load. Terminate the entire dispatcher.
+ this.terminate();
+ return;
+ }
+
+ // Do not run the keepAlive timer until we are certain that the engine loaded,
+ // as the engine shouldn't be killed while it is initializing.
+ this.keepAlive();
+
+ try {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: await engine.run(request),
+ error: null,
+ });
+ } catch (error) {
+ port.postMessage({
+ type: "EnginePort:RunResponse",
+ requestId,
+ response: null,
+ error,
+ });
+ }
+ break;
+ }
+ default:
+ lazy.console.error("Unknown port message to engine: ", data);
+ break;
+ }
+ };
+ }
+
+ /**
+ * Terminates the engine and its worker after a timeout.
+ */
+ async terminate() {
+ if (this.#keepAliveTimeout) {
+ lazy.clearTimeout(this.#keepAliveTimeout);
+ this.#keepAliveTimeout = null;
+ }
+ for (const port of this.#ports) {
+ port.postMessage({ type: "EnginePort:EngineTerminated" });
+ port.close();
+ }
+ this.#ports = new Set();
+ this.mlEngineChild.removeEngine(this.#engineName);
+ try {
+ const engine = await this.#engine;
+ engine.terminate();
+ } catch (error) {
+ console.error("Failed to get the engine", error);
+ }
+ }
+}
+
+/**
+ * Fake the engine by slicing the text in half.
+ */
+class FakeEngine {
+ /** @type {BasePromiseWorker} */
+ #worker;
+
+ /**
+ * Initialize the worker.
+ *
+ * @param {ArrayBuffer} wasm
+ * @param {ArrayBuffer} model
+ * @returns {FakeEngine}
+ */
+ static async create(wasm, model) {
+ /** @type {BasePromiseWorker} */
+ const worker = new lazy.BasePromiseWorker(
+ "chrome://global/content/ml/MLEngine.worker.mjs",
+ { type: "module" }
+ );
+
+ const args = [wasm, model, lazy.loggingLevel];
+ const closure = {};
+ const transferables = [wasm, model];
+ await worker.post("initializeEngine", args, closure, transferables);
+ return new FakeEngine(worker);
+ }
+
+ /**
+ * @param {BasePromiseWorker} worker
+ */
+ constructor(worker) {
+ this.#worker = worker;
+ }
+
+ /**
+ * @param {string} request
+ * @returns {Promise<string>}
+ */
+ run(request) {
+ return this.#worker.post("run", [request]);
+ }
+
+ terminate() {
+ this.#worker.terminate();
+ this.#worker = null;
+ }
+}
diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
new file mode 100644
index 0000000000..10b4eed4fa
--- /dev/null
+++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
@@ -0,0 +1,403 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/**
+ * @typedef {object} Lazy
+ * @property {typeof console} console
+ * @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess
+ * @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings
+ * @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent
+ */
+
+/** @type {Lazy} */
+const lazy = {};
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+ChromeUtils.defineESModuleGetters(lazy, {
+ EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs",
+ RemoteSettings: "resource://services-settings/remote-settings.sys.mjs",
+ TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs",
+});
+
+/**
+ * @typedef {import("../../translations/translations").WasmRecord} WasmRecord
+ */
+
+const DEFAULT_CACHE_TIMEOUT_MS = 15_000;
+
+/**
+ * The ML engine is in its own content process. This actor handles the
+ * marshalling of the data such as the engine payload.
+ */
+export class MLEngineParent extends JSWindowActorParent {
+ /**
+ * The RemoteSettingsClient that downloads the wasm binaries.
+ *
+ * @type {RemoteSettingsClient | null}
+ */
+ static #remoteClient = null;
+
+ /** @type {Promise<WasmRecord> | null} */
+ static #wasmRecord = null;
+
+ /**
+ * The following constant controls the major version for wasm downloaded from
+ * Remote Settings. When a breaking change is introduced, Nightly will have these
+ * numbers incremented by one, but Beta and Release will still be on the previous
+ * version. Remote Settings will ship both versions of the records, and the latest
+ * asset released in that version will be used. For instance, with a major version
+ * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked
+ * as "2.0", "2.1", etc will not be downloaded.
+ */
+ static WASM_MAJOR_VERSION = 1;
+
+ /**
+ * Remote settings isn't available in tests, so provide mocked responses.
+ *
+ * @param {RemoteSettingsClient} remoteClient
+ */
+ static mockRemoteSettings(remoteClient) {
+ lazy.console.log("Mocking remote settings in MLEngineParent.");
+ MLEngineParent.#remoteClient = remoteClient;
+ MLEngineParent.#wasmRecord = null;
+ }
+
+ /**
+ * Remove anything that could have been mocked.
+ */
+ static removeMocks() {
+ lazy.console.log("Removing mocked remote client in MLEngineParent.");
+ MLEngineParent.#remoteClient = null;
+ MLEngineParent.#wasmRecord = null;
+ }
+
+ /**
+ * @param {string} engineName
+ * @param {() => Promise<ArrayBuffer>} getModel
+ * @param {number} cacheTimeoutMS - How long the engine cache remains alive between
+ * uses, in milliseconds. In automation the engine is manually created and destroyed
+ * to avoid timing issues.
+ * @returns {MLEngine}
+ */
+ getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) {
+ return new MLEngine(this, engineName, getModel, cacheTimeoutMS);
+ }
+
+ // eslint-disable-next-line consistent-return
+ async receiveMessage({ name, data }) {
+ switch (name) {
+ case "MLEngine:Ready":
+ if (lazy.EngineProcess.resolveMLEngineParent) {
+ lazy.EngineProcess.resolveMLEngineParent(this);
+ } else {
+ lazy.console.error(
+ "Expected #resolveMLEngineParent to exist when then ML Engine is ready."
+ );
+ }
+ break;
+ case "MLEngine:GetWasmArrayBuffer":
+ return MLEngineParent.getWasmArrayBuffer();
+ case "MLEngine:DestroyEngineProcess":
+ lazy.EngineProcess.destroyMLEngine().catch(error =>
+ console.error(error)
+ );
+ break;
+ }
+ }
+
+ /**
+ * @param {RemoteSettingsClient} client
+ */
+ static async #getWasmArrayRecord(client) {
+ // Load the wasm binary from remote settings, if it hasn't been already.
+ lazy.console.log(`Getting remote wasm records.`);
+
+ /** @type {WasmRecord[]} */
+ const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords(
+ client,
+ {
+ // TODO - This record needs to be created with the engine wasm payload.
+ filters: { name: "inference-engine" },
+ majorVersion: MLEngineParent.WASM_MAJOR_VERSION,
+ }
+ );
+
+ if (wasmRecords.length === 0) {
+ // The remote settings client provides an empty list of records when there is
+ // an error.
+ throw new Error("Unable to get the ML engine from Remote Settings.");
+ }
+
+ if (wasmRecords.length > 1) {
+ MLEngineParent.reportError(
+ new Error("Expected the ml engine to only have 1 record."),
+ wasmRecords
+ );
+ }
+ const [record] = wasmRecords;
+ lazy.console.log(
+ `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`,
+ record
+ );
+ return record;
+ }
+
+ /**
+ * Download the wasm for the ML inference engine.
+ *
+ * @returns {Promise<ArrayBuffer>}
+ */
+ static async getWasmArrayBuffer() {
+ const client = MLEngineParent.#getRemoteClient();
+
+ if (!MLEngineParent.#wasmRecord) {
+ // Place the records into a promise to prevent any races.
+ MLEngineParent.#wasmRecord = MLEngineParent.#getWasmArrayRecord(client);
+ }
+
+ let wasmRecord;
+ try {
+ wasmRecord = await MLEngineParent.#wasmRecord;
+ if (!wasmRecord) {
+ return Promise.reject(
+ "Error: Unable to get the ML engine from Remote Settings."
+ );
+ }
+ } catch (error) {
+ MLEngineParent.#wasmRecord = null;
+ throw error;
+ }
+
+ /** @type {{buffer: ArrayBuffer}} */
+ const { buffer } = await client.attachments.download(wasmRecord);
+
+ return buffer;
+ }
+
+ /**
+ * Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data.
+ *
+ * @returns {RemoteSettingsClient}
+ */
+ static #getRemoteClient() {
+ if (MLEngineParent.#remoteClient) {
+ return MLEngineParent.#remoteClient;
+ }
+
+ /** @type {RemoteSettingsClient} */
+ const client = lazy.RemoteSettings("ml-wasm");
+
+ MLEngineParent.#remoteClient = client;
+
+ client.on("sync", async ({ data: { created, updated, deleted } }) => {
+ lazy.console.log(`"sync" event for ml-wasm`, {
+ created,
+ updated,
+ deleted,
+ });
+
+ // Remove all the deleted records.
+ for (const record of deleted) {
+ await client.attachments.deleteDownloaded(record);
+ }
+
+ // Remove any updated records, and download the new ones.
+ for (const { old: oldRecord } of updated) {
+ await client.attachments.deleteDownloaded(oldRecord);
+ }
+
+ // Do nothing for the created records.
+ });
+
+ return client;
+ }
+
+ /**
+ * Send a message to gracefully shutdown all of the ML engines in the engine process.
+ * This mostly exists for testing the shutdown paths of the code.
+ */
+ forceShutdown() {
+ return this.sendQuery("MLEngine:ForceShutdown");
+ }
+}
+
+/**
+ * This contains all of the information needed to perform a translation request.
+ *
+ * @typedef {object} TranslationRequest
+ * @property {Node} node
+ * @property {string} sourceText
+ * @property {boolean} isHTML
+ * @property {Function} resolve
+ * @property {Function} reject
+ */
+
+/**
+ * The interface to communicate to an MLEngine in the parent process. The engine manages
+ * its own lifetime, and is kept alive with a timeout. A reference to this engine can
+ * be retained, but once idle, the engine will be destroyed. If a new request to run
+ * is sent, the engine will be recreated on demand. This balances the cost of retaining
+ * potentially large amounts of memory to run models, with the speed and ease of running
+ * the engine.
+ *
+ * @template Request
+ * @template Response
+ */
+class MLEngine {
+ /**
+ * @type {MessagePort | null}
+ */
+ #port = null;
+
+ #nextRequestId = 0;
+
+ /**
+ * Tie together a message id to a resolved response.
+ *
+ * @type {Map<number, PromiseWithResolvers<Request>>}
+ */
+ #requests = new Map();
+
+ /**
+ * @type {"uninitialized" | "ready" | "error" | "closed"}
+ */
+ engineStatus = "uninitialized";
+
+ /**
+ * @param {MLEngineParent} mlEngineParent
+ * @param {string} engineName
+ * @param {() => Promise<ArrayBuffer>} getModel
+ * @param {number} timeoutMS
+ */
+ constructor(mlEngineParent, engineName, getModel, timeoutMS) {
+ /** @type {MLEngineParent} */
+ this.mlEngineParent = mlEngineParent;
+ /** @type {string} */
+ this.engineName = engineName;
+ /** @type {() => Promise<ArrayBuffer>} */
+ this.getModel = getModel;
+ /** @type {number} */
+ this.timeoutMS = timeoutMS;
+
+ this.#setupPortCommunication();
+ }
+
+ /**
+ * Create a MessageChannel to communicate with the engine directly.
+ */
+ #setupPortCommunication() {
+ const { port1: childPort, port2: parentPort } = new MessageChannel();
+ const transferables = [childPort];
+ this.#port = parentPort;
+
+ this.#port.onmessage = this.handlePortMessage;
+ this.mlEngineParent.sendAsyncMessage(
+ "MLEngine:NewPort",
+ {
+ port: childPort,
+ engineName: this.engineName,
+ timeoutMS: this.timeoutMS,
+ },
+ transferables
+ );
+ }
+
+ handlePortMessage = ({ data }) => {
+ switch (data.type) {
+ case "EnginePort:ModelRequest": {
+ if (this.#port) {
+ this.getModel().then(
+ model => {
+ this.#port.postMessage({
+ type: "EnginePort:ModelResponse",
+ model,
+ error: null,
+ });
+ },
+ error => {
+ this.#port.postMessage({
+ type: "EnginePort:ModelResponse",
+ model: null,
+ error,
+ });
+ if (
+ // Ignore intentional errors in tests.
+ !error?.message.startsWith("Intentionally")
+ ) {
+ lazy.console.error("Failed to get the model", error);
+ }
+ }
+ );
+ } else {
+ lazy.console.error(
+ "Expected a port to exist during the EnginePort:GetModel event"
+ );
+ }
+ break;
+ }
+ case "EnginePort:RunResponse": {
+ const { response, error, requestId } = data;
+ const request = this.#requests.get(requestId);
+ if (request) {
+ if (response) {
+ request.resolve(response);
+ } else {
+ request.reject(error);
+ }
+ } else {
+ lazy.console.error(
+ "Could not resolve response in the MLEngineParent",
+ data
+ );
+ }
+ this.#requests.delete(requestId);
+ break;
+ }
+ case "EnginePort:EngineTerminated": {
+ // The engine was terminated, and if a new run is needed a new port
+ // will need to be requested.
+ this.engineStatus = "closed";
+ this.discardPort();
+ break;
+ }
+ default:
+ lazy.console.error("Unknown port message from engine", data);
+ break;
+ }
+ };
+
+ discardPort() {
+ if (this.#port) {
+ this.#port.postMessage({ type: "EnginePort:Discard" });
+ this.#port.close();
+ this.#port = null;
+ }
+ }
+
+ terminate() {
+ this.#port.postMessage({ type: "EnginePort:Terminate" });
+ }
+
+ /**
+ * @param {Request} request
+ * @returns {Promise<Response>}
+ */
+ run(request) {
+ const resolvers = Promise.withResolvers();
+ const requestId = this.#nextRequestId++;
+ this.#requests.set(requestId, resolvers);
+ this.#port.postMessage({
+ type: "EnginePort:Run",
+ requestId,
+ request,
+ });
+ return resolvers.promise;
+ }
+}
diff --git a/toolkit/components/ml/actors/moz.build b/toolkit/components/ml/actors/moz.build
new file mode 100644
index 0000000000..de3e27ae2a
--- /dev/null
+++ b/toolkit/components/ml/actors/moz.build
@@ -0,0 +1,8 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+FINAL_TARGET_FILES.actors += [
+ "MLEngineChild.sys.mjs",
+ "MLEngineParent.sys.mjs",
+]
diff --git a/toolkit/components/ml/content/EngineProcess.sys.mjs b/toolkit/components/ml/content/EngineProcess.sys.mjs
new file mode 100644
index 0000000000..36a9381192
--- /dev/null
+++ b/toolkit/components/ml/content/EngineProcess.sys.mjs
@@ -0,0 +1,241 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+const lazy = {};
+ChromeUtils.defineESModuleGetters(lazy, {
+ HiddenFrame: "resource://gre/modules/HiddenFrame.sys.mjs",
+});
+
+/**
+ * @typedef {import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent
+ */
+
+/**
+ * @typedef {import("../../translations/actors/TranslationsEngineParent.sys.mjs").TranslationsEngineParent} TranslationsEngineParent
+ */
+
+/**
+ * This class controls the life cycle of the engine process used both in the
+ * Translations engine and the MLEngine component.
+ */
+export class EngineProcess {
+ /**
+ * @type {Promise<{ hiddenFrame: HiddenFrame, actor: TranslationsEngineParent }> | null}
+ */
+
+ /** @type {Promise<HiddenFrame> | null} */
+ static #hiddenFrame = null;
+ /** @type {Promise<TranslationsEngineParent> | null} */
+ static translationsEngineParent = null;
+ /** @type {Promise<MLEngineParent> | null} */
+ static mlEngineParent = null;
+
+ /** @type {((actor: TranslationsEngineParent) => void) | null} */
+ resolveTranslationsEngineParent = null;
+
+ /** @type {((actor: MLEngineParent) => void) | null} */
+ resolveMLEngineParent = null;
+
+ /**
+ * See if all engines are terminated. This is useful for testing.
+ *
+ * @returns {boolean}
+ */
+ static areAllEnginesTerminated() {
+ return (
+ !EngineProcess.#hiddenFrame &&
+ !EngineProcess.translationsEngineParent &&
+ !EngineProcess.mlEngineParent
+ );
+ }
+
+ /**
+ * @returns {Promise<TranslationsEngineParent>}
+ */
+ static async getTranslationsEngineParent() {
+ if (!this.translationsEngineParent) {
+ this.translationsEngineParent = this.#attachBrowser({
+ id: "translations-engine-browser",
+ url: "chrome://global/content/translations/translations-engine.html",
+ resolverName: "resolveTranslationsEngineParent",
+ });
+ }
+ return this.translationsEngineParent;
+ }
+
+ /**
+ * @returns {Promise<MLEngineParent>}
+ */
+ static async getMLEngineParent() {
+ if (!this.mlEngineParent) {
+ this.mlEngineParent = this.#attachBrowser({
+ id: "ml-engine-browser",
+ url: "chrome://global/content/ml/MLEngine.html",
+ resolverName: "resolveMLEngineParent",
+ });
+ }
+ return this.mlEngineParent;
+ }
+
+ /**
+ * @param {object} config
+ * @param {string} config.url
+ * @param {string} config.id
+ * @param {string} config.resolverName
+ * @returns {Promise<TranslationsEngineParent>}
+ */
+ static async #attachBrowser({ url, id, resolverName }) {
+ const hiddenFrame = await this.#getHiddenFrame();
+ const chromeWindow = await hiddenFrame.get();
+ const doc = chromeWindow.document;
+
+ if (doc.getElementById(id)) {
+ throw new Error(
+ "Attempting to append the translations-engine.html <browser> when one " +
+ "already exists."
+ );
+ }
+
+ const browser = doc.createXULElement("browser");
+ browser.setAttribute("id", id);
+ browser.setAttribute("remote", "true");
+ browser.setAttribute("remoteType", "web");
+ browser.setAttribute("disableglobalhistory", "true");
+ browser.setAttribute("type", "content");
+ browser.setAttribute("src", url);
+
+ ChromeUtils.addProfilerMarker(
+ "EngineProcess",
+ {},
+ `Creating the "${id}" process`
+ );
+ doc.documentElement.appendChild(browser);
+
+ const { promise, resolve } = Promise.withResolvers();
+
+ // The engine parents must resolve themselves when they are ready.
+ this[resolverName] = resolve;
+
+ return promise;
+ }
+
+ /**
+ * @returns {HiddenFrame}
+ */
+ static async #getHiddenFrame() {
+ if (!EngineProcess.#hiddenFrame) {
+ EngineProcess.#hiddenFrame = new lazy.HiddenFrame();
+ }
+ return EngineProcess.#hiddenFrame;
+ }
+
+ /**
+ * Destroy the translations engine, and remove the hidden frame if no other
+ * engines exist.
+ */
+ static destroyTranslationsEngine() {
+ return this.#destroyEngine({
+ id: "translations-engine-browser",
+ keyName: "translationsEngineParent",
+ });
+ }
+
+ /**
+ * Destroy the ML engine, and remove the hidden frame if no other engines exist.
+ */
+ static destroyMLEngine() {
+ return this.#destroyEngine({
+ id: "ml-engine-browser",
+ keyName: "mlEngineParent",
+ });
+ }
+
+ /**
+ * Destroy the specified engine and maybe the entire hidden frame as well if no engines
+ * are remaining.
+ */
+ static #destroyEngine({ id, keyName }) {
+ ChromeUtils.addProfilerMarker(
+ "EngineProcess",
+ {},
+ `Destroying the "${id}" engine`
+ );
+
+ const actorShutdown = this.forceActorShutdown(id, keyName).catch(
+ error => void console.error(error)
+ );
+
+ this[keyName] = null;
+
+ const hiddenFrame = EngineProcess.#hiddenFrame;
+ if (hiddenFrame && !this.translationsEngineParent && !this.mlEngineParent) {
+ EngineProcess.#hiddenFrame = null;
+
+ // Both actors are destroyed, also destroy the hidden frame.
+ actorShutdown.then(() => {
+ // Double check a race condition that no new actors have been created during
+ // shutdown.
+ if (this.translationsEngineParent && this.mlEngineParent) {
+ return;
+ }
+ if (!hiddenFrame) {
+ return;
+ }
+ hiddenFrame.destroy();
+ ChromeUtils.addProfilerMarker(
+ "EngineProcess",
+ {},
+ `Removing the hidden frame`
+ );
+ });
+ }
+
+ // Infallibly resolve the promise even if there are errors.
+ return Promise.resolve();
+ }
+
+ /**
+ * Shut down an actor and remove its <browser> element.
+ *
+ * @param {string} id
+ * @param {string} keyName
+ */
+ static async forceActorShutdown(id, keyName) {
+ const actorPromise = this[keyName];
+ if (!actorPromise) {
+ return;
+ }
+
+ let actor;
+ try {
+ actor = await actorPromise;
+ } catch {
+ // The actor failed to initialize, so it doesn't need to be shut down.
+ return;
+ }
+
+ // Shut down the actor.
+ try {
+ await actor.forceShutdown();
+ } catch (error) {
+ console.error("Failed to shut down the actor " + id, error);
+ return;
+ }
+
+ if (!EngineProcess.#hiddenFrame) {
+ // The hidden frame was already removed.
+ return;
+ }
+
+ // Remove the <brower> element.
+ const chromeWindow = EngineProcess.#hiddenFrame.getWindow();
+ const doc = chromeWindow.document;
+ const element = doc.getElementById(id);
+ if (!element) {
+ console.error("Could not find the <browser> element for " + id);
+ return;
+ }
+ element.remove();
+ }
+}
diff --git a/toolkit/components/ml/content/MLEngine.html b/toolkit/components/ml/content/MLEngine.html
new file mode 100644
index 0000000000..8763995102
--- /dev/null
+++ b/toolkit/components/ml/content/MLEngine.html
@@ -0,0 +1,16 @@
+<!-- This Source Code Form is subject to the terms of the Mozilla Public
+ - License, v. 2.0. If a copy of the MPL was not distributed with this
+ - file, You can obtain one at http://mozilla.org/MPL/2.0/. -->
+
+<!DOCTYPE html>
+<html>
+ <head>
+ <meta charset="utf-8" />
+ <meta
+ http-equiv="Content-Security-Policy"
+ content="default-src chrome: resource:; object-src 'none'"
+ />
+ <!-- Run the machine learning inference engine in its own singleton content process. -->
+ </head>
+ <body></body>
+</html>
diff --git a/toolkit/components/ml/content/MLEngine.worker.mjs b/toolkit/components/ml/content/MLEngine.worker.mjs
new file mode 100644
index 0000000000..1013977e07
--- /dev/null
+++ b/toolkit/components/ml/content/MLEngine.worker.mjs
@@ -0,0 +1,91 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+import { PromiseWorker } from "resource://gre/modules/workers/PromiseWorker.mjs";
+
+// Respect the preference "browser.ml.logLevel".
+let _loggingLevel = "Error";
+function log(...args) {
+ if (_loggingLevel !== "Error" && _loggingLevel !== "Warn") {
+ console.log("ML:", ...args);
+ }
+}
+function trace(...args) {
+ if (_loggingLevel === "Trace" || _loggingLevel === "All") {
+ console.log("ML:", ...args);
+ }
+}
+
+/**
+ * The actual MLEngine lives here in a worker.
+ */
+class MLEngineWorker {
+ /** @type {ArrayBuffer} */
+ #wasm;
+ /** @type {ArrayBuffer} */
+ #model;
+
+ constructor() {
+ // Connect the provider to the worker.
+ this.#connectToPromiseWorker();
+ }
+
+ /**
+ * @param {ArrayBuffer} wasm
+ * @param {ArrayBuffer} model
+ * @param {string} loggingLevel
+ */
+ initializeEngine(wasm, model, loggingLevel) {
+ this.#wasm = wasm;
+ this.#model = model;
+ _loggingLevel = loggingLevel;
+ // TODO - Initialize the engine for real here.
+ log("MLEngineWorker is initalized");
+ }
+
+ /**
+ * Run the worker.
+ *
+ * @param {string} request
+ */
+ run(request) {
+ if (!this.#wasm) {
+ throw new Error("Expected the wasm to exist.");
+ }
+ if (!this.#model) {
+ throw new Error("Expected the model to exist");
+ }
+ if (request === "throw") {
+ throw new Error(
+ 'Received the message "throw", so intentionally throwing an error.'
+ );
+ }
+ trace("inference run requested with:", request);
+ return request.slice(0, Math.floor(request.length / 2));
+ }
+
+ /**
+ * Glue code to connect the `MLEngineWorker` to the PromiseWorker interface.
+ */
+ #connectToPromiseWorker() {
+ const worker = new PromiseWorker.AbstractWorker();
+ worker.dispatch = (method, args = []) => {
+ if (!this[method]) {
+ throw new Error("Method does not exist: " + method);
+ }
+ return this[method](...args);
+ };
+ worker.close = () => self.close();
+ worker.postMessage = (message, ...transfers) => {
+ self.postMessage(message, ...transfers);
+ };
+
+ self.addEventListener("message", msg => worker.handleMessage(msg));
+ self.addEventListener("unhandledrejection", function (error) {
+ throw error.reason;
+ });
+ }
+}
+
+new MLEngineWorker();
diff --git a/toolkit/components/ml/content/SummarizerModel.sys.mjs b/toolkit/components/ml/content/SummarizerModel.sys.mjs
new file mode 100644
index 0000000000..7cac55d92f
--- /dev/null
+++ b/toolkit/components/ml/content/SummarizerModel.sys.mjs
@@ -0,0 +1,160 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/**
+ * @typedef {object} LazyImports
+ * @property {typeof import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent
+ */
+
+/** @type {LazyImports} */
+const lazy = {};
+
+ChromeUtils.defineESModuleGetters(lazy, {
+ RemoteSettings: "resource://services-settings/remote-settings.sys.mjs",
+ TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs",
+});
+
+ChromeUtils.defineLazyGetter(lazy, "console", () => {
+ return console.createInstance({
+ maxLogLevelPref: "browser.ml.logLevel",
+ prefix: "ML",
+ });
+});
+
+export class SummarizerModel {
+ /**
+ * The RemoteSettingsClient that downloads the summarizer model.
+ *
+ * @type {RemoteSettingsClient | null}
+ */
+ static #remoteClient = null;
+
+ /** @type {Promise<WasmRecord> | null} */
+ static #modelRecord = null;
+
+ /**
+ * The following constant controls the major version for wasm downloaded from
+ * Remote Settings. When a breaking change is introduced, Nightly will have these
+ * numbers incremented by one, but Beta and Release will still be on the previous
+ * version. Remote Settings will ship both versions of the records, and the latest
+ * asset released in that version will be used. For instance, with a major version
+ * of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked
+ * as "2.0", "2.1", etc will not be downloaded.
+ */
+ static MODEL_MAJOR_VERSION = 1;
+
+ /**
+ * Remote settings isn't available in tests, so provide mocked responses.
+ */
+ static mockRemoteSettings(remoteClient) {
+ lazy.console.log("Mocking remote client in SummarizerModel.");
+ SummarizerModel.#remoteClient = remoteClient;
+ SummarizerModel.#modelRecord = null;
+ }
+
+ /**
+ * Remove anything that could have been mocked.
+ */
+ static removeMocks() {
+ lazy.console.log("Removing mocked remote client in SummarizerModel.");
+ SummarizerModel.#remoteClient = null;
+ SummarizerModel.#modelRecord = null;
+ }
+ /**
+ * Download or load the model from remote settings.
+ *
+ * @returns {Promise<ArrayBuffer>}
+ */
+ static async getModel() {
+ const client = SummarizerModel.#getRemoteClient();
+
+ if (!SummarizerModel.#modelRecord) {
+ // Place the records into a promise to prevent any races.
+ SummarizerModel.#modelRecord = (async () => {
+ // Load the wasm binary from remote settings, if it hasn't been already.
+ lazy.console.log(`Getting the summarizer model record.`);
+
+ // TODO - The getMaxVersionRecords should eventually migrated to some kind of
+ // shared utility.
+ const { getMaxVersionRecords } = lazy.TranslationsParent;
+
+ /** @type {WasmRecord[]} */
+ const wasmRecords = await getMaxVersionRecords(client, {
+ // TODO - This record needs to be created with the engine wasm payload.
+ filters: { name: "summarizer-model" },
+ majorVersion: SummarizerModel.MODEL_MAJOR_VERSION,
+ });
+
+ if (wasmRecords.length === 0) {
+ // The remote settings client provides an empty list of records when there is
+ // an error.
+ throw new Error("Unable to get the models from Remote Settings.");
+ }
+
+ if (wasmRecords.length > 1) {
+ SummarizerModel.reportError(
+ new Error("Expected the ml engine to only have 1 record."),
+ wasmRecords
+ );
+ }
+ const [record] = wasmRecords;
+ lazy.console.log(
+ `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`,
+ record
+ );
+ return record;
+ })();
+ }
+
+ try {
+ /** @type {{buffer: ArrayBuffer}} */
+ const { buffer } = await client.attachments.download(
+ await SummarizerModel.#modelRecord
+ );
+
+ return buffer;
+ } catch (error) {
+ SummarizerModel.#modelRecord = null;
+ throw error;
+ }
+ }
+
+ /**
+ * Lazily initializes the RemoteSettingsClient.
+ *
+ * @returns {RemoteSettingsClient}
+ */
+ static #getRemoteClient() {
+ if (SummarizerModel.#remoteClient) {
+ return SummarizerModel.#remoteClient;
+ }
+
+ /** @type {RemoteSettingsClient} */
+ const client = lazy.RemoteSettings("ml-model");
+
+ SummarizerModel.#remoteClient = client;
+
+ client.on("sync", async ({ data: { created, updated, deleted } }) => {
+ lazy.console.log(`"sync" event for ml-model`, {
+ created,
+ updated,
+ deleted,
+ });
+
+ // Remove all the deleted records.
+ for (const record of deleted) {
+ await client.attachments.deleteDownloaded(record);
+ }
+
+ // Remove any updated records, and download the new ones.
+ for (const { old: oldRecord } of updated) {
+ await client.attachments.deleteDownloaded(oldRecord);
+ }
+
+ // Do nothing for the created records.
+ });
+
+ return client;
+ }
+}
diff --git a/toolkit/components/ml/docs/index.md b/toolkit/components/ml/docs/index.md
new file mode 100644
index 0000000000..1b2015456b
--- /dev/null
+++ b/toolkit/components/ml/docs/index.md
@@ -0,0 +1,44 @@
+# Machine Learning
+
+This component is an experimental machine learning local inference engine. Currently there is no inference engine actually integrated yet.
+
+Here is an example of the API:
+
+```js
+// The engine process manages the life cycle of the engine. It runs in its own process.
+// Models can consume large amounts of memory, and this helps encapsulate it at the
+// operating system level.
+const EngineProcess = ChromeUtils.importESModule("chrome://global/content/ml/EngineProcess.sys.mjs");
+
+// The MLEngineParent is a JSActor that can communicate with the engine process.
+const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+
+/**
+ * When implementing a model, there should be a class that provides a `getModel` function
+ * that is responsible for providing the `ArrayBuffer` of the model. Typically this
+ * download is managed by RemoteSettings.
+ */
+class SummarizerModel {
+ /**
+ * @returns {ArrayBuffer}
+ */
+ static getModel() { ... }
+}
+
+// An engine can be created using a unique name for the engine, and the function
+// to get the model. This class handles the life cycle of the engine.
+const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+);
+
+// In order to run the model, use the `run` method. This will initiate the engine if
+// it is needed, and return the result. The messaging to the engine process happens
+// through a MessagePort.
+const result = await summarizer.run("A sentence that can be summarized.")
+
+// The engine can be explicitly terminated, or it will be destroyed through an idle
+// timeout when not in use, as the memory requirements for models can be quite large.
+summarizer.terminate();
+```
diff --git a/toolkit/components/ml/jar.mn b/toolkit/components/ml/jar.mn
new file mode 100644
index 0000000000..56bfb0d469
--- /dev/null
+++ b/toolkit/components/ml/jar.mn
@@ -0,0 +1,9 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+toolkit.jar:
+ content/global/ml/EngineProcess.sys.mjs (content/EngineProcess.sys.mjs)
+ content/global/ml/MLEngine.worker.mjs (content/MLEngine.worker.mjs)
+ content/global/ml/MLEngine.html (content/MLEngine.html)
+ content/global/ml/SummarizerModel.sys.mjs (content/SummarizerModel.sys.mjs)
diff --git a/toolkit/components/ml/moz.build b/toolkit/components/ml/moz.build
new file mode 100644
index 0000000000..3308d8f085
--- /dev/null
+++ b/toolkit/components/ml/moz.build
@@ -0,0 +1,17 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+SPHINX_TREES["/toolkit/components/ml"] = "docs"
+
+JAR_MANIFESTS += ["jar.mn"]
+
+with Files("**"):
+ BUG_COMPONENT = ("Core", "Machine Learning")
+
+DIRS += ["actors"]
+
+BROWSER_CHROME_MANIFESTS += ["tests/browser/browser.toml"]
+
+with Files("docs/**"):
+ SCHEDULES.exclusive = ["docs"]
diff --git a/toolkit/components/ml/tests/browser/browser.toml b/toolkit/components/ml/tests/browser/browser.toml
new file mode 100644
index 0000000000..9ccda0beaa
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/browser.toml
@@ -0,0 +1,5 @@
+[DEFAULT]
+support-files = [
+ "head.js",
+]
+["browser_ml_engine.js"]
diff --git a/toolkit/components/ml/tests/browser/browser_ml_engine.js b/toolkit/components/ml/tests/browser/browser_ml_engine.js
new file mode 100644
index 0000000000..6942809d6d
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/browser_ml_engine.js
@@ -0,0 +1,219 @@
+/* Any copyright is dedicated to the Public Domain.
+ http://creativecommons.org/publicdomain/zero/1.0/ */
+
+"use strict";
+
+/// <reference path="head.js" />
+
+async function setup({ disabled = false, prefs = [] } = {}) {
+ const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({
+ autoDownloadFromRemoteSettings: false,
+ });
+
+ await SpecialPowers.pushPrefEnv({
+ set: [
+ // Enabled by default.
+ ["browser.ml.enable", !disabled],
+ ["browser.ml.logLevel", "All"],
+ ...prefs,
+ ],
+ });
+
+ return {
+ remoteClients,
+ async cleanup() {
+ await removeMocks();
+ await waitForCondition(
+ () => EngineProcess.areAllEnginesTerminated(),
+ "Waiting for all of the engines to be terminated.",
+ 100,
+ 200
+ );
+ },
+ };
+}
+
+add_task(async function test_ml_engine_basics() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients.wasm.resolvePendingDownloads(1);
+
+ is(
+ await summarizePromise,
+ "This gets c",
+ "The text gets cut in half simulating summarizing"
+ );
+
+ ok(
+ !EngineProcess.areAllEnginesTerminated(),
+ "The engine process is still active."
+ );
+
+ await EngineProcess.destroyMLEngine();
+
+ await cleanup();
+});
+
+add_task(async function test_ml_engine_model_rejection() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients.models.rejectPendingDownloads(1);
+
+ let error;
+ try {
+ await summarizePromise;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "Intentionally rejecting downloads.",
+ "The error is correctly surfaced."
+ );
+
+ await cleanup();
+});
+
+add_task(async function test_ml_engine_wasm_rejection() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.wasm.rejectPendingDownloads(1);
+ await remoteClients.models.resolvePendingDownloads(1);
+
+ let error;
+ try {
+ await summarizePromise;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ "Intentionally rejecting downloads.",
+ "The error is correctly surfaced."
+ );
+
+ await cleanup();
+});
+
+/**
+ * Tests that the SummarizerModel's internal errors are correctly surfaced.
+ */
+add_task(async function test_ml_engine_model_error() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer with a throwing example.");
+ const summarizePromise = summarizer.run("throw");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.wasm.resolvePendingDownloads(1);
+ await remoteClients.models.resolvePendingDownloads(1);
+
+ let error;
+ try {
+ await summarizePromise;
+ } catch (e) {
+ error = e;
+ }
+ is(
+ error?.message,
+ 'Error: Received the message "throw", so intentionally throwing an error.',
+ "The error is correctly surfaced."
+ );
+
+ summarizer.terminate();
+
+ await cleanup();
+});
+
+/**
+ * This test is really similar to the "basic" test, but tests manually destroying
+ * the summarizer.
+ */
+add_task(async function test_ml_engine_destruction() {
+ const { cleanup, remoteClients } = await setup();
+
+ info("Get the engine process");
+ const mlEngineParent = await EngineProcess.getMLEngineParent();
+
+ info("Get summarizer");
+ const summarizer = mlEngineParent.getEngine(
+ "summarizer",
+ SummarizerModel.getModel
+ );
+
+ info("Run the summarizer");
+ const summarizePromise = summarizer.run("This gets cut in half.");
+
+ info("Wait for the pending downloads.");
+ await remoteClients.models.resolvePendingDownloads(1);
+ await remoteClients.wasm.resolvePendingDownloads(1);
+
+ is(
+ await summarizePromise,
+ "This gets c",
+ "The text gets cut in half simulating summarizing"
+ );
+
+ ok(
+ !EngineProcess.areAllEnginesTerminated(),
+ "The engine process is still active."
+ );
+
+ summarizer.terminate();
+
+ info(
+ "The summarizer is manually destroyed. The cleanup function should wait for the engine process to be destroyed."
+ );
+
+ await cleanup();
+});
diff --git a/toolkit/components/ml/tests/browser/head.js b/toolkit/components/ml/tests/browser/head.js
new file mode 100644
index 0000000000..99d27ce18a
--- /dev/null
+++ b/toolkit/components/ml/tests/browser/head.js
@@ -0,0 +1,155 @@
+/* Any copyright is dedicated to the Public Domain.
+ http://creativecommons.org/publicdomain/zero/1.0/ */
+
+/// <reference path="../../../../../toolkit/components/translations/tests/browser/shared-head.js" />
+
+"use strict";
+
+/**
+ * @type {import("../../content/SummarizerModel.sys.mjs")}
+ */
+const { SummarizerModel } = ChromeUtils.importESModule(
+ "chrome://global/content/ml/SummarizerModel.sys.mjs"
+);
+
+/**
+ * @type {import("../../actors/MLEngineParent.sys.mjs")}
+ */
+const { MLEngineParent } = ChromeUtils.importESModule(
+ "resource://gre/actors/MLEngineParent.sys.mjs"
+);
+
+// This test suite shares some utility functions with translations as they work in a very
+// similar fashion. Eventually, the plan is to unify these two components.
+Services.scriptloader.loadSubScript(
+ "chrome://mochitests/content/browser/toolkit/components/translations/tests/browser/shared-head.js",
+ this
+);
+
+function getDefaultModelRecords() {
+ return [
+ {
+ name: "summarizer-model",
+ version: SummarizerModel.MODEL_MAJOR_VERSION + ".0",
+ },
+ ];
+}
+
+function getDefaultWasmRecords() {
+ return [
+ {
+ name: "inference-engine",
+ version: MLEngineParent.WASM_MAJOR_VERSION + ".0",
+ },
+ ];
+}
+
+/**
+ * Creates a local RemoteSettingsClient for use within tests.
+ *
+ * @param {boolean} autoDownloadFromRemoteSettings
+ * @param {object[]} models
+ * @returns {AttachmentMock}
+ */
+async function createMLModelsRemoteClient(
+ autoDownloadFromRemoteSettings,
+ models = getDefaultModelRecords()
+) {
+ const { RemoteSettings } = ChromeUtils.importESModule(
+ "resource://services-settings/remote-settings.sys.mjs"
+ );
+ const mockedCollectionName = "test-ml-models";
+ const client = RemoteSettings(
+ `${mockedCollectionName}-${_remoteSettingsMockId++}`
+ );
+ const metadata = {};
+ await client.db.clear();
+ await client.db.importChanges(
+ metadata,
+ Date.now(),
+ models.map(({ name, version }) => ({
+ id: crypto.randomUUID(),
+ name,
+ version,
+ last_modified: Date.now(),
+ schema: Date.now(),
+ attachment: {
+ hash: `${crypto.randomUUID()}`,
+ size: `123`,
+ filename: name,
+ location: `main-workspace/ml-models/${crypto.randomUUID()}.bin`,
+ mimetype: "application/octet-stream",
+ },
+ }))
+ );
+
+ return createAttachmentMock(
+ client,
+ mockedCollectionName,
+ autoDownloadFromRemoteSettings
+ );
+}
+
+async function createAndMockMLRemoteSettings({
+ models = getDefaultModelRecords(),
+ autoDownloadFromRemoteSettings = false,
+} = {}) {
+ const remoteClients = {
+ models: await createMLModelsRemoteClient(
+ autoDownloadFromRemoteSettings,
+ models
+ ),
+ wasm: await createMLWasmRemoteClient(autoDownloadFromRemoteSettings),
+ };
+
+ MLEngineParent.mockRemoteSettings(remoteClients.wasm.client);
+ SummarizerModel.mockRemoteSettings(remoteClients.models.client);
+
+ return {
+ async removeMocks() {
+ await remoteClients.models.client.attachments.deleteAll();
+ await remoteClients.models.client.db.clear();
+ await remoteClients.wasm.client.attachments.deleteAll();
+ await remoteClients.wasm.client.db.clear();
+
+ MLEngineParent.removeMocks();
+ SummarizerModel.removeMocks();
+ },
+ remoteClients,
+ };
+}
+
+/**
+ * Creates a local RemoteSettingsClient for use within tests.
+ *
+ * @param {boolean} autoDownloadFromRemoteSettings
+ * @returns {AttachmentMock}
+ */
+async function createMLWasmRemoteClient(autoDownloadFromRemoteSettings) {
+ const { RemoteSettings } = ChromeUtils.importESModule(
+ "resource://services-settings/remote-settings.sys.mjs"
+ );
+ const mockedCollectionName = "test-translation-wasm";
+ const client = RemoteSettings(
+ `${mockedCollectionName}-${_remoteSettingsMockId++}`
+ );
+ const metadata = {};
+ await client.db.clear();
+ await client.db.importChanges(
+ metadata,
+ Date.now(),
+ getDefaultWasmRecords().map(({ name, version }) => ({
+ id: crypto.randomUUID(),
+ name,
+ version,
+ last_modified: Date.now(),
+ schema: Date.now(),
+ }))
+ );
+
+ return createAttachmentMock(
+ client,
+ mockedCollectionName,
+ autoDownloadFromRemoteSettings
+ );
+}