# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import collections import warnings import numpy as np import pyarrow as pa from pyarrow.lib import SerializationContext, py_buffer, builtin_pickle try: import cloudpickle except ImportError: cloudpickle = builtin_pickle try: # This function is available after numpy-0.16.0. # See also: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py from numpy.lib.format import descr_to_dtype except ImportError: def descr_to_dtype(descr): ''' descr may be stored as dtype.descr, which is a list of (name, format, [shape]) tuples where format may be a str or a tuple. Offsets are not explicitly saved, rather empty fields with name, format == '', '|Vn' are added as padding. This function reverses the process, eliminating the empty padding fields. ''' if isinstance(descr, str): # No padding removal needed return np.dtype(descr) elif isinstance(descr, tuple): # subtype, will always have a shape descr[1] dt = descr_to_dtype(descr[0]) return np.dtype((dt, descr[1])) fields = [] offset = 0 for field in descr: if len(field) == 2: name, descr_str = field dt = descr_to_dtype(descr_str) else: name, descr_str, shape = field dt = np.dtype((descr_to_dtype(descr_str), shape)) # Ignore padding bytes, which will be void bytes with '' as name # Once support for blank names is removed, only "if name == ''" # needed) is_pad = (name == '' and dt.type is np.void and dt.names is None) if not is_pad: fields.append((name, dt, offset)) offset += dt.itemsize names, formats, offsets = zip(*fields) # names may be (title, names) tuples nametups = (n if isinstance(n, tuple) else (None, n) for n in names) titles, names = zip(*nametups) return np.dtype({'names': names, 'formats': formats, 'titles': titles, 'offsets': offsets, 'itemsize': offset}) def _deprecate_serialization(name): msg = ( "'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a " "future version. Use pickle or the pyarrow IPC functionality instead." ).format(name) warnings.warn(msg, FutureWarning, stacklevel=3) # ---------------------------------------------------------------------- # Set up serialization for numpy with dtype object (primitive types are # handled efficiently with Arrow's Tensor facilities, see # python_to_arrow.cc) def _serialize_numpy_array_list(obj): if obj.dtype.str != '|O': # Make the array c_contiguous if necessary so that we can call change # the view. if not obj.flags.c_contiguous: obj = np.ascontiguousarray(obj) return obj.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype) else: return obj.tolist(), np.lib.format.dtype_to_descr(obj.dtype) def _deserialize_numpy_array_list(data): if data[1] != '|O': assert data[0].dtype == np.uint8 return data[0].view(descr_to_dtype(data[1])) else: return np.array(data[0], dtype=np.dtype(data[1])) def _serialize_numpy_matrix(obj): if obj.dtype.str != '|O': # Make the array c_contiguous if necessary so that we can call change # the view. if not obj.flags.c_contiguous: obj = np.ascontiguousarray(obj.A) return obj.A.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype) else: return obj.A.tolist(), np.lib.format.dtype_to_descr(obj.dtype) def _deserialize_numpy_matrix(data): if data[1] != '|O': assert data[0].dtype == np.uint8 return np.matrix(data[0].view(descr_to_dtype(data[1])), copy=False) else: return np.matrix(data[0], dtype=np.dtype(data[1]), copy=False) # ---------------------------------------------------------------------- # pyarrow.RecordBatch-specific serialization matters def _serialize_pyarrow_recordbatch(batch): output_stream = pa.BufferOutputStream() with pa.RecordBatchStreamWriter(output_stream, schema=batch.schema) as wr: wr.write_batch(batch) return output_stream.getvalue() # This will also close the stream. def _deserialize_pyarrow_recordbatch(buf): with pa.RecordBatchStreamReader(buf) as reader: return reader.read_next_batch() # ---------------------------------------------------------------------- # pyarrow.Array-specific serialization matters def _serialize_pyarrow_array(array): # TODO(suquark): implement more effcient array serialization. batch = pa.RecordBatch.from_arrays([array], ['']) return _serialize_pyarrow_recordbatch(batch) def _deserialize_pyarrow_array(buf): # TODO(suquark): implement more effcient array deserialization. batch = _deserialize_pyarrow_recordbatch(buf) return batch.columns[0] # ---------------------------------------------------------------------- # pyarrow.Table-specific serialization matters def _serialize_pyarrow_table(table): output_stream = pa.BufferOutputStream() with pa.RecordBatchStreamWriter(output_stream, schema=table.schema) as wr: wr.write_table(table) return output_stream.getvalue() # This will also close the stream. def _deserialize_pyarrow_table(buf): with pa.RecordBatchStreamReader(buf) as reader: return reader.read_all() def _pickle_to_buffer(x): pickled = builtin_pickle.dumps(x, protocol=builtin_pickle.HIGHEST_PROTOCOL) return py_buffer(pickled) def _load_pickle_from_buffer(data): as_memoryview = memoryview(data) return builtin_pickle.loads(as_memoryview) # ---------------------------------------------------------------------- # pandas-specific serialization matters def _register_custom_pandas_handlers(context): # ARROW-1784, faster path for pandas-only visibility try: import pandas as pd except ImportError: return import pyarrow.pandas_compat as pdcompat sparse_type_error_msg = ( '{0} serialization is not supported.\n' 'Note that {0} is planned to be deprecated ' 'in pandas future releases.\n' 'See https://github.com/pandas-dev/pandas/issues/19239 ' 'for more information.' ) def _serialize_pandas_dataframe(obj): if (pdcompat._pandas_api.has_sparse and isinstance(obj, pd.SparseDataFrame)): raise NotImplementedError( sparse_type_error_msg.format('SparseDataFrame') ) return pdcompat.dataframe_to_serialized_dict(obj) def _deserialize_pandas_dataframe(data): return pdcompat.serialized_dict_to_dataframe(data) def _serialize_pandas_series(obj): if (pdcompat._pandas_api.has_sparse and isinstance(obj, pd.SparseSeries)): raise NotImplementedError( sparse_type_error_msg.format('SparseSeries') ) return _serialize_pandas_dataframe(pd.DataFrame({obj.name: obj})) def _deserialize_pandas_series(data): deserialized = _deserialize_pandas_dataframe(data) return deserialized[deserialized.columns[0]] context.register_type( pd.Series, 'pd.Series', custom_serializer=_serialize_pandas_series, custom_deserializer=_deserialize_pandas_series) context.register_type( pd.Index, 'pd.Index', custom_serializer=_pickle_to_buffer, custom_deserializer=_load_pickle_from_buffer) if hasattr(pd.core, 'arrays'): if hasattr(pd.core.arrays, 'interval'): context.register_type( pd.core.arrays.interval.IntervalArray, 'pd.core.arrays.interval.IntervalArray', custom_serializer=_pickle_to_buffer, custom_deserializer=_load_pickle_from_buffer) if hasattr(pd.core.arrays, 'period'): context.register_type( pd.core.arrays.period.PeriodArray, 'pd.core.arrays.period.PeriodArray', custom_serializer=_pickle_to_buffer, custom_deserializer=_load_pickle_from_buffer) if hasattr(pd.core.arrays, 'datetimes'): context.register_type( pd.core.arrays.datetimes.DatetimeArray, 'pd.core.arrays.datetimes.DatetimeArray', custom_serializer=_pickle_to_buffer, custom_deserializer=_load_pickle_from_buffer) context.register_type( pd.DataFrame, 'pd.DataFrame', custom_serializer=_serialize_pandas_dataframe, custom_deserializer=_deserialize_pandas_dataframe) def register_torch_serialization_handlers(serialization_context): # ---------------------------------------------------------------------- # Set up serialization for pytorch tensors _deprecate_serialization("register_torch_serialization_handlers") try: import torch def _serialize_torch_tensor(obj): if obj.is_sparse: return pa.SparseCOOTensor.from_numpy( obj._values().detach().numpy(), obj._indices().detach().numpy().T, shape=list(obj.shape)) else: return obj.detach().numpy() def _deserialize_torch_tensor(data): if isinstance(data, pa.SparseCOOTensor): return torch.sparse_coo_tensor( indices=data.to_numpy()[1].T, values=data.to_numpy()[0][:, 0], size=data.shape) else: return torch.from_numpy(data) for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor, torch.ByteTensor, torch.CharTensor, torch.ShortTensor, torch.IntTensor, torch.LongTensor, torch.Tensor]: serialization_context.register_type( t, "torch." + t.__name__, custom_serializer=_serialize_torch_tensor, custom_deserializer=_deserialize_torch_tensor) except ImportError: # no torch pass def _register_collections_serialization_handlers(serialization_context): def _serialize_deque(obj): return list(obj) def _deserialize_deque(data): return collections.deque(data) serialization_context.register_type( collections.deque, "collections.deque", custom_serializer=_serialize_deque, custom_deserializer=_deserialize_deque) def _serialize_ordered_dict(obj): return list(obj.keys()), list(obj.values()) def _deserialize_ordered_dict(data): return collections.OrderedDict(zip(data[0], data[1])) serialization_context.register_type( collections.OrderedDict, "collections.OrderedDict", custom_serializer=_serialize_ordered_dict, custom_deserializer=_deserialize_ordered_dict) def _serialize_default_dict(obj): return list(obj.keys()), list(obj.values()), obj.default_factory def _deserialize_default_dict(data): return collections.defaultdict(data[2], zip(data[0], data[1])) serialization_context.register_type( collections.defaultdict, "collections.defaultdict", custom_serializer=_serialize_default_dict, custom_deserializer=_deserialize_default_dict) def _serialize_counter(obj): return list(obj.keys()), list(obj.values()) def _deserialize_counter(data): return collections.Counter(dict(zip(data[0], data[1]))) serialization_context.register_type( collections.Counter, "collections.Counter", custom_serializer=_serialize_counter, custom_deserializer=_deserialize_counter) # ---------------------------------------------------------------------- # Set up serialization for scipy sparse matrices. Primitive types are handled # efficiently with Arrow's SparseTensor facilities, see numpy_convert.cc) def _register_scipy_handlers(serialization_context): try: from scipy.sparse import (csr_matrix, csc_matrix, coo_matrix, isspmatrix_coo, isspmatrix_csr, isspmatrix_csc, isspmatrix) def _serialize_scipy_sparse(obj): if isspmatrix_coo(obj): return 'coo', pa.SparseCOOTensor.from_scipy(obj) elif isspmatrix_csr(obj): return 'csr', pa.SparseCSRMatrix.from_scipy(obj) elif isspmatrix_csc(obj): return 'csc', pa.SparseCSCMatrix.from_scipy(obj) elif isspmatrix(obj): return 'csr', pa.SparseCOOTensor.from_scipy(obj.to_coo()) else: raise NotImplementedError( "Serialization of {} is not supported.".format(obj[0])) def _deserialize_scipy_sparse(data): if data[0] == 'coo': return data[1].to_scipy() elif data[0] == 'csr': return data[1].to_scipy() elif data[0] == 'csc': return data[1].to_scipy() else: return data[1].to_scipy() serialization_context.register_type( coo_matrix, 'scipy.sparse.coo.coo_matrix', custom_serializer=_serialize_scipy_sparse, custom_deserializer=_deserialize_scipy_sparse) serialization_context.register_type( csr_matrix, 'scipy.sparse.csr.csr_matrix', custom_serializer=_serialize_scipy_sparse, custom_deserializer=_deserialize_scipy_sparse) serialization_context.register_type( csc_matrix, 'scipy.sparse.csc.csc_matrix', custom_serializer=_serialize_scipy_sparse, custom_deserializer=_deserialize_scipy_sparse) except ImportError: # no scipy pass # ---------------------------------------------------------------------- # Set up serialization for pydata/sparse tensors. def _register_pydata_sparse_handlers(serialization_context): try: import sparse def _serialize_pydata_sparse(obj): if isinstance(obj, sparse.COO): return 'coo', pa.SparseCOOTensor.from_pydata_sparse(obj) else: raise NotImplementedError( "Serialization of {} is not supported.".format(sparse.COO)) def _deserialize_pydata_sparse(data): if data[0] == 'coo': data_array, coords = data[1].to_numpy() return sparse.COO( data=data_array[:, 0], coords=coords.T, shape=data[1].shape) serialization_context.register_type( sparse.COO, 'sparse.COO', custom_serializer=_serialize_pydata_sparse, custom_deserializer=_deserialize_pydata_sparse) except ImportError: # no pydata/sparse pass def _register_default_serialization_handlers(serialization_context): # ---------------------------------------------------------------------- # Set up serialization for primitive datatypes # TODO(pcm): This is currently a workaround until arrow supports # arbitrary precision integers. This is only called on long integers, # see the associated case in the append method in python_to_arrow.cc serialization_context.register_type( int, "int", custom_serializer=lambda obj: str(obj), custom_deserializer=lambda data: int(data)) serialization_context.register_type( type(lambda: 0), "function", pickle=True) serialization_context.register_type(type, "type", pickle=True) serialization_context.register_type( np.matrix, 'np.matrix', custom_serializer=_serialize_numpy_matrix, custom_deserializer=_deserialize_numpy_matrix) serialization_context.register_type( np.ndarray, 'np.array', custom_serializer=_serialize_numpy_array_list, custom_deserializer=_deserialize_numpy_array_list) serialization_context.register_type( pa.Array, 'pyarrow.Array', custom_serializer=_serialize_pyarrow_array, custom_deserializer=_deserialize_pyarrow_array) serialization_context.register_type( pa.RecordBatch, 'pyarrow.RecordBatch', custom_serializer=_serialize_pyarrow_recordbatch, custom_deserializer=_deserialize_pyarrow_recordbatch) serialization_context.register_type( pa.Table, 'pyarrow.Table', custom_serializer=_serialize_pyarrow_table, custom_deserializer=_deserialize_pyarrow_table) _register_collections_serialization_handlers(serialization_context) _register_custom_pandas_handlers(serialization_context) _register_scipy_handlers(serialization_context) _register_pydata_sparse_handlers(serialization_context) def register_default_serialization_handlers(serialization_context): _deprecate_serialization("register_default_serialization_handlers") _register_default_serialization_handlers(serialization_context) def default_serialization_context(): _deprecate_serialization("default_serialization_context") context = SerializationContext() _register_default_serialization_handlers(context) return context