summaryrefslogtreecommitdiffstats
path: root/pgcli
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2021-09-06 04:17:12 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2021-09-06 04:17:24 +0000
commit9b77fda0d4171f68760c070895dc5700cb6d1e0f (patch)
tree37912945e9d35ef62e5a4a1eb07ac224307e5db6 /pgcli
parentReleasing debian version 3.1.0-3. (diff)
downloadpgcli-9b77fda0d4171f68760c070895dc5700cb6d1e0f.tar.xz
pgcli-9b77fda0d4171f68760c070895dc5700cb6d1e0f.zip
Merging upstream version 3.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli')
-rw-r--r--pgcli/__init__.py2
-rw-r--r--pgcli/completion_refresher.py9
-rw-r--r--pgcli/config.py45
-rw-r--r--pgcli/magic.py8
-rw-r--r--pgcli/main.py93
-rw-r--r--pgcli/packages/parseutils/__init__.py32
-rw-r--r--pgcli/packages/parseutils/meta.py2
-rw-r--r--pgcli/packages/parseutils/tables.py3
-rw-r--r--pgcli/packages/pgliterals/pgliterals.json1
-rw-r--r--pgcli/packages/prioritization.py4
-rw-r--r--pgcli/packages/prompt_utils.py4
-rw-r--r--pgcli/packages/sqlcompletion.py2
-rw-r--r--pgcli/pgclirc19
-rw-r--r--pgcli/pgcompleter.py40
-rw-r--r--pgcli/pgexecute.py178
-rw-r--r--pgcli/pgtoolbar.py22
16 files changed, 311 insertions, 153 deletions
diff --git a/pgcli/__init__.py b/pgcli/__init__.py
index f5f41e5..1173108 100644
--- a/pgcli/__init__.py
+++ b/pgcli/__init__.py
@@ -1 +1 @@
-__version__ = "3.1.0"
+__version__ = "3.2.0"
diff --git a/pgcli/completion_refresher.py b/pgcli/completion_refresher.py
index cf0879f..1039d51 100644
--- a/pgcli/completion_refresher.py
+++ b/pgcli/completion_refresher.py
@@ -3,10 +3,9 @@ import os
from collections import OrderedDict
from .pgcompleter import PGCompleter
-from .pgexecute import PGExecute
-class CompletionRefresher(object):
+class CompletionRefresher:
refreshers = OrderedDict()
@@ -27,6 +26,10 @@ class CompletionRefresher(object):
has completed the refresh. The newly created completion
object will be passed in as an argument to each callback.
"""
+ if executor.is_virtual_database():
+ # do nothing
+ return [(None, None, None, "Auto-completion refresh can't be started.")]
+
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]
@@ -141,7 +144,7 @@ def refresh_casing(completer, executor):
with open(casing_file, "w") as f:
f.write(casing_prefs)
if os.path.isfile(casing_file):
- with open(casing_file, "r") as f:
+ with open(casing_file) as f:
completer.extend_casing([line.strip() for line in f])
diff --git a/pgcli/config.py b/pgcli/config.py
index 0fc42dd..22f08dc 100644
--- a/pgcli/config.py
+++ b/pgcli/config.py
@@ -3,6 +3,8 @@ import shutil
import os
import platform
from os.path import expanduser, exists, dirname
+import re
+from typing import TextIO
from configobj import ConfigObj
@@ -16,11 +18,15 @@ def config_location():
def load_config(usr_cfg, def_cfg=None):
- cfg = ConfigObj()
- cfg.merge(ConfigObj(def_cfg, interpolation=False))
- cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ # avoid config merges when possible. For writing, we need an umerged config instance.
+ # see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171
+ if def_cfg:
+ cfg = ConfigObj()
+ cfg.merge(ConfigObj(def_cfg, interpolation=False))
+ cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
+ else:
+ cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")
cfg.filename = expanduser(usr_cfg)
-
return cfg
@@ -44,12 +50,16 @@ def upgrade_config(config, def_config):
cfg.write()
+def get_config_filename(pgclirc_file=None):
+ return pgclirc_file or "%sconfig" % config_location()
+
+
def get_config(pgclirc_file=None):
from pgcli import __file__ as package_root
package_root = os.path.dirname(package_root)
- pgclirc_file = pgclirc_file or "%sconfig" % config_location()
+ pgclirc_file = get_config_filename(pgclirc_file)
default_config = os.path.join(package_root, "pgclirc")
write_default_config(default_config, pgclirc_file)
@@ -62,3 +72,28 @@ def get_casing_file(config):
if casing_file == "default":
casing_file = config_location() + "casing"
return casing_file
+
+
+def skip_initial_comment(f_stream: TextIO) -> int:
+ """
+ Initial comment in ~/.pg_service.conf is not always marked with '#'
+ which crashes the parser. This function takes a file object and
+ "rewinds" it to the beginning of the first section,
+ from where on it can be parsed safely
+
+ :return: number of skipped lines
+ """
+ section_regex = r"\s*\["
+ pos = f_stream.tell()
+ lines_skipped = 0
+ while True:
+ line = f_stream.readline()
+ if line == "":
+ break
+ if re.match(section_regex, line) is not None:
+ f_stream.seek(pos)
+ break
+ else:
+ pos += len(line)
+ lines_skipped += 1
+ return lines_skipped
diff --git a/pgcli/magic.py b/pgcli/magic.py
index f58f415..6e58f28 100644
--- a/pgcli/magic.py
+++ b/pgcli/magic.py
@@ -25,7 +25,11 @@ def pgcli_line_magic(line):
if hasattr(sql.connection.Connection, "get"):
conn = sql.connection.Connection.get(parsed["connection"])
else:
- conn = sql.connection.Connection.set(parsed["connection"])
+ try:
+ conn = sql.connection.Connection.set(parsed["connection"])
+ # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql
+ except TypeError:
+ conn = sql.connection.Connection.set(parsed["connection"], False)
try:
# A corresponding pgcli object already exists
@@ -43,7 +47,7 @@ def pgcli_line_magic(line):
conn._pgcli = pgcli
# For convenience, print the connection alias
- print("Connected: {}".format(conn.name))
+ print(f"Connected: {conn.name}")
try:
pgcli.run_cli()
diff --git a/pgcli/main.py b/pgcli/main.py
index b146898..5135f6f 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -2,8 +2,9 @@ import platform
import warnings
from os.path import expanduser
-from configobj import ConfigObj
+from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
+from .config import skip_initial_comment
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
@@ -20,12 +21,12 @@ import datetime as dt
import itertools
import platform
from time import time, sleep
-from codecs import open
keyring = None # keyring will be loaded later
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
+from cli_helpers.utils import strip_ansi
import click
try:
@@ -62,6 +63,7 @@ from .config import (
config_location,
ensure_dir_exists,
get_config,
+ get_config_filename,
)
from .key_bindings import pgcli_bindings
from .packages.prompt_utils import confirm_destructive_query
@@ -122,7 +124,7 @@ class PgCliQuitError(Exception):
pass
-class PGCli(object):
+class PGCli:
default_prompt = "\\u@\\h:\\d> "
max_len_prompt = 30
@@ -175,7 +177,11 @@ class PGCli(object):
# Load config.
c = self.config = get_config(pgclirc_file)
- NamedQueries.instance = NamedQueries.from_config(self.config)
+ # at this point, config should be written to pgclirc_file if it did not exist. Read it.
+ self.config_writer = load_config(get_config_filename(pgclirc_file))
+
+ # make sure to use self.config_writer, not self.config
+ NamedQueries.instance = NamedQueries.from_config(self.config_writer)
self.logger = logging.getLogger(__name__)
self.initialize_logging()
@@ -201,8 +207,11 @@ class PGCli(object):
self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
- c_dest_warning = c["main"].as_bool("destructive_warning")
- self.destructive_warning = c_dest_warning if warn is None else warn
+ self.destructive_warning = warn or c["main"]["destructive_warning"]
+ # also handle boolean format of destructive warning
+ self.destructive_warning = {"true": "all", "false": "off"}.get(
+ self.destructive_warning.lower(), self.destructive_warning
+ )
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = (
@@ -325,11 +334,11 @@ class PGCli(object):
if pattern not in TabularOutputFormatter().supported_formats:
raise ValueError()
self.table_format = pattern
- yield (None, None, None, "Changed table format to {}".format(pattern))
+ yield (None, None, None, f"Changed table format to {pattern}")
except ValueError:
- msg = "Table format {} not recognized. Allowed formats:".format(pattern)
+ msg = f"Table format {pattern} not recognized. Allowed formats:"
for table_type in TabularOutputFormatter().supported_formats:
- msg += "\n\t{}".format(table_type)
+ msg += f"\n\t{table_type}"
msg += "\nCurrently set to: %s" % self.table_format
yield (None, None, None, msg)
@@ -386,10 +395,13 @@ class PGCli(object):
try:
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
query = f.read()
- except IOError as e:
+ except OSError as e:
return [(None, None, None, str(e), "", False, True)]
- if self.destructive_warning and confirm_destructive_query(query) is False:
+ if (
+ self.destructive_warning != "off"
+ and confirm_destructive_query(query, self.destructive_warning) is False
+ ):
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
@@ -407,7 +419,7 @@ class PGCli(object):
if not os.path.isfile(filename):
try:
open(filename, "w").close()
- except IOError as e:
+ except OSError as e:
self.output_file = None
message = str(e) + "\nFile output disabled"
return [(None, None, None, message, "", False, True)]
@@ -479,7 +491,7 @@ class PGCli(object):
service_config, file = parse_service_info(service)
if service_config is None:
click.secho(
- "service '%s' was not found in %s" % (service, file), err=True, fg="red"
+ f"service '{service}' was not found in {file}", err=True, fg="red"
)
exit(1)
self.connect(
@@ -515,7 +527,7 @@ class PGCli(object):
passwd = os.environ.get("PGPASSWORD", "")
# Find password from store
- key = "%s@%s" % (user, host)
+ key = f"{user}@{host}"
keyring_error_message = dedent(
"""\
{}
@@ -644,8 +656,10 @@ class PGCli(object):
query = MetaQuery(query=text, successful=False)
try:
- if self.destructive_warning:
- destroy = confirm = confirm_destructive_query(text)
+ if self.destructive_warning != "off":
+ destroy = confirm = confirm_destructive_query(
+ text, self.destructive_warning
+ )
if destroy is False:
click.secho("Wise choice!")
raise KeyboardInterrupt
@@ -677,7 +691,7 @@ class PGCli(object):
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
- except IOError as e:
+ except OSError as e:
click.secho(str(e), err=True, fg="red")
else:
if output:
@@ -729,7 +743,6 @@ class PGCli(object):
if not self.less_chatty:
print("Server: PostgreSQL", self.pgexecute.server_version)
print("Version:", __version__)
- print("Chat: https://gitter.im/dbcli/pgcli")
print("Home: http://pgcli.com")
try:
@@ -753,11 +766,7 @@ class PGCli(object):
while self.watch_command:
try:
query = self.execute_command(self.watch_command)
- click.echo(
- "Waiting for {0} seconds before repeating".format(
- timing
- )
- )
+ click.echo(f"Waiting for {timing} seconds before repeating")
sleep(timing)
except KeyboardInterrupt:
self.watch_command = None
@@ -979,16 +988,13 @@ class PGCli(object):
callback = functools.partial(
self._on_completions_refreshed, persist_priorities=persist_priorities
)
- self.completion_refresher.refresh(
+ return self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
callback,
history=history,
settings=self.settings,
)
- return [
- (None, None, None, "Auto-completion refresh started in the background.")
- ]
def _on_completions_refreshed(self, new_completer, persist_priorities):
self._swap_completer_objects(new_completer, persist_priorities)
@@ -1049,7 +1055,7 @@ class PGCli(object):
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
)
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
- string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
+ string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
string = string.replace("\\n", "\n")
return string
@@ -1075,9 +1081,10 @@ class PGCli(object):
def echo_via_pager(self, text, color=None):
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
click.echo(text, color=color)
- elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv":
- click.echo_via_pager(text, color)
- elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
+ elif (
+ self.pgspecial.pager_config == PAGER_LONG_OUTPUT
+ and self.table_format != "csv"
+ ):
lines = text.split("\n")
# The last 4 lines are reserved for the pgcli menu and padding
@@ -1192,7 +1199,10 @@ class PGCli(object):
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
)
@click.option(
- "--warn/--no-warn", default=None, help="Warn before running a destructive query."
+ "--warn",
+ default=None,
+ type=click.Choice(["all", "moderate", "off"]),
+ help="Warn before running a destructive query.",
)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
@@ -1384,7 +1394,7 @@ def is_mutating(status):
if not status:
return False
- mutating = set(["insert", "update", "delete"])
+ mutating = {"insert", "update", "delete"}
return status.split(None, 1)[0].lower() in mutating
@@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings):
formatted = iter(formatted.splitlines())
first_line = next(formatted)
formatted = itertools.chain([first_line], formatted)
- if not expanded and max_width and len(first_line) > max_width and headers:
+ if (
+ not expanded
+ and max_width
+ and len(strip_ansi(first_line)) > max_width
+ and headers
+ ):
formatted = formatter.format_output(
cur, headers, format_name="vertical", column_types=None, **output_kwargs
)
@@ -1502,10 +1517,16 @@ def parse_service_info(service):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
service_file = expanduser("~/.pg_service.conf")
- if not service:
+ if not service or not os.path.exists(service_file):
# nothing to do
return None, service_file
- service_file_config = ConfigObj(service_file)
+ with open(service_file, newline="") as f:
+ skipped_lines = skip_initial_comment(f)
+ try:
+ service_file_config = ConfigObj(f)
+ except ParseError as err:
+ err.line_number += skipped_lines
+ raise err
if service not in service_file_config:
return None, service_file
service_conf = service_file_config.get(service)
diff --git a/pgcli/packages/parseutils/__init__.py b/pgcli/packages/parseutils/__init__.py
index a11e7bf..1acc008 100644
--- a/pgcli/packages/parseutils/__init__.py
+++ b/pgcli/packages/parseutils/__init__.py
@@ -1,22 +1,34 @@
import sqlparse
-def query_starts_with(query, prefixes):
+def query_starts_with(formatted_sql, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
- formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
-def queries_start_with(queries, prefixes):
- """Check if any queries start with any item from *prefixes*."""
- for query in sqlparse.split(queries):
- if query and query_starts_with(query, prefixes) is True:
- return True
- return False
+def query_is_unconditional_update(formatted_sql):
+ """Check if the query starts with UPDATE and contains no WHERE."""
+ tokens = formatted_sql.split()
+ return bool(tokens) and tokens[0] == "update" and "where" not in tokens
+
+def query_is_simple_update(formatted_sql):
+ """Check if the query starts with UPDATE."""
+ tokens = formatted_sql.split()
+ return bool(tokens) and tokens[0] == "update"
-def is_destructive(queries):
+
+def is_destructive(queries, warning_level="all"):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ("drop", "shutdown", "delete", "truncate", "alter")
- return queries_start_with(queries, keywords)
+ for query in sqlparse.split(queries):
+ if query:
+ formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
+ if query_starts_with(formatted_sql, keywords):
+ return True
+ if query_is_unconditional_update(formatted_sql):
+ return True
+ if warning_level == "all" and query_is_simple_update(formatted_sql):
+ return True
+ return False
diff --git a/pgcli/packages/parseutils/meta.py b/pgcli/packages/parseutils/meta.py
index 108c01a..333cab5 100644
--- a/pgcli/packages/parseutils/meta.py
+++ b/pgcli/packages/parseutils/meta.py
@@ -50,7 +50,7 @@ def parse_defaults(defaults_string):
yield current
-class FunctionMetadata(object):
+class FunctionMetadata:
def __init__(
self,
schema_name,
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
index 0ec3e69..aaa676c 100644
--- a/pgcli/packages/parseutils/tables.py
+++ b/pgcli/packages/parseutils/tables.py
@@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
- for x in extract_from_part(item, stop_at_punctuation):
- yield x
+ yield from extract_from_part(item, stop_at_punctuation)
elif stop_at_punctuation and item.ttype is Punctuation:
return
# An incomplete nested select won't be recognized correctly as a
diff --git a/pgcli/packages/pgliterals/pgliterals.json b/pgcli/packages/pgliterals/pgliterals.json
index c7b74b5..df00817 100644
--- a/pgcli/packages/pgliterals/pgliterals.json
+++ b/pgcli/packages/pgliterals/pgliterals.json
@@ -392,6 +392,7 @@
"QUOTE_NULLABLE",
"RADIANS",
"RADIUS",
+ "RANDOM",
"RANK",
"REGEXP_MATCH",
"REGEXP_MATCHES",
diff --git a/pgcli/packages/prioritization.py b/pgcli/packages/prioritization.py
index e92dcbb..f5a9cb5 100644
--- a/pgcli/packages/prioritization.py
+++ b/pgcli/packages/prioritization.py
@@ -16,10 +16,10 @@ def _compile_regex(keyword):
keywords = get_literals("keywords")
-keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
+keyword_regexs = {kw: _compile_regex(kw) for kw in keywords}
-class PrevalenceCounter(object):
+class PrevalenceCounter:
def __init__(self):
self.keyword_counts = defaultdict(int)
self.name_counts = defaultdict(int)
diff --git a/pgcli/packages/prompt_utils.py b/pgcli/packages/prompt_utils.py
index 3c58490..e8589de 100644
--- a/pgcli/packages/prompt_utils.py
+++ b/pgcli/packages/prompt_utils.py
@@ -3,7 +3,7 @@ import click
from .parseutils import is_destructive
-def confirm_destructive_query(queries):
+def confirm_destructive_query(queries, warning_level):
"""Check if the query is destructive and prompts the user to confirm.
Returns:
@@ -15,7 +15,7 @@ def confirm_destructive_query(queries):
prompt_text = (
"You're about to run a destructive command.\n" "Do you want to proceed? (y/n)"
)
- if is_destructive(queries) and sys.stdin.isatty():
+ if is_destructive(queries, warning_level) and sys.stdin.isatty():
return prompt(prompt_text, type=bool)
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index 6ef8859..6305301 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -47,7 +47,7 @@ Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple("Path", [])
-class SqlStatement(object):
+class SqlStatement:
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(
diff --git a/pgcli/pgclirc b/pgcli/pgclirc
index e97afda..15c10f5 100644
--- a/pgcli/pgclirc
+++ b/pgcli/pgclirc
@@ -23,9 +23,13 @@ multi_line = False
multi_line_mode = psql
# Destructive warning mode will alert you before executing a sql statement
-# that may cause harm to the database such as "drop table", "drop database"
-# or "shutdown".
-destructive_warning = True
+# that may cause harm to the database such as "drop table", "drop database",
+# "shutdown", "delete", or "update".
+# Possible values:
+# "all" - warn on data definition statements, server actions such as SHUTDOWN, DELETE or UPDATE
+# "moderate" - skip warning on UPDATE statements, except for unconditional updates
+# "off" - skip all warnings
+destructive_warning = all
# Enables expand mode, which is similar to `\x` in psql.
expand = False
@@ -170,9 +174,12 @@ arg-toolbar = 'noinherit bold'
arg-toolbar.text = 'nobold'
bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold'
bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold'
-literal.string = '#ba2121'
-literal.number = '#666666'
-keyword = 'bold #008000'
+# These three values can be used to further refine the syntax highlighting.
+# They are commented out by default, since they have priority over the theme set
+# with the `syntax_style` setting and overriding its behavior can be confusing.
+# literal.string = '#ba2121'
+# literal.number = '#666666'
+# keyword = 'bold #008000'
# style classes for colored table output
output.header = "#00ff5f bold"
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 9c95a01..227e25c 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -83,7 +83,7 @@ class PGCompleter(Completer):
reserved_words = set(get_literals("reserved"))
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
- super(PGCompleter, self).__init__()
+ super().__init__()
self.smart_completion = smart_completion
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
@@ -140,7 +140,7 @@ class PGCompleter(Completer):
return "'{}'".format(self.unescape_name(name))
def unescape_name(self, name):
- """ Unquote a string."""
+ """Unquote a string."""
if name and name[0] == '"' and name[-1] == '"':
name = name[1:-1]
@@ -177,7 +177,7 @@ class PGCompleter(Completer):
:return:
"""
# casing should be a dict {lowercasename:PreferredCasingName}
- self.casing = dict((word.lower(), word) for word in words)
+ self.casing = {word.lower(): word for word in words}
def extend_relations(self, data, kind):
"""extend metadata for tables or views.
@@ -279,8 +279,8 @@ class PGCompleter(Completer):
fk = ForeignKey(
parentschema, parenttable, parcol, childschema, childtable, childcol
)
- childcolmeta.foreignkeys.append((fk))
- parcolmeta.foreignkeys.append((fk))
+ childcolmeta.foreignkeys.append(fk)
+ parcolmeta.foreignkeys.append(fk)
def extend_datatypes(self, type_data):
@@ -424,7 +424,7 @@ class PGCompleter(Completer):
# the same priority as unquoted names.
lexical_priority = (
tuple(
- 0 if c in (" _") else -ord(c)
+ 0 if c in " _" else -ord(c)
for c in self.unescape_name(item.lower())
)
+ (1,)
@@ -517,9 +517,9 @@ class PGCompleter(Completer):
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref
- other_tbl_cols = set(
+ other_tbl_cols = {
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
- )
+ }
scoped_cols = {
t: [col for col in cols if col.name in other_tbl_cols]
for t, cols in scoped_cols.items()
@@ -574,7 +574,7 @@ class PGCompleter(Completer):
tbls - TableReference iterable of tables already in query
"""
tbl = self.case(tbl)
- tbls = set(normalize_ref(t.ref) for t in tbls)
+ tbls = {normalize_ref(t.ref) for t in tbls}
if self.generate_aliases:
tbl = generate_alias(self.unescape_name(tbl))
if normalize_ref(tbl) not in tbls:
@@ -589,10 +589,10 @@ class PGCompleter(Completer):
tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
- qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
- ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
- refs = set(normalize_ref(t.ref) for t in tbls)
- other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
+ qualified = {normalize_ref(t.ref): t.schema for t in tbls}
+ ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
+ refs = {normalize_ref(t.ref) for t in tbls}
+ other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
joins = []
# Iterate over FKs in existing tables to find potential joins
fks = (
@@ -667,7 +667,7 @@ class PGCompleter(Completer):
return d
# Tables that are closer to the cursor get higher prio
- ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
+ ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
# Map (schema, table, col) to tables
coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
@@ -703,7 +703,11 @@ class PGCompleter(Completer):
not f.is_aggregate
and not f.is_window
and not f.is_extension
- and (f.is_public or f.schema_name == suggestion.schema)
+ and (
+ f.is_public
+ or f.schema_name in self.search_path
+ or f.schema_name == suggestion.schema
+ )
)
else:
@@ -721,9 +725,7 @@ class PGCompleter(Completer):
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt)
- funcs = set(
- self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
- )
+ funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
matches = self.find_matches(word_before_cursor, funcs, meta="function")
@@ -953,7 +955,7 @@ class PGCompleter(Completer):
:return: {TableReference:{colname:ColumnMetaData}}
"""
- ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
+ ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
columns = OrderedDict()
meta = self.dbmetadata
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index d34bf26..a013b55 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -1,13 +1,15 @@
-import traceback
import logging
+import select
+import traceback
+
+import pgspecial as special
import psycopg2
-import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
+import psycopg2.extras
import sqlparse
-import pgspecial as special
-import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
+
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
@@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
+_wait_callback_is_set = False
def _wait_select(conn):
@@ -34,31 +37,41 @@ def _wait_select(conn):
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
- while 1:
- try:
- state = conn.poll()
- if state == POLL_OK:
- break
- elif state == POLL_READ:
- select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
- elif state == POLL_WRITE:
- select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
- else:
- raise conn.OperationalError("bad state from poll: %s" % state)
- except KeyboardInterrupt:
- conn.cancel()
- # the loop will be broken by a server error
- continue
- except select.error as e:
- errno = e.args[0]
- if errno != 4:
- raise
+ try:
+ while 1:
+ try:
+ state = conn.poll()
+ if state == POLL_OK:
+ break
+ elif state == POLL_READ:
+ select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
+ elif state == POLL_WRITE:
+ select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
+ else:
+ raise conn.OperationalError("bad state from poll: %s" % state)
+ except KeyboardInterrupt:
+ conn.cancel()
+ # the loop will be broken by a server error
+ continue
+ except OSError as e:
+ errno = e.args[0]
+ if errno != 4:
+ raise
+ except psycopg2.OperationalError:
+ pass
-# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
-# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
-# See also https://github.com/psycopg/psycopg2/issues/468
-ext.set_wait_callback(_wait_select)
+def _set_wait_callback(is_virtual_database):
+ global _wait_callback_is_set
+ if _wait_callback_is_set:
+ return
+ _wait_callback_is_set = True
+ if is_virtual_database:
+ return
+ # When running a query, make pressing CTRL+C raise a KeyboardInterrupt
+ # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
+ # See also https://github.com/psycopg/psycopg2/issues/468
+ ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
@@ -72,6 +85,8 @@ def register_date_typecasters(connection):
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
+ if cursor.description is None:
+ return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
@@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
- except psycopg2.ProgrammingError:
+ except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
@@ -127,7 +142,39 @@ def register_hstore_typecaster(conn):
pass
-class PGExecute(object):
+class ProtocolSafeCursor(psycopg2.extensions.cursor):
+ def __init__(self, *args, **kwargs):
+ self.protocol_error = False
+ self.protocol_message = ""
+ super().__init__(*args, **kwargs)
+
+ def __iter__(self):
+ if self.protocol_error:
+ raise StopIteration
+ return super().__iter__()
+
+ def fetchall(self):
+ if self.protocol_error:
+ return [(self.protocol_message,)]
+ return super().fetchall()
+
+ def fetchone(self):
+ if self.protocol_error:
+ return (self.protocol_message,)
+ return super().fetchone()
+
+ def execute(self, sql, args=None):
+ try:
+ psycopg2.extensions.cursor.execute(self, sql, args)
+ self.protocol_error = False
+ self.protocol_message = ""
+ except psycopg2.errors.ProtocolViolation as ex:
+ self.protocol_error = True
+ self.protocol_message = ex.pgerror
+ _logger.debug("%s: %s" % (ex.__class__.__name__, ex))
+
+
+class PGExecute:
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
@@ -190,8 +237,6 @@ class PGExecute(object):
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
- version_query = "SELECT version();"
-
def __init__(
self,
database=None,
@@ -203,6 +248,7 @@ class PGExecute(object):
**kwargs,
):
self._conn_params = {}
+ self._is_virtual_database = None
self.conn = None
self.dbname = None
self.user = None
@@ -214,6 +260,11 @@ class PGExecute(object):
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
+ def is_virtual_database(self):
+ if self._is_virtual_database is None:
+ self._is_virtual_database = self.is_protocol_error()
+ return self._is_virtual_database
+
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
@@ -250,9 +301,9 @@ class PGExecute(object):
)
conn_params.update({k: v for k, v in new_params.items() if v})
+ conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
- cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
@@ -293,16 +344,22 @@ class PGExecute(object):
self.extra_args = kwargs
if not self.host:
- self.host = self.get_socket_directory()
+ self.host = (
+ "pgbouncer"
+ if self.is_virtual_database()
+ else self.get_socket_directory()
+ )
- pid = self._select_one(cursor, "select pg_backend_pid()")[0]
- self.pid = pid
+ self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
- self.server_version = conn.get_parameter_status("server_version")
+ self.server_version = conn.get_parameter_status("server_version") or ""
+
+ _set_wait_callback(self.is_virtual_database())
- register_date_typecasters(conn)
- register_json_typecasters(self.conn, self._json_typecaster)
- register_hstore_typecaster(self.conn)
+ if not self.is_virtual_database():
+ register_date_typecasters(conn)
+ register_json_typecasters(self.conn, self._json_typecaster)
+ register_hstore_typecaster(self.conn)
@property
def short_host(self):
@@ -395,7 +452,13 @@ class PGExecute(object):
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
- for result in pgspecial.execute(cur, sql):
+ response = pgspecial.execute(cur, sql)
+ if cur and cur.protocol_error:
+ yield None, None, None, cur.protocol_message, statement, False, False
+ # this would close connection. We should reconnect.
+ self.connect()
+ continue
+ for result in response:
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
@@ -453,6 +516,9 @@ class PGExecute(object):
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
+ elif cur.protocol_error:
+ _logger.debug("Protocol error, unsupported command.")
+ return title, None, None, cur.protocol_message
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
@@ -485,7 +551,7 @@ class PGExecute(object):
try:
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
- raise RuntimeError("View {} does not exist.".format(spec))
+ raise RuntimeError(f"View {spec} does not exist.")
result = cur.fetchone()
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
@@ -501,7 +567,7 @@ class PGExecute(object):
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
- raise RuntimeError("Function {} does not exist.".format(spec))
+ raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
"""Returns a list of schema names in the database"""
@@ -527,21 +593,18 @@ class PGExecute(object):
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
- for row in cur:
- yield row
+ yield from cur
def tables(self):
"""Yields (schema_name, table_name) tuples"""
- for row in self._relations(kinds=["r", "p", "f"]):
- yield row
+ yield from self._relations(kinds=["r", "p", "f"])
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
- for row in self._relations(kinds=["v", "m"]):
- yield row
+ yield from self._relations(kinds=["v", "m"])
def _columns(self, kinds=("r", "p", "f", "v", "m")):
"""Get column metadata for tables and views
@@ -599,16 +662,13 @@ class PGExecute(object):
sql = cur.mogrify(columns_query, [kinds])
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
- for row in cur:
- yield row
+ yield from cur
def table_columns(self):
- for row in self._columns(kinds=["r", "p", "f"]):
- yield row
+ yield from self._columns(kinds=["r", "p", "f"])
def view_columns(self):
- for row in self._columns(kinds=["v", "m"]):
- yield row
+ yield from self._columns(kinds=["v", "m"])
def databases(self):
with self.conn.cursor() as cur:
@@ -623,6 +683,13 @@ class PGExecute(object):
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
+ def is_protocol_error(self):
+ query = "SELECT 1"
+ with self.conn.cursor() as cur:
+ _logger.debug("Simple Query. sql: %r", query)
+ cur.execute(query)
+ return bool(cur.protocol_error)
+
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(
@@ -804,8 +871,7 @@ class PGExecute(object):
"""
_logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
- for row in cur:
- yield row
+ yield from cur
def casing(self):
"""Yields the most common casing for names used in db functions"""
diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py
index f4289a1..41f903d 100644
--- a/pgcli/pgtoolbar.py
+++ b/pgcli/pgtoolbar.py
@@ -1,15 +1,23 @@
+from pkg_resources import packaging
+
+import prompt_toolkit
from prompt_toolkit.key_binding.vi_state import InputMode
from prompt_toolkit.application import get_app
+parse_version = packaging.version.parse
+
+vi_modes = {
+ InputMode.INSERT: "I",
+ InputMode.NAVIGATION: "N",
+ InputMode.REPLACE: "R",
+ InputMode.INSERT_MULTIPLE: "M",
+}
+if parse_version(prompt_toolkit.__version__) >= parse_version("3.0.6"):
+ vi_modes[InputMode.REPLACE_SINGLE] = "R"
+
def _get_vi_mode():
- return {
- InputMode.INSERT: "I",
- InputMode.NAVIGATION: "N",
- InputMode.REPLACE: "R",
- InputMode.REPLACE_SINGLE: "R",
- InputMode.INSERT_MULTIPLE: "M",
- }[get_app().vi_state.input_mode]
+ return vi_modes[get_app().vi_state.input_mode]
def create_toolbar_tokens_func(pgcli):