summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/actors/MLEngineChild.sys.mjs
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/actors/MLEngineChild.sys.mjs')
-rw-r--r--toolkit/components/ml/actors/MLEngineChild.sys.mjs156
1 files changed, 125 insertions, 31 deletions
diff --git a/toolkit/components/ml/actors/MLEngineChild.sys.mjs b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
index 925ce59266..17a8b3511a 100644
--- a/toolkit/components/ml/actors/MLEngineChild.sys.mjs
+++ b/toolkit/components/ml/actors/MLEngineChild.sys.mjs
@@ -21,6 +21,8 @@ ChromeUtils.defineESModuleGetters(lazy, {
BasePromiseWorker: "resource://gre/modules/PromiseWorker.sys.mjs",
setTimeout: "resource://gre/modules/Timer.sys.mjs",
clearTimeout: "resource://gre/modules/Timer.sys.mjs",
+ ModelHub: "chrome://global/content/ml/ModelHub.sys.mjs",
+ PipelineOptions: "chrome://global/content/ml/EngineProcess.sys.mjs",
});
ChromeUtils.defineLazyGetter(lazy, "console", () => {
@@ -32,9 +34,20 @@ ChromeUtils.defineLazyGetter(lazy, "console", () => {
XPCOMUtils.defineLazyPreferenceGetter(
lazy,
- "loggingLevel",
- "browser.ml.logLevel"
+ "CACHE_TIMEOUT_MS",
+ "browser.ml.modelCacheTimeout"
);
+XPCOMUtils.defineLazyPreferenceGetter(
+ lazy,
+ "MODEL_HUB_ROOT_URL",
+ "browser.ml.modelHubRootUrl"
+);
+XPCOMUtils.defineLazyPreferenceGetter(
+ lazy,
+ "MODEL_HUB_URL_TEMPLATE",
+ "browser.ml.modelHubUrlTemplate"
+);
+XPCOMUtils.defineLazyPreferenceGetter(lazy, "LOG_LEVEL", "browser.ml.logLevel");
/**
* The engine child is responsible for the life cycle and instantiation of the local
@@ -52,10 +65,21 @@ export class MLEngineChild extends JSWindowActorChild {
async receiveMessage({ name, data }) {
switch (name) {
case "MLEngine:NewPort": {
- const { engineName, port, timeoutMS } = data;
+ const { port, pipelineOptions } = data;
+
+ // Override some options using prefs
+ let options = new lazy.PipelineOptions(pipelineOptions);
+
+ options.updateOptions({
+ modelHubRootUrl: lazy.MODEL_HUB_ROOT_URL,
+ modelHubUrlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
+ timeoutMS: lazy.CACHE_TIMEOUT_MS,
+ logLevel: lazy.LOG_LEVEL,
+ });
+
this.#engineDispatchers.set(
- engineName,
- new EngineDispatcher(this, port, engineName, timeoutMS)
+ options.taskName,
+ new EngineDispatcher(this, port, options)
);
break;
}
@@ -78,13 +102,24 @@ export class MLEngineChild extends JSWindowActorChild {
}
/**
- * @returns {ArrayBuffer}
+ * Gets the wasm array buffer from RemoteSettings.
+ *
+ * @returns {Promise<ArrayBuffer>}
*/
getWasmArrayBuffer() {
return this.sendQuery("MLEngine:GetWasmArrayBuffer");
}
/**
+ * Gets the inference options from RemoteSettings.
+ *
+ * @returns {Promise<object>}
+ */
+ getInferenceOptions(taskName) {
+ return this.sendQuery(`MLEngine:GetInferenceOptions:${taskName}`);
+ }
+
+ /**
* @param {string} engineName
*/
removeEngine(engineName) {
@@ -113,28 +148,45 @@ class EngineDispatcher {
#engine = null;
/** @type {string} */
- #engineName;
+ #taskName;
+
+ /** Creates the inference engine given the wasm runtime and the run options.
+ *
+ * The initialization is done in three steps:
+ * 1. The wasm runtime is fetched from RS
+ * 2. The inference options are fetched from RS and augmented with the pipeline options.
+ * 3. The inference engine is created with the wasm runtime and the options.
+ *
+ * Any exception here will be bubbled up for the constructor to log.
+ *
+ * @param {PipelineOptions} pipelineOptions
+ * @returns {Promise<Engine>}
+ */
+ async initializeInferenceEngine(pipelineOptions) {
+ // Create the inference engine given the wasm runtime and the options.
+ const wasm = await this.mlEngineChild.getWasmArrayBuffer();
+ const inferenceOptions = await this.mlEngineChild.getInferenceOptions(
+ this.#taskName
+ );
+ lazy.console.debug("Inference engine options:", inferenceOptions);
+ pipelineOptions.updateOptions(inferenceOptions);
+
+ return InferenceEngine.create(wasm, pipelineOptions);
+ }
/**
* @param {MLEngineChild} mlEngineChild
* @param {MessagePort} port
- * @param {string} engineName
- * @param {number} timeoutMS
+ * @param {PipelineOptions} pipelineOptions
*/
- constructor(mlEngineChild, port, engineName, timeoutMS) {
- /** @type {MLEngineChild} */
+ constructor(mlEngineChild, port, pipelineOptions) {
this.mlEngineChild = mlEngineChild;
+ this.#taskName = pipelineOptions.taskName;
+ this.timeoutMS = pipelineOptions.timeoutMS;
- /** @type {number} */
- this.timeoutMS = timeoutMS;
-
- this.#engineName = engineName;
-
- this.#engine = Promise.all([
- this.mlEngineChild.getWasmArrayBuffer(),
- this.getModel(port),
- ]).then(([wasm, model]) => FakeEngine.create(wasm, model));
+ this.#engine = this.initializeInferenceEngine(pipelineOptions);
+ // Trigger the keep alive timer.
this.#engine
.then(() => void this.keepAlive())
.catch(error => {
@@ -265,20 +317,61 @@ class EngineDispatcher {
port.close();
}
this.#ports = new Set();
- this.mlEngineChild.removeEngine(this.#engineName);
+ this.mlEngineChild.removeEngine(this.#taskName);
try {
const engine = await this.#engine;
engine.terminate();
} catch (error) {
- console.error("Failed to get the engine", error);
+ lazy.console.error("Failed to get the engine", error);
}
}
}
+let modelHub = null; // This will hold the ModelHub instance to reuse it.
+
+/**
+ * Retrieves a model file as an ArrayBuffer from the specified URL.
+ * This function normalizes the URL, extracts the organization, model name, and file path,
+ * then fetches the model file using the ModelHub API. The `modelHub` instance is created
+ * only once and reused for subsequent calls to optimize performance.
+ *
+ * @param {string} url - The URL of the model file to fetch. Can be a path relative to
+ * the model hub root or an absolute URL.
+ * @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers,
+ * and data as an ArrayBuffer. The data is marked for transfer to avoid cloning.
+ */
+async function getModelFile(url) {
+ // Create the model hub instance if needed
+ if (!modelHub) {
+ lazy.console.debug("Creating model hub instance");
+ modelHub = new lazy.ModelHub({
+ rootUrl: lazy.MODEL_HUB_ROOT_URL,
+ urlTemplate: lazy.MODEL_HUB_URL_TEMPLATE,
+ });
+ }
+
+ if (url.startsWith(lazy.MODEL_HUB_ROOT_URL)) {
+ url = url.slice(lazy.MODEL_HUB_ROOT_URL.length);
+ // Make sure we get a front slash
+ if (!url.startsWith("/")) {
+ url = `/${url}`;
+ }
+ }
+
+ // Parsing url to get model name, and file path.
+ // if this errors out, it will be caught in the worker
+ const parsedUrl = modelHub.parseUrl(url);
+
+ let [data, headers] = await modelHub.getModelFileAsArrayBuffer(parsedUrl);
+ return new lazy.BasePromiseWorker.Meta([url, headers, data], {
+ transfers: [data],
+ });
+}
+
/**
- * Fake the engine by slicing the text in half.
+ * Wrapper around the ChromeWorker that runs the inference.
*/
-class FakeEngine {
+class InferenceEngine {
/** @type {BasePromiseWorker} */
#worker;
@@ -286,21 +379,22 @@ class FakeEngine {
* Initialize the worker.
*
* @param {ArrayBuffer} wasm
- * @param {ArrayBuffer} model
- * @returns {FakeEngine}
+ * @param {PipelineOptions} pipelineOptions
+ * @returns {InferenceEngine}
*/
- static async create(wasm, model) {
+ static async create(wasm, pipelineOptions) {
/** @type {BasePromiseWorker} */
const worker = new lazy.BasePromiseWorker(
"chrome://global/content/ml/MLEngine.worker.mjs",
- { type: "module" }
+ { type: "module" },
+ { getModelFile }
);
- const args = [wasm, model, lazy.loggingLevel];
+ const args = [wasm, pipelineOptions];
const closure = {};
- const transferables = [wasm, model];
+ const transferables = [wasm];
await worker.post("initializeEngine", args, closure, transferables);
- return new FakeEngine(worker);
+ return new InferenceEngine(worker);
}
/**