diff options
Diffstat (limited to 'pgcli/packages/parseutils')
-rw-r--r-- | pgcli/packages/parseutils/__init__.py | 32 | ||||
-rw-r--r-- | pgcli/packages/parseutils/meta.py | 2 | ||||
-rw-r--r-- | pgcli/packages/parseutils/tables.py | 3 |
3 files changed, 24 insertions, 13 deletions
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py index a11e7bf..1acc008 100644 --- a/pgcli/packages/parseutils/__init__.py +++ b/pgcli/packages/parseutils/__init__.py @@ -1,22 +1,34 @@ import sqlparse -def query_starts_with(query, prefixes): +def query_starts_with(formatted_sql, prefixes): """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] - formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() return bool(formatted_sql) and formatted_sql.split()[0] in prefixes -def queries_start_with(queries, prefixes): - """Check if any queries start with any item from *prefixes*.""" - for query in sqlparse.split(queries): - if query and query_starts_with(query, prefixes) is True: - return True - return False +def query_is_unconditional_update(formatted_sql): + """Check if the query starts with UPDATE and contains no WHERE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" and "where" not in tokens + +def query_is_simple_update(formatted_sql): + """Check if the query starts with UPDATE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" -def is_destructive(queries): + +def is_destructive(queries, warning_level="all"): """Returns if any of the queries in *queries* is destructive.""" keywords = ("drop", "shutdown", "delete", "truncate", "alter") - return queries_start_with(queries, keywords) + for query in sqlparse.split(queries): + if query: + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + if query_starts_with(formatted_sql, keywords): + return True + if query_is_unconditional_update(formatted_sql): + return True + if warning_level == "all" and query_is_simple_update(formatted_sql): + return True + return False diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py index 108c01a..333cab5 100644 --- a/pgcli/packages/parseutils/meta.py +++ b/pgcli/packages/parseutils/meta.py @@ -50,7 +50,7 @@ def parse_defaults(defaults_string): yield current -class FunctionMetadata(object): +class FunctionMetadata: def __init__( self, schema_name, diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py index 0ec3e69..aaa676c 100644 --- a/pgcli/packages/parseutils/tables.py +++ b/pgcli/packages/parseutils/tables.py @@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True): for item in parsed.tokens: if tbl_prefix_seen: if is_subselect(item): - for x in extract_from_part(item, stop_at_punctuation): - yield x + yield from extract_from_part(item, stop_at_punctuation) elif stop_at_punctuation and item.ttype is Punctuation: return # An incomplete nested select won't be recognized correctly as a |