# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import pytest import collections import datetime import os import pathlib import pickle import subprocess import string import sys import pyarrow as pa import numpy as np import pyarrow.tests.util as test_util try: import torch except ImportError: torch = None # Blacklist the module in case `import torch` is costly before # failing (ARROW-2071) sys.modules['torch'] = None try: from scipy.sparse import coo_matrix, csr_matrix, csc_matrix except ImportError: coo_matrix = None csr_matrix = None csc_matrix = None try: import sparse except ImportError: sparse = None # ignore all serialization deprecation warnings in this file, we test that the # warnings are actually raised in test_serialization_deprecated.py pytestmark = pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning") def assert_equal(obj1, obj2): if torch is not None and torch.is_tensor(obj1) and torch.is_tensor(obj2): if obj1.is_sparse: obj1 = obj1.to_dense() if obj2.is_sparse: obj2 = obj2.to_dense() assert torch.equal(obj1, obj2) return module_numpy = (type(obj1).__module__ == np.__name__ or type(obj2).__module__ == np.__name__) if module_numpy: empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or (hasattr(obj2, "shape") and obj2.shape == ())) if empty_shape: # This is a special case because currently np.testing.assert_equal # fails because we do not properly handle different numerical # types. assert obj1 == obj2, ("Objects {} and {} are " "different.".format(obj1, obj2)) else: np.testing.assert_equal(obj1, obj2) elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): special_keys = ["_pytype_"] assert (set(list(obj1.__dict__.keys()) + special_keys) == set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} " "and {} are " "different." .format( obj1, obj2)) if obj1.__dict__ == {}: print("WARNING: Empty dict in ", obj1) for key in obj1.__dict__.keys(): if key not in special_keys: assert_equal(obj1.__dict__[key], obj2.__dict__[key]) elif type(obj1) is dict or type(obj2) is dict: assert_equal(obj1.keys(), obj2.keys()) for key in obj1.keys(): assert_equal(obj1[key], obj2[key]) elif type(obj1) is list or type(obj2) is list: assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " "different lengths." .format(obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) elif type(obj1) is tuple or type(obj2) is tuple: assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with " "different lengths." .format(obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) elif (pa.lib.is_named_tuple(type(obj1)) or pa.lib.is_named_tuple(type(obj2))): assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples " "with different lengths." .format(obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) elif isinstance(obj1, pa.Array) and isinstance(obj2, pa.Array): assert obj1.equals(obj2) elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor): assert obj1.equals(obj2) elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor): assert obj1.equals(obj2) elif isinstance(obj1, pa.SparseCOOTensor) and \ isinstance(obj2, pa.SparseCOOTensor): assert obj1.equals(obj2) elif isinstance(obj1, pa.SparseCSRMatrix) and \ isinstance(obj2, pa.SparseCSRMatrix): assert obj1.equals(obj2) elif isinstance(obj1, pa.SparseCSCMatrix) and \ isinstance(obj2, pa.SparseCSCMatrix): assert obj1.equals(obj2) elif isinstance(obj1, pa.SparseCSFTensor) and \ isinstance(obj2, pa.SparseCSFTensor): assert obj1.equals(obj2) elif isinstance(obj1, pa.RecordBatch) and isinstance(obj2, pa.RecordBatch): assert obj1.equals(obj2) elif isinstance(obj1, pa.Table) and isinstance(obj2, pa.Table): assert obj1.equals(obj2) else: assert type(obj1) == type(obj2) and obj1 == obj2, \ "Objects {} and {} are different.".format(obj1, obj2) PRIMITIVE_OBJECTS = [ 0, 0.0, 0.9, 1 << 62, 1 << 999, [1 << 100, [1 << 100]], "a", string.printable, "\u262F", "hello world", "hello world", "\xff\xfe\x9c\x001\x000\x00", None, True, False, [], (), {}, {(1, 2): 1}, {(): 2}, [1, "hello", 3.0], "\u262F", 42.0, (1.0, "hi"), [1, 2, 3, None], [(None,), 3, 1.0], ["h", "e", "l", "l", "o", None], (None, None), ("hello", None), (True, False), {True: "hello", False: "world"}, {"hello": "world", 1: 42, 2.5: 45}, {"hello": {2, 3}, "world": {42.0}, "this": None}, np.int8(3), np.int32(4), np.int64(5), np.uint8(3), np.uint32(4), np.uint64(5), np.float16(1.9), np.float32(1.9), np.float64(1.9), np.zeros([8, 20]), np.random.normal(size=[17, 10]), np.array(["hi", 3]), np.array(["hi", 3], dtype=object), np.random.normal(size=[15, 13]).T ] index_types = ('i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8') tensor_types = ('i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8', 'f2', 'f4', 'f8') PRIMITIVE_OBJECTS += [0, np.array([["hi", "hi"], [1.3, 1]])] COMPLEX_OBJECTS = [ [[[[[[[[[[[[]]]]]]]]]]]], {"obj{}".format(i): np.random.normal(size=[4, 4]) for i in range(5)}, # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): { # (): {(): {}}}}}}}}}}}}}, ((((((((((),),),),),),),),),), {"a": {"b": {"c": {"d": {}}}}}, ] class Foo: def __init__(self, value=0): self.value = value def __hash__(self): return hash(self.value) def __eq__(self, other): return other.value == self.value class Bar: def __init__(self): for i, val in enumerate(COMPLEX_OBJECTS): setattr(self, "field{}".format(i), val) class Baz: def __init__(self): self.foo = Foo() self.bar = Bar() def method(self, arg): pass class Qux: def __init__(self): self.objs = [Foo(1), Foo(42)] class SubQux(Qux): def __init__(self): Qux.__init__(self) class SubQuxPickle(Qux): def __init__(self): Qux.__init__(self) class CustomError(Exception): pass Point = collections.namedtuple("Point", ["x", "y"]) NamedTupleExample = collections.namedtuple( "Example", "field1, field2, field3, field4, field5") CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22), Foo(), Bar(), Baz(), Qux(), SubQux(), SubQuxPickle(), NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]), collections.OrderedDict([("hello", 1), ("world", 2)]), collections.deque([1, 2, 3, "a", "b", "c", 3.5]), collections.Counter([1, 1, 1, 2, 2, 3, "a", "b"])] def make_serialization_context(): with pytest.warns(FutureWarning): context = pa.default_serialization_context() context.register_type(Foo, "Foo") context.register_type(Bar, "Bar") context.register_type(Baz, "Baz") context.register_type(Qux, "Quz") context.register_type(SubQux, "SubQux") context.register_type(SubQuxPickle, "SubQuxPickle", pickle=True) context.register_type(Exception, "Exception") context.register_type(CustomError, "CustomError") context.register_type(Point, "Point") context.register_type(NamedTupleExample, "NamedTupleExample") return context global_serialization_context = make_serialization_context() def serialization_roundtrip(value, scratch_buffer, context=global_serialization_context): writer = pa.FixedSizeBufferWriter(scratch_buffer) pa.serialize_to(value, writer, context=context) reader = pa.BufferReader(scratch_buffer) result = pa.deserialize_from(reader, None, context=context) assert_equal(value, result) _check_component_roundtrip(value, context=context) def _check_component_roundtrip(value, context=global_serialization_context): # Test to/from components serialized = pa.serialize(value, context=context) components = serialized.to_components() from_comp = pa.SerializedPyObject.from_components(components) recons = from_comp.deserialize(context=context) assert_equal(value, recons) @pytest.fixture(scope='session') def large_buffer(size=32*1024*1024): yield pa.allocate_buffer(size) def large_memory_map(tmpdir_factory, size=100*1024*1024): path = (tmpdir_factory.mktemp('data') .join('pyarrow-serialization-tmp-file').strpath) # Create a large memory mapped file with open(path, 'wb') as f: f.write(np.random.randint(0, 256, size=size) .astype('u1') .tobytes() [:size]) return path def test_clone(): context = pa.SerializationContext() class Foo: pass def custom_serializer(obj): return 0 def custom_deserializer(serialized_obj): return (serialized_obj, 'a') context.register_type(Foo, 'Foo', custom_serializer=custom_serializer, custom_deserializer=custom_deserializer) new_context = context.clone() f = Foo() serialized = pa.serialize(f, context=context) deserialized = serialized.deserialize(context=context) assert deserialized == (0, 'a') serialized = pa.serialize(f, context=new_context) deserialized = serialized.deserialize(context=new_context) assert deserialized == (0, 'a') def test_primitive_serialization_notbroken(large_buffer): serialization_roundtrip({(1, 2): 2}, large_buffer) def test_primitive_serialization_broken(large_buffer): serialization_roundtrip({(): 2}, large_buffer) def test_primitive_serialization(large_buffer): for obj in PRIMITIVE_OBJECTS: serialization_roundtrip(obj, large_buffer) def test_integer_limits(large_buffer): # Check that Numpy scalars can be represented up to their limit values # (except np.uint64 which is limited to 2**63 - 1) for dt in [np.int8, np.int64, np.int32, np.int64, np.uint8, np.uint64, np.uint32, np.uint64]: scal = dt(np.iinfo(dt).min) serialization_roundtrip(scal, large_buffer) if dt is not np.uint64: scal = dt(np.iinfo(dt).max) serialization_roundtrip(scal, large_buffer) else: scal = dt(2**63 - 1) serialization_roundtrip(scal, large_buffer) for v in (2**63, 2**64 - 1): scal = dt(v) with pytest.raises(pa.ArrowInvalid): pa.serialize(scal) def test_serialize_to_buffer(): for nthreads in [1, 4]: for value in COMPLEX_OBJECTS: buf = pa.serialize(value).to_buffer(nthreads=nthreads) result = pa.deserialize(buf) assert_equal(value, result) def test_complex_serialization(large_buffer): for obj in COMPLEX_OBJECTS: serialization_roundtrip(obj, large_buffer) def test_custom_serialization(large_buffer): for obj in CUSTOM_OBJECTS: serialization_roundtrip(obj, large_buffer) def test_default_dict_serialization(large_buffer): pytest.importorskip("cloudpickle") obj = collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)]) serialization_roundtrip(obj, large_buffer) def test_numpy_serialization(large_buffer): for t in ["bool", "int8", "uint8", "int16", "uint16", "int32", "uint32", "float16", "float32", "float64", "