diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-09-06 04:17:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-09-06 04:17:09 +0000 |
commit | 76d27bc43d56d7ef3ca0090fb199777888adf7c3 (patch) | |
tree | b6d3d562a0be03d404426bb43aff62174be3dc5e /pgcli/packages/parseutils/__init__.py | |
parent | Adding upstream version 3.1.0. (diff) | |
download | pgcli-76d27bc43d56d7ef3ca0090fb199777888adf7c3.tar.xz pgcli-76d27bc43d56d7ef3ca0090fb199777888adf7c3.zip |
Adding upstream version 3.2.0.upstream/3.2.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli/packages/parseutils/__init__.py')
-rw-r--r-- | pgcli/packages/parseutils/__init__.py | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py index a11e7bf..1acc008 100644 --- a/pgcli/packages/parseutils/__init__.py +++ b/pgcli/packages/parseutils/__init__.py @@ -1,22 +1,34 @@ import sqlparse -def query_starts_with(query, prefixes): +def query_starts_with(formatted_sql, prefixes): """Check if the query starts with any item from *prefixes*.""" prefixes = [prefix.lower() for prefix in prefixes] - formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() return bool(formatted_sql) and formatted_sql.split()[0] in prefixes -def queries_start_with(queries, prefixes): - """Check if any queries start with any item from *prefixes*.""" - for query in sqlparse.split(queries): - if query and query_starts_with(query, prefixes) is True: - return True - return False +def query_is_unconditional_update(formatted_sql): + """Check if the query starts with UPDATE and contains no WHERE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" and "where" not in tokens + +def query_is_simple_update(formatted_sql): + """Check if the query starts with UPDATE.""" + tokens = formatted_sql.split() + return bool(tokens) and tokens[0] == "update" -def is_destructive(queries): + +def is_destructive(queries, warning_level="all"): """Returns if any of the queries in *queries* is destructive.""" keywords = ("drop", "shutdown", "delete", "truncate", "alter") - return queries_start_with(queries, keywords) + for query in sqlparse.split(queries): + if query: + formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip() + if query_starts_with(formatted_sql, keywords): + return True + if query_is_unconditional_update(formatted_sql): + return True + if warning_level == "all" and query_is_simple_update(formatted_sql): + return True + return False |