diff options
Diffstat (limited to '')
-rw-r--r-- | tests/test_pickling.py | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/tests/test_pickling.py b/tests/test_pickling.py new file mode 100644 index 0000000..2356d98 --- /dev/null +++ b/tests/test_pickling.py @@ -0,0 +1,203 @@ +import pickle +from copy import ( + copy, + deepcopy, +) +from dataclasses import dataclass +from typing import ( + Dict, + List, +) +from unittest.mock import ANY + +import cachelib + +import aristaproto +from aristaproto.lib.google import protobuf as google + + +def unpickled(message): + return pickle.loads(pickle.dumps(message)) + + +@dataclass(eq=False, repr=False) +class Fe(aristaproto.Message): + abc: str = aristaproto.string_field(1) + + +@dataclass(eq=False, repr=False) +class Fi(aristaproto.Message): + abc: str = aristaproto.string_field(1) + + +@dataclass(eq=False, repr=False) +class Fo(aristaproto.Message): + abc: str = aristaproto.string_field(1) + + +@dataclass(eq=False, repr=False) +class NestedData(aristaproto.Message): + struct_foo: Dict[str, "google.Struct"] = aristaproto.map_field( + 1, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE + ) + map_str_any_bar: Dict[str, "google.Any"] = aristaproto.map_field( + 2, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE + ) + + +@dataclass(eq=False, repr=False) +class Complex(aristaproto.Message): + foo_str: str = aristaproto.string_field(1) + fe: "Fe" = aristaproto.message_field(3, group="grp") + fi: "Fi" = aristaproto.message_field(4, group="grp") + fo: "Fo" = aristaproto.message_field(5, group="grp") + nested_data: "NestedData" = aristaproto.message_field(6) + mapping: Dict[str, "google.Any"] = aristaproto.map_field( + 7, aristaproto.TYPE_STRING, aristaproto.TYPE_MESSAGE + ) + + +def complex_msg(): + return Complex( + foo_str="yep", + fe=Fe(abc="1"), + nested_data=NestedData( + struct_foo={ + "foo": google.Struct( + fields={ + "hello": google.Value( + list_value=google.ListValue( + values=[google.Value(string_value="world")] + ) + ) + } + ), + }, + map_str_any_bar={ + "key": google.Any(value=b"value"), + }, + ), + mapping={ + "message": google.Any(value=bytes(Fi(abc="hi"))), + "string": google.Any(value=b"howdy"), + }, + ) + + +def test_pickling_complex_message(): + msg = complex_msg() + deser = unpickled(msg) + assert msg == deser + assert msg.fe.abc == "1" + assert msg.is_set("fi") is not True + assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi"))) + assert msg.mapping["string"].value.decode() == "howdy" + assert ( + msg.nested_data.struct_foo["foo"] + .fields["hello"] + .list_value.values[0] + .string_value + == "world" + ) + + +def test_recursive_message(): + from tests.output_aristaproto.recursivemessage import Test as RecursiveMessage + + msg = RecursiveMessage() + msg = unpickled(msg) + + assert msg.child == RecursiveMessage() + + # Lazily-created zero-value children must not affect equality. + assert msg == RecursiveMessage() + + # Lazily-created zero-value children must not affect serialization. + assert bytes(msg) == b"" + + +def test_recursive_message_defaults(): + from tests.output_aristaproto.recursivemessage import ( + Intermediate, + Test as RecursiveMessage, + ) + + msg = RecursiveMessage(name="bob", intermediate=Intermediate(42)) + msg = unpickled(msg) + + # set values are as expected + assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42)) + + # lazy initialized works modifies the message + assert msg != RecursiveMessage( + name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude") + ) + msg.child.child.name = "jude" + assert msg == RecursiveMessage( + name="bob", + intermediate=Intermediate(42), + child=RecursiveMessage(child=RecursiveMessage(name="jude")), + ) + + # lazily initialization recurses as needed + assert msg.child.child.child.child.child.child.child == RecursiveMessage() + assert msg.intermediate.child.intermediate == Intermediate() + + +@dataclass +class PickledMessage(aristaproto.Message): + foo: bool = aristaproto.bool_field(1) + bar: int = aristaproto.int32_field(2) + baz: List[str] = aristaproto.string_field(3) + + +def test_copyability(): + msg = PickledMessage(bar=12, baz=["hello"]) + msg = unpickled(msg) + + copied = copy(msg) + assert msg == copied + assert msg is not copied + assert msg.baz is copied.baz + + deepcopied = deepcopy(msg) + assert msg == deepcopied + assert msg is not deepcopied + assert msg.baz is not deepcopied.baz + + +def test_message_can_be_cached(): + """Cachelib uses pickling to cache values""" + + cache = cachelib.SimpleCache() + + def use_cache(): + calls = getattr(use_cache, "calls", 0) + result = cache.get("message") + if result is not None: + return result + else: + setattr(use_cache, "calls", calls + 1) + result = complex_msg() + cache.set("message", result) + return result + + for n in range(10): + if n == 0: + assert not cache.has("message") + else: + assert cache.has("message") + + msg = use_cache() + assert use_cache.calls == 1 # The message is only ever built once + assert msg.fe.abc == "1" + assert msg.is_set("fi") is not True + assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi"))) + assert msg.mapping["string"].value.decode() == "howdy" + assert ( + msg.nested_data.struct_foo["foo"] + .fields["hello"] + .list_value.values[0] + .string_value + == "world" + ) |