summaryrefslogtreecommitdiffstats
path: root/pgcli/pgexecute.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-09-20 15:46:57 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-09-20 15:46:57 +0000
commit917739023a7acaae3645bbfd27ed454df3c5be33 (patch)
tree4e205849ae64ccd4d1797a1ad7579416f69f52ee /pgcli/pgexecute.py
parentAdding upstream version 3.4.1. (diff)
downloadpgcli-917739023a7acaae3645bbfd27ed454df3c5be33.tar.xz
pgcli-917739023a7acaae3645bbfd27ed454df3c5be33.zip
Adding upstream version 3.5.0.upstream/3.5.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'pgcli/pgexecute.py')
-rw-r--r--pgcli/pgexecute.py351
1 files changed, 132 insertions, 219 deletions
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py
index 4808630..8f2968d 100644
--- a/pgcli/pgexecute.py
+++ b/pgcli/pgexecute.py
@@ -1,155 +1,45 @@
import logging
-import select
import traceback
+from collections import namedtuple
import pgspecial as special
-import psycopg2
-import psycopg2.errorcodes
-import psycopg2.extensions as ext
-import psycopg2.extras
+import psycopg
+import psycopg.sql
+from psycopg.conninfo import make_conninfo
import sqlparse
-from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
-# Cast all database input to unicode automatically.
-# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info.
-# pg3: These should be automatic: unicode is the default
-ext.register_type(ext.UNICODE)
-ext.register_type(ext.UNICODEARRAY)
-ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE))
-# See https://github.com/dbcli/pgcli/issues/426 for more details.
-# This registers a unicode type caster for datatype 'RECORD'.
-ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
-
-# Cast bytea fields to text. By default, this will render as hex strings with
-# Postgres 9+ and as escaped binary in earlier versions.
-ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
-
-# TODO: Get default timeout from pgclirc?
-_WAIT_SELECT_TIMEOUT = 1
-_wait_callback_is_set = False
-
-
-# pg3: it is already "green" but Ctrl-C breaks the query
-# pg3: This should be fixed upstream: https://github.com/psycopg/psycopg/issues/231
-def _wait_select(conn):
- """
- copy-pasted from psycopg2.extras.wait_select
- the default implementation doesn't define a timeout in the select calls
- """
- try:
- while 1:
- try:
- state = conn.poll()
- if state == POLL_OK:
- break
- elif state == POLL_READ:
- select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
- elif state == POLL_WRITE:
- select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
- else:
- raise conn.OperationalError("bad state from poll: %s" % state)
- except KeyboardInterrupt:
- conn.cancel()
- # the loop will be broken by a server error
- continue
- except OSError as e:
- errno = e.args[0]
- if errno != 4:
- raise
- except psycopg2.OperationalError:
- pass
-
-
-def _set_wait_callback(is_virtual_database):
- global _wait_callback_is_set
- if _wait_callback_is_set:
- return
- _wait_callback_is_set = True
- if is_virtual_database:
- return
- # When running a query, make pressing CTRL+C raise a KeyboardInterrupt
- # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
- # See also https://github.com/psycopg/psycopg2/issues/468
- ext.set_wait_callback(_wait_select)
-
-
-# pg3: You can do something like:
-# pg3: cnn.adapters.register_loader("date", psycopg.types.string.TextLoader)
-def register_date_typecasters(connection):
- """
- Casts date and timestamp values to string, resolves issues with out of
- range dates (e.g. BC) which psycopg2 can't handle
- """
-
- def cast_date(value, cursor):
- return value
-
- cursor = connection.cursor()
- cursor.execute("SELECT NULL::date")
- if cursor.description is None:
- return
- date_oid = cursor.description[0][1]
- cursor.execute("SELECT NULL::timestamp")
- timestamp_oid = cursor.description[0][1]
- cursor.execute("SELECT NULL::timestamp with time zone")
- timestamptz_oid = cursor.description[0][1]
- oids = (date_oid, timestamp_oid, timestamptz_oid)
- new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date)
- psycopg2.extensions.register_type(new_type)
-
-
-def register_json_typecasters(conn, loads_fn):
- """Set the function for converting JSON data for a connection.
-
- Use the supplied function to decode JSON data returned from the database
- via the given connection. The function should accept a single argument of
- the data as a string encoded in the database's character encoding.
- psycopg2's default handler for JSON data is json.loads.
- http://initd.org/psycopg/docs/extras.html#json-adaptation
-
- This function attempts to register the typecaster for both JSON and JSONB
- types.
-
- Returns a set that is a subset of {'json', 'jsonb'} indicating which types
- (if any) were successfully registered.
- """
- available = set()
-
- for name in ["json", "jsonb"]:
- try:
- psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
- available.add(name)
- except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
- pass
-
- return available
-
-
-# pg3: Probably you don't need this because by default unknown -> unicode
-def register_hstore_typecaster(conn):
- """
- Instead of using register_hstore() which converts hstore into a python
- dict, we query the 'oid' of hstore which will be different for each
- database and register a type caster that converts it to unicode.
- http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore
- """
- with conn.cursor() as cur:
- try:
- cur.execute(
- "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined"
- )
- oid = cur.fetchone()[0]
- ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
- except Exception:
- pass
+ViewDef = namedtuple(
+ "ViewDef", "nspname relname relkind viewdef reloptions checkoption"
+)
+
+
+def register_typecasters(connection):
+ """Casts date and timestamp values to string, resolves issues with out-of-range
+ dates (e.g. BC) which psycopg can't handle"""
+ for forced_text_type in [
+ "date",
+ "time",
+ "timestamp",
+ "timestamptz",
+ "bytea",
+ "json",
+ "jsonb",
+ ]:
+ connection.adapters.register_loader(
+ forced_text_type, psycopg.types.string.TextLoader
+ )
# pg3: I don't know what is this
-class ProtocolSafeCursor(psycopg2.extensions.cursor):
+class ProtocolSafeCursor(psycopg.Cursor):
+ """This class wraps and suppresses Protocol Errors with pgbouncer database.
+ See https://github.com/dbcli/pgcli/pull/1097.
+ Pgbouncer database is a virtual database with its own set of commands."""
+
def __init__(self, *args, **kwargs):
self.protocol_error = False
self.protocol_message = ""
@@ -170,14 +60,18 @@ class ProtocolSafeCursor(psycopg2.extensions.cursor):
return (self.protocol_message,)
return super().fetchone()
- def execute(self, sql, args=None):
+ # def mogrify(self, query, params):
+ # args = [Literal(v).as_string(self.connection) for v in params]
+ # return query % tuple(args)
+ #
+ def execute(self, *args, **kwargs):
try:
- psycopg2.extensions.cursor.execute(self, sql, args)
+ super().execute(*args, **kwargs)
self.protocol_error = False
self.protocol_message = ""
- except psycopg2.errors.ProtocolViolation as ex:
+ except psycopg.errors.ProtocolViolation as ex:
self.protocol_error = True
- self.protocol_message = ex.pgerror
+ self.protocol_message = str(ex)
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
@@ -290,7 +184,7 @@ class PGExecute:
conn_params = self._conn_params.copy()
new_params = {
- "database": database,
+ "dbname": database,
"user": user,
"password": password,
"host": host,
@@ -303,15 +197,15 @@ class PGExecute:
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
if new_params["password"]:
- new_params["dsn"] = make_dsn(
+ new_params["dsn"] = make_conninfo(
new_params["dsn"], password=new_params.pop("password")
)
conn_params.update({k: v for k, v in new_params.items() if v})
- conn_params["cursor_factory"] = ProtocolSafeCursor
- conn = psycopg2.connect(**conn_params)
- conn.set_client_encoding("utf8")
+ conn_info = make_conninfo(**conn_params)
+ conn = psycopg.connect(conn_info)
+ conn.cursor_factory = ProtocolSafeCursor
self._conn_params = conn_params
if self.conn:
@@ -322,19 +216,7 @@ class PGExecute:
# When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it.
# Note: moved this after setting autocommit because of #664.
- libpq_version = psycopg2.__libpq_version__
- dsn_parameters = {}
- if libpq_version >= 93000:
- # use actual connection info from psycopg2.extensions.Connection.info
- # as libpq_version > 9.3 is available and required dependency
- dsn_parameters = conn.info.dsn_parameters
- else:
- try:
- dsn_parameters = conn.get_dsn_parameters()
- except Exception as x:
- # https://github.com/dbcli/pgcli/issues/1110
- # PQconninfo not available in libpq < 9.3
- _logger.info("Exception in get_dsn_parameters: %r", x)
+ dsn_parameters = conn.info.get_parameters()
if dsn_parameters:
self.dbname = dsn_parameters.get("dbname")
@@ -357,16 +239,14 @@ class PGExecute:
else self.get_socket_directory()
)
- self.pid = conn.get_backend_pid()
- self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
- self.server_version = conn.get_parameter_status("server_version") or ""
+ self.pid = conn.info.backend_pid
+ self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1")
+ self.server_version = conn.info.parameter_status("server_version") or ""
- _set_wait_callback(self.is_virtual_database())
+ # _set_wait_callback(self.is_virtual_database())
if not self.is_virtual_database():
- register_date_typecasters(conn)
- register_json_typecasters(self.conn, self._json_typecaster)
- register_hstore_typecaster(self.conn)
+ register_typecasters(conn)
@property
def short_host(self):
@@ -387,31 +267,23 @@ class PGExecute:
cur.execute(sql)
return cur.fetchone()
- def _json_typecaster(self, json_data):
- """Interpret incoming JSON data as a string.
-
- The raw data is decoded using the connection's encoding, which defaults
- to the database's encoding.
-
- See http://initd.org/psycopg/docs/connection.html#connection.encoding
- """
-
- return json_data
-
def failed_transaction(self):
- # pg3: self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
- status = self.conn.get_transaction_status()
- return status == ext.TRANSACTION_STATUS_INERROR
+ return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR
def valid_transaction(self):
- status = self.conn.get_transaction_status()
+ status = self.conn.info.transaction_status
return (
- status == ext.TRANSACTION_STATUS_ACTIVE
- or status == ext.TRANSACTION_STATUS_INTRANS
+ status == psycopg.pq.TransactionStatus.ACTIVE
+ or status == psycopg.pq.TransactionStatus.INTRANS
)
def run(
- self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False
+ self,
+ statement,
+ pgspecial=None,
+ exception_formatter=None,
+ on_error_resume=False,
+ explain_mode=False,
):
"""Execute the sql in the database and return the results.
@@ -432,17 +304,38 @@ class PGExecute:
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
- yield (None, None, None, None, statement, False, False)
+ yield None, None, None, None, statement, False, False
+
+ # sql parse doesn't split on a comment first + special
+ # so we're going to do it
+
+ sqltemp = []
+ sqlarr = []
+
+ if statement.startswith("--"):
+ sqltemp = statement.split("\n")
+ sqlarr.append(sqltemp[0])
+ for i in sqlparse.split(sqltemp[1]):
+ sqlarr.append(i)
+ elif statement.startswith("/*"):
+ sqltemp = statement.split("*/")
+ sqltemp[0] = sqltemp[0] + "*/"
+ for i in sqlparse.split(sqltemp[1]):
+ sqlarr.append(i)
+ else:
+ sqlarr = sqlparse.split(statement)
- # Split the sql into separate queries and run each one.
- for sql in sqlparse.split(statement):
+ # run each sql query
+ for sql in sqlarr:
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(";")
sql = sqlparse.format(sql, strip_comments=False).strip()
if not sql:
continue
try:
- if pgspecial:
+ if explain_mode:
+ sql = self.explain_prefix() + sql
+ elif pgspecial:
# \G is treated specially since we have to set the expanded output.
if sql.endswith("\\G"):
if not pgspecial.expanded_output:
@@ -454,7 +347,7 @@ class PGExecute:
_logger.debug("Trying a pgspecial command. sql: %r", sql)
try:
cur = self.conn.cursor()
- except psycopg2.InterfaceError:
+ except psycopg.InterfaceError:
# edge case when connection is already closed, but we
# don't need cursor for special_cmd.arg_type == NO_QUERY.
# See https://github.com/dbcli/pgcli/issues/1014.
@@ -478,7 +371,7 @@ class PGExecute:
# Not a special command, so execute as normal sql
yield self.execute_normal_sql(sql) + (sql, True, False)
- except psycopg2.DatabaseError as e:
+ except psycopg.DatabaseError as e:
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
@@ -498,7 +391,7 @@ class PGExecute:
"""Return true if e is an error that should not be caught in ``run``.
An uncaught error will prompt the user to reconnect; as long as we
- detect that the connection is stil open, we catch the error, as
+ detect that the connection is still open, we catch the error, as
reconnecting won't solve that problem.
:param e: DatabaseError. An exception raised while executing a query.
@@ -511,13 +404,23 @@ class PGExecute:
def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
_logger.debug("Regular sql statement. sql: %r", split_sql)
- cur = self.conn.cursor()
- cur.execute(split_sql)
- # conn.notices persist between queies, we use pop to clear out the list
title = ""
- while len(self.conn.notices) > 0:
- title = self.conn.notices.pop() + title
+
+ def handle_notices(n):
+ nonlocal title
+ title = f"{n.message_primary}\n{n.message_detail}\n{title}"
+
+ self.conn.add_notice_handler(handle_notices)
+
+ if self.is_virtual_database() and "show help" in split_sql.lower():
+ # see https://github.com/psycopg/psycopg/issues/303
+ # special case "show help" in pgbouncer
+ res = self.conn.pgconn.exec_(split_sql.encode())
+ return title, None, None, res.command_status.decode()
+
+ cur = self.conn.cursor()
+ cur.execute(split_sql)
# cur.description will be None for operations that do not return
# rows.
@@ -539,7 +442,7 @@ class PGExecute:
_logger.debug("Search path query. sql: %r", self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
- except psycopg2.ProgrammingError:
+ except psycopg.ProgrammingError:
fallback = "SELECT * FROM current_schemas(true)"
with self.conn.cursor() as cur:
_logger.debug("Search path query. sql: %r", fallback)
@@ -549,9 +452,6 @@ class PGExecute:
def view_definition(self, spec):
"""Returns the SQL defining views described by `spec`"""
- # pg3: you may want to use `psycopg.sql` for client-side composition
- # pg3: (also available in psycopg2 by the way)
- template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}"
# 2: relkind, v or m (materialized)
# 4: reloptions, null
# 5: checkoption: local or cascaded
@@ -560,11 +460,21 @@ class PGExecute:
_logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec)
try:
cur.execute(sql, (spec,))
- except psycopg2.ProgrammingError:
+ except psycopg.ProgrammingError:
raise RuntimeError(f"View {spec} does not exist.")
- result = cur.fetchone()
- view_type = "MATERIALIZED" if result[2] == "m" else ""
- return template.format(*result + (view_type,))
+ result = ViewDef(*cur.fetchone())
+ if result.relkind == "m":
+ template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}"
+ else:
+ template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}"
+ return (
+ psycopg.sql.SQL(template)
+ .format(
+ name=psycopg.sql.Identifier(f"{result.nspname}.{result.relname}"),
+ stmt=psycopg.sql.SQL(result.viewdef),
+ )
+ .as_string(self.conn)
+ )
def function_definition(self, spec):
"""Returns the SQL defining functions described by `spec`"""
@@ -576,7 +486,7 @@ class PGExecute:
cur.execute(sql, (spec,))
result = cur.fetchone()
return result[0]
- except psycopg2.ProgrammingError:
+ except psycopg.ProgrammingError:
raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
@@ -600,9 +510,9 @@ class PGExecute:
"""
with self.conn.cursor() as cur:
- sql = cur.mogrify(self.tables_query, [kinds])
- _logger.debug("Tables Query. sql: %r", sql)
- cur.execute(sql)
+ # sql = cur.mogrify(self.tables_query, kinds)
+ # _logger.debug("Tables Query. sql: %r", sql)
+ cur.execute(self.tables_query, [kinds])
yield from cur
def tables(self):
@@ -628,7 +538,7 @@ class PGExecute:
:return: list of (schema_name, relation_name, column_name, column_type) tuples
"""
- if self.conn.server_version >= 80400:
+ if self.conn.info.server_version >= 80400:
columns_query = """
SELECT nsp.nspname schema_name,
cls.relname table_name,
@@ -669,9 +579,9 @@ class PGExecute:
ORDER BY 1, 2, att.attnum"""
with self.conn.cursor() as cur:
- sql = cur.mogrify(columns_query, [kinds])
- _logger.debug("Columns Query. sql: %r", sql)
- cur.execute(sql)
+ # sql = cur.mogrify(columns_query, kinds)
+ # _logger.debug("Columns Query. sql: %r", sql)
+ cur.execute(columns_query, [kinds])
yield from cur
def table_columns(self):
@@ -712,7 +622,7 @@ class PGExecute:
def foreignkeys(self):
"""Yields ForeignKey named tuples"""
- if self.conn.server_version < 90000:
+ if self.conn.info.server_version < 90000:
return
with self.conn.cursor() as cur:
@@ -752,7 +662,7 @@ class PGExecute:
def functions(self):
"""Yields FunctionMetadata named tuples"""
- if self.conn.server_version >= 110000:
+ if self.conn.info.server_version >= 110000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@@ -772,7 +682,7 @@ class PGExecute:
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
- elif self.conn.server_version > 90000:
+ elif self.conn.info.server_version > 90000:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@@ -792,7 +702,7 @@ class PGExecute:
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
"""
- elif self.conn.server_version >= 80400:
+ elif self.conn.info.server_version >= 80400:
query = """
SELECT n.nspname schema_name,
p.proname func_name,
@@ -843,7 +753,7 @@ class PGExecute:
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
- if self.conn.server_version > 90000:
+ if self.conn.info.server_version > 90000:
query = """
SELECT n.nspname schema_name,
t.typname type_name
@@ -931,3 +841,6 @@ class PGExecute:
cur.execute(query)
for row in cur:
yield row[0]
+
+ def explain_prefix(self):
+ return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) "