summaryrefslogtreecommitdiffstats
path: root/src/arrow/python/pyarrow
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/arrow/python/pyarrow/__init__.pxd42
-rw-r--r--src/arrow/python/pyarrow/__init__.py511
-rw-r--r--src/arrow/python/pyarrow/_compute.pxd30
-rw-r--r--src/arrow/python/pyarrow/_compute.pyx1296
-rw-r--r--src/arrow/python/pyarrow/_csv.pxd54
-rw-r--r--src/arrow/python/pyarrow/_csv.pyx1077
-rw-r--r--src/arrow/python/pyarrow/_cuda.pxd67
-rw-r--r--src/arrow/python/pyarrow/_cuda.pyx1060
-rw-r--r--src/arrow/python/pyarrow/_dataset.pxd51
-rw-r--r--src/arrow/python/pyarrow/_dataset.pyx3408
-rw-r--r--src/arrow/python/pyarrow/_dataset_orc.pyx42
-rw-r--r--src/arrow/python/pyarrow/_feather.pyx113
-rw-r--r--src/arrow/python/pyarrow/_flight.pyx2664
-rw-r--r--src/arrow/python/pyarrow/_fs.pxd94
-rw-r--r--src/arrow/python/pyarrow/_fs.pyx1233
-rw-r--r--src/arrow/python/pyarrow/_hdfs.pyx149
-rw-r--r--src/arrow/python/pyarrow/_hdfsio.pyx480
-rw-r--r--src/arrow/python/pyarrow/_json.pyx248
-rw-r--r--src/arrow/python/pyarrow/_orc.pxd63
-rw-r--r--src/arrow/python/pyarrow/_orc.pyx163
-rw-r--r--src/arrow/python/pyarrow/_parquet.pxd559
-rw-r--r--src/arrow/python/pyarrow/_parquet.pyx1466
-rw-r--r--src/arrow/python/pyarrow/_plasma.pyx867
-rw-r--r--src/arrow/python/pyarrow/_s3fs.pyx284
-rw-r--r--src/arrow/python/pyarrow/array.pxi2541
-rw-r--r--src/arrow/python/pyarrow/benchmark.pxi20
-rw-r--r--src/arrow/python/pyarrow/benchmark.py21
-rw-r--r--src/arrow/python/pyarrow/builder.pxi82
-rw-r--r--src/arrow/python/pyarrow/cffi.py71
-rw-r--r--src/arrow/python/pyarrow/compat.pxi65
-rw-r--r--src/arrow/python/pyarrow/compat.py29
-rw-r--r--src/arrow/python/pyarrow/compute.py759
-rw-r--r--src/arrow/python/pyarrow/config.pxi74
-rw-r--r--src/arrow/python/pyarrow/csv.py22
-rw-r--r--src/arrow/python/pyarrow/cuda.py25
-rw-r--r--src/arrow/python/pyarrow/dataset.py881
-rw-r--r--src/arrow/python/pyarrow/error.pxi242
-rw-r--r--src/arrow/python/pyarrow/feather.py265
-rw-r--r--src/arrow/python/pyarrow/filesystem.py511
-rw-r--r--src/arrow/python/pyarrow/flight.py63
-rw-r--r--src/arrow/python/pyarrow/fs.py405
-rw-r--r--src/arrow/python/pyarrow/gandiva.pyx518
-rw-r--r--src/arrow/python/pyarrow/hdfs.py240
-rw-r--r--src/arrow/python/pyarrow/includes/__init__.pxd0
-rw-r--r--src/arrow/python/pyarrow/includes/common.pxd138
-rw-r--r--src/arrow/python/pyarrow/includes/libarrow.pxd2615
-rw-r--r--src/arrow/python/pyarrow/includes/libarrow_cuda.pxd107
-rw-r--r--src/arrow/python/pyarrow/includes/libarrow_dataset.pxd478
-rw-r--r--src/arrow/python/pyarrow/includes/libarrow_feather.pxd49
-rw-r--r--src/arrow/python/pyarrow/includes/libarrow_flight.pxd560
-rw-r--r--src/arrow/python/pyarrow/includes/libarrow_fs.pxd296
-rw-r--r--src/arrow/python/pyarrow/includes/libgandiva.pxd286
-rw-r--r--src/arrow/python/pyarrow/includes/libplasma.pxd25
-rw-r--r--src/arrow/python/pyarrow/io.pxi2137
-rw-r--r--src/arrow/python/pyarrow/ipc.pxi1009
-rw-r--r--src/arrow/python/pyarrow/ipc.py233
-rw-r--r--src/arrow/python/pyarrow/json.py19
-rw-r--r--src/arrow/python/pyarrow/jvm.py335
-rw-r--r--src/arrow/python/pyarrow/lib.pxd604
-rw-r--r--src/arrow/python/pyarrow/lib.pyx172
-rw-r--r--src/arrow/python/pyarrow/memory.pxi249
-rw-r--r--src/arrow/python/pyarrow/orc.py177
-rw-r--r--src/arrow/python/pyarrow/pandas-shim.pxi254
-rw-r--r--src/arrow/python/pyarrow/pandas_compat.py1226
-rw-r--r--src/arrow/python/pyarrow/parquet.py2299
-rw-r--r--src/arrow/python/pyarrow/plasma.py152
-rw-r--r--src/arrow/python/pyarrow/public-api.pxi418
-rw-r--r--src/arrow/python/pyarrow/scalar.pxi1048
-rw-r--r--src/arrow/python/pyarrow/serialization.pxi556
-rw-r--r--src/arrow/python/pyarrow/serialization.py504
-rw-r--r--src/arrow/python/pyarrow/table.pxi2389
-rw-r--r--src/arrow/python/pyarrow/tensor.pxi1025
-rw-r--r--src/arrow/python/pyarrow/tensorflow/plasma_op.cc391
-rw-r--r--src/arrow/python/pyarrow/tests/__init__.py0
-rw-r--r--src/arrow/python/pyarrow/tests/arrow_7980.py30
-rw-r--r--src/arrow/python/pyarrow/tests/bound_function_visit_strings.pyx68
-rw-r--r--src/arrow/python/pyarrow/tests/conftest.py302
-rw-r--r--src/arrow/python/pyarrow/tests/data/feather/v0.17.0.version=2-compression=lz4.featherbin0 -> 594 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/README.md22
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.jsn.gzbin0 -> 50 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.orcbin0 -> 523 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.jsn.gzbin0 -> 323 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.orcbin0 -> 1711 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.jsn.gzbin0 -> 182453 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.orcbin0 -> 30941 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/decimal.jsn.gzbin0 -> 19313 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/orc/decimal.orcbin0 -> 16337 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.all-named-index.parquetbin0 -> 3948 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.column-metadata-handling.parquetbin0 -> 2012 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.parquetbin0 -> 4372 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.some-named-index.parquetbin0 -> 4008 bytes
-rw-r--r--src/arrow/python/pyarrow/tests/deserialize_buffer.py26
-rw-r--r--src/arrow/python/pyarrow/tests/pandas_examples.py172
-rw-r--r--src/arrow/python/pyarrow/tests/pandas_threaded_import.py44
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/common.py177
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/conftest.py87
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_basic.py631
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_compliant_nested_type.py115
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_data_types.py529
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_dataset.py1661
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_datetime.py440
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_metadata.py524
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_pandas.py687
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_parquet_file.py276
-rw-r--r--src/arrow/python/pyarrow/tests/parquet/test_parquet_writer.py278
-rw-r--r--src/arrow/python/pyarrow/tests/pyarrow_cython_example.pyx55
-rw-r--r--src/arrow/python/pyarrow/tests/strategies.py419
-rw-r--r--src/arrow/python/pyarrow/tests/test_adhoc_memory_leak.py43
-rw-r--r--src/arrow/python/pyarrow/tests/test_array.py3064
-rw-r--r--src/arrow/python/pyarrow/tests/test_builder.py67
-rw-r--r--src/arrow/python/pyarrow/tests/test_cffi.py398
-rw-r--r--src/arrow/python/pyarrow/tests/test_compute.py2238
-rw-r--r--src/arrow/python/pyarrow/tests/test_convert_builtin.py2309
-rw-r--r--src/arrow/python/pyarrow/tests/test_csv.py1824
-rw-r--r--src/arrow/python/pyarrow/tests/test_cuda.py792
-rw-r--r--src/arrow/python/pyarrow/tests/test_cuda_numba_interop.py235
-rw-r--r--src/arrow/python/pyarrow/tests/test_cython.py180
-rw-r--r--src/arrow/python/pyarrow/tests/test_dataset.py3976
-rw-r--r--src/arrow/python/pyarrow/tests/test_deprecations.py23
-rw-r--r--src/arrow/python/pyarrow/tests/test_extension_type.py779
-rw-r--r--src/arrow/python/pyarrow/tests/test_feather.py799
-rw-r--r--src/arrow/python/pyarrow/tests/test_filesystem.py67
-rw-r--r--src/arrow/python/pyarrow/tests/test_flight.py2047
-rw-r--r--src/arrow/python/pyarrow/tests/test_fs.py1714
-rw-r--r--src/arrow/python/pyarrow/tests/test_gandiva.py391
-rw-r--r--src/arrow/python/pyarrow/tests/test_hdfs.py447
-rw-r--r--src/arrow/python/pyarrow/tests/test_io.py1886
-rw-r--r--src/arrow/python/pyarrow/tests/test_ipc.py999
-rw-r--r--src/arrow/python/pyarrow/tests/test_json.py310
-rw-r--r--src/arrow/python/pyarrow/tests/test_jvm.py433
-rw-r--r--src/arrow/python/pyarrow/tests/test_memory.py161
-rw-r--r--src/arrow/python/pyarrow/tests/test_misc.py185
-rw-r--r--src/arrow/python/pyarrow/tests/test_orc.py271
-rw-r--r--src/arrow/python/pyarrow/tests/test_pandas.py4386
-rw-r--r--src/arrow/python/pyarrow/tests/test_plasma.py1073
-rw-r--r--src/arrow/python/pyarrow/tests/test_plasma_tf_op.py104
-rw-r--r--src/arrow/python/pyarrow/tests/test_scalars.py687
-rw-r--r--src/arrow/python/pyarrow/tests/test_schema.py730
-rw-r--r--src/arrow/python/pyarrow/tests/test_serialization.py1233
-rw-r--r--src/arrow/python/pyarrow/tests/test_serialization_deprecated.py56
-rw-r--r--src/arrow/python/pyarrow/tests/test_sparse_tensor.py491
-rw-r--r--src/arrow/python/pyarrow/tests/test_strategies.py61
-rw-r--r--src/arrow/python/pyarrow/tests/test_table.py1748
-rw-r--r--src/arrow/python/pyarrow/tests/test_tensor.py216
-rw-r--r--src/arrow/python/pyarrow/tests/test_types.py1067
-rw-r--r--src/arrow/python/pyarrow/tests/test_util.py52
-rw-r--r--src/arrow/python/pyarrow/tests/util.py331
-rw-r--r--src/arrow/python/pyarrow/types.pxi2930
-rw-r--r--src/arrow/python/pyarrow/types.py550
-rw-r--r--src/arrow/python/pyarrow/util.py178
-rw-r--r--src/arrow/python/pyarrow/vendored/__init__.py16
-rw-r--r--src/arrow/python/pyarrow/vendored/version.py545
152 files changed, 91241 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/__init__.pxd b/src/arrow/python/pyarrow/__init__.pxd
new file mode 100644
index 000000000..8cc54b4c6
--- /dev/null
+++ b/src/arrow/python/pyarrow/__init__.pxd
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from libcpp.memory cimport shared_ptr
+from pyarrow.includes.libarrow cimport (CArray, CBuffer, CDataType,
+ CField, CRecordBatch, CSchema,
+ CTable, CTensor, CSparseCOOTensor,
+ CSparseCSRMatrix, CSparseCSCMatrix,
+ CSparseCSFTensor)
+
+cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py":
+ cdef int import_pyarrow() except -1
+ cdef object wrap_buffer(const shared_ptr[CBuffer]& buffer)
+ cdef object wrap_data_type(const shared_ptr[CDataType]& type)
+ cdef object wrap_field(const shared_ptr[CField]& field)
+ cdef object wrap_schema(const shared_ptr[CSchema]& schema)
+ cdef object wrap_array(const shared_ptr[CArray]& sp_array)
+ cdef object wrap_tensor(const shared_ptr[CTensor]& sp_tensor)
+ cdef object wrap_sparse_tensor_coo(
+ const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor)
+ cdef object wrap_sparse_tensor_csr(
+ const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor)
+ cdef object wrap_sparse_tensor_csc(
+ const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor)
+ cdef object wrap_sparse_tensor_csf(
+ const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor)
+ cdef object wrap_table(const shared_ptr[CTable]& ctable)
+ cdef object wrap_batch(const shared_ptr[CRecordBatch]& cbatch)
diff --git a/src/arrow/python/pyarrow/__init__.py b/src/arrow/python/pyarrow/__init__.py
new file mode 100644
index 000000000..1ec229d53
--- /dev/null
+++ b/src/arrow/python/pyarrow/__init__.py
@@ -0,0 +1,511 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# flake8: noqa
+
+"""
+PyArrow is the python implementation of Apache Arrow.
+
+Apache Arrow is a cross-language development platform for in-memory data.
+It specifies a standardized language-independent columnar memory format for
+flat and hierarchical data, organized for efficient analytic operations on
+modern hardware. It also provides computational libraries and zero-copy
+streaming messaging and interprocess communication.
+
+For more information see the official page at https://arrow.apache.org
+"""
+
+import gc as _gc
+import os as _os
+import sys as _sys
+import warnings as _warnings
+
+try:
+ from ._generated_version import version as __version__
+except ImportError:
+ # Package is not installed, parse git tag at runtime
+ try:
+ import setuptools_scm
+ # Code duplicated from setup.py to avoid a dependency on each other
+
+ def parse_git(root, **kwargs):
+ """
+ Parse function for setuptools_scm that ignores tags for non-C++
+ subprojects, e.g. apache-arrow-js-XXX tags.
+ """
+ from setuptools_scm.git import parse
+ kwargs['describe_command'] = \
+ "git describe --dirty --tags --long --match 'apache-arrow-[0-9].*'"
+ return parse(root, **kwargs)
+ __version__ = setuptools_scm.get_version('../',
+ parse=parse_git)
+ except ImportError:
+ __version__ = None
+
+# ARROW-8684: Disable GC while initializing Cython extension module,
+# to workaround Cython bug in https://github.com/cython/cython/issues/3603
+_gc_enabled = _gc.isenabled()
+_gc.disable()
+import pyarrow.lib as _lib
+if _gc_enabled:
+ _gc.enable()
+
+from pyarrow.lib import (BuildInfo, RuntimeInfo, MonthDayNano,
+ VersionInfo, cpp_build_info, cpp_version,
+ cpp_version_info, runtime_info, cpu_count,
+ set_cpu_count, enable_signal_handlers,
+ io_thread_count, set_io_thread_count)
+
+
+def show_versions():
+ """
+ Print various version information, to help with error reporting.
+ """
+ # TODO: CPU information and flags
+ print("pyarrow version info\n--------------------")
+ print("Package kind: {}".format(cpp_build_info.package_kind
+ if len(cpp_build_info.package_kind) > 0
+ else "not indicated"))
+ print("Arrow C++ library version: {0}".format(cpp_build_info.version))
+ print("Arrow C++ compiler: {0} {1}"
+ .format(cpp_build_info.compiler_id, cpp_build_info.compiler_version))
+ print("Arrow C++ compiler flags: {0}"
+ .format(cpp_build_info.compiler_flags))
+ print("Arrow C++ git revision: {0}".format(cpp_build_info.git_id))
+ print("Arrow C++ git description: {0}"
+ .format(cpp_build_info.git_description))
+
+
+from pyarrow.lib import (null, bool_,
+ int8, int16, int32, int64,
+ uint8, uint16, uint32, uint64,
+ time32, time64, timestamp, date32, date64, duration,
+ month_day_nano_interval,
+ float16, float32, float64,
+ binary, string, utf8,
+ large_binary, large_string, large_utf8,
+ decimal128, decimal256,
+ list_, large_list, map_, struct,
+ union, sparse_union, dense_union,
+ dictionary,
+ field,
+ type_for_alias,
+ DataType, DictionaryType, StructType,
+ ListType, LargeListType, MapType, FixedSizeListType,
+ UnionType, SparseUnionType, DenseUnionType,
+ TimestampType, Time32Type, Time64Type, DurationType,
+ FixedSizeBinaryType, Decimal128Type, Decimal256Type,
+ BaseExtensionType, ExtensionType,
+ PyExtensionType, UnknownExtensionType,
+ register_extension_type, unregister_extension_type,
+ DictionaryMemo,
+ KeyValueMetadata,
+ Field,
+ Schema,
+ schema,
+ unify_schemas,
+ Array, Tensor,
+ array, chunked_array, record_batch, nulls, repeat,
+ SparseCOOTensor, SparseCSRMatrix, SparseCSCMatrix,
+ SparseCSFTensor,
+ infer_type, from_numpy_dtype,
+ NullArray,
+ NumericArray, IntegerArray, FloatingPointArray,
+ BooleanArray,
+ Int8Array, UInt8Array,
+ Int16Array, UInt16Array,
+ Int32Array, UInt32Array,
+ Int64Array, UInt64Array,
+ ListArray, LargeListArray, MapArray,
+ FixedSizeListArray, UnionArray,
+ BinaryArray, StringArray,
+ LargeBinaryArray, LargeStringArray,
+ FixedSizeBinaryArray,
+ DictionaryArray,
+ Date32Array, Date64Array, TimestampArray,
+ Time32Array, Time64Array, DurationArray,
+ MonthDayNanoIntervalArray,
+ Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
+ scalar, NA, _NULL as NULL, Scalar,
+ NullScalar, BooleanScalar,
+ Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
+ UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar,
+ HalfFloatScalar, FloatScalar, DoubleScalar,
+ Decimal128Scalar, Decimal256Scalar,
+ ListScalar, LargeListScalar, FixedSizeListScalar,
+ Date32Scalar, Date64Scalar,
+ Time32Scalar, Time64Scalar,
+ TimestampScalar, DurationScalar,
+ MonthDayNanoIntervalScalar,
+ BinaryScalar, LargeBinaryScalar,
+ StringScalar, LargeStringScalar,
+ FixedSizeBinaryScalar, DictionaryScalar,
+ MapScalar, StructScalar, UnionScalar,
+ ExtensionScalar)
+
+# Buffers, allocation
+from pyarrow.lib import (Buffer, ResizableBuffer, foreign_buffer, py_buffer,
+ Codec, compress, decompress, allocate_buffer)
+
+from pyarrow.lib import (MemoryPool, LoggingMemoryPool, ProxyMemoryPool,
+ total_allocated_bytes, set_memory_pool,
+ default_memory_pool, system_memory_pool,
+ jemalloc_memory_pool, mimalloc_memory_pool,
+ logging_memory_pool, proxy_memory_pool,
+ log_memory_allocations, jemalloc_set_decay_ms)
+
+# I/O
+from pyarrow.lib import (NativeFile, PythonFile,
+ BufferedInputStream, BufferedOutputStream,
+ CompressedInputStream, CompressedOutputStream,
+ TransformInputStream, transcoding_input_stream,
+ FixedSizeBufferWriter,
+ BufferReader, BufferOutputStream,
+ OSFile, MemoryMappedFile, memory_map,
+ create_memory_map, MockOutputStream,
+ input_stream, output_stream)
+
+from pyarrow._hdfsio import HdfsFile, have_libhdfs
+
+from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table,
+ concat_arrays, concat_tables)
+
+# Exceptions
+from pyarrow.lib import (ArrowCancelled,
+ ArrowCapacityError,
+ ArrowException,
+ ArrowKeyError,
+ ArrowIndexError,
+ ArrowInvalid,
+ ArrowIOError,
+ ArrowMemoryError,
+ ArrowNotImplementedError,
+ ArrowTypeError,
+ ArrowSerializationError)
+
+# Serialization
+from pyarrow.lib import (deserialize_from, deserialize,
+ deserialize_components,
+ serialize, serialize_to, read_serialized,
+ SerializationCallbackError,
+ DeserializationCallbackError)
+
+import pyarrow.hdfs as hdfs
+
+from pyarrow.ipc import serialize_pandas, deserialize_pandas
+import pyarrow.ipc as ipc
+
+from pyarrow.serialization import (default_serialization_context,
+ register_default_serialization_handlers,
+ register_torch_serialization_handlers)
+
+import pyarrow.types as types
+
+
+# deprecated top-level access
+
+
+from pyarrow.filesystem import FileSystem as _FileSystem
+from pyarrow.filesystem import LocalFileSystem as _LocalFileSystem
+from pyarrow.hdfs import HadoopFileSystem as _HadoopFileSystem
+
+from pyarrow.lib import SerializationContext as _SerializationContext
+from pyarrow.lib import SerializedPyObject as _SerializedPyObject
+
+
+_localfs = _LocalFileSystem._get_instance()
+
+
+_msg = (
+ "pyarrow.{0} is deprecated as of 2.0.0, please use pyarrow.fs.{1} instead."
+)
+
+_serialization_msg = (
+ "'pyarrow.{0}' is deprecated and will be removed in a future version. "
+ "Use pickle or the pyarrow IPC functionality instead."
+)
+
+_deprecated = {
+ "localfs": (_localfs, "LocalFileSystem"),
+ "FileSystem": (_FileSystem, "FileSystem"),
+ "LocalFileSystem": (_LocalFileSystem, "LocalFileSystem"),
+ "HadoopFileSystem": (_HadoopFileSystem, "HadoopFileSystem"),
+}
+
+_serialization_deprecatd = {
+ "SerializationContext": _SerializationContext,
+ "SerializedPyObject": _SerializedPyObject,
+}
+
+if _sys.version_info >= (3, 7):
+ def __getattr__(name):
+ if name in _deprecated:
+ obj, new_name = _deprecated[name]
+ _warnings.warn(_msg.format(name, new_name),
+ FutureWarning, stacklevel=2)
+ return obj
+ elif name in _serialization_deprecatd:
+ _warnings.warn(_serialization_msg.format(name),
+ FutureWarning, stacklevel=2)
+ return _serialization_deprecatd[name]
+
+ raise AttributeError(
+ "module 'pyarrow' has no attribute '{0}'".format(name)
+ )
+else:
+ localfs = _localfs
+ FileSystem = _FileSystem
+ LocalFileSystem = _LocalFileSystem
+ HadoopFileSystem = _HadoopFileSystem
+ SerializationContext = _SerializationContext
+ SerializedPyObject = _SerializedPyObject
+
+
+# Entry point for starting the plasma store
+
+
+def _plasma_store_entry_point():
+ """Entry point for starting the plasma store.
+
+ This can be used by invoking e.g.
+ ``plasma_store -s /tmp/plasma -m 1000000000``
+ from the command line and will start the plasma_store executable with the
+ given arguments.
+ """
+ import pyarrow
+ plasma_store_executable = _os.path.join(pyarrow.__path__[0],
+ "plasma-store-server")
+ _os.execv(plasma_store_executable, _sys.argv)
+
+
+# ----------------------------------------------------------------------
+# Deprecations
+
+from pyarrow.util import _deprecate_api, _deprecate_class
+
+read_message = _deprecate_api("read_message", "ipc.read_message",
+ ipc.read_message, "0.17.0")
+
+read_record_batch = _deprecate_api("read_record_batch",
+ "ipc.read_record_batch",
+ ipc.read_record_batch, "0.17.0")
+
+read_schema = _deprecate_api("read_schema", "ipc.read_schema",
+ ipc.read_schema, "0.17.0")
+
+read_tensor = _deprecate_api("read_tensor", "ipc.read_tensor",
+ ipc.read_tensor, "0.17.0")
+
+write_tensor = _deprecate_api("write_tensor", "ipc.write_tensor",
+ ipc.write_tensor, "0.17.0")
+
+get_record_batch_size = _deprecate_api("get_record_batch_size",
+ "ipc.get_record_batch_size",
+ ipc.get_record_batch_size, "0.17.0")
+
+get_tensor_size = _deprecate_api("get_tensor_size",
+ "ipc.get_tensor_size",
+ ipc.get_tensor_size, "0.17.0")
+
+open_stream = _deprecate_api("open_stream", "ipc.open_stream",
+ ipc.open_stream, "0.17.0")
+
+open_file = _deprecate_api("open_file", "ipc.open_file", ipc.open_file,
+ "0.17.0")
+
+
+def _deprecate_scalar(ty, symbol):
+ return _deprecate_class("{}Value".format(ty), symbol, "1.0.0")
+
+
+ArrayValue = _deprecate_class("ArrayValue", Scalar, "1.0.0")
+NullType = _deprecate_class("NullType", NullScalar, "1.0.0")
+
+BooleanValue = _deprecate_scalar("Boolean", BooleanScalar)
+Int8Value = _deprecate_scalar("Int8", Int8Scalar)
+Int16Value = _deprecate_scalar("Int16", Int16Scalar)
+Int32Value = _deprecate_scalar("Int32", Int32Scalar)
+Int64Value = _deprecate_scalar("Int64", Int64Scalar)
+UInt8Value = _deprecate_scalar("UInt8", UInt8Scalar)
+UInt16Value = _deprecate_scalar("UInt16", UInt16Scalar)
+UInt32Value = _deprecate_scalar("UInt32", UInt32Scalar)
+UInt64Value = _deprecate_scalar("UInt64", UInt64Scalar)
+HalfFloatValue = _deprecate_scalar("HalfFloat", HalfFloatScalar)
+FloatValue = _deprecate_scalar("Float", FloatScalar)
+DoubleValue = _deprecate_scalar("Double", DoubleScalar)
+ListValue = _deprecate_scalar("List", ListScalar)
+LargeListValue = _deprecate_scalar("LargeList", LargeListScalar)
+MapValue = _deprecate_scalar("Map", MapScalar)
+FixedSizeListValue = _deprecate_scalar("FixedSizeList", FixedSizeListScalar)
+BinaryValue = _deprecate_scalar("Binary", BinaryScalar)
+StringValue = _deprecate_scalar("String", StringScalar)
+LargeBinaryValue = _deprecate_scalar("LargeBinary", LargeBinaryScalar)
+LargeStringValue = _deprecate_scalar("LargeString", LargeStringScalar)
+FixedSizeBinaryValue = _deprecate_scalar("FixedSizeBinary",
+ FixedSizeBinaryScalar)
+Decimal128Value = _deprecate_scalar("Decimal128", Decimal128Scalar)
+Decimal256Value = _deprecate_scalar("Decimal256", Decimal256Scalar)
+UnionValue = _deprecate_scalar("Union", UnionScalar)
+StructValue = _deprecate_scalar("Struct", StructScalar)
+DictionaryValue = _deprecate_scalar("Dictionary", DictionaryScalar)
+Date32Value = _deprecate_scalar("Date32", Date32Scalar)
+Date64Value = _deprecate_scalar("Date64", Date64Scalar)
+Time32Value = _deprecate_scalar("Time32", Time32Scalar)
+Time64Value = _deprecate_scalar("Time64", Time64Scalar)
+TimestampValue = _deprecate_scalar("Timestamp", TimestampScalar)
+DurationValue = _deprecate_scalar("Duration", DurationScalar)
+
+
+# TODO: Deprecate these somehow in the pyarrow namespace
+from pyarrow.ipc import (Message, MessageReader, MetadataVersion,
+ RecordBatchFileReader, RecordBatchFileWriter,
+ RecordBatchStreamReader, RecordBatchStreamWriter)
+
+# ----------------------------------------------------------------------
+# Returning absolute path to the pyarrow include directory (if bundled, e.g. in
+# wheels)
+
+
+def get_include():
+ """
+ Return absolute path to directory containing Arrow C++ include
+ headers. Similar to numpy.get_include
+ """
+ return _os.path.join(_os.path.dirname(__file__), 'include')
+
+
+def _get_pkg_config_executable():
+ return _os.environ.get('PKG_CONFIG', 'pkg-config')
+
+
+def _has_pkg_config(pkgname):
+ import subprocess
+ try:
+ return subprocess.call([_get_pkg_config_executable(),
+ '--exists', pkgname]) == 0
+ except FileNotFoundError:
+ return False
+
+
+def _read_pkg_config_variable(pkgname, cli_args):
+ import subprocess
+ cmd = [_get_pkg_config_executable(), pkgname] + cli_args
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ out, err = proc.communicate()
+ if proc.returncode != 0:
+ raise RuntimeError("pkg-config failed: " + err.decode('utf8'))
+ return out.rstrip().decode('utf8')
+
+
+def get_libraries():
+ """
+ Return list of library names to include in the `libraries` argument for C
+ or Cython extensions using pyarrow
+ """
+ return ['arrow', 'arrow_python']
+
+
+def create_library_symlinks():
+ """
+ With Linux and macOS wheels, the bundled shared libraries have an embedded
+ ABI version like libarrow.so.17 or libarrow.17.dylib and so linking to them
+ with -larrow won't work unless we create symlinks at locations like
+ site-packages/pyarrow/libarrow.so. This unfortunate workaround addresses
+ prior problems we had with shipping two copies of the shared libraries to
+ permit third party projects like turbodbc to build their C++ extensions
+ against the pyarrow wheels.
+
+ This function must only be invoked once and only when the shared libraries
+ are bundled with the Python package, which should only apply to wheel-based
+ installs. It requires write access to the site-packages/pyarrow directory
+ and so depending on your system may need to be run with root.
+ """
+ import glob
+ if _sys.platform == 'win32':
+ return
+ package_cwd = _os.path.dirname(__file__)
+
+ if _sys.platform == 'linux':
+ bundled_libs = glob.glob(_os.path.join(package_cwd, '*.so.*'))
+
+ def get_symlink_path(hard_path):
+ return hard_path.rsplit('.', 1)[0]
+ else:
+ bundled_libs = glob.glob(_os.path.join(package_cwd, '*.*.dylib'))
+
+ def get_symlink_path(hard_path):
+ return '.'.join((hard_path.rsplit('.', 2)[0], 'dylib'))
+
+ for lib_hard_path in bundled_libs:
+ symlink_path = get_symlink_path(lib_hard_path)
+ if _os.path.exists(symlink_path):
+ continue
+ try:
+ _os.symlink(lib_hard_path, symlink_path)
+ except PermissionError:
+ print("Tried creating symlink {}. If you need to link to "
+ "bundled shared libraries, run "
+ "pyarrow.create_library_symlinks() as root")
+
+
+def get_library_dirs():
+ """
+ Return lists of directories likely to contain Arrow C++ libraries for
+ linking C or Cython extensions using pyarrow
+ """
+ package_cwd = _os.path.dirname(__file__)
+ library_dirs = [package_cwd]
+
+ def append_library_dir(library_dir):
+ if library_dir not in library_dirs:
+ library_dirs.append(library_dir)
+
+ # Search library paths via pkg-config. This is necessary if the user
+ # installed libarrow and the other shared libraries manually and they
+ # are not shipped inside the pyarrow package (see also ARROW-2976).
+ pkg_config_executable = _os.environ.get('PKG_CONFIG') or 'pkg-config'
+ for pkgname in ["arrow", "arrow_python"]:
+ if _has_pkg_config(pkgname):
+ library_dir = _read_pkg_config_variable(pkgname,
+ ["--libs-only-L"])
+ # pkg-config output could be empty if Arrow is installed
+ # as a system package.
+ if library_dir:
+ if not library_dir.startswith("-L"):
+ raise ValueError(
+ "pkg-config --libs-only-L returned unexpected "
+ "value {!r}".format(library_dir))
+ append_library_dir(library_dir[2:])
+
+ if _sys.platform == 'win32':
+ # TODO(wesm): Is this necessary, or does setuptools within a conda
+ # installation add Library\lib to the linker path for MSVC?
+ python_base_install = _os.path.dirname(_sys.executable)
+ library_dir = _os.path.join(python_base_install, 'Library', 'lib')
+
+ if _os.path.exists(_os.path.join(library_dir, 'arrow.lib')):
+ append_library_dir(library_dir)
+
+ # ARROW-4074: Allow for ARROW_HOME to be set to some other directory
+ if _os.environ.get('ARROW_HOME'):
+ append_library_dir(_os.path.join(_os.environ['ARROW_HOME'], 'lib'))
+ else:
+ # Python wheels bundle the Arrow libraries in the pyarrow directory.
+ append_library_dir(_os.path.dirname(_os.path.abspath(__file__)))
+
+ return library_dirs
diff --git a/src/arrow/python/pyarrow/_compute.pxd b/src/arrow/python/pyarrow/_compute.pxd
new file mode 100644
index 000000000..8358271ef
--- /dev/null
+++ b/src/arrow/python/pyarrow/_compute.pxd
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from pyarrow.lib cimport *
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+
+
+cdef class FunctionOptions(_Weakrefable):
+ cdef:
+ unique_ptr[CFunctionOptions] wrapped
+
+ cdef const CFunctionOptions* get_options(self) except NULL
+ cdef void init(self, unique_ptr[CFunctionOptions] options)
diff --git a/src/arrow/python/pyarrow/_compute.pyx b/src/arrow/python/pyarrow/_compute.pyx
new file mode 100644
index 000000000..d62c9c0ee
--- /dev/null
+++ b/src/arrow/python/pyarrow/_compute.pyx
@@ -0,0 +1,1296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+import sys
+
+from cython.operator cimport dereference as deref
+
+from collections import namedtuple
+
+from pyarrow.lib import frombytes, tobytes, ordered_dict
+from pyarrow.lib cimport *
+from pyarrow.includes.libarrow cimport *
+import pyarrow.lib as lib
+
+import numpy as np
+
+
+cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func):
+ """
+ Wrap a C++ scalar Function in a ScalarFunction object.
+ """
+ cdef ScalarFunction func = ScalarFunction.__new__(ScalarFunction)
+ func.init(sp_func)
+ return func
+
+
+cdef wrap_vector_function(const shared_ptr[CFunction]& sp_func):
+ """
+ Wrap a C++ vector Function in a VectorFunction object.
+ """
+ cdef VectorFunction func = VectorFunction.__new__(VectorFunction)
+ func.init(sp_func)
+ return func
+
+
+cdef wrap_scalar_aggregate_function(const shared_ptr[CFunction]& sp_func):
+ """
+ Wrap a C++ aggregate Function in a ScalarAggregateFunction object.
+ """
+ cdef ScalarAggregateFunction func = \
+ ScalarAggregateFunction.__new__(ScalarAggregateFunction)
+ func.init(sp_func)
+ return func
+
+
+cdef wrap_hash_aggregate_function(const shared_ptr[CFunction]& sp_func):
+ """
+ Wrap a C++ aggregate Function in a HashAggregateFunction object.
+ """
+ cdef HashAggregateFunction func = \
+ HashAggregateFunction.__new__(HashAggregateFunction)
+ func.init(sp_func)
+ return func
+
+
+cdef wrap_meta_function(const shared_ptr[CFunction]& sp_func):
+ """
+ Wrap a C++ meta Function in a MetaFunction object.
+ """
+ cdef MetaFunction func = MetaFunction.__new__(MetaFunction)
+ func.init(sp_func)
+ return func
+
+
+cdef wrap_function(const shared_ptr[CFunction]& sp_func):
+ """
+ Wrap a C++ Function in a Function object.
+
+ This dispatches to specialized wrappers depending on the function kind.
+ """
+ if sp_func.get() == NULL:
+ raise ValueError("Function was NULL")
+
+ cdef FunctionKind c_kind = sp_func.get().kind()
+ if c_kind == FunctionKind_SCALAR:
+ return wrap_scalar_function(sp_func)
+ elif c_kind == FunctionKind_VECTOR:
+ return wrap_vector_function(sp_func)
+ elif c_kind == FunctionKind_SCALAR_AGGREGATE:
+ return wrap_scalar_aggregate_function(sp_func)
+ elif c_kind == FunctionKind_HASH_AGGREGATE:
+ return wrap_hash_aggregate_function(sp_func)
+ elif c_kind == FunctionKind_META:
+ return wrap_meta_function(sp_func)
+ else:
+ raise NotImplementedError("Unknown Function::Kind")
+
+
+cdef wrap_scalar_kernel(const CScalarKernel* c_kernel):
+ if c_kernel == NULL:
+ raise ValueError("Kernel was NULL")
+ cdef ScalarKernel kernel = ScalarKernel.__new__(ScalarKernel)
+ kernel.init(c_kernel)
+ return kernel
+
+
+cdef wrap_vector_kernel(const CVectorKernel* c_kernel):
+ if c_kernel == NULL:
+ raise ValueError("Kernel was NULL")
+ cdef VectorKernel kernel = VectorKernel.__new__(VectorKernel)
+ kernel.init(c_kernel)
+ return kernel
+
+
+cdef wrap_scalar_aggregate_kernel(const CScalarAggregateKernel* c_kernel):
+ if c_kernel == NULL:
+ raise ValueError("Kernel was NULL")
+ cdef ScalarAggregateKernel kernel = \
+ ScalarAggregateKernel.__new__(ScalarAggregateKernel)
+ kernel.init(c_kernel)
+ return kernel
+
+
+cdef wrap_hash_aggregate_kernel(const CHashAggregateKernel* c_kernel):
+ if c_kernel == NULL:
+ raise ValueError("Kernel was NULL")
+ cdef HashAggregateKernel kernel = \
+ HashAggregateKernel.__new__(HashAggregateKernel)
+ kernel.init(c_kernel)
+ return kernel
+
+
+cdef class Kernel(_Weakrefable):
+ """
+ A kernel object.
+
+ Kernels handle the execution of a Function for a certain signature.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly"
+ .format(self.__class__.__name__))
+
+
+cdef class ScalarKernel(Kernel):
+ cdef const CScalarKernel* kernel
+
+ cdef void init(self, const CScalarKernel* kernel) except *:
+ self.kernel = kernel
+
+ def __repr__(self):
+ return ("ScalarKernel<{}>"
+ .format(frombytes(self.kernel.signature.get().ToString())))
+
+
+cdef class VectorKernel(Kernel):
+ cdef const CVectorKernel* kernel
+
+ cdef void init(self, const CVectorKernel* kernel) except *:
+ self.kernel = kernel
+
+ def __repr__(self):
+ return ("VectorKernel<{}>"
+ .format(frombytes(self.kernel.signature.get().ToString())))
+
+
+cdef class ScalarAggregateKernel(Kernel):
+ cdef const CScalarAggregateKernel* kernel
+
+ cdef void init(self, const CScalarAggregateKernel* kernel) except *:
+ self.kernel = kernel
+
+ def __repr__(self):
+ return ("ScalarAggregateKernel<{}>"
+ .format(frombytes(self.kernel.signature.get().ToString())))
+
+
+cdef class HashAggregateKernel(Kernel):
+ cdef const CHashAggregateKernel* kernel
+
+ cdef void init(self, const CHashAggregateKernel* kernel) except *:
+ self.kernel = kernel
+
+ def __repr__(self):
+ return ("HashAggregateKernel<{}>"
+ .format(frombytes(self.kernel.signature.get().ToString())))
+
+
+FunctionDoc = namedtuple(
+ "FunctionDoc",
+ ("summary", "description", "arg_names", "options_class"))
+
+
+cdef class Function(_Weakrefable):
+ """
+ A compute function.
+
+ A function implements a certain logical computation over a range of
+ possible input signatures. Each signature accepts a range of input
+ types and is implemented by a given Kernel.
+
+ Functions can be of different kinds:
+
+ * "scalar" functions apply an item-wise computation over all items
+ of their inputs. Each item in the output only depends on the values
+ of the inputs at the same position. Examples: addition, comparisons,
+ string predicates...
+
+ * "vector" functions apply a collection-wise computation, such that
+ each item in the output may depend on the values of several items
+ in each input. Examples: dictionary encoding, sorting, extracting
+ unique values...
+
+ * "scalar_aggregate" functions reduce the dimensionality of the inputs by
+ applying a reduction function. Examples: sum, min_max, mode...
+
+ * "hash_aggregate" functions apply a reduction function to an input
+ subdivided by grouping criteria. They may not be directly called.
+ Examples: hash_sum, hash_min_max...
+
+ * "meta" functions dispatch to other functions.
+ """
+
+ cdef:
+ shared_ptr[CFunction] sp_func
+ CFunction* base_func
+
+ _kind_map = {
+ FunctionKind_SCALAR: "scalar",
+ FunctionKind_VECTOR: "vector",
+ FunctionKind_SCALAR_AGGREGATE: "scalar_aggregate",
+ FunctionKind_HASH_AGGREGATE: "hash_aggregate",
+ FunctionKind_META: "meta",
+ }
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly"
+ .format(self.__class__.__name__))
+
+ cdef void init(self, const shared_ptr[CFunction]& sp_func) except *:
+ self.sp_func = sp_func
+ self.base_func = sp_func.get()
+
+ def __repr__(self):
+ return ("arrow.compute.Function<name={}, kind={}, "
+ "arity={}, num_kernels={}>"
+ .format(self.name, self.kind, self.arity, self.num_kernels))
+
+ def __reduce__(self):
+ # Reduction uses the global registry
+ return get_function, (self.name,)
+
+ @property
+ def name(self):
+ """
+ The function name.
+ """
+ return frombytes(self.base_func.name())
+
+ @property
+ def arity(self):
+ """
+ The function arity.
+
+ If Ellipsis (i.e. `...`) is returned, the function takes a variable
+ number of arguments.
+ """
+ cdef CArity arity = self.base_func.arity()
+ if arity.is_varargs:
+ return ...
+ else:
+ return arity.num_args
+
+ @property
+ def kind(self):
+ """
+ The function kind.
+ """
+ cdef FunctionKind c_kind = self.base_func.kind()
+ try:
+ return self._kind_map[c_kind]
+ except KeyError:
+ raise NotImplementedError("Unknown Function::Kind")
+
+ @property
+ def _doc(self):
+ """
+ The C++-like function documentation (for internal use).
+ """
+ cdef CFunctionDoc c_doc = self.base_func.doc()
+ return FunctionDoc(frombytes(c_doc.summary),
+ frombytes(c_doc.description),
+ [frombytes(s) for s in c_doc.arg_names],
+ frombytes(c_doc.options_class))
+
+ @property
+ def num_kernels(self):
+ """
+ The number of kernels implementing this function.
+ """
+ return self.base_func.num_kernels()
+
+ def call(self, args, FunctionOptions options=None,
+ MemoryPool memory_pool=None):
+ """
+ Call the function on the given arguments.
+ """
+ cdef:
+ const CFunctionOptions* c_options = NULL
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ CExecContext c_exec_ctx = CExecContext(pool)
+ vector[CDatum] c_args
+ CDatum result
+
+ _pack_compute_args(args, &c_args)
+
+ if options is not None:
+ c_options = options.get_options()
+
+ with nogil:
+ result = GetResultValue(
+ self.base_func.Execute(c_args, c_options, &c_exec_ctx)
+ )
+
+ return wrap_datum(result)
+
+
+cdef class ScalarFunction(Function):
+ cdef const CScalarFunction* func
+
+ cdef void init(self, const shared_ptr[CFunction]& sp_func) except *:
+ Function.init(self, sp_func)
+ self.func = <const CScalarFunction*> sp_func.get()
+
+ @property
+ def kernels(self):
+ """
+ The kernels implementing this function.
+ """
+ cdef vector[const CScalarKernel*] kernels = self.func.kernels()
+ return [wrap_scalar_kernel(k) for k in kernels]
+
+
+cdef class VectorFunction(Function):
+ cdef const CVectorFunction* func
+
+ cdef void init(self, const shared_ptr[CFunction]& sp_func) except *:
+ Function.init(self, sp_func)
+ self.func = <const CVectorFunction*> sp_func.get()
+
+ @property
+ def kernels(self):
+ """
+ The kernels implementing this function.
+ """
+ cdef vector[const CVectorKernel*] kernels = self.func.kernels()
+ return [wrap_vector_kernel(k) for k in kernels]
+
+
+cdef class ScalarAggregateFunction(Function):
+ cdef const CScalarAggregateFunction* func
+
+ cdef void init(self, const shared_ptr[CFunction]& sp_func) except *:
+ Function.init(self, sp_func)
+ self.func = <const CScalarAggregateFunction*> sp_func.get()
+
+ @property
+ def kernels(self):
+ """
+ The kernels implementing this function.
+ """
+ cdef vector[const CScalarAggregateKernel*] kernels = \
+ self.func.kernels()
+ return [wrap_scalar_aggregate_kernel(k) for k in kernels]
+
+
+cdef class HashAggregateFunction(Function):
+ cdef const CHashAggregateFunction* func
+
+ cdef void init(self, const shared_ptr[CFunction]& sp_func) except *:
+ Function.init(self, sp_func)
+ self.func = <const CHashAggregateFunction*> sp_func.get()
+
+ @property
+ def kernels(self):
+ """
+ The kernels implementing this function.
+ """
+ cdef vector[const CHashAggregateKernel*] kernels = self.func.kernels()
+ return [wrap_hash_aggregate_kernel(k) for k in kernels]
+
+
+cdef class MetaFunction(Function):
+ cdef const CMetaFunction* func
+
+ cdef void init(self, const shared_ptr[CFunction]& sp_func) except *:
+ Function.init(self, sp_func)
+ self.func = <const CMetaFunction*> sp_func.get()
+
+ # Since num_kernels is exposed, also expose a kernels property
+ @property
+ def kernels(self):
+ """
+ The kernels implementing this function.
+ """
+ return []
+
+
+cdef _pack_compute_args(object values, vector[CDatum]* out):
+ for val in values:
+ if isinstance(val, (list, np.ndarray)):
+ val = lib.asarray(val)
+
+ if isinstance(val, Array):
+ out.push_back(CDatum((<Array> val).sp_array))
+ continue
+ elif isinstance(val, ChunkedArray):
+ out.push_back(CDatum((<ChunkedArray> val).sp_chunked_array))
+ continue
+ elif isinstance(val, Scalar):
+ out.push_back(CDatum((<Scalar> val).unwrap()))
+ continue
+ elif isinstance(val, RecordBatch):
+ out.push_back(CDatum((<RecordBatch> val).sp_batch))
+ continue
+ elif isinstance(val, Table):
+ out.push_back(CDatum((<Table> val).sp_table))
+ continue
+ else:
+ # Is it a Python scalar?
+ try:
+ scal = lib.scalar(val)
+ except Exception:
+ # Raise dedicated error below
+ pass
+ else:
+ out.push_back(CDatum((<Scalar> scal).unwrap()))
+ continue
+
+ raise TypeError(f"Got unexpected argument type {type(val)} "
+ "for compute function")
+
+
+cdef class FunctionRegistry(_Weakrefable):
+ cdef CFunctionRegistry* registry
+
+ def __init__(self):
+ self.registry = GetFunctionRegistry()
+
+ def list_functions(self):
+ """
+ Return all function names in the registry.
+ """
+ cdef vector[c_string] names = self.registry.GetFunctionNames()
+ return [frombytes(name) for name in names]
+
+ def get_function(self, name):
+ """
+ Look up a function by name in the registry.
+
+ Parameters
+ ----------
+ name : str
+ The name of the function to lookup
+ """
+ cdef:
+ c_string c_name = tobytes(name)
+ shared_ptr[CFunction] func
+ with nogil:
+ func = GetResultValue(self.registry.GetFunction(c_name))
+ return wrap_function(func)
+
+
+cdef FunctionRegistry _global_func_registry = FunctionRegistry()
+
+
+def function_registry():
+ return _global_func_registry
+
+
+def get_function(name):
+ """
+ Get a function by name.
+
+ The function is looked up in the global registry
+ (as returned by `function_registry()`).
+
+ Parameters
+ ----------
+ name : str
+ The name of the function to lookup
+ """
+ return _global_func_registry.get_function(name)
+
+
+def list_functions():
+ """
+ Return all function names in the global registry.
+ """
+ return _global_func_registry.list_functions()
+
+
+def call_function(name, args, options=None, memory_pool=None):
+ """
+ Call a named function.
+
+ The function is looked up in the global registry
+ (as returned by `function_registry()`).
+
+ Parameters
+ ----------
+ name : str
+ The name of the function to call.
+ args : list
+ The arguments to the function.
+ options : optional
+ options provided to the function.
+ memory_pool : MemoryPool, optional
+ memory pool to use for allocations during function execution.
+ """
+ func = _global_func_registry.get_function(name)
+ return func.call(args, options=options, memory_pool=memory_pool)
+
+
+cdef class FunctionOptions(_Weakrefable):
+ __slots__ = () # avoid mistakingly creating attributes
+
+ cdef const CFunctionOptions* get_options(self) except NULL:
+ return self.wrapped.get()
+
+ cdef void init(self, unique_ptr[CFunctionOptions] options):
+ self.wrapped = move(options)
+
+ def serialize(self):
+ cdef:
+ CResult[shared_ptr[CBuffer]] res = self.get_options().Serialize()
+ shared_ptr[CBuffer] c_buf = GetResultValue(res)
+ return pyarrow_wrap_buffer(c_buf)
+
+ @staticmethod
+ def deserialize(buf):
+ """
+ Deserialize options for a function.
+
+ Parameters
+ ----------
+ buf : Buffer
+ The buffer containing the data to deserialize.
+ """
+ cdef:
+ shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf)
+ CResult[unique_ptr[CFunctionOptions]] maybe_options = \
+ DeserializeFunctionOptions(deref(c_buf))
+ unique_ptr[CFunctionOptions] c_options
+ c_options = move(GetResultValue(move(maybe_options)))
+ type_name = frombytes(c_options.get().options_type().type_name())
+ module = globals()
+ if type_name not in module:
+ raise ValueError(f'Cannot deserialize "{type_name}"')
+ klass = module[type_name]
+ options = klass.__new__(klass)
+ (<FunctionOptions> options).init(move(c_options))
+ return options
+
+ def __repr__(self):
+ type_name = self.__class__.__name__
+ # Remove {} so we can use our own braces
+ string_repr = frombytes(self.get_options().ToString())[1:-1]
+ return f"{type_name}({string_repr})"
+
+ def __eq__(self, FunctionOptions other):
+ return self.get_options().Equals(deref(other.get_options()))
+
+
+def _raise_invalid_function_option(value, description, *,
+ exception_class=ValueError):
+ raise exception_class(f"\"{value}\" is not a valid {description}")
+
+
+# NOTE:
+# To properly expose the constructor signature of FunctionOptions
+# subclasses, we use a two-level inheritance:
+# 1. a C extension class that implements option validation and setting
+# (won't expose function signatures because of
+# https://github.com/cython/cython/issues/3873)
+# 2. a Python derived class that implements the constructor
+
+cdef class _CastOptions(FunctionOptions):
+ cdef CCastOptions* options
+
+ cdef void init(self, unique_ptr[CFunctionOptions] options):
+ FunctionOptions.init(self, move(options))
+ self.options = <CCastOptions*> self.wrapped.get()
+
+ def _set_options(self, DataType target_type, allow_int_overflow,
+ allow_time_truncate, allow_time_overflow,
+ allow_decimal_truncate, allow_float_truncate,
+ allow_invalid_utf8):
+ self.init(unique_ptr[CFunctionOptions](new CCastOptions()))
+ self._set_type(target_type)
+ if allow_int_overflow is not None:
+ self.allow_int_overflow = allow_int_overflow
+ if allow_time_truncate is not None:
+ self.allow_time_truncate = allow_time_truncate
+ if allow_time_overflow is not None:
+ self.allow_time_overflow = allow_time_overflow
+ if allow_decimal_truncate is not None:
+ self.allow_decimal_truncate = allow_decimal_truncate
+ if allow_float_truncate is not None:
+ self.allow_float_truncate = allow_float_truncate
+ if allow_invalid_utf8 is not None:
+ self.allow_invalid_utf8 = allow_invalid_utf8
+
+ def _set_type(self, target_type=None):
+ if target_type is not None:
+ deref(self.options).to_type = \
+ (<DataType> ensure_type(target_type)).sp_type
+
+ def _set_safe(self):
+ self.init(unique_ptr[CFunctionOptions](
+ new CCastOptions(CCastOptions.Safe())))
+
+ def _set_unsafe(self):
+ self.init(unique_ptr[CFunctionOptions](
+ new CCastOptions(CCastOptions.Unsafe())))
+
+ def is_safe(self):
+ return not (deref(self.options).allow_int_overflow or
+ deref(self.options).allow_time_truncate or
+ deref(self.options).allow_time_overflow or
+ deref(self.options).allow_decimal_truncate or
+ deref(self.options).allow_float_truncate or
+ deref(self.options).allow_invalid_utf8)
+
+ @property
+ def allow_int_overflow(self):
+ return deref(self.options).allow_int_overflow
+
+ @allow_int_overflow.setter
+ def allow_int_overflow(self, c_bool flag):
+ deref(self.options).allow_int_overflow = flag
+
+ @property
+ def allow_time_truncate(self):
+ return deref(self.options).allow_time_truncate
+
+ @allow_time_truncate.setter
+ def allow_time_truncate(self, c_bool flag):
+ deref(self.options).allow_time_truncate = flag
+
+ @property
+ def allow_time_overflow(self):
+ return deref(self.options).allow_time_overflow
+
+ @allow_time_overflow.setter
+ def allow_time_overflow(self, c_bool flag):
+ deref(self.options).allow_time_overflow = flag
+
+ @property
+ def allow_decimal_truncate(self):
+ return deref(self.options).allow_decimal_truncate
+
+ @allow_decimal_truncate.setter
+ def allow_decimal_truncate(self, c_bool flag):
+ deref(self.options).allow_decimal_truncate = flag
+
+ @property
+ def allow_float_truncate(self):
+ return deref(self.options).allow_float_truncate
+
+ @allow_float_truncate.setter
+ def allow_float_truncate(self, c_bool flag):
+ deref(self.options).allow_float_truncate = flag
+
+ @property
+ def allow_invalid_utf8(self):
+ return deref(self.options).allow_invalid_utf8
+
+ @allow_invalid_utf8.setter
+ def allow_invalid_utf8(self, c_bool flag):
+ deref(self.options).allow_invalid_utf8 = flag
+
+
+class CastOptions(_CastOptions):
+
+ def __init__(self, target_type=None, *, allow_int_overflow=None,
+ allow_time_truncate=None, allow_time_overflow=None,
+ allow_decimal_truncate=None, allow_float_truncate=None,
+ allow_invalid_utf8=None):
+ self._set_options(target_type, allow_int_overflow, allow_time_truncate,
+ allow_time_overflow, allow_decimal_truncate,
+ allow_float_truncate, allow_invalid_utf8)
+
+ @staticmethod
+ def safe(target_type=None):
+ """"
+ Create a CastOptions for a safe cast.
+
+ Parameters
+ ----------
+ target_type : optional
+ Target cast type for the safe cast.
+ """
+ self = CastOptions()
+ self._set_safe()
+ self._set_type(target_type)
+ return self
+
+ @staticmethod
+ def unsafe(target_type=None):
+ """"
+ Create a CastOptions for an unsafe cast.
+
+ Parameters
+ ----------
+ target_type : optional
+ Target cast type for the unsafe cast.
+ """
+ self = CastOptions()
+ self._set_unsafe()
+ self._set_type(target_type)
+ return self
+
+
+cdef class _ElementWiseAggregateOptions(FunctionOptions):
+ def _set_options(self, skip_nulls):
+ self.wrapped.reset(new CElementWiseAggregateOptions(skip_nulls))
+
+
+class ElementWiseAggregateOptions(_ElementWiseAggregateOptions):
+ def __init__(self, *, skip_nulls=True):
+ self._set_options(skip_nulls)
+
+
+cdef CRoundMode unwrap_round_mode(round_mode) except *:
+ if round_mode == "down":
+ return CRoundMode_DOWN
+ elif round_mode == "up":
+ return CRoundMode_UP
+ elif round_mode == "towards_zero":
+ return CRoundMode_TOWARDS_ZERO
+ elif round_mode == "towards_infinity":
+ return CRoundMode_TOWARDS_INFINITY
+ elif round_mode == "half_down":
+ return CRoundMode_HALF_DOWN
+ elif round_mode == "half_up":
+ return CRoundMode_HALF_UP
+ elif round_mode == "half_towards_zero":
+ return CRoundMode_HALF_TOWARDS_ZERO
+ elif round_mode == "half_towards_infinity":
+ return CRoundMode_HALF_TOWARDS_INFINITY
+ elif round_mode == "half_to_even":
+ return CRoundMode_HALF_TO_EVEN
+ elif round_mode == "half_to_odd":
+ return CRoundMode_HALF_TO_ODD
+ _raise_invalid_function_option(round_mode, "round mode")
+
+
+cdef class _RoundOptions(FunctionOptions):
+ def _set_options(self, ndigits, round_mode):
+ self.wrapped.reset(
+ new CRoundOptions(ndigits, unwrap_round_mode(round_mode))
+ )
+
+
+class RoundOptions(_RoundOptions):
+ def __init__(self, ndigits=0, round_mode="half_to_even"):
+ self._set_options(ndigits, round_mode)
+
+
+cdef class _RoundToMultipleOptions(FunctionOptions):
+ def _set_options(self, multiple, round_mode):
+ self.wrapped.reset(
+ new CRoundToMultipleOptions(multiple,
+ unwrap_round_mode(round_mode))
+ )
+
+
+class RoundToMultipleOptions(_RoundToMultipleOptions):
+ def __init__(self, multiple=1.0, round_mode="half_to_even"):
+ self._set_options(multiple, round_mode)
+
+
+cdef class _JoinOptions(FunctionOptions):
+ _null_handling_map = {
+ "emit_null": CJoinNullHandlingBehavior_EMIT_NULL,
+ "skip": CJoinNullHandlingBehavior_SKIP,
+ "replace": CJoinNullHandlingBehavior_REPLACE,
+ }
+
+ def _set_options(self, null_handling, null_replacement):
+ try:
+ self.wrapped.reset(
+ new CJoinOptions(self._null_handling_map[null_handling],
+ tobytes(null_replacement))
+ )
+ except KeyError:
+ _raise_invalid_function_option(null_handling, "null handling")
+
+
+class JoinOptions(_JoinOptions):
+ def __init__(self, null_handling="emit_null", null_replacement=""):
+ self._set_options(null_handling, null_replacement)
+
+
+cdef class _MatchSubstringOptions(FunctionOptions):
+ def _set_options(self, pattern, ignore_case):
+ self.wrapped.reset(
+ new CMatchSubstringOptions(tobytes(pattern), ignore_case)
+ )
+
+
+class MatchSubstringOptions(_MatchSubstringOptions):
+ def __init__(self, pattern, *, ignore_case=False):
+ self._set_options(pattern, ignore_case)
+
+
+cdef class _PadOptions(FunctionOptions):
+ def _set_options(self, width, padding):
+ self.wrapped.reset(new CPadOptions(width, tobytes(padding)))
+
+
+class PadOptions(_PadOptions):
+ def __init__(self, width, padding=' '):
+ self._set_options(width, padding)
+
+
+cdef class _TrimOptions(FunctionOptions):
+ def _set_options(self, characters):
+ self.wrapped.reset(new CTrimOptions(tobytes(characters)))
+
+
+class TrimOptions(_TrimOptions):
+ def __init__(self, characters):
+ self._set_options(tobytes(characters))
+
+
+cdef class _ReplaceSliceOptions(FunctionOptions):
+ def _set_options(self, start, stop, replacement):
+ self.wrapped.reset(
+ new CReplaceSliceOptions(start, stop, tobytes(replacement))
+ )
+
+
+class ReplaceSliceOptions(_ReplaceSliceOptions):
+ def __init__(self, start, stop, replacement):
+ self._set_options(start, stop, replacement)
+
+
+cdef class _ReplaceSubstringOptions(FunctionOptions):
+ def _set_options(self, pattern, replacement, max_replacements):
+ self.wrapped.reset(
+ new CReplaceSubstringOptions(tobytes(pattern),
+ tobytes(replacement),
+ max_replacements)
+ )
+
+
+class ReplaceSubstringOptions(_ReplaceSubstringOptions):
+ def __init__(self, pattern, replacement, *, max_replacements=-1):
+ self._set_options(pattern, replacement, max_replacements)
+
+
+cdef class _ExtractRegexOptions(FunctionOptions):
+ def _set_options(self, pattern):
+ self.wrapped.reset(new CExtractRegexOptions(tobytes(pattern)))
+
+
+class ExtractRegexOptions(_ExtractRegexOptions):
+ def __init__(self, pattern):
+ self._set_options(pattern)
+
+
+cdef class _SliceOptions(FunctionOptions):
+ def _set_options(self, start, stop, step):
+ self.wrapped.reset(new CSliceOptions(start, stop, step))
+
+
+class SliceOptions(_SliceOptions):
+ def __init__(self, start, stop=sys.maxsize, step=1):
+ self._set_options(start, stop, step)
+
+
+cdef class _FilterOptions(FunctionOptions):
+ _null_selection_map = {
+ "drop": CFilterNullSelectionBehavior_DROP,
+ "emit_null": CFilterNullSelectionBehavior_EMIT_NULL,
+ }
+
+ def _set_options(self, null_selection_behavior):
+ try:
+ self.wrapped.reset(
+ new CFilterOptions(
+ self._null_selection_map[null_selection_behavior]
+ )
+ )
+ except KeyError:
+ _raise_invalid_function_option(null_selection_behavior,
+ "null selection behavior")
+
+
+class FilterOptions(_FilterOptions):
+ def __init__(self, null_selection_behavior="drop"):
+ self._set_options(null_selection_behavior)
+
+
+cdef class _DictionaryEncodeOptions(FunctionOptions):
+ _null_encoding_map = {
+ "encode": CDictionaryEncodeNullEncodingBehavior_ENCODE,
+ "mask": CDictionaryEncodeNullEncodingBehavior_MASK,
+ }
+
+ def _set_options(self, null_encoding):
+ try:
+ self.wrapped.reset(
+ new CDictionaryEncodeOptions(
+ self._null_encoding_map[null_encoding]
+ )
+ )
+ except KeyError:
+ _raise_invalid_function_option(null_encoding, "null encoding")
+
+
+class DictionaryEncodeOptions(_DictionaryEncodeOptions):
+ def __init__(self, null_encoding="mask"):
+ self._set_options(null_encoding)
+
+
+cdef class _TakeOptions(FunctionOptions):
+ def _set_options(self, boundscheck):
+ self.wrapped.reset(new CTakeOptions(boundscheck))
+
+
+class TakeOptions(_TakeOptions):
+ def __init__(self, *, boundscheck=True):
+ self._set_options(boundscheck)
+
+
+cdef class _MakeStructOptions(FunctionOptions):
+ def _set_options(self, field_names, field_nullability, field_metadata):
+ cdef:
+ vector[c_string] c_field_names
+ vector[shared_ptr[const CKeyValueMetadata]] c_field_metadata
+ for name in field_names:
+ c_field_names.push_back(tobytes(name))
+ for metadata in field_metadata:
+ c_field_metadata.push_back(pyarrow_unwrap_metadata(metadata))
+ self.wrapped.reset(
+ new CMakeStructOptions(c_field_names, field_nullability,
+ c_field_metadata)
+ )
+
+
+class MakeStructOptions(_MakeStructOptions):
+ def __init__(self, field_names, *, field_nullability=None,
+ field_metadata=None):
+ if field_nullability is None:
+ field_nullability = [True] * len(field_names)
+ if field_metadata is None:
+ field_metadata = [None] * len(field_names)
+ self._set_options(field_names, field_nullability, field_metadata)
+
+
+cdef class _ScalarAggregateOptions(FunctionOptions):
+ def _set_options(self, skip_nulls, min_count):
+ self.wrapped.reset(new CScalarAggregateOptions(skip_nulls, min_count))
+
+
+class ScalarAggregateOptions(_ScalarAggregateOptions):
+ def __init__(self, *, skip_nulls=True, min_count=1):
+ self._set_options(skip_nulls, min_count)
+
+
+cdef class _CountOptions(FunctionOptions):
+ _mode_map = {
+ "only_valid": CCountMode_ONLY_VALID,
+ "only_null": CCountMode_ONLY_NULL,
+ "all": CCountMode_ALL,
+ }
+
+ def _set_options(self, mode):
+ try:
+ self.wrapped.reset(new CCountOptions(self._mode_map[mode]))
+ except KeyError:
+ _raise_invalid_function_option(mode, "count mode")
+
+
+class CountOptions(_CountOptions):
+ def __init__(self, mode="only_valid"):
+ self._set_options(mode)
+
+
+cdef class _IndexOptions(FunctionOptions):
+ def _set_options(self, scalar):
+ self.wrapped.reset(new CIndexOptions(pyarrow_unwrap_scalar(scalar)))
+
+
+class IndexOptions(_IndexOptions):
+ """
+ Options for the index kernel.
+
+ Parameters
+ ----------
+ value : Scalar
+ The value to search for.
+ """
+
+ def __init__(self, value):
+ self._set_options(value)
+
+
+cdef class _ModeOptions(FunctionOptions):
+ def _set_options(self, n, skip_nulls, min_count):
+ self.wrapped.reset(new CModeOptions(n, skip_nulls, min_count))
+
+
+class ModeOptions(_ModeOptions):
+ def __init__(self, n=1, *, skip_nulls=True, min_count=0):
+ self._set_options(n, skip_nulls, min_count)
+
+
+cdef class _SetLookupOptions(FunctionOptions):
+ def _set_options(self, value_set, c_bool skip_nulls):
+ cdef unique_ptr[CDatum] valset
+ if isinstance(value_set, Array):
+ valset.reset(new CDatum((<Array> value_set).sp_array))
+ elif isinstance(value_set, ChunkedArray):
+ valset.reset(
+ new CDatum((<ChunkedArray> value_set).sp_chunked_array)
+ )
+ elif isinstance(value_set, Scalar):
+ valset.reset(new CDatum((<Scalar> value_set).unwrap()))
+ else:
+ _raise_invalid_function_option(value_set, "value set",
+ exception_class=TypeError)
+
+ self.wrapped.reset(new CSetLookupOptions(deref(valset), skip_nulls))
+
+
+class SetLookupOptions(_SetLookupOptions):
+ def __init__(self, value_set, *, skip_nulls=False):
+ self._set_options(value_set, skip_nulls)
+
+
+cdef class _StrptimeOptions(FunctionOptions):
+ _unit_map = {
+ "s": TimeUnit_SECOND,
+ "ms": TimeUnit_MILLI,
+ "us": TimeUnit_MICRO,
+ "ns": TimeUnit_NANO,
+ }
+
+ def _set_options(self, format, unit):
+ try:
+ self.wrapped.reset(
+ new CStrptimeOptions(tobytes(format), self._unit_map[unit])
+ )
+ except KeyError:
+ _raise_invalid_function_option(unit, "time unit")
+
+
+class StrptimeOptions(_StrptimeOptions):
+ def __init__(self, format, unit):
+ self._set_options(format, unit)
+
+
+cdef class _StrftimeOptions(FunctionOptions):
+ def _set_options(self, format, locale):
+ self.wrapped.reset(
+ new CStrftimeOptions(tobytes(format), tobytes(locale))
+ )
+
+
+class StrftimeOptions(_StrftimeOptions):
+ def __init__(self, format="%Y-%m-%dT%H:%M:%S", locale="C"):
+ self._set_options(format, locale)
+
+
+cdef class _DayOfWeekOptions(FunctionOptions):
+ def _set_options(self, count_from_zero, week_start):
+ self.wrapped.reset(
+ new CDayOfWeekOptions(count_from_zero, week_start)
+ )
+
+
+class DayOfWeekOptions(_DayOfWeekOptions):
+ def __init__(self, *, count_from_zero=True, week_start=1):
+ self._set_options(count_from_zero, week_start)
+
+
+cdef class _WeekOptions(FunctionOptions):
+ def _set_options(self, week_starts_monday, count_from_zero,
+ first_week_is_fully_in_year):
+ self.wrapped.reset(
+ new CWeekOptions(week_starts_monday, count_from_zero,
+ first_week_is_fully_in_year)
+ )
+
+
+class WeekOptions(_WeekOptions):
+ def __init__(self, *, week_starts_monday=True, count_from_zero=False,
+ first_week_is_fully_in_year=False):
+ self._set_options(week_starts_monday,
+ count_from_zero, first_week_is_fully_in_year)
+
+
+cdef class _AssumeTimezoneOptions(FunctionOptions):
+ _ambiguous_map = {
+ "raise": CAssumeTimezoneAmbiguous_AMBIGUOUS_RAISE,
+ "earliest": CAssumeTimezoneAmbiguous_AMBIGUOUS_EARLIEST,
+ "latest": CAssumeTimezoneAmbiguous_AMBIGUOUS_LATEST,
+ }
+ _nonexistent_map = {
+ "raise": CAssumeTimezoneNonexistent_NONEXISTENT_RAISE,
+ "earliest": CAssumeTimezoneNonexistent_NONEXISTENT_EARLIEST,
+ "latest": CAssumeTimezoneNonexistent_NONEXISTENT_LATEST,
+ }
+
+ def _set_options(self, timezone, ambiguous, nonexistent):
+ if ambiguous not in self._ambiguous_map:
+ _raise_invalid_function_option(ambiguous,
+ "'ambiguous' timestamp handling")
+ if nonexistent not in self._nonexistent_map:
+ _raise_invalid_function_option(nonexistent,
+ "'nonexistent' timestamp handling")
+ self.wrapped.reset(
+ new CAssumeTimezoneOptions(tobytes(timezone),
+ self._ambiguous_map[ambiguous],
+ self._nonexistent_map[nonexistent])
+ )
+
+
+class AssumeTimezoneOptions(_AssumeTimezoneOptions):
+ def __init__(self, timezone, *, ambiguous="raise", nonexistent="raise"):
+ self._set_options(timezone, ambiguous, nonexistent)
+
+
+cdef class _NullOptions(FunctionOptions):
+ def _set_options(self, nan_is_null):
+ self.wrapped.reset(new CNullOptions(nan_is_null))
+
+
+class NullOptions(_NullOptions):
+ def __init__(self, *, nan_is_null=False):
+ self._set_options(nan_is_null)
+
+
+cdef class _VarianceOptions(FunctionOptions):
+ def _set_options(self, ddof, skip_nulls, min_count):
+ self.wrapped.reset(new CVarianceOptions(ddof, skip_nulls, min_count))
+
+
+class VarianceOptions(_VarianceOptions):
+ def __init__(self, *, ddof=0, skip_nulls=True, min_count=0):
+ self._set_options(ddof, skip_nulls, min_count)
+
+
+cdef class _SplitOptions(FunctionOptions):
+ def _set_options(self, max_splits, reverse):
+ self.wrapped.reset(new CSplitOptions(max_splits, reverse))
+
+
+class SplitOptions(_SplitOptions):
+ def __init__(self, *, max_splits=-1, reverse=False):
+ self._set_options(max_splits, reverse)
+
+
+cdef class _SplitPatternOptions(FunctionOptions):
+ def _set_options(self, pattern, max_splits, reverse):
+ self.wrapped.reset(
+ new CSplitPatternOptions(tobytes(pattern), max_splits, reverse)
+ )
+
+
+class SplitPatternOptions(_SplitPatternOptions):
+ def __init__(self, pattern, *, max_splits=-1, reverse=False):
+ self._set_options(pattern, max_splits, reverse)
+
+
+cdef CSortOrder unwrap_sort_order(order) except *:
+ if order == "ascending":
+ return CSortOrder_Ascending
+ elif order == "descending":
+ return CSortOrder_Descending
+ _raise_invalid_function_option(order, "sort order")
+
+
+cdef CNullPlacement unwrap_null_placement(null_placement) except *:
+ if null_placement == "at_start":
+ return CNullPlacement_AtStart
+ elif null_placement == "at_end":
+ return CNullPlacement_AtEnd
+ _raise_invalid_function_option(null_placement, "null placement")
+
+
+cdef class _PartitionNthOptions(FunctionOptions):
+ def _set_options(self, pivot, null_placement):
+ self.wrapped.reset(new CPartitionNthOptions(
+ pivot, unwrap_null_placement(null_placement)))
+
+
+class PartitionNthOptions(_PartitionNthOptions):
+ def __init__(self, pivot, *, null_placement="at_end"):
+ self._set_options(pivot, null_placement)
+
+
+cdef class _ArraySortOptions(FunctionOptions):
+ def _set_options(self, order, null_placement):
+ self.wrapped.reset(new CArraySortOptions(
+ unwrap_sort_order(order), unwrap_null_placement(null_placement)))
+
+
+class ArraySortOptions(_ArraySortOptions):
+ def __init__(self, order="ascending", *, null_placement="at_end"):
+ self._set_options(order, null_placement)
+
+
+cdef class _SortOptions(FunctionOptions):
+ def _set_options(self, sort_keys, null_placement):
+ cdef vector[CSortKey] c_sort_keys
+ for name, order in sort_keys:
+ c_sort_keys.push_back(
+ CSortKey(tobytes(name), unwrap_sort_order(order))
+ )
+ self.wrapped.reset(new CSortOptions(
+ c_sort_keys, unwrap_null_placement(null_placement)))
+
+
+class SortOptions(_SortOptions):
+ def __init__(self, sort_keys, *, null_placement="at_end"):
+ self._set_options(sort_keys, null_placement)
+
+
+cdef class _SelectKOptions(FunctionOptions):
+ def _set_options(self, k, sort_keys):
+ cdef vector[CSortKey] c_sort_keys
+ for name, order in sort_keys:
+ c_sort_keys.push_back(
+ CSortKey(tobytes(name), unwrap_sort_order(order))
+ )
+ self.wrapped.reset(new CSelectKOptions(k, c_sort_keys))
+
+
+class SelectKOptions(_SelectKOptions):
+ def __init__(self, k, sort_keys):
+ self._set_options(k, sort_keys)
+
+
+cdef class _QuantileOptions(FunctionOptions):
+ _interp_map = {
+ "linear": CQuantileInterp_LINEAR,
+ "lower": CQuantileInterp_LOWER,
+ "higher": CQuantileInterp_HIGHER,
+ "nearest": CQuantileInterp_NEAREST,
+ "midpoint": CQuantileInterp_MIDPOINT,
+ }
+
+ def _set_options(self, quantiles, interp, skip_nulls, min_count):
+ try:
+ self.wrapped.reset(
+ new CQuantileOptions(quantiles, self._interp_map[interp],
+ skip_nulls, min_count)
+ )
+ except KeyError:
+ _raise_invalid_function_option(interp, "quantile interpolation")
+
+
+class QuantileOptions(_QuantileOptions):
+ def __init__(self, q=0.5, *, interpolation="linear", skip_nulls=True,
+ min_count=0):
+ if not isinstance(q, (list, tuple, np.ndarray)):
+ q = [q]
+ self._set_options(q, interpolation, skip_nulls, min_count)
+
+
+cdef class _TDigestOptions(FunctionOptions):
+ def _set_options(self, quantiles, delta, buffer_size, skip_nulls,
+ min_count):
+ self.wrapped.reset(
+ new CTDigestOptions(quantiles, delta, buffer_size, skip_nulls,
+ min_count)
+ )
+
+
+class TDigestOptions(_TDigestOptions):
+ def __init__(self, q=0.5, *, delta=100, buffer_size=500, skip_nulls=True,
+ min_count=0):
+ if not isinstance(q, (list, tuple, np.ndarray)):
+ q = [q]
+ self._set_options(q, delta, buffer_size, skip_nulls, min_count)
diff --git a/src/arrow/python/pyarrow/_csv.pxd b/src/arrow/python/pyarrow/_csv.pxd
new file mode 100644
index 000000000..b2fe7d639
--- /dev/null
+++ b/src/arrow/python/pyarrow/_csv.pxd
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from pyarrow.includes.libarrow cimport *
+from pyarrow.lib cimport _Weakrefable
+
+
+cdef class ConvertOptions(_Weakrefable):
+ cdef:
+ unique_ptr[CCSVConvertOptions] options
+
+ @staticmethod
+ cdef ConvertOptions wrap(CCSVConvertOptions options)
+
+
+cdef class ParseOptions(_Weakrefable):
+ cdef:
+ unique_ptr[CCSVParseOptions] options
+
+ @staticmethod
+ cdef ParseOptions wrap(CCSVParseOptions options)
+
+
+cdef class ReadOptions(_Weakrefable):
+ cdef:
+ unique_ptr[CCSVReadOptions] options
+ public object encoding
+
+ @staticmethod
+ cdef ReadOptions wrap(CCSVReadOptions options)
+
+
+cdef class WriteOptions(_Weakrefable):
+ cdef:
+ unique_ptr[CCSVWriteOptions] options
+
+ @staticmethod
+ cdef WriteOptions wrap(CCSVWriteOptions options)
diff --git a/src/arrow/python/pyarrow/_csv.pyx b/src/arrow/python/pyarrow/_csv.pyx
new file mode 100644
index 000000000..19ade4324
--- /dev/null
+++ b/src/arrow/python/pyarrow/_csv.pyx
@@ -0,0 +1,1077 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile=False
+# distutils: language = c++
+# cython: language_level = 3
+
+from cython.operator cimport dereference as deref
+
+import codecs
+from collections.abc import Mapping
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema,
+ RecordBatchReader, ensure_type,
+ maybe_unbox_memory_pool, get_input_stream,
+ get_writer, native_transcoding_input_stream,
+ pyarrow_unwrap_batch, pyarrow_unwrap_schema,
+ pyarrow_unwrap_table, pyarrow_wrap_schema,
+ pyarrow_wrap_table, pyarrow_wrap_data_type,
+ pyarrow_unwrap_data_type, Table, RecordBatch,
+ StopToken, _CRecordBatchWriter)
+from pyarrow.lib import frombytes, tobytes, SignalStopHandler
+from pyarrow.util import _stringify_path
+
+
+cdef unsigned char _single_char(s) except 0:
+ val = ord(s)
+ if val == 0 or val > 127:
+ raise ValueError("Expecting an ASCII character")
+ return <unsigned char> val
+
+
+cdef class ReadOptions(_Weakrefable):
+ """
+ Options for reading CSV files.
+
+ Parameters
+ ----------
+ use_threads : bool, optional (default True)
+ Whether to use multiple threads to accelerate reading
+ block_size : int, optional
+ How much bytes to process at a time from the input stream.
+ This will determine multi-threading granularity as well as
+ the size of individual record batches or table chunks.
+ Minimum valid value for block size is 1
+ skip_rows : int, optional (default 0)
+ The number of rows to skip before the column names (if any)
+ and the CSV data.
+ skip_rows_after_names : int, optional (default 0)
+ The number of rows to skip after the column names.
+ This number can be larger than the number of rows in one
+ block, and empty rows are counted.
+ The order of application is as follows:
+ - `skip_rows` is applied (if non-zero);
+ - column names aread (unless `column_names` is set);
+ - `skip_rows_after_names` is applied (if non-zero).
+ column_names : list, optional
+ The column names of the target table. If empty, fall back on
+ `autogenerate_column_names`.
+ autogenerate_column_names : bool, optional (default False)
+ Whether to autogenerate column names if `column_names` is empty.
+ If true, column names will be of the form "f0", "f1"...
+ If false, column names will be read from the first CSV row
+ after `skip_rows`.
+ encoding : str, optional (default 'utf8')
+ The character encoding of the CSV data. Columns that cannot
+ decode using this encoding can still be read as Binary.
+ """
+
+ # Avoid mistakingly creating attributes
+ __slots__ = ()
+
+ # __init__() is not called when unpickling, initialize storage here
+ def __cinit__(self, *argw, **kwargs):
+ self.options.reset(new CCSVReadOptions(CCSVReadOptions.Defaults()))
+
+ def __init__(self, *, use_threads=None, block_size=None, skip_rows=None,
+ column_names=None, autogenerate_column_names=None,
+ encoding='utf8', skip_rows_after_names=None):
+ if use_threads is not None:
+ self.use_threads = use_threads
+ if block_size is not None:
+ self.block_size = block_size
+ if skip_rows is not None:
+ self.skip_rows = skip_rows
+ if column_names is not None:
+ self.column_names = column_names
+ if autogenerate_column_names is not None:
+ self.autogenerate_column_names= autogenerate_column_names
+ # Python-specific option
+ self.encoding = encoding
+ if skip_rows_after_names is not None:
+ self.skip_rows_after_names = skip_rows_after_names
+
+ @property
+ def use_threads(self):
+ """
+ Whether to use multiple threads to accelerate reading.
+ """
+ return deref(self.options).use_threads
+
+ @use_threads.setter
+ def use_threads(self, value):
+ deref(self.options).use_threads = value
+
+ @property
+ def block_size(self):
+ """
+ How much bytes to process at a time from the input stream.
+ This will determine multi-threading granularity as well as
+ the size of individual record batches or table chunks.
+ """
+ return deref(self.options).block_size
+
+ @block_size.setter
+ def block_size(self, value):
+ deref(self.options).block_size = value
+
+ @property
+ def skip_rows(self):
+ """
+ The number of rows to skip before the column names (if any)
+ and the CSV data.
+ See `skip_rows_after_names` for interaction description
+ """
+ return deref(self.options).skip_rows
+
+ @skip_rows.setter
+ def skip_rows(self, value):
+ deref(self.options).skip_rows = value
+
+ @property
+ def column_names(self):
+ """
+ The column names of the target table. If empty, fall back on
+ `autogenerate_column_names`.
+ """
+ return [frombytes(s) for s in deref(self.options).column_names]
+
+ @column_names.setter
+ def column_names(self, value):
+ deref(self.options).column_names.clear()
+ for item in value:
+ deref(self.options).column_names.push_back(tobytes(item))
+
+ @property
+ def autogenerate_column_names(self):
+ """
+ Whether to autogenerate column names if `column_names` is empty.
+ If true, column names will be of the form "f0", "f1"...
+ If false, column names will be read from the first CSV row
+ after `skip_rows`.
+ """
+ return deref(self.options).autogenerate_column_names
+
+ @autogenerate_column_names.setter
+ def autogenerate_column_names(self, value):
+ deref(self.options).autogenerate_column_names = value
+
+ @property
+ def skip_rows_after_names(self):
+ """
+ The number of rows to skip after the column names.
+ This number can be larger than the number of rows in one
+ block, and empty rows are counted.
+ The order of application is as follows:
+ - `skip_rows` is applied (if non-zero);
+ - column names aread (unless `column_names` is set);
+ - `skip_rows_after_names` is applied (if non-zero).
+ """
+ return deref(self.options).skip_rows_after_names
+
+ @skip_rows_after_names.setter
+ def skip_rows_after_names(self, value):
+ deref(self.options).skip_rows_after_names = value
+
+ def validate(self):
+ check_status(deref(self.options).Validate())
+
+ def equals(self, ReadOptions other):
+ return (
+ self.use_threads == other.use_threads and
+ self.block_size == other.block_size and
+ self.skip_rows == other.skip_rows and
+ self.column_names == other.column_names and
+ self.autogenerate_column_names ==
+ other.autogenerate_column_names and
+ self.encoding == other.encoding and
+ self.skip_rows_after_names == other.skip_rows_after_names
+ )
+
+ @staticmethod
+ cdef ReadOptions wrap(CCSVReadOptions options):
+ out = ReadOptions()
+ out.options.reset(new CCSVReadOptions(move(options)))
+ out.encoding = 'utf8' # No way to know this
+ return out
+
+ def __getstate__(self):
+ return (self.use_threads, self.block_size, self.skip_rows,
+ self.column_names, self.autogenerate_column_names,
+ self.encoding, self.skip_rows_after_names)
+
+ def __setstate__(self, state):
+ (self.use_threads, self.block_size, self.skip_rows,
+ self.column_names, self.autogenerate_column_names,
+ self.encoding, self.skip_rows_after_names) = state
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return False
+
+
+cdef class ParseOptions(_Weakrefable):
+ """
+ Options for parsing CSV files.
+
+ Parameters
+ ----------
+ delimiter : 1-character string, optional (default ',')
+ The character delimiting individual cells in the CSV data.
+ quote_char : 1-character string or False, optional (default '"')
+ The character used optionally for quoting CSV values
+ (False if quoting is not allowed).
+ double_quote : bool, optional (default True)
+ Whether two quotes in a quoted CSV value denote a single quote
+ in the data.
+ escape_char : 1-character string or False, optional (default False)
+ The character used optionally for escaping special characters
+ (False if escaping is not allowed).
+ newlines_in_values : bool, optional (default False)
+ Whether newline characters are allowed in CSV values.
+ Setting this to True reduces the performance of multi-threaded
+ CSV reading.
+ ignore_empty_lines : bool, optional (default True)
+ Whether empty lines are ignored in CSV input.
+ If False, an empty line is interpreted as containing a single empty
+ value (assuming a one-column CSV file).
+ """
+ __slots__ = ()
+
+ def __cinit__(self, *argw, **kwargs):
+ self.options.reset(new CCSVParseOptions(CCSVParseOptions.Defaults()))
+
+ def __init__(self, *, delimiter=None, quote_char=None, double_quote=None,
+ escape_char=None, newlines_in_values=None,
+ ignore_empty_lines=None):
+ if delimiter is not None:
+ self.delimiter = delimiter
+ if quote_char is not None:
+ self.quote_char = quote_char
+ if double_quote is not None:
+ self.double_quote = double_quote
+ if escape_char is not None:
+ self.escape_char = escape_char
+ if newlines_in_values is not None:
+ self.newlines_in_values = newlines_in_values
+ if ignore_empty_lines is not None:
+ self.ignore_empty_lines = ignore_empty_lines
+
+ @property
+ def delimiter(self):
+ """
+ The character delimiting individual cells in the CSV data.
+ """
+ return chr(deref(self.options).delimiter)
+
+ @delimiter.setter
+ def delimiter(self, value):
+ deref(self.options).delimiter = _single_char(value)
+
+ @property
+ def quote_char(self):
+ """
+ The character used optionally for quoting CSV values
+ (False if quoting is not allowed).
+ """
+ if deref(self.options).quoting:
+ return chr(deref(self.options).quote_char)
+ else:
+ return False
+
+ @quote_char.setter
+ def quote_char(self, value):
+ if value is False:
+ deref(self.options).quoting = False
+ else:
+ deref(self.options).quote_char = _single_char(value)
+ deref(self.options).quoting = True
+
+ @property
+ def double_quote(self):
+ """
+ Whether two quotes in a quoted CSV value denote a single quote
+ in the data.
+ """
+ return deref(self.options).double_quote
+
+ @double_quote.setter
+ def double_quote(self, value):
+ deref(self.options).double_quote = value
+
+ @property
+ def escape_char(self):
+ """
+ The character used optionally for escaping special characters
+ (False if escaping is not allowed).
+ """
+ if deref(self.options).escaping:
+ return chr(deref(self.options).escape_char)
+ else:
+ return False
+
+ @escape_char.setter
+ def escape_char(self, value):
+ if value is False:
+ deref(self.options).escaping = False
+ else:
+ deref(self.options).escape_char = _single_char(value)
+ deref(self.options).escaping = True
+
+ @property
+ def newlines_in_values(self):
+ """
+ Whether newline characters are allowed in CSV values.
+ Setting this to True reduces the performance of multi-threaded
+ CSV reading.
+ """
+ return deref(self.options).newlines_in_values
+
+ @newlines_in_values.setter
+ def newlines_in_values(self, value):
+ deref(self.options).newlines_in_values = value
+
+ @property
+ def ignore_empty_lines(self):
+ """
+ Whether empty lines are ignored in CSV input.
+ If False, an empty line is interpreted as containing a single empty
+ value (assuming a one-column CSV file).
+ """
+ return deref(self.options).ignore_empty_lines
+
+ @ignore_empty_lines.setter
+ def ignore_empty_lines(self, value):
+ deref(self.options).ignore_empty_lines = value
+
+ def validate(self):
+ check_status(deref(self.options).Validate())
+
+ def equals(self, ParseOptions other):
+ return (
+ self.delimiter == other.delimiter and
+ self.quote_char == other.quote_char and
+ self.double_quote == other.double_quote and
+ self.escape_char == other.escape_char and
+ self.newlines_in_values == other.newlines_in_values and
+ self.ignore_empty_lines == other.ignore_empty_lines
+ )
+
+ @staticmethod
+ cdef ParseOptions wrap(CCSVParseOptions options):
+ out = ParseOptions()
+ out.options.reset(new CCSVParseOptions(move(options)))
+ return out
+
+ def __getstate__(self):
+ return (self.delimiter, self.quote_char, self.double_quote,
+ self.escape_char, self.newlines_in_values,
+ self.ignore_empty_lines)
+
+ def __setstate__(self, state):
+ (self.delimiter, self.quote_char, self.double_quote,
+ self.escape_char, self.newlines_in_values,
+ self.ignore_empty_lines) = state
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return False
+
+
+cdef class _ISO8601(_Weakrefable):
+ """
+ A special object indicating ISO-8601 parsing.
+ """
+ __slots__ = ()
+
+ def __str__(self):
+ return 'ISO8601'
+
+ def __eq__(self, other):
+ return isinstance(other, _ISO8601)
+
+
+ISO8601 = _ISO8601()
+
+
+cdef class ConvertOptions(_Weakrefable):
+ """
+ Options for converting CSV data.
+
+ Parameters
+ ----------
+ check_utf8 : bool, optional (default True)
+ Whether to check UTF8 validity of string columns.
+ column_types : pa.Schema or dict, optional
+ Explicitly map column names to column types. Passing this argument
+ disables type inference on the defined columns.
+ null_values : list, optional
+ A sequence of strings that denote nulls in the data
+ (defaults are appropriate in most cases). Note that by default,
+ string columns are not checked for null values. To enable
+ null checking for those, specify ``strings_can_be_null=True``.
+ true_values : list, optional
+ A sequence of strings that denote true booleans in the data
+ (defaults are appropriate in most cases).
+ false_values : list, optional
+ A sequence of strings that denote false booleans in the data
+ (defaults are appropriate in most cases).
+ decimal_point : 1-character string, optional (default '.')
+ The character used as decimal point in floating-point and decimal
+ data.
+ timestamp_parsers : list, optional
+ A sequence of strptime()-compatible format strings, tried in order
+ when attempting to infer or convert timestamp values (the special
+ value ISO8601() can also be given). By default, a fast built-in
+ ISO-8601 parser is used.
+ strings_can_be_null : bool, optional (default False)
+ Whether string / binary columns can have null values.
+ If true, then strings in null_values are considered null for
+ string columns.
+ If false, then all strings are valid string values.
+ quoted_strings_can_be_null : bool, optional (default True)
+ Whether quoted values can be null.
+ If true, then strings in "null_values" are also considered null
+ when they appear quoted in the CSV file. Otherwise, quoted values
+ are never considered null.
+ auto_dict_encode : bool, optional (default False)
+ Whether to try to automatically dict-encode string / binary data.
+ If true, then when type inference detects a string or binary column,
+ it it dict-encoded up to `auto_dict_max_cardinality` distinct values
+ (per chunk), after which it switches to regular encoding.
+ This setting is ignored for non-inferred columns (those in
+ `column_types`).
+ auto_dict_max_cardinality : int, optional
+ The maximum dictionary cardinality for `auto_dict_encode`.
+ This value is per chunk.
+ include_columns : list, optional
+ The names of columns to include in the Table.
+ If empty, the Table will include all columns from the CSV file.
+ If not empty, only these columns will be included, in this order.
+ include_missing_columns : bool, optional (default False)
+ If false, columns in `include_columns` but not in the CSV file will
+ error out.
+ If true, columns in `include_columns` but not in the CSV file will
+ produce a column of nulls (whose type is selected using
+ `column_types`, or null by default).
+ This option is ignored if `include_columns` is empty.
+ """
+ # Avoid mistakingly creating attributes
+ __slots__ = ()
+
+ def __cinit__(self, *argw, **kwargs):
+ self.options.reset(
+ new CCSVConvertOptions(CCSVConvertOptions.Defaults()))
+
+ def __init__(self, *, check_utf8=None, column_types=None, null_values=None,
+ true_values=None, false_values=None, decimal_point=None,
+ strings_can_be_null=None, quoted_strings_can_be_null=None,
+ include_columns=None, include_missing_columns=None,
+ auto_dict_encode=None, auto_dict_max_cardinality=None,
+ timestamp_parsers=None):
+ if check_utf8 is not None:
+ self.check_utf8 = check_utf8
+ if column_types is not None:
+ self.column_types = column_types
+ if null_values is not None:
+ self.null_values = null_values
+ if true_values is not None:
+ self.true_values = true_values
+ if false_values is not None:
+ self.false_values = false_values
+ if decimal_point is not None:
+ self.decimal_point = decimal_point
+ if strings_can_be_null is not None:
+ self.strings_can_be_null = strings_can_be_null
+ if quoted_strings_can_be_null is not None:
+ self.quoted_strings_can_be_null = quoted_strings_can_be_null
+ if include_columns is not None:
+ self.include_columns = include_columns
+ if include_missing_columns is not None:
+ self.include_missing_columns = include_missing_columns
+ if auto_dict_encode is not None:
+ self.auto_dict_encode = auto_dict_encode
+ if auto_dict_max_cardinality is not None:
+ self.auto_dict_max_cardinality = auto_dict_max_cardinality
+ if timestamp_parsers is not None:
+ self.timestamp_parsers = timestamp_parsers
+
+ @property
+ def check_utf8(self):
+ """
+ Whether to check UTF8 validity of string columns.
+ """
+ return deref(self.options).check_utf8
+
+ @check_utf8.setter
+ def check_utf8(self, value):
+ deref(self.options).check_utf8 = value
+
+ @property
+ def strings_can_be_null(self):
+ """
+ Whether string / binary columns can have null values.
+ """
+ return deref(self.options).strings_can_be_null
+
+ @strings_can_be_null.setter
+ def strings_can_be_null(self, value):
+ deref(self.options).strings_can_be_null = value
+
+ @property
+ def quoted_strings_can_be_null(self):
+ """
+ Whether quoted values can be null.
+ """
+ return deref(self.options).quoted_strings_can_be_null
+
+ @quoted_strings_can_be_null.setter
+ def quoted_strings_can_be_null(self, value):
+ deref(self.options).quoted_strings_can_be_null = value
+
+ @property
+ def column_types(self):
+ """
+ Explicitly map column names to column types.
+ """
+ d = {frombytes(item.first): pyarrow_wrap_data_type(item.second)
+ for item in deref(self.options).column_types}
+ return d
+
+ @column_types.setter
+ def column_types(self, value):
+ cdef:
+ shared_ptr[CDataType] typ
+
+ if isinstance(value, Mapping):
+ value = value.items()
+
+ deref(self.options).column_types.clear()
+ for item in value:
+ if isinstance(item, Field):
+ k = item.name
+ v = item.type
+ else:
+ k, v = item
+ typ = pyarrow_unwrap_data_type(ensure_type(v))
+ assert typ != NULL
+ deref(self.options).column_types[tobytes(k)] = typ
+
+ @property
+ def null_values(self):
+ """
+ A sequence of strings that denote nulls in the data.
+ """
+ return [frombytes(x) for x in deref(self.options).null_values]
+
+ @null_values.setter
+ def null_values(self, value):
+ deref(self.options).null_values = [tobytes(x) for x in value]
+
+ @property
+ def true_values(self):
+ """
+ A sequence of strings that denote true booleans in the data.
+ """
+ return [frombytes(x) for x in deref(self.options).true_values]
+
+ @true_values.setter
+ def true_values(self, value):
+ deref(self.options).true_values = [tobytes(x) for x in value]
+
+ @property
+ def false_values(self):
+ """
+ A sequence of strings that denote false booleans in the data.
+ """
+ return [frombytes(x) for x in deref(self.options).false_values]
+
+ @false_values.setter
+ def false_values(self, value):
+ deref(self.options).false_values = [tobytes(x) for x in value]
+
+ @property
+ def decimal_point(self):
+ """
+ The character used as decimal point in floating-point and decimal
+ data.
+ """
+ return chr(deref(self.options).decimal_point)
+
+ @decimal_point.setter
+ def decimal_point(self, value):
+ deref(self.options).decimal_point = _single_char(value)
+
+ @property
+ def auto_dict_encode(self):
+ """
+ Whether to try to automatically dict-encode string / binary data.
+ """
+ return deref(self.options).auto_dict_encode
+
+ @auto_dict_encode.setter
+ def auto_dict_encode(self, value):
+ deref(self.options).auto_dict_encode = value
+
+ @property
+ def auto_dict_max_cardinality(self):
+ """
+ The maximum dictionary cardinality for `auto_dict_encode`.
+
+ This value is per chunk.
+ """
+ return deref(self.options).auto_dict_max_cardinality
+
+ @auto_dict_max_cardinality.setter
+ def auto_dict_max_cardinality(self, value):
+ deref(self.options).auto_dict_max_cardinality = value
+
+ @property
+ def include_columns(self):
+ """
+ The names of columns to include in the Table.
+
+ If empty, the Table will include all columns from the CSV file.
+ If not empty, only these columns will be included, in this order.
+ """
+ return [frombytes(s) for s in deref(self.options).include_columns]
+
+ @include_columns.setter
+ def include_columns(self, value):
+ deref(self.options).include_columns.clear()
+ for item in value:
+ deref(self.options).include_columns.push_back(tobytes(item))
+
+ @property
+ def include_missing_columns(self):
+ """
+ If false, columns in `include_columns` but not in the CSV file will
+ error out.
+ If true, columns in `include_columns` but not in the CSV file will
+ produce a null column (whose type is selected using `column_types`,
+ or null by default).
+ This option is ignored if `include_columns` is empty.
+ """
+ return deref(self.options).include_missing_columns
+
+ @include_missing_columns.setter
+ def include_missing_columns(self, value):
+ deref(self.options).include_missing_columns = value
+
+ @property
+ def timestamp_parsers(self):
+ """
+ A sequence of strptime()-compatible format strings, tried in order
+ when attempting to infer or convert timestamp values (the special
+ value ISO8601() can also be given). By default, a fast built-in
+ ISO-8601 parser is used.
+ """
+ cdef:
+ shared_ptr[CTimestampParser] c_parser
+ c_string kind
+
+ parsers = []
+ for c_parser in deref(self.options).timestamp_parsers:
+ kind = deref(c_parser).kind()
+ if kind == b'strptime':
+ parsers.append(frombytes(deref(c_parser).format()))
+ else:
+ assert kind == b'iso8601'
+ parsers.append(ISO8601)
+
+ return parsers
+
+ @timestamp_parsers.setter
+ def timestamp_parsers(self, value):
+ cdef:
+ vector[shared_ptr[CTimestampParser]] c_parsers
+
+ for v in value:
+ if isinstance(v, str):
+ c_parsers.push_back(CTimestampParser.MakeStrptime(tobytes(v)))
+ elif v == ISO8601:
+ c_parsers.push_back(CTimestampParser.MakeISO8601())
+ else:
+ raise TypeError("Expected list of str or ISO8601 objects")
+
+ deref(self.options).timestamp_parsers = move(c_parsers)
+
+ @staticmethod
+ cdef ConvertOptions wrap(CCSVConvertOptions options):
+ out = ConvertOptions()
+ out.options.reset(new CCSVConvertOptions(move(options)))
+ return out
+
+ def validate(self):
+ check_status(deref(self.options).Validate())
+
+ def equals(self, ConvertOptions other):
+ return (
+ self.check_utf8 == other.check_utf8 and
+ self.column_types == other.column_types and
+ self.null_values == other.null_values and
+ self.true_values == other.true_values and
+ self.false_values == other.false_values and
+ self.decimal_point == other.decimal_point and
+ self.timestamp_parsers == other.timestamp_parsers and
+ self.strings_can_be_null == other.strings_can_be_null and
+ self.quoted_strings_can_be_null ==
+ other.quoted_strings_can_be_null and
+ self.auto_dict_encode == other.auto_dict_encode and
+ self.auto_dict_max_cardinality ==
+ other.auto_dict_max_cardinality and
+ self.include_columns == other.include_columns and
+ self.include_missing_columns == other.include_missing_columns
+ )
+
+ def __getstate__(self):
+ return (self.check_utf8, self.column_types, self.null_values,
+ self.true_values, self.false_values, self.decimal_point,
+ self.timestamp_parsers, self.strings_can_be_null,
+ self.quoted_strings_can_be_null, self.auto_dict_encode,
+ self.auto_dict_max_cardinality, self.include_columns,
+ self.include_missing_columns)
+
+ def __setstate__(self, state):
+ (self.check_utf8, self.column_types, self.null_values,
+ self.true_values, self.false_values, self.decimal_point,
+ self.timestamp_parsers, self.strings_can_be_null,
+ self.quoted_strings_can_be_null, self.auto_dict_encode,
+ self.auto_dict_max_cardinality, self.include_columns,
+ self.include_missing_columns) = state
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return False
+
+
+cdef _get_reader(input_file, ReadOptions read_options,
+ shared_ptr[CInputStream]* out):
+ use_memory_map = False
+ get_input_stream(input_file, use_memory_map, out)
+ if read_options is not None:
+ out[0] = native_transcoding_input_stream(out[0],
+ read_options.encoding,
+ 'utf8')
+
+
+cdef _get_read_options(ReadOptions read_options, CCSVReadOptions* out):
+ if read_options is None:
+ out[0] = CCSVReadOptions.Defaults()
+ else:
+ out[0] = deref(read_options.options)
+
+
+cdef _get_parse_options(ParseOptions parse_options, CCSVParseOptions* out):
+ if parse_options is None:
+ out[0] = CCSVParseOptions.Defaults()
+ else:
+ out[0] = deref(parse_options.options)
+
+
+cdef _get_convert_options(ConvertOptions convert_options,
+ CCSVConvertOptions* out):
+ if convert_options is None:
+ out[0] = CCSVConvertOptions.Defaults()
+ else:
+ out[0] = deref(convert_options.options)
+
+
+cdef class CSVStreamingReader(RecordBatchReader):
+ """An object that reads record batches incrementally from a CSV file.
+
+ Should not be instantiated directly by user code.
+ """
+ cdef readonly:
+ Schema schema
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, "
+ "use pyarrow.csv.open_csv() instead."
+ .format(self.__class__.__name__))
+
+ # Note about cancellation: we cannot create a SignalStopHandler
+ # by default here, as several CSVStreamingReader instances may be
+ # created (including by the same thread). Handling cancellation
+ # would require having the user pass the SignalStopHandler.
+ # (in addition to solving ARROW-11853)
+
+ cdef _open(self, shared_ptr[CInputStream] stream,
+ CCSVReadOptions c_read_options,
+ CCSVParseOptions c_parse_options,
+ CCSVConvertOptions c_convert_options,
+ MemoryPool memory_pool):
+ cdef:
+ shared_ptr[CSchema] c_schema
+ CIOContext io_context
+
+ io_context = CIOContext(maybe_unbox_memory_pool(memory_pool))
+
+ with nogil:
+ self.reader = <shared_ptr[CRecordBatchReader]> GetResultValue(
+ CCSVStreamingReader.Make(
+ io_context, stream,
+ move(c_read_options), move(c_parse_options),
+ move(c_convert_options)))
+ c_schema = self.reader.get().schema()
+
+ self.schema = pyarrow_wrap_schema(c_schema)
+
+
+def read_csv(input_file, read_options=None, parse_options=None,
+ convert_options=None, MemoryPool memory_pool=None):
+ """
+ Read a Table from a stream of CSV data.
+
+ Parameters
+ ----------
+ input_file : string, path or file-like object
+ The location of CSV data. If a string or path, and if it ends
+ with a recognized compressed file extension (e.g. ".gz" or ".bz2"),
+ the data is automatically decompressed when reading.
+ read_options : pyarrow.csv.ReadOptions, optional
+ Options for the CSV reader (see pyarrow.csv.ReadOptions constructor
+ for defaults)
+ parse_options : pyarrow.csv.ParseOptions, optional
+ Options for the CSV parser
+ (see pyarrow.csv.ParseOptions constructor for defaults)
+ convert_options : pyarrow.csv.ConvertOptions, optional
+ Options for converting CSV data
+ (see pyarrow.csv.ConvertOptions constructor for defaults)
+ memory_pool : MemoryPool, optional
+ Pool to allocate Table memory from
+
+ Returns
+ -------
+ :class:`pyarrow.Table`
+ Contents of the CSV file as a in-memory table.
+ """
+ cdef:
+ shared_ptr[CInputStream] stream
+ CCSVReadOptions c_read_options
+ CCSVParseOptions c_parse_options
+ CCSVConvertOptions c_convert_options
+ CIOContext io_context
+ shared_ptr[CCSVReader] reader
+ shared_ptr[CTable] table
+
+ _get_reader(input_file, read_options, &stream)
+ _get_read_options(read_options, &c_read_options)
+ _get_parse_options(parse_options, &c_parse_options)
+ _get_convert_options(convert_options, &c_convert_options)
+
+ with SignalStopHandler() as stop_handler:
+ io_context = CIOContext(
+ maybe_unbox_memory_pool(memory_pool),
+ (<StopToken> stop_handler.stop_token).stop_token)
+ reader = GetResultValue(CCSVReader.Make(
+ io_context, stream,
+ c_read_options, c_parse_options, c_convert_options))
+
+ with nogil:
+ table = GetResultValue(reader.get().Read())
+
+ return pyarrow_wrap_table(table)
+
+
+def open_csv(input_file, read_options=None, parse_options=None,
+ convert_options=None, MemoryPool memory_pool=None):
+ """
+ Open a streaming reader of CSV data.
+
+ Reading using this function is always single-threaded.
+
+ Parameters
+ ----------
+ input_file : string, path or file-like object
+ The location of CSV data. If a string or path, and if it ends
+ with a recognized compressed file extension (e.g. ".gz" or ".bz2"),
+ the data is automatically decompressed when reading.
+ read_options : pyarrow.csv.ReadOptions, optional
+ Options for the CSV reader (see pyarrow.csv.ReadOptions constructor
+ for defaults)
+ parse_options : pyarrow.csv.ParseOptions, optional
+ Options for the CSV parser
+ (see pyarrow.csv.ParseOptions constructor for defaults)
+ convert_options : pyarrow.csv.ConvertOptions, optional
+ Options for converting CSV data
+ (see pyarrow.csv.ConvertOptions constructor for defaults)
+ memory_pool : MemoryPool, optional
+ Pool to allocate Table memory from
+
+ Returns
+ -------
+ :class:`pyarrow.csv.CSVStreamingReader`
+ """
+ cdef:
+ shared_ptr[CInputStream] stream
+ CCSVReadOptions c_read_options
+ CCSVParseOptions c_parse_options
+ CCSVConvertOptions c_convert_options
+ CSVStreamingReader reader
+
+ _get_reader(input_file, read_options, &stream)
+ _get_read_options(read_options, &c_read_options)
+ _get_parse_options(parse_options, &c_parse_options)
+ _get_convert_options(convert_options, &c_convert_options)
+
+ reader = CSVStreamingReader.__new__(CSVStreamingReader)
+ reader._open(stream, move(c_read_options), move(c_parse_options),
+ move(c_convert_options), memory_pool)
+ return reader
+
+
+cdef class WriteOptions(_Weakrefable):
+ """
+ Options for writing CSV files.
+
+ Parameters
+ ----------
+ include_header : bool, optional (default True)
+ Whether to write an initial header line with column names
+ batch_size : int, optional (default 1024)
+ How many rows to process together when converting and writing
+ CSV data
+ """
+
+ # Avoid mistakingly creating attributes
+ __slots__ = ()
+
+ def __init__(self, *, include_header=None, batch_size=None):
+ self.options.reset(new CCSVWriteOptions(CCSVWriteOptions.Defaults()))
+ if include_header is not None:
+ self.include_header = include_header
+ if batch_size is not None:
+ self.batch_size = batch_size
+
+ @property
+ def include_header(self):
+ """
+ Whether to write an initial header line with column names.
+ """
+ return deref(self.options).include_header
+
+ @include_header.setter
+ def include_header(self, value):
+ deref(self.options).include_header = value
+
+ @property
+ def batch_size(self):
+ """
+ How many rows to process together when converting and writing
+ CSV data.
+ """
+ return deref(self.options).batch_size
+
+ @batch_size.setter
+ def batch_size(self, value):
+ deref(self.options).batch_size = value
+
+ @staticmethod
+ cdef WriteOptions wrap(CCSVWriteOptions options):
+ out = WriteOptions()
+ out.options.reset(new CCSVWriteOptions(move(options)))
+ return out
+
+ def validate(self):
+ check_status(self.options.get().Validate())
+
+
+cdef _get_write_options(WriteOptions write_options, CCSVWriteOptions* out):
+ if write_options is None:
+ out[0] = CCSVWriteOptions.Defaults()
+ else:
+ out[0] = deref(write_options.options)
+
+
+def write_csv(data, output_file, write_options=None,
+ MemoryPool memory_pool=None):
+ """
+ Write record batch or table to a CSV file.
+
+ Parameters
+ ----------
+ data : pyarrow.RecordBatch or pyarrow.Table
+ The data to write.
+ output_file : string, path, pyarrow.NativeFile, or file-like object
+ The location where to write the CSV data.
+ write_options : pyarrow.csv.WriteOptions
+ Options to configure writing the CSV data.
+ memory_pool : MemoryPool, optional
+ Pool for temporary allocations.
+ """
+ cdef:
+ shared_ptr[COutputStream] stream
+ CCSVWriteOptions c_write_options
+ CMemoryPool* c_memory_pool
+ CRecordBatch* batch
+ CTable* table
+ _get_write_options(write_options, &c_write_options)
+
+ get_writer(output_file, &stream)
+ c_memory_pool = maybe_unbox_memory_pool(memory_pool)
+ c_write_options.io_context = CIOContext(c_memory_pool)
+ if isinstance(data, RecordBatch):
+ batch = pyarrow_unwrap_batch(data).get()
+ with nogil:
+ check_status(WriteCSV(deref(batch), c_write_options, stream.get()))
+ elif isinstance(data, Table):
+ table = pyarrow_unwrap_table(data).get()
+ with nogil:
+ check_status(WriteCSV(deref(table), c_write_options, stream.get()))
+ else:
+ raise TypeError(f"Expected Table or RecordBatch, got '{type(data)}'")
+
+
+cdef class CSVWriter(_CRecordBatchWriter):
+ """
+ Writer to create a CSV file.
+
+ Parameters
+ ----------
+ sink : str, path, pyarrow.OutputStream or file-like object
+ The location where to write the CSV data.
+ schema : pyarrow.Schema
+ The schema of the data to be written.
+ write_options : pyarrow.csv.WriteOptions
+ Options to configure writing the CSV data.
+ memory_pool : MemoryPool, optional
+ Pool for temporary allocations.
+ """
+
+ def __init__(self, sink, Schema schema, *,
+ WriteOptions write_options=None, MemoryPool memory_pool=None):
+ cdef:
+ shared_ptr[COutputStream] c_stream
+ shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
+ CCSVWriteOptions c_write_options
+ CMemoryPool* c_memory_pool = maybe_unbox_memory_pool(memory_pool)
+ _get_write_options(write_options, &c_write_options)
+ c_write_options.io_context = CIOContext(c_memory_pool)
+ get_writer(sink, &c_stream)
+ with nogil:
+ self.writer = GetResultValue(MakeCSVWriter(
+ c_stream, c_schema, c_write_options))
diff --git a/src/arrow/python/pyarrow/_cuda.pxd b/src/arrow/python/pyarrow/_cuda.pxd
new file mode 100644
index 000000000..6acb8826d
--- /dev/null
+++ b/src/arrow/python/pyarrow/_cuda.pxd
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from pyarrow.lib cimport *
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_cuda cimport *
+
+
+cdef class Context(_Weakrefable):
+ cdef:
+ shared_ptr[CCudaContext] context
+ int device_number
+
+ cdef void init(self, const shared_ptr[CCudaContext]& ctx)
+
+
+cdef class IpcMemHandle(_Weakrefable):
+ cdef:
+ shared_ptr[CCudaIpcMemHandle] handle
+
+ cdef void init(self, shared_ptr[CCudaIpcMemHandle]& h)
+
+
+cdef class CudaBuffer(Buffer):
+ cdef:
+ shared_ptr[CCudaBuffer] cuda_buffer
+ object base
+
+ cdef void init_cuda(self,
+ const shared_ptr[CCudaBuffer]& buffer,
+ object base)
+
+
+cdef class HostBuffer(Buffer):
+ cdef:
+ shared_ptr[CCudaHostBuffer] host_buffer
+
+ cdef void init_host(self, const shared_ptr[CCudaHostBuffer]& buffer)
+
+
+cdef class BufferReader(NativeFile):
+ cdef:
+ CCudaBufferReader* reader
+ CudaBuffer buffer
+
+
+cdef class BufferWriter(NativeFile):
+ cdef:
+ CCudaBufferWriter* writer
+ CudaBuffer buffer
diff --git a/src/arrow/python/pyarrow/_cuda.pyx b/src/arrow/python/pyarrow/_cuda.pyx
new file mode 100644
index 000000000..1b66b9508
--- /dev/null
+++ b/src/arrow/python/pyarrow/_cuda.pyx
@@ -0,0 +1,1060 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from pyarrow.lib import tobytes
+from pyarrow.lib cimport *
+from pyarrow.includes.libarrow_cuda cimport *
+from pyarrow.lib import py_buffer, allocate_buffer, as_buffer, ArrowTypeError
+from pyarrow.util import get_contiguous_span
+cimport cpython as cp
+
+
+cdef class Context(_Weakrefable):
+ """
+ CUDA driver context.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Create a CUDA driver context for a particular device.
+
+ If a CUDA context handle is passed, it is wrapped, otherwise
+ a default CUDA context for the given device is requested.
+
+ Parameters
+ ----------
+ device_number : int (default 0)
+ Specify the GPU device for which the CUDA driver context is
+ requested.
+ handle : int, optional
+ Specify CUDA handle for a shared context that has been created
+ by another library.
+ """
+ # This method exposed because autodoc doesn't pick __cinit__
+
+ def __cinit__(self, int device_number=0, uintptr_t handle=0):
+ cdef CCudaDeviceManager* manager
+ manager = GetResultValue(CCudaDeviceManager.Instance())
+ cdef int n = manager.num_devices()
+ if device_number >= n or device_number < 0:
+ self.context.reset()
+ raise ValueError('device_number argument must be '
+ 'non-negative less than %s' % (n))
+ if handle == 0:
+ self.context = GetResultValue(manager.GetContext(device_number))
+ else:
+ self.context = GetResultValue(manager.GetSharedContext(
+ device_number, <void*>handle))
+ self.device_number = device_number
+
+ @staticmethod
+ def from_numba(context=None):
+ """
+ Create a Context instance from a Numba CUDA context.
+
+ Parameters
+ ----------
+ context : {numba.cuda.cudadrv.driver.Context, None}
+ A Numba CUDA context instance.
+ If None, the current Numba context is used.
+
+ Returns
+ -------
+ shared_context : pyarrow.cuda.Context
+ Context instance.
+ """
+ if context is None:
+ import numba.cuda
+ context = numba.cuda.current_context()
+ return Context(device_number=context.device.id,
+ handle=context.handle.value)
+
+ def to_numba(self):
+ """
+ Convert Context to a Numba CUDA context.
+
+ Returns
+ -------
+ context : numba.cuda.cudadrv.driver.Context
+ Numba CUDA context instance.
+ """
+ import ctypes
+ import numba.cuda
+ device = numba.cuda.gpus[self.device_number]
+ handle = ctypes.c_void_p(self.handle)
+ context = numba.cuda.cudadrv.driver.Context(device, handle)
+
+ class DummyPendingDeallocs(object):
+ # Context is managed by pyarrow
+ def add_item(self, *args, **kwargs):
+ pass
+
+ context.deallocations = DummyPendingDeallocs()
+ return context
+
+ @staticmethod
+ def get_num_devices():
+ """ Return the number of GPU devices.
+ """
+ cdef CCudaDeviceManager* manager
+ manager = GetResultValue(CCudaDeviceManager.Instance())
+ return manager.num_devices()
+
+ @property
+ def device_number(self):
+ """ Return context device number.
+ """
+ return self.device_number
+
+ @property
+ def handle(self):
+ """ Return pointer to context handle.
+ """
+ return <uintptr_t>self.context.get().handle()
+
+ cdef void init(self, const shared_ptr[CCudaContext]& ctx):
+ self.context = ctx
+
+ def synchronize(self):
+ """Blocks until the device has completed all preceding requested
+ tasks.
+ """
+ check_status(self.context.get().Synchronize())
+
+ @property
+ def bytes_allocated(self):
+ """Return the number of allocated bytes.
+ """
+ return self.context.get().bytes_allocated()
+
+ def get_device_address(self, uintptr_t address):
+ """Return the device address that is reachable from kernels running in
+ the context
+
+ Parameters
+ ----------
+ address : int
+ Specify memory address value
+
+ Returns
+ -------
+ device_address : int
+ Device address accessible from device context
+
+ Notes
+ -----
+ The device address is defined as a memory address accessible
+ by device. While it is often a device memory address but it
+ can be also a host memory address, for instance, when the
+ memory is allocated as host memory (using cudaMallocHost or
+ cudaHostAlloc) or as managed memory (using cudaMallocManaged)
+ or the host memory is page-locked (using cudaHostRegister).
+ """
+ return GetResultValue(self.context.get().GetDeviceAddress(address))
+
+ def new_buffer(self, int64_t nbytes):
+ """Return new device buffer.
+
+ Parameters
+ ----------
+ nbytes : int
+ Specify the number of bytes to be allocated.
+
+ Returns
+ -------
+ buf : CudaBuffer
+ Allocated buffer.
+ """
+ cdef:
+ shared_ptr[CCudaBuffer] cudabuf
+ with nogil:
+ cudabuf = GetResultValue(self.context.get().Allocate(nbytes))
+ return pyarrow_wrap_cudabuffer(cudabuf)
+
+ def foreign_buffer(self, address, size, base=None):
+ """
+ Create device buffer from address and size as a view.
+
+ The caller is responsible for allocating and freeing the
+ memory. When `address==size==0` then a new zero-sized buffer
+ is returned.
+
+ Parameters
+ ----------
+ address : int
+ Specify the starting address of the buffer. The address can
+ refer to both device or host memory but it must be
+ accessible from device after mapping it with
+ `get_device_address` method.
+ size : int
+ Specify the size of device buffer in bytes.
+ base : {None, object}
+ Specify object that owns the referenced memory.
+
+ Returns
+ -------
+ cbuf : CudaBuffer
+ Device buffer as a view of device reachable memory.
+
+ """
+ if not address and size == 0:
+ return self.new_buffer(0)
+ cdef:
+ uintptr_t c_addr = self.get_device_address(address)
+ int64_t c_size = size
+ shared_ptr[CCudaBuffer] cudabuf
+
+ cudabuf = GetResultValue(self.context.get().View(
+ <uint8_t*>c_addr, c_size))
+ return pyarrow_wrap_cudabuffer_base(cudabuf, base)
+
+ def open_ipc_buffer(self, ipc_handle):
+ """ Open existing CUDA IPC memory handle
+
+ Parameters
+ ----------
+ ipc_handle : IpcMemHandle
+ Specify opaque pointer to CUipcMemHandle (driver API).
+
+ Returns
+ -------
+ buf : CudaBuffer
+ referencing device buffer
+ """
+ handle = pyarrow_unwrap_cudaipcmemhandle(ipc_handle)
+ cdef shared_ptr[CCudaBuffer] cudabuf
+ with nogil:
+ cudabuf = GetResultValue(
+ self.context.get().OpenIpcBuffer(handle.get()[0]))
+ return pyarrow_wrap_cudabuffer(cudabuf)
+
+ def buffer_from_data(self, object data, int64_t offset=0, int64_t size=-1):
+ """Create device buffer and initialize with data.
+
+ Parameters
+ ----------
+ data : {CudaBuffer, HostBuffer, Buffer, array-like}
+ Specify data to be copied to device buffer.
+ offset : int
+ Specify the offset of input buffer for device data
+ buffering. Default: 0.
+ size : int
+ Specify the size of device buffer in bytes. Default: all
+ (starting from input offset)
+
+ Returns
+ -------
+ cbuf : CudaBuffer
+ Device buffer with copied data.
+ """
+ is_host_data = not pyarrow_is_cudabuffer(data)
+ buf = as_buffer(data) if is_host_data else data
+
+ bsize = buf.size
+ if offset < 0 or (bsize and offset >= bsize):
+ raise ValueError('offset argument is out-of-range')
+ if size < 0:
+ size = bsize - offset
+ elif offset + size > bsize:
+ raise ValueError(
+ 'requested larger slice than available in device buffer')
+
+ if offset != 0 or size != bsize:
+ buf = buf.slice(offset, size)
+
+ result = self.new_buffer(size)
+ if is_host_data:
+ result.copy_from_host(buf, position=0, nbytes=size)
+ else:
+ result.copy_from_device(buf, position=0, nbytes=size)
+ return result
+
+ def buffer_from_object(self, obj):
+ """Create device buffer view of arbitrary object that references
+ device accessible memory.
+
+ When the object contains a non-contiguous view of device
+ accessible memory then the returned device buffer will contain
+ contiguous view of the memory, that is, including the
+ intermediate data that is otherwise invisible to the input
+ object.
+
+ Parameters
+ ----------
+ obj : {object, Buffer, HostBuffer, CudaBuffer, ...}
+ Specify an object that holds (device or host) address that
+ can be accessed from device. This includes objects with
+ types defined in pyarrow.cuda as well as arbitrary objects
+ that implement the CUDA array interface as defined by numba.
+
+ Returns
+ -------
+ cbuf : CudaBuffer
+ Device buffer as a view of device accessible memory.
+
+ """
+ if isinstance(obj, HostBuffer):
+ return self.foreign_buffer(obj.address, obj.size, base=obj)
+ elif isinstance(obj, Buffer):
+ return CudaBuffer.from_buffer(obj)
+ elif isinstance(obj, CudaBuffer):
+ return obj
+ elif hasattr(obj, '__cuda_array_interface__'):
+ desc = obj.__cuda_array_interface__
+ addr = desc['data'][0]
+ if addr is None:
+ return self.new_buffer(0)
+ import numpy as np
+ start, end = get_contiguous_span(
+ desc['shape'], desc.get('strides'),
+ np.dtype(desc['typestr']).itemsize)
+ return self.foreign_buffer(addr + start, end - start, base=obj)
+ raise ArrowTypeError('cannot create device buffer view from'
+ ' `%s` object' % (type(obj)))
+
+
+cdef class IpcMemHandle(_Weakrefable):
+ """A serializable container for a CUDA IPC handle.
+ """
+ cdef void init(self, shared_ptr[CCudaIpcMemHandle]& h):
+ self.handle = h
+
+ @staticmethod
+ def from_buffer(Buffer opaque_handle):
+ """Create IpcMemHandle from opaque buffer (e.g. from another
+ process)
+
+ Parameters
+ ----------
+ opaque_handle :
+ a CUipcMemHandle as a const void*
+
+ Results
+ -------
+ ipc_handle : IpcMemHandle
+ """
+ c_buf = pyarrow_unwrap_buffer(opaque_handle)
+ cdef:
+ shared_ptr[CCudaIpcMemHandle] handle
+
+ handle = GetResultValue(
+ CCudaIpcMemHandle.FromBuffer(c_buf.get().data()))
+ return pyarrow_wrap_cudaipcmemhandle(handle)
+
+ def serialize(self, pool=None):
+ """Write IpcMemHandle to a Buffer
+
+ Parameters
+ ----------
+ pool : {MemoryPool, None}
+ Specify a pool to allocate memory from
+
+ Returns
+ -------
+ buf : Buffer
+ The serialized buffer.
+ """
+ cdef CMemoryPool* pool_ = maybe_unbox_memory_pool(pool)
+ cdef shared_ptr[CBuffer] buf
+ cdef CCudaIpcMemHandle* h = self.handle.get()
+ with nogil:
+ buf = GetResultValue(h.Serialize(pool_))
+ return pyarrow_wrap_buffer(buf)
+
+
+cdef class CudaBuffer(Buffer):
+ """An Arrow buffer with data located in a GPU device.
+
+ To create a CudaBuffer instance, use Context.device_buffer().
+
+ The memory allocated in a CudaBuffer is freed when the buffer object
+ is deleted.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call CudaBuffer's constructor directly, use "
+ "`<pyarrow.Context instance>.device_buffer`"
+ " method instead.")
+
+ cdef void init_cuda(self,
+ const shared_ptr[CCudaBuffer]& buffer,
+ object base):
+ self.cuda_buffer = buffer
+ self.init(<shared_ptr[CBuffer]> buffer)
+ self.base = base
+
+ @staticmethod
+ def from_buffer(buf):
+ """ Convert back generic buffer into CudaBuffer
+
+ Parameters
+ ----------
+ buf : Buffer
+ Specify buffer containing CudaBuffer
+
+ Returns
+ -------
+ dbuf : CudaBuffer
+ Resulting device buffer.
+ """
+ c_buf = pyarrow_unwrap_buffer(buf)
+ cuda_buffer = GetResultValue(CCudaBuffer.FromBuffer(c_buf))
+ return pyarrow_wrap_cudabuffer(cuda_buffer)
+
+ @staticmethod
+ def from_numba(mem):
+ """Create a CudaBuffer view from numba MemoryPointer instance.
+
+ Parameters
+ ----------
+ mem : numba.cuda.cudadrv.driver.MemoryPointer
+
+ Returns
+ -------
+ cbuf : CudaBuffer
+ Device buffer as a view of numba MemoryPointer.
+ """
+ ctx = Context.from_numba(mem.context)
+ if mem.device_pointer.value is None and mem.size==0:
+ return ctx.new_buffer(0)
+ return ctx.foreign_buffer(mem.device_pointer.value, mem.size, base=mem)
+
+ def to_numba(self):
+ """Return numba memory pointer of CudaBuffer instance.
+ """
+ import ctypes
+ from numba.cuda.cudadrv.driver import MemoryPointer
+ return MemoryPointer(self.context.to_numba(),
+ pointer=ctypes.c_void_p(self.address),
+ size=self.size)
+
+ cdef getitem(self, int64_t i):
+ return self.copy_to_host(position=i, nbytes=1)[0]
+
+ def copy_to_host(self, int64_t position=0, int64_t nbytes=-1,
+ Buffer buf=None,
+ MemoryPool memory_pool=None, c_bool resizable=False):
+ """Copy memory from GPU device to CPU host
+
+ Caller is responsible for ensuring that all tasks affecting
+ the memory are finished. Use
+
+ `<CudaBuffer instance>.context.synchronize()`
+
+ when needed.
+
+ Parameters
+ ----------
+ position : int
+ Specify the starting position of the source data in GPU
+ device buffer. Default: 0.
+ nbytes : int
+ Specify the number of bytes to copy. Default: -1 (all from
+ the position until host buffer is full).
+ buf : Buffer
+ Specify a pre-allocated output buffer in host. Default: None
+ (allocate new output buffer).
+ memory_pool : MemoryPool
+ resizable : bool
+ Specify extra arguments to allocate_buffer. Used only when
+ buf is None.
+
+ Returns
+ -------
+ buf : Buffer
+ Output buffer in host.
+
+ """
+ if position < 0 or (self.size and position > self.size) \
+ or (self.size == 0 and position != 0):
+ raise ValueError('position argument is out-of-range')
+ cdef:
+ int64_t c_nbytes
+ if buf is None:
+ if nbytes < 0:
+ # copy all starting from position to new host buffer
+ c_nbytes = self.size - position
+ else:
+ if nbytes > self.size - position:
+ raise ValueError(
+ 'requested more to copy than available from '
+ 'device buffer')
+ # copy nbytes starting from position to new host buffeer
+ c_nbytes = nbytes
+ buf = allocate_buffer(c_nbytes, memory_pool=memory_pool,
+ resizable=resizable)
+ else:
+ if nbytes < 0:
+ # copy all from position until given host buffer is full
+ c_nbytes = min(self.size - position, buf.size)
+ else:
+ if nbytes > buf.size:
+ raise ValueError(
+ 'requested copy does not fit into host buffer')
+ # copy nbytes from position to given host buffer
+ c_nbytes = nbytes
+
+ cdef:
+ shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf)
+ int64_t c_position = position
+ with nogil:
+ check_status(self.cuda_buffer.get()
+ .CopyToHost(c_position, c_nbytes,
+ c_buf.get().mutable_data()))
+ return buf
+
+ def copy_from_host(self, data, int64_t position=0, int64_t nbytes=-1):
+ """Copy data from host to device.
+
+ The device buffer must be pre-allocated.
+
+ Parameters
+ ----------
+ data : {Buffer, array-like}
+ Specify data in host. It can be array-like that is valid
+ argument to py_buffer
+ position : int
+ Specify the starting position of the copy in device buffer.
+ Default: 0.
+ nbytes : int
+ Specify the number of bytes to copy. Default: -1 (all from
+ source until device buffer, starting from position, is full)
+
+ Returns
+ -------
+ nbytes : int
+ Number of bytes copied.
+ """
+ if position < 0 or position > self.size:
+ raise ValueError('position argument is out-of-range')
+ cdef:
+ int64_t c_nbytes
+ buf = as_buffer(data)
+
+ if nbytes < 0:
+ # copy from host buffer to device buffer starting from
+ # position until device buffer is full
+ c_nbytes = min(self.size - position, buf.size)
+ else:
+ if nbytes > buf.size:
+ raise ValueError(
+ 'requested more to copy than available from host buffer')
+ if nbytes > self.size - position:
+ raise ValueError(
+ 'requested more to copy than available in device buffer')
+ # copy nbytes from host buffer to device buffer starting
+ # from position
+ c_nbytes = nbytes
+
+ cdef:
+ shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf)
+ int64_t c_position = position
+ with nogil:
+ check_status(self.cuda_buffer.get().
+ CopyFromHost(c_position, c_buf.get().data(),
+ c_nbytes))
+ return c_nbytes
+
+ def copy_from_device(self, buf, int64_t position=0, int64_t nbytes=-1):
+ """Copy data from device to device.
+
+ Parameters
+ ----------
+ buf : CudaBuffer
+ Specify source device buffer.
+ position : int
+ Specify the starting position of the copy in device buffer.
+ Default: 0.
+ nbytes : int
+ Specify the number of bytes to copy. Default: -1 (all from
+ source until device buffer, starting from position, is full)
+
+ Returns
+ -------
+ nbytes : int
+ Number of bytes copied.
+
+ """
+ if position < 0 or position > self.size:
+ raise ValueError('position argument is out-of-range')
+ cdef:
+ int64_t c_nbytes
+
+ if nbytes < 0:
+ # copy from source device buffer to device buffer starting
+ # from position until device buffer is full
+ c_nbytes = min(self.size - position, buf.size)
+ else:
+ if nbytes > buf.size:
+ raise ValueError(
+ 'requested more to copy than available from device buffer')
+ if nbytes > self.size - position:
+ raise ValueError(
+ 'requested more to copy than available in device buffer')
+ # copy nbytes from source device buffer to device buffer
+ # starting from position
+ c_nbytes = nbytes
+
+ cdef:
+ shared_ptr[CCudaBuffer] c_buf = pyarrow_unwrap_cudabuffer(buf)
+ int64_t c_position = position
+ shared_ptr[CCudaContext] c_src_ctx = pyarrow_unwrap_cudacontext(
+ buf.context)
+ void* c_source_data = <void*>(c_buf.get().address())
+
+ if self.context.handle != buf.context.handle:
+ with nogil:
+ check_status(self.cuda_buffer.get().
+ CopyFromAnotherDevice(c_src_ctx, c_position,
+ c_source_data, c_nbytes))
+ else:
+ with nogil:
+ check_status(self.cuda_buffer.get().
+ CopyFromDevice(c_position, c_source_data,
+ c_nbytes))
+ return c_nbytes
+
+ def export_for_ipc(self):
+ """
+ Expose this device buffer as IPC memory which can be used in other
+ processes.
+
+ After calling this function, this device memory will not be
+ freed when the CudaBuffer is destructed.
+
+ Returns
+ -------
+ ipc_handle : IpcMemHandle
+ The exported IPC handle
+
+ """
+ cdef shared_ptr[CCudaIpcMemHandle] handle
+ with nogil:
+ handle = GetResultValue(self.cuda_buffer.get().ExportForIpc())
+ return pyarrow_wrap_cudaipcmemhandle(handle)
+
+ @property
+ def context(self):
+ """Returns the CUDA driver context of this buffer.
+ """
+ return pyarrow_wrap_cudacontext(self.cuda_buffer.get().context())
+
+ def slice(self, offset=0, length=None):
+ """Return slice of device buffer
+
+ Parameters
+ ----------
+ offset : int, default 0
+ Specify offset from the start of device buffer to slice
+ length : int, default None
+ Specify the length of slice (default is until end of device
+ buffer starting from offset). If the length is larger than
+ the data available, the returned slice will have a size of
+ the available data starting from the offset.
+
+ Returns
+ -------
+ sliced : CudaBuffer
+ Zero-copy slice of device buffer.
+
+ """
+ if offset < 0 or (self.size and offset >= self.size):
+ raise ValueError('offset argument is out-of-range')
+ cdef int64_t offset_ = offset
+ cdef int64_t size
+ if length is None:
+ size = self.size - offset_
+ elif offset + length <= self.size:
+ size = length
+ else:
+ size = self.size - offset
+ parent = pyarrow_unwrap_cudabuffer(self)
+ return pyarrow_wrap_cudabuffer(make_shared[CCudaBuffer](parent,
+ offset_, size))
+
+ def to_pybytes(self):
+ """Return device buffer content as Python bytes.
+ """
+ return self.copy_to_host().to_pybytes()
+
+ def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
+ # Device buffer contains data pointers on the device. Hence,
+ # cannot support buffer protocol PEP-3118 for CudaBuffer.
+ raise BufferError('buffer protocol for device buffer not supported')
+
+
+cdef class HostBuffer(Buffer):
+ """Device-accessible CPU memory created using cudaHostAlloc.
+
+ To create a HostBuffer instance, use
+
+ cuda.new_host_buffer(<nbytes>)
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call HostBuffer's constructor directly,"
+ " use `cuda.new_host_buffer` function instead.")
+
+ cdef void init_host(self, const shared_ptr[CCudaHostBuffer]& buffer):
+ self.host_buffer = buffer
+ self.init(<shared_ptr[CBuffer]> buffer)
+
+ @property
+ def size(self):
+ return self.host_buffer.get().size()
+
+
+cdef class BufferReader(NativeFile):
+ """File interface for zero-copy read from CUDA buffers.
+
+ Note: Read methods return pointers to device memory. This means
+ you must be careful using this interface with any Arrow code which
+ may expect to be able to do anything other than pointer arithmetic
+ on the returned buffers.
+ """
+
+ def __cinit__(self, CudaBuffer obj):
+ self.buffer = obj
+ self.reader = new CCudaBufferReader(self.buffer.buffer)
+ self.set_random_access_file(
+ shared_ptr[CRandomAccessFile](self.reader))
+ self.is_readable = True
+
+ def read_buffer(self, nbytes=None):
+ """Return a slice view of the underlying device buffer.
+
+ The slice will start at the current reader position and will
+ have specified size in bytes.
+
+ Parameters
+ ----------
+ nbytes : int, default None
+ Specify the number of bytes to read. Default: None (read all
+ remaining bytes).
+
+ Returns
+ -------
+ cbuf : CudaBuffer
+ New device buffer.
+
+ """
+ cdef:
+ int64_t c_nbytes
+ int64_t bytes_read = 0
+ shared_ptr[CCudaBuffer] output
+
+ if nbytes is None:
+ c_nbytes = self.size() - self.tell()
+ else:
+ c_nbytes = nbytes
+
+ with nogil:
+ output = static_pointer_cast[CCudaBuffer, CBuffer](
+ GetResultValue(self.reader.Read(c_nbytes)))
+
+ return pyarrow_wrap_cudabuffer(output)
+
+
+cdef class BufferWriter(NativeFile):
+ """File interface for writing to CUDA buffers.
+
+ By default writes are unbuffered. Use set_buffer_size to enable
+ buffering.
+ """
+
+ def __cinit__(self, CudaBuffer buffer):
+ self.buffer = buffer
+ self.writer = new CCudaBufferWriter(self.buffer.cuda_buffer)
+ self.set_output_stream(shared_ptr[COutputStream](self.writer))
+ self.is_writable = True
+
+ def writeat(self, int64_t position, object data):
+ """Write data to buffer starting from position.
+
+ Parameters
+ ----------
+ position : int
+ Specify device buffer position where the data will be
+ written.
+ data : array-like
+ Specify data, the data instance must implement buffer
+ protocol.
+ """
+ cdef:
+ Buffer buf = as_buffer(data)
+ const uint8_t* c_data = buf.buffer.get().data()
+ int64_t c_size = buf.buffer.get().size()
+
+ with nogil:
+ check_status(self.writer.WriteAt(position, c_data, c_size))
+
+ def flush(self):
+ """ Flush the buffer stream """
+ with nogil:
+ check_status(self.writer.Flush())
+
+ def seek(self, int64_t position, int whence=0):
+ # TODO: remove this method after NativeFile.seek supports
+ # writable files.
+ cdef int64_t offset
+
+ with nogil:
+ if whence == 0:
+ offset = position
+ elif whence == 1:
+ offset = GetResultValue(self.writer.Tell())
+ offset = offset + position
+ else:
+ with gil:
+ raise ValueError("Invalid value of whence: {0}"
+ .format(whence))
+ check_status(self.writer.Seek(offset))
+ return self.tell()
+
+ @property
+ def buffer_size(self):
+ """Returns size of host (CPU) buffer, 0 for unbuffered
+ """
+ return self.writer.buffer_size()
+
+ @buffer_size.setter
+ def buffer_size(self, int64_t buffer_size):
+ """Set CPU buffer size to limit calls to cudaMemcpy
+
+ Parameters
+ ----------
+ buffer_size : int
+ Specify the size of CPU buffer to allocate in bytes.
+ """
+ with nogil:
+ check_status(self.writer.SetBufferSize(buffer_size))
+
+ @property
+ def num_bytes_buffered(self):
+ """Returns number of bytes buffered on host
+ """
+ return self.writer.num_bytes_buffered()
+
+# Functions
+
+
+def new_host_buffer(const int64_t size, int device=0):
+ """Return buffer with CUDA-accessible memory on CPU host
+
+ Parameters
+ ----------
+ size : int
+ Specify the number of bytes to be allocated.
+ device : int
+ Specify GPU device number.
+
+ Returns
+ -------
+ dbuf : HostBuffer
+ Allocated host buffer
+ """
+ cdef shared_ptr[CCudaHostBuffer] buffer
+ with nogil:
+ buffer = GetResultValue(AllocateCudaHostBuffer(device, size))
+ return pyarrow_wrap_cudahostbuffer(buffer)
+
+
+def serialize_record_batch(object batch, object ctx):
+ """ Write record batch message to GPU device memory
+
+ Parameters
+ ----------
+ batch : RecordBatch
+ Record batch to write
+ ctx : Context
+ CUDA Context to allocate device memory from
+
+ Returns
+ -------
+ dbuf : CudaBuffer
+ device buffer which contains the record batch message
+ """
+ cdef shared_ptr[CCudaBuffer] buffer
+ cdef CRecordBatch* batch_ = pyarrow_unwrap_batch(batch).get()
+ cdef CCudaContext* ctx_ = pyarrow_unwrap_cudacontext(ctx).get()
+ with nogil:
+ buffer = GetResultValue(CudaSerializeRecordBatch(batch_[0], ctx_))
+ return pyarrow_wrap_cudabuffer(buffer)
+
+
+def read_message(object source, pool=None):
+ """ Read Arrow IPC message located on GPU device
+
+ Parameters
+ ----------
+ source : {CudaBuffer, cuda.BufferReader}
+ Device buffer or reader of device buffer.
+ pool : MemoryPool (optional)
+ Pool to allocate CPU memory for the metadata
+
+ Returns
+ -------
+ message : Message
+ The deserialized message, body still on device
+ """
+ cdef:
+ Message result = Message.__new__(Message)
+ cdef CMemoryPool* pool_ = maybe_unbox_memory_pool(pool)
+ if not isinstance(source, BufferReader):
+ reader = BufferReader(source)
+ with nogil:
+ result.message = move(
+ GetResultValue(ReadMessage(reader.reader, pool_)))
+ return result
+
+
+def read_record_batch(object buffer, object schema, *,
+ DictionaryMemo dictionary_memo=None, pool=None):
+ """Construct RecordBatch referencing IPC message located on CUDA device.
+
+ While the metadata is copied to host memory for deserialization,
+ the record batch data remains on the device.
+
+ Parameters
+ ----------
+ buffer :
+ Device buffer containing the complete IPC message
+ schema : Schema
+ The schema for the record batch
+ dictionary_memo : DictionaryMemo, optional
+ If message contains dictionaries, must pass a populated
+ DictionaryMemo
+ pool : MemoryPool (optional)
+ Pool to allocate metadata from
+
+ Returns
+ -------
+ batch : RecordBatch
+ Reconstructed record batch, with device pointers
+
+ """
+ cdef:
+ shared_ptr[CSchema] schema_ = pyarrow_unwrap_schema(schema)
+ shared_ptr[CCudaBuffer] buffer_ = pyarrow_unwrap_cudabuffer(buffer)
+ CDictionaryMemo temp_memo
+ CDictionaryMemo* arg_dict_memo
+ CMemoryPool* pool_ = maybe_unbox_memory_pool(pool)
+ shared_ptr[CRecordBatch] batch
+
+ if dictionary_memo is not None:
+ arg_dict_memo = dictionary_memo.memo
+ else:
+ arg_dict_memo = &temp_memo
+
+ with nogil:
+ batch = GetResultValue(CudaReadRecordBatch(
+ schema_, arg_dict_memo, buffer_, pool_))
+ return pyarrow_wrap_batch(batch)
+
+
+# Public API
+
+
+cdef public api bint pyarrow_is_buffer(object buffer):
+ return isinstance(buffer, Buffer)
+
+# cudabuffer
+
+cdef public api bint pyarrow_is_cudabuffer(object buffer):
+ return isinstance(buffer, CudaBuffer)
+
+
+cdef public api object \
+ pyarrow_wrap_cudabuffer_base(const shared_ptr[CCudaBuffer]& buf, base):
+ cdef CudaBuffer result = CudaBuffer.__new__(CudaBuffer)
+ result.init_cuda(buf, base)
+ return result
+
+
+cdef public api object \
+ pyarrow_wrap_cudabuffer(const shared_ptr[CCudaBuffer]& buf):
+ cdef CudaBuffer result = CudaBuffer.__new__(CudaBuffer)
+ result.init_cuda(buf, None)
+ return result
+
+
+cdef public api shared_ptr[CCudaBuffer] pyarrow_unwrap_cudabuffer(object obj):
+ if pyarrow_is_cudabuffer(obj):
+ return (<CudaBuffer>obj).cuda_buffer
+ raise TypeError('expected CudaBuffer instance, got %s'
+ % (type(obj).__name__))
+
+# cudahostbuffer
+
+cdef public api bint pyarrow_is_cudahostbuffer(object buffer):
+ return isinstance(buffer, HostBuffer)
+
+
+cdef public api object \
+ pyarrow_wrap_cudahostbuffer(const shared_ptr[CCudaHostBuffer]& buf):
+ cdef HostBuffer result = HostBuffer.__new__(HostBuffer)
+ result.init_host(buf)
+ return result
+
+
+cdef public api shared_ptr[CCudaHostBuffer] \
+ pyarrow_unwrap_cudahostbuffer(object obj):
+ if pyarrow_is_cudahostbuffer(obj):
+ return (<HostBuffer>obj).host_buffer
+ raise TypeError('expected HostBuffer instance, got %s'
+ % (type(obj).__name__))
+
+# cudacontext
+
+cdef public api bint pyarrow_is_cudacontext(object ctx):
+ return isinstance(ctx, Context)
+
+
+cdef public api object \
+ pyarrow_wrap_cudacontext(const shared_ptr[CCudaContext]& ctx):
+ cdef Context result = Context.__new__(Context)
+ result.init(ctx)
+ return result
+
+
+cdef public api shared_ptr[CCudaContext] \
+ pyarrow_unwrap_cudacontext(object obj):
+ if pyarrow_is_cudacontext(obj):
+ return (<Context>obj).context
+ raise TypeError('expected Context instance, got %s'
+ % (type(obj).__name__))
+
+# cudaipcmemhandle
+
+cdef public api bint pyarrow_is_cudaipcmemhandle(object handle):
+ return isinstance(handle, IpcMemHandle)
+
+
+cdef public api object \
+ pyarrow_wrap_cudaipcmemhandle(shared_ptr[CCudaIpcMemHandle]& h):
+ cdef IpcMemHandle result = IpcMemHandle.__new__(IpcMemHandle)
+ result.init(h)
+ return result
+
+
+cdef public api shared_ptr[CCudaIpcMemHandle] \
+ pyarrow_unwrap_cudaipcmemhandle(object obj):
+ if pyarrow_is_cudaipcmemhandle(obj):
+ return (<IpcMemHandle>obj).handle
+ raise TypeError('expected IpcMemHandle instance, got %s'
+ % (type(obj).__name__))
diff --git a/src/arrow/python/pyarrow/_dataset.pxd b/src/arrow/python/pyarrow/_dataset.pxd
new file mode 100644
index 000000000..875e13f87
--- /dev/null
+++ b/src/arrow/python/pyarrow/_dataset.pxd
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+"""Dataset is currently unstable. APIs subject to change without notice."""
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow_dataset cimport *
+from pyarrow.lib cimport *
+
+
+cdef class FragmentScanOptions(_Weakrefable):
+
+ cdef:
+ shared_ptr[CFragmentScanOptions] wrapped
+
+ cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp)
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFragmentScanOptions]& sp)
+
+
+cdef class FileFormat(_Weakrefable):
+
+ cdef:
+ shared_ptr[CFileFormat] wrapped
+ CFileFormat* format
+
+ cdef void init(self, const shared_ptr[CFileFormat]& sp)
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFileFormat]& sp)
+
+ cdef inline shared_ptr[CFileFormat] unwrap(self)
+
+ cdef _set_default_fragment_scan_options(self, FragmentScanOptions options)
diff --git a/src/arrow/python/pyarrow/_dataset.pyx b/src/arrow/python/pyarrow/_dataset.pyx
new file mode 100644
index 000000000..459c3b8fb
--- /dev/null
+++ b/src/arrow/python/pyarrow/_dataset.pyx
@@ -0,0 +1,3408 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+"""Dataset is currently unstable. APIs subject to change without notice."""
+
+from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE
+from cython.operator cimport dereference as deref
+
+import collections
+import os
+import warnings
+
+import pyarrow as pa
+from pyarrow.lib cimport *
+from pyarrow.lib import ArrowTypeError, frombytes, tobytes
+from pyarrow.includes.libarrow_dataset cimport *
+from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
+from pyarrow._csv cimport (
+ ConvertOptions, ParseOptions, ReadOptions, WriteOptions)
+from pyarrow.util import _is_iterable, _is_path_like, _stringify_path
+
+from pyarrow._parquet cimport (
+ _create_writer_properties, _create_arrow_writer_properties,
+ FileMetaData, RowGroupMetaData, ColumnChunkMetaData
+)
+
+
+def _forbid_instantiation(klass, subclasses_instead=True):
+ msg = '{} is an abstract class thus cannot be initialized.'.format(
+ klass.__name__
+ )
+ if subclasses_instead:
+ subclasses = [cls.__name__ for cls in klass.__subclasses__]
+ msg += ' Use one of the subclasses instead: {}'.format(
+ ', '.join(subclasses)
+ )
+ raise TypeError(msg)
+
+
+_orc_fileformat = None
+_orc_imported = False
+
+
+def _get_orc_fileformat():
+ """
+ Import OrcFileFormat on first usage (to avoid circular import issue
+ when `pyarrow._dataset_orc` would be imported first)
+ """
+ global _orc_fileformat
+ global _orc_imported
+ if not _orc_imported:
+ try:
+ from pyarrow._dataset_orc import OrcFileFormat
+ _orc_fileformat = OrcFileFormat
+ except ImportError as e:
+ _orc_fileformat = None
+ finally:
+ _orc_imported = True
+ return _orc_fileformat
+
+
+cdef CFileSource _make_file_source(object file, FileSystem filesystem=None):
+
+ cdef:
+ CFileSource c_source
+ shared_ptr[CFileSystem] c_filesystem
+ c_string c_path
+ shared_ptr[CRandomAccessFile] c_file
+ shared_ptr[CBuffer] c_buffer
+
+ if isinstance(file, Buffer):
+ c_buffer = pyarrow_unwrap_buffer(file)
+ c_source = CFileSource(move(c_buffer))
+
+ elif _is_path_like(file):
+ if filesystem is None:
+ raise ValueError("cannot construct a FileSource from "
+ "a path without a FileSystem")
+ c_filesystem = filesystem.unwrap()
+ c_path = tobytes(_stringify_path(file))
+ c_source = CFileSource(move(c_path), move(c_filesystem))
+
+ elif hasattr(file, 'read'):
+ # Optimistically hope this is file-like
+ c_file = get_native_file(file, False).get_random_access_file()
+ c_source = CFileSource(move(c_file))
+
+ else:
+ raise TypeError("cannot construct a FileSource "
+ "from " + str(file))
+
+ return c_source
+
+
+cdef CSegmentEncoding _get_segment_encoding(str segment_encoding):
+ if segment_encoding == "none":
+ return CSegmentEncodingNone
+ elif segment_encoding == "uri":
+ return CSegmentEncodingUri
+ raise ValueError(f"Unknown segment encoding: {segment_encoding}")
+
+
+cdef class Expression(_Weakrefable):
+ """
+ A logical expression to be evaluated against some input.
+
+ To create an expression:
+
+ - Use the factory function ``pyarrow.dataset.scalar()`` to create a
+ scalar (not necessary when combined, see example below).
+ - Use the factory function ``pyarrow.dataset.field()`` to reference
+ a field (column in table).
+ - Compare fields and scalars with ``<``, ``<=``, ``==``, ``>=``, ``>``.
+ - Combine expressions using python operators ``&`` (logical and),
+ ``|`` (logical or) and ``~`` (logical not).
+ Note: python keywords ``and``, ``or`` and ``not`` cannot be used
+ to combine expressions.
+ - Check whether the expression is contained in a list of values with
+ the ``pyarrow.dataset.Expression.isin()`` member function.
+
+ Examples
+ --------
+
+ >>> import pyarrow.dataset as ds
+ >>> (ds.field("a") < ds.scalar(3)) | (ds.field("b") > 7)
+ <pyarrow.dataset.Expression ((a < 3:int64) or (b > 7:int64))>
+ >>> ds.field('a') != 3
+ <pyarrow.dataset.Expression (a != 3)>
+ >>> ds.field('a').isin([1, 2, 3])
+ <pyarrow.dataset.Expression (a is in [
+ 1,
+ 2,
+ 3
+ ])>
+ """
+ cdef:
+ CExpression expr
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef void init(self, const CExpression& sp):
+ self.expr = sp
+
+ @staticmethod
+ cdef wrap(const CExpression& sp):
+ cdef Expression self = Expression.__new__(Expression)
+ self.init(sp)
+ return self
+
+ cdef inline CExpression unwrap(self):
+ return self.expr
+
+ def equals(self, Expression other):
+ return self.expr.Equals(other.unwrap())
+
+ def __str__(self):
+ return frombytes(self.expr.ToString())
+
+ def __repr__(self):
+ return "<pyarrow.dataset.{0} {1}>".format(
+ self.__class__.__name__, str(self)
+ )
+
+ @staticmethod
+ def _deserialize(Buffer buffer not None):
+ return Expression.wrap(GetResultValue(CDeserializeExpression(
+ pyarrow_unwrap_buffer(buffer))))
+
+ def __reduce__(self):
+ buffer = pyarrow_wrap_buffer(GetResultValue(
+ CSerializeExpression(self.expr)))
+ return Expression._deserialize, (buffer,)
+
+ @staticmethod
+ cdef Expression _expr_or_scalar(object expr):
+ if isinstance(expr, Expression):
+ return (<Expression> expr)
+ return (<Expression> Expression._scalar(expr))
+
+ @staticmethod
+ cdef Expression _call(str function_name, list arguments,
+ shared_ptr[CFunctionOptions] options=(
+ <shared_ptr[CFunctionOptions]> nullptr)):
+ cdef:
+ vector[CExpression] c_arguments
+
+ for argument in arguments:
+ c_arguments.push_back((<Expression> argument).expr)
+
+ return Expression.wrap(CMakeCallExpression(tobytes(function_name),
+ move(c_arguments), options))
+
+ def __richcmp__(self, other, int op):
+ other = Expression._expr_or_scalar(other)
+ return Expression._call({
+ Py_EQ: "equal",
+ Py_NE: "not_equal",
+ Py_GT: "greater",
+ Py_GE: "greater_equal",
+ Py_LT: "less",
+ Py_LE: "less_equal",
+ }[op], [self, other])
+
+ def __bool__(self):
+ raise ValueError(
+ "An Expression cannot be evaluated to python True or False. "
+ "If you are using the 'and', 'or' or 'not' operators, use '&', "
+ "'|' or '~' instead."
+ )
+
+ def __invert__(self):
+ return Expression._call("invert", [self])
+
+ def __and__(Expression self, other):
+ other = Expression._expr_or_scalar(other)
+ return Expression._call("and_kleene", [self, other])
+
+ def __or__(Expression self, other):
+ other = Expression._expr_or_scalar(other)
+ return Expression._call("or_kleene", [self, other])
+
+ def __add__(Expression self, other):
+ other = Expression._expr_or_scalar(other)
+ return Expression._call("add_checked", [self, other])
+
+ def __mul__(Expression self, other):
+ other = Expression._expr_or_scalar(other)
+ return Expression._call("multiply_checked", [self, other])
+
+ def __sub__(Expression self, other):
+ other = Expression._expr_or_scalar(other)
+ return Expression._call("subtract_checked", [self, other])
+
+ def __truediv__(Expression self, other):
+ other = Expression._expr_or_scalar(other)
+ return Expression._call("divide_checked", [self, other])
+
+ def is_valid(self):
+ """Checks whether the expression is not-null (valid)"""
+ return Expression._call("is_valid", [self])
+
+ def is_null(self, bint nan_is_null=False):
+ """Checks whether the expression is null"""
+ cdef:
+ shared_ptr[CFunctionOptions] c_options
+
+ c_options.reset(new CNullOptions(nan_is_null))
+ return Expression._call("is_null", [self], c_options)
+
+ def cast(self, type, bint safe=True):
+ """Explicitly change the expression's data type"""
+ cdef shared_ptr[CCastOptions] c_options
+ c_options.reset(new CCastOptions(safe))
+ c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type))
+ return Expression._call("cast", [self],
+ <shared_ptr[CFunctionOptions]> c_options)
+
+ def isin(self, values):
+ """Checks whether the expression is contained in values"""
+ cdef:
+ shared_ptr[CFunctionOptions] c_options
+ CDatum c_values
+
+ if not isinstance(values, pa.Array):
+ values = pa.array(values)
+
+ c_values = CDatum(pyarrow_unwrap_array(values))
+ c_options.reset(new CSetLookupOptions(c_values, True))
+ return Expression._call("is_in", [self], c_options)
+
+ @staticmethod
+ def _field(str name not None):
+ return Expression.wrap(CMakeFieldExpression(tobytes(name)))
+
+ @staticmethod
+ def _scalar(value):
+ cdef:
+ Scalar scalar
+
+ if isinstance(value, Scalar):
+ scalar = value
+ else:
+ scalar = pa.scalar(value)
+
+ return Expression.wrap(CMakeScalarExpression(scalar.unwrap()))
+
+
+_deserialize = Expression._deserialize
+cdef Expression _true = Expression._scalar(True)
+
+
+cdef class Dataset(_Weakrefable):
+ """
+ Collection of data fragments and potentially child datasets.
+
+ Arrow Datasets allow you to query against data that has been split across
+ multiple files. This sharding of data may indicate partitioning, which
+ can accelerate queries that only touch some partitions (files).
+ """
+
+ cdef:
+ shared_ptr[CDataset] wrapped
+ CDataset* dataset
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef void init(self, const shared_ptr[CDataset]& sp):
+ self.wrapped = sp
+ self.dataset = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CDataset]& sp):
+ type_name = frombytes(sp.get().type_name())
+
+ classes = {
+ 'union': UnionDataset,
+ 'filesystem': FileSystemDataset,
+ 'in-memory': InMemoryDataset,
+ }
+
+ class_ = classes.get(type_name, None)
+ if class_ is None:
+ raise TypeError(type_name)
+
+ cdef Dataset self = class_.__new__(class_)
+ self.init(sp)
+ return self
+
+ cdef shared_ptr[CDataset] unwrap(self) nogil:
+ return self.wrapped
+
+ @property
+ def partition_expression(self):
+ """
+ An Expression which evaluates to true for all data viewed by this
+ Dataset.
+ """
+ return Expression.wrap(self.dataset.partition_expression())
+
+ def replace_schema(self, Schema schema not None):
+ """
+ Return a copy of this Dataset with a different schema.
+
+ The copy will view the same Fragments. If the new schema is not
+ compatible with the original dataset's schema then an error will
+ be raised.
+ """
+ cdef shared_ptr[CDataset] copy = GetResultValue(
+ self.dataset.ReplaceSchema(pyarrow_unwrap_schema(schema)))
+ return Dataset.wrap(move(copy))
+
+ def get_fragments(self, Expression filter=None):
+ """Returns an iterator over the fragments in this dataset.
+
+ Parameters
+ ----------
+ filter : Expression, default None
+ Return fragments matching the optional filter, either using the
+ partition_expression or internal information like Parquet's
+ statistics.
+
+ Returns
+ -------
+ fragments : iterator of Fragment
+ """
+ cdef:
+ CExpression c_filter
+ CFragmentIterator c_iterator
+
+ if filter is None:
+ c_fragments = move(GetResultValue(self.dataset.GetFragments()))
+ else:
+ c_filter = _bind(filter, self.schema)
+ c_fragments = move(GetResultValue(
+ self.dataset.GetFragments(c_filter)))
+
+ for maybe_fragment in c_fragments:
+ yield Fragment.wrap(GetResultValue(move(maybe_fragment)))
+
+ def scanner(self, **kwargs):
+ """Builds a scan operation against the dataset.
+
+ Data is not loaded immediately. Instead, this produces a Scanner,
+ which exposes further operations (e.g. loading all data as a
+ table, counting rows).
+
+ Parameters
+ ----------
+ columns : list of str, default None
+ The columns to project. This can be a list of column names to
+ include (order and duplicates will be preserved), or a dictionary
+ with {new_column_name: expression} values for more advanced
+ projections.
+ The columns will be passed down to Datasets and corresponding data
+ fragments to avoid loading, copying, and deserializing columns
+ that will not be required further down the compute chain.
+ By default all of the available columns are projected. Raises
+ an exception if any of the referenced column names does not exist
+ in the dataset's Schema.
+ filter : Expression, default None
+ Scan will return only the rows matching the filter.
+ If possible the predicate will be pushed down to exploit the
+ partition information or internal metadata found in the data
+ source, e.g. Parquet statistics. Otherwise filters the loaded
+ RecordBatches before yielding them.
+ batch_size : int, default 1M
+ The maximum row count for scanned record batches. If scanned
+ record batches are overflowing memory then this method can be
+ called to reduce their size.
+ use_threads : bool, default True
+ If enabled, then maximum parallelism will be used determined by
+ the number of available CPU cores.
+ use_async : bool, default False
+ If enabled, an async scanner will be used that should offer
+ better performance with high-latency/highly-parallel filesystems
+ (e.g. S3)
+
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required. If not specified, uses the
+ default pool.
+ fragment_scan_options : FragmentScanOptions, default None
+ Options specific to a particular scan and fragment type, which
+ can change between different scans of the same dataset.
+
+ Returns
+ -------
+ scanner : Scanner
+
+ Examples
+ --------
+ >>> import pyarrow.dataset as ds
+ >>> dataset = ds.dataset("path/to/dataset")
+
+ Selecting a subset of the columns:
+
+ >>> dataset.scanner(columns=["A", "B"]).to_table()
+
+ Projecting selected columns using an expression:
+
+ >>> dataset.scanner(columns={
+ ... "A_int": ds.field("A").cast("int64"),
+ ... }).to_table()
+
+ Filtering rows while scanning:
+
+ >>> dataset.scanner(filter=ds.field("A") > 0).to_table()
+ """
+ return Scanner.from_dataset(self, **kwargs)
+
+ def to_batches(self, **kwargs):
+ """Read the dataset as materialized record batches.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ record_batches : iterator of RecordBatch
+ """
+ return self.scanner(**kwargs).to_batches()
+
+ def to_table(self, **kwargs):
+ """Read the dataset to an arrow table.
+
+ Note that this method reads all the selected data from the dataset
+ into memory.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ return self.scanner(**kwargs).to_table()
+
+ def take(self, object indices, **kwargs):
+ """Select rows of data by index.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ return self.scanner(**kwargs).take(indices)
+
+ def head(self, int num_rows, **kwargs):
+ """Load the first N rows of the dataset.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ return self.scanner(**kwargs).head(num_rows)
+
+ def count_rows(self, **kwargs):
+ """Count rows matching the scanner filter.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ count : int
+ """
+ return self.scanner(**kwargs).count_rows()
+
+ @property
+ def schema(self):
+ """The common schema of the full Dataset"""
+ return pyarrow_wrap_schema(self.dataset.schema())
+
+
+cdef class InMemoryDataset(Dataset):
+ """
+ A Dataset wrapping in-memory data.
+
+ Parameters
+ ----------
+ source : The data for this dataset.
+ Can be a RecordBatch, Table, list of
+ RecordBatch/Table, iterable of RecordBatch, or a RecordBatchReader.
+ If an iterable is provided, the schema must also be provided.
+ schema : Schema, optional
+ Only required if passing an iterable as the source.
+ """
+
+ cdef:
+ CInMemoryDataset* in_memory_dataset
+
+ def __init__(self, source, Schema schema=None):
+ cdef:
+ RecordBatchReader reader
+ shared_ptr[CInMemoryDataset] in_memory_dataset
+
+ if isinstance(source, (pa.RecordBatch, pa.Table)):
+ source = [source]
+
+ if isinstance(source, (list, tuple)):
+ batches = []
+ for item in source:
+ if isinstance(item, pa.RecordBatch):
+ batches.append(item)
+ elif isinstance(item, pa.Table):
+ batches.extend(item.to_batches())
+ else:
+ raise TypeError(
+ 'Expected a list of tables or batches. The given list '
+ 'contains a ' + type(item).__name__)
+ if schema is None:
+ schema = item.schema
+ elif not schema.equals(item.schema):
+ raise ArrowTypeError(
+ f'Item has schema\n{item.schema}\nwhich does not '
+ f'match expected schema\n{schema}')
+ if not batches and schema is None:
+ raise ValueError('Must provide schema to construct in-memory '
+ 'dataset from an empty list')
+ table = pa.Table.from_batches(batches, schema=schema)
+ in_memory_dataset = make_shared[CInMemoryDataset](
+ pyarrow_unwrap_table(table))
+ else:
+ raise TypeError(
+ 'Expected a table, batch, or list of tables/batches '
+ 'instead of the given type: ' +
+ type(source).__name__
+ )
+
+ self.init(<shared_ptr[CDataset]> in_memory_dataset)
+
+ cdef void init(self, const shared_ptr[CDataset]& sp):
+ Dataset.init(self, sp)
+ self.in_memory_dataset = <CInMemoryDataset*> sp.get()
+
+
+cdef class UnionDataset(Dataset):
+ """
+ A Dataset wrapping child datasets.
+
+ Children's schemas must agree with the provided schema.
+
+ Parameters
+ ----------
+ schema : Schema
+ A known schema to conform to.
+ children : list of Dataset
+ One or more input children
+ """
+
+ cdef:
+ CUnionDataset* union_dataset
+
+ def __init__(self, Schema schema not None, children):
+ cdef:
+ Dataset child
+ CDatasetVector c_children
+ shared_ptr[CUnionDataset] union_dataset
+
+ for child in children:
+ c_children.push_back(child.wrapped)
+
+ union_dataset = GetResultValue(CUnionDataset.Make(
+ pyarrow_unwrap_schema(schema), move(c_children)))
+ self.init(<shared_ptr[CDataset]> union_dataset)
+
+ cdef void init(self, const shared_ptr[CDataset]& sp):
+ Dataset.init(self, sp)
+ self.union_dataset = <CUnionDataset*> sp.get()
+
+ def __reduce__(self):
+ return UnionDataset, (self.schema, self.children)
+
+ @property
+ def children(self):
+ cdef CDatasetVector children = self.union_dataset.children()
+ return [Dataset.wrap(children[i]) for i in range(children.size())]
+
+
+cdef class FileSystemDataset(Dataset):
+ """
+ A Dataset of file fragments.
+
+ A FileSystemDataset is composed of one or more FileFragment.
+
+ Parameters
+ ----------
+ fragments : list[Fragments]
+ List of fragments to consume.
+ schema : Schema
+ The top-level schema of the Dataset.
+ format : FileFormat
+ File format of the fragments, currently only ParquetFileFormat,
+ IpcFileFormat, and CsvFileFormat are supported.
+ filesystem : FileSystem
+ FileSystem of the fragments.
+ root_partition : Expression, optional
+ The top-level partition of the DataDataset.
+ """
+
+ cdef:
+ CFileSystemDataset* filesystem_dataset
+
+ def __init__(self, fragments, Schema schema, FileFormat format,
+ FileSystem filesystem=None, root_partition=None):
+ cdef:
+ FileFragment fragment=None
+ vector[shared_ptr[CFileFragment]] c_fragments
+ CResult[shared_ptr[CDataset]] result
+ shared_ptr[CFileSystem] c_filesystem
+
+ if root_partition is None:
+ root_partition = _true
+ elif not isinstance(root_partition, Expression):
+ raise TypeError(
+ "Argument 'root_partition' has incorrect type (expected "
+ "Epression, got {0})".format(type(root_partition))
+ )
+
+ for fragment in fragments:
+ c_fragments.push_back(
+ static_pointer_cast[CFileFragment, CFragment](
+ fragment.unwrap()))
+
+ if filesystem is None:
+ filesystem = fragment.filesystem
+
+ if filesystem is not None:
+ c_filesystem = filesystem.unwrap()
+
+ result = CFileSystemDataset.Make(
+ pyarrow_unwrap_schema(schema),
+ (<Expression> root_partition).unwrap(),
+ format.unwrap(),
+ c_filesystem,
+ c_fragments
+ )
+ self.init(GetResultValue(result))
+
+ @property
+ def filesystem(self):
+ return FileSystem.wrap(self.filesystem_dataset.filesystem())
+
+ @property
+ def partitioning(self):
+ """
+ The partitioning of the Dataset source, if discovered.
+
+ If the FileSystemDataset is created using the ``dataset()`` factory
+ function with a partitioning specified, this will return the
+ finalized Partitioning object from the dataset discovery. In all
+ other cases, this returns None.
+ """
+ c_partitioning = self.filesystem_dataset.partitioning()
+ if c_partitioning.get() == nullptr:
+ return None
+ try:
+ return Partitioning.wrap(c_partitioning)
+ except TypeError:
+ # e.g. type_name "default"
+ return None
+
+ cdef void init(self, const shared_ptr[CDataset]& sp):
+ Dataset.init(self, sp)
+ self.filesystem_dataset = <CFileSystemDataset*> sp.get()
+
+ def __reduce__(self):
+ return FileSystemDataset, (
+ list(self.get_fragments()),
+ self.schema,
+ self.format,
+ self.filesystem,
+ self.partition_expression
+ )
+
+ @classmethod
+ def from_paths(cls, paths, schema=None, format=None,
+ filesystem=None, partitions=None, root_partition=None):
+ """A Dataset created from a list of paths on a particular filesystem.
+
+ Parameters
+ ----------
+ paths : list of str
+ List of file paths to create the fragments from.
+ schema : Schema
+ The top-level schema of the DataDataset.
+ format : FileFormat
+ File format to create fragments from, currently only
+ ParquetFileFormat, IpcFileFormat, and CsvFileFormat are supported.
+ filesystem : FileSystem
+ The filesystem which files are from.
+ partitions : List[Expression], optional
+ Attach additional partition information for the file paths.
+ root_partition : Expression, optional
+ The top-level partition of the DataDataset.
+ """
+ cdef:
+ FileFragment fragment
+
+ if root_partition is None:
+ root_partition = _true
+
+ for arg, class_, name in [
+ (schema, Schema, 'schema'),
+ (format, FileFormat, 'format'),
+ (filesystem, FileSystem, 'filesystem'),
+ (root_partition, Expression, 'root_partition')
+ ]:
+ if not isinstance(arg, class_):
+ raise TypeError(
+ "Argument '{0}' has incorrect type (expected {1}, "
+ "got {2})".format(name, class_.__name__, type(arg))
+ )
+
+ partitions = partitions or [_true] * len(paths)
+
+ if len(paths) != len(partitions):
+ raise ValueError(
+ 'The number of files resulting from paths_or_selector '
+ 'must be equal to the number of partitions.'
+ )
+
+ fragments = [
+ format.make_fragment(path, filesystem, partitions[i])
+ for i, path in enumerate(paths)
+ ]
+ return FileSystemDataset(fragments, schema, format,
+ filesystem, root_partition)
+
+ @property
+ def files(self):
+ """List of the files"""
+ cdef vector[c_string] files = self.filesystem_dataset.files()
+ return [frombytes(f) for f in files]
+
+ @property
+ def format(self):
+ """The FileFormat of this source."""
+ return FileFormat.wrap(self.filesystem_dataset.format())
+
+
+cdef CExpression _bind(Expression filter, Schema schema) except *:
+ assert schema is not None
+
+ if filter is None:
+ return _true.unwrap()
+
+ return GetResultValue(filter.unwrap().Bind(
+ deref(pyarrow_unwrap_schema(schema).get())))
+
+
+cdef class FileWriteOptions(_Weakrefable):
+
+ cdef:
+ shared_ptr[CFileWriteOptions] wrapped
+ CFileWriteOptions* c_options
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef void init(self, const shared_ptr[CFileWriteOptions]& sp):
+ self.wrapped = sp
+ self.c_options = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFileWriteOptions]& sp):
+ type_name = frombytes(sp.get().type_name())
+
+ classes = {
+ 'csv': CsvFileWriteOptions,
+ 'ipc': IpcFileWriteOptions,
+ 'parquet': ParquetFileWriteOptions,
+ }
+
+ class_ = classes.get(type_name, None)
+ if class_ is None:
+ raise TypeError(type_name)
+
+ cdef FileWriteOptions self = class_.__new__(class_)
+ self.init(sp)
+ return self
+
+ @property
+ def format(self):
+ return FileFormat.wrap(self.c_options.format())
+
+ cdef inline shared_ptr[CFileWriteOptions] unwrap(self):
+ return self.wrapped
+
+
+cdef class FileFormat(_Weakrefable):
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef void init(self, const shared_ptr[CFileFormat]& sp):
+ self.wrapped = sp
+ self.format = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFileFormat]& sp):
+ type_name = frombytes(sp.get().type_name())
+
+ classes = {
+ 'ipc': IpcFileFormat,
+ 'csv': CsvFileFormat,
+ 'parquet': ParquetFileFormat,
+ 'orc': _get_orc_fileformat(),
+ }
+
+ class_ = classes.get(type_name, None)
+ if class_ is None:
+ raise TypeError(type_name)
+
+ cdef FileFormat self = class_.__new__(class_)
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[CFileFormat] unwrap(self):
+ return self.wrapped
+
+ def inspect(self, file, filesystem=None):
+ """Infer the schema of a file."""
+ c_source = _make_file_source(file, filesystem)
+ c_schema = GetResultValue(self.format.Inspect(c_source))
+ return pyarrow_wrap_schema(move(c_schema))
+
+ def make_fragment(self, file, filesystem=None,
+ Expression partition_expression=None):
+ """
+ Make a FileFragment of this FileFormat. The filter may not reference
+ fields absent from the provided schema. If no schema is provided then
+ one will be inferred.
+ """
+ if partition_expression is None:
+ partition_expression = _true
+
+ c_source = _make_file_source(file, filesystem)
+ c_fragment = <shared_ptr[CFragment]> GetResultValue(
+ self.format.MakeFragment(move(c_source),
+ partition_expression.unwrap(),
+ <shared_ptr[CSchema]>nullptr))
+ return Fragment.wrap(move(c_fragment))
+
+ def make_write_options(self):
+ return FileWriteOptions.wrap(self.format.DefaultWriteOptions())
+
+ @property
+ def default_extname(self):
+ return frombytes(self.format.type_name())
+
+ @property
+ def default_fragment_scan_options(self):
+ return FragmentScanOptions.wrap(
+ self.wrapped.get().default_fragment_scan_options)
+
+ @default_fragment_scan_options.setter
+ def default_fragment_scan_options(self, FragmentScanOptions options):
+ if options is None:
+ self.wrapped.get().default_fragment_scan_options =\
+ <shared_ptr[CFragmentScanOptions]>nullptr
+ else:
+ self._set_default_fragment_scan_options(options)
+
+ cdef _set_default_fragment_scan_options(self, FragmentScanOptions options):
+ raise ValueError(f"Cannot set fragment scan options for "
+ f"'{options.type_name}' on {self.__class__.__name__}")
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return False
+
+
+cdef class Fragment(_Weakrefable):
+ """Fragment of data from a Dataset."""
+
+ cdef:
+ shared_ptr[CFragment] wrapped
+ CFragment* fragment
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef void init(self, const shared_ptr[CFragment]& sp):
+ self.wrapped = sp
+ self.fragment = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFragment]& sp):
+ type_name = frombytes(sp.get().type_name())
+
+ classes = {
+ # IpcFileFormat and CsvFileFormat do not have corresponding
+ # subclasses of FileFragment
+ 'ipc': FileFragment,
+ 'csv': FileFragment,
+ 'parquet': ParquetFileFragment,
+ }
+
+ class_ = classes.get(type_name, None)
+ if class_ is None:
+ class_ = Fragment
+
+ cdef Fragment self = class_.__new__(class_)
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[CFragment] unwrap(self):
+ return self.wrapped
+
+ @property
+ def physical_schema(self):
+ """Return the physical schema of this Fragment. This schema can be
+ different from the dataset read schema."""
+ cdef:
+ CResult[shared_ptr[CSchema]] maybe_schema
+ with nogil:
+ maybe_schema = self.fragment.ReadPhysicalSchema()
+ return pyarrow_wrap_schema(GetResultValue(maybe_schema))
+
+ @property
+ def partition_expression(self):
+ """An Expression which evaluates to true for all data viewed by this
+ Fragment.
+ """
+ return Expression.wrap(self.fragment.partition_expression())
+
+ def scanner(self, Schema schema=None, **kwargs):
+ """Builds a scan operation against the dataset.
+
+ Data is not loaded immediately. Instead, this produces a Scanner,
+ which exposes further operations (e.g. loading all data as a
+ table, counting rows).
+
+ Parameters
+ ----------
+ schema : Schema
+ Schema to use for scanning. This is used to unify a Fragment to
+ it's Dataset's schema. If not specified this will use the
+ Fragment's physical schema which might differ for each Fragment.
+ columns : list of str, default None
+ The columns to project. This can be a list of column names to
+ include (order and duplicates will be preserved), or a dictionary
+ with {new_column_name: expression} values for more advanced
+ projections.
+ The columns will be passed down to Datasets and corresponding data
+ fragments to avoid loading, copying, and deserializing columns
+ that will not be required further down the compute chain.
+ By default all of the available columns are projected. Raises
+ an exception if any of the referenced column names does not exist
+ in the dataset's Schema.
+ filter : Expression, default None
+ Scan will return only the rows matching the filter.
+ If possible the predicate will be pushed down to exploit the
+ partition information or internal metadata found in the data
+ source, e.g. Parquet statistics. Otherwise filters the loaded
+ RecordBatches before yielding them.
+ batch_size : int, default 1M
+ The maximum row count for scanned record batches. If scanned
+ record batches are overflowing memory then this method can be
+ called to reduce their size.
+ use_threads : bool, default True
+ If enabled, then maximum parallelism will be used determined by
+ the number of available CPU cores.
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required. If not specified, uses the
+ default pool.
+ fragment_scan_options : FragmentScanOptions, default None
+ Options specific to a particular scan and fragment type, which
+ can change between different scans of the same dataset.
+
+ Returns
+ -------
+ scanner : Scanner
+
+ """
+ return Scanner.from_fragment(self, schema=schema, **kwargs)
+
+ def to_batches(self, Schema schema=None, **kwargs):
+ """Read the fragment as materialized record batches.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ record_batches : iterator of RecordBatch
+ """
+ return self.scanner(schema=schema, **kwargs).to_batches()
+
+ def to_table(self, Schema schema=None, **kwargs):
+ """Convert this Fragment into a Table.
+
+ Use this convenience utility with care. This will serially materialize
+ the Scan result in memory before creating the Table.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ table : Table
+ """
+ return self.scanner(schema=schema, **kwargs).to_table()
+
+ def take(self, object indices, **kwargs):
+ """Select rows of data by index.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ return self.scanner(**kwargs).take(indices)
+
+ def head(self, int num_rows, **kwargs):
+ """Load the first N rows of the fragment.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ return self.scanner(**kwargs).head(num_rows)
+
+ def count_rows(self, **kwargs):
+ """Count rows matching the scanner filter.
+
+ See scanner method parameters documentation.
+
+ Returns
+ -------
+ count : int
+ """
+ return self.scanner(**kwargs).count_rows()
+
+
+cdef class FileFragment(Fragment):
+ """A Fragment representing a data file."""
+
+ cdef:
+ CFileFragment* file_fragment
+
+ cdef void init(self, const shared_ptr[CFragment]& sp):
+ Fragment.init(self, sp)
+ self.file_fragment = <CFileFragment*> sp.get()
+
+ def __repr__(self):
+ type_name = frombytes(self.fragment.type_name())
+ if type_name != "parquet":
+ typ = f" type={type_name}"
+ else:
+ # parquet has a subclass -> type embedded in class name
+ typ = ""
+ partition_dict = _get_partition_keys(self.partition_expression)
+ partition = ", ".join(
+ [f"{key}={val}" for key, val in partition_dict.items()]
+ )
+ if partition:
+ partition = f" partition=[{partition}]"
+ return "<pyarrow.dataset.{0}{1} path={2}{3}>".format(
+ self.__class__.__name__, typ, self.path, partition
+ )
+
+ def __reduce__(self):
+ buffer = self.buffer
+ return self.format.make_fragment, (
+ self.path if buffer is None else buffer,
+ self.filesystem,
+ self.partition_expression
+ )
+
+ @property
+ def path(self):
+ """
+ The path of the data file viewed by this fragment, if it views a
+ file. If instead it views a buffer, this will be "<Buffer>".
+ """
+ return frombytes(self.file_fragment.source().path())
+
+ @property
+ def filesystem(self):
+ """
+ The FileSystem containing the data file viewed by this fragment, if
+ it views a file. If instead it views a buffer, this will be None.
+ """
+ cdef:
+ shared_ptr[CFileSystem] c_fs
+ c_fs = self.file_fragment.source().filesystem()
+
+ if c_fs.get() == nullptr:
+ return None
+
+ return FileSystem.wrap(c_fs)
+
+ @property
+ def buffer(self):
+ """
+ The buffer viewed by this fragment, if it views a buffer. If
+ instead it views a file, this will be None.
+ """
+ cdef:
+ shared_ptr[CBuffer] c_buffer
+ c_buffer = self.file_fragment.source().buffer()
+
+ if c_buffer.get() == nullptr:
+ return None
+
+ return pyarrow_wrap_buffer(c_buffer)
+
+ @property
+ def format(self):
+ """
+ The format of the data file viewed by this fragment.
+ """
+ return FileFormat.wrap(self.file_fragment.format())
+
+
+class RowGroupInfo:
+ """
+ A wrapper class for RowGroup information
+
+ Parameters
+ ----------
+ id : the group id.
+ metadata : the rowgroup metadata.
+ schema : schema of the rows.
+ """
+
+ def __init__(self, id, metadata, schema):
+ self.id = id
+ self.metadata = metadata
+ self.schema = schema
+
+ @property
+ def num_rows(self):
+ return self.metadata.num_rows
+
+ @property
+ def total_byte_size(self):
+ return self.metadata.total_byte_size
+
+ @property
+ def statistics(self):
+ def name_stats(i):
+ col = self.metadata.column(i)
+
+ stats = col.statistics
+ if stats is None or not stats.has_min_max:
+ return None, None
+
+ name = col.path_in_schema
+ field_index = self.schema.get_field_index(name)
+ if field_index < 0:
+ return None, None
+
+ typ = self.schema.field(field_index).type
+ return col.path_in_schema, {
+ 'min': pa.scalar(stats.min, type=typ).as_py(),
+ 'max': pa.scalar(stats.max, type=typ).as_py()
+ }
+
+ return {
+ name: stats for name, stats
+ in map(name_stats, range(self.metadata.num_columns))
+ if stats is not None
+ }
+
+ def __repr__(self):
+ return "RowGroupInfo({})".format(self.id)
+
+ def __eq__(self, other):
+ if isinstance(other, int):
+ return self.id == other
+ if not isinstance(other, RowGroupInfo):
+ return False
+ return self.id == other.id
+
+
+cdef class FragmentScanOptions(_Weakrefable):
+ """Scan options specific to a particular fragment and scan operation."""
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp):
+ self.wrapped = sp
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFragmentScanOptions]& sp):
+ if not sp:
+ return None
+
+ type_name = frombytes(sp.get().type_name())
+
+ classes = {
+ 'csv': CsvFragmentScanOptions,
+ 'parquet': ParquetFragmentScanOptions,
+ }
+
+ class_ = classes.get(type_name, None)
+ if class_ is None:
+ raise TypeError(type_name)
+
+ cdef FragmentScanOptions self = class_.__new__(class_)
+ self.init(sp)
+ return self
+
+ @property
+ def type_name(self):
+ return frombytes(self.wrapped.get().type_name())
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return False
+
+
+cdef class ParquetFileFragment(FileFragment):
+ """A Fragment representing a parquet file."""
+
+ cdef:
+ CParquetFileFragment* parquet_file_fragment
+
+ cdef void init(self, const shared_ptr[CFragment]& sp):
+ FileFragment.init(self, sp)
+ self.parquet_file_fragment = <CParquetFileFragment*> sp.get()
+
+ def __reduce__(self):
+ buffer = self.buffer
+ row_groups = [row_group.id for row_group in self.row_groups]
+ return self.format.make_fragment, (
+ self.path if buffer is None else buffer,
+ self.filesystem,
+ self.partition_expression,
+ row_groups
+ )
+
+ def ensure_complete_metadata(self):
+ """
+ Ensure that all metadata (statistics, physical schema, ...) have
+ been read and cached in this fragment.
+ """
+ check_status(self.parquet_file_fragment.EnsureCompleteMetadata())
+
+ @property
+ def row_groups(self):
+ metadata = self.metadata
+ cdef vector[int] row_groups = self.parquet_file_fragment.row_groups()
+ return [RowGroupInfo(i, metadata.row_group(i), self.physical_schema)
+ for i in row_groups]
+
+ @property
+ def metadata(self):
+ self.ensure_complete_metadata()
+ cdef FileMetaData metadata = FileMetaData()
+ metadata.init(self.parquet_file_fragment.metadata())
+ return metadata
+
+ @property
+ def num_row_groups(self):
+ """
+ Return the number of row groups viewed by this fragment (not the
+ number of row groups in the origin file).
+ """
+ self.ensure_complete_metadata()
+ return self.parquet_file_fragment.row_groups().size()
+
+ def split_by_row_group(self, Expression filter=None,
+ Schema schema=None):
+ """
+ Split the fragment into multiple fragments.
+
+ Yield a Fragment wrapping each row group in this ParquetFileFragment.
+ Row groups will be excluded whose metadata contradicts the optional
+ filter.
+
+ Parameters
+ ----------
+ filter : Expression, default None
+ Only include the row groups which satisfy this predicate (using
+ the Parquet RowGroup statistics).
+ schema : Schema, default None
+ Schema to use when filtering row groups. Defaults to the
+ Fragment's phsyical schema
+
+ Returns
+ -------
+ A list of Fragments
+ """
+ cdef:
+ vector[shared_ptr[CFragment]] c_fragments
+ CExpression c_filter
+ shared_ptr[CFragment] c_fragment
+
+ schema = schema or self.physical_schema
+ c_filter = _bind(filter, schema)
+ with nogil:
+ c_fragments = move(GetResultValue(
+ self.parquet_file_fragment.SplitByRowGroup(move(c_filter))))
+
+ return [Fragment.wrap(c_fragment) for c_fragment in c_fragments]
+
+ def subset(self, Expression filter=None, Schema schema=None,
+ object row_group_ids=None):
+ """
+ Create a subset of the fragment (viewing a subset of the row groups).
+
+ Subset can be specified by either a filter predicate (with optional
+ schema) or by a list of row group IDs. Note that when using a filter,
+ the resulting fragment can be empty (viewing no row groups).
+
+ Parameters
+ ----------
+ filter : Expression, default None
+ Only include the row groups which satisfy this predicate (using
+ the Parquet RowGroup statistics).
+ schema : Schema, default None
+ Schema to use when filtering row groups. Defaults to the
+ Fragment's phsyical schema
+ row_group_ids : list of ints
+ The row group IDs to include in the subset. Can only be specified
+ if `filter` is None.
+
+ Returns
+ -------
+ ParquetFileFragment
+ """
+ cdef:
+ CExpression c_filter
+ vector[int] c_row_group_ids
+ shared_ptr[CFragment] c_fragment
+
+ if filter is not None and row_group_ids is not None:
+ raise ValueError(
+ "Cannot specify both 'filter' and 'row_group_ids'."
+ )
+
+ if filter is not None:
+ schema = schema or self.physical_schema
+ c_filter = _bind(filter, schema)
+ with nogil:
+ c_fragment = move(GetResultValue(
+ self.parquet_file_fragment.SubsetWithFilter(
+ move(c_filter))))
+ elif row_group_ids is not None:
+ c_row_group_ids = [
+ <int> row_group for row_group in sorted(set(row_group_ids))
+ ]
+ with nogil:
+ c_fragment = move(GetResultValue(
+ self.parquet_file_fragment.SubsetWithIds(
+ move(c_row_group_ids))))
+ else:
+ raise ValueError(
+ "Need to specify one of 'filter' or 'row_group_ids'"
+ )
+
+ return Fragment.wrap(c_fragment)
+
+
+cdef class ParquetReadOptions(_Weakrefable):
+ """
+ Parquet format specific options for reading.
+
+ Parameters
+ ----------
+ dictionary_columns : list of string, default None
+ Names of columns which should be dictionary encoded as
+ they are read.
+ coerce_int96_timestamp_unit : str, default None.
+ Cast timestamps that are stored in INT96 format to a particular
+ resolution (e.g. 'ms'). Setting to None is equivalent to 'ns'
+ and therefore INT96 timestamps will be infered as timestamps
+ in nanoseconds.
+ """
+
+ cdef public:
+ set dictionary_columns
+ TimeUnit _coerce_int96_timestamp_unit
+
+ # Also see _PARQUET_READ_OPTIONS
+ def __init__(self, dictionary_columns=None,
+ coerce_int96_timestamp_unit=None):
+ self.dictionary_columns = set(dictionary_columns or set())
+ self.coerce_int96_timestamp_unit = coerce_int96_timestamp_unit
+
+ @property
+ def coerce_int96_timestamp_unit(self):
+ return timeunit_to_string(self._coerce_int96_timestamp_unit)
+
+ @coerce_int96_timestamp_unit.setter
+ def coerce_int96_timestamp_unit(self, unit):
+ if unit is not None:
+ self._coerce_int96_timestamp_unit = string_to_timeunit(unit)
+ else:
+ self._coerce_int96_timestamp_unit = TimeUnit_NANO
+
+ def equals(self, ParquetReadOptions other):
+ return (self.dictionary_columns == other.dictionary_columns and
+ self.coerce_int96_timestamp_unit ==
+ other.coerce_int96_timestamp_unit)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return False
+
+ def __repr__(self):
+ return (
+ f"<ParquetReadOptions"
+ f" dictionary_columns={self.dictionary_columns}"
+ f" coerce_int96_timestamp_unit={self.coerce_int96_timestamp_unit}>"
+ )
+
+
+cdef class ParquetFileWriteOptions(FileWriteOptions):
+
+ cdef:
+ CParquetFileWriteOptions* parquet_options
+ object _properties
+
+ def update(self, **kwargs):
+ arrow_fields = {
+ "use_deprecated_int96_timestamps",
+ "coerce_timestamps",
+ "allow_truncated_timestamps",
+ }
+
+ setters = set()
+ for name, value in kwargs.items():
+ if name not in self._properties:
+ raise TypeError("unexpected parquet write option: " + name)
+ self._properties[name] = value
+ if name in arrow_fields:
+ setters.add(self._set_arrow_properties)
+ else:
+ setters.add(self._set_properties)
+
+ for setter in setters:
+ setter()
+
+ def _set_properties(self):
+ cdef CParquetFileWriteOptions* opts = self.parquet_options
+
+ opts.writer_properties = _create_writer_properties(
+ use_dictionary=self._properties["use_dictionary"],
+ compression=self._properties["compression"],
+ version=self._properties["version"],
+ write_statistics=self._properties["write_statistics"],
+ data_page_size=self._properties["data_page_size"],
+ compression_level=self._properties["compression_level"],
+ use_byte_stream_split=(
+ self._properties["use_byte_stream_split"]
+ ),
+ data_page_version=self._properties["data_page_version"],
+ )
+
+ def _set_arrow_properties(self):
+ cdef CParquetFileWriteOptions* opts = self.parquet_options
+
+ opts.arrow_writer_properties = _create_arrow_writer_properties(
+ use_deprecated_int96_timestamps=(
+ self._properties["use_deprecated_int96_timestamps"]
+ ),
+ coerce_timestamps=self._properties["coerce_timestamps"],
+ allow_truncated_timestamps=(
+ self._properties["allow_truncated_timestamps"]
+ ),
+ writer_engine_version="V2",
+ use_compliant_nested_type=(
+ self._properties["use_compliant_nested_type"]
+ )
+ )
+
+ cdef void init(self, const shared_ptr[CFileWriteOptions]& sp):
+ FileWriteOptions.init(self, sp)
+ self.parquet_options = <CParquetFileWriteOptions*> sp.get()
+ self._properties = dict(
+ use_dictionary=True,
+ compression="snappy",
+ version="1.0",
+ write_statistics=None,
+ data_page_size=None,
+ compression_level=None,
+ use_byte_stream_split=False,
+ data_page_version="1.0",
+ use_deprecated_int96_timestamps=False,
+ coerce_timestamps=None,
+ allow_truncated_timestamps=False,
+ use_compliant_nested_type=False,
+ )
+ self._set_properties()
+ self._set_arrow_properties()
+
+
+cdef set _PARQUET_READ_OPTIONS = {
+ 'dictionary_columns', 'coerce_int96_timestamp_unit'
+}
+
+
+cdef class ParquetFileFormat(FileFormat):
+ """
+ FileFormat for Parquet
+
+ Parameters
+ ----------
+ read_options : ParquetReadOptions
+ Read options for the file.
+ default_fragment_scan_options : ParquetFragmentScanOptions
+ Scan Options for the file.
+ **kwargs : dict
+ Additional options for read option or scan option.
+ """
+
+ cdef:
+ CParquetFileFormat* parquet_format
+
+ def __init__(self, read_options=None,
+ default_fragment_scan_options=None, **kwargs):
+ cdef:
+ shared_ptr[CParquetFileFormat] wrapped
+ CParquetFileFormatReaderOptions* options
+
+ # Read/scan options
+ read_options_args = {option: kwargs[option] for option in kwargs
+ if option in _PARQUET_READ_OPTIONS}
+ scan_args = {option: kwargs[option] for option in kwargs
+ if option not in _PARQUET_READ_OPTIONS}
+ if read_options and read_options_args:
+ duplicates = ', '.join(sorted(read_options_args))
+ raise ValueError(f'If `read_options` is given, '
+ f'cannot specify {duplicates}')
+ if default_fragment_scan_options and scan_args:
+ duplicates = ', '.join(sorted(scan_args))
+ raise ValueError(f'If `default_fragment_scan_options` is given, '
+ f'cannot specify {duplicates}')
+
+ if read_options is None:
+ read_options = ParquetReadOptions(**read_options_args)
+ elif isinstance(read_options, dict):
+ # For backwards compatibility
+ duplicates = []
+ for option, value in read_options.items():
+ if option in _PARQUET_READ_OPTIONS:
+ read_options_args[option] = value
+ else:
+ duplicates.append(option)
+ scan_args[option] = value
+ if duplicates:
+ duplicates = ", ".join(duplicates)
+ warnings.warn(f'The scan options {duplicates} should be '
+ 'specified directly as keyword arguments')
+ read_options = ParquetReadOptions(**read_options_args)
+ elif not isinstance(read_options, ParquetReadOptions):
+ raise TypeError('`read_options` must be either a dictionary or an '
+ 'instance of ParquetReadOptions')
+
+ if default_fragment_scan_options is None:
+ default_fragment_scan_options = ParquetFragmentScanOptions(
+ **scan_args)
+ elif isinstance(default_fragment_scan_options, dict):
+ default_fragment_scan_options = ParquetFragmentScanOptions(
+ **default_fragment_scan_options)
+ elif not isinstance(default_fragment_scan_options,
+ ParquetFragmentScanOptions):
+ raise TypeError('`default_fragment_scan_options` must be either a '
+ 'dictionary or an instance of '
+ 'ParquetFragmentScanOptions')
+
+ wrapped = make_shared[CParquetFileFormat]()
+ options = &(wrapped.get().reader_options)
+ if read_options.dictionary_columns is not None:
+ for column in read_options.dictionary_columns:
+ options.dict_columns.insert(tobytes(column))
+ options.coerce_int96_timestamp_unit = \
+ read_options._coerce_int96_timestamp_unit
+
+ self.init(<shared_ptr[CFileFormat]> wrapped)
+ self.default_fragment_scan_options = default_fragment_scan_options
+
+ cdef void init(self, const shared_ptr[CFileFormat]& sp):
+ FileFormat.init(self, sp)
+ self.parquet_format = <CParquetFileFormat*> sp.get()
+
+ @property
+ def read_options(self):
+ cdef CParquetFileFormatReaderOptions* options
+ options = &self.parquet_format.reader_options
+ parquet_read_options = ParquetReadOptions(
+ dictionary_columns={frombytes(col)
+ for col in options.dict_columns},
+ )
+ # Read options getter/setter works with strings so setting
+ # the private property which uses the C Type
+ parquet_read_options._coerce_int96_timestamp_unit = \
+ options.coerce_int96_timestamp_unit
+ return parquet_read_options
+
+ def make_write_options(self, **kwargs):
+ opts = FileFormat.make_write_options(self)
+ (<ParquetFileWriteOptions> opts).update(**kwargs)
+ return opts
+
+ cdef _set_default_fragment_scan_options(self, FragmentScanOptions options):
+ if options.type_name == 'parquet':
+ self.parquet_format.default_fragment_scan_options = options.wrapped
+ else:
+ super()._set_default_fragment_scan_options(options)
+
+ def equals(self, ParquetFileFormat other):
+ return (
+ self.read_options.equals(other.read_options) and
+ self.default_fragment_scan_options ==
+ other.default_fragment_scan_options
+ )
+
+ def __reduce__(self):
+ return ParquetFileFormat, (self.read_options,
+ self.default_fragment_scan_options)
+
+ def __repr__(self):
+ return f"<ParquetFileFormat read_options={self.read_options}>"
+
+ def make_fragment(self, file, filesystem=None,
+ Expression partition_expression=None, row_groups=None):
+ cdef:
+ vector[int] c_row_groups
+
+ if partition_expression is None:
+ partition_expression = _true
+
+ if row_groups is None:
+ return super().make_fragment(file, filesystem,
+ partition_expression)
+
+ c_source = _make_file_source(file, filesystem)
+ c_row_groups = [<int> row_group for row_group in set(row_groups)]
+
+ c_fragment = <shared_ptr[CFragment]> GetResultValue(
+ self.parquet_format.MakeFragment(move(c_source),
+ partition_expression.unwrap(),
+ <shared_ptr[CSchema]>nullptr,
+ move(c_row_groups)))
+ return Fragment.wrap(move(c_fragment))
+
+
+cdef class ParquetFragmentScanOptions(FragmentScanOptions):
+ """
+ Scan-specific options for Parquet fragments.
+
+ Parameters
+ ----------
+ use_buffered_stream : bool, default False
+ Read files through buffered input streams rather than loading entire
+ row groups at once. This may be enabled to reduce memory overhead.
+ Disabled by default.
+ buffer_size : int, default 8192
+ Size of buffered stream, if enabled. Default is 8KB.
+ pre_buffer : bool, default False
+ If enabled, pre-buffer the raw Parquet data instead of issuing one
+ read per column chunk. This can improve performance on high-latency
+ filesystems.
+ enable_parallel_column_conversion : bool, default False
+ EXPERIMENTAL: Parallelize conversion across columns. This option is
+ ignored if a scan is already parallelized across input files to avoid
+ thread contention. This option will be removed after support is added
+ for simultaneous parallelization across files and columns.
+ """
+
+ cdef:
+ CParquetFragmentScanOptions* parquet_options
+
+ # Avoid mistakingly creating attributes
+ __slots__ = ()
+
+ def __init__(self, bint use_buffered_stream=False,
+ buffer_size=8192,
+ bint pre_buffer=False,
+ bint enable_parallel_column_conversion=False):
+ self.init(shared_ptr[CFragmentScanOptions](
+ new CParquetFragmentScanOptions()))
+ self.use_buffered_stream = use_buffered_stream
+ self.buffer_size = buffer_size
+ self.pre_buffer = pre_buffer
+ self.enable_parallel_column_conversion = \
+ enable_parallel_column_conversion
+
+ cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp):
+ FragmentScanOptions.init(self, sp)
+ self.parquet_options = <CParquetFragmentScanOptions*> sp.get()
+
+ cdef CReaderProperties* reader_properties(self):
+ return self.parquet_options.reader_properties.get()
+
+ cdef ArrowReaderProperties* arrow_reader_properties(self):
+ return self.parquet_options.arrow_reader_properties.get()
+
+ @property
+ def use_buffered_stream(self):
+ return self.reader_properties().is_buffered_stream_enabled()
+
+ @use_buffered_stream.setter
+ def use_buffered_stream(self, bint use_buffered_stream):
+ if use_buffered_stream:
+ self.reader_properties().enable_buffered_stream()
+ else:
+ self.reader_properties().disable_buffered_stream()
+
+ @property
+ def buffer_size(self):
+ return self.reader_properties().buffer_size()
+
+ @buffer_size.setter
+ def buffer_size(self, buffer_size):
+ if buffer_size <= 0:
+ raise ValueError("Buffer size must be larger than zero")
+ self.reader_properties().set_buffer_size(buffer_size)
+
+ @property
+ def pre_buffer(self):
+ return self.arrow_reader_properties().pre_buffer()
+
+ @pre_buffer.setter
+ def pre_buffer(self, bint pre_buffer):
+ self.arrow_reader_properties().set_pre_buffer(pre_buffer)
+
+ @property
+ def enable_parallel_column_conversion(self):
+ return self.parquet_options.enable_parallel_column_conversion
+
+ @enable_parallel_column_conversion.setter
+ def enable_parallel_column_conversion(
+ self, bint enable_parallel_column_conversion):
+ self.parquet_options.enable_parallel_column_conversion = \
+ enable_parallel_column_conversion
+
+ def equals(self, ParquetFragmentScanOptions other):
+ return (
+ self.use_buffered_stream == other.use_buffered_stream and
+ self.buffer_size == other.buffer_size and
+ self.pre_buffer == other.pre_buffer and
+ self.enable_parallel_column_conversion ==
+ other.enable_parallel_column_conversion
+ )
+
+ def __reduce__(self):
+ return ParquetFragmentScanOptions, (
+ self.use_buffered_stream, self.buffer_size, self.pre_buffer,
+ self.enable_parallel_column_conversion
+ )
+
+
+cdef class IpcFileWriteOptions(FileWriteOptions):
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+
+cdef class IpcFileFormat(FileFormat):
+
+ def __init__(self):
+ self.init(shared_ptr[CFileFormat](new CIpcFileFormat()))
+
+ def equals(self, IpcFileFormat other):
+ return True
+
+ @property
+ def default_extname(self):
+ return "feather"
+
+ def __reduce__(self):
+ return IpcFileFormat, tuple()
+
+
+cdef class CsvFileFormat(FileFormat):
+ """
+ FileFormat for CSV files.
+
+ Parameters
+ ----------
+ parse_options : ParseOptions
+ Options regarding CSV parsing.
+ convert_options : ConvertOptions
+ Options regarding value conversion.
+ read_options : ReadOptions
+ General read options.
+ default_fragment_scan_options : CsvFragmentScanOptions
+ Default options for fragments scan.
+ """
+ cdef:
+ CCsvFileFormat* csv_format
+
+ # Avoid mistakingly creating attributes
+ __slots__ = ()
+
+ def __init__(self, ParseOptions parse_options=None,
+ default_fragment_scan_options=None,
+ ConvertOptions convert_options=None,
+ ReadOptions read_options=None):
+ self.init(shared_ptr[CFileFormat](new CCsvFileFormat()))
+ if parse_options is not None:
+ self.parse_options = parse_options
+ if convert_options is not None or read_options is not None:
+ if default_fragment_scan_options:
+ raise ValueError('If `default_fragment_scan_options` is '
+ 'given, cannot specify convert_options '
+ 'or read_options')
+ self.default_fragment_scan_options = CsvFragmentScanOptions(
+ convert_options=convert_options, read_options=read_options)
+ elif isinstance(default_fragment_scan_options, dict):
+ self.default_fragment_scan_options = CsvFragmentScanOptions(
+ **default_fragment_scan_options)
+ elif isinstance(default_fragment_scan_options, CsvFragmentScanOptions):
+ self.default_fragment_scan_options = default_fragment_scan_options
+ elif default_fragment_scan_options is not None:
+ raise TypeError('`default_fragment_scan_options` must be either '
+ 'a dictionary or an instance of '
+ 'CsvFragmentScanOptions')
+
+ cdef void init(self, const shared_ptr[CFileFormat]& sp):
+ FileFormat.init(self, sp)
+ self.csv_format = <CCsvFileFormat*> sp.get()
+
+ def make_write_options(self, **kwargs):
+ cdef CsvFileWriteOptions opts = \
+ <CsvFileWriteOptions> FileFormat.make_write_options(self)
+ opts.write_options = WriteOptions(**kwargs)
+ return opts
+
+ @property
+ def parse_options(self):
+ return ParseOptions.wrap(self.csv_format.parse_options)
+
+ @parse_options.setter
+ def parse_options(self, ParseOptions parse_options not None):
+ self.csv_format.parse_options = deref(parse_options.options)
+
+ cdef _set_default_fragment_scan_options(self, FragmentScanOptions options):
+ if options.type_name == 'csv':
+ self.csv_format.default_fragment_scan_options = options.wrapped
+ else:
+ super()._set_default_fragment_scan_options(options)
+
+ def equals(self, CsvFileFormat other):
+ return (
+ self.parse_options.equals(other.parse_options) and
+ self.default_fragment_scan_options ==
+ other.default_fragment_scan_options)
+
+ def __reduce__(self):
+ return CsvFileFormat, (self.parse_options,
+ self.default_fragment_scan_options)
+
+ def __repr__(self):
+ return f"<CsvFileFormat parse_options={self.parse_options}>"
+
+
+cdef class CsvFragmentScanOptions(FragmentScanOptions):
+ """
+ Scan-specific options for CSV fragments.
+
+ Parameters
+ ----------
+ convert_options : ConvertOptions
+ Options regarding value conversion.
+ read_options : ReadOptions
+ General read options.
+ """
+
+ cdef:
+ CCsvFragmentScanOptions* csv_options
+
+ # Avoid mistakingly creating attributes
+ __slots__ = ()
+
+ def __init__(self, ConvertOptions convert_options=None,
+ ReadOptions read_options=None):
+ self.init(shared_ptr[CFragmentScanOptions](
+ new CCsvFragmentScanOptions()))
+ if convert_options is not None:
+ self.convert_options = convert_options
+ if read_options is not None:
+ self.read_options = read_options
+
+ cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp):
+ FragmentScanOptions.init(self, sp)
+ self.csv_options = <CCsvFragmentScanOptions*> sp.get()
+
+ @property
+ def convert_options(self):
+ return ConvertOptions.wrap(self.csv_options.convert_options)
+
+ @convert_options.setter
+ def convert_options(self, ConvertOptions convert_options not None):
+ self.csv_options.convert_options = deref(convert_options.options)
+
+ @property
+ def read_options(self):
+ return ReadOptions.wrap(self.csv_options.read_options)
+
+ @read_options.setter
+ def read_options(self, ReadOptions read_options not None):
+ self.csv_options.read_options = deref(read_options.options)
+
+ def equals(self, CsvFragmentScanOptions other):
+ return (
+ other and
+ self.convert_options.equals(other.convert_options) and
+ self.read_options.equals(other.read_options))
+
+ def __reduce__(self):
+ return CsvFragmentScanOptions, (self.convert_options,
+ self.read_options)
+
+
+cdef class CsvFileWriteOptions(FileWriteOptions):
+ cdef:
+ CCsvFileWriteOptions* csv_options
+ object _properties
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ @property
+ def write_options(self):
+ return WriteOptions.wrap(deref(self.csv_options.write_options))
+
+ @write_options.setter
+ def write_options(self, WriteOptions write_options not None):
+ self.csv_options.write_options.reset(
+ new CCSVWriteOptions(deref(write_options.options)))
+
+ cdef void init(self, const shared_ptr[CFileWriteOptions]& sp):
+ FileWriteOptions.init(self, sp)
+ self.csv_options = <CCsvFileWriteOptions*> sp.get()
+
+
+cdef class Partitioning(_Weakrefable):
+
+ cdef:
+ shared_ptr[CPartitioning] wrapped
+ CPartitioning* partitioning
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef init(self, const shared_ptr[CPartitioning]& sp):
+ self.wrapped = sp
+ self.partitioning = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CPartitioning]& sp):
+ type_name = frombytes(sp.get().type_name())
+
+ classes = {
+ 'directory': DirectoryPartitioning,
+ 'hive': HivePartitioning,
+ }
+
+ class_ = classes.get(type_name, None)
+ if class_ is None:
+ raise TypeError(type_name)
+
+ cdef Partitioning self = class_.__new__(class_)
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[CPartitioning] unwrap(self):
+ return self.wrapped
+
+ def parse(self, path):
+ cdef CResult[CExpression] result
+ result = self.partitioning.Parse(tobytes(path))
+ return Expression.wrap(GetResultValue(result))
+
+ @property
+ def schema(self):
+ """The arrow Schema attached to the partitioning."""
+ return pyarrow_wrap_schema(self.partitioning.schema())
+
+
+cdef class PartitioningFactory(_Weakrefable):
+
+ cdef:
+ shared_ptr[CPartitioningFactory] wrapped
+ CPartitioningFactory* factory
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef init(self, const shared_ptr[CPartitioningFactory]& sp):
+ self.wrapped = sp
+ self.factory = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CPartitioningFactory]& sp):
+ cdef PartitioningFactory self = PartitioningFactory.__new__(
+ PartitioningFactory
+ )
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[CPartitioningFactory] unwrap(self):
+ return self.wrapped
+
+ @property
+ def type_name(self):
+ return frombytes(self.factory.type_name())
+
+
+cdef vector[shared_ptr[CArray]] _partitioning_dictionaries(
+ Schema schema, dictionaries) except *:
+ cdef:
+ vector[shared_ptr[CArray]] c_dictionaries
+
+ dictionaries = dictionaries or {}
+
+ for field in schema:
+ dictionary = dictionaries.get(field.name)
+
+ if (isinstance(field.type, pa.DictionaryType) and
+ dictionary is not None):
+ c_dictionaries.push_back(pyarrow_unwrap_array(dictionary))
+ else:
+ c_dictionaries.push_back(<shared_ptr[CArray]> nullptr)
+
+ return c_dictionaries
+
+
+cdef class DirectoryPartitioning(Partitioning):
+ """
+ A Partitioning based on a specified Schema.
+
+ The DirectoryPartitioning expects one segment in the file path for each
+ field in the schema (all fields are required to be present).
+ For example given schema<year:int16, month:int8> the path "/2009/11" would
+ be parsed to ("year"_ == 2009 and "month"_ == 11).
+
+ Parameters
+ ----------
+ schema : Schema
+ The schema that describes the partitions present in the file path.
+ dictionaries : Dict[str, Array]
+ If the type of any field of `schema` is a dictionary type, the
+ corresponding entry of `dictionaries` must be an array containing
+ every value which may be taken by the corresponding column or an
+ error will be raised in parsing.
+ segment_encoding : str, default "uri"
+ After splitting paths into segments, decode the segments. Valid
+ values are "uri" (URI-decode segments) and "none" (leave as-is).
+
+ Returns
+ -------
+ DirectoryPartitioning
+
+ Examples
+ --------
+ >>> from pyarrow.dataset import DirectoryPartitioning
+ >>> partition = DirectoryPartitioning(
+ ... pa.schema([("year", pa.int16()), ("month", pa.int8())]))
+ >>> print(partitioning.parse("/2009/11"))
+ ((year == 2009:int16) and (month == 11:int8))
+ """
+
+ cdef:
+ CDirectoryPartitioning* directory_partitioning
+
+ def __init__(self, Schema schema not None, dictionaries=None,
+ segment_encoding="uri"):
+ cdef:
+ shared_ptr[CDirectoryPartitioning] c_partitioning
+ CKeyValuePartitioningOptions c_options
+
+ c_options.segment_encoding = _get_segment_encoding(segment_encoding)
+ c_partitioning = make_shared[CDirectoryPartitioning](
+ pyarrow_unwrap_schema(schema),
+ _partitioning_dictionaries(schema, dictionaries),
+ c_options,
+ )
+ self.init(<shared_ptr[CPartitioning]> c_partitioning)
+
+ cdef init(self, const shared_ptr[CPartitioning]& sp):
+ Partitioning.init(self, sp)
+ self.directory_partitioning = <CDirectoryPartitioning*> sp.get()
+
+ @staticmethod
+ def discover(field_names=None, infer_dictionary=False,
+ max_partition_dictionary_size=0,
+ schema=None, segment_encoding="uri"):
+ """
+ Discover a DirectoryPartitioning.
+
+ Parameters
+ ----------
+ field_names : list of str
+ The names to associate with the values from the subdirectory names.
+ If schema is given, will be populated from the schema.
+ infer_dictionary : bool, default False
+ When inferring a schema for partition fields, yield dictionary
+ encoded types instead of plain types. This can be more efficient
+ when materializing virtual columns, and Expressions parsed by the
+ finished Partitioning will include dictionaries of all unique
+ inspected values for each field.
+ max_partition_dictionary_size : int, default 0
+ Synonymous with infer_dictionary for backwards compatibility with
+ 1.0: setting this to -1 or None is equivalent to passing
+ infer_dictionary=True.
+ schema : Schema, default None
+ Use this schema instead of inferring a schema from partition
+ values. Partition values will be validated against this schema
+ before accumulation into the Partitioning's dictionary.
+ segment_encoding : str, default "uri"
+ After splitting paths into segments, decode the segments. Valid
+ values are "uri" (URI-decode segments) and "none" (leave as-is).
+
+ Returns
+ -------
+ PartitioningFactory
+ To be used in the FileSystemFactoryOptions.
+ """
+ cdef:
+ CPartitioningFactoryOptions c_options
+ vector[c_string] c_field_names
+
+ if max_partition_dictionary_size in {-1, None}:
+ infer_dictionary = True
+ elif max_partition_dictionary_size != 0:
+ raise NotImplementedError("max_partition_dictionary_size must be "
+ "0, -1, or None")
+
+ if infer_dictionary:
+ c_options.infer_dictionary = True
+
+ if schema:
+ c_options.schema = pyarrow_unwrap_schema(schema)
+ c_field_names = [tobytes(f.name) for f in schema]
+ elif not field_names:
+ raise ValueError(
+ "Neither field_names nor schema was passed; "
+ "cannot infer field_names")
+ else:
+ c_field_names = [tobytes(s) for s in field_names]
+
+ c_options.segment_encoding = _get_segment_encoding(segment_encoding)
+
+ return PartitioningFactory.wrap(
+ CDirectoryPartitioning.MakeFactory(c_field_names, c_options))
+
+ @property
+ def dictionaries(self):
+ """
+ The unique values for each partition field, if available.
+
+ Those values are only available if the Partitioning object was
+ created through dataset discovery from a PartitioningFactory, or
+ if the dictionaries were manually specified in the constructor.
+ If not available, this returns None.
+ """
+ cdef vector[shared_ptr[CArray]] c_arrays
+ c_arrays = self.directory_partitioning.dictionaries()
+ res = []
+ for arr in c_arrays:
+ if arr.get() == nullptr:
+ # Partitioning object has not been created through
+ # inspected Factory
+ return None
+ res.append(pyarrow_wrap_array(arr))
+ return res
+
+
+cdef class HivePartitioning(Partitioning):
+ """
+ A Partitioning for "/$key=$value/" nested directories as found in
+ Apache Hive.
+
+ Multi-level, directory based partitioning scheme originating from
+ Apache Hive with all data files stored in the leaf directories. Data is
+ partitioned by static values of a particular column in the schema.
+ Partition keys are represented in the form $key=$value in directory names.
+ Field order is ignored, as are missing or unrecognized field names.
+
+ For example, given schema<year:int16, month:int8, day:int8>, a possible
+ path would be "/year=2009/month=11/day=15".
+
+ Parameters
+ ----------
+ schema : Schema
+ The schema that describes the partitions present in the file path.
+ dictionaries : Dict[str, Array]
+ If the type of any field of `schema` is a dictionary type, the
+ corresponding entry of `dictionaries` must be an array containing
+ every value which may be taken by the corresponding column or an
+ error will be raised in parsing.
+ null_fallback : str, default "__HIVE_DEFAULT_PARTITION__"
+ If any field is None then this fallback will be used as a label
+ segment_encoding : str, default "uri"
+ After splitting paths into segments, decode the segments. Valid
+ values are "uri" (URI-decode segments) and "none" (leave as-is).
+
+ Returns
+ -------
+ HivePartitioning
+
+ Examples
+ --------
+ >>> from pyarrow.dataset import HivePartitioning
+ >>> partitioning = HivePartitioning(
+ ... pa.schema([("year", pa.int16()), ("month", pa.int8())]))
+ >>> print(partitioning.parse("/year=2009/month=11"))
+ ((year == 2009:int16) and (month == 11:int8))
+
+ """
+
+ cdef:
+ CHivePartitioning* hive_partitioning
+
+ def __init__(self,
+ Schema schema not None,
+ dictionaries=None,
+ null_fallback="__HIVE_DEFAULT_PARTITION__",
+ segment_encoding="uri"):
+
+ cdef:
+ shared_ptr[CHivePartitioning] c_partitioning
+ CHivePartitioningOptions c_options
+
+ c_options.null_fallback = tobytes(null_fallback)
+ c_options.segment_encoding = _get_segment_encoding(segment_encoding)
+
+ c_partitioning = make_shared[CHivePartitioning](
+ pyarrow_unwrap_schema(schema),
+ _partitioning_dictionaries(schema, dictionaries),
+ c_options,
+ )
+ self.init(<shared_ptr[CPartitioning]> c_partitioning)
+
+ cdef init(self, const shared_ptr[CPartitioning]& sp):
+ Partitioning.init(self, sp)
+ self.hive_partitioning = <CHivePartitioning*> sp.get()
+
+ @staticmethod
+ def discover(infer_dictionary=False,
+ max_partition_dictionary_size=0,
+ null_fallback="__HIVE_DEFAULT_PARTITION__",
+ schema=None,
+ segment_encoding="uri"):
+ """
+ Discover a HivePartitioning.
+
+ Parameters
+ ----------
+ infer_dictionary : bool, default False
+ When inferring a schema for partition fields, yield dictionary
+ encoded types instead of plain. This can be more efficient when
+ materializing virtual columns, and Expressions parsed by the
+ finished Partitioning will include dictionaries of all unique
+ inspected values for each field.
+ max_partition_dictionary_size : int, default 0
+ Synonymous with infer_dictionary for backwards compatibility with
+ 1.0: setting this to -1 or None is equivalent to passing
+ infer_dictionary=True.
+ null_fallback : str, default "__HIVE_DEFAULT_PARTITION__"
+ When inferring a schema for partition fields this value will be
+ replaced by null. The default is set to __HIVE_DEFAULT_PARTITION__
+ for compatibility with Spark
+ schema : Schema, default None
+ Use this schema instead of inferring a schema from partition
+ values. Partition values will be validated against this schema
+ before accumulation into the Partitioning's dictionary.
+ segment_encoding : str, default "uri"
+ After splitting paths into segments, decode the segments. Valid
+ values are "uri" (URI-decode segments) and "none" (leave as-is).
+
+ Returns
+ -------
+ PartitioningFactory
+ To be used in the FileSystemFactoryOptions.
+ """
+ cdef:
+ CHivePartitioningFactoryOptions c_options
+
+ if max_partition_dictionary_size in {-1, None}:
+ infer_dictionary = True
+ elif max_partition_dictionary_size != 0:
+ raise NotImplementedError("max_partition_dictionary_size must be "
+ "0, -1, or None")
+
+ if infer_dictionary:
+ c_options.infer_dictionary = True
+
+ c_options.null_fallback = tobytes(null_fallback)
+
+ if schema:
+ c_options.schema = pyarrow_unwrap_schema(schema)
+
+ c_options.segment_encoding = _get_segment_encoding(segment_encoding)
+
+ return PartitioningFactory.wrap(
+ CHivePartitioning.MakeFactory(c_options))
+
+ @property
+ def dictionaries(self):
+ """
+ The unique values for each partition field, if available.
+
+ Those values are only available if the Partitioning object was
+ created through dataset discovery from a PartitioningFactory, or
+ if the dictionaries were manually specified in the constructor.
+ If not available, this returns None.
+ """
+ cdef vector[shared_ptr[CArray]] c_arrays
+ c_arrays = self.hive_partitioning.dictionaries()
+ res = []
+ for arr in c_arrays:
+ if arr.get() == nullptr:
+ # Partitioning object has not been created through
+ # inspected Factory
+ return None
+ res.append(pyarrow_wrap_array(arr))
+ return res
+
+
+cdef class DatasetFactory(_Weakrefable):
+ """
+ DatasetFactory is used to create a Dataset, inspect the Schema
+ of the fragments contained in it, and declare a partitioning.
+ """
+
+ cdef:
+ shared_ptr[CDatasetFactory] wrapped
+ CDatasetFactory* factory
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef init(self, const shared_ptr[CDatasetFactory]& sp):
+ self.wrapped = sp
+ self.factory = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CDatasetFactory]& sp):
+ cdef DatasetFactory self = \
+ DatasetFactory.__new__(DatasetFactory)
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[CDatasetFactory] unwrap(self) nogil:
+ return self.wrapped
+
+ @property
+ def root_partition(self):
+ return Expression.wrap(self.factory.root_partition())
+
+ @root_partition.setter
+ def root_partition(self, Expression expr):
+ check_status(self.factory.SetRootPartition(expr.unwrap()))
+
+ def inspect_schemas(self):
+ cdef CResult[vector[shared_ptr[CSchema]]] result
+ cdef CInspectOptions options
+ with nogil:
+ result = self.factory.InspectSchemas(options)
+
+ schemas = []
+ for s in GetResultValue(result):
+ schemas.append(pyarrow_wrap_schema(s))
+ return schemas
+
+ def inspect(self):
+ """
+ Inspect all data fragments and return a common Schema.
+
+ Returns
+ -------
+ Schema
+ """
+ cdef:
+ CInspectOptions options
+ CResult[shared_ptr[CSchema]] result
+ with nogil:
+ result = self.factory.Inspect(options)
+ return pyarrow_wrap_schema(GetResultValue(result))
+
+ def finish(self, Schema schema=None):
+ """
+ Create a Dataset using the inspected schema or an explicit schema
+ (if given).
+
+ Parameters
+ ----------
+ schema : Schema, default None
+ The schema to conform the source to. If None, the inspected
+ schema is used.
+
+ Returns
+ -------
+ Dataset
+ """
+ cdef:
+ shared_ptr[CSchema] sp_schema
+ CResult[shared_ptr[CDataset]] result
+
+ if schema is not None:
+ sp_schema = pyarrow_unwrap_schema(schema)
+ with nogil:
+ result = self.factory.FinishWithSchema(sp_schema)
+ else:
+ with nogil:
+ result = self.factory.Finish()
+
+ return Dataset.wrap(GetResultValue(result))
+
+
+cdef class FileSystemFactoryOptions(_Weakrefable):
+ """
+ Influences the discovery of filesystem paths.
+
+ Parameters
+ ----------
+ partition_base_dir : str, optional
+ For the purposes of applying the partitioning, paths will be
+ stripped of the partition_base_dir. Files not matching the
+ partition_base_dir prefix will be skipped for partitioning discovery.
+ The ignored files will still be part of the Dataset, but will not
+ have partition information.
+ partitioning : Partitioning/PartitioningFactory, optional
+ Apply the Partitioning to every discovered Fragment. See Partitioning or
+ PartitioningFactory documentation.
+ exclude_invalid_files : bool, optional (default True)
+ If True, invalid files will be excluded (file format specific check).
+ This will incur IO for each files in a serial and single threaded
+ fashion. Disabling this feature will skip the IO, but unsupported
+ files may be present in the Dataset (resulting in an error at scan
+ time).
+ selector_ignore_prefixes : list, optional
+ When discovering from a Selector (and not from an explicit file list),
+ ignore files and directories matching any of these prefixes.
+ By default this is ['.', '_'].
+ """
+
+ cdef:
+ CFileSystemFactoryOptions options
+
+ __slots__ = () # avoid mistakingly creating attributes
+
+ def __init__(self, partition_base_dir=None, partitioning=None,
+ exclude_invalid_files=None,
+ list selector_ignore_prefixes=None):
+ if isinstance(partitioning, PartitioningFactory):
+ self.partitioning_factory = partitioning
+ elif isinstance(partitioning, Partitioning):
+ self.partitioning = partitioning
+
+ if partition_base_dir is not None:
+ self.partition_base_dir = partition_base_dir
+ if exclude_invalid_files is not None:
+ self.exclude_invalid_files = exclude_invalid_files
+ if selector_ignore_prefixes is not None:
+ self.selector_ignore_prefixes = selector_ignore_prefixes
+
+ cdef inline CFileSystemFactoryOptions unwrap(self):
+ return self.options
+
+ @property
+ def partitioning(self):
+ """Partitioning to apply to discovered files.
+
+ NOTE: setting this property will overwrite partitioning_factory.
+ """
+ c_partitioning = self.options.partitioning.partitioning()
+ if c_partitioning.get() == nullptr:
+ return None
+ return Partitioning.wrap(c_partitioning)
+
+ @partitioning.setter
+ def partitioning(self, Partitioning value):
+ self.options.partitioning = (<Partitioning> value).unwrap()
+
+ @property
+ def partitioning_factory(self):
+ """PartitioningFactory to apply to discovered files and
+ discover a Partitioning.
+
+ NOTE: setting this property will overwrite partitioning.
+ """
+ c_factory = self.options.partitioning.factory()
+ if c_factory.get() == nullptr:
+ return None
+ return PartitioningFactory.wrap(c_factory)
+
+ @partitioning_factory.setter
+ def partitioning_factory(self, PartitioningFactory value):
+ self.options.partitioning = (<PartitioningFactory> value).unwrap()
+
+ @property
+ def partition_base_dir(self):
+ """
+ Base directory to strip paths before applying the partitioning.
+ """
+ return frombytes(self.options.partition_base_dir)
+
+ @partition_base_dir.setter
+ def partition_base_dir(self, value):
+ self.options.partition_base_dir = tobytes(value)
+
+ @property
+ def exclude_invalid_files(self):
+ """Whether to exclude invalid files."""
+ return self.options.exclude_invalid_files
+
+ @exclude_invalid_files.setter
+ def exclude_invalid_files(self, bint value):
+ self.options.exclude_invalid_files = value
+
+ @property
+ def selector_ignore_prefixes(self):
+ """
+ List of prefixes. Files matching one of those prefixes will be
+ ignored by the discovery process.
+ """
+ return [frombytes(p) for p in self.options.selector_ignore_prefixes]
+
+ @selector_ignore_prefixes.setter
+ def selector_ignore_prefixes(self, values):
+ self.options.selector_ignore_prefixes = [tobytes(v) for v in values]
+
+
+cdef class FileSystemDatasetFactory(DatasetFactory):
+ """
+ Create a DatasetFactory from a list of paths with schema inspection.
+
+ Parameters
+ ----------
+ filesystem : pyarrow.fs.FileSystem
+ Filesystem to discover.
+ paths_or_selector : pyarrow.fs.Selector or list of path-likes
+ Either a Selector object or a list of path-like objects.
+ format : FileFormat
+ Currently only ParquetFileFormat and IpcFileFormat are supported.
+ options : FileSystemFactoryOptions, optional
+ Various flags influencing the discovery of filesystem paths.
+ """
+
+ cdef:
+ CFileSystemDatasetFactory* filesystem_factory
+
+ def __init__(self, FileSystem filesystem not None, paths_or_selector,
+ FileFormat format not None,
+ FileSystemFactoryOptions options=None):
+ cdef:
+ vector[c_string] paths
+ CFileSelector c_selector
+ CResult[shared_ptr[CDatasetFactory]] result
+ shared_ptr[CFileSystem] c_filesystem
+ shared_ptr[CFileFormat] c_format
+ CFileSystemFactoryOptions c_options
+
+ options = options or FileSystemFactoryOptions()
+ c_options = options.unwrap()
+ c_filesystem = filesystem.unwrap()
+ c_format = format.unwrap()
+
+ if isinstance(paths_or_selector, FileSelector):
+ with nogil:
+ c_selector = (<FileSelector> paths_or_selector).selector
+ result = CFileSystemDatasetFactory.MakeFromSelector(
+ c_filesystem,
+ c_selector,
+ c_format,
+ c_options
+ )
+ elif isinstance(paths_or_selector, (list, tuple)):
+ paths = [tobytes(s) for s in paths_or_selector]
+ with nogil:
+ result = CFileSystemDatasetFactory.MakeFromPaths(
+ c_filesystem,
+ paths,
+ c_format,
+ c_options
+ )
+ else:
+ raise TypeError('Must pass either paths or a FileSelector, but '
+ 'passed {}'.format(type(paths_or_selector)))
+
+ self.init(GetResultValue(result))
+
+ cdef init(self, shared_ptr[CDatasetFactory]& sp):
+ DatasetFactory.init(self, sp)
+ self.filesystem_factory = <CFileSystemDatasetFactory*> sp.get()
+
+
+cdef class UnionDatasetFactory(DatasetFactory):
+ """
+ Provides a way to inspect/discover a Dataset's expected schema before
+ materialization.
+
+ Parameters
+ ----------
+ factories : list of DatasetFactory
+ """
+
+ cdef:
+ CUnionDatasetFactory* union_factory
+
+ def __init__(self, list factories):
+ cdef:
+ DatasetFactory factory
+ vector[shared_ptr[CDatasetFactory]] c_factories
+ for factory in factories:
+ c_factories.push_back(factory.unwrap())
+ self.init(GetResultValue(CUnionDatasetFactory.Make(c_factories)))
+
+ cdef init(self, const shared_ptr[CDatasetFactory]& sp):
+ DatasetFactory.init(self, sp)
+ self.union_factory = <CUnionDatasetFactory*> sp.get()
+
+
+cdef class ParquetFactoryOptions(_Weakrefable):
+ """
+ Influences the discovery of parquet dataset.
+
+ Parameters
+ ----------
+ partition_base_dir : str, optional
+ For the purposes of applying the partitioning, paths will be
+ stripped of the partition_base_dir. Files not matching the
+ partition_base_dir prefix will be skipped for partitioning discovery.
+ The ignored files will still be part of the Dataset, but will not
+ have partition information.
+ partitioning : Partitioning, PartitioningFactory, optional
+ The partitioning scheme applied to fragments, see ``Partitioning``.
+ validate_column_chunk_paths : bool, default False
+ Assert that all ColumnChunk paths are consistent. The parquet spec
+ allows for ColumnChunk data to be stored in multiple files, but
+ ParquetDatasetFactory supports only a single file with all ColumnChunk
+ data. If this flag is set construction of a ParquetDatasetFactory will
+ raise an error if ColumnChunk data is not resident in a single file.
+ """
+
+ cdef:
+ CParquetFactoryOptions options
+
+ __slots__ = () # avoid mistakingly creating attributes
+
+ def __init__(self, partition_base_dir=None, partitioning=None,
+ validate_column_chunk_paths=False):
+ if isinstance(partitioning, PartitioningFactory):
+ self.partitioning_factory = partitioning
+ elif isinstance(partitioning, Partitioning):
+ self.partitioning = partitioning
+
+ if partition_base_dir is not None:
+ self.partition_base_dir = partition_base_dir
+
+ self.options.validate_column_chunk_paths = validate_column_chunk_paths
+
+ cdef inline CParquetFactoryOptions unwrap(self):
+ return self.options
+
+ @property
+ def partitioning(self):
+ """Partitioning to apply to discovered files.
+
+ NOTE: setting this property will overwrite partitioning_factory.
+ """
+ c_partitioning = self.options.partitioning.partitioning()
+ if c_partitioning.get() == nullptr:
+ return None
+ return Partitioning.wrap(c_partitioning)
+
+ @partitioning.setter
+ def partitioning(self, Partitioning value):
+ self.options.partitioning = (<Partitioning> value).unwrap()
+
+ @property
+ def partitioning_factory(self):
+ """PartitioningFactory to apply to discovered files and
+ discover a Partitioning.
+
+ NOTE: setting this property will overwrite partitioning.
+ """
+ c_factory = self.options.partitioning.factory()
+ if c_factory.get() == nullptr:
+ return None
+ return PartitioningFactory.wrap(c_factory)
+
+ @partitioning_factory.setter
+ def partitioning_factory(self, PartitioningFactory value):
+ self.options.partitioning = (<PartitioningFactory> value).unwrap()
+
+ @property
+ def partition_base_dir(self):
+ """
+ Base directory to strip paths before applying the partitioning.
+ """
+ return frombytes(self.options.partition_base_dir)
+
+ @partition_base_dir.setter
+ def partition_base_dir(self, value):
+ self.options.partition_base_dir = tobytes(value)
+
+ @property
+ def validate_column_chunk_paths(self):
+ """
+ Base directory to strip paths before applying the partitioning.
+ """
+ return self.options.validate_column_chunk_paths
+
+ @validate_column_chunk_paths.setter
+ def validate_column_chunk_paths(self, value):
+ self.options.validate_column_chunk_paths = value
+
+
+cdef class ParquetDatasetFactory(DatasetFactory):
+ """
+ Create a ParquetDatasetFactory from a Parquet `_metadata` file.
+
+ Parameters
+ ----------
+ metadata_path : str
+ Path to the `_metadata` parquet metadata-only file generated with
+ `pyarrow.parquet.write_metadata`.
+ filesystem : pyarrow.fs.FileSystem
+ Filesystem to read the metadata_path from, and subsequent parquet
+ files.
+ format : ParquetFileFormat
+ Parquet format options.
+ options : ParquetFactoryOptions, optional
+ Various flags influencing the discovery of filesystem paths.
+ """
+
+ cdef:
+ CParquetDatasetFactory* parquet_factory
+
+ def __init__(self, metadata_path, FileSystem filesystem not None,
+ FileFormat format not None,
+ ParquetFactoryOptions options=None):
+ cdef:
+ c_string path
+ shared_ptr[CFileSystem] c_filesystem
+ shared_ptr[CParquetFileFormat] c_format
+ CResult[shared_ptr[CDatasetFactory]] result
+ CParquetFactoryOptions c_options
+
+ c_path = tobytes(metadata_path)
+ c_filesystem = filesystem.unwrap()
+ c_format = static_pointer_cast[CParquetFileFormat, CFileFormat](
+ format.unwrap())
+ options = options or ParquetFactoryOptions()
+ c_options = options.unwrap()
+
+ result = CParquetDatasetFactory.MakeFromMetaDataPath(
+ c_path, c_filesystem, c_format, c_options)
+ self.init(GetResultValue(result))
+
+ cdef init(self, shared_ptr[CDatasetFactory]& sp):
+ DatasetFactory.init(self, sp)
+ self.parquet_factory = <CParquetDatasetFactory*> sp.get()
+
+
+cdef class RecordBatchIterator(_Weakrefable):
+ """An iterator over a sequence of record batches."""
+ cdef:
+ # An object that must be kept alive with the iterator.
+ object iterator_owner
+ # Iterator is a non-POD type and Cython uses offsetof, leading
+ # to a compiler warning unless wrapped like so
+ shared_ptr[CRecordBatchIterator] iterator
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__, subclasses_instead=False)
+
+ @staticmethod
+ cdef wrap(object owner, CRecordBatchIterator iterator):
+ cdef RecordBatchIterator self = \
+ RecordBatchIterator.__new__(RecordBatchIterator)
+ self.iterator_owner = owner
+ self.iterator = make_shared[CRecordBatchIterator](move(iterator))
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ cdef shared_ptr[CRecordBatch] record_batch
+ with nogil:
+ record_batch = GetResultValue(move(self.iterator.get().Next()))
+ if record_batch == NULL:
+ raise StopIteration
+ return pyarrow_wrap_batch(record_batch)
+
+
+class TaggedRecordBatch(collections.namedtuple(
+ "TaggedRecordBatch", ["record_batch", "fragment"])):
+ """
+ A combination of a record batch and the fragment it came from.
+
+ Parameters
+ ----------
+ record_batch : The record batch.
+ fragment : fragment of the record batch.
+ """
+
+
+cdef class TaggedRecordBatchIterator(_Weakrefable):
+ """An iterator over a sequence of record batches with fragments."""
+ cdef:
+ object iterator_owner
+ shared_ptr[CTaggedRecordBatchIterator] iterator
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__, subclasses_instead=False)
+
+ @staticmethod
+ cdef wrap(object owner, CTaggedRecordBatchIterator iterator):
+ cdef TaggedRecordBatchIterator self = \
+ TaggedRecordBatchIterator.__new__(TaggedRecordBatchIterator)
+ self.iterator_owner = owner
+ self.iterator = make_shared[CTaggedRecordBatchIterator](
+ move(iterator))
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ cdef CTaggedRecordBatch batch
+ with nogil:
+ batch = GetResultValue(move(self.iterator.get().Next()))
+ if batch.record_batch == NULL:
+ raise StopIteration
+ return TaggedRecordBatch(
+ record_batch=pyarrow_wrap_batch(batch.record_batch),
+ fragment=Fragment.wrap(batch.fragment))
+
+
+_DEFAULT_BATCH_SIZE = 2**20
+
+
+cdef void _populate_builder(const shared_ptr[CScannerBuilder]& ptr,
+ object columns=None, Expression filter=None,
+ int batch_size=_DEFAULT_BATCH_SIZE,
+ bint use_threads=True, bint use_async=False,
+ MemoryPool memory_pool=None,
+ FragmentScanOptions fragment_scan_options=None)\
+ except *:
+ cdef:
+ CScannerBuilder *builder
+ vector[CExpression] c_exprs
+
+ builder = ptr.get()
+
+ check_status(builder.Filter(_bind(
+ filter, pyarrow_wrap_schema(builder.schema()))))
+
+ if columns is not None:
+ if isinstance(columns, dict):
+ for expr in columns.values():
+ if not isinstance(expr, Expression):
+ raise TypeError(
+ "Expected an Expression for a 'column' dictionary "
+ "value, got {} instead".format(type(expr))
+ )
+ c_exprs.push_back((<Expression> expr).unwrap())
+
+ check_status(
+ builder.Project(c_exprs, [tobytes(c) for c in columns.keys()])
+ )
+ elif isinstance(columns, list):
+ check_status(builder.ProjectColumns([tobytes(c) for c in columns]))
+ else:
+ raise ValueError(
+ "Expected a list or a dict for 'columns', "
+ "got {} instead.".format(type(columns))
+ )
+
+ check_status(builder.BatchSize(batch_size))
+ check_status(builder.UseThreads(use_threads))
+ check_status(builder.UseAsync(use_async))
+ if memory_pool:
+ check_status(builder.Pool(maybe_unbox_memory_pool(memory_pool)))
+ if fragment_scan_options:
+ check_status(
+ builder.FragmentScanOptions(fragment_scan_options.wrapped))
+
+
+cdef class Scanner(_Weakrefable):
+ """A materialized scan operation with context and options bound.
+
+ A scanner is the class that glues the scan tasks, data fragments and data
+ sources together.
+
+ Parameters
+ ----------
+ dataset : Dataset
+ Dataset to scan.
+ columns : list of str or dict, default None
+ The columns to project. This can be a list of column names to include
+ (order and duplicates will be preserved), or a dictionary with
+ {new_column_name: expression} values for more advanced projections.
+ The columns will be passed down to Datasets and corresponding data
+ fragments to avoid loading, copying, and deserializing columns
+ that will not be required further down the compute chain.
+ By default all of the available columns are projected. Raises
+ an exception if any of the referenced column names does not exist
+ in the dataset's Schema.
+ filter : Expression, default None
+ Scan will return only the rows matching the filter.
+ If possible the predicate will be pushed down to exploit the
+ partition information or internal metadata found in the data
+ source, e.g. Parquet statistics. Otherwise filters the loaded
+ RecordBatches before yielding them.
+ batch_size : int, default 1M
+ The maximum row count for scanned record batches. If scanned
+ record batches are overflowing memory then this method can be
+ called to reduce their size.
+ use_threads : bool, default True
+ If enabled, then maximum parallelism will be used determined by
+ the number of available CPU cores.
+ use_async : bool, default False
+ If enabled, an async scanner will be used that should offer
+ better performance with high-latency/highly-parallel filesystems
+ (e.g. S3)
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required. If not specified, uses the
+ default pool.
+ """
+
+ cdef:
+ shared_ptr[CScanner] wrapped
+ CScanner* scanner
+
+ def __init__(self):
+ _forbid_instantiation(self.__class__)
+
+ cdef void init(self, const shared_ptr[CScanner]& sp):
+ self.wrapped = sp
+ self.scanner = sp.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CScanner]& sp):
+ cdef Scanner self = Scanner.__new__(Scanner)
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[CScanner] unwrap(self):
+ return self.wrapped
+
+ @staticmethod
+ def from_dataset(Dataset dataset not None,
+ bint use_threads=True, bint use_async=False,
+ MemoryPool memory_pool=None,
+ object columns=None, Expression filter=None,
+ int batch_size=_DEFAULT_BATCH_SIZE,
+ FragmentScanOptions fragment_scan_options=None):
+ """
+ Create Scanner from Dataset,
+ refer to Scanner class doc for additional details on Scanner.
+
+ Parameters
+ ----------
+ dataset : Dataset
+ Dataset to scan.
+ columns : list of str or dict, default None
+ The columns to project.
+ filter : Expression, default None
+ Scan will return only the rows matching the filter.
+ batch_size : int, default 1M
+ The maximum row count for scanned record batches.
+ use_threads : bool, default True
+ If enabled, then maximum parallelism will be used determined by
+ the number of available CPU cores.
+ use_async : bool, default False
+ If enabled, an async scanner will be used that should offer
+ better performance with high-latency/highly-parallel filesystems
+ (e.g. S3)
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required. If not specified, uses the
+ default pool.
+ fragment_scan_options : FragmentScanOptions
+ The fragment scan options.
+ """
+ cdef:
+ shared_ptr[CScanOptions] options = make_shared[CScanOptions]()
+ shared_ptr[CScannerBuilder] builder
+ shared_ptr[CScanner] scanner
+
+ builder = make_shared[CScannerBuilder](dataset.unwrap(), options)
+ _populate_builder(builder, columns=columns, filter=filter,
+ batch_size=batch_size, use_threads=use_threads,
+ use_async=use_async, memory_pool=memory_pool,
+ fragment_scan_options=fragment_scan_options)
+
+ scanner = GetResultValue(builder.get().Finish())
+ return Scanner.wrap(scanner)
+
+ @staticmethod
+ def from_fragment(Fragment fragment not None, Schema schema=None,
+ bint use_threads=True, bint use_async=False,
+ MemoryPool memory_pool=None,
+ object columns=None, Expression filter=None,
+ int batch_size=_DEFAULT_BATCH_SIZE,
+ FragmentScanOptions fragment_scan_options=None):
+ """
+ Create Scanner from Fragment,
+ refer to Scanner class doc for additional details on Scanner.
+
+ Parameters
+ ----------
+ fragment : Fragment
+ fragment to scan.
+ schema : Schema
+ The schema of the fragment.
+ columns : list of str or dict, default None
+ The columns to project.
+ filter : Expression, default None
+ Scan will return only the rows matching the filter.
+ batch_size : int, default 1M
+ The maximum row count for scanned record batches.
+ use_threads : bool, default True
+ If enabled, then maximum parallelism will be used determined by
+ the number of available CPU cores.
+ use_async : bool, default False
+ If enabled, an async scanner will be used that should offer
+ better performance with high-latency/highly-parallel filesystems
+ (e.g. S3)
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required. If not specified, uses the
+ default pool.
+ fragment_scan_options : FragmentScanOptions
+ The fragment scan options.
+ """
+ cdef:
+ shared_ptr[CScanOptions] options = make_shared[CScanOptions]()
+ shared_ptr[CScannerBuilder] builder
+ shared_ptr[CScanner] scanner
+
+ schema = schema or fragment.physical_schema
+
+ builder = make_shared[CScannerBuilder](pyarrow_unwrap_schema(schema),
+ fragment.unwrap(), options)
+ _populate_builder(builder, columns=columns, filter=filter,
+ batch_size=batch_size, use_threads=use_threads,
+ use_async=use_async, memory_pool=memory_pool,
+ fragment_scan_options=fragment_scan_options)
+
+ scanner = GetResultValue(builder.get().Finish())
+ return Scanner.wrap(scanner)
+
+ @staticmethod
+ def from_batches(source, Schema schema=None, bint use_threads=True,
+ bint use_async=False,
+ MemoryPool memory_pool=None, object columns=None,
+ Expression filter=None,
+ int batch_size=_DEFAULT_BATCH_SIZE,
+ FragmentScanOptions fragment_scan_options=None):
+ """
+ Create a Scanner from an iterator of batches.
+
+ This creates a scanner which can be used only once. It is
+ intended to support writing a dataset (which takes a scanner)
+ from a source which can be read only once (e.g. a
+ RecordBatchReader or generator).
+
+ Parameters
+ ----------
+ source : Iterator
+ The iterator of Batches.
+ schema : Schema
+ The schema of the batches.
+ columns : list of str or dict, default None
+ The columns to project.
+ filter : Expression, default None
+ Scan will return only the rows matching the filter.
+ batch_size : int, default 1M
+ The maximum row count for scanned record batches.
+ use_threads : bool, default True
+ If enabled, then maximum parallelism will be used determined by
+ the number of available CPU cores.
+ use_async : bool, default False
+ If enabled, an async scanner will be used that should offer
+ better performance with high-latency/highly-parallel filesystems
+ (e.g. S3)
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required. If not specified, uses the
+ default pool.
+ fragment_scan_options : FragmentScanOptions
+ The fragment scan options.
+ """
+ cdef:
+ shared_ptr[CScanOptions] options = make_shared[CScanOptions]()
+ shared_ptr[CScannerBuilder] builder
+ shared_ptr[CScanner] scanner
+ RecordBatchReader reader
+ if isinstance(source, pa.ipc.RecordBatchReader):
+ if schema:
+ raise ValueError('Cannot specify a schema when providing '
+ 'a RecordBatchReader')
+ reader = source
+ elif _is_iterable(source):
+ if schema is None:
+ raise ValueError('Must provide schema to construct scanner '
+ 'from an iterable')
+ reader = pa.ipc.RecordBatchReader.from_batches(schema, source)
+ else:
+ raise TypeError('Expected a RecordBatchReader or an iterable of '
+ 'batches instead of the given type: ' +
+ type(source).__name__)
+ builder = CScannerBuilder.FromRecordBatchReader(reader.reader)
+ _populate_builder(builder, columns=columns, filter=filter,
+ batch_size=batch_size, use_threads=use_threads,
+ use_async=use_async, memory_pool=memory_pool,
+ fragment_scan_options=fragment_scan_options)
+ scanner = GetResultValue(builder.get().Finish())
+ return Scanner.wrap(scanner)
+
+ @property
+ def dataset_schema(self):
+ """The schema with which batches will be read from fragments."""
+ return pyarrow_wrap_schema(
+ self.scanner.options().get().dataset_schema)
+
+ @property
+ def projected_schema(self):
+ """The materialized schema of the data, accounting for projections.
+
+ This is the schema of any data returned from the scanner.
+ """
+ return pyarrow_wrap_schema(
+ self.scanner.options().get().projected_schema)
+
+ def to_batches(self):
+ """Consume a Scanner in record batches.
+
+ Returns
+ -------
+ record_batches : iterator of RecordBatch
+ """
+ def _iterator(batch_iter):
+ for batch in batch_iter:
+ yield batch.record_batch
+ # Don't make ourselves a generator so errors are raised immediately
+ return _iterator(self.scan_batches())
+
+ def scan_batches(self):
+ """Consume a Scanner in record batches with corresponding fragments.
+
+ Returns
+ -------
+ record_batches : iterator of TaggedRecordBatch
+ """
+ cdef CTaggedRecordBatchIterator iterator
+ with nogil:
+ iterator = move(GetResultValue(self.scanner.ScanBatches()))
+ # Don't make ourselves a generator so errors are raised immediately
+ return TaggedRecordBatchIterator.wrap(self, move(iterator))
+
+ def to_table(self):
+ """Convert a Scanner into a Table.
+
+ Use this convenience utility with care. This will serially materialize
+ the Scan result in memory before creating the Table.
+
+ Returns
+ -------
+ table : Table
+ """
+ cdef CResult[shared_ptr[CTable]] result
+
+ with nogil:
+ result = self.scanner.ToTable()
+
+ return pyarrow_wrap_table(GetResultValue(result))
+
+ def take(self, object indices):
+ """Select rows of data by index.
+
+ Will only consume as many batches of the underlying dataset as
+ needed. Otherwise, this is equivalent to
+ ``to_table().take(indices)``.
+
+ Returns
+ -------
+ table : Table
+ """
+ cdef CResult[shared_ptr[CTable]] result
+ cdef shared_ptr[CArray] c_indices = pyarrow_unwrap_array(indices)
+ with nogil:
+ result = self.scanner.TakeRows(deref(c_indices))
+ return pyarrow_wrap_table(GetResultValue(result))
+
+ def head(self, int num_rows):
+ """Load the first N rows of the dataset.
+
+ Returns
+ -------
+ table : Table instance
+ """
+ cdef CResult[shared_ptr[CTable]] result
+ with nogil:
+ result = self.scanner.Head(num_rows)
+ return pyarrow_wrap_table(GetResultValue(result))
+
+ def count_rows(self):
+ """Count rows matching the scanner filter.
+
+ Returns
+ -------
+ count : int
+ """
+ cdef CResult[int64_t] result
+ with nogil:
+ result = self.scanner.CountRows()
+ return GetResultValue(result)
+
+ def to_reader(self):
+ """Consume this scanner as a RecordBatchReader."""
+ cdef RecordBatchReader reader
+ reader = RecordBatchReader.__new__(RecordBatchReader)
+ reader.reader = GetResultValue(self.scanner.ToRecordBatchReader())
+ return reader
+
+
+def _get_partition_keys(Expression partition_expression):
+ """
+ Extract partition keys (equality constraints between a field and a scalar)
+ from an expression as a dict mapping the field's name to its value.
+
+ NB: All expressions yielded by a HivePartitioning or DirectoryPartitioning
+ will be conjunctions of equality conditions and are accessible through this
+ function. Other subexpressions will be ignored.
+
+ For example, an expression of
+ <pyarrow.dataset.Expression ((part == A:string) and (year == 2016:int32))>
+ is converted to {'part': 'A', 'year': 2016}
+ """
+ cdef:
+ CExpression expr = partition_expression.unwrap()
+ pair[CFieldRef, CDatum] ref_val
+
+ out = {}
+ for ref_val in GetResultValue(CExtractKnownFieldValues(expr)).map:
+ assert ref_val.first.name() != nullptr
+ assert ref_val.second.kind() == DatumType_SCALAR
+ val = pyarrow_wrap_scalar(ref_val.second.scalar())
+ out[frombytes(deref(ref_val.first.name()))] = val.as_py()
+ return out
+
+
+ctypedef CParquetFileWriter* _CParquetFileWriterPtr
+
+cdef class WrittenFile(_Weakrefable):
+ """
+ Metadata information about files written as
+ part of a dataset write operation
+ """
+
+ """The full path to the created file"""
+ cdef public str path
+ """
+ If the file is a parquet file this will contain the parquet metadata.
+ This metadata will have the file path attribute set to the path of
+ the written file.
+ """
+ cdef public object metadata
+
+ def __init__(self, path, metadata):
+ self.path = path
+ self.metadata = metadata
+
+cdef void _filesystemdataset_write_visitor(
+ dict visit_args,
+ CFileWriter* file_writer):
+ cdef:
+ str path
+ str base_dir
+ WrittenFile written_file
+ FileMetaData parquet_metadata
+ CParquetFileWriter* parquet_file_writer
+
+ parquet_metadata = None
+ path = frombytes(deref(file_writer).destination().path)
+ if deref(deref(file_writer).format()).type_name() == b"parquet":
+ parquet_file_writer = dynamic_cast[_CParquetFileWriterPtr](file_writer)
+ with nogil:
+ metadata = deref(
+ deref(parquet_file_writer).parquet_writer()).metadata()
+ if metadata:
+ base_dir = frombytes(visit_args['base_dir'])
+ parquet_metadata = FileMetaData()
+ parquet_metadata.init(metadata)
+ parquet_metadata.set_file_path(os.path.relpath(path, base_dir))
+ written_file = WrittenFile(path, parquet_metadata)
+ visit_args['file_visitor'](written_file)
+
+
+def _filesystemdataset_write(
+ Scanner data not None,
+ object base_dir not None,
+ str basename_template not None,
+ FileSystem filesystem not None,
+ Partitioning partitioning not None,
+ FileWriteOptions file_options not None,
+ int max_partitions,
+ object file_visitor,
+ str existing_data_behavior not None
+):
+ """
+ CFileSystemDataset.Write wrapper
+ """
+ cdef:
+ CFileSystemDatasetWriteOptions c_options
+ shared_ptr[CScanner] c_scanner
+ vector[shared_ptr[CRecordBatch]] c_batches
+ dict visit_args
+
+ c_options.file_write_options = file_options.unwrap()
+ c_options.filesystem = filesystem.unwrap()
+ c_options.base_dir = tobytes(_stringify_path(base_dir))
+ c_options.partitioning = partitioning.unwrap()
+ c_options.max_partitions = max_partitions
+ c_options.basename_template = tobytes(basename_template)
+ if existing_data_behavior == 'error':
+ c_options.existing_data_behavior = ExistingDataBehavior_ERROR
+ elif existing_data_behavior == 'overwrite_or_ignore':
+ c_options.existing_data_behavior =\
+ ExistingDataBehavior_OVERWRITE_OR_IGNORE
+ elif existing_data_behavior == 'delete_matching':
+ c_options.existing_data_behavior = ExistingDataBehavior_DELETE_MATCHING
+ else:
+ raise ValueError(
+ ("existing_data_behavior must be one of 'error', ",
+ "'overwrite_or_ignore' or 'delete_matching'")
+ )
+
+ if file_visitor is not None:
+ visit_args = {'base_dir': c_options.base_dir,
+ 'file_visitor': file_visitor}
+ # Need to use post_finish because parquet metadata is not available
+ # until after Finish has been called
+ c_options.writer_post_finish = BindFunction[cb_writer_finish_internal](
+ &_filesystemdataset_write_visitor, visit_args)
+
+ c_scanner = data.unwrap()
+ with nogil:
+ check_status(CFileSystemDataset.Write(c_options, c_scanner))
diff --git a/src/arrow/python/pyarrow/_dataset_orc.pyx b/src/arrow/python/pyarrow/_dataset_orc.pyx
new file mode 100644
index 000000000..40a21ef54
--- /dev/null
+++ b/src/arrow/python/pyarrow/_dataset_orc.pyx
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+"""Dataset support for ORC file format."""
+
+from pyarrow.lib cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_dataset cimport *
+
+from pyarrow._dataset cimport FileFormat
+
+
+cdef class OrcFileFormat(FileFormat):
+
+ def __init__(self):
+ self.init(shared_ptr[CFileFormat](new COrcFileFormat()))
+
+ def equals(self, OrcFileFormat other):
+ return True
+
+ @property
+ def default_extname(self):
+ return "orc"
+
+ def __reduce__(self):
+ return OrcFileFormat, tuple()
diff --git a/src/arrow/python/pyarrow/_feather.pyx b/src/arrow/python/pyarrow/_feather.pyx
new file mode 100644
index 000000000..8df7935aa
--- /dev/null
+++ b/src/arrow/python/pyarrow/_feather.pyx
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# ---------------------------------------------------------------------
+# Implement Feather file format
+
+# cython: profile=False
+# distutils: language = c++
+# cython: language_level=3
+
+from cython.operator cimport dereference as deref
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_feather cimport *
+from pyarrow.lib cimport (check_status, Table, _Weakrefable,
+ get_writer, get_reader, pyarrow_wrap_table)
+from pyarrow.lib import tobytes
+
+
+class FeatherError(Exception):
+ pass
+
+
+def write_feather(Table table, object dest, compression=None,
+ compression_level=None, chunksize=None, version=2):
+ cdef shared_ptr[COutputStream] sink
+ get_writer(dest, &sink)
+
+ cdef CFeatherProperties properties
+ if version == 2:
+ properties.version = kFeatherV2Version
+ else:
+ properties.version = kFeatherV1Version
+
+ if compression == 'zstd':
+ properties.compression = CCompressionType_ZSTD
+ elif compression == 'lz4':
+ properties.compression = CCompressionType_LZ4_FRAME
+ else:
+ properties.compression = CCompressionType_UNCOMPRESSED
+
+ if chunksize is not None:
+ properties.chunksize = chunksize
+
+ if compression_level is not None:
+ properties.compression_level = compression_level
+
+ with nogil:
+ check_status(WriteFeather(deref(table.table), sink.get(),
+ properties))
+
+
+cdef class FeatherReader(_Weakrefable):
+ cdef:
+ shared_ptr[CFeatherReader] reader
+
+ def __cinit__(self, source, c_bool use_memory_map):
+ cdef shared_ptr[CRandomAccessFile] reader
+ get_reader(source, use_memory_map, &reader)
+ with nogil:
+ self.reader = GetResultValue(CFeatherReader.Open(reader))
+
+ @property
+ def version(self):
+ return self.reader.get().version()
+
+ def read(self):
+ cdef shared_ptr[CTable] sp_table
+ with nogil:
+ check_status(self.reader.get()
+ .Read(&sp_table))
+
+ return pyarrow_wrap_table(sp_table)
+
+ def read_indices(self, indices):
+ cdef:
+ shared_ptr[CTable] sp_table
+ vector[int] c_indices
+
+ for index in indices:
+ c_indices.push_back(index)
+ with nogil:
+ check_status(self.reader.get()
+ .Read(c_indices, &sp_table))
+
+ return pyarrow_wrap_table(sp_table)
+
+ def read_names(self, names):
+ cdef:
+ shared_ptr[CTable] sp_table
+ vector[c_string] c_names
+
+ for name in names:
+ c_names.push_back(tobytes(name))
+ with nogil:
+ check_status(self.reader.get()
+ .Read(c_names, &sp_table))
+
+ return pyarrow_wrap_table(sp_table)
diff --git a/src/arrow/python/pyarrow/_flight.pyx b/src/arrow/python/pyarrow/_flight.pyx
new file mode 100644
index 000000000..5b00c531d
--- /dev/null
+++ b/src/arrow/python/pyarrow/_flight.pyx
@@ -0,0 +1,2664 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+import collections
+import contextlib
+import enum
+import re
+import socket
+import time
+import threading
+import warnings
+
+from cython.operator cimport dereference as deref
+from cython.operator cimport postincrement
+from libcpp cimport bool as c_bool
+
+from pyarrow.lib cimport *
+from pyarrow.lib import ArrowException, ArrowInvalid, SignalStopHandler
+from pyarrow.lib import as_buffer, frombytes, tobytes
+from pyarrow.includes.libarrow_flight cimport *
+from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin
+import pyarrow.lib as lib
+
+
+cdef CFlightCallOptions DEFAULT_CALL_OPTIONS
+
+
+cdef int check_flight_status(const CStatus& status) nogil except -1:
+ cdef shared_ptr[FlightStatusDetail] detail
+
+ if status.ok():
+ return 0
+
+ detail = FlightStatusDetail.UnwrapStatus(status)
+ if detail:
+ with gil:
+ message = frombytes(status.message(), safe=True)
+ detail_msg = detail.get().extra_info()
+ if detail.get().code() == CFlightStatusInternal:
+ raise FlightInternalError(message, detail_msg)
+ elif detail.get().code() == CFlightStatusFailed:
+ message = _munge_grpc_python_error(message)
+ raise FlightServerError(message, detail_msg)
+ elif detail.get().code() == CFlightStatusTimedOut:
+ raise FlightTimedOutError(message, detail_msg)
+ elif detail.get().code() == CFlightStatusCancelled:
+ raise FlightCancelledError(message, detail_msg)
+ elif detail.get().code() == CFlightStatusUnauthenticated:
+ raise FlightUnauthenticatedError(message, detail_msg)
+ elif detail.get().code() == CFlightStatusUnauthorized:
+ raise FlightUnauthorizedError(message, detail_msg)
+ elif detail.get().code() == CFlightStatusUnavailable:
+ raise FlightUnavailableError(message, detail_msg)
+
+ size_detail = FlightWriteSizeStatusDetail.UnwrapStatus(status)
+ if size_detail:
+ with gil:
+ message = frombytes(status.message(), safe=True)
+ raise FlightWriteSizeExceededError(
+ message,
+ size_detail.get().limit(), size_detail.get().actual())
+
+ return check_status(status)
+
+
+_FLIGHT_SERVER_ERROR_REGEX = re.compile(
+ r'Flight RPC failed with message: (.*). Detail: '
+ r'Python exception: (.*)',
+ re.DOTALL
+)
+
+
+def _munge_grpc_python_error(message):
+ m = _FLIGHT_SERVER_ERROR_REGEX.match(message)
+ if m:
+ return ('Flight RPC failed with Python exception \"{}: {}\"'
+ .format(m.group(2), m.group(1)))
+ else:
+ return message
+
+
+cdef IpcWriteOptions _get_options(options):
+ return <IpcWriteOptions> _get_legacy_format_default(
+ use_legacy_format=None, options=options)
+
+
+cdef class FlightCallOptions(_Weakrefable):
+ """RPC-layer options for a Flight call."""
+
+ cdef:
+ CFlightCallOptions options
+
+ def __init__(self, timeout=None, write_options=None, headers=None):
+ """Create call options.
+
+ Parameters
+ ----------
+ timeout : float, None
+ A timeout for the call, in seconds. None means that the
+ timeout defaults to an implementation-specific value.
+ write_options : pyarrow.ipc.IpcWriteOptions, optional
+ IPC write options. The default options can be controlled
+ by environment variables (see pyarrow.ipc).
+ headers : List[Tuple[str, str]], optional
+ A list of arbitrary headers as key, value tuples
+ """
+ cdef IpcWriteOptions c_write_options
+
+ if timeout is not None:
+ self.options.timeout = CTimeoutDuration(timeout)
+ if write_options is not None:
+ c_write_options = _get_options(write_options)
+ self.options.write_options = c_write_options.c_options
+ if headers is not None:
+ self.options.headers = headers
+
+ @staticmethod
+ cdef CFlightCallOptions* unwrap(obj):
+ if not obj:
+ return &DEFAULT_CALL_OPTIONS
+ elif isinstance(obj, FlightCallOptions):
+ return &((<FlightCallOptions> obj).options)
+ raise TypeError("Expected a FlightCallOptions object, not "
+ "'{}'".format(type(obj)))
+
+
+_CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key'])
+
+
+class CertKeyPair(_CertKeyPair):
+ """A TLS certificate and key for use in Flight."""
+
+
+cdef class FlightError(Exception):
+ cdef dict __dict__
+
+ def __init__(self, message='', extra_info=b''):
+ super().__init__(message)
+ self.extra_info = tobytes(extra_info)
+
+ cdef CStatus to_status(self):
+ message = tobytes("Flight error: {}".format(str(self)))
+ return CStatus_UnknownError(message)
+
+cdef class FlightInternalError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusInternal,
+ tobytes(str(self)), self.extra_info)
+
+
+cdef class FlightTimedOutError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusTimedOut,
+ tobytes(str(self)), self.extra_info)
+
+
+cdef class FlightCancelledError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)),
+ self.extra_info)
+
+
+cdef class FlightServerError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusFailed, tobytes(str(self)),
+ self.extra_info)
+
+
+cdef class FlightUnauthenticatedError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(
+ CFlightStatusUnauthenticated, tobytes(str(self)), self.extra_info)
+
+
+cdef class FlightUnauthorizedError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusUnauthorized, tobytes(str(self)),
+ self.extra_info)
+
+
+cdef class FlightUnavailableError(FlightError, ArrowException):
+ cdef CStatus to_status(self):
+ return MakeFlightError(CFlightStatusUnavailable, tobytes(str(self)),
+ self.extra_info)
+
+
+class FlightWriteSizeExceededError(ArrowInvalid):
+ """A write operation exceeded the client-configured limit."""
+
+ def __init__(self, message, limit, actual):
+ super().__init__(message)
+ self.limit = limit
+ self.actual = actual
+
+
+cdef class Action(_Weakrefable):
+ """An action executable on a Flight service."""
+ cdef:
+ CAction action
+
+ def __init__(self, action_type, buf):
+ """Create an action from a type and a buffer.
+
+ Parameters
+ ----------
+ action_type : bytes or str
+ buf : Buffer or bytes-like object
+ """
+ self.action.type = tobytes(action_type)
+ self.action.body = pyarrow_unwrap_buffer(as_buffer(buf))
+
+ @property
+ def type(self):
+ """The action type."""
+ return frombytes(self.action.type)
+
+ @property
+ def body(self):
+ """The action body (arguments for the action)."""
+ return pyarrow_wrap_buffer(self.action.body)
+
+ @staticmethod
+ cdef CAction unwrap(action) except *:
+ if not isinstance(action, Action):
+ raise TypeError("Must provide Action, not '{}'".format(
+ type(action)))
+ return (<Action> action).action
+
+
+_ActionType = collections.namedtuple('_ActionType', ['type', 'description'])
+
+
+class ActionType(_ActionType):
+ """A type of action that is executable on a Flight service."""
+
+ def make_action(self, buf):
+ """Create an Action with this type.
+
+ Parameters
+ ----------
+ buf : obj
+ An Arrow buffer or Python bytes or bytes-like object.
+ """
+ return Action(self.type, buf)
+
+
+cdef class Result(_Weakrefable):
+ """A result from executing an Action."""
+ cdef:
+ unique_ptr[CFlightResult] result
+
+ def __init__(self, buf):
+ """Create a new result.
+
+ Parameters
+ ----------
+ buf : Buffer or bytes-like object
+ """
+ self.result.reset(new CFlightResult())
+ self.result.get().body = pyarrow_unwrap_buffer(as_buffer(buf))
+
+ @property
+ def body(self):
+ """Get the Buffer containing the result."""
+ return pyarrow_wrap_buffer(self.result.get().body)
+
+
+cdef class BasicAuth(_Weakrefable):
+ """A container for basic auth."""
+ cdef:
+ unique_ptr[CBasicAuth] basic_auth
+
+ def __init__(self, username=None, password=None):
+ """Create a new basic auth object.
+
+ Parameters
+ ----------
+ username : string
+ password : string
+ """
+ self.basic_auth.reset(new CBasicAuth())
+ if username:
+ self.basic_auth.get().username = tobytes(username)
+ if password:
+ self.basic_auth.get().password = tobytes(password)
+
+ @property
+ def username(self):
+ """Get the username."""
+ return self.basic_auth.get().username
+
+ @property
+ def password(self):
+ """Get the password."""
+ return self.basic_auth.get().password
+
+ @staticmethod
+ def deserialize(string):
+ auth = BasicAuth()
+ check_flight_status(DeserializeBasicAuth(string, &auth.basic_auth))
+ return auth
+
+ def serialize(self):
+ cdef:
+ c_string auth
+ check_flight_status(SerializeBasicAuth(deref(self.basic_auth), &auth))
+ return frombytes(auth)
+
+
+class DescriptorType(enum.Enum):
+ """
+ The type of a FlightDescriptor.
+
+ Attributes
+ ----------
+
+ UNKNOWN
+ An unknown descriptor type.
+
+ PATH
+ A Flight stream represented by a path.
+
+ CMD
+ A Flight stream represented by an application-defined command.
+
+ """
+
+ UNKNOWN = 0
+ PATH = 1
+ CMD = 2
+
+
+class FlightMethod(enum.Enum):
+ """The implemented methods in Flight."""
+
+ INVALID = 0
+ HANDSHAKE = 1
+ LIST_FLIGHTS = 2
+ GET_FLIGHT_INFO = 3
+ GET_SCHEMA = 4
+ DO_GET = 5
+ DO_PUT = 6
+ DO_ACTION = 7
+ LIST_ACTIONS = 8
+ DO_EXCHANGE = 9
+
+
+cdef wrap_flight_method(CFlightMethod method):
+ if method == CFlightMethodHandshake:
+ return FlightMethod.HANDSHAKE
+ elif method == CFlightMethodListFlights:
+ return FlightMethod.LIST_FLIGHTS
+ elif method == CFlightMethodGetFlightInfo:
+ return FlightMethod.GET_FLIGHT_INFO
+ elif method == CFlightMethodGetSchema:
+ return FlightMethod.GET_SCHEMA
+ elif method == CFlightMethodDoGet:
+ return FlightMethod.DO_GET
+ elif method == CFlightMethodDoPut:
+ return FlightMethod.DO_PUT
+ elif method == CFlightMethodDoAction:
+ return FlightMethod.DO_ACTION
+ elif method == CFlightMethodListActions:
+ return FlightMethod.LIST_ACTIONS
+ elif method == CFlightMethodDoExchange:
+ return FlightMethod.DO_EXCHANGE
+ return FlightMethod.INVALID
+
+
+cdef class FlightDescriptor(_Weakrefable):
+ """A description of a data stream available from a Flight service."""
+ cdef:
+ CFlightDescriptor descriptor
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use "
+ "`pyarrow.flight.FlightDescriptor.for_{path,command}` "
+ "function instead."
+ .format(self.__class__.__name__))
+
+ @staticmethod
+ def for_path(*path):
+ """Create a FlightDescriptor for a resource path."""
+ cdef FlightDescriptor result = \
+ FlightDescriptor.__new__(FlightDescriptor)
+ result.descriptor.type = CDescriptorTypePath
+ result.descriptor.path = [tobytes(p) for p in path]
+ return result
+
+ @staticmethod
+ def for_command(command):
+ """Create a FlightDescriptor for an opaque command."""
+ cdef FlightDescriptor result = \
+ FlightDescriptor.__new__(FlightDescriptor)
+ result.descriptor.type = CDescriptorTypeCmd
+ result.descriptor.cmd = tobytes(command)
+ return result
+
+ @property
+ def descriptor_type(self):
+ """Get the type of this descriptor."""
+ if self.descriptor.type == CDescriptorTypeUnknown:
+ return DescriptorType.UNKNOWN
+ elif self.descriptor.type == CDescriptorTypePath:
+ return DescriptorType.PATH
+ elif self.descriptor.type == CDescriptorTypeCmd:
+ return DescriptorType.CMD
+ raise RuntimeError("Invalid descriptor type!")
+
+ @property
+ def command(self):
+ """Get the command for this descriptor."""
+ if self.descriptor_type != DescriptorType.CMD:
+ return None
+ return self.descriptor.cmd
+
+ @property
+ def path(self):
+ """Get the path for this descriptor."""
+ if self.descriptor_type != DescriptorType.PATH:
+ return None
+ return self.descriptor.path
+
+ def __repr__(self):
+ if self.descriptor_type == DescriptorType.PATH:
+ return "<FlightDescriptor path: {!r}>".format(self.path)
+ elif self.descriptor_type == DescriptorType.CMD:
+ return "<FlightDescriptor command: {!r}>".format(self.command)
+ else:
+ return "<FlightDescriptor type: {!r}>".format(self.descriptor_type)
+
+ @staticmethod
+ cdef CFlightDescriptor unwrap(descriptor) except *:
+ if not isinstance(descriptor, FlightDescriptor):
+ raise TypeError("Must provide a FlightDescriptor, not '{}'".format(
+ type(descriptor)))
+ return (<FlightDescriptor> descriptor).descriptor
+
+ def serialize(self):
+ """Get the wire-format representation of this type.
+
+ Useful when interoperating with non-Flight systems (e.g. REST
+ services) that may want to return Flight types.
+
+ """
+ cdef c_string out
+ check_flight_status(self.descriptor.SerializeToString(&out))
+ return out
+
+ @classmethod
+ def deserialize(cls, serialized):
+ """Parse the wire-format representation of this type.
+
+ Useful when interoperating with non-Flight systems (e.g. REST
+ services) that may want to return Flight types.
+
+ """
+ cdef FlightDescriptor descriptor = \
+ FlightDescriptor.__new__(FlightDescriptor)
+ check_flight_status(CFlightDescriptor.Deserialize(
+ tobytes(serialized), &descriptor.descriptor))
+ return descriptor
+
+ def __eq__(self, FlightDescriptor other):
+ return self.descriptor == other.descriptor
+
+
+cdef class Ticket(_Weakrefable):
+ """A ticket for requesting a Flight stream."""
+
+ cdef:
+ CTicket ticket
+
+ def __init__(self, ticket):
+ self.ticket.ticket = tobytes(ticket)
+
+ @property
+ def ticket(self):
+ return self.ticket.ticket
+
+ def serialize(self):
+ """Get the wire-format representation of this type.
+
+ Useful when interoperating with non-Flight systems (e.g. REST
+ services) that may want to return Flight types.
+
+ """
+ cdef c_string out
+ check_flight_status(self.ticket.SerializeToString(&out))
+ return out
+
+ @classmethod
+ def deserialize(cls, serialized):
+ """Parse the wire-format representation of this type.
+
+ Useful when interoperating with non-Flight systems (e.g. REST
+ services) that may want to return Flight types.
+
+ """
+ cdef:
+ CTicket c_ticket
+ Ticket ticket
+ check_flight_status(
+ CTicket.Deserialize(tobytes(serialized), &c_ticket))
+ ticket = Ticket.__new__(Ticket)
+ ticket.ticket = c_ticket
+ return ticket
+
+ def __eq__(self, Ticket other):
+ return self.ticket == other.ticket
+
+ def __repr__(self):
+ return '<Ticket {}>'.format(self.ticket.ticket)
+
+
+cdef class Location(_Weakrefable):
+ """The location of a Flight service."""
+ cdef:
+ CLocation location
+
+ def __init__(self, uri):
+ check_flight_status(CLocation.Parse(tobytes(uri), &self.location))
+
+ def __repr__(self):
+ return '<Location {}>'.format(self.location.ToString())
+
+ @property
+ def uri(self):
+ return self.location.ToString()
+
+ def equals(self, Location other):
+ return self == other
+
+ def __eq__(self, other):
+ if not isinstance(other, Location):
+ return NotImplemented
+ return self.location.Equals((<Location> other).location)
+
+ @staticmethod
+ def for_grpc_tcp(host, port):
+ """Create a Location for a TCP-based gRPC service."""
+ cdef:
+ c_string c_host = tobytes(host)
+ int c_port = port
+ Location result = Location.__new__(Location)
+ check_flight_status(
+ CLocation.ForGrpcTcp(c_host, c_port, &result.location))
+ return result
+
+ @staticmethod
+ def for_grpc_tls(host, port):
+ """Create a Location for a TLS-based gRPC service."""
+ cdef:
+ c_string c_host = tobytes(host)
+ int c_port = port
+ Location result = Location.__new__(Location)
+ check_flight_status(
+ CLocation.ForGrpcTls(c_host, c_port, &result.location))
+ return result
+
+ @staticmethod
+ def for_grpc_unix(path):
+ """Create a Location for a domain socket-based gRPC service."""
+ cdef:
+ c_string c_path = tobytes(path)
+ Location result = Location.__new__(Location)
+ check_flight_status(CLocation.ForGrpcUnix(c_path, &result.location))
+ return result
+
+ @staticmethod
+ cdef Location wrap(CLocation location):
+ cdef Location result = Location.__new__(Location)
+ result.location = location
+ return result
+
+ @staticmethod
+ cdef CLocation unwrap(object location) except *:
+ cdef CLocation c_location
+ if isinstance(location, str):
+ check_flight_status(
+ CLocation.Parse(tobytes(location), &c_location))
+ return c_location
+ elif not isinstance(location, Location):
+ raise TypeError("Must provide a Location, not '{}'".format(
+ type(location)))
+ return (<Location> location).location
+
+
+cdef class FlightEndpoint(_Weakrefable):
+ """A Flight stream, along with the ticket and locations to access it."""
+ cdef:
+ CFlightEndpoint endpoint
+
+ def __init__(self, ticket, locations):
+ """Create a FlightEndpoint from a ticket and list of locations.
+
+ Parameters
+ ----------
+ ticket : Ticket or bytes
+ the ticket needed to access this flight
+ locations : list of string URIs
+ locations where this flight is available
+
+ Raises
+ ------
+ ArrowException
+ If one of the location URIs is not a valid URI.
+ """
+ cdef:
+ CLocation c_location
+
+ if isinstance(ticket, Ticket):
+ self.endpoint.ticket.ticket = tobytes(ticket.ticket)
+ else:
+ self.endpoint.ticket.ticket = tobytes(ticket)
+
+ for location in locations:
+ if isinstance(location, Location):
+ c_location = (<Location> location).location
+ else:
+ c_location = CLocation()
+ check_flight_status(
+ CLocation.Parse(tobytes(location), &c_location))
+ self.endpoint.locations.push_back(c_location)
+
+ @property
+ def ticket(self):
+ """Get the ticket in this endpoint."""
+ return Ticket(self.endpoint.ticket.ticket)
+
+ @property
+ def locations(self):
+ return [Location.wrap(location)
+ for location in self.endpoint.locations]
+
+ def __repr__(self):
+ return "<FlightEndpoint ticket: {!r} locations: {!r}>".format(
+ self.ticket, self.locations)
+
+ def __eq__(self, FlightEndpoint other):
+ return self.endpoint == other.endpoint
+
+
+cdef class SchemaResult(_Weakrefable):
+ """A result from a getschema request. Holding a schema"""
+ cdef:
+ unique_ptr[CSchemaResult] result
+
+ def __init__(self, Schema schema):
+ """Create a SchemaResult from a schema.
+
+ Parameters
+ ----------
+ schema: Schema
+ the schema of the data in this flight.
+ """
+ cdef:
+ shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
+ check_flight_status(CreateSchemaResult(c_schema, &self.result))
+
+ @property
+ def schema(self):
+ """The schema of the data in this flight."""
+ cdef:
+ shared_ptr[CSchema] schema
+ CDictionaryMemo dummy_memo
+
+ check_flight_status(self.result.get().GetSchema(&dummy_memo, &schema))
+ return pyarrow_wrap_schema(schema)
+
+
+cdef class FlightInfo(_Weakrefable):
+ """A description of a Flight stream."""
+ cdef:
+ unique_ptr[CFlightInfo] info
+
+ def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints,
+ total_records, total_bytes):
+ """Create a FlightInfo object from a schema, descriptor, and endpoints.
+
+ Parameters
+ ----------
+ schema : Schema
+ the schema of the data in this flight.
+ descriptor : FlightDescriptor
+ the descriptor for this flight.
+ endpoints : list of FlightEndpoint
+ a list of endpoints where this flight is available.
+ total_records : int
+ the total records in this flight, or -1 if unknown
+ total_bytes : int
+ the total bytes in this flight, or -1 if unknown
+ """
+ cdef:
+ shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
+ vector[CFlightEndpoint] c_endpoints
+
+ for endpoint in endpoints:
+ if isinstance(endpoint, FlightEndpoint):
+ c_endpoints.push_back((<FlightEndpoint> endpoint).endpoint)
+ else:
+ raise TypeError('Endpoint {} is not instance of'
+ ' FlightEndpoint'.format(endpoint))
+
+ check_flight_status(CreateFlightInfo(c_schema,
+ descriptor.descriptor,
+ c_endpoints,
+ total_records,
+ total_bytes, &self.info))
+
+ @property
+ def total_records(self):
+ """The total record count of this flight, or -1 if unknown."""
+ return self.info.get().total_records()
+
+ @property
+ def total_bytes(self):
+ """The size in bytes of the data in this flight, or -1 if unknown."""
+ return self.info.get().total_bytes()
+
+ @property
+ def schema(self):
+ """The schema of the data in this flight."""
+ cdef:
+ shared_ptr[CSchema] schema
+ CDictionaryMemo dummy_memo
+
+ check_flight_status(self.info.get().GetSchema(&dummy_memo, &schema))
+ return pyarrow_wrap_schema(schema)
+
+ @property
+ def descriptor(self):
+ """The descriptor of the data in this flight."""
+ cdef FlightDescriptor result = \
+ FlightDescriptor.__new__(FlightDescriptor)
+ result.descriptor = self.info.get().descriptor()
+ return result
+
+ @property
+ def endpoints(self):
+ """The endpoints where this flight is available."""
+ # TODO: get Cython to iterate over reference directly
+ cdef:
+ vector[CFlightEndpoint] endpoints = self.info.get().endpoints()
+ FlightEndpoint py_endpoint
+
+ result = []
+ for endpoint in endpoints:
+ py_endpoint = FlightEndpoint.__new__(FlightEndpoint)
+ py_endpoint.endpoint = endpoint
+ result.append(py_endpoint)
+ return result
+
+ def serialize(self):
+ """Get the wire-format representation of this type.
+
+ Useful when interoperating with non-Flight systems (e.g. REST
+ services) that may want to return Flight types.
+
+ """
+ cdef c_string out
+ check_flight_status(self.info.get().SerializeToString(&out))
+ return out
+
+ @classmethod
+ def deserialize(cls, serialized):
+ """Parse the wire-format representation of this type.
+
+ Useful when interoperating with non-Flight systems (e.g. REST
+ services) that may want to return Flight types.
+
+ """
+ cdef FlightInfo info = FlightInfo.__new__(FlightInfo)
+ check_flight_status(CFlightInfo.Deserialize(
+ tobytes(serialized), &info.info))
+ return info
+
+
+cdef class FlightStreamChunk(_Weakrefable):
+ """A RecordBatch with application metadata on the side."""
+ cdef:
+ CFlightStreamChunk chunk
+
+ @property
+ def data(self):
+ if self.chunk.data == NULL:
+ return None
+ return pyarrow_wrap_batch(self.chunk.data)
+
+ @property
+ def app_metadata(self):
+ if self.chunk.app_metadata == NULL:
+ return None
+ return pyarrow_wrap_buffer(self.chunk.app_metadata)
+
+ def __iter__(self):
+ return iter((self.data, self.app_metadata))
+
+ def __repr__(self):
+ return "<FlightStreamChunk with data: {} with metadata: {}>".format(
+ self.chunk.data != NULL, self.chunk.app_metadata != NULL)
+
+
+cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin):
+ """A reader for Flight streams."""
+
+ # Needs to be separate class so the "real" class can subclass the
+ # pure-Python mixin class
+
+ cdef dict __dict__
+ cdef shared_ptr[CMetadataRecordBatchReader] reader
+
+ def __iter__(self):
+ while True:
+ yield self.read_chunk()
+
+ @property
+ def schema(self):
+ """Get the schema for this reader."""
+ cdef shared_ptr[CSchema] c_schema
+ with nogil:
+ c_schema = GetResultValue(self.reader.get().GetSchema())
+ return pyarrow_wrap_schema(c_schema)
+
+ def read_all(self):
+ """Read the entire contents of the stream as a Table."""
+ cdef:
+ shared_ptr[CTable] c_table
+ with nogil:
+ check_flight_status(self.reader.get().ReadAll(&c_table))
+ return pyarrow_wrap_table(c_table)
+
+ def read_chunk(self):
+ """Read the next RecordBatch along with any metadata.
+
+ Returns
+ -------
+ data : RecordBatch
+ The next RecordBatch in the stream.
+ app_metadata : Buffer or None
+ Application-specific metadata for the batch as defined by
+ Flight.
+
+ Raises
+ ------
+ StopIteration
+ when the stream is finished
+ """
+ cdef:
+ FlightStreamChunk chunk = FlightStreamChunk()
+
+ with nogil:
+ check_flight_status(self.reader.get().Next(&chunk.chunk))
+
+ if chunk.chunk.data == NULL and chunk.chunk.app_metadata == NULL:
+ raise StopIteration
+
+ return chunk
+
+ def to_reader(self):
+ """Convert this reader into a regular RecordBatchReader.
+
+ This may fail if the schema cannot be read from the remote end.
+ """
+ cdef RecordBatchReader reader
+ reader = RecordBatchReader.__new__(RecordBatchReader)
+ reader.reader = GetResultValue(MakeRecordBatchReader(self.reader))
+ return reader
+
+
+cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader):
+ """The virtual base class for readers for Flight streams."""
+
+
+cdef class FlightStreamReader(MetadataRecordBatchReader):
+ """A reader that can also be canceled."""
+
+ def cancel(self):
+ """Cancel the read operation."""
+ with nogil:
+ (<CFlightStreamReader*> self.reader.get()).Cancel()
+
+ def read_all(self):
+ """Read the entire contents of the stream as a Table."""
+ cdef:
+ shared_ptr[CTable] c_table
+ CStopToken stop_token
+ with SignalStopHandler() as stop_handler:
+ stop_token = (<StopToken> stop_handler.stop_token).stop_token
+ with nogil:
+ check_flight_status(
+ (<CFlightStreamReader*> self.reader.get())
+ .ReadAllWithStopToken(&c_table, stop_token))
+ return pyarrow_wrap_table(c_table)
+
+
+cdef class MetadataRecordBatchWriter(_CRecordBatchWriter):
+ """A RecordBatchWriter that also allows writing application metadata.
+
+ This class is a context manager; on exit, close() will be called.
+ """
+
+ cdef CMetadataRecordBatchWriter* _writer(self) nogil:
+ return <CMetadataRecordBatchWriter*> self.writer.get()
+
+ def begin(self, schema: Schema, options=None):
+ """Prepare to write data to this stream with the given schema."""
+ cdef:
+ shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
+ CIpcWriteOptions c_options = _get_options(options).c_options
+ with nogil:
+ check_flight_status(self._writer().Begin(c_schema, c_options))
+
+ def write_metadata(self, buf):
+ """Write Flight metadata by itself."""
+ cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf))
+ with nogil:
+ check_flight_status(
+ self._writer().WriteMetadata(c_buf))
+
+ def write_batch(self, RecordBatch batch):
+ """
+ Write RecordBatch to stream.
+
+ Parameters
+ ----------
+ batch : RecordBatch
+ """
+ # Override superclass method to use check_flight_status so we
+ # can generate FlightWriteSizeExceededError. We don't do this
+ # for write_table as callers who intend to handle the error
+ # and retry with a smaller batch should be working with
+ # individual batches to have control.
+ with nogil:
+ check_flight_status(
+ self._writer().WriteRecordBatch(deref(batch.batch)))
+
+ def write_table(self, Table table, max_chunksize=None, **kwargs):
+ """
+ Write Table to stream in (contiguous) RecordBatch objects.
+
+ Parameters
+ ----------
+ table : Table
+ max_chunksize : int, default None
+ Maximum size for RecordBatch chunks. Individual chunks may be
+ smaller depending on the chunk layout of individual columns.
+ """
+ cdef:
+ # max_chunksize must be > 0 to have any impact
+ int64_t c_max_chunksize = -1
+
+ if 'chunksize' in kwargs:
+ max_chunksize = kwargs['chunksize']
+ msg = ('The parameter chunksize is deprecated for the write_table '
+ 'methods as of 0.15, please use parameter '
+ 'max_chunksize instead')
+ warnings.warn(msg, FutureWarning)
+
+ if max_chunksize is not None:
+ c_max_chunksize = max_chunksize
+
+ with nogil:
+ check_flight_status(
+ self._writer().WriteTable(table.table[0], c_max_chunksize))
+
+ def close(self):
+ """
+ Close stream and write end-of-stream 0 marker.
+ """
+ with nogil:
+ check_flight_status(self._writer().Close())
+
+ def write_with_metadata(self, RecordBatch batch, buf):
+ """Write a RecordBatch along with Flight metadata.
+
+ Parameters
+ ----------
+ batch : RecordBatch
+ The next RecordBatch in the stream.
+ buf : Buffer
+ Application-specific metadata for the batch as defined by
+ Flight.
+ """
+ cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf))
+ with nogil:
+ check_flight_status(
+ self._writer().WriteWithMetadata(deref(batch.batch), c_buf))
+
+
+cdef class FlightStreamWriter(MetadataRecordBatchWriter):
+ """A writer that also allows closing the write side of a stream."""
+
+ def done_writing(self):
+ """Indicate that the client is done writing, but not done reading."""
+ with nogil:
+ check_flight_status(
+ (<CFlightStreamWriter*> self.writer.get()).DoneWriting())
+
+
+cdef class FlightMetadataReader(_Weakrefable):
+ """A reader for Flight metadata messages sent during a DoPut."""
+
+ cdef:
+ unique_ptr[CFlightMetadataReader] reader
+
+ def read(self):
+ """Read the next metadata message."""
+ cdef shared_ptr[CBuffer] buf
+ with nogil:
+ check_flight_status(self.reader.get().ReadMetadata(&buf))
+ if buf == NULL:
+ return None
+ return pyarrow_wrap_buffer(buf)
+
+
+cdef class FlightMetadataWriter(_Weakrefable):
+ """A sender for Flight metadata messages during a DoPut."""
+
+ cdef:
+ unique_ptr[CFlightMetadataWriter] writer
+
+ def write(self, message):
+ """Write the next metadata message.
+
+ Parameters
+ ----------
+ message : Buffer
+ """
+ cdef shared_ptr[CBuffer] buf = \
+ pyarrow_unwrap_buffer(as_buffer(message))
+ with nogil:
+ check_flight_status(self.writer.get().WriteMetadata(deref(buf)))
+
+
+cdef class FlightClient(_Weakrefable):
+ """A client to a Flight service.
+
+ Connect to a Flight service on the given host and port.
+
+ Parameters
+ ----------
+ location : str, tuple or Location
+ Location to connect to. Either a gRPC URI like `grpc://localhost:port`,
+ a tuple of (host, port) pair, or a Location instance.
+ tls_root_certs : bytes or None
+ PEM-encoded
+ cert_chain: bytes or None
+ Client certificate if using mutual TLS
+ private_key: bytes or None
+ Client private key for cert_chain is using mutual TLS
+ override_hostname : str or None
+ Override the hostname checked by TLS. Insecure, use with caution.
+ middleware : list optional, default None
+ A list of ClientMiddlewareFactory instances.
+ write_size_limit_bytes : int optional, default None
+ A soft limit on the size of a data payload sent to the
+ server. Enabled if positive. If enabled, writing a record
+ batch that (when serialized) exceeds this limit will raise an
+ exception; the client can retry the write with a smaller
+ batch.
+ disable_server_verification : boolean optional, default False
+ A flag that indicates that, if the client is connecting
+ with TLS, that it skips server verification. If this is
+ enabled, all other TLS settings are overridden.
+ generic_options : list optional, default None
+ A list of generic (string, int or string) option tuples passed
+ to the underlying transport. Effect is implementation
+ dependent.
+ """
+ cdef:
+ unique_ptr[CFlightClient] client
+
+ def __init__(self, location, *, tls_root_certs=None, cert_chain=None,
+ private_key=None, override_hostname=None, middleware=None,
+ write_size_limit_bytes=None,
+ disable_server_verification=None, generic_options=None):
+ if isinstance(location, (bytes, str)):
+ location = Location(location)
+ elif isinstance(location, tuple):
+ host, port = location
+ if tls_root_certs or disable_server_verification is not None:
+ location = Location.for_grpc_tls(host, port)
+ else:
+ location = Location.for_grpc_tcp(host, port)
+ elif not isinstance(location, Location):
+ raise TypeError('`location` argument must be a string, tuple or a '
+ 'Location instance')
+ self.init(location, tls_root_certs, cert_chain, private_key,
+ override_hostname, middleware, write_size_limit_bytes,
+ disable_server_verification, generic_options)
+
+ cdef init(self, Location location, tls_root_certs, cert_chain,
+ private_key, override_hostname, middleware,
+ write_size_limit_bytes, disable_server_verification,
+ generic_options):
+ cdef:
+ int c_port = 0
+ CLocation c_location = Location.unwrap(location)
+ CFlightClientOptions c_options = CFlightClientOptions.Defaults()
+ function[cb_client_middleware_start_call] start_call = \
+ &_client_middleware_start_call
+ CIntStringVariant variant
+
+ if tls_root_certs:
+ c_options.tls_root_certs = tobytes(tls_root_certs)
+ if cert_chain:
+ c_options.cert_chain = tobytes(cert_chain)
+ if private_key:
+ c_options.private_key = tobytes(private_key)
+ if override_hostname:
+ c_options.override_hostname = tobytes(override_hostname)
+ if disable_server_verification is not None:
+ c_options.disable_server_verification = disable_server_verification
+ if middleware:
+ for factory in middleware:
+ c_options.middleware.push_back(
+ <shared_ptr[CClientMiddlewareFactory]>
+ make_shared[CPyClientMiddlewareFactory](
+ <PyObject*> factory, start_call))
+ if write_size_limit_bytes is not None:
+ c_options.write_size_limit_bytes = write_size_limit_bytes
+ else:
+ c_options.write_size_limit_bytes = 0
+ if generic_options:
+ for key, value in generic_options:
+ if isinstance(value, (str, bytes)):
+ variant = CIntStringVariant(<c_string> tobytes(value))
+ else:
+ variant = CIntStringVariant(<int> value)
+ c_options.generic_options.push_back(
+ pair[c_string, CIntStringVariant](tobytes(key), variant))
+
+ with nogil:
+ check_flight_status(CFlightClient.Connect(c_location, c_options,
+ &self.client))
+
+ def wait_for_available(self, timeout=5):
+ """Block until the server can be contacted.
+
+ Parameters
+ ----------
+ timeout : int, default 5
+ The maximum seconds to wait.
+ """
+ deadline = time.time() + timeout
+ while True:
+ try:
+ list(self.list_flights())
+ except FlightUnavailableError:
+ if time.time() < deadline:
+ time.sleep(0.025)
+ continue
+ else:
+ raise
+ except NotImplementedError:
+ # allow if list_flights is not implemented, because
+ # the server can be contacted nonetheless
+ break
+ else:
+ break
+
+ @classmethod
+ def connect(cls, location, tls_root_certs=None, cert_chain=None,
+ private_key=None, override_hostname=None,
+ disable_server_verification=None):
+ warnings.warn("The 'FlightClient.connect' method is deprecated, use "
+ "FlightClient constructor or pyarrow.flight.connect "
+ "function instead")
+ return FlightClient(
+ location, tls_root_certs=tls_root_certs,
+ cert_chain=cert_chain, private_key=private_key,
+ override_hostname=override_hostname,
+ disable_server_verification=disable_server_verification
+ )
+
+ def authenticate(self, auth_handler, options: FlightCallOptions = None):
+ """Authenticate to the server.
+
+ Parameters
+ ----------
+ auth_handler : ClientAuthHandler
+ The authentication mechanism to use.
+ options : FlightCallOptions
+ Options for this call.
+ """
+ cdef:
+ unique_ptr[CClientAuthHandler] handler
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+
+ if not isinstance(auth_handler, ClientAuthHandler):
+ raise TypeError(
+ "FlightClient.authenticate takes a ClientAuthHandler, "
+ "not '{}'".format(type(auth_handler)))
+ handler.reset((<ClientAuthHandler> auth_handler).to_handler())
+ with nogil:
+ check_flight_status(
+ self.client.get().Authenticate(deref(c_options),
+ move(handler)))
+
+ def authenticate_basic_token(self, username, password,
+ options: FlightCallOptions = None):
+ """Authenticate to the server with HTTP basic authentication.
+
+ Parameters
+ ----------
+ username : string
+ Username to authenticate with
+ password : string
+ Password to authenticate with
+ options : FlightCallOptions
+ Options for this call
+
+ Returns
+ -------
+ tuple : Tuple[str, str]
+ A tuple representing the FlightCallOptions authorization
+ header entry of a bearer token.
+ """
+ cdef:
+ CResult[pair[c_string, c_string]] result
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ c_string user = tobytes(username)
+ c_string pw = tobytes(password)
+
+ with nogil:
+ result = self.client.get().AuthenticateBasicToken(deref(c_options),
+ user, pw)
+ check_flight_status(result.status())
+
+ return GetResultValue(result)
+
+ def list_actions(self, options: FlightCallOptions = None):
+ """List the actions available on a service."""
+ cdef:
+ vector[CActionType] results
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+
+ with SignalStopHandler() as stop_handler:
+ c_options.stop_token = \
+ (<StopToken> stop_handler.stop_token).stop_token
+ with nogil:
+ check_flight_status(
+ self.client.get().ListActions(deref(c_options), &results))
+
+ result = []
+ for action_type in results:
+ py_action = ActionType(frombytes(action_type.type),
+ frombytes(action_type.description))
+ result.append(py_action)
+
+ return result
+
+ def do_action(self, action, options: FlightCallOptions = None):
+ """
+ Execute an action on a service.
+
+ Parameters
+ ----------
+ action : str, tuple, or Action
+ Can be action type name (no body), type and body, or any Action
+ object
+ options : FlightCallOptions
+ RPC options
+
+ Returns
+ -------
+ results : iterator of Result values
+ """
+ cdef:
+ unique_ptr[CResultStream] results
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+
+ if isinstance(action, (str, bytes)):
+ action = Action(action, b'')
+ elif isinstance(action, tuple):
+ action = Action(*action)
+ elif not isinstance(action, Action):
+ raise TypeError("Action must be Action instance, string, or tuple")
+
+ cdef CAction c_action = Action.unwrap(<Action> action)
+ with nogil:
+ check_flight_status(
+ self.client.get().DoAction(
+ deref(c_options), c_action, &results))
+
+ def _do_action_response():
+ cdef:
+ Result result
+ while True:
+ result = Result.__new__(Result)
+ with nogil:
+ check_flight_status(results.get().Next(&result.result))
+ if result.result == NULL:
+ break
+ yield result
+ return _do_action_response()
+
+ def list_flights(self, criteria: bytes = None,
+ options: FlightCallOptions = None):
+ """List the flights available on a service."""
+ cdef:
+ unique_ptr[CFlightListing] listing
+ FlightInfo result
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ CCriteria c_criteria
+
+ if criteria:
+ c_criteria.expression = tobytes(criteria)
+
+ with SignalStopHandler() as stop_handler:
+ c_options.stop_token = \
+ (<StopToken> stop_handler.stop_token).stop_token
+ with nogil:
+ check_flight_status(
+ self.client.get().ListFlights(deref(c_options),
+ c_criteria, &listing))
+
+ while True:
+ result = FlightInfo.__new__(FlightInfo)
+ with nogil:
+ check_flight_status(listing.get().Next(&result.info))
+ if result.info == NULL:
+ break
+ yield result
+
+ def get_flight_info(self, descriptor: FlightDescriptor,
+ options: FlightCallOptions = None):
+ """Request information about an available flight."""
+ cdef:
+ FlightInfo result = FlightInfo.__new__(FlightInfo)
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ CFlightDescriptor c_descriptor = \
+ FlightDescriptor.unwrap(descriptor)
+
+ with nogil:
+ check_flight_status(self.client.get().GetFlightInfo(
+ deref(c_options), c_descriptor, &result.info))
+
+ return result
+
+ def get_schema(self, descriptor: FlightDescriptor,
+ options: FlightCallOptions = None):
+ """Request schema for an available flight."""
+ cdef:
+ SchemaResult result = SchemaResult.__new__(SchemaResult)
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ CFlightDescriptor c_descriptor = \
+ FlightDescriptor.unwrap(descriptor)
+ with nogil:
+ check_status(
+ self.client.get()
+ .GetSchema(deref(c_options), c_descriptor, &result.result)
+ )
+
+ return result
+
+ def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
+ """Request the data for a flight.
+
+ Returns
+ -------
+ reader : FlightStreamReader
+ """
+ cdef:
+ unique_ptr[CFlightStreamReader] reader
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+
+ with nogil:
+ check_flight_status(
+ self.client.get().DoGet(
+ deref(c_options), ticket.ticket, &reader))
+ result = FlightStreamReader()
+ result.reader.reset(reader.release())
+ return result
+
+ def do_put(self, descriptor: FlightDescriptor, schema: Schema,
+ options: FlightCallOptions = None):
+ """Upload data to a flight.
+
+ Returns
+ -------
+ writer : FlightStreamWriter
+ reader : FlightMetadataReader
+ """
+ cdef:
+ shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
+ unique_ptr[CFlightStreamWriter] writer
+ unique_ptr[CFlightMetadataReader] metadata_reader
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ CFlightDescriptor c_descriptor = \
+ FlightDescriptor.unwrap(descriptor)
+ FlightMetadataReader reader = FlightMetadataReader()
+
+ with nogil:
+ check_flight_status(self.client.get().DoPut(
+ deref(c_options),
+ c_descriptor,
+ c_schema,
+ &writer,
+ &reader.reader))
+ result = FlightStreamWriter()
+ result.writer.reset(writer.release())
+ return result, reader
+
+ def do_exchange(self, descriptor: FlightDescriptor,
+ options: FlightCallOptions = None):
+ """Start a bidirectional data exchange with a server.
+
+ Parameters
+ ----------
+ descriptor : FlightDescriptor
+ A descriptor for the flight.
+ options : FlightCallOptions
+ RPC options.
+
+ Returns
+ -------
+ writer : FlightStreamWriter
+ reader : FlightStreamReader
+ """
+ cdef:
+ unique_ptr[CFlightStreamWriter] c_writer
+ unique_ptr[CFlightStreamReader] c_reader
+ CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
+ CFlightDescriptor c_descriptor = \
+ FlightDescriptor.unwrap(descriptor)
+
+ with nogil:
+ check_flight_status(self.client.get().DoExchange(
+ deref(c_options),
+ c_descriptor,
+ &c_writer,
+ &c_reader))
+ py_writer = FlightStreamWriter()
+ py_writer.writer.reset(c_writer.release())
+ py_reader = FlightStreamReader()
+ py_reader.reader.reset(c_reader.release())
+ return py_writer, py_reader
+
+
+cdef class FlightDataStream(_Weakrefable):
+ """Abstract base class for Flight data streams."""
+
+ cdef CFlightDataStream* to_stream(self) except *:
+ """Create the C++ data stream for the backing Python object.
+
+ We don't expose the C++ object to Python, so we can manage its
+ lifetime from the Cython/C++ side.
+ """
+ raise NotImplementedError
+
+
+cdef class RecordBatchStream(FlightDataStream):
+ """A Flight data stream backed by RecordBatches."""
+ cdef:
+ object data_source
+ CIpcWriteOptions write_options
+
+ def __init__(self, data_source, options=None):
+ """Create a RecordBatchStream from a data source.
+
+ Parameters
+ ----------
+ data_source : RecordBatchReader or Table
+ options : pyarrow.ipc.IpcWriteOptions, optional
+ """
+ if (not isinstance(data_source, RecordBatchReader) and
+ not isinstance(data_source, lib.Table)):
+ raise TypeError("Expected RecordBatchReader or Table, "
+ "but got: {}".format(type(data_source)))
+ self.data_source = data_source
+ self.write_options = _get_options(options).c_options
+
+ cdef CFlightDataStream* to_stream(self) except *:
+ cdef:
+ shared_ptr[CRecordBatchReader] reader
+ if isinstance(self.data_source, RecordBatchReader):
+ reader = (<RecordBatchReader> self.data_source).reader
+ elif isinstance(self.data_source, lib.Table):
+ table = (<Table> self.data_source).table
+ reader.reset(new TableBatchReader(deref(table)))
+ else:
+ raise RuntimeError("Can't construct RecordBatchStream "
+ "from type {}".format(type(self.data_source)))
+ return new CRecordBatchStream(reader, self.write_options)
+
+
+cdef class GeneratorStream(FlightDataStream):
+ """A Flight data stream backed by a Python generator."""
+ cdef:
+ shared_ptr[CSchema] schema
+ object generator
+ # A substream currently being consumed by the client, if
+ # present. Produced by the generator.
+ unique_ptr[CFlightDataStream] current_stream
+ CIpcWriteOptions c_options
+
+ def __init__(self, schema, generator, options=None):
+ """Create a GeneratorStream from a Python generator.
+
+ Parameters
+ ----------
+ schema : Schema
+ The schema for the data to be returned.
+
+ generator : iterator or iterable
+ The generator should yield other FlightDataStream objects,
+ Tables, RecordBatches, or RecordBatchReaders.
+
+ options : pyarrow.ipc.IpcWriteOptions, optional
+ """
+ self.schema = pyarrow_unwrap_schema(schema)
+ self.generator = iter(generator)
+ self.c_options = _get_options(options).c_options
+
+ cdef CFlightDataStream* to_stream(self) except *:
+ cdef:
+ function[cb_data_stream_next] callback = &_data_stream_next
+ return new CPyGeneratorFlightDataStream(self, self.schema, callback,
+ self.c_options)
+
+
+cdef class ServerCallContext(_Weakrefable):
+ """Per-call state/context."""
+ cdef:
+ const CServerCallContext* context
+
+ def peer_identity(self):
+ """Get the identity of the authenticated peer.
+
+ May be the empty string.
+ """
+ return tobytes(self.context.peer_identity())
+
+ def peer(self):
+ """Get the address of the peer."""
+ # Set safe=True as gRPC on Windows sometimes gives garbage bytes
+ return frombytes(self.context.peer(), safe=True)
+
+ def is_cancelled(self):
+ return self.context.is_cancelled()
+
+ def get_middleware(self, key):
+ """
+ Get a middleware instance by key.
+
+ Returns None if the middleware was not found.
+ """
+ cdef:
+ CServerMiddleware* c_middleware = \
+ self.context.GetMiddleware(CPyServerMiddlewareName)
+ CPyServerMiddleware* middleware
+ if c_middleware == NULL:
+ return None
+ if c_middleware.name() != CPyServerMiddlewareName:
+ return None
+ middleware = <CPyServerMiddleware*> c_middleware
+ py_middleware = <_ServerMiddlewareWrapper> middleware.py_object()
+ return py_middleware.middleware.get(key)
+
+ @staticmethod
+ cdef ServerCallContext wrap(const CServerCallContext& context):
+ cdef ServerCallContext result = \
+ ServerCallContext.__new__(ServerCallContext)
+ result.context = &context
+ return result
+
+
+cdef class ServerAuthReader(_Weakrefable):
+ """A reader for messages from the client during an auth handshake."""
+ cdef:
+ CServerAuthReader* reader
+
+ def read(self):
+ cdef c_string token
+ if not self.reader:
+ raise ValueError("Cannot use ServerAuthReader outside "
+ "ServerAuthHandler.authenticate")
+ with nogil:
+ check_flight_status(self.reader.Read(&token))
+ return token
+
+ cdef void poison(self):
+ """Prevent further usage of this object.
+
+ This object is constructed by taking a pointer to a reference,
+ so we want to make sure Python users do not access this after
+ the reference goes away.
+ """
+ self.reader = NULL
+
+ @staticmethod
+ cdef ServerAuthReader wrap(CServerAuthReader* reader):
+ cdef ServerAuthReader result = \
+ ServerAuthReader.__new__(ServerAuthReader)
+ result.reader = reader
+ return result
+
+
+cdef class ServerAuthSender(_Weakrefable):
+ """A writer for messages to the client during an auth handshake."""
+ cdef:
+ CServerAuthSender* sender
+
+ def write(self, message):
+ cdef c_string c_message = tobytes(message)
+ if not self.sender:
+ raise ValueError("Cannot use ServerAuthSender outside "
+ "ServerAuthHandler.authenticate")
+ with nogil:
+ check_flight_status(self.sender.Write(c_message))
+
+ cdef void poison(self):
+ """Prevent further usage of this object.
+
+ This object is constructed by taking a pointer to a reference,
+ so we want to make sure Python users do not access this after
+ the reference goes away.
+ """
+ self.sender = NULL
+
+ @staticmethod
+ cdef ServerAuthSender wrap(CServerAuthSender* sender):
+ cdef ServerAuthSender result = \
+ ServerAuthSender.__new__(ServerAuthSender)
+ result.sender = sender
+ return result
+
+
+cdef class ClientAuthReader(_Weakrefable):
+ """A reader for messages from the server during an auth handshake."""
+ cdef:
+ CClientAuthReader* reader
+
+ def read(self):
+ cdef c_string token
+ if not self.reader:
+ raise ValueError("Cannot use ClientAuthReader outside "
+ "ClientAuthHandler.authenticate")
+ with nogil:
+ check_flight_status(self.reader.Read(&token))
+ return token
+
+ cdef void poison(self):
+ """Prevent further usage of this object.
+
+ This object is constructed by taking a pointer to a reference,
+ so we want to make sure Python users do not access this after
+ the reference goes away.
+ """
+ self.reader = NULL
+
+ @staticmethod
+ cdef ClientAuthReader wrap(CClientAuthReader* reader):
+ cdef ClientAuthReader result = \
+ ClientAuthReader.__new__(ClientAuthReader)
+ result.reader = reader
+ return result
+
+
+cdef class ClientAuthSender(_Weakrefable):
+ """A writer for messages to the server during an auth handshake."""
+ cdef:
+ CClientAuthSender* sender
+
+ def write(self, message):
+ cdef c_string c_message = tobytes(message)
+ if not self.sender:
+ raise ValueError("Cannot use ClientAuthSender outside "
+ "ClientAuthHandler.authenticate")
+ with nogil:
+ check_flight_status(self.sender.Write(c_message))
+
+ cdef void poison(self):
+ """Prevent further usage of this object.
+
+ This object is constructed by taking a pointer to a reference,
+ so we want to make sure Python users do not access this after
+ the reference goes away.
+ """
+ self.sender = NULL
+
+ @staticmethod
+ cdef ClientAuthSender wrap(CClientAuthSender* sender):
+ cdef ClientAuthSender result = \
+ ClientAuthSender.__new__(ClientAuthSender)
+ result.sender = sender
+ return result
+
+
+cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *:
+ """Callback for implementing FlightDataStream in Python."""
+ cdef:
+ unique_ptr[CFlightDataStream] data_stream
+
+ py_stream = <object> self
+ if not isinstance(py_stream, GeneratorStream):
+ raise RuntimeError("self object in callback is not GeneratorStream")
+ stream = <GeneratorStream> py_stream
+
+ # The generator is allowed to yield a reader or table which we
+ # yield from; if that sub-generator is empty, we need to reset and
+ # try again. However, limit the number of attempts so that we
+ # don't just spin forever.
+ max_attempts = 128
+ for _ in range(max_attempts):
+ if stream.current_stream != nullptr:
+ check_flight_status(stream.current_stream.get().Next(payload))
+ # If the stream ended, see if there's another stream from the
+ # generator
+ if payload.ipc_message.metadata != nullptr:
+ return CStatus_OK()
+ stream.current_stream.reset(nullptr)
+
+ try:
+ result = next(stream.generator)
+ except StopIteration:
+ payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
+ return CStatus_OK()
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+
+ if isinstance(result, (list, tuple)):
+ result, metadata = result
+ else:
+ result, metadata = result, None
+
+ if isinstance(result, (Table, RecordBatchReader)):
+ if metadata:
+ raise ValueError("Can only return metadata alongside a "
+ "RecordBatch.")
+ result = RecordBatchStream(result)
+
+ stream_schema = pyarrow_wrap_schema(stream.schema)
+ if isinstance(result, FlightDataStream):
+ if metadata:
+ raise ValueError("Can only return metadata alongside a "
+ "RecordBatch.")
+ data_stream = unique_ptr[CFlightDataStream](
+ (<FlightDataStream> result).to_stream())
+ substream_schema = pyarrow_wrap_schema(data_stream.get().schema())
+ if substream_schema != stream_schema:
+ raise ValueError("Got a FlightDataStream whose schema "
+ "does not match the declared schema of this "
+ "GeneratorStream. "
+ "Got: {}\nExpected: {}".format(
+ substream_schema, stream_schema))
+ stream.current_stream.reset(
+ new CPyFlightDataStream(result, move(data_stream)))
+ # Loop around and try again
+ continue
+ elif isinstance(result, RecordBatch):
+ batch = <RecordBatch> result
+ if batch.schema != stream_schema:
+ raise ValueError("Got a RecordBatch whose schema does not "
+ "match the declared schema of this "
+ "GeneratorStream. "
+ "Got: {}\nExpected: {}".format(batch.schema,
+ stream_schema))
+ check_flight_status(GetRecordBatchPayload(
+ deref(batch.batch),
+ stream.c_options,
+ &payload.ipc_message))
+ if metadata:
+ payload.app_metadata = pyarrow_unwrap_buffer(
+ as_buffer(metadata))
+ else:
+ raise TypeError("GeneratorStream must be initialized with "
+ "an iterator of FlightDataStream, Table, "
+ "RecordBatch, or RecordBatchStreamReader objects, "
+ "not {}.".format(type(result)))
+ # Don't loop around
+ return CStatus_OK()
+ # Ran out of attempts (the RPC handler kept yielding empty tables/readers)
+ raise RuntimeError("While getting next payload, ran out of attempts to "
+ "get something to send "
+ "(application server implementation error)")
+
+
+cdef CStatus _list_flights(void* self, const CServerCallContext& context,
+ const CCriteria* c_criteria,
+ unique_ptr[CFlightListing]* listing) except *:
+ """Callback for implementing ListFlights in Python."""
+ cdef:
+ vector[CFlightInfo] flights
+
+ try:
+ result = (<object> self).list_flights(ServerCallContext.wrap(context),
+ c_criteria.expression)
+ for info in result:
+ if not isinstance(info, FlightInfo):
+ raise TypeError("FlightServerBase.list_flights must return "
+ "FlightInfo instances, but got {}".format(
+ type(info)))
+ flights.push_back(deref((<FlightInfo> info).info.get()))
+ listing.reset(new CSimpleFlightListing(flights))
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _get_flight_info(void* self, const CServerCallContext& context,
+ CFlightDescriptor c_descriptor,
+ unique_ptr[CFlightInfo]* info) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ FlightDescriptor py_descriptor = \
+ FlightDescriptor.__new__(FlightDescriptor)
+ py_descriptor.descriptor = c_descriptor
+ try:
+ result = (<object> self).get_flight_info(
+ ServerCallContext.wrap(context),
+ py_descriptor)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ if not isinstance(result, FlightInfo):
+ raise TypeError("FlightServerBase.get_flight_info must return "
+ "a FlightInfo instance, but got {}".format(
+ type(result)))
+ info.reset(new CFlightInfo(deref((<FlightInfo> result).info.get())))
+ return CStatus_OK()
+
+cdef CStatus _get_schema(void* self, const CServerCallContext& context,
+ CFlightDescriptor c_descriptor,
+ unique_ptr[CSchemaResult]* info) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ FlightDescriptor py_descriptor = \
+ FlightDescriptor.__new__(FlightDescriptor)
+ py_descriptor.descriptor = c_descriptor
+ result = (<object> self).get_schema(ServerCallContext.wrap(context),
+ py_descriptor)
+ if not isinstance(result, SchemaResult):
+ raise TypeError("FlightServerBase.get_schema_info must return "
+ "a SchemaResult instance, but got {}".format(
+ type(result)))
+ info.reset(new CSchemaResult(deref((<SchemaResult> result).result.get())))
+ return CStatus_OK()
+
+cdef CStatus _do_put(void* self, const CServerCallContext& context,
+ unique_ptr[CFlightMessageReader] reader,
+ unique_ptr[CFlightMetadataWriter] writer) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ MetadataRecordBatchReader py_reader = MetadataRecordBatchReader()
+ FlightMetadataWriter py_writer = FlightMetadataWriter()
+ FlightDescriptor descriptor = \
+ FlightDescriptor.__new__(FlightDescriptor)
+
+ descriptor.descriptor = reader.get().descriptor()
+ py_reader.reader.reset(reader.release())
+ py_writer.writer.reset(writer.release())
+ try:
+ (<object> self).do_put(ServerCallContext.wrap(context), descriptor,
+ py_reader, py_writer)
+ return CStatus_OK()
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+
+
+cdef CStatus _do_get(void* self, const CServerCallContext& context,
+ CTicket ticket,
+ unique_ptr[CFlightDataStream]* stream) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ unique_ptr[CFlightDataStream] data_stream
+
+ py_ticket = Ticket(ticket.ticket)
+ try:
+ result = (<object> self).do_get(ServerCallContext.wrap(context),
+ py_ticket)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ if not isinstance(result, FlightDataStream):
+ raise TypeError("FlightServerBase.do_get must return "
+ "a FlightDataStream")
+ data_stream = unique_ptr[CFlightDataStream](
+ (<FlightDataStream> result).to_stream())
+ stream[0] = unique_ptr[CFlightDataStream](
+ new CPyFlightDataStream(result, move(data_stream)))
+ return CStatus_OK()
+
+
+cdef CStatus _do_exchange(void* self, const CServerCallContext& context,
+ unique_ptr[CFlightMessageReader] reader,
+ unique_ptr[CFlightMessageWriter] writer) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ MetadataRecordBatchReader py_reader = MetadataRecordBatchReader()
+ MetadataRecordBatchWriter py_writer = MetadataRecordBatchWriter()
+ FlightDescriptor descriptor = \
+ FlightDescriptor.__new__(FlightDescriptor)
+
+ descriptor.descriptor = reader.get().descriptor()
+ py_reader.reader.reset(reader.release())
+ py_writer.writer.reset(writer.release())
+ try:
+ (<object> self).do_exchange(ServerCallContext.wrap(context),
+ descriptor, py_reader, py_writer)
+ return CStatus_OK()
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+
+
+cdef CStatus _do_action_result_next(
+ void* self,
+ unique_ptr[CFlightResult]* result
+) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ CFlightResult* c_result
+
+ try:
+ action_result = next(<object> self)
+ if not isinstance(action_result, Result):
+ action_result = Result(action_result)
+ c_result = (<Result> action_result).result.get()
+ result.reset(new CFlightResult(deref(c_result)))
+ except StopIteration:
+ result.reset(nullptr)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _do_action(void* self, const CServerCallContext& context,
+ const CAction& action,
+ unique_ptr[CResultStream]* result) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ function[cb_result_next] ptr = &_do_action_result_next
+ py_action = Action(action.type, pyarrow_wrap_buffer(action.body))
+ try:
+ responses = (<object> self).do_action(ServerCallContext.wrap(context),
+ py_action)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ # Let the application return an iterator or anything convertible
+ # into one
+ if responses is None:
+ # Server didn't return anything
+ responses = []
+ result.reset(new CPyFlightResultStream(iter(responses), ptr))
+ return CStatus_OK()
+
+
+cdef CStatus _list_actions(void* self, const CServerCallContext& context,
+ vector[CActionType]* actions) except *:
+ """Callback for implementing Flight servers in Python."""
+ cdef:
+ CActionType action_type
+ # Method should return a list of ActionTypes or similar tuple
+ try:
+ result = (<object> self).list_actions(ServerCallContext.wrap(context))
+ for action in result:
+ if not isinstance(action, tuple):
+ raise TypeError(
+ "Results of list_actions must be ActionType or tuple")
+ action_type.type = tobytes(action[0])
+ action_type.description = tobytes(action[1])
+ actions.push_back(action_type)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _server_authenticate(void* self, CServerAuthSender* outgoing,
+ CServerAuthReader* incoming) except *:
+ """Callback for implementing authentication in Python."""
+ sender = ServerAuthSender.wrap(outgoing)
+ reader = ServerAuthReader.wrap(incoming)
+ try:
+ (<object> self).authenticate(sender, reader)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ finally:
+ sender.poison()
+ reader.poison()
+ return CStatus_OK()
+
+cdef CStatus _is_valid(void* self, const c_string& token,
+ c_string* peer_identity) except *:
+ """Callback for implementing authentication in Python."""
+ cdef c_string c_result
+ try:
+ c_result = tobytes((<object> self).is_valid(token))
+ peer_identity[0] = c_result
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _client_authenticate(void* self, CClientAuthSender* outgoing,
+ CClientAuthReader* incoming) except *:
+ """Callback for implementing authentication in Python."""
+ sender = ClientAuthSender.wrap(outgoing)
+ reader = ClientAuthReader.wrap(incoming)
+ try:
+ (<object> self).authenticate(sender, reader)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ finally:
+ sender.poison()
+ reader.poison()
+ return CStatus_OK()
+
+
+cdef CStatus _get_token(void* self, c_string* token) except *:
+ """Callback for implementing authentication in Python."""
+ cdef c_string c_result
+ try:
+ c_result = tobytes((<object> self).get_token())
+ token[0] = c_result
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _middleware_sending_headers(
+ void* self, CAddCallHeaders* add_headers) except *:
+ """Callback for implementing middleware."""
+ try:
+ headers = (<object> self).sending_headers()
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+
+ if headers:
+ for header, values in headers.items():
+ if isinstance(values, (str, bytes)):
+ values = (values,)
+ # Headers in gRPC (and HTTP/1, HTTP/2) are required to be
+ # valid ASCII.
+ if isinstance(header, str):
+ header = header.encode("ascii")
+ for value in values:
+ if isinstance(value, str):
+ value = value.encode("ascii")
+ # Allow bytes values to pass through.
+ add_headers.AddHeader(header, value)
+
+ return CStatus_OK()
+
+
+cdef CStatus _middleware_call_completed(
+ void* self,
+ const CStatus& call_status) except *:
+ """Callback for implementing middleware."""
+ try:
+ try:
+ check_flight_status(call_status)
+ except Exception as e:
+ (<object> self).call_completed(e)
+ else:
+ (<object> self).call_completed(None)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef CStatus _middleware_received_headers(
+ void* self,
+ const CCallHeaders& c_headers) except *:
+ """Callback for implementing middleware."""
+ try:
+ headers = convert_headers(c_headers)
+ (<object> self).received_headers(headers)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+ return CStatus_OK()
+
+
+cdef dict convert_headers(const CCallHeaders& c_headers):
+ cdef:
+ CCallHeaders.const_iterator header_iter = c_headers.cbegin()
+ headers = {}
+ while header_iter != c_headers.cend():
+ header = c_string(deref(header_iter).first).decode("ascii")
+ value = c_string(deref(header_iter).second)
+ if not header.endswith("-bin"):
+ # Text header values in gRPC (and HTTP/1, HTTP/2) are
+ # required to be valid ASCII. Binary header values are
+ # exposed as bytes.
+ value = value.decode("ascii")
+ headers.setdefault(header, []).append(value)
+ postincrement(header_iter)
+ return headers
+
+
+cdef CStatus _server_middleware_start_call(
+ void* self,
+ const CCallInfo& c_info,
+ const CCallHeaders& c_headers,
+ shared_ptr[CServerMiddleware]* c_instance) except *:
+ """Callback for implementing server middleware."""
+ instance = None
+ try:
+ call_info = wrap_call_info(c_info)
+ headers = convert_headers(c_headers)
+ instance = (<object> self).start_call(call_info, headers)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+
+ if instance:
+ ServerMiddleware.wrap(instance, c_instance)
+
+ return CStatus_OK()
+
+
+cdef CStatus _client_middleware_start_call(
+ void* self,
+ const CCallInfo& c_info,
+ unique_ptr[CClientMiddleware]* c_instance) except *:
+ """Callback for implementing client middleware."""
+ instance = None
+ try:
+ call_info = wrap_call_info(c_info)
+ instance = (<object> self).start_call(call_info)
+ except FlightError as flight_error:
+ return (<FlightError> flight_error).to_status()
+
+ if instance:
+ ClientMiddleware.wrap(instance, c_instance)
+
+ return CStatus_OK()
+
+
+cdef class ServerAuthHandler(_Weakrefable):
+ """Authentication middleware for a server.
+
+ To implement an authentication mechanism, subclass this class and
+ override its methods.
+
+ """
+
+ def authenticate(self, outgoing, incoming):
+ """Conduct the handshake with the client.
+
+ May raise an error if the client cannot authenticate.
+
+ Parameters
+ ----------
+ outgoing : ServerAuthSender
+ A channel to send messages to the client.
+ incoming : ServerAuthReader
+ A channel to read messages from the client.
+ """
+ raise NotImplementedError
+
+ def is_valid(self, token):
+ """Validate a client token, returning their identity.
+
+ May return an empty string (if the auth mechanism does not
+ name the peer) or raise an exception (if the token is
+ invalid).
+
+ Parameters
+ ----------
+ token : bytes
+ The authentication token from the client.
+
+ """
+ raise NotImplementedError
+
+ cdef PyServerAuthHandler* to_handler(self):
+ cdef PyServerAuthHandlerVtable vtable
+ vtable.authenticate = _server_authenticate
+ vtable.is_valid = _is_valid
+ return new PyServerAuthHandler(self, vtable)
+
+
+cdef class ClientAuthHandler(_Weakrefable):
+ """Authentication plugin for a client."""
+
+ def authenticate(self, outgoing, incoming):
+ """Conduct the handshake with the server.
+
+ Parameters
+ ----------
+ outgoing : ClientAuthSender
+ A channel to send messages to the server.
+ incoming : ClientAuthReader
+ A channel to read messages from the server.
+ """
+ raise NotImplementedError
+
+ def get_token(self):
+ """Get the auth token for a call."""
+ raise NotImplementedError
+
+ cdef PyClientAuthHandler* to_handler(self):
+ cdef PyClientAuthHandlerVtable vtable
+ vtable.authenticate = _client_authenticate
+ vtable.get_token = _get_token
+ return new PyClientAuthHandler(self, vtable)
+
+
+_CallInfo = collections.namedtuple("_CallInfo", ["method"])
+
+
+class CallInfo(_CallInfo):
+ """Information about a particular RPC for Flight middleware."""
+
+
+cdef wrap_call_info(const CCallInfo& c_info):
+ method = wrap_flight_method(c_info.method)
+ return CallInfo(method=method)
+
+
+cdef class ClientMiddlewareFactory(_Weakrefable):
+ """A factory for new middleware instances.
+
+ All middleware methods will be called from the same thread as the
+ RPC method implementation. That is, thread-locals set in the
+ client are accessible from the middleware itself.
+
+ """
+
+ def start_call(self, info):
+ """Called at the start of an RPC.
+
+ This must be thread-safe and must not raise exceptions.
+
+ Parameters
+ ----------
+ info : CallInfo
+ Information about the call.
+
+ Returns
+ -------
+ instance : ClientMiddleware
+ An instance of ClientMiddleware (the instance to use for
+ the call), or None if this call is not intercepted.
+
+ """
+
+
+cdef class ClientMiddleware(_Weakrefable):
+ """Client-side middleware for a call, instantiated per RPC.
+
+ Methods here should be fast and must be infallible: they should
+ not raise exceptions or stall indefinitely.
+
+ """
+
+ def sending_headers(self):
+ """A callback before headers are sent.
+
+ Returns
+ -------
+ headers : dict
+ A dictionary of header values to add to the request, or
+ None if no headers are to be added. The dictionary should
+ have string keys and string or list-of-string values.
+
+ Bytes values are allowed, but the underlying transport may
+ not support them or may restrict them. For gRPC, binary
+ values are only allowed on headers ending in "-bin".
+
+ """
+
+ def received_headers(self, headers):
+ """A callback when headers are received.
+
+ The default implementation does nothing.
+
+ Parameters
+ ----------
+ headers : dict
+ A dictionary of headers from the server. Keys are strings
+ and values are lists of strings (for text headers) or
+ bytes (for binary headers).
+
+ """
+
+ def call_completed(self, exception):
+ """A callback when the call finishes.
+
+ The default implementation does nothing.
+
+ Parameters
+ ----------
+ exception : ArrowException
+ If the call errored, this is the equivalent
+ exception. Will be None if the call succeeded.
+
+ """
+
+ @staticmethod
+ cdef void wrap(object py_middleware,
+ unique_ptr[CClientMiddleware]* c_instance):
+ cdef PyClientMiddlewareVtable vtable
+ vtable.sending_headers = _middleware_sending_headers
+ vtable.received_headers = _middleware_received_headers
+ vtable.call_completed = _middleware_call_completed
+ c_instance[0].reset(new CPyClientMiddleware(py_middleware, vtable))
+
+
+cdef class ServerMiddlewareFactory(_Weakrefable):
+ """A factory for new middleware instances.
+
+ All middleware methods will be called from the same thread as the
+ RPC method implementation. That is, thread-locals set in the
+ middleware are accessible from the method itself.
+
+ """
+
+ def start_call(self, info, headers):
+ """Called at the start of an RPC.
+
+ This must be thread-safe.
+
+ Parameters
+ ----------
+ info : CallInfo
+ Information about the call.
+ headers : dict
+ A dictionary of headers from the client. Keys are strings
+ and values are lists of strings (for text headers) or
+ bytes (for binary headers).
+
+ Returns
+ -------
+ instance : ServerMiddleware
+ An instance of ServerMiddleware (the instance to use for
+ the call), or None if this call is not intercepted.
+
+ Raises
+ ------
+ exception : pyarrow.ArrowException
+ If an exception is raised, the call will be rejected with
+ the given error.
+
+ """
+
+
+cdef class ServerMiddleware(_Weakrefable):
+ """Server-side middleware for a call, instantiated per RPC.
+
+ Methods here should be fast and must be infalliable: they should
+ not raise exceptions or stall indefinitely.
+
+ """
+
+ def sending_headers(self):
+ """A callback before headers are sent.
+
+ Returns
+ -------
+ headers : dict
+ A dictionary of header values to add to the response, or
+ None if no headers are to be added. The dictionary should
+ have string keys and string or list-of-string values.
+
+ Bytes values are allowed, but the underlying transport may
+ not support them or may restrict them. For gRPC, binary
+ values are only allowed on headers ending in "-bin".
+
+ """
+
+ def call_completed(self, exception):
+ """A callback when the call finishes.
+
+ Parameters
+ ----------
+ exception : pyarrow.ArrowException
+ If the call errored, this is the equivalent
+ exception. Will be None if the call succeeded.
+
+ """
+
+ @staticmethod
+ cdef void wrap(object py_middleware,
+ shared_ptr[CServerMiddleware]* c_instance):
+ cdef PyServerMiddlewareVtable vtable
+ vtable.sending_headers = _middleware_sending_headers
+ vtable.call_completed = _middleware_call_completed
+ c_instance[0].reset(new CPyServerMiddleware(py_middleware, vtable))
+
+
+cdef class _ServerMiddlewareFactoryWrapper(ServerMiddlewareFactory):
+ """Wrapper to bundle server middleware into a single C++ one."""
+
+ cdef:
+ dict factories
+
+ def __init__(self, dict factories):
+ self.factories = factories
+
+ def start_call(self, info, headers):
+ instances = {}
+ for key, factory in self.factories.items():
+ instance = factory.start_call(info, headers)
+ if instance:
+ # TODO: prevent duplicate keys
+ instances[key] = instance
+ if instances:
+ wrapper = _ServerMiddlewareWrapper(instances)
+ return wrapper
+ return None
+
+
+cdef class _ServerMiddlewareWrapper(ServerMiddleware):
+ cdef:
+ dict middleware
+
+ def __init__(self, dict middleware):
+ self.middleware = middleware
+
+ def sending_headers(self):
+ headers = collections.defaultdict(list)
+ for instance in self.middleware.values():
+ more_headers = instance.sending_headers()
+ if not more_headers:
+ continue
+ # Manually merge with existing headers (since headers are
+ # multi-valued)
+ for key, values in more_headers.items():
+ if isinstance(values, (bytes, str)):
+ values = (values,)
+ headers[key].extend(values)
+ return headers
+
+ def call_completed(self, exception):
+ for instance in self.middleware.values():
+ instance.call_completed(exception)
+
+
+cdef class FlightServerBase(_Weakrefable):
+ """A Flight service definition.
+
+ Override methods to define your Flight service.
+
+ Parameters
+ ----------
+ location : str, tuple or Location optional, default None
+ Location to serve on. Either a gRPC URI like `grpc://localhost:port`,
+ a tuple of (host, port) pair, or a Location instance.
+ If None is passed then the server will be started on localhost with a
+ system provided random port.
+ auth_handler : ServerAuthHandler optional, default None
+ An authentication mechanism to use. May be None.
+ tls_certificates : list optional, default None
+ A list of (certificate, key) pairs.
+ verify_client : boolean optional, default False
+ If True, then enable mutual TLS: require the client to present
+ a client certificate, and validate the certificate.
+ root_certificates : bytes optional, default None
+ If enabling mutual TLS, this specifies the PEM-encoded root
+ certificate used to validate client certificates.
+ middleware : list optional, default None
+ A dictionary of :class:`ServerMiddlewareFactory` items. The
+ keys are used to retrieve the middleware instance during calls
+ (see :meth:`ServerCallContext.get_middleware`).
+
+ """
+
+ cdef:
+ unique_ptr[PyFlightServer] server
+
+ def __init__(self, location=None, auth_handler=None,
+ tls_certificates=None, verify_client=None,
+ root_certificates=None, middleware=None):
+ if isinstance(location, (bytes, str)):
+ location = Location(location)
+ elif isinstance(location, (tuple, type(None))):
+ if location is None:
+ location = ('localhost', 0)
+ host, port = location
+ if tls_certificates:
+ location = Location.for_grpc_tls(host, port)
+ else:
+ location = Location.for_grpc_tcp(host, port)
+ elif not isinstance(location, Location):
+ raise TypeError('`location` argument must be a string, tuple or a '
+ 'Location instance')
+ self.init(location, auth_handler, tls_certificates, verify_client,
+ tobytes(root_certificates or b""), middleware)
+
+ cdef init(self, Location location, ServerAuthHandler auth_handler,
+ list tls_certificates, c_bool verify_client,
+ bytes root_certificates, dict middleware):
+ cdef:
+ PyFlightServerVtable vtable = PyFlightServerVtable()
+ PyFlightServer* c_server
+ unique_ptr[CFlightServerOptions] c_options
+ CCertKeyPair c_cert
+ function[cb_server_middleware_start_call] start_call = \
+ &_server_middleware_start_call
+ pair[c_string, shared_ptr[CServerMiddlewareFactory]] c_middleware
+
+ c_options.reset(new CFlightServerOptions(Location.unwrap(location)))
+ # mTLS configuration
+ c_options.get().verify_client = verify_client
+ c_options.get().root_certificates = root_certificates
+
+ if auth_handler:
+ if not isinstance(auth_handler, ServerAuthHandler):
+ raise TypeError("auth_handler must be a ServerAuthHandler, "
+ "not a '{}'".format(type(auth_handler)))
+ c_options.get().auth_handler.reset(
+ (<ServerAuthHandler> auth_handler).to_handler())
+
+ if tls_certificates:
+ for cert, key in tls_certificates:
+ c_cert.pem_cert = tobytes(cert)
+ c_cert.pem_key = tobytes(key)
+ c_options.get().tls_certificates.push_back(c_cert)
+
+ if middleware:
+ py_middleware = _ServerMiddlewareFactoryWrapper(middleware)
+ c_middleware.first = CPyServerMiddlewareName
+ c_middleware.second.reset(new CPyServerMiddlewareFactory(
+ py_middleware,
+ start_call))
+ c_options.get().middleware.push_back(c_middleware)
+
+ vtable.list_flights = &_list_flights
+ vtable.get_flight_info = &_get_flight_info
+ vtable.get_schema = &_get_schema
+ vtable.do_put = &_do_put
+ vtable.do_get = &_do_get
+ vtable.do_exchange = &_do_exchange
+ vtable.list_actions = &_list_actions
+ vtable.do_action = &_do_action
+
+ c_server = new PyFlightServer(self, vtable)
+ self.server.reset(c_server)
+ with nogil:
+ check_flight_status(c_server.Init(deref(c_options)))
+
+ @property
+ def port(self):
+ """
+ Get the port that this server is listening on.
+
+ Returns a non-positive value if the operation is invalid
+ (e.g. init() was not called or server is listening on a domain
+ socket).
+ """
+ return self.server.get().port()
+
+ def list_flights(self, context, criteria):
+ raise NotImplementedError
+
+ def get_flight_info(self, context, descriptor):
+ raise NotImplementedError
+
+ def get_schema(self, context, descriptor):
+ raise NotImplementedError
+
+ def do_put(self, context, descriptor, reader,
+ writer: FlightMetadataWriter):
+ raise NotImplementedError
+
+ def do_get(self, context, ticket):
+ raise NotImplementedError
+
+ def do_exchange(self, context, descriptor, reader, writer):
+ raise NotImplementedError
+
+ def list_actions(self, context):
+ raise NotImplementedError
+
+ def do_action(self, context, action):
+ raise NotImplementedError
+
+ def serve(self):
+ """Start serving.
+
+ This method only returns if shutdown() is called or a signal a
+ received.
+ """
+ if self.server.get() == nullptr:
+ raise ValueError("run() on uninitialized FlightServerBase")
+ with nogil:
+ check_flight_status(self.server.get().ServeWithSignals())
+
+ def run(self):
+ warnings.warn("The 'FlightServer.run' method is deprecated, use "
+ "FlightServer.serve method instead")
+ self.serve()
+
+ def shutdown(self):
+ """Shut down the server, blocking until current requests finish.
+
+ Do not call this directly from the implementation of a Flight
+ method, as then the server will block forever waiting for that
+ request to finish. Instead, call this method from a background
+ thread.
+ """
+ # Must not hold the GIL: shutdown waits for pending RPCs to
+ # complete. Holding the GIL means Python-implemented Flight
+ # methods will never get to run, so this will hang
+ # indefinitely.
+ if self.server.get() == nullptr:
+ raise ValueError("shutdown() on uninitialized FlightServerBase")
+ with nogil:
+ check_flight_status(self.server.get().Shutdown())
+
+ def wait(self):
+ """Block until server is terminated with shutdown."""
+ with nogil:
+ self.server.get().Wait()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.shutdown()
+ self.wait()
+
+
+def connect(location, **kwargs):
+ """
+ Connect to the Flight server
+ Parameters
+ ----------
+ location : str, tuple or Location
+ Location to connect to. Either a gRPC URI like `grpc://localhost:port`,
+ a tuple of (host, port) pair, or a Location instance.
+ tls_root_certs : bytes or None
+ PEM-encoded
+ cert_chain: str or None
+ If provided, enables TLS mutual authentication.
+ private_key: str or None
+ If provided, enables TLS mutual authentication.
+ override_hostname : str or None
+ Override the hostname checked by TLS. Insecure, use with caution.
+ middleware : list or None
+ A list of ClientMiddlewareFactory instances to apply.
+ write_size_limit_bytes : int or None
+ A soft limit on the size of a data payload sent to the
+ server. Enabled if positive. If enabled, writing a record
+ batch that (when serialized) exceeds this limit will raise an
+ exception; the client can retry the write with a smaller
+ batch.
+ disable_server_verification : boolean or None
+ Disable verifying the server when using TLS.
+ Insecure, use with caution.
+ generic_options : list or None
+ A list of generic (string, int or string) options to pass to
+ the underlying transport.
+ Returns
+ -------
+ client : FlightClient
+ """
+ return FlightClient(location, **kwargs)
diff --git a/src/arrow/python/pyarrow/_fs.pxd b/src/arrow/python/pyarrow/_fs.pxd
new file mode 100644
index 000000000..4504b78b8
--- /dev/null
+++ b/src/arrow/python/pyarrow/_fs.pxd
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow_fs cimport *
+from pyarrow.lib import _detect_compression, frombytes, tobytes
+from pyarrow.lib cimport *
+
+
+cpdef enum FileType:
+ NotFound = <int8_t> CFileType_NotFound
+ Unknown = <int8_t> CFileType_Unknown
+ File = <int8_t> CFileType_File
+ Directory = <int8_t> CFileType_Directory
+
+
+cdef class FileInfo(_Weakrefable):
+ cdef:
+ CFileInfo info
+
+ @staticmethod
+ cdef wrap(CFileInfo info)
+
+ cdef inline CFileInfo unwrap(self) nogil
+
+ @staticmethod
+ cdef CFileInfo unwrap_safe(obj)
+
+
+cdef class FileSelector(_Weakrefable):
+ cdef:
+ CFileSelector selector
+
+ @staticmethod
+ cdef FileSelector wrap(CFileSelector selector)
+
+ cdef inline CFileSelector unwrap(self) nogil
+
+
+cdef class FileSystem(_Weakrefable):
+ cdef:
+ shared_ptr[CFileSystem] wrapped
+ CFileSystem* fs
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped)
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFileSystem]& sp)
+
+ cdef inline shared_ptr[CFileSystem] unwrap(self) nogil
+
+
+cdef class LocalFileSystem(FileSystem):
+ cdef:
+ CLocalFileSystem* localfs
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped)
+
+
+cdef class SubTreeFileSystem(FileSystem):
+ cdef:
+ CSubTreeFileSystem* subtreefs
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped)
+
+
+cdef class _MockFileSystem(FileSystem):
+ cdef:
+ CMockFileSystem* mockfs
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped)
+
+
+cdef class PyFileSystem(FileSystem):
+ cdef:
+ CPyFileSystem* pyfs
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped)
diff --git a/src/arrow/python/pyarrow/_fs.pyx b/src/arrow/python/pyarrow/_fs.pyx
new file mode 100644
index 000000000..34a0ef55b
--- /dev/null
+++ b/src/arrow/python/pyarrow/_fs.pyx
@@ -0,0 +1,1233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from cpython.datetime cimport datetime, PyDateTime_DateTime
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport PyDateTime_to_TimePoint
+from pyarrow.lib import _detect_compression, frombytes, tobytes
+from pyarrow.lib cimport *
+from pyarrow.util import _stringify_path
+
+from abc import ABC, abstractmethod
+from datetime import datetime, timezone
+import os
+import pathlib
+import sys
+import warnings
+
+
+cdef _init_ca_paths():
+ cdef CFileSystemGlobalOptions options
+
+ import ssl
+ paths = ssl.get_default_verify_paths()
+ if paths.cafile:
+ options.tls_ca_file_path = os.fsencode(paths.cafile)
+ if paths.capath:
+ options.tls_ca_dir_path = os.fsencode(paths.capath)
+ check_status(CFileSystemsInitialize(options))
+
+
+if sys.platform == 'linux':
+ # ARROW-9261: On Linux, we may need to fixup the paths to TLS CA certs
+ # (especially in manylinux packages) since the values hardcoded at
+ # compile-time in libcurl may be wrong.
+ _init_ca_paths()
+
+
+cdef inline c_string _path_as_bytes(path) except *:
+ # handle only abstract paths, not bound to any filesystem like pathlib is,
+ # so we only accept plain strings
+ if not isinstance(path, (bytes, str)):
+ raise TypeError('Path must be a string')
+ # tobytes always uses utf-8, which is more or less ok, at least on Windows
+ # since the C++ side then decodes from utf-8. On Unix, os.fsencode may be
+ # better.
+ return tobytes(path)
+
+
+cdef object _wrap_file_type(CFileType ty):
+ return FileType(<int8_t> ty)
+
+
+cdef CFileType _unwrap_file_type(FileType ty) except *:
+ if ty == FileType.Unknown:
+ return CFileType_Unknown
+ elif ty == FileType.NotFound:
+ return CFileType_NotFound
+ elif ty == FileType.File:
+ return CFileType_File
+ elif ty == FileType.Directory:
+ return CFileType_Directory
+ assert 0
+
+
+cdef class FileInfo(_Weakrefable):
+ """
+ FileSystem entry info.
+
+ Parameters
+ ----------
+ path : str
+ The full path to the filesystem entry.
+ type : FileType
+ The type of the filesystem entry.
+ mtime : datetime or float, default None
+ If given, the modification time of the filesystem entry.
+ If a float is given, it is the number of seconds since the
+ Unix epoch.
+ mtime_ns : int, default None
+ If given, the modification time of the filesystem entry,
+ in nanoseconds since the Unix epoch.
+ `mtime` and `mtime_ns` are mutually exclusive.
+ size : int, default None
+ If given, the filesystem entry size in bytes. This should only
+ be given if `type` is `FileType.File`.
+
+ """
+
+ def __init__(self, path, FileType type=FileType.Unknown, *,
+ mtime=None, mtime_ns=None, size=None):
+ self.info.set_path(tobytes(path))
+ self.info.set_type(_unwrap_file_type(type))
+ if mtime is not None:
+ if mtime_ns is not None:
+ raise TypeError("Only one of mtime and mtime_ns "
+ "can be given")
+ if isinstance(mtime, datetime):
+ self.info.set_mtime(PyDateTime_to_TimePoint(
+ <PyDateTime_DateTime*> mtime))
+ else:
+ self.info.set_mtime(TimePoint_from_s(mtime))
+ elif mtime_ns is not None:
+ self.info.set_mtime(TimePoint_from_ns(mtime_ns))
+ if size is not None:
+ self.info.set_size(size)
+
+ @staticmethod
+ cdef wrap(CFileInfo info):
+ cdef FileInfo self = FileInfo.__new__(FileInfo)
+ self.info = move(info)
+ return self
+
+ cdef inline CFileInfo unwrap(self) nogil:
+ return self.info
+
+ @staticmethod
+ cdef CFileInfo unwrap_safe(obj):
+ if not isinstance(obj, FileInfo):
+ raise TypeError("Expected FileInfo instance, got {0}"
+ .format(type(obj)))
+ return (<FileInfo> obj).unwrap()
+
+ def __repr__(self):
+ def getvalue(attr):
+ try:
+ return getattr(self, attr)
+ except ValueError:
+ return ''
+
+ s = '<FileInfo for {!r}: type={}'.format(self.path, str(self.type))
+ if self.is_file:
+ s += ', size={}'.format(self.size)
+ s += '>'
+ return s
+
+ @property
+ def type(self):
+ """
+ Type of the file.
+
+ The returned enum values can be the following:
+
+ - FileType.NotFound: target does not exist
+ - FileType.Unknown: target exists but its type is unknown (could be a
+ special file such as a Unix socket or character device, or
+ Windows NUL / CON / ...)
+ - FileType.File: target is a regular file
+ - FileType.Directory: target is a regular directory
+
+ Returns
+ -------
+ type : FileType
+ """
+ return _wrap_file_type(self.info.type())
+
+ @property
+ def is_file(self):
+ """
+ """
+ return self.type == FileType.File
+
+ @property
+ def path(self):
+ """
+ The full file path in the filesystem.
+ """
+ return frombytes(self.info.path())
+
+ @property
+ def base_name(self):
+ """
+ The file base name.
+
+ Component after the last directory separator.
+ """
+ return frombytes(self.info.base_name())
+
+ @property
+ def size(self):
+ """
+ The size in bytes, if available.
+
+ Only regular files are guaranteed to have a size.
+
+ Returns
+ -------
+ size : int or None
+ """
+ cdef int64_t size
+ size = self.info.size()
+ return (size if size != -1 else None)
+
+ @property
+ def extension(self):
+ """
+ The file extension.
+ """
+ return frombytes(self.info.extension())
+
+ @property
+ def mtime(self):
+ """
+ The time of last modification, if available.
+
+ Returns
+ -------
+ mtime : datetime.datetime or None
+ """
+ cdef int64_t nanoseconds
+ nanoseconds = TimePoint_to_ns(self.info.mtime())
+ return (datetime.fromtimestamp(nanoseconds / 1.0e9, timezone.utc)
+ if nanoseconds != -1 else None)
+
+ @property
+ def mtime_ns(self):
+ """
+ The time of last modification, if available, expressed in nanoseconds
+ since the Unix epoch.
+
+ Returns
+ -------
+ mtime_ns : int or None
+ """
+ cdef int64_t nanoseconds
+ nanoseconds = TimePoint_to_ns(self.info.mtime())
+ return (nanoseconds if nanoseconds != -1 else None)
+
+
+cdef class FileSelector(_Weakrefable):
+ """
+ File and directory selector.
+
+ It contains a set of options that describes how to search for files and
+ directories.
+
+ Parameters
+ ----------
+ base_dir : str
+ The directory in which to select files. Relative paths also work, use
+ '.' for the current directory and '..' for the parent.
+ allow_not_found : bool, default False
+ The behavior if `base_dir` doesn't exist in the filesystem.
+ If false, an error is returned.
+ If true, an empty selection is returned.
+ recursive : bool, default False
+ Whether to recurse into subdirectories.
+ """
+
+ def __init__(self, base_dir, bint allow_not_found=False,
+ bint recursive=False):
+ self.base_dir = base_dir
+ self.recursive = recursive
+ self.allow_not_found = allow_not_found
+
+ @staticmethod
+ cdef FileSelector wrap(CFileSelector wrapped):
+ cdef FileSelector self = FileSelector.__new__(FileSelector)
+ self.selector = move(wrapped)
+ return self
+
+ cdef inline CFileSelector unwrap(self) nogil:
+ return self.selector
+
+ @property
+ def base_dir(self):
+ return frombytes(self.selector.base_dir)
+
+ @base_dir.setter
+ def base_dir(self, base_dir):
+ self.selector.base_dir = _path_as_bytes(base_dir)
+
+ @property
+ def allow_not_found(self):
+ return self.selector.allow_not_found
+
+ @allow_not_found.setter
+ def allow_not_found(self, bint allow_not_found):
+ self.selector.allow_not_found = allow_not_found
+
+ @property
+ def recursive(self):
+ return self.selector.recursive
+
+ @recursive.setter
+ def recursive(self, bint recursive):
+ self.selector.recursive = recursive
+
+ def __repr__(self):
+ return ("<FileSelector base_dir={0.base_dir!r} "
+ "recursive={0.recursive}>".format(self))
+
+
+cdef class FileSystem(_Weakrefable):
+ """
+ Abstract file system API.
+ """
+
+ def __init__(self):
+ raise TypeError("FileSystem is an abstract class, instantiate one of "
+ "the subclasses instead: LocalFileSystem or "
+ "SubTreeFileSystem")
+
+ @staticmethod
+ def from_uri(uri):
+ """
+ Create a new FileSystem from URI or Path.
+
+ Recognized URI schemes are "file", "mock", "s3fs", "hdfs" and "viewfs".
+ In addition, the argument can be a pathlib.Path object, or a string
+ describing an absolute local path.
+
+ Parameters
+ ----------
+ uri : string
+ URI-based path, for example: file:///some/local/path.
+
+ Returns
+ -------
+ With (filesystem, path) tuple where path is the abstract path inside
+ the FileSystem instance.
+ """
+ cdef:
+ c_string path
+ CResult[shared_ptr[CFileSystem]] result
+
+ if isinstance(uri, pathlib.Path):
+ # Make absolute
+ uri = uri.resolve().absolute()
+ uri = _stringify_path(uri)
+ result = CFileSystemFromUriOrPath(tobytes(uri), &path)
+ return FileSystem.wrap(GetResultValue(result)), frombytes(path)
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped):
+ self.wrapped = wrapped
+ self.fs = wrapped.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CFileSystem]& sp):
+ cdef FileSystem self
+
+ typ = frombytes(sp.get().type_name())
+ if typ == 'local':
+ self = LocalFileSystem.__new__(LocalFileSystem)
+ elif typ == 'mock':
+ self = _MockFileSystem.__new__(_MockFileSystem)
+ elif typ == 'subtree':
+ self = SubTreeFileSystem.__new__(SubTreeFileSystem)
+ elif typ == 's3':
+ from pyarrow._s3fs import S3FileSystem
+ self = S3FileSystem.__new__(S3FileSystem)
+ elif typ == 'hdfs':
+ from pyarrow._hdfs import HadoopFileSystem
+ self = HadoopFileSystem.__new__(HadoopFileSystem)
+ elif typ.startswith('py::'):
+ self = PyFileSystem.__new__(PyFileSystem)
+ else:
+ raise TypeError('Cannot wrap FileSystem pointer')
+
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[CFileSystem] unwrap(self) nogil:
+ return self.wrapped
+
+ def equals(self, FileSystem other):
+ return self.fs.Equals(other.unwrap())
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ @property
+ def type_name(self):
+ """
+ The filesystem's type name.
+ """
+ return frombytes(self.fs.type_name())
+
+ def get_file_info(self, paths_or_selector):
+ """
+ Get info for the given files.
+
+ Any symlink is automatically dereferenced, recursively. A non-existing
+ or unreachable file returns a FileStat object and has a FileType of
+ value NotFound. An exception indicates a truly exceptional condition
+ (low-level I/O error, etc.).
+
+ Parameters
+ ----------
+ paths_or_selector: FileSelector, path-like or list of path-likes
+ Either a selector object, a path-like object or a list of
+ path-like objects. The selector's base directory will not be
+ part of the results, even if it exists. If it doesn't exist,
+ use `allow_not_found`.
+
+ Returns
+ -------
+ FileInfo or list of FileInfo
+ Single FileInfo object is returned for a single path, otherwise
+ a list of FileInfo objects is returned.
+ """
+ cdef:
+ CFileInfo info
+ c_string path
+ vector[CFileInfo] infos
+ vector[c_string] paths
+ CFileSelector selector
+
+ if isinstance(paths_or_selector, FileSelector):
+ with nogil:
+ selector = (<FileSelector>paths_or_selector).selector
+ infos = GetResultValue(self.fs.GetFileInfo(selector))
+ elif isinstance(paths_or_selector, (list, tuple)):
+ paths = [_path_as_bytes(s) for s in paths_or_selector]
+ with nogil:
+ infos = GetResultValue(self.fs.GetFileInfo(paths))
+ elif isinstance(paths_or_selector, (bytes, str)):
+ path =_path_as_bytes(paths_or_selector)
+ with nogil:
+ info = GetResultValue(self.fs.GetFileInfo(path))
+ return FileInfo.wrap(info)
+ else:
+ raise TypeError('Must pass either path(s) or a FileSelector')
+
+ return [FileInfo.wrap(info) for info in infos]
+
+ def create_dir(self, path, *, bint recursive=True):
+ """
+ Create a directory and subdirectories.
+
+ This function succeeds if the directory already exists.
+
+ Parameters
+ ----------
+ path : str
+ The path of the new directory.
+ recursive: bool, default True
+ Create nested directories as well.
+ """
+ cdef c_string directory = _path_as_bytes(path)
+ with nogil:
+ check_status(self.fs.CreateDir(directory, recursive=recursive))
+
+ def delete_dir(self, path):
+ """Delete a directory and its contents, recursively.
+
+ Parameters
+ ----------
+ path : str
+ The path of the directory to be deleted.
+ """
+ cdef c_string directory = _path_as_bytes(path)
+ with nogil:
+ check_status(self.fs.DeleteDir(directory))
+
+ def delete_dir_contents(self, path, *, bint accept_root_dir=False):
+ """Delete a directory's contents, recursively.
+
+ Like delete_dir, but doesn't delete the directory itself.
+
+ Parameters
+ ----------
+ path : str
+ The path of the directory to be deleted.
+ accept_root_dir : boolean, default False
+ Allow deleting the root directory's contents
+ (if path is empty or "/")
+ """
+ cdef c_string directory = _path_as_bytes(path)
+ if accept_root_dir and directory.strip(b"/") == b"":
+ with nogil:
+ check_status(self.fs.DeleteRootDirContents())
+ else:
+ with nogil:
+ check_status(self.fs.DeleteDirContents(directory))
+
+ def move(self, src, dest):
+ """
+ Move / rename a file or directory.
+
+ If the destination exists:
+ - if it is a non-empty directory, an error is returned
+ - otherwise, if it has the same type as the source, it is replaced
+ - otherwise, behavior is unspecified (implementation-dependent).
+
+ Parameters
+ ----------
+ src : str
+ The path of the file or the directory to be moved.
+ dest : str
+ The destination path where the file or directory is moved to.
+ """
+ cdef:
+ c_string source = _path_as_bytes(src)
+ c_string destination = _path_as_bytes(dest)
+ with nogil:
+ check_status(self.fs.Move(source, destination))
+
+ def copy_file(self, src, dest):
+ """
+ Copy a file.
+
+ If the destination exists and is a directory, an error is returned.
+ Otherwise, it is replaced.
+
+ Parameters
+ ----------
+ src : str
+ The path of the file to be copied from.
+ dest : str
+ The destination path where the file is copied to.
+ """
+ cdef:
+ c_string source = _path_as_bytes(src)
+ c_string destination = _path_as_bytes(dest)
+ with nogil:
+ check_status(self.fs.CopyFile(source, destination))
+
+ def delete_file(self, path):
+ """
+ Delete a file.
+
+ Parameters
+ ----------
+ path : str
+ The path of the file to be deleted.
+ """
+ cdef c_string file = _path_as_bytes(path)
+ with nogil:
+ check_status(self.fs.DeleteFile(file))
+
+ def _wrap_input_stream(self, stream, path, compression, buffer_size):
+ if buffer_size is not None and buffer_size != 0:
+ stream = BufferedInputStream(stream, buffer_size)
+ if compression == 'detect':
+ compression = _detect_compression(path)
+ if compression is not None:
+ stream = CompressedInputStream(stream, compression)
+ return stream
+
+ def _wrap_output_stream(self, stream, path, compression, buffer_size):
+ if buffer_size is not None and buffer_size != 0:
+ stream = BufferedOutputStream(stream, buffer_size)
+ if compression == 'detect':
+ compression = _detect_compression(path)
+ if compression is not None:
+ stream = CompressedOutputStream(stream, compression)
+ return stream
+
+ def open_input_file(self, path):
+ """
+ Open an input file for random access reading.
+
+ Parameters
+ ----------
+ path : str
+ The source to open for reading.
+
+ Returns
+ -------
+ stram : NativeFile
+ """
+ cdef:
+ c_string pathstr = _path_as_bytes(path)
+ NativeFile stream = NativeFile()
+ shared_ptr[CRandomAccessFile] in_handle
+
+ with nogil:
+ in_handle = GetResultValue(self.fs.OpenInputFile(pathstr))
+
+ stream.set_random_access_file(in_handle)
+ stream.is_readable = True
+ return stream
+
+ def open_input_stream(self, path, compression='detect', buffer_size=None):
+ """
+ Open an input stream for sequential reading.
+
+ Parameters
+ ----------
+ source : str
+ The source to open for reading.
+ compression : str optional, default 'detect'
+ The compression algorithm to use for on-the-fly decompression.
+ If "detect" and source is a file path, then compression will be
+ chosen based on the file extension.
+ If None, no compression will be applied. Otherwise, a well-known
+ algorithm name must be supplied (e.g. "gzip").
+ buffer_size : int optional, default None
+ If None or 0, no buffering will happen. Otherwise the size of the
+ temporary read buffer.
+
+ Returns
+ -------
+ stream : NativeFile
+ """
+ cdef:
+ c_string pathstr = _path_as_bytes(path)
+ NativeFile stream = NativeFile()
+ shared_ptr[CInputStream] in_handle
+
+ with nogil:
+ in_handle = GetResultValue(self.fs.OpenInputStream(pathstr))
+
+ stream.set_input_stream(in_handle)
+ stream.is_readable = True
+
+ return self._wrap_input_stream(
+ stream, path=path, compression=compression, buffer_size=buffer_size
+ )
+
+ def open_output_stream(self, path, compression='detect',
+ buffer_size=None, metadata=None):
+ """
+ Open an output stream for sequential writing.
+
+ If the target already exists, existing data is truncated.
+
+ Parameters
+ ----------
+ path : str
+ The source to open for writing.
+ compression : str optional, default 'detect'
+ The compression algorithm to use for on-the-fly compression.
+ If "detect" and source is a file path, then compression will be
+ chosen based on the file extension.
+ If None, no compression will be applied. Otherwise, a well-known
+ algorithm name must be supplied (e.g. "gzip").
+ buffer_size : int optional, default None
+ If None or 0, no buffering will happen. Otherwise the size of the
+ temporary write buffer.
+ metadata : dict optional, default None
+ If not None, a mapping of string keys to string values.
+ Some filesystems support storing metadata along the file
+ (such as "Content-Type").
+ Unsupported metadata keys will be ignored.
+
+ Returns
+ -------
+ stream : NativeFile
+ """
+ cdef:
+ c_string pathstr = _path_as_bytes(path)
+ NativeFile stream = NativeFile()
+ shared_ptr[COutputStream] out_handle
+ shared_ptr[const CKeyValueMetadata] c_metadata
+
+ if metadata is not None:
+ c_metadata = pyarrow_unwrap_metadata(KeyValueMetadata(metadata))
+
+ with nogil:
+ out_handle = GetResultValue(
+ self.fs.OpenOutputStream(pathstr, c_metadata))
+
+ stream.set_output_stream(out_handle)
+ stream.is_writable = True
+
+ return self._wrap_output_stream(
+ stream, path=path, compression=compression, buffer_size=buffer_size
+ )
+
+ def open_append_stream(self, path, compression='detect',
+ buffer_size=None, metadata=None):
+ """
+ DEPRECATED: Open an output stream for appending.
+
+ If the target doesn't exist, a new empty file is created.
+
+ .. deprecated:: 6.0
+ Several filesystems don't support this functionality
+ and it will be later removed.
+
+ Parameters
+ ----------
+ path : str
+ The source to open for writing.
+ compression : str optional, default 'detect'
+ The compression algorithm to use for on-the-fly compression.
+ If "detect" and source is a file path, then compression will be
+ chosen based on the file extension.
+ If None, no compression will be applied. Otherwise, a well-known
+ algorithm name must be supplied (e.g. "gzip").
+ buffer_size : int optional, default None
+ If None or 0, no buffering will happen. Otherwise the size of the
+ temporary write buffer.
+ metadata : dict optional, default None
+ If not None, a mapping of string keys to string values.
+ Some filesystems support storing metadata along the file
+ (such as "Content-Type").
+ Unsupported metadata keys will be ignored.
+
+ Returns
+ -------
+ stream : NativeFile
+ """
+ cdef:
+ c_string pathstr = _path_as_bytes(path)
+ NativeFile stream = NativeFile()
+ shared_ptr[COutputStream] out_handle
+ shared_ptr[const CKeyValueMetadata] c_metadata
+
+ warnings.warn(
+ "`open_append_stream` is deprecated as of 6.0.0; several "
+ "filesystems don't support it and it will be later removed",
+ FutureWarning)
+
+ if metadata is not None:
+ c_metadata = pyarrow_unwrap_metadata(KeyValueMetadata(metadata))
+
+ with nogil:
+ out_handle = GetResultValue(
+ self.fs.OpenAppendStream(pathstr, c_metadata))
+
+ stream.set_output_stream(out_handle)
+ stream.is_writable = True
+
+ return self._wrap_output_stream(
+ stream, path=path, compression=compression, buffer_size=buffer_size
+ )
+
+ def normalize_path(self, path):
+ """
+ Normalize filesystem path.
+
+ Parameters
+ ----------
+ path : str
+ The path to normalize
+
+ Returns
+ -------
+ normalized_path : str
+ The normalized path
+ """
+ cdef:
+ c_string c_path = _path_as_bytes(path)
+ c_string c_path_normalized
+
+ c_path_normalized = GetResultValue(self.fs.NormalizePath(c_path))
+ return frombytes(c_path_normalized)
+
+
+cdef class LocalFileSystem(FileSystem):
+ """
+ A FileSystem implementation accessing files on the local machine.
+
+ Details such as symlinks are abstracted away (symlinks are always followed,
+ except when deleting an entry).
+
+ Parameters
+ ----------
+ use_mmap : bool, default False
+ Whether open_input_stream and open_input_file should return
+ a mmap'ed file or a regular file.
+ """
+
+ def __init__(self, *, use_mmap=False):
+ cdef:
+ CLocalFileSystemOptions opts
+ shared_ptr[CLocalFileSystem] fs
+
+ opts = CLocalFileSystemOptions.Defaults()
+ opts.use_mmap = use_mmap
+
+ fs = make_shared[CLocalFileSystem](opts)
+ self.init(<shared_ptr[CFileSystem]> fs)
+
+ cdef init(self, const shared_ptr[CFileSystem]& c_fs):
+ FileSystem.init(self, c_fs)
+ self.localfs = <CLocalFileSystem*> c_fs.get()
+
+ @classmethod
+ def _reconstruct(cls, kwargs):
+ # __reduce__ doesn't allow passing named arguments directly to the
+ # reconstructor, hence this wrapper.
+ return cls(**kwargs)
+
+ def __reduce__(self):
+ cdef CLocalFileSystemOptions opts = self.localfs.options()
+ return LocalFileSystem._reconstruct, (dict(
+ use_mmap=opts.use_mmap),)
+
+
+cdef class SubTreeFileSystem(FileSystem):
+ """
+ Delegates to another implementation after prepending a fixed base path.
+
+ This is useful to expose a logical view of a subtree of a filesystem,
+ for example a directory in a LocalFileSystem.
+
+ Note, that this makes no security guarantee. For example, symlinks may
+ allow to "escape" the subtree and access other parts of the underlying
+ filesystem.
+
+ Parameters
+ ----------
+ base_path : str
+ The root of the subtree.
+ base_fs : FileSystem
+ FileSystem object the operations delegated to.
+ """
+
+ def __init__(self, base_path, FileSystem base_fs):
+ cdef:
+ c_string pathstr
+ shared_ptr[CSubTreeFileSystem] wrapped
+
+ pathstr = _path_as_bytes(base_path)
+ wrapped = make_shared[CSubTreeFileSystem](pathstr, base_fs.wrapped)
+
+ self.init(<shared_ptr[CFileSystem]> wrapped)
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped):
+ FileSystem.init(self, wrapped)
+ self.subtreefs = <CSubTreeFileSystem*> wrapped.get()
+
+ def __repr__(self):
+ return ("SubTreeFileSystem(base_path={}, base_fs={}"
+ .format(self.base_path, self.base_fs))
+
+ def __reduce__(self):
+ return SubTreeFileSystem, (
+ frombytes(self.subtreefs.base_path()),
+ FileSystem.wrap(self.subtreefs.base_fs())
+ )
+
+ @property
+ def base_path(self):
+ return frombytes(self.subtreefs.base_path())
+
+ @property
+ def base_fs(self):
+ return FileSystem.wrap(self.subtreefs.base_fs())
+
+
+cdef class _MockFileSystem(FileSystem):
+
+ def __init__(self, datetime current_time=None):
+ cdef shared_ptr[CMockFileSystem] wrapped
+
+ current_time = current_time or datetime.now()
+ wrapped = make_shared[CMockFileSystem](
+ PyDateTime_to_TimePoint(<PyDateTime_DateTime*> current_time)
+ )
+
+ self.init(<shared_ptr[CFileSystem]> wrapped)
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped):
+ FileSystem.init(self, wrapped)
+ self.mockfs = <CMockFileSystem*> wrapped.get()
+
+
+cdef class PyFileSystem(FileSystem):
+ """
+ A FileSystem with behavior implemented in Python.
+
+ Parameters
+ ----------
+ handler : FileSystemHandler
+ The handler object implementing custom filesystem behavior.
+ """
+
+ def __init__(self, handler):
+ cdef:
+ CPyFileSystemVtable vtable
+ shared_ptr[CPyFileSystem] wrapped
+
+ if not isinstance(handler, FileSystemHandler):
+ raise TypeError("Expected a FileSystemHandler instance, got {0}"
+ .format(type(handler)))
+
+ vtable.get_type_name = _cb_get_type_name
+ vtable.equals = _cb_equals
+ vtable.get_file_info = _cb_get_file_info
+ vtable.get_file_info_vector = _cb_get_file_info_vector
+ vtable.get_file_info_selector = _cb_get_file_info_selector
+ vtable.create_dir = _cb_create_dir
+ vtable.delete_dir = _cb_delete_dir
+ vtable.delete_dir_contents = _cb_delete_dir_contents
+ vtable.delete_root_dir_contents = _cb_delete_root_dir_contents
+ vtable.delete_file = _cb_delete_file
+ vtable.move = _cb_move
+ vtable.copy_file = _cb_copy_file
+ vtable.open_input_stream = _cb_open_input_stream
+ vtable.open_input_file = _cb_open_input_file
+ vtable.open_output_stream = _cb_open_output_stream
+ vtable.open_append_stream = _cb_open_append_stream
+ vtable.normalize_path = _cb_normalize_path
+
+ wrapped = CPyFileSystem.Make(handler, move(vtable))
+ self.init(<shared_ptr[CFileSystem]> wrapped)
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped):
+ FileSystem.init(self, wrapped)
+ self.pyfs = <CPyFileSystem*> wrapped.get()
+
+ @property
+ def handler(self):
+ """
+ The filesystem's underlying handler.
+
+ Returns
+ -------
+ handler : FileSystemHandler
+ """
+ return <object> self.pyfs.handler()
+
+ def __reduce__(self):
+ return PyFileSystem, (self.handler,)
+
+
+class FileSystemHandler(ABC):
+ """
+ An abstract class exposing methods to implement PyFileSystem's behavior.
+ """
+
+ @abstractmethod
+ def get_type_name(self):
+ """
+ Implement PyFileSystem.type_name.
+ """
+
+ @abstractmethod
+ def get_file_info(self, paths):
+ """
+ Implement PyFileSystem.get_file_info(paths).
+
+ Parameters
+ ----------
+ paths : paths for which we want to retrieve the info.
+ """
+
+ @abstractmethod
+ def get_file_info_selector(self, selector):
+ """
+ Implement PyFileSystem.get_file_info(selector).
+
+ Parameters
+ ----------
+ selector : selector for which we want to retrieve the info.
+ """
+
+ @abstractmethod
+ def create_dir(self, path, recursive):
+ """
+ Implement PyFileSystem.create_dir(...).
+
+ Parameters
+ ----------
+ path : path of the directory.
+ recursive : if the parent directories should be created too.
+ """
+
+ @abstractmethod
+ def delete_dir(self, path):
+ """
+ Implement PyFileSystem.delete_dir(...).
+
+ Parameters
+ ----------
+ path : path of the directory.
+ """
+
+ @abstractmethod
+ def delete_dir_contents(self, path):
+ """
+ Implement PyFileSystem.delete_dir_contents(...).
+
+ Parameters
+ ----------
+ path : path of the directory.
+ """
+
+ @abstractmethod
+ def delete_root_dir_contents(self):
+ """
+ Implement PyFileSystem.delete_dir_contents("/", accept_root_dir=True).
+ """
+
+ @abstractmethod
+ def delete_file(self, path):
+ """
+ Implement PyFileSystem.delete_file(...).
+
+ Parameters
+ ----------
+ path : path of the file.
+ """
+
+ @abstractmethod
+ def move(self, src, dest):
+ """
+ Implement PyFileSystem.move(...).
+
+ Parameters
+ ----------
+ src : path of what should be moved.
+ dest : path of where it should be moved to.
+ """
+
+ @abstractmethod
+ def copy_file(self, src, dest):
+ """
+ Implement PyFileSystem.copy_file(...).
+
+ Parameters
+ ----------
+ src : path of what should be copied.
+ dest : path of where it should be copied to.
+ """
+
+ @abstractmethod
+ def open_input_stream(self, path):
+ """
+ Implement PyFileSystem.open_input_stream(...).
+
+ Parameters
+ ----------
+ path : path of what should be opened.
+ """
+
+ @abstractmethod
+ def open_input_file(self, path):
+ """
+ Implement PyFileSystem.open_input_file(...).
+
+ Parameters
+ ----------
+ path : path of what should be opened.
+ """
+
+ @abstractmethod
+ def open_output_stream(self, path, metadata):
+ """
+ Implement PyFileSystem.open_output_stream(...).
+
+ Parameters
+ ----------
+ path : path of what should be opened.
+ metadata : mapping of string keys to string values.
+ Some filesystems support storing metadata along the file
+ (such as "Content-Type").
+ """
+
+ @abstractmethod
+ def open_append_stream(self, path, metadata):
+ """
+ Implement PyFileSystem.open_append_stream(...).
+
+ Parameters
+ ----------
+ path : path of what should be opened.
+ metadata : mapping of string keys to string values.
+ Some filesystems support storing metadata along the file
+ (such as "Content-Type").
+ """
+
+ @abstractmethod
+ def normalize_path(self, path):
+ """
+ Implement PyFileSystem.normalize_path(...).
+
+ Parameters
+ ----------
+ path : path of what should be normalized.
+ """
+
+# Callback definitions for CPyFileSystemVtable
+
+
+cdef void _cb_get_type_name(handler, c_string* out) except *:
+ out[0] = tobytes("py::" + handler.get_type_name())
+
+cdef c_bool _cb_equals(handler, const CFileSystem& c_other) except False:
+ if c_other.type_name().startswith(b"py::"):
+ return <object> (<const CPyFileSystem&> c_other).handler() == handler
+
+ return False
+
+cdef void _cb_get_file_info(handler, const c_string& path,
+ CFileInfo* out) except *:
+ infos = handler.get_file_info([frombytes(path)])
+ if not isinstance(infos, list) or len(infos) != 1:
+ raise TypeError("get_file_info should have returned a 1-element list")
+ out[0] = FileInfo.unwrap_safe(infos[0])
+
+cdef void _cb_get_file_info_vector(handler, const vector[c_string]& paths,
+ vector[CFileInfo]* out) except *:
+ py_paths = [frombytes(paths[i]) for i in range(len(paths))]
+ infos = handler.get_file_info(py_paths)
+ if not isinstance(infos, list):
+ raise TypeError("get_file_info should have returned a list")
+ out[0].clear()
+ out[0].reserve(len(infos))
+ for info in infos:
+ out[0].push_back(FileInfo.unwrap_safe(info))
+
+cdef void _cb_get_file_info_selector(handler, const CFileSelector& selector,
+ vector[CFileInfo]* out) except *:
+ infos = handler.get_file_info_selector(FileSelector.wrap(selector))
+ if not isinstance(infos, list):
+ raise TypeError("get_file_info_selector should have returned a list")
+ out[0].clear()
+ out[0].reserve(len(infos))
+ for info in infos:
+ out[0].push_back(FileInfo.unwrap_safe(info))
+
+cdef void _cb_create_dir(handler, const c_string& path,
+ c_bool recursive) except *:
+ handler.create_dir(frombytes(path), recursive)
+
+cdef void _cb_delete_dir(handler, const c_string& path) except *:
+ handler.delete_dir(frombytes(path))
+
+cdef void _cb_delete_dir_contents(handler, const c_string& path) except *:
+ handler.delete_dir_contents(frombytes(path))
+
+cdef void _cb_delete_root_dir_contents(handler) except *:
+ handler.delete_root_dir_contents()
+
+cdef void _cb_delete_file(handler, const c_string& path) except *:
+ handler.delete_file(frombytes(path))
+
+cdef void _cb_move(handler, const c_string& src,
+ const c_string& dest) except *:
+ handler.move(frombytes(src), frombytes(dest))
+
+cdef void _cb_copy_file(handler, const c_string& src,
+ const c_string& dest) except *:
+ handler.copy_file(frombytes(src), frombytes(dest))
+
+cdef void _cb_open_input_stream(handler, const c_string& path,
+ shared_ptr[CInputStream]* out) except *:
+ stream = handler.open_input_stream(frombytes(path))
+ if not isinstance(stream, NativeFile):
+ raise TypeError("open_input_stream should have returned "
+ "a PyArrow file")
+ out[0] = (<NativeFile> stream).get_input_stream()
+
+cdef void _cb_open_input_file(handler, const c_string& path,
+ shared_ptr[CRandomAccessFile]* out) except *:
+ stream = handler.open_input_file(frombytes(path))
+ if not isinstance(stream, NativeFile):
+ raise TypeError("open_input_file should have returned "
+ "a PyArrow file")
+ out[0] = (<NativeFile> stream).get_random_access_file()
+
+cdef void _cb_open_output_stream(
+ handler, const c_string& path,
+ const shared_ptr[const CKeyValueMetadata]& metadata,
+ shared_ptr[COutputStream]* out) except *:
+ stream = handler.open_output_stream(
+ frombytes(path), pyarrow_wrap_metadata(metadata))
+ if not isinstance(stream, NativeFile):
+ raise TypeError("open_output_stream should have returned "
+ "a PyArrow file")
+ out[0] = (<NativeFile> stream).get_output_stream()
+
+cdef void _cb_open_append_stream(
+ handler, const c_string& path,
+ const shared_ptr[const CKeyValueMetadata]& metadata,
+ shared_ptr[COutputStream]* out) except *:
+ stream = handler.open_append_stream(
+ frombytes(path), pyarrow_wrap_metadata(metadata))
+ if not isinstance(stream, NativeFile):
+ raise TypeError("open_append_stream should have returned "
+ "a PyArrow file")
+ out[0] = (<NativeFile> stream).get_output_stream()
+
+cdef void _cb_normalize_path(handler, const c_string& path,
+ c_string* out) except *:
+ out[0] = tobytes(handler.normalize_path(frombytes(path)))
+
+
+def _copy_files(FileSystem source_fs, str source_path,
+ FileSystem destination_fs, str destination_path,
+ int64_t chunk_size, c_bool use_threads):
+ # low-level helper exposed through pyarrow/fs.py::copy_files
+ cdef:
+ CFileLocator c_source
+ vector[CFileLocator] c_sources
+ CFileLocator c_destination
+ vector[CFileLocator] c_destinations
+ FileSystem fs
+ CStatus c_status
+ shared_ptr[CFileSystem] c_fs
+
+ c_source.filesystem = source_fs.unwrap()
+ c_source.path = tobytes(source_path)
+ c_sources.push_back(c_source)
+
+ c_destination.filesystem = destination_fs.unwrap()
+ c_destination.path = tobytes(destination_path)
+ c_destinations.push_back(c_destination)
+
+ with nogil:
+ check_status(CCopyFiles(
+ c_sources, c_destinations,
+ c_default_io_context(), chunk_size, use_threads,
+ ))
+
+
+def _copy_files_selector(FileSystem source_fs, FileSelector source_sel,
+ FileSystem destination_fs, str destination_base_dir,
+ int64_t chunk_size, c_bool use_threads):
+ # low-level helper exposed through pyarrow/fs.py::copy_files
+ cdef c_string c_destination_base_dir = tobytes(destination_base_dir)
+
+ with nogil:
+ check_status(CCopyFilesWithSelector(
+ source_fs.unwrap(), source_sel.unwrap(),
+ destination_fs.unwrap(), c_destination_base_dir,
+ c_default_io_context(), chunk_size, use_threads,
+ ))
diff --git a/src/arrow/python/pyarrow/_hdfs.pyx b/src/arrow/python/pyarrow/_hdfs.pyx
new file mode 100644
index 000000000..7a3b974be
--- /dev/null
+++ b/src/arrow/python/pyarrow/_hdfs.pyx
@@ -0,0 +1,149 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from pyarrow.lib cimport check_status
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_fs cimport *
+from pyarrow._fs cimport FileSystem
+
+from pyarrow.lib import frombytes, tobytes
+from pyarrow.util import _stringify_path
+
+
+cdef class HadoopFileSystem(FileSystem):
+ """
+ HDFS backed FileSystem implementation
+
+ Parameters
+ ----------
+ host : str
+ HDFS host to connect to. Set to "default" for fs.defaultFS from
+ core-site.xml.
+ port : int, default 8020
+ HDFS port to connect to. Set to 0 for default or logical (HA) nodes.
+ user : str, default None
+ Username when connecting to HDFS; None implies login user.
+ replication : int, default 3
+ Number of copies each block will have.
+ buffer_size : int, default 0
+ If 0, no buffering will happen otherwise the size of the temporary read
+ and write buffer.
+ default_block_size : int, default None
+ None means the default configuration for HDFS, a typical block size is
+ 128 MB.
+ kerb_ticket : string or path, default None
+ If not None, the path to the Kerberos ticket cache.
+ extra_conf : dict, default None
+ Extra key/value pairs for configuration; will override any
+ hdfs-site.xml properties.
+ """
+
+ cdef:
+ CHadoopFileSystem* hdfs
+
+ def __init__(self, str host, int port=8020, *, str user=None,
+ int replication=3, int buffer_size=0,
+ default_block_size=None, kerb_ticket=None,
+ extra_conf=None):
+ cdef:
+ CHdfsOptions options
+ shared_ptr[CHadoopFileSystem] wrapped
+
+ if not host.startswith(('hdfs://', 'viewfs://')) and host != "default":
+ # TODO(kszucs): do more sanitization
+ host = 'hdfs://{}'.format(host)
+
+ options.ConfigureEndPoint(tobytes(host), int(port))
+ options.ConfigureReplication(replication)
+ options.ConfigureBufferSize(buffer_size)
+
+ if user is not None:
+ options.ConfigureUser(tobytes(user))
+ if default_block_size is not None:
+ options.ConfigureBlockSize(default_block_size)
+ if kerb_ticket is not None:
+ options.ConfigureKerberosTicketCachePath(
+ tobytes(_stringify_path(kerb_ticket)))
+ if extra_conf is not None:
+ for k, v in extra_conf.items():
+ options.ConfigureExtraConf(tobytes(k), tobytes(v))
+
+ with nogil:
+ wrapped = GetResultValue(CHadoopFileSystem.Make(options))
+ self.init(<shared_ptr[CFileSystem]> wrapped)
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped):
+ FileSystem.init(self, wrapped)
+ self.hdfs = <CHadoopFileSystem*> wrapped.get()
+
+ @staticmethod
+ def from_uri(uri):
+ """
+ Instantiate HadoopFileSystem object from an URI string.
+
+ The following two calls are equivalent
+
+ * ``HadoopFileSystem.from_uri('hdfs://localhost:8020/?user=test\
+&replication=1')``
+ * ``HadoopFileSystem('localhost', port=8020, user='test', \
+replication=1)``
+
+ Parameters
+ ----------
+ uri : str
+ A string URI describing the connection to HDFS.
+ In order to change the user, replication, buffer_size or
+ default_block_size pass the values as query parts.
+
+ Returns
+ -------
+ HadoopFileSystem
+ """
+ cdef:
+ HadoopFileSystem self = HadoopFileSystem.__new__(HadoopFileSystem)
+ shared_ptr[CHadoopFileSystem] wrapped
+ CHdfsOptions options
+
+ options = GetResultValue(CHdfsOptions.FromUriString(tobytes(uri)))
+ with nogil:
+ wrapped = GetResultValue(CHadoopFileSystem.Make(options))
+
+ self.init(<shared_ptr[CFileSystem]> wrapped)
+ return self
+
+ @classmethod
+ def _reconstruct(cls, kwargs):
+ return cls(**kwargs)
+
+ def __reduce__(self):
+ cdef CHdfsOptions opts = self.hdfs.options()
+ return (
+ HadoopFileSystem._reconstruct, (dict(
+ host=frombytes(opts.connection_config.host),
+ port=opts.connection_config.port,
+ user=frombytes(opts.connection_config.user),
+ replication=opts.replication,
+ buffer_size=opts.buffer_size,
+ default_block_size=opts.default_block_size,
+ kerb_ticket=frombytes(opts.connection_config.kerb_ticket),
+ extra_conf={frombytes(k): frombytes(v)
+ for k, v in opts.connection_config.extra_conf},
+ ),)
+ )
diff --git a/src/arrow/python/pyarrow/_hdfsio.pyx b/src/arrow/python/pyarrow/_hdfsio.pyx
new file mode 100644
index 000000000..b864f8a68
--- /dev/null
+++ b/src/arrow/python/pyarrow/_hdfsio.pyx
@@ -0,0 +1,480 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# ----------------------------------------------------------------------
+# HDFS IO implementation
+
+# cython: language_level = 3
+
+import re
+
+from pyarrow.lib cimport check_status, _Weakrefable, NativeFile
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_fs cimport *
+from pyarrow.lib import frombytes, tobytes, ArrowIOError
+
+from queue import Queue, Empty as QueueEmpty, Full as QueueFull
+
+
+_HDFS_PATH_RE = re.compile(r'hdfs://(.*):(\d+)(.*)')
+
+
+def have_libhdfs():
+ try:
+ with nogil:
+ check_status(HaveLibHdfs())
+ return True
+ except Exception:
+ return False
+
+
+def strip_hdfs_abspath(path):
+ m = _HDFS_PATH_RE.match(path)
+ if m:
+ return m.group(3)
+ else:
+ return path
+
+
+cdef class HadoopFileSystem(_Weakrefable):
+ cdef:
+ shared_ptr[CIOHadoopFileSystem] client
+
+ cdef readonly:
+ bint is_open
+ object host
+ object user
+ object kerb_ticket
+ int port
+ dict extra_conf
+
+ def _connect(self, host, port, user, kerb_ticket, extra_conf):
+ cdef HdfsConnectionConfig conf
+
+ if host is not None:
+ conf.host = tobytes(host)
+ self.host = host
+
+ conf.port = port
+ self.port = port
+
+ if user is not None:
+ conf.user = tobytes(user)
+ self.user = user
+
+ if kerb_ticket is not None:
+ conf.kerb_ticket = tobytes(kerb_ticket)
+ self.kerb_ticket = kerb_ticket
+
+ with nogil:
+ check_status(HaveLibHdfs())
+
+ if extra_conf is not None and isinstance(extra_conf, dict):
+ conf.extra_conf = {tobytes(k): tobytes(v)
+ for k, v in extra_conf.items()}
+ self.extra_conf = extra_conf
+
+ with nogil:
+ check_status(CIOHadoopFileSystem.Connect(&conf, &self.client))
+ self.is_open = True
+
+ @classmethod
+ def connect(cls, *args, **kwargs):
+ return cls(*args, **kwargs)
+
+ def __dealloc__(self):
+ if self.is_open:
+ self.close()
+
+ def close(self):
+ """
+ Disconnect from the HDFS cluster
+ """
+ self._ensure_client()
+ with nogil:
+ check_status(self.client.get().Disconnect())
+ self.is_open = False
+
+ cdef _ensure_client(self):
+ if self.client.get() == NULL:
+ raise IOError('HDFS client improperly initialized')
+ elif not self.is_open:
+ raise IOError('HDFS client is closed')
+
+ def exists(self, path):
+ """
+ Returns True if the path is known to the cluster, False if it does not
+ (or there is an RPC error)
+ """
+ self._ensure_client()
+
+ cdef c_string c_path = tobytes(path)
+ cdef c_bool result
+ with nogil:
+ result = self.client.get().Exists(c_path)
+ return result
+
+ def isdir(self, path):
+ cdef HdfsPathInfo info
+ try:
+ self._path_info(path, &info)
+ except ArrowIOError:
+ return False
+ return info.kind == ObjectType_DIRECTORY
+
+ def isfile(self, path):
+ cdef HdfsPathInfo info
+ try:
+ self._path_info(path, &info)
+ except ArrowIOError:
+ return False
+ return info.kind == ObjectType_FILE
+
+ def get_capacity(self):
+ """
+ Get reported total capacity of file system
+
+ Returns
+ -------
+ capacity : int
+ """
+ cdef int64_t capacity = 0
+ with nogil:
+ check_status(self.client.get().GetCapacity(&capacity))
+ return capacity
+
+ def get_space_used(self):
+ """
+ Get space used on file system
+
+ Returns
+ -------
+ space_used : int
+ """
+ cdef int64_t space_used = 0
+ with nogil:
+ check_status(self.client.get().GetUsed(&space_used))
+ return space_used
+
+ def df(self):
+ """
+ Return free space on disk, like the UNIX df command
+
+ Returns
+ -------
+ space : int
+ """
+ return self.get_capacity() - self.get_space_used()
+
+ def rename(self, path, new_path):
+ cdef c_string c_path = tobytes(path)
+ cdef c_string c_new_path = tobytes(new_path)
+ with nogil:
+ check_status(self.client.get().Rename(c_path, c_new_path))
+
+ def info(self, path):
+ """
+ Return detailed HDFS information for path
+
+ Parameters
+ ----------
+ path : string
+ Path to file or directory
+
+ Returns
+ -------
+ path_info : dict
+ """
+ cdef HdfsPathInfo info
+ self._path_info(path, &info)
+ return {
+ 'path': frombytes(info.name),
+ 'owner': frombytes(info.owner),
+ 'group': frombytes(info.group),
+ 'size': info.size,
+ 'block_size': info.block_size,
+ 'last_modified': info.last_modified_time,
+ 'last_accessed': info.last_access_time,
+ 'replication': info.replication,
+ 'permissions': info.permissions,
+ 'kind': ('directory' if info.kind == ObjectType_DIRECTORY
+ else 'file')
+ }
+
+ def stat(self, path):
+ """
+ Return basic file system statistics about path
+
+ Parameters
+ ----------
+ path : string
+ Path to file or directory
+
+ Returns
+ -------
+ stat : dict
+ """
+ cdef FileStatistics info
+ cdef c_string c_path = tobytes(path)
+ with nogil:
+ check_status(self.client.get()
+ .Stat(c_path, &info))
+ return {
+ 'size': info.size,
+ 'kind': ('directory' if info.kind == ObjectType_DIRECTORY
+ else 'file')
+ }
+
+ cdef _path_info(self, path, HdfsPathInfo* info):
+ cdef c_string c_path = tobytes(path)
+
+ with nogil:
+ check_status(self.client.get()
+ .GetPathInfo(c_path, info))
+
+ def ls(self, path, bint full_info):
+ cdef:
+ c_string c_path = tobytes(path)
+ vector[HdfsPathInfo] listing
+ list results = []
+ int i
+
+ self._ensure_client()
+
+ with nogil:
+ check_status(self.client.get()
+ .ListDirectory(c_path, &listing))
+
+ cdef const HdfsPathInfo* info
+ for i in range(<int> listing.size()):
+ info = &listing[i]
+
+ # Try to trim off the hdfs://HOST:PORT piece
+ name = strip_hdfs_abspath(frombytes(info.name))
+
+ if full_info:
+ kind = ('file' if info.kind == ObjectType_FILE
+ else 'directory')
+
+ results.append({
+ 'kind': kind,
+ 'name': name,
+ 'owner': frombytes(info.owner),
+ 'group': frombytes(info.group),
+ 'last_modified_time': info.last_modified_time,
+ 'last_access_time': info.last_access_time,
+ 'size': info.size,
+ 'replication': info.replication,
+ 'block_size': info.block_size,
+ 'permissions': info.permissions
+ })
+ else:
+ results.append(name)
+
+ return results
+
+ def chmod(self, path, mode):
+ """
+ Change file permissions
+
+ Parameters
+ ----------
+ path : string
+ absolute path to file or directory
+ mode : int
+ POSIX-like bitmask
+ """
+ self._ensure_client()
+ cdef c_string c_path = tobytes(path)
+ cdef int c_mode = mode
+ with nogil:
+ check_status(self.client.get()
+ .Chmod(c_path, c_mode))
+
+ def chown(self, path, owner=None, group=None):
+ """
+ Change file permissions
+
+ Parameters
+ ----------
+ path : string
+ absolute path to file or directory
+ owner : string, default None
+ New owner, None for no change
+ group : string, default None
+ New group, None for no change
+ """
+ cdef:
+ c_string c_path
+ c_string c_owner
+ c_string c_group
+ const char* c_owner_ptr = NULL
+ const char* c_group_ptr = NULL
+
+ self._ensure_client()
+
+ c_path = tobytes(path)
+ if owner is not None:
+ c_owner = tobytes(owner)
+ c_owner_ptr = c_owner.c_str()
+
+ if group is not None:
+ c_group = tobytes(group)
+ c_group_ptr = c_group.c_str()
+
+ with nogil:
+ check_status(self.client.get()
+ .Chown(c_path, c_owner_ptr, c_group_ptr))
+
+ def mkdir(self, path):
+ """
+ Create indicated directory and any necessary parent directories
+ """
+ self._ensure_client()
+ cdef c_string c_path = tobytes(path)
+ with nogil:
+ check_status(self.client.get()
+ .MakeDirectory(c_path))
+
+ def delete(self, path, bint recursive=False):
+ """
+ Delete the indicated file or directory
+
+ Parameters
+ ----------
+ path : string
+ recursive : boolean, default False
+ If True, also delete child paths for directories
+ """
+ self._ensure_client()
+
+ cdef c_string c_path = tobytes(path)
+ with nogil:
+ check_status(self.client.get()
+ .Delete(c_path, recursive == 1))
+
+ def open(self, path, mode='rb', buffer_size=None, replication=None,
+ default_block_size=None):
+ """
+ Open HDFS file for reading or writing
+
+ Parameters
+ ----------
+ mode : string
+ Must be one of 'rb', 'wb', 'ab'
+
+ Returns
+ -------
+ handle : HdfsFile
+ """
+ self._ensure_client()
+
+ cdef HdfsFile out = HdfsFile()
+
+ if mode not in ('rb', 'wb', 'ab'):
+ raise Exception("Mode must be 'rb' (read), "
+ "'wb' (write, new file), or 'ab' (append)")
+
+ cdef c_string c_path = tobytes(path)
+ cdef c_bool append = False
+
+ # 0 in libhdfs means "use the default"
+ cdef int32_t c_buffer_size = buffer_size or 0
+ cdef int16_t c_replication = replication or 0
+ cdef int64_t c_default_block_size = default_block_size or 0
+
+ cdef shared_ptr[HdfsOutputStream] wr_handle
+ cdef shared_ptr[HdfsReadableFile] rd_handle
+
+ if mode in ('wb', 'ab'):
+ if mode == 'ab':
+ append = True
+
+ with nogil:
+ check_status(
+ self.client.get()
+ .OpenWritable(c_path, append, c_buffer_size,
+ c_replication, c_default_block_size,
+ &wr_handle))
+
+ out.set_output_stream(<shared_ptr[COutputStream]> wr_handle)
+ out.is_writable = True
+ else:
+ with nogil:
+ check_status(self.client.get()
+ .OpenReadable(c_path, &rd_handle))
+
+ out.set_random_access_file(
+ <shared_ptr[CRandomAccessFile]> rd_handle)
+ out.is_readable = True
+
+ assert not out.closed
+
+ if c_buffer_size == 0:
+ c_buffer_size = 2 ** 16
+
+ out.mode = mode
+ out.buffer_size = c_buffer_size
+ out.parent = _HdfsFileNanny(self, out)
+ out.own_file = True
+
+ return out
+
+ def download(self, path, stream, buffer_size=None):
+ with self.open(path, 'rb') as f:
+ f.download(stream, buffer_size=buffer_size)
+
+ def upload(self, path, stream, buffer_size=None):
+ """
+ Upload file-like object to HDFS path
+ """
+ with self.open(path, 'wb') as f:
+ f.upload(stream, buffer_size=buffer_size)
+
+
+# ARROW-404: Helper class to ensure that files are closed before the
+# client. During deallocation of the extension class, the attributes are
+# decref'd which can cause the client to get closed first if the file has the
+# last remaining reference
+cdef class _HdfsFileNanny(_Weakrefable):
+ cdef:
+ object client
+ object file_handle_ref
+
+ def __cinit__(self, client, file_handle):
+ import weakref
+ self.client = client
+ self.file_handle_ref = weakref.ref(file_handle)
+
+ def __dealloc__(self):
+ fh = self.file_handle_ref()
+ if fh:
+ fh.close()
+ # avoid cyclic GC
+ self.file_handle_ref = None
+ self.client = None
+
+
+cdef class HdfsFile(NativeFile):
+ cdef readonly:
+ int32_t buffer_size
+ object mode
+ object parent
+
+ def __dealloc__(self):
+ self.parent = None
diff --git a/src/arrow/python/pyarrow/_json.pyx b/src/arrow/python/pyarrow/_json.pyx
new file mode 100644
index 000000000..1c08e546e
--- /dev/null
+++ b/src/arrow/python/pyarrow/_json.pyx
@@ -0,0 +1,248 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile=False
+# distutils: language = c++
+# cython: language_level = 3
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.lib cimport (check_status, _Weakrefable, Field, MemoryPool,
+ ensure_type, maybe_unbox_memory_pool,
+ get_input_stream, pyarrow_wrap_table,
+ pyarrow_wrap_data_type, pyarrow_unwrap_data_type,
+ pyarrow_wrap_schema, pyarrow_unwrap_schema)
+
+
+cdef class ReadOptions(_Weakrefable):
+ """
+ Options for reading JSON files.
+
+ Parameters
+ ----------
+ use_threads : bool, optional (default True)
+ Whether to use multiple threads to accelerate reading
+ block_size : int, optional
+ How much bytes to process at a time from the input stream.
+ This will determine multi-threading granularity as well as
+ the size of individual chunks in the Table.
+ """
+ cdef:
+ CJSONReadOptions options
+
+ # Avoid mistakingly creating attributes
+ __slots__ = ()
+
+ def __init__(self, use_threads=None, block_size=None):
+ self.options = CJSONReadOptions.Defaults()
+ if use_threads is not None:
+ self.use_threads = use_threads
+ if block_size is not None:
+ self.block_size = block_size
+
+ @property
+ def use_threads(self):
+ """
+ Whether to use multiple threads to accelerate reading.
+ """
+ return self.options.use_threads
+
+ @use_threads.setter
+ def use_threads(self, value):
+ self.options.use_threads = value
+
+ @property
+ def block_size(self):
+ """
+ How much bytes to process at a time from the input stream.
+
+ This will determine multi-threading granularity as well as the size of
+ individual chunks in the Table.
+ """
+ return self.options.block_size
+
+ @block_size.setter
+ def block_size(self, value):
+ self.options.block_size = value
+
+
+cdef class ParseOptions(_Weakrefable):
+ """
+ Options for parsing JSON files.
+
+ Parameters
+ ----------
+ explicit_schema : Schema, optional (default None)
+ Optional explicit schema (no type inference, ignores other fields).
+ newlines_in_values : bool, optional (default False)
+ Whether objects may be printed across multiple lines (for example
+ pretty printed). If false, input must end with an empty line.
+ unexpected_field_behavior : str, default "infer"
+ How JSON fields outside of explicit_schema (if given) are treated.
+
+ Possible behaviors:
+
+ - "ignore": unexpected JSON fields are ignored
+ - "error": error out on unexpected JSON fields
+ - "infer": unexpected JSON fields are type-inferred and included in
+ the output
+ """
+
+ cdef:
+ CJSONParseOptions options
+
+ __slots__ = ()
+
+ def __init__(self, explicit_schema=None, newlines_in_values=None,
+ unexpected_field_behavior=None):
+ self.options = CJSONParseOptions.Defaults()
+ if explicit_schema is not None:
+ self.explicit_schema = explicit_schema
+ if newlines_in_values is not None:
+ self.newlines_in_values = newlines_in_values
+ if unexpected_field_behavior is not None:
+ self.unexpected_field_behavior = unexpected_field_behavior
+
+ @property
+ def explicit_schema(self):
+ """
+ Optional explicit schema (no type inference, ignores other fields)
+ """
+ if self.options.explicit_schema.get() == NULL:
+ return None
+ else:
+ return pyarrow_wrap_schema(self.options.explicit_schema)
+
+ @explicit_schema.setter
+ def explicit_schema(self, value):
+ self.options.explicit_schema = pyarrow_unwrap_schema(value)
+
+ @property
+ def newlines_in_values(self):
+ """
+ Whether newline characters are allowed in JSON values.
+ Setting this to True reduces the performance of multi-threaded
+ JSON reading.
+ """
+ return self.options.newlines_in_values
+
+ @newlines_in_values.setter
+ def newlines_in_values(self, value):
+ self.options.newlines_in_values = value
+
+ @property
+ def unexpected_field_behavior(self):
+ """
+ How JSON fields outside of explicit_schema (if given) are treated.
+
+ Possible behaviors:
+
+ - "ignore": unexpected JSON fields are ignored
+ - "error": error out on unexpected JSON fields
+ - "infer": unexpected JSON fields are type-inferred and included in
+ the output
+
+ Set to "infer" by default.
+ """
+ v = self.options.unexpected_field_behavior
+ if v == CUnexpectedFieldBehavior_Ignore:
+ return "ignore"
+ elif v == CUnexpectedFieldBehavior_Error:
+ return "error"
+ elif v == CUnexpectedFieldBehavior_InferType:
+ return "infer"
+ else:
+ raise ValueError('Unexpected value for unexpected_field_behavior')
+
+ @unexpected_field_behavior.setter
+ def unexpected_field_behavior(self, value):
+ cdef CUnexpectedFieldBehavior v
+
+ if value == "ignore":
+ v = CUnexpectedFieldBehavior_Ignore
+ elif value == "error":
+ v = CUnexpectedFieldBehavior_Error
+ elif value == "infer":
+ v = CUnexpectedFieldBehavior_InferType
+ else:
+ raise ValueError(
+ "Unexpected value `{}` for `unexpected_field_behavior`, pass "
+ "either `ignore`, `error` or `infer`.".format(value)
+ )
+
+ self.options.unexpected_field_behavior = v
+
+
+cdef _get_reader(input_file, shared_ptr[CInputStream]* out):
+ use_memory_map = False
+ get_input_stream(input_file, use_memory_map, out)
+
+cdef _get_read_options(ReadOptions read_options, CJSONReadOptions* out):
+ if read_options is None:
+ out[0] = CJSONReadOptions.Defaults()
+ else:
+ out[0] = read_options.options
+
+cdef _get_parse_options(ParseOptions parse_options, CJSONParseOptions* out):
+ if parse_options is None:
+ out[0] = CJSONParseOptions.Defaults()
+ else:
+ out[0] = parse_options.options
+
+
+def read_json(input_file, read_options=None, parse_options=None,
+ MemoryPool memory_pool=None):
+ """
+ Read a Table from a stream of JSON data.
+
+ Parameters
+ ----------
+ input_file : str, path or file-like object
+ The location of JSON data. Currently only the line-delimited JSON
+ format is supported.
+ read_options : pyarrow.json.ReadOptions, optional
+ Options for the JSON reader (see ReadOptions constructor for defaults).
+ parse_options : pyarrow.json.ParseOptions, optional
+ Options for the JSON parser
+ (see ParseOptions constructor for defaults).
+ memory_pool : MemoryPool, optional
+ Pool to allocate Table memory from.
+
+ Returns
+ -------
+ :class:`pyarrow.Table`
+ Contents of the JSON file as a in-memory table.
+ """
+ cdef:
+ shared_ptr[CInputStream] stream
+ CJSONReadOptions c_read_options
+ CJSONParseOptions c_parse_options
+ shared_ptr[CJSONReader] reader
+ shared_ptr[CTable] table
+
+ _get_reader(input_file, &stream)
+ _get_read_options(read_options, &c_read_options)
+ _get_parse_options(parse_options, &c_parse_options)
+
+ reader = GetResultValue(
+ CJSONReader.Make(maybe_unbox_memory_pool(memory_pool),
+ stream, c_read_options, c_parse_options))
+
+ with nogil:
+ table = GetResultValue(reader.get().Read())
+
+ return pyarrow_wrap_table(table)
diff --git a/src/arrow/python/pyarrow/_orc.pxd b/src/arrow/python/pyarrow/_orc.pxd
new file mode 100644
index 000000000..736622591
--- /dev/null
+++ b/src/arrow/python/pyarrow/_orc.pxd
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+# cython: language_level = 3
+
+from libc.string cimport const_char
+from libcpp.vector cimport vector as std_vector
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport (CArray, CSchema, CStatus,
+ CResult, CTable, CMemoryPool,
+ CKeyValueMetadata,
+ CRecordBatch,
+ CTable,
+ CRandomAccessFile, COutputStream,
+ TimeUnit)
+
+
+cdef extern from "arrow/adapters/orc/adapter.h" \
+ namespace "arrow::adapters::orc" nogil:
+
+ cdef cppclass ORCFileReader:
+ @staticmethod
+ CResult[unique_ptr[ORCFileReader]] Open(
+ const shared_ptr[CRandomAccessFile]& file,
+ CMemoryPool* pool)
+
+ CResult[shared_ptr[const CKeyValueMetadata]] ReadMetadata()
+
+ CResult[shared_ptr[CSchema]] ReadSchema()
+
+ CResult[shared_ptr[CRecordBatch]] ReadStripe(int64_t stripe)
+ CResult[shared_ptr[CRecordBatch]] ReadStripe(
+ int64_t stripe, std_vector[c_string])
+
+ CResult[shared_ptr[CTable]] Read()
+ CResult[shared_ptr[CTable]] Read(std_vector[c_string])
+
+ int64_t NumberOfStripes()
+
+ int64_t NumberOfRows()
+
+ cdef cppclass ORCFileWriter:
+ @staticmethod
+ CResult[unique_ptr[ORCFileWriter]] Open(COutputStream* output_stream)
+
+ CStatus Write(const CTable& table)
+
+ CStatus Close()
diff --git a/src/arrow/python/pyarrow/_orc.pyx b/src/arrow/python/pyarrow/_orc.pyx
new file mode 100644
index 000000000..18ca28682
--- /dev/null
+++ b/src/arrow/python/pyarrow/_orc.pyx
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile=False
+# distutils: language = c++
+
+from cython.operator cimport dereference as deref
+from libcpp.vector cimport vector as std_vector
+from libcpp.utility cimport move
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.lib cimport (check_status, _Weakrefable,
+ MemoryPool, maybe_unbox_memory_pool,
+ Schema, pyarrow_wrap_schema,
+ KeyValueMetadata,
+ pyarrow_wrap_batch,
+ RecordBatch,
+ Table,
+ pyarrow_wrap_table,
+ pyarrow_unwrap_schema,
+ pyarrow_wrap_metadata,
+ pyarrow_unwrap_table,
+ get_reader,
+ get_writer)
+from pyarrow.lib import tobytes
+
+
+cdef class ORCReader(_Weakrefable):
+ cdef:
+ object source
+ CMemoryPool* allocator
+ unique_ptr[ORCFileReader] reader
+
+ def __cinit__(self, MemoryPool memory_pool=None):
+ self.allocator = maybe_unbox_memory_pool(memory_pool)
+
+ def open(self, object source, c_bool use_memory_map=True):
+ cdef:
+ shared_ptr[CRandomAccessFile] rd_handle
+
+ self.source = source
+
+ get_reader(source, use_memory_map, &rd_handle)
+ with nogil:
+ self.reader = move(GetResultValue(
+ ORCFileReader.Open(rd_handle, self.allocator)
+ ))
+
+ def metadata(self):
+ """
+ The arrow metadata for this file.
+
+ Returns
+ -------
+ metadata : pyarrow.KeyValueMetadata
+ """
+ cdef:
+ shared_ptr[const CKeyValueMetadata] sp_arrow_metadata
+
+ with nogil:
+ sp_arrow_metadata = GetResultValue(
+ deref(self.reader).ReadMetadata()
+ )
+
+ return pyarrow_wrap_metadata(sp_arrow_metadata)
+
+ def schema(self):
+ """
+ The arrow schema for this file.
+
+ Returns
+ -------
+ schema : pyarrow.Schema
+ """
+ cdef:
+ shared_ptr[CSchema] sp_arrow_schema
+
+ with nogil:
+ sp_arrow_schema = GetResultValue(deref(self.reader).ReadSchema())
+
+ return pyarrow_wrap_schema(sp_arrow_schema)
+
+ def nrows(self):
+ return deref(self.reader).NumberOfRows()
+
+ def nstripes(self):
+ return deref(self.reader).NumberOfStripes()
+
+ def read_stripe(self, n, columns=None):
+ cdef:
+ shared_ptr[CRecordBatch] sp_record_batch
+ RecordBatch batch
+ int64_t stripe
+ std_vector[c_string] c_names
+
+ stripe = n
+
+ if columns is None:
+ with nogil:
+ sp_record_batch = GetResultValue(
+ deref(self.reader).ReadStripe(stripe)
+ )
+ else:
+ c_names = [tobytes(name) for name in columns]
+ with nogil:
+ sp_record_batch = GetResultValue(
+ deref(self.reader).ReadStripe(stripe, c_names)
+ )
+
+ return pyarrow_wrap_batch(sp_record_batch)
+
+ def read(self, columns=None):
+ cdef:
+ shared_ptr[CTable] sp_table
+ std_vector[c_string] c_names
+
+ if columns is None:
+ with nogil:
+ sp_table = GetResultValue(deref(self.reader).Read())
+ else:
+ c_names = [tobytes(name) for name in columns]
+ with nogil:
+ sp_table = GetResultValue(deref(self.reader).Read(c_names))
+
+ return pyarrow_wrap_table(sp_table)
+
+cdef class ORCWriter(_Weakrefable):
+ cdef:
+ object source
+ unique_ptr[ORCFileWriter] writer
+ shared_ptr[COutputStream] rd_handle
+
+ def open(self, object source):
+ self.source = source
+ get_writer(source, &self.rd_handle)
+ with nogil:
+ self.writer = move(GetResultValue[unique_ptr[ORCFileWriter]](
+ ORCFileWriter.Open(self.rd_handle.get())))
+
+ def write(self, Table table):
+ cdef:
+ shared_ptr[CTable] sp_table
+ sp_table = pyarrow_unwrap_table(table)
+ with nogil:
+ check_status(deref(self.writer).Write(deref(sp_table)))
+
+ def close(self):
+ with nogil:
+ check_status(deref(self.writer).Close())
diff --git a/src/arrow/python/pyarrow/_parquet.pxd b/src/arrow/python/pyarrow/_parquet.pxd
new file mode 100644
index 000000000..0efbb0c86
--- /dev/null
+++ b/src/arrow/python/pyarrow/_parquet.pxd
@@ -0,0 +1,559 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+# cython: language_level = 3
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport (CChunkedArray, CSchema, CStatus,
+ CTable, CMemoryPool, CBuffer,
+ CKeyValueMetadata,
+ CRandomAccessFile, COutputStream,
+ TimeUnit, CRecordBatchReader)
+from pyarrow.lib cimport _Weakrefable
+
+
+cdef extern from "parquet/api/schema.h" namespace "parquet::schema" nogil:
+ cdef cppclass Node:
+ pass
+
+ cdef cppclass GroupNode(Node):
+ pass
+
+ cdef cppclass PrimitiveNode(Node):
+ pass
+
+ cdef cppclass ColumnPath:
+ c_string ToDotString()
+ vector[c_string] ToDotVector()
+
+
+cdef extern from "parquet/api/schema.h" namespace "parquet" nogil:
+ enum ParquetType" parquet::Type::type":
+ ParquetType_BOOLEAN" parquet::Type::BOOLEAN"
+ ParquetType_INT32" parquet::Type::INT32"
+ ParquetType_INT64" parquet::Type::INT64"
+ ParquetType_INT96" parquet::Type::INT96"
+ ParquetType_FLOAT" parquet::Type::FLOAT"
+ ParquetType_DOUBLE" parquet::Type::DOUBLE"
+ ParquetType_BYTE_ARRAY" parquet::Type::BYTE_ARRAY"
+ ParquetType_FIXED_LEN_BYTE_ARRAY" parquet::Type::FIXED_LEN_BYTE_ARRAY"
+
+ enum ParquetLogicalTypeId" parquet::LogicalType::Type::type":
+ ParquetLogicalType_UNDEFINED" parquet::LogicalType::Type::UNDEFINED"
+ ParquetLogicalType_STRING" parquet::LogicalType::Type::STRING"
+ ParquetLogicalType_MAP" parquet::LogicalType::Type::MAP"
+ ParquetLogicalType_LIST" parquet::LogicalType::Type::LIST"
+ ParquetLogicalType_ENUM" parquet::LogicalType::Type::ENUM"
+ ParquetLogicalType_DECIMAL" parquet::LogicalType::Type::DECIMAL"
+ ParquetLogicalType_DATE" parquet::LogicalType::Type::DATE"
+ ParquetLogicalType_TIME" parquet::LogicalType::Type::TIME"
+ ParquetLogicalType_TIMESTAMP" parquet::LogicalType::Type::TIMESTAMP"
+ ParquetLogicalType_INT" parquet::LogicalType::Type::INT"
+ ParquetLogicalType_JSON" parquet::LogicalType::Type::JSON"
+ ParquetLogicalType_BSON" parquet::LogicalType::Type::BSON"
+ ParquetLogicalType_UUID" parquet::LogicalType::Type::UUID"
+ ParquetLogicalType_NONE" parquet::LogicalType::Type::NONE"
+
+ enum ParquetTimeUnit" parquet::LogicalType::TimeUnit::unit":
+ ParquetTimeUnit_UNKNOWN" parquet::LogicalType::TimeUnit::UNKNOWN"
+ ParquetTimeUnit_MILLIS" parquet::LogicalType::TimeUnit::MILLIS"
+ ParquetTimeUnit_MICROS" parquet::LogicalType::TimeUnit::MICROS"
+ ParquetTimeUnit_NANOS" parquet::LogicalType::TimeUnit::NANOS"
+
+ enum ParquetConvertedType" parquet::ConvertedType::type":
+ ParquetConvertedType_NONE" parquet::ConvertedType::NONE"
+ ParquetConvertedType_UTF8" parquet::ConvertedType::UTF8"
+ ParquetConvertedType_MAP" parquet::ConvertedType::MAP"
+ ParquetConvertedType_MAP_KEY_VALUE \
+ " parquet::ConvertedType::MAP_KEY_VALUE"
+ ParquetConvertedType_LIST" parquet::ConvertedType::LIST"
+ ParquetConvertedType_ENUM" parquet::ConvertedType::ENUM"
+ ParquetConvertedType_DECIMAL" parquet::ConvertedType::DECIMAL"
+ ParquetConvertedType_DATE" parquet::ConvertedType::DATE"
+ ParquetConvertedType_TIME_MILLIS" parquet::ConvertedType::TIME_MILLIS"
+ ParquetConvertedType_TIME_MICROS" parquet::ConvertedType::TIME_MICROS"
+ ParquetConvertedType_TIMESTAMP_MILLIS \
+ " parquet::ConvertedType::TIMESTAMP_MILLIS"
+ ParquetConvertedType_TIMESTAMP_MICROS \
+ " parquet::ConvertedType::TIMESTAMP_MICROS"
+ ParquetConvertedType_UINT_8" parquet::ConvertedType::UINT_8"
+ ParquetConvertedType_UINT_16" parquet::ConvertedType::UINT_16"
+ ParquetConvertedType_UINT_32" parquet::ConvertedType::UINT_32"
+ ParquetConvertedType_UINT_64" parquet::ConvertedType::UINT_64"
+ ParquetConvertedType_INT_8" parquet::ConvertedType::INT_8"
+ ParquetConvertedType_INT_16" parquet::ConvertedType::INT_16"
+ ParquetConvertedType_INT_32" parquet::ConvertedType::INT_32"
+ ParquetConvertedType_INT_64" parquet::ConvertedType::INT_64"
+ ParquetConvertedType_JSON" parquet::ConvertedType::JSON"
+ ParquetConvertedType_BSON" parquet::ConvertedType::BSON"
+ ParquetConvertedType_INTERVAL" parquet::ConvertedType::INTERVAL"
+
+ enum ParquetRepetition" parquet::Repetition::type":
+ ParquetRepetition_REQUIRED" parquet::REPETITION::REQUIRED"
+ ParquetRepetition_OPTIONAL" parquet::REPETITION::OPTIONAL"
+ ParquetRepetition_REPEATED" parquet::REPETITION::REPEATED"
+
+ enum ParquetEncoding" parquet::Encoding::type":
+ ParquetEncoding_PLAIN" parquet::Encoding::PLAIN"
+ ParquetEncoding_PLAIN_DICTIONARY" parquet::Encoding::PLAIN_DICTIONARY"
+ ParquetEncoding_RLE" parquet::Encoding::RLE"
+ ParquetEncoding_BIT_PACKED" parquet::Encoding::BIT_PACKED"
+ ParquetEncoding_DELTA_BINARY_PACKED \
+ " parquet::Encoding::DELTA_BINARY_PACKED"
+ ParquetEncoding_DELTA_LENGTH_BYTE_ARRAY \
+ " parquet::Encoding::DELTA_LENGTH_BYTE_ARRAY"
+ ParquetEncoding_DELTA_BYTE_ARRAY" parquet::Encoding::DELTA_BYTE_ARRAY"
+ ParquetEncoding_RLE_DICTIONARY" parquet::Encoding::RLE_DICTIONARY"
+ ParquetEncoding_BYTE_STREAM_SPLIT \
+ " parquet::Encoding::BYTE_STREAM_SPLIT"
+
+ enum ParquetCompression" parquet::Compression::type":
+ ParquetCompression_UNCOMPRESSED" parquet::Compression::UNCOMPRESSED"
+ ParquetCompression_SNAPPY" parquet::Compression::SNAPPY"
+ ParquetCompression_GZIP" parquet::Compression::GZIP"
+ ParquetCompression_LZO" parquet::Compression::LZO"
+ ParquetCompression_BROTLI" parquet::Compression::BROTLI"
+ ParquetCompression_LZ4" parquet::Compression::LZ4"
+ ParquetCompression_ZSTD" parquet::Compression::ZSTD"
+
+ enum ParquetVersion" parquet::ParquetVersion::type":
+ ParquetVersion_V1" parquet::ParquetVersion::PARQUET_1_0"
+ ParquetVersion_V2_0" parquet::ParquetVersion::PARQUET_2_0"
+ ParquetVersion_V2_4" parquet::ParquetVersion::PARQUET_2_4"
+ ParquetVersion_V2_6" parquet::ParquetVersion::PARQUET_2_6"
+
+ enum ParquetSortOrder" parquet::SortOrder::type":
+ ParquetSortOrder_SIGNED" parquet::SortOrder::SIGNED"
+ ParquetSortOrder_UNSIGNED" parquet::SortOrder::UNSIGNED"
+ ParquetSortOrder_UNKNOWN" parquet::SortOrder::UNKNOWN"
+
+ cdef cppclass CParquetLogicalType" parquet::LogicalType":
+ c_string ToString() const
+ c_string ToJSON() const
+ ParquetLogicalTypeId type() const
+
+ cdef cppclass CParquetDecimalType \
+ " parquet::DecimalLogicalType"(CParquetLogicalType):
+ int32_t precision() const
+ int32_t scale() const
+
+ cdef cppclass CParquetIntType \
+ " parquet::IntLogicalType"(CParquetLogicalType):
+ int bit_width() const
+ c_bool is_signed() const
+
+ cdef cppclass CParquetTimeType \
+ " parquet::TimeLogicalType"(CParquetLogicalType):
+ c_bool is_adjusted_to_utc() const
+ ParquetTimeUnit time_unit() const
+
+ cdef cppclass CParquetTimestampType \
+ " parquet::TimestampLogicalType"(CParquetLogicalType):
+ c_bool is_adjusted_to_utc() const
+ ParquetTimeUnit time_unit() const
+
+ cdef cppclass ColumnDescriptor" parquet::ColumnDescriptor":
+ c_bool Equals(const ColumnDescriptor& other)
+
+ shared_ptr[ColumnPath] path()
+ int16_t max_definition_level()
+ int16_t max_repetition_level()
+
+ ParquetType physical_type()
+ const shared_ptr[const CParquetLogicalType]& logical_type()
+ ParquetConvertedType converted_type()
+ const c_string& name()
+ int type_length()
+ int type_precision()
+ int type_scale()
+
+ cdef cppclass SchemaDescriptor:
+ const ColumnDescriptor* Column(int i)
+ shared_ptr[Node] schema()
+ GroupNode* group()
+ c_bool Equals(const SchemaDescriptor& other)
+ c_string ToString()
+ int num_columns()
+
+ cdef c_string FormatStatValue(ParquetType parquet_type, c_string val)
+
+
+cdef extern from "parquet/api/reader.h" namespace "parquet" nogil:
+ cdef cppclass ColumnReader:
+ pass
+
+ cdef cppclass BoolReader(ColumnReader):
+ pass
+
+ cdef cppclass Int32Reader(ColumnReader):
+ pass
+
+ cdef cppclass Int64Reader(ColumnReader):
+ pass
+
+ cdef cppclass Int96Reader(ColumnReader):
+ pass
+
+ cdef cppclass FloatReader(ColumnReader):
+ pass
+
+ cdef cppclass DoubleReader(ColumnReader):
+ pass
+
+ cdef cppclass ByteArrayReader(ColumnReader):
+ pass
+
+ cdef cppclass RowGroupReader:
+ pass
+
+ cdef cppclass CEncodedStatistics" parquet::EncodedStatistics":
+ const c_string& max() const
+ const c_string& min() const
+ int64_t null_count
+ int64_t distinct_count
+ bint has_min
+ bint has_max
+ bint has_null_count
+ bint has_distinct_count
+
+ cdef cppclass ParquetByteArray" parquet::ByteArray":
+ uint32_t len
+ const uint8_t* ptr
+
+ cdef cppclass ParquetFLBA" parquet::FLBA":
+ const uint8_t* ptr
+
+ cdef cppclass CStatistics" parquet::Statistics":
+ int64_t null_count() const
+ int64_t distinct_count() const
+ int64_t num_values() const
+ bint HasMinMax()
+ bint HasNullCount()
+ bint HasDistinctCount()
+ c_bool Equals(const CStatistics&) const
+ void Reset()
+ c_string EncodeMin()
+ c_string EncodeMax()
+ CEncodedStatistics Encode()
+ void SetComparator()
+ ParquetType physical_type() const
+ const ColumnDescriptor* descr() const
+
+ cdef cppclass CBoolStatistics" parquet::BoolStatistics"(CStatistics):
+ c_bool min()
+ c_bool max()
+
+ cdef cppclass CInt32Statistics" parquet::Int32Statistics"(CStatistics):
+ int32_t min()
+ int32_t max()
+
+ cdef cppclass CInt64Statistics" parquet::Int64Statistics"(CStatistics):
+ int64_t min()
+ int64_t max()
+
+ cdef cppclass CFloatStatistics" parquet::FloatStatistics"(CStatistics):
+ float min()
+ float max()
+
+ cdef cppclass CDoubleStatistics" parquet::DoubleStatistics"(CStatistics):
+ double min()
+ double max()
+
+ cdef cppclass CByteArrayStatistics \
+ " parquet::ByteArrayStatistics"(CStatistics):
+ ParquetByteArray min()
+ ParquetByteArray max()
+
+ cdef cppclass CFLBAStatistics" parquet::FLBAStatistics"(CStatistics):
+ ParquetFLBA min()
+ ParquetFLBA max()
+
+ cdef cppclass CColumnChunkMetaData" parquet::ColumnChunkMetaData":
+ int64_t file_offset() const
+ const c_string& file_path() const
+
+ ParquetType type() const
+ int64_t num_values() const
+ shared_ptr[ColumnPath] path_in_schema() const
+ bint is_stats_set() const
+ shared_ptr[CStatistics] statistics() const
+ ParquetCompression compression() const
+ const vector[ParquetEncoding]& encodings() const
+ c_bool Equals(const CColumnChunkMetaData&) const
+
+ int64_t has_dictionary_page() const
+ int64_t dictionary_page_offset() const
+ int64_t data_page_offset() const
+ int64_t index_page_offset() const
+ int64_t total_compressed_size() const
+ int64_t total_uncompressed_size() const
+
+ cdef cppclass CRowGroupMetaData" parquet::RowGroupMetaData":
+ c_bool Equals(const CRowGroupMetaData&) const
+ int num_columns()
+ int64_t num_rows()
+ int64_t total_byte_size()
+ unique_ptr[CColumnChunkMetaData] ColumnChunk(int i) const
+
+ cdef cppclass CFileMetaData" parquet::FileMetaData":
+ c_bool Equals(const CFileMetaData&) const
+ uint32_t size()
+ int num_columns()
+ int64_t num_rows()
+ int num_row_groups()
+ ParquetVersion version()
+ const c_string created_by()
+ int num_schema_elements()
+
+ void set_file_path(const c_string& path)
+ void AppendRowGroups(const CFileMetaData& other) except +
+
+ unique_ptr[CRowGroupMetaData] RowGroup(int i)
+ const SchemaDescriptor* schema()
+ shared_ptr[const CKeyValueMetadata] key_value_metadata() const
+ void WriteTo(COutputStream* dst) const
+
+ cdef shared_ptr[CFileMetaData] CFileMetaData_Make \
+ " parquet::FileMetaData::Make"(const void* serialized_metadata,
+ uint32_t* metadata_len)
+
+ cdef cppclass CReaderProperties" parquet::ReaderProperties":
+ c_bool is_buffered_stream_enabled() const
+ void enable_buffered_stream()
+ void disable_buffered_stream()
+ void set_buffer_size(int64_t buf_size)
+ int64_t buffer_size() const
+
+ CReaderProperties default_reader_properties()
+
+ cdef cppclass ArrowReaderProperties:
+ ArrowReaderProperties()
+ void set_read_dictionary(int column_index, c_bool read_dict)
+ c_bool read_dictionary()
+ void set_batch_size(int64_t batch_size)
+ int64_t batch_size()
+ void set_pre_buffer(c_bool pre_buffer)
+ c_bool pre_buffer() const
+ void set_coerce_int96_timestamp_unit(TimeUnit unit)
+ TimeUnit coerce_int96_timestamp_unit() const
+
+ ArrowReaderProperties default_arrow_reader_properties()
+
+ cdef cppclass ParquetFileReader:
+ shared_ptr[CFileMetaData] metadata()
+
+
+cdef extern from "parquet/api/writer.h" namespace "parquet" nogil:
+ cdef cppclass WriterProperties:
+ cppclass Builder:
+ Builder* data_page_version(ParquetDataPageVersion version)
+ Builder* version(ParquetVersion version)
+ Builder* compression(ParquetCompression codec)
+ Builder* compression(const c_string& path,
+ ParquetCompression codec)
+ Builder* compression_level(int compression_level)
+ Builder* compression_level(const c_string& path,
+ int compression_level)
+ Builder* disable_dictionary()
+ Builder* enable_dictionary()
+ Builder* enable_dictionary(const c_string& path)
+ Builder* disable_statistics()
+ Builder* enable_statistics()
+ Builder* enable_statistics(const c_string& path)
+ Builder* data_pagesize(int64_t size)
+ Builder* encoding(ParquetEncoding encoding)
+ Builder* encoding(const c_string& path,
+ ParquetEncoding encoding)
+ Builder* write_batch_size(int64_t batch_size)
+ shared_ptr[WriterProperties] build()
+
+ cdef cppclass ArrowWriterProperties:
+ cppclass Builder:
+ Builder()
+ Builder* disable_deprecated_int96_timestamps()
+ Builder* enable_deprecated_int96_timestamps()
+ Builder* coerce_timestamps(TimeUnit unit)
+ Builder* allow_truncated_timestamps()
+ Builder* disallow_truncated_timestamps()
+ Builder* store_schema()
+ Builder* enable_compliant_nested_types()
+ Builder* disable_compliant_nested_types()
+ Builder* set_engine_version(ArrowWriterEngineVersion version)
+ shared_ptr[ArrowWriterProperties] build()
+ c_bool support_deprecated_int96_timestamps()
+
+
+cdef extern from "parquet/arrow/reader.h" namespace "parquet::arrow" nogil:
+ cdef cppclass FileReader:
+ FileReader(CMemoryPool* pool, unique_ptr[ParquetFileReader] reader)
+
+ CStatus GetSchema(shared_ptr[CSchema]* out)
+
+ CStatus ReadColumn(int i, shared_ptr[CChunkedArray]* out)
+ CStatus ReadSchemaField(int i, shared_ptr[CChunkedArray]* out)
+
+ int num_row_groups()
+ CStatus ReadRowGroup(int i, shared_ptr[CTable]* out)
+ CStatus ReadRowGroup(int i, const vector[int]& column_indices,
+ shared_ptr[CTable]* out)
+
+ CStatus ReadRowGroups(const vector[int]& row_groups,
+ shared_ptr[CTable]* out)
+ CStatus ReadRowGroups(const vector[int]& row_groups,
+ const vector[int]& column_indices,
+ shared_ptr[CTable]* out)
+
+ CStatus GetRecordBatchReader(const vector[int]& row_group_indices,
+ const vector[int]& column_indices,
+ unique_ptr[CRecordBatchReader]* out)
+ CStatus GetRecordBatchReader(const vector[int]& row_group_indices,
+ unique_ptr[CRecordBatchReader]* out)
+
+ CStatus ReadTable(shared_ptr[CTable]* out)
+ CStatus ReadTable(const vector[int]& column_indices,
+ shared_ptr[CTable]* out)
+
+ CStatus ScanContents(vector[int] columns, int32_t column_batch_size,
+ int64_t* num_rows)
+
+ const ParquetFileReader* parquet_reader()
+
+ void set_use_threads(c_bool use_threads)
+
+ void set_batch_size(int64_t batch_size)
+
+ cdef cppclass FileReaderBuilder:
+ FileReaderBuilder()
+ CStatus Open(const shared_ptr[CRandomAccessFile]& file,
+ const CReaderProperties& properties,
+ const shared_ptr[CFileMetaData]& metadata)
+
+ ParquetFileReader* raw_reader()
+ FileReaderBuilder* memory_pool(CMemoryPool*)
+ FileReaderBuilder* properties(const ArrowReaderProperties&)
+ CStatus Build(unique_ptr[FileReader]* out)
+
+ CStatus FromParquetSchema(
+ const SchemaDescriptor* parquet_schema,
+ const ArrowReaderProperties& properties,
+ const shared_ptr[const CKeyValueMetadata]& key_value_metadata,
+ shared_ptr[CSchema]* out)
+
+cdef extern from "parquet/arrow/schema.h" namespace "parquet::arrow" nogil:
+
+ CStatus ToParquetSchema(
+ const CSchema* arrow_schema,
+ const ArrowReaderProperties& properties,
+ const shared_ptr[const CKeyValueMetadata]& key_value_metadata,
+ shared_ptr[SchemaDescriptor]* out)
+
+
+cdef extern from "parquet/properties.h" namespace "parquet" nogil:
+ cdef enum ArrowWriterEngineVersion:
+ V1 "parquet::ArrowWriterProperties::V1",
+ V2 "parquet::ArrowWriterProperties::V2"
+
+ cdef cppclass ParquetDataPageVersion:
+ pass
+
+ cdef ParquetDataPageVersion ParquetDataPageVersion_V1 \
+ " parquet::ParquetDataPageVersion::V1"
+ cdef ParquetDataPageVersion ParquetDataPageVersion_V2 \
+ " parquet::ParquetDataPageVersion::V2"
+
+cdef extern from "parquet/arrow/writer.h" namespace "parquet::arrow" nogil:
+ cdef cppclass FileWriter:
+
+ @staticmethod
+ CStatus Open(const CSchema& schema, CMemoryPool* pool,
+ const shared_ptr[COutputStream]& sink,
+ const shared_ptr[WriterProperties]& properties,
+ const shared_ptr[ArrowWriterProperties]& arrow_properties,
+ unique_ptr[FileWriter]* writer)
+
+ CStatus WriteTable(const CTable& table, int64_t chunk_size)
+ CStatus NewRowGroup(int64_t chunk_size)
+ CStatus Close()
+
+ const shared_ptr[CFileMetaData] metadata() const
+
+ CStatus WriteMetaDataFile(
+ const CFileMetaData& file_metadata,
+ const COutputStream* sink)
+
+
+cdef shared_ptr[WriterProperties] _create_writer_properties(
+ use_dictionary=*,
+ compression=*,
+ version=*,
+ write_statistics=*,
+ data_page_size=*,
+ compression_level=*,
+ use_byte_stream_split=*,
+ data_page_version=*) except *
+
+
+cdef shared_ptr[ArrowWriterProperties] _create_arrow_writer_properties(
+ use_deprecated_int96_timestamps=*,
+ coerce_timestamps=*,
+ allow_truncated_timestamps=*,
+ writer_engine_version=*,
+ use_compliant_nested_type=*) except *
+
+cdef class ParquetSchema(_Weakrefable):
+ cdef:
+ FileMetaData parent # the FileMetaData owning the SchemaDescriptor
+ const SchemaDescriptor* schema
+
+cdef class FileMetaData(_Weakrefable):
+ cdef:
+ shared_ptr[CFileMetaData] sp_metadata
+ CFileMetaData* _metadata
+ ParquetSchema _schema
+
+ cdef inline init(self, const shared_ptr[CFileMetaData]& metadata):
+ self.sp_metadata = metadata
+ self._metadata = metadata.get()
+
+cdef class RowGroupMetaData(_Weakrefable):
+ cdef:
+ int index # for pickling support
+ unique_ptr[CRowGroupMetaData] up_metadata
+ CRowGroupMetaData* metadata
+ FileMetaData parent
+
+cdef class ColumnChunkMetaData(_Weakrefable):
+ cdef:
+ unique_ptr[CColumnChunkMetaData] up_metadata
+ CColumnChunkMetaData* metadata
+ RowGroupMetaData parent
+
+ cdef inline init(self, RowGroupMetaData parent, int i):
+ self.up_metadata = parent.metadata.ColumnChunk(i)
+ self.metadata = self.up_metadata.get()
+ self.parent = parent
+
+cdef class Statistics(_Weakrefable):
+ cdef:
+ shared_ptr[CStatistics] statistics
+ ColumnChunkMetaData parent
+
+ cdef inline init(self, const shared_ptr[CStatistics]& statistics,
+ ColumnChunkMetaData parent):
+ self.statistics = statistics
+ self.parent = parent
diff --git a/src/arrow/python/pyarrow/_parquet.pyx b/src/arrow/python/pyarrow/_parquet.pyx
new file mode 100644
index 000000000..87dfecf1b
--- /dev/null
+++ b/src/arrow/python/pyarrow/_parquet.pyx
@@ -0,0 +1,1466 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile=False
+# distutils: language = c++
+
+import io
+from textwrap import indent
+import warnings
+
+import numpy as np
+
+from cython.operator cimport dereference as deref
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.lib cimport (_Weakrefable, Buffer, Array, Schema,
+ check_status,
+ MemoryPool, maybe_unbox_memory_pool,
+ Table, NativeFile,
+ pyarrow_wrap_chunked_array,
+ pyarrow_wrap_schema,
+ pyarrow_wrap_table,
+ pyarrow_wrap_buffer,
+ pyarrow_wrap_batch,
+ NativeFile, get_reader, get_writer,
+ string_to_timeunit)
+
+from pyarrow.lib import (ArrowException, NativeFile, BufferOutputStream,
+ _stringify_path, _datetime_from_int,
+ tobytes, frombytes)
+
+cimport cpython as cp
+
+
+cdef class Statistics(_Weakrefable):
+ def __cinit__(self):
+ pass
+
+ def __repr__(self):
+ return """{}
+ has_min_max: {}
+ min: {}
+ max: {}
+ null_count: {}
+ distinct_count: {}
+ num_values: {}
+ physical_type: {}
+ logical_type: {}
+ converted_type (legacy): {}""".format(object.__repr__(self),
+ self.has_min_max,
+ self.min,
+ self.max,
+ self.null_count,
+ self.distinct_count,
+ self.num_values,
+ self.physical_type,
+ str(self.logical_type),
+ self.converted_type)
+
+ def to_dict(self):
+ d = dict(
+ has_min_max=self.has_min_max,
+ min=self.min,
+ max=self.max,
+ null_count=self.null_count,
+ distinct_count=self.distinct_count,
+ num_values=self.num_values,
+ physical_type=self.physical_type
+ )
+ return d
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def equals(self, Statistics other):
+ return self.statistics.get().Equals(deref(other.statistics.get()))
+
+ @property
+ def has_min_max(self):
+ return self.statistics.get().HasMinMax()
+
+ @property
+ def has_null_count(self):
+ return self.statistics.get().HasNullCount()
+
+ @property
+ def has_distinct_count(self):
+ return self.statistics.get().HasDistinctCount()
+
+ @property
+ def min_raw(self):
+ if self.has_min_max:
+ return _cast_statistic_raw_min(self.statistics.get())
+ else:
+ return None
+
+ @property
+ def max_raw(self):
+ if self.has_min_max:
+ return _cast_statistic_raw_max(self.statistics.get())
+ else:
+ return None
+
+ @property
+ def min(self):
+ if self.has_min_max:
+ return _cast_statistic_min(self.statistics.get())
+ else:
+ return None
+
+ @property
+ def max(self):
+ if self.has_min_max:
+ return _cast_statistic_max(self.statistics.get())
+ else:
+ return None
+
+ @property
+ def null_count(self):
+ return self.statistics.get().null_count()
+
+ @property
+ def distinct_count(self):
+ return self.statistics.get().distinct_count()
+
+ @property
+ def num_values(self):
+ return self.statistics.get().num_values()
+
+ @property
+ def physical_type(self):
+ raw_physical_type = self.statistics.get().physical_type()
+ return physical_type_name_from_enum(raw_physical_type)
+
+ @property
+ def logical_type(self):
+ return wrap_logical_type(self.statistics.get().descr().logical_type())
+
+ @property
+ def converted_type(self):
+ raw_converted_type = self.statistics.get().descr().converted_type()
+ return converted_type_name_from_enum(raw_converted_type)
+
+
+cdef class ParquetLogicalType(_Weakrefable):
+ cdef:
+ shared_ptr[const CParquetLogicalType] type
+
+ def __cinit__(self):
+ pass
+
+ cdef init(self, const shared_ptr[const CParquetLogicalType]& type):
+ self.type = type
+
+ def __str__(self):
+ return frombytes(self.type.get().ToString(), safe=True)
+
+ def to_json(self):
+ return frombytes(self.type.get().ToJSON())
+
+ @property
+ def type(self):
+ return logical_type_name_from_enum(self.type.get().type())
+
+
+cdef wrap_logical_type(const shared_ptr[const CParquetLogicalType]& type):
+ cdef ParquetLogicalType out = ParquetLogicalType()
+ out.init(type)
+ return out
+
+
+cdef _cast_statistic_raw_min(CStatistics* statistics):
+ cdef ParquetType physical_type = statistics.physical_type()
+ cdef uint32_t type_length = statistics.descr().type_length()
+ if physical_type == ParquetType_BOOLEAN:
+ return (<CBoolStatistics*> statistics).min()
+ elif physical_type == ParquetType_INT32:
+ return (<CInt32Statistics*> statistics).min()
+ elif physical_type == ParquetType_INT64:
+ return (<CInt64Statistics*> statistics).min()
+ elif physical_type == ParquetType_FLOAT:
+ return (<CFloatStatistics*> statistics).min()
+ elif physical_type == ParquetType_DOUBLE:
+ return (<CDoubleStatistics*> statistics).min()
+ elif physical_type == ParquetType_BYTE_ARRAY:
+ return _box_byte_array((<CByteArrayStatistics*> statistics).min())
+ elif physical_type == ParquetType_FIXED_LEN_BYTE_ARRAY:
+ return _box_flba((<CFLBAStatistics*> statistics).min(), type_length)
+
+
+cdef _cast_statistic_raw_max(CStatistics* statistics):
+ cdef ParquetType physical_type = statistics.physical_type()
+ cdef uint32_t type_length = statistics.descr().type_length()
+ if physical_type == ParquetType_BOOLEAN:
+ return (<CBoolStatistics*> statistics).max()
+ elif physical_type == ParquetType_INT32:
+ return (<CInt32Statistics*> statistics).max()
+ elif physical_type == ParquetType_INT64:
+ return (<CInt64Statistics*> statistics).max()
+ elif physical_type == ParquetType_FLOAT:
+ return (<CFloatStatistics*> statistics).max()
+ elif physical_type == ParquetType_DOUBLE:
+ return (<CDoubleStatistics*> statistics).max()
+ elif physical_type == ParquetType_BYTE_ARRAY:
+ return _box_byte_array((<CByteArrayStatistics*> statistics).max())
+ elif physical_type == ParquetType_FIXED_LEN_BYTE_ARRAY:
+ return _box_flba((<CFLBAStatistics*> statistics).max(), type_length)
+
+
+cdef _cast_statistic_min(CStatistics* statistics):
+ min_raw = _cast_statistic_raw_min(statistics)
+ return _box_logical_type_value(min_raw, statistics.descr())
+
+
+cdef _cast_statistic_max(CStatistics* statistics):
+ max_raw = _cast_statistic_raw_max(statistics)
+ return _box_logical_type_value(max_raw, statistics.descr())
+
+
+cdef _box_logical_type_value(object value, const ColumnDescriptor* descr):
+ cdef:
+ const CParquetLogicalType* ltype = descr.logical_type().get()
+ ParquetTimeUnit time_unit
+ const CParquetIntType* itype
+ const CParquetTimestampType* ts_type
+
+ if ltype.type() == ParquetLogicalType_STRING:
+ return value.decode('utf8')
+ elif ltype.type() == ParquetLogicalType_TIME:
+ time_unit = (<const CParquetTimeType*> ltype).time_unit()
+ if time_unit == ParquetTimeUnit_MILLIS:
+ return _datetime_from_int(value, unit=TimeUnit_MILLI).time()
+ else:
+ return _datetime_from_int(value, unit=TimeUnit_MICRO).time()
+ elif ltype.type() == ParquetLogicalType_TIMESTAMP:
+ ts_type = <const CParquetTimestampType*> ltype
+ time_unit = ts_type.time_unit()
+ if ts_type.is_adjusted_to_utc():
+ import pytz
+ tzinfo = pytz.utc
+ else:
+ tzinfo = None
+ if time_unit == ParquetTimeUnit_MILLIS:
+ return _datetime_from_int(value, unit=TimeUnit_MILLI,
+ tzinfo=tzinfo)
+ elif time_unit == ParquetTimeUnit_MICROS:
+ return _datetime_from_int(value, unit=TimeUnit_MICRO,
+ tzinfo=tzinfo)
+ elif time_unit == ParquetTimeUnit_NANOS:
+ return _datetime_from_int(value, unit=TimeUnit_NANO,
+ tzinfo=tzinfo)
+ else:
+ raise ValueError("Unsupported time unit")
+ elif ltype.type() == ParquetLogicalType_INT:
+ itype = <const CParquetIntType*> ltype
+ if not itype.is_signed() and itype.bit_width() == 32:
+ return int(np.int32(value).view(np.uint32))
+ elif not itype.is_signed() and itype.bit_width() == 64:
+ return int(np.int64(value).view(np.uint64))
+ else:
+ return value
+ else:
+ # No logical boxing defined
+ return value
+
+
+cdef _box_byte_array(ParquetByteArray val):
+ return cp.PyBytes_FromStringAndSize(<char*> val.ptr, <Py_ssize_t> val.len)
+
+
+cdef _box_flba(ParquetFLBA val, uint32_t len):
+ return cp.PyBytes_FromStringAndSize(<char*> val.ptr, <Py_ssize_t> len)
+
+
+cdef class ColumnChunkMetaData(_Weakrefable):
+ def __cinit__(self):
+ pass
+
+ def __repr__(self):
+ statistics = indent(repr(self.statistics), 4 * ' ')
+ return """{0}
+ file_offset: {1}
+ file_path: {2}
+ physical_type: {3}
+ num_values: {4}
+ path_in_schema: {5}
+ is_stats_set: {6}
+ statistics:
+{7}
+ compression: {8}
+ encodings: {9}
+ has_dictionary_page: {10}
+ dictionary_page_offset: {11}
+ data_page_offset: {12}
+ total_compressed_size: {13}
+ total_uncompressed_size: {14}""".format(object.__repr__(self),
+ self.file_offset,
+ self.file_path,
+ self.physical_type,
+ self.num_values,
+ self.path_in_schema,
+ self.is_stats_set,
+ statistics,
+ self.compression,
+ self.encodings,
+ self.has_dictionary_page,
+ self.dictionary_page_offset,
+ self.data_page_offset,
+ self.total_compressed_size,
+ self.total_uncompressed_size)
+
+ def to_dict(self):
+ statistics = self.statistics.to_dict() if self.is_stats_set else None
+ d = dict(
+ file_offset=self.file_offset,
+ file_path=self.file_path,
+ physical_type=self.physical_type,
+ num_values=self.num_values,
+ path_in_schema=self.path_in_schema,
+ is_stats_set=self.is_stats_set,
+ statistics=statistics,
+ compression=self.compression,
+ encodings=self.encodings,
+ has_dictionary_page=self.has_dictionary_page,
+ dictionary_page_offset=self.dictionary_page_offset,
+ data_page_offset=self.data_page_offset,
+ total_compressed_size=self.total_compressed_size,
+ total_uncompressed_size=self.total_uncompressed_size
+ )
+ return d
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def equals(self, ColumnChunkMetaData other):
+ return self.metadata.Equals(deref(other.metadata))
+
+ @property
+ def file_offset(self):
+ return self.metadata.file_offset()
+
+ @property
+ def file_path(self):
+ return frombytes(self.metadata.file_path())
+
+ @property
+ def physical_type(self):
+ return physical_type_name_from_enum(self.metadata.type())
+
+ @property
+ def num_values(self):
+ return self.metadata.num_values()
+
+ @property
+ def path_in_schema(self):
+ path = self.metadata.path_in_schema().get().ToDotString()
+ return frombytes(path)
+
+ @property
+ def is_stats_set(self):
+ return self.metadata.is_stats_set()
+
+ @property
+ def statistics(self):
+ if not self.metadata.is_stats_set():
+ return None
+ statistics = Statistics()
+ statistics.init(self.metadata.statistics(), self)
+ return statistics
+
+ @property
+ def compression(self):
+ return compression_name_from_enum(self.metadata.compression())
+
+ @property
+ def encodings(self):
+ return tuple(map(encoding_name_from_enum, self.metadata.encodings()))
+
+ @property
+ def has_dictionary_page(self):
+ return bool(self.metadata.has_dictionary_page())
+
+ @property
+ def dictionary_page_offset(self):
+ if self.has_dictionary_page:
+ return self.metadata.dictionary_page_offset()
+ else:
+ return None
+
+ @property
+ def data_page_offset(self):
+ return self.metadata.data_page_offset()
+
+ @property
+ def has_index_page(self):
+ raise NotImplementedError('not supported in parquet-cpp')
+
+ @property
+ def index_page_offset(self):
+ raise NotImplementedError("parquet-cpp doesn't return valid values")
+
+ @property
+ def total_compressed_size(self):
+ return self.metadata.total_compressed_size()
+
+ @property
+ def total_uncompressed_size(self):
+ return self.metadata.total_uncompressed_size()
+
+
+cdef class RowGroupMetaData(_Weakrefable):
+ def __cinit__(self, FileMetaData parent, int index):
+ if index < 0 or index >= parent.num_row_groups:
+ raise IndexError('{0} out of bounds'.format(index))
+ self.up_metadata = parent._metadata.RowGroup(index)
+ self.metadata = self.up_metadata.get()
+ self.parent = parent
+ self.index = index
+
+ def __reduce__(self):
+ return RowGroupMetaData, (self.parent, self.index)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def equals(self, RowGroupMetaData other):
+ return self.metadata.Equals(deref(other.metadata))
+
+ def column(self, int i):
+ if i < 0 or i >= self.num_columns:
+ raise IndexError('{0} out of bounds'.format(i))
+ chunk = ColumnChunkMetaData()
+ chunk.init(self, i)
+ return chunk
+
+ def __repr__(self):
+ return """{0}
+ num_columns: {1}
+ num_rows: {2}
+ total_byte_size: {3}""".format(object.__repr__(self),
+ self.num_columns,
+ self.num_rows,
+ self.total_byte_size)
+
+ def to_dict(self):
+ columns = []
+ d = dict(
+ num_columns=self.num_columns,
+ num_rows=self.num_rows,
+ total_byte_size=self.total_byte_size,
+ columns=columns,
+ )
+ for i in range(self.num_columns):
+ columns.append(self.column(i).to_dict())
+ return d
+
+ @property
+ def num_columns(self):
+ return self.metadata.num_columns()
+
+ @property
+ def num_rows(self):
+ return self.metadata.num_rows()
+
+ @property
+ def total_byte_size(self):
+ return self.metadata.total_byte_size()
+
+
+def _reconstruct_filemetadata(Buffer serialized):
+ cdef:
+ FileMetaData metadata = FileMetaData.__new__(FileMetaData)
+ CBuffer *buffer = serialized.buffer.get()
+ uint32_t metadata_len = <uint32_t>buffer.size()
+
+ metadata.init(CFileMetaData_Make(buffer.data(), &metadata_len))
+
+ return metadata
+
+
+cdef class FileMetaData(_Weakrefable):
+ def __cinit__(self):
+ pass
+
+ def __reduce__(self):
+ cdef:
+ NativeFile sink = BufferOutputStream()
+ COutputStream* c_sink = sink.get_output_stream().get()
+ with nogil:
+ self._metadata.WriteTo(c_sink)
+
+ cdef Buffer buffer = sink.getvalue()
+ return _reconstruct_filemetadata, (buffer,)
+
+ def __repr__(self):
+ return """{0}
+ created_by: {1}
+ num_columns: {2}
+ num_rows: {3}
+ num_row_groups: {4}
+ format_version: {5}
+ serialized_size: {6}""".format(object.__repr__(self),
+ self.created_by, self.num_columns,
+ self.num_rows, self.num_row_groups,
+ self.format_version,
+ self.serialized_size)
+
+ def to_dict(self):
+ row_groups = []
+ d = dict(
+ created_by=self.created_by,
+ num_columns=self.num_columns,
+ num_rows=self.num_rows,
+ num_row_groups=self.num_row_groups,
+ row_groups=row_groups,
+ format_version=self.format_version,
+ serialized_size=self.serialized_size
+ )
+ for i in range(self.num_row_groups):
+ row_groups.append(self.row_group(i).to_dict())
+ return d
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def equals(self, FileMetaData other):
+ return self._metadata.Equals(deref(other._metadata))
+
+ @property
+ def schema(self):
+ if self._schema is None:
+ self._schema = ParquetSchema(self)
+ return self._schema
+
+ @property
+ def serialized_size(self):
+ return self._metadata.size()
+
+ @property
+ def num_columns(self):
+ return self._metadata.num_columns()
+
+ @property
+ def num_rows(self):
+ return self._metadata.num_rows()
+
+ @property
+ def num_row_groups(self):
+ return self._metadata.num_row_groups()
+
+ @property
+ def format_version(self):
+ cdef ParquetVersion version = self._metadata.version()
+ if version == ParquetVersion_V1:
+ return '1.0'
+ elif version == ParquetVersion_V2_0:
+ return 'pseudo-2.0'
+ elif version == ParquetVersion_V2_4:
+ return '2.4'
+ elif version == ParquetVersion_V2_6:
+ return '2.6'
+ else:
+ warnings.warn('Unrecognized file version, assuming 1.0: {}'
+ .format(version))
+ return '1.0'
+
+ @property
+ def created_by(self):
+ return frombytes(self._metadata.created_by())
+
+ @property
+ def metadata(self):
+ cdef:
+ unordered_map[c_string, c_string] metadata
+ const CKeyValueMetadata* underlying_metadata
+ underlying_metadata = self._metadata.key_value_metadata().get()
+ if underlying_metadata != NULL:
+ underlying_metadata.ToUnorderedMap(&metadata)
+ return metadata
+ else:
+ return None
+
+ def row_group(self, int i):
+ return RowGroupMetaData(self, i)
+
+ def set_file_path(self, path):
+ """
+ Modify the file_path field of each ColumnChunk in the
+ FileMetaData to be a particular value
+ """
+ cdef:
+ c_string c_path = tobytes(path)
+ self._metadata.set_file_path(c_path)
+
+ def append_row_groups(self, FileMetaData other):
+ """
+ Append row groups of other FileMetaData object
+ """
+ cdef shared_ptr[CFileMetaData] c_metadata
+
+ c_metadata = other.sp_metadata
+ self._metadata.AppendRowGroups(deref(c_metadata))
+
+ def write_metadata_file(self, where):
+ """
+ Write the metadata object to a metadata-only file
+ """
+ cdef:
+ shared_ptr[COutputStream] sink
+ c_string c_where
+
+ try:
+ where = _stringify_path(where)
+ except TypeError:
+ get_writer(where, &sink)
+ else:
+ c_where = tobytes(where)
+ with nogil:
+ sink = GetResultValue(FileOutputStream.Open(c_where))
+
+ with nogil:
+ check_status(
+ WriteMetaDataFile(deref(self._metadata), sink.get()))
+
+
+cdef class ParquetSchema(_Weakrefable):
+ def __cinit__(self, FileMetaData container):
+ self.parent = container
+ self.schema = container._metadata.schema()
+
+ def __repr__(self):
+ return "{0}\n{1}".format(
+ object.__repr__(self),
+ frombytes(self.schema.ToString(), safe=True))
+
+ def __reduce__(self):
+ return ParquetSchema, (self.parent,)
+
+ def __len__(self):
+ return self.schema.num_columns()
+
+ def __getitem__(self, i):
+ return self.column(i)
+
+ @property
+ def names(self):
+ return [self[i].name for i in range(len(self))]
+
+ def to_arrow_schema(self):
+ """
+ Convert Parquet schema to effective Arrow schema
+
+ Returns
+ -------
+ schema : pyarrow.Schema
+ """
+ cdef shared_ptr[CSchema] sp_arrow_schema
+
+ with nogil:
+ check_status(FromParquetSchema(
+ self.schema, default_arrow_reader_properties(),
+ self.parent._metadata.key_value_metadata(),
+ &sp_arrow_schema))
+
+ return pyarrow_wrap_schema(sp_arrow_schema)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def equals(self, ParquetSchema other):
+ """
+ Returns True if the Parquet schemas are equal
+ """
+ return self.schema.Equals(deref(other.schema))
+
+ def column(self, i):
+ if i < 0 or i >= len(self):
+ raise IndexError('{0} out of bounds'.format(i))
+
+ return ColumnSchema(self, i)
+
+
+cdef class ColumnSchema(_Weakrefable):
+ cdef:
+ int index
+ ParquetSchema parent
+ const ColumnDescriptor* descr
+
+ def __cinit__(self, ParquetSchema schema, int index):
+ self.parent = schema
+ self.index = index # for pickling support
+ self.descr = schema.schema.Column(index)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def __reduce__(self):
+ return ColumnSchema, (self.parent, self.index)
+
+ def equals(self, ColumnSchema other):
+ """
+ Returns True if the column schemas are equal
+ """
+ return self.descr.Equals(deref(other.descr))
+
+ def __repr__(self):
+ physical_type = self.physical_type
+ converted_type = self.converted_type
+ if converted_type == 'DECIMAL':
+ converted_type = 'DECIMAL({0}, {1})'.format(self.precision,
+ self.scale)
+ elif physical_type == 'FIXED_LEN_BYTE_ARRAY':
+ converted_type = ('FIXED_LEN_BYTE_ARRAY(length={0})'
+ .format(self.length))
+
+ return """<ParquetColumnSchema>
+ name: {0}
+ path: {1}
+ max_definition_level: {2}
+ max_repetition_level: {3}
+ physical_type: {4}
+ logical_type: {5}
+ converted_type (legacy): {6}""".format(self.name, self.path,
+ self.max_definition_level,
+ self.max_repetition_level,
+ physical_type,
+ str(self.logical_type),
+ converted_type)
+
+ @property
+ def name(self):
+ return frombytes(self.descr.name())
+
+ @property
+ def path(self):
+ return frombytes(self.descr.path().get().ToDotString())
+
+ @property
+ def max_definition_level(self):
+ return self.descr.max_definition_level()
+
+ @property
+ def max_repetition_level(self):
+ return self.descr.max_repetition_level()
+
+ @property
+ def physical_type(self):
+ return physical_type_name_from_enum(self.descr.physical_type())
+
+ @property
+ def logical_type(self):
+ return wrap_logical_type(self.descr.logical_type())
+
+ @property
+ def converted_type(self):
+ return converted_type_name_from_enum(self.descr.converted_type())
+
+ @property
+ def logical_type(self):
+ return wrap_logical_type(self.descr.logical_type())
+
+ # FIXED_LEN_BYTE_ARRAY attribute
+ @property
+ def length(self):
+ return self.descr.type_length()
+
+ # Decimal attributes
+ @property
+ def precision(self):
+ return self.descr.type_precision()
+
+ @property
+ def scale(self):
+ return self.descr.type_scale()
+
+
+cdef physical_type_name_from_enum(ParquetType type_):
+ return {
+ ParquetType_BOOLEAN: 'BOOLEAN',
+ ParquetType_INT32: 'INT32',
+ ParquetType_INT64: 'INT64',
+ ParquetType_INT96: 'INT96',
+ ParquetType_FLOAT: 'FLOAT',
+ ParquetType_DOUBLE: 'DOUBLE',
+ ParquetType_BYTE_ARRAY: 'BYTE_ARRAY',
+ ParquetType_FIXED_LEN_BYTE_ARRAY: 'FIXED_LEN_BYTE_ARRAY',
+ }.get(type_, 'UNKNOWN')
+
+
+cdef logical_type_name_from_enum(ParquetLogicalTypeId type_):
+ return {
+ ParquetLogicalType_UNDEFINED: 'UNDEFINED',
+ ParquetLogicalType_STRING: 'STRING',
+ ParquetLogicalType_MAP: 'MAP',
+ ParquetLogicalType_LIST: 'LIST',
+ ParquetLogicalType_ENUM: 'ENUM',
+ ParquetLogicalType_DECIMAL: 'DECIMAL',
+ ParquetLogicalType_DATE: 'DATE',
+ ParquetLogicalType_TIME: 'TIME',
+ ParquetLogicalType_TIMESTAMP: 'TIMESTAMP',
+ ParquetLogicalType_INT: 'INT',
+ ParquetLogicalType_JSON: 'JSON',
+ ParquetLogicalType_BSON: 'BSON',
+ ParquetLogicalType_UUID: 'UUID',
+ ParquetLogicalType_NONE: 'NONE',
+ }.get(type_, 'UNKNOWN')
+
+
+cdef converted_type_name_from_enum(ParquetConvertedType type_):
+ return {
+ ParquetConvertedType_NONE: 'NONE',
+ ParquetConvertedType_UTF8: 'UTF8',
+ ParquetConvertedType_MAP: 'MAP',
+ ParquetConvertedType_MAP_KEY_VALUE: 'MAP_KEY_VALUE',
+ ParquetConvertedType_LIST: 'LIST',
+ ParquetConvertedType_ENUM: 'ENUM',
+ ParquetConvertedType_DECIMAL: 'DECIMAL',
+ ParquetConvertedType_DATE: 'DATE',
+ ParquetConvertedType_TIME_MILLIS: 'TIME_MILLIS',
+ ParquetConvertedType_TIME_MICROS: 'TIME_MICROS',
+ ParquetConvertedType_TIMESTAMP_MILLIS: 'TIMESTAMP_MILLIS',
+ ParquetConvertedType_TIMESTAMP_MICROS: 'TIMESTAMP_MICROS',
+ ParquetConvertedType_UINT_8: 'UINT_8',
+ ParquetConvertedType_UINT_16: 'UINT_16',
+ ParquetConvertedType_UINT_32: 'UINT_32',
+ ParquetConvertedType_UINT_64: 'UINT_64',
+ ParquetConvertedType_INT_8: 'INT_8',
+ ParquetConvertedType_INT_16: 'INT_16',
+ ParquetConvertedType_INT_32: 'INT_32',
+ ParquetConvertedType_INT_64: 'INT_64',
+ ParquetConvertedType_JSON: 'JSON',
+ ParquetConvertedType_BSON: 'BSON',
+ ParquetConvertedType_INTERVAL: 'INTERVAL',
+ }.get(type_, 'UNKNOWN')
+
+
+cdef encoding_name_from_enum(ParquetEncoding encoding_):
+ return {
+ ParquetEncoding_PLAIN: 'PLAIN',
+ ParquetEncoding_PLAIN_DICTIONARY: 'PLAIN_DICTIONARY',
+ ParquetEncoding_RLE: 'RLE',
+ ParquetEncoding_BIT_PACKED: 'BIT_PACKED',
+ ParquetEncoding_DELTA_BINARY_PACKED: 'DELTA_BINARY_PACKED',
+ ParquetEncoding_DELTA_LENGTH_BYTE_ARRAY: 'DELTA_LENGTH_BYTE_ARRAY',
+ ParquetEncoding_DELTA_BYTE_ARRAY: 'DELTA_BYTE_ARRAY',
+ ParquetEncoding_RLE_DICTIONARY: 'RLE_DICTIONARY',
+ ParquetEncoding_BYTE_STREAM_SPLIT: 'BYTE_STREAM_SPLIT',
+ }.get(encoding_, 'UNKNOWN')
+
+
+cdef compression_name_from_enum(ParquetCompression compression_):
+ return {
+ ParquetCompression_UNCOMPRESSED: 'UNCOMPRESSED',
+ ParquetCompression_SNAPPY: 'SNAPPY',
+ ParquetCompression_GZIP: 'GZIP',
+ ParquetCompression_LZO: 'LZO',
+ ParquetCompression_BROTLI: 'BROTLI',
+ ParquetCompression_LZ4: 'LZ4',
+ ParquetCompression_ZSTD: 'ZSTD',
+ }.get(compression_, 'UNKNOWN')
+
+
+cdef int check_compression_name(name) except -1:
+ if name.upper() not in {'NONE', 'SNAPPY', 'GZIP', 'LZO', 'BROTLI', 'LZ4',
+ 'ZSTD'}:
+ raise ArrowException("Unsupported compression: " + name)
+ return 0
+
+
+cdef ParquetCompression compression_from_name(name):
+ name = name.upper()
+ if name == 'SNAPPY':
+ return ParquetCompression_SNAPPY
+ elif name == 'GZIP':
+ return ParquetCompression_GZIP
+ elif name == 'LZO':
+ return ParquetCompression_LZO
+ elif name == 'BROTLI':
+ return ParquetCompression_BROTLI
+ elif name == 'LZ4':
+ return ParquetCompression_LZ4
+ elif name == 'ZSTD':
+ return ParquetCompression_ZSTD
+ else:
+ return ParquetCompression_UNCOMPRESSED
+
+
+cdef class ParquetReader(_Weakrefable):
+ cdef:
+ object source
+ CMemoryPool* pool
+ unique_ptr[FileReader] reader
+ FileMetaData _metadata
+
+ cdef public:
+ _column_idx_map
+
+ def __cinit__(self, MemoryPool memory_pool=None):
+ self.pool = maybe_unbox_memory_pool(memory_pool)
+ self._metadata = None
+
+ def open(self, object source not None, bint use_memory_map=True,
+ read_dictionary=None, FileMetaData metadata=None,
+ int buffer_size=0, bint pre_buffer=False,
+ coerce_int96_timestamp_unit=None):
+ cdef:
+ shared_ptr[CRandomAccessFile] rd_handle
+ shared_ptr[CFileMetaData] c_metadata
+ CReaderProperties properties = default_reader_properties()
+ ArrowReaderProperties arrow_props = (
+ default_arrow_reader_properties())
+ c_string path
+ FileReaderBuilder builder
+ TimeUnit int96_timestamp_unit_code
+
+ if metadata is not None:
+ c_metadata = metadata.sp_metadata
+
+ if buffer_size > 0:
+ properties.enable_buffered_stream()
+ properties.set_buffer_size(buffer_size)
+ elif buffer_size == 0:
+ properties.disable_buffered_stream()
+ else:
+ raise ValueError('Buffer size must be larger than zero')
+
+ arrow_props.set_pre_buffer(pre_buffer)
+
+ if coerce_int96_timestamp_unit is None:
+ # use the default defined in default_arrow_reader_properties()
+ pass
+ else:
+ arrow_props.set_coerce_int96_timestamp_unit(
+ string_to_timeunit(coerce_int96_timestamp_unit))
+
+ self.source = source
+
+ get_reader(source, use_memory_map, &rd_handle)
+ with nogil:
+ check_status(builder.Open(rd_handle, properties, c_metadata))
+
+ # Set up metadata
+ with nogil:
+ c_metadata = builder.raw_reader().metadata()
+ self._metadata = result = FileMetaData()
+ result.init(c_metadata)
+
+ if read_dictionary is not None:
+ self._set_read_dictionary(read_dictionary, &arrow_props)
+
+ with nogil:
+ check_status(builder.memory_pool(self.pool)
+ .properties(arrow_props)
+ .Build(&self.reader))
+
+ cdef _set_read_dictionary(self, read_dictionary,
+ ArrowReaderProperties* props):
+ for column in read_dictionary:
+ if not isinstance(column, int):
+ column = self.column_name_idx(column)
+ props.set_read_dictionary(column, True)
+
+ @property
+ def column_paths(self):
+ cdef:
+ FileMetaData container = self.metadata
+ const CFileMetaData* metadata = container._metadata
+ vector[c_string] path
+ int i = 0
+
+ paths = []
+ for i in range(0, metadata.num_columns()):
+ path = (metadata.schema().Column(i)
+ .path().get().ToDotVector())
+ paths.append([frombytes(x) for x in path])
+
+ return paths
+
+ @property
+ def metadata(self):
+ return self._metadata
+
+ @property
+ def schema_arrow(self):
+ cdef shared_ptr[CSchema] out
+ with nogil:
+ check_status(self.reader.get().GetSchema(&out))
+ return pyarrow_wrap_schema(out)
+
+ @property
+ def num_row_groups(self):
+ return self.reader.get().num_row_groups()
+
+ def set_use_threads(self, bint use_threads):
+ self.reader.get().set_use_threads(use_threads)
+
+ def set_batch_size(self, int64_t batch_size):
+ self.reader.get().set_batch_size(batch_size)
+
+ def iter_batches(self, int64_t batch_size, row_groups, column_indices=None,
+ bint use_threads=True):
+ cdef:
+ vector[int] c_row_groups
+ vector[int] c_column_indices
+ shared_ptr[CRecordBatch] record_batch
+ shared_ptr[TableBatchReader] batch_reader
+ unique_ptr[CRecordBatchReader] recordbatchreader
+
+ self.set_batch_size(batch_size)
+
+ if use_threads:
+ self.set_use_threads(use_threads)
+
+ for row_group in row_groups:
+ c_row_groups.push_back(row_group)
+
+ if column_indices is not None:
+ for index in column_indices:
+ c_column_indices.push_back(index)
+ with nogil:
+ check_status(
+ self.reader.get().GetRecordBatchReader(
+ c_row_groups, c_column_indices, &recordbatchreader
+ )
+ )
+ else:
+ with nogil:
+ check_status(
+ self.reader.get().GetRecordBatchReader(
+ c_row_groups, &recordbatchreader
+ )
+ )
+
+ while True:
+ with nogil:
+ check_status(
+ recordbatchreader.get().ReadNext(&record_batch)
+ )
+
+ if record_batch.get() == NULL:
+ break
+
+ yield pyarrow_wrap_batch(record_batch)
+
+ def read_row_group(self, int i, column_indices=None,
+ bint use_threads=True):
+ return self.read_row_groups([i], column_indices, use_threads)
+
+ def read_row_groups(self, row_groups not None, column_indices=None,
+ bint use_threads=True):
+ cdef:
+ shared_ptr[CTable] ctable
+ vector[int] c_row_groups
+ vector[int] c_column_indices
+
+ self.set_use_threads(use_threads)
+
+ for row_group in row_groups:
+ c_row_groups.push_back(row_group)
+
+ if column_indices is not None:
+ for index in column_indices:
+ c_column_indices.push_back(index)
+
+ with nogil:
+ check_status(self.reader.get()
+ .ReadRowGroups(c_row_groups, c_column_indices,
+ &ctable))
+ else:
+ # Read all columns
+ with nogil:
+ check_status(self.reader.get()
+ .ReadRowGroups(c_row_groups, &ctable))
+ return pyarrow_wrap_table(ctable)
+
+ def read_all(self, column_indices=None, bint use_threads=True):
+ cdef:
+ shared_ptr[CTable] ctable
+ vector[int] c_column_indices
+
+ self.set_use_threads(use_threads)
+
+ if column_indices is not None:
+ for index in column_indices:
+ c_column_indices.push_back(index)
+
+ with nogil:
+ check_status(self.reader.get()
+ .ReadTable(c_column_indices, &ctable))
+ else:
+ # Read all columns
+ with nogil:
+ check_status(self.reader.get()
+ .ReadTable(&ctable))
+ return pyarrow_wrap_table(ctable)
+
+ def scan_contents(self, column_indices=None, batch_size=65536):
+ cdef:
+ vector[int] c_column_indices
+ int32_t c_batch_size
+ int64_t c_num_rows
+
+ if column_indices is not None:
+ for index in column_indices:
+ c_column_indices.push_back(index)
+
+ c_batch_size = batch_size
+
+ with nogil:
+ check_status(self.reader.get()
+ .ScanContents(c_column_indices, c_batch_size,
+ &c_num_rows))
+
+ return c_num_rows
+
+ def column_name_idx(self, column_name):
+ """
+ Find the matching index of a column in the schema.
+
+ Parameter
+ ---------
+ column_name: str
+ Name of the column, separation of nesting levels is done via ".".
+
+ Returns
+ -------
+ column_idx: int
+ Integer index of the position of the column
+ """
+ cdef:
+ FileMetaData container = self.metadata
+ const CFileMetaData* metadata = container._metadata
+ int i = 0
+
+ if self._column_idx_map is None:
+ self._column_idx_map = {}
+ for i in range(0, metadata.num_columns()):
+ col_bytes = tobytes(metadata.schema().Column(i)
+ .path().get().ToDotString())
+ self._column_idx_map[col_bytes] = i
+
+ return self._column_idx_map[tobytes(column_name)]
+
+ def read_column(self, int column_index):
+ cdef shared_ptr[CChunkedArray] out
+ with nogil:
+ check_status(self.reader.get()
+ .ReadColumn(column_index, &out))
+ return pyarrow_wrap_chunked_array(out)
+
+ def read_schema_field(self, int field_index):
+ cdef shared_ptr[CChunkedArray] out
+ with nogil:
+ check_status(self.reader.get()
+ .ReadSchemaField(field_index, &out))
+ return pyarrow_wrap_chunked_array(out)
+
+
+cdef shared_ptr[WriterProperties] _create_writer_properties(
+ use_dictionary=None,
+ compression=None,
+ version=None,
+ write_statistics=None,
+ data_page_size=None,
+ compression_level=None,
+ use_byte_stream_split=False,
+ data_page_version=None) except *:
+ """General writer properties"""
+ cdef:
+ shared_ptr[WriterProperties] properties
+ WriterProperties.Builder props
+
+ # data_page_version
+
+ if data_page_version is not None:
+ if data_page_version == "1.0":
+ props.data_page_version(ParquetDataPageVersion_V1)
+ elif data_page_version == "2.0":
+ props.data_page_version(ParquetDataPageVersion_V2)
+ else:
+ raise ValueError("Unsupported Parquet data page version: {0}"
+ .format(data_page_version))
+
+ # version
+
+ if version is not None:
+ if version == "1.0":
+ props.version(ParquetVersion_V1)
+ elif version in ("2.0", "pseudo-2.0"):
+ warnings.warn(
+ "Parquet format '2.0' pseudo version is deprecated, use "
+ "'2.4' or '2.6' for fine-grained feature selection",
+ FutureWarning, stacklevel=2)
+ props.version(ParquetVersion_V2_0)
+ elif version == "2.4":
+ props.version(ParquetVersion_V2_4)
+ elif version == "2.6":
+ props.version(ParquetVersion_V2_6)
+ else:
+ raise ValueError("Unsupported Parquet format version: {0}"
+ .format(version))
+
+ # compression
+
+ if isinstance(compression, basestring):
+ check_compression_name(compression)
+ props.compression(compression_from_name(compression))
+ elif compression is not None:
+ for column, codec in compression.iteritems():
+ check_compression_name(codec)
+ props.compression(tobytes(column), compression_from_name(codec))
+
+ if isinstance(compression_level, int):
+ props.compression_level(compression_level)
+ elif compression_level is not None:
+ for column, level in compression_level.iteritems():
+ props.compression_level(tobytes(column), level)
+
+ # use_dictionary
+
+ if isinstance(use_dictionary, bool):
+ if use_dictionary:
+ props.enable_dictionary()
+ else:
+ props.disable_dictionary()
+ elif use_dictionary is not None:
+ # Deactivate dictionary encoding by default
+ props.disable_dictionary()
+ for column in use_dictionary:
+ props.enable_dictionary(tobytes(column))
+
+ # write_statistics
+
+ if isinstance(write_statistics, bool):
+ if write_statistics:
+ props.enable_statistics()
+ else:
+ props.disable_statistics()
+ elif write_statistics is not None:
+ # Deactivate statistics by default and enable for specified columns
+ props.disable_statistics()
+ for column in write_statistics:
+ props.enable_statistics(tobytes(column))
+
+ # use_byte_stream_split
+
+ if isinstance(use_byte_stream_split, bool):
+ if use_byte_stream_split:
+ props.encoding(ParquetEncoding_BYTE_STREAM_SPLIT)
+ elif use_byte_stream_split is not None:
+ for column in use_byte_stream_split:
+ props.encoding(tobytes(column),
+ ParquetEncoding_BYTE_STREAM_SPLIT)
+
+ if data_page_size is not None:
+ props.data_pagesize(data_page_size)
+
+ properties = props.build()
+
+ return properties
+
+
+cdef shared_ptr[ArrowWriterProperties] _create_arrow_writer_properties(
+ use_deprecated_int96_timestamps=False,
+ coerce_timestamps=None,
+ allow_truncated_timestamps=False,
+ writer_engine_version=None,
+ use_compliant_nested_type=False) except *:
+ """Arrow writer properties"""
+ cdef:
+ shared_ptr[ArrowWriterProperties] arrow_properties
+ ArrowWriterProperties.Builder arrow_props
+
+ # Store the original Arrow schema so things like dictionary types can
+ # be automatically reconstructed
+ arrow_props.store_schema()
+
+ # int96 support
+
+ if use_deprecated_int96_timestamps:
+ arrow_props.enable_deprecated_int96_timestamps()
+ else:
+ arrow_props.disable_deprecated_int96_timestamps()
+
+ # coerce_timestamps
+
+ if coerce_timestamps == 'ms':
+ arrow_props.coerce_timestamps(TimeUnit_MILLI)
+ elif coerce_timestamps == 'us':
+ arrow_props.coerce_timestamps(TimeUnit_MICRO)
+ elif coerce_timestamps is not None:
+ raise ValueError('Invalid value for coerce_timestamps: {0}'
+ .format(coerce_timestamps))
+
+ # allow_truncated_timestamps
+
+ if allow_truncated_timestamps:
+ arrow_props.allow_truncated_timestamps()
+ else:
+ arrow_props.disallow_truncated_timestamps()
+
+ # use_compliant_nested_type
+
+ if use_compliant_nested_type:
+ arrow_props.enable_compliant_nested_types()
+ else:
+ arrow_props.disable_compliant_nested_types()
+
+ # writer_engine_version
+
+ if writer_engine_version == "V1":
+ warnings.warn("V1 parquet writer engine is a no-op. Use V2.")
+ arrow_props.set_engine_version(ArrowWriterEngineVersion.V1)
+ elif writer_engine_version != "V2":
+ raise ValueError("Unsupported Writer Engine Version: {0}"
+ .format(writer_engine_version))
+
+ arrow_properties = arrow_props.build()
+
+ return arrow_properties
+
+
+cdef class ParquetWriter(_Weakrefable):
+ cdef:
+ unique_ptr[FileWriter] writer
+ shared_ptr[COutputStream] sink
+ bint own_sink
+
+ cdef readonly:
+ object use_dictionary
+ object use_deprecated_int96_timestamps
+ object use_byte_stream_split
+ object coerce_timestamps
+ object allow_truncated_timestamps
+ object compression
+ object compression_level
+ object data_page_version
+ object use_compliant_nested_type
+ object version
+ object write_statistics
+ object writer_engine_version
+ int row_group_size
+ int64_t data_page_size
+
+ def __cinit__(self, where, Schema schema, use_dictionary=None,
+ compression=None, version=None,
+ write_statistics=None,
+ MemoryPool memory_pool=None,
+ use_deprecated_int96_timestamps=False,
+ coerce_timestamps=None,
+ data_page_size=None,
+ allow_truncated_timestamps=False,
+ compression_level=None,
+ use_byte_stream_split=False,
+ writer_engine_version=None,
+ data_page_version=None,
+ use_compliant_nested_type=False):
+ cdef:
+ shared_ptr[WriterProperties] properties
+ shared_ptr[ArrowWriterProperties] arrow_properties
+ c_string c_where
+ CMemoryPool* pool
+
+ try:
+ where = _stringify_path(where)
+ except TypeError:
+ get_writer(where, &self.sink)
+ self.own_sink = False
+ else:
+ c_where = tobytes(where)
+ with nogil:
+ self.sink = GetResultValue(FileOutputStream.Open(c_where))
+ self.own_sink = True
+
+ properties = _create_writer_properties(
+ use_dictionary=use_dictionary,
+ compression=compression,
+ version=version,
+ write_statistics=write_statistics,
+ data_page_size=data_page_size,
+ compression_level=compression_level,
+ use_byte_stream_split=use_byte_stream_split,
+ data_page_version=data_page_version
+ )
+ arrow_properties = _create_arrow_writer_properties(
+ use_deprecated_int96_timestamps=use_deprecated_int96_timestamps,
+ coerce_timestamps=coerce_timestamps,
+ allow_truncated_timestamps=allow_truncated_timestamps,
+ writer_engine_version=writer_engine_version,
+ use_compliant_nested_type=use_compliant_nested_type
+ )
+
+ pool = maybe_unbox_memory_pool(memory_pool)
+ with nogil:
+ check_status(
+ FileWriter.Open(deref(schema.schema), pool,
+ self.sink, properties, arrow_properties,
+ &self.writer))
+
+ def close(self):
+ with nogil:
+ check_status(self.writer.get().Close())
+ if self.own_sink:
+ check_status(self.sink.get().Close())
+
+ def write_table(self, Table table, row_group_size=None):
+ cdef:
+ CTable* ctable = table.table
+ int64_t c_row_group_size
+
+ if row_group_size is None or row_group_size == -1:
+ c_row_group_size = ctable.num_rows()
+ elif row_group_size == 0:
+ raise ValueError('Row group size cannot be 0')
+ else:
+ c_row_group_size = row_group_size
+
+ with nogil:
+ check_status(self.writer.get()
+ .WriteTable(deref(ctable), c_row_group_size))
+
+ @property
+ def metadata(self):
+ cdef:
+ shared_ptr[CFileMetaData] metadata
+ FileMetaData result
+ with nogil:
+ metadata = self.writer.get().metadata()
+ if metadata:
+ result = FileMetaData()
+ result.init(metadata)
+ return result
+ raise RuntimeError(
+ 'file metadata is only available after writer close')
diff --git a/src/arrow/python/pyarrow/_plasma.pyx b/src/arrow/python/pyarrow/_plasma.pyx
new file mode 100644
index 000000000..e38c81f80
--- /dev/null
+++ b/src/arrow/python/pyarrow/_plasma.pyx
@@ -0,0 +1,867 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile=False
+# distutils: language = c++
+# cython: language_level = 3
+
+from libcpp cimport bool as c_bool, nullptr
+from libcpp.memory cimport shared_ptr, unique_ptr, make_shared
+from libcpp.string cimport string as c_string
+from libcpp.vector cimport vector as c_vector
+from libcpp.unordered_map cimport unordered_map
+from libc.stdint cimport int64_t, uint8_t, uintptr_t
+from cython.operator cimport dereference as deref, preincrement as inc
+from cpython.pycapsule cimport *
+
+from collections.abc import Sequence
+import random
+import socket
+import warnings
+
+import pyarrow
+from pyarrow.lib cimport (Buffer, NativeFile, _Weakrefable,
+ check_status, pyarrow_wrap_buffer)
+from pyarrow.lib import ArrowException, frombytes
+from pyarrow.includes.libarrow cimport (CBuffer, CMutableBuffer,
+ CFixedSizeBufferWriter, CStatus)
+from pyarrow.includes.libplasma cimport *
+
+PLASMA_WAIT_TIMEOUT = 2 ** 30
+
+
+cdef extern from "plasma/common.h" nogil:
+ cdef cppclass CCudaIpcPlaceholder" plasma::internal::CudaIpcPlaceholder":
+ pass
+
+ cdef cppclass CUniqueID" plasma::UniqueID":
+
+ @staticmethod
+ CUniqueID from_binary(const c_string& binary)
+
+ @staticmethod
+ CUniqueID from_random()
+
+ c_bool operator==(const CUniqueID& rhs) const
+
+ c_string hex() const
+
+ c_string binary() const
+
+ @staticmethod
+ int64_t size()
+
+ cdef enum CObjectState" plasma::ObjectState":
+ PLASMA_CREATED" plasma::ObjectState::PLASMA_CREATED"
+ PLASMA_SEALED" plasma::ObjectState::PLASMA_SEALED"
+
+ cdef struct CObjectTableEntry" plasma::ObjectTableEntry":
+ int fd
+ int device_num
+ int64_t map_size
+ ptrdiff_t offset
+ uint8_t* pointer
+ int64_t data_size
+ int64_t metadata_size
+ int ref_count
+ int64_t create_time
+ int64_t construct_duration
+ CObjectState state
+ shared_ptr[CCudaIpcPlaceholder] ipc_handle
+
+ ctypedef unordered_map[CUniqueID, unique_ptr[CObjectTableEntry]] \
+ CObjectTable" plasma::ObjectTable"
+
+
+cdef extern from "plasma/common.h":
+ cdef int64_t kDigestSize" plasma::kDigestSize"
+
+cdef extern from "plasma/client.h" nogil:
+
+ cdef cppclass CPlasmaClient" plasma::PlasmaClient":
+
+ CPlasmaClient()
+
+ CStatus Connect(const c_string& store_socket_name,
+ const c_string& manager_socket_name,
+ int release_delay, int num_retries)
+
+ CStatus Create(const CUniqueID& object_id,
+ int64_t data_size, const uint8_t* metadata, int64_t
+ metadata_size, const shared_ptr[CBuffer]* data)
+
+ CStatus CreateAndSeal(const CUniqueID& object_id,
+ const c_string& data, const c_string& metadata)
+
+ CStatus Get(const c_vector[CUniqueID] object_ids, int64_t timeout_ms,
+ c_vector[CObjectBuffer]* object_buffers)
+
+ CStatus Seal(const CUniqueID& object_id)
+
+ CStatus Evict(int64_t num_bytes, int64_t& num_bytes_evicted)
+
+ CStatus Hash(const CUniqueID& object_id, uint8_t* digest)
+
+ CStatus Release(const CUniqueID& object_id)
+
+ CStatus Contains(const CUniqueID& object_id, c_bool* has_object)
+
+ CStatus List(CObjectTable* objects)
+
+ CStatus Subscribe(int* fd)
+
+ CStatus DecodeNotifications(const uint8_t* buffer,
+ c_vector[CUniqueID]* object_ids,
+ c_vector[int64_t]* data_sizes,
+ c_vector[int64_t]* metadata_sizes)
+
+ CStatus GetNotification(int fd, CUniqueID* object_id,
+ int64_t* data_size, int64_t* metadata_size)
+
+ CStatus Disconnect()
+
+ CStatus Delete(const c_vector[CUniqueID] object_ids)
+
+ CStatus SetClientOptions(const c_string& client_name,
+ int64_t limit_output_memory)
+
+ c_string DebugString()
+
+ int64_t store_capacity()
+
+cdef extern from "plasma/client.h" nogil:
+
+ cdef struct CObjectBuffer" plasma::ObjectBuffer":
+ shared_ptr[CBuffer] data
+ shared_ptr[CBuffer] metadata
+
+
+def make_object_id(object_id):
+ return ObjectID(object_id)
+
+
+cdef class ObjectID(_Weakrefable):
+ """
+ An ObjectID represents a string of bytes used to identify Plasma objects.
+ """
+
+ cdef:
+ CUniqueID data
+
+ def __cinit__(self, object_id):
+ if (not isinstance(object_id, bytes) or
+ len(object_id) != CUniqueID.size()):
+ raise ValueError("Object ID must by 20 bytes,"
+ " is " + str(object_id))
+ self.data = CUniqueID.from_binary(object_id)
+
+ def __eq__(self, other):
+ try:
+ return self.data == (<ObjectID?>other).data
+ except TypeError:
+ return False
+
+ def __hash__(self):
+ return hash(self.data.binary())
+
+ def __repr__(self):
+ return "ObjectID(" + self.data.hex().decode() + ")"
+
+ def __reduce__(self):
+ return (make_object_id, (self.data.binary(),))
+
+ def binary(self):
+ """
+ Return the binary representation of this ObjectID.
+
+ Returns
+ -------
+ bytes
+ Binary representation of the ObjectID.
+ """
+ return self.data.binary()
+
+ @staticmethod
+ def from_random():
+ """
+ Returns a randomly generated ObjectID.
+
+ Returns
+ -------
+ ObjectID
+ A randomly generated ObjectID.
+ """
+ random_id = bytes(bytearray(
+ random.getrandbits(8) for _ in range(CUniqueID.size())))
+ return ObjectID(random_id)
+
+
+cdef class ObjectNotAvailable(_Weakrefable):
+ """
+ Placeholder for an object that was not available within the given timeout.
+ """
+ pass
+
+
+cdef class PlasmaBuffer(Buffer):
+ """
+ This is the type returned by calls to get with a PlasmaClient.
+
+ We define our own class instead of directly returning a buffer object so
+ that we can add a custom destructor which notifies Plasma that the object
+ is no longer being used, so the memory in the Plasma store backing the
+ object can potentially be freed.
+
+ Attributes
+ ----------
+ object_id : ObjectID
+ The ID of the object in the buffer.
+ client : PlasmaClient
+ The PlasmaClient that we use to communicate with the store and manager.
+ """
+
+ cdef:
+ ObjectID object_id
+ PlasmaClient client
+
+ @staticmethod
+ cdef PlasmaBuffer create(ObjectID object_id, PlasmaClient client,
+ const shared_ptr[CBuffer]& buffer):
+ cdef PlasmaBuffer self = PlasmaBuffer.__new__(PlasmaBuffer)
+ self.object_id = object_id
+ self.client = client
+ self.init(buffer)
+ return self
+
+ def __init__(self):
+ raise TypeError("Do not call PlasmaBuffer's constructor directly, use "
+ "`PlasmaClient.create` instead.")
+
+ def __dealloc__(self):
+ """
+ Notify Plasma that the object is no longer needed.
+
+ If the plasma client has been shut down, then don't do anything.
+ """
+ self.client._release(self.object_id)
+
+
+class PlasmaObjectNotFound(ArrowException):
+ pass
+
+
+class PlasmaStoreFull(ArrowException):
+ pass
+
+
+class PlasmaObjectExists(ArrowException):
+ pass
+
+
+cdef int plasma_check_status(const CStatus& status) nogil except -1:
+ if status.ok():
+ return 0
+
+ with gil:
+ message = frombytes(status.message())
+ if IsPlasmaObjectExists(status):
+ raise PlasmaObjectExists(message)
+ elif IsPlasmaObjectNotFound(status):
+ raise PlasmaObjectNotFound(message)
+ elif IsPlasmaStoreFull(status):
+ raise PlasmaStoreFull(message)
+
+ return check_status(status)
+
+
+def get_socket_from_fd(fileno, family, type):
+ import socket
+ return socket.socket(fileno=fileno, family=family, type=type)
+
+
+cdef class PlasmaClient(_Weakrefable):
+ """
+ The PlasmaClient is used to interface with a plasma store and manager.
+
+ The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
+ buffer, and get a buffer. Buffers are referred to by object IDs, which are
+ strings.
+ """
+
+ cdef:
+ shared_ptr[CPlasmaClient] client
+ int notification_fd
+ c_string store_socket_name
+
+ def __cinit__(self):
+ self.client.reset(new CPlasmaClient())
+ self.notification_fd = -1
+ self.store_socket_name = b""
+
+ cdef _get_object_buffers(self, object_ids, int64_t timeout_ms,
+ c_vector[CObjectBuffer]* result):
+ cdef:
+ c_vector[CUniqueID] ids
+ ObjectID object_id
+
+ for object_id in object_ids:
+ ids.push_back(object_id.data)
+ with nogil:
+ plasma_check_status(self.client.get().Get(ids, timeout_ms, result))
+
+ # XXX C++ API should instead expose some kind of CreateAuto()
+ cdef _make_mutable_plasma_buffer(self, ObjectID object_id, uint8_t* data,
+ int64_t size):
+ cdef shared_ptr[CBuffer] buffer
+ buffer.reset(new CMutableBuffer(data, size))
+ return PlasmaBuffer.create(object_id, self, buffer)
+
+ @property
+ def store_socket_name(self):
+ return self.store_socket_name.decode()
+
+ def create(self, ObjectID object_id, int64_t data_size,
+ c_string metadata=b""):
+ """
+ Create a new buffer in the PlasmaStore for a particular object ID.
+
+ The returned buffer is mutable until seal is called.
+
+ Parameters
+ ----------
+ object_id : ObjectID
+ The object ID used to identify an object.
+ size : int
+ The size in bytes of the created buffer.
+ metadata : bytes
+ An optional string of bytes encoding whatever metadata the user
+ wishes to encode.
+
+ Raises
+ ------
+ PlasmaObjectExists
+ This exception is raised if the object could not be created because
+ there already is an object with the same ID in the plasma store.
+
+ PlasmaStoreFull: This exception is raised if the object could
+ not be created because the plasma store is unable to evict
+ enough objects to create room for it.
+ """
+ cdef shared_ptr[CBuffer] data
+ with nogil:
+ plasma_check_status(
+ self.client.get().Create(object_id.data, data_size,
+ <uint8_t*>(metadata.data()),
+ metadata.size(), &data))
+ return self._make_mutable_plasma_buffer(object_id,
+ data.get().mutable_data(),
+ data_size)
+
+ def create_and_seal(self, ObjectID object_id, c_string data,
+ c_string metadata=b""):
+ """
+ Store a new object in the PlasmaStore for a particular object ID.
+
+ Parameters
+ ----------
+ object_id : ObjectID
+ The object ID used to identify an object.
+ data : bytes
+ The object to store.
+ metadata : bytes
+ An optional string of bytes encoding whatever metadata the user
+ wishes to encode.
+
+ Raises
+ ------
+ PlasmaObjectExists
+ This exception is raised if the object could not be created because
+ there already is an object with the same ID in the plasma store.
+
+ PlasmaStoreFull: This exception is raised if the object could
+ not be created because the plasma store is unable to evict
+ enough objects to create room for it.
+ """
+ with nogil:
+ plasma_check_status(
+ self.client.get().CreateAndSeal(object_id.data, data,
+ metadata))
+
+ def get_buffers(self, object_ids, timeout_ms=-1, with_meta=False):
+ """
+ Returns data buffer from the PlasmaStore based on object ID.
+
+ If the object has not been sealed yet, this call will block. The
+ retrieved buffer is immutable.
+
+ Parameters
+ ----------
+ object_ids : list
+ A list of ObjectIDs used to identify some objects.
+ timeout_ms : int
+ The number of milliseconds that the get call should block before
+ timing out and returning. Pass -1 if the call should block and 0
+ if the call should return immediately.
+ with_meta : bool
+
+ Returns
+ -------
+ list
+ If with_meta=False, this is a list of PlasmaBuffers for the data
+ associated with the object_ids and None if the object was not
+ available. If with_meta=True, this is a list of tuples of
+ PlasmaBuffer and metadata bytes.
+ """
+ cdef c_vector[CObjectBuffer] object_buffers
+ self._get_object_buffers(object_ids, timeout_ms, &object_buffers)
+ result = []
+ for i in range(object_buffers.size()):
+ if object_buffers[i].data.get() != nullptr:
+ data = pyarrow_wrap_buffer(object_buffers[i].data)
+ else:
+ data = None
+ if not with_meta:
+ result.append(data)
+ else:
+ if object_buffers[i].metadata.get() != nullptr:
+ size = object_buffers[i].metadata.get().size()
+ metadata = object_buffers[i].metadata.get().data()[:size]
+ else:
+ metadata = None
+ result.append((metadata, data))
+ return result
+
+ def get_metadata(self, object_ids, timeout_ms=-1):
+ """
+ Returns metadata buffer from the PlasmaStore based on object ID.
+
+ If the object has not been sealed yet, this call will block. The
+ retrieved buffer is immutable.
+
+ Parameters
+ ----------
+ object_ids : list
+ A list of ObjectIDs used to identify some objects.
+ timeout_ms : int
+ The number of milliseconds that the get call should block before
+ timing out and returning. Pass -1 if the call should block and 0
+ if the call should return immediately.
+
+ Returns
+ -------
+ list
+ List of PlasmaBuffers for the metadata associated with the
+ object_ids and None if the object was not available.
+ """
+ cdef c_vector[CObjectBuffer] object_buffers
+ self._get_object_buffers(object_ids, timeout_ms, &object_buffers)
+ result = []
+ for i in range(object_buffers.size()):
+ if object_buffers[i].metadata.get() != nullptr:
+ result.append(pyarrow_wrap_buffer(object_buffers[i].metadata))
+ else:
+ result.append(None)
+ return result
+
+ def put_raw_buffer(self, object value, ObjectID object_id=None,
+ c_string metadata=b"", int memcopy_threads=6):
+ """
+ Store Python buffer into the object store.
+
+ Parameters
+ ----------
+ value : Python object that implements the buffer protocol
+ A Python buffer object to store.
+ object_id : ObjectID, default None
+ If this is provided, the specified object ID will be used to refer
+ to the object.
+ metadata : bytes
+ An optional string of bytes encoding whatever metadata the user
+ wishes to encode.
+ memcopy_threads : int, default 6
+ The number of threads to use to write the serialized object into
+ the object store for large objects.
+
+ Returns
+ -------
+ The object ID associated to the Python buffer object.
+ """
+ cdef ObjectID target_id = (object_id if object_id
+ else ObjectID.from_random())
+ cdef Buffer arrow_buffer = pyarrow.py_buffer(value)
+ write_buffer = self.create(target_id, len(value), metadata)
+ stream = pyarrow.FixedSizeBufferWriter(write_buffer)
+ stream.set_memcopy_threads(memcopy_threads)
+ stream.write(arrow_buffer)
+ self.seal(target_id)
+ return target_id
+
+ def put(self, object value, ObjectID object_id=None, int memcopy_threads=6,
+ serialization_context=None):
+ """
+ Store a Python value into the object store.
+
+ Parameters
+ ----------
+ value : object
+ A Python object to store.
+ object_id : ObjectID, default None
+ If this is provided, the specified object ID will be used to refer
+ to the object.
+ memcopy_threads : int, default 6
+ The number of threads to use to write the serialized object into
+ the object store for large objects.
+ serialization_context : pyarrow.SerializationContext, default None
+ Custom serialization and deserialization context.
+
+ Returns
+ -------
+ The object ID associated to the Python object.
+ """
+ cdef ObjectID target_id = (object_id if object_id
+ else ObjectID.from_random())
+ if serialization_context is not None:
+ warnings.warn(
+ "'serialization_context' is deprecated and will be removed "
+ "in a future version.",
+ FutureWarning, stacklevel=2
+ )
+ serialized = pyarrow.lib._serialize(value, serialization_context)
+ buffer = self.create(target_id, serialized.total_bytes)
+ stream = pyarrow.FixedSizeBufferWriter(buffer)
+ stream.set_memcopy_threads(memcopy_threads)
+ serialized.write_to(stream)
+ self.seal(target_id)
+ return target_id
+
+ def get(self, object_ids, int timeout_ms=-1, serialization_context=None):
+ """
+ Get one or more Python values from the object store.
+
+ Parameters
+ ----------
+ object_ids : list or ObjectID
+ Object ID or list of object IDs associated to the values we get
+ from the store.
+ timeout_ms : int, default -1
+ The number of milliseconds that the get call should block before
+ timing out and returning. Pass -1 if the call should block and 0
+ if the call should return immediately.
+ serialization_context : pyarrow.SerializationContext, default None
+ Custom serialization and deserialization context.
+
+ Returns
+ -------
+ list or object
+ Python value or list of Python values for the data associated with
+ the object_ids and ObjectNotAvailable if the object was not
+ available.
+ """
+ if serialization_context is not None:
+ warnings.warn(
+ "'serialization_context' is deprecated and will be removed "
+ "in a future version.",
+ FutureWarning, stacklevel=2
+ )
+ if isinstance(object_ids, Sequence):
+ results = []
+ buffers = self.get_buffers(object_ids, timeout_ms)
+ for i in range(len(object_ids)):
+ # buffers[i] is None if this object was not available within
+ # the timeout
+ if buffers[i]:
+ val = pyarrow.lib._deserialize(buffers[i],
+ serialization_context)
+ results.append(val)
+ else:
+ results.append(ObjectNotAvailable)
+ return results
+ else:
+ return self.get([object_ids], timeout_ms, serialization_context)[0]
+
+ def seal(self, ObjectID object_id):
+ """
+ Seal the buffer in the PlasmaStore for a particular object ID.
+
+ Once a buffer has been sealed, the buffer is immutable and can only be
+ accessed through get.
+
+ Parameters
+ ----------
+ object_id : ObjectID
+ A string used to identify an object.
+ """
+ with nogil:
+ plasma_check_status(self.client.get().Seal(object_id.data))
+
+ def _release(self, ObjectID object_id):
+ """
+ Notify Plasma that the object is no longer needed.
+
+ Parameters
+ ----------
+ object_id : ObjectID
+ A string used to identify an object.
+ """
+ with nogil:
+ plasma_check_status(self.client.get().Release(object_id.data))
+
+ def contains(self, ObjectID object_id):
+ """
+ Check if the object is present and sealed in the PlasmaStore.
+
+ Parameters
+ ----------
+ object_id : ObjectID
+ A string used to identify an object.
+ """
+ cdef c_bool is_contained
+ with nogil:
+ plasma_check_status(self.client.get().Contains(object_id.data,
+ &is_contained))
+ return is_contained
+
+ def hash(self, ObjectID object_id):
+ """
+ Compute the checksum of an object in the object store.
+
+ Parameters
+ ----------
+ object_id : ObjectID
+ A string used to identify an object.
+
+ Returns
+ -------
+ bytes
+ A digest string object's hash. If the object isn't in the object
+ store, the string will have length zero.
+ """
+ cdef c_vector[uint8_t] digest = c_vector[uint8_t](kDigestSize)
+ with nogil:
+ plasma_check_status(self.client.get().Hash(object_id.data,
+ digest.data()))
+ return bytes(digest[:])
+
+ def evict(self, int64_t num_bytes):
+ """
+ Evict some objects until to recover some bytes.
+
+ Recover at least num_bytes bytes if possible.
+
+ Parameters
+ ----------
+ num_bytes : int
+ The number of bytes to attempt to recover.
+ """
+ cdef int64_t num_bytes_evicted = -1
+ with nogil:
+ plasma_check_status(
+ self.client.get().Evict(num_bytes, num_bytes_evicted))
+ return num_bytes_evicted
+
+ def subscribe(self):
+ """Subscribe to notifications about sealed objects."""
+ with nogil:
+ plasma_check_status(
+ self.client.get().Subscribe(&self.notification_fd))
+
+ def get_notification_socket(self):
+ """
+ Get the notification socket.
+ """
+ return get_socket_from_fd(self.notification_fd,
+ family=socket.AF_UNIX,
+ type=socket.SOCK_STREAM)
+
+ def decode_notifications(self, const uint8_t* buf):
+ """
+ Get the notification from the buffer.
+
+ Returns
+ -------
+ [ObjectID]
+ The list of object IDs in the notification message.
+ c_vector[int64_t]
+ The data sizes of the objects in the notification message.
+ c_vector[int64_t]
+ The metadata sizes of the objects in the notification message.
+ """
+ cdef c_vector[CUniqueID] ids
+ cdef c_vector[int64_t] data_sizes
+ cdef c_vector[int64_t] metadata_sizes
+ with nogil:
+ status = self.client.get().DecodeNotifications(buf,
+ &ids,
+ &data_sizes,
+ &metadata_sizes)
+ plasma_check_status(status)
+ object_ids = []
+ for object_id in ids:
+ object_ids.append(ObjectID(object_id.binary()))
+ return object_ids, data_sizes, metadata_sizes
+
+ def get_next_notification(self):
+ """
+ Get the next notification from the notification socket.
+
+ Returns
+ -------
+ ObjectID
+ The object ID of the object that was stored.
+ int
+ The data size of the object that was stored.
+ int
+ The metadata size of the object that was stored.
+ """
+ cdef ObjectID object_id = ObjectID(CUniqueID.size() * b"\0")
+ cdef int64_t data_size
+ cdef int64_t metadata_size
+ with nogil:
+ status = self.client.get().GetNotification(self.notification_fd,
+ &object_id.data,
+ &data_size,
+ &metadata_size)
+ plasma_check_status(status)
+ return object_id, data_size, metadata_size
+
+ def to_capsule(self):
+ return PyCapsule_New(<void *>self.client.get(), "plasma", NULL)
+
+ def disconnect(self):
+ """
+ Disconnect this client from the Plasma store.
+ """
+ with nogil:
+ plasma_check_status(self.client.get().Disconnect())
+
+ def delete(self, object_ids):
+ """
+ Delete the objects with the given IDs from other object store.
+
+ Parameters
+ ----------
+ object_ids : list
+ A list of strings used to identify the objects.
+ """
+ cdef c_vector[CUniqueID] ids
+ cdef ObjectID object_id
+ for object_id in object_ids:
+ ids.push_back(object_id.data)
+ with nogil:
+ plasma_check_status(self.client.get().Delete(ids))
+
+ def set_client_options(self, client_name, int64_t limit_output_memory):
+ cdef c_string name
+ name = client_name.encode()
+ with nogil:
+ plasma_check_status(
+ self.client.get().SetClientOptions(name, limit_output_memory))
+
+ def debug_string(self):
+ cdef c_string result
+ with nogil:
+ result = self.client.get().DebugString()
+ return result.decode()
+
+ def list(self):
+ """
+ Experimental: List the objects in the store.
+
+ Returns
+ -------
+ dict
+ Dictionary from ObjectIDs to an "info" dictionary describing the
+ object. The "info" dictionary has the following entries:
+
+ data_size
+ size of the object in bytes
+
+ metadata_size
+ size of the object metadata in bytes
+
+ ref_count
+ Number of clients referencing the object buffer
+
+ create_time
+ Unix timestamp of the creation of the object
+
+ construct_duration
+ Time the creation of the object took in seconds
+
+ state
+ "created" if the object is still being created and
+ "sealed" if it is already sealed
+ """
+ cdef CObjectTable objects
+ with nogil:
+ plasma_check_status(self.client.get().List(&objects))
+ result = dict()
+ cdef ObjectID object_id
+ cdef CObjectTableEntry entry
+ it = objects.begin()
+ while it != objects.end():
+ object_id = ObjectID(deref(it).first.binary())
+ entry = deref(deref(it).second)
+ if entry.state == CObjectState.PLASMA_CREATED:
+ state = "created"
+ else:
+ state = "sealed"
+ result[object_id] = {
+ "data_size": entry.data_size,
+ "metadata_size": entry.metadata_size,
+ "ref_count": entry.ref_count,
+ "create_time": entry.create_time,
+ "construct_duration": entry.construct_duration,
+ "state": state
+ }
+ inc(it)
+ return result
+
+ def store_capacity(self):
+ """
+ Get the memory capacity of the store.
+
+ Returns
+ -------
+
+ int
+ The memory capacity of the store in bytes.
+ """
+ return self.client.get().store_capacity()
+
+
+def connect(store_socket_name, int num_retries=-1):
+ """
+ Return a new PlasmaClient that is connected a plasma store and
+ optionally a manager.
+
+ Parameters
+ ----------
+ store_socket_name : str
+ Name of the socket the plasma store is listening at.
+ num_retries : int, default -1
+ Number of times to try to connect to plasma store. Default value of -1
+ uses the default (50)
+ """
+ cdef PlasmaClient result = PlasmaClient()
+ cdef int deprecated_release_delay = 0
+ result.store_socket_name = store_socket_name.encode()
+ with nogil:
+ plasma_check_status(
+ result.client.get().Connect(result.store_socket_name, b"",
+ deprecated_release_delay, num_retries))
+ return result
diff --git a/src/arrow/python/pyarrow/_s3fs.pyx b/src/arrow/python/pyarrow/_s3fs.pyx
new file mode 100644
index 000000000..5829d74d3
--- /dev/null
+++ b/src/arrow/python/pyarrow/_s3fs.pyx
@@ -0,0 +1,284 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from pyarrow.lib cimport (check_status, pyarrow_wrap_metadata,
+ pyarrow_unwrap_metadata)
+from pyarrow.lib import frombytes, tobytes, KeyValueMetadata
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_fs cimport *
+from pyarrow._fs cimport FileSystem
+
+
+cpdef enum S3LogLevel:
+ Off = <int8_t> CS3LogLevel_Off
+ Fatal = <int8_t> CS3LogLevel_Fatal
+ Error = <int8_t> CS3LogLevel_Error
+ Warn = <int8_t> CS3LogLevel_Warn
+ Info = <int8_t> CS3LogLevel_Info
+ Debug = <int8_t> CS3LogLevel_Debug
+ Trace = <int8_t> CS3LogLevel_Trace
+
+
+def initialize_s3(S3LogLevel log_level=S3LogLevel.Fatal):
+ """
+ Initialize S3 support
+
+ Parameters
+ ----------
+ log_level : S3LogLevel
+ level of logging
+ """
+ cdef CS3GlobalOptions options
+ options.log_level = <CS3LogLevel> log_level
+ check_status(CInitializeS3(options))
+
+
+def finalize_s3():
+ check_status(CFinalizeS3())
+
+
+cdef class S3FileSystem(FileSystem):
+ """
+ S3-backed FileSystem implementation
+
+ If neither access_key nor secret_key are provided, and role_arn is also not
+ provided, then attempts to initialize from AWS environment variables,
+ otherwise both access_key and secret_key must be provided.
+
+ If role_arn is provided instead of access_key and secret_key, temporary
+ credentials will be fetched by issuing a request to STS to assume the
+ specified role.
+
+ Note: S3 buckets are special and the operations available on them may be
+ limited or more expensive than desired.
+
+ Parameters
+ ----------
+ access_key : str, default None
+ AWS Access Key ID. Pass None to use the standard AWS environment
+ variables and/or configuration file.
+ secret_key : str, default None
+ AWS Secret Access key. Pass None to use the standard AWS environment
+ variables and/or configuration file.
+ session_token : str, default None
+ AWS Session Token. An optional session token, required if access_key
+ and secret_key are temporary credentials from STS.
+ anonymous : boolean, default False
+ Whether to connect anonymously if access_key and secret_key are None.
+ If true, will not attempt to look up credentials using standard AWS
+ configuration methods.
+ role_arn : str, default None
+ AWS Role ARN. If provided instead of access_key and secret_key,
+ temporary credentials will be fetched by assuming this role.
+ session_name : str, default None
+ An optional identifier for the assumed role session.
+ external_id : str, default None
+ An optional unique identifier that might be required when you assume
+ a role in another account.
+ load_frequency : int, default 900
+ The frequency (in seconds) with which temporary credentials from an
+ assumed role session will be refreshed.
+ region : str, default 'us-east-1'
+ AWS region to connect to.
+ scheme : str, default 'https'
+ S3 connection transport scheme.
+ endpoint_override : str, default None
+ Override region with a connect string such as "localhost:9000"
+ background_writes : boolean, default True
+ Whether file writes will be issued in the background, without
+ blocking.
+ default_metadata : mapping or KeyValueMetadata, default None
+ Default metadata for open_output_stream. This will be ignored if
+ non-empty metadata is passed to open_output_stream.
+ proxy_options : dict or str, default None
+ If a proxy is used, provide the options here. Supported options are:
+ 'scheme' (str: 'http' or 'https'; required), 'host' (str; required),
+ 'port' (int; required), 'username' (str; optional),
+ 'password' (str; optional).
+ A proxy URI (str) can also be provided, in which case these options
+ will be derived from the provided URI.
+ The following are equivalent::
+
+ S3FileSystem(proxy_options='http://username:password@localhost:8020')
+ S3FileSystem(proxy_options={'scheme': 'http', 'host': 'localhost',
+ 'port': 8020, 'username': 'username',
+ 'password': 'password'})
+ """
+
+ cdef:
+ CS3FileSystem* s3fs
+
+ def __init__(self, *, access_key=None, secret_key=None, session_token=None,
+ bint anonymous=False, region=None, scheme=None,
+ endpoint_override=None, bint background_writes=True,
+ default_metadata=None, role_arn=None, session_name=None,
+ external_id=None, load_frequency=900, proxy_options=None):
+ cdef:
+ CS3Options options
+ shared_ptr[CS3FileSystem] wrapped
+
+ if access_key is not None and secret_key is None:
+ raise ValueError(
+ 'In order to initialize with explicit credentials both '
+ 'access_key and secret_key must be provided, '
+ '`secret_key` is not set.'
+ )
+ elif access_key is None and secret_key is not None:
+ raise ValueError(
+ 'In order to initialize with explicit credentials both '
+ 'access_key and secret_key must be provided, '
+ '`access_key` is not set.'
+ )
+
+ elif session_token is not None and (access_key is None or
+ secret_key is None):
+ raise ValueError(
+ 'In order to initialize a session with temporary credentials, '
+ 'both secret_key and access_key must be provided in addition '
+ 'to session_token.'
+ )
+
+ elif (access_key is not None or secret_key is not None):
+ if anonymous:
+ raise ValueError(
+ 'Cannot pass anonymous=True together with access_key '
+ 'and secret_key.')
+
+ if role_arn:
+ raise ValueError(
+ 'Cannot provide role_arn with access_key and secret_key')
+
+ if session_token is None:
+ session_token = ""
+
+ options = CS3Options.FromAccessKey(
+ tobytes(access_key),
+ tobytes(secret_key),
+ tobytes(session_token)
+ )
+ elif anonymous:
+ if role_arn:
+ raise ValueError(
+ 'Cannot provide role_arn with anonymous=True')
+
+ options = CS3Options.Anonymous()
+ elif role_arn:
+
+ options = CS3Options.FromAssumeRole(
+ tobytes(role_arn),
+ tobytes(session_name),
+ tobytes(external_id),
+ load_frequency
+ )
+ else:
+ options = CS3Options.Defaults()
+
+ if region is not None:
+ options.region = tobytes(region)
+ if scheme is not None:
+ options.scheme = tobytes(scheme)
+ if endpoint_override is not None:
+ options.endpoint_override = tobytes(endpoint_override)
+ if background_writes is not None:
+ options.background_writes = background_writes
+ if default_metadata is not None:
+ if not isinstance(default_metadata, KeyValueMetadata):
+ default_metadata = KeyValueMetadata(default_metadata)
+ options.default_metadata = pyarrow_unwrap_metadata(
+ default_metadata)
+
+ if proxy_options is not None:
+ if isinstance(proxy_options, dict):
+ options.proxy_options.scheme = tobytes(proxy_options["scheme"])
+ options.proxy_options.host = tobytes(proxy_options["host"])
+ options.proxy_options.port = proxy_options["port"]
+ proxy_username = proxy_options.get("username", None)
+ if proxy_username:
+ options.proxy_options.username = tobytes(proxy_username)
+ proxy_password = proxy_options.get("password", None)
+ if proxy_password:
+ options.proxy_options.password = tobytes(proxy_password)
+ elif isinstance(proxy_options, str):
+ options.proxy_options = GetResultValue(
+ CS3ProxyOptions.FromUriString(tobytes(proxy_options)))
+ else:
+ raise TypeError(
+ "'proxy_options': expected 'dict' or 'str', "
+ f"got {type(proxy_options)} instead.")
+
+ with nogil:
+ wrapped = GetResultValue(CS3FileSystem.Make(options))
+
+ self.init(<shared_ptr[CFileSystem]> wrapped)
+
+ cdef init(self, const shared_ptr[CFileSystem]& wrapped):
+ FileSystem.init(self, wrapped)
+ self.s3fs = <CS3FileSystem*> wrapped.get()
+
+ @classmethod
+ def _reconstruct(cls, kwargs):
+ return cls(**kwargs)
+
+ def __reduce__(self):
+ cdef CS3Options opts = self.s3fs.options()
+
+ # if creds were explicitly provided, then use them
+ # else obtain them as they were last time.
+ if opts.credentials_kind == CS3CredentialsKind_Explicit:
+ access_key = frombytes(opts.GetAccessKey())
+ secret_key = frombytes(opts.GetSecretKey())
+ session_token = frombytes(opts.GetSessionToken())
+ else:
+ access_key = None
+ secret_key = None
+ session_token = None
+
+ return (
+ S3FileSystem._reconstruct, (dict(
+ access_key=access_key,
+ secret_key=secret_key,
+ session_token=session_token,
+ anonymous=(opts.credentials_kind ==
+ CS3CredentialsKind_Anonymous),
+ region=frombytes(opts.region),
+ scheme=frombytes(opts.scheme),
+ endpoint_override=frombytes(opts.endpoint_override),
+ role_arn=frombytes(opts.role_arn),
+ session_name=frombytes(opts.session_name),
+ external_id=frombytes(opts.external_id),
+ load_frequency=opts.load_frequency,
+ background_writes=opts.background_writes,
+ default_metadata=pyarrow_wrap_metadata(opts.default_metadata),
+ proxy_options={'scheme': frombytes(opts.proxy_options.scheme),
+ 'host': frombytes(opts.proxy_options.host),
+ 'port': opts.proxy_options.port,
+ 'username': frombytes(
+ opts.proxy_options.username),
+ 'password': frombytes(
+ opts.proxy_options.password)}
+ ),)
+ )
+
+ @property
+ def region(self):
+ """
+ The AWS region this filesystem connects to.
+ """
+ return frombytes(self.s3fs.region())
diff --git a/src/arrow/python/pyarrow/array.pxi b/src/arrow/python/pyarrow/array.pxi
new file mode 100644
index 000000000..97cbce759
--- /dev/null
+++ b/src/arrow/python/pyarrow/array.pxi
@@ -0,0 +1,2541 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import warnings
+
+
+cdef _sequence_to_array(object sequence, object mask, object size,
+ DataType type, CMemoryPool* pool, c_bool from_pandas):
+ cdef:
+ int64_t c_size
+ PyConversionOptions options
+ shared_ptr[CChunkedArray] chunked
+
+ if type is not None:
+ options.type = type.sp_type
+
+ if size is not None:
+ options.size = size
+
+ options.from_pandas = from_pandas
+ options.ignore_timezone = os.environ.get('PYARROW_IGNORE_TIMEZONE', False)
+
+ with nogil:
+ chunked = GetResultValue(
+ ConvertPySequence(sequence, mask, options, pool)
+ )
+
+ if chunked.get().num_chunks() == 1:
+ return pyarrow_wrap_array(chunked.get().chunk(0))
+ else:
+ return pyarrow_wrap_chunked_array(chunked)
+
+
+cdef inline _is_array_like(obj):
+ if isinstance(obj, np.ndarray):
+ return True
+ return pandas_api._have_pandas_internal() and pandas_api.is_array_like(obj)
+
+
+def _ndarray_to_arrow_type(object values, DataType type):
+ return pyarrow_wrap_data_type(_ndarray_to_type(values, type))
+
+
+cdef shared_ptr[CDataType] _ndarray_to_type(object values,
+ DataType type) except *:
+ cdef shared_ptr[CDataType] c_type
+
+ dtype = values.dtype
+
+ if type is None and dtype != object:
+ with nogil:
+ check_status(NumPyDtypeToArrow(dtype, &c_type))
+
+ if type is not None:
+ c_type = type.sp_type
+
+ return c_type
+
+
+cdef _ndarray_to_array(object values, object mask, DataType type,
+ c_bool from_pandas, c_bool safe, CMemoryPool* pool):
+ cdef:
+ shared_ptr[CChunkedArray] chunked_out
+ shared_ptr[CDataType] c_type = _ndarray_to_type(values, type)
+ CCastOptions cast_options = CCastOptions(safe)
+
+ with nogil:
+ check_status(NdarrayToArrow(pool, values, mask, from_pandas,
+ c_type, cast_options, &chunked_out))
+
+ if chunked_out.get().num_chunks() > 1:
+ return pyarrow_wrap_chunked_array(chunked_out)
+ else:
+ return pyarrow_wrap_array(chunked_out.get().chunk(0))
+
+
+cdef _codes_to_indices(object codes, object mask, DataType type,
+ MemoryPool memory_pool):
+ """
+ Convert the codes of a pandas Categorical to indices for a pyarrow
+ DictionaryArray, taking into account missing values + mask
+ """
+ if mask is None:
+ mask = codes == -1
+ else:
+ mask = mask | (codes == -1)
+ return array(codes, mask=mask, type=type, memory_pool=memory_pool)
+
+
+def _handle_arrow_array_protocol(obj, type, mask, size):
+ if mask is not None or size is not None:
+ raise ValueError(
+ "Cannot specify a mask or a size when passing an object that is "
+ "converted with the __arrow_array__ protocol.")
+ res = obj.__arrow_array__(type=type)
+ if not isinstance(res, (Array, ChunkedArray)):
+ raise TypeError("The object's __arrow_array__ method does not "
+ "return a pyarrow Array or ChunkedArray.")
+ return res
+
+
+def array(object obj, type=None, mask=None, size=None, from_pandas=None,
+ bint safe=True, MemoryPool memory_pool=None):
+ """
+ Create pyarrow.Array instance from a Python object.
+
+ Parameters
+ ----------
+ obj : sequence, iterable, ndarray or Series
+ If both type and size are specified may be a single use iterable. If
+ not strongly-typed, Arrow type will be inferred for resulting array.
+ type : pyarrow.DataType
+ Explicit type to attempt to coerce to, otherwise will be inferred from
+ the data.
+ mask : array[bool], optional
+ Indicate which values are null (True) or not null (False).
+ size : int64, optional
+ Size of the elements. If the input is larger than size bail at this
+ length. For iterators, if size is larger than the input iterator this
+ will be treated as a "max size", but will involve an initial allocation
+ of size followed by a resize to the actual size (so if you know the
+ exact size specifying it correctly will give you better performance).
+ from_pandas : bool, default None
+ Use pandas's semantics for inferring nulls from values in
+ ndarray-like data. If passed, the mask tasks precedence, but
+ if a value is unmasked (not-null), but still null according to
+ pandas semantics, then it is null. Defaults to False if not
+ passed explicitly by user, or True if a pandas object is
+ passed in.
+ safe : bool, default True
+ Check for overflows or other unsafe conversions.
+ memory_pool : pyarrow.MemoryPool, optional
+ If not passed, will allocate memory from the currently-set default
+ memory pool.
+
+ Returns
+ -------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ A ChunkedArray instead of an Array is returned if:
+
+ - the object data overflowed binary storage.
+ - the object's ``__arrow_array__`` protocol method returned a chunked
+ array.
+
+ Notes
+ -----
+ Localized timestamps will currently be returned as UTC (pandas's native
+ representation). Timezone-naive data will be implicitly interpreted as
+ UTC.
+
+ Pandas's DateOffsets and dateutil.relativedelta.relativedelta are by
+ default converted as MonthDayNanoIntervalArray. relativedelta leapdays
+ are ignored as are all absolute fields on both objects. datetime.timedelta
+ can also be converted to MonthDayNanoIntervalArray but this requires
+ passing MonthDayNanoIntervalType explicitly.
+
+ Converting to dictionary array will promote to a wider integer type for
+ indices if the number of distinct values cannot be represented, even if
+ the index type was explicitly set. This means that if there are more than
+ 127 values the returned dictionary array's index type will be at least
+ pa.int16() even if pa.int8() was passed to the function. Note that an
+ explicit index type will not be demoted even if it is wider than required.
+
+ Examples
+ --------
+ >>> import pandas as pd
+ >>> import pyarrow as pa
+ >>> pa.array(pd.Series([1, 2]))
+ <pyarrow.lib.Int64Array object at 0x7f674e4c0e10>
+ [
+ 1,
+ 2
+ ]
+
+ >>> pa.array(["a", "b", "a"], type=pa.dictionary(pa.int8(), pa.string()))
+ <pyarrow.lib.DictionaryArray object at 0x7feb288d9040>
+ -- dictionary:
+ [
+ "a",
+ "b"
+ ]
+ -- indices:
+ [
+ 0,
+ 1,
+ 0
+ ]
+
+ >>> import numpy as np
+ >>> pa.array(pd.Series([1, 2]), mask=np.array([0, 1], dtype=bool))
+ <pyarrow.lib.Int64Array object at 0x7f9019e11208>
+ [
+ 1,
+ null
+ ]
+
+ >>> arr = pa.array(range(1024), type=pa.dictionary(pa.int8(), pa.int64()))
+ >>> arr.type.index_type
+ DataType(int16)
+ """
+ cdef:
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ bint is_pandas_object = False
+ bint c_from_pandas
+
+ type = ensure_type(type, allow_none=True)
+
+ if from_pandas is None:
+ c_from_pandas = False
+ else:
+ c_from_pandas = from_pandas
+
+ if hasattr(obj, '__arrow_array__'):
+ return _handle_arrow_array_protocol(obj, type, mask, size)
+ elif _is_array_like(obj):
+ if mask is not None:
+ if _is_array_like(mask):
+ mask = get_values(mask, &is_pandas_object)
+ else:
+ raise TypeError("Mask must be a numpy array "
+ "when converting numpy arrays")
+
+ values = get_values(obj, &is_pandas_object)
+ if is_pandas_object and from_pandas is None:
+ c_from_pandas = True
+
+ if isinstance(values, np.ma.MaskedArray):
+ if mask is not None:
+ raise ValueError("Cannot pass a numpy masked array and "
+ "specify a mask at the same time")
+ else:
+ # don't use shrunken masks
+ mask = None if values.mask is np.ma.nomask else values.mask
+ values = values.data
+
+ if mask is not None:
+ if mask.dtype != np.bool_:
+ raise TypeError("Mask must be boolean dtype")
+ if mask.ndim != 1:
+ raise ValueError("Mask must be 1D array")
+ if len(values) != len(mask):
+ raise ValueError(
+ "Mask is a different length from sequence being converted")
+
+ if hasattr(values, '__arrow_array__'):
+ return _handle_arrow_array_protocol(values, type, mask, size)
+ elif pandas_api.is_categorical(values):
+ if type is not None:
+ if type.id != Type_DICTIONARY:
+ return _ndarray_to_array(
+ np.asarray(values), mask, type, c_from_pandas, safe,
+ pool)
+ index_type = type.index_type
+ value_type = type.value_type
+ if values.ordered != type.ordered:
+ warnings.warn(
+ "The 'ordered' flag of the passed categorical values "
+ "does not match the 'ordered' of the specified type. "
+ "Using the flag of the values, but in the future this "
+ "mismatch will raise a ValueError.",
+ FutureWarning, stacklevel=2)
+ else:
+ index_type = None
+ value_type = None
+
+ indices = _codes_to_indices(
+ values.codes, mask, index_type, memory_pool)
+ try:
+ dictionary = array(
+ values.categories.values, type=value_type,
+ memory_pool=memory_pool)
+ except TypeError:
+ # TODO when removing the deprecation warning, this whole
+ # try/except can be removed (to bubble the TypeError of
+ # the first array(..) call)
+ if value_type is not None:
+ warnings.warn(
+ "The dtype of the 'categories' of the passed "
+ "categorical values ({0}) does not match the "
+ "specified type ({1}). For now ignoring the specified "
+ "type, but in the future this mismatch will raise a "
+ "TypeError".format(
+ values.categories.dtype, value_type),
+ FutureWarning, stacklevel=2)
+ dictionary = array(
+ values.categories.values, memory_pool=memory_pool)
+ else:
+ raise
+
+ return DictionaryArray.from_arrays(
+ indices, dictionary, ordered=values.ordered, safe=safe)
+ else:
+ if pandas_api.have_pandas:
+ values, type = pandas_api.compat.get_datetimetz_type(
+ values, obj.dtype, type)
+ return _ndarray_to_array(values, mask, type, c_from_pandas, safe,
+ pool)
+ else:
+ # ConvertPySequence does strict conversion if type is explicitly passed
+ return _sequence_to_array(obj, mask, size, type, pool, c_from_pandas)
+
+
+def asarray(values, type=None):
+ """
+ Convert to pyarrow.Array, inferring type if not provided.
+
+ Parameters
+ ----------
+ values : array-like
+ This can be a sequence, numpy.ndarray, pyarrow.Array or
+ pyarrow.ChunkedArray. If a ChunkedArray is passed, the output will be
+ a ChunkedArray, otherwise the output will be a Array.
+ type : string or DataType
+ Explicitly construct the array with this type. Attempt to cast if
+ indicated type is different.
+
+ Returns
+ -------
+ arr : Array or ChunkedArray
+ """
+ if isinstance(values, (Array, ChunkedArray)):
+ if type is not None and not values.type.equals(type):
+ values = values.cast(type)
+ return values
+ else:
+ return array(values, type=type)
+
+
+def nulls(size, type=None, MemoryPool memory_pool=None):
+ """
+ Create a strongly-typed Array instance with all elements null.
+
+ Parameters
+ ----------
+ size : int
+ Array length.
+ type : pyarrow.DataType, default None
+ Explicit type for the array. By default use NullType.
+ memory_pool : MemoryPool, default None
+ Arrow MemoryPool to use for allocations. Uses the default memory
+ pool is not passed.
+
+ Returns
+ -------
+ arr : Array
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> pa.nulls(10)
+ <pyarrow.lib.NullArray object at 0x7ffaf04c2e50>
+ 10 nulls
+
+ >>> pa.nulls(3, pa.uint32())
+ <pyarrow.lib.UInt32Array object at 0x7ffaf04c2e50>
+ [
+ null,
+ null,
+ null
+ ]
+ """
+ cdef:
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ int64_t length = size
+ shared_ptr[CDataType] ty
+ shared_ptr[CArray] arr
+
+ type = ensure_type(type, allow_none=True)
+ if type is None:
+ type = null()
+
+ ty = pyarrow_unwrap_data_type(type)
+ with nogil:
+ arr = GetResultValue(MakeArrayOfNull(ty, length, pool))
+
+ return pyarrow_wrap_array(arr)
+
+
+def repeat(value, size, MemoryPool memory_pool=None):
+ """
+ Create an Array instance whose slots are the given scalar.
+
+ Parameters
+ ----------
+ value : Scalar-like object
+ Either a pyarrow.Scalar or any python object coercible to a Scalar.
+ size : int
+ Number of times to repeat the scalar in the output Array.
+ memory_pool : MemoryPool, default None
+ Arrow MemoryPool to use for allocations. Uses the default memory
+ pool is not passed.
+
+ Returns
+ -------
+ arr : Array
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> pa.repeat(10, 3)
+ <pyarrow.lib.Int64Array object at 0x7ffac03a2750>
+ [
+ 10,
+ 10,
+ 10
+ ]
+
+ >>> pa.repeat([1, 2], 2)
+ <pyarrow.lib.ListArray object at 0x7ffaf04c2e50>
+ [
+ [
+ 1,
+ 2
+ ],
+ [
+ 1,
+ 2
+ ]
+ ]
+
+ >>> pa.repeat("string", 3)
+ <pyarrow.lib.StringArray object at 0x7ffac03a2750>
+ [
+ "string",
+ "string",
+ "string"
+ ]
+
+ >>> pa.repeat(pa.scalar({'a': 1, 'b': [1, 2]}), 2)
+ <pyarrow.lib.StructArray object at 0x7ffac03a2750>
+ -- is_valid: all not null
+ -- child 0 type: int64
+ [
+ 1,
+ 1
+ ]
+ -- child 1 type: list<item: int64>
+ [
+ [
+ 1,
+ 2
+ ],
+ [
+ 1,
+ 2
+ ]
+ ]
+ """
+ cdef:
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ int64_t length = size
+ shared_ptr[CArray] c_array
+ shared_ptr[CScalar] c_scalar
+
+ if not isinstance(value, Scalar):
+ value = scalar(value, memory_pool=memory_pool)
+
+ c_scalar = (<Scalar> value).unwrap()
+ with nogil:
+ c_array = GetResultValue(
+ MakeArrayFromScalar(deref(c_scalar), length, pool)
+ )
+
+ return pyarrow_wrap_array(c_array)
+
+
+def infer_type(values, mask=None, from_pandas=False):
+ """
+ Attempt to infer Arrow data type that can hold the passed Python
+ sequence type in an Array object
+
+ Parameters
+ ----------
+ values : array-like
+ Sequence to infer type from.
+ mask : ndarray (bool type), optional
+ Optional exclusion mask where True marks null, False non-null.
+ from_pandas : bool, default False
+ Use pandas's NA/null sentinel values for type inference.
+
+ Returns
+ -------
+ type : DataType
+ """
+ cdef:
+ shared_ptr[CDataType] out
+ c_bool use_pandas_sentinels = from_pandas
+
+ if mask is not None and not isinstance(mask, np.ndarray):
+ mask = np.array(mask, dtype=bool)
+
+ out = GetResultValue(InferArrowType(values, mask, use_pandas_sentinels))
+ return pyarrow_wrap_data_type(out)
+
+
+def _normalize_slice(object arrow_obj, slice key):
+ """
+ Slices with step not equal to 1 (or None) will produce a copy
+ rather than a zero-copy view
+ """
+ cdef:
+ Py_ssize_t start, stop, step
+ Py_ssize_t n = len(arrow_obj)
+
+ start = key.start or 0
+ if start < 0:
+ start += n
+ if start < 0:
+ start = 0
+ elif start >= n:
+ start = n
+
+ stop = key.stop if key.stop is not None else n
+ if stop < 0:
+ stop += n
+ if stop < 0:
+ stop = 0
+ elif stop >= n:
+ stop = n
+
+ step = key.step or 1
+ if step != 1:
+ if step < 0:
+ # Negative steps require some special handling
+ if key.start is None:
+ start = n - 1
+
+ if key.stop is None:
+ stop = -1
+
+ indices = np.arange(start, stop, step)
+ return arrow_obj.take(indices)
+ else:
+ length = max(stop - start, 0)
+ return arrow_obj.slice(start, length)
+
+
+cdef Py_ssize_t _normalize_index(Py_ssize_t index,
+ Py_ssize_t length) except -1:
+ if index < 0:
+ index += length
+ if index < 0:
+ raise IndexError("index out of bounds")
+ elif index >= length:
+ raise IndexError("index out of bounds")
+ return index
+
+
+cdef wrap_datum(const CDatum& datum):
+ if datum.kind() == DatumType_ARRAY:
+ return pyarrow_wrap_array(MakeArray(datum.array()))
+ elif datum.kind() == DatumType_CHUNKED_ARRAY:
+ return pyarrow_wrap_chunked_array(datum.chunked_array())
+ elif datum.kind() == DatumType_RECORD_BATCH:
+ return pyarrow_wrap_batch(datum.record_batch())
+ elif datum.kind() == DatumType_TABLE:
+ return pyarrow_wrap_table(datum.table())
+ elif datum.kind() == DatumType_SCALAR:
+ return pyarrow_wrap_scalar(datum.scalar())
+ else:
+ raise ValueError("Unable to wrap Datum in a Python object")
+
+
+cdef _append_array_buffers(const CArrayData* ad, list res):
+ """
+ Recursively append Buffer wrappers from *ad* and its children.
+ """
+ cdef size_t i, n
+ assert ad != NULL
+ n = ad.buffers.size()
+ for i in range(n):
+ buf = ad.buffers[i]
+ res.append(pyarrow_wrap_buffer(buf)
+ if buf.get() != NULL else None)
+ n = ad.child_data.size()
+ for i in range(n):
+ _append_array_buffers(ad.child_data[i].get(), res)
+
+
+cdef _reduce_array_data(const CArrayData* ad):
+ """
+ Recursively dissect ArrayData to (pickable) tuples.
+ """
+ cdef size_t i, n
+ assert ad != NULL
+
+ n = ad.buffers.size()
+ buffers = []
+ for i in range(n):
+ buf = ad.buffers[i]
+ buffers.append(pyarrow_wrap_buffer(buf)
+ if buf.get() != NULL else None)
+
+ children = []
+ n = ad.child_data.size()
+ for i in range(n):
+ children.append(_reduce_array_data(ad.child_data[i].get()))
+
+ if ad.dictionary.get() != NULL:
+ dictionary = _reduce_array_data(ad.dictionary.get())
+ else:
+ dictionary = None
+
+ return pyarrow_wrap_data_type(ad.type), ad.length, ad.null_count, \
+ ad.offset, buffers, children, dictionary
+
+
+cdef shared_ptr[CArrayData] _reconstruct_array_data(data):
+ """
+ Reconstruct CArrayData objects from the tuple structure generated
+ by _reduce_array_data.
+ """
+ cdef:
+ int64_t length, null_count, offset, i
+ DataType dtype
+ Buffer buf
+ vector[shared_ptr[CBuffer]] c_buffers
+ vector[shared_ptr[CArrayData]] c_children
+ shared_ptr[CArrayData] c_dictionary
+
+ dtype, length, null_count, offset, buffers, children, dictionary = data
+
+ for i in range(len(buffers)):
+ buf = buffers[i]
+ if buf is None:
+ c_buffers.push_back(shared_ptr[CBuffer]())
+ else:
+ c_buffers.push_back(buf.buffer)
+
+ for i in range(len(children)):
+ c_children.push_back(_reconstruct_array_data(children[i]))
+
+ if dictionary is not None:
+ c_dictionary = _reconstruct_array_data(dictionary)
+
+ return CArrayData.MakeWithChildrenAndDictionary(
+ dtype.sp_type,
+ length,
+ c_buffers,
+ c_children,
+ c_dictionary,
+ null_count,
+ offset)
+
+
+def _restore_array(data):
+ """
+ Reconstruct an Array from pickled ArrayData.
+ """
+ cdef shared_ptr[CArrayData] ad = _reconstruct_array_data(data)
+ return pyarrow_wrap_array(MakeArray(ad))
+
+
+cdef class _PandasConvertible(_Weakrefable):
+
+ def to_pandas(
+ self,
+ memory_pool=None,
+ categories=None,
+ bint strings_to_categorical=False,
+ bint zero_copy_only=False,
+ bint integer_object_nulls=False,
+ bint date_as_object=True,
+ bint timestamp_as_object=False,
+ bint use_threads=True,
+ bint deduplicate_objects=True,
+ bint ignore_metadata=False,
+ bint safe=True,
+ bint split_blocks=False,
+ bint self_destruct=False,
+ types_mapper=None
+ ):
+ """
+ Convert to a pandas-compatible NumPy array or DataFrame, as appropriate
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ Arrow MemoryPool to use for allocations. Uses the default memory
+ pool is not passed.
+ strings_to_categorical : bool, default False
+ Encode string (UTF8) and binary types to pandas.Categorical.
+ categories: list, default empty
+ List of fields that should be returned as pandas.Categorical. Only
+ applies to table-like data structures.
+ zero_copy_only : bool, default False
+ Raise an ArrowException if this function call would require copying
+ the underlying data.
+ integer_object_nulls : bool, default False
+ Cast integers with nulls to objects
+ date_as_object : bool, default True
+ Cast dates to objects. If False, convert to datetime64[ns] dtype.
+ timestamp_as_object : bool, default False
+ Cast non-nanosecond timestamps (np.datetime64) to objects. This is
+ useful if you have timestamps that don't fit in the normal date
+ range of nanosecond timestamps (1678 CE-2262 CE).
+ If False, all timestamps are converted to datetime64[ns] dtype.
+ use_threads: bool, default True
+ Whether to parallelize the conversion using multiple threads.
+ deduplicate_objects : bool, default False
+ Do not create multiple copies Python objects when created, to save
+ on memory use. Conversion will be slower.
+ ignore_metadata : bool, default False
+ If True, do not use the 'pandas' metadata to reconstruct the
+ DataFrame index, if present
+ safe : bool, default True
+ For certain data types, a cast is needed in order to store the
+ data in a pandas DataFrame or Series (e.g. timestamps are always
+ stored as nanoseconds in pandas). This option controls whether it
+ is a safe cast or not.
+ split_blocks : bool, default False
+ If True, generate one internal "block" for each column when
+ creating a pandas.DataFrame from a RecordBatch or Table. While this
+ can temporarily reduce memory note that various pandas operations
+ can trigger "consolidation" which may balloon memory use.
+ self_destruct : bool, default False
+ EXPERIMENTAL: If True, attempt to deallocate the originating Arrow
+ memory while converting the Arrow object to pandas. If you use the
+ object after calling to_pandas with this option it will crash your
+ program.
+
+ Note that you may not see always memory usage improvements. For
+ example, if multiple columns share an underlying allocation,
+ memory can't be freed until all columns are converted.
+ types_mapper : function, default None
+ A function mapping a pyarrow DataType to a pandas ExtensionDtype.
+ This can be used to override the default pandas type for conversion
+ of built-in pyarrow types or in absence of pandas_metadata in the
+ Table schema. The function receives a pyarrow DataType and is
+ expected to return a pandas ExtensionDtype or ``None`` if the
+ default conversion should be used for that type. If you have
+ a dictionary mapping, you can pass ``dict.get`` as function.
+
+ Returns
+ -------
+ pandas.Series or pandas.DataFrame depending on type of object
+ """
+ options = dict(
+ pool=memory_pool,
+ strings_to_categorical=strings_to_categorical,
+ zero_copy_only=zero_copy_only,
+ integer_object_nulls=integer_object_nulls,
+ date_as_object=date_as_object,
+ timestamp_as_object=timestamp_as_object,
+ use_threads=use_threads,
+ deduplicate_objects=deduplicate_objects,
+ safe=safe,
+ split_blocks=split_blocks,
+ self_destruct=self_destruct
+ )
+ return self._to_pandas(options, categories=categories,
+ ignore_metadata=ignore_metadata,
+ types_mapper=types_mapper)
+
+
+cdef PandasOptions _convert_pandas_options(dict options):
+ cdef PandasOptions result
+ result.pool = maybe_unbox_memory_pool(options['pool'])
+ result.strings_to_categorical = options['strings_to_categorical']
+ result.zero_copy_only = options['zero_copy_only']
+ result.integer_object_nulls = options['integer_object_nulls']
+ result.date_as_object = options['date_as_object']
+ result.timestamp_as_object = options['timestamp_as_object']
+ result.use_threads = options['use_threads']
+ result.deduplicate_objects = options['deduplicate_objects']
+ result.safe_cast = options['safe']
+ result.split_blocks = options['split_blocks']
+ result.self_destruct = options['self_destruct']
+ result.ignore_timezone = os.environ.get('PYARROW_IGNORE_TIMEZONE', False)
+ return result
+
+
+cdef class Array(_PandasConvertible):
+ """
+ The base class for all Arrow arrays.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use one of "
+ "the `pyarrow.Array.from_*` functions instead."
+ .format(self.__class__.__name__))
+
+ cdef void init(self, const shared_ptr[CArray]& sp_array) except *:
+ self.sp_array = sp_array
+ self.ap = sp_array.get()
+ self.type = pyarrow_wrap_data_type(self.sp_array.get().type())
+
+ def _debug_print(self):
+ with nogil:
+ check_status(DebugPrint(deref(self.ap), 0))
+
+ def diff(self, Array other):
+ """
+ Compare contents of this array against another one.
+
+ Return string containing the result of arrow::Diff comparing contents
+ of this array against the other array.
+ """
+ cdef c_string result
+ with nogil:
+ result = self.ap.Diff(deref(other.ap))
+ return frombytes(result, safe=True)
+
+ def cast(self, object target_type, safe=True):
+ """
+ Cast array values to another data type
+
+ See pyarrow.compute.cast for usage
+ """
+ return _pc().cast(self, target_type, safe=safe)
+
+ def view(self, object target_type):
+ """
+ Return zero-copy "view" of array as another data type.
+
+ The data types must have compatible columnar buffer layouts
+
+ Parameters
+ ----------
+ target_type : DataType
+ Type to construct view as.
+
+ Returns
+ -------
+ view : Array
+ """
+ cdef DataType type = ensure_type(target_type)
+ cdef shared_ptr[CArray] result
+ with nogil:
+ result = GetResultValue(self.ap.View(type.sp_type))
+ return pyarrow_wrap_array(result)
+
+ def sum(self, **kwargs):
+ """
+ Sum the values in a numerical array.
+ """
+ options = _pc().ScalarAggregateOptions(**kwargs)
+ return _pc().call_function('sum', [self], options)
+
+ def unique(self):
+ """
+ Compute distinct elements in array.
+ """
+ return _pc().call_function('unique', [self])
+
+ def dictionary_encode(self, null_encoding='mask'):
+ """
+ Compute dictionary-encoded representation of array.
+ """
+ options = _pc().DictionaryEncodeOptions(null_encoding)
+ return _pc().call_function('dictionary_encode', [self], options)
+
+ def value_counts(self):
+ """
+ Compute counts of unique elements in array.
+
+ Returns
+ -------
+ An array of <input type "Values", int64_t "Counts"> structs
+ """
+ return _pc().call_function('value_counts', [self])
+
+ @staticmethod
+ def from_pandas(obj, mask=None, type=None, bint safe=True,
+ MemoryPool memory_pool=None):
+ """
+ Convert pandas.Series to an Arrow Array.
+
+ This method uses Pandas semantics about what values indicate
+ nulls. See pyarrow.array for more general conversion from arrays or
+ sequences to Arrow arrays.
+
+ Parameters
+ ----------
+ obj : ndarray, pandas.Series, array-like
+ mask : array (boolean), optional
+ Indicate which values are null (True) or not null (False).
+ type : pyarrow.DataType
+ Explicit type to attempt to coerce to, otherwise will be inferred
+ from the data.
+ safe : bool, default True
+ Check for overflows or other unsafe conversions.
+ memory_pool : pyarrow.MemoryPool, optional
+ If not passed, will allocate memory from the currently-set default
+ memory pool.
+
+ Notes
+ -----
+ Localized timestamps will currently be returned as UTC (pandas's native
+ representation). Timezone-naive data will be implicitly interpreted as
+ UTC.
+
+ Returns
+ -------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ ChunkedArray is returned if object data overflows binary buffer.
+ """
+ return array(obj, mask=mask, type=type, safe=safe, from_pandas=True,
+ memory_pool=memory_pool)
+
+ def __reduce__(self):
+ return _restore_array, \
+ (_reduce_array_data(self.sp_array.get().data().get()),)
+
+ @staticmethod
+ def from_buffers(DataType type, length, buffers, null_count=-1, offset=0,
+ children=None):
+ """
+ Construct an Array from a sequence of buffers.
+
+ The concrete type returned depends on the datatype.
+
+ Parameters
+ ----------
+ type : DataType
+ The value type of the array.
+ length : int
+ The number of values in the array.
+ buffers : List[Buffer]
+ The buffers backing this array.
+ null_count : int, default -1
+ The number of null entries in the array. Negative value means that
+ the null count is not known.
+ offset : int, default 0
+ The array's logical offset (in values, not in bytes) from the
+ start of each buffer.
+ children : List[Array], default None
+ Nested type children with length matching type.num_fields.
+
+ Returns
+ -------
+ array : Array
+ """
+ cdef:
+ Buffer buf
+ Array child
+ vector[shared_ptr[CBuffer]] c_buffers
+ vector[shared_ptr[CArrayData]] c_child_data
+ shared_ptr[CArrayData] array_data
+
+ children = children or []
+
+ if type.num_fields != len(children):
+ raise ValueError("Type's expected number of children "
+ "({0}) did not match the passed number "
+ "({1}).".format(type.num_fields, len(children)))
+
+ if type.num_buffers != len(buffers):
+ raise ValueError("Type's expected number of buffers "
+ "({0}) did not match the passed number "
+ "({1}).".format(type.num_buffers, len(buffers)))
+
+ for buf in buffers:
+ # None will produce a null buffer pointer
+ c_buffers.push_back(pyarrow_unwrap_buffer(buf))
+
+ for child in children:
+ c_child_data.push_back(child.ap.data())
+
+ array_data = CArrayData.MakeWithChildren(type.sp_type, length,
+ c_buffers, c_child_data,
+ null_count, offset)
+ cdef Array result = pyarrow_wrap_array(MakeArray(array_data))
+ result.validate()
+ return result
+
+ @property
+ def null_count(self):
+ return self.sp_array.get().null_count()
+
+ @property
+ def nbytes(self):
+ """
+ Total number of bytes consumed by the elements of the array.
+ """
+ size = 0
+ for buf in self.buffers():
+ if buf is not None:
+ size += buf.size
+ return size
+
+ def __sizeof__(self):
+ return super(Array, self).__sizeof__() + self.nbytes
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield self.getitem(i)
+
+ def __repr__(self):
+ type_format = object.__repr__(self)
+ return '{0}\n{1}'.format(type_format, str(self))
+
+ def to_string(self, *, int indent=0, int window=10,
+ c_bool skip_new_lines=False):
+ """
+ Render a "pretty-printed" string representation of the Array.
+
+ Parameters
+ ----------
+ indent : int
+ How much to indent right the content of the array,
+ by default ``0``.
+ window : int
+ How many items to preview at the begin and end
+ of the array when the arrays is bigger than the window.
+ The other elements will be ellipsed.
+ skip_new_lines : bool
+ If the array should be rendered as a single line of text
+ or if each element should be on its own line.
+ """
+ cdef:
+ c_string result
+ PrettyPrintOptions options
+
+ with nogil:
+ options = PrettyPrintOptions(indent, window)
+ options.skip_new_lines = skip_new_lines
+ check_status(
+ PrettyPrint(
+ deref(self.ap),
+ options,
+ &result
+ )
+ )
+
+ return frombytes(result, safe=True)
+
+ def format(self, **kwargs):
+ import warnings
+ warnings.warn('Array.format is deprecated, use Array.to_string')
+ return self.to_string(**kwargs)
+
+ def __str__(self):
+ return self.to_string()
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ # This also handles comparing with None
+ # as Array.equals(None) raises a TypeError.
+ return NotImplemented
+
+ def equals(Array self, Array other not None):
+ return self.ap.Equals(deref(other.ap))
+
+ def __len__(self):
+ return self.length()
+
+ cdef int64_t length(self):
+ if self.sp_array.get():
+ return self.sp_array.get().length()
+ else:
+ return 0
+
+ def is_null(self, *, nan_is_null=False):
+ """
+ Return BooleanArray indicating the null values.
+
+ Parameters
+ ----------
+ nan_is_null : bool (optional, default False)
+ Whether floating-point NaN values should also be considered null.
+
+ Returns
+ -------
+ array : boolean Array
+ """
+ options = _pc().NullOptions(nan_is_null=nan_is_null)
+ return _pc().call_function('is_null', [self], options)
+
+ def is_valid(self):
+ """
+ Return BooleanArray indicating the non-null values.
+ """
+ return _pc().is_valid(self)
+
+ def fill_null(self, fill_value):
+ """
+ See pyarrow.compute.fill_null for usage.
+ """
+ return _pc().fill_null(self, fill_value)
+
+ def __getitem__(self, key):
+ """
+ Slice or return value at given index
+
+ Parameters
+ ----------
+ key : integer or slice
+ Slices with step not equal to 1 (or None) will produce a copy
+ rather than a zero-copy view
+
+ Returns
+ -------
+ value : Scalar (index) or Array (slice)
+ """
+ if PySlice_Check(key):
+ return _normalize_slice(self, key)
+
+ return self.getitem(_normalize_index(key, self.length()))
+
+ cdef getitem(self, int64_t i):
+ return Scalar.wrap(GetResultValue(self.ap.GetScalar(i)))
+
+ def slice(self, offset=0, length=None):
+ """
+ Compute zero-copy slice of this array.
+
+ Parameters
+ ----------
+ offset : int, default 0
+ Offset from start of array to slice.
+ length : int, default None
+ Length of slice (default is until end of Array starting from
+ offset).
+
+ Returns
+ -------
+ sliced : RecordBatch
+ """
+ cdef:
+ shared_ptr[CArray] result
+
+ if offset < 0:
+ raise IndexError('Offset must be non-negative')
+
+ offset = min(len(self), offset)
+ if length is None:
+ result = self.ap.Slice(offset)
+ else:
+ if length < 0:
+ raise ValueError('Length must be non-negative')
+ result = self.ap.Slice(offset, length)
+
+ return pyarrow_wrap_array(result)
+
+ def take(self, object indices):
+ """
+ Select values from an array. See pyarrow.compute.take for full usage.
+ """
+ return _pc().take(self, indices)
+
+ def drop_null(self):
+ """
+ Remove missing values from an array.
+ """
+ return _pc().drop_null(self)
+
+ def filter(self, Array mask, *, null_selection_behavior='drop'):
+ """
+ Select values from an array. See pyarrow.compute.filter for full usage.
+ """
+ return _pc().filter(self, mask,
+ null_selection_behavior=null_selection_behavior)
+
+ def index(self, value, start=None, end=None, *, memory_pool=None):
+ """
+ Find the first index of a value.
+
+ See pyarrow.compute.index for full usage.
+ """
+ return _pc().index(self, value, start, end, memory_pool=memory_pool)
+
+ def _to_pandas(self, options, **kwargs):
+ return _array_like_to_pandas(self, options)
+
+ def __array__(self, dtype=None):
+ values = self.to_numpy(zero_copy_only=False)
+ if dtype is None:
+ return values
+ return values.astype(dtype)
+
+ def to_numpy(self, zero_copy_only=True, writable=False):
+ """
+ Return a NumPy view or copy of this array (experimental).
+
+ By default, tries to return a view of this array. This is only
+ supported for primitive arrays with the same memory layout as NumPy
+ (i.e. integers, floating point, ..) and without any nulls.
+
+ Parameters
+ ----------
+ zero_copy_only : bool, default True
+ If True, an exception will be raised if the conversion to a numpy
+ array would require copying the underlying data (e.g. in presence
+ of nulls, or for non-primitive types).
+ writable : bool, default False
+ For numpy arrays created with zero copy (view on the Arrow data),
+ the resulting array is not writable (Arrow data is immutable).
+ By setting this to True, a copy of the array is made to ensure
+ it is writable.
+
+ Returns
+ -------
+ array : numpy.ndarray
+ """
+ cdef:
+ PyObject* out
+ PandasOptions c_options
+ object values
+
+ if zero_copy_only and writable:
+ raise ValueError(
+ "Cannot return a writable array if asking for zero-copy")
+
+ # If there are nulls and the array is a DictionaryArray
+ # decoding the dictionary will make sure nulls are correctly handled.
+ # Decoding a dictionary does imply a copy by the way,
+ # so it can't be done if the user requested a zero_copy.
+ c_options.decode_dictionaries = not zero_copy_only
+ c_options.zero_copy_only = zero_copy_only
+
+ with nogil:
+ check_status(ConvertArrayToPandas(c_options, self.sp_array,
+ self, &out))
+
+ # wrap_array_output uses pandas to convert to Categorical, here
+ # always convert to numpy array without pandas dependency
+ array = PyObject_to_object(out)
+
+ if isinstance(array, dict):
+ array = np.take(array['dictionary'], array['indices'])
+
+ if writable and not array.flags.writeable:
+ # if the conversion already needed to a copy, writeable is True
+ array = array.copy()
+ return array
+
+ def to_pylist(self):
+ """
+ Convert to a list of native Python objects.
+
+ Returns
+ -------
+ lst : list
+ """
+ return [x.as_py() for x in self]
+
+ def tolist(self):
+ """
+ Alias of to_pylist for compatibility with NumPy.
+ """
+ return self.to_pylist()
+
+ def validate(self, *, full=False):
+ """
+ Perform validation checks. An exception is raised if validation fails.
+
+ By default only cheap validation checks are run. Pass `full=True`
+ for thorough validation checks (potentially O(n)).
+
+ Parameters
+ ----------
+ full: bool, default False
+ If True, run expensive checks, otherwise cheap checks only.
+
+ Raises
+ ------
+ ArrowInvalid
+ """
+ if full:
+ with nogil:
+ check_status(self.ap.ValidateFull())
+ else:
+ with nogil:
+ check_status(self.ap.Validate())
+
+ @property
+ def offset(self):
+ """
+ A relative position into another array's data.
+
+ The purpose is to enable zero-copy slicing. This value defaults to zero
+ but must be applied on all operations with the physical storage
+ buffers.
+ """
+ return self.sp_array.get().offset()
+
+ def buffers(self):
+ """
+ Return a list of Buffer objects pointing to this array's physical
+ storage.
+
+ To correctly interpret these buffers, you need to also apply the offset
+ multiplied with the size of the stored data type.
+ """
+ res = []
+ _append_array_buffers(self.sp_array.get().data().get(), res)
+ return res
+
+ def _export_to_c(self, uintptr_t out_ptr, uintptr_t out_schema_ptr=0):
+ """
+ Export to a C ArrowArray struct, given its pointer.
+
+ If a C ArrowSchema struct pointer is also given, the array type
+ is exported to it at the same time.
+
+ Parameters
+ ----------
+ out_ptr: int
+ The raw pointer to a C ArrowArray struct.
+ out_schema_ptr: int (optional)
+ The raw pointer to a C ArrowSchema struct.
+
+ Be careful: if you don't pass the ArrowArray struct to a consumer,
+ array memory will leak. This is a low-level function intended for
+ expert users.
+ """
+ with nogil:
+ check_status(ExportArray(deref(self.sp_array),
+ <ArrowArray*> out_ptr,
+ <ArrowSchema*> out_schema_ptr))
+
+ @staticmethod
+ def _import_from_c(uintptr_t in_ptr, type):
+ """
+ Import Array from a C ArrowArray struct, given its pointer
+ and the imported array type.
+
+ Parameters
+ ----------
+ in_ptr: int
+ The raw pointer to a C ArrowArray struct.
+ type: DataType or int
+ Either a DataType object, or the raw pointer to a C ArrowSchema
+ struct.
+
+ This is a low-level function intended for expert users.
+ """
+ cdef:
+ shared_ptr[CArray] c_array
+
+ c_type = pyarrow_unwrap_data_type(type)
+ if c_type == nullptr:
+ # Not a DataType object, perhaps a raw ArrowSchema pointer
+ type_ptr = <uintptr_t> type
+ with nogil:
+ c_array = GetResultValue(ImportArray(<ArrowArray*> in_ptr,
+ <ArrowSchema*> type_ptr))
+ else:
+ with nogil:
+ c_array = GetResultValue(ImportArray(<ArrowArray*> in_ptr,
+ c_type))
+ return pyarrow_wrap_array(c_array)
+
+
+cdef _array_like_to_pandas(obj, options):
+ cdef:
+ PyObject* out
+ PandasOptions c_options = _convert_pandas_options(options)
+
+ original_type = obj.type
+ name = obj._name
+
+ # ARROW-3789(wesm): Convert date/timestamp types to datetime64[ns]
+ c_options.coerce_temporal_nanoseconds = True
+
+ if isinstance(obj, Array):
+ with nogil:
+ check_status(ConvertArrayToPandas(c_options,
+ (<Array> obj).sp_array,
+ obj, &out))
+ elif isinstance(obj, ChunkedArray):
+ with nogil:
+ check_status(libarrow.ConvertChunkedArrayToPandas(
+ c_options,
+ (<ChunkedArray> obj).sp_chunked_array,
+ obj, &out))
+
+ arr = wrap_array_output(out)
+
+ if (isinstance(original_type, TimestampType) and
+ options["timestamp_as_object"]):
+ # ARROW-5359 - need to specify object dtype to avoid pandas to
+ # coerce back to ns resolution
+ dtype = "object"
+ else:
+ dtype = None
+
+ result = pandas_api.series(arr, dtype=dtype, name=name)
+
+ if (isinstance(original_type, TimestampType) and
+ original_type.tz is not None and
+ # can be object dtype for non-ns and timestamp_as_object=True
+ result.dtype.kind == "M"):
+ from pyarrow.pandas_compat import make_tz_aware
+ result = make_tz_aware(result, original_type.tz)
+
+ return result
+
+
+cdef wrap_array_output(PyObject* output):
+ cdef object obj = PyObject_to_object(output)
+
+ if isinstance(obj, dict):
+ return pandas_api.categorical_type(obj['indices'],
+ categories=obj['dictionary'],
+ ordered=obj['ordered'],
+ fastpath=True)
+ else:
+ return obj
+
+
+cdef class NullArray(Array):
+ """
+ Concrete class for Arrow arrays of null data type.
+ """
+
+
+cdef class BooleanArray(Array):
+ """
+ Concrete class for Arrow arrays of boolean data type.
+ """
+ @property
+ def false_count(self):
+ return (<CBooleanArray*> self.ap).false_count()
+
+ @property
+ def true_count(self):
+ return (<CBooleanArray*> self.ap).true_count()
+
+
+cdef class NumericArray(Array):
+ """
+ A base class for Arrow numeric arrays.
+ """
+
+
+cdef class IntegerArray(NumericArray):
+ """
+ A base class for Arrow integer arrays.
+ """
+
+
+cdef class FloatingPointArray(NumericArray):
+ """
+ A base class for Arrow floating-point arrays.
+ """
+
+
+cdef class Int8Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of int8 data type.
+ """
+
+
+cdef class UInt8Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of uint8 data type.
+ """
+
+
+cdef class Int16Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of int16 data type.
+ """
+
+
+cdef class UInt16Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of uint16 data type.
+ """
+
+
+cdef class Int32Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of int32 data type.
+ """
+
+
+cdef class UInt32Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of uint32 data type.
+ """
+
+
+cdef class Int64Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of int64 data type.
+ """
+
+
+cdef class UInt64Array(IntegerArray):
+ """
+ Concrete class for Arrow arrays of uint64 data type.
+ """
+
+
+cdef class Date32Array(NumericArray):
+ """
+ Concrete class for Arrow arrays of date32 data type.
+ """
+
+
+cdef class Date64Array(NumericArray):
+ """
+ Concrete class for Arrow arrays of date64 data type.
+ """
+
+
+cdef class TimestampArray(NumericArray):
+ """
+ Concrete class for Arrow arrays of timestamp data type.
+ """
+
+
+cdef class Time32Array(NumericArray):
+ """
+ Concrete class for Arrow arrays of time32 data type.
+ """
+
+
+cdef class Time64Array(NumericArray):
+ """
+ Concrete class for Arrow arrays of time64 data type.
+ """
+
+
+cdef class DurationArray(NumericArray):
+ """
+ Concrete class for Arrow arrays of duration data type.
+ """
+
+
+cdef class MonthDayNanoIntervalArray(Array):
+ """
+ Concrete class for Arrow arrays of interval[MonthDayNano] type.
+ """
+
+ def to_pylist(self):
+ """
+ Convert to a list of native Python objects.
+
+ pyarrow.MonthDayNano is used as the native representation.
+
+ Returns
+ -------
+ lst : list
+ """
+ cdef:
+ CResult[PyObject*] maybe_py_list
+ PyObject* py_list
+ CMonthDayNanoIntervalArray* array
+ array = <CMonthDayNanoIntervalArray*>self.sp_array.get()
+ maybe_py_list = MonthDayNanoIntervalArrayToPyList(deref(array))
+ py_list = GetResultValue(maybe_py_list)
+ return PyObject_to_object(py_list)
+
+
+cdef class HalfFloatArray(FloatingPointArray):
+ """
+ Concrete class for Arrow arrays of float16 data type.
+ """
+
+
+cdef class FloatArray(FloatingPointArray):
+ """
+ Concrete class for Arrow arrays of float32 data type.
+ """
+
+
+cdef class DoubleArray(FloatingPointArray):
+ """
+ Concrete class for Arrow arrays of float64 data type.
+ """
+
+
+cdef class FixedSizeBinaryArray(Array):
+ """
+ Concrete class for Arrow arrays of a fixed-size binary data type.
+ """
+
+
+cdef class Decimal128Array(FixedSizeBinaryArray):
+ """
+ Concrete class for Arrow arrays of decimal128 data type.
+ """
+
+
+cdef class Decimal256Array(FixedSizeBinaryArray):
+ """
+ Concrete class for Arrow arrays of decimal256 data type.
+ """
+
+cdef class BaseListArray(Array):
+
+ def flatten(self):
+ """
+ Unnest this ListArray/LargeListArray by one level.
+
+ The returned Array is logically a concatenation of all the sub-lists
+ in this Array.
+
+ Note that this method is different from ``self.values()`` in that
+ it takes care of the slicing offset as well as null elements backed
+ by non-empty sub-lists.
+
+ Returns
+ -------
+ result : Array
+ """
+ return _pc().list_flatten(self)
+
+ def value_parent_indices(self):
+ """
+ Return array of same length as list child values array where each
+ output value is the index of the parent list array slot containing each
+ child value.
+
+ Examples
+ --------
+ >>> arr = pa.array([[1, 2, 3], [], None, [4]],
+ ... type=pa.list_(pa.int32()))
+ >>> arr.value_parent_indices()
+ <pyarrow.lib.Int32Array object at 0x7efc5db958a0>
+ [
+ 0,
+ 0,
+ 0,
+ 3
+ ]
+ """
+ return _pc().list_parent_indices(self)
+
+ def value_lengths(self):
+ """
+ Return integers array with values equal to the respective length of
+ each list element. Null list values are null in the output.
+
+ Examples
+ --------
+ >>> arr = pa.array([[1, 2, 3], [], None, [4]],
+ ... type=pa.list_(pa.int32()))
+ >>> arr.value_lengths()
+ <pyarrow.lib.Int32Array object at 0x7efc5db95910>
+ [
+ 3,
+ 0,
+ null,
+ 1
+ ]
+ """
+ return _pc().list_value_length(self)
+
+
+cdef class ListArray(BaseListArray):
+ """
+ Concrete class for Arrow arrays of a list data type.
+ """
+
+ @staticmethod
+ def from_arrays(offsets, values, MemoryPool pool=None):
+ """
+ Construct ListArray from arrays of int32 offsets and values.
+
+ Parameters
+ ----------
+ offsets : Array (int32 type)
+ values : Array (any type)
+ pool : MemoryPool
+
+ Returns
+ -------
+ list_array : ListArray
+
+ Examples
+ --------
+ >>> values = pa.array([1, 2, 3, 4])
+ >>> offsets = pa.array([0, 2, 4])
+ >>> pa.ListArray.from_arrays(offsets, values)
+ <pyarrow.lib.ListArray object at 0x7fbde226bf40>
+ [
+ [
+ 0,
+ 1
+ ],
+ [
+ 2,
+ 3
+ ]
+ ]
+ # nulls in the offsets array become null lists
+ >>> offsets = pa.array([0, None, 2, 4])
+ >>> pa.ListArray.from_arrays(offsets, values)
+ <pyarrow.lib.ListArray object at 0x7fbde226bf40>
+ [
+ [
+ 0,
+ 1
+ ],
+ null,
+ [
+ 2,
+ 3
+ ]
+ ]
+ """
+ cdef:
+ Array _offsets, _values
+ shared_ptr[CArray] out
+ cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool)
+
+ _offsets = asarray(offsets, type='int32')
+ _values = asarray(values)
+
+ with nogil:
+ out = GetResultValue(
+ CListArray.FromArrays(_offsets.ap[0], _values.ap[0], cpool))
+ cdef Array result = pyarrow_wrap_array(out)
+ result.validate()
+ return result
+
+ @property
+ def values(self):
+ cdef CListArray* arr = <CListArray*> self.ap
+ return pyarrow_wrap_array(arr.values())
+
+ @property
+ def offsets(self):
+ """
+ Return the offsets as an int32 array.
+ """
+ return pyarrow_wrap_array((<CListArray*> self.ap).offsets())
+
+
+cdef class LargeListArray(BaseListArray):
+ """
+ Concrete class for Arrow arrays of a large list data type.
+
+ Identical to ListArray, but 64-bit offsets.
+ """
+
+ @staticmethod
+ def from_arrays(offsets, values, MemoryPool pool=None):
+ """
+ Construct LargeListArray from arrays of int64 offsets and values.
+
+ Parameters
+ ----------
+ offsets : Array (int64 type)
+ values : Array (any type)
+ pool : MemoryPool
+
+ Returns
+ -------
+ list_array : LargeListArray
+ """
+ cdef:
+ Array _offsets, _values
+ shared_ptr[CArray] out
+ cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool)
+
+ _offsets = asarray(offsets, type='int64')
+ _values = asarray(values)
+
+ with nogil:
+ out = GetResultValue(
+ CLargeListArray.FromArrays(_offsets.ap[0], _values.ap[0],
+ cpool))
+ cdef Array result = pyarrow_wrap_array(out)
+ result.validate()
+ return result
+
+ @property
+ def values(self):
+ cdef CLargeListArray* arr = <CLargeListArray*> self.ap
+ return pyarrow_wrap_array(arr.values())
+
+ @property
+ def offsets(self):
+ """
+ Return the offsets as an int64 array.
+ """
+ return pyarrow_wrap_array((<CLargeListArray*> self.ap).offsets())
+
+
+cdef class MapArray(Array):
+ """
+ Concrete class for Arrow arrays of a map data type.
+ """
+
+ @staticmethod
+ def from_arrays(offsets, keys, items, MemoryPool pool=None):
+ """
+ Construct MapArray from arrays of int32 offsets and key, item arrays.
+
+ Parameters
+ ----------
+ offsets : array-like or sequence (int32 type)
+ keys : array-like or sequence (any type)
+ items : array-like or sequence (any type)
+ pool : MemoryPool
+
+ Returns
+ -------
+ map_array : MapArray
+ """
+ cdef:
+ Array _offsets, _keys, _items
+ shared_ptr[CArray] out
+ cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool)
+
+ _offsets = asarray(offsets, type='int32')
+ _keys = asarray(keys)
+ _items = asarray(items)
+
+ with nogil:
+ out = GetResultValue(
+ CMapArray.FromArrays(_offsets.sp_array,
+ _keys.sp_array,
+ _items.sp_array, cpool))
+ cdef Array result = pyarrow_wrap_array(out)
+ result.validate()
+ return result
+
+ @property
+ def keys(self):
+ return pyarrow_wrap_array((<CMapArray*> self.ap).keys())
+
+ @property
+ def items(self):
+ return pyarrow_wrap_array((<CMapArray*> self.ap).items())
+
+
+cdef class FixedSizeListArray(Array):
+ """
+ Concrete class for Arrow arrays of a fixed size list data type.
+ """
+
+ @staticmethod
+ def from_arrays(values, int32_t list_size):
+ """
+ Construct FixedSizeListArray from array of values and a list length.
+
+ Parameters
+ ----------
+ values : Array (any type)
+ list_size : int
+ The fixed length of the lists.
+
+ Returns
+ -------
+ FixedSizeListArray
+ """
+ cdef:
+ Array _values
+ CResult[shared_ptr[CArray]] c_result
+
+ _values = asarray(values)
+
+ with nogil:
+ c_result = CFixedSizeListArray.FromArrays(
+ _values.sp_array, list_size)
+ cdef Array result = pyarrow_wrap_array(GetResultValue(c_result))
+ result.validate()
+ return result
+
+ @property
+ def values(self):
+ return self.flatten()
+
+ def flatten(self):
+ """
+ Unnest this FixedSizeListArray by one level.
+
+ Returns
+ -------
+ result : Array
+ """
+ cdef CFixedSizeListArray* arr = <CFixedSizeListArray*> self.ap
+ return pyarrow_wrap_array(arr.values())
+
+
+cdef class UnionArray(Array):
+ """
+ Concrete class for Arrow arrays of a Union data type.
+ """
+
+ def child(self, int pos):
+ import warnings
+ warnings.warn("child is deprecated, use field", FutureWarning)
+ return self.field(pos)
+
+ def field(self, int pos):
+ """
+ Return the given child field as an individual array.
+
+ For sparse unions, the returned array has its offset, length,
+ and null count adjusted.
+
+ For dense unions, the returned array is unchanged.
+ """
+ cdef shared_ptr[CArray] result
+ result = (<CUnionArray*> self.ap).field(pos)
+ if result != NULL:
+ return pyarrow_wrap_array(result)
+ raise KeyError("UnionArray does not have child {}".format(pos))
+
+ @property
+ def type_codes(self):
+ """Get the type codes array."""
+ buf = pyarrow_wrap_buffer((<CUnionArray*> self.ap).type_codes())
+ return Array.from_buffers(int8(), len(self), [None, buf])
+
+ @property
+ def offsets(self):
+ """
+ Get the value offsets array (dense arrays only).
+
+ Does not account for any slice offset.
+ """
+ if self.type.mode != "dense":
+ raise ArrowTypeError("Can only get value offsets for dense arrays")
+ cdef CDenseUnionArray* dense = <CDenseUnionArray*> self.ap
+ buf = pyarrow_wrap_buffer(dense.value_offsets())
+ return Array.from_buffers(int32(), len(self), [None, buf])
+
+ @staticmethod
+ def from_dense(Array types, Array value_offsets, list children,
+ list field_names=None, list type_codes=None):
+ """
+ Construct dense UnionArray from arrays of int8 types, int32 offsets and
+ children arrays
+
+ Parameters
+ ----------
+ types : Array (int8 type)
+ value_offsets : Array (int32 type)
+ children : list
+ field_names : list
+ type_codes : list
+
+ Returns
+ -------
+ union_array : UnionArray
+ """
+ cdef:
+ shared_ptr[CArray] out
+ vector[shared_ptr[CArray]] c
+ Array child
+ vector[c_string] c_field_names
+ vector[int8_t] c_type_codes
+
+ for child in children:
+ c.push_back(child.sp_array)
+ if field_names is not None:
+ for x in field_names:
+ c_field_names.push_back(tobytes(x))
+ if type_codes is not None:
+ for x in type_codes:
+ c_type_codes.push_back(x)
+
+ with nogil:
+ out = GetResultValue(CDenseUnionArray.Make(
+ deref(types.ap), deref(value_offsets.ap), c, c_field_names,
+ c_type_codes))
+
+ cdef Array result = pyarrow_wrap_array(out)
+ result.validate()
+ return result
+
+ @staticmethod
+ def from_sparse(Array types, list children, list field_names=None,
+ list type_codes=None):
+ """
+ Construct sparse UnionArray from arrays of int8 types and children
+ arrays
+
+ Parameters
+ ----------
+ types : Array (int8 type)
+ children : list
+ field_names : list
+ type_codes : list
+
+ Returns
+ -------
+ union_array : UnionArray
+ """
+ cdef:
+ shared_ptr[CArray] out
+ vector[shared_ptr[CArray]] c
+ Array child
+ vector[c_string] c_field_names
+ vector[int8_t] c_type_codes
+
+ for child in children:
+ c.push_back(child.sp_array)
+ if field_names is not None:
+ for x in field_names:
+ c_field_names.push_back(tobytes(x))
+ if type_codes is not None:
+ for x in type_codes:
+ c_type_codes.push_back(x)
+
+ with nogil:
+ out = GetResultValue(CSparseUnionArray.Make(
+ deref(types.ap), c, c_field_names, c_type_codes))
+
+ cdef Array result = pyarrow_wrap_array(out)
+ result.validate()
+ return result
+
+
+cdef class StringArray(Array):
+ """
+ Concrete class for Arrow arrays of string (or utf8) data type.
+ """
+
+ @staticmethod
+ def from_buffers(int length, Buffer value_offsets, Buffer data,
+ Buffer null_bitmap=None, int null_count=-1,
+ int offset=0):
+ """
+ Construct a StringArray from value_offsets and data buffers.
+ If there are nulls in the data, also a null_bitmap and the matching
+ null_count must be passed.
+
+ Parameters
+ ----------
+ length : int
+ value_offsets : Buffer
+ data : Buffer
+ null_bitmap : Buffer, optional
+ null_count : int, default 0
+ offset : int, default 0
+
+ Returns
+ -------
+ string_array : StringArray
+ """
+ return Array.from_buffers(utf8(), length,
+ [null_bitmap, value_offsets, data],
+ null_count, offset)
+
+
+cdef class LargeStringArray(Array):
+ """
+ Concrete class for Arrow arrays of large string (or utf8) data type.
+ """
+
+ @staticmethod
+ def from_buffers(int length, Buffer value_offsets, Buffer data,
+ Buffer null_bitmap=None, int null_count=-1,
+ int offset=0):
+ """
+ Construct a LargeStringArray from value_offsets and data buffers.
+ If there are nulls in the data, also a null_bitmap and the matching
+ null_count must be passed.
+
+ Parameters
+ ----------
+ length : int
+ value_offsets : Buffer
+ data : Buffer
+ null_bitmap : Buffer, optional
+ null_count : int, default 0
+ offset : int, default 0
+
+ Returns
+ -------
+ string_array : StringArray
+ """
+ return Array.from_buffers(large_utf8(), length,
+ [null_bitmap, value_offsets, data],
+ null_count, offset)
+
+
+cdef class BinaryArray(Array):
+ """
+ Concrete class for Arrow arrays of variable-sized binary data type.
+ """
+ @property
+ def total_values_length(self):
+ """
+ The number of bytes from beginning to end of the data buffer addressed
+ by the offsets of this BinaryArray.
+ """
+ return (<CBinaryArray*> self.ap).total_values_length()
+
+
+cdef class LargeBinaryArray(Array):
+ """
+ Concrete class for Arrow arrays of large variable-sized binary data type.
+ """
+ @property
+ def total_values_length(self):
+ """
+ The number of bytes from beginning to end of the data buffer addressed
+ by the offsets of this LargeBinaryArray.
+ """
+ return (<CLargeBinaryArray*> self.ap).total_values_length()
+
+
+cdef class DictionaryArray(Array):
+ """
+ Concrete class for dictionary-encoded Arrow arrays.
+ """
+
+ def dictionary_encode(self):
+ return self
+
+ def dictionary_decode(self):
+ """
+ Decodes the DictionaryArray to an Array.
+ """
+ return self.dictionary.take(self.indices)
+
+ @property
+ def dictionary(self):
+ cdef CDictionaryArray* darr = <CDictionaryArray*>(self.ap)
+
+ if self._dictionary is None:
+ self._dictionary = pyarrow_wrap_array(darr.dictionary())
+
+ return self._dictionary
+
+ @property
+ def indices(self):
+ cdef CDictionaryArray* darr = <CDictionaryArray*>(self.ap)
+
+ if self._indices is None:
+ self._indices = pyarrow_wrap_array(darr.indices())
+
+ return self._indices
+
+ @staticmethod
+ def from_arrays(indices, dictionary, mask=None, bint ordered=False,
+ bint from_pandas=False, bint safe=True,
+ MemoryPool memory_pool=None):
+ """
+ Construct a DictionaryArray from indices and values.
+
+ Parameters
+ ----------
+ indices : pyarrow.Array, numpy.ndarray or pandas.Series, int type
+ Non-negative integers referencing the dictionary values by zero
+ based index.
+ dictionary : pyarrow.Array, ndarray or pandas.Series
+ The array of values referenced by the indices.
+ mask : ndarray or pandas.Series, bool type
+ True values indicate that indices are actually null.
+ from_pandas : bool, default False
+ If True, the indices should be treated as though they originated in
+ a pandas.Categorical (null encoded as -1).
+ ordered : bool, default False
+ Set to True if the category values are ordered.
+ safe : bool, default True
+ If True, check that the dictionary indices are in range.
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise uses default pool.
+
+ Returns
+ -------
+ dict_array : DictionaryArray
+ """
+ cdef:
+ Array _indices, _dictionary
+ shared_ptr[CDataType] c_type
+ shared_ptr[CArray] c_result
+
+ if isinstance(indices, Array):
+ if mask is not None:
+ raise NotImplementedError(
+ "mask not implemented with Arrow array inputs yet")
+ _indices = indices
+ else:
+ if from_pandas:
+ _indices = _codes_to_indices(indices, mask, None, memory_pool)
+ else:
+ _indices = array(indices, mask=mask, memory_pool=memory_pool)
+
+ if isinstance(dictionary, Array):
+ _dictionary = dictionary
+ else:
+ _dictionary = array(dictionary, memory_pool=memory_pool)
+
+ if not isinstance(_indices, IntegerArray):
+ raise ValueError('Indices must be integer type')
+
+ cdef c_bool c_ordered = ordered
+
+ c_type.reset(new CDictionaryType(_indices.type.sp_type,
+ _dictionary.sp_array.get().type(),
+ c_ordered))
+
+ if safe:
+ with nogil:
+ c_result = GetResultValue(
+ CDictionaryArray.FromArrays(c_type, _indices.sp_array,
+ _dictionary.sp_array))
+ else:
+ c_result.reset(new CDictionaryArray(c_type, _indices.sp_array,
+ _dictionary.sp_array))
+
+ cdef Array result = pyarrow_wrap_array(c_result)
+ result.validate()
+ return result
+
+
+cdef class StructArray(Array):
+ """
+ Concrete class for Arrow arrays of a struct data type.
+ """
+
+ def field(self, index):
+ """
+ Retrieves the child array belonging to field.
+
+ Parameters
+ ----------
+ index : Union[int, str]
+ Index / position or name of the field.
+
+ Returns
+ -------
+ result : Array
+ """
+ cdef:
+ CStructArray* arr = <CStructArray*> self.ap
+ shared_ptr[CArray] child
+
+ if isinstance(index, (bytes, str)):
+ child = arr.GetFieldByName(tobytes(index))
+ if child == nullptr:
+ raise KeyError(index)
+ elif isinstance(index, int):
+ child = arr.field(
+ <int>_normalize_index(index, self.ap.num_fields()))
+ else:
+ raise TypeError('Expected integer or string index')
+
+ return pyarrow_wrap_array(child)
+
+ def flatten(self, MemoryPool memory_pool=None):
+ """
+ Return one individual array for each field in the struct.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool.
+
+ Returns
+ -------
+ result : List[Array]
+ """
+ cdef:
+ vector[shared_ptr[CArray]] arrays
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ CStructArray* sarr = <CStructArray*> self.ap
+
+ with nogil:
+ arrays = GetResultValue(sarr.Flatten(pool))
+
+ return [pyarrow_wrap_array(arr) for arr in arrays]
+
+ @staticmethod
+ def from_arrays(arrays, names=None, fields=None, mask=None,
+ memory_pool=None):
+ """
+ Construct StructArray from collection of arrays representing
+ each field in the struct.
+
+ Either field names or field instances must be passed.
+
+ Parameters
+ ----------
+ arrays : sequence of Array
+ names : List[str] (optional)
+ Field names for each struct child.
+ fields : List[Field] (optional)
+ Field instances for each struct child.
+ mask : pyarrow.Array[bool] (optional)
+ Indicate which values are null (True) or not null (False).
+ memory_pool : MemoryPool (optional)
+ For memory allocations, if required, otherwise uses default pool.
+
+ Returns
+ -------
+ result : StructArray
+ """
+ cdef:
+ shared_ptr[CArray] c_array
+ shared_ptr[CBuffer] c_mask
+ vector[shared_ptr[CArray]] c_arrays
+ vector[c_string] c_names
+ vector[shared_ptr[CField]] c_fields
+ CResult[shared_ptr[CArray]] c_result
+ ssize_t num_arrays
+ ssize_t length
+ ssize_t i
+ Field py_field
+ DataType struct_type
+
+ if names is None and fields is None:
+ raise ValueError('Must pass either names or fields')
+ if names is not None and fields is not None:
+ raise ValueError('Must pass either names or fields, not both')
+
+ if mask is None:
+ c_mask = shared_ptr[CBuffer]()
+ elif isinstance(mask, Array):
+ if mask.type.id != Type_BOOL:
+ raise ValueError('Mask must be a pyarrow.Array of type bool')
+ if mask.null_count != 0:
+ raise ValueError('Mask must not contain nulls')
+ inverted_mask = _pc().invert(mask, memory_pool=memory_pool)
+ c_mask = pyarrow_unwrap_buffer(inverted_mask.buffers()[1])
+ else:
+ raise ValueError('Mask must be a pyarrow.Array of type bool')
+
+ arrays = [asarray(x) for x in arrays]
+ for arr in arrays:
+ c_array = pyarrow_unwrap_array(arr)
+ if c_array == nullptr:
+ raise TypeError(f"Expected Array, got {arr.__class__}")
+ c_arrays.push_back(c_array)
+ if names is not None:
+ for name in names:
+ c_names.push_back(tobytes(name))
+ else:
+ for item in fields:
+ if isinstance(item, tuple):
+ py_field = field(*item)
+ else:
+ py_field = item
+ c_fields.push_back(py_field.sp_field)
+
+ if (c_arrays.size() == 0 and c_names.size() == 0 and
+ c_fields.size() == 0):
+ # The C++ side doesn't allow this
+ return array([], struct([]))
+
+ if names is not None:
+ # XXX Cannot pass "nullptr" for a shared_ptr<T> argument:
+ # https://github.com/cython/cython/issues/3020
+ c_result = CStructArray.MakeFromFieldNames(
+ c_arrays, c_names, c_mask, -1, 0)
+ else:
+ c_result = CStructArray.MakeFromFields(
+ c_arrays, c_fields, c_mask, -1, 0)
+ cdef Array result = pyarrow_wrap_array(GetResultValue(c_result))
+ result.validate()
+ return result
+
+
+cdef class ExtensionArray(Array):
+ """
+ Concrete class for Arrow extension arrays.
+ """
+
+ @property
+ def storage(self):
+ cdef:
+ CExtensionArray* ext_array = <CExtensionArray*>(self.ap)
+
+ return pyarrow_wrap_array(ext_array.storage())
+
+ @staticmethod
+ def from_storage(BaseExtensionType typ, Array storage):
+ """
+ Construct ExtensionArray from type and storage array.
+
+ Parameters
+ ----------
+ typ : DataType
+ The extension type for the result array.
+ storage : Array
+ The underlying storage for the result array.
+
+ Returns
+ -------
+ ext_array : ExtensionArray
+ """
+ cdef:
+ shared_ptr[CExtensionArray] ext_array
+
+ if storage.type != typ.storage_type:
+ raise TypeError("Incompatible storage type {0} "
+ "for extension type {1}".format(storage.type, typ))
+
+ ext_array = make_shared[CExtensionArray](typ.sp_type, storage.sp_array)
+ cdef Array result = pyarrow_wrap_array(<shared_ptr[CArray]> ext_array)
+ result.validate()
+ return result
+
+ def _to_pandas(self, options, **kwargs):
+ pandas_dtype = None
+ try:
+ pandas_dtype = self.type.to_pandas_dtype()
+ except NotImplementedError:
+ pass
+
+ # pandas ExtensionDtype that implements conversion from pyarrow
+ if hasattr(pandas_dtype, '__from_arrow__'):
+ arr = pandas_dtype.__from_arrow__(self)
+ return pandas_api.series(arr)
+
+ # otherwise convert the storage array with the base implementation
+ return Array._to_pandas(self.storage, options, **kwargs)
+
+ def to_numpy(self, **kwargs):
+ """
+ Convert extension array to a numpy ndarray.
+
+ See Also
+ --------
+ Array.to_numpy
+ """
+ return self.storage.to_numpy(**kwargs)
+
+
+cdef dict _array_classes = {
+ _Type_NA: NullArray,
+ _Type_BOOL: BooleanArray,
+ _Type_UINT8: UInt8Array,
+ _Type_UINT16: UInt16Array,
+ _Type_UINT32: UInt32Array,
+ _Type_UINT64: UInt64Array,
+ _Type_INT8: Int8Array,
+ _Type_INT16: Int16Array,
+ _Type_INT32: Int32Array,
+ _Type_INT64: Int64Array,
+ _Type_DATE32: Date32Array,
+ _Type_DATE64: Date64Array,
+ _Type_TIMESTAMP: TimestampArray,
+ _Type_TIME32: Time32Array,
+ _Type_TIME64: Time64Array,
+ _Type_DURATION: DurationArray,
+ _Type_INTERVAL_MONTH_DAY_NANO: MonthDayNanoIntervalArray,
+ _Type_HALF_FLOAT: HalfFloatArray,
+ _Type_FLOAT: FloatArray,
+ _Type_DOUBLE: DoubleArray,
+ _Type_LIST: ListArray,
+ _Type_LARGE_LIST: LargeListArray,
+ _Type_MAP: MapArray,
+ _Type_FIXED_SIZE_LIST: FixedSizeListArray,
+ _Type_SPARSE_UNION: UnionArray,
+ _Type_DENSE_UNION: UnionArray,
+ _Type_BINARY: BinaryArray,
+ _Type_STRING: StringArray,
+ _Type_LARGE_BINARY: LargeBinaryArray,
+ _Type_LARGE_STRING: LargeStringArray,
+ _Type_DICTIONARY: DictionaryArray,
+ _Type_FIXED_SIZE_BINARY: FixedSizeBinaryArray,
+ _Type_DECIMAL128: Decimal128Array,
+ _Type_DECIMAL256: Decimal256Array,
+ _Type_STRUCT: StructArray,
+ _Type_EXTENSION: ExtensionArray,
+}
+
+
+cdef object get_array_class_from_type(
+ const shared_ptr[CDataType]& sp_data_type):
+ cdef CDataType* data_type = sp_data_type.get()
+ if data_type == NULL:
+ raise ValueError('Array data type was NULL')
+
+ if data_type.id() == _Type_EXTENSION:
+ py_ext_data_type = pyarrow_wrap_data_type(sp_data_type)
+ return py_ext_data_type.__arrow_ext_class__()
+ else:
+ return _array_classes[data_type.id()]
+
+
+cdef object get_values(object obj, bint* is_series):
+ if pandas_api.is_series(obj) or pandas_api.is_index(obj):
+ result = pandas_api.get_values(obj)
+ is_series[0] = True
+ elif isinstance(obj, np.ndarray):
+ result = obj
+ is_series[0] = False
+ else:
+ result = pandas_api.series(obj).values
+ is_series[0] = False
+
+ return result
+
+
+def concat_arrays(arrays, MemoryPool memory_pool=None):
+ """
+ Concatenate the given arrays.
+
+ The contents of the input arrays are copied into the returned array.
+
+ Raises
+ ------
+ ArrowInvalid : if not all of the arrays have the same type.
+
+ Parameters
+ ----------
+ arrays : iterable of pyarrow.Array
+ Arrays to concatenate, must be identically typed.
+ memory_pool : MemoryPool, default None
+ For memory allocations. If None, the default pool is used.
+ """
+ cdef:
+ vector[shared_ptr[CArray]] c_arrays
+ shared_ptr[CArray] c_concatenated
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+
+ for array in arrays:
+ if not isinstance(array, Array):
+ raise TypeError("Iterable should contain Array objects, "
+ "got {0} instead".format(type(array)))
+ c_arrays.push_back(pyarrow_unwrap_array(array))
+
+ with nogil:
+ c_concatenated = GetResultValue(Concatenate(c_arrays, pool))
+
+ return pyarrow_wrap_array(c_concatenated)
+
+
+def _empty_array(DataType type):
+ """
+ Create empty array of the given type.
+ """
+ if type.id == Type_DICTIONARY:
+ arr = DictionaryArray.from_arrays(
+ _empty_array(type.index_type), _empty_array(type.value_type),
+ ordered=type.ordered)
+ else:
+ arr = array([], type=type)
+ return arr
diff --git a/src/arrow/python/pyarrow/benchmark.pxi b/src/arrow/python/pyarrow/benchmark.pxi
new file mode 100644
index 000000000..ab251017d
--- /dev/null
+++ b/src/arrow/python/pyarrow/benchmark.pxi
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+def benchmark_PandasObjectIsNull(list obj):
+ Benchmark_PandasObjectIsNull(obj)
diff --git a/src/arrow/python/pyarrow/benchmark.py b/src/arrow/python/pyarrow/benchmark.py
new file mode 100644
index 000000000..25ee1141f
--- /dev/null
+++ b/src/arrow/python/pyarrow/benchmark.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# flake8: noqa
+
+
+from pyarrow.lib import benchmark_PandasObjectIsNull
diff --git a/src/arrow/python/pyarrow/builder.pxi b/src/arrow/python/pyarrow/builder.pxi
new file mode 100644
index 000000000..a34ea5412
--- /dev/null
+++ b/src/arrow/python/pyarrow/builder.pxi
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+cdef class StringBuilder(_Weakrefable):
+ """
+ Builder class for UTF8 strings.
+
+ This class exposes facilities for incrementally adding string values and
+ building the null bitmap for a pyarrow.Array (type='string').
+ """
+ cdef:
+ unique_ptr[CStringBuilder] builder
+
+ def __cinit__(self, MemoryPool memory_pool=None):
+ cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ self.builder.reset(new CStringBuilder(pool))
+
+ def append(self, value):
+ """
+ Append a single value to the builder.
+
+ The value can either be a string/bytes object or a null value
+ (np.nan or None).
+
+ Parameters
+ ----------
+ value : string/bytes or np.nan/None
+ The value to append to the string array builder.
+ """
+ if value is None or value is np.nan:
+ self.builder.get().AppendNull()
+ elif isinstance(value, (bytes, str)):
+ self.builder.get().Append(tobytes(value))
+ else:
+ raise TypeError('StringBuilder only accepts string objects')
+
+ def append_values(self, values):
+ """
+ Append all the values from an iterable.
+
+ Parameters
+ ----------
+ values : iterable of string/bytes or np.nan/None values
+ The values to append to the string array builder.
+ """
+ for value in values:
+ self.append(value)
+
+ def finish(self):
+ """
+ Return result of builder as an Array object; also resets the builder.
+
+ Returns
+ -------
+ array : pyarrow.Array
+ """
+ cdef shared_ptr[CArray] out
+ with nogil:
+ self.builder.get().Finish(&out)
+ return pyarrow_wrap_array(out)
+
+ @property
+ def null_count(self):
+ return self.builder.get().null_count()
+
+ def __len__(self):
+ return self.builder.get().length()
diff --git a/src/arrow/python/pyarrow/cffi.py b/src/arrow/python/pyarrow/cffi.py
new file mode 100644
index 000000000..961b61dee
--- /dev/null
+++ b/src/arrow/python/pyarrow/cffi.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+
+import cffi
+
+c_source = """
+ struct ArrowSchema {
+ // Array type description
+ const char* format;
+ const char* name;
+ const char* metadata;
+ int64_t flags;
+ int64_t n_children;
+ struct ArrowSchema** children;
+ struct ArrowSchema* dictionary;
+
+ // Release callback
+ void (*release)(struct ArrowSchema*);
+ // Opaque producer-specific data
+ void* private_data;
+ };
+
+ struct ArrowArray {
+ // Array data description
+ int64_t length;
+ int64_t null_count;
+ int64_t offset;
+ int64_t n_buffers;
+ int64_t n_children;
+ const void** buffers;
+ struct ArrowArray** children;
+ struct ArrowArray* dictionary;
+
+ // Release callback
+ void (*release)(struct ArrowArray*);
+ // Opaque producer-specific data
+ void* private_data;
+ };
+
+ struct ArrowArrayStream {
+ int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out);
+ int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out);
+
+ const char* (*get_last_error)(struct ArrowArrayStream*);
+
+ // Release callback
+ void (*release)(struct ArrowArrayStream*);
+ // Opaque producer-specific data
+ void* private_data;
+ };
+ """
+
+# TODO use out-of-line mode for faster import and avoid C parsing
+ffi = cffi.FFI()
+ffi.cdef(c_source)
diff --git a/src/arrow/python/pyarrow/compat.pxi b/src/arrow/python/pyarrow/compat.pxi
new file mode 100644
index 000000000..a5db5741b
--- /dev/null
+++ b/src/arrow/python/pyarrow/compat.pxi
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+
+def encode_file_path(path):
+ if isinstance(path, str):
+ # POSIX systems can handle utf-8. UTF8 is converted to utf16-le in
+ # libarrow
+ encoded_path = path.encode('utf-8')
+ else:
+ encoded_path = path
+
+ # Windows file system requires utf-16le for file names; Arrow C++ libraries
+ # will convert utf8 to utf16
+ return encoded_path
+
+
+if sys.version_info >= (3, 7):
+ # Starting with Python 3.7, dicts are guaranteed to be insertion-ordered.
+ ordered_dict = dict
+else:
+ import collections
+ ordered_dict = collections.OrderedDict
+
+
+try:
+ import pickle5 as builtin_pickle
+except ImportError:
+ import pickle as builtin_pickle
+
+
+try:
+ import cloudpickle as pickle
+except ImportError:
+ pickle = builtin_pickle
+
+
+def tobytes(o):
+ if isinstance(o, str):
+ return o.encode('utf8')
+ else:
+ return o
+
+
+def frombytes(o, *, safe=False):
+ if safe:
+ return o.decode('utf8', errors='replace')
+ else:
+ return o.decode('utf8')
diff --git a/src/arrow/python/pyarrow/compat.py b/src/arrow/python/pyarrow/compat.py
new file mode 100644
index 000000000..814a51bef
--- /dev/null
+++ b/src/arrow/python/pyarrow/compat.py
@@ -0,0 +1,29 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# flake8: noqa
+
+import pyarrow.util as util
+import warnings
+
+
+warnings.warn("pyarrow.compat has been deprecated and will be removed in a "
+ "future release", FutureWarning)
+
+
+guid = util._deprecate_api("compat.guid", "util.guid",
+ util.guid, "1.0.0")
diff --git a/src/arrow/python/pyarrow/compute.py b/src/arrow/python/pyarrow/compute.py
new file mode 100644
index 000000000..6e3bd7fca
--- /dev/null
+++ b/src/arrow/python/pyarrow/compute.py
@@ -0,0 +1,759 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyarrow._compute import ( # noqa
+ Function,
+ FunctionOptions,
+ FunctionRegistry,
+ HashAggregateFunction,
+ HashAggregateKernel,
+ Kernel,
+ ScalarAggregateFunction,
+ ScalarAggregateKernel,
+ ScalarFunction,
+ ScalarKernel,
+ VectorFunction,
+ VectorKernel,
+ # Option classes
+ ArraySortOptions,
+ AssumeTimezoneOptions,
+ CastOptions,
+ CountOptions,
+ DayOfWeekOptions,
+ DictionaryEncodeOptions,
+ ElementWiseAggregateOptions,
+ ExtractRegexOptions,
+ FilterOptions,
+ IndexOptions,
+ JoinOptions,
+ MakeStructOptions,
+ MatchSubstringOptions,
+ ModeOptions,
+ NullOptions,
+ PadOptions,
+ PartitionNthOptions,
+ QuantileOptions,
+ ReplaceSliceOptions,
+ ReplaceSubstringOptions,
+ RoundOptions,
+ RoundToMultipleOptions,
+ ScalarAggregateOptions,
+ SelectKOptions,
+ SetLookupOptions,
+ SliceOptions,
+ SortOptions,
+ SplitOptions,
+ SplitPatternOptions,
+ StrftimeOptions,
+ StrptimeOptions,
+ TakeOptions,
+ TDigestOptions,
+ TrimOptions,
+ VarianceOptions,
+ WeekOptions,
+ # Functions
+ call_function,
+ function_registry,
+ get_function,
+ list_functions,
+)
+
+import inspect
+from textwrap import dedent
+import warnings
+
+import pyarrow as pa
+
+
+def _get_arg_names(func):
+ return func._doc.arg_names
+
+
+def _decorate_compute_function(wrapper, exposed_name, func, option_class):
+ # Decorate the given compute function wrapper with useful metadata
+ # and documentation.
+ wrapper.__arrow_compute_function__ = dict(name=func.name,
+ arity=func.arity)
+ wrapper.__name__ = exposed_name
+ wrapper.__qualname__ = exposed_name
+
+ doc_pieces = []
+
+ cpp_doc = func._doc
+ summary = cpp_doc.summary
+ if not summary:
+ arg_str = "arguments" if func.arity > 1 else "argument"
+ summary = ("Call compute function {!r} with the given {}"
+ .format(func.name, arg_str))
+
+ description = cpp_doc.description
+ arg_names = _get_arg_names(func)
+
+ doc_pieces.append("""\
+ {}.
+
+ """.format(summary))
+
+ if description:
+ doc_pieces.append("{}\n\n".format(description))
+
+ doc_pieces.append("""\
+ Parameters
+ ----------
+ """)
+
+ for arg_name in arg_names:
+ if func.kind in ('vector', 'scalar_aggregate'):
+ arg_type = 'Array-like'
+ else:
+ arg_type = 'Array-like or scalar-like'
+ doc_pieces.append("""\
+ {} : {}
+ Argument to compute function
+ """.format(arg_name, arg_type))
+
+ doc_pieces.append("""\
+ memory_pool : pyarrow.MemoryPool, optional
+ If not passed, will allocate memory from the default memory pool.
+ """)
+ if option_class is not None:
+ doc_pieces.append("""\
+ options : pyarrow.compute.{0}, optional
+ Parameters altering compute function semantics.
+ """.format(option_class.__name__))
+ options_sig = inspect.signature(option_class)
+ for p in options_sig.parameters.values():
+ doc_pieces.append("""\
+ {0} : optional
+ Parameter for {1} constructor. Either `options`
+ or `{0}` can be passed, but not both at the same time.
+ """.format(p.name, option_class.__name__))
+
+ wrapper.__doc__ = "".join(dedent(s) for s in doc_pieces)
+ return wrapper
+
+
+def _get_options_class(func):
+ class_name = func._doc.options_class
+ if not class_name:
+ return None
+ try:
+ return globals()[class_name]
+ except KeyError:
+ warnings.warn("Python binding for {} not exposed"
+ .format(class_name), RuntimeWarning)
+ return None
+
+
+def _handle_options(name, option_class, options, kwargs):
+ if kwargs:
+ if options is None:
+ return option_class(**kwargs)
+ raise TypeError(
+ "Function {!r} called with both an 'options' argument "
+ "and additional named arguments"
+ .format(name))
+
+ if options is not None:
+ if isinstance(options, dict):
+ return option_class(**options)
+ elif isinstance(options, option_class):
+ return options
+ raise TypeError(
+ "Function {!r} expected a {} parameter, got {}"
+ .format(name, option_class, type(options)))
+
+ return options
+
+
+def _make_generic_wrapper(func_name, func, option_class):
+ if option_class is None:
+ def wrapper(*args, memory_pool=None):
+ return func.call(args, None, memory_pool)
+ else:
+ def wrapper(*args, memory_pool=None, options=None, **kwargs):
+ options = _handle_options(func_name, option_class, options,
+ kwargs)
+ return func.call(args, options, memory_pool)
+ return wrapper
+
+
+def _make_signature(arg_names, var_arg_names, option_class):
+ from inspect import Parameter
+ params = []
+ for name in arg_names:
+ params.append(Parameter(name, Parameter.POSITIONAL_OR_KEYWORD))
+ for name in var_arg_names:
+ params.append(Parameter(name, Parameter.VAR_POSITIONAL))
+ params.append(Parameter("memory_pool", Parameter.KEYWORD_ONLY,
+ default=None))
+ if option_class is not None:
+ params.append(Parameter("options", Parameter.KEYWORD_ONLY,
+ default=None))
+ options_sig = inspect.signature(option_class)
+ for p in options_sig.parameters.values():
+ # XXX for now, our generic wrappers don't allow positional
+ # option arguments
+ params.append(p.replace(kind=Parameter.KEYWORD_ONLY))
+ return inspect.Signature(params)
+
+
+def _wrap_function(name, func):
+ option_class = _get_options_class(func)
+ arg_names = _get_arg_names(func)
+ has_vararg = arg_names and arg_names[-1].startswith('*')
+ if has_vararg:
+ var_arg_names = [arg_names.pop().lstrip('*')]
+ else:
+ var_arg_names = []
+
+ wrapper = _make_generic_wrapper(name, func, option_class)
+ wrapper.__signature__ = _make_signature(arg_names, var_arg_names,
+ option_class)
+ return _decorate_compute_function(wrapper, name, func, option_class)
+
+
+def _make_global_functions():
+ """
+ Make global functions wrapping each compute function.
+
+ Note that some of the automatically-generated wrappers may be overriden
+ by custom versions below.
+ """
+ g = globals()
+ reg = function_registry()
+
+ # Avoid clashes with Python keywords
+ rewrites = {'and': 'and_',
+ 'or': 'or_'}
+
+ for cpp_name in reg.list_functions():
+ name = rewrites.get(cpp_name, cpp_name)
+ func = reg.get_function(cpp_name)
+ assert name not in g, name
+ g[cpp_name] = g[name] = _wrap_function(name, func)
+
+
+_make_global_functions()
+
+
+def cast(arr, target_type, safe=True):
+ """
+ Cast array values to another data type. Can also be invoked as an array
+ instance method.
+
+ Parameters
+ ----------
+ arr : Array or ChunkedArray
+ target_type : DataType or type string alias
+ Type to cast to
+ safe : bool, default True
+ Check for overflows or other unsafe conversions
+
+ Examples
+ --------
+ >>> from datetime import datetime
+ >>> import pyarrow as pa
+ >>> arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)])
+ >>> arr.type
+ TimestampType(timestamp[us])
+
+ You can use ``pyarrow.DataType`` objects to specify the target type:
+
+ >>> cast(arr, pa.timestamp('ms'))
+ <pyarrow.lib.TimestampArray object at 0x7fe93c0f6910>
+ [
+ 2010-01-01 00:00:00.000,
+ 2015-01-01 00:00:00.000
+ ]
+
+ >>> cast(arr, pa.timestamp('ms')).type
+ TimestampType(timestamp[ms])
+
+ Alternatively, it is also supported to use the string aliases for these
+ types:
+
+ >>> arr.cast('timestamp[ms]')
+ <pyarrow.lib.TimestampArray object at 0x10420eb88>
+ [
+ 1262304000000,
+ 1420070400000
+ ]
+ >>> arr.cast('timestamp[ms]').type
+ TimestampType(timestamp[ms])
+
+ Returns
+ -------
+ casted : Array
+ """
+ if target_type is None:
+ raise ValueError("Cast target type must not be None")
+ if safe:
+ options = CastOptions.safe(target_type)
+ else:
+ options = CastOptions.unsafe(target_type)
+ return call_function("cast", [arr], options)
+
+
+def count_substring(array, pattern, *, ignore_case=False):
+ """
+ Count the occurrences of substring *pattern* in each value of a
+ string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ pattern to search for exact matches
+ ignore_case : bool, default False
+ Ignore case while searching.
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("count_substring", [array],
+ MatchSubstringOptions(pattern,
+ ignore_case=ignore_case))
+
+
+def count_substring_regex(array, pattern, *, ignore_case=False):
+ """
+ Count the non-overlapping matches of regex *pattern* in each value
+ of a string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ pattern to search for exact matches
+ ignore_case : bool, default False
+ Ignore case while searching.
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("count_substring_regex", [array],
+ MatchSubstringOptions(pattern,
+ ignore_case=ignore_case))
+
+
+def find_substring(array, pattern, *, ignore_case=False):
+ """
+ Find the index of the first occurrence of substring *pattern* in each
+ value of a string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ pattern to search for exact matches
+ ignore_case : bool, default False
+ Ignore case while searching.
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("find_substring", [array],
+ MatchSubstringOptions(pattern,
+ ignore_case=ignore_case))
+
+
+def find_substring_regex(array, pattern, *, ignore_case=False):
+ """
+ Find the index of the first match of regex *pattern* in each
+ value of a string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ regex pattern to search for
+ ignore_case : bool, default False
+ Ignore case while searching.
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("find_substring_regex", [array],
+ MatchSubstringOptions(pattern,
+ ignore_case=ignore_case))
+
+
+def match_like(array, pattern, *, ignore_case=False):
+ """
+ Test if the SQL-style LIKE pattern *pattern* matches a value of a
+ string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ SQL-style LIKE pattern. '%' will match any number of
+ characters, '_' will match exactly one character, and all
+ other characters match themselves. To match a literal percent
+ sign or underscore, precede the character with a backslash.
+ ignore_case : bool, default False
+ Ignore case while searching.
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+
+ """
+ return call_function("match_like", [array],
+ MatchSubstringOptions(pattern,
+ ignore_case=ignore_case))
+
+
+def match_substring(array, pattern, *, ignore_case=False):
+ """
+ Test if substring *pattern* is contained within a value of a string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ pattern to search for exact matches
+ ignore_case : bool, default False
+ Ignore case while searching.
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("match_substring", [array],
+ MatchSubstringOptions(pattern,
+ ignore_case=ignore_case))
+
+
+def match_substring_regex(array, pattern, *, ignore_case=False):
+ """
+ Test if regex *pattern* matches at any position a value of a string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ regex pattern to search
+ ignore_case : bool, default False
+ Ignore case while searching.
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("match_substring_regex", [array],
+ MatchSubstringOptions(pattern,
+ ignore_case=ignore_case))
+
+
+def mode(array, n=1, *, skip_nulls=True, min_count=0):
+ """
+ Return top-n most common values and number of times they occur in a passed
+ numerical (chunked) array, in descending order of occurrence. If there are
+ multiple values with same count, the smaller one is returned first.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ n : int, default 1
+ Specify the top-n values.
+ skip_nulls : bool, default True
+ If True, ignore nulls in the input. Else return an empty array
+ if any input is null.
+ min_count : int, default 0
+ If there are fewer than this many values in the input, return
+ an empty array.
+
+ Returns
+ -------
+ An array of <input type "Mode", int64_t "Count"> structs
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> import pyarrow.compute as pc
+ >>> arr = pa.array([1, 1, 2, 2, 3, 2, 2, 2])
+ >>> modes = pc.mode(arr, 2)
+ >>> modes[0]
+ <pyarrow.StructScalar: {'mode': 2, 'count': 5}>
+ >>> modes[1]
+ <pyarrow.StructScalar: {'mode': 1, 'count': 2}>
+ """
+ options = ModeOptions(n, skip_nulls=skip_nulls, min_count=min_count)
+ return call_function("mode", [array], options)
+
+
+def filter(data, mask, null_selection_behavior='drop'):
+ """
+ Select values (or records) from array- or table-like data given boolean
+ filter, where true values are selected.
+
+ Parameters
+ ----------
+ data : Array, ChunkedArray, RecordBatch, or Table
+ mask : Array, ChunkedArray
+ Must be of boolean type
+ null_selection_behavior : str, default 'drop'
+ Configure the behavior on encountering a null slot in the mask.
+ Allowed values are 'drop' and 'emit_null'.
+
+ - 'drop': nulls will be treated as equivalent to False.
+ - 'emit_null': nulls will result in a null in the output.
+
+ Returns
+ -------
+ result : depends on inputs
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> arr = pa.array(["a", "b", "c", None, "e"])
+ >>> mask = pa.array([True, False, None, False, True])
+ >>> arr.filter(mask)
+ <pyarrow.lib.StringArray object at 0x7fa826df9200>
+ [
+ "a",
+ "e"
+ ]
+ >>> arr.filter(mask, null_selection_behavior='emit_null')
+ <pyarrow.lib.StringArray object at 0x7fa826df9200>
+ [
+ "a",
+ null,
+ "e"
+ ]
+ """
+ options = FilterOptions(null_selection_behavior)
+ return call_function('filter', [data, mask], options)
+
+
+def index(data, value, start=None, end=None, *, memory_pool=None):
+ """
+ Find the index of the first occurrence of a given value.
+
+ Parameters
+ ----------
+ data : Array or ChunkedArray
+ value : Scalar-like object
+ start : int, optional
+ end : int, optional
+ memory_pool : MemoryPool, optional
+ If not passed, will allocate memory from the default memory pool.
+
+ Returns
+ -------
+ index : the index, or -1 if not found
+ """
+ if start is not None:
+ if end is not None:
+ data = data.slice(start, end - start)
+ else:
+ data = data.slice(start)
+ elif end is not None:
+ data = data.slice(0, end)
+
+ if not isinstance(value, pa.Scalar):
+ value = pa.scalar(value, type=data.type)
+ elif data.type != value.type:
+ value = pa.scalar(value.as_py(), type=data.type)
+ options = IndexOptions(value=value)
+ result = call_function('index', [data], options, memory_pool)
+ if start is not None and result.as_py() >= 0:
+ result = pa.scalar(result.as_py() + start, type=pa.int64())
+ return result
+
+
+def take(data, indices, *, boundscheck=True, memory_pool=None):
+ """
+ Select values (or records) from array- or table-like data given integer
+ selection indices.
+
+ The result will be of the same type(s) as the input, with elements taken
+ from the input array (or record batch / table fields) at the given
+ indices. If an index is null then the corresponding value in the output
+ will be null.
+
+ Parameters
+ ----------
+ data : Array, ChunkedArray, RecordBatch, or Table
+ indices : Array, ChunkedArray
+ Must be of integer type
+ boundscheck : boolean, default True
+ Whether to boundscheck the indices. If False and there is an out of
+ bounds index, will likely cause the process to crash.
+ memory_pool : MemoryPool, optional
+ If not passed, will allocate memory from the default memory pool.
+
+ Returns
+ -------
+ result : depends on inputs
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> arr = pa.array(["a", "b", "c", None, "e", "f"])
+ >>> indices = pa.array([0, None, 4, 3])
+ >>> arr.take(indices)
+ <pyarrow.lib.StringArray object at 0x7ffa4fc7d368>
+ [
+ "a",
+ null,
+ "e",
+ null
+ ]
+ """
+ options = TakeOptions(boundscheck=boundscheck)
+ return call_function('take', [data, indices], options, memory_pool)
+
+
+def fill_null(values, fill_value):
+ """
+ Replace each null element in values with fill_value. The fill_value must be
+ the same type as values or able to be implicitly casted to the array's
+ type.
+
+ This is an alias for :func:`coalesce`.
+
+ Parameters
+ ----------
+ values : Array, ChunkedArray, or Scalar-like object
+ Each null element is replaced with the corresponding value
+ from fill_value.
+ fill_value : Array, ChunkedArray, or Scalar-like object
+ If not same type as data will attempt to cast.
+
+ Returns
+ -------
+ result : depends on inputs
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> arr = pa.array([1, 2, None, 3], type=pa.int8())
+ >>> fill_value = pa.scalar(5, type=pa.int8())
+ >>> arr.fill_null(fill_value)
+ pyarrow.lib.Int8Array object at 0x7f95437f01a0>
+ [
+ 1,
+ 2,
+ 5,
+ 3
+ ]
+ """
+ if not isinstance(fill_value, (pa.Array, pa.ChunkedArray, pa.Scalar)):
+ fill_value = pa.scalar(fill_value, type=values.type)
+ elif values.type != fill_value.type:
+ fill_value = pa.scalar(fill_value.as_py(), type=values.type)
+
+ return call_function("coalesce", [values, fill_value])
+
+
+def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None):
+ """
+ Select the indices of the top-k ordered elements from array- or table-like
+ data.
+
+ This is a specialization for :func:`select_k_unstable`. Output is not
+ guaranteed to be stable.
+
+ Parameters
+ ----------
+ values : Array, ChunkedArray, RecordBatch, or Table
+ Data to sort and get top indices from.
+ k : int
+ The number of `k` elements to keep.
+ sort_keys : List-like
+ Column key names to order by when input is table-like data.
+ memory_pool : MemoryPool, optional
+ If not passed, will allocate memory from the default memory pool.
+
+ Returns
+ -------
+ result : Array of indices
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> import pyarrow.compute as pc
+ >>> arr = pa.array(["a", "b", "c", None, "e", "f"])
+ >>> pc.top_k_unstable(arr, k=3)
+ <pyarrow.lib.UInt64Array object at 0x7fdcb19d7f30>
+ [
+ 5,
+ 4,
+ 2
+ ]
+ """
+ if sort_keys is None:
+ sort_keys = []
+ if isinstance(values, (pa.Array, pa.ChunkedArray)):
+ sort_keys.append(("dummy", "descending"))
+ else:
+ sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys)
+ options = SelectKOptions(k, sort_keys)
+ return call_function("select_k_unstable", [values], options, memory_pool)
+
+
+def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None):
+ """
+ Select the indices of the bottom-k ordered elements from
+ array- or table-like data.
+
+ This is a specialization for :func:`select_k_unstable`. Output is not
+ guaranteed to be stable.
+
+ Parameters
+ ----------
+ values : Array, ChunkedArray, RecordBatch, or Table
+ Data to sort and get bottom indices from.
+ k : int
+ The number of `k` elements to keep.
+ sort_keys : List-like
+ Column key names to order by when input is table-like data.
+ memory_pool : MemoryPool, optional
+ If not passed, will allocate memory from the default memory pool.
+
+ Returns
+ -------
+ result : Array of indices
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> import pyarrow.compute as pc
+ >>> arr = pa.array(["a", "b", "c", None, "e", "f"])
+ >>> pc.bottom_k_unstable(arr, k=3)
+ <pyarrow.lib.UInt64Array object at 0x7fdcb19d7fa0>
+ [
+ 0,
+ 1,
+ 2
+ ]
+ """
+ if sort_keys is None:
+ sort_keys = []
+ if isinstance(values, (pa.Array, pa.ChunkedArray)):
+ sort_keys.append(("dummy", "ascending"))
+ else:
+ sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys)
+ options = SelectKOptions(k, sort_keys)
+ return call_function("select_k_unstable", [values], options, memory_pool)
diff --git a/src/arrow/python/pyarrow/config.pxi b/src/arrow/python/pyarrow/config.pxi
new file mode 100644
index 000000000..fc88c28d6
--- /dev/null
+++ b/src/arrow/python/pyarrow/config.pxi
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyarrow.includes.libarrow cimport GetBuildInfo
+
+from collections import namedtuple
+
+
+VersionInfo = namedtuple('VersionInfo', ('major', 'minor', 'patch'))
+
+BuildInfo = namedtuple(
+ 'BuildInfo',
+ ('version', 'version_info', 'so_version', 'full_so_version',
+ 'compiler_id', 'compiler_version', 'compiler_flags',
+ 'git_id', 'git_description', 'package_kind'))
+
+RuntimeInfo = namedtuple('RuntimeInfo',
+ ('simd_level', 'detected_simd_level'))
+
+cdef _build_info():
+ cdef:
+ const CBuildInfo* c_info
+
+ c_info = &GetBuildInfo()
+
+ return BuildInfo(version=frombytes(c_info.version_string),
+ version_info=VersionInfo(c_info.version_major,
+ c_info.version_minor,
+ c_info.version_patch),
+ so_version=frombytes(c_info.so_version),
+ full_so_version=frombytes(c_info.full_so_version),
+ compiler_id=frombytes(c_info.compiler_id),
+ compiler_version=frombytes(c_info.compiler_version),
+ compiler_flags=frombytes(c_info.compiler_flags),
+ git_id=frombytes(c_info.git_id),
+ git_description=frombytes(c_info.git_description),
+ package_kind=frombytes(c_info.package_kind))
+
+
+cpp_build_info = _build_info()
+cpp_version = cpp_build_info.version
+cpp_version_info = cpp_build_info.version_info
+
+
+def runtime_info():
+ """
+ Get runtime information.
+
+ Returns
+ -------
+ info : pyarrow.RuntimeInfo
+ """
+ cdef:
+ CRuntimeInfo c_info
+
+ c_info = GetRuntimeInfo()
+
+ return RuntimeInfo(
+ simd_level=frombytes(c_info.simd_level),
+ detected_simd_level=frombytes(c_info.detected_simd_level))
diff --git a/src/arrow/python/pyarrow/csv.py b/src/arrow/python/pyarrow/csv.py
new file mode 100644
index 000000000..e073252cb
--- /dev/null
+++ b/src/arrow/python/pyarrow/csv.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from pyarrow._csv import ( # noqa
+ ReadOptions, ParseOptions, ConvertOptions, ISO8601,
+ open_csv, read_csv, CSVStreamingReader, write_csv,
+ WriteOptions, CSVWriter)
diff --git a/src/arrow/python/pyarrow/cuda.py b/src/arrow/python/pyarrow/cuda.py
new file mode 100644
index 000000000..18c530d4a
--- /dev/null
+++ b/src/arrow/python/pyarrow/cuda.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# flake8: noqa
+
+
+from pyarrow._cuda import (Context, IpcMemHandle, CudaBuffer,
+ HostBuffer, BufferReader, BufferWriter,
+ new_host_buffer,
+ serialize_record_batch, read_message,
+ read_record_batch)
diff --git a/src/arrow/python/pyarrow/dataset.py b/src/arrow/python/pyarrow/dataset.py
new file mode 100644
index 000000000..42515a9f4
--- /dev/null
+++ b/src/arrow/python/pyarrow/dataset.py
@@ -0,0 +1,881 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Dataset is currently unstable. APIs subject to change without notice."""
+
+import pyarrow as pa
+from pyarrow.util import _is_iterable, _stringify_path, _is_path_like
+
+from pyarrow._dataset import ( # noqa
+ CsvFileFormat,
+ CsvFragmentScanOptions,
+ Expression,
+ Dataset,
+ DatasetFactory,
+ DirectoryPartitioning,
+ FileFormat,
+ FileFragment,
+ FileSystemDataset,
+ FileSystemDatasetFactory,
+ FileSystemFactoryOptions,
+ FileWriteOptions,
+ Fragment,
+ HivePartitioning,
+ IpcFileFormat,
+ IpcFileWriteOptions,
+ InMemoryDataset,
+ ParquetDatasetFactory,
+ ParquetFactoryOptions,
+ ParquetFileFormat,
+ ParquetFileFragment,
+ ParquetFileWriteOptions,
+ ParquetFragmentScanOptions,
+ ParquetReadOptions,
+ Partitioning,
+ PartitioningFactory,
+ RowGroupInfo,
+ Scanner,
+ TaggedRecordBatch,
+ UnionDataset,
+ UnionDatasetFactory,
+ _get_partition_keys,
+ _filesystemdataset_write,
+)
+
+_orc_available = False
+_orc_msg = (
+ "The pyarrow installation is not built with support for the ORC file "
+ "format."
+)
+
+try:
+ from pyarrow._dataset_orc import OrcFileFormat
+ _orc_available = True
+except ImportError:
+ pass
+
+
+def __getattr__(name):
+ if name == "OrcFileFormat" and not _orc_available:
+ raise ImportError(_orc_msg)
+
+ raise AttributeError(
+ "module 'pyarrow.dataset' has no attribute '{0}'".format(name)
+ )
+
+
+def field(name):
+ """Reference a named column of the dataset.
+
+ Stores only the field's name. Type and other information is known only when
+ the expression is bound to a dataset having an explicit scheme.
+
+ Parameters
+ ----------
+ name : string
+ The name of the field the expression references to.
+
+ Returns
+ -------
+ field_expr : Expression
+ """
+ return Expression._field(name)
+
+
+def scalar(value):
+ """Expression representing a scalar value.
+
+ Parameters
+ ----------
+ value : bool, int, float or string
+ Python value of the scalar. Note that only a subset of types are
+ currently supported.
+
+ Returns
+ -------
+ scalar_expr : Expression
+ """
+ return Expression._scalar(value)
+
+
+def partitioning(schema=None, field_names=None, flavor=None,
+ dictionaries=None):
+ """
+ Specify a partitioning scheme.
+
+ The supported schemes include:
+
+ - "DirectoryPartitioning": this scheme expects one segment in the file path
+ for each field in the specified schema (all fields are required to be
+ present). For example given schema<year:int16, month:int8> the path
+ "/2009/11" would be parsed to ("year"_ == 2009 and "month"_ == 11).
+ - "HivePartitioning": a scheme for "/$key=$value/" nested directories as
+ found in Apache Hive. This is a multi-level, directory based partitioning
+ scheme. Data is partitioned by static values of a particular column in
+ the schema. Partition keys are represented in the form $key=$value in
+ directory names. Field order is ignored, as are missing or unrecognized
+ field names.
+ For example, given schema<year:int16, month:int8, day:int8>, a possible
+ path would be "/year=2009/month=11/day=15" (but the field order does not
+ need to match).
+
+ Parameters
+ ----------
+ schema : pyarrow.Schema, default None
+ The schema that describes the partitions present in the file path.
+ If not specified, and `field_names` and/or `flavor` are specified,
+ the schema will be inferred from the file path (and a
+ PartitioningFactory is returned).
+ field_names : list of str, default None
+ A list of strings (field names). If specified, the schema's types are
+ inferred from the file paths (only valid for DirectoryPartitioning).
+ flavor : str, default None
+ The default is DirectoryPartitioning. Specify ``flavor="hive"`` for
+ a HivePartitioning.
+ dictionaries : Dict[str, Array]
+ If the type of any field of `schema` is a dictionary type, the
+ corresponding entry of `dictionaries` must be an array containing
+ every value which may be taken by the corresponding column or an
+ error will be raised in parsing. Alternatively, pass `infer` to have
+ Arrow discover the dictionary values, in which case a
+ PartitioningFactory is returned.
+
+ Returns
+ -------
+ Partitioning or PartitioningFactory
+
+ Examples
+ --------
+
+ Specify the Schema for paths like "/2009/June":
+
+ >>> partitioning(pa.schema([("year", pa.int16()), ("month", pa.string())]))
+
+ or let the types be inferred by only specifying the field names:
+
+ >>> partitioning(field_names=["year", "month"])
+
+ For paths like "/2009/June", the year will be inferred as int32 while month
+ will be inferred as string.
+
+ Specify a Schema with dictionary encoding, providing dictionary values:
+
+ >>> partitioning(
+ ... pa.schema([
+ ... ("year", pa.int16()),
+ ... ("month", pa.dictionary(pa.int8(), pa.string()))
+ ... ]),
+ ... dictionaries={
+ ... "month": pa.array(["January", "February", "March"]),
+ ... })
+
+ Alternatively, specify a Schema with dictionary encoding, but have Arrow
+ infer the dictionary values:
+
+ >>> partitioning(
+ ... pa.schema([
+ ... ("year", pa.int16()),
+ ... ("month", pa.dictionary(pa.int8(), pa.string()))
+ ... ]),
+ ... dictionaries="infer")
+
+ Create a Hive scheme for a path like "/year=2009/month=11":
+
+ >>> partitioning(
+ ... pa.schema([("year", pa.int16()), ("month", pa.int8())]),
+ ... flavor="hive")
+
+ A Hive scheme can also be discovered from the directory structure (and
+ types will be inferred):
+
+ >>> partitioning(flavor="hive")
+
+ """
+ if flavor is None:
+ # default flavor
+ if schema is not None:
+ if field_names is not None:
+ raise ValueError(
+ "Cannot specify both 'schema' and 'field_names'")
+ if dictionaries == 'infer':
+ return DirectoryPartitioning.discover(schema=schema)
+ return DirectoryPartitioning(schema, dictionaries)
+ elif field_names is not None:
+ if isinstance(field_names, list):
+ return DirectoryPartitioning.discover(field_names)
+ else:
+ raise ValueError(
+ "Expected list of field names, got {}".format(
+ type(field_names)))
+ else:
+ raise ValueError(
+ "For the default directory flavor, need to specify "
+ "a Schema or a list of field names")
+ elif flavor == 'hive':
+ if field_names is not None:
+ raise ValueError("Cannot specify 'field_names' for flavor 'hive'")
+ elif schema is not None:
+ if isinstance(schema, pa.Schema):
+ if dictionaries == 'infer':
+ return HivePartitioning.discover(schema=schema)
+ return HivePartitioning(schema, dictionaries)
+ else:
+ raise ValueError(
+ "Expected Schema for 'schema', got {}".format(
+ type(schema)))
+ else:
+ return HivePartitioning.discover()
+ else:
+ raise ValueError("Unsupported flavor")
+
+
+def _ensure_partitioning(scheme):
+ """
+ Validate input and return a Partitioning(Factory).
+
+ It passes None through if no partitioning scheme is defined.
+ """
+ if scheme is None:
+ pass
+ elif isinstance(scheme, str):
+ scheme = partitioning(flavor=scheme)
+ elif isinstance(scheme, list):
+ scheme = partitioning(field_names=scheme)
+ elif isinstance(scheme, (Partitioning, PartitioningFactory)):
+ pass
+ else:
+ ValueError("Expected Partitioning or PartitioningFactory, got {}"
+ .format(type(scheme)))
+ return scheme
+
+
+def _ensure_format(obj):
+ if isinstance(obj, FileFormat):
+ return obj
+ elif obj == "parquet":
+ return ParquetFileFormat()
+ elif obj in {"ipc", "arrow", "feather"}:
+ return IpcFileFormat()
+ elif obj == "csv":
+ return CsvFileFormat()
+ elif obj == "orc":
+ if not _orc_available:
+ raise ValueError(_orc_msg)
+ return OrcFileFormat()
+ else:
+ raise ValueError("format '{}' is not supported".format(obj))
+
+
+def _ensure_multiple_sources(paths, filesystem=None):
+ """
+ Treat a list of paths as files belonging to a single file system
+
+ If the file system is local then also validates that all paths
+ are referencing existing *files* otherwise any non-file paths will be
+ silently skipped (for example on a remote filesystem).
+
+ Parameters
+ ----------
+ paths : list of path-like
+ Note that URIs are not allowed.
+ filesystem : FileSystem or str, optional
+ If an URI is passed, then its path component will act as a prefix for
+ the file paths.
+
+ Returns
+ -------
+ (FileSystem, list of str)
+ File system object and a list of normalized paths.
+
+ Raises
+ ------
+ TypeError
+ If the passed filesystem has wrong type.
+ IOError
+ If the file system is local and a referenced path is not available or
+ not a file.
+ """
+ from pyarrow.fs import (
+ LocalFileSystem, SubTreeFileSystem, _MockFileSystem, FileType,
+ _ensure_filesystem
+ )
+
+ if filesystem is None:
+ # fall back to local file system as the default
+ filesystem = LocalFileSystem()
+ else:
+ # construct a filesystem if it is a valid URI
+ filesystem = _ensure_filesystem(filesystem)
+
+ is_local = (
+ isinstance(filesystem, (LocalFileSystem, _MockFileSystem)) or
+ (isinstance(filesystem, SubTreeFileSystem) and
+ isinstance(filesystem.base_fs, LocalFileSystem))
+ )
+
+ # allow normalizing irregular paths such as Windows local paths
+ paths = [filesystem.normalize_path(_stringify_path(p)) for p in paths]
+
+ # validate that all of the paths are pointing to existing *files*
+ # possible improvement is to group the file_infos by type and raise for
+ # multiple paths per error category
+ if is_local:
+ for info in filesystem.get_file_info(paths):
+ file_type = info.type
+ if file_type == FileType.File:
+ continue
+ elif file_type == FileType.NotFound:
+ raise FileNotFoundError(info.path)
+ elif file_type == FileType.Directory:
+ raise IsADirectoryError(
+ 'Path {} points to a directory, but only file paths are '
+ 'supported. To construct a nested or union dataset pass '
+ 'a list of dataset objects instead.'.format(info.path)
+ )
+ else:
+ raise IOError(
+ 'Path {} exists but its type is unknown (could be a '
+ 'special file such as a Unix socket or character device, '
+ 'or Windows NUL / CON / ...)'.format(info.path)
+ )
+
+ return filesystem, paths
+
+
+def _ensure_single_source(path, filesystem=None):
+ """
+ Treat path as either a recursively traversable directory or a single file.
+
+ Parameters
+ ----------
+ path : path-like
+ filesystem : FileSystem or str, optional
+ If an URI is passed, then its path component will act as a prefix for
+ the file paths.
+
+ Returns
+ -------
+ (FileSystem, list of str or fs.Selector)
+ File system object and either a single item list pointing to a file or
+ an fs.Selector object pointing to a directory.
+
+ Raises
+ ------
+ TypeError
+ If the passed filesystem has wrong type.
+ FileNotFoundError
+ If the referenced file or directory doesn't exist.
+ """
+ from pyarrow.fs import FileType, FileSelector, _resolve_filesystem_and_path
+
+ # at this point we already checked that `path` is a path-like
+ filesystem, path = _resolve_filesystem_and_path(path, filesystem)
+
+ # ensure that the path is normalized before passing to dataset discovery
+ path = filesystem.normalize_path(path)
+
+ # retrieve the file descriptor
+ file_info = filesystem.get_file_info(path)
+
+ # depending on the path type either return with a recursive
+ # directory selector or as a list containing a single file
+ if file_info.type == FileType.Directory:
+ paths_or_selector = FileSelector(path, recursive=True)
+ elif file_info.type == FileType.File:
+ paths_or_selector = [path]
+ else:
+ raise FileNotFoundError(path)
+
+ return filesystem, paths_or_selector
+
+
+def _filesystem_dataset(source, schema=None, filesystem=None,
+ partitioning=None, format=None,
+ partition_base_dir=None, exclude_invalid_files=None,
+ selector_ignore_prefixes=None):
+ """
+ Create a FileSystemDataset which can be used to build a Dataset.
+
+ Parameters are documented in the dataset function.
+
+ Returns
+ -------
+ FileSystemDataset
+ """
+ format = _ensure_format(format or 'parquet')
+ partitioning = _ensure_partitioning(partitioning)
+
+ if isinstance(source, (list, tuple)):
+ fs, paths_or_selector = _ensure_multiple_sources(source, filesystem)
+ else:
+ fs, paths_or_selector = _ensure_single_source(source, filesystem)
+
+ options = FileSystemFactoryOptions(
+ partitioning=partitioning,
+ partition_base_dir=partition_base_dir,
+ exclude_invalid_files=exclude_invalid_files,
+ selector_ignore_prefixes=selector_ignore_prefixes
+ )
+ factory = FileSystemDatasetFactory(fs, paths_or_selector, format, options)
+
+ return factory.finish(schema)
+
+
+def _in_memory_dataset(source, schema=None, **kwargs):
+ if any(v is not None for v in kwargs.values()):
+ raise ValueError(
+ "For in-memory datasets, you cannot pass any additional arguments")
+ return InMemoryDataset(source, schema)
+
+
+def _union_dataset(children, schema=None, **kwargs):
+ if any(v is not None for v in kwargs.values()):
+ raise ValueError(
+ "When passing a list of Datasets, you cannot pass any additional "
+ "arguments"
+ )
+
+ if schema is None:
+ # unify the children datasets' schemas
+ schema = pa.unify_schemas([child.schema for child in children])
+
+ # create datasets with the requested schema
+ children = [child.replace_schema(schema) for child in children]
+
+ return UnionDataset(schema, children)
+
+
+def parquet_dataset(metadata_path, schema=None, filesystem=None, format=None,
+ partitioning=None, partition_base_dir=None):
+ """
+ Create a FileSystemDataset from a `_metadata` file created via
+ `pyarrrow.parquet.write_metadata`.
+
+ Parameters
+ ----------
+ metadata_path : path,
+ Path pointing to a single file parquet metadata file
+ schema : Schema, optional
+ Optionally provide the Schema for the Dataset, in which case it will
+ not be inferred from the source.
+ filesystem : FileSystem or URI string, default None
+ If a single path is given as source and filesystem is None, then the
+ filesystem will be inferred from the path.
+ If an URI string is passed, then a filesystem object is constructed
+ using the URI's optional path component as a directory prefix. See the
+ examples below.
+ Note that the URIs on Windows must follow 'file:///C:...' or
+ 'file:/C:...' patterns.
+ format : ParquetFileFormat
+ An instance of a ParquetFileFormat if special options needs to be
+ passed.
+ partitioning : Partitioning, PartitioningFactory, str, list of str
+ The partitioning scheme specified with the ``partitioning()``
+ function. A flavor string can be used as shortcut, and with a list of
+ field names a DirectionaryPartitioning will be inferred.
+ partition_base_dir : str, optional
+ For the purposes of applying the partitioning, paths will be
+ stripped of the partition_base_dir. Files not matching the
+ partition_base_dir prefix will be skipped for partitioning discovery.
+ The ignored files will still be part of the Dataset, but will not
+ have partition information.
+
+ Returns
+ -------
+ FileSystemDataset
+ """
+ from pyarrow.fs import LocalFileSystem, _ensure_filesystem
+
+ if format is None:
+ format = ParquetFileFormat()
+ elif not isinstance(format, ParquetFileFormat):
+ raise ValueError("format argument must be a ParquetFileFormat")
+
+ if filesystem is None:
+ filesystem = LocalFileSystem()
+ else:
+ filesystem = _ensure_filesystem(filesystem)
+
+ metadata_path = filesystem.normalize_path(_stringify_path(metadata_path))
+ options = ParquetFactoryOptions(
+ partition_base_dir=partition_base_dir,
+ partitioning=_ensure_partitioning(partitioning)
+ )
+
+ factory = ParquetDatasetFactory(
+ metadata_path, filesystem, format, options=options)
+ return factory.finish(schema)
+
+
+def dataset(source, schema=None, format=None, filesystem=None,
+ partitioning=None, partition_base_dir=None,
+ exclude_invalid_files=None, ignore_prefixes=None):
+ """
+ Open a dataset.
+
+ Datasets provides functionality to efficiently work with tabular,
+ potentially larger than memory and multi-file dataset.
+
+ - A unified interface for different sources, like Parquet and Feather
+ - Discovery of sources (crawling directories, handle directory-based
+ partitioned datasets, basic schema normalization)
+ - Optimized reading with predicate pushdown (filtering rows), projection
+ (selecting columns), parallel reading or fine-grained managing of tasks.
+
+ Note that this is the high-level API, to have more control over the dataset
+ construction use the low-level API classes (FileSystemDataset,
+ FilesystemDatasetFactory, etc.)
+
+ Parameters
+ ----------
+ source : path, list of paths, dataset, list of datasets, (list of) batches\
+or tables, iterable of batches, RecordBatchReader, or URI
+ Path pointing to a single file:
+ Open a FileSystemDataset from a single file.
+ Path pointing to a directory:
+ The directory gets discovered recursively according to a
+ partitioning scheme if given.
+ List of file paths:
+ Create a FileSystemDataset from explicitly given files. The files
+ must be located on the same filesystem given by the filesystem
+ parameter.
+ Note that in contrary of construction from a single file, passing
+ URIs as paths is not allowed.
+ List of datasets:
+ A nested UnionDataset gets constructed, it allows arbitrary
+ composition of other datasets.
+ Note that additional keyword arguments are not allowed.
+ (List of) batches or tables, iterable of batches, or RecordBatchReader:
+ Create an InMemoryDataset. If an iterable or empty list is given,
+ a schema must also be given. If an iterable or RecordBatchReader
+ is given, the resulting dataset can only be scanned once; further
+ attempts will raise an error.
+ schema : Schema, optional
+ Optionally provide the Schema for the Dataset, in which case it will
+ not be inferred from the source.
+ format : FileFormat or str
+ Currently "parquet" and "ipc"/"arrow"/"feather" are supported. For
+ Feather, only version 2 files are supported.
+ filesystem : FileSystem or URI string, default None
+ If a single path is given as source and filesystem is None, then the
+ filesystem will be inferred from the path.
+ If an URI string is passed, then a filesystem object is constructed
+ using the URI's optional path component as a directory prefix. See the
+ examples below.
+ Note that the URIs on Windows must follow 'file:///C:...' or
+ 'file:/C:...' patterns.
+ partitioning : Partitioning, PartitioningFactory, str, list of str
+ The partitioning scheme specified with the ``partitioning()``
+ function. A flavor string can be used as shortcut, and with a list of
+ field names a DirectionaryPartitioning will be inferred.
+ partition_base_dir : str, optional
+ For the purposes of applying the partitioning, paths will be
+ stripped of the partition_base_dir. Files not matching the
+ partition_base_dir prefix will be skipped for partitioning discovery.
+ The ignored files will still be part of the Dataset, but will not
+ have partition information.
+ exclude_invalid_files : bool, optional (default True)
+ If True, invalid files will be excluded (file format specific check).
+ This will incur IO for each files in a serial and single threaded
+ fashion. Disabling this feature will skip the IO, but unsupported
+ files may be present in the Dataset (resulting in an error at scan
+ time).
+ ignore_prefixes : list, optional
+ Files matching any of these prefixes will be ignored by the
+ discovery process. This is matched to the basename of a path.
+ By default this is ['.', '_'].
+ Note that discovery happens only if a directory is passed as source.
+
+ Returns
+ -------
+ dataset : Dataset
+ Either a FileSystemDataset or a UnionDataset depending on the source
+ parameter.
+
+ Examples
+ --------
+ Opening a single file:
+
+ >>> dataset("path/to/file.parquet", format="parquet")
+
+ Opening a single file with an explicit schema:
+
+ >>> dataset("path/to/file.parquet", schema=myschema, format="parquet")
+
+ Opening a dataset for a single directory:
+
+ >>> dataset("path/to/nyc-taxi/", format="parquet")
+ >>> dataset("s3://mybucket/nyc-taxi/", format="parquet")
+
+ Opening a dataset from a list of relatives local paths:
+
+ >>> dataset([
+ ... "part0/data.parquet",
+ ... "part1/data.parquet",
+ ... "part3/data.parquet",
+ ... ], format='parquet')
+
+ With filesystem provided:
+
+ >>> paths = [
+ ... 'part0/data.parquet',
+ ... 'part1/data.parquet',
+ ... 'part3/data.parquet',
+ ... ]
+ >>> dataset(paths, filesystem='file:///directory/prefix, format='parquet')
+
+ Which is equivalent with:
+
+ >>> fs = SubTreeFileSystem("/directory/prefix", LocalFileSystem())
+ >>> dataset(paths, filesystem=fs, format='parquet')
+
+ With a remote filesystem URI:
+
+ >>> paths = [
+ ... 'nested/directory/part0/data.parquet',
+ ... 'nested/directory/part1/data.parquet',
+ ... 'nested/directory/part3/data.parquet',
+ ... ]
+ >>> dataset(paths, filesystem='s3://bucket/', format='parquet')
+
+ Similarly to the local example, the directory prefix may be included in the
+ filesystem URI:
+
+ >>> dataset(paths, filesystem='s3://bucket/nested/directory',
+ ... format='parquet')
+
+ Construction of a nested dataset:
+
+ >>> dataset([
+ ... dataset("s3://old-taxi-data", format="parquet"),
+ ... dataset("local/path/to/data", format="ipc")
+ ... ])
+ """
+ # collect the keyword arguments for later reuse
+ kwargs = dict(
+ schema=schema,
+ filesystem=filesystem,
+ partitioning=partitioning,
+ format=format,
+ partition_base_dir=partition_base_dir,
+ exclude_invalid_files=exclude_invalid_files,
+ selector_ignore_prefixes=ignore_prefixes
+ )
+
+ if _is_path_like(source):
+ return _filesystem_dataset(source, **kwargs)
+ elif isinstance(source, (tuple, list)):
+ if all(_is_path_like(elem) for elem in source):
+ return _filesystem_dataset(source, **kwargs)
+ elif all(isinstance(elem, Dataset) for elem in source):
+ return _union_dataset(source, **kwargs)
+ elif all(isinstance(elem, (pa.RecordBatch, pa.Table))
+ for elem in source):
+ return _in_memory_dataset(source, **kwargs)
+ else:
+ unique_types = set(type(elem).__name__ for elem in source)
+ type_names = ', '.join('{}'.format(t) for t in unique_types)
+ raise TypeError(
+ 'Expected a list of path-like or dataset objects, or a list '
+ 'of batches or tables. The given list contains the following '
+ 'types: {}'.format(type_names)
+ )
+ elif isinstance(source, (pa.RecordBatch, pa.Table)):
+ return _in_memory_dataset(source, **kwargs)
+ else:
+ raise TypeError(
+ 'Expected a path-like, list of path-likes or a list of Datasets '
+ 'instead of the given type: {}'.format(type(source).__name__)
+ )
+
+
+def _ensure_write_partitioning(part, schema, flavor):
+ if isinstance(part, PartitioningFactory):
+ raise ValueError("A PartitioningFactory cannot be used. "
+ "Did you call the partitioning function "
+ "without supplying a schema?")
+
+ if isinstance(part, Partitioning) and flavor:
+ raise ValueError(
+ "Providing a partitioning_flavor with "
+ "a Partitioning object is not supported"
+ )
+ elif isinstance(part, (tuple, list)):
+ # Name of fields were provided instead of a partitioning object.
+ # Create a partitioning factory with those field names.
+ part = partitioning(
+ schema=pa.schema([schema.field(f) for f in part]),
+ flavor=flavor
+ )
+ elif part is None:
+ part = partitioning(pa.schema([]), flavor=flavor)
+
+ if not isinstance(part, Partitioning):
+ raise ValueError(
+ "partitioning must be a Partitioning object or "
+ "a list of column names"
+ )
+
+ return part
+
+
+def write_dataset(data, base_dir, basename_template=None, format=None,
+ partitioning=None, partitioning_flavor=None, schema=None,
+ filesystem=None, file_options=None, use_threads=True,
+ max_partitions=None, file_visitor=None,
+ existing_data_behavior='error'):
+ """
+ Write a dataset to a given format and partitioning.
+
+ Parameters
+ ----------
+ data : Dataset, Table/RecordBatch, RecordBatchReader, list of
+ Table/RecordBatch, or iterable of RecordBatch
+ The data to write. This can be a Dataset instance or
+ in-memory Arrow data. If an iterable is given, the schema must
+ also be given.
+ base_dir : str
+ The root directory where to write the dataset.
+ basename_template : str, optional
+ A template string used to generate basenames of written data files.
+ The token '{i}' will be replaced with an automatically incremented
+ integer. If not specified, it defaults to
+ "part-{i}." + format.default_extname
+ format : FileFormat or str
+ The format in which to write the dataset. Currently supported:
+ "parquet", "ipc"/"feather". If a FileSystemDataset is being written
+ and `format` is not specified, it defaults to the same format as the
+ specified FileSystemDataset. When writing a Table or RecordBatch, this
+ keyword is required.
+ partitioning : Partitioning or list[str], optional
+ The partitioning scheme specified with the ``partitioning()``
+ function or a list of field names. When providing a list of
+ field names, you can use ``partitioning_flavor`` to drive which
+ partitioning type should be used.
+ partitioning_flavor : str, optional
+ One of the partitioning flavors supported by
+ ``pyarrow.dataset.partitioning``. If omitted will use the
+ default of ``partitioning()`` which is directory partitioning.
+ schema : Schema, optional
+ filesystem : FileSystem, optional
+ file_options : FileWriteOptions, optional
+ FileFormat specific write options, created using the
+ ``FileFormat.make_write_options()`` function.
+ use_threads : bool, default True
+ Write files in parallel. If enabled, then maximum parallelism will be
+ used determined by the number of available CPU cores.
+ max_partitions : int, default 1024
+ Maximum number of partitions any batch may be written into.
+ file_visitor : Function
+ If set, this function will be called with a WrittenFile instance
+ for each file created during the call. This object will have both
+ a path attribute and a metadata attribute.
+
+ The path attribute will be a string containing the path to
+ the created file.
+
+ The metadata attribute will be the parquet metadata of the file.
+ This metadata will have the file path attribute set and can be used
+ to build a _metadata file. The metadata attribute will be None if
+ the format is not parquet.
+
+ Example visitor which simple collects the filenames created::
+
+ visited_paths = []
+
+ def file_visitor(written_file):
+ visited_paths.append(written_file.path)
+ existing_data_behavior : 'error' | 'overwrite_or_ignore' | \
+'delete_matching'
+ Controls how the dataset will handle data that already exists in
+ the destination. The default behavior ('error') is to raise an error
+ if any data exists in the destination.
+
+ 'overwrite_or_ignore' will ignore any existing data and will
+ overwrite files with the same name as an output file. Other
+ existing files will be ignored. This behavior, in combination
+ with a unique basename_template for each write, will allow for
+ an append workflow.
+
+ 'delete_matching' is useful when you are writing a partitioned
+ dataset. The first time each partition directory is encountered
+ the entire directory will be deleted. This allows you to overwrite
+ old partitions completely.
+ """
+ from pyarrow.fs import _resolve_filesystem_and_path
+
+ if isinstance(data, (list, tuple)):
+ schema = schema or data[0].schema
+ data = InMemoryDataset(data, schema=schema)
+ elif isinstance(data, (pa.RecordBatch, pa.Table)):
+ schema = schema or data.schema
+ data = InMemoryDataset(data, schema=schema)
+ elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data):
+ data = Scanner.from_batches(data, schema=schema, use_async=True)
+ schema = None
+ elif not isinstance(data, (Dataset, Scanner)):
+ raise ValueError(
+ "Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, "
+ "a list of Tables/RecordBatches, or iterable of batches are "
+ "supported."
+ )
+
+ if format is None and isinstance(data, FileSystemDataset):
+ format = data.format
+ else:
+ format = _ensure_format(format)
+
+ if file_options is None:
+ file_options = format.make_write_options()
+
+ if format != file_options.format:
+ raise TypeError("Supplied FileWriteOptions have format {}, "
+ "which doesn't match supplied FileFormat {}".format(
+ format, file_options))
+
+ if basename_template is None:
+ basename_template = "part-{i}." + format.default_extname
+
+ if max_partitions is None:
+ max_partitions = 1024
+
+ # at this point data is a Scanner or a Dataset, anything else
+ # was converted to one of those two. So we can grab the schema
+ # to build the partitioning object from Dataset.
+ if isinstance(data, Scanner):
+ partitioning_schema = data.dataset_schema
+ else:
+ partitioning_schema = data.schema
+ partitioning = _ensure_write_partitioning(partitioning,
+ schema=partitioning_schema,
+ flavor=partitioning_flavor)
+
+ filesystem, base_dir = _resolve_filesystem_and_path(base_dir, filesystem)
+
+ if isinstance(data, Dataset):
+ scanner = data.scanner(use_threads=use_threads, use_async=True)
+ else:
+ # scanner was passed directly by the user, in which case a schema
+ # cannot be passed
+ if schema is not None:
+ raise ValueError("Cannot specify a schema when writing a Scanner")
+ scanner = data
+
+ _filesystemdataset_write(
+ scanner, base_dir, basename_template, filesystem, partitioning,
+ file_options, max_partitions, file_visitor, existing_data_behavior
+ )
diff --git a/src/arrow/python/pyarrow/error.pxi b/src/arrow/python/pyarrow/error.pxi
new file mode 100644
index 000000000..233b4fb16
--- /dev/null
+++ b/src/arrow/python/pyarrow/error.pxi
@@ -0,0 +1,242 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from cpython.exc cimport PyErr_CheckSignals, PyErr_SetInterrupt
+
+from pyarrow.includes.libarrow cimport CStatus, IsPyError, RestorePyError
+from pyarrow.includes.common cimport c_string
+
+from contextlib import contextmanager
+import os
+import signal
+import threading
+
+from pyarrow.util import _break_traceback_cycle_from_frame
+
+
+class ArrowException(Exception):
+ pass
+
+
+class ArrowInvalid(ValueError, ArrowException):
+ pass
+
+
+class ArrowMemoryError(MemoryError, ArrowException):
+ pass
+
+
+class ArrowKeyError(KeyError, ArrowException):
+ def __str__(self):
+ # Override KeyError.__str__, as it uses the repr() of the key
+ return ArrowException.__str__(self)
+
+
+class ArrowTypeError(TypeError, ArrowException):
+ pass
+
+
+class ArrowNotImplementedError(NotImplementedError, ArrowException):
+ pass
+
+
+class ArrowCapacityError(ArrowException):
+ pass
+
+
+class ArrowIndexError(IndexError, ArrowException):
+ pass
+
+
+class ArrowSerializationError(ArrowException):
+ pass
+
+
+class ArrowCancelled(ArrowException):
+ def __init__(self, message, signum=None):
+ super().__init__(message)
+ self.signum = signum
+
+
+# Compatibility alias
+ArrowIOError = IOError
+
+
+# This function could be written directly in C++ if we didn't
+# define Arrow-specific subclasses (ArrowInvalid etc.)
+cdef int check_status(const CStatus& status) nogil except -1:
+ if status.ok():
+ return 0
+
+ with gil:
+ if IsPyError(status):
+ RestorePyError(status)
+ return -1
+
+ # We don't use Status::ToString() as it would redundantly include
+ # the C++ class name.
+ message = frombytes(status.message(), safe=True)
+ detail = status.detail()
+ if detail != nullptr:
+ message += ". Detail: " + frombytes(detail.get().ToString(),
+ safe=True)
+
+ if status.IsInvalid():
+ raise ArrowInvalid(message)
+ elif status.IsIOError():
+ # Note: OSError constructor is
+ # OSError(message)
+ # or
+ # OSError(errno, message, filename=None)
+ # or (on Windows)
+ # OSError(errno, message, filename, winerror)
+ errno = ErrnoFromStatus(status)
+ winerror = WinErrorFromStatus(status)
+ if winerror != 0:
+ raise IOError(errno, message, None, winerror)
+ elif errno != 0:
+ raise IOError(errno, message)
+ else:
+ raise IOError(message)
+ elif status.IsOutOfMemory():
+ raise ArrowMemoryError(message)
+ elif status.IsKeyError():
+ raise ArrowKeyError(message)
+ elif status.IsNotImplemented():
+ raise ArrowNotImplementedError(message)
+ elif status.IsTypeError():
+ raise ArrowTypeError(message)
+ elif status.IsCapacityError():
+ raise ArrowCapacityError(message)
+ elif status.IsIndexError():
+ raise ArrowIndexError(message)
+ elif status.IsSerializationError():
+ raise ArrowSerializationError(message)
+ elif status.IsCancelled():
+ signum = SignalFromStatus(status)
+ if signum > 0:
+ raise ArrowCancelled(message, signum)
+ else:
+ raise ArrowCancelled(message)
+ else:
+ message = frombytes(status.ToString(), safe=True)
+ raise ArrowException(message)
+
+
+# This is an API function for C++ PyArrow
+cdef api int pyarrow_internal_check_status(const CStatus& status) \
+ nogil except -1:
+ return check_status(status)
+
+
+cdef class StopToken:
+ cdef void init(self, CStopToken stop_token):
+ self.stop_token = move(stop_token)
+
+
+cdef c_bool signal_handlers_enabled = True
+
+
+def enable_signal_handlers(c_bool enable):
+ """
+ Enable or disable interruption of long-running operations.
+
+ By default, certain long running operations will detect user
+ interruptions, such as by pressing Ctrl-C. This detection relies
+ on setting a signal handler for the duration of the long-running
+ operation, and may therefore interfere with other frameworks or
+ libraries (such as an event loop).
+
+ Parameters
+ ----------
+ enable : bool
+ Whether to enable user interruption by setting a temporary
+ signal handler.
+ """
+ global signal_handlers_enabled
+ signal_handlers_enabled = enable
+
+
+# For internal use
+
+# Whether we need a workaround for https://bugs.python.org/issue42248
+have_signal_refcycle = (sys.version_info < (3, 8, 10) or
+ (3, 9) <= sys.version_info < (3, 9, 5) or
+ sys.version_info[:2] == (3, 10))
+
+cdef class SignalStopHandler:
+ cdef:
+ StopToken _stop_token
+ vector[int] _signals
+ c_bool _enabled
+
+ def __cinit__(self):
+ self._enabled = False
+
+ self._init_signals()
+ if have_signal_refcycle:
+ _break_traceback_cycle_from_frame(sys._getframe(0))
+
+ self._stop_token = StopToken()
+ if not self._signals.empty():
+ self._stop_token.init(GetResultValue(
+ SetSignalStopSource()).token())
+ self._enabled = True
+
+ def _init_signals(self):
+ if (signal_handlers_enabled and
+ threading.current_thread() is threading.main_thread()):
+ self._signals = [
+ sig for sig in (signal.SIGINT, signal.SIGTERM)
+ if signal.getsignal(sig) not in (signal.SIG_DFL,
+ signal.SIG_IGN, None)]
+
+ def __enter__(self):
+ if self._enabled:
+ check_status(RegisterCancellingSignalHandler(self._signals))
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ if self._enabled:
+ UnregisterCancellingSignalHandler()
+ if isinstance(exc_value, ArrowCancelled):
+ if exc_value.signum:
+ # Re-emit the exact same signal. We restored the Python signal
+ # handler above, so it should receive it.
+ if os.name == 'nt':
+ SendSignal(exc_value.signum)
+ else:
+ SendSignalToThread(exc_value.signum, threading.get_ident())
+ else:
+ # Simulate Python receiving a SIGINT
+ # (see https://bugs.python.org/issue43356 for why we can't
+ # simulate the exact signal number)
+ PyErr_SetInterrupt()
+ # Maximize chances of the Python signal handler being executed now.
+ # Otherwise a potential KeyboardInterrupt might be missed by an
+ # immediately enclosing try/except block.
+ PyErr_CheckSignals()
+ # ArrowCancelled will be re-raised if PyErr_CheckSignals()
+ # returned successfully.
+
+ def __dealloc__(self):
+ if self._enabled:
+ ResetSignalStopSource()
+
+ @property
+ def stop_token(self):
+ return self._stop_token
diff --git a/src/arrow/python/pyarrow/feather.py b/src/arrow/python/pyarrow/feather.py
new file mode 100644
index 000000000..2170a93c3
--- /dev/null
+++ b/src/arrow/python/pyarrow/feather.py
@@ -0,0 +1,265 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import os
+
+from pyarrow.pandas_compat import _pandas_api # noqa
+from pyarrow.lib import (Codec, Table, # noqa
+ concat_tables, schema)
+import pyarrow.lib as ext
+from pyarrow import _feather
+from pyarrow._feather import FeatherError # noqa: F401
+from pyarrow.vendored.version import Version
+
+
+def _check_pandas_version():
+ if _pandas_api.loose_version < Version('0.17.0'):
+ raise ImportError("feather requires pandas >= 0.17.0")
+
+
+class FeatherDataset:
+ """
+ Encapsulates details of reading a list of Feather files.
+
+ Parameters
+ ----------
+ path_or_paths : List[str]
+ A list of file names
+ validate_schema : bool, default True
+ Check that individual file schemas are all the same / compatible
+ """
+
+ def __init__(self, path_or_paths, validate_schema=True):
+ self.paths = path_or_paths
+ self.validate_schema = validate_schema
+
+ def read_table(self, columns=None):
+ """
+ Read multiple feather files as a single pyarrow.Table
+
+ Parameters
+ ----------
+ columns : List[str]
+ Names of columns to read from the file
+
+ Returns
+ -------
+ pyarrow.Table
+ Content of the file as a table (of columns)
+ """
+ _fil = read_table(self.paths[0], columns=columns)
+ self._tables = [_fil]
+ self.schema = _fil.schema
+
+ for path in self.paths[1:]:
+ table = read_table(path, columns=columns)
+ if self.validate_schema:
+ self.validate_schemas(path, table)
+ self._tables.append(table)
+ return concat_tables(self._tables)
+
+ def validate_schemas(self, piece, table):
+ if not self.schema.equals(table.schema):
+ raise ValueError('Schema in {!s} was different. \n'
+ '{!s}\n\nvs\n\n{!s}'
+ .format(piece, self.schema,
+ table.schema))
+
+ def read_pandas(self, columns=None, use_threads=True):
+ """
+ Read multiple Parquet files as a single pandas DataFrame
+
+ Parameters
+ ----------
+ columns : List[str]
+ Names of columns to read from the file
+ use_threads : bool, default True
+ Use multiple threads when converting to pandas
+
+ Returns
+ -------
+ pandas.DataFrame
+ Content of the file as a pandas DataFrame (of columns)
+ """
+ _check_pandas_version()
+ return self.read_table(columns=columns).to_pandas(
+ use_threads=use_threads)
+
+
+def check_chunked_overflow(name, col):
+ if col.num_chunks == 1:
+ return
+
+ if col.type in (ext.binary(), ext.string()):
+ raise ValueError("Column '{}' exceeds 2GB maximum capacity of "
+ "a Feather binary column. This restriction may be "
+ "lifted in the future".format(name))
+ else:
+ # TODO(wesm): Not sure when else this might be reached
+ raise ValueError("Column '{}' of type {} was chunked on conversion "
+ "to Arrow and cannot be currently written to "
+ "Feather format".format(name, str(col.type)))
+
+
+_FEATHER_SUPPORTED_CODECS = {'lz4', 'zstd', 'uncompressed'}
+
+
+def write_feather(df, dest, compression=None, compression_level=None,
+ chunksize=None, version=2):
+ """
+ Write a pandas.DataFrame to Feather format.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame or pyarrow.Table
+ Data to write out as Feather format.
+ dest : str
+ Local destination path.
+ compression : string, default None
+ Can be one of {"zstd", "lz4", "uncompressed"}. The default of None uses
+ LZ4 for V2 files if it is available, otherwise uncompressed.
+ compression_level : int, default None
+ Use a compression level particular to the chosen compressor. If None
+ use the default compression level
+ chunksize : int, default None
+ For V2 files, the internal maximum size of Arrow RecordBatch chunks
+ when writing the Arrow IPC file format. None means use the default,
+ which is currently 64K
+ version : int, default 2
+ Feather file version. Version 2 is the current. Version 1 is the more
+ limited legacy format
+ """
+ if _pandas_api.have_pandas:
+ _check_pandas_version()
+ if (_pandas_api.has_sparse and
+ isinstance(df, _pandas_api.pd.SparseDataFrame)):
+ df = df.to_dense()
+
+ if _pandas_api.is_data_frame(df):
+ table = Table.from_pandas(df, preserve_index=False)
+
+ if version == 1:
+ # Version 1 does not chunking
+ for i, name in enumerate(table.schema.names):
+ col = table[i]
+ check_chunked_overflow(name, col)
+ else:
+ table = df
+
+ if version == 1:
+ if len(table.column_names) > len(set(table.column_names)):
+ raise ValueError("cannot serialize duplicate column names")
+
+ if compression is not None:
+ raise ValueError("Feather V1 files do not support compression "
+ "option")
+
+ if chunksize is not None:
+ raise ValueError("Feather V1 files do not support chunksize "
+ "option")
+ else:
+ if compression is None and Codec.is_available('lz4_frame'):
+ compression = 'lz4'
+ elif (compression is not None and
+ compression not in _FEATHER_SUPPORTED_CODECS):
+ raise ValueError('compression="{}" not supported, must be '
+ 'one of {}'.format(compression,
+ _FEATHER_SUPPORTED_CODECS))
+
+ try:
+ _feather.write_feather(table, dest, compression=compression,
+ compression_level=compression_level,
+ chunksize=chunksize, version=version)
+ except Exception:
+ if isinstance(dest, str):
+ try:
+ os.remove(dest)
+ except os.error:
+ pass
+ raise
+
+
+def read_feather(source, columns=None, use_threads=True, memory_map=True):
+ """
+ Read a pandas.DataFrame from Feather format. To read as pyarrow.Table use
+ feather.read_table.
+
+ Parameters
+ ----------
+ source : str file path, or file-like object
+ columns : sequence, optional
+ Only read a specific set of columns. If not provided, all columns are
+ read.
+ use_threads : bool, default True
+ Whether to parallelize reading using multiple threads. If false the
+ restriction is only used in the conversion to Pandas and not in the
+ reading from Feather format.
+ memory_map : boolean, default True
+ Use memory mapping when opening file on disk
+
+ Returns
+ -------
+ df : pandas.DataFrame
+ """
+ _check_pandas_version()
+ return (read_table(source, columns=columns, memory_map=memory_map)
+ .to_pandas(use_threads=use_threads))
+
+
+def read_table(source, columns=None, memory_map=True):
+ """
+ Read a pyarrow.Table from Feather format
+
+ Parameters
+ ----------
+ source : str file path, or file-like object
+ columns : sequence, optional
+ Only read a specific set of columns. If not provided, all columns are
+ read.
+ memory_map : boolean, default True
+ Use memory mapping when opening file on disk
+
+ Returns
+ -------
+ table : pyarrow.Table
+ """
+ reader = _feather.FeatherReader(source, use_memory_map=memory_map)
+
+ if columns is None:
+ return reader.read()
+
+ column_types = [type(column) for column in columns]
+ if all(map(lambda t: t == int, column_types)):
+ table = reader.read_indices(columns)
+ elif all(map(lambda t: t == str, column_types)):
+ table = reader.read_names(columns)
+ else:
+ column_type_names = [t.__name__ for t in column_types]
+ raise TypeError("Columns must be indices or names. "
+ "Got columns {} of types {}"
+ .format(columns, column_type_names))
+
+ # Feather v1 already respects the column selection
+ if reader.version < 3:
+ return table
+ # Feather v2 reads with sorted / deduplicated selection
+ elif sorted(set(columns)) == columns:
+ return table
+ else:
+ # follow exact order / selection of names
+ return table.select(columns)
diff --git a/src/arrow/python/pyarrow/filesystem.py b/src/arrow/python/pyarrow/filesystem.py
new file mode 100644
index 000000000..c2017e42b
--- /dev/null
+++ b/src/arrow/python/pyarrow/filesystem.py
@@ -0,0 +1,511 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import os
+import posixpath
+import sys
+import urllib.parse
+import warnings
+
+from os.path import join as pjoin
+
+import pyarrow as pa
+from pyarrow.util import implements, _stringify_path, _is_path_like, _DEPR_MSG
+
+
+_FS_DEPR_MSG = _DEPR_MSG.format(
+ "filesystem.LocalFileSystem", "2.0.0", "fs.LocalFileSystem"
+)
+
+
+class FileSystem:
+ """
+ Abstract filesystem interface.
+ """
+
+ def cat(self, path):
+ """
+ Return contents of file as a bytes object.
+
+ Parameters
+ ----------
+ path : str
+ File path to read content from.
+
+ Returns
+ -------
+ contents : bytes
+ """
+ with self.open(path, 'rb') as f:
+ return f.read()
+
+ def ls(self, path):
+ """
+ Return list of file paths.
+
+ Parameters
+ ----------
+ path : str
+ Directory to list contents from.
+ """
+ raise NotImplementedError
+
+ def delete(self, path, recursive=False):
+ """
+ Delete the indicated file or directory.
+
+ Parameters
+ ----------
+ path : str
+ Path to delete.
+ recursive : bool, default False
+ If True, also delete child paths for directories.
+ """
+ raise NotImplementedError
+
+ def disk_usage(self, path):
+ """
+ Compute bytes used by all contents under indicated path in file tree.
+
+ Parameters
+ ----------
+ path : str
+ Can be a file path or directory.
+
+ Returns
+ -------
+ usage : int
+ """
+ path = _stringify_path(path)
+ path_info = self.stat(path)
+ if path_info['kind'] == 'file':
+ return path_info['size']
+
+ total = 0
+ for root, directories, files in self.walk(path):
+ for child_path in files:
+ abspath = self._path_join(root, child_path)
+ total += self.stat(abspath)['size']
+
+ return total
+
+ def _path_join(self, *args):
+ return self.pathsep.join(args)
+
+ def stat(self, path):
+ """
+ Information about a filesystem entry.
+
+ Returns
+ -------
+ stat : dict
+ """
+ raise NotImplementedError('FileSystem.stat')
+
+ def rm(self, path, recursive=False):
+ """
+ Alias for FileSystem.delete.
+ """
+ return self.delete(path, recursive=recursive)
+
+ def mv(self, path, new_path):
+ """
+ Alias for FileSystem.rename.
+ """
+ return self.rename(path, new_path)
+
+ def rename(self, path, new_path):
+ """
+ Rename file, like UNIX mv command.
+
+ Parameters
+ ----------
+ path : str
+ Path to alter.
+ new_path : str
+ Path to move to.
+ """
+ raise NotImplementedError('FileSystem.rename')
+
+ def mkdir(self, path, create_parents=True):
+ """
+ Create a directory.
+
+ Parameters
+ ----------
+ path : str
+ Path to the directory.
+ create_parents : bool, default True
+ If the parent directories don't exists create them as well.
+ """
+ raise NotImplementedError
+
+ def exists(self, path):
+ """
+ Return True if path exists.
+
+ Parameters
+ ----------
+ path : str
+ Path to check.
+ """
+ raise NotImplementedError
+
+ def isdir(self, path):
+ """
+ Return True if path is a directory.
+
+ Parameters
+ ----------
+ path : str
+ Path to check.
+ """
+ raise NotImplementedError
+
+ def isfile(self, path):
+ """
+ Return True if path is a file.
+
+ Parameters
+ ----------
+ path : str
+ Path to check.
+ """
+ raise NotImplementedError
+
+ def _isfilestore(self):
+ """
+ Returns True if this FileSystem is a unix-style file store with
+ directories.
+ """
+ raise NotImplementedError
+
+ def read_parquet(self, path, columns=None, metadata=None, schema=None,
+ use_threads=True, use_pandas_metadata=False):
+ """
+ Read Parquet data from path in file system. Can read from a single file
+ or a directory of files.
+
+ Parameters
+ ----------
+ path : str
+ Single file path or directory
+ columns : List[str], optional
+ Subset of columns to read.
+ metadata : pyarrow.parquet.FileMetaData
+ Known metadata to validate files against.
+ schema : pyarrow.parquet.Schema
+ Known schema to validate files against. Alternative to metadata
+ argument.
+ use_threads : bool, default True
+ Perform multi-threaded column reads.
+ use_pandas_metadata : bool, default False
+ If True and file has custom pandas schema metadata, ensure that
+ index columns are also loaded.
+
+ Returns
+ -------
+ table : pyarrow.Table
+ """
+ from pyarrow.parquet import ParquetDataset
+ dataset = ParquetDataset(path, schema=schema, metadata=metadata,
+ filesystem=self)
+ return dataset.read(columns=columns, use_threads=use_threads,
+ use_pandas_metadata=use_pandas_metadata)
+
+ def open(self, path, mode='rb'):
+ """
+ Open file for reading or writing.
+ """
+ raise NotImplementedError
+
+ @property
+ def pathsep(self):
+ return '/'
+
+
+class LocalFileSystem(FileSystem):
+
+ _instance = None
+
+ def __init__(self):
+ warnings.warn(_FS_DEPR_MSG, FutureWarning, stacklevel=2)
+ super().__init__()
+
+ @classmethod
+ def _get_instance(cls):
+ if cls._instance is None:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ cls._instance = LocalFileSystem()
+ return cls._instance
+
+ @classmethod
+ def get_instance(cls):
+ warnings.warn(_FS_DEPR_MSG, FutureWarning, stacklevel=2)
+ return cls._get_instance()
+
+ @implements(FileSystem.ls)
+ def ls(self, path):
+ path = _stringify_path(path)
+ return sorted(pjoin(path, x) for x in os.listdir(path))
+
+ @implements(FileSystem.mkdir)
+ def mkdir(self, path, create_parents=True):
+ path = _stringify_path(path)
+ if create_parents:
+ os.makedirs(path)
+ else:
+ os.mkdir(path)
+
+ @implements(FileSystem.isdir)
+ def isdir(self, path):
+ path = _stringify_path(path)
+ return os.path.isdir(path)
+
+ @implements(FileSystem.isfile)
+ def isfile(self, path):
+ path = _stringify_path(path)
+ return os.path.isfile(path)
+
+ @implements(FileSystem._isfilestore)
+ def _isfilestore(self):
+ return True
+
+ @implements(FileSystem.exists)
+ def exists(self, path):
+ path = _stringify_path(path)
+ return os.path.exists(path)
+
+ @implements(FileSystem.open)
+ def open(self, path, mode='rb'):
+ """
+ Open file for reading or writing.
+ """
+ path = _stringify_path(path)
+ return open(path, mode=mode)
+
+ @property
+ def pathsep(self):
+ return os.path.sep
+
+ def walk(self, path):
+ """
+ Directory tree generator, see os.walk.
+ """
+ path = _stringify_path(path)
+ return os.walk(path)
+
+
+class DaskFileSystem(FileSystem):
+ """
+ Wraps s3fs Dask filesystem implementation like s3fs, gcsfs, etc.
+ """
+
+ def __init__(self, fs):
+ warnings.warn(
+ "The pyarrow.filesystem.DaskFileSystem/S3FSWrapper are deprecated "
+ "as of pyarrow 3.0.0, and will be removed in a future version.",
+ FutureWarning, stacklevel=2)
+ self.fs = fs
+
+ @implements(FileSystem.isdir)
+ def isdir(self, path):
+ raise NotImplementedError("Unsupported file system API")
+
+ @implements(FileSystem.isfile)
+ def isfile(self, path):
+ raise NotImplementedError("Unsupported file system API")
+
+ @implements(FileSystem._isfilestore)
+ def _isfilestore(self):
+ """
+ Object Stores like S3 and GCSFS are based on key lookups, not true
+ file-paths.
+ """
+ return False
+
+ @implements(FileSystem.delete)
+ def delete(self, path, recursive=False):
+ path = _stringify_path(path)
+ return self.fs.rm(path, recursive=recursive)
+
+ @implements(FileSystem.exists)
+ def exists(self, path):
+ path = _stringify_path(path)
+ return self.fs.exists(path)
+
+ @implements(FileSystem.mkdir)
+ def mkdir(self, path, create_parents=True):
+ path = _stringify_path(path)
+ if create_parents:
+ return self.fs.mkdirs(path)
+ else:
+ return self.fs.mkdir(path)
+
+ @implements(FileSystem.open)
+ def open(self, path, mode='rb'):
+ """
+ Open file for reading or writing.
+ """
+ path = _stringify_path(path)
+ return self.fs.open(path, mode=mode)
+
+ def ls(self, path, detail=False):
+ path = _stringify_path(path)
+ return self.fs.ls(path, detail=detail)
+
+ def walk(self, path):
+ """
+ Directory tree generator, like os.walk.
+ """
+ path = _stringify_path(path)
+ return self.fs.walk(path)
+
+
+class S3FSWrapper(DaskFileSystem):
+
+ @implements(FileSystem.isdir)
+ def isdir(self, path):
+ path = _sanitize_s3(_stringify_path(path))
+ try:
+ contents = self.fs.ls(path)
+ if len(contents) == 1 and contents[0] == path:
+ return False
+ else:
+ return True
+ except OSError:
+ return False
+
+ @implements(FileSystem.isfile)
+ def isfile(self, path):
+ path = _sanitize_s3(_stringify_path(path))
+ try:
+ contents = self.fs.ls(path)
+ return len(contents) == 1 and contents[0] == path
+ except OSError:
+ return False
+
+ def walk(self, path, refresh=False):
+ """
+ Directory tree generator, like os.walk.
+
+ Generator version of what is in s3fs, which yields a flattened list of
+ files.
+ """
+ path = _sanitize_s3(_stringify_path(path))
+ directories = set()
+ files = set()
+
+ for key in list(self.fs._ls(path, refresh=refresh)):
+ path = key['Key']
+ if key['StorageClass'] == 'DIRECTORY':
+ directories.add(path)
+ elif key['StorageClass'] == 'BUCKET':
+ pass
+ else:
+ files.add(path)
+
+ # s3fs creates duplicate 'DIRECTORY' entries
+ files = sorted([posixpath.split(f)[1] for f in files
+ if f not in directories])
+ directories = sorted([posixpath.split(x)[1]
+ for x in directories])
+
+ yield path, directories, files
+
+ for directory in directories:
+ yield from self.walk(directory, refresh=refresh)
+
+
+def _sanitize_s3(path):
+ if path.startswith('s3://'):
+ return path.replace('s3://', '')
+ else:
+ return path
+
+
+def _ensure_filesystem(fs):
+ fs_type = type(fs)
+
+ # If the arrow filesystem was subclassed, assume it supports the full
+ # interface and return it
+ if not issubclass(fs_type, FileSystem):
+ if "fsspec" in sys.modules:
+ fsspec = sys.modules["fsspec"]
+ if isinstance(fs, fsspec.AbstractFileSystem):
+ # for recent fsspec versions that stop inheriting from
+ # pyarrow.filesystem.FileSystem, still allow fsspec
+ # filesystems (which should be compatible with our legacy fs)
+ return fs
+
+ raise OSError('Unrecognized filesystem: {}'.format(fs_type))
+ else:
+ return fs
+
+
+def resolve_filesystem_and_path(where, filesystem=None):
+ """
+ Return filesystem from path which could be an HDFS URI, a local URI,
+ or a plain filesystem path.
+ """
+ if not _is_path_like(where):
+ if filesystem is not None:
+ raise ValueError("filesystem passed but where is file-like, so"
+ " there is nothing to open with filesystem.")
+ return filesystem, where
+
+ if filesystem is not None:
+ filesystem = _ensure_filesystem(filesystem)
+ if isinstance(filesystem, LocalFileSystem):
+ path = _stringify_path(where)
+ elif not isinstance(where, str):
+ raise TypeError(
+ "Expected string path; path-like objects are only allowed "
+ "with a local filesystem"
+ )
+ else:
+ path = where
+ return filesystem, path
+
+ path = _stringify_path(where)
+
+ parsed_uri = urllib.parse.urlparse(path)
+ if parsed_uri.scheme == 'hdfs' or parsed_uri.scheme == 'viewfs':
+ # Input is hdfs URI such as hdfs://host:port/myfile.parquet
+ netloc_split = parsed_uri.netloc.split(':')
+ host = netloc_split[0]
+ if host == '':
+ host = 'default'
+ else:
+ host = parsed_uri.scheme + "://" + host
+ port = 0
+ if len(netloc_split) == 2 and netloc_split[1].isnumeric():
+ port = int(netloc_split[1])
+ fs = pa.hdfs._connect(host=host, port=port)
+ fs_path = parsed_uri.path
+ elif parsed_uri.scheme == 'file':
+ # Input is local URI such as file:///home/user/myfile.parquet
+ fs = LocalFileSystem._get_instance()
+ fs_path = parsed_uri.path
+ else:
+ # Input is local path such as /home/user/myfile.parquet
+ fs = LocalFileSystem._get_instance()
+ fs_path = path
+
+ return fs, fs_path
diff --git a/src/arrow/python/pyarrow/flight.py b/src/arrow/python/pyarrow/flight.py
new file mode 100644
index 000000000..0664ff2c9
--- /dev/null
+++ b/src/arrow/python/pyarrow/flight.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyarrow._flight import ( # noqa:F401
+ connect,
+ Action,
+ ActionType,
+ BasicAuth,
+ CallInfo,
+ CertKeyPair,
+ ClientAuthHandler,
+ ClientMiddleware,
+ ClientMiddlewareFactory,
+ DescriptorType,
+ FlightCallOptions,
+ FlightCancelledError,
+ FlightClient,
+ FlightDataStream,
+ FlightDescriptor,
+ FlightEndpoint,
+ FlightError,
+ FlightInfo,
+ FlightInternalError,
+ FlightMetadataReader,
+ FlightMetadataWriter,
+ FlightMethod,
+ FlightServerBase,
+ FlightServerError,
+ FlightStreamChunk,
+ FlightStreamReader,
+ FlightStreamWriter,
+ FlightTimedOutError,
+ FlightUnauthenticatedError,
+ FlightUnauthorizedError,
+ FlightUnavailableError,
+ FlightWriteSizeExceededError,
+ GeneratorStream,
+ Location,
+ MetadataRecordBatchReader,
+ MetadataRecordBatchWriter,
+ RecordBatchStream,
+ Result,
+ SchemaResult,
+ ServerAuthHandler,
+ ServerCallContext,
+ ServerMiddleware,
+ ServerMiddlewareFactory,
+ Ticket,
+)
diff --git a/src/arrow/python/pyarrow/fs.py b/src/arrow/python/pyarrow/fs.py
new file mode 100644
index 000000000..5d3326861
--- /dev/null
+++ b/src/arrow/python/pyarrow/fs.py
@@ -0,0 +1,405 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+FileSystem abstraction to interact with various local and remote filesystems.
+"""
+
+from pyarrow.util import _is_path_like, _stringify_path
+
+from pyarrow._fs import ( # noqa
+ FileSelector,
+ FileType,
+ FileInfo,
+ FileSystem,
+ LocalFileSystem,
+ SubTreeFileSystem,
+ _MockFileSystem,
+ FileSystemHandler,
+ PyFileSystem,
+ _copy_files,
+ _copy_files_selector,
+)
+
+# For backward compatibility.
+FileStats = FileInfo
+
+_not_imported = []
+
+try:
+ from pyarrow._hdfs import HadoopFileSystem # noqa
+except ImportError:
+ _not_imported.append("HadoopFileSystem")
+
+try:
+ from pyarrow._s3fs import ( # noqa
+ S3FileSystem, S3LogLevel, initialize_s3, finalize_s3)
+except ImportError:
+ _not_imported.append("S3FileSystem")
+else:
+ initialize_s3()
+
+
+def __getattr__(name):
+ if name in _not_imported:
+ raise ImportError(
+ "The pyarrow installation is not built with support for "
+ "'{0}'".format(name)
+ )
+
+ raise AttributeError(
+ "module 'pyarrow.fs' has no attribute '{0}'".format(name)
+ )
+
+
+def _filesystem_from_str(uri):
+ # instantiate the file system from an uri, if the uri has a path
+ # component then it will be treated as a path prefix
+ filesystem, prefix = FileSystem.from_uri(uri)
+ prefix = filesystem.normalize_path(prefix)
+ if prefix:
+ # validate that the prefix is pointing to a directory
+ prefix_info = filesystem.get_file_info([prefix])[0]
+ if prefix_info.type != FileType.Directory:
+ raise ValueError(
+ "The path component of the filesystem URI must point to a "
+ "directory but it has a type: `{}`. The path component "
+ "is `{}` and the given filesystem URI is `{}`".format(
+ prefix_info.type.name, prefix_info.path, uri
+ )
+ )
+ filesystem = SubTreeFileSystem(prefix, filesystem)
+ return filesystem
+
+
+def _ensure_filesystem(
+ filesystem, use_mmap=False, allow_legacy_filesystem=False
+):
+ if isinstance(filesystem, FileSystem):
+ return filesystem
+ elif isinstance(filesystem, str):
+ if use_mmap:
+ raise ValueError(
+ "Specifying to use memory mapping not supported for "
+ "filesytem specified as an URI string"
+ )
+ return _filesystem_from_str(filesystem)
+
+ # handle fsspec-compatible filesystems
+ try:
+ import fsspec
+ except ImportError:
+ pass
+ else:
+ if isinstance(filesystem, fsspec.AbstractFileSystem):
+ if type(filesystem).__name__ == 'LocalFileSystem':
+ # In case its a simple LocalFileSystem, use native arrow one
+ return LocalFileSystem(use_mmap=use_mmap)
+ return PyFileSystem(FSSpecHandler(filesystem))
+
+ # map old filesystems to new ones
+ import pyarrow.filesystem as legacyfs
+
+ if isinstance(filesystem, legacyfs.LocalFileSystem):
+ return LocalFileSystem(use_mmap=use_mmap)
+ # TODO handle HDFS?
+ if allow_legacy_filesystem and isinstance(filesystem, legacyfs.FileSystem):
+ return filesystem
+
+ raise TypeError(
+ "Unrecognized filesystem: {}. `filesystem` argument must be a "
+ "FileSystem instance or a valid file system URI'".format(
+ type(filesystem))
+ )
+
+
+def _resolve_filesystem_and_path(
+ path, filesystem=None, allow_legacy_filesystem=False
+):
+ """
+ Return filesystem/path from path which could be an URI or a plain
+ filesystem path.
+ """
+ if not _is_path_like(path):
+ if filesystem is not None:
+ raise ValueError(
+ "'filesystem' passed but the specified path is file-like, so"
+ " there is nothing to open with 'filesystem'."
+ )
+ return filesystem, path
+
+ if filesystem is not None:
+ filesystem = _ensure_filesystem(
+ filesystem, allow_legacy_filesystem=allow_legacy_filesystem
+ )
+ if isinstance(filesystem, LocalFileSystem):
+ path = _stringify_path(path)
+ elif not isinstance(path, str):
+ raise TypeError(
+ "Expected string path; path-like objects are only allowed "
+ "with a local filesystem"
+ )
+ if not allow_legacy_filesystem:
+ path = filesystem.normalize_path(path)
+ return filesystem, path
+
+ path = _stringify_path(path)
+
+ # if filesystem is not given, try to automatically determine one
+ # first check if the file exists as a local (relative) file path
+ # if not then try to parse the path as an URI
+ filesystem = LocalFileSystem()
+ try:
+ file_info = filesystem.get_file_info(path)
+ except OSError:
+ file_info = None
+ exists_locally = False
+ else:
+ exists_locally = (file_info.type != FileType.NotFound)
+
+ # if the file or directory doesn't exists locally, then assume that
+ # the path is an URI describing the file system as well
+ if not exists_locally:
+ try:
+ filesystem, path = FileSystem.from_uri(path)
+ except ValueError as e:
+ # neither an URI nor a locally existing path, so assume that
+ # local path was given and propagate a nicer file not found error
+ # instead of a more confusing scheme parsing error
+ if "empty scheme" not in str(e):
+ raise
+ else:
+ path = filesystem.normalize_path(path)
+
+ return filesystem, path
+
+
+def copy_files(source, destination,
+ source_filesystem=None, destination_filesystem=None,
+ *, chunk_size=1024*1024, use_threads=True):
+ """
+ Copy files between FileSystems.
+
+ This functions allows you to recursively copy directories of files from
+ one file system to another, such as from S3 to your local machine.
+
+ Parameters
+ ----------
+ source : string
+ Source file path or URI to a single file or directory.
+ If a directory, files will be copied recursively from this path.
+ destination : string
+ Destination file path or URI. If `source` is a file, `destination`
+ is also interpreted as the destination file (not directory).
+ Directories will be created as necessary.
+ source_filesystem : FileSystem, optional
+ Source filesystem, needs to be specified if `source` is not a URI,
+ otherwise inferred.
+ destination_filesystem : FileSystem, optional
+ Destination filesystem, needs to be specified if `destination` is not
+ a URI, otherwise inferred.
+ chunk_size : int, default 1MB
+ The maximum size of block to read before flushing to the
+ destination file. A larger chunk_size will use more memory while
+ copying but may help accommodate high latency FileSystems.
+ use_threads : bool, default True
+ Whether to use multiple threads to accelerate copying.
+
+ Examples
+ --------
+ Copy an S3 bucket's files to a local directory:
+
+ >>> copy_files("s3://your-bucket-name", "local-directory")
+
+ Using a FileSystem object:
+
+ >>> copy_files("your-bucket-name", "local-directory",
+ ... source_filesystem=S3FileSystem(...))
+
+ """
+ source_fs, source_path = _resolve_filesystem_and_path(
+ source, source_filesystem
+ )
+ destination_fs, destination_path = _resolve_filesystem_and_path(
+ destination, destination_filesystem
+ )
+
+ file_info = source_fs.get_file_info(source_path)
+ if file_info.type == FileType.Directory:
+ source_sel = FileSelector(source_path, recursive=True)
+ _copy_files_selector(source_fs, source_sel,
+ destination_fs, destination_path,
+ chunk_size, use_threads)
+ else:
+ _copy_files(source_fs, source_path,
+ destination_fs, destination_path,
+ chunk_size, use_threads)
+
+
+class FSSpecHandler(FileSystemHandler):
+ """
+ Handler for fsspec-based Python filesystems.
+
+ https://filesystem-spec.readthedocs.io/en/latest/index.html
+
+ Parameters
+ ----------
+ fs : The FSSpec-compliant filesystem instance.
+
+ Examples
+ --------
+ >>> PyFileSystem(FSSpecHandler(fsspec_fs))
+ """
+
+ def __init__(self, fs):
+ self.fs = fs
+
+ def __eq__(self, other):
+ if isinstance(other, FSSpecHandler):
+ return self.fs == other.fs
+ return NotImplemented
+
+ def __ne__(self, other):
+ if isinstance(other, FSSpecHandler):
+ return self.fs != other.fs
+ return NotImplemented
+
+ def get_type_name(self):
+ protocol = self.fs.protocol
+ if isinstance(protocol, list):
+ protocol = protocol[0]
+ return "fsspec+{0}".format(protocol)
+
+ def normalize_path(self, path):
+ return path
+
+ @staticmethod
+ def _create_file_info(path, info):
+ size = info["size"]
+ if info["type"] == "file":
+ ftype = FileType.File
+ elif info["type"] == "directory":
+ ftype = FileType.Directory
+ # some fsspec filesystems include a file size for directories
+ size = None
+ else:
+ ftype = FileType.Unknown
+ return FileInfo(path, ftype, size=size, mtime=info.get("mtime", None))
+
+ def get_file_info(self, paths):
+ infos = []
+ for path in paths:
+ try:
+ info = self.fs.info(path)
+ except FileNotFoundError:
+ infos.append(FileInfo(path, FileType.NotFound))
+ else:
+ infos.append(self._create_file_info(path, info))
+ return infos
+
+ def get_file_info_selector(self, selector):
+ if not self.fs.isdir(selector.base_dir):
+ if self.fs.exists(selector.base_dir):
+ raise NotADirectoryError(selector.base_dir)
+ else:
+ if selector.allow_not_found:
+ return []
+ else:
+ raise FileNotFoundError(selector.base_dir)
+
+ if selector.recursive:
+ maxdepth = None
+ else:
+ maxdepth = 1
+
+ infos = []
+ selected_files = self.fs.find(
+ selector.base_dir, maxdepth=maxdepth, withdirs=True, detail=True
+ )
+ for path, info in selected_files.items():
+ infos.append(self._create_file_info(path, info))
+
+ return infos
+
+ def create_dir(self, path, recursive):
+ # mkdir also raises FileNotFoundError when base directory is not found
+ try:
+ self.fs.mkdir(path, create_parents=recursive)
+ except FileExistsError:
+ pass
+
+ def delete_dir(self, path):
+ self.fs.rm(path, recursive=True)
+
+ def _delete_dir_contents(self, path):
+ for subpath in self.fs.listdir(path, detail=False):
+ if self.fs.isdir(subpath):
+ self.fs.rm(subpath, recursive=True)
+ elif self.fs.isfile(subpath):
+ self.fs.rm(subpath)
+
+ def delete_dir_contents(self, path):
+ if path.strip("/") == "":
+ raise ValueError(
+ "delete_dir_contents called on path '", path, "'")
+ self._delete_dir_contents(path)
+
+ def delete_root_dir_contents(self):
+ self._delete_dir_contents("/")
+
+ def delete_file(self, path):
+ # fs.rm correctly raises IsADirectoryError when `path` is a directory
+ # instead of a file and `recursive` is not set to True
+ if not self.fs.exists(path):
+ raise FileNotFoundError(path)
+ self.fs.rm(path)
+
+ def move(self, src, dest):
+ self.fs.mv(src, dest, recursive=True)
+
+ def copy_file(self, src, dest):
+ # fs.copy correctly raises IsADirectoryError when `src` is a directory
+ # instead of a file
+ self.fs.copy(src, dest)
+
+ # TODO can we read/pass metadata (e.g. Content-Type) in the methods below?
+
+ def open_input_stream(self, path):
+ from pyarrow import PythonFile
+
+ if not self.fs.isfile(path):
+ raise FileNotFoundError(path)
+
+ return PythonFile(self.fs.open(path, mode="rb"), mode="r")
+
+ def open_input_file(self, path):
+ from pyarrow import PythonFile
+
+ if not self.fs.isfile(path):
+ raise FileNotFoundError(path)
+
+ return PythonFile(self.fs.open(path, mode="rb"), mode="r")
+
+ def open_output_stream(self, path, metadata):
+ from pyarrow import PythonFile
+
+ return PythonFile(self.fs.open(path, mode="wb"), mode="w")
+
+ def open_append_stream(self, path, metadata):
+ from pyarrow import PythonFile
+
+ return PythonFile(self.fs.open(path, mode="ab"), mode="w")
diff --git a/src/arrow/python/pyarrow/gandiva.pyx b/src/arrow/python/pyarrow/gandiva.pyx
new file mode 100644
index 000000000..12d572b33
--- /dev/null
+++ b/src/arrow/python/pyarrow/gandiva.pyx
@@ -0,0 +1,518 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile=False
+# distutils: language = c++
+# cython: language_level = 3
+
+from libcpp cimport bool as c_bool, nullptr
+from libcpp.memory cimport shared_ptr, unique_ptr, make_shared
+from libcpp.string cimport string as c_string
+from libcpp.vector cimport vector as c_vector
+from libcpp.unordered_set cimport unordered_set as c_unordered_set
+from libc.stdint cimport int64_t, int32_t, uint8_t, uintptr_t
+
+from pyarrow.includes.libarrow cimport *
+from pyarrow.lib cimport (Array, DataType, Field, MemoryPool, RecordBatch,
+ Schema, check_status, pyarrow_wrap_array,
+ pyarrow_wrap_data_type, ensure_type, _Weakrefable,
+ pyarrow_wrap_field)
+from pyarrow.lib import frombytes
+
+from pyarrow.includes.libgandiva cimport (
+ CCondition, CExpression,
+ CNode, CProjector, CFilter,
+ CSelectionVector,
+ CSelectionVector_Mode,
+ _ensure_selection_mode,
+ CConfiguration,
+ CConfigurationBuilder,
+ TreeExprBuilder_MakeExpression,
+ TreeExprBuilder_MakeFunction,
+ TreeExprBuilder_MakeBoolLiteral,
+ TreeExprBuilder_MakeUInt8Literal,
+ TreeExprBuilder_MakeUInt16Literal,
+ TreeExprBuilder_MakeUInt32Literal,
+ TreeExprBuilder_MakeUInt64Literal,
+ TreeExprBuilder_MakeInt8Literal,
+ TreeExprBuilder_MakeInt16Literal,
+ TreeExprBuilder_MakeInt32Literal,
+ TreeExprBuilder_MakeInt64Literal,
+ TreeExprBuilder_MakeFloatLiteral,
+ TreeExprBuilder_MakeDoubleLiteral,
+ TreeExprBuilder_MakeStringLiteral,
+ TreeExprBuilder_MakeBinaryLiteral,
+ TreeExprBuilder_MakeField,
+ TreeExprBuilder_MakeIf,
+ TreeExprBuilder_MakeAnd,
+ TreeExprBuilder_MakeOr,
+ TreeExprBuilder_MakeCondition,
+ TreeExprBuilder_MakeInExpressionInt32,
+ TreeExprBuilder_MakeInExpressionInt64,
+ TreeExprBuilder_MakeInExpressionTime32,
+ TreeExprBuilder_MakeInExpressionTime64,
+ TreeExprBuilder_MakeInExpressionDate32,
+ TreeExprBuilder_MakeInExpressionDate64,
+ TreeExprBuilder_MakeInExpressionTimeStamp,
+ TreeExprBuilder_MakeInExpressionString,
+ TreeExprBuilder_MakeInExpressionBinary,
+ SelectionVector_MakeInt16,
+ SelectionVector_MakeInt32,
+ SelectionVector_MakeInt64,
+ Projector_Make,
+ Filter_Make,
+ CFunctionSignature,
+ GetRegisteredFunctionSignatures)
+
+cdef class Node(_Weakrefable):
+ cdef:
+ shared_ptr[CNode] node
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use the "
+ "TreeExprBuilder API directly"
+ .format(self.__class__.__name__))
+
+ @staticmethod
+ cdef create(shared_ptr[CNode] node):
+ cdef Node self = Node.__new__(Node)
+ self.node = node
+ return self
+
+ def __str__(self):
+ return self.node.get().ToString().decode()
+
+ def __repr__(self):
+ type_format = object.__repr__(self)
+ return '{0}\n{1}'.format(type_format, str(self))
+
+ def return_type(self):
+ return pyarrow_wrap_data_type(self.node.get().return_type())
+
+cdef class Expression(_Weakrefable):
+ cdef:
+ shared_ptr[CExpression] expression
+
+ cdef void init(self, shared_ptr[CExpression] expression):
+ self.expression = expression
+
+ def __str__(self):
+ return self.expression.get().ToString().decode()
+
+ def __repr__(self):
+ type_format = object.__repr__(self)
+ return '{0}\n{1}'.format(type_format, str(self))
+
+ def root(self):
+ return Node.create(self.expression.get().root())
+
+ def result(self):
+ return pyarrow_wrap_field(self.expression.get().result())
+
+cdef class Condition(_Weakrefable):
+ cdef:
+ shared_ptr[CCondition] condition
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use the "
+ "TreeExprBuilder API instead"
+ .format(self.__class__.__name__))
+
+ @staticmethod
+ cdef create(shared_ptr[CCondition] condition):
+ cdef Condition self = Condition.__new__(Condition)
+ self.condition = condition
+ return self
+
+ def __str__(self):
+ return self.condition.get().ToString().decode()
+
+ def __repr__(self):
+ type_format = object.__repr__(self)
+ return '{0}\n{1}'.format(type_format, str(self))
+
+ def root(self):
+ return Node.create(self.condition.get().root())
+
+ def result(self):
+ return pyarrow_wrap_field(self.condition.get().result())
+
+cdef class SelectionVector(_Weakrefable):
+ cdef:
+ shared_ptr[CSelectionVector] selection_vector
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly."
+ .format(self.__class__.__name__))
+
+ @staticmethod
+ cdef create(shared_ptr[CSelectionVector] selection_vector):
+ cdef SelectionVector self = SelectionVector.__new__(SelectionVector)
+ self.selection_vector = selection_vector
+ return self
+
+ def to_array(self):
+ cdef shared_ptr[CArray] result = self.selection_vector.get().ToArray()
+ return pyarrow_wrap_array(result)
+
+cdef class Projector(_Weakrefable):
+ cdef:
+ shared_ptr[CProjector] projector
+ MemoryPool pool
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use "
+ "make_projector instead"
+ .format(self.__class__.__name__))
+
+ @staticmethod
+ cdef create(shared_ptr[CProjector] projector, MemoryPool pool):
+ cdef Projector self = Projector.__new__(Projector)
+ self.projector = projector
+ self.pool = pool
+ return self
+
+ @property
+ def llvm_ir(self):
+ return self.projector.get().DumpIR().decode()
+
+ def evaluate(self, RecordBatch batch, SelectionVector selection=None):
+ cdef vector[shared_ptr[CArray]] results
+ if selection is None:
+ check_status(self.projector.get().Evaluate(
+ batch.sp_batch.get()[0], self.pool.pool, &results))
+ else:
+ check_status(
+ self.projector.get().Evaluate(
+ batch.sp_batch.get()[0], selection.selection_vector.get(),
+ self.pool.pool, &results))
+ cdef shared_ptr[CArray] result
+ arrays = []
+ for result in results:
+ arrays.append(pyarrow_wrap_array(result))
+ return arrays
+
+cdef class Filter(_Weakrefable):
+ cdef:
+ shared_ptr[CFilter] filter
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use "
+ "make_filter instead"
+ .format(self.__class__.__name__))
+
+ @staticmethod
+ cdef create(shared_ptr[CFilter] filter):
+ cdef Filter self = Filter.__new__(Filter)
+ self.filter = filter
+ return self
+
+ @property
+ def llvm_ir(self):
+ return self.filter.get().DumpIR().decode()
+
+ def evaluate(self, RecordBatch batch, MemoryPool pool, dtype='int32'):
+ cdef:
+ DataType type = ensure_type(dtype)
+ shared_ptr[CSelectionVector] selection
+
+ if type.id == _Type_INT16:
+ check_status(SelectionVector_MakeInt16(
+ batch.num_rows, pool.pool, &selection))
+ elif type.id == _Type_INT32:
+ check_status(SelectionVector_MakeInt32(
+ batch.num_rows, pool.pool, &selection))
+ elif type.id == _Type_INT64:
+ check_status(SelectionVector_MakeInt64(
+ batch.num_rows, pool.pool, &selection))
+ else:
+ raise ValueError("'dtype' of the selection vector should be "
+ "one of 'int16', 'int32' and 'int64'.")
+
+ check_status(self.filter.get().Evaluate(
+ batch.sp_batch.get()[0], selection))
+ return SelectionVector.create(selection)
+
+
+cdef class TreeExprBuilder(_Weakrefable):
+
+ def make_literal(self, value, dtype):
+ cdef:
+ DataType type = ensure_type(dtype)
+ shared_ptr[CNode] r
+
+ if type.id == _Type_BOOL:
+ r = TreeExprBuilder_MakeBoolLiteral(value)
+ elif type.id == _Type_UINT8:
+ r = TreeExprBuilder_MakeUInt8Literal(value)
+ elif type.id == _Type_UINT16:
+ r = TreeExprBuilder_MakeUInt16Literal(value)
+ elif type.id == _Type_UINT32:
+ r = TreeExprBuilder_MakeUInt32Literal(value)
+ elif type.id == _Type_UINT64:
+ r = TreeExprBuilder_MakeUInt64Literal(value)
+ elif type.id == _Type_INT8:
+ r = TreeExprBuilder_MakeInt8Literal(value)
+ elif type.id == _Type_INT16:
+ r = TreeExprBuilder_MakeInt16Literal(value)
+ elif type.id == _Type_INT32:
+ r = TreeExprBuilder_MakeInt32Literal(value)
+ elif type.id == _Type_INT64:
+ r = TreeExprBuilder_MakeInt64Literal(value)
+ elif type.id == _Type_FLOAT:
+ r = TreeExprBuilder_MakeFloatLiteral(value)
+ elif type.id == _Type_DOUBLE:
+ r = TreeExprBuilder_MakeDoubleLiteral(value)
+ elif type.id == _Type_STRING:
+ r = TreeExprBuilder_MakeStringLiteral(value.encode('UTF-8'))
+ elif type.id == _Type_BINARY:
+ r = TreeExprBuilder_MakeBinaryLiteral(value)
+ else:
+ raise TypeError("Didn't recognize dtype " + str(dtype))
+
+ return Node.create(r)
+
+ def make_expression(self, Node root_node, Field return_field):
+ cdef shared_ptr[CExpression] r = TreeExprBuilder_MakeExpression(
+ root_node.node, return_field.sp_field)
+ cdef Expression expression = Expression()
+ expression.init(r)
+ return expression
+
+ def make_function(self, name, children, DataType return_type):
+ cdef c_vector[shared_ptr[CNode]] c_children
+ cdef Node child
+ for child in children:
+ c_children.push_back(child.node)
+ cdef shared_ptr[CNode] r = TreeExprBuilder_MakeFunction(
+ name.encode(), c_children, return_type.sp_type)
+ return Node.create(r)
+
+ def make_field(self, Field field):
+ cdef shared_ptr[CNode] r = TreeExprBuilder_MakeField(field.sp_field)
+ return Node.create(r)
+
+ def make_if(self, Node condition, Node this_node,
+ Node else_node, DataType return_type):
+ cdef shared_ptr[CNode] r = TreeExprBuilder_MakeIf(
+ condition.node, this_node.node, else_node.node,
+ return_type.sp_type)
+ return Node.create(r)
+
+ def make_and(self, children):
+ cdef c_vector[shared_ptr[CNode]] c_children
+ cdef Node child
+ for child in children:
+ c_children.push_back(child.node)
+ cdef shared_ptr[CNode] r = TreeExprBuilder_MakeAnd(c_children)
+ return Node.create(r)
+
+ def make_or(self, children):
+ cdef c_vector[shared_ptr[CNode]] c_children
+ cdef Node child
+ for child in children:
+ c_children.push_back(child.node)
+ cdef shared_ptr[CNode] r = TreeExprBuilder_MakeOr(c_children)
+ return Node.create(r)
+
+ def _make_in_expression_int32(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[int32_t] c_values
+ cdef int32_t v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionInt32(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_int64(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[int64_t] c_values
+ cdef int64_t v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionInt64(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_time32(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[int32_t] c_values
+ cdef int32_t v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionTime32(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_time64(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[int64_t] c_values
+ cdef int64_t v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionTime64(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_date32(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[int32_t] c_values
+ cdef int32_t v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionDate32(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_date64(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[int64_t] c_values
+ cdef int64_t v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionDate64(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_timestamp(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[int64_t] c_values
+ cdef int64_t v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionTimeStamp(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_binary(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[c_string] c_values
+ cdef c_string v
+ for v in values:
+ c_values.insert(v)
+ r = TreeExprBuilder_MakeInExpressionString(node.node, c_values)
+ return Node.create(r)
+
+ def _make_in_expression_string(self, Node node, values):
+ cdef shared_ptr[CNode] r
+ cdef c_unordered_set[c_string] c_values
+ cdef c_string _v
+ for v in values:
+ _v = v.encode('UTF-8')
+ c_values.insert(_v)
+ r = TreeExprBuilder_MakeInExpressionString(node.node, c_values)
+ return Node.create(r)
+
+ def make_in_expression(self, Node node, values, dtype):
+ cdef DataType type = ensure_type(dtype)
+
+ if type.id == _Type_INT32:
+ return self._make_in_expression_int32(node, values)
+ elif type.id == _Type_INT64:
+ return self._make_in_expression_int64(node, values)
+ elif type.id == _Type_TIME32:
+ return self._make_in_expression_time32(node, values)
+ elif type.id == _Type_TIME64:
+ return self._make_in_expression_time64(node, values)
+ elif type.id == _Type_TIMESTAMP:
+ return self._make_in_expression_timestamp(node, values)
+ elif type.id == _Type_DATE32:
+ return self._make_in_expression_date32(node, values)
+ elif type.id == _Type_DATE64:
+ return self._make_in_expression_date64(node, values)
+ elif type.id == _Type_BINARY:
+ return self._make_in_expression_binary(node, values)
+ elif type.id == _Type_STRING:
+ return self._make_in_expression_string(node, values)
+ else:
+ raise TypeError("Data type " + str(dtype) + " not supported.")
+
+ def make_condition(self, Node condition):
+ cdef shared_ptr[CCondition] r = TreeExprBuilder_MakeCondition(
+ condition.node)
+ return Condition.create(r)
+
+cpdef make_projector(Schema schema, children, MemoryPool pool,
+ str selection_mode="NONE"):
+ cdef c_vector[shared_ptr[CExpression]] c_children
+ cdef Expression child
+ for child in children:
+ c_children.push_back(child.expression)
+ cdef shared_ptr[CProjector] result
+ check_status(
+ Projector_Make(schema.sp_schema, c_children,
+ _ensure_selection_mode(selection_mode),
+ CConfigurationBuilder.DefaultConfiguration(),
+ &result))
+ return Projector.create(result, pool)
+
+cpdef make_filter(Schema schema, Condition condition):
+ cdef shared_ptr[CFilter] result
+ check_status(
+ Filter_Make(schema.sp_schema, condition.condition, &result))
+ return Filter.create(result)
+
+cdef class FunctionSignature(_Weakrefable):
+ """
+ Signature of a Gandiva function including name, parameter types
+ and return type.
+ """
+
+ cdef:
+ shared_ptr[CFunctionSignature] signature
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly."
+ .format(self.__class__.__name__))
+
+ @staticmethod
+ cdef create(shared_ptr[CFunctionSignature] signature):
+ cdef FunctionSignature self = FunctionSignature.__new__(
+ FunctionSignature)
+ self.signature = signature
+ return self
+
+ def return_type(self):
+ return pyarrow_wrap_data_type(self.signature.get().ret_type())
+
+ def param_types(self):
+ result = []
+ cdef vector[shared_ptr[CDataType]] types = \
+ self.signature.get().param_types()
+ for t in types:
+ result.append(pyarrow_wrap_data_type(t))
+ return result
+
+ def name(self):
+ return self.signature.get().base_name().decode()
+
+ def __repr__(self):
+ signature = self.signature.get().ToString().decode()
+ return "FunctionSignature(" + signature + ")"
+
+
+def get_registered_function_signatures():
+ """
+ Return the function in Gandiva's ExpressionRegistry.
+
+ Returns
+ -------
+ registry: a list of registered function signatures
+ """
+ results = []
+
+ cdef vector[shared_ptr[CFunctionSignature]] signatures = \
+ GetRegisteredFunctionSignatures()
+
+ for signature in signatures:
+ results.append(FunctionSignature.create(signature))
+
+ return results
diff --git a/src/arrow/python/pyarrow/hdfs.py b/src/arrow/python/pyarrow/hdfs.py
new file mode 100644
index 000000000..56667bd5d
--- /dev/null
+++ b/src/arrow/python/pyarrow/hdfs.py
@@ -0,0 +1,240 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import os
+import posixpath
+import sys
+import warnings
+
+from pyarrow.util import implements, _DEPR_MSG
+from pyarrow.filesystem import FileSystem
+import pyarrow._hdfsio as _hdfsio
+
+
+class HadoopFileSystem(_hdfsio.HadoopFileSystem, FileSystem):
+ """
+ DEPRECATED: FileSystem interface for HDFS cluster.
+
+ See pyarrow.hdfs.connect for full connection details
+
+ .. deprecated:: 2.0
+ ``pyarrow.hdfs.HadoopFileSystem`` is deprecated,
+ please use ``pyarrow.fs.HadoopFileSystem`` instead.
+ """
+
+ def __init__(self, host="default", port=0, user=None, kerb_ticket=None,
+ driver='libhdfs', extra_conf=None):
+ warnings.warn(
+ _DEPR_MSG.format(
+ "hdfs.HadoopFileSystem", "2.0.0", "fs.HadoopFileSystem"),
+ FutureWarning, stacklevel=2)
+ if driver == 'libhdfs':
+ _maybe_set_hadoop_classpath()
+
+ self._connect(host, port, user, kerb_ticket, extra_conf)
+
+ def __reduce__(self):
+ return (HadoopFileSystem, (self.host, self.port, self.user,
+ self.kerb_ticket, self.extra_conf))
+
+ def _isfilestore(self):
+ """
+ Return True if this is a Unix-style file store with directories.
+ """
+ return True
+
+ @implements(FileSystem.isdir)
+ def isdir(self, path):
+ return super().isdir(path)
+
+ @implements(FileSystem.isfile)
+ def isfile(self, path):
+ return super().isfile(path)
+
+ @implements(FileSystem.delete)
+ def delete(self, path, recursive=False):
+ return super().delete(path, recursive)
+
+ def mkdir(self, path, **kwargs):
+ """
+ Create directory in HDFS.
+
+ Parameters
+ ----------
+ path : str
+ Directory path to create, including any parent directories.
+
+ Notes
+ -----
+ libhdfs does not support create_parents=False, so we ignore this here
+ """
+ return super().mkdir(path)
+
+ @implements(FileSystem.rename)
+ def rename(self, path, new_path):
+ return super().rename(path, new_path)
+
+ @implements(FileSystem.exists)
+ def exists(self, path):
+ return super().exists(path)
+
+ def ls(self, path, detail=False):
+ """
+ Retrieve directory contents and metadata, if requested.
+
+ Parameters
+ ----------
+ path : str
+ HDFS path to retrieve contents of.
+ detail : bool, default False
+ If False, only return list of paths.
+
+ Returns
+ -------
+ result : list of dicts (detail=True) or strings (detail=False)
+ """
+ return super().ls(path, detail)
+
+ def walk(self, top_path):
+ """
+ Directory tree generator for HDFS, like os.walk.
+
+ Parameters
+ ----------
+ top_path : str
+ Root directory for tree traversal.
+
+ Returns
+ -------
+ Generator yielding 3-tuple (dirpath, dirnames, filename)
+ """
+ contents = self.ls(top_path, detail=True)
+
+ directories, files = _libhdfs_walk_files_dirs(top_path, contents)
+ yield top_path, directories, files
+ for dirname in directories:
+ yield from self.walk(self._path_join(top_path, dirname))
+
+
+def _maybe_set_hadoop_classpath():
+ import re
+
+ if re.search(r'hadoop-common[^/]+.jar', os.environ.get('CLASSPATH', '')):
+ return
+
+ if 'HADOOP_HOME' in os.environ:
+ if sys.platform != 'win32':
+ classpath = _derive_hadoop_classpath()
+ else:
+ hadoop_bin = '{}/bin/hadoop'.format(os.environ['HADOOP_HOME'])
+ classpath = _hadoop_classpath_glob(hadoop_bin)
+ else:
+ classpath = _hadoop_classpath_glob('hadoop')
+
+ os.environ['CLASSPATH'] = classpath.decode('utf-8')
+
+
+def _derive_hadoop_classpath():
+ import subprocess
+
+ find_args = ('find', '-L', os.environ['HADOOP_HOME'], '-name', '*.jar')
+ find = subprocess.Popen(find_args, stdout=subprocess.PIPE)
+ xargs_echo = subprocess.Popen(('xargs', 'echo'),
+ stdin=find.stdout,
+ stdout=subprocess.PIPE)
+ jars = subprocess.check_output(('tr', "' '", "':'"),
+ stdin=xargs_echo.stdout)
+ hadoop_conf = os.environ["HADOOP_CONF_DIR"] \
+ if "HADOOP_CONF_DIR" in os.environ \
+ else os.environ["HADOOP_HOME"] + "/etc/hadoop"
+ return (hadoop_conf + ":").encode("utf-8") + jars
+
+
+def _hadoop_classpath_glob(hadoop_bin):
+ import subprocess
+
+ hadoop_classpath_args = (hadoop_bin, 'classpath', '--glob')
+ return subprocess.check_output(hadoop_classpath_args)
+
+
+def _libhdfs_walk_files_dirs(top_path, contents):
+ files = []
+ directories = []
+ for c in contents:
+ scrubbed_name = posixpath.split(c['name'])[1]
+ if c['kind'] == 'file':
+ files.append(scrubbed_name)
+ else:
+ directories.append(scrubbed_name)
+
+ return directories, files
+
+
+def connect(host="default", port=0, user=None, kerb_ticket=None,
+ extra_conf=None):
+ """
+ DEPRECATED: Connect to an HDFS cluster.
+
+ All parameters are optional and should only be set if the defaults need
+ to be overridden.
+
+ Authentication should be automatic if the HDFS cluster uses Kerberos.
+ However, if a username is specified, then the ticket cache will likely
+ be required.
+
+ .. deprecated:: 2.0
+ ``pyarrow.hdfs.connect`` is deprecated,
+ please use ``pyarrow.fs.HadoopFileSystem`` instead.
+
+ Parameters
+ ----------
+ host : NameNode. Set to "default" for fs.defaultFS from core-site.xml.
+ port : NameNode's port. Set to 0 for default or logical (HA) nodes.
+ user : Username when connecting to HDFS; None implies login user.
+ kerb_ticket : Path to Kerberos ticket cache.
+ extra_conf : dict, default None
+ extra Key/Value pairs for config; Will override any
+ hdfs-site.xml properties
+
+ Notes
+ -----
+ The first time you call this method, it will take longer than usual due
+ to JNI spin-up time.
+
+ Returns
+ -------
+ filesystem : HadoopFileSystem
+ """
+ warnings.warn(
+ _DEPR_MSG.format("hdfs.connect", "2.0.0", "fs.HadoopFileSystem"),
+ FutureWarning, stacklevel=2
+ )
+ return _connect(
+ host=host, port=port, user=user, kerb_ticket=kerb_ticket,
+ extra_conf=extra_conf
+ )
+
+
+def _connect(host="default", port=0, user=None, kerb_ticket=None,
+ extra_conf=None):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ fs = HadoopFileSystem(host=host, port=port, user=user,
+ kerb_ticket=kerb_ticket,
+ extra_conf=extra_conf)
+ return fs
diff --git a/src/arrow/python/pyarrow/includes/__init__.pxd b/src/arrow/python/pyarrow/includes/__init__.pxd
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/__init__.pxd
diff --git a/src/arrow/python/pyarrow/includes/common.pxd b/src/arrow/python/pyarrow/includes/common.pxd
new file mode 100644
index 000000000..902eaafbb
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/common.pxd
@@ -0,0 +1,138 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from libc.stdint cimport *
+from libcpp cimport bool as c_bool, nullptr
+from libcpp.functional cimport function
+from libcpp.memory cimport shared_ptr, unique_ptr, make_shared
+from libcpp.string cimport string as c_string
+from libcpp.utility cimport pair
+from libcpp.vector cimport vector
+from libcpp.unordered_map cimport unordered_map
+from libcpp.unordered_set cimport unordered_set
+
+from cpython cimport PyObject
+from cpython.datetime cimport PyDateTime_DateTime
+cimport cpython
+
+
+cdef extern from * namespace "std" nogil:
+ cdef shared_ptr[T] static_pointer_cast[T, U](shared_ptr[U])
+
+# vendored from the cymove project https://github.com/ozars/cymove
+cdef extern from * namespace "cymove" nogil:
+ """
+ #include <type_traits>
+ #include <utility>
+ namespace cymove {
+ template <typename T>
+ inline typename std::remove_reference<T>::type&& cymove(T& t) {
+ return std::move(t);
+ }
+ template <typename T>
+ inline typename std::remove_reference<T>::type&& cymove(T&& t) {
+ return std::move(t);
+ }
+ } // namespace cymove
+ """
+ cdef T move" cymove::cymove"[T](T)
+
+cdef extern from * namespace "arrow::py" nogil:
+ """
+ #include <memory>
+ #include <utility>
+
+ namespace arrow {
+ namespace py {
+ template <typename T>
+ std::shared_ptr<T> to_shared(std::unique_ptr<T>& t) {
+ return std::move(t);
+ }
+ template <typename T>
+ std::shared_ptr<T> to_shared(std::unique_ptr<T>&& t) {
+ return std::move(t);
+ }
+ } // namespace py
+ } // namespace arrow
+ """
+ cdef shared_ptr[T] to_shared" arrow::py::to_shared"[T](unique_ptr[T])
+
+cdef extern from "arrow/python/platform.h":
+ pass
+
+cdef extern from "<Python.h>":
+ void Py_XDECREF(PyObject* o)
+ Py_ssize_t Py_REFCNT(PyObject* o)
+
+cdef extern from "numpy/halffloat.h":
+ ctypedef uint16_t npy_half
+
+cdef extern from "arrow/api.h" namespace "arrow" nogil:
+ # We can later add more of the common status factory methods as needed
+ cdef CStatus CStatus_OK "arrow::Status::OK"()
+
+ cdef CStatus CStatus_Invalid "arrow::Status::Invalid"()
+ cdef CStatus CStatus_NotImplemented \
+ "arrow::Status::NotImplemented"(const c_string& msg)
+ cdef CStatus CStatus_UnknownError \
+ "arrow::Status::UnknownError"(const c_string& msg)
+
+ cdef cppclass CStatus "arrow::Status":
+ CStatus()
+
+ c_string ToString()
+ c_string message()
+ shared_ptr[CStatusDetail] detail()
+
+ c_bool ok()
+ c_bool IsIOError()
+ c_bool IsOutOfMemory()
+ c_bool IsInvalid()
+ c_bool IsKeyError()
+ c_bool IsNotImplemented()
+ c_bool IsTypeError()
+ c_bool IsCapacityError()
+ c_bool IsIndexError()
+ c_bool IsSerializationError()
+ c_bool IsCancelled()
+
+ cdef cppclass CStatusDetail "arrow::StatusDetail":
+ c_string ToString()
+
+
+cdef extern from "arrow/result.h" namespace "arrow" nogil:
+ cdef cppclass CResult "arrow::Result"[T]:
+ CResult()
+ CResult(CStatus)
+ CResult(T)
+ c_bool ok()
+ CStatus status()
+ T operator*()
+
+
+cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil:
+ T GetResultValue[T](CResult[T]) except *
+ cdef function[F] BindFunction[F](void* unbound, object bound, ...)
+
+
+cdef inline object PyObject_to_object(PyObject* o):
+ # Cast to "object" increments reference count
+ cdef object result = <object> o
+ cpython.Py_DECREF(result)
+ return result
diff --git a/src/arrow/python/pyarrow/includes/libarrow.pxd b/src/arrow/python/pyarrow/includes/libarrow.pxd
new file mode 100644
index 000000000..815238f11
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libarrow.pxd
@@ -0,0 +1,2615 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from pyarrow.includes.common cimport *
+
+cdef extern from "arrow/util/key_value_metadata.h" namespace "arrow" nogil:
+ cdef cppclass CKeyValueMetadata" arrow::KeyValueMetadata":
+ CKeyValueMetadata()
+ CKeyValueMetadata(const unordered_map[c_string, c_string]&)
+ CKeyValueMetadata(const vector[c_string]& keys,
+ const vector[c_string]& values)
+
+ void reserve(int64_t n)
+ int64_t size() const
+ c_string key(int64_t i) const
+ c_string value(int64_t i) const
+ int FindKey(const c_string& key) const
+
+ shared_ptr[CKeyValueMetadata] Copy() const
+ c_bool Equals(const CKeyValueMetadata& other)
+ void Append(const c_string& key, const c_string& value)
+ void ToUnorderedMap(unordered_map[c_string, c_string]*) const
+ c_string ToString() const
+
+ CResult[c_string] Get(const c_string& key) const
+ CStatus Delete(const c_string& key)
+ CStatus Set(const c_string& key, const c_string& value)
+ c_bool Contains(const c_string& key) const
+
+
+cdef extern from "arrow/util/decimal.h" namespace "arrow" nogil:
+ cdef cppclass CDecimal128" arrow::Decimal128":
+ c_string ToString(int32_t scale) const
+
+
+cdef extern from "arrow/util/decimal.h" namespace "arrow" nogil:
+ cdef cppclass CDecimal256" arrow::Decimal256":
+ c_string ToString(int32_t scale) const
+
+
+cdef extern from "arrow/config.h" namespace "arrow" nogil:
+ cdef cppclass CBuildInfo" arrow::BuildInfo":
+ int version
+ int version_major
+ int version_minor
+ int version_patch
+ c_string version_string
+ c_string so_version
+ c_string full_so_version
+ c_string compiler_id
+ c_string compiler_version
+ c_string compiler_flags
+ c_string git_id
+ c_string git_description
+ c_string package_kind
+
+ const CBuildInfo& GetBuildInfo()
+
+ cdef cppclass CRuntimeInfo" arrow::RuntimeInfo":
+ c_string simd_level
+ c_string detected_simd_level
+
+ CRuntimeInfo GetRuntimeInfo()
+
+
+cdef extern from "arrow/api.h" namespace "arrow" nogil:
+ cdef enum Type" arrow::Type::type":
+ _Type_NA" arrow::Type::NA"
+
+ _Type_BOOL" arrow::Type::BOOL"
+
+ _Type_UINT8" arrow::Type::UINT8"
+ _Type_INT8" arrow::Type::INT8"
+ _Type_UINT16" arrow::Type::UINT16"
+ _Type_INT16" arrow::Type::INT16"
+ _Type_UINT32" arrow::Type::UINT32"
+ _Type_INT32" arrow::Type::INT32"
+ _Type_UINT64" arrow::Type::UINT64"
+ _Type_INT64" arrow::Type::INT64"
+
+ _Type_HALF_FLOAT" arrow::Type::HALF_FLOAT"
+ _Type_FLOAT" arrow::Type::FLOAT"
+ _Type_DOUBLE" arrow::Type::DOUBLE"
+
+ _Type_DECIMAL128" arrow::Type::DECIMAL128"
+ _Type_DECIMAL256" arrow::Type::DECIMAL256"
+
+ _Type_DATE32" arrow::Type::DATE32"
+ _Type_DATE64" arrow::Type::DATE64"
+ _Type_TIMESTAMP" arrow::Type::TIMESTAMP"
+ _Type_TIME32" arrow::Type::TIME32"
+ _Type_TIME64" arrow::Type::TIME64"
+ _Type_DURATION" arrow::Type::DURATION"
+ _Type_INTERVAL_MONTH_DAY_NANO" arrow::Type::INTERVAL_MONTH_DAY_NANO"
+
+ _Type_BINARY" arrow::Type::BINARY"
+ _Type_STRING" arrow::Type::STRING"
+ _Type_LARGE_BINARY" arrow::Type::LARGE_BINARY"
+ _Type_LARGE_STRING" arrow::Type::LARGE_STRING"
+ _Type_FIXED_SIZE_BINARY" arrow::Type::FIXED_SIZE_BINARY"
+
+ _Type_LIST" arrow::Type::LIST"
+ _Type_LARGE_LIST" arrow::Type::LARGE_LIST"
+ _Type_FIXED_SIZE_LIST" arrow::Type::FIXED_SIZE_LIST"
+ _Type_STRUCT" arrow::Type::STRUCT"
+ _Type_SPARSE_UNION" arrow::Type::SPARSE_UNION"
+ _Type_DENSE_UNION" arrow::Type::DENSE_UNION"
+ _Type_DICTIONARY" arrow::Type::DICTIONARY"
+ _Type_MAP" arrow::Type::MAP"
+
+ _Type_EXTENSION" arrow::Type::EXTENSION"
+
+ cdef enum UnionMode" arrow::UnionMode::type":
+ _UnionMode_SPARSE" arrow::UnionMode::SPARSE"
+ _UnionMode_DENSE" arrow::UnionMode::DENSE"
+
+ cdef enum TimeUnit" arrow::TimeUnit::type":
+ TimeUnit_SECOND" arrow::TimeUnit::SECOND"
+ TimeUnit_MILLI" arrow::TimeUnit::MILLI"
+ TimeUnit_MICRO" arrow::TimeUnit::MICRO"
+ TimeUnit_NANO" arrow::TimeUnit::NANO"
+
+ cdef cppclass CBufferSpec" arrow::DataTypeLayout::BufferSpec":
+ pass
+
+ cdef cppclass CDataTypeLayout" arrow::DataTypeLayout":
+ vector[CBufferSpec] buffers
+ c_bool has_dictionary
+
+ cdef cppclass CDataType" arrow::DataType":
+ Type id()
+
+ c_bool Equals(const CDataType& other)
+ c_bool Equals(const shared_ptr[CDataType]& other)
+
+ shared_ptr[CField] field(int i)
+ const vector[shared_ptr[CField]] fields()
+ int num_fields()
+ CDataTypeLayout layout()
+ c_string ToString()
+
+ c_bool is_primitive(Type type)
+
+ cdef cppclass CArrayData" arrow::ArrayData":
+ shared_ptr[CDataType] type
+ int64_t length
+ int64_t null_count
+ int64_t offset
+ vector[shared_ptr[CBuffer]] buffers
+ vector[shared_ptr[CArrayData]] child_data
+ shared_ptr[CArrayData] dictionary
+
+ @staticmethod
+ shared_ptr[CArrayData] Make(const shared_ptr[CDataType]& type,
+ int64_t length,
+ vector[shared_ptr[CBuffer]]& buffers,
+ int64_t null_count,
+ int64_t offset)
+
+ @staticmethod
+ shared_ptr[CArrayData] MakeWithChildren" Make"(
+ const shared_ptr[CDataType]& type,
+ int64_t length,
+ vector[shared_ptr[CBuffer]]& buffers,
+ vector[shared_ptr[CArrayData]]& child_data,
+ int64_t null_count,
+ int64_t offset)
+
+ @staticmethod
+ shared_ptr[CArrayData] MakeWithChildrenAndDictionary" Make"(
+ const shared_ptr[CDataType]& type,
+ int64_t length,
+ vector[shared_ptr[CBuffer]]& buffers,
+ vector[shared_ptr[CArrayData]]& child_data,
+ shared_ptr[CArrayData]& dictionary,
+ int64_t null_count,
+ int64_t offset)
+
+ cdef cppclass CArray" arrow::Array":
+ shared_ptr[CDataType] type()
+
+ int64_t length()
+ int64_t null_count()
+ int64_t offset()
+ Type type_id()
+
+ int num_fields()
+
+ CResult[shared_ptr[CScalar]] GetScalar(int64_t i) const
+
+ c_string Diff(const CArray& other)
+ c_bool Equals(const CArray& arr)
+ c_bool IsNull(int i)
+
+ shared_ptr[CArrayData] data()
+
+ shared_ptr[CArray] Slice(int64_t offset)
+ shared_ptr[CArray] Slice(int64_t offset, int64_t length)
+
+ CStatus Validate() const
+ CStatus ValidateFull() const
+ CResult[shared_ptr[CArray]] View(const shared_ptr[CDataType]& type)
+
+ shared_ptr[CArray] MakeArray(const shared_ptr[CArrayData]& data)
+ CResult[shared_ptr[CArray]] MakeArrayOfNull(
+ const shared_ptr[CDataType]& type, int64_t length, CMemoryPool* pool)
+
+ CResult[shared_ptr[CArray]] MakeArrayFromScalar(
+ const CScalar& scalar, int64_t length, CMemoryPool* pool)
+
+ CStatus DebugPrint(const CArray& arr, int indent)
+
+ cdef cppclass CFixedWidthType" arrow::FixedWidthType"(CDataType):
+ int bit_width()
+
+ cdef cppclass CNullArray" arrow::NullArray"(CArray):
+ CNullArray(int64_t length)
+
+ cdef cppclass CDictionaryArray" arrow::DictionaryArray"(CArray):
+ CDictionaryArray(const shared_ptr[CDataType]& type,
+ const shared_ptr[CArray]& indices,
+ const shared_ptr[CArray]& dictionary)
+
+ @staticmethod
+ CResult[shared_ptr[CArray]] FromArrays(
+ const shared_ptr[CDataType]& type,
+ const shared_ptr[CArray]& indices,
+ const shared_ptr[CArray]& dictionary)
+
+ shared_ptr[CArray] indices()
+ shared_ptr[CArray] dictionary()
+
+ cdef cppclass CDate32Type" arrow::Date32Type"(CFixedWidthType):
+ pass
+
+ cdef cppclass CDate64Type" arrow::Date64Type"(CFixedWidthType):
+ pass
+
+ cdef cppclass CTimestampType" arrow::TimestampType"(CFixedWidthType):
+ CTimestampType(TimeUnit unit)
+ TimeUnit unit()
+ const c_string& timezone()
+
+ cdef cppclass CTime32Type" arrow::Time32Type"(CFixedWidthType):
+ TimeUnit unit()
+
+ cdef cppclass CTime64Type" arrow::Time64Type"(CFixedWidthType):
+ TimeUnit unit()
+
+ shared_ptr[CDataType] ctime32" arrow::time32"(TimeUnit unit)
+ shared_ptr[CDataType] ctime64" arrow::time64"(TimeUnit unit)
+
+ cdef cppclass CDurationType" arrow::DurationType"(CFixedWidthType):
+ TimeUnit unit()
+
+ shared_ptr[CDataType] cduration" arrow::duration"(TimeUnit unit)
+
+ cdef cppclass CDictionaryType" arrow::DictionaryType"(CFixedWidthType):
+ CDictionaryType(const shared_ptr[CDataType]& index_type,
+ const shared_ptr[CDataType]& value_type,
+ c_bool ordered)
+
+ shared_ptr[CDataType] index_type()
+ shared_ptr[CDataType] value_type()
+ c_bool ordered()
+
+ shared_ptr[CDataType] ctimestamp" arrow::timestamp"(TimeUnit unit)
+ shared_ptr[CDataType] ctimestamp" arrow::timestamp"(
+ TimeUnit unit, const c_string& timezone)
+
+ cdef cppclass CMemoryPool" arrow::MemoryPool":
+ int64_t bytes_allocated()
+ int64_t max_memory()
+ c_string backend_name()
+ void ReleaseUnused()
+
+ cdef cppclass CLoggingMemoryPool" arrow::LoggingMemoryPool"(CMemoryPool):
+ CLoggingMemoryPool(CMemoryPool*)
+
+ cdef cppclass CProxyMemoryPool" arrow::ProxyMemoryPool"(CMemoryPool):
+ CProxyMemoryPool(CMemoryPool*)
+
+ cdef cppclass CBuffer" arrow::Buffer":
+ CBuffer(const uint8_t* data, int64_t size)
+ const uint8_t* data()
+ uint8_t* mutable_data()
+ uintptr_t address()
+ uintptr_t mutable_address()
+ int64_t size()
+ shared_ptr[CBuffer] parent()
+ c_bool is_cpu() const
+ c_bool is_mutable() const
+ c_string ToHexString()
+ c_bool Equals(const CBuffer& other)
+
+ shared_ptr[CBuffer] SliceBuffer(const shared_ptr[CBuffer]& buffer,
+ int64_t offset, int64_t length)
+ shared_ptr[CBuffer] SliceBuffer(const shared_ptr[CBuffer]& buffer,
+ int64_t offset)
+
+ cdef cppclass CMutableBuffer" arrow::MutableBuffer"(CBuffer):
+ CMutableBuffer(const uint8_t* data, int64_t size)
+
+ cdef cppclass CResizableBuffer" arrow::ResizableBuffer"(CMutableBuffer):
+ CStatus Resize(const int64_t new_size, c_bool shrink_to_fit)
+ CStatus Reserve(const int64_t new_size)
+
+ CResult[unique_ptr[CBuffer]] AllocateBuffer(const int64_t size,
+ CMemoryPool* pool)
+
+ CResult[unique_ptr[CResizableBuffer]] AllocateResizableBuffer(
+ const int64_t size, CMemoryPool* pool)
+
+ cdef CMemoryPool* c_default_memory_pool" arrow::default_memory_pool"()
+ cdef CMemoryPool* c_system_memory_pool" arrow::system_memory_pool"()
+ cdef CStatus c_jemalloc_memory_pool" arrow::jemalloc_memory_pool"(
+ CMemoryPool** out)
+ cdef CStatus c_mimalloc_memory_pool" arrow::mimalloc_memory_pool"(
+ CMemoryPool** out)
+
+ CStatus c_jemalloc_set_decay_ms" arrow::jemalloc_set_decay_ms"(int ms)
+
+ cdef cppclass CListType" arrow::ListType"(CDataType):
+ CListType(const shared_ptr[CDataType]& value_type)
+ CListType(const shared_ptr[CField]& field)
+ shared_ptr[CDataType] value_type()
+ shared_ptr[CField] value_field()
+
+ cdef cppclass CLargeListType" arrow::LargeListType"(CDataType):
+ CLargeListType(const shared_ptr[CDataType]& value_type)
+ CLargeListType(const shared_ptr[CField]& field)
+ shared_ptr[CDataType] value_type()
+ shared_ptr[CField] value_field()
+
+ cdef cppclass CMapType" arrow::MapType"(CDataType):
+ CMapType(const shared_ptr[CField]& key_field,
+ const shared_ptr[CField]& item_field, c_bool keys_sorted)
+ shared_ptr[CDataType] key_type()
+ shared_ptr[CField] key_field()
+ shared_ptr[CDataType] item_type()
+ shared_ptr[CField] item_field()
+ c_bool keys_sorted()
+
+ cdef cppclass CFixedSizeListType" arrow::FixedSizeListType"(CDataType):
+ CFixedSizeListType(const shared_ptr[CDataType]& value_type,
+ int32_t list_size)
+ CFixedSizeListType(const shared_ptr[CField]& field, int32_t list_size)
+ shared_ptr[CDataType] value_type()
+ shared_ptr[CField] value_field()
+ int32_t list_size()
+
+ cdef cppclass CStringType" arrow::StringType"(CDataType):
+ pass
+
+ cdef cppclass CFixedSizeBinaryType \
+ " arrow::FixedSizeBinaryType"(CFixedWidthType):
+ CFixedSizeBinaryType(int byte_width)
+ int byte_width()
+ int bit_width()
+
+ cdef cppclass CDecimal128Type \
+ " arrow::Decimal128Type"(CFixedSizeBinaryType):
+ CDecimal128Type(int precision, int scale)
+ int precision()
+ int scale()
+
+ cdef cppclass CDecimal256Type \
+ " arrow::Decimal256Type"(CFixedSizeBinaryType):
+ CDecimal256Type(int precision, int scale)
+ int precision()
+ int scale()
+
+ cdef cppclass CField" arrow::Field":
+ cppclass CMergeOptions "arrow::Field::MergeOptions":
+ c_bool promote_nullability
+
+ @staticmethod
+ CMergeOptions Defaults()
+
+ const c_string& name()
+ shared_ptr[CDataType] type()
+ c_bool nullable()
+
+ c_string ToString()
+ c_bool Equals(const CField& other, c_bool check_metadata)
+
+ shared_ptr[const CKeyValueMetadata] metadata()
+
+ CField(const c_string& name, const shared_ptr[CDataType]& type,
+ c_bool nullable)
+
+ CField(const c_string& name, const shared_ptr[CDataType]& type,
+ c_bool nullable, const shared_ptr[CKeyValueMetadata]& metadata)
+
+ # Removed const in Cython so don't have to cast to get code to generate
+ shared_ptr[CField] AddMetadata(
+ const shared_ptr[CKeyValueMetadata]& metadata)
+ shared_ptr[CField] WithMetadata(
+ const shared_ptr[CKeyValueMetadata]& metadata)
+ shared_ptr[CField] RemoveMetadata()
+ shared_ptr[CField] WithType(const shared_ptr[CDataType]& type)
+ shared_ptr[CField] WithName(const c_string& name)
+ shared_ptr[CField] WithNullable(c_bool nullable)
+ vector[shared_ptr[CField]] Flatten()
+
+ cdef cppclass CFieldRef" arrow::FieldRef":
+ CFieldRef()
+ CFieldRef(c_string name)
+ CFieldRef(int index)
+ const c_string* name() const
+
+ cdef cppclass CFieldRefHash" arrow::FieldRef::Hash":
+ pass
+
+ cdef cppclass CStructType" arrow::StructType"(CDataType):
+ CStructType(const vector[shared_ptr[CField]]& fields)
+
+ shared_ptr[CField] GetFieldByName(const c_string& name)
+ vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
+ int GetFieldIndex(const c_string& name)
+ vector[int] GetAllFieldIndices(const c_string& name)
+
+ cdef cppclass CUnionType" arrow::UnionType"(CDataType):
+ UnionMode mode()
+ const vector[int8_t]& type_codes()
+ const vector[int]& child_ids()
+
+ cdef shared_ptr[CDataType] CMakeSparseUnionType" arrow::sparse_union"(
+ vector[shared_ptr[CField]] fields,
+ vector[int8_t] type_codes)
+
+ cdef shared_ptr[CDataType] CMakeDenseUnionType" arrow::dense_union"(
+ vector[shared_ptr[CField]] fields,
+ vector[int8_t] type_codes)
+
+ cdef cppclass CSchema" arrow::Schema":
+ CSchema(const vector[shared_ptr[CField]]& fields)
+ CSchema(const vector[shared_ptr[CField]]& fields,
+ const shared_ptr[const CKeyValueMetadata]& metadata)
+
+ # Does not actually exist, but gets Cython to not complain
+ CSchema(const vector[shared_ptr[CField]]& fields,
+ const shared_ptr[CKeyValueMetadata]& metadata)
+
+ c_bool Equals(const CSchema& other, c_bool check_metadata)
+
+ shared_ptr[CField] field(int i)
+ shared_ptr[const CKeyValueMetadata] metadata()
+ shared_ptr[CField] GetFieldByName(const c_string& name)
+ vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
+ int GetFieldIndex(const c_string& name)
+ vector[int] GetAllFieldIndices(const c_string& name)
+ int num_fields()
+ c_string ToString()
+
+ CResult[shared_ptr[CSchema]] AddField(int i,
+ const shared_ptr[CField]& field)
+ CResult[shared_ptr[CSchema]] RemoveField(int i)
+ CResult[shared_ptr[CSchema]] SetField(int i,
+ const shared_ptr[CField]& field)
+
+ # Removed const in Cython so don't have to cast to get code to generate
+ shared_ptr[CSchema] AddMetadata(
+ const shared_ptr[CKeyValueMetadata]& metadata)
+ shared_ptr[CSchema] WithMetadata(
+ const shared_ptr[CKeyValueMetadata]& metadata)
+ shared_ptr[CSchema] RemoveMetadata()
+
+ CResult[shared_ptr[CSchema]] UnifySchemas(
+ const vector[shared_ptr[CSchema]]& schemas)
+
+ cdef cppclass PrettyPrintOptions:
+ PrettyPrintOptions()
+ PrettyPrintOptions(int indent_arg)
+ PrettyPrintOptions(int indent_arg, int window_arg)
+ int indent
+ int indent_size
+ int window
+ c_string null_rep
+ c_bool skip_new_lines
+ c_bool truncate_metadata
+ c_bool show_field_metadata
+ c_bool show_schema_metadata
+
+ @staticmethod
+ PrettyPrintOptions Defaults()
+
+ CStatus PrettyPrint(const CArray& schema,
+ const PrettyPrintOptions& options,
+ c_string* result)
+ CStatus PrettyPrint(const CChunkedArray& schema,
+ const PrettyPrintOptions& options,
+ c_string* result)
+ CStatus PrettyPrint(const CSchema& schema,
+ const PrettyPrintOptions& options,
+ c_string* result)
+
+ cdef cppclass CBooleanArray" arrow::BooleanArray"(CArray):
+ c_bool Value(int i)
+ int64_t false_count()
+ int64_t true_count()
+
+ cdef cppclass CUInt8Array" arrow::UInt8Array"(CArray):
+ uint8_t Value(int i)
+
+ cdef cppclass CInt8Array" arrow::Int8Array"(CArray):
+ int8_t Value(int i)
+
+ cdef cppclass CUInt16Array" arrow::UInt16Array"(CArray):
+ uint16_t Value(int i)
+
+ cdef cppclass CInt16Array" arrow::Int16Array"(CArray):
+ int16_t Value(int i)
+
+ cdef cppclass CUInt32Array" arrow::UInt32Array"(CArray):
+ uint32_t Value(int i)
+
+ cdef cppclass CInt32Array" arrow::Int32Array"(CArray):
+ int32_t Value(int i)
+
+ cdef cppclass CUInt64Array" arrow::UInt64Array"(CArray):
+ uint64_t Value(int i)
+
+ cdef cppclass CInt64Array" arrow::Int64Array"(CArray):
+ int64_t Value(int i)
+
+ cdef cppclass CDate32Array" arrow::Date32Array"(CArray):
+ int32_t Value(int i)
+
+ cdef cppclass CDate64Array" arrow::Date64Array"(CArray):
+ int64_t Value(int i)
+
+ cdef cppclass CTime32Array" arrow::Time32Array"(CArray):
+ int32_t Value(int i)
+
+ cdef cppclass CTime64Array" arrow::Time64Array"(CArray):
+ int64_t Value(int i)
+
+ cdef cppclass CTimestampArray" arrow::TimestampArray"(CArray):
+ int64_t Value(int i)
+
+ cdef cppclass CDurationArray" arrow::DurationArray"(CArray):
+ int64_t Value(int i)
+
+ cdef cppclass CMonthDayNanoIntervalArray \
+ "arrow::MonthDayNanoIntervalArray"(CArray):
+ pass
+
+ cdef cppclass CHalfFloatArray" arrow::HalfFloatArray"(CArray):
+ uint16_t Value(int i)
+
+ cdef cppclass CFloatArray" arrow::FloatArray"(CArray):
+ float Value(int i)
+
+ cdef cppclass CDoubleArray" arrow::DoubleArray"(CArray):
+ double Value(int i)
+
+ cdef cppclass CFixedSizeBinaryArray" arrow::FixedSizeBinaryArray"(CArray):
+ const uint8_t* GetValue(int i)
+
+ cdef cppclass CDecimal128Array" arrow::Decimal128Array"(
+ CFixedSizeBinaryArray
+ ):
+ c_string FormatValue(int i)
+
+ cdef cppclass CDecimal256Array" arrow::Decimal256Array"(
+ CFixedSizeBinaryArray
+ ):
+ c_string FormatValue(int i)
+
+ cdef cppclass CListArray" arrow::ListArray"(CArray):
+ @staticmethod
+ CResult[shared_ptr[CArray]] FromArrays(
+ const CArray& offsets, const CArray& values, CMemoryPool* pool)
+
+ const int32_t* raw_value_offsets()
+ int32_t value_offset(int i)
+ int32_t value_length(int i)
+ shared_ptr[CArray] values()
+ shared_ptr[CArray] offsets()
+ shared_ptr[CDataType] value_type()
+
+ cdef cppclass CLargeListArray" arrow::LargeListArray"(CArray):
+ @staticmethod
+ CResult[shared_ptr[CArray]] FromArrays(
+ const CArray& offsets, const CArray& values, CMemoryPool* pool)
+
+ int64_t value_offset(int i)
+ int64_t value_length(int i)
+ shared_ptr[CArray] values()
+ shared_ptr[CArray] offsets()
+ shared_ptr[CDataType] value_type()
+
+ cdef cppclass CFixedSizeListArray" arrow::FixedSizeListArray"(CArray):
+ @staticmethod
+ CResult[shared_ptr[CArray]] FromArrays(
+ const shared_ptr[CArray]& values, int32_t list_size)
+
+ int64_t value_offset(int i)
+ int64_t value_length(int i)
+ shared_ptr[CArray] values()
+ shared_ptr[CDataType] value_type()
+
+ cdef cppclass CMapArray" arrow::MapArray"(CArray):
+ @staticmethod
+ CResult[shared_ptr[CArray]] FromArrays(
+ const shared_ptr[CArray]& offsets,
+ const shared_ptr[CArray]& keys,
+ const shared_ptr[CArray]& items,
+ CMemoryPool* pool)
+
+ shared_ptr[CArray] keys()
+ shared_ptr[CArray] items()
+ CMapType* map_type()
+ int64_t value_offset(int i)
+ int64_t value_length(int i)
+ shared_ptr[CArray] values()
+ shared_ptr[CDataType] value_type()
+
+ cdef cppclass CUnionArray" arrow::UnionArray"(CArray):
+ shared_ptr[CBuffer] type_codes()
+ int8_t* raw_type_codes()
+ int child_id(int64_t index)
+ shared_ptr[CArray] field(int pos)
+ const CArray* UnsafeField(int pos)
+ UnionMode mode()
+
+ cdef cppclass CSparseUnionArray" arrow::SparseUnionArray"(CUnionArray):
+ @staticmethod
+ CResult[shared_ptr[CArray]] Make(
+ const CArray& type_codes,
+ const vector[shared_ptr[CArray]]& children,
+ const vector[c_string]& field_names,
+ const vector[int8_t]& type_codes)
+
+ cdef cppclass CDenseUnionArray" arrow::DenseUnionArray"(CUnionArray):
+ @staticmethod
+ CResult[shared_ptr[CArray]] Make(
+ const CArray& type_codes,
+ const CArray& value_offsets,
+ const vector[shared_ptr[CArray]]& children,
+ const vector[c_string]& field_names,
+ const vector[int8_t]& type_codes)
+
+ int32_t value_offset(int i)
+ shared_ptr[CBuffer] value_offsets()
+
+ cdef cppclass CBinaryArray" arrow::BinaryArray"(CArray):
+ const uint8_t* GetValue(int i, int32_t* length)
+ shared_ptr[CBuffer] value_data()
+ int32_t value_offset(int64_t i)
+ int32_t value_length(int64_t i)
+ int32_t total_values_length()
+
+ cdef cppclass CLargeBinaryArray" arrow::LargeBinaryArray"(CArray):
+ const uint8_t* GetValue(int i, int64_t* length)
+ shared_ptr[CBuffer] value_data()
+ int64_t value_offset(int64_t i)
+ int64_t value_length(int64_t i)
+ int64_t total_values_length()
+
+ cdef cppclass CStringArray" arrow::StringArray"(CBinaryArray):
+ CStringArray(int64_t length, shared_ptr[CBuffer] value_offsets,
+ shared_ptr[CBuffer] data,
+ shared_ptr[CBuffer] null_bitmap,
+ int64_t null_count,
+ int64_t offset)
+
+ c_string GetString(int i)
+
+ cdef cppclass CLargeStringArray" arrow::LargeStringArray" \
+ (CLargeBinaryArray):
+ CLargeStringArray(int64_t length, shared_ptr[CBuffer] value_offsets,
+ shared_ptr[CBuffer] data,
+ shared_ptr[CBuffer] null_bitmap,
+ int64_t null_count,
+ int64_t offset)
+
+ c_string GetString(int i)
+
+ cdef cppclass CStructArray" arrow::StructArray"(CArray):
+ CStructArray(shared_ptr[CDataType]& type, int64_t length,
+ vector[shared_ptr[CArray]]& children,
+ shared_ptr[CBuffer] null_bitmap=nullptr,
+ int64_t null_count=-1,
+ int64_t offset=0)
+
+ # XXX Cython crashes if default argument values are declared here
+ # https://github.com/cython/cython/issues/2167
+ @staticmethod
+ CResult[shared_ptr[CArray]] MakeFromFieldNames "Make"(
+ vector[shared_ptr[CArray]] children,
+ vector[c_string] field_names,
+ shared_ptr[CBuffer] null_bitmap,
+ int64_t null_count,
+ int64_t offset)
+
+ @staticmethod
+ CResult[shared_ptr[CArray]] MakeFromFields "Make"(
+ vector[shared_ptr[CArray]] children,
+ vector[shared_ptr[CField]] fields,
+ shared_ptr[CBuffer] null_bitmap,
+ int64_t null_count,
+ int64_t offset)
+
+ shared_ptr[CArray] field(int pos)
+ shared_ptr[CArray] GetFieldByName(const c_string& name) const
+
+ CResult[vector[shared_ptr[CArray]]] Flatten(CMemoryPool* pool)
+
+ cdef cppclass CChunkedArray" arrow::ChunkedArray":
+ CChunkedArray(const vector[shared_ptr[CArray]]& arrays)
+ CChunkedArray(const vector[shared_ptr[CArray]]& arrays,
+ const shared_ptr[CDataType]& type)
+ int64_t length()
+ int64_t null_count()
+ int num_chunks()
+ c_bool Equals(const CChunkedArray& other)
+
+ shared_ptr[CArray] chunk(int i)
+ shared_ptr[CDataType] type()
+ shared_ptr[CChunkedArray] Slice(int64_t offset, int64_t length) const
+ shared_ptr[CChunkedArray] Slice(int64_t offset) const
+
+ CResult[vector[shared_ptr[CChunkedArray]]] Flatten(CMemoryPool* pool)
+
+ CStatus Validate() const
+ CStatus ValidateFull() const
+
+ cdef cppclass CRecordBatch" arrow::RecordBatch":
+ @staticmethod
+ shared_ptr[CRecordBatch] Make(
+ const shared_ptr[CSchema]& schema, int64_t num_rows,
+ const vector[shared_ptr[CArray]]& columns)
+
+ @staticmethod
+ CResult[shared_ptr[CRecordBatch]] FromStructArray(
+ const shared_ptr[CArray]& array)
+
+ c_bool Equals(const CRecordBatch& other, c_bool check_metadata)
+
+ shared_ptr[CSchema] schema()
+ shared_ptr[CArray] column(int i)
+ const c_string& column_name(int i)
+
+ const vector[shared_ptr[CArray]]& columns()
+
+ int num_columns()
+ int64_t num_rows()
+
+ CStatus Validate() const
+ CStatus ValidateFull() const
+
+ shared_ptr[CRecordBatch] ReplaceSchemaMetadata(
+ const shared_ptr[CKeyValueMetadata]& metadata)
+
+ shared_ptr[CRecordBatch] Slice(int64_t offset)
+ shared_ptr[CRecordBatch] Slice(int64_t offset, int64_t length)
+
+ cdef cppclass CTable" arrow::Table":
+ CTable(const shared_ptr[CSchema]& schema,
+ const vector[shared_ptr[CChunkedArray]]& columns)
+
+ @staticmethod
+ shared_ptr[CTable] Make(
+ const shared_ptr[CSchema]& schema,
+ const vector[shared_ptr[CChunkedArray]]& columns)
+
+ @staticmethod
+ shared_ptr[CTable] MakeFromArrays" Make"(
+ const shared_ptr[CSchema]& schema,
+ const vector[shared_ptr[CArray]]& arrays)
+
+ @staticmethod
+ CResult[shared_ptr[CTable]] FromRecordBatches(
+ const shared_ptr[CSchema]& schema,
+ const vector[shared_ptr[CRecordBatch]]& batches)
+
+ int num_columns()
+ int64_t num_rows()
+
+ c_bool Equals(const CTable& other, c_bool check_metadata)
+
+ shared_ptr[CSchema] schema()
+ shared_ptr[CChunkedArray] column(int i)
+ shared_ptr[CField] field(int i)
+
+ CResult[shared_ptr[CTable]] AddColumn(
+ int i, shared_ptr[CField] field, shared_ptr[CChunkedArray] column)
+ CResult[shared_ptr[CTable]] RemoveColumn(int i)
+ CResult[shared_ptr[CTable]] SetColumn(
+ int i, shared_ptr[CField] field, shared_ptr[CChunkedArray] column)
+
+ vector[c_string] ColumnNames()
+ CResult[shared_ptr[CTable]] RenameColumns(const vector[c_string]&)
+ CResult[shared_ptr[CTable]] SelectColumns(const vector[int]&)
+
+ CResult[shared_ptr[CTable]] Flatten(CMemoryPool* pool)
+
+ CResult[shared_ptr[CTable]] CombineChunks(CMemoryPool* pool)
+
+ CStatus Validate() const
+ CStatus ValidateFull() const
+
+ shared_ptr[CTable] ReplaceSchemaMetadata(
+ const shared_ptr[CKeyValueMetadata]& metadata)
+
+ shared_ptr[CTable] Slice(int64_t offset)
+ shared_ptr[CTable] Slice(int64_t offset, int64_t length)
+
+ cdef cppclass CRecordBatchReader" arrow::RecordBatchReader":
+ shared_ptr[CSchema] schema()
+ CStatus ReadNext(shared_ptr[CRecordBatch]* batch)
+ CStatus ReadAll(shared_ptr[CTable]* out)
+
+ cdef cppclass TableBatchReader(CRecordBatchReader):
+ TableBatchReader(const CTable& table)
+ void set_chunksize(int64_t chunksize)
+
+ cdef cppclass CTensor" arrow::Tensor":
+ shared_ptr[CDataType] type()
+ shared_ptr[CBuffer] data()
+
+ const vector[int64_t]& shape()
+ const vector[int64_t]& strides()
+ int64_t size()
+
+ int ndim()
+ const vector[c_string]& dim_names()
+ const c_string& dim_name(int i)
+
+ c_bool is_mutable()
+ c_bool is_contiguous()
+ Type type_id()
+ c_bool Equals(const CTensor& other)
+
+ cdef cppclass CSparseIndex" arrow::SparseIndex":
+ pass
+
+ cdef cppclass CSparseCOOIndex" arrow::SparseCOOIndex":
+ c_bool is_canonical()
+
+ cdef cppclass CSparseCOOTensor" arrow::SparseCOOTensor":
+ shared_ptr[CDataType] type()
+ shared_ptr[CBuffer] data()
+ CResult[shared_ptr[CTensor]] ToTensor()
+
+ shared_ptr[CSparseIndex] sparse_index()
+
+ const vector[int64_t]& shape()
+ int64_t size()
+ int64_t non_zero_length()
+
+ int ndim()
+ const vector[c_string]& dim_names()
+ const c_string& dim_name(int i)
+
+ c_bool is_mutable()
+ Type type_id()
+ c_bool Equals(const CSparseCOOTensor& other)
+
+ cdef cppclass CSparseCSRMatrix" arrow::SparseCSRMatrix":
+ shared_ptr[CDataType] type()
+ shared_ptr[CBuffer] data()
+ CResult[shared_ptr[CTensor]] ToTensor()
+
+ const vector[int64_t]& shape()
+ int64_t size()
+ int64_t non_zero_length()
+
+ int ndim()
+ const vector[c_string]& dim_names()
+ const c_string& dim_name(int i)
+
+ c_bool is_mutable()
+ Type type_id()
+ c_bool Equals(const CSparseCSRMatrix& other)
+
+ cdef cppclass CSparseCSCMatrix" arrow::SparseCSCMatrix":
+ shared_ptr[CDataType] type()
+ shared_ptr[CBuffer] data()
+ CResult[shared_ptr[CTensor]] ToTensor()
+
+ const vector[int64_t]& shape()
+ int64_t size()
+ int64_t non_zero_length()
+
+ int ndim()
+ const vector[c_string]& dim_names()
+ const c_string& dim_name(int i)
+
+ c_bool is_mutable()
+ Type type_id()
+ c_bool Equals(const CSparseCSCMatrix& other)
+
+ cdef cppclass CSparseCSFTensor" arrow::SparseCSFTensor":
+ shared_ptr[CDataType] type()
+ shared_ptr[CBuffer] data()
+ CResult[shared_ptr[CTensor]] ToTensor()
+
+ const vector[int64_t]& shape()
+ int64_t size()
+ int64_t non_zero_length()
+
+ int ndim()
+ const vector[c_string]& dim_names()
+ const c_string& dim_name(int i)
+
+ c_bool is_mutable()
+ Type type_id()
+ c_bool Equals(const CSparseCSFTensor& other)
+
+ cdef cppclass CScalar" arrow::Scalar":
+ CScalar(shared_ptr[CDataType])
+
+ shared_ptr[CDataType] type
+ c_bool is_valid
+
+ c_string ToString() const
+ c_bool Equals(const CScalar& other) const
+ CStatus Validate() const
+ CStatus ValidateFull() const
+ CResult[shared_ptr[CScalar]] CastTo(shared_ptr[CDataType] to) const
+
+ cdef cppclass CScalarHash" arrow::Scalar::Hash":
+ size_t operator()(const shared_ptr[CScalar]& scalar) const
+
+ cdef cppclass CNullScalar" arrow::NullScalar"(CScalar):
+ CNullScalar()
+
+ cdef cppclass CBooleanScalar" arrow::BooleanScalar"(CScalar):
+ c_bool value
+
+ cdef cppclass CInt8Scalar" arrow::Int8Scalar"(CScalar):
+ int8_t value
+
+ cdef cppclass CUInt8Scalar" arrow::UInt8Scalar"(CScalar):
+ uint8_t value
+
+ cdef cppclass CInt16Scalar" arrow::Int16Scalar"(CScalar):
+ int16_t value
+
+ cdef cppclass CUInt16Scalar" arrow::UInt16Scalar"(CScalar):
+ uint16_t value
+
+ cdef cppclass CInt32Scalar" arrow::Int32Scalar"(CScalar):
+ int32_t value
+
+ cdef cppclass CUInt32Scalar" arrow::UInt32Scalar"(CScalar):
+ uint32_t value
+
+ cdef cppclass CInt64Scalar" arrow::Int64Scalar"(CScalar):
+ int64_t value
+
+ cdef cppclass CUInt64Scalar" arrow::UInt64Scalar"(CScalar):
+ uint64_t value
+
+ cdef cppclass CHalfFloatScalar" arrow::HalfFloatScalar"(CScalar):
+ npy_half value
+
+ cdef cppclass CFloatScalar" arrow::FloatScalar"(CScalar):
+ float value
+
+ cdef cppclass CDoubleScalar" arrow::DoubleScalar"(CScalar):
+ double value
+
+ cdef cppclass CDecimal128Scalar" arrow::Decimal128Scalar"(CScalar):
+ CDecimal128 value
+
+ cdef cppclass CDecimal256Scalar" arrow::Decimal256Scalar"(CScalar):
+ CDecimal256 value
+
+ cdef cppclass CDate32Scalar" arrow::Date32Scalar"(CScalar):
+ int32_t value
+
+ cdef cppclass CDate64Scalar" arrow::Date64Scalar"(CScalar):
+ int64_t value
+
+ cdef cppclass CTime32Scalar" arrow::Time32Scalar"(CScalar):
+ int32_t value
+
+ cdef cppclass CTime64Scalar" arrow::Time64Scalar"(CScalar):
+ int64_t value
+
+ cdef cppclass CTimestampScalar" arrow::TimestampScalar"(CScalar):
+ int64_t value
+
+ cdef cppclass CDurationScalar" arrow::DurationScalar"(CScalar):
+ int64_t value
+
+ cdef cppclass CMonthDayNanoIntervalScalar \
+ "arrow::MonthDayNanoIntervalScalar"(CScalar):
+ pass
+
+ cdef cppclass CBaseBinaryScalar" arrow::BaseBinaryScalar"(CScalar):
+ shared_ptr[CBuffer] value
+
+ cdef cppclass CBaseListScalar" arrow::BaseListScalar"(CScalar):
+ shared_ptr[CArray] value
+
+ cdef cppclass CListScalar" arrow::ListScalar"(CBaseListScalar):
+ pass
+
+ cdef cppclass CMapScalar" arrow::MapScalar"(CListScalar):
+ pass
+
+ cdef cppclass CStructScalar" arrow::StructScalar"(CScalar):
+ vector[shared_ptr[CScalar]] value
+ CResult[shared_ptr[CScalar]] field(CFieldRef ref) const
+
+ cdef cppclass CDictionaryScalarIndexAndDictionary \
+ "arrow::DictionaryScalar::ValueType":
+ shared_ptr[CScalar] index
+ shared_ptr[CArray] dictionary
+
+ cdef cppclass CDictionaryScalar" arrow::DictionaryScalar"(CScalar):
+ CDictionaryScalar(CDictionaryScalarIndexAndDictionary value,
+ shared_ptr[CDataType], c_bool is_valid)
+ CDictionaryScalarIndexAndDictionary value
+
+ CResult[shared_ptr[CScalar]] GetEncodedValue()
+
+ cdef cppclass CUnionScalar" arrow::UnionScalar"(CScalar):
+ shared_ptr[CScalar] value
+ int8_t type_code
+
+ cdef cppclass CExtensionScalar" arrow::ExtensionScalar"(CScalar):
+ shared_ptr[CScalar] value
+
+ shared_ptr[CScalar] MakeScalar[Value](Value value)
+
+ cdef cppclass CConcatenateTablesOptions" arrow::ConcatenateTablesOptions":
+ c_bool unify_schemas
+ CField.CMergeOptions field_merge_options
+
+ @staticmethod
+ CConcatenateTablesOptions Defaults()
+
+ CResult[shared_ptr[CTable]] ConcatenateTables(
+ const vector[shared_ptr[CTable]]& tables,
+ CConcatenateTablesOptions options,
+ CMemoryPool* memory_pool)
+
+ cdef cppclass CDictionaryUnifier" arrow::DictionaryUnifier":
+ @staticmethod
+ CResult[shared_ptr[CChunkedArray]] UnifyChunkedArray(
+ shared_ptr[CChunkedArray] array, CMemoryPool* pool)
+
+ @staticmethod
+ CResult[shared_ptr[CTable]] UnifyTable(
+ const CTable& table, CMemoryPool* pool)
+
+
+cdef extern from "arrow/builder.h" namespace "arrow" nogil:
+
+ cdef cppclass CArrayBuilder" arrow::ArrayBuilder":
+ CArrayBuilder(shared_ptr[CDataType], CMemoryPool* pool)
+
+ int64_t length()
+ int64_t null_count()
+ CStatus AppendNull()
+ CStatus Finish(shared_ptr[CArray]* out)
+ CStatus Reserve(int64_t additional_capacity)
+
+ cdef cppclass CBooleanBuilder" arrow::BooleanBuilder"(CArrayBuilder):
+ CBooleanBuilder(CMemoryPool* pool)
+ CStatus Append(const c_bool val)
+ CStatus Append(const uint8_t val)
+
+ cdef cppclass CInt8Builder" arrow::Int8Builder"(CArrayBuilder):
+ CInt8Builder(CMemoryPool* pool)
+ CStatus Append(const int8_t value)
+
+ cdef cppclass CInt16Builder" arrow::Int16Builder"(CArrayBuilder):
+ CInt16Builder(CMemoryPool* pool)
+ CStatus Append(const int16_t value)
+
+ cdef cppclass CInt32Builder" arrow::Int32Builder"(CArrayBuilder):
+ CInt32Builder(CMemoryPool* pool)
+ CStatus Append(const int32_t value)
+
+ cdef cppclass CInt64Builder" arrow::Int64Builder"(CArrayBuilder):
+ CInt64Builder(CMemoryPool* pool)
+ CStatus Append(const int64_t value)
+
+ cdef cppclass CUInt8Builder" arrow::UInt8Builder"(CArrayBuilder):
+ CUInt8Builder(CMemoryPool* pool)
+ CStatus Append(const uint8_t value)
+
+ cdef cppclass CUInt16Builder" arrow::UInt16Builder"(CArrayBuilder):
+ CUInt16Builder(CMemoryPool* pool)
+ CStatus Append(const uint16_t value)
+
+ cdef cppclass CUInt32Builder" arrow::UInt32Builder"(CArrayBuilder):
+ CUInt32Builder(CMemoryPool* pool)
+ CStatus Append(const uint32_t value)
+
+ cdef cppclass CUInt64Builder" arrow::UInt64Builder"(CArrayBuilder):
+ CUInt64Builder(CMemoryPool* pool)
+ CStatus Append(const uint64_t value)
+
+ cdef cppclass CHalfFloatBuilder" arrow::HalfFloatBuilder"(CArrayBuilder):
+ CHalfFloatBuilder(CMemoryPool* pool)
+
+ cdef cppclass CFloatBuilder" arrow::FloatBuilder"(CArrayBuilder):
+ CFloatBuilder(CMemoryPool* pool)
+ CStatus Append(const float value)
+
+ cdef cppclass CDoubleBuilder" arrow::DoubleBuilder"(CArrayBuilder):
+ CDoubleBuilder(CMemoryPool* pool)
+ CStatus Append(const double value)
+
+ cdef cppclass CBinaryBuilder" arrow::BinaryBuilder"(CArrayBuilder):
+ CArrayBuilder(shared_ptr[CDataType], CMemoryPool* pool)
+ CStatus Append(const char* value, int32_t length)
+
+ cdef cppclass CStringBuilder" arrow::StringBuilder"(CBinaryBuilder):
+ CStringBuilder(CMemoryPool* pool)
+
+ CStatus Append(const c_string& value)
+
+ cdef cppclass CTimestampBuilder "arrow::TimestampBuilder"(CArrayBuilder):
+ CTimestampBuilder(const shared_ptr[CDataType] typ, CMemoryPool* pool)
+ CStatus Append(const int64_t value)
+
+ cdef cppclass CDate32Builder "arrow::Date32Builder"(CArrayBuilder):
+ CDate32Builder(CMemoryPool* pool)
+ CStatus Append(const int32_t value)
+
+ cdef cppclass CDate64Builder "arrow::Date64Builder"(CArrayBuilder):
+ CDate64Builder(CMemoryPool* pool)
+ CStatus Append(const int64_t value)
+
+
+# Use typedef to emulate syntax for std::function<void(..)>
+ctypedef void CallbackTransform(object, const shared_ptr[CBuffer]& src,
+ shared_ptr[CBuffer]* dest)
+
+
+cdef extern from "arrow/util/cancel.h" namespace "arrow" nogil:
+ cdef cppclass CStopToken "arrow::StopToken":
+ CStatus Poll()
+ c_bool IsStopRequested()
+
+ cdef cppclass CStopSource "arrow::StopSource":
+ CStopToken token()
+
+ CResult[CStopSource*] SetSignalStopSource()
+ void ResetSignalStopSource()
+
+ CStatus RegisterCancellingSignalHandler(vector[int] signals)
+ void UnregisterCancellingSignalHandler()
+
+
+cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil:
+ cdef enum FileMode" arrow::io::FileMode::type":
+ FileMode_READ" arrow::io::FileMode::READ"
+ FileMode_WRITE" arrow::io::FileMode::WRITE"
+ FileMode_READWRITE" arrow::io::FileMode::READWRITE"
+
+ cdef enum ObjectType" arrow::io::ObjectType::type":
+ ObjectType_FILE" arrow::io::ObjectType::FILE"
+ ObjectType_DIRECTORY" arrow::io::ObjectType::DIRECTORY"
+
+ cdef cppclass CIOContext" arrow::io::IOContext":
+ CIOContext()
+ CIOContext(CStopToken)
+ CIOContext(CMemoryPool*)
+ CIOContext(CMemoryPool*, CStopToken)
+
+ CIOContext c_default_io_context "arrow::io::default_io_context"()
+ int GetIOThreadPoolCapacity()
+ CStatus SetIOThreadPoolCapacity(int threads)
+
+ cdef cppclass FileStatistics:
+ int64_t size
+ ObjectType kind
+
+ cdef cppclass FileInterface:
+ CStatus Close()
+ CResult[int64_t] Tell()
+ FileMode mode()
+ c_bool closed()
+
+ cdef cppclass Readable:
+ # put overload under a different name to avoid cython bug with multiple
+ # layers of inheritance
+ CResult[shared_ptr[CBuffer]] ReadBuffer" Read"(int64_t nbytes)
+ CResult[int64_t] Read(int64_t nbytes, uint8_t* out)
+
+ cdef cppclass Seekable:
+ CStatus Seek(int64_t position)
+
+ cdef cppclass Writable:
+ CStatus WriteBuffer" Write"(shared_ptr[CBuffer] data)
+ CStatus Write(const uint8_t* data, int64_t nbytes)
+ CStatus Flush()
+
+ cdef cppclass COutputStream" arrow::io::OutputStream"(FileInterface,
+ Writable):
+ pass
+
+ cdef cppclass CInputStream" arrow::io::InputStream"(FileInterface,
+ Readable):
+ CResult[shared_ptr[const CKeyValueMetadata]] ReadMetadata()
+
+ cdef cppclass CRandomAccessFile" arrow::io::RandomAccessFile"(CInputStream,
+ Seekable):
+ CResult[int64_t] GetSize()
+
+ CResult[int64_t] ReadAt(int64_t position, int64_t nbytes,
+ uint8_t* buffer)
+ CResult[shared_ptr[CBuffer]] ReadAt(int64_t position, int64_t nbytes)
+ c_bool supports_zero_copy()
+
+ cdef cppclass WritableFile(COutputStream, Seekable):
+ CStatus WriteAt(int64_t position, const uint8_t* data,
+ int64_t nbytes)
+
+ cdef cppclass ReadWriteFileInterface(CRandomAccessFile,
+ WritableFile):
+ pass
+
+ cdef cppclass CIOFileSystem" arrow::io::FileSystem":
+ CStatus Stat(const c_string& path, FileStatistics* stat)
+
+ cdef cppclass FileOutputStream(COutputStream):
+ @staticmethod
+ CResult[shared_ptr[COutputStream]] Open(const c_string& path)
+
+ int file_descriptor()
+
+ cdef cppclass ReadableFile(CRandomAccessFile):
+ @staticmethod
+ CResult[shared_ptr[ReadableFile]] Open(const c_string& path)
+
+ @staticmethod
+ CResult[shared_ptr[ReadableFile]] Open(const c_string& path,
+ CMemoryPool* memory_pool)
+
+ int file_descriptor()
+
+ cdef cppclass CMemoryMappedFile \
+ " arrow::io::MemoryMappedFile"(ReadWriteFileInterface):
+
+ @staticmethod
+ CResult[shared_ptr[CMemoryMappedFile]] Create(const c_string& path,
+ int64_t size)
+
+ @staticmethod
+ CResult[shared_ptr[CMemoryMappedFile]] Open(const c_string& path,
+ FileMode mode)
+
+ CStatus Resize(int64_t size)
+
+ int file_descriptor()
+
+ cdef cppclass CCompressedInputStream \
+ " arrow::io::CompressedInputStream"(CInputStream):
+ @staticmethod
+ CResult[shared_ptr[CCompressedInputStream]] Make(
+ CCodec* codec, shared_ptr[CInputStream] raw)
+
+ cdef cppclass CCompressedOutputStream \
+ " arrow::io::CompressedOutputStream"(COutputStream):
+ @staticmethod
+ CResult[shared_ptr[CCompressedOutputStream]] Make(
+ CCodec* codec, shared_ptr[COutputStream] raw)
+
+ cdef cppclass CBufferedInputStream \
+ " arrow::io::BufferedInputStream"(CInputStream):
+
+ @staticmethod
+ CResult[shared_ptr[CBufferedInputStream]] Create(
+ int64_t buffer_size, CMemoryPool* pool,
+ shared_ptr[CInputStream] raw)
+
+ CResult[shared_ptr[CInputStream]] Detach()
+
+ cdef cppclass CBufferedOutputStream \
+ " arrow::io::BufferedOutputStream"(COutputStream):
+
+ @staticmethod
+ CResult[shared_ptr[CBufferedOutputStream]] Create(
+ int64_t buffer_size, CMemoryPool* pool,
+ shared_ptr[COutputStream] raw)
+
+ CResult[shared_ptr[COutputStream]] Detach()
+
+ cdef cppclass CTransformInputStreamVTable \
+ "arrow::py::TransformInputStreamVTable":
+ CTransformInputStreamVTable()
+ function[CallbackTransform] transform
+
+ shared_ptr[CInputStream] MakeTransformInputStream \
+ "arrow::py::MakeTransformInputStream"(
+ shared_ptr[CInputStream] wrapped, CTransformInputStreamVTable vtable,
+ object method_arg)
+
+ # ----------------------------------------------------------------------
+ # HDFS
+
+ CStatus HaveLibHdfs()
+ CStatus HaveLibHdfs3()
+
+ cdef enum HdfsDriver" arrow::io::HdfsDriver":
+ HdfsDriver_LIBHDFS" arrow::io::HdfsDriver::LIBHDFS"
+ HdfsDriver_LIBHDFS3" arrow::io::HdfsDriver::LIBHDFS3"
+
+ cdef cppclass HdfsConnectionConfig:
+ c_string host
+ int port
+ c_string user
+ c_string kerb_ticket
+ unordered_map[c_string, c_string] extra_conf
+ HdfsDriver driver
+
+ cdef cppclass HdfsPathInfo:
+ ObjectType kind
+ c_string name
+ c_string owner
+ c_string group
+ int32_t last_modified_time
+ int32_t last_access_time
+ int64_t size
+ int16_t replication
+ int64_t block_size
+ int16_t permissions
+
+ cdef cppclass HdfsReadableFile(CRandomAccessFile):
+ pass
+
+ cdef cppclass HdfsOutputStream(COutputStream):
+ pass
+
+ cdef cppclass CIOHadoopFileSystem \
+ "arrow::io::HadoopFileSystem"(CIOFileSystem):
+ @staticmethod
+ CStatus Connect(const HdfsConnectionConfig* config,
+ shared_ptr[CIOHadoopFileSystem]* client)
+
+ CStatus MakeDirectory(const c_string& path)
+
+ CStatus Delete(const c_string& path, c_bool recursive)
+
+ CStatus Disconnect()
+
+ c_bool Exists(const c_string& path)
+
+ CStatus Chmod(const c_string& path, int mode)
+ CStatus Chown(const c_string& path, const char* owner,
+ const char* group)
+
+ CStatus GetCapacity(int64_t* nbytes)
+ CStatus GetUsed(int64_t* nbytes)
+
+ CStatus ListDirectory(const c_string& path,
+ vector[HdfsPathInfo]* listing)
+
+ CStatus GetPathInfo(const c_string& path, HdfsPathInfo* info)
+
+ CStatus Rename(const c_string& src, const c_string& dst)
+
+ CStatus OpenReadable(const c_string& path,
+ shared_ptr[HdfsReadableFile]* handle)
+
+ CStatus OpenWritable(const c_string& path, c_bool append,
+ int32_t buffer_size, int16_t replication,
+ int64_t default_block_size,
+ shared_ptr[HdfsOutputStream]* handle)
+
+ cdef cppclass CBufferReader \
+ " arrow::io::BufferReader"(CRandomAccessFile):
+ CBufferReader(const shared_ptr[CBuffer]& buffer)
+ CBufferReader(const uint8_t* data, int64_t nbytes)
+
+ cdef cppclass CBufferOutputStream \
+ " arrow::io::BufferOutputStream"(COutputStream):
+ CBufferOutputStream(const shared_ptr[CResizableBuffer]& buffer)
+
+ cdef cppclass CMockOutputStream \
+ " arrow::io::MockOutputStream"(COutputStream):
+ CMockOutputStream()
+ int64_t GetExtentBytesWritten()
+
+ cdef cppclass CFixedSizeBufferWriter \
+ " arrow::io::FixedSizeBufferWriter"(WritableFile):
+ CFixedSizeBufferWriter(const shared_ptr[CBuffer]& buffer)
+
+ void set_memcopy_threads(int num_threads)
+ void set_memcopy_blocksize(int64_t blocksize)
+ void set_memcopy_threshold(int64_t threshold)
+
+
+cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
+ cdef enum MessageType" arrow::ipc::MessageType":
+ MessageType_SCHEMA" arrow::ipc::MessageType::SCHEMA"
+ MessageType_RECORD_BATCH" arrow::ipc::MessageType::RECORD_BATCH"
+ MessageType_DICTIONARY_BATCH \
+ " arrow::ipc::MessageType::DICTIONARY_BATCH"
+
+ # TODO: use "cpdef enum class" to automatically get a Python wrapper?
+ # See
+ # https://github.com/cython/cython/commit/2c7c22f51405299a4e247f78edf52957d30cf71d#diff-61c1365c0f761a8137754bb3a73bfbf7
+ ctypedef enum CMetadataVersion" arrow::ipc::MetadataVersion":
+ CMetadataVersion_V1" arrow::ipc::MetadataVersion::V1"
+ CMetadataVersion_V2" arrow::ipc::MetadataVersion::V2"
+ CMetadataVersion_V3" arrow::ipc::MetadataVersion::V3"
+ CMetadataVersion_V4" arrow::ipc::MetadataVersion::V4"
+ CMetadataVersion_V5" arrow::ipc::MetadataVersion::V5"
+
+ cdef cppclass CIpcWriteOptions" arrow::ipc::IpcWriteOptions":
+ c_bool allow_64bit
+ int max_recursion_depth
+ int32_t alignment
+ c_bool write_legacy_ipc_format
+ CMemoryPool* memory_pool
+ CMetadataVersion metadata_version
+ shared_ptr[CCodec] codec
+ c_bool use_threads
+ c_bool emit_dictionary_deltas
+
+ @staticmethod
+ CIpcWriteOptions Defaults()
+
+ cdef cppclass CIpcReadOptions" arrow::ipc::IpcReadOptions":
+ int max_recursion_depth
+ CMemoryPool* memory_pool
+ shared_ptr[unordered_set[int]] included_fields
+
+ @staticmethod
+ CIpcReadOptions Defaults()
+
+ cdef cppclass CIpcWriteStats" arrow::ipc::WriteStats":
+ int64_t num_messages
+ int64_t num_record_batches
+ int64_t num_dictionary_batches
+ int64_t num_dictionary_deltas
+ int64_t num_replaced_dictionaries
+
+ cdef cppclass CIpcReadStats" arrow::ipc::ReadStats":
+ int64_t num_messages
+ int64_t num_record_batches
+ int64_t num_dictionary_batches
+ int64_t num_dictionary_deltas
+ int64_t num_replaced_dictionaries
+
+ cdef cppclass CDictionaryMemo" arrow::ipc::DictionaryMemo":
+ pass
+
+ cdef cppclass CIpcPayload" arrow::ipc::IpcPayload":
+ MessageType type
+ shared_ptr[CBuffer] metadata
+ vector[shared_ptr[CBuffer]] body_buffers
+ int64_t body_length
+
+ cdef cppclass CMessage" arrow::ipc::Message":
+ CResult[unique_ptr[CMessage]] Open(shared_ptr[CBuffer] metadata,
+ shared_ptr[CBuffer] body)
+
+ shared_ptr[CBuffer] body()
+
+ c_bool Equals(const CMessage& other)
+
+ shared_ptr[CBuffer] metadata()
+ CMetadataVersion metadata_version()
+ MessageType type()
+
+ CStatus SerializeTo(COutputStream* stream,
+ const CIpcWriteOptions& options,
+ int64_t* output_length)
+
+ c_string FormatMessageType(MessageType type)
+
+ cdef cppclass CMessageReader" arrow::ipc::MessageReader":
+ @staticmethod
+ unique_ptr[CMessageReader] Open(const shared_ptr[CInputStream]& stream)
+
+ CResult[unique_ptr[CMessage]] ReadNextMessage()
+
+ cdef cppclass CRecordBatchWriter" arrow::ipc::RecordBatchWriter":
+ CStatus Close()
+ CStatus WriteRecordBatch(const CRecordBatch& batch)
+ CStatus WriteTable(const CTable& table, int64_t max_chunksize)
+
+ CIpcWriteStats stats()
+
+ cdef cppclass CRecordBatchStreamReader \
+ " arrow::ipc::RecordBatchStreamReader"(CRecordBatchReader):
+ @staticmethod
+ CResult[shared_ptr[CRecordBatchReader]] Open(
+ const shared_ptr[CInputStream], const CIpcReadOptions&)
+
+ @staticmethod
+ CResult[shared_ptr[CRecordBatchReader]] Open2" Open"(
+ unique_ptr[CMessageReader] message_reader,
+ const CIpcReadOptions& options)
+
+ CIpcReadStats stats()
+
+ cdef cppclass CRecordBatchFileReader \
+ " arrow::ipc::RecordBatchFileReader":
+ @staticmethod
+ CResult[shared_ptr[CRecordBatchFileReader]] Open(
+ CRandomAccessFile* file,
+ const CIpcReadOptions& options)
+
+ @staticmethod
+ CResult[shared_ptr[CRecordBatchFileReader]] Open2" Open"(
+ CRandomAccessFile* file, int64_t footer_offset,
+ const CIpcReadOptions& options)
+
+ shared_ptr[CSchema] schema()
+
+ int num_record_batches()
+
+ CResult[shared_ptr[CRecordBatch]] ReadRecordBatch(int i)
+
+ CIpcReadStats stats()
+
+ CResult[shared_ptr[CRecordBatchWriter]] MakeStreamWriter(
+ shared_ptr[COutputStream] sink, const shared_ptr[CSchema]& schema,
+ CIpcWriteOptions& options)
+
+ CResult[shared_ptr[CRecordBatchWriter]] MakeFileWriter(
+ shared_ptr[COutputStream] sink, const shared_ptr[CSchema]& schema,
+ CIpcWriteOptions& options)
+
+ CResult[unique_ptr[CMessage]] ReadMessage(CInputStream* stream,
+ CMemoryPool* pool)
+
+ CStatus GetRecordBatchSize(const CRecordBatch& batch, int64_t* size)
+ CStatus GetTensorSize(const CTensor& tensor, int64_t* size)
+
+ CStatus WriteTensor(const CTensor& tensor, COutputStream* dst,
+ int32_t* metadata_length,
+ int64_t* body_length)
+
+ CResult[shared_ptr[CTensor]] ReadTensor(CInputStream* stream)
+
+ CResult[shared_ptr[CRecordBatch]] ReadRecordBatch(
+ const CMessage& message, const shared_ptr[CSchema]& schema,
+ CDictionaryMemo* dictionary_memo,
+ const CIpcReadOptions& options)
+
+ CResult[shared_ptr[CBuffer]] SerializeSchema(
+ const CSchema& schema, CMemoryPool* pool)
+
+ CResult[shared_ptr[CBuffer]] SerializeRecordBatch(
+ const CRecordBatch& schema, const CIpcWriteOptions& options)
+
+ CResult[shared_ptr[CSchema]] ReadSchema(CInputStream* stream,
+ CDictionaryMemo* dictionary_memo)
+
+ CResult[shared_ptr[CRecordBatch]] ReadRecordBatch(
+ const shared_ptr[CSchema]& schema,
+ CDictionaryMemo* dictionary_memo,
+ const CIpcReadOptions& options,
+ CInputStream* stream)
+
+ CStatus AlignStream(CInputStream* stream, int64_t alignment)
+ CStatus AlignStream(COutputStream* stream, int64_t alignment)
+
+ cdef CStatus GetRecordBatchPayload \
+ " arrow::ipc::GetRecordBatchPayload"(
+ const CRecordBatch& batch,
+ const CIpcWriteOptions& options,
+ CIpcPayload* out)
+
+
+cdef extern from "arrow/util/value_parsing.h" namespace "arrow" nogil:
+ cdef cppclass CTimestampParser" arrow::TimestampParser":
+ const char* kind() const
+ const char* format() const
+
+ @staticmethod
+ shared_ptr[CTimestampParser] MakeStrptime(c_string format)
+
+ @staticmethod
+ shared_ptr[CTimestampParser] MakeISO8601()
+
+
+cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil:
+
+ cdef cppclass CCSVParseOptions" arrow::csv::ParseOptions":
+ unsigned char delimiter
+ c_bool quoting
+ unsigned char quote_char
+ c_bool double_quote
+ c_bool escaping
+ unsigned char escape_char
+ c_bool newlines_in_values
+ c_bool ignore_empty_lines
+
+ CCSVParseOptions()
+ CCSVParseOptions(CCSVParseOptions&&)
+
+ @staticmethod
+ CCSVParseOptions Defaults()
+
+ CStatus Validate()
+
+ cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions":
+ c_bool check_utf8
+ unordered_map[c_string, shared_ptr[CDataType]] column_types
+ vector[c_string] null_values
+ vector[c_string] true_values
+ vector[c_string] false_values
+ c_bool strings_can_be_null
+ c_bool quoted_strings_can_be_null
+ vector[shared_ptr[CTimestampParser]] timestamp_parsers
+
+ c_bool auto_dict_encode
+ int32_t auto_dict_max_cardinality
+ unsigned char decimal_point
+
+ vector[c_string] include_columns
+ c_bool include_missing_columns
+
+ CCSVConvertOptions()
+ CCSVConvertOptions(CCSVConvertOptions&&)
+
+ @staticmethod
+ CCSVConvertOptions Defaults()
+
+ CStatus Validate()
+
+ cdef cppclass CCSVReadOptions" arrow::csv::ReadOptions":
+ c_bool use_threads
+ int32_t block_size
+ int32_t skip_rows
+ int32_t skip_rows_after_names
+ vector[c_string] column_names
+ c_bool autogenerate_column_names
+
+ CCSVReadOptions()
+ CCSVReadOptions(CCSVReadOptions&&)
+
+ @staticmethod
+ CCSVReadOptions Defaults()
+
+ CStatus Validate()
+
+ cdef cppclass CCSVWriteOptions" arrow::csv::WriteOptions":
+ c_bool include_header
+ int32_t batch_size
+ CIOContext io_context
+
+ CCSVWriteOptions()
+ CCSVWriteOptions(CCSVWriteOptions&&)
+
+ @staticmethod
+ CCSVWriteOptions Defaults()
+
+ CStatus Validate()
+
+ cdef cppclass CCSVReader" arrow::csv::TableReader":
+ @staticmethod
+ CResult[shared_ptr[CCSVReader]] Make(
+ CIOContext, shared_ptr[CInputStream],
+ CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions)
+
+ CResult[shared_ptr[CTable]] Read()
+
+ cdef cppclass CCSVStreamingReader" arrow::csv::StreamingReader"(
+ CRecordBatchReader):
+ @staticmethod
+ CResult[shared_ptr[CCSVStreamingReader]] Make(
+ CIOContext, shared_ptr[CInputStream],
+ CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions)
+
+ cdef CStatus WriteCSV(CTable&, CCSVWriteOptions& options, COutputStream*)
+ cdef CStatus WriteCSV(
+ CRecordBatch&, CCSVWriteOptions& options, COutputStream*)
+ cdef CResult[shared_ptr[CRecordBatchWriter]] MakeCSVWriter(
+ shared_ptr[COutputStream], shared_ptr[CSchema],
+ CCSVWriteOptions& options)
+
+
+cdef extern from "arrow/json/options.h" nogil:
+
+ ctypedef enum CUnexpectedFieldBehavior \
+ "arrow::json::UnexpectedFieldBehavior":
+ CUnexpectedFieldBehavior_Ignore \
+ "arrow::json::UnexpectedFieldBehavior::Ignore"
+ CUnexpectedFieldBehavior_Error \
+ "arrow::json::UnexpectedFieldBehavior::Error"
+ CUnexpectedFieldBehavior_InferType \
+ "arrow::json::UnexpectedFieldBehavior::InferType"
+
+ cdef cppclass CJSONReadOptions" arrow::json::ReadOptions":
+ c_bool use_threads
+ int32_t block_size
+
+ @staticmethod
+ CJSONReadOptions Defaults()
+
+ cdef cppclass CJSONParseOptions" arrow::json::ParseOptions":
+ shared_ptr[CSchema] explicit_schema
+ c_bool newlines_in_values
+ CUnexpectedFieldBehavior unexpected_field_behavior
+
+ @staticmethod
+ CJSONParseOptions Defaults()
+
+
+cdef extern from "arrow/json/reader.h" namespace "arrow::json" nogil:
+
+ cdef cppclass CJSONReader" arrow::json::TableReader":
+ @staticmethod
+ CResult[shared_ptr[CJSONReader]] Make(
+ CMemoryPool*, shared_ptr[CInputStream],
+ CJSONReadOptions, CJSONParseOptions)
+
+ CResult[shared_ptr[CTable]] Read()
+
+
+cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
+
+ cdef cppclass CExecContext" arrow::compute::ExecContext":
+ CExecContext()
+ CExecContext(CMemoryPool* pool)
+
+ cdef cppclass CKernelSignature" arrow::compute::KernelSignature":
+ c_string ToString() const
+
+ cdef cppclass CKernel" arrow::compute::Kernel":
+ shared_ptr[CKernelSignature] signature
+
+ cdef cppclass CArrayKernel" arrow::compute::ArrayKernel"(CKernel):
+ pass
+
+ cdef cppclass CScalarKernel" arrow::compute::ScalarKernel"(CArrayKernel):
+ pass
+
+ cdef cppclass CVectorKernel" arrow::compute::VectorKernel"(CArrayKernel):
+ pass
+
+ cdef cppclass CScalarAggregateKernel \
+ " arrow::compute::ScalarAggregateKernel"(CKernel):
+ pass
+
+ cdef cppclass CHashAggregateKernel \
+ " arrow::compute::HashAggregateKernel"(CKernel):
+ pass
+
+ cdef cppclass CArity" arrow::compute::Arity":
+ int num_args
+ c_bool is_varargs
+
+ cdef enum FunctionKind" arrow::compute::Function::Kind":
+ FunctionKind_SCALAR" arrow::compute::Function::SCALAR"
+ FunctionKind_VECTOR" arrow::compute::Function::VECTOR"
+ FunctionKind_SCALAR_AGGREGATE \
+ " arrow::compute::Function::SCALAR_AGGREGATE"
+ FunctionKind_HASH_AGGREGATE \
+ " arrow::compute::Function::HASH_AGGREGATE"
+ FunctionKind_META \
+ " arrow::compute::Function::META"
+
+ cdef cppclass CFunctionDoc" arrow::compute::FunctionDoc":
+ c_string summary
+ c_string description
+ vector[c_string] arg_names
+ c_string options_class
+
+ cdef cppclass CFunctionOptionsType" arrow::compute::FunctionOptionsType":
+ const char* type_name() const
+
+ cdef cppclass CFunctionOptions" arrow::compute::FunctionOptions":
+ const CFunctionOptionsType* options_type() const
+ const char* type_name() const
+ c_bool Equals(const CFunctionOptions& other) const
+ c_string ToString() const
+ unique_ptr[CFunctionOptions] Copy() const
+ CResult[shared_ptr[CBuffer]] Serialize() const
+
+ @staticmethod
+ CResult[unique_ptr[CFunctionOptions]] Deserialize(
+ const c_string& type_name, const CBuffer& buffer)
+
+ cdef cppclass CFunction" arrow::compute::Function":
+ const c_string& name() const
+ FunctionKind kind() const
+ const CArity& arity() const
+ const CFunctionDoc& doc() const
+ int num_kernels() const
+ CResult[CDatum] Execute(const vector[CDatum]& args,
+ const CFunctionOptions* options,
+ CExecContext* ctx) const
+
+ cdef cppclass CScalarFunction" arrow::compute::ScalarFunction"(CFunction):
+ vector[const CScalarKernel*] kernels() const
+
+ cdef cppclass CVectorFunction" arrow::compute::VectorFunction"(CFunction):
+ vector[const CVectorKernel*] kernels() const
+
+ cdef cppclass CScalarAggregateFunction \
+ " arrow::compute::ScalarAggregateFunction"(CFunction):
+ vector[const CScalarAggregateKernel*] kernels() const
+
+ cdef cppclass CHashAggregateFunction \
+ " arrow::compute::HashAggregateFunction"(CFunction):
+ vector[const CHashAggregateKernel*] kernels() const
+
+ cdef cppclass CMetaFunction" arrow::compute::MetaFunction"(CFunction):
+ pass
+
+ cdef cppclass CFunctionRegistry" arrow::compute::FunctionRegistry":
+ CResult[shared_ptr[CFunction]] GetFunction(
+ const c_string& name) const
+ vector[c_string] GetFunctionNames() const
+ int num_functions() const
+
+ CFunctionRegistry* GetFunctionRegistry()
+
+ cdef cppclass CElementWiseAggregateOptions \
+ "arrow::compute::ElementWiseAggregateOptions"(CFunctionOptions):
+ CElementWiseAggregateOptions(c_bool skip_nulls)
+ c_bool skip_nulls
+
+ ctypedef enum CRoundMode \
+ "arrow::compute::RoundMode":
+ CRoundMode_DOWN \
+ "arrow::compute::RoundMode::DOWN"
+ CRoundMode_UP \
+ "arrow::compute::RoundMode::UP"
+ CRoundMode_TOWARDS_ZERO \
+ "arrow::compute::RoundMode::TOWARDS_ZERO"
+ CRoundMode_TOWARDS_INFINITY \
+ "arrow::compute::RoundMode::TOWARDS_INFINITY"
+ CRoundMode_HALF_DOWN \
+ "arrow::compute::RoundMode::HALF_DOWN"
+ CRoundMode_HALF_UP \
+ "arrow::compute::RoundMode::HALF_UP"
+ CRoundMode_HALF_TOWARDS_ZERO \
+ "arrow::compute::RoundMode::HALF_TOWARDS_ZERO"
+ CRoundMode_HALF_TOWARDS_INFINITY \
+ "arrow::compute::RoundMode::HALF_TOWARDS_INFINITY"
+ CRoundMode_HALF_TO_EVEN \
+ "arrow::compute::RoundMode::HALF_TO_EVEN"
+ CRoundMode_HALF_TO_ODD \
+ "arrow::compute::RoundMode::HALF_TO_ODD"
+
+ cdef cppclass CRoundOptions \
+ "arrow::compute::RoundOptions"(CFunctionOptions):
+ CRoundOptions(int64_t ndigits, CRoundMode round_mode)
+ int64_t ndigits
+ CRoundMode round_mode
+
+ cdef cppclass CRoundToMultipleOptions \
+ "arrow::compute::RoundToMultipleOptions"(CFunctionOptions):
+ CRoundToMultipleOptions(double multiple, CRoundMode round_mode)
+ double multiple
+ CRoundMode round_mode
+
+ cdef enum CJoinNullHandlingBehavior \
+ "arrow::compute::JoinOptions::NullHandlingBehavior":
+ CJoinNullHandlingBehavior_EMIT_NULL \
+ "arrow::compute::JoinOptions::EMIT_NULL"
+ CJoinNullHandlingBehavior_SKIP \
+ "arrow::compute::JoinOptions::SKIP"
+ CJoinNullHandlingBehavior_REPLACE \
+ "arrow::compute::JoinOptions::REPLACE"
+
+ cdef cppclass CJoinOptions \
+ "arrow::compute::JoinOptions"(CFunctionOptions):
+ CJoinOptions(CJoinNullHandlingBehavior null_handling,
+ c_string null_replacement)
+ CJoinNullHandlingBehavior null_handling
+ c_string null_replacement
+
+ cdef cppclass CMatchSubstringOptions \
+ "arrow::compute::MatchSubstringOptions"(CFunctionOptions):
+ CMatchSubstringOptions(c_string pattern, c_bool ignore_case)
+ c_string pattern
+ c_bool ignore_case
+
+ cdef cppclass CTrimOptions \
+ "arrow::compute::TrimOptions"(CFunctionOptions):
+ CTrimOptions(c_string characters)
+ c_string characters
+
+ cdef cppclass CPadOptions \
+ "arrow::compute::PadOptions"(CFunctionOptions):
+ CPadOptions(int64_t width, c_string padding)
+ int64_t width
+ c_string padding
+
+ cdef cppclass CSliceOptions \
+ "arrow::compute::SliceOptions"(CFunctionOptions):
+ CSliceOptions(int64_t start, int64_t stop, int64_t step)
+ int64_t start
+ int64_t stop
+ int64_t step
+
+ cdef cppclass CSplitOptions \
+ "arrow::compute::SplitOptions"(CFunctionOptions):
+ CSplitOptions(int64_t max_splits, c_bool reverse)
+ int64_t max_splits
+ c_bool reverse
+
+ cdef cppclass CSplitPatternOptions \
+ "arrow::compute::SplitPatternOptions"(CFunctionOptions):
+ CSplitPatternOptions(c_string pattern, int64_t max_splits,
+ c_bool reverse)
+ int64_t max_splits
+ c_bool reverse
+ c_string pattern
+
+ cdef cppclass CReplaceSliceOptions \
+ "arrow::compute::ReplaceSliceOptions"(CFunctionOptions):
+ CReplaceSliceOptions(int64_t start, int64_t stop, c_string replacement)
+ int64_t start
+ int64_t stop
+ c_string replacement
+
+ cdef cppclass CReplaceSubstringOptions \
+ "arrow::compute::ReplaceSubstringOptions"(CFunctionOptions):
+ CReplaceSubstringOptions(c_string pattern, c_string replacement,
+ int64_t max_replacements)
+ c_string pattern
+ c_string replacement
+ int64_t max_replacements
+
+ cdef cppclass CExtractRegexOptions \
+ "arrow::compute::ExtractRegexOptions"(CFunctionOptions):
+ CExtractRegexOptions(c_string pattern)
+ c_string pattern
+
+ cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions):
+ CCastOptions()
+ CCastOptions(c_bool safe)
+ CCastOptions(CCastOptions&& options)
+
+ @staticmethod
+ CCastOptions Safe()
+
+ @staticmethod
+ CCastOptions Unsafe()
+ shared_ptr[CDataType] to_type
+ c_bool allow_int_overflow
+ c_bool allow_time_truncate
+ c_bool allow_time_overflow
+ c_bool allow_decimal_truncate
+ c_bool allow_float_truncate
+ c_bool allow_invalid_utf8
+
+ cdef enum CFilterNullSelectionBehavior \
+ "arrow::compute::FilterOptions::NullSelectionBehavior":
+ CFilterNullSelectionBehavior_DROP \
+ "arrow::compute::FilterOptions::DROP"
+ CFilterNullSelectionBehavior_EMIT_NULL \
+ "arrow::compute::FilterOptions::EMIT_NULL"
+
+ cdef cppclass CFilterOptions \
+ " arrow::compute::FilterOptions"(CFunctionOptions):
+ CFilterOptions()
+ CFilterOptions(CFilterNullSelectionBehavior null_selection_behavior)
+ CFilterNullSelectionBehavior null_selection_behavior
+
+ cdef enum CDictionaryEncodeNullEncodingBehavior \
+ "arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior":
+ CDictionaryEncodeNullEncodingBehavior_ENCODE \
+ "arrow::compute::DictionaryEncodeOptions::ENCODE"
+ CDictionaryEncodeNullEncodingBehavior_MASK \
+ "arrow::compute::DictionaryEncodeOptions::MASK"
+
+ cdef cppclass CDictionaryEncodeOptions \
+ "arrow::compute::DictionaryEncodeOptions"(CFunctionOptions):
+ CDictionaryEncodeOptions(
+ CDictionaryEncodeNullEncodingBehavior null_encoding)
+ CDictionaryEncodeNullEncodingBehavior null_encoding
+
+ cdef cppclass CTakeOptions \
+ " arrow::compute::TakeOptions"(CFunctionOptions):
+ CTakeOptions(c_bool boundscheck)
+ c_bool boundscheck
+
+ cdef cppclass CStrptimeOptions \
+ "arrow::compute::StrptimeOptions"(CFunctionOptions):
+ CStrptimeOptions(c_string format, TimeUnit unit)
+ c_string format
+ TimeUnit unit
+
+ cdef cppclass CStrftimeOptions \
+ "arrow::compute::StrftimeOptions"(CFunctionOptions):
+ CStrftimeOptions(c_string format, c_string locale)
+ c_string format
+ c_string locale
+
+ cdef cppclass CDayOfWeekOptions \
+ "arrow::compute::DayOfWeekOptions"(CFunctionOptions):
+ CDayOfWeekOptions(c_bool count_from_zero, uint32_t week_start)
+ c_bool count_from_zero
+ uint32_t week_start
+
+ cdef enum CAssumeTimezoneAmbiguous \
+ "arrow::compute::AssumeTimezoneOptions::Ambiguous":
+ CAssumeTimezoneAmbiguous_AMBIGUOUS_RAISE \
+ "arrow::compute::AssumeTimezoneOptions::AMBIGUOUS_RAISE"
+ CAssumeTimezoneAmbiguous_AMBIGUOUS_EARLIEST \
+ "arrow::compute::AssumeTimezoneOptions::AMBIGUOUS_EARLIEST"
+ CAssumeTimezoneAmbiguous_AMBIGUOUS_LATEST \
+ "arrow::compute::AssumeTimezoneOptions::AMBIGUOUS_LATEST"
+
+ cdef enum CAssumeTimezoneNonexistent \
+ "arrow::compute::AssumeTimezoneOptions::Nonexistent":
+ CAssumeTimezoneNonexistent_NONEXISTENT_RAISE \
+ "arrow::compute::AssumeTimezoneOptions::NONEXISTENT_RAISE"
+ CAssumeTimezoneNonexistent_NONEXISTENT_EARLIEST \
+ "arrow::compute::AssumeTimezoneOptions::NONEXISTENT_EARLIEST"
+ CAssumeTimezoneNonexistent_NONEXISTENT_LATEST \
+ "arrow::compute::AssumeTimezoneOptions::NONEXISTENT_LATEST"
+
+ cdef cppclass CAssumeTimezoneOptions \
+ "arrow::compute::AssumeTimezoneOptions"(CFunctionOptions):
+ CAssumeTimezoneOptions(c_string timezone,
+ CAssumeTimezoneAmbiguous ambiguous,
+ CAssumeTimezoneNonexistent nonexistent)
+ c_string timezone
+ CAssumeTimezoneAmbiguous ambiguous
+ CAssumeTimezoneNonexistent nonexistent
+
+ cdef cppclass CWeekOptions \
+ "arrow::compute::WeekOptions"(CFunctionOptions):
+ CWeekOptions(c_bool week_starts_monday, c_bool count_from_zero,
+ c_bool first_week_is_fully_in_year)
+ c_bool week_starts_monday
+ c_bool count_from_zero
+ c_bool first_week_is_fully_in_year
+
+ cdef cppclass CNullOptions \
+ "arrow::compute::NullOptions"(CFunctionOptions):
+ CNullOptions(c_bool nan_is_null)
+ c_bool nan_is_null
+
+ cdef cppclass CVarianceOptions \
+ "arrow::compute::VarianceOptions"(CFunctionOptions):
+ CVarianceOptions(int ddof, c_bool skip_nulls, uint32_t min_count)
+ int ddof
+ c_bool skip_nulls
+ uint32_t min_count
+
+ cdef cppclass CScalarAggregateOptions \
+ "arrow::compute::ScalarAggregateOptions"(CFunctionOptions):
+ CScalarAggregateOptions(c_bool skip_nulls, uint32_t min_count)
+ c_bool skip_nulls
+ uint32_t min_count
+
+ cdef enum CCountMode "arrow::compute::CountOptions::CountMode":
+ CCountMode_ONLY_VALID "arrow::compute::CountOptions::ONLY_VALID"
+ CCountMode_ONLY_NULL "arrow::compute::CountOptions::ONLY_NULL"
+ CCountMode_ALL "arrow::compute::CountOptions::ALL"
+
+ cdef cppclass CCountOptions \
+ "arrow::compute::CountOptions"(CFunctionOptions):
+ CCountOptions(CCountMode mode)
+ CCountMode mode
+
+ cdef cppclass CModeOptions \
+ "arrow::compute::ModeOptions"(CFunctionOptions):
+ CModeOptions(int64_t n, c_bool skip_nulls, uint32_t min_count)
+ int64_t n
+ c_bool skip_nulls
+ uint32_t min_count
+
+ cdef cppclass CIndexOptions \
+ "arrow::compute::IndexOptions"(CFunctionOptions):
+ CIndexOptions(shared_ptr[CScalar] value)
+ shared_ptr[CScalar] value
+
+ cdef cppclass CMakeStructOptions \
+ "arrow::compute::MakeStructOptions"(CFunctionOptions):
+ CMakeStructOptions(vector[c_string] n,
+ vector[c_bool] r,
+ vector[shared_ptr[const CKeyValueMetadata]] m)
+ CMakeStructOptions(vector[c_string] n)
+ vector[c_string] field_names
+ vector[c_bool] field_nullability
+ vector[shared_ptr[const CKeyValueMetadata]] field_metadata
+
+ ctypedef enum CSortOrder" arrow::compute::SortOrder":
+ CSortOrder_Ascending \
+ "arrow::compute::SortOrder::Ascending"
+ CSortOrder_Descending \
+ "arrow::compute::SortOrder::Descending"
+
+ ctypedef enum CNullPlacement" arrow::compute::NullPlacement":
+ CNullPlacement_AtStart \
+ "arrow::compute::NullPlacement::AtStart"
+ CNullPlacement_AtEnd \
+ "arrow::compute::NullPlacement::AtEnd"
+
+ cdef cppclass CPartitionNthOptions \
+ "arrow::compute::PartitionNthOptions"(CFunctionOptions):
+ CPartitionNthOptions(int64_t pivot, CNullPlacement)
+ int64_t pivot
+ CNullPlacement null_placement
+
+ cdef cppclass CArraySortOptions \
+ "arrow::compute::ArraySortOptions"(CFunctionOptions):
+ CArraySortOptions(CSortOrder, CNullPlacement)
+ CSortOrder order
+ CNullPlacement null_placement
+
+ cdef cppclass CSortKey" arrow::compute::SortKey":
+ CSortKey(c_string name, CSortOrder order)
+ c_string name
+ CSortOrder order
+
+ cdef cppclass CSortOptions \
+ "arrow::compute::SortOptions"(CFunctionOptions):
+ CSortOptions(vector[CSortKey] sort_keys, CNullPlacement)
+ vector[CSortKey] sort_keys
+ CNullPlacement null_placement
+
+ cdef cppclass CSelectKOptions \
+ "arrow::compute::SelectKOptions"(CFunctionOptions):
+ CSelectKOptions(int64_t k, vector[CSortKey] sort_keys)
+ int64_t k
+ vector[CSortKey] sort_keys
+
+ cdef enum CQuantileInterp \
+ "arrow::compute::QuantileOptions::Interpolation":
+ CQuantileInterp_LINEAR "arrow::compute::QuantileOptions::LINEAR"
+ CQuantileInterp_LOWER "arrow::compute::QuantileOptions::LOWER"
+ CQuantileInterp_HIGHER "arrow::compute::QuantileOptions::HIGHER"
+ CQuantileInterp_NEAREST "arrow::compute::QuantileOptions::NEAREST"
+ CQuantileInterp_MIDPOINT "arrow::compute::QuantileOptions::MIDPOINT"
+
+ cdef cppclass CQuantileOptions \
+ "arrow::compute::QuantileOptions"(CFunctionOptions):
+ CQuantileOptions(vector[double] q, CQuantileInterp interpolation,
+ c_bool skip_nulls, uint32_t min_count)
+ vector[double] q
+ CQuantileInterp interpolation
+ c_bool skip_nulls
+ uint32_t min_count
+
+ cdef cppclass CTDigestOptions \
+ "arrow::compute::TDigestOptions"(CFunctionOptions):
+ CTDigestOptions(vector[double] q,
+ uint32_t delta, uint32_t buffer_size,
+ c_bool skip_nulls, uint32_t min_count)
+ vector[double] q
+ uint32_t delta
+ uint32_t buffer_size
+ c_bool skip_nulls
+ uint32_t min_count
+
+ cdef enum DatumType" arrow::Datum::type":
+ DatumType_NONE" arrow::Datum::NONE"
+ DatumType_SCALAR" arrow::Datum::SCALAR"
+ DatumType_ARRAY" arrow::Datum::ARRAY"
+ DatumType_CHUNKED_ARRAY" arrow::Datum::CHUNKED_ARRAY"
+ DatumType_RECORD_BATCH" arrow::Datum::RECORD_BATCH"
+ DatumType_TABLE" arrow::Datum::TABLE"
+ DatumType_COLLECTION" arrow::Datum::COLLECTION"
+
+ cdef cppclass CDatum" arrow::Datum":
+ CDatum()
+ CDatum(const shared_ptr[CArray]& value)
+ CDatum(const shared_ptr[CChunkedArray]& value)
+ CDatum(const shared_ptr[CScalar]& value)
+ CDatum(const shared_ptr[CRecordBatch]& value)
+ CDatum(const shared_ptr[CTable]& value)
+
+ DatumType kind() const
+ c_string ToString() const
+
+ const shared_ptr[CArrayData]& array() const
+ const shared_ptr[CChunkedArray]& chunked_array() const
+ const shared_ptr[CRecordBatch]& record_batch() const
+ const shared_ptr[CTable]& table() const
+ const shared_ptr[CScalar]& scalar() const
+
+ cdef cppclass CSetLookupOptions \
+ "arrow::compute::SetLookupOptions"(CFunctionOptions):
+ CSetLookupOptions(CDatum value_set, c_bool skip_nulls)
+ CDatum value_set
+ c_bool skip_nulls
+
+
+cdef extern from * namespace "arrow::compute":
+ # inlined from compute/function_internal.h to avoid exposing
+ # implementation details
+ """
+ #include "arrow/compute/function.h"
+ namespace arrow {
+ namespace compute {
+ namespace internal {
+ Result<std::unique_ptr<FunctionOptions>> DeserializeFunctionOptions(
+ const Buffer& buffer);
+ } // namespace internal
+ } // namespace compute
+ } // namespace arrow
+ """
+ CResult[unique_ptr[CFunctionOptions]] DeserializeFunctionOptions \
+ " arrow::compute::internal::DeserializeFunctionOptions"(
+ const CBuffer& buffer)
+
+
+cdef extern from "arrow/python/api.h" namespace "arrow::py":
+ # Requires GIL
+ CResult[shared_ptr[CDataType]] InferArrowType(
+ object obj, object mask, c_bool pandas_null_sentinels)
+
+
+cdef extern from "arrow/python/api.h" namespace "arrow::py::internal":
+ object NewMonthDayNanoTupleType()
+ CResult[PyObject*] MonthDayNanoIntervalArrayToPyList(
+ const CMonthDayNanoIntervalArray& array)
+ CResult[PyObject*] MonthDayNanoIntervalScalarToPyObject(
+ const CMonthDayNanoIntervalScalar& scalar)
+
+
+cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil:
+ shared_ptr[CDataType] GetPrimitiveType(Type type)
+
+ object PyHalf_FromHalf(npy_half value)
+
+ cdef cppclass PyConversionOptions:
+ PyConversionOptions()
+
+ shared_ptr[CDataType] type
+ int64_t size
+ CMemoryPool* pool
+ c_bool from_pandas
+ c_bool ignore_timezone
+ c_bool strict
+
+ # TODO Some functions below are not actually "nogil"
+
+ CResult[shared_ptr[CChunkedArray]] ConvertPySequence(
+ object obj, object mask, const PyConversionOptions& options,
+ CMemoryPool* pool)
+
+ CStatus NumPyDtypeToArrow(object dtype, shared_ptr[CDataType]* type)
+
+ CStatus NdarrayToArrow(CMemoryPool* pool, object ao, object mo,
+ c_bool from_pandas,
+ const shared_ptr[CDataType]& type,
+ shared_ptr[CChunkedArray]* out)
+
+ CStatus NdarrayToArrow(CMemoryPool* pool, object ao, object mo,
+ c_bool from_pandas,
+ const shared_ptr[CDataType]& type,
+ const CCastOptions& cast_options,
+ shared_ptr[CChunkedArray]* out)
+
+ CStatus NdarrayToTensor(CMemoryPool* pool, object ao,
+ const vector[c_string]& dim_names,
+ shared_ptr[CTensor]* out)
+
+ CStatus TensorToNdarray(const shared_ptr[CTensor]& tensor, object base,
+ PyObject** out)
+
+ CStatus SparseCOOTensorToNdarray(
+ const shared_ptr[CSparseCOOTensor]& sparse_tensor, object base,
+ PyObject** out_data, PyObject** out_coords)
+
+ CStatus SparseCSRMatrixToNdarray(
+ const shared_ptr[CSparseCSRMatrix]& sparse_tensor, object base,
+ PyObject** out_data, PyObject** out_indptr, PyObject** out_indices)
+
+ CStatus SparseCSCMatrixToNdarray(
+ const shared_ptr[CSparseCSCMatrix]& sparse_tensor, object base,
+ PyObject** out_data, PyObject** out_indptr, PyObject** out_indices)
+
+ CStatus SparseCSFTensorToNdarray(
+ const shared_ptr[CSparseCSFTensor]& sparse_tensor, object base,
+ PyObject** out_data, PyObject** out_indptr, PyObject** out_indices)
+
+ CStatus NdarraysToSparseCOOTensor(CMemoryPool* pool, object data_ao,
+ object coords_ao,
+ const vector[int64_t]& shape,
+ const vector[c_string]& dim_names,
+ shared_ptr[CSparseCOOTensor]* out)
+
+ CStatus NdarraysToSparseCSRMatrix(CMemoryPool* pool, object data_ao,
+ object indptr_ao, object indices_ao,
+ const vector[int64_t]& shape,
+ const vector[c_string]& dim_names,
+ shared_ptr[CSparseCSRMatrix]* out)
+
+ CStatus NdarraysToSparseCSCMatrix(CMemoryPool* pool, object data_ao,
+ object indptr_ao, object indices_ao,
+ const vector[int64_t]& shape,
+ const vector[c_string]& dim_names,
+ shared_ptr[CSparseCSCMatrix]* out)
+
+ CStatus NdarraysToSparseCSFTensor(CMemoryPool* pool, object data_ao,
+ object indptr_ao, object indices_ao,
+ const vector[int64_t]& shape,
+ const vector[int64_t]& axis_order,
+ const vector[c_string]& dim_names,
+ shared_ptr[CSparseCSFTensor]* out)
+
+ CStatus TensorToSparseCOOTensor(shared_ptr[CTensor],
+ shared_ptr[CSparseCOOTensor]* out)
+
+ CStatus TensorToSparseCSRMatrix(shared_ptr[CTensor],
+ shared_ptr[CSparseCSRMatrix]* out)
+
+ CStatus TensorToSparseCSCMatrix(shared_ptr[CTensor],
+ shared_ptr[CSparseCSCMatrix]* out)
+
+ CStatus TensorToSparseCSFTensor(shared_ptr[CTensor],
+ shared_ptr[CSparseCSFTensor]* out)
+
+ CStatus ConvertArrayToPandas(const PandasOptions& options,
+ shared_ptr[CArray] arr,
+ object py_ref, PyObject** out)
+
+ CStatus ConvertChunkedArrayToPandas(const PandasOptions& options,
+ shared_ptr[CChunkedArray] arr,
+ object py_ref, PyObject** out)
+
+ CStatus ConvertTableToPandas(const PandasOptions& options,
+ shared_ptr[CTable] table,
+ PyObject** out)
+
+ void c_set_default_memory_pool \
+ " arrow::py::set_default_memory_pool"(CMemoryPool* pool)\
+
+ CMemoryPool* c_get_memory_pool \
+ " arrow::py::get_memory_pool"()
+
+ cdef cppclass PyBuffer(CBuffer):
+ @staticmethod
+ CResult[shared_ptr[CBuffer]] FromPyObject(object obj)
+
+ cdef cppclass PyForeignBuffer(CBuffer):
+ @staticmethod
+ CStatus Make(const uint8_t* data, int64_t size, object base,
+ shared_ptr[CBuffer]* out)
+
+ cdef cppclass PyReadableFile(CRandomAccessFile):
+ PyReadableFile(object fo)
+
+ cdef cppclass PyOutputStream(COutputStream):
+ PyOutputStream(object fo)
+
+ cdef cppclass PandasOptions:
+ CMemoryPool* pool
+ c_bool strings_to_categorical
+ c_bool zero_copy_only
+ c_bool integer_object_nulls
+ c_bool date_as_object
+ c_bool timestamp_as_object
+ c_bool use_threads
+ c_bool coerce_temporal_nanoseconds
+ c_bool ignore_timezone
+ c_bool deduplicate_objects
+ c_bool safe_cast
+ c_bool split_blocks
+ c_bool self_destruct
+ c_bool decode_dictionaries
+ unordered_set[c_string] categorical_columns
+ unordered_set[c_string] extension_columns
+
+ cdef cppclass CSerializedPyObject" arrow::py::SerializedPyObject":
+ shared_ptr[CRecordBatch] batch
+ vector[shared_ptr[CTensor]] tensors
+
+ CStatus WriteTo(COutputStream* dst)
+ CStatus GetComponents(CMemoryPool* pool, PyObject** dst)
+
+ CStatus SerializeObject(object context, object sequence,
+ CSerializedPyObject* out)
+
+ CStatus DeserializeObject(object context,
+ const CSerializedPyObject& obj,
+ PyObject* base, PyObject** out)
+
+ CStatus ReadSerializedObject(CRandomAccessFile* src,
+ CSerializedPyObject* out)
+
+ cdef cppclass SparseTensorCounts:
+ SparseTensorCounts()
+ int coo
+ int csr
+ int csc
+ int csf
+ int ndim_csf
+ int num_total_tensors() const
+ int num_total_buffers() const
+
+ CStatus GetSerializedFromComponents(
+ int num_tensors,
+ const SparseTensorCounts& num_sparse_tensors,
+ int num_ndarrays,
+ int num_buffers,
+ object buffers,
+ CSerializedPyObject* out)
+
+
+cdef extern from "arrow/python/api.h" namespace "arrow::py::internal" nogil:
+ cdef cppclass CTimePoint "arrow::py::internal::TimePoint":
+ pass
+
+ CTimePoint PyDateTime_to_TimePoint(PyDateTime_DateTime* pydatetime)
+ int64_t TimePoint_to_ns(CTimePoint val)
+ CTimePoint TimePoint_from_s(double val)
+ CTimePoint TimePoint_from_ns(int64_t val)
+
+ CResult[c_string] TzinfoToString(PyObject* pytzinfo)
+ CResult[PyObject*] StringToTzinfo(c_string)
+
+
+cdef extern from "arrow/python/init.h":
+ int arrow_init_numpy() except -1
+
+
+cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py":
+ int import_pyarrow() except -1
+
+
+cdef extern from "arrow/python/common.h" namespace "arrow::py":
+ c_bool IsPyError(const CStatus& status)
+ void RestorePyError(const CStatus& status)
+
+
+cdef extern from "arrow/python/inference.h" namespace "arrow::py":
+ c_bool IsPyBool(object o)
+ c_bool IsPyInt(object o)
+ c_bool IsPyFloat(object o)
+
+
+cdef extern from "arrow/python/ipc.h" namespace "arrow::py":
+ cdef cppclass CPyRecordBatchReader" arrow::py::PyRecordBatchReader" \
+ (CRecordBatchReader):
+ @staticmethod
+ CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CSchema],
+ object)
+
+
+cdef extern from "arrow/extension_type.h" namespace "arrow":
+ cdef cppclass CExtensionTypeRegistry" arrow::ExtensionTypeRegistry":
+ @staticmethod
+ shared_ptr[CExtensionTypeRegistry] GetGlobalRegistry()
+
+ cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType):
+ c_string extension_name()
+ shared_ptr[CDataType] storage_type()
+
+ @staticmethod
+ shared_ptr[CArray] WrapArray(shared_ptr[CDataType] ext_type,
+ shared_ptr[CArray] storage)
+
+ @staticmethod
+ shared_ptr[CChunkedArray] WrapArray(shared_ptr[CDataType] ext_type,
+ shared_ptr[CChunkedArray] storage)
+
+ cdef cppclass CExtensionArray" arrow::ExtensionArray"(CArray):
+ CExtensionArray(shared_ptr[CDataType], shared_ptr[CArray] storage)
+
+ shared_ptr[CArray] storage()
+
+
+cdef extern from "arrow/python/extension_type.h" namespace "arrow::py":
+ cdef cppclass CPyExtensionType \
+ " arrow::py::PyExtensionType"(CExtensionType):
+ @staticmethod
+ CStatus FromClass(const shared_ptr[CDataType] storage_type,
+ const c_string extension_name, object typ,
+ shared_ptr[CExtensionType]* out)
+
+ @staticmethod
+ CStatus FromInstance(shared_ptr[CDataType] storage_type,
+ object inst, shared_ptr[CExtensionType]* out)
+
+ object GetInstance()
+ CStatus SetInstance(object)
+
+ c_string PyExtensionName()
+ CStatus RegisterPyExtensionType(shared_ptr[CDataType])
+ CStatus UnregisterPyExtensionType(c_string type_name)
+
+
+cdef extern from "arrow/python/benchmark.h" namespace "arrow::py::benchmark":
+ void Benchmark_PandasObjectIsNull(object lst) except *
+
+
+cdef extern from "arrow/util/compression.h" namespace "arrow" nogil:
+ cdef enum CCompressionType" arrow::Compression::type":
+ CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED"
+ CCompressionType_SNAPPY" arrow::Compression::SNAPPY"
+ CCompressionType_GZIP" arrow::Compression::GZIP"
+ CCompressionType_BROTLI" arrow::Compression::BROTLI"
+ CCompressionType_ZSTD" arrow::Compression::ZSTD"
+ CCompressionType_LZ4" arrow::Compression::LZ4"
+ CCompressionType_LZ4_FRAME" arrow::Compression::LZ4_FRAME"
+ CCompressionType_BZ2" arrow::Compression::BZ2"
+
+ cdef cppclass CCodec" arrow::util::Codec":
+ @staticmethod
+ CResult[unique_ptr[CCodec]] Create(CCompressionType codec)
+
+ @staticmethod
+ CResult[unique_ptr[CCodec]] CreateWithLevel" Create"(
+ CCompressionType codec,
+ int compression_level)
+
+ @staticmethod
+ c_bool SupportsCompressionLevel(CCompressionType codec)
+
+ @staticmethod
+ CResult[int] MinimumCompressionLevel(CCompressionType codec)
+
+ @staticmethod
+ CResult[int] MaximumCompressionLevel(CCompressionType codec)
+
+ @staticmethod
+ CResult[int] DefaultCompressionLevel(CCompressionType codec)
+
+ @staticmethod
+ c_bool IsAvailable(CCompressionType codec)
+
+ CResult[int64_t] Decompress(int64_t input_len, const uint8_t* input,
+ int64_t output_len,
+ uint8_t* output_buffer)
+ CResult[int64_t] Compress(int64_t input_len, const uint8_t* input,
+ int64_t output_buffer_len,
+ uint8_t* output_buffer)
+ c_string name() const
+ int compression_level() const
+ int64_t MaxCompressedLen(int64_t input_len, const uint8_t* input)
+
+
+cdef extern from "arrow/util/io_util.h" namespace "arrow::internal" nogil:
+ int ErrnoFromStatus(CStatus status)
+ int WinErrorFromStatus(CStatus status)
+ int SignalFromStatus(CStatus status)
+
+ CStatus SendSignal(int signum)
+ CStatus SendSignalToThread(int signum, uint64_t thread_id)
+
+
+cdef extern from "arrow/util/iterator.h" namespace "arrow" nogil:
+ cdef cppclass CIterator" arrow::Iterator"[T]:
+ CResult[T] Next()
+ CStatus Visit[Visitor](Visitor&& visitor)
+ cppclass RangeIterator:
+ CResult[T] operator*()
+ RangeIterator& operator++()
+ c_bool operator!=(RangeIterator) const
+ RangeIterator begin()
+ RangeIterator end()
+ CIterator[T] MakeVectorIterator[T](vector[T] v)
+
+cdef extern from "arrow/util/thread_pool.h" namespace "arrow" nogil:
+ int GetCpuThreadPoolCapacity()
+ CStatus SetCpuThreadPoolCapacity(int threads)
+
+cdef extern from "arrow/array/concatenate.h" namespace "arrow" nogil:
+ CResult[shared_ptr[CArray]] Concatenate(
+ const vector[shared_ptr[CArray]]& arrays,
+ CMemoryPool* pool)
+
+cdef extern from "arrow/c/abi.h":
+ cdef struct ArrowSchema:
+ pass
+
+ cdef struct ArrowArray:
+ pass
+
+ cdef struct ArrowArrayStream:
+ pass
+
+cdef extern from "arrow/c/bridge.h" namespace "arrow" nogil:
+ CStatus ExportType(CDataType&, ArrowSchema* out)
+ CResult[shared_ptr[CDataType]] ImportType(ArrowSchema*)
+
+ CStatus ExportField(CField&, ArrowSchema* out)
+ CResult[shared_ptr[CField]] ImportField(ArrowSchema*)
+
+ CStatus ExportSchema(CSchema&, ArrowSchema* out)
+ CResult[shared_ptr[CSchema]] ImportSchema(ArrowSchema*)
+
+ CStatus ExportArray(CArray&, ArrowArray* out)
+ CStatus ExportArray(CArray&, ArrowArray* out, ArrowSchema* out_schema)
+ CResult[shared_ptr[CArray]] ImportArray(ArrowArray*,
+ shared_ptr[CDataType])
+ CResult[shared_ptr[CArray]] ImportArray(ArrowArray*, ArrowSchema*)
+
+ CStatus ExportRecordBatch(CRecordBatch&, ArrowArray* out)
+ CStatus ExportRecordBatch(CRecordBatch&, ArrowArray* out,
+ ArrowSchema* out_schema)
+ CResult[shared_ptr[CRecordBatch]] ImportRecordBatch(ArrowArray*,
+ shared_ptr[CSchema])
+ CResult[shared_ptr[CRecordBatch]] ImportRecordBatch(ArrowArray*,
+ ArrowSchema*)
+
+ CStatus ExportRecordBatchReader(shared_ptr[CRecordBatchReader],
+ ArrowArrayStream*)
+ CResult[shared_ptr[CRecordBatchReader]] ImportRecordBatchReader(
+ ArrowArrayStream*)
diff --git a/src/arrow/python/pyarrow/includes/libarrow_cuda.pxd b/src/arrow/python/pyarrow/includes/libarrow_cuda.pxd
new file mode 100644
index 000000000..3ac943cf9
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libarrow_cuda.pxd
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from pyarrow.includes.libarrow cimport *
+
+cdef extern from "arrow/gpu/cuda_api.h" namespace "arrow::cuda" nogil:
+
+ cdef cppclass CCudaDeviceManager" arrow::cuda::CudaDeviceManager":
+ @staticmethod
+ CResult[CCudaDeviceManager*] Instance()
+ CResult[shared_ptr[CCudaContext]] GetContext(int gpu_number)
+ CResult[shared_ptr[CCudaContext]] GetSharedContext(int gpu_number,
+ void* handle)
+ CStatus AllocateHost(int device_number, int64_t nbytes,
+ shared_ptr[CCudaHostBuffer]* buffer)
+ int num_devices() const
+
+ cdef cppclass CCudaContext" arrow::cuda::CudaContext":
+ CResult[shared_ptr[CCudaBuffer]] Allocate(int64_t nbytes)
+ CResult[shared_ptr[CCudaBuffer]] View(uint8_t* data, int64_t nbytes)
+ CResult[shared_ptr[CCudaBuffer]] OpenIpcBuffer(
+ const CCudaIpcMemHandle& ipc_handle)
+ CStatus Synchronize()
+ int64_t bytes_allocated() const
+ const void* handle() const
+ int device_number() const
+ CResult[uintptr_t] GetDeviceAddress(uintptr_t addr)
+
+ cdef cppclass CCudaIpcMemHandle" arrow::cuda::CudaIpcMemHandle":
+ @staticmethod
+ CResult[shared_ptr[CCudaIpcMemHandle]] FromBuffer(
+ const void* opaque_handle)
+ CResult[shared_ptr[CBuffer]] Serialize(CMemoryPool* pool) const
+
+ cdef cppclass CCudaBuffer" arrow::cuda::CudaBuffer"(CBuffer):
+ CCudaBuffer(uint8_t* data, int64_t size,
+ const shared_ptr[CCudaContext]& context,
+ c_bool own_data=false, c_bool is_ipc=false)
+ CCudaBuffer(const shared_ptr[CCudaBuffer]& parent,
+ const int64_t offset, const int64_t size)
+
+ @staticmethod
+ CResult[shared_ptr[CCudaBuffer]] FromBuffer(shared_ptr[CBuffer] buf)
+
+ CStatus CopyToHost(const int64_t position, const int64_t nbytes,
+ void* out) const
+ CStatus CopyFromHost(const int64_t position, const void* data,
+ int64_t nbytes)
+ CStatus CopyFromDevice(const int64_t position, const void* data,
+ int64_t nbytes)
+ CStatus CopyFromAnotherDevice(const shared_ptr[CCudaContext]& src_ctx,
+ const int64_t position, const void* data,
+ int64_t nbytes)
+ CResult[shared_ptr[CCudaIpcMemHandle]] ExportForIpc()
+ shared_ptr[CCudaContext] context() const
+
+ cdef cppclass \
+ CCudaHostBuffer" arrow::cuda::CudaHostBuffer"(CMutableBuffer):
+ pass
+
+ cdef cppclass \
+ CCudaBufferReader" arrow::cuda::CudaBufferReader"(CBufferReader):
+ CCudaBufferReader(const shared_ptr[CBuffer]& buffer)
+ CResult[int64_t] Read(int64_t nbytes, void* buffer)
+ CResult[shared_ptr[CBuffer]] Read(int64_t nbytes)
+
+ cdef cppclass \
+ CCudaBufferWriter" arrow::cuda::CudaBufferWriter"(WritableFile):
+ CCudaBufferWriter(const shared_ptr[CCudaBuffer]& buffer)
+ CStatus Close()
+ CStatus Write(const void* data, int64_t nbytes)
+ CStatus WriteAt(int64_t position, const void* data, int64_t nbytes)
+ CStatus SetBufferSize(const int64_t buffer_size)
+ int64_t buffer_size()
+ int64_t num_bytes_buffered() const
+
+ CResult[shared_ptr[CCudaHostBuffer]] AllocateCudaHostBuffer(
+ int device_number, const int64_t size)
+
+ # Cuda prefix is added to avoid picking up arrow::cuda functions
+ # from arrow namespace.
+ CResult[shared_ptr[CCudaBuffer]] \
+ CudaSerializeRecordBatch" arrow::cuda::SerializeRecordBatch"\
+ (const CRecordBatch& batch,
+ CCudaContext* ctx)
+ CResult[shared_ptr[CRecordBatch]] \
+ CudaReadRecordBatch" arrow::cuda::ReadRecordBatch"\
+ (const shared_ptr[CSchema]& schema,
+ CDictionaryMemo* dictionary_memo,
+ const shared_ptr[CCudaBuffer]& buffer,
+ CMemoryPool* pool)
diff --git a/src/arrow/python/pyarrow/includes/libarrow_dataset.pxd b/src/arrow/python/pyarrow/includes/libarrow_dataset.pxd
new file mode 100644
index 000000000..abc79fea8
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libarrow_dataset.pxd
@@ -0,0 +1,478 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from libcpp.unordered_map cimport unordered_map
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.libarrow_fs cimport *
+from pyarrow._parquet cimport *
+
+
+cdef extern from "arrow/api.h" namespace "arrow" nogil:
+
+ cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"(
+ CIterator[shared_ptr[CRecordBatch]]):
+ pass
+
+
+cdef extern from * namespace "arrow::compute":
+ # inlined from expression_internal.h to avoid
+ # proliferation of #include <unordered_map>
+ """
+ #include <unordered_map>
+
+ #include "arrow/type.h"
+ #include "arrow/datum.h"
+
+ namespace arrow {
+ namespace compute {
+ struct KnownFieldValues {
+ std::unordered_map<FieldRef, Datum, FieldRef::Hash> map;
+ };
+ } // namespace compute
+ } // namespace arrow
+ """
+ cdef struct CKnownFieldValues "arrow::compute::KnownFieldValues":
+ unordered_map[CFieldRef, CDatum, CFieldRefHash] map
+
+cdef extern from "arrow/compute/exec/expression.h" \
+ namespace "arrow::compute" nogil:
+
+ cdef cppclass CExpression "arrow::compute::Expression":
+ c_bool Equals(const CExpression& other) const
+ c_string ToString() const
+ CResult[CExpression] Bind(const CSchema&)
+
+ cdef CExpression CMakeScalarExpression \
+ "arrow::compute::literal"(shared_ptr[CScalar] value)
+
+ cdef CExpression CMakeFieldExpression \
+ "arrow::compute::field_ref"(c_string name)
+
+ cdef CExpression CMakeCallExpression \
+ "arrow::compute::call"(c_string function,
+ vector[CExpression] arguments,
+ shared_ptr[CFunctionOptions] options)
+
+ cdef CResult[shared_ptr[CBuffer]] CSerializeExpression \
+ "arrow::compute::Serialize"(const CExpression&)
+
+ cdef CResult[CExpression] CDeserializeExpression \
+ "arrow::compute::Deserialize"(shared_ptr[CBuffer])
+
+ cdef CResult[CKnownFieldValues] \
+ CExtractKnownFieldValues "arrow::compute::ExtractKnownFieldValues"(
+ const CExpression& partition_expression)
+
+ctypedef CStatus cb_writer_finish_internal(CFileWriter*)
+ctypedef void cb_writer_finish(dict, CFileWriter*)
+
+cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil:
+
+ cdef enum ExistingDataBehavior" arrow::dataset::ExistingDataBehavior":
+ ExistingDataBehavior_DELETE_MATCHING" \
+ arrow::dataset::ExistingDataBehavior::kDeleteMatchingPartitions"
+ ExistingDataBehavior_OVERWRITE_OR_IGNORE" \
+ arrow::dataset::ExistingDataBehavior::kOverwriteOrIgnore"
+ ExistingDataBehavior_ERROR" \
+ arrow::dataset::ExistingDataBehavior::kError"
+
+ cdef cppclass CScanOptions "arrow::dataset::ScanOptions":
+ @staticmethod
+ shared_ptr[CScanOptions] Make(shared_ptr[CSchema] schema)
+
+ shared_ptr[CSchema] dataset_schema
+ shared_ptr[CSchema] projected_schema
+
+ cdef cppclass CFragmentScanOptions "arrow::dataset::FragmentScanOptions":
+ c_string type_name() const
+
+ ctypedef CIterator[shared_ptr[CScanTask]] CScanTaskIterator \
+ "arrow::dataset::ScanTaskIterator"
+
+ cdef cppclass CScanTask" arrow::dataset::ScanTask":
+ CResult[CRecordBatchIterator] Execute()
+
+ cdef cppclass CFragment "arrow::dataset::Fragment":
+ CResult[shared_ptr[CSchema]] ReadPhysicalSchema()
+ CResult[CScanTaskIterator] Scan(shared_ptr[CScanOptions] options)
+ c_bool splittable() const
+ c_string type_name() const
+ const CExpression& partition_expression() const
+
+ ctypedef vector[shared_ptr[CFragment]] CFragmentVector \
+ "arrow::dataset::FragmentVector"
+
+ ctypedef CIterator[shared_ptr[CFragment]] CFragmentIterator \
+ "arrow::dataset::FragmentIterator"
+
+ cdef cppclass CInMemoryFragment "arrow::dataset::InMemoryFragment"(
+ CFragment):
+ CInMemoryFragment(vector[shared_ptr[CRecordBatch]] record_batches,
+ CExpression partition_expression)
+
+ cdef cppclass CTaggedRecordBatch "arrow::dataset::TaggedRecordBatch":
+ shared_ptr[CRecordBatch] record_batch
+ shared_ptr[CFragment] fragment
+
+ ctypedef CIterator[CTaggedRecordBatch] CTaggedRecordBatchIterator \
+ "arrow::dataset::TaggedRecordBatchIterator"
+
+ cdef cppclass CScanner "arrow::dataset::Scanner":
+ CScanner(shared_ptr[CDataset], shared_ptr[CScanOptions])
+ CScanner(shared_ptr[CFragment], shared_ptr[CScanOptions])
+ CResult[CScanTaskIterator] Scan()
+ CResult[CTaggedRecordBatchIterator] ScanBatches()
+ CResult[shared_ptr[CTable]] ToTable()
+ CResult[shared_ptr[CTable]] TakeRows(const CArray& indices)
+ CResult[shared_ptr[CTable]] Head(int64_t num_rows)
+ CResult[int64_t] CountRows()
+ CResult[CFragmentIterator] GetFragments()
+ CResult[shared_ptr[CRecordBatchReader]] ToRecordBatchReader()
+ const shared_ptr[CScanOptions]& options()
+
+ cdef cppclass CScannerBuilder "arrow::dataset::ScannerBuilder":
+ CScannerBuilder(shared_ptr[CDataset],
+ shared_ptr[CScanOptions] scan_options)
+ CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment],
+ shared_ptr[CScanOptions] scan_options)
+
+ @staticmethod
+ shared_ptr[CScannerBuilder] FromRecordBatchReader(
+ shared_ptr[CRecordBatchReader] reader)
+ CStatus ProjectColumns "Project"(const vector[c_string]& columns)
+ CStatus Project(vector[CExpression]& exprs, vector[c_string]& columns)
+ CStatus Filter(CExpression filter)
+ CStatus UseThreads(c_bool use_threads)
+ CStatus UseAsync(c_bool use_async)
+ CStatus Pool(CMemoryPool* pool)
+ CStatus BatchSize(int64_t batch_size)
+ CStatus FragmentScanOptions(
+ shared_ptr[CFragmentScanOptions] fragment_scan_options)
+ CResult[shared_ptr[CScanner]] Finish()
+ shared_ptr[CSchema] schema() const
+
+ ctypedef vector[shared_ptr[CDataset]] CDatasetVector \
+ "arrow::dataset::DatasetVector"
+
+ cdef cppclass CDataset "arrow::dataset::Dataset":
+ const shared_ptr[CSchema] & schema()
+ CResult[CFragmentIterator] GetFragments()
+ CResult[CFragmentIterator] GetFragments(CExpression predicate)
+ const CExpression & partition_expression()
+ c_string type_name()
+
+ CResult[shared_ptr[CDataset]] ReplaceSchema(shared_ptr[CSchema])
+
+ CResult[shared_ptr[CScannerBuilder]] NewScan()
+
+ cdef cppclass CInMemoryDataset "arrow::dataset::InMemoryDataset"(
+ CDataset):
+ CInMemoryDataset(shared_ptr[CRecordBatchReader])
+ CInMemoryDataset(shared_ptr[CTable])
+
+ cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"(
+ CDataset):
+ @staticmethod
+ CResult[shared_ptr[CUnionDataset]] Make(shared_ptr[CSchema] schema,
+ CDatasetVector children)
+
+ const CDatasetVector& children() const
+
+ cdef cppclass CInspectOptions "arrow::dataset::InspectOptions":
+ int fragments
+
+ cdef cppclass CFinishOptions "arrow::dataset::FinishOptions":
+ shared_ptr[CSchema] schema
+ CInspectOptions inspect_options
+ c_bool validate_fragments
+
+ cdef cppclass CDatasetFactory "arrow::dataset::DatasetFactory":
+ CResult[vector[shared_ptr[CSchema]]] InspectSchemas(CInspectOptions)
+ CResult[shared_ptr[CSchema]] Inspect(CInspectOptions)
+ CResult[shared_ptr[CDataset]] FinishWithSchema "Finish"(
+ const shared_ptr[CSchema]& schema)
+ CResult[shared_ptr[CDataset]] Finish()
+ const CExpression& root_partition()
+ CStatus SetRootPartition(CExpression partition)
+
+ cdef cppclass CUnionDatasetFactory "arrow::dataset::UnionDatasetFactory":
+ @staticmethod
+ CResult[shared_ptr[CDatasetFactory]] Make(
+ vector[shared_ptr[CDatasetFactory]] factories)
+
+ cdef cppclass CFileSource "arrow::dataset::FileSource":
+ const c_string& path() const
+ const shared_ptr[CFileSystem]& filesystem() const
+ const shared_ptr[CBuffer]& buffer() const
+ # HACK: Cython can't handle all the overloads so don't declare them.
+ # This means invalid construction of CFileSource won't be caught in
+ # the C++ generation phase (though it will still be caught when
+ # the generated C++ is compiled).
+ CFileSource(...)
+
+ cdef cppclass CFileWriteOptions \
+ "arrow::dataset::FileWriteOptions":
+ const shared_ptr[CFileFormat]& format() const
+ c_string type_name() const
+
+ cdef cppclass CFileWriter \
+ "arrow::dataset::FileWriter":
+ const shared_ptr[CFileFormat]& format() const
+ const shared_ptr[CSchema]& schema() const
+ const shared_ptr[CFileWriteOptions]& options() const
+ const CFileLocator& destination() const
+
+ cdef cppclass CParquetFileWriter \
+ "arrow::dataset::ParquetFileWriter"(CFileWriter):
+ const shared_ptr[FileWriter]& parquet_writer() const
+
+ cdef cppclass CFileFormat "arrow::dataset::FileFormat":
+ shared_ptr[CFragmentScanOptions] default_fragment_scan_options
+ c_string type_name() const
+ CResult[shared_ptr[CSchema]] Inspect(const CFileSource&) const
+ CResult[shared_ptr[CFileFragment]] MakeFragment(
+ CFileSource source,
+ CExpression partition_expression,
+ shared_ptr[CSchema] physical_schema)
+ shared_ptr[CFileWriteOptions] DefaultWriteOptions()
+
+ cdef cppclass CFileFragment "arrow::dataset::FileFragment"(
+ CFragment):
+ const CFileSource& source() const
+ const shared_ptr[CFileFormat]& format() const
+
+ cdef cppclass CParquetFileWriteOptions \
+ "arrow::dataset::ParquetFileWriteOptions"(CFileWriteOptions):
+ shared_ptr[WriterProperties] writer_properties
+ shared_ptr[ArrowWriterProperties] arrow_writer_properties
+
+ cdef cppclass CParquetFileFragment "arrow::dataset::ParquetFileFragment"(
+ CFileFragment):
+ const vector[int]& row_groups() const
+ shared_ptr[CFileMetaData] metadata() const
+ CResult[vector[shared_ptr[CFragment]]] SplitByRowGroup(
+ CExpression predicate)
+ CResult[shared_ptr[CFragment]] SubsetWithFilter "Subset"(
+ CExpression predicate)
+ CResult[shared_ptr[CFragment]] SubsetWithIds "Subset"(
+ vector[int] row_group_ids)
+ CStatus EnsureCompleteMetadata()
+
+ cdef cppclass CFileSystemDatasetWriteOptions \
+ "arrow::dataset::FileSystemDatasetWriteOptions":
+ shared_ptr[CFileWriteOptions] file_write_options
+ shared_ptr[CFileSystem] filesystem
+ c_string base_dir
+ shared_ptr[CPartitioning] partitioning
+ int max_partitions
+ c_string basename_template
+ function[cb_writer_finish_internal] writer_pre_finish
+ function[cb_writer_finish_internal] writer_post_finish
+ ExistingDataBehavior existing_data_behavior
+
+ cdef cppclass CFileSystemDataset \
+ "arrow::dataset::FileSystemDataset"(CDataset):
+ @staticmethod
+ CResult[shared_ptr[CDataset]] Make(
+ shared_ptr[CSchema] schema,
+ CExpression source_partition,
+ shared_ptr[CFileFormat] format,
+ shared_ptr[CFileSystem] filesystem,
+ vector[shared_ptr[CFileFragment]] fragments)
+
+ @staticmethod
+ CStatus Write(
+ const CFileSystemDatasetWriteOptions& write_options,
+ shared_ptr[CScanner] scanner)
+
+ c_string type()
+ vector[c_string] files()
+ const shared_ptr[CFileFormat]& format() const
+ const shared_ptr[CFileSystem]& filesystem() const
+ const shared_ptr[CPartitioning]& partitioning() const
+
+ cdef cppclass CParquetFileFormatReaderOptions \
+ "arrow::dataset::ParquetFileFormat::ReaderOptions":
+ unordered_set[c_string] dict_columns
+ TimeUnit coerce_int96_timestamp_unit
+
+ cdef cppclass CParquetFileFormat "arrow::dataset::ParquetFileFormat"(
+ CFileFormat):
+ CParquetFileFormatReaderOptions reader_options
+ CResult[shared_ptr[CFileFragment]] MakeFragment(
+ CFileSource source,
+ CExpression partition_expression,
+ shared_ptr[CSchema] physical_schema,
+ vector[int] row_groups)
+
+ cdef cppclass CParquetFragmentScanOptions \
+ "arrow::dataset::ParquetFragmentScanOptions"(CFragmentScanOptions):
+ shared_ptr[CReaderProperties] reader_properties
+ shared_ptr[ArrowReaderProperties] arrow_reader_properties
+ c_bool enable_parallel_column_conversion
+
+ cdef cppclass CIpcFileWriteOptions \
+ "arrow::dataset::IpcFileWriteOptions"(CFileWriteOptions):
+ pass
+
+ cdef cppclass CIpcFileFormat "arrow::dataset::IpcFileFormat"(
+ CFileFormat):
+ pass
+
+ cdef cppclass COrcFileFormat "arrow::dataset::OrcFileFormat"(
+ CFileFormat):
+ pass
+
+ cdef cppclass CCsvFileWriteOptions \
+ "arrow::dataset::CsvFileWriteOptions"(CFileWriteOptions):
+ shared_ptr[CCSVWriteOptions] write_options
+ CMemoryPool* pool
+
+ cdef cppclass CCsvFileFormat "arrow::dataset::CsvFileFormat"(
+ CFileFormat):
+ CCSVParseOptions parse_options
+
+ cdef cppclass CCsvFragmentScanOptions \
+ "arrow::dataset::CsvFragmentScanOptions"(CFragmentScanOptions):
+ CCSVConvertOptions convert_options
+ CCSVReadOptions read_options
+
+ cdef cppclass CPartitioning "arrow::dataset::Partitioning":
+ c_string type_name() const
+ CResult[CExpression] Parse(const c_string & path) const
+ const shared_ptr[CSchema] & schema()
+
+ cdef cppclass CSegmentEncoding" arrow::dataset::SegmentEncoding":
+ pass
+
+ CSegmentEncoding CSegmentEncodingNone\
+ " arrow::dataset::SegmentEncoding::None"
+ CSegmentEncoding CSegmentEncodingUri\
+ " arrow::dataset::SegmentEncoding::Uri"
+
+ cdef cppclass CKeyValuePartitioningOptions \
+ "arrow::dataset::KeyValuePartitioningOptions":
+ CSegmentEncoding segment_encoding
+
+ cdef cppclass CHivePartitioningOptions \
+ "arrow::dataset::HivePartitioningOptions":
+ CSegmentEncoding segment_encoding
+ c_string null_fallback
+
+ cdef cppclass CPartitioningFactoryOptions \
+ "arrow::dataset::PartitioningFactoryOptions":
+ c_bool infer_dictionary
+ shared_ptr[CSchema] schema
+ CSegmentEncoding segment_encoding
+
+ cdef cppclass CHivePartitioningFactoryOptions \
+ "arrow::dataset::HivePartitioningFactoryOptions":
+ c_bool infer_dictionary
+ c_string null_fallback
+ shared_ptr[CSchema] schema
+ CSegmentEncoding segment_encoding
+
+ cdef cppclass CPartitioningFactory "arrow::dataset::PartitioningFactory":
+ c_string type_name() const
+
+ cdef cppclass CDirectoryPartitioning \
+ "arrow::dataset::DirectoryPartitioning"(CPartitioning):
+ CDirectoryPartitioning(shared_ptr[CSchema] schema,
+ vector[shared_ptr[CArray]] dictionaries)
+
+ @staticmethod
+ shared_ptr[CPartitioningFactory] MakeFactory(
+ vector[c_string] field_names, CPartitioningFactoryOptions)
+
+ vector[shared_ptr[CArray]] dictionaries() const
+
+ cdef cppclass CHivePartitioning \
+ "arrow::dataset::HivePartitioning"(CPartitioning):
+ CHivePartitioning(shared_ptr[CSchema] schema,
+ vector[shared_ptr[CArray]] dictionaries,
+ CHivePartitioningOptions options)
+
+ @staticmethod
+ shared_ptr[CPartitioningFactory] MakeFactory(
+ CHivePartitioningFactoryOptions)
+
+ vector[shared_ptr[CArray]] dictionaries() const
+
+ cdef cppclass CPartitioningOrFactory \
+ "arrow::dataset::PartitioningOrFactory":
+ CPartitioningOrFactory(shared_ptr[CPartitioning])
+ CPartitioningOrFactory(shared_ptr[CPartitioningFactory])
+ CPartitioningOrFactory & operator = (shared_ptr[CPartitioning])
+ CPartitioningOrFactory & operator = (
+ shared_ptr[CPartitioningFactory])
+ shared_ptr[CPartitioning] partitioning() const
+ shared_ptr[CPartitioningFactory] factory() const
+
+ cdef cppclass CFileSystemFactoryOptions \
+ "arrow::dataset::FileSystemFactoryOptions":
+ CPartitioningOrFactory partitioning
+ c_string partition_base_dir
+ c_bool exclude_invalid_files
+ vector[c_string] selector_ignore_prefixes
+
+ cdef cppclass CFileSystemDatasetFactory \
+ "arrow::dataset::FileSystemDatasetFactory"(
+ CDatasetFactory):
+ @staticmethod
+ CResult[shared_ptr[CDatasetFactory]] MakeFromPaths "Make"(
+ shared_ptr[CFileSystem] filesystem,
+ vector[c_string] paths,
+ shared_ptr[CFileFormat] format,
+ CFileSystemFactoryOptions options
+ )
+
+ @staticmethod
+ CResult[shared_ptr[CDatasetFactory]] MakeFromSelector "Make"(
+ shared_ptr[CFileSystem] filesystem,
+ CFileSelector,
+ shared_ptr[CFileFormat] format,
+ CFileSystemFactoryOptions options
+ )
+
+ cdef cppclass CParquetFactoryOptions \
+ "arrow::dataset::ParquetFactoryOptions":
+ CPartitioningOrFactory partitioning
+ c_string partition_base_dir
+ c_bool validate_column_chunk_paths
+
+ cdef cppclass CParquetDatasetFactory \
+ "arrow::dataset::ParquetDatasetFactory"(CDatasetFactory):
+ @staticmethod
+ CResult[shared_ptr[CDatasetFactory]] MakeFromMetaDataPath "Make"(
+ const c_string& metadata_path,
+ shared_ptr[CFileSystem] filesystem,
+ shared_ptr[CParquetFileFormat] format,
+ CParquetFactoryOptions options
+ )
+
+ @staticmethod
+ CResult[shared_ptr[CDatasetFactory]] MakeFromMetaDataSource "Make"(
+ const CFileSource& metadata_path,
+ const c_string& base_path,
+ shared_ptr[CFileSystem] filesystem,
+ shared_ptr[CParquetFileFormat] format,
+ CParquetFactoryOptions options
+ )
diff --git a/src/arrow/python/pyarrow/includes/libarrow_feather.pxd b/src/arrow/python/pyarrow/includes/libarrow_feather.pxd
new file mode 100644
index 000000000..ddfc8b2e5
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libarrow_feather.pxd
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from pyarrow.includes.libarrow cimport (CCompressionType, CStatus, CTable,
+ COutputStream, CResult, shared_ptr,
+ vector, CRandomAccessFile, CSchema,
+ c_string)
+
+
+cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil:
+ int kFeatherV1Version" arrow::ipc::feather::kFeatherV1Version"
+ int kFeatherV2Version" arrow::ipc::feather::kFeatherV2Version"
+
+ cdef cppclass CFeatherProperties" arrow::ipc::feather::WriteProperties":
+ int version
+ int chunksize
+ CCompressionType compression
+ int compression_level
+
+ CStatus WriteFeather" arrow::ipc::feather::WriteTable" \
+ (const CTable& table, COutputStream* out,
+ CFeatherProperties properties)
+
+ cdef cppclass CFeatherReader" arrow::ipc::feather::Reader":
+ @staticmethod
+ CResult[shared_ptr[CFeatherReader]] Open(
+ const shared_ptr[CRandomAccessFile]& file)
+ int version()
+ shared_ptr[CSchema] schema()
+
+ CStatus Read(shared_ptr[CTable]* out)
+ CStatus Read(const vector[int] indices, shared_ptr[CTable]* out)
+ CStatus Read(const vector[c_string] names, shared_ptr[CTable]* out)
diff --git a/src/arrow/python/pyarrow/includes/libarrow_flight.pxd b/src/arrow/python/pyarrow/includes/libarrow_flight.pxd
new file mode 100644
index 000000000..2ac737aba
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libarrow_flight.pxd
@@ -0,0 +1,560 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+
+
+cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
+ cdef char* CPyServerMiddlewareName\
+ " arrow::py::flight::kPyServerMiddlewareName"
+
+ cdef cppclass CActionType" arrow::flight::ActionType":
+ c_string type
+ c_string description
+
+ cdef cppclass CAction" arrow::flight::Action":
+ c_string type
+ shared_ptr[CBuffer] body
+
+ cdef cppclass CFlightResult" arrow::flight::Result":
+ CFlightResult()
+ CFlightResult(CFlightResult)
+ shared_ptr[CBuffer] body
+
+ cdef cppclass CBasicAuth" arrow::flight::BasicAuth":
+ CBasicAuth()
+ CBasicAuth(CBuffer)
+ CBasicAuth(CBasicAuth)
+ c_string username
+ c_string password
+
+ cdef cppclass CResultStream" arrow::flight::ResultStream":
+ CStatus Next(unique_ptr[CFlightResult]* result)
+
+ cdef cppclass CDescriptorType \
+ " arrow::flight::FlightDescriptor::DescriptorType":
+ bint operator==(CDescriptorType)
+
+ CDescriptorType CDescriptorTypeUnknown\
+ " arrow::flight::FlightDescriptor::UNKNOWN"
+ CDescriptorType CDescriptorTypePath\
+ " arrow::flight::FlightDescriptor::PATH"
+ CDescriptorType CDescriptorTypeCmd\
+ " arrow::flight::FlightDescriptor::CMD"
+
+ cdef cppclass CFlightDescriptor" arrow::flight::FlightDescriptor":
+ CDescriptorType type
+ c_string cmd
+ vector[c_string] path
+ CStatus SerializeToString(c_string* out)
+
+ @staticmethod
+ CStatus Deserialize(const c_string& serialized,
+ CFlightDescriptor* out)
+ bint operator==(CFlightDescriptor)
+
+ cdef cppclass CTicket" arrow::flight::Ticket":
+ CTicket()
+ c_string ticket
+ bint operator==(CTicket)
+ CStatus SerializeToString(c_string* out)
+
+ @staticmethod
+ CStatus Deserialize(const c_string& serialized, CTicket* out)
+
+ cdef cppclass CCriteria" arrow::flight::Criteria":
+ CCriteria()
+ c_string expression
+
+ cdef cppclass CLocation" arrow::flight::Location":
+ CLocation()
+ c_string ToString()
+ c_bool Equals(const CLocation& other)
+
+ @staticmethod
+ CStatus Parse(c_string& uri_string, CLocation* location)
+
+ @staticmethod
+ CStatus ForGrpcTcp(c_string& host, int port, CLocation* location)
+
+ @staticmethod
+ CStatus ForGrpcTls(c_string& host, int port, CLocation* location)
+
+ @staticmethod
+ CStatus ForGrpcUnix(c_string& path, CLocation* location)
+
+ cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint":
+ CFlightEndpoint()
+
+ CTicket ticket
+ vector[CLocation] locations
+
+ bint operator==(CFlightEndpoint)
+
+ cdef cppclass CFlightInfo" arrow::flight::FlightInfo":
+ CFlightInfo(CFlightInfo info)
+ int64_t total_records()
+ int64_t total_bytes()
+ CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out)
+ CFlightDescriptor& descriptor()
+ const vector[CFlightEndpoint]& endpoints()
+ CStatus SerializeToString(c_string* out)
+
+ @staticmethod
+ CStatus Deserialize(const c_string& serialized,
+ unique_ptr[CFlightInfo]* out)
+
+ cdef cppclass CSchemaResult" arrow::flight::SchemaResult":
+ CSchemaResult(CSchemaResult result)
+ CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out)
+
+ cdef cppclass CFlightListing" arrow::flight::FlightListing":
+ CStatus Next(unique_ptr[CFlightInfo]* info)
+
+ cdef cppclass CSimpleFlightListing" arrow::flight::SimpleFlightListing":
+ CSimpleFlightListing(vector[CFlightInfo]&& info)
+
+ cdef cppclass CFlightPayload" arrow::flight::FlightPayload":
+ shared_ptr[CBuffer] descriptor
+ shared_ptr[CBuffer] app_metadata
+ CIpcPayload ipc_message
+
+ cdef cppclass CFlightDataStream" arrow::flight::FlightDataStream":
+ shared_ptr[CSchema] schema()
+ CStatus Next(CFlightPayload*)
+
+ cdef cppclass CFlightStreamChunk" arrow::flight::FlightStreamChunk":
+ CFlightStreamChunk()
+ shared_ptr[CRecordBatch] data
+ shared_ptr[CBuffer] app_metadata
+
+ cdef cppclass CMetadataRecordBatchReader \
+ " arrow::flight::MetadataRecordBatchReader":
+ CResult[shared_ptr[CSchema]] GetSchema()
+ CStatus Next(CFlightStreamChunk* out)
+ CStatus ReadAll(shared_ptr[CTable]* table)
+
+ CResult[shared_ptr[CRecordBatchReader]] MakeRecordBatchReader\
+ " arrow::flight::MakeRecordBatchReader"(
+ shared_ptr[CMetadataRecordBatchReader])
+
+ cdef cppclass CMetadataRecordBatchWriter \
+ " arrow::flight::MetadataRecordBatchWriter"(CRecordBatchWriter):
+ CStatus Begin(shared_ptr[CSchema] schema,
+ const CIpcWriteOptions& options)
+ CStatus WriteMetadata(shared_ptr[CBuffer] app_metadata)
+ CStatus WriteWithMetadata(const CRecordBatch& batch,
+ shared_ptr[CBuffer] app_metadata)
+
+ cdef cppclass CFlightStreamReader \
+ " arrow::flight::FlightStreamReader"(CMetadataRecordBatchReader):
+ void Cancel()
+ CStatus ReadAllWithStopToken" ReadAll"\
+ (shared_ptr[CTable]* table, const CStopToken& stop_token)
+
+ cdef cppclass CFlightMessageReader \
+ " arrow::flight::FlightMessageReader"(CMetadataRecordBatchReader):
+ CFlightDescriptor& descriptor()
+
+ cdef cppclass CFlightMessageWriter \
+ " arrow::flight::FlightMessageWriter"(CMetadataRecordBatchWriter):
+ pass
+
+ cdef cppclass CFlightStreamWriter \
+ " arrow::flight::FlightStreamWriter"(CMetadataRecordBatchWriter):
+ CStatus DoneWriting()
+
+ cdef cppclass CRecordBatchStream \
+ " arrow::flight::RecordBatchStream"(CFlightDataStream):
+ CRecordBatchStream(shared_ptr[CRecordBatchReader]& reader,
+ const CIpcWriteOptions& options)
+
+ cdef cppclass CFlightMetadataReader" arrow::flight::FlightMetadataReader":
+ CStatus ReadMetadata(shared_ptr[CBuffer]* out)
+
+ cdef cppclass CFlightMetadataWriter" arrow::flight::FlightMetadataWriter":
+ CStatus WriteMetadata(const CBuffer& message)
+
+ cdef cppclass CServerAuthReader" arrow::flight::ServerAuthReader":
+ CStatus Read(c_string* token)
+
+ cdef cppclass CServerAuthSender" arrow::flight::ServerAuthSender":
+ CStatus Write(c_string& token)
+
+ cdef cppclass CClientAuthReader" arrow::flight::ClientAuthReader":
+ CStatus Read(c_string* token)
+
+ cdef cppclass CClientAuthSender" arrow::flight::ClientAuthSender":
+ CStatus Write(c_string& token)
+
+ cdef cppclass CServerAuthHandler" arrow::flight::ServerAuthHandler":
+ pass
+
+ cdef cppclass CClientAuthHandler" arrow::flight::ClientAuthHandler":
+ pass
+
+ cdef cppclass CServerCallContext" arrow::flight::ServerCallContext":
+ c_string& peer_identity()
+ c_string& peer()
+ c_bool is_cancelled()
+ CServerMiddleware* GetMiddleware(const c_string& key)
+
+ cdef cppclass CTimeoutDuration" arrow::flight::TimeoutDuration":
+ CTimeoutDuration(double)
+
+ cdef cppclass CFlightCallOptions" arrow::flight::FlightCallOptions":
+ CFlightCallOptions()
+ CTimeoutDuration timeout
+ CIpcWriteOptions write_options
+ vector[pair[c_string, c_string]] headers
+ CStopToken stop_token
+
+ cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair":
+ CCertKeyPair()
+ c_string pem_cert
+ c_string pem_key
+
+ cdef cppclass CFlightMethod" arrow::flight::FlightMethod":
+ bint operator==(CFlightMethod)
+
+ CFlightMethod CFlightMethodInvalid\
+ " arrow::flight::FlightMethod::Invalid"
+ CFlightMethod CFlightMethodHandshake\
+ " arrow::flight::FlightMethod::Handshake"
+ CFlightMethod CFlightMethodListFlights\
+ " arrow::flight::FlightMethod::ListFlights"
+ CFlightMethod CFlightMethodGetFlightInfo\
+ " arrow::flight::FlightMethod::GetFlightInfo"
+ CFlightMethod CFlightMethodGetSchema\
+ " arrow::flight::FlightMethod::GetSchema"
+ CFlightMethod CFlightMethodDoGet\
+ " arrow::flight::FlightMethod::DoGet"
+ CFlightMethod CFlightMethodDoPut\
+ " arrow::flight::FlightMethod::DoPut"
+ CFlightMethod CFlightMethodDoAction\
+ " arrow::flight::FlightMethod::DoAction"
+ CFlightMethod CFlightMethodListActions\
+ " arrow::flight::FlightMethod::ListActions"
+ CFlightMethod CFlightMethodDoExchange\
+ " arrow::flight::FlightMethod::DoExchange"
+
+ cdef cppclass CCallInfo" arrow::flight::CallInfo":
+ CFlightMethod method
+
+ # This is really std::unordered_multimap, but Cython has no
+ # bindings for it, so treat it as an opaque class and bind the
+ # methods we need
+ cdef cppclass CCallHeaders" arrow::flight::CallHeaders":
+ cppclass const_iterator:
+ pair[c_string, c_string] operator*()
+ const_iterator operator++()
+ bint operator==(const_iterator)
+ bint operator!=(const_iterator)
+ const_iterator cbegin()
+ const_iterator cend()
+
+ cdef cppclass CAddCallHeaders" arrow::flight::AddCallHeaders":
+ void AddHeader(const c_string& key, const c_string& value)
+
+ cdef cppclass CServerMiddleware" arrow::flight::ServerMiddleware":
+ c_string name()
+
+ cdef cppclass CServerMiddlewareFactory\
+ " arrow::flight::ServerMiddlewareFactory":
+ pass
+
+ cdef cppclass CClientMiddleware" arrow::flight::ClientMiddleware":
+ pass
+
+ cdef cppclass CClientMiddlewareFactory\
+ " arrow::flight::ClientMiddlewareFactory":
+ pass
+
+ cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions":
+ CFlightServerOptions(const CLocation& location)
+ CLocation location
+ unique_ptr[CServerAuthHandler] auth_handler
+ vector[CCertKeyPair] tls_certificates
+ c_bool verify_client
+ c_string root_certificates
+ vector[pair[c_string, shared_ptr[CServerMiddlewareFactory]]] middleware
+
+ cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
+ c_string tls_root_certs
+ c_string cert_chain
+ c_string private_key
+ c_string override_hostname
+ vector[shared_ptr[CClientMiddlewareFactory]] middleware
+ int64_t write_size_limit_bytes
+ vector[pair[c_string, CIntStringVariant]] generic_options
+ c_bool disable_server_verification
+
+ @staticmethod
+ CFlightClientOptions Defaults()
+
+ cdef cppclass CFlightClient" arrow::flight::FlightClient":
+ @staticmethod
+ CStatus Connect(const CLocation& location,
+ const CFlightClientOptions& options,
+ unique_ptr[CFlightClient]* client)
+
+ CStatus Authenticate(CFlightCallOptions& options,
+ unique_ptr[CClientAuthHandler] auth_handler)
+
+ CResult[pair[c_string, c_string]] AuthenticateBasicToken(
+ CFlightCallOptions& options,
+ const c_string& username,
+ const c_string& password)
+
+ CStatus DoAction(CFlightCallOptions& options, CAction& action,
+ unique_ptr[CResultStream]* results)
+ CStatus ListActions(CFlightCallOptions& options,
+ vector[CActionType]* actions)
+
+ CStatus ListFlights(CFlightCallOptions& options, CCriteria criteria,
+ unique_ptr[CFlightListing]* listing)
+ CStatus GetFlightInfo(CFlightCallOptions& options,
+ CFlightDescriptor& descriptor,
+ unique_ptr[CFlightInfo]* info)
+ CStatus GetSchema(CFlightCallOptions& options,
+ CFlightDescriptor& descriptor,
+ unique_ptr[CSchemaResult]* result)
+ CStatus DoGet(CFlightCallOptions& options, CTicket& ticket,
+ unique_ptr[CFlightStreamReader]* stream)
+ CStatus DoPut(CFlightCallOptions& options,
+ CFlightDescriptor& descriptor,
+ shared_ptr[CSchema]& schema,
+ unique_ptr[CFlightStreamWriter]* stream,
+ unique_ptr[CFlightMetadataReader]* reader)
+ CStatus DoExchange(CFlightCallOptions& options,
+ CFlightDescriptor& descriptor,
+ unique_ptr[CFlightStreamWriter]* writer,
+ unique_ptr[CFlightStreamReader]* reader)
+
+ cdef cppclass CFlightStatusCode" arrow::flight::FlightStatusCode":
+ bint operator==(CFlightStatusCode)
+
+ CFlightStatusCode CFlightStatusInternal \
+ " arrow::flight::FlightStatusCode::Internal"
+ CFlightStatusCode CFlightStatusTimedOut \
+ " arrow::flight::FlightStatusCode::TimedOut"
+ CFlightStatusCode CFlightStatusCancelled \
+ " arrow::flight::FlightStatusCode::Cancelled"
+ CFlightStatusCode CFlightStatusUnauthenticated \
+ " arrow::flight::FlightStatusCode::Unauthenticated"
+ CFlightStatusCode CFlightStatusUnauthorized \
+ " arrow::flight::FlightStatusCode::Unauthorized"
+ CFlightStatusCode CFlightStatusUnavailable \
+ " arrow::flight::FlightStatusCode::Unavailable"
+ CFlightStatusCode CFlightStatusFailed \
+ " arrow::flight::FlightStatusCode::Failed"
+
+ cdef cppclass FlightStatusDetail" arrow::flight::FlightStatusDetail":
+ CFlightStatusCode code()
+ c_string extra_info()
+
+ @staticmethod
+ shared_ptr[FlightStatusDetail] UnwrapStatus(const CStatus& status)
+
+ cdef cppclass FlightWriteSizeStatusDetail\
+ " arrow::flight::FlightWriteSizeStatusDetail":
+ int64_t limit()
+ int64_t actual()
+
+ @staticmethod
+ shared_ptr[FlightWriteSizeStatusDetail] UnwrapStatus(
+ const CStatus& status)
+
+ cdef CStatus MakeFlightError" arrow::flight::MakeFlightError" \
+ (CFlightStatusCode code, const c_string& message)
+
+ cdef CStatus MakeFlightError" arrow::flight::MakeFlightError" \
+ (CFlightStatusCode code,
+ const c_string& message,
+ const c_string& extra_info)
+
+# Callbacks for implementing Flight servers
+# Use typedef to emulate syntax for std::function<void(..)>
+ctypedef CStatus cb_list_flights(object, const CServerCallContext&,
+ const CCriteria*,
+ unique_ptr[CFlightListing]*)
+ctypedef CStatus cb_get_flight_info(object, const CServerCallContext&,
+ const CFlightDescriptor&,
+ unique_ptr[CFlightInfo]*)
+ctypedef CStatus cb_get_schema(object, const CServerCallContext&,
+ const CFlightDescriptor&,
+ unique_ptr[CSchemaResult]*)
+ctypedef CStatus cb_do_put(object, const CServerCallContext&,
+ unique_ptr[CFlightMessageReader],
+ unique_ptr[CFlightMetadataWriter])
+ctypedef CStatus cb_do_get(object, const CServerCallContext&,
+ const CTicket&,
+ unique_ptr[CFlightDataStream]*)
+ctypedef CStatus cb_do_exchange(object, const CServerCallContext&,
+ unique_ptr[CFlightMessageReader],
+ unique_ptr[CFlightMessageWriter])
+ctypedef CStatus cb_do_action(object, const CServerCallContext&,
+ const CAction&,
+ unique_ptr[CResultStream]*)
+ctypedef CStatus cb_list_actions(object, const CServerCallContext&,
+ vector[CActionType]*)
+ctypedef CStatus cb_result_next(object, unique_ptr[CFlightResult]*)
+ctypedef CStatus cb_data_stream_next(object, CFlightPayload*)
+ctypedef CStatus cb_server_authenticate(object, CServerAuthSender*,
+ CServerAuthReader*)
+ctypedef CStatus cb_is_valid(object, const c_string&, c_string*)
+ctypedef CStatus cb_client_authenticate(object, CClientAuthSender*,
+ CClientAuthReader*)
+ctypedef CStatus cb_get_token(object, c_string*)
+
+ctypedef CStatus cb_middleware_sending_headers(object, CAddCallHeaders*)
+ctypedef CStatus cb_middleware_call_completed(object, const CStatus&)
+ctypedef CStatus cb_client_middleware_received_headers(
+ object, const CCallHeaders&)
+ctypedef CStatus cb_server_middleware_start_call(
+ object,
+ const CCallInfo&,
+ const CCallHeaders&,
+ shared_ptr[CServerMiddleware]*)
+ctypedef CStatus cb_client_middleware_start_call(
+ object,
+ const CCallInfo&,
+ unique_ptr[CClientMiddleware]*)
+
+cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
+ cdef cppclass PyFlightServerVtable:
+ PyFlightServerVtable()
+ function[cb_list_flights] list_flights
+ function[cb_get_flight_info] get_flight_info
+ function[cb_get_schema] get_schema
+ function[cb_do_put] do_put
+ function[cb_do_get] do_get
+ function[cb_do_exchange] do_exchange
+ function[cb_do_action] do_action
+ function[cb_list_actions] list_actions
+
+ cdef cppclass PyServerAuthHandlerVtable:
+ PyServerAuthHandlerVtable()
+ function[cb_server_authenticate] authenticate
+ function[cb_is_valid] is_valid
+
+ cdef cppclass PyClientAuthHandlerVtable:
+ PyClientAuthHandlerVtable()
+ function[cb_client_authenticate] authenticate
+ function[cb_get_token] get_token
+
+ cdef cppclass PyFlightServer:
+ PyFlightServer(object server, PyFlightServerVtable vtable)
+
+ CStatus Init(CFlightServerOptions& options)
+ int port()
+ CStatus ServeWithSignals() except *
+ CStatus Shutdown()
+ CStatus Wait()
+
+ cdef cppclass PyServerAuthHandler\
+ " arrow::py::flight::PyServerAuthHandler"(CServerAuthHandler):
+ PyServerAuthHandler(object handler, PyServerAuthHandlerVtable vtable)
+
+ cdef cppclass PyClientAuthHandler\
+ " arrow::py::flight::PyClientAuthHandler"(CClientAuthHandler):
+ PyClientAuthHandler(object handler, PyClientAuthHandlerVtable vtable)
+
+ cdef cppclass CPyFlightResultStream\
+ " arrow::py::flight::PyFlightResultStream"(CResultStream):
+ CPyFlightResultStream(object generator,
+ function[cb_result_next] callback)
+
+ cdef cppclass CPyFlightDataStream\
+ " arrow::py::flight::PyFlightDataStream"(CFlightDataStream):
+ CPyFlightDataStream(object data_source,
+ unique_ptr[CFlightDataStream] stream)
+
+ cdef cppclass CPyGeneratorFlightDataStream\
+ " arrow::py::flight::PyGeneratorFlightDataStream"\
+ (CFlightDataStream):
+ CPyGeneratorFlightDataStream(object generator,
+ shared_ptr[CSchema] schema,
+ function[cb_data_stream_next] callback,
+ const CIpcWriteOptions& options)
+
+ cdef cppclass PyServerMiddlewareVtable\
+ " arrow::py::flight::PyServerMiddleware::Vtable":
+ PyServerMiddlewareVtable()
+ function[cb_middleware_sending_headers] sending_headers
+ function[cb_middleware_call_completed] call_completed
+
+ cdef cppclass PyClientMiddlewareVtable\
+ " arrow::py::flight::PyClientMiddleware::Vtable":
+ PyClientMiddlewareVtable()
+ function[cb_middleware_sending_headers] sending_headers
+ function[cb_client_middleware_received_headers] received_headers
+ function[cb_middleware_call_completed] call_completed
+
+ cdef cppclass CPyServerMiddleware\
+ " arrow::py::flight::PyServerMiddleware"(CServerMiddleware):
+ CPyServerMiddleware(object middleware, PyServerMiddlewareVtable vtable)
+ void* py_object()
+
+ cdef cppclass CPyServerMiddlewareFactory\
+ " arrow::py::flight::PyServerMiddlewareFactory"\
+ (CServerMiddlewareFactory):
+ CPyServerMiddlewareFactory(
+ object factory,
+ function[cb_server_middleware_start_call] start_call)
+
+ cdef cppclass CPyClientMiddleware\
+ " arrow::py::flight::PyClientMiddleware"(CClientMiddleware):
+ CPyClientMiddleware(object middleware, PyClientMiddlewareVtable vtable)
+
+ cdef cppclass CPyClientMiddlewareFactory\
+ " arrow::py::flight::PyClientMiddlewareFactory"\
+ (CClientMiddlewareFactory):
+ CPyClientMiddlewareFactory(
+ object factory,
+ function[cb_client_middleware_start_call] start_call)
+
+ cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"(
+ shared_ptr[CSchema] schema,
+ CFlightDescriptor& descriptor,
+ vector[CFlightEndpoint] endpoints,
+ int64_t total_records,
+ int64_t total_bytes,
+ unique_ptr[CFlightInfo]* out)
+
+ cdef CStatus CreateSchemaResult" arrow::py::flight::CreateSchemaResult"(
+ shared_ptr[CSchema] schema,
+ unique_ptr[CSchemaResult]* out)
+
+ cdef CStatus DeserializeBasicAuth\
+ " arrow::py::flight::DeserializeBasicAuth"(
+ c_string buf,
+ unique_ptr[CBasicAuth]* out)
+
+ cdef CStatus SerializeBasicAuth" arrow::py::flight::SerializeBasicAuth"(
+ CBasicAuth basic_auth,
+ c_string* out)
+
+
+cdef extern from "arrow/util/variant.h" namespace "arrow" nogil:
+ cdef cppclass CIntStringVariant" arrow::util::Variant<int, std::string>":
+ CIntStringVariant()
+ CIntStringVariant(int)
+ CIntStringVariant(c_string)
diff --git a/src/arrow/python/pyarrow/includes/libarrow_fs.pxd b/src/arrow/python/pyarrow/includes/libarrow_fs.pxd
new file mode 100644
index 000000000..9dca5fbf6
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libarrow_fs.pxd
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+
+cdef extern from "arrow/filesystem/api.h" namespace "arrow::fs" nogil:
+
+ ctypedef enum CFileType "arrow::fs::FileType":
+ CFileType_NotFound "arrow::fs::FileType::NotFound"
+ CFileType_Unknown "arrow::fs::FileType::Unknown"
+ CFileType_File "arrow::fs::FileType::File"
+ CFileType_Directory "arrow::fs::FileType::Directory"
+
+ cdef cppclass CFileInfo "arrow::fs::FileInfo":
+ CFileInfo()
+ CFileInfo(CFileInfo&&)
+ CFileInfo& operator=(CFileInfo&&)
+ CFileInfo(const CFileInfo&)
+ CFileInfo& operator=(const CFileInfo&)
+
+ CFileType type()
+ void set_type(CFileType type)
+ c_string path()
+ void set_path(const c_string& path)
+ c_string base_name()
+ int64_t size()
+ void set_size(int64_t size)
+ c_string extension()
+ CTimePoint mtime()
+ void set_mtime(CTimePoint mtime)
+
+ cdef cppclass CFileSelector "arrow::fs::FileSelector":
+ CFileSelector()
+ c_string base_dir
+ c_bool allow_not_found
+ c_bool recursive
+
+ cdef cppclass CFileLocator "arrow::fs::FileLocator":
+ shared_ptr[CFileSystem] filesystem
+ c_string path
+
+ cdef cppclass CFileSystem "arrow::fs::FileSystem":
+ shared_ptr[CFileSystem] shared_from_this()
+ c_string type_name() const
+ CResult[c_string] NormalizePath(c_string path)
+ CResult[CFileInfo] GetFileInfo(const c_string& path)
+ CResult[vector[CFileInfo]] GetFileInfo(
+ const vector[c_string]& paths)
+ CResult[vector[CFileInfo]] GetFileInfo(const CFileSelector& select)
+ CStatus CreateDir(const c_string& path, c_bool recursive)
+ CStatus DeleteDir(const c_string& path)
+ CStatus DeleteDirContents(const c_string& path)
+ CStatus DeleteRootDirContents()
+ CStatus DeleteFile(const c_string& path)
+ CStatus DeleteFiles(const vector[c_string]& paths)
+ CStatus Move(const c_string& src, const c_string& dest)
+ CStatus CopyFile(const c_string& src, const c_string& dest)
+ CResult[shared_ptr[CInputStream]] OpenInputStream(
+ const c_string& path)
+ CResult[shared_ptr[CRandomAccessFile]] OpenInputFile(
+ const c_string& path)
+ CResult[shared_ptr[COutputStream]] OpenOutputStream(
+ const c_string& path, const shared_ptr[const CKeyValueMetadata]&)
+ CResult[shared_ptr[COutputStream]] OpenAppendStream(
+ const c_string& path, const shared_ptr[const CKeyValueMetadata]&)
+ c_bool Equals(const CFileSystem& other)
+ c_bool Equals(shared_ptr[CFileSystem] other)
+
+ CResult[shared_ptr[CFileSystem]] CFileSystemFromUri \
+ "arrow::fs::FileSystemFromUri"(const c_string& uri, c_string* out_path)
+ CResult[shared_ptr[CFileSystem]] CFileSystemFromUriOrPath \
+ "arrow::fs::FileSystemFromUriOrPath"(const c_string& uri,
+ c_string* out_path)
+
+ cdef cppclass CFileSystemGlobalOptions \
+ "arrow::fs::FileSystemGlobalOptions":
+ c_string tls_ca_file_path
+ c_string tls_ca_dir_path
+
+ CStatus CFileSystemsInitialize "arrow::fs::Initialize" \
+ (const CFileSystemGlobalOptions& options)
+
+ cdef cppclass CLocalFileSystemOptions "arrow::fs::LocalFileSystemOptions":
+ c_bool use_mmap
+
+ @staticmethod
+ CLocalFileSystemOptions Defaults()
+
+ c_bool Equals(const CLocalFileSystemOptions& other)
+
+ cdef cppclass CLocalFileSystem "arrow::fs::LocalFileSystem"(CFileSystem):
+ CLocalFileSystem()
+ CLocalFileSystem(CLocalFileSystemOptions)
+ CLocalFileSystemOptions options()
+
+ cdef cppclass CSubTreeFileSystem \
+ "arrow::fs::SubTreeFileSystem"(CFileSystem):
+ CSubTreeFileSystem(const c_string& base_path,
+ shared_ptr[CFileSystem] base_fs)
+ c_string base_path()
+ shared_ptr[CFileSystem] base_fs()
+
+ ctypedef enum CS3LogLevel "arrow::fs::S3LogLevel":
+ CS3LogLevel_Off "arrow::fs::S3LogLevel::Off"
+ CS3LogLevel_Fatal "arrow::fs::S3LogLevel::Fatal"
+ CS3LogLevel_Error "arrow::fs::S3LogLevel::Error"
+ CS3LogLevel_Warn "arrow::fs::S3LogLevel::Warn"
+ CS3LogLevel_Info "arrow::fs::S3LogLevel::Info"
+ CS3LogLevel_Debug "arrow::fs::S3LogLevel::Debug"
+ CS3LogLevel_Trace "arrow::fs::S3LogLevel::Trace"
+
+ cdef struct CS3GlobalOptions "arrow::fs::S3GlobalOptions":
+ CS3LogLevel log_level
+
+ cdef cppclass CS3ProxyOptions "arrow::fs::S3ProxyOptions":
+ c_string scheme
+ c_string host
+ int port
+ c_string username
+ c_string password
+ c_bool Equals(const CS3ProxyOptions& other)
+
+ @staticmethod
+ CResult[CS3ProxyOptions] FromUriString "FromUri"(
+ const c_string& uri_string)
+
+ ctypedef enum CS3CredentialsKind "arrow::fs::S3CredentialsKind":
+ CS3CredentialsKind_Anonymous "arrow::fs::S3CredentialsKind::Anonymous"
+ CS3CredentialsKind_Default "arrow::fs::S3CredentialsKind::Default"
+ CS3CredentialsKind_Explicit "arrow::fs::S3CredentialsKind::Explicit"
+ CS3CredentialsKind_Role "arrow::fs::S3CredentialsKind::Role"
+ CS3CredentialsKind_WebIdentity \
+ "arrow::fs::S3CredentialsKind::WebIdentity"
+
+ cdef cppclass CS3Options "arrow::fs::S3Options":
+ c_string region
+ c_string endpoint_override
+ c_string scheme
+ c_bool background_writes
+ shared_ptr[const CKeyValueMetadata] default_metadata
+ c_string role_arn
+ c_string session_name
+ c_string external_id
+ int load_frequency
+ CS3ProxyOptions proxy_options
+ CS3CredentialsKind credentials_kind
+ void ConfigureDefaultCredentials()
+ void ConfigureAccessKey(const c_string& access_key,
+ const c_string& secret_key,
+ const c_string& session_token)
+ c_string GetAccessKey()
+ c_string GetSecretKey()
+ c_string GetSessionToken()
+ c_bool Equals(const CS3Options& other)
+
+ @staticmethod
+ CS3Options Defaults()
+
+ @staticmethod
+ CS3Options Anonymous()
+
+ @staticmethod
+ CS3Options FromAccessKey(const c_string& access_key,
+ const c_string& secret_key,
+ const c_string& session_token)
+
+ @staticmethod
+ CS3Options FromAssumeRole(const c_string& role_arn,
+ const c_string& session_name,
+ const c_string& external_id,
+ const int load_frequency)
+
+ cdef cppclass CS3FileSystem "arrow::fs::S3FileSystem"(CFileSystem):
+ @staticmethod
+ CResult[shared_ptr[CS3FileSystem]] Make(const CS3Options& options)
+ CS3Options options()
+ c_string region()
+
+ cdef CStatus CInitializeS3 "arrow::fs::InitializeS3"(
+ const CS3GlobalOptions& options)
+ cdef CStatus CFinalizeS3 "arrow::fs::FinalizeS3"()
+
+ cdef cppclass CHdfsOptions "arrow::fs::HdfsOptions":
+ HdfsConnectionConfig connection_config
+ int32_t buffer_size
+ int16_t replication
+ int64_t default_block_size
+
+ @staticmethod
+ CResult[CHdfsOptions] FromUriString "FromUri"(
+ const c_string& uri_string)
+ void ConfigureEndPoint(c_string host, int port)
+ void ConfigureDriver(c_bool use_hdfs3)
+ void ConfigureReplication(int16_t replication)
+ void ConfigureUser(c_string user_name)
+ void ConfigureBufferSize(int32_t buffer_size)
+ void ConfigureBlockSize(int64_t default_block_size)
+ void ConfigureKerberosTicketCachePath(c_string path)
+ void ConfigureExtraConf(c_string key, c_string value)
+
+ cdef cppclass CHadoopFileSystem "arrow::fs::HadoopFileSystem"(CFileSystem):
+ @staticmethod
+ CResult[shared_ptr[CHadoopFileSystem]] Make(
+ const CHdfsOptions& options)
+ CHdfsOptions options()
+
+ cdef cppclass CMockFileSystem "arrow::fs::internal::MockFileSystem"(
+ CFileSystem):
+ CMockFileSystem(CTimePoint current_time)
+
+ CStatus CCopyFiles "arrow::fs::CopyFiles"(
+ const vector[CFileLocator]& sources,
+ const vector[CFileLocator]& destinations,
+ const CIOContext& io_context,
+ int64_t chunk_size, c_bool use_threads)
+ CStatus CCopyFilesWithSelector "arrow::fs::CopyFiles"(
+ const shared_ptr[CFileSystem]& source_fs,
+ const CFileSelector& source_sel,
+ const shared_ptr[CFileSystem]& destination_fs,
+ const c_string& destination_base_dir,
+ const CIOContext& io_context,
+ int64_t chunk_size, c_bool use_threads)
+
+
+# Callbacks for implementing Python filesystems
+# Use typedef to emulate syntax for std::function<void(..)>
+ctypedef void CallbackGetTypeName(object, c_string*)
+ctypedef c_bool CallbackEquals(object, const CFileSystem&)
+
+ctypedef void CallbackGetFileInfo(object, const c_string&, CFileInfo*)
+ctypedef void CallbackGetFileInfoVector(object, const vector[c_string]&,
+ vector[CFileInfo]*)
+ctypedef void CallbackGetFileInfoSelector(object, const CFileSelector&,
+ vector[CFileInfo]*)
+ctypedef void CallbackCreateDir(object, const c_string&, c_bool)
+ctypedef void CallbackDeleteDir(object, const c_string&)
+ctypedef void CallbackDeleteDirContents(object, const c_string&)
+ctypedef void CallbackDeleteRootDirContents(object)
+ctypedef void CallbackDeleteFile(object, const c_string&)
+ctypedef void CallbackMove(object, const c_string&, const c_string&)
+ctypedef void CallbackCopyFile(object, const c_string&, const c_string&)
+
+ctypedef void CallbackOpenInputStream(object, const c_string&,
+ shared_ptr[CInputStream]*)
+ctypedef void CallbackOpenInputFile(object, const c_string&,
+ shared_ptr[CRandomAccessFile]*)
+ctypedef void CallbackOpenOutputStream(
+ object, const c_string&, const shared_ptr[const CKeyValueMetadata]&,
+ shared_ptr[COutputStream]*)
+ctypedef void CallbackNormalizePath(object, const c_string&, c_string*)
+
+cdef extern from "arrow/python/filesystem.h" namespace "arrow::py::fs" nogil:
+
+ cdef cppclass CPyFileSystemVtable "arrow::py::fs::PyFileSystemVtable":
+ PyFileSystemVtable()
+ function[CallbackGetTypeName] get_type_name
+ function[CallbackEquals] equals
+ function[CallbackGetFileInfo] get_file_info
+ function[CallbackGetFileInfoVector] get_file_info_vector
+ function[CallbackGetFileInfoSelector] get_file_info_selector
+ function[CallbackCreateDir] create_dir
+ function[CallbackDeleteDir] delete_dir
+ function[CallbackDeleteDirContents] delete_dir_contents
+ function[CallbackDeleteRootDirContents] delete_root_dir_contents
+ function[CallbackDeleteFile] delete_file
+ function[CallbackMove] move
+ function[CallbackCopyFile] copy_file
+ function[CallbackOpenInputStream] open_input_stream
+ function[CallbackOpenInputFile] open_input_file
+ function[CallbackOpenOutputStream] open_output_stream
+ function[CallbackOpenOutputStream] open_append_stream
+ function[CallbackNormalizePath] normalize_path
+
+ cdef cppclass CPyFileSystem "arrow::py::fs::PyFileSystem":
+ @staticmethod
+ shared_ptr[CPyFileSystem] Make(object handler,
+ CPyFileSystemVtable vtable)
+
+ PyObject* handler()
diff --git a/src/arrow/python/pyarrow/includes/libgandiva.pxd b/src/arrow/python/pyarrow/includes/libgandiva.pxd
new file mode 100644
index 000000000..c75977d37
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libgandiva.pxd
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from libcpp.string cimport string as c_string
+from libcpp.unordered_set cimport unordered_set as c_unordered_set
+from libc.stdint cimport int64_t, int32_t, uint8_t, uintptr_t
+
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+
+cdef extern from "gandiva/node.h" namespace "gandiva" nogil:
+
+ cdef cppclass CNode" gandiva::Node":
+ c_string ToString()
+ shared_ptr[CDataType] return_type()
+
+ cdef cppclass CExpression" gandiva::Expression":
+ c_string ToString()
+ shared_ptr[CNode] root()
+ shared_ptr[CField] result()
+
+ ctypedef vector[shared_ptr[CNode]] CNodeVector" gandiva::NodeVector"
+
+ ctypedef vector[shared_ptr[CExpression]] \
+ CExpressionVector" gandiva::ExpressionVector"
+
+cdef extern from "gandiva/selection_vector.h" namespace "gandiva" nogil:
+
+ cdef cppclass CSelectionVector" gandiva::SelectionVector":
+
+ shared_ptr[CArray] ToArray()
+
+ enum CSelectionVector_Mode" gandiva::SelectionVector::Mode":
+ CSelectionVector_Mode_NONE" gandiva::SelectionVector::Mode::MODE_NONE"
+ CSelectionVector_Mode_UINT16" \
+ gandiva::SelectionVector::Mode::MODE_UINT16"
+ CSelectionVector_Mode_UINT32" \
+ gandiva::SelectionVector::Mode::MODE_UINT32"
+ CSelectionVector_Mode_UINT64" \
+ gandiva::SelectionVector::Mode::MODE_UINT64"
+
+ cdef CStatus SelectionVector_MakeInt16\
+ "gandiva::SelectionVector::MakeInt16"(
+ int64_t max_slots, CMemoryPool* pool,
+ shared_ptr[CSelectionVector]* selection_vector)
+
+ cdef CStatus SelectionVector_MakeInt32\
+ "gandiva::SelectionVector::MakeInt32"(
+ int64_t max_slots, CMemoryPool* pool,
+ shared_ptr[CSelectionVector]* selection_vector)
+
+ cdef CStatus SelectionVector_MakeInt64\
+ "gandiva::SelectionVector::MakeInt64"(
+ int64_t max_slots, CMemoryPool* pool,
+ shared_ptr[CSelectionVector]* selection_vector)
+
+cdef inline CSelectionVector_Mode _ensure_selection_mode(str name) except *:
+ uppercase = name.upper()
+ if uppercase == 'NONE':
+ return CSelectionVector_Mode_NONE
+ elif uppercase == 'UINT16':
+ return CSelectionVector_Mode_UINT16
+ elif uppercase == 'UINT32':
+ return CSelectionVector_Mode_UINT32
+ elif uppercase == 'UINT64':
+ return CSelectionVector_Mode_UINT64
+ else:
+ raise ValueError('Invalid value for Selection Mode: {!r}'.format(name))
+
+cdef inline str _selection_mode_name(CSelectionVector_Mode ctype):
+ if ctype == CSelectionVector_Mode_NONE:
+ return 'NONE'
+ elif ctype == CSelectionVector_Mode_UINT16:
+ return 'UINT16'
+ elif ctype == CSelectionVector_Mode_UINT32:
+ return 'UINT32'
+ elif ctype == CSelectionVector_Mode_UINT64:
+ return 'UINT64'
+ else:
+ raise RuntimeError('Unexpected CSelectionVector_Mode value')
+
+cdef extern from "gandiva/condition.h" namespace "gandiva" nogil:
+
+ cdef cppclass CCondition" gandiva::Condition":
+ c_string ToString()
+ shared_ptr[CNode] root()
+ shared_ptr[CField] result()
+
+cdef extern from "gandiva/arrow.h" namespace "gandiva" nogil:
+
+ ctypedef vector[shared_ptr[CArray]] CArrayVector" gandiva::ArrayVector"
+
+
+cdef extern from "gandiva/tree_expr_builder.h" namespace "gandiva" nogil:
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeBoolLiteral \
+ "gandiva::TreeExprBuilder::MakeLiteral"(c_bool value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt8Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint8_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt16Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint16_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt32Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint32_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt64Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint64_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt8Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int8_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt16Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int16_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt32Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int32_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt64Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int64_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeFloatLiteral \
+ "gandiva::TreeExprBuilder::MakeLiteral"(float value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeDoubleLiteral \
+ "gandiva::TreeExprBuilder::MakeLiteral"(double value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeStringLiteral \
+ "gandiva::TreeExprBuilder::MakeStringLiteral"(const c_string& value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeBinaryLiteral \
+ "gandiva::TreeExprBuilder::MakeBinaryLiteral"(const c_string& value)
+
+ cdef shared_ptr[CExpression] TreeExprBuilder_MakeExpression\
+ "gandiva::TreeExprBuilder::MakeExpression"(
+ shared_ptr[CNode] root_node, shared_ptr[CField] result_field)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeFunction \
+ "gandiva::TreeExprBuilder::MakeFunction"(
+ const c_string& name, const CNodeVector& children,
+ shared_ptr[CDataType] return_type)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeField \
+ "gandiva::TreeExprBuilder::MakeField"(shared_ptr[CField] field)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeIf \
+ "gandiva::TreeExprBuilder::MakeIf"(
+ shared_ptr[CNode] condition, shared_ptr[CNode] this_node,
+ shared_ptr[CNode] else_node, shared_ptr[CDataType] return_type)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeAnd \
+ "gandiva::TreeExprBuilder::MakeAnd"(const CNodeVector& children)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeOr \
+ "gandiva::TreeExprBuilder::MakeOr"(const CNodeVector& children)
+
+ cdef shared_ptr[CCondition] TreeExprBuilder_MakeCondition \
+ "gandiva::TreeExprBuilder::MakeCondition"(
+ shared_ptr[CNode] condition)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionInt32 \
+ "gandiva::TreeExprBuilder::MakeInExpressionInt32"(
+ shared_ptr[CNode] node, const c_unordered_set[int32_t]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionInt64 \
+ "gandiva::TreeExprBuilder::MakeInExpressionInt64"(
+ shared_ptr[CNode] node, const c_unordered_set[int64_t]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionTime32 \
+ "gandiva::TreeExprBuilder::MakeInExpressionTime32"(
+ shared_ptr[CNode] node, const c_unordered_set[int32_t]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionTime64 \
+ "gandiva::TreeExprBuilder::MakeInExpressionTime64"(
+ shared_ptr[CNode] node, const c_unordered_set[int64_t]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionDate32 \
+ "gandiva::TreeExprBuilder::MakeInExpressionDate32"(
+ shared_ptr[CNode] node, const c_unordered_set[int32_t]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionDate64 \
+ "gandiva::TreeExprBuilder::MakeInExpressionDate64"(
+ shared_ptr[CNode] node, const c_unordered_set[int64_t]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionTimeStamp \
+ "gandiva::TreeExprBuilder::MakeInExpressionTimeStamp"(
+ shared_ptr[CNode] node, const c_unordered_set[int64_t]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionString \
+ "gandiva::TreeExprBuilder::MakeInExpressionString"(
+ shared_ptr[CNode] node, const c_unordered_set[c_string]& values)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionBinary \
+ "gandiva::TreeExprBuilder::MakeInExpressionBinary"(
+ shared_ptr[CNode] node, const c_unordered_set[c_string]& values)
+
+cdef extern from "gandiva/projector.h" namespace "gandiva" nogil:
+
+ cdef cppclass CProjector" gandiva::Projector":
+
+ CStatus Evaluate(
+ const CRecordBatch& batch, CMemoryPool* pool,
+ const CArrayVector* output)
+
+ CStatus Evaluate(
+ const CRecordBatch& batch,
+ const CSelectionVector* selection,
+ CMemoryPool* pool,
+ const CArrayVector* output)
+
+ c_string DumpIR()
+
+ cdef CStatus Projector_Make \
+ "gandiva::Projector::Make"(
+ shared_ptr[CSchema] schema, const CExpressionVector& children,
+ shared_ptr[CProjector]* projector)
+
+ cdef CStatus Projector_Make \
+ "gandiva::Projector::Make"(
+ shared_ptr[CSchema] schema, const CExpressionVector& children,
+ CSelectionVector_Mode mode,
+ shared_ptr[CConfiguration] configuration,
+ shared_ptr[CProjector]* projector)
+
+cdef extern from "gandiva/filter.h" namespace "gandiva" nogil:
+
+ cdef cppclass CFilter" gandiva::Filter":
+
+ CStatus Evaluate(
+ const CRecordBatch& batch,
+ shared_ptr[CSelectionVector] out_selection)
+
+ c_string DumpIR()
+
+ cdef CStatus Filter_Make \
+ "gandiva::Filter::Make"(
+ shared_ptr[CSchema] schema, shared_ptr[CCondition] condition,
+ shared_ptr[CFilter]* filter)
+
+cdef extern from "gandiva/function_signature.h" namespace "gandiva" nogil:
+
+ cdef cppclass CFunctionSignature" gandiva::FunctionSignature":
+
+ CFunctionSignature(const c_string& base_name,
+ vector[shared_ptr[CDataType]] param_types,
+ shared_ptr[CDataType] ret_type)
+
+ shared_ptr[CDataType] ret_type() const
+
+ const c_string& base_name() const
+
+ vector[shared_ptr[CDataType]] param_types() const
+
+ c_string ToString() const
+
+cdef extern from "gandiva/expression_registry.h" namespace "gandiva" nogil:
+
+ cdef vector[shared_ptr[CFunctionSignature]] \
+ GetRegisteredFunctionSignatures()
+
+cdef extern from "gandiva/configuration.h" namespace "gandiva" nogil:
+
+ cdef cppclass CConfiguration" gandiva::Configuration":
+ pass
+
+ cdef cppclass CConfigurationBuilder \
+ " gandiva::ConfigurationBuilder":
+ @staticmethod
+ shared_ptr[CConfiguration] DefaultConfiguration()
diff --git a/src/arrow/python/pyarrow/includes/libplasma.pxd b/src/arrow/python/pyarrow/includes/libplasma.pxd
new file mode 100644
index 000000000..d54e9f484
--- /dev/null
+++ b/src/arrow/python/pyarrow/includes/libplasma.pxd
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language = c++
+
+from pyarrow.includes.common cimport *
+
+cdef extern from "plasma/common.h" namespace "plasma" nogil:
+ cdef c_bool IsPlasmaObjectExists(const CStatus& status)
+ cdef c_bool IsPlasmaObjectNotFound(const CStatus& status)
+ cdef c_bool IsPlasmaStoreFull(const CStatus& status)
diff --git a/src/arrow/python/pyarrow/io.pxi b/src/arrow/python/pyarrow/io.pxi
new file mode 100644
index 000000000..f6c2b4219
--- /dev/null
+++ b/src/arrow/python/pyarrow/io.pxi
@@ -0,0 +1,2137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Cython wrappers for IO interfaces defined in arrow::io and messaging in
+# arrow::ipc
+
+from libc.stdlib cimport malloc, free
+
+import codecs
+import re
+import sys
+import threading
+import time
+import warnings
+from io import BufferedIOBase, IOBase, TextIOBase, UnsupportedOperation
+from queue import Queue, Empty as QueueEmpty
+
+from pyarrow.util import _is_path_like, _stringify_path
+
+
+# 64K
+DEFAULT_BUFFER_SIZE = 2 ** 16
+
+
+# To let us get a PyObject* and avoid Cython auto-ref-counting
+cdef extern from "Python.h":
+ PyObject* PyBytes_FromStringAndSizeNative" PyBytes_FromStringAndSize"(
+ char *v, Py_ssize_t len) except NULL
+
+
+def io_thread_count():
+ """
+ Return the number of threads to use for I/O operations.
+
+ Many operations, such as scanning a dataset, will implicitly make
+ use of this pool. The number of threads is set to a fixed value at
+ startup. It can be modified at runtime by calling
+ :func:`set_io_thread_count()`.
+
+ See Also
+ --------
+ set_io_thread_count : Modify the size of this pool.
+ cpu_count : The analogous function for the CPU thread pool.
+ """
+ return GetIOThreadPoolCapacity()
+
+
+def set_io_thread_count(int count):
+ """
+ Set the number of threads to use for I/O operations.
+
+ Many operations, such as scanning a dataset, will implicitly make
+ use of this pool.
+
+ Parameters
+ ----------
+ count : int
+ The max number of threads that may be used for I/O.
+ Must be positive.
+
+ See Also
+ --------
+ io_thread_count : Get the size of this pool.
+ set_cpu_count : The analogous function for the CPU thread pool.
+ """
+ if count < 1:
+ raise ValueError("IO thread count must be strictly positive")
+ check_status(SetIOThreadPoolCapacity(count))
+
+
+cdef class NativeFile(_Weakrefable):
+ """
+ The base class for all Arrow streams.
+
+ Streams are either readable, writable, or both.
+ They optionally support seeking.
+
+ While this class exposes methods to read or write data from Python, the
+ primary intent of using a Arrow stream is to pass it to other Arrow
+ facilities that will make use of it, such as Arrow IPC routines.
+
+ Be aware that there are subtle differences with regular Python files,
+ e.g. destroying a writable Arrow stream without closing it explicitly
+ will not flush any pending data.
+ """
+
+ def __cinit__(self):
+ self.own_file = False
+ self.is_readable = False
+ self.is_writable = False
+ self.is_seekable = False
+
+ def __dealloc__(self):
+ if self.own_file:
+ self.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ self.close()
+
+ @property
+ def mode(self):
+ """
+ The file mode. Currently instances of NativeFile may support:
+
+ * rb: binary read
+ * wb: binary write
+ * rb+: binary read and write
+ """
+ # Emulate built-in file modes
+ if self.is_readable and self.is_writable:
+ return 'rb+'
+ elif self.is_readable:
+ return 'rb'
+ elif self.is_writable:
+ return 'wb'
+ else:
+ raise ValueError('File object is malformed, has no mode')
+
+ def readable(self):
+ self._assert_open()
+ return self.is_readable
+
+ def writable(self):
+ self._assert_open()
+ return self.is_writable
+
+ def seekable(self):
+ self._assert_open()
+ return self.is_seekable
+
+ def isatty(self):
+ self._assert_open()
+ return False
+
+ def fileno(self):
+ """
+ NOT IMPLEMENTED
+ """
+ raise UnsupportedOperation()
+
+ @property
+ def closed(self):
+ if self.is_readable:
+ return self.input_stream.get().closed()
+ elif self.is_writable:
+ return self.output_stream.get().closed()
+ else:
+ return True
+
+ def close(self):
+ if not self.closed:
+ with nogil:
+ if self.is_readable:
+ check_status(self.input_stream.get().Close())
+ else:
+ check_status(self.output_stream.get().Close())
+
+ cdef set_random_access_file(self, shared_ptr[CRandomAccessFile] handle):
+ self.input_stream = <shared_ptr[CInputStream]> handle
+ self.random_access = handle
+ self.is_seekable = True
+
+ cdef set_input_stream(self, shared_ptr[CInputStream] handle):
+ self.input_stream = handle
+ self.random_access.reset()
+ self.is_seekable = False
+
+ cdef set_output_stream(self, shared_ptr[COutputStream] handle):
+ self.output_stream = handle
+
+ cdef shared_ptr[CRandomAccessFile] get_random_access_file(self) except *:
+ self._assert_readable()
+ self._assert_seekable()
+ return self.random_access
+
+ cdef shared_ptr[CInputStream] get_input_stream(self) except *:
+ self._assert_readable()
+ return self.input_stream
+
+ cdef shared_ptr[COutputStream] get_output_stream(self) except *:
+ self._assert_writable()
+ return self.output_stream
+
+ def _assert_open(self):
+ if self.closed:
+ raise ValueError("I/O operation on closed file")
+
+ def _assert_readable(self):
+ self._assert_open()
+ if not self.is_readable:
+ # XXX UnsupportedOperation
+ raise IOError("only valid on readable files")
+
+ def _assert_writable(self):
+ self._assert_open()
+ if not self.is_writable:
+ raise IOError("only valid on writable files")
+
+ def _assert_seekable(self):
+ self._assert_open()
+ if not self.is_seekable:
+ raise IOError("only valid on seekable files")
+
+ def size(self):
+ """
+ Return file size
+ """
+ cdef int64_t size
+
+ handle = self.get_random_access_file()
+ with nogil:
+ size = GetResultValue(handle.get().GetSize())
+
+ return size
+
+ def metadata(self):
+ """
+ Return file metadata
+ """
+ cdef:
+ shared_ptr[const CKeyValueMetadata] c_metadata
+
+ handle = self.get_input_stream()
+ with nogil:
+ c_metadata = GetResultValue(handle.get().ReadMetadata())
+
+ metadata = {}
+ if c_metadata.get() != nullptr:
+ for i in range(c_metadata.get().size()):
+ metadata[frombytes(c_metadata.get().key(i))] = \
+ c_metadata.get().value(i)
+ return metadata
+
+ def tell(self):
+ """
+ Return current stream position
+ """
+ cdef int64_t position
+
+ if self.is_readable:
+ rd_handle = self.get_random_access_file()
+ with nogil:
+ position = GetResultValue(rd_handle.get().Tell())
+ else:
+ wr_handle = self.get_output_stream()
+ with nogil:
+ position = GetResultValue(wr_handle.get().Tell())
+
+ return position
+
+ def seek(self, int64_t position, int whence=0):
+ """
+ Change current file stream position
+
+ Parameters
+ ----------
+ position : int
+ Byte offset, interpreted relative to value of whence argument
+ whence : int, default 0
+ Point of reference for seek offset
+
+ Notes
+ -----
+ Values of whence:
+ * 0 -- start of stream (the default); offset should be zero or positive
+ * 1 -- current stream position; offset may be negative
+ * 2 -- end of stream; offset is usually negative
+
+ Returns
+ -------
+ new_position : the new absolute stream position
+ """
+ cdef int64_t offset
+ handle = self.get_random_access_file()
+
+ with nogil:
+ if whence == 0:
+ offset = position
+ elif whence == 1:
+ offset = GetResultValue(handle.get().Tell())
+ offset = offset + position
+ elif whence == 2:
+ offset = GetResultValue(handle.get().GetSize())
+ offset = offset + position
+ else:
+ with gil:
+ raise ValueError("Invalid value of whence: {0}"
+ .format(whence))
+ check_status(handle.get().Seek(offset))
+
+ return self.tell()
+
+ def flush(self):
+ """
+ Flush the stream, if applicable.
+
+ An error is raised if stream is not writable.
+ """
+ self._assert_open()
+ # For IOBase compatibility, flush() on an input stream is a no-op
+ if self.is_writable:
+ handle = self.get_output_stream()
+ with nogil:
+ check_status(handle.get().Flush())
+
+ def write(self, data):
+ """
+ Write byte from any object implementing buffer protocol (bytes,
+ bytearray, ndarray, pyarrow.Buffer)
+
+ Parameters
+ ----------
+ data : bytes-like object or exporter of buffer protocol
+
+ Returns
+ -------
+ nbytes : number of bytes written
+ """
+ self._assert_writable()
+ handle = self.get_output_stream()
+
+ cdef shared_ptr[CBuffer] buf = as_c_buffer(data)
+
+ with nogil:
+ check_status(handle.get().WriteBuffer(buf))
+ return buf.get().size()
+
+ def read(self, nbytes=None):
+ """
+ Read indicated number of bytes from file, or read all remaining bytes
+ if no argument passed
+
+ Parameters
+ ----------
+ nbytes : int, default None
+
+ Returns
+ -------
+ data : bytes
+ """
+ cdef:
+ int64_t c_nbytes
+ int64_t bytes_read = 0
+ PyObject* obj
+
+ if nbytes is None:
+ if not self.is_seekable:
+ # Cannot get file size => read chunkwise
+ bs = 16384
+ chunks = []
+ while True:
+ chunk = self.read(bs)
+ if not chunk:
+ break
+ chunks.append(chunk)
+ return b"".join(chunks)
+
+ c_nbytes = self.size() - self.tell()
+ else:
+ c_nbytes = nbytes
+
+ handle = self.get_input_stream()
+
+ # Allocate empty write space
+ obj = PyBytes_FromStringAndSizeNative(NULL, c_nbytes)
+
+ cdef uint8_t* buf = <uint8_t*> cp.PyBytes_AS_STRING(<object> obj)
+ with nogil:
+ bytes_read = GetResultValue(handle.get().Read(c_nbytes, buf))
+
+ if bytes_read < c_nbytes:
+ cp._PyBytes_Resize(&obj, <Py_ssize_t> bytes_read)
+
+ return PyObject_to_object(obj)
+
+ def read_at(self, nbytes, offset):
+ """
+ Read indicated number of bytes at offset from the file
+
+ Parameters
+ ----------
+ nbytes : int
+ offset : int
+
+ Returns
+ -------
+ data : bytes
+ """
+ cdef:
+ int64_t c_nbytes
+ int64_t c_offset
+ int64_t bytes_read = 0
+ PyObject* obj
+
+ c_nbytes = nbytes
+
+ c_offset = offset
+
+ handle = self.get_random_access_file()
+
+ # Allocate empty write space
+ obj = PyBytes_FromStringAndSizeNative(NULL, c_nbytes)
+
+ cdef uint8_t* buf = <uint8_t*> cp.PyBytes_AS_STRING(<object> obj)
+ with nogil:
+ bytes_read = GetResultValue(handle.get().
+ ReadAt(c_offset, c_nbytes, buf))
+
+ if bytes_read < c_nbytes:
+ cp._PyBytes_Resize(&obj, <Py_ssize_t> bytes_read)
+
+ return PyObject_to_object(obj)
+
+ def read1(self, nbytes=None):
+ """Read and return up to n bytes.
+
+ Alias for read, needed to match the IOBase interface."""
+ return self.read(nbytes=None)
+
+ def readall(self):
+ return self.read()
+
+ def readinto(self, b):
+ """
+ Read into the supplied buffer
+
+ Parameters
+ -----------
+ b: any python object supporting buffer interface
+
+ Returns
+ --------
+ number of bytes written
+ """
+
+ cdef:
+ int64_t bytes_read
+ uint8_t* buf
+ Buffer py_buf
+ int64_t buf_len
+
+ handle = self.get_input_stream()
+
+ py_buf = py_buffer(b)
+ buf_len = py_buf.size
+ buf = py_buf.buffer.get().mutable_data()
+
+ with nogil:
+ bytes_read = GetResultValue(handle.get().Read(buf_len, buf))
+
+ return bytes_read
+
+ def readline(self, size=None):
+ """NOT IMPLEMENTED. Read and return a line of bytes from the file.
+
+ If size is specified, read at most size bytes.
+
+ Line terminator is always b"\\n".
+ """
+
+ raise UnsupportedOperation()
+
+ def readlines(self, hint=None):
+ """NOT IMPLEMENTED. Read lines of the file
+
+ Parameters
+ -----------
+
+ hint: int maximum number of bytes read until we stop
+ """
+
+ raise UnsupportedOperation()
+
+ def __iter__(self):
+ self._assert_readable()
+ return self
+
+ def __next__(self):
+ line = self.readline()
+ if not line:
+ raise StopIteration
+ return line
+
+ def read_buffer(self, nbytes=None):
+ cdef:
+ int64_t c_nbytes
+ int64_t bytes_read = 0
+ shared_ptr[CBuffer] output
+
+ handle = self.get_input_stream()
+
+ if nbytes is None:
+ if not self.is_seekable:
+ # Cannot get file size => read chunkwise
+ return py_buffer(self.read())
+ c_nbytes = self.size() - self.tell()
+ else:
+ c_nbytes = nbytes
+
+ with nogil:
+ output = GetResultValue(handle.get().ReadBuffer(c_nbytes))
+
+ return pyarrow_wrap_buffer(output)
+
+ def truncate(self):
+ """
+ NOT IMPLEMENTED
+ """
+ raise UnsupportedOperation()
+
+ def writelines(self, lines):
+ self._assert_writable()
+
+ for line in lines:
+ self.write(line)
+
+ def download(self, stream_or_path, buffer_size=None):
+ """
+ Read file completely to local path (rather than reading completely into
+ memory). First seeks to the beginning of the file.
+ """
+ cdef:
+ int64_t bytes_read = 0
+ uint8_t* buf
+
+ handle = self.get_input_stream()
+
+ buffer_size = buffer_size or DEFAULT_BUFFER_SIZE
+
+ write_queue = Queue(50)
+
+ if not hasattr(stream_or_path, 'read'):
+ stream = open(stream_or_path, 'wb')
+
+ def cleanup():
+ stream.close()
+ else:
+ stream = stream_or_path
+
+ def cleanup():
+ pass
+
+ done = False
+ exc_info = None
+
+ def bg_write():
+ try:
+ while not done or write_queue.qsize() > 0:
+ try:
+ buf = write_queue.get(timeout=0.01)
+ except QueueEmpty:
+ continue
+ stream.write(buf)
+ except Exception as e:
+ exc_info = sys.exc_info()
+ finally:
+ cleanup()
+
+ self.seek(0)
+
+ writer_thread = threading.Thread(target=bg_write)
+
+ # This isn't ideal -- PyBytes_FromStringAndSize copies the data from
+ # the passed buffer, so it's hard for us to avoid doubling the memory
+ buf = <uint8_t*> malloc(buffer_size)
+ if buf == NULL:
+ raise MemoryError("Failed to allocate {0} bytes"
+ .format(buffer_size))
+
+ writer_thread.start()
+
+ cdef int64_t total_bytes = 0
+ cdef int32_t c_buffer_size = buffer_size
+
+ try:
+ while True:
+ with nogil:
+ bytes_read = GetResultValue(
+ handle.get().Read(c_buffer_size, buf))
+
+ total_bytes += bytes_read
+
+ # EOF
+ if bytes_read == 0:
+ break
+
+ pybuf = cp.PyBytes_FromStringAndSize(<const char*>buf,
+ bytes_read)
+
+ if writer_thread.is_alive():
+ while write_queue.full():
+ time.sleep(0.01)
+ else:
+ break
+
+ write_queue.put_nowait(pybuf)
+ finally:
+ free(buf)
+ done = True
+
+ writer_thread.join()
+ if exc_info is not None:
+ raise exc_info[0], exc_info[1], exc_info[2]
+
+ def upload(self, stream, buffer_size=None):
+ """
+ Pipe file-like object to file
+ """
+ write_queue = Queue(50)
+ self._assert_writable()
+
+ buffer_size = buffer_size or DEFAULT_BUFFER_SIZE
+
+ done = False
+ exc_info = None
+
+ def bg_write():
+ try:
+ while not done or write_queue.qsize() > 0:
+ try:
+ buf = write_queue.get(timeout=0.01)
+ except QueueEmpty:
+ continue
+
+ self.write(buf)
+
+ except Exception as e:
+ exc_info = sys.exc_info()
+
+ writer_thread = threading.Thread(target=bg_write)
+ writer_thread.start()
+
+ try:
+ while True:
+ buf = stream.read(buffer_size)
+ if not buf:
+ break
+
+ if writer_thread.is_alive():
+ while write_queue.full():
+ time.sleep(0.01)
+ else:
+ break
+
+ write_queue.put_nowait(buf)
+ finally:
+ done = True
+
+ writer_thread.join()
+ if exc_info is not None:
+ raise exc_info[0], exc_info[1], exc_info[2]
+
+BufferedIOBase.register(NativeFile)
+
+# ----------------------------------------------------------------------
+# Python file-like objects
+
+
+cdef class PythonFile(NativeFile):
+ """
+ A stream backed by a Python file object.
+
+ This class allows using Python file objects with arbitrary Arrow
+ functions, including functions written in another language than Python.
+
+ As a downside, there is a non-zero redirection cost in translating
+ Arrow stream calls to Python method calls. Furthermore, Python's
+ Global Interpreter Lock may limit parallelism in some situations.
+ """
+ cdef:
+ object handle
+
+ def __cinit__(self, handle, mode=None):
+ self.handle = handle
+
+ if mode is None:
+ try:
+ inferred_mode = handle.mode
+ except AttributeError:
+ # Not all file-like objects have a mode attribute
+ # (e.g. BytesIO)
+ try:
+ inferred_mode = 'w' if handle.writable() else 'r'
+ except AttributeError:
+ raise ValueError("could not infer open mode for file-like "
+ "object %r, please pass it explicitly"
+ % (handle,))
+ else:
+ inferred_mode = mode
+
+ if inferred_mode.startswith('w'):
+ kind = 'w'
+ elif inferred_mode.startswith('r'):
+ kind = 'r'
+ else:
+ raise ValueError('Invalid file mode: {0}'.format(mode))
+
+ # If mode was given, check it matches the given file
+ if mode is not None:
+ if isinstance(handle, IOBase):
+ # Python 3 IO object
+ if kind == 'r':
+ if not handle.readable():
+ raise TypeError("readable file expected")
+ else:
+ if not handle.writable():
+ raise TypeError("writable file expected")
+ # (other duck-typed file-like objects are possible)
+
+ # If possible, check the file is a binary file
+ if isinstance(handle, TextIOBase):
+ raise TypeError("binary file expected, got text file")
+
+ if kind == 'r':
+ self.set_random_access_file(
+ shared_ptr[CRandomAccessFile](new PyReadableFile(handle)))
+ self.is_readable = True
+ else:
+ self.set_output_stream(
+ shared_ptr[COutputStream](new PyOutputStream(handle)))
+ self.is_writable = True
+
+ def truncate(self, pos=None):
+ self.handle.truncate(pos)
+
+ def readline(self, size=None):
+ return self.handle.readline(size)
+
+ def readlines(self, hint=None):
+ return self.handle.readlines(hint)
+
+
+cdef class MemoryMappedFile(NativeFile):
+ """
+ A stream that represents a memory-mapped file.
+
+ Supports 'r', 'r+', 'w' modes.
+ """
+ cdef:
+ shared_ptr[CMemoryMappedFile] handle
+ object path
+
+ @staticmethod
+ def create(path, size):
+ """
+ Create a MemoryMappedFile
+
+ Parameters
+ ----------
+ path : str
+ Where to create the file.
+ size : int
+ Size of the memory mapped file.
+ """
+ cdef:
+ shared_ptr[CMemoryMappedFile] handle
+ c_string c_path = encode_file_path(path)
+ int64_t c_size = size
+
+ with nogil:
+ handle = GetResultValue(CMemoryMappedFile.Create(c_path, c_size))
+
+ cdef MemoryMappedFile result = MemoryMappedFile()
+ result.path = path
+ result.is_readable = True
+ result.is_writable = True
+ result.set_output_stream(<shared_ptr[COutputStream]> handle)
+ result.set_random_access_file(<shared_ptr[CRandomAccessFile]> handle)
+ result.handle = handle
+
+ return result
+
+ def _open(self, path, mode='r'):
+ self.path = path
+
+ cdef:
+ FileMode c_mode
+ shared_ptr[CMemoryMappedFile] handle
+ c_string c_path = encode_file_path(path)
+
+ if mode in ('r', 'rb'):
+ c_mode = FileMode_READ
+ self.is_readable = True
+ elif mode in ('w', 'wb'):
+ c_mode = FileMode_WRITE
+ self.is_writable = True
+ elif mode in ('r+', 'r+b', 'rb+'):
+ c_mode = FileMode_READWRITE
+ self.is_readable = True
+ self.is_writable = True
+ else:
+ raise ValueError('Invalid file mode: {0}'.format(mode))
+
+ with nogil:
+ handle = GetResultValue(CMemoryMappedFile.Open(c_path, c_mode))
+
+ self.set_output_stream(<shared_ptr[COutputStream]> handle)
+ self.set_random_access_file(<shared_ptr[CRandomAccessFile]> handle)
+ self.handle = handle
+
+ def resize(self, new_size):
+ """
+ Resize the map and underlying file.
+
+ Parameters
+ ----------
+ new_size : new size in bytes
+ """
+ check_status(self.handle.get().Resize(new_size))
+
+ def fileno(self):
+ self._assert_open()
+ return self.handle.get().file_descriptor()
+
+
+def memory_map(path, mode='r'):
+ """
+ Open memory map at file path. Size of the memory map cannot change.
+
+ Parameters
+ ----------
+ path : str
+ mode : {'r', 'r+', 'w'}, default 'r'
+ Whether the file is opened for reading ('r+'), writing ('w')
+ or both ('r+').
+
+ Returns
+ -------
+ mmap : MemoryMappedFile
+ """
+ _check_is_file(path)
+
+ cdef MemoryMappedFile mmap = MemoryMappedFile()
+ mmap._open(path, mode)
+ return mmap
+
+
+cdef _check_is_file(path):
+ if os.path.isdir(path):
+ raise IOError("Expected file path, but {0} is a directory"
+ .format(path))
+
+
+def create_memory_map(path, size):
+ """
+ Create a file of the given size and memory-map it.
+
+ Parameters
+ ----------
+ path : str
+ The file path to create, on the local filesystem.
+ size : int
+ The file size to create.
+
+ Returns
+ -------
+ mmap : MemoryMappedFile
+ """
+ return MemoryMappedFile.create(path, size)
+
+
+cdef class OSFile(NativeFile):
+ """
+ A stream backed by a regular file descriptor.
+ """
+ cdef:
+ object path
+
+ def __cinit__(self, path, mode='r', MemoryPool memory_pool=None):
+ _check_is_file(path)
+ self.path = path
+
+ cdef:
+ FileMode c_mode
+ shared_ptr[Readable] handle
+ c_string c_path = encode_file_path(path)
+
+ if mode in ('r', 'rb'):
+ self._open_readable(c_path, maybe_unbox_memory_pool(memory_pool))
+ elif mode in ('w', 'wb'):
+ self._open_writable(c_path)
+ else:
+ raise ValueError('Invalid file mode: {0}'.format(mode))
+
+ cdef _open_readable(self, c_string path, CMemoryPool* pool):
+ cdef shared_ptr[ReadableFile] handle
+
+ with nogil:
+ handle = GetResultValue(ReadableFile.Open(path, pool))
+
+ self.is_readable = True
+ self.set_random_access_file(<shared_ptr[CRandomAccessFile]> handle)
+
+ cdef _open_writable(self, c_string path):
+ with nogil:
+ self.output_stream = GetResultValue(FileOutputStream.Open(path))
+ self.is_writable = True
+
+ def fileno(self):
+ self._assert_open()
+ return self.handle.file_descriptor()
+
+
+cdef class FixedSizeBufferWriter(NativeFile):
+ """
+ A stream writing to a Arrow buffer.
+ """
+
+ def __cinit__(self, Buffer buffer):
+ self.output_stream.reset(new CFixedSizeBufferWriter(buffer.buffer))
+ self.is_writable = True
+
+ def set_memcopy_threads(self, int num_threads):
+ cdef CFixedSizeBufferWriter* writer = \
+ <CFixedSizeBufferWriter*> self.output_stream.get()
+ writer.set_memcopy_threads(num_threads)
+
+ def set_memcopy_blocksize(self, int64_t blocksize):
+ cdef CFixedSizeBufferWriter* writer = \
+ <CFixedSizeBufferWriter*> self.output_stream.get()
+ writer.set_memcopy_blocksize(blocksize)
+
+ def set_memcopy_threshold(self, int64_t threshold):
+ cdef CFixedSizeBufferWriter* writer = \
+ <CFixedSizeBufferWriter*> self.output_stream.get()
+ writer.set_memcopy_threshold(threshold)
+
+
+# ----------------------------------------------------------------------
+# Arrow buffers
+
+
+cdef class Buffer(_Weakrefable):
+ """
+ The base class for all Arrow buffers.
+
+ A buffer represents a contiguous memory area. Many buffers will own
+ their memory, though not all of them do.
+ """
+
+ def __cinit__(self):
+ pass
+
+ def __init__(self):
+ raise TypeError("Do not call Buffer's constructor directly, use "
+ "`pyarrow.py_buffer` function instead.")
+
+ cdef void init(self, const shared_ptr[CBuffer]& buffer):
+ self.buffer = buffer
+ self.shape[0] = self.size
+ self.strides[0] = <Py_ssize_t>(1)
+
+ def __len__(self):
+ return self.size
+
+ @property
+ def size(self):
+ """
+ The buffer size in bytes.
+ """
+ return self.buffer.get().size()
+
+ @property
+ def address(self):
+ """
+ The buffer's address, as an integer.
+
+ The returned address may point to CPU or device memory.
+ Use `is_cpu()` to disambiguate.
+ """
+ return self.buffer.get().address()
+
+ def hex(self):
+ """
+ Compute hexadecimal representation of the buffer.
+
+ Returns
+ -------
+ : bytes
+ """
+ return self.buffer.get().ToHexString()
+
+ @property
+ def is_mutable(self):
+ """
+ Whether the buffer is mutable.
+ """
+ return self.buffer.get().is_mutable()
+
+ @property
+ def is_cpu(self):
+ """
+ Whether the buffer is CPU-accessible.
+ """
+ return self.buffer.get().is_cpu()
+
+ @property
+ def parent(self):
+ cdef shared_ptr[CBuffer] parent_buf = self.buffer.get().parent()
+
+ if parent_buf.get() == NULL:
+ return None
+ else:
+ return pyarrow_wrap_buffer(parent_buf)
+
+ def __getitem__(self, key):
+ if PySlice_Check(key):
+ if (key.step or 1) != 1:
+ raise IndexError('only slices with step 1 supported')
+ return _normalize_slice(self, key)
+
+ return self.getitem(_normalize_index(key, self.size))
+
+ cdef getitem(self, int64_t i):
+ return self.buffer.get().data()[i]
+
+ def slice(self, offset=0, length=None):
+ """
+ Slice this buffer. Memory is not copied.
+
+ You can also use the Python slice notation ``buffer[start:stop]``.
+
+ Parameters
+ ----------
+ offset : int, default 0
+ Offset from start of buffer to slice.
+ length : int, default None
+ Length of slice (default is until end of Buffer starting from
+ offset).
+
+ Returns
+ -------
+ sliced : Buffer
+ A logical view over this buffer.
+ """
+ cdef shared_ptr[CBuffer] result
+
+ if offset < 0:
+ raise IndexError('Offset must be non-negative')
+
+ if length is None:
+ result = SliceBuffer(self.buffer, offset)
+ else:
+ result = SliceBuffer(self.buffer, offset, max(length, 0))
+
+ return pyarrow_wrap_buffer(result)
+
+ def equals(self, Buffer other):
+ """
+ Determine if two buffers contain exactly the same data.
+
+ Parameters
+ ----------
+ other : Buffer
+
+ Returns
+ -------
+ are_equal : True if buffer contents and size are equal
+ """
+ cdef c_bool result = False
+ with nogil:
+ result = self.buffer.get().Equals(deref(other.buffer.get()))
+ return result
+
+ def __eq__(self, other):
+ if isinstance(other, Buffer):
+ return self.equals(other)
+ else:
+ return self.equals(py_buffer(other))
+
+ def __reduce_ex__(self, protocol):
+ if protocol >= 5:
+ return py_buffer, (builtin_pickle.PickleBuffer(self),)
+ else:
+ return py_buffer, (self.to_pybytes(),)
+
+ def to_pybytes(self):
+ """
+ Return this buffer as a Python bytes object. Memory is copied.
+ """
+ return cp.PyBytes_FromStringAndSize(
+ <const char*>self.buffer.get().data(),
+ self.buffer.get().size())
+
+ def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
+ if self.buffer.get().is_mutable():
+ buffer.readonly = 0
+ else:
+ if flags & cp.PyBUF_WRITABLE:
+ raise BufferError("Writable buffer requested but Arrow "
+ "buffer was not mutable")
+ buffer.readonly = 1
+ buffer.buf = <char *>self.buffer.get().data()
+ buffer.format = 'b'
+ buffer.internal = NULL
+ buffer.itemsize = 1
+ buffer.len = self.size
+ buffer.ndim = 1
+ buffer.obj = self
+ buffer.shape = self.shape
+ buffer.strides = self.strides
+ buffer.suboffsets = NULL
+
+ def __getsegcount__(self, Py_ssize_t *len_out):
+ if len_out != NULL:
+ len_out[0] = <Py_ssize_t>self.size
+ return 1
+
+ def __getreadbuffer__(self, Py_ssize_t idx, void **p):
+ if idx != 0:
+ raise SystemError("accessing non-existent buffer segment")
+ if p != NULL:
+ p[0] = <void*> self.buffer.get().data()
+ return self.size
+
+ def __getwritebuffer__(self, Py_ssize_t idx, void **p):
+ if not self.buffer.get().is_mutable():
+ raise SystemError("trying to write an immutable buffer")
+ if idx != 0:
+ raise SystemError("accessing non-existent buffer segment")
+ if p != NULL:
+ p[0] = <void*> self.buffer.get().data()
+ return self.size
+
+
+cdef class ResizableBuffer(Buffer):
+ """
+ A base class for buffers that can be resized.
+ """
+
+ cdef void init_rz(self, const shared_ptr[CResizableBuffer]& buffer):
+ self.init(<shared_ptr[CBuffer]> buffer)
+
+ def resize(self, int64_t new_size, shrink_to_fit=False):
+ """
+ Resize buffer to indicated size.
+
+ Parameters
+ ----------
+ new_size : int
+ New size of buffer (padding may be added internally).
+ shrink_to_fit : bool, default False
+ If this is true, the buffer is shrunk when new_size is less
+ than the current size.
+ If this is false, the buffer is never shrunk.
+ """
+ cdef c_bool c_shrink_to_fit = shrink_to_fit
+ with nogil:
+ check_status((<CResizableBuffer*> self.buffer.get())
+ .Resize(new_size, c_shrink_to_fit))
+
+
+cdef shared_ptr[CResizableBuffer] _allocate_buffer(CMemoryPool* pool) except *:
+ with nogil:
+ return to_shared(GetResultValue(AllocateResizableBuffer(0, pool)))
+
+
+def allocate_buffer(int64_t size, MemoryPool memory_pool=None,
+ resizable=False):
+ """
+ Allocate a mutable buffer.
+
+ Parameters
+ ----------
+ size : int
+ Number of bytes to allocate (plus internal padding)
+ memory_pool : MemoryPool, optional
+ The pool to allocate memory from.
+ If not given, the default memory pool is used.
+ resizable : bool, default False
+ If true, the returned buffer is resizable.
+
+ Returns
+ -------
+ buffer : Buffer or ResizableBuffer
+ """
+ cdef:
+ CMemoryPool* cpool = maybe_unbox_memory_pool(memory_pool)
+ shared_ptr[CResizableBuffer] c_rz_buffer
+ shared_ptr[CBuffer] c_buffer
+
+ if resizable:
+ with nogil:
+ c_rz_buffer = to_shared(GetResultValue(
+ AllocateResizableBuffer(size, cpool)))
+ return pyarrow_wrap_resizable_buffer(c_rz_buffer)
+ else:
+ with nogil:
+ c_buffer = to_shared(GetResultValue(AllocateBuffer(size, cpool)))
+ return pyarrow_wrap_buffer(c_buffer)
+
+
+cdef class BufferOutputStream(NativeFile):
+
+ cdef:
+ shared_ptr[CResizableBuffer] buffer
+
+ def __cinit__(self, MemoryPool memory_pool=None):
+ self.buffer = _allocate_buffer(maybe_unbox_memory_pool(memory_pool))
+ self.output_stream.reset(new CBufferOutputStream(
+ <shared_ptr[CResizableBuffer]> self.buffer))
+ self.is_writable = True
+
+ def getvalue(self):
+ """
+ Finalize output stream and return result as pyarrow.Buffer.
+
+ Returns
+ -------
+ value : Buffer
+ """
+ with nogil:
+ check_status(self.output_stream.get().Close())
+ return pyarrow_wrap_buffer(<shared_ptr[CBuffer]> self.buffer)
+
+
+cdef class MockOutputStream(NativeFile):
+
+ def __cinit__(self):
+ self.output_stream.reset(new CMockOutputStream())
+ self.is_writable = True
+
+ def size(self):
+ handle = <CMockOutputStream*> self.output_stream.get()
+ return handle.GetExtentBytesWritten()
+
+
+cdef class BufferReader(NativeFile):
+ """
+ Zero-copy reader from objects convertible to Arrow buffer.
+
+ Parameters
+ ----------
+ obj : Python bytes or pyarrow.Buffer
+ """
+ cdef:
+ Buffer buffer
+
+ def __cinit__(self, object obj):
+ self.buffer = as_buffer(obj)
+ self.set_random_access_file(shared_ptr[CRandomAccessFile](
+ new CBufferReader(self.buffer.buffer)))
+ self.is_readable = True
+
+
+cdef class CompressedInputStream(NativeFile):
+ """
+ An input stream wrapper which decompresses data on the fly.
+
+ Parameters
+ ----------
+ stream : string, path, pa.NativeFile, or file-like object
+ Input stream object to wrap with the compression.
+ compression : str
+ The compression type ("bz2", "brotli", "gzip", "lz4" or "zstd").
+ """
+
+ def __init__(self, object stream, str compression not None):
+ cdef:
+ NativeFile nf
+ Codec codec = Codec(compression)
+ shared_ptr[CInputStream] c_reader
+ shared_ptr[CCompressedInputStream] compressed_stream
+ nf = get_native_file(stream, False)
+ c_reader = nf.get_input_stream()
+ compressed_stream = GetResultValue(
+ CCompressedInputStream.Make(codec.unwrap(), c_reader)
+ )
+ self.set_input_stream(<shared_ptr[CInputStream]> compressed_stream)
+ self.is_readable = True
+
+
+cdef class CompressedOutputStream(NativeFile):
+ """
+ An output stream wrapper which compresses data on the fly.
+
+ Parameters
+ ----------
+ stream : string, path, pa.NativeFile, or file-like object
+ Input stream object to wrap with the compression.
+ compression : str
+ The compression type ("bz2", "brotli", "gzip", "lz4" or "zstd").
+ """
+
+ def __init__(self, object stream, str compression not None):
+ cdef:
+ Codec codec = Codec(compression)
+ shared_ptr[COutputStream] c_writer
+ shared_ptr[CCompressedOutputStream] compressed_stream
+ get_writer(stream, &c_writer)
+ compressed_stream = GetResultValue(
+ CCompressedOutputStream.Make(codec.unwrap(), c_writer)
+ )
+ self.set_output_stream(<shared_ptr[COutputStream]> compressed_stream)
+ self.is_writable = True
+
+
+ctypedef CBufferedInputStream* _CBufferedInputStreamPtr
+ctypedef CBufferedOutputStream* _CBufferedOutputStreamPtr
+ctypedef CRandomAccessFile* _RandomAccessFilePtr
+
+
+cdef class BufferedInputStream(NativeFile):
+ """
+ An input stream that performs buffered reads from
+ an unbuffered input stream, which can mitigate the overhead
+ of many small reads in some cases.
+
+ Parameters
+ ----------
+ stream : NativeFile
+ The input stream to wrap with the buffer
+ buffer_size : int
+ Size of the temporary read buffer.
+ memory_pool : MemoryPool
+ The memory pool used to allocate the buffer.
+ """
+
+ def __init__(self, NativeFile stream, int buffer_size,
+ MemoryPool memory_pool=None):
+ cdef shared_ptr[CBufferedInputStream] buffered_stream
+
+ if buffer_size <= 0:
+ raise ValueError('Buffer size must be larger than zero')
+ buffered_stream = GetResultValue(CBufferedInputStream.Create(
+ buffer_size, maybe_unbox_memory_pool(memory_pool),
+ stream.get_input_stream()))
+
+ self.set_input_stream(<shared_ptr[CInputStream]> buffered_stream)
+ self.is_readable = True
+
+ def detach(self):
+ """
+ Release the raw InputStream.
+ Further operations on this stream are invalid.
+
+ Returns
+ -------
+ raw : NativeFile
+ The underlying raw input stream
+ """
+ cdef:
+ shared_ptr[CInputStream] c_raw
+ _CBufferedInputStreamPtr buffered
+ NativeFile raw
+
+ buffered = dynamic_cast[_CBufferedInputStreamPtr](
+ self.input_stream.get())
+ assert buffered != nullptr
+
+ with nogil:
+ c_raw = GetResultValue(buffered.Detach())
+
+ raw = NativeFile()
+ raw.is_readable = True
+ # Find out whether the raw stream is a RandomAccessFile
+ # or a mere InputStream. This helps us support seek() etc.
+ # selectively.
+ if dynamic_cast[_RandomAccessFilePtr](c_raw.get()) != nullptr:
+ raw.set_random_access_file(
+ static_pointer_cast[CRandomAccessFile, CInputStream](c_raw))
+ else:
+ raw.set_input_stream(c_raw)
+ return raw
+
+
+cdef class BufferedOutputStream(NativeFile):
+ """
+ An output stream that performs buffered reads from
+ an unbuffered output stream, which can mitigate the overhead
+ of many small writes in some cases.
+
+ Parameters
+ ----------
+ stream : NativeFile
+ The writable output stream to wrap with the buffer
+ buffer_size : int
+ Size of the buffer that should be added.
+ memory_pool : MemoryPool
+ The memory pool used to allocate the buffer.
+ """
+
+ def __init__(self, NativeFile stream, int buffer_size,
+ MemoryPool memory_pool=None):
+ cdef shared_ptr[CBufferedOutputStream] buffered_stream
+
+ if buffer_size <= 0:
+ raise ValueError('Buffer size must be larger than zero')
+ buffered_stream = GetResultValue(CBufferedOutputStream.Create(
+ buffer_size, maybe_unbox_memory_pool(memory_pool),
+ stream.get_output_stream()))
+
+ self.set_output_stream(<shared_ptr[COutputStream]> buffered_stream)
+ self.is_writable = True
+
+ def detach(self):
+ """
+ Flush any buffered writes and release the raw OutputStream.
+ Further operations on this stream are invalid.
+
+ Returns
+ -------
+ raw : NativeFile
+ The underlying raw output stream.
+ """
+ cdef:
+ shared_ptr[COutputStream] c_raw
+ _CBufferedOutputStreamPtr buffered
+ NativeFile raw
+
+ buffered = dynamic_cast[_CBufferedOutputStreamPtr](
+ self.output_stream.get())
+ assert buffered != nullptr
+
+ with nogil:
+ c_raw = GetResultValue(buffered.Detach())
+
+ raw = NativeFile()
+ raw.is_writable = True
+ raw.set_output_stream(c_raw)
+ return raw
+
+
+cdef void _cb_transform(transform_func, const shared_ptr[CBuffer]& src,
+ shared_ptr[CBuffer]* dest) except *:
+ py_dest = transform_func(pyarrow_wrap_buffer(src))
+ dest[0] = pyarrow_unwrap_buffer(py_buffer(py_dest))
+
+
+cdef class TransformInputStream(NativeFile):
+ """
+ Transform an input stream.
+
+ Parameters
+ ----------
+ stream : NativeFile
+ The stream to transform.
+ transform_func : callable
+ The transformation to apply.
+ """
+
+ def __init__(self, NativeFile stream, transform_func):
+ self.set_input_stream(TransformInputStream.make_native(
+ stream.get_input_stream(), transform_func))
+ self.is_readable = True
+
+ @staticmethod
+ cdef shared_ptr[CInputStream] make_native(
+ shared_ptr[CInputStream] stream, transform_func) except *:
+ cdef:
+ shared_ptr[CInputStream] transform_stream
+ CTransformInputStreamVTable vtable
+
+ vtable.transform = _cb_transform
+ return MakeTransformInputStream(stream, move(vtable),
+ transform_func)
+
+
+class Transcoder:
+
+ def __init__(self, decoder, encoder):
+ self._decoder = decoder
+ self._encoder = encoder
+
+ def __call__(self, buf):
+ final = len(buf) == 0
+ return self._encoder.encode(self._decoder.decode(buf, final), final)
+
+
+def transcoding_input_stream(stream, src_encoding, dest_encoding):
+ """
+ Add a transcoding transformation to the stream.
+ Incoming data will be decoded according to ``src_encoding`` and
+ then re-encoded according to ``dest_encoding``.
+
+ Parameters
+ ----------
+ stream : NativeFile
+ The stream to which the transformation should be applied.
+ src_encoding : str
+ The codec to use when reading data data.
+ dest_encoding : str
+ The codec to use for emitted data.
+ """
+ src_codec = codecs.lookup(src_encoding)
+ dest_codec = codecs.lookup(dest_encoding)
+ if src_codec.name == dest_codec.name:
+ # Avoid losing performance on no-op transcoding
+ # (encoding errors won't be detected)
+ return stream
+ return TransformInputStream(stream,
+ Transcoder(src_codec.incrementaldecoder(),
+ dest_codec.incrementalencoder()))
+
+
+cdef shared_ptr[CInputStream] native_transcoding_input_stream(
+ shared_ptr[CInputStream] stream, src_encoding,
+ dest_encoding) except *:
+ src_codec = codecs.lookup(src_encoding)
+ dest_codec = codecs.lookup(dest_encoding)
+ if src_codec.name == dest_codec.name:
+ # Avoid losing performance on no-op transcoding
+ # (encoding errors won't be detected)
+ return stream
+ return TransformInputStream.make_native(
+ stream, Transcoder(src_codec.incrementaldecoder(),
+ dest_codec.incrementalencoder()))
+
+
+def py_buffer(object obj):
+ """
+ Construct an Arrow buffer from a Python bytes-like or buffer-like object
+
+ Parameters
+ ----------
+ obj : object
+ the object from which the buffer should be constructed.
+ """
+ cdef shared_ptr[CBuffer] buf
+ buf = GetResultValue(PyBuffer.FromPyObject(obj))
+ return pyarrow_wrap_buffer(buf)
+
+
+def foreign_buffer(address, size, base=None):
+ """
+ Construct an Arrow buffer with the given *address* and *size*.
+
+ The buffer will be optionally backed by the Python *base* object, if given.
+ The *base* object will be kept alive as long as this buffer is alive,
+ including across language boundaries (for example if the buffer is
+ referenced by C++ code).
+
+ Parameters
+ ----------
+ address : int
+ The starting address of the buffer. The address can
+ refer to both device or host memory but it must be
+ accessible from device after mapping it with
+ `get_device_address` method.
+ size : int
+ The size of device buffer in bytes.
+ base : {None, object}
+ Object that owns the referenced memory.
+ """
+ cdef:
+ intptr_t c_addr = address
+ int64_t c_size = size
+ shared_ptr[CBuffer] buf
+
+ check_status(PyForeignBuffer.Make(<uint8_t*> c_addr, c_size,
+ base, &buf))
+ return pyarrow_wrap_buffer(buf)
+
+
+def as_buffer(object o):
+ if isinstance(o, Buffer):
+ return o
+ return py_buffer(o)
+
+
+cdef shared_ptr[CBuffer] as_c_buffer(object o) except *:
+ cdef shared_ptr[CBuffer] buf
+ if isinstance(o, Buffer):
+ buf = (<Buffer> o).buffer
+ if buf == nullptr:
+ raise ValueError("got null buffer")
+ else:
+ buf = GetResultValue(PyBuffer.FromPyObject(o))
+ return buf
+
+
+cdef NativeFile get_native_file(object source, c_bool use_memory_map):
+ try:
+ source_path = _stringify_path(source)
+ except TypeError:
+ if isinstance(source, Buffer):
+ source = BufferReader(source)
+ elif not isinstance(source, NativeFile) and hasattr(source, 'read'):
+ # Optimistically hope this is file-like
+ source = PythonFile(source, mode='r')
+ else:
+ if use_memory_map:
+ source = memory_map(source_path, mode='r')
+ else:
+ source = OSFile(source_path, mode='r')
+
+ return source
+
+
+cdef get_reader(object source, c_bool use_memory_map,
+ shared_ptr[CRandomAccessFile]* reader):
+ cdef NativeFile nf
+
+ nf = get_native_file(source, use_memory_map)
+ reader[0] = nf.get_random_access_file()
+
+
+cdef get_input_stream(object source, c_bool use_memory_map,
+ shared_ptr[CInputStream]* out):
+ """
+ Like get_reader(), but can automatically decompress, and returns
+ an InputStream.
+ """
+ cdef:
+ NativeFile nf
+ Codec codec
+ shared_ptr[CInputStream] input_stream
+
+ try:
+ codec = Codec.detect(source)
+ except TypeError:
+ codec = None
+
+ nf = get_native_file(source, use_memory_map)
+ input_stream = nf.get_input_stream()
+
+ # codec is None if compression can't be detected
+ if codec is not None:
+ input_stream = <shared_ptr[CInputStream]> GetResultValue(
+ CCompressedInputStream.Make(codec.unwrap(), input_stream)
+ )
+
+ out[0] = input_stream
+
+
+cdef get_writer(object source, shared_ptr[COutputStream]* writer):
+ cdef NativeFile nf
+
+ try:
+ source_path = _stringify_path(source)
+ except TypeError:
+ if not isinstance(source, NativeFile) and hasattr(source, 'write'):
+ # Optimistically hope this is file-like
+ source = PythonFile(source, mode='w')
+ else:
+ source = OSFile(source_path, mode='w')
+
+ if isinstance(source, NativeFile):
+ nf = source
+ writer[0] = nf.get_output_stream()
+ else:
+ raise TypeError('Unable to read from object of type: {0}'
+ .format(type(source)))
+
+
+# ---------------------------------------------------------------------
+
+
+def _detect_compression(path):
+ if isinstance(path, str):
+ if path.endswith('.bz2'):
+ return 'bz2'
+ elif path.endswith('.gz'):
+ return 'gzip'
+ elif path.endswith('.lz4'):
+ return 'lz4'
+ elif path.endswith('.zst'):
+ return 'zstd'
+
+
+cdef CCompressionType _ensure_compression(str name) except *:
+ uppercase = name.upper()
+ if uppercase == 'BZ2':
+ return CCompressionType_BZ2
+ elif uppercase == 'GZIP':
+ return CCompressionType_GZIP
+ elif uppercase == 'BROTLI':
+ return CCompressionType_BROTLI
+ elif uppercase == 'LZ4' or uppercase == 'LZ4_FRAME':
+ return CCompressionType_LZ4_FRAME
+ elif uppercase == 'LZ4_RAW':
+ return CCompressionType_LZ4
+ elif uppercase == 'SNAPPY':
+ return CCompressionType_SNAPPY
+ elif uppercase == 'ZSTD':
+ return CCompressionType_ZSTD
+ else:
+ raise ValueError('Invalid value for compression: {!r}'.format(name))
+
+
+cdef class Codec(_Weakrefable):
+ """
+ Compression codec.
+
+ Parameters
+ ----------
+ compression : str
+ Type of compression codec to initialize, valid values are: 'gzip',
+ 'bz2', 'brotli', 'lz4' (or 'lz4_frame'), 'lz4_raw', 'zstd' and
+ 'snappy'.
+ compression_level : int, None
+ Optional parameter specifying how aggressively to compress. The
+ possible ranges and effect of this parameter depend on the specific
+ codec chosen. Higher values compress more but typically use more
+ resources (CPU/RAM). Some codecs support negative values.
+
+ gzip
+ The compression_level maps to the memlevel parameter of
+ deflateInit2. Higher levels use more RAM but are faster
+ and should have higher compression ratios.
+
+ bz2
+ The compression level maps to the blockSize100k parameter of
+ the BZ2_bzCompressInit function. Higher levels use more RAM
+ but are faster and should have higher compression ratios.
+
+ brotli
+ The compression level maps to the BROTLI_PARAM_QUALITY
+ parameter. Higher values are slower and should have higher
+ compression ratios.
+
+ lz4/lz4_frame/lz4_raw
+ The compression level parameter is not supported and must
+ be None
+
+ zstd
+ The compression level maps to the compressionLevel parameter
+ of ZSTD_initCStream. Negative values are supported. Higher
+ values are slower and should have higher compression ratios.
+
+ snappy
+ The compression level parameter is not supported and must
+ be None
+
+
+ Raises
+ ------
+ ValueError
+ If invalid compression value is passed.
+ """
+
+ def __init__(self, str compression not None, compression_level=None):
+ cdef CCompressionType typ = _ensure_compression(compression)
+ if compression_level is not None:
+ self.wrapped = shared_ptr[CCodec](move(GetResultValue(
+ CCodec.CreateWithLevel(typ, compression_level))))
+ else:
+ self.wrapped = shared_ptr[CCodec](move(GetResultValue(
+ CCodec.Create(typ))))
+
+ cdef inline CCodec* unwrap(self) nogil:
+ return self.wrapped.get()
+
+ @staticmethod
+ def detect(path):
+ """
+ Detect and instantiate compression codec based on file extension.
+
+ Parameters
+ ----------
+ path : str, path-like
+ File-path to detect compression from.
+
+ Raises
+ ------
+ TypeError
+ If the passed value is not path-like.
+ ValueError
+ If the compression can't be detected from the path.
+
+ Returns
+ -------
+ Codec
+ """
+ return Codec(_detect_compression(_stringify_path(path)))
+
+ @staticmethod
+ def is_available(str compression not None):
+ """
+ Returns whether the compression support has been built and enabled.
+
+ Parameters
+ ----------
+ compression : str
+ Type of compression codec,
+ refer to Codec docstring for a list of supported ones.
+
+ Returns
+ -------
+ bool
+ """
+ cdef CCompressionType typ = _ensure_compression(compression)
+ return CCodec.IsAvailable(typ)
+
+ @staticmethod
+ def supports_compression_level(str compression not None):
+ """
+ Returns true if the compression level parameter is supported
+ for the given codec.
+
+ Parameters
+ ----------
+ compression : str
+ Type of compression codec,
+ refer to Codec docstring for a list of supported ones.
+ """
+ cdef CCompressionType typ = _ensure_compression(compression)
+ return CCodec.SupportsCompressionLevel(typ)
+
+ @staticmethod
+ def default_compression_level(str compression not None):
+ """
+ Returns the compression level that Arrow will use for the codec if
+ None is specified.
+
+ Parameters
+ ----------
+ compression : str
+ Type of compression codec,
+ refer to Codec docstring for a list of supported ones.
+ """
+ cdef CCompressionType typ = _ensure_compression(compression)
+ return GetResultValue(CCodec.DefaultCompressionLevel(typ))
+
+ @staticmethod
+ def minimum_compression_level(str compression not None):
+ """
+ Returns the smallest valid value for the compression level
+
+ Parameters
+ ----------
+ compression : str
+ Type of compression codec,
+ refer to Codec docstring for a list of supported ones.
+ """
+ cdef CCompressionType typ = _ensure_compression(compression)
+ return GetResultValue(CCodec.MinimumCompressionLevel(typ))
+
+ @staticmethod
+ def maximum_compression_level(str compression not None):
+ """
+ Returns the largest valid value for the compression level
+
+ Parameters
+ ----------
+ compression : str
+ Type of compression codec,
+ refer to Codec docstring for a list of supported ones.
+ """
+ cdef CCompressionType typ = _ensure_compression(compression)
+ return GetResultValue(CCodec.MaximumCompressionLevel(typ))
+
+ @property
+ def name(self):
+ """Returns the name of the codec"""
+ return frombytes(self.unwrap().name())
+
+ @property
+ def compression_level(self):
+ """Returns the compression level parameter of the codec"""
+ return frombytes(self.unwrap().compression_level())
+
+ def compress(self, object buf, asbytes=False, memory_pool=None):
+ """
+ Compress data from buffer-like object.
+
+ Parameters
+ ----------
+ buf : pyarrow.Buffer, bytes, or other object supporting buffer protocol
+ asbytes : bool, default False
+ Return result as Python bytes object, otherwise Buffer
+ memory_pool : MemoryPool, default None
+ Memory pool to use for buffer allocations, if any
+
+ Returns
+ -------
+ compressed : pyarrow.Buffer or bytes (if asbytes=True)
+ """
+ cdef:
+ shared_ptr[CBuffer] owned_buf
+ CBuffer* c_buf
+ PyObject* pyobj
+ ResizableBuffer out_buf
+ int64_t max_output_size
+ int64_t output_length
+ uint8_t* output_buffer = NULL
+
+ owned_buf = as_c_buffer(buf)
+ c_buf = owned_buf.get()
+
+ max_output_size = self.wrapped.get().MaxCompressedLen(
+ c_buf.size(), c_buf.data()
+ )
+
+ if asbytes:
+ pyobj = PyBytes_FromStringAndSizeNative(NULL, max_output_size)
+ output_buffer = <uint8_t*> cp.PyBytes_AS_STRING(<object> pyobj)
+ else:
+ out_buf = allocate_buffer(
+ max_output_size, memory_pool=memory_pool, resizable=True
+ )
+ output_buffer = out_buf.buffer.get().mutable_data()
+
+ with nogil:
+ output_length = GetResultValue(
+ self.unwrap().Compress(
+ c_buf.size(),
+ c_buf.data(),
+ max_output_size,
+ output_buffer
+ )
+ )
+
+ if asbytes:
+ cp._PyBytes_Resize(&pyobj, <Py_ssize_t> output_length)
+ return PyObject_to_object(pyobj)
+ else:
+ out_buf.resize(output_length)
+ return out_buf
+
+ def decompress(self, object buf, decompressed_size=None, asbytes=False,
+ memory_pool=None):
+ """
+ Decompress data from buffer-like object.
+
+ Parameters
+ ----------
+ buf : pyarrow.Buffer, bytes, or memoryview-compatible object
+ decompressed_size : int64_t, default None
+ If not specified, will be computed if the codec is able to
+ determine the uncompressed buffer size.
+ asbytes : boolean, default False
+ Return result as Python bytes object, otherwise Buffer
+ memory_pool : MemoryPool, default None
+ Memory pool to use for buffer allocations, if any.
+
+ Returns
+ -------
+ uncompressed : pyarrow.Buffer or bytes (if asbytes=True)
+ """
+ cdef:
+ shared_ptr[CBuffer] owned_buf
+ CBuffer* c_buf
+ Buffer out_buf
+ int64_t output_size
+ uint8_t* output_buffer = NULL
+
+ owned_buf = as_c_buffer(buf)
+ c_buf = owned_buf.get()
+
+ if decompressed_size is None:
+ raise ValueError(
+ "Must pass decompressed_size for {} codec".format(self)
+ )
+
+ output_size = decompressed_size
+
+ if asbytes:
+ pybuf = cp.PyBytes_FromStringAndSize(NULL, output_size)
+ output_buffer = <uint8_t*> cp.PyBytes_AS_STRING(pybuf)
+ else:
+ out_buf = allocate_buffer(output_size, memory_pool=memory_pool)
+ output_buffer = out_buf.buffer.get().mutable_data()
+
+ with nogil:
+ GetResultValue(
+ self.unwrap().Decompress(
+ c_buf.size(),
+ c_buf.data(),
+ output_size,
+ output_buffer
+ )
+ )
+
+ return pybuf if asbytes else out_buf
+
+
+def compress(object buf, codec='lz4', asbytes=False, memory_pool=None):
+ """
+ Compress data from buffer-like object.
+
+ Parameters
+ ----------
+ buf : pyarrow.Buffer, bytes, or other object supporting buffer protocol
+ codec : str, default 'lz4'
+ Compression codec.
+ Supported types: {'brotli, 'gzip', 'lz4', 'lz4_raw', 'snappy', 'zstd'}
+ asbytes : bool, default False
+ Return result as Python bytes object, otherwise Buffer.
+ memory_pool : MemoryPool, default None
+ Memory pool to use for buffer allocations, if any.
+
+ Returns
+ -------
+ compressed : pyarrow.Buffer or bytes (if asbytes=True)
+ """
+ cdef Codec coder = Codec(codec)
+ return coder.compress(buf, asbytes=asbytes, memory_pool=memory_pool)
+
+
+def decompress(object buf, decompressed_size=None, codec='lz4',
+ asbytes=False, memory_pool=None):
+ """
+ Decompress data from buffer-like object.
+
+ Parameters
+ ----------
+ buf : pyarrow.Buffer, bytes, or memoryview-compatible object
+ Input object to decompress data from.
+ decompressed_size : int64_t, default None
+ If not specified, will be computed if the codec is able to determine
+ the uncompressed buffer size.
+ codec : str, default 'lz4'
+ Compression codec.
+ Supported types: {'brotli, 'gzip', 'lz4', 'lz4_raw', 'snappy', 'zstd'}
+ asbytes : bool, default False
+ Return result as Python bytes object, otherwise Buffer.
+ memory_pool : MemoryPool, default None
+ Memory pool to use for buffer allocations, if any.
+
+ Returns
+ -------
+ uncompressed : pyarrow.Buffer or bytes (if asbytes=True)
+ """
+ cdef Codec decoder = Codec(codec)
+ return decoder.decompress(buf, asbytes=asbytes, memory_pool=memory_pool,
+ decompressed_size=decompressed_size)
+
+
+def input_stream(source, compression='detect', buffer_size=None):
+ """
+ Create an Arrow input stream.
+
+ Parameters
+ ----------
+ source : str, Path, buffer, file-like object, ...
+ The source to open for reading.
+ compression : str optional, default 'detect'
+ The compression algorithm to use for on-the-fly decompression.
+ If "detect" and source is a file path, then compression will be
+ chosen based on the file extension.
+ If None, no compression will be applied.
+ Otherwise, a well-known algorithm name must be supplied (e.g. "gzip").
+ buffer_size : int, default None
+ If None or 0, no buffering will happen. Otherwise the size of the
+ temporary read buffer.
+ """
+ cdef NativeFile stream
+
+ try:
+ source_path = _stringify_path(source)
+ except TypeError:
+ source_path = None
+
+ if isinstance(source, NativeFile):
+ stream = source
+ elif source_path is not None:
+ stream = OSFile(source_path, 'r')
+ elif isinstance(source, (Buffer, memoryview)):
+ stream = BufferReader(as_buffer(source))
+ elif (hasattr(source, 'read') and
+ hasattr(source, 'close') and
+ hasattr(source, 'closed')):
+ stream = PythonFile(source, 'r')
+ else:
+ raise TypeError("pa.input_stream() called with instance of '{}'"
+ .format(source.__class__))
+
+ if compression == 'detect':
+ # detect for OSFile too
+ compression = _detect_compression(source_path)
+
+ if buffer_size is not None and buffer_size != 0:
+ stream = BufferedInputStream(stream, buffer_size)
+
+ if compression is not None:
+ stream = CompressedInputStream(stream, compression)
+
+ return stream
+
+
+def output_stream(source, compression='detect', buffer_size=None):
+ """
+ Create an Arrow output stream.
+
+ Parameters
+ ----------
+ source : str, Path, buffer, file-like object, ...
+ The source to open for writing.
+ compression : str optional, default 'detect'
+ The compression algorithm to use for on-the-fly compression.
+ If "detect" and source is a file path, then compression will be
+ chosen based on the file extension.
+ If None, no compression will be applied.
+ Otherwise, a well-known algorithm name must be supplied (e.g. "gzip").
+ buffer_size : int, default None
+ If None or 0, no buffering will happen. Otherwise the size of the
+ temporary write buffer.
+ """
+ cdef NativeFile stream
+
+ try:
+ source_path = _stringify_path(source)
+ except TypeError:
+ source_path = None
+
+ if isinstance(source, NativeFile):
+ stream = source
+ elif source_path is not None:
+ stream = OSFile(source_path, 'w')
+ elif isinstance(source, (Buffer, memoryview)):
+ stream = FixedSizeBufferWriter(as_buffer(source))
+ elif (hasattr(source, 'write') and
+ hasattr(source, 'close') and
+ hasattr(source, 'closed')):
+ stream = PythonFile(source, 'w')
+ else:
+ raise TypeError("pa.output_stream() called with instance of '{}'"
+ .format(source.__class__))
+
+ if compression == 'detect':
+ compression = _detect_compression(source_path)
+
+ if buffer_size is not None and buffer_size != 0:
+ stream = BufferedOutputStream(stream, buffer_size)
+
+ if compression is not None:
+ stream = CompressedOutputStream(stream, compression)
+
+ return stream
diff --git a/src/arrow/python/pyarrow/ipc.pxi b/src/arrow/python/pyarrow/ipc.pxi
new file mode 100644
index 000000000..9304bbb97
--- /dev/null
+++ b/src/arrow/python/pyarrow/ipc.pxi
@@ -0,0 +1,1009 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import namedtuple
+import warnings
+
+
+cpdef enum MetadataVersion:
+ V1 = <char> CMetadataVersion_V1
+ V2 = <char> CMetadataVersion_V2
+ V3 = <char> CMetadataVersion_V3
+ V4 = <char> CMetadataVersion_V4
+ V5 = <char> CMetadataVersion_V5
+
+
+cdef object _wrap_metadata_version(CMetadataVersion version):
+ return MetadataVersion(<char> version)
+
+
+cdef CMetadataVersion _unwrap_metadata_version(
+ MetadataVersion version) except *:
+ if version == MetadataVersion.V1:
+ return CMetadataVersion_V1
+ elif version == MetadataVersion.V2:
+ return CMetadataVersion_V2
+ elif version == MetadataVersion.V3:
+ return CMetadataVersion_V3
+ elif version == MetadataVersion.V4:
+ return CMetadataVersion_V4
+ elif version == MetadataVersion.V5:
+ return CMetadataVersion_V5
+ raise ValueError("Not a metadata version: " + repr(version))
+
+
+_WriteStats = namedtuple(
+ 'WriteStats',
+ ('num_messages', 'num_record_batches', 'num_dictionary_batches',
+ 'num_dictionary_deltas', 'num_replaced_dictionaries'))
+
+
+class WriteStats(_WriteStats):
+ """IPC write statistics
+
+ Parameters
+ ----------
+ num_messages : number of messages.
+ num_record_batches : number of record batches.
+ num_dictionary_batches : number of dictionary batches.
+ num_dictionary_deltas : delta of dictionaries.
+ num_replaced_dictionaries : number of replaced dictionaries.
+ """
+ __slots__ = ()
+
+
+@staticmethod
+cdef _wrap_write_stats(CIpcWriteStats c):
+ return WriteStats(c.num_messages, c.num_record_batches,
+ c.num_dictionary_batches, c.num_dictionary_deltas,
+ c.num_replaced_dictionaries)
+
+
+_ReadStats = namedtuple(
+ 'ReadStats',
+ ('num_messages', 'num_record_batches', 'num_dictionary_batches',
+ 'num_dictionary_deltas', 'num_replaced_dictionaries'))
+
+
+class ReadStats(_ReadStats):
+ """IPC read statistics
+
+ Parameters
+ ----------
+ num_messages : number of messages.
+ num_record_batches : number of record batches.
+ num_dictionary_batches : number of dictionary batches.
+ num_dictionary_deltas : delta of dictionaries.
+ num_replaced_dictionaries : number of replaced dictionaries.
+ """
+ __slots__ = ()
+
+
+@staticmethod
+cdef _wrap_read_stats(CIpcReadStats c):
+ return ReadStats(c.num_messages, c.num_record_batches,
+ c.num_dictionary_batches, c.num_dictionary_deltas,
+ c.num_replaced_dictionaries)
+
+
+cdef class IpcWriteOptions(_Weakrefable):
+ """
+ Serialization options for the IPC format.
+
+ Parameters
+ ----------
+ metadata_version : MetadataVersion, default MetadataVersion.V5
+ The metadata version to write. V5 is the current and latest,
+ V4 is the pre-1.0 metadata version (with incompatible Union layout).
+ allow_64bit : bool, default False
+ If true, allow field lengths that don't fit in a signed 32-bit int.
+ use_legacy_format : bool, default False
+ Whether to use the pre-Arrow 0.15 IPC format.
+ compression : str, Codec, or None
+ compression codec to use for record batch buffers.
+ If None then batch buffers will be uncompressed.
+ Must be "lz4", "zstd" or None.
+ To specify a compression_level use `pyarrow.Codec`
+ use_threads : bool
+ Whether to use the global CPU thread pool to parallelize any
+ computational tasks like compression.
+ emit_dictionary_deltas : bool
+ Whether to emit dictionary deltas. Default is false for maximum
+ stream compatibility.
+ """
+ __slots__ = ()
+
+ # cdef block is in lib.pxd
+
+ def __init__(self, *, metadata_version=MetadataVersion.V5,
+ bint allow_64bit=False, use_legacy_format=False,
+ compression=None, bint use_threads=True,
+ bint emit_dictionary_deltas=False):
+ self.c_options = CIpcWriteOptions.Defaults()
+ self.allow_64bit = allow_64bit
+ self.use_legacy_format = use_legacy_format
+ self.metadata_version = metadata_version
+ if compression is not None:
+ self.compression = compression
+ self.use_threads = use_threads
+ self.emit_dictionary_deltas = emit_dictionary_deltas
+
+ @property
+ def allow_64bit(self):
+ return self.c_options.allow_64bit
+
+ @allow_64bit.setter
+ def allow_64bit(self, bint value):
+ self.c_options.allow_64bit = value
+
+ @property
+ def use_legacy_format(self):
+ return self.c_options.write_legacy_ipc_format
+
+ @use_legacy_format.setter
+ def use_legacy_format(self, bint value):
+ self.c_options.write_legacy_ipc_format = value
+
+ @property
+ def metadata_version(self):
+ return _wrap_metadata_version(self.c_options.metadata_version)
+
+ @metadata_version.setter
+ def metadata_version(self, value):
+ self.c_options.metadata_version = _unwrap_metadata_version(value)
+
+ @property
+ def compression(self):
+ if self.c_options.codec == nullptr:
+ return None
+ else:
+ return frombytes(self.c_options.codec.get().name())
+
+ @compression.setter
+ def compression(self, value):
+ if value is None:
+ self.c_options.codec.reset()
+ elif isinstance(value, str):
+ self.c_options.codec = shared_ptr[CCodec](GetResultValue(
+ CCodec.Create(_ensure_compression(value))).release())
+ elif isinstance(value, Codec):
+ self.c_options.codec = (<Codec>value).wrapped
+ else:
+ raise TypeError(
+ "Property `compression` must be None, str, or pyarrow.Codec")
+
+ @property
+ def use_threads(self):
+ return self.c_options.use_threads
+
+ @use_threads.setter
+ def use_threads(self, bint value):
+ self.c_options.use_threads = value
+
+ @property
+ def emit_dictionary_deltas(self):
+ return self.c_options.emit_dictionary_deltas
+
+ @emit_dictionary_deltas.setter
+ def emit_dictionary_deltas(self, bint value):
+ self.c_options.emit_dictionary_deltas = value
+
+
+cdef class Message(_Weakrefable):
+ """
+ Container for an Arrow IPC message with metadata and optional body
+ """
+
+ def __cinit__(self):
+ pass
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use "
+ "`pyarrow.ipc.read_message` function instead."
+ .format(self.__class__.__name__))
+
+ @property
+ def type(self):
+ return frombytes(FormatMessageType(self.message.get().type()))
+
+ @property
+ def metadata(self):
+ return pyarrow_wrap_buffer(self.message.get().metadata())
+
+ @property
+ def metadata_version(self):
+ return _wrap_metadata_version(self.message.get().metadata_version())
+
+ @property
+ def body(self):
+ cdef shared_ptr[CBuffer] body = self.message.get().body()
+ if body.get() == NULL:
+ return None
+ else:
+ return pyarrow_wrap_buffer(body)
+
+ def equals(self, Message other):
+ """
+ Returns True if the message contents (metadata and body) are identical
+
+ Parameters
+ ----------
+ other : Message
+
+ Returns
+ -------
+ are_equal : bool
+ """
+ cdef c_bool result
+ with nogil:
+ result = self.message.get().Equals(deref(other.message.get()))
+ return result
+
+ def serialize_to(self, NativeFile sink, alignment=8, memory_pool=None):
+ """
+ Write message to generic OutputStream
+
+ Parameters
+ ----------
+ sink : NativeFile
+ alignment : int, default 8
+ Byte alignment for metadata and body
+ memory_pool : MemoryPool, default None
+ Uses default memory pool if not specified
+ """
+ cdef:
+ int64_t output_length = 0
+ COutputStream* out
+ CIpcWriteOptions options
+
+ options.alignment = alignment
+ out = sink.get_output_stream().get()
+ with nogil:
+ check_status(self.message.get()
+ .SerializeTo(out, options, &output_length))
+
+ def serialize(self, alignment=8, memory_pool=None):
+ """
+ Write message as encapsulated IPC message
+
+ Parameters
+ ----------
+ alignment : int, default 8
+ Byte alignment for metadata and body
+ memory_pool : MemoryPool, default None
+ Uses default memory pool if not specified
+
+ Returns
+ -------
+ serialized : Buffer
+ """
+ stream = BufferOutputStream(memory_pool)
+ self.serialize_to(stream, alignment=alignment, memory_pool=memory_pool)
+ return stream.getvalue()
+
+ def __repr__(self):
+ if self.message == nullptr:
+ return """pyarrow.Message(uninitialized)"""
+
+ metadata_len = self.metadata.size
+ body = self.body
+ body_len = 0 if body is None else body.size
+
+ return """pyarrow.Message
+type: {0}
+metadata length: {1}
+body length: {2}""".format(self.type, metadata_len, body_len)
+
+
+cdef class MessageReader(_Weakrefable):
+ """
+ Interface for reading Message objects from some source (like an
+ InputStream)
+ """
+ cdef:
+ unique_ptr[CMessageReader] reader
+
+ def __cinit__(self):
+ pass
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use "
+ "`pyarrow.ipc.MessageReader.open_stream` function "
+ "instead.".format(self.__class__.__name__))
+
+ @staticmethod
+ def open_stream(source):
+ """
+ Open stream from source.
+
+ Parameters
+ ----------
+ source : a readable source, like an InputStream
+ """
+ cdef:
+ MessageReader result = MessageReader.__new__(MessageReader)
+ shared_ptr[CInputStream] in_stream
+ unique_ptr[CMessageReader] reader
+
+ _get_input_stream(source, &in_stream)
+ with nogil:
+ reader = CMessageReader.Open(in_stream)
+ result.reader.reset(reader.release())
+
+ return result
+
+ def __iter__(self):
+ while True:
+ yield self.read_next_message()
+
+ def read_next_message(self):
+ """
+ Read next Message from the stream.
+
+ Raises
+ ------
+ StopIteration : at end of stream
+ """
+ cdef Message result = Message.__new__(Message)
+
+ with nogil:
+ result.message = move(GetResultValue(self.reader.get()
+ .ReadNextMessage()))
+
+ if result.message.get() == NULL:
+ raise StopIteration
+
+ return result
+
+# ----------------------------------------------------------------------
+# File and stream readers and writers
+
+cdef class _CRecordBatchWriter(_Weakrefable):
+ """The base RecordBatchWriter wrapper.
+
+ Provides common implementations of convenience methods. Should not
+ be instantiated directly by user code.
+ """
+
+ # cdef block is in lib.pxd
+
+ def write(self, table_or_batch):
+ """
+ Write RecordBatch or Table to stream.
+
+ Parameters
+ ----------
+ table_or_batch : {RecordBatch, Table}
+ """
+ if isinstance(table_or_batch, RecordBatch):
+ self.write_batch(table_or_batch)
+ elif isinstance(table_or_batch, Table):
+ self.write_table(table_or_batch)
+ else:
+ raise ValueError(type(table_or_batch))
+
+ def write_batch(self, RecordBatch batch):
+ """
+ Write RecordBatch to stream.
+
+ Parameters
+ ----------
+ batch : RecordBatch
+ """
+ with nogil:
+ check_status(self.writer.get()
+ .WriteRecordBatch(deref(batch.batch)))
+
+ def write_table(self, Table table, max_chunksize=None, **kwargs):
+ """
+ Write Table to stream in (contiguous) RecordBatch objects.
+
+ Parameters
+ ----------
+ table : Table
+ max_chunksize : int, default None
+ Maximum size for RecordBatch chunks. Individual chunks may be
+ smaller depending on the chunk layout of individual columns.
+ """
+ cdef:
+ # max_chunksize must be > 0 to have any impact
+ int64_t c_max_chunksize = -1
+
+ if 'chunksize' in kwargs:
+ max_chunksize = kwargs['chunksize']
+ msg = ('The parameter chunksize is deprecated for the write_table '
+ 'methods as of 0.15, please use parameter '
+ 'max_chunksize instead')
+ warnings.warn(msg, FutureWarning)
+
+ if max_chunksize is not None:
+ c_max_chunksize = max_chunksize
+
+ with nogil:
+ check_status(self.writer.get().WriteTable(table.table[0],
+ c_max_chunksize))
+
+ def close(self):
+ """
+ Close stream and write end-of-stream 0 marker.
+ """
+ with nogil:
+ check_status(self.writer.get().Close())
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ @property
+ def stats(self):
+ """
+ Current IPC write statistics.
+ """
+ if not self.writer:
+ raise ValueError("Operation on closed writer")
+ return _wrap_write_stats(self.writer.get().stats())
+
+
+cdef class _RecordBatchStreamWriter(_CRecordBatchWriter):
+ cdef:
+ CIpcWriteOptions options
+ bint closed
+
+ def __cinit__(self):
+ pass
+
+ def __dealloc__(self):
+ pass
+
+ @property
+ def _use_legacy_format(self):
+ # For testing (see test_ipc.py)
+ return self.options.write_legacy_ipc_format
+
+ @property
+ def _metadata_version(self):
+ # For testing (see test_ipc.py)
+ return _wrap_metadata_version(self.options.metadata_version)
+
+ def _open(self, sink, Schema schema not None,
+ IpcWriteOptions options=IpcWriteOptions()):
+ cdef:
+ shared_ptr[COutputStream] c_sink
+
+ self.options = options.c_options
+ get_writer(sink, &c_sink)
+ with nogil:
+ self.writer = GetResultValue(
+ MakeStreamWriter(c_sink, schema.sp_schema,
+ self.options))
+
+
+cdef _get_input_stream(object source, shared_ptr[CInputStream]* out):
+ try:
+ source = as_buffer(source)
+ except TypeError:
+ # Non-buffer-like
+ pass
+
+ get_input_stream(source, True, out)
+
+
+class _ReadPandasMixin:
+
+ def read_pandas(self, **options):
+ """
+ Read contents of stream to a pandas.DataFrame.
+
+ Read all record batches as a pyarrow.Table then convert it to a
+ pandas.DataFrame using Table.to_pandas.
+
+ Parameters
+ ----------
+ **options : arguments to forward to Table.to_pandas
+
+ Returns
+ -------
+ df : pandas.DataFrame
+ """
+ table = self.read_all()
+ return table.to_pandas(**options)
+
+
+cdef class RecordBatchReader(_Weakrefable):
+ """Base class for reading stream of record batches.
+
+ Provides common implementations of convenience methods. Should not
+ be instantiated directly by user code.
+ """
+
+ # cdef block is in lib.pxd
+
+ def __iter__(self):
+ while True:
+ try:
+ yield self.read_next_batch()
+ except StopIteration:
+ return
+
+ @property
+ def schema(self):
+ """
+ Shared schema of the record batches in the stream.
+ """
+ cdef shared_ptr[CSchema] c_schema
+
+ with nogil:
+ c_schema = self.reader.get().schema()
+
+ return pyarrow_wrap_schema(c_schema)
+
+ def get_next_batch(self):
+ import warnings
+ warnings.warn('Please use read_next_batch instead of '
+ 'get_next_batch', FutureWarning)
+ return self.read_next_batch()
+
+ def read_next_batch(self):
+ """
+ Read next RecordBatch from the stream.
+
+ Raises
+ ------
+ StopIteration:
+ At end of stream.
+ """
+ cdef shared_ptr[CRecordBatch] batch
+
+ with nogil:
+ check_status(self.reader.get().ReadNext(&batch))
+
+ if batch.get() == NULL:
+ raise StopIteration
+
+ return pyarrow_wrap_batch(batch)
+
+ def read_all(self):
+ """
+ Read all record batches as a pyarrow.Table.
+ """
+ cdef shared_ptr[CTable] table
+ with nogil:
+ check_status(self.reader.get().ReadAll(&table))
+ return pyarrow_wrap_table(table)
+
+ read_pandas = _ReadPandasMixin.read_pandas
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def _export_to_c(self, uintptr_t out_ptr):
+ """
+ Export to a C ArrowArrayStream struct, given its pointer.
+
+ Parameters
+ ----------
+ out_ptr: int
+ The raw pointer to a C ArrowArrayStream struct.
+
+ Be careful: if you don't pass the ArrowArrayStream struct to a
+ consumer, array memory will leak. This is a low-level function
+ intended for expert users.
+ """
+ with nogil:
+ check_status(ExportRecordBatchReader(
+ self.reader, <ArrowArrayStream*> out_ptr))
+
+ @staticmethod
+ def _import_from_c(uintptr_t in_ptr):
+ """
+ Import RecordBatchReader from a C ArrowArrayStream struct,
+ given its pointer.
+
+ Parameters
+ ----------
+ in_ptr: int
+ The raw pointer to a C ArrowArrayStream struct.
+
+ This is a low-level function intended for expert users.
+ """
+ cdef:
+ shared_ptr[CRecordBatchReader] c_reader
+ RecordBatchReader self
+
+ with nogil:
+ c_reader = GetResultValue(ImportRecordBatchReader(
+ <ArrowArrayStream*> in_ptr))
+
+ self = RecordBatchReader.__new__(RecordBatchReader)
+ self.reader = c_reader
+ return self
+
+ @staticmethod
+ def from_batches(schema, batches):
+ """
+ Create RecordBatchReader from an iterable of batches.
+
+ Parameters
+ ----------
+ schema : Schema
+ The shared schema of the record batches
+ batches : Iterable[RecordBatch]
+ The batches that this reader will return.
+
+ Returns
+ -------
+ reader : RecordBatchReader
+ """
+ cdef:
+ shared_ptr[CSchema] c_schema
+ shared_ptr[CRecordBatchReader] c_reader
+ RecordBatchReader self
+
+ c_schema = pyarrow_unwrap_schema(schema)
+ c_reader = GetResultValue(CPyRecordBatchReader.Make(
+ c_schema, batches))
+
+ self = RecordBatchReader.__new__(RecordBatchReader)
+ self.reader = c_reader
+ return self
+
+
+cdef class _RecordBatchStreamReader(RecordBatchReader):
+ cdef:
+ shared_ptr[CInputStream] in_stream
+ CIpcReadOptions options
+ CRecordBatchStreamReader* stream_reader
+
+ def __cinit__(self):
+ pass
+
+ def _open(self, source):
+ _get_input_stream(source, &self.in_stream)
+ with nogil:
+ self.reader = GetResultValue(CRecordBatchStreamReader.Open(
+ self.in_stream, self.options))
+ self.stream_reader = <CRecordBatchStreamReader*> self.reader.get()
+
+ @property
+ def stats(self):
+ """
+ Current IPC read statistics.
+ """
+ if not self.reader:
+ raise ValueError("Operation on closed reader")
+ return _wrap_read_stats(self.stream_reader.stats())
+
+
+cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter):
+
+ def _open(self, sink, Schema schema not None,
+ IpcWriteOptions options=IpcWriteOptions()):
+ cdef:
+ shared_ptr[COutputStream] c_sink
+
+ self.options = options.c_options
+ get_writer(sink, &c_sink)
+ with nogil:
+ self.writer = GetResultValue(
+ MakeFileWriter(c_sink, schema.sp_schema, self.options))
+
+
+cdef class _RecordBatchFileReader(_Weakrefable):
+ cdef:
+ shared_ptr[CRecordBatchFileReader] reader
+ shared_ptr[CRandomAccessFile] file
+ CIpcReadOptions options
+
+ cdef readonly:
+ Schema schema
+
+ def __cinit__(self):
+ pass
+
+ def _open(self, source, footer_offset=None):
+ try:
+ source = as_buffer(source)
+ except TypeError:
+ pass
+
+ get_reader(source, True, &self.file)
+
+ cdef int64_t offset = 0
+ if footer_offset is not None:
+ offset = footer_offset
+
+ with nogil:
+ if offset != 0:
+ self.reader = GetResultValue(
+ CRecordBatchFileReader.Open2(self.file.get(), offset,
+ self.options))
+
+ else:
+ self.reader = GetResultValue(
+ CRecordBatchFileReader.Open(self.file.get(),
+ self.options))
+
+ self.schema = pyarrow_wrap_schema(self.reader.get().schema())
+
+ @property
+ def num_record_batches(self):
+ return self.reader.get().num_record_batches()
+
+ def get_batch(self, int i):
+ cdef shared_ptr[CRecordBatch] batch
+
+ if i < 0 or i >= self.num_record_batches:
+ raise ValueError('Batch number {0} out of range'.format(i))
+
+ with nogil:
+ batch = GetResultValue(self.reader.get().ReadRecordBatch(i))
+
+ return pyarrow_wrap_batch(batch)
+
+ # TODO(wesm): ARROW-503: Function was renamed. Remove after a period of
+ # time has passed
+ get_record_batch = get_batch
+
+ def read_all(self):
+ """
+ Read all record batches as a pyarrow.Table
+ """
+ cdef:
+ vector[shared_ptr[CRecordBatch]] batches
+ shared_ptr[CTable] table
+ int i, nbatches
+
+ nbatches = self.num_record_batches
+
+ batches.resize(nbatches)
+ with nogil:
+ for i in range(nbatches):
+ batches[i] = GetResultValue(self.reader.get()
+ .ReadRecordBatch(i))
+ table = GetResultValue(
+ CTable.FromRecordBatches(self.schema.sp_schema, move(batches)))
+
+ return pyarrow_wrap_table(table)
+
+ read_pandas = _ReadPandasMixin.read_pandas
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+ @property
+ def stats(self):
+ """
+ Current IPC read statistics.
+ """
+ if not self.reader:
+ raise ValueError("Operation on closed reader")
+ return _wrap_read_stats(self.reader.get().stats())
+
+
+def get_tensor_size(Tensor tensor):
+ """
+ Return total size of serialized Tensor including metadata and padding.
+
+ Parameters
+ ----------
+ tensor : Tensor
+ The tensor for which we want to known the size.
+ """
+ cdef int64_t size
+ with nogil:
+ check_status(GetTensorSize(deref(tensor.tp), &size))
+ return size
+
+
+def get_record_batch_size(RecordBatch batch):
+ """
+ Return total size of serialized RecordBatch including metadata and padding.
+
+ Parameters
+ ----------
+ batch : RecordBatch
+ The recordbatch for which we want to know the size.
+ """
+ cdef int64_t size
+ with nogil:
+ check_status(GetRecordBatchSize(deref(batch.batch), &size))
+ return size
+
+
+def write_tensor(Tensor tensor, NativeFile dest):
+ """
+ Write pyarrow.Tensor to pyarrow.NativeFile object its current position.
+
+ Parameters
+ ----------
+ tensor : pyarrow.Tensor
+ dest : pyarrow.NativeFile
+
+ Returns
+ -------
+ bytes_written : int
+ Total number of bytes written to the file
+ """
+ cdef:
+ int32_t metadata_length
+ int64_t body_length
+
+ handle = dest.get_output_stream()
+
+ with nogil:
+ check_status(
+ WriteTensor(deref(tensor.tp), handle.get(),
+ &metadata_length, &body_length))
+
+ return metadata_length + body_length
+
+
+cdef NativeFile as_native_file(source):
+ if not isinstance(source, NativeFile):
+ if hasattr(source, 'read'):
+ source = PythonFile(source)
+ else:
+ source = BufferReader(source)
+
+ if not isinstance(source, NativeFile):
+ raise ValueError('Unable to read message from object with type: {0}'
+ .format(type(source)))
+ return source
+
+
+def read_tensor(source):
+ """Read pyarrow.Tensor from pyarrow.NativeFile object from current
+ position. If the file source supports zero copy (e.g. a memory map), then
+ this operation does not allocate any memory. This function not assume that
+ the stream is aligned
+
+ Parameters
+ ----------
+ source : pyarrow.NativeFile
+
+ Returns
+ -------
+ tensor : Tensor
+
+ """
+ cdef:
+ shared_ptr[CTensor] sp_tensor
+ CInputStream* c_stream
+ NativeFile nf = as_native_file(source)
+
+ c_stream = nf.get_input_stream().get()
+ with nogil:
+ sp_tensor = GetResultValue(ReadTensor(c_stream))
+ return pyarrow_wrap_tensor(sp_tensor)
+
+
+def read_message(source):
+ """
+ Read length-prefixed message from file or buffer-like object
+
+ Parameters
+ ----------
+ source : pyarrow.NativeFile, file-like object, or buffer-like object
+
+ Returns
+ -------
+ message : Message
+ """
+ cdef:
+ Message result = Message.__new__(Message)
+ CInputStream* c_stream
+
+ cdef NativeFile nf = as_native_file(source)
+ c_stream = nf.get_input_stream().get()
+
+ with nogil:
+ result.message = move(
+ GetResultValue(ReadMessage(c_stream, c_default_memory_pool())))
+
+ if result.message == nullptr:
+ raise EOFError("End of Arrow stream")
+
+ return result
+
+
+def read_schema(obj, DictionaryMemo dictionary_memo=None):
+ """
+ Read Schema from message or buffer
+
+ Parameters
+ ----------
+ obj : buffer or Message
+ dictionary_memo : DictionaryMemo, optional
+ Needed to be able to reconstruct dictionary-encoded fields
+ with read_record_batch
+
+ Returns
+ -------
+ schema : Schema
+ """
+ cdef:
+ shared_ptr[CSchema] result
+ shared_ptr[CRandomAccessFile] cpp_file
+ CDictionaryMemo temp_memo
+ CDictionaryMemo* arg_dict_memo
+
+ if isinstance(obj, Message):
+ raise NotImplementedError(type(obj))
+
+ get_reader(obj, True, &cpp_file)
+
+ if dictionary_memo is not None:
+ arg_dict_memo = dictionary_memo.memo
+ else:
+ arg_dict_memo = &temp_memo
+
+ with nogil:
+ result = GetResultValue(ReadSchema(cpp_file.get(), arg_dict_memo))
+
+ return pyarrow_wrap_schema(result)
+
+
+def read_record_batch(obj, Schema schema,
+ DictionaryMemo dictionary_memo=None):
+ """
+ Read RecordBatch from message, given a known schema. If reading data from a
+ complete IPC stream, use ipc.open_stream instead
+
+ Parameters
+ ----------
+ obj : Message or Buffer-like
+ schema : Schema
+ dictionary_memo : DictionaryMemo, optional
+ If message contains dictionaries, must pass a populated
+ DictionaryMemo
+
+ Returns
+ -------
+ batch : RecordBatch
+ """
+ cdef:
+ shared_ptr[CRecordBatch] result
+ Message message
+ CDictionaryMemo temp_memo
+ CDictionaryMemo* arg_dict_memo
+
+ if isinstance(obj, Message):
+ message = obj
+ else:
+ message = read_message(obj)
+
+ if dictionary_memo is not None:
+ arg_dict_memo = dictionary_memo.memo
+ else:
+ arg_dict_memo = &temp_memo
+
+ with nogil:
+ result = GetResultValue(
+ ReadRecordBatch(deref(message.message.get()),
+ schema.sp_schema,
+ arg_dict_memo,
+ CIpcReadOptions.Defaults()))
+
+ return pyarrow_wrap_batch(result)
diff --git a/src/arrow/python/pyarrow/ipc.py b/src/arrow/python/pyarrow/ipc.py
new file mode 100644
index 000000000..cb28a0b5f
--- /dev/null
+++ b/src/arrow/python/pyarrow/ipc.py
@@ -0,0 +1,233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Arrow file and stream reader/writer classes, and other messaging tools
+
+import os
+
+import pyarrow as pa
+
+from pyarrow.lib import (IpcWriteOptions, ReadStats, WriteStats, # noqa
+ Message, MessageReader,
+ RecordBatchReader, _ReadPandasMixin,
+ MetadataVersion,
+ read_message, read_record_batch, read_schema,
+ read_tensor, write_tensor,
+ get_record_batch_size, get_tensor_size)
+import pyarrow.lib as lib
+
+
+class RecordBatchStreamReader(lib._RecordBatchStreamReader):
+ """
+ Reader for the Arrow streaming binary format.
+
+ Parameters
+ ----------
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object.
+ """
+
+ def __init__(self, source):
+ self._open(source)
+
+
+_ipc_writer_class_doc = """\
+Parameters
+----------
+sink : str, pyarrow.NativeFile, or file-like Python object
+ Either a file path, or a writable file object.
+schema : pyarrow.Schema
+ The Arrow schema for data to be written to the file.
+options : pyarrow.ipc.IpcWriteOptions
+ Options for IPC serialization.
+
+ If None, default values will be used: the legacy format will not
+ be used unless overridden by setting the environment variable
+ ARROW_PRE_0_15_IPC_FORMAT=1, and the V5 metadata version will be
+ used unless overridden by setting the environment variable
+ ARROW_PRE_1_0_METADATA_VERSION=1.
+use_legacy_format : bool, default None
+ Deprecated in favor of setting options. Cannot be provided with
+ options.
+
+ If None, False will be used unless this default is overridden by
+ setting the environment variable ARROW_PRE_0_15_IPC_FORMAT=1"""
+
+
+class RecordBatchStreamWriter(lib._RecordBatchStreamWriter):
+ __doc__ = """Writer for the Arrow streaming binary format
+
+{}""".format(_ipc_writer_class_doc)
+
+ def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
+ options = _get_legacy_format_default(use_legacy_format, options)
+ self._open(sink, schema, options=options)
+
+
+class RecordBatchFileReader(lib._RecordBatchFileReader):
+ """
+ Class for reading Arrow record batch data from the Arrow binary file format
+
+ Parameters
+ ----------
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object
+ footer_offset : int, default None
+ If the file is embedded in some larger file, this is the byte offset to
+ the very end of the file data
+ """
+
+ def __init__(self, source, footer_offset=None):
+ self._open(source, footer_offset=footer_offset)
+
+
+class RecordBatchFileWriter(lib._RecordBatchFileWriter):
+
+ __doc__ = """Writer to create the Arrow binary file format
+
+{}""".format(_ipc_writer_class_doc)
+
+ def __init__(self, sink, schema, *, use_legacy_format=None, options=None):
+ options = _get_legacy_format_default(use_legacy_format, options)
+ self._open(sink, schema, options=options)
+
+
+def _get_legacy_format_default(use_legacy_format, options):
+ if use_legacy_format is not None and options is not None:
+ raise ValueError(
+ "Can provide at most one of options and use_legacy_format")
+ elif options:
+ if not isinstance(options, IpcWriteOptions):
+ raise TypeError("expected IpcWriteOptions, got {}"
+ .format(type(options)))
+ return options
+
+ metadata_version = MetadataVersion.V5
+ if use_legacy_format is None:
+ use_legacy_format = \
+ bool(int(os.environ.get('ARROW_PRE_0_15_IPC_FORMAT', '0')))
+ if bool(int(os.environ.get('ARROW_PRE_1_0_METADATA_VERSION', '0'))):
+ metadata_version = MetadataVersion.V4
+ return IpcWriteOptions(use_legacy_format=use_legacy_format,
+ metadata_version=metadata_version)
+
+
+def new_stream(sink, schema, *, use_legacy_format=None, options=None):
+ return RecordBatchStreamWriter(sink, schema,
+ use_legacy_format=use_legacy_format,
+ options=options)
+
+
+new_stream.__doc__ = """\
+Create an Arrow columnar IPC stream writer instance
+
+{}""".format(_ipc_writer_class_doc)
+
+
+def open_stream(source):
+ """
+ Create reader for Arrow streaming format.
+
+ Parameters
+ ----------
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object.
+
+ Returns
+ -------
+ reader : RecordBatchStreamReader
+ """
+ return RecordBatchStreamReader(source)
+
+
+def new_file(sink, schema, *, use_legacy_format=None, options=None):
+ return RecordBatchFileWriter(sink, schema,
+ use_legacy_format=use_legacy_format,
+ options=options)
+
+
+new_file.__doc__ = """\
+Create an Arrow columnar IPC file writer instance
+
+{}""".format(_ipc_writer_class_doc)
+
+
+def open_file(source, footer_offset=None):
+ """
+ Create reader for Arrow file format.
+
+ Parameters
+ ----------
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object.
+ footer_offset : int, default None
+ If the file is embedded in some larger file, this is the byte offset to
+ the very end of the file data.
+
+ Returns
+ -------
+ reader : RecordBatchFileReader
+ """
+ return RecordBatchFileReader(source, footer_offset=footer_offset)
+
+
+def serialize_pandas(df, *, nthreads=None, preserve_index=None):
+ """
+ Serialize a pandas DataFrame into a buffer protocol compatible object.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ nthreads : int, default None
+ Number of threads to use for conversion to Arrow, default all CPUs.
+ preserve_index : bool, default None
+ The default of None will store the index as a column, except for
+ RangeIndex which is stored as metadata only. If True, always
+ preserve the pandas index data as a column. If False, no index
+ information is saved and the result will have a default RangeIndex.
+
+ Returns
+ -------
+ buf : buffer
+ An object compatible with the buffer protocol.
+ """
+ batch = pa.RecordBatch.from_pandas(df, nthreads=nthreads,
+ preserve_index=preserve_index)
+ sink = pa.BufferOutputStream()
+ with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
+ writer.write_batch(batch)
+ return sink.getvalue()
+
+
+def deserialize_pandas(buf, *, use_threads=True):
+ """Deserialize a buffer protocol compatible object into a pandas DataFrame.
+
+ Parameters
+ ----------
+ buf : buffer
+ An object compatible with the buffer protocol.
+ use_threads : bool, default True
+ Whether to parallelize the conversion using multiple threads.
+
+ Returns
+ -------
+ df : pandas.DataFrame
+ """
+ buffer_reader = pa.BufferReader(buf)
+ with pa.RecordBatchStreamReader(buffer_reader) as reader:
+ table = reader.read_all()
+ return table.to_pandas(use_threads=use_threads)
diff --git a/src/arrow/python/pyarrow/json.py b/src/arrow/python/pyarrow/json.py
new file mode 100644
index 000000000..a864f5d99
--- /dev/null
+++ b/src/arrow/python/pyarrow/json.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from pyarrow._json import ReadOptions, ParseOptions, read_json # noqa
diff --git a/src/arrow/python/pyarrow/jvm.py b/src/arrow/python/pyarrow/jvm.py
new file mode 100644
index 000000000..161c5ff4d
--- /dev/null
+++ b/src/arrow/python/pyarrow/jvm.py
@@ -0,0 +1,335 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Functions to interact with Arrow memory allocated by Arrow Java.
+
+These functions convert the objects holding the metadata, the actual
+data is not copied at all.
+
+This will only work with a JVM running in the same process such as provided
+through jpype. Modules that talk to a remote JVM like py4j will not work as the
+memory addresses reported by them are not reachable in the python process.
+"""
+
+import pyarrow as pa
+
+
+class _JvmBufferNanny:
+ """
+ An object that keeps a org.apache.arrow.memory.ArrowBuf's underlying
+ memory alive.
+ """
+ ref_manager = None
+
+ def __init__(self, jvm_buf):
+ ref_manager = jvm_buf.getReferenceManager()
+ # Will raise a java.lang.IllegalArgumentException if the buffer
+ # is already freed. It seems that exception cannot easily be
+ # caught...
+ ref_manager.retain()
+ self.ref_manager = ref_manager
+
+ def __del__(self):
+ if self.ref_manager is not None:
+ self.ref_manager.release()
+
+
+def jvm_buffer(jvm_buf):
+ """
+ Construct an Arrow buffer from org.apache.arrow.memory.ArrowBuf
+
+ Parameters
+ ----------
+
+ jvm_buf: org.apache.arrow.memory.ArrowBuf
+ Arrow Buffer representation on the JVM.
+
+ Returns
+ -------
+ pyarrow.Buffer
+ Python Buffer that references the JVM memory.
+ """
+ nanny = _JvmBufferNanny(jvm_buf)
+ address = jvm_buf.memoryAddress()
+ size = jvm_buf.capacity()
+ return pa.foreign_buffer(address, size, base=nanny)
+
+
+def _from_jvm_int_type(jvm_type):
+ """
+ Convert a JVM int type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type : org.apache.arrow.vector.types.pojo.ArrowType$Int
+
+ Returns
+ -------
+ typ : pyarrow.DataType
+ """
+
+ bit_width = jvm_type.getBitWidth()
+ if jvm_type.getIsSigned():
+ if bit_width == 8:
+ return pa.int8()
+ elif bit_width == 16:
+ return pa.int16()
+ elif bit_width == 32:
+ return pa.int32()
+ elif bit_width == 64:
+ return pa.int64()
+ else:
+ if bit_width == 8:
+ return pa.uint8()
+ elif bit_width == 16:
+ return pa.uint16()
+ elif bit_width == 32:
+ return pa.uint32()
+ elif bit_width == 64:
+ return pa.uint64()
+
+
+def _from_jvm_float_type(jvm_type):
+ """
+ Convert a JVM float type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$FloatingPoint
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ precision = jvm_type.getPrecision().toString()
+ if precision == 'HALF':
+ return pa.float16()
+ elif precision == 'SINGLE':
+ return pa.float32()
+ elif precision == 'DOUBLE':
+ return pa.float64()
+
+
+def _from_jvm_time_type(jvm_type):
+ """
+ Convert a JVM time type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Time
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ time_unit = jvm_type.getUnit().toString()
+ if time_unit == 'SECOND':
+ assert jvm_type.getBitWidth() == 32
+ return pa.time32('s')
+ elif time_unit == 'MILLISECOND':
+ assert jvm_type.getBitWidth() == 32
+ return pa.time32('ms')
+ elif time_unit == 'MICROSECOND':
+ assert jvm_type.getBitWidth() == 64
+ return pa.time64('us')
+ elif time_unit == 'NANOSECOND':
+ assert jvm_type.getBitWidth() == 64
+ return pa.time64('ns')
+
+
+def _from_jvm_timestamp_type(jvm_type):
+ """
+ Convert a JVM timestamp type to its Python equivalent.
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Timestamp
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ time_unit = jvm_type.getUnit().toString()
+ timezone = jvm_type.getTimezone()
+ if timezone is not None:
+ timezone = str(timezone)
+ if time_unit == 'SECOND':
+ return pa.timestamp('s', tz=timezone)
+ elif time_unit == 'MILLISECOND':
+ return pa.timestamp('ms', tz=timezone)
+ elif time_unit == 'MICROSECOND':
+ return pa.timestamp('us', tz=timezone)
+ elif time_unit == 'NANOSECOND':
+ return pa.timestamp('ns', tz=timezone)
+
+
+def _from_jvm_date_type(jvm_type):
+ """
+ Convert a JVM date type to its Python equivalent
+
+ Parameters
+ ----------
+ jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Date
+
+ Returns
+ -------
+ typ: pyarrow.DataType
+ """
+ day_unit = jvm_type.getUnit().toString()
+ if day_unit == 'DAY':
+ return pa.date32()
+ elif day_unit == 'MILLISECOND':
+ return pa.date64()
+
+
+def field(jvm_field):
+ """
+ Construct a Field from a org.apache.arrow.vector.types.pojo.Field
+ instance.
+
+ Parameters
+ ----------
+ jvm_field: org.apache.arrow.vector.types.pojo.Field
+
+ Returns
+ -------
+ pyarrow.Field
+ """
+ name = str(jvm_field.getName())
+ jvm_type = jvm_field.getType()
+
+ typ = None
+ if not jvm_type.isComplex():
+ type_str = jvm_type.getTypeID().toString()
+ if type_str == 'Null':
+ typ = pa.null()
+ elif type_str == 'Int':
+ typ = _from_jvm_int_type(jvm_type)
+ elif type_str == 'FloatingPoint':
+ typ = _from_jvm_float_type(jvm_type)
+ elif type_str == 'Utf8':
+ typ = pa.string()
+ elif type_str == 'Binary':
+ typ = pa.binary()
+ elif type_str == 'FixedSizeBinary':
+ typ = pa.binary(jvm_type.getByteWidth())
+ elif type_str == 'Bool':
+ typ = pa.bool_()
+ elif type_str == 'Time':
+ typ = _from_jvm_time_type(jvm_type)
+ elif type_str == 'Timestamp':
+ typ = _from_jvm_timestamp_type(jvm_type)
+ elif type_str == 'Date':
+ typ = _from_jvm_date_type(jvm_type)
+ elif type_str == 'Decimal':
+ typ = pa.decimal128(jvm_type.getPrecision(), jvm_type.getScale())
+ else:
+ raise NotImplementedError(
+ "Unsupported JVM type: {}".format(type_str))
+ else:
+ # TODO: The following JVM types are not implemented:
+ # Struct, List, FixedSizeList, Union, Dictionary
+ raise NotImplementedError(
+ "JVM field conversion only implemented for primitive types.")
+
+ nullable = jvm_field.isNullable()
+ jvm_metadata = jvm_field.getMetadata()
+ if jvm_metadata.isEmpty():
+ metadata = None
+ else:
+ metadata = {str(entry.getKey()): str(entry.getValue())
+ for entry in jvm_metadata.entrySet()}
+ return pa.field(name, typ, nullable, metadata)
+
+
+def schema(jvm_schema):
+ """
+ Construct a Schema from a org.apache.arrow.vector.types.pojo.Schema
+ instance.
+
+ Parameters
+ ----------
+ jvm_schema: org.apache.arrow.vector.types.pojo.Schema
+
+ Returns
+ -------
+ pyarrow.Schema
+ """
+ fields = jvm_schema.getFields()
+ fields = [field(f) for f in fields]
+ jvm_metadata = jvm_schema.getCustomMetadata()
+ if jvm_metadata.isEmpty():
+ metadata = None
+ else:
+ metadata = {str(entry.getKey()): str(entry.getValue())
+ for entry in jvm_metadata.entrySet()}
+ return pa.schema(fields, metadata)
+
+
+def array(jvm_array):
+ """
+ Construct an (Python) Array from its JVM equivalent.
+
+ Parameters
+ ----------
+ jvm_array : org.apache.arrow.vector.ValueVector
+
+ Returns
+ -------
+ array : Array
+ """
+ if jvm_array.getField().getType().isComplex():
+ minor_type_str = jvm_array.getMinorType().toString()
+ raise NotImplementedError(
+ "Cannot convert JVM Arrow array of type {},"
+ " complex types not yet implemented.".format(minor_type_str))
+ dtype = field(jvm_array.getField()).type
+ buffers = [jvm_buffer(buf)
+ for buf in list(jvm_array.getBuffers(False))]
+
+ # If JVM has an empty Vector, buffer list will be empty so create manually
+ if len(buffers) == 0:
+ return pa.array([], type=dtype)
+
+ length = jvm_array.getValueCount()
+ null_count = jvm_array.getNullCount()
+ return pa.Array.from_buffers(dtype, length, buffers, null_count)
+
+
+def record_batch(jvm_vector_schema_root):
+ """
+ Construct a (Python) RecordBatch from a JVM VectorSchemaRoot
+
+ Parameters
+ ----------
+ jvm_vector_schema_root : org.apache.arrow.vector.VectorSchemaRoot
+
+ Returns
+ -------
+ record_batch: pyarrow.RecordBatch
+ """
+ pa_schema = schema(jvm_vector_schema_root.getSchema())
+
+ arrays = []
+ for name in pa_schema.names:
+ arrays.append(array(jvm_vector_schema_root.getVector(name)))
+
+ return pa.RecordBatch.from_arrays(
+ arrays,
+ pa_schema.names,
+ metadata=pa_schema.metadata
+ )
diff --git a/src/arrow/python/pyarrow/lib.pxd b/src/arrow/python/pyarrow/lib.pxd
new file mode 100644
index 000000000..e3b07f404
--- /dev/null
+++ b/src/arrow/python/pyarrow/lib.pxd
@@ -0,0 +1,604 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: language_level = 3
+
+from cpython cimport PyObject
+from libcpp cimport nullptr
+from libcpp.cast cimport dynamic_cast
+from pyarrow.includes.common cimport *
+from pyarrow.includes.libarrow cimport *
+
+
+cdef extern from "Python.h":
+ int PySlice_Check(object)
+
+
+cdef int check_status(const CStatus& status) nogil except -1
+
+
+cdef class _Weakrefable:
+ cdef object __weakref__
+
+
+cdef class IpcWriteOptions(_Weakrefable):
+ cdef:
+ CIpcWriteOptions c_options
+
+
+cdef class Message(_Weakrefable):
+ cdef:
+ unique_ptr[CMessage] message
+
+
+cdef class MemoryPool(_Weakrefable):
+ cdef:
+ CMemoryPool* pool
+
+ cdef void init(self, CMemoryPool* pool)
+
+
+cdef CMemoryPool* maybe_unbox_memory_pool(MemoryPool memory_pool)
+
+
+cdef class DataType(_Weakrefable):
+ cdef:
+ shared_ptr[CDataType] sp_type
+ CDataType* type
+ bytes pep3118_format
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *
+ cdef Field field(self, int i)
+
+
+cdef class ListType(DataType):
+ cdef:
+ const CListType* list_type
+
+
+cdef class LargeListType(DataType):
+ cdef:
+ const CLargeListType* list_type
+
+
+cdef class MapType(DataType):
+ cdef:
+ const CMapType* map_type
+
+
+cdef class FixedSizeListType(DataType):
+ cdef:
+ const CFixedSizeListType* list_type
+
+
+cdef class StructType(DataType):
+ cdef:
+ const CStructType* struct_type
+
+ cdef Field field_by_name(self, name)
+
+
+cdef class DictionaryMemo(_Weakrefable):
+ cdef:
+ # Even though the CDictionaryMemo instance is private, we allocate
+ # it on the heap so as to avoid C++ ABI issues with Python wheels.
+ shared_ptr[CDictionaryMemo] sp_memo
+ CDictionaryMemo* memo
+
+
+cdef class DictionaryType(DataType):
+ cdef:
+ const CDictionaryType* dict_type
+
+
+cdef class TimestampType(DataType):
+ cdef:
+ const CTimestampType* ts_type
+
+
+cdef class Time32Type(DataType):
+ cdef:
+ const CTime32Type* time_type
+
+
+cdef class Time64Type(DataType):
+ cdef:
+ const CTime64Type* time_type
+
+
+cdef class DurationType(DataType):
+ cdef:
+ const CDurationType* duration_type
+
+
+cdef class FixedSizeBinaryType(DataType):
+ cdef:
+ const CFixedSizeBinaryType* fixed_size_binary_type
+
+
+cdef class Decimal128Type(FixedSizeBinaryType):
+ cdef:
+ const CDecimal128Type* decimal128_type
+
+
+cdef class Decimal256Type(FixedSizeBinaryType):
+ cdef:
+ const CDecimal256Type* decimal256_type
+
+
+cdef class BaseExtensionType(DataType):
+ cdef:
+ const CExtensionType* ext_type
+
+
+cdef class ExtensionType(BaseExtensionType):
+ cdef:
+ const CPyExtensionType* cpy_ext_type
+
+
+cdef class PyExtensionType(ExtensionType):
+ pass
+
+
+cdef class _Metadata(_Weakrefable):
+ # required because KeyValueMetadata also extends collections.abc.Mapping
+ # and the first parent class must be an extension type
+ pass
+
+
+cdef class KeyValueMetadata(_Metadata):
+ cdef:
+ shared_ptr[const CKeyValueMetadata] wrapped
+ const CKeyValueMetadata* metadata
+
+ cdef void init(self, const shared_ptr[const CKeyValueMetadata]& wrapped)
+
+ @staticmethod
+ cdef wrap(const shared_ptr[const CKeyValueMetadata]& sp)
+ cdef inline shared_ptr[const CKeyValueMetadata] unwrap(self) nogil
+
+
+cdef class Field(_Weakrefable):
+ cdef:
+ shared_ptr[CField] sp_field
+ CField* field
+
+ cdef readonly:
+ DataType type
+
+ cdef void init(self, const shared_ptr[CField]& field)
+
+
+cdef class Schema(_Weakrefable):
+ cdef:
+ shared_ptr[CSchema] sp_schema
+ CSchema* schema
+
+ cdef void init(self, const vector[shared_ptr[CField]]& fields)
+ cdef void init_schema(self, const shared_ptr[CSchema]& schema)
+
+
+cdef class Scalar(_Weakrefable):
+ cdef:
+ shared_ptr[CScalar] wrapped
+
+ cdef void init(self, const shared_ptr[CScalar]& wrapped)
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CScalar]& wrapped)
+
+ cdef inline shared_ptr[CScalar] unwrap(self) nogil
+
+
+cdef class _PandasConvertible(_Weakrefable):
+ pass
+
+
+cdef class Array(_PandasConvertible):
+ cdef:
+ shared_ptr[CArray] sp_array
+ CArray* ap
+
+ cdef readonly:
+ DataType type
+ # To allow Table to propagate metadata to pandas.Series
+ object _name
+
+ cdef void init(self, const shared_ptr[CArray]& sp_array) except *
+ cdef getitem(self, int64_t i)
+ cdef int64_t length(self)
+
+
+cdef class Tensor(_Weakrefable):
+ cdef:
+ shared_ptr[CTensor] sp_tensor
+ CTensor* tp
+
+ cdef readonly:
+ DataType type
+
+ cdef void init(self, const shared_ptr[CTensor]& sp_tensor)
+
+
+cdef class SparseCSRMatrix(_Weakrefable):
+ cdef:
+ shared_ptr[CSparseCSRMatrix] sp_sparse_tensor
+ CSparseCSRMatrix* stp
+
+ cdef readonly:
+ DataType type
+
+ cdef void init(self, const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor)
+
+
+cdef class SparseCSCMatrix(_Weakrefable):
+ cdef:
+ shared_ptr[CSparseCSCMatrix] sp_sparse_tensor
+ CSparseCSCMatrix* stp
+
+ cdef readonly:
+ DataType type
+
+ cdef void init(self, const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor)
+
+
+cdef class SparseCOOTensor(_Weakrefable):
+ cdef:
+ shared_ptr[CSparseCOOTensor] sp_sparse_tensor
+ CSparseCOOTensor* stp
+
+ cdef readonly:
+ DataType type
+
+ cdef void init(self, const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor)
+
+
+cdef class SparseCSFTensor(_Weakrefable):
+ cdef:
+ shared_ptr[CSparseCSFTensor] sp_sparse_tensor
+ CSparseCSFTensor* stp
+
+ cdef readonly:
+ DataType type
+
+ cdef void init(self, const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor)
+
+
+cdef class NullArray(Array):
+ pass
+
+
+cdef class BooleanArray(Array):
+ pass
+
+
+cdef class NumericArray(Array):
+ pass
+
+
+cdef class IntegerArray(NumericArray):
+ pass
+
+
+cdef class FloatingPointArray(NumericArray):
+ pass
+
+
+cdef class Int8Array(IntegerArray):
+ pass
+
+
+cdef class UInt8Array(IntegerArray):
+ pass
+
+
+cdef class Int16Array(IntegerArray):
+ pass
+
+
+cdef class UInt16Array(IntegerArray):
+ pass
+
+
+cdef class Int32Array(IntegerArray):
+ pass
+
+
+cdef class UInt32Array(IntegerArray):
+ pass
+
+
+cdef class Int64Array(IntegerArray):
+ pass
+
+
+cdef class UInt64Array(IntegerArray):
+ pass
+
+
+cdef class HalfFloatArray(FloatingPointArray):
+ pass
+
+
+cdef class FloatArray(FloatingPointArray):
+ pass
+
+
+cdef class DoubleArray(FloatingPointArray):
+ pass
+
+
+cdef class FixedSizeBinaryArray(Array):
+ pass
+
+
+cdef class Decimal128Array(FixedSizeBinaryArray):
+ pass
+
+
+cdef class Decimal256Array(FixedSizeBinaryArray):
+ pass
+
+
+cdef class StructArray(Array):
+ pass
+
+
+cdef class BaseListArray(Array):
+ pass
+
+
+cdef class ListArray(BaseListArray):
+ pass
+
+
+cdef class LargeListArray(BaseListArray):
+ pass
+
+
+cdef class MapArray(Array):
+ pass
+
+
+cdef class FixedSizeListArray(Array):
+ pass
+
+
+cdef class UnionArray(Array):
+ pass
+
+
+cdef class StringArray(Array):
+ pass
+
+
+cdef class BinaryArray(Array):
+ pass
+
+
+cdef class DictionaryArray(Array):
+ cdef:
+ object _indices, _dictionary
+
+
+cdef class ExtensionArray(Array):
+ pass
+
+
+cdef class MonthDayNanoIntervalArray(Array):
+ pass
+
+
+cdef wrap_array_output(PyObject* output)
+cdef wrap_datum(const CDatum& datum)
+
+
+cdef class ChunkedArray(_PandasConvertible):
+ cdef:
+ shared_ptr[CChunkedArray] sp_chunked_array
+ CChunkedArray* chunked_array
+
+ cdef readonly:
+ # To allow Table to propagate metadata to pandas.Series
+ object _name
+
+ cdef void init(self, const shared_ptr[CChunkedArray]& chunked_array)
+ cdef getitem(self, int64_t i)
+
+
+cdef class Table(_PandasConvertible):
+ cdef:
+ shared_ptr[CTable] sp_table
+ CTable* table
+
+ cdef void init(self, const shared_ptr[CTable]& table)
+
+
+cdef class RecordBatch(_PandasConvertible):
+ cdef:
+ shared_ptr[CRecordBatch] sp_batch
+ CRecordBatch* batch
+ Schema _schema
+
+ cdef void init(self, const shared_ptr[CRecordBatch]& table)
+
+
+cdef class Buffer(_Weakrefable):
+ cdef:
+ shared_ptr[CBuffer] buffer
+ Py_ssize_t shape[1]
+ Py_ssize_t strides[1]
+
+ cdef void init(self, const shared_ptr[CBuffer]& buffer)
+ cdef getitem(self, int64_t i)
+
+
+cdef class ResizableBuffer(Buffer):
+
+ cdef void init_rz(self, const shared_ptr[CResizableBuffer]& buffer)
+
+
+cdef class NativeFile(_Weakrefable):
+ cdef:
+ shared_ptr[CInputStream] input_stream
+ shared_ptr[CRandomAccessFile] random_access
+ shared_ptr[COutputStream] output_stream
+ bint is_readable
+ bint is_writable
+ bint is_seekable
+ bint own_file
+
+ # By implementing these "virtual" functions (all functions in Cython
+ # extension classes are technically virtual in the C++ sense) we can expose
+ # the arrow::io abstract file interfaces to other components throughout the
+ # suite of Arrow C++ libraries
+ cdef set_random_access_file(self, shared_ptr[CRandomAccessFile] handle)
+ cdef set_input_stream(self, shared_ptr[CInputStream] handle)
+ cdef set_output_stream(self, shared_ptr[COutputStream] handle)
+
+ cdef shared_ptr[CRandomAccessFile] get_random_access_file(self) except *
+ cdef shared_ptr[CInputStream] get_input_stream(self) except *
+ cdef shared_ptr[COutputStream] get_output_stream(self) except *
+
+
+cdef class BufferedInputStream(NativeFile):
+ pass
+
+
+cdef class BufferedOutputStream(NativeFile):
+ pass
+
+
+cdef class CompressedInputStream(NativeFile):
+ pass
+
+
+cdef class CompressedOutputStream(NativeFile):
+ pass
+
+
+cdef class _CRecordBatchWriter(_Weakrefable):
+ cdef:
+ shared_ptr[CRecordBatchWriter] writer
+
+
+cdef class RecordBatchReader(_Weakrefable):
+ cdef:
+ shared_ptr[CRecordBatchReader] reader
+
+
+cdef class Codec(_Weakrefable):
+ cdef:
+ shared_ptr[CCodec] wrapped
+
+ cdef inline CCodec* unwrap(self) nogil
+
+
+# This class is only used internally for now
+cdef class StopToken:
+ cdef:
+ CStopToken stop_token
+
+ cdef void init(self, CStopToken stop_token)
+
+
+cdef get_input_stream(object source, c_bool use_memory_map,
+ shared_ptr[CInputStream]* reader)
+cdef get_reader(object source, c_bool use_memory_map,
+ shared_ptr[CRandomAccessFile]* reader)
+cdef get_writer(object source, shared_ptr[COutputStream]* writer)
+cdef NativeFile get_native_file(object source, c_bool use_memory_map)
+
+cdef shared_ptr[CInputStream] native_transcoding_input_stream(
+ shared_ptr[CInputStream] stream, src_encoding,
+ dest_encoding) except *
+
+# Default is allow_none=False
+cpdef DataType ensure_type(object type, bint allow_none=*)
+
+cdef timeunit_to_string(TimeUnit unit)
+cdef TimeUnit string_to_timeunit(unit) except *
+
+# Exceptions may be raised when converting dict values, so need to
+# check exception state on return
+cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata(
+ object meta) except *
+cdef object pyarrow_wrap_metadata(
+ const shared_ptr[const CKeyValueMetadata]& meta)
+
+#
+# Public Cython API for 3rd party code
+#
+# If you add functions to this list, please also update
+# `cpp/src/arrow/python/pyarrow.{h, cc}`
+#
+
+# Wrapping C++ -> Python
+
+cdef public object pyarrow_wrap_buffer(const shared_ptr[CBuffer]& buf)
+cdef public object pyarrow_wrap_resizable_buffer(
+ const shared_ptr[CResizableBuffer]& buf)
+
+cdef public object pyarrow_wrap_data_type(const shared_ptr[CDataType]& type)
+cdef public object pyarrow_wrap_field(const shared_ptr[CField]& field)
+cdef public object pyarrow_wrap_schema(const shared_ptr[CSchema]& type)
+
+cdef public object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar)
+
+cdef public object pyarrow_wrap_array(const shared_ptr[CArray]& sp_array)
+cdef public object pyarrow_wrap_chunked_array(
+ const shared_ptr[CChunkedArray]& sp_array)
+
+cdef public object pyarrow_wrap_sparse_coo_tensor(
+ const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor)
+cdef public object pyarrow_wrap_sparse_csc_matrix(
+ const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor)
+cdef public object pyarrow_wrap_sparse_csf_tensor(
+ const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor)
+cdef public object pyarrow_wrap_sparse_csr_matrix(
+ const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor)
+cdef public object pyarrow_wrap_tensor(const shared_ptr[CTensor]& sp_tensor)
+
+cdef public object pyarrow_wrap_batch(const shared_ptr[CRecordBatch]& cbatch)
+cdef public object pyarrow_wrap_table(const shared_ptr[CTable]& ctable)
+
+# Unwrapping Python -> C++
+
+cdef public shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer)
+
+cdef public shared_ptr[CDataType] pyarrow_unwrap_data_type(object data_type)
+cdef public shared_ptr[CField] pyarrow_unwrap_field(object field)
+cdef public shared_ptr[CSchema] pyarrow_unwrap_schema(object schema)
+
+cdef public shared_ptr[CScalar] pyarrow_unwrap_scalar(object scalar)
+
+cdef public shared_ptr[CArray] pyarrow_unwrap_array(object array)
+cdef public shared_ptr[CChunkedArray] pyarrow_unwrap_chunked_array(
+ object array)
+
+cdef public shared_ptr[CSparseCOOTensor] pyarrow_unwrap_sparse_coo_tensor(
+ object sparse_tensor)
+cdef public shared_ptr[CSparseCSCMatrix] pyarrow_unwrap_sparse_csc_matrix(
+ object sparse_tensor)
+cdef public shared_ptr[CSparseCSFTensor] pyarrow_unwrap_sparse_csf_tensor(
+ object sparse_tensor)
+cdef public shared_ptr[CSparseCSRMatrix] pyarrow_unwrap_sparse_csr_matrix(
+ object sparse_tensor)
+cdef public shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor)
+
+cdef public shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch)
+cdef public shared_ptr[CTable] pyarrow_unwrap_table(object table)
diff --git a/src/arrow/python/pyarrow/lib.pyx b/src/arrow/python/pyarrow/lib.pyx
new file mode 100644
index 000000000..0c9cbcc5b
--- /dev/null
+++ b/src/arrow/python/pyarrow/lib.pyx
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile = False
+# cython: nonecheck = True
+# distutils: language = c++
+
+import datetime
+import decimal as _pydecimal
+import numpy as np
+import os
+import sys
+
+from cython.operator cimport dereference as deref
+from pyarrow.includes.libarrow cimport *
+from pyarrow.includes.common cimport PyObject_to_object
+cimport pyarrow.includes.libarrow as libarrow
+cimport cpython as cp
+
+# Initialize NumPy C API
+arrow_init_numpy()
+# Initialize PyArrow C++ API
+# (used from some of our C++ code, see e.g. ARROW-5260)
+import_pyarrow()
+
+
+MonthDayNano = NewMonthDayNanoTupleType()
+
+
+def cpu_count():
+ """
+ Return the number of threads to use in parallel operations.
+
+ The number of threads is determined at startup by inspecting the
+ ``OMP_NUM_THREADS`` and ``OMP_THREAD_LIMIT`` environment variables.
+ If neither is present, it will default to the number of hardware threads
+ on the system. It can be modified at runtime by calling
+ :func:`set_cpu_count()`.
+
+ See Also
+ --------
+ set_cpu_count : Modify the size of this pool.
+ io_thread_count : The analogous function for the I/O thread pool.
+ """
+ return GetCpuThreadPoolCapacity()
+
+
+def set_cpu_count(int count):
+ """
+ Set the number of threads to use in parallel operations.
+
+ Parameters
+ ----------
+ count : int
+ The number of concurrent threads that should be used.
+
+ See Also
+ --------
+ cpu_count : Get the size of this pool.
+ set_io_thread_count : The analogous function for the I/O thread pool.
+ """
+ if count < 1:
+ raise ValueError("CPU count must be strictly positive")
+ check_status(SetCpuThreadPoolCapacity(count))
+
+
+Type_NA = _Type_NA
+Type_BOOL = _Type_BOOL
+Type_UINT8 = _Type_UINT8
+Type_INT8 = _Type_INT8
+Type_UINT16 = _Type_UINT16
+Type_INT16 = _Type_INT16
+Type_UINT32 = _Type_UINT32
+Type_INT32 = _Type_INT32
+Type_UINT64 = _Type_UINT64
+Type_INT64 = _Type_INT64
+Type_HALF_FLOAT = _Type_HALF_FLOAT
+Type_FLOAT = _Type_FLOAT
+Type_DOUBLE = _Type_DOUBLE
+Type_DECIMAL128 = _Type_DECIMAL128
+Type_DECIMAL256 = _Type_DECIMAL256
+Type_DATE32 = _Type_DATE32
+Type_DATE64 = _Type_DATE64
+Type_TIMESTAMP = _Type_TIMESTAMP
+Type_TIME32 = _Type_TIME32
+Type_TIME64 = _Type_TIME64
+Type_DURATION = _Type_DURATION
+Type_INTERVAL_MONTH_DAY_NANO = _Type_INTERVAL_MONTH_DAY_NANO
+Type_BINARY = _Type_BINARY
+Type_STRING = _Type_STRING
+Type_LARGE_BINARY = _Type_LARGE_BINARY
+Type_LARGE_STRING = _Type_LARGE_STRING
+Type_FIXED_SIZE_BINARY = _Type_FIXED_SIZE_BINARY
+Type_LIST = _Type_LIST
+Type_LARGE_LIST = _Type_LARGE_LIST
+Type_MAP = _Type_MAP
+Type_FIXED_SIZE_LIST = _Type_FIXED_SIZE_LIST
+Type_STRUCT = _Type_STRUCT
+Type_SPARSE_UNION = _Type_SPARSE_UNION
+Type_DENSE_UNION = _Type_DENSE_UNION
+Type_DICTIONARY = _Type_DICTIONARY
+
+UnionMode_SPARSE = _UnionMode_SPARSE
+UnionMode_DENSE = _UnionMode_DENSE
+
+
+def _pc():
+ import pyarrow.compute as pc
+ return pc
+
+
+# Assorted compatibility helpers
+include "compat.pxi"
+
+# Exception types and Status handling
+include "error.pxi"
+
+# Configuration information
+include "config.pxi"
+
+# pandas API shim
+include "pandas-shim.pxi"
+
+# Memory pools and allocation
+include "memory.pxi"
+
+# DataType, Field, Schema
+include "types.pxi"
+
+# Array scalar values
+include "scalar.pxi"
+
+# Array types
+include "array.pxi"
+
+# Builders
+include "builder.pxi"
+
+# Column, Table, Record Batch
+include "table.pxi"
+
+# Tensors
+include "tensor.pxi"
+
+# File IO
+include "io.pxi"
+
+# IPC / Messaging
+include "ipc.pxi"
+
+# Python serialization
+include "serialization.pxi"
+
+# Micro-benchmark routines
+include "benchmark.pxi"
+
+# Public API
+include "public-api.pxi"
diff --git a/src/arrow/python/pyarrow/memory.pxi b/src/arrow/python/pyarrow/memory.pxi
new file mode 100644
index 000000000..8ccb35058
--- /dev/null
+++ b/src/arrow/python/pyarrow/memory.pxi
@@ -0,0 +1,249 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# cython: profile=False
+# distutils: language = c++
+# cython: embedsignature = True
+
+
+cdef class MemoryPool(_Weakrefable):
+ """
+ Base class for memory allocation.
+
+ Besides tracking its number of allocated bytes, a memory pool also
+ takes care of the required 64-byte alignment for Arrow data.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, "
+ "use pyarrow.*_memory_pool instead."
+ .format(self.__class__.__name__))
+
+ cdef void init(self, CMemoryPool* pool):
+ self.pool = pool
+
+ def release_unused(self):
+ """
+ Attempt to return to the OS any memory being held onto by the pool.
+
+ This function should not be called except potentially for
+ benchmarking or debugging as it could be expensive and detrimental to
+ performance.
+
+ This is best effort and may not have any effect on some memory pools
+ or in some situations (e.g. fragmentation).
+ """
+ cdef CMemoryPool* pool = c_get_memory_pool()
+ with nogil:
+ pool.ReleaseUnused()
+
+ def bytes_allocated(self):
+ """
+ Return the number of bytes that are currently allocated from this
+ memory pool.
+ """
+ return self.pool.bytes_allocated()
+
+ def max_memory(self):
+ """
+ Return the peak memory allocation in this memory pool.
+ This can be an approximate number in multi-threaded applications.
+
+ None is returned if the pool implementation doesn't know how to
+ compute this number.
+ """
+ ret = self.pool.max_memory()
+ return ret if ret >= 0 else None
+
+ @property
+ def backend_name(self):
+ """
+ The name of the backend used by this MemoryPool (e.g. "jemalloc").
+ """
+ return frombytes(self.pool.backend_name())
+
+
+cdef CMemoryPool* maybe_unbox_memory_pool(MemoryPool memory_pool):
+ if memory_pool is None:
+ return c_get_memory_pool()
+ else:
+ return memory_pool.pool
+
+
+cdef class LoggingMemoryPool(MemoryPool):
+ cdef:
+ unique_ptr[CLoggingMemoryPool] logging_pool
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, "
+ "use pyarrow.logging_memory_pool instead."
+ .format(self.__class__.__name__))
+
+
+cdef class ProxyMemoryPool(MemoryPool):
+ """
+ Memory pool implementation that tracks the number of bytes and
+ maximum memory allocated through its direct calls, while redirecting
+ to another memory pool.
+ """
+ cdef:
+ unique_ptr[CProxyMemoryPool] proxy_pool
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, "
+ "use pyarrow.proxy_memory_pool instead."
+ .format(self.__class__.__name__))
+
+
+def default_memory_pool():
+ """
+ Return the process-global memory pool.
+ """
+ cdef:
+ MemoryPool pool = MemoryPool.__new__(MemoryPool)
+ pool.init(c_get_memory_pool())
+ return pool
+
+
+def proxy_memory_pool(MemoryPool parent):
+ """
+ Create and return a MemoryPool instance that redirects to the
+ *parent*, but with separate allocation statistics.
+
+ Parameters
+ ----------
+ parent : MemoryPool
+ The real memory pool that should be used for allocations.
+ """
+ cdef ProxyMemoryPool out = ProxyMemoryPool.__new__(ProxyMemoryPool)
+ out.proxy_pool.reset(new CProxyMemoryPool(parent.pool))
+ out.init(out.proxy_pool.get())
+ return out
+
+
+def logging_memory_pool(MemoryPool parent):
+ """
+ Create and return a MemoryPool instance that redirects to the
+ *parent*, but also dumps allocation logs on stderr.
+
+ Parameters
+ ----------
+ parent : MemoryPool
+ The real memory pool that should be used for allocations.
+ """
+ cdef LoggingMemoryPool out = LoggingMemoryPool.__new__(
+ LoggingMemoryPool, parent)
+ out.logging_pool.reset(new CLoggingMemoryPool(parent.pool))
+ out.init(out.logging_pool.get())
+ return out
+
+
+def system_memory_pool():
+ """
+ Return a memory pool based on the C malloc heap.
+ """
+ cdef:
+ MemoryPool pool = MemoryPool.__new__(MemoryPool)
+ pool.init(c_system_memory_pool())
+ return pool
+
+
+def jemalloc_memory_pool():
+ """
+ Return a memory pool based on the jemalloc heap.
+
+ NotImplementedError is raised if jemalloc support is not enabled.
+ """
+ cdef:
+ CMemoryPool* c_pool
+ MemoryPool pool = MemoryPool.__new__(MemoryPool)
+ check_status(c_jemalloc_memory_pool(&c_pool))
+ pool.init(c_pool)
+ return pool
+
+
+def mimalloc_memory_pool():
+ """
+ Return a memory pool based on the mimalloc heap.
+
+ NotImplementedError is raised if mimalloc support is not enabled.
+ """
+ cdef:
+ CMemoryPool* c_pool
+ MemoryPool pool = MemoryPool.__new__(MemoryPool)
+ check_status(c_mimalloc_memory_pool(&c_pool))
+ pool.init(c_pool)
+ return pool
+
+
+def set_memory_pool(MemoryPool pool):
+ """
+ Set the default memory pool.
+
+ Parameters
+ ----------
+ pool : MemoryPool
+ The memory pool that should be used by default.
+ """
+ c_set_default_memory_pool(pool.pool)
+
+
+cdef MemoryPool _default_memory_pool = default_memory_pool()
+cdef LoggingMemoryPool _logging_memory_pool = logging_memory_pool(
+ _default_memory_pool)
+
+
+def log_memory_allocations(enable=True):
+ """
+ Enable or disable memory allocator logging for debugging purposes
+
+ Parameters
+ ----------
+ enable : bool, default True
+ Pass False to disable logging
+ """
+ if enable:
+ set_memory_pool(_logging_memory_pool)
+ else:
+ set_memory_pool(_default_memory_pool)
+
+
+def total_allocated_bytes():
+ """
+ Return the currently allocated bytes from the default memory pool.
+ Other memory pools may not be accounted for.
+ """
+ cdef CMemoryPool* pool = c_get_memory_pool()
+ return pool.bytes_allocated()
+
+
+def jemalloc_set_decay_ms(decay_ms):
+ """
+ Set arenas.dirty_decay_ms and arenas.muzzy_decay_ms to indicated number of
+ milliseconds. A value of 0 (the default) results in dirty / muzzy memory
+ pages being released right away to the OS, while a higher value will result
+ in a time-based decay. See the jemalloc docs for more information
+
+ It's best to set this at the start of your application.
+
+ Parameters
+ ----------
+ decay_ms : int
+ Number of milliseconds to set for jemalloc decay conf parameters. Note
+ that this change will only affect future memory arenas
+ """
+ check_status(c_jemalloc_set_decay_ms(decay_ms))
diff --git a/src/arrow/python/pyarrow/orc.py b/src/arrow/python/pyarrow/orc.py
new file mode 100644
index 000000000..87f805886
--- /dev/null
+++ b/src/arrow/python/pyarrow/orc.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from numbers import Integral
+import warnings
+
+from pyarrow.lib import Table
+import pyarrow._orc as _orc
+
+
+class ORCFile:
+ """
+ Reader interface for a single ORC file
+
+ Parameters
+ ----------
+ source : str or pyarrow.io.NativeFile
+ Readable source. For passing Python file objects or byte buffers,
+ see pyarrow.io.PythonFileInterface or pyarrow.io.BufferReader.
+ """
+
+ def __init__(self, source):
+ self.reader = _orc.ORCReader()
+ self.reader.open(source)
+
+ @property
+ def metadata(self):
+ """The file metadata, as an arrow KeyValueMetadata"""
+ return self.reader.metadata()
+
+ @property
+ def schema(self):
+ """The file schema, as an arrow schema"""
+ return self.reader.schema()
+
+ @property
+ def nrows(self):
+ """The number of rows in the file"""
+ return self.reader.nrows()
+
+ @property
+ def nstripes(self):
+ """The number of stripes in the file"""
+ return self.reader.nstripes()
+
+ def _select_names(self, columns=None):
+ if columns is None:
+ return None
+
+ schema = self.schema
+ names = []
+ for col in columns:
+ if isinstance(col, Integral):
+ col = int(col)
+ if 0 <= col < len(schema):
+ col = schema[col].name
+ names.append(col)
+ else:
+ raise ValueError("Column indices must be in 0 <= ind < %d,"
+ " got %d" % (len(schema), col))
+ else:
+ return columns
+
+ return names
+
+ def read_stripe(self, n, columns=None):
+ """Read a single stripe from the file.
+
+ Parameters
+ ----------
+ n : int
+ The stripe index
+ columns : list
+ If not None, only these columns will be read from the stripe. A
+ column name may be a prefix of a nested field, e.g. 'a' will select
+ 'a.b', 'a.c', and 'a.d.e'
+
+ Returns
+ -------
+ pyarrow.lib.RecordBatch
+ Content of the stripe as a RecordBatch.
+ """
+ columns = self._select_names(columns)
+ return self.reader.read_stripe(n, columns=columns)
+
+ def read(self, columns=None):
+ """Read the whole file.
+
+ Parameters
+ ----------
+ columns : list
+ If not None, only these columns will be read from the file. A
+ column name may be a prefix of a nested field, e.g. 'a' will select
+ 'a.b', 'a.c', and 'a.d.e'
+
+ Returns
+ -------
+ pyarrow.lib.Table
+ Content of the file as a Table.
+ """
+ columns = self._select_names(columns)
+ return self.reader.read(columns=columns)
+
+
+class ORCWriter:
+ """
+ Writer interface for a single ORC file
+
+ Parameters
+ ----------
+ where : str or pyarrow.io.NativeFile
+ Writable target. For passing Python file objects or byte buffers,
+ see pyarrow.io.PythonFileInterface, pyarrow.io.BufferOutputStream
+ or pyarrow.io.FixedSizeBufferWriter.
+ """
+
+ def __init__(self, where):
+ self.writer = _orc.ORCWriter()
+ self.writer.open(where)
+
+ def write(self, table):
+ """
+ Write the table into an ORC file. The schema of the table must
+ be equal to the schema used when opening the ORC file.
+
+ Parameters
+ ----------
+ schema : pyarrow.lib.Table
+ The table to be written into the ORC file
+ """
+ self.writer.write(table)
+
+ def close(self):
+ """
+ Close the ORC file
+ """
+ self.writer.close()
+
+
+def write_table(table, where):
+ """
+ Write a table into an ORC file
+
+ Parameters
+ ----------
+ table : pyarrow.lib.Table
+ The table to be written into the ORC file
+ where : str or pyarrow.io.NativeFile
+ Writable target. For passing Python file objects or byte buffers,
+ see pyarrow.io.PythonFileInterface, pyarrow.io.BufferOutputStream
+ or pyarrow.io.FixedSizeBufferWriter.
+ """
+ if isinstance(where, Table):
+ warnings.warn(
+ "The order of the arguments has changed. Pass as "
+ "'write_table(table, where)' instead. The old order will raise "
+ "an error in the future.", FutureWarning, stacklevel=2
+ )
+ table, where = where, table
+ writer = ORCWriter(where)
+ writer.write(table)
+ writer.close()
diff --git a/src/arrow/python/pyarrow/pandas-shim.pxi b/src/arrow/python/pyarrow/pandas-shim.pxi
new file mode 100644
index 000000000..0e7cfe937
--- /dev/null
+++ b/src/arrow/python/pyarrow/pandas-shim.pxi
@@ -0,0 +1,254 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pandas lazy-loading API shim that reduces API call and import overhead
+
+import warnings
+
+
+cdef class _PandasAPIShim(object):
+ """
+ Lazy pandas importer that isolates usages of pandas APIs and avoids
+ importing pandas until it's actually needed
+ """
+ cdef:
+ bint _tried_importing_pandas
+ bint _have_pandas
+
+ cdef readonly:
+ object _loose_version, _version
+ object _pd, _types_api, _compat_module
+ object _data_frame, _index, _series, _categorical_type
+ object _datetimetz_type, _extension_array, _extension_dtype
+ object _array_like_types, _is_extension_array_dtype
+ bint has_sparse
+ bint _pd024
+
+ def __init__(self):
+ self._tried_importing_pandas = False
+ self._have_pandas = 0
+
+ cdef _import_pandas(self, bint raise_):
+ try:
+ import pandas as pd
+ import pyarrow.pandas_compat as pdcompat
+ except ImportError:
+ self._have_pandas = False
+ if raise_:
+ raise
+ else:
+ return
+
+ from pyarrow.vendored.version import Version
+
+ self._pd = pd
+ self._version = pd.__version__
+ self._loose_version = Version(pd.__version__)
+
+ if self._loose_version < Version('0.23.0'):
+ self._have_pandas = False
+ if raise_:
+ raise ImportError(
+ "pyarrow requires pandas 0.23.0 or above, pandas {} is "
+ "installed".format(self._version)
+ )
+ else:
+ warnings.warn(
+ "pyarrow requires pandas 0.23.0 or above, pandas {} is "
+ "installed. Therefore, pandas-specific integration is not "
+ "used.".format(self._version), stacklevel=2)
+ return
+
+ self._compat_module = pdcompat
+ self._data_frame = pd.DataFrame
+ self._index = pd.Index
+ self._categorical_type = pd.Categorical
+ self._series = pd.Series
+ self._extension_array = pd.api.extensions.ExtensionArray
+ self._array_like_types = (
+ self._series, self._index, self._categorical_type,
+ self._extension_array)
+ self._extension_dtype = pd.api.extensions.ExtensionDtype
+ if self._loose_version >= Version('0.24.0'):
+ self._is_extension_array_dtype = \
+ pd.api.types.is_extension_array_dtype
+ else:
+ self._is_extension_array_dtype = None
+
+ self._types_api = pd.api.types
+ self._datetimetz_type = pd.api.types.DatetimeTZDtype
+ self._have_pandas = True
+
+ if self._loose_version > Version('0.25'):
+ self.has_sparse = False
+ else:
+ self.has_sparse = True
+
+ self._pd024 = self._loose_version >= Version('0.24')
+
+ cdef inline _check_import(self, bint raise_=True):
+ if self._tried_importing_pandas:
+ if not self._have_pandas and raise_:
+ self._import_pandas(raise_)
+ return
+
+ self._tried_importing_pandas = True
+ self._import_pandas(raise_)
+
+ def series(self, *args, **kwargs):
+ self._check_import()
+ return self._series(*args, **kwargs)
+
+ def data_frame(self, *args, **kwargs):
+ self._check_import()
+ return self._data_frame(*args, **kwargs)
+
+ cdef inline bint _have_pandas_internal(self):
+ if not self._tried_importing_pandas:
+ self._check_import(raise_=False)
+ return self._have_pandas
+
+ @property
+ def have_pandas(self):
+ return self._have_pandas_internal()
+
+ @property
+ def compat(self):
+ self._check_import()
+ return self._compat_module
+
+ @property
+ def pd(self):
+ self._check_import()
+ return self._pd
+
+ cpdef infer_dtype(self, obj):
+ self._check_import()
+ try:
+ return self._types_api.infer_dtype(obj, skipna=False)
+ except AttributeError:
+ return self._pd.lib.infer_dtype(obj)
+
+ cpdef pandas_dtype(self, dtype):
+ self._check_import()
+ try:
+ return self._types_api.pandas_dtype(dtype)
+ except AttributeError:
+ return None
+
+ @property
+ def loose_version(self):
+ self._check_import()
+ return self._loose_version
+
+ @property
+ def version(self):
+ self._check_import()
+ return self._version
+
+ @property
+ def categorical_type(self):
+ self._check_import()
+ return self._categorical_type
+
+ @property
+ def datetimetz_type(self):
+ self._check_import()
+ return self._datetimetz_type
+
+ @property
+ def extension_dtype(self):
+ self._check_import()
+ return self._extension_dtype
+
+ cpdef is_array_like(self, obj):
+ self._check_import()
+ return isinstance(obj, self._array_like_types)
+
+ cpdef is_categorical(self, obj):
+ if self._have_pandas_internal():
+ return isinstance(obj, self._categorical_type)
+ else:
+ return False
+
+ cpdef is_datetimetz(self, obj):
+ if self._have_pandas_internal():
+ return isinstance(obj, self._datetimetz_type)
+ else:
+ return False
+
+ cpdef is_extension_array_dtype(self, obj):
+ self._check_import()
+ if self._is_extension_array_dtype:
+ return self._is_extension_array_dtype(obj)
+ else:
+ return False
+
+ cpdef is_sparse(self, obj):
+ if self._have_pandas_internal():
+ return self._types_api.is_sparse(obj)
+ else:
+ return False
+
+ cpdef is_data_frame(self, obj):
+ if self._have_pandas_internal():
+ return isinstance(obj, self._data_frame)
+ else:
+ return False
+
+ cpdef is_series(self, obj):
+ if self._have_pandas_internal():
+ return isinstance(obj, self._series)
+ else:
+ return False
+
+ cpdef is_index(self, obj):
+ if self._have_pandas_internal():
+ return isinstance(obj, self._index)
+ else:
+ return False
+
+ cpdef get_values(self, obj):
+ """
+ Get the underlying array values of a pandas Series or Index in the
+ format (np.ndarray or pandas ExtensionArray) as we need them.
+
+ Assumes obj is a pandas Series or Index.
+ """
+ self._check_import()
+ if isinstance(obj.dtype, (self.pd.api.types.IntervalDtype,
+ self.pd.api.types.PeriodDtype)):
+ if self._pd024:
+ # only since pandas 0.24, interval and period are stored as
+ # such in Series
+ return obj.array
+ return obj.values
+
+ def assert_frame_equal(self, *args, **kwargs):
+ self._check_import()
+ return self._pd.util.testing.assert_frame_equal
+
+ def get_rangeindex_attribute(self, level, name):
+ # public start/stop/step attributes added in pandas 0.25.0
+ self._check_import()
+ if hasattr(level, name):
+ return getattr(level, name)
+ return getattr(level, '_' + name)
+
+
+cdef _PandasAPIShim pandas_api = _PandasAPIShim()
+_pandas_api = pandas_api
diff --git a/src/arrow/python/pyarrow/pandas_compat.py b/src/arrow/python/pyarrow/pandas_compat.py
new file mode 100644
index 000000000..e4b13175f
--- /dev/null
+++ b/src/arrow/python/pyarrow/pandas_compat.py
@@ -0,0 +1,1226 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import ast
+from collections.abc import Sequence
+from concurrent import futures
+# import threading submodule upfront to avoid partially initialized
+# module bug (ARROW-11983)
+import concurrent.futures.thread # noqa
+from copy import deepcopy
+from itertools import zip_longest
+import json
+import operator
+import re
+import warnings
+
+import numpy as np
+
+import pyarrow as pa
+from pyarrow.lib import _pandas_api, builtin_pickle, frombytes # noqa
+
+
+_logical_type_map = {}
+
+
+def get_logical_type_map():
+ global _logical_type_map
+
+ if not _logical_type_map:
+ _logical_type_map.update({
+ pa.lib.Type_NA: 'empty',
+ pa.lib.Type_BOOL: 'bool',
+ pa.lib.Type_INT8: 'int8',
+ pa.lib.Type_INT16: 'int16',
+ pa.lib.Type_INT32: 'int32',
+ pa.lib.Type_INT64: 'int64',
+ pa.lib.Type_UINT8: 'uint8',
+ pa.lib.Type_UINT16: 'uint16',
+ pa.lib.Type_UINT32: 'uint32',
+ pa.lib.Type_UINT64: 'uint64',
+ pa.lib.Type_HALF_FLOAT: 'float16',
+ pa.lib.Type_FLOAT: 'float32',
+ pa.lib.Type_DOUBLE: 'float64',
+ pa.lib.Type_DATE32: 'date',
+ pa.lib.Type_DATE64: 'date',
+ pa.lib.Type_TIME32: 'time',
+ pa.lib.Type_TIME64: 'time',
+ pa.lib.Type_BINARY: 'bytes',
+ pa.lib.Type_FIXED_SIZE_BINARY: 'bytes',
+ pa.lib.Type_STRING: 'unicode',
+ })
+ return _logical_type_map
+
+
+def get_logical_type(arrow_type):
+ logical_type_map = get_logical_type_map()
+
+ try:
+ return logical_type_map[arrow_type.id]
+ except KeyError:
+ if isinstance(arrow_type, pa.lib.DictionaryType):
+ return 'categorical'
+ elif isinstance(arrow_type, pa.lib.ListType):
+ return 'list[{}]'.format(get_logical_type(arrow_type.value_type))
+ elif isinstance(arrow_type, pa.lib.TimestampType):
+ return 'datetimetz' if arrow_type.tz is not None else 'datetime'
+ elif isinstance(arrow_type, pa.lib.Decimal128Type):
+ return 'decimal'
+ return 'object'
+
+
+_numpy_logical_type_map = {
+ np.bool_: 'bool',
+ np.int8: 'int8',
+ np.int16: 'int16',
+ np.int32: 'int32',
+ np.int64: 'int64',
+ np.uint8: 'uint8',
+ np.uint16: 'uint16',
+ np.uint32: 'uint32',
+ np.uint64: 'uint64',
+ np.float32: 'float32',
+ np.float64: 'float64',
+ 'datetime64[D]': 'date',
+ np.unicode_: 'string',
+ np.bytes_: 'bytes',
+}
+
+
+def get_logical_type_from_numpy(pandas_collection):
+ try:
+ return _numpy_logical_type_map[pandas_collection.dtype.type]
+ except KeyError:
+ if hasattr(pandas_collection.dtype, 'tz'):
+ return 'datetimetz'
+ # See https://github.com/pandas-dev/pandas/issues/24739
+ if str(pandas_collection.dtype) == 'datetime64[ns]':
+ return 'datetime64[ns]'
+ result = _pandas_api.infer_dtype(pandas_collection)
+ if result == 'string':
+ return 'unicode'
+ return result
+
+
+def get_extension_dtype_info(column):
+ dtype = column.dtype
+ if str(dtype) == 'category':
+ cats = getattr(column, 'cat', column)
+ assert cats is not None
+ metadata = {
+ 'num_categories': len(cats.categories),
+ 'ordered': cats.ordered,
+ }
+ physical_dtype = str(cats.codes.dtype)
+ elif hasattr(dtype, 'tz'):
+ metadata = {'timezone': pa.lib.tzinfo_to_string(dtype.tz)}
+ physical_dtype = 'datetime64[ns]'
+ else:
+ metadata = None
+ physical_dtype = str(dtype)
+ return physical_dtype, metadata
+
+
+def get_column_metadata(column, name, arrow_type, field_name):
+ """Construct the metadata for a given column
+
+ Parameters
+ ----------
+ column : pandas.Series or pandas.Index
+ name : str
+ arrow_type : pyarrow.DataType
+ field_name : str
+ Equivalent to `name` when `column` is a `Series`, otherwise if `column`
+ is a pandas Index then `field_name` will not be the same as `name`.
+ This is the name of the field in the arrow Table's schema.
+
+ Returns
+ -------
+ dict
+ """
+ logical_type = get_logical_type(arrow_type)
+
+ string_dtype, extra_metadata = get_extension_dtype_info(column)
+ if logical_type == 'decimal':
+ extra_metadata = {
+ 'precision': arrow_type.precision,
+ 'scale': arrow_type.scale,
+ }
+ string_dtype = 'object'
+
+ if name is not None and not isinstance(name, str):
+ raise TypeError(
+ 'Column name must be a string. Got column {} of type {}'.format(
+ name, type(name).__name__
+ )
+ )
+
+ assert field_name is None or isinstance(field_name, str), \
+ str(type(field_name))
+ return {
+ 'name': name,
+ 'field_name': 'None' if field_name is None else field_name,
+ 'pandas_type': logical_type,
+ 'numpy_type': string_dtype,
+ 'metadata': extra_metadata,
+ }
+
+
+def construct_metadata(columns_to_convert, df, column_names, index_levels,
+ index_descriptors, preserve_index, types):
+ """Returns a dictionary containing enough metadata to reconstruct a pandas
+ DataFrame as an Arrow Table, including index columns.
+
+ Parameters
+ ----------
+ columns_to_convert : list[pd.Series]
+ df : pandas.DataFrame
+ index_levels : List[pd.Index]
+ index_descriptors : List[Dict]
+ preserve_index : bool
+ types : List[pyarrow.DataType]
+
+ Returns
+ -------
+ dict
+ """
+ num_serialized_index_levels = len([descr for descr in index_descriptors
+ if not isinstance(descr, dict)])
+ # Use ntypes instead of Python shorthand notation [:-len(x)] as [:-0]
+ # behaves differently to what we want.
+ ntypes = len(types)
+ df_types = types[:ntypes - num_serialized_index_levels]
+ index_types = types[ntypes - num_serialized_index_levels:]
+
+ column_metadata = []
+ for col, sanitized_name, arrow_type in zip(columns_to_convert,
+ column_names, df_types):
+ metadata = get_column_metadata(col, name=sanitized_name,
+ arrow_type=arrow_type,
+ field_name=sanitized_name)
+ column_metadata.append(metadata)
+
+ index_column_metadata = []
+ if preserve_index is not False:
+ for level, arrow_type, descriptor in zip(index_levels, index_types,
+ index_descriptors):
+ if isinstance(descriptor, dict):
+ # The index is represented in a non-serialized fashion,
+ # e.g. RangeIndex
+ continue
+ metadata = get_column_metadata(level, name=level.name,
+ arrow_type=arrow_type,
+ field_name=descriptor)
+ index_column_metadata.append(metadata)
+
+ column_indexes = []
+
+ levels = getattr(df.columns, 'levels', [df.columns])
+ names = getattr(df.columns, 'names', [df.columns.name])
+ for level, name in zip(levels, names):
+ metadata = _get_simple_index_descriptor(level, name)
+ column_indexes.append(metadata)
+ else:
+ index_descriptors = index_column_metadata = column_indexes = []
+
+ return {
+ b'pandas': json.dumps({
+ 'index_columns': index_descriptors,
+ 'column_indexes': column_indexes,
+ 'columns': column_metadata + index_column_metadata,
+ 'creator': {
+ 'library': 'pyarrow',
+ 'version': pa.__version__
+ },
+ 'pandas_version': _pandas_api.version
+ }).encode('utf8')
+ }
+
+
+def _get_simple_index_descriptor(level, name):
+ string_dtype, extra_metadata = get_extension_dtype_info(level)
+ pandas_type = get_logical_type_from_numpy(level)
+ if 'mixed' in pandas_type:
+ warnings.warn(
+ "The DataFrame has column names of mixed type. They will be "
+ "converted to strings and not roundtrip correctly.",
+ UserWarning, stacklevel=4)
+ if pandas_type == 'unicode':
+ assert not extra_metadata
+ extra_metadata = {'encoding': 'UTF-8'}
+ return {
+ 'name': name,
+ 'field_name': name,
+ 'pandas_type': pandas_type,
+ 'numpy_type': string_dtype,
+ 'metadata': extra_metadata,
+ }
+
+
+def _column_name_to_strings(name):
+ """Convert a column name (or level) to either a string or a recursive
+ collection of strings.
+
+ Parameters
+ ----------
+ name : str or tuple
+
+ Returns
+ -------
+ value : str or tuple
+
+ Examples
+ --------
+ >>> name = 'foo'
+ >>> _column_name_to_strings(name)
+ 'foo'
+ >>> name = ('foo', 'bar')
+ >>> _column_name_to_strings(name)
+ ('foo', 'bar')
+ >>> import pandas as pd
+ >>> name = (1, pd.Timestamp('2017-02-01 00:00:00'))
+ >>> _column_name_to_strings(name)
+ ('1', '2017-02-01 00:00:00')
+ """
+ if isinstance(name, str):
+ return name
+ elif isinstance(name, bytes):
+ # XXX: should we assume that bytes in Python 3 are UTF-8?
+ return name.decode('utf8')
+ elif isinstance(name, tuple):
+ return str(tuple(map(_column_name_to_strings, name)))
+ elif isinstance(name, Sequence):
+ raise TypeError("Unsupported type for MultiIndex level")
+ elif name is None:
+ return None
+ return str(name)
+
+
+def _index_level_name(index, i, column_names):
+ """Return the name of an index level or a default name if `index.name` is
+ None or is already a column name.
+
+ Parameters
+ ----------
+ index : pandas.Index
+ i : int
+
+ Returns
+ -------
+ name : str
+ """
+ if index.name is not None and index.name not in column_names:
+ return index.name
+ else:
+ return '__index_level_{:d}__'.format(i)
+
+
+def _get_columns_to_convert(df, schema, preserve_index, columns):
+ columns = _resolve_columns_of_interest(df, schema, columns)
+
+ if not df.columns.is_unique:
+ raise ValueError(
+ 'Duplicate column names found: {}'.format(list(df.columns))
+ )
+
+ if schema is not None:
+ return _get_columns_to_convert_given_schema(df, schema, preserve_index)
+
+ column_names = []
+
+ index_levels = (
+ _get_index_level_values(df.index) if preserve_index is not False
+ else []
+ )
+
+ columns_to_convert = []
+ convert_fields = []
+
+ for name in columns:
+ col = df[name]
+ name = _column_name_to_strings(name)
+
+ if _pandas_api.is_sparse(col):
+ raise TypeError(
+ "Sparse pandas data (column {}) not supported.".format(name))
+
+ columns_to_convert.append(col)
+ convert_fields.append(None)
+ column_names.append(name)
+
+ index_descriptors = []
+ index_column_names = []
+ for i, index_level in enumerate(index_levels):
+ name = _index_level_name(index_level, i, column_names)
+ if (isinstance(index_level, _pandas_api.pd.RangeIndex) and
+ preserve_index is None):
+ descr = _get_range_index_descriptor(index_level)
+ else:
+ columns_to_convert.append(index_level)
+ convert_fields.append(None)
+ descr = name
+ index_column_names.append(name)
+ index_descriptors.append(descr)
+
+ all_names = column_names + index_column_names
+
+ # all_names : all of the columns in the resulting table including the data
+ # columns and serialized index columns
+ # column_names : the names of the data columns
+ # index_column_names : the names of the serialized index columns
+ # index_descriptors : descriptions of each index to be used for
+ # reconstruction
+ # index_levels : the extracted index level values
+ # columns_to_convert : assembled raw data (both data columns and indexes)
+ # to be converted to Arrow format
+ # columns_fields : specified column to use for coercion / casting
+ # during serialization, if a Schema was provided
+ return (all_names, column_names, index_column_names, index_descriptors,
+ index_levels, columns_to_convert, convert_fields)
+
+
+def _get_columns_to_convert_given_schema(df, schema, preserve_index):
+ """
+ Specialized version of _get_columns_to_convert in case a Schema is
+ specified.
+ In that case, the Schema is used as the single point of truth for the
+ table structure (types, which columns are included, order of columns, ...).
+ """
+ column_names = []
+ columns_to_convert = []
+ convert_fields = []
+ index_descriptors = []
+ index_column_names = []
+ index_levels = []
+
+ for name in schema.names:
+ try:
+ col = df[name]
+ is_index = False
+ except KeyError:
+ try:
+ col = _get_index_level(df, name)
+ except (KeyError, IndexError):
+ # name not found as index level
+ raise KeyError(
+ "name '{}' present in the specified schema is not found "
+ "in the columns or index".format(name))
+ if preserve_index is False:
+ raise ValueError(
+ "name '{}' present in the specified schema corresponds "
+ "to the index, but 'preserve_index=False' was "
+ "specified".format(name))
+ elif (preserve_index is None and
+ isinstance(col, _pandas_api.pd.RangeIndex)):
+ raise ValueError(
+ "name '{}' is present in the schema, but it is a "
+ "RangeIndex which will not be converted as a column "
+ "in the Table, but saved as metadata-only not in "
+ "columns. Specify 'preserve_index=True' to force it "
+ "being added as a column, or remove it from the "
+ "specified schema".format(name))
+ is_index = True
+
+ name = _column_name_to_strings(name)
+
+ if _pandas_api.is_sparse(col):
+ raise TypeError(
+ "Sparse pandas data (column {}) not supported.".format(name))
+
+ field = schema.field(name)
+ columns_to_convert.append(col)
+ convert_fields.append(field)
+ column_names.append(name)
+
+ if is_index:
+ index_column_names.append(name)
+ index_descriptors.append(name)
+ index_levels.append(col)
+
+ all_names = column_names + index_column_names
+
+ return (all_names, column_names, index_column_names, index_descriptors,
+ index_levels, columns_to_convert, convert_fields)
+
+
+def _get_index_level(df, name):
+ """
+ Get the index level of a DataFrame given 'name' (column name in an arrow
+ Schema).
+ """
+ key = name
+ if name not in df.index.names and _is_generated_index_name(name):
+ # we know we have an autogenerated name => extract number and get
+ # the index level positionally
+ key = int(name[len("__index_level_"):-2])
+ return df.index.get_level_values(key)
+
+
+def _level_name(name):
+ # preserve type when default serializable, otherwise str it
+ try:
+ json.dumps(name)
+ return name
+ except TypeError:
+ return str(name)
+
+
+def _get_range_index_descriptor(level):
+ # public start/stop/step attributes added in pandas 0.25.0
+ return {
+ 'kind': 'range',
+ 'name': _level_name(level.name),
+ 'start': _pandas_api.get_rangeindex_attribute(level, 'start'),
+ 'stop': _pandas_api.get_rangeindex_attribute(level, 'stop'),
+ 'step': _pandas_api.get_rangeindex_attribute(level, 'step')
+ }
+
+
+def _get_index_level_values(index):
+ n = len(getattr(index, 'levels', [index]))
+ return [index.get_level_values(i) for i in range(n)]
+
+
+def _resolve_columns_of_interest(df, schema, columns):
+ if schema is not None and columns is not None:
+ raise ValueError('Schema and columns arguments are mutually '
+ 'exclusive, pass only one of them')
+ elif schema is not None:
+ columns = schema.names
+ elif columns is not None:
+ columns = [c for c in columns if c in df.columns]
+ else:
+ columns = df.columns
+
+ return columns
+
+
+def dataframe_to_types(df, preserve_index, columns=None):
+ (all_names,
+ column_names,
+ _,
+ index_descriptors,
+ index_columns,
+ columns_to_convert,
+ _) = _get_columns_to_convert(df, None, preserve_index, columns)
+
+ types = []
+ # If pandas knows type, skip conversion
+ for c in columns_to_convert:
+ values = c.values
+ if _pandas_api.is_categorical(values):
+ type_ = pa.array(c, from_pandas=True).type
+ elif _pandas_api.is_extension_array_dtype(values):
+ type_ = pa.array(c.head(0), from_pandas=True).type
+ else:
+ values, type_ = get_datetimetz_type(values, c.dtype, None)
+ type_ = pa.lib._ndarray_to_arrow_type(values, type_)
+ if type_ is None:
+ type_ = pa.array(c, from_pandas=True).type
+ types.append(type_)
+
+ metadata = construct_metadata(
+ columns_to_convert, df, column_names, index_columns,
+ index_descriptors, preserve_index, types
+ )
+
+ return all_names, types, metadata
+
+
+def dataframe_to_arrays(df, schema, preserve_index, nthreads=1, columns=None,
+ safe=True):
+ (all_names,
+ column_names,
+ index_column_names,
+ index_descriptors,
+ index_columns,
+ columns_to_convert,
+ convert_fields) = _get_columns_to_convert(df, schema, preserve_index,
+ columns)
+
+ # NOTE(wesm): If nthreads=None, then we use a heuristic to decide whether
+ # using a thread pool is worth it. Currently the heuristic is whether the
+ # nrows > 100 * ncols and ncols > 1.
+ if nthreads is None:
+ nrows, ncols = len(df), len(df.columns)
+ if nrows > ncols * 100 and ncols > 1:
+ nthreads = pa.cpu_count()
+ else:
+ nthreads = 1
+
+ def convert_column(col, field):
+ if field is None:
+ field_nullable = True
+ type_ = None
+ else:
+ field_nullable = field.nullable
+ type_ = field.type
+
+ try:
+ result = pa.array(col, type=type_, from_pandas=True, safe=safe)
+ except (pa.ArrowInvalid,
+ pa.ArrowNotImplementedError,
+ pa.ArrowTypeError) as e:
+ e.args += ("Conversion failed for column {!s} with type {!s}"
+ .format(col.name, col.dtype),)
+ raise e
+ if not field_nullable and result.null_count > 0:
+ raise ValueError("Field {} was non-nullable but pandas column "
+ "had {} null values".format(str(field),
+ result.null_count))
+ return result
+
+ def _can_definitely_zero_copy(arr):
+ return (isinstance(arr, np.ndarray) and
+ arr.flags.contiguous and
+ issubclass(arr.dtype.type, np.integer))
+
+ if nthreads == 1:
+ arrays = [convert_column(c, f)
+ for c, f in zip(columns_to_convert, convert_fields)]
+ else:
+ arrays = []
+ with futures.ThreadPoolExecutor(nthreads) as executor:
+ for c, f in zip(columns_to_convert, convert_fields):
+ if _can_definitely_zero_copy(c.values):
+ arrays.append(convert_column(c, f))
+ else:
+ arrays.append(executor.submit(convert_column, c, f))
+
+ for i, maybe_fut in enumerate(arrays):
+ if isinstance(maybe_fut, futures.Future):
+ arrays[i] = maybe_fut.result()
+
+ types = [x.type for x in arrays]
+
+ if schema is None:
+ fields = []
+ for name, type_ in zip(all_names, types):
+ name = name if name is not None else 'None'
+ fields.append(pa.field(name, type_))
+ schema = pa.schema(fields)
+
+ pandas_metadata = construct_metadata(
+ columns_to_convert, df, column_names, index_columns,
+ index_descriptors, preserve_index, types
+ )
+ metadata = deepcopy(schema.metadata) if schema.metadata else dict()
+ metadata.update(pandas_metadata)
+ schema = schema.with_metadata(metadata)
+
+ return arrays, schema
+
+
+def get_datetimetz_type(values, dtype, type_):
+ if values.dtype.type != np.datetime64:
+ return values, type_
+
+ if _pandas_api.is_datetimetz(dtype) and type_ is None:
+ # If no user type passed, construct a tz-aware timestamp type
+ tz = dtype.tz
+ unit = dtype.unit
+ type_ = pa.timestamp(unit, tz)
+ elif type_ is None:
+ # Trust the NumPy dtype
+ type_ = pa.from_numpy_dtype(values.dtype)
+
+ return values, type_
+
+# ----------------------------------------------------------------------
+# Converting pandas.DataFrame to a dict containing only NumPy arrays or other
+# objects friendly to pyarrow.serialize
+
+
+def dataframe_to_serialized_dict(frame):
+ block_manager = frame._data
+
+ blocks = []
+ axes = [ax for ax in block_manager.axes]
+
+ for block in block_manager.blocks:
+ values = block.values
+ block_data = {}
+
+ if _pandas_api.is_datetimetz(values.dtype):
+ block_data['timezone'] = pa.lib.tzinfo_to_string(values.tz)
+ if hasattr(values, 'values'):
+ values = values.values
+ elif _pandas_api.is_categorical(values):
+ block_data.update(dictionary=values.categories,
+ ordered=values.ordered)
+ values = values.codes
+ block_data.update(
+ placement=block.mgr_locs.as_array,
+ block=values
+ )
+
+ # If we are dealing with an object array, pickle it instead.
+ if values.dtype == np.dtype(object):
+ block_data['object'] = None
+ block_data['block'] = builtin_pickle.dumps(
+ values, protocol=builtin_pickle.HIGHEST_PROTOCOL)
+
+ blocks.append(block_data)
+
+ return {
+ 'blocks': blocks,
+ 'axes': axes
+ }
+
+
+def serialized_dict_to_dataframe(data):
+ import pandas.core.internals as _int
+ reconstructed_blocks = [_reconstruct_block(block)
+ for block in data['blocks']]
+
+ block_mgr = _int.BlockManager(reconstructed_blocks, data['axes'])
+ return _pandas_api.data_frame(block_mgr)
+
+
+def _reconstruct_block(item, columns=None, extension_columns=None):
+ """
+ Construct a pandas Block from the `item` dictionary coming from pyarrow's
+ serialization or returned by arrow::python::ConvertTableToPandas.
+
+ This function takes care of converting dictionary types to pandas
+ categorical, Timestamp-with-timezones to the proper pandas Block, and
+ conversion to pandas ExtensionBlock
+
+ Parameters
+ ----------
+ item : dict
+ For basic types, this is a dictionary in the form of
+ {'block': np.ndarray of values, 'placement': pandas block placement}.
+ Additional keys are present for other types (dictionary, timezone,
+ object).
+ columns :
+ Column names of the table being constructed, used for extension types
+ extension_columns : dict
+ Dictionary of {column_name: pandas_dtype} that includes all columns
+ and corresponding dtypes that will be converted to a pandas
+ ExtensionBlock.
+
+ Returns
+ -------
+ pandas Block
+
+ """
+ import pandas.core.internals as _int
+
+ block_arr = item.get('block', None)
+ placement = item['placement']
+ if 'dictionary' in item:
+ cat = _pandas_api.categorical_type.from_codes(
+ block_arr, categories=item['dictionary'],
+ ordered=item['ordered'])
+ block = _int.make_block(cat, placement=placement)
+ elif 'timezone' in item:
+ dtype = make_datetimetz(item['timezone'])
+ block = _int.make_block(block_arr, placement=placement,
+ klass=_int.DatetimeTZBlock,
+ dtype=dtype)
+ elif 'object' in item:
+ block = _int.make_block(builtin_pickle.loads(block_arr),
+ placement=placement)
+ elif 'py_array' in item:
+ # create ExtensionBlock
+ arr = item['py_array']
+ assert len(placement) == 1
+ name = columns[placement[0]]
+ pandas_dtype = extension_columns[name]
+ if not hasattr(pandas_dtype, '__from_arrow__'):
+ raise ValueError("This column does not support to be converted "
+ "to a pandas ExtensionArray")
+ pd_ext_arr = pandas_dtype.__from_arrow__(arr)
+ block = _int.make_block(pd_ext_arr, placement=placement)
+ else:
+ block = _int.make_block(block_arr, placement=placement)
+
+ return block
+
+
+def make_datetimetz(tz):
+ tz = pa.lib.string_to_tzinfo(tz)
+ return _pandas_api.datetimetz_type('ns', tz=tz)
+
+
+# ----------------------------------------------------------------------
+# Converting pyarrow.Table efficiently to pandas.DataFrame
+
+
+def table_to_blockmanager(options, table, categories=None,
+ ignore_metadata=False, types_mapper=None):
+ from pandas.core.internals import BlockManager
+
+ all_columns = []
+ column_indexes = []
+ pandas_metadata = table.schema.pandas_metadata
+
+ if not ignore_metadata and pandas_metadata is not None:
+ all_columns = pandas_metadata['columns']
+ column_indexes = pandas_metadata.get('column_indexes', [])
+ index_descriptors = pandas_metadata['index_columns']
+ table = _add_any_metadata(table, pandas_metadata)
+ table, index = _reconstruct_index(table, index_descriptors,
+ all_columns)
+ ext_columns_dtypes = _get_extension_dtypes(
+ table, all_columns, types_mapper)
+ else:
+ index = _pandas_api.pd.RangeIndex(table.num_rows)
+ ext_columns_dtypes = _get_extension_dtypes(table, [], types_mapper)
+
+ _check_data_column_metadata_consistency(all_columns)
+ columns = _deserialize_column_index(table, all_columns, column_indexes)
+ blocks = _table_to_blocks(options, table, categories, ext_columns_dtypes)
+
+ axes = [columns, index]
+ return BlockManager(blocks, axes)
+
+
+# Set of the string repr of all numpy dtypes that can be stored in a pandas
+# dataframe (complex not included since not supported by Arrow)
+_pandas_supported_numpy_types = {
+ str(np.dtype(typ))
+ for typ in (np.sctypes['int'] + np.sctypes['uint'] + np.sctypes['float'] +
+ ['object', 'bool'])
+}
+
+
+def _get_extension_dtypes(table, columns_metadata, types_mapper=None):
+ """
+ Based on the stored column pandas metadata and the extension types
+ in the arrow schema, infer which columns should be converted to a
+ pandas extension dtype.
+
+ The 'numpy_type' field in the column metadata stores the string
+ representation of the original pandas dtype (and, despite its name,
+ not the 'pandas_type' field).
+ Based on this string representation, a pandas/numpy dtype is constructed
+ and then we can check if this dtype supports conversion from arrow.
+
+ """
+ ext_columns = {}
+
+ # older pandas version that does not yet support extension dtypes
+ if _pandas_api.extension_dtype is None:
+ return ext_columns
+
+ # infer the extension columns from the pandas metadata
+ for col_meta in columns_metadata:
+ name = col_meta['name']
+ dtype = col_meta['numpy_type']
+ if dtype not in _pandas_supported_numpy_types:
+ # pandas_dtype is expensive, so avoid doing this for types
+ # that are certainly numpy dtypes
+ pandas_dtype = _pandas_api.pandas_dtype(dtype)
+ if isinstance(pandas_dtype, _pandas_api.extension_dtype):
+ if hasattr(pandas_dtype, "__from_arrow__"):
+ ext_columns[name] = pandas_dtype
+
+ # infer from extension type in the schema
+ for field in table.schema:
+ typ = field.type
+ if isinstance(typ, pa.BaseExtensionType):
+ try:
+ pandas_dtype = typ.to_pandas_dtype()
+ except NotImplementedError:
+ pass
+ else:
+ ext_columns[field.name] = pandas_dtype
+
+ # use the specified mapping of built-in arrow types to pandas dtypes
+ if types_mapper:
+ for field in table.schema:
+ typ = field.type
+ pandas_dtype = types_mapper(typ)
+ if pandas_dtype is not None:
+ ext_columns[field.name] = pandas_dtype
+
+ return ext_columns
+
+
+def _check_data_column_metadata_consistency(all_columns):
+ # It can never be the case in a released version of pyarrow that
+ # c['name'] is None *and* 'field_name' is not a key in the column metadata,
+ # because the change to allow c['name'] to be None and the change to add
+ # 'field_name' are in the same release (0.8.0)
+ assert all(
+ (c['name'] is None and 'field_name' in c) or c['name'] is not None
+ for c in all_columns
+ )
+
+
+def _deserialize_column_index(block_table, all_columns, column_indexes):
+ column_strings = [frombytes(x) if isinstance(x, bytes) else x
+ for x in block_table.column_names]
+ if all_columns:
+ columns_name_dict = {
+ c.get('field_name', _column_name_to_strings(c['name'])): c['name']
+ for c in all_columns
+ }
+ columns_values = [
+ columns_name_dict.get(name, name) for name in column_strings
+ ]
+ else:
+ columns_values = column_strings
+
+ # If we're passed multiple column indexes then evaluate with
+ # ast.literal_eval, since the column index values show up as a list of
+ # tuples
+ to_pair = ast.literal_eval if len(column_indexes) > 1 else lambda x: (x,)
+
+ # Create the column index
+
+ # Construct the base index
+ if not columns_values:
+ columns = _pandas_api.pd.Index(columns_values)
+ else:
+ columns = _pandas_api.pd.MultiIndex.from_tuples(
+ list(map(to_pair, columns_values)),
+ names=[col_index['name'] for col_index in column_indexes] or None,
+ )
+
+ # if we're reconstructing the index
+ if len(column_indexes) > 0:
+ columns = _reconstruct_columns_from_metadata(columns, column_indexes)
+
+ # ARROW-1751: flatten a single level column MultiIndex for pandas 0.21.0
+ columns = _flatten_single_level_multiindex(columns)
+
+ return columns
+
+
+def _reconstruct_index(table, index_descriptors, all_columns):
+ # 0. 'field_name' is the name of the column in the arrow Table
+ # 1. 'name' is the user-facing name of the column, that is, it came from
+ # pandas
+ # 2. 'field_name' and 'name' differ for index columns
+ # 3. We fall back on c['name'] for backwards compatibility
+ field_name_to_metadata = {
+ c.get('field_name', c['name']): c
+ for c in all_columns
+ }
+
+ # Build up a list of index columns and names while removing those columns
+ # from the original table
+ index_arrays = []
+ index_names = []
+ result_table = table
+ for descr in index_descriptors:
+ if isinstance(descr, str):
+ result_table, index_level, index_name = _extract_index_level(
+ table, result_table, descr, field_name_to_metadata)
+ if index_level is None:
+ # ARROW-1883: the serialized index column was not found
+ continue
+ elif descr['kind'] == 'range':
+ index_name = descr['name']
+ index_level = _pandas_api.pd.RangeIndex(descr['start'],
+ descr['stop'],
+ step=descr['step'],
+ name=index_name)
+ if len(index_level) != len(table):
+ # Possibly the result of munged metadata
+ continue
+ else:
+ raise ValueError("Unrecognized index kind: {}"
+ .format(descr['kind']))
+ index_arrays.append(index_level)
+ index_names.append(index_name)
+
+ pd = _pandas_api.pd
+
+ # Reconstruct the row index
+ if len(index_arrays) > 1:
+ index = pd.MultiIndex.from_arrays(index_arrays, names=index_names)
+ elif len(index_arrays) == 1:
+ index = index_arrays[0]
+ if not isinstance(index, pd.Index):
+ # Box anything that wasn't boxed above
+ index = pd.Index(index, name=index_names[0])
+ else:
+ index = pd.RangeIndex(table.num_rows)
+
+ return result_table, index
+
+
+def _extract_index_level(table, result_table, field_name,
+ field_name_to_metadata):
+ logical_name = field_name_to_metadata[field_name]['name']
+ index_name = _backwards_compatible_index_name(field_name, logical_name)
+ i = table.schema.get_field_index(field_name)
+
+ if i == -1:
+ # The serialized index column was removed by the user
+ return result_table, None, None
+
+ pd = _pandas_api.pd
+
+ col = table.column(i)
+ values = col.to_pandas().values
+
+ if hasattr(values, 'flags') and not values.flags.writeable:
+ # ARROW-1054: in pandas 0.19.2, factorize will reject
+ # non-writeable arrays when calling MultiIndex.from_arrays
+ values = values.copy()
+
+ if isinstance(col.type, pa.lib.TimestampType) and col.type.tz is not None:
+ index_level = make_tz_aware(pd.Series(values), col.type.tz)
+ else:
+ index_level = pd.Series(values, dtype=values.dtype)
+ result_table = result_table.remove_column(
+ result_table.schema.get_field_index(field_name)
+ )
+ return result_table, index_level, index_name
+
+
+def _backwards_compatible_index_name(raw_name, logical_name):
+ """Compute the name of an index column that is compatible with older
+ versions of :mod:`pyarrow`.
+
+ Parameters
+ ----------
+ raw_name : str
+ logical_name : str
+
+ Returns
+ -------
+ result : str
+
+ Notes
+ -----
+ * Part of :func:`~pyarrow.pandas_compat.table_to_blockmanager`
+ """
+ # Part of table_to_blockmanager
+ if raw_name == logical_name and _is_generated_index_name(raw_name):
+ return None
+ else:
+ return logical_name
+
+
+def _is_generated_index_name(name):
+ pattern = r'^__index_level_\d+__$'
+ return re.match(pattern, name) is not None
+
+
+_pandas_logical_type_map = {
+ 'date': 'datetime64[D]',
+ 'datetime': 'datetime64[ns]',
+ 'unicode': np.unicode_,
+ 'bytes': np.bytes_,
+ 'string': np.str_,
+ 'integer': np.int64,
+ 'floating': np.float64,
+ 'empty': np.object_,
+}
+
+
+def _pandas_type_to_numpy_type(pandas_type):
+ """Get the numpy dtype that corresponds to a pandas type.
+
+ Parameters
+ ----------
+ pandas_type : str
+ The result of a call to pandas.lib.infer_dtype.
+
+ Returns
+ -------
+ dtype : np.dtype
+ The dtype that corresponds to `pandas_type`.
+ """
+ try:
+ return _pandas_logical_type_map[pandas_type]
+ except KeyError:
+ if 'mixed' in pandas_type:
+ # catching 'mixed', 'mixed-integer' and 'mixed-integer-float'
+ return np.object_
+ return np.dtype(pandas_type)
+
+
+def _get_multiindex_codes(mi):
+ # compat for pandas < 0.24 (MI labels renamed to codes).
+ if isinstance(mi, _pandas_api.pd.MultiIndex):
+ return mi.codes if hasattr(mi, 'codes') else mi.labels
+ else:
+ return None
+
+
+def _reconstruct_columns_from_metadata(columns, column_indexes):
+ """Construct a pandas MultiIndex from `columns` and column index metadata
+ in `column_indexes`.
+
+ Parameters
+ ----------
+ columns : List[pd.Index]
+ The columns coming from a pyarrow.Table
+ column_indexes : List[Dict[str, str]]
+ The column index metadata deserialized from the JSON schema metadata
+ in a :class:`~pyarrow.Table`.
+
+ Returns
+ -------
+ result : MultiIndex
+ The index reconstructed using `column_indexes` metadata with levels of
+ the correct type.
+
+ Notes
+ -----
+ * Part of :func:`~pyarrow.pandas_compat.table_to_blockmanager`
+ """
+ pd = _pandas_api.pd
+ # Get levels and labels, and provide sane defaults if the index has a
+ # single level to avoid if/else spaghetti.
+ levels = getattr(columns, 'levels', None) or [columns]
+ labels = _get_multiindex_codes(columns) or [
+ pd.RangeIndex(len(level)) for level in levels
+ ]
+
+ # Convert each level to the dtype provided in the metadata
+ levels_dtypes = [
+ (level, col_index.get('pandas_type', str(level.dtype)),
+ col_index.get('numpy_type', None))
+ for level, col_index in zip_longest(
+ levels, column_indexes, fillvalue={}
+ )
+ ]
+
+ new_levels = []
+ encoder = operator.methodcaller('encode', 'UTF-8')
+
+ for level, pandas_dtype, numpy_dtype in levels_dtypes:
+ dtype = _pandas_type_to_numpy_type(pandas_dtype)
+ # Since our metadata is UTF-8 encoded, Python turns things that were
+ # bytes into unicode strings when json.loads-ing them. We need to
+ # convert them back to bytes to preserve metadata.
+ if dtype == np.bytes_:
+ level = level.map(encoder)
+ elif level.dtype != dtype:
+ level = level.astype(dtype)
+ # ARROW-9096: if original DataFrame was upcast we keep that
+ if level.dtype != numpy_dtype:
+ level = level.astype(numpy_dtype)
+
+ new_levels.append(level)
+
+ return pd.MultiIndex(new_levels, labels, names=columns.names)
+
+
+def _table_to_blocks(options, block_table, categories, extension_columns):
+ # Part of table_to_blockmanager
+
+ # Convert an arrow table to Block from the internal pandas API
+ columns = block_table.column_names
+ result = pa.lib.table_to_blocks(options, block_table, categories,
+ list(extension_columns.keys()))
+ return [_reconstruct_block(item, columns, extension_columns)
+ for item in result]
+
+
+def _flatten_single_level_multiindex(index):
+ pd = _pandas_api.pd
+ if isinstance(index, pd.MultiIndex) and index.nlevels == 1:
+ levels, = index.levels
+ labels, = _get_multiindex_codes(index)
+ # ARROW-9096: use levels.dtype to match cast with original DataFrame
+ dtype = levels.dtype
+
+ # Cheaply check that we do not somehow have duplicate column names
+ if not index.is_unique:
+ raise ValueError('Found non-unique column index')
+
+ return pd.Index(
+ [levels[_label] if _label != -1 else None for _label in labels],
+ dtype=dtype,
+ name=index.names[0]
+ )
+ return index
+
+
+def _add_any_metadata(table, pandas_metadata):
+ modified_columns = {}
+ modified_fields = {}
+
+ schema = table.schema
+
+ index_columns = pandas_metadata['index_columns']
+ # only take index columns into account if they are an actual table column
+ index_columns = [idx_col for idx_col in index_columns
+ if isinstance(idx_col, str)]
+ n_index_levels = len(index_columns)
+ n_columns = len(pandas_metadata['columns']) - n_index_levels
+
+ # Add time zones
+ for i, col_meta in enumerate(pandas_metadata['columns']):
+
+ raw_name = col_meta.get('field_name')
+ if not raw_name:
+ # deal with metadata written with arrow < 0.8 or fastparquet
+ raw_name = col_meta['name']
+ if i >= n_columns:
+ # index columns
+ raw_name = index_columns[i - n_columns]
+ if raw_name is None:
+ raw_name = 'None'
+
+ idx = schema.get_field_index(raw_name)
+ if idx != -1:
+ if col_meta['pandas_type'] == 'datetimetz':
+ col = table[idx]
+ if not isinstance(col.type, pa.lib.TimestampType):
+ continue
+ metadata = col_meta['metadata']
+ if not metadata:
+ continue
+ metadata_tz = metadata.get('timezone')
+ if metadata_tz and metadata_tz != col.type.tz:
+ converted = col.to_pandas()
+ tz_aware_type = pa.timestamp('ns', tz=metadata_tz)
+ with_metadata = pa.Array.from_pandas(converted,
+ type=tz_aware_type)
+
+ modified_fields[idx] = pa.field(schema[idx].name,
+ tz_aware_type)
+ modified_columns[idx] = with_metadata
+
+ if len(modified_columns) > 0:
+ columns = []
+ fields = []
+ for i in range(len(table.schema)):
+ if i in modified_columns:
+ columns.append(modified_columns[i])
+ fields.append(modified_fields[i])
+ else:
+ columns.append(table[i])
+ fields.append(table.schema[i])
+ return pa.Table.from_arrays(columns, schema=pa.schema(fields))
+ else:
+ return table
+
+
+# ----------------------------------------------------------------------
+# Helper functions used in lib
+
+
+def make_tz_aware(series, tz):
+ """
+ Make a datetime64 Series timezone-aware for the given tz
+ """
+ tz = pa.lib.string_to_tzinfo(tz)
+ series = (series.dt.tz_localize('utc')
+ .dt.tz_convert(tz))
+ return series
diff --git a/src/arrow/python/pyarrow/parquet.py b/src/arrow/python/pyarrow/parquet.py
new file mode 100644
index 000000000..8041b4e3c
--- /dev/null
+++ b/src/arrow/python/pyarrow/parquet.py
@@ -0,0 +1,2299 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from collections import defaultdict
+from concurrent import futures
+from functools import partial, reduce
+
+import json
+from collections.abc import Collection
+import numpy as np
+import os
+import re
+import operator
+import urllib.parse
+import warnings
+
+import pyarrow as pa
+import pyarrow.lib as lib
+import pyarrow._parquet as _parquet
+
+from pyarrow._parquet import (ParquetReader, Statistics, # noqa
+ FileMetaData, RowGroupMetaData,
+ ColumnChunkMetaData,
+ ParquetSchema, ColumnSchema)
+from pyarrow.fs import (LocalFileSystem, FileSystem,
+ _resolve_filesystem_and_path, _ensure_filesystem)
+from pyarrow import filesystem as legacyfs
+from pyarrow.util import guid, _is_path_like, _stringify_path
+
+_URI_STRIP_SCHEMES = ('hdfs',)
+
+
+def _parse_uri(path):
+ path = _stringify_path(path)
+ parsed_uri = urllib.parse.urlparse(path)
+ if parsed_uri.scheme in _URI_STRIP_SCHEMES:
+ return parsed_uri.path
+ else:
+ # ARROW-4073: On Windows returning the path with the scheme
+ # stripped removes the drive letter, if any
+ return path
+
+
+def _get_filesystem_and_path(passed_filesystem, path):
+ if passed_filesystem is None:
+ return legacyfs.resolve_filesystem_and_path(path, passed_filesystem)
+ else:
+ passed_filesystem = legacyfs._ensure_filesystem(passed_filesystem)
+ parsed_path = _parse_uri(path)
+ return passed_filesystem, parsed_path
+
+
+def _check_contains_null(val):
+ if isinstance(val, bytes):
+ for byte in val:
+ if isinstance(byte, bytes):
+ compare_to = chr(0)
+ else:
+ compare_to = 0
+ if byte == compare_to:
+ return True
+ elif isinstance(val, str):
+ return '\x00' in val
+ return False
+
+
+def _check_filters(filters, check_null_strings=True):
+ """
+ Check if filters are well-formed.
+ """
+ if filters is not None:
+ if len(filters) == 0 or any(len(f) == 0 for f in filters):
+ raise ValueError("Malformed filters")
+ if isinstance(filters[0][0], str):
+ # We have encountered the situation where we have one nesting level
+ # too few:
+ # We have [(,,), ..] instead of [[(,,), ..]]
+ filters = [filters]
+ if check_null_strings:
+ for conjunction in filters:
+ for col, op, val in conjunction:
+ if (
+ isinstance(val, list) and
+ all(_check_contains_null(v) for v in val) or
+ _check_contains_null(val)
+ ):
+ raise NotImplementedError(
+ "Null-terminated binary strings are not supported "
+ "as filter values."
+ )
+ return filters
+
+
+_DNF_filter_doc = """Predicates are expressed in disjunctive normal form (DNF), like
+ ``[[('x', '=', 0), ...], ...]``. DNF allows arbitrary boolean logical
+ combinations of single column predicates. The innermost tuples each
+ describe a single column predicate. The list of inner predicates is
+ interpreted as a conjunction (AND), forming a more selective and
+ multiple column predicate. Finally, the most outer list combines these
+ filters as a disjunction (OR).
+
+ Predicates may also be passed as List[Tuple]. This form is interpreted
+ as a single conjunction. To express OR in predicates, one must
+ use the (preferred) List[List[Tuple]] notation.
+
+ Each tuple has format: (``key``, ``op``, ``value``) and compares the
+ ``key`` with the ``value``.
+ The supported ``op`` are: ``=`` or ``==``, ``!=``, ``<``, ``>``, ``<=``,
+ ``>=``, ``in`` and ``not in``. If the ``op`` is ``in`` or ``not in``, the
+ ``value`` must be a collection such as a ``list``, a ``set`` or a
+ ``tuple``.
+
+ Examples:
+
+ .. code-block:: python
+
+ ('x', '=', 0)
+ ('y', 'in', ['a', 'b', 'c'])
+ ('z', 'not in', {'a','b'})
+
+ """
+
+
+def _filters_to_expression(filters):
+ """
+ Check if filters are well-formed.
+
+ See _DNF_filter_doc above for more details.
+ """
+ import pyarrow.dataset as ds
+
+ if isinstance(filters, ds.Expression):
+ return filters
+
+ filters = _check_filters(filters, check_null_strings=False)
+
+ def convert_single_predicate(col, op, val):
+ field = ds.field(col)
+
+ if op == "=" or op == "==":
+ return field == val
+ elif op == "!=":
+ return field != val
+ elif op == '<':
+ return field < val
+ elif op == '>':
+ return field > val
+ elif op == '<=':
+ return field <= val
+ elif op == '>=':
+ return field >= val
+ elif op == 'in':
+ return field.isin(val)
+ elif op == 'not in':
+ return ~field.isin(val)
+ else:
+ raise ValueError(
+ '"{0}" is not a valid operator in predicates.'.format(
+ (col, op, val)))
+
+ disjunction_members = []
+
+ for conjunction in filters:
+ conjunction_members = [
+ convert_single_predicate(col, op, val)
+ for col, op, val in conjunction
+ ]
+
+ disjunction_members.append(reduce(operator.and_, conjunction_members))
+
+ return reduce(operator.or_, disjunction_members)
+
+
+# ----------------------------------------------------------------------
+# Reading a single Parquet file
+
+
+class ParquetFile:
+ """
+ Reader interface for a single Parquet file.
+
+ Parameters
+ ----------
+ source : str, pathlib.Path, pyarrow.NativeFile, or file-like object
+ Readable source. For passing bytes or buffer-like file containing a
+ Parquet file, use pyarrow.BufferReader.
+ metadata : FileMetaData, default None
+ Use existing metadata object, rather than reading from file.
+ common_metadata : FileMetaData, default None
+ Will be used in reads for pandas schema metadata if not found in the
+ main file's metadata, no other uses at the moment.
+ memory_map : bool, default False
+ If the source is a file path, use a memory map to read file, which can
+ improve performance in some environments.
+ buffer_size : int, default 0
+ If positive, perform read buffering when deserializing individual
+ column chunks. Otherwise IO calls are unbuffered.
+ pre_buffer : bool, default False
+ Coalesce and issue file reads in parallel to improve performance on
+ high-latency filesystems (e.g. S3). If True, Arrow will use a
+ background I/O thread pool.
+ read_dictionary : list
+ List of column names to read directly as DictionaryArray.
+ coerce_int96_timestamp_unit : str, default None.
+ Cast timestamps that are stored in INT96 format to a particular
+ resolution (e.g. 'ms'). Setting to None is equivalent to 'ns'
+ and therefore INT96 timestamps will be infered as timestamps
+ in nanoseconds.
+ """
+
+ def __init__(self, source, metadata=None, common_metadata=None,
+ read_dictionary=None, memory_map=False, buffer_size=0,
+ pre_buffer=False, coerce_int96_timestamp_unit=None):
+ self.reader = ParquetReader()
+ self.reader.open(
+ source, use_memory_map=memory_map,
+ buffer_size=buffer_size, pre_buffer=pre_buffer,
+ read_dictionary=read_dictionary, metadata=metadata,
+ coerce_int96_timestamp_unit=coerce_int96_timestamp_unit
+ )
+ self.common_metadata = common_metadata
+ self._nested_paths_by_prefix = self._build_nested_paths()
+
+ def _build_nested_paths(self):
+ paths = self.reader.column_paths
+
+ result = defaultdict(list)
+
+ for i, path in enumerate(paths):
+ key = path[0]
+ rest = path[1:]
+ while True:
+ result[key].append(i)
+
+ if not rest:
+ break
+
+ key = '.'.join((key, rest[0]))
+ rest = rest[1:]
+
+ return result
+
+ @property
+ def metadata(self):
+ return self.reader.metadata
+
+ @property
+ def schema(self):
+ """
+ Return the Parquet schema, unconverted to Arrow types
+ """
+ return self.metadata.schema
+
+ @property
+ def schema_arrow(self):
+ """
+ Return the inferred Arrow schema, converted from the whole Parquet
+ file's schema
+ """
+ return self.reader.schema_arrow
+
+ @property
+ def num_row_groups(self):
+ return self.reader.num_row_groups
+
+ def read_row_group(self, i, columns=None, use_threads=True,
+ use_pandas_metadata=False):
+ """
+ Read a single row group from a Parquet file.
+
+ Parameters
+ ----------
+ i : int
+ Index of the individual row group that we want to read.
+ columns : list
+ If not None, only these columns will be read from the row group. A
+ column name may be a prefix of a nested field, e.g. 'a' will select
+ 'a.b', 'a.c', and 'a.d.e'.
+ use_threads : bool, default True
+ Perform multi-threaded column reads.
+ use_pandas_metadata : bool, default False
+ If True and file has custom pandas schema metadata, ensure that
+ index columns are also loaded.
+
+ Returns
+ -------
+ pyarrow.table.Table
+ Content of the row group as a table (of columns)
+ """
+ column_indices = self._get_column_indices(
+ columns, use_pandas_metadata=use_pandas_metadata)
+ return self.reader.read_row_group(i, column_indices=column_indices,
+ use_threads=use_threads)
+
+ def read_row_groups(self, row_groups, columns=None, use_threads=True,
+ use_pandas_metadata=False):
+ """
+ Read a multiple row groups from a Parquet file.
+
+ Parameters
+ ----------
+ row_groups : list
+ Only these row groups will be read from the file.
+ columns : list
+ If not None, only these columns will be read from the row group. A
+ column name may be a prefix of a nested field, e.g. 'a' will select
+ 'a.b', 'a.c', and 'a.d.e'.
+ use_threads : bool, default True
+ Perform multi-threaded column reads.
+ use_pandas_metadata : bool, default False
+ If True and file has custom pandas schema metadata, ensure that
+ index columns are also loaded.
+
+ Returns
+ -------
+ pyarrow.table.Table
+ Content of the row groups as a table (of columns).
+ """
+ column_indices = self._get_column_indices(
+ columns, use_pandas_metadata=use_pandas_metadata)
+ return self.reader.read_row_groups(row_groups,
+ column_indices=column_indices,
+ use_threads=use_threads)
+
+ def iter_batches(self, batch_size=65536, row_groups=None, columns=None,
+ use_threads=True, use_pandas_metadata=False):
+ """
+ Read streaming batches from a Parquet file
+
+ Parameters
+ ----------
+ batch_size : int, default 64K
+ Maximum number of records to yield per batch. Batches may be
+ smaller if there aren't enough rows in the file.
+ row_groups : list
+ Only these row groups will be read from the file.
+ columns : list
+ If not None, only these columns will be read from the file. A
+ column name may be a prefix of a nested field, e.g. 'a' will select
+ 'a.b', 'a.c', and 'a.d.e'.
+ use_threads : boolean, default True
+ Perform multi-threaded column reads.
+ use_pandas_metadata : boolean, default False
+ If True and file has custom pandas schema metadata, ensure that
+ index columns are also loaded.
+
+ Returns
+ -------
+ iterator of pyarrow.RecordBatch
+ Contents of each batch as a record batch
+ """
+ if row_groups is None:
+ row_groups = range(0, self.metadata.num_row_groups)
+ column_indices = self._get_column_indices(
+ columns, use_pandas_metadata=use_pandas_metadata)
+
+ batches = self.reader.iter_batches(batch_size,
+ row_groups=row_groups,
+ column_indices=column_indices,
+ use_threads=use_threads)
+ return batches
+
+ def read(self, columns=None, use_threads=True, use_pandas_metadata=False):
+ """
+ Read a Table from Parquet format,
+
+ Parameters
+ ----------
+ columns : list
+ If not None, only these columns will be read from the file. A
+ column name may be a prefix of a nested field, e.g. 'a' will select
+ 'a.b', 'a.c', and 'a.d.e'.
+ use_threads : bool, default True
+ Perform multi-threaded column reads.
+ use_pandas_metadata : bool, default False
+ If True and file has custom pandas schema metadata, ensure that
+ index columns are also loaded.
+
+ Returns
+ -------
+ pyarrow.table.Table
+ Content of the file as a table (of columns).
+ """
+ column_indices = self._get_column_indices(
+ columns, use_pandas_metadata=use_pandas_metadata)
+ return self.reader.read_all(column_indices=column_indices,
+ use_threads=use_threads)
+
+ def scan_contents(self, columns=None, batch_size=65536):
+ """
+ Read contents of file for the given columns and batch size.
+
+ Notes
+ -----
+ This function's primary purpose is benchmarking.
+ The scan is executed on a single thread.
+
+ Parameters
+ ----------
+ columns : list of integers, default None
+ Select columns to read, if None scan all columns.
+ batch_size : int, default 64K
+ Number of rows to read at a time internally.
+
+ Returns
+ -------
+ num_rows : number of rows in file
+ """
+ column_indices = self._get_column_indices(columns)
+ return self.reader.scan_contents(column_indices,
+ batch_size=batch_size)
+
+ def _get_column_indices(self, column_names, use_pandas_metadata=False):
+ if column_names is None:
+ return None
+
+ indices = []
+
+ for name in column_names:
+ if name in self._nested_paths_by_prefix:
+ indices.extend(self._nested_paths_by_prefix[name])
+
+ if use_pandas_metadata:
+ file_keyvalues = self.metadata.metadata
+ common_keyvalues = (self.common_metadata.metadata
+ if self.common_metadata is not None
+ else None)
+
+ if file_keyvalues and b'pandas' in file_keyvalues:
+ index_columns = _get_pandas_index_columns(file_keyvalues)
+ elif common_keyvalues and b'pandas' in common_keyvalues:
+ index_columns = _get_pandas_index_columns(common_keyvalues)
+ else:
+ index_columns = []
+
+ if indices is not None and index_columns:
+ indices += [self.reader.column_name_idx(descr)
+ for descr in index_columns
+ if not isinstance(descr, dict)]
+
+ return indices
+
+
+_SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]')
+
+
+def _sanitized_spark_field_name(name):
+ return _SPARK_DISALLOWED_CHARS.sub('_', name)
+
+
+def _sanitize_schema(schema, flavor):
+ if 'spark' in flavor:
+ sanitized_fields = []
+
+ schema_changed = False
+
+ for field in schema:
+ name = field.name
+ sanitized_name = _sanitized_spark_field_name(name)
+
+ if sanitized_name != name:
+ schema_changed = True
+ sanitized_field = pa.field(sanitized_name, field.type,
+ field.nullable, field.metadata)
+ sanitized_fields.append(sanitized_field)
+ else:
+ sanitized_fields.append(field)
+
+ new_schema = pa.schema(sanitized_fields, metadata=schema.metadata)
+ return new_schema, schema_changed
+ else:
+ return schema, False
+
+
+def _sanitize_table(table, new_schema, flavor):
+ # TODO: This will not handle prohibited characters in nested field names
+ if 'spark' in flavor:
+ column_data = [table[i] for i in range(table.num_columns)]
+ return pa.Table.from_arrays(column_data, schema=new_schema)
+ else:
+ return table
+
+
+_parquet_writer_arg_docs = """version : {"1.0", "2.4", "2.6"}, default "1.0"
+ Determine which Parquet logical types are available for use, whether the
+ reduced set from the Parquet 1.x.x format or the expanded logical types
+ added in later format versions.
+ Files written with version='2.4' or '2.6' may not be readable in all
+ Parquet implementations, so version='1.0' is likely the choice that
+ maximizes file compatibility.
+ UINT32 and some logical types are only available with version '2.4'.
+ Nanosecond timestamps are only available with version '2.6'.
+ Other features such as compression algorithms or the new serialized
+ data page format must be enabled separately (see 'compression' and
+ 'data_page_version').
+use_dictionary : bool or list
+ Specify if we should use dictionary encoding in general or only for
+ some columns.
+use_deprecated_int96_timestamps : bool, default None
+ Write timestamps to INT96 Parquet format. Defaults to False unless enabled
+ by flavor argument. This take priority over the coerce_timestamps option.
+coerce_timestamps : str, default None
+ Cast timestamps to a particular resolution. If omitted, defaults are chosen
+ depending on `version`. By default, for ``version='1.0'`` (the default)
+ and ``version='2.4'``, nanoseconds are cast to microseconds ('us'), while
+ for other `version` values, they are written natively without loss
+ of resolution. Seconds are always cast to milliseconds ('ms') by default,
+ as Parquet does not have any temporal type with seconds resolution.
+ If the casting results in loss of data, it will raise an exception
+ unless ``allow_truncated_timestamps=True`` is given.
+ Valid values: {None, 'ms', 'us'}
+data_page_size : int, default None
+ Set a target threshold for the approximate encoded size of data
+ pages within a column chunk (in bytes). If None, use the default data page
+ size of 1MByte.
+allow_truncated_timestamps : bool, default False
+ Allow loss of data when coercing timestamps to a particular
+ resolution. E.g. if microsecond or nanosecond data is lost when coercing to
+ 'ms', do not raise an exception. Passing ``allow_truncated_timestamp=True``
+ will NOT result in the truncation exception being ignored unless
+ ``coerce_timestamps`` is not None.
+compression : str or dict
+ Specify the compression codec, either on a general basis or per-column.
+ Valid values: {'NONE', 'SNAPPY', 'GZIP', 'BROTLI', 'LZ4', 'ZSTD'}.
+write_statistics : bool or list
+ Specify if we should write statistics in general (default is True) or only
+ for some columns.
+flavor : {'spark'}, default None
+ Sanitize schema or set other compatibility options to work with
+ various target systems.
+filesystem : FileSystem, default None
+ If nothing passed, will be inferred from `where` if path-like, else
+ `where` is already a file-like object so no filesystem is needed.
+compression_level : int or dict, default None
+ Specify the compression level for a codec, either on a general basis or
+ per-column. If None is passed, arrow selects the compression level for
+ the compression codec in use. The compression level has a different
+ meaning for each codec, so you have to read the documentation of the
+ codec you are using.
+ An exception is thrown if the compression codec does not allow specifying
+ a compression level.
+use_byte_stream_split : bool or list, default False
+ Specify if the byte_stream_split encoding should be used in general or
+ only for some columns. If both dictionary and byte_stream_stream are
+ enabled, then dictionary is preferred.
+ The byte_stream_split encoding is valid only for floating-point data types
+ and should be combined with a compression codec.
+data_page_version : {"1.0", "2.0"}, default "1.0"
+ The serialized Parquet data page format version to write, defaults to
+ 1.0. This does not impact the file schema logical types and Arrow to
+ Parquet type casting behavior; for that use the "version" option.
+use_compliant_nested_type : bool, default False
+ Whether to write compliant Parquet nested type (lists) as defined
+ `here <https://github.com/apache/parquet-format/blob/master/
+ LogicalTypes.md#nested-types>`_, defaults to ``False``.
+ For ``use_compliant_nested_type=True``, this will write into a list
+ with 3-level structure where the middle level, named ``list``,
+ is a repeated group with a single field named ``element``::
+
+ <list-repetition> group <name> (LIST) {
+ repeated group list {
+ <element-repetition> <element-type> element;
+ }
+ }
+
+ For ``use_compliant_nested_type=False``, this will also write into a list
+ with 3-level structure, where the name of the single field of the middle
+ level ``list`` is taken from the element name for nested columns in Arrow,
+ which defaults to ``item``::
+
+ <list-repetition> group <name> (LIST) {
+ repeated group list {
+ <element-repetition> <element-type> item;
+ }
+ }
+"""
+
+
+class ParquetWriter:
+
+ __doc__ = """
+Class for incrementally building a Parquet file for Arrow tables.
+
+Parameters
+----------
+where : path or file-like object
+schema : arrow Schema
+{}
+writer_engine_version : unused
+**options : dict
+ If options contains a key `metadata_collector` then the
+ corresponding value is assumed to be a list (or any object with
+ `.append` method) that will be filled with the file metadata instance
+ of the written file.
+""".format(_parquet_writer_arg_docs)
+
+ def __init__(self, where, schema, filesystem=None,
+ flavor=None,
+ version='1.0',
+ use_dictionary=True,
+ compression='snappy',
+ write_statistics=True,
+ use_deprecated_int96_timestamps=None,
+ compression_level=None,
+ use_byte_stream_split=False,
+ writer_engine_version=None,
+ data_page_version='1.0',
+ use_compliant_nested_type=False,
+ **options):
+ if use_deprecated_int96_timestamps is None:
+ # Use int96 timestamps for Spark
+ if flavor is not None and 'spark' in flavor:
+ use_deprecated_int96_timestamps = True
+ else:
+ use_deprecated_int96_timestamps = False
+
+ self.flavor = flavor
+ if flavor is not None:
+ schema, self.schema_changed = _sanitize_schema(schema, flavor)
+ else:
+ self.schema_changed = False
+
+ self.schema = schema
+ self.where = where
+
+ # If we open a file using a filesystem, store file handle so we can be
+ # sure to close it when `self.close` is called.
+ self.file_handle = None
+
+ filesystem, path = _resolve_filesystem_and_path(
+ where, filesystem, allow_legacy_filesystem=True
+ )
+ if filesystem is not None:
+ if isinstance(filesystem, legacyfs.FileSystem):
+ # legacy filesystem (eg custom subclass)
+ # TODO deprecate
+ sink = self.file_handle = filesystem.open(path, 'wb')
+ else:
+ # ARROW-10480: do not auto-detect compression. While
+ # a filename like foo.parquet.gz is nonconforming, it
+ # shouldn't implicitly apply compression.
+ sink = self.file_handle = filesystem.open_output_stream(
+ path, compression=None)
+ else:
+ sink = where
+ self._metadata_collector = options.pop('metadata_collector', None)
+ engine_version = 'V2'
+ self.writer = _parquet.ParquetWriter(
+ sink, schema,
+ version=version,
+ compression=compression,
+ use_dictionary=use_dictionary,
+ write_statistics=write_statistics,
+ use_deprecated_int96_timestamps=use_deprecated_int96_timestamps,
+ compression_level=compression_level,
+ use_byte_stream_split=use_byte_stream_split,
+ writer_engine_version=engine_version,
+ data_page_version=data_page_version,
+ use_compliant_nested_type=use_compliant_nested_type,
+ **options)
+ self.is_open = True
+
+ def __del__(self):
+ if getattr(self, 'is_open', False):
+ self.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self.close()
+ # return false since we want to propagate exceptions
+ return False
+
+ def write_table(self, table, row_group_size=None):
+ if self.schema_changed:
+ table = _sanitize_table(table, self.schema, self.flavor)
+ assert self.is_open
+
+ if not table.schema.equals(self.schema, check_metadata=False):
+ msg = ('Table schema does not match schema used to create file: '
+ '\ntable:\n{!s} vs. \nfile:\n{!s}'
+ .format(table.schema, self.schema))
+ raise ValueError(msg)
+
+ self.writer.write_table(table, row_group_size=row_group_size)
+
+ def close(self):
+ if self.is_open:
+ self.writer.close()
+ self.is_open = False
+ if self._metadata_collector is not None:
+ self._metadata_collector.append(self.writer.metadata)
+ if self.file_handle is not None:
+ self.file_handle.close()
+
+
+def _get_pandas_index_columns(keyvalues):
+ return (json.loads(keyvalues[b'pandas'].decode('utf8'))
+ ['index_columns'])
+
+
+# ----------------------------------------------------------------------
+# Metadata container providing instructions about reading a single Parquet
+# file, possibly part of a partitioned dataset
+
+
+class ParquetDatasetPiece:
+ """
+ DEPRECATED: A single chunk of a potentially larger Parquet dataset to read.
+
+ The arguments will indicate to read either a single row group or all row
+ groups, and whether to add partition keys to the resulting pyarrow.Table.
+
+ .. deprecated:: 5.0
+ Directly constructing a ``ParquetDatasetPiece`` is deprecated, as well
+ as accessing the pieces of a ``ParquetDataset`` object. Specify
+ ``use_legacy_dataset=False`` when constructing the ``ParquetDataset``
+ and use the ``ParquetDataset.fragments`` attribute instead.
+
+ Parameters
+ ----------
+ path : str or pathlib.Path
+ Path to file in the file system where this piece is located.
+ open_file_func : callable
+ Function to use for obtaining file handle to dataset piece.
+ partition_keys : list of tuples
+ Two-element tuples of ``(column name, ordinal index)``.
+ row_group : int, default None
+ Row group to load. By default, reads all row groups.
+ file_options : dict
+ Options
+ """
+
+ def __init__(self, path, open_file_func=partial(open, mode='rb'),
+ file_options=None, row_group=None, partition_keys=None):
+ warnings.warn(
+ "ParquetDatasetPiece is deprecated as of pyarrow 5.0.0 and will "
+ "be removed in a future version.",
+ DeprecationWarning, stacklevel=2)
+ self._init(
+ path, open_file_func, file_options, row_group, partition_keys)
+
+ @staticmethod
+ def _create(path, open_file_func=partial(open, mode='rb'),
+ file_options=None, row_group=None, partition_keys=None):
+ self = ParquetDatasetPiece.__new__(ParquetDatasetPiece)
+ self._init(
+ path, open_file_func, file_options, row_group, partition_keys)
+ return self
+
+ def _init(self, path, open_file_func, file_options, row_group,
+ partition_keys):
+ self.path = _stringify_path(path)
+ self.open_file_func = open_file_func
+ self.row_group = row_group
+ self.partition_keys = partition_keys or []
+ self.file_options = file_options or {}
+
+ def __eq__(self, other):
+ if not isinstance(other, ParquetDatasetPiece):
+ return False
+ return (self.path == other.path and
+ self.row_group == other.row_group and
+ self.partition_keys == other.partition_keys)
+
+ def __repr__(self):
+ return ('{}({!r}, row_group={!r}, partition_keys={!r})'
+ .format(type(self).__name__, self.path,
+ self.row_group,
+ self.partition_keys))
+
+ def __str__(self):
+ result = ''
+
+ if len(self.partition_keys) > 0:
+ partition_str = ', '.join('{}={}'.format(name, index)
+ for name, index in self.partition_keys)
+ result += 'partition[{}] '.format(partition_str)
+
+ result += self.path
+
+ if self.row_group is not None:
+ result += ' | row_group={}'.format(self.row_group)
+
+ return result
+
+ def get_metadata(self):
+ """
+ Return the file's metadata.
+
+ Returns
+ -------
+ metadata : FileMetaData
+ """
+ f = self.open()
+ return f.metadata
+
+ def open(self):
+ """
+ Return instance of ParquetFile.
+ """
+ reader = self.open_file_func(self.path)
+ if not isinstance(reader, ParquetFile):
+ reader = ParquetFile(reader, **self.file_options)
+ return reader
+
+ def read(self, columns=None, use_threads=True, partitions=None,
+ file=None, use_pandas_metadata=False):
+ """
+ Read this piece as a pyarrow.Table.
+
+ Parameters
+ ----------
+ columns : list of column names, default None
+ use_threads : bool, default True
+ Perform multi-threaded column reads.
+ partitions : ParquetPartitions, default None
+ file : file-like object
+ Passed to ParquetFile.
+ use_pandas_metadata : bool
+ If pandas metadata should be used or not.
+
+ Returns
+ -------
+ table : pyarrow.Table
+ """
+ if self.open_file_func is not None:
+ reader = self.open()
+ elif file is not None:
+ reader = ParquetFile(file, **self.file_options)
+ else:
+ # try to read the local path
+ reader = ParquetFile(self.path, **self.file_options)
+
+ options = dict(columns=columns,
+ use_threads=use_threads,
+ use_pandas_metadata=use_pandas_metadata)
+
+ if self.row_group is not None:
+ table = reader.read_row_group(self.row_group, **options)
+ else:
+ table = reader.read(**options)
+
+ if len(self.partition_keys) > 0:
+ if partitions is None:
+ raise ValueError('Must pass partition sets')
+
+ # Here, the index is the categorical code of the partition where
+ # this piece is located. Suppose we had
+ #
+ # /foo=a/0.parq
+ # /foo=b/0.parq
+ # /foo=c/0.parq
+ #
+ # Then we assign a=0, b=1, c=2. And the resulting Table pieces will
+ # have a DictionaryArray column named foo having the constant index
+ # value as indicated. The distinct categories of the partition have
+ # been computed in the ParquetManifest
+ for i, (name, index) in enumerate(self.partition_keys):
+ # The partition code is the same for all values in this piece
+ indices = np.full(len(table), index, dtype='i4')
+
+ # This is set of all partition values, computed as part of the
+ # manifest, so ['a', 'b', 'c'] as in our example above.
+ dictionary = partitions.levels[i].dictionary
+
+ arr = pa.DictionaryArray.from_arrays(indices, dictionary)
+ table = table.append_column(name, arr)
+
+ return table
+
+
+class PartitionSet:
+ """
+ A data structure for cataloguing the observed Parquet partitions at a
+ particular level. So if we have
+
+ /foo=a/bar=0
+ /foo=a/bar=1
+ /foo=a/bar=2
+ /foo=b/bar=0
+ /foo=b/bar=1
+ /foo=b/bar=2
+
+ Then we have two partition sets, one for foo, another for bar. As we visit
+ levels of the partition hierarchy, a PartitionSet tracks the distinct
+ values and assigns categorical codes to use when reading the pieces
+
+ Parameters
+ ----------
+ name : str
+ Name of the partition set. Under which key to collect all values.
+ keys : list
+ All possible values that have been collected for that partition set.
+ """
+
+ def __init__(self, name, keys=None):
+ self.name = name
+ self.keys = keys or []
+ self.key_indices = {k: i for i, k in enumerate(self.keys)}
+ self._dictionary = None
+
+ def get_index(self, key):
+ """
+ Get the index of the partition value if it is known, otherwise assign
+ one
+
+ Parameters
+ ----------
+ key : The value for which we want to known the index.
+ """
+ if key in self.key_indices:
+ return self.key_indices[key]
+ else:
+ index = len(self.key_indices)
+ self.keys.append(key)
+ self.key_indices[key] = index
+ return index
+
+ @property
+ def dictionary(self):
+ if self._dictionary is not None:
+ return self._dictionary
+
+ if len(self.keys) == 0:
+ raise ValueError('No known partition keys')
+
+ # Only integer and string partition types are supported right now
+ try:
+ integer_keys = [int(x) for x in self.keys]
+ dictionary = lib.array(integer_keys)
+ except ValueError:
+ dictionary = lib.array(self.keys)
+
+ self._dictionary = dictionary
+ return dictionary
+
+ @property
+ def is_sorted(self):
+ return list(self.keys) == sorted(self.keys)
+
+
+class ParquetPartitions:
+
+ def __init__(self):
+ self.levels = []
+ self.partition_names = set()
+
+ def __len__(self):
+ return len(self.levels)
+
+ def __getitem__(self, i):
+ return self.levels[i]
+
+ def equals(self, other):
+ if not isinstance(other, ParquetPartitions):
+ raise TypeError('`other` must be an instance of ParquetPartitions')
+
+ return (self.levels == other.levels and
+ self.partition_names == other.partition_names)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def get_index(self, level, name, key):
+ """
+ Record a partition value at a particular level, returning the distinct
+ code for that value at that level.
+
+ Example:
+
+ partitions.get_index(1, 'foo', 'a') returns 0
+ partitions.get_index(1, 'foo', 'b') returns 1
+ partitions.get_index(1, 'foo', 'c') returns 2
+ partitions.get_index(1, 'foo', 'a') returns 0
+
+ Parameters
+ ----------
+ level : int
+ The nesting level of the partition we are observing
+ name : str
+ The partition name
+ key : str or int
+ The partition value
+ """
+ if level == len(self.levels):
+ if name in self.partition_names:
+ raise ValueError('{} was the name of the partition in '
+ 'another level'.format(name))
+
+ part_set = PartitionSet(name)
+ self.levels.append(part_set)
+ self.partition_names.add(name)
+
+ return self.levels[level].get_index(key)
+
+ def filter_accepts_partition(self, part_key, filter, level):
+ p_column, p_value_index = part_key
+ f_column, op, f_value = filter
+ if p_column != f_column:
+ return True
+
+ f_type = type(f_value)
+
+ if op in {'in', 'not in'}:
+ if not isinstance(f_value, Collection):
+ raise TypeError(
+ "'%s' object is not a collection", f_type.__name__)
+ if not f_value:
+ raise ValueError("Cannot use empty collection as filter value")
+ if len({type(item) for item in f_value}) != 1:
+ raise ValueError("All elements of the collection '%s' must be"
+ " of same type", f_value)
+ f_type = type(next(iter(f_value)))
+
+ elif not isinstance(f_value, str) and isinstance(f_value, Collection):
+ raise ValueError(
+ "Op '%s' not supported with a collection value", op)
+
+ p_value = f_type(self.levels[level]
+ .dictionary[p_value_index].as_py())
+
+ if op == "=" or op == "==":
+ return p_value == f_value
+ elif op == "!=":
+ return p_value != f_value
+ elif op == '<':
+ return p_value < f_value
+ elif op == '>':
+ return p_value > f_value
+ elif op == '<=':
+ return p_value <= f_value
+ elif op == '>=':
+ return p_value >= f_value
+ elif op == 'in':
+ return p_value in f_value
+ elif op == 'not in':
+ return p_value not in f_value
+ else:
+ raise ValueError("'%s' is not a valid operator in predicates.",
+ filter[1])
+
+
+class ParquetManifest:
+
+ def __init__(self, dirpath, open_file_func=None, filesystem=None,
+ pathsep='/', partition_scheme='hive', metadata_nthreads=1):
+ filesystem, dirpath = _get_filesystem_and_path(filesystem, dirpath)
+ self.filesystem = filesystem
+ self.open_file_func = open_file_func
+ self.pathsep = pathsep
+ self.dirpath = _stringify_path(dirpath)
+ self.partition_scheme = partition_scheme
+ self.partitions = ParquetPartitions()
+ self.pieces = []
+ self._metadata_nthreads = metadata_nthreads
+ self._thread_pool = futures.ThreadPoolExecutor(
+ max_workers=metadata_nthreads)
+
+ self.common_metadata_path = None
+ self.metadata_path = None
+
+ self._visit_level(0, self.dirpath, [])
+
+ # Due to concurrency, pieces will potentially by out of order if the
+ # dataset is partitioned so we sort them to yield stable results
+ self.pieces.sort(key=lambda piece: piece.path)
+
+ if self.common_metadata_path is None:
+ # _common_metadata is a subset of _metadata
+ self.common_metadata_path = self.metadata_path
+
+ self._thread_pool.shutdown()
+
+ def _visit_level(self, level, base_path, part_keys):
+ fs = self.filesystem
+
+ _, directories, files = next(fs.walk(base_path))
+
+ filtered_files = []
+ for path in files:
+ full_path = self.pathsep.join((base_path, path))
+ if path.endswith('_common_metadata'):
+ self.common_metadata_path = full_path
+ elif path.endswith('_metadata'):
+ self.metadata_path = full_path
+ elif self._should_silently_exclude(path):
+ continue
+ else:
+ filtered_files.append(full_path)
+
+ # ARROW-1079: Filter out "private" directories starting with underscore
+ filtered_directories = [self.pathsep.join((base_path, x))
+ for x in directories
+ if not _is_private_directory(x)]
+
+ filtered_files.sort()
+ filtered_directories.sort()
+
+ if len(filtered_files) > 0 and len(filtered_directories) > 0:
+ raise ValueError('Found files in an intermediate '
+ 'directory: {}'.format(base_path))
+ elif len(filtered_directories) > 0:
+ self._visit_directories(level, filtered_directories, part_keys)
+ else:
+ self._push_pieces(filtered_files, part_keys)
+
+ def _should_silently_exclude(self, file_name):
+ return (file_name.endswith('.crc') or # Checksums
+ file_name.endswith('_$folder$') or # HDFS directories in S3
+ file_name.startswith('.') or # Hidden files starting with .
+ file_name.startswith('_') or # Hidden files starting with _
+ file_name in EXCLUDED_PARQUET_PATHS)
+
+ def _visit_directories(self, level, directories, part_keys):
+ futures_list = []
+ for path in directories:
+ head, tail = _path_split(path, self.pathsep)
+ name, key = _parse_hive_partition(tail)
+
+ index = self.partitions.get_index(level, name, key)
+ dir_part_keys = part_keys + [(name, index)]
+ # If you have less threads than levels, the wait call will block
+ # indefinitely due to multiple waits within a thread.
+ if level < self._metadata_nthreads:
+ future = self._thread_pool.submit(self._visit_level,
+ level + 1,
+ path,
+ dir_part_keys)
+ futures_list.append(future)
+ else:
+ self._visit_level(level + 1, path, dir_part_keys)
+ if futures_list:
+ futures.wait(futures_list)
+
+ def _parse_partition(self, dirname):
+ if self.partition_scheme == 'hive':
+ return _parse_hive_partition(dirname)
+ else:
+ raise NotImplementedError('partition schema: {}'
+ .format(self.partition_scheme))
+
+ def _push_pieces(self, files, part_keys):
+ self.pieces.extend([
+ ParquetDatasetPiece._create(path, partition_keys=part_keys,
+ open_file_func=self.open_file_func)
+ for path in files
+ ])
+
+
+def _parse_hive_partition(value):
+ if '=' not in value:
+ raise ValueError('Directory name did not appear to be a '
+ 'partition: {}'.format(value))
+ return value.split('=', 1)
+
+
+def _is_private_directory(x):
+ _, tail = os.path.split(x)
+ return (tail.startswith('_') or tail.startswith('.')) and '=' not in tail
+
+
+def _path_split(path, sep):
+ i = path.rfind(sep) + 1
+ head, tail = path[:i], path[i:]
+ head = head.rstrip(sep)
+ return head, tail
+
+
+EXCLUDED_PARQUET_PATHS = {'_SUCCESS'}
+
+
+class _ParquetDatasetMetadata:
+ __slots__ = ('fs', 'memory_map', 'read_dictionary', 'common_metadata',
+ 'buffer_size')
+
+
+def _open_dataset_file(dataset, path, meta=None):
+ if (dataset.fs is not None and
+ not isinstance(dataset.fs, legacyfs.LocalFileSystem)):
+ path = dataset.fs.open(path, mode='rb')
+ return ParquetFile(
+ path,
+ metadata=meta,
+ memory_map=dataset.memory_map,
+ read_dictionary=dataset.read_dictionary,
+ common_metadata=dataset.common_metadata,
+ buffer_size=dataset.buffer_size
+ )
+
+
+_DEPR_MSG = (
+ "'{}' attribute is deprecated as of pyarrow 5.0.0 and will be removed "
+ "in a future version.{}"
+)
+
+
+_read_docstring_common = """\
+read_dictionary : list, default None
+ List of names or column paths (for nested types) to read directly
+ as DictionaryArray. Only supported for BYTE_ARRAY storage. To read
+ a flat column as dictionary-encoded pass the column name. For
+ nested types, you must pass the full column "path", which could be
+ something like level1.level2.list.item. Refer to the Parquet
+ file's schema to obtain the paths.
+memory_map : bool, default False
+ If the source is a file path, use a memory map to read file, which can
+ improve performance in some environments.
+buffer_size : int, default 0
+ If positive, perform read buffering when deserializing individual
+ column chunks. Otherwise IO calls are unbuffered.
+partitioning : Partitioning or str or list of str, default "hive"
+ The partitioning scheme for a partitioned dataset. The default of "hive"
+ assumes directory names with key=value pairs like "/year=2009/month=11".
+ In addition, a scheme like "/2009/11" is also supported, in which case
+ you need to specify the field names or a full schema. See the
+ ``pyarrow.dataset.partitioning()`` function for more details."""
+
+
+class ParquetDataset:
+
+ __doc__ = """
+Encapsulates details of reading a complete Parquet dataset possibly
+consisting of multiple files and partitions in subdirectories.
+
+Parameters
+----------
+path_or_paths : str or List[str]
+ A directory name, single file name, or list of file names.
+filesystem : FileSystem, default None
+ If nothing passed, paths assumed to be found in the local on-disk
+ filesystem.
+metadata : pyarrow.parquet.FileMetaData
+ Use metadata obtained elsewhere to validate file schemas.
+schema : pyarrow.parquet.Schema
+ Use schema obtained elsewhere to validate file schemas. Alternative to
+ metadata parameter.
+split_row_groups : bool, default False
+ Divide files into pieces for each row group in the file.
+validate_schema : bool, default True
+ Check that individual file schemas are all the same / compatible.
+filters : List[Tuple] or List[List[Tuple]] or None (default)
+ Rows which do not match the filter predicate will be removed from scanned
+ data. Partition keys embedded in a nested directory structure will be
+ exploited to avoid loading files at all if they contain no matching rows.
+ If `use_legacy_dataset` is True, filters can only reference partition
+ keys and only a hive-style directory structure is supported. When
+ setting `use_legacy_dataset` to False, also within-file level filtering
+ and different partitioning schemes are supported.
+
+ {1}
+metadata_nthreads : int, default 1
+ How many threads to allow the thread pool which is used to read the
+ dataset metadata. Increasing this is helpful to read partitioned
+ datasets.
+{0}
+use_legacy_dataset : bool, default True
+ Set to False to enable the new code path (experimental, using the
+ new Arrow Dataset API). Among other things, this allows to pass
+ `filters` for all columns and not only the partition keys, enables
+ different partitioning schemes, etc.
+pre_buffer : bool, default True
+ Coalesce and issue file reads in parallel to improve performance on
+ high-latency filesystems (e.g. S3). If True, Arrow will use a
+ background I/O thread pool. This option is only supported for
+ use_legacy_dataset=False. If using a filesystem layer that itself
+ performs readahead (e.g. fsspec's S3FS), disable readahead for best
+ results.
+coerce_int96_timestamp_unit : str, default None.
+ Cast timestamps that are stored in INT96 format to a particular resolution
+ (e.g. 'ms'). Setting to None is equivalent to 'ns' and therefore INT96
+ timestamps will be infered as timestamps in nanoseconds.
+""".format(_read_docstring_common, _DNF_filter_doc)
+
+ def __new__(cls, path_or_paths=None, filesystem=None, schema=None,
+ metadata=None, split_row_groups=False, validate_schema=True,
+ filters=None, metadata_nthreads=1, read_dictionary=None,
+ memory_map=False, buffer_size=0, partitioning="hive",
+ use_legacy_dataset=None, pre_buffer=True,
+ coerce_int96_timestamp_unit=None):
+ if use_legacy_dataset is None:
+ # if a new filesystem is passed -> default to new implementation
+ if isinstance(filesystem, FileSystem):
+ use_legacy_dataset = False
+ # otherwise the default is still True
+ else:
+ use_legacy_dataset = True
+
+ if not use_legacy_dataset:
+ return _ParquetDatasetV2(
+ path_or_paths, filesystem=filesystem,
+ filters=filters,
+ partitioning=partitioning,
+ read_dictionary=read_dictionary,
+ memory_map=memory_map,
+ buffer_size=buffer_size,
+ pre_buffer=pre_buffer,
+ coerce_int96_timestamp_unit=coerce_int96_timestamp_unit,
+ # unsupported keywords
+ schema=schema, metadata=metadata,
+ split_row_groups=split_row_groups,
+ validate_schema=validate_schema,
+ metadata_nthreads=metadata_nthreads
+ )
+ self = object.__new__(cls)
+ return self
+
+ def __init__(self, path_or_paths, filesystem=None, schema=None,
+ metadata=None, split_row_groups=False, validate_schema=True,
+ filters=None, metadata_nthreads=1, read_dictionary=None,
+ memory_map=False, buffer_size=0, partitioning="hive",
+ use_legacy_dataset=True, pre_buffer=True,
+ coerce_int96_timestamp_unit=None):
+ if partitioning != "hive":
+ raise ValueError(
+ 'Only "hive" for hive-like partitioning is supported when '
+ 'using use_legacy_dataset=True')
+ self._metadata = _ParquetDatasetMetadata()
+ a_path = path_or_paths
+ if isinstance(a_path, list):
+ a_path = a_path[0]
+
+ self._metadata.fs, _ = _get_filesystem_and_path(filesystem, a_path)
+ if isinstance(path_or_paths, list):
+ self.paths = [_parse_uri(path) for path in path_or_paths]
+ else:
+ self.paths = _parse_uri(path_or_paths)
+
+ self._metadata.read_dictionary = read_dictionary
+ self._metadata.memory_map = memory_map
+ self._metadata.buffer_size = buffer_size
+
+ (self._pieces,
+ self._partitions,
+ self.common_metadata_path,
+ self.metadata_path) = _make_manifest(
+ path_or_paths, self._fs, metadata_nthreads=metadata_nthreads,
+ open_file_func=partial(_open_dataset_file, self._metadata)
+ )
+
+ if self.common_metadata_path is not None:
+ with self._fs.open(self.common_metadata_path) as f:
+ self._metadata.common_metadata = read_metadata(
+ f,
+ memory_map=memory_map
+ )
+ else:
+ self._metadata.common_metadata = None
+
+ if metadata is None and self.metadata_path is not None:
+ with self._fs.open(self.metadata_path) as f:
+ self.metadata = read_metadata(f, memory_map=memory_map)
+ else:
+ self.metadata = metadata
+
+ self.schema = schema
+
+ self.split_row_groups = split_row_groups
+
+ if split_row_groups:
+ raise NotImplementedError("split_row_groups not yet implemented")
+
+ if filters is not None:
+ filters = _check_filters(filters)
+ self._filter(filters)
+
+ if validate_schema:
+ self.validate_schemas()
+
+ def equals(self, other):
+ if not isinstance(other, ParquetDataset):
+ raise TypeError('`other` must be an instance of ParquetDataset')
+
+ if self._fs.__class__ != other._fs.__class__:
+ return False
+ for prop in ('paths', '_pieces', '_partitions',
+ 'common_metadata_path', 'metadata_path',
+ 'common_metadata', 'metadata', 'schema',
+ 'split_row_groups'):
+ if getattr(self, prop) != getattr(other, prop):
+ return False
+ for prop in ('memory_map', 'buffer_size'):
+ if getattr(self._metadata, prop) != getattr(other._metadata, prop):
+ return False
+
+ return True
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def validate_schemas(self):
+ if self.metadata is None and self.schema is None:
+ if self.common_metadata is not None:
+ self.schema = self.common_metadata.schema
+ else:
+ self.schema = self._pieces[0].get_metadata().schema
+ elif self.schema is None:
+ self.schema = self.metadata.schema
+
+ # Verify schemas are all compatible
+ dataset_schema = self.schema.to_arrow_schema()
+ # Exclude the partition columns from the schema, they are provided
+ # by the path, not the DatasetPiece
+ if self._partitions is not None:
+ for partition_name in self._partitions.partition_names:
+ if dataset_schema.get_field_index(partition_name) != -1:
+ field_idx = dataset_schema.get_field_index(partition_name)
+ dataset_schema = dataset_schema.remove(field_idx)
+
+ for piece in self._pieces:
+ file_metadata = piece.get_metadata()
+ file_schema = file_metadata.schema.to_arrow_schema()
+ if not dataset_schema.equals(file_schema, check_metadata=False):
+ raise ValueError('Schema in {!s} was different. \n'
+ '{!s}\n\nvs\n\n{!s}'
+ .format(piece, file_schema,
+ dataset_schema))
+
+ def read(self, columns=None, use_threads=True, use_pandas_metadata=False):
+ """
+ Read multiple Parquet files as a single pyarrow.Table.
+
+ Parameters
+ ----------
+ columns : List[str]
+ Names of columns to read from the file.
+ use_threads : bool, default True
+ Perform multi-threaded column reads
+ use_pandas_metadata : bool, default False
+ Passed through to each dataset piece.
+
+ Returns
+ -------
+ pyarrow.Table
+ Content of the file as a table (of columns).
+ """
+ tables = []
+ for piece in self._pieces:
+ table = piece.read(columns=columns, use_threads=use_threads,
+ partitions=self._partitions,
+ use_pandas_metadata=use_pandas_metadata)
+ tables.append(table)
+
+ all_data = lib.concat_tables(tables)
+
+ if use_pandas_metadata:
+ # We need to ensure that this metadata is set in the Table's schema
+ # so that Table.to_pandas will construct pandas.DataFrame with the
+ # right index
+ common_metadata = self._get_common_pandas_metadata()
+ current_metadata = all_data.schema.metadata or {}
+
+ if common_metadata and b'pandas' not in current_metadata:
+ all_data = all_data.replace_schema_metadata({
+ b'pandas': common_metadata})
+
+ return all_data
+
+ def read_pandas(self, **kwargs):
+ """
+ Read dataset including pandas metadata, if any. Other arguments passed
+ through to ParquetDataset.read, see docstring for further details.
+
+ Parameters
+ ----------
+ **kwargs : optional
+ All additional options to pass to the reader.
+
+ Returns
+ -------
+ pyarrow.Table
+ Content of the file as a table (of columns).
+ """
+ return self.read(use_pandas_metadata=True, **kwargs)
+
+ def _get_common_pandas_metadata(self):
+ if self.common_metadata is None:
+ return None
+
+ keyvalues = self.common_metadata.metadata
+ return keyvalues.get(b'pandas', None)
+
+ def _filter(self, filters):
+ accepts_filter = self._partitions.filter_accepts_partition
+
+ def one_filter_accepts(piece, filter):
+ return all(accepts_filter(part_key, filter, level)
+ for level, part_key in enumerate(piece.partition_keys))
+
+ def all_filters_accept(piece):
+ return any(all(one_filter_accepts(piece, f) for f in conjunction)
+ for conjunction in filters)
+
+ self._pieces = [p for p in self._pieces if all_filters_accept(p)]
+
+ @property
+ def pieces(self):
+ warnings.warn(
+ _DEPR_MSG.format(
+ "ParquetDataset.pieces",
+ " Specify 'use_legacy_dataset=False' while constructing the "
+ "ParquetDataset, and then use the '.fragments' attribute "
+ "instead."),
+ DeprecationWarning, stacklevel=2)
+ return self._pieces
+
+ @property
+ def partitions(self):
+ warnings.warn(
+ _DEPR_MSG.format(
+ "ParquetDataset.partitions",
+ " Specify 'use_legacy_dataset=False' while constructing the "
+ "ParquetDataset, and then use the '.partitioning' attribute "
+ "instead."),
+ DeprecationWarning, stacklevel=2)
+ return self._partitions
+
+ @property
+ def memory_map(self):
+ warnings.warn(
+ _DEPR_MSG.format("ParquetDataset.memory_map", ""),
+ DeprecationWarning, stacklevel=2)
+ return self._metadata.memory_map
+
+ @property
+ def read_dictionary(self):
+ warnings.warn(
+ _DEPR_MSG.format("ParquetDataset.read_dictionary", ""),
+ DeprecationWarning, stacklevel=2)
+ return self._metadata.read_dictionary
+
+ @property
+ def buffer_size(self):
+ warnings.warn(
+ _DEPR_MSG.format("ParquetDataset.buffer_size", ""),
+ DeprecationWarning, stacklevel=2)
+ return self._metadata.buffer_size
+
+ _fs = property(
+ operator.attrgetter('_metadata.fs')
+ )
+
+ @property
+ def fs(self):
+ warnings.warn(
+ _DEPR_MSG.format(
+ "ParquetDataset.fs",
+ " Specify 'use_legacy_dataset=False' while constructing the "
+ "ParquetDataset, and then use the '.filesystem' attribute "
+ "instead."),
+ DeprecationWarning, stacklevel=2)
+ return self._metadata.fs
+
+ common_metadata = property(
+ operator.attrgetter('_metadata.common_metadata')
+ )
+
+
+def _make_manifest(path_or_paths, fs, pathsep='/', metadata_nthreads=1,
+ open_file_func=None):
+ partitions = None
+ common_metadata_path = None
+ metadata_path = None
+
+ if isinstance(path_or_paths, list) and len(path_or_paths) == 1:
+ # Dask passes a directory as a list of length 1
+ path_or_paths = path_or_paths[0]
+
+ if _is_path_like(path_or_paths) and fs.isdir(path_or_paths):
+ manifest = ParquetManifest(path_or_paths, filesystem=fs,
+ open_file_func=open_file_func,
+ pathsep=getattr(fs, "pathsep", "/"),
+ metadata_nthreads=metadata_nthreads)
+ common_metadata_path = manifest.common_metadata_path
+ metadata_path = manifest.metadata_path
+ pieces = manifest.pieces
+ partitions = manifest.partitions
+ else:
+ if not isinstance(path_or_paths, list):
+ path_or_paths = [path_or_paths]
+
+ # List of paths
+ if len(path_or_paths) == 0:
+ raise ValueError('Must pass at least one file path')
+
+ pieces = []
+ for path in path_or_paths:
+ if not fs.isfile(path):
+ raise OSError('Passed non-file path: {}'
+ .format(path))
+ piece = ParquetDatasetPiece._create(
+ path, open_file_func=open_file_func)
+ pieces.append(piece)
+
+ return pieces, partitions, common_metadata_path, metadata_path
+
+
+def _is_local_file_system(fs):
+ return isinstance(fs, LocalFileSystem) or isinstance(
+ fs, legacyfs.LocalFileSystem
+ )
+
+
+class _ParquetDatasetV2:
+ """
+ ParquetDataset shim using the Dataset API under the hood.
+ """
+
+ def __init__(self, path_or_paths, filesystem=None, filters=None,
+ partitioning="hive", read_dictionary=None, buffer_size=None,
+ memory_map=False, ignore_prefixes=None, pre_buffer=True,
+ coerce_int96_timestamp_unit=None, **kwargs):
+ import pyarrow.dataset as ds
+
+ # Raise error for not supported keywords
+ for keyword, default in [
+ ("schema", None), ("metadata", None),
+ ("split_row_groups", False), ("validate_schema", True),
+ ("metadata_nthreads", 1)]:
+ if keyword in kwargs and kwargs[keyword] is not default:
+ raise ValueError(
+ "Keyword '{0}' is not yet supported with the new "
+ "Dataset API".format(keyword))
+
+ # map format arguments
+ read_options = {
+ "pre_buffer": pre_buffer,
+ "coerce_int96_timestamp_unit": coerce_int96_timestamp_unit
+ }
+ if buffer_size:
+ read_options.update(use_buffered_stream=True,
+ buffer_size=buffer_size)
+ if read_dictionary is not None:
+ read_options.update(dictionary_columns=read_dictionary)
+
+ # map filters to Expressions
+ self._filters = filters
+ self._filter_expression = filters and _filters_to_expression(filters)
+
+ # map old filesystems to new one
+ if filesystem is not None:
+ filesystem = _ensure_filesystem(
+ filesystem, use_mmap=memory_map)
+ elif filesystem is None and memory_map:
+ # if memory_map is specified, assume local file system (string
+ # path can in principle be URI for any filesystem)
+ filesystem = LocalFileSystem(use_mmap=memory_map)
+
+ # This needs to be checked after _ensure_filesystem, because that
+ # handles the case of an fsspec LocalFileSystem
+ if (
+ hasattr(path_or_paths, "__fspath__") and
+ filesystem is not None and
+ not _is_local_file_system(filesystem)
+ ):
+ raise TypeError(
+ "Path-like objects with __fspath__ must only be used with "
+ f"local file systems, not {type(filesystem)}"
+ )
+
+ # check for single fragment dataset
+ single_file = None
+ if isinstance(path_or_paths, list):
+ if len(path_or_paths) == 1:
+ single_file = path_or_paths[0]
+ else:
+ if _is_path_like(path_or_paths):
+ path_or_paths = _stringify_path(path_or_paths)
+ if filesystem is None:
+ # path might be a URI describing the FileSystem as well
+ try:
+ filesystem, path_or_paths = FileSystem.from_uri(
+ path_or_paths)
+ except ValueError:
+ filesystem = LocalFileSystem(use_mmap=memory_map)
+ if filesystem.get_file_info(path_or_paths).is_file:
+ single_file = path_or_paths
+ else:
+ single_file = path_or_paths
+
+ if single_file is not None:
+ self._enable_parallel_column_conversion = True
+ read_options.update(enable_parallel_column_conversion=True)
+
+ parquet_format = ds.ParquetFileFormat(**read_options)
+ fragment = parquet_format.make_fragment(single_file, filesystem)
+
+ self._dataset = ds.FileSystemDataset(
+ [fragment], schema=fragment.physical_schema,
+ format=parquet_format,
+ filesystem=fragment.filesystem
+ )
+ return
+ else:
+ self._enable_parallel_column_conversion = False
+
+ parquet_format = ds.ParquetFileFormat(**read_options)
+
+ # check partitioning to enable dictionary encoding
+ if partitioning == "hive":
+ partitioning = ds.HivePartitioning.discover(
+ infer_dictionary=True)
+
+ self._dataset = ds.dataset(path_or_paths, filesystem=filesystem,
+ format=parquet_format,
+ partitioning=partitioning,
+ ignore_prefixes=ignore_prefixes)
+
+ @property
+ def schema(self):
+ return self._dataset.schema
+
+ def read(self, columns=None, use_threads=True, use_pandas_metadata=False):
+ """
+ Read (multiple) Parquet files as a single pyarrow.Table.
+
+ Parameters
+ ----------
+ columns : List[str]
+ Names of columns to read from the dataset. The partition fields
+ are not automatically included (in contrast to when setting
+ ``use_legacy_dataset=True``).
+ use_threads : bool, default True
+ Perform multi-threaded column reads.
+ use_pandas_metadata : bool, default False
+ If True and file has custom pandas schema metadata, ensure that
+ index columns are also loaded.
+
+ Returns
+ -------
+ pyarrow.Table
+ Content of the file as a table (of columns).
+ """
+ # if use_pandas_metadata, we need to include index columns in the
+ # column selection, to be able to restore those in the pandas DataFrame
+ metadata = self.schema.metadata
+ if columns is not None and use_pandas_metadata:
+ if metadata and b'pandas' in metadata:
+ # RangeIndex can be represented as dict instead of column name
+ index_columns = [
+ col for col in _get_pandas_index_columns(metadata)
+ if not isinstance(col, dict)
+ ]
+ columns = (
+ list(columns) + list(set(index_columns) - set(columns))
+ )
+
+ if self._enable_parallel_column_conversion:
+ if use_threads:
+ # Allow per-column parallelism; would otherwise cause
+ # contention in the presence of per-file parallelism.
+ use_threads = False
+
+ table = self._dataset.to_table(
+ columns=columns, filter=self._filter_expression,
+ use_threads=use_threads
+ )
+
+ # if use_pandas_metadata, restore the pandas metadata (which gets
+ # lost if doing a specific `columns` selection in to_table)
+ if use_pandas_metadata:
+ if metadata and b"pandas" in metadata:
+ new_metadata = table.schema.metadata or {}
+ new_metadata.update({b"pandas": metadata[b"pandas"]})
+ table = table.replace_schema_metadata(new_metadata)
+
+ return table
+
+ def read_pandas(self, **kwargs):
+ """
+ Read dataset including pandas metadata, if any. Other arguments passed
+ through to ParquetDataset.read, see docstring for further details.
+ """
+ return self.read(use_pandas_metadata=True, **kwargs)
+
+ @property
+ def pieces(self):
+ warnings.warn(
+ _DEPR_MSG.format("ParquetDataset.pieces",
+ " Use the '.fragments' attribute instead"),
+ DeprecationWarning, stacklevel=2)
+ return list(self._dataset.get_fragments())
+
+ @property
+ def fragments(self):
+ return list(self._dataset.get_fragments())
+
+ @property
+ def files(self):
+ return self._dataset.files
+
+ @property
+ def filesystem(self):
+ return self._dataset.filesystem
+
+ @property
+ def partitioning(self):
+ """
+ The partitioning of the Dataset source, if discovered.
+ """
+ return self._dataset.partitioning
+
+
+_read_table_docstring = """
+{0}
+
+Parameters
+----------
+source : str, pyarrow.NativeFile, or file-like object
+ If a string passed, can be a single file name or directory name. For
+ file-like objects, only read a single file. Use pyarrow.BufferReader to
+ read a file contained in a bytes or buffer-like object.
+columns : list
+ If not None, only these columns will be read from the file. A column
+ name may be a prefix of a nested field, e.g. 'a' will select 'a.b',
+ 'a.c', and 'a.d.e'. If empty, no columns will be read. Note
+ that the table will still have the correct num_rows set despite having
+ no columns.
+use_threads : bool, default True
+ Perform multi-threaded column reads.
+metadata : FileMetaData
+ If separately computed
+{1}
+use_legacy_dataset : bool, default False
+ By default, `read_table` uses the new Arrow Datasets API since
+ pyarrow 1.0.0. Among other things, this allows to pass `filters`
+ for all columns and not only the partition keys, enables
+ different partitioning schemes, etc.
+ Set to True to use the legacy behaviour.
+ignore_prefixes : list, optional
+ Files matching any of these prefixes will be ignored by the
+ discovery process if use_legacy_dataset=False.
+ This is matched to the basename of a path.
+ By default this is ['.', '_'].
+ Note that discovery happens only if a directory is passed as source.
+filesystem : FileSystem, default None
+ If nothing passed, paths assumed to be found in the local on-disk
+ filesystem.
+filters : List[Tuple] or List[List[Tuple]] or None (default)
+ Rows which do not match the filter predicate will be removed from scanned
+ data. Partition keys embedded in a nested directory structure will be
+ exploited to avoid loading files at all if they contain no matching rows.
+ If `use_legacy_dataset` is True, filters can only reference partition
+ keys and only a hive-style directory structure is supported. When
+ setting `use_legacy_dataset` to False, also within-file level filtering
+ and different partitioning schemes are supported.
+
+ {3}
+pre_buffer : bool, default True
+ Coalesce and issue file reads in parallel to improve performance on
+ high-latency filesystems (e.g. S3). If True, Arrow will use a
+ background I/O thread pool. This option is only supported for
+ use_legacy_dataset=False. If using a filesystem layer that itself
+ performs readahead (e.g. fsspec's S3FS), disable readahead for best
+ results.
+coerce_int96_timestamp_unit : str, default None.
+ Cast timestamps that are stored in INT96 format to a particular
+ resolution (e.g. 'ms'). Setting to None is equivalent to 'ns'
+ and therefore INT96 timestamps will be infered as timestamps
+ in nanoseconds.
+
+Returns
+-------
+{2}
+"""
+
+
+def read_table(source, columns=None, use_threads=True, metadata=None,
+ use_pandas_metadata=False, memory_map=False,
+ read_dictionary=None, filesystem=None, filters=None,
+ buffer_size=0, partitioning="hive", use_legacy_dataset=False,
+ ignore_prefixes=None, pre_buffer=True,
+ coerce_int96_timestamp_unit=None):
+ if not use_legacy_dataset:
+ if metadata is not None:
+ raise ValueError(
+ "The 'metadata' keyword is no longer supported with the new "
+ "datasets-based implementation. Specify "
+ "'use_legacy_dataset=True' to temporarily recover the old "
+ "behaviour."
+ )
+ try:
+ dataset = _ParquetDatasetV2(
+ source,
+ filesystem=filesystem,
+ partitioning=partitioning,
+ memory_map=memory_map,
+ read_dictionary=read_dictionary,
+ buffer_size=buffer_size,
+ filters=filters,
+ ignore_prefixes=ignore_prefixes,
+ pre_buffer=pre_buffer,
+ coerce_int96_timestamp_unit=coerce_int96_timestamp_unit
+ )
+ except ImportError:
+ # fall back on ParquetFile for simple cases when pyarrow.dataset
+ # module is not available
+ if filters is not None:
+ raise ValueError(
+ "the 'filters' keyword is not supported when the "
+ "pyarrow.dataset module is not available"
+ )
+ if partitioning != "hive":
+ raise ValueError(
+ "the 'partitioning' keyword is not supported when the "
+ "pyarrow.dataset module is not available"
+ )
+ filesystem, path = _resolve_filesystem_and_path(source, filesystem)
+ if filesystem is not None:
+ source = filesystem.open_input_file(path)
+ # TODO test that source is not a directory or a list
+ dataset = ParquetFile(
+ source, metadata=metadata, read_dictionary=read_dictionary,
+ memory_map=memory_map, buffer_size=buffer_size,
+ pre_buffer=pre_buffer,
+ coerce_int96_timestamp_unit=coerce_int96_timestamp_unit
+ )
+
+ return dataset.read(columns=columns, use_threads=use_threads,
+ use_pandas_metadata=use_pandas_metadata)
+
+ if ignore_prefixes is not None:
+ raise ValueError(
+ "The 'ignore_prefixes' keyword is only supported when "
+ "use_legacy_dataset=False")
+
+ if _is_path_like(source):
+ pf = ParquetDataset(
+ source, metadata=metadata, memory_map=memory_map,
+ read_dictionary=read_dictionary,
+ buffer_size=buffer_size,
+ filesystem=filesystem, filters=filters,
+ partitioning=partitioning,
+ coerce_int96_timestamp_unit=coerce_int96_timestamp_unit
+ )
+ else:
+ pf = ParquetFile(
+ source, metadata=metadata,
+ read_dictionary=read_dictionary,
+ memory_map=memory_map,
+ buffer_size=buffer_size,
+ coerce_int96_timestamp_unit=coerce_int96_timestamp_unit
+ )
+ return pf.read(columns=columns, use_threads=use_threads,
+ use_pandas_metadata=use_pandas_metadata)
+
+
+read_table.__doc__ = _read_table_docstring.format(
+ """Read a Table from Parquet format
+
+Note: starting with pyarrow 1.0, the default for `use_legacy_dataset` is
+switched to False.""",
+ "\n".join((_read_docstring_common,
+ """use_pandas_metadata : bool, default False
+ If True and file has custom pandas schema metadata, ensure that
+ index columns are also loaded.""")),
+ """pyarrow.Table
+ Content of the file as a table (of columns)""",
+ _DNF_filter_doc)
+
+
+def read_pandas(source, columns=None, **kwargs):
+ return read_table(
+ source, columns=columns, use_pandas_metadata=True, **kwargs
+ )
+
+
+read_pandas.__doc__ = _read_table_docstring.format(
+ 'Read a Table from Parquet format, also reading DataFrame\n'
+ 'index values if known in the file metadata',
+ "\n".join((_read_docstring_common,
+ """**kwargs : additional options for :func:`read_table`""")),
+ """pyarrow.Table
+ Content of the file as a Table of Columns, including DataFrame
+ indexes as columns""",
+ _DNF_filter_doc)
+
+
+def write_table(table, where, row_group_size=None, version='1.0',
+ use_dictionary=True, compression='snappy',
+ write_statistics=True,
+ use_deprecated_int96_timestamps=None,
+ coerce_timestamps=None,
+ allow_truncated_timestamps=False,
+ data_page_size=None, flavor=None,
+ filesystem=None,
+ compression_level=None,
+ use_byte_stream_split=False,
+ data_page_version='1.0',
+ use_compliant_nested_type=False,
+ **kwargs):
+ row_group_size = kwargs.pop('chunk_size', row_group_size)
+ use_int96 = use_deprecated_int96_timestamps
+ try:
+ with ParquetWriter(
+ where, table.schema,
+ filesystem=filesystem,
+ version=version,
+ flavor=flavor,
+ use_dictionary=use_dictionary,
+ write_statistics=write_statistics,
+ coerce_timestamps=coerce_timestamps,
+ data_page_size=data_page_size,
+ allow_truncated_timestamps=allow_truncated_timestamps,
+ compression=compression,
+ use_deprecated_int96_timestamps=use_int96,
+ compression_level=compression_level,
+ use_byte_stream_split=use_byte_stream_split,
+ data_page_version=data_page_version,
+ use_compliant_nested_type=use_compliant_nested_type,
+ **kwargs) as writer:
+ writer.write_table(table, row_group_size=row_group_size)
+ except Exception:
+ if _is_path_like(where):
+ try:
+ os.remove(_stringify_path(where))
+ except os.error:
+ pass
+ raise
+
+
+write_table.__doc__ = """
+Write a Table to Parquet format.
+
+Parameters
+----------
+table : pyarrow.Table
+where : string or pyarrow.NativeFile
+row_group_size : int
+ The number of rows per rowgroup
+{}
+**kwargs : optional
+ Additional options for ParquetWriter
+""".format(_parquet_writer_arg_docs)
+
+
+def _mkdir_if_not_exists(fs, path):
+ if fs._isfilestore() and not fs.exists(path):
+ try:
+ fs.mkdir(path)
+ except OSError:
+ assert fs.exists(path)
+
+
+def write_to_dataset(table, root_path, partition_cols=None,
+ partition_filename_cb=None, filesystem=None,
+ use_legacy_dataset=None, **kwargs):
+ """Wrapper around parquet.write_table for writing a Table to
+ Parquet format by partitions.
+ For each combination of partition columns and values,
+ a subdirectories are created in the following
+ manner:
+
+ root_dir/
+ group1=value1
+ group2=value1
+ <uuid>.parquet
+ group2=value2
+ <uuid>.parquet
+ group1=valueN
+ group2=value1
+ <uuid>.parquet
+ group2=valueN
+ <uuid>.parquet
+
+ Parameters
+ ----------
+ table : pyarrow.Table
+ root_path : str, pathlib.Path
+ The root directory of the dataset
+ filesystem : FileSystem, default None
+ If nothing passed, paths assumed to be found in the local on-disk
+ filesystem
+ partition_cols : list,
+ Column names by which to partition the dataset
+ Columns are partitioned in the order they are given
+ partition_filename_cb : callable,
+ A callback function that takes the partition key(s) as an argument
+ and allow you to override the partition filename. If nothing is
+ passed, the filename will consist of a uuid.
+ use_legacy_dataset : bool
+ Default is True unless a ``pyarrow.fs`` filesystem is passed.
+ Set to False to enable the new code path (experimental, using the
+ new Arrow Dataset API). This is more efficient when using partition
+ columns, but does not (yet) support `partition_filename_cb` and
+ `metadata_collector` keywords.
+ **kwargs : dict,
+ Additional kwargs for write_table function. See docstring for
+ `write_table` or `ParquetWriter` for more information.
+ Using `metadata_collector` in kwargs allows one to collect the
+ file metadata instances of dataset pieces. The file paths in the
+ ColumnChunkMetaData will be set relative to `root_path`.
+ """
+ if use_legacy_dataset is None:
+ # if a new filesystem is passed -> default to new implementation
+ if isinstance(filesystem, FileSystem):
+ use_legacy_dataset = False
+ # otherwise the default is still True
+ else:
+ use_legacy_dataset = True
+
+ if not use_legacy_dataset:
+ import pyarrow.dataset as ds
+
+ # extract non-file format options
+ schema = kwargs.pop("schema", None)
+ use_threads = kwargs.pop("use_threads", True)
+
+ # raise for unsupported keywords
+ msg = (
+ "The '{}' argument is not supported with the new dataset "
+ "implementation."
+ )
+ metadata_collector = kwargs.pop('metadata_collector', None)
+ file_visitor = None
+ if metadata_collector is not None:
+ def file_visitor(written_file):
+ metadata_collector.append(written_file.metadata)
+ if partition_filename_cb is not None:
+ raise ValueError(msg.format("partition_filename_cb"))
+
+ # map format arguments
+ parquet_format = ds.ParquetFileFormat()
+ write_options = parquet_format.make_write_options(**kwargs)
+
+ # map old filesystems to new one
+ if filesystem is not None:
+ filesystem = _ensure_filesystem(filesystem)
+
+ partitioning = None
+ if partition_cols:
+ part_schema = table.select(partition_cols).schema
+ partitioning = ds.partitioning(part_schema, flavor="hive")
+
+ ds.write_dataset(
+ table, root_path, filesystem=filesystem,
+ format=parquet_format, file_options=write_options, schema=schema,
+ partitioning=partitioning, use_threads=use_threads,
+ file_visitor=file_visitor)
+ return
+
+ fs, root_path = legacyfs.resolve_filesystem_and_path(root_path, filesystem)
+
+ _mkdir_if_not_exists(fs, root_path)
+
+ metadata_collector = kwargs.pop('metadata_collector', None)
+
+ if partition_cols is not None and len(partition_cols) > 0:
+ df = table.to_pandas()
+ partition_keys = [df[col] for col in partition_cols]
+ data_df = df.drop(partition_cols, axis='columns')
+ data_cols = df.columns.drop(partition_cols)
+ if len(data_cols) == 0:
+ raise ValueError('No data left to save outside partition columns')
+
+ subschema = table.schema
+
+ # ARROW-2891: Ensure the output_schema is preserved when writing a
+ # partitioned dataset
+ for col in table.schema.names:
+ if col in partition_cols:
+ subschema = subschema.remove(subschema.get_field_index(col))
+
+ for keys, subgroup in data_df.groupby(partition_keys):
+ if not isinstance(keys, tuple):
+ keys = (keys,)
+ subdir = '/'.join(
+ ['{colname}={value}'.format(colname=name, value=val)
+ for name, val in zip(partition_cols, keys)])
+ subtable = pa.Table.from_pandas(subgroup, schema=subschema,
+ safe=False)
+ _mkdir_if_not_exists(fs, '/'.join([root_path, subdir]))
+ if partition_filename_cb:
+ outfile = partition_filename_cb(keys)
+ else:
+ outfile = guid() + '.parquet'
+ relative_path = '/'.join([subdir, outfile])
+ full_path = '/'.join([root_path, relative_path])
+ with fs.open(full_path, 'wb') as f:
+ write_table(subtable, f, metadata_collector=metadata_collector,
+ **kwargs)
+ if metadata_collector is not None:
+ metadata_collector[-1].set_file_path(relative_path)
+ else:
+ if partition_filename_cb:
+ outfile = partition_filename_cb(None)
+ else:
+ outfile = guid() + '.parquet'
+ full_path = '/'.join([root_path, outfile])
+ with fs.open(full_path, 'wb') as f:
+ write_table(table, f, metadata_collector=metadata_collector,
+ **kwargs)
+ if metadata_collector is not None:
+ metadata_collector[-1].set_file_path(outfile)
+
+
+def write_metadata(schema, where, metadata_collector=None, **kwargs):
+ """
+ Write metadata-only Parquet file from schema. This can be used with
+ `write_to_dataset` to generate `_common_metadata` and `_metadata` sidecar
+ files.
+
+ Parameters
+ ----------
+ schema : pyarrow.Schema
+ where : string or pyarrow.NativeFile
+ metadata_collector : list
+ where to collect metadata information.
+ **kwargs : dict,
+ Additional kwargs for ParquetWriter class. See docstring for
+ `ParquetWriter` for more information.
+
+ Examples
+ --------
+
+ Write a dataset and collect metadata information.
+
+ >>> metadata_collector = []
+ >>> write_to_dataset(
+ ... table, root_path,
+ ... metadata_collector=metadata_collector, **writer_kwargs)
+
+ Write the `_common_metadata` parquet file without row groups statistics.
+
+ >>> write_metadata(
+ ... table.schema, root_path / '_common_metadata', **writer_kwargs)
+
+ Write the `_metadata` parquet file with row groups statistics.
+
+ >>> write_metadata(
+ ... table.schema, root_path / '_metadata',
+ ... metadata_collector=metadata_collector, **writer_kwargs)
+ """
+ writer = ParquetWriter(where, schema, **kwargs)
+ writer.close()
+
+ if metadata_collector is not None:
+ # ParquetWriter doesn't expose the metadata until it's written. Write
+ # it and read it again.
+ metadata = read_metadata(where)
+ for m in metadata_collector:
+ metadata.append_row_groups(m)
+ metadata.write_metadata_file(where)
+
+
+def read_metadata(where, memory_map=False):
+ """
+ Read FileMetadata from footer of a single Parquet file.
+
+ Parameters
+ ----------
+ where : str (filepath) or file-like object
+ memory_map : bool, default False
+ Create memory map when the source is a file path.
+
+ Returns
+ -------
+ metadata : FileMetadata
+ """
+ return ParquetFile(where, memory_map=memory_map).metadata
+
+
+def read_schema(where, memory_map=False):
+ """
+ Read effective Arrow schema from Parquet file metadata.
+
+ Parameters
+ ----------
+ where : str (filepath) or file-like object
+ memory_map : bool, default False
+ Create memory map when the source is a file path.
+
+ Returns
+ -------
+ schema : pyarrow.Schema
+ """
+ return ParquetFile(where, memory_map=memory_map).schema.to_arrow_schema()
diff --git a/src/arrow/python/pyarrow/plasma.py b/src/arrow/python/pyarrow/plasma.py
new file mode 100644
index 000000000..239d29094
--- /dev/null
+++ b/src/arrow/python/pyarrow/plasma.py
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import contextlib
+import os
+import pyarrow as pa
+import shutil
+import subprocess
+import sys
+import tempfile
+import time
+
+from pyarrow._plasma import (ObjectID, ObjectNotAvailable, # noqa
+ PlasmaBuffer, PlasmaClient, connect,
+ PlasmaObjectExists, PlasmaObjectNotFound,
+ PlasmaStoreFull)
+
+
+# The Plasma TensorFlow Operator needs to be compiled on the end user's
+# machine since the TensorFlow ABI is not stable between versions.
+# The following code checks if the operator is already present. If not,
+# the function build_plasma_tensorflow_op can be used to compile it.
+
+
+TF_PLASMA_OP_PATH = os.path.join(pa.__path__[0], "tensorflow", "plasma_op.so")
+
+
+tf_plasma_op = None
+
+
+def load_plasma_tensorflow_op():
+ global tf_plasma_op
+ import tensorflow as tf
+ tf_plasma_op = tf.load_op_library(TF_PLASMA_OP_PATH)
+
+
+def build_plasma_tensorflow_op():
+ global tf_plasma_op
+ try:
+ import tensorflow as tf
+ print("TensorFlow version: " + tf.__version__)
+ except ImportError:
+ pass
+ else:
+ print("Compiling Plasma TensorFlow Op...")
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ cc_path = os.path.join(dir_path, "tensorflow", "plasma_op.cc")
+ so_path = os.path.join(dir_path, "tensorflow", "plasma_op.so")
+ tf_cflags = tf.sysconfig.get_compile_flags()
+ if sys.platform == 'darwin':
+ tf_cflags = ["-undefined", "dynamic_lookup"] + tf_cflags
+ cmd = ["g++", "-std=c++11", "-g", "-shared", cc_path,
+ "-o", so_path, "-DNDEBUG", "-I" + pa.get_include()]
+ cmd += ["-L" + dir for dir in pa.get_library_dirs()]
+ cmd += ["-lplasma", "-larrow_python", "-larrow", "-fPIC"]
+ cmd += tf_cflags
+ cmd += tf.sysconfig.get_link_flags()
+ cmd += ["-O2"]
+ if tf.test.is_built_with_cuda():
+ cmd += ["-DGOOGLE_CUDA"]
+ print("Running command " + str(cmd))
+ subprocess.check_call(cmd)
+ tf_plasma_op = tf.load_op_library(TF_PLASMA_OP_PATH)
+
+
+@contextlib.contextmanager
+def start_plasma_store(plasma_store_memory,
+ use_valgrind=False, use_profiler=False,
+ plasma_directory=None, use_hugepages=False,
+ external_store=None):
+ """Start a plasma store process.
+ Args:
+ plasma_store_memory (int): Capacity of the plasma store in bytes.
+ use_valgrind (bool): True if the plasma store should be started inside
+ of valgrind. If this is True, use_profiler must be False.
+ use_profiler (bool): True if the plasma store should be started inside
+ a profiler. If this is True, use_valgrind must be False.
+ plasma_directory (str): Directory where plasma memory mapped files
+ will be stored.
+ use_hugepages (bool): True if the plasma store should use huge pages.
+ external_store (str): External store to use for evicted objects.
+ Return:
+ A tuple of the name of the plasma store socket and the process ID of
+ the plasma store process.
+ """
+ if use_valgrind and use_profiler:
+ raise Exception("Cannot use valgrind and profiler at the same time.")
+
+ tmpdir = tempfile.mkdtemp(prefix='test_plasma-')
+ try:
+ plasma_store_name = os.path.join(tmpdir, 'plasma.sock')
+ plasma_store_executable = os.path.join(
+ pa.__path__[0], "plasma-store-server")
+ if not os.path.exists(plasma_store_executable):
+ # Fallback to sys.prefix/bin/ (conda)
+ plasma_store_executable = os.path.join(
+ sys.prefix, "bin", "plasma-store-server")
+ command = [plasma_store_executable,
+ "-s", plasma_store_name,
+ "-m", str(plasma_store_memory)]
+ if plasma_directory:
+ command += ["-d", plasma_directory]
+ if use_hugepages:
+ command += ["-h"]
+ if external_store is not None:
+ command += ["-e", external_store]
+ stdout_file = None
+ stderr_file = None
+ if use_valgrind:
+ command = ["valgrind",
+ "--track-origins=yes",
+ "--leak-check=full",
+ "--show-leak-kinds=all",
+ "--leak-check-heuristics=stdstring",
+ "--error-exitcode=1"] + command
+ proc = subprocess.Popen(command, stdout=stdout_file,
+ stderr=stderr_file)
+ time.sleep(1.0)
+ elif use_profiler:
+ command = ["valgrind", "--tool=callgrind"] + command
+ proc = subprocess.Popen(command, stdout=stdout_file,
+ stderr=stderr_file)
+ time.sleep(1.0)
+ else:
+ proc = subprocess.Popen(command, stdout=stdout_file,
+ stderr=stderr_file)
+ time.sleep(0.1)
+ rc = proc.poll()
+ if rc is not None:
+ raise RuntimeError("plasma_store exited unexpectedly with "
+ "code %d" % (rc,))
+
+ yield plasma_store_name, proc
+ finally:
+ if proc.poll() is None:
+ proc.kill()
+ shutil.rmtree(tmpdir)
diff --git a/src/arrow/python/pyarrow/public-api.pxi b/src/arrow/python/pyarrow/public-api.pxi
new file mode 100644
index 000000000..c427fb9f5
--- /dev/null
+++ b/src/arrow/python/pyarrow/public-api.pxi
@@ -0,0 +1,418 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from libcpp.memory cimport shared_ptr
+from pyarrow.includes.libarrow cimport (CArray, CDataType, CField,
+ CRecordBatch, CSchema,
+ CTable, CTensor,
+ CSparseCOOTensor, CSparseCSRMatrix,
+ CSparseCSCMatrix, CSparseCSFTensor)
+
+# You cannot assign something to a dereferenced pointer in Cython thus these
+# methods don't use Status to indicate a successful operation.
+
+
+cdef api bint pyarrow_is_buffer(object buffer):
+ return isinstance(buffer, Buffer)
+
+
+cdef api shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer):
+ cdef Buffer buf
+ if pyarrow_is_buffer(buffer):
+ buf = <Buffer>(buffer)
+ return buf.buffer
+
+ return shared_ptr[CBuffer]()
+
+
+cdef api object pyarrow_wrap_buffer(const shared_ptr[CBuffer]& buf):
+ cdef Buffer result = Buffer.__new__(Buffer)
+ result.init(buf)
+ return result
+
+
+cdef api object pyarrow_wrap_resizable_buffer(
+ const shared_ptr[CResizableBuffer]& buf):
+ cdef ResizableBuffer result = ResizableBuffer.__new__(ResizableBuffer)
+ result.init_rz(buf)
+ return result
+
+
+cdef api bint pyarrow_is_data_type(object type_):
+ return isinstance(type_, DataType)
+
+
+cdef api shared_ptr[CDataType] pyarrow_unwrap_data_type(
+ object data_type):
+ cdef DataType type_
+ if pyarrow_is_data_type(data_type):
+ type_ = <DataType>(data_type)
+ return type_.sp_type
+
+ return shared_ptr[CDataType]()
+
+
+# Workaround for Cython parsing bug
+# https://github.com/cython/cython/issues/2143
+ctypedef const CPyExtensionType* _CPyExtensionTypePtr
+
+
+cdef api object pyarrow_wrap_data_type(
+ const shared_ptr[CDataType]& type):
+ cdef:
+ const CExtensionType* ext_type
+ const CPyExtensionType* cpy_ext_type
+ DataType out
+
+ if type.get() == NULL:
+ return None
+
+ if type.get().id() == _Type_DICTIONARY:
+ out = DictionaryType.__new__(DictionaryType)
+ elif type.get().id() == _Type_LIST:
+ out = ListType.__new__(ListType)
+ elif type.get().id() == _Type_LARGE_LIST:
+ out = LargeListType.__new__(LargeListType)
+ elif type.get().id() == _Type_MAP:
+ out = MapType.__new__(MapType)
+ elif type.get().id() == _Type_FIXED_SIZE_LIST:
+ out = FixedSizeListType.__new__(FixedSizeListType)
+ elif type.get().id() == _Type_STRUCT:
+ out = StructType.__new__(StructType)
+ elif type.get().id() == _Type_SPARSE_UNION:
+ out = SparseUnionType.__new__(SparseUnionType)
+ elif type.get().id() == _Type_DENSE_UNION:
+ out = DenseUnionType.__new__(DenseUnionType)
+ elif type.get().id() == _Type_TIMESTAMP:
+ out = TimestampType.__new__(TimestampType)
+ elif type.get().id() == _Type_DURATION:
+ out = DurationType.__new__(DurationType)
+ elif type.get().id() == _Type_FIXED_SIZE_BINARY:
+ out = FixedSizeBinaryType.__new__(FixedSizeBinaryType)
+ elif type.get().id() == _Type_DECIMAL128:
+ out = Decimal128Type.__new__(Decimal128Type)
+ elif type.get().id() == _Type_DECIMAL256:
+ out = Decimal256Type.__new__(Decimal256Type)
+ elif type.get().id() == _Type_EXTENSION:
+ ext_type = <const CExtensionType*> type.get()
+ cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type)
+ if cpy_ext_type != nullptr:
+ return cpy_ext_type.GetInstance()
+ else:
+ out = BaseExtensionType.__new__(BaseExtensionType)
+ else:
+ out = DataType.__new__(DataType)
+
+ out.init(type)
+ return out
+
+
+cdef object pyarrow_wrap_metadata(
+ const shared_ptr[const CKeyValueMetadata]& meta):
+ if meta.get() == nullptr:
+ return None
+ else:
+ return KeyValueMetadata.wrap(meta)
+
+
+cdef api bint pyarrow_is_metadata(object metadata):
+ return isinstance(metadata, KeyValueMetadata)
+
+
+cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata(object meta):
+ cdef shared_ptr[const CKeyValueMetadata] c_meta
+ if pyarrow_is_metadata(meta):
+ c_meta = (<KeyValueMetadata>meta).unwrap()
+ return c_meta
+
+
+cdef api bint pyarrow_is_field(object field):
+ return isinstance(field, Field)
+
+
+cdef api shared_ptr[CField] pyarrow_unwrap_field(object field):
+ cdef Field field_
+ if pyarrow_is_field(field):
+ field_ = <Field>(field)
+ return field_.sp_field
+
+ return shared_ptr[CField]()
+
+
+cdef api object pyarrow_wrap_field(const shared_ptr[CField]& field):
+ if field.get() == NULL:
+ return None
+ cdef Field out = Field.__new__(Field)
+ out.init(field)
+ return out
+
+
+cdef api bint pyarrow_is_schema(object schema):
+ return isinstance(schema, Schema)
+
+
+cdef api shared_ptr[CSchema] pyarrow_unwrap_schema(object schema):
+ cdef Schema sch
+ if pyarrow_is_schema(schema):
+ sch = <Schema>(schema)
+ return sch.sp_schema
+
+ return shared_ptr[CSchema]()
+
+
+cdef api object pyarrow_wrap_schema(const shared_ptr[CSchema]& schema):
+ cdef Schema out = Schema.__new__(Schema)
+ out.init_schema(schema)
+ return out
+
+
+cdef api bint pyarrow_is_array(object array):
+ return isinstance(array, Array)
+
+
+cdef api shared_ptr[CArray] pyarrow_unwrap_array(object array):
+ cdef Array arr
+ if pyarrow_is_array(array):
+ arr = <Array>(array)
+ return arr.sp_array
+
+ return shared_ptr[CArray]()
+
+
+cdef api object pyarrow_wrap_array(const shared_ptr[CArray]& sp_array):
+ if sp_array.get() == NULL:
+ raise ValueError('Array was NULL')
+
+ klass = get_array_class_from_type(sp_array.get().type())
+
+ cdef Array arr = klass.__new__(klass)
+ arr.init(sp_array)
+ return arr
+
+
+cdef api bint pyarrow_is_chunked_array(object array):
+ return isinstance(array, ChunkedArray)
+
+
+cdef api shared_ptr[CChunkedArray] pyarrow_unwrap_chunked_array(object array):
+ cdef ChunkedArray arr
+ if pyarrow_is_chunked_array(array):
+ arr = <ChunkedArray>(array)
+ return arr.sp_chunked_array
+
+ return shared_ptr[CChunkedArray]()
+
+
+cdef api object pyarrow_wrap_chunked_array(
+ const shared_ptr[CChunkedArray]& sp_array):
+ if sp_array.get() == NULL:
+ raise ValueError('ChunkedArray was NULL')
+
+ cdef CDataType* data_type = sp_array.get().type().get()
+
+ if data_type == NULL:
+ raise ValueError('ChunkedArray data type was NULL')
+
+ cdef ChunkedArray arr = ChunkedArray.__new__(ChunkedArray)
+ arr.init(sp_array)
+ return arr
+
+
+cdef api bint pyarrow_is_scalar(object value):
+ return isinstance(value, Scalar)
+
+
+cdef api shared_ptr[CScalar] pyarrow_unwrap_scalar(object scalar):
+ if pyarrow_is_scalar(scalar):
+ return (<Scalar> scalar).unwrap()
+ return shared_ptr[CScalar]()
+
+
+cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar):
+ if sp_scalar.get() == NULL:
+ raise ValueError('Scalar was NULL')
+
+ cdef CDataType* data_type = sp_scalar.get().type.get()
+
+ if data_type == NULL:
+ raise ValueError('Scalar data type was NULL')
+
+ if data_type.id() == _Type_NA:
+ return _NULL
+
+ if data_type.id() not in _scalar_classes:
+ raise ValueError('Scalar type not supported')
+
+ klass = _scalar_classes[data_type.id()]
+
+ cdef Scalar scalar = klass.__new__(klass)
+ scalar.init(sp_scalar)
+ return scalar
+
+
+cdef api bint pyarrow_is_tensor(object tensor):
+ return isinstance(tensor, Tensor)
+
+
+cdef api shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor):
+ cdef Tensor ten
+ if pyarrow_is_tensor(tensor):
+ ten = <Tensor>(tensor)
+ return ten.sp_tensor
+
+ return shared_ptr[CTensor]()
+
+
+cdef api object pyarrow_wrap_tensor(
+ const shared_ptr[CTensor]& sp_tensor):
+ if sp_tensor.get() == NULL:
+ raise ValueError('Tensor was NULL')
+
+ cdef Tensor tensor = Tensor.__new__(Tensor)
+ tensor.init(sp_tensor)
+ return tensor
+
+
+cdef api bint pyarrow_is_sparse_coo_tensor(object sparse_tensor):
+ return isinstance(sparse_tensor, SparseCOOTensor)
+
+cdef api shared_ptr[CSparseCOOTensor] pyarrow_unwrap_sparse_coo_tensor(
+ object sparse_tensor):
+ cdef SparseCOOTensor sten
+ if pyarrow_is_sparse_coo_tensor(sparse_tensor):
+ sten = <SparseCOOTensor>(sparse_tensor)
+ return sten.sp_sparse_tensor
+
+ return shared_ptr[CSparseCOOTensor]()
+
+cdef api object pyarrow_wrap_sparse_coo_tensor(
+ const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor):
+ if sp_sparse_tensor.get() == NULL:
+ raise ValueError('SparseCOOTensor was NULL')
+
+ cdef SparseCOOTensor sparse_tensor = SparseCOOTensor.__new__(
+ SparseCOOTensor)
+ sparse_tensor.init(sp_sparse_tensor)
+ return sparse_tensor
+
+
+cdef api bint pyarrow_is_sparse_csr_matrix(object sparse_tensor):
+ return isinstance(sparse_tensor, SparseCSRMatrix)
+
+cdef api shared_ptr[CSparseCSRMatrix] pyarrow_unwrap_sparse_csr_matrix(
+ object sparse_tensor):
+ cdef SparseCSRMatrix sten
+ if pyarrow_is_sparse_csr_matrix(sparse_tensor):
+ sten = <SparseCSRMatrix>(sparse_tensor)
+ return sten.sp_sparse_tensor
+
+ return shared_ptr[CSparseCSRMatrix]()
+
+cdef api object pyarrow_wrap_sparse_csr_matrix(
+ const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor):
+ if sp_sparse_tensor.get() == NULL:
+ raise ValueError('SparseCSRMatrix was NULL')
+
+ cdef SparseCSRMatrix sparse_tensor = SparseCSRMatrix.__new__(
+ SparseCSRMatrix)
+ sparse_tensor.init(sp_sparse_tensor)
+ return sparse_tensor
+
+
+cdef api bint pyarrow_is_sparse_csc_matrix(object sparse_tensor):
+ return isinstance(sparse_tensor, SparseCSCMatrix)
+
+cdef api shared_ptr[CSparseCSCMatrix] pyarrow_unwrap_sparse_csc_matrix(
+ object sparse_tensor):
+ cdef SparseCSCMatrix sten
+ if pyarrow_is_sparse_csc_matrix(sparse_tensor):
+ sten = <SparseCSCMatrix>(sparse_tensor)
+ return sten.sp_sparse_tensor
+
+ return shared_ptr[CSparseCSCMatrix]()
+
+cdef api object pyarrow_wrap_sparse_csc_matrix(
+ const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor):
+ if sp_sparse_tensor.get() == NULL:
+ raise ValueError('SparseCSCMatrix was NULL')
+
+ cdef SparseCSCMatrix sparse_tensor = SparseCSCMatrix.__new__(
+ SparseCSCMatrix)
+ sparse_tensor.init(sp_sparse_tensor)
+ return sparse_tensor
+
+
+cdef api bint pyarrow_is_sparse_csf_tensor(object sparse_tensor):
+ return isinstance(sparse_tensor, SparseCSFTensor)
+
+cdef api shared_ptr[CSparseCSFTensor] pyarrow_unwrap_sparse_csf_tensor(
+ object sparse_tensor):
+ cdef SparseCSFTensor sten
+ if pyarrow_is_sparse_csf_tensor(sparse_tensor):
+ sten = <SparseCSFTensor>(sparse_tensor)
+ return sten.sp_sparse_tensor
+
+ return shared_ptr[CSparseCSFTensor]()
+
+cdef api object pyarrow_wrap_sparse_csf_tensor(
+ const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor):
+ if sp_sparse_tensor.get() == NULL:
+ raise ValueError('SparseCSFTensor was NULL')
+
+ cdef SparseCSFTensor sparse_tensor = SparseCSFTensor.__new__(
+ SparseCSFTensor)
+ sparse_tensor.init(sp_sparse_tensor)
+ return sparse_tensor
+
+
+cdef api bint pyarrow_is_table(object table):
+ return isinstance(table, Table)
+
+
+cdef api shared_ptr[CTable] pyarrow_unwrap_table(object table):
+ cdef Table tab
+ if pyarrow_is_table(table):
+ tab = <Table>(table)
+ return tab.sp_table
+
+ return shared_ptr[CTable]()
+
+
+cdef api object pyarrow_wrap_table(const shared_ptr[CTable]& ctable):
+ cdef Table table = Table.__new__(Table)
+ table.init(ctable)
+ return table
+
+
+cdef api bint pyarrow_is_batch(object batch):
+ return isinstance(batch, RecordBatch)
+
+
+cdef api shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch):
+ cdef RecordBatch bat
+ if pyarrow_is_batch(batch):
+ bat = <RecordBatch>(batch)
+ return bat.sp_batch
+
+ return shared_ptr[CRecordBatch]()
+
+
+cdef api object pyarrow_wrap_batch(
+ const shared_ptr[CRecordBatch]& cbatch):
+ cdef RecordBatch batch = RecordBatch.__new__(RecordBatch)
+ batch.init(cbatch)
+ return batch
diff --git a/src/arrow/python/pyarrow/scalar.pxi b/src/arrow/python/pyarrow/scalar.pxi
new file mode 100644
index 000000000..80fcc0028
--- /dev/null
+++ b/src/arrow/python/pyarrow/scalar.pxi
@@ -0,0 +1,1048 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import collections
+
+
+cdef class Scalar(_Weakrefable):
+ """
+ The base class for scalars.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use "
+ "pa.scalar() instead.".format(self.__class__.__name__))
+
+ cdef void init(self, const shared_ptr[CScalar]& wrapped):
+ self.wrapped = wrapped
+
+ @staticmethod
+ cdef wrap(const shared_ptr[CScalar]& wrapped):
+ cdef:
+ Scalar self
+ Type type_id = wrapped.get().type.get().id()
+
+ if type_id == _Type_NA:
+ return _NULL
+
+ try:
+ typ = _scalar_classes[type_id]
+ except KeyError:
+ raise NotImplementedError(
+ "Wrapping scalar of type " +
+ frombytes(wrapped.get().type.get().ToString()))
+ self = typ.__new__(typ)
+ self.init(wrapped)
+
+ return self
+
+ cdef inline shared_ptr[CScalar] unwrap(self) nogil:
+ return self.wrapped
+
+ @property
+ def type(self):
+ """
+ Data type of the Scalar object.
+ """
+ return pyarrow_wrap_data_type(self.wrapped.get().type)
+
+ @property
+ def is_valid(self):
+ """
+ Holds a valid (non-null) value.
+ """
+ return self.wrapped.get().is_valid
+
+ def cast(self, object target_type):
+ """
+ Attempt a safe cast to target data type.
+ """
+ cdef:
+ DataType type = ensure_type(target_type)
+ shared_ptr[CScalar] result
+
+ with nogil:
+ result = GetResultValue(self.wrapped.get().CastTo(type.sp_type))
+
+ return Scalar.wrap(result)
+
+ def __repr__(self):
+ return '<pyarrow.{}: {!r}>'.format(
+ self.__class__.__name__, self.as_py()
+ )
+
+ def __str__(self):
+ return str(self.as_py())
+
+ def equals(self, Scalar other not None):
+ return self.wrapped.get().Equals(other.unwrap().get()[0])
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def __hash__(self):
+ cdef CScalarHash hasher
+ return hasher(self.wrapped)
+
+ def __reduce__(self):
+ return scalar, (self.as_py(), self.type)
+
+ def as_py(self):
+ raise NotImplementedError()
+
+
+_NULL = NA = None
+
+
+cdef class NullScalar(Scalar):
+ """
+ Concrete class for null scalars.
+ """
+
+ def __cinit__(self):
+ global NA
+ if NA is not None:
+ raise RuntimeError('Cannot create multiple NullScalar instances')
+ self.init(shared_ptr[CScalar](new CNullScalar()))
+
+ def __init__(self):
+ pass
+
+ def as_py(self):
+ """
+ Return this value as a Python None.
+ """
+ return None
+
+
+_NULL = NA = NullScalar()
+
+
+cdef class BooleanScalar(Scalar):
+ """
+ Concrete class for boolean scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python bool.
+ """
+ cdef CBooleanScalar* sp = <CBooleanScalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class UInt8Scalar(Scalar):
+ """
+ Concrete class for uint8 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CUInt8Scalar* sp = <CUInt8Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class Int8Scalar(Scalar):
+ """
+ Concrete class for int8 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CInt8Scalar* sp = <CInt8Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class UInt16Scalar(Scalar):
+ """
+ Concrete class for uint16 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CUInt16Scalar* sp = <CUInt16Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class Int16Scalar(Scalar):
+ """
+ Concrete class for int16 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CInt16Scalar* sp = <CInt16Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class UInt32Scalar(Scalar):
+ """
+ Concrete class for uint32 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CUInt32Scalar* sp = <CUInt32Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class Int32Scalar(Scalar):
+ """
+ Concrete class for int32 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CInt32Scalar* sp = <CInt32Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class UInt64Scalar(Scalar):
+ """
+ Concrete class for uint64 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CUInt64Scalar* sp = <CUInt64Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class Int64Scalar(Scalar):
+ """
+ Concrete class for int64 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python int.
+ """
+ cdef CInt64Scalar* sp = <CInt64Scalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class HalfFloatScalar(Scalar):
+ """
+ Concrete class for float scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python float.
+ """
+ cdef CHalfFloatScalar* sp = <CHalfFloatScalar*> self.wrapped.get()
+ return PyHalf_FromHalf(sp.value) if sp.is_valid else None
+
+
+cdef class FloatScalar(Scalar):
+ """
+ Concrete class for float scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python float.
+ """
+ cdef CFloatScalar* sp = <CFloatScalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class DoubleScalar(Scalar):
+ """
+ Concrete class for double scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python float.
+ """
+ cdef CDoubleScalar* sp = <CDoubleScalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+
+cdef class Decimal128Scalar(Scalar):
+ """
+ Concrete class for decimal128 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python Decimal.
+ """
+ cdef:
+ CDecimal128Scalar* sp = <CDecimal128Scalar*> self.wrapped.get()
+ CDecimal128Type* dtype = <CDecimal128Type*> sp.type.get()
+ if sp.is_valid:
+ return _pydecimal.Decimal(
+ frombytes(sp.value.ToString(dtype.scale()))
+ )
+ else:
+ return None
+
+
+cdef class Decimal256Scalar(Scalar):
+ """
+ Concrete class for decimal256 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python Decimal.
+ """
+ cdef:
+ CDecimal256Scalar* sp = <CDecimal256Scalar*> self.wrapped.get()
+ CDecimal256Type* dtype = <CDecimal256Type*> sp.type.get()
+ if sp.is_valid:
+ return _pydecimal.Decimal(
+ frombytes(sp.value.ToString(dtype.scale()))
+ )
+ else:
+ return None
+
+
+cdef class Date32Scalar(Scalar):
+ """
+ Concrete class for date32 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python datetime.datetime instance.
+ """
+ cdef CDate32Scalar* sp = <CDate32Scalar*> self.wrapped.get()
+
+ if sp.is_valid:
+ # shift to seconds since epoch
+ return (
+ datetime.date(1970, 1, 1) + datetime.timedelta(days=sp.value)
+ )
+ else:
+ return None
+
+
+cdef class Date64Scalar(Scalar):
+ """
+ Concrete class for date64 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python datetime.datetime instance.
+ """
+ cdef CDate64Scalar* sp = <CDate64Scalar*> self.wrapped.get()
+
+ if sp.is_valid:
+ return (
+ datetime.date(1970, 1, 1) +
+ datetime.timedelta(days=sp.value / 86400000)
+ )
+ else:
+ return None
+
+
+def _datetime_from_int(int64_t value, TimeUnit unit, tzinfo=None):
+ if unit == TimeUnit_SECOND:
+ delta = datetime.timedelta(seconds=value)
+ elif unit == TimeUnit_MILLI:
+ delta = datetime.timedelta(milliseconds=value)
+ elif unit == TimeUnit_MICRO:
+ delta = datetime.timedelta(microseconds=value)
+ else:
+ # TimeUnit_NANO: prefer pandas timestamps if available
+ if _pandas_api.have_pandas:
+ return _pandas_api.pd.Timestamp(value, tz=tzinfo, unit='ns')
+ # otherwise safely truncate to microsecond resolution datetime
+ if value % 1000 != 0:
+ raise ValueError(
+ "Nanosecond resolution temporal type {} is not safely "
+ "convertible to microseconds to convert to datetime.datetime. "
+ "Install pandas to return as Timestamp with nanosecond "
+ "support or access the .value attribute.".format(value)
+ )
+ delta = datetime.timedelta(microseconds=value // 1000)
+
+ dt = datetime.datetime(1970, 1, 1) + delta
+ # adjust timezone if set to the datatype
+ if tzinfo is not None:
+ dt = tzinfo.fromutc(dt)
+
+ return dt
+
+
+cdef class Time32Scalar(Scalar):
+ """
+ Concrete class for time32 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python datetime.timedelta instance.
+ """
+ cdef:
+ CTime32Scalar* sp = <CTime32Scalar*> self.wrapped.get()
+ CTime32Type* dtype = <CTime32Type*> sp.type.get()
+
+ if sp.is_valid:
+ return _datetime_from_int(sp.value, unit=dtype.unit()).time()
+ else:
+ return None
+
+
+cdef class Time64Scalar(Scalar):
+ """
+ Concrete class for time64 scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python datetime.timedelta instance.
+ """
+ cdef:
+ CTime64Scalar* sp = <CTime64Scalar*> self.wrapped.get()
+ CTime64Type* dtype = <CTime64Type*> sp.type.get()
+
+ if sp.is_valid:
+ return _datetime_from_int(sp.value, unit=dtype.unit()).time()
+ else:
+ return None
+
+
+cdef class TimestampScalar(Scalar):
+ """
+ Concrete class for timestamp scalars.
+ """
+
+ @property
+ def value(self):
+ cdef CTimestampScalar* sp = <CTimestampScalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+ def as_py(self):
+ """
+ Return this value as a Pandas Timestamp instance (if units are
+ nanoseconds and pandas is available), otherwise as a Python
+ datetime.datetime instance.
+ """
+ cdef:
+ CTimestampScalar* sp = <CTimestampScalar*> self.wrapped.get()
+ CTimestampType* dtype = <CTimestampType*> sp.type.get()
+
+ if not sp.is_valid:
+ return None
+
+ if not dtype.timezone().empty():
+ tzinfo = string_to_tzinfo(frombytes(dtype.timezone()))
+ else:
+ tzinfo = None
+
+ return _datetime_from_int(sp.value, unit=dtype.unit(), tzinfo=tzinfo)
+
+
+cdef class DurationScalar(Scalar):
+ """
+ Concrete class for duration scalars.
+ """
+
+ @property
+ def value(self):
+ cdef CDurationScalar* sp = <CDurationScalar*> self.wrapped.get()
+ return sp.value if sp.is_valid else None
+
+ def as_py(self):
+ """
+ Return this value as a Pandas Timedelta instance (if units are
+ nanoseconds and pandas is available), otherwise as a Python
+ datetime.timedelta instance.
+ """
+ cdef:
+ CDurationScalar* sp = <CDurationScalar*> self.wrapped.get()
+ CDurationType* dtype = <CDurationType*> sp.type.get()
+ TimeUnit unit = dtype.unit()
+
+ if not sp.is_valid:
+ return None
+
+ if unit == TimeUnit_SECOND:
+ return datetime.timedelta(seconds=sp.value)
+ elif unit == TimeUnit_MILLI:
+ return datetime.timedelta(milliseconds=sp.value)
+ elif unit == TimeUnit_MICRO:
+ return datetime.timedelta(microseconds=sp.value)
+ else:
+ # TimeUnit_NANO: prefer pandas timestamps if available
+ if _pandas_api.have_pandas:
+ return _pandas_api.pd.Timedelta(sp.value, unit='ns')
+ # otherwise safely truncate to microsecond resolution timedelta
+ if sp.value % 1000 != 0:
+ raise ValueError(
+ "Nanosecond duration {} is not safely convertible to "
+ "microseconds to convert to datetime.timedelta. Install "
+ "pandas to return as Timedelta with nanosecond support or "
+ "access the .value attribute.".format(sp.value)
+ )
+ return datetime.timedelta(microseconds=sp.value // 1000)
+
+
+cdef class MonthDayNanoIntervalScalar(Scalar):
+ """
+ Concrete class for month, day, nanosecond interval scalars.
+ """
+
+ @property
+ def value(self):
+ """
+ Same as self.as_py()
+ """
+ return self.as_py()
+
+ def as_py(self):
+ """
+ Return this value as a pyarrow.MonthDayNano.
+ """
+ cdef:
+ PyObject* val
+ CMonthDayNanoIntervalScalar* scalar
+ scalar = <CMonthDayNanoIntervalScalar*>self.wrapped.get()
+ val = GetResultValue(MonthDayNanoIntervalScalarToPyObject(
+ deref(scalar)))
+ return PyObject_to_object(val)
+
+
+cdef class BinaryScalar(Scalar):
+ """
+ Concrete class for binary-like scalars.
+ """
+
+ def as_buffer(self):
+ """
+ Return a view over this value as a Buffer object.
+ """
+ cdef CBaseBinaryScalar* sp = <CBaseBinaryScalar*> self.wrapped.get()
+ return pyarrow_wrap_buffer(sp.value) if sp.is_valid else None
+
+ def as_py(self):
+ """
+ Return this value as a Python bytes.
+ """
+ buffer = self.as_buffer()
+ return None if buffer is None else buffer.to_pybytes()
+
+
+cdef class LargeBinaryScalar(BinaryScalar):
+ pass
+
+
+cdef class FixedSizeBinaryScalar(BinaryScalar):
+ pass
+
+
+cdef class StringScalar(BinaryScalar):
+ """
+ Concrete class for string-like (utf8) scalars.
+ """
+
+ def as_py(self):
+ """
+ Return this value as a Python string.
+ """
+ buffer = self.as_buffer()
+ return None if buffer is None else str(buffer, 'utf8')
+
+
+cdef class LargeStringScalar(StringScalar):
+ pass
+
+
+cdef class ListScalar(Scalar):
+ """
+ Concrete class for list-like scalars.
+ """
+
+ @property
+ def values(self):
+ cdef CBaseListScalar* sp = <CBaseListScalar*> self.wrapped.get()
+ if sp.is_valid:
+ return pyarrow_wrap_array(sp.value)
+ else:
+ return None
+
+ def __len__(self):
+ """
+ Return the number of values.
+ """
+ return len(self.values)
+
+ def __getitem__(self, i):
+ """
+ Return the value at the given index.
+ """
+ return self.values[_normalize_index(i, len(self))]
+
+ def __iter__(self):
+ """
+ Iterate over this element's values.
+ """
+ return iter(self.values)
+
+ def as_py(self):
+ """
+ Return this value as a Python list.
+ """
+ arr = self.values
+ return None if arr is None else arr.to_pylist()
+
+
+cdef class FixedSizeListScalar(ListScalar):
+ pass
+
+
+cdef class LargeListScalar(ListScalar):
+ pass
+
+
+cdef class StructScalar(Scalar, collections.abc.Mapping):
+ """
+ Concrete class for struct scalars.
+ """
+
+ def __len__(self):
+ cdef CStructScalar* sp = <CStructScalar*> self.wrapped.get()
+ return sp.value.size()
+
+ def __iter__(self):
+ cdef:
+ CStructScalar* sp = <CStructScalar*> self.wrapped.get()
+ CStructType* dtype = <CStructType*> sp.type.get()
+ vector[shared_ptr[CField]] fields = dtype.fields()
+
+ for i in range(dtype.num_fields()):
+ yield frombytes(fields[i].get().name())
+
+ def items(self):
+ return ((key, self[i]) for i, key in enumerate(self))
+
+ def __contains__(self, key):
+ return key in list(self)
+
+ def __getitem__(self, key):
+ """
+ Return the child value for the given field.
+
+ Parameters
+ ----------
+ index : Union[int, str]
+ Index / position or name of the field.
+
+ Returns
+ -------
+ result : Scalar
+ """
+ cdef:
+ CFieldRef ref
+ CStructScalar* sp = <CStructScalar*> self.wrapped.get()
+
+ if isinstance(key, (bytes, str)):
+ ref = CFieldRef(<c_string> tobytes(key))
+ elif isinstance(key, int):
+ ref = CFieldRef(<int> key)
+ else:
+ raise TypeError('Expected integer or string index')
+
+ try:
+ return Scalar.wrap(GetResultValue(sp.field(ref)))
+ except ArrowInvalid as exc:
+ if isinstance(key, int):
+ raise IndexError(key) from exc
+ else:
+ raise KeyError(key) from exc
+
+ def as_py(self):
+ """
+ Return this value as a Python dict.
+ """
+ if self.is_valid:
+ try:
+ return {k: self[k].as_py() for k in self.keys()}
+ except KeyError:
+ raise ValueError(
+ "Converting to Python dictionary is not supported when "
+ "duplicate field names are present")
+ else:
+ return None
+
+ def _as_py_tuple(self):
+ # a version that returns a tuple instead of dict to support repr/str
+ # with the presence of duplicate field names
+ if self.is_valid:
+ return [(key, self[i].as_py()) for i, key in enumerate(self)]
+ else:
+ return None
+
+ def __repr__(self):
+ return '<pyarrow.{}: {!r}>'.format(
+ self.__class__.__name__, self._as_py_tuple()
+ )
+
+ def __str__(self):
+ return str(self._as_py_tuple())
+
+
+cdef class MapScalar(ListScalar):
+ """
+ Concrete class for map scalars.
+ """
+
+ def __getitem__(self, i):
+ """
+ Return the value at the given index.
+ """
+ arr = self.values
+ if arr is None:
+ raise IndexError(i)
+ dct = arr[_normalize_index(i, len(arr))]
+ return (dct['key'], dct['value'])
+
+ def __iter__(self):
+ """
+ Iterate over this element's values.
+ """
+ arr = self.values
+ if array is None:
+ raise StopIteration
+ for k, v in zip(arr.field('key'), arr.field('value')):
+ yield (k.as_py(), v.as_py())
+
+ def as_py(self):
+ """
+ Return this value as a Python list.
+ """
+ cdef CStructScalar* sp = <CStructScalar*> self.wrapped.get()
+ return list(self) if sp.is_valid else None
+
+
+cdef class DictionaryScalar(Scalar):
+ """
+ Concrete class for dictionary-encoded scalars.
+ """
+
+ @classmethod
+ def _reconstruct(cls, type, is_valid, index, dictionary):
+ cdef:
+ CDictionaryScalarIndexAndDictionary value
+ shared_ptr[CDictionaryScalar] wrapped
+ DataType type_
+ Scalar index_
+ Array dictionary_
+
+ type_ = ensure_type(type, allow_none=False)
+ if not isinstance(type_, DictionaryType):
+ raise TypeError('Must pass a DictionaryType instance')
+
+ if isinstance(index, Scalar):
+ if not index.type.equals(type.index_type):
+ raise TypeError("The Scalar value passed as index must have "
+ "identical type to the dictionary type's "
+ "index_type")
+ index_ = index
+ else:
+ index_ = scalar(index, type=type_.index_type)
+
+ if isinstance(dictionary, Array):
+ if not dictionary.type.equals(type.value_type):
+ raise TypeError("The Array passed as dictionary must have "
+ "identical type to the dictionary type's "
+ "value_type")
+ dictionary_ = dictionary
+ else:
+ dictionary_ = array(dictionary, type=type_.value_type)
+
+ value.index = pyarrow_unwrap_scalar(index_)
+ value.dictionary = pyarrow_unwrap_array(dictionary_)
+
+ wrapped = make_shared[CDictionaryScalar](
+ value, pyarrow_unwrap_data_type(type_), <c_bool>(is_valid)
+ )
+ return Scalar.wrap(<shared_ptr[CScalar]> wrapped)
+
+ def __reduce__(self):
+ return DictionaryScalar._reconstruct, (
+ self.type, self.is_valid, self.index, self.dictionary
+ )
+
+ @property
+ def index(self):
+ """
+ Return this value's underlying index as a scalar.
+ """
+ cdef CDictionaryScalar* sp = <CDictionaryScalar*> self.wrapped.get()
+ return Scalar.wrap(sp.value.index)
+
+ @property
+ def value(self):
+ """
+ Return the encoded value as a scalar.
+ """
+ cdef CDictionaryScalar* sp = <CDictionaryScalar*> self.wrapped.get()
+ return Scalar.wrap(GetResultValue(sp.GetEncodedValue()))
+
+ @property
+ def dictionary(self):
+ cdef CDictionaryScalar* sp = <CDictionaryScalar*> self.wrapped.get()
+ return pyarrow_wrap_array(sp.value.dictionary)
+
+ def as_py(self):
+ """
+ Return this encoded value as a Python object.
+ """
+ return self.value.as_py() if self.is_valid else None
+
+ @property
+ def index_value(self):
+ warnings.warn("`index_value` property is deprecated as of 1.0.0"
+ "please use the `index` property instead",
+ FutureWarning)
+ return self.index
+
+ @property
+ def dictionary_value(self):
+ warnings.warn("`dictionary_value` property is deprecated as of 1.0.0, "
+ "please use the `value` property instead", FutureWarning)
+ return self.value
+
+
+cdef class UnionScalar(Scalar):
+ """
+ Concrete class for Union scalars.
+ """
+
+ @property
+ def value(self):
+ """
+ Return underlying value as a scalar.
+ """
+ cdef CUnionScalar* sp = <CUnionScalar*> self.wrapped.get()
+ return Scalar.wrap(sp.value) if sp.is_valid else None
+
+ def as_py(self):
+ """
+ Return underlying value as a Python object.
+ """
+ value = self.value
+ return None if value is None else value.as_py()
+
+ @property
+ def type_code(self):
+ """
+ Return the union type code for this scalar.
+ """
+ cdef CUnionScalar* sp = <CUnionScalar*> self.wrapped.get()
+ return sp.type_code
+
+
+cdef class ExtensionScalar(Scalar):
+ """
+ Concrete class for Extension scalars.
+ """
+
+ @property
+ def value(self):
+ """
+ Return storage value as a scalar.
+ """
+ cdef CExtensionScalar* sp = <CExtensionScalar*> self.wrapped.get()
+ return Scalar.wrap(sp.value) if sp.is_valid else None
+
+ def as_py(self):
+ """
+ Return this scalar as a Python object.
+ """
+ # XXX should there be a hook to wrap the result in a custom class?
+ value = self.value
+ return None if value is None else value.as_py()
+
+ @staticmethod
+ def from_storage(BaseExtensionType typ, value):
+ """
+ Construct ExtensionScalar from type and storage value.
+
+ Parameters
+ ----------
+ typ : DataType
+ The extension type for the result scalar.
+ value : object
+ The storage value for the result scalar.
+
+ Returns
+ -------
+ ext_scalar : ExtensionScalar
+ """
+ cdef:
+ shared_ptr[CExtensionScalar] sp_scalar
+ CExtensionScalar* ext_scalar
+
+ if value is None:
+ storage = None
+ elif isinstance(value, Scalar):
+ if value.type != typ.storage_type:
+ raise TypeError("Incompatible storage type {0} "
+ "for extension type {1}"
+ .format(value.type, typ))
+ storage = value
+ else:
+ storage = scalar(value, typ.storage_type)
+
+ sp_scalar = make_shared[CExtensionScalar](typ.sp_type)
+ ext_scalar = sp_scalar.get()
+ ext_scalar.is_valid = storage is not None and storage.is_valid
+ if ext_scalar.is_valid:
+ ext_scalar.value = pyarrow_unwrap_scalar(storage)
+ check_status(ext_scalar.Validate())
+ return pyarrow_wrap_scalar(<shared_ptr[CScalar]> sp_scalar)
+
+
+cdef dict _scalar_classes = {
+ _Type_BOOL: BooleanScalar,
+ _Type_UINT8: UInt8Scalar,
+ _Type_UINT16: UInt16Scalar,
+ _Type_UINT32: UInt32Scalar,
+ _Type_UINT64: UInt64Scalar,
+ _Type_INT8: Int8Scalar,
+ _Type_INT16: Int16Scalar,
+ _Type_INT32: Int32Scalar,
+ _Type_INT64: Int64Scalar,
+ _Type_HALF_FLOAT: HalfFloatScalar,
+ _Type_FLOAT: FloatScalar,
+ _Type_DOUBLE: DoubleScalar,
+ _Type_DECIMAL128: Decimal128Scalar,
+ _Type_DECIMAL256: Decimal256Scalar,
+ _Type_DATE32: Date32Scalar,
+ _Type_DATE64: Date64Scalar,
+ _Type_TIME32: Time32Scalar,
+ _Type_TIME64: Time64Scalar,
+ _Type_TIMESTAMP: TimestampScalar,
+ _Type_DURATION: DurationScalar,
+ _Type_BINARY: BinaryScalar,
+ _Type_LARGE_BINARY: LargeBinaryScalar,
+ _Type_FIXED_SIZE_BINARY: FixedSizeBinaryScalar,
+ _Type_STRING: StringScalar,
+ _Type_LARGE_STRING: LargeStringScalar,
+ _Type_LIST: ListScalar,
+ _Type_LARGE_LIST: LargeListScalar,
+ _Type_FIXED_SIZE_LIST: FixedSizeListScalar,
+ _Type_STRUCT: StructScalar,
+ _Type_MAP: MapScalar,
+ _Type_DICTIONARY: DictionaryScalar,
+ _Type_SPARSE_UNION: UnionScalar,
+ _Type_DENSE_UNION: UnionScalar,
+ _Type_INTERVAL_MONTH_DAY_NANO: MonthDayNanoIntervalScalar,
+ _Type_EXTENSION: ExtensionScalar,
+}
+
+
+def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None):
+ """
+ Create a pyarrow.Scalar instance from a Python object.
+
+ Parameters
+ ----------
+ value : Any
+ Python object coercible to arrow's type system.
+ type : pyarrow.DataType
+ Explicit type to attempt to coerce to, otherwise will be inferred from
+ the value.
+ from_pandas : bool, default None
+ Use pandas's semantics for inferring nulls from values in
+ ndarray-like data. Defaults to False if not passed explicitly by user,
+ or True if a pandas object is passed in.
+ memory_pool : pyarrow.MemoryPool, optional
+ If not passed, will allocate memory from the currently-set default
+ memory pool.
+
+ Returns
+ -------
+ scalar : pyarrow.Scalar
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+
+ >>> pa.scalar(42)
+ <pyarrow.Int64Scalar: 42>
+
+ >>> pa.scalar("string")
+ <pyarrow.StringScalar: 'string'>
+
+ >>> pa.scalar([1, 2])
+ <pyarrow.ListScalar: [1, 2]>
+
+ >>> pa.scalar([1, 2], type=pa.list_(pa.int16()))
+ <pyarrow.ListScalar: [1, 2]>
+ """
+ cdef:
+ DataType ty
+ PyConversionOptions options
+ shared_ptr[CScalar] scalar
+ shared_ptr[CArray] array
+ shared_ptr[CChunkedArray] chunked
+ bint is_pandas_object = False
+ CMemoryPool* pool
+
+ type = ensure_type(type, allow_none=True)
+ pool = maybe_unbox_memory_pool(memory_pool)
+
+ if _is_array_like(value):
+ value = get_values(value, &is_pandas_object)
+
+ options.size = 1
+
+ if type is not None:
+ ty = ensure_type(type)
+ options.type = ty.sp_type
+
+ if from_pandas is None:
+ options.from_pandas = is_pandas_object
+ else:
+ options.from_pandas = from_pandas
+
+ value = [value]
+ with nogil:
+ chunked = GetResultValue(ConvertPySequence(value, None, options, pool))
+
+ # get the first chunk
+ assert chunked.get().num_chunks() == 1
+ array = chunked.get().chunk(0)
+
+ # retrieve the scalar from the first position
+ scalar = GetResultValue(array.get().GetScalar(0))
+ return Scalar.wrap(scalar)
diff --git a/src/arrow/python/pyarrow/serialization.pxi b/src/arrow/python/pyarrow/serialization.pxi
new file mode 100644
index 000000000..c03721578
--- /dev/null
+++ b/src/arrow/python/pyarrow/serialization.pxi
@@ -0,0 +1,556 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from cpython.ref cimport PyObject
+
+import warnings
+
+
+def _deprecate_serialization(name):
+ msg = (
+ "'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a "
+ "future version. Use pickle or the pyarrow IPC functionality instead."
+ ).format(name)
+ warnings.warn(msg, FutureWarning, stacklevel=3)
+
+
+def is_named_tuple(cls):
+ """
+ Return True if cls is a namedtuple and False otherwise.
+ """
+ b = cls.__bases__
+ if len(b) != 1 or b[0] != tuple:
+ return False
+ f = getattr(cls, "_fields", None)
+ if not isinstance(f, tuple):
+ return False
+ return all(isinstance(n, str) for n in f)
+
+
+class SerializationCallbackError(ArrowSerializationError):
+ def __init__(self, message, example_object):
+ ArrowSerializationError.__init__(self, message)
+ self.example_object = example_object
+
+
+class DeserializationCallbackError(ArrowSerializationError):
+ def __init__(self, message, type_id):
+ ArrowSerializationError.__init__(self, message)
+ self.type_id = type_id
+
+
+cdef class SerializationContext(_Weakrefable):
+ cdef:
+ object type_to_type_id
+ object whitelisted_types
+ object types_to_pickle
+ object custom_serializers
+ object custom_deserializers
+ object pickle_serializer
+ object pickle_deserializer
+
+ def __init__(self):
+ # Types with special serialization handlers
+ self.type_to_type_id = dict()
+ self.whitelisted_types = dict()
+ self.types_to_pickle = set()
+ self.custom_serializers = dict()
+ self.custom_deserializers = dict()
+ self.pickle_serializer = pickle.dumps
+ self.pickle_deserializer = pickle.loads
+
+ def set_pickle(self, serializer, deserializer):
+ """
+ Set the serializer and deserializer to use for objects that are to be
+ pickled.
+
+ Parameters
+ ----------
+ serializer : callable
+ The serializer to use (e.g., pickle.dumps or cloudpickle.dumps).
+ deserializer : callable
+ The deserializer to use (e.g., pickle.dumps or cloudpickle.dumps).
+ """
+ self.pickle_serializer = serializer
+ self.pickle_deserializer = deserializer
+
+ def clone(self):
+ """
+ Return copy of this SerializationContext.
+
+ Returns
+ -------
+ clone : SerializationContext
+ """
+ result = SerializationContext()
+ result.type_to_type_id = self.type_to_type_id.copy()
+ result.whitelisted_types = self.whitelisted_types.copy()
+ result.types_to_pickle = self.types_to_pickle.copy()
+ result.custom_serializers = self.custom_serializers.copy()
+ result.custom_deserializers = self.custom_deserializers.copy()
+ result.pickle_serializer = self.pickle_serializer
+ result.pickle_deserializer = self.pickle_deserializer
+
+ return result
+
+ def register_type(self, type_, type_id, pickle=False,
+ custom_serializer=None, custom_deserializer=None):
+ r"""
+ EXPERIMENTAL: Add type to the list of types we can serialize.
+
+ Parameters
+ ----------
+ type\_ : type
+ The type that we can serialize.
+ type_id : string
+ A string used to identify the type.
+ pickle : bool
+ True if the serialization should be done with pickle.
+ False if it should be done efficiently with Arrow.
+ custom_serializer : callable
+ This argument is optional, but can be provided to
+ serialize objects of the class in a particular way.
+ custom_deserializer : callable
+ This argument is optional, but can be provided to
+ deserialize objects of the class in a particular way.
+ """
+ if not isinstance(type_id, str):
+ raise TypeError("The type_id argument must be a string. The value "
+ "passed in has type {}.".format(type(type_id)))
+
+ self.type_to_type_id[type_] = type_id
+ self.whitelisted_types[type_id] = type_
+ if pickle:
+ self.types_to_pickle.add(type_id)
+ if custom_serializer is not None:
+ self.custom_serializers[type_id] = custom_serializer
+ self.custom_deserializers[type_id] = custom_deserializer
+
+ def _serialize_callback(self, obj):
+ found = False
+ for type_ in type(obj).__mro__:
+ if type_ in self.type_to_type_id:
+ found = True
+ break
+
+ if not found:
+ raise SerializationCallbackError(
+ "pyarrow does not know how to "
+ "serialize objects of type {}.".format(type(obj)), obj
+ )
+
+ # use the closest match to type(obj)
+ type_id = self.type_to_type_id[type_]
+ if type_id in self.types_to_pickle:
+ serialized_obj = {"data": self.pickle_serializer(obj),
+ "pickle": True}
+ elif type_id in self.custom_serializers:
+ serialized_obj = {"data": self.custom_serializers[type_id](obj)}
+ else:
+ if is_named_tuple(type_):
+ serialized_obj = {}
+ serialized_obj["_pa_getnewargs_"] = obj.__getnewargs__()
+ elif hasattr(obj, "__dict__"):
+ serialized_obj = obj.__dict__
+ else:
+ msg = "We do not know how to serialize " \
+ "the object '{}'".format(obj)
+ raise SerializationCallbackError(msg, obj)
+ return dict(serialized_obj, **{"_pytype_": type_id})
+
+ def _deserialize_callback(self, serialized_obj):
+ type_id = serialized_obj["_pytype_"]
+ if isinstance(type_id, bytes):
+ # ARROW-4675: Python 2 serialized, read in Python 3
+ type_id = frombytes(type_id)
+
+ if "pickle" in serialized_obj:
+ # The object was pickled, so unpickle it.
+ obj = self.pickle_deserializer(serialized_obj["data"])
+ else:
+ assert type_id not in self.types_to_pickle
+ if type_id not in self.whitelisted_types:
+ msg = "Type ID " + type_id + " not registered in " \
+ "deserialization callback"
+ raise DeserializationCallbackError(msg, type_id)
+ type_ = self.whitelisted_types[type_id]
+ if type_id in self.custom_deserializers:
+ obj = self.custom_deserializers[type_id](
+ serialized_obj["data"])
+ else:
+ # In this case, serialized_obj should just be
+ # the __dict__ field.
+ if "_pa_getnewargs_" in serialized_obj:
+ obj = type_.__new__(
+ type_, *serialized_obj["_pa_getnewargs_"])
+ else:
+ obj = type_.__new__(type_)
+ serialized_obj.pop("_pytype_")
+ obj.__dict__.update(serialized_obj)
+ return obj
+
+ def serialize(self, obj):
+ """
+ Call pyarrow.serialize and pass this SerializationContext.
+ """
+ return serialize(obj, context=self)
+
+ def serialize_to(self, object value, sink):
+ """
+ Call pyarrow.serialize_to and pass this SerializationContext.
+ """
+ return serialize_to(value, sink, context=self)
+
+ def deserialize(self, what):
+ """
+ Call pyarrow.deserialize and pass this SerializationContext.
+ """
+ return deserialize(what, context=self)
+
+ def deserialize_components(self, what):
+ """
+ Call pyarrow.deserialize_components and pass this SerializationContext.
+ """
+ return deserialize_components(what, context=self)
+
+
+_default_serialization_context = SerializationContext()
+_default_context_initialized = False
+
+
+def _get_default_context():
+ global _default_context_initialized
+ from pyarrow.serialization import _register_default_serialization_handlers
+ if not _default_context_initialized:
+ _register_default_serialization_handlers(
+ _default_serialization_context)
+ _default_context_initialized = True
+ return _default_serialization_context
+
+
+cdef class SerializedPyObject(_Weakrefable):
+ """
+ Arrow-serialized representation of Python object.
+ """
+ cdef:
+ CSerializedPyObject data
+
+ cdef readonly:
+ object base
+
+ @property
+ def total_bytes(self):
+ cdef CMockOutputStream mock_stream
+ with nogil:
+ check_status(self.data.WriteTo(&mock_stream))
+
+ return mock_stream.GetExtentBytesWritten()
+
+ def write_to(self, sink):
+ """
+ Write serialized object to a sink.
+ """
+ cdef shared_ptr[COutputStream] stream
+ get_writer(sink, &stream)
+ self._write_to(stream.get())
+
+ cdef _write_to(self, COutputStream* stream):
+ with nogil:
+ check_status(self.data.WriteTo(stream))
+
+ def deserialize(self, SerializationContext context=None):
+ """
+ Convert back to Python object.
+ """
+ cdef PyObject* result
+
+ if context is None:
+ context = _get_default_context()
+
+ with nogil:
+ check_status(DeserializeObject(context, self.data,
+ <PyObject*> self.base, &result))
+
+ # PyObject_to_object is necessary to avoid a memory leak;
+ # also unpack the list the object was wrapped in in serialize
+ return PyObject_to_object(result)[0]
+
+ def to_buffer(self, nthreads=1):
+ """
+ Write serialized data as Buffer.
+ """
+ cdef Buffer output = allocate_buffer(self.total_bytes)
+ sink = FixedSizeBufferWriter(output)
+ if nthreads > 1:
+ sink.set_memcopy_threads(nthreads)
+ self.write_to(sink)
+ return output
+
+ @staticmethod
+ def from_components(components):
+ """
+ Reconstruct SerializedPyObject from output of
+ SerializedPyObject.to_components.
+ """
+ cdef:
+ int num_tensors = components['num_tensors']
+ int num_ndarrays = components['num_ndarrays']
+ int num_buffers = components['num_buffers']
+ list buffers = components['data']
+ SparseTensorCounts num_sparse_tensors = SparseTensorCounts()
+ SerializedPyObject result = SerializedPyObject()
+
+ num_sparse_tensors.coo = components['num_sparse_tensors']['coo']
+ num_sparse_tensors.csr = components['num_sparse_tensors']['csr']
+ num_sparse_tensors.csc = components['num_sparse_tensors']['csc']
+ num_sparse_tensors.csf = components['num_sparse_tensors']['csf']
+ num_sparse_tensors.ndim_csf = \
+ components['num_sparse_tensors']['ndim_csf']
+
+ with nogil:
+ check_status(GetSerializedFromComponents(num_tensors,
+ num_sparse_tensors,
+ num_ndarrays,
+ num_buffers,
+ buffers, &result.data))
+
+ return result
+
+ def to_components(self, memory_pool=None):
+ """
+ Return the decomposed dict representation of the serialized object
+ containing a collection of Buffer objects which maximize opportunities
+ for zero-copy.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool default None
+ Pool to use for necessary allocations.
+
+ Returns
+
+ """
+ cdef PyObject* result
+ cdef CMemoryPool* c_pool = maybe_unbox_memory_pool(memory_pool)
+
+ with nogil:
+ check_status(self.data.GetComponents(c_pool, &result))
+
+ return PyObject_to_object(result)
+
+
+def serialize(object value, SerializationContext context=None):
+ """
+ DEPRECATED: Serialize a general Python sequence for transient storage
+ and transport.
+
+ .. deprecated:: 2.0
+ The custom serialization functionality is deprecated in pyarrow 2.0,
+ and will be removed in a future version. Use the standard library
+ ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for
+ more).
+
+ Notes
+ -----
+ This function produces data that is incompatible with the standard
+ Arrow IPC binary protocol, i.e. it cannot be used with ipc.open_stream or
+ ipc.open_file. You can use deserialize, deserialize_from, or
+ deserialize_components to read it.
+
+ Parameters
+ ----------
+ value : object
+ Python object for the sequence that is to be serialized.
+ context : SerializationContext
+ Custom serialization and deserialization context, uses a default
+ context with some standard type handlers if not specified.
+
+ Returns
+ -------
+ serialized : SerializedPyObject
+
+ """
+ _deprecate_serialization("serialize")
+ return _serialize(value, context)
+
+
+def _serialize(object value, SerializationContext context=None):
+ cdef SerializedPyObject serialized = SerializedPyObject()
+ wrapped_value = [value]
+
+ if context is None:
+ context = _get_default_context()
+
+ with nogil:
+ check_status(SerializeObject(context, wrapped_value, &serialized.data))
+ return serialized
+
+
+def serialize_to(object value, sink, SerializationContext context=None):
+ """
+ DEPRECATED: Serialize a Python sequence to a file.
+
+ .. deprecated:: 2.0
+ The custom serialization functionality is deprecated in pyarrow 2.0,
+ and will be removed in a future version. Use the standard library
+ ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for
+ more).
+
+ Parameters
+ ----------
+ value : object
+ Python object for the sequence that is to be serialized.
+ sink : NativeFile or file-like
+ File the sequence will be written to.
+ context : SerializationContext
+ Custom serialization and deserialization context, uses a default
+ context with some standard type handlers if not specified.
+ """
+ _deprecate_serialization("serialize_to")
+ serialized = _serialize(value, context)
+ serialized.write_to(sink)
+
+
+def read_serialized(source, base=None):
+ """
+ DEPRECATED: Read serialized Python sequence from file-like object.
+
+ .. deprecated:: 2.0
+ The custom serialization functionality is deprecated in pyarrow 2.0,
+ and will be removed in a future version. Use the standard library
+ ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for
+ more).
+
+ Parameters
+ ----------
+ source : NativeFile
+ File to read the sequence from.
+ base : object
+ This object will be the base object of all the numpy arrays
+ contained in the sequence.
+
+ Returns
+ -------
+ serialized : the serialized data
+ """
+ _deprecate_serialization("read_serialized")
+ return _read_serialized(source, base=base)
+
+
+def _read_serialized(source, base=None):
+ cdef shared_ptr[CRandomAccessFile] stream
+ get_reader(source, True, &stream)
+
+ cdef SerializedPyObject serialized = SerializedPyObject()
+ serialized.base = base
+ with nogil:
+ check_status(ReadSerializedObject(stream.get(), &serialized.data))
+
+ return serialized
+
+
+def deserialize_from(source, object base, SerializationContext context=None):
+ """
+ DEPRECATED: Deserialize a Python sequence from a file.
+
+ .. deprecated:: 2.0
+ The custom serialization functionality is deprecated in pyarrow 2.0,
+ and will be removed in a future version. Use the standard library
+ ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for
+ more).
+
+ This only can interact with data produced by pyarrow.serialize or
+ pyarrow.serialize_to.
+
+ Parameters
+ ----------
+ source : NativeFile
+ File to read the sequence from.
+ base : object
+ This object will be the base object of all the numpy arrays
+ contained in the sequence.
+ context : SerializationContext
+ Custom serialization and deserialization context.
+
+ Returns
+ -------
+ object
+ Python object for the deserialized sequence.
+ """
+ _deprecate_serialization("deserialize_from")
+ serialized = _read_serialized(source, base=base)
+ return serialized.deserialize(context)
+
+
+def deserialize_components(components, SerializationContext context=None):
+ """
+ DEPRECATED: Reconstruct Python object from output of
+ SerializedPyObject.to_components.
+
+ .. deprecated:: 2.0
+ The custom serialization functionality is deprecated in pyarrow 2.0,
+ and will be removed in a future version. Use the standard library
+ ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for
+ more).
+
+ Parameters
+ ----------
+ components : dict
+ Output of SerializedPyObject.to_components
+ context : SerializationContext, default None
+
+ Returns
+ -------
+ object : the Python object that was originally serialized
+ """
+ _deprecate_serialization("deserialize_components")
+ serialized = SerializedPyObject.from_components(components)
+ return serialized.deserialize(context)
+
+
+def deserialize(obj, SerializationContext context=None):
+ """
+ DEPRECATED: Deserialize Python object from Buffer or other Python
+ object supporting the buffer protocol.
+
+ .. deprecated:: 2.0
+ The custom serialization functionality is deprecated in pyarrow 2.0,
+ and will be removed in a future version. Use the standard library
+ ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for
+ more).
+
+ This only can interact with data produced by pyarrow.serialize or
+ pyarrow.serialize_to.
+
+ Parameters
+ ----------
+ obj : pyarrow.Buffer or Python object supporting buffer protocol
+ context : SerializationContext
+ Custom serialization and deserialization context.
+
+ Returns
+ -------
+ deserialized : object
+ """
+ _deprecate_serialization("deserialize")
+ return _deserialize(obj, context=context)
+
+
+def _deserialize(obj, SerializationContext context=None):
+ source = BufferReader(obj)
+ serialized = _read_serialized(source, base=obj)
+ return serialized.deserialize(context)
diff --git a/src/arrow/python/pyarrow/serialization.py b/src/arrow/python/pyarrow/serialization.py
new file mode 100644
index 000000000..d59a13166
--- /dev/null
+++ b/src/arrow/python/pyarrow/serialization.py
@@ -0,0 +1,504 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import collections
+import warnings
+
+import numpy as np
+
+import pyarrow as pa
+from pyarrow.lib import SerializationContext, py_buffer, builtin_pickle
+
+try:
+ import cloudpickle
+except ImportError:
+ cloudpickle = builtin_pickle
+
+
+try:
+ # This function is available after numpy-0.16.0.
+ # See also: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py
+ from numpy.lib.format import descr_to_dtype
+except ImportError:
+ def descr_to_dtype(descr):
+ '''
+ descr may be stored as dtype.descr, which is a list of (name, format,
+ [shape]) tuples where format may be a str or a tuple. Offsets are not
+ explicitly saved, rather empty fields with name, format == '', '|Vn'
+ are added as padding. This function reverses the process, eliminating
+ the empty padding fields.
+ '''
+ if isinstance(descr, str):
+ # No padding removal needed
+ return np.dtype(descr)
+ elif isinstance(descr, tuple):
+ # subtype, will always have a shape descr[1]
+ dt = descr_to_dtype(descr[0])
+ return np.dtype((dt, descr[1]))
+ fields = []
+ offset = 0
+ for field in descr:
+ if len(field) == 2:
+ name, descr_str = field
+ dt = descr_to_dtype(descr_str)
+ else:
+ name, descr_str, shape = field
+ dt = np.dtype((descr_to_dtype(descr_str), shape))
+
+ # Ignore padding bytes, which will be void bytes with '' as name
+ # Once support for blank names is removed, only "if name == ''"
+ # needed)
+ is_pad = (name == '' and dt.type is np.void and dt.names is None)
+ if not is_pad:
+ fields.append((name, dt, offset))
+
+ offset += dt.itemsize
+
+ names, formats, offsets = zip(*fields)
+ # names may be (title, names) tuples
+ nametups = (n if isinstance(n, tuple) else (None, n) for n in names)
+ titles, names = zip(*nametups)
+ return np.dtype({'names': names, 'formats': formats, 'titles': titles,
+ 'offsets': offsets, 'itemsize': offset})
+
+
+def _deprecate_serialization(name):
+ msg = (
+ "'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a "
+ "future version. Use pickle or the pyarrow IPC functionality instead."
+ ).format(name)
+ warnings.warn(msg, FutureWarning, stacklevel=3)
+
+
+# ----------------------------------------------------------------------
+# Set up serialization for numpy with dtype object (primitive types are
+# handled efficiently with Arrow's Tensor facilities, see
+# python_to_arrow.cc)
+
+def _serialize_numpy_array_list(obj):
+ if obj.dtype.str != '|O':
+ # Make the array c_contiguous if necessary so that we can call change
+ # the view.
+ if not obj.flags.c_contiguous:
+ obj = np.ascontiguousarray(obj)
+ return obj.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype)
+ else:
+ return obj.tolist(), np.lib.format.dtype_to_descr(obj.dtype)
+
+
+def _deserialize_numpy_array_list(data):
+ if data[1] != '|O':
+ assert data[0].dtype == np.uint8
+ return data[0].view(descr_to_dtype(data[1]))
+ else:
+ return np.array(data[0], dtype=np.dtype(data[1]))
+
+
+def _serialize_numpy_matrix(obj):
+ if obj.dtype.str != '|O':
+ # Make the array c_contiguous if necessary so that we can call change
+ # the view.
+ if not obj.flags.c_contiguous:
+ obj = np.ascontiguousarray(obj.A)
+ return obj.A.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype)
+ else:
+ return obj.A.tolist(), np.lib.format.dtype_to_descr(obj.dtype)
+
+
+def _deserialize_numpy_matrix(data):
+ if data[1] != '|O':
+ assert data[0].dtype == np.uint8
+ return np.matrix(data[0].view(descr_to_dtype(data[1])),
+ copy=False)
+ else:
+ return np.matrix(data[0], dtype=np.dtype(data[1]), copy=False)
+
+
+# ----------------------------------------------------------------------
+# pyarrow.RecordBatch-specific serialization matters
+
+def _serialize_pyarrow_recordbatch(batch):
+ output_stream = pa.BufferOutputStream()
+ with pa.RecordBatchStreamWriter(output_stream, schema=batch.schema) as wr:
+ wr.write_batch(batch)
+ return output_stream.getvalue() # This will also close the stream.
+
+
+def _deserialize_pyarrow_recordbatch(buf):
+ with pa.RecordBatchStreamReader(buf) as reader:
+ return reader.read_next_batch()
+
+
+# ----------------------------------------------------------------------
+# pyarrow.Array-specific serialization matters
+
+def _serialize_pyarrow_array(array):
+ # TODO(suquark): implement more effcient array serialization.
+ batch = pa.RecordBatch.from_arrays([array], [''])
+ return _serialize_pyarrow_recordbatch(batch)
+
+
+def _deserialize_pyarrow_array(buf):
+ # TODO(suquark): implement more effcient array deserialization.
+ batch = _deserialize_pyarrow_recordbatch(buf)
+ return batch.columns[0]
+
+
+# ----------------------------------------------------------------------
+# pyarrow.Table-specific serialization matters
+
+def _serialize_pyarrow_table(table):
+ output_stream = pa.BufferOutputStream()
+ with pa.RecordBatchStreamWriter(output_stream, schema=table.schema) as wr:
+ wr.write_table(table)
+ return output_stream.getvalue() # This will also close the stream.
+
+
+def _deserialize_pyarrow_table(buf):
+ with pa.RecordBatchStreamReader(buf) as reader:
+ return reader.read_all()
+
+
+def _pickle_to_buffer(x):
+ pickled = builtin_pickle.dumps(x, protocol=builtin_pickle.HIGHEST_PROTOCOL)
+ return py_buffer(pickled)
+
+
+def _load_pickle_from_buffer(data):
+ as_memoryview = memoryview(data)
+ return builtin_pickle.loads(as_memoryview)
+
+
+# ----------------------------------------------------------------------
+# pandas-specific serialization matters
+
+def _register_custom_pandas_handlers(context):
+ # ARROW-1784, faster path for pandas-only visibility
+
+ try:
+ import pandas as pd
+ except ImportError:
+ return
+
+ import pyarrow.pandas_compat as pdcompat
+
+ sparse_type_error_msg = (
+ '{0} serialization is not supported.\n'
+ 'Note that {0} is planned to be deprecated '
+ 'in pandas future releases.\n'
+ 'See https://github.com/pandas-dev/pandas/issues/19239 '
+ 'for more information.'
+ )
+
+ def _serialize_pandas_dataframe(obj):
+ if (pdcompat._pandas_api.has_sparse and
+ isinstance(obj, pd.SparseDataFrame)):
+ raise NotImplementedError(
+ sparse_type_error_msg.format('SparseDataFrame')
+ )
+
+ return pdcompat.dataframe_to_serialized_dict(obj)
+
+ def _deserialize_pandas_dataframe(data):
+ return pdcompat.serialized_dict_to_dataframe(data)
+
+ def _serialize_pandas_series(obj):
+ if (pdcompat._pandas_api.has_sparse and
+ isinstance(obj, pd.SparseSeries)):
+ raise NotImplementedError(
+ sparse_type_error_msg.format('SparseSeries')
+ )
+
+ return _serialize_pandas_dataframe(pd.DataFrame({obj.name: obj}))
+
+ def _deserialize_pandas_series(data):
+ deserialized = _deserialize_pandas_dataframe(data)
+ return deserialized[deserialized.columns[0]]
+
+ context.register_type(
+ pd.Series, 'pd.Series',
+ custom_serializer=_serialize_pandas_series,
+ custom_deserializer=_deserialize_pandas_series)
+
+ context.register_type(
+ pd.Index, 'pd.Index',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ if hasattr(pd.core, 'arrays'):
+ if hasattr(pd.core.arrays, 'interval'):
+ context.register_type(
+ pd.core.arrays.interval.IntervalArray,
+ 'pd.core.arrays.interval.IntervalArray',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ if hasattr(pd.core.arrays, 'period'):
+ context.register_type(
+ pd.core.arrays.period.PeriodArray,
+ 'pd.core.arrays.period.PeriodArray',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ if hasattr(pd.core.arrays, 'datetimes'):
+ context.register_type(
+ pd.core.arrays.datetimes.DatetimeArray,
+ 'pd.core.arrays.datetimes.DatetimeArray',
+ custom_serializer=_pickle_to_buffer,
+ custom_deserializer=_load_pickle_from_buffer)
+
+ context.register_type(
+ pd.DataFrame, 'pd.DataFrame',
+ custom_serializer=_serialize_pandas_dataframe,
+ custom_deserializer=_deserialize_pandas_dataframe)
+
+
+def register_torch_serialization_handlers(serialization_context):
+ # ----------------------------------------------------------------------
+ # Set up serialization for pytorch tensors
+ _deprecate_serialization("register_torch_serialization_handlers")
+
+ try:
+ import torch
+
+ def _serialize_torch_tensor(obj):
+ if obj.is_sparse:
+ return pa.SparseCOOTensor.from_numpy(
+ obj._values().detach().numpy(),
+ obj._indices().detach().numpy().T,
+ shape=list(obj.shape))
+ else:
+ return obj.detach().numpy()
+
+ def _deserialize_torch_tensor(data):
+ if isinstance(data, pa.SparseCOOTensor):
+ return torch.sparse_coo_tensor(
+ indices=data.to_numpy()[1].T,
+ values=data.to_numpy()[0][:, 0],
+ size=data.shape)
+ else:
+ return torch.from_numpy(data)
+
+ for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
+ torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
+ torch.IntTensor, torch.LongTensor, torch.Tensor]:
+ serialization_context.register_type(
+ t, "torch." + t.__name__,
+ custom_serializer=_serialize_torch_tensor,
+ custom_deserializer=_deserialize_torch_tensor)
+ except ImportError:
+ # no torch
+ pass
+
+
+def _register_collections_serialization_handlers(serialization_context):
+ def _serialize_deque(obj):
+ return list(obj)
+
+ def _deserialize_deque(data):
+ return collections.deque(data)
+
+ serialization_context.register_type(
+ collections.deque, "collections.deque",
+ custom_serializer=_serialize_deque,
+ custom_deserializer=_deserialize_deque)
+
+ def _serialize_ordered_dict(obj):
+ return list(obj.keys()), list(obj.values())
+
+ def _deserialize_ordered_dict(data):
+ return collections.OrderedDict(zip(data[0], data[1]))
+
+ serialization_context.register_type(
+ collections.OrderedDict, "collections.OrderedDict",
+ custom_serializer=_serialize_ordered_dict,
+ custom_deserializer=_deserialize_ordered_dict)
+
+ def _serialize_default_dict(obj):
+ return list(obj.keys()), list(obj.values()), obj.default_factory
+
+ def _deserialize_default_dict(data):
+ return collections.defaultdict(data[2], zip(data[0], data[1]))
+
+ serialization_context.register_type(
+ collections.defaultdict, "collections.defaultdict",
+ custom_serializer=_serialize_default_dict,
+ custom_deserializer=_deserialize_default_dict)
+
+ def _serialize_counter(obj):
+ return list(obj.keys()), list(obj.values())
+
+ def _deserialize_counter(data):
+ return collections.Counter(dict(zip(data[0], data[1])))
+
+ serialization_context.register_type(
+ collections.Counter, "collections.Counter",
+ custom_serializer=_serialize_counter,
+ custom_deserializer=_deserialize_counter)
+
+
+# ----------------------------------------------------------------------
+# Set up serialization for scipy sparse matrices. Primitive types are handled
+# efficiently with Arrow's SparseTensor facilities, see numpy_convert.cc)
+
+def _register_scipy_handlers(serialization_context):
+ try:
+ from scipy.sparse import (csr_matrix, csc_matrix, coo_matrix,
+ isspmatrix_coo, isspmatrix_csr,
+ isspmatrix_csc, isspmatrix)
+
+ def _serialize_scipy_sparse(obj):
+ if isspmatrix_coo(obj):
+ return 'coo', pa.SparseCOOTensor.from_scipy(obj)
+
+ elif isspmatrix_csr(obj):
+ return 'csr', pa.SparseCSRMatrix.from_scipy(obj)
+
+ elif isspmatrix_csc(obj):
+ return 'csc', pa.SparseCSCMatrix.from_scipy(obj)
+
+ elif isspmatrix(obj):
+ return 'csr', pa.SparseCOOTensor.from_scipy(obj.to_coo())
+
+ else:
+ raise NotImplementedError(
+ "Serialization of {} is not supported.".format(obj[0]))
+
+ def _deserialize_scipy_sparse(data):
+ if data[0] == 'coo':
+ return data[1].to_scipy()
+
+ elif data[0] == 'csr':
+ return data[1].to_scipy()
+
+ elif data[0] == 'csc':
+ return data[1].to_scipy()
+
+ else:
+ return data[1].to_scipy()
+
+ serialization_context.register_type(
+ coo_matrix, 'scipy.sparse.coo.coo_matrix',
+ custom_serializer=_serialize_scipy_sparse,
+ custom_deserializer=_deserialize_scipy_sparse)
+
+ serialization_context.register_type(
+ csr_matrix, 'scipy.sparse.csr.csr_matrix',
+ custom_serializer=_serialize_scipy_sparse,
+ custom_deserializer=_deserialize_scipy_sparse)
+
+ serialization_context.register_type(
+ csc_matrix, 'scipy.sparse.csc.csc_matrix',
+ custom_serializer=_serialize_scipy_sparse,
+ custom_deserializer=_deserialize_scipy_sparse)
+
+ except ImportError:
+ # no scipy
+ pass
+
+
+# ----------------------------------------------------------------------
+# Set up serialization for pydata/sparse tensors.
+
+def _register_pydata_sparse_handlers(serialization_context):
+ try:
+ import sparse
+
+ def _serialize_pydata_sparse(obj):
+ if isinstance(obj, sparse.COO):
+ return 'coo', pa.SparseCOOTensor.from_pydata_sparse(obj)
+ else:
+ raise NotImplementedError(
+ "Serialization of {} is not supported.".format(sparse.COO))
+
+ def _deserialize_pydata_sparse(data):
+ if data[0] == 'coo':
+ data_array, coords = data[1].to_numpy()
+ return sparse.COO(
+ data=data_array[:, 0],
+ coords=coords.T, shape=data[1].shape)
+
+ serialization_context.register_type(
+ sparse.COO, 'sparse.COO',
+ custom_serializer=_serialize_pydata_sparse,
+ custom_deserializer=_deserialize_pydata_sparse)
+
+ except ImportError:
+ # no pydata/sparse
+ pass
+
+
+def _register_default_serialization_handlers(serialization_context):
+
+ # ----------------------------------------------------------------------
+ # Set up serialization for primitive datatypes
+
+ # TODO(pcm): This is currently a workaround until arrow supports
+ # arbitrary precision integers. This is only called on long integers,
+ # see the associated case in the append method in python_to_arrow.cc
+ serialization_context.register_type(
+ int, "int",
+ custom_serializer=lambda obj: str(obj),
+ custom_deserializer=lambda data: int(data))
+
+ serialization_context.register_type(
+ type(lambda: 0), "function",
+ pickle=True)
+
+ serialization_context.register_type(type, "type", pickle=True)
+
+ serialization_context.register_type(
+ np.matrix, 'np.matrix',
+ custom_serializer=_serialize_numpy_matrix,
+ custom_deserializer=_deserialize_numpy_matrix)
+
+ serialization_context.register_type(
+ np.ndarray, 'np.array',
+ custom_serializer=_serialize_numpy_array_list,
+ custom_deserializer=_deserialize_numpy_array_list)
+
+ serialization_context.register_type(
+ pa.Array, 'pyarrow.Array',
+ custom_serializer=_serialize_pyarrow_array,
+ custom_deserializer=_deserialize_pyarrow_array)
+
+ serialization_context.register_type(
+ pa.RecordBatch, 'pyarrow.RecordBatch',
+ custom_serializer=_serialize_pyarrow_recordbatch,
+ custom_deserializer=_deserialize_pyarrow_recordbatch)
+
+ serialization_context.register_type(
+ pa.Table, 'pyarrow.Table',
+ custom_serializer=_serialize_pyarrow_table,
+ custom_deserializer=_deserialize_pyarrow_table)
+
+ _register_collections_serialization_handlers(serialization_context)
+ _register_custom_pandas_handlers(serialization_context)
+ _register_scipy_handlers(serialization_context)
+ _register_pydata_sparse_handlers(serialization_context)
+
+
+def register_default_serialization_handlers(serialization_context):
+ _deprecate_serialization("register_default_serialization_handlers")
+ _register_default_serialization_handlers(serialization_context)
+
+
+def default_serialization_context():
+ _deprecate_serialization("default_serialization_context")
+ context = SerializationContext()
+ _register_default_serialization_handlers(context)
+ return context
diff --git a/src/arrow/python/pyarrow/table.pxi b/src/arrow/python/pyarrow/table.pxi
new file mode 100644
index 000000000..8105ce482
--- /dev/null
+++ b/src/arrow/python/pyarrow/table.pxi
@@ -0,0 +1,2389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import warnings
+
+
+cdef class ChunkedArray(_PandasConvertible):
+ """
+ An array-like composed from a (possibly empty) collection of pyarrow.Arrays
+
+ Warnings
+ --------
+ Do not call this class's constructor directly.
+ """
+
+ def __cinit__(self):
+ self.chunked_array = NULL
+
+ def __init__(self):
+ raise TypeError("Do not call ChunkedArray's constructor directly, use "
+ "`chunked_array` function instead.")
+
+ cdef void init(self, const shared_ptr[CChunkedArray]& chunked_array):
+ self.sp_chunked_array = chunked_array
+ self.chunked_array = chunked_array.get()
+
+ def __reduce__(self):
+ return chunked_array, (self.chunks, self.type)
+
+ @property
+ def data(self):
+ import warnings
+ warnings.warn("Calling .data on ChunkedArray is provided for "
+ "compatibility after Column was removed, simply drop "
+ "this attribute", FutureWarning)
+ return self
+
+ @property
+ def type(self):
+ return pyarrow_wrap_data_type(self.sp_chunked_array.get().type())
+
+ def length(self):
+ return self.chunked_array.length()
+
+ def __len__(self):
+ return self.length()
+
+ def __repr__(self):
+ type_format = object.__repr__(self)
+ return '{0}\n{1}'.format(type_format, str(self))
+
+ def to_string(self, *, int indent=0, int window=10,
+ c_bool skip_new_lines=False):
+ """
+ Render a "pretty-printed" string representation of the ChunkedArray
+
+ Parameters
+ ----------
+ indent : int
+ How much to indent right the content of the array,
+ by default ``0``.
+ window : int
+ How many items to preview at the begin and end
+ of the array when the arrays is bigger than the window.
+ The other elements will be ellipsed.
+ skip_new_lines : bool
+ If the array should be rendered as a single line of text
+ or if each element should be on its own line.
+ """
+ cdef:
+ c_string result
+ PrettyPrintOptions options
+
+ with nogil:
+ options = PrettyPrintOptions(indent, window)
+ options.skip_new_lines = skip_new_lines
+ check_status(
+ PrettyPrint(
+ deref(self.chunked_array),
+ options,
+ &result
+ )
+ )
+
+ return frombytes(result, safe=True)
+
+ def format(self, **kwargs):
+ import warnings
+ warnings.warn('ChunkedArray.format is deprecated, '
+ 'use ChunkedArray.to_string')
+ return self.to_string(**kwargs)
+
+ def __str__(self):
+ return self.to_string()
+
+ def validate(self, *, full=False):
+ """
+ Perform validation checks. An exception is raised if validation fails.
+
+ By default only cheap validation checks are run. Pass `full=True`
+ for thorough validation checks (potentially O(n)).
+
+ Parameters
+ ----------
+ full: bool, default False
+ If True, run expensive checks, otherwise cheap checks only.
+
+ Raises
+ ------
+ ArrowInvalid
+ """
+ if full:
+ with nogil:
+ check_status(self.sp_chunked_array.get().ValidateFull())
+ else:
+ with nogil:
+ check_status(self.sp_chunked_array.get().Validate())
+
+ @property
+ def null_count(self):
+ """
+ Number of null entries
+
+ Returns
+ -------
+ int
+ """
+ return self.chunked_array.null_count()
+
+ @property
+ def nbytes(self):
+ """
+ Total number of bytes consumed by the elements of the chunked array.
+ """
+ size = 0
+ for chunk in self.iterchunks():
+ size += chunk.nbytes
+ return size
+
+ def __sizeof__(self):
+ return super(ChunkedArray, self).__sizeof__() + self.nbytes
+
+ def __iter__(self):
+ for chunk in self.iterchunks():
+ for item in chunk:
+ yield item
+
+ def __getitem__(self, key):
+ """
+ Slice or return value at given index
+
+ Parameters
+ ----------
+ key : integer or slice
+ Slices with step not equal to 1 (or None) will produce a copy
+ rather than a zero-copy view
+
+ Returns
+ -------
+ value : Scalar (index) or ChunkedArray (slice)
+ """
+ if isinstance(key, slice):
+ return _normalize_slice(self, key)
+
+ return self.getitem(_normalize_index(key, self.chunked_array.length()))
+
+ cdef getitem(self, int64_t index):
+ cdef int j
+
+ for j in range(self.num_chunks):
+ if index < self.chunked_array.chunk(j).get().length():
+ return self.chunk(j)[index]
+ else:
+ index -= self.chunked_array.chunk(j).get().length()
+
+ def is_null(self, *, nan_is_null=False):
+ """
+ Return boolean array indicating the null values.
+
+ Parameters
+ ----------
+ nan_is_null : bool (optional, default False)
+ Whether floating-point NaN values should also be considered null.
+
+ Returns
+ -------
+ array : boolean Array or ChunkedArray
+ """
+ options = _pc().NullOptions(nan_is_null=nan_is_null)
+ return _pc().call_function('is_null', [self], options)
+
+ def is_valid(self):
+ """
+ Return boolean array indicating the non-null values.
+ """
+ return _pc().is_valid(self)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def fill_null(self, fill_value):
+ """
+ See pyarrow.compute.fill_null docstring for usage.
+ """
+ return _pc().fill_null(self, fill_value)
+
+ def equals(self, ChunkedArray other):
+ """
+ Return whether the contents of two chunked arrays are equal.
+
+ Parameters
+ ----------
+ other : pyarrow.ChunkedArray
+ Chunked array to compare against.
+
+ Returns
+ -------
+ are_equal : bool
+ """
+ if other is None:
+ return False
+
+ cdef:
+ CChunkedArray* this_arr = self.chunked_array
+ CChunkedArray* other_arr = other.chunked_array
+ c_bool result
+
+ with nogil:
+ result = this_arr.Equals(deref(other_arr))
+
+ return result
+
+ def _to_pandas(self, options, **kwargs):
+ return _array_like_to_pandas(self, options)
+
+ def to_numpy(self):
+ """
+ Return a NumPy copy of this array (experimental).
+
+ Returns
+ -------
+ array : numpy.ndarray
+ """
+ cdef:
+ PyObject* out
+ PandasOptions c_options
+ object values
+
+ if self.type.id == _Type_EXTENSION:
+ storage_array = chunked_array(
+ [chunk.storage for chunk in self.iterchunks()],
+ type=self.type.storage_type
+ )
+ return storage_array.to_numpy()
+
+ with nogil:
+ check_status(
+ ConvertChunkedArrayToPandas(
+ c_options,
+ self.sp_chunked_array,
+ self,
+ &out
+ )
+ )
+
+ # wrap_array_output uses pandas to convert to Categorical, here
+ # always convert to numpy array
+ values = PyObject_to_object(out)
+
+ if isinstance(values, dict):
+ values = np.take(values['dictionary'], values['indices'])
+
+ return values
+
+ def __array__(self, dtype=None):
+ values = self.to_numpy()
+ if dtype is None:
+ return values
+ return values.astype(dtype)
+
+ def cast(self, object target_type, safe=True):
+ """
+ Cast array values to another data type
+
+ See pyarrow.compute.cast for usage
+ """
+ return _pc().cast(self, target_type, safe=safe)
+
+ def dictionary_encode(self, null_encoding='mask'):
+ """
+ Compute dictionary-encoded representation of array
+
+ Returns
+ -------
+ pyarrow.ChunkedArray
+ Same chunking as the input, all chunks share a common dictionary.
+ """
+ options = _pc().DictionaryEncodeOptions(null_encoding)
+ return _pc().call_function('dictionary_encode', [self], options)
+
+ def flatten(self, MemoryPool memory_pool=None):
+ """
+ Flatten this ChunkedArray. If it has a struct type, the column is
+ flattened into one array per struct field.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool
+
+ Returns
+ -------
+ result : List[ChunkedArray]
+ """
+ cdef:
+ vector[shared_ptr[CChunkedArray]] flattened
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+
+ with nogil:
+ flattened = GetResultValue(self.chunked_array.Flatten(pool))
+
+ return [pyarrow_wrap_chunked_array(col) for col in flattened]
+
+ def combine_chunks(self, MemoryPool memory_pool=None):
+ """
+ Flatten this ChunkedArray into a single non-chunked array.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool
+
+ Returns
+ -------
+ result : Array
+ """
+ return concat_arrays(self.chunks)
+
+ def unique(self):
+ """
+ Compute distinct elements in array
+
+ Returns
+ -------
+ pyarrow.Array
+ """
+ return _pc().call_function('unique', [self])
+
+ def value_counts(self):
+ """
+ Compute counts of unique elements in array.
+
+ Returns
+ -------
+ An array of <input type "Values", int64_t "Counts"> structs
+ """
+ return _pc().call_function('value_counts', [self])
+
+ def slice(self, offset=0, length=None):
+ """
+ Compute zero-copy slice of this ChunkedArray
+
+ Parameters
+ ----------
+ offset : int, default 0
+ Offset from start of array to slice
+ length : int, default None
+ Length of slice (default is until end of batch starting from
+ offset)
+
+ Returns
+ -------
+ sliced : ChunkedArray
+ """
+ cdef shared_ptr[CChunkedArray] result
+
+ if offset < 0:
+ raise IndexError('Offset must be non-negative')
+
+ offset = min(len(self), offset)
+ if length is None:
+ result = self.chunked_array.Slice(offset)
+ else:
+ result = self.chunked_array.Slice(offset, length)
+
+ return pyarrow_wrap_chunked_array(result)
+
+ def filter(self, mask, object null_selection_behavior="drop"):
+ """
+ Select values from a chunked array. See pyarrow.compute.filter for full
+ usage.
+ """
+ return _pc().filter(self, mask, null_selection_behavior)
+
+ def index(self, value, start=None, end=None, *, memory_pool=None):
+ """
+ Find the first index of a value.
+
+ See pyarrow.compute.index for full usage.
+ """
+ return _pc().index(self, value, start, end, memory_pool=memory_pool)
+
+ def take(self, object indices):
+ """
+ Select values from a chunked array. See pyarrow.compute.take for full
+ usage.
+ """
+ return _pc().take(self, indices)
+
+ def drop_null(self):
+ """
+ Remove missing values from a chunked array.
+ See pyarrow.compute.drop_null for full description.
+ """
+ return _pc().drop_null(self)
+
+ def unify_dictionaries(self, MemoryPool memory_pool=None):
+ """
+ Unify dictionaries across all chunks.
+
+ This method returns an equivalent chunked array, but where all
+ chunks share the same dictionary values. Dictionary indices are
+ transposed accordingly.
+
+ If there are no dictionaries in the chunked array, it is returned
+ unchanged.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool
+
+ Returns
+ -------
+ result : ChunkedArray
+ """
+ cdef:
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ shared_ptr[CChunkedArray] c_result
+
+ with nogil:
+ c_result = GetResultValue(CDictionaryUnifier.UnifyChunkedArray(
+ self.sp_chunked_array, pool))
+
+ return pyarrow_wrap_chunked_array(c_result)
+
+ @property
+ def num_chunks(self):
+ """
+ Number of underlying chunks
+
+ Returns
+ -------
+ int
+ """
+ return self.chunked_array.num_chunks()
+
+ def chunk(self, i):
+ """
+ Select a chunk by its index
+
+ Parameters
+ ----------
+ i : int
+
+ Returns
+ -------
+ pyarrow.Array
+ """
+ if i >= self.num_chunks or i < 0:
+ raise IndexError('Chunk index out of range.')
+
+ return pyarrow_wrap_array(self.chunked_array.chunk(i))
+
+ @property
+ def chunks(self):
+ return list(self.iterchunks())
+
+ def iterchunks(self):
+ for i in range(self.num_chunks):
+ yield self.chunk(i)
+
+ def to_pylist(self):
+ """
+ Convert to a list of native Python objects.
+ """
+ result = []
+ for i in range(self.num_chunks):
+ result += self.chunk(i).to_pylist()
+ return result
+
+
+def chunked_array(arrays, type=None):
+ """
+ Construct chunked array from list of array-like objects
+
+ Parameters
+ ----------
+ arrays : Array, list of Array, or values coercible to arrays
+ Must all be the same data type. Can be empty only if type also passed.
+ type : DataType or string coercible to DataType
+
+ Returns
+ -------
+ ChunkedArray
+ """
+ cdef:
+ Array arr
+ vector[shared_ptr[CArray]] c_arrays
+ shared_ptr[CChunkedArray] sp_chunked_array
+
+ type = ensure_type(type, allow_none=True)
+
+ if isinstance(arrays, Array):
+ arrays = [arrays]
+
+ for x in arrays:
+ arr = x if isinstance(x, Array) else array(x, type=type)
+
+ if type is None:
+ # it allows more flexible chunked array construction from to coerce
+ # subsequent arrays to the firstly inferred array type
+ # it also spares the inference overhead after the first chunk
+ type = arr.type
+ else:
+ if arr.type != type:
+ raise TypeError(
+ "All array chunks must have type {}".format(type)
+ )
+
+ c_arrays.push_back(arr.sp_array)
+
+ if c_arrays.size() == 0 and type is None:
+ raise ValueError("When passing an empty collection of arrays "
+ "you must also pass the data type")
+
+ sp_chunked_array.reset(
+ new CChunkedArray(c_arrays, pyarrow_unwrap_data_type(type))
+ )
+ with nogil:
+ check_status(sp_chunked_array.get().Validate())
+
+ return pyarrow_wrap_chunked_array(sp_chunked_array)
+
+
+cdef _schema_from_arrays(arrays, names, metadata, shared_ptr[CSchema]* schema):
+ cdef:
+ Py_ssize_t K = len(arrays)
+ c_string c_name
+ shared_ptr[CDataType] c_type
+ shared_ptr[const CKeyValueMetadata] c_meta
+ vector[shared_ptr[CField]] c_fields
+
+ if metadata is not None:
+ c_meta = KeyValueMetadata(metadata).unwrap()
+
+ if K == 0:
+ if names is None or len(names) == 0:
+ schema.reset(new CSchema(c_fields, c_meta))
+ return arrays
+ else:
+ raise ValueError('Length of names ({}) does not match '
+ 'length of arrays ({})'.format(len(names), K))
+
+ c_fields.resize(K)
+
+ if names is None:
+ raise ValueError('Must pass names or schema when constructing '
+ 'Table or RecordBatch.')
+
+ if len(names) != K:
+ raise ValueError('Length of names ({}) does not match '
+ 'length of arrays ({})'.format(len(names), K))
+
+ converted_arrays = []
+ for i in range(K):
+ val = arrays[i]
+ if not isinstance(val, (Array, ChunkedArray)):
+ val = array(val)
+
+ c_type = (<DataType> val.type).sp_type
+
+ if names[i] is None:
+ c_name = b'None'
+ else:
+ c_name = tobytes(names[i])
+ c_fields[i].reset(new CField(c_name, c_type, True))
+ converted_arrays.append(val)
+
+ schema.reset(new CSchema(c_fields, c_meta))
+ return converted_arrays
+
+
+cdef _sanitize_arrays(arrays, names, schema, metadata,
+ shared_ptr[CSchema]* c_schema):
+ cdef Schema cy_schema
+ if schema is None:
+ converted_arrays = _schema_from_arrays(arrays, names, metadata,
+ c_schema)
+ else:
+ if names is not None:
+ raise ValueError('Cannot pass both schema and names')
+ if metadata is not None:
+ raise ValueError('Cannot pass both schema and metadata')
+ cy_schema = schema
+
+ if len(schema) != len(arrays):
+ raise ValueError('Schema and number of arrays unequal')
+
+ c_schema[0] = cy_schema.sp_schema
+ converted_arrays = []
+ for i, item in enumerate(arrays):
+ item = asarray(item, type=schema[i].type)
+ converted_arrays.append(item)
+ return converted_arrays
+
+
+cdef class RecordBatch(_PandasConvertible):
+ """
+ Batch of rows of columns of equal length
+
+ Warnings
+ --------
+ Do not call this class's constructor directly, use one of the
+ ``RecordBatch.from_*`` functions instead.
+ """
+
+ def __cinit__(self):
+ self.batch = NULL
+ self._schema = None
+
+ def __init__(self):
+ raise TypeError("Do not call RecordBatch's constructor directly, use "
+ "one of the `RecordBatch.from_*` functions instead.")
+
+ cdef void init(self, const shared_ptr[CRecordBatch]& batch):
+ self.sp_batch = batch
+ self.batch = batch.get()
+
+ @staticmethod
+ def from_pydict(mapping, schema=None, metadata=None):
+ """
+ Construct a RecordBatch from Arrow arrays or columns.
+
+ Parameters
+ ----------
+ mapping : dict or Mapping
+ A mapping of strings to Arrays or Python lists.
+ schema : Schema, default None
+ If not passed, will be inferred from the Mapping values.
+ metadata : dict or Mapping, default None
+ Optional metadata for the schema (if inferred).
+
+ Returns
+ -------
+ RecordBatch
+ """
+
+ return _from_pydict(cls=RecordBatch,
+ mapping=mapping,
+ schema=schema,
+ metadata=metadata)
+
+ def __reduce__(self):
+ return _reconstruct_record_batch, (self.columns, self.schema)
+
+ def __len__(self):
+ return self.batch.num_rows()
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def to_string(self, show_metadata=False):
+ # Use less verbose schema output.
+ schema_as_string = self.schema.to_string(
+ show_field_metadata=show_metadata,
+ show_schema_metadata=show_metadata
+ )
+ return 'pyarrow.{}\n{}'.format(type(self).__name__, schema_as_string)
+
+ def __repr__(self):
+ return self.to_string()
+
+ def validate(self, *, full=False):
+ """
+ Perform validation checks. An exception is raised if validation fails.
+
+ By default only cheap validation checks are run. Pass `full=True`
+ for thorough validation checks (potentially O(n)).
+
+ Parameters
+ ----------
+ full: bool, default False
+ If True, run expensive checks, otherwise cheap checks only.
+
+ Raises
+ ------
+ ArrowInvalid
+ """
+ if full:
+ with nogil:
+ check_status(self.batch.ValidateFull())
+ else:
+ with nogil:
+ check_status(self.batch.Validate())
+
+ def replace_schema_metadata(self, metadata=None):
+ """
+ Create shallow copy of record batch by replacing schema
+ key-value metadata with the indicated new metadata (which may be None,
+ which deletes any existing metadata
+
+ Parameters
+ ----------
+ metadata : dict, default None
+
+ Returns
+ -------
+ shallow_copy : RecordBatch
+ """
+ cdef:
+ shared_ptr[const CKeyValueMetadata] c_meta
+ shared_ptr[CRecordBatch] c_batch
+
+ metadata = ensure_metadata(metadata, allow_none=True)
+ c_meta = pyarrow_unwrap_metadata(metadata)
+ with nogil:
+ c_batch = self.batch.ReplaceSchemaMetadata(c_meta)
+
+ return pyarrow_wrap_batch(c_batch)
+
+ @property
+ def num_columns(self):
+ """
+ Number of columns
+
+ Returns
+ -------
+ int
+ """
+ return self.batch.num_columns()
+
+ @property
+ def num_rows(self):
+ """
+ Number of rows
+
+ Due to the definition of a RecordBatch, all columns have the same
+ number of rows.
+
+ Returns
+ -------
+ int
+ """
+ return len(self)
+
+ @property
+ def schema(self):
+ """
+ Schema of the RecordBatch and its columns
+
+ Returns
+ -------
+ pyarrow.Schema
+ """
+ if self._schema is None:
+ self._schema = pyarrow_wrap_schema(self.batch.schema())
+
+ return self._schema
+
+ def field(self, i):
+ """
+ Select a schema field by its column name or numeric index
+
+ Parameters
+ ----------
+ i : int or string
+ The index or name of the field to retrieve
+
+ Returns
+ -------
+ pyarrow.Field
+ """
+ return self.schema.field(i)
+
+ @property
+ def columns(self):
+ """
+ List of all columns in numerical order
+
+ Returns
+ -------
+ list of pa.Array
+ """
+ return [self.column(i) for i in range(self.num_columns)]
+
+ def _ensure_integer_index(self, i):
+ """
+ Ensure integer index (convert string column name to integer if needed).
+ """
+ if isinstance(i, (bytes, str)):
+ field_indices = self.schema.get_all_field_indices(i)
+
+ if len(field_indices) == 0:
+ raise KeyError(
+ "Field \"{}\" does not exist in record batch schema"
+ .format(i))
+ elif len(field_indices) > 1:
+ raise KeyError(
+ "Field \"{}\" exists {} times in record batch schema"
+ .format(i, len(field_indices)))
+ else:
+ return field_indices[0]
+ elif isinstance(i, int):
+ return i
+ else:
+ raise TypeError("Index must either be string or integer")
+
+ def column(self, i):
+ """
+ Select single column from record batch
+
+ Parameters
+ ----------
+ i : int or string
+ The index or name of the column to retrieve.
+
+ Returns
+ -------
+ column : pyarrow.Array
+ """
+ return self._column(self._ensure_integer_index(i))
+
+ def _column(self, int i):
+ """
+ Select single column from record batch by its numeric index.
+
+ Parameters
+ ----------
+ i : int
+ The index of the column to retrieve.
+
+ Returns
+ -------
+ column : pyarrow.Array
+ """
+ cdef int index = <int> _normalize_index(i, self.num_columns)
+ cdef Array result = pyarrow_wrap_array(self.batch.column(index))
+ result._name = self.schema[index].name
+ return result
+
+ @property
+ def nbytes(self):
+ """
+ Total number of bytes consumed by the elements of the record batch.
+ """
+ size = 0
+ for i in range(self.num_columns):
+ size += self.column(i).nbytes
+ return size
+
+ def __sizeof__(self):
+ return super(RecordBatch, self).__sizeof__() + self.nbytes
+
+ def __getitem__(self, key):
+ """
+ Slice or return column at given index or column name
+
+ Parameters
+ ----------
+ key : integer, str, or slice
+ Slices with step not equal to 1 (or None) will produce a copy
+ rather than a zero-copy view
+
+ Returns
+ -------
+ value : Array (index/column) or RecordBatch (slice)
+ """
+ if isinstance(key, slice):
+ return _normalize_slice(self, key)
+ else:
+ return self.column(key)
+
+ def serialize(self, memory_pool=None):
+ """
+ Write RecordBatch to Buffer as encapsulated IPC message.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ Uses default memory pool if not specified
+
+ Returns
+ -------
+ serialized : Buffer
+ """
+ cdef shared_ptr[CBuffer] buffer
+ cdef CIpcWriteOptions options = CIpcWriteOptions.Defaults()
+ options.memory_pool = maybe_unbox_memory_pool(memory_pool)
+
+ with nogil:
+ buffer = GetResultValue(
+ SerializeRecordBatch(deref(self.batch), options))
+ return pyarrow_wrap_buffer(buffer)
+
+ def slice(self, offset=0, length=None):
+ """
+ Compute zero-copy slice of this RecordBatch
+
+ Parameters
+ ----------
+ offset : int, default 0
+ Offset from start of record batch to slice
+ length : int, default None
+ Length of slice (default is until end of batch starting from
+ offset)
+
+ Returns
+ -------
+ sliced : RecordBatch
+ """
+ cdef shared_ptr[CRecordBatch] result
+
+ if offset < 0:
+ raise IndexError('Offset must be non-negative')
+
+ offset = min(len(self), offset)
+ if length is None:
+ result = self.batch.Slice(offset)
+ else:
+ result = self.batch.Slice(offset, length)
+
+ return pyarrow_wrap_batch(result)
+
+ def filter(self, mask, object null_selection_behavior="drop"):
+ """
+ Select record from a record batch. See pyarrow.compute.filter for full
+ usage.
+ """
+ return _pc().filter(self, mask, null_selection_behavior)
+
+ def equals(self, object other, bint check_metadata=False):
+ """
+ Check if contents of two record batches are equal.
+
+ Parameters
+ ----------
+ other : pyarrow.RecordBatch
+ RecordBatch to compare against.
+ check_metadata : bool, default False
+ Whether schema metadata equality should be checked as well.
+
+ Returns
+ -------
+ are_equal : bool
+ """
+ cdef:
+ CRecordBatch* this_batch = self.batch
+ shared_ptr[CRecordBatch] other_batch = pyarrow_unwrap_batch(other)
+ c_bool result
+
+ if not other_batch:
+ return False
+
+ with nogil:
+ result = this_batch.Equals(deref(other_batch), check_metadata)
+
+ return result
+
+ def take(self, object indices):
+ """
+ Select records from a RecordBatch. See pyarrow.compute.take for full
+ usage.
+ """
+ return _pc().take(self, indices)
+
+ def drop_null(self):
+ """
+ Remove missing values from a RecordBatch.
+ See pyarrow.compute.drop_null for full usage.
+ """
+ return _pc().drop_null(self)
+
+ def to_pydict(self):
+ """
+ Convert the RecordBatch to a dict or OrderedDict.
+
+ Returns
+ -------
+ dict
+ """
+ entries = []
+ for i in range(self.batch.num_columns()):
+ name = bytes(self.batch.column_name(i)).decode('utf8')
+ column = self[i].to_pylist()
+ entries.append((name, column))
+ return ordered_dict(entries)
+
+ def _to_pandas(self, options, **kwargs):
+ return Table.from_batches([self])._to_pandas(options, **kwargs)
+
+ @classmethod
+ def from_pandas(cls, df, Schema schema=None, preserve_index=None,
+ nthreads=None, columns=None):
+ """
+ Convert pandas.DataFrame to an Arrow RecordBatch
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ schema : pyarrow.Schema, optional
+ The expected schema of the RecordBatch. This can be used to
+ indicate the type of columns if we cannot infer it automatically.
+ If passed, the output will have exactly this schema. Columns
+ specified in the schema that are not found in the DataFrame columns
+ or its index will raise an error. Additional columns or index
+ levels in the DataFrame which are not specified in the schema will
+ be ignored.
+ preserve_index : bool, optional
+ Whether to store the index as an additional column in the resulting
+ ``RecordBatch``. The default of None will store the index as a
+ column, except for RangeIndex which is stored as metadata only. Use
+ ``preserve_index=True`` to force it to be stored as a column.
+ nthreads : int, default None (may use up to system CPU count threads)
+ If greater than 1, convert columns to Arrow in parallel using
+ indicated number of threads
+ columns : list, optional
+ List of column to be converted. If None, use all columns.
+
+ Returns
+ -------
+ pyarrow.RecordBatch
+ """
+ from pyarrow.pandas_compat import dataframe_to_arrays
+ arrays, schema = dataframe_to_arrays(
+ df, schema, preserve_index, nthreads=nthreads, columns=columns
+ )
+ return cls.from_arrays(arrays, schema=schema)
+
+ @staticmethod
+ def from_arrays(list arrays, names=None, schema=None, metadata=None):
+ """
+ Construct a RecordBatch from multiple pyarrow.Arrays
+
+ Parameters
+ ----------
+ arrays : list of pyarrow.Array
+ One for each field in RecordBatch
+ names : list of str, optional
+ Names for the batch fields. If not passed, schema must be passed
+ schema : Schema, default None
+ Schema for the created batch. If not passed, names must be passed
+ metadata : dict or Mapping, default None
+ Optional metadata for the schema (if inferred).
+
+ Returns
+ -------
+ pyarrow.RecordBatch
+ """
+ cdef:
+ Array arr
+ shared_ptr[CSchema] c_schema
+ vector[shared_ptr[CArray]] c_arrays
+ int64_t num_rows
+
+ if len(arrays) > 0:
+ num_rows = len(arrays[0])
+ else:
+ num_rows = 0
+
+ if isinstance(names, Schema):
+ import warnings
+ warnings.warn("Schema passed to names= option, please "
+ "pass schema= explicitly. "
+ "Will raise exception in future", FutureWarning)
+ schema = names
+ names = None
+
+ converted_arrays = _sanitize_arrays(arrays, names, schema, metadata,
+ &c_schema)
+
+ c_arrays.reserve(len(arrays))
+ for arr in converted_arrays:
+ if len(arr) != num_rows:
+ raise ValueError('Arrays were not all the same length: '
+ '{0} vs {1}'.format(len(arr), num_rows))
+ c_arrays.push_back(arr.sp_array)
+
+ result = pyarrow_wrap_batch(CRecordBatch.Make(c_schema, num_rows,
+ c_arrays))
+ result.validate()
+ return result
+
+ @staticmethod
+ def from_struct_array(StructArray struct_array):
+ """
+ Construct a RecordBatch from a StructArray.
+
+ Each field in the StructArray will become a column in the resulting
+ ``RecordBatch``.
+
+ Parameters
+ ----------
+ struct_array : StructArray
+ Array to construct the record batch from.
+
+ Returns
+ -------
+ pyarrow.RecordBatch
+ """
+ cdef:
+ shared_ptr[CRecordBatch] c_record_batch
+ with nogil:
+ c_record_batch = GetResultValue(
+ CRecordBatch.FromStructArray(struct_array.sp_array))
+ return pyarrow_wrap_batch(c_record_batch)
+
+ def _export_to_c(self, uintptr_t out_ptr, uintptr_t out_schema_ptr=0):
+ """
+ Export to a C ArrowArray struct, given its pointer.
+
+ If a C ArrowSchema struct pointer is also given, the record batch
+ schema is exported to it at the same time.
+
+ Parameters
+ ----------
+ out_ptr: int
+ The raw pointer to a C ArrowArray struct.
+ out_schema_ptr: int (optional)
+ The raw pointer to a C ArrowSchema struct.
+
+ Be careful: if you don't pass the ArrowArray struct to a consumer,
+ array memory will leak. This is a low-level function intended for
+ expert users.
+ """
+ with nogil:
+ check_status(ExportRecordBatch(deref(self.sp_batch),
+ <ArrowArray*> out_ptr,
+ <ArrowSchema*> out_schema_ptr))
+
+ @staticmethod
+ def _import_from_c(uintptr_t in_ptr, schema):
+ """
+ Import RecordBatch from a C ArrowArray struct, given its pointer
+ and the imported schema.
+
+ Parameters
+ ----------
+ in_ptr: int
+ The raw pointer to a C ArrowArray struct.
+ type: Schema or int
+ Either a Schema object, or the raw pointer to a C ArrowSchema
+ struct.
+
+ This is a low-level function intended for expert users.
+ """
+ cdef:
+ shared_ptr[CRecordBatch] c_batch
+
+ c_schema = pyarrow_unwrap_schema(schema)
+ if c_schema == nullptr:
+ # Not a Schema object, perhaps a raw ArrowSchema pointer
+ schema_ptr = <uintptr_t> schema
+ with nogil:
+ c_batch = GetResultValue(ImportRecordBatch(
+ <ArrowArray*> in_ptr, <ArrowSchema*> schema_ptr))
+ else:
+ with nogil:
+ c_batch = GetResultValue(ImportRecordBatch(
+ <ArrowArray*> in_ptr, c_schema))
+ return pyarrow_wrap_batch(c_batch)
+
+
+def _reconstruct_record_batch(columns, schema):
+ """
+ Internal: reconstruct RecordBatch from pickled components.
+ """
+ return RecordBatch.from_arrays(columns, schema=schema)
+
+
+def table_to_blocks(options, Table table, categories, extension_columns):
+ cdef:
+ PyObject* result_obj
+ shared_ptr[CTable] c_table
+ CMemoryPool* pool
+ PandasOptions c_options = _convert_pandas_options(options)
+
+ if categories is not None:
+ c_options.categorical_columns = {tobytes(cat) for cat in categories}
+ if extension_columns is not None:
+ c_options.extension_columns = {tobytes(col)
+ for col in extension_columns}
+
+ # ARROW-3789(wesm); Convert date/timestamp types to datetime64[ns]
+ c_options.coerce_temporal_nanoseconds = True
+
+ if c_options.self_destruct:
+ # Move the shared_ptr, table is now unsafe to use further
+ c_table = move(table.sp_table)
+ table.table = NULL
+ else:
+ c_table = table.sp_table
+
+ with nogil:
+ check_status(
+ libarrow.ConvertTableToPandas(c_options, move(c_table),
+ &result_obj)
+ )
+
+ return PyObject_to_object(result_obj)
+
+
+cdef class Table(_PandasConvertible):
+ """
+ A collection of top-level named, equal length Arrow arrays.
+
+ Warning
+ -------
+ Do not call this class's constructor directly, use one of the ``from_*``
+ methods instead.
+ """
+
+ def __cinit__(self):
+ self.table = NULL
+
+ def __init__(self):
+ raise TypeError("Do not call Table's constructor directly, use one of "
+ "the `Table.from_*` functions instead.")
+
+ def to_string(self, *, show_metadata=False, preview_cols=0):
+ """
+ Return human-readable string representation of Table.
+
+ Parameters
+ ----------
+ show_metadata : bool, default True
+ Display Field-level and Schema-level KeyValueMetadata.
+ preview_cols : int, default 0
+ Display values of the columns for the first N columns.
+
+ Returns
+ -------
+ str
+ """
+ # Use less verbose schema output.
+ schema_as_string = self.schema.to_string(
+ show_field_metadata=show_metadata,
+ show_schema_metadata=show_metadata
+ )
+ title = 'pyarrow.{}\n{}'.format(type(self).__name__, schema_as_string)
+ pieces = [title]
+ if preview_cols:
+ pieces.append('----')
+ for i in range(min(self.num_columns, preview_cols)):
+ pieces.append('{}: {}'.format(
+ self.field(i).name,
+ self.column(i).to_string(indent=0, skip_new_lines=True)
+ ))
+ if preview_cols < self.num_columns:
+ pieces.append('...')
+ return '\n'.join(pieces)
+
+ def __repr__(self):
+ if self.table == NULL:
+ raise ValueError("Table's internal pointer is NULL, do not use "
+ "any methods or attributes on this object")
+ return self.to_string(preview_cols=10)
+
+ cdef void init(self, const shared_ptr[CTable]& table):
+ self.sp_table = table
+ self.table = table.get()
+
+ def validate(self, *, full=False):
+ """
+ Perform validation checks. An exception is raised if validation fails.
+
+ By default only cheap validation checks are run. Pass `full=True`
+ for thorough validation checks (potentially O(n)).
+
+ Parameters
+ ----------
+ full: bool, default False
+ If True, run expensive checks, otherwise cheap checks only.
+
+ Raises
+ ------
+ ArrowInvalid
+ """
+ if full:
+ with nogil:
+ check_status(self.table.ValidateFull())
+ else:
+ with nogil:
+ check_status(self.table.Validate())
+
+ def __reduce__(self):
+ # Reduce the columns as ChunkedArrays to avoid serializing schema
+ # data twice
+ columns = [col for col in self.columns]
+ return _reconstruct_table, (columns, self.schema)
+
+ def __getitem__(self, key):
+ """
+ Slice or return column at given index or column name.
+
+ Parameters
+ ----------
+ key : integer, str, or slice
+ Slices with step not equal to 1 (or None) will produce a copy
+ rather than a zero-copy view.
+
+ Returns
+ -------
+ ChunkedArray (index/column) or Table (slice)
+ """
+ if isinstance(key, slice):
+ return _normalize_slice(self, key)
+ else:
+ return self.column(key)
+
+ def slice(self, offset=0, length=None):
+ """
+ Compute zero-copy slice of this Table.
+
+ Parameters
+ ----------
+ offset : int, default 0
+ Offset from start of table to slice.
+ length : int, default None
+ Length of slice (default is until end of table starting from
+ offset).
+
+ Returns
+ -------
+ Table
+ """
+ cdef shared_ptr[CTable] result
+
+ if offset < 0:
+ raise IndexError('Offset must be non-negative')
+
+ offset = min(len(self), offset)
+ if length is None:
+ result = self.table.Slice(offset)
+ else:
+ result = self.table.Slice(offset, length)
+
+ return pyarrow_wrap_table(result)
+
+ def filter(self, mask, object null_selection_behavior="drop"):
+ """
+ Select records from a Table. See :func:`pyarrow.compute.filter` for
+ full usage.
+ """
+ return _pc().filter(self, mask, null_selection_behavior)
+
+ def take(self, object indices):
+ """
+ Select records from a Table. See :func:`pyarrow.compute.take` for full
+ usage.
+ """
+ return _pc().take(self, indices)
+
+ def drop_null(self):
+ """
+ Remove missing values from a Table.
+ See :func:`pyarrow.compute.drop_null` for full usage.
+ """
+ return _pc().drop_null(self)
+
+ def select(self, object columns):
+ """
+ Select columns of the Table.
+
+ Returns a new Table with the specified columns, and metadata
+ preserved.
+
+ Parameters
+ ----------
+ columns : list-like
+ The column names or integer indices to select.
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ shared_ptr[CTable] c_table
+ vector[int] c_indices
+
+ for idx in columns:
+ idx = self._ensure_integer_index(idx)
+ idx = _normalize_index(idx, self.num_columns)
+ c_indices.push_back(<int> idx)
+
+ with nogil:
+ c_table = GetResultValue(self.table.SelectColumns(move(c_indices)))
+
+ return pyarrow_wrap_table(c_table)
+
+ def replace_schema_metadata(self, metadata=None):
+ """
+ Create shallow copy of table by replacing schema
+ key-value metadata with the indicated new metadata (which may be None),
+ which deletes any existing metadata.
+
+ Parameters
+ ----------
+ metadata : dict, default None
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ shared_ptr[const CKeyValueMetadata] c_meta
+ shared_ptr[CTable] c_table
+
+ metadata = ensure_metadata(metadata, allow_none=True)
+ c_meta = pyarrow_unwrap_metadata(metadata)
+ with nogil:
+ c_table = self.table.ReplaceSchemaMetadata(c_meta)
+
+ return pyarrow_wrap_table(c_table)
+
+ def flatten(self, MemoryPool memory_pool=None):
+ """
+ Flatten this Table.
+
+ Each column with a struct type is flattened
+ into one column per struct field. Other columns are left unchanged.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ shared_ptr[CTable] flattened
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+
+ with nogil:
+ flattened = GetResultValue(self.table.Flatten(pool))
+
+ return pyarrow_wrap_table(flattened)
+
+ def combine_chunks(self, MemoryPool memory_pool=None):
+ """
+ Make a new table by combining the chunks this table has.
+
+ All the underlying chunks in the ChunkedArray of each column are
+ concatenated into zero or one chunk.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool.
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ shared_ptr[CTable] combined
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+
+ with nogil:
+ combined = GetResultValue(self.table.CombineChunks(pool))
+
+ return pyarrow_wrap_table(combined)
+
+ def unify_dictionaries(self, MemoryPool memory_pool=None):
+ """
+ Unify dictionaries across all chunks.
+
+ This method returns an equivalent table, but where all chunks of
+ each column share the same dictionary values. Dictionary indices
+ are transposed accordingly.
+
+ Columns without dictionaries are returned unchanged.
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ shared_ptr[CTable] c_result
+
+ with nogil:
+ c_result = GetResultValue(CDictionaryUnifier.UnifyTable(
+ deref(self.table), pool))
+
+ return pyarrow_wrap_table(c_result)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def equals(self, Table other, bint check_metadata=False):
+ """
+ Check if contents of two tables are equal.
+
+ Parameters
+ ----------
+ other : pyarrow.Table
+ Table to compare against.
+ check_metadata : bool, default False
+ Whether schema metadata equality should be checked as well.
+
+ Returns
+ -------
+ bool
+ """
+ if other is None:
+ return False
+
+ cdef:
+ CTable* this_table = self.table
+ CTable* other_table = other.table
+ c_bool result
+
+ with nogil:
+ result = this_table.Equals(deref(other_table), check_metadata)
+
+ return result
+
+ def cast(self, Schema target_schema, bint safe=True):
+ """
+ Cast table values to another schema.
+
+ Parameters
+ ----------
+ target_schema : Schema
+ Schema to cast to, the names and order of fields must match.
+ safe : bool, default True
+ Check for overflows or other unsafe conversions.
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ ChunkedArray column, casted
+ Field field
+ list newcols = []
+
+ if self.schema.names != target_schema.names:
+ raise ValueError("Target schema's field names are not matching "
+ "the table's field names: {!r}, {!r}"
+ .format(self.schema.names, target_schema.names))
+
+ for column, field in zip(self.itercolumns(), target_schema):
+ casted = column.cast(field.type, safe=safe)
+ newcols.append(casted)
+
+ return Table.from_arrays(newcols, schema=target_schema)
+
+ @classmethod
+ def from_pandas(cls, df, Schema schema=None, preserve_index=None,
+ nthreads=None, columns=None, bint safe=True):
+ """
+ Convert pandas.DataFrame to an Arrow Table.
+
+ The column types in the resulting Arrow Table are inferred from the
+ dtypes of the pandas.Series in the DataFrame. In the case of non-object
+ Series, the NumPy dtype is translated to its Arrow equivalent. In the
+ case of `object`, we need to guess the datatype by looking at the
+ Python objects in this Series.
+
+ Be aware that Series of the `object` dtype don't carry enough
+ information to always lead to a meaningful Arrow type. In the case that
+ we cannot infer a type, e.g. because the DataFrame is of length 0 or
+ the Series only contains None/nan objects, the type is set to
+ null. This behavior can be avoided by constructing an explicit schema
+ and passing it to this function.
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ schema : pyarrow.Schema, optional
+ The expected schema of the Arrow Table. This can be used to
+ indicate the type of columns if we cannot infer it automatically.
+ If passed, the output will have exactly this schema. Columns
+ specified in the schema that are not found in the DataFrame columns
+ or its index will raise an error. Additional columns or index
+ levels in the DataFrame which are not specified in the schema will
+ be ignored.
+ preserve_index : bool, optional
+ Whether to store the index as an additional column in the resulting
+ ``Table``. The default of None will store the index as a column,
+ except for RangeIndex which is stored as metadata only. Use
+ ``preserve_index=True`` to force it to be stored as a column.
+ nthreads : int, default None (may use up to system CPU count threads)
+ If greater than 1, convert columns to Arrow in parallel using
+ indicated number of threads.
+ columns : list, optional
+ List of column to be converted. If None, use all columns.
+ safe : bool, default True
+ Check for overflows or other unsafe conversions.
+
+ Returns
+ -------
+ Table
+
+ Examples
+ --------
+
+ >>> import pandas as pd
+ >>> import pyarrow as pa
+ >>> df = pd.DataFrame({
+ ... 'int': [1, 2],
+ ... 'str': ['a', 'b']
+ ... })
+ >>> pa.Table.from_pandas(df)
+ <pyarrow.lib.Table object at 0x7f05d1fb1b40>
+ """
+ from pyarrow.pandas_compat import dataframe_to_arrays
+ arrays, schema = dataframe_to_arrays(
+ df,
+ schema=schema,
+ preserve_index=preserve_index,
+ nthreads=nthreads,
+ columns=columns,
+ safe=safe
+ )
+ return cls.from_arrays(arrays, schema=schema)
+
+ @staticmethod
+ def from_arrays(arrays, names=None, schema=None, metadata=None):
+ """
+ Construct a Table from Arrow arrays.
+
+ Parameters
+ ----------
+ arrays : list of pyarrow.Array or pyarrow.ChunkedArray
+ Equal-length arrays that should form the table.
+ names : list of str, optional
+ Names for the table columns. If not passed, schema must be passed.
+ schema : Schema, default None
+ Schema for the created table. If not passed, names must be passed.
+ metadata : dict or Mapping, default None
+ Optional metadata for the schema (if inferred).
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ vector[shared_ptr[CChunkedArray]] columns
+ shared_ptr[CSchema] c_schema
+ int i, K = <int> len(arrays)
+
+ converted_arrays = _sanitize_arrays(arrays, names, schema, metadata,
+ &c_schema)
+
+ columns.reserve(K)
+ for item in converted_arrays:
+ if isinstance(item, Array):
+ columns.push_back(
+ make_shared[CChunkedArray](
+ (<Array> item).sp_array
+ )
+ )
+ elif isinstance(item, ChunkedArray):
+ columns.push_back((<ChunkedArray> item).sp_chunked_array)
+ else:
+ raise TypeError(type(item))
+
+ result = pyarrow_wrap_table(CTable.Make(c_schema, columns))
+ result.validate()
+ return result
+
+ @staticmethod
+ def from_pydict(mapping, schema=None, metadata=None):
+ """
+ Construct a Table from Arrow arrays or columns.
+
+ Parameters
+ ----------
+ mapping : dict or Mapping
+ A mapping of strings to Arrays or Python lists.
+ schema : Schema, default None
+ If not passed, will be inferred from the Mapping values.
+ metadata : dict or Mapping, default None
+ Optional metadata for the schema (if inferred).
+
+ Returns
+ -------
+ Table
+ """
+
+ return _from_pydict(cls=Table,
+ mapping=mapping,
+ schema=schema,
+ metadata=metadata)
+
+ @staticmethod
+ def from_batches(batches, Schema schema=None):
+ """
+ Construct a Table from a sequence or iterator of Arrow RecordBatches.
+
+ Parameters
+ ----------
+ batches : sequence or iterator of RecordBatch
+ Sequence of RecordBatch to be converted, all schemas must be equal.
+ schema : Schema, default None
+ If not passed, will be inferred from the first RecordBatch.
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ vector[shared_ptr[CRecordBatch]] c_batches
+ shared_ptr[CTable] c_table
+ shared_ptr[CSchema] c_schema
+ RecordBatch batch
+
+ for batch in batches:
+ c_batches.push_back(batch.sp_batch)
+
+ if schema is None:
+ if c_batches.size() == 0:
+ raise ValueError('Must pass schema, or at least '
+ 'one RecordBatch')
+ c_schema = c_batches[0].get().schema()
+ else:
+ c_schema = schema.sp_schema
+
+ with nogil:
+ c_table = GetResultValue(
+ CTable.FromRecordBatches(c_schema, move(c_batches)))
+
+ return pyarrow_wrap_table(c_table)
+
+ def to_batches(self, max_chunksize=None, **kwargs):
+ """
+ Convert Table to list of (contiguous) RecordBatch objects.
+
+ Parameters
+ ----------
+ max_chunksize : int, default None
+ Maximum size for RecordBatch chunks. Individual chunks may be
+ smaller depending on the chunk layout of individual columns.
+
+ Returns
+ -------
+ list of RecordBatch
+ """
+ cdef:
+ unique_ptr[TableBatchReader] reader
+ int64_t c_max_chunksize
+ list result = []
+ shared_ptr[CRecordBatch] batch
+
+ reader.reset(new TableBatchReader(deref(self.table)))
+
+ if 'chunksize' in kwargs:
+ max_chunksize = kwargs['chunksize']
+ msg = ('The parameter chunksize is deprecated for '
+ 'pyarrow.Table.to_batches as of 0.15, please use '
+ 'the parameter max_chunksize instead')
+ warnings.warn(msg, FutureWarning)
+
+ if max_chunksize is not None:
+ c_max_chunksize = max_chunksize
+ reader.get().set_chunksize(c_max_chunksize)
+
+ while True:
+ with nogil:
+ check_status(reader.get().ReadNext(&batch))
+
+ if batch.get() == NULL:
+ break
+
+ result.append(pyarrow_wrap_batch(batch))
+
+ return result
+
+ def _to_pandas(self, options, categories=None, ignore_metadata=False,
+ types_mapper=None):
+ from pyarrow.pandas_compat import table_to_blockmanager
+ mgr = table_to_blockmanager(
+ options, self, categories,
+ ignore_metadata=ignore_metadata,
+ types_mapper=types_mapper)
+ return pandas_api.data_frame(mgr)
+
+ def to_pydict(self):
+ """
+ Convert the Table to a dict or OrderedDict.
+
+ Returns
+ -------
+ dict
+ """
+ cdef:
+ size_t i
+ size_t num_columns = self.table.num_columns()
+ list entries = []
+ ChunkedArray column
+
+ for i in range(num_columns):
+ column = self.column(i)
+ entries.append((self.field(i).name, column.to_pylist()))
+
+ return ordered_dict(entries)
+
+ @property
+ def schema(self):
+ """
+ Schema of the table and its columns.
+
+ Returns
+ -------
+ Schema
+ """
+ return pyarrow_wrap_schema(self.table.schema())
+
+ def field(self, i):
+ """
+ Select a schema field by its column name or numeric index.
+
+ Parameters
+ ----------
+ i : int or string
+ The index or name of the field to retrieve.
+
+ Returns
+ -------
+ Field
+ """
+ return self.schema.field(i)
+
+ def _ensure_integer_index(self, i):
+ """
+ Ensure integer index (convert string column name to integer if needed).
+ """
+ if isinstance(i, (bytes, str)):
+ field_indices = self.schema.get_all_field_indices(i)
+
+ if len(field_indices) == 0:
+ raise KeyError("Field \"{}\" does not exist in table schema"
+ .format(i))
+ elif len(field_indices) > 1:
+ raise KeyError("Field \"{}\" exists {} times in table schema"
+ .format(i, len(field_indices)))
+ else:
+ return field_indices[0]
+ elif isinstance(i, int):
+ return i
+ else:
+ raise TypeError("Index must either be string or integer")
+
+ def column(self, i):
+ """
+ Select a column by its column name, or numeric index.
+
+ Parameters
+ ----------
+ i : int or string
+ The index or name of the column to retrieve.
+
+ Returns
+ -------
+ ChunkedArray
+ """
+ return self._column(self._ensure_integer_index(i))
+
+ def _column(self, int i):
+ """
+ Select a column by its numeric index.
+
+ Parameters
+ ----------
+ i : int
+ The index of the column to retrieve.
+
+ Returns
+ -------
+ ChunkedArray
+ """
+ cdef int index = <int> _normalize_index(i, self.num_columns)
+ cdef ChunkedArray result = pyarrow_wrap_chunked_array(
+ self.table.column(index))
+ result._name = self.schema[index].name
+ return result
+
+ def itercolumns(self):
+ """
+ Iterator over all columns in their numerical order.
+
+ Yields
+ ------
+ ChunkedArray
+ """
+ for i in range(self.num_columns):
+ yield self._column(i)
+
+ @property
+ def columns(self):
+ """
+ List of all columns in numerical order.
+
+ Returns
+ -------
+ list of ChunkedArray
+ """
+ return [self._column(i) for i in range(self.num_columns)]
+
+ @property
+ def num_columns(self):
+ """
+ Number of columns in this table.
+
+ Returns
+ -------
+ int
+ """
+ return self.table.num_columns()
+
+ @property
+ def num_rows(self):
+ """
+ Number of rows in this table.
+
+ Due to the definition of a table, all columns have the same number of
+ rows.
+
+ Returns
+ -------
+ int
+ """
+ return self.table.num_rows()
+
+ def __len__(self):
+ return self.num_rows
+
+ @property
+ def shape(self):
+ """
+ Dimensions of the table: (#rows, #columns).
+
+ Returns
+ -------
+ (int, int)
+ Number of rows and number of columns.
+ """
+ return (self.num_rows, self.num_columns)
+
+ @property
+ def nbytes(self):
+ """
+ Total number of bytes consumed by the elements of the table.
+
+ Returns
+ -------
+ int
+ """
+ size = 0
+ for column in self.itercolumns():
+ size += column.nbytes
+ return size
+
+ def __sizeof__(self):
+ return super(Table, self).__sizeof__() + self.nbytes
+
+ def add_column(self, int i, field_, column):
+ """
+ Add column to Table at position.
+
+ A new table is returned with the column added, the original table
+ object is left unchanged.
+
+ Parameters
+ ----------
+ i : int
+ Index to place the column at.
+ field_ : str or Field
+ If a string is passed then the type is deduced from the column
+ data.
+ column : Array, list of Array, or values coercible to arrays
+ Column data.
+
+ Returns
+ -------
+ Table
+ New table with the passed column added.
+ """
+ cdef:
+ shared_ptr[CTable] c_table
+ Field c_field
+ ChunkedArray c_arr
+
+ if isinstance(column, ChunkedArray):
+ c_arr = column
+ else:
+ c_arr = chunked_array(column)
+
+ if isinstance(field_, Field):
+ c_field = field_
+ else:
+ c_field = field(field_, c_arr.type)
+
+ with nogil:
+ c_table = GetResultValue(self.table.AddColumn(
+ i, c_field.sp_field, c_arr.sp_chunked_array))
+
+ return pyarrow_wrap_table(c_table)
+
+ def append_column(self, field_, column):
+ """
+ Append column at end of columns.
+
+ Parameters
+ ----------
+ field_ : str or Field
+ If a string is passed then the type is deduced from the column
+ data.
+ column : Array, list of Array, or values coercible to arrays
+ Column data.
+
+ Returns
+ -------
+ Table
+ New table with the passed column added.
+ """
+ return self.add_column(self.num_columns, field_, column)
+
+ def remove_column(self, int i):
+ """
+ Create new Table with the indicated column removed.
+
+ Parameters
+ ----------
+ i : int
+ Index of column to remove.
+
+ Returns
+ -------
+ Table
+ New table without the column.
+ """
+ cdef shared_ptr[CTable] c_table
+
+ with nogil:
+ c_table = GetResultValue(self.table.RemoveColumn(i))
+
+ return pyarrow_wrap_table(c_table)
+
+ def set_column(self, int i, field_, column):
+ """
+ Replace column in Table at position.
+
+ Parameters
+ ----------
+ i : int
+ Index to place the column at.
+ field_ : str or Field
+ If a string is passed then the type is deduced from the column
+ data.
+ column : Array, list of Array, or values coercible to arrays
+ Column data.
+
+ Returns
+ -------
+ Table
+ New table with the passed column set.
+ """
+ cdef:
+ shared_ptr[CTable] c_table
+ Field c_field
+ ChunkedArray c_arr
+
+ if isinstance(column, ChunkedArray):
+ c_arr = column
+ else:
+ c_arr = chunked_array(column)
+
+ if isinstance(field_, Field):
+ c_field = field_
+ else:
+ c_field = field(field_, c_arr.type)
+
+ with nogil:
+ c_table = GetResultValue(self.table.SetColumn(
+ i, c_field.sp_field, c_arr.sp_chunked_array))
+
+ return pyarrow_wrap_table(c_table)
+
+ @property
+ def column_names(self):
+ """
+ Names of the table's columns.
+
+ Returns
+ -------
+ list of str
+ """
+ names = self.table.ColumnNames()
+ return [frombytes(name) for name in names]
+
+ def rename_columns(self, names):
+ """
+ Create new table with columns renamed to provided names.
+
+ Parameters
+ ----------
+ names : list of str
+ List of new column names.
+
+ Returns
+ -------
+ Table
+ """
+ cdef:
+ shared_ptr[CTable] c_table
+ vector[c_string] c_names
+
+ for name in names:
+ c_names.push_back(tobytes(name))
+
+ with nogil:
+ c_table = GetResultValue(self.table.RenameColumns(move(c_names)))
+
+ return pyarrow_wrap_table(c_table)
+
+ def drop(self, columns):
+ """
+ Drop one or more columns and return a new table.
+
+ Parameters
+ ----------
+ columns : list of str
+ List of field names referencing existing columns.
+
+ Raises
+ ------
+ KeyError
+ If any of the passed columns name are not existing.
+
+ Returns
+ -------
+ Table
+ New table without the columns.
+ """
+ indices = []
+ for col in columns:
+ idx = self.schema.get_field_index(col)
+ if idx == -1:
+ raise KeyError("Column {!r} not found".format(col))
+ indices.append(idx)
+
+ indices.sort()
+ indices.reverse()
+
+ table = self
+ for idx in indices:
+ table = table.remove_column(idx)
+
+ return table
+
+
+def _reconstruct_table(arrays, schema):
+ """
+ Internal: reconstruct pa.Table from pickled components.
+ """
+ return Table.from_arrays(arrays, schema=schema)
+
+
+def record_batch(data, names=None, schema=None, metadata=None):
+ """
+ Create a pyarrow.RecordBatch from another Python data structure or sequence
+ of arrays.
+
+ Parameters
+ ----------
+ data : pandas.DataFrame, list
+ A DataFrame or list of arrays or chunked arrays.
+ names : list, default None
+ Column names if list of arrays passed as data. Mutually exclusive with
+ 'schema' argument.
+ schema : Schema, default None
+ The expected schema of the RecordBatch. If not passed, will be inferred
+ from the data. Mutually exclusive with 'names' argument.
+ metadata : dict or Mapping, default None
+ Optional metadata for the schema (if schema not passed).
+
+ Returns
+ -------
+ RecordBatch
+
+ See Also
+ --------
+ RecordBatch.from_arrays, RecordBatch.from_pandas, table
+ """
+ # accept schema as first argument for backwards compatibility / usability
+ if isinstance(names, Schema) and schema is None:
+ schema = names
+ names = None
+
+ if isinstance(data, (list, tuple)):
+ return RecordBatch.from_arrays(data, names=names, schema=schema,
+ metadata=metadata)
+ elif _pandas_api.is_data_frame(data):
+ return RecordBatch.from_pandas(data, schema=schema)
+ else:
+ raise TypeError("Expected pandas DataFrame or list of arrays")
+
+
+def table(data, names=None, schema=None, metadata=None, nthreads=None):
+ """
+ Create a pyarrow.Table from a Python data structure or sequence of arrays.
+
+ Parameters
+ ----------
+ data : pandas.DataFrame, dict, list
+ A DataFrame, mapping of strings to Arrays or Python lists, or list of
+ arrays or chunked arrays.
+ names : list, default None
+ Column names if list of arrays passed as data. Mutually exclusive with
+ 'schema' argument.
+ schema : Schema, default None
+ The expected schema of the Arrow Table. If not passed, will be inferred
+ from the data. Mutually exclusive with 'names' argument.
+ If passed, the output will have exactly this schema (raising an error
+ when columns are not found in the data and ignoring additional data not
+ specified in the schema, when data is a dict or DataFrame).
+ metadata : dict or Mapping, default None
+ Optional metadata for the schema (if schema not passed).
+ nthreads : int, default None (may use up to system CPU count threads)
+ For pandas.DataFrame inputs: if greater than 1, convert columns to
+ Arrow in parallel using indicated number of threads.
+
+ Returns
+ -------
+ Table
+
+ See Also
+ --------
+ Table.from_arrays, Table.from_pandas, Table.from_pydict
+ """
+ # accept schema as first argument for backwards compatibility / usability
+ if isinstance(names, Schema) and schema is None:
+ schema = names
+ names = None
+
+ if isinstance(data, (list, tuple)):
+ return Table.from_arrays(data, names=names, schema=schema,
+ metadata=metadata)
+ elif isinstance(data, dict):
+ if names is not None:
+ raise ValueError(
+ "The 'names' argument is not valid when passing a dictionary")
+ return Table.from_pydict(data, schema=schema, metadata=metadata)
+ elif _pandas_api.is_data_frame(data):
+ if names is not None or metadata is not None:
+ raise ValueError(
+ "The 'names' and 'metadata' arguments are not valid when "
+ "passing a pandas DataFrame")
+ return Table.from_pandas(data, schema=schema, nthreads=nthreads)
+ else:
+ raise TypeError(
+ "Expected pandas DataFrame, python dictionary or list of arrays")
+
+
+def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None):
+ """
+ Concatenate pyarrow.Table objects.
+
+ If promote==False, a zero-copy concatenation will be performed. The schemas
+ of all the Tables must be the same (except the metadata), otherwise an
+ exception will be raised. The result Table will share the metadata with the
+ first table.
+
+ If promote==True, any null type arrays will be casted to the type of other
+ arrays in the column of the same name. If a table is missing a particular
+ field, null values of the appropriate type will be generated to take the
+ place of the missing field. The new schema will share the metadata with the
+ first table. Each field in the new schema will share the metadata with the
+ first table which has the field defined. Note that type promotions may
+ involve additional allocations on the given ``memory_pool``.
+
+ Parameters
+ ----------
+ tables : iterable of pyarrow.Table objects
+ Pyarrow tables to concatenate into a single Table.
+ promote : bool, default False
+ If True, concatenate tables with null-filling and null type promotion.
+ memory_pool : MemoryPool, default None
+ For memory allocations, if required, otherwise use default pool.
+ """
+ cdef:
+ vector[shared_ptr[CTable]] c_tables
+ shared_ptr[CTable] c_result_table
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+ Table table
+ CConcatenateTablesOptions options = (
+ CConcatenateTablesOptions.Defaults())
+
+ for table in tables:
+ c_tables.push_back(table.sp_table)
+
+ with nogil:
+ options.unify_schemas = promote
+ c_result_table = GetResultValue(
+ ConcatenateTables(c_tables, options, pool))
+
+ return pyarrow_wrap_table(c_result_table)
+
+
+def _from_pydict(cls, mapping, schema, metadata):
+ """
+ Construct a Table/RecordBatch from Arrow arrays or columns.
+
+ Parameters
+ ----------
+ cls : Class Table/RecordBatch
+ mapping : dict or Mapping
+ A mapping of strings to Arrays or Python lists.
+ schema : Schema, default None
+ If not passed, will be inferred from the Mapping values.
+ metadata : dict or Mapping, default None
+ Optional metadata for the schema (if inferred).
+
+ Returns
+ -------
+ Table/RecordBatch
+ """
+
+ arrays = []
+ if schema is None:
+ names = []
+ for k, v in mapping.items():
+ names.append(k)
+ arrays.append(asarray(v))
+ return cls.from_arrays(arrays, names, metadata=metadata)
+ elif isinstance(schema, Schema):
+ for field in schema:
+ try:
+ v = mapping[field.name]
+ except KeyError:
+ try:
+ v = mapping[tobytes(field.name)]
+ except KeyError:
+ present = mapping.keys()
+ missing = [n for n in schema.names if n not in present]
+ raise KeyError(
+ "The passed mapping doesn't contain the "
+ "following field(s) of the schema: {}".
+ format(', '.join(missing))
+ )
+ arrays.append(asarray(v, type=field.type))
+ # Will raise if metadata is not None
+ return cls.from_arrays(arrays, schema=schema, metadata=metadata)
+ else:
+ raise TypeError('Schema must be an instance of pyarrow.Schema')
diff --git a/src/arrow/python/pyarrow/tensor.pxi b/src/arrow/python/pyarrow/tensor.pxi
new file mode 100644
index 000000000..42fd44741
--- /dev/null
+++ b/src/arrow/python/pyarrow/tensor.pxi
@@ -0,0 +1,1025 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+cdef class Tensor(_Weakrefable):
+ """
+ A n-dimensional array a.k.a Tensor.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call Tensor's constructor directly, use one "
+ "of the `pyarrow.Tensor.from_*` functions instead.")
+
+ cdef void init(self, const shared_ptr[CTensor]& sp_tensor):
+ self.sp_tensor = sp_tensor
+ self.tp = sp_tensor.get()
+ self.type = pyarrow_wrap_data_type(self.tp.type())
+
+ def __repr__(self):
+ return """<pyarrow.Tensor>
+type: {0.type}
+shape: {0.shape}
+strides: {0.strides}""".format(self)
+
+ @staticmethod
+ def from_numpy(obj, dim_names=None):
+ """
+ Create a Tensor from a numpy array.
+
+ Parameters
+ ----------
+ obj : numpy.ndarray
+ The source numpy array
+ dim_names : list, optional
+ Names of each dimension of the Tensor.
+ """
+ cdef:
+ vector[c_string] c_dim_names
+ shared_ptr[CTensor] ctensor
+
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ check_status(NdarrayToTensor(c_default_memory_pool(), obj,
+ c_dim_names, &ctensor))
+ return pyarrow_wrap_tensor(ctensor)
+
+ def to_numpy(self):
+ """
+ Convert arrow::Tensor to numpy.ndarray with zero copy
+ """
+ cdef PyObject* out
+
+ check_status(TensorToNdarray(self.sp_tensor, self, &out))
+ return PyObject_to_object(out)
+
+ def equals(self, Tensor other):
+ """
+ Return true if the tensors contains exactly equal data
+ """
+ return self.tp.Equals(deref(other.tp))
+
+ def __eq__(self, other):
+ if isinstance(other, Tensor):
+ return self.equals(other)
+ else:
+ return NotImplemented
+
+ def dim_name(self, i):
+ return frombytes(self.tp.dim_name(i))
+
+ @property
+ def dim_names(self):
+ return [frombytes(x) for x in tuple(self.tp.dim_names())]
+
+ @property
+ def is_mutable(self):
+ return self.tp.is_mutable()
+
+ @property
+ def is_contiguous(self):
+ return self.tp.is_contiguous()
+
+ @property
+ def ndim(self):
+ return self.tp.ndim()
+
+ @property
+ def size(self):
+ return self.tp.size()
+
+ @property
+ def shape(self):
+ # Cython knows how to convert a vector[T] to a Python list
+ return tuple(self.tp.shape())
+
+ @property
+ def strides(self):
+ return tuple(self.tp.strides())
+
+ def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
+ buffer.buf = <char *> self.tp.data().get().data()
+ pep3118_format = self.type.pep3118_format
+ if pep3118_format is None:
+ raise NotImplementedError("type %s not supported for buffer "
+ "protocol" % (self.type,))
+ buffer.format = pep3118_format
+ buffer.itemsize = self.type.bit_width // 8
+ buffer.internal = NULL
+ buffer.len = self.tp.size() * buffer.itemsize
+ buffer.ndim = self.tp.ndim()
+ buffer.obj = self
+ if self.tp.is_mutable():
+ buffer.readonly = 0
+ else:
+ buffer.readonly = 1
+ # NOTE: This assumes Py_ssize_t == int64_t, and that the shape
+ # and strides arrays lifetime is tied to the tensor's
+ buffer.shape = <Py_ssize_t *> &self.tp.shape()[0]
+ buffer.strides = <Py_ssize_t *> &self.tp.strides()[0]
+ buffer.suboffsets = NULL
+
+
+ctypedef CSparseCOOIndex* _CSparseCOOIndexPtr
+
+
+cdef class SparseCOOTensor(_Weakrefable):
+ """
+ A sparse COO tensor.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call SparseCOOTensor's constructor directly, "
+ "use one of the `pyarrow.SparseCOOTensor.from_*` "
+ "functions instead.")
+
+ cdef void init(self, const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor):
+ self.sp_sparse_tensor = sp_sparse_tensor
+ self.stp = sp_sparse_tensor.get()
+ self.type = pyarrow_wrap_data_type(self.stp.type())
+
+ def __repr__(self):
+ return """<pyarrow.SparseCOOTensor>
+type: {0.type}
+shape: {0.shape}""".format(self)
+
+ @classmethod
+ def from_dense_numpy(cls, obj, dim_names=None):
+ """
+ Convert numpy.ndarray to arrow::SparseCOOTensor
+ """
+ return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names))
+
+ @staticmethod
+ def from_numpy(data, coords, shape, dim_names=None):
+ """
+ Create arrow::SparseCOOTensor from numpy.ndarrays
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ Data used to populate the rows.
+ coords : numpy.ndarray
+ Coordinates of the data.
+ shape : tuple
+ Shape of the tensor.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ cdef shared_ptr[CSparseCOOTensor] csparse_tensor
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in shape:
+ c_shape.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ # Enforce precondition for SparseCOOTensor indices
+ coords = np.require(coords, dtype='i8', requirements='C')
+ if coords.ndim != 2:
+ raise ValueError("Expected 2-dimensional array for "
+ "SparseCOOTensor indices")
+
+ check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(),
+ data, coords, c_shape,
+ c_dim_names, &csparse_tensor))
+ return pyarrow_wrap_sparse_coo_tensor(csparse_tensor)
+
+ @staticmethod
+ def from_scipy(obj, dim_names=None):
+ """
+ Convert scipy.sparse.coo_matrix to arrow::SparseCOOTensor
+
+ Parameters
+ ----------
+ obj : scipy.sparse.csr_matrix
+ The scipy matrix that should be converted.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ import scipy.sparse
+ if not isinstance(obj, scipy.sparse.coo_matrix):
+ raise TypeError(
+ "Expected scipy.sparse.coo_matrix, got {}".format(type(obj)))
+
+ cdef shared_ptr[CSparseCOOTensor] csparse_tensor
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in obj.shape:
+ c_shape.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ row = obj.row
+ col = obj.col
+
+ # When SciPy's coo_matrix has canonical format, its indices matrix is
+ # sorted in column-major order. As Arrow's SparseCOOIndex is sorted
+ # in row-major order if it is canonical, we must sort indices matrix
+ # into row-major order to keep its canonicalness, here.
+ if obj.has_canonical_format:
+ order = np.lexsort((col, row)) # sort in row-major order
+ row = row[order]
+ col = col[order]
+ coords = np.vstack([row, col]).T
+ coords = np.require(coords, dtype='i8', requirements='C')
+
+ check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(),
+ obj.data, coords, c_shape,
+ c_dim_names, &csparse_tensor))
+ return pyarrow_wrap_sparse_coo_tensor(csparse_tensor)
+
+ @staticmethod
+ def from_pydata_sparse(obj, dim_names=None):
+ """
+ Convert pydata/sparse.COO to arrow::SparseCOOTensor.
+
+ Parameters
+ ----------
+ obj : pydata.sparse.COO
+ The sparse multidimensional array that should be converted.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ import sparse
+ if not isinstance(obj, sparse.COO):
+ raise TypeError(
+ "Expected sparse.COO, got {}".format(type(obj)))
+
+ cdef shared_ptr[CSparseCOOTensor] csparse_tensor
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in obj.shape:
+ c_shape.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ coords = np.require(obj.coords.T, dtype='i8', requirements='C')
+
+ check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(),
+ obj.data, coords, c_shape,
+ c_dim_names, &csparse_tensor))
+ return pyarrow_wrap_sparse_coo_tensor(csparse_tensor)
+
+ @staticmethod
+ def from_tensor(obj):
+ """
+ Convert arrow::Tensor to arrow::SparseCOOTensor.
+
+ Parameters
+ ----------
+ obj : Tensor
+ The tensor that should be converted.
+ """
+ cdef shared_ptr[CSparseCOOTensor] csparse_tensor
+ cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj)
+
+ with nogil:
+ check_status(TensorToSparseCOOTensor(ctensor, &csparse_tensor))
+
+ return pyarrow_wrap_sparse_coo_tensor(csparse_tensor)
+
+ def to_numpy(self):
+ """
+ Convert arrow::SparseCOOTensor to numpy.ndarrays with zero copy.
+ """
+ cdef PyObject* out_data
+ cdef PyObject* out_coords
+
+ check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_coords))
+ return PyObject_to_object(out_data), PyObject_to_object(out_coords)
+
+ def to_scipy(self):
+ """
+ Convert arrow::SparseCOOTensor to scipy.sparse.coo_matrix.
+ """
+ from scipy.sparse import coo_matrix
+ cdef PyObject* out_data
+ cdef PyObject* out_coords
+
+ check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_coords))
+ data = PyObject_to_object(out_data)
+ coords = PyObject_to_object(out_coords)
+ row, col = coords[:, 0], coords[:, 1]
+ result = coo_matrix((data[:, 0], (row, col)), shape=self.shape)
+
+ # As the description in from_scipy above, we sorted indices matrix
+ # in row-major order if SciPy's coo_matrix has canonical format.
+ # So, we must call sum_duplicates() to make the result coo_matrix
+ # has canonical format.
+ if self.has_canonical_format:
+ result.sum_duplicates()
+ return result
+
+ def to_pydata_sparse(self):
+ """
+ Convert arrow::SparseCOOTensor to pydata/sparse.COO.
+ """
+ from sparse import COO
+ cdef PyObject* out_data
+ cdef PyObject* out_coords
+
+ check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_coords))
+ data = PyObject_to_object(out_data)
+ coords = PyObject_to_object(out_coords)
+ result = COO(data=data[:, 0], coords=coords.T, shape=self.shape)
+ return result
+
+ def to_tensor(self):
+ """
+ Convert arrow::SparseCOOTensor to arrow::Tensor.
+ """
+
+ cdef shared_ptr[CTensor] ctensor
+ with nogil:
+ ctensor = GetResultValue(self.stp.ToTensor())
+
+ return pyarrow_wrap_tensor(ctensor)
+
+ def equals(self, SparseCOOTensor other):
+ """
+ Return true if sparse tensors contains exactly equal data.
+ """
+ return self.stp.Equals(deref(other.stp))
+
+ def __eq__(self, other):
+ if isinstance(other, SparseCOOTensor):
+ return self.equals(other)
+ else:
+ return NotImplemented
+
+ @property
+ def is_mutable(self):
+ return self.stp.is_mutable()
+
+ @property
+ def ndim(self):
+ return self.stp.ndim()
+
+ @property
+ def shape(self):
+ # Cython knows how to convert a vector[T] to a Python list
+ return tuple(self.stp.shape())
+
+ @property
+ def size(self):
+ return self.stp.size()
+
+ def dim_name(self, i):
+ return frombytes(self.stp.dim_name(i))
+
+ @property
+ def dim_names(self):
+ return tuple(frombytes(x) for x in tuple(self.stp.dim_names()))
+
+ @property
+ def non_zero_length(self):
+ return self.stp.non_zero_length()
+
+ @property
+ def has_canonical_format(self):
+ cdef:
+ _CSparseCOOIndexPtr csi
+
+ csi = <_CSparseCOOIndexPtr>(self.stp.sparse_index().get())
+ if csi != nullptr:
+ return csi.is_canonical()
+ return True
+
+cdef class SparseCSRMatrix(_Weakrefable):
+ """
+ A sparse CSR matrix.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call SparseCSRMatrix's constructor directly, "
+ "use one of the `pyarrow.SparseCSRMatrix.from_*` "
+ "functions instead.")
+
+ cdef void init(self, const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor):
+ self.sp_sparse_tensor = sp_sparse_tensor
+ self.stp = sp_sparse_tensor.get()
+ self.type = pyarrow_wrap_data_type(self.stp.type())
+
+ def __repr__(self):
+ return """<pyarrow.SparseCSRMatrix>
+type: {0.type}
+shape: {0.shape}""".format(self)
+
+ @classmethod
+ def from_dense_numpy(cls, obj, dim_names=None):
+ """
+ Convert numpy.ndarray to arrow::SparseCSRMatrix
+
+ Parameters
+ ----------
+ obj : numpy.ndarray
+ The dense numpy array that should be converted.
+ dim_names : list, optional
+ The names of the dimensions.
+ """
+ return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names))
+
+ @staticmethod
+ def from_numpy(data, indptr, indices, shape, dim_names=None):
+ """
+ Create arrow::SparseCSRMatrix from numpy.ndarrays.
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ Data used to populate the sparse matrix.
+ indptr : numpy.ndarray
+ Range of the rows,
+ The i-th row spans from `indptr[i]` to `indptr[i+1]` in the data.
+ indices : numpy.ndarray
+ Column indices of the corresponding non-zero values.
+ shape : tuple
+ Shape of the matrix.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ cdef shared_ptr[CSparseCSRMatrix] csparse_tensor
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in shape:
+ c_shape.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ # Enforce precondition for SparseCSRMatrix indices
+ indptr = np.require(indptr, dtype='i8')
+ indices = np.require(indices, dtype='i8')
+ if indptr.ndim != 1:
+ raise ValueError("Expected 1-dimensional array for "
+ "SparseCSRMatrix indptr")
+ if indices.ndim != 1:
+ raise ValueError("Expected 1-dimensional array for "
+ "SparseCSRMatrix indices")
+
+ check_status(NdarraysToSparseCSRMatrix(c_default_memory_pool(),
+ data, indptr, indices, c_shape,
+ c_dim_names, &csparse_tensor))
+ return pyarrow_wrap_sparse_csr_matrix(csparse_tensor)
+
+ @staticmethod
+ def from_scipy(obj, dim_names=None):
+ """
+ Convert scipy.sparse.csr_matrix to arrow::SparseCSRMatrix.
+
+ Parameters
+ ----------
+ obj : scipy.sparse.csr_matrix
+ The scipy matrix that should be converted.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ import scipy.sparse
+ if not isinstance(obj, scipy.sparse.csr_matrix):
+ raise TypeError(
+ "Expected scipy.sparse.csr_matrix, got {}".format(type(obj)))
+
+ cdef shared_ptr[CSparseCSRMatrix] csparse_tensor
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in obj.shape:
+ c_shape.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ # Enforce precondition for CSparseCSRMatrix indices
+ indptr = np.require(obj.indptr, dtype='i8')
+ indices = np.require(obj.indices, dtype='i8')
+
+ check_status(NdarraysToSparseCSRMatrix(c_default_memory_pool(),
+ obj.data, indptr, indices,
+ c_shape, c_dim_names,
+ &csparse_tensor))
+ return pyarrow_wrap_sparse_csr_matrix(csparse_tensor)
+
+ @staticmethod
+ def from_tensor(obj):
+ """
+ Convert arrow::Tensor to arrow::SparseCSRMatrix.
+
+ Parameters
+ ----------
+ obj : Tensor
+ The dense tensor that should be converted.
+ """
+ cdef shared_ptr[CSparseCSRMatrix] csparse_tensor
+ cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj)
+
+ with nogil:
+ check_status(TensorToSparseCSRMatrix(ctensor, &csparse_tensor))
+
+ return pyarrow_wrap_sparse_csr_matrix(csparse_tensor)
+
+ def to_numpy(self):
+ """
+ Convert arrow::SparseCSRMatrix to numpy.ndarrays with zero copy.
+ """
+ cdef PyObject* out_data
+ cdef PyObject* out_indptr
+ cdef PyObject* out_indices
+
+ check_status(SparseCSRMatrixToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_indptr,
+ &out_indices))
+ return (PyObject_to_object(out_data), PyObject_to_object(out_indptr),
+ PyObject_to_object(out_indices))
+
+ def to_scipy(self):
+ """
+ Convert arrow::SparseCSRMatrix to scipy.sparse.csr_matrix.
+ """
+ from scipy.sparse import csr_matrix
+ cdef PyObject* out_data
+ cdef PyObject* out_indptr
+ cdef PyObject* out_indices
+
+ check_status(SparseCSRMatrixToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_indptr,
+ &out_indices))
+
+ data = PyObject_to_object(out_data)
+ indptr = PyObject_to_object(out_indptr)
+ indices = PyObject_to_object(out_indices)
+ result = csr_matrix((data[:, 0], indices, indptr), shape=self.shape)
+ return result
+
+ def to_tensor(self):
+ """
+ Convert arrow::SparseCSRMatrix to arrow::Tensor.
+ """
+ cdef shared_ptr[CTensor] ctensor
+ with nogil:
+ ctensor = GetResultValue(self.stp.ToTensor())
+
+ return pyarrow_wrap_tensor(ctensor)
+
+ def equals(self, SparseCSRMatrix other):
+ """
+ Return true if sparse tensors contains exactly equal data.
+ """
+ return self.stp.Equals(deref(other.stp))
+
+ def __eq__(self, other):
+ if isinstance(other, SparseCSRMatrix):
+ return self.equals(other)
+ else:
+ return NotImplemented
+
+ @property
+ def is_mutable(self):
+ return self.stp.is_mutable()
+
+ @property
+ def ndim(self):
+ return self.stp.ndim()
+
+ @property
+ def shape(self):
+ # Cython knows how to convert a vector[T] to a Python list
+ return tuple(self.stp.shape())
+
+ @property
+ def size(self):
+ return self.stp.size()
+
+ def dim_name(self, i):
+ return frombytes(self.stp.dim_name(i))
+
+ @property
+ def dim_names(self):
+ return tuple(frombytes(x) for x in tuple(self.stp.dim_names()))
+
+ @property
+ def non_zero_length(self):
+ return self.stp.non_zero_length()
+
+cdef class SparseCSCMatrix(_Weakrefable):
+ """
+ A sparse CSC matrix.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call SparseCSCMatrix's constructor directly, "
+ "use one of the `pyarrow.SparseCSCMatrix.from_*` "
+ "functions instead.")
+
+ cdef void init(self, const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor):
+ self.sp_sparse_tensor = sp_sparse_tensor
+ self.stp = sp_sparse_tensor.get()
+ self.type = pyarrow_wrap_data_type(self.stp.type())
+
+ def __repr__(self):
+ return """<pyarrow.SparseCSCMatrix>
+type: {0.type}
+shape: {0.shape}""".format(self)
+
+ @classmethod
+ def from_dense_numpy(cls, obj, dim_names=None):
+ """
+ Convert numpy.ndarray to arrow::SparseCSCMatrix
+ """
+ return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names))
+
+ @staticmethod
+ def from_numpy(data, indptr, indices, shape, dim_names=None):
+ """
+ Create arrow::SparseCSCMatrix from numpy.ndarrays
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ Data used to populate the sparse matrix.
+ indptr : numpy.ndarray
+ Range of the rows,
+ The i-th row spans from `indptr[i]` to `indptr[i+1]` in the data.
+ indices : numpy.ndarray
+ Column indices of the corresponding non-zero values.
+ shape : tuple
+ Shape of the matrix.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ cdef shared_ptr[CSparseCSCMatrix] csparse_tensor
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in shape:
+ c_shape.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ # Enforce precondition for SparseCSCMatrix indices
+ indptr = np.require(indptr, dtype='i8')
+ indices = np.require(indices, dtype='i8')
+ if indptr.ndim != 1:
+ raise ValueError("Expected 1-dimensional array for "
+ "SparseCSCMatrix indptr")
+ if indices.ndim != 1:
+ raise ValueError("Expected 1-dimensional array for "
+ "SparseCSCMatrix indices")
+
+ check_status(NdarraysToSparseCSCMatrix(c_default_memory_pool(),
+ data, indptr, indices, c_shape,
+ c_dim_names, &csparse_tensor))
+ return pyarrow_wrap_sparse_csc_matrix(csparse_tensor)
+
+ @staticmethod
+ def from_scipy(obj, dim_names=None):
+ """
+ Convert scipy.sparse.csc_matrix to arrow::SparseCSCMatrix
+
+ Parameters
+ ----------
+ obj : scipy.sparse.csc_matrix
+ The scipy matrix that should be converted.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ import scipy.sparse
+ if not isinstance(obj, scipy.sparse.csc_matrix):
+ raise TypeError(
+ "Expected scipy.sparse.csc_matrix, got {}".format(type(obj)))
+
+ cdef shared_ptr[CSparseCSCMatrix] csparse_tensor
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in obj.shape:
+ c_shape.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ # Enforce precondition for CSparseCSCMatrix indices
+ indptr = np.require(obj.indptr, dtype='i8')
+ indices = np.require(obj.indices, dtype='i8')
+
+ check_status(NdarraysToSparseCSCMatrix(c_default_memory_pool(),
+ obj.data, indptr, indices,
+ c_shape, c_dim_names,
+ &csparse_tensor))
+ return pyarrow_wrap_sparse_csc_matrix(csparse_tensor)
+
+ @staticmethod
+ def from_tensor(obj):
+ """
+ Convert arrow::Tensor to arrow::SparseCSCMatrix
+
+ Parameters
+ ----------
+ obj : Tensor
+ The dense tensor that should be converted.
+ """
+ cdef shared_ptr[CSparseCSCMatrix] csparse_tensor
+ cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj)
+
+ with nogil:
+ check_status(TensorToSparseCSCMatrix(ctensor, &csparse_tensor))
+
+ return pyarrow_wrap_sparse_csc_matrix(csparse_tensor)
+
+ def to_numpy(self):
+ """
+ Convert arrow::SparseCSCMatrix to numpy.ndarrays with zero copy
+ """
+ cdef PyObject* out_data
+ cdef PyObject* out_indptr
+ cdef PyObject* out_indices
+
+ check_status(SparseCSCMatrixToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_indptr,
+ &out_indices))
+ return (PyObject_to_object(out_data), PyObject_to_object(out_indptr),
+ PyObject_to_object(out_indices))
+
+ def to_scipy(self):
+ """
+ Convert arrow::SparseCSCMatrix to scipy.sparse.csc_matrix
+ """
+ from scipy.sparse import csc_matrix
+ cdef PyObject* out_data
+ cdef PyObject* out_indptr
+ cdef PyObject* out_indices
+
+ check_status(SparseCSCMatrixToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_indptr,
+ &out_indices))
+
+ data = PyObject_to_object(out_data)
+ indptr = PyObject_to_object(out_indptr)
+ indices = PyObject_to_object(out_indices)
+ result = csc_matrix((data[:, 0], indices, indptr), shape=self.shape)
+ return result
+
+ def to_tensor(self):
+ """
+ Convert arrow::SparseCSCMatrix to arrow::Tensor
+ """
+
+ cdef shared_ptr[CTensor] ctensor
+ with nogil:
+ ctensor = GetResultValue(self.stp.ToTensor())
+
+ return pyarrow_wrap_tensor(ctensor)
+
+ def equals(self, SparseCSCMatrix other):
+ """
+ Return true if sparse tensors contains exactly equal data
+ """
+ return self.stp.Equals(deref(other.stp))
+
+ def __eq__(self, other):
+ if isinstance(other, SparseCSCMatrix):
+ return self.equals(other)
+ else:
+ return NotImplemented
+
+ @property
+ def is_mutable(self):
+ return self.stp.is_mutable()
+
+ @property
+ def ndim(self):
+ return self.stp.ndim()
+
+ @property
+ def shape(self):
+ # Cython knows how to convert a vector[T] to a Python list
+ return tuple(self.stp.shape())
+
+ @property
+ def size(self):
+ return self.stp.size()
+
+ def dim_name(self, i):
+ return frombytes(self.stp.dim_name(i))
+
+ @property
+ def dim_names(self):
+ return tuple(frombytes(x) for x in tuple(self.stp.dim_names()))
+
+ @property
+ def non_zero_length(self):
+ return self.stp.non_zero_length()
+
+
+cdef class SparseCSFTensor(_Weakrefable):
+ """
+ A sparse CSF tensor.
+
+ CSF is a generalization of compressed sparse row (CSR) index.
+
+ CSF index recursively compresses each dimension of a tensor into a set
+ of prefix trees. Each path from a root to leaf forms one tensor
+ non-zero index. CSF is implemented with two arrays of buffers and one
+ arrays of integers.
+ """
+
+ def __init__(self):
+ raise TypeError("Do not call SparseCSFTensor's constructor directly, "
+ "use one of the `pyarrow.SparseCSFTensor.from_*` "
+ "functions instead.")
+
+ cdef void init(self, const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor):
+ self.sp_sparse_tensor = sp_sparse_tensor
+ self.stp = sp_sparse_tensor.get()
+ self.type = pyarrow_wrap_data_type(self.stp.type())
+
+ def __repr__(self):
+ return """<pyarrow.SparseCSFTensor>
+type: {0.type}
+shape: {0.shape}""".format(self)
+
+ @classmethod
+ def from_dense_numpy(cls, obj, dim_names=None):
+ """
+ Convert numpy.ndarray to arrow::SparseCSFTensor
+ """
+ return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names))
+
+ @staticmethod
+ def from_numpy(data, indptr, indices, shape, axis_order=None,
+ dim_names=None):
+ """
+ Create arrow::SparseCSFTensor from numpy.ndarrays
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ Data used to populate the sparse tensor.
+ indptr : numpy.ndarray
+ The sparsity structure.
+ Each two consecutive dimensions in a tensor correspond to
+ a buffer in indices.
+ A pair of consecutive values at `indptr[dim][i]`
+ `indptr[dim][i + 1]` signify a range of nodes in
+ `indices[dim + 1]` who are children of `indices[dim][i]` node.
+ indices : numpy.ndarray
+ Stores values of nodes.
+ Each tensor dimension corresponds to a buffer in indptr.
+ shape : tuple
+ Shape of the matrix.
+ axis_order : list, optional
+ the sequence in which dimensions were traversed to
+ produce the prefix tree.
+ dim_names : list, optional
+ Names of the dimensions.
+ """
+ cdef shared_ptr[CSparseCSFTensor] csparse_tensor
+ cdef vector[int64_t] c_axis_order
+ cdef vector[int64_t] c_shape
+ cdef vector[c_string] c_dim_names
+
+ for x in shape:
+ c_shape.push_back(x)
+ if not axis_order:
+ axis_order = np.argsort(shape)
+ for x in axis_order:
+ c_axis_order.push_back(x)
+ if dim_names is not None:
+ for x in dim_names:
+ c_dim_names.push_back(tobytes(x))
+
+ # Enforce preconditions for SparseCSFTensor indices
+ if not (isinstance(indptr, (list, tuple)) and
+ isinstance(indices, (list, tuple))):
+ raise TypeError("Expected list or tuple, got {}, {}"
+ .format(type(indptr), type(indices)))
+ if len(indptr) != len(shape) - 1:
+ raise ValueError("Expected list of {ndim} np.arrays for "
+ "SparseCSFTensor.indptr".format(ndim=len(shape)))
+ if len(indices) != len(shape):
+ raise ValueError("Expected list of {ndim} np.arrays for "
+ "SparseCSFTensor.indices".format(ndim=len(shape)))
+ if any([x.ndim != 1 for x in indptr]):
+ raise ValueError("Expected a list of 1-dimensional arrays for "
+ "SparseCSFTensor.indptr")
+ if any([x.ndim != 1 for x in indices]):
+ raise ValueError("Expected a list of 1-dimensional arrays for "
+ "SparseCSFTensor.indices")
+ indptr = [np.require(arr, dtype='i8') for arr in indptr]
+ indices = [np.require(arr, dtype='i8') for arr in indices]
+
+ check_status(NdarraysToSparseCSFTensor(c_default_memory_pool(), data,
+ indptr, indices, c_shape,
+ c_axis_order, c_dim_names,
+ &csparse_tensor))
+ return pyarrow_wrap_sparse_csf_tensor(csparse_tensor)
+
+ @staticmethod
+ def from_tensor(obj):
+ """
+ Convert arrow::Tensor to arrow::SparseCSFTensor
+
+ Parameters
+ ----------
+ obj : Tensor
+ The dense tensor that should be converted.
+ """
+ cdef shared_ptr[CSparseCSFTensor] csparse_tensor
+ cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj)
+
+ with nogil:
+ check_status(TensorToSparseCSFTensor(ctensor, &csparse_tensor))
+
+ return pyarrow_wrap_sparse_csf_tensor(csparse_tensor)
+
+ def to_numpy(self):
+ """
+ Convert arrow::SparseCSFTensor to numpy.ndarrays with zero copy
+ """
+ cdef PyObject* out_data
+ cdef PyObject* out_indptr
+ cdef PyObject* out_indices
+
+ check_status(SparseCSFTensorToNdarray(self.sp_sparse_tensor, self,
+ &out_data, &out_indptr,
+ &out_indices))
+ return (PyObject_to_object(out_data), PyObject_to_object(out_indptr),
+ PyObject_to_object(out_indices))
+
+ def to_tensor(self):
+ """
+ Convert arrow::SparseCSFTensor to arrow::Tensor
+ """
+
+ cdef shared_ptr[CTensor] ctensor
+ with nogil:
+ ctensor = GetResultValue(self.stp.ToTensor())
+
+ return pyarrow_wrap_tensor(ctensor)
+
+ def equals(self, SparseCSFTensor other):
+ """
+ Return true if sparse tensors contains exactly equal data
+ """
+ return self.stp.Equals(deref(other.stp))
+
+ def __eq__(self, other):
+ if isinstance(other, SparseCSFTensor):
+ return self.equals(other)
+ else:
+ return NotImplemented
+
+ @property
+ def is_mutable(self):
+ return self.stp.is_mutable()
+
+ @property
+ def ndim(self):
+ return self.stp.ndim()
+
+ @property
+ def shape(self):
+ # Cython knows how to convert a vector[T] to a Python list
+ return tuple(self.stp.shape())
+
+ @property
+ def size(self):
+ return self.stp.size()
+
+ def dim_name(self, i):
+ return frombytes(self.stp.dim_name(i))
+
+ @property
+ def dim_names(self):
+ return tuple(frombytes(x) for x in tuple(self.stp.dim_names()))
+
+ @property
+ def non_zero_length(self):
+ return self.stp.non_zero_length()
diff --git a/src/arrow/python/pyarrow/tensorflow/plasma_op.cc b/src/arrow/python/pyarrow/tensorflow/plasma_op.cc
new file mode 100644
index 000000000..bf4eec789
--- /dev/null
+++ b/src/arrow/python/pyarrow/tensorflow/plasma_op.cc
@@ -0,0 +1,391 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/stream.h"
+
+#ifdef GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#endif
+
+#include "arrow/adapters/tensorflow/convert.h"
+#include "arrow/api.h"
+#include "arrow/io/api.h"
+#include "arrow/util/logging.h"
+
+// These headers do not include Python.h
+#include "arrow/python/deserialize.h"
+#include "arrow/python/serialize.h"
+
+#include "plasma/client.h"
+
+namespace tf = tensorflow;
+
+using ArrowStatus = arrow::Status;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+using Event = perftools::gputools::Event;
+using Stream = perftools::gputools::Stream;
+
+// NOTE(zongheng): for some reason using unique_ptr or shared_ptr results in
+// CUDA_ERROR_DEINITIALIZED on program exit. I suspect this is because the
+// static object's dtor gets called *after* TensorFlow's own CUDA cleanup.
+// Instead, we use a raw pointer here and manually clean up in the Ops' dtors.
+static Stream* d2h_stream = nullptr;
+static tf::mutex d2h_stream_mu;
+
+// TODO(zongheng): CPU kernels' std::memcpy might be able to be sped up by
+// parallelization.
+
+int64_t get_byte_width(const arrow::DataType& dtype) {
+ return arrow::internal::checked_cast<const arrow::FixedWidthType&>(dtype)
+ .bit_width() / CHAR_BIT;
+}
+
+// Put: tf.Tensor -> plasma.
+template <typename Device>
+class TensorToPlasmaOp : public tf::AsyncOpKernel {
+ public:
+ explicit TensorToPlasmaOp(tf::OpKernelConstruction* context) : tf::AsyncOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("plasma_store_socket_name",
+ &plasma_store_socket_name_));
+ tf::mutex_lock lock(mu_);
+ if (!connected_) {
+ VLOG(1) << "Connecting to Plasma...";
+ ARROW_CHECK_OK(client_.Connect(plasma_store_socket_name_));
+ VLOG(1) << "Connected!";
+ connected_ = true;
+ }
+ }
+
+ ~TensorToPlasmaOp() override {
+ {
+ tf::mutex_lock lock(mu_);
+ ARROW_CHECK_OK(client_.Disconnect());
+ connected_ = false;
+ }
+ {
+ tf::mutex_lock lock(d2h_stream_mu);
+ if (d2h_stream != nullptr) {
+ delete d2h_stream;
+ }
+ }
+ }
+
+ void ComputeAsync(tf::OpKernelContext* context, DoneCallback done) override {
+ const int num_inputs = context->num_inputs();
+ OP_REQUIRES_ASYNC(
+ context, num_inputs >= 2,
+ tf::errors::InvalidArgument("Input should have at least 1 tensor and 1 object_id"),
+ done);
+ const int num_tensors = num_inputs - 1;
+
+ // Check that all tensors have the same dtype
+ tf::DataType tf_dtype = context->input(0).dtype();
+ for (int i = 1; i < num_inputs - 1; i++) {
+ if (tf_dtype != context->input(i).dtype()) {
+ ARROW_CHECK_OK(arrow::Status(arrow::StatusCode::TypeError,
+ "All input tensors must have the same data type"));
+ }
+ }
+
+ std::shared_ptr<arrow::DataType> arrow_dtype;
+ ARROW_CHECK_OK(arrow::adapters::tensorflow::GetArrowType(tf_dtype, &arrow_dtype));
+ int64_t byte_width = get_byte_width(*arrow_dtype);
+
+ std::vector<size_t> offsets;
+ offsets.reserve(num_tensors + 1);
+ offsets.push_back(0);
+ int64_t total_bytes = 0;
+ for (int i = 0; i < num_tensors; ++i) {
+ const size_t s = context->input(i).TotalBytes();
+ CHECK_EQ(s, context->input(i).NumElements() * byte_width);
+ CHECK_GT(s, 0);
+ total_bytes += s;
+ offsets.push_back(total_bytes);
+ }
+
+ const tf::Tensor& plasma_object_id = context->input(num_inputs - 1);
+ CHECK_EQ(plasma_object_id.NumElements(), 1);
+ const std::string& plasma_object_id_str = plasma_object_id.flat<std::string>()(0);
+ VLOG(1) << "plasma_object_id_str: '" << plasma_object_id_str << "'";
+ const plasma::ObjectID object_id =
+ plasma::ObjectID::from_binary(plasma_object_id_str);
+
+ std::vector<int64_t> shape = {total_bytes / byte_width};
+
+ arrow::io::MockOutputStream mock;
+ ARROW_CHECK_OK(arrow::py::WriteNdarrayHeader(arrow_dtype, shape, 0, &mock));
+ int64_t header_size = mock.GetExtentBytesWritten();
+
+ std::shared_ptr<Buffer> data_buffer;
+ {
+ tf::mutex_lock lock(mu_);
+ ARROW_CHECK_OK(client_.Create(object_id, header_size + total_bytes,
+ /*metadata=*/nullptr, 0, &data_buffer));
+ }
+
+ int64_t offset;
+ arrow::io::FixedSizeBufferWriter buf(data_buffer);
+ ARROW_CHECK_OK(arrow::py::WriteNdarrayHeader(arrow_dtype, shape, total_bytes, &buf));
+ ARROW_CHECK_OK(buf.Tell(&offset));
+
+ uint8_t* data = reinterpret_cast<uint8_t*>(data_buffer->mutable_data() + offset);
+
+ auto wrapped_callback = [this, context, done, data_buffer, data, object_id]() {
+ {
+ tf::mutex_lock lock(mu_);
+ ARROW_CHECK_OK(client_.Seal(object_id));
+ ARROW_CHECK_OK(client_.Release(object_id));
+#ifdef GOOGLE_CUDA
+ auto orig_stream = context->op_device_context()->stream();
+ auto stream_executor = orig_stream->parent();
+ CHECK(stream_executor->HostMemoryUnregister(static_cast<void*>(data)));
+#endif
+ }
+ context->SetStatus(tensorflow::Status::OK());
+ done();
+ };
+
+ if (std::is_same<Device, CPUDevice>::value) {
+ for (int i = 0; i < num_tensors; ++i) {
+ const auto& input_tensor = context->input(i);
+ std::memcpy(static_cast<void*>(data + offsets[i]),
+ input_tensor.tensor_data().data(),
+ static_cast<tf::uint64>(offsets[i + 1] - offsets[i]));
+ }
+ wrapped_callback();
+ } else {
+#ifdef GOOGLE_CUDA
+ auto orig_stream = context->op_device_context()->stream();
+ OP_REQUIRES_ASYNC(context, orig_stream != nullptr,
+ tf::errors::Internal("No GPU stream available."), done);
+ auto stream_executor = orig_stream->parent();
+
+ // NOTE(zongheng): this is critical of getting good performance out of D2H
+ // async memcpy. Under the hood it performs cuMemHostRegister(), see:
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223
+ CHECK(stream_executor->HostMemoryRegister(static_cast<void*>(data),
+ static_cast<tf::uint64>(total_bytes)));
+
+ {
+ tf::mutex_lock l(d2h_stream_mu);
+ if (d2h_stream == nullptr) {
+ d2h_stream = new Stream(stream_executor);
+ CHECK(d2h_stream->Init().ok());
+ }
+ }
+
+ // Needed to make sure the input buffers have been computed.
+ // NOTE(ekl): this is unnecessary when the op is behind a NCCL allreduce already
+ CHECK(d2h_stream->ThenWaitFor(orig_stream).ok());
+
+ for (int i = 0; i < num_tensors; ++i) {
+ const auto& input_tensor = context->input(i);
+ auto input_buffer = const_cast<char*>(input_tensor.tensor_data().data());
+ perftools::gputools::DeviceMemoryBase wrapped_src(
+ static_cast<void*>(input_buffer));
+ const bool success =
+ d2h_stream
+ ->ThenMemcpy(static_cast<void*>(data + offsets[i]), wrapped_src,
+ static_cast<tf::uint64>(offsets[i + 1] - offsets[i]))
+ .ok();
+ OP_REQUIRES_ASYNC(context, success,
+ tf::errors::Internal("D2H memcpy failed to be enqueued."), done);
+ }
+ context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ d2h_stream, std::move(wrapped_callback));
+#endif
+ }
+ }
+
+ private:
+ std::string plasma_store_socket_name_;
+
+ tf::mutex mu_;
+ bool connected_ = false;
+ plasma::PlasmaClient client_ GUARDED_BY(mu_);
+};
+
+static Stream* h2d_stream = nullptr;
+static tf::mutex h2d_stream_mu;
+
+// Get: plasma -> tf.Tensor.
+template <typename Device>
+class PlasmaToTensorOp : public tf::AsyncOpKernel {
+ public:
+ explicit PlasmaToTensorOp(tf::OpKernelConstruction* context) : tf::AsyncOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("plasma_store_socket_name",
+ &plasma_store_socket_name_));
+ tf::mutex_lock lock(mu_);
+ if (!connected_) {
+ VLOG(1) << "Connecting to Plasma...";
+ ARROW_CHECK_OK(client_.Connect(plasma_store_socket_name_));
+ VLOG(1) << "Connected!";
+ connected_ = true;
+ }
+ }
+
+ ~PlasmaToTensorOp() override {
+ {
+ tf::mutex_lock lock(mu_);
+ ARROW_CHECK_OK(client_.Disconnect());
+ connected_ = false;
+ }
+ {
+ tf::mutex_lock lock(h2d_stream_mu);
+ if (h2d_stream != nullptr) {
+ delete h2d_stream;
+ }
+ }
+ }
+
+ void ComputeAsync(tf::OpKernelContext* context, DoneCallback done) override {
+ const tf::Tensor& plasma_object_id = context->input(0);
+ CHECK_EQ(plasma_object_id.NumElements(), 1);
+ const std::string& plasma_object_id_str = plasma_object_id.flat<std::string>()(0);
+
+ VLOG(1) << "plasma_object_id_str: '" << plasma_object_id_str << "'";
+ const plasma::ObjectID object_id =
+ plasma::ObjectID::from_binary(plasma_object_id_str);
+
+ plasma::ObjectBuffer object_buffer;
+ {
+ tf::mutex_lock lock(mu_);
+ // NOTE(zongheng): this is a blocking call. We might want to (1) make
+ // Plasma asynchronous, (2) launch a thread / event here ourselves, or
+ // something like that...
+ ARROW_CHECK_OK(client_.Get(&object_id, /*num_objects=*/1,
+ /*timeout_ms=*/-1, &object_buffer));
+ }
+
+ std::shared_ptr<arrow::Tensor> ndarray;
+ ARROW_CHECK_OK(arrow::py::NdarrayFromBuffer(object_buffer.data, &ndarray));
+
+ int64_t byte_width = get_byte_width(*ndarray->type());
+ const int64_t size_in_bytes = ndarray->data()->size();
+
+ tf::TensorShape shape({static_cast<int64_t>(size_in_bytes / byte_width)});
+
+ const float* plasma_data = reinterpret_cast<const float*>(ndarray->raw_data());
+
+ tf::Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shape, &output_tensor),
+ done);
+
+ auto wrapped_callback = [this, context, done, plasma_data, object_id]() {
+ {
+ tf::mutex_lock lock(mu_);
+ ARROW_CHECK_OK(client_.Release(object_id));
+#ifdef GOOGLE_CUDA
+ auto orig_stream = context->op_device_context()->stream();
+ auto stream_executor = orig_stream->parent();
+ CHECK(stream_executor->HostMemoryUnregister(
+ const_cast<void*>(static_cast<const void*>(plasma_data))));
+#endif
+ }
+ done();
+ };
+
+ if (std::is_same<Device, CPUDevice>::value) {
+ std::memcpy(
+ reinterpret_cast<void*>(const_cast<char*>(output_tensor->tensor_data().data())),
+ plasma_data, size_in_bytes);
+ wrapped_callback();
+ } else {
+#ifdef GOOGLE_CUDA
+ auto orig_stream = context->op_device_context()->stream();
+ OP_REQUIRES_ASYNC(context, orig_stream != nullptr,
+ tf::errors::Internal("No GPU stream available."), done);
+ auto stream_executor = orig_stream->parent();
+
+ {
+ tf::mutex_lock l(h2d_stream_mu);
+ if (h2d_stream == nullptr) {
+ h2d_stream = new Stream(stream_executor);
+ CHECK(h2d_stream->Init().ok());
+ }
+ }
+
+ // Important. See note in T2P op.
+ CHECK(stream_executor->HostMemoryRegister(
+ const_cast<void*>(static_cast<const void*>(plasma_data)),
+ static_cast<tf::uint64>(size_in_bytes)));
+
+ perftools::gputools::DeviceMemoryBase wrapped_dst(
+ reinterpret_cast<void*>(const_cast<char*>(output_tensor->tensor_data().data())));
+ const bool success =
+ h2d_stream
+ ->ThenMemcpy(&wrapped_dst, static_cast<const void*>(plasma_data),
+ static_cast<tf::uint64>(size_in_bytes))
+ .ok();
+ OP_REQUIRES_ASYNC(context, success,
+ tf::errors::Internal("H2D memcpy failed to be enqueued."), done);
+
+ // Without this sync the main compute stream might proceed to use the
+ // Tensor buffer, but its contents might still be in-flight from our
+ // h2d_stream.
+ CHECK(orig_stream->ThenWaitFor(h2d_stream).ok());
+
+ context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ h2d_stream, std::move(wrapped_callback));
+#endif
+ }
+ }
+
+ private:
+ std::string plasma_store_socket_name_;
+
+ tf::mutex mu_;
+ bool connected_ = false;
+ plasma::PlasmaClient client_ GUARDED_BY(mu_);
+};
+
+REGISTER_OP("TensorToPlasma")
+ .Input("input_tensor: dtypes")
+ .Input("plasma_object_id: string")
+ .Attr("dtypes: list(type)")
+ .Attr("plasma_store_socket_name: string");
+
+REGISTER_KERNEL_BUILDER(Name("TensorToPlasma").Device(tf::DEVICE_CPU),
+ TensorToPlasmaOp<CPUDevice>);
+#ifdef GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("TensorToPlasma").Device(tf::DEVICE_GPU),
+ TensorToPlasmaOp<GPUDevice>);
+#endif
+
+REGISTER_OP("PlasmaToTensor")
+ .Input("plasma_object_id: string")
+ .Output("tensor: dtype")
+ .Attr("dtype: type")
+ .Attr("plasma_store_socket_name: string");
+
+REGISTER_KERNEL_BUILDER(Name("PlasmaToTensor").Device(tf::DEVICE_CPU),
+ PlasmaToTensorOp<CPUDevice>);
+#ifdef GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("PlasmaToTensor").Device(tf::DEVICE_GPU),
+ PlasmaToTensorOp<GPUDevice>);
+#endif
diff --git a/src/arrow/python/pyarrow/tests/__init__.py b/src/arrow/python/pyarrow/tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/__init__.py
diff --git a/src/arrow/python/pyarrow/tests/arrow_7980.py b/src/arrow/python/pyarrow/tests/arrow_7980.py
new file mode 100644
index 000000000..c1bc3176d
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/arrow_7980.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file is called from a test in test_schema.py.
+
+import pyarrow as pa
+
+
+# the types where to_pandas_dtype returns a non-numpy dtype
+cases = [
+ (pa.timestamp('ns', tz='UTC'), "datetime64[ns, UTC]"),
+]
+
+
+for arrow_type, pandas_type in cases:
+ assert str(arrow_type.to_pandas_dtype()) == pandas_type
diff --git a/src/arrow/python/pyarrow/tests/bound_function_visit_strings.pyx b/src/arrow/python/pyarrow/tests/bound_function_visit_strings.pyx
new file mode 100644
index 000000000..90437be8c
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/bound_function_visit_strings.pyx
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language=c++
+# cython: language_level = 3
+
+import pyarrow as pa
+from pyarrow.lib cimport *
+from pyarrow.lib import frombytes, tobytes
+
+# basic test to roundtrip through a BoundFunction
+
+ctypedef CStatus visit_string_cb(const c_string&)
+
+cdef extern from * namespace "arrow::py" nogil:
+ """
+ #include <functional>
+ #include <string>
+ #include <vector>
+
+ #include "arrow/status.h"
+
+ namespace arrow {
+ namespace py {
+
+ Status VisitStrings(const std::vector<std::string>& strs,
+ std::function<Status(const std::string&)> cb) {
+ for (const std::string& str : strs) {
+ RETURN_NOT_OK(cb(str));
+ }
+ return Status::OK();
+ }
+
+ } // namespace py
+ } // namespace arrow
+ """
+ cdef CStatus CVisitStrings" arrow::py::VisitStrings"(
+ vector[c_string], function[visit_string_cb])
+
+
+cdef void _visit_strings_impl(py_cb, const c_string& s) except *:
+ py_cb(frombytes(s))
+
+
+def _visit_strings(strings, cb):
+ cdef:
+ function[visit_string_cb] c_cb
+ vector[c_string] c_strings
+
+ c_cb = BindFunction[visit_string_cb](&_visit_strings_impl, cb)
+ for s in strings:
+ c_strings.push_back(tobytes(s))
+
+ check_status(CVisitStrings(c_strings, c_cb))
diff --git a/src/arrow/python/pyarrow/tests/conftest.py b/src/arrow/python/pyarrow/tests/conftest.py
new file mode 100644
index 000000000..8fa520b93
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/conftest.py
@@ -0,0 +1,302 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import pathlib
+import subprocess
+from tempfile import TemporaryDirectory
+
+import pytest
+import hypothesis as h
+
+from pyarrow.util import find_free_port
+from pyarrow import Codec
+
+
+# setup hypothesis profiles
+h.settings.register_profile('ci', max_examples=1000)
+h.settings.register_profile('dev', max_examples=50)
+h.settings.register_profile('debug', max_examples=10,
+ verbosity=h.Verbosity.verbose)
+
+# load default hypothesis profile, either set HYPOTHESIS_PROFILE environment
+# variable or pass --hypothesis-profile option to pytest, to see the generated
+# examples try:
+# pytest pyarrow -sv --enable-hypothesis --hypothesis-profile=debug
+h.settings.load_profile(os.environ.get('HYPOTHESIS_PROFILE', 'dev'))
+
+# Set this at the beginning before the AWS SDK was loaded to avoid reading in
+# user configuration values.
+os.environ['AWS_CONFIG_FILE'] = "/dev/null"
+
+
+groups = [
+ 'brotli',
+ 'bz2',
+ 'cython',
+ 'dataset',
+ 'hypothesis',
+ 'fastparquet',
+ 'gandiva',
+ 'gzip',
+ 'hdfs',
+ 'large_memory',
+ 'lz4',
+ 'memory_leak',
+ 'nopandas',
+ 'orc',
+ 'pandas',
+ 'parquet',
+ 'plasma',
+ 's3',
+ 'snappy',
+ 'tensorflow',
+ 'flight',
+ 'slow',
+ 'requires_testing_data',
+ 'zstd',
+]
+
+defaults = {
+ 'brotli': Codec.is_available('brotli'),
+ 'bz2': Codec.is_available('bz2'),
+ 'cython': False,
+ 'dataset': False,
+ 'fastparquet': False,
+ 'hypothesis': False,
+ 'gandiva': False,
+ 'gzip': Codec.is_available('gzip'),
+ 'hdfs': False,
+ 'large_memory': False,
+ 'lz4': Codec.is_available('lz4'),
+ 'memory_leak': False,
+ 'orc': False,
+ 'nopandas': False,
+ 'pandas': False,
+ 'parquet': False,
+ 'plasma': False,
+ 's3': False,
+ 'snappy': Codec.is_available('snappy'),
+ 'tensorflow': False,
+ 'flight': False,
+ 'slow': False,
+ 'requires_testing_data': True,
+ 'zstd': Codec.is_available('zstd'),
+}
+
+try:
+ import cython # noqa
+ defaults['cython'] = True
+except ImportError:
+ pass
+
+try:
+ import fastparquet # noqa
+ defaults['fastparquet'] = True
+except ImportError:
+ pass
+
+try:
+ import pyarrow.gandiva # noqa
+ defaults['gandiva'] = True
+except ImportError:
+ pass
+
+try:
+ import pyarrow.dataset # noqa
+ defaults['dataset'] = True
+except ImportError:
+ pass
+
+try:
+ import pyarrow.orc # noqa
+ defaults['orc'] = True
+except ImportError:
+ pass
+
+try:
+ import pandas # noqa
+ defaults['pandas'] = True
+except ImportError:
+ defaults['nopandas'] = True
+
+try:
+ import pyarrow.parquet # noqa
+ defaults['parquet'] = True
+except ImportError:
+ pass
+
+try:
+ import pyarrow.plasma # noqa
+ defaults['plasma'] = True
+except ImportError:
+ pass
+
+try:
+ import tensorflow # noqa
+ defaults['tensorflow'] = True
+except ImportError:
+ pass
+
+try:
+ import pyarrow.flight # noqa
+ defaults['flight'] = True
+except ImportError:
+ pass
+
+try:
+ from pyarrow.fs import S3FileSystem # noqa
+ defaults['s3'] = True
+except ImportError:
+ pass
+
+try:
+ from pyarrow.fs import HadoopFileSystem # noqa
+ defaults['hdfs'] = True
+except ImportError:
+ pass
+
+
+def pytest_addoption(parser):
+ # Create options to selectively enable test groups
+ def bool_env(name, default=None):
+ value = os.environ.get(name.upper())
+ if value is None:
+ return default
+ value = value.lower()
+ if value in {'1', 'true', 'on', 'yes', 'y'}:
+ return True
+ elif value in {'0', 'false', 'off', 'no', 'n'}:
+ return False
+ else:
+ raise ValueError('{}={} is not parsable as boolean'
+ .format(name.upper(), value))
+
+ for group in groups:
+ default = bool_env('PYARROW_TEST_{}'.format(group), defaults[group])
+ parser.addoption('--enable-{}'.format(group),
+ action='store_true', default=default,
+ help=('Enable the {} test group'.format(group)))
+ parser.addoption('--disable-{}'.format(group),
+ action='store_true', default=False,
+ help=('Disable the {} test group'.format(group)))
+
+
+class PyArrowConfig:
+ def __init__(self):
+ self.is_enabled = {}
+
+ def apply_mark(self, mark):
+ group = mark.name
+ if group in groups:
+ self.requires(group)
+
+ def requires(self, group):
+ if not self.is_enabled[group]:
+ pytest.skip('{} NOT enabled'.format(group))
+
+
+def pytest_configure(config):
+ # Apply command-line options to initialize PyArrow-specific config object
+ config.pyarrow = PyArrowConfig()
+
+ for mark in groups:
+ config.addinivalue_line(
+ "markers", mark,
+ )
+
+ enable_flag = '--enable-{}'.format(mark)
+ disable_flag = '--disable-{}'.format(mark)
+
+ is_enabled = (config.getoption(enable_flag) and not
+ config.getoption(disable_flag))
+ config.pyarrow.is_enabled[mark] = is_enabled
+
+
+def pytest_runtest_setup(item):
+ # Apply test markers to skip tests selectively
+ for mark in item.iter_markers():
+ item.config.pyarrow.apply_mark(mark)
+
+
+@pytest.fixture
+def tempdir(tmpdir):
+ # convert pytest's LocalPath to pathlib.Path
+ return pathlib.Path(tmpdir.strpath)
+
+
+@pytest.fixture(scope='session')
+def base_datadir():
+ return pathlib.Path(__file__).parent / 'data'
+
+
+@pytest.fixture(autouse=True)
+def disable_aws_metadata(monkeypatch):
+ """Stop the AWS SDK from trying to contact the EC2 metadata server.
+
+ Otherwise, this causes a 5 second delay in tests that exercise the
+ S3 filesystem.
+ """
+ monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
+
+
+# TODO(kszucs): move the following fixtures to test_fs.py once the previous
+# parquet dataset implementation and hdfs implementation are removed.
+
+@pytest.fixture(scope='session')
+def hdfs_connection():
+ host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default')
+ port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
+ user = os.environ.get('ARROW_HDFS_TEST_USER', 'hdfs')
+ return host, port, user
+
+
+@pytest.fixture(scope='session')
+def s3_connection():
+ host, port = 'localhost', find_free_port()
+ access_key, secret_key = 'arrow', 'apachearrow'
+ return host, port, access_key, secret_key
+
+
+@pytest.fixture(scope='session')
+def s3_server(s3_connection):
+ host, port, access_key, secret_key = s3_connection
+
+ address = '{}:{}'.format(host, port)
+ env = os.environ.copy()
+ env.update({
+ 'MINIO_ACCESS_KEY': access_key,
+ 'MINIO_SECRET_KEY': secret_key
+ })
+
+ with TemporaryDirectory() as tempdir:
+ args = ['minio', '--compat', 'server', '--quiet', '--address',
+ address, tempdir]
+ proc = None
+ try:
+ proc = subprocess.Popen(args, env=env)
+ except OSError:
+ pytest.skip('`minio` command cannot be located')
+ else:
+ yield {
+ 'connection': s3_connection,
+ 'process': proc,
+ 'tempdir': tempdir
+ }
+ finally:
+ if proc is not None:
+ proc.kill()
diff --git a/src/arrow/python/pyarrow/tests/data/feather/v0.17.0.version=2-compression=lz4.feather b/src/arrow/python/pyarrow/tests/data/feather/v0.17.0.version=2-compression=lz4.feather
new file mode 100644
index 000000000..562b0b2c5
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/feather/v0.17.0.version=2-compression=lz4.feather
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/README.md b/src/arrow/python/pyarrow/tests/data/orc/README.md
new file mode 100644
index 000000000..ccbb0e8b1
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/README.md
@@ -0,0 +1,22 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+The ORC and JSON files come from the `examples` directory in the Apache ORC
+source tree:
+https://github.com/apache/orc/tree/master/examples
diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.jsn.gz
new file mode 100644
index 000000000..91c85cd76
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.jsn.gz
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.orc b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.orc
new file mode 100644
index 000000000..ecdadcbff
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.orc
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.jsn.gz
new file mode 100644
index 000000000..5eab19a41
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.jsn.gz
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.orc b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.orc
new file mode 100644
index 000000000..4fb0beff8
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.orc
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.jsn.gz
new file mode 100644
index 000000000..62dbaba42
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.jsn.gz
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.orc b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.orc
new file mode 100644
index 000000000..f51ffdbd0
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.orc
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/decimal.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/decimal.jsn.gz
new file mode 100644
index 000000000..e634bd70b
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/decimal.jsn.gz
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/orc/decimal.orc b/src/arrow/python/pyarrow/tests/data/orc/decimal.orc
new file mode 100644
index 000000000..cb0f7b9d7
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/orc/decimal.orc
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.all-named-index.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.all-named-index.parquet
new file mode 100644
index 000000000..e9efd9b39
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.all-named-index.parquet
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.column-metadata-handling.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.column-metadata-handling.parquet
new file mode 100644
index 000000000..d48041f51
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.column-metadata-handling.parquet
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.parquet
new file mode 100644
index 000000000..44670bcd1
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.parquet
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.some-named-index.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.some-named-index.parquet
new file mode 100644
index 000000000..34097ca12
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.some-named-index.parquet
Binary files differ
diff --git a/src/arrow/python/pyarrow/tests/deserialize_buffer.py b/src/arrow/python/pyarrow/tests/deserialize_buffer.py
new file mode 100644
index 000000000..982dc6695
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/deserialize_buffer.py
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file is called from a test in test_serialization.py.
+
+import sys
+
+import pyarrow as pa
+
+with open(sys.argv[1], 'rb') as f:
+ data = f.read()
+ pa.deserialize(data)
diff --git a/src/arrow/python/pyarrow/tests/pandas_examples.py b/src/arrow/python/pyarrow/tests/pandas_examples.py
new file mode 100644
index 000000000..466c14eeb
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/pandas_examples.py
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import OrderedDict
+from datetime import date, time
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+
+
+def dataframe_with_arrays(include_index=False):
+ """
+ Dataframe with numpy arrays columns of every possible primitive type.
+
+ Returns
+ -------
+ df: pandas.DataFrame
+ schema: pyarrow.Schema
+ Arrow schema definition that is in line with the constructed df.
+ """
+ dtypes = [('i1', pa.int8()), ('i2', pa.int16()),
+ ('i4', pa.int32()), ('i8', pa.int64()),
+ ('u1', pa.uint8()), ('u2', pa.uint16()),
+ ('u4', pa.uint32()), ('u8', pa.uint64()),
+ ('f4', pa.float32()), ('f8', pa.float64())]
+
+ arrays = OrderedDict()
+ fields = []
+ for dtype, arrow_dtype in dtypes:
+ fields.append(pa.field(dtype, pa.list_(arrow_dtype)))
+ arrays[dtype] = [
+ np.arange(10, dtype=dtype),
+ np.arange(5, dtype=dtype),
+ None,
+ np.arange(1, dtype=dtype)
+ ]
+
+ fields.append(pa.field('str', pa.list_(pa.string())))
+ arrays['str'] = [
+ np.array(["1", "ä"], dtype="object"),
+ None,
+ np.array(["1"], dtype="object"),
+ np.array(["1", "2", "3"], dtype="object")
+ ]
+
+ fields.append(pa.field('datetime64', pa.list_(pa.timestamp('ms'))))
+ arrays['datetime64'] = [
+ np.array(['2007-07-13T01:23:34.123456789',
+ None,
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ms]'),
+ None,
+ None,
+ np.array(['2007-07-13T02',
+ None,
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ms]'),
+ ]
+
+ if include_index:
+ fields.append(pa.field('__index_level_0__', pa.int64()))
+ df = pd.DataFrame(arrays)
+ schema = pa.schema(fields)
+
+ return df, schema
+
+
+def dataframe_with_lists(include_index=False, parquet_compatible=False):
+ """
+ Dataframe with list columns of every possible primitive type.
+
+ Returns
+ -------
+ df: pandas.DataFrame
+ schema: pyarrow.Schema
+ Arrow schema definition that is in line with the constructed df.
+ parquet_compatible: bool
+ Exclude types not supported by parquet
+ """
+ arrays = OrderedDict()
+ fields = []
+
+ fields.append(pa.field('int64', pa.list_(pa.int64())))
+ arrays['int64'] = [
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ [0, 1, 2, 3, 4],
+ None,
+ [],
+ np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 2,
+ dtype=np.int64)[::2]
+ ]
+ fields.append(pa.field('double', pa.list_(pa.float64())))
+ arrays['double'] = [
+ [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
+ [0., 1., 2., 3., 4.],
+ None,
+ [],
+ np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.] * 2)[::2],
+ ]
+ fields.append(pa.field('bytes_list', pa.list_(pa.binary())))
+ arrays['bytes_list'] = [
+ [b"1", b"f"],
+ None,
+ [b"1"],
+ [b"1", b"2", b"3"],
+ [],
+ ]
+ fields.append(pa.field('str_list', pa.list_(pa.string())))
+ arrays['str_list'] = [
+ ["1", "ä"],
+ None,
+ ["1"],
+ ["1", "2", "3"],
+ [],
+ ]
+
+ date_data = [
+ [],
+ [date(2018, 1, 1), date(2032, 12, 30)],
+ [date(2000, 6, 7)],
+ None,
+ [date(1969, 6, 9), date(1972, 7, 3)]
+ ]
+ time_data = [
+ [time(23, 11, 11), time(1, 2, 3), time(23, 59, 59)],
+ [],
+ [time(22, 5, 59)],
+ None,
+ [time(0, 0, 0), time(18, 0, 2), time(12, 7, 3)]
+ ]
+
+ temporal_pairs = [
+ (pa.date32(), date_data),
+ (pa.date64(), date_data),
+ (pa.time32('s'), time_data),
+ (pa.time32('ms'), time_data),
+ (pa.time64('us'), time_data)
+ ]
+ if not parquet_compatible:
+ temporal_pairs += [
+ (pa.time64('ns'), time_data),
+ ]
+
+ for value_type, data in temporal_pairs:
+ field_name = '{}_list'.format(value_type)
+ field_type = pa.list_(value_type)
+ field = pa.field(field_name, field_type)
+ fields.append(field)
+ arrays[field_name] = data
+
+ if include_index:
+ fields.append(pa.field('__index_level_0__', pa.int64()))
+
+ df = pd.DataFrame(arrays)
+ schema = pa.schema(fields)
+
+ return df, schema
diff --git a/src/arrow/python/pyarrow/tests/pandas_threaded_import.py b/src/arrow/python/pyarrow/tests/pandas_threaded_import.py
new file mode 100644
index 000000000..f44632d74
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/pandas_threaded_import.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file is called from a test in test_pandas.py.
+
+from concurrent.futures import ThreadPoolExecutor
+import faulthandler
+import sys
+
+import pyarrow as pa
+
+num_threads = 60
+timeout = 10 # seconds
+
+
+def thread_func(i):
+ pa.array([i]).to_pandas()
+
+
+def main():
+ # In case of import deadlock, crash after a finite timeout
+ faulthandler.dump_traceback_later(timeout, exit=True)
+ with ThreadPoolExecutor(num_threads) as pool:
+ assert "pandas" not in sys.modules # pandas is imported lazily
+ list(pool.map(thread_func, range(num_threads)))
+ assert "pandas" in sys.modules
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/arrow/python/pyarrow/tests/parquet/common.py b/src/arrow/python/pyarrow/tests/parquet/common.py
new file mode 100644
index 000000000..90bfb55d1
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/common.py
@@ -0,0 +1,177 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import io
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow.tests import util
+
+parametrize_legacy_dataset = pytest.mark.parametrize(
+ "use_legacy_dataset",
+ [True, pytest.param(False, marks=pytest.mark.dataset)])
+parametrize_legacy_dataset_not_supported = pytest.mark.parametrize(
+ "use_legacy_dataset", [True, pytest.param(False, marks=pytest.mark.skip)])
+parametrize_legacy_dataset_fixed = pytest.mark.parametrize(
+ "use_legacy_dataset", [pytest.param(True, marks=pytest.mark.xfail),
+ pytest.param(False, marks=pytest.mark.dataset)])
+
+# Marks all of the tests in this module
+# Ignore these with pytest ... -m 'not parquet'
+pytestmark = pytest.mark.parquet
+
+
+def _write_table(table, path, **kwargs):
+ # So we see the ImportError somewhere
+ import pyarrow.parquet as pq
+ from pyarrow.pandas_compat import _pandas_api
+
+ if _pandas_api.is_data_frame(table):
+ table = pa.Table.from_pandas(table)
+
+ pq.write_table(table, path, **kwargs)
+ return table
+
+
+def _read_table(*args, **kwargs):
+ import pyarrow.parquet as pq
+
+ table = pq.read_table(*args, **kwargs)
+ table.validate(full=True)
+ return table
+
+
+def _roundtrip_table(table, read_table_kwargs=None,
+ write_table_kwargs=None, use_legacy_dataset=True):
+ read_table_kwargs = read_table_kwargs or {}
+ write_table_kwargs = write_table_kwargs or {}
+
+ writer = pa.BufferOutputStream()
+ _write_table(table, writer, **write_table_kwargs)
+ reader = pa.BufferReader(writer.getvalue())
+ return _read_table(reader, use_legacy_dataset=use_legacy_dataset,
+ **read_table_kwargs)
+
+
+def _check_roundtrip(table, expected=None, read_table_kwargs=None,
+ use_legacy_dataset=True, **write_table_kwargs):
+ if expected is None:
+ expected = table
+
+ read_table_kwargs = read_table_kwargs or {}
+
+ # intentionally check twice
+ result = _roundtrip_table(table, read_table_kwargs=read_table_kwargs,
+ write_table_kwargs=write_table_kwargs,
+ use_legacy_dataset=use_legacy_dataset)
+ assert result.equals(expected)
+ result = _roundtrip_table(result, read_table_kwargs=read_table_kwargs,
+ write_table_kwargs=write_table_kwargs,
+ use_legacy_dataset=use_legacy_dataset)
+ assert result.equals(expected)
+
+
+def _roundtrip_pandas_dataframe(df, write_kwargs, use_legacy_dataset=True):
+ table = pa.Table.from_pandas(df)
+ result = _roundtrip_table(
+ table, write_table_kwargs=write_kwargs,
+ use_legacy_dataset=use_legacy_dataset)
+ return result.to_pandas()
+
+
+def _random_integers(size, dtype):
+ # We do not generate integers outside the int64 range
+ platform_int_info = np.iinfo('int_')
+ iinfo = np.iinfo(dtype)
+ return np.random.randint(max(iinfo.min, platform_int_info.min),
+ min(iinfo.max, platform_int_info.max),
+ size=size).astype(dtype)
+
+
+def _test_dataframe(size=10000, seed=0):
+ import pandas as pd
+
+ np.random.seed(seed)
+ df = pd.DataFrame({
+ 'uint8': _random_integers(size, np.uint8),
+ 'uint16': _random_integers(size, np.uint16),
+ 'uint32': _random_integers(size, np.uint32),
+ 'uint64': _random_integers(size, np.uint64),
+ 'int8': _random_integers(size, np.int8),
+ 'int16': _random_integers(size, np.int16),
+ 'int32': _random_integers(size, np.int32),
+ 'int64': _random_integers(size, np.int64),
+ 'float32': np.random.randn(size).astype(np.float32),
+ 'float64': np.arange(size, dtype=np.float64),
+ 'bool': np.random.randn(size) > 0,
+ 'strings': [util.rands(10) for i in range(size)],
+ 'all_none': [None] * size,
+ 'all_none_category': [None] * size
+ })
+
+ # TODO(PARQUET-1015)
+ # df['all_none_category'] = df['all_none_category'].astype('category')
+ return df
+
+
+def make_sample_file(table_or_df):
+ import pyarrow.parquet as pq
+
+ if isinstance(table_or_df, pa.Table):
+ a_table = table_or_df
+ else:
+ a_table = pa.Table.from_pandas(table_or_df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, compression='SNAPPY', version='2.6',
+ coerce_timestamps='ms')
+
+ buf.seek(0)
+ return pq.ParquetFile(buf)
+
+
+def alltypes_sample(size=10000, seed=0, categorical=False):
+ import pandas as pd
+
+ np.random.seed(seed)
+ arrays = {
+ 'uint8': np.arange(size, dtype=np.uint8),
+ 'uint16': np.arange(size, dtype=np.uint16),
+ 'uint32': np.arange(size, dtype=np.uint32),
+ 'uint64': np.arange(size, dtype=np.uint64),
+ 'int8': np.arange(size, dtype=np.int16),
+ 'int16': np.arange(size, dtype=np.int16),
+ 'int32': np.arange(size, dtype=np.int32),
+ 'int64': np.arange(size, dtype=np.int64),
+ 'float32': np.arange(size, dtype=np.float32),
+ 'float64': np.arange(size, dtype=np.float64),
+ 'bool': np.random.randn(size) > 0,
+ # TODO(wesm): Test other timestamp resolutions now that arrow supports
+ # them
+ 'datetime': np.arange("2016-01-01T00:00:00.001", size,
+ dtype='datetime64[ms]'),
+ 'str': pd.Series([str(x) for x in range(size)]),
+ 'empty_str': [''] * size,
+ 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None],
+ 'null': [None] * size,
+ 'null_list': [None] * 2 + [[None] * (x % 4) for x in range(size - 2)],
+ }
+ if categorical:
+ arrays['str_category'] = arrays['str'].astype('category')
+ return pd.DataFrame(arrays)
diff --git a/src/arrow/python/pyarrow/tests/parquet/conftest.py b/src/arrow/python/pyarrow/tests/parquet/conftest.py
new file mode 100644
index 000000000..1e75493cd
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/conftest.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+from pyarrow.util import guid
+
+
+@pytest.fixture(scope='module')
+def datadir(base_datadir):
+ return base_datadir / 'parquet'
+
+
+@pytest.fixture
+def s3_bucket(s3_server):
+ boto3 = pytest.importorskip('boto3')
+ botocore = pytest.importorskip('botocore')
+
+ host, port, access_key, secret_key = s3_server['connection']
+ s3 = boto3.resource(
+ 's3',
+ endpoint_url='http://{}:{}'.format(host, port),
+ aws_access_key_id=access_key,
+ aws_secret_access_key=secret_key,
+ config=botocore.client.Config(signature_version='s3v4'),
+ region_name='us-east-1'
+ )
+ bucket = s3.Bucket('test-s3fs')
+ try:
+ bucket.create()
+ except Exception:
+ # we get BucketAlreadyOwnedByYou error with fsspec handler
+ pass
+ return 'test-s3fs'
+
+
+@pytest.fixture
+def s3_example_s3fs(s3_server, s3_bucket):
+ s3fs = pytest.importorskip('s3fs')
+
+ host, port, access_key, secret_key = s3_server['connection']
+ fs = s3fs.S3FileSystem(
+ key=access_key,
+ secret=secret_key,
+ client_kwargs={
+ 'endpoint_url': 'http://{}:{}'.format(host, port)
+ }
+ )
+
+ test_path = '{}/{}'.format(s3_bucket, guid())
+
+ fs.mkdir(test_path)
+ yield fs, test_path
+ try:
+ fs.rm(test_path, recursive=True)
+ except FileNotFoundError:
+ pass
+
+
+@pytest.fixture
+def s3_example_fs(s3_server):
+ from pyarrow.fs import FileSystem
+
+ host, port, access_key, secret_key = s3_server['connection']
+ uri = (
+ "s3://{}:{}@mybucket/data.parquet?scheme=http&endpoint_override={}:{}"
+ .format(access_key, secret_key, host, port)
+ )
+ fs, path = FileSystem.from_uri(uri)
+
+ fs.create_dir("mybucket")
+
+ yield fs, uri, path
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_basic.py b/src/arrow/python/pyarrow/tests/parquet/test_basic.py
new file mode 100644
index 000000000..cf1aaa21f
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_basic.py
@@ -0,0 +1,631 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import OrderedDict
+import io
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow import fs
+from pyarrow.filesystem import LocalFileSystem, FileSystem
+from pyarrow.tests import util
+from pyarrow.tests.parquet.common import (_check_roundtrip, _roundtrip_table,
+ parametrize_legacy_dataset)
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import _read_table, _write_table
+except ImportError:
+ pq = None
+
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+ from pyarrow.tests.pandas_examples import dataframe_with_lists
+ from pyarrow.tests.parquet.common import alltypes_sample
+except ImportError:
+ pd = tm = None
+
+
+pytestmark = pytest.mark.parquet
+
+
+def test_parquet_invalid_version(tempdir):
+ table = pa.table({'a': [1, 2, 3]})
+ with pytest.raises(ValueError, match="Unsupported Parquet format version"):
+ _write_table(table, tempdir / 'test_version.parquet', version="2.2")
+ with pytest.raises(ValueError, match="Unsupported Parquet data page " +
+ "version"):
+ _write_table(table, tempdir / 'test_version.parquet',
+ data_page_version="2.2")
+
+
+@parametrize_legacy_dataset
+def test_set_data_page_size(use_legacy_dataset):
+ arr = pa.array([1, 2, 3] * 100000)
+ t = pa.Table.from_arrays([arr], names=['f0'])
+
+ # 128K, 512K
+ page_sizes = [2 << 16, 2 << 18]
+ for target_page_size in page_sizes:
+ _check_roundtrip(t, data_page_size=target_page_size,
+ use_legacy_dataset=use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_chunked_table_write(use_legacy_dataset):
+ # ARROW-232
+ tables = []
+ batch = pa.RecordBatch.from_pandas(alltypes_sample(size=10))
+ tables.append(pa.Table.from_batches([batch] * 3))
+ df, _ = dataframe_with_lists()
+ batch = pa.RecordBatch.from_pandas(df)
+ tables.append(pa.Table.from_batches([batch] * 3))
+
+ for data_page_version in ['1.0', '2.0']:
+ for use_dictionary in [True, False]:
+ for table in tables:
+ _check_roundtrip(
+ table, version='2.6',
+ use_legacy_dataset=use_legacy_dataset,
+ data_page_version=data_page_version,
+ use_dictionary=use_dictionary)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_memory_map(tempdir, use_legacy_dataset):
+ df = alltypes_sample(size=10)
+
+ table = pa.Table.from_pandas(df)
+ _check_roundtrip(table, read_table_kwargs={'memory_map': True},
+ version='2.6', use_legacy_dataset=use_legacy_dataset)
+
+ filename = str(tempdir / 'tmp_file')
+ with open(filename, 'wb') as f:
+ _write_table(table, f, version='2.6')
+ table_read = pq.read_pandas(filename, memory_map=True,
+ use_legacy_dataset=use_legacy_dataset)
+ assert table_read.equals(table)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_enable_buffered_stream(tempdir, use_legacy_dataset):
+ df = alltypes_sample(size=10)
+
+ table = pa.Table.from_pandas(df)
+ _check_roundtrip(table, read_table_kwargs={'buffer_size': 1025},
+ version='2.6', use_legacy_dataset=use_legacy_dataset)
+
+ filename = str(tempdir / 'tmp_file')
+ with open(filename, 'wb') as f:
+ _write_table(table, f, version='2.6')
+ table_read = pq.read_pandas(filename, buffer_size=4096,
+ use_legacy_dataset=use_legacy_dataset)
+ assert table_read.equals(table)
+
+
+@parametrize_legacy_dataset
+def test_special_chars_filename(tempdir, use_legacy_dataset):
+ table = pa.Table.from_arrays([pa.array([42])], ["ints"])
+ filename = "foo # bar"
+ path = tempdir / filename
+ assert not path.exists()
+ _write_table(table, str(path))
+ assert path.exists()
+ table_read = _read_table(str(path), use_legacy_dataset=use_legacy_dataset)
+ assert table_read.equals(table)
+
+
+@parametrize_legacy_dataset
+def test_invalid_source(use_legacy_dataset):
+ # Test that we provide an helpful error message pointing out
+ # that None wasn't expected when trying to open a Parquet None file.
+ #
+ # Depending on use_legacy_dataset the message changes slightly
+ # but in both cases it should point out that None wasn't expected.
+ with pytest.raises(TypeError, match="None"):
+ pq.read_table(None, use_legacy_dataset=use_legacy_dataset)
+
+ with pytest.raises(TypeError, match="None"):
+ pq.ParquetFile(None)
+
+
+@pytest.mark.slow
+def test_file_with_over_int16_max_row_groups():
+ # PARQUET-1857: Parquet encryption support introduced a INT16_MAX upper
+ # limit on the number of row groups, but this limit only impacts files with
+ # encrypted row group metadata because of the int16 row group ordinal used
+ # in the Parquet Thrift metadata. Unencrypted files are not impacted, so
+ # this test checks that it works (even if it isn't a good idea)
+ t = pa.table([list(range(40000))], names=['f0'])
+ _check_roundtrip(t, row_group_size=1)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_empty_table_roundtrip(use_legacy_dataset):
+ df = alltypes_sample(size=10)
+
+ # Create a non-empty table to infer the types correctly, then slice to 0
+ table = pa.Table.from_pandas(df)
+ table = pa.Table.from_arrays(
+ [col.chunk(0)[:0] for col in table.itercolumns()],
+ names=table.schema.names)
+
+ assert table.schema.field('null').type == pa.null()
+ assert table.schema.field('null_list').type == pa.list_(pa.null())
+ _check_roundtrip(
+ table, version='2.6', use_legacy_dataset=use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_empty_table_no_columns(use_legacy_dataset):
+ df = pd.DataFrame()
+ empty = pa.Table.from_pandas(df, preserve_index=False)
+ _check_roundtrip(empty, use_legacy_dataset=use_legacy_dataset)
+
+
+@parametrize_legacy_dataset
+def test_write_nested_zero_length_array_chunk_failure(use_legacy_dataset):
+ # Bug report in ARROW-3792
+ cols = OrderedDict(
+ int32=pa.int32(),
+ list_string=pa.list_(pa.string())
+ )
+ data = [[], [OrderedDict(int32=1, list_string=('G',)), ]]
+
+ # This produces a table with a column like
+ # <Column name='list_string' type=ListType(list<item: string>)>
+ # [
+ # [],
+ # [
+ # [
+ # "G"
+ # ]
+ # ]
+ # ]
+ #
+ # Each column is a ChunkedArray with 2 elements
+ my_arrays = [pa.array(batch, type=pa.struct(cols)).flatten()
+ for batch in data]
+ my_batches = [pa.RecordBatch.from_arrays(batch, schema=pa.schema(cols))
+ for batch in my_arrays]
+ tbl = pa.Table.from_batches(my_batches, pa.schema(cols))
+ _check_roundtrip(tbl, use_legacy_dataset=use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_multiple_path_types(tempdir, use_legacy_dataset):
+ # Test compatibility with PEP 519 path-like objects
+ path = tempdir / 'zzz.parquet'
+ df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)})
+ _write_table(df, path)
+ table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+ # Test compatibility with plain string paths
+ path = str(tempdir) + 'zzz.parquet'
+ df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)})
+ _write_table(df, path)
+ table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@parametrize_legacy_dataset
+def test_fspath(tempdir, use_legacy_dataset):
+ # ARROW-12472 support __fspath__ objects without using str()
+ path = tempdir / "test.parquet"
+ table = pa.table({"a": [1, 2, 3]})
+ _write_table(table, path)
+
+ fs_protocol_obj = util.FSProtocolClass(path)
+
+ result = _read_table(
+ fs_protocol_obj, use_legacy_dataset=use_legacy_dataset
+ )
+ assert result.equals(table)
+
+ # combined with non-local filesystem raises
+ with pytest.raises(TypeError):
+ _read_table(fs_protocol_obj, filesystem=FileSystem())
+
+
+@pytest.mark.dataset
+@parametrize_legacy_dataset
+@pytest.mark.parametrize("filesystem", [
+ None, fs.LocalFileSystem(), LocalFileSystem._get_instance()
+])
+def test_relative_paths(tempdir, use_legacy_dataset, filesystem):
+ # reading and writing from relative paths
+ table = pa.table({"a": [1, 2, 3]})
+
+ # reading
+ pq.write_table(table, str(tempdir / "data.parquet"))
+ with util.change_cwd(tempdir):
+ result = pq.read_table("data.parquet", filesystem=filesystem,
+ use_legacy_dataset=use_legacy_dataset)
+ assert result.equals(table)
+
+ # writing
+ with util.change_cwd(tempdir):
+ pq.write_table(table, "data2.parquet", filesystem=filesystem)
+ result = pq.read_table(tempdir / "data2.parquet")
+ assert result.equals(table)
+
+
+def test_read_non_existing_file():
+ # ensure we have a proper error message
+ with pytest.raises(FileNotFoundError):
+ pq.read_table('i-am-not-existing.parquet')
+
+
+def test_file_error_python_exception():
+ class BogusFile(io.BytesIO):
+ def read(self, *args):
+ raise ZeroDivisionError("zorglub")
+
+ def seek(self, *args):
+ raise ZeroDivisionError("zorglub")
+
+ # ensure the Python exception is restored
+ with pytest.raises(ZeroDivisionError, match="zorglub"):
+ pq.read_table(BogusFile(b""))
+
+
+@parametrize_legacy_dataset
+def test_parquet_read_from_buffer(tempdir, use_legacy_dataset):
+ # reading from a buffer from python's open()
+ table = pa.table({"a": [1, 2, 3]})
+ pq.write_table(table, str(tempdir / "data.parquet"))
+
+ with open(str(tempdir / "data.parquet"), "rb") as f:
+ result = pq.read_table(f, use_legacy_dataset=use_legacy_dataset)
+ assert result.equals(table)
+
+ with open(str(tempdir / "data.parquet"), "rb") as f:
+ result = pq.read_table(pa.PythonFile(f),
+ use_legacy_dataset=use_legacy_dataset)
+ assert result.equals(table)
+
+
+@parametrize_legacy_dataset
+def test_byte_stream_split(use_legacy_dataset):
+ # This is only a smoke test.
+ arr_float = pa.array(list(map(float, range(100))))
+ arr_int = pa.array(list(map(int, range(100))))
+ data_float = [arr_float, arr_float]
+ table = pa.Table.from_arrays(data_float, names=['a', 'b'])
+
+ # Check with byte_stream_split for both columns.
+ _check_roundtrip(table, expected=table, compression="gzip",
+ use_dictionary=False, use_byte_stream_split=True)
+
+ # Check with byte_stream_split for column 'b' and dictionary
+ # for column 'a'.
+ _check_roundtrip(table, expected=table, compression="gzip",
+ use_dictionary=['a'],
+ use_byte_stream_split=['b'])
+
+ # Check with a collision for both columns.
+ _check_roundtrip(table, expected=table, compression="gzip",
+ use_dictionary=['a', 'b'],
+ use_byte_stream_split=['a', 'b'])
+
+ # Check with mixed column types.
+ mixed_table = pa.Table.from_arrays([arr_float, arr_int],
+ names=['a', 'b'])
+ _check_roundtrip(mixed_table, expected=mixed_table,
+ use_dictionary=['b'],
+ use_byte_stream_split=['a'])
+
+ # Try to use the wrong data type with the byte_stream_split encoding.
+ # This should throw an exception.
+ table = pa.Table.from_arrays([arr_int], names=['tmp'])
+ with pytest.raises(IOError):
+ _check_roundtrip(table, expected=table, use_byte_stream_split=True,
+ use_dictionary=False,
+ use_legacy_dataset=use_legacy_dataset)
+
+
+@parametrize_legacy_dataset
+def test_compression_level(use_legacy_dataset):
+ arr = pa.array(list(map(int, range(1000))))
+ data = [arr, arr]
+ table = pa.Table.from_arrays(data, names=['a', 'b'])
+
+ # Check one compression level.
+ _check_roundtrip(table, expected=table, compression="gzip",
+ compression_level=1,
+ use_legacy_dataset=use_legacy_dataset)
+
+ # Check another one to make sure that compression_level=1 does not
+ # coincide with the default one in Arrow.
+ _check_roundtrip(table, expected=table, compression="gzip",
+ compression_level=5,
+ use_legacy_dataset=use_legacy_dataset)
+
+ # Check that the user can provide a compression per column
+ _check_roundtrip(table, expected=table,
+ compression={'a': "gzip", 'b': "snappy"},
+ use_legacy_dataset=use_legacy_dataset)
+
+ # Check that the user can provide a compression level per column
+ _check_roundtrip(table, expected=table, compression="gzip",
+ compression_level={'a': 2, 'b': 3},
+ use_legacy_dataset=use_legacy_dataset)
+
+ # Check that specifying a compression level for a codec which does allow
+ # specifying one, results into an error.
+ # Uncompressed, snappy, lz4 and lzo do not support specifying a compression
+ # level.
+ # GZIP (zlib) allows for specifying a compression level but as of up
+ # to version 1.2.11 the valid range is [-1, 9].
+ invalid_combinations = [("snappy", 4), ("lz4", 5), ("gzip", -1337),
+ ("None", 444), ("lzo", 14)]
+ buf = io.BytesIO()
+ for (codec, level) in invalid_combinations:
+ with pytest.raises((ValueError, OSError)):
+ _write_table(table, buf, compression=codec,
+ compression_level=level)
+
+
+def test_sanitized_spark_field_names():
+ a0 = pa.array([0, 1, 2, 3, 4])
+ name = 'prohib; ,\t{}'
+ table = pa.Table.from_arrays([a0], [name])
+
+ result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'})
+
+ expected_name = 'prohib______'
+ assert result.schema[0].name == expected_name
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_multithreaded_read(use_legacy_dataset):
+ df = alltypes_sample(size=10000)
+
+ table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(table, buf, compression='SNAPPY', version='2.6')
+
+ buf.seek(0)
+ table1 = _read_table(
+ buf, use_threads=True, use_legacy_dataset=use_legacy_dataset)
+
+ buf.seek(0)
+ table2 = _read_table(
+ buf, use_threads=False, use_legacy_dataset=use_legacy_dataset)
+
+ assert table1.equals(table2)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_min_chunksize(use_legacy_dataset):
+ data = pd.DataFrame([np.arange(4)], columns=['A', 'B', 'C', 'D'])
+ table = pa.Table.from_pandas(data.reset_index())
+
+ buf = io.BytesIO()
+ _write_table(table, buf, chunk_size=-1)
+
+ buf.seek(0)
+ result = _read_table(buf, use_legacy_dataset=use_legacy_dataset)
+
+ assert result.equals(table)
+
+ with pytest.raises(ValueError):
+ _write_table(table, buf, chunk_size=0)
+
+
+@pytest.mark.pandas
+def test_write_error_deletes_incomplete_file(tempdir):
+ # ARROW-1285
+ df = pd.DataFrame({'a': list('abc'),
+ 'b': list(range(1, 4)),
+ 'c': np.arange(3, 6).astype('u1'),
+ 'd': np.arange(4.0, 7.0, dtype='float64'),
+ 'e': [True, False, True],
+ 'f': pd.Categorical(list('abc')),
+ 'g': pd.date_range('20130101', periods=3),
+ 'h': pd.date_range('20130101', periods=3,
+ tz='US/Eastern'),
+ 'i': pd.date_range('20130101', periods=3, freq='ns')})
+
+ pdf = pa.Table.from_pandas(df)
+
+ filename = tempdir / 'tmp_file'
+ try:
+ _write_table(pdf, filename)
+ except pa.ArrowException:
+ pass
+
+ assert not filename.exists()
+
+
+@parametrize_legacy_dataset
+def test_read_non_existent_file(tempdir, use_legacy_dataset):
+ path = 'non-existent-file.parquet'
+ try:
+ pq.read_table(path, use_legacy_dataset=use_legacy_dataset)
+ except Exception as e:
+ assert path in e.args[0]
+
+
+@parametrize_legacy_dataset
+def test_read_table_doesnt_warn(datadir, use_legacy_dataset):
+ with pytest.warns(None) as record:
+ pq.read_table(datadir / 'v0.7.1.parquet',
+ use_legacy_dataset=use_legacy_dataset)
+
+ assert len(record) == 0
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_zlib_compression_bug(use_legacy_dataset):
+ # ARROW-3514: "zlib deflate failed, output buffer too small"
+ table = pa.Table.from_arrays([pa.array(['abc', 'def'])], ['some_col'])
+ f = io.BytesIO()
+ pq.write_table(table, f, compression='gzip')
+
+ f.seek(0)
+ roundtrip = pq.read_table(f, use_legacy_dataset=use_legacy_dataset)
+ tm.assert_frame_equal(roundtrip.to_pandas(), table.to_pandas())
+
+
+@parametrize_legacy_dataset
+def test_parquet_file_too_small(tempdir, use_legacy_dataset):
+ path = str(tempdir / "test.parquet")
+ # TODO(dataset) with datasets API it raises OSError instead
+ with pytest.raises((pa.ArrowInvalid, OSError),
+ match='size is 0 bytes'):
+ with open(path, 'wb') as f:
+ pass
+ pq.read_table(path, use_legacy_dataset=use_legacy_dataset)
+
+ with pytest.raises((pa.ArrowInvalid, OSError),
+ match='size is 4 bytes'):
+ with open(path, 'wb') as f:
+ f.write(b'ffff')
+ pq.read_table(path, use_legacy_dataset=use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@pytest.mark.fastparquet
+@pytest.mark.filterwarnings("ignore:RangeIndex:FutureWarning")
+@pytest.mark.filterwarnings("ignore:tostring:DeprecationWarning:fastparquet")
+def test_fastparquet_cross_compatibility(tempdir):
+ fp = pytest.importorskip('fastparquet')
+
+ df = pd.DataFrame(
+ {
+ "a": list("abc"),
+ "b": list(range(1, 4)),
+ "c": np.arange(4.0, 7.0, dtype="float64"),
+ "d": [True, False, True],
+ "e": pd.date_range("20130101", periods=3),
+ "f": pd.Categorical(["a", "b", "a"]),
+ # fastparquet writes list as BYTE_ARRAY JSON, so no roundtrip
+ # "g": [[1, 2], None, [1, 2, 3]],
+ }
+ )
+ table = pa.table(df)
+
+ # Arrow -> fastparquet
+ file_arrow = str(tempdir / "cross_compat_arrow.parquet")
+ pq.write_table(table, file_arrow, compression=None)
+
+ fp_file = fp.ParquetFile(file_arrow)
+ df_fp = fp_file.to_pandas()
+ tm.assert_frame_equal(df, df_fp)
+
+ # Fastparquet -> arrow
+ file_fastparquet = str(tempdir / "cross_compat_fastparquet.parquet")
+ fp.write(file_fastparquet, df)
+
+ table_fp = pq.read_pandas(file_fastparquet)
+ # for fastparquet written file, categoricals comes back as strings
+ # (no arrow schema in parquet metadata)
+ df['f'] = df['f'].astype(object)
+ tm.assert_frame_equal(table_fp.to_pandas(), df)
+
+
+@parametrize_legacy_dataset
+@pytest.mark.parametrize('array_factory', [
+ lambda: pa.array([0, None] * 10),
+ lambda: pa.array([0, None] * 10).dictionary_encode(),
+ lambda: pa.array(["", None] * 10),
+ lambda: pa.array(["", None] * 10).dictionary_encode(),
+])
+@pytest.mark.parametrize('use_dictionary', [False, True])
+@pytest.mark.parametrize('read_dictionary', [False, True])
+def test_buffer_contents(
+ array_factory, use_dictionary, read_dictionary, use_legacy_dataset
+):
+ # Test that null values are deterministically initialized to zero
+ # after a roundtrip through Parquet.
+ # See ARROW-8006 and ARROW-8011.
+ orig_table = pa.Table.from_pydict({"col": array_factory()})
+ bio = io.BytesIO()
+ pq.write_table(orig_table, bio, use_dictionary=True)
+ bio.seek(0)
+ read_dictionary = ['col'] if read_dictionary else None
+ table = pq.read_table(bio, use_threads=False,
+ read_dictionary=read_dictionary,
+ use_legacy_dataset=use_legacy_dataset)
+
+ for col in table.columns:
+ [chunk] = col.chunks
+ buf = chunk.buffers()[1]
+ assert buf.to_pybytes() == buf.size * b"\0"
+
+
+def test_parquet_compression_roundtrip(tempdir):
+ # ARROW-10480: ensure even with nonstandard Parquet file naming
+ # conventions, writing and then reading a file works. In
+ # particular, ensure that we don't automatically double-compress
+ # the stream due to auto-detecting the extension in the filename
+ table = pa.table([pa.array(range(4))], names=["ints"])
+ path = tempdir / "arrow-10480.pyarrow.gz"
+ pq.write_table(table, path, compression="GZIP")
+ result = pq.read_table(path)
+ assert result.equals(table)
+
+
+def test_empty_row_groups(tempdir):
+ # ARROW-3020
+ table = pa.Table.from_arrays([pa.array([], type='int32')], ['f0'])
+
+ path = tempdir / 'empty_row_groups.parquet'
+
+ num_groups = 3
+ with pq.ParquetWriter(path, table.schema) as writer:
+ for i in range(num_groups):
+ writer.write_table(table)
+
+ reader = pq.ParquetFile(path)
+ assert reader.metadata.num_row_groups == num_groups
+
+ for i in range(num_groups):
+ assert reader.read_row_group(i).equals(table)
+
+
+def test_reads_over_batch(tempdir):
+ data = [None] * (1 << 20)
+ data.append([1])
+ # Large list<int64> with mostly nones and one final
+ # value. This should force batched reads when
+ # reading back.
+ table = pa.Table.from_arrays([data], ['column'])
+
+ path = tempdir / 'arrow-11607.parquet'
+ pq.write_table(table, path)
+ table2 = pq.read_table(path)
+ assert table == table2
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_compliant_nested_type.py b/src/arrow/python/pyarrow/tests/parquet/test_compliant_nested_type.py
new file mode 100644
index 000000000..91bc08df6
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_compliant_nested_type.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import pyarrow as pa
+from pyarrow.tests.parquet.common import parametrize_legacy_dataset
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import (_read_table,
+ _check_roundtrip)
+except ImportError:
+ pq = None
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+ from pyarrow.tests.parquet.common import _roundtrip_pandas_dataframe
+except ImportError:
+ pd = tm = None
+
+pytestmark = pytest.mark.parquet
+
+# Tests for ARROW-11497
+_test_data_simple = [
+ {'items': [1, 2]},
+ {'items': [0]},
+]
+
+_test_data_complex = [
+ {'items': [{'name': 'elem1', 'value': '1'},
+ {'name': 'elem2', 'value': '2'}]},
+ {'items': [{'name': 'elem1', 'value': '0'}]},
+]
+
+parametrize_test_data = pytest.mark.parametrize(
+ "test_data", [_test_data_simple, _test_data_complex])
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+@parametrize_test_data
+def test_write_compliant_nested_type_enable(tempdir,
+ use_legacy_dataset, test_data):
+ # prepare dataframe for testing
+ df = pd.DataFrame(data=test_data)
+ # verify that we can read/write pandas df with new flag
+ _roundtrip_pandas_dataframe(df,
+ write_kwargs={
+ 'use_compliant_nested_type': True},
+ use_legacy_dataset=use_legacy_dataset)
+
+ # Write to a parquet file with compliant nested type
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ path = str(tempdir / 'data.parquet')
+ with pq.ParquetWriter(path, table.schema,
+ use_compliant_nested_type=True,
+ version='2.6') as writer:
+ writer.write_table(table)
+ # Read back as a table
+ new_table = _read_table(path)
+ # Validate that "items" columns compliant to Parquet nested format
+ # Should be like this: list<element: struct<name: string, value: string>>
+ assert isinstance(new_table.schema.types[0], pa.ListType)
+ assert new_table.schema.types[0].value_field.name == 'element'
+
+ # Verify that the new table can be read/written correctly
+ _check_roundtrip(new_table,
+ use_legacy_dataset=use_legacy_dataset,
+ use_compliant_nested_type=True)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+@parametrize_test_data
+def test_write_compliant_nested_type_disable(tempdir,
+ use_legacy_dataset, test_data):
+ # prepare dataframe for testing
+ df = pd.DataFrame(data=test_data)
+ # verify that we can read/write with new flag disabled (default behaviour)
+ _roundtrip_pandas_dataframe(df, write_kwargs={},
+ use_legacy_dataset=use_legacy_dataset)
+
+ # Write to a parquet file while disabling compliant nested type
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ path = str(tempdir / 'data.parquet')
+ with pq.ParquetWriter(path, table.schema, version='2.6') as writer:
+ writer.write_table(table)
+ new_table = _read_table(path)
+
+ # Validate that "items" columns is not compliant to Parquet nested format
+ # Should be like this: list<item: struct<name: string, value: string>>
+ assert isinstance(new_table.schema.types[0], pa.ListType)
+ assert new_table.schema.types[0].value_field.name == 'item'
+
+ # Verify that the new table can be read/written correctly
+ _check_roundtrip(new_table,
+ use_legacy_dataset=use_legacy_dataset,
+ use_compliant_nested_type=False)
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_data_types.py b/src/arrow/python/pyarrow/tests/parquet/test_data_types.py
new file mode 100644
index 000000000..1e2660006
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_data_types.py
@@ -0,0 +1,529 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import decimal
+import io
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow.tests import util
+from pyarrow.tests.parquet.common import (_check_roundtrip,
+ parametrize_legacy_dataset)
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import _read_table, _write_table
+except ImportError:
+ pq = None
+
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+ from pyarrow.tests.pandas_examples import (dataframe_with_arrays,
+ dataframe_with_lists)
+ from pyarrow.tests.parquet.common import alltypes_sample
+except ImportError:
+ pd = tm = None
+
+
+pytestmark = pytest.mark.parquet
+
+
+# General roundtrip of data types
+# -----------------------------------------------------------------------------
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+@pytest.mark.parametrize('chunk_size', [None, 1000])
+def test_parquet_2_0_roundtrip(tempdir, chunk_size, use_legacy_dataset):
+ df = alltypes_sample(size=10000, categorical=True)
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+ assert arrow_table.schema.pandas_metadata is not None
+
+ _write_table(arrow_table, filename, version='2.6',
+ coerce_timestamps='ms', chunk_size=chunk_size)
+ table_read = pq.read_pandas(
+ filename, use_legacy_dataset=use_legacy_dataset)
+ assert table_read.schema.pandas_metadata is not None
+
+ read_metadata = table_read.schema.metadata
+ assert arrow_table.schema.metadata == read_metadata
+
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_parquet_1_0_roundtrip(tempdir, use_legacy_dataset):
+ size = 10000
+ np.random.seed(0)
+ df = pd.DataFrame({
+ 'uint8': np.arange(size, dtype=np.uint8),
+ 'uint16': np.arange(size, dtype=np.uint16),
+ 'uint32': np.arange(size, dtype=np.uint32),
+ 'uint64': np.arange(size, dtype=np.uint64),
+ 'int8': np.arange(size, dtype=np.int16),
+ 'int16': np.arange(size, dtype=np.int16),
+ 'int32': np.arange(size, dtype=np.int32),
+ 'int64': np.arange(size, dtype=np.int64),
+ 'float32': np.arange(size, dtype=np.float32),
+ 'float64': np.arange(size, dtype=np.float64),
+ 'bool': np.random.randn(size) > 0,
+ 'str': [str(x) for x in range(size)],
+ 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None],
+ 'empty_str': [''] * size
+ })
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+ _write_table(arrow_table, filename, version='1.0')
+ table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+
+ # We pass uint32_t as int64_t if we write Parquet version 1.0
+ df['uint32'] = df['uint32'].values.astype(np.int64)
+
+ tm.assert_frame_equal(df, df_read)
+
+
+# Dictionary
+# -----------------------------------------------------------------------------
+
+
+def _simple_table_write_read(table, use_legacy_dataset):
+ bio = pa.BufferOutputStream()
+ pq.write_table(table, bio)
+ contents = bio.getvalue()
+ return pq.read_table(
+ pa.BufferReader(contents), use_legacy_dataset=use_legacy_dataset
+ )
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_direct_read_dictionary(use_legacy_dataset):
+ # ARROW-3325
+ repeats = 10
+ nunique = 5
+
+ data = [
+ [util.rands(10) for i in range(nunique)] * repeats,
+
+ ]
+ table = pa.table(data, names=['f0'])
+
+ bio = pa.BufferOutputStream()
+ pq.write_table(table, bio)
+ contents = bio.getvalue()
+
+ result = pq.read_table(pa.BufferReader(contents),
+ read_dictionary=['f0'],
+ use_legacy_dataset=use_legacy_dataset)
+
+ # Compute dictionary-encoded subfield
+ expected = pa.table([table[0].dictionary_encode()], names=['f0'])
+ assert result.equals(expected)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_direct_read_dictionary_subfield(use_legacy_dataset):
+ repeats = 10
+ nunique = 5
+
+ data = [
+ [[util.rands(10)] for i in range(nunique)] * repeats,
+ ]
+ table = pa.table(data, names=['f0'])
+
+ bio = pa.BufferOutputStream()
+ pq.write_table(table, bio)
+ contents = bio.getvalue()
+ result = pq.read_table(pa.BufferReader(contents),
+ read_dictionary=['f0.list.item'],
+ use_legacy_dataset=use_legacy_dataset)
+
+ arr = pa.array(data[0])
+ values_as_dict = arr.values.dictionary_encode()
+
+ inner_indices = values_as_dict.indices.cast('int32')
+ new_values = pa.DictionaryArray.from_arrays(inner_indices,
+ values_as_dict.dictionary)
+
+ offsets = pa.array(range(51), type='int32')
+ expected_arr = pa.ListArray.from_arrays(offsets, new_values)
+ expected = pa.table([expected_arr], names=['f0'])
+
+ assert result.equals(expected)
+ assert result[0].num_chunks == 1
+
+
+@parametrize_legacy_dataset
+def test_dictionary_array_automatically_read(use_legacy_dataset):
+ # ARROW-3246
+
+ # Make a large dictionary, a little over 4MB of data
+ dict_length = 4000
+ dict_values = pa.array([('x' * 1000 + '_{}'.format(i))
+ for i in range(dict_length)])
+
+ num_chunks = 10
+ chunk_size = 100
+ chunks = []
+ for i in range(num_chunks):
+ indices = np.random.randint(0, dict_length,
+ size=chunk_size).astype(np.int32)
+ chunks.append(pa.DictionaryArray.from_arrays(pa.array(indices),
+ dict_values))
+
+ table = pa.table([pa.chunked_array(chunks)], names=['f0'])
+ result = _simple_table_write_read(table, use_legacy_dataset)
+
+ assert result.equals(table)
+
+ # The only key in the metadata was the Arrow schema key
+ assert result.schema.metadata is None
+
+
+# Decimal
+# -----------------------------------------------------------------------------
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_decimal_roundtrip(tempdir, use_legacy_dataset):
+ num_values = 10
+
+ columns = {}
+ for precision in range(1, 39):
+ for scale in range(0, precision + 1):
+ with util.random_seed(0):
+ random_decimal_values = [
+ util.randdecimal(precision, scale)
+ for _ in range(num_values)
+ ]
+ column_name = ('dec_precision_{:d}_scale_{:d}'
+ .format(precision, scale))
+ columns[column_name] = random_decimal_values
+
+ expected = pd.DataFrame(columns)
+ filename = tempdir / 'decimals.parquet'
+ string_filename = str(filename)
+ table = pa.Table.from_pandas(expected)
+ _write_table(table, string_filename)
+ result_table = _read_table(
+ string_filename, use_legacy_dataset=use_legacy_dataset)
+ result = result_table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+
+@pytest.mark.pandas
+@pytest.mark.xfail(
+ raises=OSError, reason='Parquet does not support negative scale'
+)
+def test_decimal_roundtrip_negative_scale(tempdir):
+ expected = pd.DataFrame({'decimal_num': [decimal.Decimal('1.23E4')]})
+ filename = tempdir / 'decimals.parquet'
+ string_filename = str(filename)
+ t = pa.Table.from_pandas(expected)
+ _write_table(t, string_filename)
+ result_table = _read_table(string_filename)
+ result = result_table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+
+# List types
+# -----------------------------------------------------------------------------
+
+
+@parametrize_legacy_dataset
+@pytest.mark.parametrize('dtype', [int, float])
+def test_single_pylist_column_roundtrip(tempdir, dtype, use_legacy_dataset):
+ filename = tempdir / 'single_{}_column.parquet'.format(dtype.__name__)
+ data = [pa.array(list(map(dtype, range(5))))]
+ table = pa.Table.from_arrays(data, names=['a'])
+ _write_table(table, filename)
+ table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset)
+ for i in range(table.num_columns):
+ col_written = table[i]
+ col_read = table_read[i]
+ assert table.field(i).name == table_read.field(i).name
+ assert col_read.num_chunks == 1
+ data_written = col_written.chunk(0)
+ data_read = col_read.chunk(0)
+ assert data_written.equals(data_read)
+
+
+@parametrize_legacy_dataset
+def test_empty_lists_table_roundtrip(use_legacy_dataset):
+ # ARROW-2744: Shouldn't crash when writing an array of empty lists
+ arr = pa.array([[], []], type=pa.list_(pa.int32()))
+ table = pa.Table.from_arrays([arr], ["A"])
+ _check_roundtrip(table, use_legacy_dataset=use_legacy_dataset)
+
+
+@parametrize_legacy_dataset
+def test_nested_list_nonnullable_roundtrip_bug(use_legacy_dataset):
+ # Reproduce failure in ARROW-5630
+ typ = pa.list_(pa.field("item", pa.float32(), False))
+ num_rows = 10000
+ t = pa.table([
+ pa.array(([[0] * ((i + 5) % 10) for i in range(0, 10)] *
+ (num_rows // 10)), type=typ)
+ ], ['a'])
+ _check_roundtrip(
+ t, data_page_size=4096, use_legacy_dataset=use_legacy_dataset)
+
+
+@parametrize_legacy_dataset
+def test_nested_list_struct_multiple_batches_roundtrip(
+ tempdir, use_legacy_dataset
+):
+ # Reproduce failure in ARROW-11024
+ data = [[{'x': 'abc', 'y': 'abc'}]]*100 + [[{'x': 'abc', 'y': 'gcb'}]]*100
+ table = pa.table([pa.array(data)], names=['column'])
+ _check_roundtrip(
+ table, row_group_size=20, use_legacy_dataset=use_legacy_dataset)
+
+ # Reproduce failure in ARROW-11069 (plain non-nested structs with strings)
+ data = pa.array(
+ [{'a': '1', 'b': '2'}, {'a': '3', 'b': '4'}, {'a': '5', 'b': '6'}]*10
+ )
+ table = pa.table({'column': data})
+ _check_roundtrip(
+ table, row_group_size=10, use_legacy_dataset=use_legacy_dataset)
+
+
+def test_writing_empty_lists():
+ # ARROW-2591: [Python] Segmentation fault issue in pq.write_table
+ arr1 = pa.array([[], []], pa.list_(pa.int32()))
+ table = pa.Table.from_arrays([arr1], ['list(int32)'])
+ _check_roundtrip(table)
+
+
+@pytest.mark.pandas
+def test_column_of_arrays(tempdir):
+ df, schema = dataframe_with_arrays()
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df, schema=schema)
+ _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
+ table_read = _read_table(filename)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+def test_column_of_lists(tempdir):
+ df, schema = dataframe_with_lists(parquet_compatible=True)
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df, schema=schema)
+ _write_table(arrow_table, filename, version='2.6')
+ table_read = _read_table(filename)
+ df_read = table_read.to_pandas()
+
+ tm.assert_frame_equal(df, df_read)
+
+
+def test_large_list_records():
+ # This was fixed in PARQUET-1100
+
+ list_lengths = np.random.randint(0, 500, size=50)
+ list_lengths[::10] = 0
+
+ list_values = [list(map(int, np.random.randint(0, 100, size=x)))
+ if i % 8 else None
+ for i, x in enumerate(list_lengths)]
+
+ a1 = pa.array(list_values)
+
+ table = pa.Table.from_arrays([a1], ['int_lists'])
+ _check_roundtrip(table)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_parquet_nested_convenience(tempdir, use_legacy_dataset):
+ # ARROW-1684
+ df = pd.DataFrame({
+ 'a': [[1, 2, 3], None, [4, 5], []],
+ 'b': [[1.], None, None, [6., 7.]],
+ })
+
+ path = str(tempdir / 'nested_convenience.parquet')
+
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ _write_table(table, path)
+
+ read = pq.read_table(
+ path, columns=['a'], use_legacy_dataset=use_legacy_dataset)
+ tm.assert_frame_equal(read.to_pandas(), df[['a']])
+
+ read = pq.read_table(
+ path, columns=['a', 'b'], use_legacy_dataset=use_legacy_dataset)
+ tm.assert_frame_equal(read.to_pandas(), df)
+
+
+# Binary
+# -----------------------------------------------------------------------------
+
+
+def test_fixed_size_binary():
+ t0 = pa.binary(10)
+ data = [b'fooooooooo', None, b'barooooooo', b'quxooooooo']
+ a0 = pa.array(data, type=t0)
+
+ table = pa.Table.from_arrays([a0],
+ ['binary[10]'])
+ _check_roundtrip(table)
+
+
+# Large types
+# -----------------------------------------------------------------------------
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+def test_large_table_int32_overflow():
+ size = np.iinfo('int32').max + 1
+
+ arr = np.ones(size, dtype='uint8')
+
+ parr = pa.array(arr, type=pa.uint8())
+
+ table = pa.Table.from_arrays([parr], names=['one'])
+ f = io.BytesIO()
+ _write_table(table, f)
+
+
+def _simple_table_roundtrip(table, use_legacy_dataset=False, **write_kwargs):
+ stream = pa.BufferOutputStream()
+ _write_table(table, stream, **write_kwargs)
+ buf = stream.getvalue()
+ return _read_table(buf, use_legacy_dataset=use_legacy_dataset)
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+@parametrize_legacy_dataset
+def test_byte_array_exactly_2gb(use_legacy_dataset):
+ # Test edge case reported in ARROW-3762
+ val = b'x' * (1 << 10)
+
+ base = pa.array([val] * ((1 << 21) - 1))
+ cases = [
+ [b'x' * 1023], # 2^31 - 1
+ [b'x' * 1024], # 2^31
+ [b'x' * 1025] # 2^31 + 1
+ ]
+ for case in cases:
+ values = pa.chunked_array([base, pa.array(case)])
+ t = pa.table([values], names=['f0'])
+ result = _simple_table_roundtrip(
+ t, use_legacy_dataset=use_legacy_dataset, use_dictionary=False)
+ assert t.equals(result)
+
+
+@pytest.mark.slow
+@pytest.mark.pandas
+@pytest.mark.large_memory
+@parametrize_legacy_dataset
+def test_binary_array_overflow_to_chunked(use_legacy_dataset):
+ # ARROW-3762
+
+ # 2^31 + 1 bytes
+ values = [b'x'] + [
+ b'x' * (1 << 20)
+ ] * 2 * (1 << 10)
+ df = pd.DataFrame({'byte_col': values})
+
+ tbl = pa.Table.from_pandas(df, preserve_index=False)
+ read_tbl = _simple_table_roundtrip(
+ tbl, use_legacy_dataset=use_legacy_dataset)
+
+ col0_data = read_tbl[0]
+ assert isinstance(col0_data, pa.ChunkedArray)
+
+ # Split up into 2GB chunks
+ assert col0_data.num_chunks == 2
+
+ assert tbl.equals(read_tbl)
+
+
+@pytest.mark.slow
+@pytest.mark.pandas
+@pytest.mark.large_memory
+@parametrize_legacy_dataset
+def test_list_of_binary_large_cell(use_legacy_dataset):
+ # ARROW-4688
+ data = []
+
+ # TODO(wesm): handle chunked children
+ # 2^31 - 1 bytes in a single cell
+ # data.append([b'x' * (1 << 20)] * 2047 + [b'x' * ((1 << 20) - 1)])
+
+ # A little under 2GB in cell each containing approximately 10MB each
+ data.extend([[b'x' * 1000000] * 10] * 214)
+
+ arr = pa.array(data)
+ table = pa.Table.from_arrays([arr], ['chunky_cells'])
+ read_table = _simple_table_roundtrip(
+ table, use_legacy_dataset=use_legacy_dataset)
+ assert table.equals(read_table)
+
+
+def test_large_binary():
+ data = [b'foo', b'bar'] * 50
+ for type in [pa.large_binary(), pa.large_string()]:
+ arr = pa.array(data, type=type)
+ table = pa.Table.from_arrays([arr], names=['strs'])
+ for use_dictionary in [False, True]:
+ _check_roundtrip(table, use_dictionary=use_dictionary)
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+def test_large_binary_huge():
+ s = b'xy' * 997
+ data = [s] * ((1 << 33) // len(s))
+ for type in [pa.large_binary(), pa.large_string()]:
+ arr = pa.array(data, type=type)
+ table = pa.Table.from_arrays([arr], names=['strs'])
+ for use_dictionary in [False, True]:
+ _check_roundtrip(table, use_dictionary=use_dictionary)
+ del arr, table
+
+
+@pytest.mark.large_memory
+def test_large_binary_overflow():
+ s = b'x' * (1 << 31)
+ arr = pa.array([s], type=pa.large_binary())
+ table = pa.Table.from_arrays([arr], names=['strs'])
+ for use_dictionary in [False, True]:
+ writer = pa.BufferOutputStream()
+ with pytest.raises(
+ pa.ArrowInvalid,
+ match="Parquet cannot store strings with size 2GB or more"):
+ _write_table(table, writer, use_dictionary=use_dictionary)
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_dataset.py b/src/arrow/python/pyarrow/tests/parquet/test_dataset.py
new file mode 100644
index 000000000..82f7e5814
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_dataset.py
@@ -0,0 +1,1661 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import os
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow import fs
+from pyarrow.filesystem import LocalFileSystem
+from pyarrow.tests import util
+from pyarrow.tests.parquet.common import (
+ parametrize_legacy_dataset, parametrize_legacy_dataset_fixed,
+ parametrize_legacy_dataset_not_supported)
+from pyarrow.util import guid
+from pyarrow.vendored.version import Version
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import (
+ _read_table, _test_dataframe, _write_table)
+except ImportError:
+ pq = None
+
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+except ImportError:
+ pd = tm = None
+
+pytestmark = pytest.mark.parquet
+
+
+@pytest.mark.pandas
+def test_parquet_piece_read(tempdir):
+ df = _test_dataframe(1000)
+ table = pa.Table.from_pandas(df)
+
+ path = tempdir / 'parquet_piece_read.parquet'
+ _write_table(table, path, version='2.6')
+
+ with pytest.warns(DeprecationWarning):
+ piece1 = pq.ParquetDatasetPiece(path)
+
+ result = piece1.read()
+ assert result.equals(table)
+
+
+@pytest.mark.pandas
+def test_parquet_piece_open_and_get_metadata(tempdir):
+ df = _test_dataframe(100)
+ table = pa.Table.from_pandas(df)
+
+ path = tempdir / 'parquet_piece_read.parquet'
+ _write_table(table, path, version='2.6')
+
+ with pytest.warns(DeprecationWarning):
+ piece = pq.ParquetDatasetPiece(path)
+ table1 = piece.read()
+ assert isinstance(table1, pa.Table)
+ meta1 = piece.get_metadata()
+ assert isinstance(meta1, pq.FileMetaData)
+
+ assert table.equals(table1)
+
+
+@pytest.mark.filterwarnings("ignore:ParquetDatasetPiece:DeprecationWarning")
+def test_parquet_piece_basics():
+ path = '/baz.parq'
+
+ piece1 = pq.ParquetDatasetPiece(path)
+ piece2 = pq.ParquetDatasetPiece(path, row_group=1)
+ piece3 = pq.ParquetDatasetPiece(
+ path, row_group=1, partition_keys=[('foo', 0), ('bar', 1)])
+
+ assert str(piece1) == path
+ assert str(piece2) == '/baz.parq | row_group=1'
+ assert str(piece3) == 'partition[foo=0, bar=1] /baz.parq | row_group=1'
+
+ assert piece1 == piece1
+ assert piece2 == piece2
+ assert piece3 == piece3
+ assert piece1 != piece3
+
+
+def test_partition_set_dictionary_type():
+ set1 = pq.PartitionSet('key1', ['foo', 'bar', 'baz'])
+ set2 = pq.PartitionSet('key2', [2007, 2008, 2009])
+
+ assert isinstance(set1.dictionary, pa.StringArray)
+ assert isinstance(set2.dictionary, pa.IntegerArray)
+
+ set3 = pq.PartitionSet('key2', [datetime.datetime(2007, 1, 1)])
+ with pytest.raises(TypeError):
+ set3.dictionary
+
+
+@parametrize_legacy_dataset_fixed
+def test_filesystem_uri(tempdir, use_legacy_dataset):
+ table = pa.table({"a": [1, 2, 3]})
+
+ directory = tempdir / "data_dir"
+ directory.mkdir()
+ path = directory / "data.parquet"
+ pq.write_table(table, str(path))
+
+ # filesystem object
+ result = pq.read_table(
+ path, filesystem=fs.LocalFileSystem(),
+ use_legacy_dataset=use_legacy_dataset)
+ assert result.equals(table)
+
+ # filesystem URI
+ result = pq.read_table(
+ "data_dir/data.parquet", filesystem=util._filesystem_uri(tempdir),
+ use_legacy_dataset=use_legacy_dataset)
+ assert result.equals(table)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_read_partitioned_directory(tempdir, use_legacy_dataset):
+ fs = LocalFileSystem._get_instance()
+ _partition_test_for_filesystem(fs, tempdir, use_legacy_dataset)
+
+
+@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning")
+@pytest.mark.pandas
+def test_create_parquet_dataset_multi_threaded(tempdir):
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ _partition_test_for_filesystem(fs, base_path)
+
+ manifest = pq.ParquetManifest(base_path, filesystem=fs,
+ metadata_nthreads=1)
+ dataset = pq.ParquetDataset(base_path, filesystem=fs, metadata_nthreads=16)
+ assert len(dataset.pieces) > 0
+ partitions = dataset.partitions
+ assert len(partitions.partition_names) > 0
+ assert partitions.partition_names == manifest.partitions.partition_names
+ assert len(partitions.levels) == len(manifest.partitions.levels)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_read_partitioned_columns_selection(tempdir, use_legacy_dataset):
+ # ARROW-3861 - do not include partition columns in resulting table when
+ # `columns` keyword was passed without those columns
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+ _partition_test_for_filesystem(fs, base_path)
+
+ dataset = pq.ParquetDataset(
+ base_path, use_legacy_dataset=use_legacy_dataset)
+ result = dataset.read(columns=["values"])
+ if use_legacy_dataset:
+ # ParquetDataset implementation always includes the partition columns
+ # automatically, and we can't easily "fix" this since dask relies on
+ # this behaviour (ARROW-8644)
+ assert result.column_names == ["values", "foo", "bar"]
+ else:
+ assert result.column_names == ["values"]
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_filters_equivalency(tempdir, use_legacy_dataset):
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ integer_keys = [0, 1]
+ string_keys = ['a', 'b', 'c']
+ boolean_keys = [True, False]
+ partition_spec = [
+ ['integer', integer_keys],
+ ['string', string_keys],
+ ['boolean', boolean_keys]
+ ]
+
+ df = pd.DataFrame({
+ 'integer': np.array(integer_keys, dtype='i4').repeat(15),
+ 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2),
+ 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5),
+ 3),
+ }, columns=['integer', 'string', 'boolean'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ # Old filters syntax:
+ # integer == 1 AND string != b AND boolean == True
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs,
+ filters=[('integer', '=', 1), ('string', '!=', 'b'),
+ ('boolean', '==', 'True')],
+ use_legacy_dataset=use_legacy_dataset,
+ )
+ table = dataset.read()
+ result_df = (table.to_pandas().reset_index(drop=True))
+
+ assert 0 not in result_df['integer'].values
+ assert 'b' not in result_df['string'].values
+ assert False not in result_df['boolean'].values
+
+ # filters in disjunctive normal form:
+ # (integer == 1 AND string != b AND boolean == True) OR
+ # (integer == 2 AND boolean == False)
+ # TODO(ARROW-3388): boolean columns are reconstructed as string
+ filters = [
+ [
+ ('integer', '=', 1),
+ ('string', '!=', 'b'),
+ ('boolean', '==', 'True')
+ ],
+ [('integer', '=', 0), ('boolean', '==', 'False')]
+ ]
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs, filters=filters,
+ use_legacy_dataset=use_legacy_dataset)
+ table = dataset.read()
+ result_df = table.to_pandas().reset_index(drop=True)
+
+ # Check that all rows in the DF fulfill the filter
+ # Pandas 0.23.x has problems with indexing constant memoryviews in
+ # categoricals. Thus we need to make an explicit copy here with np.array.
+ df_filter_1 = (np.array(result_df['integer']) == 1) \
+ & (np.array(result_df['string']) != 'b') \
+ & (np.array(result_df['boolean']) == 'True')
+ df_filter_2 = (np.array(result_df['integer']) == 0) \
+ & (np.array(result_df['boolean']) == 'False')
+ assert df_filter_1.sum() > 0
+ assert df_filter_2.sum() > 0
+ assert result_df.shape[0] == (df_filter_1.sum() + df_filter_2.sum())
+
+ if use_legacy_dataset:
+ # Check for \0 in predicate values. Until they are correctly
+ # implemented in ARROW-3391, they would otherwise lead to weird
+ # results with the current code.
+ with pytest.raises(NotImplementedError):
+ filters = [[('string', '==', b'1\0a')]]
+ pq.ParquetDataset(base_path, filesystem=fs, filters=filters)
+ with pytest.raises(NotImplementedError):
+ filters = [[('string', '==', '1\0a')]]
+ pq.ParquetDataset(base_path, filesystem=fs, filters=filters)
+ else:
+ for filters in [[[('string', '==', b'1\0a')]],
+ [[('string', '==', '1\0a')]]]:
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs, filters=filters,
+ use_legacy_dataset=False)
+ assert dataset.read().num_rows == 0
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_filters_cutoff_exclusive_integer(tempdir, use_legacy_dataset):
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ integer_keys = [0, 1, 2, 3, 4]
+ partition_spec = [
+ ['integers', integer_keys],
+ ]
+ N = 5
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'integers': np.array(integer_keys, dtype='i4'),
+ }, columns=['index', 'integers'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs,
+ filters=[
+ ('integers', '<', 4),
+ ('integers', '>', 1),
+ ],
+ use_legacy_dataset=use_legacy_dataset
+ )
+ table = dataset.read()
+ result_df = (table.to_pandas()
+ .sort_values(by='index')
+ .reset_index(drop=True))
+
+ result_list = [x for x in map(int, result_df['integers'].values)]
+ assert result_list == [2, 3]
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+@pytest.mark.xfail(
+ # different error with use_legacy_datasets because result_df is no longer
+ # categorical
+ raises=(TypeError, AssertionError),
+ reason='Loss of type information in creation of categoricals.'
+)
+def test_filters_cutoff_exclusive_datetime(tempdir, use_legacy_dataset):
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ date_keys = [
+ datetime.date(2018, 4, 9),
+ datetime.date(2018, 4, 10),
+ datetime.date(2018, 4, 11),
+ datetime.date(2018, 4, 12),
+ datetime.date(2018, 4, 13)
+ ]
+ partition_spec = [
+ ['dates', date_keys]
+ ]
+ N = 5
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'dates': np.array(date_keys, dtype='datetime64'),
+ }, columns=['index', 'dates'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs,
+ filters=[
+ ('dates', '<', "2018-04-12"),
+ ('dates', '>', "2018-04-10")
+ ],
+ use_legacy_dataset=use_legacy_dataset
+ )
+ table = dataset.read()
+ result_df = (table.to_pandas()
+ .sort_values(by='index')
+ .reset_index(drop=True))
+
+ expected = pd.Categorical(
+ np.array([datetime.date(2018, 4, 11)], dtype='datetime64'),
+ categories=np.array(date_keys, dtype='datetime64'))
+
+ assert result_df['dates'].values == expected
+
+
+@pytest.mark.pandas
+@pytest.mark.dataset
+def test_filters_inclusive_datetime(tempdir):
+ # ARROW-11480
+ path = tempdir / 'timestamps.parquet'
+
+ pd.DataFrame({
+ "dates": pd.date_range("2020-01-01", periods=10, freq="D"),
+ "id": range(10)
+ }).to_parquet(path, use_deprecated_int96_timestamps=True)
+
+ table = pq.read_table(path, filters=[
+ ("dates", "<=", datetime.datetime(2020, 1, 5))
+ ])
+
+ assert table.column('id').to_pylist() == [0, 1, 2, 3, 4]
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_filters_inclusive_integer(tempdir, use_legacy_dataset):
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ integer_keys = [0, 1, 2, 3, 4]
+ partition_spec = [
+ ['integers', integer_keys],
+ ]
+ N = 5
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'integers': np.array(integer_keys, dtype='i4'),
+ }, columns=['index', 'integers'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs,
+ filters=[
+ ('integers', '<=', 3),
+ ('integers', '>=', 2),
+ ],
+ use_legacy_dataset=use_legacy_dataset
+ )
+ table = dataset.read()
+ result_df = (table.to_pandas()
+ .sort_values(by='index')
+ .reset_index(drop=True))
+
+ result_list = [int(x) for x in map(int, result_df['integers'].values)]
+ assert result_list == [2, 3]
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_filters_inclusive_set(tempdir, use_legacy_dataset):
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ integer_keys = [0, 1]
+ string_keys = ['a', 'b', 'c']
+ boolean_keys = [True, False]
+ partition_spec = [
+ ['integer', integer_keys],
+ ['string', string_keys],
+ ['boolean', boolean_keys]
+ ]
+
+ df = pd.DataFrame({
+ 'integer': np.array(integer_keys, dtype='i4').repeat(15),
+ 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2),
+ 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5),
+ 3),
+ }, columns=['integer', 'string', 'boolean'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs,
+ filters=[('string', 'in', 'ab')],
+ use_legacy_dataset=use_legacy_dataset
+ )
+ table = dataset.read()
+ result_df = (table.to_pandas().reset_index(drop=True))
+
+ assert 'a' in result_df['string'].values
+ assert 'b' in result_df['string'].values
+ assert 'c' not in result_df['string'].values
+
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs,
+ filters=[('integer', 'in', [1]), ('string', 'in', ('a', 'b')),
+ ('boolean', 'not in', {False})],
+ use_legacy_dataset=use_legacy_dataset
+ )
+ table = dataset.read()
+ result_df = (table.to_pandas().reset_index(drop=True))
+
+ assert 0 not in result_df['integer'].values
+ assert 'c' not in result_df['string'].values
+ assert False not in result_df['boolean'].values
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_filters_invalid_pred_op(tempdir, use_legacy_dataset):
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ integer_keys = [0, 1, 2, 3, 4]
+ partition_spec = [
+ ['integers', integer_keys],
+ ]
+ N = 5
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'integers': np.array(integer_keys, dtype='i4'),
+ }, columns=['index', 'integers'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ with pytest.raises(TypeError):
+ pq.ParquetDataset(base_path,
+ filesystem=fs,
+ filters=[('integers', 'in', 3), ],
+ use_legacy_dataset=use_legacy_dataset)
+
+ with pytest.raises(ValueError):
+ pq.ParquetDataset(base_path,
+ filesystem=fs,
+ filters=[('integers', '=<', 3), ],
+ use_legacy_dataset=use_legacy_dataset)
+
+ if use_legacy_dataset:
+ with pytest.raises(ValueError):
+ pq.ParquetDataset(base_path,
+ filesystem=fs,
+ filters=[('integers', 'in', set()), ],
+ use_legacy_dataset=use_legacy_dataset)
+ else:
+ # Dataset API returns empty table instead
+ dataset = pq.ParquetDataset(base_path,
+ filesystem=fs,
+ filters=[('integers', 'in', set()), ],
+ use_legacy_dataset=use_legacy_dataset)
+ assert dataset.read().num_rows == 0
+
+ if use_legacy_dataset:
+ with pytest.raises(ValueError):
+ pq.ParquetDataset(base_path,
+ filesystem=fs,
+ filters=[('integers', '!=', {3})],
+ use_legacy_dataset=use_legacy_dataset)
+ else:
+ dataset = pq.ParquetDataset(base_path,
+ filesystem=fs,
+ filters=[('integers', '!=', {3})],
+ use_legacy_dataset=use_legacy_dataset)
+ with pytest.raises(NotImplementedError):
+ assert dataset.read().num_rows == 0
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset_fixed
+def test_filters_invalid_column(tempdir, use_legacy_dataset):
+ # ARROW-5572 - raise error on invalid name in filter specification
+ # works with new dataset / xfail with legacy implementation
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ integer_keys = [0, 1, 2, 3, 4]
+ partition_spec = [['integers', integer_keys]]
+ N = 5
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'integers': np.array(integer_keys, dtype='i4'),
+ }, columns=['index', 'integers'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ msg = r"No match for FieldRef.Name\(non_existent_column\)"
+ with pytest.raises(ValueError, match=msg):
+ pq.ParquetDataset(base_path, filesystem=fs,
+ filters=[('non_existent_column', '<', 3), ],
+ use_legacy_dataset=use_legacy_dataset).read()
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_filters_read_table(tempdir, use_legacy_dataset):
+ # test that filters keyword is passed through in read_table
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ integer_keys = [0, 1, 2, 3, 4]
+ partition_spec = [
+ ['integers', integer_keys],
+ ]
+ N = 5
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'integers': np.array(integer_keys, dtype='i4'),
+ }, columns=['index', 'integers'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ table = pq.read_table(
+ base_path, filesystem=fs, filters=[('integers', '<', 3)],
+ use_legacy_dataset=use_legacy_dataset)
+ assert table.num_rows == 3
+
+ table = pq.read_table(
+ base_path, filesystem=fs, filters=[[('integers', '<', 3)]],
+ use_legacy_dataset=use_legacy_dataset)
+ assert table.num_rows == 3
+
+ table = pq.read_pandas(
+ base_path, filters=[('integers', '<', 3)],
+ use_legacy_dataset=use_legacy_dataset)
+ assert table.num_rows == 3
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset_fixed
+def test_partition_keys_with_underscores(tempdir, use_legacy_dataset):
+ # ARROW-5666 - partition field values with underscores preserve underscores
+ # xfail with legacy dataset -> they get interpreted as integers
+ fs = LocalFileSystem._get_instance()
+ base_path = tempdir
+
+ string_keys = ["2019_2", "2019_3"]
+ partition_spec = [
+ ['year_week', string_keys],
+ ]
+ N = 2
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'year_week': np.array(string_keys, dtype='object'),
+ }, columns=['index', 'year_week'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ dataset = pq.ParquetDataset(
+ base_path, use_legacy_dataset=use_legacy_dataset)
+ result = dataset.read()
+ assert result.column("year_week").to_pylist() == string_keys
+
+
+@pytest.mark.s3
+@parametrize_legacy_dataset
+def test_read_s3fs(s3_example_s3fs, use_legacy_dataset):
+ fs, path = s3_example_s3fs
+ path = path + "/test.parquet"
+ table = pa.table({"a": [1, 2, 3]})
+ _write_table(table, path, filesystem=fs)
+
+ result = _read_table(
+ path, filesystem=fs, use_legacy_dataset=use_legacy_dataset
+ )
+ assert result.equals(table)
+
+
+@pytest.mark.s3
+@parametrize_legacy_dataset
+def test_read_directory_s3fs(s3_example_s3fs, use_legacy_dataset):
+ fs, directory = s3_example_s3fs
+ path = directory + "/test.parquet"
+ table = pa.table({"a": [1, 2, 3]})
+ _write_table(table, path, filesystem=fs)
+
+ result = _read_table(
+ directory, filesystem=fs, use_legacy_dataset=use_legacy_dataset
+ )
+ assert result.equals(table)
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+@parametrize_legacy_dataset
+def test_read_partitioned_directory_s3fs_wrapper(
+ s3_example_s3fs, use_legacy_dataset
+):
+ import s3fs
+
+ from pyarrow.filesystem import S3FSWrapper
+
+ if Version(s3fs.__version__) >= Version("0.5"):
+ pytest.skip("S3FSWrapper no longer working for s3fs 0.5+")
+
+ fs, path = s3_example_s3fs
+ with pytest.warns(FutureWarning):
+ wrapper = S3FSWrapper(fs)
+ _partition_test_for_filesystem(wrapper, path)
+
+ # Check that we can auto-wrap
+ dataset = pq.ParquetDataset(
+ path, filesystem=fs, use_legacy_dataset=use_legacy_dataset
+ )
+ dataset.read()
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+@parametrize_legacy_dataset
+def test_read_partitioned_directory_s3fs(s3_example_s3fs, use_legacy_dataset):
+ fs, path = s3_example_s3fs
+ _partition_test_for_filesystem(
+ fs, path, use_legacy_dataset=use_legacy_dataset
+ )
+
+
+def _partition_test_for_filesystem(fs, base_path, use_legacy_dataset=True):
+ foo_keys = [0, 1]
+ bar_keys = ['a', 'b', 'c']
+ partition_spec = [
+ ['foo', foo_keys],
+ ['bar', bar_keys]
+ ]
+ N = 30
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'foo': np.array(foo_keys, dtype='i4').repeat(15),
+ 'bar': np.tile(np.tile(np.array(bar_keys, dtype=object), 5), 2),
+ 'values': np.random.randn(N)
+ }, columns=['index', 'foo', 'bar', 'values'])
+
+ _generate_partition_directories(fs, base_path, partition_spec, df)
+
+ dataset = pq.ParquetDataset(
+ base_path, filesystem=fs, use_legacy_dataset=use_legacy_dataset)
+ table = dataset.read()
+ result_df = (table.to_pandas()
+ .sort_values(by='index')
+ .reset_index(drop=True))
+
+ expected_df = (df.sort_values(by='index')
+ .reset_index(drop=True)
+ .reindex(columns=result_df.columns))
+
+ expected_df['foo'] = pd.Categorical(df['foo'], categories=foo_keys)
+ expected_df['bar'] = pd.Categorical(df['bar'], categories=bar_keys)
+
+ assert (result_df.columns == ['index', 'values', 'foo', 'bar']).all()
+
+ tm.assert_frame_equal(result_df, expected_df)
+
+
+def _generate_partition_directories(fs, base_dir, partition_spec, df):
+ # partition_spec : list of lists, e.g. [['foo', [0, 1, 2],
+ # ['bar', ['a', 'b', 'c']]
+ # part_table : a pyarrow.Table to write to each partition
+ DEPTH = len(partition_spec)
+
+ pathsep = getattr(fs, "pathsep", getattr(fs, "sep", "/"))
+
+ def _visit_level(base_dir, level, part_keys):
+ name, values = partition_spec[level]
+ for value in values:
+ this_part_keys = part_keys + [(name, value)]
+
+ level_dir = pathsep.join([
+ str(base_dir),
+ '{}={}'.format(name, value)
+ ])
+ fs.mkdir(level_dir)
+
+ if level == DEPTH - 1:
+ # Generate example data
+ file_path = pathsep.join([level_dir, guid()])
+ filtered_df = _filter_partition(df, this_part_keys)
+ part_table = pa.Table.from_pandas(filtered_df)
+ with fs.open(file_path, 'wb') as f:
+ _write_table(part_table, f)
+ assert fs.exists(file_path)
+
+ file_success = pathsep.join([level_dir, '_SUCCESS'])
+ with fs.open(file_success, 'wb') as f:
+ pass
+ else:
+ _visit_level(level_dir, level + 1, this_part_keys)
+ file_success = pathsep.join([level_dir, '_SUCCESS'])
+ with fs.open(file_success, 'wb') as f:
+ pass
+
+ _visit_level(base_dir, 0, [])
+
+
+def _test_read_common_metadata_files(fs, base_path):
+ import pandas as pd
+
+ import pyarrow.parquet as pq
+
+ N = 100
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'values': np.random.randn(N)
+ }, columns=['index', 'values'])
+
+ base_path = str(base_path)
+ data_path = os.path.join(base_path, 'data.parquet')
+
+ table = pa.Table.from_pandas(df)
+
+ with fs.open(data_path, 'wb') as f:
+ _write_table(table, f)
+
+ metadata_path = os.path.join(base_path, '_common_metadata')
+ with fs.open(metadata_path, 'wb') as f:
+ pq.write_metadata(table.schema, f)
+
+ dataset = pq.ParquetDataset(base_path, filesystem=fs)
+ assert dataset.common_metadata_path == str(metadata_path)
+
+ with fs.open(data_path) as f:
+ common_schema = pq.read_metadata(f).schema
+ assert dataset.schema.equals(common_schema)
+
+ # handle list of one directory
+ dataset2 = pq.ParquetDataset([base_path], filesystem=fs)
+ assert dataset2.schema.equals(dataset.schema)
+
+
+@pytest.mark.pandas
+def test_read_common_metadata_files(tempdir):
+ fs = LocalFileSystem._get_instance()
+ _test_read_common_metadata_files(fs, tempdir)
+
+
+@pytest.mark.pandas
+def test_read_metadata_files(tempdir):
+ fs = LocalFileSystem._get_instance()
+
+ N = 100
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'values': np.random.randn(N)
+ }, columns=['index', 'values'])
+
+ data_path = tempdir / 'data.parquet'
+
+ table = pa.Table.from_pandas(df)
+
+ with fs.open(data_path, 'wb') as f:
+ _write_table(table, f)
+
+ metadata_path = tempdir / '_metadata'
+ with fs.open(metadata_path, 'wb') as f:
+ pq.write_metadata(table.schema, f)
+
+ dataset = pq.ParquetDataset(tempdir, filesystem=fs)
+ assert dataset.metadata_path == str(metadata_path)
+
+ with fs.open(data_path) as f:
+ metadata_schema = pq.read_metadata(f).schema
+ assert dataset.schema.equals(metadata_schema)
+
+
+def _filter_partition(df, part_keys):
+ predicate = np.ones(len(df), dtype=bool)
+
+ to_drop = []
+ for name, value in part_keys:
+ to_drop.append(name)
+
+ # to avoid pandas warning
+ if isinstance(value, (datetime.date, datetime.datetime)):
+ value = pd.Timestamp(value)
+
+ predicate &= df[name] == value
+
+ return df[predicate].drop(to_drop, axis=1)
+
+
+@parametrize_legacy_dataset
+@pytest.mark.pandas
+def test_filter_before_validate_schema(tempdir, use_legacy_dataset):
+ # ARROW-4076 apply filter before schema validation
+ # to avoid checking unneeded schemas
+
+ # create partitioned dataset with mismatching schemas which would
+ # otherwise raise if first validation all schemas
+ dir1 = tempdir / 'A=0'
+ dir1.mkdir()
+ table1 = pa.Table.from_pandas(pd.DataFrame({'B': [1, 2, 3]}))
+ pq.write_table(table1, dir1 / 'data.parquet')
+
+ dir2 = tempdir / 'A=1'
+ dir2.mkdir()
+ table2 = pa.Table.from_pandas(pd.DataFrame({'B': ['a', 'b', 'c']}))
+ pq.write_table(table2, dir2 / 'data.parquet')
+
+ # read single file using filter
+ table = pq.read_table(tempdir, filters=[[('A', '==', 0)]],
+ use_legacy_dataset=use_legacy_dataset)
+ assert table.column('B').equals(pa.chunked_array([[1, 2, 3]]))
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_read_multiple_files(tempdir, use_legacy_dataset):
+ nfiles = 10
+ size = 5
+
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ test_data = []
+ paths = []
+ for i in range(nfiles):
+ df = _test_dataframe(size, seed=i)
+
+ # Hack so that we don't have a dtype cast in v1 files
+ df['uint32'] = df['uint32'].astype(np.int64)
+
+ path = dirpath / '{}.parquet'.format(i)
+
+ table = pa.Table.from_pandas(df)
+ _write_table(table, path)
+
+ test_data.append(table)
+ paths.append(path)
+
+ # Write a _SUCCESS.crc file
+ (dirpath / '_SUCCESS.crc').touch()
+
+ def read_multiple_files(paths, columns=None, use_threads=True, **kwargs):
+ dataset = pq.ParquetDataset(
+ paths, use_legacy_dataset=use_legacy_dataset, **kwargs)
+ return dataset.read(columns=columns, use_threads=use_threads)
+
+ result = read_multiple_files(paths)
+ expected = pa.concat_tables(test_data)
+
+ assert result.equals(expected)
+
+ # Read with provided metadata
+ # TODO(dataset) specifying metadata not yet supported
+ metadata = pq.read_metadata(paths[0])
+ if use_legacy_dataset:
+ result2 = read_multiple_files(paths, metadata=metadata)
+ assert result2.equals(expected)
+
+ result3 = pq.ParquetDataset(dirpath, schema=metadata.schema).read()
+ assert result3.equals(expected)
+ else:
+ with pytest.raises(ValueError, match="no longer supported"):
+ pq.read_table(paths, metadata=metadata, use_legacy_dataset=False)
+
+ # Read column subset
+ to_read = [0, 2, 6, result.num_columns - 1]
+
+ col_names = [result.field(i).name for i in to_read]
+ out = pq.read_table(
+ dirpath, columns=col_names, use_legacy_dataset=use_legacy_dataset
+ )
+ expected = pa.Table.from_arrays([result.column(i) for i in to_read],
+ names=col_names,
+ metadata=result.schema.metadata)
+ assert out.equals(expected)
+
+ # Read with multiple threads
+ pq.read_table(
+ dirpath, use_threads=True, use_legacy_dataset=use_legacy_dataset
+ )
+
+ # Test failure modes with non-uniform metadata
+ bad_apple = _test_dataframe(size, seed=i).iloc[:, :4]
+ bad_apple_path = tempdir / '{}.parquet'.format(guid())
+
+ t = pa.Table.from_pandas(bad_apple)
+ _write_table(t, bad_apple_path)
+
+ if not use_legacy_dataset:
+ # TODO(dataset) Dataset API skips bad files
+ return
+
+ bad_meta = pq.read_metadata(bad_apple_path)
+
+ with pytest.raises(ValueError):
+ read_multiple_files(paths + [bad_apple_path])
+
+ with pytest.raises(ValueError):
+ read_multiple_files(paths, metadata=bad_meta)
+
+ mixed_paths = [bad_apple_path, paths[0]]
+
+ with pytest.raises(ValueError):
+ read_multiple_files(mixed_paths, schema=bad_meta.schema)
+
+ with pytest.raises(ValueError):
+ read_multiple_files(mixed_paths)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_dataset_read_pandas(tempdir, use_legacy_dataset):
+ nfiles = 5
+ size = 5
+
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ test_data = []
+ frames = []
+ paths = []
+ for i in range(nfiles):
+ df = _test_dataframe(size, seed=i)
+ df.index = np.arange(i * size, (i + 1) * size)
+ df.index.name = 'index'
+
+ path = dirpath / '{}.parquet'.format(i)
+
+ table = pa.Table.from_pandas(df)
+ _write_table(table, path)
+ test_data.append(table)
+ frames.append(df)
+ paths.append(path)
+
+ dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset)
+ columns = ['uint8', 'strings']
+ result = dataset.read_pandas(columns=columns).to_pandas()
+ expected = pd.concat([x[columns] for x in frames])
+
+ tm.assert_frame_equal(result, expected)
+
+ # also be able to pass the columns as a set (ARROW-12314)
+ result = dataset.read_pandas(columns=set(columns)).to_pandas()
+ assert result.shape == expected.shape
+ # column order can be different because of using a set
+ tm.assert_frame_equal(result.reindex(columns=expected.columns), expected)
+
+
+@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning")
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_dataset_memory_map(tempdir, use_legacy_dataset):
+ # ARROW-2627: Check that we can use ParquetDataset with memory-mapping
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ df = _test_dataframe(10, seed=0)
+ path = dirpath / '{}.parquet'.format(0)
+ table = pa.Table.from_pandas(df)
+ _write_table(table, path, version='2.6')
+
+ dataset = pq.ParquetDataset(
+ dirpath, memory_map=True, use_legacy_dataset=use_legacy_dataset)
+ assert dataset.read().equals(table)
+ if use_legacy_dataset:
+ assert dataset.pieces[0].read().equals(table)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_dataset_enable_buffered_stream(tempdir, use_legacy_dataset):
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ df = _test_dataframe(10, seed=0)
+ path = dirpath / '{}.parquet'.format(0)
+ table = pa.Table.from_pandas(df)
+ _write_table(table, path, version='2.6')
+
+ with pytest.raises(ValueError):
+ pq.ParquetDataset(
+ dirpath, buffer_size=-64,
+ use_legacy_dataset=use_legacy_dataset)
+
+ for buffer_size in [128, 1024]:
+ dataset = pq.ParquetDataset(
+ dirpath, buffer_size=buffer_size,
+ use_legacy_dataset=use_legacy_dataset)
+ assert dataset.read().equals(table)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_dataset_enable_pre_buffer(tempdir, use_legacy_dataset):
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ df = _test_dataframe(10, seed=0)
+ path = dirpath / '{}.parquet'.format(0)
+ table = pa.Table.from_pandas(df)
+ _write_table(table, path, version='2.6')
+
+ for pre_buffer in (True, False):
+ dataset = pq.ParquetDataset(
+ dirpath, pre_buffer=pre_buffer,
+ use_legacy_dataset=use_legacy_dataset)
+ assert dataset.read().equals(table)
+ actual = pq.read_table(dirpath, pre_buffer=pre_buffer,
+ use_legacy_dataset=use_legacy_dataset)
+ assert actual.equals(table)
+
+
+def _make_example_multifile_dataset(base_path, nfiles=10, file_nrows=5):
+ test_data = []
+ paths = []
+ for i in range(nfiles):
+ df = _test_dataframe(file_nrows, seed=i)
+ path = base_path / '{}.parquet'.format(i)
+
+ test_data.append(_write_table(df, path))
+ paths.append(path)
+ return paths
+
+
+def _assert_dataset_paths(dataset, paths, use_legacy_dataset):
+ if use_legacy_dataset:
+ assert set(map(str, paths)) == {x.path for x in dataset._pieces}
+ else:
+ paths = [str(path.as_posix()) for path in paths]
+ assert set(paths) == set(dataset._dataset.files)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+@pytest.mark.parametrize('dir_prefix', ['_', '.'])
+def test_ignore_private_directories(tempdir, dir_prefix, use_legacy_dataset):
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ paths = _make_example_multifile_dataset(dirpath, nfiles=10,
+ file_nrows=5)
+
+ # private directory
+ (dirpath / '{}staging'.format(dir_prefix)).mkdir()
+
+ dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset)
+
+ _assert_dataset_paths(dataset, paths, use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_ignore_hidden_files_dot(tempdir, use_legacy_dataset):
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ paths = _make_example_multifile_dataset(dirpath, nfiles=10,
+ file_nrows=5)
+
+ with (dirpath / '.DS_Store').open('wb') as f:
+ f.write(b'gibberish')
+
+ with (dirpath / '.private').open('wb') as f:
+ f.write(b'gibberish')
+
+ dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset)
+
+ _assert_dataset_paths(dataset, paths, use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_ignore_hidden_files_underscore(tempdir, use_legacy_dataset):
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ paths = _make_example_multifile_dataset(dirpath, nfiles=10,
+ file_nrows=5)
+
+ with (dirpath / '_committed_123').open('wb') as f:
+ f.write(b'abcd')
+
+ with (dirpath / '_started_321').open('wb') as f:
+ f.write(b'abcd')
+
+ dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset)
+
+ _assert_dataset_paths(dataset, paths, use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+@pytest.mark.parametrize('dir_prefix', ['_', '.'])
+def test_ignore_no_private_directories_in_base_path(
+ tempdir, dir_prefix, use_legacy_dataset
+):
+ # ARROW-8427 - don't ignore explicitly listed files if parent directory
+ # is a private directory
+ dirpath = tempdir / "{0}data".format(dir_prefix) / guid()
+ dirpath.mkdir(parents=True)
+
+ paths = _make_example_multifile_dataset(dirpath, nfiles=10,
+ file_nrows=5)
+
+ dataset = pq.ParquetDataset(paths, use_legacy_dataset=use_legacy_dataset)
+ _assert_dataset_paths(dataset, paths, use_legacy_dataset)
+
+ # ARROW-9644 - don't ignore full directory with underscore in base path
+ dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset)
+ _assert_dataset_paths(dataset, paths, use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset_fixed
+def test_ignore_custom_prefixes(tempdir, use_legacy_dataset):
+ # ARROW-9573 - allow override of default ignore_prefixes
+ part = ["xxx"] * 3 + ["yyy"] * 3
+ table = pa.table([
+ pa.array(range(len(part))),
+ pa.array(part).dictionary_encode(),
+ ], names=['index', '_part'])
+
+ # TODO use_legacy_dataset ARROW-10247
+ pq.write_to_dataset(table, str(tempdir), partition_cols=['_part'])
+
+ private_duplicate = tempdir / '_private_duplicate'
+ private_duplicate.mkdir()
+ pq.write_to_dataset(table, str(private_duplicate),
+ partition_cols=['_part'])
+
+ read = pq.read_table(
+ tempdir, use_legacy_dataset=use_legacy_dataset,
+ ignore_prefixes=['_private'])
+
+ assert read.equals(table)
+
+
+@parametrize_legacy_dataset_fixed
+def test_empty_directory(tempdir, use_legacy_dataset):
+ # ARROW-5310 - reading empty directory
+ # fails with legacy implementation
+ empty_dir = tempdir / 'dataset'
+ empty_dir.mkdir()
+
+ dataset = pq.ParquetDataset(
+ empty_dir, use_legacy_dataset=use_legacy_dataset)
+ result = dataset.read()
+ assert result.num_rows == 0
+ assert result.num_columns == 0
+
+
+def _test_write_to_dataset_with_partitions(base_path,
+ use_legacy_dataset=True,
+ filesystem=None,
+ schema=None,
+ index_name=None):
+ import pandas as pd
+ import pandas.testing as tm
+
+ import pyarrow.parquet as pq
+
+ # ARROW-1400
+ output_df = pd.DataFrame({'group1': list('aaabbbbccc'),
+ 'group2': list('eefeffgeee'),
+ 'num': list(range(10)),
+ 'nan': [np.nan] * 10,
+ 'date': np.arange('2017-01-01', '2017-01-11',
+ dtype='datetime64[D]')})
+ cols = output_df.columns.tolist()
+ partition_by = ['group1', 'group2']
+ output_table = pa.Table.from_pandas(output_df, schema=schema, safe=False,
+ preserve_index=False)
+ pq.write_to_dataset(output_table, base_path, partition_by,
+ filesystem=filesystem,
+ use_legacy_dataset=use_legacy_dataset)
+
+ metadata_path = os.path.join(str(base_path), '_common_metadata')
+
+ if filesystem is not None:
+ with filesystem.open(metadata_path, 'wb') as f:
+ pq.write_metadata(output_table.schema, f)
+ else:
+ pq.write_metadata(output_table.schema, metadata_path)
+
+ # ARROW-2891: Ensure the output_schema is preserved when writing a
+ # partitioned dataset
+ dataset = pq.ParquetDataset(base_path,
+ filesystem=filesystem,
+ validate_schema=True,
+ use_legacy_dataset=use_legacy_dataset)
+ # ARROW-2209: Ensure the dataset schema also includes the partition columns
+ if use_legacy_dataset:
+ dataset_cols = set(dataset.schema.to_arrow_schema().names)
+ else:
+ # NB schema property is an arrow and not parquet schema
+ dataset_cols = set(dataset.schema.names)
+
+ assert dataset_cols == set(output_table.schema.names)
+
+ input_table = dataset.read()
+ input_df = input_table.to_pandas()
+
+ # Read data back in and compare with original DataFrame
+ # Partitioned columns added to the end of the DataFrame when read
+ input_df_cols = input_df.columns.tolist()
+ assert partition_by == input_df_cols[-1 * len(partition_by):]
+
+ input_df = input_df[cols]
+ # Partitioned columns become 'categorical' dtypes
+ for col in partition_by:
+ output_df[col] = output_df[col].astype('category')
+ tm.assert_frame_equal(output_df, input_df)
+
+
+def _test_write_to_dataset_no_partitions(base_path,
+ use_legacy_dataset=True,
+ filesystem=None):
+ import pandas as pd
+
+ import pyarrow.parquet as pq
+
+ # ARROW-1400
+ output_df = pd.DataFrame({'group1': list('aaabbbbccc'),
+ 'group2': list('eefeffgeee'),
+ 'num': list(range(10)),
+ 'date': np.arange('2017-01-01', '2017-01-11',
+ dtype='datetime64[D]')})
+ cols = output_df.columns.tolist()
+ output_table = pa.Table.from_pandas(output_df)
+
+ if filesystem is None:
+ filesystem = LocalFileSystem._get_instance()
+
+ # Without partitions, append files to root_path
+ n = 5
+ for i in range(n):
+ pq.write_to_dataset(output_table, base_path,
+ filesystem=filesystem)
+ output_files = [file for file in filesystem.ls(str(base_path))
+ if file.endswith(".parquet")]
+ assert len(output_files) == n
+
+ # Deduplicated incoming DataFrame should match
+ # original outgoing Dataframe
+ input_table = pq.ParquetDataset(
+ base_path, filesystem=filesystem,
+ use_legacy_dataset=use_legacy_dataset
+ ).read()
+ input_df = input_table.to_pandas()
+ input_df = input_df.drop_duplicates()
+ input_df = input_df[cols]
+ assert output_df.equals(input_df)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_write_to_dataset_with_partitions(tempdir, use_legacy_dataset):
+ _test_write_to_dataset_with_partitions(str(tempdir), use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_write_to_dataset_with_partitions_and_schema(
+ tempdir, use_legacy_dataset
+):
+ schema = pa.schema([pa.field('group1', type=pa.string()),
+ pa.field('group2', type=pa.string()),
+ pa.field('num', type=pa.int64()),
+ pa.field('nan', type=pa.int32()),
+ pa.field('date', type=pa.timestamp(unit='us'))])
+ _test_write_to_dataset_with_partitions(
+ str(tempdir), use_legacy_dataset, schema=schema)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_write_to_dataset_with_partitions_and_index_name(
+ tempdir, use_legacy_dataset
+):
+ _test_write_to_dataset_with_partitions(
+ str(tempdir), use_legacy_dataset, index_name='index_name')
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_write_to_dataset_no_partitions(tempdir, use_legacy_dataset):
+ _test_write_to_dataset_no_partitions(str(tempdir), use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_write_to_dataset_pathlib(tempdir, use_legacy_dataset):
+ _test_write_to_dataset_with_partitions(
+ tempdir / "test1", use_legacy_dataset)
+ _test_write_to_dataset_no_partitions(
+ tempdir / "test2", use_legacy_dataset)
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+@parametrize_legacy_dataset
+def test_write_to_dataset_pathlib_nonlocal(
+ tempdir, s3_example_s3fs, use_legacy_dataset
+):
+ # pathlib paths are only accepted for local files
+ fs, _ = s3_example_s3fs
+
+ with pytest.raises(TypeError, match="path-like objects are only allowed"):
+ _test_write_to_dataset_with_partitions(
+ tempdir / "test1", use_legacy_dataset, filesystem=fs)
+
+ with pytest.raises(TypeError, match="path-like objects are only allowed"):
+ _test_write_to_dataset_no_partitions(
+ tempdir / "test2", use_legacy_dataset, filesystem=fs)
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+@parametrize_legacy_dataset
+def test_write_to_dataset_with_partitions_s3fs(
+ s3_example_s3fs, use_legacy_dataset
+):
+ fs, path = s3_example_s3fs
+
+ _test_write_to_dataset_with_partitions(
+ path, use_legacy_dataset, filesystem=fs)
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+@parametrize_legacy_dataset
+def test_write_to_dataset_no_partitions_s3fs(
+ s3_example_s3fs, use_legacy_dataset
+):
+ fs, path = s3_example_s3fs
+
+ _test_write_to_dataset_no_partitions(
+ path, use_legacy_dataset, filesystem=fs)
+
+
+@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning")
+@pytest.mark.pandas
+@parametrize_legacy_dataset_not_supported
+def test_write_to_dataset_with_partitions_and_custom_filenames(
+ tempdir, use_legacy_dataset
+):
+ output_df = pd.DataFrame({'group1': list('aaabbbbccc'),
+ 'group2': list('eefeffgeee'),
+ 'num': list(range(10)),
+ 'nan': [np.nan] * 10,
+ 'date': np.arange('2017-01-01', '2017-01-11',
+ dtype='datetime64[D]')})
+ partition_by = ['group1', 'group2']
+ output_table = pa.Table.from_pandas(output_df)
+ path = str(tempdir)
+
+ def partition_filename_callback(keys):
+ return "{}-{}.parquet".format(*keys)
+
+ pq.write_to_dataset(output_table, path,
+ partition_by, partition_filename_callback,
+ use_legacy_dataset=use_legacy_dataset)
+
+ dataset = pq.ParquetDataset(path)
+
+ # ARROW-3538: Ensure partition filenames match the given pattern
+ # defined in the local function partition_filename_callback
+ expected_basenames = [
+ 'a-e.parquet', 'a-f.parquet',
+ 'b-e.parquet', 'b-f.parquet',
+ 'b-g.parquet', 'c-e.parquet'
+ ]
+ output_basenames = [os.path.basename(p.path) for p in dataset.pieces]
+
+ assert sorted(expected_basenames) == sorted(output_basenames)
+
+
+@pytest.mark.dataset
+@pytest.mark.pandas
+def test_write_to_dataset_filesystem(tempdir):
+ df = pd.DataFrame({'A': [1, 2, 3]})
+ table = pa.Table.from_pandas(df)
+ path = str(tempdir)
+
+ pq.write_to_dataset(table, path, filesystem=fs.LocalFileSystem())
+ result = pq.read_table(path)
+ assert result.equals(table)
+
+
+# TODO(dataset) support pickling
+def _make_dataset_for_pickling(tempdir, N=100):
+ path = tempdir / 'data.parquet'
+ fs = LocalFileSystem._get_instance()
+
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'values': np.random.randn(N)
+ }, columns=['index', 'values'])
+ table = pa.Table.from_pandas(df)
+
+ num_groups = 3
+ with pq.ParquetWriter(path, table.schema) as writer:
+ for i in range(num_groups):
+ writer.write_table(table)
+
+ reader = pq.ParquetFile(path)
+ assert reader.metadata.num_row_groups == num_groups
+
+ metadata_path = tempdir / '_metadata'
+ with fs.open(metadata_path, 'wb') as f:
+ pq.write_metadata(table.schema, f)
+
+ dataset = pq.ParquetDataset(tempdir, filesystem=fs)
+ assert dataset.metadata_path == str(metadata_path)
+
+ return dataset
+
+
+def _assert_dataset_is_picklable(dataset, pickler):
+ def is_pickleable(obj):
+ return obj == pickler.loads(pickler.dumps(obj))
+
+ assert is_pickleable(dataset)
+ assert is_pickleable(dataset.metadata)
+ assert is_pickleable(dataset.metadata.schema)
+ assert len(dataset.metadata.schema)
+ for column in dataset.metadata.schema:
+ assert is_pickleable(column)
+
+ for piece in dataset._pieces:
+ assert is_pickleable(piece)
+ metadata = piece.get_metadata()
+ assert metadata.num_row_groups
+ for i in range(metadata.num_row_groups):
+ assert is_pickleable(metadata.row_group(i))
+
+
+@pytest.mark.pandas
+def test_builtin_pickle_dataset(tempdir, datadir):
+ import pickle
+ dataset = _make_dataset_for_pickling(tempdir)
+ _assert_dataset_is_picklable(dataset, pickler=pickle)
+
+
+@pytest.mark.pandas
+def test_cloudpickle_dataset(tempdir, datadir):
+ cp = pytest.importorskip('cloudpickle')
+ dataset = _make_dataset_for_pickling(tempdir)
+ _assert_dataset_is_picklable(dataset, pickler=cp)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_partitioned_dataset(tempdir, use_legacy_dataset):
+ # ARROW-3208: Segmentation fault when reading a Parquet partitioned dataset
+ # to a Parquet file
+ path = tempdir / "ARROW-3208"
+ df = pd.DataFrame({
+ 'one': [-1, 10, 2.5, 100, 1000, 1, 29.2],
+ 'two': [-1, 10, 2, 100, 1000, 1, 11],
+ 'three': [0, 0, 0, 0, 0, 0, 0]
+ })
+ table = pa.Table.from_pandas(df)
+ pq.write_to_dataset(table, root_path=str(path),
+ partition_cols=['one', 'two'])
+ table = pq.ParquetDataset(
+ path, use_legacy_dataset=use_legacy_dataset).read()
+ pq.write_table(table, path / "output.parquet")
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_dataset_read_dictionary(tempdir, use_legacy_dataset):
+ path = tempdir / "ARROW-3325-dataset"
+ t1 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0'])
+ t2 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0'])
+ # TODO pass use_legacy_dataset (need to fix unique names)
+ pq.write_to_dataset(t1, root_path=str(path))
+ pq.write_to_dataset(t2, root_path=str(path))
+
+ result = pq.ParquetDataset(
+ path, read_dictionary=['f0'],
+ use_legacy_dataset=use_legacy_dataset).read()
+
+ # The order of the chunks is non-deterministic
+ ex_chunks = [t1[0].chunk(0).dictionary_encode(),
+ t2[0].chunk(0).dictionary_encode()]
+
+ assert result[0].num_chunks == 2
+ c0, c1 = result[0].chunk(0), result[0].chunk(1)
+ if c0.equals(ex_chunks[0]):
+ assert c1.equals(ex_chunks[1])
+ else:
+ assert c0.equals(ex_chunks[1])
+ assert c1.equals(ex_chunks[0])
+
+
+@pytest.mark.dataset
+def test_dataset_unsupported_keywords():
+
+ with pytest.raises(ValueError, match="not yet supported with the new"):
+ pq.ParquetDataset("", use_legacy_dataset=False, schema=pa.schema([]))
+
+ with pytest.raises(ValueError, match="not yet supported with the new"):
+ pq.ParquetDataset("", use_legacy_dataset=False, metadata=pa.schema([]))
+
+ with pytest.raises(ValueError, match="not yet supported with the new"):
+ pq.ParquetDataset("", use_legacy_dataset=False, validate_schema=False)
+
+ with pytest.raises(ValueError, match="not yet supported with the new"):
+ pq.ParquetDataset("", use_legacy_dataset=False, split_row_groups=True)
+
+ with pytest.raises(ValueError, match="not yet supported with the new"):
+ pq.ParquetDataset("", use_legacy_dataset=False, metadata_nthreads=4)
+
+ with pytest.raises(ValueError, match="no longer supported"):
+ pq.read_table("", use_legacy_dataset=False, metadata=pa.schema([]))
+
+
+@pytest.mark.dataset
+def test_dataset_partitioning(tempdir):
+ import pyarrow.dataset as ds
+
+ # create small dataset with directory partitioning
+ root_path = tempdir / "test_partitioning"
+ (root_path / "2012" / "10" / "01").mkdir(parents=True)
+
+ table = pa.table({'a': [1, 2, 3]})
+ pq.write_table(
+ table, str(root_path / "2012" / "10" / "01" / "data.parquet"))
+
+ # This works with new dataset API
+
+ # read_table
+ part = ds.partitioning(field_names=["year", "month", "day"])
+ result = pq.read_table(
+ str(root_path), partitioning=part, use_legacy_dataset=False)
+ assert result.column_names == ["a", "year", "month", "day"]
+
+ result = pq.ParquetDataset(
+ str(root_path), partitioning=part, use_legacy_dataset=False).read()
+ assert result.column_names == ["a", "year", "month", "day"]
+
+ # This raises an error for legacy dataset
+ with pytest.raises(ValueError):
+ pq.read_table(
+ str(root_path), partitioning=part, use_legacy_dataset=True)
+
+ with pytest.raises(ValueError):
+ pq.ParquetDataset(
+ str(root_path), partitioning=part, use_legacy_dataset=True)
+
+
+@pytest.mark.dataset
+def test_parquet_dataset_new_filesystem(tempdir):
+ # Ensure we can pass new FileSystem object to ParquetDataset
+ # (use new implementation automatically without specifying
+ # use_legacy_dataset=False)
+ table = pa.table({'a': [1, 2, 3]})
+ pq.write_table(table, tempdir / 'data.parquet')
+ # don't use simple LocalFileSystem (as that gets mapped to legacy one)
+ filesystem = fs.SubTreeFileSystem(str(tempdir), fs.LocalFileSystem())
+ dataset = pq.ParquetDataset('.', filesystem=filesystem)
+ result = dataset.read()
+ assert result.equals(table)
+
+
+@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning")
+def test_parquet_dataset_partitions_piece_path_with_fsspec(tempdir):
+ # ARROW-10462 ensure that on Windows we properly use posix-style paths
+ # as used by fsspec
+ fsspec = pytest.importorskip("fsspec")
+ filesystem = fsspec.filesystem('file')
+ table = pa.table({'a': [1, 2, 3]})
+ pq.write_table(table, tempdir / 'data.parquet')
+
+ # pass a posix-style path (using "/" also on Windows)
+ path = str(tempdir).replace("\\", "/")
+ dataset = pq.ParquetDataset(path, filesystem=filesystem)
+ # ensure the piece path is also posix-style
+ expected = path + "/data.parquet"
+ assert dataset.pieces[0].path == expected
+
+
+@pytest.mark.dataset
+def test_parquet_dataset_deprecated_properties(tempdir):
+ table = pa.table({'a': [1, 2, 3]})
+ path = tempdir / 'data.parquet'
+ pq.write_table(table, path)
+ dataset = pq.ParquetDataset(path)
+
+ with pytest.warns(DeprecationWarning, match="'ParquetDataset.pieces"):
+ dataset.pieces
+
+ with pytest.warns(DeprecationWarning, match="'ParquetDataset.partitions"):
+ dataset.partitions
+
+ with pytest.warns(DeprecationWarning, match="'ParquetDataset.memory_map"):
+ dataset.memory_map
+
+ with pytest.warns(DeprecationWarning, match="'ParquetDataset.read_dictio"):
+ dataset.read_dictionary
+
+ with pytest.warns(DeprecationWarning, match="'ParquetDataset.buffer_size"):
+ dataset.buffer_size
+
+ with pytest.warns(DeprecationWarning, match="'ParquetDataset.fs"):
+ dataset.fs
+
+ dataset2 = pq.ParquetDataset(path, use_legacy_dataset=False)
+
+ with pytest.warns(DeprecationWarning, match="'ParquetDataset.pieces"):
+ dataset2.pieces
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_datetime.py b/src/arrow/python/pyarrow/tests/parquet/test_datetime.py
new file mode 100644
index 000000000..f39f1c762
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_datetime.py
@@ -0,0 +1,440 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import io
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow.tests.parquet.common import (
+ _check_roundtrip, parametrize_legacy_dataset)
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import _read_table, _write_table
+except ImportError:
+ pq = None
+
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+ from pyarrow.tests.parquet.common import _roundtrip_pandas_dataframe
+except ImportError:
+ pd = tm = None
+
+
+pytestmark = pytest.mark.parquet
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_parquet_datetime_tz(use_legacy_dataset):
+ s = pd.Series([datetime.datetime(2017, 9, 6)])
+ s = s.dt.tz_localize('utc')
+
+ s.index = s
+
+ # Both a column and an index to hit both use cases
+ df = pd.DataFrame({'tz_aware': s,
+ 'tz_eastern': s.dt.tz_convert('US/Eastern')},
+ index=s)
+
+ f = io.BytesIO()
+
+ arrow_table = pa.Table.from_pandas(df)
+
+ _write_table(arrow_table, f, coerce_timestamps='ms')
+ f.seek(0)
+
+ table_read = pq.read_pandas(f, use_legacy_dataset=use_legacy_dataset)
+
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_datetime_timezone_tzinfo(use_legacy_dataset):
+ value = datetime.datetime(2018, 1, 1, 1, 23, 45,
+ tzinfo=datetime.timezone.utc)
+ df = pd.DataFrame({'foo': [value]})
+
+ _roundtrip_pandas_dataframe(
+ df, write_kwargs={}, use_legacy_dataset=use_legacy_dataset)
+
+
+@pytest.mark.pandas
+def test_coerce_timestamps(tempdir):
+ from collections import OrderedDict
+
+ # ARROW-622
+ arrays = OrderedDict()
+ fields = [pa.field('datetime64',
+ pa.list_(pa.timestamp('ms')))]
+ arrays['datetime64'] = [
+ np.array(['2007-07-13T01:23:34.123456789',
+ None,
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ms]'),
+ None,
+ None,
+ np.array(['2007-07-13T02',
+ None,
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ms]'),
+ ]
+
+ df = pd.DataFrame(arrays)
+ schema = pa.schema(fields)
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df, schema=schema)
+
+ _write_table(arrow_table, filename, version='2.6', coerce_timestamps='us')
+ table_read = _read_table(filename)
+ df_read = table_read.to_pandas()
+
+ df_expected = df.copy()
+ for i, x in enumerate(df_expected['datetime64']):
+ if isinstance(x, np.ndarray):
+ df_expected['datetime64'][i] = x.astype('M8[us]')
+
+ tm.assert_frame_equal(df_expected, df_read)
+
+ with pytest.raises(ValueError):
+ _write_table(arrow_table, filename, version='2.6',
+ coerce_timestamps='unknown')
+
+
+@pytest.mark.pandas
+def test_coerce_timestamps_truncated(tempdir):
+ """
+ ARROW-2555: Test that we can truncate timestamps when coercing if
+ explicitly allowed.
+ """
+ dt_us = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1,
+ second=1, microsecond=1)
+ dt_ms = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1,
+ second=1)
+
+ fields_us = [pa.field('datetime64', pa.timestamp('us'))]
+ arrays_us = {'datetime64': [dt_us, dt_ms]}
+
+ df_us = pd.DataFrame(arrays_us)
+ schema_us = pa.schema(fields_us)
+
+ filename = tempdir / 'pandas_truncated.parquet'
+ table_us = pa.Table.from_pandas(df_us, schema=schema_us)
+
+ _write_table(table_us, filename, version='2.6', coerce_timestamps='ms',
+ allow_truncated_timestamps=True)
+ table_ms = _read_table(filename)
+ df_ms = table_ms.to_pandas()
+
+ arrays_expected = {'datetime64': [dt_ms, dt_ms]}
+ df_expected = pd.DataFrame(arrays_expected)
+ tm.assert_frame_equal(df_expected, df_ms)
+
+
+@pytest.mark.pandas
+def test_date_time_types(tempdir):
+ t1 = pa.date32()
+ data1 = np.array([17259, 17260, 17261], dtype='int32')
+ a1 = pa.array(data1, type=t1)
+
+ t2 = pa.date64()
+ data2 = data1.astype('int64') * 86400000
+ a2 = pa.array(data2, type=t2)
+
+ t3 = pa.timestamp('us')
+ start = pd.Timestamp('2001-01-01').value / 1000
+ data3 = np.array([start, start + 1, start + 2], dtype='int64')
+ a3 = pa.array(data3, type=t3)
+
+ t4 = pa.time32('ms')
+ data4 = np.arange(3, dtype='i4')
+ a4 = pa.array(data4, type=t4)
+
+ t5 = pa.time64('us')
+ a5 = pa.array(data4.astype('int64'), type=t5)
+
+ t6 = pa.time32('s')
+ a6 = pa.array(data4, type=t6)
+
+ ex_t6 = pa.time32('ms')
+ ex_a6 = pa.array(data4 * 1000, type=ex_t6)
+
+ t7 = pa.timestamp('ns')
+ start = pd.Timestamp('2001-01-01').value
+ data7 = np.array([start, start + 1000, start + 2000],
+ dtype='int64')
+ a7 = pa.array(data7, type=t7)
+
+ table = pa.Table.from_arrays([a1, a2, a3, a4, a5, a6, a7],
+ ['date32', 'date64', 'timestamp[us]',
+ 'time32[s]', 'time64[us]',
+ 'time32_from64[s]',
+ 'timestamp[ns]'])
+
+ # date64 as date32
+ # time32[s] to time32[ms]
+ expected = pa.Table.from_arrays([a1, a1, a3, a4, a5, ex_a6, a7],
+ ['date32', 'date64', 'timestamp[us]',
+ 'time32[s]', 'time64[us]',
+ 'time32_from64[s]',
+ 'timestamp[ns]'])
+
+ _check_roundtrip(table, expected=expected, version='2.6')
+
+ t0 = pa.timestamp('ms')
+ data0 = np.arange(4, dtype='int64')
+ a0 = pa.array(data0, type=t0)
+
+ t1 = pa.timestamp('us')
+ data1 = np.arange(4, dtype='int64')
+ a1 = pa.array(data1, type=t1)
+
+ t2 = pa.timestamp('ns')
+ data2 = np.arange(4, dtype='int64')
+ a2 = pa.array(data2, type=t2)
+
+ table = pa.Table.from_arrays([a0, a1, a2],
+ ['ts[ms]', 'ts[us]', 'ts[ns]'])
+ expected = pa.Table.from_arrays([a0, a1, a2],
+ ['ts[ms]', 'ts[us]', 'ts[ns]'])
+
+ # int64 for all timestamps supported by default
+ filename = tempdir / 'int64_timestamps.parquet'
+ _write_table(table, filename, version='2.6')
+ parquet_schema = pq.ParquetFile(filename).schema
+ for i in range(3):
+ assert parquet_schema.column(i).physical_type == 'INT64'
+ read_table = _read_table(filename)
+ assert read_table.equals(expected)
+
+ t0_ns = pa.timestamp('ns')
+ data0_ns = np.array(data0 * 1000000, dtype='int64')
+ a0_ns = pa.array(data0_ns, type=t0_ns)
+
+ t1_ns = pa.timestamp('ns')
+ data1_ns = np.array(data1 * 1000, dtype='int64')
+ a1_ns = pa.array(data1_ns, type=t1_ns)
+
+ expected = pa.Table.from_arrays([a0_ns, a1_ns, a2],
+ ['ts[ms]', 'ts[us]', 'ts[ns]'])
+
+ # int96 nanosecond timestamps produced upon request
+ filename = tempdir / 'explicit_int96_timestamps.parquet'
+ _write_table(table, filename, version='2.6',
+ use_deprecated_int96_timestamps=True)
+ parquet_schema = pq.ParquetFile(filename).schema
+ for i in range(3):
+ assert parquet_schema.column(i).physical_type == 'INT96'
+ read_table = _read_table(filename)
+ assert read_table.equals(expected)
+
+ # int96 nanosecond timestamps implied by flavor 'spark'
+ filename = tempdir / 'spark_int96_timestamps.parquet'
+ _write_table(table, filename, version='2.6',
+ flavor='spark')
+ parquet_schema = pq.ParquetFile(filename).schema
+ for i in range(3):
+ assert parquet_schema.column(i).physical_type == 'INT96'
+ read_table = _read_table(filename)
+ assert read_table.equals(expected)
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns'])
+def test_coerce_int96_timestamp_unit(unit):
+ i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000
+
+ d_s = np.arange(i_s, i_s + 10, 1, dtype='int64')
+ d_ms = d_s * 1000
+ d_us = d_ms * 1000
+ d_ns = d_us * 1000
+
+ a_s = pa.array(d_s, type=pa.timestamp('s'))
+ a_ms = pa.array(d_ms, type=pa.timestamp('ms'))
+ a_us = pa.array(d_us, type=pa.timestamp('us'))
+ a_ns = pa.array(d_ns, type=pa.timestamp('ns'))
+
+ arrays = {"s": a_s, "ms": a_ms, "us": a_us, "ns": a_ns}
+ names = ['ts_s', 'ts_ms', 'ts_us', 'ts_ns']
+ table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names)
+
+ # For either Parquet version, coercing to nanoseconds is allowed
+ # if Int96 storage is used
+ expected = pa.Table.from_arrays([arrays.get(unit)]*4, names)
+ read_table_kwargs = {"coerce_int96_timestamp_unit": unit}
+ _check_roundtrip(table, expected,
+ read_table_kwargs=read_table_kwargs,
+ use_deprecated_int96_timestamps=True)
+ _check_roundtrip(table, expected, version='2.6',
+ read_table_kwargs=read_table_kwargs,
+ use_deprecated_int96_timestamps=True)
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize('pq_reader_method', ['ParquetFile', 'read_table'])
+def test_coerce_int96_timestamp_overflow(pq_reader_method, tempdir):
+
+ def get_table(pq_reader_method, filename, **kwargs):
+ if pq_reader_method == "ParquetFile":
+ return pq.ParquetFile(filename, **kwargs).read()
+ elif pq_reader_method == "read_table":
+ return pq.read_table(filename, **kwargs)
+
+ # Recreating the initial JIRA issue referrenced in ARROW-12096
+ oob_dts = [
+ datetime.datetime(1000, 1, 1),
+ datetime.datetime(2000, 1, 1),
+ datetime.datetime(3000, 1, 1)
+ ]
+ df = pd.DataFrame({"a": oob_dts})
+ table = pa.table(df)
+
+ filename = tempdir / "test_round_trip_overflow.parquet"
+ pq.write_table(table, filename, use_deprecated_int96_timestamps=True,
+ version="1.0")
+
+ # with the default resolution of ns, we get wrong values for INT96
+ # that are out of bounds for nanosecond range
+ tab_error = get_table(pq_reader_method, filename)
+ assert tab_error["a"].to_pylist() != oob_dts
+
+ # avoid this overflow by specifying the resolution to use for INT96 values
+ tab_correct = get_table(
+ pq_reader_method, filename, coerce_int96_timestamp_unit="s"
+ )
+ df_correct = tab_correct.to_pandas(timestamp_as_object=True)
+ tm.assert_frame_equal(df, df_correct)
+
+
+def test_timestamp_restore_timezone():
+ # ARROW-5888, restore timezone from serialized metadata
+ ty = pa.timestamp('ms', tz='America/New_York')
+ arr = pa.array([1, 2, 3], type=ty)
+ t = pa.table([arr], names=['f0'])
+ _check_roundtrip(t)
+
+
+def test_timestamp_restore_timezone_nanosecond():
+ # ARROW-9634, also restore timezone for nanosecond data that get stored
+ # as microseconds in the parquet file
+ ty = pa.timestamp('ns', tz='America/New_York')
+ arr = pa.array([1000, 2000, 3000], type=ty)
+ table = pa.table([arr], names=['f0'])
+ ty_us = pa.timestamp('us', tz='America/New_York')
+ expected = pa.table([arr.cast(ty_us)], names=['f0'])
+ _check_roundtrip(table, expected=expected)
+
+
+@pytest.mark.pandas
+def test_list_of_datetime_time_roundtrip():
+ # ARROW-4135
+ times = pd.to_datetime(['09:00', '09:30', '10:00', '10:30', '11:00',
+ '11:30', '12:00'])
+ df = pd.DataFrame({'time': [times.time]})
+ _roundtrip_pandas_dataframe(df, write_kwargs={})
+
+
+@pytest.mark.pandas
+def test_parquet_version_timestamp_differences():
+ i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000
+
+ d_s = np.arange(i_s, i_s + 10, 1, dtype='int64')
+ d_ms = d_s * 1000
+ d_us = d_ms * 1000
+ d_ns = d_us * 1000
+
+ a_s = pa.array(d_s, type=pa.timestamp('s'))
+ a_ms = pa.array(d_ms, type=pa.timestamp('ms'))
+ a_us = pa.array(d_us, type=pa.timestamp('us'))
+ a_ns = pa.array(d_ns, type=pa.timestamp('ns'))
+
+ names = ['ts:s', 'ts:ms', 'ts:us', 'ts:ns']
+ table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names)
+
+ # Using Parquet version 1.0, seconds should be coerced to milliseconds
+ # and nanoseconds should be coerced to microseconds by default
+ expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_us], names)
+ _check_roundtrip(table, expected)
+
+ # Using Parquet version 2.0, seconds should be coerced to milliseconds
+ # and nanoseconds should be retained by default
+ expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_ns], names)
+ _check_roundtrip(table, expected, version='2.6')
+
+ # Using Parquet version 1.0, coercing to milliseconds or microseconds
+ # is allowed
+ expected = pa.Table.from_arrays([a_ms, a_ms, a_ms, a_ms], names)
+ _check_roundtrip(table, expected, coerce_timestamps='ms')
+
+ # Using Parquet version 2.0, coercing to milliseconds or microseconds
+ # is allowed
+ expected = pa.Table.from_arrays([a_us, a_us, a_us, a_us], names)
+ _check_roundtrip(table, expected, version='2.6', coerce_timestamps='us')
+
+ # TODO: after pyarrow allows coerce_timestamps='ns', tests like the
+ # following should pass ...
+
+ # Using Parquet version 1.0, coercing to nanoseconds is not allowed
+ # expected = None
+ # with pytest.raises(NotImplementedError):
+ # _roundtrip_table(table, coerce_timestamps='ns')
+
+ # Using Parquet version 2.0, coercing to nanoseconds is allowed
+ # expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names)
+ # _check_roundtrip(table, expected, version='2.6', coerce_timestamps='ns')
+
+ # For either Parquet version, coercing to nanoseconds is allowed
+ # if Int96 storage is used
+ expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names)
+ _check_roundtrip(table, expected,
+ use_deprecated_int96_timestamps=True)
+ _check_roundtrip(table, expected, version='2.6',
+ use_deprecated_int96_timestamps=True)
+
+
+@pytest.mark.pandas
+def test_noncoerced_nanoseconds_written_without_exception(tempdir):
+ # ARROW-1957: the Parquet version 2.0 writer preserves Arrow
+ # nanosecond timestamps by default
+ n = 9
+ df = pd.DataFrame({'x': range(n)},
+ index=pd.date_range('2017-01-01', freq='1n', periods=n))
+ tb = pa.Table.from_pandas(df)
+
+ filename = tempdir / 'written.parquet'
+ try:
+ pq.write_table(tb, filename, version='2.6')
+ except Exception:
+ pass
+ assert filename.exists()
+
+ recovered_table = pq.read_table(filename)
+ assert tb.equals(recovered_table)
+
+ # Loss of data through coercion (without explicit override) still an error
+ filename = tempdir / 'not_written.parquet'
+ with pytest.raises(ValueError):
+ pq.write_table(tb, filename, coerce_timestamps='ms', version='2.6')
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_metadata.py b/src/arrow/python/pyarrow/tests/parquet/test_metadata.py
new file mode 100644
index 000000000..9fa2f6394
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_metadata.py
@@ -0,0 +1,524 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+from collections import OrderedDict
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow.tests.parquet.common import _check_roundtrip, make_sample_file
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import _write_table
+except ImportError:
+ pq = None
+
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+ from pyarrow.tests.parquet.common import alltypes_sample
+except ImportError:
+ pd = tm = None
+
+
+pytestmark = pytest.mark.parquet
+
+
+@pytest.mark.pandas
+def test_parquet_metadata_api():
+ df = alltypes_sample(size=10000)
+ df = df.reindex(columns=sorted(df.columns))
+ df.index = np.random.randint(0, 1000000, size=len(df))
+
+ fileh = make_sample_file(df)
+ ncols = len(df.columns)
+
+ # Series of sniff tests
+ meta = fileh.metadata
+ repr(meta)
+ assert meta.num_rows == len(df)
+ assert meta.num_columns == ncols + 1 # +1 for index
+ assert meta.num_row_groups == 1
+ assert meta.format_version == '2.6'
+ assert 'parquet-cpp' in meta.created_by
+ assert isinstance(meta.serialized_size, int)
+ assert isinstance(meta.metadata, dict)
+
+ # Schema
+ schema = fileh.schema
+ assert meta.schema is schema
+ assert len(schema) == ncols + 1 # +1 for index
+ repr(schema)
+
+ col = schema[0]
+ repr(col)
+ assert col.name == df.columns[0]
+ assert col.max_definition_level == 1
+ assert col.max_repetition_level == 0
+ assert col.max_repetition_level == 0
+
+ assert col.physical_type == 'BOOLEAN'
+ assert col.converted_type == 'NONE'
+
+ with pytest.raises(IndexError):
+ schema[ncols + 1] # +1 for index
+
+ with pytest.raises(IndexError):
+ schema[-1]
+
+ # Row group
+ for rg in range(meta.num_row_groups):
+ rg_meta = meta.row_group(rg)
+ assert isinstance(rg_meta, pq.RowGroupMetaData)
+ repr(rg_meta)
+
+ for col in range(rg_meta.num_columns):
+ col_meta = rg_meta.column(col)
+ assert isinstance(col_meta, pq.ColumnChunkMetaData)
+ repr(col_meta)
+
+ with pytest.raises(IndexError):
+ meta.row_group(-1)
+
+ with pytest.raises(IndexError):
+ meta.row_group(meta.num_row_groups + 1)
+
+ rg_meta = meta.row_group(0)
+ assert rg_meta.num_rows == len(df)
+ assert rg_meta.num_columns == ncols + 1 # +1 for index
+ assert rg_meta.total_byte_size > 0
+
+ with pytest.raises(IndexError):
+ col_meta = rg_meta.column(-1)
+
+ with pytest.raises(IndexError):
+ col_meta = rg_meta.column(ncols + 2)
+
+ col_meta = rg_meta.column(0)
+ assert col_meta.file_offset > 0
+ assert col_meta.file_path == '' # created from BytesIO
+ assert col_meta.physical_type == 'BOOLEAN'
+ assert col_meta.num_values == 10000
+ assert col_meta.path_in_schema == 'bool'
+ assert col_meta.is_stats_set is True
+ assert isinstance(col_meta.statistics, pq.Statistics)
+ assert col_meta.compression == 'SNAPPY'
+ assert col_meta.encodings == ('PLAIN', 'RLE')
+ assert col_meta.has_dictionary_page is False
+ assert col_meta.dictionary_page_offset is None
+ assert col_meta.data_page_offset > 0
+ assert col_meta.total_compressed_size > 0
+ assert col_meta.total_uncompressed_size > 0
+ with pytest.raises(NotImplementedError):
+ col_meta.has_index_page
+ with pytest.raises(NotImplementedError):
+ col_meta.index_page_offset
+
+
+def test_parquet_metadata_lifetime(tempdir):
+ # ARROW-6642 - ensure that chained access keeps parent objects alive
+ table = pa.table({'a': [1, 2, 3]})
+ pq.write_table(table, tempdir / 'test_metadata_segfault.parquet')
+ parquet_file = pq.ParquetFile(tempdir / 'test_metadata_segfault.parquet')
+ parquet_file.metadata.row_group(0).column(0).statistics
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize(
+ (
+ 'data',
+ 'type',
+ 'physical_type',
+ 'min_value',
+ 'max_value',
+ 'null_count',
+ 'num_values',
+ 'distinct_count'
+ ),
+ [
+ ([1, 2, 2, None, 4], pa.uint8(), 'INT32', 1, 4, 1, 4, 0),
+ ([1, 2, 2, None, 4], pa.uint16(), 'INT32', 1, 4, 1, 4, 0),
+ ([1, 2, 2, None, 4], pa.uint32(), 'INT32', 1, 4, 1, 4, 0),
+ ([1, 2, 2, None, 4], pa.uint64(), 'INT64', 1, 4, 1, 4, 0),
+ ([-1, 2, 2, None, 4], pa.int8(), 'INT32', -1, 4, 1, 4, 0),
+ ([-1, 2, 2, None, 4], pa.int16(), 'INT32', -1, 4, 1, 4, 0),
+ ([-1, 2, 2, None, 4], pa.int32(), 'INT32', -1, 4, 1, 4, 0),
+ ([-1, 2, 2, None, 4], pa.int64(), 'INT64', -1, 4, 1, 4, 0),
+ (
+ [-1.1, 2.2, 2.3, None, 4.4], pa.float32(),
+ 'FLOAT', -1.1, 4.4, 1, 4, 0
+ ),
+ (
+ [-1.1, 2.2, 2.3, None, 4.4], pa.float64(),
+ 'DOUBLE', -1.1, 4.4, 1, 4, 0
+ ),
+ (
+ ['', 'b', chr(1000), None, 'aaa'], pa.binary(),
+ 'BYTE_ARRAY', b'', chr(1000).encode('utf-8'), 1, 4, 0
+ ),
+ (
+ [True, False, False, True, True], pa.bool_(),
+ 'BOOLEAN', False, True, 0, 5, 0
+ ),
+ (
+ [b'\x00', b'b', b'12', None, b'aaa'], pa.binary(),
+ 'BYTE_ARRAY', b'\x00', b'b', 1, 4, 0
+ ),
+ ]
+)
+def test_parquet_column_statistics_api(data, type, physical_type, min_value,
+ max_value, null_count, num_values,
+ distinct_count):
+ df = pd.DataFrame({'data': data})
+ schema = pa.schema([pa.field('data', type)])
+ table = pa.Table.from_pandas(df, schema=schema, safe=False)
+ fileh = make_sample_file(table)
+
+ meta = fileh.metadata
+
+ rg_meta = meta.row_group(0)
+ col_meta = rg_meta.column(0)
+
+ stat = col_meta.statistics
+ assert stat.has_min_max
+ assert _close(type, stat.min, min_value)
+ assert _close(type, stat.max, max_value)
+ assert stat.null_count == null_count
+ assert stat.num_values == num_values
+ # TODO(kszucs) until parquet-cpp API doesn't expose HasDistinctCount
+ # method, missing distinct_count is represented as zero instead of None
+ assert stat.distinct_count == distinct_count
+ assert stat.physical_type == physical_type
+
+
+def _close(type, left, right):
+ if type == pa.float32():
+ return abs(left - right) < 1E-7
+ elif type == pa.float64():
+ return abs(left - right) < 1E-13
+ else:
+ return left == right
+
+
+# ARROW-6339
+@pytest.mark.pandas
+def test_parquet_raise_on_unset_statistics():
+ df = pd.DataFrame({"t": pd.Series([pd.NaT], dtype="datetime64[ns]")})
+ meta = make_sample_file(pa.Table.from_pandas(df)).metadata
+
+ assert not meta.row_group(0).column(0).statistics.has_min_max
+ assert meta.row_group(0).column(0).statistics.max is None
+
+
+def test_statistics_convert_logical_types(tempdir):
+ # ARROW-5166, ARROW-4139
+
+ # (min, max, type)
+ cases = [(10, 11164359321221007157, pa.uint64()),
+ (10, 4294967295, pa.uint32()),
+ ("ähnlich", "öffentlich", pa.utf8()),
+ (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000),
+ pa.time32('ms')),
+ (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000),
+ pa.time64('us')),
+ (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000),
+ datetime.datetime(2019, 6, 25, 0, 0, 0, 1000),
+ pa.timestamp('ms')),
+ (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000),
+ datetime.datetime(2019, 6, 25, 0, 0, 0, 1000),
+ pa.timestamp('us'))]
+
+ for i, (min_val, max_val, typ) in enumerate(cases):
+ t = pa.Table.from_arrays([pa.array([min_val, max_val], type=typ)],
+ ['col'])
+ path = str(tempdir / ('example{}.parquet'.format(i)))
+ pq.write_table(t, path, version='2.6')
+ pf = pq.ParquetFile(path)
+ stats = pf.metadata.row_group(0).column(0).statistics
+ assert stats.min == min_val
+ assert stats.max == max_val
+
+
+def test_parquet_write_disable_statistics(tempdir):
+ table = pa.Table.from_pydict(
+ OrderedDict([
+ ('a', pa.array([1, 2, 3])),
+ ('b', pa.array(['a', 'b', 'c']))
+ ])
+ )
+ _write_table(table, tempdir / 'data.parquet')
+ meta = pq.read_metadata(tempdir / 'data.parquet')
+ for col in [0, 1]:
+ cc = meta.row_group(0).column(col)
+ assert cc.is_stats_set is True
+ assert cc.statistics is not None
+
+ _write_table(table, tempdir / 'data2.parquet', write_statistics=False)
+ meta = pq.read_metadata(tempdir / 'data2.parquet')
+ for col in [0, 1]:
+ cc = meta.row_group(0).column(col)
+ assert cc.is_stats_set is False
+ assert cc.statistics is None
+
+ _write_table(table, tempdir / 'data3.parquet', write_statistics=['a'])
+ meta = pq.read_metadata(tempdir / 'data3.parquet')
+ cc_a = meta.row_group(0).column(0)
+ cc_b = meta.row_group(0).column(1)
+ assert cc_a.is_stats_set is True
+ assert cc_b.is_stats_set is False
+ assert cc_a.statistics is not None
+ assert cc_b.statistics is None
+
+
+def test_field_id_metadata():
+ # ARROW-7080
+ field_id = b'PARQUET:field_id'
+ inner = pa.field('inner', pa.int32(), metadata={field_id: b'100'})
+ middle = pa.field('middle', pa.struct(
+ [inner]), metadata={field_id: b'101'})
+ fields = [
+ pa.field('basic', pa.int32(), metadata={
+ b'other': b'abc', field_id: b'1'}),
+ pa.field(
+ 'list',
+ pa.list_(pa.field('list-inner', pa.int32(),
+ metadata={field_id: b'10'})),
+ metadata={field_id: b'11'}),
+ pa.field('struct', pa.struct([middle]), metadata={field_id: b'102'}),
+ pa.field('no-metadata', pa.int32()),
+ pa.field('non-integral-field-id', pa.int32(),
+ metadata={field_id: b'xyz'}),
+ pa.field('negative-field-id', pa.int32(),
+ metadata={field_id: b'-1000'})
+ ]
+ arrs = [[] for _ in fields]
+ table = pa.table(arrs, schema=pa.schema(fields))
+
+ bio = pa.BufferOutputStream()
+ pq.write_table(table, bio)
+ contents = bio.getvalue()
+
+ pf = pq.ParquetFile(pa.BufferReader(contents))
+ schema = pf.schema_arrow
+
+ assert schema[0].metadata[field_id] == b'1'
+ assert schema[0].metadata[b'other'] == b'abc'
+
+ list_field = schema[1]
+ assert list_field.metadata[field_id] == b'11'
+
+ list_item_field = list_field.type.value_field
+ assert list_item_field.metadata[field_id] == b'10'
+
+ struct_field = schema[2]
+ assert struct_field.metadata[field_id] == b'102'
+
+ struct_middle_field = struct_field.type[0]
+ assert struct_middle_field.metadata[field_id] == b'101'
+
+ struct_inner_field = struct_middle_field.type[0]
+ assert struct_inner_field.metadata[field_id] == b'100'
+
+ assert schema[3].metadata is None
+ # Invalid input is passed through (ok) but does not
+ # have field_id in parquet (not tested)
+ assert schema[4].metadata[field_id] == b'xyz'
+ assert schema[5].metadata[field_id] == b'-1000'
+
+
+@pytest.mark.pandas
+def test_multi_dataset_metadata(tempdir):
+ filenames = ["ARROW-1983-dataset.0", "ARROW-1983-dataset.1"]
+ metapath = str(tempdir / "_metadata")
+
+ # create a test dataset
+ df = pd.DataFrame({
+ 'one': [1, 2, 3],
+ 'two': [-1, -2, -3],
+ 'three': [[1, 2], [2, 3], [3, 4]],
+ })
+ table = pa.Table.from_pandas(df)
+
+ # write dataset twice and collect/merge metadata
+ _meta = None
+ for filename in filenames:
+ meta = []
+ pq.write_table(table, str(tempdir / filename),
+ metadata_collector=meta)
+ meta[0].set_file_path(filename)
+ if _meta is None:
+ _meta = meta[0]
+ else:
+ _meta.append_row_groups(meta[0])
+
+ # Write merged metadata-only file
+ with open(metapath, "wb") as f:
+ _meta.write_metadata_file(f)
+
+ # Read back the metadata
+ meta = pq.read_metadata(metapath)
+ md = meta.to_dict()
+ _md = _meta.to_dict()
+ for key in _md:
+ if key != 'serialized_size':
+ assert _md[key] == md[key]
+ assert _md['num_columns'] == 3
+ assert _md['num_rows'] == 6
+ assert _md['num_row_groups'] == 2
+ assert _md['serialized_size'] == 0
+ assert md['serialized_size'] > 0
+
+
+def test_write_metadata(tempdir):
+ path = str(tempdir / "metadata")
+ schema = pa.schema([("a", "int64"), ("b", "float64")])
+
+ # write a pyarrow schema
+ pq.write_metadata(schema, path)
+ parquet_meta = pq.read_metadata(path)
+ schema_as_arrow = parquet_meta.schema.to_arrow_schema()
+ assert schema_as_arrow.equals(schema)
+
+ # ARROW-8980: Check that the ARROW:schema metadata key was removed
+ if schema_as_arrow.metadata:
+ assert b'ARROW:schema' not in schema_as_arrow.metadata
+
+ # pass through writer keyword arguments
+ for version in ["1.0", "2.0", "2.4", "2.6"]:
+ pq.write_metadata(schema, path, version=version)
+ parquet_meta = pq.read_metadata(path)
+ # The version is stored as a single integer in the Parquet metadata,
+ # so it cannot correctly express dotted format versions
+ expected_version = "1.0" if version == "1.0" else "2.6"
+ assert parquet_meta.format_version == expected_version
+
+ # metadata_collector: list of FileMetaData objects
+ table = pa.table({'a': [1, 2], 'b': [.1, .2]}, schema=schema)
+ pq.write_table(table, tempdir / "data.parquet")
+ parquet_meta = pq.read_metadata(str(tempdir / "data.parquet"))
+ pq.write_metadata(
+ schema, path, metadata_collector=[parquet_meta, parquet_meta]
+ )
+ parquet_meta_mult = pq.read_metadata(path)
+ assert parquet_meta_mult.num_row_groups == 2
+
+ # append metadata with different schema raises an error
+ with pytest.raises(RuntimeError, match="requires equal schemas"):
+ pq.write_metadata(
+ pa.schema([("a", "int32"), ("b", "null")]),
+ path, metadata_collector=[parquet_meta, parquet_meta]
+ )
+
+
+def test_table_large_metadata():
+ # ARROW-8694
+ my_schema = pa.schema([pa.field('f0', 'double')],
+ metadata={'large': 'x' * 10000000})
+
+ table = pa.table([np.arange(10)], schema=my_schema)
+ _check_roundtrip(table)
+
+
+@pytest.mark.pandas
+def test_compare_schemas():
+ df = alltypes_sample(size=10000)
+
+ fileh = make_sample_file(df)
+ fileh2 = make_sample_file(df)
+ fileh3 = make_sample_file(df[df.columns[::2]])
+
+ # ParquetSchema
+ assert isinstance(fileh.schema, pq.ParquetSchema)
+ assert fileh.schema.equals(fileh.schema)
+ assert fileh.schema == fileh.schema
+ assert fileh.schema.equals(fileh2.schema)
+ assert fileh.schema == fileh2.schema
+ assert fileh.schema != 'arbitrary object'
+ assert not fileh.schema.equals(fileh3.schema)
+ assert fileh.schema != fileh3.schema
+
+ # ColumnSchema
+ assert isinstance(fileh.schema[0], pq.ColumnSchema)
+ assert fileh.schema[0].equals(fileh.schema[0])
+ assert fileh.schema[0] == fileh.schema[0]
+ assert not fileh.schema[0].equals(fileh.schema[1])
+ assert fileh.schema[0] != fileh.schema[1]
+ assert fileh.schema[0] != 'arbitrary object'
+
+
+@pytest.mark.pandas
+def test_read_schema(tempdir):
+ N = 100
+ df = pd.DataFrame({
+ 'index': np.arange(N),
+ 'values': np.random.randn(N)
+ }, columns=['index', 'values'])
+
+ data_path = tempdir / 'test.parquet'
+
+ table = pa.Table.from_pandas(df)
+ _write_table(table, data_path)
+
+ read1 = pq.read_schema(data_path)
+ read2 = pq.read_schema(data_path, memory_map=True)
+ assert table.schema.equals(read1)
+ assert table.schema.equals(read2)
+
+ assert table.schema.metadata[b'pandas'] == read1.metadata[b'pandas']
+
+
+def test_parquet_metadata_empty_to_dict(tempdir):
+ # https://issues.apache.org/jira/browse/ARROW-10146
+ table = pa.table({"a": pa.array([], type="int64")})
+ pq.write_table(table, tempdir / "data.parquet")
+ metadata = pq.read_metadata(tempdir / "data.parquet")
+ # ensure this doesn't error / statistics set to None
+ metadata_dict = metadata.to_dict()
+ assert len(metadata_dict["row_groups"]) == 1
+ assert len(metadata_dict["row_groups"][0]["columns"]) == 1
+ assert metadata_dict["row_groups"][0]["columns"][0]["statistics"] is None
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+def test_metadata_exceeds_message_size():
+ # ARROW-13655: Thrift may enable a defaut message size that limits
+ # the size of Parquet metadata that can be written.
+ NCOLS = 1000
+ NREPEATS = 4000
+
+ table = pa.table({str(i): np.random.randn(10) for i in range(NCOLS)})
+
+ with pa.BufferOutputStream() as out:
+ pq.write_table(table, out)
+ buf = out.getvalue()
+
+ original_metadata = pq.read_metadata(pa.BufferReader(buf))
+ metadata = pq.read_metadata(pa.BufferReader(buf))
+ for i in range(NREPEATS):
+ metadata.append_row_groups(original_metadata)
+
+ with pa.BufferOutputStream() as out:
+ metadata.write_metadata_file(out)
+ buf = out.getvalue()
+
+ metadata = pq.read_metadata(pa.BufferReader(buf))
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_pandas.py b/src/arrow/python/pyarrow/tests/parquet/test_pandas.py
new file mode 100644
index 000000000..5fc1b7060
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_pandas.py
@@ -0,0 +1,687 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import io
+import json
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow.fs import LocalFileSystem, SubTreeFileSystem
+from pyarrow.tests.parquet.common import (
+ parametrize_legacy_dataset, parametrize_legacy_dataset_not_supported)
+from pyarrow.util import guid
+from pyarrow.vendored.version import Version
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import (_read_table, _test_dataframe,
+ _write_table)
+except ImportError:
+ pq = None
+
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+ from pyarrow.tests.parquet.common import (_roundtrip_pandas_dataframe,
+ alltypes_sample)
+except ImportError:
+ pd = tm = None
+
+
+pytestmark = pytest.mark.parquet
+
+
+@pytest.mark.pandas
+def test_pandas_parquet_custom_metadata(tempdir):
+ df = alltypes_sample(size=10000)
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+ assert b'pandas' in arrow_table.schema.metadata
+
+ _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
+
+ metadata = pq.read_metadata(filename).metadata
+ assert b'pandas' in metadata
+
+ js = json.loads(metadata[b'pandas'].decode('utf8'))
+ assert js['index_columns'] == [{'kind': 'range',
+ 'name': None,
+ 'start': 0, 'stop': 10000,
+ 'step': 1}]
+
+
+@pytest.mark.pandas
+def test_merging_parquet_tables_with_different_pandas_metadata(tempdir):
+ # ARROW-3728: Merging Parquet Files - Pandas Meta in Schema Mismatch
+ schema = pa.schema([
+ pa.field('int', pa.int16()),
+ pa.field('float', pa.float32()),
+ pa.field('string', pa.string())
+ ])
+ df1 = pd.DataFrame({
+ 'int': np.arange(3, dtype=np.uint8),
+ 'float': np.arange(3, dtype=np.float32),
+ 'string': ['ABBA', 'EDDA', 'ACDC']
+ })
+ df2 = pd.DataFrame({
+ 'int': [4, 5],
+ 'float': [1.1, None],
+ 'string': [None, None]
+ })
+ table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False)
+ table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False)
+
+ assert not table1.schema.equals(table2.schema, check_metadata=True)
+ assert table1.schema.equals(table2.schema)
+
+ writer = pq.ParquetWriter(tempdir / 'merged.parquet', schema=schema)
+ writer.write_table(table1)
+ writer.write_table(table2)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_parquet_column_multiindex(tempdir, use_legacy_dataset):
+ df = alltypes_sample(size=10)
+ df.columns = pd.MultiIndex.from_tuples(
+ list(zip(df.columns, df.columns[::-1])),
+ names=['level_1', 'level_2']
+ )
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+ assert arrow_table.schema.pandas_metadata is not None
+
+ _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
+
+ table_read = pq.read_pandas(
+ filename, use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_parquet_2_0_roundtrip_read_pandas_no_index_written(
+ tempdir, use_legacy_dataset
+):
+ df = alltypes_sample(size=10000)
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+ js = arrow_table.schema.pandas_metadata
+ assert not js['index_columns']
+ # ARROW-2170
+ # While index_columns should be empty, columns needs to be filled still.
+ assert js['columns']
+
+ _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms')
+ table_read = pq.read_pandas(
+ filename, use_legacy_dataset=use_legacy_dataset)
+
+ js = table_read.schema.pandas_metadata
+ assert not js['index_columns']
+
+ read_metadata = table_read.schema.metadata
+ assert arrow_table.schema.metadata == read_metadata
+
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+# TODO(dataset) duplicate column selection actually gives duplicate columns now
+@pytest.mark.pandas
+@parametrize_legacy_dataset_not_supported
+def test_pandas_column_selection(tempdir, use_legacy_dataset):
+ size = 10000
+ np.random.seed(0)
+ df = pd.DataFrame({
+ 'uint8': np.arange(size, dtype=np.uint8),
+ 'uint16': np.arange(size, dtype=np.uint16)
+ })
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+ _write_table(arrow_table, filename)
+ table_read = _read_table(
+ filename, columns=['uint8'], use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+
+ tm.assert_frame_equal(df[['uint8']], df_read)
+
+ # ARROW-4267: Selection of duplicate columns still leads to these columns
+ # being read uniquely.
+ table_read = _read_table(
+ filename, columns=['uint8', 'uint8'],
+ use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+
+ tm.assert_frame_equal(df[['uint8']], df_read)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_parquet_native_file_roundtrip(tempdir, use_legacy_dataset):
+ df = _test_dataframe(10000)
+ arrow_table = pa.Table.from_pandas(df)
+ imos = pa.BufferOutputStream()
+ _write_table(arrow_table, imos, version='2.6')
+ buf = imos.getvalue()
+ reader = pa.BufferReader(buf)
+ df_read = _read_table(
+ reader, use_legacy_dataset=use_legacy_dataset).to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_read_pandas_column_subset(tempdir, use_legacy_dataset):
+ df = _test_dataframe(10000)
+ arrow_table = pa.Table.from_pandas(df)
+ imos = pa.BufferOutputStream()
+ _write_table(arrow_table, imos, version='2.6')
+ buf = imos.getvalue()
+ reader = pa.BufferReader(buf)
+ df_read = pq.read_pandas(
+ reader, columns=['strings', 'uint8'],
+ use_legacy_dataset=use_legacy_dataset
+ ).to_pandas()
+ tm.assert_frame_equal(df[['strings', 'uint8']], df_read)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_parquet_empty_roundtrip(tempdir, use_legacy_dataset):
+ df = _test_dataframe(0)
+ arrow_table = pa.Table.from_pandas(df)
+ imos = pa.BufferOutputStream()
+ _write_table(arrow_table, imos, version='2.6')
+ buf = imos.getvalue()
+ reader = pa.BufferReader(buf)
+ df_read = _read_table(
+ reader, use_legacy_dataset=use_legacy_dataset).to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+def test_pandas_can_write_nested_data(tempdir):
+ data = {
+ "agg_col": [
+ {"page_type": 1},
+ {"record_type": 1},
+ {"non_consecutive_home": 0},
+ ],
+ "uid_first": "1001"
+ }
+ df = pd.DataFrame(data=data)
+ arrow_table = pa.Table.from_pandas(df)
+ imos = pa.BufferOutputStream()
+ # This succeeds under V2
+ _write_table(arrow_table, imos)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_parquet_pyfile_roundtrip(tempdir, use_legacy_dataset):
+ filename = tempdir / 'pandas_pyfile_roundtrip.parquet'
+ size = 5
+ df = pd.DataFrame({
+ 'int64': np.arange(size, dtype=np.int64),
+ 'float32': np.arange(size, dtype=np.float32),
+ 'float64': np.arange(size, dtype=np.float64),
+ 'bool': np.random.randn(size) > 0,
+ 'strings': ['foo', 'bar', None, 'baz', 'qux']
+ })
+
+ arrow_table = pa.Table.from_pandas(df)
+
+ with filename.open('wb') as f:
+ _write_table(arrow_table, f, version="1.0")
+
+ data = io.BytesIO(filename.read_bytes())
+
+ table_read = _read_table(data, use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_parquet_configuration_options(tempdir, use_legacy_dataset):
+ size = 10000
+ np.random.seed(0)
+ df = pd.DataFrame({
+ 'uint8': np.arange(size, dtype=np.uint8),
+ 'uint16': np.arange(size, dtype=np.uint16),
+ 'uint32': np.arange(size, dtype=np.uint32),
+ 'uint64': np.arange(size, dtype=np.uint64),
+ 'int8': np.arange(size, dtype=np.int16),
+ 'int16': np.arange(size, dtype=np.int16),
+ 'int32': np.arange(size, dtype=np.int32),
+ 'int64': np.arange(size, dtype=np.int64),
+ 'float32': np.arange(size, dtype=np.float32),
+ 'float64': np.arange(size, dtype=np.float64),
+ 'bool': np.random.randn(size) > 0
+ })
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+
+ for use_dictionary in [True, False]:
+ _write_table(arrow_table, filename, version='2.6',
+ use_dictionary=use_dictionary)
+ table_read = _read_table(
+ filename, use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+ for write_statistics in [True, False]:
+ _write_table(arrow_table, filename, version='2.6',
+ write_statistics=write_statistics)
+ table_read = _read_table(filename,
+ use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+ for compression in ['NONE', 'SNAPPY', 'GZIP', 'LZ4', 'ZSTD']:
+ if (compression != 'NONE' and
+ not pa.lib.Codec.is_available(compression)):
+ continue
+ _write_table(arrow_table, filename, version='2.6',
+ compression=compression)
+ table_read = _read_table(
+ filename, use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df, df_read)
+
+
+@pytest.mark.pandas
+def test_spark_flavor_preserves_pandas_metadata():
+ df = _test_dataframe(size=100)
+ df.index = np.arange(0, 10 * len(df), 10)
+ df.index.name = 'foo'
+
+ result = _roundtrip_pandas_dataframe(df, {'version': '2.0',
+ 'flavor': 'spark'})
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_index_column_name_duplicate(tempdir, use_legacy_dataset):
+ data = {
+ 'close': {
+ pd.Timestamp('2017-06-30 01:31:00'): 154.99958999999998,
+ pd.Timestamp('2017-06-30 01:32:00'): 154.99958999999998,
+ },
+ 'time': {
+ pd.Timestamp('2017-06-30 01:31:00'): pd.Timestamp(
+ '2017-06-30 01:31:00'
+ ),
+ pd.Timestamp('2017-06-30 01:32:00'): pd.Timestamp(
+ '2017-06-30 01:32:00'
+ ),
+ }
+ }
+ path = str(tempdir / 'data.parquet')
+ dfx = pd.DataFrame(data).set_index('time', drop=False)
+ tdfx = pa.Table.from_pandas(dfx)
+ _write_table(tdfx, path)
+ arrow_table = _read_table(path, use_legacy_dataset=use_legacy_dataset)
+ result_df = arrow_table.to_pandas()
+ tm.assert_frame_equal(result_df, dfx)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_multiindex_duplicate_values(tempdir, use_legacy_dataset):
+ num_rows = 3
+ numbers = list(range(num_rows))
+ index = pd.MultiIndex.from_arrays(
+ [['foo', 'foo', 'bar'], numbers],
+ names=['foobar', 'some_numbers'],
+ )
+
+ df = pd.DataFrame({'numbers': numbers}, index=index)
+ table = pa.Table.from_pandas(df)
+
+ filename = tempdir / 'dup_multi_index_levels.parquet'
+
+ _write_table(table, filename)
+ result_table = _read_table(filename, use_legacy_dataset=use_legacy_dataset)
+ assert table.equals(result_table)
+
+ result_df = result_table.to_pandas()
+ tm.assert_frame_equal(result_df, df)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_backwards_compatible_index_naming(datadir, use_legacy_dataset):
+ expected_string = b"""\
+carat cut color clarity depth table price x y z
+ 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
+ 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
+ 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
+ 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
+ 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
+ 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48
+ 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47
+ 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53
+ 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49
+ 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39"""
+ expected = pd.read_csv(io.BytesIO(expected_string), sep=r'\s{2,}',
+ index_col=None, header=0, engine='python')
+ table = _read_table(
+ datadir / 'v0.7.1.parquet', use_legacy_dataset=use_legacy_dataset)
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_backwards_compatible_index_multi_level_named(
+ datadir, use_legacy_dataset
+):
+ expected_string = b"""\
+carat cut color clarity depth table price x y z
+ 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
+ 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
+ 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
+ 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
+ 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
+ 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48
+ 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47
+ 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53
+ 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49
+ 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39"""
+ expected = pd.read_csv(
+ io.BytesIO(expected_string), sep=r'\s{2,}',
+ index_col=['cut', 'color', 'clarity'],
+ header=0, engine='python'
+ ).sort_index()
+
+ table = _read_table(datadir / 'v0.7.1.all-named-index.parquet',
+ use_legacy_dataset=use_legacy_dataset)
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_backwards_compatible_index_multi_level_some_named(
+ datadir, use_legacy_dataset
+):
+ expected_string = b"""\
+carat cut color clarity depth table price x y z
+ 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
+ 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
+ 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
+ 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
+ 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
+ 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48
+ 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47
+ 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53
+ 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49
+ 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39"""
+ expected = pd.read_csv(
+ io.BytesIO(expected_string),
+ sep=r'\s{2,}', index_col=['cut', 'color', 'clarity'],
+ header=0, engine='python'
+ ).sort_index()
+ expected.index = expected.index.set_names(['cut', None, 'clarity'])
+
+ table = _read_table(datadir / 'v0.7.1.some-named-index.parquet',
+ use_legacy_dataset=use_legacy_dataset)
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_backwards_compatible_column_metadata_handling(
+ datadir, use_legacy_dataset
+):
+ expected = pd.DataFrame(
+ {'a': [1, 2, 3], 'b': [.1, .2, .3],
+ 'c': pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')})
+ expected.index = pd.MultiIndex.from_arrays(
+ [['a', 'b', 'c'],
+ pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')],
+ names=['index', None])
+
+ path = datadir / 'v0.7.1.column-metadata-handling.parquet'
+ table = _read_table(path, use_legacy_dataset=use_legacy_dataset)
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+ table = _read_table(
+ path, columns=['a'], use_legacy_dataset=use_legacy_dataset)
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, expected[['a']].reset_index(drop=True))
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_categorical_index_survives_roundtrip(use_legacy_dataset):
+ # ARROW-3652, addressed by ARROW-3246
+ df = pd.DataFrame([['a', 'b'], ['c', 'd']], columns=['c1', 'c2'])
+ df['c1'] = df['c1'].astype('category')
+ df = df.set_index(['c1'])
+
+ table = pa.Table.from_pandas(df)
+ bos = pa.BufferOutputStream()
+ pq.write_table(table, bos)
+ ref_df = pq.read_pandas(
+ bos.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas()
+ assert isinstance(ref_df.index, pd.CategoricalIndex)
+ assert ref_df.index.equals(df.index)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_categorical_order_survives_roundtrip(use_legacy_dataset):
+ # ARROW-6302
+ df = pd.DataFrame({"a": pd.Categorical(
+ ["a", "b", "c", "a"], categories=["b", "c", "d"], ordered=True)})
+
+ table = pa.Table.from_pandas(df)
+ bos = pa.BufferOutputStream()
+ pq.write_table(table, bos)
+
+ contents = bos.getvalue()
+ result = pq.read_pandas(
+ contents, use_legacy_dataset=use_legacy_dataset).to_pandas()
+
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_categorical_na_type_row_groups(use_legacy_dataset):
+ # ARROW-5085
+ df = pd.DataFrame({"col": [None] * 100, "int": [1.0] * 100})
+ df_category = df.astype({"col": "category", "int": "category"})
+ table = pa.Table.from_pandas(df)
+ table_cat = pa.Table.from_pandas(df_category)
+ buf = pa.BufferOutputStream()
+
+ # it works
+ pq.write_table(table_cat, buf, version='2.6', chunk_size=10)
+ result = pq.read_table(
+ buf.getvalue(), use_legacy_dataset=use_legacy_dataset)
+
+ # Result is non-categorical
+ assert result[0].equals(table[0])
+ assert result[1].equals(table[1])
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_pandas_categorical_roundtrip(use_legacy_dataset):
+ # ARROW-5480, this was enabled by ARROW-3246
+
+ # Have one of the categories unobserved and include a null (-1)
+ codes = np.array([2, 0, 0, 2, 0, -1, 2], dtype='int32')
+ categories = ['foo', 'bar', 'baz']
+ df = pd.DataFrame({'x': pd.Categorical.from_codes(
+ codes, categories=categories)})
+
+ buf = pa.BufferOutputStream()
+ pq.write_table(pa.table(df), buf)
+
+ result = pq.read_table(
+ buf.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas()
+ assert result.x.dtype == 'category'
+ assert (result.x.cat.categories == categories).all()
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_write_to_dataset_pandas_preserve_extensiondtypes(
+ tempdir, use_legacy_dataset
+):
+ # ARROW-8251 - preserve pandas extension dtypes in roundtrip
+ if Version(pd.__version__) < Version("1.0.0"):
+ pytest.skip("__arrow_array__ added to pandas in 1.0.0")
+
+ df = pd.DataFrame({'part': 'a', "col": [1, 2, 3]})
+ df['col'] = df['col'].astype("Int64")
+ table = pa.table(df)
+
+ pq.write_to_dataset(
+ table, str(tempdir / "case1"), partition_cols=['part'],
+ use_legacy_dataset=use_legacy_dataset
+ )
+ result = pq.read_table(
+ str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset
+ ).to_pandas()
+ tm.assert_frame_equal(result[["col"]], df[["col"]])
+
+ pq.write_to_dataset(
+ table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
+ )
+ result = pq.read_table(
+ str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
+ ).to_pandas()
+ tm.assert_frame_equal(result[["col"]], df[["col"]])
+
+ pq.write_table(table, str(tempdir / "data.parquet"))
+ result = pq.read_table(
+ str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset
+ ).to_pandas()
+ tm.assert_frame_equal(result[["col"]], df[["col"]])
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_write_to_dataset_pandas_preserve_index(tempdir, use_legacy_dataset):
+ # ARROW-8251 - preserve pandas index in roundtrip
+
+ df = pd.DataFrame({'part': ['a', 'a', 'b'], "col": [1, 2, 3]})
+ df.index = pd.Index(['a', 'b', 'c'], name="idx")
+ table = pa.table(df)
+ df_cat = df[["col", "part"]].copy()
+ df_cat["part"] = df_cat["part"].astype("category")
+
+ pq.write_to_dataset(
+ table, str(tempdir / "case1"), partition_cols=['part'],
+ use_legacy_dataset=use_legacy_dataset
+ )
+ result = pq.read_table(
+ str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset
+ ).to_pandas()
+ tm.assert_frame_equal(result, df_cat)
+
+ pq.write_to_dataset(
+ table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
+ )
+ result = pq.read_table(
+ str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset
+ ).to_pandas()
+ tm.assert_frame_equal(result, df)
+
+ pq.write_table(table, str(tempdir / "data.parquet"))
+ result = pq.read_table(
+ str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset
+ ).to_pandas()
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize('preserve_index', [True, False, None])
+def test_dataset_read_pandas_common_metadata(tempdir, preserve_index):
+ # ARROW-1103
+ nfiles = 5
+ size = 5
+
+ dirpath = tempdir / guid()
+ dirpath.mkdir()
+
+ test_data = []
+ frames = []
+ paths = []
+ for i in range(nfiles):
+ df = _test_dataframe(size, seed=i)
+ df.index = pd.Index(np.arange(i * size, (i + 1) * size), name='index')
+
+ path = dirpath / '{}.parquet'.format(i)
+
+ table = pa.Table.from_pandas(df, preserve_index=preserve_index)
+
+ # Obliterate metadata
+ table = table.replace_schema_metadata(None)
+ assert table.schema.metadata is None
+
+ _write_table(table, path)
+ test_data.append(table)
+ frames.append(df)
+ paths.append(path)
+
+ # Write _metadata common file
+ table_for_metadata = pa.Table.from_pandas(
+ df, preserve_index=preserve_index
+ )
+ pq.write_metadata(table_for_metadata.schema, dirpath / '_metadata')
+
+ dataset = pq.ParquetDataset(dirpath)
+ columns = ['uint8', 'strings']
+ result = dataset.read_pandas(columns=columns).to_pandas()
+ expected = pd.concat([x[columns] for x in frames])
+ expected.index.name = (
+ df.index.name if preserve_index is not False else None)
+ tm.assert_frame_equal(result, expected)
+
+
+@pytest.mark.pandas
+def test_read_pandas_passthrough_keywords(tempdir):
+ # ARROW-11464 - previously not all keywords were passed through (such as
+ # the filesystem keyword)
+ df = pd.DataFrame({'a': [1, 2, 3]})
+
+ filename = tempdir / 'data.parquet'
+ _write_table(df, filename)
+
+ result = pq.read_pandas(
+ 'data.parquet',
+ filesystem=SubTreeFileSystem(str(tempdir), LocalFileSystem())
+ )
+ assert result.equals(pa.table(df))
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_parquet_file.py b/src/arrow/python/pyarrow/tests/parquet/test_parquet_file.py
new file mode 100644
index 000000000..3a4d4aaa2
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_parquet_file.py
@@ -0,0 +1,276 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import io
+import os
+
+import pytest
+
+import pyarrow as pa
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import _write_table
+except ImportError:
+ pq = None
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+ from pyarrow.tests.parquet.common import alltypes_sample
+except ImportError:
+ pd = tm = None
+
+pytestmark = pytest.mark.parquet
+
+
+@pytest.mark.pandas
+def test_pass_separate_metadata():
+ # ARROW-471
+ df = alltypes_sample(size=10000)
+
+ a_table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, compression='snappy', version='2.6')
+
+ buf.seek(0)
+ metadata = pq.read_metadata(buf)
+
+ buf.seek(0)
+
+ fileh = pq.ParquetFile(buf, metadata=metadata)
+
+ tm.assert_frame_equal(df, fileh.read().to_pandas())
+
+
+@pytest.mark.pandas
+def test_read_single_row_group():
+ # ARROW-471
+ N, K = 10000, 4
+ df = alltypes_sample(size=N)
+
+ a_table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, row_group_size=N / K,
+ compression='snappy', version='2.6')
+
+ buf.seek(0)
+
+ pf = pq.ParquetFile(buf)
+
+ assert pf.num_row_groups == K
+
+ row_groups = [pf.read_row_group(i) for i in range(K)]
+ result = pa.concat_tables(row_groups)
+ tm.assert_frame_equal(df, result.to_pandas())
+
+
+@pytest.mark.pandas
+def test_read_single_row_group_with_column_subset():
+ N, K = 10000, 4
+ df = alltypes_sample(size=N)
+ a_table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, row_group_size=N / K,
+ compression='snappy', version='2.6')
+
+ buf.seek(0)
+ pf = pq.ParquetFile(buf)
+
+ cols = list(df.columns[:2])
+ row_groups = [pf.read_row_group(i, columns=cols) for i in range(K)]
+ result = pa.concat_tables(row_groups)
+ tm.assert_frame_equal(df[cols], result.to_pandas())
+
+ # ARROW-4267: Selection of duplicate columns still leads to these columns
+ # being read uniquely.
+ row_groups = [pf.read_row_group(i, columns=cols + cols) for i in range(K)]
+ result = pa.concat_tables(row_groups)
+ tm.assert_frame_equal(df[cols], result.to_pandas())
+
+
+@pytest.mark.pandas
+def test_read_multiple_row_groups():
+ N, K = 10000, 4
+ df = alltypes_sample(size=N)
+
+ a_table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, row_group_size=N / K,
+ compression='snappy', version='2.6')
+
+ buf.seek(0)
+
+ pf = pq.ParquetFile(buf)
+
+ assert pf.num_row_groups == K
+
+ result = pf.read_row_groups(range(K))
+ tm.assert_frame_equal(df, result.to_pandas())
+
+
+@pytest.mark.pandas
+def test_read_multiple_row_groups_with_column_subset():
+ N, K = 10000, 4
+ df = alltypes_sample(size=N)
+ a_table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, row_group_size=N / K,
+ compression='snappy', version='2.6')
+
+ buf.seek(0)
+ pf = pq.ParquetFile(buf)
+
+ cols = list(df.columns[:2])
+ result = pf.read_row_groups(range(K), columns=cols)
+ tm.assert_frame_equal(df[cols], result.to_pandas())
+
+ # ARROW-4267: Selection of duplicate columns still leads to these columns
+ # being read uniquely.
+ result = pf.read_row_groups(range(K), columns=cols + cols)
+ tm.assert_frame_equal(df[cols], result.to_pandas())
+
+
+@pytest.mark.pandas
+def test_scan_contents():
+ N, K = 10000, 4
+ df = alltypes_sample(size=N)
+ a_table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, row_group_size=N / K,
+ compression='snappy', version='2.6')
+
+ buf.seek(0)
+ pf = pq.ParquetFile(buf)
+
+ assert pf.scan_contents() == 10000
+ assert pf.scan_contents(df.columns[:4]) == 10000
+
+
+def test_parquet_file_pass_directory_instead_of_file(tempdir):
+ # ARROW-7208
+ path = tempdir / 'directory'
+ os.mkdir(str(path))
+
+ with pytest.raises(IOError, match="Expected file path"):
+ pq.ParquetFile(path)
+
+
+def test_read_column_invalid_index():
+ table = pa.table([pa.array([4, 5]), pa.array(["foo", "bar"])],
+ names=['ints', 'strs'])
+ bio = pa.BufferOutputStream()
+ pq.write_table(table, bio)
+ f = pq.ParquetFile(bio.getvalue())
+ assert f.reader.read_column(0).to_pylist() == [4, 5]
+ assert f.reader.read_column(1).to_pylist() == ["foo", "bar"]
+ for index in (-1, 2):
+ with pytest.raises((ValueError, IndexError)):
+ f.reader.read_column(index)
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize('batch_size', [300, 1000, 1300])
+def test_iter_batches_columns_reader(tempdir, batch_size):
+ total_size = 3000
+ chunk_size = 1000
+ # TODO: Add categorical support
+ df = alltypes_sample(size=total_size)
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+ _write_table(arrow_table, filename, version='2.6',
+ coerce_timestamps='ms', chunk_size=chunk_size)
+
+ file_ = pq.ParquetFile(filename)
+ for columns in [df.columns[:10], df.columns[10:]]:
+ batches = file_.iter_batches(batch_size=batch_size, columns=columns)
+ batch_starts = range(0, total_size+batch_size, batch_size)
+ for batch, start in zip(batches, batch_starts):
+ end = min(total_size, start + batch_size)
+ tm.assert_frame_equal(
+ batch.to_pandas(),
+ df.iloc[start:end, :].loc[:, columns].reset_index(drop=True)
+ )
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize('chunk_size', [1000])
+def test_iter_batches_reader(tempdir, chunk_size):
+ df = alltypes_sample(size=10000, categorical=True)
+
+ filename = tempdir / 'pandas_roundtrip.parquet'
+ arrow_table = pa.Table.from_pandas(df)
+ assert arrow_table.schema.pandas_metadata is not None
+
+ _write_table(arrow_table, filename, version='2.6',
+ coerce_timestamps='ms', chunk_size=chunk_size)
+
+ file_ = pq.ParquetFile(filename)
+
+ def get_all_batches(f):
+ for row_group in range(f.num_row_groups):
+ batches = f.iter_batches(
+ batch_size=900,
+ row_groups=[row_group],
+ )
+
+ for batch in batches:
+ yield batch
+
+ batches = list(get_all_batches(file_))
+ batch_no = 0
+
+ for i in range(file_.num_row_groups):
+ tm.assert_frame_equal(
+ batches[batch_no].to_pandas(),
+ file_.read_row_groups([i]).to_pandas().head(900)
+ )
+
+ batch_no += 1
+
+ tm.assert_frame_equal(
+ batches[batch_no].to_pandas().reset_index(drop=True),
+ file_.read_row_groups([i]).to_pandas().iloc[900:].reset_index(
+ drop=True
+ )
+ )
+
+ batch_no += 1
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize('pre_buffer', [False, True])
+def test_pre_buffer(pre_buffer):
+ N, K = 10000, 4
+ df = alltypes_sample(size=N)
+ a_table = pa.Table.from_pandas(df)
+
+ buf = io.BytesIO()
+ _write_table(a_table, buf, row_group_size=N / K,
+ compression='snappy', version='2.6')
+
+ buf.seek(0)
+ pf = pq.ParquetFile(buf, pre_buffer=pre_buffer)
+ assert pf.read().num_rows == N
diff --git a/src/arrow/python/pyarrow/tests/parquet/test_parquet_writer.py b/src/arrow/python/pyarrow/tests/parquet/test_parquet_writer.py
new file mode 100644
index 000000000..9be7634c8
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/parquet/test_parquet_writer.py
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import pyarrow as pa
+from pyarrow import fs
+from pyarrow.filesystem import FileSystem, LocalFileSystem
+from pyarrow.tests.parquet.common import parametrize_legacy_dataset
+
+try:
+ import pyarrow.parquet as pq
+ from pyarrow.tests.parquet.common import _read_table, _test_dataframe
+except ImportError:
+ pq = None
+
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+
+except ImportError:
+ pd = tm = None
+
+pytestmark = pytest.mark.parquet
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_parquet_incremental_file_build(tempdir, use_legacy_dataset):
+ df = _test_dataframe(100)
+ df['unique_id'] = 0
+
+ arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+ out = pa.BufferOutputStream()
+
+ writer = pq.ParquetWriter(out, arrow_table.schema, version='2.6')
+
+ frames = []
+ for i in range(10):
+ df['unique_id'] = i
+ arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+ writer.write_table(arrow_table)
+
+ frames.append(df.copy())
+
+ writer.close()
+
+ buf = out.getvalue()
+ result = _read_table(
+ pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
+
+ expected = pd.concat(frames, ignore_index=True)
+ tm.assert_frame_equal(result.to_pandas(), expected)
+
+
+def test_validate_schema_write_table(tempdir):
+ # ARROW-2926
+ simple_fields = [
+ pa.field('POS', pa.uint32()),
+ pa.field('desc', pa.string())
+ ]
+
+ simple_schema = pa.schema(simple_fields)
+
+ # simple_table schema does not match simple_schema
+ simple_from_array = [pa.array([1]), pa.array(['bla'])]
+ simple_table = pa.Table.from_arrays(simple_from_array, ['POS', 'desc'])
+
+ path = tempdir / 'simple_validate_schema.parquet'
+
+ with pq.ParquetWriter(path, simple_schema,
+ version='2.6',
+ compression='snappy', flavor='spark') as w:
+ with pytest.raises(ValueError):
+ w.write_table(simple_table)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_parquet_writer_context_obj(tempdir, use_legacy_dataset):
+ df = _test_dataframe(100)
+ df['unique_id'] = 0
+
+ arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+ out = pa.BufferOutputStream()
+
+ with pq.ParquetWriter(out, arrow_table.schema, version='2.6') as writer:
+
+ frames = []
+ for i in range(10):
+ df['unique_id'] = i
+ arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+ writer.write_table(arrow_table)
+
+ frames.append(df.copy())
+
+ buf = out.getvalue()
+ result = _read_table(
+ pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
+
+ expected = pd.concat(frames, ignore_index=True)
+ tm.assert_frame_equal(result.to_pandas(), expected)
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_parquet_writer_context_obj_with_exception(
+ tempdir, use_legacy_dataset
+):
+ df = _test_dataframe(100)
+ df['unique_id'] = 0
+
+ arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+ out = pa.BufferOutputStream()
+ error_text = 'Artificial Error'
+
+ try:
+ with pq.ParquetWriter(out,
+ arrow_table.schema,
+ version='2.6') as writer:
+
+ frames = []
+ for i in range(10):
+ df['unique_id'] = i
+ arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+ writer.write_table(arrow_table)
+ frames.append(df.copy())
+ if i == 5:
+ raise ValueError(error_text)
+ except Exception as e:
+ assert str(e) == error_text
+
+ buf = out.getvalue()
+ result = _read_table(
+ pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
+
+ expected = pd.concat(frames, ignore_index=True)
+ tm.assert_frame_equal(result.to_pandas(), expected)
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize("filesystem", [
+ None,
+ LocalFileSystem._get_instance(),
+ fs.LocalFileSystem(),
+])
+def test_parquet_writer_filesystem_local(tempdir, filesystem):
+ df = _test_dataframe(100)
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ path = str(tempdir / 'data.parquet')
+
+ with pq.ParquetWriter(
+ path, table.schema, filesystem=filesystem, version='2.6'
+ ) as writer:
+ writer.write_table(table)
+
+ result = _read_table(path).to_pandas()
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+def test_parquet_writer_filesystem_s3(s3_example_fs):
+ df = _test_dataframe(100)
+ table = pa.Table.from_pandas(df, preserve_index=False)
+
+ fs, uri, path = s3_example_fs
+
+ with pq.ParquetWriter(
+ path, table.schema, filesystem=fs, version='2.6'
+ ) as writer:
+ writer.write_table(table)
+
+ result = _read_table(uri).to_pandas()
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+def test_parquet_writer_filesystem_s3_uri(s3_example_fs):
+ df = _test_dataframe(100)
+ table = pa.Table.from_pandas(df, preserve_index=False)
+
+ fs, uri, path = s3_example_fs
+
+ with pq.ParquetWriter(uri, table.schema, version='2.6') as writer:
+ writer.write_table(table)
+
+ result = _read_table(path, filesystem=fs).to_pandas()
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@pytest.mark.s3
+def test_parquet_writer_filesystem_s3fs(s3_example_s3fs):
+ df = _test_dataframe(100)
+ table = pa.Table.from_pandas(df, preserve_index=False)
+
+ fs, directory = s3_example_s3fs
+ path = directory + "/test.parquet"
+
+ with pq.ParquetWriter(
+ path, table.schema, filesystem=fs, version='2.6'
+ ) as writer:
+ writer.write_table(table)
+
+ result = _read_table(path, filesystem=fs).to_pandas()
+ tm.assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+def test_parquet_writer_filesystem_buffer_raises():
+ df = _test_dataframe(100)
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ filesystem = fs.LocalFileSystem()
+
+ # Should raise ValueError when filesystem is passed with file-like object
+ with pytest.raises(ValueError, match="specified path is file-like"):
+ pq.ParquetWriter(
+ pa.BufferOutputStream(), table.schema, filesystem=filesystem
+ )
+
+
+@pytest.mark.pandas
+@parametrize_legacy_dataset
+def test_parquet_writer_with_caller_provided_filesystem(use_legacy_dataset):
+ out = pa.BufferOutputStream()
+
+ class CustomFS(FileSystem):
+ def __init__(self):
+ self.path = None
+ self.mode = None
+
+ def open(self, path, mode='rb'):
+ self.path = path
+ self.mode = mode
+ return out
+
+ fs = CustomFS()
+ fname = 'expected_fname.parquet'
+ df = _test_dataframe(100)
+ table = pa.Table.from_pandas(df, preserve_index=False)
+
+ with pq.ParquetWriter(fname, table.schema, filesystem=fs, version='2.6') \
+ as writer:
+ writer.write_table(table)
+
+ assert fs.path == fname
+ assert fs.mode == 'wb'
+ assert out.closed
+
+ buf = out.getvalue()
+ table_read = _read_table(
+ pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset)
+ df_read = table_read.to_pandas()
+ tm.assert_frame_equal(df_read, df)
+
+ # Should raise ValueError when filesystem is passed with file-like object
+ with pytest.raises(ValueError) as err_info:
+ pq.ParquetWriter(pa.BufferOutputStream(), table.schema, filesystem=fs)
+ expected_msg = ("filesystem passed but where is file-like, so"
+ " there is nothing to open with filesystem.")
+ assert str(err_info) == expected_msg
diff --git a/src/arrow/python/pyarrow/tests/pyarrow_cython_example.pyx b/src/arrow/python/pyarrow/tests/pyarrow_cython_example.pyx
new file mode 100644
index 000000000..08f5e17a9
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/pyarrow_cython_example.pyx
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# distutils: language=c++
+# cython: language_level = 3
+
+from pyarrow.lib cimport *
+
+
+def get_array_length(obj):
+ # An example function accessing both the pyarrow Cython API
+ # and the Arrow C++ API
+ cdef shared_ptr[CArray] arr = pyarrow_unwrap_array(obj)
+ if arr.get() == NULL:
+ raise TypeError("not an array")
+ return arr.get().length()
+
+
+def make_null_array(length):
+ # An example function that returns a PyArrow object without PyArrow
+ # being imported explicitly at the Python level.
+ cdef shared_ptr[CArray] null_array
+ null_array.reset(new CNullArray(length))
+ return pyarrow_wrap_array(null_array)
+
+
+def cast_scalar(scalar, to_type):
+ cdef:
+ shared_ptr[CScalar] c_scalar
+ shared_ptr[CDataType] c_type
+ CResult[shared_ptr[CScalar]] c_result
+
+ c_scalar = pyarrow_unwrap_scalar(scalar)
+ if c_scalar.get() == NULL:
+ raise TypeError("not a scalar")
+ c_type = pyarrow_unwrap_data_type(to_type)
+ if c_type.get() == NULL:
+ raise TypeError("not a type")
+ c_result = c_scalar.get().CastTo(c_type)
+ c_scalar = GetResultValue(c_result)
+ return pyarrow_wrap_scalar(c_scalar)
diff --git a/src/arrow/python/pyarrow/tests/strategies.py b/src/arrow/python/pyarrow/tests/strategies.py
new file mode 100644
index 000000000..d314785ff
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/strategies.py
@@ -0,0 +1,419 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+
+import pytz
+import hypothesis as h
+import hypothesis.strategies as st
+import hypothesis.extra.numpy as npst
+import hypothesis.extra.pytz as tzst
+import numpy as np
+
+import pyarrow as pa
+
+
+# TODO(kszucs): alphanum_text, surrogate_text
+custom_text = st.text(
+ alphabet=st.characters(
+ min_codepoint=0x41,
+ max_codepoint=0x7E
+ )
+)
+
+null_type = st.just(pa.null())
+bool_type = st.just(pa.bool_())
+
+binary_type = st.just(pa.binary())
+string_type = st.just(pa.string())
+large_binary_type = st.just(pa.large_binary())
+large_string_type = st.just(pa.large_string())
+fixed_size_binary_type = st.builds(
+ pa.binary,
+ st.integers(min_value=0, max_value=16)
+)
+binary_like_types = st.one_of(
+ binary_type,
+ string_type,
+ large_binary_type,
+ large_string_type,
+ fixed_size_binary_type
+)
+
+signed_integer_types = st.sampled_from([
+ pa.int8(),
+ pa.int16(),
+ pa.int32(),
+ pa.int64()
+])
+unsigned_integer_types = st.sampled_from([
+ pa.uint8(),
+ pa.uint16(),
+ pa.uint32(),
+ pa.uint64()
+])
+integer_types = st.one_of(signed_integer_types, unsigned_integer_types)
+
+floating_types = st.sampled_from([
+ pa.float16(),
+ pa.float32(),
+ pa.float64()
+])
+decimal128_type = st.builds(
+ pa.decimal128,
+ precision=st.integers(min_value=1, max_value=38),
+ scale=st.integers(min_value=1, max_value=38)
+)
+decimal256_type = st.builds(
+ pa.decimal256,
+ precision=st.integers(min_value=1, max_value=76),
+ scale=st.integers(min_value=1, max_value=76)
+)
+numeric_types = st.one_of(integer_types, floating_types,
+ decimal128_type, decimal256_type)
+
+date_types = st.sampled_from([
+ pa.date32(),
+ pa.date64()
+])
+time_types = st.sampled_from([
+ pa.time32('s'),
+ pa.time32('ms'),
+ pa.time64('us'),
+ pa.time64('ns')
+])
+timestamp_types = st.builds(
+ pa.timestamp,
+ unit=st.sampled_from(['s', 'ms', 'us', 'ns']),
+ tz=tzst.timezones()
+)
+duration_types = st.builds(
+ pa.duration,
+ st.sampled_from(['s', 'ms', 'us', 'ns'])
+)
+interval_types = st.sampled_from(
+ pa.month_day_nano_interval()
+)
+temporal_types = st.one_of(
+ date_types,
+ time_types,
+ timestamp_types,
+ duration_types,
+ interval_types
+)
+
+primitive_types = st.one_of(
+ null_type,
+ bool_type,
+ numeric_types,
+ temporal_types,
+ binary_like_types
+)
+
+metadata = st.dictionaries(st.text(), st.text())
+
+
+@st.composite
+def fields(draw, type_strategy=primitive_types):
+ name = draw(custom_text)
+ typ = draw(type_strategy)
+ if pa.types.is_null(typ):
+ nullable = True
+ else:
+ nullable = draw(st.booleans())
+ meta = draw(metadata)
+ return pa.field(name, type=typ, nullable=nullable, metadata=meta)
+
+
+def list_types(item_strategy=primitive_types):
+ return (
+ st.builds(pa.list_, item_strategy) |
+ st.builds(pa.large_list, item_strategy) |
+ st.builds(
+ pa.list_,
+ item_strategy,
+ st.integers(min_value=0, max_value=16)
+ )
+ )
+
+
+@st.composite
+def struct_types(draw, item_strategy=primitive_types):
+ fields_strategy = st.lists(fields(item_strategy))
+ fields_rendered = draw(fields_strategy)
+ field_names = [field.name for field in fields_rendered]
+ # check that field names are unique, see ARROW-9997
+ h.assume(len(set(field_names)) == len(field_names))
+ return pa.struct(fields_rendered)
+
+
+def dictionary_types(key_strategy=None, value_strategy=None):
+ key_strategy = key_strategy or signed_integer_types
+ value_strategy = value_strategy or st.one_of(
+ bool_type,
+ integer_types,
+ st.sampled_from([pa.float32(), pa.float64()]),
+ binary_type,
+ string_type,
+ fixed_size_binary_type,
+ )
+ return st.builds(pa.dictionary, key_strategy, value_strategy)
+
+
+@st.composite
+def map_types(draw, key_strategy=primitive_types,
+ item_strategy=primitive_types):
+ key_type = draw(key_strategy)
+ h.assume(not pa.types.is_null(key_type))
+ value_type = draw(item_strategy)
+ return pa.map_(key_type, value_type)
+
+
+# union type
+# extension type
+
+
+def schemas(type_strategy=primitive_types, max_fields=None):
+ children = st.lists(fields(type_strategy), max_size=max_fields)
+ return st.builds(pa.schema, children)
+
+
+all_types = st.deferred(
+ lambda: (
+ primitive_types |
+ list_types() |
+ struct_types() |
+ dictionary_types() |
+ map_types() |
+ list_types(all_types) |
+ struct_types(all_types)
+ )
+)
+all_fields = fields(all_types)
+all_schemas = schemas(all_types)
+
+
+_default_array_sizes = st.integers(min_value=0, max_value=20)
+
+
+@st.composite
+def _pylist(draw, value_type, size, nullable=True):
+ arr = draw(arrays(value_type, size=size, nullable=False))
+ return arr.to_pylist()
+
+
+@st.composite
+def _pymap(draw, key_type, value_type, size, nullable=True):
+ length = draw(size)
+ keys = draw(_pylist(key_type, size=length, nullable=False))
+ values = draw(_pylist(value_type, size=length, nullable=nullable))
+ return list(zip(keys, values))
+
+
+@st.composite
+def arrays(draw, type, size=None, nullable=True):
+ if isinstance(type, st.SearchStrategy):
+ ty = draw(type)
+ elif isinstance(type, pa.DataType):
+ ty = type
+ else:
+ raise TypeError('Type must be a pyarrow DataType')
+
+ if isinstance(size, st.SearchStrategy):
+ size = draw(size)
+ elif size is None:
+ size = draw(_default_array_sizes)
+ elif not isinstance(size, int):
+ raise TypeError('Size must be an integer')
+
+ if pa.types.is_null(ty):
+ h.assume(nullable)
+ value = st.none()
+ elif pa.types.is_boolean(ty):
+ value = st.booleans()
+ elif pa.types.is_integer(ty):
+ values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,)))
+ return pa.array(values, type=ty)
+ elif pa.types.is_floating(ty):
+ values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,)))
+ # Workaround ARROW-4952: no easy way to assert array equality
+ # in a NaN-tolerant way.
+ values[np.isnan(values)] = -42.0
+ return pa.array(values, type=ty)
+ elif pa.types.is_decimal(ty):
+ # TODO(kszucs): properly limit the precision
+ # value = st.decimals(places=type.scale, allow_infinity=False)
+ h.reject()
+ elif pa.types.is_time(ty):
+ value = st.times()
+ elif pa.types.is_date(ty):
+ value = st.dates()
+ elif pa.types.is_timestamp(ty):
+ min_int64 = -(2**63)
+ max_int64 = 2**63 - 1
+ min_datetime = datetime.datetime.fromtimestamp(min_int64 // 10**9)
+ max_datetime = datetime.datetime.fromtimestamp(max_int64 // 10**9)
+ try:
+ offset_hours = int(ty.tz)
+ tz = pytz.FixedOffset(offset_hours * 60)
+ except ValueError:
+ tz = pytz.timezone(ty.tz)
+ value = st.datetimes(timezones=st.just(tz), min_value=min_datetime,
+ max_value=max_datetime)
+ elif pa.types.is_duration(ty):
+ value = st.timedeltas()
+ elif pa.types.is_binary(ty) or pa.types.is_large_binary(ty):
+ value = st.binary()
+ elif pa.types.is_string(ty) or pa.types.is_large_string(ty):
+ value = st.text()
+ elif pa.types.is_fixed_size_binary(ty):
+ value = st.binary(min_size=ty.byte_width, max_size=ty.byte_width)
+ elif pa.types.is_list(ty):
+ value = _pylist(ty.value_type, size=size, nullable=nullable)
+ elif pa.types.is_large_list(ty):
+ value = _pylist(ty.value_type, size=size, nullable=nullable)
+ elif pa.types.is_fixed_size_list(ty):
+ value = _pylist(ty.value_type, size=ty.list_size, nullable=nullable)
+ elif pa.types.is_dictionary(ty):
+ values = _pylist(ty.value_type, size=size, nullable=nullable)
+ return pa.array(draw(values), type=ty)
+ elif pa.types.is_map(ty):
+ value = _pymap(ty.key_type, ty.item_type, size=_default_array_sizes,
+ nullable=nullable)
+ elif pa.types.is_struct(ty):
+ h.assume(len(ty) > 0)
+ fields, child_arrays = [], []
+ for field in ty:
+ fields.append(field)
+ child_arrays.append(draw(arrays(field.type, size=size)))
+ return pa.StructArray.from_arrays(child_arrays, fields=fields)
+ else:
+ raise NotImplementedError(ty)
+
+ if nullable:
+ value = st.one_of(st.none(), value)
+ values = st.lists(value, min_size=size, max_size=size)
+
+ return pa.array(draw(values), type=ty)
+
+
+@st.composite
+def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None):
+ if isinstance(type, st.SearchStrategy):
+ type = draw(type)
+
+ # TODO(kszucs): remove it, field metadata is not kept
+ h.assume(not pa.types.is_struct(type))
+
+ chunk = arrays(type, size=chunk_size)
+ chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks)
+
+ return pa.chunked_array(draw(chunks), type=type)
+
+
+@st.composite
+def record_batches(draw, type, rows=None, max_fields=None):
+ if isinstance(rows, st.SearchStrategy):
+ rows = draw(rows)
+ elif rows is None:
+ rows = draw(_default_array_sizes)
+ elif not isinstance(rows, int):
+ raise TypeError('Rows must be an integer')
+
+ schema = draw(schemas(type, max_fields=max_fields))
+ children = [draw(arrays(field.type, size=rows)) for field in schema]
+ # TODO(kszucs): the names and schema arguments are not consistent with
+ # Table.from_array's arguments
+ return pa.RecordBatch.from_arrays(children, names=schema)
+
+
+@st.composite
+def tables(draw, type, rows=None, max_fields=None):
+ if isinstance(rows, st.SearchStrategy):
+ rows = draw(rows)
+ elif rows is None:
+ rows = draw(_default_array_sizes)
+ elif not isinstance(rows, int):
+ raise TypeError('Rows must be an integer')
+
+ schema = draw(schemas(type, max_fields=max_fields))
+ children = [draw(arrays(field.type, size=rows)) for field in schema]
+ return pa.Table.from_arrays(children, schema=schema)
+
+
+all_arrays = arrays(all_types)
+all_chunked_arrays = chunked_arrays(all_types)
+all_record_batches = record_batches(all_types)
+all_tables = tables(all_types)
+
+
+# Define the same rules as above for pandas tests by excluding certain types
+# from the generation because of known issues.
+
+pandas_compatible_primitive_types = st.one_of(
+ null_type,
+ bool_type,
+ integer_types,
+ st.sampled_from([pa.float32(), pa.float64()]),
+ decimal128_type,
+ date_types,
+ time_types,
+ # Need to exclude timestamp and duration types otherwise hypothesis
+ # discovers ARROW-10210
+ # timestamp_types,
+ # duration_types
+ interval_types,
+ binary_type,
+ string_type,
+ large_binary_type,
+ large_string_type,
+)
+
+# Need to exclude floating point types otherwise hypothesis discovers
+# ARROW-10211
+pandas_compatible_dictionary_value_types = st.one_of(
+ bool_type,
+ integer_types,
+ binary_type,
+ string_type,
+ fixed_size_binary_type,
+)
+
+
+def pandas_compatible_list_types(
+ item_strategy=pandas_compatible_primitive_types
+):
+ # Need to exclude fixed size list type otherwise hypothesis discovers
+ # ARROW-10194
+ return (
+ st.builds(pa.list_, item_strategy) |
+ st.builds(pa.large_list, item_strategy)
+ )
+
+
+pandas_compatible_types = st.deferred(
+ lambda: st.one_of(
+ pandas_compatible_primitive_types,
+ pandas_compatible_list_types(pandas_compatible_primitive_types),
+ struct_types(pandas_compatible_primitive_types),
+ dictionary_types(
+ value_strategy=pandas_compatible_dictionary_value_types
+ ),
+ pandas_compatible_list_types(pandas_compatible_types),
+ struct_types(pandas_compatible_types)
+ )
+)
diff --git a/src/arrow/python/pyarrow/tests/test_adhoc_memory_leak.py b/src/arrow/python/pyarrow/tests/test_adhoc_memory_leak.py
new file mode 100644
index 000000000..cd381cf42
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_adhoc_memory_leak.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import numpy as np
+import pyarrow as pa
+
+import pyarrow.tests.util as test_util
+
+try:
+ import pandas as pd
+except ImportError:
+ pass
+
+
+@pytest.mark.memory_leak
+@pytest.mark.pandas
+def test_deserialize_pandas_arrow_7956():
+ df = pd.DataFrame({'a': np.arange(10000),
+ 'b': [test_util.rands(5) for _ in range(10000)]})
+
+ def action():
+ df_bytes = pa.ipc.serialize_pandas(df).to_pybytes()
+ buf = pa.py_buffer(df_bytes)
+ pa.ipc.deserialize_pandas(buf)
+
+ # Abort at 128MB threshold
+ test_util.memory_leak_check(action, threshold=1 << 27, iterations=100)
diff --git a/src/arrow/python/pyarrow/tests/test_array.py b/src/arrow/python/pyarrow/tests/test_array.py
new file mode 100644
index 000000000..9a1f41efe
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_array.py
@@ -0,0 +1,3064 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections.abc import Iterable
+import datetime
+import decimal
+import hypothesis as h
+import hypothesis.strategies as st
+import itertools
+import pickle
+import pytest
+import struct
+import sys
+import weakref
+
+import numpy as np
+try:
+ import pickle5
+except ImportError:
+ pickle5 = None
+import pytz
+
+import pyarrow as pa
+import pyarrow.tests.strategies as past
+
+
+def test_total_bytes_allocated():
+ assert pa.total_allocated_bytes() == 0
+
+
+def test_weakref():
+ arr = pa.array([1, 2, 3])
+ wr = weakref.ref(arr)
+ assert wr() is not None
+ del arr
+ assert wr() is None
+
+
+def test_getitem_NULL():
+ arr = pa.array([1, None, 2])
+ assert arr[1].as_py() is None
+ assert arr[1].is_valid is False
+ assert isinstance(arr[1], pa.Int64Scalar)
+
+
+def test_constructor_raises():
+ # This could happen by wrong capitalization.
+ # ARROW-2638: prevent calling extension class constructors directly
+ with pytest.raises(TypeError):
+ pa.Array([1, 2])
+
+
+def test_list_format():
+ arr = pa.array([[1], None, [2, 3, None]])
+ result = arr.to_string()
+ expected = """\
+[
+ [
+ 1
+ ],
+ null,
+ [
+ 2,
+ 3,
+ null
+ ]
+]"""
+ assert result == expected
+
+
+def test_string_format():
+ arr = pa.array(['', None, 'foo'])
+ result = arr.to_string()
+ expected = """\
+[
+ "",
+ null,
+ "foo"
+]"""
+ assert result == expected
+
+
+def test_long_array_format():
+ arr = pa.array(range(100))
+ result = arr.to_string(window=2)
+ expected = """\
+[
+ 0,
+ 1,
+ ...
+ 98,
+ 99
+]"""
+ assert result == expected
+
+
+def test_binary_format():
+ arr = pa.array([b'\x00', b'', None, b'\x01foo', b'\x80\xff'])
+ result = arr.to_string()
+ expected = """\
+[
+ 00,
+ ,
+ null,
+ 01666F6F,
+ 80FF
+]"""
+ assert result == expected
+
+
+def test_binary_total_values_length():
+ arr = pa.array([b'0000', None, b'11111', b'222222', b'3333333'],
+ type='binary')
+ large_arr = pa.array([b'0000', None, b'11111', b'222222', b'3333333'],
+ type='large_binary')
+
+ assert arr.total_values_length == 22
+ assert arr.slice(1, 3).total_values_length == 11
+ assert large_arr.total_values_length == 22
+ assert large_arr.slice(1, 3).total_values_length == 11
+
+
+def test_to_numpy_zero_copy():
+ arr = pa.array(range(10))
+
+ np_arr = arr.to_numpy()
+
+ # check for zero copy (both arrays using same memory)
+ arrow_buf = arr.buffers()[1]
+ assert arrow_buf.address == np_arr.ctypes.data
+
+ arr = None
+ import gc
+ gc.collect()
+
+ # Ensure base is still valid
+ assert np_arr.base is not None
+ expected = np.arange(10)
+ np.testing.assert_array_equal(np_arr, expected)
+
+
+def test_to_numpy_unsupported_types():
+ # ARROW-2871: Some primitive types are not yet supported in to_numpy
+ bool_arr = pa.array([True, False, True])
+
+ with pytest.raises(ValueError):
+ bool_arr.to_numpy()
+
+ result = bool_arr.to_numpy(zero_copy_only=False)
+ expected = np.array([True, False, True])
+ np.testing.assert_array_equal(result, expected)
+
+ null_arr = pa.array([None, None, None])
+
+ with pytest.raises(ValueError):
+ null_arr.to_numpy()
+
+ result = null_arr.to_numpy(zero_copy_only=False)
+ expected = np.array([None, None, None], dtype=object)
+ np.testing.assert_array_equal(result, expected)
+
+ arr = pa.array([1, 2, None])
+
+ with pytest.raises(ValueError, match="with 1 nulls"):
+ arr.to_numpy()
+
+
+def test_to_numpy_writable():
+ arr = pa.array(range(10))
+ np_arr = arr.to_numpy()
+
+ # by default not writable for zero-copy conversion
+ with pytest.raises(ValueError):
+ np_arr[0] = 10
+
+ np_arr2 = arr.to_numpy(zero_copy_only=False, writable=True)
+ np_arr2[0] = 10
+ assert arr[0].as_py() == 0
+
+ # when asking for writable, cannot do zero-copy
+ with pytest.raises(ValueError):
+ arr.to_numpy(zero_copy_only=True, writable=True)
+
+
+@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns'])
+def test_to_numpy_datetime64(unit):
+ arr = pa.array([1, 2, 3], pa.timestamp(unit))
+ expected = np.array([1, 2, 3], dtype="datetime64[{}]".format(unit))
+ np_arr = arr.to_numpy()
+ np.testing.assert_array_equal(np_arr, expected)
+
+
+@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns'])
+def test_to_numpy_timedelta64(unit):
+ arr = pa.array([1, 2, 3], pa.duration(unit))
+ expected = np.array([1, 2, 3], dtype="timedelta64[{}]".format(unit))
+ np_arr = arr.to_numpy()
+ np.testing.assert_array_equal(np_arr, expected)
+
+
+def test_to_numpy_dictionary():
+ # ARROW-7591
+ arr = pa.array(["a", "b", "a"]).dictionary_encode()
+ expected = np.array(["a", "b", "a"], dtype=object)
+ np_arr = arr.to_numpy(zero_copy_only=False)
+ np.testing.assert_array_equal(np_arr, expected)
+
+
+@pytest.mark.pandas
+def test_to_pandas_zero_copy():
+ import gc
+
+ arr = pa.array(range(10))
+
+ for i in range(10):
+ series = arr.to_pandas()
+ assert sys.getrefcount(series) == 2
+ series = None # noqa
+
+ assert sys.getrefcount(arr) == 2
+
+ for i in range(10):
+ arr = pa.array(range(10))
+ series = arr.to_pandas()
+ arr = None
+ gc.collect()
+
+ # Ensure base is still valid
+
+ # Because of py.test's assert inspection magic, if you put getrefcount
+ # on the line being examined, it will be 1 higher than you expect
+ base_refcount = sys.getrefcount(series.values.base)
+ assert base_refcount == 2
+ series.sum()
+
+
+@pytest.mark.nopandas
+@pytest.mark.pandas
+def test_asarray():
+ # ensure this is tested both when pandas is present or not (ARROW-6564)
+
+ arr = pa.array(range(4))
+
+ # The iterator interface gives back an array of Int64Value's
+ np_arr = np.asarray([_ for _ in arr])
+ assert np_arr.tolist() == [0, 1, 2, 3]
+ assert np_arr.dtype == np.dtype('O')
+ assert type(np_arr[0]) == pa.lib.Int64Value
+
+ # Calling with the arrow array gives back an array with 'int64' dtype
+ np_arr = np.asarray(arr)
+ assert np_arr.tolist() == [0, 1, 2, 3]
+ assert np_arr.dtype == np.dtype('int64')
+
+ # An optional type can be specified when calling np.asarray
+ np_arr = np.asarray(arr, dtype='str')
+ assert np_arr.tolist() == ['0', '1', '2', '3']
+
+ # If PyArrow array has null values, numpy type will be changed as needed
+ # to support nulls.
+ arr = pa.array([0, 1, 2, None])
+ assert arr.type == pa.int64()
+ np_arr = np.asarray(arr)
+ elements = np_arr.tolist()
+ assert elements[:3] == [0., 1., 2.]
+ assert np.isnan(elements[3])
+ assert np_arr.dtype == np.dtype('float64')
+
+ # DictionaryType data will be converted to dense numpy array
+ arr = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 2, 0, 1]), pa.array(['a', 'b', 'c']))
+ np_arr = np.asarray(arr)
+ assert np_arr.dtype == np.dtype('object')
+ assert np_arr.tolist() == ['a', 'b', 'c', 'a', 'b']
+
+
+@pytest.mark.parametrize('ty', [
+ None,
+ pa.null(),
+ pa.int8(),
+ pa.string()
+])
+def test_nulls(ty):
+ arr = pa.nulls(3, type=ty)
+ expected = pa.array([None, None, None], type=ty)
+
+ assert len(arr) == 3
+ assert arr.equals(expected)
+
+ if ty is None:
+ assert arr.type == pa.null()
+ else:
+ assert arr.type == ty
+
+
+def test_array_from_scalar():
+ today = datetime.date.today()
+ now = datetime.datetime.now()
+ now_utc = now.replace(tzinfo=pytz.utc)
+ now_with_tz = now_utc.astimezone(pytz.timezone('US/Eastern'))
+ oneday = datetime.timedelta(days=1)
+
+ cases = [
+ (None, 1, pa.array([None])),
+ (None, 10, pa.nulls(10)),
+ (-1, 3, pa.array([-1, -1, -1], type=pa.int64())),
+ (2.71, 2, pa.array([2.71, 2.71], type=pa.float64())),
+ ("string", 4, pa.array(["string"] * 4)),
+ (
+ pa.scalar(8, type=pa.uint8()),
+ 17,
+ pa.array([8] * 17, type=pa.uint8())
+ ),
+ (pa.scalar(None), 3, pa.array([None, None, None])),
+ (pa.scalar(True), 11, pa.array([True] * 11)),
+ (today, 2, pa.array([today] * 2)),
+ (now, 10, pa.array([now] * 10)),
+ (
+ now_with_tz,
+ 2,
+ pa.array(
+ [now_utc] * 2,
+ type=pa.timestamp('us', tz=pytz.timezone('US/Eastern'))
+ )
+ ),
+ (now.time(), 9, pa.array([now.time()] * 9)),
+ (oneday, 4, pa.array([oneday] * 4)),
+ (False, 9, pa.array([False] * 9)),
+ ([1, 2], 2, pa.array([[1, 2], [1, 2]])),
+ (
+ pa.scalar([-1, 3], type=pa.large_list(pa.int8())),
+ 5,
+ pa.array([[-1, 3]] * 5, type=pa.large_list(pa.int8()))
+ ),
+ ({'a': 1, 'b': 2}, 3, pa.array([{'a': 1, 'b': 2}] * 3))
+ ]
+
+ for value, size, expected in cases:
+ arr = pa.repeat(value, size)
+ assert len(arr) == size
+ assert arr.type.equals(expected.type)
+ assert arr.equals(expected)
+ if expected.type == pa.null():
+ assert arr.null_count == size
+ else:
+ assert arr.null_count == 0
+
+
+def test_array_from_dictionary_scalar():
+ dictionary = ['foo', 'bar', 'baz']
+ arr = pa.DictionaryArray.from_arrays([2, 1, 2, 0], dictionary=dictionary)
+
+ result = pa.repeat(arr[0], 5)
+ expected = pa.DictionaryArray.from_arrays([2] * 5, dictionary=dictionary)
+ assert result.equals(expected)
+
+ result = pa.repeat(arr[3], 5)
+ expected = pa.DictionaryArray.from_arrays([0] * 5, dictionary=dictionary)
+ assert result.equals(expected)
+
+
+def test_array_getitem():
+ arr = pa.array(range(10, 15))
+ lst = arr.to_pylist()
+
+ for idx in range(-len(arr), len(arr)):
+ assert arr[idx].as_py() == lst[idx]
+ for idx in range(-2 * len(arr), -len(arr)):
+ with pytest.raises(IndexError):
+ arr[idx]
+ for idx in range(len(arr), 2 * len(arr)):
+ with pytest.raises(IndexError):
+ arr[idx]
+
+ # check that numpy scalars are supported
+ for idx in range(-len(arr), len(arr)):
+ assert arr[np.int32(idx)].as_py() == lst[idx]
+
+
+def test_array_slice():
+ arr = pa.array(range(10))
+
+ sliced = arr.slice(2)
+ expected = pa.array(range(2, 10))
+ assert sliced.equals(expected)
+
+ sliced2 = arr.slice(2, 4)
+ expected2 = pa.array(range(2, 6))
+ assert sliced2.equals(expected2)
+
+ # 0 offset
+ assert arr.slice(0).equals(arr)
+
+ # Slice past end of array
+ assert len(arr.slice(len(arr))) == 0
+ assert len(arr.slice(len(arr) + 2)) == 0
+ assert len(arr.slice(len(arr) + 2, 100)) == 0
+
+ with pytest.raises(IndexError):
+ arr.slice(-1)
+
+ with pytest.raises(ValueError):
+ arr.slice(2, -1)
+
+ # Test slice notation
+ assert arr[2:].equals(arr.slice(2))
+ assert arr[2:5].equals(arr.slice(2, 3))
+ assert arr[-5:].equals(arr.slice(len(arr) - 5))
+
+ n = len(arr)
+ for start in range(-n * 2, n * 2):
+ for stop in range(-n * 2, n * 2):
+ res = arr[start:stop]
+ res.validate()
+ expected = arr.to_pylist()[start:stop]
+ assert res.to_pylist() == expected
+ assert res.to_numpy().tolist() == expected
+
+
+def test_array_slice_negative_step():
+ # ARROW-2714
+ np_arr = np.arange(20)
+ arr = pa.array(np_arr)
+ chunked_arr = pa.chunked_array([arr])
+
+ cases = [
+ slice(None, None, -1),
+ slice(None, 6, -2),
+ slice(10, 6, -2),
+ slice(8, None, -2),
+ slice(2, 10, -2),
+ slice(10, 2, -2),
+ slice(None, None, 2),
+ slice(0, 10, 2),
+ ]
+
+ for case in cases:
+ result = arr[case]
+ expected = pa.array(np_arr[case])
+ assert result.equals(expected)
+
+ result = pa.record_batch([arr], names=['f0'])[case]
+ expected = pa.record_batch([expected], names=['f0'])
+ assert result.equals(expected)
+
+ result = chunked_arr[case]
+ expected = pa.chunked_array([np_arr[case]])
+ assert result.equals(expected)
+
+
+def test_array_diff():
+ # ARROW-6252
+ arr1 = pa.array(['foo'], type=pa.utf8())
+ arr2 = pa.array(['foo', 'bar', None], type=pa.utf8())
+ arr3 = pa.array([1, 2, 3])
+ arr4 = pa.array([[], [1], None], type=pa.list_(pa.int64()))
+
+ assert arr1.diff(arr1) == ''
+ assert arr1.diff(arr2) == '''
+@@ -1, +1 @@
++"bar"
++null
+'''
+ assert arr1.diff(arr3).strip() == '# Array types differed: string vs int64'
+ assert arr1.diff(arr3).strip() == '# Array types differed: string vs int64'
+ assert arr1.diff(arr4).strip() == ('# Array types differed: string vs '
+ 'list<item: int64>')
+
+
+def test_array_iter():
+ arr = pa.array(range(10))
+
+ for i, j in zip(range(10), arr):
+ assert i == j.as_py()
+
+ assert isinstance(arr, Iterable)
+
+
+def test_struct_array_slice():
+ # ARROW-2311: slicing nested arrays needs special care
+ ty = pa.struct([pa.field('a', pa.int8()),
+ pa.field('b', pa.float32())])
+ arr = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty)
+ assert arr[1:].to_pylist() == [{'a': 3, 'b': 4.5},
+ {'a': 5, 'b': 6.5}]
+
+
+def test_array_factory_invalid_type():
+
+ class MyObject:
+ pass
+
+ arr = np.array([MyObject()])
+ with pytest.raises(ValueError):
+ pa.array(arr)
+
+
+def test_array_ref_to_ndarray_base():
+ arr = np.array([1, 2, 3])
+
+ refcount = sys.getrefcount(arr)
+ arr2 = pa.array(arr) # noqa
+ assert sys.getrefcount(arr) == (refcount + 1)
+
+
+def test_array_eq():
+ # ARROW-2150 / ARROW-9445: we define the __eq__ behavior to be
+ # data equality (not element-wise equality)
+ arr1 = pa.array([1, 2, 3], type=pa.int32())
+ arr2 = pa.array([1, 2, 3], type=pa.int32())
+ arr3 = pa.array([1, 2, 3], type=pa.int64())
+
+ assert (arr1 == arr2) is True
+ assert (arr1 != arr2) is False
+ assert (arr1 == arr3) is False
+ assert (arr1 != arr3) is True
+
+ assert (arr1 == 1) is False
+ assert (arr1 == None) is False # noqa: E711
+
+
+def test_array_from_buffers():
+ values_buf = pa.py_buffer(np.int16([4, 5, 6, 7]))
+ nulls_buf = pa.py_buffer(np.uint8([0b00001101]))
+ arr = pa.Array.from_buffers(pa.int16(), 4, [nulls_buf, values_buf])
+ assert arr.type == pa.int16()
+ assert arr.to_pylist() == [4, None, 6, 7]
+
+ arr = pa.Array.from_buffers(pa.int16(), 4, [None, values_buf])
+ assert arr.type == pa.int16()
+ assert arr.to_pylist() == [4, 5, 6, 7]
+
+ arr = pa.Array.from_buffers(pa.int16(), 3, [nulls_buf, values_buf],
+ offset=1)
+ assert arr.type == pa.int16()
+ assert arr.to_pylist() == [None, 6, 7]
+
+ with pytest.raises(TypeError):
+ pa.Array.from_buffers(pa.int16(), 3, ['', ''], offset=1)
+
+
+def test_string_binary_from_buffers():
+ array = pa.array(["a", None, "b", "c"])
+
+ buffers = array.buffers()
+ copied = pa.StringArray.from_buffers(
+ len(array), buffers[1], buffers[2], buffers[0], array.null_count,
+ array.offset)
+ assert copied.to_pylist() == ["a", None, "b", "c"]
+
+ binary_copy = pa.Array.from_buffers(pa.binary(), len(array),
+ array.buffers(), array.null_count,
+ array.offset)
+ assert binary_copy.to_pylist() == [b"a", None, b"b", b"c"]
+
+ copied = pa.StringArray.from_buffers(
+ len(array), buffers[1], buffers[2], buffers[0])
+ assert copied.to_pylist() == ["a", None, "b", "c"]
+
+ sliced = array[1:]
+ buffers = sliced.buffers()
+ copied = pa.StringArray.from_buffers(
+ len(sliced), buffers[1], buffers[2], buffers[0], -1, sliced.offset)
+ assert copied.to_pylist() == [None, "b", "c"]
+ assert copied.null_count == 1
+
+ # Slice but exclude all null entries so that we don't need to pass
+ # the null bitmap.
+ sliced = array[2:]
+ buffers = sliced.buffers()
+ copied = pa.StringArray.from_buffers(
+ len(sliced), buffers[1], buffers[2], None, -1, sliced.offset)
+ assert copied.to_pylist() == ["b", "c"]
+ assert copied.null_count == 0
+
+
+@pytest.mark.parametrize('list_type_factory', [pa.list_, pa.large_list])
+def test_list_from_buffers(list_type_factory):
+ ty = list_type_factory(pa.int16())
+ array = pa.array([[0, 1, 2], None, [], [3, 4, 5]], type=ty)
+ assert array.type == ty
+
+ buffers = array.buffers()
+
+ with pytest.raises(ValueError):
+ # No children
+ pa.Array.from_buffers(ty, 4, [None, buffers[1]])
+
+ child = pa.Array.from_buffers(pa.int16(), 6, buffers[2:])
+ copied = pa.Array.from_buffers(ty, 4, buffers[:2], children=[child])
+ assert copied.equals(array)
+
+ with pytest.raises(ValueError):
+ # too many children
+ pa.Array.from_buffers(ty, 4, [None, buffers[1]],
+ children=[child, child])
+
+
+def test_struct_from_buffers():
+ ty = pa.struct([pa.field('a', pa.int16()), pa.field('b', pa.utf8())])
+ array = pa.array([{'a': 0, 'b': 'foo'}, None, {'a': 5, 'b': ''}],
+ type=ty)
+ buffers = array.buffers()
+
+ with pytest.raises(ValueError):
+ # No children
+ pa.Array.from_buffers(ty, 3, [None, buffers[1]])
+
+ children = [pa.Array.from_buffers(pa.int16(), 3, buffers[1:3]),
+ pa.Array.from_buffers(pa.utf8(), 3, buffers[3:])]
+ copied = pa.Array.from_buffers(ty, 3, buffers[:1], children=children)
+ assert copied.equals(array)
+
+ with pytest.raises(ValueError):
+ # not enough many children
+ pa.Array.from_buffers(ty, 3, [buffers[0]],
+ children=children[:1])
+
+
+def test_struct_from_arrays():
+ a = pa.array([4, 5, 6], type=pa.int64())
+ b = pa.array(["bar", None, ""])
+ c = pa.array([[1, 2], None, [3, None]])
+ expected_list = [
+ {'a': 4, 'b': 'bar', 'c': [1, 2]},
+ {'a': 5, 'b': None, 'c': None},
+ {'a': 6, 'b': '', 'c': [3, None]},
+ ]
+
+ # From field names
+ arr = pa.StructArray.from_arrays([a, b, c], ["a", "b", "c"])
+ assert arr.type == pa.struct(
+ [("a", a.type), ("b", b.type), ("c", c.type)])
+ assert arr.to_pylist() == expected_list
+
+ with pytest.raises(ValueError):
+ pa.StructArray.from_arrays([a, b, c], ["a", "b"])
+
+ arr = pa.StructArray.from_arrays([], [])
+ assert arr.type == pa.struct([])
+ assert arr.to_pylist() == []
+
+ # From fields
+ fa = pa.field("a", a.type, nullable=False)
+ fb = pa.field("b", b.type)
+ fc = pa.field("c", c.type)
+ arr = pa.StructArray.from_arrays([a, b, c], fields=[fa, fb, fc])
+ assert arr.type == pa.struct([fa, fb, fc])
+ assert not arr.type[0].nullable
+ assert arr.to_pylist() == expected_list
+
+ with pytest.raises(ValueError):
+ pa.StructArray.from_arrays([a, b, c], fields=[fa, fb])
+
+ arr = pa.StructArray.from_arrays([], fields=[])
+ assert arr.type == pa.struct([])
+ assert arr.to_pylist() == []
+
+ # Inconsistent fields
+ fa2 = pa.field("a", pa.int32())
+ with pytest.raises(ValueError, match="int64 vs int32"):
+ pa.StructArray.from_arrays([a, b, c], fields=[fa2, fb, fc])
+
+ arrays = [a, b, c]
+ fields = [fa, fb, fc]
+ # With mask
+ mask = pa.array([True, False, False])
+ arr = pa.StructArray.from_arrays(arrays, fields=fields, mask=mask)
+ assert arr.to_pylist() == [None] + expected_list[1:]
+
+ arr = pa.StructArray.from_arrays(arrays, names=['a', 'b', 'c'], mask=mask)
+ assert arr.to_pylist() == [None] + expected_list[1:]
+
+ # Bad masks
+ with pytest.raises(ValueError, match='Mask must be'):
+ pa.StructArray.from_arrays(arrays, fields, mask=[True, False, False])
+
+ with pytest.raises(ValueError, match='not contain nulls'):
+ pa.StructArray.from_arrays(
+ arrays, fields, mask=pa.array([True, False, None]))
+
+ with pytest.raises(ValueError, match='Mask must be'):
+ pa.StructArray.from_arrays(
+ arrays, fields, mask=pa.chunked_array([mask]))
+
+
+def test_struct_array_from_chunked():
+ # ARROW-11780
+ # Check that we don't segfault when trying to build
+ # a StructArray from a chunked array.
+ chunked_arr = pa.chunked_array([[1, 2, 3], [4, 5, 6]])
+
+ with pytest.raises(TypeError, match="Expected Array"):
+ pa.StructArray.from_arrays([chunked_arr], ["foo"])
+
+
+def test_dictionary_from_numpy():
+ indices = np.repeat([0, 1, 2], 2)
+ dictionary = np.array(['foo', 'bar', 'baz'], dtype=object)
+ mask = np.array([False, False, True, False, False, False])
+
+ d1 = pa.DictionaryArray.from_arrays(indices, dictionary)
+ d2 = pa.DictionaryArray.from_arrays(indices, dictionary, mask=mask)
+
+ assert d1.indices.to_pylist() == indices.tolist()
+ assert d1.indices.to_pylist() == indices.tolist()
+ assert d1.dictionary.to_pylist() == dictionary.tolist()
+ assert d2.dictionary.to_pylist() == dictionary.tolist()
+
+ for i in range(len(indices)):
+ assert d1[i].as_py() == dictionary[indices[i]]
+
+ if mask[i]:
+ assert d2[i].as_py() is None
+ else:
+ assert d2[i].as_py() == dictionary[indices[i]]
+
+
+def test_dictionary_to_numpy():
+ expected = pa.array(
+ ["foo", "bar", None, "foo"]
+ ).to_numpy(zero_copy_only=False)
+ a = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, None, 0]),
+ pa.array(['foo', 'bar'])
+ )
+ np.testing.assert_array_equal(a.to_numpy(zero_copy_only=False),
+ expected)
+
+ with pytest.raises(pa.ArrowInvalid):
+ # If this would be changed to no longer raise in the future,
+ # ensure to test the actual result because, currently, to_numpy takes
+ # for granted that when zero_copy_only=True there will be no nulls
+ # (it's the decoding of the DictionaryArray that handles the nulls and
+ # this is only activated with zero_copy_only=False)
+ a.to_numpy(zero_copy_only=True)
+
+ anonulls = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 1, 0]),
+ pa.array(['foo', 'bar'])
+ )
+ expected = pa.array(
+ ["foo", "bar", "bar", "foo"]
+ ).to_numpy(zero_copy_only=False)
+ np.testing.assert_array_equal(anonulls.to_numpy(zero_copy_only=False),
+ expected)
+
+ with pytest.raises(pa.ArrowInvalid):
+ anonulls.to_numpy(zero_copy_only=True)
+
+ afloat = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 1, 0]),
+ pa.array([13.7, 11.0])
+ )
+ expected = pa.array([13.7, 11.0, 11.0, 13.7]).to_numpy()
+ np.testing.assert_array_equal(afloat.to_numpy(zero_copy_only=True),
+ expected)
+ np.testing.assert_array_equal(afloat.to_numpy(zero_copy_only=False),
+ expected)
+
+ afloat2 = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, None, 0]),
+ pa.array([13.7, 11.0])
+ )
+ expected = pa.array(
+ [13.7, 11.0, None, 13.7]
+ ).to_numpy(zero_copy_only=False)
+ np.testing.assert_allclose(
+ afloat2.to_numpy(zero_copy_only=False),
+ expected,
+ equal_nan=True
+ )
+
+ # Testing for integers can reveal problems related to dealing
+ # with None values, as a numpy array of int dtype
+ # can't contain NaN nor None.
+ aints = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, None, 0]),
+ pa.array([7, 11])
+ )
+ expected = pa.array([7, 11, None, 7]).to_numpy(zero_copy_only=False)
+ np.testing.assert_allclose(
+ aints.to_numpy(zero_copy_only=False),
+ expected,
+ equal_nan=True
+ )
+
+
+def test_dictionary_from_boxed_arrays():
+ indices = np.repeat([0, 1, 2], 2)
+ dictionary = np.array(['foo', 'bar', 'baz'], dtype=object)
+
+ iarr = pa.array(indices)
+ darr = pa.array(dictionary)
+
+ d1 = pa.DictionaryArray.from_arrays(iarr, darr)
+
+ assert d1.indices.to_pylist() == indices.tolist()
+ assert d1.dictionary.to_pylist() == dictionary.tolist()
+
+ for i in range(len(indices)):
+ assert d1[i].as_py() == dictionary[indices[i]]
+
+
+def test_dictionary_from_arrays_boundscheck():
+ indices1 = pa.array([0, 1, 2, 0, 1, 2])
+ indices2 = pa.array([0, -1, 2])
+ indices3 = pa.array([0, 1, 2, 3])
+
+ dictionary = pa.array(['foo', 'bar', 'baz'])
+
+ # Works fine
+ pa.DictionaryArray.from_arrays(indices1, dictionary)
+
+ with pytest.raises(pa.ArrowException):
+ pa.DictionaryArray.from_arrays(indices2, dictionary)
+
+ with pytest.raises(pa.ArrowException):
+ pa.DictionaryArray.from_arrays(indices3, dictionary)
+
+ # If we are confident that the indices are "safe" we can pass safe=False to
+ # disable the boundschecking
+ pa.DictionaryArray.from_arrays(indices2, dictionary, safe=False)
+
+
+def test_dictionary_indices():
+ # https://issues.apache.org/jira/browse/ARROW-6882
+ indices = pa.array([0, 1, 2, 0, 1, 2])
+ dictionary = pa.array(['foo', 'bar', 'baz'])
+ arr = pa.DictionaryArray.from_arrays(indices, dictionary)
+ arr.indices.validate(full=True)
+
+
+@pytest.mark.parametrize(('list_array_type', 'list_type_factory'),
+ [(pa.ListArray, pa.list_),
+ (pa.LargeListArray, pa.large_list)])
+def test_list_from_arrays(list_array_type, list_type_factory):
+ offsets_arr = np.array([0, 2, 5, 8], dtype='i4')
+ offsets = pa.array(offsets_arr, type='int32')
+ pyvalues = [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']
+ values = pa.array(pyvalues, type='binary')
+
+ result = list_array_type.from_arrays(offsets, values)
+ expected = pa.array([pyvalues[:2], pyvalues[2:5], pyvalues[5:8]],
+ type=list_type_factory(pa.binary()))
+
+ assert result.equals(expected)
+
+ # With nulls
+ offsets = [0, None, 2, 6]
+ values = [b'a', b'b', b'c', b'd', b'e', b'f']
+
+ result = list_array_type.from_arrays(offsets, values)
+ expected = pa.array([values[:2], None, values[2:]],
+ type=list_type_factory(pa.binary()))
+
+ assert result.equals(expected)
+
+ # Another edge case
+ offsets2 = [0, 2, None, 6]
+ result = list_array_type.from_arrays(offsets2, values)
+ expected = pa.array([values[:2], values[2:], None],
+ type=list_type_factory(pa.binary()))
+ assert result.equals(expected)
+
+ # raise on invalid array
+ offsets = [1, 3, 10]
+ values = np.arange(5)
+ with pytest.raises(ValueError):
+ list_array_type.from_arrays(offsets, values)
+
+ # Non-monotonic offsets
+ offsets = [0, 3, 2, 6]
+ values = list(range(6))
+ result = list_array_type.from_arrays(offsets, values)
+ with pytest.raises(ValueError):
+ result.validate(full=True)
+
+
+def test_map_from_arrays():
+ offsets_arr = np.array([0, 2, 5, 8], dtype='i4')
+ offsets = pa.array(offsets_arr, type='int32')
+ pykeys = [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h']
+ pyitems = list(range(len(pykeys)))
+ pypairs = list(zip(pykeys, pyitems))
+ pyentries = [pypairs[:2], pypairs[2:5], pypairs[5:8]]
+ keys = pa.array(pykeys, type='binary')
+ items = pa.array(pyitems, type='i4')
+
+ result = pa.MapArray.from_arrays(offsets, keys, items)
+ expected = pa.array(pyentries, type=pa.map_(pa.binary(), pa.int32()))
+
+ assert result.equals(expected)
+
+ # With nulls
+ offsets = [0, None, 2, 6]
+ pykeys = [b'a', b'b', b'c', b'd', b'e', b'f']
+ pyitems = [1, 2, 3, None, 4, 5]
+ pypairs = list(zip(pykeys, pyitems))
+ pyentries = [pypairs[:2], None, pypairs[2:]]
+ keys = pa.array(pykeys, type='binary')
+ items = pa.array(pyitems, type='i4')
+
+ result = pa.MapArray.from_arrays(offsets, keys, items)
+ expected = pa.array(pyentries, type=pa.map_(pa.binary(), pa.int32()))
+
+ assert result.equals(expected)
+
+ # check invalid usage
+
+ offsets = [0, 1, 3, 5]
+ keys = np.arange(5)
+ items = np.arange(5)
+ _ = pa.MapArray.from_arrays(offsets, keys, items)
+
+ # raise on invalid offsets
+ with pytest.raises(ValueError):
+ pa.MapArray.from_arrays(offsets + [6], keys, items)
+
+ # raise on length of keys != items
+ with pytest.raises(ValueError):
+ pa.MapArray.from_arrays(offsets, keys, np.concatenate([items, items]))
+
+ # raise on keys with null
+ keys_with_null = list(keys)[:-1] + [None]
+ assert len(keys_with_null) == len(items)
+ with pytest.raises(ValueError):
+ pa.MapArray.from_arrays(offsets, keys_with_null, items)
+
+
+def test_fixed_size_list_from_arrays():
+ values = pa.array(range(12), pa.int64())
+ result = pa.FixedSizeListArray.from_arrays(values, 4)
+ assert result.to_pylist() == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
+ assert result.type.equals(pa.list_(pa.int64(), 4))
+
+ # raise on invalid values / list_size
+ with pytest.raises(ValueError):
+ pa.FixedSizeListArray.from_arrays(values, -4)
+
+ with pytest.raises(ValueError):
+ # array with list size 0 cannot be constructed with from_arrays
+ pa.FixedSizeListArray.from_arrays(pa.array([], pa.int64()), 0)
+
+ with pytest.raises(ValueError):
+ # length of values not multiple of 5
+ pa.FixedSizeListArray.from_arrays(values, 5)
+
+
+def test_variable_list_from_arrays():
+ values = pa.array([1, 2, 3, 4], pa.int64())
+ offsets = pa.array([0, 2, 4])
+ result = pa.ListArray.from_arrays(offsets, values)
+ assert result.to_pylist() == [[1, 2], [3, 4]]
+ assert result.type.equals(pa.list_(pa.int64()))
+
+ offsets = pa.array([0, None, 2, 4])
+ result = pa.ListArray.from_arrays(offsets, values)
+ assert result.to_pylist() == [[1, 2], None, [3, 4]]
+
+ # raise if offset out of bounds
+ with pytest.raises(ValueError):
+ pa.ListArray.from_arrays(pa.array([-1, 2, 4]), values)
+
+ with pytest.raises(ValueError):
+ pa.ListArray.from_arrays(pa.array([0, 2, 5]), values)
+
+
+def test_union_from_dense():
+ binary = pa.array([b'a', b'b', b'c', b'd'], type='binary')
+ int64 = pa.array([1, 2, 3], type='int64')
+ types = pa.array([0, 1, 0, 0, 1, 1, 0], type='int8')
+ logical_types = pa.array([11, 13, 11, 11, 13, 13, 11], type='int8')
+ value_offsets = pa.array([0, 0, 1, 2, 1, 2, 3], type='int32')
+ py_value = [b'a', 1, b'b', b'c', 2, 3, b'd']
+
+ def check_result(result, expected_field_names, expected_type_codes,
+ expected_type_code_values):
+ result.validate(full=True)
+ actual_field_names = [result.type[i].name
+ for i in range(result.type.num_fields)]
+ assert actual_field_names == expected_field_names
+ assert result.type.mode == "dense"
+ assert result.type.type_codes == expected_type_codes
+ assert result.to_pylist() == py_value
+ assert expected_type_code_values.equals(result.type_codes)
+ assert value_offsets.equals(result.offsets)
+ assert result.field(0).equals(binary)
+ assert result.field(1).equals(int64)
+ with pytest.raises(KeyError):
+ result.field(-1)
+ with pytest.raises(KeyError):
+ result.field(2)
+
+ # without field names and type codes
+ check_result(pa.UnionArray.from_dense(types, value_offsets,
+ [binary, int64]),
+ expected_field_names=['0', '1'],
+ expected_type_codes=[0, 1],
+ expected_type_code_values=types)
+
+ # with field names
+ check_result(pa.UnionArray.from_dense(types, value_offsets,
+ [binary, int64],
+ ['bin', 'int']),
+ expected_field_names=['bin', 'int'],
+ expected_type_codes=[0, 1],
+ expected_type_code_values=types)
+
+ # with type codes
+ check_result(pa.UnionArray.from_dense(logical_types, value_offsets,
+ [binary, int64],
+ type_codes=[11, 13]),
+ expected_field_names=['0', '1'],
+ expected_type_codes=[11, 13],
+ expected_type_code_values=logical_types)
+
+ # with field names and type codes
+ check_result(pa.UnionArray.from_dense(logical_types, value_offsets,
+ [binary, int64],
+ ['bin', 'int'], [11, 13]),
+ expected_field_names=['bin', 'int'],
+ expected_type_codes=[11, 13],
+ expected_type_code_values=logical_types)
+
+ # Bad type ids
+ arr = pa.UnionArray.from_dense(logical_types, value_offsets,
+ [binary, int64])
+ with pytest.raises(pa.ArrowInvalid):
+ arr.validate(full=True)
+ arr = pa.UnionArray.from_dense(types, value_offsets, [binary, int64],
+ type_codes=[11, 13])
+ with pytest.raises(pa.ArrowInvalid):
+ arr.validate(full=True)
+
+ # Offset larger than child size
+ bad_offsets = pa.array([0, 0, 1, 2, 1, 2, 4], type='int32')
+ arr = pa.UnionArray.from_dense(types, bad_offsets, [binary, int64])
+ with pytest.raises(pa.ArrowInvalid):
+ arr.validate(full=True)
+
+
+def test_union_from_sparse():
+ binary = pa.array([b'a', b' ', b'b', b'c', b' ', b' ', b'd'],
+ type='binary')
+ int64 = pa.array([0, 1, 0, 0, 2, 3, 0], type='int64')
+ types = pa.array([0, 1, 0, 0, 1, 1, 0], type='int8')
+ logical_types = pa.array([11, 13, 11, 11, 13, 13, 11], type='int8')
+ py_value = [b'a', 1, b'b', b'c', 2, 3, b'd']
+
+ def check_result(result, expected_field_names, expected_type_codes,
+ expected_type_code_values):
+ result.validate(full=True)
+ assert result.to_pylist() == py_value
+ actual_field_names = [result.type[i].name
+ for i in range(result.type.num_fields)]
+ assert actual_field_names == expected_field_names
+ assert result.type.mode == "sparse"
+ assert result.type.type_codes == expected_type_codes
+ assert expected_type_code_values.equals(result.type_codes)
+ assert result.field(0).equals(binary)
+ assert result.field(1).equals(int64)
+ with pytest.raises(pa.ArrowTypeError):
+ result.offsets
+ with pytest.raises(KeyError):
+ result.field(-1)
+ with pytest.raises(KeyError):
+ result.field(2)
+
+ # without field names and type codes
+ check_result(pa.UnionArray.from_sparse(types, [binary, int64]),
+ expected_field_names=['0', '1'],
+ expected_type_codes=[0, 1],
+ expected_type_code_values=types)
+
+ # with field names
+ check_result(pa.UnionArray.from_sparse(types, [binary, int64],
+ ['bin', 'int']),
+ expected_field_names=['bin', 'int'],
+ expected_type_codes=[0, 1],
+ expected_type_code_values=types)
+
+ # with type codes
+ check_result(pa.UnionArray.from_sparse(logical_types, [binary, int64],
+ type_codes=[11, 13]),
+ expected_field_names=['0', '1'],
+ expected_type_codes=[11, 13],
+ expected_type_code_values=logical_types)
+
+ # with field names and type codes
+ check_result(pa.UnionArray.from_sparse(logical_types, [binary, int64],
+ ['bin', 'int'],
+ [11, 13]),
+ expected_field_names=['bin', 'int'],
+ expected_type_codes=[11, 13],
+ expected_type_code_values=logical_types)
+
+ # Bad type ids
+ arr = pa.UnionArray.from_sparse(logical_types, [binary, int64])
+ with pytest.raises(pa.ArrowInvalid):
+ arr.validate(full=True)
+ arr = pa.UnionArray.from_sparse(types, [binary, int64],
+ type_codes=[11, 13])
+ with pytest.raises(pa.ArrowInvalid):
+ arr.validate(full=True)
+
+ # Invalid child length
+ with pytest.raises(pa.ArrowInvalid):
+ arr = pa.UnionArray.from_sparse(logical_types, [binary, int64[1:]])
+
+
+def test_union_array_to_pylist_with_nulls():
+ # ARROW-9556
+ arr = pa.UnionArray.from_sparse(
+ pa.array([0, 1, 0, 0, 1], type=pa.int8()),
+ [
+ pa.array([0.0, 1.1, None, 3.3, 4.4]),
+ pa.array([True, None, False, True, False]),
+ ]
+ )
+ assert arr.to_pylist() == [0.0, None, None, 3.3, False]
+
+ arr = pa.UnionArray.from_dense(
+ pa.array([0, 1, 0, 0, 0, 1, 1], type=pa.int8()),
+ pa.array([0, 0, 1, 2, 3, 1, 2], type=pa.int32()),
+ [
+ pa.array([0.0, 1.1, None, 3.3]),
+ pa.array([True, None, False])
+ ]
+ )
+ assert arr.to_pylist() == [0.0, True, 1.1, None, 3.3, None, False]
+
+
+def test_union_array_slice():
+ # ARROW-2314
+ arr = pa.UnionArray.from_sparse(pa.array([0, 0, 1, 1], type=pa.int8()),
+ [pa.array(["a", "b", "c", "d"]),
+ pa.array([1, 2, 3, 4])])
+ assert arr[1:].to_pylist() == ["b", 3, 4]
+
+ binary = pa.array([b'a', b'b', b'c', b'd'], type='binary')
+ int64 = pa.array([1, 2, 3], type='int64')
+ types = pa.array([0, 1, 0, 0, 1, 1, 0], type='int8')
+ value_offsets = pa.array([0, 0, 2, 1, 1, 2, 3], type='int32')
+
+ arr = pa.UnionArray.from_dense(types, value_offsets, [binary, int64])
+ lst = arr.to_pylist()
+ for i in range(len(arr)):
+ for j in range(i, len(arr)):
+ assert arr[i:j].to_pylist() == lst[i:j]
+
+
+def _check_cast_case(case, *, safe=True, check_array_construction=True):
+ in_data, in_type, out_data, out_type = case
+ if isinstance(out_data, pa.Array):
+ assert out_data.type == out_type
+ expected = out_data
+ else:
+ expected = pa.array(out_data, type=out_type)
+
+ # check casting an already created array
+ if isinstance(in_data, pa.Array):
+ assert in_data.type == in_type
+ in_arr = in_data
+ else:
+ in_arr = pa.array(in_data, type=in_type)
+ casted = in_arr.cast(out_type, safe=safe)
+ casted.validate(full=True)
+ assert casted.equals(expected)
+
+ # constructing an array with out type which optionally involves casting
+ # for more see ARROW-1949
+ if check_array_construction:
+ in_arr = pa.array(in_data, type=out_type, safe=safe)
+ assert in_arr.equals(expected)
+
+
+def test_cast_integers_safe():
+ safe_cases = [
+ (np.array([0, 1, 2, 3], dtype='i1'), 'int8',
+ np.array([0, 1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([0, 1, 2, 3], dtype='i1'), 'int8',
+ np.array([0, 1, 2, 3], dtype='u4'), pa.uint16()),
+ (np.array([0, 1, 2, 3], dtype='i1'), 'int8',
+ np.array([0, 1, 2, 3], dtype='u1'), pa.uint8()),
+ (np.array([0, 1, 2, 3], dtype='i1'), 'int8',
+ np.array([0, 1, 2, 3], dtype='f8'), pa.float64())
+ ]
+
+ for case in safe_cases:
+ _check_cast_case(case)
+
+ unsafe_cases = [
+ (np.array([50000], dtype='i4'), 'int32', 'int16'),
+ (np.array([70000], dtype='i4'), 'int32', 'uint16'),
+ (np.array([-1], dtype='i4'), 'int32', 'uint16'),
+ (np.array([50000], dtype='u2'), 'uint16', 'int16')
+ ]
+ for in_data, in_type, out_type in unsafe_cases:
+ in_arr = pa.array(in_data, type=in_type)
+
+ with pytest.raises(pa.ArrowInvalid):
+ in_arr.cast(out_type)
+
+
+def test_cast_none():
+ # ARROW-3735: Ensure that calling cast(None) doesn't segfault.
+ arr = pa.array([1, 2, 3])
+
+ with pytest.raises(ValueError):
+ arr.cast(None)
+
+
+def test_cast_list_to_primitive():
+ # ARROW-8070: cast segfaults on unsupported cast from list<binary> to utf8
+ arr = pa.array([[1, 2], [3, 4]])
+ with pytest.raises(NotImplementedError):
+ arr.cast(pa.int8())
+
+ arr = pa.array([[b"a", b"b"], [b"c"]], pa.list_(pa.binary()))
+ with pytest.raises(NotImplementedError):
+ arr.cast(pa.binary())
+
+
+def test_slice_chunked_array_zero_chunks():
+ # ARROW-8911
+ arr = pa.chunked_array([], type='int8')
+ assert arr.num_chunks == 0
+
+ result = arr[:]
+ assert result.equals(arr)
+
+ # Do not crash
+ arr[:5]
+
+
+def test_cast_chunked_array():
+ arrays = [pa.array([1, 2, 3]), pa.array([4, 5, 6])]
+ carr = pa.chunked_array(arrays)
+
+ target = pa.float64()
+ casted = carr.cast(target)
+ expected = pa.chunked_array([x.cast(target) for x in arrays])
+ assert casted.equals(expected)
+
+
+def test_cast_chunked_array_empty():
+ # ARROW-8142
+ for typ1, typ2 in [(pa.dictionary(pa.int8(), pa.string()), pa.string()),
+ (pa.int64(), pa.int32())]:
+
+ arr = pa.chunked_array([], type=typ1)
+ result = arr.cast(typ2)
+ expected = pa.chunked_array([], type=typ2)
+ assert result.equals(expected)
+
+
+def test_chunked_array_data_warns():
+ with pytest.warns(FutureWarning):
+ res = pa.chunked_array([[]]).data
+ assert isinstance(res, pa.ChunkedArray)
+
+
+def test_cast_integers_unsafe():
+ # We let NumPy do the unsafe casting
+ unsafe_cases = [
+ (np.array([50000], dtype='i4'), 'int32',
+ np.array([50000], dtype='i2'), pa.int16()),
+ (np.array([70000], dtype='i4'), 'int32',
+ np.array([70000], dtype='u2'), pa.uint16()),
+ (np.array([-1], dtype='i4'), 'int32',
+ np.array([-1], dtype='u2'), pa.uint16()),
+ (np.array([50000], dtype='u2'), pa.uint16(),
+ np.array([50000], dtype='i2'), pa.int16())
+ ]
+
+ for case in unsafe_cases:
+ _check_cast_case(case, safe=False)
+
+
+def test_floating_point_truncate_safe():
+ safe_cases = [
+ (np.array([1.0, 2.0, 3.0], dtype='float32'), 'float32',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([1.0, 2.0, 3.0], dtype='float64'), 'float64',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([-10.0, 20.0, -30.0], dtype='float64'), 'float64',
+ np.array([-10, 20, -30], dtype='i4'), pa.int32()),
+ ]
+ for case in safe_cases:
+ _check_cast_case(case, safe=True)
+
+
+def test_floating_point_truncate_unsafe():
+ unsafe_cases = [
+ (np.array([1.1, 2.2, 3.3], dtype='float32'), 'float32',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([1.1, 2.2, 3.3], dtype='float64'), 'float64',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([-10.1, 20.2, -30.3], dtype='float64'), 'float64',
+ np.array([-10, 20, -30], dtype='i4'), pa.int32()),
+ ]
+ for case in unsafe_cases:
+ # test safe casting raises
+ with pytest.raises(pa.ArrowInvalid, match='truncated'):
+ _check_cast_case(case, safe=True)
+
+ # test unsafe casting truncates
+ _check_cast_case(case, safe=False)
+
+
+def test_decimal_to_int_safe():
+ safe_cases = [
+ (
+ [decimal.Decimal("123456"), None, decimal.Decimal("-912345")],
+ pa.decimal128(32, 5),
+ [123456, None, -912345],
+ pa.int32()
+ ),
+ (
+ [decimal.Decimal("1234"), None, decimal.Decimal("-9123")],
+ pa.decimal128(19, 10),
+ [1234, None, -9123],
+ pa.int16()
+ ),
+ (
+ [decimal.Decimal("123"), None, decimal.Decimal("-91")],
+ pa.decimal128(19, 10),
+ [123, None, -91],
+ pa.int8()
+ ),
+ ]
+ for case in safe_cases:
+ _check_cast_case(case)
+ _check_cast_case(case, safe=True)
+
+
+def test_decimal_to_int_value_out_of_bounds():
+ out_of_bounds_cases = [
+ (
+ np.array([
+ decimal.Decimal("1234567890123"),
+ None,
+ decimal.Decimal("-912345678901234")
+ ]),
+ pa.decimal128(32, 5),
+ [1912276171, None, -135950322],
+ pa.int32()
+ ),
+ (
+ [decimal.Decimal("123456"), None, decimal.Decimal("-912345678")],
+ pa.decimal128(32, 5),
+ [-7616, None, -19022],
+ pa.int16()
+ ),
+ (
+ [decimal.Decimal("1234"), None, decimal.Decimal("-9123")],
+ pa.decimal128(32, 5),
+ [-46, None, 93],
+ pa.int8()
+ ),
+ ]
+
+ for case in out_of_bounds_cases:
+ # test safe casting raises
+ with pytest.raises(pa.ArrowInvalid,
+ match='Integer value out of bounds'):
+ _check_cast_case(case)
+
+ # XXX `safe=False` can be ignored when constructing an array
+ # from a sequence of Python objects (ARROW-8567)
+ _check_cast_case(case, safe=False, check_array_construction=False)
+
+
+def test_decimal_to_int_non_integer():
+ non_integer_cases = [
+ (
+ [
+ decimal.Decimal("123456.21"),
+ None,
+ decimal.Decimal("-912345.13")
+ ],
+ pa.decimal128(32, 5),
+ [123456, None, -912345],
+ pa.int32()
+ ),
+ (
+ [decimal.Decimal("1234.134"), None, decimal.Decimal("-9123.1")],
+ pa.decimal128(19, 10),
+ [1234, None, -9123],
+ pa.int16()
+ ),
+ (
+ [decimal.Decimal("123.1451"), None, decimal.Decimal("-91.21")],
+ pa.decimal128(19, 10),
+ [123, None, -91],
+ pa.int8()
+ ),
+ ]
+
+ for case in non_integer_cases:
+ # test safe casting raises
+ msg_regexp = 'Rescaling Decimal128 value would cause data loss'
+ with pytest.raises(pa.ArrowInvalid, match=msg_regexp):
+ _check_cast_case(case)
+
+ _check_cast_case(case, safe=False)
+
+
+def test_decimal_to_decimal():
+ arr = pa.array(
+ [decimal.Decimal("1234.12"), None],
+ type=pa.decimal128(19, 10)
+ )
+ result = arr.cast(pa.decimal128(15, 6))
+ expected = pa.array(
+ [decimal.Decimal("1234.12"), None],
+ type=pa.decimal128(15, 6)
+ )
+ assert result.equals(expected)
+
+ msg_regexp = 'Rescaling Decimal128 value would cause data loss'
+ with pytest.raises(pa.ArrowInvalid, match=msg_regexp):
+ result = arr.cast(pa.decimal128(9, 1))
+
+ result = arr.cast(pa.decimal128(9, 1), safe=False)
+ expected = pa.array(
+ [decimal.Decimal("1234.1"), None],
+ type=pa.decimal128(9, 1)
+ )
+ assert result.equals(expected)
+
+ with pytest.raises(pa.ArrowInvalid,
+ match='Decimal value does not fit in precision'):
+ result = arr.cast(pa.decimal128(5, 2))
+
+
+def test_safe_cast_nan_to_int_raises():
+ arr = pa.array([np.nan, 1.])
+
+ with pytest.raises(pa.ArrowInvalid, match='truncated'):
+ arr.cast(pa.int64(), safe=True)
+
+
+def test_cast_signed_to_unsigned():
+ safe_cases = [
+ (np.array([0, 1, 2, 3], dtype='i1'), pa.uint8(),
+ np.array([0, 1, 2, 3], dtype='u1'), pa.uint8()),
+ (np.array([0, 1, 2, 3], dtype='i2'), pa.uint16(),
+ np.array([0, 1, 2, 3], dtype='u2'), pa.uint16())
+ ]
+
+ for case in safe_cases:
+ _check_cast_case(case)
+
+
+def test_cast_from_null():
+ in_data = [None] * 3
+ in_type = pa.null()
+ out_types = [
+ pa.null(),
+ pa.uint8(),
+ pa.float16(),
+ pa.utf8(),
+ pa.binary(),
+ pa.binary(10),
+ pa.list_(pa.int16()),
+ pa.list_(pa.int32(), 4),
+ pa.large_list(pa.uint8()),
+ pa.decimal128(19, 4),
+ pa.timestamp('us'),
+ pa.timestamp('us', tz='UTC'),
+ pa.timestamp('us', tz='Europe/Paris'),
+ pa.duration('us'),
+ pa.month_day_nano_interval(),
+ pa.struct([pa.field('a', pa.int32()),
+ pa.field('b', pa.list_(pa.int8())),
+ pa.field('c', pa.string())]),
+ pa.dictionary(pa.int32(), pa.string()),
+ ]
+ for out_type in out_types:
+ _check_cast_case((in_data, in_type, in_data, out_type))
+
+ out_types = [
+
+ pa.union([pa.field('a', pa.binary(10)),
+ pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE),
+ pa.union([pa.field('a', pa.binary(10)),
+ pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE),
+ ]
+ in_arr = pa.array(in_data, type=pa.null())
+ for out_type in out_types:
+ with pytest.raises(NotImplementedError):
+ in_arr.cast(out_type)
+
+
+def test_cast_string_to_number_roundtrip():
+ cases = [
+ (pa.array(["1", "127", "-128"]),
+ pa.array([1, 127, -128], type=pa.int8())),
+ (pa.array([None, "18446744073709551615"]),
+ pa.array([None, 18446744073709551615], type=pa.uint64())),
+ ]
+ for in_arr, expected in cases:
+ casted = in_arr.cast(expected.type, safe=True)
+ casted.validate(full=True)
+ assert casted.equals(expected)
+ casted_back = casted.cast(in_arr.type, safe=True)
+ casted_back.validate(full=True)
+ assert casted_back.equals(in_arr)
+
+
+def test_cast_dictionary():
+ # cast to the value type
+ arr = pa.array(
+ ["foo", "bar", None],
+ type=pa.dictionary(pa.int64(), pa.string())
+ )
+ expected = pa.array(["foo", "bar", None])
+ assert arr.type == pa.dictionary(pa.int64(), pa.string())
+ assert arr.cast(pa.string()) == expected
+
+ # cast to a different key type
+ for key_type in [pa.int8(), pa.int16(), pa.int32()]:
+ typ = pa.dictionary(key_type, pa.string())
+ expected = pa.array(
+ ["foo", "bar", None],
+ type=pa.dictionary(key_type, pa.string())
+ )
+ assert arr.cast(typ) == expected
+
+ # shouldn't crash (ARROW-7077)
+ with pytest.raises(pa.ArrowInvalid):
+ arr.cast(pa.int32())
+
+
+def test_view():
+ # ARROW-5992
+ arr = pa.array(['foo', 'bar', 'baz'], type=pa.utf8())
+ expected = pa.array(['foo', 'bar', 'baz'], type=pa.binary())
+
+ assert arr.view(pa.binary()).equals(expected)
+ assert arr.view('binary').equals(expected)
+
+
+def test_unique_simple():
+ cases = [
+ (pa.array([1, 2, 3, 1, 2, 3]), pa.array([1, 2, 3])),
+ (pa.array(['foo', None, 'bar', 'foo']),
+ pa.array(['foo', None, 'bar'])),
+ (pa.array(['foo', None, 'bar', 'foo'], pa.large_binary()),
+ pa.array(['foo', None, 'bar'], pa.large_binary())),
+ ]
+ for arr, expected in cases:
+ result = arr.unique()
+ assert result.equals(expected)
+ result = pa.chunked_array([arr]).unique()
+ assert result.equals(expected)
+
+
+def test_value_counts_simple():
+ cases = [
+ (pa.array([1, 2, 3, 1, 2, 3]),
+ pa.array([1, 2, 3]),
+ pa.array([2, 2, 2], type=pa.int64())),
+ (pa.array(['foo', None, 'bar', 'foo']),
+ pa.array(['foo', None, 'bar']),
+ pa.array([2, 1, 1], type=pa.int64())),
+ (pa.array(['foo', None, 'bar', 'foo'], pa.large_binary()),
+ pa.array(['foo', None, 'bar'], pa.large_binary()),
+ pa.array([2, 1, 1], type=pa.int64())),
+ ]
+ for arr, expected_values, expected_counts in cases:
+ for arr_in in (arr, pa.chunked_array([arr])):
+ result = arr_in.value_counts()
+ assert result.type.equals(
+ pa.struct([pa.field("values", arr.type),
+ pa.field("counts", pa.int64())]))
+ assert result.field("values").equals(expected_values)
+ assert result.field("counts").equals(expected_counts)
+
+
+def test_unique_value_counts_dictionary_type():
+ indices = pa.array([3, 0, 0, 0, 1, 1, 3, 0, 1, 3, 0, 1])
+ dictionary = pa.array(['foo', 'bar', 'baz', 'qux'])
+
+ arr = pa.DictionaryArray.from_arrays(indices, dictionary)
+
+ unique_result = arr.unique()
+ expected = pa.DictionaryArray.from_arrays(indices.unique(), dictionary)
+ assert unique_result.equals(expected)
+
+ result = arr.value_counts()
+ assert result.field('values').equals(unique_result)
+ assert result.field('counts').equals(pa.array([3, 5, 4], type='int64'))
+
+ arr = pa.DictionaryArray.from_arrays(
+ pa.array([], type='int64'), dictionary)
+ unique_result = arr.unique()
+ expected = pa.DictionaryArray.from_arrays(pa.array([], type='int64'),
+ pa.array([], type='utf8'))
+ assert unique_result.equals(expected)
+
+ result = arr.value_counts()
+ assert result.field('values').equals(unique_result)
+ assert result.field('counts').equals(pa.array([], type='int64'))
+
+
+def test_dictionary_encode_simple():
+ cases = [
+ (pa.array([1, 2, 3, None, 1, 2, 3]),
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 2, None, 0, 1, 2], type='int32'),
+ [1, 2, 3])),
+ (pa.array(['foo', None, 'bar', 'foo']),
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, None, 1, 0], type='int32'),
+ ['foo', 'bar'])),
+ (pa.array(['foo', None, 'bar', 'foo'], type=pa.large_binary()),
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, None, 1, 0], type='int32'),
+ pa.array(['foo', 'bar'], type=pa.large_binary()))),
+ ]
+ for arr, expected in cases:
+ result = arr.dictionary_encode()
+ assert result.equals(expected)
+ result = pa.chunked_array([arr]).dictionary_encode()
+ assert result.num_chunks == 1
+ assert result.chunk(0).equals(expected)
+ result = pa.chunked_array([], type=arr.type).dictionary_encode()
+ assert result.num_chunks == 0
+ assert result.type == expected.type
+
+
+def test_dictionary_encode_sliced():
+ cases = [
+ (pa.array([1, 2, 3, None, 1, 2, 3])[1:-1],
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, None, 2, 0], type='int32'),
+ [2, 3, 1])),
+ (pa.array([None, 'foo', 'bar', 'foo', 'xyzzy'])[1:-1],
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 0], type='int32'),
+ ['foo', 'bar'])),
+ (pa.array([None, 'foo', 'bar', 'foo', 'xyzzy'],
+ type=pa.large_string())[1:-1],
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 0], type='int32'),
+ pa.array(['foo', 'bar'], type=pa.large_string()))),
+ ]
+ for arr, expected in cases:
+ result = arr.dictionary_encode()
+ assert result.equals(expected)
+ result = pa.chunked_array([arr]).dictionary_encode()
+ assert result.num_chunks == 1
+ assert result.type == expected.type
+ assert result.chunk(0).equals(expected)
+ result = pa.chunked_array([], type=arr.type).dictionary_encode()
+ assert result.num_chunks == 0
+ assert result.type == expected.type
+
+ # ARROW-9143 dictionary_encode after slice was segfaulting
+ array = pa.array(['foo', 'bar', 'baz'])
+ array.slice(1).dictionary_encode()
+
+
+def test_dictionary_encode_zero_length():
+ # User-facing experience of ARROW-7008
+ arr = pa.array([], type=pa.string())
+ encoded = arr.dictionary_encode()
+ assert len(encoded.dictionary) == 0
+ encoded.validate(full=True)
+
+
+def test_dictionary_decode():
+ cases = [
+ (pa.array([1, 2, 3, None, 1, 2, 3]),
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 2, None, 0, 1, 2], type='int32'),
+ [1, 2, 3])),
+ (pa.array(['foo', None, 'bar', 'foo']),
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, None, 1, 0], type='int32'),
+ ['foo', 'bar'])),
+ (pa.array(['foo', None, 'bar', 'foo'], type=pa.large_binary()),
+ pa.DictionaryArray.from_arrays(
+ pa.array([0, None, 1, 0], type='int32'),
+ pa.array(['foo', 'bar'], type=pa.large_binary()))),
+ ]
+ for expected, arr in cases:
+ result = arr.dictionary_decode()
+ assert result.equals(expected)
+
+
+def test_cast_time32_to_int():
+ arr = pa.array(np.array([0, 1, 2], dtype='int32'),
+ type=pa.time32('s'))
+ expected = pa.array([0, 1, 2], type='i4')
+
+ result = arr.cast('i4')
+ assert result.equals(expected)
+
+
+def test_cast_time64_to_int():
+ arr = pa.array(np.array([0, 1, 2], dtype='int64'),
+ type=pa.time64('us'))
+ expected = pa.array([0, 1, 2], type='i8')
+
+ result = arr.cast('i8')
+ assert result.equals(expected)
+
+
+def test_cast_timestamp_to_int():
+ arr = pa.array(np.array([0, 1, 2], dtype='int64'),
+ type=pa.timestamp('us'))
+ expected = pa.array([0, 1, 2], type='i8')
+
+ result = arr.cast('i8')
+ assert result.equals(expected)
+
+
+def test_cast_date32_to_int():
+ arr = pa.array([0, 1, 2], type='i4')
+
+ result1 = arr.cast('date32')
+ result2 = result1.cast('i4')
+
+ expected1 = pa.array([
+ datetime.date(1970, 1, 1),
+ datetime.date(1970, 1, 2),
+ datetime.date(1970, 1, 3)
+ ]).cast('date32')
+
+ assert result1.equals(expected1)
+ assert result2.equals(arr)
+
+
+def test_cast_duration_to_int():
+ arr = pa.array(np.array([0, 1, 2], dtype='int64'),
+ type=pa.duration('us'))
+ expected = pa.array([0, 1, 2], type='i8')
+
+ result = arr.cast('i8')
+ assert result.equals(expected)
+
+
+def test_cast_binary_to_utf8():
+ binary_arr = pa.array([b'foo', b'bar', b'baz'], type=pa.binary())
+ utf8_arr = binary_arr.cast(pa.utf8())
+ expected = pa.array(['foo', 'bar', 'baz'], type=pa.utf8())
+
+ assert utf8_arr.equals(expected)
+
+ non_utf8_values = [('mañana').encode('utf-16-le')]
+ non_utf8_binary = pa.array(non_utf8_values)
+ assert non_utf8_binary.type == pa.binary()
+ with pytest.raises(ValueError):
+ non_utf8_binary.cast(pa.string())
+
+ non_utf8_all_null = pa.array(non_utf8_values, mask=np.array([True]),
+ type=pa.binary())
+ # No error
+ casted = non_utf8_all_null.cast(pa.string())
+ assert casted.null_count == 1
+
+
+def test_cast_date64_to_int():
+ arr = pa.array(np.array([0, 1, 2], dtype='int64'),
+ type=pa.date64())
+ expected = pa.array([0, 1, 2], type='i8')
+
+ result = arr.cast('i8')
+
+ assert result.equals(expected)
+
+
+def test_date64_from_builtin_datetime():
+ val1 = datetime.datetime(2000, 1, 1, 12, 34, 56, 123456)
+ val2 = datetime.datetime(2000, 1, 1)
+ result = pa.array([val1, val2], type='date64')
+ result2 = pa.array([val1.date(), val2.date()], type='date64')
+
+ assert result.equals(result2)
+
+ as_i8 = result.view('int64')
+ assert as_i8[0].as_py() == as_i8[1].as_py()
+
+
+@pytest.mark.parametrize(('ty', 'values'), [
+ ('bool', [True, False, True]),
+ ('uint8', range(0, 255)),
+ ('int8', range(0, 128)),
+ ('uint16', range(0, 10)),
+ ('int16', range(0, 10)),
+ ('uint32', range(0, 10)),
+ ('int32', range(0, 10)),
+ ('uint64', range(0, 10)),
+ ('int64', range(0, 10)),
+ ('float', [0.0, 0.1, 0.2]),
+ ('double', [0.0, 0.1, 0.2]),
+ ('string', ['a', 'b', 'c']),
+ ('binary', [b'a', b'b', b'c']),
+ (pa.binary(3), [b'abc', b'bcd', b'cde'])
+])
+def test_cast_identities(ty, values):
+ arr = pa.array(values, type=ty)
+ assert arr.cast(ty).equals(arr)
+
+
+pickle_test_parametrize = pytest.mark.parametrize(
+ ('data', 'typ'),
+ [
+ ([True, False, True, True], pa.bool_()),
+ ([1, 2, 4, 6], pa.int64()),
+ ([1.0, 2.5, None], pa.float64()),
+ (['a', None, 'b'], pa.string()),
+ ([], None),
+ ([[1, 2], [3]], pa.list_(pa.int64())),
+ ([[4, 5], [6]], pa.large_list(pa.int16())),
+ ([['a'], None, ['b', 'c']], pa.list_(pa.string())),
+ ([(1, 'a'), (2, 'c'), None],
+ pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.string())]))
+ ]
+)
+
+
+@pickle_test_parametrize
+def test_array_pickle(data, typ):
+ # Allocate here so that we don't have any Arrow data allocated.
+ # This is needed to ensure that allocator tests can be reliable.
+ array = pa.array(data, type=typ)
+ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+ result = pickle.loads(pickle.dumps(array, proto))
+ assert array.equals(result)
+
+
+def test_array_pickle_dictionary():
+ # not included in the above as dictionary array cannot be created with
+ # the pa.array function
+ array = pa.DictionaryArray.from_arrays([0, 1, 2, 0, 1], ['a', 'b', 'c'])
+ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+ result = pickle.loads(pickle.dumps(array, proto))
+ assert array.equals(result)
+
+
+@h.given(
+ past.arrays(
+ past.all_types,
+ size=st.integers(min_value=0, max_value=10)
+ )
+)
+def test_pickling(arr):
+ data = pickle.dumps(arr)
+ restored = pickle.loads(data)
+ assert arr.equals(restored)
+
+
+@pickle_test_parametrize
+def test_array_pickle5(data, typ):
+ # Test zero-copy pickling with protocol 5 (PEP 574)
+ picklemod = pickle5 or pickle
+ if pickle5 is None and picklemod.HIGHEST_PROTOCOL < 5:
+ pytest.skip("need pickle5 package or Python 3.8+")
+
+ array = pa.array(data, type=typ)
+ addresses = [buf.address if buf is not None else 0
+ for buf in array.buffers()]
+
+ for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
+ buffers = []
+ pickled = picklemod.dumps(array, proto, buffer_callback=buffers.append)
+ result = picklemod.loads(pickled, buffers=buffers)
+ assert array.equals(result)
+
+ result_addresses = [buf.address if buf is not None else 0
+ for buf in result.buffers()]
+ assert result_addresses == addresses
+
+
+@pytest.mark.parametrize(
+ 'narr',
+ [
+ np.arange(10, dtype=np.int64),
+ np.arange(10, dtype=np.int32),
+ np.arange(10, dtype=np.int16),
+ np.arange(10, dtype=np.int8),
+ np.arange(10, dtype=np.uint64),
+ np.arange(10, dtype=np.uint32),
+ np.arange(10, dtype=np.uint16),
+ np.arange(10, dtype=np.uint8),
+ np.arange(10, dtype=np.float64),
+ np.arange(10, dtype=np.float32),
+ np.arange(10, dtype=np.float16),
+ ]
+)
+def test_to_numpy_roundtrip(narr):
+ arr = pa.array(narr)
+ assert narr.dtype == arr.to_numpy().dtype
+ np.testing.assert_array_equal(narr, arr.to_numpy())
+ np.testing.assert_array_equal(narr[:6], arr[:6].to_numpy())
+ np.testing.assert_array_equal(narr[2:], arr[2:].to_numpy())
+ np.testing.assert_array_equal(narr[2:6], arr[2:6].to_numpy())
+
+
+def test_array_uint64_from_py_over_range():
+ arr = pa.array([2 ** 63], type=pa.uint64())
+ expected = pa.array(np.array([2 ** 63], dtype='u8'))
+ assert arr.equals(expected)
+
+
+def test_array_conversions_no_sentinel_values():
+ arr = np.array([1, 2, 3, 4], dtype='int8')
+ refcount = sys.getrefcount(arr)
+ arr2 = pa.array(arr) # noqa
+ assert sys.getrefcount(arr) == (refcount + 1)
+
+ assert arr2.type == 'int8'
+
+ arr3 = pa.array(np.array([1, np.nan, 2, 3, np.nan, 4], dtype='float32'),
+ type='float32')
+ assert arr3.type == 'float32'
+ assert arr3.null_count == 0
+
+
+def test_time32_time64_from_integer():
+ # ARROW-4111
+ result = pa.array([1, 2, None], type=pa.time32('s'))
+ expected = pa.array([datetime.time(second=1),
+ datetime.time(second=2), None],
+ type=pa.time32('s'))
+ assert result.equals(expected)
+
+ result = pa.array([1, 2, None], type=pa.time32('ms'))
+ expected = pa.array([datetime.time(microsecond=1000),
+ datetime.time(microsecond=2000), None],
+ type=pa.time32('ms'))
+ assert result.equals(expected)
+
+ result = pa.array([1, 2, None], type=pa.time64('us'))
+ expected = pa.array([datetime.time(microsecond=1),
+ datetime.time(microsecond=2), None],
+ type=pa.time64('us'))
+ assert result.equals(expected)
+
+ result = pa.array([1000, 2000, None], type=pa.time64('ns'))
+ expected = pa.array([datetime.time(microsecond=1),
+ datetime.time(microsecond=2), None],
+ type=pa.time64('ns'))
+ assert result.equals(expected)
+
+
+def test_binary_string_pandas_null_sentinels():
+ # ARROW-6227
+ def _check_case(ty):
+ arr = pa.array(['string', np.nan], type=ty, from_pandas=True)
+ expected = pa.array(['string', None], type=ty)
+ assert arr.equals(expected)
+ _check_case('binary')
+ _check_case('utf8')
+
+
+def test_pandas_null_sentinels_raise_error():
+ # ARROW-6227
+ cases = [
+ ([None, np.nan], 'null'),
+ (['string', np.nan], 'binary'),
+ (['string', np.nan], 'utf8'),
+ (['string', np.nan], 'large_binary'),
+ (['string', np.nan], 'large_utf8'),
+ ([b'string', np.nan], pa.binary(6)),
+ ([True, np.nan], pa.bool_()),
+ ([decimal.Decimal('0'), np.nan], pa.decimal128(12, 2)),
+ ([0, np.nan], pa.date32()),
+ ([0, np.nan], pa.date32()),
+ ([0, np.nan], pa.date64()),
+ ([0, np.nan], pa.time32('s')),
+ ([0, np.nan], pa.time64('us')),
+ ([0, np.nan], pa.timestamp('us')),
+ ([0, np.nan], pa.duration('us')),
+ ]
+ for case, ty in cases:
+ # Both types of exceptions are raised. May want to clean that up
+ with pytest.raises((ValueError, TypeError)):
+ pa.array(case, type=ty)
+
+ # from_pandas option suppresses failure
+ result = pa.array(case, type=ty, from_pandas=True)
+ assert result.null_count == (1 if ty != 'null' else 2)
+
+
+@pytest.mark.pandas
+def test_pandas_null_sentinels_index():
+ # ARROW-7023 - ensure that when passing a pandas Index, "from_pandas"
+ # semantics are used
+ import pandas as pd
+ idx = pd.Index([1, 2, np.nan], dtype=object)
+ result = pa.array(idx)
+ expected = pa.array([1, 2, np.nan], from_pandas=True)
+ assert result.equals(expected)
+
+
+def test_array_from_numpy_datetimeD():
+ arr = np.array([None, datetime.date(2017, 4, 4)], dtype='datetime64[D]')
+
+ result = pa.array(arr)
+ expected = pa.array([None, datetime.date(2017, 4, 4)], type=pa.date32())
+ assert result.equals(expected)
+
+
+def test_array_from_naive_datetimes():
+ arr = pa.array([
+ None,
+ datetime.datetime(2017, 4, 4, 12, 11, 10),
+ datetime.datetime(2018, 1, 1, 0, 2, 0)
+ ])
+ assert arr.type == pa.timestamp('us', tz=None)
+
+
+@pytest.mark.parametrize(('dtype', 'type'), [
+ ('datetime64[s]', pa.timestamp('s')),
+ ('datetime64[ms]', pa.timestamp('ms')),
+ ('datetime64[us]', pa.timestamp('us')),
+ ('datetime64[ns]', pa.timestamp('ns'))
+])
+def test_array_from_numpy_datetime(dtype, type):
+ data = [
+ None,
+ datetime.datetime(2017, 4, 4, 12, 11, 10),
+ datetime.datetime(2018, 1, 1, 0, 2, 0)
+ ]
+
+ # from numpy array
+ arr = pa.array(np.array(data, dtype=dtype))
+ expected = pa.array(data, type=type)
+ assert arr.equals(expected)
+
+ # from list of numpy scalars
+ arr = pa.array(list(np.array(data, dtype=dtype)))
+ assert arr.equals(expected)
+
+
+def test_array_from_different_numpy_datetime_units_raises():
+ data = [
+ None,
+ datetime.datetime(2017, 4, 4, 12, 11, 10),
+ datetime.datetime(2018, 1, 1, 0, 2, 0)
+ ]
+ s = np.array(data, dtype='datetime64[s]')
+ ms = np.array(data, dtype='datetime64[ms]')
+ data = list(s[:2]) + list(ms[2:])
+
+ with pytest.raises(pa.ArrowNotImplementedError):
+ pa.array(data)
+
+
+@pytest.mark.parametrize('unit', ['ns', 'us', 'ms', 's'])
+def test_array_from_list_of_timestamps(unit):
+ n = np.datetime64('NaT', unit)
+ x = np.datetime64('2017-01-01 01:01:01.111111111', unit)
+ y = np.datetime64('2018-11-22 12:24:48.111111111', unit)
+
+ a1 = pa.array([n, x, y])
+ a2 = pa.array([n, x, y], type=pa.timestamp(unit))
+
+ assert a1.type == a2.type
+ assert a1.type.unit == unit
+ assert a1[0] == a2[0]
+
+
+def test_array_from_timestamp_with_generic_unit():
+ n = np.datetime64('NaT')
+ x = np.datetime64('2017-01-01 01:01:01.111111111')
+ y = np.datetime64('2018-11-22 12:24:48.111111111')
+
+ with pytest.raises(pa.ArrowNotImplementedError,
+ match='Unbound or generic datetime64 time unit'):
+ pa.array([n, x, y])
+
+
+@pytest.mark.parametrize(('dtype', 'type'), [
+ ('timedelta64[s]', pa.duration('s')),
+ ('timedelta64[ms]', pa.duration('ms')),
+ ('timedelta64[us]', pa.duration('us')),
+ ('timedelta64[ns]', pa.duration('ns'))
+])
+def test_array_from_numpy_timedelta(dtype, type):
+ data = [
+ None,
+ datetime.timedelta(1),
+ datetime.timedelta(0, 1)
+ ]
+
+ # from numpy array
+ np_arr = np.array(data, dtype=dtype)
+ arr = pa.array(np_arr)
+ assert isinstance(arr, pa.DurationArray)
+ assert arr.type == type
+ expected = pa.array(data, type=type)
+ assert arr.equals(expected)
+ assert arr.to_pylist() == data
+
+ # from list of numpy scalars
+ arr = pa.array(list(np.array(data, dtype=dtype)))
+ assert arr.equals(expected)
+ assert arr.to_pylist() == data
+
+
+def test_array_from_numpy_timedelta_incorrect_unit():
+ # generic (no unit)
+ td = np.timedelta64(1)
+
+ for data in [[td], np.array([td])]:
+ with pytest.raises(NotImplementedError):
+ pa.array(data)
+
+ # unsupported unit
+ td = np.timedelta64(1, 'M')
+ for data in [[td], np.array([td])]:
+ with pytest.raises(NotImplementedError):
+ pa.array(data)
+
+
+def test_array_from_numpy_ascii():
+ arr = np.array(['abcde', 'abc', ''], dtype='|S5')
+
+ arrow_arr = pa.array(arr)
+ assert arrow_arr.type == 'binary'
+ expected = pa.array(['abcde', 'abc', ''], type='binary')
+ assert arrow_arr.equals(expected)
+
+ mask = np.array([False, True, False])
+ arrow_arr = pa.array(arr, mask=mask)
+ expected = pa.array(['abcde', None, ''], type='binary')
+ assert arrow_arr.equals(expected)
+
+ # Strided variant
+ arr = np.array(['abcde', 'abc', ''] * 5, dtype='|S5')[::2]
+ mask = np.array([False, True, False] * 5)[::2]
+ arrow_arr = pa.array(arr, mask=mask)
+
+ expected = pa.array(['abcde', '', None, 'abcde', '', None, 'abcde', ''],
+ type='binary')
+ assert arrow_arr.equals(expected)
+
+ # 0 itemsize
+ arr = np.array(['', '', ''], dtype='|S0')
+ arrow_arr = pa.array(arr)
+ expected = pa.array(['', '', ''], type='binary')
+ assert arrow_arr.equals(expected)
+
+
+def test_interval_array_from_timedelta():
+ data = [
+ None,
+ datetime.timedelta(days=1, seconds=1, microseconds=1,
+ milliseconds=1, minutes=1, hours=1, weeks=1)]
+
+ # From timedelta (explicit type required)
+ arr = pa.array(data, pa.month_day_nano_interval())
+ assert isinstance(arr, pa.MonthDayNanoIntervalArray)
+ assert arr.type == pa.month_day_nano_interval()
+ expected_list = [
+ None,
+ pa.MonthDayNano([0, 8,
+ (datetime.timedelta(seconds=1, microseconds=1,
+ milliseconds=1, minutes=1,
+ hours=1) //
+ datetime.timedelta(microseconds=1)) * 1000])]
+ expected = pa.array(expected_list)
+ assert arr.equals(expected)
+ assert arr.to_pylist() == expected_list
+
+
+@pytest.mark.pandas
+def test_interval_array_from_relativedelta():
+ # dateutil is dependency of pandas
+ from dateutil.relativedelta import relativedelta
+ from pandas import DateOffset
+ data = [
+ None,
+ relativedelta(years=1, months=1,
+ days=1, seconds=1, microseconds=1,
+ minutes=1, hours=1, weeks=1, leapdays=1)]
+ # Note leapdays are ignored.
+
+ # From relativedelta
+ arr = pa.array(data)
+ assert isinstance(arr, pa.MonthDayNanoIntervalArray)
+ assert arr.type == pa.month_day_nano_interval()
+ expected_list = [
+ None,
+ pa.MonthDayNano([13, 8,
+ (datetime.timedelta(seconds=1, microseconds=1,
+ minutes=1, hours=1) //
+ datetime.timedelta(microseconds=1)) * 1000])]
+ expected = pa.array(expected_list)
+ assert arr.equals(expected)
+ assert arr.to_pandas().tolist() == [
+ None, DateOffset(months=13, days=8,
+ microseconds=(
+ datetime.timedelta(seconds=1, microseconds=1,
+ minutes=1, hours=1) //
+ datetime.timedelta(microseconds=1)),
+ nanoseconds=0)]
+ with pytest.raises(ValueError):
+ pa.array([DateOffset(years=((1 << 32) // 12), months=100)])
+ with pytest.raises(ValueError):
+ pa.array([DateOffset(weeks=((1 << 32) // 7), days=100)])
+ with pytest.raises(ValueError):
+ pa.array([DateOffset(seconds=((1 << 64) // 1000000000),
+ nanoseconds=1)])
+ with pytest.raises(ValueError):
+ pa.array([DateOffset(microseconds=((1 << 64) // 100))])
+
+
+@pytest.mark.pandas
+def test_interval_array_from_dateoffset():
+ from pandas.tseries.offsets import DateOffset
+ data = [
+ None,
+ DateOffset(years=1, months=1,
+ days=1, seconds=1, microseconds=1,
+ minutes=1, hours=1, weeks=1, nanoseconds=1),
+ DateOffset()]
+
+ arr = pa.array(data)
+ assert isinstance(arr, pa.MonthDayNanoIntervalArray)
+ assert arr.type == pa.month_day_nano_interval()
+ expected_list = [
+ None,
+ pa.MonthDayNano([13, 8, 3661000001001]),
+ pa.MonthDayNano([0, 0, 0])]
+ expected = pa.array(expected_list)
+ assert arr.equals(expected)
+ assert arr.to_pandas().tolist() == [
+ None, DateOffset(months=13, days=8,
+ microseconds=(
+ datetime.timedelta(seconds=1, microseconds=1,
+ minutes=1, hours=1) //
+ datetime.timedelta(microseconds=1)),
+ nanoseconds=1),
+ DateOffset(months=0, days=0, microseconds=0, nanoseconds=0)]
+
+
+def test_array_from_numpy_unicode():
+ dtypes = ['<U5', '>U5']
+
+ for dtype in dtypes:
+ arr = np.array(['abcde', 'abc', ''], dtype=dtype)
+
+ arrow_arr = pa.array(arr)
+ assert arrow_arr.type == 'utf8'
+ expected = pa.array(['abcde', 'abc', ''], type='utf8')
+ assert arrow_arr.equals(expected)
+
+ mask = np.array([False, True, False])
+ arrow_arr = pa.array(arr, mask=mask)
+ expected = pa.array(['abcde', None, ''], type='utf8')
+ assert arrow_arr.equals(expected)
+
+ # Strided variant
+ arr = np.array(['abcde', 'abc', ''] * 5, dtype=dtype)[::2]
+ mask = np.array([False, True, False] * 5)[::2]
+ arrow_arr = pa.array(arr, mask=mask)
+
+ expected = pa.array(['abcde', '', None, 'abcde', '', None,
+ 'abcde', ''], type='utf8')
+ assert arrow_arr.equals(expected)
+
+ # 0 itemsize
+ arr = np.array(['', '', ''], dtype='<U0')
+ arrow_arr = pa.array(arr)
+ expected = pa.array(['', '', ''], type='utf8')
+ assert arrow_arr.equals(expected)
+
+
+def test_array_string_from_non_string():
+ # ARROW-5682 - when converting to string raise on non string-like dtype
+ with pytest.raises(TypeError):
+ pa.array(np.array([1, 2, 3]), type=pa.string())
+
+
+def test_array_string_from_all_null():
+ # ARROW-5682
+ vals = np.array([None, None], dtype=object)
+ arr = pa.array(vals, type=pa.string())
+ assert arr.null_count == 2
+
+ vals = np.array([np.nan, np.nan], dtype='float64')
+ # by default raises, but accept as all-null when from_pandas=True
+ with pytest.raises(TypeError):
+ pa.array(vals, type=pa.string())
+ arr = pa.array(vals, type=pa.string(), from_pandas=True)
+ assert arr.null_count == 2
+
+
+def test_array_from_masked():
+ ma = np.ma.array([1, 2, 3, 4], dtype='int64',
+ mask=[False, False, True, False])
+ result = pa.array(ma)
+ expected = pa.array([1, 2, None, 4], type='int64')
+ assert expected.equals(result)
+
+ with pytest.raises(ValueError, match="Cannot pass a numpy masked array"):
+ pa.array(ma, mask=np.array([True, False, False, False]))
+
+
+def test_array_from_shrunken_masked():
+ ma = np.ma.array([0], dtype='int64')
+ result = pa.array(ma)
+ expected = pa.array([0], type='int64')
+ assert expected.equals(result)
+
+
+def test_array_from_invalid_dim_raises():
+ msg = "only handle 1-dimensional arrays"
+ arr2d = np.array([[1, 2, 3], [4, 5, 6]])
+ with pytest.raises(ValueError, match=msg):
+ pa.array(arr2d)
+
+ arr0d = np.array(0)
+ with pytest.raises(ValueError, match=msg):
+ pa.array(arr0d)
+
+
+def test_array_from_strided_bool():
+ # ARROW-6325
+ arr = np.ones((3, 2), dtype=bool)
+ result = pa.array(arr[:, 0])
+ expected = pa.array([True, True, True])
+ assert result.equals(expected)
+ result = pa.array(arr[0, :])
+ expected = pa.array([True, True])
+ assert result.equals(expected)
+
+
+def test_array_from_strided():
+ pydata = [
+ ([b"ab", b"cd", b"ef"], (pa.binary(), pa.binary(2))),
+ ([1, 2, 3], (pa.int8(), pa.int16(), pa.int32(), pa.int64())),
+ ([1.0, 2.0, 3.0], (pa.float32(), pa.float64())),
+ (["ab", "cd", "ef"], (pa.utf8(), ))
+ ]
+
+ for values, dtypes in pydata:
+ nparray = np.array(values)
+ for patype in dtypes:
+ for mask in (None, np.array([False, False])):
+ arrow_array = pa.array(nparray[::2], patype,
+ mask=mask)
+ assert values[::2] == arrow_array.to_pylist()
+
+
+def test_boolean_true_count_false_count():
+ # ARROW-9145
+ arr = pa.array([True, True, None, False, None, True] * 1000)
+ assert arr.true_count == 3000
+ assert arr.false_count == 1000
+
+
+def test_buffers_primitive():
+ a = pa.array([1, 2, None, 4], type=pa.int16())
+ buffers = a.buffers()
+ assert len(buffers) == 2
+ null_bitmap = buffers[0].to_pybytes()
+ assert 1 <= len(null_bitmap) <= 64 # XXX this is varying
+ assert bytearray(null_bitmap)[0] == 0b00001011
+
+ # Slicing does not affect the buffers but the offset
+ a_sliced = a[1:]
+ buffers = a_sliced.buffers()
+ a_sliced.offset == 1
+ assert len(buffers) == 2
+ null_bitmap = buffers[0].to_pybytes()
+ assert 1 <= len(null_bitmap) <= 64 # XXX this is varying
+ assert bytearray(null_bitmap)[0] == 0b00001011
+
+ assert struct.unpack('hhxxh', buffers[1].to_pybytes()) == (1, 2, 4)
+
+ a = pa.array(np.int8([4, 5, 6]))
+ buffers = a.buffers()
+ assert len(buffers) == 2
+ # No null bitmap from Numpy int array
+ assert buffers[0] is None
+ assert struct.unpack('3b', buffers[1].to_pybytes()) == (4, 5, 6)
+
+ a = pa.array([b'foo!', None, b'bar!!'])
+ buffers = a.buffers()
+ assert len(buffers) == 3
+ null_bitmap = buffers[0].to_pybytes()
+ assert bytearray(null_bitmap)[0] == 0b00000101
+ offsets = buffers[1].to_pybytes()
+ assert struct.unpack('4i', offsets) == (0, 4, 4, 9)
+ values = buffers[2].to_pybytes()
+ assert values == b'foo!bar!!'
+
+
+def test_buffers_nested():
+ a = pa.array([[1, 2], None, [3, None, 4, 5]], type=pa.list_(pa.int64()))
+ buffers = a.buffers()
+ assert len(buffers) == 4
+ # The parent buffers
+ null_bitmap = buffers[0].to_pybytes()
+ assert bytearray(null_bitmap)[0] == 0b00000101
+ offsets = buffers[1].to_pybytes()
+ assert struct.unpack('4i', offsets) == (0, 2, 2, 6)
+ # The child buffers
+ null_bitmap = buffers[2].to_pybytes()
+ assert bytearray(null_bitmap)[0] == 0b00110111
+ values = buffers[3].to_pybytes()
+ assert struct.unpack('qqq8xqq', values) == (1, 2, 3, 4, 5)
+
+ a = pa.array([(42, None), None, (None, 43)],
+ type=pa.struct([pa.field('a', pa.int8()),
+ pa.field('b', pa.int16())]))
+ buffers = a.buffers()
+ assert len(buffers) == 5
+ # The parent buffer
+ null_bitmap = buffers[0].to_pybytes()
+ assert bytearray(null_bitmap)[0] == 0b00000101
+ # The child buffers: 'a'
+ null_bitmap = buffers[1].to_pybytes()
+ assert bytearray(null_bitmap)[0] == 0b00000011
+ values = buffers[2].to_pybytes()
+ assert struct.unpack('bxx', values) == (42,)
+ # The child buffers: 'b'
+ null_bitmap = buffers[3].to_pybytes()
+ assert bytearray(null_bitmap)[0] == 0b00000110
+ values = buffers[4].to_pybytes()
+ assert struct.unpack('4xh', values) == (43,)
+
+
+def test_nbytes_sizeof():
+ a = pa.array(np.array([4, 5, 6], dtype='int64'))
+ assert a.nbytes == 8 * 3
+ assert sys.getsizeof(a) >= object.__sizeof__(a) + a.nbytes
+ a = pa.array([1, None, 3], type='int64')
+ assert a.nbytes == 8*3 + 1
+ assert sys.getsizeof(a) >= object.__sizeof__(a) + a.nbytes
+ a = pa.array([[1, 2], None, [3, None, 4, 5]], type=pa.list_(pa.int64()))
+ assert a.nbytes == 1 + 4 * 4 + 1 + 6 * 8
+ assert sys.getsizeof(a) >= object.__sizeof__(a) + a.nbytes
+
+
+def test_invalid_tensor_constructor_repr():
+ # ARROW-2638: prevent calling extension class constructors directly
+ with pytest.raises(TypeError):
+ repr(pa.Tensor([1]))
+
+
+def test_invalid_tensor_construction():
+ with pytest.raises(TypeError):
+ pa.Tensor()
+
+
+@pytest.mark.parametrize(('offset_type', 'list_type_factory'),
+ [(pa.int32(), pa.list_), (pa.int64(), pa.large_list)])
+def test_list_array_flatten(offset_type, list_type_factory):
+ typ2 = list_type_factory(
+ list_type_factory(
+ pa.int64()
+ )
+ )
+ arr2 = pa.array([
+ None,
+ [
+ [1, None, 2],
+ None,
+ [3, 4]
+ ],
+ [],
+ [
+ [],
+ [5, 6],
+ None
+ ],
+ [
+ [7, 8]
+ ]
+ ], type=typ2)
+ offsets2 = pa.array([0, 0, 3, 3, 6, 7], type=offset_type)
+
+ typ1 = list_type_factory(pa.int64())
+ arr1 = pa.array([
+ [1, None, 2],
+ None,
+ [3, 4],
+ [],
+ [5, 6],
+ None,
+ [7, 8]
+ ], type=typ1)
+ offsets1 = pa.array([0, 3, 3, 5, 5, 7, 7, 9], type=offset_type)
+
+ arr0 = pa.array([
+ 1, None, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8
+ ], type=pa.int64())
+
+ assert arr2.flatten().equals(arr1)
+ assert arr2.offsets.equals(offsets2)
+ assert arr2.values.equals(arr1)
+ assert arr1.flatten().equals(arr0)
+ assert arr1.offsets.equals(offsets1)
+ assert arr1.values.equals(arr0)
+ assert arr2.flatten().flatten().equals(arr0)
+ assert arr2.values.values.equals(arr0)
+
+
+@pytest.mark.parametrize(('offset_type', 'list_type_factory'),
+ [(pa.int32(), pa.list_), (pa.int64(), pa.large_list)])
+def test_list_value_parent_indices(offset_type, list_type_factory):
+ arr = pa.array(
+ [
+ [0, 1, 2],
+ None,
+ [],
+ [3, 4]
+ ], type=list_type_factory(pa.int32()))
+ expected = pa.array([0, 0, 0, 3, 3], type=offset_type)
+ assert arr.value_parent_indices().equals(expected)
+
+
+@pytest.mark.parametrize(('offset_type', 'list_type_factory'),
+ [(pa.int32(), pa.list_), (pa.int64(), pa.large_list)])
+def test_list_value_lengths(offset_type, list_type_factory):
+ arr = pa.array(
+ [
+ [0, 1, 2],
+ None,
+ [],
+ [3, 4]
+ ], type=list_type_factory(pa.int32()))
+ expected = pa.array([3, None, 0, 2], type=offset_type)
+ assert arr.value_lengths().equals(expected)
+
+
+@pytest.mark.parametrize('list_type_factory', [pa.list_, pa.large_list])
+def test_list_array_flatten_non_canonical(list_type_factory):
+ # Non-canonical list array (null elements backed by non-empty sublists)
+ typ = list_type_factory(pa.int64())
+ arr = pa.array([[1], [2, 3], [4, 5, 6]], type=typ)
+ buffers = arr.buffers()[:2]
+ buffers[0] = pa.py_buffer(b"\x05") # validity bitmap
+ arr = arr.from_buffers(arr.type, len(arr), buffers, children=[arr.values])
+ assert arr.to_pylist() == [[1], None, [4, 5, 6]]
+ assert arr.offsets.to_pylist() == [0, 1, 3, 6]
+
+ flattened = arr.flatten()
+ flattened.validate(full=True)
+ assert flattened.type == typ.value_type
+ assert flattened.to_pylist() == [1, 4, 5, 6]
+
+ # .values is the physical values array (including masked elements)
+ assert arr.values.to_pylist() == [1, 2, 3, 4, 5, 6]
+
+
+@pytest.mark.parametrize('klass', [pa.ListArray, pa.LargeListArray])
+def test_list_array_values_offsets_sliced(klass):
+ # ARROW-7301
+ arr = klass.from_arrays(offsets=[0, 3, 4, 6], values=[1, 2, 3, 4, 5, 6])
+ assert arr.values.to_pylist() == [1, 2, 3, 4, 5, 6]
+ assert arr.offsets.to_pylist() == [0, 3, 4, 6]
+
+ # sliced -> values keeps referring to full values buffer, but offsets is
+ # sliced as well so the offsets correctly point into the full values array
+ # sliced -> flatten() will return the sliced value array.
+ arr2 = arr[1:]
+ assert arr2.values.to_pylist() == [1, 2, 3, 4, 5, 6]
+ assert arr2.offsets.to_pylist() == [3, 4, 6]
+ assert arr2.flatten().to_pylist() == [4, 5, 6]
+ i = arr2.offsets[0].as_py()
+ j = arr2.offsets[1].as_py()
+ assert arr2[0].as_py() == arr2.values[i:j].to_pylist() == [4]
+
+
+def test_fixed_size_list_array_flatten():
+ typ2 = pa.list_(pa.list_(pa.int64(), 2), 3)
+ arr2 = pa.array([
+ [
+ [1, 2],
+ [3, 4],
+ [5, 6],
+ ],
+ None,
+ [
+ [7, None],
+ None,
+ [8, 9]
+ ],
+ ], type=typ2)
+ assert arr2.type.equals(typ2)
+
+ typ1 = pa.list_(pa.int64(), 2)
+ arr1 = pa.array([
+ [1, 2], [3, 4], [5, 6],
+ None, None, None,
+ [7, None], None, [8, 9]
+ ], type=typ1)
+ assert arr1.type.equals(typ1)
+ assert arr2.flatten().equals(arr1)
+
+ typ0 = pa.int64()
+ arr0 = pa.array([
+ 1, 2, 3, 4, 5, 6,
+ None, None, None, None, None, None,
+ 7, None, None, None, 8, 9,
+ ], type=typ0)
+ assert arr0.type.equals(typ0)
+ assert arr1.flatten().equals(arr0)
+ assert arr2.flatten().flatten().equals(arr0)
+
+
+def test_struct_array_flatten():
+ ty = pa.struct([pa.field('x', pa.int16()),
+ pa.field('y', pa.float32())])
+ a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty)
+ xs, ys = a.flatten()
+ assert xs.type == pa.int16()
+ assert ys.type == pa.float32()
+ assert xs.to_pylist() == [1, 3, 5]
+ assert ys.to_pylist() == [2.5, 4.5, 6.5]
+ xs, ys = a[1:].flatten()
+ assert xs.to_pylist() == [3, 5]
+ assert ys.to_pylist() == [4.5, 6.5]
+
+ a = pa.array([(1, 2.5), None, (3, 4.5)], type=ty)
+ xs, ys = a.flatten()
+ assert xs.to_pylist() == [1, None, 3]
+ assert ys.to_pylist() == [2.5, None, 4.5]
+ xs, ys = a[1:].flatten()
+ assert xs.to_pylist() == [None, 3]
+ assert ys.to_pylist() == [None, 4.5]
+
+ a = pa.array([(1, None), (2, 3.5), (None, 4.5)], type=ty)
+ xs, ys = a.flatten()
+ assert xs.to_pylist() == [1, 2, None]
+ assert ys.to_pylist() == [None, 3.5, 4.5]
+ xs, ys = a[1:].flatten()
+ assert xs.to_pylist() == [2, None]
+ assert ys.to_pylist() == [3.5, 4.5]
+
+ a = pa.array([(1, None), None, (None, 2.5)], type=ty)
+ xs, ys = a.flatten()
+ assert xs.to_pylist() == [1, None, None]
+ assert ys.to_pylist() == [None, None, 2.5]
+ xs, ys = a[1:].flatten()
+ assert xs.to_pylist() == [None, None]
+ assert ys.to_pylist() == [None, 2.5]
+
+
+def test_struct_array_field():
+ ty = pa.struct([pa.field('x', pa.int16()),
+ pa.field('y', pa.float32())])
+ a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty)
+
+ x0 = a.field(0)
+ y0 = a.field(1)
+ x1 = a.field(-2)
+ y1 = a.field(-1)
+ x2 = a.field('x')
+ y2 = a.field('y')
+
+ assert isinstance(x0, pa.lib.Int16Array)
+ assert isinstance(y1, pa.lib.FloatArray)
+ assert x0.equals(pa.array([1, 3, 5], type=pa.int16()))
+ assert y0.equals(pa.array([2.5, 4.5, 6.5], type=pa.float32()))
+ assert x0.equals(x1)
+ assert x0.equals(x2)
+ assert y0.equals(y1)
+ assert y0.equals(y2)
+
+ for invalid_index in [None, pa.int16()]:
+ with pytest.raises(TypeError):
+ a.field(invalid_index)
+
+ for invalid_index in [3, -3]:
+ with pytest.raises(IndexError):
+ a.field(invalid_index)
+
+ for invalid_name in ['z', '']:
+ with pytest.raises(KeyError):
+ a.field(invalid_name)
+
+
+def test_empty_cast():
+ types = [
+ pa.null(),
+ pa.bool_(),
+ pa.int8(),
+ pa.int16(),
+ pa.int32(),
+ pa.int64(),
+ pa.uint8(),
+ pa.uint16(),
+ pa.uint32(),
+ pa.uint64(),
+ pa.float16(),
+ pa.float32(),
+ pa.float64(),
+ pa.date32(),
+ pa.date64(),
+ pa.binary(),
+ pa.binary(length=4),
+ pa.string(),
+ ]
+
+ for (t1, t2) in itertools.product(types, types):
+ try:
+ # ARROW-4766: Ensure that supported types conversion don't segfault
+ # on empty arrays of common types
+ pa.array([], type=t1).cast(t2)
+ except (pa.lib.ArrowNotImplementedError, pa.ArrowInvalid):
+ continue
+
+
+def test_nested_dictionary_array():
+ dict_arr = pa.DictionaryArray.from_arrays([0, 1, 0], ['a', 'b'])
+ list_arr = pa.ListArray.from_arrays([0, 2, 3], dict_arr)
+ assert list_arr.to_pylist() == [['a', 'b'], ['a']]
+
+ dict_arr = pa.DictionaryArray.from_arrays([0, 1, 0], ['a', 'b'])
+ dict_arr2 = pa.DictionaryArray.from_arrays([0, 1, 2, 1, 0], dict_arr)
+ assert dict_arr2.to_pylist() == ['a', 'b', 'a', 'b', 'a']
+
+
+def test_array_from_numpy_str_utf8():
+ # ARROW-3890 -- in Python 3, NPY_UNICODE arrays are produced, but in Python
+ # 2 they are NPY_STRING (binary), so we must do UTF-8 validation
+ vec = np.array(["toto", "tata"])
+ vec2 = np.array(["toto", "tata"], dtype=object)
+
+ arr = pa.array(vec, pa.string())
+ arr2 = pa.array(vec2, pa.string())
+ expected = pa.array(["toto", "tata"])
+ assert arr.equals(expected)
+ assert arr2.equals(expected)
+
+ # with mask, separate code path
+ mask = np.array([False, False], dtype=bool)
+ arr = pa.array(vec, pa.string(), mask=mask)
+ assert arr.equals(expected)
+
+ # UTF8 validation failures
+ vec = np.array([('mañana').encode('utf-16-le')])
+ with pytest.raises(ValueError):
+ pa.array(vec, pa.string())
+
+ with pytest.raises(ValueError):
+ pa.array(vec, pa.string(), mask=np.array([False]))
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+def test_numpy_binary_overflow_to_chunked():
+ # ARROW-3762, ARROW-5966
+
+ # 2^31 + 1 bytes
+ values = [b'x']
+ unicode_values = ['x']
+
+ # Make 10 unique 1MB strings then repeat then 2048 times
+ unique_strings = {
+ i: b'x' * ((1 << 20) - 1) + str(i % 10).encode('utf8')
+ for i in range(10)
+ }
+ unicode_unique_strings = {i: x.decode('utf8')
+ for i, x in unique_strings.items()}
+ values += [unique_strings[i % 10] for i in range(1 << 11)]
+ unicode_values += [unicode_unique_strings[i % 10]
+ for i in range(1 << 11)]
+
+ for case, ex_type in [(values, pa.binary()),
+ (unicode_values, pa.utf8())]:
+ arr = np.array(case)
+ arrow_arr = pa.array(arr)
+ arr = None
+
+ assert isinstance(arrow_arr, pa.ChunkedArray)
+ assert arrow_arr.type == ex_type
+
+ # Split up into 16MB chunks. 128 * 16 = 2048, so 129
+ assert arrow_arr.num_chunks == 129
+
+ value_index = 0
+ for i in range(arrow_arr.num_chunks):
+ chunk = arrow_arr.chunk(i)
+ for val in chunk:
+ assert val.as_py() == case[value_index]
+ value_index += 1
+
+
+@pytest.mark.large_memory
+def test_list_child_overflow_to_chunked():
+ kilobyte_string = 'x' * 1024
+ two_mega = 2**21
+
+ vals = [[kilobyte_string]] * (two_mega - 1)
+ arr = pa.array(vals)
+ assert isinstance(arr, pa.Array)
+ assert len(arr) == two_mega - 1
+
+ vals = [[kilobyte_string]] * two_mega
+ arr = pa.array(vals)
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr) == two_mega
+ assert len(arr.chunk(0)) == two_mega - 1
+ assert len(arr.chunk(1)) == 1
+
+
+def test_infer_type_masked():
+ # ARROW-5208
+ ty = pa.infer_type(['foo', 'bar', None, 2],
+ mask=[False, False, False, True])
+ assert ty == pa.utf8()
+
+ # all masked
+ ty = pa.infer_type(['foo', 'bar', None, 2],
+ mask=np.array([True, True, True, True]))
+ assert ty == pa.null()
+
+ # length 0
+ assert pa.infer_type([], mask=[]) == pa.null()
+
+
+def test_array_masked():
+ # ARROW-5208
+ arr = pa.array([4, None, 4, 3.],
+ mask=np.array([False, True, False, True]))
+ assert arr.type == pa.int64()
+
+ # ndarray dtype=object argument
+ arr = pa.array(np.array([4, None, 4, 3.], dtype="O"),
+ mask=np.array([False, True, False, True]))
+ assert arr.type == pa.int64()
+
+
+def test_array_supported_masks():
+ # ARROW-13883
+ arr = pa.array([4, None, 4, 3.],
+ mask=np.array([False, True, False, True]))
+ assert arr.to_pylist() == [4, None, 4, None]
+
+ arr = pa.array([4, None, 4, 3],
+ mask=pa.array([False, True, False, True]))
+ assert arr.to_pylist() == [4, None, 4, None]
+
+ arr = pa.array([4, None, 4, 3],
+ mask=[False, True, False, True])
+ assert arr.to_pylist() == [4, None, 4, None]
+
+ arr = pa.array([4, 3, None, 3],
+ mask=[False, True, False, True])
+ assert arr.to_pylist() == [4, None, None, None]
+
+ # Non boolean values
+ with pytest.raises(pa.ArrowTypeError):
+ arr = pa.array([4, None, 4, 3],
+ mask=pa.array([1.0, 2.0, 3.0, 4.0]))
+
+ with pytest.raises(pa.ArrowTypeError):
+ arr = pa.array([4, None, 4, 3],
+ mask=[1.0, 2.0, 3.0, 4.0])
+
+ with pytest.raises(pa.ArrowTypeError):
+ arr = pa.array([4, None, 4, 3],
+ mask=np.array([1.0, 2.0, 3.0, 4.0]))
+
+ with pytest.raises(pa.ArrowTypeError):
+ arr = pa.array([4, None, 4, 3],
+ mask=pa.array([False, True, False, True],
+ mask=pa.array([True, True, True, True])))
+
+ with pytest.raises(pa.ArrowTypeError):
+ arr = pa.array([4, None, 4, 3],
+ mask=pa.array([False, None, False, True]))
+
+ # Numpy arrays only accepts numpy masks
+ with pytest.raises(TypeError):
+ arr = pa.array(np.array([4, None, 4, 3.]),
+ mask=[True, False, True, False])
+
+ with pytest.raises(TypeError):
+ arr = pa.array(np.array([4, None, 4, 3.]),
+ mask=pa.array([True, False, True, False]))
+
+
+def test_binary_array_masked():
+ # ARROW-12431
+ masked_basic = pa.array([b'\x05'], type=pa.binary(1),
+ mask=np.array([False]))
+ assert [b'\x05'] == masked_basic.to_pylist()
+
+ # Fixed Length Binary
+ masked = pa.array(np.array([b'\x05']), type=pa.binary(1),
+ mask=np.array([False]))
+ assert [b'\x05'] == masked.to_pylist()
+
+ masked_nulls = pa.array(np.array([b'\x05']), type=pa.binary(1),
+ mask=np.array([True]))
+ assert [None] == masked_nulls.to_pylist()
+
+ # Variable Length Binary
+ masked = pa.array(np.array([b'\x05']), type=pa.binary(),
+ mask=np.array([False]))
+ assert [b'\x05'] == masked.to_pylist()
+
+ masked_nulls = pa.array(np.array([b'\x05']), type=pa.binary(),
+ mask=np.array([True]))
+ assert [None] == masked_nulls.to_pylist()
+
+ # Fixed Length Binary, copy
+ npa = np.array([b'aaa', b'bbb', b'ccc']*10)
+ arrow_array = pa.array(npa, type=pa.binary(3),
+ mask=np.array([False, False, False]*10))
+ npa[npa == b"bbb"] = b"XXX"
+ assert ([b'aaa', b'bbb', b'ccc']*10) == arrow_array.to_pylist()
+
+
+def test_binary_array_strided():
+ # Masked
+ nparray = np.array([b"ab", b"cd", b"ef"])
+ arrow_array = pa.array(nparray[::2], pa.binary(2),
+ mask=np.array([False, False]))
+ assert [b"ab", b"ef"] == arrow_array.to_pylist()
+
+ # Unmasked
+ nparray = np.array([b"ab", b"cd", b"ef"])
+ arrow_array = pa.array(nparray[::2], pa.binary(2))
+ assert [b"ab", b"ef"] == arrow_array.to_pylist()
+
+
+def test_array_invalid_mask_raises():
+ # ARROW-10742
+ cases = [
+ ([1, 2], np.array([False, False], dtype="O"),
+ TypeError, "must be boolean dtype"),
+
+ ([1, 2], np.array([[False], [False]]),
+ pa.ArrowInvalid, "must be 1D array"),
+
+ ([1, 2, 3], np.array([False, False]),
+ pa.ArrowInvalid, "different length"),
+
+ (np.array([1, 2]), np.array([False, False], dtype="O"),
+ TypeError, "must be boolean dtype"),
+
+ (np.array([1, 2]), np.array([[False], [False]]),
+ ValueError, "must be 1D array"),
+
+ (np.array([1, 2, 3]), np.array([False, False]),
+ ValueError, "different length"),
+ ]
+ for obj, mask, ex, msg in cases:
+ with pytest.raises(ex, match=msg):
+ pa.array(obj, mask=mask)
+
+
+def test_array_from_large_pyints():
+ # ARROW-5430
+ with pytest.raises(OverflowError):
+ # too large for int64 so dtype must be explicitly provided
+ pa.array([int(2 ** 63)])
+
+
+def test_array_protocol():
+
+ class MyArray:
+ def __init__(self, data):
+ self.data = data
+
+ def __arrow_array__(self, type=None):
+ return pa.array(self.data, type=type)
+
+ arr = MyArray(np.array([1, 2, 3], dtype='int64'))
+ result = pa.array(arr)
+ expected = pa.array([1, 2, 3], type=pa.int64())
+ assert result.equals(expected)
+ result = pa.array(arr, type=pa.int64())
+ expected = pa.array([1, 2, 3], type=pa.int64())
+ assert result.equals(expected)
+ result = pa.array(arr, type=pa.float64())
+ expected = pa.array([1, 2, 3], type=pa.float64())
+ assert result.equals(expected)
+
+ # raise error when passing size or mask keywords
+ with pytest.raises(ValueError):
+ pa.array(arr, mask=np.array([True, False, True]))
+ with pytest.raises(ValueError):
+ pa.array(arr, size=3)
+
+ # ensure the return value is an Array
+ class MyArrayInvalid:
+ def __init__(self, data):
+ self.data = data
+
+ def __arrow_array__(self, type=None):
+ return np.array(self.data)
+
+ arr = MyArrayInvalid(np.array([1, 2, 3], dtype='int64'))
+ with pytest.raises(TypeError):
+ pa.array(arr)
+
+ # ARROW-7066 - allow ChunkedArray output
+ class MyArray2:
+ def __init__(self, data):
+ self.data = data
+
+ def __arrow_array__(self, type=None):
+ return pa.chunked_array([self.data], type=type)
+
+ arr = MyArray2(np.array([1, 2, 3], dtype='int64'))
+ result = pa.array(arr)
+ expected = pa.chunked_array([[1, 2, 3]], type=pa.int64())
+ assert result.equals(expected)
+
+
+def test_concat_array():
+ concatenated = pa.concat_arrays(
+ [pa.array([1, 2]), pa.array([3, 4])])
+ assert concatenated.equals(pa.array([1, 2, 3, 4]))
+
+
+def test_concat_array_different_types():
+ with pytest.raises(pa.ArrowInvalid):
+ pa.concat_arrays([pa.array([1]), pa.array([2.])])
+
+
+def test_concat_array_invalid_type():
+ # ARROW-9920 - do not segfault on non-array input
+
+ with pytest.raises(TypeError, match="should contain Array objects"):
+ pa.concat_arrays([None])
+
+ arr = pa.chunked_array([[0, 1], [3, 4]])
+ with pytest.raises(TypeError, match="should contain Array objects"):
+ pa.concat_arrays(arr)
+
+
+@pytest.mark.pandas
+def test_to_pandas_timezone():
+ # https://issues.apache.org/jira/browse/ARROW-6652
+ arr = pa.array([1, 2, 3], type=pa.timestamp('s', tz='Europe/Brussels'))
+ s = arr.to_pandas()
+ assert s.dt.tz is not None
+ arr = pa.chunked_array([arr])
+ s = arr.to_pandas()
+ assert s.dt.tz is not None
diff --git a/src/arrow/python/pyarrow/tests/test_builder.py b/src/arrow/python/pyarrow/tests/test_builder.py
new file mode 100644
index 000000000..50d801026
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_builder.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import weakref
+
+import numpy as np
+
+import pyarrow as pa
+from pyarrow.lib import StringBuilder
+
+
+def test_weakref():
+ sbuilder = StringBuilder()
+ wr = weakref.ref(sbuilder)
+ assert wr() is not None
+ del sbuilder
+ assert wr() is None
+
+
+def test_string_builder_append():
+ sbuilder = StringBuilder()
+ sbuilder.append(b"a byte string")
+ sbuilder.append("a string")
+ sbuilder.append(np.nan)
+ sbuilder.append(None)
+ assert len(sbuilder) == 4
+ assert sbuilder.null_count == 2
+ arr = sbuilder.finish()
+ assert len(sbuilder) == 0
+ assert isinstance(arr, pa.Array)
+ assert arr.null_count == 2
+ assert arr.type == 'str'
+ expected = ["a byte string", "a string", None, None]
+ assert arr.to_pylist() == expected
+
+
+def test_string_builder_append_values():
+ sbuilder = StringBuilder()
+ sbuilder.append_values([np.nan, None, "text", None, "other text"])
+ assert sbuilder.null_count == 3
+ arr = sbuilder.finish()
+ assert arr.null_count == 3
+ expected = [None, None, "text", None, "other text"]
+ assert arr.to_pylist() == expected
+
+
+def test_string_builder_append_after_finish():
+ sbuilder = StringBuilder()
+ sbuilder.append_values([np.nan, None, "text", None, "other text"])
+ arr = sbuilder.finish()
+ sbuilder.append("No effect")
+ expected = [None, None, "text", None, "other text"]
+ assert arr.to_pylist() == expected
diff --git a/src/arrow/python/pyarrow/tests/test_cffi.py b/src/arrow/python/pyarrow/tests/test_cffi.py
new file mode 100644
index 000000000..f0ce42909
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_cffi.py
@@ -0,0 +1,398 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import gc
+
+import pyarrow as pa
+try:
+ from pyarrow.cffi import ffi
+except ImportError:
+ ffi = None
+
+import pytest
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+except ImportError:
+ pd = tm = None
+
+
+needs_cffi = pytest.mark.skipif(ffi is None,
+ reason="test needs cffi package installed")
+
+
+assert_schema_released = pytest.raises(
+ ValueError, match="Cannot import released ArrowSchema")
+
+assert_array_released = pytest.raises(
+ ValueError, match="Cannot import released ArrowArray")
+
+assert_stream_released = pytest.raises(
+ ValueError, match="Cannot import released ArrowArrayStream")
+
+
+class ParamExtType(pa.PyExtensionType):
+
+ def __init__(self, width):
+ self._width = width
+ pa.PyExtensionType.__init__(self, pa.binary(width))
+
+ @property
+ def width(self):
+ return self._width
+
+ def __reduce__(self):
+ return ParamExtType, (self.width,)
+
+
+def make_schema():
+ return pa.schema([('ints', pa.list_(pa.int32()))],
+ metadata={b'key1': b'value1'})
+
+
+def make_extension_schema():
+ return pa.schema([('ext', ParamExtType(3))],
+ metadata={b'key1': b'value1'})
+
+
+def make_batch():
+ return pa.record_batch([[[1], [2, 42]]], make_schema())
+
+
+def make_extension_batch():
+ schema = make_extension_schema()
+ ext_col = schema[0].type.wrap_array(pa.array([b"foo", b"bar"],
+ type=pa.binary(3)))
+ return pa.record_batch([ext_col], schema)
+
+
+def make_batches():
+ schema = make_schema()
+ return [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+
+
+def make_serialized(schema, batches):
+ with pa.BufferOutputStream() as sink:
+ with pa.ipc.new_stream(sink, schema) as out:
+ for batch in batches:
+ out.write(batch)
+ return sink.getvalue()
+
+
+@needs_cffi
+def test_export_import_type():
+ c_schema = ffi.new("struct ArrowSchema*")
+ ptr_schema = int(ffi.cast("uintptr_t", c_schema))
+
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ typ = pa.list_(pa.int32())
+ typ._export_to_c(ptr_schema)
+ assert pa.total_allocated_bytes() > old_allocated
+ # Delete and recreate C++ object from exported pointer
+ del typ
+ assert pa.total_allocated_bytes() > old_allocated
+ typ_new = pa.DataType._import_from_c(ptr_schema)
+ assert typ_new == pa.list_(pa.int32())
+ assert pa.total_allocated_bytes() == old_allocated
+ # Now released
+ with assert_schema_released:
+ pa.DataType._import_from_c(ptr_schema)
+
+ # Invalid format string
+ pa.int32()._export_to_c(ptr_schema)
+ bad_format = ffi.new("char[]", b"zzz")
+ c_schema.format = bad_format
+ with pytest.raises(ValueError,
+ match="Invalid or unsupported format string"):
+ pa.DataType._import_from_c(ptr_schema)
+ # Now released
+ with assert_schema_released:
+ pa.DataType._import_from_c(ptr_schema)
+
+
+@needs_cffi
+def test_export_import_field():
+ c_schema = ffi.new("struct ArrowSchema*")
+ ptr_schema = int(ffi.cast("uintptr_t", c_schema))
+
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ field = pa.field("test", pa.list_(pa.int32()), nullable=True)
+ field._export_to_c(ptr_schema)
+ assert pa.total_allocated_bytes() > old_allocated
+ # Delete and recreate C++ object from exported pointer
+ del field
+ assert pa.total_allocated_bytes() > old_allocated
+
+ field_new = pa.Field._import_from_c(ptr_schema)
+ assert field_new == pa.field("test", pa.list_(pa.int32()), nullable=True)
+ assert pa.total_allocated_bytes() == old_allocated
+
+ # Now released
+ with assert_schema_released:
+ pa.Field._import_from_c(ptr_schema)
+
+
+@needs_cffi
+def test_export_import_array():
+ c_schema = ffi.new("struct ArrowSchema*")
+ ptr_schema = int(ffi.cast("uintptr_t", c_schema))
+ c_array = ffi.new("struct ArrowArray*")
+ ptr_array = int(ffi.cast("uintptr_t", c_array))
+
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ # Type is known up front
+ typ = pa.list_(pa.int32())
+ arr = pa.array([[1], [2, 42]], type=typ)
+ py_value = arr.to_pylist()
+ arr._export_to_c(ptr_array)
+ assert pa.total_allocated_bytes() > old_allocated
+ # Delete recreate C++ object from exported pointer
+ del arr
+ arr_new = pa.Array._import_from_c(ptr_array, typ)
+ assert arr_new.to_pylist() == py_value
+ assert arr_new.type == pa.list_(pa.int32())
+ assert pa.total_allocated_bytes() > old_allocated
+ del arr_new, typ
+ assert pa.total_allocated_bytes() == old_allocated
+ # Now released
+ with assert_array_released:
+ pa.Array._import_from_c(ptr_array, pa.list_(pa.int32()))
+
+ # Type is exported and imported at the same time
+ arr = pa.array([[1], [2, 42]], type=pa.list_(pa.int32()))
+ py_value = arr.to_pylist()
+ arr._export_to_c(ptr_array, ptr_schema)
+ # Delete and recreate C++ objects from exported pointers
+ del arr
+ arr_new = pa.Array._import_from_c(ptr_array, ptr_schema)
+ assert arr_new.to_pylist() == py_value
+ assert arr_new.type == pa.list_(pa.int32())
+ assert pa.total_allocated_bytes() > old_allocated
+ del arr_new
+ assert pa.total_allocated_bytes() == old_allocated
+ # Now released
+ with assert_schema_released:
+ pa.Array._import_from_c(ptr_array, ptr_schema)
+
+
+def check_export_import_schema(schema_factory):
+ c_schema = ffi.new("struct ArrowSchema*")
+ ptr_schema = int(ffi.cast("uintptr_t", c_schema))
+
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ schema_factory()._export_to_c(ptr_schema)
+ assert pa.total_allocated_bytes() > old_allocated
+ # Delete and recreate C++ object from exported pointer
+ schema_new = pa.Schema._import_from_c(ptr_schema)
+ assert schema_new == schema_factory()
+ assert pa.total_allocated_bytes() == old_allocated
+ del schema_new
+ assert pa.total_allocated_bytes() == old_allocated
+ # Now released
+ with assert_schema_released:
+ pa.Schema._import_from_c(ptr_schema)
+
+ # Not a struct type
+ pa.int32()._export_to_c(ptr_schema)
+ with pytest.raises(ValueError,
+ match="ArrowSchema describes non-struct type"):
+ pa.Schema._import_from_c(ptr_schema)
+ # Now released
+ with assert_schema_released:
+ pa.Schema._import_from_c(ptr_schema)
+
+
+@needs_cffi
+def test_export_import_schema():
+ check_export_import_schema(make_schema)
+
+
+@needs_cffi
+def test_export_import_schema_with_extension():
+ check_export_import_schema(make_extension_schema)
+
+
+def check_export_import_batch(batch_factory):
+ c_schema = ffi.new("struct ArrowSchema*")
+ ptr_schema = int(ffi.cast("uintptr_t", c_schema))
+ c_array = ffi.new("struct ArrowArray*")
+ ptr_array = int(ffi.cast("uintptr_t", c_array))
+
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ # Schema is known up front
+ batch = batch_factory()
+ schema = batch.schema
+ py_value = batch.to_pydict()
+ batch._export_to_c(ptr_array)
+ assert pa.total_allocated_bytes() > old_allocated
+ # Delete and recreate C++ object from exported pointer
+ del batch
+ batch_new = pa.RecordBatch._import_from_c(ptr_array, schema)
+ assert batch_new.to_pydict() == py_value
+ assert batch_new.schema == schema
+ assert pa.total_allocated_bytes() > old_allocated
+ del batch_new, schema
+ assert pa.total_allocated_bytes() == old_allocated
+ # Now released
+ with assert_array_released:
+ pa.RecordBatch._import_from_c(ptr_array, make_schema())
+
+ # Type is exported and imported at the same time
+ batch = batch_factory()
+ py_value = batch.to_pydict()
+ batch._export_to_c(ptr_array, ptr_schema)
+ # Delete and recreate C++ objects from exported pointers
+ del batch
+ batch_new = pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
+ assert batch_new.to_pydict() == py_value
+ assert batch_new.schema == batch_factory().schema
+ assert pa.total_allocated_bytes() > old_allocated
+ del batch_new
+ assert pa.total_allocated_bytes() == old_allocated
+ # Now released
+ with assert_schema_released:
+ pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
+
+ # Not a struct type
+ pa.int32()._export_to_c(ptr_schema)
+ batch_factory()._export_to_c(ptr_array)
+ with pytest.raises(ValueError,
+ match="ArrowSchema describes non-struct type"):
+ pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
+ # Now released
+ with assert_schema_released:
+ pa.RecordBatch._import_from_c(ptr_array, ptr_schema)
+
+
+@needs_cffi
+def test_export_import_batch():
+ check_export_import_batch(make_batch)
+
+
+@needs_cffi
+def test_export_import_batch_with_extension():
+ check_export_import_batch(make_extension_batch)
+
+
+def _export_import_batch_reader(ptr_stream, reader_factory):
+ # Prepare input
+ batches = make_batches()
+ schema = batches[0].schema
+
+ reader = reader_factory(schema, batches)
+ reader._export_to_c(ptr_stream)
+ # Delete and recreate C++ object from exported pointer
+ del reader, batches
+
+ reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream)
+ assert reader_new.schema == schema
+ got_batches = list(reader_new)
+ del reader_new
+ assert got_batches == make_batches()
+
+ # Test read_pandas()
+ if pd is not None:
+ batches = make_batches()
+ schema = batches[0].schema
+ expected_df = pa.Table.from_batches(batches).to_pandas()
+
+ reader = reader_factory(schema, batches)
+ reader._export_to_c(ptr_stream)
+ del reader, batches
+
+ reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream)
+ got_df = reader_new.read_pandas()
+ del reader_new
+ tm.assert_frame_equal(expected_df, got_df)
+
+
+def make_ipc_stream_reader(schema, batches):
+ return pa.ipc.open_stream(make_serialized(schema, batches))
+
+
+def make_py_record_batch_reader(schema, batches):
+ return pa.ipc.RecordBatchReader.from_batches(schema, batches)
+
+
+@needs_cffi
+@pytest.mark.parametrize('reader_factory',
+ [make_ipc_stream_reader,
+ make_py_record_batch_reader])
+def test_export_import_batch_reader(reader_factory):
+ c_stream = ffi.new("struct ArrowArrayStream*")
+ ptr_stream = int(ffi.cast("uintptr_t", c_stream))
+
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ _export_import_batch_reader(ptr_stream, reader_factory)
+
+ assert pa.total_allocated_bytes() == old_allocated
+
+ # Now released
+ with assert_stream_released:
+ pa.ipc.RecordBatchReader._import_from_c(ptr_stream)
+
+
+@needs_cffi
+def test_imported_batch_reader_error():
+ c_stream = ffi.new("struct ArrowArrayStream*")
+ ptr_stream = int(ffi.cast("uintptr_t", c_stream))
+
+ schema = pa.schema([('foo', pa.int32())])
+ batches = [pa.record_batch([[1, 2, 3]], schema=schema),
+ pa.record_batch([[4, 5, 6]], schema=schema)]
+ buf = make_serialized(schema, batches)
+
+ # Open a corrupt/incomplete stream and export it
+ reader = pa.ipc.open_stream(buf[:-16])
+ reader._export_to_c(ptr_stream)
+ del reader
+
+ reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream)
+ batch = reader_new.read_next_batch()
+ assert batch == batches[0]
+ with pytest.raises(OSError,
+ match="Expected to be able to read 16 bytes "
+ "for message body, got 8"):
+ reader_new.read_next_batch()
+
+ # Again, but call read_all()
+ reader = pa.ipc.open_stream(buf[:-16])
+ reader._export_to_c(ptr_stream)
+ del reader
+
+ reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream)
+ with pytest.raises(OSError,
+ match="Expected to be able to read 16 bytes "
+ "for message body, got 8"):
+ reader_new.read_all()
diff --git a/src/arrow/python/pyarrow/tests/test_compute.py b/src/arrow/python/pyarrow/tests/test_compute.py
new file mode 100644
index 000000000..be2da31b9
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_compute.py
@@ -0,0 +1,2238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+from functools import lru_cache, partial
+import inspect
+import pickle
+import pytest
+import random
+import sys
+import textwrap
+
+import numpy as np
+
+try:
+ import pandas as pd
+except ImportError:
+ pd = None
+
+import pyarrow as pa
+import pyarrow.compute as pc
+
+all_array_types = [
+ ('bool', [True, False, False, True, True]),
+ ('uint8', np.arange(5)),
+ ('int8', np.arange(5)),
+ ('uint16', np.arange(5)),
+ ('int16', np.arange(5)),
+ ('uint32', np.arange(5)),
+ ('int32', np.arange(5)),
+ ('uint64', np.arange(5, 10)),
+ ('int64', np.arange(5, 10)),
+ ('float', np.arange(0, 0.5, 0.1)),
+ ('double', np.arange(0, 0.5, 0.1)),
+ ('string', ['a', 'b', None, 'ddd', 'ee']),
+ ('binary', [b'a', b'b', b'c', b'ddd', b'ee']),
+ (pa.binary(3), [b'abc', b'bcd', b'cde', b'def', b'efg']),
+ (pa.list_(pa.int8()), [[1, 2], [3, 4], [5, 6], None, [9, 16]]),
+ (pa.large_list(pa.int16()), [[1], [2, 3, 4], [5, 6], None, [9, 16]]),
+ (pa.struct([('a', pa.int8()), ('b', pa.int8())]), [
+ {'a': 1, 'b': 2}, None, {'a': 3, 'b': 4}, None, {'a': 5, 'b': 6}]),
+]
+
+exported_functions = [
+ func for (name, func) in sorted(pc.__dict__.items())
+ if hasattr(func, '__arrow_compute_function__')]
+
+exported_option_classes = [
+ cls for (name, cls) in sorted(pc.__dict__.items())
+ if (isinstance(cls, type) and
+ cls is not pc.FunctionOptions and
+ issubclass(cls, pc.FunctionOptions))]
+
+numerical_arrow_types = [
+ pa.int8(),
+ pa.int16(),
+ pa.int64(),
+ pa.uint8(),
+ pa.uint16(),
+ pa.uint64(),
+ pa.float32(),
+ pa.float64()
+]
+
+
+def test_exported_functions():
+ # Check that all exported concrete functions can be called with
+ # the right number of arguments.
+ # Note that unregistered functions (e.g. with a mismatching name)
+ # will raise KeyError.
+ functions = exported_functions
+ assert len(functions) >= 10
+ for func in functions:
+ arity = func.__arrow_compute_function__['arity']
+ if arity is Ellipsis:
+ args = [object()] * 3
+ else:
+ args = [object()] * arity
+ with pytest.raises(TypeError,
+ match="Got unexpected argument type "
+ "<class 'object'> for compute function"):
+ func(*args)
+
+
+def test_exported_option_classes():
+ classes = exported_option_classes
+ assert len(classes) >= 10
+ for cls in classes:
+ # Option classes must have an introspectable constructor signature,
+ # and that signature should not have any *args or **kwargs.
+ sig = inspect.signature(cls)
+ for param in sig.parameters.values():
+ assert param.kind not in (param.VAR_POSITIONAL,
+ param.VAR_KEYWORD)
+
+
+def test_option_class_equality():
+ options = [
+ pc.ArraySortOptions(),
+ pc.AssumeTimezoneOptions("UTC"),
+ pc.CastOptions.safe(pa.int8()),
+ pc.CountOptions(),
+ pc.DayOfWeekOptions(count_from_zero=False, week_start=0),
+ pc.DictionaryEncodeOptions(),
+ pc.ElementWiseAggregateOptions(skip_nulls=True),
+ pc.ExtractRegexOptions("pattern"),
+ pc.FilterOptions(),
+ pc.IndexOptions(pa.scalar(1)),
+ pc.JoinOptions(),
+ pc.MakeStructOptions(["field", "names"],
+ field_nullability=[True, True],
+ field_metadata=[pa.KeyValueMetadata({"a": "1"}),
+ pa.KeyValueMetadata({"b": "2"})]),
+ pc.MatchSubstringOptions("pattern"),
+ pc.ModeOptions(),
+ pc.NullOptions(),
+ pc.PadOptions(5),
+ pc.PartitionNthOptions(1, null_placement="at_start"),
+ pc.QuantileOptions(),
+ pc.ReplaceSliceOptions(0, 1, "a"),
+ pc.ReplaceSubstringOptions("a", "b"),
+ pc.RoundOptions(2, "towards_infinity"),
+ pc.RoundToMultipleOptions(100, "towards_infinity"),
+ pc.ScalarAggregateOptions(),
+ pc.SelectKOptions(0, sort_keys=[("b", "ascending")]),
+ pc.SetLookupOptions(pa.array([1])),
+ pc.SliceOptions(0, 1, 1),
+ pc.SortOptions([("dummy", "descending")], null_placement="at_start"),
+ pc.SplitOptions(),
+ pc.SplitPatternOptions("pattern"),
+ pc.StrftimeOptions(),
+ pc.StrptimeOptions("%Y", "s"),
+ pc.TakeOptions(),
+ pc.TDigestOptions(),
+ pc.TrimOptions(" "),
+ pc.VarianceOptions(),
+ pc.WeekOptions(week_starts_monday=True, count_from_zero=False,
+ first_week_is_fully_in_year=False),
+ ]
+ # TODO: We should test on windows once ARROW-13168 is resolved.
+ # Timezone database is not available on Windows yet
+ if sys.platform != 'win32':
+ options.append(pc.AssumeTimezoneOptions("Europe/Ljubljana"))
+
+ classes = {type(option) for option in options}
+ for cls in exported_option_classes:
+ # Timezone database is not available on Windows yet
+ if cls not in classes and sys.platform != 'win32' and \
+ cls != pc.AssumeTimezoneOptions:
+ try:
+ options.append(cls())
+ except TypeError:
+ pytest.fail(f"Options class is not tested: {cls}")
+ for option in options:
+ assert option == option
+ assert repr(option).startswith(option.__class__.__name__)
+ buf = option.serialize()
+ deserialized = pc.FunctionOptions.deserialize(buf)
+ assert option == deserialized
+ assert repr(option) == repr(deserialized)
+ for option1, option2 in zip(options, options[1:]):
+ assert option1 != option2
+
+ assert repr(pc.IndexOptions(pa.scalar(1))) == "IndexOptions(value=int64:1)"
+ assert repr(pc.ArraySortOptions()) == \
+ "ArraySortOptions(order=Ascending, null_placement=AtEnd)"
+
+
+def test_list_functions():
+ assert len(pc.list_functions()) > 10
+ assert "add" in pc.list_functions()
+
+
+def _check_get_function(name, expected_func_cls, expected_ker_cls,
+ min_num_kernels=1):
+ func = pc.get_function(name)
+ assert isinstance(func, expected_func_cls)
+ n = func.num_kernels
+ assert n >= min_num_kernels
+ assert n == len(func.kernels)
+ assert all(isinstance(ker, expected_ker_cls) for ker in func.kernels)
+
+
+def test_get_function_scalar():
+ _check_get_function("add", pc.ScalarFunction, pc.ScalarKernel, 8)
+
+
+def test_get_function_vector():
+ _check_get_function("unique", pc.VectorFunction, pc.VectorKernel, 8)
+
+
+def test_get_function_scalar_aggregate():
+ _check_get_function("mean", pc.ScalarAggregateFunction,
+ pc.ScalarAggregateKernel, 8)
+
+
+def test_get_function_hash_aggregate():
+ _check_get_function("hash_sum", pc.HashAggregateFunction,
+ pc.HashAggregateKernel, 1)
+
+
+def test_call_function_with_memory_pool():
+ arr = pa.array(["foo", "bar", "baz"])
+ indices = np.array([2, 2, 1])
+ result1 = arr.take(indices)
+ result2 = pc.call_function('take', [arr, indices],
+ memory_pool=pa.default_memory_pool())
+ expected = pa.array(["baz", "baz", "bar"])
+ assert result1.equals(expected)
+ assert result2.equals(expected)
+
+ result3 = pc.take(arr, indices, memory_pool=pa.default_memory_pool())
+ assert result3.equals(expected)
+
+
+def test_pickle_functions():
+ # Pickle registered functions
+ for name in pc.list_functions():
+ func = pc.get_function(name)
+ reconstructed = pickle.loads(pickle.dumps(func))
+ assert type(reconstructed) is type(func)
+ assert reconstructed.name == func.name
+ assert reconstructed.arity == func.arity
+ assert reconstructed.num_kernels == func.num_kernels
+
+
+def test_pickle_global_functions():
+ # Pickle global wrappers (manual or automatic) of registered functions
+ for name in pc.list_functions():
+ func = getattr(pc, name)
+ reconstructed = pickle.loads(pickle.dumps(func))
+ assert reconstructed is func
+
+
+def test_function_attributes():
+ # Sanity check attributes of registered functions
+ for name in pc.list_functions():
+ func = pc.get_function(name)
+ assert isinstance(func, pc.Function)
+ assert func.name == name
+ kernels = func.kernels
+ assert func.num_kernels == len(kernels)
+ assert all(isinstance(ker, pc.Kernel) for ker in kernels)
+ if func.arity is not Ellipsis:
+ assert func.arity >= 1
+ repr(func)
+ for ker in kernels:
+ repr(ker)
+
+
+def test_input_type_conversion():
+ # Automatic array conversion from Python
+ arr = pc.add([1, 2], [4, None])
+ assert arr.to_pylist() == [5, None]
+ # Automatic scalar conversion from Python
+ arr = pc.add([1, 2], 4)
+ assert arr.to_pylist() == [5, 6]
+ # Other scalar type
+ assert pc.equal(["foo", "bar", None],
+ "foo").to_pylist() == [True, False, None]
+
+
+@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
+def test_sum_array(arrow_type):
+ arr = pa.array([1, 2, 3, 4], type=arrow_type)
+ assert arr.sum().as_py() == 10
+ assert pc.sum(arr).as_py() == 10
+
+ arr = pa.array([1, 2, 3, 4, None], type=arrow_type)
+ assert arr.sum().as_py() == 10
+ assert pc.sum(arr).as_py() == 10
+
+ arr = pa.array([None], type=arrow_type)
+ assert arr.sum().as_py() is None # noqa: E711
+ assert pc.sum(arr).as_py() is None # noqa: E711
+ assert arr.sum(min_count=0).as_py() == 0
+ assert pc.sum(arr, min_count=0).as_py() == 0
+
+ arr = pa.array([], type=arrow_type)
+ assert arr.sum().as_py() is None # noqa: E711
+ assert arr.sum(min_count=0).as_py() == 0
+ assert pc.sum(arr, min_count=0).as_py() == 0
+
+
+@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
+def test_sum_chunked_array(arrow_type):
+ arr = pa.chunked_array([pa.array([1, 2, 3, 4], type=arrow_type)])
+ assert pc.sum(arr).as_py() == 10
+
+ arr = pa.chunked_array([
+ pa.array([1, 2], type=arrow_type), pa.array([3, 4], type=arrow_type)
+ ])
+ assert pc.sum(arr).as_py() == 10
+
+ arr = pa.chunked_array([
+ pa.array([1, 2], type=arrow_type),
+ pa.array([], type=arrow_type),
+ pa.array([3, 4], type=arrow_type)
+ ])
+ assert pc.sum(arr).as_py() == 10
+
+ arr = pa.chunked_array((), type=arrow_type)
+ assert arr.num_chunks == 0
+ assert pc.sum(arr).as_py() is None # noqa: E711
+ assert pc.sum(arr, min_count=0).as_py() == 0
+
+
+def test_mode_array():
+ # ARROW-9917
+ arr = pa.array([1, 1, 3, 4, 3, 5], type='int64')
+ mode = pc.mode(arr)
+ assert len(mode) == 1
+ assert mode[0].as_py() == {"mode": 1, "count": 2}
+
+ mode = pc.mode(arr, n=2)
+ assert len(mode) == 2
+ assert mode[0].as_py() == {"mode": 1, "count": 2}
+ assert mode[1].as_py() == {"mode": 3, "count": 2}
+
+ arr = pa.array([], type='int64')
+ assert len(pc.mode(arr)) == 0
+
+ arr = pa.array([1, 1, 3, 4, 3, None], type='int64')
+ mode = pc.mode(arr, skip_nulls=False)
+ assert len(mode) == 0
+ mode = pc.mode(arr, min_count=6)
+ assert len(mode) == 0
+ mode = pc.mode(arr, skip_nulls=False, min_count=5)
+ assert len(mode) == 0
+
+
+def test_mode_chunked_array():
+ # ARROW-9917
+ arr = pa.chunked_array([pa.array([1, 1, 3, 4, 3, 5], type='int64')])
+ mode = pc.mode(arr)
+ assert len(mode) == 1
+ assert mode[0].as_py() == {"mode": 1, "count": 2}
+
+ mode = pc.mode(arr, n=2)
+ assert len(mode) == 2
+ assert mode[0].as_py() == {"mode": 1, "count": 2}
+ assert mode[1].as_py() == {"mode": 3, "count": 2}
+
+ arr = pa.chunked_array((), type='int64')
+ assert arr.num_chunks == 0
+ assert len(pc.mode(arr)) == 0
+
+
+def test_variance():
+ data = [1, 2, 3, 4, 5, 6, 7, 8]
+ assert pc.variance(data).as_py() == 5.25
+ assert pc.variance(data, ddof=0).as_py() == 5.25
+ assert pc.variance(data, ddof=1).as_py() == 6.0
+
+
+def test_count_substring():
+ for (ty, offset) in [(pa.string(), pa.int32()),
+ (pa.large_string(), pa.int64())]:
+ arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None], type=ty)
+
+ result = pc.count_substring(arr, "ab")
+ expected = pa.array([1, 1, 2, 0, 0, None], type=offset)
+ assert expected.equals(result)
+
+ result = pc.count_substring(arr, "ab", ignore_case=True)
+ expected = pa.array([1, 1, 2, 0, 1, None], type=offset)
+ assert expected.equals(result)
+
+
+def test_count_substring_regex():
+ for (ty, offset) in [(pa.string(), pa.int32()),
+ (pa.large_string(), pa.int64())]:
+ arr = pa.array(["ab", "cab", "baAacaa", "ba", "AB", None], type=ty)
+
+ result = pc.count_substring_regex(arr, "a+")
+ expected = pa.array([1, 1, 3, 1, 0, None], type=offset)
+ assert expected.equals(result)
+
+ result = pc.count_substring_regex(arr, "a+", ignore_case=True)
+ expected = pa.array([1, 1, 2, 1, 1, None], type=offset)
+ assert expected.equals(result)
+
+
+def test_find_substring():
+ for ty in [pa.string(), pa.binary(), pa.large_string(), pa.large_binary()]:
+ arr = pa.array(["ab", "cab", "ba", None], type=ty)
+ result = pc.find_substring(arr, "ab")
+ assert result.to_pylist() == [0, 1, -1, None]
+
+ result = pc.find_substring_regex(arr, "a?b")
+ assert result.to_pylist() == [0, 1, 0, None]
+
+ arr = pa.array(["ab*", "cAB*", "ba", "aB?"], type=ty)
+ result = pc.find_substring(arr, "aB*", ignore_case=True)
+ assert result.to_pylist() == [0, 1, -1, -1]
+
+ result = pc.find_substring_regex(arr, "a?b", ignore_case=True)
+ assert result.to_pylist() == [0, 1, 0, 0]
+
+
+def test_match_like():
+ arr = pa.array(["ab", "ba%", "ba", "ca%d", None])
+ result = pc.match_like(arr, r"_a\%%")
+ expected = pa.array([False, True, False, True, None])
+ assert expected.equals(result)
+
+ arr = pa.array(["aB", "bA%", "ba", "ca%d", None])
+ result = pc.match_like(arr, r"_a\%%", ignore_case=True)
+ expected = pa.array([False, True, False, True, None])
+ assert expected.equals(result)
+ result = pc.match_like(arr, r"_a\%%", ignore_case=False)
+ expected = pa.array([False, False, False, True, None])
+ assert expected.equals(result)
+
+
+def test_match_substring():
+ arr = pa.array(["ab", "abc", "ba", None])
+ result = pc.match_substring(arr, "ab")
+ expected = pa.array([True, True, False, None])
+ assert expected.equals(result)
+
+ arr = pa.array(["áB", "Ábc", "ba", None])
+ result = pc.match_substring(arr, "áb", ignore_case=True)
+ expected = pa.array([True, True, False, None])
+ assert expected.equals(result)
+ result = pc.match_substring(arr, "áb", ignore_case=False)
+ expected = pa.array([False, False, False, None])
+ assert expected.equals(result)
+
+
+def test_match_substring_regex():
+ arr = pa.array(["ab", "abc", "ba", "c", None])
+ result = pc.match_substring_regex(arr, "^a?b")
+ expected = pa.array([True, True, True, False, None])
+ assert expected.equals(result)
+
+ arr = pa.array(["aB", "Abc", "BA", "c", None])
+ result = pc.match_substring_regex(arr, "^a?b", ignore_case=True)
+ expected = pa.array([True, True, True, False, None])
+ assert expected.equals(result)
+ result = pc.match_substring_regex(arr, "^a?b", ignore_case=False)
+ expected = pa.array([False, False, False, False, None])
+ assert expected.equals(result)
+
+
+def test_trim():
+ # \u3000 is unicode whitespace
+ arr = pa.array([" foo", None, " \u3000foo bar \t"])
+ result = pc.utf8_trim_whitespace(arr)
+ expected = pa.array(["foo", None, "foo bar"])
+ assert expected.equals(result)
+
+ arr = pa.array([" foo", None, " \u3000foo bar \t"])
+ result = pc.ascii_trim_whitespace(arr)
+ expected = pa.array(["foo", None, "\u3000foo bar"])
+ assert expected.equals(result)
+
+ arr = pa.array([" foo", None, " \u3000foo bar \t"])
+ result = pc.utf8_trim(arr, characters=' f\u3000')
+ expected = pa.array(["oo", None, "oo bar \t"])
+ assert expected.equals(result)
+
+
+def test_slice_compatibility():
+ arr = pa.array(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])
+ for start in range(-6, 6):
+ for stop in range(-6, 6):
+ for step in [-3, -2, -1, 1, 2, 3]:
+ expected = pa.array([k.as_py()[start:stop:step]
+ for k in arr])
+ result = pc.utf8_slice_codeunits(
+ arr, start=start, stop=stop, step=step)
+ assert expected.equals(result)
+
+
+def test_split_pattern():
+ arr = pa.array(["-foo---bar--", "---foo---b"])
+ result = pc.split_pattern(arr, pattern="---")
+ expected = pa.array([["-foo", "bar--"], ["", "foo", "b"]])
+ assert expected.equals(result)
+
+ result = pc.split_pattern(arr, pattern="---", max_splits=1)
+ expected = pa.array([["-foo", "bar--"], ["", "foo---b"]])
+ assert expected.equals(result)
+
+ result = pc.split_pattern(arr, pattern="---", max_splits=1, reverse=True)
+ expected = pa.array([["-foo", "bar--"], ["---foo", "b"]])
+ assert expected.equals(result)
+
+
+def test_split_whitespace_utf8():
+ arr = pa.array(["foo bar", " foo \u3000\tb"])
+ result = pc.utf8_split_whitespace(arr)
+ expected = pa.array([["foo", "bar"], ["", "foo", "b"]])
+ assert expected.equals(result)
+
+ result = pc.utf8_split_whitespace(arr, max_splits=1)
+ expected = pa.array([["foo", "bar"], ["", "foo \u3000\tb"]])
+ assert expected.equals(result)
+
+ result = pc.utf8_split_whitespace(arr, max_splits=1, reverse=True)
+ expected = pa.array([["foo", "bar"], [" foo", "b"]])
+ assert expected.equals(result)
+
+
+def test_split_whitespace_ascii():
+ arr = pa.array(["foo bar", " foo \u3000\tb"])
+ result = pc.ascii_split_whitespace(arr)
+ expected = pa.array([["foo", "bar"], ["", "foo", "\u3000", "b"]])
+ assert expected.equals(result)
+
+ result = pc.ascii_split_whitespace(arr, max_splits=1)
+ expected = pa.array([["foo", "bar"], ["", "foo \u3000\tb"]])
+ assert expected.equals(result)
+
+ result = pc.ascii_split_whitespace(arr, max_splits=1, reverse=True)
+ expected = pa.array([["foo", "bar"], [" foo \u3000", "b"]])
+ assert expected.equals(result)
+
+
+def test_split_pattern_regex():
+ arr = pa.array(["-foo---bar--", "---foo---b"])
+ result = pc.split_pattern_regex(arr, pattern="-+")
+ expected = pa.array([["", "foo", "bar", ""], ["", "foo", "b"]])
+ assert expected.equals(result)
+
+ result = pc.split_pattern_regex(arr, pattern="-+", max_splits=1)
+ expected = pa.array([["", "foo---bar--"], ["", "foo---b"]])
+ assert expected.equals(result)
+
+ with pytest.raises(NotImplementedError,
+ match="Cannot split in reverse with regex"):
+ result = pc.split_pattern_regex(
+ arr, pattern="---", max_splits=1, reverse=True)
+
+
+def test_min_max():
+ # An example generated function wrapper with possible options
+ data = [4, 5, 6, None, 1]
+ s = pc.min_max(data)
+ assert s.as_py() == {'min': 1, 'max': 6}
+ s = pc.min_max(data, options=pc.ScalarAggregateOptions())
+ assert s.as_py() == {'min': 1, 'max': 6}
+ s = pc.min_max(data, options=pc.ScalarAggregateOptions(skip_nulls=True))
+ assert s.as_py() == {'min': 1, 'max': 6}
+ s = pc.min_max(data, options=pc.ScalarAggregateOptions(skip_nulls=False))
+ assert s.as_py() == {'min': None, 'max': None}
+
+ # Options as dict of kwargs
+ s = pc.min_max(data, options={'skip_nulls': False})
+ assert s.as_py() == {'min': None, 'max': None}
+ # Options as named functions arguments
+ s = pc.min_max(data, skip_nulls=False)
+ assert s.as_py() == {'min': None, 'max': None}
+
+ # Both options and named arguments
+ with pytest.raises(TypeError):
+ s = pc.min_max(
+ data, options=pc.ScalarAggregateOptions(), skip_nulls=False)
+
+ # Wrong options type
+ options = pc.TakeOptions()
+ with pytest.raises(TypeError):
+ s = pc.min_max(data, options=options)
+
+ # Missing argument
+ with pytest.raises(ValueError,
+ match="Function min_max accepts 1 argument"):
+ s = pc.min_max()
+
+
+def test_any():
+ # ARROW-1846
+
+ options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0)
+
+ a = pa.array([], type='bool')
+ assert pc.any(a).as_py() is None
+ assert pc.any(a, min_count=0).as_py() is False
+ assert pc.any(a, options=options).as_py() is False
+
+ a = pa.array([False, None, True])
+ assert pc.any(a).as_py() is True
+ assert pc.any(a, options=options).as_py() is True
+
+ a = pa.array([False, None, False])
+ assert pc.any(a).as_py() is False
+ assert pc.any(a, options=options).as_py() is None
+
+
+def test_all():
+ # ARROW-10301
+
+ options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0)
+
+ a = pa.array([], type='bool')
+ assert pc.all(a).as_py() is None
+ assert pc.all(a, min_count=0).as_py() is True
+ assert pc.all(a, options=options).as_py() is True
+
+ a = pa.array([False, True])
+ assert pc.all(a).as_py() is False
+ assert pc.all(a, options=options).as_py() is False
+
+ a = pa.array([True, None])
+ assert pc.all(a).as_py() is True
+ assert pc.all(a, options=options).as_py() is None
+
+ a = pa.chunked_array([[True], [True, None]])
+ assert pc.all(a).as_py() is True
+ assert pc.all(a, options=options).as_py() is None
+
+ a = pa.chunked_array([[True], [False]])
+ assert pc.all(a).as_py() is False
+ assert pc.all(a, options=options).as_py() is False
+
+
+def test_is_valid():
+ # An example generated function wrapper without options
+ data = [4, 5, None]
+ assert pc.is_valid(data).to_pylist() == [True, True, False]
+
+ with pytest.raises(TypeError):
+ pc.is_valid(data, options=None)
+
+
+def test_generated_docstrings():
+ assert pc.min_max.__doc__ == textwrap.dedent("""\
+ Compute the minimum and maximum values of a numeric array.
+
+ Null values are ignored by default.
+ This can be changed through ScalarAggregateOptions.
+
+ Parameters
+ ----------
+ array : Array-like
+ Argument to compute function
+ memory_pool : pyarrow.MemoryPool, optional
+ If not passed, will allocate memory from the default memory pool.
+ options : pyarrow.compute.ScalarAggregateOptions, optional
+ Parameters altering compute function semantics.
+ skip_nulls : optional
+ Parameter for ScalarAggregateOptions constructor. Either `options`
+ or `skip_nulls` can be passed, but not both at the same time.
+ min_count : optional
+ Parameter for ScalarAggregateOptions constructor. Either `options`
+ or `min_count` can be passed, but not both at the same time.
+ """)
+ assert pc.add.__doc__ == textwrap.dedent("""\
+ Add the arguments element-wise.
+
+ Results will wrap around on integer overflow.
+ Use function "add_checked" if you want overflow
+ to return an error.
+
+ Parameters
+ ----------
+ x : Array-like or scalar-like
+ Argument to compute function
+ y : Array-like or scalar-like
+ Argument to compute function
+ memory_pool : pyarrow.MemoryPool, optional
+ If not passed, will allocate memory from the default memory pool.
+ """)
+
+
+def test_generated_signatures():
+ # The self-documentation provided by signatures should show acceptable
+ # options and their default values.
+ sig = inspect.signature(pc.add)
+ assert str(sig) == "(x, y, *, memory_pool=None)"
+ sig = inspect.signature(pc.min_max)
+ assert str(sig) == ("(array, *, memory_pool=None, "
+ "options=None, skip_nulls=True, min_count=1)")
+ sig = inspect.signature(pc.quantile)
+ assert str(sig) == ("(array, *, memory_pool=None, "
+ "options=None, q=0.5, interpolation='linear', "
+ "skip_nulls=True, min_count=0)")
+ sig = inspect.signature(pc.binary_join_element_wise)
+ assert str(sig) == ("(*strings, memory_pool=None, options=None, "
+ "null_handling='emit_null', null_replacement='')")
+
+
+# We use isprintable to find about codepoints that Python doesn't know, but
+# utf8proc does (or in a future version of Python the other way around).
+# These codepoints cannot be compared between Arrow and the Python
+# implementation.
+@lru_cache()
+def find_new_unicode_codepoints():
+ new = set()
+ characters = [chr(c) for c in range(0x80, 0x11000)
+ if not (0xD800 <= c < 0xE000)]
+ is_printable = pc.utf8_is_printable(pa.array(characters)).to_pylist()
+ for i, c in enumerate(characters):
+ if is_printable[i] != c.isprintable():
+ new.add(ord(c))
+ return new
+
+
+# Python claims there are not alpha, not sure why, they are in
+# gc='Other Letter': https://graphemica.com/%E1%B3%B2
+unknown_issue_is_alpha = {0x1cf2, 0x1cf3}
+# utf8proc does not know if codepoints are lower case
+utf8proc_issue_is_lower = {
+ 0xaa, 0xba, 0x2b0, 0x2b1, 0x2b2, 0x2b3, 0x2b4,
+ 0x2b5, 0x2b6, 0x2b7, 0x2b8, 0x2c0, 0x2c1, 0x2e0,
+ 0x2e1, 0x2e2, 0x2e3, 0x2e4, 0x37a, 0x1d2c, 0x1d2d,
+ 0x1d2e, 0x1d2f, 0x1d30, 0x1d31, 0x1d32, 0x1d33,
+ 0x1d34, 0x1d35, 0x1d36, 0x1d37, 0x1d38, 0x1d39,
+ 0x1d3a, 0x1d3b, 0x1d3c, 0x1d3d, 0x1d3e, 0x1d3f,
+ 0x1d40, 0x1d41, 0x1d42, 0x1d43, 0x1d44, 0x1d45,
+ 0x1d46, 0x1d47, 0x1d48, 0x1d49, 0x1d4a, 0x1d4b,
+ 0x1d4c, 0x1d4d, 0x1d4e, 0x1d4f, 0x1d50, 0x1d51,
+ 0x1d52, 0x1d53, 0x1d54, 0x1d55, 0x1d56, 0x1d57,
+ 0x1d58, 0x1d59, 0x1d5a, 0x1d5b, 0x1d5c, 0x1d5d,
+ 0x1d5e, 0x1d5f, 0x1d60, 0x1d61, 0x1d62, 0x1d63,
+ 0x1d64, 0x1d65, 0x1d66, 0x1d67, 0x1d68, 0x1d69,
+ 0x1d6a, 0x1d78, 0x1d9b, 0x1d9c, 0x1d9d, 0x1d9e,
+ 0x1d9f, 0x1da0, 0x1da1, 0x1da2, 0x1da3, 0x1da4,
+ 0x1da5, 0x1da6, 0x1da7, 0x1da8, 0x1da9, 0x1daa,
+ 0x1dab, 0x1dac, 0x1dad, 0x1dae, 0x1daf, 0x1db0,
+ 0x1db1, 0x1db2, 0x1db3, 0x1db4, 0x1db5, 0x1db6,
+ 0x1db7, 0x1db8, 0x1db9, 0x1dba, 0x1dbb, 0x1dbc,
+ 0x1dbd, 0x1dbe, 0x1dbf, 0x2071, 0x207f, 0x2090,
+ 0x2091, 0x2092, 0x2093, 0x2094, 0x2095, 0x2096,
+ 0x2097, 0x2098, 0x2099, 0x209a, 0x209b, 0x209c,
+ 0x2c7c, 0x2c7d, 0xa69c, 0xa69d, 0xa770, 0xa7f8,
+ 0xa7f9, 0xab5c, 0xab5d, 0xab5e, 0xab5f, }
+# utf8proc does not store if a codepoint is numeric
+numeric_info_missing = {
+ 0x3405, 0x3483, 0x382a, 0x3b4d, 0x4e00, 0x4e03,
+ 0x4e07, 0x4e09, 0x4e5d, 0x4e8c, 0x4e94, 0x4e96,
+ 0x4ebf, 0x4ec0, 0x4edf, 0x4ee8, 0x4f0d, 0x4f70,
+ 0x5104, 0x5146, 0x5169, 0x516b, 0x516d, 0x5341,
+ 0x5343, 0x5344, 0x5345, 0x534c, 0x53c1, 0x53c2,
+ 0x53c3, 0x53c4, 0x56db, 0x58f1, 0x58f9, 0x5e7a,
+ 0x5efe, 0x5eff, 0x5f0c, 0x5f0d, 0x5f0e, 0x5f10,
+ 0x62fe, 0x634c, 0x67d2, 0x6f06, 0x7396, 0x767e,
+ 0x8086, 0x842c, 0x8cae, 0x8cb3, 0x8d30, 0x9621,
+ 0x9646, 0x964c, 0x9678, 0x96f6, 0xf96b, 0xf973,
+ 0xf978, 0xf9b2, 0xf9d1, 0xf9d3, 0xf9fd, 0x10fc5,
+ 0x10fc6, 0x10fc7, 0x10fc8, 0x10fc9, 0x10fca,
+ 0x10fcb, }
+# utf8proc has no no digit/numeric information
+digit_info_missing = {
+ 0xb2, 0xb3, 0xb9, 0x1369, 0x136a, 0x136b, 0x136c,
+ 0x136d, 0x136e, 0x136f, 0x1370, 0x1371, 0x19da, 0x2070,
+ 0x2074, 0x2075, 0x2076, 0x2077, 0x2078, 0x2079, 0x2080,
+ 0x2081, 0x2082, 0x2083, 0x2084, 0x2085, 0x2086, 0x2087,
+ 0x2088, 0x2089, 0x2460, 0x2461, 0x2462, 0x2463, 0x2464,
+ 0x2465, 0x2466, 0x2467, 0x2468, 0x2474, 0x2475, 0x2476,
+ 0x2477, 0x2478, 0x2479, 0x247a, 0x247b, 0x247c, 0x2488,
+ 0x2489, 0x248a, 0x248b, 0x248c, 0x248d, 0x248e, 0x248f,
+ 0x2490, 0x24ea, 0x24f5, 0x24f6, 0x24f7, 0x24f8, 0x24f9,
+ 0x24fa, 0x24fb, 0x24fc, 0x24fd, 0x24ff, 0x2776, 0x2777,
+ 0x2778, 0x2779, 0x277a, 0x277b, 0x277c, 0x277d, 0x277e,
+ 0x2780, 0x2781, 0x2782, 0x2783, 0x2784, 0x2785, 0x2786,
+ 0x2787, 0x2788, 0x278a, 0x278b, 0x278c, 0x278d, 0x278e,
+ 0x278f, 0x2790, 0x2791, 0x2792, 0x10a40, 0x10a41,
+ 0x10a42, 0x10a43, 0x10e60, 0x10e61, 0x10e62, 0x10e63,
+ 0x10e64, 0x10e65, 0x10e66, 0x10e67, 0x10e68, }
+numeric_info_missing = {
+ 0x3405, 0x3483, 0x382a, 0x3b4d, 0x4e00, 0x4e03,
+ 0x4e07, 0x4e09, 0x4e5d, 0x4e8c, 0x4e94, 0x4e96,
+ 0x4ebf, 0x4ec0, 0x4edf, 0x4ee8, 0x4f0d, 0x4f70,
+ 0x5104, 0x5146, 0x5169, 0x516b, 0x516d, 0x5341,
+ 0x5343, 0x5344, 0x5345, 0x534c, 0x53c1, 0x53c2,
+ 0x53c3, 0x53c4, 0x56db, 0x58f1, 0x58f9, 0x5e7a,
+ 0x5efe, 0x5eff, 0x5f0c, 0x5f0d, 0x5f0e, 0x5f10,
+ 0x62fe, 0x634c, 0x67d2, 0x6f06, 0x7396, 0x767e,
+ 0x8086, 0x842c, 0x8cae, 0x8cb3, 0x8d30, 0x9621,
+ 0x9646, 0x964c, 0x9678, 0x96f6, 0xf96b, 0xf973,
+ 0xf978, 0xf9b2, 0xf9d1, 0xf9d3, 0xf9fd, }
+
+codepoints_ignore = {
+ 'is_alnum': numeric_info_missing | digit_info_missing |
+ unknown_issue_is_alpha,
+ 'is_alpha': unknown_issue_is_alpha,
+ 'is_digit': digit_info_missing,
+ 'is_numeric': numeric_info_missing,
+ 'is_lower': utf8proc_issue_is_lower
+}
+
+
+@pytest.mark.parametrize('function_name', ['is_alnum', 'is_alpha',
+ 'is_ascii', 'is_decimal',
+ 'is_digit', 'is_lower',
+ 'is_numeric', 'is_printable',
+ 'is_space', 'is_upper', ])
+@pytest.mark.parametrize('variant', ['ascii', 'utf8'])
+def test_string_py_compat_boolean(function_name, variant):
+ arrow_name = variant + "_" + function_name
+ py_name = function_name.replace('_', '')
+ ignore = codepoints_ignore.get(function_name, set()) | \
+ find_new_unicode_codepoints()
+ for i in range(128 if ascii else 0x11000):
+ if i in range(0xD800, 0xE000):
+ continue # bug? pyarrow doesn't allow utf16 surrogates
+ # the issues we know of, we skip
+ if i in ignore:
+ continue
+ # Compare results with the equivalent Python predicate
+ # (except "is_space" where functions are known to be incompatible)
+ c = chr(i)
+ if hasattr(pc, arrow_name) and function_name != 'is_space':
+ ar = pa.array([c])
+ arrow_func = getattr(pc, arrow_name)
+ assert arrow_func(ar)[0].as_py() == getattr(c, py_name)()
+
+
+def test_pad():
+ arr = pa.array([None, 'a', 'abcd'])
+ assert pc.ascii_center(arr, width=3).tolist() == [None, ' a ', 'abcd']
+ assert pc.ascii_lpad(arr, width=3).tolist() == [None, ' a', 'abcd']
+ assert pc.ascii_rpad(arr, width=3).tolist() == [None, 'a ', 'abcd']
+
+ arr = pa.array([None, 'á', 'abcd'])
+ assert pc.utf8_center(arr, width=3).tolist() == [None, ' á ', 'abcd']
+ assert pc.utf8_lpad(arr, width=3).tolist() == [None, ' á', 'abcd']
+ assert pc.utf8_rpad(arr, width=3).tolist() == [None, 'á ', 'abcd']
+
+
+@pytest.mark.pandas
+def test_replace_slice():
+ offsets = range(-3, 4)
+
+ arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd', 'abcde'])
+ series = arr.to_pandas()
+ for start in offsets:
+ for stop in offsets:
+ expected = series.str.slice_replace(start, stop, 'XX')
+ actual = pc.binary_replace_slice(
+ arr, start=start, stop=stop, replacement='XX')
+ assert actual.tolist() == expected.tolist()
+
+ arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd', 'πbθde'])
+ series = arr.to_pandas()
+ for start in offsets:
+ for stop in offsets:
+ expected = series.str.slice_replace(start, stop, 'XX')
+ actual = pc.utf8_replace_slice(
+ arr, start=start, stop=stop, replacement='XX')
+ assert actual.tolist() == expected.tolist()
+
+
+def test_replace_plain():
+ ar = pa.array(['foo', 'food', None])
+ ar = pc.replace_substring(ar, pattern='foo', replacement='bar')
+ assert ar.tolist() == ['bar', 'bard', None]
+
+
+def test_replace_regex():
+ ar = pa.array(['foo', 'mood', None])
+ ar = pc.replace_substring_regex(ar, pattern='(.)oo', replacement=r'\100')
+ assert ar.tolist() == ['f00', 'm00d', None]
+
+
+def test_extract_regex():
+ ar = pa.array(['a1', 'zb2z'])
+ struct = pc.extract_regex(ar, pattern=r'(?P<letter>[ab])(?P<digit>\d)')
+ assert struct.tolist() == [{'letter': 'a', 'digit': '1'}, {
+ 'letter': 'b', 'digit': '2'}]
+
+
+def test_binary_join():
+ ar_list = pa.array([['foo', 'bar'], None, []])
+ expected = pa.array(['foo-bar', None, ''])
+ assert pc.binary_join(ar_list, '-').equals(expected)
+
+ separator_array = pa.array(['1', '2'], type=pa.binary())
+ expected = pa.array(['a1b', 'c2d'], type=pa.binary())
+ ar_list = pa.array([['a', 'b'], ['c', 'd']], type=pa.list_(pa.binary()))
+ assert pc.binary_join(ar_list, separator_array).equals(expected)
+
+
+def test_binary_join_element_wise():
+ null = pa.scalar(None, type=pa.string())
+ arrs = [[None, 'a', 'b'], ['c', None, 'd'], [None, '-', '--']]
+ assert pc.binary_join_element_wise(*arrs).to_pylist() == \
+ [None, None, 'b--d']
+ assert pc.binary_join_element_wise('a', 'b', '-').as_py() == 'a-b'
+ assert pc.binary_join_element_wise('a', null, '-').as_py() is None
+ assert pc.binary_join_element_wise('a', 'b', null).as_py() is None
+
+ skip = pc.JoinOptions(null_handling='skip')
+ assert pc.binary_join_element_wise(*arrs, options=skip).to_pylist() == \
+ [None, 'a', 'b--d']
+ assert pc.binary_join_element_wise(
+ 'a', 'b', '-', options=skip).as_py() == 'a-b'
+ assert pc.binary_join_element_wise(
+ 'a', null, '-', options=skip).as_py() == 'a'
+ assert pc.binary_join_element_wise(
+ 'a', 'b', null, options=skip).as_py() is None
+
+ replace = pc.JoinOptions(null_handling='replace', null_replacement='spam')
+ assert pc.binary_join_element_wise(*arrs, options=replace).to_pylist() == \
+ [None, 'a-spam', 'b--d']
+ assert pc.binary_join_element_wise(
+ 'a', 'b', '-', options=replace).as_py() == 'a-b'
+ assert pc.binary_join_element_wise(
+ 'a', null, '-', options=replace).as_py() == 'a-spam'
+ assert pc.binary_join_element_wise(
+ 'a', 'b', null, options=replace).as_py() is None
+
+
+@pytest.mark.parametrize(('ty', 'values'), all_array_types)
+def test_take(ty, values):
+ arr = pa.array(values, type=ty)
+ for indices_type in [pa.int8(), pa.int64()]:
+ indices = pa.array([0, 4, 2, None], type=indices_type)
+ result = arr.take(indices)
+ result.validate()
+ expected = pa.array([values[0], values[4], values[2], None], type=ty)
+ assert result.equals(expected)
+
+ # empty indices
+ indices = pa.array([], type=indices_type)
+ result = arr.take(indices)
+ result.validate()
+ expected = pa.array([], type=ty)
+ assert result.equals(expected)
+
+ indices = pa.array([2, 5])
+ with pytest.raises(IndexError):
+ arr.take(indices)
+
+ indices = pa.array([2, -1])
+ with pytest.raises(IndexError):
+ arr.take(indices)
+
+
+def test_take_indices_types():
+ arr = pa.array(range(5))
+
+ for indices_type in ['uint8', 'int8', 'uint16', 'int16',
+ 'uint32', 'int32', 'uint64', 'int64']:
+ indices = pa.array([0, 4, 2, None], type=indices_type)
+ result = arr.take(indices)
+ result.validate()
+ expected = pa.array([0, 4, 2, None])
+ assert result.equals(expected)
+
+ for indices_type in [pa.float32(), pa.float64()]:
+ indices = pa.array([0, 4, 2], type=indices_type)
+ with pytest.raises(NotImplementedError):
+ arr.take(indices)
+
+
+def test_take_on_chunked_array():
+ # ARROW-9504
+ arr = pa.chunked_array([
+ [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e"
+ ],
+ [
+ "f",
+ "g",
+ "h",
+ "i",
+ "j"
+ ]
+ ])
+
+ indices = np.array([0, 5, 1, 6, 9, 2])
+ result = arr.take(indices)
+ expected = pa.chunked_array([["a", "f", "b", "g", "j", "c"]])
+ assert result.equals(expected)
+
+ indices = pa.chunked_array([[1], [9, 2]])
+ result = arr.take(indices)
+ expected = pa.chunked_array([
+ [
+ "b"
+ ],
+ [
+ "j",
+ "c"
+ ]
+ ])
+ assert result.equals(expected)
+
+
+@pytest.mark.parametrize('ordered', [False, True])
+def test_take_dictionary(ordered):
+ arr = pa.DictionaryArray.from_arrays([0, 1, 2, 0, 1, 2], ['a', 'b', 'c'],
+ ordered=ordered)
+ result = arr.take(pa.array([0, 1, 3]))
+ result.validate()
+ assert result.to_pylist() == ['a', 'b', 'a']
+ assert result.dictionary.to_pylist() == ['a', 'b', 'c']
+ assert result.type.ordered is ordered
+
+
+def test_take_null_type():
+ # ARROW-10027
+ arr = pa.array([None] * 10)
+ chunked_arr = pa.chunked_array([[None] * 5] * 2)
+ batch = pa.record_batch([arr], names=['a'])
+ table = pa.table({'a': arr})
+
+ indices = pa.array([1, 3, 7, None])
+ assert len(arr.take(indices)) == 4
+ assert len(chunked_arr.take(indices)) == 4
+ assert len(batch.take(indices).column(0)) == 4
+ assert len(table.take(indices).column(0)) == 4
+
+
+@pytest.mark.parametrize(('ty', 'values'), all_array_types)
+def test_drop_null(ty, values):
+ arr = pa.array(values, type=ty)
+ result = arr.drop_null()
+ result.validate(full=True)
+ indices = [i for i in range(len(arr)) if arr[i].is_valid]
+ expected = arr.take(pa.array(indices))
+ assert result.equals(expected)
+
+
+def test_drop_null_chunked_array():
+ arr = pa.chunked_array([["a", None], ["c", "d", None], [None], []])
+ expected_drop = pa.chunked_array([["a"], ["c", "d"], [], []])
+
+ result = arr.drop_null()
+ assert result.equals(expected_drop)
+
+
+def test_drop_null_record_batch():
+ batch = pa.record_batch(
+ [pa.array(["a", None, "c", "d", None])], names=["a'"])
+ result = batch.drop_null()
+ expected = pa.record_batch([pa.array(["a", "c", "d"])], names=["a'"])
+ assert result.equals(expected)
+
+ batch = pa.record_batch(
+ [pa.array(["a", None, "c", "d", None]),
+ pa.array([None, None, "c", None, "e"])], names=["a'", "b'"])
+
+ result = batch.drop_null()
+ expected = pa.record_batch(
+ [pa.array(["c"]), pa.array(["c"])], names=["a'", "b'"])
+ assert result.equals(expected)
+
+
+def test_drop_null_table():
+ table = pa.table([pa.array(["a", None, "c", "d", None])], names=["a"])
+ expected = pa.table([pa.array(["a", "c", "d"])], names=["a"])
+ result = table.drop_null()
+ assert result.equals(expected)
+
+ table = pa.table([pa.chunked_array([["a", None], ["c", "d", None]]),
+ pa.chunked_array([["a", None], [None, "d", None]]),
+ pa.chunked_array([["a"], ["b"], [None], ["d", None]])],
+ names=["a", "b", "c"])
+ expected = pa.table([pa.array(["a", "d"]),
+ pa.array(["a", "d"]),
+ pa.array(["a", "d"])],
+ names=["a", "b", "c"])
+ result = table.drop_null()
+ assert result.equals(expected)
+
+ table = pa.table([pa.chunked_array([["a", "b"], ["c", "d", "e"]]),
+ pa.chunked_array([["A"], ["B"], [None], ["D", None]]),
+ pa.chunked_array([["a`", None], ["c`", "d`", None]])],
+ names=["a", "b", "c"])
+ expected = pa.table([pa.array(["a", "d"]),
+ pa.array(["A", "D"]),
+ pa.array(["a`", "d`"])],
+ names=["a", "b", "c"])
+ result = table.drop_null()
+ assert result.equals(expected)
+
+
+def test_drop_null_null_type():
+ arr = pa.array([None] * 10)
+ chunked_arr = pa.chunked_array([[None] * 5] * 2)
+ batch = pa.record_batch([arr], names=['a'])
+ table = pa.table({'a': arr})
+
+ assert len(arr.drop_null()) == 0
+ assert len(chunked_arr.drop_null()) == 0
+ assert len(batch.drop_null().column(0)) == 0
+ assert len(table.drop_null().column(0)) == 0
+
+
+@pytest.mark.parametrize(('ty', 'values'), all_array_types)
+def test_filter(ty, values):
+ arr = pa.array(values, type=ty)
+
+ mask = pa.array([True, False, False, True, None])
+ result = arr.filter(mask, null_selection_behavior='drop')
+ result.validate()
+ assert result.equals(pa.array([values[0], values[3]], type=ty))
+ result = arr.filter(mask, null_selection_behavior='emit_null')
+ result.validate()
+ assert result.equals(pa.array([values[0], values[3], None], type=ty))
+
+ # non-boolean dtype
+ mask = pa.array([0, 1, 0, 1, 0])
+ with pytest.raises(NotImplementedError):
+ arr.filter(mask)
+
+ # wrong length
+ mask = pa.array([True, False, True])
+ with pytest.raises(ValueError, match="must all be the same length"):
+ arr.filter(mask)
+
+
+def test_filter_chunked_array():
+ arr = pa.chunked_array([["a", None], ["c", "d", "e"]])
+ expected_drop = pa.chunked_array([["a"], ["e"]])
+ expected_null = pa.chunked_array([["a"], [None, "e"]])
+
+ for mask in [
+ # mask is array
+ pa.array([True, False, None, False, True]),
+ # mask is chunked array
+ pa.chunked_array([[True, False, None], [False, True]]),
+ # mask is python object
+ [True, False, None, False, True]
+ ]:
+ result = arr.filter(mask)
+ assert result.equals(expected_drop)
+ result = arr.filter(mask, null_selection_behavior="emit_null")
+ assert result.equals(expected_null)
+
+
+def test_filter_record_batch():
+ batch = pa.record_batch(
+ [pa.array(["a", None, "c", "d", "e"])], names=["a'"])
+
+ # mask is array
+ mask = pa.array([True, False, None, False, True])
+ result = batch.filter(mask)
+ expected = pa.record_batch([pa.array(["a", "e"])], names=["a'"])
+ assert result.equals(expected)
+
+ result = batch.filter(mask, null_selection_behavior="emit_null")
+ expected = pa.record_batch([pa.array(["a", None, "e"])], names=["a'"])
+ assert result.equals(expected)
+
+
+def test_filter_table():
+ table = pa.table([pa.array(["a", None, "c", "d", "e"])], names=["a"])
+ expected_drop = pa.table([pa.array(["a", "e"])], names=["a"])
+ expected_null = pa.table([pa.array(["a", None, "e"])], names=["a"])
+
+ for mask in [
+ # mask is array
+ pa.array([True, False, None, False, True]),
+ # mask is chunked array
+ pa.chunked_array([[True, False], [None, False, True]]),
+ # mask is python object
+ [True, False, None, False, True]
+ ]:
+ result = table.filter(mask)
+ assert result.equals(expected_drop)
+ result = table.filter(mask, null_selection_behavior="emit_null")
+ assert result.equals(expected_null)
+
+
+def test_filter_errors():
+ arr = pa.chunked_array([["a", None], ["c", "d", "e"]])
+ batch = pa.record_batch(
+ [pa.array(["a", None, "c", "d", "e"])], names=["a'"])
+ table = pa.table([pa.array(["a", None, "c", "d", "e"])], names=["a"])
+
+ for obj in [arr, batch, table]:
+ # non-boolean dtype
+ mask = pa.array([0, 1, 0, 1, 0])
+ with pytest.raises(NotImplementedError):
+ obj.filter(mask)
+
+ # wrong length
+ mask = pa.array([True, False, True])
+ with pytest.raises(pa.ArrowInvalid,
+ match="must all be the same length"):
+ obj.filter(mask)
+
+
+def test_filter_null_type():
+ # ARROW-10027
+ arr = pa.array([None] * 10)
+ chunked_arr = pa.chunked_array([[None] * 5] * 2)
+ batch = pa.record_batch([arr], names=['a'])
+ table = pa.table({'a': arr})
+
+ mask = pa.array([True, False] * 5)
+ assert len(arr.filter(mask)) == 5
+ assert len(chunked_arr.filter(mask)) == 5
+ assert len(batch.filter(mask).column(0)) == 5
+ assert len(table.filter(mask).column(0)) == 5
+
+
+@pytest.mark.parametrize("typ", ["array", "chunked_array"])
+def test_compare_array(typ):
+ if typ == "array":
+ def con(values):
+ return pa.array(values)
+ else:
+ def con(values):
+ return pa.chunked_array([values])
+
+ arr1 = con([1, 2, 3, 4, None])
+ arr2 = con([1, 1, 4, None, 4])
+
+ result = pc.equal(arr1, arr2)
+ assert result.equals(con([True, False, False, None, None]))
+
+ result = pc.not_equal(arr1, arr2)
+ assert result.equals(con([False, True, True, None, None]))
+
+ result = pc.less(arr1, arr2)
+ assert result.equals(con([False, False, True, None, None]))
+
+ result = pc.less_equal(arr1, arr2)
+ assert result.equals(con([True, False, True, None, None]))
+
+ result = pc.greater(arr1, arr2)
+ assert result.equals(con([False, True, False, None, None]))
+
+ result = pc.greater_equal(arr1, arr2)
+ assert result.equals(con([True, True, False, None, None]))
+
+
+@pytest.mark.parametrize("typ", ["array", "chunked_array"])
+def test_compare_string_scalar(typ):
+ if typ == "array":
+ def con(values):
+ return pa.array(values)
+ else:
+ def con(values):
+ return pa.chunked_array([values])
+
+ arr = con(['a', 'b', 'c', None])
+ scalar = pa.scalar('b')
+
+ result = pc.equal(arr, scalar)
+ assert result.equals(con([False, True, False, None]))
+
+ if typ == "array":
+ nascalar = pa.scalar(None, type="string")
+ result = pc.equal(arr, nascalar)
+ isnull = pc.is_null(result)
+ assert isnull.equals(con([True, True, True, True]))
+
+ result = pc.not_equal(arr, scalar)
+ assert result.equals(con([True, False, True, None]))
+
+ result = pc.less(arr, scalar)
+ assert result.equals(con([True, False, False, None]))
+
+ result = pc.less_equal(arr, scalar)
+ assert result.equals(con([True, True, False, None]))
+
+ result = pc.greater(arr, scalar)
+ assert result.equals(con([False, False, True, None]))
+
+ result = pc.greater_equal(arr, scalar)
+ assert result.equals(con([False, True, True, None]))
+
+
+@pytest.mark.parametrize("typ", ["array", "chunked_array"])
+def test_compare_scalar(typ):
+ if typ == "array":
+ def con(values):
+ return pa.array(values)
+ else:
+ def con(values):
+ return pa.chunked_array([values])
+
+ arr = con([1, 2, 3, None])
+ scalar = pa.scalar(2)
+
+ result = pc.equal(arr, scalar)
+ assert result.equals(con([False, True, False, None]))
+
+ if typ == "array":
+ nascalar = pa.scalar(None, type="int64")
+ result = pc.equal(arr, nascalar)
+ assert result.to_pylist() == [None, None, None, None]
+
+ result = pc.not_equal(arr, scalar)
+ assert result.equals(con([True, False, True, None]))
+
+ result = pc.less(arr, scalar)
+ assert result.equals(con([True, False, False, None]))
+
+ result = pc.less_equal(arr, scalar)
+ assert result.equals(con([True, True, False, None]))
+
+ result = pc.greater(arr, scalar)
+ assert result.equals(con([False, False, True, None]))
+
+ result = pc.greater_equal(arr, scalar)
+ assert result.equals(con([False, True, True, None]))
+
+
+def test_compare_chunked_array_mixed():
+ arr = pa.array([1, 2, 3, 4, None])
+ arr_chunked = pa.chunked_array([[1, 2, 3], [4, None]])
+ arr_chunked2 = pa.chunked_array([[1, 2], [3, 4, None]])
+
+ expected = pa.chunked_array([[True, True, True, True, None]])
+
+ for left, right in [
+ (arr, arr_chunked),
+ (arr_chunked, arr),
+ (arr_chunked, arr_chunked2),
+ ]:
+ result = pc.equal(left, right)
+ assert result.equals(expected)
+
+
+def test_arithmetic_add():
+ left = pa.array([1, 2, 3, 4, 5])
+ right = pa.array([0, -1, 1, 2, 3])
+ result = pc.add(left, right)
+ expected = pa.array([1, 1, 4, 6, 8])
+ assert result.equals(expected)
+
+
+def test_arithmetic_subtract():
+ left = pa.array([1, 2, 3, 4, 5])
+ right = pa.array([0, -1, 1, 2, 3])
+ result = pc.subtract(left, right)
+ expected = pa.array([1, 3, 2, 2, 2])
+ assert result.equals(expected)
+
+
+def test_arithmetic_multiply():
+ left = pa.array([1, 2, 3, 4, 5])
+ right = pa.array([0, -1, 1, 2, 3])
+ result = pc.multiply(left, right)
+ expected = pa.array([0, -2, 3, 8, 15])
+ assert result.equals(expected)
+
+
+@pytest.mark.parametrize("ty", ["round", "round_to_multiple"])
+def test_round_to_integer(ty):
+ if ty == "round":
+ round = pc.round
+ RoundOptions = partial(pc.RoundOptions, ndigits=0)
+ elif ty == "round_to_multiple":
+ round = pc.round_to_multiple
+ RoundOptions = partial(pc.RoundToMultipleOptions, multiple=1)
+
+ values = [3.2, 3.5, 3.7, 4.5, -3.2, -3.5, -3.7, None]
+ rmode_and_expected = {
+ "down": [3, 3, 3, 4, -4, -4, -4, None],
+ "up": [4, 4, 4, 5, -3, -3, -3, None],
+ "towards_zero": [3, 3, 3, 4, -3, -3, -3, None],
+ "towards_infinity": [4, 4, 4, 5, -4, -4, -4, None],
+ "half_down": [3, 3, 4, 4, -3, -4, -4, None],
+ "half_up": [3, 4, 4, 5, -3, -3, -4, None],
+ "half_towards_zero": [3, 3, 4, 4, -3, -3, -4, None],
+ "half_towards_infinity": [3, 4, 4, 5, -3, -4, -4, None],
+ "half_to_even": [3, 4, 4, 4, -3, -4, -4, None],
+ "half_to_odd": [3, 3, 4, 5, -3, -3, -4, None],
+ }
+ for round_mode, expected in rmode_and_expected.items():
+ options = RoundOptions(round_mode=round_mode)
+ result = round(values, options=options)
+ np.testing.assert_array_equal(result, pa.array(expected))
+
+
+def test_round():
+ values = [320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045, None]
+ ndigits_and_expected = {
+ -2: [300, 0, 0, 0, -0, -0, -0, None],
+ -1: [320, 0, 0, 0, -0, -40, -0, None],
+ 0: [320, 4, 3, 5, -3, -35, -3, None],
+ 1: [320, 3.5, 3.1, 4.5, -3.2, -35.1, -3, None],
+ 2: [320, 3.5, 3.08, 4.5, -3.21, -35.12, -3.05, None],
+ }
+ for ndigits, expected in ndigits_and_expected.items():
+ options = pc.RoundOptions(ndigits, "half_towards_infinity")
+ result = pc.round(values, options=options)
+ np.testing.assert_allclose(result, pa.array(expected), equal_nan=True)
+
+
+def test_round_to_multiple():
+ values = [320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045, None]
+ multiple_and_expected = {
+ 2: [320, 4, 4, 4, -4, -36, -4, None],
+ 0.05: [320, 3.5, 3.1, 4.5, -3.2, -35.1, -3.05, None],
+ 0.1: [320, 3.5, 3.1, 4.5, -3.2, -35.1, -3, None],
+ 10: [320, 0, 0, 0, -0, -40, -0, None],
+ 100: [300, 0, 0, 0, -0, -0, -0, None],
+ }
+ for multiple, expected in multiple_and_expected.items():
+ options = pc.RoundToMultipleOptions(multiple, "half_towards_infinity")
+ result = pc.round_to_multiple(values, options=options)
+ np.testing.assert_allclose(result, pa.array(expected), equal_nan=True)
+
+ with pytest.raises(pa.ArrowInvalid, match="multiple must be positive"):
+ pc.round_to_multiple(values, multiple=-2)
+
+
+def test_is_null():
+ arr = pa.array([1, 2, 3, None])
+ result = arr.is_null()
+ expected = pa.array([False, False, False, True])
+ assert result.equals(expected)
+ assert result.equals(pc.is_null(arr))
+ result = arr.is_valid()
+ expected = pa.array([True, True, True, False])
+ assert result.equals(expected)
+ assert result.equals(pc.is_valid(arr))
+
+ arr = pa.chunked_array([[1, 2], [3, None]])
+ result = arr.is_null()
+ expected = pa.chunked_array([[False, False], [False, True]])
+ assert result.equals(expected)
+ result = arr.is_valid()
+ expected = pa.chunked_array([[True, True], [True, False]])
+ assert result.equals(expected)
+
+ arr = pa.array([1, 2, 3, None, np.nan])
+ result = arr.is_null()
+ expected = pa.array([False, False, False, True, False])
+ assert result.equals(expected)
+
+ result = arr.is_null(nan_is_null=True)
+ expected = pa.array([False, False, False, True, True])
+ assert result.equals(expected)
+
+
+def test_fill_null():
+ arr = pa.array([1, 2, None, 4], type=pa.int8())
+ fill_value = pa.array([5], type=pa.int8())
+ with pytest.raises(pa.ArrowInvalid,
+ match="Array arguments must all be the same length"):
+ arr.fill_null(fill_value)
+
+ arr = pa.array([None, None, None, None], type=pa.null())
+ fill_value = pa.scalar(None, type=pa.null())
+ result = arr.fill_null(fill_value)
+ expected = pa.array([None, None, None, None])
+ assert result.equals(expected)
+
+ arr = pa.array(['a', 'bb', None])
+ result = arr.fill_null('ccc')
+ expected = pa.array(['a', 'bb', 'ccc'])
+ assert result.equals(expected)
+
+ arr = pa.array([b'a', b'bb', None], type=pa.large_binary())
+ result = arr.fill_null('ccc')
+ expected = pa.array([b'a', b'bb', b'ccc'], type=pa.large_binary())
+ assert result.equals(expected)
+
+ arr = pa.array(['a', 'bb', None])
+ result = arr.fill_null(None)
+ expected = pa.array(['a', 'bb', None])
+ assert result.equals(expected)
+
+
+@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
+def test_fill_null_array(arrow_type):
+ arr = pa.array([1, 2, None, 4], type=arrow_type)
+ fill_value = pa.scalar(5, type=arrow_type)
+ result = arr.fill_null(fill_value)
+ expected = pa.array([1, 2, 5, 4], type=arrow_type)
+ assert result.equals(expected)
+
+ # Implicit conversions
+ result = arr.fill_null(5)
+ assert result.equals(expected)
+
+ # ARROW-9451: Unsigned integers allow this for some reason
+ if not pa.types.is_unsigned_integer(arr.type):
+ with pytest.raises((ValueError, TypeError)):
+ arr.fill_null('5')
+
+ result = arr.fill_null(pa.scalar(5, type='int8'))
+ assert result.equals(expected)
+
+
+@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
+def test_fill_null_chunked_array(arrow_type):
+ fill_value = pa.scalar(5, type=arrow_type)
+ arr = pa.chunked_array([pa.array([None, 2, 3, 4], type=arrow_type)])
+ result = arr.fill_null(fill_value)
+ expected = pa.chunked_array([pa.array([5, 2, 3, 4], type=arrow_type)])
+ assert result.equals(expected)
+
+ arr = pa.chunked_array([
+ pa.array([1, 2], type=arrow_type),
+ pa.array([], type=arrow_type),
+ pa.array([None, 4], type=arrow_type)
+ ])
+ expected = pa.chunked_array([
+ pa.array([1, 2], type=arrow_type),
+ pa.array([], type=arrow_type),
+ pa.array([5, 4], type=arrow_type)
+ ])
+ result = arr.fill_null(fill_value)
+ assert result.equals(expected)
+
+ # Implicit conversions
+ result = arr.fill_null(5)
+ assert result.equals(expected)
+
+ result = arr.fill_null(pa.scalar(5, type='int8'))
+ assert result.equals(expected)
+
+
+def test_logical():
+ a = pa.array([True, False, False, None])
+ b = pa.array([True, True, False, True])
+
+ assert pc.and_(a, b) == pa.array([True, False, False, None])
+ assert pc.and_kleene(a, b) == pa.array([True, False, False, None])
+
+ assert pc.or_(a, b) == pa.array([True, True, False, None])
+ assert pc.or_kleene(a, b) == pa.array([True, True, False, True])
+
+ assert pc.xor(a, b) == pa.array([False, True, False, None])
+
+ assert pc.invert(a) == pa.array([False, True, True, None])
+
+
+def test_cast():
+ arr = pa.array([2 ** 63 - 1], type='int64')
+
+ with pytest.raises(pa.ArrowInvalid):
+ pc.cast(arr, 'int32')
+
+ assert pc.cast(arr, 'int32', safe=False) == pa.array([-1], type='int32')
+
+ arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)])
+ expected = pa.array([1262304000000, 1420070400000], type='timestamp[ms]')
+ assert pc.cast(arr, 'timestamp[ms]') == expected
+
+ arr = pa.array([[1, 2], [3, 4, 5]], type=pa.large_list(pa.int8()))
+ expected = pa.array([["1", "2"], ["3", "4", "5"]],
+ type=pa.list_(pa.utf8()))
+ assert pc.cast(arr, expected.type) == expected
+
+
+def test_strptime():
+ arr = pa.array(["5/1/2020", None, "12/13/1900"])
+
+ got = pc.strptime(arr, format='%m/%d/%Y', unit='s')
+ expected = pa.array([datetime(2020, 5, 1), None, datetime(1900, 12, 13)],
+ type=pa.timestamp('s'))
+ assert got == expected
+
+
+# TODO: We should test on windows once ARROW-13168 is resolved.
+@pytest.mark.pandas
+@pytest.mark.skipif(sys.platform == 'win32',
+ reason="Timezone database is not available on Windows yet")
+def test_strftime():
+ from pyarrow.vendored.version import Version
+
+ def _fix_timestamp(s):
+ if Version(pd.__version__) < Version("1.0.0"):
+ return s.to_series().replace("NaT", pd.NaT)
+ else:
+ return s
+
+ times = ["2018-03-10 09:00", "2038-01-31 12:23", None]
+ timezones = ["CET", "UTC", "Europe/Ljubljana"]
+
+ formats = ["%a", "%A", "%w", "%d", "%b", "%B", "%m", "%y", "%Y", "%H",
+ "%I", "%p", "%M", "%z", "%Z", "%j", "%U", "%W", "%c", "%x",
+ "%X", "%%", "%G", "%V", "%u"]
+
+ for timezone in timezones:
+ ts = pd.to_datetime(times).tz_localize(timezone)
+ for unit in ["s", "ms", "us", "ns"]:
+ tsa = pa.array(ts, type=pa.timestamp(unit, timezone))
+ for fmt in formats:
+ options = pc.StrftimeOptions(fmt)
+ result = pc.strftime(tsa, options=options)
+ expected = pa.array(_fix_timestamp(ts.strftime(fmt)))
+ assert result.equals(expected)
+
+ fmt = "%Y-%m-%dT%H:%M:%S"
+
+ # Default format
+ tsa = pa.array(ts, type=pa.timestamp("s", timezone))
+ result = pc.strftime(tsa, options=pc.StrftimeOptions())
+ expected = pa.array(_fix_timestamp(ts.strftime(fmt)))
+ assert result.equals(expected)
+
+ # Default format plus timezone
+ tsa = pa.array(ts, type=pa.timestamp("s", timezone))
+ result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%Z"))
+ expected = pa.array(_fix_timestamp(ts.strftime(fmt + "%Z")))
+ assert result.equals(expected)
+
+ # Pandas %S is equivalent to %S in arrow for unit="s"
+ tsa = pa.array(ts, type=pa.timestamp("s", timezone))
+ options = pc.StrftimeOptions("%S")
+ result = pc.strftime(tsa, options=options)
+ expected = pa.array(_fix_timestamp(ts.strftime("%S")))
+ assert result.equals(expected)
+
+ # Pandas %S.%f is equivalent to %S in arrow for unit="us"
+ tsa = pa.array(ts, type=pa.timestamp("us", timezone))
+ options = pc.StrftimeOptions("%S")
+ result = pc.strftime(tsa, options=options)
+ expected = pa.array(_fix_timestamp(ts.strftime("%S.%f")))
+ assert result.equals(expected)
+
+ # Test setting locale
+ tsa = pa.array(ts, type=pa.timestamp("s", timezone))
+ options = pc.StrftimeOptions(fmt, locale="C")
+ result = pc.strftime(tsa, options=options)
+ expected = pa.array(_fix_timestamp(ts.strftime(fmt)))
+ assert result.equals(expected)
+
+ # Test timestamps without timezone
+ fmt = "%Y-%m-%dT%H:%M:%S"
+ ts = pd.to_datetime(times)
+ tsa = pa.array(ts, type=pa.timestamp("s"))
+ result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt))
+ expected = pa.array(_fix_timestamp(ts.strftime(fmt)))
+
+ assert result.equals(expected)
+ with pytest.raises(pa.ArrowInvalid,
+ match="Timezone not present, cannot convert to string"):
+ pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%Z"))
+ with pytest.raises(pa.ArrowInvalid,
+ match="Timezone not present, cannot convert to string"):
+ pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%z"))
+
+
+def _check_datetime_components(timestamps, timezone=None):
+ from pyarrow.vendored.version import Version
+
+ ts = pd.to_datetime(timestamps).tz_localize(
+ "UTC").tz_convert(timezone).to_series()
+ tsa = pa.array(ts, pa.timestamp("ns", tz=timezone))
+
+ subseconds = ((ts.dt.microsecond * 10 ** 3 +
+ ts.dt.nanosecond) * 10 ** -9).round(9)
+ iso_calendar_fields = [
+ pa.field('iso_year', pa.int64()),
+ pa.field('iso_week', pa.int64()),
+ pa.field('iso_day_of_week', pa.int64())
+ ]
+
+ if Version(pd.__version__) < Version("1.1.0"):
+ # https://github.com/pandas-dev/pandas/issues/33206
+ iso_year = ts.map(lambda x: x.isocalendar()[0]).astype("int64")
+ iso_week = ts.map(lambda x: x.isocalendar()[1]).astype("int64")
+ iso_day = ts.map(lambda x: x.isocalendar()[2]).astype("int64")
+ else:
+ # Casting is required because pandas isocalendar returns int32
+ # while arrow isocalendar returns int64.
+ iso_year = ts.dt.isocalendar()["year"].astype("int64")
+ iso_week = ts.dt.isocalendar()["week"].astype("int64")
+ iso_day = ts.dt.isocalendar()["day"].astype("int64")
+
+ iso_calendar = pa.StructArray.from_arrays(
+ [iso_year, iso_week, iso_day],
+ fields=iso_calendar_fields)
+
+ assert pc.year(tsa).equals(pa.array(ts.dt.year))
+ assert pc.month(tsa).equals(pa.array(ts.dt.month))
+ assert pc.day(tsa).equals(pa.array(ts.dt.day))
+ assert pc.day_of_week(tsa).equals(pa.array(ts.dt.dayofweek))
+ assert pc.day_of_year(tsa).equals(pa.array(ts.dt.dayofyear))
+ assert pc.iso_year(tsa).equals(pa.array(iso_year))
+ assert pc.iso_week(tsa).equals(pa.array(iso_week))
+ assert pc.iso_calendar(tsa).equals(iso_calendar)
+ assert pc.quarter(tsa).equals(pa.array(ts.dt.quarter))
+ assert pc.hour(tsa).equals(pa.array(ts.dt.hour))
+ assert pc.minute(tsa).equals(pa.array(ts.dt.minute))
+ assert pc.second(tsa).equals(pa.array(ts.dt.second.values))
+ assert pc.millisecond(tsa).equals(pa.array(ts.dt.microsecond // 10 ** 3))
+ assert pc.microsecond(tsa).equals(pa.array(ts.dt.microsecond % 10 ** 3))
+ assert pc.nanosecond(tsa).equals(pa.array(ts.dt.nanosecond))
+ assert pc.subsecond(tsa).equals(pa.array(subseconds))
+
+ day_of_week_options = pc.DayOfWeekOptions(
+ count_from_zero=False, week_start=1)
+ assert pc.day_of_week(tsa, options=day_of_week_options).equals(
+ pa.array(ts.dt.dayofweek + 1))
+
+ week_options = pc.WeekOptions(
+ week_starts_monday=True, count_from_zero=False,
+ first_week_is_fully_in_year=False)
+ assert pc.week(tsa, options=week_options).equals(pa.array(iso_week))
+
+
+@pytest.mark.pandas
+def test_extract_datetime_components():
+ from pyarrow.vendored.version import Version
+
+ timestamps = ["1970-01-01T00:00:59.123456789",
+ "2000-02-29T23:23:23.999999999",
+ "2033-05-18T03:33:20.000000000",
+ "2020-01-01T01:05:05.001",
+ "2019-12-31T02:10:10.002",
+ "2019-12-30T03:15:15.003",
+ "2009-12-31T04:20:20.004132",
+ "2010-01-01T05:25:25.005321",
+ "2010-01-03T06:30:30.006163",
+ "2010-01-04T07:35:35",
+ "2006-01-01T08:40:40",
+ "2005-12-31T09:45:45",
+ "2008-12-28",
+ "2008-12-29",
+ "2012-01-01 01:02:03"]
+ timezones = ["UTC", "US/Central", "Asia/Kolkata",
+ "Etc/GMT-4", "Etc/GMT+4", "Australia/Broken_Hill"]
+
+ # Test timezone naive timestamp array
+ _check_datetime_components(timestamps)
+
+ # Test timezone aware timestamp array
+ if sys.platform == 'win32':
+ # TODO: We should test on windows once ARROW-13168 is resolved.
+ pytest.skip('Timezone database is not available on Windows yet')
+ elif Version(pd.__version__) < Version('1.0.0'):
+ pytest.skip('Pandas < 1.0 extracts time components incorrectly.')
+ else:
+ for timezone in timezones:
+ _check_datetime_components(timestamps, timezone)
+
+
+# TODO: We should test on windows once ARROW-13168 is resolved.
+@pytest.mark.pandas
+@pytest.mark.skipif(sys.platform == 'win32',
+ reason="Timezone database is not available on Windows yet")
+def test_assume_timezone():
+ from pyarrow.vendored.version import Version
+
+ ts_type = pa.timestamp("ns")
+ timestamps = pd.to_datetime(["1970-01-01T00:00:59.123456789",
+ "2000-02-29T23:23:23.999999999",
+ "2033-05-18T03:33:20.000000000",
+ "2020-01-01T01:05:05.001",
+ "2019-12-31T02:10:10.002",
+ "2019-12-30T03:15:15.003",
+ "2009-12-31T04:20:20.004132",
+ "2010-01-01T05:25:25.005321",
+ "2010-01-03T06:30:30.006163",
+ "2010-01-04T07:35:35",
+ "2006-01-01T08:40:40",
+ "2005-12-31T09:45:45",
+ "2008-12-28",
+ "2008-12-29",
+ "2012-01-01 01:02:03"])
+ nonexistent = pd.to_datetime(["2015-03-29 02:30:00",
+ "2015-03-29 03:30:00"])
+ ambiguous = pd.to_datetime(["2018-10-28 01:20:00",
+ "2018-10-28 02:36:00",
+ "2018-10-28 03:46:00"])
+ ambiguous_array = pa.array(ambiguous, type=ts_type)
+ nonexistent_array = pa.array(nonexistent, type=ts_type)
+
+ for timezone in ["UTC", "US/Central", "Asia/Kolkata"]:
+ options = pc.AssumeTimezoneOptions(timezone)
+ ta = pa.array(timestamps, type=ts_type)
+ expected = timestamps.tz_localize(timezone)
+ result = pc.assume_timezone(ta, options=options)
+ assert result.equals(pa.array(expected))
+
+ ta_zoned = pa.array(timestamps, type=pa.timestamp("ns", timezone))
+ with pytest.raises(pa.ArrowInvalid, match="already have a timezone:"):
+ pc.assume_timezone(ta_zoned, options=options)
+
+ invalid_options = pc.AssumeTimezoneOptions("Europe/Brusselsss")
+ with pytest.raises(ValueError, match="not found in timezone database"):
+ pc.assume_timezone(ta, options=invalid_options)
+
+ timezone = "Europe/Brussels"
+
+ # nonexistent parameter was introduced in Pandas 0.24.0
+ if Version(pd.__version__) >= Version("0.24.0"):
+ options_nonexistent_raise = pc.AssumeTimezoneOptions(timezone)
+ options_nonexistent_earliest = pc.AssumeTimezoneOptions(
+ timezone, ambiguous="raise", nonexistent="earliest")
+ options_nonexistent_latest = pc.AssumeTimezoneOptions(
+ timezone, ambiguous="raise", nonexistent="latest")
+
+ with pytest.raises(ValueError,
+ match="Timestamp doesn't exist in "
+ f"timezone '{timezone}'"):
+ pc.assume_timezone(nonexistent_array,
+ options=options_nonexistent_raise)
+
+ expected = pa.array(nonexistent.tz_localize(
+ timezone, nonexistent="shift_forward"))
+ result = pc.assume_timezone(
+ nonexistent_array, options=options_nonexistent_latest)
+ expected.equals(result)
+
+ expected = pa.array(nonexistent.tz_localize(
+ timezone, nonexistent="shift_backward"))
+ result = pc.assume_timezone(
+ nonexistent_array, options=options_nonexistent_earliest)
+ expected.equals(result)
+
+ options_ambiguous_raise = pc.AssumeTimezoneOptions(timezone)
+ options_ambiguous_latest = pc.AssumeTimezoneOptions(
+ timezone, ambiguous="latest", nonexistent="raise")
+ options_ambiguous_earliest = pc.AssumeTimezoneOptions(
+ timezone, ambiguous="earliest", nonexistent="raise")
+
+ with pytest.raises(ValueError,
+ match="Timestamp is ambiguous in "
+ f"timezone '{timezone}'"):
+ pc.assume_timezone(ambiguous_array, options=options_ambiguous_raise)
+
+ expected = ambiguous.tz_localize(timezone, ambiguous=[True, True, True])
+ result = pc.assume_timezone(
+ ambiguous_array, options=options_ambiguous_earliest)
+ result.equals(pa.array(expected))
+
+ expected = ambiguous.tz_localize(timezone, ambiguous=[False, False, False])
+ result = pc.assume_timezone(
+ ambiguous_array, options=options_ambiguous_latest)
+ result.equals(pa.array(expected))
+
+
+def test_count():
+ arr = pa.array([1, 2, 3, None, None])
+ assert pc.count(arr).as_py() == 3
+ assert pc.count(arr, mode='only_valid').as_py() == 3
+ assert pc.count(arr, mode='only_null').as_py() == 2
+ assert pc.count(arr, mode='all').as_py() == 5
+
+
+def test_index():
+ arr = pa.array([0, 1, None, 3, 4], type=pa.int64())
+ assert pc.index(arr, pa.scalar(0)).as_py() == 0
+ assert pc.index(arr, pa.scalar(2, type=pa.int8())).as_py() == -1
+ assert pc.index(arr, 4).as_py() == 4
+ assert arr.index(3, start=2).as_py() == 3
+ assert arr.index(None).as_py() == -1
+
+ arr = pa.chunked_array([[1, 2], [1, 3]], type=pa.int64())
+ assert arr.index(1).as_py() == 0
+ assert arr.index(1, start=2).as_py() == 2
+ assert arr.index(1, start=1, end=2).as_py() == -1
+
+
+def check_partition_nth(data, indices, pivot, null_placement):
+ indices = indices.to_pylist()
+ assert len(indices) == len(data)
+ assert sorted(indices) == list(range(len(data)))
+ until_pivot = [data[indices[i]] for i in range(pivot)]
+ after_pivot = [data[indices[i]] for i in range(pivot, len(data))]
+ p = data[indices[pivot]]
+ if p is None:
+ if null_placement == "at_start":
+ assert all(v is None for v in until_pivot)
+ else:
+ assert all(v is None for v in after_pivot)
+ else:
+ if null_placement == "at_start":
+ assert all(v is None or v <= p for v in until_pivot)
+ assert all(v >= p for v in after_pivot)
+ else:
+ assert all(v <= p for v in until_pivot)
+ assert all(v is None or v >= p for v in after_pivot)
+
+
+def test_partition_nth():
+ data = list(range(100, 140))
+ random.shuffle(data)
+ pivot = 10
+ indices = pc.partition_nth_indices(data, pivot=pivot)
+ check_partition_nth(data, indices, pivot, "at_end")
+
+
+def test_partition_nth_null_placement():
+ data = list(range(10)) + [None] * 10
+ random.shuffle(data)
+
+ for pivot in (0, 7, 13, 19):
+ for null_placement in ("at_start", "at_end"):
+ indices = pc.partition_nth_indices(data, pivot=pivot,
+ null_placement=null_placement)
+ check_partition_nth(data, indices, pivot, null_placement)
+
+
+def test_select_k_array():
+ def validate_select_k(select_k_indices, arr, order, stable_sort=False):
+ sorted_indices = pc.sort_indices(arr, sort_keys=[("dummy", order)])
+ head_k_indices = sorted_indices.slice(0, len(select_k_indices))
+ if stable_sort:
+ assert select_k_indices == head_k_indices
+ else:
+ expected = pc.take(arr, head_k_indices)
+ actual = pc.take(arr, select_k_indices)
+ assert actual == expected
+
+ arr = pa.array([1, 2, None, 0])
+ for k in [0, 2, 4]:
+ for order in ["descending", "ascending"]:
+ result = pc.select_k_unstable(
+ arr, k=k, sort_keys=[("dummy", order)])
+ validate_select_k(result, arr, order)
+
+ result = pc.top_k_unstable(arr, k=k)
+ validate_select_k(result, arr, "descending")
+
+ result = pc.bottom_k_unstable(arr, k=k)
+ validate_select_k(result, arr, "ascending")
+
+ result = pc.select_k_unstable(
+ arr, options=pc.SelectKOptions(
+ k=2, sort_keys=[("dummy", "descending")])
+ )
+ validate_select_k(result, arr, "descending")
+
+ result = pc.select_k_unstable(
+ arr, options=pc.SelectKOptions(k=2, sort_keys=[("dummy", "ascending")])
+ )
+ validate_select_k(result, arr, "ascending")
+
+
+def test_select_k_table():
+ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False):
+ sorted_indices = pc.sort_indices(tbl, sort_keys=sort_keys)
+ head_k_indices = sorted_indices.slice(0, len(select_k_indices))
+ if stable_sort:
+ assert select_k_indices == head_k_indices
+ else:
+ expected = pc.take(tbl, head_k_indices)
+ actual = pc.take(tbl, select_k_indices)
+ assert actual == expected
+
+ table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]})
+ for k in [0, 2, 4]:
+ result = pc.select_k_unstable(
+ table, k=k, sort_keys=[("a", "ascending")])
+ validate_select_k(result, table, sort_keys=[("a", "ascending")])
+
+ result = pc.select_k_unstable(
+ table, k=k, sort_keys=[("a", "ascending"), ("b", "ascending")])
+ validate_select_k(
+ result, table, sort_keys=[("a", "ascending"), ("b", "ascending")])
+
+ result = pc.top_k_unstable(table, k=k, sort_keys=["a"])
+ validate_select_k(result, table, sort_keys=[("a", "descending")])
+
+ result = pc.bottom_k_unstable(table, k=k, sort_keys=["a", "b"])
+ validate_select_k(
+ result, table, sort_keys=[("a", "ascending"), ("b", "ascending")])
+
+ with pytest.raises(ValueError,
+ match="select_k_unstable requires a nonnegative `k`"):
+ pc.select_k_unstable(table)
+
+ with pytest.raises(ValueError,
+ match="select_k_unstable requires a "
+ "non-empty `sort_keys`"):
+ pc.select_k_unstable(table, k=2, sort_keys=[])
+
+ with pytest.raises(ValueError, match="not a valid sort order"):
+ pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending")])
+
+ with pytest.raises(ValueError, match="Nonexistent sort key column"):
+ pc.select_k_unstable(table, k=k, sort_keys=[("unknown", "ascending")])
+
+
+def test_array_sort_indices():
+ arr = pa.array([1, 2, None, 0])
+ result = pc.array_sort_indices(arr)
+ assert result.to_pylist() == [3, 0, 1, 2]
+ result = pc.array_sort_indices(arr, order="ascending")
+ assert result.to_pylist() == [3, 0, 1, 2]
+ result = pc.array_sort_indices(arr, order="descending")
+ assert result.to_pylist() == [1, 0, 3, 2]
+ result = pc.array_sort_indices(arr, order="descending",
+ null_placement="at_start")
+ assert result.to_pylist() == [2, 1, 0, 3]
+
+ with pytest.raises(ValueError, match="not a valid sort order"):
+ pc.array_sort_indices(arr, order="nonscending")
+
+
+def test_sort_indices_array():
+ arr = pa.array([1, 2, None, 0])
+ result = pc.sort_indices(arr)
+ assert result.to_pylist() == [3, 0, 1, 2]
+ result = pc.sort_indices(arr, sort_keys=[("dummy", "ascending")])
+ assert result.to_pylist() == [3, 0, 1, 2]
+ result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")])
+ assert result.to_pylist() == [1, 0, 3, 2]
+ result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")],
+ null_placement="at_start")
+ assert result.to_pylist() == [2, 1, 0, 3]
+ # Using SortOptions
+ result = pc.sort_indices(
+ arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")])
+ )
+ assert result.to_pylist() == [1, 0, 3, 2]
+ result = pc.sort_indices(
+ arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")],
+ null_placement="at_start")
+ )
+ assert result.to_pylist() == [2, 1, 0, 3]
+
+
+def test_sort_indices_table():
+ table = pa.table({"a": [1, 1, None, 0], "b": [1, 0, 0, 1]})
+
+ result = pc.sort_indices(table, sort_keys=[("a", "ascending")])
+ assert result.to_pylist() == [3, 0, 1, 2]
+ result = pc.sort_indices(table, sort_keys=[("a", "ascending")],
+ null_placement="at_start")
+ assert result.to_pylist() == [2, 3, 0, 1]
+
+ result = pc.sort_indices(
+ table, sort_keys=[("a", "descending"), ("b", "ascending")]
+ )
+ assert result.to_pylist() == [1, 0, 3, 2]
+ result = pc.sort_indices(
+ table, sort_keys=[("a", "descending"), ("b", "ascending")],
+ null_placement="at_start"
+ )
+ assert result.to_pylist() == [2, 1, 0, 3]
+
+ with pytest.raises(ValueError, match="Must specify one or more sort keys"):
+ pc.sort_indices(table)
+
+ with pytest.raises(ValueError, match="Nonexistent sort key column"):
+ pc.sort_indices(table, sort_keys=[("unknown", "ascending")])
+
+ with pytest.raises(ValueError, match="not a valid sort order"):
+ pc.sort_indices(table, sort_keys=[("a", "nonscending")])
+
+
+def test_is_in():
+ arr = pa.array([1, 2, None, 1, 2, 3])
+
+ result = pc.is_in(arr, value_set=pa.array([1, 3, None]))
+ assert result.to_pylist() == [True, False, True, True, False, True]
+
+ result = pc.is_in(arr, value_set=pa.array([1, 3, None]), skip_nulls=True)
+ assert result.to_pylist() == [True, False, False, True, False, True]
+
+ result = pc.is_in(arr, value_set=pa.array([1, 3]))
+ assert result.to_pylist() == [True, False, False, True, False, True]
+
+ result = pc.is_in(arr, value_set=pa.array([1, 3]), skip_nulls=True)
+ assert result.to_pylist() == [True, False, False, True, False, True]
+
+
+def test_index_in():
+ arr = pa.array([1, 2, None, 1, 2, 3])
+
+ result = pc.index_in(arr, value_set=pa.array([1, 3, None]))
+ assert result.to_pylist() == [0, None, 2, 0, None, 1]
+
+ result = pc.index_in(arr, value_set=pa.array([1, 3, None]),
+ skip_nulls=True)
+ assert result.to_pylist() == [0, None, None, 0, None, 1]
+
+ result = pc.index_in(arr, value_set=pa.array([1, 3]))
+ assert result.to_pylist() == [0, None, None, 0, None, 1]
+
+ result = pc.index_in(arr, value_set=pa.array([1, 3]), skip_nulls=True)
+ assert result.to_pylist() == [0, None, None, 0, None, 1]
+
+
+def test_quantile():
+ arr = pa.array([1, 2, 3, 4])
+
+ result = pc.quantile(arr)
+ assert result.to_pylist() == [2.5]
+
+ result = pc.quantile(arr, interpolation='lower')
+ assert result.to_pylist() == [2]
+ result = pc.quantile(arr, interpolation='higher')
+ assert result.to_pylist() == [3]
+ result = pc.quantile(arr, interpolation='nearest')
+ assert result.to_pylist() == [3]
+ result = pc.quantile(arr, interpolation='midpoint')
+ assert result.to_pylist() == [2.5]
+ result = pc.quantile(arr, interpolation='linear')
+ assert result.to_pylist() == [2.5]
+
+ arr = pa.array([1, 2])
+
+ result = pc.quantile(arr, q=[0.25, 0.5, 0.75])
+ assert result.to_pylist() == [1.25, 1.5, 1.75]
+
+ result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='lower')
+ assert result.to_pylist() == [1, 1, 1]
+ result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='higher')
+ assert result.to_pylist() == [2, 2, 2]
+ result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='midpoint')
+ assert result.to_pylist() == [1.5, 1.5, 1.5]
+ result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='nearest')
+ assert result.to_pylist() == [1, 1, 2]
+ result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='linear')
+ assert result.to_pylist() == [1.25, 1.5, 1.75]
+
+ with pytest.raises(ValueError, match="Quantile must be between 0 and 1"):
+ pc.quantile(arr, q=1.1)
+ with pytest.raises(ValueError, match="not a valid quantile interpolation"):
+ pc.quantile(arr, interpolation='zzz')
+
+
+def test_tdigest():
+ arr = pa.array([1, 2, 3, 4])
+ result = pc.tdigest(arr)
+ assert result.to_pylist() == [2.5]
+
+ arr = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])])
+ result = pc.tdigest(arr)
+ assert result.to_pylist() == [2.5]
+
+ arr = pa.array([1, 2, 3, 4])
+ result = pc.tdigest(arr, q=[0, 0.5, 1])
+ assert result.to_pylist() == [1, 2.5, 4]
+
+ arr = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])])
+ result = pc.tdigest(arr, q=[0, 0.5, 1])
+ assert result.to_pylist() == [1, 2.5, 4]
+
+
+def test_fill_null_segfault():
+ # ARROW-12672
+ arr = pa.array([None], pa.bool_()).fill_null(False)
+ result = arr.cast(pa.int8())
+ assert result == pa.array([0], pa.int8())
+
+
+def test_min_max_element_wise():
+ arr1 = pa.array([1, 2, 3])
+ arr2 = pa.array([3, 1, 2])
+ arr3 = pa.array([2, 3, None])
+
+ result = pc.max_element_wise(arr1, arr2)
+ assert result == pa.array([3, 2, 3])
+ result = pc.min_element_wise(arr1, arr2)
+ assert result == pa.array([1, 1, 2])
+
+ result = pc.max_element_wise(arr1, arr2, arr3)
+ assert result == pa.array([3, 3, 3])
+ result = pc.min_element_wise(arr1, arr2, arr3)
+ assert result == pa.array([1, 1, 2])
+
+ # with specifying the option
+ result = pc.max_element_wise(arr1, arr3, skip_nulls=True)
+ assert result == pa.array([2, 3, 3])
+ result = pc.min_element_wise(arr1, arr3, skip_nulls=True)
+ assert result == pa.array([1, 2, 3])
+ result = pc.max_element_wise(
+ arr1, arr3, options=pc.ElementWiseAggregateOptions())
+ assert result == pa.array([2, 3, 3])
+ result = pc.min_element_wise(
+ arr1, arr3, options=pc.ElementWiseAggregateOptions())
+ assert result == pa.array([1, 2, 3])
+
+ # not skipping nulls
+ result = pc.max_element_wise(arr1, arr3, skip_nulls=False)
+ assert result == pa.array([2, 3, None])
+ result = pc.min_element_wise(arr1, arr3, skip_nulls=False)
+ assert result == pa.array([1, 2, None])
+
+
+def test_make_struct():
+ assert pc.make_struct(1, 'a').as_py() == {'0': 1, '1': 'a'}
+
+ assert pc.make_struct(1, 'a', field_names=['i', 's']).as_py() == {
+ 'i': 1, 's': 'a'}
+
+ assert pc.make_struct([1, 2, 3],
+ "a b c".split()) == pa.StructArray.from_arrays([
+ [1, 2, 3],
+ "a b c".split()], names='0 1'.split())
+
+ with pytest.raises(ValueError,
+ match="Array arguments must all be the same length"):
+ pc.make_struct([1, 2, 3, 4], "a b c".split())
+
+ with pytest.raises(ValueError, match="0 arguments but 2 field names"):
+ pc.make_struct(field_names=['one', 'two'])
+
+
+def test_case_when():
+ assert pc.case_when(pc.make_struct([True, False, None],
+ [False, True, None]),
+ [1, 2, 3],
+ [11, 12, 13]) == pa.array([1, 12, None])
+
+
+def test_list_element():
+ element_type = pa.struct([('a', pa.float64()), ('b', pa.int8())])
+ list_type = pa.list_(element_type)
+ l1 = [{'a': .4, 'b': 2}, None, {'a': .2, 'b': 4}, None, {'a': 5.6, 'b': 6}]
+ l2 = [None, {'a': .52, 'b': 3}, {'a': .7, 'b': 4}, None, {'a': .6, 'b': 8}]
+ lists = pa.array([l1, l2], list_type)
+
+ index = 1
+ result = pa.compute.list_element(lists, index)
+ expected = pa.array([None, {'a': 0.52, 'b': 3}], element_type)
+ assert result.equals(expected)
+
+ index = 4
+ result = pa.compute.list_element(lists, index)
+ expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type)
+ assert result.equals(expected)
+
+
+def test_count_distinct():
+ seed = datetime.now()
+ samples = [seed.replace(year=y) for y in range(1992, 2092)]
+ arr = pa.array(samples, pa.timestamp("ns"))
+ result = pa.compute.count_distinct(arr)
+ expected = pa.scalar(len(samples), type=pa.int64())
+ assert result.equals(expected)
+
+
+def test_count_distinct_options():
+ arr = pa.array([1, 2, 3, None, None])
+ assert pc.count_distinct(arr).as_py() == 3
+ assert pc.count_distinct(arr, mode='only_valid').as_py() == 3
+ assert pc.count_distinct(arr, mode='only_null').as_py() == 1
+ assert pc.count_distinct(arr, mode='all').as_py() == 4
diff --git a/src/arrow/python/pyarrow/tests/test_convert_builtin.py b/src/arrow/python/pyarrow/tests/test_convert_builtin.py
new file mode 100644
index 000000000..7a355390a
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_convert_builtin.py
@@ -0,0 +1,2309 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import collections
+import datetime
+import decimal
+import itertools
+import math
+import re
+
+import hypothesis as h
+import numpy as np
+import pytz
+import pytest
+
+from pyarrow.pandas_compat import _pandas_api # noqa
+import pyarrow as pa
+import pyarrow.tests.strategies as past
+
+
+int_type_pairs = [
+ (np.int8, pa.int8()),
+ (np.int16, pa.int16()),
+ (np.int32, pa.int32()),
+ (np.int64, pa.int64()),
+ (np.uint8, pa.uint8()),
+ (np.uint16, pa.uint16()),
+ (np.uint32, pa.uint32()),
+ (np.uint64, pa.uint64())]
+
+
+np_int_types, pa_int_types = zip(*int_type_pairs)
+
+
+class StrangeIterable:
+ def __init__(self, lst):
+ self.lst = lst
+
+ def __iter__(self):
+ return self.lst.__iter__()
+
+
+class MyInt:
+ def __init__(self, value):
+ self.value = value
+
+ def __int__(self):
+ return self.value
+
+
+class MyBrokenInt:
+ def __int__(self):
+ 1/0 # MARKER
+
+
+def check_struct_type(ty, expected):
+ """
+ Check a struct type is as expected, but not taking order into account.
+ """
+ assert pa.types.is_struct(ty)
+ assert set(ty) == set(expected)
+
+
+def test_iterable_types():
+ arr1 = pa.array(StrangeIterable([0, 1, 2, 3]))
+ arr2 = pa.array((0, 1, 2, 3))
+
+ assert arr1.equals(arr2)
+
+
+def test_empty_iterable():
+ arr = pa.array(StrangeIterable([]))
+ assert len(arr) == 0
+ assert arr.null_count == 0
+ assert arr.type == pa.null()
+ assert arr.to_pylist() == []
+
+
+def test_limited_iterator_types():
+ arr1 = pa.array(iter(range(3)), type=pa.int64(), size=3)
+ arr2 = pa.array((0, 1, 2))
+ assert arr1.equals(arr2)
+
+
+def test_limited_iterator_size_overflow():
+ arr1 = pa.array(iter(range(3)), type=pa.int64(), size=2)
+ arr2 = pa.array((0, 1))
+ assert arr1.equals(arr2)
+
+
+def test_limited_iterator_size_underflow():
+ arr1 = pa.array(iter(range(3)), type=pa.int64(), size=10)
+ arr2 = pa.array((0, 1, 2))
+ assert arr1.equals(arr2)
+
+
+def test_iterator_without_size():
+ expected = pa.array((0, 1, 2))
+ arr1 = pa.array(iter(range(3)))
+ assert arr1.equals(expected)
+ # Same with explicit type
+ arr1 = pa.array(iter(range(3)), type=pa.int64())
+ assert arr1.equals(expected)
+
+
+def test_infinite_iterator():
+ expected = pa.array((0, 1, 2))
+ arr1 = pa.array(itertools.count(0), size=3)
+ assert arr1.equals(expected)
+ # Same with explicit type
+ arr1 = pa.array(itertools.count(0), type=pa.int64(), size=3)
+ assert arr1.equals(expected)
+
+
+def _as_list(xs):
+ return xs
+
+
+def _as_tuple(xs):
+ return tuple(xs)
+
+
+def _as_deque(xs):
+ # deque is a sequence while neither tuple nor list
+ return collections.deque(xs)
+
+
+def _as_dict_values(xs):
+ # a dict values object is not a sequence, just a regular iterable
+ dct = {k: v for k, v in enumerate(xs)}
+ return dct.values()
+
+
+def _as_numpy_array(xs):
+ arr = np.empty(len(xs), dtype=object)
+ arr[:] = xs
+ return arr
+
+
+def _as_set(xs):
+ return set(xs)
+
+
+SEQUENCE_TYPES = [_as_list, _as_tuple, _as_numpy_array]
+ITERABLE_TYPES = [_as_set, _as_dict_values] + SEQUENCE_TYPES
+COLLECTIONS_TYPES = [_as_deque] + ITERABLE_TYPES
+
+parametrize_with_iterable_types = pytest.mark.parametrize(
+ "seq", ITERABLE_TYPES
+)
+
+parametrize_with_sequence_types = pytest.mark.parametrize(
+ "seq", SEQUENCE_TYPES
+)
+
+parametrize_with_collections_types = pytest.mark.parametrize(
+ "seq", COLLECTIONS_TYPES
+)
+
+
+@parametrize_with_collections_types
+def test_sequence_types(seq):
+ arr1 = pa.array(seq([1, 2, 3]))
+ arr2 = pa.array([1, 2, 3])
+
+ assert arr1.equals(arr2)
+
+
+@parametrize_with_iterable_types
+def test_nested_sequence_types(seq):
+ arr1 = pa.array([seq([1, 2, 3])])
+ arr2 = pa.array([[1, 2, 3]])
+
+ assert arr1.equals(arr2)
+
+
+@parametrize_with_sequence_types
+def test_sequence_boolean(seq):
+ expected = [True, None, False, None]
+ arr = pa.array(seq(expected))
+ assert len(arr) == 4
+ assert arr.null_count == 2
+ assert arr.type == pa.bool_()
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_sequence_types
+def test_sequence_numpy_boolean(seq):
+ expected = [np.bool_(True), None, np.bool_(False), None]
+ arr = pa.array(seq(expected))
+ assert arr.type == pa.bool_()
+ assert arr.to_pylist() == [True, None, False, None]
+
+
+@parametrize_with_sequence_types
+def test_sequence_mixed_numpy_python_bools(seq):
+ values = np.array([True, False])
+ arr = pa.array(seq([values[0], None, values[1], True, False]))
+ assert arr.type == pa.bool_()
+ assert arr.to_pylist() == [True, None, False, True, False]
+
+
+@parametrize_with_collections_types
+def test_empty_list(seq):
+ arr = pa.array(seq([]))
+ assert len(arr) == 0
+ assert arr.null_count == 0
+ assert arr.type == pa.null()
+ assert arr.to_pylist() == []
+
+
+@parametrize_with_sequence_types
+def test_nested_lists(seq):
+ data = [[], [1, 2], None]
+ arr = pa.array(seq(data))
+ assert len(arr) == 3
+ assert arr.null_count == 1
+ assert arr.type == pa.list_(pa.int64())
+ assert arr.to_pylist() == data
+ # With explicit type
+ arr = pa.array(seq(data), type=pa.list_(pa.int32()))
+ assert len(arr) == 3
+ assert arr.null_count == 1
+ assert arr.type == pa.list_(pa.int32())
+ assert arr.to_pylist() == data
+
+
+@parametrize_with_sequence_types
+def test_nested_large_lists(seq):
+ data = [[], [1, 2], None]
+ arr = pa.array(seq(data), type=pa.large_list(pa.int16()))
+ assert len(arr) == 3
+ assert arr.null_count == 1
+ assert arr.type == pa.large_list(pa.int16())
+ assert arr.to_pylist() == data
+
+
+@parametrize_with_collections_types
+def test_list_with_non_list(seq):
+ # List types don't accept non-sequences
+ with pytest.raises(TypeError):
+ pa.array(seq([[], [1, 2], 3]), type=pa.list_(pa.int64()))
+ with pytest.raises(TypeError):
+ pa.array(seq([[], [1, 2], 3]), type=pa.large_list(pa.int64()))
+
+
+@parametrize_with_sequence_types
+def test_nested_arrays(seq):
+ arr = pa.array(seq([np.array([], dtype=np.int64),
+ np.array([1, 2], dtype=np.int64), None]))
+ assert len(arr) == 3
+ assert arr.null_count == 1
+ assert arr.type == pa.list_(pa.int64())
+ assert arr.to_pylist() == [[], [1, 2], None]
+
+
+@parametrize_with_sequence_types
+def test_nested_fixed_size_list(seq):
+ # sequence of lists
+ data = [[1, 2], [3, None], None]
+ arr = pa.array(seq(data), type=pa.list_(pa.int64(), 2))
+ assert len(arr) == 3
+ assert arr.null_count == 1
+ assert arr.type == pa.list_(pa.int64(), 2)
+ assert arr.to_pylist() == data
+
+ # sequence of numpy arrays
+ data = [np.array([1, 2], dtype='int64'), np.array([3, 4], dtype='int64'),
+ None]
+ arr = pa.array(seq(data), type=pa.list_(pa.int64(), 2))
+ assert len(arr) == 3
+ assert arr.null_count == 1
+ assert arr.type == pa.list_(pa.int64(), 2)
+ assert arr.to_pylist() == [[1, 2], [3, 4], None]
+
+ # incorrect length of the lists or arrays
+ data = [[1, 2, 4], [3, None], None]
+ for data in [[[1, 2, 3]], [np.array([1, 2, 4], dtype='int64')]]:
+ with pytest.raises(
+ ValueError, match="Length of item not correct: expected 2"):
+ pa.array(seq(data), type=pa.list_(pa.int64(), 2))
+
+ # with list size of 0
+ data = [[], [], None]
+ arr = pa.array(seq(data), type=pa.list_(pa.int64(), 0))
+ assert len(arr) == 3
+ assert arr.null_count == 1
+ assert arr.type == pa.list_(pa.int64(), 0)
+ assert arr.to_pylist() == [[], [], None]
+
+
+@parametrize_with_sequence_types
+def test_sequence_all_none(seq):
+ arr = pa.array(seq([None, None]))
+ assert len(arr) == 2
+ assert arr.null_count == 2
+ assert arr.type == pa.null()
+ assert arr.to_pylist() == [None, None]
+
+
+@parametrize_with_sequence_types
+@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs)
+def test_sequence_integer(seq, np_scalar_pa_type):
+ np_scalar, pa_type = np_scalar_pa_type
+ expected = [1, None, 3, None,
+ np.iinfo(np_scalar).min, np.iinfo(np_scalar).max]
+ arr = pa.array(seq(expected), type=pa_type)
+ assert len(arr) == 6
+ assert arr.null_count == 2
+ assert arr.type == pa_type
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_collections_types
+@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs)
+def test_sequence_integer_np_nan(seq, np_scalar_pa_type):
+ # ARROW-2806: numpy.nan is a double value and thus should produce
+ # a double array.
+ _, pa_type = np_scalar_pa_type
+ with pytest.raises(ValueError):
+ pa.array(seq([np.nan]), type=pa_type, from_pandas=False)
+
+ arr = pa.array(seq([np.nan]), type=pa_type, from_pandas=True)
+ expected = [None]
+ assert len(arr) == 1
+ assert arr.null_count == 1
+ assert arr.type == pa_type
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_sequence_types
+@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs)
+def test_sequence_integer_nested_np_nan(seq, np_scalar_pa_type):
+ # ARROW-2806: numpy.nan is a double value and thus should produce
+ # a double array.
+ _, pa_type = np_scalar_pa_type
+ with pytest.raises(ValueError):
+ pa.array(seq([[np.nan]]), type=pa.list_(pa_type), from_pandas=False)
+
+ arr = pa.array(seq([[np.nan]]), type=pa.list_(pa_type), from_pandas=True)
+ expected = [[None]]
+ assert len(arr) == 1
+ assert arr.null_count == 0
+ assert arr.type == pa.list_(pa_type)
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_sequence_types
+def test_sequence_integer_inferred(seq):
+ expected = [1, None, 3, None]
+ arr = pa.array(seq(expected))
+ assert len(arr) == 4
+ assert arr.null_count == 2
+ assert arr.type == pa.int64()
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_sequence_types
+@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs)
+def test_sequence_numpy_integer(seq, np_scalar_pa_type):
+ np_scalar, pa_type = np_scalar_pa_type
+ expected = [np_scalar(1), None, np_scalar(3), None,
+ np_scalar(np.iinfo(np_scalar).min),
+ np_scalar(np.iinfo(np_scalar).max)]
+ arr = pa.array(seq(expected), type=pa_type)
+ assert len(arr) == 6
+ assert arr.null_count == 2
+ assert arr.type == pa_type
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_sequence_types
+@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs)
+def test_sequence_numpy_integer_inferred(seq, np_scalar_pa_type):
+ np_scalar, pa_type = np_scalar_pa_type
+ expected = [np_scalar(1), None, np_scalar(3), None]
+ expected += [np_scalar(np.iinfo(np_scalar).min),
+ np_scalar(np.iinfo(np_scalar).max)]
+ arr = pa.array(seq(expected))
+ assert len(arr) == 6
+ assert arr.null_count == 2
+ assert arr.type == pa_type
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_sequence_types
+def test_sequence_custom_integers(seq):
+ expected = [0, 42, 2**33 + 1, -2**63]
+ data = list(map(MyInt, expected))
+ arr = pa.array(seq(data), type=pa.int64())
+ assert arr.to_pylist() == expected
+
+
+@parametrize_with_collections_types
+def test_broken_integers(seq):
+ data = [MyBrokenInt()]
+ with pytest.raises(pa.ArrowInvalid, match="tried to convert to int"):
+ pa.array(seq(data), type=pa.int64())
+
+
+def test_numpy_scalars_mixed_type():
+ # ARROW-4324
+ data = [np.int32(10), np.float32(0.5)]
+ arr = pa.array(data)
+ expected = pa.array([10, 0.5], type="float64")
+ assert arr.equals(expected)
+
+ # ARROW-9490
+ data = [np.int8(10), np.float32(0.5)]
+ arr = pa.array(data)
+ expected = pa.array([10, 0.5], type="float32")
+ assert arr.equals(expected)
+
+
+@pytest.mark.xfail(reason="Type inference for uint64 not implemented",
+ raises=OverflowError)
+def test_uint64_max_convert():
+ data = [0, np.iinfo(np.uint64).max]
+
+ arr = pa.array(data, type=pa.uint64())
+ expected = pa.array(np.array(data, dtype='uint64'))
+ assert arr.equals(expected)
+
+ arr_inferred = pa.array(data)
+ assert arr_inferred.equals(expected)
+
+
+@pytest.mark.parametrize("bits", [8, 16, 32, 64])
+def test_signed_integer_overflow(bits):
+ ty = getattr(pa, "int%d" % bits)()
+ # XXX ideally would always raise OverflowError
+ with pytest.raises((OverflowError, pa.ArrowInvalid)):
+ pa.array([2 ** (bits - 1)], ty)
+ with pytest.raises((OverflowError, pa.ArrowInvalid)):
+ pa.array([-2 ** (bits - 1) - 1], ty)
+
+
+@pytest.mark.parametrize("bits", [8, 16, 32, 64])
+def test_unsigned_integer_overflow(bits):
+ ty = getattr(pa, "uint%d" % bits)()
+ # XXX ideally would always raise OverflowError
+ with pytest.raises((OverflowError, pa.ArrowInvalid)):
+ pa.array([2 ** bits], ty)
+ with pytest.raises((OverflowError, pa.ArrowInvalid)):
+ pa.array([-1], ty)
+
+
+@parametrize_with_collections_types
+@pytest.mark.parametrize("typ", pa_int_types)
+def test_integer_from_string_error(seq, typ):
+ # ARROW-9451: pa.array(['1'], type=pa.uint32()) should not succeed
+ with pytest.raises(pa.ArrowInvalid):
+ pa.array(seq(['1']), type=typ)
+
+
+def test_convert_with_mask():
+ data = [1, 2, 3, 4, 5]
+ mask = np.array([False, True, False, False, True])
+
+ result = pa.array(data, mask=mask)
+ expected = pa.array([1, None, 3, 4, None])
+
+ assert result.equals(expected)
+
+ # Mask wrong length
+ with pytest.raises(ValueError):
+ pa.array(data, mask=mask[1:])
+
+
+def test_garbage_collection():
+ import gc
+
+ # Force the cyclic garbage collector to run
+ gc.collect()
+
+ bytes_before = pa.total_allocated_bytes()
+ pa.array([1, None, 3, None])
+ gc.collect()
+ assert pa.total_allocated_bytes() == bytes_before
+
+
+def test_sequence_double():
+ data = [1.5, 1., None, 2.5, None, None]
+ arr = pa.array(data)
+ assert len(arr) == 6
+ assert arr.null_count == 3
+ assert arr.type == pa.float64()
+ assert arr.to_pylist() == data
+
+
+def test_double_auto_coerce_from_integer():
+ # Done as part of ARROW-2814
+ data = [1.5, 1., None, 2.5, None, None]
+ arr = pa.array(data)
+
+ data2 = [1.5, 1, None, 2.5, None, None]
+ arr2 = pa.array(data2)
+
+ assert arr.equals(arr2)
+
+ data3 = [1, 1.5, None, 2.5, None, None]
+ arr3 = pa.array(data3)
+
+ data4 = [1., 1.5, None, 2.5, None, None]
+ arr4 = pa.array(data4)
+
+ assert arr3.equals(arr4)
+
+
+def test_double_integer_coerce_representable_range():
+ valid_values = [1.5, 1, 2, None, 1 << 53, -(1 << 53)]
+ invalid_values = [1.5, 1, 2, None, (1 << 53) + 1]
+ invalid_values2 = [1.5, 1, 2, None, -((1 << 53) + 1)]
+
+ # it works
+ pa.array(valid_values)
+
+ # it fails
+ with pytest.raises(ValueError):
+ pa.array(invalid_values)
+
+ with pytest.raises(ValueError):
+ pa.array(invalid_values2)
+
+
+def test_float32_integer_coerce_representable_range():
+ f32 = np.float32
+ valid_values = [f32(1.5), 1 << 24, -(1 << 24)]
+ invalid_values = [f32(1.5), (1 << 24) + 1]
+ invalid_values2 = [f32(1.5), -((1 << 24) + 1)]
+
+ # it works
+ pa.array(valid_values, type=pa.float32())
+
+ # it fails
+ with pytest.raises(ValueError):
+ pa.array(invalid_values, type=pa.float32())
+
+ with pytest.raises(ValueError):
+ pa.array(invalid_values2, type=pa.float32())
+
+
+def test_mixed_sequence_errors():
+ with pytest.raises(ValueError, match="tried to convert to boolean"):
+ pa.array([True, 'foo'], type=pa.bool_())
+
+ with pytest.raises(ValueError, match="tried to convert to float32"):
+ pa.array([1.5, 'foo'], type=pa.float32())
+
+ with pytest.raises(ValueError, match="tried to convert to double"):
+ pa.array([1.5, 'foo'])
+
+
+@parametrize_with_sequence_types
+@pytest.mark.parametrize("np_scalar,pa_type", [
+ (np.float16, pa.float16()),
+ (np.float32, pa.float32()),
+ (np.float64, pa.float64())
+])
+@pytest.mark.parametrize("from_pandas", [True, False])
+def test_sequence_numpy_double(seq, np_scalar, pa_type, from_pandas):
+ data = [np_scalar(1.5), np_scalar(1), None, np_scalar(2.5), None, np.nan]
+ arr = pa.array(seq(data), from_pandas=from_pandas)
+ assert len(arr) == 6
+ if from_pandas:
+ assert arr.null_count == 3
+ else:
+ assert arr.null_count == 2
+ if from_pandas:
+ # The NaN is skipped in type inference, otherwise it forces a
+ # float64 promotion
+ assert arr.type == pa_type
+ else:
+ assert arr.type == pa.float64()
+
+ assert arr.to_pylist()[:4] == data[:4]
+ if from_pandas:
+ assert arr.to_pylist()[5] is None
+ else:
+ assert np.isnan(arr.to_pylist()[5])
+
+
+@pytest.mark.parametrize("from_pandas", [True, False])
+@pytest.mark.parametrize("inner_seq", [np.array, list])
+def test_ndarray_nested_numpy_double(from_pandas, inner_seq):
+ # ARROW-2806
+ data = np.array([
+ inner_seq([1., 2.]),
+ inner_seq([1., 2., 3.]),
+ inner_seq([np.nan]),
+ None
+ ], dtype=object)
+ arr = pa.array(data, from_pandas=from_pandas)
+ assert len(arr) == 4
+ assert arr.null_count == 1
+ assert arr.type == pa.list_(pa.float64())
+ if from_pandas:
+ assert arr.to_pylist() == [[1.0, 2.0], [1.0, 2.0, 3.0], [None], None]
+ else:
+ np.testing.assert_equal(arr.to_pylist(),
+ [[1., 2.], [1., 2., 3.], [np.nan], None])
+
+
+def test_nested_ndarray_in_object_array():
+ # ARROW-4350
+ arr = np.empty(2, dtype=object)
+ arr[:] = [np.array([1, 2], dtype=np.int64),
+ np.array([2, 3], dtype=np.int64)]
+
+ arr2 = np.empty(2, dtype=object)
+ arr2[0] = [3, 4]
+ arr2[1] = [5, 6]
+
+ expected_type = pa.list_(pa.list_(pa.int64()))
+ assert pa.infer_type([arr]) == expected_type
+
+ result = pa.array([arr, arr2])
+ expected = pa.array([[[1, 2], [2, 3]], [[3, 4], [5, 6]]],
+ type=expected_type)
+
+ assert result.equals(expected)
+
+ # test case for len-1 arrays to ensure they are interpreted as
+ # sublists and not scalars
+ arr = np.empty(2, dtype=object)
+ arr[:] = [np.array([1]), np.array([2])]
+ result = pa.array([arr, arr])
+ assert result.to_pylist() == [[[1], [2]], [[1], [2]]]
+
+
+@pytest.mark.xfail(reason=("Type inference for multidimensional ndarray "
+ "not yet implemented"),
+ raises=AssertionError)
+def test_multidimensional_ndarray_as_nested_list():
+ # TODO(wesm): see ARROW-5645
+ arr = np.array([[1, 2], [2, 3]], dtype=np.int64)
+ arr2 = np.array([[3, 4], [5, 6]], dtype=np.int64)
+
+ expected_type = pa.list_(pa.list_(pa.int64()))
+ assert pa.infer_type([arr]) == expected_type
+
+ result = pa.array([arr, arr2])
+ expected = pa.array([[[1, 2], [2, 3]], [[3, 4], [5, 6]]],
+ type=expected_type)
+
+ assert result.equals(expected)
+
+
+@pytest.mark.parametrize(('data', 'value_type'), [
+ ([True, False], pa.bool_()),
+ ([None, None], pa.null()),
+ ([1, 2, None], pa.int8()),
+ ([1, 2., 3., None], pa.float32()),
+ ([datetime.date.today(), None], pa.date32()),
+ ([None, datetime.date.today()], pa.date64()),
+ ([datetime.time(1, 1, 1), None], pa.time32('s')),
+ ([None, datetime.time(2, 2, 2)], pa.time64('us')),
+ ([datetime.datetime.now(), None], pa.timestamp('us')),
+ ([datetime.timedelta(seconds=10)], pa.duration('s')),
+ ([b"a", b"b"], pa.binary()),
+ ([b"aaa", b"bbb", b"ccc"], pa.binary(3)),
+ ([b"a", b"b", b"c"], pa.large_binary()),
+ (["a", "b", "c"], pa.string()),
+ (["a", "b", "c"], pa.large_string()),
+ (
+ [{"a": 1, "b": 2}, None, {"a": 5, "b": None}],
+ pa.struct([('a', pa.int8()), ('b', pa.int16())])
+ )
+])
+def test_list_array_from_object_ndarray(data, value_type):
+ ty = pa.list_(value_type)
+ ndarray = np.array(data, dtype=object)
+ arr = pa.array([ndarray], type=ty)
+ assert arr.type.equals(ty)
+ assert arr.to_pylist() == [data]
+
+
+@pytest.mark.parametrize(('data', 'value_type'), [
+ ([[1, 2], [3]], pa.list_(pa.int64())),
+ ([[1, 2], [3, 4]], pa.list_(pa.int64(), 2)),
+ ([[1], [2, 3]], pa.large_list(pa.int64()))
+])
+def test_nested_list_array_from_object_ndarray(data, value_type):
+ ndarray = np.empty(len(data), dtype=object)
+ ndarray[:] = [np.array(item, dtype=object) for item in data]
+
+ ty = pa.list_(value_type)
+ arr = pa.array([ndarray], type=ty)
+ assert arr.type.equals(ty)
+ assert arr.to_pylist() == [data]
+
+
+def test_array_ignore_nan_from_pandas():
+ # See ARROW-4324, this reverts logic that was introduced in
+ # ARROW-2240
+ with pytest.raises(ValueError):
+ pa.array([np.nan, 'str'])
+
+ arr = pa.array([np.nan, 'str'], from_pandas=True)
+ expected = pa.array([None, 'str'])
+ assert arr.equals(expected)
+
+
+def test_nested_ndarray_different_dtypes():
+ data = [
+ np.array([1, 2, 3], dtype='int64'),
+ None,
+ np.array([4, 5, 6], dtype='uint32')
+ ]
+
+ arr = pa.array(data)
+ expected = pa.array([[1, 2, 3], None, [4, 5, 6]],
+ type=pa.list_(pa.int64()))
+ assert arr.equals(expected)
+
+ t2 = pa.list_(pa.uint32())
+ arr2 = pa.array(data, type=t2)
+ expected2 = expected.cast(t2)
+ assert arr2.equals(expected2)
+
+
+def test_sequence_unicode():
+ data = ['foo', 'bar', None, 'mañana']
+ arr = pa.array(data)
+ assert len(arr) == 4
+ assert arr.null_count == 1
+ assert arr.type == pa.string()
+ assert arr.to_pylist() == data
+
+
+def check_array_mixed_unicode_bytes(binary_type, string_type):
+ values = ['qux', b'foo', bytearray(b'barz')]
+ b_values = [b'qux', b'foo', b'barz']
+ u_values = ['qux', 'foo', 'barz']
+
+ arr = pa.array(values)
+ expected = pa.array(b_values, type=pa.binary())
+ assert arr.type == pa.binary()
+ assert arr.equals(expected)
+
+ arr = pa.array(values, type=binary_type)
+ expected = pa.array(b_values, type=binary_type)
+ assert arr.type == binary_type
+ assert arr.equals(expected)
+
+ arr = pa.array(values, type=string_type)
+ expected = pa.array(u_values, type=string_type)
+ assert arr.type == string_type
+ assert arr.equals(expected)
+
+
+def test_array_mixed_unicode_bytes():
+ check_array_mixed_unicode_bytes(pa.binary(), pa.string())
+ check_array_mixed_unicode_bytes(pa.large_binary(), pa.large_string())
+
+
+@pytest.mark.large_memory
+@pytest.mark.parametrize("ty", [pa.large_binary(), pa.large_string()])
+def test_large_binary_array(ty):
+ # Construct a large binary array with more than 4GB of data
+ s = b"0123456789abcdefghijklmnopqrstuvwxyz" * 10
+ nrepeats = math.ceil((2**32 + 5) / len(s))
+ data = [s] * nrepeats
+ arr = pa.array(data, type=ty)
+ assert isinstance(arr, pa.Array)
+ assert arr.type == ty
+ assert len(arr) == nrepeats
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+@pytest.mark.parametrize("ty", [pa.large_binary(), pa.large_string()])
+def test_large_binary_value(ty):
+ # Construct a large binary array with a single value larger than 4GB
+ s = b"0123456789abcdefghijklmnopqrstuvwxyz"
+ nrepeats = math.ceil((2**32 + 5) / len(s))
+ arr = pa.array([b"foo", s * nrepeats, None, b"bar"], type=ty)
+ assert isinstance(arr, pa.Array)
+ assert arr.type == ty
+ assert len(arr) == 4
+ buf = arr[1].as_buffer()
+ assert len(buf) == len(s) * nrepeats
+
+
+@pytest.mark.large_memory
+@pytest.mark.parametrize("ty", [pa.binary(), pa.string()])
+def test_string_too_large(ty):
+ # Construct a binary array with a single value larger than 4GB
+ s = b"0123456789abcdefghijklmnopqrstuvwxyz"
+ nrepeats = math.ceil((2**32 + 5) / len(s))
+ with pytest.raises(pa.ArrowCapacityError):
+ pa.array([b"foo", s * nrepeats, None, b"bar"], type=ty)
+
+
+def test_sequence_bytes():
+ u1 = b'ma\xc3\xb1ana'
+
+ data = [b'foo',
+ memoryview(b'dada'),
+ memoryview(b'd-a-t-a')[::2], # non-contiguous is made contiguous
+ u1.decode('utf-8'), # unicode gets encoded,
+ bytearray(b'bar'),
+ None]
+ for ty in [None, pa.binary(), pa.large_binary()]:
+ arr = pa.array(data, type=ty)
+ assert len(arr) == 6
+ assert arr.null_count == 1
+ assert arr.type == ty or pa.binary()
+ assert arr.to_pylist() == [b'foo', b'dada', b'data', u1, b'bar', None]
+
+
+@pytest.mark.parametrize("ty", [pa.string(), pa.large_string()])
+def test_sequence_utf8_to_unicode(ty):
+ # ARROW-1225
+ data = [b'foo', None, b'bar']
+ arr = pa.array(data, type=ty)
+ assert arr.type == ty
+ assert arr[0].as_py() == 'foo'
+
+ # test a non-utf8 unicode string
+ val = ('mañana').encode('utf-16-le')
+ with pytest.raises(pa.ArrowInvalid):
+ pa.array([val], type=ty)
+
+
+def test_sequence_fixed_size_bytes():
+ data = [b'foof', None, bytearray(b'barb'), b'2346']
+ arr = pa.array(data, type=pa.binary(4))
+ assert len(arr) == 4
+ assert arr.null_count == 1
+ assert arr.type == pa.binary(4)
+ assert arr.to_pylist() == [b'foof', None, b'barb', b'2346']
+
+
+def test_fixed_size_bytes_does_not_accept_varying_lengths():
+ data = [b'foo', None, b'barb', b'2346']
+ with pytest.raises(pa.ArrowInvalid):
+ pa.array(data, type=pa.binary(4))
+
+
+def test_fixed_size_binary_length_check():
+ # ARROW-10193
+ data = [b'\x19h\r\x9e\x00\x00\x00\x00\x01\x9b\x9fA']
+ assert len(data[0]) == 12
+ ty = pa.binary(12)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == data
+
+
+def test_sequence_date():
+ data = [datetime.date(2000, 1, 1), None, datetime.date(1970, 1, 1),
+ datetime.date(2040, 2, 26)]
+ arr = pa.array(data)
+ assert len(arr) == 4
+ assert arr.type == pa.date32()
+ assert arr.null_count == 1
+ assert arr[0].as_py() == datetime.date(2000, 1, 1)
+ assert arr[1].as_py() is None
+ assert arr[2].as_py() == datetime.date(1970, 1, 1)
+ assert arr[3].as_py() == datetime.date(2040, 2, 26)
+
+
+@pytest.mark.parametrize('input',
+ [(pa.date32(), [10957, None]),
+ (pa.date64(), [10957 * 86400000, None])])
+def test_sequence_explicit_types(input):
+ t, ex_values = input
+ data = [datetime.date(2000, 1, 1), None]
+ arr = pa.array(data, type=t)
+ arr2 = pa.array(ex_values, type=t)
+
+ for x in [arr, arr2]:
+ assert len(x) == 2
+ assert x.type == t
+ assert x.null_count == 1
+ assert x[0].as_py() == datetime.date(2000, 1, 1)
+ assert x[1].as_py() is None
+
+
+def test_date32_overflow():
+ # Overflow
+ data3 = [2**32, None]
+ with pytest.raises((OverflowError, pa.ArrowException)):
+ pa.array(data3, type=pa.date32())
+
+
+@pytest.mark.parametrize(('time_type', 'unit', 'int_type'), [
+ (pa.time32, 's', 'int32'),
+ (pa.time32, 'ms', 'int32'),
+ (pa.time64, 'us', 'int64'),
+ (pa.time64, 'ns', 'int64'),
+])
+def test_sequence_time_with_timezone(time_type, unit, int_type):
+ def expected_integer_value(t):
+ # only use with utc time object because it doesn't adjust with the
+ # offset
+ units = ['s', 'ms', 'us', 'ns']
+ multiplier = 10**(units.index(unit) * 3)
+ if t is None:
+ return None
+ seconds = (
+ t.hour * 3600 +
+ t.minute * 60 +
+ t.second +
+ t.microsecond * 10**-6
+ )
+ return int(seconds * multiplier)
+
+ def expected_time_value(t):
+ # only use with utc time object because it doesn't adjust with the
+ # time objects tzdata
+ if unit == 's':
+ return t.replace(microsecond=0)
+ elif unit == 'ms':
+ return t.replace(microsecond=(t.microsecond // 1000) * 1000)
+ else:
+ return t
+
+ # only timezone naive times are supported in arrow
+ data = [
+ datetime.time(8, 23, 34, 123456),
+ datetime.time(5, 0, 0, 1000),
+ None,
+ datetime.time(1, 11, 56, 432539),
+ datetime.time(23, 10, 0, 437699)
+ ]
+
+ ty = time_type(unit)
+ arr = pa.array(data, type=ty)
+ assert len(arr) == 5
+ assert arr.type == ty
+ assert arr.null_count == 1
+
+ # test that the underlying integers are UTC values
+ values = arr.cast(int_type)
+ expected = list(map(expected_integer_value, data))
+ assert values.to_pylist() == expected
+
+ # test that the scalars are datetime.time objects with UTC timezone
+ assert arr[0].as_py() == expected_time_value(data[0])
+ assert arr[1].as_py() == expected_time_value(data[1])
+ assert arr[2].as_py() is None
+ assert arr[3].as_py() == expected_time_value(data[3])
+ assert arr[4].as_py() == expected_time_value(data[4])
+
+ def tz(hours, minutes=0):
+ offset = datetime.timedelta(hours=hours, minutes=minutes)
+ return datetime.timezone(offset)
+
+
+def test_sequence_timestamp():
+ data = [
+ datetime.datetime(2007, 7, 13, 1, 23, 34, 123456),
+ None,
+ datetime.datetime(2006, 1, 13, 12, 34, 56, 432539),
+ datetime.datetime(2010, 8, 13, 5, 46, 57, 437699)
+ ]
+ arr = pa.array(data)
+ assert len(arr) == 4
+ assert arr.type == pa.timestamp('us')
+ assert arr.null_count == 1
+ assert arr[0].as_py() == datetime.datetime(2007, 7, 13, 1,
+ 23, 34, 123456)
+ assert arr[1].as_py() is None
+ assert arr[2].as_py() == datetime.datetime(2006, 1, 13, 12,
+ 34, 56, 432539)
+ assert arr[3].as_py() == datetime.datetime(2010, 8, 13, 5,
+ 46, 57, 437699)
+
+
+@pytest.mark.parametrize('timezone', [
+ None,
+ 'UTC',
+ 'Etc/GMT-1',
+ 'Europe/Budapest',
+])
+@pytest.mark.parametrize('unit', [
+ 's',
+ 'ms',
+ 'us',
+ 'ns'
+])
+def test_sequence_timestamp_with_timezone(timezone, unit):
+ def expected_integer_value(dt):
+ units = ['s', 'ms', 'us', 'ns']
+ multiplier = 10**(units.index(unit) * 3)
+ if dt is None:
+ return None
+ else:
+ # avoid float precision issues
+ ts = decimal.Decimal(str(dt.timestamp()))
+ return int(ts * multiplier)
+
+ def expected_datetime_value(dt):
+ if dt is None:
+ return None
+
+ if unit == 's':
+ dt = dt.replace(microsecond=0)
+ elif unit == 'ms':
+ dt = dt.replace(microsecond=(dt.microsecond // 1000) * 1000)
+
+ # adjust the timezone
+ if timezone is None:
+ # make datetime timezone unaware
+ return dt.replace(tzinfo=None)
+ else:
+ # convert to the expected timezone
+ return dt.astimezone(pytz.timezone(timezone))
+
+ data = [
+ datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive
+ pytz.utc.localize(
+ datetime.datetime(2008, 1, 5, 5, 0, 0, 1000)
+ ),
+ None,
+ pytz.timezone('US/Eastern').localize(
+ datetime.datetime(2006, 1, 13, 12, 34, 56, 432539)
+ ),
+ pytz.timezone('Europe/Moscow').localize(
+ datetime.datetime(2010, 8, 13, 5, 0, 0, 437699)
+ ),
+ ]
+ utcdata = [
+ pytz.utc.localize(data[0]),
+ data[1],
+ None,
+ data[3].astimezone(pytz.utc),
+ data[4].astimezone(pytz.utc),
+ ]
+
+ ty = pa.timestamp(unit, tz=timezone)
+ arr = pa.array(data, type=ty)
+ assert len(arr) == 5
+ assert arr.type == ty
+ assert arr.null_count == 1
+
+ # test that the underlying integers are UTC values
+ values = arr.cast('int64')
+ expected = list(map(expected_integer_value, utcdata))
+ assert values.to_pylist() == expected
+
+ # test that the scalars are datetimes with the correct timezone
+ for i in range(len(arr)):
+ assert arr[i].as_py() == expected_datetime_value(utcdata[i])
+
+
+@pytest.mark.parametrize('timezone', [
+ None,
+ 'UTC',
+ 'Etc/GMT-1',
+ 'Europe/Budapest',
+])
+def test_pyarrow_ignore_timezone_environment_variable(monkeypatch, timezone):
+ # note that any non-empty value will evaluate to true
+ monkeypatch.setenv("PYARROW_IGNORE_TIMEZONE", "1")
+ data = [
+ datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive
+ pytz.utc.localize(
+ datetime.datetime(2008, 1, 5, 5, 0, 0, 1000)
+ ),
+ pytz.timezone('US/Eastern').localize(
+ datetime.datetime(2006, 1, 13, 12, 34, 56, 432539)
+ ),
+ pytz.timezone('Europe/Moscow').localize(
+ datetime.datetime(2010, 8, 13, 5, 0, 0, 437699)
+ ),
+ ]
+
+ expected = [dt.replace(tzinfo=None) for dt in data]
+ if timezone is not None:
+ tzinfo = pytz.timezone(timezone)
+ expected = [tzinfo.fromutc(dt) for dt in expected]
+
+ ty = pa.timestamp('us', tz=timezone)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == expected
+
+
+def test_sequence_timestamp_with_timezone_inference():
+ data = [
+ datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive
+ pytz.utc.localize(
+ datetime.datetime(2008, 1, 5, 5, 0, 0, 1000)
+ ),
+ None,
+ pytz.timezone('US/Eastern').localize(
+ datetime.datetime(2006, 1, 13, 12, 34, 56, 432539)
+ ),
+ pytz.timezone('Europe/Moscow').localize(
+ datetime.datetime(2010, 8, 13, 5, 0, 0, 437699)
+ ),
+ ]
+ expected = [
+ pa.timestamp('us', tz=None),
+ pa.timestamp('us', tz='UTC'),
+ pa.timestamp('us', tz=None),
+ pa.timestamp('us', tz='US/Eastern'),
+ pa.timestamp('us', tz='Europe/Moscow')
+ ]
+ for dt, expected_type in zip(data, expected):
+ prepended = [dt] + data
+ arr = pa.array(prepended)
+ assert arr.type == expected_type
+
+
+@pytest.mark.pandas
+def test_sequence_timestamp_from_mixed_builtin_and_pandas_datetimes():
+ import pandas as pd
+
+ data = [
+ pd.Timestamp(1184307814123456123, tz=pytz.timezone('US/Eastern'),
+ unit='ns'),
+ datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive
+ pytz.utc.localize(
+ datetime.datetime(2008, 1, 5, 5, 0, 0, 1000)
+ ),
+ None,
+ ]
+ utcdata = [
+ data[0].astimezone(pytz.utc),
+ pytz.utc.localize(data[1]),
+ data[2].astimezone(pytz.utc),
+ None,
+ ]
+
+ arr = pa.array(data)
+ assert arr.type == pa.timestamp('us', tz='US/Eastern')
+
+ values = arr.cast('int64')
+ expected = [int(dt.timestamp() * 10**6) if dt else None for dt in utcdata]
+ assert values.to_pylist() == expected
+
+
+def test_sequence_timestamp_out_of_bounds_nanosecond():
+ # https://issues.apache.org/jira/browse/ARROW-9768
+ # datetime outside of range supported for nanosecond resolution
+ data = [datetime.datetime(2262, 4, 12)]
+ with pytest.raises(ValueError, match="out of bounds"):
+ pa.array(data, type=pa.timestamp('ns'))
+
+ # with microsecond resolution it works fine
+ arr = pa.array(data, type=pa.timestamp('us'))
+ assert arr.to_pylist() == data
+
+ # case where the naive is within bounds, but converted to UTC not
+ tz = datetime.timezone(datetime.timedelta(hours=-1))
+ data = [datetime.datetime(2262, 4, 11, 23, tzinfo=tz)]
+ with pytest.raises(ValueError, match="out of bounds"):
+ pa.array(data, type=pa.timestamp('ns'))
+
+ arr = pa.array(data, type=pa.timestamp('us'))
+ assert arr.to_pylist()[0] == datetime.datetime(2262, 4, 12)
+
+
+def test_sequence_numpy_timestamp():
+ data = [
+ np.datetime64(datetime.datetime(2007, 7, 13, 1, 23, 34, 123456)),
+ None,
+ np.datetime64(datetime.datetime(2006, 1, 13, 12, 34, 56, 432539)),
+ np.datetime64(datetime.datetime(2010, 8, 13, 5, 46, 57, 437699))
+ ]
+ arr = pa.array(data)
+ assert len(arr) == 4
+ assert arr.type == pa.timestamp('us')
+ assert arr.null_count == 1
+ assert arr[0].as_py() == datetime.datetime(2007, 7, 13, 1,
+ 23, 34, 123456)
+ assert arr[1].as_py() is None
+ assert arr[2].as_py() == datetime.datetime(2006, 1, 13, 12,
+ 34, 56, 432539)
+ assert arr[3].as_py() == datetime.datetime(2010, 8, 13, 5,
+ 46, 57, 437699)
+
+
+class MyDate(datetime.date):
+ pass
+
+
+class MyDatetime(datetime.datetime):
+ pass
+
+
+class MyTimedelta(datetime.timedelta):
+ pass
+
+
+def test_datetime_subclassing():
+ data = [
+ MyDate(2007, 7, 13),
+ ]
+ date_type = pa.date32()
+ arr_date = pa.array(data, type=date_type)
+ assert len(arr_date) == 1
+ assert arr_date.type == date_type
+ assert arr_date[0].as_py() == datetime.date(2007, 7, 13)
+
+ data = [
+ MyDatetime(2007, 7, 13, 1, 23, 34, 123456),
+ ]
+
+ s = pa.timestamp('s')
+ ms = pa.timestamp('ms')
+ us = pa.timestamp('us')
+
+ arr_s = pa.array(data, type=s)
+ assert len(arr_s) == 1
+ assert arr_s.type == s
+ assert arr_s[0].as_py() == datetime.datetime(2007, 7, 13, 1,
+ 23, 34, 0)
+
+ arr_ms = pa.array(data, type=ms)
+ assert len(arr_ms) == 1
+ assert arr_ms.type == ms
+ assert arr_ms[0].as_py() == datetime.datetime(2007, 7, 13, 1,
+ 23, 34, 123000)
+
+ arr_us = pa.array(data, type=us)
+ assert len(arr_us) == 1
+ assert arr_us.type == us
+ assert arr_us[0].as_py() == datetime.datetime(2007, 7, 13, 1,
+ 23, 34, 123456)
+
+ data = [
+ MyTimedelta(123, 456, 1002),
+ ]
+
+ s = pa.duration('s')
+ ms = pa.duration('ms')
+ us = pa.duration('us')
+
+ arr_s = pa.array(data)
+ assert len(arr_s) == 1
+ assert arr_s.type == us
+ assert arr_s[0].as_py() == datetime.timedelta(123, 456, 1002)
+
+ arr_s = pa.array(data, type=s)
+ assert len(arr_s) == 1
+ assert arr_s.type == s
+ assert arr_s[0].as_py() == datetime.timedelta(123, 456)
+
+ arr_ms = pa.array(data, type=ms)
+ assert len(arr_ms) == 1
+ assert arr_ms.type == ms
+ assert arr_ms[0].as_py() == datetime.timedelta(123, 456, 1000)
+
+ arr_us = pa.array(data, type=us)
+ assert len(arr_us) == 1
+ assert arr_us.type == us
+ assert arr_us[0].as_py() == datetime.timedelta(123, 456, 1002)
+
+
+@pytest.mark.xfail(not _pandas_api.have_pandas,
+ reason="pandas required for nanosecond conversion")
+def test_sequence_timestamp_nanoseconds():
+ inputs = [
+ [datetime.datetime(2007, 7, 13, 1, 23, 34, 123456)],
+ [MyDatetime(2007, 7, 13, 1, 23, 34, 123456)]
+ ]
+
+ for data in inputs:
+ ns = pa.timestamp('ns')
+ arr_ns = pa.array(data, type=ns)
+ assert len(arr_ns) == 1
+ assert arr_ns.type == ns
+ assert arr_ns[0].as_py() == datetime.datetime(2007, 7, 13, 1,
+ 23, 34, 123456)
+
+
+@pytest.mark.pandas
+def test_sequence_timestamp_from_int_with_unit():
+ # TODO(wesm): This test might be rewritten to assert the actual behavior
+ # when pandas is not installed
+
+ data = [1]
+
+ s = pa.timestamp('s')
+ ms = pa.timestamp('ms')
+ us = pa.timestamp('us')
+ ns = pa.timestamp('ns')
+
+ arr_s = pa.array(data, type=s)
+ assert len(arr_s) == 1
+ assert arr_s.type == s
+ assert repr(arr_s[0]) == (
+ "<pyarrow.TimestampScalar: datetime.datetime(1970, 1, 1, 0, 0, 1)>"
+ )
+ assert str(arr_s[0]) == "1970-01-01 00:00:01"
+
+ arr_ms = pa.array(data, type=ms)
+ assert len(arr_ms) == 1
+ assert arr_ms.type == ms
+ assert repr(arr_ms[0].as_py()) == (
+ "datetime.datetime(1970, 1, 1, 0, 0, 0, 1000)"
+ )
+ assert str(arr_ms[0]) == "1970-01-01 00:00:00.001000"
+
+ arr_us = pa.array(data, type=us)
+ assert len(arr_us) == 1
+ assert arr_us.type == us
+ assert repr(arr_us[0].as_py()) == (
+ "datetime.datetime(1970, 1, 1, 0, 0, 0, 1)"
+ )
+ assert str(arr_us[0]) == "1970-01-01 00:00:00.000001"
+
+ arr_ns = pa.array(data, type=ns)
+ assert len(arr_ns) == 1
+ assert arr_ns.type == ns
+ assert repr(arr_ns[0].as_py()) == (
+ "Timestamp('1970-01-01 00:00:00.000000001')"
+ )
+ assert str(arr_ns[0]) == "1970-01-01 00:00:00.000000001"
+
+ expected_exc = TypeError
+
+ class CustomClass():
+ pass
+
+ for ty in [ns, pa.date32(), pa.date64()]:
+ with pytest.raises(expected_exc):
+ pa.array([1, CustomClass()], type=ty)
+
+
+@pytest.mark.parametrize('np_scalar', [True, False])
+def test_sequence_duration(np_scalar):
+ td1 = datetime.timedelta(2, 3601, 1)
+ td2 = datetime.timedelta(1, 100, 1000)
+ if np_scalar:
+ data = [np.timedelta64(td1), None, np.timedelta64(td2)]
+ else:
+ data = [td1, None, td2]
+
+ arr = pa.array(data)
+ assert len(arr) == 3
+ assert arr.type == pa.duration('us')
+ assert arr.null_count == 1
+ assert arr[0].as_py() == td1
+ assert arr[1].as_py() is None
+ assert arr[2].as_py() == td2
+
+
+@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns'])
+def test_sequence_duration_with_unit(unit):
+ data = [
+ datetime.timedelta(3, 22, 1001),
+ ]
+ expected = {'s': datetime.timedelta(3, 22),
+ 'ms': datetime.timedelta(3, 22, 1000),
+ 'us': datetime.timedelta(3, 22, 1001),
+ 'ns': datetime.timedelta(3, 22, 1001)}
+
+ ty = pa.duration(unit)
+
+ arr_s = pa.array(data, type=ty)
+ assert len(arr_s) == 1
+ assert arr_s.type == ty
+ assert arr_s[0].as_py() == expected[unit]
+
+
+@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns'])
+def test_sequence_duration_from_int_with_unit(unit):
+ data = [5]
+
+ ty = pa.duration(unit)
+ arr = pa.array(data, type=ty)
+ assert len(arr) == 1
+ assert arr.type == ty
+ assert arr[0].value == 5
+
+
+def test_sequence_duration_nested_lists():
+ td1 = datetime.timedelta(1, 1, 1000)
+ td2 = datetime.timedelta(1, 100)
+
+ data = [[td1, None], [td1, td2]]
+
+ arr = pa.array(data)
+ assert len(arr) == 2
+ assert arr.type == pa.list_(pa.duration('us'))
+ assert arr.to_pylist() == data
+
+ arr = pa.array(data, type=pa.list_(pa.duration('ms')))
+ assert len(arr) == 2
+ assert arr.type == pa.list_(pa.duration('ms'))
+ assert arr.to_pylist() == data
+
+
+def test_sequence_duration_nested_lists_numpy():
+ td1 = datetime.timedelta(1, 1, 1000)
+ td2 = datetime.timedelta(1, 100)
+
+ data = [[np.timedelta64(td1), None],
+ [np.timedelta64(td1), np.timedelta64(td2)]]
+
+ arr = pa.array(data)
+ assert len(arr) == 2
+ assert arr.type == pa.list_(pa.duration('us'))
+ assert arr.to_pylist() == [[td1, None], [td1, td2]]
+
+ data = [np.array([np.timedelta64(td1), None], dtype='timedelta64[us]'),
+ np.array([np.timedelta64(td1), np.timedelta64(td2)])]
+
+ arr = pa.array(data)
+ assert len(arr) == 2
+ assert arr.type == pa.list_(pa.duration('us'))
+ assert arr.to_pylist() == [[td1, None], [td1, td2]]
+
+
+def test_sequence_nesting_levels():
+ data = [1, 2, None]
+ arr = pa.array(data)
+ assert arr.type == pa.int64()
+ assert arr.to_pylist() == data
+
+ data = [[1], [2], None]
+ arr = pa.array(data)
+ assert arr.type == pa.list_(pa.int64())
+ assert arr.to_pylist() == data
+
+ data = [[1], [2, 3, 4], [None]]
+ arr = pa.array(data)
+ assert arr.type == pa.list_(pa.int64())
+ assert arr.to_pylist() == data
+
+ data = [None, [[None, 1]], [[2, 3, 4], None], [None]]
+ arr = pa.array(data)
+ assert arr.type == pa.list_(pa.list_(pa.int64()))
+ assert arr.to_pylist() == data
+
+ exceptions = (pa.ArrowInvalid, pa.ArrowTypeError)
+
+ # Mixed nesting levels are rejected
+ with pytest.raises(exceptions):
+ pa.array([1, 2, [1]])
+
+ with pytest.raises(exceptions):
+ pa.array([1, 2, []])
+
+ with pytest.raises(exceptions):
+ pa.array([[1], [2], [None, [1]]])
+
+
+def test_sequence_mixed_types_fails():
+ data = ['a', 1, 2.0]
+ with pytest.raises(pa.ArrowTypeError):
+ pa.array(data)
+
+
+def test_sequence_mixed_types_with_specified_type_fails():
+ data = ['-10', '-5', {'a': 1}, '0', '5', '10']
+
+ type = pa.string()
+ with pytest.raises(TypeError):
+ pa.array(data, type=type)
+
+
+def test_sequence_decimal():
+ data = [decimal.Decimal('1234.183'), decimal.Decimal('8094.234')]
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=7, scale=3))
+ assert arr.to_pylist() == data
+
+
+def test_sequence_decimal_different_precisions():
+ data = [
+ decimal.Decimal('1234234983.183'), decimal.Decimal('80943244.234')
+ ]
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=13, scale=3))
+ assert arr.to_pylist() == data
+
+
+def test_sequence_decimal_no_scale():
+ data = [decimal.Decimal('1234234983'), decimal.Decimal('8094324')]
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=10))
+ assert arr.to_pylist() == data
+
+
+def test_sequence_decimal_negative():
+ data = [decimal.Decimal('-1234.234983'), decimal.Decimal('-8.094324')]
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=10, scale=6))
+ assert arr.to_pylist() == data
+
+
+def test_sequence_decimal_no_whole_part():
+ data = [decimal.Decimal('-.4234983'), decimal.Decimal('.0103943')]
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=7, scale=7))
+ assert arr.to_pylist() == data
+
+
+def test_sequence_decimal_large_integer():
+ data = [decimal.Decimal('-394029506937548693.42983'),
+ decimal.Decimal('32358695912932.01033')]
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=23, scale=5))
+ assert arr.to_pylist() == data
+
+
+def test_sequence_decimal_from_integers():
+ data = [0, 1, -39402950693754869342983]
+ expected = [decimal.Decimal(x) for x in data]
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=28, scale=5))
+ assert arr.to_pylist() == expected
+
+
+def test_sequence_decimal_too_high_precision():
+ # ARROW-6989 python decimal has too high precision
+ with pytest.raises(ValueError, match="precision out of range"):
+ pa.array([decimal.Decimal('1' * 80)])
+
+
+def test_sequence_decimal_infer():
+ for data, typ in [
+ # simple case
+ (decimal.Decimal('1.234'), pa.decimal128(4, 3)),
+ # trailing zeros
+ (decimal.Decimal('12300'), pa.decimal128(5, 0)),
+ (decimal.Decimal('12300.0'), pa.decimal128(6, 1)),
+ # scientific power notation
+ (decimal.Decimal('1.23E+4'), pa.decimal128(5, 0)),
+ (decimal.Decimal('123E+2'), pa.decimal128(5, 0)),
+ (decimal.Decimal('123E+4'), pa.decimal128(7, 0)),
+ # leading zeros
+ (decimal.Decimal('0.0123'), pa.decimal128(4, 4)),
+ (decimal.Decimal('0.01230'), pa.decimal128(5, 5)),
+ (decimal.Decimal('1.230E-2'), pa.decimal128(5, 5)),
+ ]:
+ assert pa.infer_type([data]) == typ
+ arr = pa.array([data])
+ assert arr.type == typ
+ assert arr.to_pylist()[0] == data
+
+
+def test_sequence_decimal_infer_mixed():
+ # ARROW-12150 - ensure mixed precision gets correctly inferred to
+ # common type that can hold all input values
+ cases = [
+ ([decimal.Decimal('1.234'), decimal.Decimal('3.456')],
+ pa.decimal128(4, 3)),
+ ([decimal.Decimal('1.234'), decimal.Decimal('456.7')],
+ pa.decimal128(6, 3)),
+ ([decimal.Decimal('123.4'), decimal.Decimal('4.567')],
+ pa.decimal128(6, 3)),
+ ([decimal.Decimal('123e2'), decimal.Decimal('4567e3')],
+ pa.decimal128(7, 0)),
+ ([decimal.Decimal('123e4'), decimal.Decimal('4567e2')],
+ pa.decimal128(7, 0)),
+ ([decimal.Decimal('0.123'), decimal.Decimal('0.04567')],
+ pa.decimal128(5, 5)),
+ ([decimal.Decimal('0.001'), decimal.Decimal('1.01E5')],
+ pa.decimal128(9, 3)),
+ ]
+ for data, typ in cases:
+ assert pa.infer_type(data) == typ
+ arr = pa.array(data)
+ assert arr.type == typ
+ assert arr.to_pylist() == data
+
+
+def test_sequence_decimal_given_type():
+ for data, typs, wrong_typs in [
+ # simple case
+ (
+ decimal.Decimal('1.234'),
+ [pa.decimal128(4, 3), pa.decimal128(5, 3), pa.decimal128(5, 4)],
+ [pa.decimal128(4, 2), pa.decimal128(4, 4)]
+ ),
+ # trailing zeros
+ (
+ decimal.Decimal('12300'),
+ [pa.decimal128(5, 0), pa.decimal128(6, 0), pa.decimal128(3, -2)],
+ [pa.decimal128(4, 0), pa.decimal128(3, -3)]
+ ),
+ # scientific power notation
+ (
+ decimal.Decimal('1.23E+4'),
+ [pa.decimal128(5, 0), pa.decimal128(6, 0), pa.decimal128(3, -2)],
+ [pa.decimal128(4, 0), pa.decimal128(3, -3)]
+ ),
+ ]:
+ for typ in typs:
+ arr = pa.array([data], type=typ)
+ assert arr.type == typ
+ assert arr.to_pylist()[0] == data
+ for typ in wrong_typs:
+ with pytest.raises(ValueError):
+ pa.array([data], type=typ)
+
+
+def test_range_types():
+ arr1 = pa.array(range(3))
+ arr2 = pa.array((0, 1, 2))
+ assert arr1.equals(arr2)
+
+
+def test_empty_range():
+ arr = pa.array(range(0))
+ assert len(arr) == 0
+ assert arr.null_count == 0
+ assert arr.type == pa.null()
+ assert arr.to_pylist() == []
+
+
+def test_structarray():
+ arr = pa.StructArray.from_arrays([], names=[])
+ assert arr.type == pa.struct([])
+ assert len(arr) == 0
+ assert arr.to_pylist() == []
+
+ ints = pa.array([None, 2, 3], type=pa.int64())
+ strs = pa.array(['a', None, 'c'], type=pa.string())
+ bools = pa.array([True, False, None], type=pa.bool_())
+ arr = pa.StructArray.from_arrays(
+ [ints, strs, bools],
+ ['ints', 'strs', 'bools'])
+
+ expected = [
+ {'ints': None, 'strs': 'a', 'bools': True},
+ {'ints': 2, 'strs': None, 'bools': False},
+ {'ints': 3, 'strs': 'c', 'bools': None},
+ ]
+
+ pylist = arr.to_pylist()
+ assert pylist == expected, (pylist, expected)
+
+ # len(names) != len(arrays)
+ with pytest.raises(ValueError):
+ pa.StructArray.from_arrays([ints], ['ints', 'strs'])
+
+
+def test_struct_from_dicts():
+ ty = pa.struct([pa.field('a', pa.int32()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())])
+ arr = pa.array([], type=ty)
+ assert arr.to_pylist() == []
+
+ data = [{'a': 5, 'b': 'foo', 'c': True},
+ {'a': 6, 'b': 'bar', 'c': False}]
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == data
+
+ # With omitted values
+ data = [{'a': 5, 'c': True},
+ None,
+ {},
+ {'a': None, 'b': 'bar'}]
+ arr = pa.array(data, type=ty)
+ expected = [{'a': 5, 'b': None, 'c': True},
+ None,
+ {'a': None, 'b': None, 'c': None},
+ {'a': None, 'b': 'bar', 'c': None}]
+ assert arr.to_pylist() == expected
+
+
+def test_struct_from_dicts_bytes_keys():
+ # ARROW-6878
+ ty = pa.struct([pa.field('a', pa.int32()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())])
+ arr = pa.array([], type=ty)
+ assert arr.to_pylist() == []
+
+ data = [{b'a': 5, b'b': 'foo'},
+ {b'a': 6, b'c': False}]
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == [
+ {'a': 5, 'b': 'foo', 'c': None},
+ {'a': 6, 'b': None, 'c': False},
+ ]
+
+
+def test_struct_from_tuples():
+ ty = pa.struct([pa.field('a', pa.int32()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())])
+
+ data = [(5, 'foo', True),
+ (6, 'bar', False)]
+ expected = [{'a': 5, 'b': 'foo', 'c': True},
+ {'a': 6, 'b': 'bar', 'c': False}]
+ arr = pa.array(data, type=ty)
+
+ data_as_ndarray = np.empty(len(data), dtype=object)
+ data_as_ndarray[:] = data
+ arr2 = pa.array(data_as_ndarray, type=ty)
+ assert arr.to_pylist() == expected
+
+ assert arr.equals(arr2)
+
+ # With omitted values
+ data = [(5, 'foo', None),
+ None,
+ (6, None, False)]
+ expected = [{'a': 5, 'b': 'foo', 'c': None},
+ None,
+ {'a': 6, 'b': None, 'c': False}]
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == expected
+
+ # Invalid tuple size
+ for tup in [(5, 'foo'), (), ('5', 'foo', True, None)]:
+ with pytest.raises(ValueError, match="(?i)tuple size"):
+ pa.array([tup], type=ty)
+
+
+def test_struct_from_list_of_pairs():
+ ty = pa.struct([
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())
+ ])
+ data = [
+ [('a', 5), ('b', 'foo'), ('c', True)],
+ [('a', 6), ('b', 'bar'), ('c', False)],
+ None
+ ]
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == [
+ {'a': 5, 'b': 'foo', 'c': True},
+ {'a': 6, 'b': 'bar', 'c': False},
+ None
+ ]
+
+ # test with duplicated field names
+ ty = pa.struct([
+ pa.field('a', pa.int32()),
+ pa.field('a', pa.string()),
+ pa.field('b', pa.bool_())
+ ])
+ data = [
+ [('a', 5), ('a', 'foo'), ('b', True)],
+ [('a', 6), ('a', 'bar'), ('b', False)],
+ ]
+ arr = pa.array(data, type=ty)
+ with pytest.raises(ValueError):
+ # TODO(kszucs): ARROW-9997
+ arr.to_pylist()
+
+ # test with empty elements
+ ty = pa.struct([
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())
+ ])
+ data = [
+ [],
+ [('a', 5), ('b', 'foo'), ('c', True)],
+ [('a', 2), ('b', 'baz')],
+ [('a', 1), ('b', 'bar'), ('c', False), ('d', 'julia')],
+ ]
+ expected = [
+ {'a': None, 'b': None, 'c': None},
+ {'a': 5, 'b': 'foo', 'c': True},
+ {'a': 2, 'b': 'baz', 'c': None},
+ {'a': 1, 'b': 'bar', 'c': False},
+ ]
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == expected
+
+
+def test_struct_from_list_of_pairs_errors():
+ ty = pa.struct([
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())
+ ])
+
+ # test that it raises if the key doesn't match the expected field name
+ data = [
+ [],
+ [('a', 5), ('c', True), ('b', None)],
+ ]
+ msg = "The expected field name is `b` but `c` was given"
+ with pytest.raises(ValueError, match=msg):
+ pa.array(data, type=ty)
+
+ # test various errors both at the first position and after because of key
+ # type inference
+ template = (
+ r"Could not convert {} with type {}: was expecting tuple of "
+ r"(key, value) pair"
+ )
+ cases = [
+ tuple(), # empty key-value pair
+ tuple('a',), # missing value
+ tuple('unknown-key',), # not known field name
+ 'string', # not a tuple
+ ]
+ for key_value_pair in cases:
+ msg = re.escape(template.format(
+ repr(key_value_pair), type(key_value_pair).__name__
+ ))
+
+ with pytest.raises(TypeError, match=msg):
+ pa.array([
+ [key_value_pair],
+ [('a', 5), ('b', 'foo'), ('c', None)],
+ ], type=ty)
+
+ with pytest.raises(TypeError, match=msg):
+ pa.array([
+ [('a', 5), ('b', 'foo'), ('c', None)],
+ [key_value_pair],
+ ], type=ty)
+
+
+def test_struct_from_mixed_sequence():
+ # It is forbidden to mix dicts and tuples when initializing a struct array
+ ty = pa.struct([pa.field('a', pa.int32()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())])
+ data = [(5, 'foo', True),
+ {'a': 6, 'b': 'bar', 'c': False}]
+ with pytest.raises(TypeError):
+ pa.array(data, type=ty)
+
+
+def test_struct_from_dicts_inference():
+ expected_type = pa.struct([pa.field('a', pa.int64()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())])
+ data = [{'a': 5, 'b': 'foo', 'c': True},
+ {'a': 6, 'b': 'bar', 'c': False}]
+
+ arr = pa.array(data)
+ check_struct_type(arr.type, expected_type)
+ assert arr.to_pylist() == data
+
+ # With omitted values
+ data = [{'a': 5, 'c': True},
+ None,
+ {},
+ {'a': None, 'b': 'bar'}]
+ expected = [{'a': 5, 'b': None, 'c': True},
+ None,
+ {'a': None, 'b': None, 'c': None},
+ {'a': None, 'b': 'bar', 'c': None}]
+
+ arr = pa.array(data)
+ data_as_ndarray = np.empty(len(data), dtype=object)
+ data_as_ndarray[:] = data
+ arr2 = pa.array(data)
+
+ check_struct_type(arr.type, expected_type)
+ assert arr.to_pylist() == expected
+ assert arr.equals(arr2)
+
+ # Nested
+ expected_type = pa.struct([
+ pa.field('a', pa.struct([pa.field('aa', pa.list_(pa.int64())),
+ pa.field('ab', pa.bool_())])),
+ pa.field('b', pa.string())])
+ data = [{'a': {'aa': [5, 6], 'ab': True}, 'b': 'foo'},
+ {'a': {'aa': None, 'ab': False}, 'b': None},
+ {'a': None, 'b': 'bar'}]
+ arr = pa.array(data)
+
+ assert arr.to_pylist() == data
+
+ # Edge cases
+ arr = pa.array([{}])
+ assert arr.type == pa.struct([])
+ assert arr.to_pylist() == [{}]
+
+ # Mixing structs and scalars is rejected
+ with pytest.raises((pa.ArrowInvalid, pa.ArrowTypeError)):
+ pa.array([1, {'a': 2}])
+
+
+def test_structarray_from_arrays_coerce():
+ # ARROW-1706
+ ints = [None, 2, 3]
+ strs = ['a', None, 'c']
+ bools = [True, False, None]
+ ints_nonnull = [1, 2, 3]
+
+ arrays = [ints, strs, bools, ints_nonnull]
+ result = pa.StructArray.from_arrays(arrays,
+ ['ints', 'strs', 'bools',
+ 'int_nonnull'])
+ expected = pa.StructArray.from_arrays(
+ [pa.array(ints, type='int64'),
+ pa.array(strs, type='utf8'),
+ pa.array(bools),
+ pa.array(ints_nonnull, type='int64')],
+ ['ints', 'strs', 'bools', 'int_nonnull'])
+
+ with pytest.raises(ValueError):
+ pa.StructArray.from_arrays(arrays)
+
+ assert result.equals(expected)
+
+
+def test_decimal_array_with_none_and_nan():
+ values = [decimal.Decimal('1.234'), None, np.nan, decimal.Decimal('nan')]
+
+ with pytest.raises(TypeError):
+ # ARROW-6227: Without from_pandas=True, NaN is considered a float
+ array = pa.array(values)
+
+ array = pa.array(values, from_pandas=True)
+ assert array.type == pa.decimal128(4, 3)
+ assert array.to_pylist() == values[:2] + [None, None]
+
+ array = pa.array(values, type=pa.decimal128(10, 4), from_pandas=True)
+ assert array.to_pylist() == [decimal.Decimal('1.2340'), None, None, None]
+
+
+def test_map_from_dicts():
+ data = [[{'key': b'a', 'value': 1}, {'key': b'b', 'value': 2}],
+ [{'key': b'c', 'value': 3}],
+ [{'key': b'd', 'value': 4}, {'key': b'e', 'value': 5},
+ {'key': b'f', 'value': None}],
+ [{'key': b'g', 'value': 7}]]
+ expected = [[(d['key'], d['value']) for d in entry] for entry in data]
+
+ arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32()))
+
+ assert arr.to_pylist() == expected
+
+ # With omitted values
+ data[1] = None
+ expected[1] = None
+
+ arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32()))
+
+ assert arr.to_pylist() == expected
+
+ # Invalid dictionary
+ for entry in [[{'value': 5}], [{}], [{'k': 1, 'v': 2}]]:
+ with pytest.raises(ValueError, match="Invalid Map"):
+ pa.array([entry], type=pa.map_('i4', 'i4'))
+
+ # Invalid dictionary types
+ for entry in [[{'key': '1', 'value': 5}], [{'key': {'value': 2}}]]:
+ with pytest.raises(pa.ArrowInvalid, match="tried to convert to int"):
+ pa.array([entry], type=pa.map_('i4', 'i4'))
+
+
+def test_map_from_tuples():
+ expected = [[(b'a', 1), (b'b', 2)],
+ [(b'c', 3)],
+ [(b'd', 4), (b'e', 5), (b'f', None)],
+ [(b'g', 7)]]
+
+ arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32()))
+
+ assert arr.to_pylist() == expected
+
+ # With omitted values
+ expected[1] = None
+
+ arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32()))
+
+ assert arr.to_pylist() == expected
+
+ # Invalid tuple size
+ for entry in [[(5,)], [()], [('5', 'foo', True)]]:
+ with pytest.raises(ValueError, match="(?i)tuple size"):
+ pa.array([entry], type=pa.map_('i4', 'i4'))
+
+
+def test_dictionary_from_boolean():
+ typ = pa.dictionary(pa.int8(), value_type=pa.bool_())
+ a = pa.array([False, False, True, False, True], type=typ)
+ assert isinstance(a.type, pa.DictionaryType)
+ assert a.type.equals(typ)
+
+ expected_indices = pa.array([0, 0, 1, 0, 1], type=pa.int8())
+ expected_dictionary = pa.array([False, True], type=pa.bool_())
+ assert a.indices.equals(expected_indices)
+ assert a.dictionary.equals(expected_dictionary)
+
+
+@pytest.mark.parametrize('value_type', [
+ pa.int8(),
+ pa.int16(),
+ pa.int32(),
+ pa.int64(),
+ pa.uint8(),
+ pa.uint16(),
+ pa.uint32(),
+ pa.uint64(),
+ pa.float32(),
+ pa.float64(),
+])
+def test_dictionary_from_integers(value_type):
+ typ = pa.dictionary(pa.int8(), value_type=value_type)
+ a = pa.array([1, 2, 1, 1, 2, 3], type=typ)
+ assert isinstance(a.type, pa.DictionaryType)
+ assert a.type.equals(typ)
+
+ expected_indices = pa.array([0, 1, 0, 0, 1, 2], type=pa.int8())
+ expected_dictionary = pa.array([1, 2, 3], type=value_type)
+ assert a.indices.equals(expected_indices)
+ assert a.dictionary.equals(expected_dictionary)
+
+
+@pytest.mark.parametrize('input_index_type', [
+ pa.int8(),
+ pa.int16(),
+ pa.int32(),
+ pa.int64()
+])
+def test_dictionary_index_type(input_index_type):
+ # dictionary array is constructed using adaptive index type builder,
+ # but the input index type is considered as the minimal width type to use
+
+ typ = pa.dictionary(input_index_type, value_type=pa.int64())
+ arr = pa.array(range(10), type=typ)
+ assert arr.type.equals(typ)
+
+
+def test_dictionary_is_always_adaptive():
+ # dictionary array is constructed using adaptive index type builder,
+ # meaning that the output index type may be wider than the given index type
+ # since it depends on the input data
+ typ = pa.dictionary(pa.int8(), value_type=pa.int64())
+
+ a = pa.array(range(2**7), type=typ)
+ expected = pa.dictionary(pa.int8(), pa.int64())
+ assert a.type.equals(expected)
+
+ a = pa.array(range(2**7 + 1), type=typ)
+ expected = pa.dictionary(pa.int16(), pa.int64())
+ assert a.type.equals(expected)
+
+
+def test_dictionary_from_strings():
+ for value_type in [pa.binary(), pa.string()]:
+ typ = pa.dictionary(pa.int8(), value_type)
+ a = pa.array(["", "a", "bb", "a", "bb", "ccc"], type=typ)
+
+ assert isinstance(a.type, pa.DictionaryType)
+
+ expected_indices = pa.array([0, 1, 2, 1, 2, 3], type=pa.int8())
+ expected_dictionary = pa.array(["", "a", "bb", "ccc"], type=value_type)
+ assert a.indices.equals(expected_indices)
+ assert a.dictionary.equals(expected_dictionary)
+
+ # fixed size binary type
+ typ = pa.dictionary(pa.int8(), pa.binary(3))
+ a = pa.array(["aaa", "aaa", "bbb", "ccc", "bbb"], type=typ)
+ assert isinstance(a.type, pa.DictionaryType)
+
+ expected_indices = pa.array([0, 0, 1, 2, 1], type=pa.int8())
+ expected_dictionary = pa.array(["aaa", "bbb", "ccc"], type=pa.binary(3))
+ assert a.indices.equals(expected_indices)
+ assert a.dictionary.equals(expected_dictionary)
+
+
+@pytest.mark.parametrize(('unit', 'expected'), [
+ ('s', datetime.timedelta(seconds=-2147483000)),
+ ('ms', datetime.timedelta(milliseconds=-2147483000)),
+ ('us', datetime.timedelta(microseconds=-2147483000)),
+ ('ns', datetime.timedelta(microseconds=-2147483))
+])
+def test_duration_array_roundtrip_corner_cases(unit, expected):
+ # Corner case discovered by hypothesis: there were implicit conversions to
+ # unsigned values resulting wrong values with wrong signs.
+ ty = pa.duration(unit)
+ arr = pa.array([-2147483000], type=ty)
+ restored = pa.array(arr.to_pylist(), type=ty)
+ assert arr.equals(restored)
+
+ expected_list = [expected]
+ if unit == 'ns':
+ # if pandas is available then a pandas Timedelta is returned
+ try:
+ import pandas as pd
+ except ImportError:
+ pass
+ else:
+ expected_list = [pd.Timedelta(-2147483000, unit='ns')]
+
+ assert restored.to_pylist() == expected_list
+
+
+@pytest.mark.pandas
+def test_roundtrip_nanosecond_resolution_pandas_temporal_objects():
+ # corner case discovered by hypothesis: preserving the nanoseconds on
+ # conversion from a list of Timedelta and Timestamp objects
+ import pandas as pd
+
+ ty = pa.duration('ns')
+ arr = pa.array([9223371273709551616], type=ty)
+ data = arr.to_pylist()
+ assert isinstance(data[0], pd.Timedelta)
+ restored = pa.array(data, type=ty)
+ assert arr.equals(restored)
+ assert restored.to_pylist() == [
+ pd.Timedelta(9223371273709551616, unit='ns')
+ ]
+
+ ty = pa.timestamp('ns')
+ arr = pa.array([9223371273709551616], type=ty)
+ data = arr.to_pylist()
+ assert isinstance(data[0], pd.Timestamp)
+ restored = pa.array(data, type=ty)
+ assert arr.equals(restored)
+ assert restored.to_pylist() == [
+ pd.Timestamp(9223371273709551616, unit='ns')
+ ]
+
+ ty = pa.timestamp('ns', tz='US/Eastern')
+ value = 1604119893000000000
+ arr = pa.array([value], type=ty)
+ data = arr.to_pylist()
+ assert isinstance(data[0], pd.Timestamp)
+ restored = pa.array(data, type=ty)
+ assert arr.equals(restored)
+ assert restored.to_pylist() == [
+ pd.Timestamp(value, unit='ns').tz_localize(
+ "UTC").tz_convert('US/Eastern')
+ ]
+
+
+@h.given(past.all_arrays)
+def test_array_to_pylist_roundtrip(arr):
+ seq = arr.to_pylist()
+ restored = pa.array(seq, type=arr.type)
+ assert restored.equals(arr)
+
+
+@pytest.mark.large_memory
+def test_auto_chunking_binary_like():
+ # single chunk
+ v1 = b'x' * 100000000
+ v2 = b'x' * 147483646
+
+ # single chunk
+ one_chunk_data = [v1] * 20 + [b'', None, v2]
+ arr = pa.array(one_chunk_data, type=pa.binary())
+ assert isinstance(arr, pa.Array)
+ assert len(arr) == 23
+ assert arr[20].as_py() == b''
+ assert arr[21].as_py() is None
+ assert arr[22].as_py() == v2
+
+ # two chunks
+ two_chunk_data = one_chunk_data + [b'two']
+ arr = pa.array(two_chunk_data, type=pa.binary())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.num_chunks == 2
+ assert len(arr.chunk(0)) == 23
+ assert len(arr.chunk(1)) == 1
+ assert arr.chunk(0)[20].as_py() == b''
+ assert arr.chunk(0)[21].as_py() is None
+ assert arr.chunk(0)[22].as_py() == v2
+ assert arr.chunk(1).to_pylist() == [b'two']
+
+ # three chunks
+ three_chunk_data = one_chunk_data * 2 + [b'three', b'three']
+ arr = pa.array(three_chunk_data, type=pa.binary())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.num_chunks == 3
+ assert len(arr.chunk(0)) == 23
+ assert len(arr.chunk(1)) == 23
+ assert len(arr.chunk(2)) == 2
+ for i in range(2):
+ assert arr.chunk(i)[20].as_py() == b''
+ assert arr.chunk(i)[21].as_py() is None
+ assert arr.chunk(i)[22].as_py() == v2
+ assert arr.chunk(2).to_pylist() == [b'three', b'three']
+
+
+@pytest.mark.large_memory
+def test_auto_chunking_list_of_binary():
+ # ARROW-6281
+ vals = [['x' * 1024]] * ((2 << 20) + 1)
+ arr = pa.array(vals)
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.num_chunks == 2
+ assert len(arr.chunk(0)) == 2**21 - 1
+ assert len(arr.chunk(1)) == 2
+ assert arr.chunk(1).to_pylist() == [['x' * 1024]] * 2
+
+
+@pytest.mark.large_memory
+def test_auto_chunking_list_like():
+ item = np.ones((2**28,), dtype='uint8')
+ data = [item] * (2**3 - 1)
+ arr = pa.array(data, type=pa.list_(pa.uint8()))
+ assert isinstance(arr, pa.Array)
+ assert len(arr) == 7
+
+ item = np.ones((2**28,), dtype='uint8')
+ data = [item] * 2**3
+ arr = pa.array(data, type=pa.list_(pa.uint8()))
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.num_chunks == 2
+ assert len(arr.chunk(0)) == 7
+ assert len(arr.chunk(1)) == 1
+ chunk = arr.chunk(1)
+ scalar = chunk[0]
+ assert isinstance(scalar, pa.ListScalar)
+ expected = pa.array(item, type=pa.uint8())
+ assert scalar.values == expected
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+def test_auto_chunking_map_type():
+ # takes ~20 minutes locally
+ ty = pa.map_(pa.int8(), pa.int8())
+ item = [(1, 1)] * 2**28
+ data = [item] * 2**3
+ arr = pa.array(data, type=ty)
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr.chunk(0)) == 7
+ assert len(arr.chunk(1)) == 1
+
+
+@pytest.mark.large_memory
+@pytest.mark.parametrize(('ty', 'char'), [
+ (pa.string(), 'x'),
+ (pa.binary(), b'x'),
+])
+def test_nested_auto_chunking(ty, char):
+ v1 = char * 100000000
+ v2 = char * 147483646
+
+ struct_type = pa.struct([
+ pa.field('bool', pa.bool_()),
+ pa.field('integer', pa.int64()),
+ pa.field('string-like', ty),
+ ])
+
+ data = [{'bool': True, 'integer': 1, 'string-like': v1}] * 20
+ data.append({'bool': True, 'integer': 1, 'string-like': v2})
+ arr = pa.array(data, type=struct_type)
+ assert isinstance(arr, pa.Array)
+
+ data.append({'bool': True, 'integer': 1, 'string-like': char})
+ arr = pa.array(data, type=struct_type)
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.num_chunks == 2
+ assert len(arr.chunk(0)) == 21
+ assert len(arr.chunk(1)) == 1
+ assert arr.chunk(1)[0].as_py() == {
+ 'bool': True,
+ 'integer': 1,
+ 'string-like': char
+ }
+
+
+@pytest.mark.large_memory
+def test_array_from_pylist_data_overflow():
+ # Regression test for ARROW-12983
+ # Data buffer overflow - should result in chunked array
+ items = [b'a' * 4096] * (2 ** 19)
+ arr = pa.array(items, type=pa.string())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr) == 2**19
+ assert len(arr.chunks) > 1
+
+ mask = np.zeros(2**19, bool)
+ arr = pa.array(items, mask=mask, type=pa.string())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr) == 2**19
+ assert len(arr.chunks) > 1
+
+ arr = pa.array(items, type=pa.binary())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr) == 2**19
+ assert len(arr.chunks) > 1
+
+
+@pytest.mark.slow
+@pytest.mark.large_memory
+def test_array_from_pylist_offset_overflow():
+ # Regression test for ARROW-12983
+ # Offset buffer overflow - should result in chunked array
+ # Note this doesn't apply to primitive arrays
+ items = [b'a'] * (2 ** 31)
+ arr = pa.array(items, type=pa.string())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr) == 2**31
+ assert len(arr.chunks) > 1
+
+ mask = np.zeros(2**31, bool)
+ arr = pa.array(items, mask=mask, type=pa.string())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr) == 2**31
+ assert len(arr.chunks) > 1
+
+ arr = pa.array(items, type=pa.binary())
+ assert isinstance(arr, pa.ChunkedArray)
+ assert len(arr) == 2**31
+ assert len(arr.chunks) > 1
diff --git a/src/arrow/python/pyarrow/tests/test_csv.py b/src/arrow/python/pyarrow/tests/test_csv.py
new file mode 100644
index 000000000..b6cca243b
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_csv.py
@@ -0,0 +1,1824 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import abc
+import bz2
+from datetime import date, datetime
+from decimal import Decimal
+import gc
+import gzip
+import io
+import itertools
+import os
+import pickle
+import shutil
+import signal
+import string
+import tempfile
+import threading
+import time
+import unittest
+import weakref
+
+import pytest
+
+import numpy as np
+
+import pyarrow as pa
+from pyarrow.csv import (
+ open_csv, read_csv, ReadOptions, ParseOptions, ConvertOptions, ISO8601,
+ write_csv, WriteOptions, CSVWriter)
+from pyarrow.tests import util
+
+
+def generate_col_names():
+ # 'a', 'b'... 'z', then 'aa', 'ab'...
+ letters = string.ascii_lowercase
+ yield from letters
+ for first in letters:
+ for second in letters:
+ yield first + second
+
+
+def make_random_csv(num_cols=2, num_rows=10, linesep='\r\n', write_names=True):
+ arr = np.random.RandomState(42).randint(0, 1000, size=(num_cols, num_rows))
+ csv = io.StringIO()
+ col_names = list(itertools.islice(generate_col_names(), num_cols))
+ if write_names:
+ csv.write(",".join(col_names))
+ csv.write(linesep)
+ for row in arr.T:
+ csv.write(",".join(map(str, row)))
+ csv.write(linesep)
+ csv = csv.getvalue().encode()
+ columns = [pa.array(a, type=pa.int64()) for a in arr]
+ expected = pa.Table.from_arrays(columns, col_names)
+ return csv, expected
+
+
+def make_empty_csv(column_names):
+ csv = io.StringIO()
+ csv.write(",".join(column_names))
+ csv.write("\n")
+ return csv.getvalue().encode()
+
+
+def check_options_class(cls, **attr_values):
+ """
+ Check setting and getting attributes of an *Options class.
+ """
+ opts = cls()
+
+ for name, values in attr_values.items():
+ assert getattr(opts, name) == values[0], \
+ "incorrect default value for " + name
+ for v in values:
+ setattr(opts, name, v)
+ assert getattr(opts, name) == v, "failed setting value"
+
+ with pytest.raises(AttributeError):
+ opts.zzz_non_existent = True
+
+ # Check constructor named arguments
+ non_defaults = {name: values[1] for name, values in attr_values.items()}
+ opts = cls(**non_defaults)
+ for name, value in non_defaults.items():
+ assert getattr(opts, name) == value
+
+
+# The various options classes need to be picklable for dataset
+def check_options_class_pickling(cls, **attr_values):
+ opts = cls(**attr_values)
+ new_opts = pickle.loads(pickle.dumps(opts,
+ protocol=pickle.HIGHEST_PROTOCOL))
+ for name, value in attr_values.items():
+ assert getattr(new_opts, name) == value
+
+
+def test_read_options():
+ cls = ReadOptions
+ opts = cls()
+
+ check_options_class(cls, use_threads=[True, False],
+ skip_rows=[0, 3],
+ column_names=[[], ["ab", "cd"]],
+ autogenerate_column_names=[False, True],
+ encoding=['utf8', 'utf16'],
+ skip_rows_after_names=[0, 27])
+
+ check_options_class_pickling(cls, use_threads=True,
+ skip_rows=3,
+ column_names=["ab", "cd"],
+ autogenerate_column_names=False,
+ encoding='utf16',
+ skip_rows_after_names=27)
+
+ assert opts.block_size > 0
+ opts.block_size = 12345
+ assert opts.block_size == 12345
+
+ opts = cls(block_size=1234)
+ assert opts.block_size == 1234
+
+ opts.validate()
+
+ match = "ReadOptions: block_size must be at least 1: 0"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.block_size = 0
+ opts.validate()
+
+ match = "ReadOptions: skip_rows cannot be negative: -1"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.skip_rows = -1
+ opts.validate()
+
+ match = "ReadOptions: skip_rows_after_names cannot be negative: -1"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.skip_rows_after_names = -1
+ opts.validate()
+
+ match = "ReadOptions: autogenerate_column_names cannot be true when" \
+ " column_names are provided"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.autogenerate_column_names = True
+ opts.column_names = ('a', 'b')
+ opts.validate()
+
+
+def test_parse_options():
+ cls = ParseOptions
+
+ check_options_class(cls, delimiter=[',', 'x'],
+ escape_char=[False, 'y'],
+ quote_char=['"', 'z', False],
+ double_quote=[True, False],
+ newlines_in_values=[False, True],
+ ignore_empty_lines=[True, False])
+
+ check_options_class_pickling(cls, delimiter='x',
+ escape_char='y',
+ quote_char=False,
+ double_quote=False,
+ newlines_in_values=True,
+ ignore_empty_lines=False)
+
+ cls().validate()
+ opts = cls()
+ opts.delimiter = "\t"
+ opts.validate()
+
+ match = "ParseOptions: delimiter cannot be \\\\r or \\\\n"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.delimiter = "\n"
+ opts.validate()
+
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.delimiter = "\r"
+ opts.validate()
+
+ match = "ParseOptions: quote_char cannot be \\\\r or \\\\n"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.quote_char = "\n"
+ opts.validate()
+
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.quote_char = "\r"
+ opts.validate()
+
+ match = "ParseOptions: escape_char cannot be \\\\r or \\\\n"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.escape_char = "\n"
+ opts.validate()
+
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.escape_char = "\r"
+ opts.validate()
+
+
+def test_convert_options():
+ cls = ConvertOptions
+ opts = cls()
+
+ check_options_class(
+ cls, check_utf8=[True, False],
+ strings_can_be_null=[False, True],
+ quoted_strings_can_be_null=[True, False],
+ decimal_point=['.', ','],
+ include_columns=[[], ['def', 'abc']],
+ include_missing_columns=[False, True],
+ auto_dict_encode=[False, True],
+ timestamp_parsers=[[], [ISO8601, '%y-%m']])
+
+ check_options_class_pickling(
+ cls, check_utf8=False,
+ strings_can_be_null=True,
+ quoted_strings_can_be_null=False,
+ decimal_point=',',
+ include_columns=['def', 'abc'],
+ include_missing_columns=False,
+ auto_dict_encode=True,
+ timestamp_parsers=[ISO8601, '%y-%m'])
+
+ with pytest.raises(ValueError):
+ opts.decimal_point = '..'
+
+ assert opts.auto_dict_max_cardinality > 0
+ opts.auto_dict_max_cardinality = 99999
+ assert opts.auto_dict_max_cardinality == 99999
+
+ assert opts.column_types == {}
+ # Pass column_types as mapping
+ opts.column_types = {'b': pa.int16(), 'c': pa.float32()}
+ assert opts.column_types == {'b': pa.int16(), 'c': pa.float32()}
+ opts.column_types = {'v': 'int16', 'w': 'null'}
+ assert opts.column_types == {'v': pa.int16(), 'w': pa.null()}
+ # Pass column_types as schema
+ schema = pa.schema([('a', pa.int32()), ('b', pa.string())])
+ opts.column_types = schema
+ assert opts.column_types == {'a': pa.int32(), 'b': pa.string()}
+ # Pass column_types as sequence
+ opts.column_types = [('x', pa.binary())]
+ assert opts.column_types == {'x': pa.binary()}
+
+ with pytest.raises(TypeError, match='DataType expected'):
+ opts.column_types = {'a': None}
+ with pytest.raises(TypeError):
+ opts.column_types = 0
+
+ assert isinstance(opts.null_values, list)
+ assert '' in opts.null_values
+ assert 'N/A' in opts.null_values
+ opts.null_values = ['xxx', 'yyy']
+ assert opts.null_values == ['xxx', 'yyy']
+
+ assert isinstance(opts.true_values, list)
+ opts.true_values = ['xxx', 'yyy']
+ assert opts.true_values == ['xxx', 'yyy']
+
+ assert isinstance(opts.false_values, list)
+ opts.false_values = ['xxx', 'yyy']
+ assert opts.false_values == ['xxx', 'yyy']
+
+ assert opts.timestamp_parsers == []
+ opts.timestamp_parsers = [ISO8601]
+ assert opts.timestamp_parsers == [ISO8601]
+
+ opts = cls(column_types={'a': pa.null()},
+ null_values=['N', 'nn'], true_values=['T', 'tt'],
+ false_values=['F', 'ff'], auto_dict_max_cardinality=999,
+ timestamp_parsers=[ISO8601, '%Y-%m-%d'])
+ assert opts.column_types == {'a': pa.null()}
+ assert opts.null_values == ['N', 'nn']
+ assert opts.false_values == ['F', 'ff']
+ assert opts.true_values == ['T', 'tt']
+ assert opts.auto_dict_max_cardinality == 999
+ assert opts.timestamp_parsers == [ISO8601, '%Y-%m-%d']
+
+
+def test_write_options():
+ cls = WriteOptions
+ opts = cls()
+
+ check_options_class(
+ cls, include_header=[True, False])
+
+ assert opts.batch_size > 0
+ opts.batch_size = 12345
+ assert opts.batch_size == 12345
+
+ opts = cls(batch_size=9876)
+ assert opts.batch_size == 9876
+
+ opts.validate()
+
+ match = "WriteOptions: batch_size must be at least 1: 0"
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ opts = cls()
+ opts.batch_size = 0
+ opts.validate()
+
+
+class BaseTestCSV(abc.ABC):
+ """Common tests which are shared by streaming and non streaming readers"""
+
+ @abc.abstractmethod
+ def read_bytes(self, b, **kwargs):
+ """
+ :param b: bytes to be parsed
+ :param kwargs: arguments passed on to open the csv file
+ :return: b parsed as a single RecordBatch
+ """
+ raise NotImplementedError
+
+ @property
+ @abc.abstractmethod
+ def use_threads(self):
+ """Whether this test is multi-threaded"""
+ raise NotImplementedError
+
+ @staticmethod
+ def check_names(table, names):
+ assert table.num_columns == len(names)
+ assert table.column_names == names
+
+ def test_header_skip_rows(self):
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ opts = ReadOptions()
+ opts.skip_rows = 1
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["ef", "gh"])
+ assert table.to_pydict() == {
+ "ef": ["ij", "mn"],
+ "gh": ["kl", "op"],
+ }
+
+ opts.skip_rows = 3
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["mn", "op"])
+ assert table.to_pydict() == {
+ "mn": [],
+ "op": [],
+ }
+
+ opts.skip_rows = 4
+ with pytest.raises(pa.ArrowInvalid):
+ # Not enough rows
+ table = self.read_bytes(rows, read_options=opts)
+
+ # Can skip rows with a different number of columns
+ rows = b"abcd\n,,,,,\nij,kl\nmn,op\n"
+ opts.skip_rows = 2
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["ij", "kl"])
+ assert table.to_pydict() == {
+ "ij": ["mn"],
+ "kl": ["op"],
+ }
+
+ # Can skip all rows exactly when columns are given
+ opts.skip_rows = 4
+ opts.column_names = ['ij', 'kl']
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["ij", "kl"])
+ assert table.to_pydict() == {
+ "ij": [],
+ "kl": [],
+ }
+
+ def test_skip_rows_after_names(self):
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ opts = ReadOptions()
+ opts.skip_rows_after_names = 1
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["ab", "cd"])
+ assert table.to_pydict() == {
+ "ab": ["ij", "mn"],
+ "cd": ["kl", "op"],
+ }
+
+ # Can skip exact number of rows
+ opts.skip_rows_after_names = 3
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["ab", "cd"])
+ assert table.to_pydict() == {
+ "ab": [],
+ "cd": [],
+ }
+
+ # Can skip beyond all rows
+ opts.skip_rows_after_names = 4
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["ab", "cd"])
+ assert table.to_pydict() == {
+ "ab": [],
+ "cd": [],
+ }
+
+ # Can skip rows with a different number of columns
+ rows = b"abcd\n,,,,,\nij,kl\nmn,op\n"
+ opts.skip_rows_after_names = 2
+ opts.column_names = ["f0", "f1"]
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["f0", "f1"])
+ assert table.to_pydict() == {
+ "f0": ["ij", "mn"],
+ "f1": ["kl", "op"],
+ }
+ opts = ReadOptions()
+
+ # Can skip rows with new lines in the value
+ rows = b'ab,cd\n"e\nf","g\n\nh"\n"ij","k\nl"\nmn,op'
+ opts.skip_rows_after_names = 2
+ parse_opts = ParseOptions()
+ parse_opts.newlines_in_values = True
+ table = self.read_bytes(rows, read_options=opts,
+ parse_options=parse_opts)
+ self.check_names(table, ["ab", "cd"])
+ assert table.to_pydict() == {
+ "ab": ["mn"],
+ "cd": ["op"],
+ }
+
+ # Can skip rows when block ends in middle of quoted value
+ opts.skip_rows_after_names = 2
+ opts.block_size = 26
+ table = self.read_bytes(rows, read_options=opts,
+ parse_options=parse_opts)
+ self.check_names(table, ["ab", "cd"])
+ assert table.to_pydict() == {
+ "ab": ["mn"],
+ "cd": ["op"],
+ }
+ opts = ReadOptions()
+
+ # Can skip rows that are beyond the first block without lexer
+ rows, expected = make_random_csv(num_cols=5, num_rows=1000)
+ opts.skip_rows_after_names = 900
+ opts.block_size = len(rows) / 11
+ table = self.read_bytes(rows, read_options=opts)
+ assert table.schema == expected.schema
+ assert table.num_rows == 100
+ table_dict = table.to_pydict()
+ for name, values in expected.to_pydict().items():
+ assert values[900:] == table_dict[name]
+
+ # Can skip rows that are beyond the first block with lexer
+ table = self.read_bytes(rows, read_options=opts,
+ parse_options=parse_opts)
+ assert table.schema == expected.schema
+ assert table.num_rows == 100
+ table_dict = table.to_pydict()
+ for name, values in expected.to_pydict().items():
+ assert values[900:] == table_dict[name]
+
+ # Skip rows and skip rows after names
+ rows, expected = make_random_csv(num_cols=5, num_rows=200,
+ write_names=False)
+ opts = ReadOptions()
+ opts.skip_rows = 37
+ opts.skip_rows_after_names = 41
+ opts.column_names = expected.schema.names
+ table = self.read_bytes(rows, read_options=opts,
+ parse_options=parse_opts)
+ assert table.schema == expected.schema
+ assert (table.num_rows ==
+ expected.num_rows - opts.skip_rows -
+ opts.skip_rows_after_names)
+ table_dict = table.to_pydict()
+ for name, values in expected.to_pydict().items():
+ assert (values[opts.skip_rows + opts.skip_rows_after_names:] ==
+ table_dict[name])
+
+ def test_row_number_offset_in_errors(self):
+ # Row numbers are only correctly counted in serial reads
+ def format_msg(msg_format, row, *args):
+ if self.use_threads:
+ row_info = ""
+ else:
+ row_info = "Row #{}: ".format(row)
+ return msg_format.format(row_info, *args)
+
+ csv, _ = make_random_csv(4, 100, write_names=True)
+
+ read_options = ReadOptions()
+ read_options.block_size = len(csv) / 3
+ convert_options = ConvertOptions()
+ convert_options.column_types = {"a": pa.int32()}
+
+ # Test without skip_rows and column names in the csv
+ csv_bad_columns = csv + b"1,2\r\n"
+ message_columns = format_msg("{}Expected 4 columns, got 2", 102)
+ with pytest.raises(pa.ArrowInvalid, match=message_columns):
+ self.read_bytes(csv_bad_columns,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ csv_bad_type = csv + b"a,b,c,d\r\n"
+ message_value = format_msg(
+ "In CSV column #0: {}"
+ "CSV conversion error to int32: invalid value 'a'",
+ 102, csv)
+ with pytest.raises(pa.ArrowInvalid, match=message_value):
+ self.read_bytes(csv_bad_type,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ long_row = (b"this is a long row" * 15) + b",3\r\n"
+ csv_bad_columns_long = csv + long_row
+ message_long = format_msg("{}Expected 4 columns, got 2: {} ...", 102,
+ long_row[0:96].decode("utf-8"))
+ with pytest.raises(pa.ArrowInvalid, match=message_long):
+ self.read_bytes(csv_bad_columns_long,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ # Test skipping rows after the names
+ read_options.skip_rows_after_names = 47
+
+ with pytest.raises(pa.ArrowInvalid, match=message_columns):
+ self.read_bytes(csv_bad_columns,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ with pytest.raises(pa.ArrowInvalid, match=message_value):
+ self.read_bytes(csv_bad_type,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ with pytest.raises(pa.ArrowInvalid, match=message_long):
+ self.read_bytes(csv_bad_columns_long,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ read_options.skip_rows_after_names = 0
+
+ # Test without skip_rows and column names not in the csv
+ csv, _ = make_random_csv(4, 100, write_names=False)
+ read_options.column_names = ["a", "b", "c", "d"]
+ csv_bad_columns = csv + b"1,2\r\n"
+ message_columns = format_msg("{}Expected 4 columns, got 2", 101)
+ with pytest.raises(pa.ArrowInvalid, match=message_columns):
+ self.read_bytes(csv_bad_columns,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ csv_bad_columns_long = csv + long_row
+ message_long = format_msg("{}Expected 4 columns, got 2: {} ...", 101,
+ long_row[0:96].decode("utf-8"))
+ with pytest.raises(pa.ArrowInvalid, match=message_long):
+ self.read_bytes(csv_bad_columns_long,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ csv_bad_type = csv + b"a,b,c,d\r\n"
+ message_value = format_msg(
+ "In CSV column #0: {}"
+ "CSV conversion error to int32: invalid value 'a'",
+ 101)
+ message_value = message_value.format(len(csv))
+ with pytest.raises(pa.ArrowInvalid, match=message_value):
+ self.read_bytes(csv_bad_type,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ # Test with skip_rows and column names not in the csv
+ read_options.skip_rows = 23
+ with pytest.raises(pa.ArrowInvalid, match=message_columns):
+ self.read_bytes(csv_bad_columns,
+ read_options=read_options,
+ convert_options=convert_options)
+
+ with pytest.raises(pa.ArrowInvalid, match=message_value):
+ self.read_bytes(csv_bad_type,
+ read_options=read_options,
+ convert_options=convert_options)
+
+
+class BaseCSVTableRead(BaseTestCSV):
+
+ def read_csv(self, csv, *args, validate_full=True, **kwargs):
+ """
+ Reads the CSV file into memory using pyarrow's read_csv
+ csv The CSV bytes
+ args Positional arguments to be forwarded to pyarrow's read_csv
+ validate_full Whether or not to fully validate the resulting table
+ kwargs Keyword arguments to be forwarded to pyarrow's read_csv
+ """
+ assert isinstance(self.use_threads, bool) # sanity check
+ read_options = kwargs.setdefault('read_options', ReadOptions())
+ read_options.use_threads = self.use_threads
+ table = read_csv(csv, *args, **kwargs)
+ table.validate(full=validate_full)
+ return table
+
+ def read_bytes(self, b, **kwargs):
+ return self.read_csv(pa.py_buffer(b), **kwargs)
+
+ def test_file_object(self):
+ data = b"a,b\n1,2\n"
+ expected_data = {'a': [1], 'b': [2]}
+ bio = io.BytesIO(data)
+ table = self.read_csv(bio)
+ assert table.to_pydict() == expected_data
+ # Text files not allowed
+ sio = io.StringIO(data.decode())
+ with pytest.raises(TypeError):
+ self.read_csv(sio)
+
+ def test_header(self):
+ rows = b"abc,def,gh\n"
+ table = self.read_bytes(rows)
+ assert isinstance(table, pa.Table)
+ self.check_names(table, ["abc", "def", "gh"])
+ assert table.num_rows == 0
+
+ def test_bom(self):
+ rows = b"\xef\xbb\xbfa,b\n1,2\n"
+ expected_data = {'a': [1], 'b': [2]}
+ table = self.read_bytes(rows)
+ assert table.to_pydict() == expected_data
+
+ def test_one_chunk(self):
+ # ARROW-7661: lack of newline at end of file should not produce
+ # an additional chunk.
+ rows = [b"a,b", b"1,2", b"3,4", b"56,78"]
+ for line_ending in [b'\n', b'\r', b'\r\n']:
+ for file_ending in [b'', line_ending]:
+ data = line_ending.join(rows) + file_ending
+ table = self.read_bytes(data)
+ assert len(table.to_batches()) == 1
+ assert table.to_pydict() == {
+ "a": [1, 3, 56],
+ "b": [2, 4, 78],
+ }
+
+ def test_header_column_names(self):
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ opts = ReadOptions()
+ opts.column_names = ["x", "y"]
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["x", "y"])
+ assert table.to_pydict() == {
+ "x": ["ab", "ef", "ij", "mn"],
+ "y": ["cd", "gh", "kl", "op"],
+ }
+
+ opts.skip_rows = 3
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["x", "y"])
+ assert table.to_pydict() == {
+ "x": ["mn"],
+ "y": ["op"],
+ }
+
+ opts.skip_rows = 4
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["x", "y"])
+ assert table.to_pydict() == {
+ "x": [],
+ "y": [],
+ }
+
+ opts.skip_rows = 5
+ with pytest.raises(pa.ArrowInvalid):
+ # Not enough rows
+ table = self.read_bytes(rows, read_options=opts)
+
+ # Unexpected number of columns
+ opts.skip_rows = 0
+ opts.column_names = ["x", "y", "z"]
+ with pytest.raises(pa.ArrowInvalid,
+ match="Expected 3 columns, got 2"):
+ table = self.read_bytes(rows, read_options=opts)
+
+ # Can skip rows with a different number of columns
+ rows = b"abcd\n,,,,,\nij,kl\nmn,op\n"
+ opts.skip_rows = 2
+ opts.column_names = ["x", "y"]
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["x", "y"])
+ assert table.to_pydict() == {
+ "x": ["ij", "mn"],
+ "y": ["kl", "op"],
+ }
+
+ def test_header_autogenerate_column_names(self):
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ opts = ReadOptions()
+ opts.autogenerate_column_names = True
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["f0", "f1"])
+ assert table.to_pydict() == {
+ "f0": ["ab", "ef", "ij", "mn"],
+ "f1": ["cd", "gh", "kl", "op"],
+ }
+
+ opts.skip_rows = 3
+ table = self.read_bytes(rows, read_options=opts)
+ self.check_names(table, ["f0", "f1"])
+ assert table.to_pydict() == {
+ "f0": ["mn"],
+ "f1": ["op"],
+ }
+
+ # Not enough rows, impossible to infer number of columns
+ opts.skip_rows = 4
+ with pytest.raises(pa.ArrowInvalid):
+ table = self.read_bytes(rows, read_options=opts)
+
+ def test_include_columns(self):
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ convert_options = ConvertOptions()
+ convert_options.include_columns = ['ab']
+ table = self.read_bytes(rows, convert_options=convert_options)
+ self.check_names(table, ["ab"])
+ assert table.to_pydict() == {
+ "ab": ["ef", "ij", "mn"],
+ }
+
+ # Order of include_columns is respected, regardless of CSV order
+ convert_options.include_columns = ['cd', 'ab']
+ table = self.read_bytes(rows, convert_options=convert_options)
+ schema = pa.schema([('cd', pa.string()),
+ ('ab', pa.string())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ "cd": ["gh", "kl", "op"],
+ "ab": ["ef", "ij", "mn"],
+ }
+
+ # Include a column not in the CSV file => raises by default
+ convert_options.include_columns = ['xx', 'ab', 'yy']
+ with pytest.raises(KeyError,
+ match="Column 'xx' in include_columns "
+ "does not exist in CSV file"):
+ self.read_bytes(rows, convert_options=convert_options)
+
+ def test_include_missing_columns(self):
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ read_options = ReadOptions()
+ convert_options = ConvertOptions()
+ convert_options.include_columns = ['xx', 'ab', 'yy']
+ convert_options.include_missing_columns = True
+ table = self.read_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ schema = pa.schema([('xx', pa.null()),
+ ('ab', pa.string()),
+ ('yy', pa.null())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ "xx": [None, None, None],
+ "ab": ["ef", "ij", "mn"],
+ "yy": [None, None, None],
+ }
+
+ # Combining with `column_names`
+ read_options.column_names = ["xx", "yy"]
+ convert_options.include_columns = ["yy", "cd"]
+ table = self.read_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ schema = pa.schema([('yy', pa.string()),
+ ('cd', pa.null())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ "yy": ["cd", "gh", "kl", "op"],
+ "cd": [None, None, None, None],
+ }
+
+ # And with `column_types` as well
+ convert_options.column_types = {"yy": pa.binary(),
+ "cd": pa.int32()}
+ table = self.read_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ schema = pa.schema([('yy', pa.binary()),
+ ('cd', pa.int32())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ "yy": [b"cd", b"gh", b"kl", b"op"],
+ "cd": [None, None, None, None],
+ }
+
+ def test_simple_ints(self):
+ # Infer integer columns
+ rows = b"a,b,c\n1,2,3\n4,5,6\n"
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.int64()),
+ ('b', pa.int64()),
+ ('c', pa.int64())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1, 4],
+ 'b': [2, 5],
+ 'c': [3, 6],
+ }
+
+ def test_simple_varied(self):
+ # Infer various kinds of data
+ rows = b"a,b,c,d\n1,2,3,0\n4.0,-5,foo,True\n"
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.float64()),
+ ('b', pa.int64()),
+ ('c', pa.string()),
+ ('d', pa.bool_())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1.0, 4.0],
+ 'b': [2, -5],
+ 'c': ["3", "foo"],
+ 'd': [False, True],
+ }
+
+ def test_simple_nulls(self):
+ # Infer various kinds of data, with nulls
+ rows = (b"a,b,c,d,e,f\n"
+ b"1,2,,,3,N/A\n"
+ b"nan,-5,foo,,nan,TRUE\n"
+ b"4.5,#N/A,nan,,\xff,false\n")
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.float64()),
+ ('b', pa.int64()),
+ ('c', pa.string()),
+ ('d', pa.null()),
+ ('e', pa.binary()),
+ ('f', pa.bool_())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1.0, None, 4.5],
+ 'b': [2, -5, None],
+ 'c': ["", "foo", "nan"],
+ 'd': [None, None, None],
+ 'e': [b"3", b"nan", b"\xff"],
+ 'f': [None, True, False],
+ }
+
+ def test_decimal_point(self):
+ # Infer floats with a custom decimal point
+ parse_options = ParseOptions(delimiter=';')
+ rows = b"a;b\n1.25;2,5\nNA;-3\n-4;NA"
+
+ table = self.read_bytes(rows, parse_options=parse_options)
+ schema = pa.schema([('a', pa.float64()),
+ ('b', pa.string())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1.25, None, -4.0],
+ 'b': ["2,5", "-3", "NA"],
+ }
+
+ convert_options = ConvertOptions(decimal_point=',')
+ table = self.read_bytes(rows, parse_options=parse_options,
+ convert_options=convert_options)
+ schema = pa.schema([('a', pa.string()),
+ ('b', pa.float64())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': ["1.25", "NA", "-4"],
+ 'b': [2.5, -3.0, None],
+ }
+
+ def test_simple_timestamps(self):
+ # Infer a timestamp column
+ rows = (b"a,b,c\n"
+ b"1970,1970-01-01 00:00:00,1970-01-01 00:00:00.123\n"
+ b"1989,1989-07-14 01:00:00,1989-07-14 01:00:00.123456\n")
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.int64()),
+ ('b', pa.timestamp('s')),
+ ('c', pa.timestamp('ns'))])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1970, 1989],
+ 'b': [datetime(1970, 1, 1), datetime(1989, 7, 14, 1)],
+ 'c': [datetime(1970, 1, 1, 0, 0, 0, 123000),
+ datetime(1989, 7, 14, 1, 0, 0, 123456)],
+ }
+
+ def test_timestamp_parsers(self):
+ # Infer timestamps with custom parsers
+ rows = b"a,b\n1970/01/01,1980-01-01 00\n1970/01/02,1980-01-02 00\n"
+ opts = ConvertOptions()
+
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.string()),
+ ('b', pa.timestamp('s'))])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': ['1970/01/01', '1970/01/02'],
+ 'b': [datetime(1980, 1, 1), datetime(1980, 1, 2)],
+ }
+
+ opts.timestamp_parsers = ['%Y/%m/%d']
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.timestamp('s')),
+ ('b', pa.string())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [datetime(1970, 1, 1), datetime(1970, 1, 2)],
+ 'b': ['1980-01-01 00', '1980-01-02 00'],
+ }
+
+ opts.timestamp_parsers = ['%Y/%m/%d', ISO8601]
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.timestamp('s')),
+ ('b', pa.timestamp('s'))])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [datetime(1970, 1, 1), datetime(1970, 1, 2)],
+ 'b': [datetime(1980, 1, 1), datetime(1980, 1, 2)],
+ }
+
+ def test_dates(self):
+ # Dates are inferred as date32 by default
+ rows = b"a,b\n1970-01-01,1970-01-02\n1971-01-01,1971-01-02\n"
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.date32()),
+ ('b', pa.date32())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [date(1970, 1, 1), date(1971, 1, 1)],
+ 'b': [date(1970, 1, 2), date(1971, 1, 2)],
+ }
+
+ # Can ask for date types explicitly
+ opts = ConvertOptions()
+ opts.column_types = {'a': pa.date32(), 'b': pa.date64()}
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.date32()),
+ ('b', pa.date64())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [date(1970, 1, 1), date(1971, 1, 1)],
+ 'b': [date(1970, 1, 2), date(1971, 1, 2)],
+ }
+
+ # Can ask for timestamp types explicitly
+ opts = ConvertOptions()
+ opts.column_types = {'a': pa.timestamp('s'), 'b': pa.timestamp('ms')}
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.timestamp('s')),
+ ('b', pa.timestamp('ms'))])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [datetime(1970, 1, 1), datetime(1971, 1, 1)],
+ 'b': [datetime(1970, 1, 2), datetime(1971, 1, 2)],
+ }
+
+ def test_times(self):
+ # Times are inferred as time32[s] by default
+ from datetime import time
+
+ rows = b"a,b\n12:34:56,12:34:56.789\n23:59:59,23:59:59.999\n"
+ table = self.read_bytes(rows)
+ # Column 'b' has subseconds, so cannot be inferred as time32[s]
+ schema = pa.schema([('a', pa.time32('s')),
+ ('b', pa.string())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [time(12, 34, 56), time(23, 59, 59)],
+ 'b': ["12:34:56.789", "23:59:59.999"],
+ }
+
+ # Can ask for time types explicitly
+ opts = ConvertOptions()
+ opts.column_types = {'a': pa.time64('us'), 'b': pa.time32('ms')}
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.time64('us')),
+ ('b', pa.time32('ms'))])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [time(12, 34, 56), time(23, 59, 59)],
+ 'b': [time(12, 34, 56, 789000), time(23, 59, 59, 999000)],
+ }
+
+ def test_auto_dict_encode(self):
+ opts = ConvertOptions(auto_dict_encode=True)
+ rows = "a,b\nab,1\ncdé,2\ncdé,3\nab,4".encode()
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.dictionary(pa.int32(), pa.string())),
+ ('b', pa.int64())])
+ expected = {
+ 'a': ["ab", "cdé", "cdé", "ab"],
+ 'b': [1, 2, 3, 4],
+ }
+ assert table.schema == schema
+ assert table.to_pydict() == expected
+
+ opts.auto_dict_max_cardinality = 2
+ table = self.read_bytes(rows, convert_options=opts)
+ assert table.schema == schema
+ assert table.to_pydict() == expected
+
+ # Cardinality above max => plain-encoded
+ opts.auto_dict_max_cardinality = 1
+ table = self.read_bytes(rows, convert_options=opts)
+ assert table.schema == pa.schema([('a', pa.string()),
+ ('b', pa.int64())])
+ assert table.to_pydict() == expected
+
+ # With invalid UTF8, not checked
+ opts.auto_dict_max_cardinality = 50
+ opts.check_utf8 = False
+ rows = b"a,b\nab,1\ncd\xff,2\nab,3"
+ table = self.read_bytes(rows, convert_options=opts,
+ validate_full=False)
+ assert table.schema == schema
+ dict_values = table['a'].chunk(0).dictionary
+ assert len(dict_values) == 2
+ assert dict_values[0].as_py() == "ab"
+ assert dict_values[1].as_buffer() == b"cd\xff"
+
+ # With invalid UTF8, checked
+ opts.check_utf8 = True
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.dictionary(pa.int32(), pa.binary())),
+ ('b', pa.int64())])
+ expected = {
+ 'a': [b"ab", b"cd\xff", b"ab"],
+ 'b': [1, 2, 3],
+ }
+ assert table.schema == schema
+ assert table.to_pydict() == expected
+
+ def test_custom_nulls(self):
+ # Infer nulls with custom values
+ opts = ConvertOptions(null_values=['Xxx', 'Zzz'])
+ rows = b"""a,b,c,d\nZzz,"Xxx",1,2\nXxx,#N/A,,Zzz\n"""
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.null()),
+ ('b', pa.string()),
+ ('c', pa.string()),
+ ('d', pa.int64())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [None, None],
+ 'b': ["Xxx", "#N/A"],
+ 'c': ["1", ""],
+ 'd': [2, None],
+ }
+
+ opts = ConvertOptions(null_values=['Xxx', 'Zzz'],
+ strings_can_be_null=True)
+ table = self.read_bytes(rows, convert_options=opts)
+ assert table.to_pydict() == {
+ 'a': [None, None],
+ 'b': [None, "#N/A"],
+ 'c': ["1", ""],
+ 'd': [2, None],
+ }
+ opts.quoted_strings_can_be_null = False
+ table = self.read_bytes(rows, convert_options=opts)
+ assert table.to_pydict() == {
+ 'a': [None, None],
+ 'b': ["Xxx", "#N/A"],
+ 'c': ["1", ""],
+ 'd': [2, None],
+ }
+
+ opts = ConvertOptions(null_values=[])
+ rows = b"a,b\n#N/A,\n"
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.string()),
+ ('b', pa.string())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': ["#N/A"],
+ 'b': [""],
+ }
+
+ def test_custom_bools(self):
+ # Infer booleans with custom values
+ opts = ConvertOptions(true_values=['T', 'yes'],
+ false_values=['F', 'no'])
+ rows = (b"a,b,c\n"
+ b"True,T,t\n"
+ b"False,F,f\n"
+ b"True,yes,yes\n"
+ b"False,no,no\n"
+ b"N/A,N/A,N/A\n")
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.string()),
+ ('b', pa.bool_()),
+ ('c', pa.string())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': ["True", "False", "True", "False", "N/A"],
+ 'b': [True, False, True, False, None],
+ 'c': ["t", "f", "yes", "no", "N/A"],
+ }
+
+ def test_column_types(self):
+ # Ask for specific column types in ConvertOptions
+ opts = ConvertOptions(column_types={'b': 'float32',
+ 'c': 'string',
+ 'd': 'boolean',
+ 'e': pa.decimal128(11, 2),
+ 'zz': 'null'})
+ rows = b"a,b,c,d,e\n1,2,3,true,1.0\n4,-5,6,false,0\n"
+ table = self.read_bytes(rows, convert_options=opts)
+ schema = pa.schema([('a', pa.int64()),
+ ('b', pa.float32()),
+ ('c', pa.string()),
+ ('d', pa.bool_()),
+ ('e', pa.decimal128(11, 2))])
+ expected = {
+ 'a': [1, 4],
+ 'b': [2.0, -5.0],
+ 'c': ["3", "6"],
+ 'd': [True, False],
+ 'e': [Decimal("1.00"), Decimal("0.00")]
+ }
+ assert table.schema == schema
+ assert table.to_pydict() == expected
+ # Pass column_types as schema
+ opts = ConvertOptions(
+ column_types=pa.schema([('b', pa.float32()),
+ ('c', pa.string()),
+ ('d', pa.bool_()),
+ ('e', pa.decimal128(11, 2)),
+ ('zz', pa.bool_())]))
+ table = self.read_bytes(rows, convert_options=opts)
+ assert table.schema == schema
+ assert table.to_pydict() == expected
+ # One of the columns in column_types fails converting
+ rows = b"a,b,c,d,e\n1,XXX,3,true,5\n4,-5,6,false,7\n"
+ with pytest.raises(pa.ArrowInvalid) as exc:
+ self.read_bytes(rows, convert_options=opts)
+ err = str(exc.value)
+ assert "In CSV column #1: " in err
+ assert "CSV conversion error to float: invalid value 'XXX'" in err
+
+ def test_column_types_dict(self):
+ # Ask for dict-encoded column types in ConvertOptions
+ column_types = [
+ ('a', pa.dictionary(pa.int32(), pa.utf8())),
+ ('b', pa.dictionary(pa.int32(), pa.int64())),
+ ('c', pa.dictionary(pa.int32(), pa.decimal128(11, 2))),
+ ('d', pa.dictionary(pa.int32(), pa.large_utf8()))]
+
+ opts = ConvertOptions(column_types=dict(column_types))
+ rows = (b"a,b,c,d\n"
+ b"abc,123456,1.0,zz\n"
+ b"defg,123456,0.5,xx\n"
+ b"abc,N/A,1.0,xx\n")
+ table = self.read_bytes(rows, convert_options=opts)
+
+ schema = pa.schema(column_types)
+ expected = {
+ 'a': ["abc", "defg", "abc"],
+ 'b': [123456, 123456, None],
+ 'c': [Decimal("1.00"), Decimal("0.50"), Decimal("1.00")],
+ 'd': ["zz", "xx", "xx"],
+ }
+ assert table.schema == schema
+ assert table.to_pydict() == expected
+
+ # Unsupported index type
+ column_types[0] = ('a', pa.dictionary(pa.int8(), pa.utf8()))
+
+ opts = ConvertOptions(column_types=dict(column_types))
+ with pytest.raises(NotImplementedError):
+ table = self.read_bytes(rows, convert_options=opts)
+
+ def test_column_types_with_column_names(self):
+ # When both `column_names` and `column_types` are given, names
+ # in `column_types` should refer to names in `column_names`
+ rows = b"a,b\nc,d\ne,f\n"
+ read_options = ReadOptions(column_names=['x', 'y'])
+ convert_options = ConvertOptions(column_types={'x': pa.binary()})
+ table = self.read_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ schema = pa.schema([('x', pa.binary()),
+ ('y', pa.string())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'x': [b'a', b'c', b'e'],
+ 'y': ['b', 'd', 'f'],
+ }
+
+ def test_no_ending_newline(self):
+ # No \n after last line
+ rows = b"a,b,c\n1,2,3\n4,5,6"
+ table = self.read_bytes(rows)
+ assert table.to_pydict() == {
+ 'a': [1, 4],
+ 'b': [2, 5],
+ 'c': [3, 6],
+ }
+
+ def test_trivial(self):
+ # A bit pointless, but at least it shouldn't crash
+ rows = b",\n\n"
+ table = self.read_bytes(rows)
+ assert table.to_pydict() == {'': []}
+
+ def test_empty_lines(self):
+ rows = b"a,b\n\r1,2\r\n\r\n3,4\r\n"
+ table = self.read_bytes(rows)
+ assert table.to_pydict() == {
+ 'a': [1, 3],
+ 'b': [2, 4],
+ }
+ parse_options = ParseOptions(ignore_empty_lines=False)
+ table = self.read_bytes(rows, parse_options=parse_options)
+ assert table.to_pydict() == {
+ 'a': [None, 1, None, 3],
+ 'b': [None, 2, None, 4],
+ }
+ read_options = ReadOptions(skip_rows=2)
+ table = self.read_bytes(rows, parse_options=parse_options,
+ read_options=read_options)
+ assert table.to_pydict() == {
+ '1': [None, 3],
+ '2': [None, 4],
+ }
+
+ def test_invalid_csv(self):
+ # Various CSV errors
+ rows = b"a,b,c\n1,2\n4,5,6\n"
+ with pytest.raises(pa.ArrowInvalid, match="Expected 3 columns, got 2"):
+ self.read_bytes(rows)
+ rows = b"a,b,c\n1,2,3\n4"
+ with pytest.raises(pa.ArrowInvalid, match="Expected 3 columns, got 1"):
+ self.read_bytes(rows)
+ for rows in [b"", b"\n", b"\r\n", b"\r", b"\n\n"]:
+ with pytest.raises(pa.ArrowInvalid, match="Empty CSV file"):
+ self.read_bytes(rows)
+
+ def test_options_delimiter(self):
+ rows = b"a;b,c\nde,fg;eh\n"
+ table = self.read_bytes(rows)
+ assert table.to_pydict() == {
+ 'a;b': ['de'],
+ 'c': ['fg;eh'],
+ }
+ opts = ParseOptions(delimiter=';')
+ table = self.read_bytes(rows, parse_options=opts)
+ assert table.to_pydict() == {
+ 'a': ['de,fg'],
+ 'b,c': ['eh'],
+ }
+
+ def test_small_random_csv(self):
+ csv, expected = make_random_csv(num_cols=2, num_rows=10)
+ table = self.read_bytes(csv)
+ assert table.schema == expected.schema
+ assert table.equals(expected)
+ assert table.to_pydict() == expected.to_pydict()
+
+ def test_stress_block_sizes(self):
+ # Test a number of small block sizes to stress block stitching
+ csv_base, expected = make_random_csv(num_cols=2, num_rows=500)
+ block_sizes = [11, 12, 13, 17, 37, 111]
+ csvs = [csv_base, csv_base.rstrip(b'\r\n')]
+ for csv in csvs:
+ for block_size in block_sizes:
+ read_options = ReadOptions(block_size=block_size)
+ table = self.read_bytes(csv, read_options=read_options)
+ assert table.schema == expected.schema
+ if not table.equals(expected):
+ # Better error output
+ assert table.to_pydict() == expected.to_pydict()
+
+ def test_stress_convert_options_blowup(self):
+ # ARROW-6481: A convert_options with a very large number of columns
+ # should not blow memory and CPU time.
+ try:
+ clock = time.thread_time
+ except AttributeError:
+ clock = time.time
+ num_columns = 10000
+ col_names = ["K{}".format(i) for i in range(num_columns)]
+ csv = make_empty_csv(col_names)
+ t1 = clock()
+ convert_options = ConvertOptions(
+ column_types={k: pa.string() for k in col_names[::2]})
+ table = self.read_bytes(csv, convert_options=convert_options)
+ dt = clock() - t1
+ # Check that processing time didn't blow up.
+ # This is a conservative check (it takes less than 300 ms
+ # in debug mode on my local machine).
+ assert dt <= 10.0
+ # Check result
+ assert table.num_columns == num_columns
+ assert table.num_rows == 0
+ assert table.column_names == col_names
+
+ def test_cancellation(self):
+ if (threading.current_thread().ident !=
+ threading.main_thread().ident):
+ pytest.skip("test only works from main Python thread")
+ # Skips test if not available
+ raise_signal = util.get_raise_signal()
+
+ # Make the interruptible workload large enough to not finish
+ # before the interrupt comes, even in release mode on fast machines.
+ last_duration = 0.0
+ workload_size = 100_000
+
+ while last_duration < 1.0:
+ print("workload size:", workload_size)
+ large_csv = b"a,b,c\n" + b"1,2,3\n" * workload_size
+ t1 = time.time()
+ self.read_bytes(large_csv)
+ last_duration = time.time() - t1
+ workload_size = workload_size * 3
+
+ def signal_from_thread():
+ time.sleep(0.2)
+ raise_signal(signal.SIGINT)
+
+ t1 = time.time()
+ try:
+ try:
+ t = threading.Thread(target=signal_from_thread)
+ with pytest.raises(KeyboardInterrupt) as exc_info:
+ t.start()
+ self.read_bytes(large_csv)
+ finally:
+ t.join()
+ except KeyboardInterrupt:
+ # In case KeyboardInterrupt didn't interrupt `self.read_bytes`
+ # above, at least prevent it from stopping the test suite
+ pytest.fail("KeyboardInterrupt didn't interrupt CSV reading")
+ dt = time.time() - t1
+ # Interruption should have arrived timely
+ assert dt <= 1.0
+ e = exc_info.value.__context__
+ assert isinstance(e, pa.ArrowCancelled)
+ assert e.signum == signal.SIGINT
+
+ def test_cancellation_disabled(self):
+ # ARROW-12622: reader would segfault when the cancelling signal
+ # handler was not enabled (e.g. if disabled, or if not on the
+ # main thread)
+ t = threading.Thread(
+ target=lambda: self.read_bytes(b"f64\n0.1"))
+ t.start()
+ t.join()
+
+
+class TestSerialCSVTableRead(BaseCSVTableRead):
+ @property
+ def use_threads(self):
+ return False
+
+
+class TestThreadedCSVTableRead(BaseCSVTableRead):
+ @property
+ def use_threads(self):
+ return True
+
+
+class BaseStreamingCSVRead(BaseTestCSV):
+
+ def open_csv(self, csv, *args, **kwargs):
+ """
+ Reads the CSV file into memory using pyarrow's open_csv
+ csv The CSV bytes
+ args Positional arguments to be forwarded to pyarrow's open_csv
+ kwargs Keyword arguments to be forwarded to pyarrow's open_csv
+ """
+ read_options = kwargs.setdefault('read_options', ReadOptions())
+ read_options.use_threads = self.use_threads
+ return open_csv(csv, *args, **kwargs)
+
+ def open_bytes(self, b, **kwargs):
+ return self.open_csv(pa.py_buffer(b), **kwargs)
+
+ def check_reader(self, reader, expected_schema, expected_data):
+ assert reader.schema == expected_schema
+ batches = list(reader)
+ assert len(batches) == len(expected_data)
+ for batch, expected_batch in zip(batches, expected_data):
+ batch.validate(full=True)
+ assert batch.schema == expected_schema
+ assert batch.to_pydict() == expected_batch
+
+ def read_bytes(self, b, **kwargs):
+ return self.open_bytes(b, **kwargs).read_all()
+
+ def test_file_object(self):
+ data = b"a,b\n1,2\n3,4\n"
+ expected_data = {'a': [1, 3], 'b': [2, 4]}
+ bio = io.BytesIO(data)
+ reader = self.open_csv(bio)
+ expected_schema = pa.schema([('a', pa.int64()),
+ ('b', pa.int64())])
+ self.check_reader(reader, expected_schema, [expected_data])
+
+ def test_header(self):
+ rows = b"abc,def,gh\n"
+ reader = self.open_bytes(rows)
+ expected_schema = pa.schema([('abc', pa.null()),
+ ('def', pa.null()),
+ ('gh', pa.null())])
+ self.check_reader(reader, expected_schema, [])
+
+ def test_inference(self):
+ # Inference is done on first block
+ rows = b"a,b\n123,456\nabc,de\xff\ngh,ij\n"
+ expected_schema = pa.schema([('a', pa.string()),
+ ('b', pa.binary())])
+
+ read_options = ReadOptions()
+ read_options.block_size = len(rows)
+ reader = self.open_bytes(rows, read_options=read_options)
+ self.check_reader(reader, expected_schema,
+ [{'a': ['123', 'abc', 'gh'],
+ 'b': [b'456', b'de\xff', b'ij']}])
+
+ read_options.block_size = len(rows) - 1
+ reader = self.open_bytes(rows, read_options=read_options)
+ self.check_reader(reader, expected_schema,
+ [{'a': ['123', 'abc'],
+ 'b': [b'456', b'de\xff']},
+ {'a': ['gh'],
+ 'b': [b'ij']}])
+
+ def test_inference_failure(self):
+ # Inference on first block, then conversion failure on second block
+ rows = b"a,b\n123,456\nabc,de\xff\ngh,ij\n"
+ read_options = ReadOptions()
+ read_options.block_size = len(rows) - 7
+ reader = self.open_bytes(rows, read_options=read_options)
+ expected_schema = pa.schema([('a', pa.int64()),
+ ('b', pa.int64())])
+ assert reader.schema == expected_schema
+ assert reader.read_next_batch().to_pydict() == {
+ 'a': [123], 'b': [456]
+ }
+ # Second block
+ with pytest.raises(ValueError,
+ match="CSV conversion error to int64"):
+ reader.read_next_batch()
+ # EOF
+ with pytest.raises(StopIteration):
+ reader.read_next_batch()
+
+ def test_invalid_csv(self):
+ # CSV errors on first block
+ rows = b"a,b\n1,2,3\n4,5\n6,7\n"
+ read_options = ReadOptions()
+ read_options.block_size = 10
+ with pytest.raises(pa.ArrowInvalid,
+ match="Expected 2 columns, got 3"):
+ reader = self.open_bytes(
+ rows, read_options=read_options)
+
+ # CSV errors on second block
+ rows = b"a,b\n1,2\n3,4,5\n6,7\n"
+ read_options.block_size = 8
+ reader = self.open_bytes(rows, read_options=read_options)
+ assert reader.read_next_batch().to_pydict() == {'a': [1], 'b': [2]}
+ with pytest.raises(pa.ArrowInvalid,
+ match="Expected 2 columns, got 3"):
+ reader.read_next_batch()
+ # Cannot continue after a parse error
+ with pytest.raises(StopIteration):
+ reader.read_next_batch()
+
+ def test_options_delimiter(self):
+ rows = b"a;b,c\nde,fg;eh\n"
+ reader = self.open_bytes(rows)
+ expected_schema = pa.schema([('a;b', pa.string()),
+ ('c', pa.string())])
+ self.check_reader(reader, expected_schema,
+ [{'a;b': ['de'],
+ 'c': ['fg;eh']}])
+
+ opts = ParseOptions(delimiter=';')
+ reader = self.open_bytes(rows, parse_options=opts)
+ expected_schema = pa.schema([('a', pa.string()),
+ ('b,c', pa.string())])
+ self.check_reader(reader, expected_schema,
+ [{'a': ['de,fg'],
+ 'b,c': ['eh']}])
+
+ def test_no_ending_newline(self):
+ # No \n after last line
+ rows = b"a,b,c\n1,2,3\n4,5,6"
+ reader = self.open_bytes(rows)
+ expected_schema = pa.schema([('a', pa.int64()),
+ ('b', pa.int64()),
+ ('c', pa.int64())])
+ self.check_reader(reader, expected_schema,
+ [{'a': [1, 4],
+ 'b': [2, 5],
+ 'c': [3, 6]}])
+
+ def test_empty_file(self):
+ with pytest.raises(ValueError, match="Empty CSV file"):
+ self.open_bytes(b"")
+
+ def test_column_options(self):
+ # With column_names
+ rows = b"1,2,3\n4,5,6"
+ read_options = ReadOptions()
+ read_options.column_names = ['d', 'e', 'f']
+ reader = self.open_bytes(rows, read_options=read_options)
+ expected_schema = pa.schema([('d', pa.int64()),
+ ('e', pa.int64()),
+ ('f', pa.int64())])
+ self.check_reader(reader, expected_schema,
+ [{'d': [1, 4],
+ 'e': [2, 5],
+ 'f': [3, 6]}])
+
+ # With include_columns
+ convert_options = ConvertOptions()
+ convert_options.include_columns = ['f', 'e']
+ reader = self.open_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ expected_schema = pa.schema([('f', pa.int64()),
+ ('e', pa.int64())])
+ self.check_reader(reader, expected_schema,
+ [{'e': [2, 5],
+ 'f': [3, 6]}])
+
+ # With column_types
+ convert_options.column_types = {'e': pa.string()}
+ reader = self.open_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ expected_schema = pa.schema([('f', pa.int64()),
+ ('e', pa.string())])
+ self.check_reader(reader, expected_schema,
+ [{'e': ["2", "5"],
+ 'f': [3, 6]}])
+
+ # Missing columns in include_columns
+ convert_options.include_columns = ['g', 'f', 'e']
+ with pytest.raises(
+ KeyError,
+ match="Column 'g' in include_columns does not exist"):
+ reader = self.open_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+
+ convert_options.include_missing_columns = True
+ reader = self.open_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ expected_schema = pa.schema([('g', pa.null()),
+ ('f', pa.int64()),
+ ('e', pa.string())])
+ self.check_reader(reader, expected_schema,
+ [{'g': [None, None],
+ 'e': ["2", "5"],
+ 'f': [3, 6]}])
+
+ convert_options.column_types = {'e': pa.string(), 'g': pa.float64()}
+ reader = self.open_bytes(rows, read_options=read_options,
+ convert_options=convert_options)
+ expected_schema = pa.schema([('g', pa.float64()),
+ ('f', pa.int64()),
+ ('e', pa.string())])
+ self.check_reader(reader, expected_schema,
+ [{'g': [None, None],
+ 'e': ["2", "5"],
+ 'f': [3, 6]}])
+
+ def test_encoding(self):
+ # latin-1 (invalid utf-8)
+ rows = b"a,b\nun,\xe9l\xe9phant"
+ read_options = ReadOptions()
+ reader = self.open_bytes(rows, read_options=read_options)
+ expected_schema = pa.schema([('a', pa.string()),
+ ('b', pa.binary())])
+ self.check_reader(reader, expected_schema,
+ [{'a': ["un"],
+ 'b': [b"\xe9l\xe9phant"]}])
+
+ read_options.encoding = 'latin1'
+ reader = self.open_bytes(rows, read_options=read_options)
+ expected_schema = pa.schema([('a', pa.string()),
+ ('b', pa.string())])
+ self.check_reader(reader, expected_schema,
+ [{'a': ["un"],
+ 'b': ["éléphant"]}])
+
+ # utf-16
+ rows = (b'\xff\xfea\x00,\x00b\x00\n\x00u\x00n\x00,'
+ b'\x00\xe9\x00l\x00\xe9\x00p\x00h\x00a\x00n\x00t\x00')
+ read_options.encoding = 'utf16'
+ reader = self.open_bytes(rows, read_options=read_options)
+ expected_schema = pa.schema([('a', pa.string()),
+ ('b', pa.string())])
+ self.check_reader(reader, expected_schema,
+ [{'a': ["un"],
+ 'b': ["éléphant"]}])
+
+ def test_small_random_csv(self):
+ csv, expected = make_random_csv(num_cols=2, num_rows=10)
+ reader = self.open_bytes(csv)
+ table = reader.read_all()
+ assert table.schema == expected.schema
+ assert table.equals(expected)
+ assert table.to_pydict() == expected.to_pydict()
+
+ def test_stress_block_sizes(self):
+ # Test a number of small block sizes to stress block stitching
+ csv_base, expected = make_random_csv(num_cols=2, num_rows=500)
+ block_sizes = [19, 21, 23, 26, 37, 111]
+ csvs = [csv_base, csv_base.rstrip(b'\r\n')]
+ for csv in csvs:
+ for block_size in block_sizes:
+ # Need at least two lines for type inference
+ assert csv[:block_size].count(b'\n') >= 2
+ read_options = ReadOptions(block_size=block_size)
+ reader = self.open_bytes(
+ csv, read_options=read_options)
+ table = reader.read_all()
+ assert table.schema == expected.schema
+ if not table.equals(expected):
+ # Better error output
+ assert table.to_pydict() == expected.to_pydict()
+
+ def test_batch_lifetime(self):
+ gc.collect()
+ old_allocated = pa.total_allocated_bytes()
+
+ # Memory occupation should not grow with CSV file size
+ def check_one_batch(reader, expected):
+ batch = reader.read_next_batch()
+ assert batch.to_pydict() == expected
+
+ rows = b"10,11\n12,13\n14,15\n16,17\n"
+ read_options = ReadOptions()
+ read_options.column_names = ['a', 'b']
+ read_options.block_size = 6
+ reader = self.open_bytes(rows, read_options=read_options)
+ check_one_batch(reader, {'a': [10], 'b': [11]})
+ allocated_after_first_batch = pa.total_allocated_bytes()
+ check_one_batch(reader, {'a': [12], 'b': [13]})
+ assert pa.total_allocated_bytes() <= allocated_after_first_batch
+ check_one_batch(reader, {'a': [14], 'b': [15]})
+ assert pa.total_allocated_bytes() <= allocated_after_first_batch
+ check_one_batch(reader, {'a': [16], 'b': [17]})
+ assert pa.total_allocated_bytes() <= allocated_after_first_batch
+ with pytest.raises(StopIteration):
+ reader.read_next_batch()
+ assert pa.total_allocated_bytes() == old_allocated
+ reader = None
+ assert pa.total_allocated_bytes() == old_allocated
+
+ def test_header_skip_rows(self):
+ super().test_header_skip_rows()
+
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ # Skipping all rows immediately results in end of iteration
+ opts = ReadOptions()
+ opts.skip_rows = 4
+ opts.column_names = ['ab', 'cd']
+ reader = self.open_bytes(rows, read_options=opts)
+ with pytest.raises(StopIteration):
+ assert reader.read_next_batch()
+
+ def test_skip_rows_after_names(self):
+ super().test_skip_rows_after_names()
+
+ rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n"
+
+ # Skipping all rows immediately results in end of iteration
+ opts = ReadOptions()
+ opts.skip_rows_after_names = 3
+ reader = self.open_bytes(rows, read_options=opts)
+ with pytest.raises(StopIteration):
+ assert reader.read_next_batch()
+
+ # Skipping beyond all rows immediately results in end of iteration
+ opts.skip_rows_after_names = 99999
+ reader = self.open_bytes(rows, read_options=opts)
+ with pytest.raises(StopIteration):
+ assert reader.read_next_batch()
+
+
+class TestSerialStreamingCSVRead(BaseStreamingCSVRead, unittest.TestCase):
+ @property
+ def use_threads(self):
+ return False
+
+
+class TestThreadedStreamingCSVRead(BaseStreamingCSVRead, unittest.TestCase):
+ @property
+ def use_threads(self):
+ return True
+
+
+class BaseTestCompressedCSVRead:
+
+ def setUp(self):
+ self.tmpdir = tempfile.mkdtemp(prefix='arrow-csv-test-')
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdir)
+
+ def read_csv(self, csv_path):
+ try:
+ return read_csv(csv_path)
+ except pa.ArrowNotImplementedError as e:
+ pytest.skip(str(e))
+
+ def test_random_csv(self):
+ csv, expected = make_random_csv(num_cols=2, num_rows=100)
+ csv_path = os.path.join(self.tmpdir, self.csv_filename)
+ self.write_file(csv_path, csv)
+ table = self.read_csv(csv_path)
+ table.validate(full=True)
+ assert table.schema == expected.schema
+ assert table.equals(expected)
+ assert table.to_pydict() == expected.to_pydict()
+
+
+class TestGZipCSVRead(BaseTestCompressedCSVRead, unittest.TestCase):
+ csv_filename = "compressed.csv.gz"
+
+ def write_file(self, path, contents):
+ with gzip.open(path, 'wb', 3) as f:
+ f.write(contents)
+
+ def test_concatenated(self):
+ # ARROW-5974
+ csv_path = os.path.join(self.tmpdir, self.csv_filename)
+ with gzip.open(csv_path, 'wb', 3) as f:
+ f.write(b"ab,cd\nef,gh\n")
+ with gzip.open(csv_path, 'ab', 3) as f:
+ f.write(b"ij,kl\nmn,op\n")
+ table = self.read_csv(csv_path)
+ assert table.to_pydict() == {
+ 'ab': ['ef', 'ij', 'mn'],
+ 'cd': ['gh', 'kl', 'op'],
+ }
+
+
+class TestBZ2CSVRead(BaseTestCompressedCSVRead, unittest.TestCase):
+ csv_filename = "compressed.csv.bz2"
+
+ def write_file(self, path, contents):
+ with bz2.BZ2File(path, 'w') as f:
+ f.write(contents)
+
+
+def test_read_csv_does_not_close_passed_file_handles():
+ # ARROW-4823
+ buf = io.BytesIO(b"a,b,c\n1,2,3\n4,5,6")
+ read_csv(buf)
+ assert not buf.closed
+
+
+def test_write_read_round_trip():
+ t = pa.Table.from_arrays([[1, 2, 3], ["a", "b", "c"]], ["c1", "c2"])
+ record_batch = t.to_batches(max_chunksize=4)[0]
+ for data in [t, record_batch]:
+ # Test with header
+ buf = io.BytesIO()
+ write_csv(data, buf, WriteOptions(include_header=True))
+ buf.seek(0)
+ assert t == read_csv(buf)
+
+ # Test without header
+ buf = io.BytesIO()
+ write_csv(data, buf, WriteOptions(include_header=False))
+ buf.seek(0)
+
+ read_options = ReadOptions(column_names=t.column_names)
+ assert t == read_csv(buf, read_options=read_options)
+
+ # Test with writer
+ for read_options, write_options in [
+ (None, WriteOptions(include_header=True)),
+ (ReadOptions(column_names=t.column_names),
+ WriteOptions(include_header=False)),
+ ]:
+ buf = io.BytesIO()
+ with CSVWriter(buf, t.schema, write_options=write_options) as writer:
+ writer.write_table(t)
+ buf.seek(0)
+ assert t == read_csv(buf, read_options=read_options)
+
+ buf = io.BytesIO()
+ with CSVWriter(buf, t.schema, write_options=write_options) as writer:
+ for batch in t.to_batches(max_chunksize=1):
+ writer.write_batch(batch)
+ buf.seek(0)
+ assert t == read_csv(buf, read_options=read_options)
+
+
+def test_read_csv_reference_cycle():
+ # ARROW-13187
+ def inner():
+ buf = io.BytesIO(b"a,b,c\n1,2,3\n4,5,6")
+ table = read_csv(buf)
+ return weakref.ref(table)
+
+ with util.disabled_gc():
+ wr = inner()
+ assert wr() is None
diff --git a/src/arrow/python/pyarrow/tests/test_cuda.py b/src/arrow/python/pyarrow/tests/test_cuda.py
new file mode 100644
index 000000000..2ba2f8267
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_cuda.py
@@ -0,0 +1,792 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+UNTESTED:
+read_message
+"""
+
+import sys
+import sysconfig
+
+import pytest
+
+import pyarrow as pa
+import numpy as np
+
+
+cuda = pytest.importorskip("pyarrow.cuda")
+
+platform = sysconfig.get_platform()
+# TODO: enable ppc64 when Arrow C++ supports IPC in ppc64 systems:
+has_ipc_support = platform == 'linux-x86_64' # or 'ppc64' in platform
+
+cuda_ipc = pytest.mark.skipif(
+ not has_ipc_support,
+ reason='CUDA IPC not supported in platform `%s`' % (platform))
+
+global_context = None # for flake8
+global_context1 = None # for flake8
+
+
+def setup_module(module):
+ module.global_context = cuda.Context(0)
+ module.global_context1 = cuda.Context(cuda.Context.get_num_devices() - 1)
+
+
+def teardown_module(module):
+ del module.global_context
+
+
+def test_Context():
+ assert cuda.Context.get_num_devices() > 0
+ assert global_context.device_number == 0
+ assert global_context1.device_number == cuda.Context.get_num_devices() - 1
+
+ with pytest.raises(ValueError,
+ match=("device_number argument must "
+ "be non-negative less than")):
+ cuda.Context(cuda.Context.get_num_devices())
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_manage_allocate_free_host(size):
+ buf = cuda.new_host_buffer(size)
+ arr = np.frombuffer(buf, dtype=np.uint8)
+ arr[size//4:3*size//4] = 1
+ arr_cp = arr.copy()
+ arr2 = np.frombuffer(buf, dtype=np.uint8)
+ np.testing.assert_equal(arr2, arr_cp)
+ assert buf.size == size
+
+
+def test_context_allocate_del():
+ bytes_allocated = global_context.bytes_allocated
+ cudabuf = global_context.new_buffer(128)
+ assert global_context.bytes_allocated == bytes_allocated + 128
+ del cudabuf
+ assert global_context.bytes_allocated == bytes_allocated
+
+
+def make_random_buffer(size, target='host'):
+ """Return a host or device buffer with random data.
+ """
+ if target == 'host':
+ assert size >= 0
+ buf = pa.allocate_buffer(size)
+ assert buf.size == size
+ arr = np.frombuffer(buf, dtype=np.uint8)
+ assert arr.size == size
+ arr[:] = np.random.randint(low=1, high=255, size=size, dtype=np.uint8)
+ assert arr.sum() > 0 or size == 0
+ arr_ = np.frombuffer(buf, dtype=np.uint8)
+ np.testing.assert_equal(arr, arr_)
+ return arr, buf
+ elif target == 'device':
+ arr, buf = make_random_buffer(size, target='host')
+ dbuf = global_context.new_buffer(size)
+ assert dbuf.size == size
+ dbuf.copy_from_host(buf, position=0, nbytes=size)
+ return arr, dbuf
+ raise ValueError('invalid target value')
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_context_device_buffer(size):
+ # Creating device buffer from host buffer;
+ arr, buf = make_random_buffer(size)
+ cudabuf = global_context.buffer_from_data(buf)
+ assert cudabuf.size == size
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ # CudaBuffer does not support buffer protocol
+ with pytest.raises(BufferError):
+ memoryview(cudabuf)
+
+ # Creating device buffer from array:
+ cudabuf = global_context.buffer_from_data(arr)
+ assert cudabuf.size == size
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ # Creating device buffer from bytes:
+ cudabuf = global_context.buffer_from_data(arr.tobytes())
+ assert cudabuf.size == size
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ # Creating a device buffer from another device buffer, view:
+ cudabuf2 = cudabuf.slice(0, cudabuf.size)
+ assert cudabuf2.size == size
+ arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ if size > 1:
+ cudabuf2.copy_from_host(arr[size//2:])
+ arr3 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(np.concatenate((arr[size//2:], arr[size//2:])),
+ arr3)
+ cudabuf2.copy_from_host(arr[:size//2]) # restoring arr
+
+ # Creating a device buffer from another device buffer, copy:
+ cudabuf2 = global_context.buffer_from_data(cudabuf)
+ assert cudabuf2.size == size
+ arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ cudabuf2.copy_from_host(arr[size//2:])
+ arr3 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr3)
+
+ # Slice of a device buffer
+ cudabuf2 = cudabuf.slice(0, cudabuf.size+10)
+ assert cudabuf2.size == size
+ arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ cudabuf2 = cudabuf.slice(size//4, size+10)
+ assert cudabuf2.size == size - size//4
+ arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[size//4:], arr2)
+
+ # Creating a device buffer from a slice of host buffer
+ soffset = size//4
+ ssize = 2*size//4
+ cudabuf = global_context.buffer_from_data(buf, offset=soffset,
+ size=ssize)
+ assert cudabuf.size == ssize
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
+
+ cudabuf = global_context.buffer_from_data(buf.slice(offset=soffset,
+ length=ssize))
+ assert cudabuf.size == ssize
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
+
+ # Creating a device buffer from a slice of an array
+ cudabuf = global_context.buffer_from_data(arr, offset=soffset, size=ssize)
+ assert cudabuf.size == ssize
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
+
+ cudabuf = global_context.buffer_from_data(arr[soffset:soffset+ssize])
+ assert cudabuf.size == ssize
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
+
+ # Creating a device buffer from a slice of bytes
+ cudabuf = global_context.buffer_from_data(arr.tobytes(),
+ offset=soffset,
+ size=ssize)
+ assert cudabuf.size == ssize
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset + ssize], arr2)
+
+ # Creating a device buffer from size
+ cudabuf = global_context.new_buffer(size)
+ assert cudabuf.size == size
+
+ # Creating device buffer from a slice of another device buffer:
+ cudabuf = global_context.buffer_from_data(arr)
+ cudabuf2 = cudabuf.slice(soffset, ssize)
+ assert cudabuf2.size == ssize
+ arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset+ssize], arr2)
+
+ # Creating device buffer from HostBuffer
+
+ buf = cuda.new_host_buffer(size)
+ arr_ = np.frombuffer(buf, dtype=np.uint8)
+ arr_[:] = arr
+ cudabuf = global_context.buffer_from_data(buf)
+ assert cudabuf.size == size
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ # Creating device buffer from HostBuffer slice
+
+ cudabuf = global_context.buffer_from_data(buf, offset=soffset, size=ssize)
+ assert cudabuf.size == ssize
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset+ssize], arr2)
+
+ cudabuf = global_context.buffer_from_data(
+ buf.slice(offset=soffset, length=ssize))
+ assert cudabuf.size == ssize
+ arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr[soffset:soffset+ssize], arr2)
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_context_from_object(size):
+ ctx = global_context
+ arr, cbuf = make_random_buffer(size, target='device')
+ dtype = arr.dtype
+
+ # Creating device buffer from a CUDA host buffer
+ hbuf = cuda.new_host_buffer(size * arr.dtype.itemsize)
+ np.frombuffer(hbuf, dtype=dtype)[:] = arr
+ cbuf2 = ctx.buffer_from_object(hbuf)
+ assert cbuf2.size == cbuf.size
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr, arr2)
+
+ # Creating device buffer from a device buffer
+ cbuf2 = ctx.buffer_from_object(cbuf2)
+ assert cbuf2.size == cbuf.size
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr, arr2)
+
+ # Trying to create a device buffer from a Buffer
+ with pytest.raises(pa.ArrowTypeError,
+ match=('buffer is not backed by a CudaBuffer')):
+ ctx.buffer_from_object(pa.py_buffer(b"123"))
+
+ # Trying to create a device buffer from numpy.array
+ with pytest.raises(pa.ArrowTypeError,
+ match=("cannot create device buffer view from "
+ ".* \'numpy.ndarray\'")):
+ ctx.buffer_from_object(np.array([1, 2, 3]))
+
+
+def test_foreign_buffer():
+ ctx = global_context
+ dtype = np.dtype(np.uint8)
+ size = 10
+ hbuf = cuda.new_host_buffer(size * dtype.itemsize)
+
+ # test host buffer memory reference counting
+ rc = sys.getrefcount(hbuf)
+ fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size, hbuf)
+ assert sys.getrefcount(hbuf) == rc + 1
+ del fbuf
+ assert sys.getrefcount(hbuf) == rc
+
+ # test postponed deallocation of host buffer memory
+ fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size, hbuf)
+ del hbuf
+ fbuf.copy_to_host()
+
+ # test deallocating the host buffer memory making it inaccessible
+ hbuf = cuda.new_host_buffer(size * dtype.itemsize)
+ fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size)
+ del hbuf
+ with pytest.raises(pa.ArrowIOError,
+ match=('Cuda error ')):
+ fbuf.copy_to_host()
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_CudaBuffer(size):
+ arr, buf = make_random_buffer(size)
+ assert arr.tobytes() == buf.to_pybytes()
+ cbuf = global_context.buffer_from_data(buf)
+ assert cbuf.size == size
+ assert not cbuf.is_cpu
+ assert arr.tobytes() == cbuf.to_pybytes()
+ if size > 0:
+ assert cbuf.address > 0
+
+ for i in range(size):
+ assert cbuf[i] == arr[i]
+
+ for s in [
+ slice(None),
+ slice(size//4, size//2),
+ ]:
+ assert cbuf[s].to_pybytes() == arr[s].tobytes()
+
+ sbuf = cbuf.slice(size//4, size//2)
+ assert sbuf.parent == cbuf
+
+ with pytest.raises(TypeError,
+ match="Do not call CudaBuffer's constructor directly"):
+ cuda.CudaBuffer()
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_HostBuffer(size):
+ arr, buf = make_random_buffer(size)
+ assert arr.tobytes() == buf.to_pybytes()
+ hbuf = cuda.new_host_buffer(size)
+ np.frombuffer(hbuf, dtype=np.uint8)[:] = arr
+ assert hbuf.size == size
+ assert hbuf.is_cpu
+ assert arr.tobytes() == hbuf.to_pybytes()
+ for i in range(size):
+ assert hbuf[i] == arr[i]
+ for s in [
+ slice(None),
+ slice(size//4, size//2),
+ ]:
+ assert hbuf[s].to_pybytes() == arr[s].tobytes()
+
+ sbuf = hbuf.slice(size//4, size//2)
+ assert sbuf.parent == hbuf
+
+ del hbuf
+
+ with pytest.raises(TypeError,
+ match="Do not call HostBuffer's constructor directly"):
+ cuda.HostBuffer()
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_copy_from_to_host(size):
+
+ # Create a buffer in host containing range(size)
+ buf = pa.allocate_buffer(size, resizable=True) # in host
+ assert isinstance(buf, pa.Buffer)
+ assert not isinstance(buf, cuda.CudaBuffer)
+ arr = np.frombuffer(buf, dtype=np.uint8)
+ assert arr.size == size
+ arr[:] = range(size)
+ arr_ = np.frombuffer(buf, dtype=np.uint8)
+ np.testing.assert_equal(arr, arr_)
+
+ device_buffer = global_context.new_buffer(size)
+ assert isinstance(device_buffer, cuda.CudaBuffer)
+ assert isinstance(device_buffer, pa.Buffer)
+ assert device_buffer.size == size
+ assert not device_buffer.is_cpu
+
+ device_buffer.copy_from_host(buf, position=0, nbytes=size)
+
+ buf2 = device_buffer.copy_to_host(position=0, nbytes=size)
+ arr2 = np.frombuffer(buf2, dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_copy_to_host(size):
+ arr, dbuf = make_random_buffer(size, target='device')
+
+ buf = dbuf.copy_to_host()
+ assert buf.is_cpu
+ np.testing.assert_equal(arr, np.frombuffer(buf, dtype=np.uint8))
+
+ buf = dbuf.copy_to_host(position=size//4)
+ assert buf.is_cpu
+ np.testing.assert_equal(arr[size//4:], np.frombuffer(buf, dtype=np.uint8))
+
+ buf = dbuf.copy_to_host(position=size//4, nbytes=size//8)
+ assert buf.is_cpu
+ np.testing.assert_equal(arr[size//4:size//4+size//8],
+ np.frombuffer(buf, dtype=np.uint8))
+
+ buf = dbuf.copy_to_host(position=size//4, nbytes=0)
+ assert buf.is_cpu
+ assert buf.size == 0
+
+ for (position, nbytes) in [
+ (size+2, -1), (-2, -1), (size+1, 0), (-3, 0),
+ ]:
+ with pytest.raises(ValueError,
+ match='position argument is out-of-range'):
+ dbuf.copy_to_host(position=position, nbytes=nbytes)
+
+ for (position, nbytes) in [
+ (0, size+1), (size//2, (size+1)//2+1), (size, 1)
+ ]:
+ with pytest.raises(ValueError,
+ match=('requested more to copy than'
+ ' available from device buffer')):
+ dbuf.copy_to_host(position=position, nbytes=nbytes)
+
+ buf = pa.allocate_buffer(size//4)
+ dbuf.copy_to_host(buf=buf)
+ np.testing.assert_equal(arr[:size//4], np.frombuffer(buf, dtype=np.uint8))
+
+ if size < 12:
+ return
+
+ dbuf.copy_to_host(buf=buf, position=12)
+ np.testing.assert_equal(arr[12:12+size//4],
+ np.frombuffer(buf, dtype=np.uint8))
+
+ dbuf.copy_to_host(buf=buf, nbytes=12)
+ np.testing.assert_equal(arr[:12], np.frombuffer(buf, dtype=np.uint8)[:12])
+
+ dbuf.copy_to_host(buf=buf, nbytes=12, position=6)
+ np.testing.assert_equal(arr[6:6+12],
+ np.frombuffer(buf, dtype=np.uint8)[:12])
+
+ for (position, nbytes) in [
+ (0, size+10), (10, size-5),
+ (0, size//2), (size//4, size//4+1)
+ ]:
+ with pytest.raises(ValueError,
+ match=('requested copy does not '
+ 'fit into host buffer')):
+ dbuf.copy_to_host(buf=buf, position=position, nbytes=nbytes)
+
+
+@pytest.mark.parametrize("dest_ctx", ['same', 'another'])
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_copy_from_device(dest_ctx, size):
+ arr, buf = make_random_buffer(size=size, target='device')
+ lst = arr.tolist()
+ if dest_ctx == 'another':
+ dest_ctx = global_context1
+ if buf.context.device_number == dest_ctx.device_number:
+ pytest.skip("not a multi-GPU system")
+ else:
+ dest_ctx = buf.context
+ dbuf = dest_ctx.new_buffer(size)
+
+ def put(*args, **kwargs):
+ dbuf.copy_from_device(buf, *args, **kwargs)
+ rbuf = dbuf.copy_to_host()
+ return np.frombuffer(rbuf, dtype=np.uint8).tolist()
+ assert put() == lst
+ if size > 4:
+ assert put(position=size//4) == lst[:size//4]+lst[:-size//4]
+ assert put() == lst
+ assert put(position=1, nbytes=size//2) == \
+ lst[:1] + lst[:size//2] + lst[-(size-size//2-1):]
+
+ for (position, nbytes) in [
+ (size+2, -1), (-2, -1), (size+1, 0), (-3, 0),
+ ]:
+ with pytest.raises(ValueError,
+ match='position argument is out-of-range'):
+ put(position=position, nbytes=nbytes)
+
+ for (position, nbytes) in [
+ (0, size+1),
+ ]:
+ with pytest.raises(ValueError,
+ match=('requested more to copy than'
+ ' available from device buffer')):
+ put(position=position, nbytes=nbytes)
+
+ if size < 4:
+ return
+
+ for (position, nbytes) in [
+ (size//2, (size+1)//2+1)
+ ]:
+ with pytest.raises(ValueError,
+ match=('requested more to copy than'
+ ' available in device buffer')):
+ put(position=position, nbytes=nbytes)
+
+
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_copy_from_host(size):
+ arr, buf = make_random_buffer(size=size, target='host')
+ lst = arr.tolist()
+ dbuf = global_context.new_buffer(size)
+
+ def put(*args, **kwargs):
+ dbuf.copy_from_host(buf, *args, **kwargs)
+ rbuf = dbuf.copy_to_host()
+ return np.frombuffer(rbuf, dtype=np.uint8).tolist()
+ assert put() == lst
+ if size > 4:
+ assert put(position=size//4) == lst[:size//4]+lst[:-size//4]
+ assert put() == lst
+ assert put(position=1, nbytes=size//2) == \
+ lst[:1] + lst[:size//2] + lst[-(size-size//2-1):]
+
+ for (position, nbytes) in [
+ (size+2, -1), (-2, -1), (size+1, 0), (-3, 0),
+ ]:
+ with pytest.raises(ValueError,
+ match='position argument is out-of-range'):
+ put(position=position, nbytes=nbytes)
+
+ for (position, nbytes) in [
+ (0, size+1),
+ ]:
+ with pytest.raises(ValueError,
+ match=('requested more to copy than'
+ ' available from host buffer')):
+ put(position=position, nbytes=nbytes)
+
+ if size < 4:
+ return
+
+ for (position, nbytes) in [
+ (size//2, (size+1)//2+1)
+ ]:
+ with pytest.raises(ValueError,
+ match=('requested more to copy than'
+ ' available in device buffer')):
+ put(position=position, nbytes=nbytes)
+
+
+def test_BufferWriter():
+ def allocate(size):
+ cbuf = global_context.new_buffer(size)
+ writer = cuda.BufferWriter(cbuf)
+ return cbuf, writer
+
+ def test_writes(total_size, chunksize, buffer_size=0):
+ cbuf, writer = allocate(total_size)
+ arr, buf = make_random_buffer(size=total_size, target='host')
+
+ if buffer_size > 0:
+ writer.buffer_size = buffer_size
+
+ position = writer.tell()
+ assert position == 0
+ writer.write(buf.slice(length=chunksize))
+ assert writer.tell() == chunksize
+ writer.seek(0)
+ position = writer.tell()
+ assert position == 0
+
+ while position < total_size:
+ bytes_to_write = min(chunksize, total_size - position)
+ writer.write(buf.slice(offset=position, length=bytes_to_write))
+ position += bytes_to_write
+
+ writer.flush()
+ assert cbuf.size == total_size
+ cbuf.context.synchronize()
+ buf2 = cbuf.copy_to_host()
+ cbuf.context.synchronize()
+ assert buf2.size == total_size
+ arr2 = np.frombuffer(buf2, dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+ total_size, chunk_size = 1 << 16, 1000
+ test_writes(total_size, chunk_size)
+ test_writes(total_size, chunk_size, total_size // 16)
+
+ cbuf, writer = allocate(100)
+ writer.write(np.arange(100, dtype=np.uint8))
+ writer.writeat(50, np.arange(25, dtype=np.uint8))
+ writer.write(np.arange(25, dtype=np.uint8))
+ writer.flush()
+
+ arr = np.frombuffer(cbuf.copy_to_host(), np.uint8)
+ np.testing.assert_equal(arr[:50], np.arange(50, dtype=np.uint8))
+ np.testing.assert_equal(arr[50:75], np.arange(25, dtype=np.uint8))
+ np.testing.assert_equal(arr[75:], np.arange(25, dtype=np.uint8))
+
+
+def test_BufferWriter_edge_cases():
+ # edge cases, see cuda-test.cc for more information:
+ size = 1000
+ cbuf = global_context.new_buffer(size)
+ writer = cuda.BufferWriter(cbuf)
+ arr, buf = make_random_buffer(size=size, target='host')
+
+ assert writer.buffer_size == 0
+ writer.buffer_size = 100
+ assert writer.buffer_size == 100
+
+ writer.write(buf.slice(length=0))
+ assert writer.tell() == 0
+
+ writer.write(buf.slice(length=10))
+ writer.buffer_size = 200
+ assert writer.buffer_size == 200
+ assert writer.num_bytes_buffered == 0
+
+ writer.write(buf.slice(offset=10, length=300))
+ assert writer.num_bytes_buffered == 0
+
+ writer.write(buf.slice(offset=310, length=200))
+ assert writer.num_bytes_buffered == 0
+
+ writer.write(buf.slice(offset=510, length=390))
+ writer.write(buf.slice(offset=900, length=100))
+
+ writer.flush()
+
+ buf2 = cbuf.copy_to_host()
+ assert buf2.size == size
+ arr2 = np.frombuffer(buf2, dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+
+def test_BufferReader():
+ size = 1000
+ arr, cbuf = make_random_buffer(size=size, target='device')
+
+ reader = cuda.BufferReader(cbuf)
+ reader.seek(950)
+ assert reader.tell() == 950
+
+ data = reader.read(100)
+ assert len(data) == 50
+ assert reader.tell() == 1000
+
+ reader.seek(925)
+ arr2 = np.zeros(100, dtype=np.uint8)
+ n = reader.readinto(arr2)
+ assert n == 75
+ assert reader.tell() == 1000
+ np.testing.assert_equal(arr[925:], arr2[:75])
+
+ reader.seek(0)
+ assert reader.tell() == 0
+ buf2 = reader.read_buffer()
+ arr2 = np.frombuffer(buf2.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+
+def test_BufferReader_zero_size():
+ arr, cbuf = make_random_buffer(size=0, target='device')
+ reader = cuda.BufferReader(cbuf)
+ reader.seek(0)
+ data = reader.read()
+ assert len(data) == 0
+ assert reader.tell() == 0
+ buf2 = reader.read_buffer()
+ arr2 = np.frombuffer(buf2.copy_to_host(), dtype=np.uint8)
+ np.testing.assert_equal(arr, arr2)
+
+
+def make_recordbatch(length):
+ schema = pa.schema([pa.field('f0', pa.int16()),
+ pa.field('f1', pa.int16())])
+ a0 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16))
+ a1 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16))
+ batch = pa.record_batch([a0, a1], schema=schema)
+ return batch
+
+
+def test_batch_serialize():
+ batch = make_recordbatch(10)
+ hbuf = batch.serialize()
+ cbuf = cuda.serialize_record_batch(batch, global_context)
+
+ # Test that read_record_batch works properly
+ cbatch = cuda.read_record_batch(cbuf, batch.schema)
+ assert isinstance(cbatch, pa.RecordBatch)
+ assert batch.schema == cbatch.schema
+ assert batch.num_columns == cbatch.num_columns
+ assert batch.num_rows == cbatch.num_rows
+
+ # Deserialize CUDA-serialized batch on host
+ buf = cbuf.copy_to_host()
+ assert hbuf.equals(buf)
+ batch2 = pa.ipc.read_record_batch(buf, batch.schema)
+ assert hbuf.equals(batch2.serialize())
+
+ assert batch.num_columns == batch2.num_columns
+ assert batch.num_rows == batch2.num_rows
+ assert batch.column(0).equals(batch2.column(0))
+ assert batch.equals(batch2)
+
+
+def make_table():
+ a0 = pa.array([0, 1, 42, None], type=pa.int16())
+ a1 = pa.array([[0, 1], [2], [], None], type=pa.list_(pa.int32()))
+ a2 = pa.array([("ab", True), ("cde", False), (None, None), None],
+ type=pa.struct([("strs", pa.utf8()),
+ ("bools", pa.bool_())]))
+ # Dictionaries are validated on the IPC read path, but that can produce
+ # issues for GPU-located dictionaries. Check that they work fine.
+ a3 = pa.DictionaryArray.from_arrays(
+ indices=[0, 1, 1, None],
+ dictionary=pa.array(['foo', 'bar']))
+ a4 = pa.DictionaryArray.from_arrays(
+ indices=[2, 1, 2, None],
+ dictionary=a1)
+ a5 = pa.DictionaryArray.from_arrays(
+ indices=[2, 1, 0, None],
+ dictionary=a2)
+
+ arrays = [a0, a1, a2, a3, a4, a5]
+ schema = pa.schema([('f{}'.format(i), arr.type)
+ for i, arr in enumerate(arrays)])
+ batch = pa.record_batch(arrays, schema=schema)
+ table = pa.Table.from_batches([batch])
+ return table
+
+
+def make_table_cuda():
+ htable = make_table()
+ # Serialize the host table to bytes
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, htable.schema) as out:
+ out.write_table(htable)
+ hbuf = pa.py_buffer(sink.getvalue().to_pybytes())
+
+ # Copy the host bytes to a device buffer
+ dbuf = global_context.new_buffer(len(hbuf))
+ dbuf.copy_from_host(hbuf, nbytes=len(hbuf))
+ # Deserialize the device buffer into a Table
+ dtable = pa.ipc.open_stream(cuda.BufferReader(dbuf)).read_all()
+ return hbuf, htable, dbuf, dtable
+
+
+def test_table_deserialize():
+ # ARROW-9659: make sure that we can deserialize a GPU-located table
+ # without crashing when initializing or validating the underlying arrays.
+ hbuf, htable, dbuf, dtable = make_table_cuda()
+ # Assert basic fields the same between host and device tables
+ assert htable.schema == dtable.schema
+ assert htable.num_rows == dtable.num_rows
+ assert htable.num_columns == dtable.num_columns
+ # Assert byte-level equality
+ assert hbuf.equals(dbuf.copy_to_host())
+ # Copy DtoH and assert the tables are still equivalent
+ assert htable.equals(pa.ipc.open_stream(
+ dbuf.copy_to_host()
+ ).read_all())
+
+
+def test_create_table_with_device_buffers():
+ # ARROW-11872: make sure that we can create an Arrow Table from
+ # GPU-located Arrays without crashing.
+ hbuf, htable, dbuf, dtable = make_table_cuda()
+ # Construct a new Table from the device Table
+ dtable2 = pa.Table.from_arrays(dtable.columns, dtable.column_names)
+ # Assert basic fields the same between host and device tables
+ assert htable.schema == dtable2.schema
+ assert htable.num_rows == dtable2.num_rows
+ assert htable.num_columns == dtable2.num_columns
+ # Assert byte-level equality
+ assert hbuf.equals(dbuf.copy_to_host())
+ # Copy DtoH and assert the tables are still equivalent
+ assert htable.equals(pa.ipc.open_stream(
+ dbuf.copy_to_host()
+ ).read_all())
+
+
+def other_process_for_test_IPC(handle_buffer, expected_arr):
+ other_context = pa.cuda.Context(0)
+ ipc_handle = pa.cuda.IpcMemHandle.from_buffer(handle_buffer)
+ ipc_buf = other_context.open_ipc_buffer(ipc_handle)
+ ipc_buf.context.synchronize()
+ buf = ipc_buf.copy_to_host()
+ assert buf.size == expected_arr.size, repr((buf.size, expected_arr.size))
+ arr = np.frombuffer(buf, dtype=expected_arr.dtype)
+ np.testing.assert_equal(arr, expected_arr)
+
+
+@cuda_ipc
+@pytest.mark.parametrize("size", [0, 1, 1000])
+def test_IPC(size):
+ import multiprocessing
+ ctx = multiprocessing.get_context('spawn')
+ arr, cbuf = make_random_buffer(size=size, target='device')
+ ipc_handle = cbuf.export_for_ipc()
+ handle_buffer = ipc_handle.serialize()
+ p = ctx.Process(target=other_process_for_test_IPC,
+ args=(handle_buffer, arr))
+ p.start()
+ p.join()
+ assert p.exitcode == 0
diff --git a/src/arrow/python/pyarrow/tests/test_cuda_numba_interop.py b/src/arrow/python/pyarrow/tests/test_cuda_numba_interop.py
new file mode 100644
index 000000000..ff1722d27
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_cuda_numba_interop.py
@@ -0,0 +1,235 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import pyarrow as pa
+import numpy as np
+
+dtypes = ['uint8', 'int16', 'float32']
+cuda = pytest.importorskip("pyarrow.cuda")
+nb_cuda = pytest.importorskip("numba.cuda")
+
+from numba.cuda.cudadrv.devicearray import DeviceNDArray # noqa: E402
+
+
+context_choices = None
+context_choice_ids = ['pyarrow.cuda', 'numba.cuda']
+
+
+def setup_module(module):
+ np.random.seed(1234)
+ ctx1 = cuda.Context()
+ nb_ctx1 = ctx1.to_numba()
+ nb_ctx2 = nb_cuda.current_context()
+ ctx2 = cuda.Context.from_numba(nb_ctx2)
+ module.context_choices = [(ctx1, nb_ctx1), (ctx2, nb_ctx2)]
+
+
+def teardown_module(module):
+ del module.context_choices
+
+
+@pytest.mark.parametrize("c", range(len(context_choice_ids)),
+ ids=context_choice_ids)
+def test_context(c):
+ ctx, nb_ctx = context_choices[c]
+ assert ctx.handle == nb_ctx.handle.value
+ assert ctx.handle == ctx.to_numba().handle.value
+ ctx2 = cuda.Context.from_numba(nb_ctx)
+ assert ctx.handle == ctx2.handle
+ size = 10
+ buf = ctx.new_buffer(size)
+ assert ctx.handle == buf.context.handle
+
+
+def make_random_buffer(size, target='host', dtype='uint8', ctx=None):
+ """Return a host or device buffer with random data.
+ """
+ dtype = np.dtype(dtype)
+ if target == 'host':
+ assert size >= 0
+ buf = pa.allocate_buffer(size*dtype.itemsize)
+ arr = np.frombuffer(buf, dtype=dtype)
+ arr[:] = np.random.randint(low=0, high=255, size=size,
+ dtype=np.uint8)
+ return arr, buf
+ elif target == 'device':
+ arr, buf = make_random_buffer(size, target='host', dtype=dtype)
+ dbuf = ctx.new_buffer(size * dtype.itemsize)
+ dbuf.copy_from_host(buf, position=0, nbytes=buf.size)
+ return arr, dbuf
+ raise ValueError('invalid target value')
+
+
+@pytest.mark.parametrize("c", range(len(context_choice_ids)),
+ ids=context_choice_ids)
+@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
+@pytest.mark.parametrize("size", [0, 1, 8, 1000])
+def test_from_object(c, dtype, size):
+ ctx, nb_ctx = context_choices[c]
+ arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx)
+
+ # Creating device buffer from numba DeviceNDArray:
+ darr = nb_cuda.to_device(arr)
+ cbuf2 = ctx.buffer_from_object(darr)
+ assert cbuf2.size == cbuf.size
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr, arr2)
+
+ # Creating device buffer from a slice of numba DeviceNDArray:
+ if size >= 8:
+ # 1-D arrays
+ for s in [slice(size//4, None, None),
+ slice(size//4, -(size//4), None)]:
+ cbuf2 = ctx.buffer_from_object(darr[s])
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr[s], arr2)
+
+ # cannot test negative strides due to numba bug, see its issue 3705
+ if 0:
+ rdarr = darr[::-1]
+ cbuf2 = ctx.buffer_from_object(rdarr)
+ assert cbuf2.size == cbuf.size
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr, arr2)
+
+ with pytest.raises(ValueError,
+ match=('array data is non-contiguous')):
+ ctx.buffer_from_object(darr[::2])
+
+ # a rectangular 2-D array
+ s1 = size//4
+ s2 = size//s1
+ assert s1 * s2 == size
+ cbuf2 = ctx.buffer_from_object(darr.reshape(s1, s2))
+ assert cbuf2.size == cbuf.size
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr, arr2)
+
+ with pytest.raises(ValueError,
+ match=('array data is non-contiguous')):
+ ctx.buffer_from_object(darr.reshape(s1, s2)[:, ::2])
+
+ # a 3-D array
+ s1 = 4
+ s2 = size//8
+ s3 = size//(s1*s2)
+ assert s1 * s2 * s3 == size
+ cbuf2 = ctx.buffer_from_object(darr.reshape(s1, s2, s3))
+ assert cbuf2.size == cbuf.size
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr, arr2)
+
+ with pytest.raises(ValueError,
+ match=('array data is non-contiguous')):
+ ctx.buffer_from_object(darr.reshape(s1, s2, s3)[::2])
+
+ # Creating device buffer from am object implementing cuda array
+ # interface:
+ class MyObj:
+ def __init__(self, darr):
+ self.darr = darr
+
+ @property
+ def __cuda_array_interface__(self):
+ return self.darr.__cuda_array_interface__
+
+ cbuf2 = ctx.buffer_from_object(MyObj(darr))
+ assert cbuf2.size == cbuf.size
+ arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr, arr2)
+
+
+@pytest.mark.parametrize("c", range(len(context_choice_ids)),
+ ids=context_choice_ids)
+@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
+def test_numba_memalloc(c, dtype):
+ ctx, nb_ctx = context_choices[c]
+ dtype = np.dtype(dtype)
+ # Allocate memory using numba context
+ # Warning: this will not be reflected in pyarrow context manager
+ # (e.g bytes_allocated does not change)
+ size = 10
+ mem = nb_ctx.memalloc(size * dtype.itemsize)
+ darr = DeviceNDArray((size,), (dtype.itemsize,), dtype, gpu_data=mem)
+ darr[:5] = 99
+ darr[5:] = 88
+ np.testing.assert_equal(darr.copy_to_host()[:5], 99)
+ np.testing.assert_equal(darr.copy_to_host()[5:], 88)
+
+ # wrap numba allocated memory with CudaBuffer
+ cbuf = cuda.CudaBuffer.from_numba(mem)
+ arr2 = np.frombuffer(cbuf.copy_to_host(), dtype=dtype)
+ np.testing.assert_equal(arr2, darr.copy_to_host())
+
+
+@pytest.mark.parametrize("c", range(len(context_choice_ids)),
+ ids=context_choice_ids)
+@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
+def test_pyarrow_memalloc(c, dtype):
+ ctx, nb_ctx = context_choices[c]
+ size = 10
+ arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx)
+
+ # wrap CudaBuffer with numba device array
+ mem = cbuf.to_numba()
+ darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem)
+ np.testing.assert_equal(darr.copy_to_host(), arr)
+
+
+@pytest.mark.parametrize("c", range(len(context_choice_ids)),
+ ids=context_choice_ids)
+@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
+def test_numba_context(c, dtype):
+ ctx, nb_ctx = context_choices[c]
+ size = 10
+ with nb_cuda.gpus[0]:
+ arr, cbuf = make_random_buffer(size, target='device',
+ dtype=dtype, ctx=ctx)
+ assert cbuf.context.handle == nb_ctx.handle.value
+ mem = cbuf.to_numba()
+ darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem)
+ np.testing.assert_equal(darr.copy_to_host(), arr)
+ darr[0] = 99
+ cbuf.context.synchronize()
+ arr2 = np.frombuffer(cbuf.copy_to_host(), dtype=dtype)
+ assert arr2[0] == 99
+
+
+@pytest.mark.parametrize("c", range(len(context_choice_ids)),
+ ids=context_choice_ids)
+@pytest.mark.parametrize("dtype", dtypes, ids=dtypes)
+def test_pyarrow_jit(c, dtype):
+ ctx, nb_ctx = context_choices[c]
+
+ @nb_cuda.jit
+ def increment_by_one(an_array):
+ pos = nb_cuda.grid(1)
+ if pos < an_array.size:
+ an_array[pos] += 1
+
+ # applying numba.cuda kernel to memory hold by CudaBuffer
+ size = 10
+ arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx)
+ threadsperblock = 32
+ blockspergrid = (arr.size + (threadsperblock - 1)) // threadsperblock
+ mem = cbuf.to_numba()
+ darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem)
+ increment_by_one[blockspergrid, threadsperblock](darr)
+ cbuf.context.synchronize()
+ arr1 = np.frombuffer(cbuf.copy_to_host(), dtype=arr.dtype)
+ np.testing.assert_equal(arr1, arr + 1)
diff --git a/src/arrow/python/pyarrow/tests/test_cython.py b/src/arrow/python/pyarrow/tests/test_cython.py
new file mode 100644
index 000000000..e202b417a
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_cython.py
@@ -0,0 +1,180 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import shutil
+import subprocess
+import sys
+
+import pytest
+
+import pyarrow as pa
+import pyarrow.tests.util as test_util
+
+
+here = os.path.dirname(os.path.abspath(__file__))
+test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '')
+if os.name == 'posix':
+ compiler_opts = ['-std=c++11']
+else:
+ compiler_opts = []
+
+
+setup_template = """if 1:
+ from setuptools import setup
+ from Cython.Build import cythonize
+
+ import numpy as np
+
+ import pyarrow as pa
+
+ ext_modules = cythonize({pyx_file!r})
+ compiler_opts = {compiler_opts!r}
+ custom_ld_path = {test_ld_path!r}
+
+ for ext in ext_modules:
+ # XXX required for numpy/numpyconfig.h,
+ # included from arrow/python/api.h
+ ext.include_dirs.append(np.get_include())
+ ext.include_dirs.append(pa.get_include())
+ ext.libraries.extend(pa.get_libraries())
+ ext.library_dirs.extend(pa.get_library_dirs())
+ if custom_ld_path:
+ ext.library_dirs.append(custom_ld_path)
+ ext.extra_compile_args.extend(compiler_opts)
+ print("Extension module:",
+ ext, ext.include_dirs, ext.libraries, ext.library_dirs)
+
+ setup(
+ ext_modules=ext_modules,
+ )
+"""
+
+
+def check_cython_example_module(mod):
+ arr = pa.array([1, 2, 3])
+ assert mod.get_array_length(arr) == 3
+ with pytest.raises(TypeError, match="not an array"):
+ mod.get_array_length(None)
+
+ scal = pa.scalar(123)
+ cast_scal = mod.cast_scalar(scal, pa.utf8())
+ assert cast_scal == pa.scalar("123")
+ with pytest.raises(NotImplementedError,
+ match="casting scalars of type int64 to type list"):
+ mod.cast_scalar(scal, pa.list_(pa.int64()))
+
+
+@pytest.mark.cython
+def test_cython_api(tmpdir):
+ """
+ Basic test for the Cython API.
+ """
+ # Fail early if cython is not found
+ import cython # noqa
+
+ with tmpdir.as_cwd():
+ # Set up temporary workspace
+ pyx_file = 'pyarrow_cython_example.pyx'
+ shutil.copyfile(os.path.join(here, pyx_file),
+ os.path.join(str(tmpdir), pyx_file))
+ # Create setup.py file
+ setup_code = setup_template.format(pyx_file=pyx_file,
+ compiler_opts=compiler_opts,
+ test_ld_path=test_ld_path)
+ with open('setup.py', 'w') as f:
+ f.write(setup_code)
+
+ # ARROW-2263: Make environment with this pyarrow/ package first on the
+ # PYTHONPATH, for local dev environments
+ subprocess_env = test_util.get_modified_env_with_pythonpath()
+
+ # Compile extension module
+ subprocess.check_call([sys.executable, 'setup.py',
+ 'build_ext', '--inplace'],
+ env=subprocess_env)
+
+ # Check basic functionality
+ orig_path = sys.path[:]
+ sys.path.insert(0, str(tmpdir))
+ try:
+ mod = __import__('pyarrow_cython_example')
+ check_cython_example_module(mod)
+ finally:
+ sys.path = orig_path
+
+ # Check the extension module is loadable from a subprocess without
+ # pyarrow imported first.
+ code = """if 1:
+ import sys
+
+ mod = __import__({mod_name!r})
+ arr = mod.make_null_array(5)
+ assert mod.get_array_length(arr) == 5
+ assert arr.null_count == 5
+ """.format(mod_name='pyarrow_cython_example')
+
+ if sys.platform == 'win32':
+ delim, var = ';', 'PATH'
+ else:
+ delim, var = ':', 'LD_LIBRARY_PATH'
+
+ subprocess_env[var] = delim.join(
+ pa.get_library_dirs() + [subprocess_env.get(var, '')]
+ )
+
+ subprocess.check_call([sys.executable, '-c', code],
+ stdout=subprocess.PIPE,
+ env=subprocess_env)
+
+
+@pytest.mark.cython
+def test_visit_strings(tmpdir):
+ with tmpdir.as_cwd():
+ # Set up temporary workspace
+ pyx_file = 'bound_function_visit_strings.pyx'
+ shutil.copyfile(os.path.join(here, pyx_file),
+ os.path.join(str(tmpdir), pyx_file))
+ # Create setup.py file
+ setup_code = setup_template.format(pyx_file=pyx_file,
+ compiler_opts=compiler_opts,
+ test_ld_path=test_ld_path)
+ with open('setup.py', 'w') as f:
+ f.write(setup_code)
+
+ subprocess_env = test_util.get_modified_env_with_pythonpath()
+
+ # Compile extension module
+ subprocess.check_call([sys.executable, 'setup.py',
+ 'build_ext', '--inplace'],
+ env=subprocess_env)
+
+ sys.path.insert(0, str(tmpdir))
+ mod = __import__('bound_function_visit_strings')
+
+ strings = ['a', 'b', 'c']
+ visited = []
+ mod._visit_strings(strings, visited.append)
+
+ assert visited == strings
+
+ with pytest.raises(ValueError, match="wtf"):
+ def raise_on_b(s):
+ if s == 'b':
+ raise ValueError('wtf')
+
+ mod._visit_strings(strings, raise_on_b)
diff --git a/src/arrow/python/pyarrow/tests/test_dataset.py b/src/arrow/python/pyarrow/tests/test_dataset.py
new file mode 100644
index 000000000..20b12316b
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_dataset.py
@@ -0,0 +1,3976 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import contextlib
+import os
+import posixpath
+import pathlib
+import pickle
+import textwrap
+import tempfile
+import threading
+import time
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+import pyarrow.csv
+import pyarrow.feather
+import pyarrow.fs as fs
+from pyarrow.tests.util import (change_cwd, _filesystem_uri,
+ FSProtocolClass, ProxyHandler)
+
+try:
+ import pandas as pd
+except ImportError:
+ pd = None
+
+try:
+ import pyarrow.dataset as ds
+except ImportError:
+ ds = None
+
+# Marks all of the tests in this module
+# Ignore these with pytest ... -m 'not dataset'
+pytestmark = pytest.mark.dataset
+
+
+def _generate_data(n):
+ import datetime
+ import itertools
+
+ day = datetime.datetime(2000, 1, 1)
+ interval = datetime.timedelta(days=5)
+ colors = itertools.cycle(['green', 'blue', 'yellow', 'red', 'orange'])
+
+ data = []
+ for i in range(n):
+ data.append((day, i, float(i), next(colors)))
+ day += interval
+
+ return pd.DataFrame(data, columns=['date', 'index', 'value', 'color'])
+
+
+def _table_from_pandas(df):
+ schema = pa.schema([
+ pa.field('date', pa.date32()),
+ pa.field('index', pa.int64()),
+ pa.field('value', pa.float64()),
+ pa.field('color', pa.string()),
+ ])
+ table = pa.Table.from_pandas(df, schema=schema, preserve_index=False)
+ return table.replace_schema_metadata()
+
+
+@pytest.fixture
+@pytest.mark.parquet
+def mockfs():
+ import pyarrow.parquet as pq
+
+ mockfs = fs._MockFileSystem()
+
+ directories = [
+ 'subdir/1/xxx',
+ 'subdir/2/yyy',
+ ]
+
+ for i, directory in enumerate(directories):
+ path = '{}/file{}.parquet'.format(directory, i)
+ mockfs.create_dir(directory)
+ with mockfs.open_output_stream(path) as out:
+ data = [
+ list(range(5)),
+ list(map(float, range(5))),
+ list(map(str, range(5))),
+ [i] * 5
+ ]
+ schema = pa.schema([
+ pa.field('i64', pa.int64()),
+ pa.field('f64', pa.float64()),
+ pa.field('str', pa.string()),
+ pa.field('const', pa.int64()),
+ ])
+ batch = pa.record_batch(data, schema=schema)
+ table = pa.Table.from_batches([batch])
+
+ pq.write_table(table, out)
+
+ return mockfs
+
+
+@pytest.fixture
+def open_logging_fs(monkeypatch):
+ from pyarrow.fs import PyFileSystem, LocalFileSystem
+ from .test_fs import ProxyHandler
+
+ localfs = LocalFileSystem()
+
+ def normalized(paths):
+ return {localfs.normalize_path(str(p)) for p in paths}
+
+ opened = set()
+
+ def open_input_file(self, path):
+ path = localfs.normalize_path(str(path))
+ opened.add(path)
+ return self._fs.open_input_file(path)
+
+ # patch proxyhandler to log calls to open_input_file
+ monkeypatch.setattr(ProxyHandler, "open_input_file", open_input_file)
+ fs = PyFileSystem(ProxyHandler(localfs))
+
+ @contextlib.contextmanager
+ def assert_opens(expected_opened):
+ opened.clear()
+ try:
+ yield
+ finally:
+ assert normalized(opened) == normalized(expected_opened)
+
+ return fs, assert_opens
+
+
+@pytest.fixture(scope='module')
+def multisourcefs(request):
+ request.config.pyarrow.requires('pandas')
+ request.config.pyarrow.requires('parquet')
+ import pyarrow.parquet as pq
+
+ df = _generate_data(1000)
+ mockfs = fs._MockFileSystem()
+
+ # simply split the dataframe into four chunks to construct a data source
+ # from each chunk into its own directory
+ df_a, df_b, df_c, df_d = np.array_split(df, 4)
+
+ # create a directory containing a flat sequence of parquet files without
+ # any partitioning involved
+ mockfs.create_dir('plain')
+ for i, chunk in enumerate(np.array_split(df_a, 10)):
+ path = 'plain/chunk-{}.parquet'.format(i)
+ with mockfs.open_output_stream(path) as out:
+ pq.write_table(_table_from_pandas(chunk), out)
+
+ # create one with schema partitioning by weekday and color
+ mockfs.create_dir('schema')
+ for part, chunk in df_b.groupby([df_b.date.dt.dayofweek, df_b.color]):
+ folder = 'schema/{}/{}'.format(*part)
+ path = '{}/chunk.parquet'.format(folder)
+ mockfs.create_dir(folder)
+ with mockfs.open_output_stream(path) as out:
+ pq.write_table(_table_from_pandas(chunk), out)
+
+ # create one with hive partitioning by year and month
+ mockfs.create_dir('hive')
+ for part, chunk in df_c.groupby([df_c.date.dt.year, df_c.date.dt.month]):
+ folder = 'hive/year={}/month={}'.format(*part)
+ path = '{}/chunk.parquet'.format(folder)
+ mockfs.create_dir(folder)
+ with mockfs.open_output_stream(path) as out:
+ pq.write_table(_table_from_pandas(chunk), out)
+
+ # create one with hive partitioning by color
+ mockfs.create_dir('hive_color')
+ for part, chunk in df_d.groupby(["color"]):
+ folder = 'hive_color/color={}'.format(*part)
+ path = '{}/chunk.parquet'.format(folder)
+ mockfs.create_dir(folder)
+ with mockfs.open_output_stream(path) as out:
+ pq.write_table(_table_from_pandas(chunk), out)
+
+ return mockfs
+
+
+@pytest.fixture
+def dataset(mockfs):
+ format = ds.ParquetFileFormat()
+ selector = fs.FileSelector('subdir', recursive=True)
+ options = ds.FileSystemFactoryOptions('subdir')
+ options.partitioning = ds.DirectoryPartitioning(
+ pa.schema([
+ pa.field('group', pa.int32()),
+ pa.field('key', pa.string())
+ ])
+ )
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ return factory.finish()
+
+
+@pytest.fixture(params=[
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False)
+], ids=['threaded-async', 'threaded-sync', 'serial-async', 'serial-sync'])
+def dataset_reader(request):
+ '''
+ Fixture which allows dataset scanning operations to be
+ run with/without threads and with/without async
+ '''
+ use_threads, use_async = request.param
+
+ class reader:
+
+ def __init__(self):
+ self.use_threads = use_threads
+ self.use_async = use_async
+
+ def _patch_kwargs(self, kwargs):
+ if 'use_threads' in kwargs:
+ raise Exception(
+ ('Invalid use of dataset_reader, do not specify'
+ ' use_threads'))
+ if 'use_async' in kwargs:
+ raise Exception(
+ 'Invalid use of dataset_reader, do not specify use_async')
+ kwargs['use_threads'] = use_threads
+ kwargs['use_async'] = use_async
+
+ def to_table(self, dataset, **kwargs):
+ self._patch_kwargs(kwargs)
+ return dataset.to_table(**kwargs)
+
+ def to_batches(self, dataset, **kwargs):
+ self._patch_kwargs(kwargs)
+ return dataset.to_batches(**kwargs)
+
+ def scanner(self, dataset, **kwargs):
+ self._patch_kwargs(kwargs)
+ return dataset.scanner(**kwargs)
+
+ def head(self, dataset, num_rows, **kwargs):
+ self._patch_kwargs(kwargs)
+ return dataset.head(num_rows, **kwargs)
+
+ def take(self, dataset, indices, **kwargs):
+ self._patch_kwargs(kwargs)
+ return dataset.take(indices, **kwargs)
+
+ def count_rows(self, dataset, **kwargs):
+ self._patch_kwargs(kwargs)
+ return dataset.count_rows(**kwargs)
+
+ return reader()
+
+
+def test_filesystem_dataset(mockfs):
+ schema = pa.schema([
+ pa.field('const', pa.int64())
+ ])
+ file_format = ds.ParquetFileFormat()
+ paths = ['subdir/1/xxx/file0.parquet', 'subdir/2/yyy/file1.parquet']
+ partitions = [ds.field('part') == x for x in range(1, 3)]
+ fragments = [file_format.make_fragment(path, mockfs, part)
+ for path, part in zip(paths, partitions)]
+ root_partition = ds.field('level') == ds.scalar(1337)
+
+ dataset_from_fragments = ds.FileSystemDataset(
+ fragments, schema=schema, format=file_format,
+ filesystem=mockfs, root_partition=root_partition,
+ )
+ dataset_from_paths = ds.FileSystemDataset.from_paths(
+ paths, schema=schema, format=file_format, filesystem=mockfs,
+ partitions=partitions, root_partition=root_partition,
+ )
+
+ for dataset in [dataset_from_fragments, dataset_from_paths]:
+ assert isinstance(dataset, ds.FileSystemDataset)
+ assert isinstance(dataset.format, ds.ParquetFileFormat)
+ assert dataset.partition_expression.equals(root_partition)
+ assert set(dataset.files) == set(paths)
+
+ fragments = list(dataset.get_fragments())
+ for fragment, partition, path in zip(fragments, partitions, paths):
+ assert fragment.partition_expression.equals(partition)
+ assert fragment.path == path
+ assert isinstance(fragment.format, ds.ParquetFileFormat)
+ assert isinstance(fragment, ds.ParquetFileFragment)
+ assert fragment.row_groups == [0]
+ assert fragment.num_row_groups == 1
+
+ row_group_fragments = list(fragment.split_by_row_group())
+ assert fragment.num_row_groups == len(row_group_fragments) == 1
+ assert isinstance(row_group_fragments[0], ds.ParquetFileFragment)
+ assert row_group_fragments[0].path == path
+ assert row_group_fragments[0].row_groups == [0]
+ assert row_group_fragments[0].num_row_groups == 1
+
+ fragments = list(dataset.get_fragments(filter=ds.field("const") == 0))
+ assert len(fragments) == 2
+
+ # the root_partition keyword has a default
+ dataset = ds.FileSystemDataset(
+ fragments, schema=schema, format=file_format, filesystem=mockfs
+ )
+ assert dataset.partition_expression.equals(ds.scalar(True))
+
+ # from_paths partitions have defaults
+ dataset = ds.FileSystemDataset.from_paths(
+ paths, schema=schema, format=file_format, filesystem=mockfs
+ )
+ assert dataset.partition_expression.equals(ds.scalar(True))
+ for fragment in dataset.get_fragments():
+ assert fragment.partition_expression.equals(ds.scalar(True))
+
+ # validation of required arguments
+ with pytest.raises(TypeError, match="incorrect type"):
+ ds.FileSystemDataset(fragments, file_format, schema)
+ # validation of root_partition
+ with pytest.raises(TypeError, match="incorrect type"):
+ ds.FileSystemDataset(fragments, schema=schema,
+ format=file_format, root_partition=1)
+ # missing required argument in from_paths
+ with pytest.raises(TypeError, match="incorrect type"):
+ ds.FileSystemDataset.from_paths(fragments, format=file_format)
+
+
+def test_filesystem_dataset_no_filesystem_interaction(dataset_reader):
+ # ARROW-8283
+ schema = pa.schema([
+ pa.field('f1', pa.int64())
+ ])
+ file_format = ds.IpcFileFormat()
+ paths = ['nonexistingfile.arrow']
+
+ # creating the dataset itself doesn't raise
+ dataset = ds.FileSystemDataset.from_paths(
+ paths, schema=schema, format=file_format,
+ filesystem=fs.LocalFileSystem(),
+ )
+
+ # getting fragments also doesn't raise
+ dataset.get_fragments()
+
+ # scanning does raise
+ with pytest.raises(FileNotFoundError):
+ dataset_reader.to_table(dataset)
+
+
+def test_dataset(dataset, dataset_reader):
+ assert isinstance(dataset, ds.Dataset)
+ assert isinstance(dataset.schema, pa.Schema)
+
+ # TODO(kszucs): test non-boolean Exprs for filter do raise
+ expected_i64 = pa.array([0, 1, 2, 3, 4], type=pa.int64())
+ expected_f64 = pa.array([0, 1, 2, 3, 4], type=pa.float64())
+
+ for batch in dataset_reader.to_batches(dataset):
+ assert isinstance(batch, pa.RecordBatch)
+ assert batch.column(0).equals(expected_i64)
+ assert batch.column(1).equals(expected_f64)
+
+ for batch in dataset_reader.scanner(dataset).scan_batches():
+ assert isinstance(batch, ds.TaggedRecordBatch)
+ assert isinstance(batch.fragment, ds.Fragment)
+
+ table = dataset_reader.to_table(dataset)
+ assert isinstance(table, pa.Table)
+ assert len(table) == 10
+
+ condition = ds.field('i64') == 1
+ result = dataset.to_table(use_threads=True, filter=condition).to_pydict()
+
+ # don't rely on the scanning order
+ assert result['i64'] == [1, 1]
+ assert result['f64'] == [1., 1.]
+ assert sorted(result['group']) == [1, 2]
+ assert sorted(result['key']) == ['xxx', 'yyy']
+
+
+def test_scanner(dataset, dataset_reader):
+ scanner = dataset_reader.scanner(
+ dataset, memory_pool=pa.default_memory_pool())
+ assert isinstance(scanner, ds.Scanner)
+
+ with pytest.raises(pa.ArrowInvalid):
+ dataset_reader.scanner(dataset, columns=['unknown'])
+
+ scanner = dataset_reader.scanner(dataset, columns=['i64'],
+ memory_pool=pa.default_memory_pool())
+ assert scanner.dataset_schema == dataset.schema
+ assert scanner.projected_schema == pa.schema([("i64", pa.int64())])
+
+ assert isinstance(scanner, ds.Scanner)
+ table = scanner.to_table()
+ for batch in scanner.to_batches():
+ assert batch.schema == scanner.projected_schema
+ assert batch.num_columns == 1
+ assert table == scanner.to_reader().read_all()
+
+ assert table.schema == scanner.projected_schema
+ for i in range(table.num_rows):
+ indices = pa.array([i])
+ assert table.take(indices) == scanner.take(indices)
+ with pytest.raises(pa.ArrowIndexError):
+ scanner.take(pa.array([table.num_rows]))
+
+ assert table.num_rows == scanner.count_rows()
+
+
+def test_head(dataset, dataset_reader):
+ result = dataset_reader.head(dataset, 0)
+ assert result == pa.Table.from_batches([], schema=dataset.schema)
+
+ result = dataset_reader.head(dataset, 1, columns=['i64']).to_pydict()
+ assert result == {'i64': [0]}
+
+ result = dataset_reader.head(dataset, 2, columns=['i64'],
+ filter=ds.field('i64') > 1).to_pydict()
+ assert result == {'i64': [2, 3]}
+
+ result = dataset_reader.head(dataset, 1024, columns=['i64']).to_pydict()
+ assert result == {'i64': list(range(5)) * 2}
+
+ fragment = next(dataset.get_fragments())
+ result = fragment.head(1, columns=['i64']).to_pydict()
+ assert result == {'i64': [0]}
+
+ result = fragment.head(1024, columns=['i64']).to_pydict()
+ assert result == {'i64': list(range(5))}
+
+
+def test_take(dataset, dataset_reader):
+ fragment = next(dataset.get_fragments())
+ indices = pa.array([1, 3])
+ assert dataset_reader.take(
+ fragment, indices) == dataset_reader.to_table(fragment).take(indices)
+ with pytest.raises(IndexError):
+ dataset_reader.take(fragment, pa.array([5]))
+
+ indices = pa.array([1, 7])
+ assert dataset_reader.take(
+ dataset, indices) == dataset_reader.to_table(dataset).take(indices)
+ with pytest.raises(IndexError):
+ dataset_reader.take(dataset, pa.array([10]))
+
+
+def test_count_rows(dataset, dataset_reader):
+ fragment = next(dataset.get_fragments())
+ assert dataset_reader.count_rows(fragment) == 5
+ assert dataset_reader.count_rows(
+ fragment, filter=ds.field("i64") == 4) == 1
+
+ assert dataset_reader.count_rows(dataset) == 10
+ # Filter on partition key
+ assert dataset_reader.count_rows(
+ dataset, filter=ds.field("group") == 1) == 5
+ # Filter on data
+ assert dataset_reader.count_rows(dataset, filter=ds.field("i64") >= 3) == 4
+ assert dataset_reader.count_rows(dataset, filter=ds.field("i64") < 0) == 0
+
+
+def test_abstract_classes():
+ classes = [
+ ds.FileFormat,
+ ds.Scanner,
+ ds.Partitioning,
+ ]
+ for klass in classes:
+ with pytest.raises(TypeError):
+ klass()
+
+
+def test_partitioning():
+ schema = pa.schema([
+ pa.field('i64', pa.int64()),
+ pa.field('f64', pa.float64())
+ ])
+ for klass in [ds.DirectoryPartitioning, ds.HivePartitioning]:
+ partitioning = klass(schema)
+ assert isinstance(partitioning, ds.Partitioning)
+
+ partitioning = ds.DirectoryPartitioning(
+ pa.schema([
+ pa.field('group', pa.int64()),
+ pa.field('key', pa.float64())
+ ])
+ )
+ assert partitioning.dictionaries is None
+ expr = partitioning.parse('/3/3.14')
+ assert isinstance(expr, ds.Expression)
+
+ expected = (ds.field('group') == 3) & (ds.field('key') == 3.14)
+ assert expr.equals(expected)
+
+ with pytest.raises(pa.ArrowInvalid):
+ partitioning.parse('/prefix/3/aaa')
+
+ expr = partitioning.parse('/3')
+ expected = ds.field('group') == 3
+ assert expr.equals(expected)
+
+ partitioning = ds.HivePartitioning(
+ pa.schema([
+ pa.field('alpha', pa.int64()),
+ pa.field('beta', pa.int64())
+ ]),
+ null_fallback='xyz'
+ )
+ assert partitioning.dictionaries is None
+ expr = partitioning.parse('/alpha=0/beta=3')
+ expected = (
+ (ds.field('alpha') == ds.scalar(0)) &
+ (ds.field('beta') == ds.scalar(3))
+ )
+ assert expr.equals(expected)
+
+ expr = partitioning.parse('/alpha=xyz/beta=3')
+ expected = (
+ (ds.field('alpha').is_null() & (ds.field('beta') == ds.scalar(3)))
+ )
+ assert expr.equals(expected)
+
+ for shouldfail in ['/alpha=one/beta=2', '/alpha=one', '/beta=two']:
+ with pytest.raises(pa.ArrowInvalid):
+ partitioning.parse(shouldfail)
+
+
+def test_expression_serialization():
+ a = ds.scalar(1)
+ b = ds.scalar(1.1)
+ c = ds.scalar(True)
+ d = ds.scalar("string")
+ e = ds.scalar(None)
+ f = ds.scalar({'a': 1})
+ g = ds.scalar(pa.scalar(1))
+
+ all_exprs = [a, b, c, d, e, f, g, a == b, a > b, a & b, a | b, ~c,
+ d.is_valid(), a.cast(pa.int32(), safe=False),
+ a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
+ ds.field('i64') > 5, ds.field('i64') == 5,
+ ds.field('i64') == 7, ds.field('i64').is_null()]
+ for expr in all_exprs:
+ assert isinstance(expr, ds.Expression)
+ restored = pickle.loads(pickle.dumps(expr))
+ assert expr.equals(restored)
+
+
+def test_expression_construction():
+ zero = ds.scalar(0)
+ one = ds.scalar(1)
+ true = ds.scalar(True)
+ false = ds.scalar(False)
+ string = ds.scalar("string")
+ field = ds.field("field")
+
+ zero | one == string
+ ~true == false
+ for typ in ("bool", pa.bool_()):
+ field.cast(typ) == true
+
+ field.isin([1, 2])
+
+ with pytest.raises(TypeError):
+ field.isin(1)
+
+ with pytest.raises(pa.ArrowInvalid):
+ field != object()
+
+
+def test_expression_boolean_operators():
+ # https://issues.apache.org/jira/browse/ARROW-11412
+ true = ds.scalar(True)
+ false = ds.scalar(False)
+
+ with pytest.raises(ValueError, match="cannot be evaluated to python True"):
+ true and false
+
+ with pytest.raises(ValueError, match="cannot be evaluated to python True"):
+ true or false
+
+ with pytest.raises(ValueError, match="cannot be evaluated to python True"):
+ bool(true)
+
+ with pytest.raises(ValueError, match="cannot be evaluated to python True"):
+ not true
+
+
+def test_expression_arithmetic_operators():
+ dataset = ds.dataset(pa.table({'a': [1, 2, 3], 'b': [2, 2, 2]}))
+ a = ds.field("a")
+ b = ds.field("b")
+ result = dataset.to_table(columns={
+ "a+1": a + 1,
+ "b-a": b - a,
+ "a*2": a * 2,
+ "a/b": a.cast("float64") / b,
+ })
+ expected = pa.table({
+ "a+1": [2, 3, 4], "b-a": [1, 0, -1],
+ "a*2": [2, 4, 6], "a/b": [0.5, 1.0, 1.5],
+ })
+ assert result.equals(expected)
+
+
+def test_partition_keys():
+ a, b, c = [ds.field(f) == f for f in 'abc']
+ assert ds._get_partition_keys(a) == {'a': 'a'}
+ assert ds._get_partition_keys(a & b & c) == {f: f for f in 'abc'}
+
+ nope = ds.field('d') >= 3
+ assert ds._get_partition_keys(nope) == {}
+ assert ds._get_partition_keys(a & nope) == {'a': 'a'}
+
+ null = ds.field('a').is_null()
+ assert ds._get_partition_keys(null) == {'a': None}
+
+
+def test_parquet_read_options():
+ opts1 = ds.ParquetReadOptions()
+ opts2 = ds.ParquetReadOptions(dictionary_columns=['a', 'b'])
+ opts3 = ds.ParquetReadOptions(coerce_int96_timestamp_unit="ms")
+
+ assert opts1.dictionary_columns == set()
+
+ assert opts2.dictionary_columns == {'a', 'b'}
+
+ assert opts1.coerce_int96_timestamp_unit == "ns"
+ assert opts3.coerce_int96_timestamp_unit == "ms"
+
+ assert opts1 == opts1
+ assert opts1 != opts2
+ assert opts1 != opts3
+
+
+def test_parquet_file_format_read_options():
+ pff1 = ds.ParquetFileFormat()
+ pff2 = ds.ParquetFileFormat(dictionary_columns={'a'})
+ pff3 = ds.ParquetFileFormat(coerce_int96_timestamp_unit="s")
+
+ assert pff1.read_options == ds.ParquetReadOptions()
+ assert pff2.read_options == ds.ParquetReadOptions(dictionary_columns=['a'])
+ assert pff3.read_options == ds.ParquetReadOptions(
+ coerce_int96_timestamp_unit="s")
+
+
+def test_parquet_scan_options():
+ opts1 = ds.ParquetFragmentScanOptions()
+ opts2 = ds.ParquetFragmentScanOptions(buffer_size=4096)
+ opts3 = ds.ParquetFragmentScanOptions(
+ buffer_size=2**13, use_buffered_stream=True)
+ opts4 = ds.ParquetFragmentScanOptions(buffer_size=2**13, pre_buffer=True)
+
+ assert opts1.use_buffered_stream is False
+ assert opts1.buffer_size == 2**13
+ assert opts1.pre_buffer is False
+
+ assert opts2.use_buffered_stream is False
+ assert opts2.buffer_size == 2**12
+ assert opts2.pre_buffer is False
+
+ assert opts3.use_buffered_stream is True
+ assert opts3.buffer_size == 2**13
+ assert opts3.pre_buffer is False
+
+ assert opts4.use_buffered_stream is False
+ assert opts4.buffer_size == 2**13
+ assert opts4.pre_buffer is True
+
+ assert opts1 == opts1
+ assert opts1 != opts2
+ assert opts2 != opts3
+ assert opts3 != opts4
+
+
+def test_file_format_pickling():
+ formats = [
+ ds.IpcFileFormat(),
+ ds.CsvFileFormat(),
+ ds.CsvFileFormat(pa.csv.ParseOptions(delimiter='\t',
+ ignore_empty_lines=True)),
+ ds.CsvFileFormat(read_options=pa.csv.ReadOptions(
+ skip_rows=3, column_names=['foo'])),
+ ds.CsvFileFormat(read_options=pa.csv.ReadOptions(
+ skip_rows=3, block_size=2**20)),
+ ds.ParquetFileFormat(),
+ ds.ParquetFileFormat(dictionary_columns={'a'}),
+ ds.ParquetFileFormat(use_buffered_stream=True),
+ ds.ParquetFileFormat(
+ use_buffered_stream=True,
+ buffer_size=4096,
+ )
+ ]
+ try:
+ formats.append(ds.OrcFileFormat())
+ except (ImportError, AttributeError):
+ # catch AttributeError for Python 3.6
+ pass
+
+ for file_format in formats:
+ assert pickle.loads(pickle.dumps(file_format)) == file_format
+
+
+def test_fragment_scan_options_pickling():
+ options = [
+ ds.CsvFragmentScanOptions(),
+ ds.CsvFragmentScanOptions(
+ convert_options=pa.csv.ConvertOptions(strings_can_be_null=True)),
+ ds.CsvFragmentScanOptions(
+ read_options=pa.csv.ReadOptions(block_size=2**16)),
+ ds.ParquetFragmentScanOptions(buffer_size=4096),
+ ds.ParquetFragmentScanOptions(pre_buffer=True),
+ ]
+ for option in options:
+ assert pickle.loads(pickle.dumps(option)) == option
+
+
+@pytest.mark.parametrize('paths_or_selector', [
+ fs.FileSelector('subdir', recursive=True),
+ [
+ 'subdir/1/xxx/file0.parquet',
+ 'subdir/2/yyy/file1.parquet',
+ ]
+])
+@pytest.mark.parametrize('pre_buffer', [False, True])
+def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
+ format = ds.ParquetFileFormat(
+ read_options=ds.ParquetReadOptions(dictionary_columns={"str"}),
+ pre_buffer=pre_buffer
+ )
+
+ options = ds.FileSystemFactoryOptions('subdir')
+ options.partitioning = ds.DirectoryPartitioning(
+ pa.schema([
+ pa.field('group', pa.int32()),
+ pa.field('key', pa.string())
+ ])
+ )
+ assert options.partition_base_dir == 'subdir'
+ assert options.selector_ignore_prefixes == ['.', '_']
+ assert options.exclude_invalid_files is False
+
+ factory = ds.FileSystemDatasetFactory(
+ mockfs, paths_or_selector, format, options
+ )
+ inspected_schema = factory.inspect()
+
+ assert factory.inspect().equals(pa.schema([
+ pa.field('i64', pa.int64()),
+ pa.field('f64', pa.float64()),
+ pa.field('str', pa.dictionary(pa.int32(), pa.string())),
+ pa.field('const', pa.int64()),
+ pa.field('group', pa.int32()),
+ pa.field('key', pa.string()),
+ ]), check_metadata=False)
+
+ assert isinstance(factory.inspect_schemas(), list)
+ assert isinstance(factory.finish(inspected_schema),
+ ds.FileSystemDataset)
+ assert factory.root_partition.equals(ds.scalar(True))
+
+ dataset = factory.finish()
+ assert isinstance(dataset, ds.FileSystemDataset)
+
+ scanner = dataset.scanner()
+ expected_i64 = pa.array([0, 1, 2, 3, 4], type=pa.int64())
+ expected_f64 = pa.array([0, 1, 2, 3, 4], type=pa.float64())
+ expected_str = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 2, 3, 4], type=pa.int32()),
+ pa.array("0 1 2 3 4".split(), type=pa.string())
+ )
+ iterator = scanner.scan_batches()
+ for (batch, fragment), group, key in zip(iterator, [1, 2], ['xxx', 'yyy']):
+ expected_group = pa.array([group] * 5, type=pa.int32())
+ expected_key = pa.array([key] * 5, type=pa.string())
+ expected_const = pa.array([group - 1] * 5, type=pa.int64())
+ # Can't compare or really introspect expressions from Python
+ assert fragment.partition_expression is not None
+ assert batch.num_columns == 6
+ assert batch[0].equals(expected_i64)
+ assert batch[1].equals(expected_f64)
+ assert batch[2].equals(expected_str)
+ assert batch[3].equals(expected_const)
+ assert batch[4].equals(expected_group)
+ assert batch[5].equals(expected_key)
+
+ table = dataset.to_table()
+ assert isinstance(table, pa.Table)
+ assert len(table) == 10
+ assert table.num_columns == 6
+
+
+def test_make_fragment(multisourcefs):
+ parquet_format = ds.ParquetFileFormat()
+ dataset = ds.dataset('/plain', filesystem=multisourcefs,
+ format=parquet_format)
+
+ for path in dataset.files:
+ fragment = parquet_format.make_fragment(path, multisourcefs)
+ assert fragment.row_groups == [0]
+
+ row_group_fragment = parquet_format.make_fragment(path, multisourcefs,
+ row_groups=[0])
+ for f in [fragment, row_group_fragment]:
+ assert isinstance(f, ds.ParquetFileFragment)
+ assert f.path == path
+ assert isinstance(f.filesystem, type(multisourcefs))
+ assert row_group_fragment.row_groups == [0]
+
+
+def test_make_csv_fragment_from_buffer(dataset_reader):
+ content = textwrap.dedent("""
+ alpha,num,animal
+ a,12,dog
+ b,11,cat
+ c,10,rabbit
+ """)
+ buffer = pa.py_buffer(content.encode('utf-8'))
+
+ csv_format = ds.CsvFileFormat()
+ fragment = csv_format.make_fragment(buffer)
+
+ expected = pa.table([['a', 'b', 'c'],
+ [12, 11, 10],
+ ['dog', 'cat', 'rabbit']],
+ names=['alpha', 'num', 'animal'])
+ assert dataset_reader.to_table(fragment).equals(expected)
+
+ pickled = pickle.loads(pickle.dumps(fragment))
+ assert dataset_reader.to_table(pickled).equals(fragment.to_table())
+
+
+@pytest.mark.parquet
+def test_make_parquet_fragment_from_buffer(dataset_reader):
+ import pyarrow.parquet as pq
+
+ arrays = [
+ pa.array(['a', 'b', 'c']),
+ pa.array([12, 11, 10]),
+ pa.array(['dog', 'cat', 'rabbit'])
+ ]
+ dictionary_arrays = [
+ arrays[0].dictionary_encode(),
+ arrays[1],
+ arrays[2].dictionary_encode()
+ ]
+ dictionary_format = ds.ParquetFileFormat(
+ read_options=ds.ParquetReadOptions(
+ dictionary_columns=['alpha', 'animal']
+ ),
+ use_buffered_stream=True,
+ buffer_size=4096,
+ )
+
+ cases = [
+ (arrays, ds.ParquetFileFormat()),
+ (dictionary_arrays, dictionary_format)
+ ]
+ for arrays, format_ in cases:
+ table = pa.table(arrays, names=['alpha', 'num', 'animal'])
+
+ out = pa.BufferOutputStream()
+ pq.write_table(table, out)
+ buffer = out.getvalue()
+
+ fragment = format_.make_fragment(buffer)
+ assert dataset_reader.to_table(fragment).equals(table)
+
+ pickled = pickle.loads(pickle.dumps(fragment))
+ assert dataset_reader.to_table(pickled).equals(table)
+
+
+def _create_dataset_for_fragments(tempdir, chunk_size=None, filesystem=None):
+ import pyarrow.parquet as pq
+
+ table = pa.table(
+ [range(8), [1] * 8, ['a'] * 4 + ['b'] * 4],
+ names=['f1', 'f2', 'part']
+ )
+
+ path = str(tempdir / "test_parquet_dataset")
+
+ # write_to_dataset currently requires pandas
+ pq.write_to_dataset(table, path,
+ partition_cols=["part"], chunk_size=chunk_size)
+ dataset = ds.dataset(
+ path, format="parquet", partitioning="hive", filesystem=filesystem
+ )
+
+ return table, dataset
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments(tempdir, dataset_reader):
+ table, dataset = _create_dataset_for_fragments(tempdir)
+
+ # list fragments
+ fragments = list(dataset.get_fragments())
+ assert len(fragments) == 2
+ f = fragments[0]
+
+ physical_names = ['f1', 'f2']
+ # file's schema does not include partition column
+ assert f.physical_schema.names == physical_names
+ assert f.format.inspect(f.path, f.filesystem) == f.physical_schema
+ assert f.partition_expression.equals(ds.field('part') == 'a')
+
+ # By default, the partition column is not part of the schema.
+ result = dataset_reader.to_table(f)
+ assert result.column_names == physical_names
+ assert result.equals(table.remove_column(2).slice(0, 4))
+
+ # scanning fragment includes partition columns when given the proper
+ # schema.
+ result = dataset_reader.to_table(f, schema=dataset.schema)
+ assert result.column_names == ['f1', 'f2', 'part']
+ assert result.equals(table.slice(0, 4))
+ assert f.physical_schema == result.schema.remove(2)
+
+ # scanning fragments follow filter predicate
+ result = dataset_reader.to_table(
+ f, schema=dataset.schema, filter=ds.field('f1') < 2)
+ assert result.column_names == ['f1', 'f2', 'part']
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_implicit_cast(tempdir):
+ # ARROW-8693
+ import pyarrow.parquet as pq
+
+ table = pa.table([range(8), [1] * 4 + [2] * 4], names=['col', 'part'])
+ path = str(tempdir / "test_parquet_dataset")
+ pq.write_to_dataset(table, path, partition_cols=["part"])
+
+ part = ds.partitioning(pa.schema([('part', 'int8')]), flavor="hive")
+ dataset = ds.dataset(path, format="parquet", partitioning=part)
+ fragments = dataset.get_fragments(filter=ds.field("part") >= 2)
+ assert len(list(fragments)) == 1
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_reconstruct(tempdir, dataset_reader):
+ table, dataset = _create_dataset_for_fragments(tempdir)
+
+ def assert_yields_projected(fragment, row_slice,
+ columns=None, filter=None):
+ actual = fragment.to_table(
+ schema=table.schema, columns=columns, filter=filter)
+ column_names = columns if columns else table.column_names
+ assert actual.column_names == column_names
+
+ expected = table.slice(*row_slice).select(column_names)
+ assert actual.equals(expected)
+
+ fragment = list(dataset.get_fragments())[0]
+ parquet_format = fragment.format
+
+ # test pickle roundtrip
+ pickled_fragment = pickle.loads(pickle.dumps(fragment))
+ assert dataset_reader.to_table(
+ pickled_fragment) == dataset_reader.to_table(fragment)
+
+ # manually re-construct a fragment, with explicit schema
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression)
+ assert dataset_reader.to_table(new_fragment).equals(
+ dataset_reader.to_table(fragment))
+ assert_yields_projected(new_fragment, (0, 4))
+
+ # filter / column projection, inspected schema
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression)
+ assert_yields_projected(new_fragment, (0, 2), filter=ds.field('f1') < 2)
+
+ # filter requiring cast / column projection, inspected schema
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression)
+ assert_yields_projected(new_fragment, (0, 2),
+ columns=['f1'], filter=ds.field('f1') < 2.0)
+
+ # filter on the partition column
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression)
+ assert_yields_projected(new_fragment, (0, 4),
+ filter=ds.field('part') == 'a')
+
+ # Fragments don't contain the partition's columns if not provided to the
+ # `to_table(schema=...)` method.
+ pattern = (r'No match for FieldRef.Name\(part\) in ' +
+ fragment.physical_schema.to_string(False, False, False))
+ with pytest.raises(ValueError, match=pattern):
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression)
+ dataset_reader.to_table(new_fragment, filter=ds.field('part') == 'a')
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_row_groups(tempdir, dataset_reader):
+ table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2)
+
+ fragment = list(dataset.get_fragments())[0]
+
+ # list and scan row group fragments
+ row_group_fragments = list(fragment.split_by_row_group())
+ assert len(row_group_fragments) == fragment.num_row_groups == 2
+ result = dataset_reader.to_table(
+ row_group_fragments[0], schema=dataset.schema)
+ assert result.column_names == ['f1', 'f2', 'part']
+ assert len(result) == 2
+ assert result.equals(table.slice(0, 2))
+
+ assert row_group_fragments[0].row_groups is not None
+ assert row_group_fragments[0].num_row_groups == 1
+ assert row_group_fragments[0].row_groups[0].statistics == {
+ 'f1': {'min': 0, 'max': 1},
+ 'f2': {'min': 1, 'max': 1},
+ }
+
+ fragment = list(dataset.get_fragments(filter=ds.field('f1') < 1))[0]
+ row_group_fragments = list(fragment.split_by_row_group(ds.field('f1') < 1))
+ assert len(row_group_fragments) == 1
+ result = dataset_reader.to_table(
+ row_group_fragments[0], filter=ds.field('f1') < 1)
+ assert len(result) == 1
+
+
+@pytest.mark.parquet
+def test_fragments_parquet_num_row_groups(tempdir):
+ import pyarrow.parquet as pq
+
+ table = pa.table({'a': range(8)})
+ pq.write_table(table, tempdir / "test.parquet", row_group_size=2)
+ dataset = ds.dataset(tempdir / "test.parquet", format="parquet")
+ original_fragment = list(dataset.get_fragments())[0]
+
+ # create fragment with subset of row groups
+ fragment = original_fragment.format.make_fragment(
+ original_fragment.path, original_fragment.filesystem,
+ row_groups=[1, 3])
+ assert fragment.num_row_groups == 2
+ # ensure that parsing metadata preserves correct number of row groups
+ fragment.ensure_complete_metadata()
+ assert fragment.num_row_groups == 2
+ assert len(fragment.row_groups) == 2
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_row_groups_dictionary(tempdir, dataset_reader):
+ import pandas as pd
+
+ df = pd.DataFrame(dict(col1=['a', 'b'], col2=[1, 2]))
+ df['col1'] = df['col1'].astype("category")
+
+ import pyarrow.parquet as pq
+ pq.write_table(pa.table(df), tempdir / "test_filter_dictionary.parquet")
+
+ import pyarrow.dataset as ds
+ dataset = ds.dataset(tempdir / 'test_filter_dictionary.parquet')
+ result = dataset_reader.to_table(dataset, filter=ds.field("col1") == "a")
+
+ assert (df.iloc[0] == result.to_pandas()).all().all()
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_ensure_metadata(tempdir, open_logging_fs):
+ fs, assert_opens = open_logging_fs
+ _, dataset = _create_dataset_for_fragments(
+ tempdir, chunk_size=2, filesystem=fs
+ )
+ fragment = list(dataset.get_fragments())[0]
+
+ # with default discovery, no metadata loaded
+ with assert_opens([fragment.path]):
+ fragment.ensure_complete_metadata()
+ assert fragment.row_groups == [0, 1]
+
+ # second time -> use cached / no file IO
+ with assert_opens([]):
+ fragment.ensure_complete_metadata()
+
+ # recreate fragment with row group ids
+ new_fragment = fragment.format.make_fragment(
+ fragment.path, fragment.filesystem, row_groups=[0, 1]
+ )
+ assert new_fragment.row_groups == fragment.row_groups
+
+ # collect metadata
+ new_fragment.ensure_complete_metadata()
+ row_group = new_fragment.row_groups[0]
+ assert row_group.id == 0
+ assert row_group.num_rows == 2
+ assert row_group.statistics is not None
+
+ # pickling preserves row group ids
+ pickled_fragment = pickle.loads(pickle.dumps(new_fragment))
+ with assert_opens([fragment.path]):
+ assert pickled_fragment.row_groups == [0, 1]
+ row_group = pickled_fragment.row_groups[0]
+ assert row_group.id == 0
+ assert row_group.statistics is not None
+
+
+def _create_dataset_all_types(tempdir, chunk_size=None):
+ import pyarrow.parquet as pq
+
+ table = pa.table(
+ [
+ pa.array([True, None, False], pa.bool_()),
+ pa.array([1, 10, 42], pa.int8()),
+ pa.array([1, 10, 42], pa.uint8()),
+ pa.array([1, 10, 42], pa.int16()),
+ pa.array([1, 10, 42], pa.uint16()),
+ pa.array([1, 10, 42], pa.int32()),
+ pa.array([1, 10, 42], pa.uint32()),
+ pa.array([1, 10, 42], pa.int64()),
+ pa.array([1, 10, 42], pa.uint64()),
+ pa.array([1.0, 10.0, 42.0], pa.float32()),
+ pa.array([1.0, 10.0, 42.0], pa.float64()),
+ pa.array(['a', None, 'z'], pa.utf8()),
+ pa.array(['a', None, 'z'], pa.binary()),
+ pa.array([1, 10, 42], pa.timestamp('s')),
+ pa.array([1, 10, 42], pa.timestamp('ms')),
+ pa.array([1, 10, 42], pa.timestamp('us')),
+ pa.array([1, 10, 42], pa.date32()),
+ pa.array([1, 10, 4200000000], pa.date64()),
+ pa.array([1, 10, 42], pa.time32('s')),
+ pa.array([1, 10, 42], pa.time64('us')),
+ ],
+ names=[
+ 'boolean',
+ 'int8',
+ 'uint8',
+ 'int16',
+ 'uint16',
+ 'int32',
+ 'uint32',
+ 'int64',
+ 'uint64',
+ 'float',
+ 'double',
+ 'utf8',
+ 'binary',
+ 'ts[s]',
+ 'ts[ms]',
+ 'ts[us]',
+ 'date32',
+ 'date64',
+ 'time32',
+ 'time64',
+ ]
+ )
+
+ path = str(tempdir / "test_parquet_dataset_all_types")
+
+ # write_to_dataset currently requires pandas
+ pq.write_to_dataset(table, path, chunk_size=chunk_size)
+
+ return table, ds.dataset(path, format="parquet", partitioning="hive")
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_parquet_fragment_statistics(tempdir):
+ table, dataset = _create_dataset_all_types(tempdir)
+
+ fragment = list(dataset.get_fragments())[0]
+
+ import datetime
+ def dt_s(x): return datetime.datetime(1970, 1, 1, 0, 0, x)
+ def dt_ms(x): return datetime.datetime(1970, 1, 1, 0, 0, 0, x*1000)
+ def dt_us(x): return datetime.datetime(1970, 1, 1, 0, 0, 0, x)
+ date = datetime.date
+ time = datetime.time
+
+ # list and scan row group fragments
+ row_group_fragments = list(fragment.split_by_row_group())
+ assert row_group_fragments[0].row_groups is not None
+ row_group = row_group_fragments[0].row_groups[0]
+ assert row_group.num_rows == 3
+ assert row_group.total_byte_size > 1000
+ assert row_group.statistics == {
+ 'boolean': {'min': False, 'max': True},
+ 'int8': {'min': 1, 'max': 42},
+ 'uint8': {'min': 1, 'max': 42},
+ 'int16': {'min': 1, 'max': 42},
+ 'uint16': {'min': 1, 'max': 42},
+ 'int32': {'min': 1, 'max': 42},
+ 'uint32': {'min': 1, 'max': 42},
+ 'int64': {'min': 1, 'max': 42},
+ 'uint64': {'min': 1, 'max': 42},
+ 'float': {'min': 1.0, 'max': 42.0},
+ 'double': {'min': 1.0, 'max': 42.0},
+ 'utf8': {'min': 'a', 'max': 'z'},
+ 'binary': {'min': b'a', 'max': b'z'},
+ 'ts[s]': {'min': dt_s(1), 'max': dt_s(42)},
+ 'ts[ms]': {'min': dt_ms(1), 'max': dt_ms(42)},
+ 'ts[us]': {'min': dt_us(1), 'max': dt_us(42)},
+ 'date32': {'min': date(1970, 1, 2), 'max': date(1970, 2, 12)},
+ 'date64': {'min': date(1970, 1, 1), 'max': date(1970, 2, 18)},
+ 'time32': {'min': time(0, 0, 1), 'max': time(0, 0, 42)},
+ 'time64': {'min': time(0, 0, 0, 1), 'max': time(0, 0, 0, 42)},
+ }
+
+
+@pytest.mark.parquet
+def test_parquet_fragment_statistics_nulls(tempdir):
+ import pyarrow.parquet as pq
+
+ table = pa.table({'a': [0, 1, None, None], 'b': ['a', 'b', None, None]})
+ pq.write_table(table, tempdir / "test.parquet", row_group_size=2)
+
+ dataset = ds.dataset(tempdir / "test.parquet", format="parquet")
+ fragments = list(dataset.get_fragments())[0].split_by_row_group()
+ # second row group has all nulls -> no statistics
+ assert fragments[1].row_groups[0].statistics == {}
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_parquet_empty_row_group_statistics(tempdir):
+ df = pd.DataFrame({"a": ["a", "b", "b"], "b": [4, 5, 6]})[:0]
+ df.to_parquet(tempdir / "test.parquet", engine="pyarrow")
+
+ dataset = ds.dataset(tempdir / "test.parquet", format="parquet")
+ fragments = list(dataset.get_fragments())[0].split_by_row_group()
+ # Only row group is empty
+ assert fragments[0].row_groups[0].statistics == {}
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_row_groups_predicate(tempdir):
+ table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2)
+
+ fragment = list(dataset.get_fragments())[0]
+ assert fragment.partition_expression.equals(ds.field('part') == 'a')
+
+ # predicate may reference a partition field not present in the
+ # physical_schema if an explicit schema is provided to split_by_row_group
+
+ # filter matches partition_expression: all row groups
+ row_group_fragments = list(
+ fragment.split_by_row_group(filter=ds.field('part') == 'a',
+ schema=dataset.schema))
+ assert len(row_group_fragments) == 2
+
+ # filter contradicts partition_expression: no row groups
+ row_group_fragments = list(
+ fragment.split_by_row_group(filter=ds.field('part') == 'b',
+ schema=dataset.schema))
+ assert len(row_group_fragments) == 0
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_row_groups_reconstruct(tempdir, dataset_reader):
+ table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2)
+
+ fragment = list(dataset.get_fragments())[0]
+ parquet_format = fragment.format
+ row_group_fragments = list(fragment.split_by_row_group())
+
+ # test pickle roundtrip
+ pickled_fragment = pickle.loads(pickle.dumps(fragment))
+ assert dataset_reader.to_table(
+ pickled_fragment) == dataset_reader.to_table(fragment)
+
+ # manually re-construct row group fragments
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression,
+ row_groups=[0])
+ result = dataset_reader.to_table(new_fragment)
+ assert result.equals(dataset_reader.to_table(row_group_fragments[0]))
+
+ # manually re-construct a row group fragment with filter/column projection
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression,
+ row_groups={1})
+ result = dataset_reader.to_table(
+ new_fragment, schema=table.schema, columns=['f1', 'part'],
+ filter=ds.field('f1') < 3, )
+ assert result.column_names == ['f1', 'part']
+ assert len(result) == 1
+
+ # out of bounds row group index
+ new_fragment = parquet_format.make_fragment(
+ fragment.path, fragment.filesystem,
+ partition_expression=fragment.partition_expression,
+ row_groups={2})
+ with pytest.raises(IndexError, match="references row group 2"):
+ dataset_reader.to_table(new_fragment)
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_subset_ids(tempdir, open_logging_fs,
+ dataset_reader):
+ fs, assert_opens = open_logging_fs
+ table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=1,
+ filesystem=fs)
+ fragment = list(dataset.get_fragments())[0]
+
+ # select with row group ids
+ subfrag = fragment.subset(row_group_ids=[0, 3])
+ with assert_opens([]):
+ assert subfrag.num_row_groups == 2
+ assert subfrag.row_groups == [0, 3]
+ assert subfrag.row_groups[0].statistics is not None
+
+ # check correct scan result of subset
+ result = dataset_reader.to_table(subfrag)
+ assert result.to_pydict() == {"f1": [0, 3], "f2": [1, 1]}
+
+ # empty list of ids
+ subfrag = fragment.subset(row_group_ids=[])
+ assert subfrag.num_row_groups == 0
+ assert subfrag.row_groups == []
+ result = dataset_reader.to_table(subfrag, schema=dataset.schema)
+ assert result.num_rows == 0
+ assert result.equals(table[:0])
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_subset_filter(tempdir, open_logging_fs,
+ dataset_reader):
+ fs, assert_opens = open_logging_fs
+ table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=1,
+ filesystem=fs)
+ fragment = list(dataset.get_fragments())[0]
+
+ # select with filter
+ subfrag = fragment.subset(ds.field("f1") >= 1)
+ with assert_opens([]):
+ assert subfrag.num_row_groups == 3
+ assert len(subfrag.row_groups) == 3
+ assert subfrag.row_groups[0].statistics is not None
+
+ # check correct scan result of subset
+ result = dataset_reader.to_table(subfrag)
+ assert result.to_pydict() == {"f1": [1, 2, 3], "f2": [1, 1, 1]}
+
+ # filter that results in empty selection
+ subfrag = fragment.subset(ds.field("f1") > 5)
+ assert subfrag.num_row_groups == 0
+ assert subfrag.row_groups == []
+ result = dataset_reader.to_table(subfrag, schema=dataset.schema)
+ assert result.num_rows == 0
+ assert result.equals(table[:0])
+
+ # passing schema to ensure filter on partition expression works
+ subfrag = fragment.subset(ds.field("part") == "a", schema=dataset.schema)
+ assert subfrag.num_row_groups == 4
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_parquet_subset_invalid(tempdir):
+ _, dataset = _create_dataset_for_fragments(tempdir, chunk_size=1)
+ fragment = list(dataset.get_fragments())[0]
+
+ # passing none or both of filter / row_group_ids
+ with pytest.raises(ValueError):
+ fragment.subset(ds.field("f1") >= 1, row_group_ids=[1, 2])
+
+ with pytest.raises(ValueError):
+ fragment.subset()
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_fragments_repr(tempdir, dataset):
+ # partitioned parquet dataset
+ fragment = list(dataset.get_fragments())[0]
+ assert (
+ repr(fragment) ==
+ "<pyarrow.dataset.ParquetFileFragment path=subdir/1/xxx/file0.parquet "
+ "partition=[key=xxx, group=1]>"
+ )
+
+ # single-file parquet dataset (no partition information in repr)
+ table, path = _create_single_file(tempdir)
+ dataset = ds.dataset(path, format="parquet")
+ fragment = list(dataset.get_fragments())[0]
+ assert (
+ repr(fragment) ==
+ "<pyarrow.dataset.ParquetFileFragment path={}>".format(
+ dataset.filesystem.normalize_path(str(path)))
+ )
+
+ # non-parquet format
+ path = tempdir / "data.feather"
+ pa.feather.write_feather(table, path)
+ dataset = ds.dataset(path, format="feather")
+ fragment = list(dataset.get_fragments())[0]
+ assert (
+ repr(fragment) ==
+ "<pyarrow.dataset.FileFragment type=ipc path={}>".format(
+ dataset.filesystem.normalize_path(str(path)))
+ )
+
+
+def test_partitioning_factory(mockfs):
+ paths_or_selector = fs.FileSelector('subdir', recursive=True)
+ format = ds.ParquetFileFormat()
+
+ options = ds.FileSystemFactoryOptions('subdir')
+ partitioning_factory = ds.DirectoryPartitioning.discover(['group', 'key'])
+ assert isinstance(partitioning_factory, ds.PartitioningFactory)
+ options.partitioning_factory = partitioning_factory
+
+ factory = ds.FileSystemDatasetFactory(
+ mockfs, paths_or_selector, format, options
+ )
+ inspected_schema = factory.inspect()
+ # i64/f64 from data, group/key from "/1/xxx" and "/2/yyy" paths
+ expected_schema = pa.schema([
+ ("i64", pa.int64()),
+ ("f64", pa.float64()),
+ ("str", pa.string()),
+ ("const", pa.int64()),
+ ("group", pa.int32()),
+ ("key", pa.string()),
+ ])
+ assert inspected_schema.equals(expected_schema)
+
+ hive_partitioning_factory = ds.HivePartitioning.discover()
+ assert isinstance(hive_partitioning_factory, ds.PartitioningFactory)
+
+
+@pytest.mark.parametrize('infer_dictionary', [False, True])
+def test_partitioning_factory_dictionary(mockfs, infer_dictionary):
+ paths_or_selector = fs.FileSelector('subdir', recursive=True)
+ format = ds.ParquetFileFormat()
+ options = ds.FileSystemFactoryOptions('subdir')
+
+ options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ ['group', 'key'], infer_dictionary=infer_dictionary)
+
+ factory = ds.FileSystemDatasetFactory(
+ mockfs, paths_or_selector, format, options)
+
+ inferred_schema = factory.inspect()
+ if infer_dictionary:
+ expected_type = pa.dictionary(pa.int32(), pa.string())
+ assert inferred_schema.field('key').type == expected_type
+
+ table = factory.finish().to_table().combine_chunks()
+ actual = table.column('key').chunk(0)
+ expected = pa.array(['xxx'] * 5 + ['yyy'] * 5).dictionary_encode()
+ assert actual.equals(expected)
+
+ # ARROW-9345 ensure filtering on the partition field works
+ table = factory.finish().to_table(filter=ds.field('key') == 'xxx')
+ actual = table.column('key').chunk(0)
+ expected = expected.slice(0, 5)
+ assert actual.equals(expected)
+ else:
+ assert inferred_schema.field('key').type == pa.string()
+
+
+def test_partitioning_factory_segment_encoding():
+ mockfs = fs._MockFileSystem()
+ format = ds.IpcFileFormat()
+ schema = pa.schema([("i64", pa.int64())])
+ table = pa.table([pa.array(range(10))], schema=schema)
+ partition_schema = pa.schema(
+ [("date", pa.timestamp("s")), ("string", pa.string())])
+ string_partition_schema = pa.schema(
+ [("date", pa.string()), ("string", pa.string())])
+ full_schema = pa.schema(list(schema) + list(partition_schema))
+ for directory in [
+ "directory/2021-05-04 00%3A00%3A00/%24",
+ "hive/date=2021-05-04 00%3A00%3A00/string=%24",
+ ]:
+ mockfs.create_dir(directory)
+ with mockfs.open_output_stream(directory + "/0.feather") as sink:
+ with pa.ipc.new_file(sink, schema) as writer:
+ writer.write_table(table)
+ writer.close()
+
+ # Directory
+ selector = fs.FileSelector("directory", recursive=True)
+ options = ds.FileSystemFactoryOptions("directory")
+ options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ schema=partition_schema)
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ inferred_schema = factory.inspect()
+ assert inferred_schema == full_schema
+ actual = factory.finish().to_table(columns={
+ "date_int": ds.field("date").cast(pa.int64()),
+ })
+ assert actual[0][0].as_py() == 1620086400
+
+ options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ ["date", "string"], segment_encoding="none")
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ fragments = list(factory.finish().get_fragments())
+ assert fragments[0].partition_expression.equals(
+ (ds.field("date") == "2021-05-04 00%3A00%3A00") &
+ (ds.field("string") == "%24"))
+
+ options.partitioning = ds.DirectoryPartitioning(
+ string_partition_schema, segment_encoding="none")
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ fragments = list(factory.finish().get_fragments())
+ assert fragments[0].partition_expression.equals(
+ (ds.field("date") == "2021-05-04 00%3A00%3A00") &
+ (ds.field("string") == "%24"))
+
+ options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ schema=partition_schema, segment_encoding="none")
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ with pytest.raises(pa.ArrowInvalid,
+ match="Could not cast segments for partition field"):
+ inferred_schema = factory.inspect()
+
+ # Hive
+ selector = fs.FileSelector("hive", recursive=True)
+ options = ds.FileSystemFactoryOptions("hive")
+ options.partitioning_factory = ds.HivePartitioning.discover(
+ schema=partition_schema)
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ inferred_schema = factory.inspect()
+ assert inferred_schema == full_schema
+ actual = factory.finish().to_table(columns={
+ "date_int": ds.field("date").cast(pa.int64()),
+ })
+ assert actual[0][0].as_py() == 1620086400
+
+ options.partitioning_factory = ds.HivePartitioning.discover(
+ segment_encoding="none")
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ fragments = list(factory.finish().get_fragments())
+ assert fragments[0].partition_expression.equals(
+ (ds.field("date") == "2021-05-04 00%3A00%3A00") &
+ (ds.field("string") == "%24"))
+
+ options.partitioning = ds.HivePartitioning(
+ string_partition_schema, segment_encoding="none")
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ fragments = list(factory.finish().get_fragments())
+ assert fragments[0].partition_expression.equals(
+ (ds.field("date") == "2021-05-04 00%3A00%3A00") &
+ (ds.field("string") == "%24"))
+
+ options.partitioning_factory = ds.HivePartitioning.discover(
+ schema=partition_schema, segment_encoding="none")
+ factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
+ with pytest.raises(pa.ArrowInvalid,
+ match="Could not cast segments for partition field"):
+ inferred_schema = factory.inspect()
+
+
+def test_dictionary_partitioning_outer_nulls_raises(tempdir):
+ table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']})
+ part = ds.partitioning(
+ pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())]))
+ with pytest.raises(pa.ArrowInvalid):
+ ds.write_dataset(table, tempdir, format='parquet', partitioning=part)
+
+
+def _has_subdirs(basedir):
+ elements = os.listdir(basedir)
+ return any([os.path.isdir(os.path.join(basedir, el)) for el in elements])
+
+
+def _do_list_all_dirs(basedir, path_so_far, result):
+ for f in os.listdir(basedir):
+ true_nested = os.path.join(basedir, f)
+ if os.path.isdir(true_nested):
+ norm_nested = posixpath.join(path_so_far, f)
+ if _has_subdirs(true_nested):
+ _do_list_all_dirs(true_nested, norm_nested, result)
+ else:
+ result.append(norm_nested)
+
+
+def _list_all_dirs(basedir):
+ result = []
+ _do_list_all_dirs(basedir, '', result)
+ return result
+
+
+def _check_dataset_directories(tempdir, expected_directories):
+ actual_directories = set(_list_all_dirs(tempdir))
+ assert actual_directories == set(expected_directories)
+
+
+def test_dictionary_partitioning_inner_nulls(tempdir):
+ table = pa.table({'a': ['x', 'y', 'z'], 'b': ['x', 'y', None]})
+ part = ds.partitioning(
+ pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())]))
+ ds.write_dataset(table, tempdir, format='parquet', partitioning=part)
+ _check_dataset_directories(tempdir, ['x/x', 'y/y', 'z'])
+
+
+def test_hive_partitioning_nulls(tempdir):
+ table = pa.table({'a': ['x', None, 'z'], 'b': ['x', 'y', None]})
+ part = ds.HivePartitioning(pa.schema(
+ [pa.field('a', pa.string()), pa.field('b', pa.string())]), None, 'xyz')
+ ds.write_dataset(table, tempdir, format='parquet', partitioning=part)
+ _check_dataset_directories(tempdir, ['a=x/b=x', 'a=xyz/b=y', 'a=z/b=xyz'])
+
+
+def test_partitioning_function():
+ schema = pa.schema([("year", pa.int16()), ("month", pa.int8())])
+ names = ["year", "month"]
+
+ # default DirectoryPartitioning
+ part = ds.partitioning(schema)
+ assert isinstance(part, ds.DirectoryPartitioning)
+ part = ds.partitioning(schema, dictionaries="infer")
+ assert isinstance(part, ds.PartitioningFactory)
+ part = ds.partitioning(field_names=names)
+ assert isinstance(part, ds.PartitioningFactory)
+ # needs schema or list of names
+ with pytest.raises(ValueError):
+ ds.partitioning()
+ with pytest.raises(ValueError, match="Expected list"):
+ ds.partitioning(field_names=schema)
+ with pytest.raises(ValueError, match="Cannot specify both"):
+ ds.partitioning(schema, field_names=schema)
+
+ # Hive partitioning
+ part = ds.partitioning(schema, flavor="hive")
+ assert isinstance(part, ds.HivePartitioning)
+ part = ds.partitioning(schema, dictionaries="infer", flavor="hive")
+ assert isinstance(part, ds.PartitioningFactory)
+ part = ds.partitioning(flavor="hive")
+ assert isinstance(part, ds.PartitioningFactory)
+ # cannot pass list of names
+ with pytest.raises(ValueError):
+ ds.partitioning(names, flavor="hive")
+ with pytest.raises(ValueError, match="Cannot specify 'field_names'"):
+ ds.partitioning(field_names=names, flavor="hive")
+
+ # unsupported flavor
+ with pytest.raises(ValueError):
+ ds.partitioning(schema, flavor="unsupported")
+
+
+def test_directory_partitioning_dictionary_key(mockfs):
+ # ARROW-8088 specifying partition key as dictionary type
+ schema = pa.schema([
+ pa.field('group', pa.dictionary(pa.int8(), pa.int32())),
+ pa.field('key', pa.dictionary(pa.int8(), pa.string()))
+ ])
+ part = ds.DirectoryPartitioning.discover(schema=schema)
+
+ dataset = ds.dataset(
+ "subdir", format="parquet", filesystem=mockfs, partitioning=part
+ )
+ assert dataset.partitioning.schema == schema
+ table = dataset.to_table()
+
+ assert table.column('group').type.equals(schema.types[0])
+ assert table.column('group').to_pylist() == [1] * 5 + [2] * 5
+ assert table.column('key').type.equals(schema.types[1])
+ assert table.column('key').to_pylist() == ['xxx'] * 5 + ['yyy'] * 5
+
+
+def test_hive_partitioning_dictionary_key(multisourcefs):
+ # ARROW-8088 specifying partition key as dictionary type
+ schema = pa.schema([
+ pa.field('year', pa.dictionary(pa.int8(), pa.int16())),
+ pa.field('month', pa.dictionary(pa.int8(), pa.int16()))
+ ])
+ part = ds.HivePartitioning.discover(schema=schema)
+
+ dataset = ds.dataset(
+ "hive", format="parquet", filesystem=multisourcefs, partitioning=part
+ )
+ assert dataset.partitioning.schema == schema
+ table = dataset.to_table()
+
+ year_dictionary = list(range(2006, 2011))
+ month_dictionary = list(range(1, 13))
+ assert table.column('year').type.equals(schema.types[0])
+ for chunk in table.column('year').chunks:
+ actual = chunk.dictionary.to_pylist()
+ actual.sort()
+ assert actual == year_dictionary
+ assert table.column('month').type.equals(schema.types[1])
+ for chunk in table.column('month').chunks:
+ actual = chunk.dictionary.to_pylist()
+ actual.sort()
+ assert actual == month_dictionary
+
+
+def _create_single_file(base_dir, table=None, row_group_size=None):
+ import pyarrow.parquet as pq
+ if table is None:
+ table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5})
+ path = base_dir / "test.parquet"
+ pq.write_table(table, path, row_group_size=row_group_size)
+ return table, path
+
+
+def _create_directory_of_files(base_dir):
+ import pyarrow.parquet as pq
+ table1 = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5})
+ path1 = base_dir / "test1.parquet"
+ pq.write_table(table1, path1)
+ table2 = pa.table({'a': range(9, 18), 'b': [0.] * 4 + [1.] * 5})
+ path2 = base_dir / "test2.parquet"
+ pq.write_table(table2, path2)
+ return (table1, table2), (path1, path2)
+
+
+def _check_dataset(dataset, table, dataset_reader):
+ # also test that pickle roundtrip keeps the functionality
+ for d in [dataset, pickle.loads(pickle.dumps(dataset))]:
+ assert dataset.schema.equals(table.schema)
+ assert dataset_reader.to_table(dataset).equals(table)
+
+
+def _check_dataset_from_path(path, table, dataset_reader, **kwargs):
+ # pathlib object
+ assert isinstance(path, pathlib.Path)
+
+ # accept Path, str, List[Path], List[str]
+ for p in [path, str(path), [path], [str(path)]]:
+ dataset = ds.dataset(path, **kwargs)
+ assert isinstance(dataset, ds.FileSystemDataset)
+ _check_dataset(dataset, table, dataset_reader)
+
+ # relative string path
+ with change_cwd(path.parent):
+ dataset = ds.dataset(path.name, **kwargs)
+ assert isinstance(dataset, ds.FileSystemDataset)
+ _check_dataset(dataset, table, dataset_reader)
+
+
+@pytest.mark.parquet
+def test_open_dataset_single_file(tempdir, dataset_reader):
+ table, path = _create_single_file(tempdir)
+ _check_dataset_from_path(path, table, dataset_reader)
+
+
+@pytest.mark.parquet
+def test_deterministic_row_order(tempdir, dataset_reader):
+ # ARROW-8447 Ensure that dataset.to_table (and Scanner::ToTable) returns a
+ # deterministic row ordering. This is achieved by constructing a single
+ # parquet file with one row per RowGroup.
+ table, path = _create_single_file(tempdir, row_group_size=1)
+ _check_dataset_from_path(path, table, dataset_reader)
+
+
+@pytest.mark.parquet
+def test_open_dataset_directory(tempdir, dataset_reader):
+ tables, _ = _create_directory_of_files(tempdir)
+ table = pa.concat_tables(tables)
+ _check_dataset_from_path(tempdir, table, dataset_reader)
+
+
+@pytest.mark.parquet
+def test_open_dataset_list_of_files(tempdir, dataset_reader):
+ tables, (path1, path2) = _create_directory_of_files(tempdir)
+ table = pa.concat_tables(tables)
+
+ datasets = [
+ ds.dataset([path1, path2]),
+ ds.dataset([str(path1), str(path2)])
+ ]
+ datasets += [
+ pickle.loads(pickle.dumps(d)) for d in datasets
+ ]
+
+ for dataset in datasets:
+ assert dataset.schema.equals(table.schema)
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+
+@pytest.mark.parquet
+def test_open_dataset_filesystem_fspath(tempdir):
+ # single file
+ table, path = _create_single_file(tempdir)
+
+ fspath = FSProtocolClass(path)
+
+ # filesystem inferred from path
+ dataset1 = ds.dataset(fspath)
+ assert dataset1.schema.equals(table.schema)
+
+ # filesystem specified
+ dataset2 = ds.dataset(fspath, filesystem=fs.LocalFileSystem())
+ assert dataset2.schema.equals(table.schema)
+
+ # passing different filesystem
+ with pytest.raises(TypeError):
+ ds.dataset(fspath, filesystem=fs._MockFileSystem())
+
+
+def test_construct_from_single_file(tempdir, dataset_reader):
+ directory = tempdir / 'single-file'
+ directory.mkdir()
+ table, path = _create_single_file(directory)
+ relative_path = path.relative_to(directory)
+
+ # instantiate from a single file
+ d1 = ds.dataset(path)
+ # instantiate from a single file with a filesystem object
+ d2 = ds.dataset(path, filesystem=fs.LocalFileSystem())
+ # instantiate from a single file with prefixed filesystem URI
+ d3 = ds.dataset(str(relative_path), filesystem=_filesystem_uri(directory))
+ # pickle roundtrip
+ d4 = pickle.loads(pickle.dumps(d1))
+
+ assert dataset_reader.to_table(d1) == dataset_reader.to_table(
+ d2) == dataset_reader.to_table(d3) == dataset_reader.to_table(d4)
+
+
+def test_construct_from_single_directory(tempdir, dataset_reader):
+ directory = tempdir / 'single-directory'
+ directory.mkdir()
+ tables, paths = _create_directory_of_files(directory)
+
+ d1 = ds.dataset(directory)
+ d2 = ds.dataset(directory, filesystem=fs.LocalFileSystem())
+ d3 = ds.dataset(directory.name, filesystem=_filesystem_uri(tempdir))
+ t1 = dataset_reader.to_table(d1)
+ t2 = dataset_reader.to_table(d2)
+ t3 = dataset_reader.to_table(d3)
+ assert t1 == t2 == t3
+
+ # test pickle roundtrip
+ for d in [d1, d2, d3]:
+ restored = pickle.loads(pickle.dumps(d))
+ assert dataset_reader.to_table(restored) == t1
+
+
+def test_construct_from_list_of_files(tempdir, dataset_reader):
+ # instantiate from a list of files
+ directory = tempdir / 'list-of-files'
+ directory.mkdir()
+ tables, paths = _create_directory_of_files(directory)
+
+ relative_paths = [p.relative_to(tempdir) for p in paths]
+ with change_cwd(tempdir):
+ d1 = ds.dataset(relative_paths)
+ t1 = dataset_reader.to_table(d1)
+ assert len(t1) == sum(map(len, tables))
+
+ d2 = ds.dataset(relative_paths, filesystem=_filesystem_uri(tempdir))
+ t2 = dataset_reader.to_table(d2)
+ d3 = ds.dataset(paths)
+ t3 = dataset_reader.to_table(d3)
+ d4 = ds.dataset(paths, filesystem=fs.LocalFileSystem())
+ t4 = dataset_reader.to_table(d4)
+
+ assert t1 == t2 == t3 == t4
+
+
+def test_construct_from_list_of_mixed_paths_fails(mockfs):
+ # isntantiate from a list of mixed paths
+ files = [
+ 'subdir/1/xxx/file0.parquet',
+ 'subdir/1/xxx/doesnt-exist.parquet',
+ ]
+ with pytest.raises(FileNotFoundError, match='doesnt-exist'):
+ ds.dataset(files, filesystem=mockfs)
+
+
+def test_construct_from_mixed_child_datasets(mockfs):
+ # isntantiate from a list of mixed paths
+ a = ds.dataset(['subdir/1/xxx/file0.parquet',
+ 'subdir/2/yyy/file1.parquet'], filesystem=mockfs)
+ b = ds.dataset('subdir', filesystem=mockfs)
+
+ dataset = ds.dataset([a, b])
+
+ assert isinstance(dataset, ds.UnionDataset)
+ assert len(list(dataset.get_fragments())) == 4
+
+ table = dataset.to_table()
+ assert len(table) == 20
+ assert table.num_columns == 4
+
+ assert len(dataset.children) == 2
+ for child in dataset.children:
+ assert child.files == ['subdir/1/xxx/file0.parquet',
+ 'subdir/2/yyy/file1.parquet']
+
+
+def test_construct_empty_dataset():
+ empty = ds.dataset([])
+ table = empty.to_table()
+ assert table.num_rows == 0
+ assert table.num_columns == 0
+
+
+def test_construct_dataset_with_invalid_schema():
+ empty = ds.dataset([], schema=pa.schema([
+ ('a', pa.int64()),
+ ('a', pa.string())
+ ]))
+ with pytest.raises(ValueError, match='Multiple matches for .*a.* in '):
+ empty.to_table()
+
+
+def test_construct_from_invalid_sources_raise(multisourcefs):
+ child1 = ds.FileSystemDatasetFactory(
+ multisourcefs,
+ fs.FileSelector('/plain'),
+ format=ds.ParquetFileFormat()
+ )
+ child2 = ds.FileSystemDatasetFactory(
+ multisourcefs,
+ fs.FileSelector('/schema'),
+ format=ds.ParquetFileFormat()
+ )
+ batch1 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+ batch2 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["b"])
+
+ with pytest.raises(TypeError, match='Expected.*FileSystemDatasetFactory'):
+ ds.dataset([child1, child2])
+
+ expected = (
+ "Expected a list of path-like or dataset objects, or a list "
+ "of batches or tables. The given list contains the following "
+ "types: int"
+ )
+ with pytest.raises(TypeError, match=expected):
+ ds.dataset([1, 2, 3])
+
+ expected = (
+ "Expected a path-like, list of path-likes or a list of Datasets "
+ "instead of the given type: NoneType"
+ )
+ with pytest.raises(TypeError, match=expected):
+ ds.dataset(None)
+
+ expected = (
+ "Expected a path-like, list of path-likes or a list of Datasets "
+ "instead of the given type: generator"
+ )
+ with pytest.raises(TypeError, match=expected):
+ ds.dataset((batch1 for _ in range(3)))
+
+ expected = (
+ "Must provide schema to construct in-memory dataset from an empty list"
+ )
+ with pytest.raises(ValueError, match=expected):
+ ds.InMemoryDataset([])
+
+ expected = (
+ "Item has schema\nb: int64\nwhich does not match expected schema\n"
+ "a: int64"
+ )
+ with pytest.raises(TypeError, match=expected):
+ ds.dataset([batch1, batch2])
+
+ expected = (
+ "Expected a list of path-like or dataset objects, or a list of "
+ "batches or tables. The given list contains the following types:"
+ )
+ with pytest.raises(TypeError, match=expected):
+ ds.dataset([batch1, 0])
+
+ expected = (
+ "Expected a list of tables or batches. The given list contains a int"
+ )
+ with pytest.raises(TypeError, match=expected):
+ ds.InMemoryDataset([batch1, 0])
+
+
+def test_construct_in_memory(dataset_reader):
+ batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+ table = pa.Table.from_batches([batch])
+
+ assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([])
+
+ for source in (batch, table, [batch], [table]):
+ dataset = ds.dataset(source)
+ assert dataset_reader.to_table(dataset) == table
+ assert len(list(dataset.get_fragments())) == 1
+ assert next(dataset.get_fragments()).to_table() == table
+ assert pa.Table.from_batches(list(dataset.to_batches())) == table
+
+
+@pytest.mark.parametrize('use_threads,use_async',
+ [(False, False), (False, True),
+ (True, False), (True, True)])
+def test_scan_iterator(use_threads, use_async):
+ batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"])
+ table = pa.Table.from_batches([batch])
+ # When constructed from readers/iterators, should be one-shot
+ match = "OneShotFragment was already scanned"
+ for factory, schema in (
+ (lambda: pa.ipc.RecordBatchReader.from_batches(
+ batch.schema, [batch]), None),
+ (lambda: (batch for _ in range(1)), batch.schema),
+ ):
+ # Scanning the fragment consumes the underlying iterator
+ scanner = ds.Scanner.from_batches(
+ factory(), schema=schema, use_threads=use_threads,
+ use_async=use_async)
+ assert scanner.to_table() == table
+ with pytest.raises(pa.ArrowInvalid, match=match):
+ scanner.to_table()
+
+
+def _create_partitioned_dataset(basedir):
+ import pyarrow.parquet as pq
+ table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5})
+
+ path = basedir / "dataset-partitioned"
+ path.mkdir()
+
+ for i in range(3):
+ part = path / "part={}".format(i)
+ part.mkdir()
+ pq.write_table(table.slice(3*i, 3), part / "test.parquet")
+
+ full_table = table.append_column(
+ "part", pa.array(np.repeat([0, 1, 2], 3), type=pa.int32()))
+
+ return full_table, path
+
+
+@pytest.mark.parquet
+def test_open_dataset_partitioned_directory(tempdir, dataset_reader):
+ full_table, path = _create_partitioned_dataset(tempdir)
+
+ # no partitioning specified, just read all individual files
+ table = full_table.select(['a', 'b'])
+ _check_dataset_from_path(path, table, dataset_reader)
+
+ # specify partition scheme with discovery
+ dataset = ds.dataset(
+ str(path), partitioning=ds.partitioning(flavor="hive"))
+ assert dataset.schema.equals(full_table.schema)
+
+ # specify partition scheme with discovery and relative path
+ with change_cwd(tempdir):
+ dataset = ds.dataset("dataset-partitioned/",
+ partitioning=ds.partitioning(flavor="hive"))
+ assert dataset.schema.equals(full_table.schema)
+
+ # specify partition scheme with string short-cut
+ dataset = ds.dataset(str(path), partitioning="hive")
+ assert dataset.schema.equals(full_table.schema)
+
+ # specify partition scheme with explicit scheme
+ dataset = ds.dataset(
+ str(path),
+ partitioning=ds.partitioning(
+ pa.schema([("part", pa.int8())]), flavor="hive"))
+ expected_schema = table.schema.append(pa.field("part", pa.int8()))
+ assert dataset.schema.equals(expected_schema)
+
+ result = dataset.to_table()
+ expected = table.append_column(
+ "part", pa.array(np.repeat([0, 1, 2], 3), type=pa.int8()))
+ assert result.equals(expected)
+
+
+@pytest.mark.parquet
+def test_open_dataset_filesystem(tempdir):
+ # single file
+ table, path = _create_single_file(tempdir)
+
+ # filesystem inferred from path
+ dataset1 = ds.dataset(str(path))
+ assert dataset1.schema.equals(table.schema)
+
+ # filesystem specified
+ dataset2 = ds.dataset(str(path), filesystem=fs.LocalFileSystem())
+ assert dataset2.schema.equals(table.schema)
+
+ # local filesystem specified with relative path
+ with change_cwd(tempdir):
+ dataset3 = ds.dataset("test.parquet", filesystem=fs.LocalFileSystem())
+ assert dataset3.schema.equals(table.schema)
+
+ # passing different filesystem
+ with pytest.raises(FileNotFoundError):
+ ds.dataset(str(path), filesystem=fs._MockFileSystem())
+
+
+@pytest.mark.parquet
+def test_open_dataset_unsupported_format(tempdir):
+ _, path = _create_single_file(tempdir)
+ with pytest.raises(ValueError, match="format 'blabla' is not supported"):
+ ds.dataset([path], format="blabla")
+
+
+@pytest.mark.parquet
+def test_open_union_dataset(tempdir, dataset_reader):
+ _, path = _create_single_file(tempdir)
+ dataset = ds.dataset(path)
+
+ union = ds.dataset([dataset, dataset])
+ assert isinstance(union, ds.UnionDataset)
+
+ pickled = pickle.loads(pickle.dumps(union))
+ assert dataset_reader.to_table(pickled) == dataset_reader.to_table(union)
+
+
+def test_open_union_dataset_with_additional_kwargs(multisourcefs):
+ child = ds.dataset('/plain', filesystem=multisourcefs, format='parquet')
+ with pytest.raises(ValueError, match="cannot pass any additional"):
+ ds.dataset([child], format="parquet")
+
+
+def test_open_dataset_non_existing_file():
+ # ARROW-8213: Opening a dataset with a local incorrect path gives confusing
+ # error message
+ with pytest.raises(FileNotFoundError):
+ ds.dataset('i-am-not-existing.parquet', format='parquet')
+
+ with pytest.raises(pa.ArrowInvalid, match='cannot be relative'):
+ ds.dataset('file:i-am-not-existing.parquet', format='parquet')
+
+
+@pytest.mark.parquet
+@pytest.mark.parametrize('partitioning', ["directory", "hive"])
+@pytest.mark.parametrize('null_fallback', ['xyz', None])
+@pytest.mark.parametrize('infer_dictionary', [False, True])
+@pytest.mark.parametrize('partition_keys', [
+ (["A", "B", "C"], [1, 2, 3]),
+ ([1, 2, 3], ["A", "B", "C"]),
+ (["A", "B", "C"], ["D", "E", "F"]),
+ ([1, 2, 3], [4, 5, 6]),
+ ([1, None, 3], ["A", "B", "C"]),
+ ([1, 2, 3], ["A", None, "C"]),
+ ([None, 2, 3], [None, 2, 3]),
+])
+def test_partition_discovery(
+ tempdir, partitioning, null_fallback, infer_dictionary, partition_keys
+):
+ # ARROW-9288 / ARROW-9476
+ import pyarrow.parquet as pq
+
+ table = pa.table({'a': range(9), 'b': [0.0] * 4 + [1.0] * 5})
+
+ has_null = None in partition_keys[0] or None in partition_keys[1]
+ if partitioning == "directory" and has_null:
+ # Directory partitioning can't handle the first part being null
+ return
+
+ if partitioning == "directory":
+ partitioning = ds.DirectoryPartitioning.discover(
+ ["part1", "part2"], infer_dictionary=infer_dictionary)
+ fmt = "{0}/{1}"
+ null_value = None
+ else:
+ if null_fallback:
+ partitioning = ds.HivePartitioning.discover(
+ infer_dictionary=infer_dictionary, null_fallback=null_fallback
+ )
+ else:
+ partitioning = ds.HivePartitioning.discover(
+ infer_dictionary=infer_dictionary)
+ fmt = "part1={0}/part2={1}"
+ if null_fallback:
+ null_value = null_fallback
+ else:
+ null_value = "__HIVE_DEFAULT_PARTITION__"
+
+ basepath = tempdir / "dataset"
+ basepath.mkdir()
+
+ part_keys1, part_keys2 = partition_keys
+ for part1 in part_keys1:
+ for part2 in part_keys2:
+ path = basepath / \
+ fmt.format(part1 or null_value, part2 or null_value)
+ path.mkdir(parents=True)
+ pq.write_table(table, path / "test.parquet")
+
+ dataset = ds.dataset(str(basepath), partitioning=partitioning)
+
+ def expected_type(key):
+ if infer_dictionary:
+ value_type = pa.string() if isinstance(key, str) else pa.int32()
+ return pa.dictionary(pa.int32(), value_type)
+ else:
+ return pa.string() if isinstance(key, str) else pa.int32()
+ expected_schema = table.schema.append(
+ pa.field("part1", expected_type(part_keys1[0]))
+ ).append(
+ pa.field("part2", expected_type(part_keys2[0]))
+ )
+ assert dataset.schema.equals(expected_schema)
+
+
+@pytest.mark.pandas
+def test_dataset_partitioned_dictionary_type_reconstruct(tempdir):
+ # https://issues.apache.org/jira/browse/ARROW-11400
+ table = pa.table({'part': np.repeat(['A', 'B'], 5), 'col': range(10)})
+ part = ds.partitioning(table.select(['part']).schema, flavor="hive")
+ ds.write_dataset(table, tempdir, partitioning=part, format="feather")
+
+ dataset = ds.dataset(
+ tempdir, format="feather",
+ partitioning=ds.HivePartitioning.discover(infer_dictionary=True)
+ )
+ expected = pa.table(
+ {'col': table['col'], 'part': table['part'].dictionary_encode()}
+ )
+ assert dataset.to_table().equals(expected)
+ fragment = list(dataset.get_fragments())[0]
+ assert fragment.to_table(schema=dataset.schema).equals(expected[:5])
+ part_expr = fragment.partition_expression
+
+ restored = pickle.loads(pickle.dumps(dataset))
+ assert restored.to_table().equals(expected)
+
+ restored = pickle.loads(pickle.dumps(fragment))
+ assert restored.to_table(schema=dataset.schema).equals(expected[:5])
+ # to_pandas call triggers computation of the actual dictionary values
+ assert restored.to_table(schema=dataset.schema).to_pandas().equals(
+ expected[:5].to_pandas()
+ )
+ assert restored.partition_expression.equals(part_expr)
+
+
+@pytest.fixture
+def s3_example_simple(s3_server):
+ from pyarrow.fs import FileSystem
+ import pyarrow.parquet as pq
+
+ host, port, access_key, secret_key = s3_server['connection']
+ uri = (
+ "s3://{}:{}@mybucket/data.parquet?scheme=http&endpoint_override={}:{}"
+ .format(access_key, secret_key, host, port)
+ )
+
+ fs, path = FileSystem.from_uri(uri)
+
+ fs.create_dir("mybucket")
+ table = pa.table({'a': [1, 2, 3]})
+ with fs.open_output_stream("mybucket/data.parquet") as out:
+ pq.write_table(table, out)
+
+ return table, path, fs, uri, host, port, access_key, secret_key
+
+
+@pytest.mark.parquet
+@pytest.mark.s3
+def test_open_dataset_from_uri_s3(s3_example_simple, dataset_reader):
+ # open dataset from non-localfs string path
+ table, path, fs, uri, _, _, _, _ = s3_example_simple
+
+ # full string URI
+ dataset = ds.dataset(uri, format="parquet")
+ assert dataset_reader.to_table(dataset).equals(table)
+
+ # passing filesystem object
+ dataset = ds.dataset(path, format="parquet", filesystem=fs)
+ assert dataset_reader.to_table(dataset).equals(table)
+
+
+@pytest.mark.parquet
+@pytest.mark.s3 # still needed to create the data
+def test_open_dataset_from_uri_s3_fsspec(s3_example_simple):
+ table, path, _, _, host, port, access_key, secret_key = s3_example_simple
+ s3fs = pytest.importorskip("s3fs")
+
+ from pyarrow.fs import PyFileSystem, FSSpecHandler
+
+ fs = s3fs.S3FileSystem(
+ key=access_key,
+ secret=secret_key,
+ client_kwargs={
+ 'endpoint_url': 'http://{}:{}'.format(host, port)
+ }
+ )
+
+ # passing as fsspec filesystem
+ dataset = ds.dataset(path, format="parquet", filesystem=fs)
+ assert dataset.to_table().equals(table)
+
+ # directly passing the fsspec-handler
+ fs = PyFileSystem(FSSpecHandler(fs))
+ dataset = ds.dataset(path, format="parquet", filesystem=fs)
+ assert dataset.to_table().equals(table)
+
+
+@pytest.mark.parquet
+@pytest.mark.s3
+def test_open_dataset_from_s3_with_filesystem_uri(s3_server):
+ from pyarrow.fs import FileSystem
+ import pyarrow.parquet as pq
+
+ host, port, access_key, secret_key = s3_server['connection']
+ bucket = 'theirbucket'
+ path = 'nested/folder/data.parquet'
+ uri = "s3://{}:{}@{}/{}?scheme=http&endpoint_override={}:{}".format(
+ access_key, secret_key, bucket, path, host, port
+ )
+
+ fs, path = FileSystem.from_uri(uri)
+ assert path == 'theirbucket/nested/folder/data.parquet'
+
+ fs.create_dir(bucket)
+
+ table = pa.table({'a': [1, 2, 3]})
+ with fs.open_output_stream(path) as out:
+ pq.write_table(table, out)
+
+ # full string URI
+ dataset = ds.dataset(uri, format="parquet")
+ assert dataset.to_table().equals(table)
+
+ # passing filesystem as an uri
+ template = (
+ "s3://{}:{}@{{}}?scheme=http&endpoint_override={}:{}".format(
+ access_key, secret_key, host, port
+ )
+ )
+ cases = [
+ ('theirbucket/nested/folder/', '/data.parquet'),
+ ('theirbucket/nested/folder', 'data.parquet'),
+ ('theirbucket/nested/', 'folder/data.parquet'),
+ ('theirbucket/nested', 'folder/data.parquet'),
+ ('theirbucket', '/nested/folder/data.parquet'),
+ ('theirbucket', 'nested/folder/data.parquet'),
+ ]
+ for prefix, path in cases:
+ uri = template.format(prefix)
+ dataset = ds.dataset(path, filesystem=uri, format="parquet")
+ assert dataset.to_table().equals(table)
+
+ with pytest.raises(pa.ArrowInvalid, match='Missing bucket name'):
+ uri = template.format('/')
+ ds.dataset('/theirbucket/nested/folder/data.parquet', filesystem=uri)
+
+ error = (
+ "The path component of the filesystem URI must point to a directory "
+ "but it has a type: `{}`. The path component is `{}` and the given "
+ "filesystem URI is `{}`"
+ )
+
+ path = 'theirbucket/doesnt/exist'
+ uri = template.format(path)
+ with pytest.raises(ValueError) as exc:
+ ds.dataset('data.parquet', filesystem=uri)
+ assert str(exc.value) == error.format('NotFound', path, uri)
+
+ path = 'theirbucket/nested/folder/data.parquet'
+ uri = template.format(path)
+ with pytest.raises(ValueError) as exc:
+ ds.dataset('data.parquet', filesystem=uri)
+ assert str(exc.value) == error.format('File', path, uri)
+
+
+@pytest.mark.parquet
+def test_open_dataset_from_fsspec(tempdir):
+ table, path = _create_single_file(tempdir)
+
+ fsspec = pytest.importorskip("fsspec")
+
+ localfs = fsspec.filesystem("file")
+ dataset = ds.dataset(path, filesystem=localfs)
+ assert dataset.schema.equals(table.schema)
+
+
+@pytest.mark.pandas
+def test_filter_timestamp(tempdir, dataset_reader):
+ # ARROW-11379
+ path = tempdir / "test_partition_timestamps"
+
+ table = pa.table({
+ "dates": ['2012-01-01', '2012-01-02'] * 5,
+ "id": range(10)})
+
+ # write dataset partitioned on dates (as strings)
+ part = ds.partitioning(table.select(['dates']).schema, flavor="hive")
+ ds.write_dataset(table, path, partitioning=part, format="feather")
+
+ # read dataset partitioned on dates (as timestamps)
+ part = ds.partitioning(pa.schema([("dates", pa.timestamp("s"))]),
+ flavor="hive")
+ dataset = ds.dataset(path, format="feather", partitioning=part)
+
+ condition = ds.field("dates") > pd.Timestamp("2012-01-01")
+ table = dataset_reader.to_table(dataset, filter=condition)
+ assert table.column('id').to_pylist() == [1, 3, 5, 7, 9]
+
+ import datetime
+ condition = ds.field("dates") > datetime.datetime(2012, 1, 1)
+ table = dataset_reader.to_table(dataset, filter=condition)
+ assert table.column('id').to_pylist() == [1, 3, 5, 7, 9]
+
+
+@pytest.mark.parquet
+def test_filter_implicit_cast(tempdir, dataset_reader):
+ # ARROW-7652
+ table = pa.table({'a': pa.array([0, 1, 2, 3, 4, 5], type=pa.int8())})
+ _, path = _create_single_file(tempdir, table)
+ dataset = ds.dataset(str(path))
+
+ filter_ = ds.field('a') > 2
+ assert len(dataset_reader.to_table(dataset, filter=filter_)) == 3
+
+
+def test_dataset_union(multisourcefs):
+ child = ds.FileSystemDatasetFactory(
+ multisourcefs, fs.FileSelector('/plain'),
+ format=ds.ParquetFileFormat()
+ )
+ factory = ds.UnionDatasetFactory([child])
+
+ # TODO(bkietz) reintroduce factory.children property
+ assert len(factory.inspect_schemas()) == 1
+ assert all(isinstance(s, pa.Schema) for s in factory.inspect_schemas())
+ assert factory.inspect_schemas()[0].equals(child.inspect())
+ assert factory.inspect().equals(child.inspect())
+ assert isinstance(factory.finish(), ds.Dataset)
+
+
+def test_union_dataset_from_other_datasets(tempdir, multisourcefs):
+ child1 = ds.dataset('/plain', filesystem=multisourcefs, format='parquet')
+ child2 = ds.dataset('/schema', filesystem=multisourcefs, format='parquet',
+ partitioning=['week', 'color'])
+ child3 = ds.dataset('/hive', filesystem=multisourcefs, format='parquet',
+ partitioning='hive')
+
+ assert child1.schema != child2.schema != child3.schema
+
+ assembled = ds.dataset([child1, child2, child3])
+ assert isinstance(assembled, ds.UnionDataset)
+
+ msg = 'cannot pass any additional arguments'
+ with pytest.raises(ValueError, match=msg):
+ ds.dataset([child1, child2], filesystem=multisourcefs)
+
+ expected_schema = pa.schema([
+ ('date', pa.date32()),
+ ('index', pa.int64()),
+ ('value', pa.float64()),
+ ('color', pa.string()),
+ ('week', pa.int32()),
+ ('year', pa.int32()),
+ ('month', pa.int32()),
+ ])
+ assert assembled.schema.equals(expected_schema)
+ assert assembled.to_table().schema.equals(expected_schema)
+
+ assembled = ds.dataset([child1, child3])
+ expected_schema = pa.schema([
+ ('date', pa.date32()),
+ ('index', pa.int64()),
+ ('value', pa.float64()),
+ ('color', pa.string()),
+ ('year', pa.int32()),
+ ('month', pa.int32()),
+ ])
+ assert assembled.schema.equals(expected_schema)
+ assert assembled.to_table().schema.equals(expected_schema)
+
+ expected_schema = pa.schema([
+ ('month', pa.int32()),
+ ('color', pa.string()),
+ ('date', pa.date32()),
+ ])
+ assembled = ds.dataset([child1, child3], schema=expected_schema)
+ assert assembled.to_table().schema.equals(expected_schema)
+
+ expected_schema = pa.schema([
+ ('month', pa.int32()),
+ ('color', pa.string()),
+ ('unknown', pa.string()) # fill with nulls
+ ])
+ assembled = ds.dataset([child1, child3], schema=expected_schema)
+ assert assembled.to_table().schema.equals(expected_schema)
+
+ # incompatible schemas, date and index columns have conflicting types
+ table = pa.table([range(9), [0.] * 4 + [1.] * 5, 'abcdefghj'],
+ names=['date', 'value', 'index'])
+ _, path = _create_single_file(tempdir, table=table)
+ child4 = ds.dataset(path)
+
+ with pytest.raises(pa.ArrowInvalid, match='Unable to merge'):
+ ds.dataset([child1, child4])
+
+
+def test_dataset_from_a_list_of_local_directories_raises(multisourcefs):
+ msg = 'points to a directory, but only file paths are supported'
+ with pytest.raises(IsADirectoryError, match=msg):
+ ds.dataset(['/plain', '/schema', '/hive'], filesystem=multisourcefs)
+
+
+def test_union_dataset_filesystem_datasets(multisourcefs):
+ # without partitioning
+ dataset = ds.dataset([
+ ds.dataset('/plain', filesystem=multisourcefs),
+ ds.dataset('/schema', filesystem=multisourcefs),
+ ds.dataset('/hive', filesystem=multisourcefs),
+ ])
+ expected_schema = pa.schema([
+ ('date', pa.date32()),
+ ('index', pa.int64()),
+ ('value', pa.float64()),
+ ('color', pa.string()),
+ ])
+ assert dataset.schema.equals(expected_schema)
+
+ # with hive partitioning for two hive sources
+ dataset = ds.dataset([
+ ds.dataset('/plain', filesystem=multisourcefs),
+ ds.dataset('/schema', filesystem=multisourcefs),
+ ds.dataset('/hive', filesystem=multisourcefs, partitioning='hive')
+ ])
+ expected_schema = pa.schema([
+ ('date', pa.date32()),
+ ('index', pa.int64()),
+ ('value', pa.float64()),
+ ('color', pa.string()),
+ ('year', pa.int32()),
+ ('month', pa.int32()),
+ ])
+ assert dataset.schema.equals(expected_schema)
+
+
+@pytest.mark.parquet
+def test_specified_schema(tempdir, dataset_reader):
+ import pyarrow.parquet as pq
+
+ table = pa.table({'a': [1, 2, 3], 'b': [.1, .2, .3]})
+ pq.write_table(table, tempdir / "data.parquet")
+
+ def _check_dataset(schema, expected, expected_schema=None):
+ dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema)
+ if expected_schema is not None:
+ assert dataset.schema.equals(expected_schema)
+ else:
+ assert dataset.schema.equals(schema)
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(expected)
+
+ # no schema specified
+ schema = None
+ expected = table
+ _check_dataset(schema, expected, expected_schema=table.schema)
+
+ # identical schema specified
+ schema = table.schema
+ expected = table
+ _check_dataset(schema, expected)
+
+ # Specifying schema with change column order
+ schema = pa.schema([('b', 'float64'), ('a', 'int64')])
+ expected = pa.table([[.1, .2, .3], [1, 2, 3]], names=['b', 'a'])
+ _check_dataset(schema, expected)
+
+ # Specifying schema with missing column
+ schema = pa.schema([('a', 'int64')])
+ expected = pa.table([[1, 2, 3]], names=['a'])
+ _check_dataset(schema, expected)
+
+ # Specifying schema with additional column
+ schema = pa.schema([('a', 'int64'), ('c', 'int32')])
+ expected = pa.table([[1, 2, 3],
+ pa.array([None, None, None], type='int32')],
+ names=['a', 'c'])
+ _check_dataset(schema, expected)
+
+ # Specifying with differing field types
+ schema = pa.schema([('a', 'int32'), ('b', 'float64')])
+ dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema)
+ expected = pa.table([table['a'].cast('int32'),
+ table['b']],
+ names=['a', 'b'])
+ _check_dataset(schema, expected)
+
+ # Specifying with incompatible schema
+ schema = pa.schema([('a', pa.list_(pa.int32())), ('b', 'float64')])
+ dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema)
+ assert dataset.schema.equals(schema)
+ with pytest.raises(NotImplementedError,
+ match='Unsupported cast from int64 to list'):
+ dataset_reader.to_table(dataset)
+
+
+@pytest.mark.parquet
+def test_incompatible_schema_hang(tempdir, dataset_reader):
+ # ARROW-13480: deadlock when reading past an errored fragment
+ import pyarrow.parquet as pq
+
+ fn = tempdir / "data.parquet"
+ table = pa.table({'a': [1, 2, 3]})
+ pq.write_table(table, fn)
+
+ schema = pa.schema([('a', pa.null())])
+ dataset = ds.dataset([str(fn)] * 100, schema=schema)
+ assert dataset.schema.equals(schema)
+ scanner = dataset_reader.scanner(dataset)
+ reader = scanner.to_reader()
+ with pytest.raises(NotImplementedError,
+ match='Unsupported cast from int64 to null'):
+ reader.read_all()
+
+
+def test_ipc_format(tempdir, dataset_reader):
+ table = pa.table({'a': pa.array([1, 2, 3], type="int8"),
+ 'b': pa.array([.1, .2, .3], type="float64")})
+
+ path = str(tempdir / 'test.arrow')
+ with pa.output_stream(path) as sink:
+ writer = pa.RecordBatchFileWriter(sink, table.schema)
+ writer.write_batch(table.to_batches()[0])
+ writer.close()
+
+ dataset = ds.dataset(path, format=ds.IpcFileFormat())
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+ for format_str in ["ipc", "arrow"]:
+ dataset = ds.dataset(path, format=format_str)
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+
+@pytest.mark.orc
+def test_orc_format(tempdir, dataset_reader):
+ from pyarrow import orc
+ table = pa.table({'a': pa.array([1, 2, 3], type="int8"),
+ 'b': pa.array([.1, .2, .3], type="float64")})
+
+ path = str(tempdir / 'test.orc')
+ orc.write_table(table, path)
+
+ dataset = ds.dataset(path, format=ds.OrcFileFormat())
+ result = dataset_reader.to_table(dataset)
+ result.validate(full=True)
+ assert result.equals(table)
+
+ dataset = ds.dataset(path, format="orc")
+ result = dataset_reader.to_table(dataset)
+ result.validate(full=True)
+ assert result.equals(table)
+
+ result = dataset_reader.to_table(dataset, columns=["b"])
+ result.validate(full=True)
+ assert result.equals(table.select(["b"]))
+
+ result = dataset_reader.to_table(
+ dataset, columns={"b2": ds.field("b") * 2}
+ )
+ result.validate(full=True)
+ assert result.equals(
+ pa.table({'b2': pa.array([.2, .4, .6], type="float64")})
+ )
+
+ assert dataset_reader.count_rows(dataset) == 3
+ assert dataset_reader.count_rows(dataset, filter=ds.field("a") > 2) == 1
+
+
+@pytest.mark.orc
+def test_orc_scan_options(tempdir, dataset_reader):
+ from pyarrow import orc
+ table = pa.table({'a': pa.array([1, 2, 3], type="int8"),
+ 'b': pa.array([.1, .2, .3], type="float64")})
+
+ path = str(tempdir / 'test.orc')
+ orc.write_table(table, path)
+
+ dataset = ds.dataset(path, format="orc")
+ result = list(dataset_reader.to_batches(dataset))
+ assert len(result) == 1
+ assert result[0].num_rows == 3
+ assert result[0].equals(table.to_batches()[0])
+ # TODO batch_size is not yet supported (ARROW-14153)
+ # result = list(dataset_reader.to_batches(dataset, batch_size=2))
+ # assert len(result) == 2
+ # assert result[0].num_rows == 2
+ # assert result[0].equals(table.slice(0, 2).to_batches()[0])
+ # assert result[1].num_rows == 1
+ # assert result[1].equals(table.slice(2, 1).to_batches()[0])
+
+
+def test_orc_format_not_supported():
+ try:
+ from pyarrow.dataset import OrcFileFormat # noqa
+ except (ImportError, AttributeError):
+ # catch AttributeError for Python 3.6
+ # ORC is not available, test error message
+ with pytest.raises(
+ ValueError, match="not built with support for the ORC file"
+ ):
+ ds.dataset(".", format="orc")
+
+
+@pytest.mark.pandas
+def test_csv_format(tempdir, dataset_reader):
+ table = pa.table({'a': pa.array([1, 2, 3], type="int64"),
+ 'b': pa.array([.1, .2, .3], type="float64")})
+
+ path = str(tempdir / 'test.csv')
+ table.to_pandas().to_csv(path, index=False)
+
+ dataset = ds.dataset(path, format=ds.CsvFileFormat())
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+ dataset = ds.dataset(path, format='csv')
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize("compression", [
+ "bz2",
+ "gzip",
+ "lz4",
+ "zstd",
+])
+def test_csv_format_compressed(tempdir, compression, dataset_reader):
+ if not pyarrow.Codec.is_available(compression):
+ pytest.skip("{} support is not built".format(compression))
+ table = pa.table({'a': pa.array([1, 2, 3], type="int64"),
+ 'b': pa.array([.1, .2, .3], type="float64")})
+ filesystem = fs.LocalFileSystem()
+ suffix = compression if compression != 'gzip' else 'gz'
+ path = str(tempdir / f'test.csv.{suffix}')
+ with filesystem.open_output_stream(path, compression=compression) as sink:
+ # https://github.com/pandas-dev/pandas/issues/23854
+ # With CI version of Pandas (anything < 1.2), Pandas tries to write
+ # str to the sink
+ csv_str = table.to_pandas().to_csv(index=False)
+ sink.write(csv_str.encode('utf-8'))
+
+ dataset = ds.dataset(path, format=ds.CsvFileFormat())
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+
+def test_csv_format_options(tempdir, dataset_reader):
+ path = str(tempdir / 'test.csv')
+ with open(path, 'w') as sink:
+ sink.write('skipped\ncol0\nfoo\nbar\n')
+ dataset = ds.dataset(path, format='csv')
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(
+ pa.table({'skipped': pa.array(['col0', 'foo', 'bar'])}))
+
+ dataset = ds.dataset(path, format=ds.CsvFileFormat(
+ read_options=pa.csv.ReadOptions(skip_rows=1)))
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(pa.table({'col0': pa.array(['foo', 'bar'])}))
+
+ dataset = ds.dataset(path, format=ds.CsvFileFormat(
+ read_options=pa.csv.ReadOptions(column_names=['foo'])))
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(
+ pa.table({'foo': pa.array(['skipped', 'col0', 'foo', 'bar'])}))
+
+
+def test_csv_fragment_options(tempdir, dataset_reader):
+ path = str(tempdir / 'test.csv')
+ with open(path, 'w') as sink:
+ sink.write('col0\nfoo\nspam\nMYNULL\n')
+ dataset = ds.dataset(path, format='csv')
+ convert_options = pyarrow.csv.ConvertOptions(null_values=['MYNULL'],
+ strings_can_be_null=True)
+ options = ds.CsvFragmentScanOptions(
+ convert_options=convert_options,
+ read_options=pa.csv.ReadOptions(block_size=2**16))
+ result = dataset_reader.to_table(dataset, fragment_scan_options=options)
+ assert result.equals(pa.table({'col0': pa.array(['foo', 'spam', None])}))
+
+ csv_format = ds.CsvFileFormat(convert_options=convert_options)
+ dataset = ds.dataset(path, format=csv_format)
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(pa.table({'col0': pa.array(['foo', 'spam', None])}))
+
+ options = ds.CsvFragmentScanOptions()
+ result = dataset_reader.to_table(dataset, fragment_scan_options=options)
+ assert result.equals(
+ pa.table({'col0': pa.array(['foo', 'spam', 'MYNULL'])}))
+
+
+def test_feather_format(tempdir, dataset_reader):
+ from pyarrow.feather import write_feather
+
+ table = pa.table({'a': pa.array([1, 2, 3], type="int8"),
+ 'b': pa.array([.1, .2, .3], type="float64")})
+
+ basedir = tempdir / "feather_dataset"
+ basedir.mkdir()
+ write_feather(table, str(basedir / "data.feather"))
+
+ dataset = ds.dataset(basedir, format=ds.IpcFileFormat())
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+ dataset = ds.dataset(basedir, format="feather")
+ result = dataset_reader.to_table(dataset)
+ assert result.equals(table)
+
+ # ARROW-8641 - column selection order
+ result = dataset_reader.to_table(dataset, columns=["b", "a"])
+ assert result.column_names == ["b", "a"]
+ result = dataset_reader.to_table(dataset, columns=["a", "a"])
+ assert result.column_names == ["a", "a"]
+
+ # error with Feather v1 files
+ write_feather(table, str(basedir / "data1.feather"), version=1)
+ with pytest.raises(ValueError):
+ dataset_reader.to_table(ds.dataset(basedir, format="feather"))
+
+
+def _create_parquet_dataset_simple(root_path):
+ """
+ Creates a simple (flat files, no nested partitioning) Parquet dataset
+ """
+ import pyarrow.parquet as pq
+
+ metadata_collector = []
+
+ for i in range(4):
+ table = pa.table({'f1': [i] * 10, 'f2': np.random.randn(10)})
+ pq.write_to_dataset(
+ table, str(root_path), metadata_collector=metadata_collector
+ )
+
+ metadata_path = str(root_path / '_metadata')
+ # write _metadata file
+ pq.write_metadata(
+ table.schema, metadata_path,
+ metadata_collector=metadata_collector
+ )
+ return metadata_path, table
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas # write_to_dataset currently requires pandas
+def test_parquet_dataset_factory(tempdir):
+ root_path = tempdir / "test_parquet_dataset"
+ metadata_path, table = _create_parquet_dataset_simple(root_path)
+ dataset = ds.parquet_dataset(metadata_path)
+ assert dataset.schema.equals(table.schema)
+ assert len(dataset.files) == 4
+ result = dataset.to_table()
+ assert result.num_rows == 40
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas # write_to_dataset currently requires pandas
+@pytest.mark.parametrize('use_legacy_dataset', [False, True])
+def test_parquet_dataset_factory_roundtrip(tempdir, use_legacy_dataset):
+ # Simple test to ensure we can roundtrip dataset to
+ # _metadata/common_metadata and back. A more complex test
+ # using partitioning will have to wait for ARROW-13269. The
+ # above test (test_parquet_dataset_factory) will not work
+ # when legacy is False as there is no "append" equivalent in
+ # the new dataset until ARROW-12358
+ import pyarrow.parquet as pq
+ root_path = tempdir / "test_parquet_dataset"
+ table = pa.table({'f1': [0] * 10, 'f2': np.random.randn(10)})
+ metadata_collector = []
+ pq.write_to_dataset(
+ table, str(root_path), metadata_collector=metadata_collector,
+ use_legacy_dataset=use_legacy_dataset
+ )
+ metadata_path = str(root_path / '_metadata')
+ # write _metadata file
+ pq.write_metadata(
+ table.schema, metadata_path,
+ metadata_collector=metadata_collector
+ )
+ dataset = ds.parquet_dataset(metadata_path)
+ assert dataset.schema.equals(table.schema)
+ result = dataset.to_table()
+ assert result.num_rows == 10
+
+
+def test_parquet_dataset_factory_order(tempdir):
+ # The order of the fragments in the dataset should match the order of the
+ # row groups in the _metadata file.
+ import pyarrow.parquet as pq
+ metadatas = []
+ # Create a dataset where f1 is incrementing from 0 to 100 spread across
+ # 10 files. Put the row groups in the correct order in _metadata
+ for i in range(10):
+ table = pa.table(
+ {'f1': list(range(i*10, (i+1)*10))})
+ table_path = tempdir / f'{i}.parquet'
+ pq.write_table(table, table_path, metadata_collector=metadatas)
+ metadatas[-1].set_file_path(f'{i}.parquet')
+ metadata_path = str(tempdir / '_metadata')
+ pq.write_metadata(table.schema, metadata_path, metadatas)
+ dataset = ds.parquet_dataset(metadata_path)
+ # Ensure the table contains values from 0-100 in the right order
+ scanned_table = dataset.to_table()
+ scanned_col = scanned_table.column('f1').to_pylist()
+ assert scanned_col == list(range(0, 100))
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_parquet_dataset_factory_invalid(tempdir):
+ root_path = tempdir / "test_parquet_dataset_invalid"
+ metadata_path, table = _create_parquet_dataset_simple(root_path)
+ # remove one of the files
+ list(root_path.glob("*.parquet"))[0].unlink()
+ dataset = ds.parquet_dataset(metadata_path)
+ assert dataset.schema.equals(table.schema)
+ assert len(dataset.files) == 4
+ with pytest.raises(FileNotFoundError):
+ dataset.to_table()
+
+
+def _create_metadata_file(root_path):
+ # create _metadata file from existing parquet dataset
+ import pyarrow.parquet as pq
+
+ parquet_paths = list(sorted(root_path.rglob("*.parquet")))
+ schema = pq.ParquetFile(parquet_paths[0]).schema.to_arrow_schema()
+
+ metadata_collector = []
+ for path in parquet_paths:
+ metadata = pq.ParquetFile(path).metadata
+ metadata.set_file_path(str(path.relative_to(root_path)))
+ metadata_collector.append(metadata)
+
+ metadata_path = root_path / "_metadata"
+ pq.write_metadata(
+ schema, metadata_path, metadata_collector=metadata_collector
+ )
+ return metadata_path
+
+
+def _create_parquet_dataset_partitioned(root_path):
+ import pyarrow.parquet as pq
+
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))],
+ names=["f1", "f2", "part"]
+ )
+ table = table.replace_schema_metadata({"key": "value"})
+ pq.write_to_dataset(table, str(root_path), partition_cols=['part'])
+ return _create_metadata_file(root_path), table
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_parquet_dataset_factory_partitioned(tempdir):
+ root_path = tempdir / "test_parquet_dataset_factory_partitioned"
+ metadata_path, table = _create_parquet_dataset_partitioned(root_path)
+
+ partitioning = ds.partitioning(flavor="hive")
+ dataset = ds.parquet_dataset(metadata_path, partitioning=partitioning)
+
+ assert dataset.schema.equals(table.schema)
+ assert len(dataset.files) == 2
+ result = dataset.to_table()
+ assert result.num_rows == 20
+
+ # the partitioned dataset does not preserve order
+ result = result.to_pandas().sort_values("f1").reset_index(drop=True)
+ expected = table.to_pandas()
+ pd.testing.assert_frame_equal(result, expected)
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_parquet_dataset_factory_metadata(tempdir):
+ # ensure ParquetDatasetFactory preserves metadata (ARROW-9363)
+ root_path = tempdir / "test_parquet_dataset_factory_metadata"
+ metadata_path, table = _create_parquet_dataset_partitioned(root_path)
+
+ dataset = ds.parquet_dataset(metadata_path, partitioning="hive")
+ assert dataset.schema.equals(table.schema)
+ assert b"key" in dataset.schema.metadata
+
+ fragments = list(dataset.get_fragments())
+ assert b"key" in fragments[0].physical_schema.metadata
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_parquet_dataset_lazy_filtering(tempdir, open_logging_fs):
+ fs, assert_opens = open_logging_fs
+
+ # Test to ensure that no IO happens when filtering a dataset
+ # created with ParquetDatasetFactory from a _metadata file
+
+ root_path = tempdir / "test_parquet_dataset_lazy_filtering"
+ metadata_path, _ = _create_parquet_dataset_simple(root_path)
+
+ # creating the dataset should only open the metadata file
+ with assert_opens([metadata_path]):
+ dataset = ds.parquet_dataset(
+ metadata_path,
+ partitioning=ds.partitioning(flavor="hive"),
+ filesystem=fs)
+
+ # materializing fragments should not open any file
+ with assert_opens([]):
+ fragments = list(dataset.get_fragments())
+
+ # filtering fragments should not open any file
+ with assert_opens([]):
+ list(dataset.get_fragments(ds.field("f1") > 15))
+
+ # splitting by row group should still not open any file
+ with assert_opens([]):
+ fragments[0].split_by_row_group(ds.field("f1") > 15)
+
+ # ensuring metadata of splitted fragment should also not open any file
+ with assert_opens([]):
+ rg_fragments = fragments[0].split_by_row_group()
+ rg_fragments[0].ensure_complete_metadata()
+
+ # FIXME(bkietz) on Windows this results in FileNotFoundErrors.
+ # but actually scanning does open files
+ # with assert_opens([f.path for f in fragments]):
+ # dataset.to_table()
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_dataset_schema_metadata(tempdir, dataset_reader):
+ # ARROW-8802
+ df = pd.DataFrame({'a': [1, 2, 3]})
+ path = tempdir / "test.parquet"
+ df.to_parquet(path)
+ dataset = ds.dataset(path)
+
+ schema = dataset_reader.to_table(dataset).schema
+ projected_schema = dataset_reader.to_table(dataset, columns=["a"]).schema
+
+ # ensure the pandas metadata is included in the schema
+ assert b"pandas" in schema.metadata
+ # ensure it is still there in a projected schema (with column selection)
+ assert schema.equals(projected_schema, check_metadata=True)
+
+
+@pytest.mark.parquet
+def test_filter_mismatching_schema(tempdir, dataset_reader):
+ # ARROW-9146
+ import pyarrow.parquet as pq
+
+ table = pa.table({"col": pa.array([1, 2, 3, 4], type='int32')})
+ pq.write_table(table, str(tempdir / "data.parquet"))
+
+ # specifying explicit schema, but that mismatches the schema of the data
+ schema = pa.schema([("col", pa.int64())])
+ dataset = ds.dataset(
+ tempdir / "data.parquet", format="parquet", schema=schema)
+
+ # filtering on a column with such type mismatch should implicitly
+ # cast the column
+ filtered = dataset_reader.to_table(dataset, filter=ds.field("col") > 2)
+ assert filtered["col"].equals(table["col"].cast('int64').slice(2))
+
+ fragment = list(dataset.get_fragments())[0]
+ filtered = dataset_reader.to_table(
+ fragment, filter=ds.field("col") > 2, schema=schema)
+ assert filtered["col"].equals(table["col"].cast('int64').slice(2))
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_dataset_project_only_partition_columns(tempdir, dataset_reader):
+ # ARROW-8729
+ import pyarrow.parquet as pq
+
+ table = pa.table({'part': 'a a b b'.split(), 'col': list(range(4))})
+
+ path = str(tempdir / 'test_dataset')
+ pq.write_to_dataset(table, path, partition_cols=['part'])
+ dataset = ds.dataset(path, partitioning='hive')
+
+ all_cols = dataset_reader.to_table(dataset)
+ part_only = dataset_reader.to_table(dataset, columns=['part'])
+
+ assert all_cols.column('part').equals(part_only.column('part'))
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_dataset_project_null_column(tempdir, dataset_reader):
+ import pandas as pd
+ df = pd.DataFrame({"col": np.array([None, None, None], dtype='object')})
+
+ f = tempdir / "test_dataset_project_null_column.parquet"
+ df.to_parquet(f, engine="pyarrow")
+
+ dataset = ds.dataset(f, format="parquet",
+ schema=pa.schema([("col", pa.int64())]))
+ expected = pa.table({'col': pa.array([None, None, None], pa.int64())})
+ assert dataset_reader.to_table(dataset).equals(expected)
+
+
+def test_dataset_project_columns(tempdir, dataset_reader):
+ # basic column re-projection with expressions
+ from pyarrow import feather
+ table = pa.table({"A": [1, 2, 3], "B": [1., 2., 3.], "C": ["a", "b", "c"]})
+ feather.write_feather(table, tempdir / "data.feather")
+
+ dataset = ds.dataset(tempdir / "data.feather", format="feather")
+ result = dataset_reader.to_table(dataset, columns={
+ 'A_renamed': ds.field('A'),
+ 'B_as_int': ds.field('B').cast("int32", safe=False),
+ 'C_is_a': ds.field('C') == 'a'
+ })
+ expected = pa.table({
+ "A_renamed": [1, 2, 3],
+ "B_as_int": pa.array([1, 2, 3], type="int32"),
+ "C_is_a": [True, False, False],
+ })
+ assert result.equals(expected)
+
+ # raise proper error when not passing an expression
+ with pytest.raises(TypeError, match="Expected an Expression"):
+ dataset_reader.to_table(dataset, columns={"A": "A"})
+
+
+@pytest.mark.pandas
+@pytest.mark.parquet
+def test_dataset_preserved_partitioning(tempdir):
+ # ARROW-8655
+
+ # through discovery, but without partitioning
+ _, path = _create_single_file(tempdir)
+ dataset = ds.dataset(path)
+ assert dataset.partitioning is None
+
+ # through discovery, with hive partitioning but not specified
+ full_table, path = _create_partitioned_dataset(tempdir)
+ dataset = ds.dataset(path)
+ assert dataset.partitioning is None
+
+ # through discovery, with hive partitioning (from a partitioning factory)
+ dataset = ds.dataset(path, partitioning="hive")
+ part = dataset.partitioning
+ assert part is not None
+ assert isinstance(part, ds.HivePartitioning)
+ assert part.schema == pa.schema([("part", pa.int32())])
+ assert len(part.dictionaries) == 1
+ assert part.dictionaries[0] == pa.array([0, 1, 2], pa.int32())
+
+ # through discovery, with hive partitioning (from a partitioning object)
+ part = ds.partitioning(pa.schema([("part", pa.int32())]), flavor="hive")
+ assert isinstance(part, ds.HivePartitioning) # not a factory
+ assert part.dictionaries is None
+ dataset = ds.dataset(path, partitioning=part)
+ part = dataset.partitioning
+ assert isinstance(part, ds.HivePartitioning)
+ assert part.schema == pa.schema([("part", pa.int32())])
+ # TODO is this expected?
+ assert part.dictionaries is None
+
+ # through manual creation -> not available
+ dataset = ds.dataset(path, partitioning="hive")
+ dataset2 = ds.FileSystemDataset(
+ list(dataset.get_fragments()), schema=dataset.schema,
+ format=dataset.format, filesystem=dataset.filesystem
+ )
+ assert dataset2.partitioning is None
+
+ # through discovery with ParquetDatasetFactory
+ root_path = tempdir / "data-partitioned-metadata"
+ metadata_path, _ = _create_parquet_dataset_partitioned(root_path)
+ dataset = ds.parquet_dataset(metadata_path, partitioning="hive")
+ part = dataset.partitioning
+ assert part is not None
+ assert isinstance(part, ds.HivePartitioning)
+ assert part.schema == pa.schema([("part", pa.string())])
+ assert len(part.dictionaries) == 1
+ # will be fixed by ARROW-13153 (order is not preserved at the moment)
+ # assert part.dictionaries[0] == pa.array(["a", "b"], pa.string())
+ assert set(part.dictionaries[0].to_pylist()) == {"a", "b"}
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_write_to_dataset_given_null_just_works(tempdir):
+ import pyarrow.parquet as pq
+
+ schema = pa.schema([
+ pa.field('col', pa.int64()),
+ pa.field('part', pa.dictionary(pa.int32(), pa.string()))
+ ])
+ table = pa.table({'part': [None, None, 'a', 'a'],
+ 'col': list(range(4))}, schema=schema)
+
+ path = str(tempdir / 'test_dataset')
+ pq.write_to_dataset(table, path, partition_cols=[
+ 'part'], use_legacy_dataset=False)
+
+ actual_table = pq.read_table(tempdir / 'test_dataset')
+ # column.equals can handle the difference in chunking but not the fact
+ # that `part` will have different dictionaries for the two chunks
+ assert actual_table.column('part').to_pylist(
+ ) == table.column('part').to_pylist()
+ assert actual_table.column('col').equals(table.column('col'))
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_legacy_write_to_dataset_drops_null(tempdir):
+ import pyarrow.parquet as pq
+
+ schema = pa.schema([
+ pa.field('col', pa.int64()),
+ pa.field('part', pa.dictionary(pa.int32(), pa.string()))
+ ])
+ table = pa.table({'part': ['a', 'a', None, None],
+ 'col': list(range(4))}, schema=schema)
+ expected = pa.table(
+ {'part': ['a', 'a'], 'col': list(range(2))}, schema=schema)
+
+ path = str(tempdir / 'test_dataset')
+ pq.write_to_dataset(table, path, partition_cols=[
+ 'part'], use_legacy_dataset=True)
+
+ actual = pq.read_table(tempdir / 'test_dataset')
+ assert actual == expected
+
+
+def _sort_table(tab, sort_col):
+ import pyarrow.compute as pc
+ sorted_indices = pc.sort_indices(
+ tab, options=pc.SortOptions([(sort_col, 'ascending')]))
+ return pc.take(tab, sorted_indices)
+
+
+def _check_dataset_roundtrip(dataset, base_dir, expected_files, sort_col,
+ base_dir_path=None, partitioning=None):
+ base_dir_path = base_dir_path or base_dir
+
+ ds.write_dataset(dataset, base_dir, format="feather",
+ partitioning=partitioning, use_threads=False)
+
+ # check that all files are present
+ file_paths = list(base_dir_path.rglob("*"))
+ assert set(file_paths) == set(expected_files)
+
+ # check that reading back in as dataset gives the same result
+ dataset2 = ds.dataset(
+ base_dir_path, format="feather", partitioning=partitioning)
+
+ assert _sort_table(dataset2.to_table(), sort_col).equals(
+ _sort_table(dataset.to_table(), sort_col))
+
+
+@pytest.mark.parquet
+def test_write_dataset(tempdir):
+ # manually create a written dataset and read as dataset object
+ directory = tempdir / 'single-file'
+ directory.mkdir()
+ _ = _create_single_file(directory)
+ dataset = ds.dataset(directory)
+
+ # full string path
+ target = tempdir / 'single-file-target'
+ expected_files = [target / "part-0.feather"]
+ _check_dataset_roundtrip(dataset, str(target), expected_files, 'a', target)
+
+ # pathlib path object
+ target = tempdir / 'single-file-target2'
+ expected_files = [target / "part-0.feather"]
+ _check_dataset_roundtrip(dataset, target, expected_files, 'a', target)
+
+ # TODO
+ # # relative path
+ # target = tempdir / 'single-file-target3'
+ # expected_files = [target / "part-0.ipc"]
+ # _check_dataset_roundtrip(
+ # dataset, './single-file-target3', expected_files, target)
+
+ # Directory of files
+ directory = tempdir / 'single-directory'
+ directory.mkdir()
+ _ = _create_directory_of_files(directory)
+ dataset = ds.dataset(directory)
+
+ target = tempdir / 'single-directory-target'
+ expected_files = [target / "part-0.feather"]
+ _check_dataset_roundtrip(dataset, str(target), expected_files, 'a', target)
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_write_dataset_partitioned(tempdir):
+ directory = tempdir / "partitioned"
+ _ = _create_parquet_dataset_partitioned(directory)
+ partitioning = ds.partitioning(flavor="hive")
+ dataset = ds.dataset(directory, partitioning=partitioning)
+
+ # hive partitioning
+ target = tempdir / 'partitioned-hive-target'
+ expected_paths = [
+ target / "part=a", target / "part=a" / "part-0.feather",
+ target / "part=b", target / "part=b" / "part-0.feather"
+ ]
+ partitioning_schema = ds.partitioning(
+ pa.schema([("part", pa.string())]), flavor="hive")
+ _check_dataset_roundtrip(
+ dataset, str(target), expected_paths, 'f1', target,
+ partitioning=partitioning_schema)
+
+ # directory partitioning
+ target = tempdir / 'partitioned-dir-target'
+ expected_paths = [
+ target / "a", target / "a" / "part-0.feather",
+ target / "b", target / "b" / "part-0.feather"
+ ]
+ partitioning_schema = ds.partitioning(
+ pa.schema([("part", pa.string())]))
+ _check_dataset_roundtrip(
+ dataset, str(target), expected_paths, 'f1', target,
+ partitioning=partitioning_schema)
+
+
+def test_write_dataset_with_field_names(tempdir):
+ table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']})
+
+ ds.write_dataset(table, tempdir, format='parquet',
+ partitioning=["b"])
+
+ load_back = ds.dataset(tempdir, partitioning=["b"])
+ files = load_back.files
+ partitioning_dirs = {
+ str(pathlib.Path(f).relative_to(tempdir).parent) for f in files
+ }
+ assert partitioning_dirs == {"x", "y", "z"}
+
+ load_back_table = load_back.to_table()
+ assert load_back_table.equals(table)
+
+
+def test_write_dataset_with_field_names_hive(tempdir):
+ table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']})
+
+ ds.write_dataset(table, tempdir, format='parquet',
+ partitioning=["b"], partitioning_flavor="hive")
+
+ load_back = ds.dataset(tempdir, partitioning="hive")
+ files = load_back.files
+ partitioning_dirs = {
+ str(pathlib.Path(f).relative_to(tempdir).parent) for f in files
+ }
+ assert partitioning_dirs == {"b=x", "b=y", "b=z"}
+
+ load_back_table = load_back.to_table()
+ assert load_back_table.equals(table)
+
+
+def test_write_dataset_with_scanner(tempdir):
+ table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z'],
+ 'c': [1, 2, 3]})
+
+ ds.write_dataset(table, tempdir, format='parquet',
+ partitioning=["b"])
+
+ dataset = ds.dataset(tempdir, partitioning=["b"])
+
+ with tempfile.TemporaryDirectory() as tempdir2:
+ ds.write_dataset(dataset.scanner(columns=["b", "c"], use_async=True),
+ tempdir2, format='parquet', partitioning=["b"])
+
+ load_back = ds.dataset(tempdir2, partitioning=["b"])
+ load_back_table = load_back.to_table()
+ assert dict(load_back_table.to_pydict()
+ ) == table.drop(["a"]).to_pydict()
+
+
+def test_write_dataset_with_backpressure(tempdir):
+ consumer_gate = threading.Event()
+
+ # A filesystem that blocks all writes so that we can build
+ # up backpressure. The writes are released at the end of
+ # the test.
+ class GatingFs(ProxyHandler):
+ def open_output_stream(self, path, metadata):
+ # Block until the end of the test
+ consumer_gate.wait()
+ return self._fs.open_output_stream(path, metadata=metadata)
+ gating_fs = fs.PyFileSystem(GatingFs(fs.LocalFileSystem()))
+
+ schema = pa.schema([pa.field('data', pa.int32())])
+ # By default, the dataset writer will queue up 64Mi rows so
+ # with batches of 1M it should only fit ~67 batches
+ batch = pa.record_batch([pa.array(list(range(1_000_000)))], schema=schema)
+ batches_read = 0
+ min_backpressure = 67
+ end = 200
+
+ def counting_generator():
+ nonlocal batches_read
+ while batches_read < end:
+ time.sleep(0.01)
+ batches_read += 1
+ yield batch
+
+ scanner = ds.Scanner.from_batches(
+ counting_generator(), schema=schema, use_threads=True,
+ use_async=True)
+
+ write_thread = threading.Thread(
+ target=lambda: ds.write_dataset(
+ scanner, str(tempdir), format='parquet', filesystem=gating_fs))
+ write_thread.start()
+
+ try:
+ start = time.time()
+
+ def duration():
+ return time.time() - start
+
+ # This test is timing dependent. There is no signal from the C++
+ # when backpressure has been hit. We don't know exactly when
+ # backpressure will be hit because it may take some time for the
+ # signal to get from the sink to the scanner.
+ #
+ # The test may emit false positives on slow systems. It could
+ # theoretically emit a false negative if the scanner managed to read
+ # and emit all 200 batches before the backpressure signal had a chance
+ # to propagate but the 0.01s delay in the generator should make that
+ # scenario unlikely.
+ last_value = 0
+ backpressure_probably_hit = False
+ while duration() < 10:
+ if batches_read > min_backpressure:
+ if batches_read == last_value:
+ backpressure_probably_hit = True
+ break
+ last_value = batches_read
+ time.sleep(0.5)
+
+ assert backpressure_probably_hit
+
+ finally:
+ consumer_gate.set()
+ write_thread.join()
+ assert batches_read == end
+
+
+def test_write_dataset_with_dataset(tempdir):
+ table = pa.table({'b': ['x', 'y', 'z'], 'c': [1, 2, 3]})
+
+ ds.write_dataset(table, tempdir, format='parquet',
+ partitioning=["b"])
+
+ dataset = ds.dataset(tempdir, partitioning=["b"])
+
+ with tempfile.TemporaryDirectory() as tempdir2:
+ ds.write_dataset(dataset, tempdir2,
+ format='parquet', partitioning=["b"])
+
+ load_back = ds.dataset(tempdir2, partitioning=["b"])
+ load_back_table = load_back.to_table()
+ assert dict(load_back_table.to_pydict()) == table.to_pydict()
+
+
+@pytest.mark.pandas
+def test_write_dataset_existing_data(tempdir):
+ directory = tempdir / 'ds'
+ table = pa.table({'b': ['x', 'y', 'z'], 'c': [1, 2, 3]})
+ partitioning = ds.partitioning(schema=pa.schema(
+ [pa.field('c', pa.int64())]), flavor='hive')
+
+ def compare_tables_ignoring_order(t1, t2):
+ df1 = t1.to_pandas().sort_values('b').reset_index(drop=True)
+ df2 = t2.to_pandas().sort_values('b').reset_index(drop=True)
+ assert df1.equals(df2)
+
+ # First write is ok
+ ds.write_dataset(table, directory, partitioning=partitioning, format='ipc')
+
+ table = pa.table({'b': ['a', 'b', 'c'], 'c': [2, 3, 4]})
+
+ # Second write should fail
+ with pytest.raises(pa.ArrowInvalid):
+ ds.write_dataset(table, directory,
+ partitioning=partitioning, format='ipc')
+
+ extra_table = pa.table({'b': ['e']})
+ extra_file = directory / 'c=2' / 'foo.arrow'
+ pyarrow.feather.write_feather(extra_table, extra_file)
+
+ # Should be ok and overwrite with overwrite behavior
+ ds.write_dataset(table, directory, partitioning=partitioning,
+ format='ipc',
+ existing_data_behavior='overwrite_or_ignore')
+
+ overwritten = pa.table(
+ {'b': ['e', 'x', 'a', 'b', 'c'], 'c': [2, 1, 2, 3, 4]})
+ readback = ds.dataset(tempdir, format='ipc',
+ partitioning=partitioning).to_table()
+ compare_tables_ignoring_order(readback, overwritten)
+ assert extra_file.exists()
+
+ # Should be ok and delete matching with delete_matching
+ ds.write_dataset(table, directory, partitioning=partitioning,
+ format='ipc', existing_data_behavior='delete_matching')
+
+ overwritten = pa.table({'b': ['x', 'a', 'b', 'c'], 'c': [1, 2, 3, 4]})
+ readback = ds.dataset(tempdir, format='ipc',
+ partitioning=partitioning).to_table()
+ compare_tables_ignoring_order(readback, overwritten)
+ assert not extra_file.exists()
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_write_dataset_partitioned_dict(tempdir):
+ directory = tempdir / "partitioned"
+ _ = _create_parquet_dataset_partitioned(directory)
+
+ # directory partitioning, dictionary partition columns
+ dataset = ds.dataset(
+ directory,
+ partitioning=ds.HivePartitioning.discover(infer_dictionary=True))
+ target = tempdir / 'partitioned-dir-target'
+ expected_paths = [
+ target / "a", target / "a" / "part-0.feather",
+ target / "b", target / "b" / "part-0.feather"
+ ]
+ partitioning = ds.partitioning(pa.schema([
+ dataset.schema.field('part')]),
+ dictionaries={'part': pa.array(['a', 'b'])})
+ # NB: dictionaries required here since we use partitioning to parse
+ # directories in _check_dataset_roundtrip (not currently required for
+ # the formatting step)
+ _check_dataset_roundtrip(
+ dataset, str(target), expected_paths, 'f1', target,
+ partitioning=partitioning)
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_write_dataset_use_threads(tempdir):
+ directory = tempdir / "partitioned"
+ _ = _create_parquet_dataset_partitioned(directory)
+ dataset = ds.dataset(directory, partitioning="hive")
+
+ partitioning = ds.partitioning(
+ pa.schema([("part", pa.string())]), flavor="hive")
+
+ target1 = tempdir / 'partitioned1'
+ paths_written = []
+
+ def file_visitor(written_file):
+ paths_written.append(written_file.path)
+
+ ds.write_dataset(
+ dataset, target1, format="feather", partitioning=partitioning,
+ use_threads=True, file_visitor=file_visitor
+ )
+
+ expected_paths = {
+ target1 / 'part=a' / 'part-0.feather',
+ target1 / 'part=b' / 'part-0.feather'
+ }
+ paths_written_set = set(map(pathlib.Path, paths_written))
+ assert paths_written_set == expected_paths
+
+ target2 = tempdir / 'partitioned2'
+ ds.write_dataset(
+ dataset, target2, format="feather", partitioning=partitioning,
+ use_threads=False
+ )
+
+ # check that reading in gives same result
+ result1 = ds.dataset(target1, format="feather", partitioning=partitioning)
+ result2 = ds.dataset(target2, format="feather", partitioning=partitioning)
+ assert result1.to_table().equals(result2.to_table())
+
+
+def test_write_table(tempdir):
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))
+ ], names=["f1", "f2", "part"])
+
+ base_dir = tempdir / 'single'
+ ds.write_dataset(table, base_dir,
+ basename_template='dat_{i}.arrow', format="feather")
+ # check that all files are present
+ file_paths = list(base_dir.rglob("*"))
+ expected_paths = [base_dir / "dat_0.arrow"]
+ assert set(file_paths) == set(expected_paths)
+ # check Table roundtrip
+ result = ds.dataset(base_dir, format="ipc").to_table()
+ assert result.equals(table)
+
+ # with partitioning
+ base_dir = tempdir / 'partitioned'
+ expected_paths = [
+ base_dir / "part=a", base_dir / "part=a" / "dat_0.arrow",
+ base_dir / "part=b", base_dir / "part=b" / "dat_0.arrow"
+ ]
+
+ visited_paths = []
+
+ def file_visitor(written_file):
+ visited_paths.append(written_file.path)
+
+ partitioning = ds.partitioning(
+ pa.schema([("part", pa.string())]), flavor="hive")
+ ds.write_dataset(table, base_dir, format="feather",
+ basename_template='dat_{i}.arrow',
+ partitioning=partitioning, file_visitor=file_visitor)
+ file_paths = list(base_dir.rglob("*"))
+ assert set(file_paths) == set(expected_paths)
+ result = ds.dataset(base_dir, format="ipc", partitioning=partitioning)
+ assert result.to_table().equals(table)
+ assert len(visited_paths) == 2
+ for visited_path in visited_paths:
+ assert pathlib.Path(visited_path) in expected_paths
+
+
+def test_write_table_multiple_fragments(tempdir):
+ table = pa.table([
+ pa.array(range(10)), pa.array(np.random.randn(10)),
+ pa.array(np.repeat(['a', 'b'], 5))
+ ], names=["f1", "f2", "part"])
+ table = pa.concat_tables([table]*2)
+
+ # Table with multiple batches written as single Fragment by default
+ base_dir = tempdir / 'single'
+ ds.write_dataset(table, base_dir, format="feather")
+ assert set(base_dir.rglob("*")) == set([base_dir / "part-0.feather"])
+ assert ds.dataset(base_dir, format="ipc").to_table().equals(table)
+
+ # Same for single-element list of Table
+ base_dir = tempdir / 'single-list'
+ ds.write_dataset([table], base_dir, format="feather")
+ assert set(base_dir.rglob("*")) == set([base_dir / "part-0.feather"])
+ assert ds.dataset(base_dir, format="ipc").to_table().equals(table)
+
+ # Provide list of batches to write multiple fragments
+ base_dir = tempdir / 'multiple'
+ ds.write_dataset(table.to_batches(), base_dir, format="feather")
+ assert set(base_dir.rglob("*")) == set(
+ [base_dir / "part-0.feather"])
+ assert ds.dataset(base_dir, format="ipc").to_table().equals(table)
+
+ # Provide list of tables to write multiple fragments
+ base_dir = tempdir / 'multiple-table'
+ ds.write_dataset([table, table], base_dir, format="feather")
+ assert set(base_dir.rglob("*")) == set(
+ [base_dir / "part-0.feather"])
+ assert ds.dataset(base_dir, format="ipc").to_table().equals(
+ pa.concat_tables([table]*2)
+ )
+
+
+def test_write_iterable(tempdir):
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))
+ ], names=["f1", "f2", "part"])
+
+ base_dir = tempdir / 'inmemory_iterable'
+ ds.write_dataset((batch for batch in table.to_batches()), base_dir,
+ schema=table.schema,
+ basename_template='dat_{i}.arrow', format="feather")
+ result = ds.dataset(base_dir, format="ipc").to_table()
+ assert result.equals(table)
+
+ base_dir = tempdir / 'inmemory_reader'
+ reader = pa.ipc.RecordBatchReader.from_batches(table.schema,
+ table.to_batches())
+ ds.write_dataset(reader, base_dir,
+ basename_template='dat_{i}.arrow', format="feather")
+ result = ds.dataset(base_dir, format="ipc").to_table()
+ assert result.equals(table)
+
+
+def test_write_scanner(tempdir, dataset_reader):
+ if not dataset_reader.use_async:
+ pytest.skip(
+ ('ARROW-13338: Write dataset with scanner does not'
+ ' support synchronous scan'))
+
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))
+ ], names=["f1", "f2", "part"])
+ dataset = ds.dataset(table)
+
+ base_dir = tempdir / 'dataset_from_scanner'
+ ds.write_dataset(dataset_reader.scanner(
+ dataset), base_dir, format="feather")
+ result = dataset_reader.to_table(ds.dataset(base_dir, format="ipc"))
+ assert result.equals(table)
+
+ # scanner with different projected_schema
+ base_dir = tempdir / 'dataset_from_scanner2'
+ ds.write_dataset(dataset_reader.scanner(dataset, columns=["f1"]),
+ base_dir, format="feather")
+ result = dataset_reader.to_table(ds.dataset(base_dir, format="ipc"))
+ assert result.equals(table.select(["f1"]))
+
+ # schema not allowed when writing a scanner
+ with pytest.raises(ValueError, match="Cannot specify a schema"):
+ ds.write_dataset(dataset_reader.scanner(dataset), base_dir,
+ schema=table.schema, format="feather")
+
+
+def test_write_table_partitioned_dict(tempdir):
+ # ensure writing table partitioned on a dictionary column works without
+ # specifying the dictionary values explicitly
+ table = pa.table([
+ pa.array(range(20)),
+ pa.array(np.repeat(['a', 'b'], 10)).dictionary_encode(),
+ ], names=['col', 'part'])
+
+ partitioning = ds.partitioning(table.select(["part"]).schema)
+
+ base_dir = tempdir / "dataset"
+ ds.write_dataset(
+ table, base_dir, format="feather", partitioning=partitioning
+ )
+
+ # check roundtrip
+ partitioning_read = ds.DirectoryPartitioning.discover(
+ ["part"], infer_dictionary=True)
+ result = ds.dataset(
+ base_dir, format="ipc", partitioning=partitioning_read
+ ).to_table()
+ assert result.equals(table)
+
+
+@pytest.mark.parquet
+def test_write_dataset_parquet(tempdir):
+ import pyarrow.parquet as pq
+
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))
+ ], names=["f1", "f2", "part"])
+
+ # using default "parquet" format string
+
+ base_dir = tempdir / 'parquet_dataset'
+ ds.write_dataset(table, base_dir, format="parquet")
+ # check that all files are present
+ file_paths = list(base_dir.rglob("*"))
+ expected_paths = [base_dir / "part-0.parquet"]
+ assert set(file_paths) == set(expected_paths)
+ # check Table roundtrip
+ result = ds.dataset(base_dir, format="parquet").to_table()
+ assert result.equals(table)
+
+ # using custom options
+ for version in ["1.0", "2.4", "2.6"]:
+ format = ds.ParquetFileFormat()
+ opts = format.make_write_options(version=version)
+ base_dir = tempdir / 'parquet_dataset_version{0}'.format(version)
+ ds.write_dataset(table, base_dir, format=format, file_options=opts)
+ meta = pq.read_metadata(base_dir / "part-0.parquet")
+ expected_version = "1.0" if version == "1.0" else "2.6"
+ assert meta.format_version == expected_version
+
+
+def test_write_dataset_csv(tempdir):
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))
+ ], names=["f1", "f2", "chr1"])
+
+ base_dir = tempdir / 'csv_dataset'
+ ds.write_dataset(table, base_dir, format="csv")
+ # check that all files are present
+ file_paths = list(base_dir.rglob("*"))
+ expected_paths = [base_dir / "part-0.csv"]
+ assert set(file_paths) == set(expected_paths)
+ # check Table roundtrip
+ result = ds.dataset(base_dir, format="csv").to_table()
+ assert result.equals(table)
+
+ # using custom options
+ format = ds.CsvFileFormat(read_options=pyarrow.csv.ReadOptions(
+ column_names=table.schema.names))
+ opts = format.make_write_options(include_header=False)
+ base_dir = tempdir / 'csv_dataset_noheader'
+ ds.write_dataset(table, base_dir, format=format, file_options=opts)
+ result = ds.dataset(base_dir, format=format).to_table()
+ assert result.equals(table)
+
+
+@pytest.mark.parquet
+def test_write_dataset_parquet_file_visitor(tempdir):
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))
+ ], names=["f1", "f2", "part"])
+
+ visitor_called = False
+
+ def file_visitor(written_file):
+ nonlocal visitor_called
+ if (written_file.metadata is not None and
+ written_file.metadata.num_columns == 3):
+ visitor_called = True
+
+ base_dir = tempdir / 'parquet_dataset'
+ ds.write_dataset(table, base_dir, format="parquet",
+ file_visitor=file_visitor)
+
+ assert visitor_called
+
+
+def test_partition_dataset_parquet_file_visitor(tempdir):
+ f1_vals = [item for chunk in range(4) for item in [chunk] * 10]
+ f2_vals = [item*10 for chunk in range(4) for item in [chunk] * 10]
+ table = pa.table({'f1': f1_vals, 'f2': f2_vals,
+ 'part': np.repeat(['a', 'b'], 20)})
+
+ root_path = tempdir / 'partitioned'
+ partitioning = ds.partitioning(
+ pa.schema([("part", pa.string())]), flavor="hive")
+
+ paths_written = []
+
+ sample_metadata = None
+
+ def file_visitor(written_file):
+ nonlocal sample_metadata
+ if written_file.metadata:
+ sample_metadata = written_file.metadata
+ paths_written.append(written_file.path)
+
+ ds.write_dataset(
+ table, root_path, format="parquet", partitioning=partitioning,
+ use_threads=True, file_visitor=file_visitor
+ )
+
+ expected_paths = {
+ root_path / 'part=a' / 'part-0.parquet',
+ root_path / 'part=b' / 'part-0.parquet'
+ }
+ paths_written_set = set(map(pathlib.Path, paths_written))
+ assert paths_written_set == expected_paths
+ assert sample_metadata is not None
+ assert sample_metadata.num_columns == 2
+
+
+@pytest.mark.parquet
+@pytest.mark.pandas
+def test_write_dataset_arrow_schema_metadata(tempdir):
+ # ensure we serialize ARROW schema in the parquet metadata, to have a
+ # correct roundtrip (e.g. preserve non-UTC timezone)
+ import pyarrow.parquet as pq
+
+ table = pa.table({"a": [pd.Timestamp("2012-01-01", tz="Europe/Brussels")]})
+ assert table["a"].type.tz == "Europe/Brussels"
+
+ ds.write_dataset(table, tempdir, format="parquet")
+ result = pq.read_table(tempdir / "part-0.parquet")
+ assert result["a"].type.tz == "Europe/Brussels"
+
+
+def test_write_dataset_schema_metadata(tempdir):
+ # ensure that schema metadata gets written
+ from pyarrow import feather
+
+ table = pa.table({'a': [1, 2, 3]})
+ table = table.replace_schema_metadata({b'key': b'value'})
+ ds.write_dataset(table, tempdir, format="feather")
+
+ schema = feather.read_table(tempdir / "part-0.feather").schema
+ assert schema.metadata == {b'key': b'value'}
+
+
+@pytest.mark.parquet
+def test_write_dataset_schema_metadata_parquet(tempdir):
+ # ensure that schema metadata gets written
+ import pyarrow.parquet as pq
+
+ table = pa.table({'a': [1, 2, 3]})
+ table = table.replace_schema_metadata({b'key': b'value'})
+ ds.write_dataset(table, tempdir, format="parquet")
+
+ schema = pq.read_table(tempdir / "part-0.parquet").schema
+ assert schema.metadata == {b'key': b'value'}
+
+
+@pytest.mark.parquet
+@pytest.mark.s3
+def test_write_dataset_s3(s3_example_simple):
+ # write dataset with s3 filesystem
+ _, _, fs, _, host, port, access_key, secret_key = s3_example_simple
+ uri_template = (
+ "s3://{}:{}@{{}}?scheme=http&endpoint_override={}:{}".format(
+ access_key, secret_key, host, port)
+ )
+
+ table = pa.table([
+ pa.array(range(20)), pa.array(np.random.randn(20)),
+ pa.array(np.repeat(['a', 'b'], 10))],
+ names=["f1", "f2", "part"]
+ )
+ part = ds.partitioning(pa.schema([("part", pa.string())]), flavor="hive")
+
+ # writing with filesystem object
+ ds.write_dataset(
+ table, "mybucket/dataset", filesystem=fs, format="feather",
+ partitioning=part
+ )
+ # check rountrip
+ result = ds.dataset(
+ "mybucket/dataset", filesystem=fs, format="ipc", partitioning="hive"
+ ).to_table()
+ assert result.equals(table)
+
+ # writing with URI
+ uri = uri_template.format("mybucket/dataset2")
+ ds.write_dataset(table, uri, format="feather", partitioning=part)
+ # check rountrip
+ result = ds.dataset(
+ "mybucket/dataset2", filesystem=fs, format="ipc", partitioning="hive"
+ ).to_table()
+ assert result.equals(table)
+
+ # writing with path + URI as filesystem
+ uri = uri_template.format("mybucket")
+ ds.write_dataset(
+ table, "dataset3", filesystem=uri, format="feather", partitioning=part
+ )
+ # check rountrip
+ result = ds.dataset(
+ "mybucket/dataset3", filesystem=fs, format="ipc", partitioning="hive"
+ ).to_table()
+ assert result.equals(table)
+
+
+@pytest.mark.parquet
+def test_dataset_null_to_dictionary_cast(tempdir, dataset_reader):
+ # ARROW-12420
+ import pyarrow.parquet as pq
+
+ table = pa.table({"a": [None, None]})
+ pq.write_table(table, tempdir / "test.parquet")
+
+ schema = pa.schema([
+ pa.field("a", pa.dictionary(pa.int32(), pa.string()))
+ ])
+ fsds = ds.FileSystemDataset.from_paths(
+ paths=[tempdir / "test.parquet"],
+ schema=schema,
+ format=ds.ParquetFileFormat(),
+ filesystem=fs.LocalFileSystem(),
+ )
+ table = dataset_reader.to_table(fsds)
+ assert table.schema == schema
diff --git a/src/arrow/python/pyarrow/tests/test_deprecations.py b/src/arrow/python/pyarrow/tests/test_deprecations.py
new file mode 100644
index 000000000..b16528937
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_deprecations.py
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Check that various deprecation warnings are raised
+
+# flake8: noqa
+
+import pyarrow as pa
+import pytest
diff --git a/src/arrow/python/pyarrow/tests/test_extension_type.py b/src/arrow/python/pyarrow/tests/test_extension_type.py
new file mode 100644
index 000000000..4ea6e94e4
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_extension_type.py
@@ -0,0 +1,779 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pickle
+import weakref
+
+import numpy as np
+import pyarrow as pa
+
+import pytest
+
+
+class IntegerType(pa.PyExtensionType):
+
+ def __init__(self):
+ pa.PyExtensionType.__init__(self, pa.int64())
+
+ def __reduce__(self):
+ return IntegerType, ()
+
+
+class UuidType(pa.PyExtensionType):
+
+ def __init__(self):
+ pa.PyExtensionType.__init__(self, pa.binary(16))
+
+ def __reduce__(self):
+ return UuidType, ()
+
+
+class ParamExtType(pa.PyExtensionType):
+
+ def __init__(self, width):
+ self._width = width
+ pa.PyExtensionType.__init__(self, pa.binary(width))
+
+ @property
+ def width(self):
+ return self._width
+
+ def __reduce__(self):
+ return ParamExtType, (self.width,)
+
+
+class MyStructType(pa.PyExtensionType):
+ storage_type = pa.struct([('left', pa.int64()),
+ ('right', pa.int64())])
+
+ def __init__(self):
+ pa.PyExtensionType.__init__(self, self.storage_type)
+
+ def __reduce__(self):
+ return MyStructType, ()
+
+
+class MyListType(pa.PyExtensionType):
+
+ def __init__(self, storage_type):
+ pa.PyExtensionType.__init__(self, storage_type)
+
+ def __reduce__(self):
+ return MyListType, (self.storage_type,)
+
+
+def ipc_write_batch(batch):
+ stream = pa.BufferOutputStream()
+ writer = pa.RecordBatchStreamWriter(stream, batch.schema)
+ writer.write_batch(batch)
+ writer.close()
+ return stream.getvalue()
+
+
+def ipc_read_batch(buf):
+ reader = pa.RecordBatchStreamReader(buf)
+ return reader.read_next_batch()
+
+
+def test_ext_type_basics():
+ ty = UuidType()
+ assert ty.extension_name == "arrow.py_extension_type"
+
+
+def test_ext_type_str():
+ ty = IntegerType()
+ expected = "extension<arrow.py_extension_type<IntegerType>>"
+ assert str(ty) == expected
+ assert pa.DataType.__str__(ty) == expected
+
+
+def test_ext_type_repr():
+ ty = IntegerType()
+ assert repr(ty) == "IntegerType(DataType(int64))"
+
+
+def test_ext_type__lifetime():
+ ty = UuidType()
+ wr = weakref.ref(ty)
+ del ty
+ assert wr() is None
+
+
+def test_ext_type__storage_type():
+ ty = UuidType()
+ assert ty.storage_type == pa.binary(16)
+ assert ty.__class__ is UuidType
+ ty = ParamExtType(5)
+ assert ty.storage_type == pa.binary(5)
+ assert ty.__class__ is ParamExtType
+
+
+def test_uuid_type_pickle():
+ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+ ty = UuidType()
+ ser = pickle.dumps(ty, protocol=proto)
+ del ty
+ ty = pickle.loads(ser)
+ wr = weakref.ref(ty)
+ assert ty.extension_name == "arrow.py_extension_type"
+ del ty
+ assert wr() is None
+
+
+def test_ext_type_equality():
+ a = ParamExtType(5)
+ b = ParamExtType(6)
+ c = ParamExtType(6)
+ assert a != b
+ assert b == c
+ d = UuidType()
+ e = UuidType()
+ assert a != d
+ assert d == e
+
+
+def test_ext_array_basics():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+ arr.validate()
+ assert arr.type is ty
+ assert arr.storage.equals(storage)
+
+
+def test_ext_array_lifetime():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+
+ refs = [weakref.ref(ty), weakref.ref(arr), weakref.ref(storage)]
+ del ty, storage, arr
+ for ref in refs:
+ assert ref() is None
+
+
+def test_ext_array_to_pylist():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar", None], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+
+ assert arr.to_pylist() == [b"foo", b"bar", None]
+
+
+def test_ext_array_errors():
+ ty = ParamExtType(4)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ with pytest.raises(TypeError, match="Incompatible storage type"):
+ pa.ExtensionArray.from_storage(ty, storage)
+
+
+def test_ext_array_equality():
+ storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
+ storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
+ storage3 = pa.array([], type=pa.binary(16))
+ ty1 = UuidType()
+ ty2 = ParamExtType(16)
+
+ a = pa.ExtensionArray.from_storage(ty1, storage1)
+ b = pa.ExtensionArray.from_storage(ty1, storage2)
+ assert a.equals(b)
+ c = pa.ExtensionArray.from_storage(ty1, storage3)
+ assert not a.equals(c)
+ d = pa.ExtensionArray.from_storage(ty2, storage1)
+ assert not a.equals(d)
+ e = pa.ExtensionArray.from_storage(ty2, storage2)
+ assert d.equals(e)
+ f = pa.ExtensionArray.from_storage(ty2, storage3)
+ assert not d.equals(f)
+
+
+def test_ext_array_wrap_array():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar", None], type=pa.binary(3))
+ arr = ty.wrap_array(storage)
+ arr.validate(full=True)
+ assert isinstance(arr, pa.ExtensionArray)
+ assert arr.type == ty
+ assert arr.storage == storage
+
+ storage = pa.chunked_array([[b"abc", b"def"], [b"ghi"]],
+ type=pa.binary(3))
+ arr = ty.wrap_array(storage)
+ arr.validate(full=True)
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.type == ty
+ assert arr.chunk(0).storage == storage.chunk(0)
+ assert arr.chunk(1).storage == storage.chunk(1)
+
+ # Wrong storage type
+ storage = pa.array([b"foo", b"bar", None])
+ with pytest.raises(TypeError, match="Incompatible storage type"):
+ ty.wrap_array(storage)
+
+ # Not an array or chunked array
+ with pytest.raises(TypeError, match="Expected array or chunked array"):
+ ty.wrap_array(None)
+
+
+def test_ext_scalar_from_array():
+ data = [b"0123456789abcdef", b"0123456789abcdef",
+ b"zyxwvutsrqponmlk", None]
+ storage = pa.array(data, type=pa.binary(16))
+ ty1 = UuidType()
+ ty2 = ParamExtType(16)
+
+ a = pa.ExtensionArray.from_storage(ty1, storage)
+ b = pa.ExtensionArray.from_storage(ty2, storage)
+
+ scalars_a = list(a)
+ assert len(scalars_a) == 4
+
+ for s, val in zip(scalars_a, data):
+ assert isinstance(s, pa.ExtensionScalar)
+ assert s.is_valid == (val is not None)
+ assert s.type == ty1
+ if val is not None:
+ assert s.value == pa.scalar(val, storage.type)
+ else:
+ assert s.value is None
+ assert s.as_py() == val
+
+ scalars_b = list(b)
+ assert len(scalars_b) == 4
+
+ for sa, sb in zip(scalars_a, scalars_b):
+ assert sa.is_valid == sb.is_valid
+ assert sa.as_py() == sb.as_py()
+ assert sa != sb
+
+
+def test_ext_scalar_from_storage():
+ ty = UuidType()
+
+ s = pa.ExtensionScalar.from_storage(ty, None)
+ assert isinstance(s, pa.ExtensionScalar)
+ assert s.type == ty
+ assert s.is_valid is False
+ assert s.value is None
+
+ s = pa.ExtensionScalar.from_storage(ty, b"0123456789abcdef")
+ assert isinstance(s, pa.ExtensionScalar)
+ assert s.type == ty
+ assert s.is_valid is True
+ assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type)
+
+ s = pa.ExtensionScalar.from_storage(ty, pa.scalar(None, ty.storage_type))
+ assert isinstance(s, pa.ExtensionScalar)
+ assert s.type == ty
+ assert s.is_valid is False
+ assert s.value is None
+
+ s = pa.ExtensionScalar.from_storage(
+ ty, pa.scalar(b"0123456789abcdef", ty.storage_type))
+ assert isinstance(s, pa.ExtensionScalar)
+ assert s.type == ty
+ assert s.is_valid is True
+ assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type)
+
+
+def test_ext_array_pickling():
+ for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+ ser = pickle.dumps(arr, protocol=proto)
+ del ty, storage, arr
+ arr = pickle.loads(ser)
+ arr.validate()
+ assert isinstance(arr, pa.ExtensionArray)
+ assert arr.type == ParamExtType(3)
+ assert arr.type.storage_type == pa.binary(3)
+ assert arr.storage.type == pa.binary(3)
+ assert arr.storage.to_pylist() == [b"foo", b"bar"]
+
+
+def test_ext_array_conversion_to_numpy():
+ storage1 = pa.array([1, 2, 3], type=pa.int64())
+ storage2 = pa.array([b"123", b"456", b"789"], type=pa.binary(3))
+ ty1 = IntegerType()
+ ty2 = ParamExtType(3)
+
+ arr1 = pa.ExtensionArray.from_storage(ty1, storage1)
+ arr2 = pa.ExtensionArray.from_storage(ty2, storage2)
+
+ result = arr1.to_numpy()
+ expected = np.array([1, 2, 3], dtype="int64")
+ np.testing.assert_array_equal(result, expected)
+
+ with pytest.raises(ValueError, match="zero_copy_only was True"):
+ arr2.to_numpy()
+ result = arr2.to_numpy(zero_copy_only=False)
+ expected = np.array([b"123", b"456", b"789"])
+ np.testing.assert_array_equal(result, expected)
+
+
+@pytest.mark.pandas
+def test_ext_array_conversion_to_pandas():
+ import pandas as pd
+
+ storage1 = pa.array([1, 2, 3], type=pa.int64())
+ storage2 = pa.array([b"123", b"456", b"789"], type=pa.binary(3))
+ ty1 = IntegerType()
+ ty2 = ParamExtType(3)
+
+ arr1 = pa.ExtensionArray.from_storage(ty1, storage1)
+ arr2 = pa.ExtensionArray.from_storage(ty2, storage2)
+
+ result = arr1.to_pandas()
+ expected = pd.Series([1, 2, 3], dtype="int64")
+ pd.testing.assert_series_equal(result, expected)
+
+ result = arr2.to_pandas()
+ expected = pd.Series([b"123", b"456", b"789"], dtype=object)
+ pd.testing.assert_series_equal(result, expected)
+
+
+def test_cast_kernel_on_extension_arrays():
+ # test array casting
+ storage = pa.array([1, 2, 3, 4], pa.int64())
+ arr = pa.ExtensionArray.from_storage(IntegerType(), storage)
+
+ # test that no allocation happens during identity cast
+ allocated_before_cast = pa.total_allocated_bytes()
+ casted = arr.cast(pa.int64())
+ assert pa.total_allocated_bytes() == allocated_before_cast
+
+ cases = [
+ (pa.int64(), pa.Int64Array),
+ (pa.int32(), pa.Int32Array),
+ (pa.int16(), pa.Int16Array),
+ (pa.uint64(), pa.UInt64Array),
+ (pa.uint32(), pa.UInt32Array),
+ (pa.uint16(), pa.UInt16Array)
+ ]
+ for typ, klass in cases:
+ casted = arr.cast(typ)
+ assert casted.type == typ
+ assert isinstance(casted, klass)
+
+ # test chunked array casting
+ arr = pa.chunked_array([arr, arr])
+ casted = arr.cast(pa.int16())
+ assert casted.type == pa.int16()
+ assert isinstance(casted, pa.ChunkedArray)
+
+
+def test_casting_to_extension_type_raises():
+ arr = pa.array([1, 2, 3, 4], pa.int64())
+ with pytest.raises(pa.ArrowNotImplementedError):
+ arr.cast(IntegerType())
+
+
+def example_batch():
+ ty = ParamExtType(3)
+ storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+ arr = pa.ExtensionArray.from_storage(ty, storage)
+ return pa.RecordBatch.from_arrays([arr], ["exts"])
+
+
+def check_example_batch(batch):
+ arr = batch.column(0)
+ assert isinstance(arr, pa.ExtensionArray)
+ assert arr.type.storage_type == pa.binary(3)
+ assert arr.storage.to_pylist() == [b"foo", b"bar"]
+ return arr
+
+
+def test_ipc():
+ batch = example_batch()
+ buf = ipc_write_batch(batch)
+ del batch
+
+ batch = ipc_read_batch(buf)
+ arr = check_example_batch(batch)
+ assert arr.type == ParamExtType(3)
+
+
+def test_ipc_unknown_type():
+ batch = example_batch()
+ buf = ipc_write_batch(batch)
+ del batch
+
+ orig_type = ParamExtType
+ try:
+ # Simulate the original Python type being unavailable.
+ # Deserialization should not fail but return a placeholder type.
+ del globals()['ParamExtType']
+
+ batch = ipc_read_batch(buf)
+ arr = check_example_batch(batch)
+ assert isinstance(arr.type, pa.UnknownExtensionType)
+
+ # Can be serialized again
+ buf2 = ipc_write_batch(batch)
+ del batch, arr
+
+ batch = ipc_read_batch(buf2)
+ arr = check_example_batch(batch)
+ assert isinstance(arr.type, pa.UnknownExtensionType)
+ finally:
+ globals()['ParamExtType'] = orig_type
+
+ # Deserialize again with the type restored
+ batch = ipc_read_batch(buf2)
+ arr = check_example_batch(batch)
+ assert arr.type == ParamExtType(3)
+
+
+class PeriodArray(pa.ExtensionArray):
+ pass
+
+
+class PeriodType(pa.ExtensionType):
+ def __init__(self, freq):
+ # attributes need to be set first before calling
+ # super init (as that calls serialize)
+ self._freq = freq
+ pa.ExtensionType.__init__(self, pa.int64(), 'test.period')
+
+ @property
+ def freq(self):
+ return self._freq
+
+ def __arrow_ext_serialize__(self):
+ return "freq={}".format(self.freq).encode()
+
+ @classmethod
+ def __arrow_ext_deserialize__(cls, storage_type, serialized):
+ serialized = serialized.decode()
+ assert serialized.startswith("freq=")
+ freq = serialized.split('=')[1]
+ return PeriodType(freq)
+
+ def __eq__(self, other):
+ if isinstance(other, pa.BaseExtensionType):
+ return (type(self) == type(other) and
+ self.freq == other.freq)
+ else:
+ return NotImplemented
+
+
+class PeriodTypeWithClass(PeriodType):
+ def __init__(self, freq):
+ PeriodType.__init__(self, freq)
+
+ def __arrow_ext_class__(self):
+ return PeriodArray
+
+ @classmethod
+ def __arrow_ext_deserialize__(cls, storage_type, serialized):
+ freq = PeriodType.__arrow_ext_deserialize__(
+ storage_type, serialized).freq
+ return PeriodTypeWithClass(freq)
+
+
+@pytest.fixture(params=[PeriodType('D'), PeriodTypeWithClass('D')])
+def registered_period_type(request):
+ # setup
+ period_type = request.param
+ period_class = period_type.__arrow_ext_class__()
+ pa.register_extension_type(period_type)
+ yield period_type, period_class
+ # teardown
+ try:
+ pa.unregister_extension_type('test.period')
+ except KeyError:
+ pass
+
+
+def test_generic_ext_type():
+ period_type = PeriodType('D')
+ assert period_type.extension_name == "test.period"
+ assert period_type.storage_type == pa.int64()
+ # default ext_class expected.
+ assert period_type.__arrow_ext_class__() == pa.ExtensionArray
+
+
+def test_generic_ext_type_ipc(registered_period_type):
+ period_type, period_class = registered_period_type
+ storage = pa.array([1, 2, 3, 4], pa.int64())
+ arr = pa.ExtensionArray.from_storage(period_type, storage)
+ batch = pa.RecordBatch.from_arrays([arr], ["ext"])
+ # check the built array has exactly the expected clss
+ assert type(arr) == period_class
+
+ buf = ipc_write_batch(batch)
+ del batch
+ batch = ipc_read_batch(buf)
+
+ result = batch.column(0)
+ # check the deserialized array class is the expected one
+ assert type(result) == period_class
+ assert result.type.extension_name == "test.period"
+ assert arr.storage.to_pylist() == [1, 2, 3, 4]
+
+ # we get back an actual PeriodType
+ assert isinstance(result.type, PeriodType)
+ assert result.type.freq == 'D'
+ assert result.type == period_type
+
+ # using different parametrization as how it was registered
+ period_type_H = period_type.__class__('H')
+ assert period_type_H.extension_name == "test.period"
+ assert period_type_H.freq == 'H'
+
+ arr = pa.ExtensionArray.from_storage(period_type_H, storage)
+ batch = pa.RecordBatch.from_arrays([arr], ["ext"])
+
+ buf = ipc_write_batch(batch)
+ del batch
+ batch = ipc_read_batch(buf)
+ result = batch.column(0)
+ assert isinstance(result.type, PeriodType)
+ assert result.type.freq == 'H'
+ assert type(result) == period_class
+
+
+def test_generic_ext_type_ipc_unknown(registered_period_type):
+ period_type, _ = registered_period_type
+ storage = pa.array([1, 2, 3, 4], pa.int64())
+ arr = pa.ExtensionArray.from_storage(period_type, storage)
+ batch = pa.RecordBatch.from_arrays([arr], ["ext"])
+
+ buf = ipc_write_batch(batch)
+ del batch
+
+ # unregister type before loading again => reading unknown extension type
+ # as plain array (but metadata in schema's field are preserved)
+ pa.unregister_extension_type('test.period')
+
+ batch = ipc_read_batch(buf)
+ result = batch.column(0)
+
+ assert isinstance(result, pa.Int64Array)
+ ext_field = batch.schema.field('ext')
+ assert ext_field.metadata == {
+ b'ARROW:extension:metadata': b'freq=D',
+ b'ARROW:extension:name': b'test.period'
+ }
+
+
+def test_generic_ext_type_equality():
+ period_type = PeriodType('D')
+ assert period_type.extension_name == "test.period"
+
+ period_type2 = PeriodType('D')
+ period_type3 = PeriodType('H')
+ assert period_type == period_type2
+ assert not period_type == period_type3
+
+
+def test_generic_ext_type_register(registered_period_type):
+ # test that trying to register other type does not segfault
+ with pytest.raises(TypeError):
+ pa.register_extension_type(pa.string())
+
+ # register second time raises KeyError
+ period_type = PeriodType('D')
+ with pytest.raises(KeyError):
+ pa.register_extension_type(period_type)
+
+
+@pytest.mark.parquet
+def test_parquet_period(tmpdir, registered_period_type):
+ # Parquet support for primitive extension types
+ period_type, period_class = registered_period_type
+ storage = pa.array([1, 2, 3, 4], pa.int64())
+ arr = pa.ExtensionArray.from_storage(period_type, storage)
+ table = pa.table([arr], names=["ext"])
+
+ import pyarrow.parquet as pq
+
+ filename = tmpdir / 'period_extension_type.parquet'
+ pq.write_table(table, filename)
+
+ # Stored in parquet as storage type but with extension metadata saved
+ # in the serialized arrow schema
+ meta = pq.read_metadata(filename)
+ assert meta.schema.column(0).physical_type == "INT64"
+ assert b"ARROW:schema" in meta.metadata
+
+ import base64
+ decoded_schema = base64.b64decode(meta.metadata[b"ARROW:schema"])
+ schema = pa.ipc.read_schema(pa.BufferReader(decoded_schema))
+ # Since the type could be reconstructed, the extension type metadata is
+ # absent.
+ assert schema.field("ext").metadata == {}
+
+ # When reading in, properly create extension type if it is registered
+ result = pq.read_table(filename)
+ assert result.schema.field("ext").type == period_type
+ assert result.schema.field("ext").metadata == {}
+ # Get the exact array class defined by the registered type.
+ result_array = result.column("ext").chunk(0)
+ assert type(result_array) is period_class
+
+ # When the type is not registered, read in as storage type
+ pa.unregister_extension_type(period_type.extension_name)
+ result = pq.read_table(filename)
+ assert result.schema.field("ext").type == pa.int64()
+ # The extension metadata is present for roundtripping.
+ assert result.schema.field("ext").metadata == {
+ b'ARROW:extension:metadata': b'freq=D',
+ b'ARROW:extension:name': b'test.period'
+ }
+
+
+@pytest.mark.parquet
+def test_parquet_extension_with_nested_storage(tmpdir):
+ # Parquet support for extension types with nested storage type
+ import pyarrow.parquet as pq
+
+ struct_array = pa.StructArray.from_arrays(
+ [pa.array([0, 1], type="int64"), pa.array([4, 5], type="int64")],
+ names=["left", "right"])
+ list_array = pa.array([[1, 2, 3], [4, 5]], type=pa.list_(pa.int32()))
+
+ mystruct_array = pa.ExtensionArray.from_storage(MyStructType(),
+ struct_array)
+ mylist_array = pa.ExtensionArray.from_storage(
+ MyListType(list_array.type), list_array)
+
+ orig_table = pa.table({'structs': mystruct_array,
+ 'lists': mylist_array})
+ filename = tmpdir / 'nested_extension_storage.parquet'
+ pq.write_table(orig_table, filename)
+
+ table = pq.read_table(filename)
+ assert table.column('structs').type == mystruct_array.type
+ assert table.column('lists').type == mylist_array.type
+ assert table == orig_table
+
+
+@pytest.mark.parquet
+def test_parquet_nested_extension(tmpdir):
+ # Parquet support for extension types nested in struct or list
+ import pyarrow.parquet as pq
+
+ ext_type = IntegerType()
+ storage = pa.array([4, 5, 6, 7], type=pa.int64())
+ ext_array = pa.ExtensionArray.from_storage(ext_type, storage)
+
+ # Struct of extensions
+ struct_array = pa.StructArray.from_arrays(
+ [storage, ext_array],
+ names=['ints', 'exts'])
+
+ orig_table = pa.table({'structs': struct_array})
+ filename = tmpdir / 'struct_of_ext.parquet'
+ pq.write_table(orig_table, filename)
+
+ table = pq.read_table(filename)
+ assert table.column(0).type == struct_array.type
+ assert table == orig_table
+
+ # List of extensions
+ list_array = pa.ListArray.from_arrays([0, 1, None, 3], ext_array)
+
+ orig_table = pa.table({'lists': list_array})
+ filename = tmpdir / 'list_of_ext.parquet'
+ pq.write_table(orig_table, filename)
+
+ table = pq.read_table(filename)
+ assert table.column(0).type == list_array.type
+ assert table == orig_table
+
+ # Large list of extensions
+ list_array = pa.LargeListArray.from_arrays([0, 1, None, 3], ext_array)
+
+ orig_table = pa.table({'lists': list_array})
+ filename = tmpdir / 'list_of_ext.parquet'
+ pq.write_table(orig_table, filename)
+
+ table = pq.read_table(filename)
+ assert table.column(0).type == list_array.type
+ assert table == orig_table
+
+
+@pytest.mark.parquet
+def test_parquet_extension_nested_in_extension(tmpdir):
+ # Parquet support for extension<list<extension>>
+ import pyarrow.parquet as pq
+
+ inner_ext_type = IntegerType()
+ inner_storage = pa.array([4, 5, 6, 7], type=pa.int64())
+ inner_ext_array = pa.ExtensionArray.from_storage(inner_ext_type,
+ inner_storage)
+
+ list_array = pa.ListArray.from_arrays([0, 1, None, 3], inner_ext_array)
+ mylist_array = pa.ExtensionArray.from_storage(
+ MyListType(list_array.type), list_array)
+
+ orig_table = pa.table({'lists': mylist_array})
+ filename = tmpdir / 'ext_of_list_of_ext.parquet'
+ pq.write_table(orig_table, filename)
+
+ table = pq.read_table(filename)
+ assert table.column(0).type == mylist_array.type
+ assert table == orig_table
+
+
+def test_to_numpy():
+ period_type = PeriodType('D')
+ storage = pa.array([1, 2, 3, 4], pa.int64())
+ arr = pa.ExtensionArray.from_storage(period_type, storage)
+
+ expected = storage.to_numpy()
+ result = arr.to_numpy()
+ np.testing.assert_array_equal(result, expected)
+
+ result = np.asarray(arr)
+ np.testing.assert_array_equal(result, expected)
+
+ # chunked array
+ a1 = pa.chunked_array([arr, arr])
+ a2 = pa.chunked_array([arr, arr], type=period_type)
+ expected = np.hstack([expected, expected])
+
+ for charr in [a1, a2]:
+ assert charr.type == period_type
+ for result in [np.asarray(charr), charr.to_numpy()]:
+ assert result.dtype == np.int64
+ np.testing.assert_array_equal(result, expected)
+
+ # zero chunks
+ charr = pa.chunked_array([], type=period_type)
+ assert charr.type == period_type
+
+ for result in [np.asarray(charr), charr.to_numpy()]:
+ assert result.dtype == np.int64
+ np.testing.assert_array_equal(result, np.array([], dtype='int64'))
+
+
+def test_empty_take():
+ # https://issues.apache.org/jira/browse/ARROW-13474
+ ext_type = IntegerType()
+ storage = pa.array([], type=pa.int64())
+ empty_arr = pa.ExtensionArray.from_storage(ext_type, storage)
+
+ result = empty_arr.filter(pa.array([], pa.bool_()))
+ assert len(result) == 0
+ assert result.equals(empty_arr)
+
+ result = empty_arr.take(pa.array([], pa.int32()))
+ assert len(result) == 0
+ assert result.equals(empty_arr)
diff --git a/src/arrow/python/pyarrow/tests/test_feather.py b/src/arrow/python/pyarrow/tests/test_feather.py
new file mode 100644
index 000000000..3d0451ee3
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_feather.py
@@ -0,0 +1,799 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import io
+import os
+import sys
+import tempfile
+import pytest
+import hypothesis as h
+import hypothesis.strategies as st
+
+import numpy as np
+
+import pyarrow as pa
+import pyarrow.tests.strategies as past
+from pyarrow.feather import (read_feather, write_feather, read_table,
+ FeatherDataset)
+
+
+try:
+ from pandas.testing import assert_frame_equal
+ import pandas as pd
+ import pyarrow.pandas_compat
+except ImportError:
+ pass
+
+
+@pytest.fixture(scope='module')
+def datadir(base_datadir):
+ return base_datadir / 'feather'
+
+
+def random_path(prefix='feather_'):
+ return tempfile.mktemp(prefix=prefix)
+
+
+@pytest.fixture(scope="module", params=[1, 2])
+def version(request):
+ yield request.param
+
+
+@pytest.fixture(scope="module", params=[None, "uncompressed", "lz4", "zstd"])
+def compression(request):
+ if request.param in ['lz4', 'zstd'] and not pa.Codec.is_available(
+ request.param):
+ pytest.skip(f'{request.param} is not available')
+ yield request.param
+
+
+TEST_FILES = None
+
+
+def setup_module(module):
+ global TEST_FILES
+ TEST_FILES = []
+
+
+def teardown_module(module):
+ for path in TEST_FILES:
+ try:
+ os.remove(path)
+ except os.error:
+ pass
+
+
+@pytest.mark.pandas
+def test_file_not_exist():
+ with pytest.raises(pa.ArrowIOError):
+ read_feather('test_invalid_file')
+
+
+def _check_pandas_roundtrip(df, expected=None, path=None,
+ columns=None, use_threads=False,
+ version=None, compression=None,
+ compression_level=None):
+ if path is None:
+ path = random_path()
+
+ TEST_FILES.append(path)
+ write_feather(df, path, compression=compression,
+ compression_level=compression_level, version=version)
+ if not os.path.exists(path):
+ raise Exception('file not written')
+
+ result = read_feather(path, columns, use_threads=use_threads)
+ if expected is None:
+ expected = df
+
+ assert_frame_equal(result, expected)
+
+
+def _check_arrow_roundtrip(table, path=None, compression=None):
+ if path is None:
+ path = random_path()
+
+ TEST_FILES.append(path)
+ write_feather(table, path, compression=compression)
+ if not os.path.exists(path):
+ raise Exception('file not written')
+
+ result = read_table(path)
+ assert result.equals(table)
+
+
+def _assert_error_on_write(df, exc, path=None, version=2):
+ # check that we are raising the exception
+ # on writing
+
+ if path is None:
+ path = random_path()
+
+ TEST_FILES.append(path)
+
+ def f():
+ write_feather(df, path, version=version)
+
+ pytest.raises(exc, f)
+
+
+def test_dataset(version):
+ num_values = (100, 100)
+ num_files = 5
+ paths = [random_path() for i in range(num_files)]
+ data = {
+ "col_" + str(i): np.random.randn(num_values[0])
+ for i in range(num_values[1])
+ }
+ table = pa.table(data)
+
+ TEST_FILES.extend(paths)
+ for index, path in enumerate(paths):
+ rows = (
+ index * (num_values[0] // num_files),
+ (index + 1) * (num_values[0] // num_files),
+ )
+
+ write_feather(table[rows[0]: rows[1]], path, version=version)
+
+ data = FeatherDataset(paths).read_table()
+ assert data.equals(table)
+
+
+@pytest.mark.pandas
+def test_float_no_nulls(version):
+ data = {}
+ numpy_dtypes = ['f4', 'f8']
+ num_values = 100
+
+ for dtype in numpy_dtypes:
+ values = np.random.randn(num_values)
+ data[dtype] = values.astype(dtype)
+
+ df = pd.DataFrame(data)
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_read_table(version):
+ num_values = (100, 100)
+ path = random_path()
+
+ TEST_FILES.append(path)
+
+ values = np.random.randint(0, 100, size=num_values)
+ columns = ['col_' + str(i) for i in range(100)]
+ table = pa.Table.from_arrays(values, columns)
+
+ write_feather(table, path, version=version)
+
+ result = read_table(path)
+ assert result.equals(table)
+
+ # Test without memory mapping
+ result = read_table(path, memory_map=False)
+ assert result.equals(table)
+
+ result = read_feather(path, memory_map=False)
+ assert_frame_equal(table.to_pandas(), result)
+
+
+@pytest.mark.pandas
+def test_float_nulls(version):
+ num_values = 100
+
+ path = random_path()
+ TEST_FILES.append(path)
+
+ null_mask = np.random.randint(0, 10, size=num_values) < 3
+ dtypes = ['f4', 'f8']
+ expected_cols = []
+
+ arrays = []
+ for name in dtypes:
+ values = np.random.randn(num_values).astype(name)
+ arrays.append(pa.array(values, mask=null_mask))
+
+ values[null_mask] = np.nan
+
+ expected_cols.append(values)
+
+ table = pa.table(arrays, names=dtypes)
+ _check_arrow_roundtrip(table)
+
+ df = table.to_pandas()
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_integer_no_nulls(version):
+ data, arr = {}, []
+
+ numpy_dtypes = ['i1', 'i2', 'i4', 'i8',
+ 'u1', 'u2', 'u4', 'u8']
+ num_values = 100
+
+ for dtype in numpy_dtypes:
+ values = np.random.randint(0, 100, size=num_values)
+ data[dtype] = values.astype(dtype)
+ arr.append(values.astype(dtype))
+
+ df = pd.DataFrame(data)
+ _check_pandas_roundtrip(df, version=version)
+
+ table = pa.table(arr, names=numpy_dtypes)
+ _check_arrow_roundtrip(table)
+
+
+@pytest.mark.pandas
+def test_platform_numpy_integers(version):
+ data = {}
+
+ numpy_dtypes = ['longlong']
+ num_values = 100
+
+ for dtype in numpy_dtypes:
+ values = np.random.randint(0, 100, size=num_values)
+ data[dtype] = values.astype(dtype)
+
+ df = pd.DataFrame(data)
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_integer_with_nulls(version):
+ # pandas requires upcast to float dtype
+ path = random_path()
+ TEST_FILES.append(path)
+
+ int_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8']
+ num_values = 100
+
+ arrays = []
+ null_mask = np.random.randint(0, 10, size=num_values) < 3
+ expected_cols = []
+ for name in int_dtypes:
+ values = np.random.randint(0, 100, size=num_values)
+ arrays.append(pa.array(values, mask=null_mask))
+
+ expected = values.astype('f8')
+ expected[null_mask] = np.nan
+
+ expected_cols.append(expected)
+
+ table = pa.table(arrays, names=int_dtypes)
+ _check_arrow_roundtrip(table)
+
+ df = table.to_pandas()
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_boolean_no_nulls(version):
+ num_values = 100
+
+ np.random.seed(0)
+
+ df = pd.DataFrame({'bools': np.random.randn(num_values) > 0})
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_boolean_nulls(version):
+ # pandas requires upcast to object dtype
+ path = random_path()
+ TEST_FILES.append(path)
+
+ num_values = 100
+ np.random.seed(0)
+
+ mask = np.random.randint(0, 10, size=num_values) < 3
+ values = np.random.randint(0, 10, size=num_values) < 5
+
+ table = pa.table([pa.array(values, mask=mask)], names=['bools'])
+ _check_arrow_roundtrip(table)
+
+ df = table.to_pandas()
+ _check_pandas_roundtrip(df, version=version)
+
+
+def test_buffer_bounds_error(version):
+ # ARROW-1676
+ path = random_path()
+ TEST_FILES.append(path)
+
+ for i in range(16, 256):
+ table = pa.Table.from_arrays(
+ [pa.array([None] + list(range(i)), type=pa.float64())],
+ names=["arr"]
+ )
+ _check_arrow_roundtrip(table)
+
+
+def test_boolean_object_nulls(version):
+ repeats = 100
+ table = pa.Table.from_arrays(
+ [np.array([False, None, True] * repeats, dtype=object)],
+ names=["arr"]
+ )
+ _check_arrow_roundtrip(table)
+
+
+@pytest.mark.pandas
+def test_delete_partial_file_on_error(version):
+ if sys.platform == 'win32':
+ pytest.skip('Windows hangs on to file handle for some reason')
+
+ class CustomClass:
+ pass
+
+ # strings will fail
+ df = pd.DataFrame(
+ {
+ 'numbers': range(5),
+ 'strings': [b'foo', None, 'bar', CustomClass(), np.nan]},
+ columns=['numbers', 'strings'])
+
+ path = random_path()
+ try:
+ write_feather(df, path, version=version)
+ except Exception:
+ pass
+
+ assert not os.path.exists(path)
+
+
+@pytest.mark.pandas
+def test_strings(version):
+ repeats = 1000
+
+ # Mixed bytes, unicode, strings coerced to binary
+ values = [b'foo', None, 'bar', 'qux', np.nan]
+ df = pd.DataFrame({'strings': values * repeats})
+
+ ex_values = [b'foo', None, b'bar', b'qux', np.nan]
+ expected = pd.DataFrame({'strings': ex_values * repeats})
+ _check_pandas_roundtrip(df, expected, version=version)
+
+ # embedded nulls are ok
+ values = ['foo', None, 'bar', 'qux', None]
+ df = pd.DataFrame({'strings': values * repeats})
+ expected = pd.DataFrame({'strings': values * repeats})
+ _check_pandas_roundtrip(df, expected, version=version)
+
+ values = ['foo', None, 'bar', 'qux', np.nan]
+ df = pd.DataFrame({'strings': values * repeats})
+ expected = pd.DataFrame({'strings': values * repeats})
+ _check_pandas_roundtrip(df, expected, version=version)
+
+
+@pytest.mark.pandas
+def test_empty_strings(version):
+ df = pd.DataFrame({'strings': [''] * 10})
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_all_none(version):
+ df = pd.DataFrame({'all_none': [None] * 10})
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_all_null_category(version):
+ # ARROW-1188
+ df = pd.DataFrame({"A": (1, 2, 3), "B": (None, None, None)})
+ df = df.assign(B=df.B.astype("category"))
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_multithreaded_read(version):
+ data = {'c{}'.format(i): [''] * 10
+ for i in range(100)}
+ df = pd.DataFrame(data)
+ _check_pandas_roundtrip(df, use_threads=True, version=version)
+
+
+@pytest.mark.pandas
+def test_nan_as_null(version):
+ # Create a nan that is not numpy.nan
+ values = np.array(['foo', np.nan, np.nan * 2, 'bar'] * 10)
+ df = pd.DataFrame({'strings': values})
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_category(version):
+ repeats = 1000
+ values = ['foo', None, 'bar', 'qux', np.nan]
+ df = pd.DataFrame({'strings': values * repeats})
+ df['strings'] = df['strings'].astype('category')
+
+ values = ['foo', None, 'bar', 'qux', None]
+ expected = pd.DataFrame({'strings': pd.Categorical(values * repeats)})
+ _check_pandas_roundtrip(df, expected, version=version)
+
+
+@pytest.mark.pandas
+def test_timestamp(version):
+ df = pd.DataFrame({'naive': pd.date_range('2016-03-28', periods=10)})
+ df['with_tz'] = (df.naive.dt.tz_localize('utc')
+ .dt.tz_convert('America/Los_Angeles'))
+
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_timestamp_with_nulls(version):
+ df = pd.DataFrame({'test': [pd.Timestamp(2016, 1, 1),
+ None,
+ pd.Timestamp(2016, 1, 3)]})
+ df['with_tz'] = df.test.dt.tz_localize('utc')
+
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+@pytest.mark.xfail(reason="not supported", raises=TypeError)
+def test_timedelta_with_nulls_v1():
+ df = pd.DataFrame({'test': [pd.Timedelta('1 day'),
+ None,
+ pd.Timedelta('3 day')]})
+ _check_pandas_roundtrip(df, version=1)
+
+
+@pytest.mark.pandas
+def test_timedelta_with_nulls():
+ df = pd.DataFrame({'test': [pd.Timedelta('1 day'),
+ None,
+ pd.Timedelta('3 day')]})
+ _check_pandas_roundtrip(df, version=2)
+
+
+@pytest.mark.pandas
+def test_out_of_float64_timestamp_with_nulls(version):
+ df = pd.DataFrame(
+ {'test': pd.DatetimeIndex([1451606400000000001,
+ None, 14516064000030405])})
+ df['with_tz'] = df.test.dt.tz_localize('utc')
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.pandas
+def test_non_string_columns(version):
+ df = pd.DataFrame({0: [1, 2, 3, 4],
+ 1: [True, False, True, False]})
+
+ expected = df.rename(columns=str)
+ _check_pandas_roundtrip(df, expected, version=version)
+
+
+@pytest.mark.pandas
+@pytest.mark.skipif(not os.path.supports_unicode_filenames,
+ reason='unicode filenames not supported')
+def test_unicode_filename(version):
+ # GH #209
+ name = (b'Besa_Kavaj\xc3\xab.feather').decode('utf-8')
+ df = pd.DataFrame({'foo': [1, 2, 3, 4]})
+ _check_pandas_roundtrip(df, path=random_path(prefix=name),
+ version=version)
+
+
+@pytest.mark.pandas
+def test_read_columns(version):
+ df = pd.DataFrame({
+ 'foo': [1, 2, 3, 4],
+ 'boo': [5, 6, 7, 8],
+ 'woo': [1, 3, 5, 7]
+ })
+ expected = df[['boo', 'woo']]
+
+ _check_pandas_roundtrip(df, expected, version=version,
+ columns=['boo', 'woo'])
+
+
+def test_overwritten_file(version):
+ path = random_path()
+ TEST_FILES.append(path)
+
+ num_values = 100
+ np.random.seed(0)
+
+ values = np.random.randint(0, 10, size=num_values)
+
+ table = pa.table({'ints': values})
+ write_feather(table, path)
+
+ table = pa.table({'more_ints': values[0:num_values//2]})
+ _check_arrow_roundtrip(table, path=path)
+
+
+@pytest.mark.pandas
+def test_filelike_objects(version):
+ buf = io.BytesIO()
+
+ # the copy makes it non-strided
+ df = pd.DataFrame(np.arange(12).reshape(4, 3),
+ columns=['a', 'b', 'c']).copy()
+ write_feather(df, buf, version=version)
+
+ buf.seek(0)
+
+ result = read_feather(buf)
+ assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@pytest.mark.filterwarnings("ignore:Sparse:FutureWarning")
+@pytest.mark.filterwarnings("ignore:DataFrame.to_sparse:FutureWarning")
+def test_sparse_dataframe(version):
+ if not pa.pandas_compat._pandas_api.has_sparse:
+ pytest.skip("version of pandas does not support SparseDataFrame")
+ # GH #221
+ data = {'A': [0, 1, 2],
+ 'B': [1, 0, 1]}
+ df = pd.DataFrame(data).to_sparse(fill_value=1)
+ expected = df.to_dense()
+ _check_pandas_roundtrip(df, expected, version=version)
+
+
+@pytest.mark.pandas
+def test_duplicate_columns_pandas():
+
+ # https://github.com/wesm/feather/issues/53
+ # not currently able to handle duplicate columns
+ df = pd.DataFrame(np.arange(12).reshape(4, 3),
+ columns=list('aaa')).copy()
+ _assert_error_on_write(df, ValueError)
+
+
+def test_duplicate_columns():
+ # only works for version 2
+ table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'a', 'b'])
+ _check_arrow_roundtrip(table)
+ _assert_error_on_write(table, ValueError, version=1)
+
+
+@pytest.mark.pandas
+def test_unsupported():
+ # https://github.com/wesm/feather/issues/240
+ # serializing actual python objects
+
+ # custom python objects
+ class A:
+ pass
+
+ df = pd.DataFrame({'a': [A(), A()]})
+ _assert_error_on_write(df, ValueError)
+
+ # non-strings
+ df = pd.DataFrame({'a': ['a', 1, 2.0]})
+ _assert_error_on_write(df, TypeError)
+
+
+@pytest.mark.pandas
+def test_v2_set_chunksize():
+ df = pd.DataFrame({'A': np.arange(1000)})
+ table = pa.table(df)
+
+ buf = io.BytesIO()
+ write_feather(table, buf, chunksize=250, version=2)
+
+ result = buf.getvalue()
+
+ ipc_file = pa.ipc.open_file(pa.BufferReader(result))
+ assert ipc_file.num_record_batches == 4
+ assert len(ipc_file.get_batch(0)) == 250
+
+
+@pytest.mark.pandas
+@pytest.mark.lz4
+@pytest.mark.snappy
+@pytest.mark.zstd
+def test_v2_compression_options():
+ df = pd.DataFrame({'A': np.arange(1000)})
+
+ cases = [
+ # compression, compression_level
+ ('uncompressed', None),
+ ('lz4', None),
+ ('zstd', 1),
+ ('zstd', 10)
+ ]
+
+ for compression, compression_level in cases:
+ _check_pandas_roundtrip(df, compression=compression,
+ compression_level=compression_level)
+
+ buf = io.BytesIO()
+
+ # LZ4 doesn't support compression_level
+ with pytest.raises(pa.ArrowInvalid,
+ match="doesn't support setting a compression level"):
+ write_feather(df, buf, compression='lz4', compression_level=10)
+
+ # Trying to compress with V1
+ with pytest.raises(
+ ValueError,
+ match="Feather V1 files do not support compression option"):
+ write_feather(df, buf, compression='lz4', version=1)
+
+ # Trying to set chunksize with V1
+ with pytest.raises(
+ ValueError,
+ match="Feather V1 files do not support chunksize option"):
+ write_feather(df, buf, chunksize=4096, version=1)
+
+ # Unsupported compressor
+ with pytest.raises(ValueError,
+ match='compression="snappy" not supported'):
+ write_feather(df, buf, compression='snappy')
+
+
+def test_v2_lz4_default_compression():
+ # ARROW-8750: Make sure that the compression=None option selects lz4 if
+ # it's available
+ if not pa.Codec.is_available('lz4_frame'):
+ pytest.skip("LZ4 compression support is not built in C++")
+
+ # some highly compressible data
+ t = pa.table([np.repeat(0, 100000)], names=['f0'])
+
+ buf = io.BytesIO()
+ write_feather(t, buf)
+ default_result = buf.getvalue()
+
+ buf = io.BytesIO()
+ write_feather(t, buf, compression='uncompressed')
+ uncompressed_result = buf.getvalue()
+
+ assert len(default_result) < len(uncompressed_result)
+
+
+def test_v1_unsupported_types():
+ table = pa.table([pa.array([[1, 2, 3], [], None])], names=['f0'])
+
+ buf = io.BytesIO()
+ with pytest.raises(TypeError,
+ match=("Unsupported Feather V1 type: "
+ "list<item: int64>. "
+ "Use V2 format to serialize all Arrow types.")):
+ write_feather(table, buf, version=1)
+
+
+@pytest.mark.slow
+@pytest.mark.pandas
+def test_large_dataframe(version):
+ df = pd.DataFrame({'A': np.arange(400000000)})
+ _check_pandas_roundtrip(df, version=version)
+
+
+@pytest.mark.large_memory
+@pytest.mark.pandas
+def test_chunked_binary_error_message():
+ # ARROW-3058: As Feather does not yet support chunked columns, we at least
+ # make sure it's clear to the user what is going on
+
+ # 2^31 + 1 bytes
+ values = [b'x'] + [
+ b'x' * (1 << 20)
+ ] * 2 * (1 << 10)
+ df = pd.DataFrame({'byte_col': values})
+
+ # Works fine with version 2
+ buf = io.BytesIO()
+ write_feather(df, buf, version=2)
+ result = read_feather(pa.BufferReader(buf.getvalue()))
+ assert_frame_equal(result, df)
+
+ with pytest.raises(ValueError, match="'byte_col' exceeds 2GB maximum "
+ "capacity of a Feather binary column. This restriction "
+ "may be lifted in the future"):
+ write_feather(df, io.BytesIO(), version=1)
+
+
+def test_feather_without_pandas(tempdir, version):
+ # ARROW-8345
+ table = pa.table([pa.array([1, 2, 3])], names=['f0'])
+ path = str(tempdir / "data.feather")
+ _check_arrow_roundtrip(table, path)
+
+
+@pytest.mark.pandas
+def test_read_column_selection(version):
+ # ARROW-8641
+ df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=['a', 'b', 'c'])
+
+ # select columns as string names or integer indices
+ _check_pandas_roundtrip(
+ df, columns=['a', 'c'], expected=df[['a', 'c']], version=version)
+ _check_pandas_roundtrip(
+ df, columns=[0, 2], expected=df[['a', 'c']], version=version)
+
+ # different order is followed
+ _check_pandas_roundtrip(
+ df, columns=['b', 'a'], expected=df[['b', 'a']], version=version)
+ _check_pandas_roundtrip(
+ df, columns=[1, 0], expected=df[['b', 'a']], version=version)
+
+
+def test_read_column_duplicated_selection(tempdir, version):
+ # duplicated columns in the column selection
+ table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'b', 'c'])
+ path = str(tempdir / "data.feather")
+ write_feather(table, path, version=version)
+
+ expected = pa.table([[1, 2, 3], [4, 5, 6], [1, 2, 3]],
+ names=['a', 'b', 'a'])
+ for col_selection in [['a', 'b', 'a'], [0, 1, 0]]:
+ result = read_table(path, columns=col_selection)
+ assert result.equals(expected)
+
+
+def test_read_column_duplicated_in_file(tempdir):
+ # duplicated columns in feather file (only works for feather v2)
+ table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'b', 'a'])
+ path = str(tempdir / "data.feather")
+ write_feather(table, path, version=2)
+
+ # no selection works fine
+ result = read_table(path)
+ assert result.equals(table)
+
+ # selection with indices works
+ result = read_table(path, columns=[0, 2])
+ assert result.column_names == ['a', 'a']
+
+ # selection with column names errors
+ with pytest.raises(ValueError):
+ read_table(path, columns=['a', 'b'])
+
+
+def test_nested_types(compression):
+ # https://issues.apache.org/jira/browse/ARROW-8860
+ table = pa.table({'col': pa.StructArray.from_arrays(
+ [[0, 1, 2], [1, 2, 3]], names=["f1", "f2"])})
+ _check_arrow_roundtrip(table, compression=compression)
+
+ table = pa.table({'col': pa.array([[1, 2], [3, 4]])})
+ _check_arrow_roundtrip(table, compression=compression)
+
+ table = pa.table({'col': pa.array([[[1, 2], [3, 4]], [[5, 6], None]])})
+ _check_arrow_roundtrip(table, compression=compression)
+
+
+@h.given(past.all_tables, st.sampled_from(["uncompressed", "lz4", "zstd"]))
+def test_roundtrip(table, compression):
+ _check_arrow_roundtrip(table, compression=compression)
+
+
+@pytest.mark.lz4
+def test_feather_v017_experimental_compression_backward_compatibility(datadir):
+ # ARROW-11163 - ensure newer pyarrow versions can read the old feather
+ # files from version 0.17.0 with experimental compression support (before
+ # it was officially added to IPC format in 1.0.0)
+
+ # file generated with:
+ # table = pa.table({'a': range(5)})
+ # from pyarrow import feather
+ # feather.write_feather(
+ # table, "v0.17.0.version=2-compression=lz4.feather",
+ # compression="lz4", version=2)
+ expected = pa.table({'a': range(5)})
+ result = read_table(datadir / "v0.17.0.version=2-compression=lz4.feather")
+ assert result.equals(expected)
diff --git a/src/arrow/python/pyarrow/tests/test_filesystem.py b/src/arrow/python/pyarrow/tests/test_filesystem.py
new file mode 100644
index 000000000..3d54f33e1
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_filesystem.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+import pyarrow as pa
+from pyarrow import filesystem
+
+import pytest
+
+
+def test_filesystem_deprecated():
+ with pytest.warns(FutureWarning):
+ filesystem.LocalFileSystem()
+
+ with pytest.warns(FutureWarning):
+ filesystem.LocalFileSystem.get_instance()
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7),
+ reason="getattr needs Python 3.7")
+def test_filesystem_deprecated_toplevel():
+
+ with pytest.warns(FutureWarning):
+ pa.localfs
+
+ with pytest.warns(FutureWarning):
+ pa.FileSystem
+
+ with pytest.warns(FutureWarning):
+ pa.LocalFileSystem
+
+ with pytest.warns(FutureWarning):
+ pa.HadoopFileSystem
+
+
+def test_resolve_uri():
+ uri = "file:///home/user/myfile.parquet"
+ fs, path = filesystem.resolve_filesystem_and_path(uri)
+ assert isinstance(fs, filesystem.LocalFileSystem)
+ assert path == "/home/user/myfile.parquet"
+
+
+def test_resolve_local_path():
+ for uri in ['/home/user/myfile.parquet',
+ 'myfile.parquet',
+ 'my # file ? parquet',
+ 'C:/Windows/myfile.parquet',
+ r'C:\\Windows\\myfile.parquet',
+ ]:
+ fs, path = filesystem.resolve_filesystem_and_path(uri)
+ assert isinstance(fs, filesystem.LocalFileSystem)
+ assert path == uri
diff --git a/src/arrow/python/pyarrow/tests/test_flight.py b/src/arrow/python/pyarrow/tests/test_flight.py
new file mode 100644
index 000000000..5c40467a5
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_flight.py
@@ -0,0 +1,2047 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import ast
+import base64
+import itertools
+import os
+import signal
+import struct
+import tempfile
+import threading
+import time
+import traceback
+import json
+
+import numpy as np
+import pytest
+import pyarrow as pa
+
+from pyarrow.lib import tobytes
+from pyarrow.util import pathlib, find_free_port
+from pyarrow.tests import util
+
+try:
+ from pyarrow import flight
+ from pyarrow.flight import (
+ FlightClient, FlightServerBase,
+ ServerAuthHandler, ClientAuthHandler,
+ ServerMiddleware, ServerMiddlewareFactory,
+ ClientMiddleware, ClientMiddlewareFactory,
+ )
+except ImportError:
+ flight = None
+ FlightClient, FlightServerBase = object, object
+ ServerAuthHandler, ClientAuthHandler = object, object
+ ServerMiddleware, ServerMiddlewareFactory = object, object
+ ClientMiddleware, ClientMiddlewareFactory = object, object
+
+# Marks all of the tests in this module
+# Ignore these with pytest ... -m 'not flight'
+pytestmark = pytest.mark.flight
+
+
+def test_import():
+ # So we see the ImportError somewhere
+ import pyarrow.flight # noqa
+
+
+def resource_root():
+ """Get the path to the test resources directory."""
+ if not os.environ.get("ARROW_TEST_DATA"):
+ raise RuntimeError("Test resources not found; set "
+ "ARROW_TEST_DATA to <repo root>/testing/data")
+ return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight"
+
+
+def read_flight_resource(path):
+ """Get the contents of a test resource file."""
+ root = resource_root()
+ if not root:
+ return None
+ try:
+ with (root / path).open("rb") as f:
+ return f.read()
+ except FileNotFoundError:
+ raise RuntimeError(
+ "Test resource {} not found; did you initialize the "
+ "test resource submodule?\n{}".format(root / path,
+ traceback.format_exc()))
+
+
+def example_tls_certs():
+ """Get the paths to test TLS certificates."""
+ return {
+ "root_cert": read_flight_resource("root-ca.pem"),
+ "certificates": [
+ flight.CertKeyPair(
+ cert=read_flight_resource("cert0.pem"),
+ key=read_flight_resource("cert0.key"),
+ ),
+ flight.CertKeyPair(
+ cert=read_flight_resource("cert1.pem"),
+ key=read_flight_resource("cert1.key"),
+ ),
+ ]
+ }
+
+
+def simple_ints_table():
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ return pa.Table.from_arrays(data, names=['some_ints'])
+
+
+def simple_dicts_table():
+ dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8())
+ data = [
+ pa.chunked_array([
+ pa.DictionaryArray.from_arrays([1, 0, None], dict_values),
+ pa.DictionaryArray.from_arrays([2, 1], dict_values)
+ ])
+ ]
+ return pa.Table.from_arrays(data, names=['some_dicts'])
+
+
+class ConstantFlightServer(FlightServerBase):
+ """A Flight server that always returns the same data.
+
+ See ARROW-4796: this server implementation will segfault if Flight
+ does not properly hold a reference to the Table object.
+ """
+
+ CRITERIA = b"the expected criteria"
+
+ def __init__(self, location=None, options=None, **kwargs):
+ super().__init__(location, **kwargs)
+ # Ticket -> Table
+ self.table_factories = {
+ b'ints': simple_ints_table,
+ b'dicts': simple_dicts_table,
+ }
+ self.options = options
+
+ def list_flights(self, context, criteria):
+ if criteria == self.CRITERIA:
+ yield flight.FlightInfo(
+ pa.schema([]),
+ flight.FlightDescriptor.for_path('/foo'),
+ [],
+ -1, -1
+ )
+
+ def do_get(self, context, ticket):
+ # Return a fresh table, so that Flight is the only one keeping a
+ # reference.
+ table = self.table_factories[ticket.ticket]()
+ return flight.RecordBatchStream(table, options=self.options)
+
+
+class MetadataFlightServer(FlightServerBase):
+ """A Flight server that numbers incoming/outgoing data."""
+
+ def __init__(self, options=None, **kwargs):
+ super().__init__(**kwargs)
+ self.options = options
+
+ def do_get(self, context, ticket):
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+ return flight.GeneratorStream(
+ table.schema,
+ self.number_batches(table),
+ options=self.options)
+
+ def do_put(self, context, descriptor, reader, writer):
+ counter = 0
+ expected_data = [-10, -5, 0, 5, 10]
+ while True:
+ try:
+ batch, buf = reader.read_chunk()
+ assert batch.equals(pa.RecordBatch.from_arrays(
+ [pa.array([expected_data[counter]])],
+ ['a']
+ ))
+ assert buf is not None
+ client_counter, = struct.unpack('<i', buf.to_pybytes())
+ assert counter == client_counter
+ writer.write(struct.pack('<i', counter))
+ counter += 1
+ except StopIteration:
+ return
+
+ @staticmethod
+ def number_batches(table):
+ for idx, batch in enumerate(table.to_batches()):
+ buf = struct.pack('<i', idx)
+ yield batch, buf
+
+
+class EchoFlightServer(FlightServerBase):
+ """A Flight server that returns the last data uploaded."""
+
+ def __init__(self, location=None, expected_schema=None, **kwargs):
+ super().__init__(location, **kwargs)
+ self.last_message = None
+ self.expected_schema = expected_schema
+
+ def do_get(self, context, ticket):
+ return flight.RecordBatchStream(self.last_message)
+
+ def do_put(self, context, descriptor, reader, writer):
+ if self.expected_schema:
+ assert self.expected_schema == reader.schema
+ self.last_message = reader.read_all()
+
+ def do_exchange(self, context, descriptor, reader, writer):
+ for chunk in reader:
+ pass
+
+
+class EchoStreamFlightServer(EchoFlightServer):
+ """An echo server that streams individual record batches."""
+
+ def do_get(self, context, ticket):
+ return flight.GeneratorStream(
+ self.last_message.schema,
+ self.last_message.to_batches(max_chunksize=1024))
+
+ def list_actions(self, context):
+ return []
+
+ def do_action(self, context, action):
+ if action.type == "who-am-i":
+ return [context.peer_identity(), context.peer().encode("utf-8")]
+ raise NotImplementedError
+
+
+class GetInfoFlightServer(FlightServerBase):
+ """A Flight server that tests GetFlightInfo."""
+
+ def get_flight_info(self, context, descriptor):
+ return flight.FlightInfo(
+ pa.schema([('a', pa.int32())]),
+ descriptor,
+ [
+ flight.FlightEndpoint(b'', ['grpc://test']),
+ flight.FlightEndpoint(
+ b'',
+ [flight.Location.for_grpc_tcp('localhost', 5005)],
+ ),
+ ],
+ -1,
+ -1,
+ )
+
+ def get_schema(self, context, descriptor):
+ info = self.get_flight_info(context, descriptor)
+ return flight.SchemaResult(info.schema)
+
+
+class ListActionsFlightServer(FlightServerBase):
+ """A Flight server that tests ListActions."""
+
+ @classmethod
+ def expected_actions(cls):
+ return [
+ ("action-1", "description"),
+ ("action-2", ""),
+ flight.ActionType("action-3", "more detail"),
+ ]
+
+ def list_actions(self, context):
+ yield from self.expected_actions()
+
+
+class ListActionsErrorFlightServer(FlightServerBase):
+ """A Flight server that tests ListActions."""
+
+ def list_actions(self, context):
+ yield ("action-1", "")
+ yield "foo"
+
+
+class CheckTicketFlightServer(FlightServerBase):
+ """A Flight server that compares the given ticket to an expected value."""
+
+ def __init__(self, expected_ticket, location=None, **kwargs):
+ super().__init__(location, **kwargs)
+ self.expected_ticket = expected_ticket
+
+ def do_get(self, context, ticket):
+ assert self.expected_ticket == ticket.ticket
+ data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
+ table = pa.Table.from_arrays(data1, names=['a'])
+ return flight.RecordBatchStream(table)
+
+ def do_put(self, context, descriptor, reader):
+ self.last_message = reader.read_all()
+
+
+class InvalidStreamFlightServer(FlightServerBase):
+ """A Flight server that tries to return messages with differing schemas."""
+
+ schema = pa.schema([('a', pa.int32())])
+
+ def do_get(self, context, ticket):
+ data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
+ data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())]
+ assert data1.type != data2.type
+ table1 = pa.Table.from_arrays(data1, names=['a'])
+ table2 = pa.Table.from_arrays(data2, names=['a'])
+ assert table1.schema == self.schema
+
+ return flight.GeneratorStream(self.schema, [table1, table2])
+
+
+class NeverSendsDataFlightServer(FlightServerBase):
+ """A Flight server that never actually yields data."""
+
+ schema = pa.schema([('a', pa.int32())])
+
+ def do_get(self, context, ticket):
+ if ticket.ticket == b'yield_data':
+ # Check that the server handler will ignore empty tables
+ # up to a certain extent
+ data = [
+ self.schema.empty_table(),
+ self.schema.empty_table(),
+ pa.RecordBatch.from_arrays([range(5)], schema=self.schema),
+ ]
+ return flight.GeneratorStream(self.schema, data)
+ return flight.GeneratorStream(
+ self.schema, itertools.repeat(self.schema.empty_table()))
+
+
+class SlowFlightServer(FlightServerBase):
+ """A Flight server that delays its responses to test timeouts."""
+
+ def do_get(self, context, ticket):
+ return flight.GeneratorStream(pa.schema([('a', pa.int32())]),
+ self.slow_stream())
+
+ def do_action(self, context, action):
+ time.sleep(0.5)
+ return []
+
+ @staticmethod
+ def slow_stream():
+ data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
+ yield pa.Table.from_arrays(data1, names=['a'])
+ # The second message should never get sent; the client should
+ # cancel before we send this
+ time.sleep(10)
+ yield pa.Table.from_arrays(data1, names=['a'])
+
+
+class ErrorFlightServer(FlightServerBase):
+ """A Flight server that uses all the Flight-specific errors."""
+
+ def do_action(self, context, action):
+ if action.type == "internal":
+ raise flight.FlightInternalError("foo")
+ elif action.type == "timedout":
+ raise flight.FlightTimedOutError("foo")
+ elif action.type == "cancel":
+ raise flight.FlightCancelledError("foo")
+ elif action.type == "unauthenticated":
+ raise flight.FlightUnauthenticatedError("foo")
+ elif action.type == "unauthorized":
+ raise flight.FlightUnauthorizedError("foo")
+ elif action.type == "protobuf":
+ err_msg = b'this is an error message'
+ raise flight.FlightUnauthorizedError("foo", err_msg)
+ raise NotImplementedError
+
+ def list_flights(self, context, criteria):
+ yield flight.FlightInfo(
+ pa.schema([]),
+ flight.FlightDescriptor.for_path('/foo'),
+ [],
+ -1, -1
+ )
+ raise flight.FlightInternalError("foo")
+
+ def do_put(self, context, descriptor, reader, writer):
+ if descriptor.command == b"internal":
+ raise flight.FlightInternalError("foo")
+ elif descriptor.command == b"timedout":
+ raise flight.FlightTimedOutError("foo")
+ elif descriptor.command == b"cancel":
+ raise flight.FlightCancelledError("foo")
+ elif descriptor.command == b"unauthenticated":
+ raise flight.FlightUnauthenticatedError("foo")
+ elif descriptor.command == b"unauthorized":
+ raise flight.FlightUnauthorizedError("foo")
+ elif descriptor.command == b"protobuf":
+ err_msg = b'this is an error message'
+ raise flight.FlightUnauthorizedError("foo", err_msg)
+
+
+class ExchangeFlightServer(FlightServerBase):
+ """A server for testing DoExchange."""
+
+ def __init__(self, options=None, **kwargs):
+ super().__init__(**kwargs)
+ self.options = options
+
+ def do_exchange(self, context, descriptor, reader, writer):
+ if descriptor.descriptor_type != flight.DescriptorType.CMD:
+ raise pa.ArrowInvalid("Must provide a command descriptor")
+ elif descriptor.command == b"echo":
+ return self.exchange_echo(context, reader, writer)
+ elif descriptor.command == b"get":
+ return self.exchange_do_get(context, reader, writer)
+ elif descriptor.command == b"put":
+ return self.exchange_do_put(context, reader, writer)
+ elif descriptor.command == b"transform":
+ return self.exchange_transform(context, reader, writer)
+ else:
+ raise pa.ArrowInvalid(
+ "Unknown command: {}".format(descriptor.command))
+
+ def exchange_do_get(self, context, reader, writer):
+ """Emulate DoGet with DoExchange."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=["a"])
+ writer.begin(data.schema)
+ writer.write_table(data)
+
+ def exchange_do_put(self, context, reader, writer):
+ """Emulate DoPut with DoExchange."""
+ num_batches = 0
+ for chunk in reader:
+ if not chunk.data:
+ raise pa.ArrowInvalid("All chunks must have data.")
+ num_batches += 1
+ writer.write_metadata(str(num_batches).encode("utf-8"))
+
+ def exchange_echo(self, context, reader, writer):
+ """Run a simple echo server."""
+ started = False
+ for chunk in reader:
+ if not started and chunk.data:
+ writer.begin(chunk.data.schema, options=self.options)
+ started = True
+ if chunk.app_metadata and chunk.data:
+ writer.write_with_metadata(chunk.data, chunk.app_metadata)
+ elif chunk.app_metadata:
+ writer.write_metadata(chunk.app_metadata)
+ elif chunk.data:
+ writer.write_batch(chunk.data)
+ else:
+ assert False, "Should not happen"
+
+ def exchange_transform(self, context, reader, writer):
+ """Sum rows in an uploaded table."""
+ for field in reader.schema:
+ if not pa.types.is_integer(field.type):
+ raise pa.ArrowInvalid("Invalid field: " + repr(field))
+ table = reader.read_all()
+ sums = [0] * table.num_rows
+ for column in table:
+ for row, value in enumerate(column):
+ sums[row] += value.as_py()
+ result = pa.Table.from_arrays([pa.array(sums)], names=["sum"])
+ writer.begin(result.schema)
+ writer.write_table(result)
+
+
+class HttpBasicServerAuthHandler(ServerAuthHandler):
+ """An example implementation of HTTP basic authentication."""
+
+ def __init__(self, creds):
+ super().__init__()
+ self.creds = creds
+
+ def authenticate(self, outgoing, incoming):
+ buf = incoming.read()
+ auth = flight.BasicAuth.deserialize(buf)
+ if auth.username not in self.creds:
+ raise flight.FlightUnauthenticatedError("unknown user")
+ if self.creds[auth.username] != auth.password:
+ raise flight.FlightUnauthenticatedError("wrong password")
+ outgoing.write(tobytes(auth.username))
+
+ def is_valid(self, token):
+ if not token:
+ raise flight.FlightUnauthenticatedError("token not provided")
+ if token not in self.creds:
+ raise flight.FlightUnauthenticatedError("unknown user")
+ return token
+
+
+class HttpBasicClientAuthHandler(ClientAuthHandler):
+ """An example implementation of HTTP basic authentication."""
+
+ def __init__(self, username, password):
+ super().__init__()
+ self.basic_auth = flight.BasicAuth(username, password)
+ self.token = None
+
+ def authenticate(self, outgoing, incoming):
+ auth = self.basic_auth.serialize()
+ outgoing.write(auth)
+ self.token = incoming.read()
+
+ def get_token(self):
+ return self.token
+
+
+class TokenServerAuthHandler(ServerAuthHandler):
+ """An example implementation of authentication via handshake."""
+
+ def __init__(self, creds):
+ super().__init__()
+ self.creds = creds
+
+ def authenticate(self, outgoing, incoming):
+ username = incoming.read()
+ password = incoming.read()
+ if username in self.creds and self.creds[username] == password:
+ outgoing.write(base64.b64encode(b'secret:' + username))
+ else:
+ raise flight.FlightUnauthenticatedError(
+ "invalid username/password")
+
+ def is_valid(self, token):
+ token = base64.b64decode(token)
+ if not token.startswith(b'secret:'):
+ raise flight.FlightUnauthenticatedError("invalid token")
+ return token[7:]
+
+
+class TokenClientAuthHandler(ClientAuthHandler):
+ """An example implementation of authentication via handshake."""
+
+ def __init__(self, username, password):
+ super().__init__()
+ self.username = username
+ self.password = password
+ self.token = b''
+
+ def authenticate(self, outgoing, incoming):
+ outgoing.write(self.username)
+ outgoing.write(self.password)
+ self.token = incoming.read()
+
+ def get_token(self):
+ return self.token
+
+
+class NoopAuthHandler(ServerAuthHandler):
+ """A no-op auth handler."""
+
+ def authenticate(self, outgoing, incoming):
+ """Do nothing."""
+
+ def is_valid(self, token):
+ """
+ Returning an empty string.
+ Returning None causes Type error.
+ """
+ return ""
+
+
+def case_insensitive_header_lookup(headers, lookup_key):
+ """Lookup the value of given key in the given headers.
+ The key lookup is case insensitive.
+ """
+ for key in headers:
+ if key.lower() == lookup_key.lower():
+ return headers.get(key)
+
+
+class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory):
+ """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware."""
+
+ def __init__(self):
+ self.call_credential = []
+
+ def start_call(self, info):
+ return ClientHeaderAuthMiddleware(self)
+
+ def set_call_credential(self, call_credential):
+ self.call_credential = call_credential
+
+
+class ClientHeaderAuthMiddleware(ClientMiddleware):
+ """
+ ClientMiddleware that extracts the authorization header
+ from the server.
+
+ This is an example of a ClientMiddleware that can extract
+ the bearer token authorization header from a HTTP header
+ authentication enabled server.
+
+ Parameters
+ ----------
+ factory : ClientHeaderAuthMiddlewareFactory
+ This factory is used to set call credentials if an
+ authorization header is found in the headers from the server.
+ """
+
+ def __init__(self, factory):
+ self.factory = factory
+
+ def received_headers(self, headers):
+ auth_header = case_insensitive_header_lookup(headers, 'Authorization')
+ self.factory.set_call_credential([
+ b'authorization',
+ auth_header[0].encode("utf-8")])
+
+
+class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory):
+ """Validates incoming username and password."""
+
+ def start_call(self, info, headers):
+ auth_header = case_insensitive_header_lookup(
+ headers,
+ 'Authorization'
+ )
+ values = auth_header[0].split(' ')
+ token = ''
+ error_message = 'Invalid credentials'
+
+ if values[0] == 'Basic':
+ decoded = base64.b64decode(values[1])
+ pair = decoded.decode("utf-8").split(':')
+ if not (pair[0] == 'test' and pair[1] == 'password'):
+ raise flight.FlightUnauthenticatedError(error_message)
+ token = 'token1234'
+ elif values[0] == 'Bearer':
+ token = values[1]
+ if not token == 'token1234':
+ raise flight.FlightUnauthenticatedError(error_message)
+ else:
+ raise flight.FlightUnauthenticatedError(error_message)
+
+ return HeaderAuthServerMiddleware(token)
+
+
+class HeaderAuthServerMiddleware(ServerMiddleware):
+ """A ServerMiddleware that transports incoming username and passowrd."""
+
+ def __init__(self, token):
+ self.token = token
+
+ def sending_headers(self):
+ return {'authorization': 'Bearer ' + self.token}
+
+
+class HeaderAuthFlightServer(FlightServerBase):
+ """A Flight server that tests with basic token authentication. """
+
+ def do_action(self, context, action):
+ middleware = context.get_middleware("auth")
+ if middleware:
+ auth_header = case_insensitive_header_lookup(
+ middleware.sending_headers(), 'Authorization')
+ values = auth_header.split(' ')
+ return [values[1].encode("utf-8")]
+ raise flight.FlightUnauthenticatedError(
+ 'No token auth middleware found.')
+
+
+class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory):
+ """A ServerMiddlewareFactory that transports arbitrary headers."""
+
+ def start_call(self, info, headers):
+ return ArbitraryHeadersServerMiddleware(headers)
+
+
+class ArbitraryHeadersServerMiddleware(ServerMiddleware):
+ """A ServerMiddleware that transports arbitrary headers."""
+
+ def __init__(self, incoming):
+ self.incoming = incoming
+
+ def sending_headers(self):
+ return self.incoming
+
+
+class ArbitraryHeadersFlightServer(FlightServerBase):
+ """A Flight server that tests multiple arbitrary headers."""
+
+ def do_action(self, context, action):
+ middleware = context.get_middleware("arbitrary-headers")
+ if middleware:
+ headers = middleware.sending_headers()
+ header_1 = case_insensitive_header_lookup(
+ headers,
+ 'test-header-1'
+ )
+ header_2 = case_insensitive_header_lookup(
+ headers,
+ 'test-header-2'
+ )
+ value1 = header_1[0].encode("utf-8")
+ value2 = header_2[0].encode("utf-8")
+ return [value1, value2]
+ raise flight.FlightServerError("No headers middleware found")
+
+
+class HeaderServerMiddleware(ServerMiddleware):
+ """Expose a per-call value to the RPC method body."""
+
+ def __init__(self, special_value):
+ self.special_value = special_value
+
+
+class HeaderServerMiddlewareFactory(ServerMiddlewareFactory):
+ """Expose a per-call hard-coded value to the RPC method body."""
+
+ def start_call(self, info, headers):
+ return HeaderServerMiddleware("right value")
+
+
+class HeaderFlightServer(FlightServerBase):
+ """Echo back the per-call hard-coded value."""
+
+ def do_action(self, context, action):
+ middleware = context.get_middleware("test")
+ if middleware:
+ return [middleware.special_value.encode()]
+ return [b""]
+
+
+class MultiHeaderFlightServer(FlightServerBase):
+ """Test sending/receiving multiple (binary-valued) headers."""
+
+ def do_action(self, context, action):
+ middleware = context.get_middleware("test")
+ headers = repr(middleware.client_headers).encode("utf-8")
+ return [headers]
+
+
+class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory):
+ """Deny access to certain methods based on a header."""
+
+ def start_call(self, info, headers):
+ if info.method == flight.FlightMethod.LIST_ACTIONS:
+ # No auth needed
+ return
+
+ token = headers.get("x-auth-token")
+ if not token:
+ raise flight.FlightUnauthenticatedError("No token")
+
+ token = token[0]
+ if token != "password":
+ raise flight.FlightUnauthenticatedError("Invalid token")
+
+ return HeaderServerMiddleware(token)
+
+
+class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory):
+ def start_call(self, info):
+ return SelectiveAuthClientMiddleware()
+
+
+class SelectiveAuthClientMiddleware(ClientMiddleware):
+ def sending_headers(self):
+ return {
+ "x-auth-token": "password",
+ }
+
+
+class RecordingServerMiddlewareFactory(ServerMiddlewareFactory):
+ """Record what methods were called."""
+
+ def __init__(self):
+ super().__init__()
+ self.methods = []
+
+ def start_call(self, info, headers):
+ self.methods.append(info.method)
+ return None
+
+
+class RecordingClientMiddlewareFactory(ClientMiddlewareFactory):
+ """Record what methods were called."""
+
+ def __init__(self):
+ super().__init__()
+ self.methods = []
+
+ def start_call(self, info):
+ self.methods.append(info.method)
+ return None
+
+
+class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory):
+ """Test sending/receiving multiple (binary-valued) headers."""
+
+ def __init__(self):
+ # Read in test_middleware_multi_header below.
+ # The middleware instance will update this value.
+ self.last_headers = {}
+
+ def start_call(self, info):
+ return MultiHeaderClientMiddleware(self)
+
+
+class MultiHeaderClientMiddleware(ClientMiddleware):
+ """Test sending/receiving multiple (binary-valued) headers."""
+
+ EXPECTED = {
+ "x-text": ["foo", "bar"],
+ "x-binary-bin": [b"\x00", b"\x01"],
+ }
+
+ def __init__(self, factory):
+ self.factory = factory
+
+ def sending_headers(self):
+ return self.EXPECTED
+
+ def received_headers(self, headers):
+ # Let the test code know what the last set of headers we
+ # received were.
+ self.factory.last_headers = headers
+
+
+class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory):
+ """Test sending/receiving multiple (binary-valued) headers."""
+
+ def start_call(self, info, headers):
+ return MultiHeaderServerMiddleware(headers)
+
+
+class MultiHeaderServerMiddleware(ServerMiddleware):
+ """Test sending/receiving multiple (binary-valued) headers."""
+
+ def __init__(self, client_headers):
+ self.client_headers = client_headers
+
+ def sending_headers(self):
+ return MultiHeaderClientMiddleware.EXPECTED
+
+
+class LargeMetadataFlightServer(FlightServerBase):
+ """Regression test for ARROW-13253."""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._metadata = b' ' * (2 ** 31 + 1)
+
+ def do_get(self, context, ticket):
+ schema = pa.schema([('a', pa.int64())])
+ return flight.GeneratorStream(schema, [
+ (pa.record_batch([[1]], schema=schema), self._metadata),
+ ])
+
+ def do_exchange(self, context, descriptor, reader, writer):
+ writer.write_metadata(self._metadata)
+
+
+def test_flight_server_location_argument():
+ locations = [
+ None,
+ 'grpc://localhost:0',
+ ('localhost', find_free_port()),
+ ]
+ for location in locations:
+ with FlightServerBase(location) as server:
+ assert isinstance(server, FlightServerBase)
+
+
+def test_server_exit_reraises_exception():
+ with pytest.raises(ValueError):
+ with FlightServerBase():
+ raise ValueError()
+
+
+@pytest.mark.slow
+def test_client_wait_for_available():
+ location = ('localhost', find_free_port())
+ server = None
+
+ def serve():
+ global server
+ time.sleep(0.5)
+ server = FlightServerBase(location)
+ server.serve()
+
+ client = FlightClient(location)
+ thread = threading.Thread(target=serve, daemon=True)
+ thread.start()
+
+ started = time.time()
+ client.wait_for_available(timeout=5)
+ elapsed = time.time() - started
+ assert elapsed >= 0.5
+
+
+def test_flight_list_flights():
+ """Try a simple list_flights call."""
+ with ConstantFlightServer() as server:
+ client = flight.connect(('localhost', server.port))
+ assert list(client.list_flights()) == []
+ flights = client.list_flights(ConstantFlightServer.CRITERIA)
+ assert len(list(flights)) == 1
+
+
+def test_flight_do_get_ints():
+ """Try a simple do_get call."""
+ table = simple_ints_table()
+
+ with ConstantFlightServer() as server:
+ client = flight.connect(('localhost', server.port))
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+ assert data.equals(table)
+
+ options = pa.ipc.IpcWriteOptions(
+ metadata_version=pa.ipc.MetadataVersion.V4)
+ with ConstantFlightServer(options=options) as server:
+ client = flight.connect(('localhost', server.port))
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+ assert data.equals(table)
+
+ # Also test via RecordBatchReader interface
+ data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all()
+ assert data.equals(table)
+
+ with pytest.raises(flight.FlightServerError,
+ match="expected IpcWriteOptions, got <class 'int'>"):
+ with ConstantFlightServer(options=42) as server:
+ client = flight.connect(('localhost', server.port))
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+
+
+@pytest.mark.pandas
+def test_do_get_ints_pandas():
+ """Try a simple do_get call."""
+ table = simple_ints_table()
+
+ with ConstantFlightServer() as server:
+ client = flight.connect(('localhost', server.port))
+ data = client.do_get(flight.Ticket(b'ints')).read_pandas()
+ assert list(data['some_ints']) == table.column(0).to_pylist()
+
+
+def test_flight_do_get_dicts():
+ table = simple_dicts_table()
+
+ with ConstantFlightServer() as server:
+ client = flight.connect(('localhost', server.port))
+ data = client.do_get(flight.Ticket(b'dicts')).read_all()
+ assert data.equals(table)
+
+
+def test_flight_do_get_ticket():
+ """Make sure Tickets get passed to the server."""
+ data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
+ table = pa.Table.from_arrays(data1, names=['a'])
+ with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server:
+ client = flight.connect(('localhost', server.port))
+ data = client.do_get(flight.Ticket(b'the-ticket')).read_all()
+ assert data.equals(table)
+
+
+def test_flight_get_info():
+ """Make sure FlightEndpoint accepts string and object URIs."""
+ with GetInfoFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ info = client.get_flight_info(flight.FlightDescriptor.for_command(b''))
+ assert info.total_records == -1
+ assert info.total_bytes == -1
+ assert info.schema == pa.schema([('a', pa.int32())])
+ assert len(info.endpoints) == 2
+ assert len(info.endpoints[0].locations) == 1
+ assert info.endpoints[0].locations[0] == flight.Location('grpc://test')
+ assert info.endpoints[1].locations[0] == \
+ flight.Location.for_grpc_tcp('localhost', 5005)
+
+
+def test_flight_get_schema():
+ """Make sure GetSchema returns correct schema."""
+ with GetInfoFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ info = client.get_schema(flight.FlightDescriptor.for_command(b''))
+ assert info.schema == pa.schema([('a', pa.int32())])
+
+
+def test_list_actions():
+ """Make sure the return type of ListActions is validated."""
+ # ARROW-6392
+ with ListActionsErrorFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ with pytest.raises(
+ flight.FlightServerError,
+ match=("Results of list_actions must be "
+ "ActionType or tuple")
+ ):
+ list(client.list_actions())
+
+ with ListActionsFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ assert list(client.list_actions()) == \
+ ListActionsFlightServer.expected_actions()
+
+
+class ConvenienceServer(FlightServerBase):
+ """
+ Server for testing various implementation conveniences (auto-boxing, etc.)
+ """
+
+ @property
+ def simple_action_results(self):
+ return [b'foo', b'bar', b'baz']
+
+ def do_action(self, context, action):
+ if action.type == 'simple-action':
+ return self.simple_action_results
+ elif action.type == 'echo':
+ return [action.body]
+ elif action.type == 'bad-action':
+ return ['foo']
+ elif action.type == 'arrow-exception':
+ raise pa.ArrowMemoryError()
+
+
+def test_do_action_result_convenience():
+ with ConvenienceServer() as server:
+ client = FlightClient(('localhost', server.port))
+
+ # do_action as action type without body
+ results = [x.body for x in client.do_action('simple-action')]
+ assert results == server.simple_action_results
+
+ # do_action with tuple of type and body
+ body = b'the-body'
+ results = [x.body for x in client.do_action(('echo', body))]
+ assert results == [body]
+
+
+def test_nicer_server_exceptions():
+ with ConvenienceServer() as server:
+ client = FlightClient(('localhost', server.port))
+ with pytest.raises(flight.FlightServerError,
+ match="a bytes-like object is required"):
+ list(client.do_action('bad-action'))
+ # While Flight/C++ sends across the original status code, it
+ # doesn't get mapped to the equivalent code here, since we
+ # want to be able to distinguish between client- and server-
+ # side errors.
+ with pytest.raises(flight.FlightServerError,
+ match="ArrowMemoryError"):
+ list(client.do_action('arrow-exception'))
+
+
+def test_get_port():
+ """Make sure port() works."""
+ server = GetInfoFlightServer("grpc://localhost:0")
+ try:
+ assert server.port > 0
+ finally:
+ server.shutdown()
+
+
+@pytest.mark.skipif(os.name == 'nt',
+ reason="Unix sockets can't be tested on Windows")
+def test_flight_domain_socket():
+ """Try a simple do_get call over a Unix domain socket."""
+ with tempfile.NamedTemporaryFile() as sock:
+ sock.close()
+ location = flight.Location.for_grpc_unix(sock.name)
+ with ConstantFlightServer(location=location):
+ client = FlightClient(location)
+
+ reader = client.do_get(flight.Ticket(b'ints'))
+ table = simple_ints_table()
+ assert reader.schema.equals(table.schema)
+ data = reader.read_all()
+ assert data.equals(table)
+
+ reader = client.do_get(flight.Ticket(b'dicts'))
+ table = simple_dicts_table()
+ assert reader.schema.equals(table.schema)
+ data = reader.read_all()
+ assert data.equals(table)
+
+
+@pytest.mark.slow
+def test_flight_large_message():
+ """Try sending/receiving a large message via Flight.
+
+ See ARROW-4421: by default, gRPC won't allow us to send messages >
+ 4MiB in size.
+ """
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024 * 1024))
+ ], names=['a'])
+
+ with EchoFlightServer(expected_schema=data.schema) as server:
+ client = FlightClient(('localhost', server.port))
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ # Write a single giant chunk
+ writer.write_table(data, 10 * 1024 * 1024)
+ writer.close()
+ result = client.do_get(flight.Ticket(b'')).read_all()
+ assert result.equals(data)
+
+
+def test_flight_generator_stream():
+ """Try downloading a flight of RecordBatches in a GeneratorStream."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=['a'])
+
+ with EchoStreamFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
+ data.schema)
+ writer.write_table(data)
+ writer.close()
+ result = client.do_get(flight.Ticket(b'')).read_all()
+ assert result.equals(data)
+
+
+def test_flight_invalid_generator_stream():
+ """Try streaming data with mismatched schemas."""
+ with InvalidStreamFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ with pytest.raises(pa.ArrowException):
+ client.do_get(flight.Ticket(b'')).read_all()
+
+
+def test_timeout_fires():
+ """Make sure timeouts fire on slow requests."""
+ # Do this in a separate thread so that if it fails, we don't hang
+ # the entire test process
+ with SlowFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ action = flight.Action("", b"")
+ options = flight.FlightCallOptions(timeout=0.2)
+ # gRPC error messages change based on version, so don't look
+ # for a particular error
+ with pytest.raises(flight.FlightTimedOutError):
+ list(client.do_action(action, options=options))
+
+
+def test_timeout_passes():
+ """Make sure timeouts do not fire on fast requests."""
+ with ConstantFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ options = flight.FlightCallOptions(timeout=5.0)
+ client.do_get(flight.Ticket(b'ints'), options=options).read_all()
+
+
+basic_auth_handler = HttpBasicServerAuthHandler(creds={
+ b"test": b"p4ssw0rd",
+})
+
+token_auth_handler = TokenServerAuthHandler(creds={
+ b"test": b"p4ssw0rd",
+})
+
+
+@pytest.mark.slow
+def test_http_basic_unauth():
+ """Test that auth fails when not authenticated."""
+ with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
+ client = FlightClient(('localhost', server.port))
+ action = flight.Action("who-am-i", b"")
+ with pytest.raises(flight.FlightUnauthenticatedError,
+ match=".*unauthenticated.*"):
+ list(client.do_action(action))
+
+
+@pytest.mark.skipif(os.name == 'nt',
+ reason="ARROW-10013: gRPC on Windows corrupts peer()")
+def test_http_basic_auth():
+ """Test a Python implementation of HTTP basic authentication."""
+ with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
+ client = FlightClient(('localhost', server.port))
+ action = flight.Action("who-am-i", b"")
+ client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
+ results = client.do_action(action)
+ identity = next(results)
+ assert identity.body.to_pybytes() == b'test'
+ peer_address = next(results)
+ assert peer_address.body.to_pybytes() != b''
+
+
+def test_http_basic_auth_invalid_password():
+ """Test that auth fails with the wrong password."""
+ with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
+ client = FlightClient(('localhost', server.port))
+ action = flight.Action("who-am-i", b"")
+ with pytest.raises(flight.FlightUnauthenticatedError,
+ match=".*wrong password.*"):
+ client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
+ next(client.do_action(action))
+
+
+def test_token_auth():
+ """Test an auth mechanism that uses a handshake."""
+ with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
+ client = FlightClient(('localhost', server.port))
+ action = flight.Action("who-am-i", b"")
+ client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
+ identity = next(client.do_action(action))
+ assert identity.body.to_pybytes() == b'test'
+
+
+def test_token_auth_invalid():
+ """Test an auth mechanism that uses a handshake."""
+ with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
+ client = FlightClient(('localhost', server.port))
+ with pytest.raises(flight.FlightUnauthenticatedError):
+ client.authenticate(TokenClientAuthHandler('test', 'wrong'))
+
+
+header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
+no_op_auth_handler = NoopAuthHandler()
+
+
+def test_authenticate_basic_token():
+ """Test authenticate_basic_token with bearer token and auth headers."""
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ token_pair = client.authenticate_basic_token(b'test', b'password')
+ assert token_pair[0] == b'authorization'
+ assert token_pair[1] == b'Bearer token1234'
+
+
+def test_authenticate_basic_token_invalid_password():
+ """Test authenticate_basic_token with an invalid password."""
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ with pytest.raises(flight.FlightUnauthenticatedError):
+ client.authenticate_basic_token(b'test', b'badpassword')
+
+
+def test_authenticate_basic_token_and_action():
+ """Test authenticate_basic_token and doAction after authentication."""
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ token_pair = client.authenticate_basic_token(b'test', b'password')
+ assert token_pair[0] == b'authorization'
+ assert token_pair[1] == b'Bearer token1234'
+ options = flight.FlightCallOptions(headers=[token_pair])
+ result = list(client.do_action(
+ action=flight.Action('test-action', b''), options=options))
+ assert result[0].body.to_pybytes() == b'token1234'
+
+
+def test_authenticate_basic_token_with_client_middleware():
+ """Test authenticate_basic_token with client middleware
+ to intercept authorization header returned by the
+ HTTP header auth enabled server.
+ """
+ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
+ "auth": HeaderAuthServerMiddlewareFactory()
+ }) as server:
+ client_auth_middleware = ClientHeaderAuthMiddlewareFactory()
+ client = FlightClient(
+ ('localhost', server.port),
+ middleware=[client_auth_middleware]
+ )
+ encoded_credentials = base64.b64encode(b'test:password')
+ options = flight.FlightCallOptions(headers=[
+ (b'authorization', b'Basic ' + encoded_credentials)
+ ])
+ result = list(client.do_action(
+ action=flight.Action('test-action', b''), options=options))
+ assert result[0].body.to_pybytes() == b'token1234'
+ assert client_auth_middleware.call_credential[0] == b'authorization'
+ assert client_auth_middleware.call_credential[1] == \
+ b'Bearer ' + b'token1234'
+ result2 = list(client.do_action(
+ action=flight.Action('test-action', b''), options=options))
+ assert result2[0].body.to_pybytes() == b'token1234'
+ assert client_auth_middleware.call_credential[0] == b'authorization'
+ assert client_auth_middleware.call_credential[1] == \
+ b'Bearer ' + b'token1234'
+
+
+def test_arbitrary_headers_in_flight_call_options():
+ """Test passing multiple arbitrary headers to the middleware."""
+ with ArbitraryHeadersFlightServer(
+ auth_handler=no_op_auth_handler,
+ middleware={
+ "auth": HeaderAuthServerMiddlewareFactory(),
+ "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ token_pair = client.authenticate_basic_token(b'test', b'password')
+ assert token_pair[0] == b'authorization'
+ assert token_pair[1] == b'Bearer token1234'
+ options = flight.FlightCallOptions(headers=[
+ token_pair,
+ (b'test-header-1', b'value1'),
+ (b'test-header-2', b'value2')
+ ])
+ result = list(client.do_action(flight.Action(
+ "test-action", b""), options=options))
+ assert result[0].body.to_pybytes() == b'value1'
+ assert result[1].body.to_pybytes() == b'value2'
+
+
+def test_location_invalid():
+ """Test constructing invalid URIs."""
+ with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
+ flight.connect("%")
+
+ with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
+ ConstantFlightServer("%")
+
+
+def test_location_unknown_scheme():
+ """Test creating locations for unknown schemes."""
+ assert flight.Location("s3://foo").uri == b"s3://foo"
+ assert flight.Location("https://example.com/bar.parquet").uri == \
+ b"https://example.com/bar.parquet"
+
+
+@pytest.mark.slow
+@pytest.mark.requires_testing_data
+def test_tls_fails():
+ """Make sure clients cannot connect when cert verification fails."""
+ certs = example_tls_certs()
+
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
+ # Ensure client doesn't connect when certificate verification
+ # fails (this is a slow test since gRPC does retry a few times)
+ client = FlightClient("grpc+tls://localhost:" + str(s.port))
+
+ # gRPC error messages change based on version, so don't look
+ # for a particular error
+ with pytest.raises(flight.FlightUnavailableError):
+ client.do_get(flight.Ticket(b'ints')).read_all()
+
+
+@pytest.mark.requires_testing_data
+def test_tls_do_get():
+ """Try a simple do_get call over TLS."""
+ table = simple_ints_table()
+ certs = example_tls_certs()
+
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
+ client = FlightClient(('localhost', s.port),
+ tls_root_certs=certs["root_cert"])
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+ assert data.equals(table)
+
+
+@pytest.mark.requires_testing_data
+def test_tls_disable_server_verification():
+ """Try a simple do_get call over TLS with server verification disabled."""
+ table = simple_ints_table()
+ certs = example_tls_certs()
+
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
+ try:
+ client = FlightClient(('localhost', s.port),
+ disable_server_verification=True)
+ except NotImplementedError:
+ pytest.skip('disable_server_verification feature is not available')
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+ assert data.equals(table)
+
+
+@pytest.mark.requires_testing_data
+def test_tls_override_hostname():
+ """Check that incorrectly overriding the hostname fails."""
+ certs = example_tls_certs()
+
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
+ client = flight.connect(('localhost', s.port),
+ tls_root_certs=certs["root_cert"],
+ override_hostname="fakehostname")
+ with pytest.raises(flight.FlightUnavailableError):
+ client.do_get(flight.Ticket(b'ints'))
+
+
+def test_flight_do_get_metadata():
+ """Try a simple do_get call with metadata."""
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ batches = []
+ with MetadataFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ reader = client.do_get(flight.Ticket(b''))
+ idx = 0
+ while True:
+ try:
+ batch, metadata = reader.read_chunk()
+ batches.append(batch)
+ server_idx, = struct.unpack('<i', metadata.to_pybytes())
+ assert idx == server_idx
+ idx += 1
+ except StopIteration:
+ break
+ data = pa.Table.from_batches(batches)
+ assert data.equals(table)
+
+
+def test_flight_do_get_metadata_v4():
+ """Try a simple do_get call with V4 metadata version."""
+ table = pa.Table.from_arrays(
+ [pa.array([-10, -5, 0, 5, 10])], names=['a'])
+ options = pa.ipc.IpcWriteOptions(
+ metadata_version=pa.ipc.MetadataVersion.V4)
+ with MetadataFlightServer(options=options) as server:
+ client = FlightClient(('localhost', server.port))
+ reader = client.do_get(flight.Ticket(b''))
+ data = reader.read_all()
+ assert data.equals(table)
+
+
+def test_flight_do_put_metadata():
+ """Try a simple do_put call with metadata."""
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ with MetadataFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ writer, metadata_reader = client.do_put(
+ flight.FlightDescriptor.for_path(''),
+ table.schema)
+ with writer:
+ for idx, batch in enumerate(table.to_batches(max_chunksize=1)):
+ metadata = struct.pack('<i', idx)
+ writer.write_with_metadata(batch, metadata)
+ buf = metadata_reader.read()
+ assert buf is not None
+ server_idx, = struct.unpack('<i', buf.to_pybytes())
+ assert idx == server_idx
+
+
+def test_flight_do_put_limit():
+ """Try a simple do_put call with a size limit."""
+ large_batch = pa.RecordBatch.from_arrays([
+ pa.array(np.ones(768, dtype=np.int64())),
+ ], names=['a'])
+
+ with EchoFlightServer() as server:
+ client = FlightClient(('localhost', server.port),
+ write_size_limit_bytes=4096)
+ writer, metadata_reader = client.do_put(
+ flight.FlightDescriptor.for_path(''),
+ large_batch.schema)
+ with writer:
+ with pytest.raises(flight.FlightWriteSizeExceededError,
+ match="exceeded soft limit") as excinfo:
+ writer.write_batch(large_batch)
+ assert excinfo.value.limit == 4096
+ smaller_batches = [
+ large_batch.slice(0, 384),
+ large_batch.slice(384),
+ ]
+ for batch in smaller_batches:
+ writer.write_batch(batch)
+ expected = pa.Table.from_batches([large_batch])
+ actual = client.do_get(flight.Ticket(b'')).read_all()
+ assert expected == actual
+
+
+@pytest.mark.slow
+def test_cancel_do_get():
+ """Test canceling a DoGet operation on the client side."""
+ with ConstantFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ reader = client.do_get(flight.Ticket(b'ints'))
+ reader.cancel()
+ with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
+ reader.read_chunk()
+
+
+@pytest.mark.slow
+def test_cancel_do_get_threaded():
+ """Test canceling a DoGet operation from another thread."""
+ with SlowFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ reader = client.do_get(flight.Ticket(b'ints'))
+
+ read_first_message = threading.Event()
+ stream_canceled = threading.Event()
+ result_lock = threading.Lock()
+ raised_proper_exception = threading.Event()
+
+ def block_read():
+ reader.read_chunk()
+ read_first_message.set()
+ stream_canceled.wait(timeout=5)
+ try:
+ reader.read_chunk()
+ except flight.FlightCancelledError:
+ with result_lock:
+ raised_proper_exception.set()
+
+ thread = threading.Thread(target=block_read, daemon=True)
+ thread.start()
+ read_first_message.wait(timeout=5)
+ reader.cancel()
+ stream_canceled.set()
+ thread.join(timeout=1)
+
+ with result_lock:
+ assert raised_proper_exception.is_set()
+
+
+def test_roundtrip_types():
+ """Make sure serializable types round-trip."""
+ ticket = flight.Ticket("foo")
+ assert ticket == flight.Ticket.deserialize(ticket.serialize())
+
+ desc = flight.FlightDescriptor.for_command("test")
+ assert desc == flight.FlightDescriptor.deserialize(desc.serialize())
+
+ desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow")
+ assert desc == flight.FlightDescriptor.deserialize(desc.serialize())
+
+ info = flight.FlightInfo(
+ pa.schema([('a', pa.int32())]),
+ desc,
+ [
+ flight.FlightEndpoint(b'', ['grpc://test']),
+ flight.FlightEndpoint(
+ b'',
+ [flight.Location.for_grpc_tcp('localhost', 5005)],
+ ),
+ ],
+ -1,
+ -1,
+ )
+ info2 = flight.FlightInfo.deserialize(info.serialize())
+ assert info.schema == info2.schema
+ assert info.descriptor == info2.descriptor
+ assert info.total_bytes == info2.total_bytes
+ assert info.total_records == info2.total_records
+ assert info.endpoints == info2.endpoints
+
+
+def test_roundtrip_errors():
+ """Ensure that Flight errors propagate from server to client."""
+ with ErrorFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+
+ with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
+ list(client.do_action(flight.Action("internal", b"")))
+ with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"):
+ list(client.do_action(flight.Action("timedout", b"")))
+ with pytest.raises(flight.FlightCancelledError, match=".*foo.*"):
+ list(client.do_action(flight.Action("cancel", b"")))
+ with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"):
+ list(client.do_action(flight.Action("unauthenticated", b"")))
+ with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"):
+ list(client.do_action(flight.Action("unauthorized", b"")))
+ with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
+ list(client.list_flights())
+
+ data = [pa.array([-10, -5, 0, 5, 10])]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ exceptions = {
+ 'internal': flight.FlightInternalError,
+ 'timedout': flight.FlightTimedOutError,
+ 'cancel': flight.FlightCancelledError,
+ 'unauthenticated': flight.FlightUnauthenticatedError,
+ 'unauthorized': flight.FlightUnauthorizedError,
+ }
+
+ for command, exception in exceptions.items():
+
+ with pytest.raises(exception, match=".*foo.*"):
+ writer, reader = client.do_put(
+ flight.FlightDescriptor.for_command(command),
+ table.schema)
+ writer.write_table(table)
+ writer.close()
+
+ with pytest.raises(exception, match=".*foo.*"):
+ writer, reader = client.do_put(
+ flight.FlightDescriptor.for_command(command),
+ table.schema)
+ writer.close()
+
+
+def test_do_put_independent_read_write():
+ """Ensure that separate threads can read/write on a DoPut."""
+ # ARROW-6063: previously this would cause gRPC to abort when the
+ # writer was closed (due to simultaneous reads), or would hang
+ # forever.
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ with MetadataFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ writer, metadata_reader = client.do_put(
+ flight.FlightDescriptor.for_path(''),
+ table.schema)
+
+ count = [0]
+
+ def _reader_thread():
+ while metadata_reader.read() is not None:
+ count[0] += 1
+
+ thread = threading.Thread(target=_reader_thread)
+ thread.start()
+
+ batches = table.to_batches(max_chunksize=1)
+ with writer:
+ for idx, batch in enumerate(batches):
+ metadata = struct.pack('<i', idx)
+ writer.write_with_metadata(batch, metadata)
+ # Causes the server to stop writing and end the call
+ writer.done_writing()
+ # Thus reader thread will break out of loop
+ thread.join()
+ # writer.close() won't segfault since reader thread has
+ # stopped
+ assert count[0] == len(batches)
+
+
+def test_server_middleware_same_thread():
+ """Ensure that server middleware run on the same thread as the RPC."""
+ with HeaderFlightServer(middleware={
+ "test": HeaderServerMiddlewareFactory(),
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ results = list(client.do_action(flight.Action(b"test", b"")))
+ assert len(results) == 1
+ value = results[0].body.to_pybytes()
+ assert b"right value" == value
+
+
+def test_middleware_reject():
+ """Test rejecting an RPC with server middleware."""
+ with HeaderFlightServer(middleware={
+ "test": SelectiveAuthServerMiddlewareFactory(),
+ }) as server:
+ client = FlightClient(('localhost', server.port))
+ # The middleware allows this through without auth.
+ with pytest.raises(pa.ArrowNotImplementedError):
+ list(client.list_actions())
+
+ # But not anything else.
+ with pytest.raises(flight.FlightUnauthenticatedError):
+ list(client.do_action(flight.Action(b"", b"")))
+
+ client = FlightClient(
+ ('localhost', server.port),
+ middleware=[SelectiveAuthClientMiddlewareFactory()]
+ )
+ response = next(client.do_action(flight.Action(b"", b"")))
+ assert b"password" == response.body.to_pybytes()
+
+
+def test_middleware_mapping():
+ """Test that middleware records methods correctly."""
+ server_middleware = RecordingServerMiddlewareFactory()
+ client_middleware = RecordingClientMiddlewareFactory()
+ with FlightServerBase(middleware={"test": server_middleware}) as server:
+ client = FlightClient(
+ ('localhost', server.port),
+ middleware=[client_middleware]
+ )
+
+ descriptor = flight.FlightDescriptor.for_command(b"")
+ with pytest.raises(NotImplementedError):
+ list(client.list_flights())
+ with pytest.raises(NotImplementedError):
+ client.get_flight_info(descriptor)
+ with pytest.raises(NotImplementedError):
+ client.get_schema(descriptor)
+ with pytest.raises(NotImplementedError):
+ client.do_get(flight.Ticket(b""))
+ with pytest.raises(NotImplementedError):
+ writer, _ = client.do_put(descriptor, pa.schema([]))
+ writer.close()
+ with pytest.raises(NotImplementedError):
+ list(client.do_action(flight.Action(b"", b"")))
+ with pytest.raises(NotImplementedError):
+ list(client.list_actions())
+ with pytest.raises(NotImplementedError):
+ writer, _ = client.do_exchange(descriptor)
+ writer.close()
+
+ expected = [
+ flight.FlightMethod.LIST_FLIGHTS,
+ flight.FlightMethod.GET_FLIGHT_INFO,
+ flight.FlightMethod.GET_SCHEMA,
+ flight.FlightMethod.DO_GET,
+ flight.FlightMethod.DO_PUT,
+ flight.FlightMethod.DO_ACTION,
+ flight.FlightMethod.LIST_ACTIONS,
+ flight.FlightMethod.DO_EXCHANGE,
+ ]
+ assert server_middleware.methods == expected
+ assert client_middleware.methods == expected
+
+
+def test_extra_info():
+ with ErrorFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ try:
+ list(client.do_action(flight.Action("protobuf", b"")))
+ assert False
+ except flight.FlightUnauthorizedError as e:
+ assert e.extra_info is not None
+ ei = e.extra_info
+ assert ei == b'this is an error message'
+
+
+@pytest.mark.requires_testing_data
+def test_mtls():
+ """Test mutual TLS (mTLS) with gRPC."""
+ certs = example_tls_certs()
+ table = simple_ints_table()
+
+ with ConstantFlightServer(
+ tls_certificates=[certs["certificates"][0]],
+ verify_client=True,
+ root_certificates=certs["root_cert"]) as s:
+ client = FlightClient(
+ ('localhost', s.port),
+ tls_root_certs=certs["root_cert"],
+ cert_chain=certs["certificates"][0].cert,
+ private_key=certs["certificates"][0].key)
+ data = client.do_get(flight.Ticket(b'ints')).read_all()
+ assert data.equals(table)
+
+
+def test_doexchange_get():
+ """Emulate DoGet with DoExchange."""
+ expected = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=["a"])
+
+ with ExchangeFlightServer() as server:
+ client = FlightClient(("localhost", server.port))
+ descriptor = flight.FlightDescriptor.for_command(b"get")
+ writer, reader = client.do_exchange(descriptor)
+ with writer:
+ table = reader.read_all()
+ assert expected == table
+
+
+def test_doexchange_put():
+ """Emulate DoPut with DoExchange."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=["a"])
+ batches = data.to_batches(max_chunksize=512)
+
+ with ExchangeFlightServer() as server:
+ client = FlightClient(("localhost", server.port))
+ descriptor = flight.FlightDescriptor.for_command(b"put")
+ writer, reader = client.do_exchange(descriptor)
+ with writer:
+ writer.begin(data.schema)
+ for batch in batches:
+ writer.write_batch(batch)
+ writer.done_writing()
+ chunk = reader.read_chunk()
+ assert chunk.data is None
+ expected_buf = str(len(batches)).encode("utf-8")
+ assert chunk.app_metadata == expected_buf
+
+
+def test_doexchange_echo():
+ """Try a DoExchange echo server."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=["a"])
+ batches = data.to_batches(max_chunksize=512)
+
+ with ExchangeFlightServer() as server:
+ client = FlightClient(("localhost", server.port))
+ descriptor = flight.FlightDescriptor.for_command(b"echo")
+ writer, reader = client.do_exchange(descriptor)
+ with writer:
+ # Read/write metadata before starting data.
+ for i in range(10):
+ buf = str(i).encode("utf-8")
+ writer.write_metadata(buf)
+ chunk = reader.read_chunk()
+ assert chunk.data is None
+ assert chunk.app_metadata == buf
+
+ # Now write data without metadata.
+ writer.begin(data.schema)
+ for batch in batches:
+ writer.write_batch(batch)
+ assert reader.schema == data.schema
+ chunk = reader.read_chunk()
+ assert chunk.data == batch
+ assert chunk.app_metadata is None
+
+ # And write data with metadata.
+ for i, batch in enumerate(batches):
+ buf = str(i).encode("utf-8")
+ writer.write_with_metadata(batch, buf)
+ chunk = reader.read_chunk()
+ assert chunk.data == batch
+ assert chunk.app_metadata == buf
+
+
+def test_doexchange_echo_v4():
+ """Try a DoExchange echo server using the V4 metadata version."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 10 * 1024))
+ ], names=["a"])
+ batches = data.to_batches(max_chunksize=512)
+
+ options = pa.ipc.IpcWriteOptions(
+ metadata_version=pa.ipc.MetadataVersion.V4)
+ with ExchangeFlightServer(options=options) as server:
+ client = FlightClient(("localhost", server.port))
+ descriptor = flight.FlightDescriptor.for_command(b"echo")
+ writer, reader = client.do_exchange(descriptor)
+ with writer:
+ # Now write data without metadata.
+ writer.begin(data.schema, options=options)
+ for batch in batches:
+ writer.write_batch(batch)
+ assert reader.schema == data.schema
+ chunk = reader.read_chunk()
+ assert chunk.data == batch
+ assert chunk.app_metadata is None
+
+
+def test_doexchange_transform():
+ """Transform a table with a service."""
+ data = pa.Table.from_arrays([
+ pa.array(range(0, 1024)),
+ pa.array(range(1, 1025)),
+ pa.array(range(2, 1026)),
+ ], names=["a", "b", "c"])
+ expected = pa.Table.from_arrays([
+ pa.array(range(3, 1024 * 3 + 3, 3)),
+ ], names=["sum"])
+
+ with ExchangeFlightServer() as server:
+ client = FlightClient(("localhost", server.port))
+ descriptor = flight.FlightDescriptor.for_command(b"transform")
+ writer, reader = client.do_exchange(descriptor)
+ with writer:
+ writer.begin(data.schema)
+ writer.write_table(data)
+ writer.done_writing()
+ table = reader.read_all()
+ assert expected == table
+
+
+def test_middleware_multi_header():
+ """Test sending/receiving multiple (binary-valued) headers."""
+ with MultiHeaderFlightServer(middleware={
+ "test": MultiHeaderServerMiddlewareFactory(),
+ }) as server:
+ headers = MultiHeaderClientMiddlewareFactory()
+ client = FlightClient(('localhost', server.port), middleware=[headers])
+ response = next(client.do_action(flight.Action(b"", b"")))
+ # The server echoes the headers it got back to us.
+ raw_headers = response.body.to_pybytes().decode("utf-8")
+ client_headers = ast.literal_eval(raw_headers)
+ # Don't directly compare; gRPC may add headers like User-Agent.
+ for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
+ assert client_headers.get(header) == values
+ assert headers.last_headers.get(header) == values
+
+
+@pytest.mark.requires_testing_data
+def test_generic_options():
+ """Test setting generic client options."""
+ certs = example_tls_certs()
+
+ with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
+ # Try setting a string argument that will make requests fail
+ options = [("grpc.ssl_target_name_override", "fakehostname")]
+ client = flight.connect(('localhost', s.port),
+ tls_root_certs=certs["root_cert"],
+ generic_options=options)
+ with pytest.raises(flight.FlightUnavailableError):
+ client.do_get(flight.Ticket(b'ints'))
+ # Try setting an int argument that will make requests fail
+ options = [("grpc.max_receive_message_length", 32)]
+ client = flight.connect(('localhost', s.port),
+ tls_root_certs=certs["root_cert"],
+ generic_options=options)
+ with pytest.raises(pa.ArrowInvalid):
+ client.do_get(flight.Ticket(b'ints'))
+
+
+class CancelFlightServer(FlightServerBase):
+ """A server for testing StopToken."""
+
+ def do_get(self, context, ticket):
+ schema = pa.schema([])
+ rb = pa.RecordBatch.from_arrays([], schema=schema)
+ return flight.GeneratorStream(schema, itertools.repeat(rb))
+
+ def do_exchange(self, context, descriptor, reader, writer):
+ schema = pa.schema([])
+ rb = pa.RecordBatch.from_arrays([], schema=schema)
+ writer.begin(schema)
+ while not context.is_cancelled():
+ writer.write_batch(rb)
+ time.sleep(0.5)
+
+
+def test_interrupt():
+ if threading.current_thread().ident != threading.main_thread().ident:
+ pytest.skip("test only works from main Python thread")
+ # Skips test if not available
+ raise_signal = util.get_raise_signal()
+
+ def signal_from_thread():
+ time.sleep(0.5)
+ raise_signal(signal.SIGINT)
+
+ exc_types = (KeyboardInterrupt, pa.ArrowCancelled)
+
+ def test(read_all):
+ try:
+ try:
+ t = threading.Thread(target=signal_from_thread)
+ with pytest.raises(exc_types) as exc_info:
+ t.start()
+ read_all()
+ finally:
+ t.join()
+ except KeyboardInterrupt:
+ # In case KeyboardInterrupt didn't interrupt read_all
+ # above, at least prevent it from stopping the test suite
+ pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all")
+ e = exc_info.value.__context__
+ assert isinstance(e, pa.ArrowCancelled) or \
+ isinstance(e, KeyboardInterrupt)
+
+ with CancelFlightServer() as server:
+ client = FlightClient(("localhost", server.port))
+
+ reader = client.do_get(flight.Ticket(b""))
+ test(reader.read_all)
+
+ descriptor = flight.FlightDescriptor.for_command(b"echo")
+ writer, reader = client.do_exchange(descriptor)
+ test(reader.read_all)
+
+
+def test_never_sends_data():
+ # Regression test for ARROW-12779
+ match = "application server implementation error"
+ with NeverSendsDataFlightServer() as server:
+ client = flight.connect(('localhost', server.port))
+ with pytest.raises(flight.FlightServerError, match=match):
+ client.do_get(flight.Ticket(b'')).read_all()
+
+ # Check that the server handler will ignore empty tables
+ # up to a certain extent
+ table = client.do_get(flight.Ticket(b'yield_data')).read_all()
+ assert table.num_rows == 5
+
+
+@pytest.mark.large_memory
+@pytest.mark.slow
+def test_large_descriptor():
+ # Regression test for ARROW-13253. Placed here with appropriate marks
+ # since some CI pipelines can't run the C++ equivalent
+ large_descriptor = flight.FlightDescriptor.for_command(
+ b' ' * (2 ** 31 + 1))
+ with FlightServerBase() as server:
+ client = flight.connect(('localhost', server.port))
+ with pytest.raises(OSError,
+ match="Failed to serialize Flight descriptor"):
+ writer, _ = client.do_put(large_descriptor, pa.schema([]))
+ writer.close()
+ with pytest.raises(pa.ArrowException,
+ match="Failed to serialize Flight descriptor"):
+ client.do_exchange(large_descriptor)
+
+
+@pytest.mark.large_memory
+@pytest.mark.slow
+def test_large_metadata_client():
+ # Regression test for ARROW-13253
+ descriptor = flight.FlightDescriptor.for_command(b'')
+ metadata = b' ' * (2 ** 31 + 1)
+ with EchoFlightServer() as server:
+ client = flight.connect(('localhost', server.port))
+ with pytest.raises(pa.ArrowCapacityError,
+ match="app_metadata size overflow"):
+ writer, _ = client.do_put(descriptor, pa.schema([]))
+ with writer:
+ writer.write_metadata(metadata)
+ writer.close()
+ with pytest.raises(pa.ArrowCapacityError,
+ match="app_metadata size overflow"):
+ writer, reader = client.do_exchange(descriptor)
+ with writer:
+ writer.write_metadata(metadata)
+
+ del metadata
+ with LargeMetadataFlightServer() as server:
+ client = flight.connect(('localhost', server.port))
+ with pytest.raises(flight.FlightServerError,
+ match="app_metadata size overflow"):
+ reader = client.do_get(flight.Ticket(b''))
+ reader.read_all()
+ with pytest.raises(pa.ArrowException,
+ match="app_metadata size overflow"):
+ writer, reader = client.do_exchange(descriptor)
+ with writer:
+ reader.read_all()
+
+
+class ActionNoneFlightServer(EchoFlightServer):
+ """A server that implements a side effect to a non iterable action."""
+ VALUES = []
+
+ def do_action(self, context, action):
+ if action.type == "get_value":
+ return [json.dumps(self.VALUES).encode('utf-8')]
+ elif action.type == "append":
+ self.VALUES.append(True)
+ return None
+ raise NotImplementedError
+
+
+def test_none_action_side_effect():
+ """Ensure that actions are executed even when we don't consume iterator.
+
+ See https://issues.apache.org/jira/browse/ARROW-14255
+ """
+
+ with ActionNoneFlightServer() as server:
+ client = FlightClient(('localhost', server.port))
+ client.do_action(flight.Action("append", b""))
+ r = client.do_action(flight.Action("get_value", b""))
+ assert json.loads(next(r).body.to_pybytes()) == [True]
diff --git a/src/arrow/python/pyarrow/tests/test_fs.py b/src/arrow/python/pyarrow/tests/test_fs.py
new file mode 100644
index 000000000..48bdae8a5
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_fs.py
@@ -0,0 +1,1714 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime, timezone, timedelta
+import gzip
+import os
+import pathlib
+import pickle
+import subprocess
+import sys
+import time
+
+import pytest
+import weakref
+
+import pyarrow as pa
+from pyarrow.tests.test_io import assert_file_not_found
+from pyarrow.tests.util import _filesystem_uri, ProxyHandler
+from pyarrow.vendored.version import Version
+
+from pyarrow.fs import (FileType, FileInfo, FileSelector, FileSystem,
+ LocalFileSystem, SubTreeFileSystem, _MockFileSystem,
+ FileSystemHandler, PyFileSystem, FSSpecHandler,
+ copy_files)
+
+
+class DummyHandler(FileSystemHandler):
+ def __init__(self, value=42):
+ self._value = value
+
+ def __eq__(self, other):
+ if isinstance(other, FileSystemHandler):
+ return self._value == other._value
+ return NotImplemented
+
+ def __ne__(self, other):
+ if isinstance(other, FileSystemHandler):
+ return self._value != other._value
+ return NotImplemented
+
+ def get_type_name(self):
+ return "dummy"
+
+ def normalize_path(self, path):
+ return path
+
+ def get_file_info(self, paths):
+ info = []
+ for path in paths:
+ if "file" in path:
+ info.append(FileInfo(path, FileType.File))
+ elif "dir" in path:
+ info.append(FileInfo(path, FileType.Directory))
+ elif "notfound" in path:
+ info.append(FileInfo(path, FileType.NotFound))
+ elif "badtype" in path:
+ # Will raise when converting
+ info.append(object())
+ else:
+ raise IOError
+ return info
+
+ def get_file_info_selector(self, selector):
+ if selector.base_dir != "somedir":
+ if selector.allow_not_found:
+ return []
+ else:
+ raise FileNotFoundError(selector.base_dir)
+ infos = [
+ FileInfo("somedir/file1", FileType.File, size=123),
+ FileInfo("somedir/subdir1", FileType.Directory),
+ ]
+ if selector.recursive:
+ infos += [
+ FileInfo("somedir/subdir1/file2", FileType.File, size=456),
+ ]
+ return infos
+
+ def create_dir(self, path, recursive):
+ if path == "recursive":
+ assert recursive is True
+ elif path == "non-recursive":
+ assert recursive is False
+ else:
+ raise IOError
+
+ def delete_dir(self, path):
+ assert path == "delete_dir"
+
+ def delete_dir_contents(self, path):
+ if not path.strip("/"):
+ raise ValueError
+ assert path == "delete_dir_contents"
+
+ def delete_root_dir_contents(self):
+ pass
+
+ def delete_file(self, path):
+ assert path == "delete_file"
+
+ def move(self, src, dest):
+ assert src == "move_from"
+ assert dest == "move_to"
+
+ def copy_file(self, src, dest):
+ assert src == "copy_file_from"
+ assert dest == "copy_file_to"
+
+ def open_input_stream(self, path):
+ if "notfound" in path:
+ raise FileNotFoundError(path)
+ data = "{0}:input_stream".format(path).encode('utf8')
+ return pa.BufferReader(data)
+
+ def open_input_file(self, path):
+ if "notfound" in path:
+ raise FileNotFoundError(path)
+ data = "{0}:input_file".format(path).encode('utf8')
+ return pa.BufferReader(data)
+
+ def open_output_stream(self, path, metadata):
+ if "notfound" in path:
+ raise FileNotFoundError(path)
+ return pa.BufferOutputStream()
+
+ def open_append_stream(self, path, metadata):
+ if "notfound" in path:
+ raise FileNotFoundError(path)
+ return pa.BufferOutputStream()
+
+
+@pytest.fixture
+def localfs(request, tempdir):
+ return dict(
+ fs=LocalFileSystem(),
+ pathfn=lambda p: (tempdir / p).as_posix(),
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def py_localfs(request, tempdir):
+ return dict(
+ fs=PyFileSystem(ProxyHandler(LocalFileSystem())),
+ pathfn=lambda p: (tempdir / p).as_posix(),
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def mockfs(request):
+ return dict(
+ fs=_MockFileSystem(),
+ pathfn=lambda p: p,
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def py_mockfs(request):
+ return dict(
+ fs=PyFileSystem(ProxyHandler(_MockFileSystem())),
+ pathfn=lambda p: p,
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def localfs_with_mmap(request, tempdir):
+ return dict(
+ fs=LocalFileSystem(use_mmap=True),
+ pathfn=lambda p: (tempdir / p).as_posix(),
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def subtree_localfs(request, tempdir, localfs):
+ return dict(
+ fs=SubTreeFileSystem(str(tempdir), localfs['fs']),
+ pathfn=lambda p: p,
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def s3fs(request, s3_server):
+ request.config.pyarrow.requires('s3')
+ from pyarrow.fs import S3FileSystem
+
+ host, port, access_key, secret_key = s3_server['connection']
+ bucket = 'pyarrow-filesystem/'
+
+ fs = S3FileSystem(
+ access_key=access_key,
+ secret_key=secret_key,
+ endpoint_override='{}:{}'.format(host, port),
+ scheme='http'
+ )
+ fs.create_dir(bucket)
+
+ yield dict(
+ fs=fs,
+ pathfn=bucket.__add__,
+ allow_move_dir=False,
+ allow_append_to_file=False,
+ )
+ fs.delete_dir(bucket)
+
+
+@pytest.fixture
+def subtree_s3fs(request, s3fs):
+ prefix = 'pyarrow-filesystem/prefix/'
+ return dict(
+ fs=SubTreeFileSystem(prefix, s3fs['fs']),
+ pathfn=prefix.__add__,
+ allow_move_dir=False,
+ allow_append_to_file=False,
+ )
+
+
+_minio_limited_policy = """{
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Action": [
+ "s3:ListAllMyBuckets",
+ "s3:PutObject",
+ "s3:GetObject",
+ "s3:ListBucket",
+ "s3:PutObjectTagging",
+ "s3:DeleteObject",
+ "s3:GetObjectVersion"
+ ],
+ "Resource": [
+ "arn:aws:s3:::*"
+ ]
+ }
+ ]
+}"""
+
+
+def _run_mc_command(mcdir, *args):
+ full_args = ['mc', '-C', mcdir] + list(args)
+ proc = subprocess.Popen(full_args, stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE, encoding='utf-8')
+ retval = proc.wait(10)
+ cmd_str = ' '.join(full_args)
+ print(f'Cmd: {cmd_str}')
+ print(f' Return: {retval}')
+ print(f' Stdout: {proc.stdout.read()}')
+ print(f' Stderr: {proc.stderr.read()}')
+ if retval != 0:
+ raise ChildProcessError("Could not run mc")
+
+
+def _wait_for_minio_startup(mcdir, address, access_key, secret_key):
+ start = time.time()
+ while time.time() - start < 10:
+ try:
+ _run_mc_command(mcdir, 'alias', 'set', 'myminio',
+ f'http://{address}', access_key, secret_key)
+ return
+ except ChildProcessError:
+ time.sleep(1)
+ raise Exception("mc command could not connect to local minio")
+
+
+def _configure_limited_user(tmpdir, address, access_key, secret_key):
+ """
+ Attempts to use the mc command to configure the minio server
+ with a special user limited:limited123 which does not have
+ permission to create buckets. This mirrors some real life S3
+ configurations where users are given strict permissions.
+
+ Arrow S3 operations should still work in such a configuration
+ (e.g. see ARROW-13685)
+ """
+ try:
+ mcdir = os.path.join(tmpdir, 'mc')
+ os.mkdir(mcdir)
+ policy_path = os.path.join(tmpdir, 'limited-buckets-policy.json')
+ with open(policy_path, mode='w') as policy_file:
+ policy_file.write(_minio_limited_policy)
+ # The s3_server fixture starts the minio process but
+ # it takes a few moments for the process to become available
+ _wait_for_minio_startup(mcdir, address, access_key, secret_key)
+ # These commands create a limited user with a specific
+ # policy and creates a sample bucket for that user to
+ # write to
+ _run_mc_command(mcdir, 'admin', 'policy', 'add',
+ 'myminio/', 'no-create-buckets', policy_path)
+ _run_mc_command(mcdir, 'admin', 'user', 'add',
+ 'myminio/', 'limited', 'limited123')
+ _run_mc_command(mcdir, 'admin', 'policy', 'set',
+ 'myminio', 'no-create-buckets', 'user=limited')
+ _run_mc_command(mcdir, 'mb', 'myminio/existing-bucket')
+ return True
+ except FileNotFoundError:
+ # If mc is not found, skip these tests
+ return False
+
+
+@pytest.fixture(scope='session')
+def limited_s3_user(request, s3_server):
+ if sys.platform == 'win32':
+ # Can't rely on FileNotFound check because
+ # there is sometimes an mc command on Windows
+ # which is unrelated to the minio mc
+ pytest.skip('The mc command is not installed on Windows')
+ request.config.pyarrow.requires('s3')
+ tempdir = s3_server['tempdir']
+ host, port, access_key, secret_key = s3_server['connection']
+ address = '{}:{}'.format(host, port)
+ if not _configure_limited_user(tempdir, address, access_key, secret_key):
+ pytest.skip('Could not locate mc command to configure limited user')
+
+
+@pytest.fixture
+def hdfs(request, hdfs_connection):
+ request.config.pyarrow.requires('hdfs')
+ if not pa.have_libhdfs():
+ pytest.skip('Cannot locate libhdfs')
+
+ from pyarrow.fs import HadoopFileSystem
+
+ host, port, user = hdfs_connection
+ fs = HadoopFileSystem(host, port=port, user=user)
+
+ return dict(
+ fs=fs,
+ pathfn=lambda p: p,
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def py_fsspec_localfs(request, tempdir):
+ fsspec = pytest.importorskip("fsspec")
+ fs = fsspec.filesystem('file')
+ return dict(
+ fs=PyFileSystem(FSSpecHandler(fs)),
+ pathfn=lambda p: (tempdir / p).as_posix(),
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def py_fsspec_memoryfs(request, tempdir):
+ fsspec = pytest.importorskip("fsspec", minversion="0.7.5")
+ if fsspec.__version__ == "0.8.5":
+ # see https://issues.apache.org/jira/browse/ARROW-10934
+ pytest.skip("Bug in fsspec 0.8.5 for in-memory filesystem")
+ fs = fsspec.filesystem('memory')
+ return dict(
+ fs=PyFileSystem(FSSpecHandler(fs)),
+ pathfn=lambda p: p,
+ allow_move_dir=True,
+ allow_append_to_file=True,
+ )
+
+
+@pytest.fixture
+def py_fsspec_s3fs(request, s3_server):
+ s3fs = pytest.importorskip("s3fs")
+ if (sys.version_info < (3, 7) and
+ Version(s3fs.__version__) >= Version("0.5")):
+ pytest.skip("s3fs>=0.5 version is async and requires Python >= 3.7")
+
+ host, port, access_key, secret_key = s3_server['connection']
+ bucket = 'pyarrow-filesystem/'
+
+ fs = s3fs.S3FileSystem(
+ key=access_key,
+ secret=secret_key,
+ client_kwargs=dict(endpoint_url='http://{}:{}'.format(host, port))
+ )
+ fs = PyFileSystem(FSSpecHandler(fs))
+ fs.create_dir(bucket)
+
+ yield dict(
+ fs=fs,
+ pathfn=bucket.__add__,
+ allow_move_dir=False,
+ allow_append_to_file=True,
+ )
+ fs.delete_dir(bucket)
+
+
+@pytest.fixture(params=[
+ pytest.param(
+ pytest.lazy_fixture('localfs'),
+ id='LocalFileSystem()'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('localfs_with_mmap'),
+ id='LocalFileSystem(use_mmap=True)'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('subtree_localfs'),
+ id='SubTreeFileSystem(LocalFileSystem())'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('s3fs'),
+ id='S3FileSystem',
+ marks=pytest.mark.s3
+ ),
+ pytest.param(
+ pytest.lazy_fixture('hdfs'),
+ id='HadoopFileSystem',
+ marks=pytest.mark.hdfs
+ ),
+ pytest.param(
+ pytest.lazy_fixture('mockfs'),
+ id='_MockFileSystem()'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('py_localfs'),
+ id='PyFileSystem(ProxyHandler(LocalFileSystem()))'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('py_mockfs'),
+ id='PyFileSystem(ProxyHandler(_MockFileSystem()))'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('py_fsspec_localfs'),
+ id='PyFileSystem(FSSpecHandler(fsspec.LocalFileSystem()))'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('py_fsspec_memoryfs'),
+ id='PyFileSystem(FSSpecHandler(fsspec.filesystem("memory")))'
+ ),
+ pytest.param(
+ pytest.lazy_fixture('py_fsspec_s3fs'),
+ id='PyFileSystem(FSSpecHandler(s3fs.S3FileSystem()))',
+ marks=pytest.mark.s3
+ ),
+])
+def filesystem_config(request):
+ return request.param
+
+
+@pytest.fixture
+def fs(request, filesystem_config):
+ return filesystem_config['fs']
+
+
+@pytest.fixture
+def pathfn(request, filesystem_config):
+ return filesystem_config['pathfn']
+
+
+@pytest.fixture
+def allow_move_dir(request, filesystem_config):
+ return filesystem_config['allow_move_dir']
+
+
+@pytest.fixture
+def allow_append_to_file(request, filesystem_config):
+ return filesystem_config['allow_append_to_file']
+
+
+def check_mtime(file_info):
+ assert isinstance(file_info.mtime, datetime)
+ assert isinstance(file_info.mtime_ns, int)
+ assert file_info.mtime_ns >= 0
+ assert file_info.mtime_ns == pytest.approx(
+ file_info.mtime.timestamp() * 1e9)
+ # It's an aware UTC datetime
+ tzinfo = file_info.mtime.tzinfo
+ assert tzinfo is not None
+ assert tzinfo.utcoffset(None) == timedelta(0)
+
+
+def check_mtime_absent(file_info):
+ assert file_info.mtime is None
+ assert file_info.mtime_ns is None
+
+
+def check_mtime_or_absent(file_info):
+ if file_info.mtime is None:
+ check_mtime_absent(file_info)
+ else:
+ check_mtime(file_info)
+
+
+def skip_fsspec_s3fs(fs):
+ if fs.type_name == "py::fsspec+s3":
+ pytest.xfail(reason="Not working with fsspec's s3fs")
+
+
+@pytest.mark.s3
+def test_s3fs_limited_permissions_create_bucket(s3_server, limited_s3_user):
+ from pyarrow.fs import S3FileSystem
+
+ host, port, _, _ = s3_server['connection']
+
+ fs = S3FileSystem(
+ access_key='limited',
+ secret_key='limited123',
+ endpoint_override='{}:{}'.format(host, port),
+ scheme='http'
+ )
+ fs.create_dir('existing-bucket/test')
+
+
+def test_file_info_constructor():
+ dt = datetime.fromtimestamp(1568799826, timezone.utc)
+
+ info = FileInfo("foo/bar")
+ assert info.path == "foo/bar"
+ assert info.base_name == "bar"
+ assert info.type == FileType.Unknown
+ assert info.size is None
+ check_mtime_absent(info)
+
+ info = FileInfo("foo/baz.txt", type=FileType.File, size=123,
+ mtime=1568799826.5)
+ assert info.path == "foo/baz.txt"
+ assert info.base_name == "baz.txt"
+ assert info.type == FileType.File
+ assert info.size == 123
+ assert info.mtime_ns == 1568799826500000000
+ check_mtime(info)
+
+ info = FileInfo("foo", type=FileType.Directory, mtime=dt)
+ assert info.path == "foo"
+ assert info.base_name == "foo"
+ assert info.type == FileType.Directory
+ assert info.size is None
+ assert info.mtime == dt
+ assert info.mtime_ns == 1568799826000000000
+ check_mtime(info)
+
+
+def test_cannot_instantiate_base_filesystem():
+ with pytest.raises(TypeError):
+ FileSystem()
+
+
+def test_filesystem_equals():
+ fs0 = LocalFileSystem()
+ fs1 = LocalFileSystem()
+ fs2 = _MockFileSystem()
+
+ assert fs0.equals(fs0)
+ assert fs0.equals(fs1)
+ with pytest.raises(TypeError):
+ fs0.equals('string')
+ assert fs0 == fs0 == fs1
+ assert fs0 != 4
+
+ assert fs2 == fs2
+ assert fs2 != _MockFileSystem()
+
+ assert SubTreeFileSystem('/base', fs0) == SubTreeFileSystem('/base', fs0)
+ assert SubTreeFileSystem('/base', fs0) != SubTreeFileSystem('/base', fs2)
+ assert SubTreeFileSystem('/base', fs0) != SubTreeFileSystem('/other', fs0)
+
+
+def test_subtree_filesystem():
+ localfs = LocalFileSystem()
+
+ subfs = SubTreeFileSystem('/base', localfs)
+ assert subfs.base_path == '/base/'
+ assert subfs.base_fs == localfs
+ assert repr(subfs).startswith('SubTreeFileSystem(base_path=/base/, '
+ 'base_fs=<pyarrow._fs.LocalFileSystem')
+
+ subfs = SubTreeFileSystem('/another/base/', LocalFileSystem())
+ assert subfs.base_path == '/another/base/'
+ assert subfs.base_fs == localfs
+ assert repr(subfs).startswith('SubTreeFileSystem(base_path=/another/base/,'
+ ' base_fs=<pyarrow._fs.LocalFileSystem')
+
+
+def test_filesystem_pickling(fs):
+ if fs.type_name.split('::')[-1] == 'mock':
+ pytest.xfail(reason='MockFileSystem is not serializable')
+
+ serialized = pickle.dumps(fs)
+ restored = pickle.loads(serialized)
+ assert isinstance(restored, FileSystem)
+ assert restored.equals(fs)
+
+
+def test_filesystem_is_functional_after_pickling(fs, pathfn):
+ if fs.type_name.split('::')[-1] == 'mock':
+ pytest.xfail(reason='MockFileSystem is not serializable')
+ skip_fsspec_s3fs(fs)
+
+ aaa = pathfn('a/aa/aaa/')
+ bb = pathfn('a/bb')
+ c = pathfn('c.txt')
+
+ fs.create_dir(aaa)
+ with fs.open_output_stream(bb):
+ pass # touch
+ with fs.open_output_stream(c) as fp:
+ fp.write(b'test')
+
+ restored = pickle.loads(pickle.dumps(fs))
+ aaa_info, bb_info, c_info = restored.get_file_info([aaa, bb, c])
+ assert aaa_info.type == FileType.Directory
+ assert bb_info.type == FileType.File
+ assert c_info.type == FileType.File
+
+
+def test_type_name():
+ fs = LocalFileSystem()
+ assert fs.type_name == "local"
+ fs = _MockFileSystem()
+ assert fs.type_name == "mock"
+
+
+def test_normalize_path(fs):
+ # Trivial path names (without separators) should generally be
+ # already normalized. Just a sanity check.
+ assert fs.normalize_path("foo") == "foo"
+
+
+def test_non_path_like_input_raises(fs):
+ class Path:
+ pass
+
+ invalid_paths = [1, 1.1, Path(), tuple(), {}, [], lambda: 1,
+ pathlib.Path()]
+ for path in invalid_paths:
+ with pytest.raises(TypeError):
+ fs.create_dir(path)
+
+
+def test_get_file_info(fs, pathfn):
+ aaa = pathfn('a/aa/aaa/')
+ bb = pathfn('a/bb')
+ c = pathfn('c.txt')
+ zzz = pathfn('zzz')
+
+ fs.create_dir(aaa)
+ with fs.open_output_stream(bb):
+ pass # touch
+ with fs.open_output_stream(c) as fp:
+ fp.write(b'test')
+
+ aaa_info, bb_info, c_info, zzz_info = fs.get_file_info([aaa, bb, c, zzz])
+
+ assert aaa_info.path == aaa
+ assert 'aaa' in repr(aaa_info)
+ assert aaa_info.extension == ''
+ if fs.type_name == "py::fsspec+s3":
+ # s3fs doesn't create empty directories
+ assert aaa_info.type == FileType.NotFound
+ else:
+ assert aaa_info.type == FileType.Directory
+ assert 'FileType.Directory' in repr(aaa_info)
+ assert aaa_info.size is None
+ check_mtime_or_absent(aaa_info)
+
+ assert bb_info.path == str(bb)
+ assert bb_info.base_name == 'bb'
+ assert bb_info.extension == ''
+ assert bb_info.type == FileType.File
+ assert 'FileType.File' in repr(bb_info)
+ assert bb_info.size == 0
+ if fs.type_name not in ["py::fsspec+memory", "py::fsspec+s3"]:
+ check_mtime(bb_info)
+
+ assert c_info.path == str(c)
+ assert c_info.base_name == 'c.txt'
+ assert c_info.extension == 'txt'
+ assert c_info.type == FileType.File
+ assert 'FileType.File' in repr(c_info)
+ assert c_info.size == 4
+ if fs.type_name not in ["py::fsspec+memory", "py::fsspec+s3"]:
+ check_mtime(c_info)
+
+ assert zzz_info.path == str(zzz)
+ assert zzz_info.base_name == 'zzz'
+ assert zzz_info.extension == ''
+ assert zzz_info.type == FileType.NotFound
+ assert zzz_info.size is None
+ assert zzz_info.mtime is None
+ assert 'FileType.NotFound' in repr(zzz_info)
+ check_mtime_absent(zzz_info)
+
+ # with single path
+ aaa_info2 = fs.get_file_info(aaa)
+ assert aaa_info.path == aaa_info2.path
+ assert aaa_info.type == aaa_info2.type
+
+
+def test_get_file_info_with_selector(fs, pathfn):
+ base_dir = pathfn('selector-dir/')
+ file_a = pathfn('selector-dir/test_file_a')
+ file_b = pathfn('selector-dir/test_file_b')
+ dir_a = pathfn('selector-dir/test_dir_a')
+ file_c = pathfn('selector-dir/test_dir_a/test_file_c')
+ dir_b = pathfn('selector-dir/test_dir_b')
+
+ try:
+ fs.create_dir(base_dir)
+ with fs.open_output_stream(file_a):
+ pass
+ with fs.open_output_stream(file_b):
+ pass
+ fs.create_dir(dir_a)
+ with fs.open_output_stream(file_c):
+ pass
+ fs.create_dir(dir_b)
+
+ # recursive selector
+ selector = FileSelector(base_dir, allow_not_found=False,
+ recursive=True)
+ assert selector.base_dir == base_dir
+
+ infos = fs.get_file_info(selector)
+ if fs.type_name == "py::fsspec+s3":
+ # s3fs only lists directories if they are not empty, but depending
+ # on the s3fs/fsspec version combo, it includes the base_dir
+ # (https://github.com/dask/s3fs/issues/393)
+ assert (len(infos) == 4) or (len(infos) == 5)
+ else:
+ assert len(infos) == 5
+
+ for info in infos:
+ if (info.path.endswith(file_a) or info.path.endswith(file_b) or
+ info.path.endswith(file_c)):
+ assert info.type == FileType.File
+ elif (info.path.rstrip("/").endswith(dir_a) or
+ info.path.rstrip("/").endswith(dir_b)):
+ assert info.type == FileType.Directory
+ elif (fs.type_name == "py::fsspec+s3" and
+ info.path.rstrip("/").endswith("selector-dir")):
+ # s3fs can include base dir, see above
+ assert info.type == FileType.Directory
+ else:
+ raise ValueError('unexpected path {}'.format(info.path))
+ check_mtime_or_absent(info)
+
+ # non-recursive selector -> not selecting the nested file_c
+ selector = FileSelector(base_dir, recursive=False)
+
+ infos = fs.get_file_info(selector)
+ if fs.type_name == "py::fsspec+s3":
+ # s3fs only lists directories if they are not empty
+ # + for s3fs 0.5.2 all directories are dropped because of buggy
+ # side-effect of previous find() call
+ # (https://github.com/dask/s3fs/issues/410)
+ assert (len(infos) == 3) or (len(infos) == 2)
+ else:
+ assert len(infos) == 4
+
+ finally:
+ fs.delete_dir(base_dir)
+
+
+def test_create_dir(fs, pathfn):
+ # s3fs fails deleting dir fails if it is empty
+ # (https://github.com/dask/s3fs/issues/317)
+ skip_fsspec_s3fs(fs)
+ d = pathfn('test-directory/')
+
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_dir(d)
+
+ fs.create_dir(d)
+ fs.delete_dir(d)
+
+ d = pathfn('deeply/nested/test-directory/')
+ fs.create_dir(d, recursive=True)
+ fs.delete_dir(d)
+
+
+def test_delete_dir(fs, pathfn):
+ skip_fsspec_s3fs(fs)
+
+ d = pathfn('directory/')
+ nd = pathfn('directory/nested/')
+
+ fs.create_dir(nd)
+ fs.delete_dir(d)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_dir(nd)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_dir(d)
+
+
+def test_delete_dir_contents(fs, pathfn):
+ skip_fsspec_s3fs(fs)
+
+ d = pathfn('directory/')
+ nd = pathfn('directory/nested/')
+
+ fs.create_dir(nd)
+ fs.delete_dir_contents(d)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_dir(nd)
+ fs.delete_dir(d)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_dir(d)
+
+
+def _check_root_dir_contents(config):
+ fs = config['fs']
+ pathfn = config['pathfn']
+
+ d = pathfn('directory/')
+ nd = pathfn('directory/nested/')
+
+ fs.create_dir(nd)
+ with pytest.raises(pa.ArrowInvalid):
+ fs.delete_dir_contents("")
+ with pytest.raises(pa.ArrowInvalid):
+ fs.delete_dir_contents("/")
+ with pytest.raises(pa.ArrowInvalid):
+ fs.delete_dir_contents("//")
+
+ fs.delete_dir_contents("", accept_root_dir=True)
+ fs.delete_dir_contents("/", accept_root_dir=True)
+ fs.delete_dir_contents("//", accept_root_dir=True)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_dir(d)
+
+
+def test_delete_root_dir_contents(mockfs, py_mockfs):
+ _check_root_dir_contents(mockfs)
+ _check_root_dir_contents(py_mockfs)
+
+
+def test_copy_file(fs, pathfn):
+ s = pathfn('test-copy-source-file')
+ t = pathfn('test-copy-target-file')
+
+ with fs.open_output_stream(s):
+ pass
+
+ fs.copy_file(s, t)
+ fs.delete_file(s)
+ fs.delete_file(t)
+
+
+def test_move_directory(fs, pathfn, allow_move_dir):
+ # move directory (doesn't work with S3)
+ s = pathfn('source-dir/')
+ t = pathfn('target-dir/')
+
+ fs.create_dir(s)
+
+ if allow_move_dir:
+ fs.move(s, t)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_dir(s)
+ fs.delete_dir(t)
+ else:
+ with pytest.raises(pa.ArrowIOError):
+ fs.move(s, t)
+
+
+def test_move_file(fs, pathfn):
+ # s3fs moving a file with recursive=True on latest 0.5 version
+ # (https://github.com/dask/s3fs/issues/394)
+ skip_fsspec_s3fs(fs)
+
+ s = pathfn('test-move-source-file')
+ t = pathfn('test-move-target-file')
+
+ with fs.open_output_stream(s):
+ pass
+
+ fs.move(s, t)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_file(s)
+ fs.delete_file(t)
+
+
+def test_delete_file(fs, pathfn):
+ p = pathfn('test-delete-target-file')
+ with fs.open_output_stream(p):
+ pass
+
+ fs.delete_file(p)
+ with pytest.raises(pa.ArrowIOError):
+ fs.delete_file(p)
+
+ d = pathfn('test-delete-nested')
+ fs.create_dir(d)
+ f = pathfn('test-delete-nested/target-file')
+ with fs.open_output_stream(f) as s:
+ s.write(b'data')
+
+ fs.delete_dir(d)
+
+
+def identity(v):
+ return v
+
+
+@pytest.mark.gzip
+@pytest.mark.parametrize(
+ ('compression', 'buffer_size', 'compressor'),
+ [
+ (None, None, identity),
+ (None, 64, identity),
+ ('gzip', None, gzip.compress),
+ ('gzip', 256, gzip.compress),
+ ]
+)
+def test_open_input_stream(fs, pathfn, compression, buffer_size, compressor):
+ p = pathfn('open-input-stream')
+
+ data = b'some data for reading\n' * 512
+ with fs.open_output_stream(p) as s:
+ s.write(compressor(data))
+
+ with fs.open_input_stream(p, compression, buffer_size) as s:
+ result = s.read()
+
+ assert result == data
+
+
+def test_open_input_file(fs, pathfn):
+ p = pathfn('open-input-file')
+
+ data = b'some data' * 1024
+ with fs.open_output_stream(p) as s:
+ s.write(data)
+
+ read_from = len(b'some data') * 512
+ with fs.open_input_file(p) as f:
+ f.seek(read_from)
+ result = f.read()
+
+ assert result == data[read_from:]
+
+
+@pytest.mark.gzip
+@pytest.mark.parametrize(
+ ('compression', 'buffer_size', 'decompressor'),
+ [
+ (None, None, identity),
+ (None, 64, identity),
+ ('gzip', None, gzip.decompress),
+ ('gzip', 256, gzip.decompress),
+ ]
+)
+def test_open_output_stream(fs, pathfn, compression, buffer_size,
+ decompressor):
+ p = pathfn('open-output-stream')
+
+ data = b'some data for writing' * 1024
+ with fs.open_output_stream(p, compression, buffer_size) as f:
+ f.write(data)
+
+ with fs.open_input_stream(p, compression, buffer_size) as f:
+ assert f.read(len(data)) == data
+
+
+@pytest.mark.gzip
+@pytest.mark.parametrize(
+ ('compression', 'buffer_size', 'compressor', 'decompressor'),
+ [
+ (None, None, identity, identity),
+ (None, 64, identity, identity),
+ ('gzip', None, gzip.compress, gzip.decompress),
+ ('gzip', 256, gzip.compress, gzip.decompress),
+ ]
+)
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+def test_open_append_stream(fs, pathfn, compression, buffer_size, compressor,
+ decompressor, allow_append_to_file):
+ p = pathfn('open-append-stream')
+
+ initial = compressor(b'already existing')
+ with fs.open_output_stream(p) as s:
+ s.write(initial)
+
+ if allow_append_to_file:
+ with fs.open_append_stream(p, compression=compression,
+ buffer_size=buffer_size) as f:
+ f.write(b'\nnewly added')
+
+ with fs.open_input_stream(p) as f:
+ result = f.read()
+
+ result = decompressor(result)
+ assert result == b'already existing\nnewly added'
+ else:
+ with pytest.raises(pa.ArrowNotImplementedError):
+ fs.open_append_stream(p, compression=compression,
+ buffer_size=buffer_size)
+
+
+def test_open_output_stream_metadata(fs, pathfn):
+ p = pathfn('open-output-stream-metadata')
+ metadata = {'Content-Type': 'x-pyarrow/test'}
+
+ data = b'some data'
+ with fs.open_output_stream(p, metadata=metadata) as f:
+ f.write(data)
+
+ with fs.open_input_stream(p) as f:
+ assert f.read() == data
+ got_metadata = f.metadata()
+
+ if fs.type_name == 's3' or 'mock' in fs.type_name:
+ for k, v in metadata.items():
+ assert got_metadata[k] == v.encode()
+ else:
+ assert got_metadata == {}
+
+
+def test_localfs_options():
+ # LocalFileSystem instantiation
+ LocalFileSystem(use_mmap=False)
+
+ with pytest.raises(TypeError):
+ LocalFileSystem(xxx=False)
+
+
+def test_localfs_errors(localfs):
+ # Local filesystem errors should raise the right Python exceptions
+ # (e.g. FileNotFoundError)
+ fs = localfs['fs']
+ with assert_file_not_found():
+ fs.open_input_stream('/non/existent/file')
+ with assert_file_not_found():
+ fs.open_output_stream('/non/existent/file')
+ with assert_file_not_found():
+ fs.create_dir('/non/existent/dir', recursive=False)
+ with assert_file_not_found():
+ fs.delete_dir('/non/existent/dir')
+ with assert_file_not_found():
+ fs.delete_file('/non/existent/dir')
+ with assert_file_not_found():
+ fs.move('/non/existent', '/xxx')
+ with assert_file_not_found():
+ fs.copy_file('/non/existent', '/xxx')
+
+
+def test_localfs_file_info(localfs):
+ fs = localfs['fs']
+
+ file_path = pathlib.Path(__file__)
+ dir_path = file_path.parent
+ [file_info, dir_info] = fs.get_file_info([file_path.as_posix(),
+ dir_path.as_posix()])
+ assert file_info.size == file_path.stat().st_size
+ assert file_info.mtime_ns == file_path.stat().st_mtime_ns
+ check_mtime(file_info)
+ assert dir_info.mtime_ns == dir_path.stat().st_mtime_ns
+ check_mtime(dir_info)
+
+
+def test_mockfs_mtime_roundtrip(mockfs):
+ dt = datetime.fromtimestamp(1568799826, timezone.utc)
+ fs = _MockFileSystem(dt)
+
+ with fs.open_output_stream('foo'):
+ pass
+ [info] = fs.get_file_info(['foo'])
+ assert info.mtime == dt
+
+
+@pytest.mark.s3
+def test_s3_options():
+ from pyarrow.fs import S3FileSystem
+
+ fs = S3FileSystem(access_key='access', secret_key='secret',
+ session_token='token', region='us-east-2',
+ scheme='https', endpoint_override='localhost:8999')
+ assert isinstance(fs, S3FileSystem)
+ assert fs.region == 'us-east-2'
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ fs = S3FileSystem(role_arn='role', session_name='session',
+ external_id='id', load_frequency=100)
+ assert isinstance(fs, S3FileSystem)
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ fs = S3FileSystem(anonymous=True)
+ assert isinstance(fs, S3FileSystem)
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ fs = S3FileSystem(background_writes=True,
+ default_metadata={"ACL": "authenticated-read",
+ "Content-Type": "text/plain"})
+ assert isinstance(fs, S3FileSystem)
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ with pytest.raises(ValueError):
+ S3FileSystem(access_key='access')
+ with pytest.raises(ValueError):
+ S3FileSystem(secret_key='secret')
+ with pytest.raises(ValueError):
+ S3FileSystem(access_key='access', session_token='token')
+ with pytest.raises(ValueError):
+ S3FileSystem(secret_key='secret', session_token='token')
+ with pytest.raises(ValueError):
+ S3FileSystem(
+ access_key='access', secret_key='secret', role_arn='arn'
+ )
+ with pytest.raises(ValueError):
+ S3FileSystem(
+ access_key='access', secret_key='secret', anonymous=True
+ )
+ with pytest.raises(ValueError):
+ S3FileSystem(role_arn="arn", anonymous=True)
+ with pytest.raises(ValueError):
+ S3FileSystem(default_metadata=["foo", "bar"])
+
+
+@pytest.mark.s3
+def test_s3_proxy_options(monkeypatch):
+ from pyarrow.fs import S3FileSystem
+
+ # The following two are equivalent:
+ proxy_opts_1_dict = {'scheme': 'http', 'host': 'localhost', 'port': 8999}
+ proxy_opts_1_str = 'http://localhost:8999'
+ # The following two are equivalent:
+ proxy_opts_2_dict = {'scheme': 'https', 'host': 'localhost', 'port': 8080}
+ proxy_opts_2_str = 'https://localhost:8080'
+
+ # Check dict case for 'proxy_options'
+ fs = S3FileSystem(proxy_options=proxy_opts_1_dict)
+ assert isinstance(fs, S3FileSystem)
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ fs = S3FileSystem(proxy_options=proxy_opts_2_dict)
+ assert isinstance(fs, S3FileSystem)
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ # Check str case for 'proxy_options'
+ fs = S3FileSystem(proxy_options=proxy_opts_1_str)
+ assert isinstance(fs, S3FileSystem)
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ fs = S3FileSystem(proxy_options=proxy_opts_2_str)
+ assert isinstance(fs, S3FileSystem)
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ # Check that two FSs using the same proxy_options dict are equal
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_1_dict)
+ assert fs1 == fs2
+ assert pickle.loads(pickle.dumps(fs1)) == fs2
+ assert pickle.loads(pickle.dumps(fs2)) == fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_2_dict)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_2_dict)
+ assert fs1 == fs2
+ assert pickle.loads(pickle.dumps(fs1)) == fs2
+ assert pickle.loads(pickle.dumps(fs2)) == fs1
+
+ # Check that two FSs using the same proxy_options str are equal
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_str)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_1_str)
+ assert fs1 == fs2
+ assert pickle.loads(pickle.dumps(fs1)) == fs2
+ assert pickle.loads(pickle.dumps(fs2)) == fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_2_str)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_2_str)
+ assert fs1 == fs2
+ assert pickle.loads(pickle.dumps(fs1)) == fs2
+ assert pickle.loads(pickle.dumps(fs2)) == fs1
+
+ # Check that two FSs using equivalent proxy_options
+ # (one dict, one str) are equal
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_1_str)
+ assert fs1 == fs2
+ assert pickle.loads(pickle.dumps(fs1)) == fs2
+ assert pickle.loads(pickle.dumps(fs2)) == fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_2_dict)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_2_str)
+ assert fs1 == fs2
+ assert pickle.loads(pickle.dumps(fs1)) == fs2
+ assert pickle.loads(pickle.dumps(fs2)) == fs1
+
+ # Check that two FSs using nonequivalent proxy_options are not equal
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_2_dict)
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_2_str)
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_str)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_2_dict)
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_str)
+ fs2 = S3FileSystem(proxy_options=proxy_opts_2_str)
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ # Check that two FSs (one using proxy_options and the other not)
+ # are not equal
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict)
+ fs2 = S3FileSystem()
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_1_str)
+ fs2 = S3FileSystem()
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_2_dict)
+ fs2 = S3FileSystem()
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ fs1 = S3FileSystem(proxy_options=proxy_opts_2_str)
+ fs2 = S3FileSystem()
+ assert fs1 != fs2
+ assert pickle.loads(pickle.dumps(fs1)) != fs2
+ assert pickle.loads(pickle.dumps(fs2)) != fs1
+
+ # Only dict and str are supported
+ with pytest.raises(TypeError):
+ S3FileSystem(proxy_options=('http', 'localhost', 9090))
+ # Missing scheme
+ with pytest.raises(KeyError):
+ S3FileSystem(proxy_options={'host': 'localhost', 'port': 9090})
+ # Missing host
+ with pytest.raises(KeyError):
+ S3FileSystem(proxy_options={'scheme': 'https', 'port': 9090})
+ # Missing port
+ with pytest.raises(KeyError):
+ S3FileSystem(proxy_options={'scheme': 'http', 'host': 'localhost'})
+ # Invalid proxy URI (invalid scheme htttps)
+ with pytest.raises(pa.ArrowInvalid):
+ S3FileSystem(proxy_options='htttps://localhost:9000')
+ # Invalid proxy_options dict (invalid scheme htttps)
+ with pytest.raises(pa.ArrowInvalid):
+ S3FileSystem(proxy_options={'scheme': 'htttp', 'host': 'localhost',
+ 'port': 8999})
+
+
+@pytest.mark.hdfs
+def test_hdfs_options(hdfs_connection):
+ from pyarrow.fs import HadoopFileSystem
+ if not pa.have_libhdfs():
+ pytest.skip('Cannot locate libhdfs')
+
+ host, port, user = hdfs_connection
+
+ replication = 2
+ buffer_size = 64*1024
+ default_block_size = 128*1024**2
+ uri = ('hdfs://{}:{}/?user={}&replication={}&buffer_size={}'
+ '&default_block_size={}')
+
+ hdfs1 = HadoopFileSystem(host, port, user='libhdfs',
+ replication=replication, buffer_size=buffer_size,
+ default_block_size=default_block_size)
+ hdfs2 = HadoopFileSystem.from_uri(uri.format(
+ host, port, 'libhdfs', replication, buffer_size, default_block_size
+ ))
+ hdfs3 = HadoopFileSystem.from_uri(uri.format(
+ host, port, 'me', replication, buffer_size, default_block_size
+ ))
+ hdfs4 = HadoopFileSystem.from_uri(uri.format(
+ host, port, 'me', replication + 1, buffer_size, default_block_size
+ ))
+ hdfs5 = HadoopFileSystem(host, port)
+ hdfs6 = HadoopFileSystem.from_uri('hdfs://{}:{}'.format(host, port))
+ hdfs7 = HadoopFileSystem(host, port, user='localuser')
+ hdfs8 = HadoopFileSystem(host, port, user='localuser',
+ kerb_ticket="cache_path")
+ hdfs9 = HadoopFileSystem(host, port, user='localuser',
+ kerb_ticket=pathlib.Path("cache_path"))
+ hdfs10 = HadoopFileSystem(host, port, user='localuser',
+ kerb_ticket="cache_path2")
+ hdfs11 = HadoopFileSystem(host, port, user='localuser',
+ kerb_ticket="cache_path",
+ extra_conf={'hdfs_token': 'abcd'})
+
+ assert hdfs1 == hdfs2
+ assert hdfs5 == hdfs6
+ assert hdfs6 != hdfs7
+ assert hdfs2 != hdfs3
+ assert hdfs3 != hdfs4
+ assert hdfs7 != hdfs5
+ assert hdfs2 != hdfs3
+ assert hdfs3 != hdfs4
+ assert hdfs7 != hdfs8
+ assert hdfs8 == hdfs9
+ assert hdfs10 != hdfs9
+ assert hdfs11 != hdfs8
+
+ with pytest.raises(TypeError):
+ HadoopFileSystem()
+ with pytest.raises(TypeError):
+ HadoopFileSystem.from_uri(3)
+
+ for fs in [hdfs1, hdfs2, hdfs3, hdfs4, hdfs5, hdfs6, hdfs7, hdfs8,
+ hdfs9, hdfs10, hdfs11]:
+ assert pickle.loads(pickle.dumps(fs)) == fs
+
+ host, port, user = hdfs_connection
+
+ hdfs = HadoopFileSystem(host, port, user=user)
+ assert hdfs.get_file_info(FileSelector('/'))
+
+ hdfs = HadoopFileSystem.from_uri(
+ "hdfs://{}:{}/?user={}".format(host, port, user)
+ )
+ assert hdfs.get_file_info(FileSelector('/'))
+
+
+@pytest.mark.parametrize(('uri', 'expected_klass', 'expected_path'), [
+ # leading slashes are removed intentionally, because MockFileSystem doesn't
+ # have a distinction between relative and absolute paths
+ ('mock:', _MockFileSystem, ''),
+ ('mock:foo/bar', _MockFileSystem, 'foo/bar'),
+ ('mock:/foo/bar', _MockFileSystem, 'foo/bar'),
+ ('mock:///foo/bar', _MockFileSystem, 'foo/bar'),
+ ('file:/', LocalFileSystem, '/'),
+ ('file:///', LocalFileSystem, '/'),
+ ('file:/foo/bar', LocalFileSystem, '/foo/bar'),
+ ('file:///foo/bar', LocalFileSystem, '/foo/bar'),
+ ('/', LocalFileSystem, '/'),
+ ('/foo/bar', LocalFileSystem, '/foo/bar'),
+])
+def test_filesystem_from_uri(uri, expected_klass, expected_path):
+ fs, path = FileSystem.from_uri(uri)
+ assert isinstance(fs, expected_klass)
+ assert path == expected_path
+
+
+@pytest.mark.parametrize(
+ 'path',
+ ['', '/', 'foo/bar', '/foo/bar', __file__]
+)
+def test_filesystem_from_path_object(path):
+ p = pathlib.Path(path)
+ fs, path = FileSystem.from_uri(p)
+ assert isinstance(fs, LocalFileSystem)
+ assert path == p.resolve().absolute().as_posix()
+
+
+@pytest.mark.s3
+def test_filesystem_from_uri_s3(s3_server):
+ from pyarrow.fs import S3FileSystem
+
+ host, port, access_key, secret_key = s3_server['connection']
+
+ uri = "s3://{}:{}@mybucket/foo/bar?scheme=http&endpoint_override={}:{}" \
+ .format(access_key, secret_key, host, port)
+
+ fs, path = FileSystem.from_uri(uri)
+ assert isinstance(fs, S3FileSystem)
+ assert path == "mybucket/foo/bar"
+
+ fs.create_dir(path)
+ [info] = fs.get_file_info([path])
+ assert info.path == path
+ assert info.type == FileType.Directory
+
+
+def test_py_filesystem():
+ handler = DummyHandler()
+ fs = PyFileSystem(handler)
+ assert isinstance(fs, PyFileSystem)
+ assert fs.type_name == "py::dummy"
+ assert fs.handler is handler
+
+ with pytest.raises(TypeError):
+ PyFileSystem(None)
+
+
+def test_py_filesystem_equality():
+ handler1 = DummyHandler(1)
+ handler2 = DummyHandler(2)
+ handler3 = DummyHandler(2)
+ fs1 = PyFileSystem(handler1)
+ fs2 = PyFileSystem(handler1)
+ fs3 = PyFileSystem(handler2)
+ fs4 = PyFileSystem(handler3)
+
+ assert fs2 is not fs1
+ assert fs3 is not fs2
+ assert fs4 is not fs3
+ assert fs2 == fs1 # Same handler
+ assert fs3 != fs2 # Unequal handlers
+ assert fs4 == fs3 # Equal handlers
+
+ assert fs1 != LocalFileSystem()
+ assert fs1 != object()
+
+
+def test_py_filesystem_pickling():
+ handler = DummyHandler()
+ fs = PyFileSystem(handler)
+
+ serialized = pickle.dumps(fs)
+ restored = pickle.loads(serialized)
+ assert isinstance(restored, FileSystem)
+ assert restored == fs
+ assert restored.handler == handler
+ assert restored.type_name == "py::dummy"
+
+
+def test_py_filesystem_lifetime():
+ handler = DummyHandler()
+ fs = PyFileSystem(handler)
+ assert isinstance(fs, PyFileSystem)
+ wr = weakref.ref(handler)
+ handler = None
+ assert wr() is not None
+ fs = None
+ assert wr() is None
+
+ # Taking the .handler attribute doesn't wreck reference counts
+ handler = DummyHandler()
+ fs = PyFileSystem(handler)
+ wr = weakref.ref(handler)
+ handler = None
+ assert wr() is fs.handler
+ assert wr() is not None
+ fs = None
+ assert wr() is None
+
+
+def test_py_filesystem_get_file_info():
+ handler = DummyHandler()
+ fs = PyFileSystem(handler)
+
+ [info] = fs.get_file_info(['some/dir'])
+ assert info.path == 'some/dir'
+ assert info.type == FileType.Directory
+
+ [info] = fs.get_file_info(['some/file'])
+ assert info.path == 'some/file'
+ assert info.type == FileType.File
+
+ [info] = fs.get_file_info(['notfound'])
+ assert info.path == 'notfound'
+ assert info.type == FileType.NotFound
+
+ with pytest.raises(TypeError):
+ fs.get_file_info(['badtype'])
+
+ with pytest.raises(IOError):
+ fs.get_file_info(['xxx'])
+
+
+def test_py_filesystem_get_file_info_selector():
+ handler = DummyHandler()
+ fs = PyFileSystem(handler)
+
+ selector = FileSelector(base_dir="somedir")
+ infos = fs.get_file_info(selector)
+ assert len(infos) == 2
+ assert infos[0].path == "somedir/file1"
+ assert infos[0].type == FileType.File
+ assert infos[0].size == 123
+ assert infos[1].path == "somedir/subdir1"
+ assert infos[1].type == FileType.Directory
+ assert infos[1].size is None
+
+ selector = FileSelector(base_dir="somedir", recursive=True)
+ infos = fs.get_file_info(selector)
+ assert len(infos) == 3
+ assert infos[0].path == "somedir/file1"
+ assert infos[1].path == "somedir/subdir1"
+ assert infos[2].path == "somedir/subdir1/file2"
+
+ selector = FileSelector(base_dir="notfound")
+ with pytest.raises(FileNotFoundError):
+ fs.get_file_info(selector)
+
+ selector = FileSelector(base_dir="notfound", allow_not_found=True)
+ assert fs.get_file_info(selector) == []
+
+
+def test_py_filesystem_ops():
+ handler = DummyHandler()
+ fs = PyFileSystem(handler)
+
+ fs.create_dir("recursive", recursive=True)
+ fs.create_dir("non-recursive", recursive=False)
+ with pytest.raises(IOError):
+ fs.create_dir("foobar")
+
+ fs.delete_dir("delete_dir")
+ fs.delete_dir_contents("delete_dir_contents")
+ for path in ("", "/", "//"):
+ with pytest.raises(ValueError):
+ fs.delete_dir_contents(path)
+ fs.delete_dir_contents(path, accept_root_dir=True)
+ fs.delete_file("delete_file")
+ fs.move("move_from", "move_to")
+ fs.copy_file("copy_file_from", "copy_file_to")
+
+
+def test_py_open_input_stream():
+ fs = PyFileSystem(DummyHandler())
+
+ with fs.open_input_stream("somefile") as f:
+ assert f.read() == b"somefile:input_stream"
+ with pytest.raises(FileNotFoundError):
+ fs.open_input_stream("notfound")
+
+
+def test_py_open_input_file():
+ fs = PyFileSystem(DummyHandler())
+
+ with fs.open_input_file("somefile") as f:
+ assert f.read() == b"somefile:input_file"
+ with pytest.raises(FileNotFoundError):
+ fs.open_input_file("notfound")
+
+
+def test_py_open_output_stream():
+ fs = PyFileSystem(DummyHandler())
+
+ with fs.open_output_stream("somefile") as f:
+ f.write(b"data")
+
+
+@pytest.mark.filterwarnings("ignore::FutureWarning")
+def test_py_open_append_stream():
+ fs = PyFileSystem(DummyHandler())
+
+ with fs.open_append_stream("somefile") as f:
+ f.write(b"data")
+
+
+@pytest.mark.s3
+def test_s3_real_aws():
+ # Exercise connection code with an AWS-backed S3 bucket.
+ # This is a minimal integration check for ARROW-9261 and similar issues.
+ from pyarrow.fs import S3FileSystem
+ default_region = (os.environ.get('PYARROW_TEST_S3_REGION') or
+ 'us-east-1')
+ fs = S3FileSystem(anonymous=True)
+ assert fs.region == default_region
+
+ fs = S3FileSystem(anonymous=True, region='us-east-2')
+ entries = fs.get_file_info(FileSelector('ursa-labs-taxi-data'))
+ assert len(entries) > 0
+ with fs.open_input_stream('ursa-labs-taxi-data/2019/06/data.parquet') as f:
+ md = f.metadata()
+ assert 'Content-Type' in md
+ assert md['Last-Modified'] == b'2020-01-17T16:26:28Z'
+ # For some reason, the header value is quoted
+ # (both with AWS and Minio)
+ assert md['ETag'] == b'"f1efd5d76cb82861e1542117bfa52b90-8"'
+
+
+@pytest.mark.s3
+def test_s3_real_aws_region_selection():
+ # Taken from a registry of open S3-hosted datasets
+ # at https://github.com/awslabs/open-data-registry
+ fs, path = FileSystem.from_uri('s3://mf-nwp-models/README.txt')
+ assert fs.region == 'eu-west-1'
+ with fs.open_input_stream(path) as f:
+ assert b"Meteo-France Atmospheric models on AWS" in f.read(50)
+
+ # Passing an explicit region disables auto-selection
+ fs, path = FileSystem.from_uri(
+ 's3://mf-nwp-models/README.txt?region=us-east-2')
+ assert fs.region == 'us-east-2'
+ # Reading from the wrong region may still work for public buckets...
+
+ # Non-existent bucket (hopefully, otherwise need to fix this test)
+ with pytest.raises(IOError, match="Bucket '.*' not found"):
+ FileSystem.from_uri('s3://x-arrow-non-existent-bucket')
+ fs, path = FileSystem.from_uri(
+ 's3://x-arrow-non-existent-bucket?region=us-east-3')
+ assert fs.region == 'us-east-3'
+
+
+@pytest.mark.s3
+def test_copy_files(s3_connection, s3fs, tempdir):
+ fs = s3fs["fs"]
+ pathfn = s3fs["pathfn"]
+
+ # create test file on S3 filesystem
+ path = pathfn('c.txt')
+ with fs.open_output_stream(path) as f:
+ f.write(b'test')
+
+ # create URI for created file
+ host, port, access_key, secret_key = s3_connection
+ source_uri = (
+ f"s3://{access_key}:{secret_key}@{path}"
+ f"?scheme=http&endpoint_override={host}:{port}"
+ )
+ # copy from S3 URI to local file
+ local_path1 = str(tempdir / "c_copied1.txt")
+ copy_files(source_uri, local_path1)
+
+ localfs = LocalFileSystem()
+ with localfs.open_input_stream(local_path1) as f:
+ assert f.read() == b"test"
+
+ # copy from S3 path+filesystem to local file
+ local_path2 = str(tempdir / "c_copied2.txt")
+ copy_files(path, local_path2, source_filesystem=fs)
+ with localfs.open_input_stream(local_path2) as f:
+ assert f.read() == b"test"
+
+ # copy to local file with URI
+ local_path3 = str(tempdir / "c_copied3.txt")
+ destination_uri = _filesystem_uri(local_path3) # file://
+ copy_files(source_uri, destination_uri)
+
+ with localfs.open_input_stream(local_path3) as f:
+ assert f.read() == b"test"
+
+ # copy to local file with path+filesystem
+ local_path4 = str(tempdir / "c_copied4.txt")
+ copy_files(source_uri, local_path4, destination_filesystem=localfs)
+
+ with localfs.open_input_stream(local_path4) as f:
+ assert f.read() == b"test"
+
+ # copy with additional options
+ local_path5 = str(tempdir / "c_copied5.txt")
+ copy_files(source_uri, local_path5, chunk_size=1, use_threads=False)
+
+ with localfs.open_input_stream(local_path5) as f:
+ assert f.read() == b"test"
+
+
+def test_copy_files_directory(tempdir):
+ localfs = LocalFileSystem()
+
+ # create source directory with 2 files
+ source_dir = tempdir / "source"
+ source_dir.mkdir()
+ with localfs.open_output_stream(str(source_dir / "file1")) as f:
+ f.write(b'test1')
+ with localfs.open_output_stream(str(source_dir / "file2")) as f:
+ f.write(b'test2')
+
+ def check_copied_files(destination_dir):
+ with localfs.open_input_stream(str(destination_dir / "file1")) as f:
+ assert f.read() == b"test1"
+ with localfs.open_input_stream(str(destination_dir / "file2")) as f:
+ assert f.read() == b"test2"
+
+ # Copy directory with local file paths
+ destination_dir1 = tempdir / "destination1"
+ # TODO need to create?
+ destination_dir1.mkdir()
+ copy_files(str(source_dir), str(destination_dir1))
+ check_copied_files(destination_dir1)
+
+ # Copy directory with path+filesystem
+ destination_dir2 = tempdir / "destination2"
+ destination_dir2.mkdir()
+ copy_files(str(source_dir), str(destination_dir2),
+ source_filesystem=localfs, destination_filesystem=localfs)
+ check_copied_files(destination_dir2)
+
+ # Copy directory with URI
+ destination_dir3 = tempdir / "destination3"
+ destination_dir3.mkdir()
+ source_uri = _filesystem_uri(str(source_dir)) # file://
+ destination_uri = _filesystem_uri(str(destination_dir3))
+ copy_files(source_uri, destination_uri)
+ check_copied_files(destination_dir3)
+
+ # Copy directory with Path objects
+ destination_dir4 = tempdir / "destination4"
+ destination_dir4.mkdir()
+ copy_files(source_dir, destination_dir4)
+ check_copied_files(destination_dir4)
+
+ # copy with additional non-default options
+ destination_dir5 = tempdir / "destination5"
+ destination_dir5.mkdir()
+ copy_files(source_dir, destination_dir5, chunk_size=1, use_threads=False)
+ check_copied_files(destination_dir5)
diff --git a/src/arrow/python/pyarrow/tests/test_gandiva.py b/src/arrow/python/pyarrow/tests/test_gandiva.py
new file mode 100644
index 000000000..6522c233a
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_gandiva.py
@@ -0,0 +1,391 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import pytest
+
+import pyarrow as pa
+
+
+@pytest.mark.gandiva
+def test_tree_exp_builder():
+ import pyarrow.gandiva as gandiva
+
+ builder = gandiva.TreeExprBuilder()
+
+ field_a = pa.field('a', pa.int32())
+ field_b = pa.field('b', pa.int32())
+
+ schema = pa.schema([field_a, field_b])
+
+ field_result = pa.field('res', pa.int32())
+
+ node_a = builder.make_field(field_a)
+ node_b = builder.make_field(field_b)
+
+ assert node_a.return_type() == field_a.type
+
+ condition = builder.make_function("greater_than", [node_a, node_b],
+ pa.bool_())
+ if_node = builder.make_if(condition, node_a, node_b, pa.int32())
+
+ expr = builder.make_expression(if_node, field_result)
+
+ assert expr.result().type == pa.int32()
+
+ projector = gandiva.make_projector(
+ schema, [expr], pa.default_memory_pool())
+
+ # Gandiva generates compute kernel function named `@expr_X`
+ assert projector.llvm_ir.find("@expr_") != -1
+
+ a = pa.array([10, 12, -20, 5], type=pa.int32())
+ b = pa.array([5, 15, 15, 17], type=pa.int32())
+ e = pa.array([10, 15, 15, 17], type=pa.int32())
+ input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b'])
+
+ r, = projector.evaluate(input_batch)
+ assert r.equals(e)
+
+
+@pytest.mark.gandiva
+def test_table():
+ import pyarrow.gandiva as gandiva
+
+ table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])],
+ ['a', 'b'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ node_b = builder.make_field(table.schema.field("b"))
+
+ sum = builder.make_function("add", [node_a, node_b], pa.float64())
+
+ field_result = pa.field("c", pa.float64())
+ expr = builder.make_expression(sum, field_result)
+
+ projector = gandiva.make_projector(
+ table.schema, [expr], pa.default_memory_pool())
+
+ # TODO: Add .evaluate function which can take Tables instead of
+ # RecordBatches
+ r, = projector.evaluate(table.to_batches()[0])
+
+ e = pa.array([4.0, 6.0])
+ assert r.equals(e)
+
+
+@pytest.mark.gandiva
+def test_filter():
+ import pyarrow.gandiva as gandiva
+
+ table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])],
+ ['a'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ thousand = builder.make_literal(1000.0, pa.float64())
+ cond = builder.make_function("less_than", [node_a, thousand], pa.bool_())
+ condition = builder.make_condition(cond)
+
+ assert condition.result().type == pa.bool_()
+
+ filter = gandiva.make_filter(table.schema, condition)
+ # Gandiva generates compute kernel function named `@expr_X`
+ assert filter.llvm_ir.find("@expr_") != -1
+
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array(range(1000), type=pa.uint32()))
+
+
+@pytest.mark.gandiva
+def test_in_expr():
+ import pyarrow.gandiva as gandiva
+
+ arr = pa.array(["ga", "an", "nd", "di", "iv", "va"])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ # string
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, ["an", "nd"], pa.string())
+ condition = builder.make_condition(cond)
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
+
+ # int32
+ arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
+ table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"])
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [1, 5], pa.int32())
+ condition = builder.make_condition(cond)
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
+
+ # int64
+ arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
+ table = pa.Table.from_arrays([arr], ["a"])
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [1, 5], pa.int64())
+ condition = builder.make_condition(cond)
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
+
+
+@pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, "
+ "time and date support.")
+def test_in_expr_todo():
+ import pyarrow.gandiva as gandiva
+ # TODO: Implement reasonable support for timestamp, time & date.
+ # Current exceptions:
+ # pyarrow.lib.ArrowException: ExpressionValidationError:
+ # Evaluation expression for IN clause returns XXXX values are of typeXXXX
+
+ # binary
+ arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary())
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
+
+ # timestamp
+ datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877)
+ datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877)
+ datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877)
+
+ arr = pa.array([datetime_1, datetime_2, datetime_3])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms'))
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert list(result.to_array()) == [1]
+
+ # time
+ time_1 = datetime_1.time()
+ time_2 = datetime_2.time()
+ time_3 = datetime_3.time()
+
+ arr = pa.array([time_1, time_2, time_3])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms'))
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert list(result.to_array()) == [1]
+
+ # date
+ date_1 = datetime_1.date()
+ date_2 = datetime_2.date()
+ date_3 = datetime_3.date()
+
+ arr = pa.array([date_1, date_2, date_3])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [date_2], pa.date32())
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert list(result.to_array()) == [1]
+
+
+@pytest.mark.gandiva
+def test_boolean():
+ import pyarrow.gandiva as gandiva
+
+ table = pa.Table.from_arrays([
+ pa.array([1., 31., 46., 3., 57., 44., 22.]),
+ pa.array([5., 45., 36., 73., 83., 23., 76.])],
+ ['a', 'b'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ node_b = builder.make_field(table.schema.field("b"))
+ fifty = builder.make_literal(50.0, pa.float64())
+ eleven = builder.make_literal(11.0, pa.float64())
+
+ cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_())
+ cond_2 = builder.make_function("greater_than", [node_a, node_b],
+ pa.bool_())
+ cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_())
+ cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3])
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([0, 2, 5], type=pa.uint32()))
+
+
+@pytest.mark.gandiva
+def test_literals():
+ import pyarrow.gandiva as gandiva
+
+ builder = gandiva.TreeExprBuilder()
+
+ builder.make_literal(True, pa.bool_())
+ builder.make_literal(0, pa.uint8())
+ builder.make_literal(1, pa.uint16())
+ builder.make_literal(2, pa.uint32())
+ builder.make_literal(3, pa.uint64())
+ builder.make_literal(4, pa.int8())
+ builder.make_literal(5, pa.int16())
+ builder.make_literal(6, pa.int32())
+ builder.make_literal(7, pa.int64())
+ builder.make_literal(8.0, pa.float32())
+ builder.make_literal(9.0, pa.float64())
+ builder.make_literal("hello", pa.string())
+ builder.make_literal(b"world", pa.binary())
+
+ builder.make_literal(True, "bool")
+ builder.make_literal(0, "uint8")
+ builder.make_literal(1, "uint16")
+ builder.make_literal(2, "uint32")
+ builder.make_literal(3, "uint64")
+ builder.make_literal(4, "int8")
+ builder.make_literal(5, "int16")
+ builder.make_literal(6, "int32")
+ builder.make_literal(7, "int64")
+ builder.make_literal(8.0, "float32")
+ builder.make_literal(9.0, "float64")
+ builder.make_literal("hello", "string")
+ builder.make_literal(b"world", "binary")
+
+ with pytest.raises(TypeError):
+ builder.make_literal("hello", pa.int64())
+ with pytest.raises(TypeError):
+ builder.make_literal(True, None)
+
+
+@pytest.mark.gandiva
+def test_regex():
+ import pyarrow.gandiva as gandiva
+
+ elements = ["park", "sparkle", "bright spark and fire", "spark"]
+ data = pa.array(elements, type=pa.string())
+ table = pa.Table.from_arrays([data], names=['a'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ regex = builder.make_literal("%spark%", pa.string())
+ like = builder.make_function("like", [node_a, regex], pa.bool_())
+
+ field_result = pa.field("b", pa.bool_())
+ expr = builder.make_expression(like, field_result)
+
+ projector = gandiva.make_projector(
+ table.schema, [expr], pa.default_memory_pool())
+
+ r, = projector.evaluate(table.to_batches()[0])
+ b = pa.array([False, True, True, True], type=pa.bool_())
+ assert r.equals(b)
+
+
+@pytest.mark.gandiva
+def test_get_registered_function_signatures():
+ import pyarrow.gandiva as gandiva
+ signatures = gandiva.get_registered_function_signatures()
+
+ assert type(signatures[0].return_type()) is pa.DataType
+ assert type(signatures[0].param_types()) is list
+ assert hasattr(signatures[0], "name")
+
+
+@pytest.mark.gandiva
+def test_filter_project():
+ import pyarrow.gandiva as gandiva
+ mpool = pa.default_memory_pool()
+ # Create a table with some sample data
+ array0 = pa.array([10, 12, -20, 5, 21, 29], pa.int32())
+ array1 = pa.array([5, 15, 15, 17, 12, 3], pa.int32())
+ array2 = pa.array([1, 25, 11, 30, -21, None], pa.int32())
+
+ table = pa.Table.from_arrays([array0, array1, array2], ['a', 'b', 'c'])
+
+ field_result = pa.field("res", pa.int32())
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ node_b = builder.make_field(table.schema.field("b"))
+ node_c = builder.make_field(table.schema.field("c"))
+
+ greater_than_function = builder.make_function("greater_than",
+ [node_a, node_b], pa.bool_())
+ filter_condition = builder.make_condition(
+ greater_than_function)
+
+ project_condition = builder.make_function("less_than",
+ [node_b, node_c], pa.bool_())
+ if_node = builder.make_if(project_condition,
+ node_b, node_c, pa.int32())
+ expr = builder.make_expression(if_node, field_result)
+
+ # Build a filter for the expressions.
+ filter = gandiva.make_filter(table.schema, filter_condition)
+
+ # Build a projector for the expressions.
+ projector = gandiva.make_projector(
+ table.schema, [expr], mpool, "UINT32")
+
+ # Evaluate filter
+ selection_vector = filter.evaluate(table.to_batches()[0], mpool)
+
+ # Evaluate project
+ r, = projector.evaluate(
+ table.to_batches()[0], selection_vector)
+
+ exp = pa.array([1, -21, None], pa.int32())
+ assert r.equals(exp)
+
+
+@pytest.mark.gandiva
+def test_to_string():
+ import pyarrow.gandiva as gandiva
+ builder = gandiva.TreeExprBuilder()
+
+ assert str(builder.make_literal(2.0, pa.float64())
+ ).startswith('(const double) 2 raw(')
+ assert str(builder.make_literal(2, pa.int64())) == '(const int64) 2'
+ assert str(builder.make_field(pa.field('x', pa.float64()))) == '(double) x'
+ assert str(builder.make_field(pa.field('y', pa.string()))) == '(string) y'
+
+ field_z = builder.make_field(pa.field('z', pa.bool_()))
+ func_node = builder.make_function('not', [field_z], pa.bool_())
+ assert str(func_node) == 'bool not((bool) z)'
+
+ field_y = builder.make_field(pa.field('y', pa.bool_()))
+ and_node = builder.make_and([func_node, field_y])
+ assert str(and_node) == 'bool not((bool) z) && (bool) y'
diff --git a/src/arrow/python/pyarrow/tests/test_hdfs.py b/src/arrow/python/pyarrow/tests/test_hdfs.py
new file mode 100644
index 000000000..c71353b45
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_hdfs.py
@@ -0,0 +1,447 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import pickle
+import random
+import unittest
+from io import BytesIO
+from os.path import join as pjoin
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow.pandas_compat import _pandas_api
+from pyarrow.tests import util
+from pyarrow.tests.parquet.common import _test_dataframe
+from pyarrow.tests.parquet.test_dataset import (
+ _test_read_common_metadata_files, _test_write_to_dataset_with_partitions,
+ _test_write_to_dataset_no_partitions
+)
+from pyarrow.util import guid
+
+# ----------------------------------------------------------------------
+# HDFS tests
+
+
+def check_libhdfs_present():
+ if not pa.have_libhdfs():
+ message = 'No libhdfs available on system'
+ if os.environ.get('PYARROW_HDFS_TEST_LIBHDFS_REQUIRE'):
+ pytest.fail(message)
+ else:
+ pytest.skip(message)
+
+
+def hdfs_test_client():
+ host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default')
+ user = os.environ.get('ARROW_HDFS_TEST_USER', None)
+ try:
+ port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
+ except ValueError:
+ raise ValueError('Env variable ARROW_HDFS_TEST_PORT was not '
+ 'an integer')
+
+ with pytest.warns(FutureWarning):
+ return pa.hdfs.connect(host, port, user)
+
+
+@pytest.mark.hdfs
+class HdfsTestCases:
+
+ def _make_test_file(self, hdfs, test_name, test_path, test_data):
+ base_path = pjoin(self.tmp_path, test_name)
+ hdfs.mkdir(base_path)
+
+ full_path = pjoin(base_path, test_path)
+
+ with hdfs.open(full_path, 'wb') as f:
+ f.write(test_data)
+
+ return full_path
+
+ @classmethod
+ def setUpClass(cls):
+ cls.check_driver()
+ cls.hdfs = hdfs_test_client()
+ cls.tmp_path = '/tmp/pyarrow-test-{}'.format(random.randint(0, 1000))
+ cls.hdfs.mkdir(cls.tmp_path)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.hdfs.delete(cls.tmp_path, recursive=True)
+ cls.hdfs.close()
+
+ def test_pickle(self):
+ s = pickle.dumps(self.hdfs)
+ h2 = pickle.loads(s)
+ assert h2.is_open
+ assert h2.host == self.hdfs.host
+ assert h2.port == self.hdfs.port
+ assert h2.user == self.hdfs.user
+ assert h2.kerb_ticket == self.hdfs.kerb_ticket
+ # smoketest unpickled client works
+ h2.ls(self.tmp_path)
+
+ def test_cat(self):
+ path = pjoin(self.tmp_path, 'cat-test')
+
+ data = b'foobarbaz'
+ with self.hdfs.open(path, 'wb') as f:
+ f.write(data)
+
+ contents = self.hdfs.cat(path)
+ assert contents == data
+
+ def test_capacity_space(self):
+ capacity = self.hdfs.get_capacity()
+ space_used = self.hdfs.get_space_used()
+ disk_free = self.hdfs.df()
+
+ assert capacity > 0
+ assert capacity > space_used
+ assert disk_free == (capacity - space_used)
+
+ def test_close(self):
+ client = hdfs_test_client()
+ assert client.is_open
+ client.close()
+ assert not client.is_open
+
+ with pytest.raises(Exception):
+ client.ls('/')
+
+ def test_mkdir(self):
+ path = pjoin(self.tmp_path, 'test-dir/test-dir')
+ parent_path = pjoin(self.tmp_path, 'test-dir')
+
+ self.hdfs.mkdir(path)
+ assert self.hdfs.exists(path)
+
+ self.hdfs.delete(parent_path, recursive=True)
+ assert not self.hdfs.exists(path)
+
+ def test_mv_rename(self):
+ path = pjoin(self.tmp_path, 'mv-test')
+ new_path = pjoin(self.tmp_path, 'mv-new-test')
+
+ data = b'foobarbaz'
+ with self.hdfs.open(path, 'wb') as f:
+ f.write(data)
+
+ assert self.hdfs.exists(path)
+ self.hdfs.mv(path, new_path)
+ assert not self.hdfs.exists(path)
+ assert self.hdfs.exists(new_path)
+
+ assert self.hdfs.cat(new_path) == data
+
+ self.hdfs.rename(new_path, path)
+ assert self.hdfs.cat(path) == data
+
+ def test_info(self):
+ path = pjoin(self.tmp_path, 'info-base')
+ file_path = pjoin(path, 'ex')
+ self.hdfs.mkdir(path)
+
+ data = b'foobarbaz'
+ with self.hdfs.open(file_path, 'wb') as f:
+ f.write(data)
+
+ path_info = self.hdfs.info(path)
+ file_path_info = self.hdfs.info(file_path)
+
+ assert path_info['kind'] == 'directory'
+
+ assert file_path_info['kind'] == 'file'
+ assert file_path_info['size'] == len(data)
+
+ def test_exists_isdir_isfile(self):
+ dir_path = pjoin(self.tmp_path, 'info-base')
+ file_path = pjoin(dir_path, 'ex')
+ missing_path = pjoin(dir_path, 'this-path-is-missing')
+
+ self.hdfs.mkdir(dir_path)
+ with self.hdfs.open(file_path, 'wb') as f:
+ f.write(b'foobarbaz')
+
+ assert self.hdfs.exists(dir_path)
+ assert self.hdfs.exists(file_path)
+ assert not self.hdfs.exists(missing_path)
+
+ assert self.hdfs.isdir(dir_path)
+ assert not self.hdfs.isdir(file_path)
+ assert not self.hdfs.isdir(missing_path)
+
+ assert not self.hdfs.isfile(dir_path)
+ assert self.hdfs.isfile(file_path)
+ assert not self.hdfs.isfile(missing_path)
+
+ def test_disk_usage(self):
+ path = pjoin(self.tmp_path, 'disk-usage-base')
+ p1 = pjoin(path, 'p1')
+ p2 = pjoin(path, 'p2')
+
+ subdir = pjoin(path, 'subdir')
+ p3 = pjoin(subdir, 'p3')
+
+ if self.hdfs.exists(path):
+ self.hdfs.delete(path, True)
+
+ self.hdfs.mkdir(path)
+ self.hdfs.mkdir(subdir)
+
+ data = b'foobarbaz'
+
+ for file_path in [p1, p2, p3]:
+ with self.hdfs.open(file_path, 'wb') as f:
+ f.write(data)
+
+ assert self.hdfs.disk_usage(path) == len(data) * 3
+
+ def test_ls(self):
+ base_path = pjoin(self.tmp_path, 'ls-test')
+ self.hdfs.mkdir(base_path)
+
+ dir_path = pjoin(base_path, 'a-dir')
+ f1_path = pjoin(base_path, 'a-file-1')
+
+ self.hdfs.mkdir(dir_path)
+
+ f = self.hdfs.open(f1_path, 'wb')
+ f.write(b'a' * 10)
+
+ contents = sorted(self.hdfs.ls(base_path, False))
+ assert contents == [dir_path, f1_path]
+
+ def test_chmod_chown(self):
+ path = pjoin(self.tmp_path, 'chmod-test')
+ with self.hdfs.open(path, 'wb') as f:
+ f.write(b'a' * 10)
+
+ def test_download_upload(self):
+ base_path = pjoin(self.tmp_path, 'upload-test')
+
+ data = b'foobarbaz'
+ buf = BytesIO(data)
+ buf.seek(0)
+
+ self.hdfs.upload(base_path, buf)
+
+ out_buf = BytesIO()
+ self.hdfs.download(base_path, out_buf)
+ out_buf.seek(0)
+ assert out_buf.getvalue() == data
+
+ def test_file_context_manager(self):
+ path = pjoin(self.tmp_path, 'ctx-manager')
+
+ data = b'foo'
+ with self.hdfs.open(path, 'wb') as f:
+ f.write(data)
+
+ with self.hdfs.open(path, 'rb') as f:
+ assert f.size() == 3
+ result = f.read(10)
+ assert result == data
+
+ def test_open_not_exist(self):
+ path = pjoin(self.tmp_path, 'does-not-exist-123')
+
+ with pytest.raises(FileNotFoundError):
+ self.hdfs.open(path)
+
+ def test_open_write_error(self):
+ with pytest.raises((FileExistsError, IsADirectoryError)):
+ self.hdfs.open('/', 'wb')
+
+ def test_read_whole_file(self):
+ path = pjoin(self.tmp_path, 'read-whole-file')
+
+ data = b'foo' * 1000
+ with self.hdfs.open(path, 'wb') as f:
+ f.write(data)
+
+ with self.hdfs.open(path, 'rb') as f:
+ result = f.read()
+
+ assert result == data
+
+ def _write_multiple_hdfs_pq_files(self, tmpdir):
+ import pyarrow.parquet as pq
+ nfiles = 10
+ size = 5
+ test_data = []
+ for i in range(nfiles):
+ df = _test_dataframe(size, seed=i)
+
+ df['index'] = np.arange(i * size, (i + 1) * size)
+
+ # Hack so that we don't have a dtype cast in v1 files
+ df['uint32'] = df['uint32'].astype(np.int64)
+
+ path = pjoin(tmpdir, '{}.parquet'.format(i))
+
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ with self.hdfs.open(path, 'wb') as f:
+ pq.write_table(table, f)
+
+ test_data.append(table)
+
+ expected = pa.concat_tables(test_data)
+ return expected
+
+ @pytest.mark.pandas
+ @pytest.mark.parquet
+ def test_read_multiple_parquet_files(self):
+
+ tmpdir = pjoin(self.tmp_path, 'multi-parquet-' + guid())
+
+ self.hdfs.mkdir(tmpdir)
+
+ expected = self._write_multiple_hdfs_pq_files(tmpdir)
+ result = self.hdfs.read_parquet(tmpdir)
+
+ _pandas_api.assert_frame_equal(result.to_pandas()
+ .sort_values(by='index')
+ .reset_index(drop=True),
+ expected.to_pandas())
+
+ @pytest.mark.pandas
+ @pytest.mark.parquet
+ def test_read_multiple_parquet_files_with_uri(self):
+ import pyarrow.parquet as pq
+
+ tmpdir = pjoin(self.tmp_path, 'multi-parquet-uri-' + guid())
+
+ self.hdfs.mkdir(tmpdir)
+
+ expected = self._write_multiple_hdfs_pq_files(tmpdir)
+ path = _get_hdfs_uri(tmpdir)
+ result = pq.read_table(path)
+
+ _pandas_api.assert_frame_equal(result.to_pandas()
+ .sort_values(by='index')
+ .reset_index(drop=True),
+ expected.to_pandas())
+
+ @pytest.mark.pandas
+ @pytest.mark.parquet
+ def test_read_write_parquet_files_with_uri(self):
+ import pyarrow.parquet as pq
+
+ tmpdir = pjoin(self.tmp_path, 'uri-parquet-' + guid())
+ self.hdfs.mkdir(tmpdir)
+ path = _get_hdfs_uri(pjoin(tmpdir, 'test.parquet'))
+
+ size = 5
+ df = _test_dataframe(size, seed=0)
+ # Hack so that we don't have a dtype cast in v1 files
+ df['uint32'] = df['uint32'].astype(np.int64)
+ table = pa.Table.from_pandas(df, preserve_index=False)
+
+ pq.write_table(table, path, filesystem=self.hdfs)
+
+ result = pq.read_table(
+ path, filesystem=self.hdfs, use_legacy_dataset=True
+ ).to_pandas()
+
+ _pandas_api.assert_frame_equal(result, df)
+
+ @pytest.mark.parquet
+ @pytest.mark.pandas
+ def test_read_common_metadata_files(self):
+ tmpdir = pjoin(self.tmp_path, 'common-metadata-' + guid())
+ self.hdfs.mkdir(tmpdir)
+ _test_read_common_metadata_files(self.hdfs, tmpdir)
+
+ @pytest.mark.parquet
+ @pytest.mark.pandas
+ def test_write_to_dataset_with_partitions(self):
+ tmpdir = pjoin(self.tmp_path, 'write-partitions-' + guid())
+ self.hdfs.mkdir(tmpdir)
+ _test_write_to_dataset_with_partitions(
+ tmpdir, filesystem=self.hdfs)
+
+ @pytest.mark.parquet
+ @pytest.mark.pandas
+ def test_write_to_dataset_no_partitions(self):
+ tmpdir = pjoin(self.tmp_path, 'write-no_partitions-' + guid())
+ self.hdfs.mkdir(tmpdir)
+ _test_write_to_dataset_no_partitions(
+ tmpdir, filesystem=self.hdfs)
+
+
+class TestLibHdfs(HdfsTestCases, unittest.TestCase):
+
+ @classmethod
+ def check_driver(cls):
+ check_libhdfs_present()
+
+ def test_orphaned_file(self):
+ hdfs = hdfs_test_client()
+ file_path = self._make_test_file(hdfs, 'orphaned_file_test', 'fname',
+ b'foobarbaz')
+
+ f = hdfs.open(file_path)
+ hdfs = None
+ f = None # noqa
+
+
+def _get_hdfs_uri(path):
+ host = os.environ.get('ARROW_HDFS_TEST_HOST', 'localhost')
+ try:
+ port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0))
+ except ValueError:
+ raise ValueError('Env variable ARROW_HDFS_TEST_PORT was not '
+ 'an integer')
+ uri = "hdfs://{}:{}{}".format(host, port, path)
+
+ return uri
+
+
+@pytest.mark.hdfs
+@pytest.mark.pandas
+@pytest.mark.parquet
+@pytest.mark.fastparquet
+def test_fastparquet_read_with_hdfs():
+ from pandas.testing import assert_frame_equal
+
+ check_libhdfs_present()
+ try:
+ import snappy # noqa
+ except ImportError:
+ pytest.skip('fastparquet test requires snappy')
+
+ import pyarrow.parquet as pq
+ fastparquet = pytest.importorskip('fastparquet')
+
+ fs = hdfs_test_client()
+
+ df = util.make_dataframe()
+
+ table = pa.Table.from_pandas(df)
+
+ path = '/tmp/testing.parquet'
+ with fs.open(path, 'wb') as f:
+ pq.write_table(table, f)
+
+ parquet_file = fastparquet.ParquetFile(path, open_with=fs.open)
+
+ result = parquet_file.to_pandas()
+ assert_frame_equal(result, df)
diff --git a/src/arrow/python/pyarrow/tests/test_io.py b/src/arrow/python/pyarrow/tests/test_io.py
new file mode 100644
index 000000000..ea1e5e557
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_io.py
@@ -0,0 +1,1886 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import bz2
+from contextlib import contextmanager
+from io import (BytesIO, StringIO, TextIOWrapper, BufferedIOBase, IOBase)
+import itertools
+import gc
+import gzip
+import os
+import pathlib
+import pickle
+import pytest
+import sys
+import tempfile
+import weakref
+
+import numpy as np
+
+from pyarrow.util import guid
+from pyarrow import Codec
+import pyarrow as pa
+
+
+def check_large_seeks(file_factory):
+ if sys.platform in ('win32', 'darwin'):
+ pytest.skip("need sparse file support")
+ try:
+ filename = tempfile.mktemp(prefix='test_io')
+ with open(filename, 'wb') as f:
+ f.truncate(2 ** 32 + 10)
+ f.seek(2 ** 32 + 5)
+ f.write(b'mark\n')
+ with file_factory(filename) as f:
+ assert f.seek(2 ** 32 + 5) == 2 ** 32 + 5
+ assert f.tell() == 2 ** 32 + 5
+ assert f.read(5) == b'mark\n'
+ assert f.tell() == 2 ** 32 + 10
+ finally:
+ os.unlink(filename)
+
+
+@contextmanager
+def assert_file_not_found():
+ with pytest.raises(FileNotFoundError):
+ yield
+
+
+# ----------------------------------------------------------------------
+# Python file-like objects
+
+
+def test_python_file_write():
+ buf = BytesIO()
+
+ f = pa.PythonFile(buf)
+
+ assert f.tell() == 0
+
+ s1 = b'enga\xc3\xb1ado'
+ s2 = b'foobar'
+
+ f.write(s1)
+ assert f.tell() == len(s1)
+
+ f.write(s2)
+
+ expected = s1 + s2
+
+ result = buf.getvalue()
+ assert result == expected
+
+ assert not f.closed
+ f.close()
+ assert f.closed
+
+ with pytest.raises(TypeError, match="binary file expected"):
+ pa.PythonFile(StringIO())
+
+
+def test_python_file_read():
+ data = b'some sample data'
+
+ buf = BytesIO(data)
+ f = pa.PythonFile(buf, mode='r')
+
+ assert f.size() == len(data)
+
+ assert f.tell() == 0
+
+ assert f.read(4) == b'some'
+ assert f.tell() == 4
+
+ f.seek(0)
+ assert f.tell() == 0
+
+ f.seek(5)
+ assert f.tell() == 5
+
+ v = f.read(50)
+ assert v == b'sample data'
+ assert len(v) == 11
+
+ assert f.size() == len(data)
+
+ assert not f.closed
+ f.close()
+ assert f.closed
+
+ with pytest.raises(TypeError, match="binary file expected"):
+ pa.PythonFile(StringIO(), mode='r')
+
+
+def test_python_file_read_at():
+ data = b'some sample data'
+
+ buf = BytesIO(data)
+ f = pa.PythonFile(buf, mode='r')
+
+ # test simple read at
+ v = f.read_at(nbytes=5, offset=3)
+ assert v == b'e sam'
+ assert len(v) == 5
+
+ # test reading entire file when nbytes > len(file)
+ w = f.read_at(nbytes=50, offset=0)
+ assert w == data
+ assert len(w) == 16
+
+
+def test_python_file_readall():
+ data = b'some sample data'
+
+ buf = BytesIO(data)
+ with pa.PythonFile(buf, mode='r') as f:
+ assert f.readall() == data
+
+
+def test_python_file_readinto():
+ length = 10
+ data = b'some sample data longer than 10'
+ dst_buf = bytearray(length)
+ src_buf = BytesIO(data)
+
+ with pa.PythonFile(src_buf, mode='r') as f:
+ assert f.readinto(dst_buf) == 10
+
+ assert dst_buf[:length] == data[:length]
+ assert len(dst_buf) == length
+
+
+def test_python_file_read_buffer():
+ length = 10
+ data = b'0123456798'
+ dst_buf = bytearray(data)
+
+ class DuckReader:
+ def close(self):
+ pass
+
+ @property
+ def closed(self):
+ return False
+
+ def read_buffer(self, nbytes):
+ assert nbytes == length
+ return memoryview(dst_buf)[:nbytes]
+
+ duck_reader = DuckReader()
+ with pa.PythonFile(duck_reader, mode='r') as f:
+ buf = f.read_buffer(length)
+ assert len(buf) == length
+ assert memoryview(buf).tobytes() == dst_buf[:length]
+ # buf should point to the same memory, so modyfing it
+ memoryview(buf)[0] = ord(b'x')
+ # should modify the original
+ assert dst_buf[0] == ord(b'x')
+
+
+def test_python_file_correct_abc():
+ with pa.PythonFile(BytesIO(b''), mode='r') as f:
+ assert isinstance(f, BufferedIOBase)
+ assert isinstance(f, IOBase)
+
+
+def test_python_file_iterable():
+ data = b'''line1
+ line2
+ line3
+ '''
+
+ buf = BytesIO(data)
+ buf2 = BytesIO(data)
+
+ with pa.PythonFile(buf, mode='r') as f:
+ for read, expected in zip(f, buf2):
+ assert read == expected
+
+
+def test_python_file_large_seeks():
+ def factory(filename):
+ return pa.PythonFile(open(filename, 'rb'))
+
+ check_large_seeks(factory)
+
+
+def test_bytes_reader():
+ # Like a BytesIO, but zero-copy underneath for C++ consumers
+ data = b'some sample data'
+ f = pa.BufferReader(data)
+ assert f.tell() == 0
+
+ assert f.size() == len(data)
+
+ assert f.read(4) == b'some'
+ assert f.tell() == 4
+
+ f.seek(0)
+ assert f.tell() == 0
+
+ f.seek(0, 2)
+ assert f.tell() == len(data)
+
+ f.seek(5)
+ assert f.tell() == 5
+
+ assert f.read(50) == b'sample data'
+
+ assert not f.closed
+ f.close()
+ assert f.closed
+
+
+def test_bytes_reader_non_bytes():
+ with pytest.raises(TypeError):
+ pa.BufferReader('some sample data')
+
+
+def test_bytes_reader_retains_parent_reference():
+ import gc
+
+ # ARROW-421
+ def get_buffer():
+ data = b'some sample data' * 1000
+ reader = pa.BufferReader(data)
+ reader.seek(5)
+ return reader.read_buffer(6)
+
+ buf = get_buffer()
+ gc.collect()
+ assert buf.to_pybytes() == b'sample'
+ assert buf.parent is not None
+
+
+def test_python_file_implicit_mode(tmpdir):
+ path = os.path.join(str(tmpdir), 'foo.txt')
+ with open(path, 'wb') as f:
+ pf = pa.PythonFile(f)
+ assert pf.writable()
+ assert not pf.readable()
+ assert not pf.seekable() # PyOutputStream isn't seekable
+ f.write(b'foobar\n')
+
+ with open(path, 'rb') as f:
+ pf = pa.PythonFile(f)
+ assert pf.readable()
+ assert not pf.writable()
+ assert pf.seekable()
+ assert pf.read() == b'foobar\n'
+
+ bio = BytesIO()
+ pf = pa.PythonFile(bio)
+ assert pf.writable()
+ assert not pf.readable()
+ assert not pf.seekable()
+ pf.write(b'foobar\n')
+ assert bio.getvalue() == b'foobar\n'
+
+
+def test_python_file_writelines(tmpdir):
+ lines = [b'line1\n', b'line2\n' b'line3']
+ path = os.path.join(str(tmpdir), 'foo.txt')
+ with open(path, 'wb') as f:
+ try:
+ f = pa.PythonFile(f, mode='w')
+ assert f.writable()
+ f.writelines(lines)
+ finally:
+ f.close()
+
+ with open(path, 'rb') as f:
+ try:
+ f = pa.PythonFile(f, mode='r')
+ assert f.readable()
+ assert f.read() == b''.join(lines)
+ finally:
+ f.close()
+
+
+def test_python_file_closing():
+ bio = BytesIO()
+ pf = pa.PythonFile(bio)
+ wr = weakref.ref(pf)
+ del pf
+ assert wr() is None # object was destroyed
+ assert not bio.closed
+ pf = pa.PythonFile(bio)
+ pf.close()
+ assert bio.closed
+
+
+# ----------------------------------------------------------------------
+# Buffers
+
+
+def test_buffer_bytes():
+ val = b'some data'
+
+ buf = pa.py_buffer(val)
+ assert isinstance(buf, pa.Buffer)
+ assert not buf.is_mutable
+ assert buf.is_cpu
+
+ result = buf.to_pybytes()
+
+ assert result == val
+
+ # Check that buffers survive a pickle roundtrip
+ result_buf = pickle.loads(pickle.dumps(buf))
+ result = result_buf.to_pybytes()
+ assert result == val
+
+
+def test_buffer_memoryview():
+ val = b'some data'
+
+ buf = pa.py_buffer(val)
+ assert isinstance(buf, pa.Buffer)
+ assert not buf.is_mutable
+ assert buf.is_cpu
+
+ result = memoryview(buf)
+
+ assert result == val
+
+
+def test_buffer_bytearray():
+ val = bytearray(b'some data')
+
+ buf = pa.py_buffer(val)
+ assert isinstance(buf, pa.Buffer)
+ assert buf.is_mutable
+ assert buf.is_cpu
+
+ result = bytearray(buf)
+
+ assert result == val
+
+
+def test_buffer_invalid():
+ with pytest.raises(TypeError,
+ match="(bytes-like object|buffer interface)"):
+ pa.py_buffer(None)
+
+
+def test_buffer_weakref():
+ buf = pa.py_buffer(b'some data')
+ wr = weakref.ref(buf)
+ assert wr() is not None
+ del buf
+ assert wr() is None
+
+
+@pytest.mark.parametrize('val, expected_hex_buffer',
+ [(b'check', b'636865636B'),
+ (b'\a0', b'0730'),
+ (b'', b'')])
+def test_buffer_hex(val, expected_hex_buffer):
+ buf = pa.py_buffer(val)
+ assert buf.hex() == expected_hex_buffer
+
+
+def test_buffer_to_numpy():
+ # Make sure creating a numpy array from an arrow buffer works
+ byte_array = bytearray(20)
+ byte_array[0] = 42
+ buf = pa.py_buffer(byte_array)
+ array = np.frombuffer(buf, dtype="uint8")
+ assert array[0] == byte_array[0]
+ byte_array[0] += 1
+ assert array[0] == byte_array[0]
+ assert array.base == buf
+
+
+def test_buffer_from_numpy():
+ # C-contiguous
+ arr = np.arange(12, dtype=np.int8).reshape((3, 4))
+ buf = pa.py_buffer(arr)
+ assert buf.is_cpu
+ assert buf.is_mutable
+ assert buf.to_pybytes() == arr.tobytes()
+ # F-contiguous; note strides information is lost
+ buf = pa.py_buffer(arr.T)
+ assert buf.is_cpu
+ assert buf.is_mutable
+ assert buf.to_pybytes() == arr.tobytes()
+ # Non-contiguous
+ with pytest.raises(ValueError, match="not contiguous"):
+ buf = pa.py_buffer(arr.T[::2])
+
+
+def test_buffer_address():
+ b1 = b'some data!'
+ b2 = bytearray(b1)
+ b3 = bytearray(b1)
+
+ buf1 = pa.py_buffer(b1)
+ buf2 = pa.py_buffer(b1)
+ buf3 = pa.py_buffer(b2)
+ buf4 = pa.py_buffer(b3)
+
+ assert buf1.address > 0
+ assert buf1.address == buf2.address
+ assert buf3.address != buf2.address
+ assert buf4.address != buf3.address
+
+ arr = np.arange(5)
+ buf = pa.py_buffer(arr)
+ assert buf.address == arr.ctypes.data
+
+
+def test_buffer_equals():
+ # Buffer.equals() returns true iff the buffers have the same contents
+ def eq(a, b):
+ assert a.equals(b)
+ assert a == b
+ assert not (a != b)
+
+ def ne(a, b):
+ assert not a.equals(b)
+ assert not (a == b)
+ assert a != b
+
+ b1 = b'some data!'
+ b2 = bytearray(b1)
+ b3 = bytearray(b1)
+ b3[0] = 42
+ buf1 = pa.py_buffer(b1)
+ buf2 = pa.py_buffer(b2)
+ buf3 = pa.py_buffer(b2)
+ buf4 = pa.py_buffer(b3)
+ buf5 = pa.py_buffer(np.frombuffer(b2, dtype=np.int16))
+ eq(buf1, buf1)
+ eq(buf1, buf2)
+ eq(buf2, buf3)
+ ne(buf2, buf4)
+ # Data type is indifferent
+ eq(buf2, buf5)
+
+
+def test_buffer_eq_bytes():
+ buf = pa.py_buffer(b'some data')
+ assert buf == b'some data'
+ assert buf == bytearray(b'some data')
+ assert buf != b'some dat1'
+
+ with pytest.raises(TypeError):
+ buf == 'some data'
+
+
+def test_buffer_getitem():
+ data = bytearray(b'some data!')
+ buf = pa.py_buffer(data)
+
+ n = len(data)
+ for ix in range(-n, n - 1):
+ assert buf[ix] == data[ix]
+
+ with pytest.raises(IndexError):
+ buf[n]
+
+ with pytest.raises(IndexError):
+ buf[-n - 1]
+
+
+def test_buffer_slicing():
+ data = b'some data!'
+ buf = pa.py_buffer(data)
+
+ sliced = buf.slice(2)
+ expected = pa.py_buffer(b'me data!')
+ assert sliced.equals(expected)
+
+ sliced2 = buf.slice(2, 4)
+ expected2 = pa.py_buffer(b'me d')
+ assert sliced2.equals(expected2)
+
+ # 0 offset
+ assert buf.slice(0).equals(buf)
+
+ # Slice past end of buffer
+ assert len(buf.slice(len(buf))) == 0
+
+ with pytest.raises(IndexError):
+ buf.slice(-1)
+
+ # Test slice notation
+ assert buf[2:].equals(buf.slice(2))
+ assert buf[2:5].equals(buf.slice(2, 3))
+ assert buf[-5:].equals(buf.slice(len(buf) - 5))
+ with pytest.raises(IndexError):
+ buf[::-1]
+ with pytest.raises(IndexError):
+ buf[::2]
+
+ n = len(buf)
+ for start in range(-n * 2, n * 2):
+ for stop in range(-n * 2, n * 2):
+ assert buf[start:stop].to_pybytes() == buf.to_pybytes()[start:stop]
+
+
+def test_buffer_hashing():
+ # Buffers are unhashable
+ with pytest.raises(TypeError, match="unhashable"):
+ hash(pa.py_buffer(b'123'))
+
+
+def test_buffer_protocol_respects_immutability():
+ # ARROW-3228; NumPy's frombuffer ctor determines whether a buffer-like
+ # object is mutable by first attempting to get a mutable buffer using
+ # PyObject_FromBuffer. If that fails, it assumes that the object is
+ # immutable
+ a = b'12345'
+ arrow_ref = pa.py_buffer(a)
+ numpy_ref = np.frombuffer(arrow_ref, dtype=np.uint8)
+ assert not numpy_ref.flags.writeable
+
+
+def test_foreign_buffer():
+ obj = np.array([1, 2], dtype=np.int32)
+ addr = obj.__array_interface__["data"][0]
+ size = obj.nbytes
+ buf = pa.foreign_buffer(addr, size, obj)
+ wr = weakref.ref(obj)
+ del obj
+ assert np.frombuffer(buf, dtype=np.int32).tolist() == [1, 2]
+ assert wr() is not None
+ del buf
+ assert wr() is None
+
+
+def test_allocate_buffer():
+ buf = pa.allocate_buffer(100)
+ assert buf.size == 100
+ assert buf.is_mutable
+ assert buf.parent is None
+
+ bit = b'abcde'
+ writer = pa.FixedSizeBufferWriter(buf)
+ writer.write(bit)
+
+ assert buf.to_pybytes()[:5] == bit
+
+
+def test_allocate_buffer_resizable():
+ buf = pa.allocate_buffer(100, resizable=True)
+ assert isinstance(buf, pa.ResizableBuffer)
+
+ buf.resize(200)
+ assert buf.size == 200
+
+
+@pytest.mark.parametrize("compression", [
+ pytest.param(
+ "bz2", marks=pytest.mark.xfail(raises=pa.lib.ArrowNotImplementedError)
+ ),
+ "brotli",
+ "gzip",
+ "lz4",
+ "zstd",
+ "snappy"
+])
+def test_compress_decompress(compression):
+ if not Codec.is_available(compression):
+ pytest.skip("{} support is not built".format(compression))
+
+ INPUT_SIZE = 10000
+ test_data = (np.random.randint(0, 255, size=INPUT_SIZE)
+ .astype(np.uint8)
+ .tobytes())
+ test_buf = pa.py_buffer(test_data)
+
+ compressed_buf = pa.compress(test_buf, codec=compression)
+ compressed_bytes = pa.compress(test_data, codec=compression,
+ asbytes=True)
+
+ assert isinstance(compressed_bytes, bytes)
+
+ decompressed_buf = pa.decompress(compressed_buf, INPUT_SIZE,
+ codec=compression)
+ decompressed_bytes = pa.decompress(compressed_bytes, INPUT_SIZE,
+ codec=compression, asbytes=True)
+
+ assert isinstance(decompressed_bytes, bytes)
+
+ assert decompressed_buf.equals(test_buf)
+ assert decompressed_bytes == test_data
+
+ with pytest.raises(ValueError):
+ pa.decompress(compressed_bytes, codec=compression)
+
+
+@pytest.mark.parametrize("compression", [
+ pytest.param(
+ "bz2", marks=pytest.mark.xfail(raises=pa.lib.ArrowNotImplementedError)
+ ),
+ "brotli",
+ "gzip",
+ "lz4",
+ "zstd",
+ "snappy"
+])
+def test_compression_level(compression):
+ if not Codec.is_available(compression):
+ pytest.skip("{} support is not built".format(compression))
+
+ # These codecs do not support a compression level
+ no_level = ['snappy', 'lz4']
+ if compression in no_level:
+ assert not Codec.supports_compression_level(compression)
+ with pytest.raises(ValueError):
+ Codec(compression, 0)
+ with pytest.raises(ValueError):
+ Codec.minimum_compression_level(compression)
+ with pytest.raises(ValueError):
+ Codec.maximum_compression_level(compression)
+ with pytest.raises(ValueError):
+ Codec.default_compression_level(compression)
+ return
+
+ INPUT_SIZE = 10000
+ test_data = (np.random.randint(0, 255, size=INPUT_SIZE)
+ .astype(np.uint8)
+ .tobytes())
+ test_buf = pa.py_buffer(test_data)
+
+ min_level = Codec.minimum_compression_level(compression)
+ max_level = Codec.maximum_compression_level(compression)
+ default_level = Codec.default_compression_level(compression)
+
+ assert min_level < max_level
+ assert default_level >= min_level
+ assert default_level <= max_level
+
+ for compression_level in range(min_level, max_level+1):
+ codec = Codec(compression, compression_level)
+ compressed_buf = codec.compress(test_buf)
+ compressed_bytes = codec.compress(test_data, asbytes=True)
+ assert isinstance(compressed_bytes, bytes)
+ decompressed_buf = codec.decompress(compressed_buf, INPUT_SIZE)
+ decompressed_bytes = codec.decompress(compressed_bytes, INPUT_SIZE,
+ asbytes=True)
+
+ assert isinstance(decompressed_bytes, bytes)
+
+ assert decompressed_buf.equals(test_buf)
+ assert decompressed_bytes == test_data
+
+ with pytest.raises(ValueError):
+ codec.decompress(compressed_bytes)
+
+ # The ability to set a seed this way is not present on older versions of
+ # numpy (currently in our python 3.6 CI build). Some inputs might just
+ # happen to compress the same between the two levels so using seeded
+ # random numbers is neccesary to help get more reliable results
+ #
+ # The goal of this part is to ensure the compression_level is being
+ # passed down to the C++ layer, not to verify the compression algs
+ # themselves
+ if not hasattr(np.random, 'default_rng'):
+ pytest.skip('Requires newer version of numpy')
+ rng = np.random.default_rng(seed=42)
+ values = rng.integers(0, 100, 1000)
+ arr = pa.array(values)
+ hard_to_compress_buffer = arr.buffers()[1]
+
+ weak_codec = Codec(compression, min_level)
+ weakly_compressed_buf = weak_codec.compress(hard_to_compress_buffer)
+
+ strong_codec = Codec(compression, max_level)
+ strongly_compressed_buf = strong_codec.compress(hard_to_compress_buffer)
+
+ assert len(weakly_compressed_buf) > len(strongly_compressed_buf)
+
+
+def test_buffer_memoryview_is_immutable():
+ val = b'some data'
+
+ buf = pa.py_buffer(val)
+ assert not buf.is_mutable
+ assert isinstance(buf, pa.Buffer)
+
+ result = memoryview(buf)
+ assert result.readonly
+
+ with pytest.raises(TypeError) as exc:
+ result[0] = b'h'
+ assert 'cannot modify read-only' in str(exc.value)
+
+ b = bytes(buf)
+ with pytest.raises(TypeError) as exc:
+ b[0] = b'h'
+ assert 'cannot modify read-only' in str(exc.value)
+
+
+def test_uninitialized_buffer():
+ # ARROW-2039: calling Buffer() directly creates an uninitialized object
+ # ARROW-2638: prevent calling extension class constructors directly
+ with pytest.raises(TypeError):
+ pa.Buffer()
+
+
+def test_memory_output_stream():
+ # 10 bytes
+ val = b'dataabcdef'
+ f = pa.BufferOutputStream()
+
+ K = 1000
+ for i in range(K):
+ f.write(val)
+
+ buf = f.getvalue()
+ assert len(buf) == len(val) * K
+ assert buf.to_pybytes() == val * K
+
+
+def test_inmemory_write_after_closed():
+ f = pa.BufferOutputStream()
+ f.write(b'ok')
+ assert not f.closed
+ f.getvalue()
+ assert f.closed
+
+ with pytest.raises(ValueError):
+ f.write(b'not ok')
+
+
+def test_buffer_protocol_ref_counting():
+ def make_buffer(bytes_obj):
+ return bytearray(pa.py_buffer(bytes_obj))
+
+ buf = make_buffer(b'foo')
+ gc.collect()
+ assert buf == b'foo'
+
+ # ARROW-1053
+ val = b'foo'
+ refcount_before = sys.getrefcount(val)
+ for i in range(10):
+ make_buffer(val)
+ gc.collect()
+ assert refcount_before == sys.getrefcount(val)
+
+
+def test_nativefile_write_memoryview():
+ f = pa.BufferOutputStream()
+ data = b'ok'
+
+ arr = np.frombuffer(data, dtype='S1')
+
+ f.write(arr)
+ f.write(bytearray(data))
+ f.write(pa.py_buffer(data))
+ with pytest.raises(TypeError):
+ f.write(data.decode('utf8'))
+
+ buf = f.getvalue()
+
+ assert buf.to_pybytes() == data * 3
+
+
+# ----------------------------------------------------------------------
+# Mock output stream
+
+
+def test_mock_output_stream():
+ # Make sure that the MockOutputStream and the BufferOutputStream record the
+ # same size
+
+ # 10 bytes
+ val = b'dataabcdef'
+
+ f1 = pa.MockOutputStream()
+ f2 = pa.BufferOutputStream()
+
+ K = 1000
+ for i in range(K):
+ f1.write(val)
+ f2.write(val)
+
+ assert f1.size() == len(f2.getvalue())
+
+ # Do the same test with a table
+ record_batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], ['a'])
+
+ f1 = pa.MockOutputStream()
+ f2 = pa.BufferOutputStream()
+
+ stream_writer1 = pa.RecordBatchStreamWriter(f1, record_batch.schema)
+ stream_writer2 = pa.RecordBatchStreamWriter(f2, record_batch.schema)
+
+ stream_writer1.write_batch(record_batch)
+ stream_writer2.write_batch(record_batch)
+ stream_writer1.close()
+ stream_writer2.close()
+
+ assert f1.size() == len(f2.getvalue())
+
+
+# ----------------------------------------------------------------------
+# OS files and memory maps
+
+
+@pytest.fixture
+def sample_disk_data(request, tmpdir):
+ SIZE = 4096
+ arr = np.random.randint(0, 256, size=SIZE).astype('u1')
+ data = arr.tobytes()[:SIZE]
+
+ path = os.path.join(str(tmpdir), guid())
+
+ with open(path, 'wb') as f:
+ f.write(data)
+
+ def teardown():
+ _try_delete(path)
+
+ request.addfinalizer(teardown)
+ return path, data
+
+
+def _check_native_file_reader(FACTORY, sample_data,
+ allow_read_out_of_bounds=True):
+ path, data = sample_data
+
+ f = FACTORY(path, mode='r')
+
+ assert f.read(10) == data[:10]
+ assert f.read(0) == b''
+ assert f.tell() == 10
+
+ assert f.read() == data[10:]
+
+ assert f.size() == len(data)
+
+ f.seek(0)
+ assert f.tell() == 0
+
+ # Seeking past end of file not supported in memory maps
+ if allow_read_out_of_bounds:
+ f.seek(len(data) + 1)
+ assert f.tell() == len(data) + 1
+ assert f.read(5) == b''
+
+ # Test whence argument of seek, ARROW-1287
+ assert f.seek(3) == 3
+ assert f.seek(3, os.SEEK_CUR) == 6
+ assert f.tell() == 6
+
+ ex_length = len(data) - 2
+ assert f.seek(-2, os.SEEK_END) == ex_length
+ assert f.tell() == ex_length
+
+
+def test_memory_map_reader(sample_disk_data):
+ _check_native_file_reader(pa.memory_map, sample_disk_data,
+ allow_read_out_of_bounds=False)
+
+
+def test_memory_map_retain_buffer_reference(sample_disk_data):
+ path, data = sample_disk_data
+
+ cases = []
+ with pa.memory_map(path, 'rb') as f:
+ cases.append((f.read_buffer(100), data[:100]))
+ cases.append((f.read_buffer(100), data[100:200]))
+ cases.append((f.read_buffer(100), data[200:300]))
+
+ # Call gc.collect() for good measure
+ gc.collect()
+
+ for buf, expected in cases:
+ assert buf.to_pybytes() == expected
+
+
+def test_os_file_reader(sample_disk_data):
+ _check_native_file_reader(pa.OSFile, sample_disk_data)
+
+
+def test_os_file_large_seeks():
+ check_large_seeks(pa.OSFile)
+
+
+def _try_delete(path):
+ try:
+ os.remove(path)
+ except os.error:
+ pass
+
+
+def test_memory_map_writer(tmpdir):
+ SIZE = 4096
+ arr = np.random.randint(0, 256, size=SIZE).astype('u1')
+ data = arr.tobytes()[:SIZE]
+
+ path = os.path.join(str(tmpdir), guid())
+ with open(path, 'wb') as f:
+ f.write(data)
+
+ f = pa.memory_map(path, mode='r+b')
+
+ f.seek(10)
+ f.write(b'peekaboo')
+ assert f.tell() == 18
+
+ f.seek(10)
+ assert f.read(8) == b'peekaboo'
+
+ f2 = pa.memory_map(path, mode='r+b')
+
+ f2.seek(10)
+ f2.write(b'booapeak')
+ f2.seek(10)
+
+ f.seek(10)
+ assert f.read(8) == b'booapeak'
+
+ # Does not truncate file
+ f3 = pa.memory_map(path, mode='w')
+ f3.write(b'foo')
+
+ with pa.memory_map(path) as f4:
+ assert f4.size() == SIZE
+
+ with pytest.raises(IOError):
+ f3.read(5)
+
+ f.seek(0)
+ assert f.read(3) == b'foo'
+
+
+def test_memory_map_resize(tmpdir):
+ SIZE = 4096
+ arr = np.random.randint(0, 256, size=SIZE).astype(np.uint8)
+ data1 = arr.tobytes()[:(SIZE // 2)]
+ data2 = arr.tobytes()[(SIZE // 2):]
+
+ path = os.path.join(str(tmpdir), guid())
+
+ mmap = pa.create_memory_map(path, SIZE / 2)
+ mmap.write(data1)
+
+ mmap.resize(SIZE)
+ mmap.write(data2)
+
+ mmap.close()
+
+ with open(path, 'rb') as f:
+ assert f.read() == arr.tobytes()
+
+
+def test_memory_zero_length(tmpdir):
+ path = os.path.join(str(tmpdir), guid())
+ f = open(path, 'wb')
+ f.close()
+ with pa.memory_map(path, mode='r+b') as memory_map:
+ assert memory_map.size() == 0
+
+
+def test_memory_map_large_seeks():
+ check_large_seeks(pa.memory_map)
+
+
+def test_memory_map_close_remove(tmpdir):
+ # ARROW-6740: should be able to delete closed memory-mapped file (Windows)
+ path = os.path.join(str(tmpdir), guid())
+ mmap = pa.create_memory_map(path, 4096)
+ mmap.close()
+ assert mmap.closed
+ os.remove(path) # Shouldn't fail
+
+
+def test_memory_map_deref_remove(tmpdir):
+ path = os.path.join(str(tmpdir), guid())
+ pa.create_memory_map(path, 4096)
+ os.remove(path) # Shouldn't fail
+
+
+def test_os_file_writer(tmpdir):
+ SIZE = 4096
+ arr = np.random.randint(0, 256, size=SIZE).astype('u1')
+ data = arr.tobytes()[:SIZE]
+
+ path = os.path.join(str(tmpdir), guid())
+ with open(path, 'wb') as f:
+ f.write(data)
+
+ # Truncates file
+ f2 = pa.OSFile(path, mode='w')
+ f2.write(b'foo')
+
+ with pa.OSFile(path) as f3:
+ assert f3.size() == 3
+
+ with pytest.raises(IOError):
+ f2.read(5)
+
+
+def test_native_file_write_reject_unicode():
+ # ARROW-3227
+ nf = pa.BufferOutputStream()
+ with pytest.raises(TypeError):
+ nf.write('foo')
+
+
+def test_native_file_modes(tmpdir):
+ path = os.path.join(str(tmpdir), guid())
+ with open(path, 'wb') as f:
+ f.write(b'foooo')
+
+ with pa.OSFile(path, mode='r') as f:
+ assert f.mode == 'rb'
+ assert f.readable()
+ assert not f.writable()
+ assert f.seekable()
+
+ with pa.OSFile(path, mode='rb') as f:
+ assert f.mode == 'rb'
+ assert f.readable()
+ assert not f.writable()
+ assert f.seekable()
+
+ with pa.OSFile(path, mode='w') as f:
+ assert f.mode == 'wb'
+ assert not f.readable()
+ assert f.writable()
+ assert not f.seekable()
+
+ with pa.OSFile(path, mode='wb') as f:
+ assert f.mode == 'wb'
+ assert not f.readable()
+ assert f.writable()
+ assert not f.seekable()
+
+ with open(path, 'wb') as f:
+ f.write(b'foooo')
+
+ with pa.memory_map(path, 'r') as f:
+ assert f.mode == 'rb'
+ assert f.readable()
+ assert not f.writable()
+ assert f.seekable()
+
+ with pa.memory_map(path, 'r+') as f:
+ assert f.mode == 'rb+'
+ assert f.readable()
+ assert f.writable()
+ assert f.seekable()
+
+ with pa.memory_map(path, 'r+b') as f:
+ assert f.mode == 'rb+'
+ assert f.readable()
+ assert f.writable()
+ assert f.seekable()
+
+
+def test_native_file_permissions(tmpdir):
+ # ARROW-10124: permissions of created files should follow umask
+ cur_umask = os.umask(0o002)
+ os.umask(cur_umask)
+
+ path = os.path.join(str(tmpdir), guid())
+ with pa.OSFile(path, mode='w'):
+ pass
+ assert os.stat(path).st_mode & 0o777 == 0o666 & ~cur_umask
+
+ path = os.path.join(str(tmpdir), guid())
+ with pa.memory_map(path, 'w'):
+ pass
+ assert os.stat(path).st_mode & 0o777 == 0o666 & ~cur_umask
+
+
+def test_native_file_raises_ValueError_after_close(tmpdir):
+ path = os.path.join(str(tmpdir), guid())
+ with open(path, 'wb') as f:
+ f.write(b'foooo')
+
+ with pa.OSFile(path, mode='rb') as os_file:
+ assert not os_file.closed
+ assert os_file.closed
+
+ with pa.memory_map(path, mode='rb') as mmap_file:
+ assert not mmap_file.closed
+ assert mmap_file.closed
+
+ files = [os_file,
+ mmap_file]
+
+ methods = [('tell', ()),
+ ('seek', (0,)),
+ ('size', ()),
+ ('flush', ()),
+ ('readable', ()),
+ ('writable', ()),
+ ('seekable', ())]
+
+ for f in files:
+ for method, args in methods:
+ with pytest.raises(ValueError):
+ getattr(f, method)(*args)
+
+
+def test_native_file_TextIOWrapper(tmpdir):
+ data = ('foooo\n'
+ 'barrr\n'
+ 'bazzz\n')
+
+ path = os.path.join(str(tmpdir), guid())
+ with open(path, 'wb') as f:
+ f.write(data.encode('utf-8'))
+
+ with TextIOWrapper(pa.OSFile(path, mode='rb')) as fil:
+ assert fil.readable()
+ res = fil.read()
+ assert res == data
+ assert fil.closed
+
+ with TextIOWrapper(pa.OSFile(path, mode='rb')) as fil:
+ # Iteration works
+ lines = list(fil)
+ assert ''.join(lines) == data
+
+ # Writing
+ path2 = os.path.join(str(tmpdir), guid())
+ with TextIOWrapper(pa.OSFile(path2, mode='wb')) as fil:
+ assert fil.writable()
+ fil.write(data)
+
+ with TextIOWrapper(pa.OSFile(path2, mode='rb')) as fil:
+ res = fil.read()
+ assert res == data
+
+
+def test_native_file_open_error():
+ with assert_file_not_found():
+ pa.OSFile('non_existent_file', 'rb')
+ with assert_file_not_found():
+ pa.memory_map('non_existent_file', 'rb')
+
+
+# ----------------------------------------------------------------------
+# Buffered streams
+
+def test_buffered_input_stream():
+ raw = pa.BufferReader(b"123456789")
+ f = pa.BufferedInputStream(raw, buffer_size=4)
+ assert f.read(2) == b"12"
+ assert raw.tell() == 4
+ f.close()
+ assert f.closed
+ assert raw.closed
+
+
+def test_buffered_input_stream_detach_seekable():
+ # detach() to a seekable file (io::RandomAccessFile in C++)
+ f = pa.BufferedInputStream(pa.BufferReader(b"123456789"), buffer_size=4)
+ assert f.read(2) == b"12"
+ raw = f.detach()
+ assert f.closed
+ assert not raw.closed
+ assert raw.seekable()
+ assert raw.read(4) == b"5678"
+ raw.seek(2)
+ assert raw.read(4) == b"3456"
+
+
+def test_buffered_input_stream_detach_non_seekable():
+ # detach() to a non-seekable file (io::InputStream in C++)
+ f = pa.BufferedInputStream(
+ pa.BufferedInputStream(pa.BufferReader(b"123456789"), buffer_size=4),
+ buffer_size=4)
+ assert f.read(2) == b"12"
+ raw = f.detach()
+ assert f.closed
+ assert not raw.closed
+ assert not raw.seekable()
+ assert raw.read(4) == b"5678"
+ with pytest.raises(EnvironmentError):
+ raw.seek(2)
+
+
+def test_buffered_output_stream():
+ np_buf = np.zeros(100, dtype=np.int8) # zero-initialized buffer
+ buf = pa.py_buffer(np_buf)
+
+ raw = pa.FixedSizeBufferWriter(buf)
+ f = pa.BufferedOutputStream(raw, buffer_size=4)
+ f.write(b"12")
+ assert np_buf[:4].tobytes() == b'\0\0\0\0'
+ f.flush()
+ assert np_buf[:4].tobytes() == b'12\0\0'
+ f.write(b"3456789")
+ f.close()
+ assert f.closed
+ assert raw.closed
+ assert np_buf[:10].tobytes() == b'123456789\0'
+
+
+def test_buffered_output_stream_detach():
+ np_buf = np.zeros(100, dtype=np.int8) # zero-initialized buffer
+ buf = pa.py_buffer(np_buf)
+
+ f = pa.BufferedOutputStream(pa.FixedSizeBufferWriter(buf), buffer_size=4)
+ f.write(b"12")
+ assert np_buf[:4].tobytes() == b'\0\0\0\0'
+ raw = f.detach()
+ assert f.closed
+ assert not raw.closed
+ assert np_buf[:4].tobytes() == b'12\0\0'
+
+
+# ----------------------------------------------------------------------
+# Compressed input and output streams
+
+def check_compressed_input(data, fn, compression):
+ raw = pa.OSFile(fn, mode="rb")
+ with pa.CompressedInputStream(raw, compression) as compressed:
+ assert not compressed.closed
+ assert compressed.readable()
+ assert not compressed.writable()
+ assert not compressed.seekable()
+ got = compressed.read()
+ assert got == data
+ assert compressed.closed
+ assert raw.closed
+
+ # Same with read_buffer()
+ raw = pa.OSFile(fn, mode="rb")
+ with pa.CompressedInputStream(raw, compression) as compressed:
+ buf = compressed.read_buffer()
+ assert isinstance(buf, pa.Buffer)
+ assert buf.to_pybytes() == data
+
+
+@pytest.mark.gzip
+def test_compressed_input_gzip(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ fn = str(tmpdir / "compressed_input_test.gz")
+ with gzip.open(fn, "wb") as f:
+ f.write(data)
+ check_compressed_input(data, fn, "gzip")
+
+
+def test_compressed_input_bz2(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ fn = str(tmpdir / "compressed_input_test.bz2")
+ with bz2.BZ2File(fn, "w") as f:
+ f.write(data)
+ try:
+ check_compressed_input(data, fn, "bz2")
+ except NotImplementedError as e:
+ pytest.skip(str(e))
+
+
+@pytest.mark.gzip
+def test_compressed_input_openfile(tmpdir):
+ if not Codec.is_available("gzip"):
+ pytest.skip("gzip support is not built")
+
+ data = b"some test data\n" * 10 + b"eof\n"
+ fn = str(tmpdir / "test_compressed_input_openfile.gz")
+ with gzip.open(fn, "wb") as f:
+ f.write(data)
+
+ with pa.CompressedInputStream(fn, "gzip") as compressed:
+ buf = compressed.read_buffer()
+ assert buf.to_pybytes() == data
+ assert compressed.closed
+
+ with pa.CompressedInputStream(pathlib.Path(fn), "gzip") as compressed:
+ buf = compressed.read_buffer()
+ assert buf.to_pybytes() == data
+ assert compressed.closed
+
+ f = open(fn, "rb")
+ with pa.CompressedInputStream(f, "gzip") as compressed:
+ buf = compressed.read_buffer()
+ assert buf.to_pybytes() == data
+ assert f.closed
+
+
+def check_compressed_concatenated(data, fn, compression):
+ raw = pa.OSFile(fn, mode="rb")
+ with pa.CompressedInputStream(raw, compression) as compressed:
+ got = compressed.read()
+ assert got == data
+
+
+@pytest.mark.gzip
+def test_compressed_concatenated_gzip(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ fn = str(tmpdir / "compressed_input_test2.gz")
+ with gzip.open(fn, "wb") as f:
+ f.write(data[:50])
+ with gzip.open(fn, "ab") as f:
+ f.write(data[50:])
+ check_compressed_concatenated(data, fn, "gzip")
+
+
+@pytest.mark.gzip
+def test_compressed_input_invalid():
+ data = b"foo" * 10
+ raw = pa.BufferReader(data)
+ with pytest.raises(ValueError):
+ pa.CompressedInputStream(raw, "unknown_compression")
+ with pytest.raises(TypeError):
+ pa.CompressedInputStream(raw, None)
+
+ with pa.CompressedInputStream(raw, "gzip") as compressed:
+ with pytest.raises(IOError, match="zlib inflate failed"):
+ compressed.read()
+
+
+def make_compressed_output(data, fn, compression):
+ raw = pa.BufferOutputStream()
+ with pa.CompressedOutputStream(raw, compression) as compressed:
+ assert not compressed.closed
+ assert not compressed.readable()
+ assert compressed.writable()
+ assert not compressed.seekable()
+ compressed.write(data)
+ assert compressed.closed
+ assert raw.closed
+ with open(fn, "wb") as f:
+ f.write(raw.getvalue())
+
+
+@pytest.mark.gzip
+def test_compressed_output_gzip(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ fn = str(tmpdir / "compressed_output_test.gz")
+ make_compressed_output(data, fn, "gzip")
+ with gzip.open(fn, "rb") as f:
+ got = f.read()
+ assert got == data
+
+
+def test_compressed_output_bz2(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ fn = str(tmpdir / "compressed_output_test.bz2")
+ try:
+ make_compressed_output(data, fn, "bz2")
+ except NotImplementedError as e:
+ pytest.skip(str(e))
+ with bz2.BZ2File(fn, "r") as f:
+ got = f.read()
+ assert got == data
+
+
+def test_output_stream_constructor(tmpdir):
+ if not Codec.is_available("gzip"):
+ pytest.skip("gzip support is not built")
+ with pa.CompressedOutputStream(tmpdir / "ctor.gz", "gzip") as stream:
+ stream.write(b"test")
+ with (tmpdir / "ctor2.gz").open("wb") as f:
+ with pa.CompressedOutputStream(f, "gzip") as stream:
+ stream.write(b"test")
+
+
+@pytest.mark.parametrize(("path", "expected_compression"), [
+ ("file.bz2", "bz2"),
+ ("file.lz4", "lz4"),
+ (pathlib.Path("file.gz"), "gzip"),
+ (pathlib.Path("path/to/file.zst"), "zstd"),
+])
+def test_compression_detection(path, expected_compression):
+ if not Codec.is_available(expected_compression):
+ with pytest.raises(pa.lib.ArrowNotImplementedError):
+ Codec.detect(path)
+ else:
+ codec = Codec.detect(path)
+ assert isinstance(codec, Codec)
+ assert codec.name == expected_compression
+
+
+def test_unknown_compression_raises():
+ with pytest.raises(ValueError):
+ Codec.is_available('unknown')
+ with pytest.raises(TypeError):
+ Codec(None)
+ with pytest.raises(ValueError):
+ Codec('unknown')
+
+
+@pytest.mark.parametrize("compression", [
+ "bz2",
+ "brotli",
+ "gzip",
+ "lz4",
+ "zstd",
+ pytest.param(
+ "snappy",
+ marks=pytest.mark.xfail(raises=pa.lib.ArrowNotImplementedError)
+ )
+])
+def test_compressed_roundtrip(compression):
+ if not Codec.is_available(compression):
+ pytest.skip("{} support is not built".format(compression))
+
+ data = b"some test data\n" * 10 + b"eof\n"
+ raw = pa.BufferOutputStream()
+ with pa.CompressedOutputStream(raw, compression) as compressed:
+ compressed.write(data)
+
+ cdata = raw.getvalue()
+ assert len(cdata) < len(data)
+ raw = pa.BufferReader(cdata)
+ with pa.CompressedInputStream(raw, compression) as compressed:
+ got = compressed.read()
+ assert got == data
+
+
+@pytest.mark.parametrize(
+ "compression",
+ ["bz2", "brotli", "gzip", "lz4", "zstd"]
+)
+def test_compressed_recordbatch_stream(compression):
+ if not Codec.is_available(compression):
+ pytest.skip("{} support is not built".format(compression))
+
+ # ARROW-4836: roundtrip a RecordBatch through a compressed stream
+ table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5])], ['a'])
+ raw = pa.BufferOutputStream()
+ stream = pa.CompressedOutputStream(raw, compression)
+ writer = pa.RecordBatchStreamWriter(stream, table.schema)
+ writer.write_table(table, max_chunksize=3)
+ writer.close()
+ stream.close() # Flush data
+ buf = raw.getvalue()
+ stream = pa.CompressedInputStream(pa.BufferReader(buf), compression)
+ got_table = pa.RecordBatchStreamReader(stream).read_all()
+ assert got_table == table
+
+
+# ----------------------------------------------------------------------
+# Transform input streams
+
+unicode_transcoding_example = (
+ "Dès Noël où un zéphyr haï me vêt de glaçons würmiens "
+ "je dîne d’exquis rôtis de bœuf au kir à l’aÿ d’âge mûr & cætera !"
+)
+
+
+def check_transcoding(data, src_encoding, dest_encoding, chunk_sizes):
+ chunk_sizes = iter(chunk_sizes)
+ stream = pa.transcoding_input_stream(
+ pa.BufferReader(data.encode(src_encoding)),
+ src_encoding, dest_encoding)
+ out = []
+ while True:
+ buf = stream.read(next(chunk_sizes))
+ out.append(buf)
+ if not buf:
+ break
+ out = b''.join(out)
+ assert out.decode(dest_encoding) == data
+
+
+@pytest.mark.parametrize('src_encoding, dest_encoding',
+ [('utf-8', 'utf-16'),
+ ('utf-16', 'utf-8'),
+ ('utf-8', 'utf-32-le'),
+ ('utf-8', 'utf-32-be'),
+ ])
+def test_transcoding_input_stream(src_encoding, dest_encoding):
+ # All at once
+ check_transcoding(unicode_transcoding_example,
+ src_encoding, dest_encoding, [1000, 0])
+ # Incremental
+ check_transcoding(unicode_transcoding_example,
+ src_encoding, dest_encoding,
+ itertools.cycle([1, 2, 3, 5]))
+
+
+@pytest.mark.parametrize('src_encoding, dest_encoding',
+ [('utf-8', 'utf-8'),
+ ('utf-8', 'UTF8')])
+def test_transcoding_no_ops(src_encoding, dest_encoding):
+ # No indirection is wasted when a trivial transcoding is requested
+ stream = pa.BufferReader(b"abc123")
+ assert pa.transcoding_input_stream(
+ stream, src_encoding, dest_encoding) is stream
+
+
+@pytest.mark.parametrize('src_encoding, dest_encoding',
+ [('utf-8', 'ascii'),
+ ('utf-8', 'latin-1'),
+ ])
+def test_transcoding_encoding_error(src_encoding, dest_encoding):
+ # Character \u0100 cannot be represented in the destination encoding
+ stream = pa.transcoding_input_stream(
+ pa.BufferReader("\u0100".encode(src_encoding)),
+ src_encoding,
+ dest_encoding)
+ with pytest.raises(UnicodeEncodeError):
+ stream.read(1)
+
+
+@pytest.mark.parametrize('src_encoding, dest_encoding',
+ [('utf-8', 'utf-16'),
+ ('utf-16', 'utf-8'),
+ ])
+def test_transcoding_decoding_error(src_encoding, dest_encoding):
+ # The given bytestring is not valid in the source encoding
+ stream = pa.transcoding_input_stream(
+ pa.BufferReader(b"\xff\xff\xff\xff"),
+ src_encoding,
+ dest_encoding)
+ with pytest.raises(UnicodeError):
+ stream.read(1)
+
+
+# ----------------------------------------------------------------------
+# High-level API
+
+@pytest.mark.gzip
+def test_input_stream_buffer():
+ data = b"some test data\n" * 10 + b"eof\n"
+ for arg in [pa.py_buffer(data), memoryview(data)]:
+ stream = pa.input_stream(arg)
+ assert stream.read() == data
+
+ gz_data = gzip.compress(data)
+ stream = pa.input_stream(memoryview(gz_data))
+ assert stream.read() == gz_data
+ stream = pa.input_stream(memoryview(gz_data), compression='gzip')
+ assert stream.read() == data
+
+
+def test_input_stream_duck_typing():
+ # Accept objects having the right file-like methods...
+ class DuckReader:
+
+ def close(self):
+ pass
+
+ @property
+ def closed(self):
+ return False
+
+ def read(self, nbytes=None):
+ return b'hello'
+
+ stream = pa.input_stream(DuckReader())
+ assert stream.read(5) == b'hello'
+
+
+def test_input_stream_file_path(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ file_path = tmpdir / 'input_stream'
+ with open(str(file_path), 'wb') as f:
+ f.write(data)
+
+ stream = pa.input_stream(file_path)
+ assert stream.read() == data
+ stream = pa.input_stream(str(file_path))
+ assert stream.read() == data
+ stream = pa.input_stream(pathlib.Path(str(file_path)))
+ assert stream.read() == data
+
+
+@pytest.mark.gzip
+def test_input_stream_file_path_compressed(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ gz_data = gzip.compress(data)
+ file_path = tmpdir / 'input_stream.gz'
+ with open(str(file_path), 'wb') as f:
+ f.write(gz_data)
+
+ stream = pa.input_stream(file_path)
+ assert stream.read() == data
+ stream = pa.input_stream(str(file_path))
+ assert stream.read() == data
+ stream = pa.input_stream(pathlib.Path(str(file_path)))
+ assert stream.read() == data
+
+ stream = pa.input_stream(file_path, compression='gzip')
+ assert stream.read() == data
+ stream = pa.input_stream(file_path, compression=None)
+ assert stream.read() == gz_data
+
+
+def test_input_stream_file_path_buffered(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ file_path = tmpdir / 'input_stream.buffered'
+ with open(str(file_path), 'wb') as f:
+ f.write(data)
+
+ stream = pa.input_stream(file_path, buffer_size=32)
+ assert isinstance(stream, pa.BufferedInputStream)
+ assert stream.read() == data
+ stream = pa.input_stream(str(file_path), buffer_size=64)
+ assert isinstance(stream, pa.BufferedInputStream)
+ assert stream.read() == data
+ stream = pa.input_stream(pathlib.Path(str(file_path)), buffer_size=1024)
+ assert isinstance(stream, pa.BufferedInputStream)
+ assert stream.read() == data
+
+ unbuffered_stream = pa.input_stream(file_path, buffer_size=0)
+ assert isinstance(unbuffered_stream, pa.OSFile)
+
+ msg = 'Buffer size must be larger than zero'
+ with pytest.raises(ValueError, match=msg):
+ pa.input_stream(file_path, buffer_size=-1)
+ with pytest.raises(TypeError):
+ pa.input_stream(file_path, buffer_size='million')
+
+
+@pytest.mark.gzip
+def test_input_stream_file_path_compressed_and_buffered(tmpdir):
+ data = b"some test data\n" * 100 + b"eof\n"
+ gz_data = gzip.compress(data)
+ file_path = tmpdir / 'input_stream_compressed_and_buffered.gz'
+ with open(str(file_path), 'wb') as f:
+ f.write(gz_data)
+
+ stream = pa.input_stream(file_path, buffer_size=32, compression='gzip')
+ assert stream.read() == data
+ stream = pa.input_stream(str(file_path), buffer_size=64)
+ assert stream.read() == data
+ stream = pa.input_stream(pathlib.Path(str(file_path)), buffer_size=1024)
+ assert stream.read() == data
+
+
+@pytest.mark.gzip
+def test_input_stream_python_file(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ bio = BytesIO(data)
+
+ stream = pa.input_stream(bio)
+ assert stream.read() == data
+
+ gz_data = gzip.compress(data)
+ bio = BytesIO(gz_data)
+ stream = pa.input_stream(bio)
+ assert stream.read() == gz_data
+ bio.seek(0)
+ stream = pa.input_stream(bio, compression='gzip')
+ assert stream.read() == data
+
+ file_path = tmpdir / 'input_stream'
+ with open(str(file_path), 'wb') as f:
+ f.write(data)
+ with open(str(file_path), 'rb') as f:
+ stream = pa.input_stream(f)
+ assert stream.read() == data
+
+
+@pytest.mark.gzip
+def test_input_stream_native_file():
+ data = b"some test data\n" * 10 + b"eof\n"
+ gz_data = gzip.compress(data)
+ reader = pa.BufferReader(gz_data)
+ stream = pa.input_stream(reader)
+ assert stream is reader
+ reader = pa.BufferReader(gz_data)
+ stream = pa.input_stream(reader, compression='gzip')
+ assert stream.read() == data
+
+
+def test_input_stream_errors(tmpdir):
+ buf = memoryview(b"")
+ with pytest.raises(ValueError):
+ pa.input_stream(buf, compression="foo")
+
+ for arg in [bytearray(), StringIO()]:
+ with pytest.raises(TypeError):
+ pa.input_stream(arg)
+
+ with assert_file_not_found():
+ pa.input_stream("non_existent_file")
+
+ with open(str(tmpdir / 'new_file'), 'wb') as f:
+ with pytest.raises(TypeError, match="readable file expected"):
+ pa.input_stream(f)
+
+
+def test_output_stream_buffer():
+ data = b"some test data\n" * 10 + b"eof\n"
+ buf = bytearray(len(data))
+ stream = pa.output_stream(pa.py_buffer(buf))
+ stream.write(data)
+ assert buf == data
+
+ buf = bytearray(len(data))
+ stream = pa.output_stream(memoryview(buf))
+ stream.write(data)
+ assert buf == data
+
+
+def test_output_stream_duck_typing():
+ # Accept objects having the right file-like methods...
+ class DuckWriter:
+ def __init__(self):
+ self.buf = pa.BufferOutputStream()
+
+ def close(self):
+ pass
+
+ @property
+ def closed(self):
+ return False
+
+ def write(self, data):
+ self.buf.write(data)
+
+ duck_writer = DuckWriter()
+ stream = pa.output_stream(duck_writer)
+ assert stream.write(b'hello')
+ assert duck_writer.buf.getvalue().to_pybytes() == b'hello'
+
+
+def test_output_stream_file_path(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ file_path = tmpdir / 'output_stream'
+
+ def check_data(file_path, data):
+ with pa.output_stream(file_path) as stream:
+ stream.write(data)
+ with open(str(file_path), 'rb') as f:
+ assert f.read() == data
+
+ check_data(file_path, data)
+ check_data(str(file_path), data)
+ check_data(pathlib.Path(str(file_path)), data)
+
+
+@pytest.mark.gzip
+def test_output_stream_file_path_compressed(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ file_path = tmpdir / 'output_stream.gz'
+
+ def check_data(file_path, data, **kwargs):
+ with pa.output_stream(file_path, **kwargs) as stream:
+ stream.write(data)
+ with open(str(file_path), 'rb') as f:
+ return f.read()
+
+ assert gzip.decompress(check_data(file_path, data)) == data
+ assert gzip.decompress(check_data(str(file_path), data)) == data
+ assert gzip.decompress(
+ check_data(pathlib.Path(str(file_path)), data)) == data
+
+ assert gzip.decompress(
+ check_data(file_path, data, compression='gzip')) == data
+ assert check_data(file_path, data, compression=None) == data
+
+ with pytest.raises(ValueError, match='Invalid value for compression'):
+ assert check_data(file_path, data, compression='rabbit') == data
+
+
+def test_output_stream_file_path_buffered(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+ file_path = tmpdir / 'output_stream.buffered'
+
+ def check_data(file_path, data, **kwargs):
+ with pa.output_stream(file_path, **kwargs) as stream:
+ if kwargs.get('buffer_size', 0) > 0:
+ assert isinstance(stream, pa.BufferedOutputStream)
+ stream.write(data)
+ with open(str(file_path), 'rb') as f:
+ return f.read()
+
+ unbuffered_stream = pa.output_stream(file_path, buffer_size=0)
+ assert isinstance(unbuffered_stream, pa.OSFile)
+
+ msg = 'Buffer size must be larger than zero'
+ with pytest.raises(ValueError, match=msg):
+ assert check_data(file_path, data, buffer_size=-128) == data
+
+ assert check_data(file_path, data, buffer_size=32) == data
+ assert check_data(file_path, data, buffer_size=1024) == data
+ assert check_data(str(file_path), data, buffer_size=32) == data
+
+ result = check_data(pathlib.Path(str(file_path)), data, buffer_size=32)
+ assert result == data
+
+
+@pytest.mark.gzip
+def test_output_stream_file_path_compressed_and_buffered(tmpdir):
+ data = b"some test data\n" * 100 + b"eof\n"
+ file_path = tmpdir / 'output_stream_compressed_and_buffered.gz'
+
+ def check_data(file_path, data, **kwargs):
+ with pa.output_stream(file_path, **kwargs) as stream:
+ stream.write(data)
+ with open(str(file_path), 'rb') as f:
+ return f.read()
+
+ result = check_data(file_path, data, buffer_size=32)
+ assert gzip.decompress(result) == data
+
+ result = check_data(file_path, data, buffer_size=1024)
+ assert gzip.decompress(result) == data
+
+ result = check_data(file_path, data, buffer_size=1024, compression='gzip')
+ assert gzip.decompress(result) == data
+
+
+def test_output_stream_destructor(tmpdir):
+ # The wrapper returned by pa.output_stream() should respect Python
+ # file semantics, i.e. destroying it should close the underlying
+ # file cleanly.
+ data = b"some test data\n"
+ file_path = tmpdir / 'output_stream.buffered'
+
+ def check_data(file_path, data, **kwargs):
+ stream = pa.output_stream(file_path, **kwargs)
+ stream.write(data)
+ del stream
+ gc.collect()
+ with open(str(file_path), 'rb') as f:
+ return f.read()
+
+ assert check_data(file_path, data, buffer_size=0) == data
+ assert check_data(file_path, data, buffer_size=1024) == data
+
+
+@pytest.mark.gzip
+def test_output_stream_python_file(tmpdir):
+ data = b"some test data\n" * 10 + b"eof\n"
+
+ def check_data(data, **kwargs):
+ # XXX cannot use BytesIO because stream.close() is necessary
+ # to finish writing compressed data, but it will also close the
+ # underlying BytesIO
+ fn = str(tmpdir / 'output_stream_file')
+ with open(fn, 'wb') as f:
+ with pa.output_stream(f, **kwargs) as stream:
+ stream.write(data)
+ with open(fn, 'rb') as f:
+ return f.read()
+
+ assert check_data(data) == data
+ assert gzip.decompress(check_data(data, compression='gzip')) == data
+
+
+def test_output_stream_errors(tmpdir):
+ buf = memoryview(bytearray())
+ with pytest.raises(ValueError):
+ pa.output_stream(buf, compression="foo")
+
+ for arg in [bytearray(), StringIO()]:
+ with pytest.raises(TypeError):
+ pa.output_stream(arg)
+
+ fn = str(tmpdir / 'new_file')
+ with open(fn, 'wb') as f:
+ pass
+ with open(fn, 'rb') as f:
+ with pytest.raises(TypeError, match="writable file expected"):
+ pa.output_stream(f)
diff --git a/src/arrow/python/pyarrow/tests/test_ipc.py b/src/arrow/python/pyarrow/tests/test_ipc.py
new file mode 100644
index 000000000..87944bcc0
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_ipc.py
@@ -0,0 +1,999 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import UserList
+import io
+import pathlib
+import pytest
+import socket
+import threading
+import weakref
+
+import numpy as np
+
+import pyarrow as pa
+from pyarrow.tests.util import changed_environ
+
+
+try:
+ from pandas.testing import assert_frame_equal, assert_series_equal
+ import pandas as pd
+except ImportError:
+ pass
+
+
+class IpcFixture:
+ write_stats = None
+
+ def __init__(self, sink_factory=lambda: io.BytesIO()):
+ self._sink_factory = sink_factory
+ self.sink = self.get_sink()
+
+ def get_sink(self):
+ return self._sink_factory()
+
+ def get_source(self):
+ return self.sink.getvalue()
+
+ def write_batches(self, num_batches=5, as_table=False):
+ nrows = 5
+ schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])
+
+ writer = self._get_writer(self.sink, schema)
+
+ batches = []
+ for i in range(num_batches):
+ batch = pa.record_batch(
+ [np.random.randn(nrows),
+ ['foo', None, 'bar', 'bazbaz', 'qux']],
+ schema=schema)
+ batches.append(batch)
+
+ if as_table:
+ table = pa.Table.from_batches(batches)
+ writer.write_table(table)
+ else:
+ for batch in batches:
+ writer.write_batch(batch)
+
+ self.write_stats = writer.stats
+ writer.close()
+ return batches
+
+
+class FileFormatFixture(IpcFixture):
+
+ def _get_writer(self, sink, schema):
+ return pa.ipc.new_file(sink, schema)
+
+ def _check_roundtrip(self, as_table=False):
+ batches = self.write_batches(as_table=as_table)
+ file_contents = pa.BufferReader(self.get_source())
+
+ reader = pa.ipc.open_file(file_contents)
+
+ assert reader.num_record_batches == len(batches)
+
+ for i, batch in enumerate(batches):
+ # it works. Must convert back to DataFrame
+ batch = reader.get_batch(i)
+ assert batches[i].equals(batch)
+ assert reader.schema.equals(batches[0].schema)
+
+ assert isinstance(reader.stats, pa.ipc.ReadStats)
+ assert isinstance(self.write_stats, pa.ipc.WriteStats)
+ assert tuple(reader.stats) == tuple(self.write_stats)
+
+
+class StreamFormatFixture(IpcFixture):
+
+ # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
+ use_legacy_ipc_format = False
+ # ARROW-9395, for testing writing old metadata version
+ options = None
+
+ def _get_writer(self, sink, schema):
+ return pa.ipc.new_stream(
+ sink,
+ schema,
+ use_legacy_format=self.use_legacy_ipc_format,
+ options=self.options,
+ )
+
+
+class MessageFixture(IpcFixture):
+
+ def _get_writer(self, sink, schema):
+ return pa.RecordBatchStreamWriter(sink, schema)
+
+
+@pytest.fixture
+def ipc_fixture():
+ return IpcFixture()
+
+
+@pytest.fixture
+def file_fixture():
+ return FileFormatFixture()
+
+
+@pytest.fixture
+def stream_fixture():
+ return StreamFormatFixture()
+
+
+def test_empty_file():
+ buf = b''
+ with pytest.raises(pa.ArrowInvalid):
+ pa.ipc.open_file(pa.BufferReader(buf))
+
+
+def test_file_simple_roundtrip(file_fixture):
+ file_fixture._check_roundtrip(as_table=False)
+
+
+def test_file_write_table(file_fixture):
+ file_fixture._check_roundtrip(as_table=True)
+
+
+@pytest.mark.parametrize("sink_factory", [
+ lambda: io.BytesIO(),
+ lambda: pa.BufferOutputStream()
+])
+def test_file_read_all(sink_factory):
+ fixture = FileFormatFixture(sink_factory)
+
+ batches = fixture.write_batches()
+ file_contents = pa.BufferReader(fixture.get_source())
+
+ reader = pa.ipc.open_file(file_contents)
+
+ result = reader.read_all()
+ expected = pa.Table.from_batches(batches)
+ assert result.equals(expected)
+
+
+def test_open_file_from_buffer(file_fixture):
+ # ARROW-2859; APIs accept the buffer protocol
+ file_fixture.write_batches()
+ source = file_fixture.get_source()
+
+ reader1 = pa.ipc.open_file(source)
+ reader2 = pa.ipc.open_file(pa.BufferReader(source))
+ reader3 = pa.RecordBatchFileReader(source)
+
+ result1 = reader1.read_all()
+ result2 = reader2.read_all()
+ result3 = reader3.read_all()
+
+ assert result1.equals(result2)
+ assert result1.equals(result3)
+
+ st1 = reader1.stats
+ assert st1.num_messages == 6
+ assert st1.num_record_batches == 5
+ assert reader2.stats == st1
+ assert reader3.stats == st1
+
+
+@pytest.mark.pandas
+def test_file_read_pandas(file_fixture):
+ frames = [batch.to_pandas() for batch in file_fixture.write_batches()]
+
+ file_contents = pa.BufferReader(file_fixture.get_source())
+ reader = pa.ipc.open_file(file_contents)
+ result = reader.read_pandas()
+
+ expected = pd.concat(frames).reset_index(drop=True)
+ assert_frame_equal(result, expected)
+
+
+def test_file_pathlib(file_fixture, tmpdir):
+ file_fixture.write_batches()
+ source = file_fixture.get_source()
+
+ path = tmpdir.join('file.arrow').strpath
+ with open(path, 'wb') as f:
+ f.write(source)
+
+ t1 = pa.ipc.open_file(pathlib.Path(path)).read_all()
+ t2 = pa.ipc.open_file(pa.OSFile(path)).read_all()
+
+ assert t1.equals(t2)
+
+
+def test_empty_stream():
+ buf = io.BytesIO(b'')
+ with pytest.raises(pa.ArrowInvalid):
+ pa.ipc.open_stream(buf)
+
+
+@pytest.mark.pandas
+def test_stream_categorical_roundtrip(stream_fixture):
+ df = pd.DataFrame({
+ 'one': np.random.randn(5),
+ 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
+ categories=['foo', 'bar'],
+ ordered=True)
+ })
+ batch = pa.RecordBatch.from_pandas(df)
+ with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr:
+ wr.write_batch(batch)
+
+ table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
+ .read_all())
+ assert_frame_equal(table.to_pandas(), df)
+
+
+def test_open_stream_from_buffer(stream_fixture):
+ # ARROW-2859
+ stream_fixture.write_batches()
+ source = stream_fixture.get_source()
+
+ reader1 = pa.ipc.open_stream(source)
+ reader2 = pa.ipc.open_stream(pa.BufferReader(source))
+ reader3 = pa.RecordBatchStreamReader(source)
+
+ result1 = reader1.read_all()
+ result2 = reader2.read_all()
+ result3 = reader3.read_all()
+
+ assert result1.equals(result2)
+ assert result1.equals(result3)
+
+ st1 = reader1.stats
+ assert st1.num_messages == 6
+ assert st1.num_record_batches == 5
+ assert reader2.stats == st1
+ assert reader3.stats == st1
+
+ assert tuple(st1) == tuple(stream_fixture.write_stats)
+
+
+@pytest.mark.pandas
+def test_stream_write_dispatch(stream_fixture):
+ # ARROW-1616
+ df = pd.DataFrame({
+ 'one': np.random.randn(5),
+ 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
+ categories=['foo', 'bar'],
+ ordered=True)
+ })
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ batch = pa.RecordBatch.from_pandas(df, preserve_index=False)
+ with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr:
+ wr.write(table)
+ wr.write(batch)
+
+ table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
+ .read_all())
+ assert_frame_equal(table.to_pandas(),
+ pd.concat([df, df], ignore_index=True))
+
+
+@pytest.mark.pandas
+def test_stream_write_table_batches(stream_fixture):
+ # ARROW-504
+ df = pd.DataFrame({
+ 'one': np.random.randn(20),
+ })
+
+ b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False)
+ b2 = pa.RecordBatch.from_pandas(df, preserve_index=False)
+
+ table = pa.Table.from_batches([b1, b2, b1])
+
+ with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr:
+ wr.write_table(table, max_chunksize=15)
+
+ batches = list(pa.ipc.open_stream(stream_fixture.get_source()))
+
+ assert list(map(len, batches)) == [10, 15, 5, 10]
+ result_table = pa.Table.from_batches(batches)
+ assert_frame_equal(result_table.to_pandas(),
+ pd.concat([df[:10], df, df[:10]],
+ ignore_index=True))
+
+
+@pytest.mark.parametrize('use_legacy_ipc_format', [False, True])
+def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format):
+ stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format
+ batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+ reader = pa.ipc.open_stream(file_contents)
+
+ assert reader.schema.equals(batches[0].schema)
+
+ total = 0
+ for i, next_batch in enumerate(reader):
+ assert next_batch.equals(batches[i])
+ total += 1
+
+ assert total == len(batches)
+
+ with pytest.raises(StopIteration):
+ reader.read_next_batch()
+
+
+@pytest.mark.zstd
+def test_compression_roundtrip():
+ sink = io.BytesIO()
+ values = np.random.randint(0, 10, 10000)
+ table = pa.Table.from_arrays([values], names=["values"])
+
+ options = pa.ipc.IpcWriteOptions(compression='zstd')
+ with pa.ipc.RecordBatchFileWriter(
+ sink, table.schema, options=options) as writer:
+ writer.write_table(table)
+ len1 = len(sink.getvalue())
+
+ sink2 = io.BytesIO()
+ codec = pa.Codec('zstd', compression_level=5)
+ options = pa.ipc.IpcWriteOptions(compression=codec)
+ with pa.ipc.RecordBatchFileWriter(
+ sink2, table.schema, options=options) as writer:
+ writer.write_table(table)
+ len2 = len(sink2.getvalue())
+
+ # In theory len2 should be less than len1 but for this test we just want
+ # to ensure compression_level is being correctly passed down to the C++
+ # layer so we don't really care if it makes it worse or better
+ assert len2 != len1
+
+ t1 = pa.ipc.open_file(sink).read_all()
+ t2 = pa.ipc.open_file(sink2).read_all()
+
+ assert t1 == t2
+
+
+def test_write_options():
+ options = pa.ipc.IpcWriteOptions()
+ assert options.allow_64bit is False
+ assert options.use_legacy_format is False
+ assert options.metadata_version == pa.ipc.MetadataVersion.V5
+
+ options.allow_64bit = True
+ assert options.allow_64bit is True
+
+ options.use_legacy_format = True
+ assert options.use_legacy_format is True
+
+ options.metadata_version = pa.ipc.MetadataVersion.V4
+ assert options.metadata_version == pa.ipc.MetadataVersion.V4
+ for value in ('V5', 42):
+ with pytest.raises((TypeError, ValueError)):
+ options.metadata_version = value
+
+ assert options.compression is None
+ for value in ['lz4', 'zstd']:
+ if pa.Codec.is_available(value):
+ options.compression = value
+ assert options.compression == value
+ options.compression = value.upper()
+ assert options.compression == value
+ options.compression = None
+ assert options.compression is None
+
+ with pytest.raises(TypeError):
+ options.compression = 0
+
+ assert options.use_threads is True
+ options.use_threads = False
+ assert options.use_threads is False
+
+ if pa.Codec.is_available('lz4'):
+ options = pa.ipc.IpcWriteOptions(
+ metadata_version=pa.ipc.MetadataVersion.V4,
+ allow_64bit=True,
+ use_legacy_format=True,
+ compression='lz4',
+ use_threads=False)
+ assert options.metadata_version == pa.ipc.MetadataVersion.V4
+ assert options.allow_64bit is True
+ assert options.use_legacy_format is True
+ assert options.compression == 'lz4'
+ assert options.use_threads is False
+
+
+def test_write_options_legacy_exclusive(stream_fixture):
+ with pytest.raises(
+ ValueError,
+ match="provide at most one of options and use_legacy_format"):
+ stream_fixture.use_legacy_ipc_format = True
+ stream_fixture.options = pa.ipc.IpcWriteOptions()
+ stream_fixture.write_batches()
+
+
+@pytest.mark.parametrize('options', [
+ pa.ipc.IpcWriteOptions(),
+ pa.ipc.IpcWriteOptions(allow_64bit=True),
+ pa.ipc.IpcWriteOptions(use_legacy_format=True),
+ pa.ipc.IpcWriteOptions(metadata_version=pa.ipc.MetadataVersion.V4),
+ pa.ipc.IpcWriteOptions(use_legacy_format=True,
+ metadata_version=pa.ipc.MetadataVersion.V4),
+])
+def test_stream_options_roundtrip(stream_fixture, options):
+ stream_fixture.use_legacy_ipc_format = None
+ stream_fixture.options = options
+ batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+
+ message = pa.ipc.read_message(stream_fixture.get_source())
+ assert message.metadata_version == options.metadata_version
+
+ reader = pa.ipc.open_stream(file_contents)
+
+ assert reader.schema.equals(batches[0].schema)
+
+ total = 0
+ for i, next_batch in enumerate(reader):
+ assert next_batch.equals(batches[i])
+ total += 1
+
+ assert total == len(batches)
+
+ with pytest.raises(StopIteration):
+ reader.read_next_batch()
+
+
+def test_dictionary_delta(stream_fixture):
+ ty = pa.dictionary(pa.int8(), pa.utf8())
+ data = [["foo", "foo", None],
+ ["foo", "bar", "foo"], # potential delta
+ ["foo", "bar"],
+ ["foo", None, "bar", "quux"], # potential delta
+ ["bar", "quux"], # replacement
+ ]
+ batches = [
+ pa.RecordBatch.from_arrays([pa.array(v, type=ty)], names=['dicts'])
+ for v in data]
+ schema = batches[0].schema
+
+ def write_batches():
+ with stream_fixture._get_writer(pa.MockOutputStream(),
+ schema) as writer:
+ for batch in batches:
+ writer.write_batch(batch)
+ return writer.stats
+
+ st = write_batches()
+ assert st.num_record_batches == 5
+ assert st.num_dictionary_batches == 4
+ assert st.num_replaced_dictionaries == 3
+ assert st.num_dictionary_deltas == 0
+
+ stream_fixture.use_legacy_ipc_format = None
+ stream_fixture.options = pa.ipc.IpcWriteOptions(
+ emit_dictionary_deltas=True)
+ st = write_batches()
+ assert st.num_record_batches == 5
+ assert st.num_dictionary_batches == 4
+ assert st.num_replaced_dictionaries == 1
+ assert st.num_dictionary_deltas == 2
+
+
+def test_envvar_set_legacy_ipc_format():
+ schema = pa.schema([pa.field('foo', pa.int32())])
+
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+
+ with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+
+ with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+
+ with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+ with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+
+
+def test_stream_read_all(stream_fixture):
+ batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+ reader = pa.ipc.open_stream(file_contents)
+
+ result = reader.read_all()
+ expected = pa.Table.from_batches(batches)
+ assert result.equals(expected)
+
+
+@pytest.mark.pandas
+def test_stream_read_pandas(stream_fixture):
+ frames = [batch.to_pandas() for batch in stream_fixture.write_batches()]
+ file_contents = stream_fixture.get_source()
+ reader = pa.ipc.open_stream(file_contents)
+ result = reader.read_pandas()
+
+ expected = pd.concat(frames).reset_index(drop=True)
+ assert_frame_equal(result, expected)
+
+
+@pytest.fixture
+def example_messages(stream_fixture):
+ batches = stream_fixture.write_batches()
+ file_contents = stream_fixture.get_source()
+ buf_reader = pa.BufferReader(file_contents)
+ reader = pa.MessageReader.open_stream(buf_reader)
+ return batches, list(reader)
+
+
+def test_message_ctors_no_segfault():
+ with pytest.raises(TypeError):
+ repr(pa.Message())
+
+ with pytest.raises(TypeError):
+ repr(pa.MessageReader())
+
+
+def test_message_reader(example_messages):
+ _, messages = example_messages
+
+ assert len(messages) == 6
+ assert messages[0].type == 'schema'
+ assert isinstance(messages[0].metadata, pa.Buffer)
+ assert isinstance(messages[0].body, pa.Buffer)
+ assert messages[0].metadata_version == pa.MetadataVersion.V5
+
+ for msg in messages[1:]:
+ assert msg.type == 'record batch'
+ assert isinstance(msg.metadata, pa.Buffer)
+ assert isinstance(msg.body, pa.Buffer)
+ assert msg.metadata_version == pa.MetadataVersion.V5
+
+
+def test_message_serialize_read_message(example_messages):
+ _, messages = example_messages
+
+ msg = messages[0]
+ buf = msg.serialize()
+ reader = pa.BufferReader(buf.to_pybytes() * 2)
+
+ restored = pa.ipc.read_message(buf)
+ restored2 = pa.ipc.read_message(reader)
+ restored3 = pa.ipc.read_message(buf.to_pybytes())
+ restored4 = pa.ipc.read_message(reader)
+
+ assert msg.equals(restored)
+ assert msg.equals(restored2)
+ assert msg.equals(restored3)
+ assert msg.equals(restored4)
+
+ with pytest.raises(pa.ArrowInvalid, match="Corrupted message"):
+ pa.ipc.read_message(pa.BufferReader(b'ab'))
+
+ with pytest.raises(EOFError):
+ pa.ipc.read_message(reader)
+
+
+@pytest.mark.gzip
+def test_message_read_from_compressed(example_messages):
+ # Part of ARROW-5910
+ _, messages = example_messages
+ for message in messages:
+ raw_out = pa.BufferOutputStream()
+ with pa.output_stream(raw_out, compression='gzip') as compressed_out:
+ message.serialize_to(compressed_out)
+
+ compressed_buf = raw_out.getvalue()
+
+ result = pa.ipc.read_message(pa.input_stream(compressed_buf,
+ compression='gzip'))
+ assert result.equals(message)
+
+
+def test_message_read_record_batch(example_messages):
+ batches, messages = example_messages
+
+ for batch, message in zip(batches, messages[1:]):
+ read_batch = pa.ipc.read_record_batch(message, batch.schema)
+ assert read_batch.equals(batch)
+
+
+def test_read_record_batch_on_stream_error_message():
+ # ARROW-5374
+ batch = pa.record_batch([pa.array([b"foo"], type=pa.utf8())],
+ names=['strs'])
+ stream = pa.BufferOutputStream()
+ with pa.ipc.new_stream(stream, batch.schema) as writer:
+ writer.write_batch(batch)
+ buf = stream.getvalue()
+ with pytest.raises(IOError,
+ match="type record batch but got schema"):
+ pa.ipc.read_record_batch(buf, batch.schema)
+
+
+# ----------------------------------------------------------------------
+# Socket streaming testa
+
+
+class StreamReaderServer(threading.Thread):
+
+ def init(self, do_read_all):
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock.bind(('127.0.0.1', 0))
+ self._sock.listen(1)
+ host, port = self._sock.getsockname()
+ self._do_read_all = do_read_all
+ self._schema = None
+ self._batches = []
+ self._table = None
+ return port
+
+ def run(self):
+ connection, client_address = self._sock.accept()
+ try:
+ source = connection.makefile(mode='rb')
+ reader = pa.ipc.open_stream(source)
+ self._schema = reader.schema
+ if self._do_read_all:
+ self._table = reader.read_all()
+ else:
+ for i, batch in enumerate(reader):
+ self._batches.append(batch)
+ finally:
+ connection.close()
+
+ def get_result(self):
+ return(self._schema, self._table if self._do_read_all
+ else self._batches)
+
+
+class SocketStreamFixture(IpcFixture):
+
+ def __init__(self):
+ # XXX(wesm): test will decide when to start socket server. This should
+ # probably be refactored
+ pass
+
+ def start_server(self, do_read_all):
+ self._server = StreamReaderServer()
+ port = self._server.init(do_read_all)
+ self._server.start()
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock.connect(('127.0.0.1', port))
+ self.sink = self.get_sink()
+
+ def stop_and_get_result(self):
+ import struct
+ self.sink.write(struct.pack('Q', 0))
+ self.sink.flush()
+ self._sock.close()
+ self._server.join()
+ return self._server.get_result()
+
+ def get_sink(self):
+ return self._sock.makefile(mode='wb')
+
+ def _get_writer(self, sink, schema):
+ return pa.RecordBatchStreamWriter(sink, schema)
+
+
+@pytest.fixture
+def socket_fixture():
+ return SocketStreamFixture()
+
+
+def test_socket_simple_roundtrip(socket_fixture):
+ socket_fixture.start_server(do_read_all=False)
+ writer_batches = socket_fixture.write_batches()
+ reader_schema, reader_batches = socket_fixture.stop_and_get_result()
+
+ assert reader_schema.equals(writer_batches[0].schema)
+ assert len(reader_batches) == len(writer_batches)
+ for i, batch in enumerate(writer_batches):
+ assert reader_batches[i].equals(batch)
+
+
+def test_socket_read_all(socket_fixture):
+ socket_fixture.start_server(do_read_all=True)
+ writer_batches = socket_fixture.write_batches()
+ _, result = socket_fixture.stop_and_get_result()
+
+ expected = pa.Table.from_batches(writer_batches)
+ assert result.equals(expected)
+
+
+# ----------------------------------------------------------------------
+# Miscellaneous IPC tests
+
+@pytest.mark.pandas
+def test_ipc_file_stream_has_eos():
+ # ARROW-5395
+ df = pd.DataFrame({'foo': [1.5]})
+ batch = pa.RecordBatch.from_pandas(df)
+ sink = pa.BufferOutputStream()
+ write_file(batch, sink)
+ buffer = sink.getvalue()
+
+ # skip the file magic
+ reader = pa.ipc.open_stream(buffer[8:])
+
+ # will fail if encounters footer data instead of eos
+ rdf = reader.read_pandas()
+
+ assert_frame_equal(df, rdf)
+
+
+@pytest.mark.pandas
+def test_ipc_zero_copy_numpy():
+ df = pd.DataFrame({'foo': [1.5]})
+
+ batch = pa.RecordBatch.from_pandas(df)
+ sink = pa.BufferOutputStream()
+ write_file(batch, sink)
+ buffer = sink.getvalue()
+ reader = pa.BufferReader(buffer)
+
+ batches = read_file(reader)
+
+ data = batches[0].to_pandas()
+ rdf = pd.DataFrame(data)
+ assert_frame_equal(df, rdf)
+
+
+def test_ipc_stream_no_batches():
+ # ARROW-2307
+ table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]),
+ pa.array(['foo', 'bar', 'baz', 'qux'])],
+ names=['a', 'b'])
+
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, table.schema):
+ pass
+
+ source = sink.getvalue()
+ with pa.ipc.open_stream(source) as reader:
+ result = reader.read_all()
+
+ assert result.schema.equals(table.schema)
+ assert len(result) == 0
+
+
+@pytest.mark.pandas
+def test_get_record_batch_size():
+ N = 10
+ itemsize = 8
+ df = pd.DataFrame({'foo': np.random.randn(N)})
+
+ batch = pa.RecordBatch.from_pandas(df)
+ assert pa.ipc.get_record_batch_size(batch) > (N * itemsize)
+
+
+@pytest.mark.pandas
+def _check_serialize_pandas_round_trip(df, use_threads=False):
+ buf = pa.serialize_pandas(df, nthreads=2 if use_threads else 1)
+ result = pa.deserialize_pandas(buf, use_threads=use_threads)
+ assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip():
+ index = pd.Index([1, 2, 3], name='my_index')
+ columns = ['foo', 'bar']
+ df = pd.DataFrame(
+ {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
+ index=index, columns=columns
+ )
+ _check_serialize_pandas_round_trip(df)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip_nthreads():
+ index = pd.Index([1, 2, 3], name='my_index')
+ columns = ['foo', 'bar']
+ df = pd.DataFrame(
+ {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
+ index=index, columns=columns
+ )
+ _check_serialize_pandas_round_trip(df, use_threads=True)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip_multi_index():
+ index1 = pd.Index([1, 2, 3], name='level_1')
+ index2 = pd.Index(list('def'), name=None)
+ index = pd.MultiIndex.from_arrays([index1, index2])
+
+ columns = ['foo', 'bar']
+ df = pd.DataFrame(
+ {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
+ index=index,
+ columns=columns,
+ )
+ _check_serialize_pandas_round_trip(df)
+
+
+@pytest.mark.pandas
+def test_serialize_pandas_empty_dataframe():
+ df = pd.DataFrame()
+ _check_serialize_pandas_round_trip(df)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip_not_string_columns():
+ df = pd.DataFrame(list(zip([1.5, 1.6, 1.7], 'abc')))
+ buf = pa.serialize_pandas(df)
+ result = pa.deserialize_pandas(buf)
+ assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+def test_serialize_pandas_no_preserve_index():
+ df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3])
+ expected = pd.DataFrame({'a': [1, 2, 3]})
+
+ buf = pa.serialize_pandas(df, preserve_index=False)
+ result = pa.deserialize_pandas(buf)
+ assert_frame_equal(result, expected)
+
+ buf = pa.serialize_pandas(df, preserve_index=True)
+ result = pa.deserialize_pandas(buf)
+ assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning")
+def test_serialize_with_pandas_objects():
+ df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3])
+ s = pd.Series([1, 2, 3, 4])
+
+ data = {
+ 'a_series': df['a'],
+ 'a_frame': df,
+ 's_series': s
+ }
+
+ serialized = pa.serialize(data).to_buffer()
+ deserialized = pa.deserialize(serialized)
+ assert_frame_equal(deserialized['a_frame'], df)
+
+ assert_series_equal(deserialized['a_series'], df['a'])
+ assert deserialized['a_series'].name == 'a'
+
+ assert_series_equal(deserialized['s_series'], s)
+ assert deserialized['s_series'].name is None
+
+
+@pytest.mark.pandas
+def test_schema_batch_serialize_methods():
+ nrows = 5
+ df = pd.DataFrame({
+ 'one': np.random.randn(nrows),
+ 'two': ['foo', np.nan, 'bar', 'bazbaz', 'qux']})
+ batch = pa.RecordBatch.from_pandas(df)
+
+ s_schema = batch.schema.serialize()
+ s_batch = batch.serialize()
+
+ recons_schema = pa.ipc.read_schema(s_schema)
+ recons_batch = pa.ipc.read_record_batch(s_batch, recons_schema)
+ assert recons_batch.equals(batch)
+
+
+def test_schema_serialization_with_metadata():
+ field_metadata = {b'foo': b'bar', b'kind': b'field'}
+ schema_metadata = {b'foo': b'bar', b'kind': b'schema'}
+
+ f0 = pa.field('a', pa.int8())
+ f1 = pa.field('b', pa.string(), metadata=field_metadata)
+
+ schema = pa.schema([f0, f1], metadata=schema_metadata)
+
+ s_schema = schema.serialize()
+ recons_schema = pa.ipc.read_schema(s_schema)
+
+ assert recons_schema.equals(schema)
+ assert recons_schema.metadata == schema_metadata
+ assert recons_schema[0].metadata is None
+ assert recons_schema[1].metadata == field_metadata
+
+
+def test_deprecated_pyarrow_ns_apis():
+ table = pa.table([pa.array([1, 2, 3, 4])], names=['a'])
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, table.schema) as writer:
+ writer.write(table)
+
+ with pytest.warns(FutureWarning,
+ match="please use pyarrow.ipc.open_stream"):
+ pa.open_stream(sink.getvalue())
+
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_file(sink, table.schema) as writer:
+ writer.write(table)
+ with pytest.warns(FutureWarning, match="please use pyarrow.ipc.open_file"):
+ pa.open_file(sink.getvalue())
+
+
+def write_file(batch, sink):
+ with pa.ipc.new_file(sink, batch.schema) as writer:
+ writer.write_batch(batch)
+
+
+def read_file(source):
+ with pa.ipc.open_file(source) as reader:
+ return [reader.get_batch(i) for i in range(reader.num_record_batches)]
+
+
+def test_write_empty_ipc_file():
+ # ARROW-3894: IPC file was not being properly initialized when no record
+ # batches are being written
+ schema = pa.schema([('field', pa.int64())])
+
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_file(sink, schema):
+ pass
+
+ buf = sink.getvalue()
+ with pa.RecordBatchFileReader(pa.BufferReader(buf)) as reader:
+ table = reader.read_all()
+ assert len(table) == 0
+ assert table.schema.equals(schema)
+
+
+def test_py_record_batch_reader():
+ def make_schema():
+ return pa.schema([('field', pa.int64())])
+
+ def make_batches():
+ schema = make_schema()
+ batch1 = pa.record_batch([[1, 2, 3]], schema=schema)
+ batch2 = pa.record_batch([[4, 5]], schema=schema)
+ return [batch1, batch2]
+
+ # With iterable
+ batches = UserList(make_batches()) # weakrefable
+ wr = weakref.ref(batches)
+
+ with pa.ipc.RecordBatchReader.from_batches(make_schema(),
+ batches) as reader:
+ batches = None
+ assert wr() is not None
+ assert list(reader) == make_batches()
+ assert wr() is None
+
+ # With iterator
+ batches = iter(UserList(make_batches())) # weakrefable
+ wr = weakref.ref(batches)
+
+ with pa.ipc.RecordBatchReader.from_batches(make_schema(),
+ batches) as reader:
+ batches = None
+ assert wr() is not None
+ assert list(reader) == make_batches()
+ assert wr() is None
diff --git a/src/arrow/python/pyarrow/tests/test_json.py b/src/arrow/python/pyarrow/tests/test_json.py
new file mode 100644
index 000000000..6ce584e51
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_json.py
@@ -0,0 +1,310 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import OrderedDict
+import io
+import itertools
+import json
+import string
+import unittest
+
+import numpy as np
+import pytest
+
+import pyarrow as pa
+from pyarrow.json import read_json, ReadOptions, ParseOptions
+
+
+def generate_col_names():
+ # 'a', 'b'... 'z', then 'aa', 'ab'...
+ letters = string.ascii_lowercase
+ yield from letters
+ for first in letters:
+ for second in letters:
+ yield first + second
+
+
+def make_random_json(num_cols=2, num_rows=10, linesep='\r\n'):
+ arr = np.random.RandomState(42).randint(0, 1000, size=(num_cols, num_rows))
+ col_names = list(itertools.islice(generate_col_names(), num_cols))
+ lines = []
+ for row in arr.T:
+ json_obj = OrderedDict([(k, int(v)) for (k, v) in zip(col_names, row)])
+ lines.append(json.dumps(json_obj))
+ data = linesep.join(lines).encode()
+ columns = [pa.array(col, type=pa.int64()) for col in arr]
+ expected = pa.Table.from_arrays(columns, col_names)
+ return data, expected
+
+
+def test_read_options():
+ cls = ReadOptions
+ opts = cls()
+
+ assert opts.block_size > 0
+ opts.block_size = 12345
+ assert opts.block_size == 12345
+
+ assert opts.use_threads is True
+ opts.use_threads = False
+ assert opts.use_threads is False
+
+ opts = cls(block_size=1234, use_threads=False)
+ assert opts.block_size == 1234
+ assert opts.use_threads is False
+
+
+def test_parse_options():
+ cls = ParseOptions
+ opts = cls()
+ assert opts.newlines_in_values is False
+ assert opts.explicit_schema is None
+
+ opts.newlines_in_values = True
+ assert opts.newlines_in_values is True
+
+ schema = pa.schema([pa.field('foo', pa.int32())])
+ opts.explicit_schema = schema
+ assert opts.explicit_schema == schema
+
+ assert opts.unexpected_field_behavior == "infer"
+ for value in ["ignore", "error", "infer"]:
+ opts.unexpected_field_behavior = value
+ assert opts.unexpected_field_behavior == value
+
+ with pytest.raises(ValueError):
+ opts.unexpected_field_behavior = "invalid-value"
+
+
+class BaseTestJSONRead:
+
+ def read_bytes(self, b, **kwargs):
+ return self.read_json(pa.py_buffer(b), **kwargs)
+
+ def check_names(self, table, names):
+ assert table.num_columns == len(names)
+ assert [c.name for c in table.columns] == names
+
+ def test_file_object(self):
+ data = b'{"a": 1, "b": 2}\n'
+ expected_data = {'a': [1], 'b': [2]}
+ bio = io.BytesIO(data)
+ table = self.read_json(bio)
+ assert table.to_pydict() == expected_data
+ # Text files not allowed
+ sio = io.StringIO(data.decode())
+ with pytest.raises(TypeError):
+ self.read_json(sio)
+
+ def test_block_sizes(self):
+ rows = b'{"a": 1}\n{"a": 2}\n{"a": 3}'
+ read_options = ReadOptions()
+ parse_options = ParseOptions()
+
+ for data in [rows, rows + b'\n']:
+ for newlines_in_values in [False, True]:
+ parse_options.newlines_in_values = newlines_in_values
+ read_options.block_size = 4
+ with pytest.raises(ValueError,
+ match="try to increase block size"):
+ self.read_bytes(data, read_options=read_options,
+ parse_options=parse_options)
+
+ # Validate reader behavior with various block sizes.
+ # There used to be bugs in this area.
+ for block_size in range(9, 20):
+ read_options.block_size = block_size
+ table = self.read_bytes(data, read_options=read_options,
+ parse_options=parse_options)
+ assert table.to_pydict() == {'a': [1, 2, 3]}
+
+ def test_no_newline_at_end(self):
+ rows = b'{"a": 1,"b": 2, "c": 3}\n{"a": 4,"b": 5, "c": 6}'
+ table = self.read_bytes(rows)
+ assert table.to_pydict() == {
+ 'a': [1, 4],
+ 'b': [2, 5],
+ 'c': [3, 6],
+ }
+
+ def test_simple_ints(self):
+ # Infer integer columns
+ rows = b'{"a": 1,"b": 2, "c": 3}\n{"a": 4,"b": 5, "c": 6}\n'
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.int64()),
+ ('b', pa.int64()),
+ ('c', pa.int64())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1, 4],
+ 'b': [2, 5],
+ 'c': [3, 6],
+ }
+
+ def test_simple_varied(self):
+ # Infer various kinds of data
+ rows = (b'{"a": 1,"b": 2, "c": "3", "d": false}\n'
+ b'{"a": 4.0, "b": -5, "c": "foo", "d": true}\n')
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.float64()),
+ ('b', pa.int64()),
+ ('c', pa.string()),
+ ('d', pa.bool_())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1.0, 4.0],
+ 'b': [2, -5],
+ 'c': ["3", "foo"],
+ 'd': [False, True],
+ }
+
+ def test_simple_nulls(self):
+ # Infer various kinds of data, with nulls
+ rows = (b'{"a": 1, "b": 2, "c": null, "d": null, "e": null}\n'
+ b'{"a": null, "b": -5, "c": "foo", "d": null, "e": true}\n'
+ b'{"a": 4.5, "b": null, "c": "nan", "d": null,"e": false}\n')
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.float64()),
+ ('b', pa.int64()),
+ ('c', pa.string()),
+ ('d', pa.null()),
+ ('e', pa.bool_())])
+ assert table.schema == schema
+ assert table.to_pydict() == {
+ 'a': [1.0, None, 4.5],
+ 'b': [2, -5, None],
+ 'c': [None, "foo", "nan"],
+ 'd': [None, None, None],
+ 'e': [None, True, False],
+ }
+
+ def test_empty_lists(self):
+ # ARROW-10955: Infer list(null)
+ rows = b'{"a": []}'
+ table = self.read_bytes(rows)
+ schema = pa.schema([('a', pa.list_(pa.null()))])
+ assert table.schema == schema
+ assert table.to_pydict() == {'a': [[]]}
+
+ def test_empty_rows(self):
+ rows = b'{}\n{}\n'
+ table = self.read_bytes(rows)
+ schema = pa.schema([])
+ assert table.schema == schema
+ assert table.num_columns == 0
+ assert table.num_rows == 2
+
+ def test_reconcile_accross_blocks(self):
+ # ARROW-12065: reconciling inferred types accross blocks
+ first_row = b'{ }\n'
+ read_options = ReadOptions(block_size=len(first_row))
+ for next_rows, expected_pylist in [
+ (b'{"a": 0}', [None, 0]),
+ (b'{"a": []}', [None, []]),
+ (b'{"a": []}\n{"a": [[1]]}', [None, [], [[1]]]),
+ (b'{"a": {}}', [None, {}]),
+ (b'{"a": {}}\n{"a": {"b": {"c": 1}}}',
+ [None, {"b": None}, {"b": {"c": 1}}]),
+ ]:
+ table = self.read_bytes(first_row + next_rows,
+ read_options=read_options)
+ expected = {"a": expected_pylist}
+ assert table.to_pydict() == expected
+ # Check that the issue was exercised
+ assert table.column("a").num_chunks > 1
+
+ def test_explicit_schema_with_unexpected_behaviour(self):
+ # infer by default
+ rows = (b'{"foo": "bar", "num": 0}\n'
+ b'{"foo": "baz", "num": 1}\n')
+ schema = pa.schema([
+ ('foo', pa.binary())
+ ])
+
+ opts = ParseOptions(explicit_schema=schema)
+ table = self.read_bytes(rows, parse_options=opts)
+ assert table.schema == pa.schema([
+ ('foo', pa.binary()),
+ ('num', pa.int64())
+ ])
+ assert table.to_pydict() == {
+ 'foo': [b'bar', b'baz'],
+ 'num': [0, 1],
+ }
+
+ # ignore the unexpected fields
+ opts = ParseOptions(explicit_schema=schema,
+ unexpected_field_behavior="ignore")
+ table = self.read_bytes(rows, parse_options=opts)
+ assert table.schema == pa.schema([
+ ('foo', pa.binary()),
+ ])
+ assert table.to_pydict() == {
+ 'foo': [b'bar', b'baz'],
+ }
+
+ # raise error
+ opts = ParseOptions(explicit_schema=schema,
+ unexpected_field_behavior="error")
+ with pytest.raises(pa.ArrowInvalid,
+ match="JSON parse error: unexpected field"):
+ self.read_bytes(rows, parse_options=opts)
+
+ def test_small_random_json(self):
+ data, expected = make_random_json(num_cols=2, num_rows=10)
+ table = self.read_bytes(data)
+ assert table.schema == expected.schema
+ assert table.equals(expected)
+ assert table.to_pydict() == expected.to_pydict()
+
+ def test_stress_block_sizes(self):
+ # Test a number of small block sizes to stress block stitching
+ data_base, expected = make_random_json(num_cols=2, num_rows=100)
+ read_options = ReadOptions()
+ parse_options = ParseOptions()
+
+ for data in [data_base, data_base.rstrip(b'\r\n')]:
+ for newlines_in_values in [False, True]:
+ parse_options.newlines_in_values = newlines_in_values
+ for block_size in [22, 23, 37]:
+ read_options.block_size = block_size
+ table = self.read_bytes(data, read_options=read_options,
+ parse_options=parse_options)
+ assert table.schema == expected.schema
+ if not table.equals(expected):
+ # Better error output
+ assert table.to_pydict() == expected.to_pydict()
+
+
+class TestSerialJSONRead(BaseTestJSONRead, unittest.TestCase):
+
+ def read_json(self, *args, **kwargs):
+ read_options = kwargs.setdefault('read_options', ReadOptions())
+ read_options.use_threads = False
+ table = read_json(*args, **kwargs)
+ table.validate(full=True)
+ return table
+
+
+class TestParallelJSONRead(BaseTestJSONRead, unittest.TestCase):
+
+ def read_json(self, *args, **kwargs):
+ read_options = kwargs.setdefault('read_options', ReadOptions())
+ read_options.use_threads = True
+ table = read_json(*args, **kwargs)
+ table.validate(full=True)
+ return table
diff --git a/src/arrow/python/pyarrow/tests/test_jvm.py b/src/arrow/python/pyarrow/tests/test_jvm.py
new file mode 100644
index 000000000..c5996f921
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_jvm.py
@@ -0,0 +1,433 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import os
+import pyarrow as pa
+import pyarrow.jvm as pa_jvm
+import pytest
+import sys
+import xml.etree.ElementTree as ET
+
+
+jpype = pytest.importorskip("jpype")
+
+
+@pytest.fixture(scope="session")
+def root_allocator():
+ # This test requires Arrow Java to be built in the same source tree
+ try:
+ arrow_dir = os.environ["ARROW_SOURCE_DIR"]
+ except KeyError:
+ arrow_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..')
+ pom_path = os.path.join(arrow_dir, 'java', 'pom.xml')
+ tree = ET.parse(pom_path)
+ version = tree.getroot().find(
+ 'POM:version',
+ namespaces={
+ 'POM': 'http://maven.apache.org/POM/4.0.0'
+ }).text
+ jar_path = os.path.join(
+ arrow_dir, 'java', 'tools', 'target',
+ 'arrow-tools-{}-jar-with-dependencies.jar'.format(version))
+ jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path)
+ kwargs = {}
+ # This will be the default behaviour in jpype 0.8+
+ kwargs['convertStrings'] = False
+ jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path,
+ **kwargs)
+ return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize)
+
+
+def test_jvm_buffer(root_allocator):
+ # Create a Java buffer
+ jvm_buffer = root_allocator.buffer(8)
+ for i in range(8):
+ jvm_buffer.setByte(i, 8 - i)
+
+ orig_refcnt = jvm_buffer.refCnt()
+
+ # Convert to Python
+ buf = pa_jvm.jvm_buffer(jvm_buffer)
+
+ # Check its content
+ assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01'
+
+ # Check Java buffer lifetime is tied to PyArrow buffer lifetime
+ assert jvm_buffer.refCnt() == orig_refcnt + 1
+ del buf
+ assert jvm_buffer.refCnt() == orig_refcnt
+
+
+def test_jvm_buffer_released(root_allocator):
+ import jpype.imports # noqa
+ from java.lang import IllegalArgumentException
+
+ jvm_buffer = root_allocator.buffer(8)
+ jvm_buffer.release()
+
+ with pytest.raises(IllegalArgumentException):
+ pa_jvm.jvm_buffer(jvm_buffer)
+
+
+def _jvm_field(jvm_spec):
+ om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
+ pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field')
+ return om.readValue(jvm_spec, pojo_Field)
+
+
+def _jvm_schema(jvm_spec, metadata=None):
+ field = _jvm_field(jvm_spec)
+ schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema')
+ fields = jpype.JClass('java.util.ArrayList')()
+ fields.add(field)
+ if metadata:
+ dct = jpype.JClass('java.util.HashMap')()
+ for k, v in metadata.items():
+ dct.put(k, v)
+ return schema_cls(fields, dct)
+ else:
+ return schema_cls(fields)
+
+
+# In the following, we use the JSON serialization of the Field objects in Java.
+# This ensures that we neither rely on the exact mechanics on how to construct
+# them using Java code as well as enables us to define them as parameters
+# without to invoke the JVM.
+#
+# The specifications were created using:
+#
+# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')()
+# field = … # Code to instantiate the field
+# jvm_spec = om.writeValueAsString(field)
+@pytest.mark.parametrize('pa_type,jvm_spec', [
+ (pa.null(), '{"name":"null"}'),
+ (pa.bool_(), '{"name":"bool"}'),
+ (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'),
+ (pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'),
+ (pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'),
+ (pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'),
+ (pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'),
+ (pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'),
+ (pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'),
+ (pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'),
+ (pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'),
+ (pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'),
+ (pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'),
+ (pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'),
+ (pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'),
+ (pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'),
+ (pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'),
+ (pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",'
+ '"timezone":null}'),
+ (pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",'
+ '"timezone":null}'),
+ (pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",'
+ '"timezone":null}'),
+ (pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",'
+ '"timezone":null}'),
+ (pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"'
+ ',"timezone":"UTC"}'),
+ (pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",'
+ '"unit":"NANOSECOND","timezone":"Europe/Paris"}'),
+ (pa.date32(), '{"name":"date","unit":"DAY"}'),
+ (pa.date64(), '{"name":"date","unit":"MILLISECOND"}'),
+ (pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'),
+ (pa.string(), '{"name":"utf8"}'),
+ (pa.binary(), '{"name":"binary"}'),
+ (pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'),
+ # TODO(ARROW-2609): complex types that have children
+ # pa.list_(pa.int32()),
+ # pa.struct([pa.field('a', pa.int32()),
+ # pa.field('b', pa.int8()),
+ # pa.field('c', pa.string())]),
+ # pa.union([pa.field('a', pa.binary(10)),
+ # pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE),
+ # pa.union([pa.field('a', pa.binary(10)),
+ # pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE),
+ # TODO: DictionaryType requires a vector in the type
+ # pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])),
+])
+@pytest.mark.parametrize('nullable', [True, False])
+def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable):
+ if pa_type == pa.null() and not nullable:
+ return
+ spec = {
+ 'name': 'field_name',
+ 'nullable': nullable,
+ 'type': json.loads(jvm_spec),
+ # TODO: This needs to be set for complex types
+ 'children': []
+ }
+ jvm_field = _jvm_field(json.dumps(spec))
+ result = pa_jvm.field(jvm_field)
+ expected_field = pa.field('field_name', pa_type, nullable=nullable)
+ assert result == expected_field
+
+ jvm_schema = _jvm_schema(json.dumps(spec))
+ result = pa_jvm.schema(jvm_schema)
+ assert result == pa.schema([expected_field])
+
+ # Schema with custom metadata
+ jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'})
+ result = pa_jvm.schema(jvm_schema)
+ assert result == pa.schema([expected_field], {'meta': 'data'})
+
+ # Schema with custom field metadata
+ spec['metadata'] = [{'key': 'field meta', 'value': 'field data'}]
+ jvm_schema = _jvm_schema(json.dumps(spec))
+ result = pa_jvm.schema(jvm_schema)
+ expected_field = expected_field.with_metadata(
+ {'field meta': 'field data'})
+ assert result == pa.schema([expected_field])
+
+
+# These test parameters mostly use an integer range as an input as this is
+# often the only type that is understood by both Python and Java
+# implementations of Arrow.
+@pytest.mark.parametrize('pa_type,py_data,jvm_type', [
+ (pa.bool_(), [True, False, True, True], 'BitVector'),
+ (pa.uint8(), list(range(128)), 'UInt1Vector'),
+ (pa.uint16(), list(range(128)), 'UInt2Vector'),
+ (pa.int32(), list(range(128)), 'IntVector'),
+ (pa.int64(), list(range(128)), 'BigIntVector'),
+ (pa.float32(), list(range(128)), 'Float4Vector'),
+ (pa.float64(), list(range(128)), 'Float8Vector'),
+ (pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'),
+ (pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'),
+ (pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'),
+ (pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'),
+ # TODO(ARROW-2605): These types miss a conversion from pure Python objects
+ # * pa.time32('s')
+ # * pa.time32('ms')
+ # * pa.time64('us')
+ # * pa.time64('ns')
+ (pa.date32(), list(range(128)), 'DateDayVector'),
+ (pa.date64(), list(range(128)), 'DateMilliVector'),
+ # TODO(ARROW-2606): pa.decimal128(19, 4)
+])
+def test_jvm_array(root_allocator, pa_type, py_data, jvm_type):
+ # Create vector
+ cls = "org.apache.arrow.vector.{}".format(jvm_type)
+ jvm_vector = jpype.JClass(cls)("vector", root_allocator)
+ jvm_vector.allocateNew(len(py_data))
+ for i, val in enumerate(py_data):
+ # char and int are ambiguous overloads for these two setSafe calls
+ if jvm_type in {'UInt1Vector', 'UInt2Vector'}:
+ val = jpype.JInt(val)
+ jvm_vector.setSafe(i, val)
+ jvm_vector.setValueCount(len(py_data))
+
+ py_array = pa.array(py_data, type=pa_type)
+ jvm_array = pa_jvm.array(jvm_vector)
+
+ assert py_array.equals(jvm_array)
+
+
+def test_jvm_array_empty(root_allocator):
+ cls = "org.apache.arrow.vector.{}".format('IntVector')
+ jvm_vector = jpype.JClass(cls)("vector", root_allocator)
+ jvm_vector.allocateNew()
+ jvm_array = pa_jvm.array(jvm_vector)
+ assert len(jvm_array) == 0
+ assert jvm_array.type == pa.int32()
+
+
+# These test parameters mostly use an integer range as an input as this is
+# often the only type that is understood by both Python and Java
+# implementations of Arrow.
+@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [
+ # TODO: null
+ (pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'),
+ (
+ pa.uint8(),
+ list(range(128)),
+ 'UInt1Vector',
+ '{"name":"int","bitWidth":8,"isSigned":false}'
+ ),
+ (
+ pa.uint16(),
+ list(range(128)),
+ 'UInt2Vector',
+ '{"name":"int","bitWidth":16,"isSigned":false}'
+ ),
+ (
+ pa.uint32(),
+ list(range(128)),
+ 'UInt4Vector',
+ '{"name":"int","bitWidth":32,"isSigned":false}'
+ ),
+ (
+ pa.uint64(),
+ list(range(128)),
+ 'UInt8Vector',
+ '{"name":"int","bitWidth":64,"isSigned":false}'
+ ),
+ (
+ pa.int8(),
+ list(range(128)),
+ 'TinyIntVector',
+ '{"name":"int","bitWidth":8,"isSigned":true}'
+ ),
+ (
+ pa.int16(),
+ list(range(128)),
+ 'SmallIntVector',
+ '{"name":"int","bitWidth":16,"isSigned":true}'
+ ),
+ (
+ pa.int32(),
+ list(range(128)),
+ 'IntVector',
+ '{"name":"int","bitWidth":32,"isSigned":true}'
+ ),
+ (
+ pa.int64(),
+ list(range(128)),
+ 'BigIntVector',
+ '{"name":"int","bitWidth":64,"isSigned":true}'
+ ),
+ # TODO: float16
+ (
+ pa.float32(),
+ list(range(128)),
+ 'Float4Vector',
+ '{"name":"floatingpoint","precision":"SINGLE"}'
+ ),
+ (
+ pa.float64(),
+ list(range(128)),
+ 'Float8Vector',
+ '{"name":"floatingpoint","precision":"DOUBLE"}'
+ ),
+ (
+ pa.timestamp('s'),
+ list(range(128)),
+ 'TimeStampSecVector',
+ '{"name":"timestamp","unit":"SECOND","timezone":null}'
+ ),
+ (
+ pa.timestamp('ms'),
+ list(range(128)),
+ 'TimeStampMilliVector',
+ '{"name":"timestamp","unit":"MILLISECOND","timezone":null}'
+ ),
+ (
+ pa.timestamp('us'),
+ list(range(128)),
+ 'TimeStampMicroVector',
+ '{"name":"timestamp","unit":"MICROSECOND","timezone":null}'
+ ),
+ (
+ pa.timestamp('ns'),
+ list(range(128)),
+ 'TimeStampNanoVector',
+ '{"name":"timestamp","unit":"NANOSECOND","timezone":null}'
+ ),
+ # TODO(ARROW-2605): These types miss a conversion from pure Python objects
+ # * pa.time32('s')
+ # * pa.time32('ms')
+ # * pa.time64('us')
+ # * pa.time64('ns')
+ (
+ pa.date32(),
+ list(range(128)),
+ 'DateDayVector',
+ '{"name":"date","unit":"DAY"}'
+ ),
+ (
+ pa.date64(),
+ list(range(128)),
+ 'DateMilliVector',
+ '{"name":"date","unit":"MILLISECOND"}'
+ ),
+ # TODO(ARROW-2606): pa.decimal128(19, 4)
+])
+def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type,
+ jvm_spec):
+ # Create vector
+ cls = "org.apache.arrow.vector.{}".format(jvm_type)
+ jvm_vector = jpype.JClass(cls)("vector", root_allocator)
+ jvm_vector.allocateNew(len(py_data))
+ for i, val in enumerate(py_data):
+ if jvm_type in {'UInt1Vector', 'UInt2Vector'}:
+ val = jpype.JInt(val)
+ jvm_vector.setSafe(i, val)
+ jvm_vector.setValueCount(len(py_data))
+
+ # Create field
+ spec = {
+ 'name': 'field_name',
+ 'nullable': False,
+ 'type': json.loads(jvm_spec),
+ # TODO: This needs to be set for complex types
+ 'children': []
+ }
+ jvm_field = _jvm_field(json.dumps(spec))
+
+ # Create VectorSchemaRoot
+ jvm_fields = jpype.JClass('java.util.ArrayList')()
+ jvm_fields.add(jvm_field)
+ jvm_vectors = jpype.JClass('java.util.ArrayList')()
+ jvm_vectors.add(jvm_vector)
+ jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot')
+ jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data))
+
+ py_record_batch = pa.RecordBatch.from_arrays(
+ [pa.array(py_data, type=pa_type)],
+ ['col']
+ )
+ jvm_record_batch = pa_jvm.record_batch(jvm_vsr)
+
+ assert py_record_batch.equals(jvm_record_batch)
+
+
+def _string_to_varchar_holder(ra, string):
+ nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder"
+ holder = jpype.JClass(nvch_cls)()
+ if string is None:
+ holder.isSet = 0
+ else:
+ holder.isSet = 1
+ value = jpype.JClass("java.lang.String")("string")
+ std_charsets = jpype.JClass("java.nio.charset.StandardCharsets")
+ bytes_ = value.getBytes(std_charsets.UTF_8)
+ holder.buffer = ra.buffer(len(bytes_))
+ holder.buffer.setBytes(0, bytes_, 0, len(bytes_))
+ holder.start = 0
+ holder.end = len(bytes_)
+ return holder
+
+
+# TODO(ARROW-2607)
+@pytest.mark.xfail(reason="from_buffers is only supported for "
+ "primitive arrays yet")
+def test_jvm_string_array(root_allocator):
+ data = ["string", None, "töst"]
+ cls = "org.apache.arrow.vector.VarCharVector"
+ jvm_vector = jpype.JClass(cls)("vector", root_allocator)
+ jvm_vector.allocateNew()
+
+ for i, string in enumerate(data):
+ holder = _string_to_varchar_holder(root_allocator, "string")
+ jvm_vector.setSafe(i, holder)
+ jvm_vector.setValueCount(i + 1)
+
+ py_array = pa.array(data, type=pa.string())
+ jvm_array = pa_jvm.array(jvm_vector)
+
+ assert py_array.equals(jvm_array)
diff --git a/src/arrow/python/pyarrow/tests/test_memory.py b/src/arrow/python/pyarrow/tests/test_memory.py
new file mode 100644
index 000000000..b8dd7344f
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_memory.py
@@ -0,0 +1,161 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import contextlib
+import os
+import subprocess
+import sys
+import weakref
+
+import pyarrow as pa
+
+
+possible_backends = ["system", "jemalloc", "mimalloc"]
+
+should_have_jemalloc = sys.platform == "linux"
+should_have_mimalloc = sys.platform == "win32"
+
+
+@contextlib.contextmanager
+def allocate_bytes(pool, nbytes):
+ """
+ Temporarily allocate *nbytes* from the given *pool*.
+ """
+ arr = pa.array([b"x" * nbytes], type=pa.binary(), memory_pool=pool)
+ # Fetch the values buffer from the varbinary array and release the rest,
+ # to get the desired allocation amount
+ buf = arr.buffers()[2]
+ arr = None
+ assert len(buf) == nbytes
+ try:
+ yield
+ finally:
+ buf = None
+
+
+def check_allocated_bytes(pool):
+ """
+ Check allocation stats on *pool*.
+ """
+ allocated_before = pool.bytes_allocated()
+ max_mem_before = pool.max_memory()
+ with allocate_bytes(pool, 512):
+ assert pool.bytes_allocated() == allocated_before + 512
+ new_max_memory = pool.max_memory()
+ assert pool.max_memory() >= max_mem_before
+ assert pool.bytes_allocated() == allocated_before
+ assert pool.max_memory() == new_max_memory
+
+
+def test_default_allocated_bytes():
+ pool = pa.default_memory_pool()
+ with allocate_bytes(pool, 1024):
+ check_allocated_bytes(pool)
+ assert pool.bytes_allocated() == pa.total_allocated_bytes()
+
+
+def test_proxy_memory_pool():
+ pool = pa.proxy_memory_pool(pa.default_memory_pool())
+ check_allocated_bytes(pool)
+ wr = weakref.ref(pool)
+ assert wr() is not None
+ del pool
+ assert wr() is None
+
+
+def test_logging_memory_pool(capfd):
+ pool = pa.logging_memory_pool(pa.default_memory_pool())
+ check_allocated_bytes(pool)
+ out, err = capfd.readouterr()
+ assert err == ""
+ assert out.count("Allocate:") > 0
+ assert out.count("Allocate:") == out.count("Free:")
+
+
+def test_set_memory_pool():
+ old_pool = pa.default_memory_pool()
+ pool = pa.proxy_memory_pool(old_pool)
+ pa.set_memory_pool(pool)
+ try:
+ allocated_before = pool.bytes_allocated()
+ with allocate_bytes(None, 512):
+ assert pool.bytes_allocated() == allocated_before + 512
+ assert pool.bytes_allocated() == allocated_before
+ finally:
+ pa.set_memory_pool(old_pool)
+
+
+def test_default_backend_name():
+ pool = pa.default_memory_pool()
+ assert pool.backend_name in possible_backends
+
+
+def test_release_unused():
+ pool = pa.default_memory_pool()
+ pool.release_unused()
+
+
+def check_env_var(name, expected, *, expect_warning=False):
+ code = f"""if 1:
+ import pyarrow as pa
+
+ pool = pa.default_memory_pool()
+ assert pool.backend_name in {expected!r}, pool.backend_name
+ """
+ env = dict(os.environ)
+ env['ARROW_DEFAULT_MEMORY_POOL'] = name
+ res = subprocess.run([sys.executable, "-c", code], env=env,
+ universal_newlines=True, stderr=subprocess.PIPE)
+ if res.returncode != 0:
+ print(res.stderr, file=sys.stderr)
+ res.check_returncode() # fail
+ errlines = res.stderr.splitlines()
+ if expect_warning:
+ assert len(errlines) == 1
+ assert f"Unsupported backend '{name}'" in errlines[0]
+ else:
+ assert len(errlines) == 0
+
+
+def test_env_var():
+ check_env_var("system", ["system"])
+ if should_have_jemalloc:
+ check_env_var("jemalloc", ["jemalloc"])
+ if should_have_mimalloc:
+ check_env_var("mimalloc", ["mimalloc"])
+ check_env_var("nonexistent", possible_backends, expect_warning=True)
+
+
+def test_specific_memory_pools():
+ specific_pools = set()
+
+ def check(factory, name, *, can_fail=False):
+ if can_fail:
+ try:
+ pool = factory()
+ except NotImplementedError:
+ return
+ else:
+ pool = factory()
+ assert pool.backend_name == name
+ specific_pools.add(pool)
+
+ check(pa.system_memory_pool, "system")
+ check(pa.jemalloc_memory_pool, "jemalloc",
+ can_fail=not should_have_jemalloc)
+ check(pa.mimalloc_memory_pool, "mimalloc",
+ can_fail=not should_have_mimalloc)
diff --git a/src/arrow/python/pyarrow/tests/test_misc.py b/src/arrow/python/pyarrow/tests/test_misc.py
new file mode 100644
index 000000000..012f15e16
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_misc.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import subprocess
+import sys
+
+import pytest
+
+import pyarrow as pa
+
+
+def test_get_include():
+ include_dir = pa.get_include()
+ assert os.path.exists(os.path.join(include_dir, 'arrow', 'api.h'))
+
+
+@pytest.mark.skipif('sys.platform != "win32"')
+def test_get_library_dirs_win32():
+ assert any(os.path.exists(os.path.join(directory, 'arrow.lib'))
+ for directory in pa.get_library_dirs())
+
+
+def test_cpu_count():
+ n = pa.cpu_count()
+ assert n > 0
+ try:
+ pa.set_cpu_count(n + 5)
+ assert pa.cpu_count() == n + 5
+ finally:
+ pa.set_cpu_count(n)
+
+
+def test_io_thread_count():
+ n = pa.io_thread_count()
+ assert n > 0
+ try:
+ pa.set_io_thread_count(n + 5)
+ assert pa.io_thread_count() == n + 5
+ finally:
+ pa.set_io_thread_count(n)
+
+
+def test_build_info():
+ assert isinstance(pa.cpp_build_info, pa.BuildInfo)
+ assert isinstance(pa.cpp_version_info, pa.VersionInfo)
+ assert isinstance(pa.cpp_version, str)
+ assert isinstance(pa.__version__, str)
+ assert pa.cpp_build_info.version_info == pa.cpp_version_info
+
+ # assert pa.version == pa.__version__ # XXX currently false
+
+
+def test_runtime_info():
+ info = pa.runtime_info()
+ assert isinstance(info, pa.RuntimeInfo)
+ possible_simd_levels = ('none', 'sse4_2', 'avx', 'avx2', 'avx512')
+ assert info.simd_level in possible_simd_levels
+ assert info.detected_simd_level in possible_simd_levels
+
+ if info.simd_level != 'none':
+ env = os.environ.copy()
+ env['ARROW_USER_SIMD_LEVEL'] = 'none'
+ code = f"""if 1:
+ import pyarrow as pa
+
+ info = pa.runtime_info()
+ assert info.simd_level == 'none', info.simd_level
+ assert info.detected_simd_level == {info.detected_simd_level!r},\
+ info.detected_simd_level
+ """
+ subprocess.check_call([sys.executable, "-c", code], env=env)
+
+
+@pytest.mark.parametrize('klass', [
+ pa.Field,
+ pa.Schema,
+ pa.ChunkedArray,
+ pa.RecordBatch,
+ pa.Table,
+ pa.Buffer,
+ pa.Array,
+ pa.Tensor,
+ pa.DataType,
+ pa.ListType,
+ pa.LargeListType,
+ pa.FixedSizeListType,
+ pa.UnionType,
+ pa.SparseUnionType,
+ pa.DenseUnionType,
+ pa.StructType,
+ pa.Time32Type,
+ pa.Time64Type,
+ pa.TimestampType,
+ pa.Decimal128Type,
+ pa.Decimal256Type,
+ pa.DictionaryType,
+ pa.FixedSizeBinaryType,
+ pa.NullArray,
+ pa.NumericArray,
+ pa.IntegerArray,
+ pa.FloatingPointArray,
+ pa.BooleanArray,
+ pa.Int8Array,
+ pa.Int16Array,
+ pa.Int32Array,
+ pa.Int64Array,
+ pa.UInt8Array,
+ pa.UInt16Array,
+ pa.UInt32Array,
+ pa.UInt64Array,
+ pa.ListArray,
+ pa.LargeListArray,
+ pa.MapArray,
+ pa.FixedSizeListArray,
+ pa.UnionArray,
+ pa.BinaryArray,
+ pa.StringArray,
+ pa.FixedSizeBinaryArray,
+ pa.DictionaryArray,
+ pa.Date32Array,
+ pa.Date64Array,
+ pa.TimestampArray,
+ pa.Time32Array,
+ pa.Time64Array,
+ pa.DurationArray,
+ pa.Decimal128Array,
+ pa.Decimal256Array,
+ pa.StructArray,
+ pa.Scalar,
+ pa.BooleanScalar,
+ pa.Int8Scalar,
+ pa.Int16Scalar,
+ pa.Int32Scalar,
+ pa.Int64Scalar,
+ pa.UInt8Scalar,
+ pa.UInt16Scalar,
+ pa.UInt32Scalar,
+ pa.UInt64Scalar,
+ pa.HalfFloatScalar,
+ pa.FloatScalar,
+ pa.DoubleScalar,
+ pa.Decimal128Scalar,
+ pa.Decimal256Scalar,
+ pa.Date32Scalar,
+ pa.Date64Scalar,
+ pa.Time32Scalar,
+ pa.Time64Scalar,
+ pa.TimestampScalar,
+ pa.DurationScalar,
+ pa.StringScalar,
+ pa.BinaryScalar,
+ pa.FixedSizeBinaryScalar,
+ pa.ListScalar,
+ pa.LargeListScalar,
+ pa.MapScalar,
+ pa.FixedSizeListScalar,
+ pa.UnionScalar,
+ pa.StructScalar,
+ pa.DictionaryScalar,
+ pa.ipc.Message,
+ pa.ipc.MessageReader,
+ pa.MemoryPool,
+ pa.LoggingMemoryPool,
+ pa.ProxyMemoryPool,
+])
+def test_extension_type_constructor_errors(klass):
+ # ARROW-2638: prevent calling extension class constructors directly
+ msg = "Do not call {cls}'s constructor directly, use .* instead."
+ with pytest.raises(TypeError, match=msg.format(cls=klass.__name__)):
+ klass()
diff --git a/src/arrow/python/pyarrow/tests/test_orc.py b/src/arrow/python/pyarrow/tests/test_orc.py
new file mode 100644
index 000000000..f38121ffb
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_orc.py
@@ -0,0 +1,271 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import decimal
+import datetime
+
+import pyarrow as pa
+
+
+# Marks all of the tests in this module
+# Ignore these with pytest ... -m 'not orc'
+pytestmark = pytest.mark.orc
+
+
+try:
+ from pandas.testing import assert_frame_equal
+ import pandas as pd
+except ImportError:
+ pass
+
+
+@pytest.fixture(scope="module")
+def datadir(base_datadir):
+ return base_datadir / "orc"
+
+
+def fix_example_values(actual_cols, expected_cols):
+ """
+ Fix type of expected values (as read from JSON) according to
+ actual ORC datatype.
+ """
+ for name in expected_cols:
+ expected = expected_cols[name]
+ actual = actual_cols[name]
+ if (name == "map" and
+ [d.keys() == {'key', 'value'} for m in expected for d in m]):
+ # convert [{'key': k, 'value': v}, ...] to [(k, v), ...]
+ for i, m in enumerate(expected):
+ expected_cols[name][i] = [(d['key'], d['value']) for d in m]
+ continue
+
+ typ = actual[0].__class__
+ if issubclass(typ, datetime.datetime):
+ # timestamp fields are represented as strings in JSON files
+ expected = pd.to_datetime(expected)
+ elif issubclass(typ, datetime.date):
+ # date fields are represented as strings in JSON files
+ expected = expected.dt.date
+ elif typ is decimal.Decimal:
+ converted_decimals = [None] * len(expected)
+ # decimal fields are represented as reals in JSON files
+ for i, (d, v) in enumerate(zip(actual, expected)):
+ if not pd.isnull(v):
+ exp = d.as_tuple().exponent
+ factor = 10 ** -exp
+ converted_decimals[i] = (
+ decimal.Decimal(round(v * factor)).scaleb(exp))
+ expected = pd.Series(converted_decimals)
+
+ expected_cols[name] = expected
+
+
+def check_example_values(orc_df, expected_df, start=None, stop=None):
+ if start is not None or stop is not None:
+ expected_df = expected_df[start:stop].reset_index(drop=True)
+ assert_frame_equal(orc_df, expected_df, check_dtype=False)
+
+
+def check_example_file(orc_path, expected_df, need_fix=False):
+ """
+ Check a ORC file against the expected columns dictionary.
+ """
+ from pyarrow import orc
+
+ orc_file = orc.ORCFile(orc_path)
+ # Exercise ORCFile.read()
+ table = orc_file.read()
+ assert isinstance(table, pa.Table)
+ table.validate()
+
+ # This workaround needed because of ARROW-3080
+ orc_df = pd.DataFrame(table.to_pydict())
+
+ assert set(expected_df.columns) == set(orc_df.columns)
+
+ # reorder columns if necessary
+ if not orc_df.columns.equals(expected_df.columns):
+ expected_df = expected_df.reindex(columns=orc_df.columns)
+
+ if need_fix:
+ fix_example_values(orc_df, expected_df)
+
+ check_example_values(orc_df, expected_df)
+ # Exercise ORCFile.read_stripe()
+ json_pos = 0
+ for i in range(orc_file.nstripes):
+ batch = orc_file.read_stripe(i)
+ check_example_values(pd.DataFrame(batch.to_pydict()),
+ expected_df,
+ start=json_pos,
+ stop=json_pos + len(batch))
+ json_pos += len(batch)
+ assert json_pos == orc_file.nrows
+
+
+@pytest.mark.pandas
+@pytest.mark.parametrize('filename', [
+ 'TestOrcFile.test1.orc',
+ 'TestOrcFile.testDate1900.orc',
+ 'decimal.orc'
+])
+def test_example_using_json(filename, datadir):
+ """
+ Check a ORC file example against the equivalent JSON file, as given
+ in the Apache ORC repository (the JSON file has one JSON object per
+ line, corresponding to one row in the ORC file).
+ """
+ # Read JSON file
+ path = datadir / filename
+ table = pd.read_json(str(path.with_suffix('.jsn.gz')), lines=True)
+ check_example_file(path, table, need_fix=True)
+
+
+def test_orcfile_empty(datadir):
+ from pyarrow import orc
+
+ table = orc.ORCFile(datadir / "TestOrcFile.emptyFile.orc").read()
+ assert table.num_rows == 0
+
+ expected_schema = pa.schema([
+ ("boolean1", pa.bool_()),
+ ("byte1", pa.int8()),
+ ("short1", pa.int16()),
+ ("int1", pa.int32()),
+ ("long1", pa.int64()),
+ ("float1", pa.float32()),
+ ("double1", pa.float64()),
+ ("bytes1", pa.binary()),
+ ("string1", pa.string()),
+ ("middle", pa.struct(
+ [("list", pa.list_(
+ pa.struct([("int1", pa.int32()),
+ ("string1", pa.string())])))
+ ])),
+ ("list", pa.list_(
+ pa.struct([("int1", pa.int32()),
+ ("string1", pa.string())])
+ )),
+ ("map", pa.map_(pa.string(),
+ pa.struct([("int1", pa.int32()),
+ ("string1", pa.string())])
+ )),
+ ])
+ assert table.schema == expected_schema
+
+
+def test_orcfile_readwrite():
+ from pyarrow import orc
+
+ buffer_output_stream = pa.BufferOutputStream()
+ a = pa.array([1, None, 3, None])
+ b = pa.array([None, "Arrow", None, "ORC"])
+ table = pa.table({"int64": a, "utf8": b})
+ orc.write_table(table, buffer_output_stream)
+ buffer_reader = pa.BufferReader(buffer_output_stream.getvalue())
+ orc_file = orc.ORCFile(buffer_reader)
+ output_table = orc_file.read()
+ assert table.equals(output_table)
+
+ # deprecated keyword order
+ buffer_output_stream = pa.BufferOutputStream()
+ with pytest.warns(FutureWarning):
+ orc.write_table(buffer_output_stream, table)
+ buffer_reader = pa.BufferReader(buffer_output_stream.getvalue())
+ orc_file = orc.ORCFile(buffer_reader)
+ output_table = orc_file.read()
+ assert table.equals(output_table)
+
+
+def test_column_selection(tempdir):
+ from pyarrow import orc
+
+ # create a table with nested types
+ inner = pa.field('inner', pa.int64())
+ middle = pa.field('middle', pa.struct([inner]))
+ fields = [
+ pa.field('basic', pa.int32()),
+ pa.field(
+ 'list', pa.list_(pa.field('item', pa.int32()))
+ ),
+ pa.field(
+ 'struct', pa.struct([middle, pa.field('inner2', pa.int64())])
+ ),
+ pa.field(
+ 'list-struct', pa.list_(pa.field(
+ 'item', pa.struct([
+ pa.field('inner1', pa.int64()),
+ pa.field('inner2', pa.int64())
+ ])
+ ))
+ ),
+ pa.field('basic2', pa.int64()),
+ ]
+ arrs = [
+ [0], [[1, 2]], [{"middle": {"inner": 3}, "inner2": 4}],
+ [[{"inner1": 5, "inner2": 6}, {"inner1": 7, "inner2": 8}]], [9]]
+ table = pa.table(arrs, schema=pa.schema(fields))
+
+ path = str(tempdir / 'test.orc')
+ orc.write_table(table, path)
+ orc_file = orc.ORCFile(path)
+
+ # default selecting all columns
+ result1 = orc_file.read()
+ assert result1.equals(table)
+
+ # selecting with columns names
+ result2 = orc_file.read(columns=["basic", "basic2"])
+ assert result2.equals(table.select(["basic", "basic2"]))
+
+ result3 = orc_file.read(columns=["list", "struct", "basic2"])
+ assert result3.equals(table.select(["list", "struct", "basic2"]))
+
+ # using dotted paths
+ result4 = orc_file.read(columns=["struct.middle.inner"])
+ expected4 = pa.table({"struct": [{"middle": {"inner": 3}}]})
+ assert result4.equals(expected4)
+
+ result5 = orc_file.read(columns=["struct.inner2"])
+ expected5 = pa.table({"struct": [{"inner2": 4}]})
+ assert result5.equals(expected5)
+
+ result6 = orc_file.read(
+ columns=["list", "struct.middle.inner", "struct.inner2"]
+ )
+ assert result6.equals(table.select(["list", "struct"]))
+
+ result7 = orc_file.read(columns=["list-struct.inner1"])
+ expected7 = pa.table({"list-struct": [[{"inner1": 5}, {"inner1": 7}]]})
+ assert result7.equals(expected7)
+
+ # selecting with (Arrow-based) field indices
+ result2 = orc_file.read(columns=[0, 4])
+ assert result2.equals(table.select(["basic", "basic2"]))
+
+ result3 = orc_file.read(columns=[1, 2, 3])
+ assert result3.equals(table.select(["list", "struct", "list-struct"]))
+
+ # error on non-existing name or index
+ with pytest.raises(IOError):
+ # liborc returns ParseError, which gets translated into IOError
+ # instead of ValueError
+ orc_file.read(columns=["wrong"])
+
+ with pytest.raises(ValueError):
+ orc_file.read(columns=[5])
diff --git a/src/arrow/python/pyarrow/tests/test_pandas.py b/src/arrow/python/pyarrow/tests/test_pandas.py
new file mode 100644
index 000000000..112c7938e
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_pandas.py
@@ -0,0 +1,4386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import gc
+import decimal
+import json
+import multiprocessing as mp
+import sys
+
+from collections import OrderedDict
+from datetime import date, datetime, time, timedelta, timezone
+
+import hypothesis as h
+import hypothesis.extra.pytz as tzst
+import hypothesis.strategies as st
+import numpy as np
+import numpy.testing as npt
+import pytest
+import pytz
+
+from pyarrow.pandas_compat import get_logical_type, _pandas_api
+from pyarrow.tests.util import invoke_script, random_ascii, rands
+import pyarrow.tests.strategies as past
+from pyarrow.vendored.version import Version
+
+import pyarrow as pa
+try:
+ from pyarrow import parquet as pq
+except ImportError:
+ pass
+
+try:
+ import pandas as pd
+ import pandas.testing as tm
+ from .pandas_examples import dataframe_with_arrays, dataframe_with_lists
+except ImportError:
+ pass
+
+
+# Marks all of the tests in this module
+pytestmark = pytest.mark.pandas
+
+
+def _alltypes_example(size=100):
+ return pd.DataFrame({
+ 'uint8': np.arange(size, dtype=np.uint8),
+ 'uint16': np.arange(size, dtype=np.uint16),
+ 'uint32': np.arange(size, dtype=np.uint32),
+ 'uint64': np.arange(size, dtype=np.uint64),
+ 'int8': np.arange(size, dtype=np.int16),
+ 'int16': np.arange(size, dtype=np.int16),
+ 'int32': np.arange(size, dtype=np.int32),
+ 'int64': np.arange(size, dtype=np.int64),
+ 'float32': np.arange(size, dtype=np.float32),
+ 'float64': np.arange(size, dtype=np.float64),
+ 'bool': np.random.randn(size) > 0,
+ # TODO(wesm): Pandas only support ns resolution, Arrow supports s, ms,
+ # us, ns
+ 'datetime': np.arange("2016-01-01T00:00:00.001", size,
+ dtype='datetime64[ms]'),
+ 'str': [str(x) for x in range(size)],
+ 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None],
+ 'empty_str': [''] * size
+ })
+
+
+def _check_pandas_roundtrip(df, expected=None, use_threads=False,
+ expected_schema=None,
+ check_dtype=True, schema=None,
+ preserve_index=False,
+ as_batch=False):
+ klass = pa.RecordBatch if as_batch else pa.Table
+ table = klass.from_pandas(df, schema=schema,
+ preserve_index=preserve_index,
+ nthreads=2 if use_threads else 1)
+ result = table.to_pandas(use_threads=use_threads)
+
+ if expected_schema:
+ # all occurrences of _check_pandas_roundtrip passes expected_schema
+ # without the pandas generated key-value metadata
+ assert table.schema.equals(expected_schema)
+
+ if expected is None:
+ expected = df
+
+ tm.assert_frame_equal(result, expected, check_dtype=check_dtype,
+ check_index_type=('equiv' if preserve_index
+ else False))
+
+
+def _check_series_roundtrip(s, type_=None, expected_pa_type=None):
+ arr = pa.array(s, from_pandas=True, type=type_)
+
+ if type_ is not None and expected_pa_type is None:
+ expected_pa_type = type_
+
+ if expected_pa_type is not None:
+ assert arr.type == expected_pa_type
+
+ result = pd.Series(arr.to_pandas(), name=s.name)
+ tm.assert_series_equal(s, result)
+
+
+def _check_array_roundtrip(values, expected=None, mask=None,
+ type=None):
+ arr = pa.array(values, from_pandas=True, mask=mask, type=type)
+ result = arr.to_pandas()
+
+ values_nulls = pd.isnull(values)
+ if mask is None:
+ assert arr.null_count == values_nulls.sum()
+ else:
+ assert arr.null_count == (mask | values_nulls).sum()
+
+ if expected is None:
+ if mask is None:
+ expected = pd.Series(values)
+ else:
+ expected = pd.Series(np.ma.masked_array(values, mask=mask))
+
+ tm.assert_series_equal(pd.Series(result), expected, check_names=False)
+
+
+def _check_array_from_pandas_roundtrip(np_array, type=None):
+ arr = pa.array(np_array, from_pandas=True, type=type)
+ result = arr.to_pandas()
+ npt.assert_array_equal(result, np_array)
+
+
+class TestConvertMetadata:
+ """
+ Conversion tests for Pandas metadata & indices.
+ """
+
+ def test_non_string_columns(self):
+ df = pd.DataFrame({0: [1, 2, 3]})
+ table = pa.Table.from_pandas(df)
+ assert table.field(0).name == '0'
+
+ def test_from_pandas_with_columns(self):
+ df = pd.DataFrame({0: [1, 2, 3], 1: [1, 3, 3], 2: [2, 4, 5]},
+ columns=[1, 0])
+
+ table = pa.Table.from_pandas(df, columns=[0, 1])
+ expected = pa.Table.from_pandas(df[[0, 1]])
+ assert expected.equals(table)
+
+ record_batch_table = pa.RecordBatch.from_pandas(df, columns=[0, 1])
+ record_batch_expected = pa.RecordBatch.from_pandas(df[[0, 1]])
+ assert record_batch_expected.equals(record_batch_table)
+
+ def test_column_index_names_are_preserved(self):
+ df = pd.DataFrame({'data': [1, 2, 3]})
+ df.columns.names = ['a']
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_range_index_shortcut(self):
+ # ARROW-1639
+ index_name = 'foo'
+ df = pd.DataFrame({'a': [1, 2, 3, 4]},
+ index=pd.RangeIndex(0, 8, step=2, name=index_name))
+
+ df2 = pd.DataFrame({'a': [4, 5, 6, 7]},
+ index=pd.RangeIndex(0, 4))
+
+ table = pa.Table.from_pandas(df)
+ table_no_index_name = pa.Table.from_pandas(df2)
+
+ # The RangeIndex is tracked in the metadata only
+ assert len(table.schema) == 1
+
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, df)
+ assert isinstance(result.index, pd.RangeIndex)
+ assert _pandas_api.get_rangeindex_attribute(result.index, 'step') == 2
+ assert result.index.name == index_name
+
+ result2 = table_no_index_name.to_pandas()
+ tm.assert_frame_equal(result2, df2)
+ assert isinstance(result2.index, pd.RangeIndex)
+ assert _pandas_api.get_rangeindex_attribute(result2.index, 'step') == 1
+ assert result2.index.name is None
+
+ def test_range_index_force_serialization(self):
+ # ARROW-5427: preserve_index=True will force the RangeIndex to
+ # be serialized as a column rather than tracked more
+ # efficiently as metadata
+ df = pd.DataFrame({'a': [1, 2, 3, 4]},
+ index=pd.RangeIndex(0, 8, step=2, name='foo'))
+
+ table = pa.Table.from_pandas(df, preserve_index=True)
+ assert table.num_columns == 2
+ assert 'foo' in table.column_names
+
+ restored = table.to_pandas()
+ tm.assert_frame_equal(restored, df)
+
+ def test_rangeindex_doesnt_warn(self):
+ # ARROW-5606: pandas 0.25 deprecated private _start/stop/step
+ # attributes -> can be removed if support < pd 0.25 is dropped
+ df = pd.DataFrame(np.random.randn(4, 2), columns=['a', 'b'])
+
+ with pytest.warns(None) as record:
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ assert len(record) == 0
+
+ def test_multiindex_columns(self):
+ columns = pd.MultiIndex.from_arrays([
+ ['one', 'two'], ['X', 'Y']
+ ])
+ df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns)
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_multiindex_columns_with_dtypes(self):
+ columns = pd.MultiIndex.from_arrays(
+ [
+ ['one', 'two'],
+ pd.DatetimeIndex(['2017-08-01', '2017-08-02']),
+ ],
+ names=['level_1', 'level_2'],
+ )
+ df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns)
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_multiindex_with_column_dtype_object(self):
+ # ARROW-3651 & ARROW-9096
+ # Bug when dtype of the columns is object.
+
+ # uinderlying dtype: integer
+ df = pd.DataFrame([1], columns=pd.Index([1], dtype=object))
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ # underlying dtype: floating
+ df = pd.DataFrame([1], columns=pd.Index([1.1], dtype=object))
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ # underlying dtype: datetime
+ # ARROW-9096: a simple roundtrip now works
+ df = pd.DataFrame([1], columns=pd.Index(
+ [datetime(2018, 1, 1)], dtype="object"))
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_multiindex_columns_unicode(self):
+ columns = pd.MultiIndex.from_arrays([['あ', 'い'], ['X', 'Y']])
+ df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns)
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_multiindex_doesnt_warn(self):
+ # ARROW-3953: pandas 0.24 rename of MultiIndex labels to codes
+ columns = pd.MultiIndex.from_arrays([['one', 'two'], ['X', 'Y']])
+ df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns)
+
+ with pytest.warns(None) as record:
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ assert len(record) == 0
+
+ def test_integer_index_column(self):
+ df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')])
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_index_metadata_field_name(self):
+ # test None case, and strangely named non-index columns
+ df = pd.DataFrame(
+ [(1, 'a', 3.1), (2, 'b', 2.2), (3, 'c', 1.3)],
+ index=pd.MultiIndex.from_arrays(
+ [['c', 'b', 'a'], [3, 2, 1]],
+ names=[None, 'foo']
+ ),
+ columns=['a', None, '__index_level_0__'],
+ )
+ with pytest.warns(UserWarning):
+ t = pa.Table.from_pandas(df, preserve_index=True)
+ js = t.schema.pandas_metadata
+
+ col1, col2, col3, idx0, foo = js['columns']
+
+ assert col1['name'] == 'a'
+ assert col1['name'] == col1['field_name']
+
+ assert col2['name'] is None
+ assert col2['field_name'] == 'None'
+
+ assert col3['name'] == '__index_level_0__'
+ assert col3['name'] == col3['field_name']
+
+ idx0_descr, foo_descr = js['index_columns']
+ assert idx0_descr == '__index_level_0__'
+ assert idx0['field_name'] == idx0_descr
+ assert idx0['name'] is None
+
+ assert foo_descr == 'foo'
+ assert foo['field_name'] == foo_descr
+ assert foo['name'] == foo_descr
+
+ def test_categorical_column_index(self):
+ df = pd.DataFrame(
+ [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)],
+ columns=pd.Index(list('def'), dtype='category')
+ )
+ t = pa.Table.from_pandas(df, preserve_index=True)
+ js = t.schema.pandas_metadata
+
+ column_indexes, = js['column_indexes']
+ assert column_indexes['name'] is None
+ assert column_indexes['pandas_type'] == 'categorical'
+ assert column_indexes['numpy_type'] == 'int8'
+
+ md = column_indexes['metadata']
+ assert md['num_categories'] == 3
+ assert md['ordered'] is False
+
+ def test_string_column_index(self):
+ df = pd.DataFrame(
+ [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)],
+ columns=pd.Index(list('def'), name='stringz')
+ )
+ t = pa.Table.from_pandas(df, preserve_index=True)
+ js = t.schema.pandas_metadata
+
+ column_indexes, = js['column_indexes']
+ assert column_indexes['name'] == 'stringz'
+ assert column_indexes['name'] == column_indexes['field_name']
+ assert column_indexes['numpy_type'] == 'object'
+ assert column_indexes['pandas_type'] == 'unicode'
+
+ md = column_indexes['metadata']
+
+ assert len(md) == 1
+ assert md['encoding'] == 'UTF-8'
+
+ def test_datetimetz_column_index(self):
+ df = pd.DataFrame(
+ [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)],
+ columns=pd.date_range(
+ start='2017-01-01', periods=3, tz='America/New_York'
+ )
+ )
+ t = pa.Table.from_pandas(df, preserve_index=True)
+ js = t.schema.pandas_metadata
+
+ column_indexes, = js['column_indexes']
+ assert column_indexes['name'] is None
+ assert column_indexes['pandas_type'] == 'datetimetz'
+ assert column_indexes['numpy_type'] == 'datetime64[ns]'
+
+ md = column_indexes['metadata']
+ assert md['timezone'] == 'America/New_York'
+
+ def test_datetimetz_row_index(self):
+ df = pd.DataFrame({
+ 'a': pd.date_range(
+ start='2017-01-01', periods=3, tz='America/New_York'
+ )
+ })
+ df = df.set_index('a')
+
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_categorical_row_index(self):
+ df = pd.DataFrame({'a': [1, 2, 3], 'b': [1, 2, 3]})
+ df['a'] = df.a.astype('category')
+ df = df.set_index('a')
+
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_duplicate_column_names_does_not_crash(self):
+ df = pd.DataFrame([(1, 'a'), (2, 'b')], columns=list('aa'))
+ with pytest.raises(ValueError):
+ pa.Table.from_pandas(df)
+
+ def test_dictionary_indices_boundscheck(self):
+ # ARROW-1658. No validation of indices leads to segfaults in pandas
+ indices = [[0, 1], [0, -1]]
+
+ for inds in indices:
+ arr = pa.DictionaryArray.from_arrays(inds, ['a'], safe=False)
+ batch = pa.RecordBatch.from_arrays([arr], ['foo'])
+ table = pa.Table.from_batches([batch, batch, batch])
+
+ with pytest.raises(IndexError):
+ arr.to_pandas()
+
+ with pytest.raises(IndexError):
+ table.to_pandas()
+
+ def test_unicode_with_unicode_column_and_index(self):
+ df = pd.DataFrame({'あ': ['い']}, index=['う'])
+
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_mixed_column_names(self):
+ # mixed type column names are not reconstructed exactly
+ df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
+
+ for cols in [['あ', b'a'], [1, '2'], [1, 1.5]]:
+ df.columns = pd.Index(cols, dtype=object)
+
+ # assert that the from_pandas raises the warning
+ with pytest.warns(UserWarning):
+ pa.Table.from_pandas(df)
+
+ expected = df.copy()
+ expected.columns = df.columns.values.astype(str)
+ with pytest.warns(UserWarning):
+ _check_pandas_roundtrip(df, expected=expected,
+ preserve_index=True)
+
+ def test_binary_column_name(self):
+ column_data = ['い']
+ key = 'あ'.encode()
+ data = {key: column_data}
+ df = pd.DataFrame(data)
+
+ # we can't use _check_pandas_roundtrip here because our metadata
+ # is always decoded as utf8: even if binary goes in, utf8 comes out
+ t = pa.Table.from_pandas(df, preserve_index=True)
+ df2 = t.to_pandas()
+ assert df.values[0] == df2.values[0]
+ assert df.index.values[0] == df2.index.values[0]
+ assert df.columns[0] == key
+
+ def test_multiindex_duplicate_values(self):
+ num_rows = 3
+ numbers = list(range(num_rows))
+ index = pd.MultiIndex.from_arrays(
+ [['foo', 'foo', 'bar'], numbers],
+ names=['foobar', 'some_numbers'],
+ )
+
+ df = pd.DataFrame({'numbers': numbers}, index=index)
+
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ def test_metadata_with_mixed_types(self):
+ df = pd.DataFrame({'data': [b'some_bytes', 'some_unicode']})
+ table = pa.Table.from_pandas(df)
+ js = table.schema.pandas_metadata
+ assert 'mixed' not in js
+ data_column = js['columns'][0]
+ assert data_column['pandas_type'] == 'bytes'
+ assert data_column['numpy_type'] == 'object'
+
+ def test_ignore_metadata(self):
+ df = pd.DataFrame({'a': [1, 2, 3], 'b': ['foo', 'bar', 'baz']},
+ index=['one', 'two', 'three'])
+ table = pa.Table.from_pandas(df)
+
+ result = table.to_pandas(ignore_metadata=True)
+ expected = (table.cast(table.schema.remove_metadata())
+ .to_pandas())
+
+ tm.assert_frame_equal(result, expected)
+
+ def test_list_metadata(self):
+ df = pd.DataFrame({'data': [[1], [2, 3, 4], [5] * 7]})
+ schema = pa.schema([pa.field('data', type=pa.list_(pa.int64()))])
+ table = pa.Table.from_pandas(df, schema=schema)
+ js = table.schema.pandas_metadata
+ assert 'mixed' not in js
+ data_column = js['columns'][0]
+ assert data_column['pandas_type'] == 'list[int64]'
+ assert data_column['numpy_type'] == 'object'
+
+ def test_struct_metadata(self):
+ df = pd.DataFrame({'dicts': [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]})
+ table = pa.Table.from_pandas(df)
+ pandas_metadata = table.schema.pandas_metadata
+ assert pandas_metadata['columns'][0]['pandas_type'] == 'object'
+
+ def test_decimal_metadata(self):
+ expected = pd.DataFrame({
+ 'decimals': [
+ decimal.Decimal('394092382910493.12341234678'),
+ -decimal.Decimal('314292388910493.12343437128'),
+ ]
+ })
+ table = pa.Table.from_pandas(expected)
+ js = table.schema.pandas_metadata
+ assert 'mixed' not in js
+ data_column = js['columns'][0]
+ assert data_column['pandas_type'] == 'decimal'
+ assert data_column['numpy_type'] == 'object'
+ assert data_column['metadata'] == {'precision': 26, 'scale': 11}
+
+ def test_table_column_subset_metadata(self):
+ # ARROW-1883
+ # non-default index
+ for index in [
+ pd.Index(['a', 'b', 'c'], name='index'),
+ pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')]:
+ df = pd.DataFrame({'a': [1, 2, 3],
+ 'b': [.1, .2, .3]}, index=index)
+ table = pa.Table.from_pandas(df)
+
+ table_subset = table.remove_column(1)
+ result = table_subset.to_pandas()
+ expected = df[['a']]
+ if isinstance(df.index, pd.DatetimeIndex):
+ df.index.freq = None
+ tm.assert_frame_equal(result, expected)
+
+ table_subset2 = table_subset.remove_column(1)
+ result = table_subset2.to_pandas()
+ tm.assert_frame_equal(result, df[['a']].reset_index(drop=True))
+
+ def test_to_pandas_column_subset_multiindex(self):
+ # ARROW-10122
+ df = pd.DataFrame(
+ {"first": list(range(5)),
+ "second": list(range(5)),
+ "value": np.arange(5)}
+ )
+ table = pa.Table.from_pandas(df.set_index(["first", "second"]))
+
+ subset = table.select(["first", "value"])
+ result = subset.to_pandas()
+ expected = df[["first", "value"]].set_index("first")
+ tm.assert_frame_equal(result, expected)
+
+ def test_empty_list_metadata(self):
+ # Create table with array of empty lists, forced to have type
+ # list(string) in pyarrow
+ c1 = [["test"], ["a", "b"], None]
+ c2 = [[], [], []]
+ arrays = OrderedDict([
+ ('c1', pa.array(c1, type=pa.list_(pa.string()))),
+ ('c2', pa.array(c2, type=pa.list_(pa.string()))),
+ ])
+ rb = pa.RecordBatch.from_arrays(
+ list(arrays.values()),
+ list(arrays.keys())
+ )
+ tbl = pa.Table.from_batches([rb])
+
+ # First roundtrip changes schema, because pandas cannot preserve the
+ # type of empty lists
+ df = tbl.to_pandas()
+ tbl2 = pa.Table.from_pandas(df)
+ md2 = tbl2.schema.pandas_metadata
+
+ # Second roundtrip
+ df2 = tbl2.to_pandas()
+ expected = pd.DataFrame(OrderedDict([('c1', c1), ('c2', c2)]))
+
+ tm.assert_frame_equal(df2, expected)
+
+ assert md2['columns'] == [
+ {
+ 'name': 'c1',
+ 'field_name': 'c1',
+ 'metadata': None,
+ 'numpy_type': 'object',
+ 'pandas_type': 'list[unicode]',
+ },
+ {
+ 'name': 'c2',
+ 'field_name': 'c2',
+ 'metadata': None,
+ 'numpy_type': 'object',
+ 'pandas_type': 'list[empty]',
+ }
+ ]
+
+ def test_metadata_pandas_version(self):
+ df = pd.DataFrame({'a': [1, 2, 3], 'b': [1, 2, 3]})
+ table = pa.Table.from_pandas(df)
+ assert table.schema.pandas_metadata['pandas_version'] is not None
+
+ def test_mismatch_metadata_schema(self):
+ # ARROW-10511
+ # It is possible that the metadata and actual schema is not fully
+ # matching (eg no timezone information for tz-aware column)
+ # -> to_pandas() conversion should not fail on that
+ df = pd.DataFrame({"datetime": pd.date_range("2020-01-01", periods=3)})
+
+ # OPTION 1: casting after conversion
+ table = pa.Table.from_pandas(df)
+ # cast the "datetime" column to be tz-aware
+ new_col = table["datetime"].cast(pa.timestamp('ns', tz="UTC"))
+ new_table1 = table.set_column(
+ 0, pa.field("datetime", new_col.type), new_col
+ )
+
+ # OPTION 2: specify schema during conversion
+ schema = pa.schema([("datetime", pa.timestamp('ns', tz="UTC"))])
+ new_table2 = pa.Table.from_pandas(df, schema=schema)
+
+ expected = df.copy()
+ expected["datetime"] = expected["datetime"].dt.tz_localize("UTC")
+
+ for new_table in [new_table1, new_table2]:
+ # ensure the new table still has the pandas metadata
+ assert new_table.schema.pandas_metadata is not None
+ # convert to pandas
+ result = new_table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+
+class TestConvertPrimitiveTypes:
+ """
+ Conversion tests for primitive (e.g. numeric) types.
+ """
+
+ def test_float_no_nulls(self):
+ data = {}
+ fields = []
+ dtypes = [('f2', pa.float16()),
+ ('f4', pa.float32()),
+ ('f8', pa.float64())]
+ num_values = 100
+
+ for numpy_dtype, arrow_dtype in dtypes:
+ values = np.random.randn(num_values)
+ data[numpy_dtype] = values.astype(numpy_dtype)
+ fields.append(pa.field(numpy_dtype, arrow_dtype))
+
+ df = pd.DataFrame(data)
+ schema = pa.schema(fields)
+ _check_pandas_roundtrip(df, expected_schema=schema)
+
+ def test_float_nulls(self):
+ num_values = 100
+
+ null_mask = np.random.randint(0, 10, size=num_values) < 3
+ dtypes = [('f2', pa.float16()),
+ ('f4', pa.float32()),
+ ('f8', pa.float64())]
+ names = ['f2', 'f4', 'f8']
+ expected_cols = []
+
+ arrays = []
+ fields = []
+ for name, arrow_dtype in dtypes:
+ values = np.random.randn(num_values).astype(name)
+
+ arr = pa.array(values, from_pandas=True, mask=null_mask)
+ arrays.append(arr)
+ fields.append(pa.field(name, arrow_dtype))
+ values[null_mask] = np.nan
+
+ expected_cols.append(values)
+
+ ex_frame = pd.DataFrame(dict(zip(names, expected_cols)),
+ columns=names)
+
+ table = pa.Table.from_arrays(arrays, names)
+ assert table.schema.equals(pa.schema(fields))
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, ex_frame)
+
+ def test_float_nulls_to_ints(self):
+ # ARROW-2135
+ df = pd.DataFrame({"a": [1.0, 2.0, np.NaN]})
+ schema = pa.schema([pa.field("a", pa.int16(), nullable=True)])
+ table = pa.Table.from_pandas(df, schema=schema, safe=False)
+ assert table[0].to_pylist() == [1, 2, None]
+ tm.assert_frame_equal(df, table.to_pandas())
+
+ def test_float_nulls_to_boolean(self):
+ s = pd.Series([0.0, 1.0, 2.0, None, -3.0])
+ expected = pd.Series([False, True, True, None, True])
+ _check_array_roundtrip(s, expected=expected, type=pa.bool_())
+
+ def test_series_from_pandas_false_respected(self):
+ # Check that explicit from_pandas=False is respected
+ s = pd.Series([0.0, np.nan])
+ arr = pa.array(s, from_pandas=False)
+ assert arr.null_count == 0
+ assert np.isnan(arr[1].as_py())
+
+ def test_integer_no_nulls(self):
+ data = OrderedDict()
+ fields = []
+
+ numpy_dtypes = [
+ ('i1', pa.int8()), ('i2', pa.int16()),
+ ('i4', pa.int32()), ('i8', pa.int64()),
+ ('u1', pa.uint8()), ('u2', pa.uint16()),
+ ('u4', pa.uint32()), ('u8', pa.uint64()),
+ ('longlong', pa.int64()), ('ulonglong', pa.uint64())
+ ]
+ num_values = 100
+
+ for dtype, arrow_dtype in numpy_dtypes:
+ info = np.iinfo(dtype)
+ values = np.random.randint(max(info.min, np.iinfo(np.int_).min),
+ min(info.max, np.iinfo(np.int_).max),
+ size=num_values)
+ data[dtype] = values.astype(dtype)
+ fields.append(pa.field(dtype, arrow_dtype))
+
+ df = pd.DataFrame(data)
+ schema = pa.schema(fields)
+ _check_pandas_roundtrip(df, expected_schema=schema)
+
+ def test_all_integer_types(self):
+ # Test all Numpy integer aliases
+ data = OrderedDict()
+ numpy_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8',
+ 'byte', 'ubyte', 'short', 'ushort', 'intc', 'uintc',
+ 'int_', 'uint', 'longlong', 'ulonglong']
+ for dtype in numpy_dtypes:
+ data[dtype] = np.arange(12, dtype=dtype)
+ df = pd.DataFrame(data)
+ _check_pandas_roundtrip(df)
+
+ # Do the same with pa.array()
+ # (for some reason, it doesn't use the same code paths at all)
+ for np_arr in data.values():
+ arr = pa.array(np_arr)
+ assert arr.to_pylist() == np_arr.tolist()
+
+ def test_integer_byteorder(self):
+ # Byteswapped arrays are not supported yet
+ int_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8']
+ for dt in int_dtypes:
+ for order in '=<>':
+ data = np.array([1, 2, 42], dtype=order + dt)
+ for np_arr in (data, data[::2]):
+ if data.dtype.isnative:
+ arr = pa.array(data)
+ assert arr.to_pylist() == data.tolist()
+ else:
+ with pytest.raises(NotImplementedError):
+ arr = pa.array(data)
+
+ def test_integer_with_nulls(self):
+ # pandas requires upcast to float dtype
+
+ int_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8']
+ num_values = 100
+
+ null_mask = np.random.randint(0, 10, size=num_values) < 3
+
+ expected_cols = []
+ arrays = []
+ for name in int_dtypes:
+ values = np.random.randint(0, 100, size=num_values)
+
+ arr = pa.array(values, mask=null_mask)
+ arrays.append(arr)
+
+ expected = values.astype('f8')
+ expected[null_mask] = np.nan
+
+ expected_cols.append(expected)
+
+ ex_frame = pd.DataFrame(dict(zip(int_dtypes, expected_cols)),
+ columns=int_dtypes)
+
+ table = pa.Table.from_arrays(arrays, int_dtypes)
+ result = table.to_pandas()
+
+ tm.assert_frame_equal(result, ex_frame)
+
+ def test_array_from_pandas_type_cast(self):
+ arr = np.arange(10, dtype='int64')
+
+ target_type = pa.int8()
+
+ result = pa.array(arr, type=target_type)
+ expected = pa.array(arr.astype('int8'))
+ assert result.equals(expected)
+
+ def test_boolean_no_nulls(self):
+ num_values = 100
+
+ np.random.seed(0)
+
+ df = pd.DataFrame({'bools': np.random.randn(num_values) > 0})
+ field = pa.field('bools', pa.bool_())
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(df, expected_schema=schema)
+
+ def test_boolean_nulls(self):
+ # pandas requires upcast to object dtype
+ num_values = 100
+ np.random.seed(0)
+
+ mask = np.random.randint(0, 10, size=num_values) < 3
+ values = np.random.randint(0, 10, size=num_values) < 5
+
+ arr = pa.array(values, mask=mask)
+
+ expected = values.astype(object)
+ expected[mask] = None
+
+ field = pa.field('bools', pa.bool_())
+ schema = pa.schema([field])
+ ex_frame = pd.DataFrame({'bools': expected})
+
+ table = pa.Table.from_arrays([arr], ['bools'])
+ assert table.schema.equals(schema)
+ result = table.to_pandas()
+
+ tm.assert_frame_equal(result, ex_frame)
+
+ def test_boolean_to_int(self):
+ # test from dtype=bool
+ s = pd.Series([True, True, False, True, True] * 2)
+ expected = pd.Series([1, 1, 0, 1, 1] * 2)
+ _check_array_roundtrip(s, expected=expected, type=pa.int64())
+
+ def test_boolean_objects_to_int(self):
+ # test from dtype=object
+ s = pd.Series([True, True, False, True, True] * 2, dtype=object)
+ expected = pd.Series([1, 1, 0, 1, 1] * 2)
+ expected_msg = 'Expected integer, got bool'
+ with pytest.raises(pa.ArrowTypeError, match=expected_msg):
+ _check_array_roundtrip(s, expected=expected, type=pa.int64())
+
+ def test_boolean_nulls_to_float(self):
+ # test from dtype=object
+ s = pd.Series([True, True, False, None, True] * 2)
+ expected = pd.Series([1.0, 1.0, 0.0, None, 1.0] * 2)
+ _check_array_roundtrip(s, expected=expected, type=pa.float64())
+
+ def test_boolean_multiple_columns(self):
+ # ARROW-6325 (multiple columns resulting in strided conversion)
+ df = pd.DataFrame(np.ones((3, 2), dtype='bool'), columns=['a', 'b'])
+ _check_pandas_roundtrip(df)
+
+ def test_float_object_nulls(self):
+ arr = np.array([None, 1.5, np.float64(3.5)] * 5, dtype=object)
+ df = pd.DataFrame({'floats': arr})
+ expected = pd.DataFrame({'floats': pd.to_numeric(arr)})
+ field = pa.field('floats', pa.float64())
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(df, expected=expected,
+ expected_schema=schema)
+
+ def test_float_with_null_as_integer(self):
+ # ARROW-2298
+ s = pd.Series([np.nan, 1., 2., np.nan])
+
+ types = [pa.int8(), pa.int16(), pa.int32(), pa.int64(),
+ pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
+ for ty in types:
+ result = pa.array(s, type=ty)
+ expected = pa.array([None, 1, 2, None], type=ty)
+ assert result.equals(expected)
+
+ df = pd.DataFrame({'has_nulls': s})
+ schema = pa.schema([pa.field('has_nulls', ty)])
+ result = pa.Table.from_pandas(df, schema=schema,
+ preserve_index=False)
+ assert result[0].chunk(0).equals(expected)
+
+ def test_int_object_nulls(self):
+ arr = np.array([None, 1, np.int64(3)] * 5, dtype=object)
+ df = pd.DataFrame({'ints': arr})
+ expected = pd.DataFrame({'ints': pd.to_numeric(arr)})
+ field = pa.field('ints', pa.int64())
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(df, expected=expected,
+ expected_schema=schema)
+
+ def test_boolean_object_nulls(self):
+ arr = np.array([False, None, True] * 100, dtype=object)
+ df = pd.DataFrame({'bools': arr})
+ field = pa.field('bools', pa.bool_())
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(df, expected_schema=schema)
+
+ def test_all_nulls_cast_numeric(self):
+ arr = np.array([None], dtype=object)
+
+ def _check_type(t):
+ a2 = pa.array(arr, type=t)
+ assert a2.type == t
+ assert a2[0].as_py() is None
+
+ _check_type(pa.int32())
+ _check_type(pa.float64())
+
+ def test_half_floats_from_numpy(self):
+ arr = np.array([1.5, np.nan], dtype=np.float16)
+ a = pa.array(arr, type=pa.float16())
+ x, y = a.to_pylist()
+ assert isinstance(x, np.float16)
+ assert x == 1.5
+ assert isinstance(y, np.float16)
+ assert np.isnan(y)
+
+ a = pa.array(arr, type=pa.float16(), from_pandas=True)
+ x, y = a.to_pylist()
+ assert isinstance(x, np.float16)
+ assert x == 1.5
+ assert y is None
+
+
+@pytest.mark.parametrize('dtype',
+ ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8'])
+def test_array_integer_object_nulls_option(dtype):
+ num_values = 100
+
+ null_mask = np.random.randint(0, 10, size=num_values) < 3
+ values = np.random.randint(0, 100, size=num_values, dtype=dtype)
+
+ array = pa.array(values, mask=null_mask)
+
+ if null_mask.any():
+ expected = values.astype('O')
+ expected[null_mask] = None
+ else:
+ expected = values
+
+ result = array.to_pandas(integer_object_nulls=True)
+
+ np.testing.assert_equal(result, expected)
+
+
+@pytest.mark.parametrize('dtype',
+ ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8'])
+def test_table_integer_object_nulls_option(dtype):
+ num_values = 100
+
+ null_mask = np.random.randint(0, 10, size=num_values) < 3
+ values = np.random.randint(0, 100, size=num_values, dtype=dtype)
+
+ array = pa.array(values, mask=null_mask)
+
+ if null_mask.any():
+ expected = values.astype('O')
+ expected[null_mask] = None
+ else:
+ expected = values
+
+ expected = pd.DataFrame({dtype: expected})
+
+ table = pa.Table.from_arrays([array], [dtype])
+ result = table.to_pandas(integer_object_nulls=True)
+
+ tm.assert_frame_equal(result, expected)
+
+
+class TestConvertDateTimeLikeTypes:
+ """
+ Conversion tests for datetime- and timestamp-like types (date64, etc.).
+ """
+
+ def test_timestamps_notimezone_no_nulls(self):
+ df = pd.DataFrame({
+ 'datetime64': np.array([
+ '2007-07-13T01:23:34.123456789',
+ '2006-01-13T12:34:56.432539784',
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ns]')
+ })
+ field = pa.field('datetime64', pa.timestamp('ns'))
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(
+ df,
+ expected_schema=schema,
+ )
+
+ def test_timestamps_notimezone_nulls(self):
+ df = pd.DataFrame({
+ 'datetime64': np.array([
+ '2007-07-13T01:23:34.123456789',
+ None,
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ns]')
+ })
+ field = pa.field('datetime64', pa.timestamp('ns'))
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(
+ df,
+ expected_schema=schema,
+ )
+
+ def test_timestamps_with_timezone(self):
+ df = pd.DataFrame({
+ 'datetime64': np.array([
+ '2007-07-13T01:23:34.123',
+ '2006-01-13T12:34:56.432',
+ '2010-08-13T05:46:57.437'],
+ dtype='datetime64[ms]')
+ })
+ df['datetime64'] = df['datetime64'].dt.tz_localize('US/Eastern')
+ _check_pandas_roundtrip(df)
+
+ _check_series_roundtrip(df['datetime64'])
+
+ # drop-in a null and ns instead of ms
+ df = pd.DataFrame({
+ 'datetime64': np.array([
+ '2007-07-13T01:23:34.123456789',
+ None,
+ '2006-01-13T12:34:56.432539784',
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ns]')
+ })
+ df['datetime64'] = df['datetime64'].dt.tz_localize('US/Eastern')
+
+ _check_pandas_roundtrip(df)
+
+ def test_python_datetime(self):
+ # ARROW-2106
+ date_array = [datetime.today() + timedelta(days=x) for x in range(10)]
+ df = pd.DataFrame({
+ 'datetime': pd.Series(date_array, dtype=object)
+ })
+
+ table = pa.Table.from_pandas(df)
+ assert isinstance(table[0].chunk(0), pa.TimestampArray)
+
+ result = table.to_pandas()
+ expected_df = pd.DataFrame({
+ 'datetime': date_array
+ })
+ tm.assert_frame_equal(expected_df, result)
+
+ def test_python_datetime_with_pytz_tzinfo(self):
+ for tz in [pytz.utc, pytz.timezone('US/Eastern'), pytz.FixedOffset(1)]:
+ values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz)]
+ df = pd.DataFrame({'datetime': values})
+ _check_pandas_roundtrip(df)
+
+ @h.given(st.none() | tzst.timezones())
+ def test_python_datetime_with_pytz_timezone(self, tz):
+ values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz)]
+ df = pd.DataFrame({'datetime': values})
+ _check_pandas_roundtrip(df)
+
+ def test_python_datetime_with_timezone_tzinfo(self):
+ from datetime import timezone
+
+ if Version(pd.__version__) > Version("0.25.0"):
+ # older pandas versions fail on datetime.timezone.utc (as in input)
+ # vs pytz.UTC (as in result)
+ values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=timezone.utc)]
+ # also test with index to ensure both paths roundtrip (ARROW-9962)
+ df = pd.DataFrame({'datetime': values}, index=values)
+ _check_pandas_roundtrip(df, preserve_index=True)
+
+ # datetime.timezone is going to be pytz.FixedOffset
+ hours = 1
+ tz_timezone = timezone(timedelta(hours=hours))
+ tz_pytz = pytz.FixedOffset(hours * 60)
+ values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz_timezone)]
+ values_exp = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz_pytz)]
+ df = pd.DataFrame({'datetime': values}, index=values)
+ df_exp = pd.DataFrame({'datetime': values_exp}, index=values_exp)
+ _check_pandas_roundtrip(df, expected=df_exp, preserve_index=True)
+
+ def test_python_datetime_subclass(self):
+
+ class MyDatetime(datetime):
+ # see https://github.com/pandas-dev/pandas/issues/21142
+ nanosecond = 0.0
+
+ date_array = [MyDatetime(2000, 1, 1, 1, 1, 1)]
+ df = pd.DataFrame({"datetime": pd.Series(date_array, dtype=object)})
+
+ table = pa.Table.from_pandas(df)
+ assert isinstance(table[0].chunk(0), pa.TimestampArray)
+
+ result = table.to_pandas()
+ expected_df = pd.DataFrame({"datetime": date_array})
+
+ # https://github.com/pandas-dev/pandas/issues/21142
+ expected_df["datetime"] = pd.to_datetime(expected_df["datetime"])
+
+ tm.assert_frame_equal(expected_df, result)
+
+ def test_python_date_subclass(self):
+
+ class MyDate(date):
+ pass
+
+ date_array = [MyDate(2000, 1, 1)]
+ df = pd.DataFrame({"date": pd.Series(date_array, dtype=object)})
+
+ table = pa.Table.from_pandas(df)
+ assert isinstance(table[0].chunk(0), pa.Date32Array)
+
+ result = table.to_pandas()
+ expected_df = pd.DataFrame(
+ {"date": np.array([date(2000, 1, 1)], dtype=object)}
+ )
+ tm.assert_frame_equal(expected_df, result)
+
+ def test_datetime64_to_date32(self):
+ # ARROW-1718
+ arr = pa.array([date(2017, 10, 23), None])
+ c = pa.chunked_array([arr])
+ s = c.to_pandas()
+
+ arr2 = pa.Array.from_pandas(s, type=pa.date32())
+
+ assert arr2.equals(arr.cast('date32'))
+
+ @pytest.mark.parametrize('mask', [
+ None,
+ np.array([True, False, False, True, False, False]),
+ ])
+ def test_pandas_datetime_to_date64(self, mask):
+ s = pd.to_datetime([
+ '2018-05-10T00:00:00',
+ '2018-05-11T00:00:00',
+ '2018-05-12T00:00:00',
+ '2018-05-10T10:24:01',
+ '2018-05-11T10:24:01',
+ '2018-05-12T10:24:01',
+ ])
+ arr = pa.Array.from_pandas(s, type=pa.date64(), mask=mask)
+
+ data = np.array([
+ date(2018, 5, 10),
+ date(2018, 5, 11),
+ date(2018, 5, 12),
+ date(2018, 5, 10),
+ date(2018, 5, 11),
+ date(2018, 5, 12),
+ ])
+ expected = pa.array(data, mask=mask, type=pa.date64())
+
+ assert arr.equals(expected)
+
+ def test_array_types_date_as_object(self):
+ data = [date(2000, 1, 1),
+ None,
+ date(1970, 1, 1),
+ date(2040, 2, 26)]
+ expected_d = np.array(['2000-01-01', None, '1970-01-01',
+ '2040-02-26'], dtype='datetime64[D]')
+
+ expected_ns = np.array(['2000-01-01', None, '1970-01-01',
+ '2040-02-26'], dtype='datetime64[ns]')
+
+ objects = [pa.array(data),
+ pa.chunked_array([data])]
+
+ for obj in objects:
+ result = obj.to_pandas()
+ expected_obj = expected_d.astype(object)
+ assert result.dtype == expected_obj.dtype
+ npt.assert_array_equal(result, expected_obj)
+
+ result = obj.to_pandas(date_as_object=False)
+ assert result.dtype == expected_ns.dtype
+ npt.assert_array_equal(result, expected_ns)
+
+ def test_table_convert_date_as_object(self):
+ df = pd.DataFrame({
+ 'date': [date(2000, 1, 1),
+ None,
+ date(1970, 1, 1),
+ date(2040, 2, 26)]})
+
+ table = pa.Table.from_pandas(df, preserve_index=False)
+
+ df_datetime = table.to_pandas(date_as_object=False)
+ df_object = table.to_pandas()
+
+ tm.assert_frame_equal(df.astype('datetime64[ns]'), df_datetime,
+ check_dtype=True)
+ tm.assert_frame_equal(df, df_object, check_dtype=True)
+
+ def test_date_infer(self):
+ df = pd.DataFrame({
+ 'date': [date(2000, 1, 1),
+ None,
+ date(1970, 1, 1),
+ date(2040, 2, 26)]})
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ field = pa.field('date', pa.date32())
+
+ # schema's metadata is generated by from_pandas conversion
+ expected_schema = pa.schema([field], metadata=table.schema.metadata)
+ assert table.schema.equals(expected_schema)
+
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, df)
+
+ def test_date_mask(self):
+ arr = np.array([date(2017, 4, 3), date(2017, 4, 4)],
+ dtype='datetime64[D]')
+ mask = [True, False]
+ result = pa.array(arr, mask=np.array(mask))
+ expected = np.array([None, date(2017, 4, 4)], dtype='datetime64[D]')
+ expected = pa.array(expected, from_pandas=True)
+ assert expected.equals(result)
+
+ def test_date_objects_typed(self):
+ arr = np.array([
+ date(2017, 4, 3),
+ None,
+ date(2017, 4, 4),
+ date(2017, 4, 5)], dtype=object)
+
+ arr_i4 = np.array([17259, -1, 17260, 17261], dtype='int32')
+ arr_i8 = arr_i4.astype('int64') * 86400000
+ mask = np.array([False, True, False, False])
+
+ t32 = pa.date32()
+ t64 = pa.date64()
+
+ a32 = pa.array(arr, type=t32)
+ a64 = pa.array(arr, type=t64)
+
+ a32_expected = pa.array(arr_i4, mask=mask, type=t32)
+ a64_expected = pa.array(arr_i8, mask=mask, type=t64)
+
+ assert a32.equals(a32_expected)
+ assert a64.equals(a64_expected)
+
+ # Test converting back to pandas
+ colnames = ['date32', 'date64']
+ table = pa.Table.from_arrays([a32, a64], colnames)
+
+ ex_values = (np.array(['2017-04-03', '2017-04-04', '2017-04-04',
+ '2017-04-05'],
+ dtype='datetime64[D]'))
+ ex_values[1] = pd.NaT.value
+
+ ex_datetime64ns = ex_values.astype('datetime64[ns]')
+ expected_pandas = pd.DataFrame({'date32': ex_datetime64ns,
+ 'date64': ex_datetime64ns},
+ columns=colnames)
+ table_pandas = table.to_pandas(date_as_object=False)
+ tm.assert_frame_equal(table_pandas, expected_pandas)
+
+ table_pandas_objects = table.to_pandas()
+ ex_objects = ex_values.astype('object')
+ expected_pandas_objects = pd.DataFrame({'date32': ex_objects,
+ 'date64': ex_objects},
+ columns=colnames)
+ tm.assert_frame_equal(table_pandas_objects,
+ expected_pandas_objects)
+
+ def test_pandas_null_values(self):
+ # ARROW-842
+ pd_NA = getattr(pd, 'NA', None)
+ values = np.array([datetime(2000, 1, 1), pd.NaT, pd_NA], dtype=object)
+ values_with_none = np.array([datetime(2000, 1, 1), None, None],
+ dtype=object)
+ result = pa.array(values, from_pandas=True)
+ expected = pa.array(values_with_none, from_pandas=True)
+ assert result.equals(expected)
+ assert result.null_count == 2
+
+ # ARROW-9407
+ assert pa.array([pd.NaT], from_pandas=True).type == pa.null()
+ assert pa.array([pd_NA], from_pandas=True).type == pa.null()
+
+ def test_dates_from_integers(self):
+ t1 = pa.date32()
+ t2 = pa.date64()
+
+ arr = np.array([17259, 17260, 17261], dtype='int32')
+ arr2 = arr.astype('int64') * 86400000
+
+ a1 = pa.array(arr, type=t1)
+ a2 = pa.array(arr2, type=t2)
+
+ expected = date(2017, 4, 3)
+ assert a1[0].as_py() == expected
+ assert a2[0].as_py() == expected
+
+ def test_pytime_from_pandas(self):
+ pytimes = [time(1, 2, 3, 1356),
+ time(4, 5, 6, 1356)]
+
+ # microseconds
+ t1 = pa.time64('us')
+
+ aobjs = np.array(pytimes + [None], dtype=object)
+ parr = pa.array(aobjs)
+ assert parr.type == t1
+ assert parr[0].as_py() == pytimes[0]
+ assert parr[1].as_py() == pytimes[1]
+ assert parr[2].as_py() is None
+
+ # DataFrame
+ df = pd.DataFrame({'times': aobjs})
+ batch = pa.RecordBatch.from_pandas(df)
+ assert batch[0].equals(parr)
+
+ # Test ndarray of int64 values
+ arr = np.array([_pytime_to_micros(v) for v in pytimes],
+ dtype='int64')
+
+ a1 = pa.array(arr, type=pa.time64('us'))
+ assert a1[0].as_py() == pytimes[0]
+
+ a2 = pa.array(arr * 1000, type=pa.time64('ns'))
+ assert a2[0].as_py() == pytimes[0]
+
+ a3 = pa.array((arr / 1000).astype('i4'),
+ type=pa.time32('ms'))
+ assert a3[0].as_py() == pytimes[0].replace(microsecond=1000)
+
+ a4 = pa.array((arr / 1000000).astype('i4'),
+ type=pa.time32('s'))
+ assert a4[0].as_py() == pytimes[0].replace(microsecond=0)
+
+ def test_arrow_time_to_pandas(self):
+ pytimes = [time(1, 2, 3, 1356),
+ time(4, 5, 6, 1356),
+ time(0, 0, 0)]
+
+ expected = np.array(pytimes[:2] + [None])
+ expected_ms = np.array([x.replace(microsecond=1000)
+ for x in pytimes[:2]] +
+ [None])
+ expected_s = np.array([x.replace(microsecond=0)
+ for x in pytimes[:2]] +
+ [None])
+
+ arr = np.array([_pytime_to_micros(v) for v in pytimes],
+ dtype='int64')
+ arr = np.array([_pytime_to_micros(v) for v in pytimes],
+ dtype='int64')
+
+ null_mask = np.array([False, False, True], dtype=bool)
+
+ a1 = pa.array(arr, mask=null_mask, type=pa.time64('us'))
+ a2 = pa.array(arr * 1000, mask=null_mask,
+ type=pa.time64('ns'))
+
+ a3 = pa.array((arr / 1000).astype('i4'), mask=null_mask,
+ type=pa.time32('ms'))
+ a4 = pa.array((arr / 1000000).astype('i4'), mask=null_mask,
+ type=pa.time32('s'))
+
+ names = ['time64[us]', 'time64[ns]', 'time32[ms]', 'time32[s]']
+ batch = pa.RecordBatch.from_arrays([a1, a2, a3, a4], names)
+
+ for arr, expected_values in [(a1, expected),
+ (a2, expected),
+ (a3, expected_ms),
+ (a4, expected_s)]:
+ result_pandas = arr.to_pandas()
+ assert (result_pandas.values == expected_values).all()
+
+ df = batch.to_pandas()
+ expected_df = pd.DataFrame({'time64[us]': expected,
+ 'time64[ns]': expected,
+ 'time32[ms]': expected_ms,
+ 'time32[s]': expected_s},
+ columns=names)
+
+ tm.assert_frame_equal(df, expected_df)
+
+ def test_numpy_datetime64_columns(self):
+ datetime64_ns = np.array([
+ '2007-07-13T01:23:34.123456789',
+ None,
+ '2006-01-13T12:34:56.432539784',
+ '2010-08-13T05:46:57.437699912'],
+ dtype='datetime64[ns]')
+ _check_array_from_pandas_roundtrip(datetime64_ns)
+
+ datetime64_us = np.array([
+ '2007-07-13T01:23:34.123456',
+ None,
+ '2006-01-13T12:34:56.432539',
+ '2010-08-13T05:46:57.437699'],
+ dtype='datetime64[us]')
+ _check_array_from_pandas_roundtrip(datetime64_us)
+
+ datetime64_ms = np.array([
+ '2007-07-13T01:23:34.123',
+ None,
+ '2006-01-13T12:34:56.432',
+ '2010-08-13T05:46:57.437'],
+ dtype='datetime64[ms]')
+ _check_array_from_pandas_roundtrip(datetime64_ms)
+
+ datetime64_s = np.array([
+ '2007-07-13T01:23:34',
+ None,
+ '2006-01-13T12:34:56',
+ '2010-08-13T05:46:57'],
+ dtype='datetime64[s]')
+ _check_array_from_pandas_roundtrip(datetime64_s)
+
+ def test_timestamp_to_pandas_ns(self):
+ # non-ns timestamp gets cast to ns on conversion to pandas
+ arr = pa.array([1, 2, 3], pa.timestamp('ms'))
+ expected = pd.Series(pd.to_datetime([1, 2, 3], unit='ms'))
+ s = arr.to_pandas()
+ tm.assert_series_equal(s, expected)
+ arr = pa.chunked_array([arr])
+ s = arr.to_pandas()
+ tm.assert_series_equal(s, expected)
+
+ def test_timestamp_to_pandas_out_of_bounds(self):
+ # ARROW-7758 check for out of bounds timestamps for non-ns timestamps
+
+ for unit in ['s', 'ms', 'us']:
+ for tz in [None, 'America/New_York']:
+ arr = pa.array([datetime(1, 1, 1)], pa.timestamp(unit, tz=tz))
+ table = pa.table({'a': arr})
+
+ msg = "would result in out of bounds timestamp"
+ with pytest.raises(ValueError, match=msg):
+ arr.to_pandas()
+
+ with pytest.raises(ValueError, match=msg):
+ table.to_pandas()
+
+ with pytest.raises(ValueError, match=msg):
+ # chunked array
+ table.column('a').to_pandas()
+
+ # just ensure those don't give an error, but do not
+ # check actual garbage output
+ arr.to_pandas(safe=False)
+ table.to_pandas(safe=False)
+ table.column('a').to_pandas(safe=False)
+
+ def test_timestamp_to_pandas_empty_chunked(self):
+ # ARROW-7907 table with chunked array with 0 chunks
+ table = pa.table({'a': pa.chunked_array([], type=pa.timestamp('us'))})
+ result = table.to_pandas()
+ expected = pd.DataFrame({'a': pd.Series([], dtype="datetime64[ns]")})
+ tm.assert_frame_equal(result, expected)
+
+ @pytest.mark.parametrize('dtype', [pa.date32(), pa.date64()])
+ def test_numpy_datetime64_day_unit(self, dtype):
+ datetime64_d = np.array([
+ '2007-07-13',
+ None,
+ '2006-01-15',
+ '2010-08-19'],
+ dtype='datetime64[D]')
+ _check_array_from_pandas_roundtrip(datetime64_d, type=dtype)
+
+ def test_array_from_pandas_date_with_mask(self):
+ m = np.array([True, False, True])
+ data = pd.Series([
+ date(1990, 1, 1),
+ date(1991, 1, 1),
+ date(1992, 1, 1)
+ ])
+
+ result = pa.Array.from_pandas(data, mask=m)
+
+ expected = pd.Series([None, date(1991, 1, 1), None])
+ assert pa.Array.from_pandas(expected).equals(result)
+
+ @pytest.mark.skipif(
+ Version('1.16.0') <= Version(np.__version__) < Version('1.16.1'),
+ reason='Until numpy/numpy#12745 is resolved')
+ def test_fixed_offset_timezone(self):
+ df = pd.DataFrame({
+ 'a': [
+ pd.Timestamp('2012-11-11 00:00:00+01:00'),
+ pd.NaT
+ ]
+ })
+ _check_pandas_roundtrip(df)
+ _check_serialize_components_roundtrip(df)
+
+ def test_timedeltas_no_nulls(self):
+ df = pd.DataFrame({
+ 'timedelta64': np.array([0, 3600000000000, 7200000000000],
+ dtype='timedelta64[ns]')
+ })
+ field = pa.field('timedelta64', pa.duration('ns'))
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(
+ df,
+ expected_schema=schema,
+ )
+
+ def test_timedeltas_nulls(self):
+ df = pd.DataFrame({
+ 'timedelta64': np.array([0, None, 7200000000000],
+ dtype='timedelta64[ns]')
+ })
+ field = pa.field('timedelta64', pa.duration('ns'))
+ schema = pa.schema([field])
+ _check_pandas_roundtrip(
+ df,
+ expected_schema=schema,
+ )
+
+ def test_month_day_nano_interval(self):
+ from pandas.tseries.offsets import DateOffset
+ df = pd.DataFrame({
+ 'date_offset': [None,
+ DateOffset(days=3600, months=3600, microseconds=3,
+ nanoseconds=600)]
+ })
+ schema = pa.schema([('date_offset', pa.month_day_nano_interval())])
+ _check_pandas_roundtrip(
+ df,
+ expected_schema=schema)
+
+
+# ----------------------------------------------------------------------
+# Conversion tests for string and binary types.
+
+
+class TestConvertStringLikeTypes:
+
+ def test_pandas_unicode(self):
+ repeats = 1000
+ values = ['foo', None, 'bar', 'mañana', np.nan]
+ df = pd.DataFrame({'strings': values * repeats})
+ field = pa.field('strings', pa.string())
+ schema = pa.schema([field])
+
+ _check_pandas_roundtrip(df, expected_schema=schema)
+
+ def test_bytes_to_binary(self):
+ values = ['qux', b'foo', None, bytearray(b'barz'), 'qux', np.nan]
+ df = pd.DataFrame({'strings': values})
+
+ table = pa.Table.from_pandas(df)
+ assert table[0].type == pa.binary()
+
+ values2 = [b'qux', b'foo', None, b'barz', b'qux', np.nan]
+ expected = pd.DataFrame({'strings': values2})
+ _check_pandas_roundtrip(df, expected)
+
+ @pytest.mark.large_memory
+ def test_bytes_exceed_2gb(self):
+ v1 = b'x' * 100000000
+ v2 = b'x' * 147483646
+
+ # ARROW-2227, hit exactly 2GB on the nose
+ df = pd.DataFrame({
+ 'strings': [v1] * 20 + [v2] + ['x'] * 20
+ })
+ arr = pa.array(df['strings'])
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.num_chunks == 2
+ arr = None
+
+ table = pa.Table.from_pandas(df)
+ assert table[0].num_chunks == 2
+
+ @pytest.mark.large_memory
+ @pytest.mark.parametrize('char', ['x', b'x'])
+ def test_auto_chunking_pandas_series_of_strings(self, char):
+ # ARROW-2367
+ v1 = char * 100000000
+ v2 = char * 147483646
+
+ df = pd.DataFrame({
+ 'strings': [[v1]] * 20 + [[v2]] + [[b'x']]
+ })
+ arr = pa.array(df['strings'], from_pandas=True)
+ assert isinstance(arr, pa.ChunkedArray)
+ assert arr.num_chunks == 2
+ assert len(arr.chunk(0)) == 21
+ assert len(arr.chunk(1)) == 1
+
+ def test_fixed_size_bytes(self):
+ values = [b'foo', None, bytearray(b'bar'), None, None, b'hey']
+ df = pd.DataFrame({'strings': values})
+ schema = pa.schema([pa.field('strings', pa.binary(3))])
+ table = pa.Table.from_pandas(df, schema=schema)
+ assert table.schema[0].type == schema[0].type
+ assert table.schema[0].name == schema[0].name
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, df)
+
+ def test_fixed_size_bytes_does_not_accept_varying_lengths(self):
+ values = [b'foo', None, b'ba', None, None, b'hey']
+ df = pd.DataFrame({'strings': values})
+ schema = pa.schema([pa.field('strings', pa.binary(3))])
+ with pytest.raises(pa.ArrowInvalid):
+ pa.Table.from_pandas(df, schema=schema)
+
+ def test_variable_size_bytes(self):
+ s = pd.Series([b'123', b'', b'a', None])
+ _check_series_roundtrip(s, type_=pa.binary())
+
+ def test_binary_from_bytearray(self):
+ s = pd.Series([bytearray(b'123'), bytearray(b''), bytearray(b'a'),
+ None])
+ # Explicitly set type
+ _check_series_roundtrip(s, type_=pa.binary())
+ # Infer type from bytearrays
+ _check_series_roundtrip(s, expected_pa_type=pa.binary())
+
+ def test_large_binary(self):
+ s = pd.Series([b'123', b'', b'a', None])
+ _check_series_roundtrip(s, type_=pa.large_binary())
+ df = pd.DataFrame({'a': s})
+ _check_pandas_roundtrip(
+ df, schema=pa.schema([('a', pa.large_binary())]))
+
+ def test_large_string(self):
+ s = pd.Series(['123', '', 'a', None])
+ _check_series_roundtrip(s, type_=pa.large_string())
+ df = pd.DataFrame({'a': s})
+ _check_pandas_roundtrip(
+ df, schema=pa.schema([('a', pa.large_string())]))
+
+ def test_table_empty_str(self):
+ values = ['', '', '', '', '']
+ df = pd.DataFrame({'strings': values})
+ field = pa.field('strings', pa.string())
+ schema = pa.schema([field])
+ table = pa.Table.from_pandas(df, schema=schema)
+
+ result1 = table.to_pandas(strings_to_categorical=False)
+ expected1 = pd.DataFrame({'strings': values})
+ tm.assert_frame_equal(result1, expected1, check_dtype=True)
+
+ result2 = table.to_pandas(strings_to_categorical=True)
+ expected2 = pd.DataFrame({'strings': pd.Categorical(values)})
+ tm.assert_frame_equal(result2, expected2, check_dtype=True)
+
+ def test_selective_categoricals(self):
+ values = ['', '', '', '', '']
+ df = pd.DataFrame({'strings': values})
+ field = pa.field('strings', pa.string())
+ schema = pa.schema([field])
+ table = pa.Table.from_pandas(df, schema=schema)
+ expected_str = pd.DataFrame({'strings': values})
+ expected_cat = pd.DataFrame({'strings': pd.Categorical(values)})
+
+ result1 = table.to_pandas(categories=['strings'])
+ tm.assert_frame_equal(result1, expected_cat, check_dtype=True)
+ result2 = table.to_pandas(categories=[])
+ tm.assert_frame_equal(result2, expected_str, check_dtype=True)
+ result3 = table.to_pandas(categories=('strings',))
+ tm.assert_frame_equal(result3, expected_cat, check_dtype=True)
+ result4 = table.to_pandas(categories=tuple())
+ tm.assert_frame_equal(result4, expected_str, check_dtype=True)
+
+ def test_to_pandas_categorical_zero_length(self):
+ # ARROW-3586
+ array = pa.array([], type=pa.int32())
+ table = pa.Table.from_arrays(arrays=[array], names=['col'])
+ # This would segfault under 0.11.0
+ table.to_pandas(categories=['col'])
+
+ def test_to_pandas_categories_already_dictionary(self):
+ # Showed up in ARROW-6434, ARROW-6435
+ array = pa.array(['foo', 'foo', 'foo', 'bar']).dictionary_encode()
+ table = pa.Table.from_arrays(arrays=[array], names=['col'])
+ result = table.to_pandas(categories=['col'])
+ assert table.to_pandas().equals(result)
+
+ def test_table_str_to_categorical_without_na(self):
+ values = ['a', 'a', 'b', 'b', 'c']
+ df = pd.DataFrame({'strings': values})
+ field = pa.field('strings', pa.string())
+ schema = pa.schema([field])
+ table = pa.Table.from_pandas(df, schema=schema)
+
+ result = table.to_pandas(strings_to_categorical=True)
+ expected = pd.DataFrame({'strings': pd.Categorical(values)})
+ tm.assert_frame_equal(result, expected, check_dtype=True)
+
+ with pytest.raises(pa.ArrowInvalid):
+ table.to_pandas(strings_to_categorical=True,
+ zero_copy_only=True)
+
+ def test_table_str_to_categorical_with_na(self):
+ values = [None, 'a', 'b', np.nan]
+ df = pd.DataFrame({'strings': values})
+ field = pa.field('strings', pa.string())
+ schema = pa.schema([field])
+ table = pa.Table.from_pandas(df, schema=schema)
+
+ result = table.to_pandas(strings_to_categorical=True)
+ expected = pd.DataFrame({'strings': pd.Categorical(values)})
+ tm.assert_frame_equal(result, expected, check_dtype=True)
+
+ with pytest.raises(pa.ArrowInvalid):
+ table.to_pandas(strings_to_categorical=True,
+ zero_copy_only=True)
+
+ # Regression test for ARROW-2101
+ def test_array_of_bytes_to_strings(self):
+ converted = pa.array(np.array([b'x'], dtype=object), pa.string())
+ assert converted.type == pa.string()
+
+ # Make sure that if an ndarray of bytes is passed to the array
+ # constructor and the type is string, it will fail if those bytes
+ # cannot be converted to utf-8
+ def test_array_of_bytes_to_strings_bad_data(self):
+ with pytest.raises(
+ pa.lib.ArrowInvalid,
+ match="was not a utf8 string"):
+ pa.array(np.array([b'\x80\x81'], dtype=object), pa.string())
+
+ def test_numpy_string_array_to_fixed_size_binary(self):
+ arr = np.array([b'foo', b'bar', b'baz'], dtype='|S3')
+
+ converted = pa.array(arr, type=pa.binary(3))
+ expected = pa.array(list(arr), type=pa.binary(3))
+ assert converted.equals(expected)
+
+ mask = np.array([False, True, False])
+ converted = pa.array(arr, type=pa.binary(3), mask=mask)
+ expected = pa.array([b'foo', None, b'baz'], type=pa.binary(3))
+ assert converted.equals(expected)
+
+ with pytest.raises(pa.lib.ArrowInvalid,
+ match=r'Got bytestring of length 3 \(expected 4\)'):
+ arr = np.array([b'foo', b'bar', b'baz'], dtype='|S3')
+ pa.array(arr, type=pa.binary(4))
+
+ with pytest.raises(
+ pa.lib.ArrowInvalid,
+ match=r'Got bytestring of length 12 \(expected 3\)'):
+ arr = np.array([b'foo', b'bar', b'baz'], dtype='|U3')
+ pa.array(arr, type=pa.binary(3))
+
+
+class TestConvertDecimalTypes:
+ """
+ Conversion test for decimal types.
+ """
+ decimal32 = [
+ decimal.Decimal('-1234.123'),
+ decimal.Decimal('1234.439')
+ ]
+ decimal64 = [
+ decimal.Decimal('-129934.123331'),
+ decimal.Decimal('129534.123731')
+ ]
+ decimal128 = [
+ decimal.Decimal('394092382910493.12341234678'),
+ decimal.Decimal('-314292388910493.12343437128')
+ ]
+
+ @pytest.mark.parametrize(('values', 'expected_type'), [
+ pytest.param(decimal32, pa.decimal128(7, 3), id='decimal32'),
+ pytest.param(decimal64, pa.decimal128(12, 6), id='decimal64'),
+ pytest.param(decimal128, pa.decimal128(26, 11), id='decimal128')
+ ])
+ def test_decimal_from_pandas(self, values, expected_type):
+ expected = pd.DataFrame({'decimals': values})
+ table = pa.Table.from_pandas(expected, preserve_index=False)
+ field = pa.field('decimals', expected_type)
+
+ # schema's metadata is generated by from_pandas conversion
+ expected_schema = pa.schema([field], metadata=table.schema.metadata)
+ assert table.schema.equals(expected_schema)
+
+ @pytest.mark.parametrize('values', [
+ pytest.param(decimal32, id='decimal32'),
+ pytest.param(decimal64, id='decimal64'),
+ pytest.param(decimal128, id='decimal128')
+ ])
+ def test_decimal_to_pandas(self, values):
+ expected = pd.DataFrame({'decimals': values})
+ converted = pa.Table.from_pandas(expected)
+ df = converted.to_pandas()
+ tm.assert_frame_equal(df, expected)
+
+ def test_decimal_fails_with_truncation(self):
+ data1 = [decimal.Decimal('1.234')]
+ type1 = pa.decimal128(10, 2)
+ with pytest.raises(pa.ArrowInvalid):
+ pa.array(data1, type=type1)
+
+ data2 = [decimal.Decimal('1.2345')]
+ type2 = pa.decimal128(10, 3)
+ with pytest.raises(pa.ArrowInvalid):
+ pa.array(data2, type=type2)
+
+ def test_decimal_with_different_precisions(self):
+ data = [
+ decimal.Decimal('0.01'),
+ decimal.Decimal('0.001'),
+ ]
+ series = pd.Series(data)
+ array = pa.array(series)
+ assert array.to_pylist() == data
+ assert array.type == pa.decimal128(3, 3)
+
+ array = pa.array(data, type=pa.decimal128(12, 5))
+ expected = [decimal.Decimal('0.01000'), decimal.Decimal('0.00100')]
+ assert array.to_pylist() == expected
+
+ def test_decimal_with_None_explicit_type(self):
+ series = pd.Series([decimal.Decimal('3.14'), None])
+ _check_series_roundtrip(series, type_=pa.decimal128(12, 5))
+
+ # Test that having all None values still produces decimal array
+ series = pd.Series([None] * 2)
+ _check_series_roundtrip(series, type_=pa.decimal128(12, 5))
+
+ def test_decimal_with_None_infer_type(self):
+ series = pd.Series([decimal.Decimal('3.14'), None])
+ _check_series_roundtrip(series, expected_pa_type=pa.decimal128(3, 2))
+
+ def test_strided_objects(self, tmpdir):
+ # see ARROW-3053
+ data = {
+ 'a': {0: 'a'},
+ 'b': {0: decimal.Decimal('0.0')}
+ }
+
+ # This yields strided objects
+ df = pd.DataFrame.from_dict(data)
+ _check_pandas_roundtrip(df)
+
+
+class TestConvertListTypes:
+ """
+ Conversion tests for list<> types.
+ """
+
+ def test_column_of_arrays(self):
+ df, schema = dataframe_with_arrays()
+ _check_pandas_roundtrip(df, schema=schema, expected_schema=schema)
+ table = pa.Table.from_pandas(df, schema=schema, preserve_index=False)
+
+ # schema's metadata is generated by from_pandas conversion
+ expected_schema = schema.with_metadata(table.schema.metadata)
+ assert table.schema.equals(expected_schema)
+
+ for column in df.columns:
+ field = schema.field(column)
+ _check_array_roundtrip(df[column], type=field.type)
+
+ def test_column_of_arrays_to_py(self):
+ # Test regression in ARROW-1199 not caught in above test
+ dtype = 'i1'
+ arr = np.array([
+ np.arange(10, dtype=dtype),
+ np.arange(5, dtype=dtype),
+ None,
+ np.arange(1, dtype=dtype)
+ ], dtype=object)
+ type_ = pa.list_(pa.int8())
+ parr = pa.array(arr, type=type_)
+
+ assert parr[0].as_py() == list(range(10))
+ assert parr[1].as_py() == list(range(5))
+ assert parr[2].as_py() is None
+ assert parr[3].as_py() == [0]
+
+ def test_column_of_boolean_list(self):
+ # ARROW-4370: Table to pandas conversion fails for list of bool
+ array = pa.array([[True, False], [True]], type=pa.list_(pa.bool_()))
+ table = pa.Table.from_arrays([array], names=['col1'])
+ df = table.to_pandas()
+
+ expected_df = pd.DataFrame({'col1': [[True, False], [True]]})
+ tm.assert_frame_equal(df, expected_df)
+
+ s = table[0].to_pandas()
+ tm.assert_series_equal(pd.Series(s), df['col1'], check_names=False)
+
+ def test_column_of_decimal_list(self):
+ array = pa.array([[decimal.Decimal('1'), decimal.Decimal('2')],
+ [decimal.Decimal('3.3')]],
+ type=pa.list_(pa.decimal128(2, 1)))
+ table = pa.Table.from_arrays([array], names=['col1'])
+ df = table.to_pandas()
+
+ expected_df = pd.DataFrame(
+ {'col1': [[decimal.Decimal('1'), decimal.Decimal('2')],
+ [decimal.Decimal('3.3')]]})
+ tm.assert_frame_equal(df, expected_df)
+
+ def test_nested_types_from_ndarray_null_entries(self):
+ # Root cause of ARROW-6435
+ s = pd.Series(np.array([np.nan, np.nan], dtype=object))
+
+ for ty in [pa.list_(pa.int64()),
+ pa.large_list(pa.int64()),
+ pa.struct([pa.field('f0', 'int32')])]:
+ result = pa.array(s, type=ty)
+ expected = pa.array([None, None], type=ty)
+ assert result.equals(expected)
+
+ with pytest.raises(TypeError):
+ pa.array(s.values, type=ty)
+
+ def test_column_of_lists(self):
+ df, schema = dataframe_with_lists()
+ _check_pandas_roundtrip(df, schema=schema, expected_schema=schema)
+ table = pa.Table.from_pandas(df, schema=schema, preserve_index=False)
+
+ # schema's metadata is generated by from_pandas conversion
+ expected_schema = schema.with_metadata(table.schema.metadata)
+ assert table.schema.equals(expected_schema)
+
+ for column in df.columns:
+ field = schema.field(column)
+ _check_array_roundtrip(df[column], type=field.type)
+
+ def test_column_of_lists_first_empty(self):
+ # ARROW-2124
+ num_lists = [[], [2, 3, 4], [3, 6, 7, 8], [], [2]]
+ series = pd.Series([np.array(s, dtype=float) for s in num_lists])
+ arr = pa.array(series)
+ result = pd.Series(arr.to_pandas())
+ tm.assert_series_equal(result, series)
+
+ def test_column_of_lists_chunked(self):
+ # ARROW-1357
+ df = pd.DataFrame({
+ 'lists': np.array([
+ [1, 2],
+ None,
+ [2, 3],
+ [4, 5],
+ [6, 7],
+ [8, 9]
+ ], dtype=object)
+ })
+
+ schema = pa.schema([
+ pa.field('lists', pa.list_(pa.int64()))
+ ])
+
+ t1 = pa.Table.from_pandas(df[:2], schema=schema)
+ t2 = pa.Table.from_pandas(df[2:], schema=schema)
+
+ table = pa.concat_tables([t1, t2])
+ result = table.to_pandas()
+
+ tm.assert_frame_equal(result, df)
+
+ def test_empty_column_of_lists_chunked(self):
+ df = pd.DataFrame({
+ 'lists': np.array([], dtype=object)
+ })
+
+ schema = pa.schema([
+ pa.field('lists', pa.list_(pa.int64()))
+ ])
+
+ table = pa.Table.from_pandas(df, schema=schema)
+ result = table.to_pandas()
+
+ tm.assert_frame_equal(result, df)
+
+ def test_column_of_lists_chunked2(self):
+ data1 = [[0, 1], [2, 3], [4, 5], [6, 7], [10, 11],
+ [12, 13], [14, 15], [16, 17]]
+ data2 = [[8, 9], [18, 19]]
+
+ a1 = pa.array(data1)
+ a2 = pa.array(data2)
+
+ t1 = pa.Table.from_arrays([a1], names=['a'])
+ t2 = pa.Table.from_arrays([a2], names=['a'])
+
+ concatenated = pa.concat_tables([t1, t2])
+
+ result = concatenated.to_pandas()
+ expected = pd.DataFrame({'a': data1 + data2})
+
+ tm.assert_frame_equal(result, expected)
+
+ def test_column_of_lists_strided(self):
+ df, schema = dataframe_with_lists()
+ df = pd.concat([df] * 6, ignore_index=True)
+
+ arr = df['int64'].values[::3]
+ assert arr.strides[0] != 8
+
+ _check_array_roundtrip(arr)
+
+ def test_nested_lists_all_none(self):
+ data = np.array([[None, None], None], dtype=object)
+
+ arr = pa.array(data)
+ expected = pa.array(list(data))
+ assert arr.equals(expected)
+ assert arr.type == pa.list_(pa.null())
+
+ data2 = np.array([None, None, [None, None],
+ np.array([None, None], dtype=object)],
+ dtype=object)
+ arr = pa.array(data2)
+ expected = pa.array([None, None, [None, None], [None, None]])
+ assert arr.equals(expected)
+
+ def test_nested_lists_all_empty(self):
+ # ARROW-2128
+ data = pd.Series([[], [], []])
+ arr = pa.array(data)
+ expected = pa.array(list(data))
+ assert arr.equals(expected)
+ assert arr.type == pa.list_(pa.null())
+
+ def test_nested_list_first_empty(self):
+ # ARROW-2711
+ data = pd.Series([[], ["a"]])
+ arr = pa.array(data)
+ expected = pa.array(list(data))
+ assert arr.equals(expected)
+ assert arr.type == pa.list_(pa.string())
+
+ def test_nested_smaller_ints(self):
+ # ARROW-1345, ARROW-2008, there were some type inference bugs happening
+ # before
+ data = pd.Series([np.array([1, 2, 3], dtype='i1'), None])
+ result = pa.array(data)
+ result2 = pa.array(data.values)
+ expected = pa.array([[1, 2, 3], None], type=pa.list_(pa.int8()))
+ assert result.equals(expected)
+ assert result2.equals(expected)
+
+ data3 = pd.Series([np.array([1, 2, 3], dtype='f4'), None])
+ result3 = pa.array(data3)
+ expected3 = pa.array([[1, 2, 3], None], type=pa.list_(pa.float32()))
+ assert result3.equals(expected3)
+
+ def test_infer_lists(self):
+ data = OrderedDict([
+ ('nan_ints', [[None, 1], [2, 3]]),
+ ('ints', [[0, 1], [2, 3]]),
+ ('strs', [[None, 'b'], ['c', 'd']]),
+ ('nested_strs', [[[None, 'b'], ['c', 'd']], None])
+ ])
+ df = pd.DataFrame(data)
+
+ expected_schema = pa.schema([
+ pa.field('nan_ints', pa.list_(pa.int64())),
+ pa.field('ints', pa.list_(pa.int64())),
+ pa.field('strs', pa.list_(pa.string())),
+ pa.field('nested_strs', pa.list_(pa.list_(pa.string())))
+ ])
+
+ _check_pandas_roundtrip(df, expected_schema=expected_schema)
+
+ def test_fixed_size_list(self):
+ # ARROW-7365
+ fixed_ty = pa.list_(pa.int64(), list_size=4)
+ variable_ty = pa.list_(pa.int64())
+
+ data = [[0, 1, 2, 3], None, [4, 5, 6, 7], [8, 9, 10, 11]]
+ fixed_arr = pa.array(data, type=fixed_ty)
+ variable_arr = pa.array(data, type=variable_ty)
+
+ result = fixed_arr.to_pandas()
+ expected = variable_arr.to_pandas()
+
+ for left, right in zip(result, expected):
+ if left is None:
+ assert right is None
+ npt.assert_array_equal(left, right)
+
+ def test_infer_numpy_array(self):
+ data = OrderedDict([
+ ('ints', [
+ np.array([0, 1], dtype=np.int64),
+ np.array([2, 3], dtype=np.int64)
+ ])
+ ])
+ df = pd.DataFrame(data)
+ expected_schema = pa.schema([
+ pa.field('ints', pa.list_(pa.int64()))
+ ])
+
+ _check_pandas_roundtrip(df, expected_schema=expected_schema)
+
+ def test_to_list_of_structs_pandas(self):
+ ints = pa.array([1, 2, 3], pa.int32())
+ strings = pa.array([['a', 'b'], ['c', 'd'], ['e', 'f']],
+ pa.list_(pa.string()))
+ structs = pa.StructArray.from_arrays([ints, strings], ['f1', 'f2'])
+ data = pa.ListArray.from_arrays([0, 1, 3], structs)
+
+ expected = pd.Series([
+ [{'f1': 1, 'f2': ['a', 'b']}],
+ [{'f1': 2, 'f2': ['c', 'd']},
+ {'f1': 3, 'f2': ['e', 'f']}]
+ ])
+
+ series = pd.Series(data.to_pandas())
+ tm.assert_series_equal(series, expected)
+
+ @pytest.mark.parametrize('t,data,expected', [
+ (
+ pa.int64,
+ [[1, 2], [3], None],
+ [None, [3], None]
+ ),
+ (
+ pa.string,
+ [['aaa', 'bb'], ['c'], None],
+ [None, ['c'], None]
+ ),
+ (
+ pa.null,
+ [[None, None], [None], None],
+ [None, [None], None]
+ )
+ ])
+ def test_array_from_pandas_typed_array_with_mask(self, t, data, expected):
+ m = np.array([True, False, True])
+
+ s = pd.Series(data)
+ result = pa.Array.from_pandas(s, mask=m, type=pa.list_(t()))
+
+ assert pa.Array.from_pandas(expected,
+ type=pa.list_(t())).equals(result)
+
+ def test_empty_list_roundtrip(self):
+ empty_list_array = np.empty((3,), dtype=object)
+ empty_list_array.fill([])
+
+ df = pd.DataFrame({'a': np.array(['1', '2', '3']),
+ 'b': empty_list_array})
+ tbl = pa.Table.from_pandas(df)
+
+ result = tbl.to_pandas()
+
+ tm.assert_frame_equal(result, df)
+
+ def test_array_from_nested_arrays(self):
+ df, schema = dataframe_with_arrays()
+ for field in schema:
+ arr = df[field.name].values
+ expected = pa.array(list(arr), type=field.type)
+ result = pa.array(arr)
+ assert result.type == field.type # == list<scalar>
+ assert result.equals(expected)
+
+ def test_nested_large_list(self):
+ s = (pa.array([[[1, 2, 3], [4]], None],
+ type=pa.large_list(pa.large_list(pa.int64())))
+ .to_pandas())
+ tm.assert_series_equal(
+ s, pd.Series([[[1, 2, 3], [4]], None], dtype=object),
+ check_names=False)
+
+ def test_large_binary_list(self):
+ for list_type_factory in (pa.list_, pa.large_list):
+ s = (pa.array([["aa", "bb"], None, ["cc"], []],
+ type=list_type_factory(pa.large_binary()))
+ .to_pandas())
+ tm.assert_series_equal(
+ s, pd.Series([[b"aa", b"bb"], None, [b"cc"], []]),
+ check_names=False)
+ s = (pa.array([["aa", "bb"], None, ["cc"], []],
+ type=list_type_factory(pa.large_string()))
+ .to_pandas())
+ tm.assert_series_equal(
+ s, pd.Series([["aa", "bb"], None, ["cc"], []]),
+ check_names=False)
+
+ def test_list_of_dictionary(self):
+ child = pa.array(["foo", "bar", None, "foo"]).dictionary_encode()
+ arr = pa.ListArray.from_arrays([0, 1, 3, 3, 4], child)
+
+ # Expected a Series of lists
+ expected = pd.Series(arr.to_pylist())
+ tm.assert_series_equal(arr.to_pandas(), expected)
+
+ # Same but with nulls
+ arr = arr.take([0, 1, None, 3])
+ expected[2] = None
+ tm.assert_series_equal(arr.to_pandas(), expected)
+
+ @pytest.mark.large_memory
+ def test_auto_chunking_on_list_overflow(self):
+ # ARROW-9976
+ n = 2**21
+ df = pd.DataFrame.from_dict({
+ "a": list(np.zeros((n, 2**10), dtype='uint8')),
+ "b": range(n)
+ })
+ table = pa.Table.from_pandas(df)
+
+ column_a = table[0]
+ assert column_a.num_chunks == 2
+ assert len(column_a.chunk(0)) == 2**21 - 1
+ assert len(column_a.chunk(1)) == 1
+
+ def test_map_array_roundtrip(self):
+ data = [[(b'a', 1), (b'b', 2)],
+ [(b'c', 3)],
+ [(b'd', 4), (b'e', 5), (b'f', 6)],
+ [(b'g', 7)]]
+
+ df = pd.DataFrame({"map": data})
+ schema = pa.schema([("map", pa.map_(pa.binary(), pa.int32()))])
+
+ _check_pandas_roundtrip(df, schema=schema)
+
+ def test_map_array_chunked(self):
+ data1 = [[(b'a', 1), (b'b', 2)],
+ [(b'c', 3)],
+ [(b'd', 4), (b'e', 5), (b'f', 6)],
+ [(b'g', 7)]]
+ data2 = [[(k, v * 2) for k, v in row] for row in data1]
+
+ arr1 = pa.array(data1, type=pa.map_(pa.binary(), pa.int32()))
+ arr2 = pa.array(data2, type=pa.map_(pa.binary(), pa.int32()))
+ arr = pa.chunked_array([arr1, arr2])
+
+ expected = pd.Series(data1 + data2)
+ actual = arr.to_pandas()
+ tm.assert_series_equal(actual, expected, check_names=False)
+
+ def test_map_array_with_nulls(self):
+ data = [[(b'a', 1), (b'b', 2)],
+ None,
+ [(b'd', 4), (b'e', 5), (b'f', None)],
+ [(b'g', 7)]]
+
+ # None value in item array causes upcast to float
+ expected = [[(k, float(v) if v is not None else None) for k, v in row]
+ if row is not None else None for row in data]
+ expected = pd.Series(expected)
+
+ arr = pa.array(data, type=pa.map_(pa.binary(), pa.int32()))
+ actual = arr.to_pandas()
+ tm.assert_series_equal(actual, expected, check_names=False)
+
+ def test_map_array_dictionary_encoded(self):
+ offsets = pa.array([0, 3, 5])
+ items = pa.array(['a', 'b', 'c', 'a', 'd']).dictionary_encode()
+ keys = pa.array(list(range(len(items))))
+ arr = pa.MapArray.from_arrays(offsets, keys, items)
+
+ # Dictionary encoded values converted to dense
+ expected = pd.Series(
+ [[(0, 'a'), (1, 'b'), (2, 'c')], [(3, 'a'), (4, 'd')]])
+
+ actual = arr.to_pandas()
+ tm.assert_series_equal(actual, expected, check_names=False)
+
+
+class TestConvertStructTypes:
+ """
+ Conversion tests for struct types.
+ """
+
+ def test_pandas_roundtrip(self):
+ df = pd.DataFrame({'dicts': [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]})
+
+ expected_schema = pa.schema([
+ ('dicts', pa.struct([('a', pa.int64()), ('b', pa.int64())])),
+ ])
+
+ _check_pandas_roundtrip(df, expected_schema=expected_schema)
+
+ # specifying schema explicitly in from_pandas
+ _check_pandas_roundtrip(
+ df, schema=expected_schema, expected_schema=expected_schema)
+
+ def test_to_pandas(self):
+ ints = pa.array([None, 2, 3], type=pa.int64())
+ strs = pa.array(['a', None, 'c'], type=pa.string())
+ bools = pa.array([True, False, None], type=pa.bool_())
+ arr = pa.StructArray.from_arrays(
+ [ints, strs, bools],
+ ['ints', 'strs', 'bools'])
+
+ expected = pd.Series([
+ {'ints': None, 'strs': 'a', 'bools': True},
+ {'ints': 2, 'strs': None, 'bools': False},
+ {'ints': 3, 'strs': 'c', 'bools': None},
+ ])
+
+ series = pd.Series(arr.to_pandas())
+ tm.assert_series_equal(series, expected)
+
+ def test_to_pandas_multiple_chunks(self):
+ # ARROW-11855
+ gc.collect()
+ bytes_start = pa.total_allocated_bytes()
+ ints1 = pa.array([1], type=pa.int64())
+ ints2 = pa.array([2], type=pa.int64())
+ arr1 = pa.StructArray.from_arrays([ints1], ['ints'])
+ arr2 = pa.StructArray.from_arrays([ints2], ['ints'])
+ arr = pa.chunked_array([arr1, arr2])
+
+ expected = pd.Series([
+ {'ints': 1},
+ {'ints': 2}
+ ])
+
+ series = pd.Series(arr.to_pandas())
+ tm.assert_series_equal(series, expected)
+
+ del series
+ del arr
+ del arr1
+ del arr2
+ del ints1
+ del ints2
+ bytes_end = pa.total_allocated_bytes()
+ assert bytes_end == bytes_start
+
+ def test_from_numpy(self):
+ dt = np.dtype([('x', np.int32),
+ (('y_title', 'y'), np.bool_)])
+ ty = pa.struct([pa.field('x', pa.int32()),
+ pa.field('y', pa.bool_())])
+
+ data = np.array([], dtype=dt)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == []
+
+ data = np.array([(42, True), (43, False)], dtype=dt)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == [{'x': 42, 'y': True},
+ {'x': 43, 'y': False}]
+
+ # With mask
+ arr = pa.array(data, mask=np.bool_([False, True]), type=ty)
+ assert arr.to_pylist() == [{'x': 42, 'y': True}, None]
+
+ # Trivial struct type
+ dt = np.dtype([])
+ ty = pa.struct([])
+
+ data = np.array([], dtype=dt)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == []
+
+ data = np.array([(), ()], dtype=dt)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == [{}, {}]
+
+ def test_from_numpy_nested(self):
+ # Note: an object field inside a struct
+ dt = np.dtype([('x', np.dtype([('xx', np.int8),
+ ('yy', np.bool_)])),
+ ('y', np.int16),
+ ('z', np.object_)])
+ # Note: itemsize is not a multiple of sizeof(object)
+ assert dt.itemsize == 12
+ ty = pa.struct([pa.field('x', pa.struct([pa.field('xx', pa.int8()),
+ pa.field('yy', pa.bool_())])),
+ pa.field('y', pa.int16()),
+ pa.field('z', pa.string())])
+
+ data = np.array([], dtype=dt)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == []
+
+ data = np.array([
+ ((1, True), 2, 'foo'),
+ ((3, False), 4, 'bar')], dtype=dt)
+ arr = pa.array(data, type=ty)
+ assert arr.to_pylist() == [
+ {'x': {'xx': 1, 'yy': True}, 'y': 2, 'z': 'foo'},
+ {'x': {'xx': 3, 'yy': False}, 'y': 4, 'z': 'bar'}]
+
+ @pytest.mark.slow
+ @pytest.mark.large_memory
+ def test_from_numpy_large(self):
+ # Exercise rechunking + nulls
+ target_size = 3 * 1024**3 # 4GB
+ dt = np.dtype([('x', np.float64), ('y', 'object')])
+ bs = 65536 - dt.itemsize
+ block = b'.' * bs
+ n = target_size // (bs + dt.itemsize)
+ data = np.zeros(n, dtype=dt)
+ data['x'] = np.random.random_sample(n)
+ data['y'] = block
+ # Add implicit nulls
+ data['x'][data['x'] < 0.2] = np.nan
+
+ ty = pa.struct([pa.field('x', pa.float64()),
+ pa.field('y', pa.binary())])
+ arr = pa.array(data, type=ty, from_pandas=True)
+ assert arr.num_chunks == 2
+
+ def iter_chunked_array(arr):
+ for chunk in arr.iterchunks():
+ yield from chunk
+
+ def check(arr, data, mask=None):
+ assert len(arr) == len(data)
+ xs = data['x']
+ ys = data['y']
+ for i, obj in enumerate(iter_chunked_array(arr)):
+ try:
+ d = obj.as_py()
+ if mask is not None and mask[i]:
+ assert d is None
+ else:
+ x = xs[i]
+ if np.isnan(x):
+ assert d['x'] is None
+ else:
+ assert d['x'] == x
+ assert d['y'] == ys[i]
+ except Exception:
+ print("Failed at index", i)
+ raise
+
+ check(arr, data)
+ del arr
+
+ # Now with explicit mask
+ mask = np.random.random_sample(n) < 0.2
+ arr = pa.array(data, type=ty, mask=mask, from_pandas=True)
+ assert arr.num_chunks == 2
+
+ check(arr, data, mask)
+ del arr
+
+ def test_from_numpy_bad_input(self):
+ ty = pa.struct([pa.field('x', pa.int32()),
+ pa.field('y', pa.bool_())])
+ dt = np.dtype([('x', np.int32),
+ ('z', np.bool_)])
+
+ data = np.array([], dtype=dt)
+ with pytest.raises(ValueError,
+ match="Missing field 'y'"):
+ pa.array(data, type=ty)
+ data = np.int32([])
+ with pytest.raises(TypeError,
+ match="Expected struct array"):
+ pa.array(data, type=ty)
+
+ def test_from_tuples(self):
+ df = pd.DataFrame({'tuples': [(1, 2), (3, 4)]})
+ expected_df = pd.DataFrame(
+ {'tuples': [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]})
+
+ # conversion from tuples works when specifying expected struct type
+ struct_type = pa.struct([('a', pa.int64()), ('b', pa.int64())])
+
+ arr = np.asarray(df['tuples'])
+ _check_array_roundtrip(
+ arr, expected=expected_df['tuples'], type=struct_type)
+
+ expected_schema = pa.schema([('tuples', struct_type)])
+ _check_pandas_roundtrip(
+ df, expected=expected_df, schema=expected_schema,
+ expected_schema=expected_schema)
+
+ def test_struct_of_dictionary(self):
+ names = ['ints', 'strs']
+ children = [pa.array([456, 789, 456]).dictionary_encode(),
+ pa.array(["foo", "foo", None]).dictionary_encode()]
+ arr = pa.StructArray.from_arrays(children, names=names)
+
+ # Expected a Series of {field name: field value} dicts
+ rows_as_tuples = zip(*(child.to_pylist() for child in children))
+ rows_as_dicts = [dict(zip(names, row)) for row in rows_as_tuples]
+
+ expected = pd.Series(rows_as_dicts)
+ tm.assert_series_equal(arr.to_pandas(), expected)
+
+ # Same but with nulls
+ arr = arr.take([0, None, 2])
+ expected[1] = None
+ tm.assert_series_equal(arr.to_pandas(), expected)
+
+
+class TestZeroCopyConversion:
+ """
+ Tests that zero-copy conversion works with some types.
+ """
+
+ def test_zero_copy_success(self):
+ result = pa.array([0, 1, 2]).to_pandas(zero_copy_only=True)
+ npt.assert_array_equal(result, [0, 1, 2])
+
+ def test_zero_copy_dictionaries(self):
+ arr = pa.DictionaryArray.from_arrays(
+ np.array([0, 0]),
+ np.array([5]))
+
+ result = arr.to_pandas(zero_copy_only=True)
+ values = pd.Categorical([5, 5])
+
+ tm.assert_series_equal(pd.Series(result), pd.Series(values),
+ check_names=False)
+
+ def test_zero_copy_timestamp(self):
+ arr = np.array(['2007-07-13'], dtype='datetime64[ns]')
+ result = pa.array(arr).to_pandas(zero_copy_only=True)
+ npt.assert_array_equal(result, arr)
+
+ def test_zero_copy_duration(self):
+ arr = np.array([1], dtype='timedelta64[ns]')
+ result = pa.array(arr).to_pandas(zero_copy_only=True)
+ npt.assert_array_equal(result, arr)
+
+ def check_zero_copy_failure(self, arr):
+ with pytest.raises(pa.ArrowInvalid):
+ arr.to_pandas(zero_copy_only=True)
+
+ def test_zero_copy_failure_on_object_types(self):
+ self.check_zero_copy_failure(pa.array(['A', 'B', 'C']))
+
+ def test_zero_copy_failure_with_int_when_nulls(self):
+ self.check_zero_copy_failure(pa.array([0, 1, None]))
+
+ def test_zero_copy_failure_with_float_when_nulls(self):
+ self.check_zero_copy_failure(pa.array([0.0, 1.0, None]))
+
+ def test_zero_copy_failure_on_bool_types(self):
+ self.check_zero_copy_failure(pa.array([True, False]))
+
+ def test_zero_copy_failure_on_list_types(self):
+ arr = pa.array([[1, 2], [8, 9]], type=pa.list_(pa.int64()))
+ self.check_zero_copy_failure(arr)
+
+ def test_zero_copy_failure_on_timestamp_with_nulls(self):
+ arr = np.array([1, None], dtype='datetime64[ns]')
+ self.check_zero_copy_failure(pa.array(arr))
+
+ def test_zero_copy_failure_on_duration_with_nulls(self):
+ arr = np.array([1, None], dtype='timedelta64[ns]')
+ self.check_zero_copy_failure(pa.array(arr))
+
+
+def _non_threaded_conversion():
+ df = _alltypes_example()
+ _check_pandas_roundtrip(df, use_threads=False)
+ _check_pandas_roundtrip(df, use_threads=False, as_batch=True)
+
+
+def _threaded_conversion():
+ df = _alltypes_example()
+ _check_pandas_roundtrip(df, use_threads=True)
+ _check_pandas_roundtrip(df, use_threads=True, as_batch=True)
+
+
+class TestConvertMisc:
+ """
+ Miscellaneous conversion tests.
+ """
+
+ type_pairs = [
+ (np.int8, pa.int8()),
+ (np.int16, pa.int16()),
+ (np.int32, pa.int32()),
+ (np.int64, pa.int64()),
+ (np.uint8, pa.uint8()),
+ (np.uint16, pa.uint16()),
+ (np.uint32, pa.uint32()),
+ (np.uint64, pa.uint64()),
+ (np.float16, pa.float16()),
+ (np.float32, pa.float32()),
+ (np.float64, pa.float64()),
+ # XXX unsupported
+ # (np.dtype([('a', 'i2')]), pa.struct([pa.field('a', pa.int16())])),
+ (np.object_, pa.string()),
+ (np.object_, pa.binary()),
+ (np.object_, pa.binary(10)),
+ (np.object_, pa.list_(pa.int64())),
+ ]
+
+ def test_all_none_objects(self):
+ df = pd.DataFrame({'a': [None, None, None]})
+ _check_pandas_roundtrip(df)
+
+ def test_all_none_category(self):
+ df = pd.DataFrame({'a': [None, None, None]})
+ df['a'] = df['a'].astype('category')
+ _check_pandas_roundtrip(df)
+
+ def test_empty_arrays(self):
+ for dtype, pa_type in self.type_pairs:
+ arr = np.array([], dtype=dtype)
+ _check_array_roundtrip(arr, type=pa_type)
+
+ def test_non_threaded_conversion(self):
+ _non_threaded_conversion()
+
+ def test_threaded_conversion_multiprocess(self):
+ # Parallel conversion should work from child processes too (ARROW-2963)
+ pool = mp.Pool(2)
+ try:
+ pool.apply(_threaded_conversion)
+ finally:
+ pool.close()
+ pool.join()
+
+ def test_category(self):
+ repeats = 5
+ v1 = ['foo', None, 'bar', 'qux', np.nan]
+ v2 = [4, 5, 6, 7, 8]
+ v3 = [b'foo', None, b'bar', b'qux', np.nan]
+
+ arrays = {
+ 'cat_strings': pd.Categorical(v1 * repeats),
+ 'cat_strings_with_na': pd.Categorical(v1 * repeats,
+ categories=['foo', 'bar']),
+ 'cat_ints': pd.Categorical(v2 * repeats),
+ 'cat_binary': pd.Categorical(v3 * repeats),
+ 'cat_strings_ordered': pd.Categorical(
+ v1 * repeats, categories=['bar', 'qux', 'foo'],
+ ordered=True),
+ 'ints': v2 * repeats,
+ 'ints2': v2 * repeats,
+ 'strings': v1 * repeats,
+ 'strings2': v1 * repeats,
+ 'strings3': v3 * repeats}
+ df = pd.DataFrame(arrays)
+ _check_pandas_roundtrip(df)
+
+ for k in arrays:
+ _check_array_roundtrip(arrays[k])
+
+ def test_category_implicit_from_pandas(self):
+ # ARROW-3374
+ def _check(v):
+ arr = pa.array(v)
+ result = arr.to_pandas()
+ tm.assert_series_equal(pd.Series(result), pd.Series(v))
+
+ arrays = [
+ pd.Categorical(['a', 'b', 'c'], categories=['a', 'b']),
+ pd.Categorical(['a', 'b', 'c'], categories=['a', 'b'],
+ ordered=True)
+ ]
+ for arr in arrays:
+ _check(arr)
+
+ def test_empty_category(self):
+ # ARROW-2443
+ df = pd.DataFrame({'cat': pd.Categorical([])})
+ _check_pandas_roundtrip(df)
+
+ def test_category_zero_chunks(self):
+ # ARROW-5952
+ for pa_type, dtype in [(pa.string(), 'object'), (pa.int64(), 'int64')]:
+ a = pa.chunked_array([], pa.dictionary(pa.int8(), pa_type))
+ result = a.to_pandas()
+ expected = pd.Categorical([], categories=np.array([], dtype=dtype))
+ tm.assert_series_equal(pd.Series(result), pd.Series(expected))
+
+ table = pa.table({'a': a})
+ result = table.to_pandas()
+ expected = pd.DataFrame({'a': expected})
+ tm.assert_frame_equal(result, expected)
+
+ @pytest.mark.parametrize(
+ "data,error_type",
+ [
+ ({"a": ["a", 1, 2.0]}, pa.ArrowTypeError),
+ ({"a": ["a", 1, 2.0]}, pa.ArrowTypeError),
+ ({"a": [1, True]}, pa.ArrowTypeError),
+ ({"a": [True, "a"]}, pa.ArrowInvalid),
+ ({"a": [1, "a"]}, pa.ArrowInvalid),
+ ({"a": [1.0, "a"]}, pa.ArrowInvalid),
+ ],
+ )
+ def test_mixed_types_fails(self, data, error_type):
+ df = pd.DataFrame(data)
+ msg = "Conversion failed for column a with type object"
+ with pytest.raises(error_type, match=msg):
+ pa.Table.from_pandas(df)
+
+ def test_strided_data_import(self):
+ cases = []
+
+ columns = ['a', 'b', 'c']
+ N, K = 100, 3
+ random_numbers = np.random.randn(N, K).copy() * 100
+
+ numeric_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8',
+ 'f4', 'f8']
+
+ for type_name in numeric_dtypes:
+ cases.append(random_numbers.astype(type_name))
+
+ # strings
+ cases.append(np.array([random_ascii(10) for i in range(N * K)],
+ dtype=object)
+ .reshape(N, K).copy())
+
+ # booleans
+ boolean_objects = (np.array([True, False, True] * N, dtype=object)
+ .reshape(N, K).copy())
+
+ # add some nulls, so dtype comes back as objects
+ boolean_objects[5] = None
+ cases.append(boolean_objects)
+
+ cases.append(np.arange("2016-01-01T00:00:00.001", N * K,
+ dtype='datetime64[ms]')
+ .reshape(N, K).copy())
+
+ strided_mask = (random_numbers > 0).astype(bool)[:, 0]
+
+ for case in cases:
+ df = pd.DataFrame(case, columns=columns)
+ col = df['a']
+
+ _check_pandas_roundtrip(df)
+ _check_array_roundtrip(col)
+ _check_array_roundtrip(col, mask=strided_mask)
+
+ def test_all_nones(self):
+ def _check_series(s):
+ converted = pa.array(s)
+ assert isinstance(converted, pa.NullArray)
+ assert len(converted) == 3
+ assert converted.null_count == 3
+ for item in converted:
+ assert item is pa.NA
+
+ _check_series(pd.Series([None] * 3, dtype=object))
+ _check_series(pd.Series([np.nan] * 3, dtype=object))
+ _check_series(pd.Series([None, np.nan, None], dtype=object))
+
+ def test_partial_schema(self):
+ data = OrderedDict([
+ ('a', [0, 1, 2, 3, 4]),
+ ('b', np.array([-10, -5, 0, 5, 10], dtype=np.int32)),
+ ('c', [-10, -5, 0, 5, 10])
+ ])
+ df = pd.DataFrame(data)
+
+ partial_schema = pa.schema([
+ pa.field('c', pa.int64()),
+ pa.field('a', pa.int64())
+ ])
+
+ _check_pandas_roundtrip(df, schema=partial_schema,
+ expected=df[['c', 'a']],
+ expected_schema=partial_schema)
+
+ def test_table_batch_empty_dataframe(self):
+ df = pd.DataFrame({})
+ _check_pandas_roundtrip(df)
+ _check_pandas_roundtrip(df, as_batch=True)
+
+ df2 = pd.DataFrame({}, index=[0, 1, 2])
+ _check_pandas_roundtrip(df2, preserve_index=True)
+ _check_pandas_roundtrip(df2, as_batch=True, preserve_index=True)
+
+ def test_convert_empty_table(self):
+ arr = pa.array([], type=pa.int64())
+ empty_objects = pd.Series(np.array([], dtype=object))
+ tm.assert_series_equal(arr.to_pandas(),
+ pd.Series(np.array([], dtype=np.int64)))
+ arr = pa.array([], type=pa.string())
+ tm.assert_series_equal(arr.to_pandas(), empty_objects)
+ arr = pa.array([], type=pa.list_(pa.int64()))
+ tm.assert_series_equal(arr.to_pandas(), empty_objects)
+ arr = pa.array([], type=pa.struct([pa.field('a', pa.int64())]))
+ tm.assert_series_equal(arr.to_pandas(), empty_objects)
+
+ def test_non_natural_stride(self):
+ """
+ ARROW-2172: converting from a Numpy array with a stride that's
+ not a multiple of itemsize.
+ """
+ dtype = np.dtype([('x', np.int32), ('y', np.int16)])
+ data = np.array([(42, -1), (-43, 2)], dtype=dtype)
+ assert data.strides == (6,)
+ arr = pa.array(data['x'], type=pa.int32())
+ assert arr.to_pylist() == [42, -43]
+ arr = pa.array(data['y'], type=pa.int16())
+ assert arr.to_pylist() == [-1, 2]
+
+ def test_array_from_strided_numpy_array(self):
+ # ARROW-5651
+ np_arr = np.arange(0, 10, dtype=np.float32)[1:-1:2]
+ pa_arr = pa.array(np_arr, type=pa.float64())
+ expected = pa.array([1.0, 3.0, 5.0, 7.0], type=pa.float64())
+ pa_arr.equals(expected)
+
+ def test_safe_unsafe_casts(self):
+ # ARROW-2799
+ df = pd.DataFrame({
+ 'A': list('abc'),
+ 'B': np.linspace(0, 1, 3)
+ })
+
+ schema = pa.schema([
+ pa.field('A', pa.string()),
+ pa.field('B', pa.int32())
+ ])
+
+ with pytest.raises(ValueError):
+ pa.Table.from_pandas(df, schema=schema)
+
+ table = pa.Table.from_pandas(df, schema=schema, safe=False)
+ assert table.column('B').type == pa.int32()
+
+ def test_error_sparse(self):
+ # ARROW-2818
+ try:
+ df = pd.DataFrame({'a': pd.arrays.SparseArray([1, np.nan, 3])})
+ except AttributeError:
+ # pandas.arrays module introduced in pandas 0.24
+ df = pd.DataFrame({'a': pd.SparseArray([1, np.nan, 3])})
+ with pytest.raises(TypeError, match="Sparse pandas data"):
+ pa.Table.from_pandas(df)
+
+
+def test_safe_cast_from_float_with_nans_to_int():
+ # TODO(kszucs): write tests for creating Date32 and Date64 arrays, see
+ # ARROW-4258 and https://github.com/apache/arrow/pull/3395
+ values = pd.Series([1, 2, None, 4])
+ arr = pa.Array.from_pandas(values, type=pa.int32(), safe=True)
+ expected = pa.array([1, 2, None, 4], type=pa.int32())
+ assert arr.equals(expected)
+
+
+def _fully_loaded_dataframe_example():
+ index = pd.MultiIndex.from_arrays([
+ pd.date_range('2000-01-01', periods=5).repeat(2),
+ np.tile(np.array(['foo', 'bar'], dtype=object), 5)
+ ])
+
+ c1 = pd.date_range('2000-01-01', periods=10)
+ data = {
+ 0: c1,
+ 1: c1.tz_localize('utc'),
+ 2: c1.tz_localize('US/Eastern'),
+ 3: c1[::2].tz_localize('utc').repeat(2).astype('category'),
+ 4: ['foo', 'bar'] * 5,
+ 5: pd.Series(['foo', 'bar'] * 5).astype('category').values,
+ 6: [True, False] * 5,
+ 7: np.random.randn(10),
+ 8: np.random.randint(0, 100, size=10),
+ 9: pd.period_range('2013', periods=10, freq='M')
+ }
+
+ if Version(pd.__version__) >= Version('0.21'):
+ # There is an issue with pickling IntervalIndex in pandas 0.20.x
+ data[10] = pd.interval_range(start=1, freq=1, periods=10)
+
+ return pd.DataFrame(data, index=index)
+
+
+@pytest.mark.parametrize('columns', ([b'foo'], ['foo']))
+def test_roundtrip_with_bytes_unicode(columns):
+ df = pd.DataFrame(columns=columns)
+ table1 = pa.Table.from_pandas(df)
+ table2 = pa.Table.from_pandas(table1.to_pandas())
+ assert table1.equals(table2)
+ assert table1.schema.equals(table2.schema)
+ assert table1.schema.metadata == table2.schema.metadata
+
+
+def _check_serialize_components_roundtrip(pd_obj):
+ with pytest.warns(FutureWarning):
+ ctx = pa.default_serialization_context()
+
+ with pytest.warns(FutureWarning):
+ components = ctx.serialize(pd_obj).to_components()
+ with pytest.warns(FutureWarning):
+ deserialized = ctx.deserialize_components(components)
+
+ if isinstance(pd_obj, pd.DataFrame):
+ tm.assert_frame_equal(pd_obj, deserialized)
+ else:
+ tm.assert_series_equal(pd_obj, deserialized)
+
+
+@pytest.mark.skipif(
+ Version('1.16.0') <= Version(np.__version__) < Version('1.16.1'),
+ reason='Until numpy/numpy#12745 is resolved')
+def test_serialize_deserialize_pandas():
+ # ARROW-1784, serialize and deserialize DataFrame by decomposing
+ # BlockManager
+ df = _fully_loaded_dataframe_example()
+ _check_serialize_components_roundtrip(df)
+
+
+def test_serialize_deserialize_empty_pandas():
+ # ARROW-7996, serialize and deserialize empty pandas objects
+ df = pd.DataFrame({'col1': [], 'col2': [], 'col3': []})
+ _check_serialize_components_roundtrip(df)
+
+ series = pd.Series([], dtype=np.float32, name='col')
+ _check_serialize_components_roundtrip(series)
+
+
+def _pytime_from_micros(val):
+ microseconds = val % 1000000
+ val //= 1000000
+ seconds = val % 60
+ val //= 60
+ minutes = val % 60
+ hours = val // 60
+ return time(hours, minutes, seconds, microseconds)
+
+
+def _pytime_to_micros(pytime):
+ return (pytime.hour * 3600000000 +
+ pytime.minute * 60000000 +
+ pytime.second * 1000000 +
+ pytime.microsecond)
+
+
+def test_convert_unsupported_type_error_message():
+ # ARROW-1454
+
+ # custom python objects
+ class A:
+ pass
+
+ df = pd.DataFrame({'a': [A(), A()]})
+
+ msg = 'Conversion failed for column a with type object'
+ with pytest.raises(ValueError, match=msg):
+ pa.Table.from_pandas(df)
+
+ # period unsupported for pandas <= 0.25
+ if Version(pd.__version__) <= Version('0.25'):
+ df = pd.DataFrame({
+ 'a': pd.period_range('2000-01-01', periods=20),
+ })
+
+ msg = 'Conversion failed for column a with type (period|object)'
+ with pytest.raises((TypeError, ValueError), match=msg):
+ pa.Table.from_pandas(df)
+
+
+# ----------------------------------------------------------------------
+# Hypothesis tests
+
+
+@h.given(past.arrays(past.pandas_compatible_types))
+def test_array_to_pandas_roundtrip(arr):
+ s = arr.to_pandas()
+ restored = pa.array(s, type=arr.type, from_pandas=True)
+ assert restored.equals(arr)
+
+
+# ----------------------------------------------------------------------
+# Test object deduplication in to_pandas
+
+
+def _generate_dedup_example(nunique, repeats):
+ unique_values = [rands(10) for i in range(nunique)]
+ return unique_values * repeats
+
+
+def _assert_nunique(obj, expected):
+ assert len({id(x) for x in obj}) == expected
+
+
+def test_to_pandas_deduplicate_strings_array_types():
+ nunique = 100
+ repeats = 10
+ values = _generate_dedup_example(nunique, repeats)
+
+ for arr in [pa.array(values, type=pa.binary()),
+ pa.array(values, type=pa.utf8()),
+ pa.chunked_array([values, values])]:
+ _assert_nunique(arr.to_pandas(), nunique)
+ _assert_nunique(arr.to_pandas(deduplicate_objects=False), len(arr))
+
+
+def test_to_pandas_deduplicate_strings_table_types():
+ nunique = 100
+ repeats = 10
+ values = _generate_dedup_example(nunique, repeats)
+
+ arr = pa.array(values)
+ rb = pa.RecordBatch.from_arrays([arr], ['foo'])
+ tbl = pa.Table.from_batches([rb])
+
+ for obj in [rb, tbl]:
+ _assert_nunique(obj.to_pandas()['foo'], nunique)
+ _assert_nunique(obj.to_pandas(deduplicate_objects=False)['foo'],
+ len(obj))
+
+
+def test_to_pandas_deduplicate_integers_as_objects():
+ nunique = 100
+ repeats = 10
+
+ # Python automatically interns smaller integers
+ unique_values = list(np.random.randint(10000000, 1000000000, size=nunique))
+ unique_values[nunique // 2] = None
+
+ arr = pa.array(unique_values * repeats)
+
+ _assert_nunique(arr.to_pandas(integer_object_nulls=True), nunique)
+ _assert_nunique(arr.to_pandas(integer_object_nulls=True,
+ deduplicate_objects=False),
+ # Account for None
+ (nunique - 1) * repeats + 1)
+
+
+def test_to_pandas_deduplicate_date_time():
+ nunique = 100
+ repeats = 10
+
+ unique_values = list(range(nunique))
+
+ cases = [
+ # raw type, array type, to_pandas options
+ ('int32', 'date32', {'date_as_object': True}),
+ ('int64', 'date64', {'date_as_object': True}),
+ ('int32', 'time32[ms]', {}),
+ ('int64', 'time64[us]', {})
+ ]
+
+ for raw_type, array_type, pandas_options in cases:
+ raw_arr = pa.array(unique_values * repeats, type=raw_type)
+ casted_arr = raw_arr.cast(array_type)
+
+ _assert_nunique(casted_arr.to_pandas(**pandas_options),
+ nunique)
+ _assert_nunique(casted_arr.to_pandas(deduplicate_objects=False,
+ **pandas_options),
+ len(casted_arr))
+
+
+# ---------------------------------------------------------------------
+
+def test_table_from_pandas_checks_field_nullability():
+ # ARROW-2136
+ df = pd.DataFrame({'a': [1.2, 2.1, 3.1],
+ 'b': [np.nan, 'string', 'foo']})
+ schema = pa.schema([pa.field('a', pa.float64(), nullable=False),
+ pa.field('b', pa.utf8(), nullable=False)])
+
+ with pytest.raises(ValueError):
+ pa.Table.from_pandas(df, schema=schema)
+
+
+def test_table_from_pandas_keeps_column_order_of_dataframe():
+ df1 = pd.DataFrame(OrderedDict([
+ ('partition', [0, 0, 1, 1]),
+ ('arrays', [[0, 1, 2], [3, 4], None, None]),
+ ('floats', [None, None, 1.1, 3.3])
+ ]))
+ df2 = df1[['floats', 'partition', 'arrays']]
+
+ schema1 = pa.schema([
+ ('partition', pa.int64()),
+ ('arrays', pa.list_(pa.int64())),
+ ('floats', pa.float64()),
+ ])
+ schema2 = pa.schema([
+ ('floats', pa.float64()),
+ ('partition', pa.int64()),
+ ('arrays', pa.list_(pa.int64()))
+ ])
+
+ table1 = pa.Table.from_pandas(df1, preserve_index=False)
+ table2 = pa.Table.from_pandas(df2, preserve_index=False)
+
+ assert table1.schema.equals(schema1)
+ assert table2.schema.equals(schema2)
+
+
+def test_table_from_pandas_keeps_column_order_of_schema():
+ # ARROW-3766
+ df = pd.DataFrame(OrderedDict([
+ ('partition', [0, 0, 1, 1]),
+ ('arrays', [[0, 1, 2], [3, 4], None, None]),
+ ('floats', [None, None, 1.1, 3.3])
+ ]))
+
+ schema = pa.schema([
+ ('floats', pa.float64()),
+ ('arrays', pa.list_(pa.int32())),
+ ('partition', pa.int32())
+ ])
+
+ df1 = df[df.partition == 0]
+ df2 = df[df.partition == 1][['floats', 'partition', 'arrays']]
+
+ table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False)
+ table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False)
+
+ assert table1.schema.equals(schema)
+ assert table1.schema.equals(table2.schema)
+
+
+def test_table_from_pandas_columns_argument_only_does_filtering():
+ df = pd.DataFrame(OrderedDict([
+ ('partition', [0, 0, 1, 1]),
+ ('arrays', [[0, 1, 2], [3, 4], None, None]),
+ ('floats', [None, None, 1.1, 3.3])
+ ]))
+
+ columns1 = ['arrays', 'floats', 'partition']
+ schema1 = pa.schema([
+ ('arrays', pa.list_(pa.int64())),
+ ('floats', pa.float64()),
+ ('partition', pa.int64())
+ ])
+
+ columns2 = ['floats', 'partition']
+ schema2 = pa.schema([
+ ('floats', pa.float64()),
+ ('partition', pa.int64())
+ ])
+
+ table1 = pa.Table.from_pandas(df, columns=columns1, preserve_index=False)
+ table2 = pa.Table.from_pandas(df, columns=columns2, preserve_index=False)
+
+ assert table1.schema.equals(schema1)
+ assert table2.schema.equals(schema2)
+
+
+def test_table_from_pandas_columns_and_schema_are_mutually_exclusive():
+ df = pd.DataFrame(OrderedDict([
+ ('partition', [0, 0, 1, 1]),
+ ('arrays', [[0, 1, 2], [3, 4], None, None]),
+ ('floats', [None, None, 1.1, 3.3])
+ ]))
+ schema = pa.schema([
+ ('partition', pa.int32()),
+ ('arrays', pa.list_(pa.int32())),
+ ('floats', pa.float64()),
+ ])
+ columns = ['arrays', 'floats']
+
+ with pytest.raises(ValueError):
+ pa.Table.from_pandas(df, schema=schema, columns=columns)
+
+
+def test_table_from_pandas_keeps_schema_nullability():
+ # ARROW-5169
+ df = pd.DataFrame({'a': [1, 2, 3, 4]})
+
+ schema = pa.schema([
+ pa.field('a', pa.int64(), nullable=False),
+ ])
+
+ table = pa.Table.from_pandas(df)
+ assert table.schema.field('a').nullable is True
+ table = pa.Table.from_pandas(df, schema=schema)
+ assert table.schema.field('a').nullable is False
+
+
+def test_table_from_pandas_schema_index_columns():
+ # ARROW-5220
+ df = pd.DataFrame({'a': [1, 2, 3], 'b': [0.1, 0.2, 0.3]})
+
+ schema = pa.schema([
+ ('a', pa.int64()),
+ ('b', pa.float64()),
+ ('index', pa.int32()),
+ ])
+
+ # schema includes index with name not in dataframe
+ with pytest.raises(KeyError, match="name 'index' present in the"):
+ pa.Table.from_pandas(df, schema=schema)
+
+ df.index.name = 'index'
+
+ # schema includes correct index name -> roundtrip works
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=True,
+ expected_schema=schema)
+
+ # schema includes correct index name but preserve_index=False
+ with pytest.raises(ValueError, match="'preserve_index=False' was"):
+ pa.Table.from_pandas(df, schema=schema, preserve_index=False)
+
+ # in case of preserve_index=None -> RangeIndex serialized as metadata
+ # clashes with the index in the schema
+ with pytest.raises(ValueError, match="name 'index' is present in the "
+ "schema, but it is a RangeIndex"):
+ pa.Table.from_pandas(df, schema=schema, preserve_index=None)
+
+ df.index = pd.Index([0, 1, 2], name='index')
+
+ # for non-RangeIndex, both preserve_index=None and True work
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=None,
+ expected_schema=schema)
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=True,
+ expected_schema=schema)
+
+ # schema has different order (index column not at the end)
+ schema = pa.schema([
+ ('index', pa.int32()),
+ ('a', pa.int64()),
+ ('b', pa.float64()),
+ ])
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=None,
+ expected_schema=schema)
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=True,
+ expected_schema=schema)
+
+ # schema does not include the index -> index is not included as column
+ # even though preserve_index=True/None
+ schema = pa.schema([
+ ('a', pa.int64()),
+ ('b', pa.float64()),
+ ])
+ expected = df.copy()
+ expected = expected.reset_index(drop=True)
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=None,
+ expected_schema=schema, expected=expected)
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=True,
+ expected_schema=schema, expected=expected)
+
+ # dataframe with a MultiIndex
+ df.index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)],
+ names=['level1', 'level2'])
+ schema = pa.schema([
+ ('level1', pa.string()),
+ ('level2', pa.int64()),
+ ('a', pa.int64()),
+ ('b', pa.float64()),
+ ])
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=True,
+ expected_schema=schema)
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=None,
+ expected_schema=schema)
+
+ # only one of the levels of the MultiIndex is included
+ schema = pa.schema([
+ ('level2', pa.int64()),
+ ('a', pa.int64()),
+ ('b', pa.float64()),
+ ])
+ expected = df.copy()
+ expected = expected.reset_index('level1', drop=True)
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=True,
+ expected_schema=schema, expected=expected)
+ _check_pandas_roundtrip(df, schema=schema, preserve_index=None,
+ expected_schema=schema, expected=expected)
+
+
+def test_table_from_pandas_schema_index_columns__unnamed_index():
+ # ARROW-6999 - unnamed indices in specified schema
+ df = pd.DataFrame({'a': [1, 2, 3], 'b': [0.1, 0.2, 0.3]})
+
+ expected_schema = pa.schema([
+ ('a', pa.int64()),
+ ('b', pa.float64()),
+ ('__index_level_0__', pa.int64()),
+ ])
+
+ schema = pa.Schema.from_pandas(df, preserve_index=True)
+ table = pa.Table.from_pandas(df, preserve_index=True, schema=schema)
+ assert table.schema.remove_metadata().equals(expected_schema)
+
+ # non-RangeIndex (preserved by default)
+ df = pd.DataFrame({'a': [1, 2, 3], 'b': [0.1, 0.2, 0.3]}, index=[0, 1, 2])
+ schema = pa.Schema.from_pandas(df)
+ table = pa.Table.from_pandas(df, schema=schema)
+ assert table.schema.remove_metadata().equals(expected_schema)
+
+
+def test_table_from_pandas_schema_with_custom_metadata():
+ # ARROW-7087 - metadata disappear from pandas
+ df = pd.DataFrame()
+ schema = pa.Schema.from_pandas(df).with_metadata({'meta': 'True'})
+ table = pa.Table.from_pandas(df, schema=schema)
+ assert table.schema.metadata.get(b'meta') == b'True'
+
+
+def test_table_from_pandas_schema_field_order_metadat():
+ # ARROW-10532
+ # ensure that a different field order in specified schema doesn't
+ # mangle metadata
+ df = pd.DataFrame({
+ "datetime": pd.date_range("2020-01-01T00:00:00Z", freq="H", periods=2),
+ "float": np.random.randn(2)
+ })
+
+ schema = pa.schema([
+ pa.field("float", pa.float32(), nullable=True),
+ pa.field("datetime", pa.timestamp("s", tz="UTC"), nullable=False)
+ ])
+
+ table = pa.Table.from_pandas(df, schema=schema)
+ assert table.schema.equals(schema)
+ metadata_float = table.schema.pandas_metadata["columns"][0]
+ assert metadata_float["name"] == "float"
+ assert metadata_float["metadata"] is None
+ metadata_datetime = table.schema.pandas_metadata["columns"][1]
+ assert metadata_datetime["name"] == "datetime"
+ assert metadata_datetime["metadata"] == {'timezone': 'UTC'}
+
+ result = table.to_pandas()
+ expected = df[["float", "datetime"]].astype({"float": "float32"})
+ tm.assert_frame_equal(result, expected)
+
+
+# ----------------------------------------------------------------------
+# RecordBatch, Table
+
+
+def test_recordbatch_from_to_pandas():
+ data = pd.DataFrame({
+ 'c1': np.array([1, 2, 3, 4, 5], dtype='int64'),
+ 'c2': np.array([1, 2, 3, 4, 5], dtype='uint32'),
+ 'c3': np.random.randn(5),
+ 'c4': ['foo', 'bar', None, 'baz', 'qux'],
+ 'c5': [False, True, False, True, False]
+ })
+
+ batch = pa.RecordBatch.from_pandas(data)
+ result = batch.to_pandas()
+ tm.assert_frame_equal(data, result)
+
+
+def test_recordbatchlist_to_pandas():
+ data1 = pd.DataFrame({
+ 'c1': np.array([1, 1, 2], dtype='uint32'),
+ 'c2': np.array([1.0, 2.0, 3.0], dtype='float64'),
+ 'c3': [True, None, False],
+ 'c4': ['foo', 'bar', None]
+ })
+
+ data2 = pd.DataFrame({
+ 'c1': np.array([3, 5], dtype='uint32'),
+ 'c2': np.array([4.0, 5.0], dtype='float64'),
+ 'c3': [True, True],
+ 'c4': ['baz', 'qux']
+ })
+
+ batch1 = pa.RecordBatch.from_pandas(data1)
+ batch2 = pa.RecordBatch.from_pandas(data2)
+
+ table = pa.Table.from_batches([batch1, batch2])
+ result = table.to_pandas()
+ data = pd.concat([data1, data2]).reset_index(drop=True)
+ tm.assert_frame_equal(data, result)
+
+
+def test_recordbatch_table_pass_name_to_pandas():
+ rb = pa.record_batch([pa.array([1, 2, 3, 4])], names=['a0'])
+ t = pa.table([pa.array([1, 2, 3, 4])], names=['a0'])
+ assert rb[0].to_pandas().name == 'a0'
+ assert t[0].to_pandas().name == 'a0'
+
+
+# ----------------------------------------------------------------------
+# Metadata serialization
+
+
+@pytest.mark.parametrize(
+ ('type', 'expected'),
+ [
+ (pa.null(), 'empty'),
+ (pa.bool_(), 'bool'),
+ (pa.int8(), 'int8'),
+ (pa.int16(), 'int16'),
+ (pa.int32(), 'int32'),
+ (pa.int64(), 'int64'),
+ (pa.uint8(), 'uint8'),
+ (pa.uint16(), 'uint16'),
+ (pa.uint32(), 'uint32'),
+ (pa.uint64(), 'uint64'),
+ (pa.float16(), 'float16'),
+ (pa.float32(), 'float32'),
+ (pa.float64(), 'float64'),
+ (pa.date32(), 'date'),
+ (pa.date64(), 'date'),
+ (pa.binary(), 'bytes'),
+ (pa.binary(length=4), 'bytes'),
+ (pa.string(), 'unicode'),
+ (pa.list_(pa.list_(pa.int16())), 'list[list[int16]]'),
+ (pa.decimal128(18, 3), 'decimal'),
+ (pa.timestamp('ms'), 'datetime'),
+ (pa.timestamp('us', 'UTC'), 'datetimetz'),
+ (pa.time32('s'), 'time'),
+ (pa.time64('us'), 'time')
+ ]
+)
+def test_logical_type(type, expected):
+ assert get_logical_type(type) == expected
+
+
+# ----------------------------------------------------------------------
+# to_pandas uses MemoryPool
+
+def test_array_uses_memory_pool():
+ # ARROW-6570
+ N = 10000
+ arr = pa.array(np.arange(N, dtype=np.int64),
+ mask=np.random.randint(0, 2, size=N).astype(np.bool_))
+
+ # In the case the gc is caught loafing
+ gc.collect()
+
+ prior_allocation = pa.total_allocated_bytes()
+
+ x = arr.to_pandas()
+ assert pa.total_allocated_bytes() == (prior_allocation + N * 8)
+ x = None # noqa
+ gc.collect()
+
+ assert pa.total_allocated_bytes() == prior_allocation
+
+ # zero copy does not allocate memory
+ arr = pa.array(np.arange(N, dtype=np.int64))
+
+ prior_allocation = pa.total_allocated_bytes()
+ x = arr.to_pandas() # noqa
+ assert pa.total_allocated_bytes() == prior_allocation
+
+
+def test_singleton_blocks_zero_copy():
+ # Part of ARROW-3789
+ t = pa.table([pa.array(np.arange(1000, dtype=np.int64))], ['f0'])
+
+ # Zero copy if split_blocks=True
+ _check_to_pandas_memory_unchanged(t, split_blocks=True)
+
+ prior_allocation = pa.total_allocated_bytes()
+ result = t.to_pandas()
+ assert result['f0'].values.flags.writeable
+ assert pa.total_allocated_bytes() > prior_allocation
+
+
+def _check_to_pandas_memory_unchanged(obj, **kwargs):
+ prior_allocation = pa.total_allocated_bytes()
+ x = obj.to_pandas(**kwargs) # noqa
+
+ # Memory allocation unchanged -- either zero copy or self-destructing
+ assert pa.total_allocated_bytes() == prior_allocation
+
+
+def test_to_pandas_split_blocks():
+ # ARROW-3789
+ t = pa.table([
+ pa.array([1, 2, 3, 4, 5], type='i1'),
+ pa.array([1, 2, 3, 4, 5], type='i4'),
+ pa.array([1, 2, 3, 4, 5], type='i8'),
+ pa.array([1, 2, 3, 4, 5], type='f4'),
+ pa.array([1, 2, 3, 4, 5], type='f8'),
+ pa.array([1, 2, 3, 4, 5], type='f8'),
+ pa.array([1, 2, 3, 4, 5], type='f8'),
+ pa.array([1, 2, 3, 4, 5], type='f8'),
+ ], ['f{}'.format(i) for i in range(8)])
+
+ _check_blocks_created(t, 8)
+ _check_to_pandas_memory_unchanged(t, split_blocks=True)
+
+
+def _check_blocks_created(t, number):
+ x = t.to_pandas(split_blocks=True)
+ assert len(x._data.blocks) == number
+
+
+def test_to_pandas_self_destruct():
+ K = 50
+
+ def _make_table():
+ return pa.table([
+ # Slice to force a copy
+ pa.array(np.random.randn(10000)[::2])
+ for i in range(K)
+ ], ['f{}'.format(i) for i in range(K)])
+
+ t = _make_table()
+ _check_to_pandas_memory_unchanged(t, split_blocks=True, self_destruct=True)
+
+ # Check non-split-block behavior
+ t = _make_table()
+ _check_to_pandas_memory_unchanged(t, self_destruct=True)
+
+
+def test_table_uses_memory_pool():
+ N = 10000
+ arr = pa.array(np.arange(N, dtype=np.int64))
+ t = pa.table([arr, arr, arr], ['f0', 'f1', 'f2'])
+
+ prior_allocation = pa.total_allocated_bytes()
+ x = t.to_pandas()
+
+ assert pa.total_allocated_bytes() == (prior_allocation + 3 * N * 8)
+
+ # Check successful garbage collection
+ x = None # noqa
+ gc.collect()
+ assert pa.total_allocated_bytes() == prior_allocation
+
+
+def test_object_leak_in_numpy_array():
+ # ARROW-6876
+ arr = pa.array([{'a': 1}])
+ np_arr = arr.to_pandas()
+ assert np_arr.dtype == np.dtype('object')
+ obj = np_arr[0]
+ refcount = sys.getrefcount(obj)
+ assert sys.getrefcount(obj) == refcount
+ del np_arr
+ assert sys.getrefcount(obj) == refcount - 1
+
+
+def test_object_leak_in_dataframe():
+ # ARROW-6876
+ arr = pa.array([{'a': 1}])
+ table = pa.table([arr], ['f0'])
+ col = table.to_pandas()['f0']
+ assert col.dtype == np.dtype('object')
+ obj = col[0]
+ refcount = sys.getrefcount(obj)
+ assert sys.getrefcount(obj) == refcount
+ del col
+ assert sys.getrefcount(obj) == refcount - 1
+
+
+# ----------------------------------------------------------------------
+# Some nested array tests array tests
+
+
+def test_array_from_py_float32():
+ data = [[1.2, 3.4], [9.0, 42.0]]
+
+ t = pa.float32()
+
+ arr1 = pa.array(data[0], type=t)
+ arr2 = pa.array(data, type=pa.list_(t))
+
+ expected1 = np.array(data[0], dtype=np.float32)
+ expected2 = pd.Series([np.array(data[0], dtype=np.float32),
+ np.array(data[1], dtype=np.float32)])
+
+ assert arr1.type == t
+ assert arr1.equals(pa.array(expected1))
+ assert arr2.equals(pa.array(expected2))
+
+
+# ----------------------------------------------------------------------
+# Timestamp tests
+
+
+def test_cast_timestamp_unit():
+ # ARROW-1680
+ val = datetime.now()
+ s = pd.Series([val])
+ s_nyc = s.dt.tz_localize('tzlocal()').dt.tz_convert('America/New_York')
+
+ us_with_tz = pa.timestamp('us', tz='America/New_York')
+
+ arr = pa.Array.from_pandas(s_nyc, type=us_with_tz)
+
+ # ARROW-1906
+ assert arr.type == us_with_tz
+
+ arr2 = pa.Array.from_pandas(s, type=pa.timestamp('us'))
+
+ assert arr[0].as_py() == s_nyc[0].to_pydatetime()
+ assert arr2[0].as_py() == s[0].to_pydatetime()
+
+ # Disallow truncation
+ arr = pa.array([123123], type='int64').cast(pa.timestamp('ms'))
+ expected = pa.array([123], type='int64').cast(pa.timestamp('s'))
+
+ # sanity check that the cast worked right
+ assert arr.type == pa.timestamp('ms')
+
+ target = pa.timestamp('s')
+ with pytest.raises(ValueError):
+ arr.cast(target)
+
+ result = arr.cast(target, safe=False)
+ assert result.equals(expected)
+
+ # ARROW-1949
+ series = pd.Series([pd.Timestamp(1), pd.Timestamp(10), pd.Timestamp(1000)])
+ expected = pa.array([0, 0, 1], type=pa.timestamp('us'))
+
+ with pytest.raises(ValueError):
+ pa.array(series, type=pa.timestamp('us'))
+
+ with pytest.raises(ValueError):
+ pa.Array.from_pandas(series, type=pa.timestamp('us'))
+
+ result = pa.Array.from_pandas(series, type=pa.timestamp('us'), safe=False)
+ assert result.equals(expected)
+
+ result = pa.array(series, type=pa.timestamp('us'), safe=False)
+ assert result.equals(expected)
+
+
+def test_nested_with_timestamp_tz_round_trip():
+ ts = pd.Timestamp.now()
+ ts_dt = ts.to_pydatetime()
+ arr = pa.array([ts_dt], type=pa.timestamp('us', tz='America/New_York'))
+ struct = pa.StructArray.from_arrays([arr, arr], ['start', 'stop'])
+
+ result = struct.to_pandas()
+ restored = pa.array(result)
+ assert restored.equals(struct)
+
+
+def test_nested_with_timestamp_tz():
+ # ARROW-7723
+ ts = pd.Timestamp.now()
+ ts_dt = ts.to_pydatetime()
+
+ # XXX: Ensure that this data does not get promoted to nanoseconds (and thus
+ # integers) to preserve behavior in 0.15.1
+ for unit in ['s', 'ms', 'us']:
+ if unit in ['s', 'ms']:
+ # This is used for verifying timezone conversion to micros are not
+ # important
+ def truncate(x): return x.replace(microsecond=0)
+ else:
+ def truncate(x): return x
+ arr = pa.array([ts], type=pa.timestamp(unit))
+ arr2 = pa.array([ts], type=pa.timestamp(unit, tz='America/New_York'))
+
+ arr3 = pa.StructArray.from_arrays([arr, arr], ['start', 'stop'])
+ arr4 = pa.StructArray.from_arrays([arr2, arr2], ['start', 'stop'])
+
+ result = arr3.to_pandas()
+ assert isinstance(result[0]['start'], datetime)
+ assert result[0]['start'].tzinfo is None
+ assert isinstance(result[0]['stop'], datetime)
+ assert result[0]['stop'].tzinfo is None
+
+ result = arr4.to_pandas()
+ assert isinstance(result[0]['start'], datetime)
+ assert result[0]['start'].tzinfo is not None
+ utc_dt = result[0]['start'].astimezone(timezone.utc)
+ assert truncate(utc_dt).replace(tzinfo=None) == truncate(ts_dt)
+ assert isinstance(result[0]['stop'], datetime)
+ assert result[0]['stop'].tzinfo is not None
+
+ # same conversion for table
+ result = pa.table({'a': arr3}).to_pandas()
+ assert isinstance(result['a'][0]['start'], datetime)
+ assert result['a'][0]['start'].tzinfo is None
+ assert isinstance(result['a'][0]['stop'], datetime)
+ assert result['a'][0]['stop'].tzinfo is None
+
+ result = pa.table({'a': arr4}).to_pandas()
+ assert isinstance(result['a'][0]['start'], datetime)
+ assert result['a'][0]['start'].tzinfo is not None
+ assert isinstance(result['a'][0]['stop'], datetime)
+ assert result['a'][0]['stop'].tzinfo is not None
+
+
+# ----------------------------------------------------------------------
+# DictionaryArray tests
+
+
+def test_dictionary_with_pandas():
+ src_indices = np.repeat([0, 1, 2], 2)
+ dictionary = np.array(['foo', 'bar', 'baz'], dtype=object)
+ mask = np.array([False, False, True, False, False, False])
+
+ for index_type in ['uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32',
+ 'uint64', 'int64']:
+ indices = src_indices.astype(index_type)
+ d1 = pa.DictionaryArray.from_arrays(indices, dictionary)
+ d2 = pa.DictionaryArray.from_arrays(indices, dictionary, mask=mask)
+
+ if index_type[0] == 'u':
+ # TODO: unsigned dictionary indices to pandas
+ with pytest.raises(TypeError):
+ d1.to_pandas()
+ continue
+
+ pandas1 = d1.to_pandas()
+ ex_pandas1 = pd.Categorical.from_codes(indices, categories=dictionary)
+
+ tm.assert_series_equal(pd.Series(pandas1), pd.Series(ex_pandas1))
+
+ pandas2 = d2.to_pandas()
+ assert pandas2.isnull().sum() == 1
+
+ # Unsigned integers converted to signed
+ signed_indices = indices
+ if index_type[0] == 'u':
+ signed_indices = indices.astype(index_type[1:])
+ ex_pandas2 = pd.Categorical.from_codes(np.where(mask, -1,
+ signed_indices),
+ categories=dictionary)
+
+ tm.assert_series_equal(pd.Series(pandas2), pd.Series(ex_pandas2))
+
+
+def random_strings(n, item_size, pct_null=0, dictionary=None):
+ if dictionary is not None:
+ result = dictionary[np.random.randint(0, len(dictionary), size=n)]
+ else:
+ result = np.array([random_ascii(item_size) for i in range(n)],
+ dtype=object)
+
+ if pct_null > 0:
+ result[np.random.rand(n) < pct_null] = None
+
+ return result
+
+
+def test_variable_dictionary_to_pandas():
+ np.random.seed(12345)
+
+ d1 = pa.array(random_strings(100, 32), type='string')
+ d2 = pa.array(random_strings(100, 16), type='string')
+ d3 = pa.array(random_strings(10000, 10), type='string')
+
+ a1 = pa.DictionaryArray.from_arrays(
+ np.random.randint(0, len(d1), size=1000, dtype='i4'),
+ d1
+ )
+ a2 = pa.DictionaryArray.from_arrays(
+ np.random.randint(0, len(d2), size=1000, dtype='i4'),
+ d2
+ )
+
+ # With some nulls
+ a3 = pa.DictionaryArray.from_arrays(
+ np.random.randint(0, len(d3), size=1000, dtype='i4'), d3)
+
+ i4 = pa.array(
+ np.random.randint(0, len(d3), size=1000, dtype='i4'),
+ mask=np.random.rand(1000) < 0.1
+ )
+ a4 = pa.DictionaryArray.from_arrays(i4, d3)
+
+ expected_dict = pa.concat_arrays([d1, d2, d3])
+
+ a = pa.chunked_array([a1, a2, a3, a4])
+ a_dense = pa.chunked_array([a1.cast('string'),
+ a2.cast('string'),
+ a3.cast('string'),
+ a4.cast('string')])
+
+ result = a.to_pandas()
+ result_dense = a_dense.to_pandas()
+
+ assert (result.cat.categories == expected_dict.to_pandas()).all()
+
+ expected_dense = result.astype('str')
+ expected_dense[result_dense.isnull()] = None
+ tm.assert_series_equal(result_dense, expected_dense)
+
+
+def test_dictionary_encoded_nested_to_pandas():
+ # ARROW-6899
+ child = pa.array(['a', 'a', 'a', 'b', 'b']).dictionary_encode()
+
+ arr = pa.ListArray.from_arrays([0, 3, 5], child)
+
+ result = arr.to_pandas()
+ expected = pd.Series([np.array(['a', 'a', 'a'], dtype=object),
+ np.array(['b', 'b'], dtype=object)])
+
+ tm.assert_series_equal(result, expected)
+
+
+def test_dictionary_from_pandas():
+ cat = pd.Categorical(['a', 'b', 'a'])
+ expected_type = pa.dictionary(pa.int8(), pa.string())
+
+ result = pa.array(cat)
+ assert result.to_pylist() == ['a', 'b', 'a']
+ assert result.type.equals(expected_type)
+
+ # with missing values in categorical
+ cat = pd.Categorical(['a', 'b', None, 'a'])
+
+ result = pa.array(cat)
+ assert result.to_pylist() == ['a', 'b', None, 'a']
+ assert result.type.equals(expected_type)
+
+ # with additional mask
+ result = pa.array(cat, mask=np.array([False, False, False, True]))
+ assert result.to_pylist() == ['a', 'b', None, None]
+ assert result.type.equals(expected_type)
+
+
+def test_dictionary_from_pandas_specified_type():
+ # ARROW-7168 - ensure specified type is always respected
+
+ # the same as cat = pd.Categorical(['a', 'b']) but explicit about dtypes
+ cat = pd.Categorical.from_codes(
+ np.array([0, 1], dtype='int8'), np.array(['a', 'b'], dtype=object))
+
+ # different index type -> allow this
+ # (the type of the 'codes' in pandas is not part of the data type)
+ typ = pa.dictionary(index_type=pa.int16(), value_type=pa.string())
+ result = pa.array(cat, type=typ)
+ assert result.type.equals(typ)
+ assert result.to_pylist() == ['a', 'b']
+
+ # mismatching values type -> raise error
+ typ = pa.dictionary(index_type=pa.int8(), value_type=pa.int64())
+ with pytest.raises(pa.ArrowInvalid):
+ result = pa.array(cat, type=typ)
+
+ # mismatching order -> raise error (for now a deprecation warning)
+ typ = pa.dictionary(
+ index_type=pa.int8(), value_type=pa.string(), ordered=True)
+ with pytest.warns(FutureWarning, match="The 'ordered' flag of the passed"):
+ result = pa.array(cat, type=typ)
+ assert result.to_pylist() == ['a', 'b']
+
+ # with mask
+ typ = pa.dictionary(index_type=pa.int16(), value_type=pa.string())
+ result = pa.array(cat, type=typ, mask=np.array([False, True]))
+ assert result.type.equals(typ)
+ assert result.to_pylist() == ['a', None]
+
+ # empty categorical -> be flexible in values type to allow
+ cat = pd.Categorical([])
+
+ typ = pa.dictionary(index_type=pa.int8(), value_type=pa.string())
+ result = pa.array(cat, type=typ)
+ assert result.type.equals(typ)
+ assert result.to_pylist() == []
+ typ = pa.dictionary(index_type=pa.int8(), value_type=pa.int64())
+ result = pa.array(cat, type=typ)
+ assert result.type.equals(typ)
+ assert result.to_pylist() == []
+
+ # passing non-dictionary type
+ cat = pd.Categorical(['a', 'b'])
+ result = pa.array(cat, type=pa.string())
+ expected = pa.array(['a', 'b'], type=pa.string())
+ assert result.equals(expected)
+ assert result.to_pylist() == ['a', 'b']
+
+
+# ----------------------------------------------------------------------
+# Array protocol in pandas conversions tests
+
+
+def test_array_protocol():
+ if Version(pd.__version__) < Version('0.24.0'):
+ pytest.skip('IntegerArray only introduced in 0.24')
+
+ df = pd.DataFrame({'a': pd.Series([1, 2, None], dtype='Int64')})
+
+ if Version(pd.__version__) < Version('0.26.0.dev'):
+ # with pandas<=0.25, trying to convert nullable integer errors
+ with pytest.raises(TypeError):
+ pa.table(df)
+ else:
+ # __arrow_array__ added to pandas IntegerArray in 0.26.0.dev
+
+ # default conversion
+ result = pa.table(df)
+ expected = pa.array([1, 2, None], pa.int64())
+ assert result[0].chunk(0).equals(expected)
+
+ # with specifying schema
+ schema = pa.schema([('a', pa.float64())])
+ result = pa.table(df, schema=schema)
+ expected2 = pa.array([1, 2, None], pa.float64())
+ assert result[0].chunk(0).equals(expected2)
+
+ # pass Series to pa.array
+ result = pa.array(df['a'])
+ assert result.equals(expected)
+ result = pa.array(df['a'], type=pa.float64())
+ assert result.equals(expected2)
+
+ # pass actual ExtensionArray to pa.array
+ result = pa.array(df['a'].values)
+ assert result.equals(expected)
+ result = pa.array(df['a'].values, type=pa.float64())
+ assert result.equals(expected2)
+
+
+class DummyExtensionType(pa.PyExtensionType):
+
+ def __init__(self):
+ pa.PyExtensionType.__init__(self, pa.int64())
+
+ def __reduce__(self):
+ return DummyExtensionType, ()
+
+
+def PandasArray__arrow_array__(self, type=None):
+ # hardcode dummy return regardless of self - we only want to check that
+ # this method is correctly called
+ storage = pa.array([1, 2, 3], type=pa.int64())
+ return pa.ExtensionArray.from_storage(DummyExtensionType(), storage)
+
+
+def test_array_protocol_pandas_extension_types(monkeypatch):
+ # ARROW-7022 - ensure protocol works for Period / Interval extension dtypes
+
+ if Version(pd.__version__) < Version('0.24.0'):
+ pytest.skip('Period/IntervalArray only introduced in 0.24')
+
+ storage = pa.array([1, 2, 3], type=pa.int64())
+ expected = pa.ExtensionArray.from_storage(DummyExtensionType(), storage)
+
+ monkeypatch.setattr(pd.arrays.PeriodArray, "__arrow_array__",
+ PandasArray__arrow_array__, raising=False)
+ monkeypatch.setattr(pd.arrays.IntervalArray, "__arrow_array__",
+ PandasArray__arrow_array__, raising=False)
+ for arr in [pd.period_range("2012-01-01", periods=3, freq="D").array,
+ pd.interval_range(1, 4).array]:
+ result = pa.array(arr)
+ assert result.equals(expected)
+ result = pa.array(pd.Series(arr))
+ assert result.equals(expected)
+ result = pa.array(pd.Index(arr))
+ assert result.equals(expected)
+ result = pa.table(pd.DataFrame({'a': arr})).column('a').chunk(0)
+ assert result.equals(expected)
+
+
+# ----------------------------------------------------------------------
+# Pandas ExtensionArray support
+
+
+def _Int64Dtype__from_arrow__(self, array):
+ # for test only deal with single chunk for now
+ # TODO: do we require handling of chunked arrays in the protocol?
+ if isinstance(array, pa.Array):
+ arr = array
+ else:
+ # ChunkedArray - here only deal with a single chunk for the test
+ arr = array.chunk(0)
+ buflist = arr.buffers()
+ data = np.frombuffer(buflist[-1], dtype='int64')[
+ arr.offset:arr.offset + len(arr)]
+ bitmask = buflist[0]
+ if bitmask is not None:
+ mask = pa.BooleanArray.from_buffers(
+ pa.bool_(), len(arr), [None, bitmask])
+ mask = np.asarray(mask)
+ else:
+ mask = np.ones(len(arr), dtype=bool)
+ int_arr = pd.arrays.IntegerArray(data.copy(), ~mask, copy=False)
+ return int_arr
+
+
+def test_convert_to_extension_array(monkeypatch):
+ if Version(pd.__version__) < Version("0.26.0.dev"):
+ pytest.skip("Conversion from IntegerArray to arrow not yet supported")
+
+ import pandas.core.internals as _int
+
+ # table converted from dataframe with extension types (so pandas_metadata
+ # has this information)
+ df = pd.DataFrame(
+ {'a': [1, 2, 3], 'b': pd.array([2, 3, 4], dtype='Int64'),
+ 'c': [4, 5, 6]})
+ table = pa.table(df)
+
+ # Int64Dtype is recognized -> convert to extension block by default
+ # for a proper roundtrip
+ result = table.to_pandas()
+ assert not isinstance(result._data.blocks[0], _int.ExtensionBlock)
+ assert result._data.blocks[0].values.dtype == np.dtype("int64")
+ assert isinstance(result._data.blocks[1], _int.ExtensionBlock)
+ tm.assert_frame_equal(result, df)
+
+ # test with missing values
+ df2 = pd.DataFrame({'a': pd.array([1, 2, None], dtype='Int64')})
+ table2 = pa.table(df2)
+ result = table2.to_pandas()
+ assert isinstance(result._data.blocks[0], _int.ExtensionBlock)
+ tm.assert_frame_equal(result, df2)
+
+ # monkeypatch pandas Int64Dtype to *not* have the protocol method
+ if Version(pd.__version__) < Version("1.3.0.dev"):
+ monkeypatch.delattr(
+ pd.core.arrays.integer._IntegerDtype, "__from_arrow__")
+ else:
+ monkeypatch.delattr(
+ pd.core.arrays.integer.NumericDtype, "__from_arrow__")
+ # Int64Dtype has no __from_arrow__ -> use normal conversion
+ result = table.to_pandas()
+ assert len(result._data.blocks) == 1
+ assert not isinstance(result._data.blocks[0], _int.ExtensionBlock)
+
+
+class MyCustomIntegerType(pa.PyExtensionType):
+
+ def __init__(self):
+ pa.PyExtensionType.__init__(self, pa.int64())
+
+ def __reduce__(self):
+ return MyCustomIntegerType, ()
+
+ def to_pandas_dtype(self):
+ return pd.Int64Dtype()
+
+
+def test_conversion_extensiontype_to_extensionarray(monkeypatch):
+ # converting extension type to linked pandas ExtensionDtype/Array
+ import pandas.core.internals as _int
+
+ if Version(pd.__version__) < Version("0.24.0"):
+ pytest.skip("ExtensionDtype introduced in pandas 0.24")
+
+ storage = pa.array([1, 2, 3, 4], pa.int64())
+ arr = pa.ExtensionArray.from_storage(MyCustomIntegerType(), storage)
+ table = pa.table({'a': arr})
+
+ if Version(pd.__version__) < Version("0.26.0.dev"):
+ # ensure pandas Int64Dtype has the protocol method (for older pandas)
+ monkeypatch.setattr(
+ pd.Int64Dtype, '__from_arrow__', _Int64Dtype__from_arrow__,
+ raising=False)
+
+ # extension type points to Int64Dtype, which knows how to create a
+ # pandas ExtensionArray
+ result = arr.to_pandas()
+ assert isinstance(result._data.blocks[0], _int.ExtensionBlock)
+ expected = pd.Series([1, 2, 3, 4], dtype='Int64')
+ tm.assert_series_equal(result, expected)
+
+ result = table.to_pandas()
+ assert isinstance(result._data.blocks[0], _int.ExtensionBlock)
+ expected = pd.DataFrame({'a': pd.array([1, 2, 3, 4], dtype='Int64')})
+ tm.assert_frame_equal(result, expected)
+
+ # monkeypatch pandas Int64Dtype to *not* have the protocol method
+ # (remove the version added above and the actual version for recent pandas)
+ if Version(pd.__version__) < Version("0.26.0.dev"):
+ monkeypatch.delattr(pd.Int64Dtype, "__from_arrow__")
+ elif Version(pd.__version__) < Version("1.3.0.dev"):
+ monkeypatch.delattr(
+ pd.core.arrays.integer._IntegerDtype, "__from_arrow__")
+ else:
+ monkeypatch.delattr(
+ pd.core.arrays.integer.NumericDtype, "__from_arrow__")
+
+ result = arr.to_pandas()
+ assert not isinstance(result._data.blocks[0], _int.ExtensionBlock)
+ expected = pd.Series([1, 2, 3, 4])
+ tm.assert_series_equal(result, expected)
+
+ with pytest.raises(ValueError):
+ table.to_pandas()
+
+
+def test_to_pandas_extension_dtypes_mapping():
+ if Version(pd.__version__) < Version("0.26.0.dev"):
+ pytest.skip("Conversion to pandas IntegerArray not yet supported")
+
+ table = pa.table({'a': pa.array([1, 2, 3], pa.int64())})
+
+ # default use numpy dtype
+ result = table.to_pandas()
+ assert result['a'].dtype == np.dtype('int64')
+
+ # specify to override the default
+ result = table.to_pandas(types_mapper={pa.int64(): pd.Int64Dtype()}.get)
+ assert isinstance(result['a'].dtype, pd.Int64Dtype)
+
+ # types that return None in function get normal conversion
+ table = pa.table({'a': pa.array([1, 2, 3], pa.int32())})
+ result = table.to_pandas(types_mapper={pa.int64(): pd.Int64Dtype()}.get)
+ assert result['a'].dtype == np.dtype('int32')
+
+ # `types_mapper` overrules the pandas metadata
+ table = pa.table(pd.DataFrame({'a': pd.array([1, 2, 3], dtype="Int64")}))
+ result = table.to_pandas()
+ assert isinstance(result['a'].dtype, pd.Int64Dtype)
+ result = table.to_pandas(
+ types_mapper={pa.int64(): pd.PeriodDtype('D')}.get)
+ assert isinstance(result['a'].dtype, pd.PeriodDtype)
+
+
+def test_array_to_pandas():
+ if Version(pd.__version__) < Version("1.1"):
+ pytest.skip("ExtensionDtype to_pandas method missing")
+
+ for arr in [pd.period_range("2012-01-01", periods=3, freq="D").array,
+ pd.interval_range(1, 4).array]:
+ result = pa.array(arr).to_pandas()
+ expected = pd.Series(arr)
+ tm.assert_series_equal(result, expected)
+
+ # TODO implement proper conversion for chunked array
+ # result = pa.table({"col": arr})["col"].to_pandas()
+ # expected = pd.Series(arr, name="col")
+ # tm.assert_series_equal(result, expected)
+
+
+# ----------------------------------------------------------------------
+# Legacy metadata compatibility tests
+
+
+def test_metadata_compat_range_index_pre_0_12():
+ # Forward compatibility for metadata created from pandas.RangeIndex
+ # prior to pyarrow 0.13.0
+ a_values = ['foo', 'bar', None, 'baz']
+ b_values = ['a', 'a', 'b', 'b']
+ a_arrow = pa.array(a_values, type='utf8')
+ b_arrow = pa.array(b_values, type='utf8')
+
+ rng_index_arrow = pa.array([0, 2, 4, 6], type='int64')
+
+ gen_name_0 = '__index_level_0__'
+ gen_name_1 = '__index_level_1__'
+
+ # Case 1: named RangeIndex
+ e1 = pd.DataFrame({
+ 'a': a_values
+ }, index=pd.RangeIndex(0, 8, step=2, name='qux'))
+ t1 = pa.Table.from_arrays([a_arrow, rng_index_arrow],
+ names=['a', 'qux'])
+ t1 = t1.replace_schema_metadata({
+ b'pandas': json.dumps(
+ {'index_columns': ['qux'],
+ 'column_indexes': [{'name': None,
+ 'field_name': None,
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': {'encoding': 'UTF-8'}}],
+ 'columns': [{'name': 'a',
+ 'field_name': 'a',
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': None},
+ {'name': 'qux',
+ 'field_name': 'qux',
+ 'pandas_type': 'int64',
+ 'numpy_type': 'int64',
+ 'metadata': None}],
+ 'pandas_version': '0.23.4'}
+ )})
+ r1 = t1.to_pandas()
+ tm.assert_frame_equal(r1, e1)
+
+ # Case 2: named RangeIndex, but conflicts with an actual column
+ e2 = pd.DataFrame({
+ 'qux': a_values
+ }, index=pd.RangeIndex(0, 8, step=2, name='qux'))
+ t2 = pa.Table.from_arrays([a_arrow, rng_index_arrow],
+ names=['qux', gen_name_0])
+ t2 = t2.replace_schema_metadata({
+ b'pandas': json.dumps(
+ {'index_columns': [gen_name_0],
+ 'column_indexes': [{'name': None,
+ 'field_name': None,
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': {'encoding': 'UTF-8'}}],
+ 'columns': [{'name': 'a',
+ 'field_name': 'a',
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': None},
+ {'name': 'qux',
+ 'field_name': gen_name_0,
+ 'pandas_type': 'int64',
+ 'numpy_type': 'int64',
+ 'metadata': None}],
+ 'pandas_version': '0.23.4'}
+ )})
+ r2 = t2.to_pandas()
+ tm.assert_frame_equal(r2, e2)
+
+ # Case 3: unnamed RangeIndex
+ e3 = pd.DataFrame({
+ 'a': a_values
+ }, index=pd.RangeIndex(0, 8, step=2, name=None))
+ t3 = pa.Table.from_arrays([a_arrow, rng_index_arrow],
+ names=['a', gen_name_0])
+ t3 = t3.replace_schema_metadata({
+ b'pandas': json.dumps(
+ {'index_columns': [gen_name_0],
+ 'column_indexes': [{'name': None,
+ 'field_name': None,
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': {'encoding': 'UTF-8'}}],
+ 'columns': [{'name': 'a',
+ 'field_name': 'a',
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': None},
+ {'name': None,
+ 'field_name': gen_name_0,
+ 'pandas_type': 'int64',
+ 'numpy_type': 'int64',
+ 'metadata': None}],
+ 'pandas_version': '0.23.4'}
+ )})
+ r3 = t3.to_pandas()
+ tm.assert_frame_equal(r3, e3)
+
+ # Case 4: MultiIndex with named RangeIndex
+ e4 = pd.DataFrame({
+ 'a': a_values
+ }, index=[pd.RangeIndex(0, 8, step=2, name='qux'), b_values])
+ t4 = pa.Table.from_arrays([a_arrow, rng_index_arrow, b_arrow],
+ names=['a', 'qux', gen_name_1])
+ t4 = t4.replace_schema_metadata({
+ b'pandas': json.dumps(
+ {'index_columns': ['qux', gen_name_1],
+ 'column_indexes': [{'name': None,
+ 'field_name': None,
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': {'encoding': 'UTF-8'}}],
+ 'columns': [{'name': 'a',
+ 'field_name': 'a',
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': None},
+ {'name': 'qux',
+ 'field_name': 'qux',
+ 'pandas_type': 'int64',
+ 'numpy_type': 'int64',
+ 'metadata': None},
+ {'name': None,
+ 'field_name': gen_name_1,
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': None}],
+ 'pandas_version': '0.23.4'}
+ )})
+ r4 = t4.to_pandas()
+ tm.assert_frame_equal(r4, e4)
+
+ # Case 4: MultiIndex with unnamed RangeIndex
+ e5 = pd.DataFrame({
+ 'a': a_values
+ }, index=[pd.RangeIndex(0, 8, step=2, name=None), b_values])
+ t5 = pa.Table.from_arrays([a_arrow, rng_index_arrow, b_arrow],
+ names=['a', gen_name_0, gen_name_1])
+ t5 = t5.replace_schema_metadata({
+ b'pandas': json.dumps(
+ {'index_columns': [gen_name_0, gen_name_1],
+ 'column_indexes': [{'name': None,
+ 'field_name': None,
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': {'encoding': 'UTF-8'}}],
+ 'columns': [{'name': 'a',
+ 'field_name': 'a',
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': None},
+ {'name': None,
+ 'field_name': gen_name_0,
+ 'pandas_type': 'int64',
+ 'numpy_type': 'int64',
+ 'metadata': None},
+ {'name': None,
+ 'field_name': gen_name_1,
+ 'pandas_type': 'unicode',
+ 'numpy_type': 'object',
+ 'metadata': None}],
+ 'pandas_version': '0.23.4'}
+ )})
+ r5 = t5.to_pandas()
+ tm.assert_frame_equal(r5, e5)
+
+
+def test_metadata_compat_missing_field_name():
+ # Combination of missing field name but with index column as metadata.
+ # This combo occurs in the latest versions of fastparquet (0.3.2), but not
+ # in pyarrow itself (since field_name was added in 0.8, index as metadata
+ # only added later)
+
+ a_values = [1, 2, 3, 4]
+ b_values = ['a', 'b', 'c', 'd']
+ a_arrow = pa.array(a_values, type='int64')
+ b_arrow = pa.array(b_values, type='utf8')
+
+ expected = pd.DataFrame({
+ 'a': a_values,
+ 'b': b_values,
+ }, index=pd.RangeIndex(0, 8, step=2, name='qux'))
+ table = pa.table({'a': a_arrow, 'b': b_arrow})
+
+ # metadata generated by fastparquet 0.3.2 with missing field_names
+ table = table.replace_schema_metadata({
+ b'pandas': json.dumps({
+ 'column_indexes': [
+ {'field_name': None,
+ 'metadata': None,
+ 'name': None,
+ 'numpy_type': 'object',
+ 'pandas_type': 'mixed-integer'}
+ ],
+ 'columns': [
+ {'metadata': None,
+ 'name': 'a',
+ 'numpy_type': 'int64',
+ 'pandas_type': 'int64'},
+ {'metadata': None,
+ 'name': 'b',
+ 'numpy_type': 'object',
+ 'pandas_type': 'unicode'}
+ ],
+ 'index_columns': [
+ {'kind': 'range',
+ 'name': 'qux',
+ 'start': 0,
+ 'step': 2,
+ 'stop': 8}
+ ],
+ 'pandas_version': '0.25.0'}
+
+ )})
+ result = table.to_pandas()
+ tm.assert_frame_equal(result, expected)
+
+
+def test_metadata_index_name_not_json_serializable():
+ name = np.int64(6) # not json serializable by default
+ table = pa.table(pd.DataFrame(index=pd.RangeIndex(0, 4, name=name)))
+ metadata = table.schema.pandas_metadata
+ assert metadata['index_columns'][0]['name'] == '6'
+
+
+def test_metadata_index_name_is_json_serializable():
+ name = 6 # json serializable by default
+ table = pa.table(pd.DataFrame(index=pd.RangeIndex(0, 4, name=name)))
+ metadata = table.schema.pandas_metadata
+ assert metadata['index_columns'][0]['name'] == 6
+
+
+def make_df_with_timestamps():
+ # Some of the milliseconds timestamps deliberately don't fit in the range
+ # that is possible with nanosecond timestamps.
+ df = pd.DataFrame({
+ 'dateTimeMs': [
+ np.datetime64('0001-01-01 00:00', 'ms'),
+ np.datetime64('2012-05-02 12:35', 'ms'),
+ np.datetime64('2012-05-03 15:42', 'ms'),
+ np.datetime64('3000-05-03 15:42', 'ms'),
+ ],
+ 'dateTimeNs': [
+ np.datetime64('1991-01-01 00:00', 'ns'),
+ np.datetime64('2012-05-02 12:35', 'ns'),
+ np.datetime64('2012-05-03 15:42', 'ns'),
+ np.datetime64('2050-05-03 15:42', 'ns'),
+ ],
+ })
+ # Not part of what we're testing, just ensuring that the inputs are what we
+ # expect.
+ assert (df.dateTimeMs.dtype, df.dateTimeNs.dtype) == (
+ # O == object, <M8[ns] == timestamp64[ns]
+ np.dtype("O"), np.dtype("<M8[ns]")
+ )
+ return df
+
+
+@pytest.mark.parquet
+def test_timestamp_as_object_parquet(tempdir):
+ # Timestamps can be stored as Parquet and reloaded into Pandas with no loss
+ # of information if the timestamp_as_object option is True.
+ df = make_df_with_timestamps()
+ table = pa.Table.from_pandas(df)
+ filename = tempdir / "timestamps_from_pandas.parquet"
+ pq.write_table(table, filename, version="2.0")
+ result = pq.read_table(filename)
+ df2 = result.to_pandas(timestamp_as_object=True)
+ tm.assert_frame_equal(df, df2)
+
+
+def test_timestamp_as_object_out_of_range():
+ # Out of range timestamps can be converted Arrow and reloaded into Pandas
+ # with no loss of information if the timestamp_as_object option is True.
+ df = make_df_with_timestamps()
+ table = pa.Table.from_pandas(df)
+ df2 = table.to_pandas(timestamp_as_object=True)
+ tm.assert_frame_equal(df, df2)
+
+
+@pytest.mark.parametrize("resolution", ["s", "ms", "us"])
+@pytest.mark.parametrize("tz", [None, "America/New_York"])
+# One datetime outside nanosecond range, one inside nanosecond range:
+@pytest.mark.parametrize("dt", [datetime(1553, 1, 1), datetime(2020, 1, 1)])
+def test_timestamp_as_object_non_nanosecond(resolution, tz, dt):
+ # Timestamps can be converted Arrow and reloaded into Pandas with no loss
+ # of information if the timestamp_as_object option is True.
+ arr = pa.array([dt], type=pa.timestamp(resolution, tz=tz))
+ table = pa.table({'a': arr})
+
+ for result in [
+ arr.to_pandas(timestamp_as_object=True),
+ table.to_pandas(timestamp_as_object=True)['a']
+ ]:
+ assert result.dtype == object
+ assert isinstance(result[0], datetime)
+ if tz:
+ assert result[0].tzinfo is not None
+ expected = result[0].tzinfo.fromutc(dt)
+ else:
+ assert result[0].tzinfo is None
+ expected = dt
+ assert result[0] == expected
+
+
+def test_threaded_pandas_import():
+ invoke_script("pandas_threaded_import.py")
diff --git a/src/arrow/python/pyarrow/tests/test_plasma.py b/src/arrow/python/pyarrow/tests/test_plasma.py
new file mode 100644
index 000000000..ed08a6872
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_plasma.py
@@ -0,0 +1,1073 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import multiprocessing
+import os
+import pytest
+import random
+import signal
+import struct
+import subprocess
+import sys
+import time
+
+import numpy as np
+import pyarrow as pa
+
+
+DEFAULT_PLASMA_STORE_MEMORY = 10 ** 8
+USE_VALGRIND = os.getenv("PLASMA_VALGRIND") == "1"
+EXTERNAL_STORE = "hashtable://test"
+SMALL_OBJECT_SIZE = 9000
+
+
+def random_name():
+ return str(random.randint(0, 99999999))
+
+
+def random_object_id():
+ import pyarrow.plasma as plasma
+ return plasma.ObjectID(np.random.bytes(20))
+
+
+def generate_metadata(length):
+ metadata = bytearray(length)
+ if length > 0:
+ metadata[0] = random.randint(0, 255)
+ metadata[-1] = random.randint(0, 255)
+ for _ in range(100):
+ metadata[random.randint(0, length - 1)] = random.randint(0, 255)
+ return metadata
+
+
+def write_to_data_buffer(buff, length):
+ array = np.frombuffer(buff, dtype="uint8")
+ if length > 0:
+ array[0] = random.randint(0, 255)
+ array[-1] = random.randint(0, 255)
+ for _ in range(100):
+ array[random.randint(0, length - 1)] = random.randint(0, 255)
+
+
+def create_object_with_id(client, object_id, data_size, metadata_size,
+ seal=True):
+ metadata = generate_metadata(metadata_size)
+ memory_buffer = client.create(object_id, data_size, metadata)
+ write_to_data_buffer(memory_buffer, data_size)
+ if seal:
+ client.seal(object_id)
+ return memory_buffer, metadata
+
+
+def create_object(client, data_size, metadata_size=0, seal=True):
+ object_id = random_object_id()
+ memory_buffer, metadata = create_object_with_id(client, object_id,
+ data_size, metadata_size,
+ seal=seal)
+ return object_id, memory_buffer, metadata
+
+
+@pytest.mark.plasma
+class TestPlasmaClient:
+
+ def setup_method(self, test_method):
+ import pyarrow.plasma as plasma
+ # Start Plasma store.
+ self.plasma_store_ctx = plasma.start_plasma_store(
+ plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
+ use_valgrind=USE_VALGRIND)
+ self.plasma_store_name, self.p = self.plasma_store_ctx.__enter__()
+ # Connect to Plasma.
+ self.plasma_client = plasma.connect(self.plasma_store_name)
+ self.plasma_client2 = plasma.connect(self.plasma_store_name)
+
+ def teardown_method(self, test_method):
+ try:
+ # Check that the Plasma store is still alive.
+ assert self.p.poll() is None
+ # Ensure Valgrind and/or coverage have a clean exit
+ # Valgrind misses SIGTERM if it is delivered before the
+ # event loop is ready; this race condition is mitigated
+ # but not solved by time.sleep().
+ if USE_VALGRIND:
+ time.sleep(1.0)
+ self.p.send_signal(signal.SIGTERM)
+ self.p.wait(timeout=5)
+ assert self.p.returncode == 0
+ finally:
+ self.plasma_store_ctx.__exit__(None, None, None)
+
+ def test_connection_failure_raises_exception(self):
+ import pyarrow.plasma as plasma
+ # ARROW-1264
+ with pytest.raises(IOError):
+ plasma.connect('unknown-store-name', num_retries=1)
+
+ def test_create(self):
+ # Create an object id string.
+ object_id = random_object_id()
+ # Create a new buffer and write to it.
+ length = 50
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id,
+ length),
+ dtype="uint8")
+ for i in range(length):
+ memory_buffer[i] = i % 256
+ # Seal the object.
+ self.plasma_client.seal(object_id)
+ # Get the object.
+ memory_buffer = np.frombuffer(
+ self.plasma_client.get_buffers([object_id])[0], dtype="uint8")
+ for i in range(length):
+ assert memory_buffer[i] == i % 256
+
+ def test_create_with_metadata(self):
+ for length in range(0, 1000, 3):
+ # Create an object id string.
+ object_id = random_object_id()
+ # Create a random metadata string.
+ metadata = generate_metadata(length)
+ # Create a new buffer and write to it.
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id,
+ length,
+ metadata),
+ dtype="uint8")
+ for i in range(length):
+ memory_buffer[i] = i % 256
+ # Seal the object.
+ self.plasma_client.seal(object_id)
+ # Get the object.
+ memory_buffer = np.frombuffer(
+ self.plasma_client.get_buffers([object_id])[0], dtype="uint8")
+ for i in range(length):
+ assert memory_buffer[i] == i % 256
+ # Get the metadata.
+ metadata_buffer = np.frombuffer(
+ self.plasma_client.get_metadata([object_id])[0], dtype="uint8")
+ assert len(metadata) == len(metadata_buffer)
+ for i in range(len(metadata)):
+ assert metadata[i] == metadata_buffer[i]
+
+ def test_create_existing(self):
+ # This test is partially used to test the code path in which we create
+ # an object with an ID that already exists
+ length = 100
+ for _ in range(1000):
+ object_id = random_object_id()
+ self.plasma_client.create(object_id, length,
+ generate_metadata(length))
+ try:
+ self.plasma_client.create(object_id, length,
+ generate_metadata(length))
+ # TODO(pcm): Introduce a more specific error type here.
+ except pa.lib.ArrowException:
+ pass
+ else:
+ assert False
+
+ def test_create_and_seal(self):
+
+ # Create a bunch of objects.
+ object_ids = []
+ for i in range(1000):
+ object_id = random_object_id()
+ object_ids.append(object_id)
+ self.plasma_client.create_and_seal(object_id, i * b'a', i * b'b')
+
+ for i in range(1000):
+ [data_tuple] = self.plasma_client.get_buffers([object_ids[i]],
+ with_meta=True)
+ assert data_tuple[1].to_pybytes() == i * b'a'
+ assert (self.plasma_client.get_metadata(
+ [object_ids[i]])[0].to_pybytes() ==
+ i * b'b')
+
+ # Make sure that creating the same object twice raises an exception.
+ object_id = random_object_id()
+ self.plasma_client.create_and_seal(object_id, b'a', b'b')
+ with pytest.raises(pa.plasma.PlasmaObjectExists):
+ self.plasma_client.create_and_seal(object_id, b'a', b'b')
+
+ # Make sure that these objects can be evicted.
+ big_object = DEFAULT_PLASMA_STORE_MEMORY // 10 * b'a'
+ object_ids = []
+ for _ in range(20):
+ object_id = random_object_id()
+ object_ids.append(object_id)
+ self.plasma_client.create_and_seal(random_object_id(), big_object,
+ big_object)
+ for i in range(10):
+ assert not self.plasma_client.contains(object_ids[i])
+
+ def test_get(self):
+ num_object_ids = 60
+ # Test timing out of get with various timeouts.
+ for timeout in [0, 10, 100, 1000]:
+ object_ids = [random_object_id() for _ in range(num_object_ids)]
+ results = self.plasma_client.get_buffers(object_ids,
+ timeout_ms=timeout)
+ assert results == num_object_ids * [None]
+
+ data_buffers = []
+ metadata_buffers = []
+ for i in range(num_object_ids):
+ if i % 2 == 0:
+ data_buffer, metadata_buffer = create_object_with_id(
+ self.plasma_client, object_ids[i], 2000, 2000)
+ data_buffers.append(data_buffer)
+ metadata_buffers.append(metadata_buffer)
+
+ # Test timing out from some but not all get calls with various
+ # timeouts.
+ for timeout in [0, 10, 100, 1000]:
+ data_results = self.plasma_client.get_buffers(object_ids,
+ timeout_ms=timeout)
+ # metadata_results = self.plasma_client.get_metadata(
+ # object_ids, timeout_ms=timeout)
+ for i in range(num_object_ids):
+ if i % 2 == 0:
+ array1 = np.frombuffer(data_buffers[i // 2], dtype="uint8")
+ array2 = np.frombuffer(data_results[i], dtype="uint8")
+ np.testing.assert_equal(array1, array2)
+ # TODO(rkn): We should compare the metadata as well. But
+ # currently the types are different (e.g., memoryview
+ # versus bytearray).
+ # assert plasma.buffers_equal(
+ # metadata_buffers[i // 2], metadata_results[i])
+ else:
+ assert results[i] is None
+
+ # Test trying to get an object that was created by the same client but
+ # not sealed.
+ object_id = random_object_id()
+ self.plasma_client.create(object_id, 10, b"metadata")
+ assert self.plasma_client.get_buffers(
+ [object_id], timeout_ms=0, with_meta=True)[0][1] is None
+ assert self.plasma_client.get_buffers(
+ [object_id], timeout_ms=1, with_meta=True)[0][1] is None
+ self.plasma_client.seal(object_id)
+ assert self.plasma_client.get_buffers(
+ [object_id], timeout_ms=0, with_meta=True)[0][1] is not None
+
+ def test_buffer_lifetime(self):
+ # ARROW-2195
+ arr = pa.array([1, 12, 23, 3, 34], pa.int32())
+ batch = pa.RecordBatch.from_arrays([arr], ['field1'])
+
+ # Serialize RecordBatch into Plasma store
+ sink = pa.MockOutputStream()
+ writer = pa.RecordBatchStreamWriter(sink, batch.schema)
+ writer.write_batch(batch)
+ writer.close()
+
+ object_id = random_object_id()
+ data_buffer = self.plasma_client.create(object_id, sink.size())
+ stream = pa.FixedSizeBufferWriter(data_buffer)
+ writer = pa.RecordBatchStreamWriter(stream, batch.schema)
+ writer.write_batch(batch)
+ writer.close()
+ self.plasma_client.seal(object_id)
+ del data_buffer
+
+ # Unserialize RecordBatch from Plasma store
+ [data_buffer] = self.plasma_client2.get_buffers([object_id])
+ reader = pa.RecordBatchStreamReader(data_buffer)
+ read_batch = reader.read_next_batch()
+ # Lose reference to returned buffer. The RecordBatch must still
+ # be backed by valid memory.
+ del data_buffer, reader
+
+ assert read_batch.equals(batch)
+
+ def test_put_and_get(self):
+ for value in [["hello", "world", 3, 1.0], None, "hello"]:
+ object_id = self.plasma_client.put(value)
+ [result] = self.plasma_client.get([object_id])
+ assert result == value
+
+ result = self.plasma_client.get(object_id)
+ assert result == value
+
+ object_id = random_object_id()
+ [result] = self.plasma_client.get([object_id], timeout_ms=0)
+ assert result == pa.plasma.ObjectNotAvailable
+
+ @pytest.mark.filterwarnings(
+ "ignore:'pyarrow.deserialize':FutureWarning")
+ def test_put_and_get_raw_buffer(self):
+ temp_id = random_object_id()
+ use_meta = b"RAW"
+
+ def deserialize_or_output(data_tuple):
+ if data_tuple[0] == use_meta:
+ return data_tuple[1].to_pybytes()
+ else:
+ if data_tuple[1] is None:
+ return pa.plasma.ObjectNotAvailable
+ else:
+ return pa.deserialize(data_tuple[1])
+
+ for value in [b"Bytes Test", temp_id.binary(), 10 * b"\x00", 123]:
+ if isinstance(value, bytes):
+ object_id = self.plasma_client.put_raw_buffer(
+ value, metadata=use_meta)
+ else:
+ object_id = self.plasma_client.put(value)
+ [result] = self.plasma_client.get_buffers([object_id],
+ with_meta=True)
+ result = deserialize_or_output(result)
+ assert result == value
+
+ object_id = random_object_id()
+ [result] = self.plasma_client.get_buffers([object_id],
+ timeout_ms=0,
+ with_meta=True)
+ result = deserialize_or_output(result)
+ assert result == pa.plasma.ObjectNotAvailable
+
+ @pytest.mark.filterwarnings(
+ "ignore:'serialization_context':FutureWarning")
+ def test_put_and_get_serialization_context(self):
+
+ class CustomType:
+ def __init__(self, val):
+ self.val = val
+
+ val = CustomType(42)
+
+ with pytest.raises(pa.ArrowSerializationError):
+ self.plasma_client.put(val)
+
+ serialization_context = pa.lib.SerializationContext()
+ serialization_context.register_type(CustomType, 20*"\x00")
+
+ object_id = self.plasma_client.put(
+ val, None, serialization_context=serialization_context)
+
+ with pytest.raises(pa.ArrowSerializationError):
+ result = self.plasma_client.get(object_id)
+
+ result = self.plasma_client.get(
+ object_id, -1, serialization_context=serialization_context)
+ assert result.val == val.val
+
+ def test_store_arrow_objects(self):
+ data = np.random.randn(10, 4)
+ # Write an arrow object.
+ object_id = random_object_id()
+ tensor = pa.Tensor.from_numpy(data)
+ data_size = pa.ipc.get_tensor_size(tensor)
+ buf = self.plasma_client.create(object_id, data_size)
+ stream = pa.FixedSizeBufferWriter(buf)
+ pa.ipc.write_tensor(tensor, stream)
+ self.plasma_client.seal(object_id)
+ # Read the arrow object.
+ [tensor] = self.plasma_client.get_buffers([object_id])
+ reader = pa.BufferReader(tensor)
+ array = pa.ipc.read_tensor(reader).to_numpy()
+ # Assert that they are equal.
+ np.testing.assert_equal(data, array)
+
+ @pytest.mark.pandas
+ def test_store_pandas_dataframe(self):
+ import pandas as pd
+ import pyarrow.plasma as plasma
+ d = {'one': pd.Series([1., 2., 3.], index=['a', 'b', 'c']),
+ 'two': pd.Series([1., 2., 3., 4.], index=['a', 'b', 'c', 'd'])}
+ df = pd.DataFrame(d)
+
+ # Write the DataFrame.
+ record_batch = pa.RecordBatch.from_pandas(df)
+ # Determine the size.
+ s = pa.MockOutputStream()
+ stream_writer = pa.RecordBatchStreamWriter(s, record_batch.schema)
+ stream_writer.write_batch(record_batch)
+ data_size = s.size()
+ object_id = plasma.ObjectID(np.random.bytes(20))
+
+ buf = self.plasma_client.create(object_id, data_size)
+ stream = pa.FixedSizeBufferWriter(buf)
+ stream_writer = pa.RecordBatchStreamWriter(stream, record_batch.schema)
+ stream_writer.write_batch(record_batch)
+
+ self.plasma_client.seal(object_id)
+
+ # Read the DataFrame.
+ [data] = self.plasma_client.get_buffers([object_id])
+ reader = pa.RecordBatchStreamReader(pa.BufferReader(data))
+ result = reader.read_next_batch().to_pandas()
+
+ pd.testing.assert_frame_equal(df, result)
+
+ def test_pickle_object_ids(self):
+ # This can be used for sharing object IDs between processes.
+ import pickle
+ object_id = random_object_id()
+ data = pickle.dumps(object_id)
+ object_id2 = pickle.loads(data)
+ assert object_id == object_id2
+
+ def test_store_full(self):
+ # The store is started with 1GB, so make sure that create throws an
+ # exception when it is full.
+ def assert_create_raises_plasma_full(unit_test, size):
+ partial_size = np.random.randint(size)
+ try:
+ _, memory_buffer, _ = create_object(unit_test.plasma_client,
+ partial_size,
+ size - partial_size)
+ # TODO(pcm): More specific error here.
+ except pa.lib.ArrowException:
+ pass
+ else:
+ # For some reason the above didn't throw an exception, so fail.
+ assert False
+
+ PERCENT = DEFAULT_PLASMA_STORE_MEMORY // 100
+
+ # Create a list to keep some of the buffers in scope.
+ memory_buffers = []
+ _, memory_buffer, _ = create_object(self.plasma_client, 50 * PERCENT)
+ memory_buffers.append(memory_buffer)
+ # Remaining space is 50%. Make sure that we can't create an
+ # object of size 50% + 1, but we can create one of size 20%.
+ assert_create_raises_plasma_full(
+ self, 50 * PERCENT + SMALL_OBJECT_SIZE)
+ _, memory_buffer, _ = create_object(self.plasma_client, 20 * PERCENT)
+ del memory_buffer
+ _, memory_buffer, _ = create_object(self.plasma_client, 20 * PERCENT)
+ del memory_buffer
+ assert_create_raises_plasma_full(
+ self, 50 * PERCENT + SMALL_OBJECT_SIZE)
+
+ _, memory_buffer, _ = create_object(self.plasma_client, 20 * PERCENT)
+ memory_buffers.append(memory_buffer)
+ # Remaining space is 30%.
+ assert_create_raises_plasma_full(
+ self, 30 * PERCENT + SMALL_OBJECT_SIZE)
+
+ _, memory_buffer, _ = create_object(self.plasma_client, 10 * PERCENT)
+ memory_buffers.append(memory_buffer)
+ # Remaining space is 20%.
+ assert_create_raises_plasma_full(
+ self, 20 * PERCENT + SMALL_OBJECT_SIZE)
+
+ def test_contains(self):
+ fake_object_ids = [random_object_id() for _ in range(100)]
+ real_object_ids = [random_object_id() for _ in range(100)]
+ for object_id in real_object_ids:
+ assert self.plasma_client.contains(object_id) is False
+ self.plasma_client.create(object_id, 100)
+ self.plasma_client.seal(object_id)
+ assert self.plasma_client.contains(object_id)
+ for object_id in fake_object_ids:
+ assert not self.plasma_client.contains(object_id)
+ for object_id in real_object_ids:
+ assert self.plasma_client.contains(object_id)
+
+ def test_hash(self):
+ # Check the hash of an object that doesn't exist.
+ object_id1 = random_object_id()
+ try:
+ self.plasma_client.hash(object_id1)
+ # TODO(pcm): Introduce a more specific error type here
+ except pa.lib.ArrowException:
+ pass
+ else:
+ assert False
+
+ length = 1000
+ # Create a random object, and check that the hash function always
+ # returns the same value.
+ metadata = generate_metadata(length)
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id1,
+ length,
+ metadata),
+ dtype="uint8")
+ for i in range(length):
+ memory_buffer[i] = i % 256
+ self.plasma_client.seal(object_id1)
+ assert (self.plasma_client.hash(object_id1) ==
+ self.plasma_client.hash(object_id1))
+
+ # Create a second object with the same value as the first, and check
+ # that their hashes are equal.
+ object_id2 = random_object_id()
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id2,
+ length,
+ metadata),
+ dtype="uint8")
+ for i in range(length):
+ memory_buffer[i] = i % 256
+ self.plasma_client.seal(object_id2)
+ assert (self.plasma_client.hash(object_id1) ==
+ self.plasma_client.hash(object_id2))
+
+ # Create a third object with a different value from the first two, and
+ # check that its hash is different.
+ object_id3 = random_object_id()
+ metadata = generate_metadata(length)
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id3,
+ length,
+ metadata),
+ dtype="uint8")
+ for i in range(length):
+ memory_buffer[i] = (i + 1) % 256
+ self.plasma_client.seal(object_id3)
+ assert (self.plasma_client.hash(object_id1) !=
+ self.plasma_client.hash(object_id3))
+
+ # Create a fourth object with the same value as the third, but
+ # different metadata. Check that its hash is different from any of the
+ # previous three.
+ object_id4 = random_object_id()
+ metadata4 = generate_metadata(length)
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id4,
+ length,
+ metadata4),
+ dtype="uint8")
+ for i in range(length):
+ memory_buffer[i] = (i + 1) % 256
+ self.plasma_client.seal(object_id4)
+ assert (self.plasma_client.hash(object_id1) !=
+ self.plasma_client.hash(object_id4))
+ assert (self.plasma_client.hash(object_id3) !=
+ self.plasma_client.hash(object_id4))
+
+ def test_many_hashes(self):
+ hashes = []
+ length = 2 ** 10
+
+ for i in range(256):
+ object_id = random_object_id()
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id,
+ length),
+ dtype="uint8")
+ for j in range(length):
+ memory_buffer[j] = i
+ self.plasma_client.seal(object_id)
+ hashes.append(self.plasma_client.hash(object_id))
+
+ # Create objects of varying length. Each pair has two bits different.
+ for i in range(length):
+ object_id = random_object_id()
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id,
+ length),
+ dtype="uint8")
+ for j in range(length):
+ memory_buffer[j] = 0
+ memory_buffer[i] = 1
+ self.plasma_client.seal(object_id)
+ hashes.append(self.plasma_client.hash(object_id))
+
+ # Create objects of varying length, all with value 0.
+ for i in range(length):
+ object_id = random_object_id()
+ memory_buffer = np.frombuffer(self.plasma_client.create(object_id,
+ i),
+ dtype="uint8")
+ for j in range(i):
+ memory_buffer[j] = 0
+ self.plasma_client.seal(object_id)
+ hashes.append(self.plasma_client.hash(object_id))
+
+ # Check that all hashes were unique.
+ assert len(set(hashes)) == 256 + length + length
+
+ # def test_individual_delete(self):
+ # length = 100
+ # # Create an object id string.
+ # object_id = random_object_id()
+ # # Create a random metadata string.
+ # metadata = generate_metadata(100)
+ # # Create a new buffer and write to it.
+ # memory_buffer = self.plasma_client.create(object_id, length,
+ # metadata)
+ # for i in range(length):
+ # memory_buffer[i] = chr(i % 256)
+ # # Seal the object.
+ # self.plasma_client.seal(object_id)
+ # # Check that the object is present.
+ # assert self.plasma_client.contains(object_id)
+ # # Delete the object.
+ # self.plasma_client.delete(object_id)
+ # # Make sure the object is no longer present.
+ # self.assertFalse(self.plasma_client.contains(object_id))
+ #
+ # def test_delete(self):
+ # # Create some objects.
+ # object_ids = [random_object_id() for _ in range(100)]
+ # for object_id in object_ids:
+ # length = 100
+ # # Create a random metadata string.
+ # metadata = generate_metadata(100)
+ # # Create a new buffer and write to it.
+ # memory_buffer = self.plasma_client.create(object_id, length,
+ # metadata)
+ # for i in range(length):
+ # memory_buffer[i] = chr(i % 256)
+ # # Seal the object.
+ # self.plasma_client.seal(object_id)
+ # # Check that the object is present.
+ # assert self.plasma_client.contains(object_id)
+ #
+ # # Delete the objects and make sure they are no longer present.
+ # for object_id in object_ids:
+ # # Delete the object.
+ # self.plasma_client.delete(object_id)
+ # # Make sure the object is no longer present.
+ # self.assertFalse(self.plasma_client.contains(object_id))
+
+ def test_illegal_functionality(self):
+ # Create an object id string.
+ object_id = random_object_id()
+ # Create a new buffer and write to it.
+ length = 1000
+ memory_buffer = self.plasma_client.create(object_id, length)
+ # Make sure we cannot access memory out of bounds.
+ with pytest.raises(Exception):
+ memory_buffer[length]
+ # Seal the object.
+ self.plasma_client.seal(object_id)
+ # This test is commented out because it currently fails.
+ # # Make sure the object is ready only now.
+ # def illegal_assignment():
+ # memory_buffer[0] = chr(0)
+ # with pytest.raises(Exception):
+ # illegal_assignment()
+ # Get the object.
+ memory_buffer = self.plasma_client.get_buffers([object_id])[0]
+
+ # Make sure the object is read only.
+ def illegal_assignment():
+ memory_buffer[0] = chr(0)
+ with pytest.raises(Exception):
+ illegal_assignment()
+
+ def test_evict(self):
+ client = self.plasma_client2
+ object_id1 = random_object_id()
+ b1 = client.create(object_id1, 1000)
+ client.seal(object_id1)
+ del b1
+ assert client.evict(1) == 1000
+
+ object_id2 = random_object_id()
+ object_id3 = random_object_id()
+ b2 = client.create(object_id2, 999)
+ b3 = client.create(object_id3, 998)
+ client.seal(object_id3)
+ del b3
+ assert client.evict(1000) == 998
+
+ object_id4 = random_object_id()
+ b4 = client.create(object_id4, 997)
+ client.seal(object_id4)
+ del b4
+ client.seal(object_id2)
+ del b2
+ assert client.evict(1) == 997
+ assert client.evict(1) == 999
+
+ object_id5 = random_object_id()
+ object_id6 = random_object_id()
+ object_id7 = random_object_id()
+ b5 = client.create(object_id5, 996)
+ b6 = client.create(object_id6, 995)
+ b7 = client.create(object_id7, 994)
+ client.seal(object_id5)
+ client.seal(object_id6)
+ client.seal(object_id7)
+ del b5
+ del b6
+ del b7
+ assert client.evict(2000) == 996 + 995 + 994
+
+ # Mitigate valgrind-induced slowness
+ SUBSCRIBE_TEST_SIZES = ([1, 10, 100, 1000] if USE_VALGRIND
+ else [1, 10, 100, 1000, 10000])
+
+ def test_subscribe(self):
+ # Subscribe to notifications from the Plasma Store.
+ self.plasma_client.subscribe()
+ for i in self.SUBSCRIBE_TEST_SIZES:
+ object_ids = [random_object_id() for _ in range(i)]
+ metadata_sizes = [np.random.randint(1000) for _ in range(i)]
+ data_sizes = [np.random.randint(1000) for _ in range(i)]
+ for j in range(i):
+ self.plasma_client.create(
+ object_ids[j], data_sizes[j],
+ metadata=bytearray(np.random.bytes(metadata_sizes[j])))
+ self.plasma_client.seal(object_ids[j])
+ # Check that we received notifications for all of the objects.
+ for j in range(i):
+ notification_info = self.plasma_client.get_next_notification()
+ recv_objid, recv_dsize, recv_msize = notification_info
+ assert object_ids[j] == recv_objid
+ assert data_sizes[j] == recv_dsize
+ assert metadata_sizes[j] == recv_msize
+
+ def test_subscribe_socket(self):
+ # Subscribe to notifications from the Plasma Store.
+ self.plasma_client.subscribe()
+ rsock = self.plasma_client.get_notification_socket()
+ for i in self.SUBSCRIBE_TEST_SIZES:
+ # Get notification from socket.
+ object_ids = [random_object_id() for _ in range(i)]
+ metadata_sizes = [np.random.randint(1000) for _ in range(i)]
+ data_sizes = [np.random.randint(1000) for _ in range(i)]
+
+ for j in range(i):
+ self.plasma_client.create(
+ object_ids[j], data_sizes[j],
+ metadata=bytearray(np.random.bytes(metadata_sizes[j])))
+ self.plasma_client.seal(object_ids[j])
+
+ # Check that we received notifications for all of the objects.
+ for j in range(i):
+ # Assume the plasma store will not be full,
+ # so we always get the data size instead of -1.
+ msg_len, = struct.unpack('L', rsock.recv(8))
+ content = rsock.recv(msg_len)
+ recv_objids, recv_dsizes, recv_msizes = (
+ self.plasma_client.decode_notifications(content))
+ assert object_ids[j] == recv_objids[0]
+ assert data_sizes[j] == recv_dsizes[0]
+ assert metadata_sizes[j] == recv_msizes[0]
+
+ def test_subscribe_deletions(self):
+ # Subscribe to notifications from the Plasma Store. We use
+ # plasma_client2 to make sure that all used objects will get evicted
+ # properly.
+ self.plasma_client2.subscribe()
+ for i in self.SUBSCRIBE_TEST_SIZES:
+ object_ids = [random_object_id() for _ in range(i)]
+ # Add 1 to the sizes to make sure we have nonzero object sizes.
+ metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
+ data_sizes = [np.random.randint(1000) + 1 for _ in range(i)]
+ for j in range(i):
+ x = self.plasma_client2.create(
+ object_ids[j], data_sizes[j],
+ metadata=bytearray(np.random.bytes(metadata_sizes[j])))
+ self.plasma_client2.seal(object_ids[j])
+ del x
+ # Check that we received notifications for creating all of the
+ # objects.
+ for j in range(i):
+ notification_info = self.plasma_client2.get_next_notification()
+ recv_objid, recv_dsize, recv_msize = notification_info
+ assert object_ids[j] == recv_objid
+ assert data_sizes[j] == recv_dsize
+ assert metadata_sizes[j] == recv_msize
+
+ # Check that we receive notifications for deleting all objects, as
+ # we evict them.
+ for j in range(i):
+ assert (self.plasma_client2.evict(1) ==
+ data_sizes[j] + metadata_sizes[j])
+ notification_info = self.plasma_client2.get_next_notification()
+ recv_objid, recv_dsize, recv_msize = notification_info
+ assert object_ids[j] == recv_objid
+ assert -1 == recv_dsize
+ assert -1 == recv_msize
+
+ # Test multiple deletion notifications. The first 9 object IDs have
+ # size 0, and the last has a nonzero size. When Plasma evicts 1 byte,
+ # it will evict all objects, so we should receive deletion
+ # notifications for each.
+ num_object_ids = 10
+ object_ids = [random_object_id() for _ in range(num_object_ids)]
+ metadata_sizes = [0] * (num_object_ids - 1)
+ data_sizes = [0] * (num_object_ids - 1)
+ metadata_sizes.append(np.random.randint(1000))
+ data_sizes.append(np.random.randint(1000))
+ for i in range(num_object_ids):
+ x = self.plasma_client2.create(
+ object_ids[i], data_sizes[i],
+ metadata=bytearray(np.random.bytes(metadata_sizes[i])))
+ self.plasma_client2.seal(object_ids[i])
+ del x
+ for i in range(num_object_ids):
+ notification_info = self.plasma_client2.get_next_notification()
+ recv_objid, recv_dsize, recv_msize = notification_info
+ assert object_ids[i] == recv_objid
+ assert data_sizes[i] == recv_dsize
+ assert metadata_sizes[i] == recv_msize
+ assert (self.plasma_client2.evict(1) ==
+ data_sizes[-1] + metadata_sizes[-1])
+ for i in range(num_object_ids):
+ notification_info = self.plasma_client2.get_next_notification()
+ recv_objid, recv_dsize, recv_msize = notification_info
+ assert object_ids[i] == recv_objid
+ assert -1 == recv_dsize
+ assert -1 == recv_msize
+
+ def test_use_full_memory(self):
+ # Fill the object store up with a large number of small objects and let
+ # them go out of scope.
+ for _ in range(100):
+ create_object(
+ self.plasma_client2,
+ np.random.randint(1, DEFAULT_PLASMA_STORE_MEMORY // 20), 0)
+ # Create large objects that require the full object store size, and
+ # verify that they fit.
+ for _ in range(2):
+ create_object(self.plasma_client2, DEFAULT_PLASMA_STORE_MEMORY, 0)
+ # Verify that an object that is too large does not fit.
+ # Also verifies that the right error is thrown, and does not
+ # create the object ID prematurely.
+ object_id = random_object_id()
+ for i in range(3):
+ with pytest.raises(pa.plasma.PlasmaStoreFull):
+ self.plasma_client2.create(
+ object_id, DEFAULT_PLASMA_STORE_MEMORY + SMALL_OBJECT_SIZE)
+
+ @staticmethod
+ def _client_blocked_in_get(plasma_store_name, object_id):
+ import pyarrow.plasma as plasma
+ client = plasma.connect(plasma_store_name)
+ # Try to get an object ID that doesn't exist. This should block.
+ client.get([object_id])
+
+ def test_client_death_during_get(self):
+ object_id = random_object_id()
+
+ p = multiprocessing.Process(target=self._client_blocked_in_get,
+ args=(self.plasma_store_name, object_id))
+ p.start()
+ # Make sure the process is running.
+ time.sleep(0.2)
+ assert p.is_alive()
+
+ # Kill the client process.
+ p.terminate()
+ # Wait a little for the store to process the disconnect event.
+ time.sleep(0.1)
+
+ # Create the object.
+ self.plasma_client.put(1, object_id=object_id)
+
+ # Check that the store is still alive. This will raise an exception if
+ # the store is dead.
+ self.plasma_client.contains(random_object_id())
+
+ @staticmethod
+ def _client_get_multiple(plasma_store_name, object_ids):
+ import pyarrow.plasma as plasma
+ client = plasma.connect(plasma_store_name)
+ # Try to get an object ID that doesn't exist. This should block.
+ client.get(object_ids)
+
+ def test_client_getting_multiple_objects(self):
+ object_ids = [random_object_id() for _ in range(10)]
+
+ p = multiprocessing.Process(target=self._client_get_multiple,
+ args=(self.plasma_store_name, object_ids))
+ p.start()
+ # Make sure the process is running.
+ time.sleep(0.2)
+ assert p.is_alive()
+
+ # Create the objects one by one.
+ for object_id in object_ids:
+ self.plasma_client.put(1, object_id=object_id)
+
+ # Check that the store is still alive. This will raise an exception if
+ # the store is dead.
+ self.plasma_client.contains(random_object_id())
+
+ # Make sure that the blocked client finishes.
+ start_time = time.time()
+ while True:
+ if time.time() - start_time > 5:
+ raise Exception("Timing out while waiting for blocked client "
+ "to finish.")
+ if not p.is_alive():
+ break
+
+
+@pytest.mark.plasma
+class TestEvictionToExternalStore:
+
+ def setup_method(self, test_method):
+ import pyarrow.plasma as plasma
+ # Start Plasma store.
+ self.plasma_store_ctx = plasma.start_plasma_store(
+ plasma_store_memory=1000 * 1024,
+ use_valgrind=USE_VALGRIND,
+ external_store=EXTERNAL_STORE)
+ self.plasma_store_name, self.p = self.plasma_store_ctx.__enter__()
+ # Connect to Plasma.
+ self.plasma_client = plasma.connect(self.plasma_store_name)
+
+ def teardown_method(self, test_method):
+ try:
+ # Check that the Plasma store is still alive.
+ assert self.p.poll() is None
+ self.p.send_signal(signal.SIGTERM)
+ self.p.wait(timeout=5)
+ finally:
+ self.plasma_store_ctx.__exit__(None, None, None)
+
+ def test_eviction(self):
+ client = self.plasma_client
+
+ object_ids = [random_object_id() for _ in range(0, 20)]
+ data = b'x' * 100 * 1024
+ metadata = b''
+
+ for i in range(0, 20):
+ # Test for object non-existence.
+ assert not client.contains(object_ids[i])
+
+ # Create and seal the object.
+ client.create_and_seal(object_ids[i], data, metadata)
+
+ # Test that the client can get the object.
+ assert client.contains(object_ids[i])
+
+ for i in range(0, 20):
+ # Since we are accessing objects sequentially, every object we
+ # access would be a cache "miss" owing to LRU eviction.
+ # Try and access the object from the plasma store first, and then
+ # try external store on failure. This should succeed to fetch the
+ # object. However, it may evict the next few objects.
+ [result] = client.get_buffers([object_ids[i]])
+ assert result.to_pybytes() == data
+
+ # Make sure we still cannot fetch objects that do not exist
+ [result] = client.get_buffers([random_object_id()], timeout_ms=100)
+ assert result is None
+
+
+@pytest.mark.plasma
+def test_object_id_size():
+ import pyarrow.plasma as plasma
+ with pytest.raises(ValueError):
+ plasma.ObjectID("hello")
+ plasma.ObjectID(20 * b"0")
+
+
+@pytest.mark.plasma
+def test_object_id_equality_operators():
+ import pyarrow.plasma as plasma
+
+ oid1 = plasma.ObjectID(20 * b'0')
+ oid2 = plasma.ObjectID(20 * b'0')
+ oid3 = plasma.ObjectID(19 * b'0' + b'1')
+
+ assert oid1 == oid2
+ assert oid2 != oid3
+ assert oid1 != 'foo'
+
+
+@pytest.mark.xfail(reason="often fails on travis")
+@pytest.mark.skipif(not os.path.exists("/mnt/hugepages"),
+ reason="requires hugepage support")
+def test_use_huge_pages():
+ import pyarrow.plasma as plasma
+ with plasma.start_plasma_store(
+ plasma_store_memory=2*10**9,
+ plasma_directory="/mnt/hugepages",
+ use_hugepages=True) as (plasma_store_name, p):
+ plasma_client = plasma.connect(plasma_store_name)
+ create_object(plasma_client, 10**8)
+
+
+# This is checking to make sure plasma_clients cannot be destroyed
+# before all the PlasmaBuffers that have handles to them are
+# destroyed, see ARROW-2448.
+@pytest.mark.plasma
+def test_plasma_client_sharing():
+ import pyarrow.plasma as plasma
+
+ with plasma.start_plasma_store(
+ plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY) \
+ as (plasma_store_name, p):
+ plasma_client = plasma.connect(plasma_store_name)
+ object_id = plasma_client.put(np.zeros(3))
+ buf = plasma_client.get(object_id)
+ del plasma_client
+ assert (buf == np.zeros(3)).all()
+ del buf # This segfaulted pre ARROW-2448.
+
+
+@pytest.mark.plasma
+def test_plasma_list():
+ import pyarrow.plasma as plasma
+
+ with plasma.start_plasma_store(
+ plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY) \
+ as (plasma_store_name, p):
+ plasma_client = plasma.connect(plasma_store_name)
+
+ # Test sizes
+ u, _, _ = create_object(plasma_client, 11, metadata_size=7, seal=False)
+ l1 = plasma_client.list()
+ assert l1[u]["data_size"] == 11
+ assert l1[u]["metadata_size"] == 7
+
+ # Test ref_count
+ v = plasma_client.put(np.zeros(3))
+ # Ref count has already been released
+ # XXX flaky test, disabled (ARROW-3344)
+ # l2 = plasma_client.list()
+ # assert l2[v]["ref_count"] == 0
+ a = plasma_client.get(v)
+ l3 = plasma_client.list()
+ assert l3[v]["ref_count"] == 1
+ del a
+
+ # Test state
+ w, _, _ = create_object(plasma_client, 3, metadata_size=0, seal=False)
+ l4 = plasma_client.list()
+ assert l4[w]["state"] == "created"
+ plasma_client.seal(w)
+ l5 = plasma_client.list()
+ assert l5[w]["state"] == "sealed"
+
+ # Test timestamps
+ slack = 1.5 # seconds
+ t1 = time.time()
+ x, _, _ = create_object(plasma_client, 3, metadata_size=0, seal=False)
+ t2 = time.time()
+ l6 = plasma_client.list()
+ assert t1 - slack <= l6[x]["create_time"] <= t2 + slack
+ time.sleep(2.0)
+ t3 = time.time()
+ plasma_client.seal(x)
+ t4 = time.time()
+ l7 = plasma_client.list()
+ assert t3 - t2 - slack <= l7[x]["construct_duration"]
+ assert l7[x]["construct_duration"] <= t4 - t1 + slack
+
+
+@pytest.mark.plasma
+def test_object_id_randomness():
+ cmd = "from pyarrow import plasma; print(plasma.ObjectID.from_random())"
+ first_object_id = subprocess.check_output([sys.executable, "-c", cmd])
+ second_object_id = subprocess.check_output([sys.executable, "-c", cmd])
+ assert first_object_id != second_object_id
+
+
+@pytest.mark.plasma
+def test_store_capacity():
+ import pyarrow.plasma as plasma
+ with plasma.start_plasma_store(plasma_store_memory=10000) as (name, p):
+ plasma_client = plasma.connect(name)
+ assert plasma_client.store_capacity() == 10000
diff --git a/src/arrow/python/pyarrow/tests/test_plasma_tf_op.py b/src/arrow/python/pyarrow/tests/test_plasma_tf_op.py
new file mode 100644
index 000000000..53ecae217
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_plasma_tf_op.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import pytest
+
+
+def run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
+ client, use_gpu, dtype):
+ FORCE_DEVICE = '/gpu' if use_gpu else '/cpu'
+
+ object_id = np.random.bytes(20)
+
+ data = np.random.randn(3, 244, 244).astype(dtype)
+ ones = np.ones((3, 244, 244)).astype(dtype)
+
+ sess = tf.Session(config=tf.ConfigProto(
+ allow_soft_placement=True, log_device_placement=True))
+
+ def ToPlasma():
+ data_tensor = tf.constant(data)
+ ones_tensor = tf.constant(ones)
+ return plasma.tf_plasma_op.tensor_to_plasma(
+ [data_tensor, ones_tensor],
+ object_id,
+ plasma_store_socket_name=plasma_store_name)
+
+ def FromPlasma():
+ return plasma.tf_plasma_op.plasma_to_tensor(
+ object_id,
+ dtype=tf.as_dtype(dtype),
+ plasma_store_socket_name=plasma_store_name)
+
+ with tf.device(FORCE_DEVICE):
+ to_plasma = ToPlasma()
+ from_plasma = FromPlasma()
+
+ z = from_plasma + 1
+
+ sess.run(to_plasma)
+ # NOTE(zongheng): currently it returns a flat 1D tensor.
+ # So reshape manually.
+ out = sess.run(from_plasma)
+
+ out = np.split(out, 2)
+ out0 = out[0].reshape(3, 244, 244)
+ out1 = out[1].reshape(3, 244, 244)
+
+ sess.run(z)
+
+ assert np.array_equal(data, out0), "Data not equal!"
+ assert np.array_equal(ones, out1), "Data not equal!"
+
+ # Try getting the data from Python
+ plasma_object_id = plasma.ObjectID(object_id)
+ obj = client.get(plasma_object_id)
+
+ # Deserialized Tensor should be 64-byte aligned.
+ assert obj.ctypes.data % 64 == 0
+
+ result = np.split(obj, 2)
+ result0 = result[0].reshape(3, 244, 244)
+ result1 = result[1].reshape(3, 244, 244)
+
+ assert np.array_equal(data, result0), "Data not equal!"
+ assert np.array_equal(ones, result1), "Data not equal!"
+
+
+@pytest.mark.plasma
+@pytest.mark.tensorflow
+@pytest.mark.skip(reason='Until ARROW-4259 is resolved')
+def test_plasma_tf_op(use_gpu=False):
+ import pyarrow.plasma as plasma
+ import tensorflow as tf
+
+ plasma.build_plasma_tensorflow_op()
+
+ if plasma.tf_plasma_op is None:
+ pytest.skip("TensorFlow Op not found")
+
+ with plasma.start_plasma_store(10**8) as (plasma_store_name, p):
+ client = plasma.connect(plasma_store_name)
+ for dtype in [np.float32, np.float64,
+ np.int8, np.int16, np.int32, np.int64]:
+ run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name,
+ client, use_gpu, dtype)
+
+ # Make sure the objects have been released.
+ for _, info in client.list().items():
+ assert info['ref_count'] == 0
diff --git a/src/arrow/python/pyarrow/tests/test_scalars.py b/src/arrow/python/pyarrow/tests/test_scalars.py
new file mode 100644
index 000000000..778ce1066
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_scalars.py
@@ -0,0 +1,687 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import decimal
+import pickle
+import pytest
+import weakref
+
+import numpy as np
+
+import pyarrow as pa
+
+
+@pytest.mark.parametrize(['value', 'ty', 'klass', 'deprecated'], [
+ (False, None, pa.BooleanScalar, pa.BooleanValue),
+ (True, None, pa.BooleanScalar, pa.BooleanValue),
+ (1, None, pa.Int64Scalar, pa.Int64Value),
+ (-1, None, pa.Int64Scalar, pa.Int64Value),
+ (1, pa.int8(), pa.Int8Scalar, pa.Int8Value),
+ (1, pa.uint8(), pa.UInt8Scalar, pa.UInt8Value),
+ (1, pa.int16(), pa.Int16Scalar, pa.Int16Value),
+ (1, pa.uint16(), pa.UInt16Scalar, pa.UInt16Value),
+ (1, pa.int32(), pa.Int32Scalar, pa.Int32Value),
+ (1, pa.uint32(), pa.UInt32Scalar, pa.UInt32Value),
+ (1, pa.int64(), pa.Int64Scalar, pa.Int64Value),
+ (1, pa.uint64(), pa.UInt64Scalar, pa.UInt64Value),
+ (1.0, None, pa.DoubleScalar, pa.DoubleValue),
+ (np.float16(1.0), pa.float16(), pa.HalfFloatScalar, pa.HalfFloatValue),
+ (1.0, pa.float32(), pa.FloatScalar, pa.FloatValue),
+ (decimal.Decimal("1.123"), None, pa.Decimal128Scalar, pa.Decimal128Value),
+ (decimal.Decimal("1.1234567890123456789012345678901234567890"),
+ None, pa.Decimal256Scalar, pa.Decimal256Value),
+ ("string", None, pa.StringScalar, pa.StringValue),
+ (b"bytes", None, pa.BinaryScalar, pa.BinaryValue),
+ ("largestring", pa.large_string(), pa.LargeStringScalar,
+ pa.LargeStringValue),
+ (b"largebytes", pa.large_binary(), pa.LargeBinaryScalar,
+ pa.LargeBinaryValue),
+ (b"abc", pa.binary(3), pa.FixedSizeBinaryScalar, pa.FixedSizeBinaryValue),
+ ([1, 2, 3], None, pa.ListScalar, pa.ListValue),
+ ([1, 2, 3, 4], pa.large_list(pa.int8()), pa.LargeListScalar,
+ pa.LargeListValue),
+ ([1, 2, 3, 4, 5], pa.list_(pa.int8(), 5), pa.FixedSizeListScalar,
+ pa.FixedSizeListValue),
+ (datetime.date.today(), None, pa.Date32Scalar, pa.Date32Value),
+ (datetime.date.today(), pa.date64(), pa.Date64Scalar, pa.Date64Value),
+ (datetime.datetime.now(), None, pa.TimestampScalar, pa.TimestampValue),
+ (datetime.datetime.now().time().replace(microsecond=0), pa.time32('s'),
+ pa.Time32Scalar, pa.Time32Value),
+ (datetime.datetime.now().time(), None, pa.Time64Scalar, pa.Time64Value),
+ (datetime.timedelta(days=1), None, pa.DurationScalar, pa.DurationValue),
+ (pa.MonthDayNano([1, -1, -10100]), None,
+ pa.MonthDayNanoIntervalScalar, None),
+ ({'a': 1, 'b': [1, 2]}, None, pa.StructScalar, pa.StructValue),
+ ([('a', 1), ('b', 2)], pa.map_(pa.string(), pa.int8()), pa.MapScalar,
+ pa.MapValue),
+])
+def test_basics(value, ty, klass, deprecated):
+ s = pa.scalar(value, type=ty)
+ assert isinstance(s, klass)
+ assert s.as_py() == value
+ assert s == pa.scalar(value, type=ty)
+ assert s != value
+ assert s != "else"
+ assert hash(s) == hash(s)
+ assert s.is_valid is True
+ assert s != None # noqa: E711
+ if deprecated is not None:
+ with pytest.warns(FutureWarning):
+ assert isinstance(s, deprecated)
+
+ s = pa.scalar(None, type=s.type)
+ assert s.is_valid is False
+ assert s.as_py() is None
+ assert s != pa.scalar(value, type=ty)
+
+ # test pickle roundtrip
+ restored = pickle.loads(pickle.dumps(s))
+ assert s.equals(restored)
+
+ # test that scalars are weak-referenceable
+ wr = weakref.ref(s)
+ assert wr() is not None
+ del s
+ assert wr() is None
+
+
+def test_null_singleton():
+ with pytest.raises(RuntimeError):
+ pa.NullScalar()
+
+
+def test_nulls():
+ null = pa.scalar(None)
+ assert null is pa.NA
+ assert null.as_py() is None
+ assert null != "something"
+ assert (null == pa.scalar(None)) is True
+ assert (null == 0) is False
+ assert pa.NA == pa.NA
+ assert pa.NA not in [5]
+
+ arr = pa.array([None, None])
+ for v in arr:
+ assert v is pa.NA
+ assert v.as_py() is None
+
+ # test pickle roundtrip
+ restored = pickle.loads(pickle.dumps(null))
+ assert restored.equals(null)
+
+ # test that scalars are weak-referenceable
+ wr = weakref.ref(null)
+ assert wr() is not None
+ del null
+ assert wr() is not None # singleton
+
+
+def test_hashing():
+ # ARROW-640
+ values = list(range(500))
+ arr = pa.array(values + values)
+ set_from_array = set(arr)
+ assert isinstance(set_from_array, set)
+ assert len(set_from_array) == 500
+
+
+def test_bool():
+ false = pa.scalar(False)
+ true = pa.scalar(True)
+
+ assert isinstance(false, pa.BooleanScalar)
+ assert isinstance(true, pa.BooleanScalar)
+
+ assert repr(true) == "<pyarrow.BooleanScalar: True>"
+ assert str(true) == "True"
+ assert repr(false) == "<pyarrow.BooleanScalar: False>"
+ assert str(false) == "False"
+
+ assert true.as_py() is True
+ assert false.as_py() is False
+
+
+def test_numerics():
+ # int64
+ s = pa.scalar(1)
+ assert isinstance(s, pa.Int64Scalar)
+ assert repr(s) == "<pyarrow.Int64Scalar: 1>"
+ assert str(s) == "1"
+ assert s.as_py() == 1
+
+ with pytest.raises(OverflowError):
+ pa.scalar(-1, type='uint8')
+
+ # float64
+ s = pa.scalar(1.5)
+ assert isinstance(s, pa.DoubleScalar)
+ assert repr(s) == "<pyarrow.DoubleScalar: 1.5>"
+ assert str(s) == "1.5"
+ assert s.as_py() == 1.5
+
+ # float16
+ s = pa.scalar(np.float16(0.5), type='float16')
+ assert isinstance(s, pa.HalfFloatScalar)
+ assert repr(s) == "<pyarrow.HalfFloatScalar: 0.5>"
+ assert str(s) == "0.5"
+ assert s.as_py() == 0.5
+
+
+def test_decimal128():
+ v = decimal.Decimal("1.123")
+ s = pa.scalar(v)
+ assert isinstance(s, pa.Decimal128Scalar)
+ assert s.as_py() == v
+ assert s.type == pa.decimal128(4, 3)
+
+ v = decimal.Decimal("1.1234")
+ with pytest.raises(pa.ArrowInvalid):
+ pa.scalar(v, type=pa.decimal128(4, scale=3))
+ with pytest.raises(pa.ArrowInvalid):
+ pa.scalar(v, type=pa.decimal128(5, scale=3))
+
+ s = pa.scalar(v, type=pa.decimal128(5, scale=4))
+ assert isinstance(s, pa.Decimal128Scalar)
+ assert s.as_py() == v
+
+
+def test_decimal256():
+ v = decimal.Decimal("1234567890123456789012345678901234567890.123")
+ s = pa.scalar(v)
+ assert isinstance(s, pa.Decimal256Scalar)
+ assert s.as_py() == v
+ assert s.type == pa.decimal256(43, 3)
+
+ v = decimal.Decimal("1.1234")
+ with pytest.raises(pa.ArrowInvalid):
+ pa.scalar(v, type=pa.decimal256(4, scale=3))
+ with pytest.raises(pa.ArrowInvalid):
+ pa.scalar(v, type=pa.decimal256(5, scale=3))
+
+ s = pa.scalar(v, type=pa.decimal256(5, scale=4))
+ assert isinstance(s, pa.Decimal256Scalar)
+ assert s.as_py() == v
+
+
+def test_date():
+ # ARROW-5125
+ d1 = datetime.date(3200, 1, 1)
+ d2 = datetime.date(1960, 1, 1)
+
+ for ty in [pa.date32(), pa.date64()]:
+ for d in [d1, d2]:
+ s = pa.scalar(d, type=ty)
+ assert s.as_py() == d
+
+
+def test_date_cast():
+ # ARROW-10472 - casting fo scalars doesn't segfault
+ scalar = pa.scalar(datetime.datetime(2012, 1, 1), type=pa.timestamp("us"))
+ expected = datetime.date(2012, 1, 1)
+ for ty in [pa.date32(), pa.date64()]:
+ result = scalar.cast(ty)
+ assert result.as_py() == expected
+
+
+def test_time():
+ t1 = datetime.time(18, 0)
+ t2 = datetime.time(21, 0)
+
+ types = [pa.time32('s'), pa.time32('ms'), pa.time64('us'), pa.time64('ns')]
+ for ty in types:
+ for t in [t1, t2]:
+ s = pa.scalar(t, type=ty)
+ assert s.as_py() == t
+
+
+def test_cast():
+ val = pa.scalar(5, type='int8')
+ assert val.cast('int64') == pa.scalar(5, type='int64')
+ assert val.cast('uint32') == pa.scalar(5, type='uint32')
+ assert val.cast('string') == pa.scalar('5', type='string')
+ with pytest.raises(ValueError):
+ pa.scalar('foo').cast('int32')
+
+
+@pytest.mark.pandas
+def test_timestamp():
+ import pandas as pd
+ arr = pd.date_range('2000-01-01 12:34:56', periods=10).values
+
+ units = ['ns', 'us', 'ms', 's']
+
+ for i, unit in enumerate(units):
+ dtype = 'datetime64[{}]'.format(unit)
+ arrow_arr = pa.Array.from_pandas(arr.astype(dtype))
+ expected = pd.Timestamp('2000-01-01 12:34:56')
+
+ assert arrow_arr[0].as_py() == expected
+ assert arrow_arr[0].value * 1000**i == expected.value
+
+ tz = 'America/New_York'
+ arrow_type = pa.timestamp(unit, tz=tz)
+
+ dtype = 'datetime64[{}]'.format(unit)
+ arrow_arr = pa.Array.from_pandas(arr.astype(dtype), type=arrow_type)
+ expected = (pd.Timestamp('2000-01-01 12:34:56')
+ .tz_localize('utc')
+ .tz_convert(tz))
+
+ assert arrow_arr[0].as_py() == expected
+ assert arrow_arr[0].value * 1000**i == expected.value
+
+
+@pytest.mark.nopandas
+def test_timestamp_nanos_nopandas():
+ # ARROW-5450
+ import pytz
+ tz = 'America/New_York'
+ ty = pa.timestamp('ns', tz=tz)
+
+ # 2000-01-01 00:00:00 + 1 microsecond
+ s = pa.scalar(946684800000000000 + 1000, type=ty)
+
+ tzinfo = pytz.timezone(tz)
+ expected = datetime.datetime(2000, 1, 1, microsecond=1, tzinfo=tzinfo)
+ expected = tzinfo.fromutc(expected)
+ result = s.as_py()
+ assert result == expected
+ assert result.year == 1999
+ assert result.hour == 19
+
+ # Non-zero nanos yields ValueError
+ s = pa.scalar(946684800000000001, type=ty)
+ with pytest.raises(ValueError):
+ s.as_py()
+
+
+def test_timestamp_no_overflow():
+ # ARROW-5450
+ import pytz
+
+ timestamps = [
+ datetime.datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
+ datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=pytz.utc),
+ datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc),
+ ]
+ for ts in timestamps:
+ s = pa.scalar(ts, type=pa.timestamp("us", tz="UTC"))
+ assert s.as_py() == ts
+
+
+def test_duration():
+ arr = np.array([0, 3600000000000], dtype='timedelta64[ns]')
+
+ units = ['us', 'ms', 's']
+
+ for i, unit in enumerate(units):
+ dtype = 'timedelta64[{}]'.format(unit)
+ arrow_arr = pa.array(arr.astype(dtype))
+ expected = datetime.timedelta(seconds=60*60)
+ assert isinstance(arrow_arr[1].as_py(), datetime.timedelta)
+ assert arrow_arr[1].as_py() == expected
+ assert (arrow_arr[1].value * 1000**(i+1) ==
+ expected.total_seconds() * 1e9)
+
+
+@pytest.mark.pandas
+def test_duration_nanos_pandas():
+ import pandas as pd
+ arr = pa.array([0, 3600000000000], type=pa.duration('ns'))
+ expected = pd.Timedelta('1 hour')
+ assert isinstance(arr[1].as_py(), pd.Timedelta)
+ assert arr[1].as_py() == expected
+ assert arr[1].value == expected.value
+
+ # Non-zero nanos work fine
+ arr = pa.array([946684800000000001], type=pa.duration('ns'))
+ assert arr[0].as_py() == pd.Timedelta(946684800000000001, unit='ns')
+
+
+@pytest.mark.nopandas
+def test_duration_nanos_nopandas():
+ arr = pa.array([0, 3600000000000], pa.duration('ns'))
+ expected = datetime.timedelta(seconds=60*60)
+ assert isinstance(arr[1].as_py(), datetime.timedelta)
+ assert arr[1].as_py() == expected
+ assert arr[1].value == expected.total_seconds() * 1e9
+
+ # Non-zero nanos yields ValueError
+ arr = pa.array([946684800000000001], type=pa.duration('ns'))
+ with pytest.raises(ValueError):
+ arr[0].as_py()
+
+
+def test_month_day_nano_interval():
+ triple = pa.MonthDayNano([-3600, 1800, -50])
+ arr = pa.array([triple])
+ assert isinstance(arr[0].as_py(), pa.MonthDayNano)
+ assert arr[0].as_py() == triple
+ assert arr[0].value == triple
+
+
+@pytest.mark.parametrize('value', ['foo', 'mañana'])
+@pytest.mark.parametrize(('ty', 'scalar_typ'), [
+ (pa.string(), pa.StringScalar),
+ (pa.large_string(), pa.LargeStringScalar)
+])
+def test_string(value, ty, scalar_typ):
+ s = pa.scalar(value, type=ty)
+ assert isinstance(s, scalar_typ)
+ assert s.as_py() == value
+ assert s.as_py() != 'something'
+ assert repr(value) in repr(s)
+ assert str(s) == str(value)
+
+ buf = s.as_buffer()
+ assert isinstance(buf, pa.Buffer)
+ assert buf.to_pybytes() == value.encode()
+
+
+@pytest.mark.parametrize('value', [b'foo', b'bar'])
+@pytest.mark.parametrize(('ty', 'scalar_typ'), [
+ (pa.binary(), pa.BinaryScalar),
+ (pa.large_binary(), pa.LargeBinaryScalar)
+])
+def test_binary(value, ty, scalar_typ):
+ s = pa.scalar(value, type=ty)
+ assert isinstance(s, scalar_typ)
+ assert s.as_py() == value
+ assert str(s) == str(value)
+ assert repr(value) in repr(s)
+ assert s.as_py() == value
+ assert s != b'xxxxx'
+
+ buf = s.as_buffer()
+ assert isinstance(buf, pa.Buffer)
+ assert buf.to_pybytes() == value
+
+
+def test_fixed_size_binary():
+ s = pa.scalar(b'foof', type=pa.binary(4))
+ assert isinstance(s, pa.FixedSizeBinaryScalar)
+ assert s.as_py() == b'foof'
+
+ with pytest.raises(pa.ArrowInvalid):
+ pa.scalar(b'foof5', type=pa.binary(4))
+
+
+@pytest.mark.parametrize(('ty', 'klass'), [
+ (pa.list_(pa.string()), pa.ListScalar),
+ (pa.large_list(pa.string()), pa.LargeListScalar)
+])
+def test_list(ty, klass):
+ v = ['foo', None]
+ s = pa.scalar(v, type=ty)
+ assert s.type == ty
+ assert len(s) == 2
+ assert isinstance(s.values, pa.Array)
+ assert s.values.to_pylist() == v
+ assert isinstance(s, klass)
+ assert repr(v) in repr(s)
+ assert s.as_py() == v
+ assert s[0].as_py() == 'foo'
+ assert s[1].as_py() is None
+ assert s[-1] == s[1]
+ assert s[-2] == s[0]
+ with pytest.raises(IndexError):
+ s[-3]
+ with pytest.raises(IndexError):
+ s[2]
+
+
+def test_list_from_numpy():
+ s = pa.scalar(np.array([1, 2, 3], dtype=np.int64()))
+ assert s.type == pa.list_(pa.int64())
+ assert s.as_py() == [1, 2, 3]
+
+
+@pytest.mark.pandas
+def test_list_from_pandas():
+ import pandas as pd
+
+ s = pa.scalar(pd.Series([1, 2, 3]))
+ assert s.as_py() == [1, 2, 3]
+
+ cases = [
+ (np.nan, 'null'),
+ (['string', np.nan], pa.list_(pa.binary())),
+ (['string', np.nan], pa.list_(pa.utf8())),
+ ([b'string', np.nan], pa.list_(pa.binary(6))),
+ ([True, np.nan], pa.list_(pa.bool_())),
+ ([decimal.Decimal('0'), np.nan], pa.list_(pa.decimal128(12, 2))),
+ ]
+ for case, ty in cases:
+ # Both types of exceptions are raised. May want to clean that up
+ with pytest.raises((ValueError, TypeError)):
+ pa.scalar(case, type=ty)
+
+ # from_pandas option suppresses failure
+ s = pa.scalar(case, type=ty, from_pandas=True)
+
+
+def test_fixed_size_list():
+ s = pa.scalar([1, None, 3], type=pa.list_(pa.int64(), 3))
+
+ assert len(s) == 3
+ assert isinstance(s, pa.FixedSizeListScalar)
+ assert repr(s) == "<pyarrow.FixedSizeListScalar: [1, None, 3]>"
+ assert s.as_py() == [1, None, 3]
+ assert s[0].as_py() == 1
+ assert s[1].as_py() is None
+ assert s[-1] == s[2]
+ with pytest.raises(IndexError):
+ s[-4]
+ with pytest.raises(IndexError):
+ s[3]
+
+
+def test_struct():
+ ty = pa.struct([
+ pa.field('x', pa.int16()),
+ pa.field('y', pa.float32())
+ ])
+
+ v = {'x': 2, 'y': 3.5}
+ s = pa.scalar(v, type=ty)
+ assert list(s) == list(s.keys()) == ['x', 'y']
+ assert list(s.values()) == [
+ pa.scalar(2, type=pa.int16()),
+ pa.scalar(3.5, type=pa.float32())
+ ]
+ assert list(s.items()) == [
+ ('x', pa.scalar(2, type=pa.int16())),
+ ('y', pa.scalar(3.5, type=pa.float32()))
+ ]
+ assert 'x' in s
+ assert 'y' in s
+ assert 'z' not in s
+ assert 0 not in s
+
+ assert s.as_py() == v
+ assert repr(s) != repr(v)
+ assert repr(s.as_py()) == repr(v)
+ assert len(s) == 2
+ assert isinstance(s['x'], pa.Int16Scalar)
+ assert isinstance(s['y'], pa.FloatScalar)
+ assert s['x'].as_py() == 2
+ assert s['y'].as_py() == 3.5
+
+ with pytest.raises(KeyError):
+ s['non-existent']
+
+ s = pa.scalar(None, type=ty)
+ assert list(s) == list(s.keys()) == ['x', 'y']
+ assert s.as_py() is None
+ assert 'x' in s
+ assert 'y' in s
+ assert isinstance(s['x'], pa.Int16Scalar)
+ assert isinstance(s['y'], pa.FloatScalar)
+ assert s['x'].is_valid is False
+ assert s['y'].is_valid is False
+ assert s['x'].as_py() is None
+ assert s['y'].as_py() is None
+
+
+def test_struct_duplicate_fields():
+ ty = pa.struct([
+ pa.field('x', pa.int16()),
+ pa.field('y', pa.float32()),
+ pa.field('x', pa.int64()),
+ ])
+ s = pa.scalar([('x', 1), ('y', 2.0), ('x', 3)], type=ty)
+
+ assert list(s) == list(s.keys()) == ['x', 'y', 'x']
+ assert len(s) == 3
+ assert s == s
+ assert list(s.items()) == [
+ ('x', pa.scalar(1, pa.int16())),
+ ('y', pa.scalar(2.0, pa.float32())),
+ ('x', pa.scalar(3, pa.int64()))
+ ]
+
+ assert 'x' in s
+ assert 'y' in s
+ assert 'z' not in s
+ assert 0 not in s
+
+ # getitem with field names fails for duplicate fields, works for others
+ with pytest.raises(KeyError):
+ s['x']
+
+ assert isinstance(s['y'], pa.FloatScalar)
+ assert s['y'].as_py() == 2.0
+
+ # getitem with integer index works for all fields
+ assert isinstance(s[0], pa.Int16Scalar)
+ assert s[0].as_py() == 1
+ assert isinstance(s[1], pa.FloatScalar)
+ assert s[1].as_py() == 2.0
+ assert isinstance(s[2], pa.Int64Scalar)
+ assert s[2].as_py() == 3
+
+ assert "pyarrow.StructScalar" in repr(s)
+
+ with pytest.raises(ValueError, match="duplicate field names"):
+ s.as_py()
+
+
+def test_map():
+ ty = pa.map_(pa.string(), pa.int8())
+ v = [('a', 1), ('b', 2)]
+ s = pa.scalar(v, type=ty)
+
+ assert len(s) == 2
+ assert isinstance(s, pa.MapScalar)
+ assert isinstance(s.values, pa.Array)
+ assert repr(s) == "<pyarrow.MapScalar: [('a', 1), ('b', 2)]>"
+ assert s.values.to_pylist() == [
+ {'key': 'a', 'value': 1},
+ {'key': 'b', 'value': 2}
+ ]
+
+ # test iteration
+ for i, j in zip(s, v):
+ assert i == j
+
+ assert s.as_py() == v
+ assert s[1] == (
+ pa.scalar('b', type=pa.string()),
+ pa.scalar(2, type=pa.int8())
+ )
+ assert s[-1] == s[1]
+ assert s[-2] == s[0]
+ with pytest.raises(IndexError):
+ s[-3]
+ with pytest.raises(IndexError):
+ s[2]
+
+ restored = pickle.loads(pickle.dumps(s))
+ assert restored.equals(s)
+
+
+def test_dictionary():
+ indices = pa.array([2, None, 1, 2, 0, None])
+ dictionary = pa.array(['foo', 'bar', 'baz'])
+
+ arr = pa.DictionaryArray.from_arrays(indices, dictionary)
+ expected = ['baz', None, 'bar', 'baz', 'foo', None]
+ assert arr.to_pylist() == expected
+
+ for j, (i, v) in enumerate(zip(indices, expected)):
+ s = arr[j]
+
+ assert s.as_py() == v
+ assert s.value.as_py() == v
+ assert s.index.equals(i)
+ assert s.dictionary.equals(dictionary)
+
+ with pytest.warns(FutureWarning):
+ assert s.index_value.equals(i)
+ with pytest.warns(FutureWarning):
+ assert s.dictionary_value.as_py() == v
+
+ restored = pickle.loads(pickle.dumps(s))
+ assert restored.equals(s)
+
+
+def test_union():
+ # sparse
+ arr = pa.UnionArray.from_sparse(
+ pa.array([0, 0, 1, 1], type=pa.int8()),
+ [
+ pa.array(["a", "b", "c", "d"]),
+ pa.array([1, 2, 3, 4])
+ ]
+ )
+ for s in arr:
+ assert isinstance(s, pa.UnionScalar)
+ assert s.type.equals(arr.type)
+ assert s.is_valid is True
+ with pytest.raises(pa.ArrowNotImplementedError):
+ pickle.loads(pickle.dumps(s))
+
+ assert arr[0].type_code == 0
+ assert arr[0].as_py() == "a"
+ assert arr[1].type_code == 0
+ assert arr[1].as_py() == "b"
+ assert arr[2].type_code == 1
+ assert arr[2].as_py() == 3
+ assert arr[3].type_code == 1
+ assert arr[3].as_py() == 4
+
+ # dense
+ arr = pa.UnionArray.from_dense(
+ types=pa.array([0, 1, 0, 0, 1, 1, 0], type='int8'),
+ value_offsets=pa.array([0, 0, 2, 1, 1, 2, 3], type='int32'),
+ children=[
+ pa.array([b'a', b'b', b'c', b'd'], type='binary'),
+ pa.array([1, 2, 3], type='int64')
+ ]
+ )
+ for s in arr:
+ assert isinstance(s, pa.UnionScalar)
+ assert s.type.equals(arr.type)
+ assert s.is_valid is True
+ with pytest.raises(pa.ArrowNotImplementedError):
+ pickle.loads(pickle.dumps(s))
+
+ assert arr[0].type_code == 0
+ assert arr[0].as_py() == b'a'
+ assert arr[5].type_code == 1
+ assert arr[5].as_py() == 3
diff --git a/src/arrow/python/pyarrow/tests/test_schema.py b/src/arrow/python/pyarrow/tests/test_schema.py
new file mode 100644
index 000000000..f26eaaf5f
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_schema.py
@@ -0,0 +1,730 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import OrderedDict
+import pickle
+import sys
+import weakref
+
+import pytest
+import numpy as np
+import pyarrow as pa
+
+import pyarrow.tests.util as test_util
+from pyarrow.vendored.version import Version
+
+
+def test_schema_constructor_errors():
+ msg = ("Do not call Schema's constructor directly, use `pyarrow.schema` "
+ "instead")
+ with pytest.raises(TypeError, match=msg):
+ pa.Schema()
+
+
+def test_type_integers():
+ dtypes = ['int8', 'int16', 'int32', 'int64',
+ 'uint8', 'uint16', 'uint32', 'uint64']
+
+ for name in dtypes:
+ factory = getattr(pa, name)
+ t = factory()
+ assert str(t) == name
+
+
+def test_type_to_pandas_dtype():
+ M8_ns = np.dtype('datetime64[ns]')
+ cases = [
+ (pa.null(), np.object_),
+ (pa.bool_(), np.bool_),
+ (pa.int8(), np.int8),
+ (pa.int16(), np.int16),
+ (pa.int32(), np.int32),
+ (pa.int64(), np.int64),
+ (pa.uint8(), np.uint8),
+ (pa.uint16(), np.uint16),
+ (pa.uint32(), np.uint32),
+ (pa.uint64(), np.uint64),
+ (pa.float16(), np.float16),
+ (pa.float32(), np.float32),
+ (pa.float64(), np.float64),
+ (pa.date32(), M8_ns),
+ (pa.date64(), M8_ns),
+ (pa.timestamp('ms'), M8_ns),
+ (pa.binary(), np.object_),
+ (pa.binary(12), np.object_),
+ (pa.string(), np.object_),
+ (pa.list_(pa.int8()), np.object_),
+ # (pa.list_(pa.int8(), 2), np.object_), # TODO needs pandas conversion
+ (pa.map_(pa.int64(), pa.float64()), np.object_),
+ ]
+ for arrow_type, numpy_type in cases:
+ assert arrow_type.to_pandas_dtype() == numpy_type
+
+
+@pytest.mark.pandas
+def test_type_to_pandas_dtype_check_import():
+ # ARROW-7980
+ test_util.invoke_script('arrow_7980.py')
+
+
+def test_type_list():
+ value_type = pa.int32()
+ list_type = pa.list_(value_type)
+ assert str(list_type) == 'list<item: int32>'
+
+ field = pa.field('my_item', pa.string())
+ l2 = pa.list_(field)
+ assert str(l2) == 'list<my_item: string>'
+
+
+def test_type_comparisons():
+ val = pa.int32()
+ assert val == pa.int32()
+ assert val == 'int32'
+ assert val != 5
+
+
+def test_type_for_alias():
+ cases = [
+ ('i1', pa.int8()),
+ ('int8', pa.int8()),
+ ('i2', pa.int16()),
+ ('int16', pa.int16()),
+ ('i4', pa.int32()),
+ ('int32', pa.int32()),
+ ('i8', pa.int64()),
+ ('int64', pa.int64()),
+ ('u1', pa.uint8()),
+ ('uint8', pa.uint8()),
+ ('u2', pa.uint16()),
+ ('uint16', pa.uint16()),
+ ('u4', pa.uint32()),
+ ('uint32', pa.uint32()),
+ ('u8', pa.uint64()),
+ ('uint64', pa.uint64()),
+ ('f4', pa.float32()),
+ ('float32', pa.float32()),
+ ('f8', pa.float64()),
+ ('float64', pa.float64()),
+ ('date32', pa.date32()),
+ ('date64', pa.date64()),
+ ('string', pa.string()),
+ ('str', pa.string()),
+ ('binary', pa.binary()),
+ ('time32[s]', pa.time32('s')),
+ ('time32[ms]', pa.time32('ms')),
+ ('time64[us]', pa.time64('us')),
+ ('time64[ns]', pa.time64('ns')),
+ ('timestamp[s]', pa.timestamp('s')),
+ ('timestamp[ms]', pa.timestamp('ms')),
+ ('timestamp[us]', pa.timestamp('us')),
+ ('timestamp[ns]', pa.timestamp('ns')),
+ ('duration[s]', pa.duration('s')),
+ ('duration[ms]', pa.duration('ms')),
+ ('duration[us]', pa.duration('us')),
+ ('duration[ns]', pa.duration('ns')),
+ ('month_day_nano_interval', pa.month_day_nano_interval()),
+ ]
+
+ for val, expected in cases:
+ assert pa.type_for_alias(val) == expected
+
+
+def test_type_string():
+ t = pa.string()
+ assert str(t) == 'string'
+
+
+def test_type_timestamp_with_tz():
+ tz = 'America/Los_Angeles'
+ t = pa.timestamp('ns', tz=tz)
+ assert t.unit == 'ns'
+ assert t.tz == tz
+
+
+def test_time_types():
+ t1 = pa.time32('s')
+ t2 = pa.time32('ms')
+ t3 = pa.time64('us')
+ t4 = pa.time64('ns')
+
+ assert t1.unit == 's'
+ assert t2.unit == 'ms'
+ assert t3.unit == 'us'
+ assert t4.unit == 'ns'
+
+ assert str(t1) == 'time32[s]'
+ assert str(t4) == 'time64[ns]'
+
+ with pytest.raises(ValueError):
+ pa.time32('us')
+
+ with pytest.raises(ValueError):
+ pa.time64('s')
+
+
+def test_from_numpy_dtype():
+ cases = [
+ (np.dtype('bool'), pa.bool_()),
+ (np.dtype('int8'), pa.int8()),
+ (np.dtype('int16'), pa.int16()),
+ (np.dtype('int32'), pa.int32()),
+ (np.dtype('int64'), pa.int64()),
+ (np.dtype('uint8'), pa.uint8()),
+ (np.dtype('uint16'), pa.uint16()),
+ (np.dtype('uint32'), pa.uint32()),
+ (np.dtype('float16'), pa.float16()),
+ (np.dtype('float32'), pa.float32()),
+ (np.dtype('float64'), pa.float64()),
+ (np.dtype('U'), pa.string()),
+ (np.dtype('S'), pa.binary()),
+ (np.dtype('datetime64[s]'), pa.timestamp('s')),
+ (np.dtype('datetime64[ms]'), pa.timestamp('ms')),
+ (np.dtype('datetime64[us]'), pa.timestamp('us')),
+ (np.dtype('datetime64[ns]'), pa.timestamp('ns')),
+ (np.dtype('timedelta64[s]'), pa.duration('s')),
+ (np.dtype('timedelta64[ms]'), pa.duration('ms')),
+ (np.dtype('timedelta64[us]'), pa.duration('us')),
+ (np.dtype('timedelta64[ns]'), pa.duration('ns')),
+ ]
+
+ for dt, pt in cases:
+ result = pa.from_numpy_dtype(dt)
+ assert result == pt
+
+ # Things convertible to numpy dtypes work
+ assert pa.from_numpy_dtype('U') == pa.string()
+ assert pa.from_numpy_dtype(np.str_) == pa.string()
+ assert pa.from_numpy_dtype('int32') == pa.int32()
+ assert pa.from_numpy_dtype(bool) == pa.bool_()
+
+ with pytest.raises(NotImplementedError):
+ pa.from_numpy_dtype(np.dtype('O'))
+
+ with pytest.raises(TypeError):
+ pa.from_numpy_dtype('not_convertible_to_dtype')
+
+
+def test_schema():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ]
+ sch = pa.schema(fields)
+
+ assert sch.names == ['foo', 'bar', 'baz']
+ assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
+
+ assert len(sch) == 3
+ assert sch[0].name == 'foo'
+ assert sch[0].type == fields[0].type
+ assert sch.field('foo').name == 'foo'
+ assert sch.field('foo').type == fields[0].type
+
+ assert repr(sch) == """\
+foo: int32
+bar: string
+baz: list<item: int8>
+ child 0, item: int8"""
+
+ with pytest.raises(TypeError):
+ pa.schema([None])
+
+
+def test_schema_weakref():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ]
+ schema = pa.schema(fields)
+ wr = weakref.ref(schema)
+ assert wr() is not None
+ del schema
+ assert wr() is None
+
+
+def test_schema_to_string_with_metadata():
+ lorem = """\
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla accumsan vel
+turpis et mollis. Aliquam tincidunt arcu id tortor blandit blandit. Donec
+eget leo quis lectus scelerisque varius. Class aptent taciti sociosqu ad
+litora torquent per conubia nostra, per inceptos himenaeos. Praesent
+faucibus, diam eu volutpat iaculis, tellus est porta ligula, a efficitur
+turpis nulla facilisis quam. Aliquam vitae lorem erat. Proin a dolor ac libero
+dignissim mollis vitae eu mauris. Quisque posuere tellus vitae massa
+pellentesque sagittis. Aenean feugiat, diam ac dignissim fermentum, lorem
+sapien commodo massa, vel volutpat orci nisi eu justo. Nulla non blandit
+sapien. Quisque pretium vestibulum urna eu vehicula."""
+ # ARROW-7063
+ my_schema = pa.schema([pa.field("foo", "int32", False,
+ metadata={"key1": "value1"}),
+ pa.field("bar", "string", True,
+ metadata={"key3": "value3"})],
+ metadata={"lorem": lorem})
+
+ assert my_schema.to_string() == """\
+foo: int32 not null
+ -- field metadata --
+ key1: 'value1'
+bar: string
+ -- field metadata --
+ key3: 'value3'
+-- schema metadata --
+lorem: '""" + lorem[:65] + "' + " + str(len(lorem) - 65)
+
+ # Metadata that exactly fits
+ result = pa.schema([('f0', 'int32')],
+ metadata={'key': 'value' + 'x' * 62}).to_string()
+ assert result == """\
+f0: int32
+-- schema metadata --
+key: 'valuexxxxxxxxxxxxxxxxxxxxxxxxxxxxx\
+xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"""
+
+ assert my_schema.to_string(truncate_metadata=False) == """\
+foo: int32 not null
+ -- field metadata --
+ key1: 'value1'
+bar: string
+ -- field metadata --
+ key3: 'value3'
+-- schema metadata --
+lorem: '{}'""".format(lorem)
+
+ assert my_schema.to_string(truncate_metadata=False,
+ show_field_metadata=False) == """\
+foo: int32 not null
+bar: string
+-- schema metadata --
+lorem: '{}'""".format(lorem)
+
+ assert my_schema.to_string(truncate_metadata=False,
+ show_schema_metadata=False) == """\
+foo: int32 not null
+ -- field metadata --
+ key1: 'value1'
+bar: string
+ -- field metadata --
+ key3: 'value3'"""
+
+ assert my_schema.to_string(truncate_metadata=False,
+ show_field_metadata=False,
+ show_schema_metadata=False) == """\
+foo: int32 not null
+bar: string"""
+
+
+def test_schema_from_tuples():
+ fields = [
+ ('foo', pa.int32()),
+ ('bar', pa.string()),
+ ('baz', pa.list_(pa.int8())),
+ ]
+ sch = pa.schema(fields)
+ assert sch.names == ['foo', 'bar', 'baz']
+ assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
+ assert len(sch) == 3
+ assert repr(sch) == """\
+foo: int32
+bar: string
+baz: list<item: int8>
+ child 0, item: int8"""
+
+ with pytest.raises(TypeError):
+ pa.schema([('foo', None)])
+
+
+def test_schema_from_mapping():
+ fields = OrderedDict([
+ ('foo', pa.int32()),
+ ('bar', pa.string()),
+ ('baz', pa.list_(pa.int8())),
+ ])
+ sch = pa.schema(fields)
+ assert sch.names == ['foo', 'bar', 'baz']
+ assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
+ assert len(sch) == 3
+ assert repr(sch) == """\
+foo: int32
+bar: string
+baz: list<item: int8>
+ child 0, item: int8"""
+
+ fields = OrderedDict([('foo', None)])
+ with pytest.raises(TypeError):
+ pa.schema(fields)
+
+
+def test_schema_duplicate_fields():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('foo', pa.list_(pa.int8())),
+ ]
+ sch = pa.schema(fields)
+ assert sch.names == ['foo', 'bar', 'foo']
+ assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())]
+ assert len(sch) == 3
+ assert repr(sch) == """\
+foo: int32
+bar: string
+foo: list<item: int8>
+ child 0, item: int8"""
+
+ assert sch[0].name == 'foo'
+ assert sch[0].type == fields[0].type
+ with pytest.warns(FutureWarning):
+ assert sch.field_by_name('bar') == fields[1]
+ with pytest.warns(FutureWarning):
+ assert sch.field_by_name('xxx') is None
+ with pytest.warns((UserWarning, FutureWarning)):
+ assert sch.field_by_name('foo') is None
+
+ # Schema::GetFieldIndex
+ assert sch.get_field_index('foo') == -1
+
+ # Schema::GetAllFieldIndices
+ assert sch.get_all_field_indices('foo') == [0, 2]
+
+
+def test_field_flatten():
+ f0 = pa.field('foo', pa.int32()).with_metadata({b'foo': b'bar'})
+ assert f0.flatten() == [f0]
+
+ f1 = pa.field('bar', pa.float64(), nullable=False)
+ ff = pa.field('ff', pa.struct([f0, f1]), nullable=False)
+ assert ff.flatten() == [
+ pa.field('ff.foo', pa.int32()).with_metadata({b'foo': b'bar'}),
+ pa.field('ff.bar', pa.float64(), nullable=False)] # XXX
+
+ # Nullable parent makes flattened child nullable
+ ff = pa.field('ff', pa.struct([f0, f1]))
+ assert ff.flatten() == [
+ pa.field('ff.foo', pa.int32()).with_metadata({b'foo': b'bar'}),
+ pa.field('ff.bar', pa.float64())]
+
+ fff = pa.field('fff', pa.struct([ff]))
+ assert fff.flatten() == [pa.field('fff.ff', pa.struct([f0, f1]))]
+
+
+def test_schema_add_remove_metadata():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ]
+
+ s1 = pa.schema(fields)
+
+ assert s1.metadata is None
+
+ metadata = {b'foo': b'bar', b'pandas': b'badger'}
+
+ s2 = s1.with_metadata(metadata)
+ assert s2.metadata == metadata
+
+ s3 = s2.remove_metadata()
+ assert s3.metadata is None
+
+ # idempotent
+ s4 = s3.remove_metadata()
+ assert s4.metadata is None
+
+
+def test_schema_equals():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ]
+ metadata = {b'foo': b'bar', b'pandas': b'badger'}
+
+ sch1 = pa.schema(fields)
+ sch2 = pa.schema(fields)
+ sch3 = pa.schema(fields, metadata=metadata)
+ sch4 = pa.schema(fields, metadata=metadata)
+
+ assert sch1.equals(sch2, check_metadata=True)
+ assert sch3.equals(sch4, check_metadata=True)
+ assert sch1.equals(sch3)
+ assert not sch1.equals(sch3, check_metadata=True)
+ assert not sch1.equals(sch3, check_metadata=True)
+
+ del fields[-1]
+ sch3 = pa.schema(fields)
+ assert not sch1.equals(sch3)
+
+
+def test_schema_equals_propagates_check_metadata():
+ # ARROW-4088
+ schema1 = pa.schema([
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string())
+ ])
+ schema2 = pa.schema([
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string(), metadata={'a': 'alpha'}),
+ ])
+ assert not schema1.equals(schema2, check_metadata=True)
+ assert schema1.equals(schema2)
+
+
+def test_schema_equals_invalid_type():
+ # ARROW-5873
+ schema = pa.schema([pa.field("a", pa.int64())])
+
+ for val in [None, 'string', pa.array([1, 2])]:
+ with pytest.raises(TypeError):
+ schema.equals(val)
+
+
+def test_schema_equality_operators():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ]
+ metadata = {b'foo': b'bar', b'pandas': b'badger'}
+
+ sch1 = pa.schema(fields)
+ sch2 = pa.schema(fields)
+ sch3 = pa.schema(fields, metadata=metadata)
+ sch4 = pa.schema(fields, metadata=metadata)
+
+ assert sch1 == sch2
+ assert sch3 == sch4
+
+ # __eq__ and __ne__ do not check metadata
+ assert sch1 == sch3
+ assert not sch1 != sch3
+
+ assert sch2 == sch4
+
+ # comparison with other types doesn't raise
+ assert sch1 != []
+ assert sch3 != 'foo'
+
+
+def test_schema_get_fields():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ]
+
+ schema = pa.schema(fields)
+
+ assert schema.field('foo').name == 'foo'
+ assert schema.field(0).name == 'foo'
+ assert schema.field(-1).name == 'baz'
+
+ with pytest.raises(KeyError):
+ schema.field('other')
+ with pytest.raises(TypeError):
+ schema.field(0.0)
+ with pytest.raises(IndexError):
+ schema.field(4)
+
+
+def test_schema_negative_indexing():
+ fields = [
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ]
+
+ schema = pa.schema(fields)
+
+ assert schema[-1].equals(schema[2])
+ assert schema[-2].equals(schema[1])
+ assert schema[-3].equals(schema[0])
+
+ with pytest.raises(IndexError):
+ schema[-4]
+
+ with pytest.raises(IndexError):
+ schema[3]
+
+
+def test_schema_repr_with_dictionaries():
+ fields = [
+ pa.field('one', pa.dictionary(pa.int16(), pa.string())),
+ pa.field('two', pa.int32())
+ ]
+ sch = pa.schema(fields)
+
+ expected = (
+ """\
+one: dictionary<values=string, indices=int16, ordered=0>
+two: int32""")
+
+ assert repr(sch) == expected
+
+
+def test_type_schema_pickling():
+ cases = [
+ pa.int8(),
+ pa.string(),
+ pa.binary(),
+ pa.binary(10),
+ pa.list_(pa.string()),
+ pa.map_(pa.string(), pa.int8()),
+ pa.struct([
+ pa.field('a', 'int8'),
+ pa.field('b', 'string')
+ ]),
+ pa.union([
+ pa.field('a', pa.int8()),
+ pa.field('b', pa.int16())
+ ], pa.lib.UnionMode_SPARSE),
+ pa.union([
+ pa.field('a', pa.int8()),
+ pa.field('b', pa.int16())
+ ], pa.lib.UnionMode_DENSE),
+ pa.time32('s'),
+ pa.time64('us'),
+ pa.date32(),
+ pa.date64(),
+ pa.timestamp('ms'),
+ pa.timestamp('ns'),
+ pa.decimal128(12, 2),
+ pa.decimal256(76, 38),
+ pa.field('a', 'string', metadata={b'foo': b'bar'}),
+ pa.list_(pa.field("element", pa.int64())),
+ pa.large_list(pa.field("element", pa.int64())),
+ pa.map_(pa.field("key", pa.string(), nullable=False),
+ pa.field("value", pa.int8()))
+ ]
+
+ for val in cases:
+ roundtripped = pickle.loads(pickle.dumps(val))
+ assert val == roundtripped
+
+ fields = []
+ for i, f in enumerate(cases):
+ if isinstance(f, pa.Field):
+ fields.append(f)
+ else:
+ fields.append(pa.field('_f{}'.format(i), f))
+
+ schema = pa.schema(fields, metadata={b'foo': b'bar'})
+ roundtripped = pickle.loads(pickle.dumps(schema))
+ assert schema == roundtripped
+
+
+def test_empty_table():
+ schema1 = pa.schema([
+ pa.field('f0', pa.int64()),
+ pa.field('f1', pa.dictionary(pa.int32(), pa.string())),
+ pa.field('f2', pa.list_(pa.list_(pa.int64()))),
+ ])
+ # test it preserves field nullability
+ schema2 = pa.schema([
+ pa.field('a', pa.int64(), nullable=False),
+ pa.field('b', pa.int64())
+ ])
+
+ for schema in [schema1, schema2]:
+ table = schema.empty_table()
+ assert isinstance(table, pa.Table)
+ assert table.num_rows == 0
+ assert table.schema == schema
+
+
+@pytest.mark.pandas
+def test_schema_from_pandas():
+ import pandas as pd
+ inputs = [
+ list(range(10)),
+ pd.Categorical(list(range(10))),
+ ['foo', 'bar', None, 'baz', 'qux'],
+ np.array([
+ '2007-07-13T01:23:34.123456789',
+ '2006-01-13T12:34:56.432539784',
+ '2010-08-13T05:46:57.437699912'
+ ], dtype='datetime64[ns]'),
+ ]
+ if Version(pd.__version__) >= Version('1.0.0'):
+ inputs.append(pd.array([1, 2, None], dtype=pd.Int32Dtype()))
+ for data in inputs:
+ df = pd.DataFrame({'a': data})
+ schema = pa.Schema.from_pandas(df)
+ expected = pa.Table.from_pandas(df).schema
+ assert schema == expected
+
+
+def test_schema_sizeof():
+ schema = pa.schema([
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ ])
+
+ assert sys.getsizeof(schema) > 30
+
+ schema2 = schema.with_metadata({"key": "some metadata"})
+ assert sys.getsizeof(schema2) > sys.getsizeof(schema)
+ schema3 = schema.with_metadata({"key": "some more metadata"})
+ assert sys.getsizeof(schema3) > sys.getsizeof(schema2)
+
+
+def test_schema_merge():
+ a = pa.schema([
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8()))
+ ])
+ b = pa.schema([
+ pa.field('foo', pa.int32()),
+ pa.field('qux', pa.bool_())
+ ])
+ c = pa.schema([
+ pa.field('quux', pa.dictionary(pa.int32(), pa.string()))
+ ])
+ d = pa.schema([
+ pa.field('foo', pa.int64()),
+ pa.field('qux', pa.bool_())
+ ])
+
+ result = pa.unify_schemas([a, b, c])
+ expected = pa.schema([
+ pa.field('foo', pa.int32()),
+ pa.field('bar', pa.string()),
+ pa.field('baz', pa.list_(pa.int8())),
+ pa.field('qux', pa.bool_()),
+ pa.field('quux', pa.dictionary(pa.int32(), pa.string()))
+ ])
+ assert result.equals(expected)
+
+ with pytest.raises(pa.ArrowInvalid):
+ pa.unify_schemas([b, d])
+
+ # ARROW-14002: Try with tuple instead of list
+ result = pa.unify_schemas((a, b, c))
+ assert result.equals(expected)
+
+
+def test_undecodable_metadata():
+ # ARROW-10214: undecodable metadata shouldn't fail repr()
+ data1 = b'abcdef\xff\x00'
+ data2 = b'ghijkl\xff\x00'
+ schema = pa.schema(
+ [pa.field('ints', pa.int16(), metadata={'key': data1})],
+ metadata={'key': data2})
+ assert 'abcdef' in str(schema)
+ assert 'ghijkl' in str(schema)
diff --git a/src/arrow/python/pyarrow/tests/test_serialization.py b/src/arrow/python/pyarrow/tests/test_serialization.py
new file mode 100644
index 000000000..750827a50
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_serialization.py
@@ -0,0 +1,1233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import pytest
+
+import collections
+import datetime
+import os
+import pathlib
+import pickle
+import subprocess
+import string
+import sys
+
+import pyarrow as pa
+import numpy as np
+
+import pyarrow.tests.util as test_util
+
+try:
+ import torch
+except ImportError:
+ torch = None
+ # Blacklist the module in case `import torch` is costly before
+ # failing (ARROW-2071)
+ sys.modules['torch'] = None
+
+try:
+ from scipy.sparse import coo_matrix, csr_matrix, csc_matrix
+except ImportError:
+ coo_matrix = None
+ csr_matrix = None
+ csc_matrix = None
+
+try:
+ import sparse
+except ImportError:
+ sparse = None
+
+
+# ignore all serialization deprecation warnings in this file, we test that the
+# warnings are actually raised in test_serialization_deprecated.py
+pytestmark = pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning")
+
+
+def assert_equal(obj1, obj2):
+ if torch is not None and torch.is_tensor(obj1) and torch.is_tensor(obj2):
+ if obj1.is_sparse:
+ obj1 = obj1.to_dense()
+ if obj2.is_sparse:
+ obj2 = obj2.to_dense()
+ assert torch.equal(obj1, obj2)
+ return
+ module_numpy = (type(obj1).__module__ == np.__name__ or
+ type(obj2).__module__ == np.__name__)
+ if module_numpy:
+ empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or
+ (hasattr(obj2, "shape") and obj2.shape == ()))
+ if empty_shape:
+ # This is a special case because currently np.testing.assert_equal
+ # fails because we do not properly handle different numerical
+ # types.
+ assert obj1 == obj2, ("Objects {} and {} are "
+ "different.".format(obj1, obj2))
+ else:
+ np.testing.assert_equal(obj1, obj2)
+ elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
+ special_keys = ["_pytype_"]
+ assert (set(list(obj1.__dict__.keys()) + special_keys) ==
+ set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} "
+ "and {} are "
+ "different."
+ .format(
+ obj1,
+ obj2))
+ if obj1.__dict__ == {}:
+ print("WARNING: Empty dict in ", obj1)
+ for key in obj1.__dict__.keys():
+ if key not in special_keys:
+ assert_equal(obj1.__dict__[key], obj2.__dict__[key])
+ elif type(obj1) is dict or type(obj2) is dict:
+ assert_equal(obj1.keys(), obj2.keys())
+ for key in obj1.keys():
+ assert_equal(obj1[key], obj2[key])
+ elif type(obj1) is list or type(obj2) is list:
+ assert len(obj1) == len(obj2), ("Objects {} and {} are lists with "
+ "different lengths."
+ .format(obj1, obj2))
+ for i in range(len(obj1)):
+ assert_equal(obj1[i], obj2[i])
+ elif type(obj1) is tuple or type(obj2) is tuple:
+ assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with "
+ "different lengths."
+ .format(obj1, obj2))
+ for i in range(len(obj1)):
+ assert_equal(obj1[i], obj2[i])
+ elif (pa.lib.is_named_tuple(type(obj1)) or
+ pa.lib.is_named_tuple(type(obj2))):
+ assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples "
+ "with different lengths."
+ .format(obj1, obj2))
+ for i in range(len(obj1)):
+ assert_equal(obj1[i], obj2[i])
+ elif isinstance(obj1, pa.Array) and isinstance(obj2, pa.Array):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.SparseCOOTensor) and \
+ isinstance(obj2, pa.SparseCOOTensor):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.SparseCSRMatrix) and \
+ isinstance(obj2, pa.SparseCSRMatrix):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.SparseCSCMatrix) and \
+ isinstance(obj2, pa.SparseCSCMatrix):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.SparseCSFTensor) and \
+ isinstance(obj2, pa.SparseCSFTensor):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.RecordBatch) and isinstance(obj2, pa.RecordBatch):
+ assert obj1.equals(obj2)
+ elif isinstance(obj1, pa.Table) and isinstance(obj2, pa.Table):
+ assert obj1.equals(obj2)
+ else:
+ assert type(obj1) == type(obj2) and obj1 == obj2, \
+ "Objects {} and {} are different.".format(obj1, obj2)
+
+
+PRIMITIVE_OBJECTS = [
+ 0, 0.0, 0.9, 1 << 62, 1 << 999,
+ [1 << 100, [1 << 100]], "a", string.printable, "\u262F",
+ "hello world", "hello world", "\xff\xfe\x9c\x001\x000\x00",
+ None, True, False, [], (), {}, {(1, 2): 1}, {(): 2},
+ [1, "hello", 3.0], "\u262F", 42.0, (1.0, "hi"),
+ [1, 2, 3, None], [(None,), 3, 1.0], ["h", "e", "l", "l", "o", None],
+ (None, None), ("hello", None), (True, False),
+ {True: "hello", False: "world"}, {"hello": "world", 1: 42, 2.5: 45},
+ {"hello": {2, 3}, "world": {42.0}, "this": None},
+ np.int8(3), np.int32(4), np.int64(5),
+ np.uint8(3), np.uint32(4), np.uint64(5),
+ np.float16(1.9), np.float32(1.9),
+ np.float64(1.9), np.zeros([8, 20]),
+ np.random.normal(size=[17, 10]), np.array(["hi", 3]),
+ np.array(["hi", 3], dtype=object),
+ np.random.normal(size=[15, 13]).T
+]
+
+
+index_types = ('i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8')
+tensor_types = ('i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8',
+ 'f2', 'f4', 'f8')
+
+PRIMITIVE_OBJECTS += [0, np.array([["hi", "hi"], [1.3, 1]])]
+
+
+COMPLEX_OBJECTS = [
+ [[[[[[[[[[[[]]]]]]]]]]]],
+ {"obj{}".format(i): np.random.normal(size=[4, 4]) for i in range(5)},
+ # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
+ # (): {(): {}}}}}}}}}}}}},
+ ((((((((((),),),),),),),),),),
+ {"a": {"b": {"c": {"d": {}}}}},
+]
+
+
+class Foo:
+ def __init__(self, value=0):
+ self.value = value
+
+ def __hash__(self):
+ return hash(self.value)
+
+ def __eq__(self, other):
+ return other.value == self.value
+
+
+class Bar:
+ def __init__(self):
+ for i, val in enumerate(COMPLEX_OBJECTS):
+ setattr(self, "field{}".format(i), val)
+
+
+class Baz:
+ def __init__(self):
+ self.foo = Foo()
+ self.bar = Bar()
+
+ def method(self, arg):
+ pass
+
+
+class Qux:
+ def __init__(self):
+ self.objs = [Foo(1), Foo(42)]
+
+
+class SubQux(Qux):
+ def __init__(self):
+ Qux.__init__(self)
+
+
+class SubQuxPickle(Qux):
+ def __init__(self):
+ Qux.__init__(self)
+
+
+class CustomError(Exception):
+ pass
+
+
+Point = collections.namedtuple("Point", ["x", "y"])
+NamedTupleExample = collections.namedtuple(
+ "Example", "field1, field2, field3, field4, field5")
+
+
+CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22),
+ Foo(), Bar(), Baz(), Qux(), SubQux(), SubQuxPickle(),
+ NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]),
+ collections.OrderedDict([("hello", 1), ("world", 2)]),
+ collections.deque([1, 2, 3, "a", "b", "c", 3.5]),
+ collections.Counter([1, 1, 1, 2, 2, 3, "a", "b"])]
+
+
+def make_serialization_context():
+ with pytest.warns(FutureWarning):
+ context = pa.default_serialization_context()
+
+ context.register_type(Foo, "Foo")
+ context.register_type(Bar, "Bar")
+ context.register_type(Baz, "Baz")
+ context.register_type(Qux, "Quz")
+ context.register_type(SubQux, "SubQux")
+ context.register_type(SubQuxPickle, "SubQuxPickle", pickle=True)
+ context.register_type(Exception, "Exception")
+ context.register_type(CustomError, "CustomError")
+ context.register_type(Point, "Point")
+ context.register_type(NamedTupleExample, "NamedTupleExample")
+
+ return context
+
+
+global_serialization_context = make_serialization_context()
+
+
+def serialization_roundtrip(value, scratch_buffer,
+ context=global_serialization_context):
+ writer = pa.FixedSizeBufferWriter(scratch_buffer)
+ pa.serialize_to(value, writer, context=context)
+
+ reader = pa.BufferReader(scratch_buffer)
+ result = pa.deserialize_from(reader, None, context=context)
+ assert_equal(value, result)
+
+ _check_component_roundtrip(value, context=context)
+
+
+def _check_component_roundtrip(value, context=global_serialization_context):
+ # Test to/from components
+ serialized = pa.serialize(value, context=context)
+ components = serialized.to_components()
+ from_comp = pa.SerializedPyObject.from_components(components)
+ recons = from_comp.deserialize(context=context)
+ assert_equal(value, recons)
+
+
+@pytest.fixture(scope='session')
+def large_buffer(size=32*1024*1024):
+ yield pa.allocate_buffer(size)
+
+
+def large_memory_map(tmpdir_factory, size=100*1024*1024):
+ path = (tmpdir_factory.mktemp('data')
+ .join('pyarrow-serialization-tmp-file').strpath)
+
+ # Create a large memory mapped file
+ with open(path, 'wb') as f:
+ f.write(np.random.randint(0, 256, size=size)
+ .astype('u1')
+ .tobytes()
+ [:size])
+ return path
+
+
+def test_clone():
+ context = pa.SerializationContext()
+
+ class Foo:
+ pass
+
+ def custom_serializer(obj):
+ return 0
+
+ def custom_deserializer(serialized_obj):
+ return (serialized_obj, 'a')
+
+ context.register_type(Foo, 'Foo', custom_serializer=custom_serializer,
+ custom_deserializer=custom_deserializer)
+
+ new_context = context.clone()
+
+ f = Foo()
+ serialized = pa.serialize(f, context=context)
+ deserialized = serialized.deserialize(context=context)
+ assert deserialized == (0, 'a')
+
+ serialized = pa.serialize(f, context=new_context)
+ deserialized = serialized.deserialize(context=new_context)
+ assert deserialized == (0, 'a')
+
+
+def test_primitive_serialization_notbroken(large_buffer):
+ serialization_roundtrip({(1, 2): 2}, large_buffer)
+
+
+def test_primitive_serialization_broken(large_buffer):
+ serialization_roundtrip({(): 2}, large_buffer)
+
+
+def test_primitive_serialization(large_buffer):
+ for obj in PRIMITIVE_OBJECTS:
+ serialization_roundtrip(obj, large_buffer)
+
+
+def test_integer_limits(large_buffer):
+ # Check that Numpy scalars can be represented up to their limit values
+ # (except np.uint64 which is limited to 2**63 - 1)
+ for dt in [np.int8, np.int64, np.int32, np.int64,
+ np.uint8, np.uint64, np.uint32, np.uint64]:
+ scal = dt(np.iinfo(dt).min)
+ serialization_roundtrip(scal, large_buffer)
+ if dt is not np.uint64:
+ scal = dt(np.iinfo(dt).max)
+ serialization_roundtrip(scal, large_buffer)
+ else:
+ scal = dt(2**63 - 1)
+ serialization_roundtrip(scal, large_buffer)
+ for v in (2**63, 2**64 - 1):
+ scal = dt(v)
+ with pytest.raises(pa.ArrowInvalid):
+ pa.serialize(scal)
+
+
+def test_serialize_to_buffer():
+ for nthreads in [1, 4]:
+ for value in COMPLEX_OBJECTS:
+ buf = pa.serialize(value).to_buffer(nthreads=nthreads)
+ result = pa.deserialize(buf)
+ assert_equal(value, result)
+
+
+def test_complex_serialization(large_buffer):
+ for obj in COMPLEX_OBJECTS:
+ serialization_roundtrip(obj, large_buffer)
+
+
+def test_custom_serialization(large_buffer):
+ for obj in CUSTOM_OBJECTS:
+ serialization_roundtrip(obj, large_buffer)
+
+
+def test_default_dict_serialization(large_buffer):
+ pytest.importorskip("cloudpickle")
+
+ obj = collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)])
+ serialization_roundtrip(obj, large_buffer)
+
+
+def test_numpy_serialization(large_buffer):
+ for t in ["bool", "int8", "uint8", "int16", "uint16", "int32",
+ "uint32", "float16", "float32", "float64", "<U1", "<U2", "<U3",
+ "<U4", "|S1", "|S2", "|S3", "|S4", "|O",
+ np.dtype([('a', 'int64'), ('b', 'float')]),
+ np.dtype([('x', 'uint32'), ('y', '<U8')])]:
+ obj = np.random.randint(0, 10, size=(100, 100)).astype(t)
+ serialization_roundtrip(obj, large_buffer)
+ obj = obj[1:99, 10:90]
+ serialization_roundtrip(obj, large_buffer)
+
+
+def test_datetime_serialization(large_buffer):
+ data = [
+ # Principia Mathematica published
+ datetime.datetime(year=1687, month=7, day=5),
+
+ # Some random date
+ datetime.datetime(year=1911, month=6, day=3, hour=4,
+ minute=55, second=44),
+ # End of WWI
+ datetime.datetime(year=1918, month=11, day=11),
+
+ # Beginning of UNIX time
+ datetime.datetime(year=1970, month=1, day=1),
+
+ # The Berlin wall falls
+ datetime.datetime(year=1989, month=11, day=9),
+
+ # Another random date
+ datetime.datetime(year=2011, month=6, day=3, hour=4,
+ minute=0, second=3),
+ # Another random date
+ datetime.datetime(year=1970, month=1, day=3, hour=4,
+ minute=0, second=0)
+ ]
+ for d in data:
+ serialization_roundtrip(d, large_buffer)
+
+
+def test_torch_serialization(large_buffer):
+ pytest.importorskip("torch")
+
+ serialization_context = pa.default_serialization_context()
+ pa.register_torch_serialization_handlers(serialization_context)
+
+ # Dense tensors:
+
+ # These are the only types that are supported for the
+ # PyTorch to NumPy conversion
+ for t in ["float32", "float64",
+ "uint8", "int16", "int32", "int64"]:
+ obj = torch.from_numpy(np.random.randn(1000).astype(t))
+ serialization_roundtrip(obj, large_buffer,
+ context=serialization_context)
+
+ tensor_requiring_grad = torch.randn(10, 10, requires_grad=True)
+ serialization_roundtrip(tensor_requiring_grad, large_buffer,
+ context=serialization_context)
+
+ # Sparse tensors:
+
+ # These are the only types that are supported for the
+ # PyTorch to NumPy conversion
+ for t in ["float32", "float64",
+ "uint8", "int16", "int32", "int64"]:
+ i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
+ v = torch.from_numpy(np.array([3, 4, 5]).astype(t))
+ obj = torch.sparse_coo_tensor(i.t(), v, torch.Size([2, 3]))
+ serialization_roundtrip(obj, large_buffer,
+ context=serialization_context)
+
+
+@pytest.mark.skipif(not torch or not torch.cuda.is_available(),
+ reason="requires pytorch with CUDA")
+def test_torch_cuda():
+ # ARROW-2920: This used to segfault if torch is not imported
+ # before pyarrow
+ # Note that this test will only catch the issue if it is run
+ # with a pyarrow that has been built in the manylinux1 environment
+ torch.nn.Conv2d(64, 2, kernel_size=3, stride=1,
+ padding=1, bias=False).cuda()
+
+
+def test_numpy_immutable(large_buffer):
+ obj = np.zeros([10])
+
+ writer = pa.FixedSizeBufferWriter(large_buffer)
+ pa.serialize_to(obj, writer, global_serialization_context)
+
+ reader = pa.BufferReader(large_buffer)
+ result = pa.deserialize_from(reader, None, global_serialization_context)
+ with pytest.raises(ValueError):
+ result[0] = 1.0
+
+
+def test_numpy_base_object(tmpdir):
+ # ARROW-2040: deserialized Numpy array should keep a reference to the
+ # owner of its memory
+ path = os.path.join(str(tmpdir), 'zzz.bin')
+ data = np.arange(12, dtype=np.int32)
+
+ with open(path, 'wb') as f:
+ f.write(pa.serialize(data).to_buffer())
+
+ serialized = pa.read_serialized(pa.OSFile(path))
+ result = serialized.deserialize()
+ assert_equal(result, data)
+ serialized = None
+ assert_equal(result, data)
+ assert result.base is not None
+
+
+# see https://issues.apache.org/jira/browse/ARROW-1695
+def test_serialization_callback_numpy():
+
+ class DummyClass:
+ pass
+
+ def serialize_dummy_class(obj):
+ x = np.zeros(4)
+ return x
+
+ def deserialize_dummy_class(serialized_obj):
+ return serialized_obj
+
+ context = pa.default_serialization_context()
+ context.register_type(DummyClass, "DummyClass",
+ custom_serializer=serialize_dummy_class,
+ custom_deserializer=deserialize_dummy_class)
+
+ pa.serialize(DummyClass(), context=context)
+
+
+def test_numpy_subclass_serialization():
+ # Check that we can properly serialize subclasses of np.ndarray.
+ class CustomNDArray(np.ndarray):
+ def __new__(cls, input_array):
+ array = np.asarray(input_array).view(cls)
+ return array
+
+ def serializer(obj):
+ return {'numpy': obj.view(np.ndarray)}
+
+ def deserializer(data):
+ array = data['numpy'].view(CustomNDArray)
+ return array
+
+ context = pa.default_serialization_context()
+
+ context.register_type(CustomNDArray, 'CustomNDArray',
+ custom_serializer=serializer,
+ custom_deserializer=deserializer)
+
+ x = CustomNDArray(np.zeros(3))
+ serialized = pa.serialize(x, context=context).to_buffer()
+ new_x = pa.deserialize(serialized, context=context)
+ assert type(new_x) == CustomNDArray
+ assert np.alltrue(new_x.view(np.ndarray) == np.zeros(3))
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_coo_tensor_serialization(index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([[1, 2, 3, 4, 5, 6]]).T.astype(tensor_dtype)
+ coords = np.array([
+ [0, 0, 2, 3, 1, 3],
+ [0, 2, 0, 4, 5, 5],
+ ]).T.astype(index_dtype)
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCOOTensor.from_numpy(data, coords,
+ shape, dim_names)
+
+ context = pa.default_serialization_context()
+ serialized = pa.serialize(sparse_tensor, context=context).to_buffer()
+ result = pa.deserialize(serialized)
+ assert_equal(result, sparse_tensor)
+ assert isinstance(result, pa.SparseCOOTensor)
+
+ data_result, coords_result = result.to_numpy()
+ assert np.array_equal(data_result, data)
+ assert np.array_equal(coords_result, coords)
+ assert result.dim_names == dim_names
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_coo_tensor_components_serialization(large_buffer,
+ index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([[1, 2, 3, 4, 5, 6]]).T.astype(tensor_dtype)
+ coords = np.array([
+ [0, 0, 2, 3, 1, 3],
+ [0, 2, 0, 4, 5, 5],
+ ]).T.astype(index_dtype)
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCOOTensor.from_numpy(data, coords,
+ shape, dim_names)
+ serialization_roundtrip(sparse_tensor, large_buffer)
+
+
+@pytest.mark.skipif(not coo_matrix, reason="requires scipy")
+def test_scipy_sparse_coo_tensor_serialization():
+ data = np.array([1, 2, 3, 4, 5, 6])
+ row = np.array([0, 0, 2, 3, 1, 3])
+ col = np.array([0, 2, 0, 4, 5, 5])
+ shape = (4, 6)
+
+ sparse_array = coo_matrix((data, (row, col)), shape=shape)
+ serialized = pa.serialize(sparse_array)
+ result = serialized.deserialize()
+
+ assert np.array_equal(sparse_array.toarray(), result.toarray())
+
+
+@pytest.mark.skipif(not sparse, reason="requires pydata/sparse")
+def test_pydata_sparse_sparse_coo_tensor_serialization():
+ data = np.array([1, 2, 3, 4, 5, 6])
+ coords = np.array([
+ [0, 0, 2, 3, 1, 3],
+ [0, 2, 0, 4, 5, 5],
+ ])
+ shape = (4, 6)
+
+ sparse_array = sparse.COO(data=data, coords=coords, shape=shape)
+ serialized = pa.serialize(sparse_array)
+ result = serialized.deserialize()
+
+ assert np.array_equal(sparse_array.todense(), result.todense())
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_csr_matrix_serialization(index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(tensor_dtype)
+ indptr = np.array([0, 2, 3, 4, 6]).astype(index_dtype)
+ indices = np.array([0, 2, 5, 0, 4, 5]).astype(index_dtype)
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCSRMatrix.from_numpy(data, indptr, indices,
+ shape, dim_names)
+
+ context = pa.default_serialization_context()
+ serialized = pa.serialize(sparse_tensor, context=context).to_buffer()
+ result = pa.deserialize(serialized)
+ assert_equal(result, sparse_tensor)
+ assert isinstance(result, pa.SparseCSRMatrix)
+
+ data_result, indptr_result, indices_result = result.to_numpy()
+ assert np.array_equal(data_result, data)
+ assert np.array_equal(indptr_result, indptr)
+ assert np.array_equal(indices_result, indices)
+ assert result.dim_names == dim_names
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_csr_matrix_components_serialization(large_buffer,
+ index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([8, 2, 5, 3, 4, 6]).astype(tensor_dtype)
+ indptr = np.array([0, 2, 3, 4, 6]).astype(index_dtype)
+ indices = np.array([0, 2, 5, 0, 4, 5]).astype(index_dtype)
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCSRMatrix.from_numpy(data, indptr, indices,
+ shape, dim_names)
+ serialization_roundtrip(sparse_tensor, large_buffer)
+
+
+@pytest.mark.skipif(not csr_matrix, reason="requires scipy")
+def test_scipy_sparse_csr_matrix_serialization():
+ data = np.array([8, 2, 5, 3, 4, 6])
+ indptr = np.array([0, 2, 3, 4, 6])
+ indices = np.array([0, 2, 5, 0, 4, 5])
+ shape = (4, 6)
+
+ sparse_array = csr_matrix((data, indices, indptr), shape=shape)
+ serialized = pa.serialize(sparse_array)
+ result = serialized.deserialize()
+
+ assert np.array_equal(sparse_array.toarray(), result.toarray())
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_csc_matrix_serialization(index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(tensor_dtype)
+ indptr = np.array([0, 2, 3, 4, 6]).astype(index_dtype)
+ indices = np.array([0, 2, 5, 0, 4, 5]).astype(index_dtype)
+ shape = (6, 4)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCSCMatrix.from_numpy(data, indptr, indices,
+ shape, dim_names)
+
+ context = pa.default_serialization_context()
+ serialized = pa.serialize(sparse_tensor, context=context).to_buffer()
+ result = pa.deserialize(serialized)
+ assert_equal(result, sparse_tensor)
+ assert isinstance(result, pa.SparseCSCMatrix)
+
+ data_result, indptr_result, indices_result = result.to_numpy()
+ assert np.array_equal(data_result, data)
+ assert np.array_equal(indptr_result, indptr)
+ assert np.array_equal(indices_result, indices)
+ assert result.dim_names == dim_names
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_csc_matrix_components_serialization(large_buffer,
+ index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([8, 2, 5, 3, 4, 6]).astype(tensor_dtype)
+ indptr = np.array([0, 2, 3, 6]).astype(index_dtype)
+ indices = np.array([0, 2, 2, 0, 1, 2]).astype(index_dtype)
+ shape = (3, 3)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCSCMatrix.from_numpy(data, indptr, indices,
+ shape, dim_names)
+ serialization_roundtrip(sparse_tensor, large_buffer)
+
+
+@pytest.mark.skipif(not csc_matrix, reason="requires scipy")
+def test_scipy_sparse_csc_matrix_serialization():
+ data = np.array([8, 2, 5, 3, 4, 6])
+ indptr = np.array([0, 2, 3, 4, 6])
+ indices = np.array([0, 2, 5, 0, 4, 5])
+ shape = (6, 4)
+
+ sparse_array = csc_matrix((data, indices, indptr), shape=shape)
+ serialized = pa.serialize(sparse_array)
+ result = serialized.deserialize()
+
+ assert np.array_equal(sparse_array.toarray(), result.toarray())
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_csf_tensor_serialization(index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([[1, 2, 3, 4, 5, 6, 7, 8]]).T.astype(tensor_dtype)
+ indptr = [
+ np.array([0, 2, 3]),
+ np.array([0, 1, 3, 4]),
+ np.array([0, 2, 4, 5, 8]),
+ ]
+ indices = [
+ np.array([0, 1]),
+ np.array([0, 1, 1]),
+ np.array([0, 0, 1, 1]),
+ np.array([1, 2, 0, 2, 0, 0, 1, 2]),
+ ]
+ indptr = [x.astype(index_dtype) for x in indptr]
+ indices = [x.astype(index_dtype) for x in indices]
+ shape = (2, 3, 4, 5)
+ axis_order = (0, 1, 2, 3)
+ dim_names = ("a", "b", "c", "d")
+
+ for ndim in [2, 3, 4]:
+ sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr[:ndim - 1],
+ indices[:ndim],
+ shape[:ndim],
+ axis_order[:ndim],
+ dim_names[:ndim])
+
+ context = pa.default_serialization_context()
+ serialized = pa.serialize(sparse_tensor, context=context).to_buffer()
+ result = pa.deserialize(serialized)
+ assert_equal(result, sparse_tensor)
+ assert isinstance(result, pa.SparseCSFTensor)
+
+
+@pytest.mark.parametrize('tensor_type', tensor_types)
+@pytest.mark.parametrize('index_type', index_types)
+def test_sparse_csf_tensor_components_serialization(large_buffer,
+ index_type, tensor_type):
+ tensor_dtype = np.dtype(tensor_type)
+ index_dtype = np.dtype(index_type)
+ data = np.array([[1, 2, 3, 4, 5, 6, 7, 8]]).T.astype(tensor_dtype)
+ indptr = [
+ np.array([0, 2, 3]),
+ np.array([0, 1, 3, 4]),
+ np.array([0, 2, 4, 5, 8]),
+ ]
+ indices = [
+ np.array([0, 1]),
+ np.array([0, 1, 1]),
+ np.array([0, 0, 1, 1]),
+ np.array([1, 2, 0, 2, 0, 0, 1, 2]),
+ ]
+ indptr = [x.astype(index_dtype) for x in indptr]
+ indices = [x.astype(index_dtype) for x in indices]
+ shape = (2, 3, 4, 5)
+ axis_order = (0, 1, 2, 3)
+ dim_names = ("a", "b", "c", "d")
+
+ for ndim in [2, 3, 4]:
+ sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr[:ndim - 1],
+ indices[:ndim],
+ shape[:ndim],
+ axis_order[:ndim],
+ dim_names[:ndim])
+
+ serialization_roundtrip(sparse_tensor, large_buffer)
+
+
+@pytest.mark.filterwarnings(
+ "ignore:the matrix subclass:PendingDeprecationWarning")
+def test_numpy_matrix_serialization(tmpdir):
+ class CustomType:
+ def __init__(self, val):
+ self.val = val
+
+ rec_type = np.dtype([('x', 'int64'), ('y', 'double'), ('z', '<U4')])
+
+ path = os.path.join(str(tmpdir), 'pyarrow_npmatrix_serialization_test.bin')
+ array = np.random.randint(low=-1, high=1, size=(2, 2))
+
+ for data_type in [str, int, float, rec_type, CustomType]:
+ matrix = np.matrix(array.astype(data_type))
+
+ with open(path, 'wb') as f:
+ f.write(pa.serialize(matrix).to_buffer())
+
+ serialized = pa.read_serialized(pa.OSFile(path))
+ result = serialized.deserialize()
+ assert_equal(result, matrix)
+ assert_equal(result.dtype, matrix.dtype)
+ serialized = None
+ assert_equal(result, matrix)
+ assert result.base is not None
+
+
+def test_pyarrow_objects_serialization(large_buffer):
+ # NOTE: We have to put these objects inside,
+ # or it will affect 'test_total_bytes_allocated'.
+ pyarrow_objects = [
+ pa.array([1, 2, 3, 4]), pa.array(['1', 'never U+1F631', '',
+ "233 * U+1F600"]),
+ pa.array([1, None, 2, 3]),
+ pa.Tensor.from_numpy(np.random.rand(2, 3, 4)),
+ pa.RecordBatch.from_arrays(
+ [pa.array([1, None, 2, 3]),
+ pa.array(['1', 'never U+1F631', '', "233 * u1F600"])],
+ ['a', 'b']),
+ pa.Table.from_arrays([pa.array([1, None, 2, 3]),
+ pa.array(['1', 'never U+1F631', '',
+ "233 * u1F600"])],
+ ['a', 'b'])
+ ]
+ for obj in pyarrow_objects:
+ serialization_roundtrip(obj, large_buffer)
+
+
+def test_buffer_serialization():
+
+ class BufferClass:
+ pass
+
+ def serialize_buffer_class(obj):
+ return pa.py_buffer(b"hello")
+
+ def deserialize_buffer_class(serialized_obj):
+ return serialized_obj
+
+ context = pa.default_serialization_context()
+ context.register_type(
+ BufferClass, "BufferClass",
+ custom_serializer=serialize_buffer_class,
+ custom_deserializer=deserialize_buffer_class)
+
+ b = pa.serialize(BufferClass(), context=context).to_buffer()
+ assert pa.deserialize(b, context=context).to_pybytes() == b"hello"
+
+
+@pytest.mark.skip(reason="extensive memory requirements")
+def test_arrow_limits(self):
+ def huge_memory_map(temp_dir):
+ return large_memory_map(temp_dir, 100 * 1024 * 1024 * 1024)
+
+ with pa.memory_map(huge_memory_map, mode="r+") as mmap:
+ # Test that objects that are too large for Arrow throw a Python
+ # exception. These tests give out of memory errors on Travis and need
+ # to be run on a machine with lots of RAM.
+ x = 2 ** 29 * [1.0]
+ serialization_roundtrip(x, mmap)
+ del x
+ x = 2 ** 29 * ["s"]
+ serialization_roundtrip(x, mmap)
+ del x
+ x = 2 ** 29 * [["1"], 2, 3, [{"s": 4}]]
+ serialization_roundtrip(x, mmap)
+ del x
+ x = 2 ** 29 * [{"s": 1}] + 2 ** 29 * [1.0]
+ serialization_roundtrip(x, mmap)
+ del x
+ x = np.zeros(2 ** 25)
+ serialization_roundtrip(x, mmap)
+ del x
+ x = [np.zeros(2 ** 18) for _ in range(2 ** 7)]
+ serialization_roundtrip(x, mmap)
+ del x
+
+
+def test_serialization_callback_error():
+
+ class TempClass:
+ pass
+
+ # Pass a SerializationContext into serialize, but TempClass
+ # is not registered
+ serialization_context = pa.SerializationContext()
+ val = TempClass()
+ with pytest.raises(pa.SerializationCallbackError) as err:
+ serialized_object = pa.serialize(val, serialization_context)
+ assert err.value.example_object == val
+
+ serialization_context.register_type(TempClass, "TempClass")
+ serialized_object = pa.serialize(TempClass(), serialization_context)
+ deserialization_context = pa.SerializationContext()
+
+ # Pass a Serialization Context into deserialize, but TempClass
+ # is not registered
+ with pytest.raises(pa.DeserializationCallbackError) as err:
+ serialized_object.deserialize(deserialization_context)
+ assert err.value.type_id == "TempClass"
+
+ class TempClass2:
+ pass
+
+ # Make sure that we receive an error when we use an inappropriate value for
+ # the type_id argument.
+ with pytest.raises(TypeError):
+ serialization_context.register_type(TempClass2, 1)
+
+
+def test_fallback_to_subclasses():
+
+ class SubFoo(Foo):
+ def __init__(self):
+ Foo.__init__(self)
+
+ # should be able to serialize/deserialize an instance
+ # if a base class has been registered
+ serialization_context = pa.SerializationContext()
+ serialization_context.register_type(Foo, "Foo")
+
+ subfoo = SubFoo()
+ # should fallbact to Foo serializer
+ serialized_object = pa.serialize(subfoo, serialization_context)
+
+ reconstructed_object = serialized_object.deserialize(
+ serialization_context
+ )
+ assert type(reconstructed_object) == Foo
+
+
+class Serializable:
+ pass
+
+
+def serialize_serializable(obj):
+ return {"type": type(obj), "data": obj.__dict__}
+
+
+def deserialize_serializable(obj):
+ val = obj["type"].__new__(obj["type"])
+ val.__dict__.update(obj["data"])
+ return val
+
+
+class SerializableClass(Serializable):
+ def __init__(self):
+ self.value = 3
+
+
+def test_serialize_subclasses():
+
+ # This test shows how subclasses can be handled in an idiomatic way
+ # by having only a serializer for the base class
+
+ # This technique should however be used with care, since pickling
+ # type(obj) with couldpickle will include the full class definition
+ # in the serialized representation.
+ # This means the class definition is part of every instance of the
+ # object, which in general is not desirable; registering all subclasses
+ # with register_type will result in faster and more memory
+ # efficient serialization.
+
+ context = pa.default_serialization_context()
+ context.register_type(
+ Serializable, "Serializable",
+ custom_serializer=serialize_serializable,
+ custom_deserializer=deserialize_serializable)
+
+ a = SerializableClass()
+ serialized = pa.serialize(a, context=context)
+
+ deserialized = serialized.deserialize(context=context)
+ assert type(deserialized).__name__ == SerializableClass.__name__
+ assert deserialized.value == 3
+
+
+def test_serialize_to_components_invalid_cases():
+ buf = pa.py_buffer(b'hello')
+
+ components = {
+ 'num_tensors': 0,
+ 'num_sparse_tensors': {
+ 'coo': 0, 'csr': 0, 'csc': 0, 'csf': 0, 'ndim_csf': 0
+ },
+ 'num_ndarrays': 0,
+ 'num_buffers': 1,
+ 'data': [buf]
+ }
+
+ with pytest.raises(pa.ArrowInvalid):
+ pa.deserialize_components(components)
+
+ components = {
+ 'num_tensors': 0,
+ 'num_sparse_tensors': {
+ 'coo': 0, 'csr': 0, 'csc': 0, 'csf': 0, 'ndim_csf': 0
+ },
+ 'num_ndarrays': 1,
+ 'num_buffers': 0,
+ 'data': [buf, buf]
+ }
+
+ with pytest.raises(pa.ArrowInvalid):
+ pa.deserialize_components(components)
+
+
+def test_deserialize_components_in_different_process():
+ arr = pa.array([1, 2, 5, 6], type=pa.int8())
+ ser = pa.serialize(arr)
+ data = pickle.dumps(ser.to_components(), protocol=-1)
+
+ code = """if 1:
+ import pickle
+
+ import pyarrow as pa
+
+ data = {!r}
+ components = pickle.loads(data)
+ arr = pa.deserialize_components(components)
+
+ assert arr.to_pylist() == [1, 2, 5, 6], arr
+ """.format(data)
+
+ subprocess_env = test_util.get_modified_env_with_pythonpath()
+ print("** sys.path =", sys.path)
+ print("** setting PYTHONPATH to:", subprocess_env['PYTHONPATH'])
+ subprocess.check_call([sys.executable, "-c", code], env=subprocess_env)
+
+
+def test_serialize_read_concatenated_records():
+ # ARROW-1996 -- see stream alignment work in ARROW-2840, ARROW-3212
+ f = pa.BufferOutputStream()
+ pa.serialize_to(12, f)
+ pa.serialize_to(23, f)
+ buf = f.getvalue()
+
+ f = pa.BufferReader(buf)
+ pa.read_serialized(f).deserialize()
+ pa.read_serialized(f).deserialize()
+
+
+def deserialize_regex(serialized, q):
+ import pyarrow as pa
+ q.put(pa.deserialize(serialized))
+
+
+def test_deserialize_in_different_process():
+ from multiprocessing import Process, Queue
+ import re
+
+ regex = re.compile(r"\d+\.\d*")
+
+ serialization_context = pa.SerializationContext()
+ serialization_context.register_type(type(regex), "Regex", pickle=True)
+
+ serialized = pa.serialize(regex, serialization_context)
+ serialized_bytes = serialized.to_buffer().to_pybytes()
+
+ q = Queue()
+ p = Process(target=deserialize_regex, args=(serialized_bytes, q))
+ p.start()
+ assert q.get().pattern == regex.pattern
+ p.join()
+
+
+def test_deserialize_buffer_in_different_process():
+ import tempfile
+
+ f = tempfile.NamedTemporaryFile(delete=False)
+ b = pa.serialize(pa.py_buffer(b'hello')).to_buffer()
+ f.write(b.to_pybytes())
+ f.close()
+
+ test_util.invoke_script('deserialize_buffer.py', f.name)
+
+
+def test_set_pickle():
+ # Use a custom type to trigger pickling.
+ class Foo:
+ pass
+
+ context = pa.SerializationContext()
+ context.register_type(Foo, 'Foo', pickle=True)
+
+ test_object = Foo()
+
+ # Define a custom serializer and deserializer to use in place of pickle.
+
+ def dumps1(obj):
+ return b'custom'
+
+ def loads1(serialized_obj):
+ return serialized_obj + b' serialization 1'
+
+ # Test that setting a custom pickler changes the behavior.
+ context.set_pickle(dumps1, loads1)
+ serialized = pa.serialize(test_object, context=context).to_buffer()
+ deserialized = pa.deserialize(serialized.to_pybytes(), context=context)
+ assert deserialized == b'custom serialization 1'
+
+ # Define another custom serializer and deserializer.
+
+ def dumps2(obj):
+ return b'custom'
+
+ def loads2(serialized_obj):
+ return serialized_obj + b' serialization 2'
+
+ # Test that setting another custom pickler changes the behavior again.
+ context.set_pickle(dumps2, loads2)
+ serialized = pa.serialize(test_object, context=context).to_buffer()
+ deserialized = pa.deserialize(serialized.to_pybytes(), context=context)
+ assert deserialized == b'custom serialization 2'
+
+
+def test_path_objects(tmpdir):
+ # Test compatibility with PEP 519 path-like objects
+ p = pathlib.Path(tmpdir) / 'zzz.bin'
+ obj = 1234
+ pa.serialize_to(obj, p)
+ res = pa.deserialize_from(p, None)
+ assert res == obj
+
+
+def test_tensor_alignment():
+ # Deserialized numpy arrays should be 64-byte aligned.
+ x = np.random.normal(size=(10, 20, 30))
+ y = pa.deserialize(pa.serialize(x).to_buffer())
+ assert y.ctypes.data % 64 == 0
+
+ xs = [np.random.normal(size=i) for i in range(100)]
+ ys = pa.deserialize(pa.serialize(xs).to_buffer())
+ for y in ys:
+ assert y.ctypes.data % 64 == 0
+
+ xs = [np.random.normal(size=i * (1,)) for i in range(20)]
+ ys = pa.deserialize(pa.serialize(xs).to_buffer())
+ for y in ys:
+ assert y.ctypes.data % 64 == 0
+
+ xs = [np.random.normal(size=i * (5,)) for i in range(1, 8)]
+ xs = [xs[i][(i + 1) * (slice(1, 3),)] for i in range(len(xs))]
+ ys = pa.deserialize(pa.serialize(xs).to_buffer())
+ for y in ys:
+ assert y.ctypes.data % 64 == 0
+
+
+def test_empty_tensor():
+ # ARROW-8122, serialize and deserialize empty tensors
+ x = np.array([], dtype=np.float64)
+ y = pa.deserialize(pa.serialize(x).to_buffer())
+ np.testing.assert_array_equal(x, y)
+
+ x = np.array([[], [], []], dtype=np.float64)
+ y = pa.deserialize(pa.serialize(x).to_buffer())
+ np.testing.assert_array_equal(x, y)
+
+ x = np.array([[], [], []], dtype=np.float64).T
+ y = pa.deserialize(pa.serialize(x).to_buffer())
+ np.testing.assert_array_equal(x, y)
+
+
+def test_serialization_determinism():
+ for obj in COMPLEX_OBJECTS:
+ buf1 = pa.serialize(obj).to_buffer()
+ buf2 = pa.serialize(obj).to_buffer()
+ assert buf1.to_pybytes() == buf2.to_pybytes()
+
+
+def test_serialize_recursive_objects():
+ class ClassA:
+ pass
+
+ # Make a list that contains itself.
+ lst = []
+ lst.append(lst)
+
+ # Make an object that contains itself as a field.
+ a1 = ClassA()
+ a1.field = a1
+
+ # Make two objects that contain each other as fields.
+ a2 = ClassA()
+ a3 = ClassA()
+ a2.field = a3
+ a3.field = a2
+
+ # Make a dictionary that contains itself.
+ d1 = {}
+ d1["key"] = d1
+
+ # Make a numpy array that contains itself.
+ arr = np.array([None], dtype=object)
+ arr[0] = arr
+
+ # Create a list of recursive objects.
+ recursive_objects = [lst, a1, a2, a3, d1, arr]
+
+ # Check that exceptions are thrown when we serialize the recursive
+ # objects.
+ for obj in recursive_objects:
+ with pytest.raises(Exception):
+ pa.serialize(obj).deserialize()
diff --git a/src/arrow/python/pyarrow/tests/test_serialization_deprecated.py b/src/arrow/python/pyarrow/tests/test_serialization_deprecated.py
new file mode 100644
index 000000000..cd4b3ed78
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_serialization_deprecated.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+
+import pytest
+
+import pyarrow as pa
+
+
+def test_serialization_deprecated():
+ with pytest.warns(FutureWarning):
+ ser = pa.serialize(1)
+
+ with pytest.warns(FutureWarning):
+ pa.deserialize(ser.to_buffer())
+
+ f = pa.BufferOutputStream()
+ with pytest.warns(FutureWarning):
+ pa.serialize_to(12, f)
+
+ buf = f.getvalue()
+ f = pa.BufferReader(buf)
+ with pytest.warns(FutureWarning):
+ pa.read_serialized(f).deserialize()
+
+ with pytest.warns(FutureWarning):
+ pa.default_serialization_context()
+
+ context = pa.lib.SerializationContext()
+ with pytest.warns(FutureWarning):
+ pa.register_default_serialization_handlers(context)
+
+
+@pytest.mark.skipif(sys.version_info < (3, 7),
+ reason="getattr needs Python 3.7")
+def test_serialization_deprecated_toplevel():
+ with pytest.warns(FutureWarning):
+ pa.SerializedPyObject()
+
+ with pytest.warns(FutureWarning):
+ pa.SerializationContext()
diff --git a/src/arrow/python/pyarrow/tests/test_sparse_tensor.py b/src/arrow/python/pyarrow/tests/test_sparse_tensor.py
new file mode 100644
index 000000000..aa7da0a74
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_sparse_tensor.py
@@ -0,0 +1,491 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import sys
+import weakref
+
+import numpy as np
+import pyarrow as pa
+
+try:
+ from scipy.sparse import csr_matrix, coo_matrix
+except ImportError:
+ coo_matrix = None
+ csr_matrix = None
+
+try:
+ import sparse
+except ImportError:
+ sparse = None
+
+
+tensor_type_pairs = [
+ ('i1', pa.int8()),
+ ('i2', pa.int16()),
+ ('i4', pa.int32()),
+ ('i8', pa.int64()),
+ ('u1', pa.uint8()),
+ ('u2', pa.uint16()),
+ ('u4', pa.uint32()),
+ ('u8', pa.uint64()),
+ ('f2', pa.float16()),
+ ('f4', pa.float32()),
+ ('f8', pa.float64())
+]
+
+
+@pytest.mark.parametrize('sparse_tensor_type', [
+ pa.SparseCSRMatrix,
+ pa.SparseCSCMatrix,
+ pa.SparseCOOTensor,
+ pa.SparseCSFTensor,
+])
+def test_sparse_tensor_attrs(sparse_tensor_type):
+ data = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ])
+ dim_names = ('x', 'y')
+ sparse_tensor = sparse_tensor_type.from_dense_numpy(data, dim_names)
+
+ assert sparse_tensor.ndim == 2
+ assert sparse_tensor.size == 24
+ assert sparse_tensor.shape == data.shape
+ assert sparse_tensor.is_mutable
+ assert sparse_tensor.dim_name(0) == dim_names[0]
+ assert sparse_tensor.dim_names == dim_names
+ assert sparse_tensor.non_zero_length == 6
+
+ wr = weakref.ref(sparse_tensor)
+ assert wr() is not None
+ del sparse_tensor
+ assert wr() is None
+
+
+def test_sparse_coo_tensor_base_object():
+ expected_data = np.array([[8, 2, 5, 3, 4, 6]]).T
+ expected_coords = np.array([
+ [0, 0, 1, 2, 3, 3],
+ [0, 2, 5, 0, 4, 5],
+ ]).T
+ array = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ])
+ sparse_tensor = pa.SparseCOOTensor.from_dense_numpy(array)
+ n = sys.getrefcount(sparse_tensor)
+ result_data, result_coords = sparse_tensor.to_numpy()
+ assert sparse_tensor.has_canonical_format
+ assert sys.getrefcount(sparse_tensor) == n + 2
+
+ sparse_tensor = None
+ assert np.array_equal(expected_data, result_data)
+ assert np.array_equal(expected_coords, result_coords)
+ assert result_coords.flags.c_contiguous # row-major
+
+
+def test_sparse_csr_matrix_base_object():
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T
+ indptr = np.array([0, 2, 3, 4, 6])
+ indices = np.array([0, 2, 5, 0, 4, 5])
+ array = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ])
+ sparse_tensor = pa.SparseCSRMatrix.from_dense_numpy(array)
+ n = sys.getrefcount(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sys.getrefcount(sparse_tensor) == n + 3
+
+ sparse_tensor = None
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr, result_indptr)
+ assert np.array_equal(indices, result_indices)
+
+
+def test_sparse_csf_tensor_base_object():
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T
+ indptr = [np.array([0, 2, 3, 4, 6])]
+ indices = [
+ np.array([0, 1, 2, 3]),
+ np.array([0, 2, 5, 0, 4, 5])
+ ]
+ array = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ])
+ sparse_tensor = pa.SparseCSFTensor.from_dense_numpy(array)
+ n = sys.getrefcount(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sys.getrefcount(sparse_tensor) == n + 4
+
+ sparse_tensor = None
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr[0], result_indptr[0])
+ assert np.array_equal(indices[0], result_indices[0])
+ assert np.array_equal(indices[1], result_indices[1])
+
+
+@pytest.mark.parametrize('sparse_tensor_type', [
+ pa.SparseCSRMatrix,
+ pa.SparseCSCMatrix,
+ pa.SparseCOOTensor,
+ pa.SparseCSFTensor,
+])
+def test_sparse_tensor_equals(sparse_tensor_type):
+ def eq(a, b):
+ assert a.equals(b)
+ assert a == b
+ assert not (a != b)
+
+ def ne(a, b):
+ assert not a.equals(b)
+ assert not (a == b)
+ assert a != b
+
+ data = np.random.randn(10, 6)[::, ::2]
+ sparse_tensor1 = sparse_tensor_type.from_dense_numpy(data)
+ sparse_tensor2 = sparse_tensor_type.from_dense_numpy(
+ np.ascontiguousarray(data))
+ eq(sparse_tensor1, sparse_tensor2)
+ data = data.copy()
+ data[9, 0] = 1.0
+ sparse_tensor2 = sparse_tensor_type.from_dense_numpy(
+ np.ascontiguousarray(data))
+ ne(sparse_tensor1, sparse_tensor2)
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_coo_tensor_from_dense(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ expected_data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
+ expected_coords = np.array([
+ [0, 0, 1, 2, 3, 3],
+ [0, 2, 5, 0, 4, 5],
+ ]).T
+ array = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ]).astype(dtype)
+ tensor = pa.Tensor.from_numpy(array)
+
+ # Test from numpy array
+ sparse_tensor = pa.SparseCOOTensor.from_dense_numpy(array)
+ repr(sparse_tensor)
+ result_data, result_coords = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(expected_data, result_data)
+ assert np.array_equal(expected_coords, result_coords)
+
+ # Test from Tensor
+ sparse_tensor = pa.SparseCOOTensor.from_tensor(tensor)
+ repr(sparse_tensor)
+ result_data, result_coords = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(expected_data, result_data)
+ assert np.array_equal(expected_coords, result_coords)
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_csr_matrix_from_dense(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
+ indptr = np.array([0, 2, 3, 4, 6])
+ indices = np.array([0, 2, 5, 0, 4, 5])
+ array = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ]).astype(dtype)
+ tensor = pa.Tensor.from_numpy(array)
+
+ # Test from numpy array
+ sparse_tensor = pa.SparseCSRMatrix.from_dense_numpy(array)
+ repr(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr, result_indptr)
+ assert np.array_equal(indices, result_indices)
+
+ # Test from Tensor
+ sparse_tensor = pa.SparseCSRMatrix.from_tensor(tensor)
+ repr(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr, result_indptr)
+ assert np.array_equal(indices, result_indices)
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_csf_tensor_from_dense_numpy(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
+ indptr = [np.array([0, 2, 3, 4, 6])]
+ indices = [
+ np.array([0, 1, 2, 3]),
+ np.array([0, 2, 5, 0, 4, 5])
+ ]
+ array = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ]).astype(dtype)
+
+ # Test from numpy array
+ sparse_tensor = pa.SparseCSFTensor.from_dense_numpy(array)
+ repr(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr[0], result_indptr[0])
+ assert np.array_equal(indices[0], result_indices[0])
+ assert np.array_equal(indices[1], result_indices[1])
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_csf_tensor_from_dense_tensor(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
+ indptr = [np.array([0, 2, 3, 4, 6])]
+ indices = [
+ np.array([0, 1, 2, 3]),
+ np.array([0, 2, 5, 0, 4, 5])
+ ]
+ array = np.array([
+ [8, 0, 2, 0, 0, 0],
+ [0, 0, 0, 0, 0, 5],
+ [3, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 4, 6],
+ ]).astype(dtype)
+ tensor = pa.Tensor.from_numpy(array)
+
+ # Test from Tensor
+ sparse_tensor = pa.SparseCSFTensor.from_tensor(tensor)
+ repr(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr[0], result_indptr[0])
+ assert np.array_equal(indices[0], result_indices[0])
+ assert np.array_equal(indices[1], result_indices[1])
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_coo_tensor_numpy_roundtrip(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([[1, 2, 3, 4, 5, 6]]).T.astype(dtype)
+ coords = np.array([
+ [0, 0, 2, 3, 1, 3],
+ [0, 2, 0, 4, 5, 5],
+ ]).T
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCOOTensor.from_numpy(data, coords, shape,
+ dim_names)
+ repr(sparse_tensor)
+ result_data, result_coords = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(coords, result_coords)
+ assert sparse_tensor.dim_names == dim_names
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_csr_matrix_numpy_roundtrip(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
+ indptr = np.array([0, 2, 3, 4, 6])
+ indices = np.array([0, 2, 5, 0, 4, 5])
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCSRMatrix.from_numpy(data, indptr, indices,
+ shape, dim_names)
+ repr(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr, result_indptr)
+ assert np.array_equal(indices, result_indices)
+ assert sparse_tensor.dim_names == dim_names
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_csf_tensor_numpy_roundtrip(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype)
+ indptr = [np.array([0, 2, 3, 4, 6])]
+ indices = [
+ np.array([0, 1, 2, 3]),
+ np.array([0, 2, 5, 0, 4, 5])
+ ]
+ axis_order = (0, 1)
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr, indices,
+ shape, axis_order,
+ dim_names)
+ repr(sparse_tensor)
+ result_data, result_indptr, result_indices = sparse_tensor.to_numpy()
+ assert sparse_tensor.type == arrow_type
+ assert np.array_equal(data, result_data)
+ assert np.array_equal(indptr[0], result_indptr[0])
+ assert np.array_equal(indices[0], result_indices[0])
+ assert np.array_equal(indices[1], result_indices[1])
+ assert sparse_tensor.dim_names == dim_names
+
+
+@pytest.mark.parametrize('sparse_tensor_type', [
+ pa.SparseCSRMatrix,
+ pa.SparseCSCMatrix,
+ pa.SparseCOOTensor,
+ pa.SparseCSFTensor,
+])
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_dense_to_sparse_tensor(dtype_str, arrow_type, sparse_tensor_type):
+ dtype = np.dtype(dtype_str)
+ array = np.array([[4, 0, 9, 0],
+ [0, 7, 0, 0],
+ [0, 0, 0, 0],
+ [0, 0, 0, 5]]).astype(dtype)
+ dim_names = ('x', 'y')
+
+ sparse_tensor = sparse_tensor_type.from_dense_numpy(array, dim_names)
+ tensor = sparse_tensor.to_tensor()
+ result_array = tensor.to_numpy()
+
+ assert sparse_tensor.type == arrow_type
+ assert tensor.type == arrow_type
+ assert sparse_tensor.dim_names == dim_names
+ assert np.array_equal(array, result_array)
+
+
+@pytest.mark.skipif(not coo_matrix, reason="requires scipy")
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_coo_tensor_scipy_roundtrip(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([1, 2, 3, 4, 5, 6]).astype(dtype)
+ row = np.array([0, 0, 2, 3, 1, 3])
+ col = np.array([0, 2, 0, 4, 5, 5])
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ # non-canonical sparse coo matrix
+ scipy_matrix = coo_matrix((data, (row, col)), shape=shape)
+ sparse_tensor = pa.SparseCOOTensor.from_scipy(scipy_matrix,
+ dim_names=dim_names)
+ out_scipy_matrix = sparse_tensor.to_scipy()
+
+ assert not scipy_matrix.has_canonical_format
+ assert not sparse_tensor.has_canonical_format
+ assert not out_scipy_matrix.has_canonical_format
+ assert sparse_tensor.type == arrow_type
+ assert sparse_tensor.dim_names == dim_names
+ assert scipy_matrix.dtype == out_scipy_matrix.dtype
+ assert np.array_equal(scipy_matrix.data, out_scipy_matrix.data)
+ assert np.array_equal(scipy_matrix.row, out_scipy_matrix.row)
+ assert np.array_equal(scipy_matrix.col, out_scipy_matrix.col)
+
+ if dtype_str == 'f2':
+ dense_array = \
+ scipy_matrix.astype(np.float32).toarray().astype(np.float16)
+ else:
+ dense_array = scipy_matrix.toarray()
+ assert np.array_equal(dense_array, sparse_tensor.to_tensor().to_numpy())
+
+ # canonical sparse coo matrix
+ scipy_matrix.sum_duplicates()
+ sparse_tensor = pa.SparseCOOTensor.from_scipy(scipy_matrix,
+ dim_names=dim_names)
+ out_scipy_matrix = sparse_tensor.to_scipy()
+
+ assert scipy_matrix.has_canonical_format
+ assert sparse_tensor.has_canonical_format
+ assert out_scipy_matrix.has_canonical_format
+
+
+@pytest.mark.skipif(not csr_matrix, reason="requires scipy")
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_sparse_csr_matrix_scipy_roundtrip(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([8, 2, 5, 3, 4, 6]).astype(dtype)
+ indptr = np.array([0, 2, 3, 4, 6])
+ indices = np.array([0, 2, 5, 0, 4, 5])
+ shape = (4, 6)
+ dim_names = ('x', 'y')
+
+ sparse_array = csr_matrix((data, indices, indptr), shape=shape)
+ sparse_tensor = pa.SparseCSRMatrix.from_scipy(sparse_array,
+ dim_names=dim_names)
+ out_sparse_array = sparse_tensor.to_scipy()
+
+ assert sparse_tensor.type == arrow_type
+ assert sparse_tensor.dim_names == dim_names
+ assert sparse_array.dtype == out_sparse_array.dtype
+ assert np.array_equal(sparse_array.data, out_sparse_array.data)
+ assert np.array_equal(sparse_array.indptr, out_sparse_array.indptr)
+ assert np.array_equal(sparse_array.indices, out_sparse_array.indices)
+
+ if dtype_str == 'f2':
+ dense_array = \
+ sparse_array.astype(np.float32).toarray().astype(np.float16)
+ else:
+ dense_array = sparse_array.toarray()
+ assert np.array_equal(dense_array, sparse_tensor.to_tensor().to_numpy())
+
+
+@pytest.mark.skipif(not sparse, reason="requires pydata/sparse")
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_pydata_sparse_sparse_coo_tensor_roundtrip(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = np.array([1, 2, 3, 4, 5, 6]).astype(dtype)
+ coords = np.array([
+ [0, 0, 2, 3, 1, 3],
+ [0, 2, 0, 4, 5, 5],
+ ])
+ shape = (4, 6)
+ dim_names = ("x", "y")
+
+ sparse_array = sparse.COO(data=data, coords=coords, shape=shape)
+ sparse_tensor = pa.SparseCOOTensor.from_pydata_sparse(sparse_array,
+ dim_names=dim_names)
+ out_sparse_array = sparse_tensor.to_pydata_sparse()
+
+ assert sparse_tensor.type == arrow_type
+ assert sparse_tensor.dim_names == dim_names
+ assert sparse_array.dtype == out_sparse_array.dtype
+ assert np.array_equal(sparse_array.data, out_sparse_array.data)
+ assert np.array_equal(sparse_array.coords, out_sparse_array.coords)
+ assert np.array_equal(sparse_array.todense(),
+ sparse_tensor.to_tensor().to_numpy())
diff --git a/src/arrow/python/pyarrow/tests/test_strategies.py b/src/arrow/python/pyarrow/tests/test_strategies.py
new file mode 100644
index 000000000..14fc94992
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_strategies.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import hypothesis as h
+
+import pyarrow as pa
+import pyarrow.tests.strategies as past
+
+
+@h.given(past.all_types)
+def test_types(ty):
+ assert isinstance(ty, pa.lib.DataType)
+
+
+@h.given(past.all_fields)
+def test_fields(field):
+ assert isinstance(field, pa.lib.Field)
+
+
+@h.given(past.all_schemas)
+def test_schemas(schema):
+ assert isinstance(schema, pa.lib.Schema)
+
+
+@h.given(past.all_arrays)
+def test_arrays(array):
+ assert isinstance(array, pa.lib.Array)
+
+
+@h.given(past.arrays(past.primitive_types, nullable=False))
+def test_array_nullability(array):
+ assert array.null_count == 0
+
+
+@h.given(past.all_chunked_arrays)
+def test_chunked_arrays(chunked_array):
+ assert isinstance(chunked_array, pa.lib.ChunkedArray)
+
+
+@h.given(past.all_record_batches)
+def test_record_batches(record_bath):
+ assert isinstance(record_bath, pa.lib.RecordBatch)
+
+
+@h.given(past.all_tables)
+def test_tables(table):
+ assert isinstance(table, pa.lib.Table)
diff --git a/src/arrow/python/pyarrow/tests/test_table.py b/src/arrow/python/pyarrow/tests/test_table.py
new file mode 100644
index 000000000..ef41a733d
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_table.py
@@ -0,0 +1,1748 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import OrderedDict
+from collections.abc import Iterable
+import pickle
+import sys
+import weakref
+
+import numpy as np
+import pytest
+import pyarrow as pa
+
+
+def test_chunked_array_basics():
+ data = pa.chunked_array([], type=pa.string())
+ assert data.type == pa.string()
+ assert data.to_pylist() == []
+ data.validate()
+
+ data2 = pa.chunked_array([], type='binary')
+ assert data2.type == pa.binary()
+
+ with pytest.raises(ValueError):
+ pa.chunked_array([])
+
+ data = pa.chunked_array([
+ [1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9]
+ ])
+ assert isinstance(data.chunks, list)
+ assert all(isinstance(c, pa.lib.Int64Array) for c in data.chunks)
+ assert all(isinstance(c, pa.lib.Int64Array) for c in data.iterchunks())
+ assert len(data.chunks) == 3
+ assert data.nbytes == sum(c.nbytes for c in data.iterchunks())
+ assert sys.getsizeof(data) >= object.__sizeof__(data) + data.nbytes
+ data.validate()
+
+ wr = weakref.ref(data)
+ assert wr() is not None
+ del data
+ assert wr() is None
+
+
+def test_chunked_array_construction():
+ arr = pa.chunked_array([
+ [1, 2, 3],
+ [4, 5, 6],
+ [7, 8, 9],
+ ])
+ assert arr.type == pa.int64()
+ assert len(arr) == 9
+ assert len(arr.chunks) == 3
+
+ arr = pa.chunked_array([
+ [1, 2, 3],
+ [4., 5., 6.],
+ [7, 8, 9],
+ ])
+ assert arr.type == pa.int64()
+ assert len(arr) == 9
+ assert len(arr.chunks) == 3
+
+ arr = pa.chunked_array([
+ [1, 2, 3],
+ [4., 5., 6.],
+ [7, 8, 9],
+ ], type=pa.int8())
+ assert arr.type == pa.int8()
+ assert len(arr) == 9
+ assert len(arr.chunks) == 3
+
+ arr = pa.chunked_array([
+ [1, 2, 3],
+ []
+ ])
+ assert arr.type == pa.int64()
+ assert len(arr) == 3
+ assert len(arr.chunks) == 2
+
+ msg = (
+ "When passing an empty collection of arrays you must also pass the "
+ "data type"
+ )
+ with pytest.raises(ValueError, match=msg):
+ assert pa.chunked_array([])
+
+ assert pa.chunked_array([], type=pa.string()).type == pa.string()
+ assert pa.chunked_array([[]]).type == pa.null()
+ assert pa.chunked_array([[]], type=pa.string()).type == pa.string()
+
+
+def test_combine_chunks():
+ # ARROW-77363
+ arr = pa.array([1, 2])
+ chunked_arr = pa.chunked_array([arr, arr])
+ res = chunked_arr.combine_chunks()
+ expected = pa.array([1, 2, 1, 2])
+ assert res.equals(expected)
+
+
+def test_chunked_array_to_numpy():
+ data = pa.chunked_array([
+ [1, 2, 3],
+ [4, 5, 6],
+ []
+ ])
+ arr1 = np.asarray(data)
+ arr2 = data.to_numpy()
+
+ assert isinstance(arr2, np.ndarray)
+ assert arr2.shape == (6,)
+ assert np.array_equal(arr1, arr2)
+
+
+def test_chunked_array_mismatch_types():
+ with pytest.raises(TypeError):
+ # Given array types are different
+ pa.chunked_array([
+ pa.array([1, 2, 3]),
+ pa.array([1., 2., 3.])
+ ])
+
+ with pytest.raises(TypeError):
+ # Given array type is different from explicit type argument
+ pa.chunked_array([pa.array([1, 2, 3])], type=pa.float64())
+
+
+def test_chunked_array_str():
+ data = [
+ pa.array([1, 2, 3]),
+ pa.array([4, 5, 6])
+ ]
+ data = pa.chunked_array(data)
+ assert str(data) == """[
+ [
+ 1,
+ 2,
+ 3
+ ],
+ [
+ 4,
+ 5,
+ 6
+ ]
+]"""
+
+
+def test_chunked_array_getitem():
+ data = [
+ pa.array([1, 2, 3]),
+ pa.array([4, 5, 6])
+ ]
+ data = pa.chunked_array(data)
+ assert data[1].as_py() == 2
+ assert data[-1].as_py() == 6
+ assert data[-6].as_py() == 1
+ with pytest.raises(IndexError):
+ data[6]
+ with pytest.raises(IndexError):
+ data[-7]
+ # Ensure this works with numpy scalars
+ assert data[np.int32(1)].as_py() == 2
+
+ data_slice = data[2:4]
+ assert data_slice.to_pylist() == [3, 4]
+
+ data_slice = data[4:-1]
+ assert data_slice.to_pylist() == [5]
+
+ data_slice = data[99:99]
+ assert data_slice.type == data.type
+ assert data_slice.to_pylist() == []
+
+
+def test_chunked_array_slice():
+ data = [
+ pa.array([1, 2, 3]),
+ pa.array([4, 5, 6])
+ ]
+ data = pa.chunked_array(data)
+
+ data_slice = data.slice(len(data))
+ assert data_slice.type == data.type
+ assert data_slice.to_pylist() == []
+
+ data_slice = data.slice(len(data) + 10)
+ assert data_slice.type == data.type
+ assert data_slice.to_pylist() == []
+
+ table = pa.Table.from_arrays([data], names=["a"])
+ table_slice = table.slice(len(table))
+ assert len(table_slice) == 0
+
+ table = pa.Table.from_arrays([data], names=["a"])
+ table_slice = table.slice(len(table) + 10)
+ assert len(table_slice) == 0
+
+
+def test_chunked_array_iter():
+ data = [
+ pa.array([0]),
+ pa.array([1, 2, 3]),
+ pa.array([4, 5, 6]),
+ pa.array([7, 8, 9])
+ ]
+ arr = pa.chunked_array(data)
+
+ for i, j in zip(range(10), arr):
+ assert i == j.as_py()
+
+ assert isinstance(arr, Iterable)
+
+
+def test_chunked_array_equals():
+ def eq(xarrs, yarrs):
+ if isinstance(xarrs, pa.ChunkedArray):
+ x = xarrs
+ else:
+ x = pa.chunked_array(xarrs)
+ if isinstance(yarrs, pa.ChunkedArray):
+ y = yarrs
+ else:
+ y = pa.chunked_array(yarrs)
+ assert x.equals(y)
+ assert y.equals(x)
+ assert x == y
+ assert x != str(y)
+
+ def ne(xarrs, yarrs):
+ if isinstance(xarrs, pa.ChunkedArray):
+ x = xarrs
+ else:
+ x = pa.chunked_array(xarrs)
+ if isinstance(yarrs, pa.ChunkedArray):
+ y = yarrs
+ else:
+ y = pa.chunked_array(yarrs)
+ assert not x.equals(y)
+ assert not y.equals(x)
+ assert x != y
+
+ eq(pa.chunked_array([], type=pa.int32()),
+ pa.chunked_array([], type=pa.int32()))
+ ne(pa.chunked_array([], type=pa.int32()),
+ pa.chunked_array([], type=pa.int64()))
+
+ a = pa.array([0, 2], type=pa.int32())
+ b = pa.array([0, 2], type=pa.int64())
+ c = pa.array([0, 3], type=pa.int32())
+ d = pa.array([0, 2, 0, 3], type=pa.int32())
+
+ eq([a], [a])
+ ne([a], [b])
+ eq([a, c], [a, c])
+ eq([a, c], [d])
+ ne([c, a], [a, c])
+
+ # ARROW-4822
+ assert not pa.chunked_array([], type=pa.int32()).equals(None)
+
+
+@pytest.mark.parametrize(
+ ('data', 'typ'),
+ [
+ ([True, False, True, True], pa.bool_()),
+ ([1, 2, 4, 6], pa.int64()),
+ ([1.0, 2.5, None], pa.float64()),
+ (['a', None, 'b'], pa.string()),
+ ([], pa.list_(pa.uint8())),
+ ([[1, 2], [3]], pa.list_(pa.int64())),
+ ([['a'], None, ['b', 'c']], pa.list_(pa.string())),
+ ([(1, 'a'), (2, 'c'), None],
+ pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.string())]))
+ ]
+)
+def test_chunked_array_pickle(data, typ):
+ arrays = []
+ while data:
+ arrays.append(pa.array(data[:2], type=typ))
+ data = data[2:]
+ array = pa.chunked_array(arrays, type=typ)
+ array.validate()
+ result = pickle.loads(pickle.dumps(array))
+ result.validate()
+ assert result.equals(array)
+
+
+@pytest.mark.pandas
+def test_chunked_array_to_pandas():
+ import pandas as pd
+
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.table(data, names=['a'])
+ col = table.column(0)
+ assert isinstance(col, pa.ChunkedArray)
+ series = col.to_pandas()
+ assert isinstance(series, pd.Series)
+ assert series.shape == (5,)
+ assert series[0] == -10
+ assert series.name == 'a'
+
+
+@pytest.mark.pandas
+def test_chunked_array_to_pandas_preserve_name():
+ # https://issues.apache.org/jira/browse/ARROW-7709
+ import pandas as pd
+ import pandas.testing as tm
+
+ for data in [
+ pa.array([1, 2, 3]),
+ pa.array(pd.Categorical(["a", "b", "a"])),
+ pa.array(pd.date_range("2012", periods=3)),
+ pa.array(pd.date_range("2012", periods=3, tz="Europe/Brussels")),
+ pa.array([1, 2, 3], pa.timestamp("ms")),
+ pa.array([1, 2, 3], pa.timestamp("ms", "Europe/Brussels"))]:
+ table = pa.table({"name": data})
+ result = table.column("name").to_pandas()
+ assert result.name == "name"
+ expected = pd.Series(data.to_pandas(), name="name")
+ tm.assert_series_equal(result, expected)
+
+
+@pytest.mark.pandas
+@pytest.mark.nopandas
+def test_chunked_array_asarray():
+ # ensure this is tested both when pandas is present or not (ARROW-6564)
+
+ data = [
+ pa.array([0]),
+ pa.array([1, 2, 3])
+ ]
+ chunked_arr = pa.chunked_array(data)
+
+ np_arr = np.asarray(chunked_arr)
+ assert np_arr.tolist() == [0, 1, 2, 3]
+ assert np_arr.dtype == np.dtype('int64')
+
+ # An optional type can be specified when calling np.asarray
+ np_arr = np.asarray(chunked_arr, dtype='str')
+ assert np_arr.tolist() == ['0', '1', '2', '3']
+
+ # Types are modified when there are nulls
+ data = [
+ pa.array([1, None]),
+ pa.array([1, 2, 3])
+ ]
+ chunked_arr = pa.chunked_array(data)
+
+ np_arr = np.asarray(chunked_arr)
+ elements = np_arr.tolist()
+ assert elements[0] == 1.
+ assert np.isnan(elements[1])
+ assert elements[2:] == [1., 2., 3.]
+ assert np_arr.dtype == np.dtype('float64')
+
+ # DictionaryType data will be converted to dense numpy array
+ arr = pa.DictionaryArray.from_arrays(
+ pa.array([0, 1, 2, 0, 1]), pa.array(['a', 'b', 'c']))
+ chunked_arr = pa.chunked_array([arr, arr])
+ np_arr = np.asarray(chunked_arr)
+ assert np_arr.dtype == np.dtype('object')
+ assert np_arr.tolist() == ['a', 'b', 'c', 'a', 'b'] * 2
+
+
+def test_chunked_array_flatten():
+ ty = pa.struct([pa.field('x', pa.int16()),
+ pa.field('y', pa.float32())])
+ a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty)
+ carr = pa.chunked_array(a)
+ x, y = carr.flatten()
+ assert x.equals(pa.chunked_array(pa.array([1, 3, 5], type=pa.int16())))
+ assert y.equals(pa.chunked_array(pa.array([2.5, 4.5, 6.5],
+ type=pa.float32())))
+
+ # Empty column
+ a = pa.array([], type=ty)
+ carr = pa.chunked_array(a)
+ x, y = carr.flatten()
+ assert x.equals(pa.chunked_array(pa.array([], type=pa.int16())))
+ assert y.equals(pa.chunked_array(pa.array([], type=pa.float32())))
+
+
+def test_chunked_array_unify_dictionaries():
+ arr = pa.chunked_array([
+ pa.array(["foo", "bar", None, "foo"]).dictionary_encode(),
+ pa.array(["quux", None, "foo"]).dictionary_encode(),
+ ])
+ assert arr.chunk(0).dictionary.equals(pa.array(["foo", "bar"]))
+ assert arr.chunk(1).dictionary.equals(pa.array(["quux", "foo"]))
+ arr = arr.unify_dictionaries()
+ expected_dict = pa.array(["foo", "bar", "quux"])
+ assert arr.chunk(0).dictionary.equals(expected_dict)
+ assert arr.chunk(1).dictionary.equals(expected_dict)
+ assert arr.to_pylist() == ["foo", "bar", None, "foo", "quux", None, "foo"]
+
+
+def test_recordbatch_basics():
+ data = [
+ pa.array(range(5), type='int16'),
+ pa.array([-10, -5, 0, None, 10], type='int32')
+ ]
+
+ batch = pa.record_batch(data, ['c0', 'c1'])
+ assert not batch.schema.metadata
+
+ assert len(batch) == 5
+ assert batch.num_rows == 5
+ assert batch.num_columns == len(data)
+ # (only the second array has a null bitmap)
+ assert batch.nbytes == (5 * 2) + (5 * 4 + 1)
+ assert sys.getsizeof(batch) >= object.__sizeof__(batch) + batch.nbytes
+ pydict = batch.to_pydict()
+ assert pydict == OrderedDict([
+ ('c0', [0, 1, 2, 3, 4]),
+ ('c1', [-10, -5, 0, None, 10])
+ ])
+ if sys.version_info >= (3, 7):
+ assert type(pydict) == dict
+ else:
+ assert type(pydict) == OrderedDict
+
+ with pytest.raises(IndexError):
+ # bounds checking
+ batch[2]
+
+ # Schema passed explicitly
+ schema = pa.schema([pa.field('c0', pa.int16(),
+ metadata={'key': 'value'}),
+ pa.field('c1', pa.int32())],
+ metadata={b'foo': b'bar'})
+ batch = pa.record_batch(data, schema=schema)
+ assert batch.schema == schema
+ # schema as first positional argument
+ batch = pa.record_batch(data, schema)
+ assert batch.schema == schema
+ assert str(batch) == """pyarrow.RecordBatch
+c0: int16
+c1: int32"""
+
+ assert batch.to_string(show_metadata=True) == """\
+pyarrow.RecordBatch
+c0: int16
+ -- field metadata --
+ key: 'value'
+c1: int32
+-- schema metadata --
+foo: 'bar'"""
+
+ wr = weakref.ref(batch)
+ assert wr() is not None
+ del batch
+ assert wr() is None
+
+
+def test_recordbatch_equals():
+ data1 = [
+ pa.array(range(5), type='int16'),
+ pa.array([-10, -5, 0, None, 10], type='int32')
+ ]
+ data2 = [
+ pa.array(['a', 'b', 'c']),
+ pa.array([['d'], ['e'], ['f']]),
+ ]
+ column_names = ['c0', 'c1']
+
+ batch = pa.record_batch(data1, column_names)
+ assert batch == pa.record_batch(data1, column_names)
+ assert batch.equals(pa.record_batch(data1, column_names))
+
+ assert batch != pa.record_batch(data2, column_names)
+ assert not batch.equals(pa.record_batch(data2, column_names))
+
+ batch_meta = pa.record_batch(data1, names=column_names,
+ metadata={'key': 'value'})
+ assert batch_meta.equals(batch)
+ assert not batch_meta.equals(batch, check_metadata=True)
+
+ # ARROW-8889
+ assert not batch.equals(None)
+ assert batch != "foo"
+
+
+def test_recordbatch_take():
+ batch = pa.record_batch(
+ [pa.array([1, 2, 3, None, 5]),
+ pa.array(['a', 'b', 'c', 'd', 'e'])],
+ ['f1', 'f2'])
+ assert batch.take(pa.array([2, 3])).equals(batch.slice(2, 2))
+ assert batch.take(pa.array([2, None])).equals(
+ pa.record_batch([pa.array([3, None]), pa.array(['c', None])],
+ ['f1', 'f2']))
+
+
+def test_recordbatch_column_sets_private_name():
+ # ARROW-6429
+ rb = pa.record_batch([pa.array([1, 2, 3, 4])], names=['a0'])
+ assert rb[0]._name == 'a0'
+
+
+def test_recordbatch_from_arrays_validate_schema():
+ # ARROW-6263
+ arr = pa.array([1, 2])
+ schema = pa.schema([pa.field('f0', pa.list_(pa.utf8()))])
+ with pytest.raises(NotImplementedError):
+ pa.record_batch([arr], schema=schema)
+
+
+def test_recordbatch_from_arrays_validate_lengths():
+ # ARROW-2820
+ data = [pa.array([1]), pa.array(["tokyo", "like", "happy"]),
+ pa.array(["derek"])]
+
+ with pytest.raises(ValueError):
+ pa.record_batch(data, ['id', 'tags', 'name'])
+
+
+def test_recordbatch_no_fields():
+ batch = pa.record_batch([], [])
+
+ assert len(batch) == 0
+ assert batch.num_rows == 0
+ assert batch.num_columns == 0
+
+
+def test_recordbatch_from_arrays_invalid_names():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ with pytest.raises(ValueError):
+ pa.record_batch(data, names=['a', 'b', 'c'])
+
+ with pytest.raises(ValueError):
+ pa.record_batch(data, names=['a'])
+
+
+def test_recordbatch_empty_metadata():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+
+ batch = pa.record_batch(data, ['c0', 'c1'])
+ assert batch.schema.metadata is None
+
+
+def test_recordbatch_pickle():
+ data = [
+ pa.array(range(5), type='int8'),
+ pa.array([-10, -5, 0, 5, 10], type='float32')
+ ]
+ fields = [
+ pa.field('ints', pa.int8()),
+ pa.field('floats', pa.float32()),
+ ]
+ schema = pa.schema(fields, metadata={b'foo': b'bar'})
+ batch = pa.record_batch(data, schema=schema)
+
+ result = pickle.loads(pickle.dumps(batch))
+ assert result.equals(batch)
+ assert result.schema == schema
+
+
+def test_recordbatch_get_field():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ batch = pa.RecordBatch.from_arrays(data, names=('a', 'b', 'c'))
+
+ assert batch.field('a').equals(batch.schema.field('a'))
+ assert batch.field(0).equals(batch.schema.field('a'))
+
+ with pytest.raises(KeyError):
+ batch.field('d')
+
+ with pytest.raises(TypeError):
+ batch.field(None)
+
+ with pytest.raises(IndexError):
+ batch.field(4)
+
+
+def test_recordbatch_select_column():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ batch = pa.RecordBatch.from_arrays(data, names=('a', 'b', 'c'))
+
+ assert batch.column('a').equals(batch.column(0))
+
+ with pytest.raises(
+ KeyError, match='Field "d" does not exist in record batch schema'):
+ batch.column('d')
+
+ with pytest.raises(TypeError):
+ batch.column(None)
+
+ with pytest.raises(IndexError):
+ batch.column(4)
+
+
+def test_recordbatch_from_struct_array_invalid():
+ with pytest.raises(TypeError):
+ pa.RecordBatch.from_struct_array(pa.array(range(5)))
+
+
+def test_recordbatch_from_struct_array():
+ struct_array = pa.array(
+ [{"ints": 1}, {"floats": 1.0}],
+ type=pa.struct([("ints", pa.int32()), ("floats", pa.float32())]),
+ )
+ result = pa.RecordBatch.from_struct_array(struct_array)
+ assert result.equals(pa.RecordBatch.from_arrays(
+ [
+ pa.array([1, None], type=pa.int32()),
+ pa.array([None, 1.0], type=pa.float32()),
+ ], ["ints", "floats"]
+ ))
+
+
+def _table_like_slice_tests(factory):
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ names = ['c0', 'c1']
+
+ obj = factory(data, names=names)
+
+ sliced = obj.slice(2)
+ assert sliced.num_rows == 3
+
+ expected = factory([x.slice(2) for x in data], names=names)
+ assert sliced.equals(expected)
+
+ sliced2 = obj.slice(2, 2)
+ expected2 = factory([x.slice(2, 2) for x in data], names=names)
+ assert sliced2.equals(expected2)
+
+ # 0 offset
+ assert obj.slice(0).equals(obj)
+
+ # Slice past end of array
+ assert len(obj.slice(len(obj))) == 0
+
+ with pytest.raises(IndexError):
+ obj.slice(-1)
+
+ # Check __getitem__-based slicing
+ assert obj.slice(0, 0).equals(obj[:0])
+ assert obj.slice(0, 2).equals(obj[:2])
+ assert obj.slice(2, 2).equals(obj[2:4])
+ assert obj.slice(2, len(obj) - 2).equals(obj[2:])
+ assert obj.slice(len(obj) - 2, 2).equals(obj[-2:])
+ assert obj.slice(len(obj) - 4, 2).equals(obj[-4:-2])
+
+
+def test_recordbatch_slice_getitem():
+ return _table_like_slice_tests(pa.RecordBatch.from_arrays)
+
+
+def test_table_slice_getitem():
+ return _table_like_slice_tests(pa.table)
+
+
+@pytest.mark.pandas
+def test_slice_zero_length_table():
+ # ARROW-7907: a segfault on this code was fixed after 0.16.0
+ table = pa.table({'a': pa.array([], type=pa.timestamp('us'))})
+ table_slice = table.slice(0, 0)
+ table_slice.to_pandas()
+
+ table = pa.table({'a': pa.chunked_array([], type=pa.string())})
+ table.to_pandas()
+
+
+def test_recordbatchlist_schema_equals():
+ a1 = np.array([1], dtype='uint32')
+ a2 = np.array([4.0, 5.0], dtype='float64')
+ batch1 = pa.record_batch([pa.array(a1)], ['c1'])
+ batch2 = pa.record_batch([pa.array(a2)], ['c1'])
+
+ with pytest.raises(pa.ArrowInvalid):
+ pa.Table.from_batches([batch1, batch2])
+
+
+def test_table_column_sets_private_name():
+ # ARROW-6429
+ t = pa.table([pa.array([1, 2, 3, 4])], names=['a0'])
+ assert t[0]._name == 'a0'
+
+
+def test_table_equals():
+ table = pa.Table.from_arrays([], names=[])
+ assert table.equals(table)
+
+ # ARROW-4822
+ assert not table.equals(None)
+
+ other = pa.Table.from_arrays([], names=[], metadata={'key': 'value'})
+ assert not table.equals(other, check_metadata=True)
+ assert table.equals(other)
+
+
+def test_table_from_batches_and_schema():
+ schema = pa.schema([
+ pa.field('a', pa.int64()),
+ pa.field('b', pa.float64()),
+ ])
+ batch = pa.record_batch([pa.array([1]), pa.array([3.14])],
+ names=['a', 'b'])
+ table = pa.Table.from_batches([batch], schema)
+ assert table.schema.equals(schema)
+ assert table.column(0) == pa.chunked_array([[1]])
+ assert table.column(1) == pa.chunked_array([[3.14]])
+
+ incompatible_schema = pa.schema([pa.field('a', pa.int64())])
+ with pytest.raises(pa.ArrowInvalid):
+ pa.Table.from_batches([batch], incompatible_schema)
+
+ incompatible_batch = pa.record_batch([pa.array([1])], ['a'])
+ with pytest.raises(pa.ArrowInvalid):
+ pa.Table.from_batches([incompatible_batch], schema)
+
+
+@pytest.mark.pandas
+def test_table_to_batches():
+ from pandas.testing import assert_frame_equal
+ import pandas as pd
+
+ df1 = pd.DataFrame({'a': list(range(10))})
+ df2 = pd.DataFrame({'a': list(range(10, 30))})
+
+ batch1 = pa.RecordBatch.from_pandas(df1, preserve_index=False)
+ batch2 = pa.RecordBatch.from_pandas(df2, preserve_index=False)
+
+ table = pa.Table.from_batches([batch1, batch2, batch1])
+
+ expected_df = pd.concat([df1, df2, df1], ignore_index=True)
+
+ batches = table.to_batches()
+ assert len(batches) == 3
+
+ assert_frame_equal(pa.Table.from_batches(batches).to_pandas(),
+ expected_df)
+
+ batches = table.to_batches(max_chunksize=15)
+ assert list(map(len, batches)) == [10, 15, 5, 10]
+
+ assert_frame_equal(table.to_pandas(), expected_df)
+ assert_frame_equal(pa.Table.from_batches(batches).to_pandas(),
+ expected_df)
+
+ table_from_iter = pa.Table.from_batches(iter([batch1, batch2, batch1]))
+ assert table.equals(table_from_iter)
+
+
+def test_table_basics():
+ data = [
+ pa.array(range(5), type='int64'),
+ pa.array([-10, -5, 0, 5, 10], type='int64')
+ ]
+ table = pa.table(data, names=('a', 'b'))
+ table.validate()
+ assert len(table) == 5
+ assert table.num_rows == 5
+ assert table.num_columns == 2
+ assert table.shape == (5, 2)
+ assert table.nbytes == 2 * (5 * 8)
+ assert sys.getsizeof(table) >= object.__sizeof__(table) + table.nbytes
+ pydict = table.to_pydict()
+ assert pydict == OrderedDict([
+ ('a', [0, 1, 2, 3, 4]),
+ ('b', [-10, -5, 0, 5, 10])
+ ])
+ if sys.version_info >= (3, 7):
+ assert type(pydict) == dict
+ else:
+ assert type(pydict) == OrderedDict
+
+ columns = []
+ for col in table.itercolumns():
+ columns.append(col)
+ for chunk in col.iterchunks():
+ assert chunk is not None
+
+ with pytest.raises(IndexError):
+ col.chunk(-1)
+
+ with pytest.raises(IndexError):
+ col.chunk(col.num_chunks)
+
+ assert table.columns == columns
+ assert table == pa.table(columns, names=table.column_names)
+ assert table != pa.table(columns[1:], names=table.column_names[1:])
+ assert table != columns
+
+ wr = weakref.ref(table)
+ assert wr() is not None
+ del table
+ assert wr() is None
+
+
+def test_table_from_arrays_preserves_column_metadata():
+ # Added to test https://issues.apache.org/jira/browse/ARROW-3866
+ arr0 = pa.array([1, 2])
+ arr1 = pa.array([3, 4])
+ field0 = pa.field('field1', pa.int64(), metadata=dict(a="A", b="B"))
+ field1 = pa.field('field2', pa.int64(), nullable=False)
+ table = pa.Table.from_arrays([arr0, arr1],
+ schema=pa.schema([field0, field1]))
+ assert b"a" in table.field(0).metadata
+ assert table.field(1).nullable is False
+
+
+def test_table_from_arrays_invalid_names():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ with pytest.raises(ValueError):
+ pa.Table.from_arrays(data, names=['a', 'b', 'c'])
+
+ with pytest.raises(ValueError):
+ pa.Table.from_arrays(data, names=['a'])
+
+
+def test_table_from_lists():
+ data = [
+ list(range(5)),
+ [-10, -5, 0, 5, 10]
+ ]
+
+ result = pa.table(data, names=['a', 'b'])
+ expected = pa.Table.from_arrays(data, names=['a', 'b'])
+ assert result.equals(expected)
+
+ schema = pa.schema([
+ pa.field('a', pa.uint16()),
+ pa.field('b', pa.int64())
+ ])
+ result = pa.table(data, schema=schema)
+ expected = pa.Table.from_arrays(data, schema=schema)
+ assert result.equals(expected)
+
+
+def test_table_pickle():
+ data = [
+ pa.chunked_array([[1, 2], [3, 4]], type=pa.uint32()),
+ pa.chunked_array([["some", "strings", None, ""]], type=pa.string()),
+ ]
+ schema = pa.schema([pa.field('ints', pa.uint32()),
+ pa.field('strs', pa.string())],
+ metadata={b'foo': b'bar'})
+ table = pa.Table.from_arrays(data, schema=schema)
+
+ result = pickle.loads(pickle.dumps(table))
+ result.validate()
+ assert result.equals(table)
+
+
+def test_table_get_field():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ table = pa.Table.from_arrays(data, names=('a', 'b', 'c'))
+
+ assert table.field('a').equals(table.schema.field('a'))
+ assert table.field(0).equals(table.schema.field('a'))
+
+ with pytest.raises(KeyError):
+ table.field('d')
+
+ with pytest.raises(TypeError):
+ table.field(None)
+
+ with pytest.raises(IndexError):
+ table.field(4)
+
+
+def test_table_select_column():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ table = pa.Table.from_arrays(data, names=('a', 'b', 'c'))
+
+ assert table.column('a').equals(table.column(0))
+
+ with pytest.raises(KeyError,
+ match='Field "d" does not exist in table schema'):
+ table.column('d')
+
+ with pytest.raises(TypeError):
+ table.column(None)
+
+ with pytest.raises(IndexError):
+ table.column(4)
+
+
+def test_table_column_with_duplicates():
+ # ARROW-8209
+ table = pa.table([pa.array([1, 2, 3]),
+ pa.array([4, 5, 6]),
+ pa.array([7, 8, 9])], names=['a', 'b', 'a'])
+
+ with pytest.raises(KeyError,
+ match='Field "a" exists 2 times in table schema'):
+ table.column('a')
+
+
+def test_table_add_column():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ table = pa.Table.from_arrays(data, names=('a', 'b', 'c'))
+
+ new_field = pa.field('d', data[1].type)
+ t2 = table.add_column(3, new_field, data[1])
+ t3 = table.append_column(new_field, data[1])
+
+ expected = pa.Table.from_arrays(data + [data[1]],
+ names=('a', 'b', 'c', 'd'))
+ assert t2.equals(expected)
+ assert t3.equals(expected)
+
+ t4 = table.add_column(0, new_field, data[1])
+ expected = pa.Table.from_arrays([data[1]] + data,
+ names=('d', 'a', 'b', 'c'))
+ assert t4.equals(expected)
+
+
+def test_table_set_column():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ table = pa.Table.from_arrays(data, names=('a', 'b', 'c'))
+
+ new_field = pa.field('d', data[1].type)
+ t2 = table.set_column(0, new_field, data[1])
+
+ expected_data = list(data)
+ expected_data[0] = data[1]
+ expected = pa.Table.from_arrays(expected_data,
+ names=('d', 'b', 'c'))
+ assert t2.equals(expected)
+
+
+def test_table_drop():
+ """ drop one or more columns given labels"""
+ a = pa.array(range(5))
+ b = pa.array([-10, -5, 0, 5, 10])
+ c = pa.array(range(5, 10))
+
+ table = pa.Table.from_arrays([a, b, c], names=('a', 'b', 'c'))
+ t2 = table.drop(['a', 'b'])
+
+ exp = pa.Table.from_arrays([c], names=('c',))
+ assert exp.equals(t2)
+
+ # -- raise KeyError if column not in Table
+ with pytest.raises(KeyError, match="Column 'd' not found"):
+ table.drop(['d'])
+
+
+def test_table_remove_column():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ table = pa.Table.from_arrays(data, names=('a', 'b', 'c'))
+
+ t2 = table.remove_column(0)
+ t2.validate()
+ expected = pa.Table.from_arrays(data[1:], names=('b', 'c'))
+ assert t2.equals(expected)
+
+
+def test_table_remove_column_empty():
+ # ARROW-1865
+ data = [
+ pa.array(range(5)),
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ t2 = table.remove_column(0)
+ t2.validate()
+ assert len(t2) == len(table)
+
+ t3 = t2.add_column(0, table.field(0), table[0])
+ t3.validate()
+ assert t3.equals(table)
+
+
+def test_empty_table_with_names():
+ # ARROW-13784
+ data = []
+ names = ["a", "b"]
+ message = (
+ 'Length of names [(]2[)] does not match length of arrays [(]0[)]')
+ with pytest.raises(ValueError, match=message):
+ pa.Table.from_arrays(data, names=names)
+
+
+def test_empty_table():
+ table = pa.table([])
+
+ assert table.column_names == []
+ assert table.equals(pa.Table.from_arrays([], []))
+
+
+def test_table_rename_columns():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array(range(5, 10))
+ ]
+ table = pa.Table.from_arrays(data, names=['a', 'b', 'c'])
+ assert table.column_names == ['a', 'b', 'c']
+
+ t2 = table.rename_columns(['eh', 'bee', 'sea'])
+ t2.validate()
+ assert t2.column_names == ['eh', 'bee', 'sea']
+
+ expected = pa.Table.from_arrays(data, names=['eh', 'bee', 'sea'])
+ assert t2.equals(expected)
+
+
+def test_table_flatten():
+ ty1 = pa.struct([pa.field('x', pa.int16()),
+ pa.field('y', pa.float32())])
+ ty2 = pa.struct([pa.field('nest', ty1)])
+ a = pa.array([(1, 2.5), (3, 4.5)], type=ty1)
+ b = pa.array([((11, 12.5),), ((13, 14.5),)], type=ty2)
+ c = pa.array([False, True], type=pa.bool_())
+
+ table = pa.Table.from_arrays([a, b, c], names=['a', 'b', 'c'])
+ t2 = table.flatten()
+ t2.validate()
+ expected = pa.Table.from_arrays([
+ pa.array([1, 3], type=pa.int16()),
+ pa.array([2.5, 4.5], type=pa.float32()),
+ pa.array([(11, 12.5), (13, 14.5)], type=ty1),
+ c],
+ names=['a.x', 'a.y', 'b.nest', 'c'])
+ assert t2.equals(expected)
+
+
+def test_table_combine_chunks():
+ batch1 = pa.record_batch([pa.array([1]), pa.array(["a"])],
+ names=['f1', 'f2'])
+ batch2 = pa.record_batch([pa.array([2]), pa.array(["b"])],
+ names=['f1', 'f2'])
+ table = pa.Table.from_batches([batch1, batch2])
+ combined = table.combine_chunks()
+ combined.validate()
+ assert combined.equals(table)
+ for c in combined.columns:
+ assert c.num_chunks == 1
+
+
+def test_table_unify_dictionaries():
+ batch1 = pa.record_batch([
+ pa.array(["foo", "bar", None, "foo"]).dictionary_encode(),
+ pa.array([123, 456, 456, 789]).dictionary_encode(),
+ pa.array([True, False, None, None])], names=['a', 'b', 'c'])
+ batch2 = pa.record_batch([
+ pa.array(["quux", "foo", None, "quux"]).dictionary_encode(),
+ pa.array([456, 789, 789, None]).dictionary_encode(),
+ pa.array([False, None, None, True])], names=['a', 'b', 'c'])
+
+ table = pa.Table.from_batches([batch1, batch2])
+ table = table.replace_schema_metadata({b"key1": b"value1"})
+ assert table.column(0).chunk(0).dictionary.equals(
+ pa.array(["foo", "bar"]))
+ assert table.column(0).chunk(1).dictionary.equals(
+ pa.array(["quux", "foo"]))
+ assert table.column(1).chunk(0).dictionary.equals(
+ pa.array([123, 456, 789]))
+ assert table.column(1).chunk(1).dictionary.equals(
+ pa.array([456, 789]))
+
+ table = table.unify_dictionaries(pa.default_memory_pool())
+ expected_dict_0 = pa.array(["foo", "bar", "quux"])
+ expected_dict_1 = pa.array([123, 456, 789])
+ assert table.column(0).chunk(0).dictionary.equals(expected_dict_0)
+ assert table.column(0).chunk(1).dictionary.equals(expected_dict_0)
+ assert table.column(1).chunk(0).dictionary.equals(expected_dict_1)
+ assert table.column(1).chunk(1).dictionary.equals(expected_dict_1)
+
+ assert table.to_pydict() == {
+ 'a': ["foo", "bar", None, "foo", "quux", "foo", None, "quux"],
+ 'b': [123, 456, 456, 789, 456, 789, 789, None],
+ 'c': [True, False, None, None, False, None, None, True],
+ }
+ assert table.schema.metadata == {b"key1": b"value1"}
+
+
+def test_concat_tables():
+ data = [
+ list(range(5)),
+ [-10., -5., 0., 5., 10.]
+ ]
+ data2 = [
+ list(range(5, 10)),
+ [1., 2., 3., 4., 5.]
+ ]
+
+ t1 = pa.Table.from_arrays([pa.array(x) for x in data],
+ names=('a', 'b'))
+ t2 = pa.Table.from_arrays([pa.array(x) for x in data2],
+ names=('a', 'b'))
+
+ result = pa.concat_tables([t1, t2])
+ result.validate()
+ assert len(result) == 10
+
+ expected = pa.Table.from_arrays([pa.array(x + y)
+ for x, y in zip(data, data2)],
+ names=('a', 'b'))
+
+ assert result.equals(expected)
+
+
+def test_concat_tables_none_table():
+ # ARROW-11997
+ with pytest.raises(AttributeError):
+ pa.concat_tables([None])
+
+
+@pytest.mark.pandas
+def test_concat_tables_with_different_schema_metadata():
+ import pandas as pd
+
+ schema = pa.schema([
+ pa.field('a', pa.string()),
+ pa.field('b', pa.string()),
+ ])
+
+ values = list('abcdefgh')
+ df1 = pd.DataFrame({'a': values, 'b': values})
+ df2 = pd.DataFrame({'a': [np.nan] * 8, 'b': values})
+
+ table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False)
+ table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False)
+ assert table1.schema.equals(table2.schema)
+ assert not table1.schema.equals(table2.schema, check_metadata=True)
+
+ table3 = pa.concat_tables([table1, table2])
+ assert table1.schema.equals(table3.schema, check_metadata=True)
+ assert table2.schema.equals(table3.schema)
+
+
+def test_concat_tables_with_promotion():
+ t1 = pa.Table.from_arrays(
+ [pa.array([1, 2], type=pa.int64())], ["int64_field"])
+ t2 = pa.Table.from_arrays(
+ [pa.array([1.0, 2.0], type=pa.float32())], ["float_field"])
+
+ result = pa.concat_tables([t1, t2], promote=True)
+
+ assert result.equals(pa.Table.from_arrays([
+ pa.array([1, 2, None, None], type=pa.int64()),
+ pa.array([None, None, 1.0, 2.0], type=pa.float32()),
+ ], ["int64_field", "float_field"]))
+
+
+def test_concat_tables_with_promotion_error():
+ t1 = pa.Table.from_arrays(
+ [pa.array([1, 2], type=pa.int64())], ["f"])
+ t2 = pa.Table.from_arrays(
+ [pa.array([1, 2], type=pa.float32())], ["f"])
+
+ with pytest.raises(pa.ArrowInvalid):
+ pa.concat_tables([t1, t2], promote=True)
+
+
+def test_table_negative_indexing():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ pa.array([1.0, 2.0, 3.0, 4.0, 5.0]),
+ pa.array(['ab', 'bc', 'cd', 'de', 'ef']),
+ ]
+ table = pa.Table.from_arrays(data, names=tuple('abcd'))
+
+ assert table[-1].equals(table[3])
+ assert table[-2].equals(table[2])
+ assert table[-3].equals(table[1])
+ assert table[-4].equals(table[0])
+
+ with pytest.raises(IndexError):
+ table[-5]
+
+ with pytest.raises(IndexError):
+ table[4]
+
+
+def test_table_cast_to_incompatible_schema():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ ]
+ table = pa.Table.from_arrays(data, names=tuple('ab'))
+
+ target_schema1 = pa.schema([
+ pa.field('A', pa.int32()),
+ pa.field('b', pa.int16()),
+ ])
+ target_schema2 = pa.schema([
+ pa.field('a', pa.int32()),
+ ])
+ message = ("Target schema's field names are not matching the table's "
+ "field names:.*")
+ with pytest.raises(ValueError, match=message):
+ table.cast(target_schema1)
+ with pytest.raises(ValueError, match=message):
+ table.cast(target_schema2)
+
+
+def test_table_safe_casting():
+ data = [
+ pa.array(range(5), type=pa.int64()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int32()),
+ pa.array([1.0, 2.0, 3.0, 4.0, 5.0], type=pa.float64()),
+ pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string())
+ ]
+ table = pa.Table.from_arrays(data, names=tuple('abcd'))
+
+ expected_data = [
+ pa.array(range(5), type=pa.int32()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int16()),
+ pa.array([1, 2, 3, 4, 5], type=pa.int64()),
+ pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string())
+ ]
+ expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd'))
+
+ target_schema = pa.schema([
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.int16()),
+ pa.field('c', pa.int64()),
+ pa.field('d', pa.string())
+ ])
+ casted_table = table.cast(target_schema)
+
+ assert casted_table.equals(expected_table)
+
+
+def test_table_unsafe_casting():
+ data = [
+ pa.array(range(5), type=pa.int64()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int32()),
+ pa.array([1.1, 2.2, 3.3, 4.4, 5.5], type=pa.float64()),
+ pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string())
+ ]
+ table = pa.Table.from_arrays(data, names=tuple('abcd'))
+
+ expected_data = [
+ pa.array(range(5), type=pa.int32()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int16()),
+ pa.array([1, 2, 3, 4, 5], type=pa.int64()),
+ pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string())
+ ]
+ expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd'))
+
+ target_schema = pa.schema([
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.int16()),
+ pa.field('c', pa.int64()),
+ pa.field('d', pa.string())
+ ])
+
+ with pytest.raises(pa.ArrowInvalid, match='truncated'):
+ table.cast(target_schema)
+
+ casted_table = table.cast(target_schema, safe=False)
+ assert casted_table.equals(expected_table)
+
+
+def test_invalid_table_construct():
+ array = np.array([0, 1], dtype=np.uint8)
+ u8 = pa.uint8()
+ arrays = [pa.array(array, type=u8), pa.array(array[1:], type=u8)]
+
+ with pytest.raises(pa.lib.ArrowInvalid):
+ pa.Table.from_arrays(arrays, names=["a1", "a2"])
+
+
+@pytest.mark.parametrize('data, klass', [
+ ((['', 'foo', 'bar'], [4.5, 5, None]), list),
+ ((['', 'foo', 'bar'], [4.5, 5, None]), pa.array),
+ (([[''], ['foo', 'bar']], [[4.5], [5., None]]), pa.chunked_array),
+])
+def test_from_arrays_schema(data, klass):
+ data = [klass(data[0]), klass(data[1])]
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())])
+
+ table = pa.Table.from_arrays(data, schema=schema)
+ assert table.num_columns == 2
+ assert table.num_rows == 3
+ assert table.schema == schema
+
+ # length of data and schema not matching
+ schema = pa.schema([('strs', pa.utf8())])
+ with pytest.raises(ValueError):
+ pa.Table.from_arrays(data, schema=schema)
+
+ # with different but compatible schema
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())])
+ table = pa.Table.from_arrays(data, schema=schema)
+ assert pa.types.is_float32(table.column('floats').type)
+ assert table.num_columns == 2
+ assert table.num_rows == 3
+ assert table.schema == schema
+
+ # with different and incompatible schema
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.timestamp('s'))])
+ with pytest.raises((NotImplementedError, TypeError)):
+ pa.Table.from_pydict(data, schema=schema)
+
+ # Cannot pass both schema and metadata / names
+ with pytest.raises(ValueError):
+ pa.Table.from_arrays(data, schema=schema, names=['strs', 'floats'])
+
+ with pytest.raises(ValueError):
+ pa.Table.from_arrays(data, schema=schema, metadata={b'foo': b'bar'})
+
+
+@pytest.mark.parametrize(
+ ('cls'),
+ [
+ (pa.Table),
+ (pa.RecordBatch)
+ ]
+)
+def test_table_from_pydict(cls):
+ table = cls.from_pydict({})
+ assert table.num_columns == 0
+ assert table.num_rows == 0
+ assert table.schema == pa.schema([])
+ assert table.to_pydict() == {}
+
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64())])
+
+ # With lists as values
+ data = OrderedDict([('strs', ['', 'foo', 'bar']),
+ ('floats', [4.5, 5, None])])
+ table = cls.from_pydict(data)
+ assert table.num_columns == 2
+ assert table.num_rows == 3
+ assert table.schema == schema
+ assert table.to_pydict() == data
+
+ # With metadata and inferred schema
+ metadata = {b'foo': b'bar'}
+ schema = schema.with_metadata(metadata)
+ table = cls.from_pydict(data, metadata=metadata)
+ assert table.schema == schema
+ assert table.schema.metadata == metadata
+ assert table.to_pydict() == data
+
+ # With explicit schema
+ table = cls.from_pydict(data, schema=schema)
+ assert table.schema == schema
+ assert table.schema.metadata == metadata
+ assert table.to_pydict() == data
+
+ # Cannot pass both schema and metadata
+ with pytest.raises(ValueError):
+ cls.from_pydict(data, schema=schema, metadata=metadata)
+
+ # Non-convertible values given schema
+ with pytest.raises(TypeError):
+ cls.from_pydict({'c0': [0, 1, 2]},
+ schema=pa.schema([("c0", pa.string())]))
+
+ # Missing schema fields from the passed mapping
+ with pytest.raises(KeyError, match="doesn\'t contain.* c, d"):
+ cls.from_pydict(
+ {'a': [1, 2, 3], 'b': [3, 4, 5]},
+ schema=pa.schema([
+ ('a', pa.int64()),
+ ('c', pa.int32()),
+ ('d', pa.int16())
+ ])
+ )
+
+ # Passed wrong schema type
+ with pytest.raises(TypeError):
+ cls.from_pydict({'a': [1, 2, 3]}, schema={})
+
+
+@pytest.mark.parametrize('data, klass', [
+ ((['', 'foo', 'bar'], [4.5, 5, None]), pa.array),
+ (([[''], ['foo', 'bar']], [[4.5], [5., None]]), pa.chunked_array),
+])
+def test_table_from_pydict_arrow_arrays(data, klass):
+ data = OrderedDict([('strs', klass(data[0])), ('floats', klass(data[1]))])
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64())])
+
+ # With arrays as values
+ table = pa.Table.from_pydict(data)
+ assert table.num_columns == 2
+ assert table.num_rows == 3
+ assert table.schema == schema
+
+ # With explicit (matching) schema
+ table = pa.Table.from_pydict(data, schema=schema)
+ assert table.num_columns == 2
+ assert table.num_rows == 3
+ assert table.schema == schema
+
+ # with different but compatible schema
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())])
+ table = pa.Table.from_pydict(data, schema=schema)
+ assert pa.types.is_float32(table.column('floats').type)
+ assert table.num_columns == 2
+ assert table.num_rows == 3
+ assert table.schema == schema
+
+ # with different and incompatible schema
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.timestamp('s'))])
+ with pytest.raises((NotImplementedError, TypeError)):
+ pa.Table.from_pydict(data, schema=schema)
+
+
+@pytest.mark.parametrize('data, klass', [
+ ((['', 'foo', 'bar'], [4.5, 5, None]), list),
+ ((['', 'foo', 'bar'], [4.5, 5, None]), pa.array),
+ (([[''], ['foo', 'bar']], [[4.5], [5., None]]), pa.chunked_array),
+])
+def test_table_from_pydict_schema(data, klass):
+ # passed schema is source of truth for the columns
+
+ data = OrderedDict([('strs', klass(data[0])), ('floats', klass(data[1]))])
+
+ # schema has columns not present in data -> error
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64()),
+ ('ints', pa.int64())])
+ with pytest.raises(KeyError, match='ints'):
+ pa.Table.from_pydict(data, schema=schema)
+
+ # data has columns not present in schema -> ignored
+ schema = pa.schema([('strs', pa.utf8())])
+ table = pa.Table.from_pydict(data, schema=schema)
+ assert table.num_columns == 1
+ assert table.schema == schema
+ assert table.column_names == ['strs']
+
+
+@pytest.mark.pandas
+def test_table_from_pandas_schema():
+ # passed schema is source of truth for the columns
+ import pandas as pd
+
+ df = pd.DataFrame(OrderedDict([('strs', ['', 'foo', 'bar']),
+ ('floats', [4.5, 5, None])]))
+
+ # with different but compatible schema
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())])
+ table = pa.Table.from_pandas(df, schema=schema)
+ assert pa.types.is_float32(table.column('floats').type)
+ assert table.schema.remove_metadata() == schema
+
+ # with different and incompatible schema
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.timestamp('s'))])
+ with pytest.raises((NotImplementedError, TypeError)):
+ pa.Table.from_pandas(df, schema=schema)
+
+ # schema has columns not present in data -> error
+ schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64()),
+ ('ints', pa.int64())])
+ with pytest.raises(KeyError, match='ints'):
+ pa.Table.from_pandas(df, schema=schema)
+
+ # data has columns not present in schema -> ignored
+ schema = pa.schema([('strs', pa.utf8())])
+ table = pa.Table.from_pandas(df, schema=schema)
+ assert table.num_columns == 1
+ assert table.schema.remove_metadata() == schema
+ assert table.column_names == ['strs']
+
+
+@pytest.mark.pandas
+def test_table_factory_function():
+ import pandas as pd
+
+ # Put in wrong order to make sure that lines up with schema
+ d = OrderedDict([('b', ['a', 'b', 'c']), ('a', [1, 2, 3])])
+
+ d_explicit = {'b': pa.array(['a', 'b', 'c'], type='string'),
+ 'a': pa.array([1, 2, 3], type='int32')}
+
+ schema = pa.schema([('a', pa.int32()), ('b', pa.string())])
+
+ df = pd.DataFrame(d)
+ table1 = pa.table(df)
+ table2 = pa.Table.from_pandas(df)
+ assert table1.equals(table2)
+ table1 = pa.table(df, schema=schema)
+ table2 = pa.Table.from_pandas(df, schema=schema)
+ assert table1.equals(table2)
+
+ table1 = pa.table(d_explicit)
+ table2 = pa.Table.from_pydict(d_explicit)
+ assert table1.equals(table2)
+
+ # schema coerces type
+ table1 = pa.table(d, schema=schema)
+ table2 = pa.Table.from_pydict(d, schema=schema)
+ assert table1.equals(table2)
+
+
+def test_table_factory_function_args():
+ # from_pydict not accepting names:
+ with pytest.raises(ValueError):
+ pa.table({'a': [1, 2, 3]}, names=['a'])
+
+ # backwards compatibility for schema as first positional argument
+ schema = pa.schema([('a', pa.int32())])
+ table = pa.table({'a': pa.array([1, 2, 3], type=pa.int64())}, schema)
+ assert table.column('a').type == pa.int32()
+
+ # from_arrays: accept both names and schema as positional first argument
+ data = [pa.array([1, 2, 3], type='int64')]
+ names = ['a']
+ table = pa.table(data, names)
+ assert table.column_names == names
+ schema = pa.schema([('a', pa.int64())])
+ table = pa.table(data, schema)
+ assert table.column_names == names
+
+
+@pytest.mark.pandas
+def test_table_factory_function_args_pandas():
+ import pandas as pd
+
+ # from_pandas not accepting names or metadata:
+ with pytest.raises(ValueError):
+ pa.table(pd.DataFrame({'a': [1, 2, 3]}), names=['a'])
+
+ with pytest.raises(ValueError):
+ pa.table(pd.DataFrame({'a': [1, 2, 3]}), metadata={b'foo': b'bar'})
+
+ # backwards compatibility for schema as first positional argument
+ schema = pa.schema([('a', pa.int32())])
+ table = pa.table(pd.DataFrame({'a': [1, 2, 3]}), schema)
+ assert table.column('a').type == pa.int32()
+
+
+def test_factory_functions_invalid_input():
+ with pytest.raises(TypeError, match="Expected pandas DataFrame, python"):
+ pa.table("invalid input")
+
+ with pytest.raises(TypeError, match="Expected pandas DataFrame"):
+ pa.record_batch("invalid input")
+
+
+def test_table_repr_to_string():
+ # Schema passed explicitly
+ schema = pa.schema([pa.field('c0', pa.int16(),
+ metadata={'key': 'value'}),
+ pa.field('c1', pa.int32())],
+ metadata={b'foo': b'bar'})
+
+ tab = pa.table([pa.array([1, 2, 3, 4], type='int16'),
+ pa.array([10, 20, 30, 40], type='int32')], schema=schema)
+ assert str(tab) == """pyarrow.Table
+c0: int16
+c1: int32
+----
+c0: [[1,2,3,4]]
+c1: [[10,20,30,40]]"""
+
+ assert tab.to_string(show_metadata=True) == """\
+pyarrow.Table
+c0: int16
+ -- field metadata --
+ key: 'value'
+c1: int32
+-- schema metadata --
+foo: 'bar'"""
+
+ assert tab.to_string(preview_cols=5) == """\
+pyarrow.Table
+c0: int16
+c1: int32
+----
+c0: [[1,2,3,4]]
+c1: [[10,20,30,40]]"""
+
+ assert tab.to_string(preview_cols=1) == """\
+pyarrow.Table
+c0: int16
+c1: int32
+----
+c0: [[1,2,3,4]]
+..."""
+
+
+def test_table_repr_to_string_ellipsis():
+ # Schema passed explicitly
+ schema = pa.schema([pa.field('c0', pa.int16(),
+ metadata={'key': 'value'}),
+ pa.field('c1', pa.int32())],
+ metadata={b'foo': b'bar'})
+
+ tab = pa.table([pa.array([1, 2, 3, 4]*10, type='int16'),
+ pa.array([10, 20, 30, 40]*10, type='int32')],
+ schema=schema)
+ assert str(tab) == """pyarrow.Table
+c0: int16
+c1: int32
+----
+c0: [[1,2,3,4,1,2,3,4,1,2,...,3,4,1,2,3,4,1,2,3,4]]
+c1: [[10,20,30,40,10,20,30,40,10,20,...,30,40,10,20,30,40,10,20,30,40]]"""
+
+
+def test_table_function_unicode_schema():
+ col_a = "äääh"
+ col_b = "öööf"
+
+ # Put in wrong order to make sure that lines up with schema
+ d = OrderedDict([(col_b, ['a', 'b', 'c']), (col_a, [1, 2, 3])])
+
+ schema = pa.schema([(col_a, pa.int32()), (col_b, pa.string())])
+
+ result = pa.table(d, schema=schema)
+ assert result[0].chunk(0).equals(pa.array([1, 2, 3], type='int32'))
+ assert result[1].chunk(0).equals(pa.array(['a', 'b', 'c'], type='string'))
+
+
+def test_table_take_vanilla_functionality():
+ table = pa.table(
+ [pa.array([1, 2, 3, None, 5]),
+ pa.array(['a', 'b', 'c', 'd', 'e'])],
+ ['f1', 'f2'])
+
+ assert table.take(pa.array([2, 3])).equals(table.slice(2, 2))
+
+
+def test_table_take_null_index():
+ table = pa.table(
+ [pa.array([1, 2, 3, None, 5]),
+ pa.array(['a', 'b', 'c', 'd', 'e'])],
+ ['f1', 'f2'])
+
+ result_with_null_index = pa.table(
+ [pa.array([1, None]),
+ pa.array(['a', None])],
+ ['f1', 'f2'])
+
+ assert table.take(pa.array([0, None])).equals(result_with_null_index)
+
+
+def test_table_take_non_consecutive():
+ table = pa.table(
+ [pa.array([1, 2, 3, None, 5]),
+ pa.array(['a', 'b', 'c', 'd', 'e'])],
+ ['f1', 'f2'])
+
+ result_non_consecutive = pa.table(
+ [pa.array([2, None]),
+ pa.array(['b', 'd'])],
+ ['f1', 'f2'])
+
+ assert table.take(pa.array([1, 3])).equals(result_non_consecutive)
+
+
+def test_table_select():
+ a1 = pa.array([1, 2, 3, None, 5])
+ a2 = pa.array(['a', 'b', 'c', 'd', 'e'])
+ a3 = pa.array([[1, 2], [3, 4], [5, 6], None, [9, 10]])
+ table = pa.table([a1, a2, a3], ['f1', 'f2', 'f3'])
+
+ # selecting with string names
+ result = table.select(['f1'])
+ expected = pa.table([a1], ['f1'])
+ assert result.equals(expected)
+
+ result = table.select(['f3', 'f2'])
+ expected = pa.table([a3, a2], ['f3', 'f2'])
+ assert result.equals(expected)
+
+ # selecting with integer indices
+ result = table.select([0])
+ expected = pa.table([a1], ['f1'])
+ assert result.equals(expected)
+
+ result = table.select([2, 1])
+ expected = pa.table([a3, a2], ['f3', 'f2'])
+ assert result.equals(expected)
+
+ # preserve metadata
+ table2 = table.replace_schema_metadata({"a": "test"})
+ result = table2.select(["f1", "f2"])
+ assert b"a" in result.schema.metadata
+
+ # selecting non-existing column raises
+ with pytest.raises(KeyError, match='Field "f5" does not exist'):
+ table.select(['f5'])
+
+ with pytest.raises(IndexError, match="index out of bounds"):
+ table.select([5])
+
+ # duplicate selection gives duplicated names in resulting table
+ result = table.select(['f2', 'f2'])
+ expected = pa.table([a2, a2], ['f2', 'f2'])
+ assert result.equals(expected)
+
+ # selection duplicated column raises
+ table = pa.table([a1, a2, a3], ['f1', 'f2', 'f1'])
+ with pytest.raises(KeyError, match='Field "f1" exists 2 times'):
+ table.select(['f1'])
+
+ result = table.select(['f2'])
+ expected = pa.table([a2], ['f2'])
+ assert result.equals(expected)
diff --git a/src/arrow/python/pyarrow/tests/test_tensor.py b/src/arrow/python/pyarrow/tests/test_tensor.py
new file mode 100644
index 000000000..aee46bc93
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_tensor.py
@@ -0,0 +1,216 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+import pytest
+import weakref
+
+import numpy as np
+import pyarrow as pa
+
+
+tensor_type_pairs = [
+ ('i1', pa.int8()),
+ ('i2', pa.int16()),
+ ('i4', pa.int32()),
+ ('i8', pa.int64()),
+ ('u1', pa.uint8()),
+ ('u2', pa.uint16()),
+ ('u4', pa.uint32()),
+ ('u8', pa.uint64()),
+ ('f2', pa.float16()),
+ ('f4', pa.float32()),
+ ('f8', pa.float64())
+]
+
+
+def test_tensor_attrs():
+ data = np.random.randn(10, 4)
+
+ tensor = pa.Tensor.from_numpy(data)
+
+ assert tensor.ndim == 2
+ assert tensor.dim_names == []
+ assert tensor.size == 40
+ assert tensor.shape == data.shape
+ assert tensor.strides == data.strides
+
+ assert tensor.is_contiguous
+ assert tensor.is_mutable
+
+ # not writeable
+ data2 = data.copy()
+ data2.flags.writeable = False
+ tensor = pa.Tensor.from_numpy(data2)
+ assert not tensor.is_mutable
+
+ # With dim_names
+ tensor = pa.Tensor.from_numpy(data, dim_names=('x', 'y'))
+ assert tensor.ndim == 2
+ assert tensor.dim_names == ['x', 'y']
+ assert tensor.dim_name(0) == 'x'
+ assert tensor.dim_name(1) == 'y'
+
+ wr = weakref.ref(tensor)
+ assert wr() is not None
+ del tensor
+ assert wr() is None
+
+
+def test_tensor_base_object():
+ tensor = pa.Tensor.from_numpy(np.random.randn(10, 4))
+ n = sys.getrefcount(tensor)
+ array = tensor.to_numpy() # noqa
+ assert sys.getrefcount(tensor) == n + 1
+
+
+@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs)
+def test_tensor_numpy_roundtrip(dtype_str, arrow_type):
+ dtype = np.dtype(dtype_str)
+ data = (100 * np.random.randn(10, 4)).astype(dtype)
+
+ tensor = pa.Tensor.from_numpy(data)
+ assert tensor.type == arrow_type
+
+ repr(tensor)
+
+ result = tensor.to_numpy()
+ assert (data == result).all()
+
+
+def test_tensor_ipc_roundtrip(tmpdir):
+ data = np.random.randn(10, 4)
+ tensor = pa.Tensor.from_numpy(data)
+
+ path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-roundtrip')
+ mmap = pa.create_memory_map(path, 1024)
+
+ pa.ipc.write_tensor(tensor, mmap)
+
+ mmap.seek(0)
+ result = pa.ipc.read_tensor(mmap)
+
+ assert result.equals(tensor)
+
+
+@pytest.mark.gzip
+def test_tensor_ipc_read_from_compressed(tempdir):
+ # ARROW-5910
+ data = np.random.randn(10, 4)
+ tensor = pa.Tensor.from_numpy(data)
+
+ path = tempdir / 'tensor-compressed-file'
+
+ out_stream = pa.output_stream(path, compression='gzip')
+ pa.ipc.write_tensor(tensor, out_stream)
+ out_stream.close()
+
+ result = pa.ipc.read_tensor(pa.input_stream(path, compression='gzip'))
+ assert result.equals(tensor)
+
+
+def test_tensor_ipc_strided(tmpdir):
+ data1 = np.random.randn(10, 4)
+ tensor1 = pa.Tensor.from_numpy(data1[::2])
+
+ data2 = np.random.randn(10, 6, 4)
+ tensor2 = pa.Tensor.from_numpy(data2[::, ::2, ::])
+
+ path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-strided')
+ mmap = pa.create_memory_map(path, 2048)
+
+ for tensor in [tensor1, tensor2]:
+ mmap.seek(0)
+ pa.ipc.write_tensor(tensor, mmap)
+
+ mmap.seek(0)
+ result = pa.ipc.read_tensor(mmap)
+
+ assert result.equals(tensor)
+
+
+def test_tensor_equals():
+ def eq(a, b):
+ assert a.equals(b)
+ assert a == b
+ assert not (a != b)
+
+ def ne(a, b):
+ assert not a.equals(b)
+ assert not (a == b)
+ assert a != b
+
+ data = np.random.randn(10, 6, 4)[::, ::2, ::]
+ tensor1 = pa.Tensor.from_numpy(data)
+ tensor2 = pa.Tensor.from_numpy(np.ascontiguousarray(data))
+ eq(tensor1, tensor2)
+ data = data.copy()
+ data[9, 0, 0] = 1.0
+ tensor2 = pa.Tensor.from_numpy(np.ascontiguousarray(data))
+ ne(tensor1, tensor2)
+
+
+def test_tensor_hashing():
+ # Tensors are unhashable
+ with pytest.raises(TypeError, match="unhashable"):
+ hash(pa.Tensor.from_numpy(np.arange(10)))
+
+
+def test_tensor_size():
+ data = np.random.randn(10, 4)
+ tensor = pa.Tensor.from_numpy(data)
+ assert pa.ipc.get_tensor_size(tensor) > (data.size * 8)
+
+
+def test_read_tensor(tmpdir):
+ # Create and write tensor tensor
+ data = np.random.randn(10, 4)
+ tensor = pa.Tensor.from_numpy(data)
+ data_size = pa.ipc.get_tensor_size(tensor)
+ path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-read-tensor')
+ write_mmap = pa.create_memory_map(path, data_size)
+ pa.ipc.write_tensor(tensor, write_mmap)
+ # Try to read tensor
+ read_mmap = pa.memory_map(path, mode='r')
+ array = pa.ipc.read_tensor(read_mmap).to_numpy()
+ np.testing.assert_equal(data, array)
+
+
+def test_tensor_memoryview():
+ # Tensors support the PEP 3118 buffer protocol
+ for dtype, expected_format in [(np.int8, '=b'),
+ (np.int64, '=q'),
+ (np.uint64, '=Q'),
+ (np.float16, 'e'),
+ (np.float64, 'd'),
+ ]:
+ data = np.arange(10, dtype=dtype)
+ dtype = data.dtype
+ lst = data.tolist()
+ tensor = pa.Tensor.from_numpy(data)
+ m = memoryview(tensor)
+ assert m.format == expected_format
+ assert m.shape == data.shape
+ assert m.strides == data.strides
+ assert m.ndim == 1
+ assert m.nbytes == data.nbytes
+ assert m.itemsize == data.itemsize
+ assert m.itemsize * 8 == tensor.type.bit_width
+ assert np.frombuffer(m, dtype).tolist() == lst
+ del tensor, data
+ assert np.frombuffer(m, dtype).tolist() == lst
diff --git a/src/arrow/python/pyarrow/tests/test_types.py b/src/arrow/python/pyarrow/tests/test_types.py
new file mode 100644
index 000000000..07715b985
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_types.py
@@ -0,0 +1,1067 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import OrderedDict
+from collections.abc import Iterator
+from functools import partial
+import datetime
+import sys
+
+import pickle
+import pytest
+import pytz
+import hypothesis as h
+import hypothesis.strategies as st
+import hypothesis.extra.pytz as tzst
+import weakref
+
+import numpy as np
+import pyarrow as pa
+import pyarrow.types as types
+import pyarrow.tests.strategies as past
+
+
+def get_many_types():
+ # returning them from a function is required because of pa.dictionary
+ # type holds a pyarrow array and test_array.py::test_toal_bytes_allocated
+ # checks that the default memory pool has zero allocated bytes
+ return (
+ pa.null(),
+ pa.bool_(),
+ pa.int32(),
+ pa.time32('s'),
+ pa.time64('us'),
+ pa.date32(),
+ pa.timestamp('us'),
+ pa.timestamp('us', tz='UTC'),
+ pa.timestamp('us', tz='Europe/Paris'),
+ pa.duration('s'),
+ pa.float16(),
+ pa.float32(),
+ pa.float64(),
+ pa.decimal128(19, 4),
+ pa.decimal256(76, 38),
+ pa.string(),
+ pa.binary(),
+ pa.binary(10),
+ pa.large_string(),
+ pa.large_binary(),
+ pa.list_(pa.int32()),
+ pa.list_(pa.int32(), 2),
+ pa.large_list(pa.uint16()),
+ pa.map_(pa.string(), pa.int32()),
+ pa.map_(pa.field('key', pa.int32(), nullable=False),
+ pa.field('value', pa.int32())),
+ pa.struct([pa.field('a', pa.int32()),
+ pa.field('b', pa.int8()),
+ pa.field('c', pa.string())]),
+ pa.struct([pa.field('a', pa.int32(), nullable=False),
+ pa.field('b', pa.int8(), nullable=False),
+ pa.field('c', pa.string())]),
+ pa.union([pa.field('a', pa.binary(10)),
+ pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE),
+ pa.union([pa.field('a', pa.binary(10)),
+ pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE,
+ type_codes=[4, 8]),
+ pa.union([pa.field('a', pa.binary(10)),
+ pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE),
+ pa.union([pa.field('a', pa.binary(10), nullable=False),
+ pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE),
+ pa.dictionary(pa.int32(), pa.string())
+ )
+
+
+def test_is_boolean():
+ assert types.is_boolean(pa.bool_())
+ assert not types.is_boolean(pa.int8())
+
+
+def test_is_integer():
+ signed_ints = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
+ unsigned_ints = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
+
+ for t in signed_ints + unsigned_ints:
+ assert types.is_integer(t)
+
+ for t in signed_ints:
+ assert types.is_signed_integer(t)
+ assert not types.is_unsigned_integer(t)
+
+ for t in unsigned_ints:
+ assert types.is_unsigned_integer(t)
+ assert not types.is_signed_integer(t)
+
+ assert not types.is_integer(pa.float32())
+ assert not types.is_signed_integer(pa.float32())
+
+
+def test_is_floating():
+ for t in [pa.float16(), pa.float32(), pa.float64()]:
+ assert types.is_floating(t)
+
+ assert not types.is_floating(pa.int32())
+
+
+def test_is_null():
+ assert types.is_null(pa.null())
+ assert not types.is_null(pa.list_(pa.int32()))
+
+
+def test_null_field_may_not_be_non_nullable():
+ # ARROW-7273
+ with pytest.raises(ValueError):
+ pa.field('f0', pa.null(), nullable=False)
+
+
+def test_is_decimal():
+ decimal128 = pa.decimal128(19, 4)
+ decimal256 = pa.decimal256(76, 38)
+ int32 = pa.int32()
+
+ assert types.is_decimal(decimal128)
+ assert types.is_decimal(decimal256)
+ assert not types.is_decimal(int32)
+
+ assert types.is_decimal128(decimal128)
+ assert not types.is_decimal128(decimal256)
+ assert not types.is_decimal128(int32)
+
+ assert not types.is_decimal256(decimal128)
+ assert types.is_decimal256(decimal256)
+ assert not types.is_decimal256(int32)
+
+
+def test_is_list():
+ a = pa.list_(pa.int32())
+ b = pa.large_list(pa.int32())
+ c = pa.list_(pa.int32(), 3)
+
+ assert types.is_list(a)
+ assert not types.is_large_list(a)
+ assert not types.is_fixed_size_list(a)
+ assert types.is_large_list(b)
+ assert not types.is_list(b)
+ assert not types.is_fixed_size_list(b)
+ assert types.is_fixed_size_list(c)
+ assert not types.is_list(c)
+ assert not types.is_large_list(c)
+
+ assert not types.is_list(pa.int32())
+
+
+def test_is_map():
+ m = pa.map_(pa.utf8(), pa.int32())
+
+ assert types.is_map(m)
+ assert not types.is_map(pa.int32())
+
+ fields = pa.map_(pa.field('key_name', pa.utf8(), nullable=False),
+ pa.field('value_name', pa.int32()))
+ assert types.is_map(fields)
+
+ entries_type = pa.struct([pa.field('key', pa.int8()),
+ pa.field('value', pa.int8())])
+ list_type = pa.list_(entries_type)
+ assert not types.is_map(list_type)
+
+
+def test_is_dictionary():
+ assert types.is_dictionary(pa.dictionary(pa.int32(), pa.string()))
+ assert not types.is_dictionary(pa.int32())
+
+
+def test_is_nested_or_struct():
+ struct_ex = pa.struct([pa.field('a', pa.int32()),
+ pa.field('b', pa.int8()),
+ pa.field('c', pa.string())])
+
+ assert types.is_struct(struct_ex)
+ assert not types.is_struct(pa.list_(pa.int32()))
+
+ assert types.is_nested(struct_ex)
+ assert types.is_nested(pa.list_(pa.int32()))
+ assert types.is_nested(pa.large_list(pa.int32()))
+ assert not types.is_nested(pa.int32())
+
+
+def test_is_union():
+ for mode in [pa.lib.UnionMode_SPARSE, pa.lib.UnionMode_DENSE]:
+ assert types.is_union(pa.union([pa.field('a', pa.int32()),
+ pa.field('b', pa.int8()),
+ pa.field('c', pa.string())],
+ mode=mode))
+ assert not types.is_union(pa.list_(pa.int32()))
+
+
+# TODO(wesm): is_map, once implemented
+
+
+def test_is_binary_string():
+ assert types.is_binary(pa.binary())
+ assert not types.is_binary(pa.string())
+ assert not types.is_binary(pa.large_binary())
+ assert not types.is_binary(pa.large_string())
+
+ assert types.is_string(pa.string())
+ assert types.is_unicode(pa.string())
+ assert not types.is_string(pa.binary())
+ assert not types.is_string(pa.large_string())
+ assert not types.is_string(pa.large_binary())
+
+ assert types.is_large_binary(pa.large_binary())
+ assert not types.is_large_binary(pa.large_string())
+ assert not types.is_large_binary(pa.binary())
+ assert not types.is_large_binary(pa.string())
+
+ assert types.is_large_string(pa.large_string())
+ assert not types.is_large_string(pa.large_binary())
+ assert not types.is_large_string(pa.string())
+ assert not types.is_large_string(pa.binary())
+
+ assert types.is_fixed_size_binary(pa.binary(5))
+ assert not types.is_fixed_size_binary(pa.binary())
+
+
+def test_is_temporal_date_time_timestamp():
+ date_types = [pa.date32(), pa.date64()]
+ time_types = [pa.time32('s'), pa.time64('ns')]
+ timestamp_types = [pa.timestamp('ms')]
+ duration_types = [pa.duration('ms')]
+ interval_types = [pa.month_day_nano_interval()]
+
+ for case in (date_types + time_types + timestamp_types + duration_types +
+ interval_types):
+ assert types.is_temporal(case)
+
+ for case in date_types:
+ assert types.is_date(case)
+ assert not types.is_time(case)
+ assert not types.is_timestamp(case)
+ assert not types.is_duration(case)
+ assert not types.is_interval(case)
+
+ for case in time_types:
+ assert types.is_time(case)
+ assert not types.is_date(case)
+ assert not types.is_timestamp(case)
+ assert not types.is_duration(case)
+ assert not types.is_interval(case)
+
+ for case in timestamp_types:
+ assert types.is_timestamp(case)
+ assert not types.is_date(case)
+ assert not types.is_time(case)
+ assert not types.is_duration(case)
+ assert not types.is_interval(case)
+
+ for case in duration_types:
+ assert types.is_duration(case)
+ assert not types.is_date(case)
+ assert not types.is_time(case)
+ assert not types.is_timestamp(case)
+ assert not types.is_interval(case)
+
+ for case in interval_types:
+ assert types.is_interval(case)
+ assert not types.is_date(case)
+ assert not types.is_time(case)
+ assert not types.is_timestamp(case)
+
+ assert not types.is_temporal(pa.int32())
+
+
+def test_is_primitive():
+ assert types.is_primitive(pa.int32())
+ assert not types.is_primitive(pa.list_(pa.int32()))
+
+
+@pytest.mark.parametrize(('tz', 'expected'), [
+ (pytz.utc, 'UTC'),
+ (pytz.timezone('Europe/Paris'), 'Europe/Paris'),
+ # StaticTzInfo.tzname returns with '-09' so we need to infer the timezone's
+ # name from the tzinfo.zone attribute
+ (pytz.timezone('Etc/GMT-9'), 'Etc/GMT-9'),
+ (pytz.FixedOffset(180), '+03:00'),
+ (datetime.timezone.utc, 'UTC'),
+ (datetime.timezone(datetime.timedelta(hours=1, minutes=30)), '+01:30')
+])
+def test_tzinfo_to_string(tz, expected):
+ assert pa.lib.tzinfo_to_string(tz) == expected
+
+
+def test_tzinfo_to_string_errors():
+ msg = "Not an instance of datetime.tzinfo"
+ with pytest.raises(TypeError):
+ pa.lib.tzinfo_to_string("Europe/Budapest")
+
+ if sys.version_info >= (3, 8):
+ # before 3.8 it was only possible to create timezone objects with whole
+ # number of minutes
+ tz = datetime.timezone(datetime.timedelta(hours=1, seconds=30))
+ msg = "Offset must represent whole number of minutes"
+ with pytest.raises(ValueError, match=msg):
+ pa.lib.tzinfo_to_string(tz)
+
+
+@h.given(tzst.timezones())
+def test_pytz_timezone_roundtrip(tz):
+ timezone_string = pa.lib.tzinfo_to_string(tz)
+ timezone_tzinfo = pa.lib.string_to_tzinfo(timezone_string)
+ assert timezone_tzinfo == tz
+
+
+def test_convert_custom_tzinfo_objects_to_string():
+ class CorrectTimezone1(datetime.tzinfo):
+ """
+ Conversion is using utcoffset()
+ """
+
+ def tzname(self, dt):
+ return None
+
+ def utcoffset(self, dt):
+ return datetime.timedelta(hours=-3, minutes=30)
+
+ class CorrectTimezone2(datetime.tzinfo):
+ """
+ Conversion is using tzname()
+ """
+
+ def tzname(self, dt):
+ return "+03:00"
+
+ def utcoffset(self, dt):
+ return datetime.timedelta(hours=3)
+
+ class BuggyTimezone1(datetime.tzinfo):
+ """
+ Unable to infer name or offset
+ """
+
+ def tzname(self, dt):
+ return None
+
+ def utcoffset(self, dt):
+ return None
+
+ class BuggyTimezone2(datetime.tzinfo):
+ """
+ Wrong offset type
+ """
+
+ def tzname(self, dt):
+ return None
+
+ def utcoffset(self, dt):
+ return "one hour"
+
+ class BuggyTimezone3(datetime.tzinfo):
+ """
+ Wrong timezone name type
+ """
+
+ def tzname(self, dt):
+ return 240
+
+ def utcoffset(self, dt):
+ return None
+
+ assert pa.lib.tzinfo_to_string(CorrectTimezone1()) == "-02:30"
+ assert pa.lib.tzinfo_to_string(CorrectTimezone2()) == "+03:00"
+
+ msg = (r"Object returned by tzinfo.utcoffset\(None\) is not an instance "
+ r"of datetime.timedelta")
+ for wrong in [BuggyTimezone1(), BuggyTimezone2(), BuggyTimezone3()]:
+ with pytest.raises(ValueError, match=msg):
+ pa.lib.tzinfo_to_string(wrong)
+
+
+@pytest.mark.parametrize(('string', 'expected'), [
+ ('UTC', pytz.utc),
+ ('Europe/Paris', pytz.timezone('Europe/Paris')),
+ ('+03:00', pytz.FixedOffset(180)),
+ ('+01:30', pytz.FixedOffset(90)),
+ ('-02:00', pytz.FixedOffset(-120))
+])
+def test_string_to_tzinfo(string, expected):
+ result = pa.lib.string_to_tzinfo(string)
+ assert result == expected
+
+
+@pytest.mark.parametrize('tz,name', [
+ (pytz.FixedOffset(90), '+01:30'),
+ (pytz.FixedOffset(-90), '-01:30'),
+ (pytz.utc, 'UTC'),
+ (pytz.timezone('America/New_York'), 'America/New_York')
+])
+def test_timezone_string_roundtrip(tz, name):
+ assert pa.lib.tzinfo_to_string(tz) == name
+ assert pa.lib.string_to_tzinfo(name) == tz
+
+
+def test_timestamp():
+ for unit in ('s', 'ms', 'us', 'ns'):
+ for tz in (None, 'UTC', 'Europe/Paris'):
+ ty = pa.timestamp(unit, tz=tz)
+ assert ty.unit == unit
+ assert ty.tz == tz
+
+ for invalid_unit in ('m', 'arbit', 'rary'):
+ with pytest.raises(ValueError, match='Invalid time unit'):
+ pa.timestamp(invalid_unit)
+
+
+def test_time32_units():
+ for valid_unit in ('s', 'ms'):
+ ty = pa.time32(valid_unit)
+ assert ty.unit == valid_unit
+
+ for invalid_unit in ('m', 'us', 'ns'):
+ error_msg = 'Invalid time unit for time32: {!r}'.format(invalid_unit)
+ with pytest.raises(ValueError, match=error_msg):
+ pa.time32(invalid_unit)
+
+
+def test_time64_units():
+ for valid_unit in ('us', 'ns'):
+ ty = pa.time64(valid_unit)
+ assert ty.unit == valid_unit
+
+ for invalid_unit in ('m', 's', 'ms'):
+ error_msg = 'Invalid time unit for time64: {!r}'.format(invalid_unit)
+ with pytest.raises(ValueError, match=error_msg):
+ pa.time64(invalid_unit)
+
+
+def test_duration():
+ for unit in ('s', 'ms', 'us', 'ns'):
+ ty = pa.duration(unit)
+ assert ty.unit == unit
+
+ for invalid_unit in ('m', 'arbit', 'rary'):
+ with pytest.raises(ValueError, match='Invalid time unit'):
+ pa.duration(invalid_unit)
+
+
+def test_list_type():
+ ty = pa.list_(pa.int64())
+ assert isinstance(ty, pa.ListType)
+ assert ty.value_type == pa.int64()
+ assert ty.value_field == pa.field("item", pa.int64(), nullable=True)
+
+ with pytest.raises(TypeError):
+ pa.list_(None)
+
+
+def test_large_list_type():
+ ty = pa.large_list(pa.utf8())
+ assert isinstance(ty, pa.LargeListType)
+ assert ty.value_type == pa.utf8()
+ assert ty.value_field == pa.field("item", pa.utf8(), nullable=True)
+
+ with pytest.raises(TypeError):
+ pa.large_list(None)
+
+
+def test_map_type():
+ ty = pa.map_(pa.utf8(), pa.int32())
+ assert isinstance(ty, pa.MapType)
+ assert ty.key_type == pa.utf8()
+ assert ty.key_field == pa.field("key", pa.utf8(), nullable=False)
+ assert ty.item_type == pa.int32()
+ assert ty.item_field == pa.field("value", pa.int32(), nullable=True)
+
+ with pytest.raises(TypeError):
+ pa.map_(None)
+ with pytest.raises(TypeError):
+ pa.map_(pa.int32(), None)
+ with pytest.raises(TypeError):
+ pa.map_(pa.field("name", pa.string(), nullable=True), pa.int64())
+
+
+def test_fixed_size_list_type():
+ ty = pa.list_(pa.float64(), 2)
+ assert isinstance(ty, pa.FixedSizeListType)
+ assert ty.value_type == pa.float64()
+ assert ty.value_field == pa.field("item", pa.float64(), nullable=True)
+ assert ty.list_size == 2
+
+ with pytest.raises(ValueError):
+ pa.list_(pa.float64(), -2)
+
+
+def test_struct_type():
+ fields = [
+ # Duplicate field name on purpose
+ pa.field('a', pa.int64()),
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.int32())
+ ]
+ ty = pa.struct(fields)
+
+ assert len(ty) == ty.num_fields == 3
+ assert list(ty) == fields
+ assert ty[0].name == 'a'
+ assert ty[2].type == pa.int32()
+ with pytest.raises(IndexError):
+ assert ty[3]
+
+ assert ty['b'] == ty[2]
+
+ # Not found
+ with pytest.raises(KeyError):
+ ty['c']
+
+ # Neither integer nor string
+ with pytest.raises(TypeError):
+ ty[None]
+
+ for a, b in zip(ty, fields):
+ a == b
+
+ # Construct from list of tuples
+ ty = pa.struct([('a', pa.int64()),
+ ('a', pa.int32()),
+ ('b', pa.int32())])
+ assert list(ty) == fields
+ for a, b in zip(ty, fields):
+ a == b
+
+ # Construct from mapping
+ fields = [pa.field('a', pa.int64()),
+ pa.field('b', pa.int32())]
+ ty = pa.struct(OrderedDict([('a', pa.int64()),
+ ('b', pa.int32())]))
+ assert list(ty) == fields
+ for a, b in zip(ty, fields):
+ a == b
+
+ # Invalid args
+ with pytest.raises(TypeError):
+ pa.struct([('a', None)])
+
+
+def test_struct_duplicate_field_names():
+ fields = [
+ pa.field('a', pa.int64()),
+ pa.field('b', pa.int32()),
+ pa.field('a', pa.int32())
+ ]
+ ty = pa.struct(fields)
+
+ # Duplicate
+ with pytest.warns(UserWarning):
+ with pytest.raises(KeyError):
+ ty['a']
+
+ # StructType::GetFieldIndex
+ assert ty.get_field_index('a') == -1
+
+ # StructType::GetAllFieldIndices
+ assert ty.get_all_field_indices('a') == [0, 2]
+
+
+def test_union_type():
+ def check_fields(ty, fields):
+ assert ty.num_fields == len(fields)
+ assert [ty[i] for i in range(ty.num_fields)] == fields
+
+ fields = [pa.field('x', pa.list_(pa.int32())),
+ pa.field('y', pa.binary())]
+ type_codes = [5, 9]
+
+ sparse_factories = [
+ partial(pa.union, mode='sparse'),
+ partial(pa.union, mode=pa.lib.UnionMode_SPARSE),
+ pa.sparse_union,
+ ]
+
+ dense_factories = [
+ partial(pa.union, mode='dense'),
+ partial(pa.union, mode=pa.lib.UnionMode_DENSE),
+ pa.dense_union,
+ ]
+
+ for factory in sparse_factories:
+ ty = factory(fields)
+ assert isinstance(ty, pa.SparseUnionType)
+ assert ty.mode == 'sparse'
+ check_fields(ty, fields)
+ assert ty.type_codes == [0, 1]
+ ty = factory(fields, type_codes=type_codes)
+ assert ty.mode == 'sparse'
+ check_fields(ty, fields)
+ assert ty.type_codes == type_codes
+ # Invalid number of type codes
+ with pytest.raises(ValueError):
+ factory(fields, type_codes=type_codes[1:])
+
+ for factory in dense_factories:
+ ty = factory(fields)
+ assert isinstance(ty, pa.DenseUnionType)
+ assert ty.mode == 'dense'
+ check_fields(ty, fields)
+ assert ty.type_codes == [0, 1]
+ ty = factory(fields, type_codes=type_codes)
+ assert ty.mode == 'dense'
+ check_fields(ty, fields)
+ assert ty.type_codes == type_codes
+ # Invalid number of type codes
+ with pytest.raises(ValueError):
+ factory(fields, type_codes=type_codes[1:])
+
+ for mode in ('unknown', 2):
+ with pytest.raises(ValueError, match='Invalid union mode'):
+ pa.union(fields, mode=mode)
+
+
+def test_dictionary_type():
+ ty0 = pa.dictionary(pa.int32(), pa.string())
+ assert ty0.index_type == pa.int32()
+ assert ty0.value_type == pa.string()
+ assert ty0.ordered is False
+
+ ty1 = pa.dictionary(pa.int8(), pa.float64(), ordered=True)
+ assert ty1.index_type == pa.int8()
+ assert ty1.value_type == pa.float64()
+ assert ty1.ordered is True
+
+ # construct from non-arrow objects
+ ty2 = pa.dictionary('int8', 'string')
+ assert ty2.index_type == pa.int8()
+ assert ty2.value_type == pa.string()
+ assert ty2.ordered is False
+
+ # allow unsigned integers for index type
+ ty3 = pa.dictionary(pa.uint32(), pa.string())
+ assert ty3.index_type == pa.uint32()
+ assert ty3.value_type == pa.string()
+ assert ty3.ordered is False
+
+ # invalid index type raises
+ with pytest.raises(TypeError):
+ pa.dictionary(pa.string(), pa.int64())
+
+
+def test_dictionary_ordered_equals():
+ # Python side checking of ARROW-6345
+ d1 = pa.dictionary('int32', 'binary', ordered=True)
+ d2 = pa.dictionary('int32', 'binary', ordered=False)
+ d3 = pa.dictionary('int8', 'binary', ordered=True)
+ d4 = pa.dictionary('int32', 'binary', ordered=True)
+
+ assert not d1.equals(d2)
+ assert not d1.equals(d3)
+ assert d1.equals(d4)
+
+
+def test_types_hashable():
+ many_types = get_many_types()
+ in_dict = {}
+ for i, type_ in enumerate(many_types):
+ assert hash(type_) == hash(type_)
+ in_dict[type_] = i
+ assert len(in_dict) == len(many_types)
+ for i, type_ in enumerate(many_types):
+ assert in_dict[type_] == i
+
+
+def test_types_picklable():
+ for ty in get_many_types():
+ data = pickle.dumps(ty)
+ assert pickle.loads(data) == ty
+
+
+def test_types_weakref():
+ for ty in get_many_types():
+ wr = weakref.ref(ty)
+ assert wr() is not None
+ # Note that ty may be a singleton and therefore outlive this loop
+
+ wr = weakref.ref(pa.int32())
+ assert wr() is not None # singleton
+ wr = weakref.ref(pa.list_(pa.int32()))
+ assert wr() is None # not a singleton
+
+
+def test_fields_hashable():
+ in_dict = {}
+ fields = [pa.field('a', pa.int32()),
+ pa.field('a', pa.int64()),
+ pa.field('a', pa.int64(), nullable=False),
+ pa.field('b', pa.int32()),
+ pa.field('b', pa.int32(), nullable=False)]
+ for i, field in enumerate(fields):
+ in_dict[field] = i
+ assert len(in_dict) == len(fields)
+ for i, field in enumerate(fields):
+ assert in_dict[field] == i
+
+
+def test_fields_weakrefable():
+ field = pa.field('a', pa.int32())
+ wr = weakref.ref(field)
+ assert wr() is not None
+ del field
+ assert wr() is None
+
+
+@pytest.mark.parametrize('t,check_func', [
+ (pa.date32(), types.is_date32),
+ (pa.date64(), types.is_date64),
+ (pa.time32('s'), types.is_time32),
+ (pa.time64('ns'), types.is_time64),
+ (pa.int8(), types.is_int8),
+ (pa.int16(), types.is_int16),
+ (pa.int32(), types.is_int32),
+ (pa.int64(), types.is_int64),
+ (pa.uint8(), types.is_uint8),
+ (pa.uint16(), types.is_uint16),
+ (pa.uint32(), types.is_uint32),
+ (pa.uint64(), types.is_uint64),
+ (pa.float16(), types.is_float16),
+ (pa.float32(), types.is_float32),
+ (pa.float64(), types.is_float64)
+])
+def test_exact_primitive_types(t, check_func):
+ assert check_func(t)
+
+
+def test_type_id():
+ # enum values are not exposed publicly
+ for ty in get_many_types():
+ assert isinstance(ty.id, int)
+
+
+def test_bit_width():
+ for ty, expected in [(pa.bool_(), 1),
+ (pa.int8(), 8),
+ (pa.uint32(), 32),
+ (pa.float16(), 16),
+ (pa.decimal128(19, 4), 128),
+ (pa.decimal256(76, 38), 256),
+ (pa.binary(42), 42 * 8)]:
+ assert ty.bit_width == expected
+ for ty in [pa.binary(), pa.string(), pa.list_(pa.int16())]:
+ with pytest.raises(ValueError, match="fixed width"):
+ ty.bit_width
+
+
+def test_fixed_size_binary_byte_width():
+ ty = pa.binary(5)
+ assert ty.byte_width == 5
+
+
+def test_decimal_properties():
+ ty = pa.decimal128(19, 4)
+ assert ty.byte_width == 16
+ assert ty.precision == 19
+ assert ty.scale == 4
+ ty = pa.decimal256(76, 38)
+ assert ty.byte_width == 32
+ assert ty.precision == 76
+ assert ty.scale == 38
+
+
+def test_decimal_overflow():
+ pa.decimal128(1, 0)
+ pa.decimal128(38, 0)
+ for i in (0, -1, 39):
+ with pytest.raises(ValueError):
+ pa.decimal128(i, 0)
+
+ pa.decimal256(1, 0)
+ pa.decimal256(76, 0)
+ for i in (0, -1, 77):
+ with pytest.raises(ValueError):
+ pa.decimal256(i, 0)
+
+
+def test_type_equality_operators():
+ many_types = get_many_types()
+ non_pyarrow = ('foo', 16, {'s', 'e', 't'})
+
+ for index, ty in enumerate(many_types):
+ # could use two parametrization levels,
+ # but that'd bloat pytest's output
+ for i, other in enumerate(many_types + non_pyarrow):
+ if i == index:
+ assert ty == other
+ else:
+ assert ty != other
+
+
+def test_key_value_metadata():
+ m = pa.KeyValueMetadata({'a': 'A', 'b': 'B'})
+ assert len(m) == 2
+ assert m['a'] == b'A'
+ assert m[b'a'] == b'A'
+ assert m['b'] == b'B'
+ assert 'a' in m
+ assert b'a' in m
+ assert 'c' not in m
+
+ m1 = pa.KeyValueMetadata({'a': 'A', 'b': 'B'})
+ m2 = pa.KeyValueMetadata(a='A', b='B')
+ m3 = pa.KeyValueMetadata([('a', 'A'), ('b', 'B')])
+
+ assert m1 != 2
+ assert m1 == m2
+ assert m2 == m3
+ assert m1 == {'a': 'A', 'b': 'B'}
+ assert m1 != {'a': 'A', 'b': 'C'}
+
+ with pytest.raises(TypeError):
+ pa.KeyValueMetadata({'a': 1})
+ with pytest.raises(TypeError):
+ pa.KeyValueMetadata({1: 'a'})
+ with pytest.raises(TypeError):
+ pa.KeyValueMetadata(a=1)
+
+ expected = [(b'a', b'A'), (b'b', b'B')]
+ result = [(k, v) for k, v in m3.items()]
+ assert result == expected
+ assert list(m3.items()) == expected
+ assert list(m3.keys()) == [b'a', b'b']
+ assert list(m3.values()) == [b'A', b'B']
+ assert len(m3) == 2
+
+ # test duplicate key support
+ md = pa.KeyValueMetadata([
+ ('a', 'alpha'),
+ ('b', 'beta'),
+ ('a', 'Alpha'),
+ ('a', 'ALPHA'),
+ ])
+
+ expected = [
+ (b'a', b'alpha'),
+ (b'b', b'beta'),
+ (b'a', b'Alpha'),
+ (b'a', b'ALPHA')
+ ]
+ assert len(md) == 4
+ assert isinstance(md.keys(), Iterator)
+ assert isinstance(md.values(), Iterator)
+ assert isinstance(md.items(), Iterator)
+ assert list(md.items()) == expected
+ assert list(md.keys()) == [k for k, _ in expected]
+ assert list(md.values()) == [v for _, v in expected]
+
+ # first occurrence
+ assert md['a'] == b'alpha'
+ assert md['b'] == b'beta'
+ assert md.get_all('a') == [b'alpha', b'Alpha', b'ALPHA']
+ assert md.get_all('b') == [b'beta']
+ assert md.get_all('unkown') == []
+
+ with pytest.raises(KeyError):
+ md = pa.KeyValueMetadata([
+ ('a', 'alpha'),
+ ('b', 'beta'),
+ ('a', 'Alpha'),
+ ('a', 'ALPHA'),
+ ], b='BETA')
+
+
+def test_key_value_metadata_duplicates():
+ meta = pa.KeyValueMetadata({'a': '1', 'b': '2'})
+
+ with pytest.raises(KeyError):
+ pa.KeyValueMetadata(meta, a='3')
+
+
+def test_field_basic():
+ t = pa.string()
+ f = pa.field('foo', t)
+
+ assert f.name == 'foo'
+ assert f.nullable
+ assert f.type is t
+ assert repr(f) == "pyarrow.Field<foo: string>"
+
+ f = pa.field('foo', t, False)
+ assert not f.nullable
+
+ with pytest.raises(TypeError):
+ pa.field('foo', None)
+
+
+def test_field_equals():
+ meta1 = {b'foo': b'bar'}
+ meta2 = {b'bizz': b'bazz'}
+
+ f1 = pa.field('a', pa.int8(), nullable=True)
+ f2 = pa.field('a', pa.int8(), nullable=True)
+ f3 = pa.field('a', pa.int8(), nullable=False)
+ f4 = pa.field('a', pa.int16(), nullable=False)
+ f5 = pa.field('b', pa.int16(), nullable=False)
+ f6 = pa.field('a', pa.int8(), nullable=True, metadata=meta1)
+ f7 = pa.field('a', pa.int8(), nullable=True, metadata=meta1)
+ f8 = pa.field('a', pa.int8(), nullable=True, metadata=meta2)
+
+ assert f1.equals(f2)
+ assert f6.equals(f7)
+ assert not f1.equals(f3)
+ assert not f1.equals(f4)
+ assert not f3.equals(f4)
+ assert not f4.equals(f5)
+
+ # No metadata in f1, but metadata in f6
+ assert f1.equals(f6)
+ assert not f1.equals(f6, check_metadata=True)
+
+ # Different metadata
+ assert f6.equals(f7)
+ assert f7.equals(f8)
+ assert not f7.equals(f8, check_metadata=True)
+
+
+def test_field_equality_operators():
+ f1 = pa.field('a', pa.int8(), nullable=True)
+ f2 = pa.field('a', pa.int8(), nullable=True)
+ f3 = pa.field('b', pa.int8(), nullable=True)
+ f4 = pa.field('b', pa.int8(), nullable=False)
+
+ assert f1 == f2
+ assert f1 != f3
+ assert f3 != f4
+ assert f1 != 'foo'
+
+
+def test_field_metadata():
+ f1 = pa.field('a', pa.int8())
+ f2 = pa.field('a', pa.int8(), metadata={})
+ f3 = pa.field('a', pa.int8(), metadata={b'bizz': b'bazz'})
+
+ assert f1.metadata is None
+ assert f2.metadata == {}
+ assert f3.metadata[b'bizz'] == b'bazz'
+
+
+def test_field_add_remove_metadata():
+ import collections
+
+ f0 = pa.field('foo', pa.int32())
+
+ assert f0.metadata is None
+
+ metadata = {b'foo': b'bar', b'pandas': b'badger'}
+ metadata2 = collections.OrderedDict([
+ (b'a', b'alpha'),
+ (b'b', b'beta')
+ ])
+
+ f1 = f0.with_metadata(metadata)
+ assert f1.metadata == metadata
+
+ f2 = f0.with_metadata(metadata2)
+ assert f2.metadata == metadata2
+
+ with pytest.raises(TypeError):
+ f0.with_metadata([1, 2, 3])
+
+ f3 = f1.remove_metadata()
+ assert f3.metadata is None
+
+ # idempotent
+ f4 = f3.remove_metadata()
+ assert f4.metadata is None
+
+ f5 = pa.field('foo', pa.int32(), True, metadata)
+ f6 = f0.with_metadata(metadata)
+ assert f5.equals(f6)
+
+
+def test_field_modified_copies():
+ f0 = pa.field('foo', pa.int32(), True)
+ f0_ = pa.field('foo', pa.int32(), True)
+ assert f0.equals(f0_)
+
+ f1 = pa.field('foo', pa.int64(), True)
+ f1_ = f0.with_type(pa.int64())
+ assert f1.equals(f1_)
+ # Original instance is unmodified
+ assert f0.equals(f0_)
+
+ f2 = pa.field('foo', pa.int32(), False)
+ f2_ = f0.with_nullable(False)
+ assert f2.equals(f2_)
+ # Original instance is unmodified
+ assert f0.equals(f0_)
+
+ f3 = pa.field('bar', pa.int32(), True)
+ f3_ = f0.with_name('bar')
+ assert f3.equals(f3_)
+ # Original instance is unmodified
+ assert f0.equals(f0_)
+
+
+def test_is_integer_value():
+ assert pa.types.is_integer_value(1)
+ assert pa.types.is_integer_value(np.int64(1))
+ assert not pa.types.is_integer_value('1')
+
+
+def test_is_float_value():
+ assert not pa.types.is_float_value(1)
+ assert pa.types.is_float_value(1.)
+ assert pa.types.is_float_value(np.float64(1))
+ assert not pa.types.is_float_value('1.0')
+
+
+def test_is_boolean_value():
+ assert not pa.types.is_boolean_value(1)
+ assert pa.types.is_boolean_value(True)
+ assert pa.types.is_boolean_value(False)
+ assert pa.types.is_boolean_value(np.bool_(True))
+ assert pa.types.is_boolean_value(np.bool_(False))
+
+
+@h.given(
+ past.all_types |
+ past.all_fields |
+ past.all_schemas
+)
+@h.example(
+ pa.field(name='', type=pa.null(), metadata={'0': '', '': ''})
+)
+def test_pickling(field):
+ data = pickle.dumps(field)
+ assert pickle.loads(data) == field
+
+
+@h.given(
+ st.lists(past.all_types) |
+ st.lists(past.all_fields) |
+ st.lists(past.all_schemas)
+)
+def test_hashing(items):
+ h.assume(
+ # well, this is still O(n^2), but makes the input unique
+ all(not a.equals(b) for i, a in enumerate(items) for b in items[:i])
+ )
+
+ container = {}
+ for i, item in enumerate(items):
+ assert hash(item) == hash(item)
+ container[item] = i
+
+ assert len(container) == len(items)
+
+ for i, item in enumerate(items):
+ assert container[item] == i
diff --git a/src/arrow/python/pyarrow/tests/test_util.py b/src/arrow/python/pyarrow/tests/test_util.py
new file mode 100644
index 000000000..2b351a534
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_util.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import gc
+import signal
+import sys
+import weakref
+
+import pytest
+
+from pyarrow import util
+from pyarrow.tests.util import disabled_gc
+
+
+def exhibit_signal_refcycle():
+ # Put an object in the frame locals and return a weakref to it.
+ # If `signal.getsignal` has a bug where it creates a reference cycle
+ # keeping alive the current execution frames, `obj` will not be
+ # destroyed immediately when this function returns.
+ obj = set()
+ signal.getsignal(signal.SIGINT)
+ return weakref.ref(obj)
+
+
+def test_signal_refcycle():
+ # Test possible workaround for https://bugs.python.org/issue42248
+ with disabled_gc():
+ wr = exhibit_signal_refcycle()
+ if wr() is None:
+ pytest.skip(
+ "Python version does not have the bug we're testing for")
+
+ gc.collect()
+ with disabled_gc():
+ wr = exhibit_signal_refcycle()
+ assert wr() is not None
+ util._break_traceback_cycle_from_frame(sys._getframe(0))
+ assert wr() is None
diff --git a/src/arrow/python/pyarrow/tests/util.py b/src/arrow/python/pyarrow/tests/util.py
new file mode 100644
index 000000000..281de69e3
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/util.py
@@ -0,0 +1,331 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Utility functions for testing
+"""
+
+import contextlib
+import decimal
+import gc
+import numpy as np
+import os
+import random
+import signal
+import string
+import subprocess
+import sys
+
+import pytest
+
+import pyarrow as pa
+import pyarrow.fs
+
+
+def randsign():
+ """Randomly choose either 1 or -1.
+
+ Returns
+ -------
+ sign : int
+ """
+ return random.choice((-1, 1))
+
+
+@contextlib.contextmanager
+def random_seed(seed):
+ """Set the random seed inside of a context manager.
+
+ Parameters
+ ----------
+ seed : int
+ The seed to set
+
+ Notes
+ -----
+ This function is useful when you want to set a random seed but not affect
+ the random state of other functions using the random module.
+ """
+ original_state = random.getstate()
+ random.seed(seed)
+ try:
+ yield
+ finally:
+ random.setstate(original_state)
+
+
+def randdecimal(precision, scale):
+ """Generate a random decimal value with specified precision and scale.
+
+ Parameters
+ ----------
+ precision : int
+ The maximum number of digits to generate. Must be an integer between 1
+ and 38 inclusive.
+ scale : int
+ The maximum number of digits following the decimal point. Must be an
+ integer greater than or equal to 0.
+
+ Returns
+ -------
+ decimal_value : decimal.Decimal
+ A random decimal.Decimal object with the specified precision and scale.
+ """
+ assert 1 <= precision <= 38, 'precision must be between 1 and 38 inclusive'
+ if scale < 0:
+ raise ValueError(
+ 'randdecimal does not yet support generating decimals with '
+ 'negative scale'
+ )
+ max_whole_value = 10 ** (precision - scale) - 1
+ whole = random.randint(-max_whole_value, max_whole_value)
+
+ if not scale:
+ return decimal.Decimal(whole)
+
+ max_fractional_value = 10 ** scale - 1
+ fractional = random.randint(0, max_fractional_value)
+
+ return decimal.Decimal(
+ '{}.{}'.format(whole, str(fractional).rjust(scale, '0'))
+ )
+
+
+def random_ascii(length):
+ return bytes(np.random.randint(65, 123, size=length, dtype='i1'))
+
+
+def rands(nchars):
+ """
+ Generate one random string.
+ """
+ RANDS_CHARS = np.array(
+ list(string.ascii_letters + string.digits), dtype=(np.str_, 1))
+ return "".join(np.random.choice(RANDS_CHARS, nchars))
+
+
+def make_dataframe():
+ import pandas as pd
+
+ N = 30
+ df = pd.DataFrame(
+ {col: np.random.randn(N) for col in string.ascii_uppercase[:4]},
+ index=pd.Index([rands(10) for _ in range(N)])
+ )
+ return df
+
+
+def memory_leak_check(f, metric='rss', threshold=1 << 17, iterations=10,
+ check_interval=1):
+ """
+ Execute the function and try to detect a clear memory leak either internal
+ to Arrow or caused by a reference counting problem in the Python binding
+ implementation. Raises exception if a leak detected
+
+ Parameters
+ ----------
+ f : callable
+ Function to invoke on each iteration
+ metric : {'rss', 'vms', 'shared'}, default 'rss'
+ Attribute of psutil.Process.memory_info to use for determining current
+ memory use
+ threshold : int, default 128K
+ Threshold in number of bytes to consider a leak
+ iterations : int, default 10
+ Total number of invocations of f
+ check_interval : int, default 1
+ Number of invocations of f in between each memory use check
+ """
+ import psutil
+ proc = psutil.Process()
+
+ def _get_use():
+ gc.collect()
+ return getattr(proc.memory_info(), metric)
+
+ baseline_use = _get_use()
+
+ def _leak_check():
+ current_use = _get_use()
+ if current_use - baseline_use > threshold:
+ raise Exception("Memory leak detected. "
+ "Departure from baseline {} after {} iterations"
+ .format(current_use - baseline_use, i))
+
+ for i in range(iterations):
+ f()
+ if i % check_interval == 0:
+ _leak_check()
+
+
+def get_modified_env_with_pythonpath():
+ # Prepend pyarrow root directory to PYTHONPATH
+ env = os.environ.copy()
+ existing_pythonpath = env.get('PYTHONPATH', '')
+
+ module_path = os.path.abspath(
+ os.path.dirname(os.path.dirname(pa.__file__)))
+
+ if existing_pythonpath:
+ new_pythonpath = os.pathsep.join((module_path, existing_pythonpath))
+ else:
+ new_pythonpath = module_path
+ env['PYTHONPATH'] = new_pythonpath
+ return env
+
+
+def invoke_script(script_name, *args):
+ subprocess_env = get_modified_env_with_pythonpath()
+
+ dir_path = os.path.dirname(os.path.realpath(__file__))
+ python_file = os.path.join(dir_path, script_name)
+
+ cmd = [sys.executable, python_file]
+ cmd.extend(args)
+
+ subprocess.check_call(cmd, env=subprocess_env)
+
+
+@contextlib.contextmanager
+def changed_environ(name, value):
+ """
+ Temporarily set environment variable *name* to *value*.
+ """
+ orig_value = os.environ.get(name)
+ os.environ[name] = value
+ try:
+ yield
+ finally:
+ if orig_value is None:
+ del os.environ[name]
+ else:
+ os.environ[name] = orig_value
+
+
+@contextlib.contextmanager
+def change_cwd(path):
+ curdir = os.getcwd()
+ os.chdir(str(path))
+ try:
+ yield
+ finally:
+ os.chdir(curdir)
+
+
+@contextlib.contextmanager
+def disabled_gc():
+ gc.disable()
+ try:
+ yield
+ finally:
+ gc.enable()
+
+
+def _filesystem_uri(path):
+ # URIs on Windows must follow 'file:///C:...' or 'file:/C:...' patterns.
+ if os.name == 'nt':
+ uri = 'file:///{}'.format(path)
+ else:
+ uri = 'file://{}'.format(path)
+ return uri
+
+
+class FSProtocolClass:
+ def __init__(self, path):
+ self._path = path
+
+ def __fspath__(self):
+ return str(self._path)
+
+
+class ProxyHandler(pyarrow.fs.FileSystemHandler):
+ """
+ A dataset handler that proxies to an underlying filesystem. Useful
+ to partially wrap an existing filesystem with partial changes.
+ """
+
+ def __init__(self, fs):
+ self._fs = fs
+
+ def __eq__(self, other):
+ if isinstance(other, ProxyHandler):
+ return self._fs == other._fs
+ return NotImplemented
+
+ def __ne__(self, other):
+ if isinstance(other, ProxyHandler):
+ return self._fs != other._fs
+ return NotImplemented
+
+ def get_type_name(self):
+ return "proxy::" + self._fs.type_name
+
+ def normalize_path(self, path):
+ return self._fs.normalize_path(path)
+
+ def get_file_info(self, paths):
+ return self._fs.get_file_info(paths)
+
+ def get_file_info_selector(self, selector):
+ return self._fs.get_file_info(selector)
+
+ def create_dir(self, path, recursive):
+ return self._fs.create_dir(path, recursive=recursive)
+
+ def delete_dir(self, path):
+ return self._fs.delete_dir(path)
+
+ def delete_dir_contents(self, path):
+ return self._fs.delete_dir_contents(path)
+
+ def delete_root_dir_contents(self):
+ return self._fs.delete_dir_contents("", accept_root_dir=True)
+
+ def delete_file(self, path):
+ return self._fs.delete_file(path)
+
+ def move(self, src, dest):
+ return self._fs.move(src, dest)
+
+ def copy_file(self, src, dest):
+ return self._fs.copy_file(src, dest)
+
+ def open_input_stream(self, path):
+ return self._fs.open_input_stream(path)
+
+ def open_input_file(self, path):
+ return self._fs.open_input_file(path)
+
+ def open_output_stream(self, path, metadata):
+ return self._fs.open_output_stream(path, metadata=metadata)
+
+ def open_append_stream(self, path, metadata):
+ return self._fs.open_append_stream(path, metadata=metadata)
+
+
+def get_raise_signal():
+ if sys.version_info >= (3, 8):
+ return signal.raise_signal
+ elif os.name == 'nt':
+ # On Windows, os.kill() doesn't actually send a signal,
+ # it just terminates the process with the given exit code.
+ pytest.skip("test requires Python 3.8+ on Windows")
+ else:
+ # On Unix, emulate raise_signal() with os.kill().
+ def raise_signal(signum):
+ os.kill(os.getpid(), signum)
+ return raise_signal
diff --git a/src/arrow/python/pyarrow/types.pxi b/src/arrow/python/pyarrow/types.pxi
new file mode 100644
index 000000000..8795e4d3a
--- /dev/null
+++ b/src/arrow/python/pyarrow/types.pxi
@@ -0,0 +1,2930 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import atexit
+from collections.abc import Mapping
+import re
+import sys
+import warnings
+
+
+# These are imprecise because the type (in pandas 0.x) depends on the presence
+# of nulls
+cdef dict _pandas_type_map = {
+ _Type_NA: np.object_, # NaNs
+ _Type_BOOL: np.bool_,
+ _Type_INT8: np.int8,
+ _Type_INT16: np.int16,
+ _Type_INT32: np.int32,
+ _Type_INT64: np.int64,
+ _Type_UINT8: np.uint8,
+ _Type_UINT16: np.uint16,
+ _Type_UINT32: np.uint32,
+ _Type_UINT64: np.uint64,
+ _Type_HALF_FLOAT: np.float16,
+ _Type_FLOAT: np.float32,
+ _Type_DOUBLE: np.float64,
+ _Type_DATE32: np.dtype('datetime64[ns]'),
+ _Type_DATE64: np.dtype('datetime64[ns]'),
+ _Type_TIMESTAMP: np.dtype('datetime64[ns]'),
+ _Type_DURATION: np.dtype('timedelta64[ns]'),
+ _Type_BINARY: np.object_,
+ _Type_FIXED_SIZE_BINARY: np.object_,
+ _Type_STRING: np.object_,
+ _Type_LIST: np.object_,
+ _Type_MAP: np.object_,
+ _Type_DECIMAL128: np.object_,
+}
+
+cdef dict _pep3118_type_map = {
+ _Type_INT8: b'b',
+ _Type_INT16: b'h',
+ _Type_INT32: b'i',
+ _Type_INT64: b'q',
+ _Type_UINT8: b'B',
+ _Type_UINT16: b'H',
+ _Type_UINT32: b'I',
+ _Type_UINT64: b'Q',
+ _Type_HALF_FLOAT: b'e',
+ _Type_FLOAT: b'f',
+ _Type_DOUBLE: b'd',
+}
+
+
+cdef bytes _datatype_to_pep3118(CDataType* type):
+ """
+ Construct a PEP 3118 format string describing the given datatype.
+ None is returned for unsupported types.
+ """
+ try:
+ char = _pep3118_type_map[type.id()]
+ except KeyError:
+ return None
+ else:
+ if char in b'bBhHiIqQ':
+ # Use "standard" int widths, not native
+ return b'=' + char
+ else:
+ return char
+
+
+def _is_primitive(Type type):
+ # This is simply a redirect, the official API is in pyarrow.types.
+ return is_primitive(type)
+
+
+# Workaround for Cython parsing bug
+# https://github.com/cython/cython/issues/2143
+ctypedef CFixedWidthType* _CFixedWidthTypePtr
+
+
+cdef class DataType(_Weakrefable):
+ """
+ Base class of all Arrow data types.
+
+ Each data type is an *instance* of this class.
+ """
+
+ def __cinit__(self):
+ pass
+
+ def __init__(self):
+ raise TypeError("Do not call {}'s constructor directly, use public "
+ "functions like pyarrow.int64, pyarrow.list_, etc. "
+ "instead.".format(self.__class__.__name__))
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ assert type != nullptr
+ self.sp_type = type
+ self.type = type.get()
+ self.pep3118_format = _datatype_to_pep3118(self.type)
+
+ cdef Field field(self, int i):
+ cdef int index = <int> _normalize_index(i, self.type.num_fields())
+ return pyarrow_wrap_field(self.type.field(index))
+
+ @property
+ def id(self):
+ return self.type.id()
+
+ @property
+ def bit_width(self):
+ cdef _CFixedWidthTypePtr ty
+ ty = dynamic_cast[_CFixedWidthTypePtr](self.type)
+ if ty == nullptr:
+ raise ValueError("Non-fixed width type")
+ return ty.bit_width()
+
+ @property
+ def num_children(self):
+ """
+ The number of child fields.
+ """
+ import warnings
+ warnings.warn("num_children is deprecated, use num_fields",
+ FutureWarning)
+ return self.num_fields
+
+ @property
+ def num_fields(self):
+ """
+ The number of child fields.
+ """
+ return self.type.num_fields()
+
+ @property
+ def num_buffers(self):
+ """
+ Number of data buffers required to construct Array type
+ excluding children.
+ """
+ return self.type.layout().buffers.size()
+
+ def __str__(self):
+ return frombytes(self.type.ToString(), safe=True)
+
+ def __hash__(self):
+ return hash(str(self))
+
+ def __reduce__(self):
+ return type_for_alias, (str(self),)
+
+ def __repr__(self):
+ return '{0.__class__.__name__}({0})'.format(self)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except (TypeError, ValueError):
+ return NotImplemented
+
+ def equals(self, other):
+ """
+ Return true if type is equivalent to passed value.
+
+ Parameters
+ ----------
+ other : DataType or string convertible to DataType
+
+ Returns
+ -------
+ is_equal : bool
+ """
+ cdef DataType other_type
+
+ other_type = ensure_type(other)
+ return self.type.Equals(deref(other_type.type))
+
+ def to_pandas_dtype(self):
+ """
+ Return the equivalent NumPy / Pandas dtype.
+ """
+ cdef Type type_id = self.type.id()
+ if type_id in _pandas_type_map:
+ return _pandas_type_map[type_id]
+ else:
+ raise NotImplementedError(str(self))
+
+ def _export_to_c(self, uintptr_t out_ptr):
+ """
+ Export to a C ArrowSchema struct, given its pointer.
+
+ Be careful: if you don't pass the ArrowSchema struct to a consumer,
+ its memory will leak. This is a low-level function intended for
+ expert users.
+ """
+ check_status(ExportType(deref(self.type), <ArrowSchema*> out_ptr))
+
+ @staticmethod
+ def _import_from_c(uintptr_t in_ptr):
+ """
+ Import DataType from a C ArrowSchema struct, given its pointer.
+
+ This is a low-level function intended for expert users.
+ """
+ result = GetResultValue(ImportType(<ArrowSchema*> in_ptr))
+ return pyarrow_wrap_data_type(result)
+
+
+cdef class DictionaryMemo(_Weakrefable):
+ """
+ Tracking container for dictionary-encoded fields.
+ """
+
+ def __cinit__(self):
+ self.sp_memo.reset(new CDictionaryMemo())
+ self.memo = self.sp_memo.get()
+
+
+cdef class DictionaryType(DataType):
+ """
+ Concrete class for dictionary data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.dict_type = <const CDictionaryType*> type.get()
+
+ def __reduce__(self):
+ return dictionary, (self.index_type, self.value_type, self.ordered)
+
+ @property
+ def ordered(self):
+ """
+ Whether the dictionary is ordered, i.e. whether the ordering of values
+ in the dictionary is important.
+ """
+ return self.dict_type.ordered()
+
+ @property
+ def index_type(self):
+ """
+ The data type of dictionary indices (a signed integer type).
+ """
+ return pyarrow_wrap_data_type(self.dict_type.index_type())
+
+ @property
+ def value_type(self):
+ """
+ The dictionary value type.
+
+ The dictionary values are found in an instance of DictionaryArray.
+ """
+ return pyarrow_wrap_data_type(self.dict_type.value_type())
+
+
+cdef class ListType(DataType):
+ """
+ Concrete class for list data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.list_type = <const CListType*> type.get()
+
+ def __reduce__(self):
+ return list_, (self.value_field,)
+
+ @property
+ def value_field(self):
+ return pyarrow_wrap_field(self.list_type.value_field())
+
+ @property
+ def value_type(self):
+ """
+ The data type of list values.
+ """
+ return pyarrow_wrap_data_type(self.list_type.value_type())
+
+
+cdef class LargeListType(DataType):
+ """
+ Concrete class for large list data types
+ (like ListType, but with 64-bit offsets).
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.list_type = <const CLargeListType*> type.get()
+
+ def __reduce__(self):
+ return large_list, (self.value_field,)
+
+ @property
+ def value_field(self):
+ return pyarrow_wrap_field(self.list_type.value_field())
+
+ @property
+ def value_type(self):
+ """
+ The data type of large list values.
+ """
+ return pyarrow_wrap_data_type(self.list_type.value_type())
+
+
+cdef class MapType(DataType):
+ """
+ Concrete class for map data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.map_type = <const CMapType*> type.get()
+
+ def __reduce__(self):
+ return map_, (self.key_field, self.item_field)
+
+ @property
+ def key_field(self):
+ """
+ The field for keys in the map entries.
+ """
+ return pyarrow_wrap_field(self.map_type.key_field())
+
+ @property
+ def key_type(self):
+ """
+ The data type of keys in the map entries.
+ """
+ return pyarrow_wrap_data_type(self.map_type.key_type())
+
+ @property
+ def item_field(self):
+ """
+ The field for items in the map entries.
+ """
+ return pyarrow_wrap_field(self.map_type.item_field())
+
+ @property
+ def item_type(self):
+ """
+ The data type of items in the map entries.
+ """
+ return pyarrow_wrap_data_type(self.map_type.item_type())
+
+
+cdef class FixedSizeListType(DataType):
+ """
+ Concrete class for fixed size list data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.list_type = <const CFixedSizeListType*> type.get()
+
+ def __reduce__(self):
+ return list_, (self.value_type, self.list_size)
+
+ @property
+ def value_field(self):
+ return pyarrow_wrap_field(self.list_type.value_field())
+
+ @property
+ def value_type(self):
+ """
+ The data type of large list values.
+ """
+ return pyarrow_wrap_data_type(self.list_type.value_type())
+
+ @property
+ def list_size(self):
+ """
+ The size of the fixed size lists.
+ """
+ return self.list_type.list_size()
+
+
+cdef class StructType(DataType):
+ """
+ Concrete class for struct data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.struct_type = <const CStructType*> type.get()
+
+ cdef Field field_by_name(self, name):
+ """
+ Return a child field by its name rather than its index.
+ """
+ cdef vector[shared_ptr[CField]] fields
+
+ fields = self.struct_type.GetAllFieldsByName(tobytes(name))
+ if fields.size() == 0:
+ raise KeyError(name)
+ elif fields.size() > 1:
+ warnings.warn("Struct field name corresponds to more "
+ "than one field", UserWarning)
+ raise KeyError(name)
+ else:
+ return pyarrow_wrap_field(fields[0])
+
+ def get_field_index(self, name):
+ """
+ Return index of field with given unique name. Returns -1 if not found
+ or if duplicated
+ """
+ return self.struct_type.GetFieldIndex(tobytes(name))
+
+ def get_all_field_indices(self, name):
+ """
+ Return sorted list of indices for fields with the given name
+ """
+ return self.struct_type.GetAllFieldIndices(tobytes(name))
+
+ def __len__(self):
+ """
+ Like num_fields().
+ """
+ return self.type.num_fields()
+
+ def __iter__(self):
+ """
+ Iterate over struct fields, in order.
+ """
+ for i in range(len(self)):
+ yield self[i]
+
+ def __getitem__(self, i):
+ """
+ Return the struct field with the given index or name.
+ """
+ if isinstance(i, (bytes, str)):
+ return self.field_by_name(i)
+ elif isinstance(i, int):
+ return self.field(i)
+ else:
+ raise TypeError('Expected integer or string index')
+
+ def __reduce__(self):
+ return struct, (list(self),)
+
+
+cdef class UnionType(DataType):
+ """
+ Base class for union data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+
+ @property
+ def mode(self):
+ """
+ The mode of the union ("dense" or "sparse").
+ """
+ cdef CUnionType* type = <CUnionType*> self.sp_type.get()
+ cdef int mode = type.mode()
+ if mode == _UnionMode_DENSE:
+ return 'dense'
+ if mode == _UnionMode_SPARSE:
+ return 'sparse'
+ assert 0
+
+ @property
+ def type_codes(self):
+ """
+ The type code to indicate each data type in this union.
+ """
+ cdef CUnionType* type = <CUnionType*> self.sp_type.get()
+ return type.type_codes()
+
+ def __len__(self):
+ """
+ Like num_fields().
+ """
+ return self.type.num_fields()
+
+ def __iter__(self):
+ """
+ Iterate over union members, in order.
+ """
+ for i in range(len(self)):
+ yield self[i]
+
+ def __getitem__(self, i):
+ """
+ Return a child field by its index.
+ """
+ return self.field(i)
+
+ def __reduce__(self):
+ return union, (list(self), self.mode, self.type_codes)
+
+
+cdef class SparseUnionType(UnionType):
+ """
+ Concrete class for sparse union types.
+ """
+
+
+cdef class DenseUnionType(UnionType):
+ """
+ Concrete class for dense union types.
+ """
+
+
+cdef class TimestampType(DataType):
+ """
+ Concrete class for timestamp data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.ts_type = <const CTimestampType*> type.get()
+
+ @property
+ def unit(self):
+ """
+ The timestamp unit ('s', 'ms', 'us' or 'ns').
+ """
+ return timeunit_to_string(self.ts_type.unit())
+
+ @property
+ def tz(self):
+ """
+ The timestamp time zone, if any, or None.
+ """
+ if self.ts_type.timezone().size() > 0:
+ return frombytes(self.ts_type.timezone())
+ else:
+ return None
+
+ def to_pandas_dtype(self):
+ """
+ Return the equivalent NumPy / Pandas dtype.
+ """
+ if self.tz is None:
+ return _pandas_type_map[_Type_TIMESTAMP]
+ else:
+ # Return DatetimeTZ
+ from pyarrow.pandas_compat import make_datetimetz
+ return make_datetimetz(self.tz)
+
+ def __reduce__(self):
+ return timestamp, (self.unit, self.tz)
+
+
+cdef class Time32Type(DataType):
+ """
+ Concrete class for time32 data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.time_type = <const CTime32Type*> type.get()
+
+ @property
+ def unit(self):
+ """
+ The time unit ('s', 'ms', 'us' or 'ns').
+ """
+ return timeunit_to_string(self.time_type.unit())
+
+
+cdef class Time64Type(DataType):
+ """
+ Concrete class for time64 data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.time_type = <const CTime64Type*> type.get()
+
+ @property
+ def unit(self):
+ """
+ The time unit ('s', 'ms', 'us' or 'ns').
+ """
+ return timeunit_to_string(self.time_type.unit())
+
+
+cdef class DurationType(DataType):
+ """
+ Concrete class for duration data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.duration_type = <const CDurationType*> type.get()
+
+ @property
+ def unit(self):
+ """
+ The duration unit ('s', 'ms', 'us' or 'ns').
+ """
+ return timeunit_to_string(self.duration_type.unit())
+
+
+cdef class FixedSizeBinaryType(DataType):
+ """
+ Concrete class for fixed-size binary data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.fixed_size_binary_type = (
+ <const CFixedSizeBinaryType*> type.get())
+
+ def __reduce__(self):
+ return binary, (self.byte_width,)
+
+ @property
+ def byte_width(self):
+ """
+ The binary size in bytes.
+ """
+ return self.fixed_size_binary_type.byte_width()
+
+
+cdef class Decimal128Type(FixedSizeBinaryType):
+ """
+ Concrete class for decimal128 data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ FixedSizeBinaryType.init(self, type)
+ self.decimal128_type = <const CDecimal128Type*> type.get()
+
+ def __reduce__(self):
+ return decimal128, (self.precision, self.scale)
+
+ @property
+ def precision(self):
+ """
+ The decimal precision, in number of decimal digits (an integer).
+ """
+ return self.decimal128_type.precision()
+
+ @property
+ def scale(self):
+ """
+ The decimal scale (an integer).
+ """
+ return self.decimal128_type.scale()
+
+
+cdef class Decimal256Type(FixedSizeBinaryType):
+ """
+ Concrete class for Decimal256 data types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ FixedSizeBinaryType.init(self, type)
+ self.decimal256_type = <const CDecimal256Type*> type.get()
+
+ def __reduce__(self):
+ return decimal256, (self.precision, self.scale)
+
+ @property
+ def precision(self):
+ """
+ The decimal precision, in number of decimal digits (an integer).
+ """
+ return self.decimal256_type.precision()
+
+ @property
+ def scale(self):
+ """
+ The decimal scale (an integer).
+ """
+ return self.decimal256_type.scale()
+
+
+cdef class BaseExtensionType(DataType):
+ """
+ Concrete base class for extension types.
+ """
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ DataType.init(self, type)
+ self.ext_type = <const CExtensionType*> type.get()
+
+ @property
+ def extension_name(self):
+ """
+ The extension type name.
+ """
+ return frombytes(self.ext_type.extension_name())
+
+ @property
+ def storage_type(self):
+ """
+ The underlying storage type.
+ """
+ return pyarrow_wrap_data_type(self.ext_type.storage_type())
+
+ def wrap_array(self, storage):
+ """
+ Wrap the given storage array as an extension array.
+
+ Parameters
+ ----------
+ storage : Array or ChunkedArray
+
+ Returns
+ -------
+ array : Array or ChunkedArray
+ Extension array wrapping the storage array
+ """
+ cdef:
+ shared_ptr[CDataType] c_storage_type
+
+ if isinstance(storage, Array):
+ c_storage_type = (<Array> storage).ap.type()
+ elif isinstance(storage, ChunkedArray):
+ c_storage_type = (<ChunkedArray> storage).chunked_array.type()
+ else:
+ raise TypeError(
+ f"Expected array or chunked array, got {storage.__class__}")
+
+ if not c_storage_type.get().Equals(deref(self.ext_type)
+ .storage_type()):
+ raise TypeError(
+ f"Incompatible storage type for {self}: "
+ f"expected {self.storage_type}, got {storage.type}")
+
+ if isinstance(storage, Array):
+ return pyarrow_wrap_array(
+ self.ext_type.WrapArray(
+ self.sp_type, (<Array> storage).sp_array))
+ else:
+ return pyarrow_wrap_chunked_array(
+ self.ext_type.WrapArray(
+ self.sp_type, (<ChunkedArray> storage).sp_chunked_array))
+
+
+cdef class ExtensionType(BaseExtensionType):
+ """
+ Concrete base class for Python-defined extension types.
+
+ Parameters
+ ----------
+ storage_type : DataType
+ extension_name : str
+ """
+
+ def __cinit__(self):
+ if type(self) is ExtensionType:
+ raise TypeError("Can only instantiate subclasses of "
+ "ExtensionType")
+
+ def __init__(self, DataType storage_type, extension_name):
+ """
+ Initialize an extension type instance.
+
+ This should be called at the end of the subclass'
+ ``__init__`` method.
+ """
+ cdef:
+ shared_ptr[CExtensionType] cpy_ext_type
+ c_string c_extension_name
+
+ c_extension_name = tobytes(extension_name)
+
+ assert storage_type is not None
+ check_status(CPyExtensionType.FromClass(
+ storage_type.sp_type, c_extension_name, type(self),
+ &cpy_ext_type))
+ self.init(<shared_ptr[CDataType]> cpy_ext_type)
+
+ cdef void init(self, const shared_ptr[CDataType]& type) except *:
+ BaseExtensionType.init(self, type)
+ self.cpy_ext_type = <const CPyExtensionType*> type.get()
+ # Store weakref and serialized version of self on C++ type instance
+ check_status(self.cpy_ext_type.SetInstance(self))
+
+ def __eq__(self, other):
+ # Default implementation to avoid infinite recursion through
+ # DataType.__eq__ -> ExtensionType::ExtensionEquals -> DataType.__eq__
+ if isinstance(other, ExtensionType):
+ return (type(self) == type(other) and
+ self.extension_name == other.extension_name and
+ self.storage_type == other.storage_type)
+ else:
+ return NotImplemented
+
+ def __repr__(self):
+ fmt = '{0.__class__.__name__}({1})'
+ return fmt.format(self, repr(self.storage_type))
+
+ def __arrow_ext_serialize__(self):
+ """
+ Serialized representation of metadata to reconstruct the type object.
+
+ This method should return a bytes object, and those serialized bytes
+ are stored in the custom metadata of the Field holding an extension
+ type in an IPC message.
+ The bytes are passed to ``__arrow_ext_deserialize`` and should hold
+ sufficient information to reconstruct the data type instance.
+ """
+ return NotImplementedError
+
+ @classmethod
+ def __arrow_ext_deserialize__(self, storage_type, serialized):
+ """
+ Return an extension type instance from the storage type and serialized
+ metadata.
+
+ This method should return an instance of the ExtensionType subclass
+ that matches the passed storage type and serialized metadata (the
+ return value of ``__arrow_ext_serialize__``).
+ """
+ return NotImplementedError
+
+ def __arrow_ext_class__(self):
+ """Return an extension array class to be used for building or
+ deserializing arrays with this extension type.
+
+ This method should return a subclass of the ExtensionArray class. By
+ default, if not specialized in the extension implementation, an
+ extension type array will be a built-in ExtensionArray instance.
+ """
+ return ExtensionArray
+
+
+cdef class PyExtensionType(ExtensionType):
+ """
+ Concrete base class for Python-defined extension types based on pickle
+ for (de)serialization.
+
+ Parameters
+ ----------
+ storage_type : DataType
+ The storage type for which the extension is built.
+ """
+
+ def __cinit__(self):
+ if type(self) is PyExtensionType:
+ raise TypeError("Can only instantiate subclasses of "
+ "PyExtensionType")
+
+ def __init__(self, DataType storage_type):
+ ExtensionType.__init__(self, storage_type, "arrow.py_extension_type")
+
+ def __reduce__(self):
+ raise NotImplementedError("Please implement {0}.__reduce__"
+ .format(type(self).__name__))
+
+ def __arrow_ext_serialize__(self):
+ return builtin_pickle.dumps(self)
+
+ @classmethod
+ def __arrow_ext_deserialize__(cls, storage_type, serialized):
+ try:
+ ty = builtin_pickle.loads(serialized)
+ except Exception:
+ # For some reason, it's impossible to deserialize the
+ # ExtensionType instance. Perhaps the serialized data is
+ # corrupt, or more likely the type is being deserialized
+ # in an environment where the original Python class or module
+ # is not available. Fall back on a generic BaseExtensionType.
+ return UnknownExtensionType(storage_type, serialized)
+
+ if ty.storage_type != storage_type:
+ raise TypeError("Expected storage type {0} but got {1}"
+ .format(ty.storage_type, storage_type))
+ return ty
+
+
+cdef class UnknownExtensionType(PyExtensionType):
+ """
+ A concrete class for Python-defined extension types that refer to
+ an unknown Python implementation.
+
+ Parameters
+ ----------
+ storage_type : DataType
+ The storage type for which the extension is built.
+ serialized : bytes
+ The serialised output.
+ """
+
+ cdef:
+ bytes serialized
+
+ def __init__(self, DataType storage_type, serialized):
+ self.serialized = serialized
+ PyExtensionType.__init__(self, storage_type)
+
+ def __arrow_ext_serialize__(self):
+ return self.serialized
+
+
+_python_extension_types_registry = []
+
+
+def register_extension_type(ext_type):
+ """
+ Register a Python extension type.
+
+ Registration is based on the extension name (so different registered types
+ need unique extension names). Registration needs an extension type
+ instance, but then works for any instance of the same subclass regardless
+ of parametrization of the type.
+
+ Parameters
+ ----------
+ ext_type : BaseExtensionType instance
+ The ExtensionType subclass to register.
+
+ """
+ cdef:
+ DataType _type = ensure_type(ext_type, allow_none=False)
+
+ if not isinstance(_type, BaseExtensionType):
+ raise TypeError("Only extension types can be registered")
+
+ # register on the C++ side
+ check_status(
+ RegisterPyExtensionType(<shared_ptr[CDataType]> _type.sp_type))
+
+ # register on the python side
+ _python_extension_types_registry.append(_type)
+
+
+def unregister_extension_type(type_name):
+ """
+ Unregister a Python extension type.
+
+ Parameters
+ ----------
+ type_name : str
+ The name of the ExtensionType subclass to unregister.
+
+ """
+ cdef:
+ c_string c_type_name = tobytes(type_name)
+ check_status(UnregisterPyExtensionType(c_type_name))
+
+
+cdef class KeyValueMetadata(_Metadata, Mapping):
+ """
+ KeyValueMetadata
+
+ Parameters
+ ----------
+ __arg0__ : dict
+ A dict of the key-value metadata
+ **kwargs : optional
+ additional key-value metadata
+ """
+
+ def __init__(self, __arg0__=None, **kwargs):
+ cdef:
+ vector[c_string] keys, values
+ shared_ptr[const CKeyValueMetadata] result
+
+ items = []
+ if __arg0__ is not None:
+ other = (__arg0__.items() if isinstance(__arg0__, Mapping)
+ else __arg0__)
+ items.extend((tobytes(k), v) for k, v in other)
+
+ prior_keys = {k for k, v in items}
+ for k, v in kwargs.items():
+ k = tobytes(k)
+ if k in prior_keys:
+ raise KeyError("Duplicate key {}, "
+ "use pass all items as list of tuples if you "
+ "intend to have duplicate keys")
+ items.append((k, v))
+
+ keys.reserve(len(items))
+ for key, value in items:
+ keys.push_back(tobytes(key))
+ values.push_back(tobytes(value))
+ result.reset(new CKeyValueMetadata(move(keys), move(values)))
+ self.init(result)
+
+ cdef void init(self, const shared_ptr[const CKeyValueMetadata]& wrapped):
+ self.wrapped = wrapped
+ self.metadata = wrapped.get()
+
+ @staticmethod
+ cdef wrap(const shared_ptr[const CKeyValueMetadata]& sp):
+ cdef KeyValueMetadata self = KeyValueMetadata.__new__(KeyValueMetadata)
+ self.init(sp)
+ return self
+
+ cdef inline shared_ptr[const CKeyValueMetadata] unwrap(self) nogil:
+ return self.wrapped
+
+ def equals(self, KeyValueMetadata other):
+ return self.metadata.Equals(deref(other.wrapped))
+
+ def __repr__(self):
+ return str(self)
+
+ def __str__(self):
+ return frombytes(self.metadata.ToString(), safe=True)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ pass
+
+ if isinstance(other, Mapping):
+ try:
+ other = KeyValueMetadata(other)
+ return self.equals(other)
+ except TypeError:
+ pass
+
+ return NotImplemented
+
+ def __len__(self):
+ return self.metadata.size()
+
+ def __contains__(self, key):
+ return self.metadata.Contains(tobytes(key))
+
+ def __getitem__(self, key):
+ return GetResultValue(self.metadata.Get(tobytes(key)))
+
+ def __iter__(self):
+ return self.keys()
+
+ def __reduce__(self):
+ return KeyValueMetadata, (list(self.items()),)
+
+ def key(self, i):
+ return self.metadata.key(i)
+
+ def value(self, i):
+ return self.metadata.value(i)
+
+ def keys(self):
+ for i in range(self.metadata.size()):
+ yield self.metadata.key(i)
+
+ def values(self):
+ for i in range(self.metadata.size()):
+ yield self.metadata.value(i)
+
+ def items(self):
+ for i in range(self.metadata.size()):
+ yield (self.metadata.key(i), self.metadata.value(i))
+
+ def get_all(self, key):
+ key = tobytes(key)
+ return [v for k, v in self.items() if k == key]
+
+ def to_dict(self):
+ """
+ Convert KeyValueMetadata to dict. If a key occurs twice, the value for
+ the first one is returned
+ """
+ cdef object key # to force coercion to Python
+ result = ordered_dict()
+ for i in range(self.metadata.size()):
+ key = self.metadata.key(i)
+ if key not in result:
+ result[key] = self.metadata.value(i)
+ return result
+
+
+cdef KeyValueMetadata ensure_metadata(object meta, c_bool allow_none=False):
+ if allow_none and meta is None:
+ return None
+ elif isinstance(meta, KeyValueMetadata):
+ return meta
+ else:
+ return KeyValueMetadata(meta)
+
+
+cdef class Field(_Weakrefable):
+ """
+ A named field, with a data type, nullability, and optional metadata.
+
+ Notes
+ -----
+ Do not use this class's constructor directly; use pyarrow.field
+ """
+
+ def __cinit__(self):
+ pass
+
+ def __init__(self):
+ raise TypeError("Do not call Field's constructor directly, use "
+ "`pyarrow.field` instead.")
+
+ cdef void init(self, const shared_ptr[CField]& field):
+ self.sp_field = field
+ self.field = field.get()
+ self.type = pyarrow_wrap_data_type(field.get().type())
+
+ def equals(self, Field other, bint check_metadata=False):
+ """
+ Test if this field is equal to the other
+
+ Parameters
+ ----------
+ other : pyarrow.Field
+ check_metadata : bool, default False
+ Whether Field metadata equality should be checked as well.
+
+ Returns
+ -------
+ is_equal : bool
+ """
+ return self.field.Equals(deref(other.field), check_metadata)
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def __reduce__(self):
+ return field, (self.name, self.type, self.nullable, self.metadata)
+
+ def __str__(self):
+ return 'pyarrow.Field<{0}>'.format(
+ frombytes(self.field.ToString(), safe=True))
+
+ def __repr__(self):
+ return self.__str__()
+
+ def __hash__(self):
+ return hash((self.field.name(), self.type, self.field.nullable()))
+
+ @property
+ def nullable(self):
+ return self.field.nullable()
+
+ @property
+ def name(self):
+ return frombytes(self.field.name())
+
+ @property
+ def metadata(self):
+ wrapped = pyarrow_wrap_metadata(self.field.metadata())
+ if wrapped is not None:
+ return wrapped.to_dict()
+ else:
+ return wrapped
+
+ def add_metadata(self, metadata):
+ warnings.warn("The 'add_metadata' method is deprecated, use "
+ "'with_metadata' instead", FutureWarning, stacklevel=2)
+ return self.with_metadata(metadata)
+
+ def with_metadata(self, metadata):
+ """
+ Add metadata as dict of string keys and values to Field
+
+ Parameters
+ ----------
+ metadata : dict
+ Keys and values must be string-like / coercible to bytes
+
+ Returns
+ -------
+ field : pyarrow.Field
+ """
+ cdef shared_ptr[CField] c_field
+
+ meta = ensure_metadata(metadata, allow_none=False)
+ with nogil:
+ c_field = self.field.WithMetadata(meta.unwrap())
+
+ return pyarrow_wrap_field(c_field)
+
+ def remove_metadata(self):
+ """
+ Create new field without metadata, if any
+
+ Returns
+ -------
+ field : pyarrow.Field
+ """
+ cdef shared_ptr[CField] new_field
+ with nogil:
+ new_field = self.field.RemoveMetadata()
+ return pyarrow_wrap_field(new_field)
+
+ def with_type(self, DataType new_type):
+ """
+ A copy of this field with the replaced type
+
+ Parameters
+ ----------
+ new_type : pyarrow.DataType
+
+ Returns
+ -------
+ field : pyarrow.Field
+ """
+ cdef:
+ shared_ptr[CField] c_field
+ shared_ptr[CDataType] c_datatype
+
+ c_datatype = pyarrow_unwrap_data_type(new_type)
+ with nogil:
+ c_field = self.field.WithType(c_datatype)
+
+ return pyarrow_wrap_field(c_field)
+
+ def with_name(self, name):
+ """
+ A copy of this field with the replaced name
+
+ Parameters
+ ----------
+ name : str
+
+ Returns
+ -------
+ field : pyarrow.Field
+ """
+ cdef:
+ shared_ptr[CField] c_field
+
+ c_field = self.field.WithName(tobytes(name))
+
+ return pyarrow_wrap_field(c_field)
+
+ def with_nullable(self, nullable):
+ """
+ A copy of this field with the replaced nullability
+
+ Parameters
+ ----------
+ nullable : bool
+
+ Returns
+ -------
+ field: pyarrow.Field
+ """
+ cdef:
+ shared_ptr[CField] field
+ c_bool c_nullable
+
+ c_nullable = bool(nullable)
+ with nogil:
+ c_field = self.field.WithNullable(c_nullable)
+
+ return pyarrow_wrap_field(c_field)
+
+ def flatten(self):
+ """
+ Flatten this field. If a struct field, individual child fields
+ will be returned with their names prefixed by the parent's name.
+
+ Returns
+ -------
+ fields : List[pyarrow.Field]
+ """
+ cdef vector[shared_ptr[CField]] flattened
+ with nogil:
+ flattened = self.field.Flatten()
+ return [pyarrow_wrap_field(f) for f in flattened]
+
+ def _export_to_c(self, uintptr_t out_ptr):
+ """
+ Export to a C ArrowSchema struct, given its pointer.
+
+ Be careful: if you don't pass the ArrowSchema struct to a consumer,
+ its memory will leak. This is a low-level function intended for
+ expert users.
+ """
+ check_status(ExportField(deref(self.field), <ArrowSchema*> out_ptr))
+
+ @staticmethod
+ def _import_from_c(uintptr_t in_ptr):
+ """
+ Import Field from a C ArrowSchema struct, given its pointer.
+
+ This is a low-level function intended for expert users.
+ """
+ with nogil:
+ result = GetResultValue(ImportField(<ArrowSchema*> in_ptr))
+ return pyarrow_wrap_field(result)
+
+
+cdef class Schema(_Weakrefable):
+
+ def __cinit__(self):
+ pass
+
+ def __init__(self):
+ raise TypeError("Do not call Schema's constructor directly, use "
+ "`pyarrow.schema` instead.")
+
+ def __len__(self):
+ return self.schema.num_fields()
+
+ def __getitem__(self, key):
+ # access by integer index
+ return self._field(key)
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield self[i]
+
+ cdef void init(self, const vector[shared_ptr[CField]]& fields):
+ self.schema = new CSchema(fields)
+ self.sp_schema.reset(self.schema)
+
+ cdef void init_schema(self, const shared_ptr[CSchema]& schema):
+ self.schema = schema.get()
+ self.sp_schema = schema
+
+ def __reduce__(self):
+ return schema, (list(self), self.metadata)
+
+ def __hash__(self):
+ return hash((tuple(self), self.metadata))
+
+ def __sizeof__(self):
+ size = 0
+ if self.metadata:
+ for key, value in self.metadata.items():
+ size += sys.getsizeof(key)
+ size += sys.getsizeof(value)
+
+ return size + super(Schema, self).__sizeof__()
+
+ @property
+ def pandas_metadata(self):
+ """
+ Return deserialized-from-JSON pandas metadata field (if it exists)
+ """
+ metadata = self.metadata
+ key = b'pandas'
+ if metadata is None or key not in metadata:
+ return None
+
+ import json
+ return json.loads(metadata[key].decode('utf8'))
+
+ @property
+ def names(self):
+ """
+ The schema's field names.
+
+ Returns
+ -------
+ list of str
+ """
+ cdef int i
+ result = []
+ for i in range(self.schema.num_fields()):
+ name = frombytes(self.schema.field(i).get().name())
+ result.append(name)
+ return result
+
+ @property
+ def types(self):
+ """
+ The schema's field types.
+
+ Returns
+ -------
+ list of DataType
+ """
+ return [field.type for field in self]
+
+ @property
+ def metadata(self):
+ wrapped = pyarrow_wrap_metadata(self.schema.metadata())
+ if wrapped is not None:
+ return wrapped.to_dict()
+ else:
+ return wrapped
+
+ def __eq__(self, other):
+ try:
+ return self.equals(other)
+ except TypeError:
+ return NotImplemented
+
+ def empty_table(self):
+ """
+ Provide an empty table according to the schema.
+
+ Returns
+ -------
+ table: pyarrow.Table
+ """
+ arrays = [_empty_array(field.type) for field in self]
+ return Table.from_arrays(arrays, schema=self)
+
+ def equals(self, Schema other not None, bint check_metadata=False):
+ """
+ Test if this schema is equal to the other
+
+ Parameters
+ ----------
+ other : pyarrow.Schema
+ check_metadata : bool, default False
+ Key/value metadata must be equal too
+
+ Returns
+ -------
+ is_equal : bool
+ """
+ return self.sp_schema.get().Equals(deref(other.schema),
+ check_metadata)
+
+ @classmethod
+ def from_pandas(cls, df, preserve_index=None):
+ """
+ Returns implied schema from dataframe
+
+ Parameters
+ ----------
+ df : pandas.DataFrame
+ preserve_index : bool, default True
+ Whether to store the index as an additional column (or columns, for
+ MultiIndex) in the resulting `Table`.
+ The default of None will store the index as a column, except for
+ RangeIndex which is stored as metadata only. Use
+ ``preserve_index=True`` to force it to be stored as a column.
+
+ Returns
+ -------
+ pyarrow.Schema
+
+ Examples
+ --------
+
+ >>> import pandas as pd
+ >>> import pyarrow as pa
+ >>> df = pd.DataFrame({
+ ... 'int': [1, 2],
+ ... 'str': ['a', 'b']
+ ... })
+ >>> pa.Schema.from_pandas(df)
+ int: int64
+ str: string
+ __index_level_0__: int64
+ """
+ from pyarrow.pandas_compat import dataframe_to_types
+ names, types, metadata = dataframe_to_types(
+ df,
+ preserve_index=preserve_index
+ )
+ fields = []
+ for name, type_ in zip(names, types):
+ fields.append(field(name, type_))
+ return schema(fields, metadata)
+
+ def field(self, i):
+ """
+ Select a field by its column name or numeric index.
+
+ Parameters
+ ----------
+ i : int or string
+
+ Returns
+ -------
+ pyarrow.Field
+ """
+ if isinstance(i, (bytes, str)):
+ field_index = self.get_field_index(i)
+ if field_index < 0:
+ raise KeyError("Column {} does not exist in schema".format(i))
+ else:
+ return self._field(field_index)
+ elif isinstance(i, int):
+ return self._field(i)
+ else:
+ raise TypeError("Index must either be string or integer")
+
+ def _field(self, int i):
+ """Select a field by its numeric index."""
+ cdef int index = <int> _normalize_index(i, self.schema.num_fields())
+ return pyarrow_wrap_field(self.schema.field(index))
+
+ def field_by_name(self, name):
+ """
+ Access a field by its name rather than the column index.
+
+ Parameters
+ ----------
+ name: str
+
+ Returns
+ -------
+ field: pyarrow.Field
+ """
+ cdef:
+ vector[shared_ptr[CField]] results
+
+ warnings.warn(
+ "The 'field_by_name' method is deprecated, use 'field' instead",
+ FutureWarning, stacklevel=2)
+
+ results = self.schema.GetAllFieldsByName(tobytes(name))
+ if results.size() == 0:
+ return None
+ elif results.size() > 1:
+ warnings.warn("Schema field name corresponds to more "
+ "than one field", UserWarning)
+ return None
+ else:
+ return pyarrow_wrap_field(results[0])
+
+ def get_field_index(self, name):
+ """
+ Return index of field with given unique name. Returns -1 if not found
+ or if duplicated
+ """
+ return self.schema.GetFieldIndex(tobytes(name))
+
+ def get_all_field_indices(self, name):
+ """
+ Return sorted list of indices for fields with the given name
+ """
+ return self.schema.GetAllFieldIndices(tobytes(name))
+
+ def append(self, Field field):
+ """
+ Append a field at the end of the schema.
+
+ In contrast to Python's ``list.append()`` it does return a new
+ object, leaving the original Schema unmodified.
+
+ Parameters
+ ----------
+ field: Field
+
+ Returns
+ -------
+ schema: Schema
+ New object with appended field.
+ """
+ return self.insert(self.schema.num_fields(), field)
+
+ def insert(self, int i, Field field):
+ """
+ Add a field at position i to the schema.
+
+ Parameters
+ ----------
+ i: int
+ field: Field
+
+ Returns
+ -------
+ schema: Schema
+ """
+ cdef:
+ shared_ptr[CSchema] new_schema
+ shared_ptr[CField] c_field
+
+ c_field = field.sp_field
+
+ with nogil:
+ new_schema = GetResultValue(self.schema.AddField(i, c_field))
+
+ return pyarrow_wrap_schema(new_schema)
+
+ def remove(self, int i):
+ """
+ Remove the field at index i from the schema.
+
+ Parameters
+ ----------
+ i: int
+
+ Returns
+ -------
+ schema: Schema
+ """
+ cdef shared_ptr[CSchema] new_schema
+
+ with nogil:
+ new_schema = GetResultValue(self.schema.RemoveField(i))
+
+ return pyarrow_wrap_schema(new_schema)
+
+ def set(self, int i, Field field):
+ """
+ Replace a field at position i in the schema.
+
+ Parameters
+ ----------
+ i: int
+ field: Field
+
+ Returns
+ -------
+ schema: Schema
+ """
+ cdef:
+ shared_ptr[CSchema] new_schema
+ shared_ptr[CField] c_field
+
+ c_field = field.sp_field
+
+ with nogil:
+ new_schema = GetResultValue(self.schema.SetField(i, c_field))
+
+ return pyarrow_wrap_schema(new_schema)
+
+ def add_metadata(self, metadata):
+ warnings.warn("The 'add_metadata' method is deprecated, use "
+ "'with_metadata' instead", FutureWarning, stacklevel=2)
+ return self.with_metadata(metadata)
+
+ def with_metadata(self, metadata):
+ """
+ Add metadata as dict of string keys and values to Schema
+
+ Parameters
+ ----------
+ metadata : dict
+ Keys and values must be string-like / coercible to bytes
+
+ Returns
+ -------
+ schema : pyarrow.Schema
+ """
+ cdef shared_ptr[CSchema] c_schema
+
+ meta = ensure_metadata(metadata, allow_none=False)
+ with nogil:
+ c_schema = self.schema.WithMetadata(meta.unwrap())
+
+ return pyarrow_wrap_schema(c_schema)
+
+ def serialize(self, memory_pool=None):
+ """
+ Write Schema to Buffer as encapsulated IPC message
+
+ Parameters
+ ----------
+ memory_pool : MemoryPool, default None
+ Uses default memory pool if not specified
+
+ Returns
+ -------
+ serialized : Buffer
+ """
+ cdef:
+ shared_ptr[CBuffer] buffer
+ CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
+
+ with nogil:
+ buffer = GetResultValue(SerializeSchema(deref(self.schema),
+ pool))
+ return pyarrow_wrap_buffer(buffer)
+
+ def remove_metadata(self):
+ """
+ Create new schema without metadata, if any
+
+ Returns
+ -------
+ schema : pyarrow.Schema
+ """
+ cdef shared_ptr[CSchema] new_schema
+ with nogil:
+ new_schema = self.schema.RemoveMetadata()
+ return pyarrow_wrap_schema(new_schema)
+
+ def to_string(self, truncate_metadata=True, show_field_metadata=True,
+ show_schema_metadata=True):
+ """
+ Return human-readable representation of Schema
+
+ Parameters
+ ----------
+ truncate_metadata : boolean, default True
+ Limit metadata key/value display to a single line of ~80 characters
+ or less
+ show_field_metadata : boolean, default True
+ Display Field-level KeyValueMetadata
+ show_schema_metadata : boolean, default True
+ Display Schema-level KeyValueMetadata
+
+ Returns
+ -------
+ str : the formatted output
+ """
+ cdef:
+ c_string result
+ PrettyPrintOptions options = PrettyPrintOptions.Defaults()
+
+ options.indent = 0
+ options.truncate_metadata = truncate_metadata
+ options.show_field_metadata = show_field_metadata
+ options.show_schema_metadata = show_schema_metadata
+
+ with nogil:
+ check_status(
+ PrettyPrint(
+ deref(self.schema),
+ options,
+ &result
+ )
+ )
+
+ return frombytes(result, safe=True)
+
+ def _export_to_c(self, uintptr_t out_ptr):
+ """
+ Export to a C ArrowSchema struct, given its pointer.
+
+ Be careful: if you don't pass the ArrowSchema struct to a consumer,
+ its memory will leak. This is a low-level function intended for
+ expert users.
+ """
+ check_status(ExportSchema(deref(self.schema), <ArrowSchema*> out_ptr))
+
+ @staticmethod
+ def _import_from_c(uintptr_t in_ptr):
+ """
+ Import Schema from a C ArrowSchema struct, given its pointer.
+
+ This is a low-level function intended for expert users.
+ """
+ with nogil:
+ result = GetResultValue(ImportSchema(<ArrowSchema*> in_ptr))
+ return pyarrow_wrap_schema(result)
+
+ def __str__(self):
+ return self.to_string()
+
+ def __repr__(self):
+ return self.__str__()
+
+
+def unify_schemas(schemas):
+ """
+ Unify schemas by merging fields by name.
+
+ The resulting schema will contain the union of fields from all schemas.
+ Fields with the same name will be merged. Note that two fields with
+ different types will fail merging.
+
+ - The unified field will inherit the metadata from the schema where
+ that field is first defined.
+ - The first N fields in the schema will be ordered the same as the
+ N fields in the first schema.
+
+ The resulting schema will inherit its metadata from the first input
+ schema.
+
+ Parameters
+ ----------
+ schemas : list of Schema
+ Schemas to merge into a single one.
+
+ Returns
+ -------
+ Schema
+
+ Raises
+ ------
+ ArrowInvalid :
+ If any input schema contains fields with duplicate names.
+ If Fields of the same name are not mergeable.
+ """
+ cdef:
+ Schema schema
+ vector[shared_ptr[CSchema]] c_schemas
+ for schema in schemas:
+ c_schemas.push_back(pyarrow_unwrap_schema(schema))
+ return pyarrow_wrap_schema(GetResultValue(UnifySchemas(c_schemas)))
+
+
+cdef dict _type_cache = {}
+
+
+cdef DataType primitive_type(Type type):
+ if type in _type_cache:
+ return _type_cache[type]
+
+ cdef DataType out = DataType.__new__(DataType)
+ out.init(GetPrimitiveType(type))
+
+ _type_cache[type] = out
+ return out
+
+
+# -----------------------------------------------------------
+# Type factory functions
+
+
+def field(name, type, bint nullable=True, metadata=None):
+ """
+ Create a pyarrow.Field instance.
+
+ Parameters
+ ----------
+ name : str or bytes
+ Name of the field.
+ type : pyarrow.DataType
+ Arrow datatype of the field.
+ nullable : bool, default True
+ Whether the field's values are nullable.
+ metadata : dict, default None
+ Optional field metadata, the keys and values must be coercible to
+ bytes.
+
+ Returns
+ -------
+ field : pyarrow.Field
+ """
+ cdef:
+ Field result = Field.__new__(Field)
+ DataType _type = ensure_type(type, allow_none=False)
+ shared_ptr[const CKeyValueMetadata] c_meta
+
+ metadata = ensure_metadata(metadata, allow_none=True)
+ c_meta = pyarrow_unwrap_metadata(metadata)
+
+ if _type.type.id() == _Type_NA and not nullable:
+ raise ValueError("A null type field may not be non-nullable")
+
+ result.sp_field.reset(
+ new CField(tobytes(name), _type.sp_type, nullable, c_meta)
+ )
+ result.field = result.sp_field.get()
+ result.type = _type
+
+ return result
+
+
+cdef set PRIMITIVE_TYPES = set([
+ _Type_NA, _Type_BOOL,
+ _Type_UINT8, _Type_INT8,
+ _Type_UINT16, _Type_INT16,
+ _Type_UINT32, _Type_INT32,
+ _Type_UINT64, _Type_INT64,
+ _Type_TIMESTAMP, _Type_DATE32,
+ _Type_TIME32, _Type_TIME64,
+ _Type_DATE64,
+ _Type_HALF_FLOAT,
+ _Type_FLOAT,
+ _Type_DOUBLE])
+
+
+def null():
+ """
+ Create instance of null type.
+ """
+ return primitive_type(_Type_NA)
+
+
+def bool_():
+ """
+ Create instance of boolean type.
+ """
+ return primitive_type(_Type_BOOL)
+
+
+def uint8():
+ """
+ Create instance of unsigned int8 type.
+ """
+ return primitive_type(_Type_UINT8)
+
+
+def int8():
+ """
+ Create instance of signed int8 type.
+ """
+ return primitive_type(_Type_INT8)
+
+
+def uint16():
+ """
+ Create instance of unsigned uint16 type.
+ """
+ return primitive_type(_Type_UINT16)
+
+
+def int16():
+ """
+ Create instance of signed int16 type.
+ """
+ return primitive_type(_Type_INT16)
+
+
+def uint32():
+ """
+ Create instance of unsigned uint32 type.
+ """
+ return primitive_type(_Type_UINT32)
+
+
+def int32():
+ """
+ Create instance of signed int32 type.
+ """
+ return primitive_type(_Type_INT32)
+
+
+def uint64():
+ """
+ Create instance of unsigned uint64 type.
+ """
+ return primitive_type(_Type_UINT64)
+
+
+def int64():
+ """
+ Create instance of signed int64 type.
+ """
+ return primitive_type(_Type_INT64)
+
+
+cdef dict _timestamp_type_cache = {}
+cdef dict _time_type_cache = {}
+cdef dict _duration_type_cache = {}
+
+
+cdef timeunit_to_string(TimeUnit unit):
+ if unit == TimeUnit_SECOND:
+ return 's'
+ elif unit == TimeUnit_MILLI:
+ return 'ms'
+ elif unit == TimeUnit_MICRO:
+ return 'us'
+ elif unit == TimeUnit_NANO:
+ return 'ns'
+
+
+cdef TimeUnit string_to_timeunit(unit) except *:
+ if unit == 's':
+ return TimeUnit_SECOND
+ elif unit == 'ms':
+ return TimeUnit_MILLI
+ elif unit == 'us':
+ return TimeUnit_MICRO
+ elif unit == 'ns':
+ return TimeUnit_NANO
+ else:
+ raise ValueError(f"Invalid time unit: {unit!r}")
+
+
+def tzinfo_to_string(tz):
+ """
+ Converts a time zone object into a string indicating the name of a time
+ zone, one of:
+ * As used in the Olson time zone database (the "tz database" or
+ "tzdata"), such as "America/New_York"
+ * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30
+
+ Parameters
+ ----------
+ tz : datetime.tzinfo
+ Time zone object
+
+ Returns
+ -------
+ name : str
+ Time zone name
+ """
+ return frombytes(GetResultValue(TzinfoToString(<PyObject*>tz)))
+
+
+def string_to_tzinfo(name):
+ """
+ Convert a time zone name into a time zone object.
+
+ Supported input strings are:
+ * As used in the Olson time zone database (the "tz database" or
+ "tzdata"), such as "America/New_York"
+ * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30
+
+ Parameters
+ ----------
+ name: str
+ Time zone name.
+
+ Returns
+ -------
+ tz : datetime.tzinfo
+ Time zone object
+ """
+ cdef PyObject* tz = GetResultValue(StringToTzinfo(name.encode('utf-8')))
+ return PyObject_to_object(tz)
+
+
+def timestamp(unit, tz=None):
+ """
+ Create instance of timestamp type with resolution and optional time zone.
+
+ Parameters
+ ----------
+ unit : str
+ one of 's' [second], 'ms' [millisecond], 'us' [microsecond], or 'ns'
+ [nanosecond]
+ tz : str, default None
+ Time zone name. None indicates time zone naive
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> pa.timestamp('us')
+ TimestampType(timestamp[us])
+ >>> pa.timestamp('s', tz='America/New_York')
+ TimestampType(timestamp[s, tz=America/New_York])
+ >>> pa.timestamp('s', tz='+07:30')
+ TimestampType(timestamp[s, tz=+07:30])
+
+ Returns
+ -------
+ timestamp_type : TimestampType
+ """
+ cdef:
+ TimeUnit unit_code
+ c_string c_timezone
+
+ unit_code = string_to_timeunit(unit)
+
+ cdef TimestampType out = TimestampType.__new__(TimestampType)
+
+ if tz is None:
+ out.init(ctimestamp(unit_code))
+ if unit_code in _timestamp_type_cache:
+ return _timestamp_type_cache[unit_code]
+ _timestamp_type_cache[unit_code] = out
+ else:
+ if not isinstance(tz, (bytes, str)):
+ tz = tzinfo_to_string(tz)
+
+ c_timezone = tobytes(tz)
+ out.init(ctimestamp(unit_code, c_timezone))
+
+ return out
+
+
+def time32(unit):
+ """
+ Create instance of 32-bit time (time of day) type with unit resolution.
+
+ Parameters
+ ----------
+ unit : str
+ one of 's' [second], or 'ms' [millisecond]
+
+ Returns
+ -------
+ type : pyarrow.Time32Type
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> pa.time32('s')
+ Time32Type(time32[s])
+ >>> pa.time32('ms')
+ Time32Type(time32[ms])
+ """
+ cdef:
+ TimeUnit unit_code
+ c_string c_timezone
+
+ if unit == 's':
+ unit_code = TimeUnit_SECOND
+ elif unit == 'ms':
+ unit_code = TimeUnit_MILLI
+ else:
+ raise ValueError(f"Invalid time unit for time32: {unit!r}")
+
+ if unit_code in _time_type_cache:
+ return _time_type_cache[unit_code]
+
+ cdef Time32Type out = Time32Type.__new__(Time32Type)
+
+ out.init(ctime32(unit_code))
+ _time_type_cache[unit_code] = out
+
+ return out
+
+
+def time64(unit):
+ """
+ Create instance of 64-bit time (time of day) type with unit resolution.
+
+ Parameters
+ ----------
+ unit : str
+ One of 'us' [microsecond], or 'ns' [nanosecond].
+
+ Returns
+ -------
+ type : pyarrow.Time64Type
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> pa.time64('us')
+ Time64Type(time64[us])
+ >>> pa.time64('ns')
+ Time64Type(time64[ns])
+ """
+ cdef:
+ TimeUnit unit_code
+ c_string c_timezone
+
+ if unit == 'us':
+ unit_code = TimeUnit_MICRO
+ elif unit == 'ns':
+ unit_code = TimeUnit_NANO
+ else:
+ raise ValueError(f"Invalid time unit for time64: {unit!r}")
+
+ if unit_code in _time_type_cache:
+ return _time_type_cache[unit_code]
+
+ cdef Time64Type out = Time64Type.__new__(Time64Type)
+
+ out.init(ctime64(unit_code))
+ _time_type_cache[unit_code] = out
+
+ return out
+
+
+def duration(unit):
+ """
+ Create instance of a duration type with unit resolution.
+
+ Parameters
+ ----------
+ unit : str
+ One of 's' [second], 'ms' [millisecond], 'us' [microsecond], or
+ 'ns' [nanosecond].
+
+ Returns
+ -------
+ type : pyarrow.DurationType
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> pa.duration('us')
+ DurationType(duration[us])
+ >>> pa.duration('s')
+ DurationType(duration[s])
+ """
+ cdef:
+ TimeUnit unit_code
+
+ unit_code = string_to_timeunit(unit)
+
+ if unit_code in _duration_type_cache:
+ return _duration_type_cache[unit_code]
+
+ cdef DurationType out = DurationType.__new__(DurationType)
+
+ out.init(cduration(unit_code))
+ _duration_type_cache[unit_code] = out
+
+ return out
+
+
+def month_day_nano_interval():
+ """
+ Create instance of an interval type representing months, days and
+ nanoseconds between two dates.
+ """
+ return primitive_type(_Type_INTERVAL_MONTH_DAY_NANO)
+
+
+def date32():
+ """
+ Create instance of 32-bit date (days since UNIX epoch 1970-01-01).
+ """
+ return primitive_type(_Type_DATE32)
+
+
+def date64():
+ """
+ Create instance of 64-bit date (milliseconds since UNIX epoch 1970-01-01).
+ """
+ return primitive_type(_Type_DATE64)
+
+
+def float16():
+ """
+ Create half-precision floating point type.
+ """
+ return primitive_type(_Type_HALF_FLOAT)
+
+
+def float32():
+ """
+ Create single-precision floating point type.
+ """
+ return primitive_type(_Type_FLOAT)
+
+
+def float64():
+ """
+ Create double-precision floating point type.
+ """
+ return primitive_type(_Type_DOUBLE)
+
+
+cpdef DataType decimal128(int precision, int scale=0):
+ """
+ Create decimal type with precision and scale and 128-bit width.
+
+ Arrow decimals are fixed-point decimal numbers encoded as a scaled
+ integer. The precision is the number of significant digits that the
+ decimal type can represent; the scale is the number of digits after
+ the decimal point (note the scale can be negative).
+
+ As an example, ``decimal128(7, 3)`` can exactly represent the numbers
+ 1234.567 and -1234.567 (encoded internally as the 128-bit integers
+ 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567.
+
+ ``decimal128(5, -3)`` can exactly represent the number 12345000
+ (encoded internally as the 128-bit integer 12345), but neither
+ 123450000 nor 1234500.
+
+ If you need a precision higher than 38 significant digits, consider
+ using ``decimal256``.
+
+ Parameters
+ ----------
+ precision : int
+ Must be between 1 and 38
+ scale : int
+
+ Returns
+ -------
+ decimal_type : Decimal128Type
+ """
+ cdef shared_ptr[CDataType] decimal_type
+ if precision < 1 or precision > 38:
+ raise ValueError("precision should be between 1 and 38")
+ decimal_type.reset(new CDecimal128Type(precision, scale))
+ return pyarrow_wrap_data_type(decimal_type)
+
+
+cpdef DataType decimal256(int precision, int scale=0):
+ """
+ Create decimal type with precision and scale and 256-bit width.
+
+ Arrow decimals are fixed-point decimal numbers encoded as a scaled
+ integer. The precision is the number of significant digits that the
+ decimal type can represent; the scale is the number of digits after
+ the decimal point (note the scale can be negative).
+
+ For most use cases, the maximum precision offered by ``decimal128``
+ is sufficient, and it will result in a more compact and more efficient
+ encoding. ``decimal256`` is useful if you need a precision higher
+ than 38 significant digits.
+
+ Parameters
+ ----------
+ precision : int
+ Must be between 1 and 76
+ scale : int
+
+ Returns
+ -------
+ decimal_type : Decimal256Type
+ """
+ cdef shared_ptr[CDataType] decimal_type
+ if precision < 1 or precision > 76:
+ raise ValueError("precision should be between 1 and 76")
+ decimal_type.reset(new CDecimal256Type(precision, scale))
+ return pyarrow_wrap_data_type(decimal_type)
+
+
+def string():
+ """
+ Create UTF8 variable-length string type.
+ """
+ return primitive_type(_Type_STRING)
+
+
+def utf8():
+ """
+ Alias for string().
+ """
+ return string()
+
+
+def binary(int length=-1):
+ """
+ Create variable-length binary type.
+
+ Parameters
+ ----------
+ length : int, optional, default -1
+ If length == -1 then return a variable length binary type. If length is
+ greater than or equal to 0 then return a fixed size binary type of
+ width `length`.
+ """
+ if length == -1:
+ return primitive_type(_Type_BINARY)
+
+ cdef shared_ptr[CDataType] fixed_size_binary_type
+ fixed_size_binary_type.reset(new CFixedSizeBinaryType(length))
+ return pyarrow_wrap_data_type(fixed_size_binary_type)
+
+
+def large_binary():
+ """
+ Create large variable-length binary type.
+
+ This data type may not be supported by all Arrow implementations. Unless
+ you need to represent data larger than 2GB, you should prefer binary().
+ """
+ return primitive_type(_Type_LARGE_BINARY)
+
+
+def large_string():
+ """
+ Create large UTF8 variable-length string type.
+
+ This data type may not be supported by all Arrow implementations. Unless
+ you need to represent data larger than 2GB, you should prefer string().
+ """
+ return primitive_type(_Type_LARGE_STRING)
+
+
+def large_utf8():
+ """
+ Alias for large_string().
+ """
+ return large_string()
+
+
+def list_(value_type, int list_size=-1):
+ """
+ Create ListType instance from child data type or field.
+
+ Parameters
+ ----------
+ value_type : DataType or Field
+ list_size : int, optional, default -1
+ If length == -1 then return a variable length list type. If length is
+ greater than or equal to 0 then return a fixed size list type.
+
+ Returns
+ -------
+ list_type : DataType
+ """
+ cdef:
+ Field _field
+ shared_ptr[CDataType] list_type
+
+ if isinstance(value_type, DataType):
+ _field = field('item', value_type)
+ elif isinstance(value_type, Field):
+ _field = value_type
+ else:
+ raise TypeError('List requires DataType or Field')
+
+ if list_size == -1:
+ list_type.reset(new CListType(_field.sp_field))
+ else:
+ if list_size < 0:
+ raise ValueError("list_size should be a positive integer")
+ list_type.reset(new CFixedSizeListType(_field.sp_field, list_size))
+
+ return pyarrow_wrap_data_type(list_type)
+
+
+cpdef LargeListType large_list(value_type):
+ """
+ Create LargeListType instance from child data type or field.
+
+ This data type may not be supported by all Arrow implementations.
+ Unless you need to represent data larger than 2**31 elements, you should
+ prefer list_().
+
+ Parameters
+ ----------
+ value_type : DataType or Field
+
+ Returns
+ -------
+ list_type : DataType
+ """
+ cdef:
+ DataType data_type
+ Field _field
+ shared_ptr[CDataType] list_type
+ LargeListType out = LargeListType.__new__(LargeListType)
+
+ if isinstance(value_type, DataType):
+ _field = field('item', value_type)
+ elif isinstance(value_type, Field):
+ _field = value_type
+ else:
+ raise TypeError('List requires DataType or Field')
+
+ list_type.reset(new CLargeListType(_field.sp_field))
+ out.init(list_type)
+ return out
+
+
+cpdef MapType map_(key_type, item_type, keys_sorted=False):
+ """
+ Create MapType instance from key and item data types or fields.
+
+ Parameters
+ ----------
+ key_type : DataType
+ item_type : DataType
+ keys_sorted : bool
+
+ Returns
+ -------
+ map_type : DataType
+ """
+ cdef:
+ Field _key_field
+ Field _item_field
+ shared_ptr[CDataType] map_type
+ MapType out = MapType.__new__(MapType)
+
+ if isinstance(key_type, Field):
+ if key_type.nullable:
+ raise TypeError('Map key field should be non-nullable')
+ _key_field = key_type
+ else:
+ _key_field = field('key', ensure_type(key_type, allow_none=False),
+ nullable=False)
+
+ if isinstance(item_type, Field):
+ _item_field = item_type
+ else:
+ _item_field = field('value', ensure_type(item_type, allow_none=False))
+
+ map_type.reset(new CMapType(_key_field.sp_field, _item_field.sp_field,
+ keys_sorted))
+ out.init(map_type)
+ return out
+
+
+cpdef DictionaryType dictionary(index_type, value_type, bint ordered=False):
+ """
+ Dictionary (categorical, or simply encoded) type.
+
+ Parameters
+ ----------
+ index_type : DataType
+ value_type : DataType
+ ordered : bool
+
+ Returns
+ -------
+ type : DictionaryType
+ """
+ cdef:
+ DataType _index_type = ensure_type(index_type, allow_none=False)
+ DataType _value_type = ensure_type(value_type, allow_none=False)
+ DictionaryType out = DictionaryType.__new__(DictionaryType)
+ shared_ptr[CDataType] dict_type
+
+ if _index_type.id not in {
+ Type_INT8, Type_INT16, Type_INT32, Type_INT64,
+ Type_UINT8, Type_UINT16, Type_UINT32, Type_UINT64,
+ }:
+ raise TypeError("The dictionary index type should be integer.")
+
+ dict_type.reset(new CDictionaryType(_index_type.sp_type,
+ _value_type.sp_type, ordered == 1))
+ out.init(dict_type)
+ return out
+
+
+def struct(fields):
+ """
+ Create StructType instance from fields.
+
+ A struct is a nested type parameterized by an ordered sequence of types
+ (which can all be distinct), called its fields.
+
+ Parameters
+ ----------
+ fields : iterable of Fields or tuples, or mapping of strings to DataTypes
+ Each field must have a UTF8-encoded name, and these field names are
+ part of the type metadata.
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> fields = [
+ ... ('f1', pa.int32()),
+ ... ('f2', pa.string()),
+ ... ]
+ >>> struct_type = pa.struct(fields)
+ >>> struct_type
+ StructType(struct<f1: int32, f2: string>)
+ >>> fields = [
+ ... pa.field('f1', pa.int32()),
+ ... pa.field('f2', pa.string(), nullable=False),
+ ... ]
+ >>> pa.struct(fields)
+ StructType(struct<f1: int32, f2: string not null>)
+
+ Returns
+ -------
+ type : DataType
+ """
+ cdef:
+ Field py_field
+ vector[shared_ptr[CField]] c_fields
+ cdef shared_ptr[CDataType] struct_type
+
+ if isinstance(fields, Mapping):
+ fields = fields.items()
+
+ for item in fields:
+ if isinstance(item, tuple):
+ py_field = field(*item)
+ else:
+ py_field = item
+ c_fields.push_back(py_field.sp_field)
+
+ struct_type.reset(new CStructType(c_fields))
+ return pyarrow_wrap_data_type(struct_type)
+
+
+cdef _extract_union_params(child_fields, type_codes,
+ vector[shared_ptr[CField]]* c_fields,
+ vector[int8_t]* c_type_codes):
+ cdef:
+ Field child_field
+
+ for child_field in child_fields:
+ c_fields[0].push_back(child_field.sp_field)
+
+ if type_codes is not None:
+ if len(type_codes) != <Py_ssize_t>(c_fields.size()):
+ raise ValueError("type_codes should have the same length "
+ "as fields")
+ for code in type_codes:
+ c_type_codes[0].push_back(code)
+ else:
+ c_type_codes[0] = range(c_fields.size())
+
+
+def sparse_union(child_fields, type_codes=None):
+ """
+ Create SparseUnionType from child fields.
+
+ A sparse union is a nested type where each logical value is taken from
+ a single child. A buffer of 8-bit type ids indicates which child
+ a given logical value is to be taken from.
+
+ In a sparse union, each child array should have the same length as the
+ union array, regardless of the actual number of union values that
+ refer to it.
+
+ Parameters
+ ----------
+ child_fields : sequence of Field values
+ Each field must have a UTF8-encoded name, and these field names are
+ part of the type metadata.
+ type_codes : list of integers, default None
+
+ Returns
+ -------
+ type : SparseUnionType
+ """
+ cdef:
+ vector[shared_ptr[CField]] c_fields
+ vector[int8_t] c_type_codes
+
+ _extract_union_params(child_fields, type_codes,
+ &c_fields, &c_type_codes)
+
+ return pyarrow_wrap_data_type(
+ CMakeSparseUnionType(move(c_fields), move(c_type_codes)))
+
+
+def dense_union(child_fields, type_codes=None):
+ """
+ Create DenseUnionType from child fields.
+
+ A dense union is a nested type where each logical value is taken from
+ a single child, at a specific offset. A buffer of 8-bit type ids
+ indicates which child a given logical value is to be taken from,
+ and a buffer of 32-bit offsets indicates at which physical position
+ in the given child array the logical value is to be taken from.
+
+ Unlike a sparse union, a dense union allows encoding only the child array
+ values which are actually referred to by the union array. This is
+ counterbalanced by the additional footprint of the offsets buffer, and
+ the additional indirection cost when looking up values.
+
+ Parameters
+ ----------
+ child_fields : sequence of Field values
+ Each field must have a UTF8-encoded name, and these field names are
+ part of the type metadata.
+ type_codes : list of integers, default None
+
+ Returns
+ -------
+ type : DenseUnionType
+ """
+ cdef:
+ vector[shared_ptr[CField]] c_fields
+ vector[int8_t] c_type_codes
+
+ _extract_union_params(child_fields, type_codes,
+ &c_fields, &c_type_codes)
+
+ return pyarrow_wrap_data_type(
+ CMakeDenseUnionType(move(c_fields), move(c_type_codes)))
+
+
+def union(child_fields, mode, type_codes=None):
+ """
+ Create UnionType from child fields.
+
+ A union is a nested type where each logical value is taken from a
+ single child. A buffer of 8-bit type ids indicates which child
+ a given logical value is to be taken from.
+
+ Unions come in two flavors: sparse and dense
+ (see also `pyarrow.sparse_union` and `pyarrow.dense_union`).
+
+ Parameters
+ ----------
+ child_fields : sequence of Field values
+ Each field must have a UTF8-encoded name, and these field names are
+ part of the type metadata.
+ mode : str
+ Must be 'sparse' or 'dense'
+ type_codes : list of integers, default None
+
+ Returns
+ -------
+ type : UnionType
+ """
+ cdef:
+ Field child_field
+ vector[shared_ptr[CField]] c_fields
+ vector[int8_t] c_type_codes
+ shared_ptr[CDataType] union_type
+ int i
+
+ if isinstance(mode, int):
+ if mode not in (_UnionMode_SPARSE, _UnionMode_DENSE):
+ raise ValueError("Invalid union mode {0!r}".format(mode))
+ else:
+ if mode == 'sparse':
+ mode = _UnionMode_SPARSE
+ elif mode == 'dense':
+ mode = _UnionMode_DENSE
+ else:
+ raise ValueError("Invalid union mode {0!r}".format(mode))
+
+ if mode == _UnionMode_SPARSE:
+ return sparse_union(child_fields, type_codes)
+ else:
+ return dense_union(child_fields, type_codes)
+
+
+cdef dict _type_aliases = {
+ 'null': null,
+ 'bool': bool_,
+ 'boolean': bool_,
+ 'i1': int8,
+ 'int8': int8,
+ 'i2': int16,
+ 'int16': int16,
+ 'i4': int32,
+ 'int32': int32,
+ 'i8': int64,
+ 'int64': int64,
+ 'u1': uint8,
+ 'uint8': uint8,
+ 'u2': uint16,
+ 'uint16': uint16,
+ 'u4': uint32,
+ 'uint32': uint32,
+ 'u8': uint64,
+ 'uint64': uint64,
+ 'f2': float16,
+ 'halffloat': float16,
+ 'float16': float16,
+ 'f4': float32,
+ 'float': float32,
+ 'float32': float32,
+ 'f8': float64,
+ 'double': float64,
+ 'float64': float64,
+ 'string': string,
+ 'str': string,
+ 'utf8': string,
+ 'binary': binary,
+ 'large_string': large_string,
+ 'large_str': large_string,
+ 'large_utf8': large_string,
+ 'large_binary': large_binary,
+ 'date32': date32,
+ 'date64': date64,
+ 'date32[day]': date32,
+ 'date64[ms]': date64,
+ 'time32[s]': time32('s'),
+ 'time32[ms]': time32('ms'),
+ 'time64[us]': time64('us'),
+ 'time64[ns]': time64('ns'),
+ 'timestamp[s]': timestamp('s'),
+ 'timestamp[ms]': timestamp('ms'),
+ 'timestamp[us]': timestamp('us'),
+ 'timestamp[ns]': timestamp('ns'),
+ 'duration[s]': duration('s'),
+ 'duration[ms]': duration('ms'),
+ 'duration[us]': duration('us'),
+ 'duration[ns]': duration('ns'),
+ 'month_day_nano_interval': month_day_nano_interval(),
+}
+
+
+def type_for_alias(name):
+ """
+ Return DataType given a string alias if one exists.
+
+ Parameters
+ ----------
+ name : str
+ The alias of the DataType that should be retrieved.
+
+ Returns
+ -------
+ type : DataType
+ """
+ name = name.lower()
+ try:
+ alias = _type_aliases[name]
+ except KeyError:
+ raise ValueError('No type alias for {0}'.format(name))
+
+ if isinstance(alias, DataType):
+ return alias
+ return alias()
+
+
+cpdef DataType ensure_type(object ty, bint allow_none=False):
+ if allow_none and ty is None:
+ return None
+ elif isinstance(ty, DataType):
+ return ty
+ elif isinstance(ty, str):
+ return type_for_alias(ty)
+ else:
+ raise TypeError('DataType expected, got {!r}'.format(type(ty)))
+
+
+def schema(fields, metadata=None):
+ """
+ Construct pyarrow.Schema from collection of fields.
+
+ Parameters
+ ----------
+ fields : iterable of Fields or tuples, or mapping of strings to DataTypes
+ metadata : dict, default None
+ Keys and values must be coercible to bytes.
+
+ Examples
+ --------
+ >>> import pyarrow as pa
+ >>> pa.schema([
+ ... ('some_int', pa.int32()),
+ ... ('some_string', pa.string())
+ ... ])
+ some_int: int32
+ some_string: string
+ >>> pa.schema([
+ ... pa.field('some_int', pa.int32()),
+ ... pa.field('some_string', pa.string())
+ ... ])
+ some_int: int32
+ some_string: string
+
+ Returns
+ -------
+ schema : pyarrow.Schema
+ """
+ cdef:
+ shared_ptr[const CKeyValueMetadata] c_meta
+ shared_ptr[CSchema] c_schema
+ Schema result
+ Field py_field
+ vector[shared_ptr[CField]] c_fields
+
+ if isinstance(fields, Mapping):
+ fields = fields.items()
+
+ for item in fields:
+ if isinstance(item, tuple):
+ py_field = field(*item)
+ else:
+ py_field = item
+ if py_field is None:
+ raise TypeError("field or tuple expected, got None")
+ c_fields.push_back(py_field.sp_field)
+
+ metadata = ensure_metadata(metadata, allow_none=True)
+ c_meta = pyarrow_unwrap_metadata(metadata)
+
+ c_schema.reset(new CSchema(c_fields, c_meta))
+ result = Schema.__new__(Schema)
+ result.init_schema(c_schema)
+
+ return result
+
+
+def from_numpy_dtype(object dtype):
+ """
+ Convert NumPy dtype to pyarrow.DataType.
+
+ Parameters
+ ----------
+ dtype : the numpy dtype to convert
+ """
+ cdef shared_ptr[CDataType] c_type
+ dtype = np.dtype(dtype)
+ with nogil:
+ check_status(NumPyDtypeToArrow(dtype, &c_type))
+
+ return pyarrow_wrap_data_type(c_type)
+
+
+def is_boolean_value(object obj):
+ """
+ Check if the object is a boolean.
+
+ Parameters
+ ----------
+ obj : object
+ The object to check
+ """
+ return IsPyBool(obj)
+
+
+def is_integer_value(object obj):
+ """
+ Check if the object is an integer.
+
+ Parameters
+ ----------
+ obj : object
+ The object to check
+ """
+ return IsPyInt(obj)
+
+
+def is_float_value(object obj):
+ """
+ Check if the object is a float.
+
+ Parameters
+ ----------
+ obj : object
+ The object to check
+ """
+ return IsPyFloat(obj)
+
+
+cdef class _ExtensionRegistryNanny(_Weakrefable):
+ # Keep the registry alive until we have unregistered PyExtensionType
+ cdef:
+ shared_ptr[CExtensionTypeRegistry] registry
+
+ def __cinit__(self):
+ self.registry = CExtensionTypeRegistry.GetGlobalRegistry()
+
+ def release_registry(self):
+ self.registry.reset()
+
+
+_registry_nanny = _ExtensionRegistryNanny()
+
+
+def _register_py_extension_type():
+ cdef:
+ DataType storage_type
+ shared_ptr[CExtensionType] cpy_ext_type
+ c_string c_extension_name = tobytes("arrow.py_extension_type")
+
+ # Make a dummy C++ ExtensionType
+ storage_type = null()
+ check_status(CPyExtensionType.FromClass(
+ storage_type.sp_type, c_extension_name, PyExtensionType,
+ &cpy_ext_type))
+ check_status(
+ RegisterPyExtensionType(<shared_ptr[CDataType]> cpy_ext_type))
+
+
+def _unregister_py_extension_types():
+ # This needs to be done explicitly before the Python interpreter is
+ # finalized. If the C++ type is destroyed later in the process
+ # teardown stage, it will invoke CPython APIs such as Py_DECREF
+ # with a destroyed interpreter.
+ unregister_extension_type("arrow.py_extension_type")
+ for ext_type in _python_extension_types_registry:
+ try:
+ unregister_extension_type(ext_type.extension_name)
+ except KeyError:
+ pass
+ _registry_nanny.release_registry()
+
+
+_register_py_extension_type()
+atexit.register(_unregister_py_extension_types)
diff --git a/src/arrow/python/pyarrow/types.py b/src/arrow/python/pyarrow/types.py
new file mode 100644
index 000000000..f239c883b
--- /dev/null
+++ b/src/arrow/python/pyarrow/types.py
@@ -0,0 +1,550 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Tools for dealing with Arrow type metadata in Python
+
+
+from pyarrow.lib import (is_boolean_value, # noqa
+ is_integer_value,
+ is_float_value)
+
+import pyarrow.lib as lib
+
+
+_SIGNED_INTEGER_TYPES = {lib.Type_INT8, lib.Type_INT16, lib.Type_INT32,
+ lib.Type_INT64}
+_UNSIGNED_INTEGER_TYPES = {lib.Type_UINT8, lib.Type_UINT16, lib.Type_UINT32,
+ lib.Type_UINT64}
+_INTEGER_TYPES = _SIGNED_INTEGER_TYPES | _UNSIGNED_INTEGER_TYPES
+_FLOATING_TYPES = {lib.Type_HALF_FLOAT, lib.Type_FLOAT, lib.Type_DOUBLE}
+_DECIMAL_TYPES = {lib.Type_DECIMAL128, lib.Type_DECIMAL256}
+_DATE_TYPES = {lib.Type_DATE32, lib.Type_DATE64}
+_TIME_TYPES = {lib.Type_TIME32, lib.Type_TIME64}
+_INTERVAL_TYPES = {lib.Type_INTERVAL_MONTH_DAY_NANO}
+_TEMPORAL_TYPES = ({lib.Type_TIMESTAMP,
+ lib.Type_DURATION} | _TIME_TYPES | _DATE_TYPES |
+ _INTERVAL_TYPES)
+_UNION_TYPES = {lib.Type_SPARSE_UNION, lib.Type_DENSE_UNION}
+_NESTED_TYPES = {lib.Type_LIST, lib.Type_LARGE_LIST, lib.Type_STRUCT,
+ lib.Type_MAP} | _UNION_TYPES
+
+
+def is_null(t):
+ """
+ Return True if value is an instance of a null type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_NA
+
+
+def is_boolean(t):
+ """
+ Return True if value is an instance of a boolean type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_BOOL
+
+
+def is_integer(t):
+ """
+ Return True if value is an instance of any integer type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _INTEGER_TYPES
+
+
+def is_signed_integer(t):
+ """
+ Return True if value is an instance of any signed integer type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _SIGNED_INTEGER_TYPES
+
+
+def is_unsigned_integer(t):
+ """
+ Return True if value is an instance of any unsigned integer type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _UNSIGNED_INTEGER_TYPES
+
+
+def is_int8(t):
+ """
+ Return True if value is an instance of an int8 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_INT8
+
+
+def is_int16(t):
+ """
+ Return True if value is an instance of an int16 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_INT16
+
+
+def is_int32(t):
+ """
+ Return True if value is an instance of an int32 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_INT32
+
+
+def is_int64(t):
+ """
+ Return True if value is an instance of an int64 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_INT64
+
+
+def is_uint8(t):
+ """
+ Return True if value is an instance of an uint8 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_UINT8
+
+
+def is_uint16(t):
+ """
+ Return True if value is an instance of an uint16 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_UINT16
+
+
+def is_uint32(t):
+ """
+ Return True if value is an instance of an uint32 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_UINT32
+
+
+def is_uint64(t):
+ """
+ Return True if value is an instance of an uint64 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_UINT64
+
+
+def is_floating(t):
+ """
+ Return True if value is an instance of a floating point numeric type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _FLOATING_TYPES
+
+
+def is_float16(t):
+ """
+ Return True if value is an instance of a float16 (half-precision) type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_HALF_FLOAT
+
+
+def is_float32(t):
+ """
+ Return True if value is an instance of a float32 (single precision) type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_FLOAT
+
+
+def is_float64(t):
+ """
+ Return True if value is an instance of a float64 (double precision) type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_DOUBLE
+
+
+def is_list(t):
+ """
+ Return True if value is an instance of a list type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_LIST
+
+
+def is_large_list(t):
+ """
+ Return True if value is an instance of a large list type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_LARGE_LIST
+
+
+def is_fixed_size_list(t):
+ """
+ Return True if value is an instance of a fixed size list type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_FIXED_SIZE_LIST
+
+
+def is_struct(t):
+ """
+ Return True if value is an instance of a struct type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_STRUCT
+
+
+def is_union(t):
+ """
+ Return True if value is an instance of a union type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _UNION_TYPES
+
+
+def is_nested(t):
+ """
+ Return True if value is an instance of a nested type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _NESTED_TYPES
+
+
+def is_temporal(t):
+ """
+ Return True if value is an instance of date, time, timestamp or duration.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _TEMPORAL_TYPES
+
+
+def is_timestamp(t):
+ """
+ Return True if value is an instance of a timestamp type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_TIMESTAMP
+
+
+def is_duration(t):
+ """
+ Return True if value is an instance of a duration type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_DURATION
+
+
+def is_time(t):
+ """
+ Return True if value is an instance of a time type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _TIME_TYPES
+
+
+def is_time32(t):
+ """
+ Return True if value is an instance of a time32 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_TIME32
+
+
+def is_time64(t):
+ """
+ Return True if value is an instance of a time64 type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_TIME64
+
+
+def is_binary(t):
+ """
+ Return True if value is an instance of a variable-length binary type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_BINARY
+
+
+def is_large_binary(t):
+ """
+ Return True if value is an instance of a large variable-length
+ binary type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_LARGE_BINARY
+
+
+def is_unicode(t):
+ """
+ Alias for is_string.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return is_string(t)
+
+
+def is_string(t):
+ """
+ Return True if value is an instance of string (utf8 unicode) type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_STRING
+
+
+def is_large_unicode(t):
+ """
+ Alias for is_large_string.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return is_large_string(t)
+
+
+def is_large_string(t):
+ """
+ Return True if value is an instance of large string (utf8 unicode) type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_LARGE_STRING
+
+
+def is_fixed_size_binary(t):
+ """
+ Return True if value is an instance of a fixed size binary type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_FIXED_SIZE_BINARY
+
+
+def is_date(t):
+ """
+ Return True if value is an instance of a date type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _DATE_TYPES
+
+
+def is_date32(t):
+ """
+ Return True if value is an instance of a date32 (days) type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_DATE32
+
+
+def is_date64(t):
+ """
+ Return True if value is an instance of a date64 (milliseconds) type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_DATE64
+
+
+def is_map(t):
+ """
+ Return True if value is an instance of a map logical type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_MAP
+
+
+def is_decimal(t):
+ """
+ Return True if value is an instance of a decimal type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id in _DECIMAL_TYPES
+
+
+def is_decimal128(t):
+ """
+ Return True if value is an instance of a decimal type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_DECIMAL128
+
+
+def is_decimal256(t):
+ """
+ Return True if value is an instance of a decimal type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_DECIMAL256
+
+
+def is_dictionary(t):
+ """
+ Return True if value is an instance of a dictionary-encoded type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return t.id == lib.Type_DICTIONARY
+
+
+def is_interval(t):
+ """
+ Return True if the value is an instance of an interval type.
+
+ Parameters
+ ----------
+ t : DateType
+ """
+ return t.id == lib.Type_INTERVAL_MONTH_DAY_NANO
+
+
+def is_primitive(t):
+ """
+ Return True if the value is an instance of a primitive type.
+
+ Parameters
+ ----------
+ t : DataType
+ """
+ return lib._is_primitive(t.id)
diff --git a/src/arrow/python/pyarrow/util.py b/src/arrow/python/pyarrow/util.py
new file mode 100644
index 000000000..69bde250c
--- /dev/null
+++ b/src/arrow/python/pyarrow/util.py
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Miscellaneous utility code
+
+import contextlib
+import functools
+import gc
+import pathlib
+import socket
+import sys
+import types
+import warnings
+
+
+_DEPR_MSG = (
+ "pyarrow.{} is deprecated as of {}, please use pyarrow.{} instead."
+)
+
+
+def implements(f):
+ def decorator(g):
+ g.__doc__ = f.__doc__
+ return g
+ return decorator
+
+
+def _deprecate_api(old_name, new_name, api, next_version):
+ msg = _DEPR_MSG.format(old_name, next_version, new_name)
+
+ def wrapper(*args, **kwargs):
+ warnings.warn(msg, FutureWarning)
+ return api(*args, **kwargs)
+ return wrapper
+
+
+def _deprecate_class(old_name, new_class, next_version,
+ instancecheck=True):
+ """
+ Raise warning if a deprecated class is used in an isinstance check.
+ """
+ class _DeprecatedMeta(type):
+ def __instancecheck__(self, other):
+ warnings.warn(
+ _DEPR_MSG.format(old_name, next_version, new_class.__name__),
+ FutureWarning,
+ stacklevel=2
+ )
+ return isinstance(other, new_class)
+
+ return _DeprecatedMeta(old_name, (new_class,), {})
+
+
+def _is_iterable(obj):
+ try:
+ iter(obj)
+ return True
+ except TypeError:
+ return False
+
+
+def _is_path_like(path):
+ # PEP519 filesystem path protocol is available from python 3.6, so pathlib
+ # doesn't implement __fspath__ for earlier versions
+ return (isinstance(path, str) or
+ hasattr(path, '__fspath__') or
+ isinstance(path, pathlib.Path))
+
+
+def _stringify_path(path):
+ """
+ Convert *path* to a string or unicode path if possible.
+ """
+ if isinstance(path, str):
+ return path
+
+ # checking whether path implements the filesystem protocol
+ try:
+ return path.__fspath__() # new in python 3.6
+ except AttributeError:
+ # fallback pathlib ckeck for earlier python versions than 3.6
+ if isinstance(path, pathlib.Path):
+ return str(path)
+
+ raise TypeError("not a path-like object")
+
+
+def product(seq):
+ """
+ Return a product of sequence items.
+ """
+ return functools.reduce(lambda a, b: a*b, seq, 1)
+
+
+def get_contiguous_span(shape, strides, itemsize):
+ """
+ Return a contiguous span of N-D array data.
+
+ Parameters
+ ----------
+ shape : tuple
+ strides : tuple
+ itemsize : int
+ Specify array shape data
+
+ Returns
+ -------
+ start, end : int
+ The span end points.
+ """
+ if not strides:
+ start = 0
+ end = itemsize * product(shape)
+ else:
+ start = 0
+ end = itemsize
+ for i, dim in enumerate(shape):
+ if dim == 0:
+ start = end = 0
+ break
+ stride = strides[i]
+ if stride > 0:
+ end += stride * (dim - 1)
+ elif stride < 0:
+ start += stride * (dim - 1)
+ if end - start != itemsize * product(shape):
+ raise ValueError('array data is non-contiguous')
+ return start, end
+
+
+def find_free_port():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ with contextlib.closing(sock) as sock:
+ sock.bind(('', 0))
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return sock.getsockname()[1]
+
+
+def guid():
+ from uuid import uuid4
+ return uuid4().hex
+
+
+def _break_traceback_cycle_from_frame(frame):
+ # Clear local variables in all inner frames, so as to break the
+ # reference cycle.
+ this_frame = sys._getframe(0)
+ refs = gc.get_referrers(frame)
+ while refs:
+ for frame in refs:
+ if frame is not this_frame and isinstance(frame, types.FrameType):
+ break
+ else:
+ # No frame found in referrers (finished?)
+ break
+ refs = None
+ # Clear the frame locals, to try and break the cycle (it is
+ # somewhere along the chain of execution frames).
+ frame.clear()
+ # To visit the inner frame, we need to find it among the
+ # referers of this frame (while `frame.f_back` would let
+ # us visit the outer frame).
+ refs = gc.get_referrers(frame)
+ refs = frame = this_frame = None
diff --git a/src/arrow/python/pyarrow/vendored/__init__.py b/src/arrow/python/pyarrow/vendored/__init__.py
new file mode 100644
index 000000000..13a83393a
--- /dev/null
+++ b/src/arrow/python/pyarrow/vendored/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/src/arrow/python/pyarrow/vendored/version.py b/src/arrow/python/pyarrow/vendored/version.py
new file mode 100644
index 000000000..b74f1da97
--- /dev/null
+++ b/src/arrow/python/pyarrow/vendored/version.py
@@ -0,0 +1,545 @@
+# Vendored from https://github.com/pypa/packaging,
+# changeset b5878c977206f60302536db969a8cef420853ade
+
+# This file is dual licensed under the terms of the Apache License, Version
+# 2.0, and the BSD License. See the LICENSE file in the root of the
+# `packaging` repository for complete details.
+
+import collections
+import itertools
+import re
+import warnings
+
+__all__ = ["parse", "Version", "LegacyVersion",
+ "InvalidVersion", "VERSION_PATTERN"]
+
+
+class InfinityType:
+ def __repr__(self):
+ return "Infinity"
+
+ def __hash__(self):
+ return hash(repr(self))
+
+ def __lt__(self, other):
+ return False
+
+ def __le__(self, other):
+ return False
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__)
+
+ def __ne__(self, other):
+ return not isinstance(other, self.__class__)
+
+ def __gt__(self, other):
+ return True
+
+ def __ge__(self, other):
+ return True
+
+ def __neg__(self):
+ return NegativeInfinity
+
+
+Infinity = InfinityType()
+
+
+class NegativeInfinityType:
+ def __repr__(self):
+ return "-Infinity"
+
+ def __hash__(self):
+ return hash(repr(self))
+
+ def __lt__(self, other):
+ return True
+
+ def __le__(self, other):
+ return True
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__)
+
+ def __ne__(self, other):
+ return not isinstance(other, self.__class__)
+
+ def __gt__(self, other):
+ return False
+
+ def __ge__(self, other):
+ return False
+
+ def __neg__(self):
+ return Infinity
+
+
+NegativeInfinity = NegativeInfinityType()
+
+
+_Version = collections.namedtuple(
+ "_Version", ["epoch", "release", "dev", "pre", "post", "local"]
+)
+
+
+def parse(version):
+ """
+ Parse the given version string and return either a :class:`Version` object
+ or a :class:`LegacyVersion` object depending on if the given version is
+ a valid PEP 440 version or a legacy version.
+ """
+ try:
+ return Version(version)
+ except InvalidVersion:
+ return LegacyVersion(version)
+
+
+class InvalidVersion(ValueError):
+ """
+ An invalid version was found, users should refer to PEP 440.
+ """
+
+
+class _BaseVersion:
+
+ def __hash__(self):
+ return hash(self._key)
+
+ # Please keep the duplicated `isinstance` check
+ # in the six comparisons hereunder
+ # unless you find a way to avoid adding overhead function calls.
+ def __lt__(self, other):
+ if not isinstance(other, _BaseVersion):
+ return NotImplemented
+
+ return self._key < other._key
+
+ def __le__(self, other):
+ if not isinstance(other, _BaseVersion):
+ return NotImplemented
+
+ return self._key <= other._key
+
+ def __eq__(self, other):
+ if not isinstance(other, _BaseVersion):
+ return NotImplemented
+
+ return self._key == other._key
+
+ def __ge__(self, other):
+ if not isinstance(other, _BaseVersion):
+ return NotImplemented
+
+ return self._key >= other._key
+
+ def __gt__(self, other):
+ if not isinstance(other, _BaseVersion):
+ return NotImplemented
+
+ return self._key > other._key
+
+ def __ne__(self, other):
+ if not isinstance(other, _BaseVersion):
+ return NotImplemented
+
+ return self._key != other._key
+
+
+class LegacyVersion(_BaseVersion):
+ def __init__(self, version):
+ self._version = str(version)
+ self._key = _legacy_cmpkey(self._version)
+
+ warnings.warn(
+ "Creating a LegacyVersion has been deprecated and will be "
+ "removed in the next major release",
+ DeprecationWarning,
+ )
+
+ def __str__(self):
+ return self._version
+
+ def __repr__(self):
+ return f"<LegacyVersion('{self}')>"
+
+ @property
+ def public(self):
+ return self._version
+
+ @property
+ def base_version(self):
+ return self._version
+
+ @property
+ def epoch(self):
+ return -1
+
+ @property
+ def release(self):
+ return None
+
+ @property
+ def pre(self):
+ return None
+
+ @property
+ def post(self):
+ return None
+
+ @property
+ def dev(self):
+ return None
+
+ @property
+ def local(self):
+ return None
+
+ @property
+ def is_prerelease(self):
+ return False
+
+ @property
+ def is_postrelease(self):
+ return False
+
+ @property
+ def is_devrelease(self):
+ return False
+
+
+_legacy_version_component_re = re.compile(
+ r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE)
+
+_legacy_version_replacement_map = {
+ "pre": "c",
+ "preview": "c",
+ "-": "final-",
+ "rc": "c",
+ "dev": "@",
+}
+
+
+def _parse_version_parts(s):
+ for part in _legacy_version_component_re.split(s):
+ part = _legacy_version_replacement_map.get(part, part)
+
+ if not part or part == ".":
+ continue
+
+ if part[:1] in "0123456789":
+ # pad for numeric comparison
+ yield part.zfill(8)
+ else:
+ yield "*" + part
+
+ # ensure that alpha/beta/candidate are before final
+ yield "*final"
+
+
+def _legacy_cmpkey(version):
+
+ # We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch
+ # greater than or equal to 0. This will effectively put the LegacyVersion,
+ # which uses the defacto standard originally implemented by setuptools,
+ # as before all PEP 440 versions.
+ epoch = -1
+
+ # This scheme is taken from pkg_resources.parse_version setuptools prior to
+ # it's adoption of the packaging library.
+ parts = []
+ for part in _parse_version_parts(version.lower()):
+ if part.startswith("*"):
+ # remove "-" before a prerelease tag
+ if part < "*final":
+ while parts and parts[-1] == "*final-":
+ parts.pop()
+
+ # remove trailing zeros from each series of numeric parts
+ while parts and parts[-1] == "00000000":
+ parts.pop()
+
+ parts.append(part)
+
+ return epoch, tuple(parts)
+
+
+# Deliberately not anchored to the start and end of the string, to make it
+# easier for 3rd party code to reuse
+VERSION_PATTERN = r"""
+ v?
+ (?:
+ (?:(?P<epoch>[0-9]+)!)? # epoch
+ (?P<release>[0-9]+(?:\.[0-9]+)*) # release segment
+ (?P<pre> # pre-release
+ [-_\.]?
+ (?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview))
+ [-_\.]?
+ (?P<pre_n>[0-9]+)?
+ )?
+ (?P<post> # post release
+ (?:-(?P<post_n1>[0-9]+))
+ |
+ (?:
+ [-_\.]?
+ (?P<post_l>post|rev|r)
+ [-_\.]?
+ (?P<post_n2>[0-9]+)?
+ )
+ )?
+ (?P<dev> # dev release
+ [-_\.]?
+ (?P<dev_l>dev)
+ [-_\.]?
+ (?P<dev_n>[0-9]+)?
+ )?
+ )
+ (?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
+"""
+
+
+class Version(_BaseVersion):
+
+ _regex = re.compile(r"^\s*" + VERSION_PATTERN +
+ r"\s*$", re.VERBOSE | re.IGNORECASE)
+
+ def __init__(self, version):
+
+ # Validate the version and parse it into pieces
+ match = self._regex.search(version)
+ if not match:
+ raise InvalidVersion(f"Invalid version: '{version}'")
+
+ # Store the parsed out pieces of the version
+ self._version = _Version(
+ epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+ release=tuple(int(i) for i in match.group("release").split(".")),
+ pre=_parse_letter_version(
+ match.group("pre_l"), match.group("pre_n")),
+ post=_parse_letter_version(
+ match.group("post_l"), match.group(
+ "post_n1") or match.group("post_n2")
+ ),
+ dev=_parse_letter_version(
+ match.group("dev_l"), match.group("dev_n")),
+ local=_parse_local_version(match.group("local")),
+ )
+
+ # Generate a key which will be used for sorting
+ self._key = _cmpkey(
+ self._version.epoch,
+ self._version.release,
+ self._version.pre,
+ self._version.post,
+ self._version.dev,
+ self._version.local,
+ )
+
+ def __repr__(self):
+ return f"<Version('{self}')>"
+
+ def __str__(self):
+ parts = []
+
+ # Epoch
+ if self.epoch != 0:
+ parts.append(f"{self.epoch}!")
+
+ # Release segment
+ parts.append(".".join(str(x) for x in self.release))
+
+ # Pre-release
+ if self.pre is not None:
+ parts.append("".join(str(x) for x in self.pre))
+
+ # Post-release
+ if self.post is not None:
+ parts.append(f".post{self.post}")
+
+ # Development release
+ if self.dev is not None:
+ parts.append(f".dev{self.dev}")
+
+ # Local version segment
+ if self.local is not None:
+ parts.append(f"+{self.local}")
+
+ return "".join(parts)
+
+ @property
+ def epoch(self):
+ _epoch = self._version.epoch
+ return _epoch
+
+ @property
+ def release(self):
+ _release = self._version.release
+ return _release
+
+ @property
+ def pre(self):
+ _pre = self._version.pre
+ return _pre
+
+ @property
+ def post(self):
+ return self._version.post[1] if self._version.post else None
+
+ @property
+ def dev(self):
+ return self._version.dev[1] if self._version.dev else None
+
+ @property
+ def local(self):
+ if self._version.local:
+ return ".".join(str(x) for x in self._version.local)
+ else:
+ return None
+
+ @property
+ def public(self):
+ return str(self).split("+", 1)[0]
+
+ @property
+ def base_version(self):
+ parts = []
+
+ # Epoch
+ if self.epoch != 0:
+ parts.append(f"{self.epoch}!")
+
+ # Release segment
+ parts.append(".".join(str(x) for x in self.release))
+
+ return "".join(parts)
+
+ @property
+ def is_prerelease(self):
+ return self.dev is not None or self.pre is not None
+
+ @property
+ def is_postrelease(self):
+ return self.post is not None
+
+ @property
+ def is_devrelease(self):
+ return self.dev is not None
+
+ @property
+ def major(self):
+ return self.release[0] if len(self.release) >= 1 else 0
+
+ @property
+ def minor(self):
+ return self.release[1] if len(self.release) >= 2 else 0
+
+ @property
+ def micro(self):
+ return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(letter, number):
+
+ if letter:
+ # We consider there to be an implicit 0 in a pre-release if there is
+ # not a numeral associated with it.
+ if number is None:
+ number = 0
+
+ # We normalize any letters to their lower case form
+ letter = letter.lower()
+
+ # We consider some words to be alternate spellings of other words and
+ # in those cases we want to normalize the spellings to our preferred
+ # spelling.
+ if letter == "alpha":
+ letter = "a"
+ elif letter == "beta":
+ letter = "b"
+ elif letter in ["c", "pre", "preview"]:
+ letter = "rc"
+ elif letter in ["rev", "r"]:
+ letter = "post"
+
+ return letter, int(number)
+ if not letter and number:
+ # We assume if we are given a number, but we are not given a letter
+ # then this is using the implicit post release syntax (e.g. 1.0-1)
+ letter = "post"
+
+ return letter, int(number)
+
+ return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local):
+ """
+ Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+ """
+ if local is not None:
+ return tuple(
+ part.lower() if not part.isdigit() else int(part)
+ for part in _local_version_separators.split(local)
+ )
+ return None
+
+
+def _cmpkey(epoch, release, pre, post, dev, local):
+
+ # When we compare a release version, we want to compare it with all of the
+ # trailing zeros removed. So we'll use a reverse the list, drop all the now
+ # leading zeros until we come to something non zero, then take the rest
+ # re-reverse it back into the correct order and make it a tuple and use
+ # that for our sorting key.
+ _release = tuple(
+ reversed(list(itertools.dropwhile(lambda x: x == 0,
+ reversed(release))))
+ )
+
+ # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+ # We'll do this by abusing the pre segment, but we _only_ want to do this
+ # if there is not a pre or a post segment. If we have one of those then
+ # the normal sorting rules will handle this case correctly.
+ if pre is None and post is None and dev is not None:
+ _pre = NegativeInfinity
+ # Versions without a pre-release (except as noted above) should sort after
+ # those with one.
+ elif pre is None:
+ _pre = Infinity
+ else:
+ _pre = pre
+
+ # Versions without a post segment should sort before those with one.
+ if post is None:
+ _post = NegativeInfinity
+
+ else:
+ _post = post
+
+ # Versions without a development segment should sort after those with one.
+ if dev is None:
+ _dev = Infinity
+
+ else:
+ _dev = dev
+
+ if local is None:
+ # Versions without a local segment should sort before those with one.
+ _local = NegativeInfinity
+ else:
+ # Versions with a local segment need that segment parsed to implement
+ # the sorting rules in PEP440.
+ # - Alpha numeric segments sort before numeric segments
+ # - Alpha numeric segments sort lexicographically
+ # - Numeric segments sort numerically
+ # - Shorter versions sort before longer versions when the prefixes
+ # match exactly
+ _local = tuple(
+ (i, "") if isinstance(i, int) else (NegativeInfinity, i)
+ for i in local
+ )
+
+ return epoch, _release, _pre, _post, _dev, _local