summaryrefslogtreecommitdiffstats
path: root/toolkit/components/ml/tests/browser/browser_ml_cache.js
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/ml/tests/browser/browser_ml_cache.js')
-rw-r--r--toolkit/components/ml/tests/browser/browser_ml_cache.js233
1 files changed, 177 insertions, 56 deletions
diff --git a/toolkit/components/ml/tests/browser/browser_ml_cache.js b/toolkit/components/ml/tests/browser/browser_ml_cache.js
index d8725368bd..8d879fc74d 100644
--- a/toolkit/components/ml/tests/browser/browser_ml_cache.js
+++ b/toolkit/components/ml/tests/browser/browser_ml_cache.js
@@ -11,16 +11,20 @@ const FAKE_HUB =
"chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data";
const FAKE_MODEL_ARGS = {
- organization: "acme",
- modelName: "bert",
- modelVersion: "main",
+ model: "acme/bert",
+ revision: "main",
+ file: "config.json",
+};
+
+const FAKE_RELEASED_MODEL_ARGS = {
+ model: "acme/bert",
+ revision: "v0.1",
file: "config.json",
};
const FAKE_ONNX_MODEL_ARGS = {
- organization: "acme",
- modelName: "bert",
- modelVersion: "main",
+ model: "acme/bert",
+ revision: "main",
file: "onnx/config.json",
};
@@ -35,6 +39,9 @@ const badHubs = [
"https://model-hub.mozilla.org.hack", // Domain that contains allowed domain
];
+/**
+ * Make sure we reject bad model hub URLs.
+ */
add_task(async function test_bad_hubs() {
for (const badHub of badHubs) {
Assert.throws(
@@ -60,60 +67,57 @@ add_task(async function test_allowed_hub() {
const badInputs = [
[
{
- organization: "ac me",
- modelName: "bert",
- modelVersion: "main",
+ model: "ac me/bert",
+ revision: "main",
file: "config.json",
},
"Org can only contain letters, numbers, and hyphens",
],
[
{
- organization: "1111",
- modelName: "bert",
- modelVersion: "main",
+ model: "1111/bert",
+ revision: "main",
file: "config.json",
},
"Org cannot contain only numbers",
],
[
{
- organization: "-acme",
- modelName: "bert",
- modelVersion: "main",
+ model: "-acme/bert",
+ revision: "main",
file: "config.json",
},
"Org start or end with a hyphen, or use consecutive hyphens",
],
[
{
- organization: "a-c-m-e",
- modelName: "#bert",
- modelVersion: "main",
+ model: "a-c-m-e/#bert",
+ revision: "main",
file: "config.json",
},
"Models can only contain letters, numbers, and hyphens, underscord, periods",
],
[
{
- organization: "a-c-m-e",
- modelName: "b$ert",
- modelVersion: "main",
+ model: "a-c-m-e/b$ert",
+ revision: "main",
file: "config.json",
},
"Models cannot contain spaces or control characters",
],
[
{
- organization: "a-c-m-e",
- modelName: "b$ert",
- modelVersion: "main",
+ model: "a-c-m-e/b$ert",
+ revision: "main",
file: ".filename",
},
"File",
],
];
+/**
+ * Make sure we reject bad inputs.
+ */
add_task(async function test_bad_inputs() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -129,6 +133,9 @@ add_task(async function test_bad_inputs() {
}
});
+/**
+ * Test that we can retrieve a file as an ArrayBuffer.
+ */
add_task(async function test_getting_file() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -144,6 +151,35 @@ add_task(async function test_getting_file() {
Assert.equal(jsonData.hidden_size, 768);
});
+/**
+ * Test that we can retrieve a file from a released model and skip head calls
+ */
+add_task(async function test_getting_released_file() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ console.log(hub);
+
+ let spy = sinon.spy(hub, "getETag");
+ let [array, headers] = await hub.getModelFileAsArrayBuffer(
+ FAKE_RELEASED_MODEL_ARGS
+ );
+
+ Assert.equal(headers["Content-Type"], "application/json");
+
+ // check the content of the file.
+ let jsonData = JSON.parse(
+ String.fromCharCode.apply(null, new Uint8Array(array))
+ );
+
+ Assert.equal(jsonData.hidden_size, 768);
+
+ // check that head calls were not made
+ Assert.ok(!spy.called, "getETag should have never been called.");
+ spy.restore();
+});
+
+/**
+ * Make sure files can be located in sub directories
+ */
add_task(async function test_getting_file_in_subdir() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -161,10 +197,13 @@ add_task(async function test_getting_file_in_subdir() {
Assert.equal(jsonData.hidden_size, 768);
});
+/**
+ * Test that we can use a custom URL template.
+ */
add_task(async function test_getting_file_custom_path() {
const hub = new ModelHub({
rootUrl: FAKE_HUB,
- urlTemplate: "${organization}/${modelName}/resolve/${modelVersion}/${file}",
+ urlTemplate: "{model}/resolve/{revision}",
});
let res = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
@@ -172,9 +211,11 @@ add_task(async function test_getting_file_custom_path() {
Assert.equal(res[1]["Content-Type"], "application/json");
});
+/**
+ * Test that we can't use an URL with a query for the template
+ */
add_task(async function test_getting_file_custom_path_rogue() {
- const urlTemplate =
- "${organization}/${modelName}/resolve/${modelVersion}/${file}?some_id=bedqwdw";
+ const urlTemplate = "{model}/resolve/{revision}/?some_id=bedqwdw";
Assert.throws(
() => new ModelHub({ rootUrl: FAKE_HUB, urlTemplate }),
/Invalid URL template/,
@@ -182,6 +223,9 @@ add_task(async function test_getting_file_custom_path_rogue() {
);
});
+/**
+ * Test that the file can be returned as a response and its content correct.
+ */
add_task(async function test_getting_file_as_response() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
@@ -192,6 +236,10 @@ add_task(async function test_getting_file_as_response() {
Assert.equal(jsonData.hidden_size, 768);
});
+/**
+ * Test that the cache is used when the data is retrieved from the server
+ * and that the cache is updated with the new data.
+ */
add_task(async function test_getting_file_from_cache() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
let array = await hub.getModelFileAsArrayBuffer(FAKE_MODEL_ARGS);
@@ -213,6 +261,75 @@ add_task(async function test_getting_file_from_cache() {
Assert.deepEqual(array, array2);
});
+/**
+ * Test parsing of a well-formed full URL, including protocol and path.
+ */
+add_task(async function testWellFormedFullUrl() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ const url = `${FAKE_HUB}/org1/model1/v1/file/path`;
+ const result = hub.parseUrl(url);
+
+ Assert.equal(
+ result.model,
+ "org1/model1",
+ "Model should be parsed correctly."
+ );
+ Assert.equal(result.revision, "v1", "Revision should be parsed correctly.");
+ Assert.equal(
+ result.file,
+ "file/path",
+ "File path should be parsed correctly."
+ );
+});
+
+/**
+ * Test parsing of a well-formed relative URL, starting with a slash.
+ */
+add_task(async function testWellFormedRelativeUrl() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+
+ const url = "/org1/model1/v1/file/path";
+ const result = hub.parseUrl(url);
+
+ Assert.equal(
+ result.model,
+ "org1/model1",
+ "Model should be parsed correctly."
+ );
+ Assert.equal(result.revision, "v1", "Revision should be parsed correctly.");
+ Assert.equal(
+ result.file,
+ "file/path",
+ "File path should be parsed correctly."
+ );
+});
+
+/**
+ * Ensures an error is thrown when the URL does not start with the expected root URL or a slash.
+ */
+add_task(async function testInvalidDomain() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ const url = "https://example.com/org1/model1/v1/file/path";
+ Assert.throws(
+ () => hub.parseUrl(url),
+ new RegExp(`Error: Invalid domain for model URL: ${url}`),
+ `Should throw with ${url}`
+ );
+});
+
+/** Tests the method's error handling when the URL format does not include the required segments.
+ *
+ */
+add_task(async function testTooFewParts() {
+ const hub = new ModelHub({ rootUrl: FAKE_HUB });
+ const url = "/org1/model1";
+ Assert.throws(
+ () => hub.parseUrl(url),
+ new RegExp(`Error: Invalid model URL: ${url}`),
+ `Should throw with ${url}`
+ );
+});
+
// IndexedDB tests
/**
@@ -248,18 +365,41 @@ add_task(async function test_Init() {
});
/**
+ * Test checking existence of data in the cache.
+ */
+add_task(async function test_PutAndCheckExists() {
+ const cache = await initializeCache();
+ const testData = new ArrayBuffer(8); // Example data
+ const key = "file.txt";
+ await cache.put("org/model", "v1", "file.txt", testData, {
+ ETag: "ETAG123",
+ });
+
+ // Checking if the file exists
+ let exists = await cache.fileExists("org/model", "v1", key);
+ Assert.ok(exists, "The file should exist in the cache.");
+
+ // Removing all files from the model
+ await cache.deleteModel("org/model", "v1");
+
+ exists = await cache.fileExists("org/model", "v1", key);
+ Assert.ok(!exists, "The file should be gone from the cache.");
+
+ await deleteCache(cache);
+});
+
+/**
* Test adding data to the cache and retrieving it.
*/
add_task(async function test_PutAndGet() {
const cache = await initializeCache();
const testData = new ArrayBuffer(8); // Example data
- await cache.put("org", "model", "v1", "file.txt", testData, {
+ await cache.put("org/model", "v1", "file.txt", testData, {
ETag: "ETAG123",
});
const [retrievedData, headers] = await cache.getFile(
- "org",
- "model",
+ "org/model",
"v1",
"file.txt"
);
@@ -289,14 +429,9 @@ add_task(async function test_GetHeaders() {
extra: "extra",
};
- await cache.put("org", "model", "v1", "file.txt", testData, headers);
+ await cache.put("org/model", "v1", "file.txt", testData, headers);
- const storedHeaders = await cache.getHeaders(
- "org",
- "model",
- "v1",
- "file.txt"
- );
+ const storedHeaders = await cache.getHeaders("org/model", "v1", "file.txt");
// The `extra` field should be removed from the stored headers because
// it's not part of the allowed keys.
@@ -318,22 +453,8 @@ add_task(async function test_GetHeaders() {
*/
add_task(async function test_ListModels() {
const cache = await initializeCache();
- await cache.put(
- "org1",
- "modelA",
- "v1",
- "file1.txt",
- new ArrayBuffer(8),
- null
- );
- await cache.put(
- "org2",
- "modelB",
- "v1",
- "file2.txt",
- new ArrayBuffer(8),
- null
- );
+ await cache.put("org1/modelA", "v1", "file1.txt", new ArrayBuffer(8), null);
+ await cache.put("org2/modelB", "v1", "file2.txt", new ArrayBuffer(8), null);
const models = await cache.listModels();
Assert.ok(
@@ -348,10 +469,10 @@ add_task(async function test_ListModels() {
*/
add_task(async function test_DeleteModel() {
const cache = await initializeCache();
- await cache.put("org", "model", "v1", "file.txt", new ArrayBuffer(8), null);
- await cache.deleteModel("org", "model", "v1");
+ await cache.put("org/model", "v1", "file.txt", new ArrayBuffer(8), null);
+ await cache.deleteModel("org/model", "v1");
- const dataAfterDelete = await cache.getFile("org", "model", "v1", "file.txt");
+ const dataAfterDelete = await cache.getFile("org/model", "v1", "file.txt");
Assert.equal(
dataAfterDelete,
null,