summaryrefslogtreecommitdiffstats
path: root/test/features/steps/wrappers.py
blob: de833dd23edf467105e672d0f36168e8bcd4198f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import re
import pexpect
import sys
import textwrap

try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO


def expect_exact(context, expected, timeout):
    timedout = False
    try:
        context.cli.expect_exact(expected, timeout=timeout)
    except pexpect.exceptions.TIMEOUT:
        timedout = True
    if timedout:
        # Strip color codes out of the output.
        actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?',
                        '', context.cli.before)
        raise Exception(
            textwrap.dedent('''\
                Expected:
                ---
                {0!r}
                ---
                Actual:
                ---
                {1!r}
                ---
                Full log:
                ---
                {2!r}
                ---
            ''').format(
                expected,
                actual,
                context.logfile.getvalue()
            )
        )


def expect_pager(context, expected, timeout):
    expect_exact(context, "{0}\r\n{1}{0}\r\n".format(
        context.conf['pager_boundary'], expected), timeout=timeout)


def run_cli(context, run_args=None):
    """Run the process using pexpect."""
    run_args = run_args or []
    if context.conf.get('host', None):
        run_args.extend(('-h', context.conf['host']))
    if context.conf.get('user', None):
        run_args.extend(('-u', context.conf['user']))
    if context.conf.get('pass', None):
        run_args.extend(('-p', context.conf['pass']))
    if context.conf.get('dbname', None):
        run_args.extend(('-D', context.conf['dbname']))
    if context.conf.get('defaults-file', None):
        run_args.extend(('--defaults-file', context.conf['defaults-file']))
    if context.conf.get('myclirc', None):
        run_args.extend(('--myclirc', context.conf['myclirc']))
    try:
        cli_cmd = context.conf['cli_command']
    except KeyError:
        cli_cmd = (
            '{0!s} -c "'
            'import coverage ; '
            'coverage.process_startup(); '
            'import mycli.main; '
            'mycli.main.cli()'
            '"'
        ).format(sys.executable)

    cmd_parts = [cli_cmd] + run_args
    cmd = ' '.join(cmd_parts)
    context.cli = pexpect.spawnu(cmd, cwd=context.package_root)
    context.logfile = StringIO()
    context.cli.logfile = context.logfile
    context.exit_sent = False
    context.currentdb = context.conf['dbname']


def wait_prompt(context, prompt=None):
    """Make sure prompt is displayed."""
    if prompt is None:
        user = context.conf['user']
        host = context.conf['host']
        dbname = context.currentdb
        prompt = '{0}@{1}:{2}>'.format(
            user, host, dbname),
    expect_exact(context, prompt, timeout=5)
    context.atprompt = True