summaryrefslogtreecommitdiffstats
path: root/pgcli/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/main.py')
-rw-r--r--pgcli/main.py93
1 files changed, 57 insertions, 36 deletions
diff --git a/pgcli/main.py b/pgcli/main.py
index b146898..5135f6f 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -2,8 +2,9 @@ import platform
import warnings
from os.path import expanduser
-from configobj import ConfigObj
+from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
+from .config import skip_initial_comment
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
@@ -20,12 +21,12 @@ import datetime as dt
import itertools
import platform
from time import time, sleep
-from codecs import open
keyring = None # keyring will be loaded later
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
+from cli_helpers.utils import strip_ansi
import click
try:
@@ -62,6 +63,7 @@ from .config import (
config_location,
ensure_dir_exists,
get_config,
+ get_config_filename,
)
from .key_bindings import pgcli_bindings
from .packages.prompt_utils import confirm_destructive_query
@@ -122,7 +124,7 @@ class PgCliQuitError(Exception):
pass
-class PGCli(object):
+class PGCli:
default_prompt = "\\u@\\h:\\d> "
max_len_prompt = 30
@@ -175,7 +177,11 @@ class PGCli(object):
# Load config.
c = self.config = get_config(pgclirc_file)
- NamedQueries.instance = NamedQueries.from_config(self.config)
+ # at this point, config should be written to pgclirc_file if it did not exist. Read it.
+ self.config_writer = load_config(get_config_filename(pgclirc_file))
+
+ # make sure to use self.config_writer, not self.config
+ NamedQueries.instance = NamedQueries.from_config(self.config_writer)
self.logger = logging.getLogger(__name__)
self.initialize_logging()
@@ -201,8 +207,11 @@ class PGCli(object):
self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
- c_dest_warning = c["main"].as_bool("destructive_warning")
- self.destructive_warning = c_dest_warning if warn is None else warn
+ self.destructive_warning = warn or c["main"]["destructive_warning"]
+ # also handle boolean format of destructive warning
+ self.destructive_warning = {"true": "all", "false": "off"}.get(
+ self.destructive_warning.lower(), self.destructive_warning
+ )
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = (
@@ -325,11 +334,11 @@ class PGCli(object):
if pattern not in TabularOutputFormatter().supported_formats:
raise ValueError()
self.table_format = pattern
- yield (None, None, None, "Changed table format to {}".format(pattern))
+ yield (None, None, None, f"Changed table format to {pattern}")
except ValueError:
- msg = "Table format {} not recognized. Allowed formats:".format(pattern)
+ msg = f"Table format {pattern} not recognized. Allowed formats:"
for table_type in TabularOutputFormatter().supported_formats:
- msg += "\n\t{}".format(table_type)
+ msg += f"\n\t{table_type}"
msg += "\nCurrently set to: %s" % self.table_format
yield (None, None, None, msg)
@@ -386,10 +395,13 @@ class PGCli(object):
try:
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
query = f.read()
- except IOError as e:
+ except OSError as e:
return [(None, None, None, str(e), "", False, True)]
- if self.destructive_warning and confirm_destructive_query(query) is False:
+ if (
+ self.destructive_warning != "off"
+ and confirm_destructive_query(query, self.destructive_warning) is False
+ ):
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
@@ -407,7 +419,7 @@ class PGCli(object):
if not os.path.isfile(filename):
try:
open(filename, "w").close()
- except IOError as e:
+ except OSError as e:
self.output_file = None
message = str(e) + "\nFile output disabled"
return [(None, None, None, message, "", False, True)]
@@ -479,7 +491,7 @@ class PGCli(object):
service_config, file = parse_service_info(service)
if service_config is None:
click.secho(
- "service '%s' was not found in %s" % (service, file), err=True, fg="red"
+ f"service '{service}' was not found in {file}", err=True, fg="red"
)
exit(1)
self.connect(
@@ -515,7 +527,7 @@ class PGCli(object):
passwd = os.environ.get("PGPASSWORD", "")
# Find password from store
- key = "%s@%s" % (user, host)
+ key = f"{user}@{host}"
keyring_error_message = dedent(
"""\
{}
@@ -644,8 +656,10 @@ class PGCli(object):
query = MetaQuery(query=text, successful=False)
try:
- if self.destructive_warning:
- destroy = confirm = confirm_destructive_query(text)
+ if self.destructive_warning != "off":
+ destroy = confirm = confirm_destructive_query(
+ text, self.destructive_warning
+ )
if destroy is False:
click.secho("Wise choice!")
raise KeyboardInterrupt
@@ -677,7 +691,7 @@ class PGCli(object):
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
- except IOError as e:
+ except OSError as e:
click.secho(str(e), err=True, fg="red")
else:
if output:
@@ -729,7 +743,6 @@ class PGCli(object):
if not self.less_chatty:
print("Server: PostgreSQL", self.pgexecute.server_version)
print("Version:", __version__)
- print("Chat: https://gitter.im/dbcli/pgcli")
print("Home: http://pgcli.com")
try:
@@ -753,11 +766,7 @@ class PGCli(object):
while self.watch_command:
try:
query = self.execute_command(self.watch_command)
- click.echo(
- "Waiting for {0} seconds before repeating".format(
- timing
- )
- )
+ click.echo(f"Waiting for {timing} seconds before repeating")
sleep(timing)
except KeyboardInterrupt:
self.watch_command = None
@@ -979,16 +988,13 @@ class PGCli(object):
callback = functools.partial(
self._on_completions_refreshed, persist_priorities=persist_priorities
)
- self.completion_refresher.refresh(
+ return self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
callback,
history=history,
settings=self.settings,
)
- return [
- (None, None, None, "Auto-completion refresh started in the background.")
- ]
def _on_completions_refreshed(self, new_completer, persist_priorities):
self._swap_completer_objects(new_completer, persist_priorities)
@@ -1049,7 +1055,7 @@ class PGCli(object):
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
)
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
- string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
+ string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
string = string.replace("\\n", "\n")
return string
@@ -1075,9 +1081,10 @@ class PGCli(object):
def echo_via_pager(self, text, color=None):
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
click.echo(text, color=color)
- elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv":
- click.echo_via_pager(text, color)
- elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
+ elif (
+ self.pgspecial.pager_config == PAGER_LONG_OUTPUT
+ and self.table_format != "csv"
+ ):
lines = text.split("\n")
# The last 4 lines are reserved for the pgcli menu and padding
@@ -1192,7 +1199,10 @@ class PGCli(object):
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
)
@click.option(
- "--warn/--no-warn", default=None, help="Warn before running a destructive query."
+ "--warn",
+ default=None,
+ type=click.Choice(["all", "moderate", "off"]),
+ help="Warn before running a destructive query.",
)
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
@@ -1384,7 +1394,7 @@ def is_mutating(status):
if not status:
return False
- mutating = set(["insert", "update", "delete"])
+ mutating = {"insert", "update", "delete"}
return status.split(None, 1)[0].lower() in mutating
@@ -1475,7 +1485,12 @@ def format_output(title, cur, headers, status, settings):
formatted = iter(formatted.splitlines())
first_line = next(formatted)
formatted = itertools.chain([first_line], formatted)
- if not expanded and max_width and len(first_line) > max_width and headers:
+ if (
+ not expanded
+ and max_width
+ and len(strip_ansi(first_line)) > max_width
+ and headers
+ ):
formatted = formatter.format_output(
cur, headers, format_name="vertical", column_types=None, **output_kwargs
)
@@ -1502,10 +1517,16 @@ def parse_service_info(service):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
service_file = expanduser("~/.pg_service.conf")
- if not service:
+ if not service or not os.path.exists(service_file):
# nothing to do
return None, service_file
- service_file_config = ConfigObj(service_file)
+ with open(service_file, newline="") as f:
+ skipped_lines = skip_initial_comment(f)
+ try:
+ service_file_config = ConfigObj(f)
+ except ParseError as err:
+ err.line_number += skipped_lines
+ raise err
if service not in service_file_config:
return None, service_file
service_conf = service_file_config.get(service)