summaryrefslogtreecommitdiffstats
path: root/litecli/sqlcompleter.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2020-08-14 16:58:23 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-07 10:11:48 +0000
commit4edd467b28c895483cd5468d51d1c6824a21715a (patch)
tree04a4f32d617905acfc23653025b6e8d3899f51c6 /litecli/sqlcompleter.py
parentInitial commit. (diff)
downloadlitecli-4edd467b28c895483cd5468d51d1c6824a21715a.tar.xz
litecli-4edd467b28c895483cd5468d51d1c6824a21715a.zip
Adding upstream version 1.5.0.upstream/1.5.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'litecli/sqlcompleter.py')
-rw-r--r--litecli/sqlcompleter.py612
1 files changed, 612 insertions, 0 deletions
diff --git a/litecli/sqlcompleter.py b/litecli/sqlcompleter.py
new file mode 100644
index 0000000..64ca352
--- /dev/null
+++ b/litecli/sqlcompleter.py
@@ -0,0 +1,612 @@
+from __future__ import print_function
+from __future__ import unicode_literals
+import logging
+from re import compile, escape
+from collections import Counter
+
+from prompt_toolkit.completion import Completer, Completion
+
+from .packages.completion_engine import suggest_type
+from .packages.parseutils import last_word
+from .packages.special.iocommands import favoritequeries
+from .packages.filepaths import parse_path, complete_path, suggest_path
+
+_logger = logging.getLogger(__name__)
+
+
+class SQLCompleter(Completer):
+ keywords = [
+ "ABORT",
+ "ACTION",
+ "ADD",
+ "AFTER",
+ "ALL",
+ "ALTER",
+ "ANALYZE",
+ "AND",
+ "AS",
+ "ASC",
+ "ATTACH",
+ "AUTOINCREMENT",
+ "BEFORE",
+ "BEGIN",
+ "BETWEEN",
+ "BIGINT",
+ "BLOB",
+ "BOOLEAN",
+ "BY",
+ "CASCADE",
+ "CASE",
+ "CAST",
+ "CHARACTER",
+ "CHECK",
+ "CLOB",
+ "COLLATE",
+ "COLUMN",
+ "COMMIT",
+ "CONFLICT",
+ "CONSTRAINT",
+ "CREATE",
+ "CROSS",
+ "CURRENT",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "DATABASE",
+ "DATE",
+ "DATETIME",
+ "DECIMAL",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DEFERRED",
+ "DELETE",
+ "DETACH",
+ "DISTINCT",
+ "DO",
+ "DOUBLE PRECISION",
+ "DOUBLE",
+ "DROP",
+ "EACH",
+ "ELSE",
+ "END",
+ "ESCAPE",
+ "EXCEPT",
+ "EXCLUSIVE",
+ "EXISTS",
+ "EXPLAIN",
+ "FAIL",
+ "FILTER",
+ "FLOAT",
+ "FOLLOWING",
+ "FOR",
+ "FOREIGN",
+ "FROM",
+ "FULL",
+ "GLOB",
+ "GROUP",
+ "HAVING",
+ "IF",
+ "IGNORE",
+ "IMMEDIATE",
+ "IN",
+ "INDEX",
+ "INDEXED",
+ "INITIALLY",
+ "INNER",
+ "INSERT",
+ "INSTEAD",
+ "INT",
+ "INT2",
+ "INT8",
+ "INTEGER",
+ "INTERSECT",
+ "INTO",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "KEY",
+ "LEFT",
+ "LIKE",
+ "LIMIT",
+ "MATCH",
+ "MEDIUMINT",
+ "NATIVE CHARACTER",
+ "NATURAL",
+ "NCHAR",
+ "NO",
+ "NOT",
+ "NOTHING",
+ "NULL",
+ "NULLS FIRST",
+ "NULLS LAST",
+ "NUMERIC",
+ "NVARCHAR",
+ "OF",
+ "OFFSET",
+ "ON",
+ "OR",
+ "ORDER BY",
+ "OUTER",
+ "OVER",
+ "PARTITION",
+ "PLAN",
+ "PRAGMA",
+ "PRECEDING",
+ "PRIMARY",
+ "QUERY",
+ "RAISE",
+ "RANGE",
+ "REAL",
+ "RECURSIVE",
+ "REFERENCES",
+ "REGEXP",
+ "REINDEX",
+ "RELEASE",
+ "RENAME",
+ "REPLACE",
+ "RESTRICT",
+ "RIGHT",
+ "ROLLBACK",
+ "ROW",
+ "ROWS",
+ "SAVEPOINT",
+ "SELECT",
+ "SET",
+ "SMALLINT",
+ "TABLE",
+ "TEMP",
+ "TEMPORARY",
+ "TEXT",
+ "THEN",
+ "TINYINT",
+ "TO",
+ "TRANSACTION",
+ "TRIGGER",
+ "UNBOUNDED",
+ "UNION",
+ "UNIQUE",
+ "UNSIGNED BIG INT",
+ "UPDATE",
+ "USING",
+ "VACUUM",
+ "VALUES",
+ "VARCHAR",
+ "VARYING CHARACTER",
+ "VIEW",
+ "VIRTUAL",
+ "WHEN",
+ "WHERE",
+ "WINDOW",
+ "WITH",
+ "WITHOUT",
+ ]
+
+ functions = [
+ "ABS",
+ "AVG",
+ "CHANGES",
+ "CHAR",
+ "COALESCE",
+ "COUNT",
+ "CUME_DIST",
+ "DATE",
+ "DATETIME",
+ "DENSE_RANK",
+ "GLOB",
+ "GROUP_CONCAT",
+ "HEX",
+ "IFNULL",
+ "INSTR",
+ "JSON",
+ "JSON_ARRAY",
+ "JSON_ARRAY_LENGTH",
+ "JSON_EACH",
+ "JSON_EXTRACT",
+ "JSON_GROUP_ARRAY",
+ "JSON_GROUP_OBJECT",
+ "JSON_INSERT",
+ "JSON_OBJECT",
+ "JSON_PATCH",
+ "JSON_QUOTE",
+ "JSON_REMOVE",
+ "JSON_REPLACE",
+ "JSON_SET",
+ "JSON_TREE",
+ "JSON_TYPE",
+ "JSON_VALID",
+ "JULIANDAY",
+ "LAG",
+ "LAST_INSERT_ROWID",
+ "LENGTH",
+ "LIKELIHOOD",
+ "LIKELY",
+ "LOAD_EXTENSION",
+ "LOWER",
+ "LTRIM",
+ "MAX",
+ "MIN",
+ "NTILE",
+ "NULLIF",
+ "PERCENT_RANK",
+ "PRINTF",
+ "QUOTE",
+ "RANDOM",
+ "RANDOMBLOB",
+ "RANK",
+ "REPLACE",
+ "ROUND",
+ "ROW_NUMBER",
+ "RTRIM",
+ "SOUNDEX",
+ "SQLITE_COMPILEOPTION_GET",
+ "SQLITE_COMPILEOPTION_USED",
+ "SQLITE_OFFSET",
+ "SQLITE_SOURCE_ID",
+ "SQLITE_VERSION",
+ "STRFTIME",
+ "SUBSTR",
+ "SUM",
+ "TIME",
+ "TOTAL",
+ "TOTAL_CHANGES",
+ "TRIM",
+ ]
+
+ def __init__(self, supported_formats=(), keyword_casing="auto"):
+ super(self.__class__, self).__init__()
+ self.reserved_words = set()
+ for x in self.keywords:
+ self.reserved_words.update(x.split())
+ self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
+
+ self.special_commands = []
+ self.table_formats = supported_formats
+ if keyword_casing not in ("upper", "lower", "auto"):
+ keyword_casing = "auto"
+ self.keyword_casing = keyword_casing
+ self.reset_completions()
+
+ def escape_name(self, name):
+ if name and (
+ (not self.name_pattern.match(name))
+ or (name.upper() in self.reserved_words)
+ or (name.upper() in self.functions)
+ ):
+ name = "`%s`" % name
+
+ return name
+
+ def unescape_name(self, name):
+ """Unquote a string."""
+ if name and name[0] == '"' and name[-1] == '"':
+ name = name[1:-1]
+
+ return name
+
+ def escaped_names(self, names):
+ return [self.escape_name(name) for name in names]
+
+ def extend_special_commands(self, special_commands):
+ # Special commands are not part of all_completions since they can only
+ # be at the beginning of a line.
+ self.special_commands.extend(special_commands)
+
+ def extend_database_names(self, databases):
+ self.databases.extend(databases)
+
+ def extend_keywords(self, additional_keywords):
+ self.keywords.extend(additional_keywords)
+ self.all_completions.update(additional_keywords)
+
+ def extend_schemata(self, schema):
+ if schema is None:
+ return
+ metadata = self.dbmetadata["tables"]
+ metadata[schema] = {}
+
+ # dbmetadata.values() are the 'tables' and 'functions' dicts
+ for metadata in self.dbmetadata.values():
+ metadata[schema] = {}
+ self.all_completions.update(schema)
+
+ def extend_relations(self, data, kind):
+ """Extend metadata for tables or views
+
+ :param data: list of (rel_name, ) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'data' is a generator object. It can throw an exception while being
+ # consumed. This could happen if the user has launched the app without
+ # specifying a database name. This exception must be handled to prevent
+ # crashing.
+ try:
+ data = [self.escaped_names(d) for d in data]
+ except Exception:
+ data = []
+
+ # dbmetadata['tables'][$schema_name][$table_name] should be a list of
+ # column names. Default to an asterisk
+ metadata = self.dbmetadata[kind]
+ for relname in data:
+ try:
+ metadata[self.dbname][relname[0]] = ["*"]
+ except KeyError:
+ _logger.error(
+ "%r %r listed in unrecognized schema %r",
+ kind,
+ relname[0],
+ self.dbname,
+ )
+ self.all_completions.add(relname[0])
+
+ def extend_columns(self, column_data, kind):
+ """Extend column metadata
+
+ :param column_data: list of (rel_name, column_name) tuples
+ :param kind: either 'tables' or 'views'
+ :return:
+ """
+ # 'column_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ column_data = [self.escaped_names(d) for d in column_data]
+ except Exception:
+ column_data = []
+
+ metadata = self.dbmetadata[kind]
+ for relname, column in column_data:
+ metadata[self.dbname][relname].append(column)
+ self.all_completions.add(column)
+
+ def extend_functions(self, func_data):
+ # 'func_data' is a generator object. It can throw an exception while
+ # being consumed. This could happen if the user has launched the app
+ # without specifying a database name. This exception must be handled to
+ # prevent crashing.
+ try:
+ func_data = [self.escaped_names(d) for d in func_data]
+ except Exception:
+ func_data = []
+
+ # dbmetadata['functions'][$schema_name][$function_name] should return
+ # function metadata.
+ metadata = self.dbmetadata["functions"]
+
+ for func in func_data:
+ metadata[self.dbname][func[0]] = None
+ self.all_completions.add(func[0])
+
+ def set_dbname(self, dbname):
+ self.dbname = dbname
+
+ def reset_completions(self):
+ self.databases = []
+ self.dbname = ""
+ self.dbmetadata = {"tables": {}, "views": {}, "functions": {}}
+ self.all_completions = set(self.keywords + self.functions)
+
+ @staticmethod
+ def find_matches(
+ text,
+ collection,
+ start_only=False,
+ fuzzy=True,
+ casing=None,
+ punctuations="most_punctuations",
+ ):
+ """Find completion matches for the given text.
+
+ Given the user's input text and a collection of available
+ completions, find completions matching the last word of the
+ text.
+
+ If `start_only` is True, the text will match an available
+ completion only at the beginning. Otherwise, a completion is
+ considered a match if the text appears anywhere within it.
+
+ yields prompt_toolkit Completion instances for any matches found
+ in the collection of available completions.
+ """
+ last = last_word(text, include=punctuations)
+ text = last.lower()
+
+ completions = []
+
+ if fuzzy:
+ regex = ".*?".join(map(escape, text))
+ pat = compile("(%s)" % regex)
+ for item in sorted(collection):
+ r = pat.search(item.lower())
+ if r:
+ completions.append((len(r.group()), r.start(), item))
+ else:
+ match_end_limit = len(text) if start_only else None
+ for item in sorted(collection):
+ match_point = item.lower().find(text, 0, match_end_limit)
+ if match_point >= 0:
+ completions.append((len(text), match_point, item))
+
+ if casing == "auto":
+ casing = "lower" if last and last[-1].islower() else "upper"
+
+ def apply_case(kw):
+ if casing == "upper":
+ return kw.upper()
+ return kw.lower()
+
+ return (
+ Completion(z if casing is None else apply_case(z), -len(text))
+ for x, y, z in sorted(completions)
+ )
+
+ def get_completions(self, document, complete_event):
+ word_before_cursor = document.get_word_before_cursor(WORD=True)
+ completions = []
+ suggestions = suggest_type(document.text, document.text_before_cursor)
+
+ for suggestion in suggestions:
+
+ _logger.debug("Suggestion type: %r", suggestion["type"])
+
+ if suggestion["type"] == "column":
+ tables = suggestion["tables"]
+ _logger.debug("Completion column scope: %r", tables)
+ scoped_cols = self.populate_scoped_cols(tables)
+ if suggestion.get("drop_unique"):
+ # drop_unique is used for 'tb11 JOIN tbl2 USING (...'
+ # which should suggest only columns that appear in more than
+ # one table
+ scoped_cols = [
+ col
+ for (col, count) in Counter(scoped_cols).items()
+ if count > 1 and col != "*"
+ ]
+
+ cols = self.find_matches(word_before_cursor, scoped_cols)
+ completions.extend(cols)
+
+ elif suggestion["type"] == "function":
+ # suggest user-defined functions using substring matching
+ funcs = self.populate_schema_objects(suggestion["schema"], "functions")
+ user_funcs = self.find_matches(word_before_cursor, funcs)
+ completions.extend(user_funcs)
+
+ # suggest hardcoded functions using startswith matching only if
+ # there is no schema qualifier. If a schema qualifier is
+ # present it probably denotes a table.
+ # eg: SELECT * FROM users u WHERE u.
+ if not suggestion["schema"]:
+ predefined_funcs = self.find_matches(
+ word_before_cursor,
+ self.functions,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing,
+ )
+ completions.extend(predefined_funcs)
+
+ elif suggestion["type"] == "table":
+ tables = self.populate_schema_objects(suggestion["schema"], "tables")
+ tables = self.find_matches(word_before_cursor, tables)
+ completions.extend(tables)
+
+ elif suggestion["type"] == "view":
+ views = self.populate_schema_objects(suggestion["schema"], "views")
+ views = self.find_matches(word_before_cursor, views)
+ completions.extend(views)
+
+ elif suggestion["type"] == "alias":
+ aliases = suggestion["aliases"]
+ aliases = self.find_matches(word_before_cursor, aliases)
+ completions.extend(aliases)
+
+ elif suggestion["type"] == "database":
+ dbs = self.find_matches(word_before_cursor, self.databases)
+ completions.extend(dbs)
+
+ elif suggestion["type"] == "keyword":
+ keywords = self.find_matches(
+ word_before_cursor,
+ self.keywords,
+ start_only=True,
+ fuzzy=False,
+ casing=self.keyword_casing,
+ punctuations="many_punctuations",
+ )
+ completions.extend(keywords)
+
+ elif suggestion["type"] == "special":
+ special = self.find_matches(
+ word_before_cursor,
+ self.special_commands,
+ start_only=True,
+ fuzzy=False,
+ punctuations="many_punctuations",
+ )
+ completions.extend(special)
+ elif suggestion["type"] == "favoritequery":
+ queries = self.find_matches(
+ word_before_cursor,
+ favoritequeries.list(),
+ start_only=False,
+ fuzzy=True,
+ )
+ completions.extend(queries)
+ elif suggestion["type"] == "table_format":
+ formats = self.find_matches(
+ word_before_cursor, self.table_formats, start_only=True, fuzzy=False
+ )
+ completions.extend(formats)
+ elif suggestion["type"] == "file_name":
+ file_names = self.find_files(word_before_cursor)
+ completions.extend(file_names)
+
+ return completions
+
+ def find_files(self, word):
+ """Yield matching directory or file names.
+
+ :param word:
+ :return: iterable
+
+ """
+ base_path, last_path, position = parse_path(word)
+ paths = suggest_path(word)
+ for name in sorted(paths):
+ suggestion = complete_path(name, last_path)
+ if suggestion:
+ yield Completion(suggestion, position)
+
+ def populate_scoped_cols(self, scoped_tbls):
+ """Find all columns in a set of scoped_tables
+ :param scoped_tbls: list of (schema, table, alias) tuples
+ :return: list of column names
+ """
+ columns = []
+ meta = self.dbmetadata
+
+ for tbl in scoped_tbls:
+ # A fully qualified schema.relname reference or default_schema
+ # DO NOT escape schema names.
+ schema = tbl[0] or self.dbname
+ relname = tbl[1]
+ escaped_relname = self.escape_name(tbl[1])
+
+ # We don't know if schema.relname is a table or view. Since
+ # tables and views cannot share the same name, we can check one
+ # at a time
+ try:
+ columns.extend(meta["tables"][schema][relname])
+
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ try:
+ columns.extend(meta["tables"][schema][escaped_relname])
+ # Table exists, so don't bother checking for a view
+ continue
+ except KeyError:
+ pass
+
+ try:
+ columns.extend(meta["views"][schema][relname])
+ except KeyError:
+ pass
+
+ return columns
+
+ def populate_schema_objects(self, schema, obj_type):
+ """Returns list of tables or functions for a (optional) schema"""
+ metadata = self.dbmetadata[obj_type]
+ schema = schema or self.dbname
+
+ try:
+ objects = metadata[schema].keys()
+ except KeyError:
+ # schema doesn't exist
+ objects = []
+
+ return objects