summaryrefslogtreecommitdiffstats
path: root/test/conftest.py
blob: e3878b1079749dd501a08eb4e90a78b95b9a959d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import io
import sys
import socket
import pytest


@pytest.fixture(scope='module')
def ssh_audit():
    import ssh_audit.ssh_audit
    return ssh_audit.ssh_audit


# pylint: disable=attribute-defined-outside-init
class _OutputSpy(list):
    def begin(self):
        self.__out = io.StringIO()
        self.__old_stdout = sys.stdout
        sys.stdout = self.__out

    def flush(self):
        lines = self.__out.getvalue().splitlines()
        sys.stdout = self.__old_stdout
        self.__out = None
        return lines


@pytest.fixture(scope='module')
def output_spy():
    return _OutputSpy()


class _VirtualGlobalSocket:
    def __init__(self, vsocket):
        self.vsocket = vsocket
        self.addrinfodata = {}

    # pylint: disable=unused-argument
    def create_connection(self, address, timeout=0, source_address=None):
        # pylint: disable=protected-access
        return self.vsocket._connect(address, True)

    # pylint: disable=unused-argument
    def socket(self,
               family=socket.AF_INET,
               socktype=socket.SOCK_STREAM,
               proto=0,
               fileno=None):
        return self.vsocket

    def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0):
        key = '{}#{}'.format(host, port)
        if key in self.addrinfodata:
            data = self.addrinfodata[key]
            if isinstance(data, Exception):
                raise data
            return data
        if host == 'localhost':
            r = []
            if family in (0, socket.AF_INET):
                r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port)))
            if family in (0, socket.AF_INET6):
                r.append((socket.AF_INET6, 1, 6, '', ('::1', port)))
            return r
        return []


class _VirtualSocket:
    def __init__(self):
        self.sock_address = ('127.0.0.1', 0)
        self.peer_address = None
        self._connected = False
        self.timeout = -1.0
        self.rdata = []
        self.sdata = []
        self.errors = {}
        self.gsock = _VirtualGlobalSocket(self)

    def _check_err(self, method):
        method_error = self.errors.get(method)
        if method_error:
            raise method_error

    def connect(self, address):
        return self._connect(address, False)

    def _connect(self, address, ret=True):
        self.peer_address = address
        self._connected = True
        self._check_err('connect')
        return self if ret else None

    def settimeout(self, timeout):
        self.timeout = timeout

    def gettimeout(self):
        return self.timeout

    def getpeername(self):
        if self.peer_address is None or not self._connected:
            raise OSError(57, 'Socket is not connected')
        return self.peer_address

    def getsockname(self):
        return self.sock_address

    def bind(self, address):
        self.sock_address = address

    def listen(self, backlog):
        pass

    def accept(self):
        # pylint: disable=protected-access
        conn = _VirtualSocket()
        conn.sock_address = self.sock_address
        conn.peer_address = ('127.0.0.1', 0)
        conn._connected = True
        return conn, conn.peer_address

    def recv(self, bufsize, flags=0):
        # pylint: disable=unused-argument
        if not self._connected:
            raise OSError(54, 'Connection reset by peer')
        if not len(self.rdata) > 0:
            return b''
        data = self.rdata.pop(0)
        if isinstance(data, Exception):
            raise data
        return data

    def send(self, data):
        if self.peer_address is None or not self._connected:
            raise OSError(32, 'Broken pipe')
        self._check_err('send')
        self.sdata.append(data)


@pytest.fixture()
def virtual_socket(monkeypatch):
    vsocket = _VirtualSocket()
    gsock = vsocket.gsock
    monkeypatch.setattr(socket, 'create_connection', gsock.create_connection)
    monkeypatch.setattr(socket, 'socket', gsock.socket)
    monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo)
    return vsocket