249 lines
7.1 KiB
JavaScript
249 lines
7.1 KiB
JavaScript
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
|
|
|
const {
|
|
SmartTabGroupingManager,
|
|
CLUSTER_METHODS,
|
|
ANCHOR_METHODS,
|
|
getBestAnchorClusterInfo,
|
|
} = ChromeUtils.importESModule(
|
|
"moz-src:///browser/components/tabbrowser/SmartTabGrouping.sys.mjs"
|
|
);
|
|
|
|
/**
|
|
* Checks if numbers are close up to decimalPoints decimal points
|
|
*
|
|
* @param {number} a
|
|
* @param {number} b
|
|
* @param {number} decimalPoints
|
|
* @returns {boolean} True if numbers are similar
|
|
*/
|
|
function numberLooseEquals(a, b, decimalPoints = 2) {
|
|
return a.toFixed(decimalPoints) === b.toFixed(decimalPoints);
|
|
}
|
|
|
|
/**
|
|
* Compares two vectors up to decimalPoints decimal points
|
|
* Returns true if all items the same up to decimalPoints threshold
|
|
*
|
|
* @param {number[]} a
|
|
* @param {number[]} b
|
|
* @param {number} decimalPoints
|
|
* @returns {boolean} True if vectors are similar
|
|
*/
|
|
function vectorLooseEquals(a, b, decimalPoints = 2) {
|
|
return a.every(
|
|
(item, index) =>
|
|
item.toFixed(decimalPoints) === b[index].toFixed(decimalPoints)
|
|
);
|
|
}
|
|
|
|
/**
|
|
* Extremely simple generator deterministic seeded list of numbers between
|
|
* 0 and 1 for use of tests in place of a true random generator
|
|
*
|
|
* @param {number} seed
|
|
* @returns {function(): number}
|
|
*/
|
|
function simpleNumberSequence(seed = 0) {
|
|
const values = [
|
|
0.42, 0.145, 0.5, 0.9234, 0.343, 0.1324, 0.8343, 0.534, 0.634, 0.3233,
|
|
];
|
|
let counter = Math.floor(seed) % values.length;
|
|
return () => {
|
|
counter = (counter + 1) % values.length;
|
|
return values[counter];
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Utility function to shuffle an array, using a random
|
|
*
|
|
* @param {object[]} array of items to shuffle
|
|
* @param {Function} randFunc function that returns between 0 and 1
|
|
*/
|
|
function shuffleArray(array, randFunc) {
|
|
randFunc = randFunc ?? Math.random;
|
|
for (let i = array.length - 1; i >= 0; i--) {
|
|
const j = Math.floor(randFunc() * (i + 1));
|
|
[array[i], array[j]] = [array[j], array[i]];
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Returns dict that averages input values
|
|
*
|
|
* @param {object[]} itemArray List of dicts, each with values to average
|
|
* @returns {object} Object with average of values passed in itemArray
|
|
*/
|
|
function averageStatsValues(itemArray) {
|
|
const result = {};
|
|
if (itemArray.length === 0) {
|
|
return result;
|
|
}
|
|
for (const key of Object.keys(itemArray[0])) {
|
|
let total = 0.0;
|
|
itemArray.forEach(a => (total += a[key]));
|
|
result[key] = total / itemArray.length;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* Read tsv file from string
|
|
*
|
|
* @param {string} tsvString string to read from
|
|
* @returns {object} Object with parsed tsv string
|
|
*/
|
|
function parseTsvStructured(tsvString) {
|
|
const rows = tsvString.trim().split("\n");
|
|
const keys = rows[0].split("\t");
|
|
const arrayOfDicts = rows.slice(1).map(row => {
|
|
const values = row.split("\t");
|
|
// Map keys to corresponding values
|
|
const dict = {};
|
|
keys.forEach((key, index) => {
|
|
dict[key] = values[index];
|
|
});
|
|
return dict;
|
|
});
|
|
return arrayOfDicts;
|
|
}
|
|
|
|
/**
|
|
* Read tsv string with embeddings
|
|
*
|
|
* @param {string} tsvString string with embeddings present
|
|
* @returns {object} Object containing the embeddings
|
|
*/
|
|
function parseTsvEmbeddings(tsvString) {
|
|
const rows = tsvString.trim().split("\n");
|
|
return rows.map(row => {
|
|
return row.split("\t").map(value => parseFloat(value));
|
|
});
|
|
}
|
|
|
|
/**
|
|
*
|
|
* @param {string} clusterMethod kmeans or kmeans with anchor
|
|
* @param {string} umapMethod umap or dbscan
|
|
* @param {object[]} tabs tabs to cluster
|
|
* @param {object[]} embeddings precomputed embeddings for the tabs
|
|
* @param {number} iterations number of iterations before stopping clustering
|
|
* @param {number[]} preGroupedTabIndices indices of tabs that are present in the group
|
|
* @param {string} anchorMethod fixed or drift anchor methods
|
|
* @param {number} silBoost what value to multiply silhouette score
|
|
* @returns {Promise<{object}>} average of metric results
|
|
*/
|
|
async function testAugmentGroup(
|
|
clusterMethod,
|
|
umapMethod,
|
|
tabs,
|
|
embeddings,
|
|
iterations = 1,
|
|
preGroupedTabIndices,
|
|
anchorMethod = ANCHOR_METHODS.FIXED,
|
|
silBoost = undefined
|
|
) {
|
|
const groupManager = new SmartTabGroupingManager();
|
|
groupManager.setAnchorMethod(anchorMethod);
|
|
if (silBoost !== undefined) {
|
|
groupManager.setSilBoost(silBoost);
|
|
}
|
|
const randFunc = simpleNumberSequence();
|
|
groupManager.setDataTitleKey("title");
|
|
groupManager.setClusteringMethod(clusterMethod);
|
|
groupManager.setDimensionReductionMethod(umapMethod);
|
|
const allScores = [];
|
|
for (let i = 0; i < iterations; i++) {
|
|
const groupingResult = await groupManager.generateClusters(
|
|
tabs,
|
|
embeddings,
|
|
0,
|
|
randFunc,
|
|
preGroupedTabIndices
|
|
);
|
|
const titleKey = "title";
|
|
const centralClusterTitles = new Set(
|
|
groupingResult.getAnchorCluster().tabs.map(a => a[titleKey])
|
|
);
|
|
groupingResult.getAnchorCluster().print();
|
|
const anchorTitleSet = new Set(
|
|
preGroupedTabIndices.map(a => tabs[a][titleKey])
|
|
);
|
|
Assert.equal(
|
|
centralClusterTitles.intersection(anchorTitleSet).size,
|
|
anchorTitleSet.size,
|
|
`All anchor indices in target cluster`
|
|
);
|
|
const scoreInfo = groupingResult.getAccuracyStatsForCluster(
|
|
"smart_group_label",
|
|
groupingResult.getAnchorCluster().tabs[0].smart_group_label
|
|
);
|
|
allScores.push(scoreInfo);
|
|
}
|
|
return averageStatsValues(allScores);
|
|
}
|
|
|
|
/**
|
|
* Runs clustering test with multiple anchor tabs
|
|
*
|
|
* @param {object[]} data tabs to run test on
|
|
* @param {object []} precomputedEmbeddings embeddings for the tabs
|
|
* @param {number[]} anchorGroupIndices indices of tabs already present in the group
|
|
* @param {string} anchorMethod fixed or drift anchor method
|
|
* @param {number} silBoost value with which to boost silhouette score
|
|
* @returns {Promise<{}|null>} metric stats from running the clustering test
|
|
*/
|
|
async function runAnchorTabTest(
|
|
data,
|
|
precomputedEmbeddings = null,
|
|
anchorGroupIndices,
|
|
anchorMethod = ANCHOR_METHODS.FIXED,
|
|
silBoost = undefined
|
|
) {
|
|
const testParams = [[CLUSTER_METHODS.KMEANS]];
|
|
let scoreInfo;
|
|
for (let testP of testParams) {
|
|
scoreInfo = await testAugmentGroup(
|
|
testP[0],
|
|
testP[1],
|
|
data,
|
|
precomputedEmbeddings,
|
|
1,
|
|
anchorGroupIndices,
|
|
anchorMethod,
|
|
silBoost
|
|
);
|
|
}
|
|
if (testParams.length === 1) {
|
|
return scoreInfo;
|
|
}
|
|
return null;
|
|
}
|
|
|
|
/**
|
|
* Fetches a local file from prefix and filename
|
|
*
|
|
* @param {string} host_prefix root data folder path
|
|
* @param {string} filename name of file
|
|
* @returns {Promise}
|
|
*/
|
|
function fetchFile(host_prefix, filename) {
|
|
return new Promise((resolve, reject) => {
|
|
const xhr = new XMLHttpRequest();
|
|
// const url = `${HOST_PREFIX}${filename}`;
|
|
const url = `${host_prefix}${filename}`;
|
|
xhr.open("GET", url, true);
|
|
xhr.onload = () => {
|
|
if (xhr.status === 200) {
|
|
resolve(xhr.responseText);
|
|
} else {
|
|
reject(new Error(`Failed to fetch data: ${xhr.statusText}`));
|
|
}
|
|
};
|
|
xhr.onerror = () => reject(new Error(`Network error getting ${url}`));
|
|
xhr.send();
|
|
});
|
|
}
|