summaryrefslogtreecommitdiffstats
path: root/tests/utils.py
blob: 79d59e62624f801d52514f5bca165b641cc7a1d5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# -*- coding: utf-8 -*-

import os
import time
import signal
import platform
import multiprocessing
from contextlib import closing

import sqlite3
import pytest

from litecli.main import special

DATABASE = "test.sqlite3"


def db_connection(dbname=":memory:"):
    conn = sqlite3.connect(database=dbname, isolation_level=None)
    return conn


try:
    db_connection()
    CAN_CONNECT_TO_DB = True
except Exception:
    CAN_CONNECT_TO_DB = False

dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason="Error creating sqlite connection")


def create_db(dbname):
    with closing(db_connection().cursor()) as cur:
        try:
            cur.execute("""DROP DATABASE IF EXISTS _test_db""")
            cur.execute("""CREATE DATABASE _test_db""")
        except Exception:
            pass


def drop_tables(dbname):
    with closing(db_connection().cursor()) as cur:
        try:
            cur.execute("""DROP DATABASE IF EXISTS _test_db""")
        except Exception:
            pass


def run(executor, sql, rows_as_list=True):
    """Return string output for the sql to be run."""
    result = []

    for title, rows, headers, status in executor.run(sql):
        rows = list(rows) if (rows_as_list and rows) else rows
        result.append({"title": title, "rows": rows, "headers": headers, "status": status})

    return result


def set_expanded_output(is_expanded):
    """Pass-through for the tests."""
    return special.set_expanded_output(is_expanded)


def is_expanded_output():
    """Pass-through for the tests."""
    return special.is_expanded_output()


def send_ctrl_c_to_pid(pid, wait_seconds):
    """Sends a Ctrl-C like signal to the given `pid` after `wait_seconds`
    seconds."""
    time.sleep(wait_seconds)
    system_name = platform.system()
    if system_name == "Windows":
        os.kill(pid, signal.CTRL_C_EVENT)
    else:
        os.kill(pid, signal.SIGINT)


def send_ctrl_c(wait_seconds):
    """Create a process that sends a Ctrl-C like signal to the current process
    after `wait_seconds` seconds.

    Returns the `multiprocessing.Process` created.

    """
    ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds))
    ctrl_c_process.start()
    return ctrl_c_process


def assert_result_equal(
    result,
    title=None,
    rows=None,
    headers=None,
    status=None,
    auto_status=True,
    assert_contains=False,
):
    """Assert that an sqlexecute.run() result matches the expected values."""
    if status is None and auto_status and rows:
        status = "{} row{} in set".format(len(rows), "s" if len(rows) > 1 else "")
    fields = {"title": title, "rows": rows, "headers": headers, "status": status}

    if assert_contains:
        # Do a loose match on the results using the *in* operator.
        for key, field in fields.items():
            if field:
                assert field in result[0][key]
    else:
        # Do an exact match on the fields.
        assert result == [fields]