/* Any copyright is dedicated to the Public Domain.
http://creativecommons.org/publicdomain/zero/1.0/ */
///
// Load the shared-head file first.
Services.scriptloader.loadSubScript(
"chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/shared-head.js",
this
);
/**
* @type {import("../../actors/MLEngineParent.sys.mjs")}
*/
const { MLEngineParent, MLEngine } = ChromeUtils.importESModule(
"resource://gre/actors/MLEngineParent.sys.mjs"
);
const { ModelHub, TestIndexedDBCache } = ChromeUtils.importESModule(
"chrome://global/content/ml/ModelHub.sys.mjs"
);
const { getInferenceProcessInfo } = ChromeUtils.importESModule(
"chrome://global/content/ml/Utils.sys.mjs"
);
const MS_PER_SEC = 1000;
const IndexedDBCache = TestIndexedDBCache;
const {
createEngine,
PipelineOptions,
QuantizationLevel,
ExecutionPriority,
InferenceDevice,
LogLevel,
} = ChromeUtils.importESModule(
"chrome://global/content/ml/EngineProcess.sys.mjs"
);
// This test suite shares some utility functions with translations as they work in a very
// similar fashion. Eventually, the plan is to unify these two components.
Services.scriptloader.loadSubScript(
"chrome://mochitests/content/browser/toolkit/components/translations/tests/browser/shared-head.js",
this
);
/**
* Sets up the stage for a test
*
*/
async function setup({
disabled = false,
prefs = [],
records = null,
backend,
} = {}) {
const { removeMocks, remoteClients } = await createAndMockMLRemoteSettings({
autoDownloadFromRemoteSettings: false,
records,
backend,
});
await SpecialPowers.pushPrefEnv({
set: [
// Enabled by default.
["browser.ml.enable", !disabled],
["browser.ml.logLevel", "All"],
["browser.ml.modelCacheTimeout", 1000],
["browser.ml.checkForMemory", false],
["browser.ml.queueWaitTimeout", 2],
["javascript.options.wasm_lazy_tiering", true],
...prefs,
],
});
return {
remoteClients,
async cleanup() {
await removeMocks();
await waitForCondition(
() => EngineProcess.areAllEnginesTerminated(),
"Waiting for all of the engines to be terminated.",
100,
200
);
await SpecialPowers.popPrefEnv();
},
};
}
function getDefaultWasmRecords(backend) {
return [
{
name: MLEngineParent.WASM_FILENAME[
backend || MLEngineParent.DEFAULT_BACKEND
],
version:
MLEngineParent.WASM_MAJOR_VERSION[
backend || MLEngineParent.DEFAULT_BACKEND
] + ".0",
},
];
}
async function createAndMockMLRemoteSettings({
autoDownloadFromRemoteSettings = false,
records = null,
backend,
} = {}) {
const wasmRecords = getDefaultWasmRecords(backend).map(
({ name, version }) => ({
id: crypto.randomUUID(),
name,
version,
last_modified: Date.now(),
schema: Date.now(),
})
);
const runtime = await createRemoteClient({
collectionName: "test-translation-wasm",
records: wasmRecords,
attachmentMock: true,
autoDownloadFromRemoteSettings,
});
const options = await createRemoteClient({
records: records || [
{
taskName: "moz-echo",
modelId: "mozilla/distilvit",
processorId: "mozilla/distilvit",
tokenizerId: "mozilla/distilvit",
modelRevision: "main",
processorRevision: "main",
tokenizerRevision: "main",
dtype: "q8",
id: "74a71cfd-1734-44e6-85c0-69cf3e874138",
},
],
collectionName: "test-ml-inference-options",
});
const allowDeny = await createRemoteClient({
records: [
{
filter: "ALLOW",
urlPrefix: "https://",
id: "74a71cfd-1734-44e6-85c0-69cf3e874138",
},
],
collectionName: "test-ml-allow-deny-list",
});
const remoteClients = {
"ml-onnx-runtime": runtime,
"ml-inference-options": options,
"ml-model-allow-deny-list": allowDeny,
};
MLEngineParent.mockRemoteSettings({
"ml-onnx-runtime": runtime.client,
"ml-inference-options": options,
"ml-model-allow-deny-list": allowDeny,
});
return {
async removeMocks() {
await runtime.client.attachments.deleteAll();
await runtime.client.db.clear();
await options.db.clear();
await allowDeny.db.clear();
MLEngineParent.removeMocks();
},
remoteClients,
};
}
/**
* Creates a local RemoteSettingsClient for use within tests.
*
* @returns {RemoteSettings|AttachmentMock}
*/
async function createRemoteClient({
records,
collectionName,
attachmentMock = false,
autoDownloadFromRemoteSettings = false,
}) {
const { RemoteSettings } = ChromeUtils.importESModule(
"resource://services-settings/remote-settings.sys.mjs"
);
const client = RemoteSettings(`${collectionName}-${_remoteSettingsMockId++}`);
await client.db.clear();
await client.db.importChanges({}, Date.now(), records);
if (attachmentMock) {
return createAttachmentMock(
client,
collectionName,
autoDownloadFromRemoteSettings
);
}
return client;
}
/*
* Perftest related
*/
const ONE_MIB = 1024 * 1024;
const INIT_START = "initializationStart";
const INIT_END = "initializationEnd";
const RUN_START = "runStart";
const RUN_END = "runEnd";
const PIPELINE_READY_START = "ensurePipelineIsReadyStart";
const PIPELINE_READY_END = "ensurePipelineIsReadyEnd";
const PIPELINE_READY_LATENCY = "pipeline-ready-latency";
const INITIALIZATION_LATENCY = "initialization-latency";
const MODEL_RUN_LATENCY = "model-run-latency";
const TOTAL_MEMORY_USAGE = "total-memory-usage";
const COLD_START_PREFIX = "cold-start-";
const PEAK_MEMORY_USAGE = "peak-memory-usage";
const ITERATIONS = 10;
const WHEN = "when";
const MEMORY = "memory";
const E2E_INIT_LATENCY = "e2e-init-latency";
const FIRST_TOKEN_LATENCY = "1st-token-latency";
const DECODING_LATENCY = "decoding-latency";
// Token speeds are apppropriate for comparing the speed of the same model.
const DECODING_TOKEN_SPEED = "decoding-tokenSpeed";
const PROMPT_TOKEN_SPEED = "prompt-tokenSpeed";
// Characters speed is appropriate for comparing the speed of two different models.
const DECODING_CHARACTERS_SPEED = "decoding-charactersSpeed";
const PROMPT_CHARACTERS_SPEED = "prompt-charactersSpeed";
const formatNumber = new Intl.NumberFormat("en-US", {
maximumSignificantDigits: 4,
}).format;
function median(arr) {
arr = [...arr].sort((a, b) => a - b);
const mid = Math.floor(arr.length / 2);
if (arr.length % 2) {
return arr[mid];
}
return (arr[mid - 1] + arr[mid]) / 2;
}
function stringify(arr) {
function pad(str) {
str = str.padStart(7, " ");
if (str[0] != " ") {
str = " " + str;
}
return str;
}
return arr.reduce((acc, elem) => acc + pad(formatNumber(elem)), "");
}
function reportMetrics(journal) {
let text = "\nResults (ms)\n";
const names = Object.keys(journal);
const prefixLen = 1 + Math.max(...names.map(str => str.length));
for (const name in journal) {
const med = median(journal[name]);
text += (name + ":").padEnd(prefixLen, " ") + stringify(journal[name]);
text += " median " + formatNumber(med) + "\n";
}
const reportedMetrics = [];
for (const [name, values] of Object.entries(journal)) {
reportedMetrics.push({
name,
values,
value: median(values),
});
}
dump(text);
info(`perfMetrics | ${JSON.stringify(reportedMetrics)}`);
}
/**
* Fetches the latest metric entry with the specified name and retrieves its value for the given key.
* If multiple metrics share the same name, the function returns the key from the most recent one.
*
* @param {Array