From a90a5cba08fdf6c0ceb95101c275108a152a3aed Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 12 Jun 2024 07:35:37 +0200 Subject: Merging upstream version 127.0. Signed-off-by: Daniel Baumann --- toolkit/components/ml/content/ModelHub.sys.mjs | 317 ++++++++++++++----------- 1 file changed, 177 insertions(+), 140 deletions(-) (limited to 'toolkit/components/ml/content/ModelHub.sys.mjs') diff --git a/toolkit/components/ml/content/ModelHub.sys.mjs b/toolkit/components/ml/content/ModelHub.sys.mjs index 4c2181ff14..10f83c3000 100644 --- a/toolkit/components/ml/content/ModelHub.sys.mjs +++ b/toolkit/components/ml/content/ModelHub.sys.mjs @@ -24,8 +24,7 @@ const ALLOWED_HUBS = [ ]; const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"]; -const DEFAULT_URL_TEMPLATE = - "${organization}/${modelName}/resolve/${modelVersion}/${file}"; +const DEFAULT_URL_TEMPLATE = "{model}/resolve/{revision}"; /** * Checks if a given URL string corresponds to an allowed hub. @@ -238,18 +237,37 @@ export class IndexedDBCache { }); } + /** + * Checks if a specified model file exists in storage. + * + * @param {string} model - The model name (organization/name) + * @param {string} revision - The model revision. + * @param {string} file - The file name. + * @returns {Promise} A promise that resolves with `true` if the key exists, otherwise `false`. + */ + async fileExists(model, revision, file) { + const storeName = this.fileStoreName; + const cacheKey = `${model}/${revision}/${file}`; + return new Promise((resolve, reject) => { + const transaction = this.db.transaction([storeName], "readonly"); + const store = transaction.objectStore(storeName); + const request = store.getKey(cacheKey); + request.onerror = event => reject(event.target.error); + request.onsuccess = event => resolve(event.target.result !== undefined); + }); + } + /** * Retrieves the headers for a specific cache entry. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name) + * @param {string} revision - The model revision. * @param {string} file - The file name. * @returns {Promise} The headers or null if not found. */ - async getHeaders(organization, modelName, modelVersion, file) { - const headersKey = `${organization}/${modelName}/${modelVersion}`; - const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + async getHeaders(model, revision, file) { + const headersKey = `${model}/${revision}`; + const cacheKey = `${model}/${revision}/${file}`; const headers = await this.#getData(this.headersStoreName, headersKey); if (headers && headers.files[cacheKey]) { return headers.files[cacheKey]; @@ -260,22 +278,16 @@ export class IndexedDBCache { /** * Retrieves the file for a specific cache entry. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name). + * @param {string} revision - The model version. * @param {string} file - The file name. * @returns {Promise<[ArrayBuffer, object]|null>} The file ArrayBuffer and its headers or null if not found. */ - async getFile(organization, modelName, modelVersion, file) { - const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + async getFile(model, revision, file) { + const cacheKey = `${model}/${revision}/${file}`; const stored = await this.#getData(this.fileStoreName, cacheKey); if (stored) { - const headers = await this.getHeaders( - organization, - modelName, - modelVersion, - file - ); + const headers = await this.getHeaders(model, revision, file); return [stored.data, headers]; } return null; // Return null if no file is found @@ -284,29 +296,21 @@ export class IndexedDBCache { /** * Adds or updates a cache entry. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name). + * @param {string} revision - The model version. * @param {string} file - The file name. * @param {ArrayBuffer} arrayBuffer - The data to cache. * @param {object} [headers] - The headers for the file. * @returns {Promise} */ - async put( - organization, - modelName, - modelVersion, - file, - arrayBuffer, - headers = {} - ) { - const cacheKey = `${organization}/${modelName}/${modelVersion}/${file}`; + async put(model, revision, file, arrayBuffer, headers = {}) { + const cacheKey = `${model}/${revision}/${file}`; const newSize = this.totalSize + arrayBuffer.byteLength; if (newSize > this.#maxSize) { throw new Error("Exceeding total cache size limit of 1GB"); } - const headersKey = `${organization}/${modelName}/${modelVersion}`; + const headersKey = `${model}/${revision}`; const data = { id: cacheKey, data: arrayBuffer }; // Store the file data @@ -356,13 +360,12 @@ export class IndexedDBCache { /** * Deletes all data related to a specific model. * - * @param {string} organization - The organization name. - * @param {string} modelName - The model name. - * @param {string} modelVersion - The model version. + * @param {string} model - The model name (organization/name). + * @param {string} revision - The model version. * @returns {Promise} */ - async deleteModel(organization, modelName, modelVersion) { - const headersKey = `${organization}/${modelName}/${modelVersion}`; + async deleteModel(model, revision) { + const headersKey = `${model}/${revision}`; const headers = await this.#getData(this.headersStoreName, headersKey); if (headers) { for (const fileKey in headers.files) { @@ -409,9 +412,9 @@ export class ModelHub { this.cache = null; // Ensures the URL template is well-formed and does not contain any invalid characters. - const pattern = /^(?:\$\{\w+\}|\w+)(?:\/(?:\$\{\w+\}|\w+))*$/; + const pattern = /^(?:\{\w+\}|\w+)(?:\/(?:\{\w+\}|\w+))*$/; // ^ $ Start and end of string - // (?:\$\{\w+\}|\w+) Match a ${placeholder} or alphanumeric characters + // (?:\{\w+\}|\w+) Match a {placeholder} or alphanumeric characters // (?:\/(?:\$\{\w+\}|\w+))* Zero or more groups of a forward slash followed by a ${placeholder} or alphanumeric characters if (!pattern.test(urlTemplate)) { throw new Error(`Invalid URL template: ${urlTemplate}`); @@ -426,33 +429,94 @@ export class ModelHub { this.cache = await IndexedDBCache.init(); } + /** + * This method takes a model URL and parses it to extract the + * model name, optional model version, and file path. + * + * The expected URL format are : + * + * `/organization/model/revision/filePath` + * `https://hub/organization/model/revision/filePath` + * + * @param {string} url - The full URL to the model, including protocol and domain - or the relative path. + * @returns {object} An object containing the parsed components of the URL. The + * object has properties `model`, and `file`, + * and optionally `revision` if the URL includes a version. + * @throws {Error} Throws an error if the URL does not start with `this.rootUrl` or + * if the URL format does not match the expected structure. + * + * @example + * // For a URL + * parseModelUrl("https://example.com/org1/model1/v1/file/path"); + * // returns { model: "org1/model1", revision: "v1", file: "file/path" } + * + * @example + * // For a relative URL + * parseModelUrl("/org1/model1/revision/file/path"); + * // returns { model: "org1/model1", revision: "v1", file: "file/path" } + */ + parseUrl(url) { + let parts; + if (url.startsWith("/")) { + // relative URL + parts = url.slice(1).split("/"); + } else { + // absolute URL + if (!url.startsWith(this.rootUrl)) { + throw new Error(`Invalid domain for model URL: ${url}`); + } + const urlObject = new URL(url); + const rootUrlObject = new URL(this.rootUrl); + + // Remove the root URL's pathname from the full URL's pathname + const relativePath = urlObject.pathname.substring( + rootUrlObject.pathname.length + ); + + parts = relativePath.slice(1).split("/"); + } + + if (parts.length < 3) { + throw new Error(`Invalid model URL: ${url}`); + } + + const file = parts.slice(3).join("/"); + if (file == null || !file.length) { + throw new Error(`Invalid model URL: ${url}`); + } + + return { + model: `${parts[0]}/${parts[1]}`, + revision: parts[2], + file, + }; + } + /** Creates the file URL from the organization, model, and version. * - * @param {string} organization - * @param {string} modelName - * @param {string} modelVersion + * @param {string} model + * @param {string} revision * @param {string} file * @returns {string} The full URL */ - #fileUrl(organization, modelName, modelVersion, file) { + #fileUrl(model, revision, file) { const baseUrl = new URL(this.rootUrl); if (!baseUrl.pathname.endsWith("/")) { baseUrl.pathname += "/"; } - // Replace placeholders in the URL template with the provided data. // If some keys are missing in the data object, the placeholder is left as is. // If the placeholder is not found in the data object, it is left as is. const data = { - organization, - modelName, - modelVersion, - file, + model, + revision, }; - const path = this.urlTemplate.replace( - /\$\{(\w+)\}/g, + let path = this.urlTemplate.replace( + /\{(\w+)\}/g, (match, key) => data[key] || match ); + path = `${path}/${file}`; + const fullPath = `${baseUrl.pathname}${ path.startsWith("/") ? path.slice(1) : path }`; @@ -462,28 +526,29 @@ export class ModelHub { return urlObject.toString(); } - /** Checks the organization, model, and version inputs. + /** Checks the model and revision inputs. * - * @param { string } organization - * @param { string } modelName - * @param { string } modelVersion + * @param { string } model + * @param { string } revision * @param { string } file * @returns { Error } The error instance(can be null) */ - #checkInput(organization, modelName, modelVersion, file) { - // Ensures string consists only of letters, digits, and hyphens without starting/ending - // with a hyphen or containing consecutive hyphens. + #checkInput(model, revision, file) { + // Matches a string with the format 'organization/model' where: + // - 'organization' consists only of letters, digits, and hyphens, cannot start or end with a hyphen, + // and cannot contain consecutive hyphens. + // - 'model' can contain letters, digits, hyphens, underscores, or periods. // - // ^ $ Start and end of string - // (?!-) (?} ETag (can be null) */ - async #getETag(url, timeout = 1000) { + async getETag(url, timeout = 1000) { const controller = new AbortController(); const id = lazy.setTimeout(() => controller.abort(), timeout); @@ -559,24 +620,18 @@ export class ModelHub { * Given an organization, model, and version, fetch a model file in the hub as a Response. * * @param {object} config - * @param {string} config.organization - * @param {string} config.modelName - * @param {string} config.modelVersion + * @param {string} config.model + * @param {string} config.revision * @param {string} config.file * @returns {Promise} The file content */ - async getModelFileAsResponse({ - organization, - modelName, - modelVersion, - file, - }) { + async getModelFileAsResponse({ model, revision, file }) { const [blob, headers] = await this.getModelFileAsBlob({ - organization, - modelName, - modelVersion, + model, + revision, file, }); + return new Response(blob, { headers }); } @@ -584,22 +639,15 @@ export class ModelHub { * Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer. * * @param {object} config - * @param {string} config.organization - * @param {string} config.modelName - * @param {string} config.modelVersion + * @param {string} config.model + * @param {string} config.revision * @param {string} config.file * @returns {Promise<[ArrayBuffer, headers]>} The file content */ - async getModelFileAsArrayBuffer({ - organization, - modelName, - modelVersion, - file, - }) { + async getModelFileAsArrayBuffer({ model, revision, file }) { const [blob, headers] = await this.getModelFileAsBlob({ - organization, - modelName, - modelVersion, + model, + revision, file, }); return [await blob.arrayBuffer(), headers]; @@ -609,54 +657,44 @@ export class ModelHub { * Given an organization, model, and version, fetch a model file in the hub as blob. * * @param {object} config - * @param {string} config.organization - * @param {string} config.modelName - * @param {string} config.modelVersion + * @param {string} config.model + * @param {string} config.revision * @param {string} config.file * @returns {Promise<[Blob, object]>} The file content */ - async getModelFileAsBlob({ organization, modelName, modelVersion, file }) { + async getModelFileAsBlob({ model, revision, file }) { // Make sure inputs are clean. We don't sanitize them but throw an exception - let checkError = this.#checkInput( - organization, - modelName, - modelVersion, - file - ); + let checkError = this.#checkInput(model, revision, file); if (checkError) { throw checkError; } - - const url = this.#fileUrl(organization, modelName, modelVersion, file); + const url = this.#fileUrl(model, revision, file); lazy.console.debug(`Getting model file from ${url}`); await this.#initCache(); - // this can be null if no ETag was found or there were a network error - const hubETag = await this.#getETag(url); + let useCached; - lazy.console.debug( - `Checking the cache for ${organization}/${modelName}/${modelVersion}/${file}` - ); + // If the revision is `main` we want to check the ETag in the hub + if (revision === "main") { + // this can be null if no ETag was found or there were a network error + const hubETag = await this.getETag(url); - // storage lookup - const cachedHeaders = await this.cache.getHeaders( - organization, - modelName, - modelVersion, - file - ); - const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null; - - // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response - if (cachedEtag !== null && (hubETag === null || cachedEtag === hubETag)) { - lazy.console.debug(`Cache Hit`); - return await this.cache.getFile( - organization, - modelName, - modelVersion, - file - ); + // Storage ETag lookup + const cachedHeaders = await this.cache.getHeaders(model, revision, file); + const cachedEtag = cachedHeaders ? cachedHeaders.ETag : null; + + // If we have something in store, and the hub ETag is null or it matches the cached ETag, return the cached response + useCached = + cachedEtag !== null && (hubETag === null || cachedEtag === hubETag); + } else { + // If we are dealing with a pinned revision, we ignore the ETag, to spare HEAD hits on every call + useCached = await this.cache.fileExists(model, revision, file); + } + + if (useCached) { + lazy.console.debug(`Cache Hit for ${url}`); + return await this.cache.getFile(model, revision, file); } lazy.console.debug(`Fetching ${url}`); @@ -668,13 +706,12 @@ export class ModelHub { // We don't store the boundary or the charset, just the content type, // so we drop what's after the semicolon. "Content-Type": response.headers.get("Content-Type").split(";")[0], - ETag: hubETag, + ETag: response.headers.get("ETag"), }; await this.cache.put( - organization, - modelName, - modelVersion, + model, + revision, file, await clone.blob(), headers -- cgit v1.2.3