diff options
Diffstat (limited to 'tests')
183 files changed, 7001 insertions, 0 deletions
diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..1301f6b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,91 @@ +# Standard Tests Development Guide + +Standard test cases are found in [aristaproto/tests/inputs](inputs), where each subdirectory represents a testcase, that is verified in isolation. + +``` +inputs/ + bool/ + double/ + int32/ + ... +``` + +## Test case directory structure + +Each testcase has a `<name>.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`. + +```bash +bool/ + bool.proto + bool.json # optional + test_bool.py # optional +``` + +### proto + +`<name>.proto` — *The protobuf message to test* + +```protobuf +syntax = "proto3"; + +message Test { + bool value = 1; +} +``` + +You can add multiple `.proto` files to the test case, as long as one file matches the directory name. + +### json + +`<name>.json` —Â *Test-data to validate the message with* + +```json +{ + "value": true +} +``` + +### pytest + +`test_<name>.py` — *Custom test to validate specific aspects of the generated class* + +```python +from tests.output_aristaproto.bool.bool import Test + +def test_value(): + message = Test() + assert not message.value, "Boolean is False by default" +``` + +## Standard tests + +The following tests are automatically executed for all cases: + +- [x] Can the generated python code be imported? +- [x] Can the generated message class be instantiated? +- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation? + - _when `.json` is present_ + +## Running the tests + +- `pipenv run generate` + This generates: + - `aristaproto/tests/output_aristaproto` —Â *the plugin generated python classes* + - `aristaproto/tests/output_reference` — *reference implementation classes* +- `pipenv run test` + +## Intentionally Failing tests + +The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future. + +When running `pytest`, they show up as `x` or `X` in the test results. + +``` +aristaproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x.......x.xx....x...................... [ 84%] +``` + +- `.` — PASSED +- `x` —Â XFAIL: expected failure +- `X` —Â XPASS: expected failure, but still passed + +Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py)
\ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c6b256d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +import copy +import sys + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--repeat", type=int, default=1, help="repeat the operation multiple times" + ) + + +@pytest.fixture(scope="session") +def repeat(request): + return request.config.getoption("repeat") + + +@pytest.fixture +def reset_sys_path(): + original = copy.deepcopy(sys.path) + yield + sys.path = original diff --git a/tests/generate.py b/tests/generate.py new file mode 100755 index 0000000..d6f36de --- /dev/null +++ b/tests/generate.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python +import asyncio +import os +import platform +import shutil +import sys +from pathlib import Path +from typing import Set + +from tests.util import ( + get_directories, + inputs_path, + output_path_aristaproto, + output_path_aristaproto_pydantic, + output_path_reference, + protoc, +) + + +# Force pure-python implementation instead of C++, otherwise imports +# break things because we can't properly reset the symbol database. +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + + +def clear_directory(dir_path: Path): + for file_or_directory in dir_path.glob("*"): + if file_or_directory.is_dir(): + shutil.rmtree(file_or_directory) + else: + file_or_directory.unlink() + + +async def generate(whitelist: Set[str], verbose: bool): + test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} + + path_whitelist = set() + name_whitelist = set() + for item in whitelist: + if item in test_case_names: + name_whitelist.add(item) + continue + path_whitelist.add(item) + + generation_tasks = [] + for test_case_name in sorted(test_case_names): + test_case_input_path = inputs_path.joinpath(test_case_name).resolve() + if ( + whitelist + and str(test_case_input_path) not in path_whitelist + and test_case_name not in name_whitelist + ): + continue + generation_tasks.append( + generate_test_case_output(test_case_input_path, test_case_name, verbose) + ) + + failed_test_cases = [] + # Wait for all subprocs and match any failures to names to report + for test_case_name, result in zip( + sorted(test_case_names), await asyncio.gather(*generation_tasks) + ): + if result != 0: + failed_test_cases.append(test_case_name) + + if len(failed_test_cases) > 0: + sys.stderr.write( + "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n" + ) + for failed_test_case in failed_test_cases: + sys.stderr.write(f"- {failed_test_case}\n") + + sys.exit(1) + + +async def generate_test_case_output( + test_case_input_path: Path, test_case_name: str, verbose: bool +) -> int: + """ + Returns the max of the subprocess return values + """ + + test_case_output_path_reference = output_path_reference.joinpath(test_case_name) + test_case_output_path_aristaproto = output_path_aristaproto + test_case_output_path_aristaproto_pyd = output_path_aristaproto_pydantic + + os.makedirs(test_case_output_path_reference, exist_ok=True) + os.makedirs(test_case_output_path_aristaproto, exist_ok=True) + os.makedirs(test_case_output_path_aristaproto_pyd, exist_ok=True) + + clear_directory(test_case_output_path_reference) + clear_directory(test_case_output_path_aristaproto) + + ( + (ref_out, ref_err, ref_code), + (plg_out, plg_err, plg_code), + (plg_out_pyd, plg_err_pyd, plg_code_pyd), + ) = await asyncio.gather( + protoc(test_case_input_path, test_case_output_path_reference, True), + protoc(test_case_input_path, test_case_output_path_aristaproto, False), + protoc( + test_case_input_path, test_case_output_path_aristaproto_pyd, False, True + ), + ) + + if ref_code == 0: + print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m") + else: + print( + f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m" + ) + + if verbose: + if ref_out: + print("Reference stdout:") + sys.stdout.buffer.write(ref_out) + sys.stdout.buffer.flush() + + if ref_err: + print("Reference stderr:") + sys.stderr.buffer.write(ref_err) + sys.stderr.buffer.flush() + + if plg_code == 0: + print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m") + else: + print( + f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m" + ) + + if verbose: + if plg_out: + print("Plugin stdout:") + sys.stdout.buffer.write(plg_out) + sys.stdout.buffer.flush() + + if plg_err: + print("Plugin stderr:") + sys.stderr.buffer.write(plg_err) + sys.stderr.buffer.flush() + + if plg_code_pyd == 0: + print( + f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m" + ) + else: + print( + f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m" + ) + + if verbose: + if plg_out_pyd: + print("Plugin stdout:") + sys.stdout.buffer.write(plg_out_pyd) + sys.stdout.buffer.flush() + + if plg_err_pyd: + print("Plugin stderr:") + sys.stderr.buffer.write(plg_err_pyd) + sys.stderr.buffer.flush() + + return max(ref_code, plg_code, plg_code_pyd) + + +HELP = "\n".join( + ( + "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]", + "Generate python classes for standard tests.", + "", + "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.", + " python generate.py inputs/bool inputs/double inputs/enum", + "", + "NAMES One or more test-case names to generate classes for.", + " python generate.py bool double enums", + ) +) + + +def main(): + if set(sys.argv).intersection({"-h", "--help"}): + print(HELP) + return + if sys.argv[1:2] == ["-v"]: + verbose = True + whitelist = set(sys.argv[2:]) + else: + verbose = False + whitelist = set(sys.argv[1:]) + + if platform.system() == "Windows": + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + asyncio.run(generate(whitelist, verbose)) + + +if __name__ == "__main__": + main() diff --git a/tests/grpc/__init__.py b/tests/grpc/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/grpc/__init__.py diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py new file mode 100644 index 0000000..d36e4a5 --- /dev/null +++ b/tests/grpc/test_grpclib_client.py @@ -0,0 +1,298 @@ +import asyncio +import sys +import uuid + +import grpclib +import grpclib.client +import grpclib.metadata +import grpclib.server +import pytest +from grpclib.testing import ChannelFor + +from aristaproto.grpc.util.async_channel import AsyncChannel +from tests.output_aristaproto.service import ( + DoThingRequest, + DoThingResponse, + GetThingRequest, + TestStub as ThingServiceClient, +) + +from .thing_service import ThingService + + +async def _test_client(client: ThingServiceClient, name="clean room", **kwargs): + response = await client.do_thing(DoThingRequest(name=name), **kwargs) + assert response.names == [name] + + +def _assert_request_meta_received(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be received serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be received serverside" + + return server_side_test + + +@pytest.fixture +def handler_trailer_only_unauthenticated(): + async def handler(stream: grpclib.server.Stream): + await stream.recv_message() + await stream.send_initial_metadata() + await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED) + + return handler + + +@pytest.mark.asyncio +async def test_simple_service_call(): + async with ChannelFor([ThingService()]) as channel: + await _test_client(ThingServiceClient(channel)) + + +@pytest.mark.asyncio +async def test_trailer_only_error_unary_unary( + mocker, handler_trailer_only_unauthenticated +): + service = ThingService() + mocker.patch.object( + service, + "do_thing", + side_effect=handler_trailer_only_unauthenticated, + autospec=True, + ) + async with ChannelFor([service]) as channel: + with pytest.raises(grpclib.exceptions.GRPCError) as e: + await ThingServiceClient(channel).do_thing(DoThingRequest(name="something")) + assert e.value.status == grpclib.Status.UNAUTHENTICATED + + +@pytest.mark.asyncio +async def test_trailer_only_error_stream_unary( + mocker, handler_trailer_only_unauthenticated +): + service = ThingService() + mocker.patch.object( + service, + "do_many_things", + side_effect=handler_trailer_only_unauthenticated, + autospec=True, + ) + async with ChannelFor([service]) as channel: + with pytest.raises(grpclib.exceptions.GRPCError) as e: + await ThingServiceClient(channel).do_many_things( + do_thing_request_iterator=[DoThingRequest(name="something")] + ) + await _test_client(ThingServiceClient(channel)) + assert e.value.status == grpclib.Status.UNAUTHENTICATED + + +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="async mock spy does works for python3.8+" +) +async def test_service_call_mutable_defaults(mocker): + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + spy = mocker.spy(client, "_unary_unary") + await _test_client(client) + comments = spy.call_args_list[-1].args[1].comments + await _test_client(client) + assert spy.call_args_list[-1].args[1].comments is not comments + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] + ) as channel: + await _test_client( + ThingServiceClient(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] + ) as channel: + await _test_client( + ThingServiceClient(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + THING_TO_DO = "get milk" + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata), + ) + ] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("overrides_gen",), + [ + (lambda: dict(timeout=10),), + (lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),), + (lambda: dict(metadata={"authorization": str(uuid.uuid4())}),), + (lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),), + ], +) +async def test_service_call_high_level_with_overrides(mocker, overrides_gen): + overrides = overrides_gen() + request_spy = mocker.spy(grpclib.client.Channel, "request") + name = str(uuid.uuid4()) + defaults = dict( + timeout=99, + deadline=grpclib.metadata.Deadline.from_timeout(99), + metadata={"authorization": name}, + ) + + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_received( + deadline=grpclib.metadata.Deadline.from_timeout( + overrides.get("timeout", 99) + ), + metadata=overrides.get("metadata", defaults.get("metadata")), + ) + ) + ] + ) as channel: + client = ThingServiceClient(channel, **defaults) + await _test_client(client, name=name, **overrides) + assert request_spy.call_count == 1 + + # for python <3.8 request_spy.call_args.kwargs do not work + _, request_spy_call_kwargs = request_spy.call_args_list[0] + + # ensure all overrides were successful + for key, value in overrides.items(): + assert key in request_spy_call_kwargs + assert request_spy_call_kwargs[key] == value + + # ensure default values were retained + for key in set(defaults.keys()) - set(overrides.keys()): + assert key in request_spy_call_kwargs + assert request_spy_call_kwargs[key] == defaults[key] + + +@pytest.mark.asyncio +async def test_async_gen_for_unary_stream_request(): + thing_name = "my milkshakes" + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + expected_versions = [5, 4, 3, 2, 1] + async for response in client.get_thing_versions( + GetThingRequest(name=thing_name) + ): + assert response.name == thing_name + assert response.version == expected_versions.pop() + + +@pytest.mark.asyncio +async def test_async_gen_for_stream_stream_request(): + some_things = ["cake", "cricket", "coral reef"] + more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"] + expected_things = (*some_things, *more_things) + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + # Use an AsyncChannel to decouple sending and recieving, it'll send some_things + # immediately and we'll use it to send more_things later, after recieving some + # results + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) + response_index = 0 + async for response in client.get_different_things(request_chan): + assert response.name == expected_things[response_index] + assert response.version == response_index + 1 + response_index += 1 + if more_things: + # Send some more requests as we receive responses to be sure coordination of + # send/receive events doesn't matter + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests + else: + # No more things to send make sure channel is closed + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't receive all expected responses" + + +@pytest.mark.asyncio +async def test_stream_unary_with_empty_iterable(): + things = [] # empty + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + requests = [DoThingRequest(name) for name in things] + response = await client.do_many_things(requests) + assert len(response.names) == 0 + + +@pytest.mark.asyncio +async def test_stream_stream_with_empty_iterable(): + things = [] # empty + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + requests = [GetThingRequest(name) for name in things] + responses = [ + response async for response in client.get_different_things(requests) + ] + assert len(responses) == 0 diff --git a/tests/grpc/test_stream_stream.py b/tests/grpc/test_stream_stream.py new file mode 100644 index 0000000..d4b27e5 --- /dev/null +++ b/tests/grpc/test_stream_stream.py @@ -0,0 +1,99 @@ +import asyncio +from dataclasses import dataclass +from typing import AsyncIterator + +import pytest + +import aristaproto +from aristaproto.grpc.util.async_channel import AsyncChannel + + +@dataclass +class Message(aristaproto.Message): + body: str = aristaproto.string_field(1) + + +@pytest.fixture +def expected_responses(): + return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] + + +class ClientStub: + async def connect(self, requests: AsyncIterator): + await asyncio.sleep(0.1) + async for request in requests: + await asyncio.sleep(0.1) + yield request + await asyncio.sleep(0.1) + yield Message("Done") + + +async def to_list(generator: AsyncIterator): + return [value async for value in generator] + + +@pytest.fixture +def client(): + # channel = Channel(host='127.0.0.1', port=50051) + # return ClientStub(channel) + return ClientStub() + + +@pytest.mark.asyncio +async def test_send_from_before_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_after_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + responses = client.connect(requests) + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_close_manually_immediately(client, expected_responses): + requests = AsyncChannel() + responses = client.connect(requests) + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + requests.close() + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_before_connect(client, expected_responses): + requests = AsyncChannel() + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + requests.close() + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_after_connect(client, expected_responses): + requests = AsyncChannel() + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + responses = client.connect(requests) + requests.close() + + assert await to_list(responses) == expected_responses diff --git a/tests/grpc/thing_service.py b/tests/grpc/thing_service.py new file mode 100644 index 0000000..5b00cbe --- /dev/null +++ b/tests/grpc/thing_service.py @@ -0,0 +1,85 @@ +from typing import Dict + +import grpclib +import grpclib.server + +from tests.output_aristaproto.service import ( + DoThingRequest, + DoThingResponse, + GetThingRequest, + GetThingResponse, +) + + +class ThingService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook + + async def do_thing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse([request.name])) + + async def do_many_things( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + thing_names = [request.name async for request in stream] + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse(thing_names)) + + async def get_thing_versions( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + for version_num in range(1, 6): + await stream.send_message( + GetThingResponse(name=request.name, version=version_num) + ) + + async def get_different_things( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + if self.test_hook is not None: + self.test_hook(stream) + # Respond to each input item immediately + response_num = 0 + async for request in stream: + response_num += 1 + await stream.send_message( + GetThingResponse(name=request.name, version=response_num) + ) + + def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]: + return { + "/service.Test/DoThing": grpclib.const.Handler( + self.do_thing, + grpclib.const.Cardinality.UNARY_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/DoManyThings": grpclib.const.Handler( + self.do_many_things, + grpclib.const.Cardinality.STREAM_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/GetThingVersions": grpclib.const.Handler( + self.get_thing_versions, + grpclib.const.Cardinality.UNARY_STREAM, + GetThingRequest, + GetThingResponse, + ), + "/service.Test/GetDifferentThings": grpclib.const.Handler( + self.get_different_things, + grpclib.const.Cardinality.STREAM_STREAM, + GetThingRequest, + GetThingResponse, + ), + } diff --git a/tests/inputs/bool/bool.json b/tests/inputs/bool/bool.json new file mode 100644 index 0000000..348e031 --- /dev/null +++ b/tests/inputs/bool/bool.json @@ -0,0 +1,3 @@ +{ + "value": true +} diff --git a/tests/inputs/bool/bool.proto b/tests/inputs/bool/bool.proto new file mode 100644 index 0000000..77836b8 --- /dev/null +++ b/tests/inputs/bool/bool.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package bool; + +message Test { + bool value = 1; +} diff --git a/tests/inputs/bool/test_bool.py b/tests/inputs/bool/test_bool.py new file mode 100644 index 0000000..f9554ae --- /dev/null +++ b/tests/inputs/bool/test_bool.py @@ -0,0 +1,19 @@ +import pytest + +from tests.output_aristaproto.bool import Test +from tests.output_aristaproto_pydantic.bool import Test as TestPyd + + +def test_value(): + message = Test() + assert not message.value, "Boolean is False by default" + + +def test_pydantic_no_value(): + with pytest.raises(ValueError): + TestPyd() + + +def test_pydantic_value(): + message = Test(value=False) + assert not message.value diff --git a/tests/inputs/bytes/bytes.json b/tests/inputs/bytes/bytes.json new file mode 100644 index 0000000..34c4554 --- /dev/null +++ b/tests/inputs/bytes/bytes.json @@ -0,0 +1,3 @@ +{ + "data": "SGVsbG8sIFdvcmxkIQ==" +} diff --git a/tests/inputs/bytes/bytes.proto b/tests/inputs/bytes/bytes.proto new file mode 100644 index 0000000..9895468 --- /dev/null +++ b/tests/inputs/bytes/bytes.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package bytes; + +message Test { + bytes data = 1; +} diff --git a/tests/inputs/casing/casing.json b/tests/inputs/casing/casing.json new file mode 100644 index 0000000..559104b --- /dev/null +++ b/tests/inputs/casing/casing.json @@ -0,0 +1,4 @@ +{ + "camelCase": 1, + "snakeCase": "ONE" +} diff --git a/tests/inputs/casing/casing.proto b/tests/inputs/casing/casing.proto new file mode 100644 index 0000000..2023d93 --- /dev/null +++ b/tests/inputs/casing/casing.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package casing; + +enum my_enum { + ZERO = 0; + ONE = 1; + TWO = 2; +} + +message Test { + int32 camelCase = 1; + my_enum snake_case = 2; + snake_case_message snake_case_message = 3; + int32 UPPERCASE = 4; +} + +message snake_case_message { + +}
\ No newline at end of file diff --git a/tests/inputs/casing/test_casing.py b/tests/inputs/casing/test_casing.py new file mode 100644 index 0000000..0fa609b --- /dev/null +++ b/tests/inputs/casing/test_casing.py @@ -0,0 +1,23 @@ +import tests.output_aristaproto.casing as casing +from tests.output_aristaproto.casing import Test + + +def test_message_attributes(): + message = Test() + assert hasattr( + message, "snake_case_message" + ), "snake_case field name is same in python" + assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python" + assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python" + + +def test_message_casing(): + assert hasattr( + casing, "SnakeCaseMessage" + ), "snake_case Message name is converted to CamelCase in python" + + +def test_enum_casing(): + assert hasattr( + casing, "MyEnum" + ), "snake_case Enum name is converted to CamelCase in python" diff --git a/tests/inputs/casing_inner_class/casing_inner_class.proto b/tests/inputs/casing_inner_class/casing_inner_class.proto new file mode 100644 index 0000000..fae2a4c --- /dev/null +++ b/tests/inputs/casing_inner_class/casing_inner_class.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package casing_inner_class; + +message Test { + message inner_class { + sint32 old_exp = 1; + } + inner_class inner = 2; +}
\ No newline at end of file diff --git a/tests/inputs/casing_inner_class/test_casing_inner_class.py b/tests/inputs/casing_inner_class/test_casing_inner_class.py new file mode 100644 index 0000000..7c43add --- /dev/null +++ b/tests/inputs/casing_inner_class/test_casing_inner_class.py @@ -0,0 +1,14 @@ +import tests.output_aristaproto.casing_inner_class as casing_inner_class + + +def test_message_casing_inner_class_name(): + assert hasattr( + casing_inner_class, "TestInnerClass" + ), "Inline defined Message is correctly converted to CamelCase" + + +def test_message_casing_inner_class_attributes(): + message = casing_inner_class.Test() + assert hasattr( + message.inner, "old_exp" + ), "Inline defined Message attribute is snake_case" diff --git a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto new file mode 100644 index 0000000..c6d42c3 --- /dev/null +++ b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package casing_message_field_uppercase; + +message Test { + int32 UPPERCASE = 1; + int32 UPPERCASE_V2 = 2; + int32 UPPER_CAMEL_CASE = 3; +}
\ No newline at end of file diff --git a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py new file mode 100644 index 0000000..01a5234 --- /dev/null +++ b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py @@ -0,0 +1,14 @@ +from tests.output_aristaproto.casing_message_field_uppercase import Test + + +def test_message_casing(): + message = Test() + assert hasattr( + message, "uppercase" + ), "UPPERCASE attribute is converted to 'uppercase' in python" + assert hasattr( + message, "uppercase_v2" + ), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python" + assert hasattr( + message, "upper_camel_case" + ), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python" diff --git a/tests/inputs/config.py b/tests/inputs/config.py new file mode 100644 index 0000000..6da1f88 --- /dev/null +++ b/tests/inputs/config.py @@ -0,0 +1,30 @@ +# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. +# Remove from list when fixed. +xfail = { + "namespace_keywords", # 70 + "googletypes_struct", # 9 + "googletypes_value", # 9 + "import_capitalized_package", + "example", # This is the example in the readme. Not a test. +} + +services = { + "googletypes_request", + "googletypes_response", + "googletypes_response_embedded", + "service", + "service_separate_packages", + "import_service_input_message", + "googletypes_service_returns_empty", + "googletypes_service_returns_googletype", + "example_service", + "empty_service", + "service_uppercase", +} + + +# Indicate json sample messages to skip when testing that json (de)serialization +# is symmetrical becuase some cases legitimately are not symmetrical. +# Each key references the name of the test scenario and the values in the tuple +# Are the names of the json files. +non_symmetrical_json = {"empty_repeated": ("empty_repeated",)} diff --git a/tests/inputs/deprecated/deprecated.json b/tests/inputs/deprecated/deprecated.json new file mode 100644 index 0000000..43b2b65 --- /dev/null +++ b/tests/inputs/deprecated/deprecated.json @@ -0,0 +1,6 @@ +{ + "message": { + "value": "hello" + }, + "value": 10 +} diff --git a/tests/inputs/deprecated/deprecated.proto b/tests/inputs/deprecated/deprecated.proto new file mode 100644 index 0000000..81d69c0 --- /dev/null +++ b/tests/inputs/deprecated/deprecated.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package deprecated; + +// Some documentation about the Test message. +message Test { + Message message = 1 [deprecated=true]; + int32 value = 2; +} + +message Message { + option deprecated = true; + string value = 1; +} diff --git a/tests/inputs/double/double-negative.json b/tests/inputs/double/double-negative.json new file mode 100644 index 0000000..e0776c7 --- /dev/null +++ b/tests/inputs/double/double-negative.json @@ -0,0 +1,3 @@ +{ + "count": -123.45 +} diff --git a/tests/inputs/double/double.json b/tests/inputs/double/double.json new file mode 100644 index 0000000..321412e --- /dev/null +++ b/tests/inputs/double/double.json @@ -0,0 +1,3 @@ +{ + "count": 123.45 +} diff --git a/tests/inputs/double/double.proto b/tests/inputs/double/double.proto new file mode 100644 index 0000000..66aea95 --- /dev/null +++ b/tests/inputs/double/double.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package double; + +message Test { + double count = 1; +} diff --git a/tests/inputs/empty_repeated/empty_repeated.json b/tests/inputs/empty_repeated/empty_repeated.json new file mode 100644 index 0000000..12a801c --- /dev/null +++ b/tests/inputs/empty_repeated/empty_repeated.json @@ -0,0 +1,3 @@ +{ + "msg": [{"values":[]}] +} diff --git a/tests/inputs/empty_repeated/empty_repeated.proto b/tests/inputs/empty_repeated/empty_repeated.proto new file mode 100644 index 0000000..f787301 --- /dev/null +++ b/tests/inputs/empty_repeated/empty_repeated.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package empty_repeated; + +message MessageA { + repeated float values = 1; +} + +message Test { + repeated MessageA msg = 1; +} diff --git a/tests/inputs/empty_service/empty_service.proto b/tests/inputs/empty_service/empty_service.proto new file mode 100644 index 0000000..e96ff64 --- /dev/null +++ b/tests/inputs/empty_service/empty_service.proto @@ -0,0 +1,7 @@ +/* Empty service without comments */ +syntax = "proto3"; + +package empty_service; + +service Test { +} diff --git a/tests/inputs/entry/entry.proto b/tests/inputs/entry/entry.proto new file mode 100644 index 0000000..3f2af4d --- /dev/null +++ b/tests/inputs/entry/entry.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package entry; + +// This is a minimal example of a repeated message field that caused issues when +// checking whether a message is a map. +// +// During the check wheter a field is a "map", the string "entry" is added to +// the field name, checked against the type name and then further checks are +// made against the nested type of a parent message. In this edge-case, the +// first check would pass even though it shouldn't and that would cause an +// error because the parent type does not have a "nested_type" attribute. + +message Test { + repeated ExportEntry export = 1; +} + +message ExportEntry { + string name = 1; +} diff --git a/tests/inputs/enum/enum.json b/tests/inputs/enum/enum.json new file mode 100644 index 0000000..d68f1c5 --- /dev/null +++ b/tests/inputs/enum/enum.json @@ -0,0 +1,9 @@ +{ + "choice": "FOUR", + "choices": [ + "ZERO", + "ONE", + "THREE", + "FOUR" + ] +} diff --git a/tests/inputs/enum/enum.proto b/tests/inputs/enum/enum.proto new file mode 100644 index 0000000..5e2e80c --- /dev/null +++ b/tests/inputs/enum/enum.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package enum; + +// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values +message Test { + Choice choice = 1; + repeated Choice choices = 2; +} + +enum Choice { + ZERO = 0; + ONE = 1; + // TWO = 2; + FOUR = 4; + THREE = 3; +} + +// A "C" like enum with the enum name prefixed onto members, these should be stripped +enum ArithmeticOperator { + ARITHMETIC_OPERATOR_NONE = 0; + ARITHMETIC_OPERATOR_PLUS = 1; + ARITHMETIC_OPERATOR_MINUS = 2; + ARITHMETIC_OPERATOR_0_PREFIXED = 3; +} diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py new file mode 100644 index 0000000..cf14c68 --- /dev/null +++ b/tests/inputs/enum/test_enum.py @@ -0,0 +1,114 @@ +from tests.output_aristaproto.enum import ( + ArithmeticOperator, + Choice, + Test, +) + + +def test_enum_set_and_get(): + assert Test(choice=Choice.ZERO).choice == Choice.ZERO + assert Test(choice=Choice.ONE).choice == Choice.ONE + assert Test(choice=Choice.THREE).choice == Choice.THREE + assert Test(choice=Choice.FOUR).choice == Choice.FOUR + + +def test_enum_set_with_int(): + assert Test(choice=0).choice == Choice.ZERO + assert Test(choice=1).choice == Choice.ONE + assert Test(choice=3).choice == Choice.THREE + assert Test(choice=4).choice == Choice.FOUR + + +def test_enum_is_comparable_with_int(): + assert Test(choice=Choice.ZERO).choice == 0 + assert Test(choice=Choice.ONE).choice == 1 + assert Test(choice=Choice.THREE).choice == 3 + assert Test(choice=Choice.FOUR).choice == 4 + + +def test_enum_to_dict(): + assert ( + "choice" not in Test(choice=Choice.ZERO).to_dict() + ), "Default enum value is not serialized" + assert ( + Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"] + == "ZERO" + ) + assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE" + assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE" + assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR" + + +def test_repeated_enum_is_comparable_with_int(): + assert Test(choices=[Choice.ZERO]).choices == [0] + assert Test(choices=[Choice.ONE]).choices == [1] + assert Test(choices=[Choice.THREE]).choices == [3] + assert Test(choices=[Choice.FOUR]).choices == [4] + + +def test_repeated_enum_set_and_get(): + assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO] + assert Test(choices=[Choice.ONE]).choices == [Choice.ONE] + assert Test(choices=[Choice.THREE]).choices == [Choice.THREE] + assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR] + + +def test_repeated_enum_to_dict(): + assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"] + assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"] + assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"] + assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"] + + all_enums_dict = Test( + choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR] + ).to_dict() + assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"] + + +def test_repeated_enum_with_single_value_to_dict(): + assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"] + assert Test(choices=1).to_dict()["choices"] == ["ONE"] + + +def test_repeated_enum_with_non_list_iterables_to_dict(): + assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] + assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] + assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [ + "ONE", + "THREE", + ] + + def enum_generator(): + yield Choice.ONE + yield Choice.THREE + + assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"] + + +def test_enum_mapped_on_parse(): + # test default value + b = Test().parse(bytes(Test())) + assert b.choice.name == Choice.ZERO.name + assert b.choices == [] + + # test non default value + a = Test().parse(bytes(Test(choice=Choice.ONE))) + assert a.choice.name == Choice.ONE.name + assert b.choices == [] + + # test repeated + c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR]))) + assert c.choices[0].name == Choice.THREE.name + assert c.choices[1].name == Choice.FOUR.name + + # bonus: defaults after empty init are also mapped + assert Test().choice.name == Choice.ZERO.name + + +def test_renamed_enum_members(): + assert set(ArithmeticOperator.__members__) == { + "NONE", + "PLUS", + "MINUS", + "_0_PREFIXED", + } diff --git a/tests/inputs/example/example.proto b/tests/inputs/example/example.proto new file mode 100644 index 0000000..56bd364 --- /dev/null +++ b/tests/inputs/example/example.proto @@ -0,0 +1,911 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) +// Based on original Protocol Buffers design by +// Sanjay Ghemawat, Jeff Dean, and others. +// +// The messages in this file describe the definitions found in .proto files. +// A valid .proto file can be translated directly to a FileDescriptorProto +// without any other information (e.g. without reading its imports). + + +syntax = "proto2"; + +package example; + +// package google.protobuf; + +option go_package = "google.golang.org/protobuf/types/descriptorpb"; +option java_package = "com.google.protobuf"; +option java_outer_classname = "DescriptorProtos"; +option csharp_namespace = "Google.Protobuf.Reflection"; +option objc_class_prefix = "GPB"; +option cc_enable_arenas = true; + +// descriptor.proto must be optimized for speed because reflection-based +// algorithms don't work during bootstrapping. +option optimize_for = SPEED; + +// The protocol compiler can output a FileDescriptorSet containing the .proto +// files it parses. +message FileDescriptorSet { + repeated FileDescriptorProto file = 1; +} + +// Describes a complete .proto file. +message FileDescriptorProto { + optional string name = 1; // file name, relative to root of source tree + optional string package = 2; // e.g. "foo", "foo.bar", etc. + + // Names of files imported by this file. + repeated string dependency = 3; + // Indexes of the public imported files in the dependency list above. + repeated int32 public_dependency = 10; + // Indexes of the weak imported files in the dependency list. + // For Google-internal migration only. Do not use. + repeated int32 weak_dependency = 11; + + // All top-level definitions in this file. + repeated DescriptorProto message_type = 4; + repeated EnumDescriptorProto enum_type = 5; + repeated ServiceDescriptorProto service = 6; + repeated FieldDescriptorProto extension = 7; + + optional FileOptions options = 8; + + // This field contains optional information about the original source code. + // You may safely remove this entire field without harming runtime + // functionality of the descriptors -- the information is needed only by + // development tools. + optional SourceCodeInfo source_code_info = 9; + + // The syntax of the proto file. + // The supported values are "proto2" and "proto3". + optional string syntax = 12; +} + +// Describes a message type. +message DescriptorProto { + optional string name = 1; + + repeated FieldDescriptorProto field = 2; + repeated FieldDescriptorProto extension = 6; + + repeated DescriptorProto nested_type = 3; + repeated EnumDescriptorProto enum_type = 4; + + message ExtensionRange { + optional int32 start = 1; // Inclusive. + optional int32 end = 2; // Exclusive. + + optional ExtensionRangeOptions options = 3; + } + repeated ExtensionRange extension_range = 5; + + repeated OneofDescriptorProto oneof_decl = 8; + + optional MessageOptions options = 7; + + // Range of reserved tag numbers. Reserved tag numbers may not be used by + // fields or extension ranges in the same message. Reserved ranges may + // not overlap. + message ReservedRange { + optional int32 start = 1; // Inclusive. + optional int32 end = 2; // Exclusive. + } + repeated ReservedRange reserved_range = 9; + // Reserved field names, which may not be used by fields in the same message. + // A given name may only be reserved once. + repeated string reserved_name = 10; +} + +message ExtensionRangeOptions { + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; +} + +// Describes a field within a message. +message FieldDescriptorProto { + enum Type { + // 0 is reserved for errors. + // Order is weird for historical reasons. + TYPE_DOUBLE = 1; + TYPE_FLOAT = 2; + // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT64 if + // negative values are likely. + TYPE_INT64 = 3; + TYPE_UINT64 = 4; + // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT32 if + // negative values are likely. + TYPE_INT32 = 5; + TYPE_FIXED64 = 6; + TYPE_FIXED32 = 7; + TYPE_BOOL = 8; + TYPE_STRING = 9; + // Tag-delimited aggregate. + // Group type is deprecated and not supported in proto3. However, Proto3 + // implementations should still be able to parse the group wire format and + // treat group fields as unknown fields. + TYPE_GROUP = 10; + TYPE_MESSAGE = 11; // Length-delimited aggregate. + + // New in version 2. + TYPE_BYTES = 12; + TYPE_UINT32 = 13; + TYPE_ENUM = 14; + TYPE_SFIXED32 = 15; + TYPE_SFIXED64 = 16; + TYPE_SINT32 = 17; // Uses ZigZag encoding. + TYPE_SINT64 = 18; // Uses ZigZag encoding. + } + + enum Label { + // 0 is reserved for errors + LABEL_OPTIONAL = 1; + LABEL_REQUIRED = 2; + LABEL_REPEATED = 3; + } + + optional string name = 1; + optional int32 number = 3; + optional Label label = 4; + + // If type_name is set, this need not be set. If both this and type_name + // are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP. + optional Type type = 5; + + // For message and enum types, this is the name of the type. If the name + // starts with a '.', it is fully-qualified. Otherwise, C++-like scoping + // rules are used to find the type (i.e. first the nested types within this + // message are searched, then within the parent, on up to the root + // namespace). + optional string type_name = 6; + + // For extensions, this is the name of the type being extended. It is + // resolved in the same manner as type_name. + optional string extendee = 2; + + // For numeric types, contains the original text representation of the value. + // For booleans, "true" or "false". + // For strings, contains the default text contents (not escaped in any way). + // For bytes, contains the C escaped value. All bytes >= 128 are escaped. + // TODO(kenton): Base-64 encode? + optional string default_value = 7; + + // If set, gives the index of a oneof in the containing type's oneof_decl + // list. This field is a member of that oneof. + optional int32 oneof_index = 9; + + // JSON name of this field. The value is set by protocol compiler. If the + // user has set a "json_name" option on this field, that option's value + // will be used. Otherwise, it's deduced from the field's name by converting + // it to camelCase. + optional string json_name = 10; + + optional FieldOptions options = 8; + + // If true, this is a proto3 "optional". When a proto3 field is optional, it + // tracks presence regardless of field type. + // + // When proto3_optional is true, this field must be belong to a oneof to + // signal to old proto3 clients that presence is tracked for this field. This + // oneof is known as a "synthetic" oneof, and this field must be its sole + // member (each proto3 optional field gets its own synthetic oneof). Synthetic + // oneofs exist in the descriptor only, and do not generate any API. Synthetic + // oneofs must be ordered after all "real" oneofs. + // + // For message fields, proto3_optional doesn't create any semantic change, + // since non-repeated message fields always track presence. However it still + // indicates the semantic detail of whether the user wrote "optional" or not. + // This can be useful for round-tripping the .proto file. For consistency we + // give message fields a synthetic oneof also, even though it is not required + // to track presence. This is especially important because the parser can't + // tell if a field is a message or an enum, so it must always create a + // synthetic oneof. + // + // Proto2 optional fields do not set this flag, because they already indicate + // optional with `LABEL_OPTIONAL`. + optional bool proto3_optional = 17; +} + +// Describes a oneof. +message OneofDescriptorProto { + optional string name = 1; + optional OneofOptions options = 2; +} + +// Describes an enum type. +message EnumDescriptorProto { + optional string name = 1; + + repeated EnumValueDescriptorProto value = 2; + + optional EnumOptions options = 3; + + // Range of reserved numeric values. Reserved values may not be used by + // entries in the same enum. Reserved ranges may not overlap. + // + // Note that this is distinct from DescriptorProto.ReservedRange in that it + // is inclusive such that it can appropriately represent the entire int32 + // domain. + message EnumReservedRange { + optional int32 start = 1; // Inclusive. + optional int32 end = 2; // Inclusive. + } + + // Range of reserved numeric values. Reserved numeric values may not be used + // by enum values in the same enum declaration. Reserved ranges may not + // overlap. + repeated EnumReservedRange reserved_range = 4; + + // Reserved enum value names, which may not be reused. A given name may only + // be reserved once. + repeated string reserved_name = 5; +} + +// Describes a value within an enum. +message EnumValueDescriptorProto { + optional string name = 1; + optional int32 number = 2; + + optional EnumValueOptions options = 3; +} + +// Describes a service. +message ServiceDescriptorProto { + optional string name = 1; + repeated MethodDescriptorProto method = 2; + + optional ServiceOptions options = 3; +} + +// Describes a method of a service. +message MethodDescriptorProto { + optional string name = 1; + + // Input and output type names. These are resolved in the same way as + // FieldDescriptorProto.type_name, but must refer to a message type. + optional string input_type = 2; + optional string output_type = 3; + + optional MethodOptions options = 4; + + // Identifies if client streams multiple client messages + optional bool client_streaming = 5 [default = false]; + // Identifies if server streams multiple server messages + optional bool server_streaming = 6 [default = false]; +} + + +// =================================================================== +// Options + +// Each of the definitions above may have "options" attached. These are +// just annotations which may cause code to be generated slightly differently +// or may contain hints for code that manipulates protocol messages. +// +// Clients may define custom options as extensions of the *Options messages. +// These extensions may not yet be known at parsing time, so the parser cannot +// store the values in them. Instead it stores them in a field in the *Options +// message called uninterpreted_option. This field must have the same name +// across all *Options messages. We then use this field to populate the +// extensions when we build a descriptor, at which point all protos have been +// parsed and so all extensions are known. +// +// Extension numbers for custom options may be chosen as follows: +// * For options which will only be used within a single application or +// organization, or for experimental options, use field numbers 50000 +// through 99999. It is up to you to ensure that you do not use the +// same number for multiple options. +// * For options which will be published and used publicly by multiple +// independent entities, e-mail protobuf-global-extension-registry@google.com +// to reserve extension numbers. Simply provide your project name (e.g. +// Objective-C plugin) and your project website (if available) -- there's no +// need to explain how you intend to use them. Usually you only need one +// extension number. You can declare multiple options with only one extension +// number by putting them in a sub-message. See the Custom Options section of +// the docs for examples: +// https://developers.google.com/protocol-buffers/docs/proto#options +// If this turns out to be popular, a web service will be set up +// to automatically assign option numbers. + +message FileOptions { + + // Sets the Java package where classes generated from this .proto will be + // placed. By default, the proto package is used, but this is often + // inappropriate because proto packages do not normally start with backwards + // domain names. + optional string java_package = 1; + + + // If set, all the classes from the .proto file are wrapped in a single + // outer class with the given name. This applies to both Proto1 + // (equivalent to the old "--one_java_file" option) and Proto2 (where + // a .proto always translates to a single class, but you may want to + // explicitly choose the class name). + optional string java_outer_classname = 8; + + // If set true, then the Java code generator will generate a separate .java + // file for each top-level message, enum, and service defined in the .proto + // file. Thus, these types will *not* be nested inside the outer class + // named by java_outer_classname. However, the outer class will still be + // generated to contain the file's getDescriptor() method as well as any + // top-level extensions defined in the file. + optional bool java_multiple_files = 10 [default = false]; + + // This option does nothing. + optional bool java_generate_equals_and_hash = 20 [deprecated=true]; + + // If set true, then the Java2 code generator will generate code that + // throws an exception whenever an attempt is made to assign a non-UTF-8 + // byte sequence to a string field. + // Message reflection will do the same. + // However, an extension field still accepts non-UTF-8 byte sequences. + // This option has no effect on when used with the lite runtime. + optional bool java_string_check_utf8 = 27 [default = false]; + + + // Generated classes can be optimized for speed or code size. + enum OptimizeMode { + SPEED = 1; // Generate complete code for parsing, serialization, + // etc. + CODE_SIZE = 2; // Use ReflectionOps to implement these methods. + LITE_RUNTIME = 3; // Generate code using MessageLite and the lite runtime. + } + optional OptimizeMode optimize_for = 9 [default = SPEED]; + + // Sets the Go package where structs generated from this .proto will be + // placed. If omitted, the Go package will be derived from the following: + // - The basename of the package import path, if provided. + // - Otherwise, the package statement in the .proto file, if present. + // - Otherwise, the basename of the .proto file, without extension. + optional string go_package = 11; + + + + + // Should generic services be generated in each language? "Generic" services + // are not specific to any particular RPC system. They are generated by the + // main code generators in each language (without additional plugins). + // Generic services were the only kind of service generation supported by + // early versions of google.protobuf. + // + // Generic services are now considered deprecated in favor of using plugins + // that generate code specific to your particular RPC system. Therefore, + // these default to false. Old code which depends on generic services should + // explicitly set them to true. + optional bool cc_generic_services = 16 [default = false]; + optional bool java_generic_services = 17 [default = false]; + optional bool py_generic_services = 18 [default = false]; + optional bool php_generic_services = 42 [default = false]; + + // Is this file deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for everything in the file, or it will be completely ignored; in the very + // least, this is a formalization for deprecating files. + optional bool deprecated = 23 [default = false]; + + // Enables the use of arenas for the proto messages in this file. This applies + // only to generated classes for C++. + optional bool cc_enable_arenas = 31 [default = true]; + + + // Sets the objective c class prefix which is prepended to all objective c + // generated classes from this .proto. There is no default. + optional string objc_class_prefix = 36; + + // Namespace for generated classes; defaults to the package. + optional string csharp_namespace = 37; + + // By default Swift generators will take the proto package and CamelCase it + // replacing '.' with underscore and use that to prefix the types/symbols + // defined. When this options is provided, they will use this value instead + // to prefix the types/symbols defined. + optional string swift_prefix = 39; + + // Sets the php class prefix which is prepended to all php generated classes + // from this .proto. Default is empty. + optional string php_class_prefix = 40; + + // Use this option to change the namespace of php generated classes. Default + // is empty. When this option is empty, the package name will be used for + // determining the namespace. + optional string php_namespace = 41; + + // Use this option to change the namespace of php generated metadata classes. + // Default is empty. When this option is empty, the proto file name will be + // used for determining the namespace. + optional string php_metadata_namespace = 44; + + // Use this option to change the package of ruby generated classes. Default + // is empty. When this option is not set, the package name will be used for + // determining the ruby package. + optional string ruby_package = 45; + + + // The parser stores options it doesn't recognize here. + // See the documentation for the "Options" section above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. + // See the documentation for the "Options" section above. + extensions 1000 to max; + + reserved 38; +} + +message MessageOptions { + // Set true to use the old proto1 MessageSet wire format for extensions. + // This is provided for backwards-compatibility with the MessageSet wire + // format. You should not use this for any other reason: It's less + // efficient, has fewer features, and is more complicated. + // + // The message must be defined exactly as follows: + // message Foo { + // option message_set_wire_format = true; + // extensions 4 to max; + // } + // Note that the message cannot have any defined fields; MessageSets only + // have extensions. + // + // All extensions of your type must be singular messages; e.g. they cannot + // be int32s, enums, or repeated messages. + // + // Because this is an option, the above two restrictions are not enforced by + // the protocol compiler. + optional bool message_set_wire_format = 1 [default = false]; + + // Disables the generation of the standard "descriptor()" accessor, which can + // conflict with a field of the same name. This is meant to make migration + // from proto1 easier; new code should avoid fields named "descriptor". + optional bool no_standard_descriptor_accessor = 2 [default = false]; + + // Is this message deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the message, or it will be completely ignored; in the very least, + // this is a formalization for deprecating messages. + optional bool deprecated = 3 [default = false]; + + // Whether the message is an automatically generated map entry type for the + // maps field. + // + // For maps fields: + // map<KeyType, ValueType> map_field = 1; + // The parsed descriptor looks like: + // message MapFieldEntry { + // option map_entry = true; + // optional KeyType key = 1; + // optional ValueType value = 2; + // } + // repeated MapFieldEntry map_field = 1; + // + // Implementations may choose not to generate the map_entry=true message, but + // use a native map in the target language to hold the keys and values. + // The reflection APIs in such implementations still need to work as + // if the field is a repeated message field. + // + // NOTE: Do not set the option in .proto files. Always use the maps syntax + // instead. The option should only be implicitly set by the proto compiler + // parser. + optional bool map_entry = 7; + + reserved 8; // javalite_serializable + reserved 9; // javanano_as_lite + + + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; +} + +message FieldOptions { + // The ctype option instructs the C++ code generator to use a different + // representation of the field than it normally would. See the specific + // options below. This option is not yet implemented in the open source + // release -- sorry, we'll try to include it in a future version! + optional CType ctype = 1 [default = STRING]; + enum CType { + // Default mode. + STRING = 0; + + CORD = 1; + + STRING_PIECE = 2; + } + // The packed option can be enabled for repeated primitive fields to enable + // a more efficient representation on the wire. Rather than repeatedly + // writing the tag and type for each element, the entire array is encoded as + // a single length-delimited blob. In proto3, only explicit setting it to + // false will avoid using packed encoding. + optional bool packed = 2; + + // The jstype option determines the JavaScript type used for values of the + // field. The option is permitted only for 64 bit integral and fixed types + // (int64, uint64, sint64, fixed64, sfixed64). A field with jstype JS_STRING + // is represented as JavaScript string, which avoids loss of precision that + // can happen when a large value is converted to a floating point JavaScript. + // Specifying JS_NUMBER for the jstype causes the generated JavaScript code to + // use the JavaScript "number" type. The behavior of the default option + // JS_NORMAL is implementation dependent. + // + // This option is an enum to permit additional types to be added, e.g. + // goog.math.Integer. + optional JSType jstype = 6 [default = JS_NORMAL]; + enum JSType { + // Use the default type. + JS_NORMAL = 0; + + // Use JavaScript strings. + JS_STRING = 1; + + // Use JavaScript numbers. + JS_NUMBER = 2; + } + + // Should this field be parsed lazily? Lazy applies only to message-type + // fields. It means that when the outer message is initially parsed, the + // inner message's contents will not be parsed but instead stored in encoded + // form. The inner message will actually be parsed when it is first accessed. + // + // This is only a hint. Implementations are free to choose whether to use + // eager or lazy parsing regardless of the value of this option. However, + // setting this option true suggests that the protocol author believes that + // using lazy parsing on this field is worth the additional bookkeeping + // overhead typically needed to implement it. + // + // This option does not affect the public interface of any generated code; + // all method signatures remain the same. Furthermore, thread-safety of the + // interface is not affected by this option; const methods remain safe to + // call from multiple threads concurrently, while non-const methods continue + // to require exclusive access. + // + // + // Note that implementations may choose not to check required fields within + // a lazy sub-message. That is, calling IsInitialized() on the outer message + // may return true even if the inner message has missing required fields. + // This is necessary because otherwise the inner message would have to be + // parsed in order to perform the check, defeating the purpose of lazy + // parsing. An implementation which chooses not to check required fields + // must be consistent about it. That is, for any particular sub-message, the + // implementation must either *always* check its required fields, or *never* + // check its required fields, regardless of whether or not the message has + // been parsed. + optional bool lazy = 5 [default = false]; + + // Is this field deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for accessors, or it will be completely ignored; in the very least, this + // is a formalization for deprecating fields. + optional bool deprecated = 3 [default = false]; + + // For Google-internal migration only. Do not use. + optional bool weak = 10 [default = false]; + + + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; + + reserved 4; // removed jtype +} + +message OneofOptions { + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; +} + +message EnumOptions { + + // Set this option to true to allow mapping different tag names to the same + // value. + optional bool allow_alias = 2; + + // Is this enum deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the enum, or it will be completely ignored; in the very least, this + // is a formalization for deprecating enums. + optional bool deprecated = 3 [default = false]; + + reserved 5; // javanano_as_lite + + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; +} + +message EnumValueOptions { + // Is this enum value deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the enum value, or it will be completely ignored; in the very least, + // this is a formalization for deprecating enum values. + optional bool deprecated = 1 [default = false]; + + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; +} + +message ServiceOptions { + + // Note: Field numbers 1 through 32 are reserved for Google's internal RPC + // framework. We apologize for hoarding these numbers to ourselves, but + // we were already using them long before we decided to release Protocol + // Buffers. + + // Is this service deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the service, or it will be completely ignored; in the very least, + // this is a formalization for deprecating services. + optional bool deprecated = 33 [default = false]; + + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; +} + +message MethodOptions { + + // Note: Field numbers 1 through 32 are reserved for Google's internal RPC + // framework. We apologize for hoarding these numbers to ourselves, but + // we were already using them long before we decided to release Protocol + // Buffers. + + // Is this method deprecated? + // Depending on the target platform, this can emit Deprecated annotations + // for the method, or it will be completely ignored; in the very least, + // this is a formalization for deprecating methods. + optional bool deprecated = 33 [default = false]; + + // Is this method side-effect-free (or safe in HTTP parlance), or idempotent, + // or neither? HTTP based RPC implementation may choose GET verb for safe + // methods, and PUT verb for idempotent methods instead of the default POST. + enum IdempotencyLevel { + IDEMPOTENCY_UNKNOWN = 0; + NO_SIDE_EFFECTS = 1; // implies idempotent + IDEMPOTENT = 2; // idempotent, but may have side effects + } + optional IdempotencyLevel idempotency_level = 34 + [default = IDEMPOTENCY_UNKNOWN]; + + // The parser stores options it doesn't recognize here. See above. + repeated UninterpretedOption uninterpreted_option = 999; + + // Clients can define custom options in extensions of this message. See above. + extensions 1000 to max; +} + + +// A message representing a option the parser does not recognize. This only +// appears in options protos created by the compiler::Parser class. +// DescriptorPool resolves these when building Descriptor objects. Therefore, +// options protos in descriptor objects (e.g. returned by Descriptor::options(), +// or produced by Descriptor::CopyTo()) will never have UninterpretedOptions +// in them. +message UninterpretedOption { + // The name of the uninterpreted option. Each string represents a segment in + // a dot-separated name. is_extension is true iff a segment represents an + // extension (denoted with parentheses in options specs in .proto files). + // E.g.,{ ["foo", false], ["bar.baz", true], ["qux", false] } represents + // "foo.(bar.baz).qux". + message NamePart { + required string name_part = 1; + required bool is_extension = 2; + } + repeated NamePart name = 2; + + // The value of the uninterpreted option, in whatever type the tokenizer + // identified it as during parsing. Exactly one of these should be set. + optional string identifier_value = 3; + optional uint64 positive_int_value = 4; + optional int64 negative_int_value = 5; + optional double double_value = 6; + optional bytes string_value = 7; + optional string aggregate_value = 8; +} + +// =================================================================== +// Optional source code info + +// Encapsulates information about the original source file from which a +// FileDescriptorProto was generated. +message SourceCodeInfo { + // A Location identifies a piece of source code in a .proto file which + // corresponds to a particular definition. This information is intended + // to be useful to IDEs, code indexers, documentation generators, and similar + // tools. + // + // For example, say we have a file like: + // message Foo { + // optional string foo = 1; + // } + // Let's look at just the field definition: + // optional string foo = 1; + // ^ ^^ ^^ ^ ^^^ + // a bc de f ghi + // We have the following locations: + // span path represents + // [a,i) [ 4, 0, 2, 0 ] The whole field definition. + // [a,b) [ 4, 0, 2, 0, 4 ] The label (optional). + // [c,d) [ 4, 0, 2, 0, 5 ] The type (string). + // [e,f) [ 4, 0, 2, 0, 1 ] The name (foo). + // [g,h) [ 4, 0, 2, 0, 3 ] The number (1). + // + // Notes: + // - A location may refer to a repeated field itself (i.e. not to any + // particular index within it). This is used whenever a set of elements are + // logically enclosed in a single code segment. For example, an entire + // extend block (possibly containing multiple extension definitions) will + // have an outer location whose path refers to the "extensions" repeated + // field without an index. + // - Multiple locations may have the same path. This happens when a single + // logical declaration is spread out across multiple places. The most + // obvious example is the "extend" block again -- there may be multiple + // extend blocks in the same scope, each of which will have the same path. + // - A location's span is not always a subset of its parent's span. For + // example, the "extendee" of an extension declaration appears at the + // beginning of the "extend" block and is shared by all extensions within + // the block. + // - Just because a location's span is a subset of some other location's span + // does not mean that it is a descendant. For example, a "group" defines + // both a type and a field in a single declaration. Thus, the locations + // corresponding to the type and field and their components will overlap. + // - Code which tries to interpret locations should probably be designed to + // ignore those that it doesn't understand, as more types of locations could + // be recorded in the future. + repeated Location location = 1; + message Location { + // Identifies which part of the FileDescriptorProto was defined at this + // location. + // + // Each element is a field number or an index. They form a path from + // the root FileDescriptorProto to the place where the definition. For + // example, this path: + // [ 4, 3, 2, 7, 1 ] + // refers to: + // file.message_type(3) // 4, 3 + // .field(7) // 2, 7 + // .name() // 1 + // This is because FileDescriptorProto.message_type has field number 4: + // repeated DescriptorProto message_type = 4; + // and DescriptorProto.field has field number 2: + // repeated FieldDescriptorProto field = 2; + // and FieldDescriptorProto.name has field number 1: + // optional string name = 1; + // + // Thus, the above path gives the location of a field name. If we removed + // the last element: + // [ 4, 3, 2, 7 ] + // this path refers to the whole field declaration (from the beginning + // of the label to the terminating semicolon). + repeated int32 path = 1 [packed = true]; + + // Always has exactly three or four elements: start line, start column, + // end line (optional, otherwise assumed same as start line), end column. + // These are packed into a single field for efficiency. Note that line + // and column numbers are zero-based -- typically you will want to add + // 1 to each before displaying to a user. + repeated int32 span = 2 [packed = true]; + + // If this SourceCodeInfo represents a complete declaration, these are any + // comments appearing before and after the declaration which appear to be + // attached to the declaration. + // + // A series of line comments appearing on consecutive lines, with no other + // tokens appearing on those lines, will be treated as a single comment. + // + // leading_detached_comments will keep paragraphs of comments that appear + // before (but not connected to) the current element. Each paragraph, + // separated by empty lines, will be one comment element in the repeated + // field. + // + // Only the comment content is provided; comment markers (e.g. //) are + // stripped out. For block comments, leading whitespace and an asterisk + // will be stripped from the beginning of each line other than the first. + // Newlines are included in the output. + // + // Examples: + // + // optional int32 foo = 1; // Comment attached to foo. + // // Comment attached to bar. + // optional int32 bar = 2; + // + // optional string baz = 3; + // // Comment attached to baz. + // // Another line attached to baz. + // + // // Comment attached to qux. + // // + // // Another line attached to qux. + // optional double qux = 4; + // + // // Detached comment for corge. This is not leading or trailing comments + // // to qux or corge because there are blank lines separating it from + // // both. + // + // // Detached comment for corge paragraph 2. + // + // optional string corge = 5; + // /* Block comment attached + // * to corge. Leading asterisks + // * will be removed. */ + // /* Block comment attached to + // * grault. */ + // optional int32 grault = 6; + // + // // ignored detached comments. + optional string leading_comments = 3; + optional string trailing_comments = 4; + repeated string leading_detached_comments = 6; + } +} + +// Describes the relationship between generated code and its original source +// file. A GeneratedCodeInfo message is associated with only one generated +// source file, but may contain references to different source .proto files. +message GeneratedCodeInfo { + // An Annotation connects some span of text in generated code to an element + // of its generating .proto file. + repeated Annotation annotation = 1; + message Annotation { + // Identifies the element in the original source .proto file. This field + // is formatted the same as SourceCodeInfo.Location.path. + repeated int32 path = 1 [packed = true]; + + // Identifies the filesystem path to the original source .proto. + optional string source_file = 2; + + // Identifies the starting offset in bytes in the generated code + // that relates to the identified object. + optional int32 begin = 3; + + // Identifies the ending offset in bytes in the generated code that + // relates to the identified offset. The end offset should be one past + // the last relevant byte (so the length of the text = end - begin). + optional int32 end = 4; + } +} diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto new file mode 100644 index 0000000..96455cc --- /dev/null +++ b/tests/inputs/example_service/example_service.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package example_service; + +service Test { + rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); + rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); + rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse); + rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse); +} + +message ExampleRequest { + string example_string = 1; + int64 example_integer = 2; +} + +message ExampleResponse { + string example_string = 1; + int64 example_integer = 2; +} diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py new file mode 100644 index 0000000..551e3fe --- /dev/null +++ b/tests/inputs/example_service/test_example_service.py @@ -0,0 +1,86 @@ +from typing import ( + AsyncIterable, + AsyncIterator, +) + +import pytest +from grpclib.testing import ChannelFor + +from tests.output_aristaproto.example_service import ( + ExampleRequest, + ExampleResponse, + TestBase, + TestStub, +) + + +class ExampleService(TestBase): + async def example_unary_unary( + self, example_request: ExampleRequest + ) -> "ExampleResponse": + return ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + + async def example_unary_stream( + self, example_request: ExampleRequest + ) -> AsyncIterator["ExampleResponse"]: + response = ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + yield response + yield response + yield response + + async def example_stream_unary( + self, example_request_iterator: AsyncIterator["ExampleRequest"] + ) -> "ExampleResponse": + async for example_request in example_request_iterator: + return ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + + async def example_stream_stream( + self, example_request_iterator: AsyncIterator["ExampleRequest"] + ) -> AsyncIterator["ExampleResponse"]: + async for example_request in example_request_iterator: + yield ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + + +@pytest.mark.asyncio +async def test_calls_with_different_cardinalities(): + example_request = ExampleRequest("test string", 42) + + async with ChannelFor([ExampleService()]) as channel: + stub = TestStub(channel) + + # unary unary + response = await stub.example_unary_unary(example_request) + assert response.example_string == example_request.example_string + assert response.example_integer == example_request.example_integer + + # unary stream + async for response in stub.example_unary_stream(example_request): + assert response.example_string == example_request.example_string + assert response.example_integer == example_request.example_integer + + # stream unary + async def request_iterator(): + yield example_request + yield example_request + yield example_request + + response = await stub.example_stream_unary(request_iterator()) + assert response.example_string == example_request.example_string + assert response.example_integer == example_request.example_integer + + # stream stream + async for response in stub.example_stream_stream(request_iterator()): + assert response.example_string == example_request.example_string + assert response.example_integer == example_request.example_integer diff --git a/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json new file mode 100644 index 0000000..7a6e7ae --- /dev/null +++ b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json @@ -0,0 +1,7 @@ +{ + "int": 26, + "float": 26.0, + "str": "value-for-str", + "bytes": "001a", + "bool": true +}
\ No newline at end of file diff --git a/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto new file mode 100644 index 0000000..81a0fc4 --- /dev/null +++ b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package field_name_identical_to_type; + +// Tests that messages may contain fields with names that are identical to their python types (PR #294) + +message Test { + int32 int = 1; + float float = 2; + string str = 3; + bytes bytes = 4; + bool bool = 5; +}
\ No newline at end of file diff --git a/tests/inputs/fixed/fixed.json b/tests/inputs/fixed/fixed.json new file mode 100644 index 0000000..8858780 --- /dev/null +++ b/tests/inputs/fixed/fixed.json @@ -0,0 +1,6 @@ +{ + "foo": 4294967295, + "bar": -2147483648, + "baz": "18446744073709551615", + "qux": "-9223372036854775808" +} diff --git a/tests/inputs/fixed/fixed.proto b/tests/inputs/fixed/fixed.proto new file mode 100644 index 0000000..0f0ffb4 --- /dev/null +++ b/tests/inputs/fixed/fixed.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package fixed; + +message Test { + fixed32 foo = 1; + sfixed32 bar = 2; + fixed64 baz = 3; + sfixed64 qux = 4; +} diff --git a/tests/inputs/float/float.json b/tests/inputs/float/float.json new file mode 100644 index 0000000..3adac97 --- /dev/null +++ b/tests/inputs/float/float.json @@ -0,0 +1,9 @@ +{ + "positive": "Infinity", + "negative": "-Infinity", + "nan": "NaN", + "three": 3.0, + "threePointOneFour": 3.14, + "negThree": -3.0, + "negThreePointOneFour": -3.14 + } diff --git a/tests/inputs/float/float.proto b/tests/inputs/float/float.proto new file mode 100644 index 0000000..fea12b3 --- /dev/null +++ b/tests/inputs/float/float.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package float; + +// Some documentation about the Test message. +message Test { + double positive = 1; + double negative = 2; + double nan = 3; + double three = 4; + double three_point_one_four = 5; + double neg_three = 6; + double neg_three_point_one_four = 7; +} diff --git a/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto b/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto new file mode 100644 index 0000000..66ef8a6 --- /dev/null +++ b/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; +package google_impl_behavior_equivalence; + +message Foo { int64 bar = 1; } + +message Test { + oneof group { + string string = 1; + int64 integer = 2; + Foo foo = 3; + } +} + +message Spam { + google.protobuf.Timestamp ts = 1; +} + +message Request { Empty foo = 1; } + +message Empty {} diff --git a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py new file mode 100644 index 0000000..c621f11 --- /dev/null +++ b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py @@ -0,0 +1,93 @@ +from datetime import ( + datetime, + timezone, +) + +import pytest +from google.protobuf import json_format +from google.protobuf.timestamp_pb2 import Timestamp + +import aristaproto +from tests.output_aristaproto.google_impl_behavior_equivalence import ( + Empty, + Foo, + Request, + Spam, + Test, +) +from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import ( + Empty as ReferenceEmpty, + Foo as ReferenceFoo, + Request as ReferenceRequest, + Spam as ReferenceSpam, + Test as ReferenceTest, +) + + +def test_oneof_serializes_similar_to_google_oneof(): + tests = [ + (Test(string="abc"), ReferenceTest(string="abc")), + (Test(integer=2), ReferenceTest(integer=2)), + (Test(foo=Foo(bar=1)), ReferenceTest(foo=ReferenceFoo(bar=1))), + # Default values should also behave the same within oneofs + (Test(string=""), ReferenceTest(string="")), + (Test(integer=0), ReferenceTest(integer=0)), + (Test(foo=Foo(bar=0)), ReferenceTest(foo=ReferenceFoo(bar=0))), + ] + for message, message_reference in tests: + # NOTE: As of July 2020, MessageToJson inserts newlines in the output string so, + # just compare dicts + assert message.to_dict() == json_format.MessageToDict(message_reference) + + +def test_bytes_are_the_same_for_oneof(): + message = Test(string="") + message_reference = ReferenceTest(string="") + + message_bytes = bytes(message) + message_reference_bytes = message_reference.SerializeToString() + + assert message_bytes == message_reference_bytes + + message2 = Test().parse(message_reference_bytes) + message_reference2 = ReferenceTest() + message_reference2.ParseFromString(message_reference_bytes) + + assert message == message2 + assert message_reference == message_reference2 + + # None of these fields were explicitly set BUT they should not actually be null + # themselves + assert not hasattr(message, "foo") + assert object.__getattribute__(message, "foo") == aristaproto.PLACEHOLDER + assert not hasattr(message2, "foo") + assert object.__getattribute__(message2, "foo") == aristaproto.PLACEHOLDER + + assert isinstance(message_reference.foo, ReferenceFoo) + assert isinstance(message_reference2.foo, ReferenceFoo) + + +@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),)) +def test_datetime_clamping(dt): # see #407 + ts = Timestamp() + ts.FromDatetime(dt) + assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString() + message_bytes = bytes(Spam(dt)) + + assert ( + Spam().parse(message_bytes).ts.timestamp() + == ReferenceSpam.FromString(message_bytes).ts.seconds + ) + + +def test_empty_message_field(): + message = Request() + reference_message = ReferenceRequest() + + message.foo = Empty() + reference_message.foo.CopyFrom(ReferenceEmpty()) + + assert aristaproto.serialized_on_wire(message.foo) + assert reference_message.HasField("foo") + + assert bytes(message) == reference_message.SerializeToString() diff --git a/tests/inputs/googletypes/googletypes-missing.json b/tests/inputs/googletypes/googletypes-missing.json new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/tests/inputs/googletypes/googletypes-missing.json @@ -0,0 +1 @@ +{} diff --git a/tests/inputs/googletypes/googletypes.json b/tests/inputs/googletypes/googletypes.json new file mode 100644 index 0000000..0a002e9 --- /dev/null +++ b/tests/inputs/googletypes/googletypes.json @@ -0,0 +1,7 @@ +{ + "maybe": false, + "ts": "1972-01-01T10:00:20.021Z", + "duration": "1.200s", + "important": 10, + "empty": {} +} diff --git a/tests/inputs/googletypes/googletypes.proto b/tests/inputs/googletypes/googletypes.proto new file mode 100644 index 0000000..ef8cb4a --- /dev/null +++ b/tests/inputs/googletypes/googletypes.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package googletypes; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; +import "google/protobuf/empty.proto"; + +message Test { + google.protobuf.BoolValue maybe = 1; + google.protobuf.Timestamp ts = 2; + google.protobuf.Duration duration = 3; + google.protobuf.Int32Value important = 4; + google.protobuf.Empty empty = 5; +} diff --git a/tests/inputs/googletypes_request/googletypes_request.proto b/tests/inputs/googletypes_request/googletypes_request.proto new file mode 100644 index 0000000..1cedcaa --- /dev/null +++ b/tests/inputs/googletypes_request/googletypes_request.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package googletypes_request; + +import "google/protobuf/duration.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + +// Tests that google types can be used as params + +service Test { + rpc SendDouble (google.protobuf.DoubleValue) returns (Input); + rpc SendFloat (google.protobuf.FloatValue) returns (Input); + rpc SendInt64 (google.protobuf.Int64Value) returns (Input); + rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input); + rpc SendInt32 (google.protobuf.Int32Value) returns (Input); + rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input); + rpc SendBool (google.protobuf.BoolValue) returns (Input); + rpc SendString (google.protobuf.StringValue) returns (Input); + rpc SendBytes (google.protobuf.BytesValue) returns (Input); + rpc SendDatetime (google.protobuf.Timestamp) returns (Input); + rpc SendTimedelta (google.protobuf.Duration) returns (Input); + rpc SendEmpty (google.protobuf.Empty) returns (Input); +} + +message Input { + +} diff --git a/tests/inputs/googletypes_request/test_googletypes_request.py b/tests/inputs/googletypes_request/test_googletypes_request.py new file mode 100644 index 0000000..8351f71 --- /dev/null +++ b/tests/inputs/googletypes_request/test_googletypes_request.py @@ -0,0 +1,47 @@ +from datetime import ( + datetime, + timedelta, +) +from typing import ( + Any, + Callable, +) + +import pytest + +import aristaproto.lib.google.protobuf as protobuf +from tests.mocks import MockChannel +from tests.output_aristaproto.googletypes_request import ( + Input, + TestStub, +) + + +test_cases = [ + (TestStub.send_double, protobuf.DoubleValue, 2.5), + (TestStub.send_float, protobuf.FloatValue, 2.5), + (TestStub.send_int64, protobuf.Int64Value, -64), + (TestStub.send_u_int64, protobuf.UInt64Value, 64), + (TestStub.send_int32, protobuf.Int32Value, -32), + (TestStub.send_u_int32, protobuf.UInt32Value, 32), + (TestStub.send_bool, protobuf.BoolValue, True), + (TestStub.send_string, protobuf.StringValue, "string"), + (TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), + (TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)), + (TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) +async def test_channel_receives_wrapped_type( + service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value +): + wrapped_value = wrapper_class() + wrapped_value.value = value + channel = MockChannel(responses=[Input()]) + service = TestStub(channel) + + await service_method(service, wrapped_value) + + assert channel.requests[0]["request"] == type(wrapped_value) diff --git a/tests/inputs/googletypes_response/googletypes_response.proto b/tests/inputs/googletypes_response/googletypes_response.proto new file mode 100644 index 0000000..8917d1c --- /dev/null +++ b/tests/inputs/googletypes_response/googletypes_response.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package googletypes_response; + +import "google/protobuf/wrappers.proto"; + +// Tests that wrapped values can be used directly as return values + +service Test { + rpc GetDouble (Input) returns (google.protobuf.DoubleValue); + rpc GetFloat (Input) returns (google.protobuf.FloatValue); + rpc GetInt64 (Input) returns (google.protobuf.Int64Value); + rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value); + rpc GetInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value); + rpc GetBool (Input) returns (google.protobuf.BoolValue); + rpc GetString (Input) returns (google.protobuf.StringValue); + rpc GetBytes (Input) returns (google.protobuf.BytesValue); +} + +message Input { + +} diff --git a/tests/inputs/googletypes_response/test_googletypes_response.py b/tests/inputs/googletypes_response/test_googletypes_response.py new file mode 100644 index 0000000..4ac340e --- /dev/null +++ b/tests/inputs/googletypes_response/test_googletypes_response.py @@ -0,0 +1,64 @@ +from typing import ( + Any, + Callable, + Optional, +) + +import pytest + +import aristaproto.lib.google.protobuf as protobuf +from tests.mocks import MockChannel +from tests.output_aristaproto.googletypes_response import ( + Input, + TestStub, +) + + +test_cases = [ + (TestStub.get_double, protobuf.DoubleValue, 2.5), + (TestStub.get_float, protobuf.FloatValue, 2.5), + (TestStub.get_int64, protobuf.Int64Value, -64), + (TestStub.get_u_int64, protobuf.UInt64Value, 64), + (TestStub.get_int32, protobuf.Int32Value, -32), + (TestStub.get_u_int32, protobuf.UInt32Value, 32), + (TestStub.get_bool, protobuf.BoolValue, True), + (TestStub.get_string, protobuf.StringValue, "string"), + (TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) +async def test_channel_receives_wrapped_type( + service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value +): + wrapped_value = wrapper_class() + wrapped_value.value = value + channel = MockChannel(responses=[wrapped_value]) + service = TestStub(channel) + method_param = Input() + + await service_method(service, method_param) + + assert channel.requests[0]["response_type"] != Optional[type(value)] + assert channel.requests[0]["response_type"] == type(wrapped_value) + + +@pytest.mark.asyncio +@pytest.mark.xfail +@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) +async def test_service_unwraps_response( + service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value +): + """ + grpclib does not unwrap wrapper values returned by services + """ + wrapped_value = wrapper_class() + wrapped_value.value = value + service = TestStub(MockChannel(responses=[wrapped_value])) + method_param = Input() + + response_value = await service_method(service, method_param) + + assert response_value == value + assert type(response_value) == type(value) diff --git a/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto b/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto new file mode 100644 index 0000000..47284e3 --- /dev/null +++ b/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package googletypes_response_embedded; + +import "google/protobuf/wrappers.proto"; + +// Tests that wrapped values are supported as part of output message +service Test { + rpc getOutput (Input) returns (Output); +} + +message Input { + +} + +message Output { + google.protobuf.DoubleValue double_value = 1; + google.protobuf.FloatValue float_value = 2; + google.protobuf.Int64Value int64_value = 3; + google.protobuf.UInt64Value uint64_value = 4; + google.protobuf.Int32Value int32_value = 5; + google.protobuf.UInt32Value uint32_value = 6; + google.protobuf.BoolValue bool_value = 7; + google.protobuf.StringValue string_value = 8; + google.protobuf.BytesValue bytes_value = 9; +} diff --git a/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py b/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py new file mode 100644 index 0000000..3d31728 --- /dev/null +++ b/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py @@ -0,0 +1,40 @@ +import pytest + +from tests.mocks import MockChannel +from tests.output_aristaproto.googletypes_response_embedded import ( + Input, + Output, + TestStub, +) + + +@pytest.mark.asyncio +async def test_service_passes_through_unwrapped_values_embedded_in_response(): + """ + We do not not need to implement value unwrapping for embedded well-known types, + as this is already handled by grpclib. This test merely shows that this is the case. + """ + output = Output( + double_value=10.0, + float_value=12.0, + int64_value=-13, + uint64_value=14, + int32_value=-15, + uint32_value=16, + bool_value=True, + string_value="string", + bytes_value=bytes(0xFF)[0:4], + ) + + service = TestStub(MockChannel(responses=[output])) + response = await service.get_output(Input()) + + assert response.double_value == 10.0 + assert response.float_value == 12.0 + assert response.int64_value == -13 + assert response.uint64_value == 14 + assert response.int32_value == -15 + assert response.uint32_value == 16 + assert response.bool_value + assert response.string_value == "string" + assert response.bytes_value == bytes(0xFF)[0:4] diff --git a/tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto b/tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto new file mode 100644 index 0000000..2153ad5 --- /dev/null +++ b/tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package googletypes_service_returns_empty; + +import "google/protobuf/empty.proto"; + +service Test { + rpc Send (RequestMessage) returns (google.protobuf.Empty) { + } +} + +message RequestMessage { +}
\ No newline at end of file diff --git a/tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto b/tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto new file mode 100644 index 0000000..457707b --- /dev/null +++ b/tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package googletypes_service_returns_googletype; + +import "google/protobuf/empty.proto"; +import "google/protobuf/struct.proto"; + +// Tests that imports are generated correctly when returning Google well-known types + +service Test { + rpc GetEmpty (RequestMessage) returns (google.protobuf.Empty); + rpc GetStruct (RequestMessage) returns (google.protobuf.Struct); + rpc GetListValue (RequestMessage) returns (google.protobuf.ListValue); + rpc GetValue (RequestMessage) returns (google.protobuf.Value); +} + +message RequestMessage { +}
\ No newline at end of file diff --git a/tests/inputs/googletypes_struct/googletypes_struct.json b/tests/inputs/googletypes_struct/googletypes_struct.json new file mode 100644 index 0000000..ecc175e --- /dev/null +++ b/tests/inputs/googletypes_struct/googletypes_struct.json @@ -0,0 +1,5 @@ +{ + "struct": { + "key": true + } +} diff --git a/tests/inputs/googletypes_struct/googletypes_struct.proto b/tests/inputs/googletypes_struct/googletypes_struct.proto new file mode 100644 index 0000000..2b8b5c5 --- /dev/null +++ b/tests/inputs/googletypes_struct/googletypes_struct.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package googletypes_struct; + +import "google/protobuf/struct.proto"; + +message Test { + google.protobuf.Struct struct = 1; +} diff --git a/tests/inputs/googletypes_value/googletypes_value.json b/tests/inputs/googletypes_value/googletypes_value.json new file mode 100644 index 0000000..db52d5c --- /dev/null +++ b/tests/inputs/googletypes_value/googletypes_value.json @@ -0,0 +1,11 @@ +{ + "value1": "hello world", + "value2": true, + "value3": 1, + "value4": null, + "value5": [ + 1, + 2, + 3 + ] +} diff --git a/tests/inputs/googletypes_value/googletypes_value.proto b/tests/inputs/googletypes_value/googletypes_value.proto new file mode 100644 index 0000000..d5089d5 --- /dev/null +++ b/tests/inputs/googletypes_value/googletypes_value.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package googletypes_value; + +import "google/protobuf/struct.proto"; + +// Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values. + +message Test { + google.protobuf.Value value1 = 1; + google.protobuf.Value value2 = 2; + google.protobuf.Value value3 = 3; + google.protobuf.Value value4 = 4; + google.protobuf.Value value5 = 5; +} diff --git a/tests/inputs/import_capitalized_package/capitalized.proto b/tests/inputs/import_capitalized_package/capitalized.proto new file mode 100644 index 0000000..e80c95c --- /dev/null +++ b/tests/inputs/import_capitalized_package/capitalized.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + + +package import_capitalized_package.Capitalized; + +message Message { + +} diff --git a/tests/inputs/import_capitalized_package/test.proto b/tests/inputs/import_capitalized_package/test.proto new file mode 100644 index 0000000..38c9b2d --- /dev/null +++ b/tests/inputs/import_capitalized_package/test.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package import_capitalized_package; + +import "capitalized.proto"; + +// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't. + +message Test { + Capitalized.Message message = 1; +} diff --git a/tests/inputs/import_child_package_from_package/child.proto b/tests/inputs/import_child_package_from_package/child.proto new file mode 100644 index 0000000..d99c7c3 --- /dev/null +++ b/tests/inputs/import_child_package_from_package/child.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_child_package_from_package.package.childpackage; + +message ChildMessage { + +} diff --git a/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto b/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto new file mode 100644 index 0000000..66e0aa8 --- /dev/null +++ b/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package import_child_package_from_package; + +import "package_message.proto"; + +// Tests generated imports when a message in a package refers to a message in a nested child package. + +message Test { + package.PackageMessage message = 1; +} diff --git a/tests/inputs/import_child_package_from_package/package_message.proto b/tests/inputs/import_child_package_from_package/package_message.proto new file mode 100644 index 0000000..79d66f3 --- /dev/null +++ b/tests/inputs/import_child_package_from_package/package_message.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +import "child.proto"; + +package import_child_package_from_package.package; + +message PackageMessage { + package.childpackage.ChildMessage c = 1; +} diff --git a/tests/inputs/import_child_package_from_root/child.proto b/tests/inputs/import_child_package_from_root/child.proto new file mode 100644 index 0000000..2a46d5f --- /dev/null +++ b/tests/inputs/import_child_package_from_root/child.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_child_package_from_root.childpackage; + +message Message { + +} diff --git a/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto b/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto new file mode 100644 index 0000000..6299831 --- /dev/null +++ b/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package import_child_package_from_root; + +import "child.proto"; + +// Tests generated imports when a message in root refers to a message in a child package. + +message Test { + childpackage.Message child = 1; +} diff --git a/tests/inputs/import_circular_dependency/import_circular_dependency.proto b/tests/inputs/import_circular_dependency/import_circular_dependency.proto new file mode 100644 index 0000000..8b159e2 --- /dev/null +++ b/tests/inputs/import_circular_dependency/import_circular_dependency.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package import_circular_dependency; + +import "root.proto"; +import "other.proto"; + +// This test-case verifies support for circular dependencies in the generated python files. +// +// This is important because we generate 1 python file/module per package, rather than 1 file per proto file. +// +// Scenario: +// +// The proto messages depend on each other in a non-circular way: +// +// Test -------> RootPackageMessage <--------------. +// `------------------------------------> OtherPackageMessage +// +// Test and RootPackageMessage are in different files, but belong to the same package (root): +// +// (Test -------> RootPackageMessage) <------------. +// `------------------------------------> OtherPackageMessage +// +// After grouping the packages into single files or modules, a circular dependency is created: +// +// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage) +message Test { + RootPackageMessage message = 1; + other.OtherPackageMessage other = 2; +} diff --git a/tests/inputs/import_circular_dependency/other.proto b/tests/inputs/import_circular_dependency/other.proto new file mode 100644 index 0000000..833b869 --- /dev/null +++ b/tests/inputs/import_circular_dependency/other.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +import "root.proto"; +package import_circular_dependency.other; + +message OtherPackageMessage { + RootPackageMessage rootPackageMessage = 1; +} diff --git a/tests/inputs/import_circular_dependency/root.proto b/tests/inputs/import_circular_dependency/root.proto new file mode 100644 index 0000000..7383947 --- /dev/null +++ b/tests/inputs/import_circular_dependency/root.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_circular_dependency; + +message RootPackageMessage { + +} diff --git a/tests/inputs/import_cousin_package/cousin.proto b/tests/inputs/import_cousin_package/cousin.proto new file mode 100644 index 0000000..2870dfe --- /dev/null +++ b/tests/inputs/import_cousin_package/cousin.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +package import_cousin_package.cousin.cousin_subpackage; + +message CousinMessage { +} diff --git a/tests/inputs/import_cousin_package/test.proto b/tests/inputs/import_cousin_package/test.proto new file mode 100644 index 0000000..89ec3d8 --- /dev/null +++ b/tests/inputs/import_cousin_package/test.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package import_cousin_package.test.subpackage; + +import "cousin.proto"; + +// Verify that we can import message unrelated to us + +message Test { + cousin.cousin_subpackage.CousinMessage message = 1; +} diff --git a/tests/inputs/import_cousin_package_same_name/cousin.proto b/tests/inputs/import_cousin_package_same_name/cousin.proto new file mode 100644 index 0000000..84b6a40 --- /dev/null +++ b/tests/inputs/import_cousin_package_same_name/cousin.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +package import_cousin_package_same_name.cousin.subpackage; + +message CousinMessage { +} diff --git a/tests/inputs/import_cousin_package_same_name/test.proto b/tests/inputs/import_cousin_package_same_name/test.proto new file mode 100644 index 0000000..7b420d3 --- /dev/null +++ b/tests/inputs/import_cousin_package_same_name/test.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package import_cousin_package_same_name.test.subpackage; + +import "cousin.proto"; + +// Verify that we can import a message unrelated to us, in a subpackage with the same name as us. + +message Test { + cousin.subpackage.CousinMessage message = 1; +} diff --git a/tests/inputs/import_packages_same_name/import_packages_same_name.proto b/tests/inputs/import_packages_same_name/import_packages_same_name.proto new file mode 100644 index 0000000..dff7efe --- /dev/null +++ b/tests/inputs/import_packages_same_name/import_packages_same_name.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package import_packages_same_name; + +import "users_v1.proto"; +import "posts_v1.proto"; + +// Tests generated message can correctly reference two packages with the same leaf-name + +message Test { + users.v1.User user = 1; + posts.v1.Post post = 2; +} diff --git a/tests/inputs/import_packages_same_name/posts_v1.proto b/tests/inputs/import_packages_same_name/posts_v1.proto new file mode 100644 index 0000000..d3b9b1c --- /dev/null +++ b/tests/inputs/import_packages_same_name/posts_v1.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_packages_same_name.posts.v1; + +message Post { + +} diff --git a/tests/inputs/import_packages_same_name/users_v1.proto b/tests/inputs/import_packages_same_name/users_v1.proto new file mode 100644 index 0000000..d3a17e9 --- /dev/null +++ b/tests/inputs/import_packages_same_name/users_v1.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_packages_same_name.users.v1; + +message User { + +} diff --git a/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto b/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto new file mode 100644 index 0000000..edc4736 --- /dev/null +++ b/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +import "parent_package_message.proto"; + +package import_parent_package_from_child.parent.child; + +// Tests generated imports when a message refers to a message defined in its parent package + +message Test { + ParentPackageMessage message_implicit = 1; + parent.ParentPackageMessage message_explicit = 2; +} diff --git a/tests/inputs/import_parent_package_from_child/parent_package_message.proto b/tests/inputs/import_parent_package_from_child/parent_package_message.proto new file mode 100644 index 0000000..fb3fd31 --- /dev/null +++ b/tests/inputs/import_parent_package_from_child/parent_package_message.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +package import_parent_package_from_child.parent; + +message ParentPackageMessage { +} diff --git a/tests/inputs/import_root_package_from_child/child.proto b/tests/inputs/import_root_package_from_child/child.proto new file mode 100644 index 0000000..bd51967 --- /dev/null +++ b/tests/inputs/import_root_package_from_child/child.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package import_root_package_from_child.child; + +import "root.proto"; + +// Verify that we can import root message from child package + +message Test { + RootMessage message = 1; +} diff --git a/tests/inputs/import_root_package_from_child/root.proto b/tests/inputs/import_root_package_from_child/root.proto new file mode 100644 index 0000000..6ae955a --- /dev/null +++ b/tests/inputs/import_root_package_from_child/root.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_root_package_from_child; + + +message RootMessage { +} diff --git a/tests/inputs/import_root_sibling/import_root_sibling.proto b/tests/inputs/import_root_sibling/import_root_sibling.proto new file mode 100644 index 0000000..759e606 --- /dev/null +++ b/tests/inputs/import_root_sibling/import_root_sibling.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package import_root_sibling; + +import "sibling.proto"; + +// Tests generated imports when a message in the root package refers to another message in the root package + +message Test { + SiblingMessage sibling = 1; +} diff --git a/tests/inputs/import_root_sibling/sibling.proto b/tests/inputs/import_root_sibling/sibling.proto new file mode 100644 index 0000000..6b6ba2e --- /dev/null +++ b/tests/inputs/import_root_sibling/sibling.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_root_sibling; + +message SiblingMessage { + +} diff --git a/tests/inputs/import_service_input_message/child_package_request_message.proto b/tests/inputs/import_service_input_message/child_package_request_message.proto new file mode 100644 index 0000000..54fc112 --- /dev/null +++ b/tests/inputs/import_service_input_message/child_package_request_message.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_service_input_message.child; + +message ChildRequestMessage { + int32 child_argument = 1; +}
\ No newline at end of file diff --git a/tests/inputs/import_service_input_message/import_service_input_message.proto b/tests/inputs/import_service_input_message/import_service_input_message.proto new file mode 100644 index 0000000..cbf48fa --- /dev/null +++ b/tests/inputs/import_service_input_message/import_service_input_message.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package import_service_input_message; + +import "request_message.proto"; +import "child_package_request_message.proto"; + +// Tests generated service correctly imports the RequestMessage + +service Test { + rpc DoThing (RequestMessage) returns (RequestResponse); + rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse); + rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse); +} + + +message RequestResponse { + int32 value = 1; +} + +message Nested { + message RequestMessage { + int32 nestedArgument = 1; + } +}
\ No newline at end of file diff --git a/tests/inputs/import_service_input_message/request_message.proto b/tests/inputs/import_service_input_message/request_message.proto new file mode 100644 index 0000000..36a6e78 --- /dev/null +++ b/tests/inputs/import_service_input_message/request_message.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package import_service_input_message; + +message RequestMessage { + int32 argument = 1; +}
\ No newline at end of file diff --git a/tests/inputs/import_service_input_message/test_import_service_input_message.py b/tests/inputs/import_service_input_message/test_import_service_input_message.py new file mode 100644 index 0000000..66c654b --- /dev/null +++ b/tests/inputs/import_service_input_message/test_import_service_input_message.py @@ -0,0 +1,36 @@ +import pytest + +from tests.mocks import MockChannel +from tests.output_aristaproto.import_service_input_message import ( + NestedRequestMessage, + RequestMessage, + RequestResponse, + TestStub, +) +from tests.output_aristaproto.import_service_input_message.child import ( + ChildRequestMessage, +) + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing(RequestMessage(1)) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message_from_child_package(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing2(ChildRequestMessage(1)) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_nested_reference(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing3(NestedRequestMessage(1)) + assert mock_response == response diff --git a/tests/inputs/int32/int32.json b/tests/inputs/int32/int32.json new file mode 100644 index 0000000..34d4111 --- /dev/null +++ b/tests/inputs/int32/int32.json @@ -0,0 +1,4 @@ +{ + "positive": 150, + "negative": -150 +} diff --git a/tests/inputs/int32/int32.proto b/tests/inputs/int32/int32.proto new file mode 100644 index 0000000..4721c23 --- /dev/null +++ b/tests/inputs/int32/int32.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package int32; + +// Some documentation about the Test message. +message Test { + // Some documentation about the count. + int32 positive = 1; + int32 negative = 2; +} diff --git a/tests/inputs/map/map.json b/tests/inputs/map/map.json new file mode 100644 index 0000000..6a1e853 --- /dev/null +++ b/tests/inputs/map/map.json @@ -0,0 +1,7 @@ +{ + "counts": { + "item1": 1, + "item2": 2, + "item3": 3 + } +} diff --git a/tests/inputs/map/map.proto b/tests/inputs/map/map.proto new file mode 100644 index 0000000..ecef3cc --- /dev/null +++ b/tests/inputs/map/map.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package map; + +message Test { + map<string, int32> counts = 1; +} diff --git a/tests/inputs/mapmessage/mapmessage.json b/tests/inputs/mapmessage/mapmessage.json new file mode 100644 index 0000000..a944ddd --- /dev/null +++ b/tests/inputs/mapmessage/mapmessage.json @@ -0,0 +1,10 @@ +{ + "items": { + "foo": { + "count": 1 + }, + "bar": { + "count": 2 + } + } +} diff --git a/tests/inputs/mapmessage/mapmessage.proto b/tests/inputs/mapmessage/mapmessage.proto new file mode 100644 index 0000000..2c704a4 --- /dev/null +++ b/tests/inputs/mapmessage/mapmessage.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package mapmessage; + +message Test { + map<string, Nested> items = 1; +} + +message Nested { + int32 count = 1; +}
\ No newline at end of file diff --git a/tests/inputs/namespace_builtin_types/namespace_builtin_types.json b/tests/inputs/namespace_builtin_types/namespace_builtin_types.json new file mode 100644 index 0000000..8200032 --- /dev/null +++ b/tests/inputs/namespace_builtin_types/namespace_builtin_types.json @@ -0,0 +1,16 @@ +{ + "int": "value-for-int", + "float": "value-for-float", + "complex": "value-for-complex", + "list": "value-for-list", + "tuple": "value-for-tuple", + "range": "value-for-range", + "str": "value-for-str", + "bytearray": "value-for-bytearray", + "bytes": "value-for-bytes", + "memoryview": "value-for-memoryview", + "set": "value-for-set", + "frozenset": "value-for-frozenset", + "map": "value-for-map", + "bool": "value-for-bool" +}
\ No newline at end of file diff --git a/tests/inputs/namespace_builtin_types/namespace_builtin_types.proto b/tests/inputs/namespace_builtin_types/namespace_builtin_types.proto new file mode 100644 index 0000000..71cb029 --- /dev/null +++ b/tests/inputs/namespace_builtin_types/namespace_builtin_types.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package namespace_builtin_types; + +// Tests that messages may contain fields with names that are python types + +message Test { + // https://docs.python.org/2/library/stdtypes.html#numeric-types-int-float-long-complex + string int = 1; + string float = 2; + string complex = 3; + + // https://docs.python.org/3/library/stdtypes.html#sequence-types-list-tuple-range + string list = 4; + string tuple = 5; + string range = 6; + + // https://docs.python.org/3/library/stdtypes.html#str + string str = 7; + + // https://docs.python.org/3/library/stdtypes.html#bytearray-objects + string bytearray = 8; + + // https://docs.python.org/3/library/stdtypes.html#bytes-and-bytearray-operations + string bytes = 9; + + // https://docs.python.org/3/library/stdtypes.html#memory-views + string memoryview = 10; + + // https://docs.python.org/3/library/stdtypes.html#set-types-set-frozenset + string set = 11; + string frozenset = 12; + + // https://docs.python.org/3/library/stdtypes.html#dict + string map = 13; + string dict = 14; + + // https://docs.python.org/3/library/stdtypes.html#boolean-values + string bool = 15; +}
\ No newline at end of file diff --git a/tests/inputs/namespace_keywords/namespace_keywords.json b/tests/inputs/namespace_keywords/namespace_keywords.json new file mode 100644 index 0000000..4f11b60 --- /dev/null +++ b/tests/inputs/namespace_keywords/namespace_keywords.json @@ -0,0 +1,37 @@ +{ + "False": 1, + "None": 2, + "True": 3, + "and": 4, + "as": 5, + "assert": 6, + "async": 7, + "await": 8, + "break": 9, + "class": 10, + "continue": 11, + "def": 12, + "del": 13, + "elif": 14, + "else": 15, + "except": 16, + "finally": 17, + "for": 18, + "from": 19, + "global": 20, + "if": 21, + "import": 22, + "in": 23, + "is": 24, + "lambda": 25, + "nonlocal": 26, + "not": 27, + "or": 28, + "pass": 29, + "raise": 30, + "return": 31, + "try": 32, + "while": 33, + "with": 34, + "yield": 35 +} diff --git a/tests/inputs/namespace_keywords/namespace_keywords.proto b/tests/inputs/namespace_keywords/namespace_keywords.proto new file mode 100644 index 0000000..ac3e5c5 --- /dev/null +++ b/tests/inputs/namespace_keywords/namespace_keywords.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package namespace_keywords; + +// Tests that messages may contain fields that are Python keywords +// +// Generated with Python 3.7.6 +// print('\n'.join(f'string {k} = {i+1};' for i,k in enumerate(keyword.kwlist))) + +message Test { + string False = 1; + string None = 2; + string True = 3; + string and = 4; + string as = 5; + string assert = 6; + string async = 7; + string await = 8; + string break = 9; + string class = 10; + string continue = 11; + string def = 12; + string del = 13; + string elif = 14; + string else = 15; + string except = 16; + string finally = 17; + string for = 18; + string from = 19; + string global = 20; + string if = 21; + string import = 22; + string in = 23; + string is = 24; + string lambda = 25; + string nonlocal = 26; + string not = 27; + string or = 28; + string pass = 29; + string raise = 30; + string return = 31; + string try = 32; + string while = 33; + string with = 34; + string yield = 35; +}
\ No newline at end of file diff --git a/tests/inputs/nested/nested.json b/tests/inputs/nested/nested.json new file mode 100644 index 0000000..f460cad --- /dev/null +++ b/tests/inputs/nested/nested.json @@ -0,0 +1,7 @@ +{ + "nested": { + "count": 150 + }, + "sibling": {}, + "msg": "THIS" +} diff --git a/tests/inputs/nested/nested.proto b/tests/inputs/nested/nested.proto new file mode 100644 index 0000000..619c721 --- /dev/null +++ b/tests/inputs/nested/nested.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package nested; + +// A test message with a nested message inside of it. +message Test { + // This is the nested type. + message Nested { + // Stores a simple counter. + int32 count = 1; + } + // This is the nested enum. + enum Msg { + NONE = 0; + THIS = 1; + } + + Nested nested = 1; + Sibling sibling = 2; + Sibling sibling2 = 3; + Msg msg = 4; +} + +message Sibling { + int32 foo = 1; +}
\ No newline at end of file diff --git a/tests/inputs/nested2/nested2.proto b/tests/inputs/nested2/nested2.proto new file mode 100644 index 0000000..cd6510c --- /dev/null +++ b/tests/inputs/nested2/nested2.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package nested2; + +import "package.proto"; + +message Game { + message Player { + enum Race { + human = 0; + orc = 1; + } + } +} + +message Test { + Game game = 1; + Game.Player GamePlayer = 2; + Game.Player.Race GamePlayerRace = 3; + equipment.Weapon Weapon = 4; +}
\ No newline at end of file diff --git a/tests/inputs/nested2/package.proto b/tests/inputs/nested2/package.proto new file mode 100644 index 0000000..e12abb1 --- /dev/null +++ b/tests/inputs/nested2/package.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package nested2.equipment; + +message Weapon { + +}
\ No newline at end of file diff --git a/tests/inputs/nestedtwice/nestedtwice.json b/tests/inputs/nestedtwice/nestedtwice.json new file mode 100644 index 0000000..c953132 --- /dev/null +++ b/tests/inputs/nestedtwice/nestedtwice.json @@ -0,0 +1,11 @@ +{ + "top": { + "name": "double-nested", + "middle": { + "bottom": [{"foo": "hello"}], + "enumBottom": ["A"], + "topMiddleBottom": [{"a": "hello"}], + "bar": true + } + } +} diff --git a/tests/inputs/nestedtwice/nestedtwice.proto b/tests/inputs/nestedtwice/nestedtwice.proto new file mode 100644 index 0000000..84d142a --- /dev/null +++ b/tests/inputs/nestedtwice/nestedtwice.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package nestedtwice; + +/* Test doc. */ +message Test { + /* Top doc. */ + message Top { + /* Middle doc. */ + message Middle { + /* TopMiddleBottom doc.*/ + message TopMiddleBottom { + // TopMiddleBottom.a doc. + string a = 1; + } + /* EnumBottom doc. */ + enum EnumBottom{ + /* EnumBottom.A doc. */ + A = 0; + B = 1; + } + /* Bottom doc. */ + message Bottom { + /* Bottom.foo doc. */ + string foo = 1; + } + reserved 1; + /* Middle.bottom doc. */ + repeated Bottom bottom = 2; + repeated EnumBottom enumBottom=3; + repeated TopMiddleBottom topMiddleBottom=4; + bool bar = 5; + } + /* Top.name doc. */ + string name = 1; + Middle middle = 2; + } + /* Test.top doc. */ + Top top = 1; +} diff --git a/tests/inputs/nestedtwice/test_nestedtwice.py b/tests/inputs/nestedtwice/test_nestedtwice.py new file mode 100644 index 0000000..502e710 --- /dev/null +++ b/tests/inputs/nestedtwice/test_nestedtwice.py @@ -0,0 +1,25 @@ +import pytest + +from tests.output_aristaproto.nestedtwice import ( + Test, + TestTop, + TestTopMiddle, + TestTopMiddleBottom, + TestTopMiddleEnumBottom, + TestTopMiddleTopMiddleBottom, +) + + +@pytest.mark.parametrize( + ("cls", "expected_comment"), + [ + (Test, "Test doc."), + (TestTopMiddleEnumBottom, "EnumBottom doc."), + (TestTop, "Top doc."), + (TestTopMiddle, "Middle doc."), + (TestTopMiddleTopMiddleBottom, "TopMiddleBottom doc."), + (TestTopMiddleBottom, "Bottom doc."), + ], +) +def test_comment(cls, expected_comment): + assert cls.__doc__ == expected_comment diff --git a/tests/inputs/oneof/oneof-name.json b/tests/inputs/oneof/oneof-name.json new file mode 100644 index 0000000..605484b --- /dev/null +++ b/tests/inputs/oneof/oneof-name.json @@ -0,0 +1,3 @@ +{ + "pitier": "Mr. T" +} diff --git a/tests/inputs/oneof/oneof.json b/tests/inputs/oneof/oneof.json new file mode 100644 index 0000000..65cafc5 --- /dev/null +++ b/tests/inputs/oneof/oneof.json @@ -0,0 +1,3 @@ +{ + "pitied": 100 +} diff --git a/tests/inputs/oneof/oneof.proto b/tests/inputs/oneof/oneof.proto new file mode 100644 index 0000000..41f93b0 --- /dev/null +++ b/tests/inputs/oneof/oneof.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package oneof; + +message MixedDrink { + int32 shots = 1; +} + +message Test { + oneof foo { + int32 pitied = 1; + string pitier = 2; + } + + int32 just_a_regular_field = 3; + + oneof bar { + int32 drinks = 11; + string bar_name = 12; + MixedDrink mixed_drink = 13; + } +} + diff --git a/tests/inputs/oneof/oneof_name.json b/tests/inputs/oneof/oneof_name.json new file mode 100644 index 0000000..605484b --- /dev/null +++ b/tests/inputs/oneof/oneof_name.json @@ -0,0 +1,3 @@ +{ + "pitier": "Mr. T" +} diff --git a/tests/inputs/oneof/test_oneof.py b/tests/inputs/oneof/test_oneof.py new file mode 100644 index 0000000..8a38496 --- /dev/null +++ b/tests/inputs/oneof/test_oneof.py @@ -0,0 +1,43 @@ +import pytest + +import aristaproto +from tests.output_aristaproto.oneof import ( + MixedDrink, + Test, +) +from tests.output_aristaproto_pydantic.oneof import Test as TestPyd +from tests.util import get_test_case_json_data + + +def test_which_count(): + message = Test() + message.from_json(get_test_case_json_data("oneof")[0].json) + assert aristaproto.which_one_of(message, "foo") == ("pitied", 100) + + +def test_which_name(): + message = Test() + message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json) + assert aristaproto.which_one_of(message, "foo") == ("pitier", "Mr. T") + + +def test_which_count_pyd(): + message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar") + assert aristaproto.which_one_of(message, "foo") == ("pitier", "Mr. T") + + +def test_oneof_constructor_assign(): + message = Test(mixed_drink=MixedDrink(shots=42)) + field, value = aristaproto.which_one_of(message, "bar") + assert field == "mixed_drink" + assert value.shots == 42 + + +# Issue #305: +@pytest.mark.xfail +def test_oneof_nested_assign(): + message = Test() + message.mixed_drink.shots = 42 + field, value = aristaproto.which_one_of(message, "bar") + assert field == "mixed_drink" + assert value.shots == 42 diff --git a/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto b/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto new file mode 100644 index 0000000..f7ac6fe --- /dev/null +++ b/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package oneof_default_value_serialization; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + +message Message{ + int64 value = 1; +} + +message NestedMessage{ + int64 id = 1; + oneof value_type{ + Message wrapped_message_value = 2; + } +} + +message Test{ + oneof value_type { + bool bool_value = 1; + int64 int64_value = 2; + google.protobuf.Timestamp timestamp_value = 3; + google.protobuf.Duration duration_value = 4; + Message wrapped_message_value = 5; + NestedMessage wrapped_nested_message_value = 6; + google.protobuf.BoolValue wrapped_bool_value = 7; + } +}
\ No newline at end of file diff --git a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py new file mode 100644 index 0000000..0fad3d6 --- /dev/null +++ b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py @@ -0,0 +1,75 @@ +import datetime + +import pytest + +import aristaproto +from tests.output_aristaproto.oneof_default_value_serialization import ( + Message, + NestedMessage, + Test, +) + + +def assert_round_trip_serialization_works(message: Test) -> None: + assert aristaproto.which_one_of(message, "value_type") == aristaproto.which_one_of( + Test().from_json(message.to_json()), "value_type" + ) + + +def test_oneof_default_value_serialization_works_for_all_values(): + """ + Serialization from message with oneof set to default -> JSON -> message should keep + default value field intact. + """ + + test_cases = [ + Test(bool_value=False), + Test(int64_value=0), + Test( + timestamp_value=datetime.datetime( + year=1970, + month=1, + day=1, + hour=0, + minute=0, + tzinfo=datetime.timezone.utc, + ) + ), + Test(duration_value=datetime.timedelta(0)), + Test(wrapped_message_value=Message(value=0)), + # NOTE: Do NOT use aristaproto.BoolValue here, it will cause JSON serialization + # errors. + # TODO: Do we want to allow use of BoolValue directly within a wrapped field or + # should we simply hard fail here? + Test(wrapped_bool_value=False), + ] + for message in test_cases: + assert_round_trip_serialization_works(message) + + +def test_oneof_no_default_values_passed(): + message = Test() + assert ( + aristaproto.which_one_of(message, "value_type") + == aristaproto.which_one_of(Test().from_json(message.to_json()), "value_type") + == ("", None) + ) + + +def test_oneof_nested_oneof_messages_are_serialized_with_defaults(): + """ + Nested messages with oneofs should also be handled + """ + message = Test( + wrapped_nested_message_value=NestedMessage( + id=0, wrapped_message_value=Message(value=0) + ) + ) + assert ( + aristaproto.which_one_of(message, "value_type") + == aristaproto.which_one_of(Test().from_json(message.to_json()), "value_type") + == ( + "wrapped_nested_message_value", + NestedMessage(id=0, wrapped_message_value=Message(value=0)), + ) + ) diff --git a/tests/inputs/oneof_empty/oneof_empty.json b/tests/inputs/oneof_empty/oneof_empty.json new file mode 100644 index 0000000..9d21c89 --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty.json @@ -0,0 +1,3 @@ +{ + "nothing": {} +} diff --git a/tests/inputs/oneof_empty/oneof_empty.proto b/tests/inputs/oneof_empty/oneof_empty.proto new file mode 100644 index 0000000..ca51d5a --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +package oneof_empty; + +message Nothing {} + +message MaybeNothing { + string sometimes = 42; +} + +message Test { + oneof empty { + Nothing nothing = 1; + MaybeNothing maybe1 = 2; + MaybeNothing maybe2 = 3; + } +} diff --git a/tests/inputs/oneof_empty/oneof_empty_maybe1.json b/tests/inputs/oneof_empty/oneof_empty_maybe1.json new file mode 100644 index 0000000..f7a2d27 --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty_maybe1.json @@ -0,0 +1,3 @@ +{ + "maybe1": {} +} diff --git a/tests/inputs/oneof_empty/oneof_empty_maybe2.json b/tests/inputs/oneof_empty/oneof_empty_maybe2.json new file mode 100644 index 0000000..bc2b385 --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty_maybe2.json @@ -0,0 +1,5 @@ +{ + "maybe2": { + "sometimes": "now" + } +} diff --git a/tests/inputs/oneof_empty/test_oneof_empty.py b/tests/inputs/oneof_empty/test_oneof_empty.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/inputs/oneof_empty/test_oneof_empty.py diff --git a/tests/inputs/oneof_enum/oneof_enum-enum-0.json b/tests/inputs/oneof_enum/oneof_enum-enum-0.json new file mode 100644 index 0000000..be30cf0 --- /dev/null +++ b/tests/inputs/oneof_enum/oneof_enum-enum-0.json @@ -0,0 +1,3 @@ +{ + "signal": "PASS" +} diff --git a/tests/inputs/oneof_enum/oneof_enum-enum-1.json b/tests/inputs/oneof_enum/oneof_enum-enum-1.json new file mode 100644 index 0000000..cb63873 --- /dev/null +++ b/tests/inputs/oneof_enum/oneof_enum-enum-1.json @@ -0,0 +1,3 @@ +{ + "signal": "RESIGN" +} diff --git a/tests/inputs/oneof_enum/oneof_enum.json b/tests/inputs/oneof_enum/oneof_enum.json new file mode 100644 index 0000000..3220b70 --- /dev/null +++ b/tests/inputs/oneof_enum/oneof_enum.json @@ -0,0 +1,6 @@ +{ + "move": { + "x": 2, + "y": 3 + } +} diff --git a/tests/inputs/oneof_enum/oneof_enum.proto b/tests/inputs/oneof_enum/oneof_enum.proto new file mode 100644 index 0000000..906abcb --- /dev/null +++ b/tests/inputs/oneof_enum/oneof_enum.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package oneof_enum; + +message Test { + oneof action { + Signal signal = 1; + Move move = 2; + } +} + +enum Signal { + PASS = 0; + RESIGN = 1; +} + +message Move { + int32 x = 1; + int32 y = 2; +}
\ No newline at end of file diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py new file mode 100644 index 0000000..98de22a --- /dev/null +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -0,0 +1,47 @@ +import pytest + +import aristaproto +from tests.output_aristaproto.oneof_enum import ( + Move, + Signal, + Test, +) +from tests.util import get_test_case_json_data + + +def test_which_one_of_returns_enum_with_default_value(): + """ + returns first field when it is enum and set with default value + """ + message = Test() + message.from_json( + get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json + ) + + assert not hasattr(message, "move") + assert object.__getattribute__(message, "move") == aristaproto.PLACEHOLDER + assert message.signal == Signal.PASS + assert aristaproto.which_one_of(message, "action") == ("signal", Signal.PASS) + + +def test_which_one_of_returns_enum_with_non_default_value(): + """ + returns first field when it is enum and set with non default value + """ + message = Test() + message.from_json( + get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json + ) + assert not hasattr(message, "move") + assert object.__getattribute__(message, "move") == aristaproto.PLACEHOLDER + assert message.signal == Signal.RESIGN + assert aristaproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) + + +def test_which_one_of_returns_second_field_when_set(): + message = Test() + message.from_json(get_test_case_json_data("oneof_enum")[0].json) + assert message.move == Move(x=2, y=3) + assert not hasattr(message, "signal") + assert object.__getattribute__(message, "signal") == aristaproto.PLACEHOLDER + assert aristaproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence.json b/tests/inputs/proto3_field_presence/proto3_field_presence.json new file mode 100644 index 0000000..988df8e --- /dev/null +++ b/tests/inputs/proto3_field_presence/proto3_field_presence.json @@ -0,0 +1,13 @@ +{ + "test1": 128, + "test2": true, + "test3": "A value", + "test4": "aGVsbG8=", + "test5": { + "test": "Hello" + }, + "test6": "B", + "test7": "8589934592", + "test8": 2.5, + "test9": "2022-01-24T12:12:42Z" +} diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence.proto b/tests/inputs/proto3_field_presence/proto3_field_presence.proto new file mode 100644 index 0000000..f28123d --- /dev/null +++ b/tests/inputs/proto3_field_presence/proto3_field_presence.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package proto3_field_presence; + +import "google/protobuf/timestamp.proto"; + +message InnerTest { + string test = 1; +} + +message Test { + optional uint32 test1 = 1; + optional bool test2 = 2; + optional string test3 = 3; + optional bytes test4 = 4; + optional InnerTest test5 = 5; + optional TestEnum test6 = 6; + optional uint64 test7 = 7; + optional float test8 = 8; + optional google.protobuf.Timestamp test9 = 9; +} + +enum TestEnum { + A = 0; + B = 1; +} diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence_default.json b/tests/inputs/proto3_field_presence/proto3_field_presence_default.json new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/tests/inputs/proto3_field_presence/proto3_field_presence_default.json @@ -0,0 +1 @@ +{} diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json b/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json new file mode 100644 index 0000000..b19ae98 --- /dev/null +++ b/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json @@ -0,0 +1,9 @@ +{ + "test1": 0, + "test2": false, + "test3": "", + "test4": "", + "test6": "A", + "test7": "0", + "test8": 0 +} diff --git a/tests/inputs/proto3_field_presence/test_proto3_field_presence.py b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py new file mode 100644 index 0000000..80696b2 --- /dev/null +++ b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py @@ -0,0 +1,48 @@ +import json + +from tests.output_aristaproto.proto3_field_presence import ( + InnerTest, + Test, + TestEnum, +) + + +def test_null_fields_json(): + """Ensure that using "null" in JSON is equivalent to not specifying a + field, for fields with explicit presence""" + + def test_json(ref_json: str, obj_json: str) -> None: + """`ref_json` and `obj_json` are JSON strings describing a `Test` object. + Test that deserializing both leads to the same object, and that + `ref_json` is the normalized format.""" + ref_obj = Test().from_json(ref_json) + obj = Test().from_json(obj_json) + + assert obj == ref_obj + assert json.loads(obj.to_json(0)) == json.loads(ref_json) + + test_json("{}", '{ "test1": null, "test2": null, "test3": null }') + test_json("{}", '{ "test4": null, "test5": null, "test6": null }') + test_json("{}", '{ "test7": null, "test8": null }') + test_json('{ "test5": {} }', '{ "test3": null, "test5": {} }') + + # Make sure that if include_default_values is set, None values are + # exported. + obj = Test() + assert obj.to_dict() == {} + assert obj.to_dict(include_default_values=True) == { + "test1": None, + "test2": None, + "test3": None, + "test4": None, + "test5": None, + "test6": None, + "test7": None, + "test8": None, + "test9": None, + } + + +def test_unset_access(): # see #523 + assert Test().test1 is None + assert Test(test1=None).test1 is None diff --git a/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json new file mode 100644 index 0000000..da08192 --- /dev/null +++ b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json @@ -0,0 +1,3 @@ +{ + "nested": {} +} diff --git a/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto new file mode 100644 index 0000000..caa76ec --- /dev/null +++ b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package proto3_field_presence_oneof; + +message Test { + oneof kind { + Nested nested = 1; + WithOptional with_optional = 2; + } +} + +message InnerNested { + optional bool a = 1; +} + +message Nested { + InnerNested inner = 1; +} + +message WithOptional { + optional bool b = 2; +} diff --git a/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py new file mode 100644 index 0000000..f13c973 --- /dev/null +++ b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py @@ -0,0 +1,29 @@ +from tests.output_aristaproto.proto3_field_presence_oneof import ( + InnerNested, + Nested, + Test, + WithOptional, +) + + +def test_serialization(): + """Ensure that serialization of fields unset but with explicit field + presence do not bloat the serialized payload with length-delimited fields + with length 0""" + + def test_empty_nested(message: Test) -> None: + # '0a' => tag 1, length delimited + # '00' => length: 0 + assert bytes(message) == bytearray.fromhex("0a 00") + + test_empty_nested(Test(nested=Nested())) + test_empty_nested(Test(nested=Nested(inner=None))) + test_empty_nested(Test(nested=Nested(inner=InnerNested(a=None)))) + + def test_empty_with_optional(message: Test) -> None: + # '12' => tag 2, length delimited + # '00' => length: 0 + assert bytes(message) == bytearray.fromhex("12 00") + + test_empty_with_optional(Test(with_optional=WithOptional())) + test_empty_with_optional(Test(with_optional=WithOptional(b=None))) diff --git a/tests/inputs/recursivemessage/recursivemessage.json b/tests/inputs/recursivemessage/recursivemessage.json new file mode 100644 index 0000000..e92c3fb --- /dev/null +++ b/tests/inputs/recursivemessage/recursivemessage.json @@ -0,0 +1,12 @@ +{ + "name": "Zues", + "child": { + "name": "Hercules" + }, + "intermediate": { + "child": { + "name": "Douglas Adams" + }, + "number": 42 + } +} diff --git a/tests/inputs/recursivemessage/recursivemessage.proto b/tests/inputs/recursivemessage/recursivemessage.proto new file mode 100644 index 0000000..1da2b57 --- /dev/null +++ b/tests/inputs/recursivemessage/recursivemessage.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package recursivemessage; + +message Test { + string name = 1; + Test child = 2; + Intermediate intermediate = 3; +} + + +message Intermediate { + int32 number = 1; + Test child = 2; +} diff --git a/tests/inputs/ref/ref.json b/tests/inputs/ref/ref.json new file mode 100644 index 0000000..2c6bdc1 --- /dev/null +++ b/tests/inputs/ref/ref.json @@ -0,0 +1,5 @@ +{ + "greeting": { + "greeting": "hello" + } +} diff --git a/tests/inputs/ref/ref.proto b/tests/inputs/ref/ref.proto new file mode 100644 index 0000000..6945590 --- /dev/null +++ b/tests/inputs/ref/ref.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package ref; + +import "repeatedmessage.proto"; + +message Test { + repeatedmessage.Sub greeting = 1; +} diff --git a/tests/inputs/ref/repeatedmessage.proto b/tests/inputs/ref/repeatedmessage.proto new file mode 100644 index 0000000..0ffacaf --- /dev/null +++ b/tests/inputs/ref/repeatedmessage.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package repeatedmessage; + +message Test { + repeated Sub greetings = 1; +} + +message Sub { + string greeting = 1; +}
\ No newline at end of file diff --git a/tests/inputs/regression_387/regression_387.proto b/tests/inputs/regression_387/regression_387.proto new file mode 100644 index 0000000..57bd954 --- /dev/null +++ b/tests/inputs/regression_387/regression_387.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package regression_387; + +message Test { + uint64 id = 1; +} + +message ParentElement { + string name = 1; + repeated Test elems = 2; +}
\ No newline at end of file diff --git a/tests/inputs/regression_387/test_regression_387.py b/tests/inputs/regression_387/test_regression_387.py new file mode 100644 index 0000000..92d96ba --- /dev/null +++ b/tests/inputs/regression_387/test_regression_387.py @@ -0,0 +1,12 @@ +from tests.output_aristaproto.regression_387 import ( + ParentElement, + Test, +) + + +def test_regression_387(): + el = ParentElement(name="test", elems=[Test(id=0), Test(id=42)]) + binary = bytes(el) + decoded = ParentElement().parse(binary) + assert decoded == el + assert decoded.elems == [Test(id=0), Test(id=42)] diff --git a/tests/inputs/regression_414/regression_414.proto b/tests/inputs/regression_414/regression_414.proto new file mode 100644 index 0000000..d20ddda --- /dev/null +++ b/tests/inputs/regression_414/regression_414.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package regression_414; + +message Test { + bytes body = 1; + bytes auth = 2; + repeated bytes signatures = 3; +}
\ No newline at end of file diff --git a/tests/inputs/regression_414/test_regression_414.py b/tests/inputs/regression_414/test_regression_414.py new file mode 100644 index 0000000..9441470 --- /dev/null +++ b/tests/inputs/regression_414/test_regression_414.py @@ -0,0 +1,15 @@ +from tests.output_aristaproto.regression_414 import Test + + +def test_full_cycle(): + body = bytes([0, 1]) + auth = bytes([2, 3]) + sig = [b""] + + obj = Test(body=body, auth=auth, signatures=sig) + + decoded = Test().parse(bytes(obj)) + assert decoded == obj + assert decoded.body == body + assert decoded.auth == auth + assert decoded.signatures == sig diff --git a/tests/inputs/repeated/repeated.json b/tests/inputs/repeated/repeated.json new file mode 100644 index 0000000..b8a7c4e --- /dev/null +++ b/tests/inputs/repeated/repeated.json @@ -0,0 +1,3 @@ +{ + "names": ["one", "two", "three"] +} diff --git a/tests/inputs/repeated/repeated.proto b/tests/inputs/repeated/repeated.proto new file mode 100644 index 0000000..4f3c788 --- /dev/null +++ b/tests/inputs/repeated/repeated.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package repeated; + +message Test { + repeated string names = 1; +} diff --git a/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json new file mode 100644 index 0000000..6ce7b34 --- /dev/null +++ b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json @@ -0,0 +1,4 @@ +{ + "times": ["1972-01-01T10:00:20.021Z", "1972-01-01T10:00:20.021Z"], + "durations": ["1.200s", "1.200s"] +} diff --git a/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto new file mode 100644 index 0000000..38f1eaa --- /dev/null +++ b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package repeated_duration_timestamp; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + + +message Test { + repeated google.protobuf.Timestamp times = 1; + repeated google.protobuf.Duration durations = 2; +} diff --git a/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py b/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py new file mode 100644 index 0000000..aafc951 --- /dev/null +++ b/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py @@ -0,0 +1,12 @@ +from datetime import ( + datetime, + timedelta, +) + +from tests.output_aristaproto.repeated_duration_timestamp import Test + + +def test_roundtrip(): + message = Test() + message.times = [datetime.now(), datetime.now()] + message.durations = [timedelta(), timedelta()] diff --git a/tests/inputs/repeatedmessage/repeatedmessage.json b/tests/inputs/repeatedmessage/repeatedmessage.json new file mode 100644 index 0000000..90ec596 --- /dev/null +++ b/tests/inputs/repeatedmessage/repeatedmessage.json @@ -0,0 +1,10 @@ +{ + "greetings": [ + { + "greeting": "hello" + }, + { + "greeting": "hi" + } + ] +} diff --git a/tests/inputs/repeatedmessage/repeatedmessage.proto b/tests/inputs/repeatedmessage/repeatedmessage.proto new file mode 100644 index 0000000..0ffacaf --- /dev/null +++ b/tests/inputs/repeatedmessage/repeatedmessage.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package repeatedmessage; + +message Test { + repeated Sub greetings = 1; +} + +message Sub { + string greeting = 1; +}
\ No newline at end of file diff --git a/tests/inputs/repeatedpacked/repeatedpacked.json b/tests/inputs/repeatedpacked/repeatedpacked.json new file mode 100644 index 0000000..106fd90 --- /dev/null +++ b/tests/inputs/repeatedpacked/repeatedpacked.json @@ -0,0 +1,5 @@ +{ + "counts": [1, 2, -1, -2], + "signed": ["1", "2", "-1", "-2"], + "fixed": [1.0, 2.7, 3.4] +} diff --git a/tests/inputs/repeatedpacked/repeatedpacked.proto b/tests/inputs/repeatedpacked/repeatedpacked.proto new file mode 100644 index 0000000..a037d1b --- /dev/null +++ b/tests/inputs/repeatedpacked/repeatedpacked.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package repeatedpacked; + +message Test { + repeated int32 counts = 1; + repeated sint64 signed = 2; + repeated double fixed = 3; +} diff --git a/tests/inputs/service/service.proto b/tests/inputs/service/service.proto new file mode 100644 index 0000000..53d84fb --- /dev/null +++ b/tests/inputs/service/service.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package service; + +enum ThingType { + UNKNOWN = 0; + LIVING = 1; + DEAD = 2; +} + +message DoThingRequest { + string name = 1; + repeated string comments = 2; + ThingType type = 3; +} + +message DoThingResponse { + repeated string names = 1; +} + +message GetThingRequest { + string name = 1; +} + +message GetThingResponse { + string name = 1; + int32 version = 2; +} + +service Test { + rpc DoThing (DoThingRequest) returns (DoThingResponse); + rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse); + rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse); + rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse); +} diff --git a/tests/inputs/service_separate_packages/messages.proto b/tests/inputs/service_separate_packages/messages.proto new file mode 100644 index 0000000..270b188 --- /dev/null +++ b/tests/inputs/service_separate_packages/messages.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + +package service_separate_packages.things.messages; + +message DoThingRequest { + string name = 1; + + // use `repeated` so we can check if `List` is correctly imported + repeated string comments = 2; + + // use google types `timestamp` and `duration` so we can check + // if everything from `datetime` is correctly imported + google.protobuf.Timestamp when = 3; + google.protobuf.Duration duration = 4; +} + +message DoThingResponse { + repeated string names = 1; +} + +message GetThingRequest { + string name = 1; +} + +message GetThingResponse { + string name = 1; + int32 version = 2; +} diff --git a/tests/inputs/service_separate_packages/service.proto b/tests/inputs/service_separate_packages/service.proto new file mode 100644 index 0000000..950eab4 --- /dev/null +++ b/tests/inputs/service_separate_packages/service.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +import "messages.proto"; + +package service_separate_packages.things.service; + +service Test { + rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse); + rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse); + rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse); + rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse); +} diff --git a/tests/inputs/service_uppercase/service.proto b/tests/inputs/service_uppercase/service.proto new file mode 100644 index 0000000..786eec2 --- /dev/null +++ b/tests/inputs/service_uppercase/service.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package service_uppercase; + +message DoTHINGRequest { + string name = 1; + repeated string comments = 2; +} + +message DoTHINGResponse { + repeated string names = 1; +} + +service Test { + rpc DoThing (DoTHINGRequest) returns (DoTHINGResponse); +} diff --git a/tests/inputs/service_uppercase/test_service.py b/tests/inputs/service_uppercase/test_service.py new file mode 100644 index 0000000..d10fccf --- /dev/null +++ b/tests/inputs/service_uppercase/test_service.py @@ -0,0 +1,8 @@ +import inspect + +from tests.output_aristaproto.service_uppercase import TestStub + + +def test_parameters(): + sig = inspect.signature(TestStub.do_thing) + assert len(sig.parameters) == 5, "Expected 5 parameters" diff --git a/tests/inputs/signed/signed.json b/tests/inputs/signed/signed.json new file mode 100644 index 0000000..b171e15 --- /dev/null +++ b/tests/inputs/signed/signed.json @@ -0,0 +1,6 @@ +{ + "signed32": 150, + "negative32": -150, + "string64": "150", + "negative64": "-150" +} diff --git a/tests/inputs/signed/signed.proto b/tests/inputs/signed/signed.proto new file mode 100644 index 0000000..b40aad4 --- /dev/null +++ b/tests/inputs/signed/signed.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package signed; + +message Test { + // todo: rename fields after fixing bug where 'signed_32_positive' will map to 'signed_32Positive' as output json + sint32 signed32 = 1; // signed_32_positive + sint32 negative32 = 2; // signed_32_negative + sint64 string64 = 3; // signed_64_positive + sint64 negative64 = 4; // signed_64_negative +} diff --git a/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py b/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py new file mode 100644 index 0000000..59be3d1 --- /dev/null +++ b/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py @@ -0,0 +1,82 @@ +from datetime import ( + datetime, + timedelta, + timezone, +) + +import pytest + +from tests.output_aristaproto.timestamp_dict_encode import Test + + +# Current World Timezone range (UTC-12 to UTC+14) +MIN_UTC_OFFSET_MIN = -12 * 60 +MAX_UTC_OFFSET_MIN = 14 * 60 + +# Generate all timezones in range in 15 min increments +timezones = [ + timezone(timedelta(minutes=x)) + for x in range(MIN_UTC_OFFSET_MIN, MAX_UTC_OFFSET_MIN + 1, 15) +] + + +@pytest.mark.parametrize("tz", timezones) +def test_timezone_aware_datetime_dict_encode(tz: timezone): + original_time = datetime.now(tz=tz) + original_message = Test() + original_message.ts = original_time + encoded = original_message.to_dict() + decoded_message = Test() + decoded_message.from_dict(encoded) + + # check that the timestamps are equal after decoding from dict + assert original_message.ts.tzinfo is not None + assert decoded_message.ts.tzinfo is not None + assert original_message.ts == decoded_message.ts + + +def test_naive_datetime_dict_encode(): + # make suer naive datetime objects are still treated as utc + original_time = datetime.now() + assert original_time.tzinfo is None + original_message = Test() + original_message.ts = original_time + original_time_utc = original_time.replace(tzinfo=timezone.utc) + encoded = original_message.to_dict() + decoded_message = Test() + decoded_message.from_dict(encoded) + + # check that the timestamps are equal after decoding from dict + assert decoded_message.ts.tzinfo is not None + assert original_time_utc == decoded_message.ts + + +@pytest.mark.parametrize("tz", timezones) +def test_timezone_aware_json_serialize(tz: timezone): + original_time = datetime.now(tz=tz) + original_message = Test() + original_message.ts = original_time + json_serialized = original_message.to_json() + decoded_message = Test() + decoded_message.from_json(json_serialized) + + # check that the timestamps are equal after decoding from dict + assert original_message.ts.tzinfo is not None + assert decoded_message.ts.tzinfo is not None + assert original_message.ts == decoded_message.ts + + +def test_naive_datetime_json_serialize(): + # make suer naive datetime objects are still treated as utc + original_time = datetime.now() + assert original_time.tzinfo is None + original_message = Test() + original_message.ts = original_time + original_time_utc = original_time.replace(tzinfo=timezone.utc) + json_serialized = original_message.to_json() + decoded_message = Test() + decoded_message.from_json(json_serialized) + + # check that the timestamps are equal after decoding from dict + assert decoded_message.ts.tzinfo is not None + assert original_time_utc == decoded_message.ts diff --git a/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json new file mode 100644 index 0000000..3f45558 --- /dev/null +++ b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json @@ -0,0 +1,3 @@ +{ + "ts" : "2023-03-15T22:35:51.253277Z" +}
\ No newline at end of file diff --git a/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto new file mode 100644 index 0000000..9c4081a --- /dev/null +++ b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package timestamp_dict_encode; + +import "google/protobuf/timestamp.proto"; + +message Test { + google.protobuf.Timestamp ts = 1; +}
\ No newline at end of file diff --git a/tests/mocks.py b/tests/mocks.py new file mode 100644 index 0000000..dc6e117 --- /dev/null +++ b/tests/mocks.py @@ -0,0 +1,40 @@ +from typing import List + +from grpclib.client import Channel + + +class MockChannel(Channel): + # noinspection PyMissingConstructor + def __init__(self, responses=None) -> None: + self.responses = responses or [] + self.requests = [] + self._loop = None + + def request(self, route, cardinality, request, response_type, **kwargs): + self.requests.append( + { + "route": route, + "cardinality": cardinality, + "request": request, + "response_type": response_type, + } + ) + return MockStream(self.responses) + + +class MockStream: + def __init__(self, responses: List) -> None: + super().__init__() + self.responses = responses + + async def recv_message(self): + return self.responses.pop(0) + + async def send_message(self, *args, **kwargs): + pass + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return True + + async def __aenter__(self): + return self diff --git a/tests/oneof_pattern_matching.py b/tests/oneof_pattern_matching.py new file mode 100644 index 0000000..2c5e797 --- /dev/null +++ b/tests/oneof_pattern_matching.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +import pytest + +import aristaproto + + +def test_oneof_pattern_matching(): + @dataclass + class Sub(aristaproto.Message): + val: int = aristaproto.int32_field(1) + + @dataclass + class Foo(aristaproto.Message): + bar: int = aristaproto.int32_field(1, group="group1") + baz: str = aristaproto.string_field(2, group="group1") + sub: Sub = aristaproto.message_field(3, group="group2") + abc: str = aristaproto.string_field(4, group="group2") + + foo = Foo(baz="test1", abc="test2") + + match foo: + case Foo(bar=_): + pytest.fail("Matched 'bar' instead of 'baz'") + case Foo(baz=v): + assert v == "test1" + case _: + pytest.fail("Matched neither 'bar' nor 'baz'") + + match foo: + case Foo(sub=_): + pytest.fail("Matched 'sub' instead of 'abc'") + case Foo(abc=v): + assert v == "test2" + case _: + pytest.fail("Matched neither 'sub' nor 'abc'") + + foo.sub = Sub(val=1) + + match foo: + case Foo(sub=Sub(val=v)): + assert v == 1 + case Foo(abc=v): + pytest.fail("Matched 'abc' instead of 'sub'") + case _: + pytest.fail("Matched neither 'sub' nor 'abc'") diff --git a/tests/streams/delimited_messages.in b/tests/streams/delimited_messages.in new file mode 100644 index 0000000..5993ac6 --- /dev/null +++ b/tests/streams/delimited_messages.in @@ -0,0 +1,2 @@ +•šï:bTesting•šï:bTesting +
\ No newline at end of file diff --git a/tests/streams/dump_varint_negative.expected b/tests/streams/dump_varint_negative.expected new file mode 100644 index 0000000..0954822 --- /dev/null +++ b/tests/streams/dump_varint_negative.expected @@ -0,0 +1 @@ +ÿÿÿÿÿÿÿÿÿ€Óûÿÿÿÿÿ€€€€€€€€€€€€€€€€€
\ No newline at end of file diff --git a/tests/streams/dump_varint_positive.expected b/tests/streams/dump_varint_positive.expected new file mode 100644 index 0000000..8614b9d --- /dev/null +++ b/tests/streams/dump_varint_positive.expected @@ -0,0 +1 @@ +ۉ
\ No newline at end of file diff --git a/tests/streams/java/.gitignore b/tests/streams/java/.gitignore new file mode 100644 index 0000000..9b1ebba --- /dev/null +++ b/tests/streams/java/.gitignore @@ -0,0 +1,38 @@ +### Output ### +target/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ +dependency-reduced-pom.xml +MANIFEST.MF + +### IntelliJ IDEA ### +.idea/ +*.iws +*.iml +*.ipr + +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ + +### VS Code ### +.vscode/ + +### Mac OS ### +.DS_Store
\ No newline at end of file diff --git a/tests/streams/java/pom.xml b/tests/streams/java/pom.xml new file mode 100644 index 0000000..e39c567 --- /dev/null +++ b/tests/streams/java/pom.xml @@ -0,0 +1,94 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <groupId>aristaproto</groupId> + <artifactId>compatibility-test</artifactId> + <version>1.0-SNAPSHOT</version> + <packaging>jar</packaging> + + <properties> + <maven.compiler.source>11</maven.compiler.source> + <maven.compiler.target>11</maven.compiler.target> + <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> + <protobuf.version>3.23.4</protobuf.version> + </properties> + + <dependencies> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> + </dependencies> + + <build> + <extensions> + <extension> + <groupId>kr.motd.maven</groupId> + <artifactId>os-maven-plugin</artifactId> + <version>1.7.1</version> + </extension> + </extensions> + + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + <version>3.5.0</version> + <executions> + <execution> + <phase>package</phase> + <goals> + <goal>shade</goal> + </goals> + <configuration> + <transformers> + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> + <mainClass>aristaproto.CompatibilityTest</mainClass> + </transformer> + </transformers> + </configuration> + </execution> + </executions> + </plugin> + + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>3.3.0</version> + <configuration> + <archive> + <manifest> + <addClasspath>true</addClasspath> + <mainClass>aristaproto.CompatibilityTest</mainClass> + </manifest> + </archive> + </configuration> + </plugin> + + <plugin> + <groupId>org.xolstice.maven.plugins</groupId> + <artifactId>protobuf-maven-plugin</artifactId> + <version>0.6.1</version> + <executions> + <execution> + <goals> + <goal>compile</goal> + </goals> + </execution> + </executions> + <configuration> + <protocArtifact> + com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + </protocArtifact> + </configuration> + </plugin> + </plugins> + + <finalName>${project.artifactId}</finalName> + </build> + +</project>
\ No newline at end of file diff --git a/tests/streams/java/src/main/java/aristaproto/CompatibilityTest.java b/tests/streams/java/src/main/java/aristaproto/CompatibilityTest.java new file mode 100644 index 0000000..b0cff9f --- /dev/null +++ b/tests/streams/java/src/main/java/aristaproto/CompatibilityTest.java @@ -0,0 +1,41 @@ +package aristaproto; + +import java.io.IOException; + +public class CompatibilityTest { + public static void main(String[] args) throws IOException { + if (args.length < 2) + throw new RuntimeException("Attempted to run without the required arguments."); + else if (args.length > 2) + throw new RuntimeException( + "Attempted to run with more than the expected number of arguments (>1)."); + + Tests tests = new Tests(args[1]); + + switch (args[0]) { + case "single_varint": + tests.testSingleVarint(); + break; + + case "multiple_varints": + tests.testMultipleVarints(); + break; + + case "single_message": + tests.testSingleMessage(); + break; + + case "multiple_messages": + tests.testMultipleMessages(); + break; + + case "infinite_messages": + tests.testInfiniteMessages(); + break; + + default: + throw new RuntimeException( + "Attempted to run with unknown argument '" + args[0] + "'."); + } + } +} diff --git a/tests/streams/java/src/main/java/aristaproto/Tests.java b/tests/streams/java/src/main/java/aristaproto/Tests.java new file mode 100644 index 0000000..aabbac7 --- /dev/null +++ b/tests/streams/java/src/main/java/aristaproto/Tests.java @@ -0,0 +1,115 @@ +package aristaproto; + +import aristaproto.nested.NestedOuterClass; +import aristaproto.oneof.Oneof; + +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.CodedOutputStream; + +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +public class Tests { + String path; + + public Tests(String path) { + this.path = path; + } + + public void testSingleVarint() throws IOException { + // Read in the Python-generated single varint file + FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + int value = codedInput.readUInt32(); + + inputStream.close(); + + // Write the value back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + codedOutput.writeUInt32NoTag(value); + + codedOutput.flush(); + outputStream.close(); + } + + public void testMultipleVarints() throws IOException { + // Read in the Python-generated multiple varints file + FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + int value1 = codedInput.readUInt32(); + int value2 = codedInput.readUInt32(); + long value3 = codedInput.readUInt64(); + + inputStream.close(); + + // Write the values back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + codedOutput.writeUInt32NoTag(value1); + codedOutput.writeUInt64NoTag(value2); + codedOutput.writeUInt64NoTag(value3); + + codedOutput.flush(); + outputStream.close(); + } + + public void testSingleMessage() throws IOException { + // Read in the Python-generated single message file + FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out"); + CodedInputStream codedInput = CodedInputStream.newInstance(inputStream); + + Oneof.Test message = Oneof.Test.parseFrom(codedInput); + + inputStream.close(); + + // Write the message back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out"); + CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream); + + message.writeTo(codedOutput); + + codedOutput.flush(); + outputStream.close(); + } + + public void testMultipleMessages() throws IOException { + // Read in the Python-generated multi-message file + FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out"); + + Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream); + NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream); + + inputStream.close(); + + // Write the messages back to a file + FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out"); + + oneof.writeDelimitedTo(outputStream); + nested.writeDelimitedTo(outputStream); + + outputStream.flush(); + outputStream.close(); + } + + public void testInfiniteMessages() throws IOException { + // Read in as many messages as are present in the Python-generated file and write them back + FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out"); + FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out"); + + Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream); + while (current != null) { + current.writeDelimitedTo(outputStream); + current = Oneof.Test.parseDelimitedFrom(inputStream); + } + + inputStream.close(); + outputStream.flush(); + outputStream.close(); + } +} diff --git a/tests/streams/java/src/main/proto/aristaproto/nested.proto b/tests/streams/java/src/main/proto/aristaproto/nested.proto new file mode 100644 index 0000000..46a5783 --- /dev/null +++ b/tests/streams/java/src/main/proto/aristaproto/nested.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package nested; +option java_package = "aristaproto.nested"; + +// A test message with a nested message inside of it. +message Test { + // This is the nested type. + message Nested { + // Stores a simple counter. + int32 count = 1; + } + // This is the nested enum. + enum Msg { + NONE = 0; + THIS = 1; + } + + Nested nested = 1; + Sibling sibling = 2; + Sibling sibling2 = 3; + Msg msg = 4; +} + +message Sibling { + int32 foo = 1; +}
\ No newline at end of file diff --git a/tests/streams/java/src/main/proto/aristaproto/oneof.proto b/tests/streams/java/src/main/proto/aristaproto/oneof.proto new file mode 100644 index 0000000..44a8949 --- /dev/null +++ b/tests/streams/java/src/main/proto/aristaproto/oneof.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +package oneof; +option java_package = "aristaproto.oneof"; + +message Test { + oneof foo { + int32 pitied = 1; + string pitier = 2; + } + + int32 just_a_regular_field = 3; + + oneof bar { + int32 drinks = 11; + string bar_name = 12; + } +} + diff --git a/tests/streams/load_varint_cutoff.in b/tests/streams/load_varint_cutoff.in new file mode 100644 index 0000000..52b9bf1 --- /dev/null +++ b/tests/streams/load_varint_cutoff.in @@ -0,0 +1 @@ +È
\ No newline at end of file diff --git a/tests/streams/message_dump_file_multiple.expected b/tests/streams/message_dump_file_multiple.expected new file mode 100644 index 0000000..b5fdf9c --- /dev/null +++ b/tests/streams/message_dump_file_multiple.expected @@ -0,0 +1,2 @@ +•šï:bTesting•šï:bTesting +
\ No newline at end of file diff --git a/tests/streams/message_dump_file_single.expected b/tests/streams/message_dump_file_single.expected new file mode 100644 index 0000000..9b7bafb --- /dev/null +++ b/tests/streams/message_dump_file_single.expected @@ -0,0 +1 @@ +•šï:bTesting
\ No newline at end of file diff --git a/tests/test_casing.py b/tests/test_casing.py new file mode 100644 index 0000000..b16d326 --- /dev/null +++ b/tests/test_casing.py @@ -0,0 +1,129 @@ +import pytest + +from aristaproto.casing import ( + camel_case, + pascal_case, + snake_case, +) + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "A"), + ("foobar", "Foobar"), + ("fooBar", "FooBar"), + ("FooBar", "FooBar"), + ("foo.bar", "FooBar"), + ("foo_bar", "FooBar"), + ("FOOBAR", "Foobar"), + ("FOOBar", "FooBar"), + ("UInt32", "UInt32"), + ("FOO_BAR", "FooBar"), + ("FOOBAR1", "Foobar1"), + ("FOOBAR_1", "Foobar1"), + ("FOO1BAR2", "Foo1Bar2"), + ("foo__bar", "FooBar"), + ("_foobar", "Foobar"), + ("foobaR", "FoobaR"), + ("foo~bar", "FooBar"), + ("foo:bar", "FooBar"), + ("1foobar", "1Foobar"), + ], +) +def test_pascal_case(value, expected): + actual = pascal_case(value, strict=True) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "a"), + ("foobar", "foobar"), + ("fooBar", "fooBar"), + ("FooBar", "fooBar"), + ("foo.bar", "fooBar"), + ("foo_bar", "fooBar"), + ("FOOBAR", "foobar"), + ("FOO_BAR", "fooBar"), + ("FOOBAR1", "foobar1"), + ("FOOBAR_1", "foobar1"), + ("FOO1BAR2", "foo1Bar2"), + ("foo__bar", "fooBar"), + ("_foobar", "foobar"), + ("foobaR", "foobaR"), + ("foo~bar", "fooBar"), + ("foo:bar", "fooBar"), + ("1foobar", "1Foobar"), + ], +) +def test_camel_case_strict(value, expected): + actual = camel_case(value, strict=True) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("foo_bar", "fooBar"), + ("FooBar", "fooBar"), + ("foo__bar", "foo_Bar"), + ("foo__Bar", "foo__Bar"), + ], +) +def test_camel_case_not_strict(value, expected): + actual = camel_case(value, strict=False) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "a"), + ("foobar", "foobar"), + ("fooBar", "foo_bar"), + ("FooBar", "foo_bar"), + ("foo.bar", "foo_bar"), + ("foo_bar", "foo_bar"), + ("foo_Bar", "foo_bar"), + ("FOOBAR", "foobar"), + ("FOOBar", "foo_bar"), + ("UInt32", "u_int32"), + ("FOO_BAR", "foo_bar"), + ("FOOBAR1", "foobar1"), + ("FOOBAR_1", "foobar_1"), + ("FOOBAR_123", "foobar_123"), + ("FOO1BAR2", "foo1_bar2"), + ("foo__bar", "foo_bar"), + ("_foobar", "foobar"), + ("foobaR", "fooba_r"), + ("foo~bar", "foo_bar"), + ("foo:bar", "foo_bar"), + ("1foobar", "1_foobar"), + ("GetUInt64", "get_u_int64"), + ], +) +def test_snake_case_strict(value, expected): + actual = snake_case(value) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("fooBar", "foo_bar"), + ("FooBar", "foo_bar"), + ("foo_Bar", "foo__bar"), + ("foo__bar", "foo__bar"), + ("FOOBar", "foo_bar"), + ("__foo", "__foo"), + ("GetUInt64", "get_u_int64"), + ], +) +def test_snake_case_not_strict(value, expected): + actual = snake_case(value, strict=False) + assert actual == expected, f"{value} => {expected} (actual: {actual})" diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py new file mode 100644 index 0000000..fd4de82 --- /dev/null +++ b/tests/test_deprecated.py @@ -0,0 +1,45 @@ +import warnings + +import pytest + +from tests.output_aristaproto.deprecated import ( + Message, + Test, +) + + +@pytest.fixture +def message(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + return Message(value="hello") + + +def test_deprecated_message(): + with pytest.warns(DeprecationWarning) as record: + Message(value="hello") + + assert len(record) == 1 + assert str(record[0].message) == f"{Message.__name__} is deprecated" + + +def test_message_with_deprecated_field(message): + with pytest.warns(DeprecationWarning) as record: + Test(message=message, value=10) + + assert len(record) == 1 + assert str(record[0].message) == f"{Test.__name__}.message is deprecated" + + +def test_message_with_deprecated_field_not_set(message): + with pytest.warns(None) as record: + Test(value=10) + + assert not record + + +def test_message_with_deprecated_field_not_set_default(message): + with pytest.warns(None) as record: + _ = Test(value=10).message + + assert not record diff --git a/tests/test_enum.py b/tests/test_enum.py new file mode 100644 index 0000000..807e785 --- /dev/null +++ b/tests/test_enum.py @@ -0,0 +1,79 @@ +from typing import ( + Optional, + Tuple, +) + +import pytest + +import aristaproto + + +class Colour(aristaproto.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + +PURPLE = Colour.__new__(Colour, name=None, value=4) + + +@pytest.mark.parametrize( + "member, str_value", + [ + (Colour.RED, "RED"), + (Colour.GREEN, "GREEN"), + (Colour.BLUE, "BLUE"), + ], +) +def test_str(member: Colour, str_value: str) -> None: + assert str(member) == str_value + + +@pytest.mark.parametrize( + "member, repr_value", + [ + (Colour.RED, "Colour.RED"), + (Colour.GREEN, "Colour.GREEN"), + (Colour.BLUE, "Colour.BLUE"), + ], +) +def test_repr(member: Colour, repr_value: str) -> None: + assert repr(member) == repr_value + + +@pytest.mark.parametrize( + "member, values", + [ + (Colour.RED, ("RED", 1)), + (Colour.GREEN, ("GREEN", 2)), + (Colour.BLUE, ("BLUE", 3)), + (PURPLE, (None, 4)), + ], +) +def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None: + assert (member.name, member.value) == values + + +@pytest.mark.parametrize( + "member, input_str", + [ + (Colour.RED, "RED"), + (Colour.GREEN, "GREEN"), + (Colour.BLUE, "BLUE"), + ], +) +def test_from_string(member: Colour, input_str: str) -> None: + assert Colour.from_string(input_str) == member + + +@pytest.mark.parametrize( + "member, input_int", + [ + (Colour.RED, 1), + (Colour.GREEN, 2), + (Colour.BLUE, 3), + (PURPLE, 4), + ], +) +def test_try_value(member: Colour, input_int: int) -> None: + assert Colour.try_value(input_int) == member diff --git a/tests/test_features.py b/tests/test_features.py new file mode 100644 index 0000000..638e668 --- /dev/null +++ b/tests/test_features.py @@ -0,0 +1,682 @@ +import json +import sys +from copy import ( + copy, + deepcopy, +) +from dataclasses import dataclass +from datetime import ( + datetime, + timedelta, +) +from inspect import ( + Parameter, + signature, +) +from typing import ( + Dict, + List, + Optional, +) +from unittest.mock import ANY + +import pytest + +import aristaproto + + +def test_has_field(): + @dataclass + class Bar(aristaproto.Message): + baz: int = aristaproto.int32_field(1) + + @dataclass + class Foo(aristaproto.Message): + bar: Bar = aristaproto.message_field(1) + + # Unset by default + foo = Foo() + assert aristaproto.serialized_on_wire(foo.bar) is False + + # Serialized after setting something + foo.bar.baz = 1 + assert aristaproto.serialized_on_wire(foo.bar) is True + + # Still has it after setting the default value + foo.bar.baz = 0 + assert aristaproto.serialized_on_wire(foo.bar) is True + + # Manual override (don't do this) + foo.bar._serialized_on_wire = False + assert aristaproto.serialized_on_wire(foo.bar) is False + + # Can manually set it but defaults to false + foo.bar = Bar() + assert aristaproto.serialized_on_wire(foo.bar) is False + + @dataclass + class WithCollections(aristaproto.Message): + test_list: List[str] = aristaproto.string_field(1) + test_map: Dict[str, str] = aristaproto.map_field( + 2, aristaproto.TYPE_STRING, aristaproto.TYPE_STRING + ) + + # Is always set from parse, even if all collections are empty + with_collections_empty = WithCollections().parse(bytes(WithCollections())) + assert aristaproto.serialized_on_wire(with_collections_empty) == True + with_collections_list = WithCollections().parse( + bytes(WithCollections(test_list=["a", "b", "c"])) + ) + assert aristaproto.serialized_on_wire(with_collections_list) == True + with_collections_map = WithCollections().parse( + bytes(WithCollections(test_map={"a": "b", "c": "d"})) + ) + assert aristaproto.serialized_on_wire(with_collections_map) == True + + +def test_class_init(): + @dataclass + class Bar(aristaproto.Message): + name: str = aristaproto.string_field(1) + + @dataclass + class Foo(aristaproto.Message): + name: str = aristaproto.string_field(1) + child: Bar = aristaproto.message_field(2) + + foo = Foo(name="foo", child=Bar(name="bar")) + + assert foo.to_dict() == {"name": "foo", "child": {"name": "bar"}} + assert foo.to_pydict() == {"name": "foo", "child": {"name": "bar"}} + + +def test_enum_as_int_json(): + class TestEnum(aristaproto.Enum): + ZERO = 0 + ONE = 1 + + @dataclass + class Foo(aristaproto.Message): + bar: TestEnum = aristaproto.enum_field(1) + + # JSON strings are supported, but ints should still be supported too. + foo = Foo().from_dict({"bar": 1}) + assert foo.bar == TestEnum.ONE + + # Plain-ol'-ints should serialize properly too. + foo.bar = 1 + assert foo.to_dict() == {"bar": "ONE"} + + # Similar expectations for pydict + foo = Foo().from_pydict({"bar": 1}) + assert foo.bar == TestEnum.ONE + assert foo.to_pydict() == {"bar": TestEnum.ONE} + + +def test_unknown_fields(): + @dataclass + class Newer(aristaproto.Message): + foo: bool = aristaproto.bool_field(1) + bar: int = aristaproto.int32_field(2) + baz: str = aristaproto.string_field(3) + + @dataclass + class Older(aristaproto.Message): + foo: bool = aristaproto.bool_field(1) + + newer = Newer(foo=True, bar=1, baz="Hello") + serialized_newer = bytes(newer) + + # Unknown fields in `Newer` should round trip with `Older` + round_trip = bytes(Older().parse(serialized_newer)) + assert serialized_newer == round_trip + + new_again = Newer().parse(round_trip) + assert newer == new_again + + +def test_oneof_support(): + @dataclass + class Sub(aristaproto.Message): + val: int = aristaproto.int32_field(1) + + @dataclass + class Foo(aristaproto.Message): + bar: int = aristaproto.int32_field(1, group="group1") + baz: str = aristaproto.string_field(2, group="group1") + sub: Sub = aristaproto.message_field(3, group="group2") + abc: str = aristaproto.string_field(4, group="group2") + + foo = Foo() + + assert aristaproto.which_one_of(foo, "group1")[0] == "" + + foo.bar = 1 + foo.baz = "test" + + # Other oneof fields should now be unset + assert not hasattr(foo, "bar") + assert object.__getattribute__(foo, "bar") == aristaproto.PLACEHOLDER + assert aristaproto.which_one_of(foo, "group1")[0] == "baz" + + foo.sub = Sub(val=1) + assert aristaproto.serialized_on_wire(foo.sub) + + foo.abc = "test" + + # Group 1 shouldn't be touched, group 2 should have reset + assert not hasattr(foo, "sub") + assert object.__getattribute__(foo, "sub") == aristaproto.PLACEHOLDER + assert aristaproto.which_one_of(foo, "group2")[0] == "abc" + + # Zero value should always serialize for one-of + foo = Foo(bar=0) + assert aristaproto.which_one_of(foo, "group1")[0] == "bar" + assert bytes(foo) == b"\x08\x00" + + # Round trip should also work + foo2 = Foo().parse(bytes(foo)) + assert aristaproto.which_one_of(foo2, "group1")[0] == "bar" + assert foo.bar == 0 + assert aristaproto.which_one_of(foo2, "group2")[0] == "" + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="pattern matching is only supported in python3.10+", +) +def test_oneof_pattern_matching(): + from .oneof_pattern_matching import test_oneof_pattern_matching + + test_oneof_pattern_matching() + + +def test_json_casing(): + @dataclass + class CasingTest(aristaproto.Message): + pascal_case: int = aristaproto.int32_field(1) + camel_case: int = aristaproto.int32_field(2) + snake_case: int = aristaproto.int32_field(3) + kabob_case: int = aristaproto.int32_field(4) + + # Parsing should accept almost any input + test = CasingTest().from_dict( + {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4} + ) + + assert test == CasingTest(1, 2, 3, 4) + + # Serializing should be strict. + assert json.loads(test.to_json()) == { + "pascalCase": 1, + "camelCase": 2, + "snakeCase": 3, + "kabobCase": 4, + } + + assert json.loads(test.to_json(casing=aristaproto.Casing.SNAKE)) == { + "pascal_case": 1, + "camel_case": 2, + "snake_case": 3, + "kabob_case": 4, + } + + +def test_dict_casing(): + @dataclass + class CasingTest(aristaproto.Message): + pascal_case: int = aristaproto.int32_field(1) + camel_case: int = aristaproto.int32_field(2) + snake_case: int = aristaproto.int32_field(3) + kabob_case: int = aristaproto.int32_field(4) + + # Parsing should accept almost any input + test = CasingTest().from_dict( + {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4} + ) + + assert test == CasingTest(1, 2, 3, 4) + + # Serializing should be strict. + assert test.to_dict() == { + "pascalCase": 1, + "camelCase": 2, + "snakeCase": 3, + "kabobCase": 4, + } + assert test.to_pydict() == { + "pascalCase": 1, + "camelCase": 2, + "snakeCase": 3, + "kabobCase": 4, + } + + assert test.to_dict(casing=aristaproto.Casing.SNAKE) == { + "pascal_case": 1, + "camel_case": 2, + "snake_case": 3, + "kabob_case": 4, + } + assert test.to_pydict(casing=aristaproto.Casing.SNAKE) == { + "pascal_case": 1, + "camel_case": 2, + "snake_case": 3, + "kabob_case": 4, + } + + +def test_optional_flag(): + @dataclass + class Request(aristaproto.Message): + flag: Optional[bool] = aristaproto.message_field(1, wraps=aristaproto.TYPE_BOOL) + + # Serialization of not passed vs. set vs. zero-value. + assert bytes(Request()) == b"" + assert bytes(Request(flag=True)) == b"\n\x02\x08\x01" + assert bytes(Request(flag=False)) == b"\n\x00" + + # Differentiate between not passed and the zero-value. + assert Request().parse(b"").flag is None + assert Request().parse(b"\n\x00").flag is False + + +def test_optional_datetime_to_dict(): + @dataclass + class Request(aristaproto.Message): + date: Optional[datetime] = aristaproto.message_field(1, optional=True) + + # Check dict serialization + assert Request().to_dict() == {} + assert Request().to_dict(include_default_values=True) == {"date": None} + assert Request(date=datetime(2020, 1, 1)).to_dict() == { + "date": "2020-01-01T00:00:00Z" + } + assert Request(date=datetime(2020, 1, 1)).to_dict(include_default_values=True) == { + "date": "2020-01-01T00:00:00Z" + } + + # Check pydict serialization + assert Request().to_pydict() == {} + assert Request().to_pydict(include_default_values=True) == {"date": None} + assert Request(date=datetime(2020, 1, 1)).to_pydict() == { + "date": datetime(2020, 1, 1) + } + assert Request(date=datetime(2020, 1, 1)).to_pydict( + include_default_values=True + ) == {"date": datetime(2020, 1, 1)} + + +def test_to_json_default_values(): + @dataclass + class TestMessage(aristaproto.Message): + some_int: int = aristaproto.int32_field(1) + some_double: float = aristaproto.double_field(2) + some_str: str = aristaproto.string_field(3) + some_bool: bool = aristaproto.bool_field(4) + + # Empty dict + test = TestMessage().from_dict({}) + + assert json.loads(test.to_json(include_default_values=True)) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + + # All default values + test = TestMessage().from_dict( + {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False} + ) + + assert json.loads(test.to_json(include_default_values=True)) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + + +def test_to_dict_default_values(): + @dataclass + class TestMessage(aristaproto.Message): + some_int: int = aristaproto.int32_field(1) + some_double: float = aristaproto.double_field(2) + some_str: str = aristaproto.string_field(3) + some_bool: bool = aristaproto.bool_field(4) + + # Empty dict + test = TestMessage().from_dict({}) + + assert test.to_dict(include_default_values=True) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + + test = TestMessage().from_pydict({}) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + + # All default values + test = TestMessage().from_dict( + {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False} + ) + + assert test.to_dict(include_default_values=True) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + + test = TestMessage().from_pydict( + {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False} + ) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + + # Some default and some other values + @dataclass + class TestMessage2(aristaproto.Message): + some_int: int = aristaproto.int32_field(1) + some_double: float = aristaproto.double_field(2) + some_str: str = aristaproto.string_field(3) + some_bool: bool = aristaproto.bool_field(4) + some_default_int: int = aristaproto.int32_field(5) + some_default_double: float = aristaproto.double_field(6) + some_default_str: str = aristaproto.string_field(7) + some_default_bool: bool = aristaproto.bool_field(8) + + test = TestMessage2().from_dict( + { + "someInt": 2, + "someDouble": 1.2, + "someStr": "hello", + "someBool": True, + "someDefaultInt": 0, + "someDefaultDouble": 0.0, + "someDefaultStr": "", + "someDefaultBool": False, + } + ) + + assert test.to_dict(include_default_values=True) == { + "someInt": 2, + "someDouble": 1.2, + "someStr": "hello", + "someBool": True, + "someDefaultInt": 0, + "someDefaultDouble": 0.0, + "someDefaultStr": "", + "someDefaultBool": False, + } + + test = TestMessage2().from_pydict( + { + "someInt": 2, + "someDouble": 1.2, + "someStr": "hello", + "someBool": True, + "someDefaultInt": 0, + "someDefaultDouble": 0.0, + "someDefaultStr": "", + "someDefaultBool": False, + } + ) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 2, + "someDouble": 1.2, + "someStr": "hello", + "someBool": True, + "someDefaultInt": 0, + "someDefaultDouble": 0.0, + "someDefaultStr": "", + "someDefaultBool": False, + } + + # Nested messages + @dataclass + class TestChildMessage(aristaproto.Message): + some_other_int: int = aristaproto.int32_field(1) + + @dataclass + class TestParentMessage(aristaproto.Message): + some_int: int = aristaproto.int32_field(1) + some_double: float = aristaproto.double_field(2) + some_message: TestChildMessage = aristaproto.message_field(3) + + test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2}) + + assert test.to_dict(include_default_values=True) == { + "someInt": 0, + "someDouble": 1.2, + "someMessage": {"someOtherInt": 0}, + } + + test = TestParentMessage().from_pydict({"someInt": 0, "someDouble": 1.2}) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 0, + "someDouble": 1.2, + "someMessage": {"someOtherInt": 0}, + } + + +def test_to_dict_datetime_values(): + @dataclass + class TestDatetimeMessage(aristaproto.Message): + bar: datetime = aristaproto.message_field(1) + baz: timedelta = aristaproto.message_field(2) + + test = TestDatetimeMessage().from_dict( + {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"} + ) + + assert test.to_dict() == {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"} + + test = TestDatetimeMessage().from_pydict( + {"bar": datetime(year=2020, month=1, day=1), "baz": timedelta(days=1)} + ) + + assert test.to_pydict() == { + "bar": datetime(year=2020, month=1, day=1), + "baz": timedelta(days=1), + } + + +def test_oneof_default_value_set_causes_writes_wire(): + @dataclass + class Empty(aristaproto.Message): + pass + + @dataclass + class Foo(aristaproto.Message): + bar: int = aristaproto.int32_field(1, group="group1") + baz: str = aristaproto.string_field(2, group="group1") + qux: Empty = aristaproto.message_field(3, group="group1") + + def _round_trip_serialization(foo: Foo) -> Foo: + return Foo().parse(bytes(foo)) + + foo1 = Foo(bar=0) + foo2 = Foo(baz="") + foo3 = Foo(qux=Empty()) + foo4 = Foo() + + assert bytes(foo1) == b"\x08\x00" + assert ( + aristaproto.which_one_of(foo1, "group1") + == aristaproto.which_one_of(_round_trip_serialization(foo1), "group1") + == ("bar", 0) + ) + + assert bytes(foo2) == b"\x12\x00" # Baz is just an empty string + assert ( + aristaproto.which_one_of(foo2, "group1") + == aristaproto.which_one_of(_round_trip_serialization(foo2), "group1") + == ("baz", "") + ) + + assert bytes(foo3) == b"\x1a\x00" + assert ( + aristaproto.which_one_of(foo3, "group1") + == aristaproto.which_one_of(_round_trip_serialization(foo3), "group1") + == ("qux", Empty()) + ) + + assert bytes(foo4) == b"" + assert ( + aristaproto.which_one_of(foo4, "group1") + == aristaproto.which_one_of(_round_trip_serialization(foo4), "group1") + == ("", None) + ) + + +def test_message_repr(): + from tests.output_aristaproto.recursivemessage import Test + + assert repr(Test(name="Loki")) == "Test(name='Loki')" + assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())" + + +def test_bool(): + """Messages should evaluate similarly to a collection + >>> test = [] + >>> bool(test) + ... False + >>> test.append(1) + >>> bool(test) + ... True + >>> del test[0] + >>> bool(test) + ... False + """ + + @dataclass + class Falsy(aristaproto.Message): + pass + + @dataclass + class Truthy(aristaproto.Message): + bar: int = aristaproto.int32_field(1) + + assert not Falsy() + t = Truthy() + assert not t + t.bar = 1 + assert t + t.bar = 0 + assert not t + + +# valid ISO datetimes according to https://www.myintervals.com/blog/2009/05/20/iso-8601-date-validation-that-doesnt-suck/ +iso_candidates = """2009-12-12T12:34 +2009 +2009-05-19 +2009-05-19 +20090519 +2009123 +2009-05 +2009-123 +2009-222 +2009-001 +2009-W01-1 +2009-W51-1 +2009-W33 +2009W511 +2009-05-19 +2009-05-19 00:00 +2009-05-19 14 +2009-05-19 14:31 +2009-05-19 14:39:22 +2009-05-19T14:39Z +2009-W21-2 +2009-W21-2T01:22 +2009-139 +2009-05-19 14:39:22-06:00 +2009-05-19 14:39:22+0600 +2009-05-19 14:39:22-01 +20090621T0545Z +2007-04-06T00:00 +2007-04-05T24:00 +2010-02-18T16:23:48.5 +2010-02-18T16:23:48,444 +2010-02-18T16:23:48,3-06:00 +2010-02-18T16:23:00.4 +2010-02-18T16:23:00,25 +2010-02-18T16:23:00.33+0600 +2010-02-18T16:00:00.23334444 +2010-02-18T16:00:00,2283 +2009-05-19 143922 +2009-05-19 1439""".split( + "\n" +) + + +def test_iso_datetime(): + @dataclass + class Envelope(aristaproto.Message): + ts: datetime = aristaproto.message_field(1) + + msg = Envelope() + + for _, candidate in enumerate(iso_candidates): + msg.from_dict({"ts": candidate}) + assert isinstance(msg.ts, datetime) + + +def test_iso_datetime_list(): + @dataclass + class Envelope(aristaproto.Message): + timestamps: List[datetime] = aristaproto.message_field(1) + + msg = Envelope() + + msg.from_dict({"timestamps": iso_candidates}) + assert all([isinstance(item, datetime) for item in msg.timestamps]) + + +def test_service_argument__expected_parameter(): + from tests.output_aristaproto.service import TestStub + + sig = signature(TestStub.do_thing) + do_thing_request_parameter = sig.parameters["do_thing_request"] + assert do_thing_request_parameter.default is Parameter.empty + assert do_thing_request_parameter.annotation == "DoThingRequest" + + +def test_is_set(): + @dataclass + class Spam(aristaproto.Message): + foo: bool = aristaproto.bool_field(1) + bar: Optional[int] = aristaproto.int32_field(2, optional=True) + + assert not Spam().is_set("foo") + assert not Spam().is_set("bar") + assert Spam(foo=True).is_set("foo") + assert Spam(foo=True, bar=0).is_set("bar") + + +def test_equality_comparison(): + from tests.output_aristaproto.bool import Test as TestMessage + + msg = TestMessage(value=True) + + assert msg == msg + assert msg == ANY + assert msg == TestMessage(value=True) + assert msg != 1 + assert msg != TestMessage(value=False) diff --git a/tests/test_get_ref_type.py b/tests/test_get_ref_type.py new file mode 100644 index 0000000..a4c6f76 --- /dev/null +++ b/tests/test_get_ref_type.py @@ -0,0 +1,371 @@ +import pytest + +from aristaproto.compile.importing import ( + get_type_reference, + parse_source_type_name, +) + + +@pytest.mark.parametrize( + ["google_type", "expected_name", "expected_import"], + [ + ( + ".google.protobuf.Empty", + '"aristaproto_lib_google_protobuf.Empty"', + "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf", + ), + ( + ".google.protobuf.Struct", + '"aristaproto_lib_google_protobuf.Struct"', + "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf", + ), + ( + ".google.protobuf.ListValue", + '"aristaproto_lib_google_protobuf.ListValue"', + "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf", + ), + ( + ".google.protobuf.Value", + '"aristaproto_lib_google_protobuf.Value"', + "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf", + ), + ], +) +def test_reference_google_wellknown_types_non_wrappers( + google_type: str, expected_name: str, expected_import: str +): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type=google_type, pydantic=False + ) + + assert name == expected_name + assert imports.__contains__( + expected_import + ), f"{expected_import} not found in {imports}" + + +@pytest.mark.parametrize( + ["google_type", "expected_name", "expected_import"], + [ + ( + ".google.protobuf.Empty", + '"aristaproto_lib_pydantic_google_protobuf.Empty"', + "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf", + ), + ( + ".google.protobuf.Struct", + '"aristaproto_lib_pydantic_google_protobuf.Struct"', + "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf", + ), + ( + ".google.protobuf.ListValue", + '"aristaproto_lib_pydantic_google_protobuf.ListValue"', + "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf", + ), + ( + ".google.protobuf.Value", + '"aristaproto_lib_pydantic_google_protobuf.Value"', + "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf", + ), + ], +) +def test_reference_google_wellknown_types_non_wrappers_pydantic( + google_type: str, expected_name: str, expected_import: str +): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type=google_type, pydantic=True + ) + + assert name == expected_name + assert imports.__contains__( + expected_import + ), f"{expected_import} not found in {imports}" + + +@pytest.mark.parametrize( + ["google_type", "expected_name"], + [ + (".google.protobuf.DoubleValue", "Optional[float]"), + (".google.protobuf.FloatValue", "Optional[float]"), + (".google.protobuf.Int32Value", "Optional[int]"), + (".google.protobuf.Int64Value", "Optional[int]"), + (".google.protobuf.UInt32Value", "Optional[int]"), + (".google.protobuf.UInt64Value", "Optional[int]"), + (".google.protobuf.BoolValue", "Optional[bool]"), + (".google.protobuf.StringValue", "Optional[str]"), + (".google.protobuf.BytesValue", "Optional[bytes]"), + ], +) +def test_referenceing_google_wrappers_unwraps_them( + google_type: str, expected_name: str +): + imports = set() + name = get_type_reference(package="", imports=imports, source_type=google_type) + + assert name == expected_name + assert imports == set() + + +@pytest.mark.parametrize( + ["google_type", "expected_name"], + [ + ( + ".google.protobuf.DoubleValue", + '"aristaproto_lib_google_protobuf.DoubleValue"', + ), + (".google.protobuf.FloatValue", '"aristaproto_lib_google_protobuf.FloatValue"'), + (".google.protobuf.Int32Value", '"aristaproto_lib_google_protobuf.Int32Value"'), + (".google.protobuf.Int64Value", '"aristaproto_lib_google_protobuf.Int64Value"'), + ( + ".google.protobuf.UInt32Value", + '"aristaproto_lib_google_protobuf.UInt32Value"', + ), + ( + ".google.protobuf.UInt64Value", + '"aristaproto_lib_google_protobuf.UInt64Value"', + ), + (".google.protobuf.BoolValue", '"aristaproto_lib_google_protobuf.BoolValue"'), + ( + ".google.protobuf.StringValue", + '"aristaproto_lib_google_protobuf.StringValue"', + ), + (".google.protobuf.BytesValue", '"aristaproto_lib_google_protobuf.BytesValue"'), + ], +) +def test_referenceing_google_wrappers_without_unwrapping( + google_type: str, expected_name: str +): + name = get_type_reference( + package="", imports=set(), source_type=google_type, unwrap=False + ) + + assert name == expected_name + + +def test_reference_child_package_from_package(): + imports = set() + name = get_type_reference( + package="package", imports=imports, source_type="package.child.Message" + ) + + assert imports == {"from . import child"} + assert name == '"child.Message"' + + +def test_reference_child_package_from_root(): + imports = set() + name = get_type_reference(package="", imports=imports, source_type="child.Message") + + assert imports == {"from . import child"} + assert name == '"child.Message"' + + +def test_reference_camel_cased(): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type="child_package.example_message" + ) + + assert imports == {"from . import child_package"} + assert name == '"child_package.ExampleMessage"' + + +def test_reference_nested_child_from_root(): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type="nested.child.Message" + ) + + assert imports == {"from .nested import child as nested_child"} + assert name == '"nested_child.Message"' + + +def test_reference_deeply_nested_child_from_root(): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type="deeply.nested.child.Message" + ) + + assert imports == {"from .deeply.nested import child as deeply_nested_child"} + assert name == '"deeply_nested_child.Message"' + + +def test_reference_deeply_nested_child_from_package(): + imports = set() + name = get_type_reference( + package="package", + imports=imports, + source_type="package.deeply.nested.child.Message", + ) + + assert imports == {"from .deeply.nested import child as deeply_nested_child"} + assert name == '"deeply_nested_child.Message"' + + +def test_reference_root_sibling(): + imports = set() + name = get_type_reference(package="", imports=imports, source_type="Message") + + assert imports == set() + assert name == '"Message"' + + +def test_reference_nested_siblings(): + imports = set() + name = get_type_reference(package="foo", imports=imports, source_type="foo.Message") + + assert imports == set() + assert name == '"Message"' + + +def test_reference_deeply_nested_siblings(): + imports = set() + name = get_type_reference( + package="foo.bar", imports=imports, source_type="foo.bar.Message" + ) + + assert imports == set() + assert name == '"Message"' + + +def test_reference_parent_package_from_child(): + imports = set() + name = get_type_reference( + package="package.child", imports=imports, source_type="package.Message" + ) + + assert imports == {"from ... import package as __package__"} + assert name == '"__package__.Message"' + + +def test_reference_parent_package_from_deeply_nested_child(): + imports = set() + name = get_type_reference( + package="package.deeply.nested.child", + imports=imports, + source_type="package.deeply.nested.Message", + ) + + assert imports == {"from ... import nested as __nested__"} + assert name == '"__nested__.Message"' + + +def test_reference_ancestor_package_from_nested_child(): + imports = set() + name = get_type_reference( + package="package.ancestor.nested.child", + imports=imports, + source_type="package.ancestor.Message", + ) + + assert imports == {"from .... import ancestor as ___ancestor__"} + assert name == '"___ancestor__.Message"' + + +def test_reference_root_package_from_child(): + imports = set() + name = get_type_reference( + package="package.child", imports=imports, source_type="Message" + ) + + assert imports == {"from ... import Message as __Message__"} + assert name == '"__Message__"' + + +def test_reference_root_package_from_deeply_nested_child(): + imports = set() + name = get_type_reference( + package="package.deeply.nested.child", imports=imports, source_type="Message" + ) + + assert imports == {"from ..... import Message as ____Message__"} + assert name == '"____Message__"' + + +def test_reference_unrelated_package(): + imports = set() + name = get_type_reference(package="a", imports=imports, source_type="p.Message") + + assert imports == {"from .. import p as _p__"} + assert name == '"_p__.Message"' + + +def test_reference_unrelated_nested_package(): + imports = set() + name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") + + assert imports == {"from ...p import q as __p_q__"} + assert name == '"__p_q__.Message"' + + +def test_reference_unrelated_deeply_nested_package(): + imports = set() + name = get_type_reference( + package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message" + ) + + assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} + assert name == '"____p_q_r_s__.Message"' + + +def test_reference_cousin_package(): + imports = set() + name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") + + assert imports == {"from .. import y as _y__"} + assert name == '"_y__.Message"' + + +def test_reference_cousin_package_different_name(): + imports = set() + name = get_type_reference( + package="test.package1", imports=imports, source_type="cousin.package2.Message" + ) + + assert imports == {"from ...cousin import package2 as __cousin_package2__"} + assert name == '"__cousin_package2__.Message"' + + +def test_reference_cousin_package_same_name(): + imports = set() + name = get_type_reference( + package="test.package", imports=imports, source_type="cousin.package.Message" + ) + + assert imports == {"from ...cousin import package as __cousin_package__"} + assert name == '"__cousin_package__.Message"' + + +def test_reference_far_cousin_package(): + imports = set() + name = get_type_reference( + package="a.x.y", imports=imports, source_type="a.b.c.Message" + ) + + assert imports == {"from ...b import c as __b_c__"} + assert name == '"__b_c__.Message"' + + +def test_reference_far_far_cousin_package(): + imports = set() + name = get_type_reference( + package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message" + ) + + assert imports == {"from ....b.c import d as ___b_c_d__"} + assert name == '"___b_c_d__.Message"' + + +@pytest.mark.parametrize( + ["full_name", "expected_output"], + [ + ("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), + (".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), + (".service.ExampleRequest", ("service", "ExampleRequest")), + (".package.lower_case_message", ("package", "lower_case_message")), + ], +) +def test_parse_field_type_name(full_name, expected_output): + assert parse_source_type_name(full_name) == expected_output diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000..9247e7b --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,225 @@ +import importlib +import json +import math +import os +import sys +from collections import namedtuple +from types import ModuleType +from typing import ( + Any, + Dict, + List, + Set, + Tuple, +) + +import pytest + +import aristaproto +from tests.inputs import config as test_input_config +from tests.mocks import MockChannel +from tests.util import ( + find_module, + get_directories, + get_test_case_json_data, + inputs_path, +) + + +# Force pure-python implementation instead of C++, otherwise imports +# break things because we can't properly reset the symbol database. +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +from google.protobuf.json_format import Parse + + +class TestCases: + def __init__( + self, + path, + services: Set[str], + xfail: Set[str], + ): + _all = set(get_directories(path)) - {"__pycache__"} + _services = services + _messages = (_all - services) - {"__pycache__"} + _messages_with_json = { + test for test in _messages if get_test_case_json_data(test) + } + + unknown_xfail_tests = xfail - _all + if unknown_xfail_tests: + raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}") + + self.all = self.apply_xfail_marks(_all, xfail) + self.services = self.apply_xfail_marks(_services, xfail) + self.messages = self.apply_xfail_marks(_messages, xfail) + self.messages_with_json = self.apply_xfail_marks(_messages_with_json, xfail) + + @staticmethod + def apply_xfail_marks(test_set: Set[str], xfail: Set[str]): + return [ + pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test + for test in test_set + ] + + +test_cases = TestCases( + path=inputs_path, + services=test_input_config.services, + xfail=test_input_config.xfail, +) + +plugin_output_package = "tests.output_aristaproto" +reference_output_package = "tests.output_reference" + +TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"]) + + +def module_has_entry_point(module: ModuleType): + return any(hasattr(module, attr) for attr in ["Test", "TestStub"]) + + +def list_replace_nans(items: List) -> List[Any]: + """Replace float("nan") in a list with the string "NaN" + + Parameters + ---------- + items : List + List to update + + Returns + ------- + List[Any] + Updated list + """ + result = [] + for item in items: + if isinstance(item, list): + result.append(list_replace_nans(item)) + elif isinstance(item, dict): + result.append(dict_replace_nans(item)) + elif isinstance(item, float) and math.isnan(item): + result.append(aristaproto.NAN) + return result + + +def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]: + """Replace float("nan") in a dictionary with the string "NaN" + + Parameters + ---------- + input_dict : Dict[Any, Any] + Dictionary to update + + Returns + ------- + Dict[Any, Any] + Updated dictionary + """ + result = {} + for key, value in input_dict.items(): + if isinstance(value, dict): + value = dict_replace_nans(value) + elif isinstance(value, list): + value = list_replace_nans(value) + elif isinstance(value, float) and math.isnan(value): + value = aristaproto.NAN + result[key] = value + return result + + +@pytest.fixture +def test_data(request, reset_sys_path): + test_case_name = request.param + + reference_module_root = os.path.join( + *reference_output_package.split("."), test_case_name + ) + sys.path.append(reference_module_root) + + plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}") + + plugin_module_entry_point = find_module(plugin_module, module_has_entry_point) + + if not plugin_module_entry_point: + raise Exception( + f"Test case {repr(test_case_name)} has no entry point. " + "Please add a proto message or service called Test and recompile." + ) + + yield ( + TestData( + plugin_module=plugin_module_entry_point, + reference_module=lambda: importlib.import_module( + f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" + ), + json_data=get_test_case_json_data(test_case_name), + ) + ) + + +@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) +def test_message_can_instantiated(test_data: TestData) -> None: + plugin_module, *_ = test_data + plugin_module.Test() + + +@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) +def test_message_equality(test_data: TestData) -> None: + plugin_module, *_ = test_data + message1 = plugin_module.Test() + message2 = plugin_module.Test() + assert message1 == message2 + + +@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) +def test_message_json(repeat, test_data: TestData) -> None: + plugin_module, _, json_data = test_data + + for _ in range(repeat): + for sample in json_data: + if sample.belongs_to(test_input_config.non_symmetrical_json): + continue + + message: aristaproto.Message = plugin_module.Test() + + message.from_json(sample.json) + message_json = message.to_json(0) + + assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans( + json.loads(sample.json) + ) + + +@pytest.mark.parametrize("test_data", test_cases.services, indirect=True) +def test_service_can_be_instantiated(test_data: TestData) -> None: + test_data.plugin_module.TestStub(MockChannel()) + + +@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) +def test_binary_compatibility(repeat, test_data: TestData) -> None: + plugin_module, reference_module, json_data = test_data + + for sample in json_data: + reference_instance = Parse(sample.json, reference_module().Test()) + reference_binary_output = reference_instance.SerializeToString() + + for _ in range(repeat): + plugin_instance_from_json: aristaproto.Message = ( + plugin_module.Test().from_json(sample.json) + ) + plugin_instance_from_binary = plugin_module.Test.FromString( + reference_binary_output + ) + + # Generally this can't be relied on, but here we are aiming to match the + # existing Python implementation and aren't doing anything tricky. + # https://developers.google.com/protocol-buffers/docs/encoding#implications + assert bytes(plugin_instance_from_json) == reference_binary_output + assert bytes(plugin_instance_from_binary) == reference_binary_output + + assert plugin_instance_from_json == plugin_instance_from_binary + assert dict_replace_nans( + plugin_instance_from_json.to_dict() + ) == dict_replace_nans(plugin_instance_from_binary.to_dict()) diff --git a/tests/test_mapmessage.py b/tests/test_mapmessage.py new file mode 100644 index 0000000..75220e4 --- /dev/null +++ b/tests/test_mapmessage.py @@ -0,0 +1,18 @@ +from tests.output_aristaproto.mapmessage import ( + Nested, + Test, +) + + +def test_mapmessage_to_dict_preserves_message(): + message = Test( + items={ + "test": Nested( + count=1, + ) + } + ) + + message.to_dict() + + assert isinstance(message.items["test"], Nested), "Wrong nested type after to_dict" diff --git a/tests/test_pickling.py b/tests/test_pickling.py new file mode 100644 index 0000000..2356d98 --- /dev/null +++ b/tests/test_pickling.py @@ -0,0 +1,203 @@ +import pickle +from copy import ( + copy, + deepcopy, +) +from dataclasses import dataclass +from typing import ( + Dict, + List, +) +from unittest.mock import ANY + +import cachelib + +import aristaproto +from aristaproto.lib.google import protobuf as google + + +def unpickled(message): + return pickle.loads(pickle.dumps(message)) + + +@dataclass(eq=False, repr=False) +class Fe(aristaproto.Message): + abc: str = aristaproto.string_field(1) + + +@dataclass(eq=False, repr=False) +class Fi(aristaproto.Message): + abc: str = aristaproto.string_field(1) + + +@dataclass(eq=False, repr=False) +class Fo(aristaproto.Message): + abc: str = aristaproto.string_field(1) + + +@dataclass(eq=False, repr=False) +class NestedData(aristaproto.Message): + struct_foo: Dict[str, "google.Struct"] = aristaproto.map_field( + 1, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE + ) + map_str_any_bar: Dict[str, "google.Any"] = aristaproto.map_field( + 2, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE + ) + + +@dataclass(eq=False, repr=False) +class Complex(aristaproto.Message): + foo_str: str = aristaproto.string_field(1) + fe: "Fe" = aristaproto.message_field(3, group="grp") + fi: "Fi" = aristaproto.message_field(4, group="grp") + fo: "Fo" = aristaproto.message_field(5, group="grp") + nested_data: "NestedData" = aristaproto.message_field(6) + mapping: Dict[str, "google.Any"] = aristaproto.map_field( + 7, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE + ) + + +def complex_msg(): + return Complex( + foo_str="yep", + fe=Fe(abc="1"), + nested_data=NestedData( + struct_foo={ + "foo": google.Struct( + fields={ + "hello": google.Value( + list_value=google.ListValue( + values=[google.Value(string_value="world")] + ) + ) + } + ), + }, + map_str_any_bar={ + "key": google.Any(value=b"value"), + }, + ), + mapping={ + "message": google.Any(value=bytes(Fi(abc="hi"))), + "string": google.Any(value=b"howdy"), + }, + ) + + +def test_pickling_complex_message(): + msg = complex_msg() + deser = unpickled(msg) + assert msg == deser + assert msg.fe.abc == "1" + assert msg.is_set("fi") is not True + assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi"))) + assert msg.mapping["string"].value.decode() == "howdy" + assert ( + msg.nested_data.struct_foo["foo"] + .fields["hello"] + .list_value.values[0] + .string_value + == "world" + ) + + +def test_recursive_message(): + from tests.output_aristaproto.recursivemessage import Test as RecursiveMessage + + msg = RecursiveMessage() + msg = unpickled(msg) + + assert msg.child == RecursiveMessage() + + # Lazily-created zero-value children must not affect equality. + assert msg == RecursiveMessage() + + # Lazily-created zero-value children must not affect serialization. + assert bytes(msg) == b"" + + +def test_recursive_message_defaults(): + from tests.output_aristaproto.recursivemessage import ( + Intermediate, + Test as RecursiveMessage, + ) + + msg = RecursiveMessage(name="bob", intermediate=Intermediate(42)) + msg = unpickled(msg) + + # set values are as expected + assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42)) + + # lazy initialized works modifies the message + assert msg != RecursiveMessage( + name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude") + ) + msg.child.child.name = "jude" + assert msg == RecursiveMessage( + name="bob", + intermediate=Intermediate(42), + child=RecursiveMessage(child=RecursiveMessage(name="jude")), + ) + + # lazily initialization recurses as needed + assert msg.child.child.child.child.child.child.child == RecursiveMessage() + assert msg.intermediate.child.intermediate == Intermediate() + + +@dataclass +class PickledMessage(aristaproto.Message): + foo: bool = aristaproto.bool_field(1) + bar: int = aristaproto.int32_field(2) + baz: List[str] = aristaproto.string_field(3) + + +def test_copyability(): + msg = PickledMessage(bar=12, baz=["hello"]) + msg = unpickled(msg) + + copied = copy(msg) + assert msg == copied + assert msg is not copied + assert msg.baz is copied.baz + + deepcopied = deepcopy(msg) + assert msg == deepcopied + assert msg is not deepcopied + assert msg.baz is not deepcopied.baz + + +def test_message_can_be_cached(): + """Cachelib uses pickling to cache values""" + + cache = cachelib.SimpleCache() + + def use_cache(): + calls = getattr(use_cache, "calls", 0) + result = cache.get("message") + if result is not None: + return result + else: + setattr(use_cache, "calls", calls + 1) + result = complex_msg() + cache.set("message", result) + return result + + for n in range(10): + if n == 0: + assert not cache.has("message") + else: + assert cache.has("message") + + msg = use_cache() + assert use_cache.calls == 1 # The message is only ever built once + assert msg.fe.abc == "1" + assert msg.is_set("fi") is not True + assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi"))) + assert msg.mapping["string"].value.decode() == "howdy" + assert ( + msg.nested_data.struct_foo["foo"] + .fields["hello"] + .list_value.values[0] + .string_value + == "world" + ) diff --git a/tests/test_streams.py b/tests/test_streams.py new file mode 100644 index 0000000..7ae441b --- /dev/null +++ b/tests/test_streams.py @@ -0,0 +1,434 @@ +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from shutil import which +from subprocess import run +from typing import Optional + +import pytest + +import aristaproto +from tests.output_aristaproto import ( + map, + nested, + oneof, + repeated, + repeatedpacked, +) + + +oneof_example = oneof.Test().from_dict( + {"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"} +) + +len_oneof = len(oneof_example) + +nested_example = nested.Test().from_dict( + { + "nested": {"count": 1}, + "sibling": {"foo": 2}, + "sibling2": {"foo": 3}, + "msg": nested.TestMsg.THIS, + } +) + +repeated_example = repeated.Test().from_dict({"names": ["blah", "Blah2"]}) + +packed_example = repeatedpacked.Test().from_dict( + {"counts": [1, 2, 3], "signed": [-1, 2, -3], "fixed": [1.2, -2.3, 3.4]} +) + +map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}}) + +streams_path = Path("tests/streams/") + +java = which("java") + + +def test_load_varint_too_long(): + with BytesIO( + b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01" + ) as stream, pytest.raises(ValueError): + aristaproto.load_varint(stream) + + with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream: + # This should not raise a ValueError, as it is within 64 bits + aristaproto.load_varint(stream) + + +def test_load_varint_file(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + assert aristaproto.load_varint(stream) == (8, b"\x08") # Single-byte varint + stream.read(2) # Skip until first multi-byte + assert aristaproto.load_varint(stream) == ( + 123456789, + b"\x95\x9A\xEF\x3A", + ) # Multi-byte varint + + +def test_load_varint_cutoff(): + with open(streams_path / "load_varint_cutoff.in", "rb") as stream: + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + stream.seek(1) + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + +def test_dump_varint_file(tmp_path): + # Dump test varints to file + with open(tmp_path / "dump_varint_file.out", "wb") as stream: + aristaproto.dump_varint(8, stream) # Single-byte varint + aristaproto.dump_varint(123456789, stream) # Multi-byte varint + + # Check that file contents are as expected + with open(tmp_path / "dump_varint_file.out", "rb") as test_stream, open( + streams_path / "message_dump_file_single.expected", "rb" + ) as exp_stream: + assert aristaproto.load_varint(test_stream) == aristaproto.load_varint( + exp_stream + ) + exp_stream.read(2) + assert aristaproto.load_varint(test_stream) == aristaproto.load_varint( + exp_stream + ) + + +def test_parse_fields(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + parsed_bytes = aristaproto.parse_fields(stream.read()) + + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + parsed_stream = aristaproto.load_fields(stream) + for field in parsed_bytes: + assert field == next(parsed_stream) + + +def test_message_dump_file_single(tmp_path): + # Write the message to the stream + with open(tmp_path / "message_dump_file_single.out", "wb") as stream: + oneof_example.dump(stream) + + # Check that the outputted file is exactly as expected + with open(tmp_path / "message_dump_file_single.out", "rb") as test_stream, open( + streams_path / "message_dump_file_single.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_dump_file_multiple(tmp_path): + # Write the same Message twice and another, different message + with open(tmp_path / "message_dump_file_multiple.out", "wb") as stream: + oneof_example.dump(stream) + oneof_example.dump(stream) + nested_example.dump(stream) + + # Check that all three Messages were outputted to the file correctly + with open(tmp_path / "message_dump_file_multiple.out", "rb") as test_stream, open( + streams_path / "message_dump_file_multiple.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_dump_delimited(tmp_path): + with open(tmp_path / "message_dump_delimited.out", "wb") as stream: + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + nested_example.dump(stream, aristaproto.SIZE_DELIMITED) + + with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open( + streams_path / "delimited_messages.in", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +def test_message_len(): + assert len_oneof == len(bytes(oneof_example)) + assert len(nested_example) == len(bytes(nested_example)) + + +def test_message_load_file_single(): + with open(streams_path / "message_dump_file_single.expected", "rb") as stream: + assert oneof.Test().load(stream) == oneof_example + stream.seek(0) + assert oneof.Test().load(stream, len_oneof) == oneof_example + + +def test_message_load_file_multiple(): + with open(streams_path / "message_dump_file_multiple.expected", "rb") as stream: + oneof_size = len_oneof + assert oneof.Test().load(stream, oneof_size) == oneof_example + assert oneof.Test().load(stream, oneof_size) == oneof_example + assert nested.Test().load(stream) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_small(): + with open( + streams_path / "message_dump_file_single.expected", "rb" + ) as stream, pytest.raises(ValueError): + oneof.Test().load(stream, len_oneof - 1) + + +def test_message_load_delimited(): + with open(streams_path / "delimited_messages.in", "rb") as stream: + assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example + assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example + assert nested.Test().load(stream, aristaproto.SIZE_DELIMITED) == nested_example + assert stream.read(1) == b"" + + +def test_message_load_too_large(): + with open( + streams_path / "message_dump_file_single.expected", "rb" + ) as stream, pytest.raises(ValueError): + oneof.Test().load(stream, len_oneof + 1) + + +def test_message_len_optional_field(): + @dataclass + class Request(aristaproto.Message): + flag: Optional[bool] = aristaproto.message_field(1, wraps=aristaproto.TYPE_BOOL) + + assert len(Request()) == len(b"") + assert len(Request(flag=True)) == len(b"\n\x02\x08\x01") + assert len(Request(flag=False)) == len(b"\n\x00") + + +def test_message_len_repeated_field(): + assert len(repeated_example) == len(bytes(repeated_example)) + + +def test_message_len_packed_field(): + assert len(packed_example) == len(bytes(packed_example)) + + +def test_message_len_map_field(): + assert len(map_example) == len(bytes(map_example)) + + +def test_message_len_empty_string(): + @dataclass + class Empty(aristaproto.Message): + string: str = aristaproto.string_field(1, "group") + integer: int = aristaproto.int32_field(2, "group") + + empty = Empty().from_dict({"string": ""}) + assert len(empty) == len(bytes(empty)) + + +def test_calculate_varint_size_negative(): + single_byte = -1 + multi_byte = -10000000 + edge = -(1 << 63) + beyond = -(1 << 63) - 1 + before = -(1 << 63) + 1 + + assert ( + aristaproto.size_varint(single_byte) + == len(aristaproto.encode_varint(single_byte)) + == 10 + ) + assert ( + aristaproto.size_varint(multi_byte) + == len(aristaproto.encode_varint(multi_byte)) + == 10 + ) + assert aristaproto.size_varint(edge) == len(aristaproto.encode_varint(edge)) == 10 + assert ( + aristaproto.size_varint(before) == len(aristaproto.encode_varint(before)) == 10 + ) + + with pytest.raises(ValueError): + aristaproto.size_varint(beyond) + + +def test_calculate_varint_size_positive(): + single_byte = 1 + multi_byte = 10000000 + + assert aristaproto.size_varint(single_byte) == len( + aristaproto.encode_varint(single_byte) + ) + assert aristaproto.size_varint(multi_byte) == len( + aristaproto.encode_varint(multi_byte) + ) + + +def test_dump_varint_negative(tmp_path): + single_byte = -1 + multi_byte = -10000000 + edge = -(1 << 63) + beyond = -(1 << 63) - 1 + before = -(1 << 63) + 1 + + with open(tmp_path / "dump_varint_negative.out", "wb") as stream: + aristaproto.dump_varint(single_byte, stream) + aristaproto.dump_varint(multi_byte, stream) + aristaproto.dump_varint(edge, stream) + aristaproto.dump_varint(before, stream) + + with pytest.raises(ValueError): + aristaproto.dump_varint(beyond, stream) + + with open(streams_path / "dump_varint_negative.expected", "rb") as exp_stream, open( + tmp_path / "dump_varint_negative.out", "rb" + ) as test_stream: + assert test_stream.read() == exp_stream.read() + + +def test_dump_varint_positive(tmp_path): + single_byte = 1 + multi_byte = 10000000 + + with open(tmp_path / "dump_varint_positive.out", "wb") as stream: + aristaproto.dump_varint(single_byte, stream) + aristaproto.dump_varint(multi_byte, stream) + + with open(tmp_path / "dump_varint_positive.out", "rb") as test_stream, open( + streams_path / "dump_varint_positive.expected", "rb" + ) as exp_stream: + assert test_stream.read() == exp_stream.read() + + +# Java compatibility tests + + +@pytest.fixture(scope="module") +def compile_jar(): + # Skip if not all required tools are present + if java is None: + pytest.skip("`java` command is absent and is required") + mvn = which("mvn") + if mvn is None: + pytest.skip("Maven is absent and is required") + + # Compile the JAR + proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"]) + if proc_maven.returncode != 0: + pytest.skip( + "Maven compatibility-test.jar build failed (maybe Java version <11?)" + ) + + +jar = "tests/streams/java/target/compatibility-test.jar" + + +def run_jar(command: str, tmp_path): + return run([java, "-jar", jar, command, tmp_path], check=True) + + +def run_java_single_varint(value: int, tmp_path) -> int: + # Write single varint to file + with open(tmp_path / "py_single_varint.out", "wb") as stream: + aristaproto.dump_varint(value, stream) + + # Have Java read this varint and write it back + run_jar("single_varint", tmp_path) + + # Read single varint from Java output file + with open(tmp_path / "java_single_varint.out", "rb") as stream: + returned = aristaproto.load_varint(stream) + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + return returned + + +def test_single_varint(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + + # Write a single-byte varint to a file and have Java read it back + returned = run_java_single_varint(single_byte[0], tmp_path) + assert returned == single_byte + + # Same for a multi-byte varint + returned = run_java_single_varint(multi_byte[0], tmp_path) + assert returned == multi_byte + + +def test_multiple_varints(compile_jar, tmp_path): + single_byte = (1, b"\x01") + multi_byte = (123456789, b"\x95\x9A\xEF\x3A") + over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B") + + # Write two varints to the same file + with open(tmp_path / "py_multiple_varints.out", "wb") as stream: + aristaproto.dump_varint(single_byte[0], stream) + aristaproto.dump_varint(multi_byte[0], stream) + aristaproto.dump_varint(over32[0], stream) + + # Have Java read these varints and write them back + run_jar("multiple_varints", tmp_path) + + # Read varints from Java output file + with open(tmp_path / "java_multiple_varints.out", "rb") as stream: + returned_single = aristaproto.load_varint(stream) + returned_multi = aristaproto.load_varint(stream) + returned_over32 = aristaproto.load_varint(stream) + with pytest.raises(EOFError): + aristaproto.load_varint(stream) + + assert returned_single == single_byte + assert returned_multi == multi_byte + assert returned_over32 == over32 + + +def test_single_message(compile_jar, tmp_path): + # Write message to file + with open(tmp_path / "py_single_message.out", "wb") as stream: + oneof_example.dump(stream) + + # Have Java read and return the message + run_jar("single_message", tmp_path) + + # Read and check the returned message + with open(tmp_path / "java_single_message.out", "rb") as stream: + returned = oneof.Test().load(stream, len(bytes(oneof_example))) + assert stream.read() == b"" + + assert returned == oneof_example + + +def test_multiple_messages(compile_jar, tmp_path): + # Write delimited messages to file + with open(tmp_path / "py_multiple_messages.out", "wb") as stream: + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + nested_example.dump(stream, aristaproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("multiple_messages", tmp_path) + + # Read and check the returned messages + with open(tmp_path / "java_multiple_messages.out", "rb") as stream: + returned_oneof = oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) + returned_nested = nested.Test().load(stream, aristaproto.SIZE_DELIMITED) + assert stream.read() == b"" + + assert returned_oneof == oneof_example + assert returned_nested == nested_example + + +def test_infinite_messages(compile_jar, tmp_path): + num_messages = 5 + + # Write delimited messages to file + with open(tmp_path / "py_infinite_messages.out", "wb") as stream: + for x in range(num_messages): + oneof_example.dump(stream, aristaproto.SIZE_DELIMITED) + + # Have Java read and return the messages + run_jar("infinite_messages", tmp_path) + + # Read and check the returned messages + messages = [] + with open(tmp_path / "java_infinite_messages.out", "rb") as stream: + while True: + try: + messages.append(oneof.Test().load(stream, aristaproto.SIZE_DELIMITED)) + except EOFError: + break + + assert len(messages) == num_messages diff --git a/tests/test_struct.py b/tests/test_struct.py new file mode 100644 index 0000000..c562763 --- /dev/null +++ b/tests/test_struct.py @@ -0,0 +1,36 @@ +import json + +from aristaproto.lib.google.protobuf import Struct +from aristaproto.lib.pydantic.google.protobuf import Struct as StructPydantic + + +def test_struct_roundtrip(): + data = { + "foo": "bar", + "baz": None, + "quux": 123, + "zap": [1, {"two": 3}, "four"], + } + data_json = json.dumps(data) + + struct_from_dict = Struct().from_dict(data) + assert struct_from_dict.fields == data + assert struct_from_dict.to_dict() == data + assert struct_from_dict.to_json() == data_json + + struct_from_json = Struct().from_json(data_json) + assert struct_from_json.fields == data + assert struct_from_json.to_dict() == data + assert struct_from_json == struct_from_dict + assert struct_from_json.to_json() == data_json + + struct_pyd_from_dict = StructPydantic(fields={}).from_dict(data) + assert struct_pyd_from_dict.fields == data + assert struct_pyd_from_dict.to_dict() == data + assert struct_pyd_from_dict.to_json() == data_json + + struct_pyd_from_dict = StructPydantic(fields={}).from_json(data_json) + assert struct_pyd_from_dict.fields == data + assert struct_pyd_from_dict.to_dict() == data + assert struct_pyd_from_dict == struct_pyd_from_dict + assert struct_pyd_from_dict.to_json() == data_json diff --git a/tests/test_timestamp.py b/tests/test_timestamp.py new file mode 100644 index 0000000..dd51420 --- /dev/null +++ b/tests/test_timestamp.py @@ -0,0 +1,27 @@ +from datetime import ( + datetime, + timezone, +) + +import pytest + +from aristaproto import _Timestamp + + +@pytest.mark.parametrize( + "dt", + [ + datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc), + datetime.now(timezone.utc), + # potential issue with floating point precision: + datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc), + # potential issue with negative timestamps: + datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc), + ], +) +def test_timestamp_to_datetime_and_back(dt: datetime): + """ + Make sure converting a datetime to a protobuf timestamp message + and then back again ends up with the same datetime. + """ + assert _Timestamp.from_datetime(dt).to_datetime() == dt diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 0000000..bfbe842 --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,16 @@ +from pathlib import Path + +import tomlkit + +from aristaproto import __version__ + + +PROJECT_TOML = Path(__file__).joinpath("..", "..", "pyproject.toml").resolve() + + +def test_version(): + with PROJECT_TOML.open() as toml_file: + project_config = tomlkit.loads(toml_file.read()) + assert ( + __version__ == project_config["tool"]["poetry"]["version"] + ), "Project version should match in package and package config" diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..2ba7cab --- /dev/null +++ b/tests/util.py @@ -0,0 +1,169 @@ +import asyncio +import atexit +import importlib +import os +import platform +import sys +import tempfile +from dataclasses import dataclass +from pathlib import Path +from types import ModuleType +from typing import ( + Callable, + Dict, + Generator, + List, + Optional, + Tuple, + Union, +) + + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +root_path = Path(__file__).resolve().parent +inputs_path = root_path.joinpath("inputs") +output_path_reference = root_path.joinpath("output_reference") +output_path_aristaproto = root_path.joinpath("output_aristaproto") +output_path_aristaproto_pydantic = root_path.joinpath("output_aristaproto_pydantic") + + +def get_files(path, suffix: str) -> Generator[str, None, None]: + for r, dirs, files in os.walk(path): + for filename in [f for f in files if f.endswith(suffix)]: + yield os.path.join(r, filename) + + +def get_directories(path): + for root, directories, files in os.walk(path): + yield from directories + + +async def protoc( + path: Union[str, Path], + output_dir: Union[str, Path], + reference: bool = False, + pydantic_dataclasses: bool = False, +): + path: Path = Path(path).resolve() + output_dir: Path = Path(output_dir).resolve() + python_out_option: str = "python_aristaproto_out" if not reference else "python_out" + + if pydantic_dataclasses: + plugin_path = Path("src/aristaproto/plugin/main.py") + + if "Win" in platform.system(): + with tempfile.NamedTemporaryFile( + "w", encoding="UTF-8", suffix=".bat", delete=False + ) as tf: + # See https://stackoverflow.com/a/42622705 + tf.writelines( + [ + "@echo off", + f"\nchdir {os.getcwd()}", + f"\n{sys.executable} -u {plugin_path.as_posix()}", + ] + ) + + tf.flush() + + plugin_path = Path(tf.name) + atexit.register(os.remove, plugin_path) + + command = [ + sys.executable, + "-m", + "grpc.tools.protoc", + f"--plugin=protoc-gen-custom={plugin_path.as_posix()}", + "--experimental_allow_proto3_optional", + "--custom_opt=pydantic_dataclasses", + f"--proto_path={path.as_posix()}", + f"--custom_out={output_dir.as_posix()}", + *[p.as_posix() for p in path.glob("*.proto")], + ] + else: + command = [ + sys.executable, + "-m", + "grpc.tools.protoc", + f"--proto_path={path.as_posix()}", + f"--{python_out_option}={output_dir.as_posix()}", + *[p.as_posix() for p in path.glob("*.proto")], + ] + proc = await asyncio.create_subprocess_exec( + *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await proc.communicate() + return stdout, stderr, proc.returncode + + +@dataclass +class TestCaseJsonFile: + json: str + test_name: str + file_name: str + + def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]): + return self.file_name in non_symmetrical_json.get(self.test_name, tuple()) + + +def get_test_case_json_data( + test_case_name: str, *json_file_names: str +) -> List[TestCaseJsonFile]: + """ + :return: + A list of all files found in "{inputs_path}/test_case_name" with names matching + f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by + json_file_names + """ + test_case_dir = inputs_path.joinpath(test_case_name) + possible_file_paths = [ + *(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names), + test_case_dir.joinpath(f"{test_case_name}.json"), + *test_case_dir.glob(f"{test_case_name}_*.json"), + ] + + result = [] + for test_data_file_path in possible_file_paths: + if not test_data_file_path.exists(): + continue + with test_data_file_path.open("r") as fh: + result.append( + TestCaseJsonFile( + fh.read(), test_case_name, test_data_file_path.name.split(".")[0] + ) + ) + + return result + + +def find_module( + module: ModuleType, predicate: Callable[[ModuleType], bool] +) -> Optional[ModuleType]: + """ + Recursively search module tree for a module that matches the search predicate. + Assumes that the submodules are directories containing __init__.py. + + Example: + + # find module inside foo that contains Test + import foo + test_module = find_module(foo, lambda m: hasattr(m, 'Test')) + """ + if predicate(module): + return module + + module_path = Path(*module.__path__) + + for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]: + if sub == module_path: + continue + sub_module_path = sub.relative_to(module_path) + sub_module_name = ".".join(sub_module_path.parts) + + sub_module = importlib.import_module(f".{sub_module_name}", module.__name__) + + if predicate(sub_module): + return sub_module + + return None |