From f6a576f0ec04a9b2fa2982e2e9188d874bbd156c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 8 Apr 2024 10:07:30 +0200 Subject: Merging upstream version 1.27.2. Signed-off-by: Daniel Baumann --- mycli/sqlexecute.py | 45 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) (limited to 'mycli/sqlexecute.py') diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index c019707..bd5f5d9 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -56,7 +56,7 @@ class ServerInfo: re_species = ( (r'(?P[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), - (r'(?P[0-9\.]+)[a-z0-9]*-TiDB', ServerSpecies.TiDB), + (r'[0-9\.]*-TiDB-v(?P[0-9\.]+)-?(?P[a-z0-9\-]*)', ServerSpecies.TiDB), (r'(?P[0-9\.]+)[a-z0-9]*-(?P[0-9]+$)', ServerSpecies.Percona), (r'(?P[0-9\.]+)[a-z0-9]*-(?P[A-Za-z0-9_]+)', @@ -176,11 +176,15 @@ class SQLExecute(object): if init_command and len(list(special.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + ssl_context = None + if ssl: + ssl_context = self._create_ssl_ctx(ssl) + conn = pymysql.connect( database=db, user=user, password=password, host=host, port=port, unix_socket=socket, use_unicode=True, charset=charset, autocommit=True, client_flag=client_flag, - local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli", + local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli", defer_connect=defer_connect, init_command=init_command ) @@ -354,3 +358,40 @@ class SQLExecute(object): def change_db(self, db): self.conn.select_db(db) self.dbname = db + + def _create_ssl_ctx(self, sslp): + import ssl + + ca = sslp.get("ca") + capath = sslp.get("capath") + hasnoca = ca is None and capath is None + ctx = ssl.create_default_context(cafile=ca, capath=capath) + ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) + ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED + if "cert" in sslp: + ctx.load_cert_chain(sslp["cert"], keyfile=sslp.get("key")) + if "cipher" in sslp: + ctx.set_ciphers(sslp["cipher"]) + + # raise this default to v1.1 or v1.2? + ctx.minimum_version = ssl.TLSVersion.TLSv1 + + if "tls_version" in sslp: + tls_version = sslp["tls_version"] + + if tls_version == "TLSv1": + ctx.minimum_version = ssl.TLSVersion.TLSv1 + ctx.maximum_version = ssl.TLSVersion.TLSv1 + elif tls_version == "TLSv1.1": + ctx.minimum_version = ssl.TLSVersion.TLSv1_1 + ctx.maximum_version = ssl.TLSVersion.TLSv1_1 + elif tls_version == "TLSv1.2": + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 + ctx.maximum_version = ssl.TLSVersion.TLSv1_2 + elif tls_version == "TLSv1.3": + ctx.minimum_version = ssl.TLSVersion.TLSv1_3 + ctx.maximum_version = ssl.TLSVersion.TLSv1_3 + else: + _logger.error('Invalid tls version: %s', tls_version) + + return ctx -- cgit v1.2.3