diff options
Diffstat (limited to '')
-rw-r--r-- | test/conftest.py | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..e3878b1 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,145 @@ +import io +import sys +import socket +import pytest + + +@pytest.fixture(scope='module') +def ssh_audit(): + import ssh_audit.ssh_audit + return ssh_audit.ssh_audit + + +# pylint: disable=attribute-defined-outside-init +class _OutputSpy(list): + def begin(self): + self.__out = io.StringIO() + self.__old_stdout = sys.stdout + sys.stdout = self.__out + + def flush(self): + lines = self.__out.getvalue().splitlines() + sys.stdout = self.__old_stdout + self.__out = None + return lines + + +@pytest.fixture(scope='module') +def output_spy(): + return _OutputSpy() + + +class _VirtualGlobalSocket: + def __init__(self, vsocket): + self.vsocket = vsocket + self.addrinfodata = {} + + # pylint: disable=unused-argument + def create_connection(self, address, timeout=0, source_address=None): + # pylint: disable=protected-access + return self.vsocket._connect(address, True) + + # pylint: disable=unused-argument + def socket(self, + family=socket.AF_INET, + socktype=socket.SOCK_STREAM, + proto=0, + fileno=None): + return self.vsocket + + def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0): + key = '{}#{}'.format(host, port) + if key in self.addrinfodata: + data = self.addrinfodata[key] + if isinstance(data, Exception): + raise data + return data + if host == 'localhost': + r = [] + if family in (0, socket.AF_INET): + r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port))) + if family in (0, socket.AF_INET6): + r.append((socket.AF_INET6, 1, 6, '', ('::1', port))) + return r + return [] + + +class _VirtualSocket: + def __init__(self): + self.sock_address = ('127.0.0.1', 0) + self.peer_address = None + self._connected = False + self.timeout = -1.0 + self.rdata = [] + self.sdata = [] + self.errors = {} + self.gsock = _VirtualGlobalSocket(self) + + def _check_err(self, method): + method_error = self.errors.get(method) + if method_error: + raise method_error + + def connect(self, address): + return self._connect(address, False) + + def _connect(self, address, ret=True): + self.peer_address = address + self._connected = True + self._check_err('connect') + return self if ret else None + + def settimeout(self, timeout): + self.timeout = timeout + + def gettimeout(self): + return self.timeout + + def getpeername(self): + if self.peer_address is None or not self._connected: + raise OSError(57, 'Socket is not connected') + return self.peer_address + + def getsockname(self): + return self.sock_address + + def bind(self, address): + self.sock_address = address + + def listen(self, backlog): + pass + + def accept(self): + # pylint: disable=protected-access + conn = _VirtualSocket() + conn.sock_address = self.sock_address + conn.peer_address = ('127.0.0.1', 0) + conn._connected = True + return conn, conn.peer_address + + def recv(self, bufsize, flags=0): + # pylint: disable=unused-argument + if not self._connected: + raise OSError(54, 'Connection reset by peer') + if not len(self.rdata) > 0: + return b'' + data = self.rdata.pop(0) + if isinstance(data, Exception): + raise data + return data + + def send(self, data): + if self.peer_address is None or not self._connected: + raise OSError(32, 'Broken pipe') + self._check_err('send') + self.sdata.append(data) + + +@pytest.fixture() +def virtual_socket(monkeypatch): + vsocket = _VirtualSocket() + gsock = vsocket.gsock + monkeypatch.setattr(socket, 'create_connection', gsock.create_connection) + monkeypatch.setattr(socket, 'socket', gsock.socket) + monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo) + return vsocket |