diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-02-08 11:28:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-02-08 11:28:14 +0000 |
commit | b678a621c57a6d3fdfac14bdbbef0ed743ab1742 (patch) | |
tree | 5481c14ce75dfda9c55721de033992b45ab0e1dc /mycli/sqlexecute.py | |
parent | Initial commit. (diff) | |
download | mycli-b678a621c57a6d3fdfac14bdbbef0ed743ab1742.tar.xz mycli-b678a621c57a6d3fdfac14bdbbef0ed743ab1742.zip |
Adding upstream version 1.22.2.upstream/1.22.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'mycli/sqlexecute.py')
-rw-r--r-- | mycli/sqlexecute.py | 313 |
1 files changed, 313 insertions, 0 deletions
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py new file mode 100644 index 0000000..c68af0f --- /dev/null +++ b/mycli/sqlexecute.py @@ -0,0 +1,313 @@ +import logging +import pymysql +import sqlparse +from .packages import special +from pymysql.constants import FIELD_TYPE +from pymysql.converters import (convert_datetime, + convert_timedelta, convert_date, conversions, + decoders) +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko + +_logger = logging.getLogger(__name__) + +FIELD_TYPES = decoders.copy() +FIELD_TYPES.update({ + FIELD_TYPE.NULL: type(None) +}) + +class SQLExecute(object): + + databases_query = '''SHOW DATABASES''' + + tables_query = '''SHOW TABLES''' + + version_query = '''SELECT @@VERSION''' + + version_comment_query = '''SELECT @@VERSION_COMMENT''' + version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"''' + + show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' + + users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' + + functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES + WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' + + table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns + where table_schema = '%s' + order by table_name,ordinal_position''' + + def __init__(self, database, user, password, host, port, socket, charset, + local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, + ssh_key_filename): + self.dbname = database + self.user = user + self.password = password + self.host = host + self.port = port + self.socket = socket + self.charset = charset + self.local_infile = local_infile + self.ssl = ssl + self._server_type = None + self.connection_id = None + self.ssh_user = ssh_user + self.ssh_host = ssh_host + self.ssh_port = ssh_port + self.ssh_password = ssh_password + self.ssh_key_filename = ssh_key_filename + self.connect() + + def connect(self, database=None, user=None, password=None, host=None, + port=None, socket=None, charset=None, local_infile=None, + ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, + ssh_password=None, ssh_key_filename=None): + db = (database or self.dbname) + user = (user or self.user) + password = (password or self.password) + host = (host or self.host) + port = (port or self.port) + socket = (socket or self.socket) + charset = (charset or self.charset) + local_infile = (local_infile or self.local_infile) + ssl = (ssl or self.ssl) + ssh_user = (ssh_user or self.ssh_user) + ssh_host = (ssh_host or self.ssh_host) + ssh_port = (ssh_port or self.ssh_port) + ssh_password = (ssh_password or self.ssh_password) + ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) + _logger.debug( + 'Connection DB Params: \n' + '\tdatabase: %r' + '\tuser: %r' + '\thost: %r' + '\tport: %r' + '\tsocket: %r' + '\tcharset: %r' + '\tlocal_infile: %r' + '\tssl: %r' + '\tssh_user: %r' + '\tssh_host: %r' + '\tssh_port: %r' + '\tssh_password: %r' + '\tssh_key_filename: %r', + db, user, host, port, socket, charset, local_infile, ssl, + ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename + ) + conv = conversions.copy() + conv.update({ + FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), + FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), + FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), + FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), + }) + + defer_connect = False + + if ssh_host: + defer_connect = True + + conn = pymysql.connect( + database=db, user=user, password=password, host=host, port=port, + unix_socket=socket, use_unicode=True, charset=charset, + autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE, + local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", + defer_connect=defer_connect + ) + + if ssh_host: + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + client.connect( + ssh_host, ssh_port, ssh_user, ssh_password, + key_filename=ssh_key_filename + ) + chan = client.get_transport().open_channel( + 'direct-tcpip', + (host, port), + ('0.0.0.0', 0), + ) + conn.connect(chan) + + if hasattr(self, 'conn'): + self.conn.close() + self.conn = conn + # Update them after the connection is made to ensure that it was a + # successful connection. + self.dbname = db + self.user = user + self.password = password + self.host = host + self.port = port + self.socket = socket + self.charset = charset + self.ssl = ssl + # retrieve connection id + self.reset_connection_id() + + def run(self, statement): + """Execute the sql in the database and return the results. The results + are a list of tuples. Each tuple has 4 values + (title, rows, headers, status). + """ + + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield (None, None, None, None) + + # Split the sql into separate queries and run each one. + # Unless it's saving a favorite query, in which case we + # want to save them all together. + if statement.startswith('\\fs'): + components = [statement] + else: + components = special.split_queries(statement) + + for sql in components: + # \G is treated specially since we have to set the expanded output. + if sql.endswith('\\G'): + special.set_expanded_output(True) + sql = sql[:-2].strip() + + cur = self.conn.cursor() + try: # Special command + _logger.debug('Trying a dbspecial command. sql: %r', sql) + for result in special.execute(cur, sql): + yield result + except special.CommandNotFound: # Regular SQL + _logger.debug('Regular sql statement. sql: %r', sql) + cur.execute(sql) + while True: + yield self.get_result(cur) + + # PyMySQL returns an extra, empty result set with stored + # procedures. We skip it (rowcount is zero and no + # description). + if not cur.nextset() or (not cur.rowcount and cur.description is None): + break + + def get_result(self, cursor): + """Get the current result's data from the cursor.""" + title = headers = None + + # cursor.description is not None for queries that return result sets, + # e.g. SELECT or SHOW. + if cursor.description is not None: + headers = [x[0] for x in cursor.description] + status = '{0} row{1} in set' + else: + _logger.debug('No rows in result.') + status = 'Query OK, {0} row{1} affected' + status = status.format(cursor.rowcount, + '' if cursor.rowcount == 1 else 's') + + return (title, cursor if cursor.description else None, headers, status) + + def tables(self): + """Yields table names""" + + with self.conn.cursor() as cur: + _logger.debug('Tables Query. sql: %r', self.tables_query) + cur.execute(self.tables_query) + for row in cur: + yield row + + def table_columns(self): + """Yields (table name, column name) pairs""" + with self.conn.cursor() as cur: + _logger.debug('Columns Query. sql: %r', self.table_columns_query) + cur.execute(self.table_columns_query % self.dbname) + for row in cur: + yield row + + def databases(self): + with self.conn.cursor() as cur: + _logger.debug('Databases Query. sql: %r', self.databases_query) + cur.execute(self.databases_query) + return [x[0] for x in cur.fetchall()] + + def functions(self): + """Yields tuples of (schema_name, function_name)""" + + with self.conn.cursor() as cur: + _logger.debug('Functions Query. sql: %r', self.functions_query) + cur.execute(self.functions_query % self.dbname) + for row in cur: + yield row + + def show_candidates(self): + with self.conn.cursor() as cur: + _logger.debug('Show Query. sql: %r', self.show_candidates_query) + try: + cur.execute(self.show_candidates_query) + except pymysql.DatabaseError as e: + _logger.error('No show completions due to %r', e) + yield '' + else: + for row in cur: + yield (row[0].split(None, 1)[-1], ) + + def users(self): + with self.conn.cursor() as cur: + _logger.debug('Users Query. sql: %r', self.users_query) + try: + cur.execute(self.users_query) + except pymysql.DatabaseError as e: + _logger.error('No user completions due to %r', e) + yield '' + else: + for row in cur: + yield row + + def server_type(self): + if self._server_type: + return self._server_type + with self.conn.cursor() as cur: + _logger.debug('Version Query. sql: %r', self.version_query) + cur.execute(self.version_query) + version = cur.fetchone()[0] + if version[0] == '4': + _logger.debug('Version Comment. sql: %r', + self.version_comment_query_mysql4) + cur.execute(self.version_comment_query_mysql4) + version_comment = cur.fetchone()[1].lower() + if isinstance(version_comment, bytes): + # with python3 this query returns bytes + version_comment = version_comment.decode('utf-8') + else: + _logger.debug('Version Comment. sql: %r', + self.version_comment_query) + cur.execute(self.version_comment_query) + version_comment = cur.fetchone()[0].lower() + + if 'mariadb' in version_comment: + product_type = 'mariadb' + elif 'percona' in version_comment: + product_type = 'percona' + else: + product_type = 'mysql' + + self._server_type = (product_type, version) + return self._server_type + + def get_connection_id(self): + if not self.connection_id: + self.reset_connection_id() + return self.connection_id + + def reset_connection_id(self): + # Remember current connection id + _logger.debug('Get current connection id') + res = self.run('select connection_id()') + for title, cur, headers, status in res: + self.connection_id = cur.fetchone()[0] + _logger.debug('Current connection id: %s', self.connection_id) + + def change_db(self, db): + self.conn.select_db(db) + self.dbname = db |