summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/content/ModelHub.sys.mjs
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:35:37 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:35:37 +0000
commita90a5cba08fdf6c0ceb95101c275108a152a3aed (patch)
tree532507288f3defd7f4dcf1af49698bcb76034855 /toolkit/components/ml/content/ModelHub.sys.mjs
parentAdding debian version 126.0.1-1. (diff)
downloadfirefox-a90a5cba08fdf6c0ceb95101c275108a152a3aed.tar.xz
firefox-a90a5cba08fdf6c0ceb95101c275108a152a3aed.zip
Merging upstream version 127.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/components/ml/content/ModelHub.sys.mjs')
-rw-r--r--toolkit/components/ml/content/ModelHub.sys.mjs317
1 files changed, 177 insertions, 140 deletions
diff --git a/toolkit/components/ml/content/ModelHub.sys.mjs b/toolkit/components/ml/content/ModelHub.sys.mjs
index 4c2181ff14..10f83c3000 100644
--- a/toolkit/components/ml/content/ModelHub.sys.mjs
+++ b/toolkit/components/ml/content/ModelHub.sys.mjs
@@ -24,8 +24,7 @@ const ALLOWED_HUBS = [
];
const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"];
-const DEFAULT_URL_TEMPLATE =
- "${organization}/${modelName}/resolve/${modelVersion}/${file}";
+const DEFAULT_URL_TEMPLATE = "{model}/resolve/{revision}";
/**
* Checks if a given URL string corresponds to an allowed hub.
@@ -239,17 +238,36 @@ export class IndexedDBCache {
}
/**
+ * Checks if a specified model file exists in storage.
+ *
+ * @param {string} model - The model name (organization/name)
+ * @param {string} revision - The model revision.
+ * @param {string} file - The file name.
+ * @returns {Promise<boolean>} A promise that resolves with `true` if the key exists, otherwise `false`.
+ */
+ async fileExists(model, revision, file) {
+ const storeName = this.fileStoreName;
+ const cacheKey = `${model}/${revision}/${file}`;
+ return new Promise((resolve, reject) => {
+ const transaction = this.db.transaction([storeName], "readonly");
+ const store = transaction.objectStore(storeName);
+ const request = store.getKey(cacheKey);
+ request.onerror = event => reject(event.target.error);
+ request.onsuccess = event => resolve(event.target.result !== undefined);
+ });
+ }
+
+ /**
* Retrieves the headers for a specific cache entry.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name)
+ * @param {string} revision - The model revision.
* @param {string} file - The file name.
* @returns {Promise<object|null>} The headers or null if not found.
*/
- async getHeaders(organization, modelName, modelVersion, file) {
- const headersKey = `${organization}/${modelName}/${modelVersion}`;
- const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ async getHeaders(model, revision, file) {
+ const headersKey = `${model}/${revision}`;
+ const cacheKey = `${model}/${revision}/${file}`;
const headers = await this.#getData(this.headersStoreName, headersKey);
if (headers && headers.files[cacheKey]) {
return headers.files[cacheKey];
@@ -260,22 +278,16 @@ export class IndexedDBCache {
/**
* Retrieves the file for a specific cache entry.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name).
+ * @param {string} revision - The model version.
* @param {string} file - The file name.
* @returns {Promise<[ArrayBuffer, object]|null>} The file ArrayBuffer and its headers or null if not found.
*/
- async getFile(organization, modelName, modelVersion, file) {
- const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ async getFile(model, revision, file) {
+ const cacheKey = `${model}/${revision}/${file}`;
const stored = await this.#getData(this.fileStoreName, cacheKey);
if (stored) {
- const headers = await this.getHeaders(
- organization,
- modelName,
- modelVersion,
- file
- );
+ const headers = await this.getHeaders(model, revision, file);
return [stored.data, headers];
}
return null; // Return null if no file is found
@@ -284,29 +296,21 @@ export class IndexedDBCache {
/**
* Adds or updates a cache entry.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name).
+ * @param {string} revision - The model version.
* @param {string} file - The file name.
* @param {ArrayBuffer} arrayBuffer - The data to cache.
* @param {object} [headers] - The headers for the file.
* @returns {Promise<void>}
*/
- async put(
- organization,
- modelName,
- modelVersion,
- file,
- arrayBuffer,
- headers = {}
- ) {
- const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`;
+ async put(model, revision, file, arrayBuffer, headers = {}) {
+ const cacheKey = `${model}/${revision}/${file}`;
const newSize = this.totalSize + arrayBuffer.byteLength;
if (newSize > this.#maxSize) {
throw new Error("Exceeding total cache size limit of 1GB");
}
- const headersKey = `${organization}/${modelName}/${modelVersion}`;
+ const headersKey = `${model}/${revision}`;
const data = { id: cacheKey, data: arrayBuffer };
// Store the file data
@@ -356,13 +360,12 @@ export class IndexedDBCache {
/**
* Deletes all data related to a specific model.
*
- * @param {string} organization - The organization name.
- * @param {string} modelName - The model name.
- * @param {string} modelVersion - The model version.
+ * @param {string} model - The model name (organization/name).
+ * @param {string} revision - The model version.
* @returns {Promise<void>}
*/
- async deleteModel(organization, modelName, modelVersion) {
- const headersKey = `${organization}/${modelName}/${modelVersion}`;
+ async deleteModel(model, revision) {
+ const headersKey = `${model}/${revision}`;
const headers = await this.#getData(this.headersStoreName, headersKey);
if (headers) {
for (const fileKey in headers.files) {
@@ -409,9 +412,9 @@ export class ModelHub {
this.cache = null;
// Ensures the URL template is well-formed and does not contain any invalid characters.
- const pattern = /^(?:\$\{\w+\}|\w+)(?:\/(?:\$\{\w+\}|\w+))*$/;
+ const pattern = /^(?:\{\w+\}|\w+)(?:\/(?:\{\w+\}|\w+))*$/;
// ^ $ Start and end of string
- // (?:\$\{\w+\}|\w+) Match a ${placeholder} or alphanumeric characters
+ // (?:\{\w+\}|\w+) Match a {placeholder} or alphanumeric characters
// (?:\/(?:\$\{\w+\}|\w+))* Zero or more groups of a forward slash followed by a ${placeholder} or alphanumeric characters
if (!pattern.test(urlTemplate)) {
throw new Error(`Invalid URL template: ${urlTemplate}`);
@@ -426,33 +429,94 @@ export class ModelHub {
this.cache = await IndexedDBCache.init();
}
+ /**
+ * This method takes a model URL and parses it to extract the
+ * model name, optional model version, and file path.
+ *
+ * The expected URL format are :
+ *
+ * `/organization/model/revision/filePath`
+ * `https://hub/organization/model/revision/filePath`
+ *
+ * @param {string} url - The full URL to the model, including protocol and domain - or the relative path.
+ * @returns {object} An object containing the parsed components of the URL. The
+ * object has properties `model`, and `file`,
+ * and optionally `revision` if the URL includes a version.
+ * @throws {Error} Throws an error if the URL does not start with `this.rootUrl` or
+ * if the URL format does not match the expected structure.
+ *
+ * @example
+ * // For a URL
+ * parseModelUrl("https://example.com/org1/model1/v1/file/path");
+ * // returns { model: "org1/model1", revision: "v1", file: "file/path" }
+ *
+ * @example
+ * // For a relative URL
+ * parseModelUrl("/org1/model1/revision/file/path");
+ * // returns { model: "org1/model1", revision: "v1", file: "file/path" }
+ */
+ parseUrl(url) {
+ let parts;
+ if (url.startsWith("/")) {
+ // relative URL
+ parts = url.slice(1).split("/");
+ } else {
+ // absolute URL
+ if (!url.startsWith(this.rootUrl)) {
+ throw new Error(`Invalid domain for model URL: ${url}`);
+ }
+ const urlObject = new URL(url);
+ const rootUrlObject = new URL(this.rootUrl);
+
+ // Remove the root URL's pathname from the full URL's pathname
+ const relativePath = urlObject.pathname.substring(
+ rootUrlObject.pathname.length
+ );
+
+ parts = relativePath.slice(1).split("/");
+ }
+
+ if (parts.length < 3) {
+ throw new Error(`Invalid model URL: ${url}`);
+ }
+
+ const file = parts.slice(3).join("/");
+ if (file == null || !file.length) {
+ throw new Error(`Invalid model URL: ${url}`);
+ }
+
+ return {
+ model: `${parts[0]}/${parts[1]}`,
+ revision: parts[2],
+ file,
+ };
+ }
+
/** Creates the file URL from the organization, model, and version.
*
- * @param {string} organization
- * @param {string} modelName
- * @param {string} modelVersion
+ * @param {string} model
+ * @param {string} revision
* @param {string} file
* @returns {string} The full URL
*/
- #fileUrl(organization, modelName, modelVersion, file) {
+ #fileUrl(model, revision, file) {
const baseUrl = new URL(this.rootUrl);
if (!baseUrl.pathname.endsWith("/")) {
baseUrl.pathname += "/";
}
-
// Replace placeholders in the URL template with the provided data.
// If some keys are missing in the data object, the placeholder is left as is.
// If the placeholder is not found in the data object, it is left as is.
const data = {
- organization,
- modelName,
- modelVersion,
- file,
+ model,
+ revision,
};
- const path = this.urlTemplate.replace(
- /\$\{(\w+)\}/g,
+ let path = this.urlTemplate.replace(
+ /\{(\w+)\}/g,
(match, key) => data[key] || match
);
+ path = `${path}/${file}`;
+
const fullPath = `${baseUrl.pathname}${
path.startsWith("/") ? path.slice(1) : path
}`;
@@ -462,28 +526,29 @@ export class ModelHub {
return urlObject.toString();
}
- /** Checks the organization, model, and version inputs.
+ /** Checks the model and revision inputs.
*
- * @param { string } organization
- * @param { string } modelName
- * @param { string } modelVersion
+ * @param { string } model
+ * @param { string } revision
* @param { string } file
* @returns { Error } The error instance(can be null)
*/
- #checkInput(organization, modelName, modelVersion, file) {
- // Ensures string consists only of letters, digits, and hyphens without starting/ending
- // with a hyphen or containing consecutive hyphens.
+ #checkInput(model, revision, file) {
+ // Matches a string with the format 'organization/model' where:
+ // - 'organization' consists only of letters, digits, and hyphens, cannot start or end with a hyphen,
+ // and cannot contain consecutive hyphens.
+ // - 'model' can contain letters, digits, hyphens, underscores, or periods.
//
- // ^ $ Start and end of string
- // (?!-) (?<!-) Negative lookahead/behind for not starting or ending with hyphen
- // (?!.*--) Negative lookahead for not containing consecutive hyphens
- // [A-Za-z0-9-]+ Alphanum characters or hyphens, one or more
- const orgRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)$/;
-
- // Matches strings containing letters, digits, hyphens, underscores, or periods.
- // ^ $ Start and end of string
- // [A-Za-z0-9-_.]+ Alphanum characters, hyphens, underscores, or periods, one or more times
- const modelRegex = /^[A-Za-z0-9-_.]+$/;
+ // Pattern breakdown:
+ // ^ Start of string
+ // (?!-) Negative lookahead for 'organization' not starting with hyphen
+ // (?!.*--) Negative lookahead for 'organization' not containing consecutive hyphens
+ // [A-Za-z0-9-]+ 'organization' part: Alphanumeric characters or hyphens
+ // (?<!-) Negative lookbehind for 'organization' not ending with a hyphen
+ // \/ Literal '/' character separating 'organization' and 'model'
+ // [A-Za-z0-9-_.]+ 'model' part: Alphanumeric characters, hyphens, underscores, or periods
+ // $ End of string
+ const modelRegex = /^(?!-)(?!.*--)[A-Za-z0-9-]+(?<!-)\/[A-Za-z0-9-_.]+$/;
// Matches strings consisting of alphanumeric characters, hyphens, or periods.
//
@@ -502,22 +567,18 @@ export class ModelHub {
// \/ Directory separator
// [A-Za-z0-9-_]+ Directory or file name
// )* Zero or more times
- // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension
+ // (?:[.][A-Za-z]{2,4})? Optional non-capturing group for file extension
const fileRegex =
/^(?:\/)?(?!\/)[A-Za-z0-9-_]+(?:\/[A-Za-z0-9-_]+)*(?:[.][A-Za-z]{2,4})?$/;
- if (!orgRegex.test(organization) || !isNaN(parseInt(organization))) {
- return new Error(`Invalid organization name ${organization}`);
- }
-
- if (!modelRegex.test(modelName)) {
+ if (!modelRegex.test(model)) {
return new Error("Invalid model name.");
}
if (
- !versionRegex.test(modelVersion) ||
- modelVersion.includes(" ") ||
- /[\^$]/.test(modelVersion)
+ !versionRegex.test(revision) ||
+ revision.includes(" ") ||
+ /[\^$]/.test(revision)
) {
return new Error("Invalid version identifier.");
}
@@ -536,7 +597,7 @@ export class ModelHub {
* @param {number} timeout in ms. Default is 1000
* @returns {Promise<string>} ETag (can be null)
*/
- async #getETag(url, timeout = 1000) {
+ async getETag(url, timeout = 1000) {
const controller = new AbortController();
const id = lazy.setTimeout(() => controller.abort(), timeout);
@@ -559,24 +620,18 @@ export class ModelHub {
* Given an organization, model, and version, fetch a model file in the hub as a Response.
*
* @param {object} config
- * @param {string} config.organization
- * @param {string} config.modelName
- * @param {string} config.modelVersion
+ * @param {string} config.model
+ * @param {string} config.revision
* @param {string} config.file
* @returns {Promise<Response>} The file content
*/
- async getModelFileAsResponse({
- organization,
- modelName,
- modelVersion,
- file,
- }) {
+ async getModelFileAsResponse({ model, revision, file }) {
const [blob, headers] = await this.getModelFileAsBlob({
- organization,
- modelName,
- modelVersion,
+ model,
+ revision,
file,
});
+
return new Response(blob, { headers });
}
@@ -584,22 +639,15 @@ export class ModelHub {
* Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer.
*
* @param {object} config
- * @param {string} config.organization
- * @param {string} config.modelName
- * @param {string} config.modelVersion
+ * @param {string} config.model
+ * @param {string} config.revision
* @param {string} config.file
* @returns {Promise<[ArrayBuffer, headers]>} The file content
*/
- async getModelFileAsArrayBuffer({
- organization,
- modelName,
- modelVersion,
- file,
- }) {
+ async getModelFileAsArrayBuffer({ model, revision, file }) {
const [blob, headers] = await this.getModelFileAsBlob({
- organization,
- modelName,
- modelVersion,
+ model,
+ revision,
file,
});
return [await blob.arrayBuffer(), headers];
@@ -609,54 +657,44 @@ export class ModelHub {
* Given an organization, model, and version, fetch a model file in the hub as blob.
*
* @param {object} config
- * @param {string} config.organization
- * @param {string} config.modelName
- * @param {string} config.modelVersion
+ * @param {string} config.model
+ * @param {string} config.revision
* @param {string} config.file
* @returns {Promise<[Blob, object]>} The file content
*/
- async getModelFileAsBlob({ organization, modelName, modelVersion, file }) {
+ async getModelFileAsBlob({ model, revision, file }) {
// Make sure inputs are clean. We don't sanitize them but throw an exception
- let checkError = this.#checkInput(
- organization,
- modelName,
- modelVersion,
- file
- );
+ let checkError = this.#checkInput(model, revision, file);
if (checkError) {
throw checkError;
}
-
- const url = this.#fileUrl(organization, modelName, modelVersion, file);
+ const url = this.#fileUrl(model, revision, file);
lazy.console.debug(`Getting model file from ${url}`);
await this.#initCache();
- // this can be null if no ETag was found or there were a network error
- const hubETag = await this.#getETag(url);
+ let useCached;
- lazy.console.debug(
- `Checking the cache for ${organization}/${modelName}/${modelVersion}/${file}`
- );
+ // If the revision is `main` we want to check the ETag in the hub
+ if (revision === "main") {
+ // this can be null if no ETag was found or there were a network error
+ const hubETag = await this.getETag(url);
- // storage lookup
- const cachedHeaders = await this.cache.getHeaders(
- organization,
- modelName,
- modelVersion,
- file
- );
- const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null;
-
- // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response
- if (cachedEtag !== null && (hubETag === null || cachedEtag === hubETag)) {
- lazy.console.debug(`Cache Hit`);
- return await this.cache.getFile(
- organization,
- modelName,
- modelVersion,
- file
- );
+ // Storage ETag lookup
+ const cachedHeaders = await this.cache.getHeaders(model, revision, file);
+ const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null;
+
+ // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response
+ useCached =
+ cachedEtag !== null && (hubETag === null || cachedEtag === hubETag);
+ } else {
+ // If we are dealing with a pinned revision, we ignore the ETag, to spare HEAD hits on every call
+ useCached = await this.cache.fileExists(model, revision, file);
+ }
+
+ if (useCached) {
+ lazy.console.debug(`Cache Hit for ${url}`);
+ return await this.cache.getFile(model, revision, file);
}
lazy.console.debug(`Fetching ${url}`);
@@ -668,13 +706,12 @@ export class ModelHub {
// We don't store the boundary or the charset, just the content type,
// so we drop what's after the semicolon.
"Content-Type": response.headers.get("Content-Type").split(";")[0],
- ETag: hubETag,
+ ETag: response.headers.get("ETag"),
};
await this.cache.put(
- organization,
- modelName,
- modelVersion,
+ model,
+ revision,
file,
await clone.blob(),
headers