summaryrefslogtreecommitdiffstats
path: root/mycli/sqlexecute.py
diff options
context:
space:
mode:
Diffstat (limited to 'mycli/sqlexecute.py')
-rw-r--r--mycli/sqlexecute.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py
index c68af0f..7534982 100644
--- a/mycli/sqlexecute.py
+++ b/mycli/sqlexecute.py
@@ -42,7 +42,7 @@ class SQLExecute(object):
def __init__(self, database, user, password, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password,
- ssh_key_filename):
+ ssh_key_filename, init_command=None):
self.dbname = database
self.user = user
self.password = password
@@ -59,12 +59,13 @@ class SQLExecute(object):
self.ssh_port = ssh_port
self.ssh_password = ssh_password
self.ssh_key_filename = ssh_key_filename
+ self.init_command = init_command
self.connect()
def connect(self, database=None, user=None, password=None, host=None,
port=None, socket=None, charset=None, local_infile=None,
ssl=None, ssh_host=None, ssh_port=None, ssh_user=None,
- ssh_password=None, ssh_key_filename=None):
+ ssh_password=None, ssh_key_filename=None, init_command=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
@@ -79,6 +80,7 @@ class SQLExecute(object):
ssh_port = (ssh_port or self.ssh_port)
ssh_password = (ssh_password or self.ssh_password)
ssh_key_filename = (ssh_key_filename or self.ssh_key_filename)
+ init_command = (init_command or self.init_command)
_logger.debug(
'Connection DB Params: \n'
'\tdatabase: %r'
@@ -93,9 +95,11 @@ class SQLExecute(object):
'\tssh_host: %r'
'\tssh_port: %r'
'\tssh_password: %r'
- '\tssh_key_filename: %r',
+ '\tssh_key_filename: %r'
+ '\tinit_command: %r',
db, user, host, port, socket, charset, local_infile, ssl,
- ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename
+ ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename,
+ init_command
)
conv = conversions.copy()
conv.update({
@@ -110,12 +114,16 @@ class SQLExecute(object):
if ssh_host:
defer_connect = True
+ client_flag = pymysql.constants.CLIENT.INTERACTIVE
+ if init_command and len(list(special.split_queries(init_command))) > 1:
+ client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS
+
conn = pymysql.connect(
database=db, user=user, password=password, host=host, port=port,
unix_socket=socket, use_unicode=True, charset=charset,
- autocommit=True, client_flag=pymysql.constants.CLIENT.INTERACTIVE,
+ autocommit=True, client_flag=client_flag,
local_infile=local_infile, conv=conv, ssl=ssl, program_name="mycli",
- defer_connect=defer_connect
+ defer_connect=defer_connect, init_command=init_command
)
if ssh_host:
@@ -146,6 +154,7 @@ class SQLExecute(object):
self.socket = socket
self.charset = charset
self.ssl = ssl
+ self.init_command = init_command
# retrieve connection id
self.reset_connection_id()