summaryrefslogtreecommitdiffstats
path: root/pgcli/packages/parseutils
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/packages/parseutils')
-rw-r--r--pgcli/packages/parseutils/__init__.py32
-rw-r--r--pgcli/packages/parseutils/meta.py2
-rw-r--r--pgcli/packages/parseutils/tables.py3
3 files changed, 24 insertions, 13 deletions
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
index a11e7bf..1acc008 100644
--- a/pgcli/packages/parseutils/__init__.py
+++ b/pgcli/packages/parseutils/__init__.py
@@ -1,22 +1,34 @@
import sqlparse
-def query_starts_with(query, prefixes):
+def query_starts_with(formatted_sql, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
- formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
-def queries_start_with(queries, prefixes):
- """Check if any queries start with any item from *prefixes*."""
- for query in sqlparse.split(queries):
- if query and query_starts_with(query, prefixes) is True:
- return True
- return False
+def query_is_unconditional_update(formatted_sql):
+ """Check if the query starts with UPDATE and contains no WHERE."""
+ tokens = formatted_sql.split()
+ return bool(tokens) and tokens[0] == "update" and "where" not in tokens
+
+def query_is_simple_update(formatted_sql):
+ """Check if the query starts with UPDATE."""
+ tokens = formatted_sql.split()
+ return bool(tokens) and tokens[0] == "update"
-def is_destructive(queries):
+
+def is_destructive(queries, warning_level="all"):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
- return queries_start_with(queries, keywords)
+ for query in sqlparse.split(queries):
+ if query:
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
+ if query_starts_with(formatted_sql, keywords):
+ return True
+ if query_is_unconditional_update(formatted_sql):
+ return True
+ if warning_level == "all" and query_is_simple_update(formatted_sql):
+ return True
+ return False
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
index 108c01a..333cab5 100644
--- a/pgcli/packages/parseutils/meta.py
+++ b/pgcli/packages/parseutils/meta.py
@@ -50,7 +50,7 @@ def parse_defaults(defaults_string):
yield current
-class FunctionMetadata(object):
+class FunctionMetadata:
def __init__(
self,
schema_name,
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
index 0ec3e69..aaa676c 100644
--- a/pgcli/packages/parseutils/tables.py
+++ b/pgcli/packages/parseutils/tables.py
@@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
- for x in extract_from_part(item, stop_at_punctuation):
- yield x
+ yield from extract_from_part(item, stop_at_punctuation)
elif stop_at_punctuation and item.ttype is Punctuation:
return
# An incomplete nested select won't be recognized correctly as a