summaryrefslogtreecommitdiffstats
path: root/tests/test_features.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_features.py')
-rw-r--r--tests/test_features.py682
1 files changed, 682 insertions, 0 deletions
diff --git a/tests/test_features.py b/tests/test_features.py
new file mode 100644
index 0000000..638e668
--- /dev/null
+++ b/tests/test_features.py
@@ -0,0 +1,682 @@
+import json
+import sys
+from copy import (
+ copy,
+ deepcopy,
+)
+from dataclasses import dataclass
+from datetime import (
+ datetime,
+ timedelta,
+)
+from inspect import (
+ Parameter,
+ signature,
+)
+from typing import (
+ Dict,
+ List,
+ Optional,
+)
+from unittest.mock import ANY
+
+import pytest
+
+import aristaproto
+
+
+def test_has_field():
+ @dataclass
+ class Bar(aristaproto.Message):
+ baz: int = aristaproto.int32_field(1)
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: Bar = aristaproto.message_field(1)
+
+ # Unset by default
+ foo = Foo()
+ assert aristaproto.serialized_on_wire(foo.bar) is False
+
+ # Serialized after setting something
+ foo.bar.baz = 1
+ assert aristaproto.serialized_on_wire(foo.bar) is True
+
+ # Still has it after setting the default value
+ foo.bar.baz = 0
+ assert aristaproto.serialized_on_wire(foo.bar) is True
+
+ # Manual override (don't do this)
+ foo.bar._serialized_on_wire = False
+ assert aristaproto.serialized_on_wire(foo.bar) is False
+
+ # Can manually set it but defaults to false
+ foo.bar = Bar()
+ assert aristaproto.serialized_on_wire(foo.bar) is False
+
+ @dataclass
+ class WithCollections(aristaproto.Message):
+ test_list: List[str] = aristaproto.string_field(1)
+ test_map: Dict[str, str] = aristaproto.map_field(
+ 2, aristaproto.TYPE_STRING, aristaproto.TYPE_STRING
+ )
+
+ # Is always set from parse, even if all collections are empty
+ with_collections_empty = WithCollections().parse(bytes(WithCollections()))
+ assert aristaproto.serialized_on_wire(with_collections_empty) == True
+ with_collections_list = WithCollections().parse(
+ bytes(WithCollections(test_list=["a", "b", "c"]))
+ )
+ assert aristaproto.serialized_on_wire(with_collections_list) == True
+ with_collections_map = WithCollections().parse(
+ bytes(WithCollections(test_map={"a": "b", "c": "d"}))
+ )
+ assert aristaproto.serialized_on_wire(with_collections_map) == True
+
+
+def test_class_init():
+ @dataclass
+ class Bar(aristaproto.Message):
+ name: str = aristaproto.string_field(1)
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ name: str = aristaproto.string_field(1)
+ child: Bar = aristaproto.message_field(2)
+
+ foo = Foo(name="foo", child=Bar(name="bar"))
+
+ assert foo.to_dict() == {"name": "foo", "child": {"name": "bar"}}
+ assert foo.to_pydict() == {"name": "foo", "child": {"name": "bar"}}
+
+
+def test_enum_as_int_json():
+ class TestEnum(aristaproto.Enum):
+ ZERO = 0
+ ONE = 1
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: TestEnum = aristaproto.enum_field(1)
+
+ # JSON strings are supported, but ints should still be supported too.
+ foo = Foo().from_dict({"bar": 1})
+ assert foo.bar == TestEnum.ONE
+
+ # Plain-ol'-ints should serialize properly too.
+ foo.bar = 1
+ assert foo.to_dict() == {"bar": "ONE"}
+
+ # Similar expectations for pydict
+ foo = Foo().from_pydict({"bar": 1})
+ assert foo.bar == TestEnum.ONE
+ assert foo.to_pydict() == {"bar": TestEnum.ONE}
+
+
+def test_unknown_fields():
+ @dataclass
+ class Newer(aristaproto.Message):
+ foo: bool = aristaproto.bool_field(1)
+ bar: int = aristaproto.int32_field(2)
+ baz: str = aristaproto.string_field(3)
+
+ @dataclass
+ class Older(aristaproto.Message):
+ foo: bool = aristaproto.bool_field(1)
+
+ newer = Newer(foo=True, bar=1, baz="Hello")
+ serialized_newer = bytes(newer)
+
+ # Unknown fields in `Newer` should round trip with `Older`
+ round_trip = bytes(Older().parse(serialized_newer))
+ assert serialized_newer == round_trip
+
+ new_again = Newer().parse(round_trip)
+ assert newer == new_again
+
+
+def test_oneof_support():
+ @dataclass
+ class Sub(aristaproto.Message):
+ val: int = aristaproto.int32_field(1)
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: int = aristaproto.int32_field(1, group="group1")
+ baz: str = aristaproto.string_field(2, group="group1")
+ sub: Sub = aristaproto.message_field(3, group="group2")
+ abc: str = aristaproto.string_field(4, group="group2")
+
+ foo = Foo()
+
+ assert aristaproto.which_one_of(foo, "group1")[0] == ""
+
+ foo.bar = 1
+ foo.baz = "test"
+
+ # Other oneof fields should now be unset
+ assert not hasattr(foo, "bar")
+ assert object.__getattribute__(foo, "bar") == aristaproto.PLACEHOLDER
+ assert aristaproto.which_one_of(foo, "group1")[0] == "baz"
+
+ foo.sub = Sub(val=1)
+ assert aristaproto.serialized_on_wire(foo.sub)
+
+ foo.abc = "test"
+
+ # Group 1 shouldn't be touched, group 2 should have reset
+ assert not hasattr(foo, "sub")
+ assert object.__getattribute__(foo, "sub") == aristaproto.PLACEHOLDER
+ assert aristaproto.which_one_of(foo, "group2")[0] == "abc"
+
+ # Zero value should always serialize for one-of
+ foo = Foo(bar=0)
+ assert aristaproto.which_one_of(foo, "group1")[0] == "bar"
+ assert bytes(foo) == b"\x08\x00"
+
+ # Round trip should also work
+ foo2 = Foo().parse(bytes(foo))
+ assert aristaproto.which_one_of(foo2, "group1")[0] == "bar"
+ assert foo.bar == 0
+ assert aristaproto.which_one_of(foo2, "group2")[0] == ""
+
+
+@pytest.mark.skipif(
+ sys.version_info < (3, 10),
+ reason="pattern matching is only supported in python3.10+",
+)
+def test_oneof_pattern_matching():
+ from .oneof_pattern_matching import test_oneof_pattern_matching
+
+ test_oneof_pattern_matching()
+
+
+def test_json_casing():
+ @dataclass
+ class CasingTest(aristaproto.Message):
+ pascal_case: int = aristaproto.int32_field(1)
+ camel_case: int = aristaproto.int32_field(2)
+ snake_case: int = aristaproto.int32_field(3)
+ kabob_case: int = aristaproto.int32_field(4)
+
+ # Parsing should accept almost any input
+ test = CasingTest().from_dict(
+ {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
+ )
+
+ assert test == CasingTest(1, 2, 3, 4)
+
+ # Serializing should be strict.
+ assert json.loads(test.to_json()) == {
+ "pascalCase": 1,
+ "camelCase": 2,
+ "snakeCase": 3,
+ "kabobCase": 4,
+ }
+
+ assert json.loads(test.to_json(casing=aristaproto.Casing.SNAKE)) == {
+ "pascal_case": 1,
+ "camel_case": 2,
+ "snake_case": 3,
+ "kabob_case": 4,
+ }
+
+
+def test_dict_casing():
+ @dataclass
+ class CasingTest(aristaproto.Message):
+ pascal_case: int = aristaproto.int32_field(1)
+ camel_case: int = aristaproto.int32_field(2)
+ snake_case: int = aristaproto.int32_field(3)
+ kabob_case: int = aristaproto.int32_field(4)
+
+ # Parsing should accept almost any input
+ test = CasingTest().from_dict(
+ {"PascalCase": 1, "camelCase": 2, "snake_case": 3, "kabob-case": 4}
+ )
+
+ assert test == CasingTest(1, 2, 3, 4)
+
+ # Serializing should be strict.
+ assert test.to_dict() == {
+ "pascalCase": 1,
+ "camelCase": 2,
+ "snakeCase": 3,
+ "kabobCase": 4,
+ }
+ assert test.to_pydict() == {
+ "pascalCase": 1,
+ "camelCase": 2,
+ "snakeCase": 3,
+ "kabobCase": 4,
+ }
+
+ assert test.to_dict(casing=aristaproto.Casing.SNAKE) == {
+ "pascal_case": 1,
+ "camel_case": 2,
+ "snake_case": 3,
+ "kabob_case": 4,
+ }
+ assert test.to_pydict(casing=aristaproto.Casing.SNAKE) == {
+ "pascal_case": 1,
+ "camel_case": 2,
+ "snake_case": 3,
+ "kabob_case": 4,
+ }
+
+
+def test_optional_flag():
+ @dataclass
+ class Request(aristaproto.Message):
+ flag: Optional[bool] = aristaproto.message_field(1, wraps=aristaproto.TYPE_BOOL)
+
+ # Serialization of not passed vs. set vs. zero-value.
+ assert bytes(Request()) == b""
+ assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
+ assert bytes(Request(flag=False)) == b"\n\x00"
+
+ # Differentiate between not passed and the zero-value.
+ assert Request().parse(b"").flag is None
+ assert Request().parse(b"\n\x00").flag is False
+
+
+def test_optional_datetime_to_dict():
+ @dataclass
+ class Request(aristaproto.Message):
+ date: Optional[datetime] = aristaproto.message_field(1, optional=True)
+
+ # Check dict serialization
+ assert Request().to_dict() == {}
+ assert Request().to_dict(include_default_values=True) == {"date": None}
+ assert Request(date=datetime(2020, 1, 1)).to_dict() == {
+ "date": "2020-01-01T00:00:00Z"
+ }
+ assert Request(date=datetime(2020, 1, 1)).to_dict(include_default_values=True) == {
+ "date": "2020-01-01T00:00:00Z"
+ }
+
+ # Check pydict serialization
+ assert Request().to_pydict() == {}
+ assert Request().to_pydict(include_default_values=True) == {"date": None}
+ assert Request(date=datetime(2020, 1, 1)).to_pydict() == {
+ "date": datetime(2020, 1, 1)
+ }
+ assert Request(date=datetime(2020, 1, 1)).to_pydict(
+ include_default_values=True
+ ) == {"date": datetime(2020, 1, 1)}
+
+
+def test_to_json_default_values():
+ @dataclass
+ class TestMessage(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_str: str = aristaproto.string_field(3)
+ some_bool: bool = aristaproto.bool_field(4)
+
+ # Empty dict
+ test = TestMessage().from_dict({})
+
+ assert json.loads(test.to_json(include_default_values=True)) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ # All default values
+ test = TestMessage().from_dict(
+ {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
+ )
+
+ assert json.loads(test.to_json(include_default_values=True)) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+
+def test_to_dict_default_values():
+ @dataclass
+ class TestMessage(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_str: str = aristaproto.string_field(3)
+ some_bool: bool = aristaproto.bool_field(4)
+
+ # Empty dict
+ test = TestMessage().from_dict({})
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ test = TestMessage().from_pydict({})
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ # All default values
+ test = TestMessage().from_dict(
+ {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
+ )
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ test = TestMessage().from_pydict(
+ {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False}
+ )
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 0.0,
+ "someStr": "",
+ "someBool": False,
+ }
+
+ # Some default and some other values
+ @dataclass
+ class TestMessage2(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_str: str = aristaproto.string_field(3)
+ some_bool: bool = aristaproto.bool_field(4)
+ some_default_int: int = aristaproto.int32_field(5)
+ some_default_double: float = aristaproto.double_field(6)
+ some_default_str: str = aristaproto.string_field(7)
+ some_default_bool: bool = aristaproto.bool_field(8)
+
+ test = TestMessage2().from_dict(
+ {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+ )
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+
+ test = TestMessage2().from_pydict(
+ {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+ )
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 2,
+ "someDouble": 1.2,
+ "someStr": "hello",
+ "someBool": True,
+ "someDefaultInt": 0,
+ "someDefaultDouble": 0.0,
+ "someDefaultStr": "",
+ "someDefaultBool": False,
+ }
+
+ # Nested messages
+ @dataclass
+ class TestChildMessage(aristaproto.Message):
+ some_other_int: int = aristaproto.int32_field(1)
+
+ @dataclass
+ class TestParentMessage(aristaproto.Message):
+ some_int: int = aristaproto.int32_field(1)
+ some_double: float = aristaproto.double_field(2)
+ some_message: TestChildMessage = aristaproto.message_field(3)
+
+ test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2})
+
+ assert test.to_dict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 1.2,
+ "someMessage": {"someOtherInt": 0},
+ }
+
+ test = TestParentMessage().from_pydict({"someInt": 0, "someDouble": 1.2})
+
+ assert test.to_pydict(include_default_values=True) == {
+ "someInt": 0,
+ "someDouble": 1.2,
+ "someMessage": {"someOtherInt": 0},
+ }
+
+
+def test_to_dict_datetime_values():
+ @dataclass
+ class TestDatetimeMessage(aristaproto.Message):
+ bar: datetime = aristaproto.message_field(1)
+ baz: timedelta = aristaproto.message_field(2)
+
+ test = TestDatetimeMessage().from_dict(
+ {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"}
+ )
+
+ assert test.to_dict() == {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"}
+
+ test = TestDatetimeMessage().from_pydict(
+ {"bar": datetime(year=2020, month=1, day=1), "baz": timedelta(days=1)}
+ )
+
+ assert test.to_pydict() == {
+ "bar": datetime(year=2020, month=1, day=1),
+ "baz": timedelta(days=1),
+ }
+
+
+def test_oneof_default_value_set_causes_writes_wire():
+ @dataclass
+ class Empty(aristaproto.Message):
+ pass
+
+ @dataclass
+ class Foo(aristaproto.Message):
+ bar: int = aristaproto.int32_field(1, group="group1")
+ baz: str = aristaproto.string_field(2, group="group1")
+ qux: Empty = aristaproto.message_field(3, group="group1")
+
+ def _round_trip_serialization(foo: Foo) -> Foo:
+ return Foo().parse(bytes(foo))
+
+ foo1 = Foo(bar=0)
+ foo2 = Foo(baz="")
+ foo3 = Foo(qux=Empty())
+ foo4 = Foo()
+
+ assert bytes(foo1) == b"\x08\x00"
+ assert (
+ aristaproto.which_one_of(foo1, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo1), "group1")
+ == ("bar", 0)
+ )
+
+ assert bytes(foo2) == b"\x12\x00" # Baz is just an empty string
+ assert (
+ aristaproto.which_one_of(foo2, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo2), "group1")
+ == ("baz", "")
+ )
+
+ assert bytes(foo3) == b"\x1a\x00"
+ assert (
+ aristaproto.which_one_of(foo3, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo3), "group1")
+ == ("qux", Empty())
+ )
+
+ assert bytes(foo4) == b""
+ assert (
+ aristaproto.which_one_of(foo4, "group1")
+ == aristaproto.which_one_of(_round_trip_serialization(foo4), "group1")
+ == ("", None)
+ )
+
+
+def test_message_repr():
+ from tests.output_aristaproto.recursivemessage import Test
+
+ assert repr(Test(name="Loki")) == "Test(name='Loki')"
+ assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"
+
+
+def test_bool():
+ """Messages should evaluate similarly to a collection
+ >>> test = []
+ >>> bool(test)
+ ... False
+ >>> test.append(1)
+ >>> bool(test)
+ ... True
+ >>> del test[0]
+ >>> bool(test)
+ ... False
+ """
+
+ @dataclass
+ class Falsy(aristaproto.Message):
+ pass
+
+ @dataclass
+ class Truthy(aristaproto.Message):
+ bar: int = aristaproto.int32_field(1)
+
+ assert not Falsy()
+ t = Truthy()
+ assert not t
+ t.bar = 1
+ assert t
+ t.bar = 0
+ assert not t
+
+
+# valid ISO datetimes according to https://www.myintervals.com/blog/2009/05/20/iso-8601-date-validation-that-doesnt-suck/
+iso_candidates = """2009-12-12T12:34
+2009
+2009-05-19
+2009-05-19
+20090519
+2009123
+2009-05
+2009-123
+2009-222
+2009-001
+2009-W01-1
+2009-W51-1
+2009-W33
+2009W511
+2009-05-19
+2009-05-19 00:00
+2009-05-19 14
+2009-05-19 14:31
+2009-05-19 14:39:22
+2009-05-19T14:39Z
+2009-W21-2
+2009-W21-2T01:22
+2009-139
+2009-05-19 14:39:22-06:00
+2009-05-19 14:39:22+0600
+2009-05-19 14:39:22-01
+20090621T0545Z
+2007-04-06T00:00
+2007-04-05T24:00
+2010-02-18T16:23:48.5
+2010-02-18T16:23:48,444
+2010-02-18T16:23:48,3-06:00
+2010-02-18T16:23:00.4
+2010-02-18T16:23:00,25
+2010-02-18T16:23:00.33+0600
+2010-02-18T16:00:00.23334444
+2010-02-18T16:00:00,2283
+2009-05-19 143922
+2009-05-19 1439""".split(
+ "\n"
+)
+
+
+def test_iso_datetime():
+ @dataclass
+ class Envelope(aristaproto.Message):
+ ts: datetime = aristaproto.message_field(1)
+
+ msg = Envelope()
+
+ for _, candidate in enumerate(iso_candidates):
+ msg.from_dict({"ts": candidate})
+ assert isinstance(msg.ts, datetime)
+
+
+def test_iso_datetime_list():
+ @dataclass
+ class Envelope(aristaproto.Message):
+ timestamps: List[datetime] = aristaproto.message_field(1)
+
+ msg = Envelope()
+
+ msg.from_dict({"timestamps": iso_candidates})
+ assert all([isinstance(item, datetime) for item in msg.timestamps])
+
+
+def test_service_argument__expected_parameter():
+ from tests.output_aristaproto.service import TestStub
+
+ sig = signature(TestStub.do_thing)
+ do_thing_request_parameter = sig.parameters["do_thing_request"]
+ assert do_thing_request_parameter.default is Parameter.empty
+ assert do_thing_request_parameter.annotation == "DoThingRequest"
+
+
+def test_is_set():
+ @dataclass
+ class Spam(aristaproto.Message):
+ foo: bool = aristaproto.bool_field(1)
+ bar: Optional[int] = aristaproto.int32_field(2, optional=True)
+
+ assert not Spam().is_set("foo")
+ assert not Spam().is_set("bar")
+ assert Spam(foo=True).is_set("foo")
+ assert Spam(foo=True, bar=0).is_set("bar")
+
+
+def test_equality_comparison():
+ from tests.output_aristaproto.bool import Test as TestMessage
+
+ msg = TestMessage(value=True)
+
+ assert msg == msg
+ assert msg == ANY
+ assert msg == TestMessage(value=True)
+ assert msg != 1
+ assert msg != TestMessage(value=False)