summaryrefslogtreecommitdiffstats
path: root/mycli/sqlexecute.py
diff options
context:
space:
mode:
Diffstat (limited to 'mycli/sqlexecute.py')
-rw-r--r--mycli/sqlexecute.py45
1 files changed, 43 insertions, 2 deletions
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
index c019707..bd5f5d9 100644
--- a/mycli/sqlexecute.py
+++ b/mycli/sqlexecute.py
@@ -56,7 +56,7 @@ class ServerInfo:
re_species = (
(r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB),
- (r'(?P<version>[0-9\.]+)[a-z0-9]*-TiDB', ServerSpecies.TiDB),
+ (r'[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)', ServerSpecies.TiDB),
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)',
ServerSpecies.Percona),
(r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)',
@@ -176,11 +176,15 @@ class SQLExecute(object):
if init_command and len(list(special.split_queries(init_command))) > 1:
client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
+ ssl_context = None
+ if ssl:
+ ssl_context = self._create_ssl_ctx(ssl)
+
conn = pymysql.connect(
database=db, user=user, password=password, host=host, port=port,
unix_socket=socket, use_unicode=True, charset=charset,
autocommit=True, client_flag=client_flag,
- local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
+ local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli",
defer_connect=defer_connect, init_command=init_command
)
@@ -354,3 +358,40 @@ class SQLExecute(object):
def change_db(self, db):
self.conn.select_db(db)
self.dbname = db
+
+ def _create_ssl_ctx(self, sslp):
+ import ssl
+
+ ca = sslp.get("ca")
+ capath = sslp.get("capath")
+ hasnoca = ca is None and capath is None
+ ctx = ssl.create_default_context(cafile=ca, capath=capath)
+ ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
+ ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
+ if "cert" in sslp:
+ ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key"))
+ if "cipher" in sslp:
+ ctx.set_ciphers(sslp["cipher"])
+
+ # raise this default to v1.1 or v1.2?
+ ctx.minimum_version = ssl.TLSVersion.TLSv1
+
+ if "tls_version" in sslp:
+ tls_version = sslp["tls_version"]
+
+ if tls_version == "TLSv1":
+ ctx.minimum_version = ssl.TLSVersion.TLSv1
+ ctx.maximum_version = ssl.TLSVersion.TLSv1
+ elif tls_version == "TLSv1.1":
+ ctx.minimum_version = ssl.TLSVersion.TLSv1_1
+ ctx.maximum_version = ssl.TLSVersion.TLSv1_1
+ elif tls_version == "TLSv1.2":
+ ctx.minimum_version = ssl.TLSVersion.TLSv1_2
+ ctx.maximum_version = ssl.TLSVersion.TLSv1_2
+ elif tls_version == "TLSv1.3":
+ ctx.minimum_version = ssl.TLSVersion.TLSv1_3
+ ctx.maximum_version = ssl.TLSVersion.TLSv1_3
+ else:
+ _logger.error('Invalid tls version: %s', tls_version)
+
+ return ctx