From 76d27bc43d56d7ef3ca0090fb199777888adf7c3 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 6 Sep 2021 06:17:09 +0200 Subject: Adding upstream version 3.2.0. Signed-off-by: Daniel Baumann --- pgcli/pgexecute.py | 178 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 122 insertions(+), 56 deletions(-) (limited to 'pgcli/pgexecute.py') diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index d34bf26..a013b55 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,13 +1,15 @@ -import traceback import logging +import select +import traceback + +import pgspecial as special import psycopg2 -import psycopg2.extras import psycopg2.errorcodes import psycopg2.extensions as ext +import psycopg2.extras import sqlparse -import pgspecial as special -import select from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn + from .packages.parseutils.meta import FunctionMetadata, ForeignKey _logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) # TODO: Get default timeout from pgclirc? _WAIT_SELECT_TIMEOUT = 1 +_wait_callback_is_set = False def _wait_select(conn): @@ -34,31 +37,41 @@ def _wait_select(conn): copy-pasted from psycopg2.extras.wait_select the default implementation doesn't define a timeout in the select calls """ - while 1: - try: - state = conn.poll() - if state == POLL_OK: - break - elif state == POLL_READ: - select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) - elif state == POLL_WRITE: - select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) - else: - raise conn.OperationalError("bad state from poll: %s" % state) - except KeyboardInterrupt: - conn.cancel() - # the loop will be broken by a server error - continue - except select.error as e: - errno = e.args[0] - if errno != 4: - raise + try: + while 1: + try: + state = conn.poll() + if state == POLL_OK: + break + elif state == POLL_READ: + select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) + elif state == POLL_WRITE: + select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) + else: + raise conn.OperationalError("bad state from poll: %s" % state) + except KeyboardInterrupt: + conn.cancel() + # the loop will be broken by a server error + continue + except OSError as e: + errno = e.args[0] + if errno != 4: + raise + except psycopg2.OperationalError: + pass -# When running a query, make pressing CTRL+C raise a KeyboardInterrupt -# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ -# See also https://github.com/psycopg/psycopg2/issues/468 -ext.set_wait_callback(_wait_select) +def _set_wait_callback(is_virtual_database): + global _wait_callback_is_set + if _wait_callback_is_set: + return + _wait_callback_is_set = True + if is_virtual_database: + return + # When running a query, make pressing CTRL+C raise a KeyboardInterrupt + # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ + # See also https://github.com/psycopg/psycopg2/issues/468 + ext.set_wait_callback(_wait_select) def register_date_typecasters(connection): @@ -72,6 +85,8 @@ def register_date_typecasters(connection): cursor = connection.cursor() cursor.execute("SELECT NULL::date") + if cursor.description is None: + return date_oid = cursor.description[0][1] cursor.execute("SELECT NULL::timestamp") timestamp_oid = cursor.description[0][1] @@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn): try: psycopg2.extras.register_json(conn, loads=loads_fn, name=name) available.add(name) - except psycopg2.ProgrammingError: + except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation): pass return available @@ -127,7 +142,39 @@ def register_hstore_typecaster(conn): pass -class PGExecute(object): +class ProtocolSafeCursor(psycopg2.extensions.cursor): + def __init__(self, *args, **kwargs): + self.protocol_error = False + self.protocol_message = "" + super().__init__(*args, **kwargs) + + def __iter__(self): + if self.protocol_error: + raise StopIteration + return super().__iter__() + + def fetchall(self): + if self.protocol_error: + return [(self.protocol_message,)] + return super().fetchall() + + def fetchone(self): + if self.protocol_error: + return (self.protocol_message,) + return super().fetchone() + + def execute(self, sql, args=None): + try: + psycopg2.extensions.cursor.execute(self, sql, args) + self.protocol_error = False + self.protocol_message = "" + except psycopg2.errors.ProtocolViolation as ex: + self.protocol_error = True + self.protocol_message = ex.pgerror + _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) + + +class PGExecute: # The boolean argument to the current_schemas function indicates whether # implicit schemas, e.g. pg_catalog @@ -190,8 +237,6 @@ class PGExecute(object): SELECT pg_catalog.pg_get_functiondef(f.f_oid) FROM f""" - version_query = "SELECT version();" - def __init__( self, database=None, @@ -203,6 +248,7 @@ class PGExecute(object): **kwargs, ): self._conn_params = {} + self._is_virtual_database = None self.conn = None self.dbname = None self.user = None @@ -214,6 +260,11 @@ class PGExecute(object): self.connect(database, user, password, host, port, dsn, **kwargs) self.reset_expanded = None + def is_virtual_database(self): + if self._is_virtual_database is None: + self._is_virtual_database = self.is_protocol_error() + return self._is_virtual_database + def copy(self): """Returns a clone of the current executor.""" return self.__class__(**self._conn_params) @@ -250,9 +301,9 @@ class PGExecute(object): ) conn_params.update({k: v for k, v in new_params.items() if v}) + conn_params["cursor_factory"] = ProtocolSafeCursor conn = psycopg2.connect(**conn_params) - cursor = conn.cursor() conn.set_client_encoding("utf8") self._conn_params = conn_params @@ -293,16 +344,22 @@ class PGExecute(object): self.extra_args = kwargs if not self.host: - self.host = self.get_socket_directory() + self.host = ( + "pgbouncer" + if self.is_virtual_database() + else self.get_socket_directory() + ) - pid = self._select_one(cursor, "select pg_backend_pid()")[0] - self.pid = pid + self.pid = conn.get_backend_pid() self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1") - self.server_version = conn.get_parameter_status("server_version") + self.server_version = conn.get_parameter_status("server_version") or "" + + _set_wait_callback(self.is_virtual_database()) - register_date_typecasters(conn) - register_json_typecasters(self.conn, self._json_typecaster) - register_hstore_typecaster(self.conn) + if not self.is_virtual_database(): + register_date_typecasters(conn) + register_json_typecasters(self.conn, self._json_typecaster) + register_hstore_typecaster(self.conn) @property def short_host(self): @@ -395,7 +452,13 @@ class PGExecute(object): # See https://github.com/dbcli/pgcli/issues/1014. cur = None try: - for result in pgspecial.execute(cur, sql): + response = pgspecial.execute(cur, sql) + if cur and cur.protocol_error: + yield None, None, None, cur.protocol_message, statement, False, False + # this would close connection. We should reconnect. + self.connect() + continue + for result in response: # e.g. execute_from_file already appends these if len(result) < 7: yield result + (sql, True, True) @@ -453,6 +516,9 @@ class PGExecute(object): if cur.description: headers = [x[0] for x in cur.description] return title, cur, headers, cur.statusmessage + elif cur.protocol_error: + _logger.debug("Protocol error, unsupported command.") + return title, None, None, cur.protocol_message else: _logger.debug("No rows in result.") return title, None, None, cur.statusmessage @@ -485,7 +551,7 @@ class PGExecute(object): try: cur.execute(sql, (spec,)) except psycopg2.ProgrammingError: - raise RuntimeError("View {} does not exist.".format(spec)) + raise RuntimeError(f"View {spec} does not exist.") result = cur.fetchone() view_type = "MATERIALIZED" if result[2] == "m" else "" return template.format(*result + (view_type,)) @@ -501,7 +567,7 @@ class PGExecute(object): result = cur.fetchone() return result[0] except psycopg2.ProgrammingError: - raise RuntimeError("Function {} does not exist.".format(spec)) + raise RuntimeError(f"Function {spec} does not exist.") def schemata(self): """Returns a list of schema names in the database""" @@ -527,21 +593,18 @@ class PGExecute(object): sql = cur.mogrify(self.tables_query, [kinds]) _logger.debug("Tables Query. sql: %r", sql) cur.execute(sql) - for row in cur: - yield row + yield from cur def tables(self): """Yields (schema_name, table_name) tuples""" - for row in self._relations(kinds=["r", "p", "f"]): - yield row + yield from self._relations(kinds=["r", "p", "f"]) def views(self): """Yields (schema_name, view_name) tuples. Includes both views and and materialized views """ - for row in self._relations(kinds=["v", "m"]): - yield row + yield from self._relations(kinds=["v", "m"]) def _columns(self, kinds=("r", "p", "f", "v", "m")): """Get column metadata for tables and views @@ -599,16 +662,13 @@ class PGExecute(object): sql = cur.mogrify(columns_query, [kinds]) _logger.debug("Columns Query. sql: %r", sql) cur.execute(sql) - for row in cur: - yield row + yield from cur def table_columns(self): - for row in self._columns(kinds=["r", "p", "f"]): - yield row + yield from self._columns(kinds=["r", "p", "f"]) def view_columns(self): - for row in self._columns(kinds=["v", "m"]): - yield row + yield from self._columns(kinds=["v", "m"]) def databases(self): with self.conn.cursor() as cur: @@ -623,6 +683,13 @@ class PGExecute(object): headers = [x[0] for x in cur.description] return cur.fetchall(), headers, cur.statusmessage + def is_protocol_error(self): + query = "SELECT 1" + with self.conn.cursor() as cur: + _logger.debug("Simple Query. sql: %r", query) + cur.execute(query) + return bool(cur.protocol_error) + def get_socket_directory(self): with self.conn.cursor() as cur: _logger.debug( @@ -804,8 +871,7 @@ class PGExecute(object): """ _logger.debug("Datatypes Query. sql: %r", query) cur.execute(query) - for row in cur: - yield row + yield from cur def casing(self): """Yields the most common casing for names used in db functions""" -- cgit v1.2.3