import sys import re import sqlparse from collections import namedtuple from sqlparse.sql import Comparison, Identifier, Where from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier from .parseutils.tables import extract_tables from .parseutils.ctes import isolate_query_ctes from pgspecial.main import parse_special_command Special = namedtuple("Special", []) Database = namedtuple("Database", []) Schema = namedtuple("Schema", ["quoted"]) Schema.__new__.__defaults__ = (False,) # FromClauseItem is a table/view/function used in the FROM clause # `table_refs` contains the list of tables/... already in the statement, # used to ensure that the alias we suggest is unique FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables") Table = namedtuple("Table", ["schema", "table_refs", "local_tables"]) TableFormat = namedtuple("TableFormat", []) View = namedtuple("View", ["schema", "table_refs"]) # JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"]) # Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' Join = namedtuple("Join", ["table_refs", "schema"]) Function = namedtuple("Function", ["schema", "table_refs", "usage"]) # For convenience, don't require the `usage` argument in Function constructor Function.__new__.__defaults__ = (None, tuple(), None) Table.__new__.__defaults__ = (None, tuple(), tuple()) View.__new__.__defaults__ = (None, tuple()) FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) Column = namedtuple( "Column", ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"], ) Column.__new__.__defaults__ = (None, None, tuple(), False, None) Keyword = namedtuple("Keyword", ["last_token"]) Keyword.__new__.__defaults__ = (None,) NamedQuery = namedtuple("NamedQuery", []) Datatype = namedtuple("Datatype", ["schema"]) Alias = namedtuple("Alias", ["aliases"]) Path = namedtuple("Path", []) class SqlStatement: def __init__(self, full_text, text_before_cursor): self.identifier = None self.word_before_cursor = word_before_cursor = last_word( text_before_cursor, include="many_punctuations" ) full_text = _strip_named_query(full_text) text_before_cursor = _strip_named_query(text_before_cursor) full_text, text_before_cursor, self.local_tables = isolate_query_ctes( full_text, text_before_cursor ) self.text_before_cursor_including_last_word = text_before_cursor # If we've partially typed a word then word_before_cursor won't be an # empty string. In that case we want to remove the partially typed # string before sending it to the sqlparser. Otherwise the last token # will always be the partially typed string which renders the smart # completion useless because it will always return the list of # keywords as completion. if self.word_before_cursor: if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\": parsed = sqlparse.parse(text_before_cursor) else: text_before_cursor = text_before_cursor[: -len(word_before_cursor)] parsed = sqlparse.parse(text_before_cursor) self.identifier = parse_partial_identifier(word_before_cursor) else: parsed = sqlparse.parse(text_before_cursor) full_text, text_before_cursor, parsed = _split_multiple_statements( full_text, text_before_cursor, parsed ) self.full_text = full_text self.text_before_cursor = text_before_cursor self.parsed = parsed self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or "" def is_insert(self): return self.parsed.token_first().value.lower() == "insert" def get_tables(self, scope="full"): """Gets the tables available in the statement. param `scope:` possible values: 'full', 'insert', 'before' If 'insert', only the first table is returned. If 'before', only tables before the cursor are returned. If not 'insert' and the stmt is an insert, the first table is skipped. """ tables = extract_tables( self.full_text if scope == "full" else self.text_before_cursor ) if scope == "insert": tables = tables[:1] elif self.is_insert(): tables = tables[1:] return tables def get_previous_token(self, token): return self.parsed.token_prev(self.parsed.token_index(token))[1] def get_identifier_schema(self): schema = (self.identifier and self.identifier.get_parent_name()) or None # If schema name is unquoted, lower-case it if schema and self.identifier.value[0] != '"': schema = schema.lower() return schema def reduce_to_prev_keyword(self, n_skip=0): prev_keyword, self.text_before_cursor = find_prev_keyword( self.text_before_cursor, n_skip=n_skip ) return prev_keyword def suggest_type(full_text, text_before_cursor): """Takes the full_text that is typed so far and also the text before the cursor to suggest completion type and scope. Returns a tuple with a type of entity ('table', 'column' etc) and a scope. A scope for a column category will be a list of tables. """ if full_text.startswith("\\i "): return (Path(),) # This is a temporary hack; the exception handling # here should be removed once sqlparse has been fixed try: stmt = SqlStatement(full_text, text_before_cursor) except (TypeError, AttributeError): return [] # Check for special commands and handle those separately if stmt.parsed: # Be careful here because trivial whitespace is parsed as a # statement, but the statement won't have a first token tok1 = stmt.parsed.token_first() if tok1 and tok1.value.startswith("\\"): text = stmt.text_before_cursor + stmt.word_before_cursor return suggest_special(text) return suggest_based_on_last_token(stmt.last_token, stmt) named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+") def _strip_named_query(txt): """ This will strip "save named query" command in the beginning of the line: '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' """ if named_query_regex.match(txt): txt = named_query_regex.sub("", txt) return txt function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M) def _find_function_body(text): split = function_body_pattern.search(text) return (split.start(2), split.end(2)) if split else (None, None) def _statement_from_function(full_text, text_before_cursor, statement): current_pos = len(text_before_cursor) body_start, body_end = _find_function_body(full_text) if body_start is None: return full_text, text_before_cursor, statement if not body_start <= current_pos < body_end: return full_text, text_before_cursor, statement full_text = full_text[body_start:body_end] text_before_cursor = text_before_cursor[body_start:] parsed = sqlparse.parse(text_before_cursor) return _split_multiple_statements(full_text, text_before_cursor, parsed) def _split_multiple_statements(full_text, text_before_cursor, parsed): if len(parsed) > 1: # Multiple statements being edited -- isolate the current one by # cumulatively summing statement lengths to find the one that bounds # the current position current_pos = len(text_before_cursor) stmt_start, stmt_end = 0, 0 for statement in parsed: stmt_len = len(str(statement)) stmt_start, stmt_end = stmt_end, stmt_end + stmt_len if stmt_end >= current_pos: text_before_cursor = full_text[stmt_start:current_pos] full_text = full_text[stmt_start:] break elif parsed: # A single statement statement = parsed[0] else: # The empty string return full_text, text_before_cursor, None token2 = None if statement.get_type() in ("CREATE", "CREATE OR REPLACE"): token1 = statement.token_first() if token1: token1_idx = statement.token_index(token1) token2 = statement.token_next(token1_idx)[1] if token2 and token2.value.upper() == "FUNCTION": full_text, text_before_cursor, statement = _statement_from_function( full_text, text_before_cursor, statement ) return full_text, text_before_cursor, statement SPECIALS_SUGGESTION = { "dT": Datatype, "df": Function, "dt": Table, "dv": View, "sf": Function, } def suggest_special(text): text = text.lstrip() cmd, _, arg = parse_special_command(text) if cmd == text: # Trying to complete the special command itself return (Special(),) if cmd in ("\\c", "\\connect"): return (Database(),) if cmd == "\\T": return (TableFormat(),) if cmd == "\\dn": return (Schema(),) if arg: # Try to distinguish "\d name" from "\d schema.name" # Note that this will fail to obtain a schema name if wildcards are # used, e.g. "\d schema???.name" parsed = sqlparse.parse(arg)[0].tokens[0] try: schema = parsed.get_parent_name() except AttributeError: schema = None else: schema = None if cmd[1:] == "d": # \d can describe tables or views if schema: return (Table(schema=schema), View(schema=schema)) else: return (Schema(), Table(schema=None), View(schema=None)) elif cmd[1:] in SPECIALS_SUGGESTION: rel_type = SPECIALS_SUGGESTION[cmd[1:]] if schema: if rel_type == Function: return (Function(schema=schema, usage="special"),) return (rel_type(schema=schema),) else: if rel_type == Function: return (Schema(), Function(schema=None, usage="special")) return (Schema(), rel_type(schema=None)) if cmd in ["\\n", "\\ns", "\\nd"]: return (NamedQuery(),) return (Keyword(), Special()) def suggest_based_on_last_token(token, stmt): if isinstance(token, str): token_v = token.lower() elif isinstance(token, Comparison): # If 'token' is a Comparison type such as # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling # token.value on the comparison type will only return the lhs of the # comparison. In this case a.id. So we need to do token.tokens to get # both sides of the comparison and pick the last token out of that # list. token_v = token.tokens[-1].value.lower() elif isinstance(token, Where): # sqlparse groups all tokens from the where clause into a single token # list. This means that token.value may be something like # 'where foo > 5 and '. We need to look "inside" token.tokens to handle # suggestions in complicated where clauses correctly prev_keyword = stmt.reduce_to_prev_keyword() return suggest_based_on_last_token(prev_keyword, stmt) elif isinstance(token, Identifier): # If the previous token is an identifier, we can suggest datatypes if # we're in a parenthesized column/field list, e.g.: # CREATE TABLE foo (Identifier # CREATE FUNCTION foo (Identifier # If we're not in a parenthesized list, the most likely scenario is the # user is about to specify an alias, e.g.: # SELECT Identifier # SELECT foo FROM Identifier prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) if prev_keyword and prev_keyword.value == "(": # Suggest datatypes return suggest_based_on_last_token("type", stmt) else: return (Keyword(),) else: token_v = token.value.lower() if not token: return (Keyword(), Special()) elif token_v.endswith("("): p = sqlparse.parse(stmt.text_before_cursor)[0] if p.tokens and isinstance(p.tokens[-1], Where): # Four possibilities: # 1 - Parenthesized clause like "WHERE foo AND (" # Suggest columns/functions # 2 - Function call like "WHERE foo(" # Suggest columns/functions # 3 - Subquery expression like "WHERE EXISTS (" # Suggest keywords, in order to do a subquery # 4 - Subquery OR array comparison like "WHERE foo = ANY(" # Suggest columns/functions AND keywords. (If we wanted to be # really fancy, we could suggest only array-typed columns) column_suggestions = suggest_based_on_last_token("where", stmt) # Check for a subquery expression (cases 3 & 4) where = p.tokens[-1] prev_tok = where.token_prev(len(where.tokens) - 1)[1] if isinstance(prev_tok, Comparison): # e.g. "SELECT foo FROM bar WHERE foo = ANY(" prev_tok = prev_tok.tokens[-1] prev_tok = prev_tok.value.lower() if prev_tok == "exists": return (Keyword(),) else: return column_suggestions # Get the token before the parens prev_tok = p.token_prev(len(p.tokens) - 1)[1] if ( prev_tok and prev_tok.value and prev_tok.value.lower().split(" ")[-1] == "using" ): # tbl1 INNER JOIN tbl2 USING (col1, col2) tables = stmt.get_tables("before") # suggest columns that are present in more than one table return ( Column( table_refs=tables, require_last_table=True, local_tables=stmt.local_tables, ), ) elif p.token_first().value.lower() == "select": # If the lparen is preceded by a space chances are we're about to # do a sub-select. if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("): return (Keyword(),) prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] if prev_prev_tok and prev_prev_tok.normalized == "INTO": return (Column(table_refs=stmt.get_tables("insert"), context="insert"),) # We're probably in a function argument list return _suggest_expression(token_v, stmt) elif token_v == "set": return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),) elif token_v in ("select", "where", "having", "order by", "distinct"): return _suggest_expression(token_v, stmt) elif token_v == "as": # Don't suggest anything for aliases return () elif (token_v.endswith("join") and token.is_keyword) or ( token_v in ("copy", "from", "update", "into", "describe", "truncate") ): schema = stmt.get_identifier_schema() tables = extract_tables(stmt.text_before_cursor) is_join = token_v.endswith("join") and token.is_keyword # Suggest tables from either the currently-selected schema or the # public schema if no schema has been specified suggest = [] if not schema: # Suggest schemas suggest.insert(0, Schema()) if token_v == "from" or is_join: suggest.append( FromClauseItem( schema=schema, table_refs=tables, local_tables=stmt.local_tables ) ) elif token_v == "truncate": suggest.append(Table(schema)) else: suggest.extend((Table(schema), View(schema))) if is_join and _allow_join(stmt.parsed): tables = stmt.get_tables("before") suggest.append(Join(table_refs=tables, schema=schema)) return tuple(suggest) elif token_v == "function": schema = stmt.get_identifier_schema() # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:` try: prev = stmt.get_previous_token(token).value.lower() if prev in ("drop", "alter", "create", "create or replace"): # Suggest functions from either the currently-selected schema or the # public schema if no schema has been specified suggest = [] if not schema: # Suggest schemas suggest.insert(0, Schema()) suggest.append(Function(schema=schema, usage="signature")) return tuple(suggest) except ValueError: pass return tuple() elif token_v in ("table", "view"): # E.g. 'ALTER TABLE ' rel_type = {"table": Table, "view": View, "function": Function}[token_v] schema = stmt.get_identifier_schema() if schema: return (rel_type(schema=schema),) else: return (Schema(), rel_type(schema=schema)) elif token_v == "column": # E.g. 'ALTER TABLE foo ALTER COLUMN bar return (Column(table_refs=stmt.get_tables()),) elif token_v == "on": tables = stmt.get_tables("before") parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None if parent: # "ON parent." # parent can be either a schema name or table alias filteredtables = tuple(t for t in tables if identifies(parent, t)) sugs = [ Column(table_refs=filteredtables, local_tables=stmt.local_tables), Table(schema=parent), View(schema=parent), Function(schema=parent), ] if filteredtables and _allow_join_condition(stmt.parsed): sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1])) return tuple(sugs) else: # ON # Use table alias if there is one, otherwise the table name aliases = tuple(t.ref for t in tables) if _allow_join_condition(stmt.parsed): return ( Alias(aliases=aliases), JoinCondition(table_refs=tables, parent=None), ) else: return (Alias(aliases=aliases),) elif token_v in ("c", "use", "database", "template"): # "\c ", "DROP DATABASE ", # "CREATE DATABASE WITH TEMPLATE " return (Database(),) elif token_v == "schema": # DROP SCHEMA schema_name, SET SCHEMA schema name prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) quoted = prev_keyword and prev_keyword.value.lower() == "set" return (Schema(quoted),) elif token_v.endswith(",") or token_v in ("=", "and", "or"): prev_keyword = stmt.reduce_to_prev_keyword() if prev_keyword: return suggest_based_on_last_token(prev_keyword, stmt) else: return () elif token_v in ("type", "::"): # ALTER TABLE foo SET DATA TYPE bar # SELECT foo::bar # Note that tables are a form of composite type in postgresql, so # they're suggested here as well schema = stmt.get_identifier_schema() suggestions = [Datatype(schema=schema), Table(schema=schema)] if not schema: suggestions.append(Schema()) return tuple(suggestions) elif token_v in {"alter", "create", "drop"}: return (Keyword(token_v.upper()),) elif token.is_keyword: # token is a keyword we haven't implemented any special handling for # go backwards in the query until we find one we do recognize prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) if prev_keyword: return suggest_based_on_last_token(prev_keyword, stmt) else: return (Keyword(token_v.upper()),) else: return (Keyword(),) def _suggest_expression(token_v, stmt): """ Return suggestions for an expression, taking account of any partially-typed identifier's parent, which may be a table alias or schema name. """ parent = stmt.identifier.get_parent_name() if stmt.identifier else [] tables = stmt.get_tables() if parent: tables = tuple(t for t in tables if identifies(parent, t)) return ( Column(table_refs=tables, local_tables=stmt.local_tables), Table(schema=parent), View(schema=parent), Function(schema=parent), ) return ( Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True), Function(schema=None), Keyword(token_v.upper()), ) def identifies(id, ref): """Returns true if string `id` matches TableReference `ref`""" return ( id == ref.alias or id == ref.name or (ref.schema and (id == ref.schema + "." + ref.name)) ) def _allow_join_condition(statement): """ Tests if a join condition should be suggested We need this to avoid bad suggestions when entering e.g. select * from tbl1 a join tbl2 b on a.id = So check that the preceding token is a ON, AND, or OR keyword, instead of e.g. an equals sign. :param statement: an sqlparse.sql.Statement :return: boolean """ if not statement or not statement.tokens: return False last_tok = statement.token_prev(len(statement.tokens))[1] return last_tok.value.lower() in ("on", "and", "or") def _allow_join(statement): """ Tests if a join should be suggested We need this to avoid bad suggestions when entering e.g. select * from tbl1 a join tbl2 b So check that the preceding token is a JOIN keyword :param statement: an sqlparse.sql.Statement :return: boolean """ if not statement or not statement.tokens: return False last_tok = statement.token_prev(len(statement.tokens))[1] return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in ( "cross join", "natural join", )