summaryrefslogtreecommitdiffstats
path: root/pgspecial/iocommands.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--pgspecial/iocommands.py342
1 files changed, 342 insertions, 0 deletions
diff --git a/pgspecial/iocommands.py b/pgspecial/iocommands.py
new file mode 100644
index 0000000..1a4bc8f
--- /dev/null
+++ b/pgspecial/iocommands.py
@@ -0,0 +1,342 @@
+from __future__ import unicode_literals
+import re
+import sys
+import logging
+import click
+import io
+import shlex
+import sqlparse
+import psycopg
+from os.path import expanduser
+from .namedqueries import NamedQueries
+from . import export
+from .main import show_extra_help_command, special_command
+
+NAMED_QUERY_PLACEHOLDERS = frozenset({"$1", "$*", "$@"})
+
+DEFAULT_WATCH_SECONDS = 2
+
+_logger = logging.getLogger(__name__)
+
+
+@export
+def editor_command(command):
+ """
+ Is this an external editor command? (\\e or \\ev)
+
+ :param command: string
+
+ Returns the specific external editor command found.
+ """
+ # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
+ # for both conditions.
+
+ stripped = command.strip()
+ for sought in ("\\e ", "\\ev ", "\\ef "):
+ if stripped.startswith(sought):
+ return sought.strip()
+ for sought in ("\\e",):
+ if stripped.endswith(sought):
+ return sought
+
+
+@export
+def get_filename(sql):
+ if sql.strip().startswith("\\e"):
+ command, _, filename = sql.partition(" ")
+ return filename.strip() or None
+
+
+@export
+@show_extra_help_command(
+ "\\watch",
+ f"\\watch [sec={DEFAULT_WATCH_SECONDS}]",
+ "Execute query every `sec` seconds.",
+)
+def get_watch_command(command):
+ match = re.match(r"(.*?)[\s]*\\watch(\s+\d+)?\s*;?\s*$", command, re.DOTALL)
+ if match:
+ groups = match.groups(default=f"{DEFAULT_WATCH_SECONDS}")
+ return groups[0], int(groups[1])
+ return None, None
+
+
+@export
+def get_editor_query(sql):
+ """Get the query part of an editor command."""
+ sql = sql.strip()
+
+ # The reason we can't simply do .strip('\e') is that it strips characters,
+ # not a substring. So it'll strip "e" in the end of the sql also!
+ # Ex: "select * from style\e" -> "select * from styl".
+ pattern = re.compile(r"(^\\e|\\e$)")
+ while pattern.search(sql):
+ sql = pattern.sub("", sql)
+
+ return sql
+
+
+@export
+def open_external_editor(filename=None, sql=None, editor=None):
+ """
+ Open external editor, wait for the user to type in his query,
+ return the query.
+ :return: list with one tuple, query as first element.
+ """
+
+ message = None
+ filename = filename.strip().split(" ", 1)[0] if filename else None
+
+ sql = sql or ""
+ MARKER = "# Type your query above this line.\n"
+
+ # Populate the editor buffer with the partial sql (if available) and a
+ # placeholder comment.
+ query = click.edit(
+ "{sql}\n\n{marker}".format(sql=sql, marker=MARKER),
+ filename=filename,
+ extension=".sql",
+ editor=editor,
+ )
+
+ if filename:
+ try:
+ query = read_from_file(filename)
+ except IOError:
+ message = "Error reading file: %s." % filename
+
+ if query is not None:
+ query = query.split(MARKER, 1)[0].rstrip("\n")
+ else:
+ # Don't return None for the caller to deal with.
+ # Empty string is ok.
+ query = sql
+
+ return (query, message)
+
+
+def read_from_file(path):
+ with io.open(expanduser(path), encoding="utf-8") as f:
+ contents = f.read()
+ return contents
+
+
+def _index_of_file_name(tokenlist):
+ for idx, token in reversed(list(enumerate(tokenlist[:-2]))):
+ if token.is_keyword and token.value.upper() in ("TO", "FROM"):
+ return idx + 2
+ raise Exception("Missing keyword in \\copy command. Either TO or FROM is required.")
+
+
+@special_command(
+ "\\copy",
+ "\\copy [tablename] to/from [filename]",
+ "Copy data between a file and a table.",
+)
+def copy(cur, pattern, verbose):
+ """Copies table data to/from files"""
+
+ # Replace the specified file destination with STDIN or STDOUT
+ parsed = sqlparse.parse(pattern)
+ tokens = parsed[0].tokens
+ idx = _index_of_file_name(tokens)
+ file_name = tokens[idx].value
+ before_file_name = "".join(t.value for t in tokens[:idx])
+ after_file_name = "".join(t.value for t in tokens[idx + 1 :])
+
+ direction = tokens[idx - 2].value.upper()
+ replacement_file_name = "STDIN" if direction == "FROM" else "STDOUT"
+ query = f"{before_file_name} {replacement_file_name} {after_file_name}"
+ open_mode = "r" if direction == "FROM" else "wb"
+ if file_name.startswith("'") and file_name.endswith("'"):
+ file = io.open(expanduser(file_name.strip("'")), mode=open_mode)
+ elif "stdin" in file_name.lower():
+ file = sys.stdin.buffer
+ elif "stdout" in file_name.lower():
+ file = sys.stdout.buffer
+ else:
+ raise Exception("Enclose filename in single quotes")
+
+ if direction == "FROM":
+ with cur.copy("copy " + query) as pgcopy:
+ while True:
+ data = file.read(8192)
+ if not data:
+ break
+ pgcopy.write(data)
+ else:
+ with cur.copy("copy " + query) as pgcopy:
+ for data in pgcopy:
+ file.write(bytes(data))
+
+ if cur.description:
+ headers = [x.name for x in cur.description]
+ return [(None, None, headers, cur.statusmessage)]
+ else:
+ return [(None, None, None, cur.statusmessage)]
+
+
+def subst_favorite_query_args(query, args):
+ """replace positional parameters ($1,$2,...$n) in query."""
+ is_query_with_aggregation = ("$*" in query) or ("$@" in query)
+
+ # In case of arguments aggregation we replace all positional arguments until the
+ # first one not present in the query. Then we aggregate all the remaining ones and
+ # replace the placeholder with them.
+ for idx, val in enumerate(args, start=1):
+ subst_var = "$" + str(idx)
+ if subst_var not in query:
+ if is_query_with_aggregation:
+ # remove consumed arguments ( - 1 to include current value)
+ args = args[idx - 1 :]
+ break
+
+ return [
+ None,
+ "query does not have substitution parameter "
+ + subst_var
+ + ":\n "
+ + query,
+ ]
+
+ query = query.replace(subst_var, val)
+ # we consumed all arguments
+ else:
+ args = []
+
+ if is_query_with_aggregation and not args:
+ return [None, "missing substitution for $* or $@ in query:\n" + query]
+
+ if "$*" in query:
+ query = query.replace("$*", ", ".join(args))
+ elif "$@" in query:
+ query = query.replace("$@", ", ".join(map("'{}'".format, args)))
+
+ match = re.search("\\$\\d+", query)
+ if match:
+ return [
+ None,
+ "missing substitution for " + match.group(0) + " in query:\n " + query,
+ ]
+
+ return [query, None]
+
+
+@special_command(
+ "\\n", "\\n[+] [name] [param1 param2 ...]", "List or execute named queries."
+)
+def execute_named_query(cur, pattern, **_):
+ """Returns (title, rows, headers, status)"""
+ if pattern == "":
+ return list_named_queries(True)
+
+ params = shlex.split(pattern)
+ pattern = params.pop(0)
+
+ query = NamedQueries.instance.get(pattern)
+ title = "> {}".format(query)
+ if query is None:
+ message = "No named query: {}".format(pattern)
+ return [(None, None, None, message)]
+
+ try:
+ if any(p in query for p in NAMED_QUERY_PLACEHOLDERS):
+ query, params = subst_favorite_query_args(query, params)
+ if query is None:
+ raise Exception("Bad arguments\n" + params)
+ cur.execute(query)
+ except psycopg.errors.SyntaxError:
+ if "%s" in query:
+ raise Exception(
+ "Bad arguments: "
+ 'please use "$1", "$2", etc. for named queries instead of "%s"'
+ )
+ else:
+ raise
+ except (IndexError, TypeError):
+ raise Exception("Bad arguments")
+
+ if cur.description:
+ headers = [x.name for x in cur.description]
+ return [(title, cur, headers, cur.statusmessage)]
+ else:
+ return [(title, None, None, cur.statusmessage)]
+
+
+def list_named_queries(verbose):
+ """List of all named queries.
+ Returns (title, rows, headers, status)"""
+ if not verbose:
+ rows = [[r] for r in NamedQueries.instance.list()]
+ headers = ["Name"]
+ else:
+ headers = ["Name", "Query"]
+ rows = [[r, NamedQueries.instance.get(r)] for r in NamedQueries.instance.list()]
+
+ if not rows:
+ status = NamedQueries.instance.usage
+ else:
+ status = ""
+ return [("", rows, headers, status)]
+
+
+@special_command("\\np", "\\np name_pattern", "Print a named query.")
+def get_named_query(pattern, **_):
+ """Get a named query that matches name_pattern.
+
+ The named pattern can be a regular expression. Returns (title,
+ rows, headers, status)
+
+ """
+
+ usage = "Syntax: \\np name.\n\n" + NamedQueries.instance.usage
+ if not pattern:
+ return [(None, None, None, usage)]
+
+ name = pattern.strip()
+ if not name:
+ return [(None, None, None, usage + "Err: A name is required.")]
+
+ headers = ["Name", "Query"]
+ rows = [
+ (r, NamedQueries.instance.get(r))
+ for r in NamedQueries.instance.list()
+ if re.search(name, r)
+ ]
+
+ status = ""
+ if not rows:
+ status = "No match found"
+
+ return [("", rows, headers, status)]
+
+
+@special_command("\\ns", "\\ns name query", "Save a named query.")
+def save_named_query(pattern, **_):
+ """Save a new named query.
+ Returns (title, rows, headers, status)"""
+
+ usage = "Syntax: \\ns name query.\n\n" + NamedQueries.instance.usage
+ if not pattern:
+ return [(None, None, None, usage)]
+
+ name, _, query = pattern.partition(" ")
+
+ # If either name or query is missing then print the usage and complain.
+ if (not name) or (not query):
+ return [(None, None, None, usage + "Err: Both name and query are required.")]
+
+ NamedQueries.instance.save(name, query)
+ return [(None, None, None, "Saved.")]
+
+
+@special_command("\\nd", "\\nd [name]", "Delete a named query.")
+def delete_named_query(pattern, **_):
+ """Delete an existing named query."""
+ usage = "Syntax: \\nd name.\n\n" + NamedQueries.instance.usage
+ if not pattern:
+ return [(None, None, None, usage)]
+
+ status = NamedQueries.instance.delete(pattern)
+
+ return [(None, None, None, status)]