summaryrefslogtreecommitdiffstats
path: root/tests/_util.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 17:47:36 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 17:47:36 +0000
commit3c7813683b1845959aca706eaa23f062a006356b (patch)
treeecba42f14f0c919d94332e2633d9b0e6834c9cec /tests/_util.py
parentInitial commit. (diff)
downloadparamiko-3c7813683b1845959aca706eaa23f062a006356b.tar.xz
paramiko-3c7813683b1845959aca706eaa23f062a006356b.zip
Adding upstream version 3.4.0.upstream/3.4.0upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/_util.py')
-rw-r--r--tests/_util.py468
1 files changed, 468 insertions, 0 deletions
diff --git a/tests/_util.py b/tests/_util.py
new file mode 100644
index 0000000..f0ae1d4
--- /dev/null
+++ b/tests/_util.py
@@ -0,0 +1,468 @@
+from contextlib import contextmanager
+from os.path import dirname, realpath, join
+import builtins
+import os
+from pathlib import Path
+import socket
+import struct
+import sys
+import unittest
+import time
+import threading
+
+import pytest
+
+from paramiko import (
+ ServerInterface,
+ RSAKey,
+ DSSKey,
+ AUTH_FAILED,
+ AUTH_PARTIALLY_SUCCESSFUL,
+ AUTH_SUCCESSFUL,
+ OPEN_SUCCEEDED,
+ OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED,
+ InteractiveQuery,
+ Transport,
+)
+from paramiko.ssh_gss import GSS_AUTH_AVAILABLE
+
+from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.asymmetric import padding, rsa
+
+tests_dir = dirname(realpath(__file__))
+
+from ._loop import LoopSocket
+
+
+def _support(filename):
+ base = Path(tests_dir)
+ top = base / filename
+ deeper = base / "_support" / filename
+ return str(deeper if deeper.exists() else top)
+
+
+def _config(name):
+ return join(tests_dir, "configs", name)
+
+
+needs_gssapi = pytest.mark.skipif(
+ not GSS_AUTH_AVAILABLE, reason="No GSSAPI to test"
+)
+
+
+def needs_builtin(name):
+ """
+ Skip decorated test if builtin name does not exist.
+ """
+ reason = "Test requires a builtin '{}'".format(name)
+ return pytest.mark.skipif(not hasattr(builtins, name), reason=reason)
+
+
+slow = pytest.mark.slow
+
+# GSSAPI / Kerberos related tests need a working Kerberos environment.
+# The class `KerberosTestCase` provides such an environment or skips all tests.
+# There are 3 distinct cases:
+#
+# - A Kerberos environment has already been created and the environment
+# contains the required information.
+#
+# - We can use the package 'k5test' to setup an working kerberos environment on
+# the fly.
+#
+# - We skip all tests.
+#
+# ToDo: add a Windows specific implementation?
+
+if (
+ os.environ.get("K5TEST_USER_PRINC", None)
+ and os.environ.get("K5TEST_HOSTNAME", None)
+ and os.environ.get("KRB5_KTNAME", None)
+): # add other vars as needed
+
+ # The environment provides the required information
+ class DummyK5Realm:
+ def __init__(self):
+ for k in os.environ:
+ if not k.startswith("K5TEST_"):
+ continue
+ setattr(self, k[7:].lower(), os.environ[k])
+ self.env = {}
+
+ class KerberosTestCase(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.realm = DummyK5Realm()
+
+ @classmethod
+ def tearDownClass(cls):
+ del cls.realm
+
+else:
+ try:
+ # Try to setup a kerberos environment
+ from k5test import KerberosTestCase
+ except Exception:
+ # Use a dummy, that skips all tests
+ class KerberosTestCase(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ raise unittest.SkipTest(
+ "Missing extension package k5test. "
+ 'Please run "pip install k5test" '
+ "to install it."
+ )
+
+
+def update_env(testcase, mapping, env=os.environ):
+ """Modify os.environ during a test case and restore during cleanup."""
+ saved_env = env.copy()
+
+ def replace(target, source):
+ target.update(source)
+ for k in list(target):
+ if k not in source:
+ target.pop(k, None)
+
+ testcase.addCleanup(replace, env, saved_env)
+ env.update(mapping)
+ return testcase
+
+
+def k5shell(args=None):
+ """Create a shell with an kerberos environment
+
+ This can be used to debug paramiko or to test the old GSSAPI.
+ To test a different GSSAPI, simply activate a suitable venv
+ within the shell.
+ """
+ import k5test
+ import atexit
+ import subprocess
+
+ k5 = k5test.K5Realm()
+ atexit.register(k5.stop)
+ os.environ.update(k5.env)
+ for n in ("realm", "user_princ", "hostname"):
+ os.environ["K5TEST_" + n.upper()] = getattr(k5, n)
+
+ if not args:
+ args = sys.argv[1:]
+ if not args:
+ args = [os.environ.get("SHELL", "bash")]
+ sys.exit(subprocess.call(args))
+
+
+def is_low_entropy():
+ """
+ Attempts to detect whether running interpreter is low-entropy.
+
+ "low-entropy" is defined as being in 32-bit mode and with the hash seed set
+ to zero.
+ """
+ is_32bit = struct.calcsize("P") == 32 / 8
+ # I don't see a way to tell internally if the hash seed was set this
+ # way, but env should be plenty sufficient, this is only for testing.
+ return is_32bit and os.environ.get("PYTHONHASHSEED", None) == "0"
+
+
+def sha1_signing_unsupported():
+ """
+ This is used to skip tests in environments where SHA-1 signing is
+ not supported by the backend.
+ """
+ private_key = rsa.generate_private_key(
+ public_exponent=65537, key_size=2048, backend=default_backend()
+ )
+ message = b"Some dummy text"
+ try:
+ private_key.sign(
+ message,
+ padding.PSS(
+ mgf=padding.MGF1(hashes.SHA1()),
+ salt_length=padding.PSS.MAX_LENGTH,
+ ),
+ hashes.SHA1(),
+ )
+ return False
+ except UnsupportedAlgorithm as e:
+ return e._reason is _Reasons.UNSUPPORTED_HASH
+
+
+requires_sha1_signing = unittest.skipIf(
+ sha1_signing_unsupported(), "SHA-1 signing not supported"
+)
+
+_disable_sha2 = dict(
+ disabled_algorithms=dict(keys=["rsa-sha2-256", "rsa-sha2-512"])
+)
+_disable_sha1 = dict(disabled_algorithms=dict(keys=["ssh-rsa"]))
+_disable_sha2_pubkey = dict(
+ disabled_algorithms=dict(pubkeys=["rsa-sha2-256", "rsa-sha2-512"])
+)
+_disable_sha1_pubkey = dict(disabled_algorithms=dict(pubkeys=["ssh-rsa"]))
+
+
+unicodey = "\u2022"
+
+
+class TestServer(ServerInterface):
+ paranoid_did_password = False
+ paranoid_did_public_key = False
+ # TODO: make this ed25519 or something else modern? (_is_ this used??)
+ paranoid_key = DSSKey.from_private_key_file(_support("dss.key"))
+
+ def __init__(self, allowed_keys=None):
+ self.allowed_keys = allowed_keys if allowed_keys is not None else []
+
+ def check_channel_request(self, kind, chanid):
+ if kind == "bogus":
+ return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+ return OPEN_SUCCEEDED
+
+ def check_channel_exec_request(self, channel, command):
+ if command != b"yes":
+ return False
+ return True
+
+ def check_channel_shell_request(self, channel):
+ return True
+
+ def check_global_request(self, kind, msg):
+ self._global_request = kind
+ # NOTE: for w/e reason, older impl of this returned False always, even
+ # tho that's only supposed to occur if the request cannot be served.
+ # For now, leaving that the default unless test supplies specific
+ # 'acceptable' request kind
+ return kind == "acceptable"
+
+ def check_channel_x11_request(
+ self,
+ channel,
+ single_connection,
+ auth_protocol,
+ auth_cookie,
+ screen_number,
+ ):
+ self._x11_single_connection = single_connection
+ self._x11_auth_protocol = auth_protocol
+ self._x11_auth_cookie = auth_cookie
+ self._x11_screen_number = screen_number
+ return True
+
+ def check_port_forward_request(self, addr, port):
+ self._listen = socket.socket()
+ self._listen.bind(("127.0.0.1", 0))
+ self._listen.listen(1)
+ return self._listen.getsockname()[1]
+
+ def cancel_port_forward_request(self, addr, port):
+ self._listen.close()
+ self._listen = None
+
+ def check_channel_direct_tcpip_request(self, chanid, origin, destination):
+ self._tcpip_dest = destination
+ return OPEN_SUCCEEDED
+
+ def get_allowed_auths(self, username):
+ if username == "slowdive":
+ return "publickey,password"
+ if username == "paranoid":
+ if (
+ not self.paranoid_did_password
+ and not self.paranoid_did_public_key
+ ):
+ return "publickey,password"
+ elif self.paranoid_did_password:
+ return "publickey"
+ else:
+ return "password"
+ if username == "commie":
+ return "keyboard-interactive"
+ if username == "utf8":
+ return "password"
+ if username == "non-utf8":
+ return "password"
+ return "publickey"
+
+ def check_auth_password(self, username, password):
+ if (username == "slowdive") and (password == "pygmalion"):
+ return AUTH_SUCCESSFUL
+ if (username == "paranoid") and (password == "paranoid"):
+ # 2-part auth (even openssh doesn't support this)
+ self.paranoid_did_password = True
+ if self.paranoid_did_public_key:
+ return AUTH_SUCCESSFUL
+ return AUTH_PARTIALLY_SUCCESSFUL
+ if (username == "utf8") and (password == unicodey):
+ return AUTH_SUCCESSFUL
+ if (username == "non-utf8") and (password == "\xff"):
+ return AUTH_SUCCESSFUL
+ if username == "bad-server":
+ raise Exception("Ack!")
+ if username == "unresponsive-server":
+ time.sleep(5)
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
+ def check_auth_publickey(self, username, key):
+ if (username == "paranoid") and (key == self.paranoid_key):
+ # 2-part auth
+ self.paranoid_did_public_key = True
+ if self.paranoid_did_password:
+ return AUTH_SUCCESSFUL
+ return AUTH_PARTIALLY_SUCCESSFUL
+ # TODO: make sure all tests incidentally using this to pass, _without
+ # sending a username oops_, get updated somehow - probably via server()
+ # default always injecting a username
+ elif key in self.allowed_keys:
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
+ def check_auth_interactive(self, username, submethods):
+ if username == "commie":
+ self.username = username
+ return InteractiveQuery(
+ "password", "Please enter a password.", ("Password", False)
+ )
+ return AUTH_FAILED
+
+ def check_auth_interactive_response(self, responses):
+ if self.username == "commie":
+ if (len(responses) == 1) and (responses[0] == "cat"):
+ return AUTH_SUCCESSFUL
+ return AUTH_FAILED
+
+
+@contextmanager
+def server(
+ hostkey=None,
+ init=None,
+ server_init=None,
+ client_init=None,
+ connect=None,
+ pubkeys=None,
+ catch_error=False,
+ transport_factory=None,
+ server_transport_factory=None,
+ defer=False,
+ skip_verify=False,
+):
+ """
+ SSH server contextmanager for testing.
+
+ Yields a tuple of ``(tc, ts)`` (client- and server-side `Transport`
+ objects), or ``(tc, ts, err)`` when ``catch_error==True``.
+
+ :param hostkey:
+ Host key to use for the server; if None, loads
+ ``rsa.key``.
+ :param init:
+ Default `Transport` constructor kwargs to use for both sides.
+ :param server_init:
+ Extends and/or overrides ``init`` for server transport only.
+ :param client_init:
+ Extends and/or overrides ``init`` for client transport only.
+ :param connect:
+ Kwargs to use for ``connect()`` on the client.
+ :param pubkeys:
+ List of public keys for auth.
+ :param catch_error:
+ Whether to capture connection errors & yield from contextmanager.
+ Necessary for connection_time exception testing.
+ :param transport_factory:
+ Like the same-named param in SSHClient: which Transport class to use.
+ :param server_transport_factory:
+ Like ``transport_factory``, but only impacts the server transport.
+ :param bool defer:
+ Whether to defer authentication during connecting.
+
+ This is really just shorthand for ``connect={}`` which would do roughly
+ the same thing. Also: this implies skip_verify=True automatically!
+ :param bool skip_verify:
+ Whether NOT to do the default "make sure auth passed" check.
+ """
+ if init is None:
+ init = {}
+ if server_init is None:
+ server_init = {}
+ if client_init is None:
+ client_init = {}
+ if connect is None:
+ # No auth at all please
+ if defer:
+ connect = dict()
+ # Default username based auth
+ else:
+ connect = dict(username="slowdive", password="pygmalion")
+ socks = LoopSocket()
+ sockc = LoopSocket()
+ sockc.link(socks)
+ if transport_factory is None:
+ transport_factory = Transport
+ if server_transport_factory is None:
+ server_transport_factory = transport_factory
+ tc = transport_factory(sockc, **dict(init, **client_init))
+ ts = server_transport_factory(socks, **dict(init, **server_init))
+
+ if hostkey is None:
+ hostkey = RSAKey.from_private_key_file(_support("rsa.key"))
+ ts.add_server_key(hostkey)
+ event = threading.Event()
+ server = TestServer(allowed_keys=pubkeys)
+ assert not event.is_set()
+ assert not ts.is_active()
+ assert tc.get_username() is None
+ assert ts.get_username() is None
+ assert not tc.is_authenticated()
+ assert not ts.is_authenticated()
+
+ err = None
+ # Trap errors and yield instead of raising right away; otherwise callers
+ # cannot usefully deal with problems at connect time which stem from errors
+ # in the server side.
+ try:
+ ts.start_server(event, server)
+ tc.connect(**connect)
+
+ event.wait(1.0)
+ assert event.is_set()
+ assert ts.is_active()
+ assert tc.is_active()
+
+ except Exception as e:
+ if not catch_error:
+ raise
+ err = e
+
+ yield (tc, ts, err) if catch_error else (tc, ts)
+
+ if not (catch_error or skip_verify or defer):
+ assert ts.is_authenticated()
+ assert tc.is_authenticated()
+
+ tc.close()
+ ts.close()
+ socks.close()
+ sockc.close()
+
+
+def wait_until(condition, *, timeout=2):
+ """
+ Wait until `condition()` no longer raises an `AssertionError` or until
+ `timeout` seconds have passed, which causes a `TimeoutError` to be raised.
+ """
+ deadline = time.time() + timeout
+
+ while True:
+ try:
+ condition()
+ except AssertionError as e:
+ if time.time() > deadline:
+ timeout_message = f"Condition not reached after {timeout}s"
+ raise TimeoutError(timeout_message) from e
+ else:
+ return
+ time.sleep(0.01)