diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-02-08 10:31:05 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-02-08 10:31:05 +0000 |
commit | 6884720fae8a2622b14e93d9e35ca5fcc2283b40 (patch) | |
tree | df6f736bb623cdd7932bbe2256101a6ac4ef7f35 /pgcli/packages | |
parent | Initial commit. (diff) | |
download | pgcli-6884720fae8a2622b14e93d9e35ca5fcc2283b40.tar.xz pgcli-6884720fae8a2622b14e93d9e35ca5fcc2283b40.zip |
Adding upstream version 3.1.0.upstream/3.1.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli/packages')
-rw-r--r-- | pgcli/packages/__init__.py | 0 | ||||
-rw-r--r-- | pgcli/packages/parseutils/__init__.py | 22 | ||||
-rw-r--r-- | pgcli/packages/parseutils/ctes.py | 141 | ||||
-rw-r--r-- | pgcli/packages/parseutils/meta.py | 170 | ||||
-rw-r--r-- | pgcli/packages/parseutils/tables.py | 170 | ||||
-rw-r--r-- | pgcli/packages/parseutils/utils.py | 140 | ||||
-rw-r--r-- | pgcli/packages/pgliterals/__init__.py | 0 | ||||
-rw-r--r-- | pgcli/packages/pgliterals/main.py | 15 | ||||
-rw-r--r-- | pgcli/packages/pgliterals/pgliterals.json | 629 | ||||
-rw-r--r-- | pgcli/packages/prioritization.py | 51 | ||||
-rw-r--r-- | pgcli/packages/prompt_utils.py | 35 | ||||
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 608 |
12 files changed, 1981 insertions, 0 deletions
diff --git a/pgcli/packages/__init__.py b/pgcli/packages/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/pgcli/packages/__init__.py diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py new file mode 100644 index 0000000..a11e7bf --- /dev/null +++ b/pgcli/packages/parseutils/__init__.py @@ -0,0 +1,22 @@ +import sqlparse + + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ("drop", "shutdown", "delete", "truncate", "alter") + return queries_start_with(queries, keywords) diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py new file mode 100644 index 0000000..e1f9088 --- /dev/null +++ b/pgcli/packages/parseutils/ctes.py @@ -0,0 +1,141 @@ +from sqlparse import parse +from sqlparse.tokens import Keyword, CTE, DML +from sqlparse.sql import Identifier, IdentifierList, Parenthesis +from collections import namedtuple +from .meta import TableMetadata, ColumnMetadata + + +# TableExpression is a namedtuple representing a CTE, used internally +# name: cte alias assigned in the query +# columns: list of column names +# start: index into the original string of the left parens starting the CTE +# stop: index into the original string of the right parens ending the CTE +TableExpression = namedtuple("TableExpression", "name columns start stop") + + +def isolate_query_ctes(full_text, text_before_cursor): + """Simplify a query by converting CTEs into table metadata objects""" + + if not full_text or not full_text.strip(): + return full_text, text_before_cursor, tuple() + + ctes, remainder = extract_ctes(full_text) + if not ctes: + return full_text, text_before_cursor, () + + current_position = len(text_before_cursor) + meta = [] + + for cte in ctes: + if cte.start < current_position < cte.stop: + # Currently editing a cte - treat its body as the current full_text + text_before_cursor = full_text[cte.start : current_position] + full_text = full_text[cte.start : cte.stop] + return full_text, text_before_cursor, meta + + # Append this cte to the list of available table metadata + cols = (ColumnMetadata(name, None, ()) for name in cte.columns) + meta.append(TableMetadata(cte.name, cols)) + + # Editing past the last cte (ie the main body of the query) + full_text = full_text[ctes[-1].stop :] + text_before_cursor = text_before_cursor[ctes[-1].stop : current_position] + + return full_text, text_before_cursor, tuple(meta) + + +def extract_ctes(sql): + """Extract constant table expresseions from a query + + Returns tuple (ctes, remainder_sql) + + ctes is a list of TableExpression namedtuples + remainder_sql is the text from the original query after the CTEs have + been stripped. + """ + + p = parse(sql)[0] + + # Make sure the first meaningful token is "WITH" which is necessary to + # define CTEs + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return [], sql + + # Get the next (meaningful) token, which should be the first CTE + idx, tok = p.token_next(idx) + if not tok: + return ([], "") + start_pos = token_start_pos(p.tokens, idx) + ctes = [] + + if isinstance(tok, IdentifierList): + # Multiple ctes + for t in tok.get_identifiers(): + cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) + cte = get_cte_from_token(t, start_pos + cte_start_offset) + if not cte: + continue + ctes.append(cte) + elif isinstance(tok, Identifier): + # A single CTE + cte = get_cte_from_token(tok, start_pos) + if cte: + ctes.append(cte) + + idx = p.token_index(tok) + 1 + + # Collapse everything after the ctes into a remainder query + remainder = "".join(str(tok) for tok in p.tokens[idx:]) + + return ctes, remainder + + +def get_cte_from_token(tok, pos0): + cte_name = tok.get_real_name() + if not cte_name: + return None + + # Find the start position of the opening parens enclosing the cte body + idx, parens = tok.token_next_by(Parenthesis) + if not parens: + return None + + start_pos = pos0 + token_start_pos(tok.tokens, idx) + cte_len = len(str(parens)) # includes parens + stop_pos = start_pos + cte_len + + column_names = extract_column_names(parens) + + return TableExpression(cte_name, column_names, start_pos, stop_pos) + + +def extract_column_names(parsed): + # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ("insert", "update", "delete"): + # Jump ahead to the RETURNING clause where the list of column names is + idx, tok = parsed.token_next_by(idx, (Keyword, "returning")) + elif not tok_val == "select": + # Must be invalid CTE + return () + + # The next token should be either a column name, or a list of column names + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in _identifiers(tok)) + + +def token_start_pos(tokens, idx): + return sum(len(str(t)) for t in tokens[:idx]) + + +def _identifiers(tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + # NB: IdentifierList.get_identifiers() can return non-identifiers! + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py new file mode 100644 index 0000000..108c01a --- /dev/null +++ b/pgcli/packages/parseutils/meta.py @@ -0,0 +1,170 @@ +from collections import namedtuple + +_ColumnMetadata = namedtuple( + "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"] +) + + +def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False): + return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default) + + +ForeignKey = namedtuple( + "ForeignKey", + [ + "parentschema", + "parenttable", + "parentcolumn", + "childschema", + "childtable", + "childcolumn", + ], +) +TableMetadata = namedtuple("TableMetadata", "name columns") + + +def parse_defaults(defaults_string): + """Yields default values for a function, given the string provided by + pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" + if not defaults_string: + return + current = "" + in_quote = None + for char in defaults_string: + if current == "" and char == " ": + # Skip space after comma separating default expressions + continue + if char == '"' or char == "'": + if in_quote and char == in_quote: + # End quote + in_quote = None + elif not in_quote: + # Begin quote + in_quote = char + elif char == "," and not in_quote: + # End of expression + yield current + current = "" + continue + current += char + yield current + + +class FunctionMetadata(object): + def __init__( + self, + schema_name, + func_name, + arg_names, + arg_types, + arg_modes, + return_type, + is_aggregate, + is_window, + is_set_returning, + is_extension, + arg_defaults, + ): + """Class for describing a postgresql function""" + + self.schema_name = schema_name + self.func_name = func_name + + self.arg_modes = tuple(arg_modes) if arg_modes else None + self.arg_names = tuple(arg_names) if arg_names else None + + # Be flexible in not requiring arg_types -- use None as a placeholder + # for each arg. (Used for compatibility with old versions of postgresql + # where such info is hard to get. + if arg_types: + self.arg_types = tuple(arg_types) + elif arg_modes: + self.arg_types = tuple([None] * len(arg_modes)) + elif arg_names: + self.arg_types = tuple([None] * len(arg_names)) + else: + self.arg_types = None + + self.arg_defaults = tuple(parse_defaults(arg_defaults)) + + self.return_type = return_type.strip() + self.is_aggregate = is_aggregate + self.is_window = is_window + self.is_set_returning = is_set_returning + self.is_extension = bool(is_extension) + self.is_public = self.schema_name and self.schema_name == "public" + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + def _signature(self): + return ( + self.schema_name, + self.func_name, + self.arg_names, + self.arg_types, + self.arg_modes, + self.return_type, + self.is_aggregate, + self.is_window, + self.is_set_returning, + self.is_extension, + self.arg_defaults, + ) + + def __hash__(self): + return hash(self._signature()) + + def __repr__(self): + return ( + "%s(schema_name=%r, func_name=%r, arg_names=%r, " + "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, " + "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)" + ) % ((self.__class__.__name__,) + self._signature()) + + def has_variadic(self): + return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes) + + def args(self): + """Returns a list of input-parameter ColumnMetadata namedtuples.""" + if not self.arg_names: + return [] + modes = self.arg_modes or ["i"] * len(self.arg_names) + args = [ + (name, typ) + for name, typ, mode in zip(self.arg_names, self.arg_types, modes) + if mode in ("i", "b", "v") # IN, INOUT, VARIADIC + ] + + def arg(name, typ, num): + num_args = len(args) + num_defaults = len(self.arg_defaults) + has_default = num + num_defaults >= num_args + default = ( + self.arg_defaults[num - num_args + num_defaults] + if has_default + else None + ) + return ColumnMetadata(name, typ, [], default, has_default) + + return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] + + def fields(self): + """Returns a list of output-field ColumnMetadata namedtuples""" + + if self.return_type.lower() == "void": + return [] + elif not self.arg_modes: + # For functions without output parameters, the function name + # is used as the name of the output column. + # E.g. 'SELECT unnest FROM unnest(...);' + return [ColumnMetadata(self.func_name, self.return_type, [])] + + return [ + ColumnMetadata(name, typ, []) + for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes) + if mode in ("o", "b", "t") + ] # OUT, INOUT, TABLE diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py new file mode 100644 index 0000000..0ec3e69 --- /dev/null +++ b/pgcli/packages/parseutils/tables.py @@ -0,0 +1,170 @@ +import sqlparse +from collections import namedtuple +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +TableReference = namedtuple( + "TableReference", ["schema", "name", "alias", "is_function"] +) +TableReference.ref = property( + lambda self: self.alias + or ( + self.name + if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"' + ) +) + + +# This code is borrowed from sqlparse example script. +# <url> +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ( + "SELECT", + "INSERT", + "UPDATE", + "CREATE", + "DELETE", + ): + return True + return False + + +def _identifier_is_function(identifier): + return any(isinstance(t, Function) for t in identifier.tokens) + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + return + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # `return`. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif ( + item.ttype is Keyword + and (not item.value.upper() == "FROM") + and (not item.value.upper().endswith("JOIN")) + ): + tbl_prefix_seen = False + else: + yield item + elif item.ttype is Keyword or item.ttype is Keyword.DML: + item_val = item.value.upper() + if ( + item_val + in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + ) + or item_val.endswith("JOIN") + ): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream, allow_functions=True): + """yields tuples of TableReference namedtuples""" + + # We need to do some massaging of the names because postgres is case- + # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is) + def parse_identifier(item): + name = item.get_real_name() + schema_name = item.get_parent_name() + alias = item.get_alias() + if not name: + schema_name = None + name = item.get_name() + alias = alias or name + schema_quoted = schema_name and item.value[0] == '"' + if schema_name and not schema_quoted: + schema_name = schema_name.lower() + quote_count = item.value.count('"') + name_quoted = quote_count > 2 or (quote_count and not schema_quoted) + alias_quoted = alias and item.value[-1] == '"' + if alias_quoted or name_quoted and not alias and name.islower(): + alias = '"' + (alias or name) + '"' + if name and not name_quoted and not name.islower(): + if not alias: + alias = name + name = name.lower() + return schema_name, name, alias + + try: + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = allow_functions and _identifier_is_function( + identifier + ) + except AttributeError: + continue + if real_name: + yield TableReference( + schema_name, real_name, identifier.get_alias(), is_function + ) + elif isinstance(item, Identifier): + schema_name, real_name, alias = parse_identifier(item) + is_function = allow_functions and _identifier_is_function(item) + + yield TableReference(schema_name, real_name, alias, is_function) + elif isinstance(item, Function): + schema_name, real_name, alias = parse_identifier(item) + yield TableReference(None, real_name, alias, allow_functions) + except StopIteration: + return + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of TableReference namedtuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return () + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == "insert" + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + + # Kludge: sqlparse mistakenly identifies insert statements as + # function calls due to the parenthesized column list, e.g. interprets + # "insert into foo (bar, baz)" as a function call to foo with arguments + # (bar, baz). So don't allow any identifiers in insert statements + # to have is_function=True + identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt) + # In the case 'sche.<cursor>', we get an empty TableReference; remove that + return tuple(i for i in identifiers if i.name) diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py new file mode 100644 index 0000000..034c96e --- /dev/null +++ b/pgcli/packages/parseutils/utils.py @@ -0,0 +1,140 @@ +import re +import sqlparse +from sqlparse.sql import Identifier +from sqlparse.tokens import Token, Error + +cleanup_regex = { + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile(r"([^\s]+)$"), +} + + +def last_word(text, include="alphanum_underscore"): + r""" + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + >>> last_word('"foo*bar', include='most_punctuations') + '"foo*bar' + """ + + if not text: # Empty string + return "" + + if text[-1].isspace(): + return "" + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return "" + + +def find_prev_keyword(sql, n_skip=0): + """Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, "" + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + flattened = flattened[: len(flattened) - n_skip] + + logical_operators = ("AND", "OR", "NOT", "BETWEEN") + + for t in reversed(flattened): + if t.value == "(" or ( + t.is_keyword and (t.value.upper() not in logical_operators) + ): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index throws an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = "".join(tok.value for tok in flattened[: idx + 1]) + return t, text + + return None, "" + + +# Postgresql dollar quote signs look like `$$` or `$tag$` +dollar_quote_regex = re.compile(r"^\$[^$]*\$$") + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +def _parsed_is_open_quote(parsed): + # Look for unmatched single quotes, or unmatched dollar sign quotes + return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten()) + + +def parse_partial_identifier(word): + """Attempt to parse a (partially typed) word as an identifier + + word may include a schema qualification, like `schema_name.partial_name` + or `schema_name.` There may also be unclosed quotation marks, like + `"schema`, or `schema."partial_name` + + :param word: string representing a (partially complete) identifier + :return: sqlparse.sql.Identifier, or None + """ + + p = sqlparse.parse(word)[0] + n_tok = len(p.tokens) + if n_tok == 1 and isinstance(p.tokens[0], Identifier): + return p.tokens[0] + elif p.token_next_by(m=(Error, '"'))[1]: + # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' + # Close the double quote, then reparse + return parse_partial_identifier(word + '"') + else: + return None diff --git a/pgcli/packages/pgliterals/__init__.py b/pgcli/packages/pgliterals/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/pgcli/packages/pgliterals/__init__.py diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py new file mode 100644 index 0000000..5c39296 --- /dev/null +++ b/pgcli/packages/pgliterals/main.py @@ -0,0 +1,15 @@ +import os +import json + +root = os.path.dirname(__file__) +literal_file = os.path.join(root, "pgliterals.json") + +with open(literal_file) as f: + literals = json.load(f) + + +def get_literals(literal_type, type_=tuple): + # Where `literal_type` is one of 'keywords', 'functions', 'datatypes', + # returns a tuple of literal values of that type. + + return type_(literals[literal_type]) diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json new file mode 100644 index 0000000..c7b74b5 --- /dev/null +++ b/pgcli/packages/pgliterals/pgliterals.json @@ -0,0 +1,629 @@ +{ + "keywords": { + "ACCESS": [], + "ADD": [], + "ALL": [], + "ALTER": [ + "AGGREGATE", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DEFAULT", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "LARGE OBJECT", + "MATERIALIZED VIEW", + "OPERATOR", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "SYSTEM", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "USER", + "VIEW" + ], + "AND": [], + "ANY": [], + "AS": [], + "ASC": [], + "AUDIT": [], + "BEGIN": [], + "BETWEEN": [], + "BY": [], + "CASE": [], + "CHAR": [], + "CHECK": [], + "CLUSTER": [], + "COLUMN": [], + "COMMENT": [], + "COMMIT": [], + "COMPRESS": [], + "CONCURRENTLY": [], + "CONNECT": [], + "COPY": [], + "CREATE": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN EXTENSION", + "FUNCTION", + "GLOBAL", + "GROUP", + "IF NOT EXISTS", + "INDEX", + "LANGUAGE", + "LOCAL", + "MATERIALIZED VIEW", + "OPERATOR", + "OR REPLACE", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEMPORARY", + "TEXT SEARCH", + "TRIGGER", + "TYPE", + "UNIQUE", + "UNLOGGED", + "USER", + "USER MAPPING", + "VIEW" + ], + "CURRENT": [], + "DATABASE": [], + "DATE": [], + "DECIMAL": [], + "DEFAULT": [], + "DELETE FROM": [], + "DELIMITER": [], + "DESC": [], + "DESCRIBE": [], + "DISTINCT": [], + "DROP": [ + "ACCESS METHOD", + "AGGREGATE", + "CAST", + "COLLATION", + "COLUMN", + "CONVERSION", + "DATABASE", + "DOMAIN", + "EVENT TRIGGER", + "EXTENSION", + "FOREIGN DATA WRAPPER", + "FOREIGN TABLE", + "FUNCTION", + "GROUP", + "INDEX", + "LANGUAGE", + "MATERIALIZED VIEW", + "OPERATOR", + "OWNED", + "POLICY", + "ROLE", + "RULE", + "SCHEMA", + "SEQUENCE", + "SERVER", + "TABLE", + "TABLESPACE", + "TEXT SEARCH", + "TRANSFORM", + "TRIGGER", + "TYPE", + "USER", + "USER MAPPING", + "VIEW" + ], + "EXPLAIN": [], + "ELSE": [], + "ENCODING": [], + "ESCAPE": [], + "EXCLUSIVE": [], + "EXISTS": [], + "EXTENSION": [], + "FILE": [], + "FLOAT": [], + "FOR": [], + "FORMAT": [], + "FORCE_QUOTE": [], + "FORCE_NOT_NULL": [], + "FREEZE": [], + "FROM": [], + "FULL": [], + "FUNCTION": [], + "GRANT": [], + "GROUP BY": [], + "HAVING": [], + "HEADER": [], + "IDENTIFIED": [], + "IMMEDIATE": [], + "IN": [], + "INCREMENT": [], + "INDEX": [], + "INITIAL": [], + "INSERT INTO": [], + "INTEGER": [], + "INTERSECT": [], + "INTERVAL": [], + "INTO": [], + "IS": [], + "JOIN": [], + "LANGUAGE": [], + "LEFT": [], + "LEVEL": [], + "LIKE": [], + "LIMIT": [], + "LOCK": [], + "LONG": [], + "MATERIALIZED VIEW": [], + "MAXEXTENTS": [], + "MINUS": [], + "MLSLABEL": [], + "MODE": [], + "MODIFY": [], + "NOT": [], + "NOAUDIT": [], + "NOTICE": [], + "NOCOMPRESS": [], + "NOWAIT": [], + "NULL": [], + "NUMBER": [], + "OIDS": [], + "OF": [], + "OFFLINE": [], + "ON": [], + "ONLINE": [], + "OPTION": [], + "OR": [], + "ORDER BY": [], + "OUTER": [], + "OWNER": [], + "PCTFREE": [], + "PRIMARY": [], + "PRIOR": [], + "PRIVILEGES": [], + "QUOTE": [], + "RAISE": [], + "RENAME": [], + "REPLACE": [], + "RESET": ["ALL"], + "RAW": [], + "REFRESH MATERIALIZED VIEW": [], + "RESOURCE": [], + "RETURNS": [], + "REVOKE": [], + "RIGHT": [], + "ROLLBACK": [], + "ROW": [], + "ROWID": [], + "ROWNUM": [], + "ROWS": [], + "SELECT": [], + "SESSION": [], + "SET": [], + "SHARE": [], + "SHOW": [], + "SIZE": [], + "SMALLINT": [], + "START": [], + "SUCCESSFUL": [], + "SYNONYM": [], + "SYSDATE": [], + "TABLE": [], + "TEMPLATE": [], + "THEN": [], + "TO": [], + "TRIGGER": [], + "TRUNCATE": [], + "UID": [], + "UNION": [], + "UNIQUE": [], + "UPDATE": [], + "USE": [], + "USER": [], + "USING": [], + "VALIDATE": [], + "VALUES": [], + "VARCHAR": [], + "VARCHAR2": [], + "VIEW": [], + "WHEN": [], + "WHENEVER": [], + "WHERE": [], + "WITH": [] + }, + "functions": [ + "ABBREV", + "ABS", + "AGE", + "AREA", + "ARRAY_AGG", + "ARRAY_APPEND", + "ARRAY_CAT", + "ARRAY_DIMS", + "ARRAY_FILL", + "ARRAY_LENGTH", + "ARRAY_LOWER", + "ARRAY_NDIMS", + "ARRAY_POSITION", + "ARRAY_POSITIONS", + "ARRAY_PREPEND", + "ARRAY_REMOVE", + "ARRAY_REPLACE", + "ARRAY_TO_STRING", + "ARRAY_UPPER", + "ASCII", + "AVG", + "BIT_AND", + "BIT_LENGTH", + "BIT_OR", + "BOOL_AND", + "BOOL_OR", + "BOUND_BOX", + "BOX", + "BROADCAST", + "BTRIM", + "CARDINALITY", + "CBRT", + "CEIL", + "CEILING", + "CENTER", + "CHAR_LENGTH", + "CHR", + "CIRCLE", + "CLOCK_TIMESTAMP", + "CONCAT", + "CONCAT_WS", + "CONVERT", + "CONVERT_FROM", + "CONVERT_TO", + "COUNT", + "CUME_DIST", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATE_PART", + "DATE_TRUNC", + "DECODE", + "DEGREES", + "DENSE_RANK", + "DIAMETER", + "DIV", + "ENCODE", + "ENUM_FIRST", + "ENUM_LAST", + "ENUM_RANGE", + "EVERY", + "EXP", + "EXTRACT", + "FAMILY", + "FIRST_VALUE", + "FLOOR", + "FORMAT", + "GET_BIT", + "GET_BYTE", + "HEIGHT", + "HOST", + "HOSTMASK", + "INET_MERGE", + "INET_SAME_FAMILY", + "INITCAP", + "ISCLOSED", + "ISFINITE", + "ISOPEN", + "JUSTIFY_DAYS", + "JUSTIFY_HOURS", + "JUSTIFY_INTERVAL", + "LAG", + "LAST_VALUE", + "LEAD", + "LEFT", + "LENGTH", + "LINE", + "LN", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOG", + "LOG10", + "LOWER", + "LPAD", + "LSEG", + "LTRIM", + "MAKE_DATE", + "MAKE_INTERVAL", + "MAKE_TIME", + "MAKE_TIMESTAMP", + "MAKE_TIMESTAMPTZ", + "MASKLEN", + "MAX", + "MD5", + "MIN", + "MOD", + "NETMASK", + "NETWORK", + "NOW", + "NPOINTS", + "NTH_VALUE", + "NTILE", + "NUM_NONNULLS", + "NUM_NULLS", + "OCTET_LENGTH", + "OVERLAY", + "PARSE_IDENT", + "PATH", + "PCLOSE", + "PERCENT_RANK", + "PG_CLIENT_ENCODING", + "PI", + "POINT", + "POLYGON", + "POPEN", + "POSITION", + "POWER", + "QUOTE_IDENT", + "QUOTE_LITERAL", + "QUOTE_NULLABLE", + "RADIANS", + "RADIUS", + "RANK", + "REGEXP_MATCH", + "REGEXP_MATCHES", + "REGEXP_REPLACE", + "REGEXP_SPLIT_TO_ARRAY", + "REGEXP_SPLIT_TO_TABLE", + "REPEAT", + "REPLACE", + "REVERSE", + "RIGHT", + "ROUND", + "ROW_NUMBER", + "RPAD", + "RTRIM", + "SCALE", + "SET_BIT", + "SET_BYTE", + "SET_MASKLEN", + "SHA224", + "SHA256", + "SHA384", + "SHA512", + "SIGN", + "SPLIT_PART", + "SQRT", + "STARTS_WITH", + "STATEMENT_TIMESTAMP", + "STRING_TO_ARRAY", + "STRPOS", + "SUBSTR", + "SUBSTRING", + "SUM", + "TEXT", + "TIMEOFDAY", + "TO_ASCII", + "TO_CHAR", + "TO_DATE", + "TO_HEX", + "TO_NUMBER", + "TO_TIMESTAMP", + "TRANSACTION_TIMESTAMP", + "TRANSLATE", + "TRIM", + "TRUNC", + "UNNEST", + "UPPER", + "WIDTH", + "WIDTH_BUCKET", + "XMLAGG" + ], + "datatypes": [ + "ANY", + "ANYARRAY", + "ANYELEMENT", + "ANYENUM", + "ANYNONARRAY", + "ANYRANGE", + "BIGINT", + "BIGSERIAL", + "BIT", + "BIT VARYING", + "BOOL", + "BOOLEAN", + "BOX", + "BYTEA", + "CHAR", + "CHARACTER", + "CHARACTER VARYING", + "CIDR", + "CIRCLE", + "CSTRING", + "DATE", + "DECIMAL", + "DOUBLE PRECISION", + "EVENT_TRIGGER", + "FDW_HANDLER", + "FLOAT4", + "FLOAT8", + "INET", + "INT", + "INT2", + "INT4", + "INT8", + "INTEGER", + "INTERNAL", + "INTERVAL", + "JSON", + "JSONB", + "LANGUAGE_HANDLER", + "LINE", + "LSEG", + "MACADDR", + "MACADDR8", + "MONEY", + "NUMERIC", + "OID", + "OPAQUE", + "PATH", + "PG_LSN", + "POINT", + "POLYGON", + "REAL", + "RECORD", + "REGCLASS", + "REGCONFIG", + "REGDICTIONARY", + "REGNAMESPACE", + "REGOPER", + "REGOPERATOR", + "REGPROC", + "REGPROCEDURE", + "REGROLE", + "REGTYPE", + "SERIAL", + "SERIAL2", + "SERIAL4", + "SERIAL8", + "SMALLINT", + "SMALLSERIAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TRIGGER", + "TSQUERY", + "TSVECTOR", + "TXID_SNAPSHOT", + "UUID", + "VARBIT", + "VARCHAR", + "VOID", + "XML" + ], + "reserved": [ + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONSTRAINT", + "CREATE", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "GRANT", + "GROUP", + "HAVING", + "IN", + "INITIALLY", + "INTERSECT", + "INTO", + "LATERAL", + "LEADING", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NOT", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "SELECT", + "SESSION_USER", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "AUTHORIZATION", + "BINARY", + "COLLATION", + "CONCURRENTLY", + "CROSS", + "CURRENT_SCHEMA", + "FREEZE", + "FULL", + "ILIKE", + "INNER", + "IS", + "ISNULL", + "JOIN", + "LEFT", + "LIKE", + "NATURAL", + "NOTNULL", + "OUTER", + "OVERLAPS", + "RIGHT", + "SIMILAR", + "TABLESAMPLE", + "VERBOSE" + ] +} diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py new file mode 100644 index 0000000..e92dcbb --- /dev/null +++ b/pgcli/packages/prioritization.py @@ -0,0 +1,51 @@ +import re +import sqlparse +from sqlparse.tokens import Name +from collections import defaultdict +from .pgliterals.main import get_literals + + +white_space_regex = re.compile("\\s+", re.MULTILINE) + + +def _compile_regex(keyword): + # Surround the keyword with word boundaries and replace interior whitespace + # with whitespace wildcards + pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b" + return re.compile(pattern, re.MULTILINE | re.IGNORECASE) + + +keywords = get_literals("keywords") +keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords) + + +class PrevalenceCounter(object): + def __init__(self): + self.keyword_counts = defaultdict(int) + self.name_counts = defaultdict(int) + + def update(self, text): + self.update_keywords(text) + self.update_names(text) + + def update_names(self, text): + for parsed in sqlparse.parse(text): + for token in parsed.flatten(): + if token.ttype in Name: + self.name_counts[token.value] += 1 + + def clear_names(self): + self.name_counts = defaultdict(int) + + def update_keywords(self, text): + # Count keywords. Can't rely for sqlparse for this, because it's + # database agnostic + for keyword, regex in keyword_regexs.items(): + for _ in regex.finditer(text): + self.keyword_counts[keyword] += 1 + + def keyword_count(self, keyword): + return self.keyword_counts[keyword] + + def name_count(self, name): + return self.name_counts[name] diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py new file mode 100644 index 0000000..3c58490 --- /dev/null +++ b/pgcli/packages/prompt_utils.py @@ -0,0 +1,35 @@ +import sys +import click +from .parseutils import is_destructive + + +def confirm_destructive_query(queries): + """Check if the query is destructive and prompts the user to confirm. + + Returns: + * None if the query is non-destructive or we can't prompt the user. + * True if the query is destructive and the user wants to proceed. + * False if the query is destructive and the user doesn't want to proceed. + + """ + prompt_text = ( + "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)" + ) + if is_destructive(queries) and sys.stdin.isatty(): + return prompt(prompt_text, type=bool) + + +def confirm(*args, **kwargs): + """Prompt for confirmation (yes/no) and handle any abort exceptions.""" + try: + return click.confirm(*args, **kwargs) + except click.Abort: + return False + + +def prompt(*args, **kwargs): + """Prompt the user for input and handle any abort exceptions.""" + try: + return click.prompt(*args, **kwargs) + except click.Abort: + return False diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py new file mode 100644 index 0000000..6ef8859 --- /dev/null +++ b/pgcli/packages/sqlcompletion.py @@ -0,0 +1,608 @@ +import sys +import re +import sqlparse +from collections import namedtuple +from sqlparse.sql import Comparison, Identifier, Where +from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier +from .parseutils.tables import extract_tables +from .parseutils.ctes import isolate_query_ctes +from pgspecial.main import parse_special_command + + +Special = namedtuple("Special", []) +Database = namedtuple("Database", []) +Schema = namedtuple("Schema", ["quoted"]) +Schema.__new__.__defaults__ = (False,) +# FromClauseItem is a table/view/function used in the FROM clause +# `table_refs` contains the list of tables/... already in the statement, +# used to ensure that the alias we suggest is unique +FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables") +Table = namedtuple("Table", ["schema", "table_refs", "local_tables"]) +TableFormat = namedtuple("TableFormat", []) +View = namedtuple("View", ["schema", "table_refs"]) +# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' +JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"]) +# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' +Join = namedtuple("Join", ["table_refs", "schema"]) + +Function = namedtuple("Function", ["schema", "table_refs", "usage"]) +# For convenience, don't require the `usage` argument in Function constructor +Function.__new__.__defaults__ = (None, tuple(), None) +Table.__new__.__defaults__ = (None, tuple(), tuple()) +View.__new__.__defaults__ = (None, tuple()) +FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) + +Column = namedtuple( + "Column", + ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], +) +Column.__new__.__defaults__ = (None, None, tuple(), False, None) + +Keyword = namedtuple("Keyword", ["last_token"]) +Keyword.__new__.__defaults__ = (None,) +NamedQuery = namedtuple("NamedQuery", []) +Datatype = namedtuple("Datatype", ["schema"]) +Alias = namedtuple("Alias", ["aliases"]) + +Path = namedtuple("Path", []) + + +class SqlStatement(object): + def __init__(self, full_text, text_before_cursor): + self.identifier = None + self.word_before_cursor = word_before_cursor = last_word( + text_before_cursor, include="many_punctuations" + ) + full_text = _strip_named_query(full_text) + text_before_cursor = _strip_named_query(text_before_cursor) + + full_text, text_before_cursor, self.local_tables = isolate_query_ctes( + full_text, text_before_cursor + ) + + self.text_before_cursor_including_last_word = text_before_cursor + + # If we've partially typed a word then word_before_cursor won't be an + # empty string. In that case we want to remove the partially typed + # string before sending it to the sqlparser. Otherwise the last token + # will always be the partially typed string which renders the smart + # completion useless because it will always return the list of + # keywords as completion. + if self.word_before_cursor: + if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\": + parsed = sqlparse.parse(text_before_cursor) + else: + text_before_cursor = text_before_cursor[: -len(word_before_cursor)] + parsed = sqlparse.parse(text_before_cursor) + self.identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + full_text, text_before_cursor, parsed = _split_multiple_statements( + full_text, text_before_cursor, parsed + ) + + self.full_text = full_text + self.text_before_cursor = text_before_cursor + self.parsed = parsed + + self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or "" + + def is_insert(self): + return self.parsed.token_first().value.lower() == "insert" + + def get_tables(self, scope="full"): + """Gets the tables available in the statement. + param `scope:` possible values: 'full', 'insert', 'before' + If 'insert', only the first table is returned. + If 'before', only tables before the cursor are returned. + If not 'insert' and the stmt is an insert, the first table is skipped. + """ + tables = extract_tables( + self.full_text if scope == "full" else self.text_before_cursor + ) + if scope == "insert": + tables = tables[:1] + elif self.is_insert(): + tables = tables[1:] + return tables + + def get_previous_token(self, token): + return self.parsed.token_prev(self.parsed.token_index(token))[1] + + def get_identifier_schema(self): + schema = (self.identifier and self.identifier.get_parent_name()) or None + # If schema name is unquoted, lower-case it + if schema and self.identifier.value[0] != '"': + schema = schema.lower() + + return schema + + def reduce_to_prev_keyword(self, n_skip=0): + prev_keyword, self.text_before_cursor = find_prev_keyword( + self.text_before_cursor, n_skip=n_skip + ) + return prev_keyword + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + if full_text.startswith("\\i "): + return (Path(),) + + # This is a temporary hack; the exception handling + # here should be removed once sqlparse has been fixed + try: + stmt = SqlStatement(full_text, text_before_cursor) + except (TypeError, AttributeError): + return [] + + # Check for special commands and handle those separately + if stmt.parsed: + # Be careful here because trivial whitespace is parsed as a + # statement, but the statement won't have a first token + tok1 = stmt.parsed.token_first() + if tok1 and tok1.value.startswith("\\"): + text = stmt.text_before_cursor + stmt.word_before_cursor + return suggest_special(text) + + return suggest_based_on_last_token(stmt.last_token, stmt) + + +named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+") + + +def _strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + """ + + if named_query_regex.match(txt): + txt = named_query_regex.sub("", txt) + return txt + + +function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M) + + +def _find_function_body(text): + split = function_body_pattern.search(text) + return (split.start(2), split.end(2)) if split else (None, None) + + +def _statement_from_function(full_text, text_before_cursor, statement): + current_pos = len(text_before_cursor) + body_start, body_end = _find_function_body(full_text) + if body_start is None: + return full_text, text_before_cursor, statement + if not body_start <= current_pos < body_end: + return full_text, text_before_cursor, statement + full_text = full_text[body_start:body_end] + text_before_cursor = text_before_cursor[body_start:] + parsed = sqlparse.parse(text_before_cursor) + return _split_multiple_statements(full_text, text_before_cursor, parsed) + + +def _split_multiple_statements(full_text, text_before_cursor, parsed): + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds + # the current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(str(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + return full_text, text_before_cursor, None + + token2 = None + if statement.get_type() in ("CREATE", "CREATE OR REPLACE"): + token1 = statement.token_first() + if token1: + token1_idx = statement.token_index(token1) + token2 = statement.token_next(token1_idx)[1] + if token2 and token2.value.upper() == "FUNCTION": + full_text, text_before_cursor, statement = _statement_from_function( + full_text, text_before_cursor, statement + ) + return full_text, text_before_cursor, statement + + +SPECIALS_SUGGESTION = { + "dT": Datatype, + "df": Function, + "dt": Table, + "dv": View, + "sf": Function, +} + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return (Special(),) + + if cmd in ("\\c", "\\connect"): + return (Database(),) + + if cmd == "\\T": + return (TableFormat(),) + + if cmd == "\\dn": + return (Schema(),) + + if arg: + # Try to distinguish "\d name" from "\d schema.name" + # Note that this will fail to obtain a schema name if wildcards are + # used, e.g. "\d schema???.name" + parsed = sqlparse.parse(arg)[0].tokens[0] + try: + schema = parsed.get_parent_name() + except AttributeError: + schema = None + else: + schema = None + + if cmd[1:] == "d": + # \d can describe tables or views + if schema: + return (Table(schema=schema), View(schema=schema)) + else: + return (Schema(), Table(schema=None), View(schema=None)) + elif cmd[1:] in SPECIALS_SUGGESTION: + rel_type = SPECIALS_SUGGESTION[cmd[1:]] + if schema: + if rel_type == Function: + return (Function(schema=schema, usage="special"),) + return (rel_type(schema=schema),) + else: + if rel_type == Function: + return (Schema(), Function(schema=None, usage="special")) + return (Schema(), rel_type(schema=None)) + + if cmd in ["\\n", "\\ns", "\\nd"]: + return (NamedQuery(),) + + return (Keyword(), Special()) + + +def suggest_based_on_last_token(token, stmt): + + if isinstance(token, str): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword = stmt.reduce_to_prev_keyword() + return suggest_based_on_last_token(prev_keyword, stmt) + elif isinstance(token, Identifier): + # If the previous token is an identifier, we can suggest datatypes if + # we're in a parenthesized column/field list, e.g.: + # CREATE TABLE foo (Identifier <CURSOR> + # CREATE FUNCTION foo (Identifier <CURSOR> + # If we're not in a parenthesized list, the most likely scenario is the + # user is about to specify an alias, e.g.: + # SELECT Identifier <CURSOR> + # SELECT foo FROM Identifier <CURSOR> + prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) + if prev_keyword and prev_keyword.value == "(": + # Suggest datatypes + return suggest_based_on_last_token("type", stmt) + else: + return (Keyword(),) + else: + token_v = token.value.lower() + + if not token: + return (Keyword(), Special()) + elif token_v.endswith("("): + p = sqlparse.parse(stmt.text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token("where", stmt) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + prev_tok = where.token_prev(len(where.tokens) - 1)[1] + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == "exists": + return (Keyword(),) + else: + return column_suggestions + + # Get the token before the parens + prev_tok = p.token_prev(len(p.tokens) - 1)[1] + + if ( + prev_tok + and prev_tok.value + and prev_tok.value.lower().split(" ")[-1] == "using" + ): + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = stmt.get_tables("before") + + # suggest columns that are present in more than one table + return ( + Column( + table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables, + ), + ) + + elif p.token_first().value.lower() == "select": + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): + return (Keyword(),) + prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] + if prev_prev_tok and prev_prev_tok.normalized == "INTO": + return (Column(table_refs=stmt.get_tables("insert"), context="insert"),) + # We're probably in a function argument list + return _suggest_expression(token_v, stmt) + elif token_v == "set": + return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) + elif token_v in ("select", "where", "having", "order by", "distinct"): + return _suggest_expression(token_v, stmt) + elif token_v == "as": + # Don't suggest anything for aliases + return () + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate") + ): + + schema = stmt.get_identifier_schema() + tables = extract_tables(stmt.text_before_cursor) + is_join = token_v.endswith("join") and token.is_keyword + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + if token_v == "from" or is_join: + suggest.append( + FromClauseItem( + schema=schema, table_refs=tables, local_tables=stmt.local_tables + ) + ) + elif token_v == "truncate": + suggest.append(Table(schema)) + else: + suggest.extend((Table(schema), View(schema))) + + if is_join and _allow_join(stmt.parsed): + tables = stmt.get_tables("before") + suggest.append(Join(table_refs=tables, schema=schema)) + + return tuple(suggest) + + elif token_v == "function": + schema = stmt.get_identifier_schema() + + # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` + try: + prev = stmt.get_previous_token(token).value.lower() + if prev in ("drop", "alter", "create", "create or replace"): + + # Suggest functions from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + suggest.append(Function(schema=schema, usage="signature")) + return tuple(suggest) + + except ValueError: + pass + return tuple() + + elif token_v in ("table", "view"): + # E.g. 'ALTER TABLE <tablname>' + rel_type = {"table": Table, "view": View, "function": Function}[token_v] + schema = stmt.get_identifier_schema() + if schema: + return (rel_type(schema=schema),) + else: + return (Schema(), rel_type(schema=schema)) + + elif token_v == "column": + # E.g. 'ALTER TABLE foo ALTER COLUMN bar + return (Column(table_refs=stmt.get_tables()),) + + elif token_v == "on": + tables = stmt.get_tables("before") + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None + if parent: + # "ON parent.<suggestion>" + # parent can be either a schema name or table alias + filteredtables = tuple(t for t in tables if identifies(parent, t)) + sugs = [ + Column(table_refs=filteredtables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ] + if filteredtables and _allow_join_condition(stmt.parsed): + sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) + return tuple(sugs) + else: + # ON <suggestion> + # Use table alias if there is one, otherwise the table name + aliases = tuple(t.ref for t in tables) + if _allow_join_condition(stmt.parsed): + return ( + Alias(aliases=aliases), + JoinCondition(table_refs=tables, parent=None), + ) + else: + return (Alias(aliases=aliases),) + + elif token_v in ("c", "use", "database", "template"): + # "\c <db", "use <db>", "DROP DATABASE <db>", + # "CREATE DATABASE <newdb> WITH TEMPLATE <db>" + return (Database(),) + elif token_v == "schema": + # DROP SCHEMA schema_name, SET SCHEMA schema name + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) + quoted = prev_keyword and prev_keyword.value.lower() == "set" + return (Schema(quoted),) + elif token_v.endswith(",") or token_v in ("=", "and", "or"): + prev_keyword = stmt.reduce_to_prev_keyword() + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return () + elif token_v in ("type", "::"): + # ALTER TABLE foo SET DATA TYPE bar + # SELECT foo::bar + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well + schema = stmt.get_identifier_schema() + suggestions = [Datatype(schema=schema), Table(schema=schema)] + if not schema: + suggestions.append(Schema()) + return tuple(suggestions) + elif token_v in {"alter", "create", "drop"}: + return (Keyword(token_v.upper()),) + elif token.is_keyword: + # token is a keyword we haven't implemented any special handling for + # go backwards in the query until we find one we do recognize + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return (Keyword(token_v.upper()),) + else: + return (Keyword(),) + + +def _suggest_expression(token_v, stmt): + """ + Return suggestions for an expression, taking account of any partially-typed + identifier's parent, which may be a table alias or schema name. + """ + parent = stmt.identifier.get_parent_name() if stmt.identifier else [] + tables = stmt.get_tables() + + if parent: + tables = tuple(t for t in tables if identifies(parent, t)) + return ( + Column(table_refs=tables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ) + + return ( + Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True), + Function(schema=None), + Keyword(token_v.upper()), + ) + + +def identifies(id, ref): + """Returns true if string `id` matches TableReference `ref`""" + + return ( + id == ref.alias + or id == ref.name + or (ref.schema and (id == ref.schema + "." + ref.name)) + ) + + +def _allow_join_condition(statement): + """ + Tests if a join condition should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b on a.id = <cursor> + So check that the preceding token is a ON, AND, or OR keyword, instead of + e.g. an equals sign. + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower() in ("on", "and", "or") + + +def _allow_join(statement): + """ + Tests if a join should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b <cursor> + So check that the preceding token is a JOIN keyword + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in ( + "cross join", + "natural join", + ) |