summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/actors
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:43:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:43:14 +0000
commit8dd16259287f58f9273002717ec4d27e97127719 (patch)
tree3863e62a53829a84037444beab3abd4ed9dfc7d0 /toolkit/components/ml/actors
parentReleasing progress-linux version 126.0.1-1~progress7.99u1. (diff)
downloadfirefox-8dd16259287f58f9273002717ec4d27e97127719.tar.xz
firefox-8dd16259287f58f9273002717ec4d27e97127719.zip
Merging upstream version 127.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/components/ml/actors')
-rw-r--r--toolkit/components/ml/actors/MLEngineChild.sys.mjs156
-rw-r--r--toolkit/components/ml/actors/MLEngineParent.sys.mjs175
2 files changed, 233 insertions, 98 deletions
diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
index 925ce59266..17a8b3511a 100644
--- a/toolkit/components/ml/actors/MLEngineChild.sys.mjs
+++ b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
@@ -21,6 +21,8 @@ ChromeUtils.defineESModuleGetters(lazy, {
BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs",
setTimeout: "resource://gre/modules/Timer.sys.mjs",
clearTimeout: "resource://gre/modules/Timer.sys.mjs",
+ ModelHub: "chrome://global/content/ml/ModelHub.sys.mjs",
+ PipelineOptions: "chrome://global/content/ml/EngineProcess.sys.mjs",
});
ChromeUtils.defineLazyGetter(lazy, "console", () => {
@@ -32,9 +34,20 @@ ChromeUtils.defineLazyGetter(lazy, "console", () => {
XPCOMUtils.defineLazyPreferenceGetter(
lazy,
- "loggingLevel",
- "browser.ml.logLevel"
+ "CACHE_TIMEOUT_MS",
+ "browser.ml.modelCacheTimeout"
);
+XPCOMUtils.defineLazyPreferenceGetter(
+ lazy,
+ "MODEL_HUB_ROOT_URL",
+ "browser.ml.modelHubRootUrl"
+);
+XPCOMUtils.defineLazyPreferenceGetter(
+ lazy,
+ "MODEL_HUB_URL_TEMPLATE",
+ "browser.ml.modelHubUrlTemplate"
+);
+XPCOMUtils.defineLazyPreferenceGetter(lazy, "LOG_LEVEL", "browser.ml.logLevel");
/**
* The engine child is responsible for the life cycle and instantiation of the local
@@ -52,10 +65,21 @@ export class MLEngineChild extends JSWindowActorChild {
async receiveMessage({ name, data }) {
switch (name) {
case "MLEngine:NewPort": {
- const { engineName, port, timeoutMS } = data;
+ const { port, pipelineOptions } = data;
+
+ // Override some options using prefs
+ let options = new lazy.PipelineOptions(pipelineOptions);
+
+ options.updateOptions({
+ modelHubRootUrl: lazy.MODEL_HUB_ROOT_URL,
+ modelHubUrlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
+ timeoutMS: lazy.CACHE_TIMEOUT_MS,
+ logLevel: lazy.LOG_LEVEL,
+ });
+
this.#engineDispatchers.set(
- engineName,
- new EngineDispatcher(this, port, engineName, timeoutMS)
+ options.taskName,
+ new EngineDispatcher(this, port, options)
);
break;
}
@@ -78,13 +102,24 @@ export class MLEngineChild extends JSWindowActorChild {
}
/**
- * @returns {ArrayBuffer}
+ * Gets the wasm array buffer from RemoteSettings.
+ *
+ * @returns {Promise<ArrayBuffer>}
*/
getWasmArrayBuffer() {
return this.sendQuery("MLEngine:GetWasmArrayBuffer");
}
/**
+ * Gets the inference options from RemoteSettings.
+ *
+ * @returns {Promise<object>}
+ */
+ getInferenceOptions(taskName) {
+ return this.sendQuery(`MLEngine:GetInferenceOptions:${taskName}`);
+ }
+
+ /**
* @param {string} engineName
*/
removeEngine(engineName) {
@@ -113,28 +148,45 @@ class EngineDispatcher {
#engine = null;
/** @type {string} */
- #engineName;
+ #taskName;
+
+ /** Creates the inference engine given the wasm runtime and the run options.
+ *
+ * The initialization is done in three steps:
+ * 1. The wasm runtime is fetched from RS
+ * 2. The inference options are fetched from RS and augmented with the pipeline options.
+ * 3. The inference engine is created with the wasm runtime and the options.
+ *
+ * Any exception here will be bubbled up for the constructor to log.
+ *
+ * @param {PipelineOptions} pipelineOptions
+ * @returns {Promise<Engine>}
+ */
+ async initializeInferenceEngine(pipelineOptions) {
+ // Create the inference engine given the wasm runtime and the options.
+ const wasm = await this.mlEngineChild.getWasmArrayBuffer();
+ const inferenceOptions = await this.mlEngineChild.getInferenceOptions(
+ this.#taskName
+ );
+ lazy.console.debug("Inference engine options:", inferenceOptions);
+ pipelineOptions.updateOptions(inferenceOptions);
+
+ return InferenceEngine.create(wasm, pipelineOptions);
+ }
/**
* @param {MLEngineChild} mlEngineChild
* @param {MessagePort} port
- * @param {string} engineName
- * @param {number} timeoutMS
+ * @param {PipelineOptions} pipelineOptions
*/
- constructor(mlEngineChild, port, engineName, timeoutMS) {
- /** @type {MLEngineChild} */
+ constructor(mlEngineChild, port, pipelineOptions) {
this.mlEngineChild = mlEngineChild;
+ this.#taskName = pipelineOptions.taskName;
+ this.timeoutMS = pipelineOptions.timeoutMS;
- /** @type {number} */
- this.timeoutMS = timeoutMS;
-
- this.#engineName = engineName;
-
- this.#engine = Promise.all([
- this.mlEngineChild.getWasmArrayBuffer(),
- this.getModel(port),
- ]).then(([wasm, model]) => FakeEngine.create(wasm, model));
+ this.#engine = this.initializeInferenceEngine(pipelineOptions);
+ // Trigger the keep alive timer.
this.#engine
.then(() => void this.keepAlive())
.catch(error => {
@@ -265,20 +317,61 @@ class EngineDispatcher {
port.close();
}
this.#ports = new Set();
- this.mlEngineChild.removeEngine(this.#engineName);
+ this.mlEngineChild.removeEngine(this.#taskName);
try {
const engine = await this.#engine;
engine.terminate();
} catch (error) {
- console.error("Failed to get the engine", error);
+ lazy.console.error("Failed to get the engine", error);
}
}
}
+let modelHub = null; // This will hold the ModelHub instance to reuse it.
+
+/**
+ * Retrieves a model file as an ArrayBuffer from the specified URL.
+ * This function normalizes the URL, extracts the organization, model name, and file path,
+ * then fetches the model file using the ModelHub API. The `modelHub` instance is created
+ * only once and reused for subsequent calls to optimize performance.
+ *
+ * @param {string} url - The URL of the model file to fetch. Can be a path relative to
+ * the model hub root or an absolute URL.
+ * @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers,
+ * and data as an ArrayBuffer. The data is marked for transfer to avoid cloning.
+ */
+async function getModelFile(url) {
+ // Create the model hub instance if needed
+ if (!modelHub) {
+ lazy.console.debug("Creating model hub instance");
+ modelHub = new lazy.ModelHub({
+ rootUrl: lazy.MODEL_HUB_ROOT_URL,
+ urlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
+ });
+ }
+
+ if (url.startsWith(lazy.MODEL_HUB_ROOT_URL)) {
+ url = url.slice(lazy.MODEL_HUB_ROOT_URL.length);
+ // Make sure we get a front slash
+ if (!url.startsWith("/")) {
+ url = `/${url}`;
+ }
+ }
+
+ // Parsing url to get model name, and file path.
+ // if this errors out, it will be caught in the worker
+ const parsedUrl = modelHub.parseUrl(url);
+
+ let [data, headers] = await modelHub.getModelFileAsArrayBuffer(parsedUrl);
+ return new lazy.BasePromiseWorker.Meta([url, headers, data], {
+ transfers: [data],
+ });
+}
+
/**
- * Fake the engine by slicing the text in half.
+ * Wrapper around the ChromeWorker that runs the inference.
*/
-class FakeEngine {
+class InferenceEngine {
/** @type {BasePromiseWorker} */
#worker;
@@ -286,21 +379,22 @@ class FakeEngine {
* Initialize the worker.
*
* @param {ArrayBuffer} wasm
- * @param {ArrayBuffer} model
- * @returns {FakeEngine}
+ * @param {PipelineOptions} pipelineOptions
+ * @returns {InferenceEngine}
*/
- static async create(wasm, model) {
+ static async create(wasm, pipelineOptions) {
/** @type {BasePromiseWorker} */
const worker = new lazy.BasePromiseWorker(
"chrome://global/content/ml/MLEngine.worker.mjs",
- { type: "module" }
+ { type: "module" },
+ { getModelFile }
);
- const args = [wasm, model, lazy.loggingLevel];
+ const args = [wasm, pipelineOptions];
const closure = {};
- const transferables = [wasm, model];
+ const transferables = [wasm];
await worker.post("initializeEngine", args, closure, transferables);
- return new FakeEngine(worker);
+ return new InferenceEngine(worker);
}
/**
diff --git a/toolkit/components/ml/actors/MLEngineParent.sys.mjs b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
index 05203e5f69..65941d9d4e 100644
--- a/toolkit/components/ml/actors/MLEngineParent.sys.mjs
+++ b/toolkit/components/ml/actors/MLEngineParent.sys.mjs
@@ -5,6 +5,7 @@
/**
* @typedef {object} Lazy
* @property {typeof console} console
+ * @property {typeof import("../content/Utils.sys.mjs").getRuntimeWasmFilename} getRuntimeWasmFilename
* @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess
* @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings
* @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent
@@ -21,16 +22,14 @@ ChromeUtils.defineLazyGetter(lazy, "console", () => {
});
ChromeUtils.defineESModuleGetters(lazy, {
+ getRuntimeWasmFilename: "chrome://global/content/ml/Utils.sys.mjs",
EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs",
RemoteSettings: "resource://services-settings/remote-settings.sys.mjs",
TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs",
});
-/**
- * @typedef {import("../../translations/translations").WasmRecord} WasmRecord
- */
-
-const DEFAULT_CACHE_TIMEOUT_MS = 15_000;
+const RS_RUNTIME_COLLECTION = "ml-onnx-runtime";
+const RS_INFERENCE_OPTIONS_COLLECTION = "ml-inference-options";
/**
* The ML engine is in its own content process. This actor handles the
@@ -40,9 +39,9 @@ export class MLEngineParent extends JSWindowActorParent {
/**
* The RemoteSettingsClient that downloads the wasm binaries.
*
- * @type {RemoteSettingsClient | null}
+ * @type {Record<string, RemoteSettingsClient>}
*/
- static #remoteClient = null;
+ static #remoteClients = {};
/** @type {Promise<WasmRecord> | null} */
static #wasmRecord = null;
@@ -61,11 +60,11 @@ export class MLEngineParent extends JSWindowActorParent {
/**
* Remote settings isn't available in tests, so provide mocked responses.
*
- * @param {RemoteSettingsClient} remoteClient
+ * @param {RemoteSettingsClient} remoteClients
*/
- static mockRemoteSettings(remoteClient) {
+ static mockRemoteSettings(remoteClients) {
lazy.console.log("Mocking remote settings in MLEngineParent.");
- MLEngineParent.#remoteClient = remoteClient;
+ MLEngineParent.#remoteClients = remoteClients;
MLEngineParent.#wasmRecord = null;
}
@@ -74,24 +73,49 @@ export class MLEngineParent extends JSWindowActorParent {
*/
static removeMocks() {
lazy.console.log("Removing mocked remote client in MLEngineParent.");
- MLEngineParent.#remoteClient = null;
+ MLEngineParent.#remoteClients = {};
MLEngineParent.#wasmRecord = null;
}
- /**
- * @param {string} engineName
- * @param {() => Promise<ArrayBuffer>} getModel
- * @param {number} cacheTimeoutMS - How long the engine cache remains alive between
- * uses, in milliseconds. In automation the engine is manually created and destroyed
- * to avoid timing issues.
+ /** Creates a new MLEngine.
+ *
+ * @param {PipelineOptions} pipelineOptions
* @returns {MLEngine}
*/
- getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) {
- return new MLEngine(this, engineName, getModel, cacheTimeoutMS);
+ getEngine(pipelineOptions) {
+ return new MLEngine({ mlEngineParent: this, pipelineOptions });
+ }
+
+ /** Extracts the task name from the name and validates it.
+ *
+ * Throws an exception if the task name is invalid.
+ *
+ * @param {string} name
+ * @returns {string}
+ */
+ nameToTaskName(name) {
+ // Extract taskName after the specific prefix
+ const taskName = name.split("MLEngine:GetInferenceOptions:")[1];
+
+ // Define a regular expression to verify taskName pattern (alphanumeric and underscores/dashes)
+ const validTaskNamePattern = /^[a-zA-Z0-9_\-]+$/;
+
+ // Check if taskName matches the pattern
+ if (!validTaskNamePattern.test(taskName)) {
+ // Handle invalid taskName, e.g., throw an error or return null
+ throw new Error(
+ "Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes."
+ );
+ }
+ return taskName;
}
// eslint-disable-next-line consistent-return
async receiveMessage({ name }) {
+ if (name.startsWith("MLEngine:GetInferenceOptions")) {
+ return MLEngineParent.getInferenceOptions(this.nameToTaskName(name));
+ }
+
switch (name) {
case "MLEngine:Ready":
if (lazy.EngineProcess.resolveMLEngineParent) {
@@ -112,19 +136,18 @@ export class MLEngineParent extends JSWindowActorParent {
}
}
- /**
+ /** Gets the wasm file from remote settings.
+ *
* @param {RemoteSettingsClient} client
*/
static async #getWasmArrayRecord(client) {
- // Load the wasm binary from remote settings, if it hasn't been already.
- lazy.console.log(`Getting remote wasm records.`);
+ const wasmFilename = lazy.getRuntimeWasmFilename(this.browsingContext);
/** @type {WasmRecord[]} */
const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords(
client,
{
- // TODO - This record needs to be created with the engine wasm payload.
- filters: { name: "inference-engine" },
+ filters: { name: wasmFilename },
majorVersion: MLEngineParent.WASM_MAJOR_VERSION,
}
);
@@ -142,20 +165,47 @@ export class MLEngineParent extends JSWindowActorParent {
);
}
const [record] = wasmRecords;
- lazy.console.log(
- `Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`,
- record
- );
+ lazy.console.log(`Using runtime ${record.name}@${record.version}`, record);
return record;
}
+ /** Gets the inference options from remote settings given a task name.
+ *
+ * @type {string} taskName - name of the inference :wtask
+ * @returns {Promise<ModelRevisionRecord>}
+ */
+ static async getInferenceOptions(taskName) {
+ const client = MLEngineParent.#getRemoteClient(
+ RS_INFERENCE_OPTIONS_COLLECTION
+ );
+ const records = await client.get({
+ filters: {
+ taskName,
+ },
+ });
+
+ if (records.length === 0) {
+ throw new Error(`No inference options found for task ${taskName}`);
+ }
+ const options = records[0];
+ return {
+ modelRevision: options.modelRevision,
+ modelId: options.modelId,
+ tokenizerRevision: options.tokenizerRevision,
+ tokenizerId: options.tokenizerId,
+ processorRevision: options.processorRevision,
+ processorId: options.processorId,
+ runtimeFilename: lazy.getRuntimeWasmFilename(this.browsingContext),
+ };
+ }
+
/**
* Download the wasm for the ML inference engine.
*
* @returns {Promise<ArrayBuffer>}
*/
static async getWasmArrayBuffer() {
- const client = MLEngineParent.#getRemoteClient();
+ const client = MLEngineParent.#getRemoteClient(RS_RUNTIME_COLLECTION);
if (!MLEngineParent.#wasmRecord) {
// Place the records into a promise to prevent any races.
@@ -184,20 +234,23 @@ export class MLEngineParent extends JSWindowActorParent {
/**
* Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data.
*
+ * @param {string} collectionName - The name of the collection to use.
* @returns {RemoteSettingsClient}
*/
- static #getRemoteClient() {
- if (MLEngineParent.#remoteClient) {
- return MLEngineParent.#remoteClient;
+ static #getRemoteClient(collectionName) {
+ if (MLEngineParent.#remoteClients[collectionName]) {
+ return MLEngineParent.#remoteClients[collectionName];
}
/** @type {RemoteSettingsClient} */
- const client = lazy.RemoteSettings("ml-wasm");
+ const client = lazy.RemoteSettings(collectionName, {
+ bucketName: "main",
+ });
- MLEngineParent.#remoteClient = client;
+ MLEngineParent.#remoteClients[collectionName] = client;
client.on("sync", async ({ data: { created, updated, deleted } }) => {
- lazy.console.log(`"sync" event for ml-wasm`, {
+ lazy.console.log(`"sync" event for ${collectionName}`, {
created,
updated,
deleted,
@@ -229,17 +282,6 @@ export class MLEngineParent extends JSWindowActorParent {
}
/**
- * This contains all of the information needed to perform a translation request.
- *
- * @typedef {object} TranslationRequest
- * @property {Node} node
- * @property {string} sourceText
- * @property {boolean} isHTML
- * @property {Function} resolve
- * @property {Function} reject
- */
-
-/**
* The interface to communicate to an MLEngine in the parent process. The engine manages
* its own lifetime, and is kept alive with a timeout. A reference to this engine can
* be retained, but once idle, the engine will be destroyed. If a new request to run
@@ -271,21 +313,13 @@ class MLEngine {
engineStatus = "uninitialized";
/**
- * @param {MLEngineParent} mlEngineParent
- * @param {string} engineName
- * @param {() => Promise<ArrayBuffer>} getModel
- * @param {number} timeoutMS
+ * @param {object} config - The configuration object for the instance.
+ * @param {object} config.mlEngineParent - The parent machine learning engine associated with this instance.
+ * @param {object} config.pipelineOptions - The options for configuring the pipeline associated with this instance.
*/
- constructor(mlEngineParent, engineName, getModel, timeoutMS) {
- /** @type {MLEngineParent} */
+ constructor({ mlEngineParent, pipelineOptions }) {
this.mlEngineParent = mlEngineParent;
- /** @type {string} */
- this.engineName = engineName;
- /** @type {() => Promise<ArrayBuffer>} */
- this.getModel = getModel;
- /** @type {number} */
- this.timeoutMS = timeoutMS;
-
+ this.pipelineOptions = pipelineOptions;
this.#setupPortCommunication();
}
@@ -296,14 +330,12 @@ class MLEngine {
const { port1: childPort, port2: parentPort } = new MessageChannel();
const transferables = [childPort];
this.#port = parentPort;
-
this.#port.onmessage = this.handlePortMessage;
this.mlEngineParent.sendAsyncMessage(
"MLEngine:NewPort",
{
port: childPort,
- engineName: this.engineName,
- timeoutMS: this.timeoutMS,
+ pipelineOptions: this.pipelineOptions.getOptions(),
},
transferables
);
@@ -393,11 +425,20 @@ class MLEngine {
const resolvers = Promise.withResolvers();
const requestId = this.#nextRequestId++;
this.#requests.set(requestId, resolvers);
- this.#port.postMessage({
- type: "EnginePort:Run",
- requestId,
- request,
- });
+
+ let transferables = [];
+ if (request.data instanceof ArrayBuffer) {
+ transferables.push(request.data);
+ }
+
+ this.#port.postMessage(
+ {
+ type: "EnginePort:Run",
+ requestId,
+ request,
+ },
+ transferables
+ );
return resolvers.promise;
}
}