summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/content/ONNXPipeline.mjs
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/content/ONNXPipeline.mjs')
-rw-r--r--toolkit/components/ml/content/ONNXPipeline.mjs297
1 files changed, 297 insertions, 0 deletions
diff --git a/toolkit/components/ml/content/ONNXPipeline.mjs b/toolkit/components/ml/content/ONNXPipeline.mjs
new file mode 100644
index 0000000000..fcc1a0eb77
--- /dev/null
+++ b/toolkit/components/ml/content/ONNXPipeline.mjs
@@ -0,0 +1,297 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+// This import does not use Chromutils because the next version of the library
+// will require an async import, which is not supported by importESModule,
+// so we'll just add await here.
+import {
+ env,
+ RawImage,
+ AutoProcessor,
+ AutoTokenizer,
+ AutoModelForVision2Seq,
+} from "chrome://global/content/ml/transformers-dev.js";
+
+/**
+ * Lazy initialization container.
+ *
+ * @type {object}
+ */
+
+const lazy = {};
+
+ChromeUtils.defineESModuleGetters(
+ lazy,
+ {
+ arrayBufferToBlobURL: "chrome://global/content/ml/Utils.sys.mjs",
+ },
+ { global: "current" }
+);
+
+// Using a custom console, see https://bugzilla.mozilla.org/show_bug.cgi?id=1891789
+let _logLevel = "Error";
+
+function debug(...args) {
+ if (["Debug", "Trace", "All"].includes(_logLevel)) {
+ console.log("ML:", ...args); // eslint-disable-line no-console
+ }
+}
+
+/**
+ * Echo inference for testing purposes.
+ *
+ * @async
+ * @param {object} request - The request object containing image data.
+ * @param {object} _model - The model used for inference.
+ * @param {object} _tokenizer - The tokenizer used for decoding.
+ * @param {object} _processor - The processor used for preparing image data.
+ * @returns {Promise<object>} The result object containing the processed text.
+ */
+async function echo(request, _model, _tokenizer, _processor) {
+ return {
+ metrics: {
+ tokenizingTime: 0,
+ },
+ output: request.data,
+ };
+}
+
+/**
+ * Converts an image to text using a machine learning model.
+ *
+ * @async
+ * @param {object} request - The request object containing image data.
+ * @param {string} [request.imageUrl] - The URL of the image to process. Either `imageUrl` or `data` must be provided, but not both.
+ * @param {ArrayBuffer} [request.data] - The raw image data to process. Either `data` or `imageUrl` must be provided, but not both.
+ * @param {string} request.mimeType - The MIME type of the image data.
+ * @param {object} model - The model used for inference.
+ * @param {object} tokenizer - The tokenizer used for decoding.
+ * @param {object} processor - The processor used for preparing image data.
+ * @returns {Promise<object>} The result object containing the processed text.
+ */
+async function imageToText(request, model, tokenizer, processor) {
+ let result = {
+ metrics: {
+ inferenceTime: 0,
+ tokenizingTime: 0,
+ },
+ };
+ let start = Date.now();
+ let rawImage;
+
+ if ("imageUrl" in request) {
+ rawImage = await RawImage.fromUrl(request.imageUrl);
+ } else {
+ const blob = new Blob([request.data], { type: request.mimeType });
+ rawImage = await RawImage.fromBlob(blob);
+ }
+
+ debug("Image loaded in ", Date.now() - start);
+
+ const { pixel_values } = await processor(rawImage);
+ result.metrics.tokenizingTime += Date.now() - start;
+ const toReturn = [];
+ for (const batch of pixel_values) {
+ batch.dims = [1, ...batch.dims];
+ start = Date.now();
+ const output = await model.generate(batch);
+ result.metrics.inferenceTime += Date.now() - start;
+ start = Date.now();
+ const decoded = tokenizer
+ .batch_decode(output, {
+ skip_special_tokens: true,
+ })
+ .map(x => ({ generated_text: x.trim() }));
+ result.metrics.tokenizingTime += Date.now() - start;
+ toReturn.push(decoded);
+ }
+ debug("Inference done in ", Date.now() - start);
+ result.output = toReturn[0][0].generated_text;
+ return result;
+}
+
+/**
+ * Configuration for engine. Each task has a configuration object that
+ * gets merged at runtime with the options from PipelineOptions.
+ *
+ * When a key exists in both the default configuration and the options,
+ * the value from the options is used.
+ *
+ * The configuration keys that are not exposed as options are all the
+ * callables that are used in the pipeline:
+ *
+ * - modelClass
+ * - tokenizerClass
+ * - processorClass
+ * - pipelineFunction
+ *
+ * @type {object}
+ */
+const ENGINE_CONFIGURATION = {
+ "image-to-text": {
+ modelId: "mozilla/distilvit",
+ modelClass: AutoModelForVision2Seq,
+ tokenizerId: "mozilla/distilvit",
+ tokenizerClass: AutoTokenizer,
+ processorId: "mozilla/distilvit",
+ processorClass: AutoProcessor,
+ pipelineFunction: imageToText,
+ },
+ echo: {
+ modelId: null,
+ modelClass: null,
+ tokenizerId: null,
+ tokenizerClass: null,
+ processorId: null,
+ processorClass: null,
+ pipelineFunction: echo,
+ },
+};
+
+/**
+ * Represents a pipeline for processing machine learning tasks.
+ */
+export class Pipeline {
+ #modelCache = null;
+ #model = null;
+ #tokenizer = null;
+ #processor = null;
+ #pipelineFunction = null;
+ #taskName = null;
+ #initTime = 0;
+ #isReady = false;
+
+ /**
+ * Creates an instance of a Pipeline.
+ *
+ * @param {object} modelCache - Implements the Cache interface and used to get models
+ * @param {object} config - The configuration options
+ */
+ constructor(modelCache, config) {
+ let start = Date.now();
+ this.#modelCache = modelCache;
+
+ _logLevel = config.logLevel || "Error";
+ // Setting up the Transformers.js environment
+ // See https://huggingface.co/docs/transformers.js/api/env
+
+ // Caching strategy.
+ // Here we make sure that everytime transformers.js requires a file, it uses
+ // modelCache, which transfers the request to the main thread and uses the
+ // ModelHub that caches files into IndexDB.
+ env.useBrowserCache = false;
+ env.allowLocalModels = false;
+ env.remoteHost = config.modelHubRootUrl;
+ env.remotePathTemplate = config.modelHubUrlTemplate;
+ env.useCustomCache = true;
+ env.customCache = this.#modelCache;
+ env.localModelPath = "/";
+
+ // ONNX runtime - we set up the wasm runtime we got from RS for the ONNX backend to pick
+ debug("Setting up ONNX backend");
+ env.backends.onnx.wasm.wasmPaths = {};
+ env.backends.onnx.wasm.wasmPaths[config.runtimeFilename] =
+ lazy.arrayBufferToBlobURL(config.runtime);
+
+ if (config.modelClass && config.modelId) {
+ debug(`Loading model ${config.modelId} with class ${config.modelClass}`);
+ this.#model = config.modelClass.from_pretrained(config.modelId);
+ }
+ if (config.tokenizerClass && config.tokenizerId) {
+ debug(
+ `Loading tokenizer ${config.tokenizerId} with class ${config.tokenizerClass}`
+ );
+ this.#tokenizer = config.tokenizerClass.from_pretrained(
+ config.tokenizerId
+ );
+ }
+ if (config.processorClass && config.processorId) {
+ debug(
+ `Loading processor ${config.processorId} with class ${config.processorClass}`
+ );
+ this.#processor = config.processorClass.from_pretrained(
+ config.processorId
+ );
+ }
+ this.#taskName = config.taskName;
+ this.#pipelineFunction = config.pipelineFunction.bind(this);
+ this.#initTime = Date.now() - start;
+ debug("Pipeline initialized, took ", this.#initTime);
+ }
+
+ /**
+ * Initializes the pipeline with given options.
+ *
+ * @static
+ * @async
+ * @param {object} modelCache - Implements the Cache interface and used to get models
+ * @param {ArrayBuffer} runtime - The runtime wasm file.
+ * @param {PipelineOptions} options - The options for initialization.
+ * @returns {Promise<Pipeline>} The initialized pipeline instance.
+ */
+ static async initialize(modelCache, runtime, options) {
+ const taskName = options.taskName;
+ debug(`Initializing Pipeline for task ${taskName}`);
+
+ if (!ENGINE_CONFIGURATION[taskName]) {
+ throw new Error(`Task ${taskName} is not supported`);
+ }
+
+ // Loading the config defaults for the task
+ let config = { ...ENGINE_CONFIGURATION[taskName] };
+ config.runtime = runtime;
+
+ // Overriding the defaults with the options
+ options.applyToConfig(config);
+
+ if (!config.pipelineFunction) {
+ throw new Error("pipelineFunction is required for the pipeline");
+ }
+ return new Pipeline(modelCache, config);
+ }
+
+ /**
+ * Runs the pipeline with the given request.
+ *
+ * @async
+ * @param {T} request - The request object to be processed. The fields it may contain
+ * depends on the task. See each pipeline function for more details.
+ * @returns {Promise<object>} The result object from the pipeline execution.
+ */
+ async run(request) {
+ debug("Running task: ", this.#taskName);
+ // Calling all promises to ensure they are resolved before running the first pipeline
+ if (!this.#isReady) {
+ let start = Date.now();
+ debug("Initializing model, tokenizer and processor");
+
+ // deactive console.warn, see https://bugzilla.mozilla.org/show_bug.cgi?id=1891003
+ const originalWarn = console.warn;
+ console.warn = () => {};
+ try {
+ this.#model = await this.#model;
+ this.#tokenizer = await this.#tokenizer;
+ this.#processor = await this.#processor;
+ this.#isReady = true;
+ } catch (error) {
+ debug("Error initializing pipeline", error);
+ throw error;
+ } finally {
+ console.warn = originalWarn;
+ }
+
+ this.#initTime += Date.now() - start;
+ debug("Pipeline is fully initialized, took ", this.#initTime);
+ }
+
+ let result = await this.#pipelineFunction(
+ request,
+ this.#model,
+ this.#tokenizer,
+ this.#processor
+ );
+ result.metrics.initTime = this.#initTime;
+ return result;
+ }
+}