summaryrefslogtreecommitdiffstats
path: root/pgcli/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'pgcli/main.py')
-rw-r--r--pgcli/main.py137
1 files changed, 52 insertions, 85 deletions
diff --git a/pgcli/main.py b/pgcli/main.py
index 2d7edfa..0fa264f 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -1,12 +1,7 @@
-import platform
-import warnings
-
from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
from .config import skip_initial_comment
-warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
-
import atexit
import os
import re
@@ -22,13 +17,11 @@ import itertools
import platform
from time import time, sleep
from typing import Optional
-from urllib.parse import urlparse
-
-keyring = None # keyring will be loaded later
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
from cli_helpers.utils import strip_ansi
+from .explain_output_formatter import ExplainOutputFormatter
import click
try:
@@ -54,6 +47,7 @@ from pygments.lexers.sql import PostgresLexer
from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
import pgspecial as special
+from . import auth
from .pgcompleter import PGCompleter
from .pgtoolbar import create_toolbar_tokens_func
from .pgstyle import style_factory, style_factory_output
@@ -68,6 +62,7 @@ from .config import (
get_config_filename,
)
from .key_bindings import pgcli_bindings
+from .packages.formatter.sqlformatter import register_new_formatter
from .packages.prompt_utils import confirm_destructive_query
from .__init__ import __version__
@@ -79,16 +74,12 @@ except ImportError:
from urllib.parse import urlparse, unquote, parse_qs
from getpass import getuser
-from psycopg2 import OperationalError, InterfaceError
-# pg3: https://www.psycopg.org/psycopg3/docs/api/conninfo.html
-from psycopg2.extensions import make_dsn, parse_dsn
-import psycopg2
+from psycopg import OperationalError, InterfaceError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict
from collections import namedtuple
-from textwrap import dedent
-
try:
import sshtunnel
@@ -205,6 +196,7 @@ class PGCli:
self.output_file = None
self.pgspecial = PGSpecial()
+ self.explain_mode = False
self.multi_line = c["main"].as_bool("multi_line")
self.multiline_mode = c["main"].get("multi_line_mode", "psql")
self.vi_mode = c["main"].as_bool("vi")
@@ -248,7 +240,7 @@ class PGCli:
self.on_error = c["main"]["on_error"].upper()
self.decimal_format = c["data_formats"]["decimal"]
self.float_format = c["data_formats"]["float"]
- self.initialize_keyring()
+ auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger)
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
self.pgspecial.pset_pager(
@@ -292,6 +284,10 @@ class PGCli:
self.ssh_tunnel_url = ssh_tunnel_url
self.ssh_tunnel = None
+ # formatter setup
+ self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
+ register_new_formatter(self.formatter)
+
def quit(self):
raise PgCliQuitError
@@ -436,7 +432,10 @@ class PGCli:
on_error_resume = self.on_error == "RESUME"
return self.pgexecute.run(
- query, self.pgspecial, on_error_resume=on_error_resume
+ query,
+ self.pgspecial,
+ on_error_resume=on_error_resume,
+ explain_mode=self.explain_mode,
)
def write_to_file(self, pattern, **_):
@@ -500,19 +499,6 @@ class PGCli:
pgspecial_logger.addHandler(handler)
pgspecial_logger.setLevel(log_level)
- def initialize_keyring(self):
- global keyring
-
- keyring_enabled = self.config["main"].as_bool("keyring")
- if keyring_enabled:
- # Try best to load keyring (issue #1041).
- import importlib
-
- try:
- keyring = importlib.import_module("keyring")
- except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
- self.logger.warning("import keyring failed: %r.", e)
-
def connect_dsn(self, dsn, **kwargs):
self.connect(dsn=dsn, **kwargs)
@@ -532,7 +518,7 @@ class PGCli:
)
def connect_uri(self, uri):
- kwargs = psycopg2.extensions.parse_dsn(uri)
+ kwargs = conninfo_to_dict(uri)
remap = {"dbname": "database", "password": "passwd"}
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
self.connect(**kwargs)
@@ -555,18 +541,6 @@ class PGCli:
if not self.force_passwd_prompt and not passwd:
passwd = os.environ.get("PGPASSWORD", "")
- # Find password from store
- key = f"{user}@{host}"
- keyring_error_message = dedent(
- """\
- {}
- {}
- To remove this message do one of the following:
- - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
- - uninstall keyring: pip uninstall keyring
- - disable keyring in our configuration: add keyring = False to [main]"""
- )
-
# Prompt for a password immediately if requested via the -W flag. This
# avoids wasting time trying to connect to the database and catching a
# no-password exception.
@@ -577,18 +551,10 @@ class PGCli:
"Password for %s" % user, hide_input=True, show_default=False, type=str
)
- if not passwd and keyring:
+ key = f"{user}@{host}"
- try:
- passwd = keyring.get_password("pgcli", key)
- except (RuntimeError, keyring.errors.InitError) as e:
- click.secho(
- keyring_error_message.format(
- "Load your password from keyring returned:", str(e)
- ),
- err=True,
- fg="red",
- )
+ if not passwd and auth.keyring:
+ passwd = auth.keyring_get_password(key)
def should_ask_for_password(exc):
# Prompt for a password after 1st attempt to connect
@@ -603,7 +569,7 @@ class PGCli:
return False
if dsn:
- parsed_dsn = parse_dsn(dsn)
+ parsed_dsn = conninfo_to_dict(dsn)
if "host" in parsed_dsn:
host = parsed_dsn["host"]
if "port" in parsed_dsn:
@@ -650,7 +616,7 @@ class PGCli:
port = self.ssh_tunnel.local_bind_ports[0]
if dsn:
- dsn = make_dsn(dsn, host=host, port=port)
+ dsn = make_conninfo(dsn, host=host, port=port)
# Attempt to connect to the database.
# Note that passwd may be empty on the first attempt. If connection
@@ -672,17 +638,8 @@ class PGCli:
)
else:
raise e
- if passwd and keyring:
- try:
- keyring.set_password("pgcli", key, passwd)
- except (RuntimeError, keyring.errors.KeyringError) as e:
- click.secho(
- keyring_error_message.format(
- "Set password in keyring returned:", str(e)
- ),
- err=True,
- fg="red",
- )
+ if passwd and auth.keyring:
+ auth.keyring_set_password(key, passwd)
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
@@ -954,6 +911,8 @@ class PGCli:
def _should_limit_output(self, sql, cur):
"""returns True if the output should be truncated, False otherwise."""
+ if self.explain_mode:
+ return False
if not is_select(sql):
return False
@@ -986,6 +945,8 @@ class PGCli:
logger = self.logger
logger.debug("sql: %r", text)
+ # set query to formatter in order to parse table name
+ self.formatter.query = text
all_success = True
meta_changed = False # CREATE, ALTER, DROP, etc
mutated = False # INSERT, DELETE, etc
@@ -999,7 +960,11 @@ class PGCli:
start = time()
on_error_resume = self.on_error == "RESUME"
res = self.pgexecute.run(
- text, self.pgspecial, exception_formatter, on_error_resume
+ text,
+ self.pgspecial,
+ exception_formatter,
+ on_error_resume,
+ explain_mode=self.explain_mode,
)
is_special = None
@@ -1034,7 +999,9 @@ class PGCli:
max_field_width=self.max_field_width,
)
execution = time() - start
- formatted = format_output(title, cur, headers, status, settings)
+ formatted = format_output(
+ title, cur, headers, status, settings, self.explain_mode
+ )
output.extend(formatted)
total = time() - start
@@ -1195,7 +1162,7 @@ class PGCli:
@click.command()
-# Default host is '' so psycopg2 can default to either localhost or unix socket
+# Default host is '' so psycopg can default to either localhost or unix socket
@click.option(
"-h",
"--host",
@@ -1523,13 +1490,16 @@ def exception_formatter(e):
return click.style(str(e), fg="red")
-def format_output(title, cur, headers, status, settings):
+def format_output(title, cur, headers, status, settings, explain_mode=False):
output = []
expanded = settings.expanded or settings.table_format == "vertical"
table_format = "vertical" if settings.expanded else settings.table_format
max_width = settings.max_width
case_function = settings.case_function
- formatter = TabularOutputFormatter(format_name=table_format)
+ if explain_mode:
+ formatter = ExplainOutputFormatter(max_width or 100)
+ else:
+ formatter = TabularOutputFormatter(format_name=table_format)
def format_array(val):
if val is None:
@@ -1590,18 +1560,11 @@ def format_output(title, cur, headers, status, settings):
if hasattr(cur, "description"):
column_types = []
for d in cur.description:
- # pg3: type_name = cur.adapters.types[d.type_code].name
- if (
- # pg3: type_name in ("numeric", "float4", "float8")
- d[1] in psycopg2.extensions.DECIMAL.values
- or d[1] in psycopg2.extensions.FLOAT.values
- ):
+ col_type = cur.adapters.types.get(d.type_code)
+ type_name = col_type.name if col_type else None
+ if type_name in ("numeric", "float4", "float8"):
column_types.append(float)
- if (
- # pg3: type_name in ("int2", "int4", "int8")
- d[1] == psycopg2.extensions.INTEGER.values
- or d[1] in psycopg2.extensions.LONGINTEGER.values
- ):
+ if type_name in ("int2", "int4", "int8"):
column_types.append(int)
else:
column_types.append(str)
@@ -1618,15 +1581,19 @@ def format_output(title, cur, headers, status, settings):
and headers
):
formatted = formatter.format_output(
- cur, headers, format_name="vertical", column_types=None, **output_kwargs
+ cur,
+ headers,
+ format_name="vertical",
+ column_types=column_types,
+ **output_kwargs,
)
if isinstance(formatted, str):
formatted = iter(formatted.splitlines())
output = itertools.chain(output, formatted)
- # Only print the status if it's not None and we are not producing CSV
- if status and table_format != "csv":
+ # Only print the status if it's not None
+ if status:
output = itertools.chain(output, [format_status(cur, status)])
return output