summaryrefslogtreecommitdiffstats
path: root/mycli/sqlexecute.py
diff options
context:
space:
mode:
Diffstat (limited to 'mycli/sqlexecute.py')
-rw-r--r--mycli/sqlexecute.py102
1 files changed, 64 insertions, 38 deletions
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
index 7534982..9461438 100644
--- a/mycli/sqlexecute.py
+++ b/mycli/sqlexecute.py
@@ -1,6 +1,8 @@
+import enum
import logging
+import re
+
import pymysql
-import sqlparse
from .packages import special
from pymysql.constants import FIELD_TYPE
from pymysql.converters import (convert_datetime,
@@ -18,17 +20,71 @@ FIELD_TYPES.update({
FIELD_TYPE.NULL: type(None)
})
+
+ERROR_CODE_ACCESS_DENIED = 1045
+
+
+class ServerSpecies(enum.Enum):
+ MySQL = 'MySQL'
+ MariaDB = 'MariaDB'
+ Percona = 'Percona'
+ Unknown = 'MySQL'
+
+
+class ServerInfo:
+ def __init__(self, species, version_str):
+ self.species = species
+ self.version_str = version_str
+ self.version = self.calc_mysql_version_value(version_str)
+
+ @staticmethod
+ def calc_mysql_version_value(version_str) -> int:
+ if not version_str or not isinstance(version_str, str):
+ return 0
+ try:
+ major, minor, patch = version_str.split('.')
+ except ValueError:
+ return 0
+ else:
+ return int(major) * 10_000 + int(minor) * 100 + int(patch)
+
+ @classmethod
+ def from_version_string(cls, version_string):
+ if not version_string:
+ return cls(ServerSpecies.Unknown, '')
+
+ re_species = (
+ (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
+ ServerSpecies.Percona),
+ (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
+ ServerSpecies.MySQL),
+ )
+ for regexp, species in re_species:
+ match = re.search(regexp, version_string)
+ if match is not None:
+ parsed_version = match.group('version')
+ detected_species = species
+ break
+ else:
+ detected_species = ServerSpecies.Unknown
+ parsed_version = ''
+
+ return cls(detected_species, parsed_version)
+
+ def __str__(self):
+ if self.species:
+ return f'{self.species.value} {self.version_str}'
+ else:
+ return self.version_str
+
+
class SQLExecute(object):
databases_query = '''SHOW DATABASES'''
tables_query = '''SHOW TABLES'''
- version_query = '''SELECT @@VERSION'''
-
- version_comment_query = '''SELECT @@VERSION_COMMENT'''
- version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"'''
-
show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"'''
users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user'''
@@ -52,7 +108,7 @@ class SQLExecute(object):
self.charset = charset
self.local_infile = local_infile
self.ssl = ssl
- self._server_type = None
+ self.server_info = None
self.connection_id = None
self.ssh_user = ssh_user
self.ssh_host = ssh_host
@@ -157,6 +213,7 @@ class SQLExecute(object):
self.init_command = init_command
# retrieve connection id
self.reset_connection_id()
+ self.server_info = ServerInfo.from_version_string(conn.server_version)
def run(self, statement):
"""Execute the sql in the database and return the results. The results
@@ -273,37 +330,6 @@ class SQLExecute(object):
for row in cur:
yield row
- def server_type(self):
- if self._server_type:
- return self._server_type
- with self.conn.cursor() as cur:
- _logger.debug('Version Query. sql: %r', self.version_query)
- cur.execute(self.version_query)
- version = cur.fetchone()[0]
- if version[0] == '4':
- _logger.debug('Version Comment. sql: %r',
- self.version_comment_query_mysql4)
- cur.execute(self.version_comment_query_mysql4)
- version_comment = cur.fetchone()[1].lower()
- if isinstance(version_comment, bytes):
- # with python3 this query returns bytes
- version_comment = version_comment.decode('utf-8')
- else:
- _logger.debug('Version Comment. sql: %r',
- self.version_comment_query)
- cur.execute(self.version_comment_query)
- version_comment = cur.fetchone()[0].lower()
-
- if 'mariadb' in version_comment:
- product_type = 'mariadb'
- elif 'percona' in version_comment:
- product_type = 'percona'
- else:
- product_type = 'mysql'
-
- self._server_type = (product_type, version)
- return self._server_type
-
def get_connection_id(self):
if not self.connection_id:
self.reset_connection_id()