diff options
Diffstat (limited to 'tests/test_streams.py')
-rw-r--r-- | tests/test_streams.py | 434 |
1 files changed, 434 insertions, 0 deletions
diff --git a/tests/test_streams.py b/tests/test_streams.py new file mode 100644 index 0000000..7ae441b --- /dev/null +++ b/tests/test_streams.py @@ -0,0 +1,434 @@ +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from shutil import which +from subprocess import run +from typing import Optional + +import pytest + +import aristaproto +from tests.output_aristaproto import ( + map, + nested, + oneof, + repeated, + repeatedpacked, +) + + +oneof_example = oneof.Test().from_dict( + {"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"} +) + +len_oneof = len(oneof_example) + +nested_example = nested.Test().from_dict( + { + "nested": {"count": 1}, + "sibling": {"foo": 2}, + "sibling2": {"foo": 3}, + "msg": nested.TestMsg.THIS, + } +) + +repeated_example = repeated.Test().from_dict({"names": ["blah", "Blah2"]}) + +packed_example = repeatedpacked.Test().from_dict( + {"counts": [1, 2, 3], "signed": [-1, 2, -3], "fixed": [1.2, -2.3, 3.4]} +) + +map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}}) + +streams_path = Path("tests/streams/") + +java = which("java") + + +def test_load_varint_too_long(): + with BytesIO( + b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01" + ) as stream, pytest.raises(ValueError): + aristaproto.load_varint(stream) + + with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream: + # This should not raise a ValueError, as it is within 64 bits + aristaproto.load_varint(stream) + + +def test_load_varint_file(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + assert aristaproto.load_varint(stream) == (8, b"\x08") # Single-byte varint + stream.read(2) # Skip until first multi-byte + assert aristaproto.load_varint(stream) == ( + 123456789, + b"\x95\x9A\xEF\x3A", + ) # Multi-byte varint + + +def test_load_varint_cutoff(): + with open(streams_path / "load_varint_cutoff.in", "rb") as stream: + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + stream.seek(1) + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + +def test_dump_varint_file(tmp_path): + # Dump test varints to file + with open(tmp_path / "dump_varint_file.out", "wb") as stream: + aristaproto.dump_varint(8, stream) # Single-byte varint + aristaproto.dump_varint(123456789, stream) # Multi-byte varint + + # Check that file contents are as expected + with open(tmp_path / "dump_varint_file.out", "rb") as test_stream, open( + streams_path / "message_dump_file_single.expected", "rb" + ) as exp_stream: + assert aristaproto.load_varint(test_stream) == aristaproto.load_varint( + exp_stream + ) + exp_stream.read(2) + assert aristaproto.load_varint(test_stream) == aristaproto.load_varint( + exp_stream + ) + + +def test_parse_fields(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + parsed_bytes = aristaproto.parse_fields(stream.read()) + + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + parsed_stream = aristaproto.load_fields(stream) + for field in parsed_bytes: + assert field == next(parsed_stream) + + +def test_message_dump_file_single(tmp_path): + # Write the message to the stream + with open(tmp_path / "message_dump_file_single.out", "wb") as stream: + oneof_example.dump(stream) + + # Check that the outputted file is exactly as expected + with open(tmp_path / "message_dump_file_single.out", "rb") as test_stream, open( + streams_path / "message_dump_file_single.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_dump_file_multiple(tmp_path): + # Write the same Message twice and another, different message + with open(tmp_path / "message_dump_file_multiple.out", "wb") as stream: + oneof_example.dump(stream) + oneof_example.dump(stream) + nested_example.dump(stream) + + # Check that all three Messages were outputted to the file correctly + with open(tmp_path / "message_dump_file_multiple.out", "rb") as test_stream, open( + streams_path / "message_dump_file_multiple.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_dump_delimited(tmp_path): + with open(tmp_path / "message_dump_delimited.out", "wb") as stream: + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + nested_example.dump(stream, aristaproto.SIZE_DELIMITED) + + with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open( + streams_path / "delimited_messages.in", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_len(): + assert len_oneof == len(bytes(oneof_example)) + assert len(nested_example) == len(bytes(nested_example)) + + +def test_message_load_file_single(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + assert oneof.Test().load(stream) == oneof_example + stream.seek(0) + assert oneof.Test().load(stream, len_oneof) == oneof_example + + +def test_message_load_file_multiple(): + with open(streams_path / "message_dump_file_multiple.expected", "rb") as stream: + oneof_size = len_oneof + assert oneof.Test().load(stream, oneof_size) == oneof_example + assert oneof.Test().load(stream, oneof_size) == oneof_example + assert nested.Test().load(stream) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_small(): + with open( + streams_path / "message_dump_file_single.expected", "rb" + ) as stream, pytest.raises(ValueError): + oneof.Test().load(stream, len_oneof - 1) + + +def test_message_load_delimited(): + with open(streams_path / "delimited_messages.in", "rb") as stream: + assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example + assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example + assert nested.Test().load(stream, aristaproto.SIZE_DELIMITED) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_large(): + with open( + streams_path / "message_dump_file_single.expected", "rb" + ) as stream, pytest.raises(ValueError): + oneof.Test().load(stream, len_oneof + 1) + + +def test_message_len_optional_field(): + @dataclass + class Request(aristaproto.Message): + flag: Optional[bool] = aristaproto.message_field(1, wraps=aristaproto.TYPE_BOOL) + + assert len(Request()) == len(b"") + assert len(Request(flag=True)) == len(b"\n\x02\x08\x01") + assert len(Request(flag=False)) == len(b"\n\x00") + + +def test_message_len_repeated_field(): + assert len(repeated_example) == len(bytes(repeated_example)) + + +def test_message_len_packed_field(): + assert len(packed_example) == len(bytes(packed_example)) + + +def test_message_len_map_field(): + assert len(map_example) == len(bytes(map_example)) + + +def test_message_len_empty_string(): + @dataclass + class Empty(aristaproto.Message): + string: str = aristaproto.string_field(1, "group") + integer: int = aristaproto.int32_field(2, "group") + + empty = Empty().from_dict({"string": ""}) + assert len(empty) == len(bytes(empty)) + + +def test_calculate_varint_size_negative(): + single_byte = -1 + multi_byte = -10000000 + edge = -(1 << 63) + beyond = -(1 << 63) - 1 + before = -(1 << 63) + 1 + + assert ( + aristaproto.size_varint(single_byte) + == len(aristaproto.encode_varint(single_byte)) + == 10 + ) + assert ( + aristaproto.size_varint(multi_byte) + == len(aristaproto.encode_varint(multi_byte)) + == 10 + ) + assert aristaproto.size_varint(edge) == len(aristaproto.encode_varint(edge)) == 10 + assert ( + aristaproto.size_varint(before) == len(aristaproto.encode_varint(before)) == 10 + ) + + with pytest.raises(ValueError): + aristaproto.size_varint(beyond) + + +def test_calculate_varint_size_positive(): + single_byte = 1 + multi_byte = 10000000 + + assert aristaproto.size_varint(single_byte) == len( + aristaproto.encode_varint(single_byte) + ) + assert aristaproto.size_varint(multi_byte) == len( + aristaproto.encode_varint(multi_byte) + ) + + +def test_dump_varint_negative(tmp_path): + single_byte = -1 + multi_byte = -10000000 + edge = -(1 << 63) + beyond = -(1 << 63) - 1 + before = -(1 << 63) + 1 + + with open(tmp_path / "dump_varint_negative.out", "wb") as stream: + aristaproto.dump_varint(single_byte, stream) + aristaproto.dump_varint(multi_byte, stream) + aristaproto.dump_varint(edge, stream) + aristaproto.dump_varint(before, stream) + + with pytest.raises(ValueError): + aristaproto.dump_varint(beyond, stream) + + with open(streams_path / "dump_varint_negative.expected", "rb") as exp_stream, open( + tmp_path / "dump_varint_negative.out", "rb" + ) as test_stream: + assert test_stream.read() == exp_stream.read() + + +def test_dump_varint_positive(tmp_path): + single_byte = 1 + multi_byte = 10000000 + + with open(tmp_path / "dump_varint_positive.out", "wb") as stream: + aristaproto.dump_varint(single_byte, stream) + aristaproto.dump_varint(multi_byte, stream) + + with open(tmp_path / "dump_varint_positive.out", "rb") as test_stream, open( + streams_path / "dump_varint_positive.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +# Java compatibility tests + + +@pytest.fixture(scope="module") +def compile_jar(): + # Skip if not all required tools are present + if java is None: + pytest.skip("`java` command is absent and is required") + mvn = which("mvn") + if mvn is None: + pytest.skip("Maven is absent and is required") + + # Compile the JAR + proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"]) + if proc_maven.returncode != 0: + pytest.skip( + "Maven compatibility-test.jar build failed (maybe Java version <11?)" + ) + + +jar = "tests/streams/java/target/compatibility-test.jar" + + +def run_jar(command: str, tmp_path): + return run([java, "-jar", jar, command, tmp_path], check=True) + + +def run_java_single_varint(value: int, tmp_path) -> int: + # Write single varint to file + with open(tmp_path / "py_single_varint.out", "wb") as stream: + aristaproto.dump_varint(value, stream) + + # Have Java read this varint and write it back + run_jar("single_varint", tmp_path) + + # Read single varint from Java output file + with open(tmp_path / "java_single_varint.out", "rb") as stream: + returned = aristaproto.load_varint(stream) + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + return returned + + +def test_single_varint(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + + # Write a single-byte varint to a file and have Java read it back + returned = run_java_single_varint(single_byte[0], tmp_path) + assert returned == single_byte + + # Same for a multi-byte varint + returned = run_java_single_varint(multi_byte[0], tmp_path) + assert returned == multi_byte + + +def test_multiple_varints(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B") + + # Write two varints to the same file + with open(tmp_path / "py_multiple_varints.out", "wb") as stream: + aristaproto.dump_varint(single_byte[0], stream) + aristaproto.dump_varint(multi_byte[0], stream) + aristaproto.dump_varint(over32[0], stream) + + # Have Java read these varints and write them back + run_jar("multiple_varints", tmp_path) + + # Read varints from Java output file + with open(tmp_path / "java_multiple_varints.out", "rb") as stream: + returned_single = aristaproto.load_varint(stream) + returned_multi = aristaproto.load_varint(stream) + returned_over32 = aristaproto.load_varint(stream) + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + assert returned_single == single_byte + assert returned_multi == multi_byte + assert returned_over32 == over32 + + +def test_single_message(compile_jar, tmp_path): + # Write message to file + with open(tmp_path / "py_single_message.out", "wb") as stream: + oneof_example.dump(stream) + + # Have Java read and return the message + run_jar("single_message", tmp_path) + + # Read and check the returned message + with open(tmp_path / "java_single_message.out", "rb") as stream: + returned = oneof.Test().load(stream, len(bytes(oneof_example))) + assert stream.read() == b"" + + assert returned == oneof_example + + +def test_multiple_messages(compile_jar, tmp_path): + # Write delimited messages to file + with open(tmp_path / "py_multiple_messages.out", "wb") as stream: + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + nested_example.dump(stream, aristaproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("multiple_messages", tmp_path) + + # Read and check the returned messages + with open(tmp_path / "java_multiple_messages.out", "rb") as stream: + returned_oneof = oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) + returned_nested = nested.Test().load(stream, aristaproto.SIZE_DELIMITED) + assert stream.read() == b"" + + assert returned_oneof == oneof_example + assert returned_nested == nested_example + + +def test_infinite_messages(compile_jar, tmp_path): + num_messages = 5 + + # Write delimited messages to file + with open(tmp_path / "py_infinite_messages.out", "wb") as stream: + for x in range(num_messages): + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("infinite_messages", tmp_path) + + # Read and check the returned messages + messages = [] + with open(tmp_path / "java_infinite_messages.out", "rb") as stream: + while True: + try: + messages.append(oneof.Test().load(stream, aristaproto.SIZE_DELIMITED)) + except EOFError: + break + + assert len(messages) == num_messages |