diff options
Diffstat (limited to 'pgcli/packages/sqlcompletion.py')
-rw-r--r-- | pgcli/packages/sqlcompletion.py | 608 |
1 files changed, 608 insertions, 0 deletions
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py new file mode 100644 index 0000000..6ef8859 --- /dev/null +++ b/pgcli/packages/sqlcompletion.py @@ -0,0 +1,608 @@ +import sys +import re +import sqlparse +from collections import namedtuple +from sqlparse.sql import Comparison, Identifier, Where +from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier +from .parseutils.tables import extract_tables +from .parseutils.ctes import isolate_query_ctes +from pgspecial.main import parse_special_command + + +Special = namedtuple("Special", []) +Database = namedtuple("Database", []) +Schema = namedtuple("Schema", ["quoted"]) +Schema.__new__.__defaults__ = (False,) +# FromClauseItem is a table/view/function used in the FROM clause +# `table_refs` contains the list of tables/... already in the statement, +# used to ensure that the alias we suggest is unique +FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables") +Table = namedtuple("Table", ["schema", "table_refs", "local_tables"]) +TableFormat = namedtuple("TableFormat", []) +View = namedtuple("View", ["schema", "table_refs"]) +# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' +JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"]) +# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' +Join = namedtuple("Join", ["table_refs", "schema"]) + +Function = namedtuple("Function", ["schema", "table_refs", "usage"]) +# For convenience, don't require the `usage` argument in Function constructor +Function.__new__.__defaults__ = (None, tuple(), None) +Table.__new__.__defaults__ = (None, tuple(), tuple()) +View.__new__.__defaults__ = (None, tuple()) +FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) + +Column = namedtuple( + "Column", + ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], +) +Column.__new__.__defaults__ = (None, None, tuple(), False, None) + +Keyword = namedtuple("Keyword", ["last_token"]) +Keyword.__new__.__defaults__ = (None,) +NamedQuery = namedtuple("NamedQuery", []) +Datatype = namedtuple("Datatype", ["schema"]) +Alias = namedtuple("Alias", ["aliases"]) + +Path = namedtuple("Path", []) + + +class SqlStatement(object): + def __init__(self, full_text, text_before_cursor): + self.identifier = None + self.word_before_cursor = word_before_cursor = last_word( + text_before_cursor, include="many_punctuations" + ) + full_text = _strip_named_query(full_text) + text_before_cursor = _strip_named_query(text_before_cursor) + + full_text, text_before_cursor, self.local_tables = isolate_query_ctes( + full_text, text_before_cursor + ) + + self.text_before_cursor_including_last_word = text_before_cursor + + # If we've partially typed a word then word_before_cursor won't be an + # empty string. In that case we want to remove the partially typed + # string before sending it to the sqlparser. Otherwise the last token + # will always be the partially typed string which renders the smart + # completion useless because it will always return the list of + # keywords as completion. + if self.word_before_cursor: + if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\": + parsed = sqlparse.parse(text_before_cursor) + else: + text_before_cursor = text_before_cursor[: -len(word_before_cursor)] + parsed = sqlparse.parse(text_before_cursor) + self.identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + full_text, text_before_cursor, parsed = _split_multiple_statements( + full_text, text_before_cursor, parsed + ) + + self.full_text = full_text + self.text_before_cursor = text_before_cursor + self.parsed = parsed + + self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or "" + + def is_insert(self): + return self.parsed.token_first().value.lower() == "insert" + + def get_tables(self, scope="full"): + """Gets the tables available in the statement. + param `scope:` possible values: 'full', 'insert', 'before' + If 'insert', only the first table is returned. + If 'before', only tables before the cursor are returned. + If not 'insert' and the stmt is an insert, the first table is skipped. + """ + tables = extract_tables( + self.full_text if scope == "full" else self.text_before_cursor + ) + if scope == "insert": + tables = tables[:1] + elif self.is_insert(): + tables = tables[1:] + return tables + + def get_previous_token(self, token): + return self.parsed.token_prev(self.parsed.token_index(token))[1] + + def get_identifier_schema(self): + schema = (self.identifier and self.identifier.get_parent_name()) or None + # If schema name is unquoted, lower-case it + if schema and self.identifier.value[0] != '"': + schema = schema.lower() + + return schema + + def reduce_to_prev_keyword(self, n_skip=0): + prev_keyword, self.text_before_cursor = find_prev_keyword( + self.text_before_cursor, n_skip=n_skip + ) + return prev_keyword + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + if full_text.startswith("\\i "): + return (Path(),) + + # This is a temporary hack; the exception handling + # here should be removed once sqlparse has been fixed + try: + stmt = SqlStatement(full_text, text_before_cursor) + except (TypeError, AttributeError): + return [] + + # Check for special commands and handle those separately + if stmt.parsed: + # Be careful here because trivial whitespace is parsed as a + # statement, but the statement won't have a first token + tok1 = stmt.parsed.token_first() + if tok1 and tok1.value.startswith("\\"): + text = stmt.text_before_cursor + stmt.word_before_cursor + return suggest_special(text) + + return suggest_based_on_last_token(stmt.last_token, stmt) + + +named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+") + + +def _strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + """ + + if named_query_regex.match(txt): + txt = named_query_regex.sub("", txt) + return txt + + +function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M) + + +def _find_function_body(text): + split = function_body_pattern.search(text) + return (split.start(2), split.end(2)) if split else (None, None) + + +def _statement_from_function(full_text, text_before_cursor, statement): + current_pos = len(text_before_cursor) + body_start, body_end = _find_function_body(full_text) + if body_start is None: + return full_text, text_before_cursor, statement + if not body_start <= current_pos < body_end: + return full_text, text_before_cursor, statement + full_text = full_text[body_start:body_end] + text_before_cursor = text_before_cursor[body_start:] + parsed = sqlparse.parse(text_before_cursor) + return _split_multiple_statements(full_text, text_before_cursor, parsed) + + +def _split_multiple_statements(full_text, text_before_cursor, parsed): + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds + # the current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(str(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + return full_text, text_before_cursor, None + + token2 = None + if statement.get_type() in ("CREATE", "CREATE OR REPLACE"): + token1 = statement.token_first() + if token1: + token1_idx = statement.token_index(token1) + token2 = statement.token_next(token1_idx)[1] + if token2 and token2.value.upper() == "FUNCTION": + full_text, text_before_cursor, statement = _statement_from_function( + full_text, text_before_cursor, statement + ) + return full_text, text_before_cursor, statement + + +SPECIALS_SUGGESTION = { + "dT": Datatype, + "df": Function, + "dt": Table, + "dv": View, + "sf": Function, +} + + +def suggest_special(text): + text = text.lstrip() + cmd, _, arg = parse_special_command(text) + + if cmd == text: + # Trying to complete the special command itself + return (Special(),) + + if cmd in ("\\c", "\\connect"): + return (Database(),) + + if cmd == "\\T": + return (TableFormat(),) + + if cmd == "\\dn": + return (Schema(),) + + if arg: + # Try to distinguish "\d name" from "\d schema.name" + # Note that this will fail to obtain a schema name if wildcards are + # used, e.g. "\d schema???.name" + parsed = sqlparse.parse(arg)[0].tokens[0] + try: + schema = parsed.get_parent_name() + except AttributeError: + schema = None + else: + schema = None + + if cmd[1:] == "d": + # \d can describe tables or views + if schema: + return (Table(schema=schema), View(schema=schema)) + else: + return (Schema(), Table(schema=None), View(schema=None)) + elif cmd[1:] in SPECIALS_SUGGESTION: + rel_type = SPECIALS_SUGGESTION[cmd[1:]] + if schema: + if rel_type == Function: + return (Function(schema=schema, usage="special"),) + return (rel_type(schema=schema),) + else: + if rel_type == Function: + return (Schema(), Function(schema=None, usage="special")) + return (Schema(), rel_type(schema=None)) + + if cmd in ["\\n", "\\ns", "\\nd"]: + return (NamedQuery(),) + + return (Keyword(), Special()) + + +def suggest_based_on_last_token(token, stmt): + + if isinstance(token, str): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword = stmt.reduce_to_prev_keyword() + return suggest_based_on_last_token(prev_keyword, stmt) + elif isinstance(token, Identifier): + # If the previous token is an identifier, we can suggest datatypes if + # we're in a parenthesized column/field list, e.g.: + # CREATE TABLE foo (Identifier <CURSOR> + # CREATE FUNCTION foo (Identifier <CURSOR> + # If we're not in a parenthesized list, the most likely scenario is the + # user is about to specify an alias, e.g.: + # SELECT Identifier <CURSOR> + # SELECT foo FROM Identifier <CURSOR> + prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) + if prev_keyword and prev_keyword.value == "(": + # Suggest datatypes + return suggest_based_on_last_token("type", stmt) + else: + return (Keyword(),) + else: + token_v = token.value.lower() + + if not token: + return (Keyword(), Special()) + elif token_v.endswith("("): + p = sqlparse.parse(stmt.text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token("where", stmt) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + prev_tok = where.token_prev(len(where.tokens) - 1)[1] + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == "exists": + return (Keyword(),) + else: + return column_suggestions + + # Get the token before the parens + prev_tok = p.token_prev(len(p.tokens) - 1)[1] + + if ( + prev_tok + and prev_tok.value + and prev_tok.value.lower().split(" ")[-1] == "using" + ): + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = stmt.get_tables("before") + + # suggest columns that are present in more than one table + return ( + Column( + table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables, + ), + ) + + elif p.token_first().value.lower() == "select": + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): + return (Keyword(),) + prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] + if prev_prev_tok and prev_prev_tok.normalized == "INTO": + return (Column(table_refs=stmt.get_tables("insert"), context="insert"),) + # We're probably in a function argument list + return _suggest_expression(token_v, stmt) + elif token_v == "set": + return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) + elif token_v in ("select", "where", "having", "order by", "distinct"): + return _suggest_expression(token_v, stmt) + elif token_v == "as": + # Don't suggest anything for aliases + return () + elif (token_v.endswith("join") and token.is_keyword) or ( + token_v in ("copy", "from", "update", "into", "describe", "truncate") + ): + + schema = stmt.get_identifier_schema() + tables = extract_tables(stmt.text_before_cursor) + is_join = token_v.endswith("join") and token.is_keyword + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + if token_v == "from" or is_join: + suggest.append( + FromClauseItem( + schema=schema, table_refs=tables, local_tables=stmt.local_tables + ) + ) + elif token_v == "truncate": + suggest.append(Table(schema)) + else: + suggest.extend((Table(schema), View(schema))) + + if is_join and _allow_join(stmt.parsed): + tables = stmt.get_tables("before") + suggest.append(Join(table_refs=tables, schema=schema)) + + return tuple(suggest) + + elif token_v == "function": + schema = stmt.get_identifier_schema() + + # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` + try: + prev = stmt.get_previous_token(token).value.lower() + if prev in ("drop", "alter", "create", "create or replace"): + + # Suggest functions from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + suggest.append(Function(schema=schema, usage="signature")) + return tuple(suggest) + + except ValueError: + pass + return tuple() + + elif token_v in ("table", "view"): + # E.g. 'ALTER TABLE <tablname>' + rel_type = {"table": Table, "view": View, "function": Function}[token_v] + schema = stmt.get_identifier_schema() + if schema: + return (rel_type(schema=schema),) + else: + return (Schema(), rel_type(schema=schema)) + + elif token_v == "column": + # E.g. 'ALTER TABLE foo ALTER COLUMN bar + return (Column(table_refs=stmt.get_tables()),) + + elif token_v == "on": + tables = stmt.get_tables("before") + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None + if parent: + # "ON parent.<suggestion>" + # parent can be either a schema name or table alias + filteredtables = tuple(t for t in tables if identifies(parent, t)) + sugs = [ + Column(table_refs=filteredtables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ] + if filteredtables and _allow_join_condition(stmt.parsed): + sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) + return tuple(sugs) + else: + # ON <suggestion> + # Use table alias if there is one, otherwise the table name + aliases = tuple(t.ref for t in tables) + if _allow_join_condition(stmt.parsed): + return ( + Alias(aliases=aliases), + JoinCondition(table_refs=tables, parent=None), + ) + else: + return (Alias(aliases=aliases),) + + elif token_v in ("c", "use", "database", "template"): + # "\c <db", "use <db>", "DROP DATABASE <db>", + # "CREATE DATABASE <newdb> WITH TEMPLATE <db>" + return (Database(),) + elif token_v == "schema": + # DROP SCHEMA schema_name, SET SCHEMA schema name + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) + quoted = prev_keyword and prev_keyword.value.lower() == "set" + return (Schema(quoted),) + elif token_v.endswith(",") or token_v in ("=", "and", "or"): + prev_keyword = stmt.reduce_to_prev_keyword() + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return () + elif token_v in ("type", "::"): + # ALTER TABLE foo SET DATA TYPE bar + # SELECT foo::bar + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well + schema = stmt.get_identifier_schema() + suggestions = [Datatype(schema=schema), Table(schema=schema)] + if not schema: + suggestions.append(Schema()) + return tuple(suggestions) + elif token_v in {"alter", "create", "drop"}: + return (Keyword(token_v.upper()),) + elif token.is_keyword: + # token is a keyword we haven't implemented any special handling for + # go backwards in the query until we find one we do recognize + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return (Keyword(token_v.upper()),) + else: + return (Keyword(),) + + +def _suggest_expression(token_v, stmt): + """ + Return suggestions for an expression, taking account of any partially-typed + identifier's parent, which may be a table alias or schema name. + """ + parent = stmt.identifier.get_parent_name() if stmt.identifier else [] + tables = stmt.get_tables() + + if parent: + tables = tuple(t for t in tables if identifies(parent, t)) + return ( + Column(table_refs=tables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent), + ) + + return ( + Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True), + Function(schema=None), + Keyword(token_v.upper()), + ) + + +def identifies(id, ref): + """Returns true if string `id` matches TableReference `ref`""" + + return ( + id == ref.alias + or id == ref.name + or (ref.schema and (id == ref.schema + "." + ref.name)) + ) + + +def _allow_join_condition(statement): + """ + Tests if a join condition should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b on a.id = <cursor> + So check that the preceding token is a ON, AND, or OR keyword, instead of + e.g. an equals sign. + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower() in ("on", "and", "or") + + +def _allow_join(statement): + """ + Tests if a join should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b <cursor> + So check that the preceding token is a JOIN keyword + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in ( + "cross join", + "natural join", + ) |