diff options
Diffstat (limited to 'mycli/sqlexecute.py')
-rw-r--r-- | mycli/sqlexecute.py | 45 |
1 files changed, 43 insertions, 2 deletions
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index c019707..bd5f5d9 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -56,7 +56,7 @@ class ServerInfo: re_species = ( (r'(?P<version>[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), - (r'(?P<version>[0-9\.]+)[a-z0-9]*-TiDB', ServerSpecies.TiDB), + (r'[0-9\.]*-TiDB-v(?P<version>[0-9\.]+)-?(?P<comment>[a-z0-9\-]*)', ServerSpecies.TiDB), (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[0-9]+$)', ServerSpecies.Percona), (r'(?P<version>[0-9\.]+)[a-z0-9]*-(?P<comment>[A-Za-z0-9_]+)', @@ -176,11 +176,15 @@ class SQLExecute(object): if init_command and len(list(special.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + ssl_context = None + if ssl: + ssl_context = self._create_ssl_ctx(ssl) + conn = pymysql.connect( database=db, user=user, password=password, host=host, port=port, unix_socket=socket, use_unicode=True, charset=charset, autocommit=True, client_flag=client_flag, - local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", + local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli", defer_connect=defer_connect, init_command=init_command ) @@ -354,3 +358,40 @@ class SQLExecute(object): def change_db(self, db): self.conn.select_db(db) self.dbname = db + + def _create_ssl_ctx(self, sslp): + import ssl + + ca = sslp.get("ca") + capath = sslp.get("capath") + hasnoca = ca is None and capath is None + ctx = ssl.create_default_context(cafile=ca, capath=capath) + ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + if "cert" in sslp: + ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key")) + if "cipher" in sslp: + ctx.set_ciphers(sslp["cipher"]) + + # raise this default to v1.1 or v1.2? + ctx.minimum_version = ssl.TLSVersion.TLSv1 + + if "tls_version" in sslp: + tls_version = sslp["tls_version"] + + if tls_version == "TLSv1": + ctx.minimum_version = ssl.TLSVersion.TLSv1 + ctx.maximum_version = ssl.TLSVersion.TLSv1 + elif tls_version == "TLSv1.1": + ctx.minimum_version = ssl.TLSVersion.TLSv1_1 + ctx.maximum_version = ssl.TLSVersion.TLSv1_1 + elif tls_version == "TLSv1.2": + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + elif tls_version == "TLSv1.3": + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + ctx.maximum_version = ssl.TLSVersion.TLSv1_3 + else: + _logger.error('Invalid tls version: %s', tls_version) + + return ctx |