diff options
Diffstat (limited to '')
-rw-r--r-- | litecli/sqlexecute.py | 205 |
1 files changed, 205 insertions, 0 deletions
diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py new file mode 100644 index 0000000..2392472 --- /dev/null +++ b/litecli/sqlexecute.py @@ -0,0 +1,205 @@ +import logging +import sqlite3 +from contextlib import closing +from sqlite3 import OperationalError +from litecli.packages.special.utils import check_if_sqlitedotcommand + +import sqlparse +import os.path + +from .packages import special + +_logger = logging.getLogger(__name__) + +# FIELD_TYPES = decoders.copy() +# FIELD_TYPES.update({ +# FIELD_TYPE.NULL: type(None) +# }) + + +class SQLExecute(object): + databases_query = """ + PRAGMA database_list + """ + + tables_query = """ + SELECT name + FROM sqlite_master + WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + table_columns_query = """ + SELECT m.name as tableName, p.name as columnName + FROM sqlite_master m + LEFT OUTER JOIN pragma_table_info((m.name)) p ON m.name <> p.name + WHERE m.type IN ('table','view') AND m.name NOT LIKE 'sqlite_%' + ORDER BY tableName, columnName + """ + + indexes_query = """ + SELECT name + FROM sqlite_master + WHERE type = 'index' AND name NOT LIKE 'sqlite_%' + ORDER BY 1 + """ + + functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES + WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' + + def __init__(self, database): + self.dbname = database + self._server_type = None + self.conn = None + if not database: + _logger.debug("Database is not specified. Skip connection.") + return + self.connect() + + def connect(self, database=None): + db = database or self.dbname + _logger.debug("Connection DB Params: \n" "\tdatabase: %r", database) + + db_name = os.path.expanduser(db) + db_dir_name = os.path.dirname(os.path.abspath(db_name)) + if not os.path.exists(db_dir_name): + raise Exception("Path does not exist: {}".format(db_dir_name)) + + conn = sqlite3.connect(database=db_name, isolation_level=None) + conn.text_factory = lambda x: x.decode("utf-8", "backslashreplace") + if self.conn: + self.conn.close() + + self.conn = conn + # Update them after the connection is made to ensure that it was a + # successful connection. + self.dbname = db + + def run(self, statement): + """Execute the sql in the database and return the results. The results + are a list of tuples. Each tuple has 4 values + (title, rows, headers, status). + """ + # Remove spaces and EOL + statement = statement.strip() + if not statement: # Empty string + yield (None, None, None, None) + + # Split the sql into separate queries and run each one. + # Unless it's saving a favorite query, in which case we + # want to save them all together. + if statement.startswith("\\fs"): + components = [statement] + else: + components = sqlparse.split(statement) + + for sql in components: + # Remove spaces, eol and semi-colons. + sql = sql.rstrip(";") + + # \G is treated specially since we have to set the expanded output. + if sql.endswith("\\G"): + special.set_expanded_output(True) + sql = sql[:-2].strip() + + if not self.conn and not ( + sql.startswith(".open") + or sql.lower().startswith("use") + or sql.startswith("\\u") + or sql.startswith("\\?") + or sql.startswith("\\q") + or sql.startswith("help") + or sql.startswith("exit") + or sql.startswith("quit") + ): + _logger.debug( + "Not connected to database. Will not run statement: %s.", sql + ) + raise OperationalError("Not connected to database.") + # yield ('Not connected to database', None, None, None) + # return + + cur = self.conn.cursor() if self.conn else None + try: # Special command + _logger.debug("Trying a dbspecial command. sql: %r", sql) + for result in special.execute(cur, sql): + yield result + except special.CommandNotFound: # Regular SQL + if check_if_sqlitedotcommand(sql): + yield ("dot command not implemented", None, None, None) + else: + _logger.debug("Regular sql statement. sql: %r", sql) + cur.execute(sql) + yield self.get_result(cur) + + def get_result(self, cursor): + """Get the current result's data from the cursor.""" + title = headers = None + + # cursor.description is not None for queries that return result sets, + # e.g. SELECT. + if cursor.description is not None: + headers = [x[0] for x in cursor.description] + status = "{0} row{1} in set" + cursor = list(cursor) + rowcount = len(cursor) + else: + _logger.debug("No rows in result.") + status = "Query OK, {0} row{1} affected" + rowcount = 0 if cursor.rowcount == -1 else cursor.rowcount + cursor = None + + status = status.format(rowcount, "" if rowcount == 1 else "s") + + return (title, cursor, headers, status) + + def tables(self): + """Yields table names""" + + with closing(self.conn.cursor()) as cur: + _logger.debug("Tables Query. sql: %r", self.tables_query) + cur.execute(self.tables_query) + for row in cur: + yield row + + def table_columns(self): + """Yields column names""" + with closing(self.conn.cursor()) as cur: + _logger.debug("Columns Query. sql: %r", self.table_columns_query) + cur.execute(self.table_columns_query) + for row in cur: + yield row + + def databases(self): + if not self.conn: + return + + with closing(self.conn.cursor()) as cur: + _logger.debug("Databases Query. sql: %r", self.databases_query) + for row in cur.execute(self.databases_query): + yield row[1] + + def functions(self): + """Yields tuples of (schema_name, function_name)""" + + with closing(self.conn.cursor()) as cur: + _logger.debug("Functions Query. sql: %r", self.functions_query) + cur.execute(self.functions_query % self.dbname) + for row in cur: + yield row + + def show_candidates(self): + with closing(self.conn.cursor()) as cur: + _logger.debug("Show Query. sql: %r", self.show_candidates_query) + try: + cur.execute(self.show_candidates_query) + except sqlite3.DatabaseError as e: + _logger.error("No show completions due to %r", e) + yield "" + else: + for row in cur: + yield (row[0].split(None, 1)[-1],) + + def server_type(self): + self._server_type = ("sqlite3", "3") + return self._server_type |