1537 lines
49 KiB
JavaScript
1537 lines
49 KiB
JavaScript
/* This Source Code Form is subject to the terms of the Mozilla Public
|
||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
||
|
||
import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
|
||
import { createEngine } from "chrome://global/content/ml/EngineProcess.sys.mjs";
|
||
import {
|
||
cosSim,
|
||
KeywordExtractor,
|
||
} from "chrome://global/content/ml/NLPUtils.sys.mjs";
|
||
|
||
import {
|
||
computeCentroidFrom2DArray,
|
||
computeRandScore,
|
||
euclideanDistance,
|
||
getAccuracyStats,
|
||
kmeansPlusPlus,
|
||
silhouetteCoefficients,
|
||
} from "chrome://global/content/ml/ClusterAlgos.sys.mjs";
|
||
|
||
const lazy = {};
|
||
|
||
ChromeUtils.defineESModuleGetters(lazy, {
|
||
NLP: "resource://gre/modules/NLP.sys.mjs",
|
||
MLEngineParent: "resource://gre/actors/MLEngineParent.sys.mjs",
|
||
MultiProgressAggregator: "chrome://global/content/ml/Utils.sys.mjs",
|
||
Progress: "chrome://global/content/ml/Utils.sys.mjs",
|
||
});
|
||
|
||
const LATEST_MODEL_REVISION = "latest";
|
||
|
||
// Methods for suggesting tabs that are similar to current tab
|
||
export const SUGGEST_OTHER_TABS_METHODS = {
|
||
KMEANS_WITH_ANCHOR: "KMEANS_WITH_ANCHOR",
|
||
NEAREST_NEIGHBOR: "NEAREST_NEIGHBOR",
|
||
LOGISTIC_REGRESSION: "LOGISTIC_REGRESSION",
|
||
};
|
||
|
||
XPCOMUtils.defineLazyPreferenceGetter(
|
||
lazy,
|
||
"suggestOtherTabsMethod",
|
||
"browser.tabs.groups.smart.suggestOtherTabsMethod"
|
||
);
|
||
|
||
XPCOMUtils.defineLazyPreferenceGetter(
|
||
lazy,
|
||
"topicModelRevision",
|
||
"browser.tabs.groups.smart.topicModelRevision"
|
||
);
|
||
|
||
XPCOMUtils.defineLazyPreferenceGetter(
|
||
lazy,
|
||
"embeddingModelRevision",
|
||
"browser.tabs.groups.smart.embeddingModelRevision"
|
||
);
|
||
|
||
XPCOMUtils.defineLazyPreferenceGetter(
|
||
lazy,
|
||
"nearestNeighborThresholdInt",
|
||
"browser.tabs.groups.smart.nearestNeighborThresholdInt"
|
||
);
|
||
|
||
const EMBED_TEXT_KEY = "combined_text";
|
||
export const CLUSTER_METHODS = {
|
||
KMEANS: "KMEANS",
|
||
};
|
||
|
||
// Methods for finding similar items for an existing cluster
|
||
export const ANCHOR_METHODS = {
|
||
DRIFT: "DRIFT", // We let k-means clustering run, and find the cluster with the most anchor items
|
||
FIXED: "FIXED", // We always group with the anchor items in the 0 cluster, and never let them be reassinged
|
||
};
|
||
|
||
// Methods for finding ignoring other groups that were already grouped
|
||
export const PREGROUPED_HANDLING_METHODS = {
|
||
EXCLUDE: "EXCLUDE", // We let k-means clustering run, and find the cluster with the most anchor items
|
||
IGNORE: "IGNORE", // We always group with the anchor items in the 0 cluster, and never let them be reassinged
|
||
};
|
||
|
||
const EXPECTED_TOPIC_MODEL_OBJECTS = 6;
|
||
const EXPECTED_EMBEDDING_MODEL_OBJECTS = 4;
|
||
|
||
export const DIM_REDUCTION_METHODS = {};
|
||
const MISSING_ANCHOR_IN_CLUSTER_PENALTY = 0.2;
|
||
const MAX_NN_GROUPED_TABS = 4;
|
||
const MAX_SUGGESTED_TABS = 10;
|
||
|
||
const DISSIMILAR_TAB_LABEL = "none";
|
||
const ADULT_TAB_LABEL = "adult content";
|
||
const LABELS_TO_EXCLUDE = [DISSIMILAR_TAB_LABEL, ADULT_TAB_LABEL];
|
||
|
||
const ML_TASK_FEATURE_EXTRACTION = "feature-extraction";
|
||
const ML_TASK_TEXT2TEXT = "text2text-generation";
|
||
|
||
const LABEL_REASONS = {
|
||
DEFAULT: "DEFAULT",
|
||
LOW_CONFIDENCE: "LOW_CONFIDENCE",
|
||
EXCLUDE: "EXCLUDE",
|
||
ERROR: "ERROR",
|
||
};
|
||
|
||
const SMART_TAB_GROUPING_CONFIG = {
|
||
embedding: {
|
||
dtype: "q8",
|
||
timeoutMS: 2 * 60 * 1000, // 2 minutes
|
||
taskName: ML_TASK_FEATURE_EXTRACTION,
|
||
featureId: "smart-tab-embedding",
|
||
},
|
||
topicGeneration: {
|
||
dtype: "q8",
|
||
timeoutMS: 2 * 60 * 1000, // 2 minutes
|
||
taskName: ML_TASK_TEXT2TEXT,
|
||
featureId: "smart-tab-topic",
|
||
},
|
||
dataConfig: {
|
||
titleKey: "label",
|
||
descriptionKey: "description",
|
||
},
|
||
clustering: {
|
||
dimReductionMethod: null, // Not completed.
|
||
clusterImplementation: CLUSTER_METHODS.KMEANS,
|
||
clusteringTriesPerK: 3,
|
||
anchorMethod: ANCHOR_METHODS.FIXED,
|
||
pregroupedHandlingMethod: PREGROUPED_HANDLING_METHODS.EXCLUDE,
|
||
pregroupedSilhouetteBoost: 2, // Relative weight of the cluster's score and all other cluster's combined
|
||
suggestOtherTabsMethod: SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR,
|
||
},
|
||
};
|
||
|
||
// these parameters were generated by training a logistic regression
|
||
// model on synthetic data. see https://github.com/mozilla/smart-tab-grouping
|
||
// for more info
|
||
const LOGISTIC_REGRESSION_PARAMS = {
|
||
TITLE_WITH_GROUP_NAME: {
|
||
GROUP_SIMILARITY_WEIGHT: 6.76420017,
|
||
TITLE_SIMILARITY_WEIGHT: 2.95779555,
|
||
INTERCEPT: -3.06862155,
|
||
THRESHOLD: 0.45,
|
||
},
|
||
TITLE_ONLY: {
|
||
GROUP_SIMILARITY_WEIGHT: 0,
|
||
TITLE_SIMILARITY_WEIGHT: 2.50596721,
|
||
INTERCEPT: -0.54293376,
|
||
THRESHOLD: 0.6,
|
||
},
|
||
};
|
||
|
||
const TAB_URLS_TO_EXCLUDE = [
|
||
"about:newtab",
|
||
"about:home",
|
||
"about:privatebrowsing",
|
||
"chrome://browser/content/blanktab.html",
|
||
"about:firefoxview",
|
||
];
|
||
|
||
/**
|
||
* For a given set of clusters represented by indices, returns the index of the cluster
|
||
* that has the most anchor items inside it.
|
||
*
|
||
* An anhor item is an index that represents the index to a tab that is already grouped and in
|
||
* the cluster we're interested in finding more items for.
|
||
*
|
||
* @param {number[][]} groupIndices - Array of clusters represented as arrays of indices.
|
||
* @param {number[]} anchorItems - Array of anchor item indices.
|
||
* @returns {{anchorClusterIndex: number, numAnchorItemsInCluster: number}} Index of best cluster and the number of anchor items.
|
||
*/
|
||
export function getBestAnchorClusterInfo(groupIndices, anchorItems) {
|
||
const anchorItemSet = new Set(anchorItems);
|
||
const numItemsList = groupIndices.map(g =>
|
||
g.reduce(
|
||
(cur, itemIndex) => (anchorItemSet.has(itemIndex) ? cur + 1 : cur),
|
||
0
|
||
)
|
||
);
|
||
const anchorClusterIndex = numItemsList.indexOf(Math.max(...numItemsList));
|
||
const numAnchorItemsInCluster = numItemsList[anchorClusterIndex];
|
||
return { anchorClusterIndex, numAnchorItemsInCluster };
|
||
}
|
||
|
||
export class SmartTabGroupingManager {
|
||
/**
|
||
* Creates the SmartTabGroupingManager object.
|
||
* @param {object} config configuration options
|
||
*/
|
||
constructor(config) {
|
||
this.config = config || SMART_TAB_GROUPING_CONFIG;
|
||
}
|
||
|
||
/**
|
||
*
|
||
* @param {MLEngine} engine the engine to check
|
||
* @return {boolean} true if the engine has not been initialized or closed
|
||
*/
|
||
static isEngineClosed(engine) {
|
||
return !engine || engine?.engineStatus === "closed";
|
||
}
|
||
|
||
/**
|
||
* Initializes the embedding engine by running a test request
|
||
* This helps remove the init latency
|
||
*/
|
||
async initEmbeddingEngine() {
|
||
if (!SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) {
|
||
return;
|
||
}
|
||
try {
|
||
this.embeddingEngine = await this._createMLEngine(this.config.embedding);
|
||
const request = {
|
||
args: ["Test"],
|
||
options: { pooling: "mean", normalize: true },
|
||
};
|
||
this.embeddingEngine.run(request);
|
||
} catch (e) {}
|
||
}
|
||
|
||
/**
|
||
* Generates suggested tabs for an existing or provisional group
|
||
* @param {object} group active group we are adding tabs to
|
||
* @param {array} tabs list of tabs from gbrowser, some of which may be grouped in other groups
|
||
* @returns a list of suggested new tabs. If no new tabs are suggested an empty list is returned.
|
||
*/
|
||
async smartTabGroupingForGroup(group, tabs) {
|
||
// Add tabs to suggested group
|
||
const groupTabs = group.tabs;
|
||
const allTabs = tabs.filter(tab => {
|
||
// Don't include tabs already pinned
|
||
if (tab.pinned) {
|
||
return false;
|
||
}
|
||
if (!tab?.linkedBrowser?.currentURI?.spec) {
|
||
return false;
|
||
}
|
||
return true;
|
||
});
|
||
|
||
// find tabs that are part of the group
|
||
const groupIndices = groupTabs
|
||
.map(a => allTabs.indexOf(a))
|
||
.filter(a => a >= 0);
|
||
|
||
// find tabs that are part of other groups
|
||
const alreadyGroupedIndices = allTabs
|
||
.map((t, i) => (t.group ? i : -1))
|
||
.filter(a => a >= 0);
|
||
|
||
let suggestedTabs;
|
||
switch (lazy.suggestOtherTabsMethod) {
|
||
case SUGGEST_OTHER_TABS_METHODS.KMEANS_WITH_ANCHOR:
|
||
suggestedTabs = await this.generateClusters(
|
||
allTabs,
|
||
null,
|
||
null,
|
||
null,
|
||
groupIndices,
|
||
alreadyGroupedIndices
|
||
).then(clusters => {
|
||
if (!clusters) {
|
||
return [];
|
||
}
|
||
const targetCluster = clusters.clusterRepresentations.find(c =>
|
||
groupTabs.some(g => c.tabs.includes(g))
|
||
);
|
||
if (targetCluster) {
|
||
// Return only tabs not already grouped
|
||
return targetCluster.tabs.filter(t => !t.group);
|
||
}
|
||
return [];
|
||
});
|
||
break;
|
||
case SUGGEST_OTHER_TABS_METHODS.LOGISTIC_REGRESSION:
|
||
suggestedTabs = await this.findSimilarTabsLogisticRegression({
|
||
allTabs,
|
||
groupedIndices: groupIndices,
|
||
alreadyGroupedIndices,
|
||
groupLabel: group?.label,
|
||
});
|
||
break;
|
||
case SUGGEST_OTHER_TABS_METHODS.NEAREST_NEIGHBOR:
|
||
default:
|
||
// find nearest neighbors to current group
|
||
suggestedTabs = await this.findNearestNeighbors({
|
||
allTabs,
|
||
groupedIndices: groupIndices,
|
||
alreadyGroupedIndices,
|
||
groupLabel: group?.label,
|
||
});
|
||
}
|
||
return suggestedTabs.slice(0, MAX_SUGGESTED_TABS);
|
||
}
|
||
|
||
/**
|
||
* Get tabs that need to be included in suggestions
|
||
* @param {array} allTabs all tabs that are part of the window
|
||
* @param {array} groupedIndices indices of tabs that are already part of the group
|
||
* @param {array} alreadyGroupedIndices indices of tabs that are part of other groups
|
||
* @returns {array} tabs indices to be considered for suggestions
|
||
*/
|
||
getTabsToSuggest(allTabs, groupedIndices, alreadyGroupedIndices) {
|
||
// tabs to be excluded
|
||
// indices of all tabs that should be excluded (with duplicates)
|
||
const tabURLIndicesToExclude = allTabs
|
||
.map((at, index) => (TAB_URLS_TO_EXCLUDE.includes(at.url) ? index : -1))
|
||
.filter(index => index !== -1);
|
||
const excludedTabIndices = [
|
||
...groupedIndices,
|
||
...alreadyGroupedIndices,
|
||
...tabURLIndicesToExclude,
|
||
];
|
||
|
||
// tabs to be included
|
||
return allTabs
|
||
.map((_, index) => index)
|
||
.filter(i => !excludedTabIndices.includes(i));
|
||
}
|
||
|
||
/*
|
||
* Generates similar tabs a grouped list of tabs
|
||
* @param {array} allTabs all tabs that are part of the window
|
||
* @param {array} groupedIndices indices of tabs that are already part of the group
|
||
* @param {array} alreadyGroupedIndices indices of tabs that are part of other groups
|
||
* @param {string} groupLabel name of group if present
|
||
* @param {number} threshold for nearest neighbor similarity
|
||
* @returns a list of suggested tabs that are similar to the groupedIndices tabs
|
||
*/
|
||
async findNearestNeighbors({
|
||
allTabs,
|
||
groupedIndices,
|
||
alreadyGroupedIndices,
|
||
groupLabel = "",
|
||
thresholdMills = lazy.nearestNeighborThresholdInt,
|
||
precomputedEmbeddings = [],
|
||
depth = 0,
|
||
}) {
|
||
// get embeddings for all the tabs
|
||
const tabData = await this._prepareTabData(allTabs);
|
||
let embeddings = precomputedEmbeddings;
|
||
if (precomputedEmbeddings.length === 0) {
|
||
embeddings = await this._generateEmbeddings(
|
||
tabData.map((td, index) => {
|
||
let text = SmartTabGroupingManager.preprocessText(td[EMBED_TEXT_KEY]);
|
||
// augment with group name if it's present
|
||
if (groupLabel && groupedIndices.includes(index)) {
|
||
text = `${groupLabel.slice(0, 100)}. ${text}`;
|
||
}
|
||
return text;
|
||
})
|
||
);
|
||
}
|
||
|
||
// tabs that need to be assigned after filtering
|
||
const tabsToAssignIndices = this.getTabsToSuggest(
|
||
tabData,
|
||
groupedIndices,
|
||
alreadyGroupedIndices
|
||
);
|
||
|
||
let closestTabs = [];
|
||
const similarTabsIndices = [];
|
||
for (let i = 0; i < tabsToAssignIndices.length; i++) {
|
||
let closestScore = null;
|
||
for (
|
||
let j = 0;
|
||
j < Math.min(groupedIndices.length, MAX_NN_GROUPED_TABS);
|
||
j++
|
||
) {
|
||
const cosineSim = cosSim(
|
||
embeddings[tabsToAssignIndices[i]],
|
||
embeddings[groupedIndices[j]]
|
||
);
|
||
if (!closestScore || cosineSim > closestScore) {
|
||
closestScore = cosineSim;
|
||
}
|
||
}
|
||
// threshold could also be set via a nimbus experiment, in which case
|
||
// it will be an int <= 1000
|
||
if (closestScore > thresholdMills / 1000) {
|
||
closestTabs.push([allTabs[tabsToAssignIndices[i]], closestScore]);
|
||
similarTabsIndices.push(tabsToAssignIndices[i]);
|
||
}
|
||
}
|
||
closestTabs.sort((a, b) => b[1] - a[1]);
|
||
closestTabs = closestTabs.map(t => t[0]);
|
||
// recurse once if the initial call only had a single tab
|
||
// and we found at least 1 similar tab - this improves recall
|
||
if (groupedIndices.length === 1 && !!closestTabs.length && depth === 1) {
|
||
const recurseSimilarTabs = await this.findNearestNeighbors({
|
||
allTabs,
|
||
groupedIndices: similarTabsIndices,
|
||
alreadyGroupedIndices: alreadyGroupedIndices.concat(groupedIndices),
|
||
groupLabel,
|
||
thresholdMills,
|
||
precomputedEmbeddings: embeddings,
|
||
depth: depth - 1,
|
||
});
|
||
closestTabs = closestTabs.concat(recurseSimilarTabs);
|
||
}
|
||
return closestTabs;
|
||
}
|
||
|
||
/**
|
||
* Calculates the average similarity between the anchor embeddings and the candidate embeddings
|
||
* @param {list[Number]} anchorEmbeddings title embeddings for the anchor tabs
|
||
* @param {list[Number]} candidateEmbeddings title embeddings for the candidate tabs
|
||
*/
|
||
getAverageSimilarity(anchorEmbeddings, candidateEmbeddings) {
|
||
let averageSimilarities = [];
|
||
for (let candidate_embedding of candidateEmbeddings) {
|
||
let averageSimilarity = 0;
|
||
for (let anchor_embedding of anchorEmbeddings) {
|
||
averageSimilarity += cosSim(candidate_embedding, anchor_embedding);
|
||
}
|
||
averageSimilarities.push(averageSimilarity / anchorEmbeddings.length);
|
||
}
|
||
return averageSimilarities;
|
||
}
|
||
|
||
/**
|
||
* Calculates the sigmoid value of the input
|
||
*
|
||
* @param {Number} z
|
||
* @return {Number}
|
||
*/
|
||
sigmoid(z) {
|
||
return 1 / (1 + Math.exp(-z));
|
||
}
|
||
|
||
/**
|
||
* Calculates the probability using the linear combination of the parameters
|
||
*
|
||
* @param {Number} groupSimilarity how similar a candidate tab is to the group name
|
||
* @param {Number} titleSimilarity how similar a candidate tab is to the anchors
|
||
* @param {Object} params the logistic regression weights assigned to each parameter
|
||
* @return {Number}
|
||
*/
|
||
calculateProbability(groupSimilarity, titleSimilarity, params) {
|
||
return this.sigmoid(
|
||
groupSimilarity * params.GROUP_SIMILARITY_WEIGHT +
|
||
titleSimilarity * params.TITLE_SIMILARITY_WEIGHT +
|
||
params.INTERCEPT
|
||
);
|
||
}
|
||
|
||
/**
|
||
* Calculates the probabilities given two lists of the same length
|
||
*
|
||
* @param {list[Number]} groupSimilarities cosine similarity between the candidate tabs and the group name
|
||
* @param {list[Number]} titleSimilarities average cosine similarity between the candidate tabs and anchors
|
||
* @return {list[Number]} probabilities for each candidate tab
|
||
*/
|
||
calculateAllProbabilities(groupSimilarities, titleSimilarities) {
|
||
const hasGroupSimilarity = Boolean(groupSimilarities);
|
||
let probabilities = [];
|
||
for (let i = 0; i < titleSimilarities.length; i++) {
|
||
probabilities.push(
|
||
this.calculateProbability(
|
||
hasGroupSimilarity ? groupSimilarities[i] : 0,
|
||
titleSimilarities[i],
|
||
hasGroupSimilarity
|
||
? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME
|
||
: LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY
|
||
)
|
||
);
|
||
}
|
||
return probabilities;
|
||
}
|
||
|
||
/**
|
||
* Generates similar tabs to a grouped list of tabs using a logistic regression "model"
|
||
*
|
||
* @param {array} allTabs all tabs that are part of the window
|
||
* @param {array} groupedIndices indices of tabs that are already part of the group
|
||
* @param {array} alreadyGroupedIndices indices of tabs that are part of other groups
|
||
* @param {string} groupLabel name of group if present
|
||
*/
|
||
async findSimilarTabsLogisticRegression({
|
||
allTabs,
|
||
groupedIndices,
|
||
alreadyGroupedIndices,
|
||
groupLabel = "",
|
||
}) {
|
||
const tabData = await this._prepareTabData(allTabs);
|
||
const candidateIndices = this.getTabsToSuggest(
|
||
tabData,
|
||
groupedIndices,
|
||
alreadyGroupedIndices
|
||
);
|
||
|
||
const candidateTabsData = candidateIndices.map(ci => allTabs[ci]);
|
||
const candidateTabsPrep = await this._prepareTabData(candidateTabsData);
|
||
|
||
const anchorTabsPrep = groupedIndices
|
||
.map(gi => tabData[gi])
|
||
.slice(0, MAX_NN_GROUPED_TABS);
|
||
|
||
// generate embeddings for both anchor and candidate titles
|
||
const titleEmbeddings = await this._generateEmbeddings(
|
||
anchorTabsPrep
|
||
.concat(candidateTabsPrep)
|
||
.map(tab => SmartTabGroupingManager.preprocessText(tab[EMBED_TEXT_KEY]))
|
||
);
|
||
|
||
let groupEmbedding;
|
||
let groupSimilarities;
|
||
if (groupLabel) {
|
||
groupEmbedding = await this._generateEmbeddings([groupLabel]);
|
||
// calculate similarity between the group and the candidate tabs if group name is present
|
||
groupSimilarities = this.getAverageSimilarity(
|
||
groupEmbedding,
|
||
titleEmbeddings.slice(anchorTabsPrep.length)
|
||
);
|
||
}
|
||
|
||
// calculate the similarity between the anchors and candidate titles
|
||
const titleSimilarities = this.getAverageSimilarity(
|
||
titleEmbeddings.slice(0, anchorTabsPrep.length),
|
||
titleEmbeddings.slice(anchorTabsPrep.length)
|
||
);
|
||
|
||
const candidateProbabilities = this.calculateAllProbabilities(
|
||
groupSimilarities,
|
||
titleSimilarities
|
||
);
|
||
|
||
// get proper params depending on group name availability
|
||
const probabilityThreshold = groupEmbedding
|
||
? LOGISTIC_REGRESSION_PARAMS.TITLE_WITH_GROUP_NAME.THRESHOLD
|
||
: LOGISTIC_REGRESSION_PARAMS.TITLE_ONLY.THRESHOLD;
|
||
|
||
return (
|
||
candidateTabsData
|
||
// combine candidate tabs with corresponding probabilities
|
||
.map((ct, index) => ({
|
||
ct,
|
||
prob: candidateProbabilities[index],
|
||
}))
|
||
// only keep those that are within the probability threshold
|
||
.filter(item => item.prob >= probabilityThreshold)
|
||
// ensure the highest probability candidates come first in the list
|
||
.sort((a, b) => b.prob - a.prob)
|
||
// keep the tabs only
|
||
.map(item => item.ct)
|
||
);
|
||
}
|
||
|
||
/**
|
||
* This function will terminate a grouping or label generation in progress
|
||
* It is currently not implemented.
|
||
*/
|
||
terminateProcess() {
|
||
// TODO - teminate AI processes, This method will be
|
||
// called when tab grouping panel is closed.
|
||
}
|
||
|
||
/**
|
||
* Changes the clustering method. Must be one of supported methods.
|
||
* @param {string} method Name of method
|
||
*/
|
||
setClusteringMethod(method) {
|
||
if (!(method in CLUSTER_METHODS)) {
|
||
throw new Error(`Clustering method ${method} not supported`);
|
||
}
|
||
this.config.clustering.clusterImplementation = method;
|
||
}
|
||
|
||
/**
|
||
* Set the technique for clustering when certain tabs are already assigned to groups
|
||
*
|
||
* @param {string} method which is one of ANCHOR_METHODS
|
||
*/
|
||
setAnchorMethod(method) {
|
||
if (!(method in ANCHOR_METHODS)) {
|
||
throw new Error(`Clustering anchor method ${method} not supported`);
|
||
}
|
||
this.config.clustering.anchorMethod = method;
|
||
}
|
||
|
||
setSilBoost(boost) {
|
||
this.config.clustering.pregroupedSilhouetteBoost = boost;
|
||
}
|
||
|
||
/**
|
||
* Sets method to reduce dimensionality of embeddings prior to clustering
|
||
* @param {string} method Name of method
|
||
*/
|
||
setDimensionReductionMethod(method) {
|
||
if (method && !(method in DIM_REDUCTION_METHODS)) {
|
||
throw new Error(`Dimension reduction method ${method} not supported`);
|
||
}
|
||
this.config.clustering.dimReductionMethod = method;
|
||
}
|
||
|
||
/**
|
||
* Sets the field name of the title of a page to be used when clustering or generating embeddings
|
||
* This is useful when clustering test data that is not a tab object
|
||
* @param {string} titleKey KEY FOR THE TITLE
|
||
*/
|
||
setDataTitleKey(titleKey) {
|
||
this.config.dataConfig.titleKey = titleKey;
|
||
}
|
||
|
||
/**
|
||
* Logs to the appropriate place for debugging. Console for now
|
||
* @param {string} msg Message to log
|
||
* @param {boolean} useDescription Whether to add description to the final text
|
||
*/
|
||
log(_msg) {}
|
||
|
||
/**
|
||
* Prepares data to be used by the ml models
|
||
* @param {Object[]} tabList list of tabs in the current window
|
||
* @param {boolean} useDescription whether we should combined the title and description
|
||
* @return {Promise<*[Object]>}
|
||
* @private
|
||
*/
|
||
async _prepareTabData(tabList, useDescription = false) {
|
||
const titleKey = this.config.dataConfig.titleKey;
|
||
const descriptionKey = this.config.dataConfig.descriptionKey;
|
||
const structuredData = [];
|
||
for (let tab of tabList) {
|
||
const description =
|
||
useDescription && descriptionKey && tab[descriptionKey];
|
||
|
||
let textToEmbed;
|
||
if (description) {
|
||
textToEmbed = tab[titleKey] + " " + description;
|
||
} else {
|
||
textToEmbed = tab[titleKey] || "Unknown";
|
||
}
|
||
|
||
structuredData.push({
|
||
[EMBED_TEXT_KEY]: textToEmbed,
|
||
title: tab[titleKey],
|
||
description,
|
||
url: tab?.linkedBrowser?.currentURI?.spec,
|
||
});
|
||
}
|
||
return structuredData;
|
||
}
|
||
|
||
/**
|
||
* Get updated config for the ml engine
|
||
*
|
||
* @param {object} initData
|
||
* @param {string} featureId
|
||
* @return {*}
|
||
*/
|
||
static getUpdatedInitData(initData, featureId) {
|
||
// we're setting a specific modelRevision through about:config or Nimbus
|
||
if (
|
||
featureId === SMART_TAB_GROUPING_CONFIG.topicGeneration.featureId &&
|
||
lazy.topicModelRevision !== LATEST_MODEL_REVISION
|
||
) {
|
||
initData.modelRevision = lazy.topicModelRevision;
|
||
} else if (
|
||
featureId === SMART_TAB_GROUPING_CONFIG.embedding.featureId &&
|
||
lazy.embeddingModelRevision !== LATEST_MODEL_REVISION
|
||
) {
|
||
initData.modelRevision = lazy.embeddingModelRevision;
|
||
}
|
||
return initData;
|
||
}
|
||
|
||
/**
|
||
* Creates an ML engine for a given config.
|
||
* @param {*} engineConfig
|
||
* @param {function} progressCallback
|
||
* @returns MLEngine
|
||
*/
|
||
async _createMLEngine(engineConfig, progressCallback) {
|
||
const {
|
||
featureId,
|
||
engineId,
|
||
dtype,
|
||
taskName,
|
||
timeoutMS,
|
||
modelId,
|
||
modelRevision,
|
||
} = engineConfig;
|
||
let initData = {
|
||
featureId,
|
||
engineId,
|
||
dtype,
|
||
taskName,
|
||
timeoutMS,
|
||
modelId,
|
||
modelRevision,
|
||
};
|
||
|
||
initData = SmartTabGroupingManager.getUpdatedInitData(initData, featureId);
|
||
return await createEngine(initData, progressCallback);
|
||
}
|
||
|
||
/**
|
||
* Generates embeddings from a list of tab data structures
|
||
* @param tabList List of tabs with label (title) and description keys
|
||
* @returns {Promise<*[]>} List of embeddings (2d array)
|
||
* @private
|
||
*/
|
||
async _generateEmbeddings(textToEmbedList) {
|
||
const inputData = {
|
||
inputArgs: textToEmbedList,
|
||
runOptions: {
|
||
pooling: "mean",
|
||
normalize: true,
|
||
},
|
||
};
|
||
|
||
if (SmartTabGroupingManager.isEngineClosed(this.embeddingEngine)) {
|
||
this.embeddingEngine = await this._createMLEngine(this.config.embedding);
|
||
}
|
||
const request = {
|
||
args: [inputData.inputArgs],
|
||
options: inputData.runOptions,
|
||
};
|
||
return await this.embeddingEngine.run(request);
|
||
}
|
||
|
||
/**
|
||
* Clusters in desired methods
|
||
* based on the config of the class
|
||
* @param tabList List of tabs as array
|
||
* @param docEmbeddings Precomputed embeddings for the Tab as two dimensional array
|
||
* @param k Desired number of clusters. Tries a range of sizes if 0.
|
||
* @param {function} randomFunc Optional seeded random number generator for testing
|
||
* @returns {SmartTabGroupingResult}
|
||
* @private
|
||
*/
|
||
_clusterEmbeddings({
|
||
tabs,
|
||
embeddings,
|
||
k,
|
||
randomFunc,
|
||
anchorIndices,
|
||
alreadyGroupedIndices = [],
|
||
}) {
|
||
let allItems;
|
||
|
||
const freezeAnchorsInZeroCluster =
|
||
anchorIndices &&
|
||
this.config.clustering.anchorMethod == ANCHOR_METHODS.FIXED;
|
||
|
||
const dimReductionMethod = this.config.clustering.dimReductionMethod;
|
||
switch (dimReductionMethod) {
|
||
default:
|
||
// Dimensionality reduction support is landing very soon.
|
||
break;
|
||
}
|
||
k = k || 0;
|
||
let startK = k;
|
||
let endK = k + 1;
|
||
if (!k) {
|
||
startK = 2;
|
||
// Find a reasonable max # of clusters
|
||
endK =
|
||
Math.min(
|
||
Math.floor(Math.log(embeddings.length) * 2.0),
|
||
embeddings.length
|
||
) + 1;
|
||
}
|
||
let bestResult;
|
||
let bestResultSilScore = -100.0;
|
||
let bestResultCenterCluster = 0;
|
||
|
||
const clusteringMethod = this.config.clustering.clusterImplementation;
|
||
const clusteringTriesPerK = this.config.clustering.clusteringTriesPerK;
|
||
for (let curK = startK; curK < endK; curK++) {
|
||
let bestItemsForK;
|
||
let bestInertiaForK = 500000000000;
|
||
for (let j = 0; j < clusteringTriesPerK; j++) {
|
||
switch (clusteringMethod) {
|
||
case CLUSTER_METHODS.KMEANS:
|
||
allItems = kmeansPlusPlus({
|
||
data: embeddings,
|
||
k: curK,
|
||
maxIterations: 0,
|
||
randomFunc,
|
||
anchorIndices,
|
||
preassignedIndices:
|
||
this.config.clustering.pregroupedHandlingMethod ===
|
||
PREGROUPED_HANDLING_METHODS.EXCLUDE
|
||
? alreadyGroupedIndices
|
||
: [],
|
||
freezeAnchorsInZeroCluster,
|
||
});
|
||
break;
|
||
default:
|
||
throw Error("Clustering implementation not supported");
|
||
}
|
||
const tempResult = new SmartTabGroupingResult({
|
||
indices: allItems,
|
||
embeddings,
|
||
config: this.config,
|
||
});
|
||
const inertia = tempResult.getCentroidInertia();
|
||
if (inertia < bestInertiaForK) {
|
||
bestInertiaForK = inertia;
|
||
bestItemsForK = tempResult;
|
||
}
|
||
}
|
||
const silScores = silhouetteCoefficients(
|
||
embeddings,
|
||
bestItemsForK.indices
|
||
);
|
||
|
||
if (
|
||
freezeAnchorsInZeroCluster &&
|
||
this.config.clustering.pregroupedSilhouetteBoost > 0
|
||
) {
|
||
// Boost silhouette score of target cluster when we are grouping around an existing cluster
|
||
// pregroupedSilhouetteBoost indicates the relative weight of the cluster's score and all other cluster's combined
|
||
silScores[0] *= this.config.clustering.pregroupedSilhouetteBoost;
|
||
}
|
||
|
||
let avgSil = silScores.reduce((p, c) => p + c, 0) / silScores.length;
|
||
let curAnchorCluster = 0;
|
||
if (anchorIndices && !freezeAnchorsInZeroCluster) {
|
||
const { anchorClusterIndex, numAnchorItemsInCluster } =
|
||
getBestAnchorClusterInfo(bestItemsForK.indices, anchorIndices);
|
||
curAnchorCluster = anchorClusterIndex;
|
||
const penalty =
|
||
(MISSING_ANCHOR_IN_CLUSTER_PENALTY *
|
||
(anchorIndices.length - numAnchorItemsInCluster)) /
|
||
anchorIndices.length;
|
||
avgSil -= penalty;
|
||
}
|
||
if (avgSil > bestResultSilScore) {
|
||
bestResultSilScore = avgSil;
|
||
bestResult = bestItemsForK.indices;
|
||
bestResultCenterCluster = curAnchorCluster;
|
||
}
|
||
}
|
||
const result = new SmartTabGroupingResult({
|
||
indices: bestResult,
|
||
tabs,
|
||
embeddings,
|
||
config: this.config,
|
||
});
|
||
if (anchorIndices) {
|
||
result.setAnchorClusterIndex(
|
||
freezeAnchorsInZeroCluster ? 0 : bestResultCenterCluster
|
||
); // In our k-means clustering implementation anchor cluster is always first
|
||
if (!freezeAnchorsInZeroCluster) {
|
||
result.adjustClusterForAnchors(anchorIndices);
|
||
}
|
||
}
|
||
return result;
|
||
}
|
||
|
||
/**
|
||
* Generate a label for tabs in a group created by the user
|
||
*
|
||
* @param tabs tabs that are currently in the group
|
||
* @param otherTabs tabs in the window not part of the group
|
||
* @return {Promise<null|string|string|*>}
|
||
*/
|
||
async getPredictedLabelForGroup(tabs, otherTabs) {
|
||
const clusters = this.createStaticCluster(tabs);
|
||
const otherClusters = this.createStaticCluster(otherTabs);
|
||
let predictedLabel;
|
||
try {
|
||
// function below modifies "clusters" object
|
||
await this.generateGroupLabels(clusters, otherClusters);
|
||
predictedLabel = clusters.clusterRepresentations[0].predictedTopicLabel;
|
||
} catch (e) {
|
||
this.labelReason = LABEL_REASONS.ERROR;
|
||
predictedLabel = "";
|
||
}
|
||
return predictedLabel;
|
||
}
|
||
|
||
/**
|
||
* Generates clusters for a given list of tabs using precomputed embeddings or newly generated ones.
|
||
*
|
||
* @param {Object[]} tabList - List of tab objects to be clustered.
|
||
* @param {number[][]} [precomputedEmbeddings] - Precomputed embeddings for tab titles and descriptions.
|
||
* @param {number} numClusters - Number of clusters to form.
|
||
* @param {Function} randFunc - Random function used for clustering initialization.
|
||
* @param {number[]} [anchorIndices=[]] - Indices of anchor tabs that should be prioritized in clustering.
|
||
* @param {number[]} [alreadyGroupedIndices=[]] - Indices of tabs that are already assigned to groups.
|
||
* @returns {SmartTabGroupingResult} - The best clustering result based on centroid inertia.
|
||
*/
|
||
async generateClusters(
|
||
tabList,
|
||
precomputedEmbeddings,
|
||
numClusters,
|
||
randFunc,
|
||
anchorIndices = [],
|
||
alreadyGroupedIndices = []
|
||
) {
|
||
numClusters = numClusters ?? 0;
|
||
const structuredData = await this._prepareTabData(tabList);
|
||
|
||
// embeddings for title and description
|
||
if (precomputedEmbeddings) {
|
||
this.docEmbeddings = precomputedEmbeddings;
|
||
} else {
|
||
this.docEmbeddings = await this._generateEmbeddings(
|
||
structuredData.map(a => a[EMBED_TEXT_KEY])
|
||
);
|
||
}
|
||
let bestResultCluster;
|
||
let bestResultDistance = 50000000.0;
|
||
|
||
const NUM_RUNS = 1;
|
||
for (let i = 0; i < NUM_RUNS; i++) {
|
||
const curResult = this._clusterEmbeddings({
|
||
tabs: tabList,
|
||
embeddings: this.docEmbeddings,
|
||
k: numClusters,
|
||
randomFunc: randFunc,
|
||
anchorIndices,
|
||
alreadyGroupedIndices,
|
||
});
|
||
const distance = curResult.getCentroidInertia();
|
||
if (distance < bestResultDistance) {
|
||
bestResultDistance = distance;
|
||
bestResultCluster = curResult;
|
||
}
|
||
}
|
||
return bestResultCluster;
|
||
}
|
||
|
||
/**
|
||
* Create static cluster from a list of tabs. A single tab is Ok. Returns null for 0 tabs
|
||
* @param tabs
|
||
* @returns {SmartTabGroupingResult} groupingResult
|
||
*/
|
||
createStaticCluster(tabs) {
|
||
if (!tabs) {
|
||
return null;
|
||
}
|
||
|
||
return new SmartTabGroupingResult({
|
||
indices: [Array.from({ length: tabs.length }, (_, i) => i)],
|
||
tabs,
|
||
config: this.config,
|
||
});
|
||
}
|
||
|
||
/***
|
||
* Utility function that loads all required engines for Smart Tab Grouping and any dependent models
|
||
* @param {(progress: { percentage: number }) => void} progressCallback callback function to call.
|
||
* Callback passes a dict with percentage indicating best effort 0.0-100.0 progress in model download.
|
||
*/
|
||
async preloadAllModels(progressCallback) {
|
||
let previousProgress = -1;
|
||
const expectedObjects =
|
||
EXPECTED_TOPIC_MODEL_OBJECTS + EXPECTED_EMBEDDING_MODEL_OBJECTS;
|
||
// TODO - Find a way to get these fields. Add as a transformers js callback or within remotesettings
|
||
|
||
const UPDATE_THRESHOLD_PERCENTAGE = 0.5;
|
||
const ONE_MB = 1024 * 1024;
|
||
const START_THRESHOLD_BYTES = ONE_MB * 0.2;
|
||
|
||
const mutliProgressAggregator = new lazy.MultiProgressAggregator({
|
||
progressCallback: ({ progress, totalLoaded, metadata }) => {
|
||
if (totalLoaded < START_THRESHOLD_BYTES) {
|
||
progress = 0.0;
|
||
} else {
|
||
const numObjSeen = metadata.totalObjectsSeen || 0;
|
||
if (numObjSeen > 0 && numObjSeen < expectedObjects) {
|
||
// When starting to download we may still be getting configs and not have all the data
|
||
progress *= numObjSeen / expectedObjects;
|
||
}
|
||
if (progress > 100) {
|
||
progress = 100;
|
||
}
|
||
}
|
||
if (
|
||
Math.abs(previousProgress - progress) > UPDATE_THRESHOLD_PERCENTAGE
|
||
) {
|
||
// Update only once changes are above a threshold to avoid throttling the UI with events.
|
||
progressCallback({
|
||
percentage: progress,
|
||
});
|
||
previousProgress = progress;
|
||
}
|
||
},
|
||
watchedTypes: [
|
||
lazy.Progress.ProgressType.DOWNLOAD,
|
||
lazy.Progress.ProgressType.LOAD_FROM_CACHE,
|
||
],
|
||
});
|
||
|
||
const [topicEngine, embeddingEngine] = await Promise.all([
|
||
this._createMLEngine(
|
||
this.config.topicGeneration,
|
||
mutliProgressAggregator?.aggregateCallback.bind(
|
||
mutliProgressAggregator
|
||
) || null
|
||
),
|
||
this._createMLEngine(
|
||
this.config.embedding,
|
||
mutliProgressAggregator?.aggregateCallback.bind(
|
||
mutliProgressAggregator
|
||
) || null
|
||
),
|
||
]);
|
||
this.topicEngine = topicEngine;
|
||
this.embeddingEngine = embeddingEngine;
|
||
}
|
||
|
||
/**
|
||
* Generate model input from keywords and documents
|
||
* @param {string []} keywords
|
||
* @param {string []} documents
|
||
*/
|
||
createModelInput(keywords, documents) {
|
||
if (!keywords || keywords.length === 0) {
|
||
return `Topic from keywords: titles: \n${documents.join(" \n")}`;
|
||
}
|
||
return `Topic from keywords: ${keywords.join(", ")}. titles: \n${documents.join(" \n")}`;
|
||
}
|
||
|
||
/**
|
||
* One artifact of the LLM output is that sometimes words are duplicated
|
||
* This function cuts the phrase when it sees the first duplicate word.
|
||
* Handles simple singluar / plural duplicates (-s only).
|
||
* @param {string} phrase Input phrase
|
||
* @returns {string} phrase cut before any duplicate word
|
||
*/
|
||
static cutAtDuplicateWords(phrase) {
|
||
if (!phrase.length) {
|
||
return phrase;
|
||
}
|
||
const wordsSet = new Set();
|
||
const wordList = phrase.split(" ");
|
||
for (let i = 0; i < wordList.length; i++) {
|
||
let baseWord = wordList[i].toLowerCase();
|
||
if (baseWord.length > 3) {
|
||
if (baseWord.slice(-1) === "s") {
|
||
baseWord = baseWord.slice(0, -1);
|
||
}
|
||
}
|
||
if (wordsSet.has(baseWord)) {
|
||
// We are seeing a baseWord word. Exit with just the words so far and don't
|
||
// add any new words
|
||
return wordList.slice(0, i).join(" ");
|
||
}
|
||
wordsSet.add(baseWord);
|
||
}
|
||
return phrase; // return original phrase
|
||
}
|
||
|
||
/**
|
||
* Removes trailing domain-related text such as '... - Mail' or '... | News'
|
||
* If there's not enough information remaining after, we keep the text as is
|
||
* @param {string} text tab title with potential domain information
|
||
* @return {string}
|
||
*/
|
||
static preprocessText(text) {
|
||
// Matches 'xyz - Domain' or 'xyz | Domain'
|
||
// with a space before and after delimiter
|
||
// or if there are multiple delimiters next to each other
|
||
const delimiters = /(?<=\s)[|–-]+(?=\s)/;
|
||
const splitText = text.split(delimiters);
|
||
|
||
// ensure there's enough info without the last element
|
||
const hasEnoughInfo =
|
||
!!splitText.length && splitText.slice(0, -1).join(" ").length > 5;
|
||
|
||
// domain related texts are usually shorter, this takes care of the most common cases
|
||
const isPotentialDomainInfo =
|
||
splitText.length > 1 && splitText[splitText.length - 1].length < 20;
|
||
|
||
// If both conditions are met, remove the last chunk, filter out empty strings,
|
||
// join on space, trim, and lowercase
|
||
if (hasEnoughInfo && isPotentialDomainInfo) {
|
||
return splitText
|
||
.slice(0, -1) // everything except the last element
|
||
.map(t => t.trim())
|
||
.filter(Boolean) // remove empty strings
|
||
.join(" ") // join with spaces
|
||
.trim(); // remove leading/trailing spaces
|
||
}
|
||
|
||
// Otherwise, just return the text
|
||
return text;
|
||
}
|
||
|
||
/**
|
||
* Postprocessing of raw output from Topic Model ML Engine
|
||
* @param {string | undefined} topic Raw topic phrase from topic model or undefined in case of an error
|
||
*/
|
||
processTopicModelResult(topic) {
|
||
let basicResult = (topic || "").trim();
|
||
if (!basicResult) {
|
||
this.labelReason = LABEL_REASONS.LOW_CONFIDENCE;
|
||
}
|
||
if (LABELS_TO_EXCLUDE.includes(basicResult.toLowerCase())) {
|
||
this.labelReason = LABEL_REASONS.EXCLUDE;
|
||
return "";
|
||
}
|
||
return SmartTabGroupingManager.cutAtDuplicateWords(basicResult);
|
||
}
|
||
|
||
/**
|
||
* Add titles to a cluster in a SmartTabGroupingResult using generative tehniques
|
||
* Currently this function only works with a single target group, and a separate
|
||
* item that represents all other ungrouped tabs.
|
||
*
|
||
* In the future this may be updated to more generally find labels for a set of clusters.
|
||
* @param {SmartTabGroupingResult} groupingResult The cluster we are generating the label for
|
||
* @param {SmartTabGroupingResult} otherGroupingResult A 'made up' cluster representing all other tabs in the window
|
||
*/
|
||
async generateGroupLabels(groupingResult, otherGroupingResult = null) {
|
||
const { keywords, documents } =
|
||
groupingResult.getRepresentativeDocsAndKeywords(
|
||
otherGroupingResult
|
||
? otherGroupingResult.getRepresentativeDocuments()
|
||
: []
|
||
);
|
||
const inputArgs = this.createModelInput(
|
||
keywords ? keywords[0] : [],
|
||
documents
|
||
);
|
||
const requestInfo = {
|
||
inputArgs,
|
||
runOptions: {
|
||
max_length: 6,
|
||
},
|
||
};
|
||
|
||
if (SmartTabGroupingManager.isEngineClosed(this.topicEngine)) {
|
||
this.topicEngine = await this._createMLEngine(
|
||
this.config.topicGeneration
|
||
);
|
||
}
|
||
const request = {
|
||
args: [requestInfo.inputArgs],
|
||
options: requestInfo.runOptions,
|
||
};
|
||
const genLabelResults = await this.topicEngine.run(request);
|
||
genLabelResults.forEach((genResult, genResultIndex) => {
|
||
groupingResult.clusterRepresentations[
|
||
genResultIndex
|
||
].predictedTopicLabel = this.processTopicModelResult(
|
||
genResult.generated_text
|
||
);
|
||
});
|
||
}
|
||
|
||
getLabelReason() {
|
||
return this.labelReason || LABEL_REASONS.DEFAULT;
|
||
}
|
||
|
||
/**
|
||
* Generates glean metrics for ml smart tab label / topic.
|
||
* This is currently called when the user saves or cancels the "suggest label" flow.
|
||
*
|
||
* @param {string} action "save" or "cancel"
|
||
* @param {number} numTabsInGroup Number of tabs used to generate the label
|
||
* @param {string} mlLabel ML generated label for the tab group
|
||
* @param {string} userLabel User saved label for the tab group
|
||
* @param {string} id The id of the group
|
||
*/
|
||
async handleLabelTelemetry({
|
||
action,
|
||
numTabsInGroup,
|
||
mlLabel,
|
||
userLabel,
|
||
id = "",
|
||
}) {
|
||
const { [ML_TASK_TEXT2TEXT]: topicEngineConfig } =
|
||
await this.getEngineConfigs();
|
||
const labelReason = this.getLabelReason();
|
||
Glean.tabgroup.smartTabTopic.record({
|
||
action,
|
||
tabs_in_group: numTabsInGroup,
|
||
ml_label_length: (mlLabel || "").length,
|
||
user_label_length: (userLabel || "").length,
|
||
levenshtein_distance: lazy.NLP.levenshtein(
|
||
userLabel || "",
|
||
mlLabel || ""
|
||
),
|
||
model_revision: topicEngineConfig.modelRevision || "",
|
||
id,
|
||
label_reason: labelReason,
|
||
});
|
||
this.labelReason = LABEL_REASONS.DEFAULT;
|
||
}
|
||
|
||
/**
|
||
* Generates glean metrics for ml smart tab label / topic.
|
||
* This is currently called when the user saves or cancels the "suggest other tabs" flow
|
||
*
|
||
* @param {string} action "save" or "cancel"
|
||
* @param {number} numTabsInWindow Number of tabs in the current window
|
||
* @param {number} numTabsInGroup Number of tabs in the current group
|
||
* @param {number} numTabsSuggested Number of tabs suggested by the model
|
||
* @param {number} numTabsApproved Number of tabs approved by the user
|
||
* @param {number} numTabsRemoved Number of tabs removed by the user
|
||
* @param {string} id The id of the group
|
||
*/
|
||
async handleSuggestTelemetry({
|
||
action,
|
||
numTabsInWindow,
|
||
numTabsInGroup,
|
||
numTabsSuggested,
|
||
numTabsApproved,
|
||
numTabsRemoved,
|
||
id = "",
|
||
}) {
|
||
const { [ML_TASK_FEATURE_EXTRACTION]: embeddingEngineConfig } =
|
||
await this.getEngineConfigs();
|
||
Glean.tabgroup.smartTabSuggest.record({
|
||
action,
|
||
tabs_in_window: numTabsInWindow,
|
||
tabs_in_group: numTabsInGroup,
|
||
tabs_suggested: numTabsSuggested,
|
||
tabs_approved: numTabsApproved,
|
||
tabs_removed: numTabsRemoved,
|
||
model_revision: embeddingEngineConfig.modelRevision || "",
|
||
id,
|
||
});
|
||
}
|
||
|
||
/**
|
||
* Gets config that engine was initialized with
|
||
*
|
||
* @return {Promise<{"[ML_TASK_TEXT2TEXT]", "[ML_TASK_FEATURE_EXTRACTION]"}>}
|
||
*/
|
||
async getEngineConfigs() {
|
||
if (!this.topicEngineConfig) {
|
||
this.topicEngineConfig = await lazy.MLEngineParent.getInferenceOptions(
|
||
this.config.topicGeneration.featureId,
|
||
this.config.topicGeneration.taskName
|
||
);
|
||
}
|
||
if (!this.embeddingEngineConfig) {
|
||
this.embeddingEngineConfig =
|
||
await lazy.MLEngineParent.getInferenceOptions(
|
||
this.config.embedding.featureId,
|
||
this.config.embedding.taskName
|
||
);
|
||
}
|
||
return {
|
||
[ML_TASK_TEXT2TEXT]: this.topicEngineConfig,
|
||
[ML_TASK_FEATURE_EXTRACTION]: this.embeddingEngineConfig,
|
||
};
|
||
}
|
||
}
|
||
|
||
export class SmartTabGroupingResult {
|
||
#anchorClusterIndex = -1; // Index of cluster that has original items we're building clustering around, when building around an existing item.
|
||
|
||
/**
|
||
* Creates a result from indices and complete tab and embedding lists.
|
||
* This may create some extra data for management later
|
||
* @param indices indices of clusters (eg [[2,4], [1], [3]]_
|
||
* @param tabItems 1D array of tabs
|
||
* @param embeddingItems Two dimensional array of embeddings
|
||
* @param config Cluster config
|
||
*/
|
||
constructor({ indices = [], tabs, embeddings, config }) {
|
||
this.embeddingItems = embeddings;
|
||
this.config = config;
|
||
this.indices = indices.filter(subArray => !!subArray.length); // Cleanup any empty clusters
|
||
this.tabItems = tabs;
|
||
this._buildClusterRepresentations();
|
||
}
|
||
|
||
/**
|
||
* Builds list of ClusterRepresentations
|
||
*/
|
||
_buildClusterRepresentations() {
|
||
this.clusterRepresentations = this.indices.map(subClusterIndices => {
|
||
const tabItemsMapped =
|
||
this.tabItems && subClusterIndices.map(idx => this.tabItems[idx]);
|
||
const embeddingItemsMapped =
|
||
this.embeddingItems &&
|
||
subClusterIndices.map(idx => this.embeddingItems[idx]);
|
||
return new ClusterRepresentation({
|
||
tabs: tabItemsMapped,
|
||
embeddings: embeddingItemsMapped,
|
||
config: this.config,
|
||
});
|
||
});
|
||
}
|
||
|
||
/**
|
||
* Returns a list of documents for each cluster. Currently it is a list of documents picked
|
||
* in no particular order.
|
||
* @return {[strings]} Title and description that represent the cluster. (If no docs are in the class, then titles are returned)
|
||
*/
|
||
getRepresentativeDocuments() {
|
||
if (!this.documents) {
|
||
this.documents = this.tabItems.map(
|
||
t => t[this.config.dataConfig.titleKey]
|
||
);
|
||
}
|
||
// set a limit of 10 for now
|
||
return this.documents.slice(0, 10);
|
||
}
|
||
|
||
/**
|
||
* Returns the keywords and documents for the cluster, computing if needed
|
||
* Does not return keywods if only one document is passed to the function.
|
||
* @param{string[]} otherDocuments other clusters that we'll compare against
|
||
* @return keywords and documents that represent the cluster
|
||
*/
|
||
getRepresentativeDocsAndKeywords(otherDocuments = []) {
|
||
this.documents = this.getRepresentativeDocuments();
|
||
if (!this.keywords) {
|
||
const joinedDocs = this.documents.slice(0, 3).join(" ");
|
||
const otherDocs = otherDocuments.join(" ");
|
||
if (this.documents.length > 1) {
|
||
const keywordExtractor = new KeywordExtractor();
|
||
this.keywords = keywordExtractor.fitTransform([joinedDocs, otherDocs]);
|
||
} else {
|
||
this.keywords = [];
|
||
}
|
||
}
|
||
return { keywords: this.keywords, documents: this.documents };
|
||
}
|
||
|
||
setAnchorClusterIndex(index) {
|
||
this.#anchorClusterIndex = index;
|
||
}
|
||
|
||
/**
|
||
* Get the cluster we originally are grouping around (finding additinoal item)
|
||
* @returns ClusterRepresentation
|
||
*/
|
||
getAnchorCluster() {
|
||
if (this.#anchorClusterIndex === -1) {
|
||
return null;
|
||
}
|
||
return this.clusterRepresentations[this.#anchorClusterIndex];
|
||
}
|
||
|
||
/**
|
||
* Given the indices that we were clustering around, make sure they are are all in the target grouping
|
||
* Our generic k-means clustering might have them in separate groups
|
||
*/
|
||
adjustClusterForAnchors(anchorIndices) {
|
||
if (!anchorIndices.length) {
|
||
return;
|
||
}
|
||
const anchorSet = new Set(anchorIndices);
|
||
for (let i = 0; i < this.indices.length; i++) {
|
||
if (i === this.#anchorClusterIndex) {
|
||
continue;
|
||
}
|
||
this.indices[i] = this.indices[i].filter(item => {
|
||
if (anchorSet.has(item)) {
|
||
this.indices[this.#anchorClusterIndex].push(item);
|
||
return false;
|
||
}
|
||
return true;
|
||
});
|
||
}
|
||
this._buildClusterRepresentations();
|
||
}
|
||
|
||
/**
|
||
* Prints information about the cluster
|
||
*/
|
||
printClusters() {
|
||
for (let cluster of this.clusterRepresentations) {
|
||
cluster.print();
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Computes the inertia of the cluster which is the sum of square total distance.
|
||
* @returns {number}
|
||
*/
|
||
getCentroidInertia() {
|
||
let runningTotalDistance = 0;
|
||
this.clusterRepresentations.forEach(rep => {
|
||
runningTotalDistance += rep.computeTotalSquaredCentroidDistance();
|
||
});
|
||
return runningTotalDistance;
|
||
}
|
||
|
||
/**
|
||
* Converts a cluster representation to a flat list of tabs, with clusterID key in each
|
||
* tab representing the id of the cluster it was part of.
|
||
* @returns {[Object]}
|
||
*/
|
||
_flatMapItemsInClusters() {
|
||
return this.clusterRepresentations.reduce((result, clusterRep) => {
|
||
const annotatedTabs = clusterRep.tabs.map(a => {
|
||
let c = {};
|
||
Object.assign(c, a);
|
||
c.clusterID = clusterRep.clusterID;
|
||
return c;
|
||
});
|
||
return result.concat(annotatedTabs);
|
||
}, []);
|
||
}
|
||
|
||
/**
|
||
* Get rand score which describes the accuracy versus a user labeled
|
||
* annotation on the dataset. Requires the dataset to be labeled.
|
||
* @param labelKey Key in the tabs that represent a unique label ID for the cluster.
|
||
* @returns {number} The rand score.
|
||
*/
|
||
getRandScore(labelKey = "annotatedLabel") {
|
||
const combinedItems = this._flatMapItemsInClusters();
|
||
return computeRandScore(combinedItems, "clusterID", labelKey);
|
||
}
|
||
|
||
/**
|
||
* Get accuracy for a specific cluster
|
||
* @param labelKey Key in the tabs that represent a unique label ID for the cluster.
|
||
* @param clusterValue is the cluster we are comparing
|
||
* @returns {number} The rand score.
|
||
*/
|
||
getAccuracyStatsForCluster(labelKey = "annotatedLabel", clusterValue) {
|
||
const combinedItems = this._flatMapItemsInClusters();
|
||
|
||
let keyClusterId = combinedItems.find(
|
||
a => a[labelKey] === clusterValue
|
||
).clusterID;
|
||
|
||
let truePositives = 0,
|
||
trueNegatives = 0,
|
||
falseNegatives = 0,
|
||
falsePositives = 0;
|
||
|
||
combinedItems.forEach(item => {
|
||
const sameLabel = item[labelKey] === clusterValue;
|
||
const sameCluster = item.clusterID === keyClusterId;
|
||
if (sameLabel && sameCluster) {
|
||
truePositives++;
|
||
}
|
||
if (!sameLabel && !sameCluster) {
|
||
trueNegatives++;
|
||
}
|
||
if (sameLabel && !sameCluster) {
|
||
falseNegatives++;
|
||
}
|
||
if (!sameLabel && sameCluster) {
|
||
falsePositives++;
|
||
}
|
||
});
|
||
return getAccuracyStats({
|
||
truePositives,
|
||
trueNegatives,
|
||
falsePositives,
|
||
falseNegatives,
|
||
});
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Utility function to generate a random ID string
|
||
* @param len Length of the string
|
||
* @returns {string}
|
||
*/
|
||
function genHexString(len) {
|
||
const hex = "0123456789ABCDEF";
|
||
let output = "";
|
||
for (let i = 0; i < len; ++i) {
|
||
output += hex.charAt(Math.floor(Math.random() * hex.length));
|
||
}
|
||
return output;
|
||
}
|
||
|
||
class EmbeddingCluster {
|
||
constructor({ tabs, embeddings, centroid }) {
|
||
this.embeddings = embeddings;
|
||
this.centroid =
|
||
centroid || (embeddings && computeCentroidFrom2DArray(this.embeddings));
|
||
this.tabs = tabs;
|
||
}
|
||
|
||
/**
|
||
* @returns total sum euclidan squared distance of each item from cluster's centroid
|
||
*/
|
||
computeTotalSquaredCentroidDistance() {
|
||
let totalDistance = 0;
|
||
if (this.embeddings.length === 0) {
|
||
return 0;
|
||
}
|
||
this.embeddings.forEach(embedding => {
|
||
totalDistance += euclideanDistance(this.centroid, embedding, true);
|
||
});
|
||
return totalDistance;
|
||
}
|
||
|
||
/**
|
||
* Returns number of items in the cluster
|
||
* @returns {int}
|
||
*/
|
||
numItems() {
|
||
return this.tabs.length;
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Represents a single cluster with additional saved metadata
|
||
*/
|
||
export class ClusterRepresentation extends EmbeddingCluster {
|
||
constructor({ tabs, embeddings, centroid, config }) {
|
||
super({ tabs, embeddings, centroid });
|
||
this.config = config;
|
||
this.predictedTopicLabel = null;
|
||
this.annotatedTopicLabel = null;
|
||
this.userEditedTopicLabel = null;
|
||
this.representativeText = null;
|
||
this.keywords = null;
|
||
this.documents = null;
|
||
this.clusterID = genHexString(10);
|
||
}
|
||
|
||
/**
|
||
* Returns the representative text for a cluster, computing it if needed
|
||
*/
|
||
getRepresentativeText() {
|
||
if (!this.representativeText) {
|
||
this.representativeText = this._generateRepresentativeText();
|
||
}
|
||
return this.representativeText;
|
||
}
|
||
|
||
/**
|
||
* Returns representative text for a cluster.
|
||
* For this in initial implementation it simply returns title from a few tabs
|
||
* @returns {string}
|
||
* @private
|
||
*/
|
||
_generateRepresentativeText() {
|
||
let text = "";
|
||
const titleKey = this.config.dataConfig.titleKey;
|
||
for (const tab of this.tabs.slice(0, 3)) {
|
||
text += `\n${tab[titleKey]}`;
|
||
}
|
||
return text;
|
||
}
|
||
|
||
print() {
|
||
// Add console log for debugging
|
||
}
|
||
}
|