diff options
Diffstat (limited to 'pgcli/pgexecute.py')
-rw-r--r-- | pgcli/pgexecute.py | 351 |
1 files changed, 132 insertions, 219 deletions
diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index 4808630..8f2968d 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -1,155 +1,45 @@ import logging -import select import traceback +from collections import namedtuple import pgspecial as special -import psycopg2 -import psycopg2.errorcodes -import psycopg2.extensions as ext -import psycopg2.extras +import psycopg +import psycopg.sql +from psycopg.conninfo import make_conninfo import sqlparse -from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn from .packages.parseutils.meta import FunctionMetadata, ForeignKey _logger = logging.getLogger(__name__) -# Cast all database input to unicode automatically. -# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info. -# pg3: These should be automatic: unicode is the default -ext.register_type(ext.UNICODE) -ext.register_type(ext.UNICODEARRAY) -ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE)) -# See https://github.com/dbcli/pgcli/issues/426 for more details. -# This registers a unicode type caster for datatype 'RECORD'. -ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE)) - -# Cast bytea fields to text. By default, this will render as hex strings with -# Postgres 9+ and as escaped binary in earlier versions. -ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING)) - -# TODO: Get default timeout from pgclirc? -_WAIT_SELECT_TIMEOUT = 1 -_wait_callback_is_set = False - - -# pg3: it is already "green" but Ctrl-C breaks the query -# pg3: This should be fixed upstream: https://github.com/psycopg/psycopg/issues/231 -def _wait_select(conn): - """ - copy-pasted from psycopg2.extras.wait_select - the default implementation doesn't define a timeout in the select calls - """ - try: - while 1: - try: - state = conn.poll() - if state == POLL_OK: - break - elif state == POLL_READ: - select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT) - elif state == POLL_WRITE: - select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT) - else: - raise conn.OperationalError("bad state from poll: %s" % state) - except KeyboardInterrupt: - conn.cancel() - # the loop will be broken by a server error - continue - except OSError as e: - errno = e.args[0] - if errno != 4: - raise - except psycopg2.OperationalError: - pass - - -def _set_wait_callback(is_virtual_database): - global _wait_callback_is_set - if _wait_callback_is_set: - return - _wait_callback_is_set = True - if is_virtual_database: - return - # When running a query, make pressing CTRL+C raise a KeyboardInterrupt - # See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ - # See also https://github.com/psycopg/psycopg2/issues/468 - ext.set_wait_callback(_wait_select) - - -# pg3: You can do something like: -# pg3: cnn.adapters.register_loader("date", psycopg.types.string.TextLoader) -def register_date_typecasters(connection): - """ - Casts date and timestamp values to string, resolves issues with out of - range dates (e.g. BC) which psycopg2 can't handle - """ - - def cast_date(value, cursor): - return value - - cursor = connection.cursor() - cursor.execute("SELECT NULL::date") - if cursor.description is None: - return - date_oid = cursor.description[0][1] - cursor.execute("SELECT NULL::timestamp") - timestamp_oid = cursor.description[0][1] - cursor.execute("SELECT NULL::timestamp with time zone") - timestamptz_oid = cursor.description[0][1] - oids = (date_oid, timestamp_oid, timestamptz_oid) - new_type = psycopg2.extensions.new_type(oids, "DATE", cast_date) - psycopg2.extensions.register_type(new_type) - - -def register_json_typecasters(conn, loads_fn): - """Set the function for converting JSON data for a connection. - - Use the supplied function to decode JSON data returned from the database - via the given connection. The function should accept a single argument of - the data as a string encoded in the database's character encoding. - psycopg2's default handler for JSON data is json.loads. - http://initd.org/psycopg/docs/extras.html#json-adaptation - - This function attempts to register the typecaster for both JSON and JSONB - types. - - Returns a set that is a subset of {'json', 'jsonb'} indicating which types - (if any) were successfully registered. - """ - available = set() - - for name in ["json", "jsonb"]: - try: - psycopg2.extras.register_json(conn, loads=loads_fn, name=name) - available.add(name) - except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation): - pass - - return available - - -# pg3: Probably you don't need this because by default unknown -> unicode -def register_hstore_typecaster(conn): - """ - Instead of using register_hstore() which converts hstore into a python - dict, we query the 'oid' of hstore which will be different for each - database and register a type caster that converts it to unicode. - http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore - """ - with conn.cursor() as cur: - try: - cur.execute( - "select t.oid FROM pg_type t WHERE t.typname = 'hstore' and t.typisdefined" - ) - oid = cur.fetchone()[0] - ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE)) - except Exception: - pass +ViewDef = namedtuple( + "ViewDef", "nspname relname relkind viewdef reloptions checkoption" +) + + +def register_typecasters(connection): + """Casts date and timestamp values to string, resolves issues with out-of-range + dates (e.g. BC) which psycopg can't handle""" + for forced_text_type in [ + "date", + "time", + "timestamp", + "timestamptz", + "bytea", + "json", + "jsonb", + ]: + connection.adapters.register_loader( + forced_text_type, psycopg.types.string.TextLoader + ) # pg3: I don't know what is this -class ProtocolSafeCursor(psycopg2.extensions.cursor): +class ProtocolSafeCursor(psycopg.Cursor): + """This class wraps and suppresses Protocol Errors with pgbouncer database. + See https://github.com/dbcli/pgcli/pull/1097. + Pgbouncer database is a virtual database with its own set of commands.""" + def __init__(self, *args, **kwargs): self.protocol_error = False self.protocol_message = "" @@ -170,14 +60,18 @@ class ProtocolSafeCursor(psycopg2.extensions.cursor): return (self.protocol_message,) return super().fetchone() - def execute(self, sql, args=None): + # def mogrify(self, query, params): + # args = [Literal(v).as_string(self.connection) for v in params] + # return query % tuple(args) + # + def execute(self, *args, **kwargs): try: - psycopg2.extensions.cursor.execute(self, sql, args) + super().execute(*args, **kwargs) self.protocol_error = False self.protocol_message = "" - except psycopg2.errors.ProtocolViolation as ex: + except psycopg.errors.ProtocolViolation as ex: self.protocol_error = True - self.protocol_message = ex.pgerror + self.protocol_message = str(ex) _logger.debug("%s: %s" % (ex.__class__.__name__, ex)) @@ -290,7 +184,7 @@ class PGExecute: conn_params = self._conn_params.copy() new_params = { - "database": database, + "dbname": database, "user": user, "password": password, "host": host, @@ -303,15 +197,15 @@ class PGExecute: new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} if new_params["password"]: - new_params["dsn"] = make_dsn( + new_params["dsn"] = make_conninfo( new_params["dsn"], password=new_params.pop("password") ) conn_params.update({k: v for k, v in new_params.items() if v}) - conn_params["cursor_factory"] = ProtocolSafeCursor - conn = psycopg2.connect(**conn_params) - conn.set_client_encoding("utf8") + conn_info = make_conninfo(**conn_params) + conn = psycopg.connect(conn_info) + conn.cursor_factory = ProtocolSafeCursor self._conn_params = conn_params if self.conn: @@ -322,19 +216,7 @@ class PGExecute: # When we connect using a DSN, we don't really know what db, # user, etc. we connected to. Let's read it. # Note: moved this after setting autocommit because of #664. - libpq_version = psycopg2.__libpq_version__ - dsn_parameters = {} - if libpq_version >= 93000: - # use actual connection info from psycopg2.extensions.Connection.info - # as libpq_version > 9.3 is available and required dependency - dsn_parameters = conn.info.dsn_parameters - else: - try: - dsn_parameters = conn.get_dsn_parameters() - except Exception as x: - # https://github.com/dbcli/pgcli/issues/1110 - # PQconninfo not available in libpq < 9.3 - _logger.info("Exception in get_dsn_parameters: %r", x) + dsn_parameters = conn.info.get_parameters() if dsn_parameters: self.dbname = dsn_parameters.get("dbname") @@ -357,16 +239,14 @@ class PGExecute: else self.get_socket_directory() ) - self.pid = conn.get_backend_pid() - self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1") - self.server_version = conn.get_parameter_status("server_version") or "" + self.pid = conn.info.backend_pid + self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1") + self.server_version = conn.info.parameter_status("server_version") or "" - _set_wait_callback(self.is_virtual_database()) + # _set_wait_callback(self.is_virtual_database()) if not self.is_virtual_database(): - register_date_typecasters(conn) - register_json_typecasters(self.conn, self._json_typecaster) - register_hstore_typecaster(self.conn) + register_typecasters(conn) @property def short_host(self): @@ -387,31 +267,23 @@ class PGExecute: cur.execute(sql) return cur.fetchone() - def _json_typecaster(self, json_data): - """Interpret incoming JSON data as a string. - - The raw data is decoded using the connection's encoding, which defaults - to the database's encoding. - - See http://initd.org/psycopg/docs/connection.html#connection.encoding - """ - - return json_data - def failed_transaction(self): - # pg3: self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR - status = self.conn.get_transaction_status() - return status == ext.TRANSACTION_STATUS_INERROR + return self.conn.info.transaction_status == psycopg.pq.TransactionStatus.INERROR def valid_transaction(self): - status = self.conn.get_transaction_status() + status = self.conn.info.transaction_status return ( - status == ext.TRANSACTION_STATUS_ACTIVE - or status == ext.TRANSACTION_STATUS_INTRANS + status == psycopg.pq.TransactionStatus.ACTIVE + or status == psycopg.pq.TransactionStatus.INTRANS ) def run( - self, statement, pgspecial=None, exception_formatter=None, on_error_resume=False + self, + statement, + pgspecial=None, + exception_formatter=None, + on_error_resume=False, + explain_mode=False, ): """Execute the sql in the database and return the results. @@ -432,17 +304,38 @@ class PGExecute: # Remove spaces and EOL statement = statement.strip() if not statement: # Empty string - yield (None, None, None, None, statement, False, False) + yield None, None, None, None, statement, False, False + + # sql parse doesn't split on a comment first + special + # so we're going to do it + + sqltemp = [] + sqlarr = [] + + if statement.startswith("--"): + sqltemp = statement.split("\n") + sqlarr.append(sqltemp[0]) + for i in sqlparse.split(sqltemp[1]): + sqlarr.append(i) + elif statement.startswith("/*"): + sqltemp = statement.split("*/") + sqltemp[0] = sqltemp[0] + "*/" + for i in sqlparse.split(sqltemp[1]): + sqlarr.append(i) + else: + sqlarr = sqlparse.split(statement) - # Split the sql into separate queries and run each one. - for sql in sqlparse.split(statement): + # run each sql query + for sql in sqlarr: # Remove spaces, eol and semi-colons. sql = sql.rstrip(";") sql = sqlparse.format(sql, strip_comments=False).strip() if not sql: continue try: - if pgspecial: + if explain_mode: + sql = self.explain_prefix() + sql + elif pgspecial: # \G is treated specially since we have to set the expanded output. if sql.endswith("\\G"): if not pgspecial.expanded_output: @@ -454,7 +347,7 @@ class PGExecute: _logger.debug("Trying a pgspecial command. sql: %r", sql) try: cur = self.conn.cursor() - except psycopg2.InterfaceError: + except psycopg.InterfaceError: # edge case when connection is already closed, but we # don't need cursor for special_cmd.arg_type == NO_QUERY. # See https://github.com/dbcli/pgcli/issues/1014. @@ -478,7 +371,7 @@ class PGExecute: # Not a special command, so execute as normal sql yield self.execute_normal_sql(sql) + (sql, True, False) - except psycopg2.DatabaseError as e: + except psycopg.DatabaseError as e: _logger.error("sql: %r, error: %r", sql, e) _logger.error("traceback: %r", traceback.format_exc()) @@ -498,7 +391,7 @@ class PGExecute: """Return true if e is an error that should not be caught in ``run``. An uncaught error will prompt the user to reconnect; as long as we - detect that the connection is stil open, we catch the error, as + detect that the connection is still open, we catch the error, as reconnecting won't solve that problem. :param e: DatabaseError. An exception raised while executing a query. @@ -511,13 +404,23 @@ class PGExecute: def execute_normal_sql(self, split_sql): """Returns tuple (title, rows, headers, status)""" _logger.debug("Regular sql statement. sql: %r", split_sql) - cur = self.conn.cursor() - cur.execute(split_sql) - # conn.notices persist between queies, we use pop to clear out the list title = "" - while len(self.conn.notices) > 0: - title = self.conn.notices.pop() + title + + def handle_notices(n): + nonlocal title + title = f"{n.message_primary}\n{n.message_detail}\n{title}" + + self.conn.add_notice_handler(handle_notices) + + if self.is_virtual_database() and "show help" in split_sql.lower(): + # see https://github.com/psycopg/psycopg/issues/303 + # special case "show help" in pgbouncer + res = self.conn.pgconn.exec_(split_sql.encode()) + return title, None, None, res.command_status.decode() + + cur = self.conn.cursor() + cur.execute(split_sql) # cur.description will be None for operations that do not return # rows. @@ -539,7 +442,7 @@ class PGExecute: _logger.debug("Search path query. sql: %r", self.search_path_query) cur.execute(self.search_path_query) return [x[0] for x in cur.fetchall()] - except psycopg2.ProgrammingError: + except psycopg.ProgrammingError: fallback = "SELECT * FROM current_schemas(true)" with self.conn.cursor() as cur: _logger.debug("Search path query. sql: %r", fallback) @@ -549,9 +452,6 @@ class PGExecute: def view_definition(self, spec): """Returns the SQL defining views described by `spec`""" - # pg3: you may want to use `psycopg.sql` for client-side composition - # pg3: (also available in psycopg2 by the way) - template = "CREATE OR REPLACE {6} VIEW {0}.{1} AS \n{3}" # 2: relkind, v or m (materialized) # 4: reloptions, null # 5: checkoption: local or cascaded @@ -560,11 +460,21 @@ class PGExecute: _logger.debug("View Definition Query. sql: %r\nspec: %r", sql, spec) try: cur.execute(sql, (spec,)) - except psycopg2.ProgrammingError: + except psycopg.ProgrammingError: raise RuntimeError(f"View {spec} does not exist.") - result = cur.fetchone() - view_type = "MATERIALIZED" if result[2] == "m" else "" - return template.format(*result + (view_type,)) + result = ViewDef(*cur.fetchone()) + if result.relkind == "m": + template = "CREATE OR REPLACE MATERIALIZED VIEW {name} AS \n{stmt}" + else: + template = "CREATE OR REPLACE VIEW {name} AS \n{stmt}" + return ( + psycopg.sql.SQL(template) + .format( + name=psycopg.sql.Identifier(f"{result.nspname}.{result.relname}"), + stmt=psycopg.sql.SQL(result.viewdef), + ) + .as_string(self.conn) + ) def function_definition(self, spec): """Returns the SQL defining functions described by `spec`""" @@ -576,7 +486,7 @@ class PGExecute: cur.execute(sql, (spec,)) result = cur.fetchone() return result[0] - except psycopg2.ProgrammingError: + except psycopg.ProgrammingError: raise RuntimeError(f"Function {spec} does not exist.") def schemata(self): @@ -600,9 +510,9 @@ class PGExecute: """ with self.conn.cursor() as cur: - sql = cur.mogrify(self.tables_query, [kinds]) - _logger.debug("Tables Query. sql: %r", sql) - cur.execute(sql) + # sql = cur.mogrify(self.tables_query, kinds) + # _logger.debug("Tables Query. sql: %r", sql) + cur.execute(self.tables_query, [kinds]) yield from cur def tables(self): @@ -628,7 +538,7 @@ class PGExecute: :return: list of (schema_name, relation_name, column_name, column_type) tuples """ - if self.conn.server_version >= 80400: + if self.conn.info.server_version >= 80400: columns_query = """ SELECT nsp.nspname schema_name, cls.relname table_name, @@ -669,9 +579,9 @@ class PGExecute: ORDER BY 1, 2, att.attnum""" with self.conn.cursor() as cur: - sql = cur.mogrify(columns_query, [kinds]) - _logger.debug("Columns Query. sql: %r", sql) - cur.execute(sql) + # sql = cur.mogrify(columns_query, kinds) + # _logger.debug("Columns Query. sql: %r", sql) + cur.execute(columns_query, [kinds]) yield from cur def table_columns(self): @@ -712,7 +622,7 @@ class PGExecute: def foreignkeys(self): """Yields ForeignKey named tuples""" - if self.conn.server_version < 90000: + if self.conn.info.server_version < 90000: return with self.conn.cursor() as cur: @@ -752,7 +662,7 @@ class PGExecute: def functions(self): """Yields FunctionMetadata named tuples""" - if self.conn.server_version >= 110000: + if self.conn.info.server_version >= 110000: query = """ SELECT n.nspname schema_name, p.proname func_name, @@ -772,7 +682,7 @@ class PGExecute: WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 """ - elif self.conn.server_version > 90000: + elif self.conn.info.server_version > 90000: query = """ SELECT n.nspname schema_name, p.proname func_name, @@ -792,7 +702,7 @@ class PGExecute: WHERE p.prorettype::regtype != 'trigger'::regtype ORDER BY 1, 2 """ - elif self.conn.server_version >= 80400: + elif self.conn.info.server_version >= 80400: query = """ SELECT n.nspname schema_name, p.proname func_name, @@ -843,7 +753,7 @@ class PGExecute: """Yields tuples of (schema_name, type_name)""" with self.conn.cursor() as cur: - if self.conn.server_version > 90000: + if self.conn.info.server_version > 90000: query = """ SELECT n.nspname schema_name, t.typname type_name @@ -931,3 +841,6 @@ class PGExecute: cur.execute(query) for row in cur: yield row[0] + + def explain_prefix(self): + return "EXPLAIN (ANALYZE, COSTS, VERBOSE, BUFFERS, FORMAT JSON) " |