summaryrefslogtreecommitdiffstats
path: root/pgcli/packages/parseutils
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/packages/parseutils')
-rw-r--r--pgcli/packages/parseutils/__init__.py22
-rw-r--r--pgcli/packages/parseutils/ctes.py141
-rw-r--r--pgcli/packages/parseutils/meta.py170
-rw-r--r--pgcli/packages/parseutils/tables.py170
-rw-r--r--pgcli/packages/parseutils/utils.py140
5 files changed, 643 insertions, 0 deletions
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
new file mode 100644
index 0000000..a11e7bf
--- /dev/null
+++ b/pgcli/packages/parseutils/__init__.py
@@ -0,0 +1,22 @@
+import sqlparse
+
+
+def query_starts_with(query, prefixes):
+ """Check if the query starts with any item from *prefixes*."""
+ prefixes = [prefix.lower() for prefix in prefixes]
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
+ return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
+
+
+def queries_start_with(queries, prefixes):
+ """Check if any queries start with any item from *prefixes*."""
+ for query in sqlparse.split(queries):
+ if query and query_starts_with(query, prefixes) is True:
+ return True
+ return False
+
+
+def is_destructive(queries):
+ """Returns if any of the queries in *queries* is destructive."""
+ keywords = ("drop", "shutdown", "delete", "truncate", "alter")
+ return queries_start_with(queries, keywords)
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py
new file mode 100644
index 0000000..e1f9088
--- /dev/null
+++ b/pgcli/packages/parseutils/ctes.py
@@ -0,0 +1,141 @@
+from sqlparse import parse
+from sqlparse.tokens import Keyword, CTE, DML
+from sqlparse.sql import Identifier, IdentifierList, Parenthesis
+from collections import namedtuple
+from .meta import TableMetadata, ColumnMetadata
+
+
+# TableExpression is a namedtuple representing a CTE, used internally
+# name: cte alias assigned in the query
+# columns: list of column names
+# start: index into the original string of the left parens starting the CTE
+# stop: index into the original string of the right parens ending the CTE
+TableExpression = namedtuple("TableExpression", "name columns start stop")
+
+
+def isolate_query_ctes(full_text, text_before_cursor):
+ """Simplify a query by converting CTEs into table metadata objects"""
+
+ if not full_text or not full_text.strip():
+ return full_text, text_before_cursor, tuple()
+
+ ctes, remainder = extract_ctes(full_text)
+ if not ctes:
+ return full_text, text_before_cursor, ()
+
+ current_position = len(text_before_cursor)
+ meta = []
+
+ for cte in ctes:
+ if cte.start < current_position < cte.stop:
+ # Currently editing a cte - treat its body as the current full_text
+ text_before_cursor = full_text[cte.start : current_position]
+ full_text = full_text[cte.start : cte.stop]
+ return full_text, text_before_cursor, meta
+
+ # Append this cte to the list of available table metadata
+ cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
+ meta.append(TableMetadata(cte.name, cols))
+
+ # Editing past the last cte (ie the main body of the query)
+ full_text = full_text[ctes[-1].stop :]
+ text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]
+
+ return full_text, text_before_cursor, tuple(meta)
+
+
+def extract_ctes(sql):
+ """Extract constant table expresseions from a query
+
+ Returns tuple (ctes, remainder_sql)
+
+ ctes is a list of TableExpression namedtuples
+ remainder_sql is the text from the original query after the CTEs have
+ been stripped.
+ """
+
+ p = parse(sql)[0]
+
+ # Make sure the first meaningful token is "WITH" which is necessary to
+ # define CTEs
+ idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
+ if not (tok and tok.ttype == CTE):
+ return [], sql
+
+ # Get the next (meaningful) token, which should be the first CTE
+ idx, tok = p.token_next(idx)
+ if not tok:
+ return ([], "")
+ start_pos = token_start_pos(p.tokens, idx)
+ ctes = []
+
+ if isinstance(tok, IdentifierList):
+ # Multiple ctes
+ for t in tok.get_identifiers():
+ cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
+ cte = get_cte_from_token(t, start_pos + cte_start_offset)
+ if not cte:
+ continue
+ ctes.append(cte)
+ elif isinstance(tok, Identifier):
+ # A single CTE
+ cte = get_cte_from_token(tok, start_pos)
+ if cte:
+ ctes.append(cte)
+
+ idx = p.token_index(tok) + 1
+
+ # Collapse everything after the ctes into a remainder query
+ remainder = "".join(str(tok) for tok in p.tokens[idx:])
+
+ return ctes, remainder
+
+
+def get_cte_from_token(tok, pos0):
+ cte_name = tok.get_real_name()
+ if not cte_name:
+ return None
+
+ # Find the start position of the opening parens enclosing the cte body
+ idx, parens = tok.token_next_by(Parenthesis)
+ if not parens:
+ return None
+
+ start_pos = pos0 + token_start_pos(tok.tokens, idx)
+ cte_len = len(str(parens)) # includes parens
+ stop_pos = start_pos + cte_len
+
+ column_names = extract_column_names(parens)
+
+ return TableExpression(cte_name, column_names, start_pos, stop_pos)
+
+
+def extract_column_names(parsed):
+ # Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
+ idx, tok = parsed.token_next_by(t=DML)
+ tok_val = tok and tok.value.lower()
+
+ if tok_val in ("insert", "update", "delete"):
+ # Jump ahead to the RETURNING clause where the list of column names is
+ idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
+ elif not tok_val == "select":
+ # Must be invalid CTE
+ return ()
+
+ # The next token should be either a column name, or a list of column names
+ idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
+ return tuple(t.get_name() for t in _identifiers(tok))
+
+
+def token_start_pos(tokens, idx):
+ return sum(len(str(t)) for t in tokens[:idx])
+
+
+def _identifiers(tok):
+ if isinstance(tok, IdentifierList):
+ for t in tok.get_identifiers():
+ # NB: IdentifierList.get_identifiers() can return non-identifiers!
+ if isinstance(t, Identifier):
+ yield t
+ elif isinstance(tok, Identifier):
+ yield tok
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
new file mode 100644
index 0000000..108c01a
--- /dev/null
+++ b/pgcli/packages/parseutils/meta.py
@@ -0,0 +1,170 @@
+from collections import namedtuple
+
+_ColumnMetadata = namedtuple(
+ "ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
+)
+
+
+def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
+ return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
+
+
+ForeignKey = namedtuple(
+ "ForeignKey",
+ [
+ "parentschema",
+ "parenttable",
+ "parentcolumn",
+ "childschema",
+ "childtable",
+ "childcolumn",
+ ],
+)
+TableMetadata = namedtuple("TableMetadata", "name columns")
+
+
+def parse_defaults(defaults_string):
+ """Yields default values for a function, given the string provided by
+ pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
+ if not defaults_string:
+ return
+ current = ""
+ in_quote = None
+ for char in defaults_string:
+ if current == "" and char == " ":
+ # Skip space after comma separating default expressions
+ continue
+ if char == '"' or char == "'":
+ if in_quote and char == in_quote:
+ # End quote
+ in_quote = None
+ elif not in_quote:
+ # Begin quote
+ in_quote = char
+ elif char == "," and not in_quote:
+ # End of expression
+ yield current
+ current = ""
+ continue
+ current += char
+ yield current
+
+
+class FunctionMetadata(object):
+ def __init__(
+ self,
+ schema_name,
+ func_name,
+ arg_names,
+ arg_types,
+ arg_modes,
+ return_type,
+ is_aggregate,
+ is_window,
+ is_set_returning,
+ is_extension,
+ arg_defaults,
+ ):
+ """Class for describing a postgresql function"""
+
+ self.schema_name = schema_name
+ self.func_name = func_name
+
+ self.arg_modes = tuple(arg_modes) if arg_modes else None
+ self.arg_names = tuple(arg_names) if arg_names else None
+
+ # Be flexible in not requiring arg_types -- use None as a placeholder
+ # for each arg. (Used for compatibility with old versions of postgresql
+ # where such info is hard to get.
+ if arg_types:
+ self.arg_types = tuple(arg_types)
+ elif arg_modes:
+ self.arg_types = tuple([None] * len(arg_modes))
+ elif arg_names:
+ self.arg_types = tuple([None] * len(arg_names))
+ else:
+ self.arg_types = None
+
+ self.arg_defaults = tuple(parse_defaults(arg_defaults))
+
+ self.return_type = return_type.strip()
+ self.is_aggregate = is_aggregate
+ self.is_window = is_window
+ self.is_set_returning = is_set_returning
+ self.is_extension = bool(is_extension)
+ self.is_public = self.schema_name and self.schema_name == "public"
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _signature(self):
+ return (
+ self.schema_name,
+ self.func_name,
+ self.arg_names,
+ self.arg_types,
+ self.arg_modes,
+ self.return_type,
+ self.is_aggregate,
+ self.is_window,
+ self.is_set_returning,
+ self.is_extension,
+ self.arg_defaults,
+ )
+
+ def __hash__(self):
+ return hash(self._signature())
+
+ def __repr__(self):
+ return (
+ "%s(schema_name=%r, func_name=%r, arg_names=%r, "
+ "arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
+ "is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
+ ) % ((self.__class__.__name__,) + self._signature())
+
+ def has_variadic(self):
+ return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
+
+ def args(self):
+ """Returns a list of input-parameter ColumnMetadata namedtuples."""
+ if not self.arg_names:
+ return []
+ modes = self.arg_modes or ["i"] * len(self.arg_names)
+ args = [
+ (name, typ)
+ for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
+ if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
+ ]
+
+ def arg(name, typ, num):
+ num_args = len(args)
+ num_defaults = len(self.arg_defaults)
+ has_default = num + num_defaults >= num_args
+ default = (
+ self.arg_defaults[num - num_args + num_defaults]
+ if has_default
+ else None
+ )
+ return ColumnMetadata(name, typ, [], default, has_default)
+
+ return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
+
+ def fields(self):
+ """Returns a list of output-field ColumnMetadata namedtuples"""
+
+ if self.return_type.lower() == "void":
+ return []
+ elif not self.arg_modes:
+ # For functions without output parameters, the function name
+ # is used as the name of the output column.
+ # E.g. 'SELECT unnest FROM unnest(...);'
+ return [ColumnMetadata(self.func_name, self.return_type, [])]
+
+ return [
+ ColumnMetadata(name, typ, [])
+ for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
+ if mode in ("o", "b", "t")
+ ] # OUT, INOUT, TABLE
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
new file mode 100644
index 0000000..0ec3e69
--- /dev/null
+++ b/pgcli/packages/parseutils/tables.py
@@ -0,0 +1,170 @@
+import sqlparse
+from collections import namedtuple
+from sqlparse.sql import IdentifierList, Identifier, Function
+from sqlparse.tokens import Keyword, DML, Punctuation
+
+TableReference = namedtuple(
+ "TableReference", ["schema", "name", "alias", "is_function"]
+)
+TableReference.ref = property(
+ lambda self: self.alias
+ or (
+ self.name
+ if self.name.islower() or self.name[0] == '"'
+ else '"' + self.name + '"'
+ )
+)
+
+
+# This code is borrowed from sqlparse example script.
+# <url>
+def is_subselect(parsed):
+ if not parsed.is_group:
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() in (
+ "SELECT",
+ "INSERT",
+ "UPDATE",
+ "CREATE",
+ "DELETE",
+ ):
+ return True
+ return False
+
+
+def _identifier_is_function(identifier):
+ return any(isinstance(t, Function) for t in identifier.tokens)
+
+
+def extract_from_part(parsed, stop_at_punctuation=True):
+ tbl_prefix_seen = False
+ for item in parsed.tokens:
+ if tbl_prefix_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item, stop_at_punctuation):
+ yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ return
+ # An incomplete nested select won't be recognized correctly as a
+ # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
+ # the second FROM to trigger this elif condition resulting in a
+ # `return`. So we need to ignore the keyword if the keyword
+ # FROM.
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN and its variants
+ # INNER JOIN, FULL OUTER JOIN, etc.
+ elif (
+ item.ttype is Keyword
+ and (not item.value.upper() == "FROM")
+ and (not item.value.upper().endswith("JOIN"))
+ ):
+ tbl_prefix_seen = False
+ else:
+ yield item
+ elif item.ttype is Keyword or item.ttype is Keyword.DML:
+ item_val = item.value.upper()
+ if (
+ item_val
+ in (
+ "COPY",
+ "FROM",
+ "INTO",
+ "UPDATE",
+ "TABLE",
+ )
+ or item_val.endswith("JOIN")
+ ):
+ tbl_prefix_seen = True
+ # 'SELECT a, FROM abc' will detect FROM as part of the column list.
+ # So this check here is necessary.
+ elif isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ if identifier.ttype is Keyword and identifier.value.upper() == "FROM":
+ tbl_prefix_seen = True
+ break
+
+
+def extract_table_identifiers(token_stream, allow_functions=True):
+ """yields tuples of TableReference namedtuples"""
+
+ # We need to do some massaging of the names because postgres is case-
+ # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
+ def parse_identifier(item):
+ name = item.get_real_name()
+ schema_name = item.get_parent_name()
+ alias = item.get_alias()
+ if not name:
+ schema_name = None
+ name = item.get_name()
+ alias = alias or name
+ schema_quoted = schema_name and item.value[0] == '"'
+ if schema_name and not schema_quoted:
+ schema_name = schema_name.lower()
+ quote_count = item.value.count('"')
+ name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
+ alias_quoted = alias and item.value[-1] == '"'
+ if alias_quoted or name_quoted and not alias and name.islower():
+ alias = '"' + (alias or name) + '"'
+ if name and not name_quoted and not name.islower():
+ if not alias:
+ alias = name
+ name = name.lower()
+ return schema_name, name, alias
+
+ try:
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ # Sometimes Keywords (such as FROM ) are classified as
+ # identifiers which don't have the get_real_name() method.
+ try:
+ schema_name = identifier.get_parent_name()
+ real_name = identifier.get_real_name()
+ is_function = allow_functions and _identifier_is_function(
+ identifier
+ )
+ except AttributeError:
+ continue
+ if real_name:
+ yield TableReference(
+ schema_name, real_name, identifier.get_alias(), is_function
+ )
+ elif isinstance(item, Identifier):
+ schema_name, real_name, alias = parse_identifier(item)
+ is_function = allow_functions and _identifier_is_function(item)
+
+ yield TableReference(schema_name, real_name, alias, is_function)
+ elif isinstance(item, Function):
+ schema_name, real_name, alias = parse_identifier(item)
+ yield TableReference(None, real_name, alias, allow_functions)
+ except StopIteration:
+ return
+
+
+# extract_tables is inspired from examples in the sqlparse lib.
+def extract_tables(sql):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of TableReference namedtuples
+
+ """
+ parsed = sqlparse.parse(sql)
+ if not parsed:
+ return ()
+
+ # INSERT statements must stop looking for tables at the sign of first
+ # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
+ # abc is the table name, but if we don't stop at the first lparen, then
+ # we'll identify abc, col1 and col2 as table names.
+ insert_stmt = parsed[0].token_first().value.lower() == "insert"
+ stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
+
+ # Kludge: sqlparse mistakenly identifies insert statements as
+ # function calls due to the parenthesized column list, e.g. interprets
+ # "insert into foo (bar, baz)" as a function call to foo with arguments
+ # (bar, baz). So don't allow any identifiers in insert statements
+ # to have is_function=True
+ identifiers = extract_table_identifiers(stream, allow_functions=not insert_stmt)
+ # In the case 'sche.<cursor>', we get an empty TableReference; remove that
+ return tuple(i for i in identifiers if i.name)
diff --git a/pgcli/packages/parseutils/utils.py b/pgcli/packages/parseutils/utils.py
new file mode 100644
index 0000000..034c96e
--- /dev/null
+++ b/pgcli/packages/parseutils/utils.py
@@ -0,0 +1,140 @@
+import re
+import sqlparse
+from sqlparse.sql import Identifier
+from sqlparse.tokens import Token, Error
+
+cleanup_regex = {
+ # This matches only alphanumerics and underscores.
+ "alphanum_underscore": re.compile(r"(\w+)$"),
+ # This matches everything except spaces, parens, colon, and comma
+ "many_punctuations": re.compile(r"([^():,\s]+)$"),
+ # This matches everything except spaces, parens, colon, comma, and period
+ "most_punctuations": re.compile(r"([^\.():,\s]+)$"),
+ # This matches everything except a space.
+ "all_punctuations": re.compile(r"([^\s]+)$"),
+}
+
+
+def last_word(text, include="alphanum_underscore"):
+ r"""
+ Find the last word in a sentence.
+
+ >>> last_word('abc')
+ 'abc'
+ >>> last_word(' abc')
+ 'abc'
+ >>> last_word('')
+ ''
+ >>> last_word(' ')
+ ''
+ >>> last_word('abc ')
+ ''
+ >>> last_word('abc def')
+ 'def'
+ >>> last_word('abc def ')
+ ''
+ >>> last_word('abc def;')
+ ''
+ >>> last_word('bac $def')
+ 'def'
+ >>> last_word('bac $def', include='most_punctuations')
+ '$def'
+ >>> last_word('bac \def', include='most_punctuations')
+ '\\\\def'
+ >>> last_word('bac \def;', include='most_punctuations')
+ '\\\\def;'
+ >>> last_word('bac::def', include='most_punctuations')
+ 'def'
+ >>> last_word('"foo*bar', include='most_punctuations')
+ '"foo*bar'
+ """
+
+ if not text: # Empty string
+ return ""
+
+ if text[-1].isspace():
+ return ""
+ else:
+ regex = cleanup_regex[include]
+ matches = regex.search(text)
+ if matches:
+ return matches.group(0)
+ else:
+ return ""
+
+
+def find_prev_keyword(sql, n_skip=0):
+ """Find the last sql keyword in an SQL statement
+
+ Returns the value of the last keyword, and the text of the query with
+ everything after the last keyword stripped
+ """
+ if not sql.strip():
+ return None, ""
+
+ parsed = sqlparse.parse(sql)[0]
+ flattened = list(parsed.flatten())
+ flattened = flattened[: len(flattened) - n_skip]
+
+ logical_operators = ("AND", "OR", "NOT", "BETWEEN")
+
+ for t in reversed(flattened):
+ if t.value == "(" or (
+ t.is_keyword and (t.value.upper() not in logical_operators)
+ ):
+ # Find the location of token t in the original parsed statement
+ # We can't use parsed.token_index(t) because t may be a child token
+ # inside a TokenList, in which case token_index throws an error
+ # Minimal example:
+ # p = sqlparse.parse('select * from foo where bar')
+ # t = list(p.flatten())[-3] # The "Where" token
+ # p.token_index(t) # Throws ValueError: not in list
+ idx = flattened.index(t)
+
+ # Combine the string values of all tokens in the original list
+ # up to and including the target keyword token t, to produce a
+ # query string with everything after the keyword token removed
+ text = "".join(tok.value for tok in flattened[: idx + 1])
+ return t, text
+
+ return None, ""
+
+
+# Postgresql dollar quote signs look like `$$` or `$tag$`
+dollar_quote_regex = re.compile(r"^\$[^$]*\$$")
+
+
+def is_open_quote(sql):
+ """Returns true if the query contains an unclosed quote"""
+
+ # parsed can contain one or more semi-colon separated commands
+ parsed = sqlparse.parse(sql)
+ return any(_parsed_is_open_quote(p) for p in parsed)
+
+
+def _parsed_is_open_quote(parsed):
+ # Look for unmatched single quotes, or unmatched dollar sign quotes
+ return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
+
+
+def parse_partial_identifier(word):
+ """Attempt to parse a (partially typed) word as an identifier
+
+ word may include a schema qualification, like `schema_name.partial_name`
+ or `schema_name.` There may also be unclosed quotation marks, like
+ `"schema`, or `schema."partial_name`
+
+ :param word: string representing a (partially complete) identifier
+ :return: sqlparse.sql.Identifier, or None
+ """
+
+ p = sqlparse.parse(word)[0]
+ n_tok = len(p.tokens)
+ if n_tok == 1 and isinstance(p.tokens[0], Identifier):
+ return p.tokens[0]
+ elif p.token_next_by(m=(Error, '"'))[1]:
+ # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
+ # Close the double quote, then reparse
+ return parse_partial_identifier(word + '"')
+ else:
+ return None