diff options
Diffstat (limited to '')
-rw-r--r-- | tests/test_packetizer.py | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/tests/test_packetizer.py b/tests/test_packetizer.py new file mode 100644 index 0000000..aee21c2 --- /dev/null +++ b/tests/test_packetizer.py @@ -0,0 +1,148 @@ +# Copyright (C) 2003-2009 Robey Pointer <robeypointer@gmail.com> +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +""" +Some unit tests for the ssh2 protocol in Transport. +""" + +import sys +import unittest +from hashlib import sha1 + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes + +from paramiko import Message, Packetizer, util +from paramiko.common import byte_chr, zero_byte + +from ._loop import LoopSocket + + +x55 = byte_chr(0x55) +x1f = byte_chr(0x1F) + + +class PacketizerTest(unittest.TestCase): + def test_write(self): + rsock = LoopSocket() + wsock = LoopSocket() + rsock.link(wsock) + p = Packetizer(wsock) + p.set_log(util.get_logger("paramiko.transport")) + p.set_hexdump(True) + encryptor = Cipher( + algorithms.AES(zero_byte * 16), + modes.CBC(x55 * 16), + backend=default_backend(), + ).encryptor() + p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) + + # message has to be at least 16 bytes long, so we'll have at least one + # block of data encrypted that contains zero random padding bytes + m = Message() + m.add_byte(byte_chr(100)) + m.add_int(100) + m.add_int(1) + m.add_int(900) + p.send_message(m) + data = rsock.recv(100) + # 32 + 12 bytes of MAC = 44 + self.assertEqual(44, len(data)) + self.assertEqual( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0", # noqa + data[:16], + ) + + def test_read(self): + rsock = LoopSocket() + wsock = LoopSocket() + rsock.link(wsock) + p = Packetizer(rsock) + p.set_log(util.get_logger("paramiko.transport")) + p.set_hexdump(True) + decryptor = Cipher( + algorithms.AES(zero_byte * 16), + modes.CBC(x55 * 16), + backend=default_backend(), + ).decryptor() + p.set_inbound_cipher(decryptor, 16, sha1, 12, x1f * 20) + wsock.send( + b"\x43\x91\x97\xbd\x5b\x50\xac\x25\x87\xc2\xc4\x6b\xc7\xe9\x38\xc0\x90\xd2\x16\x56\x0d\x71\x73\x61\x38\x7c\x4c\x3d\xfb\x97\x7d\xe2\x6e\x03\xb1\xa0\xc2\x1c\xd6\x41\x41\x4c\xb4\x59" # noqa + ) + cmd, m = p.read_message() + self.assertEqual(100, cmd) + self.assertEqual(100, m.get_int()) + self.assertEqual(1, m.get_int()) + self.assertEqual(900, m.get_int()) + + def test_closed(self): + if sys.platform.startswith("win"): # no SIGALRM on windows + return + rsock = LoopSocket() + wsock = LoopSocket() + rsock.link(wsock) + p = Packetizer(wsock) + p.set_log(util.get_logger("paramiko.transport")) + p.set_hexdump(True) + encryptor = Cipher( + algorithms.AES(zero_byte * 16), + modes.CBC(x55 * 16), + backend=default_backend(), + ).encryptor() + p.set_outbound_cipher(encryptor, 16, sha1, 12, x1f * 20) + + # message has to be at least 16 bytes long, so we'll have at least one + # block of data encrypted that contains zero random padding bytes + m = Message() + m.add_byte(byte_chr(100)) + m.add_int(100) + m.add_int(1) + m.add_int(900) + wsock.send = lambda x: 0 + from functools import wraps + import errno + import os + import signal + + class TimeoutError(Exception): + def __init__(self, error_message): + if hasattr(errno, "ETIME"): + self.message = os.sterror(errno.ETIME) + else: + self.messaage = error_message + + def timeout(seconds=1, error_message="Timer expired"): + def decorator(func): + def _handle_timeout(signum, frame): + raise TimeoutError(error_message) + + def wrapper(*args, **kwargs): + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(seconds) + try: + result = func(*args, **kwargs) + finally: + signal.alarm(0) + return result + + return wraps(func)(wrapper) + + return decorator + + send = timeout()(p.send_message) + self.assertRaises(EOFError, send, m) |