summaryrefslogtreecommitdiffstats
path: root/mycli/sqlexecute.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--mycli/sqlexecute.py356
1 files changed, 356 insertions, 0 deletions
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
new file mode 100644
index 0000000..c019707
--- /dev/null
+++ b/mycli/sqlexecute.py
@@ -0,0 +1,356 @@
+import enum
+import logging
+import re
+
+import pymysql
+from .packages import special
+from pymysql.constants import FIELD_TYPE
+from pymysql.converters import (convert_datetime,
+ convert_timedelta, convert_date, conversions,
+ decoders)
+try:
+ import paramiko
+except ImportError:
+ from mycli.packages.paramiko_stub import paramiko
+
+_logger = logging.getLogger(__name__)
+
+FIELD_TYPES = decoders.copy()
+FIELD_TYPES.update({
+ FIELD_TYPE.NULL: type(None)
+})
+
+
+ERROR_CODE_ACCESS_DENIED = 1045
+
+
+class ServerSpecies(enum.Enum):
+ MySQL = 'MySQL'
+ MariaDB = 'MariaDB'
+ Percona = 'Percona'
+ TiDB = 'TiDB'
+ Unknown = 'MySQL'
+
+
+class ServerInfo:
+ def __init__(self, species, version_str):
+ self.species = species
+ self.version_str = version_str
+ self.version = self.calc_mysql_version_value(version_str)
+
+ @staticmethod
+ def calc_mysql_version_value(version_str) -> int:
+ if not version_str or not isinstance(version_str, str):
+ return 0
+ try:
+ major, minor, patch = version_str.split('.')
+ except ValueError:
+ return 0
+ else:
+ return int(major) * 10_000 + int(minor) * 100 + int(patch)
+
+ @classmethod
+ def from_version_string(cls, version_string):
+ if not version_string:
+ return cls(ServerSpecies.Unknown, '')
+
+ re_species = (
+ (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-TiDB', ServerSpecies.TiDB),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
+ ServerSpecies.Percona),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
+ ServerSpecies.MySQL),
+ )
+ for regexp, species in re_species:
+ match = re.search(regexp, version_string)
+ if match is not None:
+ parsed_version = match.group('version')
+ detected_species = species
+ break
+ else:
+ detected_species = ServerSpecies.Unknown
+ parsed_version = ''
+
+ return cls(detected_species, parsed_version)
+
+ def __str__(self):
+ if self.species:
+ return f'{self.species.value} {self.version_str}'
+ else:
+ return self.version_str
+
+
+class SQLExecute(object):
+
+ databases_query = '''SHOW DATABASES'''
+
+ tables_query = '''SHOW TABLES'''
+
+ show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
+
+ users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
+
+ functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES
+ WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"'''
+
+ table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns
+ where table_schema = '%s'
+ order by table_name,ordinal_position'''
+
+ def __init__(self, database, user, password, host, port, socket, charset,
+ local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
+ ssh_key_filename, init_command=None):
+ self.dbname = database
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ self.socket = socket
+ self.charset = charset
+ self.local_infile = local_infile
+ self.ssl = ssl
+ self.server_info = None
+ self.connection_id = None
+ self.ssh_user = ssh_user
+ self.ssh_host = ssh_host
+ self.ssh_port = ssh_port
+ self.ssh_password = ssh_password
+ self.ssh_key_filename = ssh_key_filename
+ self.init_command = init_command
+ self.connect()
+
+ def connect(self, database=None, user=None, password=None, host=None,
+ port=None, socket=None, charset=None, local_infile=None,
+ ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
+ ssh_password=None, ssh_key_filename=None, init_command=None):
+ db = (database or self.dbname)
+ user = (user or self.user)
+ password = (password or self.password)
+ host = (host or self.host)
+ port = (port or self.port)
+ socket = (socket or self.socket)
+ charset = (charset or self.charset)
+ local_infile = (local_infile or self.local_infile)
+ ssl = (ssl or self.ssl)
+ ssh_user = (ssh_user or self.ssh_user)
+ ssh_host = (ssh_host or self.ssh_host)
+ ssh_port = (ssh_port or self.ssh_port)
+ ssh_password = (ssh_password or self.ssh_password)
+ ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
+ init_command = (init_command or self.init_command)
+ _logger.debug(
+ 'Connection DB Params: \n'
+ '\tdatabase: %r'
+ '\tuser: %r'
+ '\thost: %r'
+ '\tport: %r'
+ '\tsocket: %r'
+ '\tcharset: %r'
+ '\tlocal_infile: %r'
+ '\tssl: %r'
+ '\tssh_user: %r'
+ '\tssh_host: %r'
+ '\tssh_port: %r'
+ '\tssh_password: %r'
+ '\tssh_key_filename: %r'
+ '\tinit_command: %r',
+ db, user, host, port, socket, charset, local_infile, ssl,
+ ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
+ init_command
+ )
+ conv = conversions.copy()
+ conv.update({
+ FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj),
+ FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj),
+ FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj),
+ FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj),
+ })
+
+ defer_connect = False
+
+ if ssh_host:
+ defer_connect = True
+
+ client_flag = pymysql.constants.CLIENT.INTERACTIVE
+ if init_command and len(list(special.split_queries(init_command))) > 1:
+ client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
+
+ conn = pymysql.connect(
+ database=db, user=user, password=password, host=host, port=port,
+ unix_socket=socket, use_unicode=True, charset=charset,
+ autocommit=True, client_flag=client_flag,
+ local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
+ defer_connect=defer_connect, init_command=init_command
+ )
+
+ if ssh_host:
+ client = paramiko.SSHClient()
+ client.load_system_host_keys()
+ client.set_missing_host_key_policy(paramiko.WarningPolicy())
+ client.connect(
+ ssh_host, ssh_port, ssh_user, ssh_password,
+ key_filename=ssh_key_filename
+ )
+ chan = client.get_transport().open_channel(
+ 'direct-tcpip',
+ (host, port),
+ ('0.0.0.0', 0),
+ )
+ conn.connect(chan)
+
+ if hasattr(self, 'conn'):
+ self.conn.close()
+ self.conn = conn
+ # Update them after the connection is made to ensure that it was a
+ # successful connection.
+ self.dbname = db
+ self.user = user
+ self.password = password
+ self.host = host
+ self.port = port
+ self.socket = socket
+ self.charset = charset
+ self.ssl = ssl
+ self.init_command = init_command
+ # retrieve connection id
+ self.reset_connection_id()
+ self.server_info = ServerInfo.from_version_string(conn.server_version)
+
+ def run(self, statement):
+ """Execute the sql in the database and return the results. The results
+ are a list of tuples. Each tuple has 4 values
+ (title, rows, headers, status).
+ """
+
+ # Remove spaces and EOL
+ statement = statement.strip()
+ if not statement: # Empty string
+ yield (None, None, None, None)
+
+ # Split the sql into separate queries and run each one.
+ # Unless it's saving a favorite query, in which case we
+ # want to save them all together.
+ if statement.startswith('\\fs'):
+ components = [statement]
+ else:
+ components = special.split_queries(statement)
+
+ for sql in components:
+ # \G is treated specially since we have to set the expanded output.
+ if sql.endswith('\\G'):
+ special.set_expanded_output(True)
+ sql = sql[:-2].strip()
+
+ cur = self.conn.cursor()
+ try: # Special command
+ _logger.debug('Trying a dbspecial command. sql: %r', sql)
+ for result in special.execute(cur, sql):
+ yield result
+ except special.CommandNotFound: # Regular SQL
+ _logger.debug('Regular sql statement. sql: %r', sql)
+ cur.execute(sql)
+ while True:
+ yield self.get_result(cur)
+
+ # PyMySQL returns an extra, empty result set with stored
+ # procedures. We skip it (rowcount is zero and no
+ # description).
+ if not cur.nextset() or (not cur.rowcount and cur.description is None):
+ break
+
+ def get_result(self, cursor):
+ """Get the current result's data from the cursor."""
+ title = headers = None
+
+ # cursor.description is not None for queries that return result sets,
+ # e.g. SELECT or SHOW.
+ if cursor.description is not None:
+ headers = [x[0] for x in cursor.description]
+ status = '{0} row{1} in set'
+ else:
+ _logger.debug('No rows in result.')
+ status = 'Query OK, {0} row{1} affected'
+ status = status.format(cursor.rowcount,
+ '' if cursor.rowcount == 1 else 's')
+
+ return (title, cursor if cursor.description else None, headers, status)
+
+ def tables(self):
+ """Yields table names"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Tables Query. sql: %r', self.tables_query)
+ cur.execute(self.tables_query)
+ for row in cur:
+ yield row
+
+ def table_columns(self):
+ """Yields (table name, column name) pairs"""
+ with self.conn.cursor() as cur:
+ _logger.debug('Columns Query. sql: %r', self.table_columns_query)
+ cur.execute(self.table_columns_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def databases(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Databases Query. sql: %r', self.databases_query)
+ cur.execute(self.databases_query)
+ return [x[0] for x in cur.fetchall()]
+
+ def functions(self):
+ """Yields tuples of (schema_name, function_name)"""
+
+ with self.conn.cursor() as cur:
+ _logger.debug('Functions Query. sql: %r', self.functions_query)
+ cur.execute(self.functions_query % self.dbname)
+ for row in cur:
+ yield row
+
+ def show_candidates(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Show Query. sql: %r', self.show_candidates_query)
+ try:
+ cur.execute(self.show_candidates_query)
+ except pymysql.DatabaseError as e:
+ _logger.error('No show completions due to %r', e)
+ yield ''
+ else:
+ for row in cur:
+ yield (row[0].split(None, 1)[-1], )
+
+ def users(self):
+ with self.conn.cursor() as cur:
+ _logger.debug('Users Query. sql: %r', self.users_query)
+ try:
+ cur.execute(self.users_query)
+ except pymysql.DatabaseError as e:
+ _logger.error('No user completions due to %r', e)
+ yield ''
+ else:
+ for row in cur:
+ yield row
+
+ def get_connection_id(self):
+ if not self.connection_id:
+ self.reset_connection_id()
+ return self.connection_id
+
+ def reset_connection_id(self):
+ # Remember current connection id
+ _logger.debug('Get current connection id')
+ try:
+ res = self.run('select connection_id()')
+ for title, cur, headers, status in res:
+ self.connection_id = cur.fetchone()[0]
+ except Exception as e:
+ # See #1054
+ self.connection_id = -1
+ _logger.error('Failed to get connection id: %s', e)
+ else:
+ _logger.debug('Current connection id: %s', self.connection_id)
+
+ def change_db(self, db):
+ self.conn.select_db(db)
+ self.dbname = db