summaryrefslogtreecommitdiffstats
path: root/mycli/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'mycli/main.py')
-rwxr-xr-xmycli/main.py90
1 files changed, 67 insertions, 23 deletions
diff --git a/mycli/main.py b/mycli/main.py
index 03797a0..f2b2fd8 100755
--- a/mycli/main.py
+++ b/mycli/main.py
@@ -21,13 +21,14 @@ from cli_helpers.tabular_output import preprocessors
from cli_helpers.utils import strip_ansi
import click
import sqlparse
-from mycli.packages.parseutils import is_dropping_database
+from mycli.packages.parseutils import is_dropping_database, is_destructive
from prompt_toolkit.completion import DynamicCompleter
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register
from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
from prompt_toolkit.document import Document
from prompt_toolkit.filters import HasFocus, IsDone
+from prompt_toolkit.formatted_text import ANSI
from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor,
ConditionalProcessor)
from prompt_toolkit.lexers import PygmentsLexer
@@ -98,7 +99,7 @@ class MyCli(object):
xdg_config_home = "~/.config"
system_config_files = [
'/etc/myclirc',
- os.path.join(xdg_config_home, "mycli", "myclirc")
+ os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc")
]
default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc')
@@ -152,7 +153,7 @@ class MyCli(object):
c['main'].as_bool('auto_vertical_output')
# Write user config if system config wasn't the last config loaded.
- if c.filename not in self.system_config_files:
+ if c.filename not in self.system_config_files and not os.path.exists(myclirc):
write_default_config(self.default_config_file, myclirc)
# audit log
@@ -238,6 +239,9 @@ class MyCli(object):
)
return
+ if arg.startswith('`') and arg.endswith('`'):
+ arg = re.sub(r'^`(.*)`$', r'\1', arg)
+ arg = re.sub(r'``', r'`', arg)
self.sqlexecute.change_db(arg)
yield (None, None, None, 'You are now connected to database "%s" as '
@@ -363,7 +367,7 @@ class MyCli(object):
def connect(self, database='', user='', passwd='', host='', port='',
socket='', charset='', local_infile='', ssl='',
ssh_user='', ssh_host='', ssh_port='',
- ssh_password='', ssh_key_filename=''):
+ ssh_password='', ssh_key_filename='', init_command=''):
cnf = {'database': None,
'user': None,
@@ -387,16 +391,16 @@ class MyCli(object):
database = database or cnf['database']
# Socket interface not supported for SSH connections
- if port or host or ssh_host or ssh_port:
+ if port or (host and host != 'localhost') or (ssh_host and ssh_port):
socket = ''
else:
socket = socket or cnf['socket'] or guess_socket_location()
user = user or cnf['user'] or os.getenv('USER')
host = host or cnf['host']
- port = port or cnf['port']
+ port = int(port or cnf['port'] or 3306)
ssl = ssl or {}
- passwd = passwd or cnf['password']
+ passwd = passwd if isinstance(passwd, str) else cnf['password']
charset = charset or cnf['default-character-set'] or 'utf8'
# Favor whichever local_infile option is set.
@@ -420,7 +424,7 @@ class MyCli(object):
self.sqlexecute = SQLExecute(
database, user, passwd, host, port, socket, charset,
local_infile, ssl, ssh_user, ssh_host, ssh_port,
- ssh_password, ssh_key_filename
+ ssh_password, ssh_key_filename, init_command
)
except OperationalError as e:
if ('Access denied for user' in e.args[1]):
@@ -429,7 +433,7 @@ class MyCli(object):
self.sqlexecute = SQLExecute(
database, user, new_passwd, host, port, socket,
charset, local_infile, ssl, ssh_user, ssh_host,
- ssh_port, ssh_password, ssh_key_filename
+ ssh_port, ssh_password, ssh_key_filename, init_command
)
else:
raise e
@@ -438,7 +442,7 @@ class MyCli(object):
if not WIN and socket:
socket_owner = getpwuid(os.stat(socket).st_uid).pw_name
self.echo(
- f"Connecting to socket {socket}, owned by user {socket_owner}")
+ f"Connecting to socket {socket}, owned by user {socket_owner}", err=True)
try:
_connect()
except OperationalError as e:
@@ -481,7 +485,7 @@ class MyCli(object):
exit(1)
def handle_editor_command(self, text):
- """Editor command is any query that is prefixed or suffixed by a '\e'.
+ r"""Editor command is any query that is prefixed or suffixed by a '\e'.
The reason for a while loop is because a user might edit a query
multiple times. For eg:
@@ -511,6 +515,24 @@ class MyCli(object):
continue
return text
+ def handle_clip_command(self, text):
+ r"""A clip command is any query that is prefixed or suffixed by a
+ '\clip'.
+
+ :param text: Document
+ :return: Boolean
+
+ """
+
+ if special.clip_command(text):
+ query = (special.get_clip_query(text) or
+ self.get_last_query())
+ message = special.copy_query_to_clipboard(sql=query)
+ if message:
+ raise RuntimeError(message)
+ return True
+ return False
+
def run_cli(self):
iterations = 0
sqlexecute = self.sqlexecute
@@ -548,10 +570,13 @@ class MyCli(object):
prompt = self.get_prompt(self.prompt_format)
if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
prompt = self.get_prompt('\\d> ')
- return [('class:prompt', prompt)]
+ prompt = prompt.replace("\\x1b", "\x1b")
+ return ANSI(prompt)
def get_continuation(width, *_):
- if self.multiline_continuation_char:
+ if self.multiline_continuation_char == '':
+ continuation = ''
+ elif self.multiline_continuation_char:
left_padding = width - len(self.multiline_continuation_char)
continuation = " " * \
max((left_padding - 1), 0) + \
@@ -580,6 +605,15 @@ class MyCli(object):
self.echo(str(e), err=True, fg='red')
return
+ try:
+ if self.handle_clip_command(text):
+ return
+ except RuntimeError as e:
+ logger.error("sql: %r, error: %r", text, e)
+ logger.error("traceback: %r", traceback.format_exc())
+ self.echo(str(e), err=True, fg='red')
+ return
+
if not text.strip():
return
@@ -654,6 +688,7 @@ class MyCli(object):
result_count += 1
mutating = mutating or destroy or is_mutating(status)
special.unset_once_if_written()
+ special.unset_pipe_once_if_written()
except EOFError as e:
raise e
except KeyboardInterrupt:
@@ -814,6 +849,7 @@ class MyCli(object):
self.log_output(line)
special.write_tee(line)
special.write_once(line)
+ special.write_pipe_once(line)
if fits or output_via_pager:
# buffering
@@ -1051,6 +1087,10 @@ class MyCli(object):
help='Read this path from the login file.')
@click.option('-e', '--execute', type=str,
help='Execute command and quit.')
+@click.option('--init-command', type=str,
+ help='SQL statement to execute after connecting.')
+@click.option('--charset', type=str,
+ help='Character set for MySQL session.')
@click.argument('database', default='', nargs=1)
def cli(database, user, host, port, socket, password, dbname,
version, verbose, prompt, logfile, defaults_group_suffix,
@@ -1058,7 +1098,8 @@ def cli(database, user, host, port, socket, password, dbname,
ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher,
ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn,
list_dsn, ssh_user, ssh_host, ssh_port, ssh_password,
- ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host):
+ ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host,
+ init_command, charset):
"""A MySQL terminal client with auto-completion and syntax highlighting.
\b
@@ -1182,7 +1223,9 @@ def cli(database, user, host, port, socket, password, dbname,
ssh_host=ssh_host,
ssh_port=ssh_port,
ssh_password=ssh_password,
- ssh_key_filename=ssh_key_filename
+ ssh_key_filename=ssh_key_filename,
+ init_command=init_command,
+ charset=charset
)
mycli.logger.debug('Launch Params: \n'
@@ -1217,14 +1260,15 @@ def cli(database, user, host, port, socket, password, dbname,
click.secho('Sorry... :(', err=True, fg='red')
exit(1)
- try:
- sys.stdin = open('/dev/tty')
- except (IOError, OSError):
- mycli.logger.warning('Unable to open TTY as stdin.')
+ if mycli.destructive_warning and is_destructive(stdin_text):
+ try:
+ sys.stdin = open('/dev/tty')
+ warn_confirmed = confirm_destructive_query(stdin_text)
+ except (IOError, OSError):
+ mycli.logger.warning('Unable to open TTY as stdin.')
+ if not warn_confirmed:
+ exit(0)
- if (mycli.destructive_warning and
- confirm_destructive_query(stdin_text) is False):
- exit(0)
try:
new_line = True
@@ -1287,7 +1331,7 @@ def is_select(status):
def thanks_picker(files=()):
contents = []
for line in fileinput.input(files=files):
- m = re.match('^ *\* (.*)', line)
+ m = re.match(r'^ *\* (.*)', line)
if m:
contents.append(m.group(1))
return choice(contents)