summaryrefslogtreecommitdiffstats
path: root/pgcli
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-09-21 01:44:43 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-09-21 01:44:53 +0000
commit97864fef063b0960fd3df4529c561296e7805e8c (patch)
treee0f84cbf2df15e6c8dafb6161babb551d6b6fda6 /pgcli
parentReleasing debian version 3.4.1-1. (diff)
downloadpgcli-97864fef063b0960fd3df4529c561296e7805e8c.tar.xz
pgcli-97864fef063b0960fd3df4529c561296e7805e8c.zip
Merging upstream version 3.5.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli')
-rw-r--r--pgcli/__init__.py2
-rw-r--r--pgcli/auth.py58
-rw-r--r--pgcli/explain_output_formatter.py18
-rw-r--r--pgcli/key_bindings.py10
-rw-r--r--pgcli/main.py137
-rw-r--r--pgcli/packages/formatter/__init__.py1
-rw-r--r--pgcli/packages/formatter/sqlformatter.py71
-rw-r--r--pgcli/packages/parseutils/tables.py2
-rw-r--r--pgcli/packages/sqlcompletion.py2
-rw-r--r--pgcli/pgbuffer.py11
-rw-r--r--pgcli/pgclirc4
-rw-r--r--pgcli/pgexecute.py351
-rw-r--r--pgcli/pgstyle.py2
-rw-r--r--pgcli/pgtoolbar.py9
-rw-r--r--pgcli/pyev.py439
15 files changed, 804 insertions, 313 deletions
diff --git a/pgcli/__init__.py b/pgcli/__init__.py
index a5cfdf5..dcbfb52 100644
--- a/pgcli/__init__.py
+++ b/pgcli/__init__.py
@@ -1 +1 @@
-__version__ = "3.4.1"
+__version__ = "3.5.0"
diff --git a/pgcli/auth.py b/pgcli/auth.py
new file mode 100644
index 0000000..342c412
--- /dev/null
+++ b/pgcli/auth.py
@@ -0,0 +1,58 @@
+import click
+from textwrap import dedent
+
+
+keyring = None # keyring will be loaded later
+
+
+keyring_error_message = dedent(
+ """\
+ {}
+ {}
+ To remove this message do one of the following:
+ - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
+ - uninstall keyring: pip uninstall keyring
+ - disable keyring in our configuration: add keyring = False to [main]"""
+)
+
+
+def keyring_initialize(keyring_enabled, *, logger):
+ """Initialize keyring only if explicitly enabled"""
+ global keyring
+
+ if keyring_enabled:
+ # Try best to load keyring (issue #1041).
+ import importlib
+
+ try:
+ keyring = importlib.import_module("keyring")
+ except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
+ logger.warning("import keyring failed: %r.", e)
+
+
+def keyring_get_password(key):
+ """Attempt to get password from keyring"""
+ # Find password from store
+ passwd = ""
+ try:
+ passwd = keyring.get_password("pgcli", key) or ""
+ except Exception as e:
+ click.secho(
+ keyring_error_message.format(
+ "Load your password from keyring returned:", str(e)
+ ),
+ err=True,
+ fg="red",
+ )
+ return passwd
+
+
+def keyring_set_password(key, passwd):
+ try:
+ keyring.set_password("pgcli", key, passwd)
+ except Exception as e:
+ click.secho(
+ keyring_error_message.format("Set password in keyring returned:", str(e)),
+ err=True,
+ fg="red",
+ )
diff --git a/pgcli/explain_output_formatter.py b/pgcli/explain_output_formatter.py
new file mode 100644
index 0000000..b14cf44
--- /dev/null
+++ b/pgcli/explain_output_formatter.py
@@ -0,0 +1,18 @@
+from pgcli.pyev import Visualizer
+import json
+
+
+"""Explain response output adapter"""
+
+
+class ExplainOutputFormatter:
+ def __init__(self, max_width):
+ self.max_width = max_width
+
+ def format_output(self, cur, headers, **output_kwargs):
+ (data,) = cur.fetchone()
+ explain_list = json.loads(data)
+ visualizer = Visualizer(self.max_width)
+ for explain in explain_list:
+ visualizer.load(explain)
+ yield visualizer.get_list()
diff --git a/pgcli/key_bindings.py b/pgcli/key_bindings.py
index 23174b6..9c016f7 100644
--- a/pgcli/key_bindings.py
+++ b/pgcli/key_bindings.py
@@ -9,7 +9,7 @@ from prompt_toolkit.filters import (
vi_mode,
)
-from .pgbuffer import buffer_should_be_handled
+from .pgbuffer import buffer_should_be_handled, safe_multi_line_mode
_logger = logging.getLogger(__name__)
@@ -39,6 +39,12 @@ def pgcli_bindings(pgcli):
pgcli.vi_mode = not pgcli.vi_mode
event.app.editing_mode = EditingMode.VI if pgcli.vi_mode else EditingMode.EMACS
+ @kb.add("f5")
+ def _(event):
+ """Toggle between Vi and Emacs mode."""
+ _logger.debug("Detected F5 key.")
+ pgcli.explain_mode = not pgcli.explain_mode
+
@kb.add("tab")
def _(event):
"""Force autocompletion at cursor on non-empty lines."""
@@ -108,7 +114,7 @@ def pgcli_bindings(pgcli):
_logger.debug("Detected enter key.")
event.current_buffer.validate_and_handle()
- @kb.add("escape", "enter", filter=~vi_mode)
+ @kb.add("escape", "enter", filter=~vi_mode & ~safe_multi_line_mode(pgcli))
def _(event):
"""Introduces a line break regardless of multi-line mode or not."""
_logger.debug("Detected alt-enter key.")
diff --git a/pgcli/main.py b/pgcli/main.py
index 2d7edfa..0fa264f 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -1,12 +1,7 @@
-import platform
-import warnings
-
from configobj import ConfigObj, ParseError
from pgspecial.namedqueries import NamedQueries
from .config import skip_initial_comment
-warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
-
import atexit
import os
import re
@@ -22,13 +17,11 @@ import itertools
import platform
from time import time, sleep
from typing import Optional
-from urllib.parse import urlparse
-
-keyring = None # keyring will be loaded later
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
from cli_helpers.utils import strip_ansi
+from .explain_output_formatter import ExplainOutputFormatter
import click
try:
@@ -54,6 +47,7 @@ from pygments.lexers.sql import PostgresLexer
from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
import pgspecial as special
+from . import auth
from .pgcompleter import PGCompleter
from .pgtoolbar import create_toolbar_tokens_func
from .pgstyle import style_factory, style_factory_output
@@ -68,6 +62,7 @@ from .config import (
get_config_filename,
)
from .key_bindings import pgcli_bindings
+from .packages.formatter.sqlformatter import register_new_formatter
from .packages.prompt_utils import confirm_destructive_query
from .__init__ import __version__
@@ -79,16 +74,12 @@ except ImportError:
from urllib.parse import urlparse, unquote, parse_qs
from getpass import getuser
-from psycopg2 import OperationalError, InterfaceError
-# pg3: https://www.psycopg.org/psycopg3/docs/api/conninfo.html
-from psycopg2.extensions import make_dsn, parse_dsn
-import psycopg2
+from psycopg import OperationalError, InterfaceError
+from psycopg.conninfo import make_conninfo, conninfo_to_dict
from collections import namedtuple
-from textwrap import dedent
-
try:
import sshtunnel
@@ -205,6 +196,7 @@ class PGCli:
self.output_file = None
self.pgspecial = PGSpecial()
+ self.explain_mode = False
self.multi_line = c["main"].as_bool("multi_line")
self.multiline_mode = c["main"].get("multi_line_mode", "psql")
self.vi_mode = c["main"].as_bool("vi")
@@ -248,7 +240,7 @@ class PGCli:
self.on_error = c["main"]["on_error"].upper()
self.decimal_format = c["data_formats"]["decimal"]
self.float_format = c["data_formats"]["float"]
- self.initialize_keyring()
+ auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger)
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
self.pgspecial.pset_pager(
@@ -292,6 +284,10 @@ class PGCli:
self.ssh_tunnel_url = ssh_tunnel_url
self.ssh_tunnel = None
+ # formatter setup
+ self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"])
+ register_new_formatter(self.formatter)
+
def quit(self):
raise PgCliQuitError
@@ -436,7 +432,10 @@ class PGCli:
on_error_resume = self.on_error == "RESUME"
return self.pgexecute.run(
- query, self.pgspecial, on_error_resume=on_error_resume
+ query,
+ self.pgspecial,
+ on_error_resume=on_error_resume,
+ explain_mode=self.explain_mode,
)
def write_to_file(self, pattern, **_):
@@ -500,19 +499,6 @@ class PGCli:
pgspecial_logger.addHandler(handler)
pgspecial_logger.setLevel(log_level)
- def initialize_keyring(self):
- global keyring
-
- keyring_enabled = self.config["main"].as_bool("keyring")
- if keyring_enabled:
- # Try best to load keyring (issue #1041).
- import importlib
-
- try:
- keyring = importlib.import_module("keyring")
- except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
- self.logger.warning("import keyring failed: %r.", e)
-
def connect_dsn(self, dsn, **kwargs):
self.connect(dsn=dsn, **kwargs)
@@ -532,7 +518,7 @@ class PGCli:
)
def connect_uri(self, uri):
- kwargs = psycopg2.extensions.parse_dsn(uri)
+ kwargs = conninfo_to_dict(uri)
remap = {"dbname": "database", "password": "passwd"}
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
self.connect(**kwargs)
@@ -555,18 +541,6 @@ class PGCli:
if not self.force_passwd_prompt and not passwd:
passwd = os.environ.get("PGPASSWORD", "")
- # Find password from store
- key = f"{user}@{host}"
- keyring_error_message = dedent(
- """\
- {}
- {}
- To remove this message do one of the following:
- - prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
- - uninstall keyring: pip uninstall keyring
- - disable keyring in our configuration: add keyring = False to [main]"""
- )
-
# Prompt for a password immediately if requested via the -W flag. This
# avoids wasting time trying to connect to the database and catching a
# no-password exception.
@@ -577,18 +551,10 @@ class PGCli:
"Password for %s" % user, hide_input=True, show_default=False, type=str
)
- if not passwd and keyring:
+ key = f"{user}@{host}"
- try:
- passwd = keyring.get_password("pgcli", key)
- except (RuntimeError, keyring.errors.InitError) as e:
- click.secho(
- keyring_error_message.format(
- "Load your password from keyring returned:", str(e)
- ),
- err=True,
- fg="red",
- )
+ if not passwd and auth.keyring:
+ passwd = auth.keyring_get_password(key)
def should_ask_for_password(exc):
# Prompt for a password after 1st attempt to connect
@@ -603,7 +569,7 @@ class PGCli:
return False
if dsn:
- parsed_dsn = parse_dsn(dsn)
+ parsed_dsn = conninfo_to_dict(dsn)
if "host" in parsed_dsn:
host = parsed_dsn["host"]
if "port" in parsed_dsn:
@@ -650,7 +616,7 @@ class PGCli:
port = self.ssh_tunnel.local_bind_ports[0]
if dsn:
- dsn = make_dsn(dsn, host=host, port=port)
+ dsn = make_conninfo(dsn, host=host, port=port)
# Attempt to connect to the database.
# Note that passwd may be empty on the first attempt. If connection
@@ -672,17 +638,8 @@ class PGCli:
)
else:
raise e
- if passwd and keyring:
- try:
- keyring.set_password("pgcli", key, passwd)
- except (RuntimeError, keyring.errors.KeyringError) as e:
- click.secho(
- keyring_error_message.format(
- "Set password in keyring returned:", str(e)
- ),
- err=True,
- fg="red",
- )
+ if passwd and auth.keyring:
+ auth.keyring_set_password(key, passwd)
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
@@ -954,6 +911,8 @@ class PGCli:
def _should_limit_output(self, sql, cur):
"""returns True if the output should be truncated, False otherwise."""
+ if self.explain_mode:
+ return False
if not is_select(sql):
return False
@@ -986,6 +945,8 @@ class PGCli:
logger = self.logger
logger.debug("sql: %r", text)
+ # set query to formatter in order to parse table name
+ self.formatter.query = text
all_success = True
meta_changed = False # CREATE, ALTER, DROP, etc
mutated = False # INSERT, DELETE, etc
@@ -999,7 +960,11 @@ class PGCli:
start = time()
on_error_resume = self.on_error == "RESUME"
res = self.pgexecute.run(
- text, self.pgspecial, exception_formatter, on_error_resume
+ text,
+ self.pgspecial,
+ exception_formatter,
+ on_error_resume,
+ explain_mode=self.explain_mode,
)
is_special = None
@@ -1034,7 +999,9 @@ class PGCli:
max_field_width=self.max_field_width,
)
execution = time() - start
- formatted = format_output(title, cur, headers, status, settings)
+ formatted = format_output(
+ title, cur, headers, status, settings, self.explain_mode
+ )
output.extend(formatted)
total = time() - start
@@ -1195,7 +1162,7 @@ class PGCli:
@click.command()
-# Default host is '' so psycopg2 can default to either localhost or unix socket
+# Default host is '' so psycopg can default to either localhost or unix socket
@click.option(
"-h",
"--host",
@@ -1523,13 +1490,16 @@ def exception_formatter(e):
return click.style(str(e), fg="red")
-def format_output(title, cur, headers, status, settings):
+def format_output(title, cur, headers, status, settings, explain_mode=False):
output = []
expanded = settings.expanded or settings.table_format == "vertical"
table_format = "vertical" if settings.expanded else settings.table_format
max_width = settings.max_width
case_function = settings.case_function
- formatter = TabularOutputFormatter(format_name=table_format)
+ if explain_mode:
+ formatter = ExplainOutputFormatter(max_width or 100)
+ else:
+ formatter = TabularOutputFormatter(format_name=table_format)
def format_array(val):
if val is None:
@@ -1590,18 +1560,11 @@ def format_output(title, cur, headers, status, settings):
if hasattr(cur, "description"):
column_types = []
for d in cur.description:
- # pg3: type_name = cur.adapters.types[d.type_code].name
- if (
- # pg3: type_name in ("numeric", "float4", "float8")
- d[1] in psycopg2.extensions.DECIMAL.values
- or d[1] in psycopg2.extensions.FLOAT.values
- ):
+ col_type = cur.adapters.types.get(d.type_code)
+ type_name = col_type.name if col_type else None
+ if type_name in ("numeric", "float4", "float8"):
column_types.append(float)
- if (
- # pg3: type_name in ("int2", "int4", "int8")
- d[1] == psycopg2.extensions.INTEGER.values
- or d[1] in psycopg2.extensions.LONGINTEGER.values
- ):
+ if type_name in ("int2", "int4", "int8"):
column_types.append(int)
else:
column_types.append(str)
@@ -1618,15 +1581,19 @@ def format_output(title, cur, headers, status, settings):
and headers
):
formatted = formatter.format_output(
- cur, headers, format_name="vertical", column_types=None, **output_kwargs
+ cur,
+ headers,
+ format_name="vertical",
+ column_types=column_types,
+ **output_kwargs,
)
if isinstance(formatted, str):
formatted = iter(formatted.splitlines())
output = itertools.chain(output, formatted)
- # Only print the status if it's not None and we are not producing CSV
- if status and table_format != "csv":
+ # Only print the status if it's not None
+ if status:
output = itertools.chain(output, [format_status(cur, status)])
return output
diff --git a/pgcli/packages/formatter/__init__.py b/pgcli/packages/formatter/__init__.py
new file mode 100644
index 0000000..9bad579
--- /dev/null
+++ b/pgcli/packages/formatter/__init__.py
@@ -0,0 +1 @@
+# coding=utf-8
diff --git a/pgcli/packages/formatter/sqlformatter.py b/pgcli/packages/formatter/sqlformatter.py
new file mode 100644
index 0000000..5bf25fe
--- /dev/null
+++ b/pgcli/packages/formatter/sqlformatter.py
@@ -0,0 +1,71 @@
+# coding=utf-8
+
+from pgcli.packages.parseutils.tables import extract_tables
+
+
+supported_formats = (
+ "sql-insert",
+ "sql-update",
+ "sql-update-1",
+ "sql-update-2",
+)
+
+preprocessors = ()
+
+
+def escape_for_sql_statement(value):
+ if isinstance(value, bytes):
+ return f"X'{value.hex()}'"
+ else:
+ return "'{}'".format(value)
+
+
+def adapter(data, headers, table_format=None, **kwargs):
+ tables = extract_tables(formatter.query)
+ if len(tables) > 0:
+ table = tables[0]
+ if table[0]:
+ table_name = "{}.{}".format(*table[:2])
+ else:
+ table_name = table[1]
+ else:
+ table_name = '"DUAL"'
+ if table_format == "sql-insert":
+ h = '", "'.join(headers)
+ yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h)
+ prefix = " "
+ for d in data:
+ values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d))
+ yield "{}({})".format(prefix, values)
+ if prefix == " ":
+ prefix = ", "
+ yield ";"
+ if table_format.startswith("sql-update"):
+ s = table_format.split("-")
+ keys = 1
+ if len(s) > 2:
+ keys = int(s[-1])
+ for d in data:
+ yield 'UPDATE "{}" SET'.format(table_name)
+ prefix = " "
+ for i, v in enumerate(d[keys:], keys):
+ yield '{}"{}" = {}'.format(
+ prefix, headers[i], escape_for_sql_statement(v)
+ )
+ if prefix == " ":
+ prefix = ", "
+ f = '"{}" = {}'
+ where = (
+ f.format(headers[i], escape_for_sql_statement(d[i]))
+ for i in range(keys)
+ )
+ yield "WHERE {};".format(" AND ".join(where))
+
+
+def register_new_formatter(TabularOutputFormatter):
+ global formatter
+ formatter = TabularOutputFormatter
+ for sql_format in supported_formats:
+ TabularOutputFormatter.register_new_formatter(
+ sql_format, adapter, preprocessors, {"table_format": sql_format}
+ )
diff --git a/pgcli/packages/parseutils/tables.py b/pgcli/packages/parseutils/tables.py
index f2e1e42..9098115 100644
--- a/pgcli/packages/parseutils/tables.py
+++ b/pgcli/packages/parseutils/tables.py
@@ -139,7 +139,7 @@ def extract_table_identifiers(token_stream, allow_functions=True):
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql):
- """Extract the table names from an SQL statment.
+ """Extract the table names from an SQL statement.
Returns a list of TableReference namedtuples
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index 6305301..be4933a 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -380,7 +380,7 @@ def suggest_based_on_last_token(token, stmt):
)
elif p.token_first().value.lower() == "select":
- # If the lparen is preceeded by a space chances are we're about to
+ # If the lparen is preceded by a space chances are we're about to
# do a sub-select.
if last_word(stmt.text_before_cursor, "all_punctuations").startswith("("):
return (Keyword(),)
diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py
index 706ed25..c236c13 100644
--- a/pgcli/pgbuffer.py
+++ b/pgcli/pgbuffer.py
@@ -22,6 +22,17 @@ mode, which by default will insert new lines on Enter.
"""
+def safe_multi_line_mode(pgcli):
+ @Condition
+ def cond():
+ _logger.debug(
+ 'Multi-line mode state: "%s" / "%s"', pgcli.multi_line, pgcli.multiline_mode
+ )
+ return pgcli.multi_line and (pgcli.multiline_mode == "safe")
+
+ return cond
+
+
def buffer_should_be_handled(pgcli):
@Condition
def cond():
diff --git a/pgcli/pgclirc b/pgcli/pgclirc
index 6654ce9..dcff63d 100644
--- a/pgcli/pgclirc
+++ b/pgcli/pgclirc
@@ -95,7 +95,9 @@ show_bottom_toolbar = True
# Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe,
# ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs,
-# textile, moinmoin, jira, vertical, tsv, csv.
+# textile, moinmoin, jira, vertical, tsv, csv, sql-insert, sql-update,
+# sql-update-1, sql-update-2 (formatter with sql-* prefix can format query
+# output to executable insertion or updating sql).
# Recommended: psql, fancy_grid and grid.
table_format = psql
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 4808630..8f2968d 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -1,155 +1,45 @@
import logging
-import select
import traceback
+from collections import namedtuple
import pgspecial as special
-import psycopg2
-import psycopg2.errorcodes
-import psycopg2.extensions as ext
-import psycopg2.extras
+import psycopg
+import psycopg.sql
+from psycopg.conninfo import make_conninfo
import sqlparse
-from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
-# Cast all database input to unicode automatically.
-# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info.
-# pg3: These should be automatic: unicode is the default
-ext.register_type(ext.UNICODE)
-ext.register_type(ext.UNICODEARRAY)
-ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE))
-# See https://github.com/dbcli/pgcli/issues/426 for more details.
-# This registers a unicode type caster for datatype 'RECORD'.
-ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
-
-# Cast bytea fields to text. By default, this will render as hex strings with
-# Postgres 9+ and as escaped binary in earlier versions.
-ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
-
-# TODO: Get default timeout from pgclirc?
-_WAIT_SELECT_TIMEOUT = 1
-_wait_callback_is_set = False
-
-
-# pg3: it is already "green" but Ctrl-C breaks the query
-# pg3: This should be fixed upstream: https://github.com/psycopg/psycopg/issues/231
-def _wait_select(conn):
- """
- copy-pasted from psycopg2.extras.wait_select
- the default implementation doesn't define a timeout in the select calls
- """
- try:
- while 1:
- try:
- state = conn.poll()
- if state == POLL_OK:
- break
- elif state == POLL_READ:
- select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
- elif state == POLL_WRITE:
- select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
- else:
- raise conn.OperationalError("bad state from poll: %s" % state)
- except KeyboardInterrupt:
- conn.cancel()
- # the loop will be broken by a server error
- continue
- except OSError as e:
- errno = e.args[0]
- if errno != 4:
- raise
- except psycopg2.OperationalError:
- pass
-
-
-def _set_wait_callback(is_virtual_database):
- global _wait_callback_is_set
- if _wait_callback_is_set:
- return
- _wait_callback_is_set = True
- if is_virtual_database:
- return
- # When running a query, make pressing CTRL+C raise a KeyboardInterrupt
- # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
- # See also https://github.com/psycopg/psycopg2/issues/468
- ext.set_wait_callback(_wait_select)
-
-
-# pg3: You can do something like:
-# pg3: cnn.adapters.register_loader("date", psycopg.types.string.TextLoader)
-def register_date_typecasters(connection):
- """
- Casts date and timestamp values to string, resolves issues with out of
- range dates (e.g. BC) which psycopg2 can't handle
- """
-
- def cast_date(value, cursor):
- return value
-
- cursor = connection.cursor()
- cursor.execute("SELECT NULL::date")
- if cursor.description is None:
- return
- date_oid = cursor.description[0][1]
- cursor.execute("SELECT NULL::timestamp")
- timestamp_oid = cursor.description[0][1]
- cursor.execute("SELECT NULL::timestamp with time zone")
- timestamptz_oid = cursor.description[0][1]
- oids = (date_oid, timestamp_oid, timestamptz_oid)
- new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date)
- psycopg2.extensions.register_type(new_type)
-
-
-def register_json_typecasters(conn, loads_fn):
- """Set the function for converting JSON data for a connection.
-
- Use the supplied function to decode JSON data returned from the database
- via the given connection. The function should accept a single argument of
- the data as a string encoded in the database's character encoding.
- psycopg2's default handler for JSON data is json.loads.
- http://initd.org/psycopg/docs/extras.html#json-adaptation
-
- This function attempts to register the typecaster for both JSON and JSONB
- types.
-
- Returns a set that is a subset of {'json', 'jsonb'} indicating which types
- (if any) were successfully registered.
- """
- available = set()
-
- for name in ["json", "jsonb"]:
- try:
- psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
- available.add(name)
- except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
- pass
-
- return available
-
-
-# pg3: Probably you don't need this because by default unknown -> unicode
-def register_hstore_typecaster(conn):
- """
- Instead of using register_hstore() which converts hstore into a python
- dict, we query the 'oid' of hstore which will be different for each
- database and register a type caster that converts it to unicode.
- http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore
- """
- with conn.cursor() as cur:
- try:
- cur.execute(
- "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined"
- )
- oid = cur.fetchone()[0]
- ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
- except Exception:
- pass
+ViewDef = namedtuple(
+ "ViewDef", "nspname relname relkind viewdef reloptions checkoption"
+)
+
+
+def register_typecasters(connection):
+ """Casts date and timestamp values to string, resolves issues with out-of-range
+ dates (e.g. BC) which psycopg can't handle"""
+ for forced_text_type in [
+ "date",
+ "time",
+ "timestamp",
+ "timestamptz",
+ "bytea",
+ "json",
+ "jsonb",
+ ]:
+ connection.adapters.register_loader(
+ forced_text_type, psycopg.types.string.TextLoader
+ )
# pg3: I don't know what is this
-class ProtocolSafeCursor(psycopg2.extensions.cursor):
+class ProtocolSafeCursor(psycopg.Cursor):
+ """This class wraps and suppresses Protocol Errors with pgbouncer database.
+ See https://github.com/dbcli/pgcli/pull/1097.
+ Pgbouncer database is a virtual database with its own set of commands."""
+
def __init__(self, *args, **kwargs):
self.protocol_error = False
self.protocol_message = ""
@@ -170,14 +60,18 @@ class ProtocolSafeCursor(psycopg2.extensions.cursor):
return (self.protocol_message,)
return super().fetchone()
- def execute(self, sql, args=None):
+ # def mogrify(self, query, params):
+ # args = [Literal(v).as_string(self.connection) for v in params]
+ # return query % tuple(args)
+ #
+ def execute(self, *args, **kwargs):
try:
- psycopg2.extensions.cursor.execute(self, sql, args)
+ super().execute(*args, **kwargs)
self.protocol_error = False
self.protocol_message = ""
- except psycopg2.errors.ProtocolViolation as ex:
+ except psycopg.errors.ProtocolViolation as ex:
self.protocol_error = True
- self.protocol_message = ex.pgerror
+ self.protocol_message = str(ex)
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
@@ -290,7 +184,7 @@ class PGExecute:
conn_params = self._conn_params.copy()
new_params = {
- "database": database,
+ "dbname": database,
"user": user,
"password": password,
"host": host,
@@ -303,15 +197,15 @@ class PGExecute:
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
if new_params["password"]:
- new_params["dsn"] = make_dsn(
+ new_params["dsn"] = make_conninfo(
new_params["dsn"], password=new_params.pop("password")
)
conn_params.update({k: v for k, v in new_params.items() if v})
- conn_params["cursor_factory"] = ProtocolSafeCursor
- conn = psycopg2.connect(**conn_params)
- conn.set_client_encoding("utf8")
+ conn_info = make_conninfo(**conn_params)
+ conn = psycopg.connect(conn_info)
+ conn.cursor_factory = ProtocolSafeCursor
self._conn_params = conn_params
if self.conn:
@@ -322,19 +216,7 @@ class PGExecute:
# When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it.
# Note: moved this after setting autocommit because of #664.
- libpq_version = psycopg2.__libpq_version__
- dsn_parameters = {}
- if libpq_version >= 93000:
- # use actual connection info from psycopg2.extensions.Connection.info
- # as libpq_version > 9.3 is available and required dependency
- dsn_parameters = conn.info.dsn_parameters
- else:
- try:
- dsn_parameters = conn.get_dsn_parameters()
- except Exception as x:
- # https://github.com/dbcli/pgcli/issues/1110
- # PQconninfo not available in libpq < 9.3
- _logger.info("Exception in get_dsn_parameters: %r", x)
+ dsn_parameters = conn.info.get_parameters()
if dsn_parameters:
self.dbname = dsn_parameters.get("dbname")
@@ -357,16 +239,14 @@ class PGExecute:
else self.get_socket_directory()
)
- self.pid = conn.get_backend_pid()
- self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
- self.server_version = conn.get_parameter_status("server_version") or ""
+ self.pid = conn.info.backend_pid
+ self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1")
+ self.server_version = conn.info.parameter_status("server_version") or ""
- _set_wait_callback(self.is_virtual_database())
+ # _set_wait_callback(self.is_virtual_database())
if not self.is_virtual_database():
- register_date_typecasters(conn)
- register_json_typecasters(self.conn, self._json_typecaster)
- register_hstore_typecaster(self.conn)
+ register_typecasters(conn)
@property
def short_host(self):
@@ -387,31 +267,23 @@ class PGExecute:
cur.execute(sql)
return cur.fetchone()
- def _json_typecaster(self, json_data):
- """Interpret incoming JSON data as a string.
-
- The raw data is decoded using the connection's encoding, which defaults
- to the database's encoding.
-
- See http://initd.org/psycopg/docs/connection.html#connection.encoding
- """
-
- return json_data
-
def failed_transaction(self):
- # pg3: self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
- status = self.conn.get_transaction_status()
- return status == ext.TRANSACTION_STATUS_INERROR
+ return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
def valid_transaction(self):
- status = self.conn.get_transaction_status()
+ status = self.conn.info.transaction_status
return (
- status == ext.TRANSACTION_STATUS_ACTIVE
- or status == ext.TRANSACTION_STATUS_INTRANS
+ status == psycopg.pq.TransactionStatus.ACTIVE
+ or status == psycopg.pq.TransactionStatus.INTRANS
)
def run(
- self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False
+ self,
+ statement,
+ pgspecial=None,
+ exception_formatter=None,
+ on_error_resume=False,
+ explain_mode=False,
):
"""Execute the sql in the database and return the results.
@@ -432,17 +304,38 @@ class PGExecute:
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
- yield (None, None, None, None, statement, False, False)
+ yield None, None, None, None, statement, False, False
+
+ # sql parse doesn't split on a comment first + special
+ # so we're going to do it
+
+ sqltemp = []
+ sqlarr = []
+
+ if statement.startswith("--"):
+ sqltemp = statement.split("\n")
+ sqlarr.append(sqltemp[0])
+ for i in sqlparse.split(sqltemp[1]):
+ sqlarr.append(i)
+ elif statement.startswith("/*"):
+ sqltemp = statement.split("*/")
+ sqltemp[0] = sqltemp[0] + "*/"
+ for i in sqlparse.split(sqltemp[1]):
+ sqlarr.append(i)
+ else:
+ sqlarr = sqlparse.split(statement)
- # Split the sql into separate queries and run each one.
- for sql in sqlparse.split(statement):
+ # run each sql query
+ for sql in sqlarr:
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(";")
sql = sqlparse.format(sql, strip_comments=False).strip()
if not sql:
continue
try:
- if pgspecial:
+ if explain_mode:
+ sql = self.explain_prefix() + sql
+ elif pgspecial:
# \G is treated specially since we have to set the expanded output.
if sql.endswith("\\G"):
if not pgspecial.expanded_output:
@@ -454,7 +347,7 @@ class PGExecute:
_logger.debug("Trying a pgspecial command. sql: %r", sql)
try:
cur = self.conn.cursor()
- except psycopg2.InterfaceError:
+ except psycopg.InterfaceError:
# edge case when connection is already closed, but we
# don't need cursor for special_cmd.arg_type == NO_QUERY.
# See https://github.com/dbcli/pgcli/issues/1014.
@@ -478,7 +371,7 @@ class PGExecute:
# Not a special command, so execute as normal sql
yield self.execute_normal_sql(sql) + (sql, True, False)
- except psycopg2.DatabaseError as e:
+ except psycopg.DatabaseError as e:
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
@@ -498,7 +391,7 @@ class PGExecute:
"""Return true if e is an error that should not be caught in ``run``.
An uncaught error will prompt the user to reconnect; as long as we
- detect that the connection is stil open, we catch the error, as
+ detect that the connection is still open, we catch the error, as
reconnecting won't solve that problem.
:param e: DatabaseError. An exception raised while executing a query.
@@ -511,13 +404,23 @@ class PGExecute:
def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
_logger.debug("Regular sql statement. sql: %r", split_sql)
- cur = self.conn.cursor()
- cur.execute(split_sql)
- # conn.notices persist between queies, we use pop to clear out the list
title = ""
- while len(self.conn.notices) > 0:
- title = self.conn.notices.pop() + title
+
+ def handle_notices(n):
+ nonlocal title
+ title = f"{n.message_primary}\n{n.message_detail}\n{title}"
+
+ self.conn.add_notice_handler(handle_notices)
+
+ if self.is_virtual_database() and "show help" in split_sql.lower():
+ # see https://github.com/psycopg/psycopg/issues/303
+ # special case "show help" in pgbouncer
+ res = self.conn.pgconn.exec_(split_sql.encode())
+ return title, None, None, res.command_status.decode()
+
+ cur = self.conn.cursor()
+ cur.execute(split_sql)
# cur.description will be None for operations that do not return
# rows.
@@ -539,7 +442,7 @@ class PGExecute:
_logger.debug("Search path query. sql: %r", self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
- except psycopg2.ProgrammingError:
+ except psycopg.ProgrammingError:
fallback = "SELECT * FROM current_schemas(true)"
with self.conn.cursor() as cur:
_logger.debug("Search path query. sql: %r", fallback)
@@ -549,9 +452,6 @@ class PGExecute:
def view_definition(self, spec):
"""Returns the SQL defining views described by `spec`"""
- # pg3: you may want to use `psycopg.sql` for client-side composition
- # pg3: (also available in psycopg2 by the way)
- template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}"
# 2: relkind, v or m (materialized)
# 4: reloptions, null
# 5: checkoption: local or cascaded
@@ -560,11 +460,21 @@ class PGExecute:
_logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec,))
- except psycopg2.ProgrammingError:
+ except psycopg.ProgrammingError:
raise RuntimeError(f"View {spec} does not exist.")
- result = cur.fetchone()
- view_type = "MATERIALIZED" if result[2] == "m" else ""
- return template.format(*result + (view_type,))
+ result = ViewDef(*cur.fetchone())
+ if result.relkind == "m":
+ template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}"
+ else:
+ template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}"
+ return (
+ psycopg.sql.SQL(template)
+ .format(
+ name=psycopg.sql.Identifier(f"{result.nspname}.{result.relname}"),
+ stmt=psycopg.sql.SQL(result.viewdef),
+ )
+ .as_string(self.conn)
+ )
def function_definition(self, spec):
"""Returns the SQL defining functions described by `spec`"""
@@ -576,7 +486,7 @@ class PGExecute:
cur.execute(sql, (spec,))
result = cur.fetchone()
return result[0]
- except psycopg2.ProgrammingError:
+ except psycopg.ProgrammingError:
raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
@@ -600,9 +510,9 @@ class PGExecute:
"""
with self.conn.cursor() as cur:
- sql = cur.mogrify(self.tables_query, [kinds])
- _logger.debug("Tables Query. sql: %r", sql)
- cur.execute(sql)
+ # sql = cur.mogrify(self.tables_query, kinds)
+ # _logger.debug("Tables Query. sql: %r", sql)
+ cur.execute(self.tables_query, [kinds])
yield from cur
def tables(self):
@@ -628,7 +538,7 @@ class PGExecute:
:return: list of (schema_name, relation_name, column_name, column_type) tuples
"""
- if self.conn.server_version >= 80400:
+ if self.conn.info.server_version >= 80400:
columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
@@ -669,9 +579,9 @@ class PGExecute:
ORDER BY 1, 2, att.attnum"""
with self.conn.cursor() as cur:
- sql = cur.mogrify(columns_query, [kinds])
- _logger.debug("Columns Query. sql: %r", sql)
- cur.execute(sql)
+ # sql = cur.mogrify(columns_query, kinds)
+ # _logger.debug("Columns Query. sql: %r", sql)
+ cur.execute(columns_query, [kinds])
yield from cur
def table_columns(self):
@@ -712,7 +622,7 @@ class PGExecute:
def foreignkeys(self):
"""Yields ForeignKey named tuples"""
- if self.conn.server_version < 90000:
+ if self.conn.info.server_version < 90000:
return
with self.conn.cursor() as cur:
@@ -752,7 +662,7 @@ class PGExecute:
def functions(self):
"""Yields FunctionMetadata named tuples"""
- if self.conn.server_version >= 110000:
+ if self.conn.info.server_version >= 110000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@@ -772,7 +682,7 @@ class PGExecute:
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
- elif self.conn.server_version > 90000:
+ elif self.conn.info.server_version > 90000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@@ -792,7 +702,7 @@ class PGExecute:
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
- elif self.conn.server_version >= 80400:
+ elif self.conn.info.server_version >= 80400:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@@ -843,7 +753,7 @@ class PGExecute:
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
- if self.conn.server_version > 90000:
+ if self.conn.info.server_version > 90000:
query = """
SELECT n.nspname schema_name,
t.typname type_name
@@ -931,3 +841,6 @@ class PGExecute:
cur.execute(query)
for row in cur:
yield row[0]
+
+ def explain_prefix(self):
+ return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) "
diff --git a/pgcli/pgstyle.py b/pgcli/pgstyle.py
index 8229037..77874f4 100644
--- a/pgcli/pgstyle.py
+++ b/pgcli/pgstyle.py
@@ -83,7 +83,7 @@ def style_factory(name, cli_style):
logger.error("Unhandled style / class name: %s", token)
else:
# treat as prompt style name (2.0). See default style names here:
- # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py
+ # https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/src/prompt_toolkit/styles/defaults.py
prompt_styles.append((token, cli_style[token]))
override_style = Style([("bottom-toolbar", "noreverse")])
diff --git a/pgcli/pgtoolbar.py b/pgcli/pgtoolbar.py
index 41f903d..7b5883e 100644
--- a/pgcli/pgtoolbar.py
+++ b/pgcli/pgtoolbar.py
@@ -47,10 +47,15 @@ def create_toolbar_tokens_func(pgcli):
if pgcli.vi_mode:
result.append(
- ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ")")
+ ("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ") ")
)
else:
- result.append(("class:bottom-toolbar", "[F4] Emacs-mode"))
+ result.append(("class:bottom-toolbar", "[F4] Emacs-mode "))
+
+ if pgcli.explain_mode:
+ result.append(("class:bottom-toolbar", "[F5] Explain: ON "))
+ else:
+ result.append(("class:bottom-toolbar", "[F5] Explain: OFF "))
if pgcli.pgexecute.failed_transaction():
result.append(
diff --git a/pgcli/pyev.py b/pgcli/pyev.py
new file mode 100644
index 0000000..202947f
--- /dev/null
+++ b/pgcli/pyev.py
@@ -0,0 +1,439 @@
+import textwrap
+import re
+from click import style as color
+
+DESCRIPTIONS = {
+ "Append": "Used in a UNION to merge multiple record sets by appending them together.",
+ "Limit": "Returns a specified number of rows from a record set.",
+ "Sort": "Sorts a record set based on the specified sort key.",
+ "Nested Loop": "Merges two record sets by looping through every record in the first set and trying to find a match in the second set. All matching records are returned.",
+ "Merge Join": "Merges two record sets by first sorting them on a join key.",
+ "Hash": "Generates a hash table from the records in the input recordset. Hash is used by Hash Join.",
+ "Hash Join": "Joins to record sets by hashing one of them (using a Hash Scan).",
+ "Aggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()).",
+ "Hashaggregate": "Groups records together based on a GROUP BY or aggregate function (e.g. sum()). Hash Aggregate uses a hash to first organize the records by a key.",
+ "Sequence Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).",
+ "Seq Scan": "Finds relevant records by sequentially scanning the input record set. When reading from a table, Seq Scans (unlike Index Scans) perform a single read operation (only the table is read).",
+ "Index Scan": "Finds relevant records based on an Index. Index Scans perform 2 read operations: one to read the index and another to read the actual value from the table.",
+ "Index Only Scan": "Finds relevant records based on an Index. Index Only Scans perform a single read operation from the index and do not read from the corresponding table.",
+ "Bitmap Heap Scan": "Searches through the pages returned by the Bitmap Index Scan for relevant rows.",
+ "Bitmap Index Scan": "Uses a Bitmap Index (index which uses 1 bit per page) to find all relevant pages. Results of this node are fed to the Bitmap Heap Scan.",
+ "CTEScan": "Performs a sequential scan of Common Table Expression (CTE) query results. Note that results of a CTE are materialized (calculated and temporarily stored).",
+ "ProjectSet": "ProjectSet appears when the SELECT or ORDER BY clause of the query. They basically just execute the set-returning function(s) for each tuple until none of the functions return any more records.",
+ "Result": "Returns result",
+}
+
+
+class Visualizer:
+ def __init__(self, terminal_width=100, color=True):
+ self.color = color
+ self.terminal_width = terminal_width
+ self.string_lines = []
+
+ def load(self, explain_dict):
+ self.plan = explain_dict.pop("Plan")
+ self.explain = explain_dict
+ self.process_all()
+ self.generate_lines()
+
+ def process_all(self):
+ self.plan = self.process_plan(self.plan)
+ self.plan = self.calculate_outlier_nodes(self.plan)
+
+ #
+ def process_plan(self, plan):
+ plan = self.calculate_planner_estimate(plan)
+ plan = self.calculate_actuals(plan)
+ self.calculate_maximums(plan)
+ #
+ for index in range(len(plan.get("Plans", []))):
+ _plan = plan["Plans"][index]
+ plan["Plans"][index] = self.process_plan(_plan)
+ return plan
+
+ def prefix_format(self, v):
+ if self.color:
+ return color(v, fg="bright_black")
+ return v
+
+ def tag_format(self, v):
+ if self.color:
+ return color(v, fg="white", bg="red")
+ return v
+
+ def muted_format(self, v):
+ if self.color:
+ return color(v, fg="bright_black")
+ return v
+
+ def bold_format(self, v):
+ if self.color:
+ return color(v, fg="white")
+ return v
+
+ def good_format(self, v):
+ if self.color:
+ return color(v, fg="green")
+ return v
+
+ def warning_format(self, v):
+ if self.color:
+ return color(v, fg="yellow")
+ return v
+
+ def critical_format(self, v):
+ if self.color:
+ return color(v, fg="red")
+ return v
+
+ def output_format(self, v):
+ if self.color:
+ return color(v, fg="cyan")
+ return v
+
+ def calculate_planner_estimate(self, plan):
+ plan["Planner Row Estimate Factor"] = 0
+ plan["Planner Row Estimate Direction"] = "Under"
+
+ if plan["Plan Rows"] == plan["Actual Rows"]:
+ return plan
+
+ if plan["Plan Rows"] != 0:
+ plan["Planner Row Estimate Factor"] = (
+ plan["Actual Rows"] / plan["Plan Rows"]
+ )
+
+ if plan["Planner Row Estimate Factor"] < 10:
+ plan["Planner Row Estimate Factor"] = 0
+ plan["Planner Row Estimate Direction"] = "Over"
+ if plan["Actual Rows"] != 0:
+ plan["Planner Row Estimate Factor"] = (
+ plan["Plan Rows"] / plan["Actual Rows"]
+ )
+ return plan
+
+ #
+ def calculate_actuals(self, plan):
+ plan["Actual Duration"] = plan["Actual Total Time"]
+ plan["Actual Cost"] = plan["Total Cost"]
+
+ for child in plan.get("Plans", []):
+ if child["Node Type"] != "CTEScan":
+ plan["Actual Duration"] = (
+ plan["Actual Duration"] - child["Actual Total Time"]
+ )
+ plan["Actual Cost"] = plan["Actual Cost"] - child["Total Cost"]
+
+ if plan["Actual Cost"] < 0:
+ plan["Actual Cost"] = 0
+
+ plan["Actual Duration"] = plan["Actual Duration"] * plan["Actual Loops"]
+ return plan
+
+ def calculate_outlier_nodes(self, plan):
+ plan["Costliest"] = plan["Actual Cost"] == self.explain["Max Cost"]
+ plan["Largest"] = plan["Actual Rows"] == self.explain["Max Rows"]
+ plan["Slowest"] = plan["Actual Duration"] == self.explain["Max Duration"]
+
+ for index in range(len(plan.get("Plans", []))):
+ _plan = plan["Plans"][index]
+ plan["Plans"][index] = self.calculate_outlier_nodes(_plan)
+ return plan
+
+ def calculate_maximums(self, plan):
+ if not self.explain.get("Max Rows"):
+ self.explain["Max Rows"] = plan["Actual Rows"]
+ elif self.explain.get("Max Rows") < plan["Actual Rows"]:
+ self.explain["Max Rows"] = plan["Actual Rows"]
+
+ if not self.explain.get("MaxCost"):
+ self.explain["Max Cost"] = plan["Actual Cost"]
+ elif self.explain.get("Max Cost") < plan["Actual Cost"]:
+ self.explain["Max Cost"] = plan["Actual Cost"]
+
+ if not self.explain.get("Max Duration"):
+ self.explain["Max Duration"] = plan["Actual Duration"]
+ elif self.explain.get("Max Duration") < plan["Actual Duration"]:
+ self.explain["Max Duration"] = plan["Actual Duration"]
+
+ if not self.explain.get("Total Cost"):
+ self.explain["Total Cost"] = plan["Actual Cost"]
+ elif self.explain.get("Total Cost") < plan["Actual Cost"]:
+ self.explain["Total Cost"] = plan["Actual Cost"]
+
+ #
+ def duration_to_string(self, value):
+ if value < 1:
+ return self.good_format("<1 ms")
+ elif value < 100:
+ return self.good_format("%.2f ms" % value)
+ elif value < 1000:
+ return self.warning_format("%.2f ms" % value)
+ elif value < 60000:
+ return self.critical_format(
+ "%.2f s" % (value / 2000.0),
+ )
+ else:
+ return self.critical_format(
+ "%.2f m" % (value / 60000.0),
+ )
+
+ # }
+ #
+ def format_details(self, plan):
+ details = []
+
+ if plan.get("Scan Direction"):
+ details.append(plan["Scan Direction"])
+
+ if plan.get("Strategy"):
+ details.append(plan["Strategy"])
+
+ if len(details) > 0:
+ return self.muted_format(" [%s]" % ", ".join(details))
+
+ return ""
+
+ def format_tags(self, plan):
+ tags = []
+
+ if plan["Slowest"]:
+ tags.append(self.tag_format("slowest"))
+ if plan["Costliest"]:
+ tags.append(self.tag_format("costliest"))
+ if plan["Largest"]:
+ tags.append(self.tag_format("largest"))
+ if plan.get("Planner Row Estimate Factor", 0) >= 100:
+ tags.append(self.tag_format("bad estimate"))
+
+ return " ".join(tags)
+
+ def get_terminator(self, index, plan):
+ if index == 0:
+ if len(plan.get("Plans", [])) == 0:
+ return "⌡► "
+ else:
+ return "├► "
+ else:
+ if len(plan.get("Plans", [])) == 0:
+ return " "
+ else:
+ return "│ "
+
+ def wrap_string(self, line, width):
+ if width == 0:
+ return [line]
+ return textwrap.wrap(line, width)
+
+ def intcomma(self, value):
+ sep = ","
+ if not isinstance(value, str):
+ value = int(value)
+
+ orig = str(value)
+
+ new = re.sub(r"^(-?\d+)(\d{3})", rf"\g<1>{sep}\g<2>", orig)
+ if orig == new:
+ return new
+ else:
+ return self.intcomma(new)
+
+ def output_fn(self, current_prefix, string):
+ return "%s%s" % (self.prefix_format(current_prefix), string)
+
+ def create_lines(self, plan, prefix, depth, width, last_child):
+ current_prefix = prefix
+ self.string_lines.append(
+ self.output_fn(current_prefix, self.prefix_format("│"))
+ )
+
+ joint = "├"
+ if last_child:
+ joint = "└"
+ #
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s%s %s"
+ % (
+ self.prefix_format(joint + "─⌠"),
+ self.bold_format(plan["Node Type"]),
+ self.format_details(plan),
+ self.format_tags(plan),
+ ),
+ )
+ )
+ #
+ if last_child:
+ prefix += " "
+ else:
+ prefix += "│ "
+
+ current_prefix = prefix + "│ "
+
+ cols = width - len(current_prefix)
+
+ for line in self.wrap_string(
+ DESCRIPTIONS.get(plan["Node Type"], "Not found : %s" % plan["Node Type"]),
+ cols,
+ ):
+ self.string_lines.append(
+ self.output_fn(current_prefix, "%s" % self.muted_format(line))
+ )
+ #
+ if plan.get("Actual Duration"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "○ %s %s (%.0f%%)"
+ % (
+ "Duration:",
+ self.duration_to_string(plan["Actual Duration"]),
+ (plan["Actual Duration"] / self.explain["Execution Time"])
+ * 100,
+ ),
+ )
+ )
+
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "○ %s %s (%.0f%%)"
+ % (
+ "Cost:",
+ self.intcomma(plan["Actual Cost"]),
+ (plan["Actual Cost"] / self.explain["Total Cost"]) * 100,
+ ),
+ )
+ )
+
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "○ %s %s" % ("Rows:", self.intcomma(plan["Actual Rows"])),
+ )
+ )
+
+ current_prefix = current_prefix + " "
+
+ if plan.get("Join Type"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (plan["Join Type"], self.muted_format("join")),
+ )
+ )
+
+ if plan.get("Relation Name"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s.%s"
+ % (
+ self.muted_format("on"),
+ plan.get("Schema", "unknown"),
+ plan["Relation Name"],
+ ),
+ )
+ )
+
+ if plan.get("Index Name"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (self.muted_format("using"), plan["Index Name"]),
+ )
+ )
+
+ if plan.get("Index Condition"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (self.muted_format("condition"), plan["Index Condition"]),
+ )
+ )
+
+ if plan.get("Filter"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s %s"
+ % (
+ self.muted_format("filter"),
+ plan["Filter"],
+ self.muted_format(
+ "[-%s rows]" % self.intcomma(plan["Rows Removed by Filter"])
+ ),
+ ),
+ )
+ )
+
+ if plan.get("Hash Condition"):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %s" % (self.muted_format("on"), plan["Hash Condition"]),
+ )
+ )
+
+ if plan.get("CTE Name"):
+ self.string_lines.append(
+ self.output_fn(current_prefix, "CTE %s" % plan["CTE Name"])
+ )
+
+ if plan.get("Planner Row Estimate Factor") != 0:
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ "%s %sestimated %s %.2fx"
+ % (
+ self.muted_format("rows"),
+ plan["Planner Row Estimate Direction"],
+ self.muted_format("by"),
+ plan["Planner Row Estimate Factor"],
+ ),
+ )
+ )
+
+ current_prefix = prefix
+
+ if len(plan.get("Output", [])) > 0:
+ for index, line in enumerate(
+ self.wrap_string(" + ".join(plan["Output"]), cols)
+ ):
+ self.string_lines.append(
+ self.output_fn(
+ current_prefix,
+ self.prefix_format(self.get_terminator(index, plan))
+ + self.output_format(line),
+ )
+ )
+
+ for index, nested_plan in enumerate(plan.get("Plans", [])):
+ self.create_lines(
+ nested_plan, prefix, depth + 1, width, index == len(plan["Plans"]) - 1
+ )
+
+ def generate_lines(self):
+ self.string_lines = [
+ "○ Total Cost: %s" % self.intcomma(self.explain["Total Cost"]),
+ "○ Planning Time: %s"
+ % self.duration_to_string(self.explain["Planning Time"]),
+ "○ Execution Time: %s"
+ % self.duration_to_string(self.explain["Execution Time"]),
+ self.prefix_format("┬"),
+ ]
+ self.create_lines(
+ self.plan,
+ "",
+ 0,
+ self.terminal_width,
+ len(self.plan.get("Plans", [])) == 1,
+ )
+
+ def get_list(self):
+ return "\n".join(self.string_lines)
+
+ def print(self):
+ for lin in self.string_lines:
+ print(lin)