summaryrefslogtreecommitdiffstats
path: root/pgcli/packages/parseutils/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/packages/parseutils/__init__.py')
-rw-r--r--pgcli/packages/parseutils/__init__.py48
1 files changed, 36 insertions, 12 deletions
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
index 1acc008..023e13b 100644
--- a/pgcli/packages/parseutils/__init__.py
+++ b/pgcli/packages/parseutils/__init__.py
@@ -1,6 +1,17 @@
import sqlparse
+BASE_KEYWORDS = [
+ "drop",
+ "shutdown",
+ "delete",
+ "truncate",
+ "alter",
+ "unconditional_update",
+]
+ALL_KEYWORDS = BASE_KEYWORDS + ["update"]
+
+
def query_starts_with(formatted_sql, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
@@ -13,22 +24,35 @@ def query_is_unconditional_update(formatted_sql):
return bool(tokens) and tokens[0] == "update" and "where" not in tokens
-def query_is_simple_update(formatted_sql):
- """Check if the query starts with UPDATE."""
- tokens = formatted_sql.split()
- return bool(tokens) and tokens[0] == "update"
-
-
-def is_destructive(queries, warning_level="all"):
+def is_destructive(queries, keywords):
"""Returns if any of the queries in *queries* is destructive."""
- keywords = ("drop", "shutdown", "delete", "truncate", "alter")
for query in sqlparse.split(queries):
if query:
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
- if query_starts_with(formatted_sql, keywords):
- return True
- if query_is_unconditional_update(formatted_sql):
+ if "unconditional_update" in keywords and query_is_unconditional_update(
+ formatted_sql
+ ):
return True
- if warning_level == "all" and query_is_simple_update(formatted_sql):
+ if query_starts_with(formatted_sql, keywords):
return True
return False
+
+
+def parse_destructive_warning(warning_level):
+ """Converts a deprecated destructive warning option to a list of command keywords."""
+ if not warning_level:
+ return []
+
+ if not isinstance(warning_level, list):
+ if "," in warning_level:
+ return warning_level.split(",")
+ warning_level = [warning_level]
+
+ return {
+ "true": ALL_KEYWORDS,
+ "false": [],
+ "all": ALL_KEYWORDS,
+ "moderate": BASE_KEYWORDS,
+ "off": [],
+ "": [],
+ }.get(warning_level[0], warning_level)