diff options
Diffstat (limited to 'tests/test_transport.py')
-rw-r--r-- | tests/test_transport.py | 1446 |
1 files changed, 1446 insertions, 0 deletions
diff --git a/tests/test_transport.py b/tests/test_transport.py new file mode 100644 index 0000000..67e2eb4 --- /dev/null +++ b/tests/test_transport.py @@ -0,0 +1,1446 @@ +# Copyright (C) 2003-2009 Robey Pointer <robeypointer@gmail.com> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +""" +Some unit tests for the ssh2 protocol in Transport. +""" + + +from binascii import hexlify +import itertools +import select +import socket +import time +import threading +import random +import sys +import unittest +from unittest.mock import Mock + +from paramiko import ( + AuthHandler, + ChannelException, + IncompatiblePeer, + MessageOrderError, + Packetizer, + RSAKey, + SSHException, + SecurityOptions, + ServiceRequestingTransport, + Transport, +) +from paramiko.auth_handler import AuthOnlyHandler +from paramiko import OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED +from paramiko.common import ( + DEFAULT_MAX_PACKET_SIZE, + DEFAULT_WINDOW_SIZE, + MAX_WINDOW_SIZE, + MIN_PACKET_SIZE, + MIN_WINDOW_SIZE, + MSG_CHANNEL_OPEN, + MSG_DEBUG, + MSG_IGNORE, + MSG_KEXINIT, + MSG_UNIMPLEMENTED, + MSG_USERAUTH_SUCCESS, + byte_chr, + cMSG_CHANNEL_WINDOW_ADJUST, + cMSG_UNIMPLEMENTED, +) +from paramiko.message import Message + +from ._util import ( + needs_builtin, + _support, + requires_sha1_signing, + slow, + server, + _disable_sha2, + _disable_sha1, + TestServer as NullServer, +) +from ._loop import LoopSocket +from pytest import mark, raises + + +LONG_BANNER = """\ +Welcome to the super-fun-land BBS, where our MOTD is the primary thing we +provide. All rights reserved. Offer void in Tennessee. Stunt drivers were +used. Do not attempt at home. Some restrictions apply. + +Happy birthday to Commie the cat! + +Note: An SSH banner may eventually appear. + +Maybe. +""" + +# Faux 'packet type' we do not implement and are unlikely ever to (but which is +# technically "within spec" re RFC 4251 +MSG_FUGGEDABOUTIT = 253 + + +class TransportTest(unittest.TestCase): + # TODO: this can get nuked once ServiceRequestingTransport becomes the + # only Transport, as it has this baked in. + _auth_handler_class = AuthHandler + + def setUp(self): + self.socks = LoopSocket() + self.sockc = LoopSocket() + self.sockc.link(self.socks) + self.tc = Transport(self.sockc) + self.ts = Transport(self.socks) + + def tearDown(self): + self.tc.close() + self.ts.close() + self.socks.close() + self.sockc.close() + + # TODO: unify with newer contextmanager + def setup_test_server( + self, client_options=None, server_options=None, connect_kwargs=None + ): + host_key = RSAKey.from_private_key_file(_support("rsa.key")) + public_host_key = RSAKey(data=host_key.asbytes()) + self.ts.add_server_key(host_key) + + if client_options is not None: + client_options(self.tc.get_security_options()) + if server_options is not None: + server_options(self.ts.get_security_options()) + + event = threading.Event() + self.server = NullServer() + self.assertTrue(not event.is_set()) + self.ts.start_server(event, self.server) + if connect_kwargs is None: + connect_kwargs = dict( + hostkey=public_host_key, + username="slowdive", + password="pygmalion", + ) + self.tc.connect(**connect_kwargs) + event.wait(1.0) + self.assertTrue(event.is_set()) + self.assertTrue(self.ts.is_active()) + + def test_security_options(self): + o = self.tc.get_security_options() + self.assertEqual(type(o), SecurityOptions) + self.assertTrue(("aes256-cbc", "aes192-cbc") != o.ciphers) + o.ciphers = ("aes256-cbc", "aes192-cbc") + self.assertEqual(("aes256-cbc", "aes192-cbc"), o.ciphers) + try: + o.ciphers = ("aes256-cbc", "made-up-cipher") + self.assertTrue(False) + except ValueError: + pass + try: + o.ciphers = 23 + self.assertTrue(False) + except TypeError: + pass + + def testb_security_options_reset(self): + o = self.tc.get_security_options() + # should not throw any exceptions + o.ciphers = o.ciphers + o.digests = o.digests + o.key_types = o.key_types + o.kex = o.kex + o.compression = o.compression + + def test_compute_key(self): + self.tc.K = 123281095979686581523377256114209720774539068973101330872763622971399429481072519713536292772709507296759612401802191955568143056534122385270077606457721553469730659233569339356140085284052436697480759510519672848743794433460113118986816826624865291116513647975790797391795651716378444844877749505443714557929 # noqa + self.tc.H = b"\x0C\x83\x07\xCD\xE6\x85\x6F\xF3\x0B\xA9\x36\x84\xEB\x0F\x04\xC2\x52\x0E\x9E\xD3" # noqa + self.tc.session_id = self.tc.H + key = self.tc._compute_key("C", 32) + self.assertEqual( + b"207E66594CA87C44ECCBA3B3CD39FDDB378E6FDB0F97C54B2AA0CFBF900CD995", # noqa + hexlify(key).upper(), + ) + + def test_simple(self): + """ + verify that we can establish an ssh link with ourselves across the + loopback sockets. this is hardly "simple" but it's simpler than the + later tests. :) + """ + host_key = RSAKey.from_private_key_file(_support("rsa.key")) + public_host_key = RSAKey(data=host_key.asbytes()) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assertTrue(not event.is_set()) + self.assertEqual(None, self.tc.get_username()) + self.assertEqual(None, self.ts.get_username()) + self.assertEqual(False, self.tc.is_authenticated()) + self.assertEqual(False, self.ts.is_authenticated()) + self.ts.start_server(event, server) + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) + event.wait(1.0) + self.assertTrue(event.is_set()) + self.assertTrue(self.ts.is_active()) + self.assertEqual("slowdive", self.tc.get_username()) + self.assertEqual("slowdive", self.ts.get_username()) + self.assertEqual(True, self.tc.is_authenticated()) + self.assertEqual(True, self.ts.is_authenticated()) + + def test_long_banner(self): + """ + verify that a long banner doesn't mess up the handshake. + """ + host_key = RSAKey.from_private_key_file(_support("rsa.key")) + public_host_key = RSAKey(data=host_key.asbytes()) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assertTrue(not event.is_set()) + self.socks.send(LONG_BANNER) + self.ts.start_server(event, server) + self.tc.connect( + hostkey=public_host_key, username="slowdive", password="pygmalion" + ) + event.wait(1.0) + self.assertTrue(event.is_set()) + self.assertTrue(self.ts.is_active()) + + def test_special(self): + """ + verify that the client can demand odd handshake settings, and can + renegotiate keys in mid-stream. + """ + + def force_algorithms(options): + options.ciphers = ("aes256-cbc",) + options.digests = ("hmac-md5-96",) + + self.setup_test_server(client_options=force_algorithms) + self.assertEqual("aes256-cbc", self.tc.local_cipher) + self.assertEqual("aes256-cbc", self.tc.remote_cipher) + self.assertEqual(12, self.tc.packetizer.get_mac_size_out()) + self.assertEqual(12, self.tc.packetizer.get_mac_size_in()) + + self.tc.send_ignore(1024) + self.tc.renegotiate_keys() + self.ts.send_ignore(1024) + + @slow + def test_keepalive(self): + """ + verify that the keepalive will be sent. + """ + self.setup_test_server() + self.assertEqual(None, getattr(self.server, "_global_request", None)) + self.tc.set_keepalive(1) + time.sleep(2) + self.assertEqual("keepalive@lag.net", self.server._global_request) + + def test_exec_command(self): + """ + verify that exec_command() does something reasonable. + """ + self.setup_test_server() + + chan = self.tc.open_session() + schan = self.ts.accept(1.0) + try: + chan.exec_command( + b"command contains \xfc and is not a valid UTF-8 string" + ) + self.assertTrue(False) + except SSHException: + pass + + chan = self.tc.open_session() + chan.exec_command("yes") + schan = self.ts.accept(1.0) + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") + schan.close() + + f = chan.makefile() + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) + f = chan.makefile_stderr() + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) + + # now try it with combined stdout/stderr + chan = self.tc.open_session() + chan.exec_command("yes") + schan = self.ts.accept(1.0) + schan.send("Hello there.\n") + schan.send_stderr("This is on stderr.\n") + schan.close() + + chan.set_combine_stderr(True) + f = chan.makefile() + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("This is on stderr.\n", f.readline()) + self.assertEqual("", f.readline()) + + def test_channel_can_be_used_as_context_manager(self): + """ + verify that exec_command() does something reasonable. + """ + self.setup_test_server() + + with self.tc.open_session() as chan: + with self.ts.accept(1.0) as schan: + chan.exec_command("yes") + schan.send("Hello there.\n") + schan.close() + + f = chan.makefile() + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) + + def test_invoke_shell(self): + """ + verify that invoke_shell() does something reasonable. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + chan.send("communist j. cat\n") + f = schan.makefile() + self.assertEqual("communist j. cat\n", f.readline()) + chan.close() + self.assertEqual("", f.readline()) + + def test_channel_exception(self): + """ + verify that ChannelException is thrown for a bad open-channel request. + """ + self.setup_test_server() + try: + self.tc.open_channel("bogus") + self.fail("expected exception") + except ChannelException as e: + self.assertTrue(e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED) + + def test_exit_status(self): + """ + verify that get_exit_status() works. + """ + self.setup_test_server() + + chan = self.tc.open_session() + schan = self.ts.accept(1.0) + chan.exec_command("yes") + schan.send("Hello there.\n") + self.assertTrue(not chan.exit_status_ready()) + # trigger an EOF + schan.shutdown_read() + schan.shutdown_write() + schan.send_exit_status(23) + schan.close() + + f = chan.makefile() + self.assertEqual("Hello there.\n", f.readline()) + self.assertEqual("", f.readline()) + count = 0 + while not chan.exit_status_ready(): + time.sleep(0.1) + count += 1 + if count > 50: + raise Exception("timeout") + self.assertEqual(23, chan.recv_exit_status()) + chan.close() + + def test_select(self): + """ + verify that select() on a channel works. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + + # nothing should be ready + r, w, e = select.select([chan], [], [], 0.1) + self.assertEqual([], r) + self.assertEqual([], w) + self.assertEqual([], e) + + schan.send("hello\n") + + # something should be ready now (give it 1 second to appear) + for i in range(10): + r, w, e = select.select([chan], [], [], 0.1) + if chan in r: + break + time.sleep(0.1) + self.assertEqual([chan], r) + self.assertEqual([], w) + self.assertEqual([], e) + + self.assertEqual(b"hello\n", chan.recv(6)) + + # and, should be dead again now + r, w, e = select.select([chan], [], [], 0.1) + self.assertEqual([], r) + self.assertEqual([], w) + self.assertEqual([], e) + + schan.close() + + # detect eof? + for i in range(10): + r, w, e = select.select([chan], [], [], 0.1) + if chan in r: + break + time.sleep(0.1) + self.assertEqual([chan], r) + self.assertEqual([], w) + self.assertEqual([], e) + self.assertEqual(b"", chan.recv(16)) + + # make sure the pipe is still open for now... + p = chan._pipe + self.assertEqual(False, p._closed) + chan.close() + # ...and now is closed. + self.assertEqual(True, p._closed) + + def test_renegotiate(self): + """ + verify that a transport can correctly renegotiate mid-stream. + """ + self.setup_test_server() + self.tc.packetizer.REKEY_BYTES = 16384 + chan = self.tc.open_session() + chan.exec_command("yes") + schan = self.ts.accept(1.0) + + self.assertEqual(self.tc.H, self.tc.session_id) + for i in range(20): + chan.send("x" * 1024) + chan.close() + + # allow a few seconds for the rekeying to complete + for i in range(50): + if self.tc.H != self.tc.session_id: + break + time.sleep(0.1) + self.assertNotEqual(self.tc.H, self.tc.session_id) + + schan.close() + + def test_compression(self): + """ + verify that zlib compression is basically working. + """ + + def force_compression(o): + o.compression = ("zlib",) + + self.setup_test_server(force_compression, force_compression) + chan = self.tc.open_session() + chan.exec_command("yes") + schan = self.ts.accept(1.0) + + bytes = self.tc.packetizer._Packetizer__sent_bytes + chan.send("x" * 1024) + bytes2 = self.tc.packetizer._Packetizer__sent_bytes + block_size = self.tc._cipher_info[self.tc.local_cipher]["block-size"] + mac_size = self.tc._mac_info[self.tc.local_mac]["size"] + # tests show this is actually compressed to *52 bytes*! including + # packet overhead! nice!! :) + self.assertTrue(bytes2 - bytes < 1024) + self.assertEqual(16 + block_size + mac_size, bytes2 - bytes) + + chan.close() + schan.close() + + def test_x11(self): + """ + verify that an x11 port can be requested and opened. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command("yes") + schan = self.ts.accept(1.0) + + requested = [] + + def handler(c, addr_port): + addr, port = addr_port + requested.append((addr, port)) + self.tc._queue_incoming_channel(c) + + self.assertEqual( + None, getattr(self.server, "_x11_screen_number", None) + ) + cookie = chan.request_x11(0, single_connection=True, handler=handler) + self.assertEqual(0, self.server._x11_screen_number) + self.assertEqual("MIT-MAGIC-COOKIE-1", self.server._x11_auth_protocol) + self.assertEqual(cookie, self.server._x11_auth_cookie) + self.assertEqual(True, self.server._x11_single_connection) + + x11_server = self.ts.open_x11_channel(("localhost", 6093)) + x11_client = self.tc.accept() + self.assertEqual("localhost", requested[0][0]) + self.assertEqual(6093, requested[0][1]) + + x11_server.send("hello") + self.assertEqual(b"hello", x11_client.recv(5)) + + x11_server.close() + x11_client.close() + chan.close() + schan.close() + + def test_reverse_port_forwarding(self): + """ + verify that a client can ask the server to open a reverse port for + forwarding. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command("yes") + self.ts.accept(1.0) + + requested = [] + + def handler(c, origin_addr_port, server_addr_port): + requested.append(origin_addr_port) + requested.append(server_addr_port) + self.tc._queue_incoming_channel(c) + + port = self.tc.request_port_forward("127.0.0.1", 0, handler) + self.assertEqual(port, self.server._listen.getsockname()[1]) + + cs = socket.socket() + cs.connect(("127.0.0.1", port)) + ss, _ = self.server._listen.accept() + sch = self.ts.open_forwarded_tcpip_channel( + ss.getsockname(), ss.getpeername() + ) + cch = self.tc.accept() + + sch.send("hello") + self.assertEqual(b"hello", cch.recv(5)) + sch.close() + cch.close() + ss.close() + cs.close() + + # now cancel it. + self.tc.cancel_port_forward("127.0.0.1", port) + self.assertTrue(self.server._listen is None) + + def test_port_forwarding(self): + """ + verify that a client can forward new connections from a locally- + forwarded port. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.exec_command("yes") + self.ts.accept(1.0) + + # open a port on the "server" that the client will ask to forward to. + greeting_server = socket.socket() + greeting_server.bind(("127.0.0.1", 0)) + greeting_server.listen(1) + greeting_port = greeting_server.getsockname()[1] + + cs = self.tc.open_channel( + "direct-tcpip", ("127.0.0.1", greeting_port), ("", 9000) + ) + sch = self.ts.accept(1.0) + cch = socket.socket() + cch.connect(self.server._tcpip_dest) + + ss, _ = greeting_server.accept() + ss.send(b"Hello!\n") + ss.close() + sch.send(cch.recv(8192)) + sch.close() + + self.assertEqual(b"Hello!\n", cs.recv(7)) + cs.close() + + def test_stderr_select(self): + """ + verify that select() on a channel works even if only stderr is + receiving data. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + + # nothing should be ready + r, w, e = select.select([chan], [], [], 0.1) + self.assertEqual([], r) + self.assertEqual([], w) + self.assertEqual([], e) + + schan.send_stderr("hello\n") + + # something should be ready now (give it 1 second to appear) + for i in range(10): + r, w, e = select.select([chan], [], [], 0.1) + if chan in r: + break + time.sleep(0.1) + self.assertEqual([chan], r) + self.assertEqual([], w) + self.assertEqual([], e) + + self.assertEqual(b"hello\n", chan.recv_stderr(6)) + + # and, should be dead again now + r, w, e = select.select([chan], [], [], 0.1) + self.assertEqual([], r) + self.assertEqual([], w) + self.assertEqual([], e) + + schan.close() + chan.close() + + def test_send_ready(self): + """ + verify that send_ready() indicates when a send would not block. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + + self.assertEqual(chan.send_ready(), True) + total = 0 + K = "*" * 1024 + limit = 1 + (64 * 2**15) + while total < limit: + chan.send(K) + total += len(K) + if not chan.send_ready(): + break + self.assertTrue(total < limit) + + schan.close() + chan.close() + self.assertEqual(chan.send_ready(), True) + + def test_rekey_deadlock(self): + """ + Regression test for deadlock when in-transit messages are received + after MSG_KEXINIT is sent + + Note: When this test fails, it may leak threads. + """ + + # Test for an obscure deadlocking bug that can occur if we receive + # certain messages while initiating a key exchange. + # + # The deadlock occurs as follows: + # + # In the main thread: + # 1. The user's program calls Channel.send(), which sends + # MSG_CHANNEL_DATA to the remote host. + # 2. Packetizer discovers that REKEY_BYTES has been exceeded, and + # sets the __need_rekey flag. + # + # In the Transport thread: + # 3. Packetizer notices that the __need_rekey flag is set, and raises + # NeedRekeyException. + # 4. In response to NeedRekeyException, the transport thread sends + # MSG_KEXINIT to the remote host. + # + # On the remote host (using any SSH implementation): + # 5. The MSG_CHANNEL_DATA is received, and MSG_CHANNEL_WINDOW_ADJUST + # is sent. + # 6. The MSG_KEXINIT is received, and a corresponding MSG_KEXINIT is + # sent. + # + # In the main thread: + # 7. The user's program calls Channel.send(). + # 8. Channel.send acquires Channel.lock, then calls + # Transport._send_user_message(). + # 9. Transport._send_user_message waits for Transport.clear_to_send + # to be set (i.e., it waits for re-keying to complete). + # Channel.lock is still held. + # + # In the Transport thread: + # 10. MSG_CHANNEL_WINDOW_ADJUST is received; Channel._window_adjust + # is called to handle it. + # 11. Channel._window_adjust tries to acquire Channel.lock, but it + # blocks because the lock is already held by the main thread. + # + # The result is that the Transport thread never processes the remote + # host's MSG_KEXINIT packet, because it becomes deadlocked while + # handling the preceding MSG_CHANNEL_WINDOW_ADJUST message. + + # We set up two separate threads for sending and receiving packets, + # while the main thread acts as a watchdog timer. If the timer + # expires, a deadlock is assumed. + + class SendThread(threading.Thread): + def __init__(self, chan, iterations, done_event): + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) + self.daemon = True + self.chan = chan + self.iterations = iterations + self.done_event = done_event + self.watchdog_event = threading.Event() + self.last = None + + def run(self): + try: + for i in range(1, 1 + self.iterations): + if self.done_event.is_set(): + break + self.watchdog_event.set() + # print i, "SEND" + self.chan.send("x" * 2048) + finally: + self.done_event.set() + self.watchdog_event.set() + + class ReceiveThread(threading.Thread): + def __init__(self, chan, done_event): + threading.Thread.__init__( + self, None, None, self.__class__.__name__ + ) + self.daemon = True + self.chan = chan + self.done_event = done_event + self.watchdog_event = threading.Event() + + def run(self): + try: + while not self.done_event.is_set(): + if self.chan.recv_ready(): + chan.recv(65536) + self.watchdog_event.set() + else: + if random.randint(0, 1): + time.sleep(random.randint(0, 500) / 1000.0) + finally: + self.done_event.set() + self.watchdog_event.set() + + self.setup_test_server() + self.ts.packetizer.REKEY_BYTES = 2048 + + chan = self.tc.open_session() + chan.exec_command("yes") + schan = self.ts.accept(1.0) + + # Monkey patch the client's Transport._handler_table so that the client + # sends MSG_CHANNEL_WINDOW_ADJUST whenever it receives an initial + # MSG_KEXINIT. This is used to simulate the effect of network latency + # on a real MSG_CHANNEL_WINDOW_ADJUST message. + self.tc._handler_table = ( + self.tc._handler_table.copy() + ) # copy per-class dictionary + _negotiate_keys = self.tc._handler_table[MSG_KEXINIT] + + def _negotiate_keys_wrapper(self, m): + if self.local_kex_init is None: # Remote side sent KEXINIT + # Simulate in-transit MSG_CHANNEL_WINDOW_ADJUST by sending it + # before responding to the incoming MSG_KEXINIT. + m2 = Message() + m2.add_byte(cMSG_CHANNEL_WINDOW_ADJUST) + m2.add_int(chan.remote_chanid) + m2.add_int(1) # bytes to add + self._send_message(m2) + return _negotiate_keys(self, m) + + self.tc._handler_table[MSG_KEXINIT] = _negotiate_keys_wrapper + + # Parameters for the test + iterations = 500 # The deadlock does not happen every time, but it + # should after many iterations. + timeout = 5 + + # This event is set when the test is completed + done_event = threading.Event() + + # Start the sending thread + st = SendThread(schan, iterations, done_event) + st.start() + + # Start the receiving thread + rt = ReceiveThread(chan, done_event) + rt.start() + + # Act as a watchdog timer, checking + deadlocked = False + while not deadlocked and not done_event.is_set(): + for event in (st.watchdog_event, rt.watchdog_event): + event.wait(timeout) + if done_event.is_set(): + break + if not event.is_set(): + deadlocked = True + break + event.clear() + + # Tell the threads to stop (if they haven't already stopped). Note + # that if one or more threads are deadlocked, they might hang around + # forever (until the process exits). + done_event.set() + + # Assertion: We must not have detected a timeout. + self.assertFalse(deadlocked) + + # Close the channels + schan.close() + chan.close() + + def test_sanitze_packet_size(self): + """ + verify that we conform to the rfc of packet and window sizes. + """ + for val, correct in [ + (4095, MIN_PACKET_SIZE), + (None, DEFAULT_MAX_PACKET_SIZE), + (2**32, MAX_WINDOW_SIZE), + ]: + self.assertEqual(self.tc._sanitize_packet_size(val), correct) + + def test_sanitze_window_size(self): + """ + verify that we conform to the rfc of packet and window sizes. + """ + for val, correct in [ + (32767, MIN_WINDOW_SIZE), + (None, DEFAULT_WINDOW_SIZE), + (2**32, MAX_WINDOW_SIZE), + ]: + self.assertEqual(self.tc._sanitize_window_size(val), correct) + + @slow + def test_handshake_timeout(self): + """ + verify that we can get a handshake timeout. + """ + # Tweak client Transport instance's Packetizer instance so + # its read_message() sleeps a bit. This helps prevent race conditions + # where the client Transport's timeout timer thread doesn't even have + # time to get scheduled before the main client thread finishes + # handshaking with the server. + # (Doing this on the server's transport *sounds* more 'correct' but + # actually doesn't work nearly as well for whatever reason.) + class SlowPacketizer(Packetizer): + def read_message(self): + time.sleep(1) + return super().read_message() + + # NOTE: prettttty sure since the replaced .packetizer Packetizer is now + # no longer doing anything with its copy of the socket...everything'll + # be fine. Even tho it's a bit squicky. + self.tc.packetizer = SlowPacketizer(self.tc.sock) + # Continue with regular test red tape. + host_key = RSAKey.from_private_key_file(_support("rsa.key")) + public_host_key = RSAKey(data=host_key.asbytes()) + self.ts.add_server_key(host_key) + event = threading.Event() + server = NullServer() + self.assertTrue(not event.is_set()) + self.tc.handshake_timeout = 0.000000000001 + self.ts.start_server(event, server) + self.assertRaises( + EOFError, + self.tc.connect, + hostkey=public_host_key, + username="slowdive", + password="pygmalion", + ) + + def test_select_after_close(self): + """ + verify that select works when a channel is already closed. + """ + self.setup_test_server() + chan = self.tc.open_session() + chan.invoke_shell() + schan = self.ts.accept(1.0) + schan.close() + + # give client a moment to receive close notification + time.sleep(0.1) + + r, w, e = select.select([chan], [], [], 0.1) + self.assertEqual([chan], r) + self.assertEqual([], w) + self.assertEqual([], e) + + def test_channel_send_misc(self): + """ + verify behaviours sending various instances to a channel + """ + self.setup_test_server() + text = "\xa7 slice me nicely" + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # TypeError raised on non string or buffer type + self.assertRaises(TypeError, chan.send, object()) + self.assertRaises(TypeError, chan.sendall, object()) + + # sendall() accepts a unicode instance + chan.sendall(text) + expected = text.encode("utf-8") + self.assertEqual(sfile.read(len(expected)), expected) + + @needs_builtin("buffer") + def test_channel_send_buffer(self): + """ + verify sending buffer instances to a channel + """ + self.setup_test_server() + data = 3 * b"some test data\n whole" + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts buffer instances + sent = 0 + while sent < len(data): + sent += chan.send(buffer(data, sent, 8)) # noqa + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a buffer instance + chan.sendall(buffer(data)) # noqa + self.assertEqual(sfile.read(len(data)), data) + + @needs_builtin("memoryview") + def test_channel_send_memoryview(self): + """ + verify sending memoryview instances to a channel + """ + self.setup_test_server() + data = 3 * b"some test data\n whole" + with self.tc.open_session() as chan: + schan = self.ts.accept(1.0) + if schan is None: + self.fail("Test server transport failed to accept") + sfile = schan.makefile() + + # send() accepts memoryview slices + sent = 0 + view = memoryview(data) + while sent < len(view): + sent += chan.send(view[sent : sent + 8]) + self.assertEqual(sfile.read(len(data)), data) + + # sendall() accepts a memoryview instance + chan.sendall(memoryview(data)) + self.assertEqual(sfile.read(len(data)), data) + + def test_server_rejects_open_channel_without_auth(self): + try: + self.setup_test_server(connect_kwargs={}) + self.tc.open_session() + except ChannelException as e: + assert e.code == OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + else: + assert False, "Did not raise ChannelException!" + + def test_server_rejects_arbitrary_global_request_without_auth(self): + self.setup_test_server(connect_kwargs={}) + # NOTE: this dummy global request kind would normally pass muster + # from the test server. + self.tc.global_request("acceptable") + # Global requests never raise exceptions, even on failure (not sure why + # this was the original design...ugh.) Best we can do to tell failure + # happened is that the client transport's global_response was set back + # to None; if it had succeeded, it would be the response Message. + err = "Unauthed global response incorrectly succeeded!" + assert self.tc.global_response is None, err + + def test_server_rejects_port_forward_without_auth(self): + # NOTE: at protocol level port forward requests are treated same as a + # regular global request, but Paramiko server implements a special-case + # method for it, so it gets its own test. (plus, THAT actually raises + # an exception on the client side, unlike the general case...) + self.setup_test_server(connect_kwargs={}) + try: + self.tc.request_port_forward("localhost", 1234) + except SSHException as e: + assert "forwarding request denied" in str(e) + else: + assert False, "Did not raise SSHException!" + + def _send_unimplemented(self, server_is_sender): + self.setup_test_server() + sender, recipient = self.tc, self.ts + if server_is_sender: + sender, recipient = self.ts, self.tc + recipient._send_message = Mock() + msg = Message() + msg.add_byte(cMSG_UNIMPLEMENTED) + sender._send_message(msg) + # TODO: I hate this but I literally don't see a good way to know when + # the recipient has received the sender's message (there are no + # existing threading events in play that work for this), esp in this + # case where we don't WANT a response (as otherwise we could + # potentially try blocking on the sender's receipt of a reply...maybe). + time.sleep(0.1) + assert not recipient._send_message.called + + def test_server_does_not_respond_to_MSG_UNIMPLEMENTED(self): + self._send_unimplemented(server_is_sender=False) + + def test_client_does_not_respond_to_MSG_UNIMPLEMENTED(self): + self._send_unimplemented(server_is_sender=True) + + def _send_client_message(self, message_type): + self.setup_test_server(connect_kwargs={}) + self.ts._send_message = Mock() + # NOTE: this isn't 100% realistic (most of these message types would + # have actual other fields in 'em) but it suffices to test the level of + # message dispatch we're interested in here. + msg = Message() + # TODO: really not liking the whole cMSG_XXX vs MSG_XXX duality right + # now, esp since the former is almost always just byte_chr(the + # latter)...but since that's the case... + msg.add_byte(byte_chr(message_type)) + self.tc._send_message(msg) + # No good way to actually wait for server action (see above tests re: + # MSG_UNIMPLEMENTED). Grump. + time.sleep(0.1) + + def _expect_unimplemented(self): + # Ensure MSG_UNIMPLEMENTED was sent (implies it hit end of loop instead + # of truly handling the given message). + # NOTE: When bug present, this will actually be the first thing that + # fails (since in many cases actual message handling doesn't involve + # sending a message back right away). + assert self.ts._send_message.call_count == 1 + reply = self.ts._send_message.call_args[0][0] + reply.rewind() # Because it's pre-send, not post-receive + assert reply.get_byte() == cMSG_UNIMPLEMENTED + + def test_server_transports_reject_client_message_types(self): + # TODO: handle Transport's own tables too, not just its inner auth + # handler's table. See TODOs in auth_handler.py + some_handler = self._auth_handler_class(self.tc) + for message_type in some_handler._client_handler_table: + self._send_client_message(message_type) + self._expect_unimplemented() + # Reset for rest of loop + self.tearDown() + self.setUp() + + def test_server_rejects_client_MSG_USERAUTH_SUCCESS(self): + self._send_client_message(MSG_USERAUTH_SUCCESS) + # Sanity checks + assert not self.ts.authenticated + assert not self.ts.auth_handler.authenticated + # Real fix's behavior + self._expect_unimplemented() + + def test_can_override_packetizer_used(self): + class MyPacketizer(Packetizer): + pass + + # control case + assert Transport(sock=LoopSocket()).packetizer.__class__ is Packetizer + # overridden case + tweaked = Transport(sock=LoopSocket(), packetizer_class=MyPacketizer) + assert tweaked.packetizer.__class__ is MyPacketizer + + +# TODO: for now this is purely a regression test. It needs actual tests of the +# intentional new behavior too! +class ServiceRequestingTransportTest(TransportTest): + _auth_handler_class = AuthOnlyHandler + + def setUp(self): + # Copypasta (Transport init is load-bearing) + self.socks = LoopSocket() + self.sockc = LoopSocket() + self.sockc.link(self.socks) + # New class who dis + self.tc = ServiceRequestingTransport(self.sockc) + self.ts = ServiceRequestingTransport(self.socks) + + +class AlgorithmDisablingTests(unittest.TestCase): + def test_preferred_lists_default_to_private_attribute_contents(self): + t = Transport(sock=Mock()) + assert t.preferred_ciphers == t._preferred_ciphers + assert t.preferred_macs == t._preferred_macs + assert t.preferred_keys == tuple( + t._preferred_keys + + tuple( + "{}-cert-v01@openssh.com".format(x) for x in t._preferred_keys + ) + ) + assert t.preferred_kex == t._preferred_kex + + def test_preferred_lists_filter_disabled_algorithms(self): + t = Transport( + sock=Mock(), + disabled_algorithms={ + "ciphers": ["aes128-cbc"], + "macs": ["hmac-md5"], + "keys": ["ssh-dss"], + "kex": ["diffie-hellman-group14-sha256"], + }, + ) + assert "aes128-cbc" in t._preferred_ciphers + assert "aes128-cbc" not in t.preferred_ciphers + assert "hmac-md5" in t._preferred_macs + assert "hmac-md5" not in t.preferred_macs + assert "ssh-dss" in t._preferred_keys + assert "ssh-dss" not in t.preferred_keys + assert "ssh-dss-cert-v01@openssh.com" not in t.preferred_keys + assert "diffie-hellman-group14-sha256" in t._preferred_kex + assert "diffie-hellman-group14-sha256" not in t.preferred_kex + + def test_implementation_refers_to_public_algo_lists(self): + t = Transport( + sock=Mock(), + disabled_algorithms={ + "ciphers": ["aes128-cbc"], + "macs": ["hmac-md5"], + "keys": ["ssh-dss"], + "kex": ["diffie-hellman-group14-sha256"], + "compression": ["zlib"], + }, + ) + # Enable compression cuz otherwise disabling one option for it makes no + # sense... + t.use_compression(True) + # Effectively a random spot check, but kex init touches most/all of the + # algorithm lists so it's a good spot. + t._send_message = Mock() + t._send_kex_init() + # Cribbed from Transport._parse_kex_init, which didn't feel worth + # refactoring given all the vars involved :( + m = t._send_message.call_args[0][0] + m.rewind() + m.get_byte() # the msg type + m.get_bytes(16) # cookie, discarded + kexen = m.get_list() + server_keys = m.get_list() + ciphers = m.get_list() + m.get_list() + macs = m.get_list() + m.get_list() + compressions = m.get_list() + # OK, now we can actually check that our disabled algos were not + # included (as this message includes the full lists) + assert "aes128-cbc" not in ciphers + assert "hmac-md5" not in macs + assert "ssh-dss" not in server_keys + assert "diffie-hellman-group14-sha256" not in kexen + assert "zlib" not in compressions + + +class TestSHA2SignatureKeyExchange(unittest.TestCase): + # NOTE: these all rely on the default server() hostkey being RSA + # NOTE: these rely on both sides being properly implemented re: agreed-upon + # hostkey during kex being what's actually used. Truly proving that eg + # SHA512 was used, is quite difficult w/o super gross hacks. However, there + # are new tests in test_pkey.py which use known signature blobs to prove + # the SHA2 family was in fact used! + + @requires_sha1_signing + def test_base_case_ssh_rsa_still_used_as_fallback(self): + # Prove that ssh-rsa is used if either, or both, participants have SHA2 + # algorithms disabled + for which in ("init", "client_init", "server_init"): + with server(**{which: _disable_sha2}) as (tc, _): + assert tc.host_key_type == "ssh-rsa" + + def test_kex_with_sha2_512(self): + # It's the default! + with server() as (tc, _): + assert tc.host_key_type == "rsa-sha2-512" + + def test_kex_with_sha2_256(self): + # No 512 -> you get 256 + with server( + init=dict(disabled_algorithms=dict(keys=["rsa-sha2-512"])) + ) as (tc, _): + assert tc.host_key_type == "rsa-sha2-256" + + def _incompatible_peers(self, client_init, server_init): + with server( + client_init=client_init, server_init=server_init, catch_error=True + ) as (tc, ts, err): + # If neither side blew up then that's bad! + assert err is not None + # If client side blew up first, it'll be straightforward + if isinstance(err, IncompatiblePeer): + pass + # If server side blew up first, client sees EOF & we need to check + # the server transport for its saved error (otherwise it can only + # appear in log output) + elif isinstance(err, EOFError): + assert ts.saved_exception is not None + assert isinstance(ts.saved_exception, IncompatiblePeer) + # If it was something else, welp + else: + raise err + + def test_client_sha2_disabled_server_sha1_disabled_no_match(self): + self._incompatible_peers( + client_init=_disable_sha2, server_init=_disable_sha1 + ) + + def test_client_sha1_disabled_server_sha2_disabled_no_match(self): + self._incompatible_peers( + client_init=_disable_sha1, server_init=_disable_sha2 + ) + + def test_explicit_client_hostkey_not_limited(self): + # Be very explicit about the hostkey on BOTH ends, + # and ensure it still ends up choosing sha2-512. + # (This is a regression test vs previous implementation which overwrote + # the entire preferred-hostkeys structure when given an explicit key as + # a client.) + hostkey = RSAKey.from_private_key_file(_support("rsa.key")) + connect = dict( + hostkey=hostkey, username="slowdive", password="pygmalion" + ) + with server(hostkey=hostkey, connect=connect) as (tc, _): + assert tc.host_key_type == "rsa-sha2-512" + + +class TestExtInfo(unittest.TestCase): + def test_ext_info_handshake_exposed_in_client_kexinit(self): + with server() as (tc, _): + # NOTE: this is latest KEXINIT /sent by us/ (Transport retains it) + kex = tc._get_latest_kex_init() + # flag in KexAlgorithms list + assert "ext-info-c" in kex["kex_algo_list"] + # data stored on Transport after hearing back from a compatible + # server (such as ourselves in server mode) + assert tc.server_extensions == { + "server-sig-algs": b"ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-512,rsa-sha2-256,ssh-rsa,ssh-dss" # noqa + } + + def test_client_uses_server_sig_algs_for_pubkey_auth(self): + privkey = RSAKey.from_private_key_file(_support("rsa.key")) + with server( + pubkeys=[privkey], + connect=dict(pkey=privkey), + server_init=dict( + disabled_algorithms=dict(pubkeys=["rsa-sha2-512"]) + ), + ) as (tc, _): + assert tc.is_authenticated() + # Client settled on 256 despite itself not having 512 disabled (and + # otherwise, 512 would have been earlier in the preferred list) + assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" + + +class BadSeqPacketizer(Packetizer): + def read_message(self): + cmd, msg = super().read_message() + # Only mess w/ seqno if kexinit. + if cmd is MSG_KEXINIT: + # NOTE: this is /only/ the copy of the seqno which gets + # transmitted up from Packetizer; it's not modifying + # Packetizer's own internal seqno. For these tests, + # modifying the latter isn't required, and is also harder + # to do w/o triggering MAC mismatches. + msg.seqno = 17 # arbitrary nonzero int + return cmd, msg + + +class TestStrictKex: + def test_kex_algos_includes_kex_strict_c(self): + with server() as (tc, _): + kex = tc._get_latest_kex_init() + assert "kex-strict-c-v00@openssh.com" in kex["kex_algo_list"] + + @mark.parametrize( + "server_active,client_active", + itertools.product([True, False], repeat=2), + ) + def test_mode_agreement(self, server_active, client_active): + with server( + server_init=dict(strict_kex=server_active), + client_init=dict(strict_kex=client_active), + ) as (tc, ts): + if server_active and client_active: + assert tc.agreed_on_strict_kex is True + assert ts.agreed_on_strict_kex is True + else: + assert tc.agreed_on_strict_kex is False + assert ts.agreed_on_strict_kex is False + + def test_mode_advertised_by_default(self): + # NOTE: no explicit strict_kex overrides... + with server() as (tc, ts): + assert all( + ( + tc.advertise_strict_kex, + tc.agreed_on_strict_kex, + ts.advertise_strict_kex, + ts.agreed_on_strict_kex, + ) + ) + + @mark.parametrize( + "ptype", + ( + # "normal" but definitely out-of-order message + MSG_CHANNEL_OPEN, + # Normally ignored, but not in this case + MSG_IGNORE, + # Normally triggers debug parsing, but not in this case + MSG_DEBUG, + # Normally ignored, but...you get the idea + MSG_UNIMPLEMENTED, + # Not real, so would normally trigger us /sending/ + # MSG_UNIMPLEMENTED, but... + MSG_FUGGEDABOUTIT, + ), + ) + def test_MessageOrderError_non_kex_messages_in_initial_kex(self, ptype): + class AttackTransport(Transport): + # Easiest apparent spot on server side which is: + # - late enough for both ends to have handshook on strict mode + # - early enough to be in the window of opportunity for Terrapin + # attack; essentially during actual kex, when the engine is + # waiting for things like MSG_KEXECDH_REPLY (for eg curve25519). + def _negotiate_keys(self, m): + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.clear() + finally: + self.clear_to_send_lock.release() + if self.local_kex_init is None: + # remote side wants to renegotiate + self._send_kex_init() + self._parse_kex_init(m) + # Here, we would normally kick over to kex_engine, but instead + # we want the server to send the OOO message. + m = Message() + m.add_byte(byte_chr(ptype)) + # rest of packet unnecessary... + self._send_message(m) + + with raises(MessageOrderError): + with server(server_transport_factory=AttackTransport) as (tc, _): + pass # above should run and except during connect() + + def test_SSHException_raised_on_out_of_order_messages_when_not_strict( + self, + ): + # This is kind of dumb (either situation is still fatal!) but whatever, + # may as well be strict with our new strict flag... + with raises(SSHException) as info: # would be true either way, but + with server( + client_init=dict(strict_kex=False), + ) as (tc, _): + tc._expect_packet(MSG_KEXINIT) + tc.open_session() + assert info.type is SSHException # NOT MessageOrderError! + + def test_error_not_raised_when_kexinit_not_seq_0_but_unstrict(self): + with server( + client_init=dict( + # Disable strict kex + strict_kex=False, + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which would trigger the + # new logic if we'd forgotten to wrap it in strict-kex check + packetizer_class=BadSeqPacketizer, + ), + ): + pass # kexinit happens at connect... + + def test_MessageOrderError_raised_when_kexinit_not_seq_0_and_strict(self): + with raises(MessageOrderError): + with server( + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which should trigger the + # new logic (given we are NOT disabling strict-mode) + client_init=dict(packetizer_class=BadSeqPacketizer), + ): + pass # kexinit happens at connect... + + def test_sequence_numbers_reset_on_newkeys_when_strict(self): + with server(defer=True) as (tc, ts): + # When in strict mode, these should all be zero or close to it + # (post-kexinit, pre-auth). + # Server->client will be 1 (EXT_INFO got sent after NEWKEYS) + assert tc.packetizer._Packetizer__sequence_number_in == 1 + assert ts.packetizer._Packetizer__sequence_number_out == 1 + # Client->server will be 0 + assert tc.packetizer._Packetizer__sequence_number_out == 0 + assert ts.packetizer._Packetizer__sequence_number_in == 0 + + def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self): + with server(defer=True, client_init=dict(strict_kex=False)) as ( + tc, + ts, + ): + # When not in strict mode, these will all be ~3-4 or so + # (post-kexinit, pre-auth). Not encoding exact values as it will + # change anytime we mess with the test harness... + assert tc.packetizer._Packetizer__sequence_number_in != 0 + assert tc.packetizer._Packetizer__sequence_number_out != 0 + assert ts.packetizer._Packetizer__sequence_number_in != 0 + assert ts.packetizer._Packetizer__sequence_number_out != 0 + + def test_sequence_number_rollover_detected(self): + class RolloverTransport(Transport): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Induce an about-to-rollover seqno, such that it rolls over + # during initial kex. + setattr( + self.packetizer, + "_Packetizer__sequence_number_in", + sys.maxsize, + ) + setattr( + self.packetizer, + "_Packetizer__sequence_number_out", + sys.maxsize, + ) + + with raises( + SSHException, + match=r"Sequence number rolled over during initial kex!", + ): + with server( + client_init=dict( + # Disable strict kex - this should happen always + strict_kex=False, + ), + # Transport which tickles its packetizer seqno's + transport_factory=RolloverTransport, + ): + pass # kexinit happens at connect... |