1
0
Fork 0
firefox/browser/components/tabbrowser/SmartTabGrouping.sys.mjs
Daniel Baumann 5e9a113729
Adding upstream version 140.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
2025-06-25 09:37:52 +02:00

1537 lines
49 KiB
JavaScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
import { createEngine } from "chrome://global/content/ml/EngineProcess.sys.mjs";
import {
cosSim,
KeywordExtractor,
} from "chrome://global/content/ml/NLPUtils.sys.mjs";
import {
computeCentroidFrom2DArray,
computeRandScore,
euclideanDistance,
getAccuracyStats,
kmeansPlusPlus,
silhouetteCoefficients,
} from "chrome://global/content/ml/ClusterAlgos.sys.mjs";
const lazy = {};
ChromeUtils.defineESModuleGetters(lazy, {
NLP: "resource://gre/modules/NLP.sys.mjs",
MLEngineParent: "resource://gre/actors/MLEngineParent.sys.mjs",
MultiProgressAggregator: "chrome://global/content/ml/Utils.sys.mjs",
Progress: "chrome://global/content/ml/Utils.sys.mjs",
});
const LATEST_MODEL_REVISION = "latest";
// Methods for suggesting tabs that are similar to current tab
export const SUGGEST_OTHER_TABS_METHODS = {
KMEANS_WITH_ANCHOR: "KMEANS_WITH_ANCHOR",
NEAREST_NEIGHBOR: "NEAREST_NEIGHBOR",
LOGISTIC_REGRESSION: "LOGISTIC_REGRESSION",
};
XPCOMUtils.defineLazyPreferenceGetter(
lazy,
"suggestOtherTabsMethod",
"browser.tabs.groups.smart.suggestOtherTabsMethod"
);
XPCOMUtils.defineLazyPreferenceGetter(
lazy,
"topicModelRevision",
"browser.tabs.groups.smart.topicModelRevision"
);
XPCOMUtils.defineLazyPreferenceGetter(
lazy,
"embeddingModelRevision",
"browser.tabs.groups.smart.embeddingModelRevision"
);
XPCOMUtils.defineLazyPreferenceGetter(
lazy,
"nearestNeighborThresholdInt",
"browser.tabs.groups.smart.nearestNeighborThresholdInt"
);
const EMBED_TEXT_KEY = "combined_text";
export const CLUSTER_METHODS = {
KMEANS: "KMEANS",
};
// Methods for finding similar items for an existing cluster
export const ANCHOR_METHODS = {
DRIFT: "DRIFT", // We let k-means clustering run, and find the cluster with the most anchor items
FIXED: "FIXED", // We always group with the anchor items in the 0 cluster, and never let them be reassinged
};
// Methods for finding ignoring other groups that were already grouped
export const PREGROUPED_HANDLING_METHODS = {
EXCLUDE: "EXCLUDE", // We let k-means clustering run, and find the cluster with the most anchor items
IGNORE: "IGNORE", // We always group with the anchor items in the 0 cluster, and never let them be reassinged
};
const EXPECTED_TOPIC_MODEL_OBJECTS = 6;
const EXPECTED_EMBEDDING_MODEL_OBJECTS = 4;
export const DIM_REDUCTION_METHODS = {};
const MISSING_ANCHOR_IN_CLUSTER_PENALTY = 0.2;
const MAX_NN_GROUPED_TABS = 4;
const MAX_SUGGESTED_TABS = 10;
const DISSIMILAR_TAB_LABEL = "none";
const ADULT_TAB_LABEL = "adult content";
const LABELS_TO_EXCLUDE = [DISSIMILAR_TAB_LABEL, ADULT_TAB_LABEL];
const ML_TASK_FEATURE_EXTRACTION = "feature-extraction";
const ML_TASK_TEXT2TEXT = "text2text-generation";
const LABEL_REASONS = {
DEFAULT: "DEFAULT",
LOW_CONFIDENCE: "LOW_CONFIDENCE",
EXCLUDE: "EXCLUDE",
ERROR: "ERROR",
};
const SMART_TAB_GROUPING_CONFIG = {
embedding: {
dtype: "q8",
timeoutMS: 2 * 60 * 1000, // 2 minutes
taskName: ML_TASK_FEATURE_EXTRACTION,
featureId: "smart-tab-embedding",
},
topicGeneration: {
dtype: "q8",
timeoutMS: 2 * 60 * 1000, // 2 minutes
taskName: ML_TASK_TEXT2TEXT,
featureId: "smart-tab-topic",
},
dataConfig: {
titleKey: "label",
descriptionKey: "description",
},
clustering: {
dimReductionMethod: null, // Not completed.
clusterImplementation: CLUSTER_METHODS.KMEANS,
clusteringTriesPerK: 3,
anchorMethod: ANCHOR_METHODS.FIXED,
pregroupedHandlingMethod: PREGROUPED_HANDLING_METHODS.EXCLUDE,
pregroupedSilhouetteBoost: 2, // Relative weight of the cluster's score and all other cluster's combined
suggestOtherTabsMethod: SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR,
},
};
// these parameters were generated by training a logistic regression
// model on synthetic data. see https://github.com/mozilla/smart-tab-grouping
// for more info
const LOGISTIC_REGRESSION_PARAMS = {
TITLE_WITH_GROUP_NAME: {
GROUP_SIMILARITY_WEIGHT: 6.76420017,
TITLE_SIMILARITY_WEIGHT: 2.95779555,
INTERCEPT: -3.06862155,
THRESHOLD: 0.45,
},
TITLE_ONLY: {
GROUP_SIMILARITY_WEIGHT: 0,
TITLE_SIMILARITY_WEIGHT: 2.50596721,
INTERCEPT: -0.54293376,
THRESHOLD: 0.6,
},
};
const TAB_URLS_TO_EXCLUDE = [
"about:newtab",
"about:home",
"about:privatebrowsing",
"chrome://browser/content/blanktab.html",
"about:firefoxview",
];
/**
* For a given set of clusters represented by indices, returns the index of the cluster
* that has the most anchor items inside it.
*
* An anhor item is an index that represents the index to a tab that is already grouped and in
* the cluster we're interested in finding more items for.
*
* @param {number[][]} groupIndices - Array of clusters represented as arrays of indices.
* @param {number[]} anchorItems - Array of anchor item indices.
* @returns {{anchorClusterIndex: number, numAnchorItemsInCluster: number}} Index of best cluster and the number of anchor items.
*/
export function getBestAnchorClusterInfo(groupIndices, anchorItems) {
const anchorItemSet = new Set(anchorItems);
const numItemsList = groupIndices.map(g =>
g.reduce(
(cur, itemIndex) => (anchorItemSet.has(itemIndex) ? cur + 1 : cur),
0
)
);
const anchorClusterIndex = numItemsList.indexOf(Math.max(...numItemsList));
const numAnchorItemsInCluster = numItemsList[anchorClusterIndex];
return { anchorClusterIndex, numAnchorItemsInCluster };
}
export class SmartTabGroupingManager {
/**
* Creates the SmartTabGroupingManager object.
* @param {object} config configuration options
*/
constructor(config) {
this.config = config || SMART_TAB_GROUPING_CONFIG;
}
/**
*
* @param {MLEngine} engine the engine to check
* @return {boolean} true if the engine has not been initialized or closed
*/
static isEngineClosed(engine) {
return !engine || engine?.engineStatus === "closed";
}
/**
* Initializes the embedding engine by running a test request
* This helps remove the init latency
*/
async initEmbeddingEngine() {
if (!SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) {
return;
}
try {
this.embeddingEngine = await this._createMLEngine(this.config.embedding);
const request = {
args: ["Test"],
options: { pooling: "mean", normalize: true },
};
this.embeddingEngine.run(request);
} catch (e) {}
}
/**
* Generates suggested tabs for an existing or provisional group
* @param {object} group active group we are adding tabs to
* @param {array} tabs list of tabs from gbrowser, some of which may be grouped in other groups
* @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned.
*/
async smartTabGroupingForGroup(group, tabs) {
// Add tabs to suggested group
const groupTabs = group.tabs;
const allTabs = tabs.filter(tab => {
// Don't include tabs already pinned
if (tab.pinned) {
return false;
}
if (!tab?.linkedBrowser?.currentURI?.spec) {
return false;
}
return true;
});
// find tabs that are part of the group
const groupIndices = groupTabs
.map(a => allTabs.indexOf(a))
.filter(a => a >= 0);
// find tabs that are part of other groups
const alreadyGroupedIndices = allTabs
.map((t, i) => (t.group ? i : -1))
.filter(a => a >= 0);
let suggestedTabs;
switch (lazy.suggestOtherTabsMethod) {
case SUGGEST_OTHER_TABS_METHODS.KMEANS_WITH_ANCHOR:
suggestedTabs = await this.generateClusters(
allTabs,
null,
null,
null,
groupIndices,
alreadyGroupedIndices
).then(clusters => {
if (!clusters) {
return [];
}
const targetCluster = clusters.clusterRepresentations.find(c =>
groupTabs.some(g => c.tabs.includes(g))
);
if (targetCluster) {
// Return only tabs not already grouped
return targetCluster.tabs.filter(t => !t.group);
}
return [];
});
break;
case SUGGEST_OTHER_TABS_METHODS.LOGISTIC_REGRESSION:
suggestedTabs = await this.findSimilarTabsLogisticRegression({
allTabs,
groupedIndices: groupIndices,
alreadyGroupedIndices,
groupLabel: group?.label,
});
break;
case SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR:
default:
// find nearest neighbors to current group
suggestedTabs = await this.findNearestNeighbors({
allTabs,
groupedIndices: groupIndices,
alreadyGroupedIndices,
groupLabel: group?.label,
});
}
return suggestedTabs.slice(0, MAX_SUGGESTED_TABS);
}
/**
* Get tabs that need to be included in suggestions
* @param {array} allTabs all tabs that are part of the window
* @param {array} groupedIndices indices of tabs that are already part of the group
* @param {array} alreadyGroupedIndices indices of tabs that are part of other groups
* @returns {array} tabs indices to be considered for suggestions
*/
getTabsToSuggest(allTabs, groupedIndices, alreadyGroupedIndices) {
// tabs to be excluded
// indices of all tabs that should be excluded (with duplicates)
const tabURLIndicesToExclude = allTabs
.map((at, index) => (TAB_URLS_TO_EXCLUDE.includes(at.url) ? index : -1))
.filter(index => index !== -1);
const excludedTabIndices = [
...groupedIndices,
...alreadyGroupedIndices,
...tabURLIndicesToExclude,
];
// tabs to be included
return allTabs
.map((_, index) => index)
.filter(i => !excludedTabIndices.includes(i));
}
/*
* Generates similar tabs a grouped list of tabs
* @param {array} allTabs all tabs that are part of the window
* @param {array} groupedIndices indices of tabs that are already part of the group
* @param {array} alreadyGroupedIndices indices of tabs that are part of other groups
* @param {string} groupLabel name of group if present
* @param {number} threshold for nearest neighbor similarity
* @returns a list of suggested tabs that are similar to the groupedIndices tabs
*/
async findNearestNeighbors({
allTabs,
groupedIndices,
alreadyGroupedIndices,
groupLabel = "",
thresholdMills = lazy.nearestNeighborThresholdInt,
precomputedEmbeddings = [],
depth = 0,
}) {
// get embeddings for all the tabs
const tabData = await this._prepareTabData(allTabs);
let embeddings = precomputedEmbeddings;
if (precomputedEmbeddings.length === 0) {
embeddings = await this._generateEmbeddings(
tabData.map((td, index) => {
let text = SmartTabGroupingManager.preprocessText(td[EMBED_TEXT_KEY]);
// augment with group name if it's present
if (groupLabel && groupedIndices.includes(index)) {
text = `${groupLabel.slice(0, 100)}. ${text}`;
}
return text;
})
);
}
// tabs that need to be assigned after filtering
const tabsToAssignIndices = this.getTabsToSuggest(
tabData,
groupedIndices,
alreadyGroupedIndices
);
let closestTabs = [];
const similarTabsIndices = [];
for (let i = 0; i < tabsToAssignIndices.length; i++) {
let closestScore = null;
for (
let j = 0;
j < Math.min(groupedIndices.length, MAX_NN_GROUPED_TABS);
j++
) {
const cosineSim = cosSim(
embeddings[tabsToAssignIndices[i]],
embeddings[groupedIndices[j]]
);
if (!closestScore || cosineSim > closestScore) {
closestScore = cosineSim;
}
}
// threshold could also be set via a nimbus experiment, in which case
// it will be an int <= 1000
if (closestScore > thresholdMills / 1000) {
closestTabs.push([allTabs[tabsToAssignIndices[i]], closestScore]);
similarTabsIndices.push(tabsToAssignIndices[i]);
}
}
closestTabs.sort((a, b) => b[1] - a[1]);
closestTabs = closestTabs.map(t => t[0]);
// recurse once if the initial call only had a single tab
// and we found at least 1 similar tab - this improves recall
if (groupedIndices.length === 1 && !!closestTabs.length && depth === 1) {
const recurseSimilarTabs = await this.findNearestNeighbors({
allTabs,
groupedIndices: similarTabsIndices,
alreadyGroupedIndices: alreadyGroupedIndices.concat(groupedIndices),
groupLabel,
thresholdMills,
precomputedEmbeddings: embeddings,
depth: depth - 1,
});
closestTabs = closestTabs.concat(recurseSimilarTabs);
}
return closestTabs;
}
/**
* Calculates the average similarity between the anchor embeddings and the candidate embeddings
* @param {list[Number]} anchorEmbeddings title embeddings for the anchor tabs
* @param {list[Number]} candidateEmbeddings title embeddings for the candidate tabs
*/
getAverageSimilarity(anchorEmbeddings, candidateEmbeddings) {
let averageSimilarities = [];
for (let candidate_embedding of candidateEmbeddings) {
let averageSimilarity = 0;
for (let anchor_embedding of anchorEmbeddings) {
averageSimilarity += cosSim(candidate_embedding, anchor_embedding);
}
averageSimilarities.push(averageSimilarity / anchorEmbeddings.length);
}
return averageSimilarities;
}
/**
* Calculates the sigmoid value of the input
*
* @param {Number} z
* @return {Number}
*/
sigmoid(z) {
return 1 / (1 + Math.exp(-z));
}
/**
* Calculates the probability using the linear combination of the parameters
*
* @param {Number} groupSimilarity how similar a candidate tab is to the group name
* @param {Number} titleSimilarity how similar a candidate tab is to the anchors
* @param {Object} params the logistic regression weights assigned to each parameter
* @return {Number}
*/
calculateProbability(groupSimilarity, titleSimilarity, params) {
return this.sigmoid(
groupSimilarity * params.GROUP_SIMILARITY_WEIGHT +
titleSimilarity * params.TITLE_SIMILARITY_WEIGHT +
params.INTERCEPT
);
}
/**
* Calculates the probabilities given two lists of the same length
*
* @param {list[Number]} groupSimilarities cosine similarity between the candidate tabs and the group name
* @param {list[Number]} titleSimilarities average cosine similarity between the candidate tabs and anchors
* @return {list[Number]} probabilities for each candidate tab
*/
calculateAllProbabilities(groupSimilarities, titleSimilarities) {
const hasGroupSimilarity = Boolean(groupSimilarities);
let probabilities = [];
for (let i = 0; i < titleSimilarities.length; i++) {
probabilities.push(
this.calculateProbability(
hasGroupSimilarity ? groupSimilarities[i] : 0,
titleSimilarities[i],
hasGroupSimilarity
? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME
: LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY
)
);
}
return probabilities;
}
/**
* Generates similar tabs to a grouped list of tabs using a logistic regression "model"
*
* @param {array} allTabs all tabs that are part of the window
* @param {array} groupedIndices indices of tabs that are already part of the group
* @param {array} alreadyGroupedIndices indices of tabs that are part of other groups
* @param {string} groupLabel name of group if present
*/
async findSimilarTabsLogisticRegression({
allTabs,
groupedIndices,
alreadyGroupedIndices,
groupLabel = "",
}) {
const tabData = await this._prepareTabData(allTabs);
const candidateIndices = this.getTabsToSuggest(
tabData,
groupedIndices,
alreadyGroupedIndices
);
const candidateTabsData = candidateIndices.map(ci => allTabs[ci]);
const candidateTabsPrep = await this._prepareTabData(candidateTabsData);
const anchorTabsPrep = groupedIndices
.map(gi => tabData[gi])
.slice(0, MAX_NN_GROUPED_TABS);
// generate embeddings for both anchor and candidate titles
const titleEmbeddings = await this._generateEmbeddings(
anchorTabsPrep
.concat(candidateTabsPrep)
.map(tab => SmartTabGroupingManager.preprocessText(tab[EMBED_TEXT_KEY]))
);
let groupEmbedding;
let groupSimilarities;
if (groupLabel) {
groupEmbedding = await this._generateEmbeddings([groupLabel]);
// calculate similarity between the group and the candidate tabs if group name is present
groupSimilarities = this.getAverageSimilarity(
groupEmbedding,
titleEmbeddings.slice(anchorTabsPrep.length)
);
}
// calculate the similarity between the anchors and candidate titles
const titleSimilarities = this.getAverageSimilarity(
titleEmbeddings.slice(0, anchorTabsPrep.length),
titleEmbeddings.slice(anchorTabsPrep.length)
);
const candidateProbabilities = this.calculateAllProbabilities(
groupSimilarities,
titleSimilarities
);
// get proper params depending on group name availability
const probabilityThreshold = groupEmbedding
? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME.THRESHOLD
: LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY.THRESHOLD;
return (
candidateTabsData
// combine candidate tabs with corresponding probabilities
.map((ct, index) => ({
ct,
prob: candidateProbabilities[index],
}))
// only keep those that are within the probability threshold
.filter(item => item.prob >= probabilityThreshold)
// ensure the highest probability candidates come first in the list
.sort((a, b) => b.prob - a.prob)
// keep the tabs only
.map(item => item.ct)
);
}
/**
* This function will terminate a grouping or label generation in progress
* It is currently not implemented.
*/
terminateProcess() {
// TODO - teminate AI processes, This method will be
// called when tab grouping panel is closed.
}
/**
* Changes the clustering method. Must be one of supported methods.
* @param {string} method Name of method
*/
setClusteringMethod(method) {
if (!(method in CLUSTER_METHODS)) {
throw new Error(`Clustering method ${method} not supported`);
}
this.config.clustering.clusterImplementation = method;
}
/**
* Set the technique for clustering when certain tabs are already assigned to groups
*
* @param {string} method which is one of ANCHOR_METHODS
*/
setAnchorMethod(method) {
if (!(method in ANCHOR_METHODS)) {
throw new Error(`Clustering anchor method ${method} not supported`);
}
this.config.clustering.anchorMethod = method;
}
setSilBoost(boost) {
this.config.clustering.pregroupedSilhouetteBoost = boost;
}
/**
* Sets method to reduce dimensionality of embeddings prior to clustering
* @param {string} method Name of method
*/
setDimensionReductionMethod(method) {
if (method && !(method in DIM_REDUCTION_METHODS)) {
throw new Error(`Dimension reduction method ${method} not supported`);
}
this.config.clustering.dimReductionMethod = method;
}
/**
* Sets the field name of the title of a page to be used when clustering or generating embeddings
* This is useful when clustering test data that is not a tab object
* @param {string} titleKey KEY FOR THE TITLE
*/
setDataTitleKey(titleKey) {
this.config.dataConfig.titleKey = titleKey;
}
/**
* Logs to the appropriate place for debugging. Console for now
* @param {string} msg Message to log
* @param {boolean} useDescription Whether to add description to the final text
*/
log(_msg) {}
/**
* Prepares data to be used by the ml models
* @param {Object[]} tabList list of tabs in the current window
* @param {boolean} useDescription whether we should combined the title and description
* @return {Promise<*[Object]>}
* @private
*/
async _prepareTabData(tabList, useDescription = false) {
const titleKey = this.config.dataConfig.titleKey;
const descriptionKey = this.config.dataConfig.descriptionKey;
const structuredData = [];
for (let tab of tabList) {
const description =
useDescription && descriptionKey && tab[descriptionKey];
let textToEmbed;
if (description) {
textToEmbed = tab[titleKey] + " " + description;
} else {
textToEmbed = tab[titleKey] || "Unknown";
}
structuredData.push({
[EMBED_TEXT_KEY]: textToEmbed,
title: tab[titleKey],
description,
url: tab?.linkedBrowser?.currentURI?.spec,
});
}
return structuredData;
}
/**
* Get updated config for the ml engine
*
* @param {object} initData
* @param {string} featureId
* @return {*}
*/
static getUpdatedInitData(initData, featureId) {
// we're setting a specific modelRevision through about:config or Nimbus
if (
featureId === SMART_TAB_GROUPING_CONFIG.topicGeneration.featureId &&
lazy.topicModelRevision !== LATEST_MODEL_REVISION
) {
initData.modelRevision = lazy.topicModelRevision;
} else if (
featureId === SMART_TAB_GROUPING_CONFIG.embedding.featureId &&
lazy.embeddingModelRevision !== LATEST_MODEL_REVISION
) {
initData.modelRevision = lazy.embeddingModelRevision;
}
return initData;
}
/**
* Creates an ML engine for a given config.
* @param {*} engineConfig
* @param {function} progressCallback
* @returns MLEngine
*/
async _createMLEngine(engineConfig, progressCallback) {
const {
featureId,
engineId,
dtype,
taskName,
timeoutMS,
modelId,
modelRevision,
} = engineConfig;
let initData = {
featureId,
engineId,
dtype,
taskName,
timeoutMS,
modelId,
modelRevision,
};
initData = SmartTabGroupingManager.getUpdatedInitData(initData, featureId);
return await createEngine(initData, progressCallback);
}
/**
* Generates embeddings from a list of tab data structures
* @param tabList List of tabs with label (title) and description keys
* @returns {Promise<*[]>} List of embeddings (2d array)
* @private
*/
async _generateEmbeddings(textToEmbedList) {
const inputData = {
inputArgs: textToEmbedList,
runOptions: {
pooling: "mean",
normalize: true,
},
};
if (SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) {
this.embeddingEngine = await this._createMLEngine(this.config.embedding);
}
const request = {
args: [inputData.inputArgs],
options: inputData.runOptions,
};
return await this.embeddingEngine.run(request);
}
/**
* Clusters in desired methods
* based on the config of the class
* @param tabList List of tabs as array
* @param docEmbeddings Precomputed embeddings for the Tab as two dimensional array
* @param k Desired number of clusters. Tries a range of sizes if 0.
* @param {function} randomFunc Optional seeded random number generator for testing
* @returns {SmartTabGroupingResult}
* @private
*/
_clusterEmbeddings({
tabs,
embeddings,
k,
randomFunc,
anchorIndices,
alreadyGroupedIndices = [],
}) {
let allItems;
const freezeAnchorsInZeroCluster =
anchorIndices &&
this.config.clustering.anchorMethod == ANCHOR_METHODS.FIXED;
const dimReductionMethod = this.config.clustering.dimReductionMethod;
switch (dimReductionMethod) {
default:
// Dimensionality reduction support is landing very soon.
break;
}
k = k || 0;
let startK = k;
let endK = k + 1;
if (!k) {
startK = 2;
// Find a reasonable max # of clusters
endK =
Math.min(
Math.floor(Math.log(embeddings.length) * 2.0),
embeddings.length
) + 1;
}
let bestResult;
let bestResultSilScore = -100.0;
let bestResultCenterCluster = 0;
const clusteringMethod = this.config.clustering.clusterImplementation;
const clusteringTriesPerK = this.config.clustering.clusteringTriesPerK;
for (let curK = startK; curK < endK; curK++) {
let bestItemsForK;
let bestInertiaForK = 500000000000;
for (let j = 0; j < clusteringTriesPerK; j++) {
switch (clusteringMethod) {
case CLUSTER_METHODS.KMEANS:
allItems = kmeansPlusPlus({
data: embeddings,
k: curK,
maxIterations: 0,
randomFunc,
anchorIndices,
preassignedIndices:
this.config.clustering.pregroupedHandlingMethod ===
PREGROUPED_HANDLING_METHODS.EXCLUDE
? alreadyGroupedIndices
: [],
freezeAnchorsInZeroCluster,
});
break;
default:
throw Error("Clustering implementation not supported");
}
const tempResult = new SmartTabGroupingResult({
indices: allItems,
embeddings,
config: this.config,
});
const inertia = tempResult.getCentroidInertia();
if (inertia < bestInertiaForK) {
bestInertiaForK = inertia;
bestItemsForK = tempResult;
}
}
const silScores = silhouetteCoefficients(
embeddings,
bestItemsForK.indices
);
if (
freezeAnchorsInZeroCluster &&
this.config.clustering.pregroupedSilhouetteBoost > 0
) {
// Boost silhouette score of target cluster when we are grouping around an existing cluster
// pregroupedSilhouetteBoost indicates the relative weight of the cluster's score and all other cluster's combined
silScores[0] *= this.config.clustering.pregroupedSilhouetteBoost;
}
let avgSil = silScores.reduce((p, c) => p + c, 0) / silScores.length;
let curAnchorCluster = 0;
if (anchorIndices && !freezeAnchorsInZeroCluster) {
const { anchorClusterIndex, numAnchorItemsInCluster } =
getBestAnchorClusterInfo(bestItemsForK.indices, anchorIndices);
curAnchorCluster = anchorClusterIndex;
const penalty =
(MISSING_ANCHOR_IN_CLUSTER_PENALTY *
(anchorIndices.length - numAnchorItemsInCluster)) /
anchorIndices.length;
avgSil -= penalty;
}
if (avgSil > bestResultSilScore) {
bestResultSilScore = avgSil;
bestResult = bestItemsForK.indices;
bestResultCenterCluster = curAnchorCluster;
}
}
const result = new SmartTabGroupingResult({
indices: bestResult,
tabs,
embeddings,
config: this.config,
});
if (anchorIndices) {
result.setAnchorClusterIndex(
freezeAnchorsInZeroCluster ? 0 : bestResultCenterCluster
); // In our k-means clustering implementation anchor cluster is always first
if (!freezeAnchorsInZeroCluster) {
result.adjustClusterForAnchors(anchorIndices);
}
}
return result;
}
/**
* Generate a label for tabs in a group created by the user
*
* @param tabs tabs that are currently in the group
* @param otherTabs tabs in the window not part of the group
* @return {Promise<null|string|string|*>}
*/
async getPredictedLabelForGroup(tabs, otherTabs) {
const clusters = this.createStaticCluster(tabs);
const otherClusters = this.createStaticCluster(otherTabs);
let predictedLabel;
try {
// function below modifies "clusters" object
await this.generateGroupLabels(clusters, otherClusters);
predictedLabel = clusters.clusterRepresentations[0].predictedTopicLabel;
} catch (e) {
this.labelReason = LABEL_REASONS.ERROR;
predictedLabel = "";
}
return predictedLabel;
}
/**
* Generates clusters for a given list of tabs using precomputed embeddings or newly generated ones.
*
* @param {Object[]} tabList - List of tab objects to be clustered.
* @param {number[][]} [precomputedEmbeddings] - Precomputed embeddings for tab titles and descriptions.
* @param {number} numClusters - Number of clusters to form.
* @param {Function} randFunc - Random function used for clustering initialization.
* @param {number[]} [anchorIndices=[]] - Indices of anchor tabs that should be prioritized in clustering.
* @param {number[]} [alreadyGroupedIndices=[]] - Indices of tabs that are already assigned to groups.
* @returns {SmartTabGroupingResult} - The best clustering result based on centroid inertia.
*/
async generateClusters(
tabList,
precomputedEmbeddings,
numClusters,
randFunc,
anchorIndices = [],
alreadyGroupedIndices = []
) {
numClusters = numClusters ?? 0;
const structuredData = await this._prepareTabData(tabList);
// embeddings for title and description
if (precomputedEmbeddings) {
this.docEmbeddings = precomputedEmbeddings;
} else {
this.docEmbeddings = await this._generateEmbeddings(
structuredData.map(a => a[EMBED_TEXT_KEY])
);
}
let bestResultCluster;
let bestResultDistance = 50000000.0;
const NUM_RUNS = 1;
for (let i = 0; i < NUM_RUNS; i++) {
const curResult = this._clusterEmbeddings({
tabs: tabList,
embeddings: this.docEmbeddings,
k: numClusters,
randomFunc: randFunc,
anchorIndices,
alreadyGroupedIndices,
});
const distance = curResult.getCentroidInertia();
if (distance < bestResultDistance) {
bestResultDistance = distance;
bestResultCluster = curResult;
}
}
return bestResultCluster;
}
/**
* Create static cluster from a list of tabs. A single tab is Ok. Returns null for 0 tabs
* @param tabs
* @returns {SmartTabGroupingResult} groupingResult
*/
createStaticCluster(tabs) {
if (!tabs) {
return null;
}
return new SmartTabGroupingResult({
indices: [Array.from({ length: tabs.length }, (_, i) => i)],
tabs,
config: this.config,
});
}
/***
* Utility function that loads all required engines for Smart Tab Grouping and any dependent models
* @param {(progress: { percentage: number }) => void} progressCallback callback function to call.
* Callback passes a dict with percentage indicating best effort 0.0-100.0 progress in model download.
*/
async preloadAllModels(progressCallback) {
let previousProgress = -1;
const expectedObjects =
EXPECTED_TOPIC_MODEL_OBJECTS + EXPECTED_EMBEDDING_MODEL_OBJECTS;
// TODO - Find a way to get these fields. Add as a transformers js callback or within remotesettings
const UPDATE_THRESHOLD_PERCENTAGE = 0.5;
const ONE_MB = 1024 * 1024;
const START_THRESHOLD_BYTES = ONE_MB * 0.2;
const mutliProgressAggregator = new lazy.MultiProgressAggregator({
progressCallback: ({ progress, totalLoaded, metadata }) => {
if (totalLoaded < START_THRESHOLD_BYTES) {
progress = 0.0;
} else {
const numObjSeen = metadata.totalObjectsSeen || 0;
if (numObjSeen > 0 && numObjSeen < expectedObjects) {
// When starting to download we may still be getting configs and not have all the data
progress *= numObjSeen / expectedObjects;
}
if (progress > 100) {
progress = 100;
}
}
if (
Math.abs(previousProgress - progress) > UPDATE_THRESHOLD_PERCENTAGE
) {
// Update only once changes are above a threshold to avoid throttling the UI with events.
progressCallback({
percentage: progress,
});
previousProgress = progress;
}
},
watchedTypes: [
lazy.Progress.ProgressType.DOWNLOAD,
lazy.Progress.ProgressType.LOAD_FROM_CACHE,
],
});
const [topicEngine, embeddingEngine] = await Promise.all([
this._createMLEngine(
this.config.topicGeneration,
mutliProgressAggregator?.aggregateCallback.bind(
mutliProgressAggregator
) || null
),
this._createMLEngine(
this.config.embedding,
mutliProgressAggregator?.aggregateCallback.bind(
mutliProgressAggregator
) || null
),
]);
this.topicEngine = topicEngine;
this.embeddingEngine = embeddingEngine;
}
/**
* Generate model input from keywords and documents
* @param {string []} keywords
* @param {string []} documents
*/
createModelInput(keywords, documents) {
if (!keywords || keywords.length === 0) {
return `Topic from keywords: titles: \n${documents.join(" \n")}`;
}
return `Topic from keywords: ${keywords.join(", ")}. titles: \n${documents.join(" \n")}`;
}
/**
* One artifact of the LLM output is that sometimes words are duplicated
* This function cuts the phrase when it sees the first duplicate word.
* Handles simple singluar / plural duplicates (-s only).
* @param {string} phrase Input phrase
* @returns {string} phrase cut before any duplicate word
*/
static cutAtDuplicateWords(phrase) {
if (!phrase.length) {
return phrase;
}
const wordsSet = new Set();
const wordList = phrase.split(" ");
for (let i = 0; i < wordList.length; i++) {
let baseWord = wordList[i].toLowerCase();
if (baseWord.length > 3) {
if (baseWord.slice(-1) === "s") {
baseWord = baseWord.slice(0, -1);
}
}
if (wordsSet.has(baseWord)) {
// We are seeing a baseWord word. Exit with just the words so far and don't
// add any new words
return wordList.slice(0, i).join(" ");
}
wordsSet.add(baseWord);
}
return phrase; // return original phrase
}
/**
* Removes trailing domain-related text such as '... - Mail' or '... | News'
* If there's not enough information remaining after, we keep the text as is
* @param {string} text tab title with potential domain information
* @return {string}
*/
static preprocessText(text) {
// Matches 'xyz - Domain' or 'xyz | Domain'
// with a space before and after delimiter
// or if there are multiple delimiters next to each other
const delimiters = /(?<=\s)[|-]+(?=\s)/;
const splitText = text.split(delimiters);
// ensure there's enough info without the last element
const hasEnoughInfo =
!!splitText.length && splitText.slice(0, -1).join(" ").length > 5;
// domain related texts are usually shorter, this takes care of the most common cases
const isPotentialDomainInfo =
splitText.length > 1 && splitText[splitText.length - 1].length < 20;
// If both conditions are met, remove the last chunk, filter out empty strings,
// join on space, trim, and lowercase
if (hasEnoughInfo && isPotentialDomainInfo) {
return splitText
.slice(0, -1) // everything except the last element
.map(t => t.trim())
.filter(Boolean) // remove empty strings
.join(" ") // join with spaces
.trim(); // remove leading/trailing spaces
}
// Otherwise, just return the text
return text;
}
/**
* Postprocessing of raw output from Topic Model ML Engine
* @param {string | undefined} topic Raw topic phrase from topic model or undefined in case of an error
*/
processTopicModelResult(topic) {
let basicResult = (topic || "").trim();
if (!basicResult) {
this.labelReason = LABEL_REASONS.LOW_CONFIDENCE;
}
if (LABELS_TO_EXCLUDE.includes(basicResult.toLowerCase())) {
this.labelReason = LABEL_REASONS.EXCLUDE;
return "";
}
return SmartTabGroupingManager.cutAtDuplicateWords(basicResult);
}
/**
* Add titles to a cluster in a SmartTabGroupingResult using generative tehniques
* Currently this function only works with a single target group, and a separate
* item that represents all other ungrouped tabs.
*
* In the future this may be updated to more generally find labels for a set of clusters.
* @param {SmartTabGroupingResult} groupingResult The cluster we are generating the label for
* @param {SmartTabGroupingResult} otherGroupingResult A 'made up' cluster representing all other tabs in the window
*/
async generateGroupLabels(groupingResult, otherGroupingResult = null) {
const { keywords, documents } =
groupingResult.getRepresentativeDocsAndKeywords(
otherGroupingResult
? otherGroupingResult.getRepresentativeDocuments()
: []
);
const inputArgs = this.createModelInput(
keywords ? keywords[0] : [],
documents
);
const requestInfo = {
inputArgs,
runOptions: {
max_length: 6,
},
};
if (SmartTabGroupingManager.isEngineClosed(this.topicEngine)) {
this.topicEngine = await this._createMLEngine(
this.config.topicGeneration
);
}
const request = {
args: [requestInfo.inputArgs],
options: requestInfo.runOptions,
};
const genLabelResults = await this.topicEngine.run(request);
genLabelResults.forEach((genResult, genResultIndex) => {
groupingResult.clusterRepresentations[
genResultIndex
].predictedTopicLabel = this.processTopicModelResult(
genResult.generated_text
);
});
}
getLabelReason() {
return this.labelReason || LABEL_REASONS.DEFAULT;
}
/**
* Generates glean metrics for ml smart tab label / topic.
* This is currently called when the user saves or cancels the "suggest label" flow.
*
* @param {string} action "save" or "cancel"
* @param {number} numTabsInGroup Number of tabs used to generate the label
* @param {string} mlLabel ML generated label for the tab group
* @param {string} userLabel User saved label for the tab group
* @param {string} id The id of the group
*/
async handleLabelTelemetry({
action,
numTabsInGroup,
mlLabel,
userLabel,
id = "",
}) {
const { [ML_TASK_TEXT2TEXT]: topicEngineConfig } =
await this.getEngineConfigs();
const labelReason = this.getLabelReason();
Glean.tabgroup.smartTabTopic.record({
action,
tabs_in_group: numTabsInGroup,
ml_label_length: (mlLabel || "").length,
user_label_length: (userLabel || "").length,
levenshtein_distance: lazy.NLP.levenshtein(
userLabel || "",
mlLabel || ""
),
model_revision: topicEngineConfig.modelRevision || "",
id,
label_reason: labelReason,
});
this.labelReason = LABEL_REASONS.DEFAULT;
}
/**
* Generates glean metrics for ml smart tab label / topic.
* This is currently called when the user saves or cancels the "suggest other tabs" flow
*
* @param {string} action "save" or "cancel"
* @param {number} numTabsInWindow Number of tabs in the current window
* @param {number} numTabsInGroup Number of tabs in the current group
* @param {number} numTabsSuggested Number of tabs suggested by the model
* @param {number} numTabsApproved Number of tabs approved by the user
* @param {number} numTabsRemoved Number of tabs removed by the user
* @param {string} id The id of the group
*/
async handleSuggestTelemetry({
action,
numTabsInWindow,
numTabsInGroup,
numTabsSuggested,
numTabsApproved,
numTabsRemoved,
id = "",
}) {
const { [ML_TASK_FEATURE_EXTRACTION]: embeddingEngineConfig } =
await this.getEngineConfigs();
Glean.tabgroup.smartTabSuggest.record({
action,
tabs_in_window: numTabsInWindow,
tabs_in_group: numTabsInGroup,
tabs_suggested: numTabsSuggested,
tabs_approved: numTabsApproved,
tabs_removed: numTabsRemoved,
model_revision: embeddingEngineConfig.modelRevision || "",
id,
});
}
/**
* Gets config that engine was initialized with
*
* @return {Promise<{"[ML_TASK_TEXT2TEXT]", "[ML_TASK_FEATURE_EXTRACTION]"}>}
*/
async getEngineConfigs() {
if (!this.topicEngineConfig) {
this.topicEngineConfig = await lazy.MLEngineParent.getInferenceOptions(
this.config.topicGeneration.featureId,
this.config.topicGeneration.taskName
);
}
if (!this.embeddingEngineConfig) {
this.embeddingEngineConfig =
await lazy.MLEngineParent.getInferenceOptions(
this.config.embedding.featureId,
this.config.embedding.taskName
);
}
return {
[ML_TASK_TEXT2TEXT]: this.topicEngineConfig,
[ML_TASK_FEATURE_EXTRACTION]: this.embeddingEngineConfig,
};
}
}
export class SmartTabGroupingResult {
#anchorClusterIndex = -1; // Index of cluster that has original items we're building clustering around, when building around an existing item.
/**
* Creates a result from indices and complete tab and embedding lists.
* This may create some extra data for management later
* @param indices indices of clusters (eg [[2,4], [1], [3]]_
* @param tabItems 1D array of tabs
* @param embeddingItems Two dimensional array of embeddings
* @param config Cluster config
*/
constructor({ indices = [], tabs, embeddings, config }) {
this.embeddingItems = embeddings;
this.config = config;
this.indices = indices.filter(subArray => !!subArray.length); // Cleanup any empty clusters
this.tabItems = tabs;
this._buildClusterRepresentations();
}
/**
* Builds list of ClusterRepresentations
*/
_buildClusterRepresentations() {
this.clusterRepresentations = this.indices.map(subClusterIndices => {
const tabItemsMapped =
this.tabItems && subClusterIndices.map(idx => this.tabItems[idx]);
const embeddingItemsMapped =
this.embeddingItems &&
subClusterIndices.map(idx => this.embeddingItems[idx]);
return new ClusterRepresentation({
tabs: tabItemsMapped,
embeddings: embeddingItemsMapped,
config: this.config,
});
});
}
/**
* Returns a list of documents for each cluster. Currently it is a list of documents picked
* in no particular order.
* @return {[strings]} Title and description that represent the cluster. (If no docs are in the class, then titles are returned)
*/
getRepresentativeDocuments() {
if (!this.documents) {
this.documents = this.tabItems.map(
t => t[this.config.dataConfig.titleKey]
);
}
// set a limit of 10 for now
return this.documents.slice(0, 10);
}
/**
* Returns the keywords and documents for the cluster, computing if needed
* Does not return keywods if only one document is passed to the function.
* @param{string[]} otherDocuments other clusters that we'll compare against
* @return keywords and documents that represent the cluster
*/
getRepresentativeDocsAndKeywords(otherDocuments = []) {
this.documents = this.getRepresentativeDocuments();
if (!this.keywords) {
const joinedDocs = this.documents.slice(0, 3).join(" ");
const otherDocs = otherDocuments.join(" ");
if (this.documents.length > 1) {
const keywordExtractor = new KeywordExtractor();
this.keywords = keywordExtractor.fitTransform([joinedDocs, otherDocs]);
} else {
this.keywords = [];
}
}
return { keywords: this.keywords, documents: this.documents };
}
setAnchorClusterIndex(index) {
this.#anchorClusterIndex = index;
}
/**
* Get the cluster we originally are grouping around (finding additinoal item)
* @returns ClusterRepresentation
*/
getAnchorCluster() {
if (this.#anchorClusterIndex === -1) {
return null;
}
return this.clusterRepresentations[this.#anchorClusterIndex];
}
/**
* Given the indices that we were clustering around, make sure they are are all in the target grouping
* Our generic k-means clustering might have them in separate groups
*/
adjustClusterForAnchors(anchorIndices) {
if (!anchorIndices.length) {
return;
}
const anchorSet = new Set(anchorIndices);
for (let i = 0; i < this.indices.length; i++) {
if (i === this.#anchorClusterIndex) {
continue;
}
this.indices[i] = this.indices[i].filter(item => {
if (anchorSet.has(item)) {
this.indices[this.#anchorClusterIndex].push(item);
return false;
}
return true;
});
}
this._buildClusterRepresentations();
}
/**
* Prints information about the cluster
*/
printClusters() {
for (let cluster of this.clusterRepresentations) {
cluster.print();
}
}
/**
* Computes the inertia of the cluster which is the sum of square total distance.
* @returns {number}
*/
getCentroidInertia() {
let runningTotalDistance = 0;
this.clusterRepresentations.forEach(rep => {
runningTotalDistance += rep.computeTotalSquaredCentroidDistance();
});
return runningTotalDistance;
}
/**
* Converts a cluster representation to a flat list of tabs, with clusterID key in each
* tab representing the id of the cluster it was part of.
* @returns {[Object]}
*/
_flatMapItemsInClusters() {
return this.clusterRepresentations.reduce((result, clusterRep) => {
const annotatedTabs = clusterRep.tabs.map(a => {
let c = {};
Object.assign(c, a);
c.clusterID = clusterRep.clusterID;
return c;
});
return result.concat(annotatedTabs);
}, []);
}
/**
* Get rand score which describes the accuracy versus a user labeled
* annotation on the dataset. Requires the dataset to be labeled.
* @param labelKey Key in the tabs that represent a unique label ID for the cluster.
* @returns {number} The rand score.
*/
getRandScore(labelKey = "annotatedLabel") {
const combinedItems = this._flatMapItemsInClusters();
return computeRandScore(combinedItems, "clusterID", labelKey);
}
/**
* Get accuracy for a specific cluster
* @param labelKey Key in the tabs that represent a unique label ID for the cluster.
* @param clusterValue is the cluster we are comparing
* @returns {number} The rand score.
*/
getAccuracyStatsForCluster(labelKey = "annotatedLabel", clusterValue) {
const combinedItems = this._flatMapItemsInClusters();
let keyClusterId = combinedItems.find(
a => a[labelKey] === clusterValue
).clusterID;
let truePositives = 0,
trueNegatives = 0,
falseNegatives = 0,
falsePositives = 0;
combinedItems.forEach(item => {
const sameLabel = item[labelKey] === clusterValue;
const sameCluster = item.clusterID === keyClusterId;
if (sameLabel && sameCluster) {
truePositives++;
}
if (!sameLabel && !sameCluster) {
trueNegatives++;
}
if (sameLabel && !sameCluster) {
falseNegatives++;
}
if (!sameLabel && sameCluster) {
falsePositives++;
}
});
return getAccuracyStats({
truePositives,
trueNegatives,
falsePositives,
falseNegatives,
});
}
}
/**
* Utility function to generate a random ID string
* @param len Length of the string
* @returns {string}
*/
function genHexString(len) {
const hex = "0123456789ABCDEF";
let output = "";
for (let i = 0; i < len; ++i) {
output += hex.charAt(Math.floor(Math.random() * hex.length));
}
return output;
}
class EmbeddingCluster {
constructor({ tabs, embeddings, centroid }) {
this.embeddings = embeddings;
this.centroid =
centroid || (embeddings && computeCentroidFrom2DArray(this.embeddings));
this.tabs = tabs;
}
/**
* @returns total sum euclidan squared distance of each item from cluster's centroid
*/
computeTotalSquaredCentroidDistance() {
let totalDistance = 0;
if (this.embeddings.length === 0) {
return 0;
}
this.embeddings.forEach(embedding => {
totalDistance += euclideanDistance(this.centroid, embedding, true);
});
return totalDistance;
}
/**
* Returns number of items in the cluster
* @returns {int}
*/
numItems() {
return this.tabs.length;
}
}
/**
* Represents a single cluster with additional saved metadata
*/
export class ClusterRepresentation extends EmbeddingCluster {
constructor({ tabs, embeddings, centroid, config }) {
super({ tabs, embeddings, centroid });
this.config = config;
this.predictedTopicLabel = null;
this.annotatedTopicLabel = null;
this.userEditedTopicLabel = null;
this.representativeText = null;
this.keywords = null;
this.documents = null;
this.clusterID = genHexString(10);
}
/**
* Returns the representative text for a cluster, computing it if needed
*/
getRepresentativeText() {
if (!this.representativeText) {
this.representativeText = this._generateRepresentativeText();
}
return this.representativeText;
}
/**
* Returns representative text for a cluster.
* For this in initial implementation it simply returns title from a few tabs
* @returns {string}
* @private
*/
_generateRepresentativeText() {
let text = "";
const titleKey = this.config.dataConfig.titleKey;
for (const tab of this.tabs.slice(0, 3)) {
text += `\n${tab[titleKey]}`;
}
return text;
}
print() {
// Add console log for debugging
}
}