summaryrefslogtreecommitdiffstats
path: root/src/arrow/python/pyarrow/serialization.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/python/pyarrow/serialization.py')
-rw-r--r--src/arrow/python/pyarrow/serialization.py504
1 files changed, 504 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/serialization.py b/src/arrow/python/pyarrow/serialization.py
new file mode 100644
index 000000000..d59a13166
--- /dev/null
+++ b/src/arrow/python/pyarrow/serialization.py
@@ -0,0 +1,504 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import collections
+import warnings
+
+import numpy as np
+
+import pyarrow as pa
+from pyarrow.lib import SerializationContext, py_buffer, builtin_pickle
+
+try:
+ import cloudpickle
+except ImportError:
+ cloudpickle = builtin_pickle
+
+
+try:
+ # This function is available after numpy-0.16.0.
+ # See also: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py
+ from numpy.lib.format import descr_to_dtype
+except ImportError:
+ def descr_to_dtype(descr):
+ '''
+ descr may be stored as dtype.descr, which is a list of (name, format,
+ [shape]) tuples where format may be a str or a tuple. Offsets are not
+ explicitly saved, rather empty fields with name, format == '', '|Vn'
+ are added as padding. This function reverses the process, eliminating
+ the empty padding fields.
+ '''
+ if isinstance(descr, str):
+ # No padding removal needed
+ return np.dtype(descr)
+ elif isinstance(descr, tuple):
+ # subtype, will always have a shape descr[1]
+ dt = descr_to_dtype(descr[0])
+ return np.dtype((dt, descr[1]))
+ fields = []
+ offset = 0
+ for field in descr:
+ if len(field) == 2:
+ name, descr_str = field
+ dt = descr_to_dtype(descr_str)
+ else:
+ name, descr_str, shape = field
+ dt = np.dtype((descr_to_dtype(descr_str), shape))
+
+ # Ignore padding bytes, which will be void bytes with '' as name
+ # Once support for blank names is removed, only "if name == ''"
+ # needed)
+ is_pad = (name == '' and dt.type is np.void and dt.names is None)
+ if not is_pad:
+ fields.append((name, dt, offset))
+
+ offset += dt.itemsize
+
+ names, formats, offsets = zip(*fields)
+ # names may be (title, names) tuples
+ nametups = (n if isinstance(n, tuple) else (None, n) for n in names)
+ titles, names = zip(*nametups)
+ return np.dtype({'names': names, 'formats': formats, 'titles': titles,
+ 'offsets': offsets, 'itemsize': offset})
+
+
+def _deprecate_serialization(name):
+ msg = (
+ "'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a "
+ "future version. Use pickle or the pyarrow IPC functionality instead."
+ ).format(name)
+ warnings.warn(msg, FutureWarning, stacklevel=3)
+
+
+# ----------------------------------------------------------------------
+# Set up serialization for numpy with dtype object (primitive types are
+# handled efficiently with Arrow's Tensor facilities, see
+# python_to_arrow.cc)
+
+def _serialize_numpy_array_list(obj):
+ if obj.dtype.str != '|O':
+ # Make the array c_contiguous if necessary so that we can call change
+ # the view.
+ if not obj.flags.c_contiguous:
+ obj = np.ascontiguousarray(obj)
+ return obj.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype)
+ else:
+ return obj.tolist(), np.lib.format.dtype_to_descr(obj.dtype)
+
+
+def _deserialize_numpy_array_list(data):
+ if data[1] != '|O':
+ assert data[0].dtype == np.uint8
+ return data[0].view(descr_to_dtype(data[1]))
+ else:
+ return np.array(data[0], dtype=np.dtype(data[1]))
+
+
+def _serialize_numpy_matrix(obj):
+ if obj.dtype.str != '|O':
+ # Make the array c_contiguous if necessary so that we can call change
+ # the view.
+ if not obj.flags.c_contiguous:
+ obj = np.ascontiguousarray(obj.A)
+ return obj.A.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype)
+ else:
+ return obj.A.tolist(), np.lib.format.dtype_to_descr(obj.dtype)
+
+
+def _deserialize_numpy_matrix(data):
+ if data[1] != '|O':
+ assert data[0].dtype == np.uint8
+ return np.matrix(data[0].view(descr_to_dtype(data[1])),
+ copy=False)
+ else:
+ return np.matrix(data[0], dtype=np.dtype(data[1]), copy=False)
+
+
+# ----------------------------------------------------------------------
+# pyarrow.RecordBatch-specific serialization matters
+
+def _serialize_pyarrow_recordbatch(batch):
+ output_stream = pa.BufferOutputStream()
+ with pa.RecordBatchStreamWriter(output_stream, schema=batch.schema) as wr:
+ wr.write_batch(batch)
+ return output_stream.getvalue() # This will also close the stream.
+
+
+def _deserialize_pyarrow_recordbatch(buf):
+ with pa.RecordBatchStreamReader(buf) as reader:
+ return reader.read_next_batch()
+
+
+# ----------------------------------------------------------------------
+# pyarrow.Array-specific serialization matters
+
+def _serialize_pyarrow_array(array):
+ # TODO(suquark): implement more effcient array serialization.
+ batch = pa.RecordBatch.from_arrays([array], [''])
+ return _serialize_pyarrow_recordbatch(batch)
+
+
+def _deserialize_pyarrow_array(buf):
+ # TODO(suquark): implement more effcient array deserialization.
+ batch = _deserialize_pyarrow_recordbatch(buf)
+ return batch.columns[0]
+
+
+# ----------------------------------------------------------------------
+# pyarrow.Table-specific serialization matters
+
+def _serialize_pyarrow_table(table):
+ output_stream = pa.BufferOutputStream()
+ with pa.RecordBatchStreamWriter(output_stream, schema=table.schema) as wr:
+ wr.write_table(table)
+ return output_stream.getvalue() # This will also close the stream.
+
+
+def _deserialize_pyarrow_table(buf):
+ with pa.RecordBatchStreamReader(buf) as reader:
+ return reader.read_all()
+
+
+def _pickle_to_buffer(x):
+ pickled = builtin_pickle.dumps(x, protocol=builtin_pickle.HIGHEST_PROTOCOL)
+ return py_buffer(pickled)
+
+
+def _load_pickle_from_buffer(data):
+ as_memoryview = memoryview(data)
+ return builtin_pickle.loads(as_memoryview)
+
+
+# ----------------------------------------------------------------------
+# pandas-specific serialization matters
+
+def _register_custom_pandas_handlers(context):
+ # ARROW-1784, faster path for pandas-only visibility
+
+ try:
+ import pandas as pd
+ except ImportError:
+ return
+
+ import pyarrow.pandas_compat as pdcompat
+
+ sparse_type_error_msg = (
+ '{0} serialization is not supported.\n'
+ 'Note that {0} is planned to be deprecated '
+ 'in pandas future releases.\n'
+ 'See https://github.com/pandas-dev/pandas/issues/19239 '
+ 'for more information.'
+ )
+
+ def _serialize_pandas_dataframe(obj):
+ if (pdcompat._pandas_api.has_sparse and
+ isinstance(obj, pd.SparseDataFrame)):
+ raise NotImplementedError(
+ sparse_type_error_msg.format('SparseDataFrame')
+ )
+
+ return pdcompat.dataframe_to_serialized_dict(obj)
+
+ def _deserialize_pandas_dataframe(data):
+ return pdcompat.serialized_dict_to_dataframe(data)
+
+ def _serialize_pandas_series(obj):
+ if (pdcompat._pandas_api.has_sparse and
+ isinstance(obj, pd.SparseSeries)):
+ raise NotImplementedError(
+ sparse_type_error_msg.format('SparseSeries')
+ )
+
+ return _serialize_pandas_dataframe(pd.DataFrame({obj.name: obj}))
+
+ def _deserialize_pandas_series(data):
+ deserialized = _deserialize_pandas_dataframe(data)
+ return deserialized[deserialized.columns[0]]
+
+ context.register_type(
+ pd.Series, 'pd.Series',
+ custom_serializer=_serialize_pandas_series,
+ custom_deserializer=_deserialize_pandas_series)
+
+ context.register_type(
+ pd.Index, 'pd.Index',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ if hasattr(pd.core, 'arrays'):
+ if hasattr(pd.core.arrays, 'interval'):
+ context.register_type(
+ pd.core.arrays.interval.IntervalArray,
+ 'pd.core.arrays.interval.IntervalArray',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ if hasattr(pd.core.arrays, 'period'):
+ context.register_type(
+ pd.core.arrays.period.PeriodArray,
+ 'pd.core.arrays.period.PeriodArray',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ if hasattr(pd.core.arrays, 'datetimes'):
+ context.register_type(
+ pd.core.arrays.datetimes.DatetimeArray,
+ 'pd.core.arrays.datetimes.DatetimeArray',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ context.register_type(
+ pd.DataFrame, 'pd.DataFrame',
+ custom_serializer=_serialize_pandas_dataframe,
+ custom_deserializer=_deserialize_pandas_dataframe)
+
+
+def register_torch_serialization_handlers(serialization_context):
+ # ----------------------------------------------------------------------
+ # Set up serialization for pytorch tensors
+ _deprecate_serialization("register_torch_serialization_handlers")
+
+ try:
+ import torch
+
+ def _serialize_torch_tensor(obj):
+ if obj.is_sparse:
+ return pa.SparseCOOTensor.from_numpy(
+ obj._values().detach().numpy(),
+ obj._indices().detach().numpy().T,
+ shape=list(obj.shape))
+ else:
+ return obj.detach().numpy()
+
+ def _deserialize_torch_tensor(data):
+ if isinstance(data, pa.SparseCOOTensor):
+ return torch.sparse_coo_tensor(
+ indices=data.to_numpy()[1].T,
+ values=data.to_numpy()[0][:, 0],
+ size=data.shape)
+ else:
+ return torch.from_numpy(data)
+
+ for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
+ torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
+ torch.IntTensor, torch.LongTensor, torch.Tensor]:
+ serialization_context.register_type(
+ t, "torch." + t.__name__,
+ custom_serializer=_serialize_torch_tensor,
+ custom_deserializer=_deserialize_torch_tensor)
+ except ImportError:
+ # no torch
+ pass
+
+
+def _register_collections_serialization_handlers(serialization_context):
+ def _serialize_deque(obj):
+ return list(obj)
+
+ def _deserialize_deque(data):
+ return collections.deque(data)
+
+ serialization_context.register_type(
+ collections.deque, "collections.deque",
+ custom_serializer=_serialize_deque,
+ custom_deserializer=_deserialize_deque)
+
+ def _serialize_ordered_dict(obj):
+ return list(obj.keys()), list(obj.values())
+
+ def _deserialize_ordered_dict(data):
+ return collections.OrderedDict(zip(data[0], data[1]))
+
+ serialization_context.register_type(
+ collections.OrderedDict, "collections.OrderedDict",
+ custom_serializer=_serialize_ordered_dict,
+ custom_deserializer=_deserialize_ordered_dict)
+
+ def _serialize_default_dict(obj):
+ return list(obj.keys()), list(obj.values()), obj.default_factory
+
+ def _deserialize_default_dict(data):
+ return collections.defaultdict(data[2], zip(data[0], data[1]))
+
+ serialization_context.register_type(
+ collections.defaultdict, "collections.defaultdict",
+ custom_serializer=_serialize_default_dict,
+ custom_deserializer=_deserialize_default_dict)
+
+ def _serialize_counter(obj):
+ return list(obj.keys()), list(obj.values())
+
+ def _deserialize_counter(data):
+ return collections.Counter(dict(zip(data[0], data[1])))
+
+ serialization_context.register_type(
+ collections.Counter, "collections.Counter",
+ custom_serializer=_serialize_counter,
+ custom_deserializer=_deserialize_counter)
+
+
+# ----------------------------------------------------------------------
+# Set up serialization for scipy sparse matrices. Primitive types are handled
+# efficiently with Arrow's SparseTensor facilities, see numpy_convert.cc)
+
+def _register_scipy_handlers(serialization_context):
+ try:
+ from scipy.sparse import (csr_matrix, csc_matrix, coo_matrix,
+ isspmatrix_coo, isspmatrix_csr,
+ isspmatrix_csc, isspmatrix)
+
+ def _serialize_scipy_sparse(obj):
+ if isspmatrix_coo(obj):
+ return 'coo', pa.SparseCOOTensor.from_scipy(obj)
+
+ elif isspmatrix_csr(obj):
+ return 'csr', pa.SparseCSRMatrix.from_scipy(obj)
+
+ elif isspmatrix_csc(obj):
+ return 'csc', pa.SparseCSCMatrix.from_scipy(obj)
+
+ elif isspmatrix(obj):
+ return 'csr', pa.SparseCOOTensor.from_scipy(obj.to_coo())
+
+ else:
+ raise NotImplementedError(
+ "Serialization of {} is not supported.".format(obj[0]))
+
+ def _deserialize_scipy_sparse(data):
+ if data[0] == 'coo':
+ return data[1].to_scipy()
+
+ elif data[0] == 'csr':
+ return data[1].to_scipy()
+
+ elif data[0] == 'csc':
+ return data[1].to_scipy()
+
+ else:
+ return data[1].to_scipy()
+
+ serialization_context.register_type(
+ coo_matrix, 'scipy.sparse.coo.coo_matrix',
+ custom_serializer=_serialize_scipy_sparse,
+ custom_deserializer=_deserialize_scipy_sparse)
+
+ serialization_context.register_type(
+ csr_matrix, 'scipy.sparse.csr.csr_matrix',
+ custom_serializer=_serialize_scipy_sparse,
+ custom_deserializer=_deserialize_scipy_sparse)
+
+ serialization_context.register_type(
+ csc_matrix, 'scipy.sparse.csc.csc_matrix',
+ custom_serializer=_serialize_scipy_sparse,
+ custom_deserializer=_deserialize_scipy_sparse)
+
+ except ImportError:
+ # no scipy
+ pass
+
+
+# ----------------------------------------------------------------------
+# Set up serialization for pydata/sparse tensors.
+
+def _register_pydata_sparse_handlers(serialization_context):
+ try:
+ import sparse
+
+ def _serialize_pydata_sparse(obj):
+ if isinstance(obj, sparse.COO):
+ return 'coo', pa.SparseCOOTensor.from_pydata_sparse(obj)
+ else:
+ raise NotImplementedError(
+ "Serialization of {} is not supported.".format(sparse.COO))
+
+ def _deserialize_pydata_sparse(data):
+ if data[0] == 'coo':
+ data_array, coords = data[1].to_numpy()
+ return sparse.COO(
+ data=data_array[:, 0],
+ coords=coords.T, shape=data[1].shape)
+
+ serialization_context.register_type(
+ sparse.COO, 'sparse.COO',
+ custom_serializer=_serialize_pydata_sparse,
+ custom_deserializer=_deserialize_pydata_sparse)
+
+ except ImportError:
+ # no pydata/sparse
+ pass
+
+
+def _register_default_serialization_handlers(serialization_context):
+
+ # ----------------------------------------------------------------------
+ # Set up serialization for primitive datatypes
+
+ # TODO(pcm): This is currently a workaround until arrow supports
+ # arbitrary precision integers. This is only called on long integers,
+ # see the associated case in the append method in python_to_arrow.cc
+ serialization_context.register_type(
+ int, "int",
+ custom_serializer=lambda obj: str(obj),
+ custom_deserializer=lambda data: int(data))
+
+ serialization_context.register_type(
+ type(lambda: 0), "function",
+ pickle=True)
+
+ serialization_context.register_type(type, "type", pickle=True)
+
+ serialization_context.register_type(
+ np.matrix, 'np.matrix',
+ custom_serializer=_serialize_numpy_matrix,
+ custom_deserializer=_deserialize_numpy_matrix)
+
+ serialization_context.register_type(
+ np.ndarray, 'np.array',
+ custom_serializer=_serialize_numpy_array_list,
+ custom_deserializer=_deserialize_numpy_array_list)
+
+ serialization_context.register_type(
+ pa.Array, 'pyarrow.Array',
+ custom_serializer=_serialize_pyarrow_array,
+ custom_deserializer=_deserialize_pyarrow_array)
+
+ serialization_context.register_type(
+ pa.RecordBatch, 'pyarrow.RecordBatch',
+ custom_serializer=_serialize_pyarrow_recordbatch,
+ custom_deserializer=_deserialize_pyarrow_recordbatch)
+
+ serialization_context.register_type(
+ pa.Table, 'pyarrow.Table',
+ custom_serializer=_serialize_pyarrow_table,
+ custom_deserializer=_deserialize_pyarrow_table)
+
+ _register_collections_serialization_handlers(serialization_context)
+ _register_custom_pandas_handlers(serialization_context)
+ _register_scipy_handlers(serialization_context)
+ _register_pydata_sparse_handlers(serialization_context)
+
+
+def register_default_serialization_handlers(serialization_context):
+ _deprecate_serialization("register_default_serialization_handlers")
+ _register_default_serialization_handlers(serialization_context)
+
+
+def default_serialization_context():
+ _deprecate_serialization("default_serialization_context")
+ context = SerializationContext()
+ _register_default_serialization_handlers(context)
+ return context