summaryrefslogtreecommitdiffstats
path: root/pgcli/packages
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 10:31:05 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2021-02-08 10:31:05 +0000
commit6884720fae8a2622b14e93d9e35ca5fcc2283b40 (patch)
treedf6f736bb623cdd7932bbe2256101a6ac4ef7f35 /pgcli/packages
parentInitial commit. (diff)
downloadpgcli-6884720fae8a2622b14e93d9e35ca5fcc2283b40.tar.xz
pgcli-6884720fae8a2622b14e93d9e35ca5fcc2283b40.zip
Adding upstream version 3.1.0.upstream/3.1.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli/packages')
-rw-r--r--pgcli/packages/__init__.py0
-rw-r--r--pgcli/packages/parseutils/__init__.py22
-rw-r--r--pgcli/packages/parseutils/ctes.py141
-rw-r--r--pgcli/packages/parseutils/meta.py170
-rw-r--r--pgcli/packages/parseutils/tables.py170
-rw-r--r--pgcli/packages/parseutils/utils.py140
-rw-r--r--pgcli/packages/pgliterals/__init__.py0
-rw-r--r--pgcli/packages/pgliterals/main.py15
-rw-r--r--pgcli/packages/pgliterals/pgliterals.json629
-rw-r--r--pgcli/packages/prioritization.py51
-rw-r--r--pgcli/packages/prompt_utils.py35
-rw-r--r--pgcli/packages/sqlcompletion.py608
12 files changed, 1981 insertions, 0 deletions
diff --git a/pgcli/packages/__init__.py b/pgcli/packages/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pgcli/packages/__init__.py
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
new file mode 100644
index 0000000..a11e7bf
--- /dev/null
+++ b/pgcli/packages/parseutils/__init__.py
@@ -0,0 +1,22 @@
+import sqlparse
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ("drop", "shutdown", "delete", "truncate", "alter")
+ return queries_start_with(queries, keywords)
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py
new file mode 100644
index 0000000..e1f9088
--- /dev/null
+++ b/pgcli/packages/parseutils/ctes.py
@@ -0,0 +1,141 @@
+from sqlparse import parse
+from sqlparse.tokens import Keyword, CTE, DML
+from sqlparse.sql import Identifier, IdentifierList, Parenthesis
+from collections import namedtuple
+from .meta import TableMetadata, ColumnMetadata
+
+
+# TableExpression is a namedtuple representing a CTE, used internally
+# name: cte alias assigned in the query
+# columns: list of column names
+# start: index into the original string of the left parens starting the CTE
+# stop: index into the original string of the right parens ending the CTE
+TableExpression = namedtuple("TableExpression", "name columns start stop")
+
+
+def isolate_query_ctes(full_text, text_before_cursor):
+ """Simplify a query by converting CTEs into table metadata objects"""
+
+ if not full_text or not full_text.strip():
+ return full_text, text_before_cursor, tuple()
+
+ ctes, remainder = extract_ctes(full_text)
+ if not ctes:
+ return full_text, text_before_cursor, ()
+
+ current_position = len(text_before_cursor)
+ meta = []
+
+ for cte in ctes:
+ if cte.start < current_position < cte.stop:
+ # Currently editing a cte - treat its body as the current full_text
+ text_before_cursor = full_text[cte.start : current_position]
+ full_text = full_text[cte.start : cte.stop]
+ return full_text, text_before_cursor, meta
+
+ # Append this cte to the list of available table metadata
+ cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
+ meta.append(TableMetadata(cte.name, cols))
+
+ # Editing past the last cte (ie the main body of the query)
+ full_text = full_text[ctes[-1].stop :]
+ text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]
+
+ return full_text, text_before_cursor, tuple(meta)
+
+
+def extract_ctes(sql):
+ """Extract constant table expresseions from a query
+
+ Returns tuple (ctes, remainder_sql)
+
+ ctes is a list of TableExpression namedtuples
+ remainder_sql is the text from the original query after the CTEs have
+ been stripped.
+ """
+
+ p = parse(sql)[0]
+
+ # Make sure the first meaningful token is "WITH" which is necessary to
+ # define CTEs
+ idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
+ if not (tok and tok.ttype == CTE):
+ return [], sql
+
+ # Get the next (meaningful) token, which should be the first CTE
+ idx, tok = p.token_next(idx)
+ if not tok:
+ return ([], "")
+ start_pos = token_start_pos(p.tokens, idx)
+ ctes = []
+
+ if isinstance(tok, IdentifierList):
+ # Multiple ctes
+ for t in tok.get_identifiers():
+ cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
+ cte = get_cte_from_token(t, start_pos + cte_start_offset)
+ if not cte:
+ continue
+ ctes.append(cte)
+ elif isinstance(tok, Identifier):
+ # A single CTE
+ cte = get_cte_from_token(tok, start_pos)
+ if cte:
+ ctes.append(cte)
+
+ idx = p.token_index(tok) + 1
+
+ # Collapse everything after the ctes into a remainder query
+ remainder = "".join(str(tok) for tok in p.tokens[idx:])
+
+ return ctes, remainder
+
+
+def get_cte_from_token(tok, pos0):
+ cte_name = tok.get_real_name()
+ if not cte_name:
+ return None
+
+ # Find the start position of the opening parens enclosing the cte body
+ idx, parens = tok.token_next_by(Parenthesis)
+ if not parens:
+ return None
+
+ start_pos = pos0 + token_start_pos(tok.tokens, idx)
+ cte_len = len(str(parens)) # includes parens
+ stop_pos = start_pos + cte_len
+
+ column_names = extract_column_names(parens)
+
+ return TableExpression(cte_name, column_names, start_pos, stop_pos)
+
+
+def extract_column_names(parsed):
+ # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
+ idx, tok = parsed.token_next_by(t=DML)
+ tok_val = tok and tok.value.lower()
+
+ if tok_val in ("insert", "update", "delete"):
+ # Jump ahead to the RETURNING clause where the list of column names is
+ idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
+ elif not tok_val == "select":
+ # Must be invalid CTE
+ return ()
+
+ # The next token should be either a column name, or a list of column names
+ idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
+ return tuple(t.get_name() for t in _identifiers(tok))
+
+
+def token_start_pos(tokens, idx):
+ return sum(len(str(t)) for t in tokens[:idx])
+
+
+def _identifiers(tok):
+ if isinstance(tok, IdentifierList):
+ for t in tok.get_identifiers():
+ # NB: IdentifierList.get_identifiers() can return non-identifiers!
+ if isinstance(t, Identifier):
+ yield t
+ elif isinstance(tok, Identifier):
+ yield tok
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
new file mode 100644
index 0000000..108c01a
--- /dev/null
+++ b/pgcli/packages/parseutils/meta.py
@@ -0,0 +1,170 @@
+from collections import namedtuple
+
+_ColumnMetadata = namedtuple(
+ "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
+)
+
+
+def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
+ return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
+
+
+ForeignKey = namedtuple(
+ "ForeignKey",
+ [
+ "parentschema",
+ "parenttable",
+ "parentcolumn",
+ "childschema",
+ "childtable",
+ "childcolumn",
+ ],
+)
+TableMetadata = namedtuple("TableMetadata", "name columns")
+
+
+def parse_defaults(defaults_string):
+ """Yields default values for a function, given the string provided by
+ pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
+ if not defaults_string:
+ return
+ current = ""
+ in_quote = None
+ for char in defaults_string:
+ if current == "" and char == " ":
+ # Skip space after comma separating default expressions
+ continue
+ if char == '"' or char == "'":
+ if in_quote and char == in_quote:
+ # End quote
+ in_quote = None
+ elif not in_quote:
+ # Begin quote
+ in_quote = char
+ elif char == "," and not in_quote:
+ # End of expression
+ yield current
+ current = ""
+ continue
+ current += char
+ yield current
+
+
+class FunctionMetadata(object):
+ def __init__(
+ self,
+ schema_name,
+ func_name,
+ arg_names,
+ arg_types,
+ arg_modes,
+ return_type,
+ is_aggregate,
+ is_window,
+ is_set_returning,
+ is_extension,
+ arg_defaults,
+ ):
+ """Class for describing a postgresql function"""
+
+ self.schema_name = schema_name
+ self.func_name = func_name
+
+ self.arg_modes = tuple(arg_modes) if arg_modes else None
+ self.arg_names = tuple(arg_names) if arg_names else None
+
+ # Be flexible in not requiring arg_types -- use None as a placeholder
+ # for each arg. (Used for compatibility with old versions of postgresql
+ # where such info is hard to get.
+ if arg_types:
+ self.arg_types = tuple(arg_types)
+ elif arg_modes:
+ self.arg_types = tuple([None] * len(arg_modes))
+ elif arg_names:
+ self.arg_types = tuple([None] * len(arg_names))
+ else:
+ self.arg_types = None
+
+ self.arg_defaults = tuple(parse_defaults(arg_defaults))
+
+ self.return_type = return_type.strip()
+ self.is_aggregate = is_aggregate
+ self.is_window = is_window
+ self.is_set_returning = is_set_returning
+ self.is_extension = bool(is_extension)
+ self.is_public = self.schema_name and self.schema_name == "public"
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _signature(self):
+ return (
+ self.schema_name,
+ self.func_name,
+ self.arg_names,
+ self.arg_types,
+ self.arg_modes,
+ self.return_type,
+ self.is_aggregate,
+ self.is_window,
+ self.is_set_returning,
+ self.is_extension,
+ self.arg_defaults,
+ )
+
+ def __hash__(self):
+ return hash(self._signature())
+
+ def __repr__(self):
+ return (
+ "%s(schema_name=%r, func_name=%r, arg_names=%r, "
+ "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
+ "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
+ ) % ((self.__class__.__name__,) + self._signature())
+
+ def has_variadic(self):
+ return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
+
+ def args(self):
+ """Returns a list of input-parameter ColumnMetadata namedtuples."""
+ if not self.arg_names:
+ return []
+ modes = self.arg_modes or ["i"] * len(self.arg_names)
+ args = [
+ (name, typ)
+ for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
+ if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
+ ]
+
+ def arg(name, typ, num):
+ num_args = len(args)
+ num_defaults = len(self.arg_defaults)
+ has_default = num + num_defaults >= num_args
+ default = (
+ self.arg_defaults[num - num_args + num_defaults]
+ if has_default
+ else None
+ )
+ return ColumnMetadata(name, typ, [], default, has_default)
+
+ return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
+
+ def fields(self):
+ """Returns a list of output-field ColumnMetadata namedtuples"""
+
+ if self.return_type.lower() == "void":
+ return []
+ elif not self.arg_modes:
+ # For functions without output parameters, the function name
+ # is used as the name of the output column.
+ # E.g. 'SELECT unnest FROM unnest(...);'
+ return [ColumnMetadata(self.func_name, self.return_type, [])]
+
+ return [
+ ColumnMetadata(name, typ, [])
+ for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
+ if mode in ("o", "b", "t")
+ ] # OUT, INOUT, TABLE
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
new file mode 100644
index 0000000..0ec3e69
--- /dev/null
+++ b/pgcli/packages/parseutils/tables.py
@@ -0,0 +1,170 @@
+import sqlparse
+from collections import namedtuple
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+TableReference = namedtuple(
+ "TableReference", ["schema", "name", "alias", "is_function"]
+)
+TableReference.ref = property(
+ lambda self: self.alias
+ or (
+ self.name
+ if self.name.islower() or self.name[0] == '"'
+ else '"' + self.name + '"'
+ )
+)
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
+ ):
+ return True
+ return False
+
+
+def _identifier_is_function(identifier):
+ return any(isinstance(t, Function) for t in identifier.tokens)
+
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # `return`. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
+ ):
+ tbl_prefix_seen = False
+ else:
+ yield item
+ elif item.ttype is Keyword or item.ttype is Keyword.DML:
+ item_val = item.value.upper()
+ if (
+ item_val
+ in (
+ "COPY",
+ "FROM",
+ "INTO",
+ "UPDATE",
+ "TABLE",
+ )
+ or item_val.endswith("JOIN")
+ ):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
+ tbl_prefix_seen = True
+ break
+
+
+def extract_table_identifiers(token_stream, allow_functions=True):
+ """yields tuples of TableReference namedtuples"""
+
+ # We need to do some massaging of the names because postgres is case-
+ # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
+ def parse_identifier(item):
+ name = item.get_real_name()
+ schema_name = item.get_parent_name()
+ alias = item.get_alias()
+ if not name:
+ schema_name = None
+ name = item.get_name()
+ alias = alias or name
+ schema_quoted = schema_name and item.value[0] == '"'
+ if schema_name and not schema_quoted:
+ schema_name = schema_name.lower()
+ quote_count = item.value.count('"')
+ name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
+ alias_quoted = alias and item.value[-1] == '"'
+ if alias_quoted or name_quoted and not alias and name.islower():
+ alias = '"' + (alias or name) + '"'
+ if name and not name_quoted and not name.islower():
+ if not alias:
+ alias = name
+ name = name.lower()
+ return schema_name, name, alias
+
+ try:
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ is_function = allow_functions and _identifier_is_function(
+ identifier
+ )
+ except AttributeError:
+ continue
+ if real_name:
+ yield TableReference(
+ schema_name, real_name, identifier.get_alias(), is_function
+ )
+ elif isinstance(item, Identifier):
+ schema_name, real_name, alias = parse_identifier(item)
+ is_function = allow_functions and _identifier_is_function(item)
+
+ yield TableReference(schema_name, real_name, alias, is_function)
+ elif isinstance(item, Function):
+ schema_name, real_name, alias = parse_identifier(item)
+ yield TableReference(None, real_name, alias, allow_functions)
+ except StopIteration:
+ return
+
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of TableReference namedtuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return ()
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == "insert"
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+
+ # Kludge: sqlparse mistakenly identifies insert statements as
+ # function calls due to the parenthesized column list, e.g. interprets
+ # "insert into foo (bar, baz)" as a function call to foo with arguments
+ # (bar, baz). So don't allow any identifiers in insert statements
+ # to have is_function=True
+ identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt)
+ # In the case 'sche.<cursor>', we get an empty TableReference; remove that
+ return tuple(i for i in identifiers if i.name)
diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py
new file mode 100644
index 0000000..034c96e
--- /dev/null
+++ b/pgcli/packages/parseutils/utils.py
@@ -0,0 +1,140 @@
+import re
+import sqlparse
+from sqlparse.sql import Identifier
+from sqlparse.tokens import Token, Error
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ "alphanum_underscore": re.compile(r"(\w+)$"),
+ # This matches everything except spaces, parens, colon, and comma
+ "many_punctuations": re.compile(r"([^():,\s]+)$"),
+ # This matches everything except spaces, parens, colon, comma, and period
+ "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
+ # This matches everything except a space.
+ "all_punctuations": re.compile(r"([^\s]+)$"),
+}
+
+
+def last_word(text, include="alphanum_underscore"):
+ r"""
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ >>> last_word('"foo*bar', include='most_punctuations')
+ '"foo*bar'
+ """
+
+ if not text: # Empty string
+ return ""
+
+ if text[-1].isspace():
+ return ""
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ""
+
+
+def find_prev_keyword(sql, n_skip=0):
+ """Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ""
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+ flattened = flattened[: len(flattened) - n_skip]
+
+ logical_operators = ("AND", "OR", "NOT", "BETWEEN")
+
+ for t in reversed(flattened):
+ if t.value == "(" or (
+ t.is_keyword and (t.value.upper() not in logical_operators)
+ ):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index throws an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = "".join(tok.value for tok in flattened[: idx + 1])
+ return t, text
+
+ return None, ""
+
+
+# Postgresql dollar quote signs look like `$$` or `$tag$`
+dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
+
+
+def is_open_quote(sql):
+ """Returns true if the query contains an unclosed quote"""
+
+ # parsed can contain one or more semi-colon separated commands
+ parsed = sqlparse.parse(sql)
+ return any(_parsed_is_open_quote(p) for p in parsed)
+
+
+def _parsed_is_open_quote(parsed):
+ # Look for unmatched single quotes, or unmatched dollar sign quotes
+ return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
+
+
+def parse_partial_identifier(word):
+ """Attempt to parse a (partially typed) word as an identifier
+
+ word may include a schema qualification, like `schema_name.partial_name`
+ or `schema_name.` There may also be unclosed quotation marks, like
+ `"schema`, or `schema."partial_name`
+
+ :param word: string representing a (partially complete) identifier
+ :return: sqlparse.sql.Identifier, or None
+ """
+
+ p = sqlparse.parse(word)[0]
+ n_tok = len(p.tokens)
+ if n_tok == 1 and isinstance(p.tokens[0], Identifier):
+ return p.tokens[0]
+ elif p.token_next_by(m=(Error, '"'))[1]:
+ # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
+ # Close the double quote, then reparse
+ return parse_partial_identifier(word + '"')
+ else:
+ return None
diff --git a/pgcli/packages/pgliterals/__init__.py b/pgcli/packages/pgliterals/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pgcli/packages/pgliterals/__init__.py
diff --git a/pgcli/packages/pgliterals/main.py b/pgcli/packages/pgliterals/main.py
new file mode 100644
index 0000000..5c39296
--- /dev/null
+++ b/pgcli/packages/pgliterals/main.py
@@ -0,0 +1,15 @@
+import os
+import json
+
+root = os.path.dirname(__file__)
+literal_file = os.path.join(root, "pgliterals.json")
+
+with open(literal_file) as f:
+ literals = json.load(f)
+
+
+def get_literals(literal_type, type_=tuple):
+ # Where `literal_type` is one of 'keywords', 'functions', 'datatypes',
+ # returns a tuple of literal values of that type.
+
+ return type_(literals[literal_type])
diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json
new file mode 100644
index 0000000..c7b74b5
--- /dev/null
+++ b/pgcli/packages/pgliterals/pgliterals.json
@@ -0,0 +1,629 @@
+{
+ "keywords": {
+ "ACCESS": [],
+ "ADD": [],
+ "ALL": [],
+ "ALTER": [
+ "AGGREGATE",
+ "COLLATION",
+ "COLUMN",
+ "CONVERSION",
+ "DATABASE",
+ "DEFAULT",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN",
+ "FUNCTION",
+ "GROUP",
+ "INDEX",
+ "LANGUAGE",
+ "LARGE OBJECT",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "SYSTEM",
+ "TABLE",
+ "TABLESPACE",
+ "TEXT SEARCH",
+ "TRIGGER",
+ "TYPE",
+ "USER",
+ "VIEW"
+ ],
+ "AND": [],
+ "ANY": [],
+ "AS": [],
+ "ASC": [],
+ "AUDIT": [],
+ "BEGIN": [],
+ "BETWEEN": [],
+ "BY": [],
+ "CASE": [],
+ "CHAR": [],
+ "CHECK": [],
+ "CLUSTER": [],
+ "COLUMN": [],
+ "COMMENT": [],
+ "COMMIT": [],
+ "COMPRESS": [],
+ "CONCURRENTLY": [],
+ "CONNECT": [],
+ "COPY": [],
+ "CREATE": [
+ "ACCESS METHOD",
+ "AGGREGATE",
+ "CAST",
+ "COLLATION",
+ "CONVERSION",
+ "DATABASE",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN DATA WRAPPER",
+ "FOREIGN EXTENSION",
+ "FUNCTION",
+ "GLOBAL",
+ "GROUP",
+ "IF NOT EXISTS",
+ "INDEX",
+ "LANGUAGE",
+ "LOCAL",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "OR REPLACE",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "TABLE",
+ "TABLESPACE",
+ "TEMPORARY",
+ "TEXT SEARCH",
+ "TRIGGER",
+ "TYPE",
+ "UNIQUE",
+ "UNLOGGED",
+ "USER",
+ "USER MAPPING",
+ "VIEW"
+ ],
+ "CURRENT": [],
+ "DATABASE": [],
+ "DATE": [],
+ "DECIMAL": [],
+ "DEFAULT": [],
+ "DELETE FROM": [],
+ "DELIMITER": [],
+ "DESC": [],
+ "DESCRIBE": [],
+ "DISTINCT": [],
+ "DROP": [
+ "ACCESS METHOD",
+ "AGGREGATE",
+ "CAST",
+ "COLLATION",
+ "COLUMN",
+ "CONVERSION",
+ "DATABASE",
+ "DOMAIN",
+ "EVENT TRIGGER",
+ "EXTENSION",
+ "FOREIGN DATA WRAPPER",
+ "FOREIGN TABLE",
+ "FUNCTION",
+ "GROUP",
+ "INDEX",
+ "LANGUAGE",
+ "MATERIALIZED VIEW",
+ "OPERATOR",
+ "OWNED",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SCHEMA",
+ "SEQUENCE",
+ "SERVER",
+ "TABLE",
+ "TABLESPACE",
+ "TEXT SEARCH",
+ "TRANSFORM",
+ "TRIGGER",
+ "TYPE",
+ "USER",
+ "USER MAPPING",
+ "VIEW"
+ ],
+ "EXPLAIN": [],
+ "ELSE": [],
+ "ENCODING": [],
+ "ESCAPE": [],
+ "EXCLUSIVE": [],
+ "EXISTS": [],
+ "EXTENSION": [],
+ "FILE": [],
+ "FLOAT": [],
+ "FOR": [],
+ "FORMAT": [],
+ "FORCE_QUOTE": [],
+ "FORCE_NOT_NULL": [],
+ "FREEZE": [],
+ "FROM": [],
+ "FULL": [],
+ "FUNCTION": [],
+ "GRANT": [],
+ "GROUP BY": [],
+ "HAVING": [],
+ "HEADER": [],
+ "IDENTIFIED": [],
+ "IMMEDIATE": [],
+ "IN": [],
+ "INCREMENT": [],
+ "INDEX": [],
+ "INITIAL": [],
+ "INSERT INTO": [],
+ "INTEGER": [],
+ "INTERSECT": [],
+ "INTERVAL": [],
+ "INTO": [],
+ "IS": [],
+ "JOIN": [],
+ "LANGUAGE": [],
+ "LEFT": [],
+ "LEVEL": [],
+ "LIKE": [],
+ "LIMIT": [],
+ "LOCK": [],
+ "LONG": [],
+ "MATERIALIZED VIEW": [],
+ "MAXEXTENTS": [],
+ "MINUS": [],
+ "MLSLABEL": [],
+ "MODE": [],
+ "MODIFY": [],
+ "NOT": [],
+ "NOAUDIT": [],
+ "NOTICE": [],
+ "NOCOMPRESS": [],
+ "NOWAIT": [],
+ "NULL": [],
+ "NUMBER": [],
+ "OIDS": [],
+ "OF": [],
+ "OFFLINE": [],
+ "ON": [],
+ "ONLINE": [],
+ "OPTION": [],
+ "OR": [],
+ "ORDER BY": [],
+ "OUTER": [],
+ "OWNER": [],
+ "PCTFREE": [],
+ "PRIMARY": [],
+ "PRIOR": [],
+ "PRIVILEGES": [],
+ "QUOTE": [],
+ "RAISE": [],
+ "RENAME": [],
+ "REPLACE": [],
+ "RESET": ["ALL"],
+ "RAW": [],
+ "REFRESH MATERIALIZED VIEW": [],
+ "RESOURCE": [],
+ "RETURNS": [],
+ "REVOKE": [],
+ "RIGHT": [],
+ "ROLLBACK": [],
+ "ROW": [],
+ "ROWID": [],
+ "ROWNUM": [],
+ "ROWS": [],
+ "SELECT": [],
+ "SESSION": [],
+ "SET": [],
+ "SHARE": [],
+ "SHOW": [],
+ "SIZE": [],
+ "SMALLINT": [],
+ "START": [],
+ "SUCCESSFUL": [],
+ "SYNONYM": [],
+ "SYSDATE": [],
+ "TABLE": [],
+ "TEMPLATE": [],
+ "THEN": [],
+ "TO": [],
+ "TRIGGER": [],
+ "TRUNCATE": [],
+ "UID": [],
+ "UNION": [],
+ "UNIQUE": [],
+ "UPDATE": [],
+ "USE": [],
+ "USER": [],
+ "USING": [],
+ "VALIDATE": [],
+ "VALUES": [],
+ "VARCHAR": [],
+ "VARCHAR2": [],
+ "VIEW": [],
+ "WHEN": [],
+ "WHENEVER": [],
+ "WHERE": [],
+ "WITH": []
+ },
+ "functions": [
+ "ABBREV",
+ "ABS",
+ "AGE",
+ "AREA",
+ "ARRAY_AGG",
+ "ARRAY_APPEND",
+ "ARRAY_CAT",
+ "ARRAY_DIMS",
+ "ARRAY_FILL",
+ "ARRAY_LENGTH",
+ "ARRAY_LOWER",
+ "ARRAY_NDIMS",
+ "ARRAY_POSITION",
+ "ARRAY_POSITIONS",
+ "ARRAY_PREPEND",
+ "ARRAY_REMOVE",
+ "ARRAY_REPLACE",
+ "ARRAY_TO_STRING",
+ "ARRAY_UPPER",
+ "ASCII",
+ "AVG",
+ "BIT_AND",
+ "BIT_LENGTH",
+ "BIT_OR",
+ "BOOL_AND",
+ "BOOL_OR",
+ "BOUND_BOX",
+ "BOX",
+ "BROADCAST",
+ "BTRIM",
+ "CARDINALITY",
+ "CBRT",
+ "CEIL",
+ "CEILING",
+ "CENTER",
+ "CHAR_LENGTH",
+ "CHR",
+ "CIRCLE",
+ "CLOCK_TIMESTAMP",
+ "CONCAT",
+ "CONCAT_WS",
+ "CONVERT",
+ "CONVERT_FROM",
+ "CONVERT_TO",
+ "COUNT",
+ "CUME_DIST",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "DATE_PART",
+ "DATE_TRUNC",
+ "DECODE",
+ "DEGREES",
+ "DENSE_RANK",
+ "DIAMETER",
+ "DIV",
+ "ENCODE",
+ "ENUM_FIRST",
+ "ENUM_LAST",
+ "ENUM_RANGE",
+ "EVERY",
+ "EXP",
+ "EXTRACT",
+ "FAMILY",
+ "FIRST_VALUE",
+ "FLOOR",
+ "FORMAT",
+ "GET_BIT",
+ "GET_BYTE",
+ "HEIGHT",
+ "HOST",
+ "HOSTMASK",
+ "INET_MERGE",
+ "INET_SAME_FAMILY",
+ "INITCAP",
+ "ISCLOSED",
+ "ISFINITE",
+ "ISOPEN",
+ "JUSTIFY_DAYS",
+ "JUSTIFY_HOURS",
+ "JUSTIFY_INTERVAL",
+ "LAG",
+ "LAST_VALUE",
+ "LEAD",
+ "LEFT",
+ "LENGTH",
+ "LINE",
+ "LN",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "LOG",
+ "LOG10",
+ "LOWER",
+ "LPAD",
+ "LSEG",
+ "LTRIM",
+ "MAKE_DATE",
+ "MAKE_INTERVAL",
+ "MAKE_TIME",
+ "MAKE_TIMESTAMP",
+ "MAKE_TIMESTAMPTZ",
+ "MASKLEN",
+ "MAX",
+ "MD5",
+ "MIN",
+ "MOD",
+ "NETMASK",
+ "NETWORK",
+ "NOW",
+ "NPOINTS",
+ "NTH_VALUE",
+ "NTILE",
+ "NUM_NONNULLS",
+ "NUM_NULLS",
+ "OCTET_LENGTH",
+ "OVERLAY",
+ "PARSE_IDENT",
+ "PATH",
+ "PCLOSE",
+ "PERCENT_RANK",
+ "PG_CLIENT_ENCODING",
+ "PI",
+ "POINT",
+ "POLYGON",
+ "POPEN",
+ "POSITION",
+ "POWER",
+ "QUOTE_IDENT",
+ "QUOTE_LITERAL",
+ "QUOTE_NULLABLE",
+ "RADIANS",
+ "RADIUS",
+ "RANK",
+ "REGEXP_MATCH",
+ "REGEXP_MATCHES",
+ "REGEXP_REPLACE",
+ "REGEXP_SPLIT_TO_ARRAY",
+ "REGEXP_SPLIT_TO_TABLE",
+ "REPEAT",
+ "REPLACE",
+ "REVERSE",
+ "RIGHT",
+ "ROUND",
+ "ROW_NUMBER",
+ "RPAD",
+ "RTRIM",
+ "SCALE",
+ "SET_BIT",
+ "SET_BYTE",
+ "SET_MASKLEN",
+ "SHA224",
+ "SHA256",
+ "SHA384",
+ "SHA512",
+ "SIGN",
+ "SPLIT_PART",
+ "SQRT",
+ "STARTS_WITH",
+ "STATEMENT_TIMESTAMP",
+ "STRING_TO_ARRAY",
+ "STRPOS",
+ "SUBSTR",
+ "SUBSTRING",
+ "SUM",
+ "TEXT",
+ "TIMEOFDAY",
+ "TO_ASCII",
+ "TO_CHAR",
+ "TO_DATE",
+ "TO_HEX",
+ "TO_NUMBER",
+ "TO_TIMESTAMP",
+ "TRANSACTION_TIMESTAMP",
+ "TRANSLATE",
+ "TRIM",
+ "TRUNC",
+ "UNNEST",
+ "UPPER",
+ "WIDTH",
+ "WIDTH_BUCKET",
+ "XMLAGG"
+ ],
+ "datatypes": [
+ "ANY",
+ "ANYARRAY",
+ "ANYELEMENT",
+ "ANYENUM",
+ "ANYNONARRAY",
+ "ANYRANGE",
+ "BIGINT",
+ "BIGSERIAL",
+ "BIT",
+ "BIT VARYING",
+ "BOOL",
+ "BOOLEAN",
+ "BOX",
+ "BYTEA",
+ "CHAR",
+ "CHARACTER",
+ "CHARACTER VARYING",
+ "CIDR",
+ "CIRCLE",
+ "CSTRING",
+ "DATE",
+ "DECIMAL",
+ "DOUBLE PRECISION",
+ "EVENT_TRIGGER",
+ "FDW_HANDLER",
+ "FLOAT4",
+ "FLOAT8",
+ "INET",
+ "INT",
+ "INT2",
+ "INT4",
+ "INT8",
+ "INTEGER",
+ "INTERNAL",
+ "INTERVAL",
+ "JSON",
+ "JSONB",
+ "LANGUAGE_HANDLER",
+ "LINE",
+ "LSEG",
+ "MACADDR",
+ "MACADDR8",
+ "MONEY",
+ "NUMERIC",
+ "OID",
+ "OPAQUE",
+ "PATH",
+ "PG_LSN",
+ "POINT",
+ "POLYGON",
+ "REAL",
+ "RECORD",
+ "REGCLASS",
+ "REGCONFIG",
+ "REGDICTIONARY",
+ "REGNAMESPACE",
+ "REGOPER",
+ "REGOPERATOR",
+ "REGPROC",
+ "REGPROCEDURE",
+ "REGROLE",
+ "REGTYPE",
+ "SERIAL",
+ "SERIAL2",
+ "SERIAL4",
+ "SERIAL8",
+ "SMALLINT",
+ "SMALLSERIAL",
+ "TEXT",
+ "TIME",
+ "TIMESTAMP",
+ "TRIGGER",
+ "TSQUERY",
+ "TSVECTOR",
+ "TXID_SNAPSHOT",
+ "UUID",
+ "VARBIT",
+ "VARCHAR",
+ "VOID",
+ "XML"
+ ],
+ "reserved": [
+ "ALL",
+ "ANALYSE",
+ "ANALYZE",
+ "AND",
+ "ANY",
+ "ARRAY",
+ "AS",
+ "ASC",
+ "ASYMMETRIC",
+ "BOTH",
+ "CASE",
+ "CAST",
+ "CHECK",
+ "COLLATE",
+ "COLUMN",
+ "CONSTRAINT",
+ "CREATE",
+ "CURRENT_CATALOG",
+ "CURRENT_DATE",
+ "CURRENT_ROLE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "CURRENT_USER",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DESC",
+ "DISTINCT",
+ "DO",
+ "ELSE",
+ "END",
+ "EXCEPT",
+ "FALSE",
+ "FETCH",
+ "FOR",
+ "FOREIGN",
+ "FROM",
+ "GRANT",
+ "GROUP",
+ "HAVING",
+ "IN",
+ "INITIALLY",
+ "INTERSECT",
+ "INTO",
+ "LATERAL",
+ "LEADING",
+ "LIMIT",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "NOT",
+ "NULL",
+ "OFFSET",
+ "ON",
+ "ONLY",
+ "OR",
+ "ORDER",
+ "PLACING",
+ "PRIMARY",
+ "REFERENCES",
+ "RETURNING",
+ "SELECT",
+ "SESSION_USER",
+ "SOME",
+ "SYMMETRIC",
+ "TABLE",
+ "THEN",
+ "TO",
+ "TRAILING",
+ "TRUE",
+ "UNION",
+ "UNIQUE",
+ "USER",
+ "USING",
+ "VARIADIC",
+ "WHEN",
+ "WHERE",
+ "WINDOW",
+ "WITH",
+ "AUTHORIZATION",
+ "BINARY",
+ "COLLATION",
+ "CONCURRENTLY",
+ "CROSS",
+ "CURRENT_SCHEMA",
+ "FREEZE",
+ "FULL",
+ "ILIKE",
+ "INNER",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "LEFT",
+ "LIKE",
+ "NATURAL",
+ "NOTNULL",
+ "OUTER",
+ "OVERLAPS",
+ "RIGHT",
+ "SIMILAR",
+ "TABLESAMPLE",
+ "VERBOSE"
+ ]
+}
diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py
new file mode 100644
index 0000000..e92dcbb
--- /dev/null
+++ b/pgcli/packages/prioritization.py
@@ -0,0 +1,51 @@
+import re
+import sqlparse
+from sqlparse.tokens import Name
+from collections import defaultdict
+from .pgliterals.main import get_literals
+
+
+white_space_regex = re.compile("\\s+", re.MULTILINE)
+
+
+def _compile_regex(keyword):
+ # Surround the keyword with word boundaries and replace interior whitespace
+ # with whitespace wildcards
+ pattern = "\\b" + white_space_regex.sub(r"\\s+", keyword) + "\\b"
+ return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
+
+
+keywords = get_literals("keywords")
+keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
+
+
+class PrevalenceCounter(object):
+ def __init__(self):
+ self.keyword_counts = defaultdict(int)
+ self.name_counts = defaultdict(int)
+
+ def update(self, text):
+ self.update_keywords(text)
+ self.update_names(text)
+
+ def update_names(self, text):
+ for parsed in sqlparse.parse(text):
+ for token in parsed.flatten():
+ if token.ttype in Name:
+ self.name_counts[token.value] += 1
+
+ def clear_names(self):
+ self.name_counts = defaultdict(int)
+
+ def update_keywords(self, text):
+ # Count keywords. Can't rely for sqlparse for this, because it's
+ # database agnostic
+ for keyword, regex in keyword_regexs.items():
+ for _ in regex.finditer(text):
+ self.keyword_counts[keyword] += 1
+
+ def keyword_count(self, keyword):
+ return self.keyword_counts[keyword]
+
+ def name_count(self, name):
+ return self.name_counts[name]
diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py
new file mode 100644
index 0000000..3c58490
--- /dev/null
+++ b/pgcli/packages/prompt_utils.py
@@ -0,0 +1,35 @@
+import sys
+import click
+from .parseutils import is_destructive
+
+
+def confirm_destructive_query(queries):
+ """Check if the query is destructive and prompts the user to confirm.
+
+ Returns:
+ * None if the query is non-destructive or we can't prompt the user.
+ * True if the query is destructive and the user wants to proceed.
+ * False if the query is destructive and the user doesn't want to proceed.
+
+ """
+ prompt_text = (
+ "You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
+ )
+ if is_destructive(queries) and sys.stdin.isatty():
+ return prompt(prompt_text, type=bool)
+
+
+def confirm(*args, **kwargs):
+ """Prompt for confirmation (yes/no) and handle any abort exceptions."""
+ try:
+ return click.confirm(*args, **kwargs)
+ except click.Abort:
+ return False
+
+
+def prompt(*args, **kwargs):
+ """Prompt the user for input and handle any abort exceptions."""
+ try:
+ return click.prompt(*args, **kwargs)
+ except click.Abort:
+ return False
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
new file mode 100644
index 0000000..6ef8859
--- /dev/null
+++ b/pgcli/packages/sqlcompletion.py
@@ -0,0 +1,608 @@
+import sys
+import re
+import sqlparse
+from collections import namedtuple
+from sqlparse.sql import Comparison, Identifier, Where
+from .parseutils.utils import last_word, find_prev_keyword, parse_partial_identifier
+from .parseutils.tables import extract_tables
+from .parseutils.ctes import isolate_query_ctes
+from pgspecial.main import parse_special_command
+
+
+Special = namedtuple("Special", [])
+Database = namedtuple("Database", [])
+Schema = namedtuple("Schema", ["quoted"])
+Schema.__new__.__defaults__ = (False,)
+# FromClauseItem is a table/view/function used in the FROM clause
+# `table_refs` contains the list of tables/... already in the statement,
+# used to ensure that the alias we suggest is unique
+FromClauseItem = namedtuple("FromClauseItem", "schema table_refs local_tables")
+Table = namedtuple("Table", ["schema", "table_refs", "local_tables"])
+TableFormat = namedtuple("TableFormat", [])
+View = namedtuple("View", ["schema", "table_refs"])
+# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
+JoinCondition = namedtuple("JoinCondition", ["table_refs", "parent"])
+# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
+Join = namedtuple("Join", ["table_refs", "schema"])
+
+Function = namedtuple("Function", ["schema", "table_refs", "usage"])
+# For convenience, don't require the `usage` argument in Function constructor
+Function.__new__.__defaults__ = (None, tuple(), None)
+Table.__new__.__defaults__ = (None, tuple(), tuple())
+View.__new__.__defaults__ = (None, tuple())
+FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
+
+Column = namedtuple(
+ "Column",
+ ["table_refs", "require_last_table", "local_tables", "qualifiable", "context"],
+)
+Column.__new__.__defaults__ = (None, None, tuple(), False, None)
+
+Keyword = namedtuple("Keyword", ["last_token"])
+Keyword.__new__.__defaults__ = (None,)
+NamedQuery = namedtuple("NamedQuery", [])
+Datatype = namedtuple("Datatype", ["schema"])
+Alias = namedtuple("Alias", ["aliases"])
+
+Path = namedtuple("Path", [])
+
+
+class SqlStatement(object):
+ def __init__(self, full_text, text_before_cursor):
+ self.identifier = None
+ self.word_before_cursor = word_before_cursor = last_word(
+ text_before_cursor, include="many_punctuations"
+ )
+ full_text = _strip_named_query(full_text)
+ text_before_cursor = _strip_named_query(text_before_cursor)
+
+ full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
+ full_text, text_before_cursor
+ )
+
+ self.text_before_cursor_including_last_word = text_before_cursor
+
+ # If we've partially typed a word then word_before_cursor won't be an
+ # empty string. In that case we want to remove the partially typed
+ # string before sending it to the sqlparser. Otherwise the last token
+ # will always be the partially typed string which renders the smart
+ # completion useless because it will always return the list of
+ # keywords as completion.
+ if self.word_before_cursor:
+ if word_before_cursor[-1] == "(" or word_before_cursor[0] == "\\":
+ parsed = sqlparse.parse(text_before_cursor)
+ else:
+ text_before_cursor = text_before_cursor[: -len(word_before_cursor)]
+ parsed = sqlparse.parse(text_before_cursor)
+ self.identifier = parse_partial_identifier(word_before_cursor)
+ else:
+ parsed = sqlparse.parse(text_before_cursor)
+
+ full_text, text_before_cursor, parsed = _split_multiple_statements(
+ full_text, text_before_cursor, parsed
+ )
+
+ self.full_text = full_text
+ self.text_before_cursor = text_before_cursor
+ self.parsed = parsed
+
+ self.last_token = parsed and parsed.token_prev(len(parsed.tokens))[1] or ""
+
+ def is_insert(self):
+ return self.parsed.token_first().value.lower() == "insert"
+
+ def get_tables(self, scope="full"):
+ """Gets the tables available in the statement.
+ param `scope:` possible values: 'full', 'insert', 'before'
+ If 'insert', only the first table is returned.
+ If 'before', only tables before the cursor are returned.
+ If not 'insert' and the stmt is an insert, the first table is skipped.
+ """
+ tables = extract_tables(
+ self.full_text if scope == "full" else self.text_before_cursor
+ )
+ if scope == "insert":
+ tables = tables[:1]
+ elif self.is_insert():
+ tables = tables[1:]
+ return tables
+
+ def get_previous_token(self, token):
+ return self.parsed.token_prev(self.parsed.token_index(token))[1]
+
+ def get_identifier_schema(self):
+ schema = (self.identifier and self.identifier.get_parent_name()) or None
+ # If schema name is unquoted, lower-case it
+ if schema and self.identifier.value[0] != '"':
+ schema = schema.lower()
+
+ return schema
+
+ def reduce_to_prev_keyword(self, n_skip=0):
+ prev_keyword, self.text_before_cursor = find_prev_keyword(
+ self.text_before_cursor, n_skip=n_skip
+ )
+ return prev_keyword
+
+
+def suggest_type(full_text, text_before_cursor):
+ """Takes the full_text that is typed so far and also the text before the
+ cursor to suggest completion type and scope.
+
+ Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
+ A scope for a column category will be a list of tables.
+ """
+
+ if full_text.startswith("\\i "):
+ return (Path(),)
+
+ # This is a temporary hack; the exception handling
+ # here should be removed once sqlparse has been fixed
+ try:
+ stmt = SqlStatement(full_text, text_before_cursor)
+ except (TypeError, AttributeError):
+ return []
+
+ # Check for special commands and handle those separately
+ if stmt.parsed:
+ # Be careful here because trivial whitespace is parsed as a
+ # statement, but the statement won't have a first token
+ tok1 = stmt.parsed.token_first()
+ if tok1 and tok1.value.startswith("\\"):
+ text = stmt.text_before_cursor + stmt.word_before_cursor
+ return suggest_special(text)
+
+ return suggest_based_on_last_token(stmt.last_token, stmt)
+
+
+named_query_regex = re.compile(r"^\s*\\ns\s+[A-z0-9\-_]+\s+")
+
+
+def _strip_named_query(txt):
+ """
+ This will strip "save named query" command in the beginning of the line:
+ '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
+ """
+
+ if named_query_regex.match(txt):
+ txt = named_query_regex.sub("", txt)
+ return txt
+
+
+function_body_pattern = re.compile(r"(\$.*?\$)([\s\S]*?)\1", re.M)
+
+
+def _find_function_body(text):
+ split = function_body_pattern.search(text)
+ return (split.start(2), split.end(2)) if split else (None, None)
+
+
+def _statement_from_function(full_text, text_before_cursor, statement):
+ current_pos = len(text_before_cursor)
+ body_start, body_end = _find_function_body(full_text)
+ if body_start is None:
+ return full_text, text_before_cursor, statement
+ if not body_start <= current_pos < body_end:
+ return full_text, text_before_cursor, statement
+ full_text = full_text[body_start:body_end]
+ text_before_cursor = text_before_cursor[body_start:]
+ parsed = sqlparse.parse(text_before_cursor)
+ return _split_multiple_statements(full_text, text_before_cursor, parsed)
+
+
+def _split_multiple_statements(full_text, text_before_cursor, parsed):
+ if len(parsed) > 1:
+ # Multiple statements being edited -- isolate the current one by
+ # cumulatively summing statement lengths to find the one that bounds
+ # the current position
+ current_pos = len(text_before_cursor)
+ stmt_start, stmt_end = 0, 0
+
+ for statement in parsed:
+ stmt_len = len(str(statement))
+ stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
+
+ if stmt_end >= current_pos:
+ text_before_cursor = full_text[stmt_start:current_pos]
+ full_text = full_text[stmt_start:]
+ break
+
+ elif parsed:
+ # A single statement
+ statement = parsed[0]
+ else:
+ # The empty string
+ return full_text, text_before_cursor, None
+
+ token2 = None
+ if statement.get_type() in ("CREATE", "CREATE OR REPLACE"):
+ token1 = statement.token_first()
+ if token1:
+ token1_idx = statement.token_index(token1)
+ token2 = statement.token_next(token1_idx)[1]
+ if token2 and token2.value.upper() == "FUNCTION":
+ full_text, text_before_cursor, statement = _statement_from_function(
+ full_text, text_before_cursor, statement
+ )
+ return full_text, text_before_cursor, statement
+
+
+SPECIALS_SUGGESTION = {
+ "dT": Datatype,
+ "df": Function,
+ "dt": Table,
+ "dv": View,
+ "sf": Function,
+}
+
+
+def suggest_special(text):
+ text = text.lstrip()
+ cmd, _, arg = parse_special_command(text)
+
+ if cmd == text:
+ # Trying to complete the special command itself
+ return (Special(),)
+
+ if cmd in ("\\c", "\\connect"):
+ return (Database(),)
+
+ if cmd == "\\T":
+ return (TableFormat(),)
+
+ if cmd == "\\dn":
+ return (Schema(),)
+
+ if arg:
+ # Try to distinguish "\d name" from "\d schema.name"
+ # Note that this will fail to obtain a schema name if wildcards are
+ # used, e.g. "\d schema???.name"
+ parsed = sqlparse.parse(arg)[0].tokens[0]
+ try:
+ schema = parsed.get_parent_name()
+ except AttributeError:
+ schema = None
+ else:
+ schema = None
+
+ if cmd[1:] == "d":
+ # \d can describe tables or views
+ if schema:
+ return (Table(schema=schema), View(schema=schema))
+ else:
+ return (Schema(), Table(schema=None), View(schema=None))
+ elif cmd[1:] in SPECIALS_SUGGESTION:
+ rel_type = SPECIALS_SUGGESTION[cmd[1:]]
+ if schema:
+ if rel_type == Function:
+ return (Function(schema=schema, usage="special"),)
+ return (rel_type(schema=schema),)
+ else:
+ if rel_type == Function:
+ return (Schema(), Function(schema=None, usage="special"))
+ return (Schema(), rel_type(schema=None))
+
+ if cmd in ["\\n", "\\ns", "\\nd"]:
+ return (NamedQuery(),)
+
+ return (Keyword(), Special())
+
+
+def suggest_based_on_last_token(token, stmt):
+
+ if isinstance(token, str):
+ token_v = token.lower()
+ elif isinstance(token, Comparison):
+ # If 'token' is a Comparison type such as
+ # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
+ # token.value on the comparison type will only return the lhs of the
+ # comparison. In this case a.id. So we need to do token.tokens to get
+ # both sides of the comparison and pick the last token out of that
+ # list.
+ token_v = token.tokens[-1].value.lower()
+ elif isinstance(token, Where):
+ # sqlparse groups all tokens from the where clause into a single token
+ # list. This means that token.value may be something like
+ # 'where foo > 5 and '. We need to look "inside" token.tokens to handle
+ # suggestions in complicated where clauses correctly
+ prev_keyword = stmt.reduce_to_prev_keyword()
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ elif isinstance(token, Identifier):
+ # If the previous token is an identifier, we can suggest datatypes if
+ # we're in a parenthesized column/field list, e.g.:
+ # CREATE TABLE foo (Identifier <CURSOR>
+ # CREATE FUNCTION foo (Identifier <CURSOR>
+ # If we're not in a parenthesized list, the most likely scenario is the
+ # user is about to specify an alias, e.g.:
+ # SELECT Identifier <CURSOR>
+ # SELECT foo FROM Identifier <CURSOR>
+ prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
+ if prev_keyword and prev_keyword.value == "(":
+ # Suggest datatypes
+ return suggest_based_on_last_token("type", stmt)
+ else:
+ return (Keyword(),)
+ else:
+ token_v = token.value.lower()
+
+ if not token:
+ return (Keyword(), Special())
+ elif token_v.endswith("("):
+ p = sqlparse.parse(stmt.text_before_cursor)[0]
+
+ if p.tokens and isinstance(p.tokens[-1], Where):
+ # Four possibilities:
+ # 1 - Parenthesized clause like "WHERE foo AND ("
+ # Suggest columns/functions
+ # 2 - Function call like "WHERE foo("
+ # Suggest columns/functions
+ # 3 - Subquery expression like "WHERE EXISTS ("
+ # Suggest keywords, in order to do a subquery
+ # 4 - Subquery OR array comparison like "WHERE foo = ANY("
+ # Suggest columns/functions AND keywords. (If we wanted to be
+ # really fancy, we could suggest only array-typed columns)
+
+ column_suggestions = suggest_based_on_last_token("where", stmt)
+
+ # Check for a subquery expression (cases 3 & 4)
+ where = p.tokens[-1]
+ prev_tok = where.token_prev(len(where.tokens) - 1)[1]
+
+ if isinstance(prev_tok, Comparison):
+ # e.g. "SELECT foo FROM bar WHERE foo = ANY("
+ prev_tok = prev_tok.tokens[-1]
+
+ prev_tok = prev_tok.value.lower()
+ if prev_tok == "exists":
+ return (Keyword(),)
+ else:
+ return column_suggestions
+
+ # Get the token before the parens
+ prev_tok = p.token_prev(len(p.tokens) - 1)[1]
+
+ if (
+ prev_tok
+ and prev_tok.value
+ and prev_tok.value.lower().split(" ")[-1] == "using"
+ ):
+ # tbl1 INNER JOIN tbl2 USING (col1, col2)
+ tables = stmt.get_tables("before")
+
+ # suggest columns that are present in more than one table
+ return (
+ Column(
+ table_refs=tables,
+ require_last_table=True,
+ local_tables=stmt.local_tables,
+ ),
+ )
+
+ elif p.token_first().value.lower() == "select":
+ # If the lparen is preceeded by a space chances are we're about to
+ # do a sub-select.
+ if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("):
+ return (Keyword(),)
+ prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
+ if prev_prev_tok and prev_prev_tok.normalized == "INTO":
+ return (Column(table_refs=stmt.get_tables("insert"), context="insert"),)
+ # We're probably in a function argument list
+ return _suggest_expression(token_v, stmt)
+ elif token_v == "set":
+ return (Column(table_refs=stmt.get_tables(), local_tables=stmt.local_tables),)
+ elif token_v in ("select", "where", "having", "order by", "distinct"):
+ return _suggest_expression(token_v, stmt)
+ elif token_v == "as":
+ # Don't suggest anything for aliases
+ return ()
+ elif (token_v.endswith("join") and token.is_keyword) or (
+ token_v in ("copy", "from", "update", "into", "describe", "truncate")
+ ):
+
+ schema = stmt.get_identifier_schema()
+ tables = extract_tables(stmt.text_before_cursor)
+ is_join = token_v.endswith("join") and token.is_keyword
+
+ # Suggest tables from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = []
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, Schema())
+
+ if token_v == "from" or is_join:
+ suggest.append(
+ FromClauseItem(
+ schema=schema, table_refs=tables, local_tables=stmt.local_tables
+ )
+ )
+ elif token_v == "truncate":
+ suggest.append(Table(schema))
+ else:
+ suggest.extend((Table(schema), View(schema)))
+
+ if is_join and _allow_join(stmt.parsed):
+ tables = stmt.get_tables("before")
+ suggest.append(Join(table_refs=tables, schema=schema))
+
+ return tuple(suggest)
+
+ elif token_v == "function":
+ schema = stmt.get_identifier_schema()
+
+ # stmt.get_previous_token will fail for e.g. `SELECT 1 FROM functions WHERE function:`
+ try:
+ prev = stmt.get_previous_token(token).value.lower()
+ if prev in ("drop", "alter", "create", "create or replace"):
+
+ # Suggest functions from either the currently-selected schema or the
+ # public schema if no schema has been specified
+ suggest = []
+
+ if not schema:
+ # Suggest schemas
+ suggest.insert(0, Schema())
+
+ suggest.append(Function(schema=schema, usage="signature"))
+ return tuple(suggest)
+
+ except ValueError:
+ pass
+ return tuple()
+
+ elif token_v in ("table", "view"):
+ # E.g. 'ALTER TABLE <tablname>'
+ rel_type = {"table": Table, "view": View, "function": Function}[token_v]
+ schema = stmt.get_identifier_schema()
+ if schema:
+ return (rel_type(schema=schema),)
+ else:
+ return (Schema(), rel_type(schema=schema))
+
+ elif token_v == "column":
+ # E.g. 'ALTER TABLE foo ALTER COLUMN bar
+ return (Column(table_refs=stmt.get_tables()),)
+
+ elif token_v == "on":
+ tables = stmt.get_tables("before")
+ parent = (stmt.identifier and stmt.identifier.get_parent_name()) or None
+ if parent:
+ # "ON parent.<suggestion>"
+ # parent can be either a schema name or table alias
+ filteredtables = tuple(t for t in tables if identifies(parent, t))
+ sugs = [
+ Column(table_refs=filteredtables, local_tables=stmt.local_tables),
+ Table(schema=parent),
+ View(schema=parent),
+ Function(schema=parent),
+ ]
+ if filteredtables and _allow_join_condition(stmt.parsed):
+ sugs.append(JoinCondition(table_refs=tables, parent=filteredtables[-1]))
+ return tuple(sugs)
+ else:
+ # ON <suggestion>
+ # Use table alias if there is one, otherwise the table name
+ aliases = tuple(t.ref for t in tables)
+ if _allow_join_condition(stmt.parsed):
+ return (
+ Alias(aliases=aliases),
+ JoinCondition(table_refs=tables, parent=None),
+ )
+ else:
+ return (Alias(aliases=aliases),)
+
+ elif token_v in ("c", "use", "database", "template"):
+ # "\c <db", "use <db>", "DROP DATABASE <db>",
+ # "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
+ return (Database(),)
+ elif token_v == "schema":
+ # DROP SCHEMA schema_name, SET SCHEMA schema name
+ prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
+ quoted = prev_keyword and prev_keyword.value.lower() == "set"
+ return (Schema(quoted),)
+ elif token_v.endswith(",") or token_v in ("=", "and", "or"):
+ prev_keyword = stmt.reduce_to_prev_keyword()
+ if prev_keyword:
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ else:
+ return ()
+ elif token_v in ("type", "::"):
+ # ALTER TABLE foo SET DATA TYPE bar
+ # SELECT foo::bar
+ # Note that tables are a form of composite type in postgresql, so
+ # they're suggested here as well
+ schema = stmt.get_identifier_schema()
+ suggestions = [Datatype(schema=schema), Table(schema=schema)]
+ if not schema:
+ suggestions.append(Schema())
+ return tuple(suggestions)
+ elif token_v in {"alter", "create", "drop"}:
+ return (Keyword(token_v.upper()),)
+ elif token.is_keyword:
+ # token is a keyword we haven't implemented any special handling for
+ # go backwards in the query until we find one we do recognize
+ prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
+ if prev_keyword:
+ return suggest_based_on_last_token(prev_keyword, stmt)
+ else:
+ return (Keyword(token_v.upper()),)
+ else:
+ return (Keyword(),)
+
+
+def _suggest_expression(token_v, stmt):
+ """
+ Return suggestions for an expression, taking account of any partially-typed
+ identifier's parent, which may be a table alias or schema name.
+ """
+ parent = stmt.identifier.get_parent_name() if stmt.identifier else []
+ tables = stmt.get_tables()
+
+ if parent:
+ tables = tuple(t for t in tables if identifies(parent, t))
+ return (
+ Column(table_refs=tables, local_tables=stmt.local_tables),
+ Table(schema=parent),
+ View(schema=parent),
+ Function(schema=parent),
+ )
+
+ return (
+ Column(table_refs=tables, local_tables=stmt.local_tables, qualifiable=True),
+ Function(schema=None),
+ Keyword(token_v.upper()),
+ )
+
+
+def identifies(id, ref):
+ """Returns true if string `id` matches TableReference `ref`"""
+
+ return (
+ id == ref.alias
+ or id == ref.name
+ or (ref.schema and (id == ref.schema + "." + ref.name))
+ )
+
+
+def _allow_join_condition(statement):
+ """
+ Tests if a join condition should be suggested
+
+ We need this to avoid bad suggestions when entering e.g.
+ select * from tbl1 a join tbl2 b on a.id = <cursor>
+ So check that the preceding token is a ON, AND, or OR keyword, instead of
+ e.g. an equals sign.
+
+ :param statement: an sqlparse.sql.Statement
+ :return: boolean
+ """
+
+ if not statement or not statement.tokens:
+ return False
+
+ last_tok = statement.token_prev(len(statement.tokens))[1]
+ return last_tok.value.lower() in ("on", "and", "or")
+
+
+def _allow_join(statement):
+ """
+ Tests if a join should be suggested
+
+ We need this to avoid bad suggestions when entering e.g.
+ select * from tbl1 a join tbl2 b <cursor>
+ So check that the preceding token is a JOIN keyword
+
+ :param statement: an sqlparse.sql.Statement
+ :return: boolean
+ """
+
+ if not statement or not statement.tokens:
+ return False
+
+ last_tok = statement.token_prev(len(statement.tokens))[1]
+ return last_tok.value.lower().endswith("join") and last_tok.value.lower() not in (
+ "cross join",
+ "natural join",
+ )