From 009d0b0f17cc82919a683a1ecb6a334f5354090d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 8 Mar 2021 07:40:40 +0100 Subject: Merging upstream version 1.24.1. Signed-off-by: Daniel Baumann --- mycli/sqlexecute.py | 102 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 64 insertions(+), 38 deletions(-) (limited to 'mycli/sqlexecute.py') diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 7534982..9461438 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -1,6 +1,8 @@ +import enum import logging +import re + import pymysql -import sqlparse from .packages import special from pymysql.constants import FIELD_TYPE from pymysql.converters import (convert_datetime, @@ -18,17 +20,71 @@ FIELD_TYPES.update({ FIELD_TYPE.NULL: type(None) }) + +ERROR_CODE_ACCESS_DENIED = 1045 + + +class ServerSpecies(enum.Enum): + MySQL = 'MySQL' + MariaDB = 'MariaDB' + Percona = 'Percona' + Unknown = 'MySQL' + + +class ServerInfo: + def __init__(self, species, version_str): + self.species = species + self.version_str = version_str + self.version = self.calc_mysql_version_value(version_str) + + @staticmethod + def calc_mysql_version_value(version_str) -> int: + if not version_str or not isinstance(version_str, str): + return 0 + try: + major, minor, patch = version_str.split('.') + except ValueError: + return 0 + else: + return int(major) * 10_000 + int(minor) * 100 + int(patch) + + @classmethod + def from_version_string(cls, version_string): + if not version_string: + return cls(ServerSpecies.Unknown, '') + + re_species = ( + (r'(?P[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[0-9]+$)', + ServerSpecies.Percona), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[A-Za-z0-9_]+)', + ServerSpecies.MySQL), + ) + for regexp, species in re_species: + match = re.search(regexp, version_string) + if match is not None: + parsed_version = match.group('version') + detected_species = species + break + else: + detected_species = ServerSpecies.Unknown + parsed_version = '' + + return cls(detected_species, parsed_version) + + def __str__(self): + if self.species: + return f'{self.species.value} {self.version_str}' + else: + return self.version_str + + class SQLExecute(object): databases_query = '''SHOW DATABASES''' tables_query = '''SHOW TABLES''' - version_query = '''SELECT @@VERSION''' - - version_comment_query = '''SELECT @@VERSION_COMMENT''' - version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"''' - show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' @@ -52,7 +108,7 @@ class SQLExecute(object): self.charset = charset self.local_infile = local_infile self.ssl = ssl - self._server_type = None + self.server_info = None self.connection_id = None self.ssh_user = ssh_user self.ssh_host = ssh_host @@ -157,6 +213,7 @@ class SQLExecute(object): self.init_command = init_command # retrieve connection id self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) def run(self, statement): """Execute the sql in the database and return the results. The results @@ -273,37 +330,6 @@ class SQLExecute(object): for row in cur: yield row - def server_type(self): - if self._server_type: - return self._server_type - with self.conn.cursor() as cur: - _logger.debug('Version Query. sql: %r', self.version_query) - cur.execute(self.version_query) - version = cur.fetchone()[0] - if version[0] == '4': - _logger.debug('Version Comment. sql: %r', - self.version_comment_query_mysql4) - cur.execute(self.version_comment_query_mysql4) - version_comment = cur.fetchone()[1].lower() - if isinstance(version_comment, bytes): - # with python3 this query returns bytes - version_comment = version_comment.decode('utf-8') - else: - _logger.debug('Version Comment. sql: %r', - self.version_comment_query) - cur.execute(self.version_comment_query) - version_comment = cur.fetchone()[0].lower() - - if 'mariadb' in version_comment: - product_type = 'mariadb' - elif 'percona' in version_comment: - product_type = 'percona' - else: - product_type = 'mysql' - - self._server_type = (product_type, version) - return self._server_type - def get_connection_id(self): if not self.connection_id: self.reset_connection_id() -- cgit v1.2.3