summaryrefslogtreecommitdiffstats
path: root/test/conftest.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/conftest.py')
-rw-r--r--test/conftest.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 0000000..e3878b1
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,145 @@
+import io
+import sys
+import socket
+import pytest
+
+
+@pytest.fixture(scope='module')
+def ssh_audit():
+ import ssh_audit.ssh_audit
+ return ssh_audit.ssh_audit
+
+
+# pylint: disable=attribute-defined-outside-init
+class _OutputSpy(list):
+ def begin(self):
+ self.__out = io.StringIO()
+ self.__old_stdout = sys.stdout
+ sys.stdout = self.__out
+
+ def flush(self):
+ lines = self.__out.getvalue().splitlines()
+ sys.stdout = self.__old_stdout
+ self.__out = None
+ return lines
+
+
+@pytest.fixture(scope='module')
+def output_spy():
+ return _OutputSpy()
+
+
+class _VirtualGlobalSocket:
+ def __init__(self, vsocket):
+ self.vsocket = vsocket
+ self.addrinfodata = {}
+
+ # pylint: disable=unused-argument
+ def create_connection(self, address, timeout=0, source_address=None):
+ # pylint: disable=protected-access
+ return self.vsocket._connect(address, True)
+
+ # pylint: disable=unused-argument
+ def socket(self,
+ family=socket.AF_INET,
+ socktype=socket.SOCK_STREAM,
+ proto=0,
+ fileno=None):
+ return self.vsocket
+
+ def getaddrinfo(self, host, port, family=0, socktype=0, proto=0, flags=0):
+ key = '{}#{}'.format(host, port)
+ if key in self.addrinfodata:
+ data = self.addrinfodata[key]
+ if isinstance(data, Exception):
+ raise data
+ return data
+ if host == 'localhost':
+ r = []
+ if family in (0, socket.AF_INET):
+ r.append((socket.AF_INET, 1, 6, '', ('127.0.0.1', port)))
+ if family in (0, socket.AF_INET6):
+ r.append((socket.AF_INET6, 1, 6, '', ('::1', port)))
+ return r
+ return []
+
+
+class _VirtualSocket:
+ def __init__(self):
+ self.sock_address = ('127.0.0.1', 0)
+ self.peer_address = None
+ self._connected = False
+ self.timeout = -1.0
+ self.rdata = []
+ self.sdata = []
+ self.errors = {}
+ self.gsock = _VirtualGlobalSocket(self)
+
+ def _check_err(self, method):
+ method_error = self.errors.get(method)
+ if method_error:
+ raise method_error
+
+ def connect(self, address):
+ return self._connect(address, False)
+
+ def _connect(self, address, ret=True):
+ self.peer_address = address
+ self._connected = True
+ self._check_err('connect')
+ return self if ret else None
+
+ def settimeout(self, timeout):
+ self.timeout = timeout
+
+ def gettimeout(self):
+ return self.timeout
+
+ def getpeername(self):
+ if self.peer_address is None or not self._connected:
+ raise OSError(57, 'Socket is not connected')
+ return self.peer_address
+
+ def getsockname(self):
+ return self.sock_address
+
+ def bind(self, address):
+ self.sock_address = address
+
+ def listen(self, backlog):
+ pass
+
+ def accept(self):
+ # pylint: disable=protected-access
+ conn = _VirtualSocket()
+ conn.sock_address = self.sock_address
+ conn.peer_address = ('127.0.0.1', 0)
+ conn._connected = True
+ return conn, conn.peer_address
+
+ def recv(self, bufsize, flags=0):
+ # pylint: disable=unused-argument
+ if not self._connected:
+ raise OSError(54, 'Connection reset by peer')
+ if not len(self.rdata) > 0:
+ return b''
+ data = self.rdata.pop(0)
+ if isinstance(data, Exception):
+ raise data
+ return data
+
+ def send(self, data):
+ if self.peer_address is None or not self._connected:
+ raise OSError(32, 'Broken pipe')
+ self._check_err('send')
+ self.sdata.append(data)
+
+
+@pytest.fixture()
+def virtual_socket(monkeypatch):
+ vsocket = _VirtualSocket()
+ gsock = vsocket.gsock
+ monkeypatch.setattr(socket, 'create_connection', gsock.create_connection)
+ monkeypatch.setattr(socket, 'socket', gsock.socket)
+ monkeypatch.setattr(socket, 'getaddrinfo', gsock.getaddrinfo)
+ return vsocket