summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/README.md91
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/conftest.py22
-rwxr-xr-xtests/generate.py196
-rw-r--r--tests/grpc/__init__.py0
-rw-r--r--tests/grpc/test_grpclib_client.py298
-rw-r--r--tests/grpc/test_stream_stream.py99
-rw-r--r--tests/grpc/thing_service.py85
-rw-r--r--tests/inputs/bool/bool.json3
-rw-r--r--tests/inputs/bool/bool.proto7
-rw-r--r--tests/inputs/bool/test_bool.py19
-rw-r--r--tests/inputs/bytes/bytes.json3
-rw-r--r--tests/inputs/bytes/bytes.proto7
-rw-r--r--tests/inputs/casing/casing.json4
-rw-r--r--tests/inputs/casing/casing.proto20
-rw-r--r--tests/inputs/casing/test_casing.py23
-rw-r--r--tests/inputs/casing_inner_class/casing_inner_class.proto10
-rw-r--r--tests/inputs/casing_inner_class/test_casing_inner_class.py14
-rw-r--r--tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto9
-rw-r--r--tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py14
-rw-r--r--tests/inputs/config.py30
-rw-r--r--tests/inputs/deprecated/deprecated.json6
-rw-r--r--tests/inputs/deprecated/deprecated.proto14
-rw-r--r--tests/inputs/double/double-negative.json3
-rw-r--r--tests/inputs/double/double.json3
-rw-r--r--tests/inputs/double/double.proto7
-rw-r--r--tests/inputs/empty_repeated/empty_repeated.json3
-rw-r--r--tests/inputs/empty_repeated/empty_repeated.proto11
-rw-r--r--tests/inputs/empty_service/empty_service.proto7
-rw-r--r--tests/inputs/entry/entry.proto20
-rw-r--r--tests/inputs/enum/enum.json9
-rw-r--r--tests/inputs/enum/enum.proto25
-rw-r--r--tests/inputs/enum/test_enum.py114
-rw-r--r--tests/inputs/example/example.proto911
-rw-r--r--tests/inputs/example_service/example_service.proto20
-rw-r--r--tests/inputs/example_service/test_example_service.py86
-rw-r--r--tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json7
-rw-r--r--tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto13
-rw-r--r--tests/inputs/fixed/fixed.json6
-rw-r--r--tests/inputs/fixed/fixed.proto10
-rw-r--r--tests/inputs/float/float.json9
-rw-r--r--tests/inputs/float/float.proto14
-rw-r--r--tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto22
-rw-r--r--tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py93
-rw-r--r--tests/inputs/googletypes/googletypes-missing.json1
-rw-r--r--tests/inputs/googletypes/googletypes.json7
-rw-r--r--tests/inputs/googletypes/googletypes.proto16
-rw-r--r--tests/inputs/googletypes_request/googletypes_request.proto29
-rw-r--r--tests/inputs/googletypes_request/test_googletypes_request.py47
-rw-r--r--tests/inputs/googletypes_response/googletypes_response.proto23
-rw-r--r--tests/inputs/googletypes_response/test_googletypes_response.py64
-rw-r--r--tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto26
-rw-r--r--tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py40
-rw-r--r--tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto13
-rw-r--r--tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto18
-rw-r--r--tests/inputs/googletypes_struct/googletypes_struct.json5
-rw-r--r--tests/inputs/googletypes_struct/googletypes_struct.proto9
-rw-r--r--tests/inputs/googletypes_value/googletypes_value.json11
-rw-r--r--tests/inputs/googletypes_value/googletypes_value.proto15
-rw-r--r--tests/inputs/import_capitalized_package/capitalized.proto8
-rw-r--r--tests/inputs/import_capitalized_package/test.proto11
-rw-r--r--tests/inputs/import_child_package_from_package/child.proto7
-rw-r--r--tests/inputs/import_child_package_from_package/import_child_package_from_package.proto11
-rw-r--r--tests/inputs/import_child_package_from_package/package_message.proto9
-rw-r--r--tests/inputs/import_child_package_from_root/child.proto7
-rw-r--r--tests/inputs/import_child_package_from_root/import_child_package_from_root.proto11
-rw-r--r--tests/inputs/import_circular_dependency/import_circular_dependency.proto30
-rw-r--r--tests/inputs/import_circular_dependency/other.proto8
-rw-r--r--tests/inputs/import_circular_dependency/root.proto7
-rw-r--r--tests/inputs/import_cousin_package/cousin.proto6
-rw-r--r--tests/inputs/import_cousin_package/test.proto11
-rw-r--r--tests/inputs/import_cousin_package_same_name/cousin.proto6
-rw-r--r--tests/inputs/import_cousin_package_same_name/test.proto11
-rw-r--r--tests/inputs/import_packages_same_name/import_packages_same_name.proto13
-rw-r--r--tests/inputs/import_packages_same_name/posts_v1.proto7
-rw-r--r--tests/inputs/import_packages_same_name/users_v1.proto7
-rw-r--r--tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto12
-rw-r--r--tests/inputs/import_parent_package_from_child/parent_package_message.proto6
-rw-r--r--tests/inputs/import_root_package_from_child/child.proto11
-rw-r--r--tests/inputs/import_root_package_from_child/root.proto7
-rw-r--r--tests/inputs/import_root_sibling/import_root_sibling.proto11
-rw-r--r--tests/inputs/import_root_sibling/sibling.proto7
-rw-r--r--tests/inputs/import_service_input_message/child_package_request_message.proto7
-rw-r--r--tests/inputs/import_service_input_message/import_service_input_message.proto25
-rw-r--r--tests/inputs/import_service_input_message/request_message.proto7
-rw-r--r--tests/inputs/import_service_input_message/test_import_service_input_message.py36
-rw-r--r--tests/inputs/int32/int32.json4
-rw-r--r--tests/inputs/int32/int32.proto10
-rw-r--r--tests/inputs/map/map.json7
-rw-r--r--tests/inputs/map/map.proto7
-rw-r--r--tests/inputs/mapmessage/mapmessage.json10
-rw-r--r--tests/inputs/mapmessage/mapmessage.proto11
-rw-r--r--tests/inputs/namespace_builtin_types/namespace_builtin_types.json16
-rw-r--r--tests/inputs/namespace_builtin_types/namespace_builtin_types.proto40
-rw-r--r--tests/inputs/namespace_keywords/namespace_keywords.json37
-rw-r--r--tests/inputs/namespace_keywords/namespace_keywords.proto46
-rw-r--r--tests/inputs/nested/nested.json7
-rw-r--r--tests/inputs/nested/nested.proto26
-rw-r--r--tests/inputs/nested2/nested2.proto21
-rw-r--r--tests/inputs/nested2/package.proto7
-rw-r--r--tests/inputs/nestedtwice/nestedtwice.json11
-rw-r--r--tests/inputs/nestedtwice/nestedtwice.proto40
-rw-r--r--tests/inputs/nestedtwice/test_nestedtwice.py25
-rw-r--r--tests/inputs/oneof/oneof-name.json3
-rw-r--r--tests/inputs/oneof/oneof.json3
-rw-r--r--tests/inputs/oneof/oneof.proto23
-rw-r--r--tests/inputs/oneof/oneof_name.json3
-rw-r--r--tests/inputs/oneof/test_oneof.py43
-rw-r--r--tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto30
-rw-r--r--tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py75
-rw-r--r--tests/inputs/oneof_empty/oneof_empty.json3
-rw-r--r--tests/inputs/oneof_empty/oneof_empty.proto17
-rw-r--r--tests/inputs/oneof_empty/oneof_empty_maybe1.json3
-rw-r--r--tests/inputs/oneof_empty/oneof_empty_maybe2.json5
-rw-r--r--tests/inputs/oneof_empty/test_oneof_empty.py0
-rw-r--r--tests/inputs/oneof_enum/oneof_enum-enum-0.json3
-rw-r--r--tests/inputs/oneof_enum/oneof_enum-enum-1.json3
-rw-r--r--tests/inputs/oneof_enum/oneof_enum.json6
-rw-r--r--tests/inputs/oneof_enum/oneof_enum.proto20
-rw-r--r--tests/inputs/oneof_enum/test_oneof_enum.py47
-rw-r--r--tests/inputs/proto3_field_presence/proto3_field_presence.json13
-rw-r--r--tests/inputs/proto3_field_presence/proto3_field_presence.proto26
-rw-r--r--tests/inputs/proto3_field_presence/proto3_field_presence_default.json1
-rw-r--r--tests/inputs/proto3_field_presence/proto3_field_presence_missing.json9
-rw-r--r--tests/inputs/proto3_field_presence/test_proto3_field_presence.py48
-rw-r--r--tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json3
-rw-r--r--tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto22
-rw-r--r--tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py29
-rw-r--r--tests/inputs/recursivemessage/recursivemessage.json12
-rw-r--r--tests/inputs/recursivemessage/recursivemessage.proto15
-rw-r--r--tests/inputs/ref/ref.json5
-rw-r--r--tests/inputs/ref/ref.proto9
-rw-r--r--tests/inputs/ref/repeatedmessage.proto11
-rw-r--r--tests/inputs/regression_387/regression_387.proto12
-rw-r--r--tests/inputs/regression_387/test_regression_387.py12
-rw-r--r--tests/inputs/regression_414/regression_414.proto9
-rw-r--r--tests/inputs/regression_414/test_regression_414.py15
-rw-r--r--tests/inputs/repeated/repeated.json3
-rw-r--r--tests/inputs/repeated/repeated.proto7
-rw-r--r--tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json4
-rw-r--r--tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto12
-rw-r--r--tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py12
-rw-r--r--tests/inputs/repeatedmessage/repeatedmessage.json10
-rw-r--r--tests/inputs/repeatedmessage/repeatedmessage.proto11
-rw-r--r--tests/inputs/repeatedpacked/repeatedpacked.json5
-rw-r--r--tests/inputs/repeatedpacked/repeatedpacked.proto9
-rw-r--r--tests/inputs/service/service.proto35
-rw-r--r--tests/inputs/service_separate_packages/messages.proto31
-rw-r--r--tests/inputs/service_separate_packages/service.proto12
-rw-r--r--tests/inputs/service_uppercase/service.proto16
-rw-r--r--tests/inputs/service_uppercase/test_service.py8
-rw-r--r--tests/inputs/signed/signed.json6
-rw-r--r--tests/inputs/signed/signed.proto11
-rw-r--r--tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py82
-rw-r--r--tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json3
-rw-r--r--tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto9
-rw-r--r--tests/mocks.py40
-rw-r--r--tests/oneof_pattern_matching.py46
-rw-r--r--tests/streams/delimited_messages.in2
-rw-r--r--tests/streams/dump_varint_negative.expected1
-rw-r--r--tests/streams/dump_varint_positive.expected1
-rw-r--r--tests/streams/java/.gitignore38
-rw-r--r--tests/streams/java/pom.xml94
-rw-r--r--tests/streams/java/src/main/java/aristaproto/CompatibilityTest.java41
-rw-r--r--tests/streams/java/src/main/java/aristaproto/Tests.java115
-rw-r--r--tests/streams/java/src/main/proto/aristaproto/nested.proto27
-rw-r--r--tests/streams/java/src/main/proto/aristaproto/oneof.proto19
-rw-r--r--tests/streams/load_varint_cutoff.in1
-rw-r--r--tests/streams/message_dump_file_multiple.expected2
-rw-r--r--tests/streams/message_dump_file_single.expected1
-rw-r--r--tests/test_casing.py129
-rw-r--r--tests/test_deprecated.py45
-rw-r--r--tests/test_enum.py79
-rw-r--r--tests/test_features.py682
-rw-r--r--tests/test_get_ref_type.py371
-rw-r--r--tests/test_inputs.py225
-rw-r--r--tests/test_mapmessage.py18
-rw-r--r--tests/test_pickling.py203
-rw-r--r--tests/test_streams.py434
-rw-r--r--tests/test_struct.py36
-rw-r--r--tests/test_timestamp.py27
-rw-r--r--tests/test_version.py16
-rw-r--r--tests/util.py169
183 files changed, 7001 insertions, 0 deletions
diff --git a/tests/README.md b/tests/README.md
new file mode 100644
index 0000000..1301f6b
--- /dev/null
+++ b/tests/README.md
@@ -0,0 +1,91 @@
+# Standard Tests Development Guide
+
+Standard test cases are found in [aristaproto/tests/inputs](inputs), where each subdirectory represents a testcase, that is verified in isolation.
+
+```
+inputs/
+ bool/
+ double/
+ int32/
+ ...
+```
+
+## Test case directory structure
+
+Each testcase has a `<name>.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`.
+
+```bash
+bool/
+ bool.proto
+ bool.json # optional
+ test_bool.py # optional
+```
+
+### proto
+
+`<name>.proto` &mdash; *The protobuf message to test*
+
+```protobuf
+syntax = "proto3";
+
+message Test {
+ bool value = 1;
+}
+```
+
+You can add multiple `.proto` files to the test case, as long as one file matches the directory name.
+
+### json
+
+`<name>.json` &mdash; *Test-data to validate the message with*
+
+```json
+{
+ "value": true
+}
+```
+
+### pytest
+
+`test_<name>.py` &mdash; *Custom test to validate specific aspects of the generated class*
+
+```python
+from tests.output_aristaproto.bool.bool import Test
+
+def test_value():
+ message = Test()
+ assert not message.value, "Boolean is False by default"
+```
+
+## Standard tests
+
+The following tests are automatically executed for all cases:
+
+- [x] Can the generated python code be imported?
+- [x] Can the generated message class be instantiated?
+- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
+ - _when `.json` is present_
+
+## Running the tests
+
+- `pipenv run generate`
+ This generates:
+ - `aristaproto/tests/output_aristaproto` &mdash; *the plugin generated python classes*
+ - `aristaproto/tests/output_reference` &mdash; *reference implementation classes*
+- `pipenv run test`
+
+## Intentionally Failing tests
+
+The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future.
+
+When running `pytest`, they show up as `x` or `X` in the test results.
+
+```
+aristaproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x.......x.xx....x...................... [ 84%]
+```
+
+- `.` &mdash; PASSED
+- `x` &mdash; XFAIL: expected failure
+- `X` &mdash; XPASS: expected failure, but still passed
+
+Test cases marked for expected failure are declared in [inputs/config.py](inputs/config.py) \ No newline at end of file
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..c6b256d
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,22 @@
+import copy
+import sys
+
+import pytest
+
+
+def pytest_addoption(parser):
+ parser.addoption(
+ "--repeat", type=int, default=1, help="repeat the operation multiple times"
+ )
+
+
+@pytest.fixture(scope="session")
+def repeat(request):
+ return request.config.getoption("repeat")
+
+
+@pytest.fixture
+def reset_sys_path():
+ original = copy.deepcopy(sys.path)
+ yield
+ sys.path = original
diff --git a/tests/generate.py b/tests/generate.py
new file mode 100755
index 0000000..d6f36de
--- /dev/null
+++ b/tests/generate.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python
+import asyncio
+import os
+import platform
+import shutil
+import sys
+from pathlib import Path
+from typing import Set
+
+from tests.util import (
+ get_directories,
+ inputs_path,
+ output_path_aristaproto,
+ output_path_aristaproto_pydantic,
+ output_path_reference,
+ protoc,
+)
+
+
+# Force pure-python implementation instead of C++, otherwise imports
+# break things because we can't properly reset the symbol database.
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+
+def clear_directory(dir_path: Path):
+ for file_or_directory in dir_path.glob("*"):
+ if file_or_directory.is_dir():
+ shutil.rmtree(file_or_directory)
+ else:
+ file_or_directory.unlink()
+
+
+async def generate(whitelist: Set[str], verbose: bool):
+ test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
+
+ path_whitelist = set()
+ name_whitelist = set()
+ for item in whitelist:
+ if item in test_case_names:
+ name_whitelist.add(item)
+ continue
+ path_whitelist.add(item)
+
+ generation_tasks = []
+ for test_case_name in sorted(test_case_names):
+ test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
+ if (
+ whitelist
+ and str(test_case_input_path) not in path_whitelist
+ and test_case_name not in name_whitelist
+ ):
+ continue
+ generation_tasks.append(
+ generate_test_case_output(test_case_input_path, test_case_name, verbose)
+ )
+
+ failed_test_cases = []
+ # Wait for all subprocs and match any failures to names to report
+ for test_case_name, result in zip(
+ sorted(test_case_names), await asyncio.gather(*generation_tasks)
+ ):
+ if result != 0:
+ failed_test_cases.append(test_case_name)
+
+ if len(failed_test_cases) > 0:
+ sys.stderr.write(
+ "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
+ )
+ for failed_test_case in failed_test_cases:
+ sys.stderr.write(f"- {failed_test_case}\n")
+
+ sys.exit(1)
+
+
+async def generate_test_case_output(
+ test_case_input_path: Path, test_case_name: str, verbose: bool
+) -> int:
+ """
+ Returns the max of the subprocess return values
+ """
+
+ test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
+ test_case_output_path_aristaproto = output_path_aristaproto
+ test_case_output_path_aristaproto_pyd = output_path_aristaproto_pydantic
+
+ os.makedirs(test_case_output_path_reference, exist_ok=True)
+ os.makedirs(test_case_output_path_aristaproto, exist_ok=True)
+ os.makedirs(test_case_output_path_aristaproto_pyd, exist_ok=True)
+
+ clear_directory(test_case_output_path_reference)
+ clear_directory(test_case_output_path_aristaproto)
+
+ (
+ (ref_out, ref_err, ref_code),
+ (plg_out, plg_err, plg_code),
+ (plg_out_pyd, plg_err_pyd, plg_code_pyd),
+ ) = await asyncio.gather(
+ protoc(test_case_input_path, test_case_output_path_reference, True),
+ protoc(test_case_input_path, test_case_output_path_aristaproto, False),
+ protoc(
+ test_case_input_path, test_case_output_path_aristaproto_pyd, False, True
+ ),
+ )
+
+ if ref_code == 0:
+ print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m")
+ else:
+ print(
+ f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
+ )
+
+ if verbose:
+ if ref_out:
+ print("Reference stdout:")
+ sys.stdout.buffer.write(ref_out)
+ sys.stdout.buffer.flush()
+
+ if ref_err:
+ print("Reference stderr:")
+ sys.stderr.buffer.write(ref_err)
+ sys.stderr.buffer.flush()
+
+ if plg_code == 0:
+ print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m")
+ else:
+ print(
+ f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
+ )
+
+ if verbose:
+ if plg_out:
+ print("Plugin stdout:")
+ sys.stdout.buffer.write(plg_out)
+ sys.stdout.buffer.flush()
+
+ if plg_err:
+ print("Plugin stderr:")
+ sys.stderr.buffer.write(plg_err)
+ sys.stderr.buffer.flush()
+
+ if plg_code_pyd == 0:
+ print(
+ f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
+ )
+ else:
+ print(
+ f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
+ )
+
+ if verbose:
+ if plg_out_pyd:
+ print("Plugin stdout:")
+ sys.stdout.buffer.write(plg_out_pyd)
+ sys.stdout.buffer.flush()
+
+ if plg_err_pyd:
+ print("Plugin stderr:")
+ sys.stderr.buffer.write(plg_err_pyd)
+ sys.stderr.buffer.flush()
+
+ return max(ref_code, plg_code, plg_code_pyd)
+
+
+HELP = "\n".join(
+ (
+ "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
+ "Generate python classes for standard tests.",
+ "",
+ "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
+ " python generate.py inputs/bool inputs/double inputs/enum",
+ "",
+ "NAMES One or more test-case names to generate classes for.",
+ " python generate.py bool double enums",
+ )
+)
+
+
+def main():
+ if set(sys.argv).intersection({"-h", "--help"}):
+ print(HELP)
+ return
+ if sys.argv[1:2] == ["-v"]:
+ verbose = True
+ whitelist = set(sys.argv[2:])
+ else:
+ verbose = False
+ whitelist = set(sys.argv[1:])
+
+ if platform.system() == "Windows":
+ asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
+
+ asyncio.run(generate(whitelist, verbose))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/grpc/__init__.py b/tests/grpc/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/grpc/__init__.py
diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py
new file mode 100644
index 0000000..d36e4a5
--- /dev/null
+++ b/tests/grpc/test_grpclib_client.py
@@ -0,0 +1,298 @@
+import asyncio
+import sys
+import uuid
+
+import grpclib
+import grpclib.client
+import grpclib.metadata
+import grpclib.server
+import pytest
+from grpclib.testing import ChannelFor
+
+from aristaproto.grpc.util.async_channel import AsyncChannel
+from tests.output_aristaproto.service import (
+ DoThingRequest,
+ DoThingResponse,
+ GetThingRequest,
+ TestStub as ThingServiceClient,
+)
+
+from .thing_service import ThingService
+
+
+async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
+ response = await client.do_thing(DoThingRequest(name=name), **kwargs)
+ assert response.names == [name]
+
+
+def _assert_request_meta_received(deadline, metadata):
+ def server_side_test(stream):
+ assert stream.deadline._timestamp == pytest.approx(
+ deadline._timestamp, 1
+ ), "The provided deadline should be received serverside"
+ assert (
+ stream.metadata["authorization"] == metadata["authorization"]
+ ), "The provided authorization metadata should be received serverside"
+
+ return server_side_test
+
+
+@pytest.fixture
+def handler_trailer_only_unauthenticated():
+ async def handler(stream: grpclib.server.Stream):
+ await stream.recv_message()
+ await stream.send_initial_metadata()
+ await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED)
+
+ return handler
+
+
+@pytest.mark.asyncio
+async def test_simple_service_call():
+ async with ChannelFor([ThingService()]) as channel:
+ await _test_client(ThingServiceClient(channel))
+
+
+@pytest.mark.asyncio
+async def test_trailer_only_error_unary_unary(
+ mocker, handler_trailer_only_unauthenticated
+):
+ service = ThingService()
+ mocker.patch.object(
+ service,
+ "do_thing",
+ side_effect=handler_trailer_only_unauthenticated,
+ autospec=True,
+ )
+ async with ChannelFor([service]) as channel:
+ with pytest.raises(grpclib.exceptions.GRPCError) as e:
+ await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
+ assert e.value.status == grpclib.Status.UNAUTHENTICATED
+
+
+@pytest.mark.asyncio
+async def test_trailer_only_error_stream_unary(
+ mocker, handler_trailer_only_unauthenticated
+):
+ service = ThingService()
+ mocker.patch.object(
+ service,
+ "do_many_things",
+ side_effect=handler_trailer_only_unauthenticated,
+ autospec=True,
+ )
+ async with ChannelFor([service]) as channel:
+ with pytest.raises(grpclib.exceptions.GRPCError) as e:
+ await ThingServiceClient(channel).do_many_things(
+ do_thing_request_iterator=[DoThingRequest(name="something")]
+ )
+ await _test_client(ThingServiceClient(channel))
+ assert e.value.status == grpclib.Status.UNAUTHENTICATED
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(
+ sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
+)
+async def test_service_call_mutable_defaults(mocker):
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ spy = mocker.spy(client, "_unary_unary")
+ await _test_client(client)
+ comments = spy.call_args_list[-1].args[1].comments
+ await _test_client(client)
+ assert spy.call_args_list[-1].args[1].comments is not comments
+
+
+@pytest.mark.asyncio
+async def test_service_call_with_upfront_request_params():
+ # Setting deadline
+ deadline = grpclib.metadata.Deadline.from_timeout(22)
+ metadata = {"authorization": "12345"}
+ async with ChannelFor(
+ [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
+ ) as channel:
+ await _test_client(
+ ThingServiceClient(channel, deadline=deadline, metadata=metadata)
+ )
+
+ # Setting timeout
+ timeout = 99
+ deadline = grpclib.metadata.Deadline.from_timeout(timeout)
+ metadata = {"authorization": "12345"}
+ async with ChannelFor(
+ [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
+ ) as channel:
+ await _test_client(
+ ThingServiceClient(channel, timeout=timeout, metadata=metadata)
+ )
+
+
+@pytest.mark.asyncio
+async def test_service_call_lower_level_with_overrides():
+ THING_TO_DO = "get milk"
+
+ # Setting deadline
+ deadline = grpclib.metadata.Deadline.from_timeout(22)
+ metadata = {"authorization": "12345"}
+ kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
+ kwarg_metadata = {"authorization": "12345"}
+ async with ChannelFor(
+ [ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
+ ) as channel:
+ client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
+ response = await client._unary_unary(
+ "/service.Test/DoThing",
+ DoThingRequest(THING_TO_DO),
+ DoThingResponse,
+ deadline=kwarg_deadline,
+ metadata=kwarg_metadata,
+ )
+ assert response.names == [THING_TO_DO]
+
+ # Setting timeout
+ timeout = 99
+ deadline = grpclib.metadata.Deadline.from_timeout(timeout)
+ metadata = {"authorization": "12345"}
+ kwarg_timeout = 9000
+ kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
+ kwarg_metadata = {"authorization": "09876"}
+ async with ChannelFor(
+ [
+ ThingService(
+ test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
+ )
+ ]
+ ) as channel:
+ client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
+ response = await client._unary_unary(
+ "/service.Test/DoThing",
+ DoThingRequest(THING_TO_DO),
+ DoThingResponse,
+ timeout=kwarg_timeout,
+ metadata=kwarg_metadata,
+ )
+ assert response.names == [THING_TO_DO]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(
+ ("overrides_gen",),
+ [
+ (lambda: dict(timeout=10),),
+ (lambda: dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
+ (lambda: dict(metadata={"authorization": str(uuid.uuid4())}),),
+ (lambda: dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
+ ],
+)
+async def test_service_call_high_level_with_overrides(mocker, overrides_gen):
+ overrides = overrides_gen()
+ request_spy = mocker.spy(grpclib.client.Channel, "request")
+ name = str(uuid.uuid4())
+ defaults = dict(
+ timeout=99,
+ deadline=grpclib.metadata.Deadline.from_timeout(99),
+ metadata={"authorization": name},
+ )
+
+ async with ChannelFor(
+ [
+ ThingService(
+ test_hook=_assert_request_meta_received(
+ deadline=grpclib.metadata.Deadline.from_timeout(
+ overrides.get("timeout", 99)
+ ),
+ metadata=overrides.get("metadata", defaults.get("metadata")),
+ )
+ )
+ ]
+ ) as channel:
+ client = ThingServiceClient(channel, **defaults)
+ await _test_client(client, name=name, **overrides)
+ assert request_spy.call_count == 1
+
+ # for python <3.8 request_spy.call_args.kwargs do not work
+ _, request_spy_call_kwargs = request_spy.call_args_list[0]
+
+ # ensure all overrides were successful
+ for key, value in overrides.items():
+ assert key in request_spy_call_kwargs
+ assert request_spy_call_kwargs[key] == value
+
+ # ensure default values were retained
+ for key in set(defaults.keys()) - set(overrides.keys()):
+ assert key in request_spy_call_kwargs
+ assert request_spy_call_kwargs[key] == defaults[key]
+
+
+@pytest.mark.asyncio
+async def test_async_gen_for_unary_stream_request():
+ thing_name = "my milkshakes"
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ expected_versions = [5, 4, 3, 2, 1]
+ async for response in client.get_thing_versions(
+ GetThingRequest(name=thing_name)
+ ):
+ assert response.name == thing_name
+ assert response.version == expected_versions.pop()
+
+
+@pytest.mark.asyncio
+async def test_async_gen_for_stream_stream_request():
+ some_things = ["cake", "cricket", "coral reef"]
+ more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
+ expected_things = (*some_things, *more_things)
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ # Use an AsyncChannel to decouple sending and recieving, it'll send some_things
+ # immediately and we'll use it to send more_things later, after recieving some
+ # results
+ request_chan = AsyncChannel()
+ send_initial_requests = asyncio.ensure_future(
+ request_chan.send_from(GetThingRequest(name) for name in some_things)
+ )
+ response_index = 0
+ async for response in client.get_different_things(request_chan):
+ assert response.name == expected_things[response_index]
+ assert response.version == response_index + 1
+ response_index += 1
+ if more_things:
+ # Send some more requests as we receive responses to be sure coordination of
+ # send/receive events doesn't matter
+ await request_chan.send(GetThingRequest(more_things.pop(0)))
+ elif not send_initial_requests.done():
+ # Make sure the sending task it completed
+ await send_initial_requests
+ else:
+ # No more things to send make sure channel is closed
+ request_chan.close()
+ assert response_index == len(
+ expected_things
+ ), "Didn't receive all expected responses"
+
+
+@pytest.mark.asyncio
+async def test_stream_unary_with_empty_iterable():
+ things = [] # empty
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ requests = [DoThingRequest(name) for name in things]
+ response = await client.do_many_things(requests)
+ assert len(response.names) == 0
+
+
+@pytest.mark.asyncio
+async def test_stream_stream_with_empty_iterable():
+ things = [] # empty
+
+ async with ChannelFor([ThingService()]) as channel:
+ client = ThingServiceClient(channel)
+ requests = [GetThingRequest(name) for name in things]
+ responses = [
+ response async for response in client.get_different_things(requests)
+ ]
+ assert len(responses) == 0
diff --git a/tests/grpc/test_stream_stream.py b/tests/grpc/test_stream_stream.py
new file mode 100644
index 0000000..d4b27e5
--- /dev/null
+++ b/tests/grpc/test_stream_stream.py
@@ -0,0 +1,99 @@
+import asyncio
+from dataclasses import dataclass
+from typing import AsyncIterator
+
+import pytest
+
+import aristaproto
+from aristaproto.grpc.util.async_channel import AsyncChannel
+
+
+@dataclass
+class Message(aristaproto.Message):
+ body: str = aristaproto.string_field(1)
+
+
+@pytest.fixture
+def expected_responses():
+ return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
+
+
+class ClientStub:
+ async def connect(self, requests: AsyncIterator):
+ await asyncio.sleep(0.1)
+ async for request in requests:
+ await asyncio.sleep(0.1)
+ yield request
+ await asyncio.sleep(0.1)
+ yield Message("Done")
+
+
+async def to_list(generator: AsyncIterator):
+ return [value async for value in generator]
+
+
+@pytest.fixture
+def client():
+ # channel = Channel(host='127.0.0.1', port=50051)
+ # return ClientStub(channel)
+ return ClientStub()
+
+
+@pytest.mark.asyncio
+async def test_send_from_before_connect_and_close_automatically(
+ client, expected_responses
+):
+ requests = AsyncChannel()
+ await requests.send_from(
+ [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
+ )
+ responses = client.connect(requests)
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_from_after_connect_and_close_automatically(
+ client, expected_responses
+):
+ requests = AsyncChannel()
+ responses = client.connect(requests)
+ await requests.send_from(
+ [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
+ )
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_from_close_manually_immediately(client, expected_responses):
+ requests = AsyncChannel()
+ responses = client.connect(requests)
+ await requests.send_from(
+ [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
+ )
+ requests.close()
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_individually_and_close_before_connect(client, expected_responses):
+ requests = AsyncChannel()
+ await requests.send(Message(body="Hello world 1"))
+ await requests.send(Message(body="Hello world 2"))
+ requests.close()
+ responses = client.connect(requests)
+
+ assert await to_list(responses) == expected_responses
+
+
+@pytest.mark.asyncio
+async def test_send_individually_and_close_after_connect(client, expected_responses):
+ requests = AsyncChannel()
+ await requests.send(Message(body="Hello world 1"))
+ await requests.send(Message(body="Hello world 2"))
+ responses = client.connect(requests)
+ requests.close()
+
+ assert await to_list(responses) == expected_responses
diff --git a/tests/grpc/thing_service.py b/tests/grpc/thing_service.py
new file mode 100644
index 0000000..5b00cbe
--- /dev/null
+++ b/tests/grpc/thing_service.py
@@ -0,0 +1,85 @@
+from typing import Dict
+
+import grpclib
+import grpclib.server
+
+from tests.output_aristaproto.service import (
+ DoThingRequest,
+ DoThingResponse,
+ GetThingRequest,
+ GetThingResponse,
+)
+
+
+class ThingService:
+ def __init__(self, test_hook=None):
+ # This lets us pass assertions to the servicer ;)
+ self.test_hook = test_hook
+
+ async def do_thing(
+ self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
+ ):
+ request = await stream.recv_message()
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ await stream.send_message(DoThingResponse([request.name]))
+
+ async def do_many_things(
+ self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
+ ):
+ thing_names = [request.name async for request in stream]
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ await stream.send_message(DoThingResponse(thing_names))
+
+ async def get_thing_versions(
+ self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
+ ):
+ request = await stream.recv_message()
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ for version_num in range(1, 6):
+ await stream.send_message(
+ GetThingResponse(name=request.name, version=version_num)
+ )
+
+ async def get_different_things(
+ self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
+ ):
+ if self.test_hook is not None:
+ self.test_hook(stream)
+ # Respond to each input item immediately
+ response_num = 0
+ async for request in stream:
+ response_num += 1
+ await stream.send_message(
+ GetThingResponse(name=request.name, version=response_num)
+ )
+
+ def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
+ return {
+ "/service.Test/DoThing": grpclib.const.Handler(
+ self.do_thing,
+ grpclib.const.Cardinality.UNARY_UNARY,
+ DoThingRequest,
+ DoThingResponse,
+ ),
+ "/service.Test/DoManyThings": grpclib.const.Handler(
+ self.do_many_things,
+ grpclib.const.Cardinality.STREAM_UNARY,
+ DoThingRequest,
+ DoThingResponse,
+ ),
+ "/service.Test/GetThingVersions": grpclib.const.Handler(
+ self.get_thing_versions,
+ grpclib.const.Cardinality.UNARY_STREAM,
+ GetThingRequest,
+ GetThingResponse,
+ ),
+ "/service.Test/GetDifferentThings": grpclib.const.Handler(
+ self.get_different_things,
+ grpclib.const.Cardinality.STREAM_STREAM,
+ GetThingRequest,
+ GetThingResponse,
+ ),
+ }
diff --git a/tests/inputs/bool/bool.json b/tests/inputs/bool/bool.json
new file mode 100644
index 0000000..348e031
--- /dev/null
+++ b/tests/inputs/bool/bool.json
@@ -0,0 +1,3 @@
+{
+ "value": true
+}
diff --git a/tests/inputs/bool/bool.proto b/tests/inputs/bool/bool.proto
new file mode 100644
index 0000000..77836b8
--- /dev/null
+++ b/tests/inputs/bool/bool.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package bool;
+
+message Test {
+ bool value = 1;
+}
diff --git a/tests/inputs/bool/test_bool.py b/tests/inputs/bool/test_bool.py
new file mode 100644
index 0000000..f9554ae
--- /dev/null
+++ b/tests/inputs/bool/test_bool.py
@@ -0,0 +1,19 @@
+import pytest
+
+from tests.output_aristaproto.bool import Test
+from tests.output_aristaproto_pydantic.bool import Test as TestPyd
+
+
+def test_value():
+ message = Test()
+ assert not message.value, "Boolean is False by default"
+
+
+def test_pydantic_no_value():
+ with pytest.raises(ValueError):
+ TestPyd()
+
+
+def test_pydantic_value():
+ message = Test(value=False)
+ assert not message.value
diff --git a/tests/inputs/bytes/bytes.json b/tests/inputs/bytes/bytes.json
new file mode 100644
index 0000000..34c4554
--- /dev/null
+++ b/tests/inputs/bytes/bytes.json
@@ -0,0 +1,3 @@
+{
+ "data": "SGVsbG8sIFdvcmxkIQ=="
+}
diff --git a/tests/inputs/bytes/bytes.proto b/tests/inputs/bytes/bytes.proto
new file mode 100644
index 0000000..9895468
--- /dev/null
+++ b/tests/inputs/bytes/bytes.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package bytes;
+
+message Test {
+ bytes data = 1;
+}
diff --git a/tests/inputs/casing/casing.json b/tests/inputs/casing/casing.json
new file mode 100644
index 0000000..559104b
--- /dev/null
+++ b/tests/inputs/casing/casing.json
@@ -0,0 +1,4 @@
+{
+ "camelCase": 1,
+ "snakeCase": "ONE"
+}
diff --git a/tests/inputs/casing/casing.proto b/tests/inputs/casing/casing.proto
new file mode 100644
index 0000000..2023d93
--- /dev/null
+++ b/tests/inputs/casing/casing.proto
@@ -0,0 +1,20 @@
+syntax = "proto3";
+
+package casing;
+
+enum my_enum {
+ ZERO = 0;
+ ONE = 1;
+ TWO = 2;
+}
+
+message Test {
+ int32 camelCase = 1;
+ my_enum snake_case = 2;
+ snake_case_message snake_case_message = 3;
+ int32 UPPERCASE = 4;
+}
+
+message snake_case_message {
+
+} \ No newline at end of file
diff --git a/tests/inputs/casing/test_casing.py b/tests/inputs/casing/test_casing.py
new file mode 100644
index 0000000..0fa609b
--- /dev/null
+++ b/tests/inputs/casing/test_casing.py
@@ -0,0 +1,23 @@
+import tests.output_aristaproto.casing as casing
+from tests.output_aristaproto.casing import Test
+
+
+def test_message_attributes():
+ message = Test()
+ assert hasattr(
+ message, "snake_case_message"
+ ), "snake_case field name is same in python"
+ assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
+ assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
+
+
+def test_message_casing():
+ assert hasattr(
+ casing, "SnakeCaseMessage"
+ ), "snake_case Message name is converted to CamelCase in python"
+
+
+def test_enum_casing():
+ assert hasattr(
+ casing, "MyEnum"
+ ), "snake_case Enum name is converted to CamelCase in python"
diff --git a/tests/inputs/casing_inner_class/casing_inner_class.proto b/tests/inputs/casing_inner_class/casing_inner_class.proto
new file mode 100644
index 0000000..fae2a4c
--- /dev/null
+++ b/tests/inputs/casing_inner_class/casing_inner_class.proto
@@ -0,0 +1,10 @@
+syntax = "proto3";
+
+package casing_inner_class;
+
+message Test {
+ message inner_class {
+ sint32 old_exp = 1;
+ }
+ inner_class inner = 2;
+} \ No newline at end of file
diff --git a/tests/inputs/casing_inner_class/test_casing_inner_class.py b/tests/inputs/casing_inner_class/test_casing_inner_class.py
new file mode 100644
index 0000000..7c43add
--- /dev/null
+++ b/tests/inputs/casing_inner_class/test_casing_inner_class.py
@@ -0,0 +1,14 @@
+import tests.output_aristaproto.casing_inner_class as casing_inner_class
+
+
+def test_message_casing_inner_class_name():
+ assert hasattr(
+ casing_inner_class, "TestInnerClass"
+ ), "Inline defined Message is correctly converted to CamelCase"
+
+
+def test_message_casing_inner_class_attributes():
+ message = casing_inner_class.Test()
+ assert hasattr(
+ message.inner, "old_exp"
+ ), "Inline defined Message attribute is snake_case"
diff --git a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto
new file mode 100644
index 0000000..c6d42c3
--- /dev/null
+++ b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package casing_message_field_uppercase;
+
+message Test {
+ int32 UPPERCASE = 1;
+ int32 UPPERCASE_V2 = 2;
+ int32 UPPER_CAMEL_CASE = 3;
+} \ No newline at end of file
diff --git a/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py
new file mode 100644
index 0000000..01a5234
--- /dev/null
+++ b/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py
@@ -0,0 +1,14 @@
+from tests.output_aristaproto.casing_message_field_uppercase import Test
+
+
+def test_message_casing():
+ message = Test()
+ assert hasattr(
+ message, "uppercase"
+ ), "UPPERCASE attribute is converted to 'uppercase' in python"
+ assert hasattr(
+ message, "uppercase_v2"
+ ), "UPPERCASE_V2 attribute is converted to 'uppercase_v2' in python"
+ assert hasattr(
+ message, "upper_camel_case"
+ ), "UPPER_CAMEL_CASE attribute is converted to upper_camel_case in python"
diff --git a/tests/inputs/config.py b/tests/inputs/config.py
new file mode 100644
index 0000000..6da1f88
--- /dev/null
+++ b/tests/inputs/config.py
@@ -0,0 +1,30 @@
+# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
+# Remove from list when fixed.
+xfail = {
+ "namespace_keywords", # 70
+ "googletypes_struct", # 9
+ "googletypes_value", # 9
+ "import_capitalized_package",
+ "example", # This is the example in the readme. Not a test.
+}
+
+services = {
+ "googletypes_request",
+ "googletypes_response",
+ "googletypes_response_embedded",
+ "service",
+ "service_separate_packages",
+ "import_service_input_message",
+ "googletypes_service_returns_empty",
+ "googletypes_service_returns_googletype",
+ "example_service",
+ "empty_service",
+ "service_uppercase",
+}
+
+
+# Indicate json sample messages to skip when testing that json (de)serialization
+# is symmetrical becuase some cases legitimately are not symmetrical.
+# Each key references the name of the test scenario and the values in the tuple
+# Are the names of the json files.
+non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}
diff --git a/tests/inputs/deprecated/deprecated.json b/tests/inputs/deprecated/deprecated.json
new file mode 100644
index 0000000..43b2b65
--- /dev/null
+++ b/tests/inputs/deprecated/deprecated.json
@@ -0,0 +1,6 @@
+{
+ "message": {
+ "value": "hello"
+ },
+ "value": 10
+}
diff --git a/tests/inputs/deprecated/deprecated.proto b/tests/inputs/deprecated/deprecated.proto
new file mode 100644
index 0000000..81d69c0
--- /dev/null
+++ b/tests/inputs/deprecated/deprecated.proto
@@ -0,0 +1,14 @@
+syntax = "proto3";
+
+package deprecated;
+
+// Some documentation about the Test message.
+message Test {
+ Message message = 1 [deprecated=true];
+ int32 value = 2;
+}
+
+message Message {
+ option deprecated = true;
+ string value = 1;
+}
diff --git a/tests/inputs/double/double-negative.json b/tests/inputs/double/double-negative.json
new file mode 100644
index 0000000..e0776c7
--- /dev/null
+++ b/tests/inputs/double/double-negative.json
@@ -0,0 +1,3 @@
+{
+ "count": -123.45
+}
diff --git a/tests/inputs/double/double.json b/tests/inputs/double/double.json
new file mode 100644
index 0000000..321412e
--- /dev/null
+++ b/tests/inputs/double/double.json
@@ -0,0 +1,3 @@
+{
+ "count": 123.45
+}
diff --git a/tests/inputs/double/double.proto b/tests/inputs/double/double.proto
new file mode 100644
index 0000000..66aea95
--- /dev/null
+++ b/tests/inputs/double/double.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package double;
+
+message Test {
+ double count = 1;
+}
diff --git a/tests/inputs/empty_repeated/empty_repeated.json b/tests/inputs/empty_repeated/empty_repeated.json
new file mode 100644
index 0000000..12a801c
--- /dev/null
+++ b/tests/inputs/empty_repeated/empty_repeated.json
@@ -0,0 +1,3 @@
+{
+ "msg": [{"values":[]}]
+}
diff --git a/tests/inputs/empty_repeated/empty_repeated.proto b/tests/inputs/empty_repeated/empty_repeated.proto
new file mode 100644
index 0000000..f787301
--- /dev/null
+++ b/tests/inputs/empty_repeated/empty_repeated.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package empty_repeated;
+
+message MessageA {
+ repeated float values = 1;
+}
+
+message Test {
+ repeated MessageA msg = 1;
+}
diff --git a/tests/inputs/empty_service/empty_service.proto b/tests/inputs/empty_service/empty_service.proto
new file mode 100644
index 0000000..e96ff64
--- /dev/null
+++ b/tests/inputs/empty_service/empty_service.proto
@@ -0,0 +1,7 @@
+/* Empty service without comments */
+syntax = "proto3";
+
+package empty_service;
+
+service Test {
+}
diff --git a/tests/inputs/entry/entry.proto b/tests/inputs/entry/entry.proto
new file mode 100644
index 0000000..3f2af4d
--- /dev/null
+++ b/tests/inputs/entry/entry.proto
@@ -0,0 +1,20 @@
+syntax = "proto3";
+
+package entry;
+
+// This is a minimal example of a repeated message field that caused issues when
+// checking whether a message is a map.
+//
+// During the check wheter a field is a "map", the string "entry" is added to
+// the field name, checked against the type name and then further checks are
+// made against the nested type of a parent message. In this edge-case, the
+// first check would pass even though it shouldn't and that would cause an
+// error because the parent type does not have a "nested_type" attribute.
+
+message Test {
+ repeated ExportEntry export = 1;
+}
+
+message ExportEntry {
+ string name = 1;
+}
diff --git a/tests/inputs/enum/enum.json b/tests/inputs/enum/enum.json
new file mode 100644
index 0000000..d68f1c5
--- /dev/null
+++ b/tests/inputs/enum/enum.json
@@ -0,0 +1,9 @@
+{
+ "choice": "FOUR",
+ "choices": [
+ "ZERO",
+ "ONE",
+ "THREE",
+ "FOUR"
+ ]
+}
diff --git a/tests/inputs/enum/enum.proto b/tests/inputs/enum/enum.proto
new file mode 100644
index 0000000..5e2e80c
--- /dev/null
+++ b/tests/inputs/enum/enum.proto
@@ -0,0 +1,25 @@
+syntax = "proto3";
+
+package enum;
+
+// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
+message Test {
+ Choice choice = 1;
+ repeated Choice choices = 2;
+}
+
+enum Choice {
+ ZERO = 0;
+ ONE = 1;
+ // TWO = 2;
+ FOUR = 4;
+ THREE = 3;
+}
+
+// A "C" like enum with the enum name prefixed onto members, these should be stripped
+enum ArithmeticOperator {
+ ARITHMETIC_OPERATOR_NONE = 0;
+ ARITHMETIC_OPERATOR_PLUS = 1;
+ ARITHMETIC_OPERATOR_MINUS = 2;
+ ARITHMETIC_OPERATOR_0_PREFIXED = 3;
+}
diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py
new file mode 100644
index 0000000..cf14c68
--- /dev/null
+++ b/tests/inputs/enum/test_enum.py
@@ -0,0 +1,114 @@
+from tests.output_aristaproto.enum import (
+ ArithmeticOperator,
+ Choice,
+ Test,
+)
+
+
+def test_enum_set_and_get():
+ assert Test(choice=Choice.ZERO).choice == Choice.ZERO
+ assert Test(choice=Choice.ONE).choice == Choice.ONE
+ assert Test(choice=Choice.THREE).choice == Choice.THREE
+ assert Test(choice=Choice.FOUR).choice == Choice.FOUR
+
+
+def test_enum_set_with_int():
+ assert Test(choice=0).choice == Choice.ZERO
+ assert Test(choice=1).choice == Choice.ONE
+ assert Test(choice=3).choice == Choice.THREE
+ assert Test(choice=4).choice == Choice.FOUR
+
+
+def test_enum_is_comparable_with_int():
+ assert Test(choice=Choice.ZERO).choice == 0
+ assert Test(choice=Choice.ONE).choice == 1
+ assert Test(choice=Choice.THREE).choice == 3
+ assert Test(choice=Choice.FOUR).choice == 4
+
+
+def test_enum_to_dict():
+ assert (
+ "choice" not in Test(choice=Choice.ZERO).to_dict()
+ ), "Default enum value is not serialized"
+ assert (
+ Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
+ == "ZERO"
+ )
+ assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE"
+ assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE"
+ assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR"
+
+
+def test_repeated_enum_is_comparable_with_int():
+ assert Test(choices=[Choice.ZERO]).choices == [0]
+ assert Test(choices=[Choice.ONE]).choices == [1]
+ assert Test(choices=[Choice.THREE]).choices == [3]
+ assert Test(choices=[Choice.FOUR]).choices == [4]
+
+
+def test_repeated_enum_set_and_get():
+ assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO]
+ assert Test(choices=[Choice.ONE]).choices == [Choice.ONE]
+ assert Test(choices=[Choice.THREE]).choices == [Choice.THREE]
+ assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR]
+
+
+def test_repeated_enum_to_dict():
+ assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"]
+ assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"]
+ assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"]
+ assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"]
+
+ all_enums_dict = Test(
+ choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]
+ ).to_dict()
+ assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]
+
+
+def test_repeated_enum_with_single_value_to_dict():
+ assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"]
+ assert Test(choices=1).to_dict()["choices"] == ["ONE"]
+
+
+def test_repeated_enum_with_non_list_iterables_to_dict():
+ assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
+ assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
+ assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [
+ "ONE",
+ "THREE",
+ ]
+
+ def enum_generator():
+ yield Choice.ONE
+ yield Choice.THREE
+
+ assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
+
+
+def test_enum_mapped_on_parse():
+ # test default value
+ b = Test().parse(bytes(Test()))
+ assert b.choice.name == Choice.ZERO.name
+ assert b.choices == []
+
+ # test non default value
+ a = Test().parse(bytes(Test(choice=Choice.ONE)))
+ assert a.choice.name == Choice.ONE.name
+ assert b.choices == []
+
+ # test repeated
+ c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
+ assert c.choices[0].name == Choice.THREE.name
+ assert c.choices[1].name == Choice.FOUR.name
+
+ # bonus: defaults after empty init are also mapped
+ assert Test().choice.name == Choice.ZERO.name
+
+
+def test_renamed_enum_members():
+ assert set(ArithmeticOperator.__members__) == {
+ "NONE",
+ "PLUS",
+ "MINUS",
+ "_0_PREFIXED",
+ }
diff --git a/tests/inputs/example/example.proto b/tests/inputs/example/example.proto
new file mode 100644
index 0000000..56bd364
--- /dev/null
+++ b/tests/inputs/example/example.proto
@@ -0,0 +1,911 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+// Author: kenton@google.com (Kenton Varda)
+// Based on original Protocol Buffers design by
+// Sanjay Ghemawat, Jeff Dean, and others.
+//
+// The messages in this file describe the definitions found in .proto files.
+// A valid .proto file can be translated directly to a FileDescriptorProto
+// without any other information (e.g. without reading its imports).
+
+
+syntax = "proto2";
+
+package example;
+
+// package google.protobuf;
+
+option go_package = "google.golang.org/protobuf/types/descriptorpb";
+option java_package = "com.google.protobuf";
+option java_outer_classname = "DescriptorProtos";
+option csharp_namespace = "Google.Protobuf.Reflection";
+option objc_class_prefix = "GPB";
+option cc_enable_arenas = true;
+
+// descriptor.proto must be optimized for speed because reflection-based
+// algorithms don't work during bootstrapping.
+option optimize_for = SPEED;
+
+// The protocol compiler can output a FileDescriptorSet containing the .proto
+// files it parses.
+message FileDescriptorSet {
+ repeated FileDescriptorProto file = 1;
+}
+
+// Describes a complete .proto file.
+message FileDescriptorProto {
+ optional string name = 1; // file name, relative to root of source tree
+ optional string package = 2; // e.g. "foo", "foo.bar", etc.
+
+ // Names of files imported by this file.
+ repeated string dependency = 3;
+ // Indexes of the public imported files in the dependency list above.
+ repeated int32 public_dependency = 10;
+ // Indexes of the weak imported files in the dependency list.
+ // For Google-internal migration only. Do not use.
+ repeated int32 weak_dependency = 11;
+
+ // All top-level definitions in this file.
+ repeated DescriptorProto message_type = 4;
+ repeated EnumDescriptorProto enum_type = 5;
+ repeated ServiceDescriptorProto service = 6;
+ repeated FieldDescriptorProto extension = 7;
+
+ optional FileOptions options = 8;
+
+ // This field contains optional information about the original source code.
+ // You may safely remove this entire field without harming runtime
+ // functionality of the descriptors -- the information is needed only by
+ // development tools.
+ optional SourceCodeInfo source_code_info = 9;
+
+ // The syntax of the proto file.
+ // The supported values are "proto2" and "proto3".
+ optional string syntax = 12;
+}
+
+// Describes a message type.
+message DescriptorProto {
+ optional string name = 1;
+
+ repeated FieldDescriptorProto field = 2;
+ repeated FieldDescriptorProto extension = 6;
+
+ repeated DescriptorProto nested_type = 3;
+ repeated EnumDescriptorProto enum_type = 4;
+
+ message ExtensionRange {
+ optional int32 start = 1; // Inclusive.
+ optional int32 end = 2; // Exclusive.
+
+ optional ExtensionRangeOptions options = 3;
+ }
+ repeated ExtensionRange extension_range = 5;
+
+ repeated OneofDescriptorProto oneof_decl = 8;
+
+ optional MessageOptions options = 7;
+
+ // Range of reserved tag numbers. Reserved tag numbers may not be used by
+ // fields or extension ranges in the same message. Reserved ranges may
+ // not overlap.
+ message ReservedRange {
+ optional int32 start = 1; // Inclusive.
+ optional int32 end = 2; // Exclusive.
+ }
+ repeated ReservedRange reserved_range = 9;
+ // Reserved field names, which may not be used by fields in the same message.
+ // A given name may only be reserved once.
+ repeated string reserved_name = 10;
+}
+
+message ExtensionRangeOptions {
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+}
+
+// Describes a field within a message.
+message FieldDescriptorProto {
+ enum Type {
+ // 0 is reserved for errors.
+ // Order is weird for historical reasons.
+ TYPE_DOUBLE = 1;
+ TYPE_FLOAT = 2;
+ // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT64 if
+ // negative values are likely.
+ TYPE_INT64 = 3;
+ TYPE_UINT64 = 4;
+ // Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT32 if
+ // negative values are likely.
+ TYPE_INT32 = 5;
+ TYPE_FIXED64 = 6;
+ TYPE_FIXED32 = 7;
+ TYPE_BOOL = 8;
+ TYPE_STRING = 9;
+ // Tag-delimited aggregate.
+ // Group type is deprecated and not supported in proto3. However, Proto3
+ // implementations should still be able to parse the group wire format and
+ // treat group fields as unknown fields.
+ TYPE_GROUP = 10;
+ TYPE_MESSAGE = 11; // Length-delimited aggregate.
+
+ // New in version 2.
+ TYPE_BYTES = 12;
+ TYPE_UINT32 = 13;
+ TYPE_ENUM = 14;
+ TYPE_SFIXED32 = 15;
+ TYPE_SFIXED64 = 16;
+ TYPE_SINT32 = 17; // Uses ZigZag encoding.
+ TYPE_SINT64 = 18; // Uses ZigZag encoding.
+ }
+
+ enum Label {
+ // 0 is reserved for errors
+ LABEL_OPTIONAL = 1;
+ LABEL_REQUIRED = 2;
+ LABEL_REPEATED = 3;
+ }
+
+ optional string name = 1;
+ optional int32 number = 3;
+ optional Label label = 4;
+
+ // If type_name is set, this need not be set. If both this and type_name
+ // are set, this must be one of TYPE_ENUM, TYPE_MESSAGE or TYPE_GROUP.
+ optional Type type = 5;
+
+ // For message and enum types, this is the name of the type. If the name
+ // starts with a '.', it is fully-qualified. Otherwise, C++-like scoping
+ // rules are used to find the type (i.e. first the nested types within this
+ // message are searched, then within the parent, on up to the root
+ // namespace).
+ optional string type_name = 6;
+
+ // For extensions, this is the name of the type being extended. It is
+ // resolved in the same manner as type_name.
+ optional string extendee = 2;
+
+ // For numeric types, contains the original text representation of the value.
+ // For booleans, "true" or "false".
+ // For strings, contains the default text contents (not escaped in any way).
+ // For bytes, contains the C escaped value. All bytes >= 128 are escaped.
+ // TODO(kenton): Base-64 encode?
+ optional string default_value = 7;
+
+ // If set, gives the index of a oneof in the containing type's oneof_decl
+ // list. This field is a member of that oneof.
+ optional int32 oneof_index = 9;
+
+ // JSON name of this field. The value is set by protocol compiler. If the
+ // user has set a "json_name" option on this field, that option's value
+ // will be used. Otherwise, it's deduced from the field's name by converting
+ // it to camelCase.
+ optional string json_name = 10;
+
+ optional FieldOptions options = 8;
+
+ // If true, this is a proto3 "optional". When a proto3 field is optional, it
+ // tracks presence regardless of field type.
+ //
+ // When proto3_optional is true, this field must be belong to a oneof to
+ // signal to old proto3 clients that presence is tracked for this field. This
+ // oneof is known as a "synthetic" oneof, and this field must be its sole
+ // member (each proto3 optional field gets its own synthetic oneof). Synthetic
+ // oneofs exist in the descriptor only, and do not generate any API. Synthetic
+ // oneofs must be ordered after all "real" oneofs.
+ //
+ // For message fields, proto3_optional doesn't create any semantic change,
+ // since non-repeated message fields always track presence. However it still
+ // indicates the semantic detail of whether the user wrote "optional" or not.
+ // This can be useful for round-tripping the .proto file. For consistency we
+ // give message fields a synthetic oneof also, even though it is not required
+ // to track presence. This is especially important because the parser can't
+ // tell if a field is a message or an enum, so it must always create a
+ // synthetic oneof.
+ //
+ // Proto2 optional fields do not set this flag, because they already indicate
+ // optional with `LABEL_OPTIONAL`.
+ optional bool proto3_optional = 17;
+}
+
+// Describes a oneof.
+message OneofDescriptorProto {
+ optional string name = 1;
+ optional OneofOptions options = 2;
+}
+
+// Describes an enum type.
+message EnumDescriptorProto {
+ optional string name = 1;
+
+ repeated EnumValueDescriptorProto value = 2;
+
+ optional EnumOptions options = 3;
+
+ // Range of reserved numeric values. Reserved values may not be used by
+ // entries in the same enum. Reserved ranges may not overlap.
+ //
+ // Note that this is distinct from DescriptorProto.ReservedRange in that it
+ // is inclusive such that it can appropriately represent the entire int32
+ // domain.
+ message EnumReservedRange {
+ optional int32 start = 1; // Inclusive.
+ optional int32 end = 2; // Inclusive.
+ }
+
+ // Range of reserved numeric values. Reserved numeric values may not be used
+ // by enum values in the same enum declaration. Reserved ranges may not
+ // overlap.
+ repeated EnumReservedRange reserved_range = 4;
+
+ // Reserved enum value names, which may not be reused. A given name may only
+ // be reserved once.
+ repeated string reserved_name = 5;
+}
+
+// Describes a value within an enum.
+message EnumValueDescriptorProto {
+ optional string name = 1;
+ optional int32 number = 2;
+
+ optional EnumValueOptions options = 3;
+}
+
+// Describes a service.
+message ServiceDescriptorProto {
+ optional string name = 1;
+ repeated MethodDescriptorProto method = 2;
+
+ optional ServiceOptions options = 3;
+}
+
+// Describes a method of a service.
+message MethodDescriptorProto {
+ optional string name = 1;
+
+ // Input and output type names. These are resolved in the same way as
+ // FieldDescriptorProto.type_name, but must refer to a message type.
+ optional string input_type = 2;
+ optional string output_type = 3;
+
+ optional MethodOptions options = 4;
+
+ // Identifies if client streams multiple client messages
+ optional bool client_streaming = 5 [default = false];
+ // Identifies if server streams multiple server messages
+ optional bool server_streaming = 6 [default = false];
+}
+
+
+// ===================================================================
+// Options
+
+// Each of the definitions above may have "options" attached. These are
+// just annotations which may cause code to be generated slightly differently
+// or may contain hints for code that manipulates protocol messages.
+//
+// Clients may define custom options as extensions of the *Options messages.
+// These extensions may not yet be known at parsing time, so the parser cannot
+// store the values in them. Instead it stores them in a field in the *Options
+// message called uninterpreted_option. This field must have the same name
+// across all *Options messages. We then use this field to populate the
+// extensions when we build a descriptor, at which point all protos have been
+// parsed and so all extensions are known.
+//
+// Extension numbers for custom options may be chosen as follows:
+// * For options which will only be used within a single application or
+// organization, or for experimental options, use field numbers 50000
+// through 99999. It is up to you to ensure that you do not use the
+// same number for multiple options.
+// * For options which will be published and used publicly by multiple
+// independent entities, e-mail protobuf-global-extension-registry@google.com
+// to reserve extension numbers. Simply provide your project name (e.g.
+// Objective-C plugin) and your project website (if available) -- there's no
+// need to explain how you intend to use them. Usually you only need one
+// extension number. You can declare multiple options with only one extension
+// number by putting them in a sub-message. See the Custom Options section of
+// the docs for examples:
+// https://developers.google.com/protocol-buffers/docs/proto#options
+// If this turns out to be popular, a web service will be set up
+// to automatically assign option numbers.
+
+message FileOptions {
+
+ // Sets the Java package where classes generated from this .proto will be
+ // placed. By default, the proto package is used, but this is often
+ // inappropriate because proto packages do not normally start with backwards
+ // domain names.
+ optional string java_package = 1;
+
+
+ // If set, all the classes from the .proto file are wrapped in a single
+ // outer class with the given name. This applies to both Proto1
+ // (equivalent to the old "--one_java_file" option) and Proto2 (where
+ // a .proto always translates to a single class, but you may want to
+ // explicitly choose the class name).
+ optional string java_outer_classname = 8;
+
+ // If set true, then the Java code generator will generate a separate .java
+ // file for each top-level message, enum, and service defined in the .proto
+ // file. Thus, these types will *not* be nested inside the outer class
+ // named by java_outer_classname. However, the outer class will still be
+ // generated to contain the file's getDescriptor() method as well as any
+ // top-level extensions defined in the file.
+ optional bool java_multiple_files = 10 [default = false];
+
+ // This option does nothing.
+ optional bool java_generate_equals_and_hash = 20 [deprecated=true];
+
+ // If set true, then the Java2 code generator will generate code that
+ // throws an exception whenever an attempt is made to assign a non-UTF-8
+ // byte sequence to a string field.
+ // Message reflection will do the same.
+ // However, an extension field still accepts non-UTF-8 byte sequences.
+ // This option has no effect on when used with the lite runtime.
+ optional bool java_string_check_utf8 = 27 [default = false];
+
+
+ // Generated classes can be optimized for speed or code size.
+ enum OptimizeMode {
+ SPEED = 1; // Generate complete code for parsing, serialization,
+ // etc.
+ CODE_SIZE = 2; // Use ReflectionOps to implement these methods.
+ LITE_RUNTIME = 3; // Generate code using MessageLite and the lite runtime.
+ }
+ optional OptimizeMode optimize_for = 9 [default = SPEED];
+
+ // Sets the Go package where structs generated from this .proto will be
+ // placed. If omitted, the Go package will be derived from the following:
+ // - The basename of the package import path, if provided.
+ // - Otherwise, the package statement in the .proto file, if present.
+ // - Otherwise, the basename of the .proto file, without extension.
+ optional string go_package = 11;
+
+
+
+
+ // Should generic services be generated in each language? "Generic" services
+ // are not specific to any particular RPC system. They are generated by the
+ // main code generators in each language (without additional plugins).
+ // Generic services were the only kind of service generation supported by
+ // early versions of google.protobuf.
+ //
+ // Generic services are now considered deprecated in favor of using plugins
+ // that generate code specific to your particular RPC system. Therefore,
+ // these default to false. Old code which depends on generic services should
+ // explicitly set them to true.
+ optional bool cc_generic_services = 16 [default = false];
+ optional bool java_generic_services = 17 [default = false];
+ optional bool py_generic_services = 18 [default = false];
+ optional bool php_generic_services = 42 [default = false];
+
+ // Is this file deprecated?
+ // Depending on the target platform, this can emit Deprecated annotations
+ // for everything in the file, or it will be completely ignored; in the very
+ // least, this is a formalization for deprecating files.
+ optional bool deprecated = 23 [default = false];
+
+ // Enables the use of arenas for the proto messages in this file. This applies
+ // only to generated classes for C++.
+ optional bool cc_enable_arenas = 31 [default = true];
+
+
+ // Sets the objective c class prefix which is prepended to all objective c
+ // generated classes from this .proto. There is no default.
+ optional string objc_class_prefix = 36;
+
+ // Namespace for generated classes; defaults to the package.
+ optional string csharp_namespace = 37;
+
+ // By default Swift generators will take the proto package and CamelCase it
+ // replacing '.' with underscore and use that to prefix the types/symbols
+ // defined. When this options is provided, they will use this value instead
+ // to prefix the types/symbols defined.
+ optional string swift_prefix = 39;
+
+ // Sets the php class prefix which is prepended to all php generated classes
+ // from this .proto. Default is empty.
+ optional string php_class_prefix = 40;
+
+ // Use this option to change the namespace of php generated classes. Default
+ // is empty. When this option is empty, the package name will be used for
+ // determining the namespace.
+ optional string php_namespace = 41;
+
+ // Use this option to change the namespace of php generated metadata classes.
+ // Default is empty. When this option is empty, the proto file name will be
+ // used for determining the namespace.
+ optional string php_metadata_namespace = 44;
+
+ // Use this option to change the package of ruby generated classes. Default
+ // is empty. When this option is not set, the package name will be used for
+ // determining the ruby package.
+ optional string ruby_package = 45;
+
+
+ // The parser stores options it doesn't recognize here.
+ // See the documentation for the "Options" section above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message.
+ // See the documentation for the "Options" section above.
+ extensions 1000 to max;
+
+ reserved 38;
+}
+
+message MessageOptions {
+ // Set true to use the old proto1 MessageSet wire format for extensions.
+ // This is provided for backwards-compatibility with the MessageSet wire
+ // format. You should not use this for any other reason: It's less
+ // efficient, has fewer features, and is more complicated.
+ //
+ // The message must be defined exactly as follows:
+ // message Foo {
+ // option message_set_wire_format = true;
+ // extensions 4 to max;
+ // }
+ // Note that the message cannot have any defined fields; MessageSets only
+ // have extensions.
+ //
+ // All extensions of your type must be singular messages; e.g. they cannot
+ // be int32s, enums, or repeated messages.
+ //
+ // Because this is an option, the above two restrictions are not enforced by
+ // the protocol compiler.
+ optional bool message_set_wire_format = 1 [default = false];
+
+ // Disables the generation of the standard "descriptor()" accessor, which can
+ // conflict with a field of the same name. This is meant to make migration
+ // from proto1 easier; new code should avoid fields named "descriptor".
+ optional bool no_standard_descriptor_accessor = 2 [default = false];
+
+ // Is this message deprecated?
+ // Depending on the target platform, this can emit Deprecated annotations
+ // for the message, or it will be completely ignored; in the very least,
+ // this is a formalization for deprecating messages.
+ optional bool deprecated = 3 [default = false];
+
+ // Whether the message is an automatically generated map entry type for the
+ // maps field.
+ //
+ // For maps fields:
+ // map<KeyType, ValueType> map_field = 1;
+ // The parsed descriptor looks like:
+ // message MapFieldEntry {
+ // option map_entry = true;
+ // optional KeyType key = 1;
+ // optional ValueType value = 2;
+ // }
+ // repeated MapFieldEntry map_field = 1;
+ //
+ // Implementations may choose not to generate the map_entry=true message, but
+ // use a native map in the target language to hold the keys and values.
+ // The reflection APIs in such implementations still need to work as
+ // if the field is a repeated message field.
+ //
+ // NOTE: Do not set the option in .proto files. Always use the maps syntax
+ // instead. The option should only be implicitly set by the proto compiler
+ // parser.
+ optional bool map_entry = 7;
+
+ reserved 8; // javalite_serializable
+ reserved 9; // javanano_as_lite
+
+
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+}
+
+message FieldOptions {
+ // The ctype option instructs the C++ code generator to use a different
+ // representation of the field than it normally would. See the specific
+ // options below. This option is not yet implemented in the open source
+ // release -- sorry, we'll try to include it in a future version!
+ optional CType ctype = 1 [default = STRING];
+ enum CType {
+ // Default mode.
+ STRING = 0;
+
+ CORD = 1;
+
+ STRING_PIECE = 2;
+ }
+ // The packed option can be enabled for repeated primitive fields to enable
+ // a more efficient representation on the wire. Rather than repeatedly
+ // writing the tag and type for each element, the entire array is encoded as
+ // a single length-delimited blob. In proto3, only explicit setting it to
+ // false will avoid using packed encoding.
+ optional bool packed = 2;
+
+ // The jstype option determines the JavaScript type used for values of the
+ // field. The option is permitted only for 64 bit integral and fixed types
+ // (int64, uint64, sint64, fixed64, sfixed64). A field with jstype JS_STRING
+ // is represented as JavaScript string, which avoids loss of precision that
+ // can happen when a large value is converted to a floating point JavaScript.
+ // Specifying JS_NUMBER for the jstype causes the generated JavaScript code to
+ // use the JavaScript "number" type. The behavior of the default option
+ // JS_NORMAL is implementation dependent.
+ //
+ // This option is an enum to permit additional types to be added, e.g.
+ // goog.math.Integer.
+ optional JSType jstype = 6 [default = JS_NORMAL];
+ enum JSType {
+ // Use the default type.
+ JS_NORMAL = 0;
+
+ // Use JavaScript strings.
+ JS_STRING = 1;
+
+ // Use JavaScript numbers.
+ JS_NUMBER = 2;
+ }
+
+ // Should this field be parsed lazily? Lazy applies only to message-type
+ // fields. It means that when the outer message is initially parsed, the
+ // inner message's contents will not be parsed but instead stored in encoded
+ // form. The inner message will actually be parsed when it is first accessed.
+ //
+ // This is only a hint. Implementations are free to choose whether to use
+ // eager or lazy parsing regardless of the value of this option. However,
+ // setting this option true suggests that the protocol author believes that
+ // using lazy parsing on this field is worth the additional bookkeeping
+ // overhead typically needed to implement it.
+ //
+ // This option does not affect the public interface of any generated code;
+ // all method signatures remain the same. Furthermore, thread-safety of the
+ // interface is not affected by this option; const methods remain safe to
+ // call from multiple threads concurrently, while non-const methods continue
+ // to require exclusive access.
+ //
+ //
+ // Note that implementations may choose not to check required fields within
+ // a lazy sub-message. That is, calling IsInitialized() on the outer message
+ // may return true even if the inner message has missing required fields.
+ // This is necessary because otherwise the inner message would have to be
+ // parsed in order to perform the check, defeating the purpose of lazy
+ // parsing. An implementation which chooses not to check required fields
+ // must be consistent about it. That is, for any particular sub-message, the
+ // implementation must either *always* check its required fields, or *never*
+ // check its required fields, regardless of whether or not the message has
+ // been parsed.
+ optional bool lazy = 5 [default = false];
+
+ // Is this field deprecated?
+ // Depending on the target platform, this can emit Deprecated annotations
+ // for accessors, or it will be completely ignored; in the very least, this
+ // is a formalization for deprecating fields.
+ optional bool deprecated = 3 [default = false];
+
+ // For Google-internal migration only. Do not use.
+ optional bool weak = 10 [default = false];
+
+
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+
+ reserved 4; // removed jtype
+}
+
+message OneofOptions {
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+}
+
+message EnumOptions {
+
+ // Set this option to true to allow mapping different tag names to the same
+ // value.
+ optional bool allow_alias = 2;
+
+ // Is this enum deprecated?
+ // Depending on the target platform, this can emit Deprecated annotations
+ // for the enum, or it will be completely ignored; in the very least, this
+ // is a formalization for deprecating enums.
+ optional bool deprecated = 3 [default = false];
+
+ reserved 5; // javanano_as_lite
+
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+}
+
+message EnumValueOptions {
+ // Is this enum value deprecated?
+ // Depending on the target platform, this can emit Deprecated annotations
+ // for the enum value, or it will be completely ignored; in the very least,
+ // this is a formalization for deprecating enum values.
+ optional bool deprecated = 1 [default = false];
+
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+}
+
+message ServiceOptions {
+
+ // Note: Field numbers 1 through 32 are reserved for Google's internal RPC
+ // framework. We apologize for hoarding these numbers to ourselves, but
+ // we were already using them long before we decided to release Protocol
+ // Buffers.
+
+ // Is this service deprecated?
+ // Depending on the target platform, this can emit Deprecated annotations
+ // for the service, or it will be completely ignored; in the very least,
+ // this is a formalization for deprecating services.
+ optional bool deprecated = 33 [default = false];
+
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+}
+
+message MethodOptions {
+
+ // Note: Field numbers 1 through 32 are reserved for Google's internal RPC
+ // framework. We apologize for hoarding these numbers to ourselves, but
+ // we were already using them long before we decided to release Protocol
+ // Buffers.
+
+ // Is this method deprecated?
+ // Depending on the target platform, this can emit Deprecated annotations
+ // for the method, or it will be completely ignored; in the very least,
+ // this is a formalization for deprecating methods.
+ optional bool deprecated = 33 [default = false];
+
+ // Is this method side-effect-free (or safe in HTTP parlance), or idempotent,
+ // or neither? HTTP based RPC implementation may choose GET verb for safe
+ // methods, and PUT verb for idempotent methods instead of the default POST.
+ enum IdempotencyLevel {
+ IDEMPOTENCY_UNKNOWN = 0;
+ NO_SIDE_EFFECTS = 1; // implies idempotent
+ IDEMPOTENT = 2; // idempotent, but may have side effects
+ }
+ optional IdempotencyLevel idempotency_level = 34
+ [default = IDEMPOTENCY_UNKNOWN];
+
+ // The parser stores options it doesn't recognize here. See above.
+ repeated UninterpretedOption uninterpreted_option = 999;
+
+ // Clients can define custom options in extensions of this message. See above.
+ extensions 1000 to max;
+}
+
+
+// A message representing a option the parser does not recognize. This only
+// appears in options protos created by the compiler::Parser class.
+// DescriptorPool resolves these when building Descriptor objects. Therefore,
+// options protos in descriptor objects (e.g. returned by Descriptor::options(),
+// or produced by Descriptor::CopyTo()) will never have UninterpretedOptions
+// in them.
+message UninterpretedOption {
+ // The name of the uninterpreted option. Each string represents a segment in
+ // a dot-separated name. is_extension is true iff a segment represents an
+ // extension (denoted with parentheses in options specs in .proto files).
+ // E.g.,{ ["foo", false], ["bar.baz", true], ["qux", false] } represents
+ // "foo.(bar.baz).qux".
+ message NamePart {
+ required string name_part = 1;
+ required bool is_extension = 2;
+ }
+ repeated NamePart name = 2;
+
+ // The value of the uninterpreted option, in whatever type the tokenizer
+ // identified it as during parsing. Exactly one of these should be set.
+ optional string identifier_value = 3;
+ optional uint64 positive_int_value = 4;
+ optional int64 negative_int_value = 5;
+ optional double double_value = 6;
+ optional bytes string_value = 7;
+ optional string aggregate_value = 8;
+}
+
+// ===================================================================
+// Optional source code info
+
+// Encapsulates information about the original source file from which a
+// FileDescriptorProto was generated.
+message SourceCodeInfo {
+ // A Location identifies a piece of source code in a .proto file which
+ // corresponds to a particular definition. This information is intended
+ // to be useful to IDEs, code indexers, documentation generators, and similar
+ // tools.
+ //
+ // For example, say we have a file like:
+ // message Foo {
+ // optional string foo = 1;
+ // }
+ // Let's look at just the field definition:
+ // optional string foo = 1;
+ // ^ ^^ ^^ ^ ^^^
+ // a bc de f ghi
+ // We have the following locations:
+ // span path represents
+ // [a,i) [ 4, 0, 2, 0 ] The whole field definition.
+ // [a,b) [ 4, 0, 2, 0, 4 ] The label (optional).
+ // [c,d) [ 4, 0, 2, 0, 5 ] The type (string).
+ // [e,f) [ 4, 0, 2, 0, 1 ] The name (foo).
+ // [g,h) [ 4, 0, 2, 0, 3 ] The number (1).
+ //
+ // Notes:
+ // - A location may refer to a repeated field itself (i.e. not to any
+ // particular index within it). This is used whenever a set of elements are
+ // logically enclosed in a single code segment. For example, an entire
+ // extend block (possibly containing multiple extension definitions) will
+ // have an outer location whose path refers to the "extensions" repeated
+ // field without an index.
+ // - Multiple locations may have the same path. This happens when a single
+ // logical declaration is spread out across multiple places. The most
+ // obvious example is the "extend" block again -- there may be multiple
+ // extend blocks in the same scope, each of which will have the same path.
+ // - A location's span is not always a subset of its parent's span. For
+ // example, the "extendee" of an extension declaration appears at the
+ // beginning of the "extend" block and is shared by all extensions within
+ // the block.
+ // - Just because a location's span is a subset of some other location's span
+ // does not mean that it is a descendant. For example, a "group" defines
+ // both a type and a field in a single declaration. Thus, the locations
+ // corresponding to the type and field and their components will overlap.
+ // - Code which tries to interpret locations should probably be designed to
+ // ignore those that it doesn't understand, as more types of locations could
+ // be recorded in the future.
+ repeated Location location = 1;
+ message Location {
+ // Identifies which part of the FileDescriptorProto was defined at this
+ // location.
+ //
+ // Each element is a field number or an index. They form a path from
+ // the root FileDescriptorProto to the place where the definition. For
+ // example, this path:
+ // [ 4, 3, 2, 7, 1 ]
+ // refers to:
+ // file.message_type(3) // 4, 3
+ // .field(7) // 2, 7
+ // .name() // 1
+ // This is because FileDescriptorProto.message_type has field number 4:
+ // repeated DescriptorProto message_type = 4;
+ // and DescriptorProto.field has field number 2:
+ // repeated FieldDescriptorProto field = 2;
+ // and FieldDescriptorProto.name has field number 1:
+ // optional string name = 1;
+ //
+ // Thus, the above path gives the location of a field name. If we removed
+ // the last element:
+ // [ 4, 3, 2, 7 ]
+ // this path refers to the whole field declaration (from the beginning
+ // of the label to the terminating semicolon).
+ repeated int32 path = 1 [packed = true];
+
+ // Always has exactly three or four elements: start line, start column,
+ // end line (optional, otherwise assumed same as start line), end column.
+ // These are packed into a single field for efficiency. Note that line
+ // and column numbers are zero-based -- typically you will want to add
+ // 1 to each before displaying to a user.
+ repeated int32 span = 2 [packed = true];
+
+ // If this SourceCodeInfo represents a complete declaration, these are any
+ // comments appearing before and after the declaration which appear to be
+ // attached to the declaration.
+ //
+ // A series of line comments appearing on consecutive lines, with no other
+ // tokens appearing on those lines, will be treated as a single comment.
+ //
+ // leading_detached_comments will keep paragraphs of comments that appear
+ // before (but not connected to) the current element. Each paragraph,
+ // separated by empty lines, will be one comment element in the repeated
+ // field.
+ //
+ // Only the comment content is provided; comment markers (e.g. //) are
+ // stripped out. For block comments, leading whitespace and an asterisk
+ // will be stripped from the beginning of each line other than the first.
+ // Newlines are included in the output.
+ //
+ // Examples:
+ //
+ // optional int32 foo = 1; // Comment attached to foo.
+ // // Comment attached to bar.
+ // optional int32 bar = 2;
+ //
+ // optional string baz = 3;
+ // // Comment attached to baz.
+ // // Another line attached to baz.
+ //
+ // // Comment attached to qux.
+ // //
+ // // Another line attached to qux.
+ // optional double qux = 4;
+ //
+ // // Detached comment for corge. This is not leading or trailing comments
+ // // to qux or corge because there are blank lines separating it from
+ // // both.
+ //
+ // // Detached comment for corge paragraph 2.
+ //
+ // optional string corge = 5;
+ // /* Block comment attached
+ // * to corge. Leading asterisks
+ // * will be removed. */
+ // /* Block comment attached to
+ // * grault. */
+ // optional int32 grault = 6;
+ //
+ // // ignored detached comments.
+ optional string leading_comments = 3;
+ optional string trailing_comments = 4;
+ repeated string leading_detached_comments = 6;
+ }
+}
+
+// Describes the relationship between generated code and its original source
+// file. A GeneratedCodeInfo message is associated with only one generated
+// source file, but may contain references to different source .proto files.
+message GeneratedCodeInfo {
+ // An Annotation connects some span of text in generated code to an element
+ // of its generating .proto file.
+ repeated Annotation annotation = 1;
+ message Annotation {
+ // Identifies the element in the original source .proto file. This field
+ // is formatted the same as SourceCodeInfo.Location.path.
+ repeated int32 path = 1 [packed = true];
+
+ // Identifies the filesystem path to the original source .proto.
+ optional string source_file = 2;
+
+ // Identifies the starting offset in bytes in the generated code
+ // that relates to the identified object.
+ optional int32 begin = 3;
+
+ // Identifies the ending offset in bytes in the generated code that
+ // relates to the identified offset. The end offset should be one past
+ // the last relevant byte (so the length of the text = end - begin).
+ optional int32 end = 4;
+ }
+}
diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto
new file mode 100644
index 0000000..96455cc
--- /dev/null
+++ b/tests/inputs/example_service/example_service.proto
@@ -0,0 +1,20 @@
+syntax = "proto3";
+
+package example_service;
+
+service Test {
+ rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse);
+ rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse);
+ rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse);
+ rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse);
+}
+
+message ExampleRequest {
+ string example_string = 1;
+ int64 example_integer = 2;
+}
+
+message ExampleResponse {
+ string example_string = 1;
+ int64 example_integer = 2;
+}
diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py
new file mode 100644
index 0000000..551e3fe
--- /dev/null
+++ b/tests/inputs/example_service/test_example_service.py
@@ -0,0 +1,86 @@
+from typing import (
+ AsyncIterable,
+ AsyncIterator,
+)
+
+import pytest
+from grpclib.testing import ChannelFor
+
+from tests.output_aristaproto.example_service import (
+ ExampleRequest,
+ ExampleResponse,
+ TestBase,
+ TestStub,
+)
+
+
+class ExampleService(TestBase):
+ async def example_unary_unary(
+ self, example_request: ExampleRequest
+ ) -> "ExampleResponse":
+ return ExampleResponse(
+ example_string=example_request.example_string,
+ example_integer=example_request.example_integer,
+ )
+
+ async def example_unary_stream(
+ self, example_request: ExampleRequest
+ ) -> AsyncIterator["ExampleResponse"]:
+ response = ExampleResponse(
+ example_string=example_request.example_string,
+ example_integer=example_request.example_integer,
+ )
+ yield response
+ yield response
+ yield response
+
+ async def example_stream_unary(
+ self, example_request_iterator: AsyncIterator["ExampleRequest"]
+ ) -> "ExampleResponse":
+ async for example_request in example_request_iterator:
+ return ExampleResponse(
+ example_string=example_request.example_string,
+ example_integer=example_request.example_integer,
+ )
+
+ async def example_stream_stream(
+ self, example_request_iterator: AsyncIterator["ExampleRequest"]
+ ) -> AsyncIterator["ExampleResponse"]:
+ async for example_request in example_request_iterator:
+ yield ExampleResponse(
+ example_string=example_request.example_string,
+ example_integer=example_request.example_integer,
+ )
+
+
+@pytest.mark.asyncio
+async def test_calls_with_different_cardinalities():
+ example_request = ExampleRequest("test string", 42)
+
+ async with ChannelFor([ExampleService()]) as channel:
+ stub = TestStub(channel)
+
+ # unary unary
+ response = await stub.example_unary_unary(example_request)
+ assert response.example_string == example_request.example_string
+ assert response.example_integer == example_request.example_integer
+
+ # unary stream
+ async for response in stub.example_unary_stream(example_request):
+ assert response.example_string == example_request.example_string
+ assert response.example_integer == example_request.example_integer
+
+ # stream unary
+ async def request_iterator():
+ yield example_request
+ yield example_request
+ yield example_request
+
+ response = await stub.example_stream_unary(request_iterator())
+ assert response.example_string == example_request.example_string
+ assert response.example_integer == example_request.example_integer
+
+ # stream stream
+ async for response in stub.example_stream_stream(request_iterator()):
+ assert response.example_string == example_request.example_string
+ assert response.example_integer == example_request.example_integer
diff --git a/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json
new file mode 100644
index 0000000..7a6e7ae
--- /dev/null
+++ b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.json
@@ -0,0 +1,7 @@
+{
+ "int": 26,
+ "float": 26.0,
+ "str": "value-for-str",
+ "bytes": "001a",
+ "bool": true
+} \ No newline at end of file
diff --git a/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto
new file mode 100644
index 0000000..81a0fc4
--- /dev/null
+++ b/tests/inputs/field_name_identical_to_type/field_name_identical_to_type.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+package field_name_identical_to_type;
+
+// Tests that messages may contain fields with names that are identical to their python types (PR #294)
+
+message Test {
+ int32 int = 1;
+ float float = 2;
+ string str = 3;
+ bytes bytes = 4;
+ bool bool = 5;
+} \ No newline at end of file
diff --git a/tests/inputs/fixed/fixed.json b/tests/inputs/fixed/fixed.json
new file mode 100644
index 0000000..8858780
--- /dev/null
+++ b/tests/inputs/fixed/fixed.json
@@ -0,0 +1,6 @@
+{
+ "foo": 4294967295,
+ "bar": -2147483648,
+ "baz": "18446744073709551615",
+ "qux": "-9223372036854775808"
+}
diff --git a/tests/inputs/fixed/fixed.proto b/tests/inputs/fixed/fixed.proto
new file mode 100644
index 0000000..0f0ffb4
--- /dev/null
+++ b/tests/inputs/fixed/fixed.proto
@@ -0,0 +1,10 @@
+syntax = "proto3";
+
+package fixed;
+
+message Test {
+ fixed32 foo = 1;
+ sfixed32 bar = 2;
+ fixed64 baz = 3;
+ sfixed64 qux = 4;
+}
diff --git a/tests/inputs/float/float.json b/tests/inputs/float/float.json
new file mode 100644
index 0000000..3adac97
--- /dev/null
+++ b/tests/inputs/float/float.json
@@ -0,0 +1,9 @@
+{
+ "positive": "Infinity",
+ "negative": "-Infinity",
+ "nan": "NaN",
+ "three": 3.0,
+ "threePointOneFour": 3.14,
+ "negThree": -3.0,
+ "negThreePointOneFour": -3.14
+ }
diff --git a/tests/inputs/float/float.proto b/tests/inputs/float/float.proto
new file mode 100644
index 0000000..fea12b3
--- /dev/null
+++ b/tests/inputs/float/float.proto
@@ -0,0 +1,14 @@
+syntax = "proto3";
+
+package float;
+
+// Some documentation about the Test message.
+message Test {
+ double positive = 1;
+ double negative = 2;
+ double nan = 3;
+ double three = 4;
+ double three_point_one_four = 5;
+ double neg_three = 6;
+ double neg_three_point_one_four = 7;
+}
diff --git a/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto b/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto
new file mode 100644
index 0000000..66ef8a6
--- /dev/null
+++ b/tests/inputs/google_impl_behavior_equivalence/google_impl_behavior_equivalence.proto
@@ -0,0 +1,22 @@
+syntax = "proto3";
+
+import "google/protobuf/timestamp.proto";
+package google_impl_behavior_equivalence;
+
+message Foo { int64 bar = 1; }
+
+message Test {
+ oneof group {
+ string string = 1;
+ int64 integer = 2;
+ Foo foo = 3;
+ }
+}
+
+message Spam {
+ google.protobuf.Timestamp ts = 1;
+}
+
+message Request { Empty foo = 1; }
+
+message Empty {}
diff --git a/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py
new file mode 100644
index 0000000..c621f11
--- /dev/null
+++ b/tests/inputs/google_impl_behavior_equivalence/test_google_impl_behavior_equivalence.py
@@ -0,0 +1,93 @@
+from datetime import (
+ datetime,
+ timezone,
+)
+
+import pytest
+from google.protobuf import json_format
+from google.protobuf.timestamp_pb2 import Timestamp
+
+import aristaproto
+from tests.output_aristaproto.google_impl_behavior_equivalence import (
+ Empty,
+ Foo,
+ Request,
+ Spam,
+ Test,
+)
+from tests.output_reference.google_impl_behavior_equivalence.google_impl_behavior_equivalence_pb2 import (
+ Empty as ReferenceEmpty,
+ Foo as ReferenceFoo,
+ Request as ReferenceRequest,
+ Spam as ReferenceSpam,
+ Test as ReferenceTest,
+)
+
+
+def test_oneof_serializes_similar_to_google_oneof():
+ tests = [
+ (Test(string="abc"), ReferenceTest(string="abc")),
+ (Test(integer=2), ReferenceTest(integer=2)),
+ (Test(foo=Foo(bar=1)), ReferenceTest(foo=ReferenceFoo(bar=1))),
+ # Default values should also behave the same within oneofs
+ (Test(string=""), ReferenceTest(string="")),
+ (Test(integer=0), ReferenceTest(integer=0)),
+ (Test(foo=Foo(bar=0)), ReferenceTest(foo=ReferenceFoo(bar=0))),
+ ]
+ for message, message_reference in tests:
+ # NOTE: As of July 2020, MessageToJson inserts newlines in the output string so,
+ # just compare dicts
+ assert message.to_dict() == json_format.MessageToDict(message_reference)
+
+
+def test_bytes_are_the_same_for_oneof():
+ message = Test(string="")
+ message_reference = ReferenceTest(string="")
+
+ message_bytes = bytes(message)
+ message_reference_bytes = message_reference.SerializeToString()
+
+ assert message_bytes == message_reference_bytes
+
+ message2 = Test().parse(message_reference_bytes)
+ message_reference2 = ReferenceTest()
+ message_reference2.ParseFromString(message_reference_bytes)
+
+ assert message == message2
+ assert message_reference == message_reference2
+
+ # None of these fields were explicitly set BUT they should not actually be null
+ # themselves
+ assert not hasattr(message, "foo")
+ assert object.__getattribute__(message, "foo") == aristaproto.PLACEHOLDER
+ assert not hasattr(message2, "foo")
+ assert object.__getattribute__(message2, "foo") == aristaproto.PLACEHOLDER
+
+ assert isinstance(message_reference.foo, ReferenceFoo)
+ assert isinstance(message_reference2.foo, ReferenceFoo)
+
+
+@pytest.mark.parametrize("dt", (datetime.min.replace(tzinfo=timezone.utc),))
+def test_datetime_clamping(dt): # see #407
+ ts = Timestamp()
+ ts.FromDatetime(dt)
+ assert bytes(Spam(dt)) == ReferenceSpam(ts=ts).SerializeToString()
+ message_bytes = bytes(Spam(dt))
+
+ assert (
+ Spam().parse(message_bytes).ts.timestamp()
+ == ReferenceSpam.FromString(message_bytes).ts.seconds
+ )
+
+
+def test_empty_message_field():
+ message = Request()
+ reference_message = ReferenceRequest()
+
+ message.foo = Empty()
+ reference_message.foo.CopyFrom(ReferenceEmpty())
+
+ assert aristaproto.serialized_on_wire(message.foo)
+ assert reference_message.HasField("foo")
+
+ assert bytes(message) == reference_message.SerializeToString()
diff --git a/tests/inputs/googletypes/googletypes-missing.json b/tests/inputs/googletypes/googletypes-missing.json
new file mode 100644
index 0000000..0967ef4
--- /dev/null
+++ b/tests/inputs/googletypes/googletypes-missing.json
@@ -0,0 +1 @@
+{}
diff --git a/tests/inputs/googletypes/googletypes.json b/tests/inputs/googletypes/googletypes.json
new file mode 100644
index 0000000..0a002e9
--- /dev/null
+++ b/tests/inputs/googletypes/googletypes.json
@@ -0,0 +1,7 @@
+{
+ "maybe": false,
+ "ts": "1972-01-01T10:00:20.021Z",
+ "duration": "1.200s",
+ "important": 10,
+ "empty": {}
+}
diff --git a/tests/inputs/googletypes/googletypes.proto b/tests/inputs/googletypes/googletypes.proto
new file mode 100644
index 0000000..ef8cb4a
--- /dev/null
+++ b/tests/inputs/googletypes/googletypes.proto
@@ -0,0 +1,16 @@
+syntax = "proto3";
+
+package googletypes;
+
+import "google/protobuf/duration.proto";
+import "google/protobuf/timestamp.proto";
+import "google/protobuf/wrappers.proto";
+import "google/protobuf/empty.proto";
+
+message Test {
+ google.protobuf.BoolValue maybe = 1;
+ google.protobuf.Timestamp ts = 2;
+ google.protobuf.Duration duration = 3;
+ google.protobuf.Int32Value important = 4;
+ google.protobuf.Empty empty = 5;
+}
diff --git a/tests/inputs/googletypes_request/googletypes_request.proto b/tests/inputs/googletypes_request/googletypes_request.proto
new file mode 100644
index 0000000..1cedcaa
--- /dev/null
+++ b/tests/inputs/googletypes_request/googletypes_request.proto
@@ -0,0 +1,29 @@
+syntax = "proto3";
+
+package googletypes_request;
+
+import "google/protobuf/duration.proto";
+import "google/protobuf/empty.proto";
+import "google/protobuf/timestamp.proto";
+import "google/protobuf/wrappers.proto";
+
+// Tests that google types can be used as params
+
+service Test {
+ rpc SendDouble (google.protobuf.DoubleValue) returns (Input);
+ rpc SendFloat (google.protobuf.FloatValue) returns (Input);
+ rpc SendInt64 (google.protobuf.Int64Value) returns (Input);
+ rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input);
+ rpc SendInt32 (google.protobuf.Int32Value) returns (Input);
+ rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input);
+ rpc SendBool (google.protobuf.BoolValue) returns (Input);
+ rpc SendString (google.protobuf.StringValue) returns (Input);
+ rpc SendBytes (google.protobuf.BytesValue) returns (Input);
+ rpc SendDatetime (google.protobuf.Timestamp) returns (Input);
+ rpc SendTimedelta (google.protobuf.Duration) returns (Input);
+ rpc SendEmpty (google.protobuf.Empty) returns (Input);
+}
+
+message Input {
+
+}
diff --git a/tests/inputs/googletypes_request/test_googletypes_request.py b/tests/inputs/googletypes_request/test_googletypes_request.py
new file mode 100644
index 0000000..8351f71
--- /dev/null
+++ b/tests/inputs/googletypes_request/test_googletypes_request.py
@@ -0,0 +1,47 @@
+from datetime import (
+ datetime,
+ timedelta,
+)
+from typing import (
+ Any,
+ Callable,
+)
+
+import pytest
+
+import aristaproto.lib.google.protobuf as protobuf
+from tests.mocks import MockChannel
+from tests.output_aristaproto.googletypes_request import (
+ Input,
+ TestStub,
+)
+
+
+test_cases = [
+ (TestStub.send_double, protobuf.DoubleValue, 2.5),
+ (TestStub.send_float, protobuf.FloatValue, 2.5),
+ (TestStub.send_int64, protobuf.Int64Value, -64),
+ (TestStub.send_u_int64, protobuf.UInt64Value, 64),
+ (TestStub.send_int32, protobuf.Int32Value, -32),
+ (TestStub.send_u_int32, protobuf.UInt32Value, 32),
+ (TestStub.send_bool, protobuf.BoolValue, True),
+ (TestStub.send_string, protobuf.StringValue, "string"),
+ (TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
+ (TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)),
+ (TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)),
+]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
+async def test_channel_receives_wrapped_type(
+ service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
+):
+ wrapped_value = wrapper_class()
+ wrapped_value.value = value
+ channel = MockChannel(responses=[Input()])
+ service = TestStub(channel)
+
+ await service_method(service, wrapped_value)
+
+ assert channel.requests[0]["request"] == type(wrapped_value)
diff --git a/tests/inputs/googletypes_response/googletypes_response.proto b/tests/inputs/googletypes_response/googletypes_response.proto
new file mode 100644
index 0000000..8917d1c
--- /dev/null
+++ b/tests/inputs/googletypes_response/googletypes_response.proto
@@ -0,0 +1,23 @@
+syntax = "proto3";
+
+package googletypes_response;
+
+import "google/protobuf/wrappers.proto";
+
+// Tests that wrapped values can be used directly as return values
+
+service Test {
+ rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
+ rpc GetFloat (Input) returns (google.protobuf.FloatValue);
+ rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
+ rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
+ rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
+ rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
+ rpc GetBool (Input) returns (google.protobuf.BoolValue);
+ rpc GetString (Input) returns (google.protobuf.StringValue);
+ rpc GetBytes (Input) returns (google.protobuf.BytesValue);
+}
+
+message Input {
+
+}
diff --git a/tests/inputs/googletypes_response/test_googletypes_response.py b/tests/inputs/googletypes_response/test_googletypes_response.py
new file mode 100644
index 0000000..4ac340e
--- /dev/null
+++ b/tests/inputs/googletypes_response/test_googletypes_response.py
@@ -0,0 +1,64 @@
+from typing import (
+ Any,
+ Callable,
+ Optional,
+)
+
+import pytest
+
+import aristaproto.lib.google.protobuf as protobuf
+from tests.mocks import MockChannel
+from tests.output_aristaproto.googletypes_response import (
+ Input,
+ TestStub,
+)
+
+
+test_cases = [
+ (TestStub.get_double, protobuf.DoubleValue, 2.5),
+ (TestStub.get_float, protobuf.FloatValue, 2.5),
+ (TestStub.get_int64, protobuf.Int64Value, -64),
+ (TestStub.get_u_int64, protobuf.UInt64Value, 64),
+ (TestStub.get_int32, protobuf.Int32Value, -32),
+ (TestStub.get_u_int32, protobuf.UInt32Value, 32),
+ (TestStub.get_bool, protobuf.BoolValue, True),
+ (TestStub.get_string, protobuf.StringValue, "string"),
+ (TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
+]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
+async def test_channel_receives_wrapped_type(
+ service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
+):
+ wrapped_value = wrapper_class()
+ wrapped_value.value = value
+ channel = MockChannel(responses=[wrapped_value])
+ service = TestStub(channel)
+ method_param = Input()
+
+ await service_method(service, method_param)
+
+ assert channel.requests[0]["response_type"] != Optional[type(value)]
+ assert channel.requests[0]["response_type"] == type(wrapped_value)
+
+
+@pytest.mark.asyncio
+@pytest.mark.xfail
+@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
+async def test_service_unwraps_response(
+ service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
+):
+ """
+ grpclib does not unwrap wrapper values returned by services
+ """
+ wrapped_value = wrapper_class()
+ wrapped_value.value = value
+ service = TestStub(MockChannel(responses=[wrapped_value]))
+ method_param = Input()
+
+ response_value = await service_method(service, method_param)
+
+ assert response_value == value
+ assert type(response_value) == type(value)
diff --git a/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto b/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto
new file mode 100644
index 0000000..47284e3
--- /dev/null
+++ b/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto
@@ -0,0 +1,26 @@
+syntax = "proto3";
+
+package googletypes_response_embedded;
+
+import "google/protobuf/wrappers.proto";
+
+// Tests that wrapped values are supported as part of output message
+service Test {
+ rpc getOutput (Input) returns (Output);
+}
+
+message Input {
+
+}
+
+message Output {
+ google.protobuf.DoubleValue double_value = 1;
+ google.protobuf.FloatValue float_value = 2;
+ google.protobuf.Int64Value int64_value = 3;
+ google.protobuf.UInt64Value uint64_value = 4;
+ google.protobuf.Int32Value int32_value = 5;
+ google.protobuf.UInt32Value uint32_value = 6;
+ google.protobuf.BoolValue bool_value = 7;
+ google.protobuf.StringValue string_value = 8;
+ google.protobuf.BytesValue bytes_value = 9;
+}
diff --git a/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py b/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py
new file mode 100644
index 0000000..3d31728
--- /dev/null
+++ b/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py
@@ -0,0 +1,40 @@
+import pytest
+
+from tests.mocks import MockChannel
+from tests.output_aristaproto.googletypes_response_embedded import (
+ Input,
+ Output,
+ TestStub,
+)
+
+
+@pytest.mark.asyncio
+async def test_service_passes_through_unwrapped_values_embedded_in_response():
+ """
+ We do not not need to implement value unwrapping for embedded well-known types,
+ as this is already handled by grpclib. This test merely shows that this is the case.
+ """
+ output = Output(
+ double_value=10.0,
+ float_value=12.0,
+ int64_value=-13,
+ uint64_value=14,
+ int32_value=-15,
+ uint32_value=16,
+ bool_value=True,
+ string_value="string",
+ bytes_value=bytes(0xFF)[0:4],
+ )
+
+ service = TestStub(MockChannel(responses=[output]))
+ response = await service.get_output(Input())
+
+ assert response.double_value == 10.0
+ assert response.float_value == 12.0
+ assert response.int64_value == -13
+ assert response.uint64_value == 14
+ assert response.int32_value == -15
+ assert response.uint32_value == 16
+ assert response.bool_value
+ assert response.string_value == "string"
+ assert response.bytes_value == bytes(0xFF)[0:4]
diff --git a/tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto b/tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto
new file mode 100644
index 0000000..2153ad5
--- /dev/null
+++ b/tests/inputs/googletypes_service_returns_empty/googletypes_service_returns_empty.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+package googletypes_service_returns_empty;
+
+import "google/protobuf/empty.proto";
+
+service Test {
+ rpc Send (RequestMessage) returns (google.protobuf.Empty) {
+ }
+}
+
+message RequestMessage {
+} \ No newline at end of file
diff --git a/tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto b/tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto
new file mode 100644
index 0000000..457707b
--- /dev/null
+++ b/tests/inputs/googletypes_service_returns_googletype/googletypes_service_returns_googletype.proto
@@ -0,0 +1,18 @@
+syntax = "proto3";
+
+package googletypes_service_returns_googletype;
+
+import "google/protobuf/empty.proto";
+import "google/protobuf/struct.proto";
+
+// Tests that imports are generated correctly when returning Google well-known types
+
+service Test {
+ rpc GetEmpty (RequestMessage) returns (google.protobuf.Empty);
+ rpc GetStruct (RequestMessage) returns (google.protobuf.Struct);
+ rpc GetListValue (RequestMessage) returns (google.protobuf.ListValue);
+ rpc GetValue (RequestMessage) returns (google.protobuf.Value);
+}
+
+message RequestMessage {
+} \ No newline at end of file
diff --git a/tests/inputs/googletypes_struct/googletypes_struct.json b/tests/inputs/googletypes_struct/googletypes_struct.json
new file mode 100644
index 0000000..ecc175e
--- /dev/null
+++ b/tests/inputs/googletypes_struct/googletypes_struct.json
@@ -0,0 +1,5 @@
+{
+ "struct": {
+ "key": true
+ }
+}
diff --git a/tests/inputs/googletypes_struct/googletypes_struct.proto b/tests/inputs/googletypes_struct/googletypes_struct.proto
new file mode 100644
index 0000000..2b8b5c5
--- /dev/null
+++ b/tests/inputs/googletypes_struct/googletypes_struct.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package googletypes_struct;
+
+import "google/protobuf/struct.proto";
+
+message Test {
+ google.protobuf.Struct struct = 1;
+}
diff --git a/tests/inputs/googletypes_value/googletypes_value.json b/tests/inputs/googletypes_value/googletypes_value.json
new file mode 100644
index 0000000..db52d5c
--- /dev/null
+++ b/tests/inputs/googletypes_value/googletypes_value.json
@@ -0,0 +1,11 @@
+{
+ "value1": "hello world",
+ "value2": true,
+ "value3": 1,
+ "value4": null,
+ "value5": [
+ 1,
+ 2,
+ 3
+ ]
+}
diff --git a/tests/inputs/googletypes_value/googletypes_value.proto b/tests/inputs/googletypes_value/googletypes_value.proto
new file mode 100644
index 0000000..d5089d5
--- /dev/null
+++ b/tests/inputs/googletypes_value/googletypes_value.proto
@@ -0,0 +1,15 @@
+syntax = "proto3";
+
+package googletypes_value;
+
+import "google/protobuf/struct.proto";
+
+// Tests that fields of type google.protobuf.Value can contain arbitrary JSON-values.
+
+message Test {
+ google.protobuf.Value value1 = 1;
+ google.protobuf.Value value2 = 2;
+ google.protobuf.Value value3 = 3;
+ google.protobuf.Value value4 = 4;
+ google.protobuf.Value value5 = 5;
+}
diff --git a/tests/inputs/import_capitalized_package/capitalized.proto b/tests/inputs/import_capitalized_package/capitalized.proto
new file mode 100644
index 0000000..e80c95c
--- /dev/null
+++ b/tests/inputs/import_capitalized_package/capitalized.proto
@@ -0,0 +1,8 @@
+syntax = "proto3";
+
+
+package import_capitalized_package.Capitalized;
+
+message Message {
+
+}
diff --git a/tests/inputs/import_capitalized_package/test.proto b/tests/inputs/import_capitalized_package/test.proto
new file mode 100644
index 0000000..38c9b2d
--- /dev/null
+++ b/tests/inputs/import_capitalized_package/test.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package import_capitalized_package;
+
+import "capitalized.proto";
+
+// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't.
+
+message Test {
+ Capitalized.Message message = 1;
+}
diff --git a/tests/inputs/import_child_package_from_package/child.proto b/tests/inputs/import_child_package_from_package/child.proto
new file mode 100644
index 0000000..d99c7c3
--- /dev/null
+++ b/tests/inputs/import_child_package_from_package/child.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_child_package_from_package.package.childpackage;
+
+message ChildMessage {
+
+}
diff --git a/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto b/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto
new file mode 100644
index 0000000..66e0aa8
--- /dev/null
+++ b/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package import_child_package_from_package;
+
+import "package_message.proto";
+
+// Tests generated imports when a message in a package refers to a message in a nested child package.
+
+message Test {
+ package.PackageMessage message = 1;
+}
diff --git a/tests/inputs/import_child_package_from_package/package_message.proto b/tests/inputs/import_child_package_from_package/package_message.proto
new file mode 100644
index 0000000..79d66f3
--- /dev/null
+++ b/tests/inputs/import_child_package_from_package/package_message.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+import "child.proto";
+
+package import_child_package_from_package.package;
+
+message PackageMessage {
+ package.childpackage.ChildMessage c = 1;
+}
diff --git a/tests/inputs/import_child_package_from_root/child.proto b/tests/inputs/import_child_package_from_root/child.proto
new file mode 100644
index 0000000..2a46d5f
--- /dev/null
+++ b/tests/inputs/import_child_package_from_root/child.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_child_package_from_root.childpackage;
+
+message Message {
+
+}
diff --git a/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto b/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto
new file mode 100644
index 0000000..6299831
--- /dev/null
+++ b/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package import_child_package_from_root;
+
+import "child.proto";
+
+// Tests generated imports when a message in root refers to a message in a child package.
+
+message Test {
+ childpackage.Message child = 1;
+}
diff --git a/tests/inputs/import_circular_dependency/import_circular_dependency.proto b/tests/inputs/import_circular_dependency/import_circular_dependency.proto
new file mode 100644
index 0000000..8b159e2
--- /dev/null
+++ b/tests/inputs/import_circular_dependency/import_circular_dependency.proto
@@ -0,0 +1,30 @@
+syntax = "proto3";
+
+package import_circular_dependency;
+
+import "root.proto";
+import "other.proto";
+
+// This test-case verifies support for circular dependencies in the generated python files.
+//
+// This is important because we generate 1 python file/module per package, rather than 1 file per proto file.
+//
+// Scenario:
+//
+// The proto messages depend on each other in a non-circular way:
+//
+// Test -------> RootPackageMessage <--------------.
+// `------------------------------------> OtherPackageMessage
+//
+// Test and RootPackageMessage are in different files, but belong to the same package (root):
+//
+// (Test -------> RootPackageMessage) <------------.
+// `------------------------------------> OtherPackageMessage
+//
+// After grouping the packages into single files or modules, a circular dependency is created:
+//
+// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
+message Test {
+ RootPackageMessage message = 1;
+ other.OtherPackageMessage other = 2;
+}
diff --git a/tests/inputs/import_circular_dependency/other.proto b/tests/inputs/import_circular_dependency/other.proto
new file mode 100644
index 0000000..833b869
--- /dev/null
+++ b/tests/inputs/import_circular_dependency/other.proto
@@ -0,0 +1,8 @@
+syntax = "proto3";
+
+import "root.proto";
+package import_circular_dependency.other;
+
+message OtherPackageMessage {
+ RootPackageMessage rootPackageMessage = 1;
+}
diff --git a/tests/inputs/import_circular_dependency/root.proto b/tests/inputs/import_circular_dependency/root.proto
new file mode 100644
index 0000000..7383947
--- /dev/null
+++ b/tests/inputs/import_circular_dependency/root.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_circular_dependency;
+
+message RootPackageMessage {
+
+}
diff --git a/tests/inputs/import_cousin_package/cousin.proto b/tests/inputs/import_cousin_package/cousin.proto
new file mode 100644
index 0000000..2870dfe
--- /dev/null
+++ b/tests/inputs/import_cousin_package/cousin.proto
@@ -0,0 +1,6 @@
+syntax = "proto3";
+
+package import_cousin_package.cousin.cousin_subpackage;
+
+message CousinMessage {
+}
diff --git a/tests/inputs/import_cousin_package/test.proto b/tests/inputs/import_cousin_package/test.proto
new file mode 100644
index 0000000..89ec3d8
--- /dev/null
+++ b/tests/inputs/import_cousin_package/test.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package import_cousin_package.test.subpackage;
+
+import "cousin.proto";
+
+// Verify that we can import message unrelated to us
+
+message Test {
+ cousin.cousin_subpackage.CousinMessage message = 1;
+}
diff --git a/tests/inputs/import_cousin_package_same_name/cousin.proto b/tests/inputs/import_cousin_package_same_name/cousin.proto
new file mode 100644
index 0000000..84b6a40
--- /dev/null
+++ b/tests/inputs/import_cousin_package_same_name/cousin.proto
@@ -0,0 +1,6 @@
+syntax = "proto3";
+
+package import_cousin_package_same_name.cousin.subpackage;
+
+message CousinMessage {
+}
diff --git a/tests/inputs/import_cousin_package_same_name/test.proto b/tests/inputs/import_cousin_package_same_name/test.proto
new file mode 100644
index 0000000..7b420d3
--- /dev/null
+++ b/tests/inputs/import_cousin_package_same_name/test.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package import_cousin_package_same_name.test.subpackage;
+
+import "cousin.proto";
+
+// Verify that we can import a message unrelated to us, in a subpackage with the same name as us.
+
+message Test {
+ cousin.subpackage.CousinMessage message = 1;
+}
diff --git a/tests/inputs/import_packages_same_name/import_packages_same_name.proto b/tests/inputs/import_packages_same_name/import_packages_same_name.proto
new file mode 100644
index 0000000..dff7efe
--- /dev/null
+++ b/tests/inputs/import_packages_same_name/import_packages_same_name.proto
@@ -0,0 +1,13 @@
+syntax = "proto3";
+
+package import_packages_same_name;
+
+import "users_v1.proto";
+import "posts_v1.proto";
+
+// Tests generated message can correctly reference two packages with the same leaf-name
+
+message Test {
+ users.v1.User user = 1;
+ posts.v1.Post post = 2;
+}
diff --git a/tests/inputs/import_packages_same_name/posts_v1.proto b/tests/inputs/import_packages_same_name/posts_v1.proto
new file mode 100644
index 0000000..d3b9b1c
--- /dev/null
+++ b/tests/inputs/import_packages_same_name/posts_v1.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_packages_same_name.posts.v1;
+
+message Post {
+
+}
diff --git a/tests/inputs/import_packages_same_name/users_v1.proto b/tests/inputs/import_packages_same_name/users_v1.proto
new file mode 100644
index 0000000..d3a17e9
--- /dev/null
+++ b/tests/inputs/import_packages_same_name/users_v1.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_packages_same_name.users.v1;
+
+message User {
+
+}
diff --git a/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto b/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto
new file mode 100644
index 0000000..edc4736
--- /dev/null
+++ b/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+import "parent_package_message.proto";
+
+package import_parent_package_from_child.parent.child;
+
+// Tests generated imports when a message refers to a message defined in its parent package
+
+message Test {
+ ParentPackageMessage message_implicit = 1;
+ parent.ParentPackageMessage message_explicit = 2;
+}
diff --git a/tests/inputs/import_parent_package_from_child/parent_package_message.proto b/tests/inputs/import_parent_package_from_child/parent_package_message.proto
new file mode 100644
index 0000000..fb3fd31
--- /dev/null
+++ b/tests/inputs/import_parent_package_from_child/parent_package_message.proto
@@ -0,0 +1,6 @@
+syntax = "proto3";
+
+package import_parent_package_from_child.parent;
+
+message ParentPackageMessage {
+}
diff --git a/tests/inputs/import_root_package_from_child/child.proto b/tests/inputs/import_root_package_from_child/child.proto
new file mode 100644
index 0000000..bd51967
--- /dev/null
+++ b/tests/inputs/import_root_package_from_child/child.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package import_root_package_from_child.child;
+
+import "root.proto";
+
+// Verify that we can import root message from child package
+
+message Test {
+ RootMessage message = 1;
+}
diff --git a/tests/inputs/import_root_package_from_child/root.proto b/tests/inputs/import_root_package_from_child/root.proto
new file mode 100644
index 0000000..6ae955a
--- /dev/null
+++ b/tests/inputs/import_root_package_from_child/root.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_root_package_from_child;
+
+
+message RootMessage {
+}
diff --git a/tests/inputs/import_root_sibling/import_root_sibling.proto b/tests/inputs/import_root_sibling/import_root_sibling.proto
new file mode 100644
index 0000000..759e606
--- /dev/null
+++ b/tests/inputs/import_root_sibling/import_root_sibling.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package import_root_sibling;
+
+import "sibling.proto";
+
+// Tests generated imports when a message in the root package refers to another message in the root package
+
+message Test {
+ SiblingMessage sibling = 1;
+}
diff --git a/tests/inputs/import_root_sibling/sibling.proto b/tests/inputs/import_root_sibling/sibling.proto
new file mode 100644
index 0000000..6b6ba2e
--- /dev/null
+++ b/tests/inputs/import_root_sibling/sibling.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_root_sibling;
+
+message SiblingMessage {
+
+}
diff --git a/tests/inputs/import_service_input_message/child_package_request_message.proto b/tests/inputs/import_service_input_message/child_package_request_message.proto
new file mode 100644
index 0000000..54fc112
--- /dev/null
+++ b/tests/inputs/import_service_input_message/child_package_request_message.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_service_input_message.child;
+
+message ChildRequestMessage {
+ int32 child_argument = 1;
+} \ No newline at end of file
diff --git a/tests/inputs/import_service_input_message/import_service_input_message.proto b/tests/inputs/import_service_input_message/import_service_input_message.proto
new file mode 100644
index 0000000..cbf48fa
--- /dev/null
+++ b/tests/inputs/import_service_input_message/import_service_input_message.proto
@@ -0,0 +1,25 @@
+syntax = "proto3";
+
+package import_service_input_message;
+
+import "request_message.proto";
+import "child_package_request_message.proto";
+
+// Tests generated service correctly imports the RequestMessage
+
+service Test {
+ rpc DoThing (RequestMessage) returns (RequestResponse);
+ rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse);
+ rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse);
+}
+
+
+message RequestResponse {
+ int32 value = 1;
+}
+
+message Nested {
+ message RequestMessage {
+ int32 nestedArgument = 1;
+ }
+} \ No newline at end of file
diff --git a/tests/inputs/import_service_input_message/request_message.proto b/tests/inputs/import_service_input_message/request_message.proto
new file mode 100644
index 0000000..36a6e78
--- /dev/null
+++ b/tests/inputs/import_service_input_message/request_message.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package import_service_input_message;
+
+message RequestMessage {
+ int32 argument = 1;
+} \ No newline at end of file
diff --git a/tests/inputs/import_service_input_message/test_import_service_input_message.py b/tests/inputs/import_service_input_message/test_import_service_input_message.py
new file mode 100644
index 0000000..66c654b
--- /dev/null
+++ b/tests/inputs/import_service_input_message/test_import_service_input_message.py
@@ -0,0 +1,36 @@
+import pytest
+
+from tests.mocks import MockChannel
+from tests.output_aristaproto.import_service_input_message import (
+ NestedRequestMessage,
+ RequestMessage,
+ RequestResponse,
+ TestStub,
+)
+from tests.output_aristaproto.import_service_input_message.child import (
+ ChildRequestMessage,
+)
+
+
+@pytest.mark.asyncio
+async def test_service_correctly_imports_reference_message():
+ mock_response = RequestResponse(value=10)
+ service = TestStub(MockChannel([mock_response]))
+ response = await service.do_thing(RequestMessage(1))
+ assert mock_response == response
+
+
+@pytest.mark.asyncio
+async def test_service_correctly_imports_reference_message_from_child_package():
+ mock_response = RequestResponse(value=10)
+ service = TestStub(MockChannel([mock_response]))
+ response = await service.do_thing2(ChildRequestMessage(1))
+ assert mock_response == response
+
+
+@pytest.mark.asyncio
+async def test_service_correctly_imports_nested_reference():
+ mock_response = RequestResponse(value=10)
+ service = TestStub(MockChannel([mock_response]))
+ response = await service.do_thing3(NestedRequestMessage(1))
+ assert mock_response == response
diff --git a/tests/inputs/int32/int32.json b/tests/inputs/int32/int32.json
new file mode 100644
index 0000000..34d4111
--- /dev/null
+++ b/tests/inputs/int32/int32.json
@@ -0,0 +1,4 @@
+{
+ "positive": 150,
+ "negative": -150
+}
diff --git a/tests/inputs/int32/int32.proto b/tests/inputs/int32/int32.proto
new file mode 100644
index 0000000..4721c23
--- /dev/null
+++ b/tests/inputs/int32/int32.proto
@@ -0,0 +1,10 @@
+syntax = "proto3";
+
+package int32;
+
+// Some documentation about the Test message.
+message Test {
+ // Some documentation about the count.
+ int32 positive = 1;
+ int32 negative = 2;
+}
diff --git a/tests/inputs/map/map.json b/tests/inputs/map/map.json
new file mode 100644
index 0000000..6a1e853
--- /dev/null
+++ b/tests/inputs/map/map.json
@@ -0,0 +1,7 @@
+{
+ "counts": {
+ "item1": 1,
+ "item2": 2,
+ "item3": 3
+ }
+}
diff --git a/tests/inputs/map/map.proto b/tests/inputs/map/map.proto
new file mode 100644
index 0000000..ecef3cc
--- /dev/null
+++ b/tests/inputs/map/map.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package map;
+
+message Test {
+ map<string, int32> counts = 1;
+}
diff --git a/tests/inputs/mapmessage/mapmessage.json b/tests/inputs/mapmessage/mapmessage.json
new file mode 100644
index 0000000..a944ddd
--- /dev/null
+++ b/tests/inputs/mapmessage/mapmessage.json
@@ -0,0 +1,10 @@
+{
+ "items": {
+ "foo": {
+ "count": 1
+ },
+ "bar": {
+ "count": 2
+ }
+ }
+}
diff --git a/tests/inputs/mapmessage/mapmessage.proto b/tests/inputs/mapmessage/mapmessage.proto
new file mode 100644
index 0000000..2c704a4
--- /dev/null
+++ b/tests/inputs/mapmessage/mapmessage.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package mapmessage;
+
+message Test {
+ map<string, Nested> items = 1;
+}
+
+message Nested {
+ int32 count = 1;
+} \ No newline at end of file
diff --git a/tests/inputs/namespace_builtin_types/namespace_builtin_types.json b/tests/inputs/namespace_builtin_types/namespace_builtin_types.json
new file mode 100644
index 0000000..8200032
--- /dev/null
+++ b/tests/inputs/namespace_builtin_types/namespace_builtin_types.json
@@ -0,0 +1,16 @@
+{
+ "int": "value-for-int",
+ "float": "value-for-float",
+ "complex": "value-for-complex",
+ "list": "value-for-list",
+ "tuple": "value-for-tuple",
+ "range": "value-for-range",
+ "str": "value-for-str",
+ "bytearray": "value-for-bytearray",
+ "bytes": "value-for-bytes",
+ "memoryview": "value-for-memoryview",
+ "set": "value-for-set",
+ "frozenset": "value-for-frozenset",
+ "map": "value-for-map",
+ "bool": "value-for-bool"
+} \ No newline at end of file
diff --git a/tests/inputs/namespace_builtin_types/namespace_builtin_types.proto b/tests/inputs/namespace_builtin_types/namespace_builtin_types.proto
new file mode 100644
index 0000000..71cb029
--- /dev/null
+++ b/tests/inputs/namespace_builtin_types/namespace_builtin_types.proto
@@ -0,0 +1,40 @@
+syntax = "proto3";
+
+package namespace_builtin_types;
+
+// Tests that messages may contain fields with names that are python types
+
+message Test {
+ // https://docs.python.org/2/library/stdtypes.html#numeric-types-int-float-long-complex
+ string int = 1;
+ string float = 2;
+ string complex = 3;
+
+ // https://docs.python.org/3/library/stdtypes.html#sequence-types-list-tuple-range
+ string list = 4;
+ string tuple = 5;
+ string range = 6;
+
+ // https://docs.python.org/3/library/stdtypes.html#str
+ string str = 7;
+
+ // https://docs.python.org/3/library/stdtypes.html#bytearray-objects
+ string bytearray = 8;
+
+ // https://docs.python.org/3/library/stdtypes.html#bytes-and-bytearray-operations
+ string bytes = 9;
+
+ // https://docs.python.org/3/library/stdtypes.html#memory-views
+ string memoryview = 10;
+
+ // https://docs.python.org/3/library/stdtypes.html#set-types-set-frozenset
+ string set = 11;
+ string frozenset = 12;
+
+ // https://docs.python.org/3/library/stdtypes.html#dict
+ string map = 13;
+ string dict = 14;
+
+ // https://docs.python.org/3/library/stdtypes.html#boolean-values
+ string bool = 15;
+} \ No newline at end of file
diff --git a/tests/inputs/namespace_keywords/namespace_keywords.json b/tests/inputs/namespace_keywords/namespace_keywords.json
new file mode 100644
index 0000000..4f11b60
--- /dev/null
+++ b/tests/inputs/namespace_keywords/namespace_keywords.json
@@ -0,0 +1,37 @@
+{
+ "False": 1,
+ "None": 2,
+ "True": 3,
+ "and": 4,
+ "as": 5,
+ "assert": 6,
+ "async": 7,
+ "await": 8,
+ "break": 9,
+ "class": 10,
+ "continue": 11,
+ "def": 12,
+ "del": 13,
+ "elif": 14,
+ "else": 15,
+ "except": 16,
+ "finally": 17,
+ "for": 18,
+ "from": 19,
+ "global": 20,
+ "if": 21,
+ "import": 22,
+ "in": 23,
+ "is": 24,
+ "lambda": 25,
+ "nonlocal": 26,
+ "not": 27,
+ "or": 28,
+ "pass": 29,
+ "raise": 30,
+ "return": 31,
+ "try": 32,
+ "while": 33,
+ "with": 34,
+ "yield": 35
+}
diff --git a/tests/inputs/namespace_keywords/namespace_keywords.proto b/tests/inputs/namespace_keywords/namespace_keywords.proto
new file mode 100644
index 0000000..ac3e5c5
--- /dev/null
+++ b/tests/inputs/namespace_keywords/namespace_keywords.proto
@@ -0,0 +1,46 @@
+syntax = "proto3";
+
+package namespace_keywords;
+
+// Tests that messages may contain fields that are Python keywords
+//
+// Generated with Python 3.7.6
+// print('\n'.join(f'string {k} = {i+1};' for i,k in enumerate(keyword.kwlist)))
+
+message Test {
+ string False = 1;
+ string None = 2;
+ string True = 3;
+ string and = 4;
+ string as = 5;
+ string assert = 6;
+ string async = 7;
+ string await = 8;
+ string break = 9;
+ string class = 10;
+ string continue = 11;
+ string def = 12;
+ string del = 13;
+ string elif = 14;
+ string else = 15;
+ string except = 16;
+ string finally = 17;
+ string for = 18;
+ string from = 19;
+ string global = 20;
+ string if = 21;
+ string import = 22;
+ string in = 23;
+ string is = 24;
+ string lambda = 25;
+ string nonlocal = 26;
+ string not = 27;
+ string or = 28;
+ string pass = 29;
+ string raise = 30;
+ string return = 31;
+ string try = 32;
+ string while = 33;
+ string with = 34;
+ string yield = 35;
+} \ No newline at end of file
diff --git a/tests/inputs/nested/nested.json b/tests/inputs/nested/nested.json
new file mode 100644
index 0000000..f460cad
--- /dev/null
+++ b/tests/inputs/nested/nested.json
@@ -0,0 +1,7 @@
+{
+ "nested": {
+ "count": 150
+ },
+ "sibling": {},
+ "msg": "THIS"
+}
diff --git a/tests/inputs/nested/nested.proto b/tests/inputs/nested/nested.proto
new file mode 100644
index 0000000..619c721
--- /dev/null
+++ b/tests/inputs/nested/nested.proto
@@ -0,0 +1,26 @@
+syntax = "proto3";
+
+package nested;
+
+// A test message with a nested message inside of it.
+message Test {
+ // This is the nested type.
+ message Nested {
+ // Stores a simple counter.
+ int32 count = 1;
+ }
+ // This is the nested enum.
+ enum Msg {
+ NONE = 0;
+ THIS = 1;
+ }
+
+ Nested nested = 1;
+ Sibling sibling = 2;
+ Sibling sibling2 = 3;
+ Msg msg = 4;
+}
+
+message Sibling {
+ int32 foo = 1;
+} \ No newline at end of file
diff --git a/tests/inputs/nested2/nested2.proto b/tests/inputs/nested2/nested2.proto
new file mode 100644
index 0000000..cd6510c
--- /dev/null
+++ b/tests/inputs/nested2/nested2.proto
@@ -0,0 +1,21 @@
+syntax = "proto3";
+
+package nested2;
+
+import "package.proto";
+
+message Game {
+ message Player {
+ enum Race {
+ human = 0;
+ orc = 1;
+ }
+ }
+}
+
+message Test {
+ Game game = 1;
+ Game.Player GamePlayer = 2;
+ Game.Player.Race GamePlayerRace = 3;
+ equipment.Weapon Weapon = 4;
+} \ No newline at end of file
diff --git a/tests/inputs/nested2/package.proto b/tests/inputs/nested2/package.proto
new file mode 100644
index 0000000..e12abb1
--- /dev/null
+++ b/tests/inputs/nested2/package.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package nested2.equipment;
+
+message Weapon {
+
+} \ No newline at end of file
diff --git a/tests/inputs/nestedtwice/nestedtwice.json b/tests/inputs/nestedtwice/nestedtwice.json
new file mode 100644
index 0000000..c953132
--- /dev/null
+++ b/tests/inputs/nestedtwice/nestedtwice.json
@@ -0,0 +1,11 @@
+{
+ "top": {
+ "name": "double-nested",
+ "middle": {
+ "bottom": [{"foo": "hello"}],
+ "enumBottom": ["A"],
+ "topMiddleBottom": [{"a": "hello"}],
+ "bar": true
+ }
+ }
+}
diff --git a/tests/inputs/nestedtwice/nestedtwice.proto b/tests/inputs/nestedtwice/nestedtwice.proto
new file mode 100644
index 0000000..84d142a
--- /dev/null
+++ b/tests/inputs/nestedtwice/nestedtwice.proto
@@ -0,0 +1,40 @@
+syntax = "proto3";
+
+package nestedtwice;
+
+/* Test doc. */
+message Test {
+ /* Top doc. */
+ message Top {
+ /* Middle doc. */
+ message Middle {
+ /* TopMiddleBottom doc.*/
+ message TopMiddleBottom {
+ // TopMiddleBottom.a doc.
+ string a = 1;
+ }
+ /* EnumBottom doc. */
+ enum EnumBottom{
+ /* EnumBottom.A doc. */
+ A = 0;
+ B = 1;
+ }
+ /* Bottom doc. */
+ message Bottom {
+ /* Bottom.foo doc. */
+ string foo = 1;
+ }
+ reserved 1;
+ /* Middle.bottom doc. */
+ repeated Bottom bottom = 2;
+ repeated EnumBottom enumBottom=3;
+ repeated TopMiddleBottom topMiddleBottom=4;
+ bool bar = 5;
+ }
+ /* Top.name doc. */
+ string name = 1;
+ Middle middle = 2;
+ }
+ /* Test.top doc. */
+ Top top = 1;
+}
diff --git a/tests/inputs/nestedtwice/test_nestedtwice.py b/tests/inputs/nestedtwice/test_nestedtwice.py
new file mode 100644
index 0000000..502e710
--- /dev/null
+++ b/tests/inputs/nestedtwice/test_nestedtwice.py
@@ -0,0 +1,25 @@
+import pytest
+
+from tests.output_aristaproto.nestedtwice import (
+ Test,
+ TestTop,
+ TestTopMiddle,
+ TestTopMiddleBottom,
+ TestTopMiddleEnumBottom,
+ TestTopMiddleTopMiddleBottom,
+)
+
+
+@pytest.mark.parametrize(
+ ("cls", "expected_comment"),
+ [
+ (Test, "Test doc."),
+ (TestTopMiddleEnumBottom, "EnumBottom doc."),
+ (TestTop, "Top doc."),
+ (TestTopMiddle, "Middle doc."),
+ (TestTopMiddleTopMiddleBottom, "TopMiddleBottom doc."),
+ (TestTopMiddleBottom, "Bottom doc."),
+ ],
+)
+def test_comment(cls, expected_comment):
+ assert cls.__doc__ == expected_comment
diff --git a/tests/inputs/oneof/oneof-name.json b/tests/inputs/oneof/oneof-name.json
new file mode 100644
index 0000000..605484b
--- /dev/null
+++ b/tests/inputs/oneof/oneof-name.json
@@ -0,0 +1,3 @@
+{
+ "pitier": "Mr. T"
+}
diff --git a/tests/inputs/oneof/oneof.json b/tests/inputs/oneof/oneof.json
new file mode 100644
index 0000000..65cafc5
--- /dev/null
+++ b/tests/inputs/oneof/oneof.json
@@ -0,0 +1,3 @@
+{
+ "pitied": 100
+}
diff --git a/tests/inputs/oneof/oneof.proto b/tests/inputs/oneof/oneof.proto
new file mode 100644
index 0000000..41f93b0
--- /dev/null
+++ b/tests/inputs/oneof/oneof.proto
@@ -0,0 +1,23 @@
+syntax = "proto3";
+
+package oneof;
+
+message MixedDrink {
+ int32 shots = 1;
+}
+
+message Test {
+ oneof foo {
+ int32 pitied = 1;
+ string pitier = 2;
+ }
+
+ int32 just_a_regular_field = 3;
+
+ oneof bar {
+ int32 drinks = 11;
+ string bar_name = 12;
+ MixedDrink mixed_drink = 13;
+ }
+}
+
diff --git a/tests/inputs/oneof/oneof_name.json b/tests/inputs/oneof/oneof_name.json
new file mode 100644
index 0000000..605484b
--- /dev/null
+++ b/tests/inputs/oneof/oneof_name.json
@@ -0,0 +1,3 @@
+{
+ "pitier": "Mr. T"
+}
diff --git a/tests/inputs/oneof/test_oneof.py b/tests/inputs/oneof/test_oneof.py
new file mode 100644
index 0000000..8a38496
--- /dev/null
+++ b/tests/inputs/oneof/test_oneof.py
@@ -0,0 +1,43 @@
+import pytest
+
+import aristaproto
+from tests.output_aristaproto.oneof import (
+ MixedDrink,
+ Test,
+)
+from tests.output_aristaproto_pydantic.oneof import Test as TestPyd
+from tests.util import get_test_case_json_data
+
+
+def test_which_count():
+ message = Test()
+ message.from_json(get_test_case_json_data("oneof")[0].json)
+ assert aristaproto.which_one_of(message, "foo") == ("pitied", 100)
+
+
+def test_which_name():
+ message = Test()
+ message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
+ assert aristaproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
+
+
+def test_which_count_pyd():
+ message = TestPyd(pitier="Mr. T", just_a_regular_field=2, bar_name="a_bar")
+ assert aristaproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
+
+
+def test_oneof_constructor_assign():
+ message = Test(mixed_drink=MixedDrink(shots=42))
+ field, value = aristaproto.which_one_of(message, "bar")
+ assert field == "mixed_drink"
+ assert value.shots == 42
+
+
+# Issue #305:
+@pytest.mark.xfail
+def test_oneof_nested_assign():
+ message = Test()
+ message.mixed_drink.shots = 42
+ field, value = aristaproto.which_one_of(message, "bar")
+ assert field == "mixed_drink"
+ assert value.shots == 42
diff --git a/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto b/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto
new file mode 100644
index 0000000..f7ac6fe
--- /dev/null
+++ b/tests/inputs/oneof_default_value_serialization/oneof_default_value_serialization.proto
@@ -0,0 +1,30 @@
+syntax = "proto3";
+
+package oneof_default_value_serialization;
+
+import "google/protobuf/duration.proto";
+import "google/protobuf/timestamp.proto";
+import "google/protobuf/wrappers.proto";
+
+message Message{
+ int64 value = 1;
+}
+
+message NestedMessage{
+ int64 id = 1;
+ oneof value_type{
+ Message wrapped_message_value = 2;
+ }
+}
+
+message Test{
+ oneof value_type {
+ bool bool_value = 1;
+ int64 int64_value = 2;
+ google.protobuf.Timestamp timestamp_value = 3;
+ google.protobuf.Duration duration_value = 4;
+ Message wrapped_message_value = 5;
+ NestedMessage wrapped_nested_message_value = 6;
+ google.protobuf.BoolValue wrapped_bool_value = 7;
+ }
+} \ No newline at end of file
diff --git a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py
new file mode 100644
index 0000000..0fad3d6
--- /dev/null
+++ b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py
@@ -0,0 +1,75 @@
+import datetime
+
+import pytest
+
+import aristaproto
+from tests.output_aristaproto.oneof_default_value_serialization import (
+ Message,
+ NestedMessage,
+ Test,
+)
+
+
+def assert_round_trip_serialization_works(message: Test) -> None:
+ assert aristaproto.which_one_of(message, "value_type") == aristaproto.which_one_of(
+ Test().from_json(message.to_json()), "value_type"
+ )
+
+
+def test_oneof_default_value_serialization_works_for_all_values():
+ """
+ Serialization from message with oneof set to default -> JSON -> message should keep
+ default value field intact.
+ """
+
+ test_cases = [
+ Test(bool_value=False),
+ Test(int64_value=0),
+ Test(
+ timestamp_value=datetime.datetime(
+ year=1970,
+ month=1,
+ day=1,
+ hour=0,
+ minute=0,
+ tzinfo=datetime.timezone.utc,
+ )
+ ),
+ Test(duration_value=datetime.timedelta(0)),
+ Test(wrapped_message_value=Message(value=0)),
+ # NOTE: Do NOT use aristaproto.BoolValue here, it will cause JSON serialization
+ # errors.
+ # TODO: Do we want to allow use of BoolValue directly within a wrapped field or
+ # should we simply hard fail here?
+ Test(wrapped_bool_value=False),
+ ]
+ for message in test_cases:
+ assert_round_trip_serialization_works(message)
+
+
+def test_oneof_no_default_values_passed():
+ message = Test()
+ assert (
+ aristaproto.which_one_of(message, "value_type")
+ == aristaproto.which_one_of(Test().from_json(message.to_json()), "value_type")
+ == ("", None)
+ )
+
+
+def test_oneof_nested_oneof_messages_are_serialized_with_defaults():
+ """
+ Nested messages with oneofs should also be handled
+ """
+ message = Test(
+ wrapped_nested_message_value=NestedMessage(
+ id=0, wrapped_message_value=Message(value=0)
+ )
+ )
+ assert (
+ aristaproto.which_one_of(message, "value_type")
+ == aristaproto.which_one_of(Test().from_json(message.to_json()), "value_type")
+ == (
+ "wrapped_nested_message_value",
+ NestedMessage(id=0, wrapped_message_value=Message(value=0)),
+ )
+ )
diff --git a/tests/inputs/oneof_empty/oneof_empty.json b/tests/inputs/oneof_empty/oneof_empty.json
new file mode 100644
index 0000000..9d21c89
--- /dev/null
+++ b/tests/inputs/oneof_empty/oneof_empty.json
@@ -0,0 +1,3 @@
+{
+ "nothing": {}
+}
diff --git a/tests/inputs/oneof_empty/oneof_empty.proto b/tests/inputs/oneof_empty/oneof_empty.proto
new file mode 100644
index 0000000..ca51d5a
--- /dev/null
+++ b/tests/inputs/oneof_empty/oneof_empty.proto
@@ -0,0 +1,17 @@
+syntax = "proto3";
+
+package oneof_empty;
+
+message Nothing {}
+
+message MaybeNothing {
+ string sometimes = 42;
+}
+
+message Test {
+ oneof empty {
+ Nothing nothing = 1;
+ MaybeNothing maybe1 = 2;
+ MaybeNothing maybe2 = 3;
+ }
+}
diff --git a/tests/inputs/oneof_empty/oneof_empty_maybe1.json b/tests/inputs/oneof_empty/oneof_empty_maybe1.json
new file mode 100644
index 0000000..f7a2d27
--- /dev/null
+++ b/tests/inputs/oneof_empty/oneof_empty_maybe1.json
@@ -0,0 +1,3 @@
+{
+ "maybe1": {}
+}
diff --git a/tests/inputs/oneof_empty/oneof_empty_maybe2.json b/tests/inputs/oneof_empty/oneof_empty_maybe2.json
new file mode 100644
index 0000000..bc2b385
--- /dev/null
+++ b/tests/inputs/oneof_empty/oneof_empty_maybe2.json
@@ -0,0 +1,5 @@
+{
+ "maybe2": {
+ "sometimes": "now"
+ }
+}
diff --git a/tests/inputs/oneof_empty/test_oneof_empty.py b/tests/inputs/oneof_empty/test_oneof_empty.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/inputs/oneof_empty/test_oneof_empty.py
diff --git a/tests/inputs/oneof_enum/oneof_enum-enum-0.json b/tests/inputs/oneof_enum/oneof_enum-enum-0.json
new file mode 100644
index 0000000..be30cf0
--- /dev/null
+++ b/tests/inputs/oneof_enum/oneof_enum-enum-0.json
@@ -0,0 +1,3 @@
+{
+ "signal": "PASS"
+}
diff --git a/tests/inputs/oneof_enum/oneof_enum-enum-1.json b/tests/inputs/oneof_enum/oneof_enum-enum-1.json
new file mode 100644
index 0000000..cb63873
--- /dev/null
+++ b/tests/inputs/oneof_enum/oneof_enum-enum-1.json
@@ -0,0 +1,3 @@
+{
+ "signal": "RESIGN"
+}
diff --git a/tests/inputs/oneof_enum/oneof_enum.json b/tests/inputs/oneof_enum/oneof_enum.json
new file mode 100644
index 0000000..3220b70
--- /dev/null
+++ b/tests/inputs/oneof_enum/oneof_enum.json
@@ -0,0 +1,6 @@
+{
+ "move": {
+ "x": 2,
+ "y": 3
+ }
+}
diff --git a/tests/inputs/oneof_enum/oneof_enum.proto b/tests/inputs/oneof_enum/oneof_enum.proto
new file mode 100644
index 0000000..906abcb
--- /dev/null
+++ b/tests/inputs/oneof_enum/oneof_enum.proto
@@ -0,0 +1,20 @@
+syntax = "proto3";
+
+package oneof_enum;
+
+message Test {
+ oneof action {
+ Signal signal = 1;
+ Move move = 2;
+ }
+}
+
+enum Signal {
+ PASS = 0;
+ RESIGN = 1;
+}
+
+message Move {
+ int32 x = 1;
+ int32 y = 2;
+} \ No newline at end of file
diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py
new file mode 100644
index 0000000..98de22a
--- /dev/null
+++ b/tests/inputs/oneof_enum/test_oneof_enum.py
@@ -0,0 +1,47 @@
+import pytest
+
+import aristaproto
+from tests.output_aristaproto.oneof_enum import (
+ Move,
+ Signal,
+ Test,
+)
+from tests.util import get_test_case_json_data
+
+
+def test_which_one_of_returns_enum_with_default_value():
+ """
+ returns first field when it is enum and set with default value
+ """
+ message = Test()
+ message.from_json(
+ get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
+ )
+
+ assert not hasattr(message, "move")
+ assert object.__getattribute__(message, "move") == aristaproto.PLACEHOLDER
+ assert message.signal == Signal.PASS
+ assert aristaproto.which_one_of(message, "action") == ("signal", Signal.PASS)
+
+
+def test_which_one_of_returns_enum_with_non_default_value():
+ """
+ returns first field when it is enum and set with non default value
+ """
+ message = Test()
+ message.from_json(
+ get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
+ )
+ assert not hasattr(message, "move")
+ assert object.__getattribute__(message, "move") == aristaproto.PLACEHOLDER
+ assert message.signal == Signal.RESIGN
+ assert aristaproto.which_one_of(message, "action") == ("signal", Signal.RESIGN)
+
+
+def test_which_one_of_returns_second_field_when_set():
+ message = Test()
+ message.from_json(get_test_case_json_data("oneof_enum")[0].json)
+ assert message.move == Move(x=2, y=3)
+ assert not hasattr(message, "signal")
+ assert object.__getattribute__(message, "signal") == aristaproto.PLACEHOLDER
+ assert aristaproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence.json b/tests/inputs/proto3_field_presence/proto3_field_presence.json
new file mode 100644
index 0000000..988df8e
--- /dev/null
+++ b/tests/inputs/proto3_field_presence/proto3_field_presence.json
@@ -0,0 +1,13 @@
+{
+ "test1": 128,
+ "test2": true,
+ "test3": "A value",
+ "test4": "aGVsbG8=",
+ "test5": {
+ "test": "Hello"
+ },
+ "test6": "B",
+ "test7": "8589934592",
+ "test8": 2.5,
+ "test9": "2022-01-24T12:12:42Z"
+}
diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence.proto b/tests/inputs/proto3_field_presence/proto3_field_presence.proto
new file mode 100644
index 0000000..f28123d
--- /dev/null
+++ b/tests/inputs/proto3_field_presence/proto3_field_presence.proto
@@ -0,0 +1,26 @@
+syntax = "proto3";
+
+package proto3_field_presence;
+
+import "google/protobuf/timestamp.proto";
+
+message InnerTest {
+ string test = 1;
+}
+
+message Test {
+ optional uint32 test1 = 1;
+ optional bool test2 = 2;
+ optional string test3 = 3;
+ optional bytes test4 = 4;
+ optional InnerTest test5 = 5;
+ optional TestEnum test6 = 6;
+ optional uint64 test7 = 7;
+ optional float test8 = 8;
+ optional google.protobuf.Timestamp test9 = 9;
+}
+
+enum TestEnum {
+ A = 0;
+ B = 1;
+}
diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence_default.json b/tests/inputs/proto3_field_presence/proto3_field_presence_default.json
new file mode 100644
index 0000000..0967ef4
--- /dev/null
+++ b/tests/inputs/proto3_field_presence/proto3_field_presence_default.json
@@ -0,0 +1 @@
+{}
diff --git a/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json b/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json
new file mode 100644
index 0000000..b19ae98
--- /dev/null
+++ b/tests/inputs/proto3_field_presence/proto3_field_presence_missing.json
@@ -0,0 +1,9 @@
+{
+ "test1": 0,
+ "test2": false,
+ "test3": "",
+ "test4": "",
+ "test6": "A",
+ "test7": "0",
+ "test8": 0
+}
diff --git a/tests/inputs/proto3_field_presence/test_proto3_field_presence.py b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py
new file mode 100644
index 0000000..80696b2
--- /dev/null
+++ b/tests/inputs/proto3_field_presence/test_proto3_field_presence.py
@@ -0,0 +1,48 @@
+import json
+
+from tests.output_aristaproto.proto3_field_presence import (
+ InnerTest,
+ Test,
+ TestEnum,
+)
+
+
+def test_null_fields_json():
+ """Ensure that using "null" in JSON is equivalent to not specifying a
+ field, for fields with explicit presence"""
+
+ def test_json(ref_json: str, obj_json: str) -> None:
+ """`ref_json` and `obj_json` are JSON strings describing a `Test` object.
+ Test that deserializing both leads to the same object, and that
+ `ref_json` is the normalized format."""
+ ref_obj = Test().from_json(ref_json)
+ obj = Test().from_json(obj_json)
+
+ assert obj == ref_obj
+ assert json.loads(obj.to_json(0)) == json.loads(ref_json)
+
+ test_json("{}", '{ "test1": null, "test2": null, "test3": null }')
+ test_json("{}", '{ "test4": null, "test5": null, "test6": null }')
+ test_json("{}", '{ "test7": null, "test8": null }')
+ test_json('{ "test5": {} }', '{ "test3": null, "test5": {} }')
+
+ # Make sure that if include_default_values is set, None values are
+ # exported.
+ obj = Test()
+ assert obj.to_dict() == {}
+ assert obj.to_dict(include_default_values=True) == {
+ "test1": None,
+ "test2": None,
+ "test3": None,
+ "test4": None,
+ "test5": None,
+ "test6": None,
+ "test7": None,
+ "test8": None,
+ "test9": None,
+ }
+
+
+def test_unset_access(): # see #523
+ assert Test().test1 is None
+ assert Test(test1=None).test1 is None
diff --git a/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json
new file mode 100644
index 0000000..da08192
--- /dev/null
+++ b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.json
@@ -0,0 +1,3 @@
+{
+ "nested": {}
+}
diff --git a/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto
new file mode 100644
index 0000000..caa76ec
--- /dev/null
+++ b/tests/inputs/proto3_field_presence_oneof/proto3_field_presence_oneof.proto
@@ -0,0 +1,22 @@
+syntax = "proto3";
+
+package proto3_field_presence_oneof;
+
+message Test {
+ oneof kind {
+ Nested nested = 1;
+ WithOptional with_optional = 2;
+ }
+}
+
+message InnerNested {
+ optional bool a = 1;
+}
+
+message Nested {
+ InnerNested inner = 1;
+}
+
+message WithOptional {
+ optional bool b = 2;
+}
diff --git a/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py
new file mode 100644
index 0000000..f13c973
--- /dev/null
+++ b/tests/inputs/proto3_field_presence_oneof/test_proto3_field_presence_oneof.py
@@ -0,0 +1,29 @@
+from tests.output_aristaproto.proto3_field_presence_oneof import (
+ InnerNested,
+ Nested,
+ Test,
+ WithOptional,
+)
+
+
+def test_serialization():
+ """Ensure that serialization of fields unset but with explicit field
+ presence do not bloat the serialized payload with length-delimited fields
+ with length 0"""
+
+ def test_empty_nested(message: Test) -> None:
+ # '0a' => tag 1, length delimited
+ # '00' => length: 0
+ assert bytes(message) == bytearray.fromhex("0a 00")
+
+ test_empty_nested(Test(nested=Nested()))
+ test_empty_nested(Test(nested=Nested(inner=None)))
+ test_empty_nested(Test(nested=Nested(inner=InnerNested(a=None))))
+
+ def test_empty_with_optional(message: Test) -> None:
+ # '12' => tag 2, length delimited
+ # '00' => length: 0
+ assert bytes(message) == bytearray.fromhex("12 00")
+
+ test_empty_with_optional(Test(with_optional=WithOptional()))
+ test_empty_with_optional(Test(with_optional=WithOptional(b=None)))
diff --git a/tests/inputs/recursivemessage/recursivemessage.json b/tests/inputs/recursivemessage/recursivemessage.json
new file mode 100644
index 0000000..e92c3fb
--- /dev/null
+++ b/tests/inputs/recursivemessage/recursivemessage.json
@@ -0,0 +1,12 @@
+{
+ "name": "Zues",
+ "child": {
+ "name": "Hercules"
+ },
+ "intermediate": {
+ "child": {
+ "name": "Douglas Adams"
+ },
+ "number": 42
+ }
+}
diff --git a/tests/inputs/recursivemessage/recursivemessage.proto b/tests/inputs/recursivemessage/recursivemessage.proto
new file mode 100644
index 0000000..1da2b57
--- /dev/null
+++ b/tests/inputs/recursivemessage/recursivemessage.proto
@@ -0,0 +1,15 @@
+syntax = "proto3";
+
+package recursivemessage;
+
+message Test {
+ string name = 1;
+ Test child = 2;
+ Intermediate intermediate = 3;
+}
+
+
+message Intermediate {
+ int32 number = 1;
+ Test child = 2;
+}
diff --git a/tests/inputs/ref/ref.json b/tests/inputs/ref/ref.json
new file mode 100644
index 0000000..2c6bdc1
--- /dev/null
+++ b/tests/inputs/ref/ref.json
@@ -0,0 +1,5 @@
+{
+ "greeting": {
+ "greeting": "hello"
+ }
+}
diff --git a/tests/inputs/ref/ref.proto b/tests/inputs/ref/ref.proto
new file mode 100644
index 0000000..6945590
--- /dev/null
+++ b/tests/inputs/ref/ref.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package ref;
+
+import "repeatedmessage.proto";
+
+message Test {
+ repeatedmessage.Sub greeting = 1;
+}
diff --git a/tests/inputs/ref/repeatedmessage.proto b/tests/inputs/ref/repeatedmessage.proto
new file mode 100644
index 0000000..0ffacaf
--- /dev/null
+++ b/tests/inputs/ref/repeatedmessage.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package repeatedmessage;
+
+message Test {
+ repeated Sub greetings = 1;
+}
+
+message Sub {
+ string greeting = 1;
+} \ No newline at end of file
diff --git a/tests/inputs/regression_387/regression_387.proto b/tests/inputs/regression_387/regression_387.proto
new file mode 100644
index 0000000..57bd954
--- /dev/null
+++ b/tests/inputs/regression_387/regression_387.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+package regression_387;
+
+message Test {
+ uint64 id = 1;
+}
+
+message ParentElement {
+ string name = 1;
+ repeated Test elems = 2;
+} \ No newline at end of file
diff --git a/tests/inputs/regression_387/test_regression_387.py b/tests/inputs/regression_387/test_regression_387.py
new file mode 100644
index 0000000..92d96ba
--- /dev/null
+++ b/tests/inputs/regression_387/test_regression_387.py
@@ -0,0 +1,12 @@
+from tests.output_aristaproto.regression_387 import (
+ ParentElement,
+ Test,
+)
+
+
+def test_regression_387():
+ el = ParentElement(name="test", elems=[Test(id=0), Test(id=42)])
+ binary = bytes(el)
+ decoded = ParentElement().parse(binary)
+ assert decoded == el
+ assert decoded.elems == [Test(id=0), Test(id=42)]
diff --git a/tests/inputs/regression_414/regression_414.proto b/tests/inputs/regression_414/regression_414.proto
new file mode 100644
index 0000000..d20ddda
--- /dev/null
+++ b/tests/inputs/regression_414/regression_414.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package regression_414;
+
+message Test {
+ bytes body = 1;
+ bytes auth = 2;
+ repeated bytes signatures = 3;
+} \ No newline at end of file
diff --git a/tests/inputs/regression_414/test_regression_414.py b/tests/inputs/regression_414/test_regression_414.py
new file mode 100644
index 0000000..9441470
--- /dev/null
+++ b/tests/inputs/regression_414/test_regression_414.py
@@ -0,0 +1,15 @@
+from tests.output_aristaproto.regression_414 import Test
+
+
+def test_full_cycle():
+ body = bytes([0, 1])
+ auth = bytes([2, 3])
+ sig = [b""]
+
+ obj = Test(body=body, auth=auth, signatures=sig)
+
+ decoded = Test().parse(bytes(obj))
+ assert decoded == obj
+ assert decoded.body == body
+ assert decoded.auth == auth
+ assert decoded.signatures == sig
diff --git a/tests/inputs/repeated/repeated.json b/tests/inputs/repeated/repeated.json
new file mode 100644
index 0000000..b8a7c4e
--- /dev/null
+++ b/tests/inputs/repeated/repeated.json
@@ -0,0 +1,3 @@
+{
+ "names": ["one", "two", "three"]
+}
diff --git a/tests/inputs/repeated/repeated.proto b/tests/inputs/repeated/repeated.proto
new file mode 100644
index 0000000..4f3c788
--- /dev/null
+++ b/tests/inputs/repeated/repeated.proto
@@ -0,0 +1,7 @@
+syntax = "proto3";
+
+package repeated;
+
+message Test {
+ repeated string names = 1;
+}
diff --git a/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json
new file mode 100644
index 0000000..6ce7b34
--- /dev/null
+++ b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.json
@@ -0,0 +1,4 @@
+{
+ "times": ["1972-01-01T10:00:20.021Z", "1972-01-01T10:00:20.021Z"],
+ "durations": ["1.200s", "1.200s"]
+}
diff --git a/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto
new file mode 100644
index 0000000..38f1eaa
--- /dev/null
+++ b/tests/inputs/repeated_duration_timestamp/repeated_duration_timestamp.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+package repeated_duration_timestamp;
+
+import "google/protobuf/duration.proto";
+import "google/protobuf/timestamp.proto";
+
+
+message Test {
+ repeated google.protobuf.Timestamp times = 1;
+ repeated google.protobuf.Duration durations = 2;
+}
diff --git a/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py b/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py
new file mode 100644
index 0000000..aafc951
--- /dev/null
+++ b/tests/inputs/repeated_duration_timestamp/test_repeated_duration_timestamp.py
@@ -0,0 +1,12 @@
+from datetime import (
+ datetime,
+ timedelta,
+)
+
+from tests.output_aristaproto.repeated_duration_timestamp import Test
+
+
+def test_roundtrip():
+ message = Test()
+ message.times = [datetime.now(), datetime.now()]
+ message.durations = [timedelta(), timedelta()]
diff --git a/tests/inputs/repeatedmessage/repeatedmessage.json b/tests/inputs/repeatedmessage/repeatedmessage.json
new file mode 100644
index 0000000..90ec596
--- /dev/null
+++ b/tests/inputs/repeatedmessage/repeatedmessage.json
@@ -0,0 +1,10 @@
+{
+ "greetings": [
+ {
+ "greeting": "hello"
+ },
+ {
+ "greeting": "hi"
+ }
+ ]
+}
diff --git a/tests/inputs/repeatedmessage/repeatedmessage.proto b/tests/inputs/repeatedmessage/repeatedmessage.proto
new file mode 100644
index 0000000..0ffacaf
--- /dev/null
+++ b/tests/inputs/repeatedmessage/repeatedmessage.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package repeatedmessage;
+
+message Test {
+ repeated Sub greetings = 1;
+}
+
+message Sub {
+ string greeting = 1;
+} \ No newline at end of file
diff --git a/tests/inputs/repeatedpacked/repeatedpacked.json b/tests/inputs/repeatedpacked/repeatedpacked.json
new file mode 100644
index 0000000..106fd90
--- /dev/null
+++ b/tests/inputs/repeatedpacked/repeatedpacked.json
@@ -0,0 +1,5 @@
+{
+ "counts": [1, 2, -1, -2],
+ "signed": ["1", "2", "-1", "-2"],
+ "fixed": [1.0, 2.7, 3.4]
+}
diff --git a/tests/inputs/repeatedpacked/repeatedpacked.proto b/tests/inputs/repeatedpacked/repeatedpacked.proto
new file mode 100644
index 0000000..a037d1b
--- /dev/null
+++ b/tests/inputs/repeatedpacked/repeatedpacked.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package repeatedpacked;
+
+message Test {
+ repeated int32 counts = 1;
+ repeated sint64 signed = 2;
+ repeated double fixed = 3;
+}
diff --git a/tests/inputs/service/service.proto b/tests/inputs/service/service.proto
new file mode 100644
index 0000000..53d84fb
--- /dev/null
+++ b/tests/inputs/service/service.proto
@@ -0,0 +1,35 @@
+syntax = "proto3";
+
+package service;
+
+enum ThingType {
+ UNKNOWN = 0;
+ LIVING = 1;
+ DEAD = 2;
+}
+
+message DoThingRequest {
+ string name = 1;
+ repeated string comments = 2;
+ ThingType type = 3;
+}
+
+message DoThingResponse {
+ repeated string names = 1;
+}
+
+message GetThingRequest {
+ string name = 1;
+}
+
+message GetThingResponse {
+ string name = 1;
+ int32 version = 2;
+}
+
+service Test {
+ rpc DoThing (DoThingRequest) returns (DoThingResponse);
+ rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse);
+ rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse);
+ rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse);
+}
diff --git a/tests/inputs/service_separate_packages/messages.proto b/tests/inputs/service_separate_packages/messages.proto
new file mode 100644
index 0000000..270b188
--- /dev/null
+++ b/tests/inputs/service_separate_packages/messages.proto
@@ -0,0 +1,31 @@
+syntax = "proto3";
+
+import "google/protobuf/duration.proto";
+import "google/protobuf/timestamp.proto";
+
+package service_separate_packages.things.messages;
+
+message DoThingRequest {
+ string name = 1;
+
+ // use `repeated` so we can check if `List` is correctly imported
+ repeated string comments = 2;
+
+ // use google types `timestamp` and `duration` so we can check
+ // if everything from `datetime` is correctly imported
+ google.protobuf.Timestamp when = 3;
+ google.protobuf.Duration duration = 4;
+}
+
+message DoThingResponse {
+ repeated string names = 1;
+}
+
+message GetThingRequest {
+ string name = 1;
+}
+
+message GetThingResponse {
+ string name = 1;
+ int32 version = 2;
+}
diff --git a/tests/inputs/service_separate_packages/service.proto b/tests/inputs/service_separate_packages/service.proto
new file mode 100644
index 0000000..950eab4
--- /dev/null
+++ b/tests/inputs/service_separate_packages/service.proto
@@ -0,0 +1,12 @@
+syntax = "proto3";
+
+import "messages.proto";
+
+package service_separate_packages.things.service;
+
+service Test {
+ rpc DoThing (things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
+ rpc DoManyThings (stream things.messages.DoThingRequest) returns (things.messages.DoThingResponse);
+ rpc GetThingVersions (things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
+ rpc GetDifferentThings (stream things.messages.GetThingRequest) returns (stream things.messages.GetThingResponse);
+}
diff --git a/tests/inputs/service_uppercase/service.proto b/tests/inputs/service_uppercase/service.proto
new file mode 100644
index 0000000..786eec2
--- /dev/null
+++ b/tests/inputs/service_uppercase/service.proto
@@ -0,0 +1,16 @@
+syntax = "proto3";
+
+package service_uppercase;
+
+message DoTHINGRequest {
+ string name = 1;
+ repeated string comments = 2;
+}
+
+message DoTHINGResponse {
+ repeated string names = 1;
+}
+
+service Test {
+ rpc DoThing (DoTHINGRequest) returns (DoTHINGResponse);
+}
diff --git a/tests/inputs/service_uppercase/test_service.py b/tests/inputs/service_uppercase/test_service.py
new file mode 100644
index 0000000..d10fccf
--- /dev/null
+++ b/tests/inputs/service_uppercase/test_service.py
@@ -0,0 +1,8 @@
+import inspect
+
+from tests.output_aristaproto.service_uppercase import TestStub
+
+
+def test_parameters():
+ sig = inspect.signature(TestStub.do_thing)
+ assert len(sig.parameters) == 5, "Expected 5 parameters"
diff --git a/tests/inputs/signed/signed.json b/tests/inputs/signed/signed.json
new file mode 100644
index 0000000..b171e15
--- /dev/null
+++ b/tests/inputs/signed/signed.json
@@ -0,0 +1,6 @@
+{
+ "signed32": 150,
+ "negative32": -150,
+ "string64": "150",
+ "negative64": "-150"
+}
diff --git a/tests/inputs/signed/signed.proto b/tests/inputs/signed/signed.proto
new file mode 100644
index 0000000..b40aad4
--- /dev/null
+++ b/tests/inputs/signed/signed.proto
@@ -0,0 +1,11 @@
+syntax = "proto3";
+
+package signed;
+
+message Test {
+ // todo: rename fields after fixing bug where 'signed_32_positive' will map to 'signed_32Positive' as output json
+ sint32 signed32 = 1; // signed_32_positive
+ sint32 negative32 = 2; // signed_32_negative
+ sint64 string64 = 3; // signed_64_positive
+ sint64 negative64 = 4; // signed_64_negative
+}
diff --git a/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py b/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py
new file mode 100644
index 0000000..59be3d1
--- /dev/null
+++ b/tests/inputs/timestamp_dict_encode/test_timestamp_dict_encode.py
@@ -0,0 +1,82 @@
+from datetime import (
+ datetime,
+ timedelta,
+ timezone,
+)
+
+import pytest
+
+from tests.output_aristaproto.timestamp_dict_encode import Test
+
+
+# Current World Timezone range (UTC-12 to UTC+14)
+MIN_UTC_OFFSET_MIN = -12 * 60
+MAX_UTC_OFFSET_MIN = 14 * 60
+
+# Generate all timezones in range in 15 min increments
+timezones = [
+ timezone(timedelta(minutes=x))
+ for x in range(MIN_UTC_OFFSET_MIN, MAX_UTC_OFFSET_MIN + 1, 15)
+]
+
+
+@pytest.mark.parametrize("tz", timezones)
+def test_timezone_aware_datetime_dict_encode(tz: timezone):
+ original_time = datetime.now(tz=tz)
+ original_message = Test()
+ original_message.ts = original_time
+ encoded = original_message.to_dict()
+ decoded_message = Test()
+ decoded_message.from_dict(encoded)
+
+ # check that the timestamps are equal after decoding from dict
+ assert original_message.ts.tzinfo is not None
+ assert decoded_message.ts.tzinfo is not None
+ assert original_message.ts == decoded_message.ts
+
+
+def test_naive_datetime_dict_encode():
+ # make suer naive datetime objects are still treated as utc
+ original_time = datetime.now()
+ assert original_time.tzinfo is None
+ original_message = Test()
+ original_message.ts = original_time
+ original_time_utc = original_time.replace(tzinfo=timezone.utc)
+ encoded = original_message.to_dict()
+ decoded_message = Test()
+ decoded_message.from_dict(encoded)
+
+ # check that the timestamps are equal after decoding from dict
+ assert decoded_message.ts.tzinfo is not None
+ assert original_time_utc == decoded_message.ts
+
+
+@pytest.mark.parametrize("tz", timezones)
+def test_timezone_aware_json_serialize(tz: timezone):
+ original_time = datetime.now(tz=tz)
+ original_message = Test()
+ original_message.ts = original_time
+ json_serialized = original_message.to_json()
+ decoded_message = Test()
+ decoded_message.from_json(json_serialized)
+
+ # check that the timestamps are equal after decoding from dict
+ assert original_message.ts.tzinfo is not None
+ assert decoded_message.ts.tzinfo is not None
+ assert original_message.ts == decoded_message.ts
+
+
+def test_naive_datetime_json_serialize():
+ # make suer naive datetime objects are still treated as utc
+ original_time = datetime.now()
+ assert original_time.tzinfo is None
+ original_message = Test()
+ original_message.ts = original_time
+ original_time_utc = original_time.replace(tzinfo=timezone.utc)
+ json_serialized = original_message.to_json()
+ decoded_message = Test()
+ decoded_message.from_json(json_serialized)
+
+ # check that the timestamps are equal after decoding from dict
+ assert decoded_message.ts.tzinfo is not None
+ assert original_time_utc == decoded_message.ts
diff --git a/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json
new file mode 100644
index 0000000..3f45558
--- /dev/null
+++ b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.json
@@ -0,0 +1,3 @@
+{
+ "ts" : "2023-03-15T22:35:51.253277Z"
+} \ No newline at end of file
diff --git a/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto
new file mode 100644
index 0000000..9c4081a
--- /dev/null
+++ b/tests/inputs/timestamp_dict_encode/timestamp_dict_encode.proto
@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+package timestamp_dict_encode;
+
+import "google/protobuf/timestamp.proto";
+
+message Test {
+ google.protobuf.Timestamp ts = 1;
+} \ No newline at end of file
diff --git a/tests/mocks.py b/tests/mocks.py
new file mode 100644
index 0000000..dc6e117
--- /dev/null
+++ b/tests/mocks.py
@@ -0,0 +1,40 @@
+from typing import List
+
+from grpclib.client import Channel
+
+
+class MockChannel(Channel):
+ # noinspection PyMissingConstructor
+ def __init__(self, responses=None) -> None:
+ self.responses = responses or []
+ self.requests = []
+ self._loop = None
+
+ def request(self, route, cardinality, request, response_type, **kwargs):
+ self.requests.append(
+ {
+ "route": route,
+ "cardinality": cardinality,
+ "request": request,
+ "response_type": response_type,
+ }
+ )
+ return MockStream(self.responses)
+
+
+class MockStream:
+ def __init__(self, responses: List) -> None:
+ super().__init__()
+ self.responses = responses
+
+ async def recv_message(self):
+ return self.responses.pop(0)
+
+ async def send_message(self, *args, **kwargs):
+ pass
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ return True
+
+ async def __aenter__(self):
+ return self
diff --git a/tests/oneof_pattern_matching.py b/tests/oneof_pattern_matching.py
new file mode 100644
index 0000000..2c5e797
--- /dev/null
+++ b/tests/oneof_pattern_matching.py
@@ -0,0 +1,46 @@
+from dataclasses import dataclass
+
+import pytest
+
+import aristaproto
+
+
+def test_oneof_pattern_matching():
+ @dataclass
+ class Sub(aristaproto.Message):
+ val: int = aristaproto.int32_field(1)
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: int = aristaproto.int32_field(1, group="group1")
+ baz: str = aristaproto.string_field(2, group="group1")
+ sub: Sub = aristaproto.message_field(3, group="group2")
+ abc: str = aristaproto.string_field(4, group="group2")
+
+ foo = Foo(baz="test1", abc="test2")
+
+ match foo:
+ case Foo(bar=_):
+ pytest.fail("Matched 'bar' instead of 'baz'")
+ case Foo(baz=v):
+ assert v == "test1"
+ case _:
+ pytest.fail("Matched neither 'bar' nor 'baz'")
+
+ match foo:
+ case Foo(sub=_):
+ pytest.fail("Matched 'sub' instead of 'abc'")
+ case Foo(abc=v):
+ assert v == "test2"
+ case _:
+ pytest.fail("Matched neither 'sub' nor 'abc'")
+
+ foo.sub = Sub(val=1)
+
+ match foo:
+ case Foo(sub=Sub(val=v)):
+ assert v == 1
+ case Foo(abc=v):
+ pytest.fail("Matched 'abc' instead of 'sub'")
+ case _:
+ pytest.fail("Matched neither 'sub' nor 'abc'")
diff --git a/tests/streams/delimited_messages.in b/tests/streams/delimited_messages.in
new file mode 100644
index 0000000..5993ac6
--- /dev/null
+++ b/tests/streams/delimited_messages.in
@@ -0,0 +1,2 @@
+•šï:bTesting•šï:bTesting
+  \ No newline at end of file
diff --git a/tests/streams/dump_varint_negative.expected b/tests/streams/dump_varint_negative.expected
new file mode 100644
index 0000000..0954822
--- /dev/null
+++ b/tests/streams/dump_varint_negative.expected
@@ -0,0 +1 @@
+ÿÿÿÿÿÿÿÿÿ€Óûÿÿÿÿÿ€€€€€€€€€€€€€€€€€ \ No newline at end of file
diff --git a/tests/streams/dump_varint_positive.expected b/tests/streams/dump_varint_positive.expected
new file mode 100644
index 0000000..8614b9d
--- /dev/null
+++ b/tests/streams/dump_varint_positive.expected
@@ -0,0 +1 @@
+€­â \ No newline at end of file
diff --git a/tests/streams/java/.gitignore b/tests/streams/java/.gitignore
new file mode 100644
index 0000000..9b1ebba
--- /dev/null
+++ b/tests/streams/java/.gitignore
@@ -0,0 +1,38 @@
+### Output ###
+target/
+!.mvn/wrapper/maven-wrapper.jar
+!**/src/main/**/target/
+!**/src/test/**/target/
+dependency-reduced-pom.xml
+MANIFEST.MF
+
+### IntelliJ IDEA ###
+.idea/
+*.iws
+*.iml
+*.ipr
+
+### Eclipse ###
+.apt_generated
+.classpath
+.factorypath
+.project
+.settings
+.springBeans
+.sts4-cache
+
+### NetBeans ###
+/nbproject/private/
+/nbbuild/
+/dist/
+/nbdist/
+/.nb-gradle/
+build/
+!**/src/main/**/build/
+!**/src/test/**/build/
+
+### VS Code ###
+.vscode/
+
+### Mac OS ###
+.DS_Store \ No newline at end of file
diff --git a/tests/streams/java/pom.xml b/tests/streams/java/pom.xml
new file mode 100644
index 0000000..e39c567
--- /dev/null
+++ b/tests/streams/java/pom.xml
@@ -0,0 +1,94 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <groupId>aristaproto</groupId>
+ <artifactId>compatibility-test</artifactId>
+ <version>1.0-SNAPSHOT</version>
+ <packaging>jar</packaging>
+
+ <properties>
+ <maven.compiler.source>11</maven.compiler.source>
+ <maven.compiler.target>11</maven.compiler.target>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <protobuf.version>3.23.4</protobuf.version>
+ </properties>
+
+ <dependencies>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <extensions>
+ <extension>
+ <groupId>kr.motd.maven</groupId>
+ <artifactId>os-maven-plugin</artifactId>
+ <version>1.7.1</version>
+ </extension>
+ </extensions>
+
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ <version>3.5.0</version>
+ <executions>
+ <execution>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <transformers>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
+ <mainClass>aristaproto.CompatibilityTest</mainClass>
+ </transformer>
+ </transformers>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <version>3.3.0</version>
+ <configuration>
+ <archive>
+ <manifest>
+ <addClasspath>true</addClasspath>
+ <mainClass>aristaproto.CompatibilityTest</mainClass>
+ </manifest>
+ </archive>
+ </configuration>
+ </plugin>
+
+ <plugin>
+ <groupId>org.xolstice.maven.plugins</groupId>
+ <artifactId>protobuf-maven-plugin</artifactId>
+ <version>0.6.1</version>
+ <executions>
+ <execution>
+ <goals>
+ <goal>compile</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <protocArtifact>
+ com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
+ </protocArtifact>
+ </configuration>
+ </plugin>
+ </plugins>
+
+ <finalName>${project.artifactId}</finalName>
+ </build>
+
+</project> \ No newline at end of file
diff --git a/tests/streams/java/src/main/java/aristaproto/CompatibilityTest.java b/tests/streams/java/src/main/java/aristaproto/CompatibilityTest.java
new file mode 100644
index 0000000..b0cff9f
--- /dev/null
+++ b/tests/streams/java/src/main/java/aristaproto/CompatibilityTest.java
@@ -0,0 +1,41 @@
+package aristaproto;
+
+import java.io.IOException;
+
+public class CompatibilityTest {
+ public static void main(String[] args) throws IOException {
+ if (args.length < 2)
+ throw new RuntimeException("Attempted to run without the required arguments.");
+ else if (args.length > 2)
+ throw new RuntimeException(
+ "Attempted to run with more than the expected number of arguments (>1).");
+
+ Tests tests = new Tests(args[1]);
+
+ switch (args[0]) {
+ case "single_varint":
+ tests.testSingleVarint();
+ break;
+
+ case "multiple_varints":
+ tests.testMultipleVarints();
+ break;
+
+ case "single_message":
+ tests.testSingleMessage();
+ break;
+
+ case "multiple_messages":
+ tests.testMultipleMessages();
+ break;
+
+ case "infinite_messages":
+ tests.testInfiniteMessages();
+ break;
+
+ default:
+ throw new RuntimeException(
+ "Attempted to run with unknown argument '" + args[0] + "'.");
+ }
+ }
+}
diff --git a/tests/streams/java/src/main/java/aristaproto/Tests.java b/tests/streams/java/src/main/java/aristaproto/Tests.java
new file mode 100644
index 0000000..aabbac7
--- /dev/null
+++ b/tests/streams/java/src/main/java/aristaproto/Tests.java
@@ -0,0 +1,115 @@
+package aristaproto;
+
+import aristaproto.nested.NestedOuterClass;
+import aristaproto.oneof.Oneof;
+
+import com.google.protobuf.CodedInputStream;
+import com.google.protobuf.CodedOutputStream;
+
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+public class Tests {
+ String path;
+
+ public Tests(String path) {
+ this.path = path;
+ }
+
+ public void testSingleVarint() throws IOException {
+ // Read in the Python-generated single varint file
+ FileInputStream inputStream = new FileInputStream(path + "/py_single_varint.out");
+ CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
+
+ int value = codedInput.readUInt32();
+
+ inputStream.close();
+
+ // Write the value back to a file
+ FileOutputStream outputStream = new FileOutputStream(path + "/java_single_varint.out");
+ CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
+
+ codedOutput.writeUInt32NoTag(value);
+
+ codedOutput.flush();
+ outputStream.close();
+ }
+
+ public void testMultipleVarints() throws IOException {
+ // Read in the Python-generated multiple varints file
+ FileInputStream inputStream = new FileInputStream(path + "/py_multiple_varints.out");
+ CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
+
+ int value1 = codedInput.readUInt32();
+ int value2 = codedInput.readUInt32();
+ long value3 = codedInput.readUInt64();
+
+ inputStream.close();
+
+ // Write the values back to a file
+ FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_varints.out");
+ CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
+
+ codedOutput.writeUInt32NoTag(value1);
+ codedOutput.writeUInt64NoTag(value2);
+ codedOutput.writeUInt64NoTag(value3);
+
+ codedOutput.flush();
+ outputStream.close();
+ }
+
+ public void testSingleMessage() throws IOException {
+ // Read in the Python-generated single message file
+ FileInputStream inputStream = new FileInputStream(path + "/py_single_message.out");
+ CodedInputStream codedInput = CodedInputStream.newInstance(inputStream);
+
+ Oneof.Test message = Oneof.Test.parseFrom(codedInput);
+
+ inputStream.close();
+
+ // Write the message back to a file
+ FileOutputStream outputStream = new FileOutputStream(path + "/java_single_message.out");
+ CodedOutputStream codedOutput = CodedOutputStream.newInstance(outputStream);
+
+ message.writeTo(codedOutput);
+
+ codedOutput.flush();
+ outputStream.close();
+ }
+
+ public void testMultipleMessages() throws IOException {
+ // Read in the Python-generated multi-message file
+ FileInputStream inputStream = new FileInputStream(path + "/py_multiple_messages.out");
+
+ Oneof.Test oneof = Oneof.Test.parseDelimitedFrom(inputStream);
+ NestedOuterClass.Test nested = NestedOuterClass.Test.parseDelimitedFrom(inputStream);
+
+ inputStream.close();
+
+ // Write the messages back to a file
+ FileOutputStream outputStream = new FileOutputStream(path + "/java_multiple_messages.out");
+
+ oneof.writeDelimitedTo(outputStream);
+ nested.writeDelimitedTo(outputStream);
+
+ outputStream.flush();
+ outputStream.close();
+ }
+
+ public void testInfiniteMessages() throws IOException {
+ // Read in as many messages as are present in the Python-generated file and write them back
+ FileInputStream inputStream = new FileInputStream(path + "/py_infinite_messages.out");
+ FileOutputStream outputStream = new FileOutputStream(path + "/java_infinite_messages.out");
+
+ Oneof.Test current = Oneof.Test.parseDelimitedFrom(inputStream);
+ while (current != null) {
+ current.writeDelimitedTo(outputStream);
+ current = Oneof.Test.parseDelimitedFrom(inputStream);
+ }
+
+ inputStream.close();
+ outputStream.flush();
+ outputStream.close();
+ }
+}
diff --git a/tests/streams/java/src/main/proto/aristaproto/nested.proto b/tests/streams/java/src/main/proto/aristaproto/nested.proto
new file mode 100644
index 0000000..46a5783
--- /dev/null
+++ b/tests/streams/java/src/main/proto/aristaproto/nested.proto
@@ -0,0 +1,27 @@
+syntax = "proto3";
+
+package nested;
+option java_package = "aristaproto.nested";
+
+// A test message with a nested message inside of it.
+message Test {
+ // This is the nested type.
+ message Nested {
+ // Stores a simple counter.
+ int32 count = 1;
+ }
+ // This is the nested enum.
+ enum Msg {
+ NONE = 0;
+ THIS = 1;
+ }
+
+ Nested nested = 1;
+ Sibling sibling = 2;
+ Sibling sibling2 = 3;
+ Msg msg = 4;
+}
+
+message Sibling {
+ int32 foo = 1;
+} \ No newline at end of file
diff --git a/tests/streams/java/src/main/proto/aristaproto/oneof.proto b/tests/streams/java/src/main/proto/aristaproto/oneof.proto
new file mode 100644
index 0000000..44a8949
--- /dev/null
+++ b/tests/streams/java/src/main/proto/aristaproto/oneof.proto
@@ -0,0 +1,19 @@
+syntax = "proto3";
+
+package oneof;
+option java_package = "aristaproto.oneof";
+
+message Test {
+ oneof foo {
+ int32 pitied = 1;
+ string pitier = 2;
+ }
+
+ int32 just_a_regular_field = 3;
+
+ oneof bar {
+ int32 drinks = 11;
+ string bar_name = 12;
+ }
+}
+
diff --git a/tests/streams/load_varint_cutoff.in b/tests/streams/load_varint_cutoff.in
new file mode 100644
index 0000000..52b9bf1
--- /dev/null
+++ b/tests/streams/load_varint_cutoff.in
@@ -0,0 +1 @@
+È \ No newline at end of file
diff --git a/tests/streams/message_dump_file_multiple.expected b/tests/streams/message_dump_file_multiple.expected
new file mode 100644
index 0000000..b5fdf9c
--- /dev/null
+++ b/tests/streams/message_dump_file_multiple.expected
@@ -0,0 +1,2 @@
+•šï:bTesting•šï:bTesting
+  \ No newline at end of file
diff --git a/tests/streams/message_dump_file_single.expected b/tests/streams/message_dump_file_single.expected
new file mode 100644
index 0000000..9b7bafb
--- /dev/null
+++ b/tests/streams/message_dump_file_single.expected
@@ -0,0 +1 @@
+•šï:bTesting \ No newline at end of file
diff --git a/tests/test_casing.py b/tests/test_casing.py
new file mode 100644
index 0000000..b16d326
--- /dev/null
+++ b/tests/test_casing.py
@@ -0,0 +1,129 @@
+import pytest
+
+from aristaproto.casing import (
+ camel_case,
+ pascal_case,
+ snake_case,
+)
+
+
+@pytest.mark.parametrize(
+ ["value", "expected"],
+ [
+ ("", ""),
+ ("a", "A"),
+ ("foobar", "Foobar"),
+ ("fooBar", "FooBar"),
+ ("FooBar", "FooBar"),
+ ("foo.bar", "FooBar"),
+ ("foo_bar", "FooBar"),
+ ("FOOBAR", "Foobar"),
+ ("FOOBar", "FooBar"),
+ ("UInt32", "UInt32"),
+ ("FOO_BAR", "FooBar"),
+ ("FOOBAR1", "Foobar1"),
+ ("FOOBAR_1", "Foobar1"),
+ ("FOO1BAR2", "Foo1Bar2"),
+ ("foo__bar", "FooBar"),
+ ("_foobar", "Foobar"),
+ ("foobaR", "FoobaR"),
+ ("foo~bar", "FooBar"),
+ ("foo:bar", "FooBar"),
+ ("1foobar", "1Foobar"),
+ ],
+)
+def test_pascal_case(value, expected):
+ actual = pascal_case(value, strict=True)
+ assert actual == expected, f"{value} => {expected} (actual: {actual})"
+
+
+@pytest.mark.parametrize(
+ ["value", "expected"],
+ [
+ ("", ""),
+ ("a", "a"),
+ ("foobar", "foobar"),
+ ("fooBar", "fooBar"),
+ ("FooBar", "fooBar"),
+ ("foo.bar", "fooBar"),
+ ("foo_bar", "fooBar"),
+ ("FOOBAR", "foobar"),
+ ("FOO_BAR", "fooBar"),
+ ("FOOBAR1", "foobar1"),
+ ("FOOBAR_1", "foobar1"),
+ ("FOO1BAR2", "foo1Bar2"),
+ ("foo__bar", "fooBar"),
+ ("_foobar", "foobar"),
+ ("foobaR", "foobaR"),
+ ("foo~bar", "fooBar"),
+ ("foo:bar", "fooBar"),
+ ("1foobar", "1Foobar"),
+ ],
+)
+def test_camel_case_strict(value, expected):
+ actual = camel_case(value, strict=True)
+ assert actual == expected, f"{value} => {expected} (actual: {actual})"
+
+
+@pytest.mark.parametrize(
+ ["value", "expected"],
+ [
+ ("foo_bar", "fooBar"),
+ ("FooBar", "fooBar"),
+ ("foo__bar", "foo_Bar"),
+ ("foo__Bar", "foo__Bar"),
+ ],
+)
+def test_camel_case_not_strict(value, expected):
+ actual = camel_case(value, strict=False)
+ assert actual == expected, f"{value} => {expected} (actual: {actual})"
+
+
+@pytest.mark.parametrize(
+ ["value", "expected"],
+ [
+ ("", ""),
+ ("a", "a"),
+ ("foobar", "foobar"),
+ ("fooBar", "foo_bar"),
+ ("FooBar", "foo_bar"),
+ ("foo.bar", "foo_bar"),
+ ("foo_bar", "foo_bar"),
+ ("foo_Bar", "foo_bar"),
+ ("FOOBAR", "foobar"),
+ ("FOOBar", "foo_bar"),
+ ("UInt32", "u_int32"),
+ ("FOO_BAR", "foo_bar"),
+ ("FOOBAR1", "foobar1"),
+ ("FOOBAR_1", "foobar_1"),
+ ("FOOBAR_123", "foobar_123"),
+ ("FOO1BAR2", "foo1_bar2"),
+ ("foo__bar", "foo_bar"),
+ ("_foobar", "foobar"),
+ ("foobaR", "fooba_r"),
+ ("foo~bar", "foo_bar"),
+ ("foo:bar", "foo_bar"),
+ ("1foobar", "1_foobar"),
+ ("GetUInt64", "get_u_int64"),
+ ],
+)
+def test_snake_case_strict(value, expected):
+ actual = snake_case(value)
+ assert actual == expected, f"{value} => {expected} (actual: {actual})"
+
+
+@pytest.mark.parametrize(
+ ["value", "expected"],
+ [
+ ("fooBar", "foo_bar"),
+ ("FooBar", "foo_bar"),
+ ("foo_Bar", "foo__bar"),
+ ("foo__bar", "foo__bar"),
+ ("FOOBar", "foo_bar"),
+ ("__foo", "__foo"),
+ ("GetUInt64", "get_u_int64"),
+ ],
+)
+def test_snake_case_not_strict(value, expected):
+ actual = snake_case(value, strict=False)
+ assert actual == expected, f"{value} => {expected} (actual: {actual})"
diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py
new file mode 100644
index 0000000..fd4de82
--- /dev/null
+++ b/tests/test_deprecated.py
@@ -0,0 +1,45 @@
+import warnings
+
+import pytest
+
+from tests.output_aristaproto.deprecated import (
+ Message,
+ Test,
+)
+
+
+@pytest.fixture
+def message():
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
+ return Message(value="hello")
+
+
+def test_deprecated_message():
+ with pytest.warns(DeprecationWarning) as record:
+ Message(value="hello")
+
+ assert len(record) == 1
+ assert str(record[0].message) == f"{Message.__name__} is deprecated"
+
+
+def test_message_with_deprecated_field(message):
+ with pytest.warns(DeprecationWarning) as record:
+ Test(message=message, value=10)
+
+ assert len(record) == 1
+ assert str(record[0].message) == f"{Test.__name__}.message is deprecated"
+
+
+def test_message_with_deprecated_field_not_set(message):
+ with pytest.warns(None) as record:
+ Test(value=10)
+
+ assert not record
+
+
+def test_message_with_deprecated_field_not_set_default(message):
+ with pytest.warns(None) as record:
+ _ = Test(value=10).message
+
+ assert not record
diff --git a/tests/test_enum.py b/tests/test_enum.py
new file mode 100644
index 0000000..807e785
--- /dev/null
+++ b/tests/test_enum.py
@@ -0,0 +1,79 @@
+from typing import (
+ Optional,
+ Tuple,
+)
+
+import pytest
+
+import aristaproto
+
+
+class Colour(aristaproto.Enum):
+ RED = 1
+ GREEN = 2
+ BLUE = 3
+
+
+PURPLE = Colour.__new__(Colour, name=None, value=4)
+
+
+@pytest.mark.parametrize(
+ "member, str_value",
+ [
+ (Colour.RED, "RED"),
+ (Colour.GREEN, "GREEN"),
+ (Colour.BLUE, "BLUE"),
+ ],
+)
+def test_str(member: Colour, str_value: str) -> None:
+ assert str(member) == str_value
+
+
+@pytest.mark.parametrize(
+ "member, repr_value",
+ [
+ (Colour.RED, "Colour.RED"),
+ (Colour.GREEN, "Colour.GREEN"),
+ (Colour.BLUE, "Colour.BLUE"),
+ ],
+)
+def test_repr(member: Colour, repr_value: str) -> None:
+ assert repr(member) == repr_value
+
+
+@pytest.mark.parametrize(
+ "member, values",
+ [
+ (Colour.RED, ("RED", 1)),
+ (Colour.GREEN, ("GREEN", 2)),
+ (Colour.BLUE, ("BLUE", 3)),
+ (PURPLE, (None, 4)),
+ ],
+)
+def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
+ assert (member.name, member.value) == values
+
+
+@pytest.mark.parametrize(
+ "member, input_str",
+ [
+ (Colour.RED, "RED"),
+ (Colour.GREEN, "GREEN"),
+ (Colour.BLUE, "BLUE"),
+ ],
+)
+def test_from_string(member: Colour, input_str: str) -> None:
+ assert Colour.from_string(input_str) == member
+
+
+@pytest.mark.parametrize(
+ "member, input_int",
+ [
+ (Colour.RED, 1),
+ (Colour.GREEN, 2),
+ (Colour.BLUE, 3),
+ (PURPLE, 4),
+ ],
+)
+def test_try_value(member: Colour, input_int: int) -> None:
+ assert Colour.try_value(input_int) == member
diff --git a/tests/test_features.py b/tests/test_features.py
new file mode 100644
index 0000000..638e668
--- /dev/null
+++ b/tests/test_features.py
@@ -0,0 +1,682 @@
+import json
+import sys
+from copy import (
+ copy,
+ deepcopy,
+)
+from dataclasses import dataclass
+from datetime import (
+ datetime,
+ timedelta,
+)
+from inspect import (
+ Parameter,
+ signature,
+)
+from typing import (
+ Dict,
+ List,
+ Optional,
+)
+from unittest.mock import ANY
+
+import pytest
+
+import aristaproto
+
+
+def test_has_field():
+ @dataclass
+ class Bar(aristaproto.Message):
+ baz: int = aristaproto.int32_field(1)
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: Bar = aristaproto.message_field(1)
+
+ # Unset by default
+ foo = Foo()
+ assert aristaproto.serialized_on_wire(foo.bar) is False
+
+ # Serialized after setting something
+ foo.bar.baz = 1
+ assert aristaproto.serialized_on_wire(foo.bar) is True
+
+ # Still has it after setting the default value
+ foo.bar.baz = 0
+ assert aristaproto.serialized_on_wire(foo.bar) is True
+
+ # Manual override (don't do this)
+ foo.bar._serialized_on_wire = False
+ assert aristaproto.serialized_on_wire(foo.bar) is False
+
+ # Can manually set it but defaults to false
+ foo.bar = Bar()
+ assert aristaproto.serialized_on_wire(foo.bar) is False
+
+ @dataclass
+ class WithCollections(aristaproto.Message):
+ test_list: List[str] = aristaproto.string_field(1)
+ test_map: Dict[str, str] = aristaproto.map_field(
+ 2, aristaproto.TYPE_STRING, aristaproto.TYPE_STRING
+ )
+
+ # Is always set from parse, even if all collections are empty
+ with_collections_empty = WithCollections().parse(bytes(WithCollections()))
+ assert aristaproto.serialized_on_wire(with_collections_empty) == True
+ with_collections_list = WithCollections().parse(
+ bytes(WithCollections(test_list=["a", "b", "c"]))
+ )
+ assert aristaproto.serialized_on_wire(with_collections_list) == True
+ with_collections_map = WithCollections().parse(
+ bytes(WithCollections(test_map={"a": "b", "c": "d"}))
+ )
+ assert aristaproto.serialized_on_wire(with_collections_map) == True
+
+
+def test_class_init():
+ @dataclass
+ class Bar(aristaproto.Message):
+ name: str = aristaproto.string_field(1)
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ name: str = aristaproto.string_field(1)
+ child: Bar = aristaproto.message_field(2)
+
+ foo = Foo(name="foo", child=Bar(name="bar"))
+
+ assert foo.to_dict() == {"name": "foo", "child": {"name": "bar"}}
+ assert foo.to_pydict() == {"name": "foo", "child": {"name": "bar"}}
+
+
+def test_enum_as_int_json():
+ class TestEnum(aristaproto.Enum):
+ ZERO = 0
+ ONE = 1
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: TestEnum = aristaproto.enum_field(1)
+
+ # JSON strings are supported, but ints should still be supported too.
+ foo = Foo().from_dict({"bar": 1})
+ assert foo.bar == TestEnum.ONE
+
+ # Plain-ol'-ints should serialize properly too.
+ foo.bar = 1
+ assert foo.to_dict() == {"bar": "ONE"}
+
+ # Similar expectations for pydict
+ foo = Foo().from_pydict({"bar": 1})
+ assert foo.bar == TestEnum.ONE
+ assert foo.to_pydict() == {"bar": TestEnum.ONE}
+
+
+def test_unknown_fields():
+ @dataclass
+ class Newer(aristaproto.Message):
+ foo: bool = aristaproto.bool_field(1)
+ bar: int = aristaproto.int32_field(2)
+ baz: str = aristaproto.string_field(3)
+
+ @dataclass
+ class Older(aristaproto.Message):
+ foo: bool = aristaproto.bool_field(1)
+
+ newer = Newer(foo=True, bar=1, baz="Hello")
+ serialized_newer = bytes(newer)
+
+ # Unknown fields in `Newer` should round trip with `Older`
+ round_trip = bytes(Older().parse(serialized_newer))
+ assert serialized_newer == round_trip
+
+ new_again = Newer().parse(round_trip)
+ assert newer == new_again
+
+
+def test_oneof_support():
+ @dataclass
+ class Sub(aristaproto.Message):
+ val: int = aristaproto.int32_field(1)
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: int = aristaproto.int32_field(1, group="group1")
+ baz: str = aristaproto.string_field(2, group="group1")
+ sub: Sub = aristaproto.message_field(3, group="group2")
+ abc: str = aristaproto.string_field(4, group="group2")
+
+ foo = Foo()
+
+ assert aristaproto.which_one_of(foo, "group1")[0] == ""
+
+ foo.bar = 1
+ foo.baz = "test"
+
+ # Other oneof fields should now be unset
+ assert not hasattr(foo, "bar")
+ assert object.__getattribute__(foo, "bar") == aristaproto.PLACEHOLDER
+ assert aristaproto.which_one_of(foo, "group1")[0] == "baz"
+
+ foo.sub = Sub(val=1)
+ assert aristaproto.serialized_on_wire(foo.sub)
+
+ foo.abc = "test"
+
+ # Group 1 shouldn't be touched, group 2 should have reset
+ assert not hasattr(foo, "sub")
+ assert object.__getattribute__(foo, "sub") == aristaproto.PLACEHOLDER
+ assert aristaproto.which_one_of(foo, "group2")[0] == "abc"
+
+ # Zero value should always serialize for one-of
+ foo = Foo(bar=0)
+ assert aristaproto.which_one_of(foo, "group1")[0] == "bar"
+ assert bytes(foo) == b"\x08\x00"
+
+ # Round trip should also work
+ foo2 = Foo().parse(bytes(foo))
+ assert aristaproto.which_one_of(foo2, "group1")[0] == "bar"
+ assert foo.bar == 0
+ assert aristaproto.which_one_of(foo2, "group2")[0] == ""
+
+
+@pytest.mark.skipif(
+ sys.version_info < (3, 10),
+ reason="pattern matching is only supported in python3.10+",
+)
+def test_oneof_pattern_matching():
+ from .oneof_pattern_matching import test_oneof_pattern_matching
+
+ test_oneof_pattern_matching()
+
+
+def test_json_casing():
+ @dataclass
+ class CasingTest(aristaproto.Message):
+ pascal_case: int = aristaproto.int32_field(1)
+ camel_case: int = aristaproto.int32_field(2)
+ snake_case: int = aristaproto.int32_field(3)
+ kabob_case: int = aristaproto.int32_field(4)
+
+ # Parsing should accept almost any input
+ test = CasingTest().from_dict(
+ {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
+ )
+
+ assert test == CasingTest(1, 2, 3, 4)
+
+ # Serializing should be strict.
+ assert json.loads(test.to_json()) == {
+ "pascalCase": 1,
+ "camelCase": 2,
+ "snakeCase": 3,
+ "kabobCase": 4,
+ }
+
+ assert json.loads(test.to_json(casing=aristaproto.Casing.SNAKE)) == {
+ "pascal_case": 1,
+ "camel_case": 2,
+ "snake_case": 3,
+ "kabob_case": 4,
+ }
+
+
+def test_dict_casing():
+ @dataclass
+ class CasingTest(aristaproto.Message):
+ pascal_case: int = aristaproto.int32_field(1)
+ camel_case: int = aristaproto.int32_field(2)
+ snake_case: int = aristaproto.int32_field(3)
+ kabob_case: int = aristaproto.int32_field(4)
+
+ # Parsing should accept almost any input
+ test = CasingTest().from_dict(
+ {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
+ )
+
+ assert test == CasingTest(1, 2, 3, 4)
+
+ # Serializing should be strict.
+ assert test.to_dict() == {
+ "pascalCase": 1,
+ "camelCase": 2,
+ "snakeCase": 3,
+ "kabobCase": 4,
+ }
+ assert test.to_pydict() == {
+ "pascalCase": 1,
+ "camelCase": 2,
+ "snakeCase": 3,
+ "kabobCase": 4,
+ }
+
+ assert test.to_dict(casing=aristaproto.Casing.SNAKE) == {
+ "pascal_case": 1,
+ "camel_case": 2,
+ "snake_case": 3,
+ "kabob_case": 4,
+ }
+ assert test.to_pydict(casing=aristaproto.Casing.SNAKE) == {
+ "pascal_case": 1,
+ "camel_case": 2,
+ "snake_case": 3,
+ "kabob_case": 4,
+ }
+
+
+def test_optional_flag():
+ @dataclass
+ class Request(aristaproto.Message):
+ flag: Optional[bool] = aristaproto.message_field(1, wraps=aristaproto.TYPE_BOOL)
+
+ # Serialization of not passed vs. set vs. zero-value.
+ assert bytes(Request()) == b""
+ assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
+ assert bytes(Request(flag=False)) == b"\n\x00"
+
+ # Differentiate between not passed and the zero-value.
+ assert Request().parse(b"").flag is None
+ assert Request().parse(b"\n\x00").flag is False
+
+
+def test_optional_datetime_to_dict():
+ @dataclass
+ class Request(aristaproto.Message):
+ date: Optional[datetime] = aristaproto.message_field(1, optional=True)
+
+ # Check dict serialization
+ assert Request().to_dict() == {}
+ assert Request().to_dict(include_default_values=True) == {"date": None}
+ assert Request(date=datetime(2020, 1, 1)).to_dict() == {
+ "date": "2020-01-01T00:00:00Z"
+ }
+ assert Request(date=datetime(2020, 1, 1)).to_dict(include_default_values=True) == {
+ "date": "2020-01-01T00:00:00Z"
+ }
+
+ # Check pydict serialization
+ assert Request().to_pydict() == {}
+ assert Request().to_pydict(include_default_values=True) == {"date": None}
+ assert Request(date=datetime(2020, 1, 1)).to_pydict() == {
+ "date": datetime(2020, 1, 1)
+ }
+ assert Request(date=datetime(2020, 1, 1)).to_pydict(
+ include_default_values=True
+ ) == {"date": datetime(2020, 1, 1)}
+
+
+def test_to_json_default_values():
+ @dataclass
+ class TestMessage(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_str: str = aristaproto.string_field(3)
+ some_bool: bool = aristaproto.bool_field(4)
+
+ # Empty dict
+ test = TestMessage().from_dict({})
+
+ assert json.loads(test.to_json(include_default_values=True)) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ # All default values
+ test = TestMessage().from_dict(
+ {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
+ )
+
+ assert json.loads(test.to_json(include_default_values=True)) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+
+def test_to_dict_default_values():
+ @dataclass
+ class TestMessage(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_str: str = aristaproto.string_field(3)
+ some_bool: bool = aristaproto.bool_field(4)
+
+ # Empty dict
+ test = TestMessage().from_dict({})
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ test = TestMessage().from_pydict({})
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ # All default values
+ test = TestMessage().from_dict(
+ {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
+ )
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ test = TestMessage().from_pydict(
+ {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
+ )
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ # Some default and some other values
+ @dataclass
+ class TestMessage2(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_str: str = aristaproto.string_field(3)
+ some_bool: bool = aristaproto.bool_field(4)
+ some_default_int: int = aristaproto.int32_field(5)
+ some_default_double: float = aristaproto.double_field(6)
+ some_default_str: str = aristaproto.string_field(7)
+ some_default_bool: bool = aristaproto.bool_field(8)
+
+ test = TestMessage2().from_dict(
+ {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+ )
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+
+ test = TestMessage2().from_pydict(
+ {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+ )
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+
+ # Nested messages
+ @dataclass
+ class TestChildMessage(aristaproto.Message):
+ some_other_int: int = aristaproto.int32_field(1)
+
+ @dataclass
+ class TestParentMessage(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_message: TestChildMessage = aristaproto.message_field(3)
+
+ test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2})
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 1.2,
+ "someMessage": {"someOtherInt": 0},
+ }
+
+ test = TestParentMessage().from_pydict({"someInt": 0, "someDouble": 1.2})
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 1.2,
+ "someMessage": {"someOtherInt": 0},
+ }
+
+
+def test_to_dict_datetime_values():
+ @dataclass
+ class TestDatetimeMessage(aristaproto.Message):
+ bar: datetime = aristaproto.message_field(1)
+ baz: timedelta = aristaproto.message_field(2)
+
+ test = TestDatetimeMessage().from_dict(
+ {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"}
+ )
+
+ assert test.to_dict() == {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"}
+
+ test = TestDatetimeMessage().from_pydict(
+ {"bar": datetime(year=2020, month=1, day=1), "baz": timedelta(days=1)}
+ )
+
+ assert test.to_pydict() == {
+ "bar": datetime(year=2020, month=1, day=1),
+ "baz": timedelta(days=1),
+ }
+
+
+def test_oneof_default_value_set_causes_writes_wire():
+ @dataclass
+ class Empty(aristaproto.Message):
+ pass
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: int = aristaproto.int32_field(1, group="group1")
+ baz: str = aristaproto.string_field(2, group="group1")
+ qux: Empty = aristaproto.message_field(3, group="group1")
+
+ def _round_trip_serialization(foo: Foo) -> Foo:
+ return Foo().parse(bytes(foo))
+
+ foo1 = Foo(bar=0)
+ foo2 = Foo(baz="")
+ foo3 = Foo(qux=Empty())
+ foo4 = Foo()
+
+ assert bytes(foo1) == b"\x08\x00"
+ assert (
+ aristaproto.which_one_of(foo1, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo1), "group1")
+ == ("bar", 0)
+ )
+
+ assert bytes(foo2) == b"\x12\x00" # Baz is just an empty string
+ assert (
+ aristaproto.which_one_of(foo2, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo2), "group1")
+ == ("baz", "")
+ )
+
+ assert bytes(foo3) == b"\x1a\x00"
+ assert (
+ aristaproto.which_one_of(foo3, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo3), "group1")
+ == ("qux", Empty())
+ )
+
+ assert bytes(foo4) == b""
+ assert (
+ aristaproto.which_one_of(foo4, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo4), "group1")
+ == ("", None)
+ )
+
+
+def test_message_repr():
+ from tests.output_aristaproto.recursivemessage import Test
+
+ assert repr(Test(name="Loki")) == "Test(name='Loki')"
+ assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"
+
+
+def test_bool():
+ """Messages should evaluate similarly to a collection
+ >>> test = []
+ >>> bool(test)
+ ... False
+ >>> test.append(1)
+ >>> bool(test)
+ ... True
+ >>> del test[0]
+ >>> bool(test)
+ ... False
+ """
+
+ @dataclass
+ class Falsy(aristaproto.Message):
+ pass
+
+ @dataclass
+ class Truthy(aristaproto.Message):
+ bar: int = aristaproto.int32_field(1)
+
+ assert not Falsy()
+ t = Truthy()
+ assert not t
+ t.bar = 1
+ assert t
+ t.bar = 0
+ assert not t
+
+
+# valid ISO datetimes according to https://www.myintervals.com/blog/2009/05/20/iso-8601-date-validation-that-doesnt-suck/
+iso_candidates = """2009-12-12T12:34
+2009
+2009-05-19
+2009-05-19
+20090519
+2009123
+2009-05
+2009-123
+2009-222
+2009-001
+2009-W01-1
+2009-W51-1
+2009-W33
+2009W511
+2009-05-19
+2009-05-19 00:00
+2009-05-19 14
+2009-05-19 14:31
+2009-05-19 14:39:22
+2009-05-19T14:39Z
+2009-W21-2
+2009-W21-2T01:22
+2009-139
+2009-05-19 14:39:22-06:00
+2009-05-19 14:39:22+0600
+2009-05-19 14:39:22-01
+20090621T0545Z
+2007-04-06T00:00
+2007-04-05T24:00
+2010-02-18T16:23:48.5
+2010-02-18T16:23:48,444
+2010-02-18T16:23:48,3-06:00
+2010-02-18T16:23:00.4
+2010-02-18T16:23:00,25
+2010-02-18T16:23:00.33+0600
+2010-02-18T16:00:00.23334444
+2010-02-18T16:00:00,2283
+2009-05-19 143922
+2009-05-19 1439""".split(
+ "\n"
+)
+
+
+def test_iso_datetime():
+ @dataclass
+ class Envelope(aristaproto.Message):
+ ts: datetime = aristaproto.message_field(1)
+
+ msg = Envelope()
+
+ for _, candidate in enumerate(iso_candidates):
+ msg.from_dict({"ts": candidate})
+ assert isinstance(msg.ts, datetime)
+
+
+def test_iso_datetime_list():
+ @dataclass
+ class Envelope(aristaproto.Message):
+ timestamps: List[datetime] = aristaproto.message_field(1)
+
+ msg = Envelope()
+
+ msg.from_dict({"timestamps": iso_candidates})
+ assert all([isinstance(item, datetime) for item in msg.timestamps])
+
+
+def test_service_argument__expected_parameter():
+ from tests.output_aristaproto.service import TestStub
+
+ sig = signature(TestStub.do_thing)
+ do_thing_request_parameter = sig.parameters["do_thing_request"]
+ assert do_thing_request_parameter.default is Parameter.empty
+ assert do_thing_request_parameter.annotation == "DoThingRequest"
+
+
+def test_is_set():
+ @dataclass
+ class Spam(aristaproto.Message):
+ foo: bool = aristaproto.bool_field(1)
+ bar: Optional[int] = aristaproto.int32_field(2, optional=True)
+
+ assert not Spam().is_set("foo")
+ assert not Spam().is_set("bar")
+ assert Spam(foo=True).is_set("foo")
+ assert Spam(foo=True, bar=0).is_set("bar")
+
+
+def test_equality_comparison():
+ from tests.output_aristaproto.bool import Test as TestMessage
+
+ msg = TestMessage(value=True)
+
+ assert msg == msg
+ assert msg == ANY
+ assert msg == TestMessage(value=True)
+ assert msg != 1
+ assert msg != TestMessage(value=False)
diff --git a/tests/test_get_ref_type.py b/tests/test_get_ref_type.py
new file mode 100644
index 0000000..a4c6f76
--- /dev/null
+++ b/tests/test_get_ref_type.py
@@ -0,0 +1,371 @@
+import pytest
+
+from aristaproto.compile.importing import (
+ get_type_reference,
+ parse_source_type_name,
+)
+
+
+@pytest.mark.parametrize(
+ ["google_type", "expected_name", "expected_import"],
+ [
+ (
+ ".google.protobuf.Empty",
+ '"aristaproto_lib_google_protobuf.Empty"',
+ "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf",
+ ),
+ (
+ ".google.protobuf.Struct",
+ '"aristaproto_lib_google_protobuf.Struct"',
+ "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf",
+ ),
+ (
+ ".google.protobuf.ListValue",
+ '"aristaproto_lib_google_protobuf.ListValue"',
+ "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf",
+ ),
+ (
+ ".google.protobuf.Value",
+ '"aristaproto_lib_google_protobuf.Value"',
+ "import aristaproto.lib.google.protobuf as aristaproto_lib_google_protobuf",
+ ),
+ ],
+)
+def test_reference_google_wellknown_types_non_wrappers(
+ google_type: str, expected_name: str, expected_import: str
+):
+ imports = set()
+ name = get_type_reference(
+ package="", imports=imports, source_type=google_type, pydantic=False
+ )
+
+ assert name == expected_name
+ assert imports.__contains__(
+ expected_import
+ ), f"{expected_import} not found in {imports}"
+
+
+@pytest.mark.parametrize(
+ ["google_type", "expected_name", "expected_import"],
+ [
+ (
+ ".google.protobuf.Empty",
+ '"aristaproto_lib_pydantic_google_protobuf.Empty"',
+ "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf",
+ ),
+ (
+ ".google.protobuf.Struct",
+ '"aristaproto_lib_pydantic_google_protobuf.Struct"',
+ "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf",
+ ),
+ (
+ ".google.protobuf.ListValue",
+ '"aristaproto_lib_pydantic_google_protobuf.ListValue"',
+ "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf",
+ ),
+ (
+ ".google.protobuf.Value",
+ '"aristaproto_lib_pydantic_google_protobuf.Value"',
+ "import aristaproto.lib.pydantic.google.protobuf as aristaproto_lib_pydantic_google_protobuf",
+ ),
+ ],
+)
+def test_reference_google_wellknown_types_non_wrappers_pydantic(
+ google_type: str, expected_name: str, expected_import: str
+):
+ imports = set()
+ name = get_type_reference(
+ package="", imports=imports, source_type=google_type, pydantic=True
+ )
+
+ assert name == expected_name
+ assert imports.__contains__(
+ expected_import
+ ), f"{expected_import} not found in {imports}"
+
+
+@pytest.mark.parametrize(
+ ["google_type", "expected_name"],
+ [
+ (".google.protobuf.DoubleValue", "Optional[float]"),
+ (".google.protobuf.FloatValue", "Optional[float]"),
+ (".google.protobuf.Int32Value", "Optional[int]"),
+ (".google.protobuf.Int64Value", "Optional[int]"),
+ (".google.protobuf.UInt32Value", "Optional[int]"),
+ (".google.protobuf.UInt64Value", "Optional[int]"),
+ (".google.protobuf.BoolValue", "Optional[bool]"),
+ (".google.protobuf.StringValue", "Optional[str]"),
+ (".google.protobuf.BytesValue", "Optional[bytes]"),
+ ],
+)
+def test_referenceing_google_wrappers_unwraps_them(
+ google_type: str, expected_name: str
+):
+ imports = set()
+ name = get_type_reference(package="", imports=imports, source_type=google_type)
+
+ assert name == expected_name
+ assert imports == set()
+
+
+@pytest.mark.parametrize(
+ ["google_type", "expected_name"],
+ [
+ (
+ ".google.protobuf.DoubleValue",
+ '"aristaproto_lib_google_protobuf.DoubleValue"',
+ ),
+ (".google.protobuf.FloatValue", '"aristaproto_lib_google_protobuf.FloatValue"'),
+ (".google.protobuf.Int32Value", '"aristaproto_lib_google_protobuf.Int32Value"'),
+ (".google.protobuf.Int64Value", '"aristaproto_lib_google_protobuf.Int64Value"'),
+ (
+ ".google.protobuf.UInt32Value",
+ '"aristaproto_lib_google_protobuf.UInt32Value"',
+ ),
+ (
+ ".google.protobuf.UInt64Value",
+ '"aristaproto_lib_google_protobuf.UInt64Value"',
+ ),
+ (".google.protobuf.BoolValue", '"aristaproto_lib_google_protobuf.BoolValue"'),
+ (
+ ".google.protobuf.StringValue",
+ '"aristaproto_lib_google_protobuf.StringValue"',
+ ),
+ (".google.protobuf.BytesValue", '"aristaproto_lib_google_protobuf.BytesValue"'),
+ ],
+)
+def test_referenceing_google_wrappers_without_unwrapping(
+ google_type: str, expected_name: str
+):
+ name = get_type_reference(
+ package="", imports=set(), source_type=google_type, unwrap=False
+ )
+
+ assert name == expected_name
+
+
+def test_reference_child_package_from_package():
+ imports = set()
+ name = get_type_reference(
+ package="package", imports=imports, source_type="package.child.Message"
+ )
+
+ assert imports == {"from . import child"}
+ assert name == '"child.Message"'
+
+
+def test_reference_child_package_from_root():
+ imports = set()
+ name = get_type_reference(package="", imports=imports, source_type="child.Message")
+
+ assert imports == {"from . import child"}
+ assert name == '"child.Message"'
+
+
+def test_reference_camel_cased():
+ imports = set()
+ name = get_type_reference(
+ package="", imports=imports, source_type="child_package.example_message"
+ )
+
+ assert imports == {"from . import child_package"}
+ assert name == '"child_package.ExampleMessage"'
+
+
+def test_reference_nested_child_from_root():
+ imports = set()
+ name = get_type_reference(
+ package="", imports=imports, source_type="nested.child.Message"
+ )
+
+ assert imports == {"from .nested import child as nested_child"}
+ assert name == '"nested_child.Message"'
+
+
+def test_reference_deeply_nested_child_from_root():
+ imports = set()
+ name = get_type_reference(
+ package="", imports=imports, source_type="deeply.nested.child.Message"
+ )
+
+ assert imports == {"from .deeply.nested import child as deeply_nested_child"}
+ assert name == '"deeply_nested_child.Message"'
+
+
+def test_reference_deeply_nested_child_from_package():
+ imports = set()
+ name = get_type_reference(
+ package="package",
+ imports=imports,
+ source_type="package.deeply.nested.child.Message",
+ )
+
+ assert imports == {"from .deeply.nested import child as deeply_nested_child"}
+ assert name == '"deeply_nested_child.Message"'
+
+
+def test_reference_root_sibling():
+ imports = set()
+ name = get_type_reference(package="", imports=imports, source_type="Message")
+
+ assert imports == set()
+ assert name == '"Message"'
+
+
+def test_reference_nested_siblings():
+ imports = set()
+ name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
+
+ assert imports == set()
+ assert name == '"Message"'
+
+
+def test_reference_deeply_nested_siblings():
+ imports = set()
+ name = get_type_reference(
+ package="foo.bar", imports=imports, source_type="foo.bar.Message"
+ )
+
+ assert imports == set()
+ assert name == '"Message"'
+
+
+def test_reference_parent_package_from_child():
+ imports = set()
+ name = get_type_reference(
+ package="package.child", imports=imports, source_type="package.Message"
+ )
+
+ assert imports == {"from ... import package as __package__"}
+ assert name == '"__package__.Message"'
+
+
+def test_reference_parent_package_from_deeply_nested_child():
+ imports = set()
+ name = get_type_reference(
+ package="package.deeply.nested.child",
+ imports=imports,
+ source_type="package.deeply.nested.Message",
+ )
+
+ assert imports == {"from ... import nested as __nested__"}
+ assert name == '"__nested__.Message"'
+
+
+def test_reference_ancestor_package_from_nested_child():
+ imports = set()
+ name = get_type_reference(
+ package="package.ancestor.nested.child",
+ imports=imports,
+ source_type="package.ancestor.Message",
+ )
+
+ assert imports == {"from .... import ancestor as ___ancestor__"}
+ assert name == '"___ancestor__.Message"'
+
+
+def test_reference_root_package_from_child():
+ imports = set()
+ name = get_type_reference(
+ package="package.child", imports=imports, source_type="Message"
+ )
+
+ assert imports == {"from ... import Message as __Message__"}
+ assert name == '"__Message__"'
+
+
+def test_reference_root_package_from_deeply_nested_child():
+ imports = set()
+ name = get_type_reference(
+ package="package.deeply.nested.child", imports=imports, source_type="Message"
+ )
+
+ assert imports == {"from ..... import Message as ____Message__"}
+ assert name == '"____Message__"'
+
+
+def test_reference_unrelated_package():
+ imports = set()
+ name = get_type_reference(package="a", imports=imports, source_type="p.Message")
+
+ assert imports == {"from .. import p as _p__"}
+ assert name == '"_p__.Message"'
+
+
+def test_reference_unrelated_nested_package():
+ imports = set()
+ name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
+
+ assert imports == {"from ...p import q as __p_q__"}
+ assert name == '"__p_q__.Message"'
+
+
+def test_reference_unrelated_deeply_nested_package():
+ imports = set()
+ name = get_type_reference(
+ package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
+ )
+
+ assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
+ assert name == '"____p_q_r_s__.Message"'
+
+
+def test_reference_cousin_package():
+ imports = set()
+ name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
+
+ assert imports == {"from .. import y as _y__"}
+ assert name == '"_y__.Message"'
+
+
+def test_reference_cousin_package_different_name():
+ imports = set()
+ name = get_type_reference(
+ package="test.package1", imports=imports, source_type="cousin.package2.Message"
+ )
+
+ assert imports == {"from ...cousin import package2 as __cousin_package2__"}
+ assert name == '"__cousin_package2__.Message"'
+
+
+def test_reference_cousin_package_same_name():
+ imports = set()
+ name = get_type_reference(
+ package="test.package", imports=imports, source_type="cousin.package.Message"
+ )
+
+ assert imports == {"from ...cousin import package as __cousin_package__"}
+ assert name == '"__cousin_package__.Message"'
+
+
+def test_reference_far_cousin_package():
+ imports = set()
+ name = get_type_reference(
+ package="a.x.y", imports=imports, source_type="a.b.c.Message"
+ )
+
+ assert imports == {"from ...b import c as __b_c__"}
+ assert name == '"__b_c__.Message"'
+
+
+def test_reference_far_far_cousin_package():
+ imports = set()
+ name = get_type_reference(
+ package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
+ )
+
+ assert imports == {"from ....b.c import d as ___b_c_d__"}
+ assert name == '"___b_c_d__.Message"'
+
+
+@pytest.mark.parametrize(
+ ["full_name", "expected_output"],
+ [
+ ("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
+ (".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
+ (".service.ExampleRequest", ("service", "ExampleRequest")),
+ (".package.lower_case_message", ("package", "lower_case_message")),
+ ],
+)
+def test_parse_field_type_name(full_name, expected_output):
+ assert parse_source_type_name(full_name) == expected_output
diff --git a/tests/test_inputs.py b/tests/test_inputs.py
new file mode 100644
index 0000000..9247e7b
--- /dev/null
+++ b/tests/test_inputs.py
@@ -0,0 +1,225 @@
+import importlib
+import json
+import math
+import os
+import sys
+from collections import namedtuple
+from types import ModuleType
+from typing import (
+ Any,
+ Dict,
+ List,
+ Set,
+ Tuple,
+)
+
+import pytest
+
+import aristaproto
+from tests.inputs import config as test_input_config
+from tests.mocks import MockChannel
+from tests.util import (
+ find_module,
+ get_directories,
+ get_test_case_json_data,
+ inputs_path,
+)
+
+
+# Force pure-python implementation instead of C++, otherwise imports
+# break things because we can't properly reset the symbol database.
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+from google.protobuf.json_format import Parse
+
+
+class TestCases:
+ def __init__(
+ self,
+ path,
+ services: Set[str],
+ xfail: Set[str],
+ ):
+ _all = set(get_directories(path)) - {"__pycache__"}
+ _services = services
+ _messages = (_all - services) - {"__pycache__"}
+ _messages_with_json = {
+ test for test in _messages if get_test_case_json_data(test)
+ }
+
+ unknown_xfail_tests = xfail - _all
+ if unknown_xfail_tests:
+ raise Exception(f"Unknown test(s) in config.py: {unknown_xfail_tests}")
+
+ self.all = self.apply_xfail_marks(_all, xfail)
+ self.services = self.apply_xfail_marks(_services, xfail)
+ self.messages = self.apply_xfail_marks(_messages, xfail)
+ self.messages_with_json = self.apply_xfail_marks(_messages_with_json, xfail)
+
+ @staticmethod
+ def apply_xfail_marks(test_set: Set[str], xfail: Set[str]):
+ return [
+ pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test
+ for test in test_set
+ ]
+
+
+test_cases = TestCases(
+ path=inputs_path,
+ services=test_input_config.services,
+ xfail=test_input_config.xfail,
+)
+
+plugin_output_package = "tests.output_aristaproto"
+reference_output_package = "tests.output_reference"
+
+TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
+
+
+def module_has_entry_point(module: ModuleType):
+ return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
+
+
+def list_replace_nans(items: List) -> List[Any]:
+ """Replace float("nan") in a list with the string "NaN"
+
+ Parameters
+ ----------
+ items : List
+ List to update
+
+ Returns
+ -------
+ List[Any]
+ Updated list
+ """
+ result = []
+ for item in items:
+ if isinstance(item, list):
+ result.append(list_replace_nans(item))
+ elif isinstance(item, dict):
+ result.append(dict_replace_nans(item))
+ elif isinstance(item, float) and math.isnan(item):
+ result.append(aristaproto.NAN)
+ return result
+
+
+def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
+ """Replace float("nan") in a dictionary with the string "NaN"
+
+ Parameters
+ ----------
+ input_dict : Dict[Any, Any]
+ Dictionary to update
+
+ Returns
+ -------
+ Dict[Any, Any]
+ Updated dictionary
+ """
+ result = {}
+ for key, value in input_dict.items():
+ if isinstance(value, dict):
+ value = dict_replace_nans(value)
+ elif isinstance(value, list):
+ value = list_replace_nans(value)
+ elif isinstance(value, float) and math.isnan(value):
+ value = aristaproto.NAN
+ result[key] = value
+ return result
+
+
+@pytest.fixture
+def test_data(request, reset_sys_path):
+ test_case_name = request.param
+
+ reference_module_root = os.path.join(
+ *reference_output_package.split("."), test_case_name
+ )
+ sys.path.append(reference_module_root)
+
+ plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
+
+ plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
+
+ if not plugin_module_entry_point:
+ raise Exception(
+ f"Test case {repr(test_case_name)} has no entry point. "
+ "Please add a proto message or service called Test and recompile."
+ )
+
+ yield (
+ TestData(
+ plugin_module=plugin_module_entry_point,
+ reference_module=lambda: importlib.import_module(
+ f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
+ ),
+ json_data=get_test_case_json_data(test_case_name),
+ )
+ )
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
+def test_message_can_instantiated(test_data: TestData) -> None:
+ plugin_module, *_ = test_data
+ plugin_module.Test()
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True)
+def test_message_equality(test_data: TestData) -> None:
+ plugin_module, *_ = test_data
+ message1 = plugin_module.Test()
+ message2 = plugin_module.Test()
+ assert message1 == message2
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
+def test_message_json(repeat, test_data: TestData) -> None:
+ plugin_module, _, json_data = test_data
+
+ for _ in range(repeat):
+ for sample in json_data:
+ if sample.belongs_to(test_input_config.non_symmetrical_json):
+ continue
+
+ message: aristaproto.Message = plugin_module.Test()
+
+ message.from_json(sample.json)
+ message_json = message.to_json(0)
+
+ assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
+ json.loads(sample.json)
+ )
+
+
+@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
+def test_service_can_be_instantiated(test_data: TestData) -> None:
+ test_data.plugin_module.TestStub(MockChannel())
+
+
+@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
+def test_binary_compatibility(repeat, test_data: TestData) -> None:
+ plugin_module, reference_module, json_data = test_data
+
+ for sample in json_data:
+ reference_instance = Parse(sample.json, reference_module().Test())
+ reference_binary_output = reference_instance.SerializeToString()
+
+ for _ in range(repeat):
+ plugin_instance_from_json: aristaproto.Message = (
+ plugin_module.Test().from_json(sample.json)
+ )
+ plugin_instance_from_binary = plugin_module.Test.FromString(
+ reference_binary_output
+ )
+
+ # Generally this can't be relied on, but here we are aiming to match the
+ # existing Python implementation and aren't doing anything tricky.
+ # https://developers.google.com/protocol-buffers/docs/encoding#implications
+ assert bytes(plugin_instance_from_json) == reference_binary_output
+ assert bytes(plugin_instance_from_binary) == reference_binary_output
+
+ assert plugin_instance_from_json == plugin_instance_from_binary
+ assert dict_replace_nans(
+ plugin_instance_from_json.to_dict()
+ ) == dict_replace_nans(plugin_instance_from_binary.to_dict())
diff --git a/tests/test_mapmessage.py b/tests/test_mapmessage.py
new file mode 100644
index 0000000..75220e4
--- /dev/null
+++ b/tests/test_mapmessage.py
@@ -0,0 +1,18 @@
+from tests.output_aristaproto.mapmessage import (
+ Nested,
+ Test,
+)
+
+
+def test_mapmessage_to_dict_preserves_message():
+ message = Test(
+ items={
+ "test": Nested(
+ count=1,
+ )
+ }
+ )
+
+ message.to_dict()
+
+ assert isinstance(message.items["test"], Nested), "Wrong nested type after to_dict"
diff --git a/tests/test_pickling.py b/tests/test_pickling.py
new file mode 100644
index 0000000..2356d98
--- /dev/null
+++ b/tests/test_pickling.py
@@ -0,0 +1,203 @@
+import pickle
+from copy import (
+ copy,
+ deepcopy,
+)
+from dataclasses import dataclass
+from typing import (
+ Dict,
+ List,
+)
+from unittest.mock import ANY
+
+import cachelib
+
+import aristaproto
+from aristaproto.lib.google import protobuf as google
+
+
+def unpickled(message):
+ return pickle.loads(pickle.dumps(message))
+
+
+@dataclass(eq=False, repr=False)
+class Fe(aristaproto.Message):
+ abc: str = aristaproto.string_field(1)
+
+
+@dataclass(eq=False, repr=False)
+class Fi(aristaproto.Message):
+ abc: str = aristaproto.string_field(1)
+
+
+@dataclass(eq=False, repr=False)
+class Fo(aristaproto.Message):
+ abc: str = aristaproto.string_field(1)
+
+
+@dataclass(eq=False, repr=False)
+class NestedData(aristaproto.Message):
+ struct_foo: Dict[str, "google.Struct"] = aristaproto.map_field(
+ 1, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
+ )
+ map_str_any_bar: Dict[str, "google.Any"] = aristaproto.map_field(
+ 2, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
+ )
+
+
+@dataclass(eq=False, repr=False)
+class Complex(aristaproto.Message):
+ foo_str: str = aristaproto.string_field(1)
+ fe: "Fe" = aristaproto.message_field(3, group="grp")
+ fi: "Fi" = aristaproto.message_field(4, group="grp")
+ fo: "Fo" = aristaproto.message_field(5, group="grp")
+ nested_data: "NestedData" = aristaproto.message_field(6)
+ mapping: Dict[str, "google.Any"] = aristaproto.map_field(
+ 7, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE
+ )
+
+
+def complex_msg():
+ return Complex(
+ foo_str="yep",
+ fe=Fe(abc="1"),
+ nested_data=NestedData(
+ struct_foo={
+ "foo": google.Struct(
+ fields={
+ "hello": google.Value(
+ list_value=google.ListValue(
+ values=[google.Value(string_value="world")]
+ )
+ )
+ }
+ ),
+ },
+ map_str_any_bar={
+ "key": google.Any(value=b"value"),
+ },
+ ),
+ mapping={
+ "message": google.Any(value=bytes(Fi(abc="hi"))),
+ "string": google.Any(value=b"howdy"),
+ },
+ )
+
+
+def test_pickling_complex_message():
+ msg = complex_msg()
+ deser = unpickled(msg)
+ assert msg == deser
+ assert msg.fe.abc == "1"
+ assert msg.is_set("fi") is not True
+ assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
+ assert msg.mapping["string"].value.decode() == "howdy"
+ assert (
+ msg.nested_data.struct_foo["foo"]
+ .fields["hello"]
+ .list_value.values[0]
+ .string_value
+ == "world"
+ )
+
+
+def test_recursive_message():
+ from tests.output_aristaproto.recursivemessage import Test as RecursiveMessage
+
+ msg = RecursiveMessage()
+ msg = unpickled(msg)
+
+ assert msg.child == RecursiveMessage()
+
+ # Lazily-created zero-value children must not affect equality.
+ assert msg == RecursiveMessage()
+
+ # Lazily-created zero-value children must not affect serialization.
+ assert bytes(msg) == b""
+
+
+def test_recursive_message_defaults():
+ from tests.output_aristaproto.recursivemessage import (
+ Intermediate,
+ Test as RecursiveMessage,
+ )
+
+ msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
+ msg = unpickled(msg)
+
+ # set values are as expected
+ assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
+
+ # lazy initialized works modifies the message
+ assert msg != RecursiveMessage(
+ name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
+ )
+ msg.child.child.name = "jude"
+ assert msg == RecursiveMessage(
+ name="bob",
+ intermediate=Intermediate(42),
+ child=RecursiveMessage(child=RecursiveMessage(name="jude")),
+ )
+
+ # lazily initialization recurses as needed
+ assert msg.child.child.child.child.child.child.child == RecursiveMessage()
+ assert msg.intermediate.child.intermediate == Intermediate()
+
+
+@dataclass
+class PickledMessage(aristaproto.Message):
+ foo: bool = aristaproto.bool_field(1)
+ bar: int = aristaproto.int32_field(2)
+ baz: List[str] = aristaproto.string_field(3)
+
+
+def test_copyability():
+ msg = PickledMessage(bar=12, baz=["hello"])
+ msg = unpickled(msg)
+
+ copied = copy(msg)
+ assert msg == copied
+ assert msg is not copied
+ assert msg.baz is copied.baz
+
+ deepcopied = deepcopy(msg)
+ assert msg == deepcopied
+ assert msg is not deepcopied
+ assert msg.baz is not deepcopied.baz
+
+
+def test_message_can_be_cached():
+ """Cachelib uses pickling to cache values"""
+
+ cache = cachelib.SimpleCache()
+
+ def use_cache():
+ calls = getattr(use_cache, "calls", 0)
+ result = cache.get("message")
+ if result is not None:
+ return result
+ else:
+ setattr(use_cache, "calls", calls + 1)
+ result = complex_msg()
+ cache.set("message", result)
+ return result
+
+ for n in range(10):
+ if n == 0:
+ assert not cache.has("message")
+ else:
+ assert cache.has("message")
+
+ msg = use_cache()
+ assert use_cache.calls == 1 # The message is only ever built once
+ assert msg.fe.abc == "1"
+ assert msg.is_set("fi") is not True
+ assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
+ assert msg.mapping["string"].value.decode() == "howdy"
+ assert (
+ msg.nested_data.struct_foo["foo"]
+ .fields["hello"]
+ .list_value.values[0]
+ .string_value
+ == "world"
+ )
diff --git a/tests/test_streams.py b/tests/test_streams.py
new file mode 100644
index 0000000..7ae441b
--- /dev/null
+++ b/tests/test_streams.py
@@ -0,0 +1,434 @@
+from dataclasses import dataclass
+from io import BytesIO
+from pathlib import Path
+from shutil import which
+from subprocess import run
+from typing import Optional
+
+import pytest
+
+import aristaproto
+from tests.output_aristaproto import (
+ map,
+ nested,
+ oneof,
+ repeated,
+ repeatedpacked,
+)
+
+
+oneof_example = oneof.Test().from_dict(
+ {"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"}
+)
+
+len_oneof = len(oneof_example)
+
+nested_example = nested.Test().from_dict(
+ {
+ "nested": {"count": 1},
+ "sibling": {"foo": 2},
+ "sibling2": {"foo": 3},
+ "msg": nested.TestMsg.THIS,
+ }
+)
+
+repeated_example = repeated.Test().from_dict({"names": ["blah", "Blah2"]})
+
+packed_example = repeatedpacked.Test().from_dict(
+ {"counts": [1, 2, 3], "signed": [-1, 2, -3], "fixed": [1.2, -2.3, 3.4]}
+)
+
+map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
+
+streams_path = Path("tests/streams/")
+
+java = which("java")
+
+
+def test_load_varint_too_long():
+ with BytesIO(
+ b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01"
+ ) as stream, pytest.raises(ValueError):
+ aristaproto.load_varint(stream)
+
+ with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream:
+ # This should not raise a ValueError, as it is within 64 bits
+ aristaproto.load_varint(stream)
+
+
+def test_load_varint_file():
+ with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
+ assert aristaproto.load_varint(stream) == (8, b"\x08") # Single-byte varint
+ stream.read(2) # Skip until first multi-byte
+ assert aristaproto.load_varint(stream) == (
+ 123456789,
+ b"\x95\x9A\xEF\x3A",
+ ) # Multi-byte varint
+
+
+def test_load_varint_cutoff():
+ with open(streams_path / "load_varint_cutoff.in", "rb") as stream:
+ with pytest.raises(EOFError):
+ aristaproto.load_varint(stream)
+
+ stream.seek(1)
+ with pytest.raises(EOFError):
+ aristaproto.load_varint(stream)
+
+
+def test_dump_varint_file(tmp_path):
+ # Dump test varints to file
+ with open(tmp_path / "dump_varint_file.out", "wb") as stream:
+ aristaproto.dump_varint(8, stream) # Single-byte varint
+ aristaproto.dump_varint(123456789, stream) # Multi-byte varint
+
+ # Check that file contents are as expected
+ with open(tmp_path / "dump_varint_file.out", "rb") as test_stream, open(
+ streams_path / "message_dump_file_single.expected", "rb"
+ ) as exp_stream:
+ assert aristaproto.load_varint(test_stream) == aristaproto.load_varint(
+ exp_stream
+ )
+ exp_stream.read(2)
+ assert aristaproto.load_varint(test_stream) == aristaproto.load_varint(
+ exp_stream
+ )
+
+
+def test_parse_fields():
+ with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
+ parsed_bytes = aristaproto.parse_fields(stream.read())
+
+ with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
+ parsed_stream = aristaproto.load_fields(stream)
+ for field in parsed_bytes:
+ assert field == next(parsed_stream)
+
+
+def test_message_dump_file_single(tmp_path):
+ # Write the message to the stream
+ with open(tmp_path / "message_dump_file_single.out", "wb") as stream:
+ oneof_example.dump(stream)
+
+ # Check that the outputted file is exactly as expected
+ with open(tmp_path / "message_dump_file_single.out", "rb") as test_stream, open(
+ streams_path / "message_dump_file_single.expected", "rb"
+ ) as exp_stream:
+ assert test_stream.read() == exp_stream.read()
+
+
+def test_message_dump_file_multiple(tmp_path):
+ # Write the same Message twice and another, different message
+ with open(tmp_path / "message_dump_file_multiple.out", "wb") as stream:
+ oneof_example.dump(stream)
+ oneof_example.dump(stream)
+ nested_example.dump(stream)
+
+ # Check that all three Messages were outputted to the file correctly
+ with open(tmp_path / "message_dump_file_multiple.out", "rb") as test_stream, open(
+ streams_path / "message_dump_file_multiple.expected", "rb"
+ ) as exp_stream:
+ assert test_stream.read() == exp_stream.read()
+
+
+def test_message_dump_delimited(tmp_path):
+ with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
+ oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)
+ oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)
+ nested_example.dump(stream, aristaproto.SIZE_DELIMITED)
+
+ with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
+ streams_path / "delimited_messages.in", "rb"
+ ) as exp_stream:
+ assert test_stream.read() == exp_stream.read()
+
+
+def test_message_len():
+ assert len_oneof == len(bytes(oneof_example))
+ assert len(nested_example) == len(bytes(nested_example))
+
+
+def test_message_load_file_single():
+ with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
+ assert oneof.Test().load(stream) == oneof_example
+ stream.seek(0)
+ assert oneof.Test().load(stream, len_oneof) == oneof_example
+
+
+def test_message_load_file_multiple():
+ with open(streams_path / "message_dump_file_multiple.expected", "rb") as stream:
+ oneof_size = len_oneof
+ assert oneof.Test().load(stream, oneof_size) == oneof_example
+ assert oneof.Test().load(stream, oneof_size) == oneof_example
+ assert nested.Test().load(stream) == nested_example
+ assert stream.read(1) == b""
+
+
+def test_message_load_too_small():
+ with open(
+ streams_path / "message_dump_file_single.expected", "rb"
+ ) as stream, pytest.raises(ValueError):
+ oneof.Test().load(stream, len_oneof - 1)
+
+
+def test_message_load_delimited():
+ with open(streams_path / "delimited_messages.in", "rb") as stream:
+ assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example
+ assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example
+ assert nested.Test().load(stream, aristaproto.SIZE_DELIMITED) == nested_example
+ assert stream.read(1) == b""
+
+
+def test_message_load_too_large():
+ with open(
+ streams_path / "message_dump_file_single.expected", "rb"
+ ) as stream, pytest.raises(ValueError):
+ oneof.Test().load(stream, len_oneof + 1)
+
+
+def test_message_len_optional_field():
+ @dataclass
+ class Request(aristaproto.Message):
+ flag: Optional[bool] = aristaproto.message_field(1, wraps=aristaproto.TYPE_BOOL)
+
+ assert len(Request()) == len(b"")
+ assert len(Request(flag=True)) == len(b"\n\x02\x08\x01")
+ assert len(Request(flag=False)) == len(b"\n\x00")
+
+
+def test_message_len_repeated_field():
+ assert len(repeated_example) == len(bytes(repeated_example))
+
+
+def test_message_len_packed_field():
+ assert len(packed_example) == len(bytes(packed_example))
+
+
+def test_message_len_map_field():
+ assert len(map_example) == len(bytes(map_example))
+
+
+def test_message_len_empty_string():
+ @dataclass
+ class Empty(aristaproto.Message):
+ string: str = aristaproto.string_field(1, "group")
+ integer: int = aristaproto.int32_field(2, "group")
+
+ empty = Empty().from_dict({"string": ""})
+ assert len(empty) == len(bytes(empty))
+
+
+def test_calculate_varint_size_negative():
+ single_byte = -1
+ multi_byte = -10000000
+ edge = -(1 << 63)
+ beyond = -(1 << 63) - 1
+ before = -(1 << 63) + 1
+
+ assert (
+ aristaproto.size_varint(single_byte)
+ == len(aristaproto.encode_varint(single_byte))
+ == 10
+ )
+ assert (
+ aristaproto.size_varint(multi_byte)
+ == len(aristaproto.encode_varint(multi_byte))
+ == 10
+ )
+ assert aristaproto.size_varint(edge) == len(aristaproto.encode_varint(edge)) == 10
+ assert (
+ aristaproto.size_varint(before) == len(aristaproto.encode_varint(before)) == 10
+ )
+
+ with pytest.raises(ValueError):
+ aristaproto.size_varint(beyond)
+
+
+def test_calculate_varint_size_positive():
+ single_byte = 1
+ multi_byte = 10000000
+
+ assert aristaproto.size_varint(single_byte) == len(
+ aristaproto.encode_varint(single_byte)
+ )
+ assert aristaproto.size_varint(multi_byte) == len(
+ aristaproto.encode_varint(multi_byte)
+ )
+
+
+def test_dump_varint_negative(tmp_path):
+ single_byte = -1
+ multi_byte = -10000000
+ edge = -(1 << 63)
+ beyond = -(1 << 63) - 1
+ before = -(1 << 63) + 1
+
+ with open(tmp_path / "dump_varint_negative.out", "wb") as stream:
+ aristaproto.dump_varint(single_byte, stream)
+ aristaproto.dump_varint(multi_byte, stream)
+ aristaproto.dump_varint(edge, stream)
+ aristaproto.dump_varint(before, stream)
+
+ with pytest.raises(ValueError):
+ aristaproto.dump_varint(beyond, stream)
+
+ with open(streams_path / "dump_varint_negative.expected", "rb") as exp_stream, open(
+ tmp_path / "dump_varint_negative.out", "rb"
+ ) as test_stream:
+ assert test_stream.read() == exp_stream.read()
+
+
+def test_dump_varint_positive(tmp_path):
+ single_byte = 1
+ multi_byte = 10000000
+
+ with open(tmp_path / "dump_varint_positive.out", "wb") as stream:
+ aristaproto.dump_varint(single_byte, stream)
+ aristaproto.dump_varint(multi_byte, stream)
+
+ with open(tmp_path / "dump_varint_positive.out", "rb") as test_stream, open(
+ streams_path / "dump_varint_positive.expected", "rb"
+ ) as exp_stream:
+ assert test_stream.read() == exp_stream.read()
+
+
+# Java compatibility tests
+
+
+@pytest.fixture(scope="module")
+def compile_jar():
+ # Skip if not all required tools are present
+ if java is None:
+ pytest.skip("`java` command is absent and is required")
+ mvn = which("mvn")
+ if mvn is None:
+ pytest.skip("Maven is absent and is required")
+
+ # Compile the JAR
+ proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
+ if proc_maven.returncode != 0:
+ pytest.skip(
+ "Maven compatibility-test.jar build failed (maybe Java version <11?)"
+ )
+
+
+jar = "tests/streams/java/target/compatibility-test.jar"
+
+
+def run_jar(command: str, tmp_path):
+ return run([java, "-jar", jar, command, tmp_path], check=True)
+
+
+def run_java_single_varint(value: int, tmp_path) -> int:
+ # Write single varint to file
+ with open(tmp_path / "py_single_varint.out", "wb") as stream:
+ aristaproto.dump_varint(value, stream)
+
+ # Have Java read this varint and write it back
+ run_jar("single_varint", tmp_path)
+
+ # Read single varint from Java output file
+ with open(tmp_path / "java_single_varint.out", "rb") as stream:
+ returned = aristaproto.load_varint(stream)
+ with pytest.raises(EOFError):
+ aristaproto.load_varint(stream)
+
+ return returned
+
+
+def test_single_varint(compile_jar, tmp_path):
+ single_byte = (1, b"\x01")
+ multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
+
+ # Write a single-byte varint to a file and have Java read it back
+ returned = run_java_single_varint(single_byte[0], tmp_path)
+ assert returned == single_byte
+
+ # Same for a multi-byte varint
+ returned = run_java_single_varint(multi_byte[0], tmp_path)
+ assert returned == multi_byte
+
+
+def test_multiple_varints(compile_jar, tmp_path):
+ single_byte = (1, b"\x01")
+ multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
+ over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B")
+
+ # Write two varints to the same file
+ with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
+ aristaproto.dump_varint(single_byte[0], stream)
+ aristaproto.dump_varint(multi_byte[0], stream)
+ aristaproto.dump_varint(over32[0], stream)
+
+ # Have Java read these varints and write them back
+ run_jar("multiple_varints", tmp_path)
+
+ # Read varints from Java output file
+ with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
+ returned_single = aristaproto.load_varint(stream)
+ returned_multi = aristaproto.load_varint(stream)
+ returned_over32 = aristaproto.load_varint(stream)
+ with pytest.raises(EOFError):
+ aristaproto.load_varint(stream)
+
+ assert returned_single == single_byte
+ assert returned_multi == multi_byte
+ assert returned_over32 == over32
+
+
+def test_single_message(compile_jar, tmp_path):
+ # Write message to file
+ with open(tmp_path / "py_single_message.out", "wb") as stream:
+ oneof_example.dump(stream)
+
+ # Have Java read and return the message
+ run_jar("single_message", tmp_path)
+
+ # Read and check the returned message
+ with open(tmp_path / "java_single_message.out", "rb") as stream:
+ returned = oneof.Test().load(stream, len(bytes(oneof_example)))
+ assert stream.read() == b""
+
+ assert returned == oneof_example
+
+
+def test_multiple_messages(compile_jar, tmp_path):
+ # Write delimited messages to file
+ with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
+ oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)
+ nested_example.dump(stream, aristaproto.SIZE_DELIMITED)
+
+ # Have Java read and return the messages
+ run_jar("multiple_messages", tmp_path)
+
+ # Read and check the returned messages
+ with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
+ returned_oneof = oneof.Test().load(stream, aristaproto.SIZE_DELIMITED)
+ returned_nested = nested.Test().load(stream, aristaproto.SIZE_DELIMITED)
+ assert stream.read() == b""
+
+ assert returned_oneof == oneof_example
+ assert returned_nested == nested_example
+
+
+def test_infinite_messages(compile_jar, tmp_path):
+ num_messages = 5
+
+ # Write delimited messages to file
+ with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
+ for x in range(num_messages):
+ oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)
+
+ # Have Java read and return the messages
+ run_jar("infinite_messages", tmp_path)
+
+ # Read and check the returned messages
+ messages = []
+ with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
+ while True:
+ try:
+ messages.append(oneof.Test().load(stream, aristaproto.SIZE_DELIMITED))
+ except EOFError:
+ break
+
+ assert len(messages) == num_messages
diff --git a/tests/test_struct.py b/tests/test_struct.py
new file mode 100644
index 0000000..c562763
--- /dev/null
+++ b/tests/test_struct.py
@@ -0,0 +1,36 @@
+import json
+
+from aristaproto.lib.google.protobuf import Struct
+from aristaproto.lib.pydantic.google.protobuf import Struct as StructPydantic
+
+
+def test_struct_roundtrip():
+ data = {
+ "foo": "bar",
+ "baz": None,
+ "quux": 123,
+ "zap": [1, {"two": 3}, "four"],
+ }
+ data_json = json.dumps(data)
+
+ struct_from_dict = Struct().from_dict(data)
+ assert struct_from_dict.fields == data
+ assert struct_from_dict.to_dict() == data
+ assert struct_from_dict.to_json() == data_json
+
+ struct_from_json = Struct().from_json(data_json)
+ assert struct_from_json.fields == data
+ assert struct_from_json.to_dict() == data
+ assert struct_from_json == struct_from_dict
+ assert struct_from_json.to_json() == data_json
+
+ struct_pyd_from_dict = StructPydantic(fields={}).from_dict(data)
+ assert struct_pyd_from_dict.fields == data
+ assert struct_pyd_from_dict.to_dict() == data
+ assert struct_pyd_from_dict.to_json() == data_json
+
+ struct_pyd_from_dict = StructPydantic(fields={}).from_json(data_json)
+ assert struct_pyd_from_dict.fields == data
+ assert struct_pyd_from_dict.to_dict() == data
+ assert struct_pyd_from_dict == struct_pyd_from_dict
+ assert struct_pyd_from_dict.to_json() == data_json
diff --git a/tests/test_timestamp.py b/tests/test_timestamp.py
new file mode 100644
index 0000000..dd51420
--- /dev/null
+++ b/tests/test_timestamp.py
@@ -0,0 +1,27 @@
+from datetime import (
+ datetime,
+ timezone,
+)
+
+import pytest
+
+from aristaproto import _Timestamp
+
+
+@pytest.mark.parametrize(
+ "dt",
+ [
+ datetime(2023, 10, 11, 9, 41, 12, tzinfo=timezone.utc),
+ datetime.now(timezone.utc),
+ # potential issue with floating point precision:
+ datetime(2242, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
+ # potential issue with negative timestamps:
+ datetime(1969, 12, 31, 23, 0, 0, 1, tzinfo=timezone.utc),
+ ],
+)
+def test_timestamp_to_datetime_and_back(dt: datetime):
+ """
+ Make sure converting a datetime to a protobuf timestamp message
+ and then back again ends up with the same datetime.
+ """
+ assert _Timestamp.from_datetime(dt).to_datetime() == dt
diff --git a/tests/test_version.py b/tests/test_version.py
new file mode 100644
index 0000000..bfbe842
--- /dev/null
+++ b/tests/test_version.py
@@ -0,0 +1,16 @@
+from pathlib import Path
+
+import tomlkit
+
+from aristaproto import __version__
+
+
+PROJECT_TOML = Path(__file__).joinpath("..", "..", "pyproject.toml").resolve()
+
+
+def test_version():
+ with PROJECT_TOML.open() as toml_file:
+ project_config = tomlkit.loads(toml_file.read())
+ assert (
+ __version__ == project_config["tool"]["poetry"]["version"]
+ ), "Project version should match in package and package config"
diff --git a/tests/util.py b/tests/util.py
new file mode 100644
index 0000000..2ba7cab
--- /dev/null
+++ b/tests/util.py
@@ -0,0 +1,169 @@
+import asyncio
+import atexit
+import importlib
+import os
+import platform
+import sys
+import tempfile
+from dataclasses import dataclass
+from pathlib import Path
+from types import ModuleType
+from typing import (
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+root_path = Path(__file__).resolve().parent
+inputs_path = root_path.joinpath("inputs")
+output_path_reference = root_path.joinpath("output_reference")
+output_path_aristaproto = root_path.joinpath("output_aristaproto")
+output_path_aristaproto_pydantic = root_path.joinpath("output_aristaproto_pydantic")
+
+
+def get_files(path, suffix: str) -> Generator[str, None, None]:
+ for r, dirs, files in os.walk(path):
+ for filename in [f for f in files if f.endswith(suffix)]:
+ yield os.path.join(r, filename)
+
+
+def get_directories(path):
+ for root, directories, files in os.walk(path):
+ yield from directories
+
+
+async def protoc(
+ path: Union[str, Path],
+ output_dir: Union[str, Path],
+ reference: bool = False,
+ pydantic_dataclasses: bool = False,
+):
+ path: Path = Path(path).resolve()
+ output_dir: Path = Path(output_dir).resolve()
+ python_out_option: str = "python_aristaproto_out" if not reference else "python_out"
+
+ if pydantic_dataclasses:
+ plugin_path = Path("src/aristaproto/plugin/main.py")
+
+ if "Win" in platform.system():
+ with tempfile.NamedTemporaryFile(
+ "w", encoding="UTF-8", suffix=".bat", delete=False
+ ) as tf:
+ # See https://stackoverflow.com/a/42622705
+ tf.writelines(
+ [
+ "@echo off",
+ f"\nchdir {os.getcwd()}",
+ f"\n{sys.executable} -u {plugin_path.as_posix()}",
+ ]
+ )
+
+ tf.flush()
+
+ plugin_path = Path(tf.name)
+ atexit.register(os.remove, plugin_path)
+
+ command = [
+ sys.executable,
+ "-m",
+ "grpc.tools.protoc",
+ f"--plugin=protoc-gen-custom={plugin_path.as_posix()}",
+ "--experimental_allow_proto3_optional",
+ "--custom_opt=pydantic_dataclasses",
+ f"--proto_path={path.as_posix()}",
+ f"--custom_out={output_dir.as_posix()}",
+ *[p.as_posix() for p in path.glob("*.proto")],
+ ]
+ else:
+ command = [
+ sys.executable,
+ "-m",
+ "grpc.tools.protoc",
+ f"--proto_path={path.as_posix()}",
+ f"--{python_out_option}={output_dir.as_posix()}",
+ *[p.as_posix() for p in path.glob("*.proto")],
+ ]
+ proc = await asyncio.create_subprocess_exec(
+ *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
+ )
+ stdout, stderr = await proc.communicate()
+ return stdout, stderr, proc.returncode
+
+
+@dataclass
+class TestCaseJsonFile:
+ json: str
+ test_name: str
+ file_name: str
+
+ def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
+ return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
+
+
+def get_test_case_json_data(
+ test_case_name: str, *json_file_names: str
+) -> List[TestCaseJsonFile]:
+ """
+ :return:
+ A list of all files found in "{inputs_path}/test_case_name" with names matching
+ f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
+ json_file_names
+ """
+ test_case_dir = inputs_path.joinpath(test_case_name)
+ possible_file_paths = [
+ *(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names),
+ test_case_dir.joinpath(f"{test_case_name}.json"),
+ *test_case_dir.glob(f"{test_case_name}_*.json"),
+ ]
+
+ result = []
+ for test_data_file_path in possible_file_paths:
+ if not test_data_file_path.exists():
+ continue
+ with test_data_file_path.open("r") as fh:
+ result.append(
+ TestCaseJsonFile(
+ fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
+ )
+ )
+
+ return result
+
+
+def find_module(
+ module: ModuleType, predicate: Callable[[ModuleType], bool]
+) -> Optional[ModuleType]:
+ """
+ Recursively search module tree for a module that matches the search predicate.
+ Assumes that the submodules are directories containing __init__.py.
+
+ Example:
+
+ # find module inside foo that contains Test
+ import foo
+ test_module = find_module(foo, lambda m: hasattr(m, 'Test'))
+ """
+ if predicate(module):
+ return module
+
+ module_path = Path(*module.__path__)
+
+ for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
+ if sub == module_path:
+ continue
+ sub_module_path = sub.relative_to(module_path)
+ sub_module_name = ".".join(sub_module_path.parts)
+
+ sub_module = importlib.import_module(f".{sub_module_name}", module.__name__)
+
+ if predicate(sub_module):
+ return sub_module
+
+ return None