403 lines
13 KiB
JavaScript
403 lines
13 KiB
JavaScript
/* This Source Code Form is subject to the terms of the Mozilla Public
|
|
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
* file, You can obtain one at https://mozilla.org/MPL/2.0/. */
|
|
|
|
/**
|
|
* MLSuggest helps with ML based suggestions around intents and location.
|
|
*/
|
|
|
|
const lazy = {};
|
|
|
|
ChromeUtils.defineESModuleGetters(lazy, {
|
|
createEngine: "chrome://global/content/ml/EngineProcess.sys.mjs",
|
|
UrlbarPrefs: "resource:///modules/UrlbarPrefs.sys.mjs",
|
|
});
|
|
|
|
// List of prepositions used in subject cleaning.
|
|
const PREPOSITIONS = ["in", "at", "on", "for", "to", "near"];
|
|
|
|
const MAX_QUERY_LENGTH = 200;
|
|
const NAME_PUNCTUATION = [".", "-", "'"];
|
|
const NAME_PUNCTUATION_EXCEPT_DOT = NAME_PUNCTUATION.filter(p => p !== ".");
|
|
|
|
/**
|
|
* Class for handling ML-based suggestions using intent and NER models.
|
|
*
|
|
* @class
|
|
*/
|
|
class _MLSuggest {
|
|
#modelEngines = {};
|
|
|
|
INTENT_OPTIONS = {
|
|
taskName: "text-classification",
|
|
featureId: "suggest-intent-classification",
|
|
timeoutMS: -1,
|
|
numThreads: 2,
|
|
};
|
|
|
|
NER_OPTIONS = {
|
|
taskName: "token-classification",
|
|
featureId: "suggest-NER",
|
|
timeoutMS: -1,
|
|
numThreads: 2,
|
|
};
|
|
|
|
// Helper to wrap createEngine for testing purpose
|
|
createEngine(args) {
|
|
return lazy.createEngine(args);
|
|
}
|
|
|
|
/**
|
|
* Initializes the intent and NER models.
|
|
*/
|
|
async initialize() {
|
|
await Promise.all([
|
|
this.#initializeModelEngine(this.INTENT_OPTIONS),
|
|
this.#initializeModelEngine(this.NER_OPTIONS),
|
|
]);
|
|
}
|
|
|
|
/**
|
|
* Generates ML-based suggestions by finding intent, detecting entities, and
|
|
* combining locations.
|
|
*
|
|
* @param {string} query
|
|
* The user's input query.
|
|
* @returns {object | null}
|
|
* The suggestion result including intent, location, and subject, or null if
|
|
* an error occurs or query length > MAX_QUERY_LENGTH
|
|
* {string} intent
|
|
* The predicted intent label of the query. Possible values include:
|
|
* - 'information_intent': For queries seeking general information.
|
|
* - 'yelp_intent': For queries related to local businesses or services.
|
|
* - 'navigation_intent': For queries with navigation-related actions.
|
|
* - 'travel_intent': For queries showing travel-related interests.
|
|
* - 'purchase_intent': For queries with purchase or shopping intent.
|
|
* - 'weather_intent': For queries asking about weather or forecasts.
|
|
* - 'translation_intent': For queries seeking translations.
|
|
* - 'unknown': When the intent cannot be classified with confidence.
|
|
* - '' (empty string): Returned when model probabilities for all intents
|
|
* are below the intent threshold.
|
|
* - {object|null} location: The detected location from the query, which is
|
|
* an object with `city` and `state` fields:
|
|
* - {string|null} city: The detected city, or `null` if no city is found.
|
|
* - {string|null} state: The detected state, or `null` if no state is found.
|
|
* {string} subject
|
|
* The subject of the query after location is removed.
|
|
* {object} metrics
|
|
* The combined metrics from NER model results, representing additional
|
|
* information about the model's performance.
|
|
*/
|
|
async makeSuggestions(query) {
|
|
// avoid bunch of work for very long strings
|
|
if (query.length > MAX_QUERY_LENGTH) {
|
|
return null;
|
|
}
|
|
|
|
let intentRes, nerResult;
|
|
try {
|
|
[intentRes, nerResult] = await Promise.all([
|
|
this._findIntent(query),
|
|
this._findNER(query),
|
|
]);
|
|
} catch (error) {
|
|
return null;
|
|
}
|
|
|
|
if (!intentRes || !nerResult) {
|
|
return null;
|
|
}
|
|
|
|
const locationResVal = await this.#combineLocations(
|
|
nerResult,
|
|
lazy.UrlbarPrefs.get("nerThreshold")
|
|
);
|
|
|
|
const intentLabel = await this.#applyIntentThreshold(
|
|
intentRes,
|
|
lazy.UrlbarPrefs.get("intentThreshold")
|
|
);
|
|
|
|
return {
|
|
intent: intentLabel,
|
|
location: locationResVal,
|
|
subject: this.#findSubjectFromQuery(query, locationResVal),
|
|
metrics: { intent: intentRes.metrics, ner: nerResult.metrics },
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Shuts down all initialized engines.
|
|
*/
|
|
async shutdown() {
|
|
for (const [key, engine] of Object.entries(this.#modelEngines)) {
|
|
try {
|
|
await engine.terminate?.();
|
|
} finally {
|
|
// Remove each engine after termination
|
|
delete this.#modelEngines[key];
|
|
}
|
|
}
|
|
}
|
|
|
|
async #initializeModelEngine(options) {
|
|
const featureId = options.featureId;
|
|
|
|
// uses cache if engine was used
|
|
if (this.#modelEngines[featureId]) {
|
|
return this.#modelEngines[featureId];
|
|
}
|
|
|
|
const engine = await this.createEngine(options);
|
|
// Cache the engine
|
|
this.#modelEngines[featureId] = engine;
|
|
return engine;
|
|
}
|
|
|
|
/**
|
|
* Finds the intent of the query using the intent classification model.
|
|
* (This has been made public to enable testing)
|
|
*
|
|
* @param {string} query
|
|
* The user's input query.
|
|
* @param {object} options
|
|
* The options for the engine pipeline
|
|
* @returns {object[] | null}
|
|
* The intent results or null if the model is not initialized.
|
|
*/
|
|
async _findIntent(query, options = {}) {
|
|
const engineIntentClassifier =
|
|
this.#modelEngines[this.INTENT_OPTIONS.featureId];
|
|
if (!engineIntentClassifier) {
|
|
return null;
|
|
}
|
|
|
|
let res;
|
|
try {
|
|
res = await engineIntentClassifier.run({
|
|
args: [query],
|
|
options,
|
|
});
|
|
} catch (error) {
|
|
// engine could timeout or fail, so remove that from cache
|
|
// and reinitialize
|
|
this.#modelEngines[this.INTENT_OPTIONS.featureId] = null;
|
|
this.#initializeModelEngine(this.INTENT_OPTIONS);
|
|
return null;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* Finds named entities in the query using the NER model.
|
|
* (This has been made public to enable testing)
|
|
*
|
|
* @param {string} query
|
|
* The user's input query.
|
|
* @param {object} options
|
|
* The options for the engine pipeline
|
|
* @returns {object[] | null}
|
|
* The NER results or null if the model is not initialized.
|
|
*/
|
|
async _findNER(query, options = {}) {
|
|
const engineNER = this.#modelEngines[this.NER_OPTIONS.featureId];
|
|
try {
|
|
return engineNER?.run({ args: [query], options });
|
|
} catch (error) {
|
|
// engine could timeout or fail, so remove that from cache
|
|
// and reinitialize
|
|
this.#modelEngines[this.NER_OPTIONS.featureId] = null;
|
|
this.#initializeModelEngine(this.NER_OPTIONS);
|
|
return null;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Applies a confidence threshold to determine the intent label.
|
|
*
|
|
* If the highest-scoring intent in the result exceeds the threshold, its label
|
|
* is returned; otherwise, the label defaults to 'unknown'.
|
|
*
|
|
* @param {object[]} intentResult
|
|
* The result of the intent classification model, where each item includes
|
|
* a `label` and `score`.
|
|
* @param {number} intentThreshold
|
|
* The confidence threshold for accepting the intent label.
|
|
* @returns {string}
|
|
* The determined intent label or 'unknown' if the threshold is not met.
|
|
*/
|
|
async #applyIntentThreshold(intentResult, intentThreshold) {
|
|
return intentResult[0]?.score > intentThreshold
|
|
? intentResult[0].label
|
|
: "";
|
|
}
|
|
|
|
/**
|
|
* Combines location tokens detected by NER into separate city and state
|
|
* components. This method processes city, state, and combined city-state
|
|
* entities, returning an object with `city` and `state` fields.
|
|
*
|
|
* Handles the following entity types:
|
|
* - B-CITY, I-CITY: Identifies city tokens.
|
|
* - B-STATE, I-STATE: Identifies state tokens.
|
|
* - B-CITYSTATE, I-CITYSTATE: Identifies tokens that represent a combined
|
|
* city and state.
|
|
*
|
|
* @param {object[]} nerResult
|
|
* The NER results containing tokens and their corresponding entity labels.
|
|
* @param {number} nerThreshold
|
|
* The confidence threshold for including entities. Tokens with a confidence
|
|
* score below this threshold will be ignored.
|
|
* @returns {object}
|
|
* An object with `city` and `state` fields:
|
|
* - {string|null} city: The detected city, or `null` if no city is found.
|
|
* - {string|null} state: The detected state, or `null` if no state is found.
|
|
*/
|
|
async #combineLocations(nerResult, nerThreshold) {
|
|
let cityResult = [];
|
|
let stateResult = [];
|
|
let cityStateResult = [];
|
|
|
|
for (let i = 0; i < nerResult.length; i++) {
|
|
const res = nerResult[i];
|
|
if (res.entity === "B-CITY" || res.entity === "I-CITY") {
|
|
this.#processNERToken(res, cityResult, nerThreshold);
|
|
} else if (res.entity === "B-STATE" || res.entity === "I-STATE") {
|
|
this.#processNERToken(res, stateResult, nerThreshold);
|
|
} else if (res.entity === "B-CITYSTATE" || res.entity === "I-CITYSTATE") {
|
|
this.#processNERToken(res, cityStateResult, nerThreshold);
|
|
}
|
|
}
|
|
|
|
// Handle city_state as combined and split into city and state
|
|
if (cityStateResult.length && !cityResult.length && !stateResult.length) {
|
|
let cityStateSplit = cityStateResult.join(" ").split(",");
|
|
cityResult =
|
|
cityStateSplit[0]
|
|
?.trim?.()
|
|
.split(",")
|
|
.filter(item => item.trim() !== "") || [];
|
|
stateResult =
|
|
cityStateSplit[1]
|
|
?.trim?.()
|
|
.split(",")
|
|
.filter(item => item.trim() !== "") || [];
|
|
}
|
|
|
|
// Remove trailing punctuation from the last cityResult element if present
|
|
this.#removePunctFromEndIfPresent(cityResult);
|
|
this.#removePunctFromEndIfPresent(stateResult);
|
|
|
|
// Return city and state as separate components if detected
|
|
return {
|
|
city: cityResult.join(" ").trim() || null,
|
|
state: stateResult.join(" ").trim() || null,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Processes a token from the NER results, appending it to the provided result
|
|
* array while handling wordpieces (e.g., "##"), punctuation, and
|
|
* multi-token entities.
|
|
*
|
|
* - Appends wordpieces (starting with "##") to the last token in the array.
|
|
* - Handles punctuation tokens like ".", "-", or "'".
|
|
* - Ensures continuity for entities split across multiple tokens.
|
|
*
|
|
* @param {object} res
|
|
* The NER result token to process. Should include:
|
|
* - {string} word: The word or token from the NER output.
|
|
* - {number} score: The confidence score for the token.
|
|
* - {string} entity: The entity type label (e.g., "B-CITY", "I-STATE").
|
|
* @param {string[]} resultArray
|
|
* The array to append the processed token. Typically `cityResult`,
|
|
* `stateResult`, or `cityStateResult`.
|
|
* @param {number} nerThreshold
|
|
* The confidence threshold for including tokens. Tokens with a score below
|
|
* this threshold will be ignored.
|
|
*/
|
|
async #processNERToken(res, resultArray, nerThreshold) {
|
|
// Skip low-confidence tokens
|
|
if (res.score <= nerThreshold) {
|
|
return;
|
|
}
|
|
|
|
const lastTokenIndex = resultArray.length - 1;
|
|
// "##" prefix indicates that a token is continuation of a word
|
|
// rather than a start of a new word.
|
|
// reference -> https://github.com/google-research/bert/blob/master/tokenization.py#L314-L316
|
|
if (res.word.startsWith("##") && resultArray.length) {
|
|
resultArray[lastTokenIndex] += res.word.slice(2);
|
|
} else if (
|
|
resultArray.length &&
|
|
(NAME_PUNCTUATION.includes(res.word) ||
|
|
NAME_PUNCTUATION_EXCEPT_DOT.includes(
|
|
resultArray[lastTokenIndex].slice(-1)
|
|
))
|
|
) {
|
|
// Special handling for punctuation like ".", "-", or "'"
|
|
resultArray[lastTokenIndex] += res.word;
|
|
} else {
|
|
resultArray.push(res.word);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Removes trailing punctuation from the last element in the result array
|
|
* if the last character matches any punctuation in `NAME_PUNCTUATION`.
|
|
*
|
|
* This method is useful for cleaning up city or state tokens that may
|
|
* contain unwanted punctuation after processing NER results.
|
|
*
|
|
* @param {string[]} resultArray
|
|
* An array of strings representing detected entities (e.g., cities or states).
|
|
* The array is modified in place if the last element ends with punctuation.
|
|
*/
|
|
async #removePunctFromEndIfPresent(resultArray) {
|
|
const lastTokenIndex = resultArray.length - 1;
|
|
if (
|
|
resultArray.length &&
|
|
NAME_PUNCTUATION.includes(resultArray[lastTokenIndex].slice(-1))
|
|
) {
|
|
resultArray[lastTokenIndex] = resultArray[lastTokenIndex].slice(0, -1);
|
|
}
|
|
}
|
|
|
|
#findSubjectFromQuery(query, location) {
|
|
// If location is null or no city/state, return the entire query
|
|
if (!location || (!location.city && !location.state)) {
|
|
return query;
|
|
}
|
|
|
|
// Remove the city and state values from the query
|
|
let locValues = Object.values(location)
|
|
.map(loc => loc?.replace(/\W+/g, " "))
|
|
.filter(loc => loc?.trim());
|
|
|
|
// Regular expression to remove locations
|
|
// This handles single & multi-worded cities/states
|
|
let locPattern = locValues.map(loc => `\\b${loc}\\b`).join("|");
|
|
let locRegex = new RegExp(locPattern, "g");
|
|
|
|
// Remove locations, trim whitespace, and split words
|
|
let words = query
|
|
.replace(/\W+/g, " ")
|
|
.replace(locRegex, "")
|
|
.split(/\W+/)
|
|
.filter(word => !!word.length);
|
|
|
|
let subjectWords = this.#cleanSubject(words);
|
|
return subjectWords.join(" ");
|
|
}
|
|
|
|
#cleanSubject(words) {
|
|
// Remove trailing prepositions from the list of words
|
|
while (words.length && PREPOSITIONS.includes(words[words.length - 1])) {
|
|
words.pop();
|
|
}
|
|
return words;
|
|
}
|
|
}
|
|
|
|
// Export the singleton instance
|
|
export var MLSuggest = new _MLSuggest();
|