diff options
Diffstat (limited to '')
152 files changed, 91241 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/__init__.pxd b/src/arrow/python/pyarrow/__init__.pxd new file mode 100644 index 000000000..8cc54b4c6 --- /dev/null +++ b/src/arrow/python/pyarrow/__init__.pxd @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from libcpp.memory cimport shared_ptr +from pyarrow.includes.libarrow cimport (CArray, CBuffer, CDataType, + CField, CRecordBatch, CSchema, + CTable, CTensor, CSparseCOOTensor, + CSparseCSRMatrix, CSparseCSCMatrix, + CSparseCSFTensor) + +cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": + cdef int import_pyarrow() except -1 + cdef object wrap_buffer(const shared_ptr[CBuffer]& buffer) + cdef object wrap_data_type(const shared_ptr[CDataType]& type) + cdef object wrap_field(const shared_ptr[CField]& field) + cdef object wrap_schema(const shared_ptr[CSchema]& schema) + cdef object wrap_array(const shared_ptr[CArray]& sp_array) + cdef object wrap_tensor(const shared_ptr[CTensor]& sp_tensor) + cdef object wrap_sparse_tensor_coo( + const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor) + cdef object wrap_sparse_tensor_csr( + const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor) + cdef object wrap_sparse_tensor_csc( + const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor) + cdef object wrap_sparse_tensor_csf( + const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) + cdef object wrap_table(const shared_ptr[CTable]& ctable) + cdef object wrap_batch(const shared_ptr[CRecordBatch]& cbatch) diff --git a/src/arrow/python/pyarrow/__init__.py b/src/arrow/python/pyarrow/__init__.py new file mode 100644 index 000000000..1ec229d53 --- /dev/null +++ b/src/arrow/python/pyarrow/__init__.py @@ -0,0 +1,511 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# flake8: noqa + +""" +PyArrow is the python implementation of Apache Arrow. + +Apache Arrow is a cross-language development platform for in-memory data. +It specifies a standardized language-independent columnar memory format for +flat and hierarchical data, organized for efficient analytic operations on +modern hardware. It also provides computational libraries and zero-copy +streaming messaging and interprocess communication. + +For more information see the official page at https://arrow.apache.org +""" + +import gc as _gc +import os as _os +import sys as _sys +import warnings as _warnings + +try: + from ._generated_version import version as __version__ +except ImportError: + # Package is not installed, parse git tag at runtime + try: + import setuptools_scm + # Code duplicated from setup.py to avoid a dependency on each other + + def parse_git(root, **kwargs): + """ + Parse function for setuptools_scm that ignores tags for non-C++ + subprojects, e.g. apache-arrow-js-XXX tags. + """ + from setuptools_scm.git import parse + kwargs['describe_command'] = \ + "git describe --dirty --tags --long --match 'apache-arrow-[0-9].*'" + return parse(root, **kwargs) + __version__ = setuptools_scm.get_version('../', + parse=parse_git) + except ImportError: + __version__ = None + +# ARROW-8684: Disable GC while initializing Cython extension module, +# to workaround Cython bug in https://github.com/cython/cython/issues/3603 +_gc_enabled = _gc.isenabled() +_gc.disable() +import pyarrow.lib as _lib +if _gc_enabled: + _gc.enable() + +from pyarrow.lib import (BuildInfo, RuntimeInfo, MonthDayNano, + VersionInfo, cpp_build_info, cpp_version, + cpp_version_info, runtime_info, cpu_count, + set_cpu_count, enable_signal_handlers, + io_thread_count, set_io_thread_count) + + +def show_versions(): + """ + Print various version information, to help with error reporting. + """ + # TODO: CPU information and flags + print("pyarrow version info\n--------------------") + print("Package kind: {}".format(cpp_build_info.package_kind + if len(cpp_build_info.package_kind) > 0 + else "not indicated")) + print("Arrow C++ library version: {0}".format(cpp_build_info.version)) + print("Arrow C++ compiler: {0} {1}" + .format(cpp_build_info.compiler_id, cpp_build_info.compiler_version)) + print("Arrow C++ compiler flags: {0}" + .format(cpp_build_info.compiler_flags)) + print("Arrow C++ git revision: {0}".format(cpp_build_info.git_id)) + print("Arrow C++ git description: {0}" + .format(cpp_build_info.git_description)) + + +from pyarrow.lib import (null, bool_, + int8, int16, int32, int64, + uint8, uint16, uint32, uint64, + time32, time64, timestamp, date32, date64, duration, + month_day_nano_interval, + float16, float32, float64, + binary, string, utf8, + large_binary, large_string, large_utf8, + decimal128, decimal256, + list_, large_list, map_, struct, + union, sparse_union, dense_union, + dictionary, + field, + type_for_alias, + DataType, DictionaryType, StructType, + ListType, LargeListType, MapType, FixedSizeListType, + UnionType, SparseUnionType, DenseUnionType, + TimestampType, Time32Type, Time64Type, DurationType, + FixedSizeBinaryType, Decimal128Type, Decimal256Type, + BaseExtensionType, ExtensionType, + PyExtensionType, UnknownExtensionType, + register_extension_type, unregister_extension_type, + DictionaryMemo, + KeyValueMetadata, + Field, + Schema, + schema, + unify_schemas, + Array, Tensor, + array, chunked_array, record_batch, nulls, repeat, + SparseCOOTensor, SparseCSRMatrix, SparseCSCMatrix, + SparseCSFTensor, + infer_type, from_numpy_dtype, + NullArray, + NumericArray, IntegerArray, FloatingPointArray, + BooleanArray, + Int8Array, UInt8Array, + Int16Array, UInt16Array, + Int32Array, UInt32Array, + Int64Array, UInt64Array, + ListArray, LargeListArray, MapArray, + FixedSizeListArray, UnionArray, + BinaryArray, StringArray, + LargeBinaryArray, LargeStringArray, + FixedSizeBinaryArray, + DictionaryArray, + Date32Array, Date64Array, TimestampArray, + Time32Array, Time64Array, DurationArray, + MonthDayNanoIntervalArray, + Decimal128Array, Decimal256Array, StructArray, ExtensionArray, + scalar, NA, _NULL as NULL, Scalar, + NullScalar, BooleanScalar, + Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, + UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar, + HalfFloatScalar, FloatScalar, DoubleScalar, + Decimal128Scalar, Decimal256Scalar, + ListScalar, LargeListScalar, FixedSizeListScalar, + Date32Scalar, Date64Scalar, + Time32Scalar, Time64Scalar, + TimestampScalar, DurationScalar, + MonthDayNanoIntervalScalar, + BinaryScalar, LargeBinaryScalar, + StringScalar, LargeStringScalar, + FixedSizeBinaryScalar, DictionaryScalar, + MapScalar, StructScalar, UnionScalar, + ExtensionScalar) + +# Buffers, allocation +from pyarrow.lib import (Buffer, ResizableBuffer, foreign_buffer, py_buffer, + Codec, compress, decompress, allocate_buffer) + +from pyarrow.lib import (MemoryPool, LoggingMemoryPool, ProxyMemoryPool, + total_allocated_bytes, set_memory_pool, + default_memory_pool, system_memory_pool, + jemalloc_memory_pool, mimalloc_memory_pool, + logging_memory_pool, proxy_memory_pool, + log_memory_allocations, jemalloc_set_decay_ms) + +# I/O +from pyarrow.lib import (NativeFile, PythonFile, + BufferedInputStream, BufferedOutputStream, + CompressedInputStream, CompressedOutputStream, + TransformInputStream, transcoding_input_stream, + FixedSizeBufferWriter, + BufferReader, BufferOutputStream, + OSFile, MemoryMappedFile, memory_map, + create_memory_map, MockOutputStream, + input_stream, output_stream) + +from pyarrow._hdfsio import HdfsFile, have_libhdfs + +from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table, + concat_arrays, concat_tables) + +# Exceptions +from pyarrow.lib import (ArrowCancelled, + ArrowCapacityError, + ArrowException, + ArrowKeyError, + ArrowIndexError, + ArrowInvalid, + ArrowIOError, + ArrowMemoryError, + ArrowNotImplementedError, + ArrowTypeError, + ArrowSerializationError) + +# Serialization +from pyarrow.lib import (deserialize_from, deserialize, + deserialize_components, + serialize, serialize_to, read_serialized, + SerializationCallbackError, + DeserializationCallbackError) + +import pyarrow.hdfs as hdfs + +from pyarrow.ipc import serialize_pandas, deserialize_pandas +import pyarrow.ipc as ipc + +from pyarrow.serialization import (default_serialization_context, + register_default_serialization_handlers, + register_torch_serialization_handlers) + +import pyarrow.types as types + + +# deprecated top-level access + + +from pyarrow.filesystem import FileSystem as _FileSystem +from pyarrow.filesystem import LocalFileSystem as _LocalFileSystem +from pyarrow.hdfs import HadoopFileSystem as _HadoopFileSystem + +from pyarrow.lib import SerializationContext as _SerializationContext +from pyarrow.lib import SerializedPyObject as _SerializedPyObject + + +_localfs = _LocalFileSystem._get_instance() + + +_msg = ( + "pyarrow.{0} is deprecated as of 2.0.0, please use pyarrow.fs.{1} instead." +) + +_serialization_msg = ( + "'pyarrow.{0}' is deprecated and will be removed in a future version. " + "Use pickle or the pyarrow IPC functionality instead." +) + +_deprecated = { + "localfs": (_localfs, "LocalFileSystem"), + "FileSystem": (_FileSystem, "FileSystem"), + "LocalFileSystem": (_LocalFileSystem, "LocalFileSystem"), + "HadoopFileSystem": (_HadoopFileSystem, "HadoopFileSystem"), +} + +_serialization_deprecatd = { + "SerializationContext": _SerializationContext, + "SerializedPyObject": _SerializedPyObject, +} + +if _sys.version_info >= (3, 7): + def __getattr__(name): + if name in _deprecated: + obj, new_name = _deprecated[name] + _warnings.warn(_msg.format(name, new_name), + FutureWarning, stacklevel=2) + return obj + elif name in _serialization_deprecatd: + _warnings.warn(_serialization_msg.format(name), + FutureWarning, stacklevel=2) + return _serialization_deprecatd[name] + + raise AttributeError( + "module 'pyarrow' has no attribute '{0}'".format(name) + ) +else: + localfs = _localfs + FileSystem = _FileSystem + LocalFileSystem = _LocalFileSystem + HadoopFileSystem = _HadoopFileSystem + SerializationContext = _SerializationContext + SerializedPyObject = _SerializedPyObject + + +# Entry point for starting the plasma store + + +def _plasma_store_entry_point(): + """Entry point for starting the plasma store. + + This can be used by invoking e.g. + ``plasma_store -s /tmp/plasma -m 1000000000`` + from the command line and will start the plasma_store executable with the + given arguments. + """ + import pyarrow + plasma_store_executable = _os.path.join(pyarrow.__path__[0], + "plasma-store-server") + _os.execv(plasma_store_executable, _sys.argv) + + +# ---------------------------------------------------------------------- +# Deprecations + +from pyarrow.util import _deprecate_api, _deprecate_class + +read_message = _deprecate_api("read_message", "ipc.read_message", + ipc.read_message, "0.17.0") + +read_record_batch = _deprecate_api("read_record_batch", + "ipc.read_record_batch", + ipc.read_record_batch, "0.17.0") + +read_schema = _deprecate_api("read_schema", "ipc.read_schema", + ipc.read_schema, "0.17.0") + +read_tensor = _deprecate_api("read_tensor", "ipc.read_tensor", + ipc.read_tensor, "0.17.0") + +write_tensor = _deprecate_api("write_tensor", "ipc.write_tensor", + ipc.write_tensor, "0.17.0") + +get_record_batch_size = _deprecate_api("get_record_batch_size", + "ipc.get_record_batch_size", + ipc.get_record_batch_size, "0.17.0") + +get_tensor_size = _deprecate_api("get_tensor_size", + "ipc.get_tensor_size", + ipc.get_tensor_size, "0.17.0") + +open_stream = _deprecate_api("open_stream", "ipc.open_stream", + ipc.open_stream, "0.17.0") + +open_file = _deprecate_api("open_file", "ipc.open_file", ipc.open_file, + "0.17.0") + + +def _deprecate_scalar(ty, symbol): + return _deprecate_class("{}Value".format(ty), symbol, "1.0.0") + + +ArrayValue = _deprecate_class("ArrayValue", Scalar, "1.0.0") +NullType = _deprecate_class("NullType", NullScalar, "1.0.0") + +BooleanValue = _deprecate_scalar("Boolean", BooleanScalar) +Int8Value = _deprecate_scalar("Int8", Int8Scalar) +Int16Value = _deprecate_scalar("Int16", Int16Scalar) +Int32Value = _deprecate_scalar("Int32", Int32Scalar) +Int64Value = _deprecate_scalar("Int64", Int64Scalar) +UInt8Value = _deprecate_scalar("UInt8", UInt8Scalar) +UInt16Value = _deprecate_scalar("UInt16", UInt16Scalar) +UInt32Value = _deprecate_scalar("UInt32", UInt32Scalar) +UInt64Value = _deprecate_scalar("UInt64", UInt64Scalar) +HalfFloatValue = _deprecate_scalar("HalfFloat", HalfFloatScalar) +FloatValue = _deprecate_scalar("Float", FloatScalar) +DoubleValue = _deprecate_scalar("Double", DoubleScalar) +ListValue = _deprecate_scalar("List", ListScalar) +LargeListValue = _deprecate_scalar("LargeList", LargeListScalar) +MapValue = _deprecate_scalar("Map", MapScalar) +FixedSizeListValue = _deprecate_scalar("FixedSizeList", FixedSizeListScalar) +BinaryValue = _deprecate_scalar("Binary", BinaryScalar) +StringValue = _deprecate_scalar("String", StringScalar) +LargeBinaryValue = _deprecate_scalar("LargeBinary", LargeBinaryScalar) +LargeStringValue = _deprecate_scalar("LargeString", LargeStringScalar) +FixedSizeBinaryValue = _deprecate_scalar("FixedSizeBinary", + FixedSizeBinaryScalar) +Decimal128Value = _deprecate_scalar("Decimal128", Decimal128Scalar) +Decimal256Value = _deprecate_scalar("Decimal256", Decimal256Scalar) +UnionValue = _deprecate_scalar("Union", UnionScalar) +StructValue = _deprecate_scalar("Struct", StructScalar) +DictionaryValue = _deprecate_scalar("Dictionary", DictionaryScalar) +Date32Value = _deprecate_scalar("Date32", Date32Scalar) +Date64Value = _deprecate_scalar("Date64", Date64Scalar) +Time32Value = _deprecate_scalar("Time32", Time32Scalar) +Time64Value = _deprecate_scalar("Time64", Time64Scalar) +TimestampValue = _deprecate_scalar("Timestamp", TimestampScalar) +DurationValue = _deprecate_scalar("Duration", DurationScalar) + + +# TODO: Deprecate these somehow in the pyarrow namespace +from pyarrow.ipc import (Message, MessageReader, MetadataVersion, + RecordBatchFileReader, RecordBatchFileWriter, + RecordBatchStreamReader, RecordBatchStreamWriter) + +# ---------------------------------------------------------------------- +# Returning absolute path to the pyarrow include directory (if bundled, e.g. in +# wheels) + + +def get_include(): + """ + Return absolute path to directory containing Arrow C++ include + headers. Similar to numpy.get_include + """ + return _os.path.join(_os.path.dirname(__file__), 'include') + + +def _get_pkg_config_executable(): + return _os.environ.get('PKG_CONFIG', 'pkg-config') + + +def _has_pkg_config(pkgname): + import subprocess + try: + return subprocess.call([_get_pkg_config_executable(), + '--exists', pkgname]) == 0 + except FileNotFoundError: + return False + + +def _read_pkg_config_variable(pkgname, cli_args): + import subprocess + cmd = [_get_pkg_config_executable(), pkgname] + cli_args + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = proc.communicate() + if proc.returncode != 0: + raise RuntimeError("pkg-config failed: " + err.decode('utf8')) + return out.rstrip().decode('utf8') + + +def get_libraries(): + """ + Return list of library names to include in the `libraries` argument for C + or Cython extensions using pyarrow + """ + return ['arrow', 'arrow_python'] + + +def create_library_symlinks(): + """ + With Linux and macOS wheels, the bundled shared libraries have an embedded + ABI version like libarrow.so.17 or libarrow.17.dylib and so linking to them + with -larrow won't work unless we create symlinks at locations like + site-packages/pyarrow/libarrow.so. This unfortunate workaround addresses + prior problems we had with shipping two copies of the shared libraries to + permit third party projects like turbodbc to build their C++ extensions + against the pyarrow wheels. + + This function must only be invoked once and only when the shared libraries + are bundled with the Python package, which should only apply to wheel-based + installs. It requires write access to the site-packages/pyarrow directory + and so depending on your system may need to be run with root. + """ + import glob + if _sys.platform == 'win32': + return + package_cwd = _os.path.dirname(__file__) + + if _sys.platform == 'linux': + bundled_libs = glob.glob(_os.path.join(package_cwd, '*.so.*')) + + def get_symlink_path(hard_path): + return hard_path.rsplit('.', 1)[0] + else: + bundled_libs = glob.glob(_os.path.join(package_cwd, '*.*.dylib')) + + def get_symlink_path(hard_path): + return '.'.join((hard_path.rsplit('.', 2)[0], 'dylib')) + + for lib_hard_path in bundled_libs: + symlink_path = get_symlink_path(lib_hard_path) + if _os.path.exists(symlink_path): + continue + try: + _os.symlink(lib_hard_path, symlink_path) + except PermissionError: + print("Tried creating symlink {}. If you need to link to " + "bundled shared libraries, run " + "pyarrow.create_library_symlinks() as root") + + +def get_library_dirs(): + """ + Return lists of directories likely to contain Arrow C++ libraries for + linking C or Cython extensions using pyarrow + """ + package_cwd = _os.path.dirname(__file__) + library_dirs = [package_cwd] + + def append_library_dir(library_dir): + if library_dir not in library_dirs: + library_dirs.append(library_dir) + + # Search library paths via pkg-config. This is necessary if the user + # installed libarrow and the other shared libraries manually and they + # are not shipped inside the pyarrow package (see also ARROW-2976). + pkg_config_executable = _os.environ.get('PKG_CONFIG') or 'pkg-config' + for pkgname in ["arrow", "arrow_python"]: + if _has_pkg_config(pkgname): + library_dir = _read_pkg_config_variable(pkgname, + ["--libs-only-L"]) + # pkg-config output could be empty if Arrow is installed + # as a system package. + if library_dir: + if not library_dir.startswith("-L"): + raise ValueError( + "pkg-config --libs-only-L returned unexpected " + "value {!r}".format(library_dir)) + append_library_dir(library_dir[2:]) + + if _sys.platform == 'win32': + # TODO(wesm): Is this necessary, or does setuptools within a conda + # installation add Library\lib to the linker path for MSVC? + python_base_install = _os.path.dirname(_sys.executable) + library_dir = _os.path.join(python_base_install, 'Library', 'lib') + + if _os.path.exists(_os.path.join(library_dir, 'arrow.lib')): + append_library_dir(library_dir) + + # ARROW-4074: Allow for ARROW_HOME to be set to some other directory + if _os.environ.get('ARROW_HOME'): + append_library_dir(_os.path.join(_os.environ['ARROW_HOME'], 'lib')) + else: + # Python wheels bundle the Arrow libraries in the pyarrow directory. + append_library_dir(_os.path.dirname(_os.path.abspath(__file__))) + + return library_dirs diff --git a/src/arrow/python/pyarrow/_compute.pxd b/src/arrow/python/pyarrow/_compute.pxd new file mode 100644 index 000000000..8358271ef --- /dev/null +++ b/src/arrow/python/pyarrow/_compute.pxd @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.lib cimport * +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + + +cdef class FunctionOptions(_Weakrefable): + cdef: + unique_ptr[CFunctionOptions] wrapped + + cdef const CFunctionOptions* get_options(self) except NULL + cdef void init(self, unique_ptr[CFunctionOptions] options) diff --git a/src/arrow/python/pyarrow/_compute.pyx b/src/arrow/python/pyarrow/_compute.pyx new file mode 100644 index 000000000..d62c9c0ee --- /dev/null +++ b/src/arrow/python/pyarrow/_compute.pyx @@ -0,0 +1,1296 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +import sys + +from cython.operator cimport dereference as deref + +from collections import namedtuple + +from pyarrow.lib import frombytes, tobytes, ordered_dict +from pyarrow.lib cimport * +from pyarrow.includes.libarrow cimport * +import pyarrow.lib as lib + +import numpy as np + + +cdef wrap_scalar_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ scalar Function in a ScalarFunction object. + """ + cdef ScalarFunction func = ScalarFunction.__new__(ScalarFunction) + func.init(sp_func) + return func + + +cdef wrap_vector_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ vector Function in a VectorFunction object. + """ + cdef VectorFunction func = VectorFunction.__new__(VectorFunction) + func.init(sp_func) + return func + + +cdef wrap_scalar_aggregate_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ aggregate Function in a ScalarAggregateFunction object. + """ + cdef ScalarAggregateFunction func = \ + ScalarAggregateFunction.__new__(ScalarAggregateFunction) + func.init(sp_func) + return func + + +cdef wrap_hash_aggregate_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ aggregate Function in a HashAggregateFunction object. + """ + cdef HashAggregateFunction func = \ + HashAggregateFunction.__new__(HashAggregateFunction) + func.init(sp_func) + return func + + +cdef wrap_meta_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ meta Function in a MetaFunction object. + """ + cdef MetaFunction func = MetaFunction.__new__(MetaFunction) + func.init(sp_func) + return func + + +cdef wrap_function(const shared_ptr[CFunction]& sp_func): + """ + Wrap a C++ Function in a Function object. + + This dispatches to specialized wrappers depending on the function kind. + """ + if sp_func.get() == NULL: + raise ValueError("Function was NULL") + + cdef FunctionKind c_kind = sp_func.get().kind() + if c_kind == FunctionKind_SCALAR: + return wrap_scalar_function(sp_func) + elif c_kind == FunctionKind_VECTOR: + return wrap_vector_function(sp_func) + elif c_kind == FunctionKind_SCALAR_AGGREGATE: + return wrap_scalar_aggregate_function(sp_func) + elif c_kind == FunctionKind_HASH_AGGREGATE: + return wrap_hash_aggregate_function(sp_func) + elif c_kind == FunctionKind_META: + return wrap_meta_function(sp_func) + else: + raise NotImplementedError("Unknown Function::Kind") + + +cdef wrap_scalar_kernel(const CScalarKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef ScalarKernel kernel = ScalarKernel.__new__(ScalarKernel) + kernel.init(c_kernel) + return kernel + + +cdef wrap_vector_kernel(const CVectorKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef VectorKernel kernel = VectorKernel.__new__(VectorKernel) + kernel.init(c_kernel) + return kernel + + +cdef wrap_scalar_aggregate_kernel(const CScalarAggregateKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef ScalarAggregateKernel kernel = \ + ScalarAggregateKernel.__new__(ScalarAggregateKernel) + kernel.init(c_kernel) + return kernel + + +cdef wrap_hash_aggregate_kernel(const CHashAggregateKernel* c_kernel): + if c_kernel == NULL: + raise ValueError("Kernel was NULL") + cdef HashAggregateKernel kernel = \ + HashAggregateKernel.__new__(HashAggregateKernel) + kernel.init(c_kernel) + return kernel + + +cdef class Kernel(_Weakrefable): + """ + A kernel object. + + Kernels handle the execution of a Function for a certain signature. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly" + .format(self.__class__.__name__)) + + +cdef class ScalarKernel(Kernel): + cdef const CScalarKernel* kernel + + cdef void init(self, const CScalarKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("ScalarKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +cdef class VectorKernel(Kernel): + cdef const CVectorKernel* kernel + + cdef void init(self, const CVectorKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("VectorKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +cdef class ScalarAggregateKernel(Kernel): + cdef const CScalarAggregateKernel* kernel + + cdef void init(self, const CScalarAggregateKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("ScalarAggregateKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +cdef class HashAggregateKernel(Kernel): + cdef const CHashAggregateKernel* kernel + + cdef void init(self, const CHashAggregateKernel* kernel) except *: + self.kernel = kernel + + def __repr__(self): + return ("HashAggregateKernel<{}>" + .format(frombytes(self.kernel.signature.get().ToString()))) + + +FunctionDoc = namedtuple( + "FunctionDoc", + ("summary", "description", "arg_names", "options_class")) + + +cdef class Function(_Weakrefable): + """ + A compute function. + + A function implements a certain logical computation over a range of + possible input signatures. Each signature accepts a range of input + types and is implemented by a given Kernel. + + Functions can be of different kinds: + + * "scalar" functions apply an item-wise computation over all items + of their inputs. Each item in the output only depends on the values + of the inputs at the same position. Examples: addition, comparisons, + string predicates... + + * "vector" functions apply a collection-wise computation, such that + each item in the output may depend on the values of several items + in each input. Examples: dictionary encoding, sorting, extracting + unique values... + + * "scalar_aggregate" functions reduce the dimensionality of the inputs by + applying a reduction function. Examples: sum, min_max, mode... + + * "hash_aggregate" functions apply a reduction function to an input + subdivided by grouping criteria. They may not be directly called. + Examples: hash_sum, hash_min_max... + + * "meta" functions dispatch to other functions. + """ + + cdef: + shared_ptr[CFunction] sp_func + CFunction* base_func + + _kind_map = { + FunctionKind_SCALAR: "scalar", + FunctionKind_VECTOR: "vector", + FunctionKind_SCALAR_AGGREGATE: "scalar_aggregate", + FunctionKind_HASH_AGGREGATE: "hash_aggregate", + FunctionKind_META: "meta", + } + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly" + .format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + self.sp_func = sp_func + self.base_func = sp_func.get() + + def __repr__(self): + return ("arrow.compute.Function<name={}, kind={}, " + "arity={}, num_kernels={}>" + .format(self.name, self.kind, self.arity, self.num_kernels)) + + def __reduce__(self): + # Reduction uses the global registry + return get_function, (self.name,) + + @property + def name(self): + """ + The function name. + """ + return frombytes(self.base_func.name()) + + @property + def arity(self): + """ + The function arity. + + If Ellipsis (i.e. `...`) is returned, the function takes a variable + number of arguments. + """ + cdef CArity arity = self.base_func.arity() + if arity.is_varargs: + return ... + else: + return arity.num_args + + @property + def kind(self): + """ + The function kind. + """ + cdef FunctionKind c_kind = self.base_func.kind() + try: + return self._kind_map[c_kind] + except KeyError: + raise NotImplementedError("Unknown Function::Kind") + + @property + def _doc(self): + """ + The C++-like function documentation (for internal use). + """ + cdef CFunctionDoc c_doc = self.base_func.doc() + return FunctionDoc(frombytes(c_doc.summary), + frombytes(c_doc.description), + [frombytes(s) for s in c_doc.arg_names], + frombytes(c_doc.options_class)) + + @property + def num_kernels(self): + """ + The number of kernels implementing this function. + """ + return self.base_func.num_kernels() + + def call(self, args, FunctionOptions options=None, + MemoryPool memory_pool=None): + """ + Call the function on the given arguments. + """ + cdef: + const CFunctionOptions* c_options = NULL + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + CExecContext c_exec_ctx = CExecContext(pool) + vector[CDatum] c_args + CDatum result + + _pack_compute_args(args, &c_args) + + if options is not None: + c_options = options.get_options() + + with nogil: + result = GetResultValue( + self.base_func.Execute(c_args, c_options, &c_exec_ctx) + ) + + return wrap_datum(result) + + +cdef class ScalarFunction(Function): + cdef const CScalarFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = <const CScalarFunction*> sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CScalarKernel*] kernels = self.func.kernels() + return [wrap_scalar_kernel(k) for k in kernels] + + +cdef class VectorFunction(Function): + cdef const CVectorFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = <const CVectorFunction*> sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CVectorKernel*] kernels = self.func.kernels() + return [wrap_vector_kernel(k) for k in kernels] + + +cdef class ScalarAggregateFunction(Function): + cdef const CScalarAggregateFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = <const CScalarAggregateFunction*> sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CScalarAggregateKernel*] kernels = \ + self.func.kernels() + return [wrap_scalar_aggregate_kernel(k) for k in kernels] + + +cdef class HashAggregateFunction(Function): + cdef const CHashAggregateFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = <const CHashAggregateFunction*> sp_func.get() + + @property + def kernels(self): + """ + The kernels implementing this function. + """ + cdef vector[const CHashAggregateKernel*] kernels = self.func.kernels() + return [wrap_hash_aggregate_kernel(k) for k in kernels] + + +cdef class MetaFunction(Function): + cdef const CMetaFunction* func + + cdef void init(self, const shared_ptr[CFunction]& sp_func) except *: + Function.init(self, sp_func) + self.func = <const CMetaFunction*> sp_func.get() + + # Since num_kernels is exposed, also expose a kernels property + @property + def kernels(self): + """ + The kernels implementing this function. + """ + return [] + + +cdef _pack_compute_args(object values, vector[CDatum]* out): + for val in values: + if isinstance(val, (list, np.ndarray)): + val = lib.asarray(val) + + if isinstance(val, Array): + out.push_back(CDatum((<Array> val).sp_array)) + continue + elif isinstance(val, ChunkedArray): + out.push_back(CDatum((<ChunkedArray> val).sp_chunked_array)) + continue + elif isinstance(val, Scalar): + out.push_back(CDatum((<Scalar> val).unwrap())) + continue + elif isinstance(val, RecordBatch): + out.push_back(CDatum((<RecordBatch> val).sp_batch)) + continue + elif isinstance(val, Table): + out.push_back(CDatum((<Table> val).sp_table)) + continue + else: + # Is it a Python scalar? + try: + scal = lib.scalar(val) + except Exception: + # Raise dedicated error below + pass + else: + out.push_back(CDatum((<Scalar> scal).unwrap())) + continue + + raise TypeError(f"Got unexpected argument type {type(val)} " + "for compute function") + + +cdef class FunctionRegistry(_Weakrefable): + cdef CFunctionRegistry* registry + + def __init__(self): + self.registry = GetFunctionRegistry() + + def list_functions(self): + """ + Return all function names in the registry. + """ + cdef vector[c_string] names = self.registry.GetFunctionNames() + return [frombytes(name) for name in names] + + def get_function(self, name): + """ + Look up a function by name in the registry. + + Parameters + ---------- + name : str + The name of the function to lookup + """ + cdef: + c_string c_name = tobytes(name) + shared_ptr[CFunction] func + with nogil: + func = GetResultValue(self.registry.GetFunction(c_name)) + return wrap_function(func) + + +cdef FunctionRegistry _global_func_registry = FunctionRegistry() + + +def function_registry(): + return _global_func_registry + + +def get_function(name): + """ + Get a function by name. + + The function is looked up in the global registry + (as returned by `function_registry()`). + + Parameters + ---------- + name : str + The name of the function to lookup + """ + return _global_func_registry.get_function(name) + + +def list_functions(): + """ + Return all function names in the global registry. + """ + return _global_func_registry.list_functions() + + +def call_function(name, args, options=None, memory_pool=None): + """ + Call a named function. + + The function is looked up in the global registry + (as returned by `function_registry()`). + + Parameters + ---------- + name : str + The name of the function to call. + args : list + The arguments to the function. + options : optional + options provided to the function. + memory_pool : MemoryPool, optional + memory pool to use for allocations during function execution. + """ + func = _global_func_registry.get_function(name) + return func.call(args, options=options, memory_pool=memory_pool) + + +cdef class FunctionOptions(_Weakrefable): + __slots__ = () # avoid mistakingly creating attributes + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.wrapped.get() + + cdef void init(self, unique_ptr[CFunctionOptions] options): + self.wrapped = move(options) + + def serialize(self): + cdef: + CResult[shared_ptr[CBuffer]] res = self.get_options().Serialize() + shared_ptr[CBuffer] c_buf = GetResultValue(res) + return pyarrow_wrap_buffer(c_buf) + + @staticmethod + def deserialize(buf): + """ + Deserialize options for a function. + + Parameters + ---------- + buf : Buffer + The buffer containing the data to deserialize. + """ + cdef: + shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) + CResult[unique_ptr[CFunctionOptions]] maybe_options = \ + DeserializeFunctionOptions(deref(c_buf)) + unique_ptr[CFunctionOptions] c_options + c_options = move(GetResultValue(move(maybe_options))) + type_name = frombytes(c_options.get().options_type().type_name()) + module = globals() + if type_name not in module: + raise ValueError(f'Cannot deserialize "{type_name}"') + klass = module[type_name] + options = klass.__new__(klass) + (<FunctionOptions> options).init(move(c_options)) + return options + + def __repr__(self): + type_name = self.__class__.__name__ + # Remove {} so we can use our own braces + string_repr = frombytes(self.get_options().ToString())[1:-1] + return f"{type_name}({string_repr})" + + def __eq__(self, FunctionOptions other): + return self.get_options().Equals(deref(other.get_options())) + + +def _raise_invalid_function_option(value, description, *, + exception_class=ValueError): + raise exception_class(f"\"{value}\" is not a valid {description}") + + +# NOTE: +# To properly expose the constructor signature of FunctionOptions +# subclasses, we use a two-level inheritance: +# 1. a C extension class that implements option validation and setting +# (won't expose function signatures because of +# https://github.com/cython/cython/issues/3873) +# 2. a Python derived class that implements the constructor + +cdef class _CastOptions(FunctionOptions): + cdef CCastOptions* options + + cdef void init(self, unique_ptr[CFunctionOptions] options): + FunctionOptions.init(self, move(options)) + self.options = <CCastOptions*> self.wrapped.get() + + def _set_options(self, DataType target_type, allow_int_overflow, + allow_time_truncate, allow_time_overflow, + allow_decimal_truncate, allow_float_truncate, + allow_invalid_utf8): + self.init(unique_ptr[CFunctionOptions](new CCastOptions())) + self._set_type(target_type) + if allow_int_overflow is not None: + self.allow_int_overflow = allow_int_overflow + if allow_time_truncate is not None: + self.allow_time_truncate = allow_time_truncate + if allow_time_overflow is not None: + self.allow_time_overflow = allow_time_overflow + if allow_decimal_truncate is not None: + self.allow_decimal_truncate = allow_decimal_truncate + if allow_float_truncate is not None: + self.allow_float_truncate = allow_float_truncate + if allow_invalid_utf8 is not None: + self.allow_invalid_utf8 = allow_invalid_utf8 + + def _set_type(self, target_type=None): + if target_type is not None: + deref(self.options).to_type = \ + (<DataType> ensure_type(target_type)).sp_type + + def _set_safe(self): + self.init(unique_ptr[CFunctionOptions]( + new CCastOptions(CCastOptions.Safe()))) + + def _set_unsafe(self): + self.init(unique_ptr[CFunctionOptions]( + new CCastOptions(CCastOptions.Unsafe()))) + + def is_safe(self): + return not (deref(self.options).allow_int_overflow or + deref(self.options).allow_time_truncate or + deref(self.options).allow_time_overflow or + deref(self.options).allow_decimal_truncate or + deref(self.options).allow_float_truncate or + deref(self.options).allow_invalid_utf8) + + @property + def allow_int_overflow(self): + return deref(self.options).allow_int_overflow + + @allow_int_overflow.setter + def allow_int_overflow(self, c_bool flag): + deref(self.options).allow_int_overflow = flag + + @property + def allow_time_truncate(self): + return deref(self.options).allow_time_truncate + + @allow_time_truncate.setter + def allow_time_truncate(self, c_bool flag): + deref(self.options).allow_time_truncate = flag + + @property + def allow_time_overflow(self): + return deref(self.options).allow_time_overflow + + @allow_time_overflow.setter + def allow_time_overflow(self, c_bool flag): + deref(self.options).allow_time_overflow = flag + + @property + def allow_decimal_truncate(self): + return deref(self.options).allow_decimal_truncate + + @allow_decimal_truncate.setter + def allow_decimal_truncate(self, c_bool flag): + deref(self.options).allow_decimal_truncate = flag + + @property + def allow_float_truncate(self): + return deref(self.options).allow_float_truncate + + @allow_float_truncate.setter + def allow_float_truncate(self, c_bool flag): + deref(self.options).allow_float_truncate = flag + + @property + def allow_invalid_utf8(self): + return deref(self.options).allow_invalid_utf8 + + @allow_invalid_utf8.setter + def allow_invalid_utf8(self, c_bool flag): + deref(self.options).allow_invalid_utf8 = flag + + +class CastOptions(_CastOptions): + + def __init__(self, target_type=None, *, allow_int_overflow=None, + allow_time_truncate=None, allow_time_overflow=None, + allow_decimal_truncate=None, allow_float_truncate=None, + allow_invalid_utf8=None): + self._set_options(target_type, allow_int_overflow, allow_time_truncate, + allow_time_overflow, allow_decimal_truncate, + allow_float_truncate, allow_invalid_utf8) + + @staticmethod + def safe(target_type=None): + """" + Create a CastOptions for a safe cast. + + Parameters + ---------- + target_type : optional + Target cast type for the safe cast. + """ + self = CastOptions() + self._set_safe() + self._set_type(target_type) + return self + + @staticmethod + def unsafe(target_type=None): + """" + Create a CastOptions for an unsafe cast. + + Parameters + ---------- + target_type : optional + Target cast type for the unsafe cast. + """ + self = CastOptions() + self._set_unsafe() + self._set_type(target_type) + return self + + +cdef class _ElementWiseAggregateOptions(FunctionOptions): + def _set_options(self, skip_nulls): + self.wrapped.reset(new CElementWiseAggregateOptions(skip_nulls)) + + +class ElementWiseAggregateOptions(_ElementWiseAggregateOptions): + def __init__(self, *, skip_nulls=True): + self._set_options(skip_nulls) + + +cdef CRoundMode unwrap_round_mode(round_mode) except *: + if round_mode == "down": + return CRoundMode_DOWN + elif round_mode == "up": + return CRoundMode_UP + elif round_mode == "towards_zero": + return CRoundMode_TOWARDS_ZERO + elif round_mode == "towards_infinity": + return CRoundMode_TOWARDS_INFINITY + elif round_mode == "half_down": + return CRoundMode_HALF_DOWN + elif round_mode == "half_up": + return CRoundMode_HALF_UP + elif round_mode == "half_towards_zero": + return CRoundMode_HALF_TOWARDS_ZERO + elif round_mode == "half_towards_infinity": + return CRoundMode_HALF_TOWARDS_INFINITY + elif round_mode == "half_to_even": + return CRoundMode_HALF_TO_EVEN + elif round_mode == "half_to_odd": + return CRoundMode_HALF_TO_ODD + _raise_invalid_function_option(round_mode, "round mode") + + +cdef class _RoundOptions(FunctionOptions): + def _set_options(self, ndigits, round_mode): + self.wrapped.reset( + new CRoundOptions(ndigits, unwrap_round_mode(round_mode)) + ) + + +class RoundOptions(_RoundOptions): + def __init__(self, ndigits=0, round_mode="half_to_even"): + self._set_options(ndigits, round_mode) + + +cdef class _RoundToMultipleOptions(FunctionOptions): + def _set_options(self, multiple, round_mode): + self.wrapped.reset( + new CRoundToMultipleOptions(multiple, + unwrap_round_mode(round_mode)) + ) + + +class RoundToMultipleOptions(_RoundToMultipleOptions): + def __init__(self, multiple=1.0, round_mode="half_to_even"): + self._set_options(multiple, round_mode) + + +cdef class _JoinOptions(FunctionOptions): + _null_handling_map = { + "emit_null": CJoinNullHandlingBehavior_EMIT_NULL, + "skip": CJoinNullHandlingBehavior_SKIP, + "replace": CJoinNullHandlingBehavior_REPLACE, + } + + def _set_options(self, null_handling, null_replacement): + try: + self.wrapped.reset( + new CJoinOptions(self._null_handling_map[null_handling], + tobytes(null_replacement)) + ) + except KeyError: + _raise_invalid_function_option(null_handling, "null handling") + + +class JoinOptions(_JoinOptions): + def __init__(self, null_handling="emit_null", null_replacement=""): + self._set_options(null_handling, null_replacement) + + +cdef class _MatchSubstringOptions(FunctionOptions): + def _set_options(self, pattern, ignore_case): + self.wrapped.reset( + new CMatchSubstringOptions(tobytes(pattern), ignore_case) + ) + + +class MatchSubstringOptions(_MatchSubstringOptions): + def __init__(self, pattern, *, ignore_case=False): + self._set_options(pattern, ignore_case) + + +cdef class _PadOptions(FunctionOptions): + def _set_options(self, width, padding): + self.wrapped.reset(new CPadOptions(width, tobytes(padding))) + + +class PadOptions(_PadOptions): + def __init__(self, width, padding=' '): + self._set_options(width, padding) + + +cdef class _TrimOptions(FunctionOptions): + def _set_options(self, characters): + self.wrapped.reset(new CTrimOptions(tobytes(characters))) + + +class TrimOptions(_TrimOptions): + def __init__(self, characters): + self._set_options(tobytes(characters)) + + +cdef class _ReplaceSliceOptions(FunctionOptions): + def _set_options(self, start, stop, replacement): + self.wrapped.reset( + new CReplaceSliceOptions(start, stop, tobytes(replacement)) + ) + + +class ReplaceSliceOptions(_ReplaceSliceOptions): + def __init__(self, start, stop, replacement): + self._set_options(start, stop, replacement) + + +cdef class _ReplaceSubstringOptions(FunctionOptions): + def _set_options(self, pattern, replacement, max_replacements): + self.wrapped.reset( + new CReplaceSubstringOptions(tobytes(pattern), + tobytes(replacement), + max_replacements) + ) + + +class ReplaceSubstringOptions(_ReplaceSubstringOptions): + def __init__(self, pattern, replacement, *, max_replacements=-1): + self._set_options(pattern, replacement, max_replacements) + + +cdef class _ExtractRegexOptions(FunctionOptions): + def _set_options(self, pattern): + self.wrapped.reset(new CExtractRegexOptions(tobytes(pattern))) + + +class ExtractRegexOptions(_ExtractRegexOptions): + def __init__(self, pattern): + self._set_options(pattern) + + +cdef class _SliceOptions(FunctionOptions): + def _set_options(self, start, stop, step): + self.wrapped.reset(new CSliceOptions(start, stop, step)) + + +class SliceOptions(_SliceOptions): + def __init__(self, start, stop=sys.maxsize, step=1): + self._set_options(start, stop, step) + + +cdef class _FilterOptions(FunctionOptions): + _null_selection_map = { + "drop": CFilterNullSelectionBehavior_DROP, + "emit_null": CFilterNullSelectionBehavior_EMIT_NULL, + } + + def _set_options(self, null_selection_behavior): + try: + self.wrapped.reset( + new CFilterOptions( + self._null_selection_map[null_selection_behavior] + ) + ) + except KeyError: + _raise_invalid_function_option(null_selection_behavior, + "null selection behavior") + + +class FilterOptions(_FilterOptions): + def __init__(self, null_selection_behavior="drop"): + self._set_options(null_selection_behavior) + + +cdef class _DictionaryEncodeOptions(FunctionOptions): + _null_encoding_map = { + "encode": CDictionaryEncodeNullEncodingBehavior_ENCODE, + "mask": CDictionaryEncodeNullEncodingBehavior_MASK, + } + + def _set_options(self, null_encoding): + try: + self.wrapped.reset( + new CDictionaryEncodeOptions( + self._null_encoding_map[null_encoding] + ) + ) + except KeyError: + _raise_invalid_function_option(null_encoding, "null encoding") + + +class DictionaryEncodeOptions(_DictionaryEncodeOptions): + def __init__(self, null_encoding="mask"): + self._set_options(null_encoding) + + +cdef class _TakeOptions(FunctionOptions): + def _set_options(self, boundscheck): + self.wrapped.reset(new CTakeOptions(boundscheck)) + + +class TakeOptions(_TakeOptions): + def __init__(self, *, boundscheck=True): + self._set_options(boundscheck) + + +cdef class _MakeStructOptions(FunctionOptions): + def _set_options(self, field_names, field_nullability, field_metadata): + cdef: + vector[c_string] c_field_names + vector[shared_ptr[const CKeyValueMetadata]] c_field_metadata + for name in field_names: + c_field_names.push_back(tobytes(name)) + for metadata in field_metadata: + c_field_metadata.push_back(pyarrow_unwrap_metadata(metadata)) + self.wrapped.reset( + new CMakeStructOptions(c_field_names, field_nullability, + c_field_metadata) + ) + + +class MakeStructOptions(_MakeStructOptions): + def __init__(self, field_names, *, field_nullability=None, + field_metadata=None): + if field_nullability is None: + field_nullability = [True] * len(field_names) + if field_metadata is None: + field_metadata = [None] * len(field_names) + self._set_options(field_names, field_nullability, field_metadata) + + +cdef class _ScalarAggregateOptions(FunctionOptions): + def _set_options(self, skip_nulls, min_count): + self.wrapped.reset(new CScalarAggregateOptions(skip_nulls, min_count)) + + +class ScalarAggregateOptions(_ScalarAggregateOptions): + def __init__(self, *, skip_nulls=True, min_count=1): + self._set_options(skip_nulls, min_count) + + +cdef class _CountOptions(FunctionOptions): + _mode_map = { + "only_valid": CCountMode_ONLY_VALID, + "only_null": CCountMode_ONLY_NULL, + "all": CCountMode_ALL, + } + + def _set_options(self, mode): + try: + self.wrapped.reset(new CCountOptions(self._mode_map[mode])) + except KeyError: + _raise_invalid_function_option(mode, "count mode") + + +class CountOptions(_CountOptions): + def __init__(self, mode="only_valid"): + self._set_options(mode) + + +cdef class _IndexOptions(FunctionOptions): + def _set_options(self, scalar): + self.wrapped.reset(new CIndexOptions(pyarrow_unwrap_scalar(scalar))) + + +class IndexOptions(_IndexOptions): + """ + Options for the index kernel. + + Parameters + ---------- + value : Scalar + The value to search for. + """ + + def __init__(self, value): + self._set_options(value) + + +cdef class _ModeOptions(FunctionOptions): + def _set_options(self, n, skip_nulls, min_count): + self.wrapped.reset(new CModeOptions(n, skip_nulls, min_count)) + + +class ModeOptions(_ModeOptions): + def __init__(self, n=1, *, skip_nulls=True, min_count=0): + self._set_options(n, skip_nulls, min_count) + + +cdef class _SetLookupOptions(FunctionOptions): + def _set_options(self, value_set, c_bool skip_nulls): + cdef unique_ptr[CDatum] valset + if isinstance(value_set, Array): + valset.reset(new CDatum((<Array> value_set).sp_array)) + elif isinstance(value_set, ChunkedArray): + valset.reset( + new CDatum((<ChunkedArray> value_set).sp_chunked_array) + ) + elif isinstance(value_set, Scalar): + valset.reset(new CDatum((<Scalar> value_set).unwrap())) + else: + _raise_invalid_function_option(value_set, "value set", + exception_class=TypeError) + + self.wrapped.reset(new CSetLookupOptions(deref(valset), skip_nulls)) + + +class SetLookupOptions(_SetLookupOptions): + def __init__(self, value_set, *, skip_nulls=False): + self._set_options(value_set, skip_nulls) + + +cdef class _StrptimeOptions(FunctionOptions): + _unit_map = { + "s": TimeUnit_SECOND, + "ms": TimeUnit_MILLI, + "us": TimeUnit_MICRO, + "ns": TimeUnit_NANO, + } + + def _set_options(self, format, unit): + try: + self.wrapped.reset( + new CStrptimeOptions(tobytes(format), self._unit_map[unit]) + ) + except KeyError: + _raise_invalid_function_option(unit, "time unit") + + +class StrptimeOptions(_StrptimeOptions): + def __init__(self, format, unit): + self._set_options(format, unit) + + +cdef class _StrftimeOptions(FunctionOptions): + def _set_options(self, format, locale): + self.wrapped.reset( + new CStrftimeOptions(tobytes(format), tobytes(locale)) + ) + + +class StrftimeOptions(_StrftimeOptions): + def __init__(self, format="%Y-%m-%dT%H:%M:%S", locale="C"): + self._set_options(format, locale) + + +cdef class _DayOfWeekOptions(FunctionOptions): + def _set_options(self, count_from_zero, week_start): + self.wrapped.reset( + new CDayOfWeekOptions(count_from_zero, week_start) + ) + + +class DayOfWeekOptions(_DayOfWeekOptions): + def __init__(self, *, count_from_zero=True, week_start=1): + self._set_options(count_from_zero, week_start) + + +cdef class _WeekOptions(FunctionOptions): + def _set_options(self, week_starts_monday, count_from_zero, + first_week_is_fully_in_year): + self.wrapped.reset( + new CWeekOptions(week_starts_monday, count_from_zero, + first_week_is_fully_in_year) + ) + + +class WeekOptions(_WeekOptions): + def __init__(self, *, week_starts_monday=True, count_from_zero=False, + first_week_is_fully_in_year=False): + self._set_options(week_starts_monday, + count_from_zero, first_week_is_fully_in_year) + + +cdef class _AssumeTimezoneOptions(FunctionOptions): + _ambiguous_map = { + "raise": CAssumeTimezoneAmbiguous_AMBIGUOUS_RAISE, + "earliest": CAssumeTimezoneAmbiguous_AMBIGUOUS_EARLIEST, + "latest": CAssumeTimezoneAmbiguous_AMBIGUOUS_LATEST, + } + _nonexistent_map = { + "raise": CAssumeTimezoneNonexistent_NONEXISTENT_RAISE, + "earliest": CAssumeTimezoneNonexistent_NONEXISTENT_EARLIEST, + "latest": CAssumeTimezoneNonexistent_NONEXISTENT_LATEST, + } + + def _set_options(self, timezone, ambiguous, nonexistent): + if ambiguous not in self._ambiguous_map: + _raise_invalid_function_option(ambiguous, + "'ambiguous' timestamp handling") + if nonexistent not in self._nonexistent_map: + _raise_invalid_function_option(nonexistent, + "'nonexistent' timestamp handling") + self.wrapped.reset( + new CAssumeTimezoneOptions(tobytes(timezone), + self._ambiguous_map[ambiguous], + self._nonexistent_map[nonexistent]) + ) + + +class AssumeTimezoneOptions(_AssumeTimezoneOptions): + def __init__(self, timezone, *, ambiguous="raise", nonexistent="raise"): + self._set_options(timezone, ambiguous, nonexistent) + + +cdef class _NullOptions(FunctionOptions): + def _set_options(self, nan_is_null): + self.wrapped.reset(new CNullOptions(nan_is_null)) + + +class NullOptions(_NullOptions): + def __init__(self, *, nan_is_null=False): + self._set_options(nan_is_null) + + +cdef class _VarianceOptions(FunctionOptions): + def _set_options(self, ddof, skip_nulls, min_count): + self.wrapped.reset(new CVarianceOptions(ddof, skip_nulls, min_count)) + + +class VarianceOptions(_VarianceOptions): + def __init__(self, *, ddof=0, skip_nulls=True, min_count=0): + self._set_options(ddof, skip_nulls, min_count) + + +cdef class _SplitOptions(FunctionOptions): + def _set_options(self, max_splits, reverse): + self.wrapped.reset(new CSplitOptions(max_splits, reverse)) + + +class SplitOptions(_SplitOptions): + def __init__(self, *, max_splits=-1, reverse=False): + self._set_options(max_splits, reverse) + + +cdef class _SplitPatternOptions(FunctionOptions): + def _set_options(self, pattern, max_splits, reverse): + self.wrapped.reset( + new CSplitPatternOptions(tobytes(pattern), max_splits, reverse) + ) + + +class SplitPatternOptions(_SplitPatternOptions): + def __init__(self, pattern, *, max_splits=-1, reverse=False): + self._set_options(pattern, max_splits, reverse) + + +cdef CSortOrder unwrap_sort_order(order) except *: + if order == "ascending": + return CSortOrder_Ascending + elif order == "descending": + return CSortOrder_Descending + _raise_invalid_function_option(order, "sort order") + + +cdef CNullPlacement unwrap_null_placement(null_placement) except *: + if null_placement == "at_start": + return CNullPlacement_AtStart + elif null_placement == "at_end": + return CNullPlacement_AtEnd + _raise_invalid_function_option(null_placement, "null placement") + + +cdef class _PartitionNthOptions(FunctionOptions): + def _set_options(self, pivot, null_placement): + self.wrapped.reset(new CPartitionNthOptions( + pivot, unwrap_null_placement(null_placement))) + + +class PartitionNthOptions(_PartitionNthOptions): + def __init__(self, pivot, *, null_placement="at_end"): + self._set_options(pivot, null_placement) + + +cdef class _ArraySortOptions(FunctionOptions): + def _set_options(self, order, null_placement): + self.wrapped.reset(new CArraySortOptions( + unwrap_sort_order(order), unwrap_null_placement(null_placement))) + + +class ArraySortOptions(_ArraySortOptions): + def __init__(self, order="ascending", *, null_placement="at_end"): + self._set_options(order, null_placement) + + +cdef class _SortOptions(FunctionOptions): + def _set_options(self, sort_keys, null_placement): + cdef vector[CSortKey] c_sort_keys + for name, order in sort_keys: + c_sort_keys.push_back( + CSortKey(tobytes(name), unwrap_sort_order(order)) + ) + self.wrapped.reset(new CSortOptions( + c_sort_keys, unwrap_null_placement(null_placement))) + + +class SortOptions(_SortOptions): + def __init__(self, sort_keys, *, null_placement="at_end"): + self._set_options(sort_keys, null_placement) + + +cdef class _SelectKOptions(FunctionOptions): + def _set_options(self, k, sort_keys): + cdef vector[CSortKey] c_sort_keys + for name, order in sort_keys: + c_sort_keys.push_back( + CSortKey(tobytes(name), unwrap_sort_order(order)) + ) + self.wrapped.reset(new CSelectKOptions(k, c_sort_keys)) + + +class SelectKOptions(_SelectKOptions): + def __init__(self, k, sort_keys): + self._set_options(k, sort_keys) + + +cdef class _QuantileOptions(FunctionOptions): + _interp_map = { + "linear": CQuantileInterp_LINEAR, + "lower": CQuantileInterp_LOWER, + "higher": CQuantileInterp_HIGHER, + "nearest": CQuantileInterp_NEAREST, + "midpoint": CQuantileInterp_MIDPOINT, + } + + def _set_options(self, quantiles, interp, skip_nulls, min_count): + try: + self.wrapped.reset( + new CQuantileOptions(quantiles, self._interp_map[interp], + skip_nulls, min_count) + ) + except KeyError: + _raise_invalid_function_option(interp, "quantile interpolation") + + +class QuantileOptions(_QuantileOptions): + def __init__(self, q=0.5, *, interpolation="linear", skip_nulls=True, + min_count=0): + if not isinstance(q, (list, tuple, np.ndarray)): + q = [q] + self._set_options(q, interpolation, skip_nulls, min_count) + + +cdef class _TDigestOptions(FunctionOptions): + def _set_options(self, quantiles, delta, buffer_size, skip_nulls, + min_count): + self.wrapped.reset( + new CTDigestOptions(quantiles, delta, buffer_size, skip_nulls, + min_count) + ) + + +class TDigestOptions(_TDigestOptions): + def __init__(self, q=0.5, *, delta=100, buffer_size=500, skip_nulls=True, + min_count=0): + if not isinstance(q, (list, tuple, np.ndarray)): + q = [q] + self._set_options(q, delta, buffer_size, skip_nulls, min_count) diff --git a/src/arrow/python/pyarrow/_csv.pxd b/src/arrow/python/pyarrow/_csv.pxd new file mode 100644 index 000000000..b2fe7d639 --- /dev/null +++ b/src/arrow/python/pyarrow/_csv.pxd @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport _Weakrefable + + +cdef class ConvertOptions(_Weakrefable): + cdef: + unique_ptr[CCSVConvertOptions] options + + @staticmethod + cdef ConvertOptions wrap(CCSVConvertOptions options) + + +cdef class ParseOptions(_Weakrefable): + cdef: + unique_ptr[CCSVParseOptions] options + + @staticmethod + cdef ParseOptions wrap(CCSVParseOptions options) + + +cdef class ReadOptions(_Weakrefable): + cdef: + unique_ptr[CCSVReadOptions] options + public object encoding + + @staticmethod + cdef ReadOptions wrap(CCSVReadOptions options) + + +cdef class WriteOptions(_Weakrefable): + cdef: + unique_ptr[CCSVWriteOptions] options + + @staticmethod + cdef WriteOptions wrap(CCSVWriteOptions options) diff --git a/src/arrow/python/pyarrow/_csv.pyx b/src/arrow/python/pyarrow/_csv.pyx new file mode 100644 index 000000000..19ade4324 --- /dev/null +++ b/src/arrow/python/pyarrow/_csv.pyx @@ -0,0 +1,1077 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from cython.operator cimport dereference as deref + +import codecs +from collections.abc import Mapping + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema, + RecordBatchReader, ensure_type, + maybe_unbox_memory_pool, get_input_stream, + get_writer, native_transcoding_input_stream, + pyarrow_unwrap_batch, pyarrow_unwrap_schema, + pyarrow_unwrap_table, pyarrow_wrap_schema, + pyarrow_wrap_table, pyarrow_wrap_data_type, + pyarrow_unwrap_data_type, Table, RecordBatch, + StopToken, _CRecordBatchWriter) +from pyarrow.lib import frombytes, tobytes, SignalStopHandler +from pyarrow.util import _stringify_path + + +cdef unsigned char _single_char(s) except 0: + val = ord(s) + if val == 0 or val > 127: + raise ValueError("Expecting an ASCII character") + return <unsigned char> val + + +cdef class ReadOptions(_Weakrefable): + """ + Options for reading CSV files. + + Parameters + ---------- + use_threads : bool, optional (default True) + Whether to use multiple threads to accelerate reading + block_size : int, optional + How much bytes to process at a time from the input stream. + This will determine multi-threading granularity as well as + the size of individual record batches or table chunks. + Minimum valid value for block size is 1 + skip_rows : int, optional (default 0) + The number of rows to skip before the column names (if any) + and the CSV data. + skip_rows_after_names : int, optional (default 0) + The number of rows to skip after the column names. + This number can be larger than the number of rows in one + block, and empty rows are counted. + The order of application is as follows: + - `skip_rows` is applied (if non-zero); + - column names aread (unless `column_names` is set); + - `skip_rows_after_names` is applied (if non-zero). + column_names : list, optional + The column names of the target table. If empty, fall back on + `autogenerate_column_names`. + autogenerate_column_names : bool, optional (default False) + Whether to autogenerate column names if `column_names` is empty. + If true, column names will be of the form "f0", "f1"... + If false, column names will be read from the first CSV row + after `skip_rows`. + encoding : str, optional (default 'utf8') + The character encoding of the CSV data. Columns that cannot + decode using this encoding can still be read as Binary. + """ + + # Avoid mistakingly creating attributes + __slots__ = () + + # __init__() is not called when unpickling, initialize storage here + def __cinit__(self, *argw, **kwargs): + self.options.reset(new CCSVReadOptions(CCSVReadOptions.Defaults())) + + def __init__(self, *, use_threads=None, block_size=None, skip_rows=None, + column_names=None, autogenerate_column_names=None, + encoding='utf8', skip_rows_after_names=None): + if use_threads is not None: + self.use_threads = use_threads + if block_size is not None: + self.block_size = block_size + if skip_rows is not None: + self.skip_rows = skip_rows + if column_names is not None: + self.column_names = column_names + if autogenerate_column_names is not None: + self.autogenerate_column_names= autogenerate_column_names + # Python-specific option + self.encoding = encoding + if skip_rows_after_names is not None: + self.skip_rows_after_names = skip_rows_after_names + + @property + def use_threads(self): + """ + Whether to use multiple threads to accelerate reading. + """ + return deref(self.options).use_threads + + @use_threads.setter + def use_threads(self, value): + deref(self.options).use_threads = value + + @property + def block_size(self): + """ + How much bytes to process at a time from the input stream. + This will determine multi-threading granularity as well as + the size of individual record batches or table chunks. + """ + return deref(self.options).block_size + + @block_size.setter + def block_size(self, value): + deref(self.options).block_size = value + + @property + def skip_rows(self): + """ + The number of rows to skip before the column names (if any) + and the CSV data. + See `skip_rows_after_names` for interaction description + """ + return deref(self.options).skip_rows + + @skip_rows.setter + def skip_rows(self, value): + deref(self.options).skip_rows = value + + @property + def column_names(self): + """ + The column names of the target table. If empty, fall back on + `autogenerate_column_names`. + """ + return [frombytes(s) for s in deref(self.options).column_names] + + @column_names.setter + def column_names(self, value): + deref(self.options).column_names.clear() + for item in value: + deref(self.options).column_names.push_back(tobytes(item)) + + @property + def autogenerate_column_names(self): + """ + Whether to autogenerate column names if `column_names` is empty. + If true, column names will be of the form "f0", "f1"... + If false, column names will be read from the first CSV row + after `skip_rows`. + """ + return deref(self.options).autogenerate_column_names + + @autogenerate_column_names.setter + def autogenerate_column_names(self, value): + deref(self.options).autogenerate_column_names = value + + @property + def skip_rows_after_names(self): + """ + The number of rows to skip after the column names. + This number can be larger than the number of rows in one + block, and empty rows are counted. + The order of application is as follows: + - `skip_rows` is applied (if non-zero); + - column names aread (unless `column_names` is set); + - `skip_rows_after_names` is applied (if non-zero). + """ + return deref(self.options).skip_rows_after_names + + @skip_rows_after_names.setter + def skip_rows_after_names(self, value): + deref(self.options).skip_rows_after_names = value + + def validate(self): + check_status(deref(self.options).Validate()) + + def equals(self, ReadOptions other): + return ( + self.use_threads == other.use_threads and + self.block_size == other.block_size and + self.skip_rows == other.skip_rows and + self.column_names == other.column_names and + self.autogenerate_column_names == + other.autogenerate_column_names and + self.encoding == other.encoding and + self.skip_rows_after_names == other.skip_rows_after_names + ) + + @staticmethod + cdef ReadOptions wrap(CCSVReadOptions options): + out = ReadOptions() + out.options.reset(new CCSVReadOptions(move(options))) + out.encoding = 'utf8' # No way to know this + return out + + def __getstate__(self): + return (self.use_threads, self.block_size, self.skip_rows, + self.column_names, self.autogenerate_column_names, + self.encoding, self.skip_rows_after_names) + + def __setstate__(self, state): + (self.use_threads, self.block_size, self.skip_rows, + self.column_names, self.autogenerate_column_names, + self.encoding, self.skip_rows_after_names) = state + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class ParseOptions(_Weakrefable): + """ + Options for parsing CSV files. + + Parameters + ---------- + delimiter : 1-character string, optional (default ',') + The character delimiting individual cells in the CSV data. + quote_char : 1-character string or False, optional (default '"') + The character used optionally for quoting CSV values + (False if quoting is not allowed). + double_quote : bool, optional (default True) + Whether two quotes in a quoted CSV value denote a single quote + in the data. + escape_char : 1-character string or False, optional (default False) + The character used optionally for escaping special characters + (False if escaping is not allowed). + newlines_in_values : bool, optional (default False) + Whether newline characters are allowed in CSV values. + Setting this to True reduces the performance of multi-threaded + CSV reading. + ignore_empty_lines : bool, optional (default True) + Whether empty lines are ignored in CSV input. + If False, an empty line is interpreted as containing a single empty + value (assuming a one-column CSV file). + """ + __slots__ = () + + def __cinit__(self, *argw, **kwargs): + self.options.reset(new CCSVParseOptions(CCSVParseOptions.Defaults())) + + def __init__(self, *, delimiter=None, quote_char=None, double_quote=None, + escape_char=None, newlines_in_values=None, + ignore_empty_lines=None): + if delimiter is not None: + self.delimiter = delimiter + if quote_char is not None: + self.quote_char = quote_char + if double_quote is not None: + self.double_quote = double_quote + if escape_char is not None: + self.escape_char = escape_char + if newlines_in_values is not None: + self.newlines_in_values = newlines_in_values + if ignore_empty_lines is not None: + self.ignore_empty_lines = ignore_empty_lines + + @property + def delimiter(self): + """ + The character delimiting individual cells in the CSV data. + """ + return chr(deref(self.options).delimiter) + + @delimiter.setter + def delimiter(self, value): + deref(self.options).delimiter = _single_char(value) + + @property + def quote_char(self): + """ + The character used optionally for quoting CSV values + (False if quoting is not allowed). + """ + if deref(self.options).quoting: + return chr(deref(self.options).quote_char) + else: + return False + + @quote_char.setter + def quote_char(self, value): + if value is False: + deref(self.options).quoting = False + else: + deref(self.options).quote_char = _single_char(value) + deref(self.options).quoting = True + + @property + def double_quote(self): + """ + Whether two quotes in a quoted CSV value denote a single quote + in the data. + """ + return deref(self.options).double_quote + + @double_quote.setter + def double_quote(self, value): + deref(self.options).double_quote = value + + @property + def escape_char(self): + """ + The character used optionally for escaping special characters + (False if escaping is not allowed). + """ + if deref(self.options).escaping: + return chr(deref(self.options).escape_char) + else: + return False + + @escape_char.setter + def escape_char(self, value): + if value is False: + deref(self.options).escaping = False + else: + deref(self.options).escape_char = _single_char(value) + deref(self.options).escaping = True + + @property + def newlines_in_values(self): + """ + Whether newline characters are allowed in CSV values. + Setting this to True reduces the performance of multi-threaded + CSV reading. + """ + return deref(self.options).newlines_in_values + + @newlines_in_values.setter + def newlines_in_values(self, value): + deref(self.options).newlines_in_values = value + + @property + def ignore_empty_lines(self): + """ + Whether empty lines are ignored in CSV input. + If False, an empty line is interpreted as containing a single empty + value (assuming a one-column CSV file). + """ + return deref(self.options).ignore_empty_lines + + @ignore_empty_lines.setter + def ignore_empty_lines(self, value): + deref(self.options).ignore_empty_lines = value + + def validate(self): + check_status(deref(self.options).Validate()) + + def equals(self, ParseOptions other): + return ( + self.delimiter == other.delimiter and + self.quote_char == other.quote_char and + self.double_quote == other.double_quote and + self.escape_char == other.escape_char and + self.newlines_in_values == other.newlines_in_values and + self.ignore_empty_lines == other.ignore_empty_lines + ) + + @staticmethod + cdef ParseOptions wrap(CCSVParseOptions options): + out = ParseOptions() + out.options.reset(new CCSVParseOptions(move(options))) + return out + + def __getstate__(self): + return (self.delimiter, self.quote_char, self.double_quote, + self.escape_char, self.newlines_in_values, + self.ignore_empty_lines) + + def __setstate__(self, state): + (self.delimiter, self.quote_char, self.double_quote, + self.escape_char, self.newlines_in_values, + self.ignore_empty_lines) = state + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class _ISO8601(_Weakrefable): + """ + A special object indicating ISO-8601 parsing. + """ + __slots__ = () + + def __str__(self): + return 'ISO8601' + + def __eq__(self, other): + return isinstance(other, _ISO8601) + + +ISO8601 = _ISO8601() + + +cdef class ConvertOptions(_Weakrefable): + """ + Options for converting CSV data. + + Parameters + ---------- + check_utf8 : bool, optional (default True) + Whether to check UTF8 validity of string columns. + column_types : pa.Schema or dict, optional + Explicitly map column names to column types. Passing this argument + disables type inference on the defined columns. + null_values : list, optional + A sequence of strings that denote nulls in the data + (defaults are appropriate in most cases). Note that by default, + string columns are not checked for null values. To enable + null checking for those, specify ``strings_can_be_null=True``. + true_values : list, optional + A sequence of strings that denote true booleans in the data + (defaults are appropriate in most cases). + false_values : list, optional + A sequence of strings that denote false booleans in the data + (defaults are appropriate in most cases). + decimal_point : 1-character string, optional (default '.') + The character used as decimal point in floating-point and decimal + data. + timestamp_parsers : list, optional + A sequence of strptime()-compatible format strings, tried in order + when attempting to infer or convert timestamp values (the special + value ISO8601() can also be given). By default, a fast built-in + ISO-8601 parser is used. + strings_can_be_null : bool, optional (default False) + Whether string / binary columns can have null values. + If true, then strings in null_values are considered null for + string columns. + If false, then all strings are valid string values. + quoted_strings_can_be_null : bool, optional (default True) + Whether quoted values can be null. + If true, then strings in "null_values" are also considered null + when they appear quoted in the CSV file. Otherwise, quoted values + are never considered null. + auto_dict_encode : bool, optional (default False) + Whether to try to automatically dict-encode string / binary data. + If true, then when type inference detects a string or binary column, + it it dict-encoded up to `auto_dict_max_cardinality` distinct values + (per chunk), after which it switches to regular encoding. + This setting is ignored for non-inferred columns (those in + `column_types`). + auto_dict_max_cardinality : int, optional + The maximum dictionary cardinality for `auto_dict_encode`. + This value is per chunk. + include_columns : list, optional + The names of columns to include in the Table. + If empty, the Table will include all columns from the CSV file. + If not empty, only these columns will be included, in this order. + include_missing_columns : bool, optional (default False) + If false, columns in `include_columns` but not in the CSV file will + error out. + If true, columns in `include_columns` but not in the CSV file will + produce a column of nulls (whose type is selected using + `column_types`, or null by default). + This option is ignored if `include_columns` is empty. + """ + # Avoid mistakingly creating attributes + __slots__ = () + + def __cinit__(self, *argw, **kwargs): + self.options.reset( + new CCSVConvertOptions(CCSVConvertOptions.Defaults())) + + def __init__(self, *, check_utf8=None, column_types=None, null_values=None, + true_values=None, false_values=None, decimal_point=None, + strings_can_be_null=None, quoted_strings_can_be_null=None, + include_columns=None, include_missing_columns=None, + auto_dict_encode=None, auto_dict_max_cardinality=None, + timestamp_parsers=None): + if check_utf8 is not None: + self.check_utf8 = check_utf8 + if column_types is not None: + self.column_types = column_types + if null_values is not None: + self.null_values = null_values + if true_values is not None: + self.true_values = true_values + if false_values is not None: + self.false_values = false_values + if decimal_point is not None: + self.decimal_point = decimal_point + if strings_can_be_null is not None: + self.strings_can_be_null = strings_can_be_null + if quoted_strings_can_be_null is not None: + self.quoted_strings_can_be_null = quoted_strings_can_be_null + if include_columns is not None: + self.include_columns = include_columns + if include_missing_columns is not None: + self.include_missing_columns = include_missing_columns + if auto_dict_encode is not None: + self.auto_dict_encode = auto_dict_encode + if auto_dict_max_cardinality is not None: + self.auto_dict_max_cardinality = auto_dict_max_cardinality + if timestamp_parsers is not None: + self.timestamp_parsers = timestamp_parsers + + @property + def check_utf8(self): + """ + Whether to check UTF8 validity of string columns. + """ + return deref(self.options).check_utf8 + + @check_utf8.setter + def check_utf8(self, value): + deref(self.options).check_utf8 = value + + @property + def strings_can_be_null(self): + """ + Whether string / binary columns can have null values. + """ + return deref(self.options).strings_can_be_null + + @strings_can_be_null.setter + def strings_can_be_null(self, value): + deref(self.options).strings_can_be_null = value + + @property + def quoted_strings_can_be_null(self): + """ + Whether quoted values can be null. + """ + return deref(self.options).quoted_strings_can_be_null + + @quoted_strings_can_be_null.setter + def quoted_strings_can_be_null(self, value): + deref(self.options).quoted_strings_can_be_null = value + + @property + def column_types(self): + """ + Explicitly map column names to column types. + """ + d = {frombytes(item.first): pyarrow_wrap_data_type(item.second) + for item in deref(self.options).column_types} + return d + + @column_types.setter + def column_types(self, value): + cdef: + shared_ptr[CDataType] typ + + if isinstance(value, Mapping): + value = value.items() + + deref(self.options).column_types.clear() + for item in value: + if isinstance(item, Field): + k = item.name + v = item.type + else: + k, v = item + typ = pyarrow_unwrap_data_type(ensure_type(v)) + assert typ != NULL + deref(self.options).column_types[tobytes(k)] = typ + + @property + def null_values(self): + """ + A sequence of strings that denote nulls in the data. + """ + return [frombytes(x) for x in deref(self.options).null_values] + + @null_values.setter + def null_values(self, value): + deref(self.options).null_values = [tobytes(x) for x in value] + + @property + def true_values(self): + """ + A sequence of strings that denote true booleans in the data. + """ + return [frombytes(x) for x in deref(self.options).true_values] + + @true_values.setter + def true_values(self, value): + deref(self.options).true_values = [tobytes(x) for x in value] + + @property + def false_values(self): + """ + A sequence of strings that denote false booleans in the data. + """ + return [frombytes(x) for x in deref(self.options).false_values] + + @false_values.setter + def false_values(self, value): + deref(self.options).false_values = [tobytes(x) for x in value] + + @property + def decimal_point(self): + """ + The character used as decimal point in floating-point and decimal + data. + """ + return chr(deref(self.options).decimal_point) + + @decimal_point.setter + def decimal_point(self, value): + deref(self.options).decimal_point = _single_char(value) + + @property + def auto_dict_encode(self): + """ + Whether to try to automatically dict-encode string / binary data. + """ + return deref(self.options).auto_dict_encode + + @auto_dict_encode.setter + def auto_dict_encode(self, value): + deref(self.options).auto_dict_encode = value + + @property + def auto_dict_max_cardinality(self): + """ + The maximum dictionary cardinality for `auto_dict_encode`. + + This value is per chunk. + """ + return deref(self.options).auto_dict_max_cardinality + + @auto_dict_max_cardinality.setter + def auto_dict_max_cardinality(self, value): + deref(self.options).auto_dict_max_cardinality = value + + @property + def include_columns(self): + """ + The names of columns to include in the Table. + + If empty, the Table will include all columns from the CSV file. + If not empty, only these columns will be included, in this order. + """ + return [frombytes(s) for s in deref(self.options).include_columns] + + @include_columns.setter + def include_columns(self, value): + deref(self.options).include_columns.clear() + for item in value: + deref(self.options).include_columns.push_back(tobytes(item)) + + @property + def include_missing_columns(self): + """ + If false, columns in `include_columns` but not in the CSV file will + error out. + If true, columns in `include_columns` but not in the CSV file will + produce a null column (whose type is selected using `column_types`, + or null by default). + This option is ignored if `include_columns` is empty. + """ + return deref(self.options).include_missing_columns + + @include_missing_columns.setter + def include_missing_columns(self, value): + deref(self.options).include_missing_columns = value + + @property + def timestamp_parsers(self): + """ + A sequence of strptime()-compatible format strings, tried in order + when attempting to infer or convert timestamp values (the special + value ISO8601() can also be given). By default, a fast built-in + ISO-8601 parser is used. + """ + cdef: + shared_ptr[CTimestampParser] c_parser + c_string kind + + parsers = [] + for c_parser in deref(self.options).timestamp_parsers: + kind = deref(c_parser).kind() + if kind == b'strptime': + parsers.append(frombytes(deref(c_parser).format())) + else: + assert kind == b'iso8601' + parsers.append(ISO8601) + + return parsers + + @timestamp_parsers.setter + def timestamp_parsers(self, value): + cdef: + vector[shared_ptr[CTimestampParser]] c_parsers + + for v in value: + if isinstance(v, str): + c_parsers.push_back(CTimestampParser.MakeStrptime(tobytes(v))) + elif v == ISO8601: + c_parsers.push_back(CTimestampParser.MakeISO8601()) + else: + raise TypeError("Expected list of str or ISO8601 objects") + + deref(self.options).timestamp_parsers = move(c_parsers) + + @staticmethod + cdef ConvertOptions wrap(CCSVConvertOptions options): + out = ConvertOptions() + out.options.reset(new CCSVConvertOptions(move(options))) + return out + + def validate(self): + check_status(deref(self.options).Validate()) + + def equals(self, ConvertOptions other): + return ( + self.check_utf8 == other.check_utf8 and + self.column_types == other.column_types and + self.null_values == other.null_values and + self.true_values == other.true_values and + self.false_values == other.false_values and + self.decimal_point == other.decimal_point and + self.timestamp_parsers == other.timestamp_parsers and + self.strings_can_be_null == other.strings_can_be_null and + self.quoted_strings_can_be_null == + other.quoted_strings_can_be_null and + self.auto_dict_encode == other.auto_dict_encode and + self.auto_dict_max_cardinality == + other.auto_dict_max_cardinality and + self.include_columns == other.include_columns and + self.include_missing_columns == other.include_missing_columns + ) + + def __getstate__(self): + return (self.check_utf8, self.column_types, self.null_values, + self.true_values, self.false_values, self.decimal_point, + self.timestamp_parsers, self.strings_can_be_null, + self.quoted_strings_can_be_null, self.auto_dict_encode, + self.auto_dict_max_cardinality, self.include_columns, + self.include_missing_columns) + + def __setstate__(self, state): + (self.check_utf8, self.column_types, self.null_values, + self.true_values, self.false_values, self.decimal_point, + self.timestamp_parsers, self.strings_can_be_null, + self.quoted_strings_can_be_null, self.auto_dict_encode, + self.auto_dict_max_cardinality, self.include_columns, + self.include_missing_columns) = state + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef _get_reader(input_file, ReadOptions read_options, + shared_ptr[CInputStream]* out): + use_memory_map = False + get_input_stream(input_file, use_memory_map, out) + if read_options is not None: + out[0] = native_transcoding_input_stream(out[0], + read_options.encoding, + 'utf8') + + +cdef _get_read_options(ReadOptions read_options, CCSVReadOptions* out): + if read_options is None: + out[0] = CCSVReadOptions.Defaults() + else: + out[0] = deref(read_options.options) + + +cdef _get_parse_options(ParseOptions parse_options, CCSVParseOptions* out): + if parse_options is None: + out[0] = CCSVParseOptions.Defaults() + else: + out[0] = deref(parse_options.options) + + +cdef _get_convert_options(ConvertOptions convert_options, + CCSVConvertOptions* out): + if convert_options is None: + out[0] = CCSVConvertOptions.Defaults() + else: + out[0] = deref(convert_options.options) + + +cdef class CSVStreamingReader(RecordBatchReader): + """An object that reads record batches incrementally from a CSV file. + + Should not be instantiated directly by user code. + """ + cdef readonly: + Schema schema + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.csv.open_csv() instead." + .format(self.__class__.__name__)) + + # Note about cancellation: we cannot create a SignalStopHandler + # by default here, as several CSVStreamingReader instances may be + # created (including by the same thread). Handling cancellation + # would require having the user pass the SignalStopHandler. + # (in addition to solving ARROW-11853) + + cdef _open(self, shared_ptr[CInputStream] stream, + CCSVReadOptions c_read_options, + CCSVParseOptions c_parse_options, + CCSVConvertOptions c_convert_options, + MemoryPool memory_pool): + cdef: + shared_ptr[CSchema] c_schema + CIOContext io_context + + io_context = CIOContext(maybe_unbox_memory_pool(memory_pool)) + + with nogil: + self.reader = <shared_ptr[CRecordBatchReader]> GetResultValue( + CCSVStreamingReader.Make( + io_context, stream, + move(c_read_options), move(c_parse_options), + move(c_convert_options))) + c_schema = self.reader.get().schema() + + self.schema = pyarrow_wrap_schema(c_schema) + + +def read_csv(input_file, read_options=None, parse_options=None, + convert_options=None, MemoryPool memory_pool=None): + """ + Read a Table from a stream of CSV data. + + Parameters + ---------- + input_file : string, path or file-like object + The location of CSV data. If a string or path, and if it ends + with a recognized compressed file extension (e.g. ".gz" or ".bz2"), + the data is automatically decompressed when reading. + read_options : pyarrow.csv.ReadOptions, optional + Options for the CSV reader (see pyarrow.csv.ReadOptions constructor + for defaults) + parse_options : pyarrow.csv.ParseOptions, optional + Options for the CSV parser + (see pyarrow.csv.ParseOptions constructor for defaults) + convert_options : pyarrow.csv.ConvertOptions, optional + Options for converting CSV data + (see pyarrow.csv.ConvertOptions constructor for defaults) + memory_pool : MemoryPool, optional + Pool to allocate Table memory from + + Returns + ------- + :class:`pyarrow.Table` + Contents of the CSV file as a in-memory table. + """ + cdef: + shared_ptr[CInputStream] stream + CCSVReadOptions c_read_options + CCSVParseOptions c_parse_options + CCSVConvertOptions c_convert_options + CIOContext io_context + shared_ptr[CCSVReader] reader + shared_ptr[CTable] table + + _get_reader(input_file, read_options, &stream) + _get_read_options(read_options, &c_read_options) + _get_parse_options(parse_options, &c_parse_options) + _get_convert_options(convert_options, &c_convert_options) + + with SignalStopHandler() as stop_handler: + io_context = CIOContext( + maybe_unbox_memory_pool(memory_pool), + (<StopToken> stop_handler.stop_token).stop_token) + reader = GetResultValue(CCSVReader.Make( + io_context, stream, + c_read_options, c_parse_options, c_convert_options)) + + with nogil: + table = GetResultValue(reader.get().Read()) + + return pyarrow_wrap_table(table) + + +def open_csv(input_file, read_options=None, parse_options=None, + convert_options=None, MemoryPool memory_pool=None): + """ + Open a streaming reader of CSV data. + + Reading using this function is always single-threaded. + + Parameters + ---------- + input_file : string, path or file-like object + The location of CSV data. If a string or path, and if it ends + with a recognized compressed file extension (e.g. ".gz" or ".bz2"), + the data is automatically decompressed when reading. + read_options : pyarrow.csv.ReadOptions, optional + Options for the CSV reader (see pyarrow.csv.ReadOptions constructor + for defaults) + parse_options : pyarrow.csv.ParseOptions, optional + Options for the CSV parser + (see pyarrow.csv.ParseOptions constructor for defaults) + convert_options : pyarrow.csv.ConvertOptions, optional + Options for converting CSV data + (see pyarrow.csv.ConvertOptions constructor for defaults) + memory_pool : MemoryPool, optional + Pool to allocate Table memory from + + Returns + ------- + :class:`pyarrow.csv.CSVStreamingReader` + """ + cdef: + shared_ptr[CInputStream] stream + CCSVReadOptions c_read_options + CCSVParseOptions c_parse_options + CCSVConvertOptions c_convert_options + CSVStreamingReader reader + + _get_reader(input_file, read_options, &stream) + _get_read_options(read_options, &c_read_options) + _get_parse_options(parse_options, &c_parse_options) + _get_convert_options(convert_options, &c_convert_options) + + reader = CSVStreamingReader.__new__(CSVStreamingReader) + reader._open(stream, move(c_read_options), move(c_parse_options), + move(c_convert_options), memory_pool) + return reader + + +cdef class WriteOptions(_Weakrefable): + """ + Options for writing CSV files. + + Parameters + ---------- + include_header : bool, optional (default True) + Whether to write an initial header line with column names + batch_size : int, optional (default 1024) + How many rows to process together when converting and writing + CSV data + """ + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, *, include_header=None, batch_size=None): + self.options.reset(new CCSVWriteOptions(CCSVWriteOptions.Defaults())) + if include_header is not None: + self.include_header = include_header + if batch_size is not None: + self.batch_size = batch_size + + @property + def include_header(self): + """ + Whether to write an initial header line with column names. + """ + return deref(self.options).include_header + + @include_header.setter + def include_header(self, value): + deref(self.options).include_header = value + + @property + def batch_size(self): + """ + How many rows to process together when converting and writing + CSV data. + """ + return deref(self.options).batch_size + + @batch_size.setter + def batch_size(self, value): + deref(self.options).batch_size = value + + @staticmethod + cdef WriteOptions wrap(CCSVWriteOptions options): + out = WriteOptions() + out.options.reset(new CCSVWriteOptions(move(options))) + return out + + def validate(self): + check_status(self.options.get().Validate()) + + +cdef _get_write_options(WriteOptions write_options, CCSVWriteOptions* out): + if write_options is None: + out[0] = CCSVWriteOptions.Defaults() + else: + out[0] = deref(write_options.options) + + +def write_csv(data, output_file, write_options=None, + MemoryPool memory_pool=None): + """ + Write record batch or table to a CSV file. + + Parameters + ---------- + data : pyarrow.RecordBatch or pyarrow.Table + The data to write. + output_file : string, path, pyarrow.NativeFile, or file-like object + The location where to write the CSV data. + write_options : pyarrow.csv.WriteOptions + Options to configure writing the CSV data. + memory_pool : MemoryPool, optional + Pool for temporary allocations. + """ + cdef: + shared_ptr[COutputStream] stream + CCSVWriteOptions c_write_options + CMemoryPool* c_memory_pool + CRecordBatch* batch + CTable* table + _get_write_options(write_options, &c_write_options) + + get_writer(output_file, &stream) + c_memory_pool = maybe_unbox_memory_pool(memory_pool) + c_write_options.io_context = CIOContext(c_memory_pool) + if isinstance(data, RecordBatch): + batch = pyarrow_unwrap_batch(data).get() + with nogil: + check_status(WriteCSV(deref(batch), c_write_options, stream.get())) + elif isinstance(data, Table): + table = pyarrow_unwrap_table(data).get() + with nogil: + check_status(WriteCSV(deref(table), c_write_options, stream.get())) + else: + raise TypeError(f"Expected Table or RecordBatch, got '{type(data)}'") + + +cdef class CSVWriter(_CRecordBatchWriter): + """ + Writer to create a CSV file. + + Parameters + ---------- + sink : str, path, pyarrow.OutputStream or file-like object + The location where to write the CSV data. + schema : pyarrow.Schema + The schema of the data to be written. + write_options : pyarrow.csv.WriteOptions + Options to configure writing the CSV data. + memory_pool : MemoryPool, optional + Pool for temporary allocations. + """ + + def __init__(self, sink, Schema schema, *, + WriteOptions write_options=None, MemoryPool memory_pool=None): + cdef: + shared_ptr[COutputStream] c_stream + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + CCSVWriteOptions c_write_options + CMemoryPool* c_memory_pool = maybe_unbox_memory_pool(memory_pool) + _get_write_options(write_options, &c_write_options) + c_write_options.io_context = CIOContext(c_memory_pool) + get_writer(sink, &c_stream) + with nogil: + self.writer = GetResultValue(MakeCSVWriter( + c_stream, c_schema, c_write_options)) diff --git a/src/arrow/python/pyarrow/_cuda.pxd b/src/arrow/python/pyarrow/_cuda.pxd new file mode 100644 index 000000000..6acb8826d --- /dev/null +++ b/src/arrow/python/pyarrow/_cuda.pxd @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.lib cimport * +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_cuda cimport * + + +cdef class Context(_Weakrefable): + cdef: + shared_ptr[CCudaContext] context + int device_number + + cdef void init(self, const shared_ptr[CCudaContext]& ctx) + + +cdef class IpcMemHandle(_Weakrefable): + cdef: + shared_ptr[CCudaIpcMemHandle] handle + + cdef void init(self, shared_ptr[CCudaIpcMemHandle]& h) + + +cdef class CudaBuffer(Buffer): + cdef: + shared_ptr[CCudaBuffer] cuda_buffer + object base + + cdef void init_cuda(self, + const shared_ptr[CCudaBuffer]& buffer, + object base) + + +cdef class HostBuffer(Buffer): + cdef: + shared_ptr[CCudaHostBuffer] host_buffer + + cdef void init_host(self, const shared_ptr[CCudaHostBuffer]& buffer) + + +cdef class BufferReader(NativeFile): + cdef: + CCudaBufferReader* reader + CudaBuffer buffer + + +cdef class BufferWriter(NativeFile): + cdef: + CCudaBufferWriter* writer + CudaBuffer buffer diff --git a/src/arrow/python/pyarrow/_cuda.pyx b/src/arrow/python/pyarrow/_cuda.pyx new file mode 100644 index 000000000..1b66b9508 --- /dev/null +++ b/src/arrow/python/pyarrow/_cuda.pyx @@ -0,0 +1,1060 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from pyarrow.lib import tobytes +from pyarrow.lib cimport * +from pyarrow.includes.libarrow_cuda cimport * +from pyarrow.lib import py_buffer, allocate_buffer, as_buffer, ArrowTypeError +from pyarrow.util import get_contiguous_span +cimport cpython as cp + + +cdef class Context(_Weakrefable): + """ + CUDA driver context. + """ + + def __init__(self, *args, **kwargs): + """ + Create a CUDA driver context for a particular device. + + If a CUDA context handle is passed, it is wrapped, otherwise + a default CUDA context for the given device is requested. + + Parameters + ---------- + device_number : int (default 0) + Specify the GPU device for which the CUDA driver context is + requested. + handle : int, optional + Specify CUDA handle for a shared context that has been created + by another library. + """ + # This method exposed because autodoc doesn't pick __cinit__ + + def __cinit__(self, int device_number=0, uintptr_t handle=0): + cdef CCudaDeviceManager* manager + manager = GetResultValue(CCudaDeviceManager.Instance()) + cdef int n = manager.num_devices() + if device_number >= n or device_number < 0: + self.context.reset() + raise ValueError('device_number argument must be ' + 'non-negative less than %s' % (n)) + if handle == 0: + self.context = GetResultValue(manager.GetContext(device_number)) + else: + self.context = GetResultValue(manager.GetSharedContext( + device_number, <void*>handle)) + self.device_number = device_number + + @staticmethod + def from_numba(context=None): + """ + Create a Context instance from a Numba CUDA context. + + Parameters + ---------- + context : {numba.cuda.cudadrv.driver.Context, None} + A Numba CUDA context instance. + If None, the current Numba context is used. + + Returns + ------- + shared_context : pyarrow.cuda.Context + Context instance. + """ + if context is None: + import numba.cuda + context = numba.cuda.current_context() + return Context(device_number=context.device.id, + handle=context.handle.value) + + def to_numba(self): + """ + Convert Context to a Numba CUDA context. + + Returns + ------- + context : numba.cuda.cudadrv.driver.Context + Numba CUDA context instance. + """ + import ctypes + import numba.cuda + device = numba.cuda.gpus[self.device_number] + handle = ctypes.c_void_p(self.handle) + context = numba.cuda.cudadrv.driver.Context(device, handle) + + class DummyPendingDeallocs(object): + # Context is managed by pyarrow + def add_item(self, *args, **kwargs): + pass + + context.deallocations = DummyPendingDeallocs() + return context + + @staticmethod + def get_num_devices(): + """ Return the number of GPU devices. + """ + cdef CCudaDeviceManager* manager + manager = GetResultValue(CCudaDeviceManager.Instance()) + return manager.num_devices() + + @property + def device_number(self): + """ Return context device number. + """ + return self.device_number + + @property + def handle(self): + """ Return pointer to context handle. + """ + return <uintptr_t>self.context.get().handle() + + cdef void init(self, const shared_ptr[CCudaContext]& ctx): + self.context = ctx + + def synchronize(self): + """Blocks until the device has completed all preceding requested + tasks. + """ + check_status(self.context.get().Synchronize()) + + @property + def bytes_allocated(self): + """Return the number of allocated bytes. + """ + return self.context.get().bytes_allocated() + + def get_device_address(self, uintptr_t address): + """Return the device address that is reachable from kernels running in + the context + + Parameters + ---------- + address : int + Specify memory address value + + Returns + ------- + device_address : int + Device address accessible from device context + + Notes + ----- + The device address is defined as a memory address accessible + by device. While it is often a device memory address but it + can be also a host memory address, for instance, when the + memory is allocated as host memory (using cudaMallocHost or + cudaHostAlloc) or as managed memory (using cudaMallocManaged) + or the host memory is page-locked (using cudaHostRegister). + """ + return GetResultValue(self.context.get().GetDeviceAddress(address)) + + def new_buffer(self, int64_t nbytes): + """Return new device buffer. + + Parameters + ---------- + nbytes : int + Specify the number of bytes to be allocated. + + Returns + ------- + buf : CudaBuffer + Allocated buffer. + """ + cdef: + shared_ptr[CCudaBuffer] cudabuf + with nogil: + cudabuf = GetResultValue(self.context.get().Allocate(nbytes)) + return pyarrow_wrap_cudabuffer(cudabuf) + + def foreign_buffer(self, address, size, base=None): + """ + Create device buffer from address and size as a view. + + The caller is responsible for allocating and freeing the + memory. When `address==size==0` then a new zero-sized buffer + is returned. + + Parameters + ---------- + address : int + Specify the starting address of the buffer. The address can + refer to both device or host memory but it must be + accessible from device after mapping it with + `get_device_address` method. + size : int + Specify the size of device buffer in bytes. + base : {None, object} + Specify object that owns the referenced memory. + + Returns + ------- + cbuf : CudaBuffer + Device buffer as a view of device reachable memory. + + """ + if not address and size == 0: + return self.new_buffer(0) + cdef: + uintptr_t c_addr = self.get_device_address(address) + int64_t c_size = size + shared_ptr[CCudaBuffer] cudabuf + + cudabuf = GetResultValue(self.context.get().View( + <uint8_t*>c_addr, c_size)) + return pyarrow_wrap_cudabuffer_base(cudabuf, base) + + def open_ipc_buffer(self, ipc_handle): + """ Open existing CUDA IPC memory handle + + Parameters + ---------- + ipc_handle : IpcMemHandle + Specify opaque pointer to CUipcMemHandle (driver API). + + Returns + ------- + buf : CudaBuffer + referencing device buffer + """ + handle = pyarrow_unwrap_cudaipcmemhandle(ipc_handle) + cdef shared_ptr[CCudaBuffer] cudabuf + with nogil: + cudabuf = GetResultValue( + self.context.get().OpenIpcBuffer(handle.get()[0])) + return pyarrow_wrap_cudabuffer(cudabuf) + + def buffer_from_data(self, object data, int64_t offset=0, int64_t size=-1): + """Create device buffer and initialize with data. + + Parameters + ---------- + data : {CudaBuffer, HostBuffer, Buffer, array-like} + Specify data to be copied to device buffer. + offset : int + Specify the offset of input buffer for device data + buffering. Default: 0. + size : int + Specify the size of device buffer in bytes. Default: all + (starting from input offset) + + Returns + ------- + cbuf : CudaBuffer + Device buffer with copied data. + """ + is_host_data = not pyarrow_is_cudabuffer(data) + buf = as_buffer(data) if is_host_data else data + + bsize = buf.size + if offset < 0 or (bsize and offset >= bsize): + raise ValueError('offset argument is out-of-range') + if size < 0: + size = bsize - offset + elif offset + size > bsize: + raise ValueError( + 'requested larger slice than available in device buffer') + + if offset != 0 or size != bsize: + buf = buf.slice(offset, size) + + result = self.new_buffer(size) + if is_host_data: + result.copy_from_host(buf, position=0, nbytes=size) + else: + result.copy_from_device(buf, position=0, nbytes=size) + return result + + def buffer_from_object(self, obj): + """Create device buffer view of arbitrary object that references + device accessible memory. + + When the object contains a non-contiguous view of device + accessible memory then the returned device buffer will contain + contiguous view of the memory, that is, including the + intermediate data that is otherwise invisible to the input + object. + + Parameters + ---------- + obj : {object, Buffer, HostBuffer, CudaBuffer, ...} + Specify an object that holds (device or host) address that + can be accessed from device. This includes objects with + types defined in pyarrow.cuda as well as arbitrary objects + that implement the CUDA array interface as defined by numba. + + Returns + ------- + cbuf : CudaBuffer + Device buffer as a view of device accessible memory. + + """ + if isinstance(obj, HostBuffer): + return self.foreign_buffer(obj.address, obj.size, base=obj) + elif isinstance(obj, Buffer): + return CudaBuffer.from_buffer(obj) + elif isinstance(obj, CudaBuffer): + return obj + elif hasattr(obj, '__cuda_array_interface__'): + desc = obj.__cuda_array_interface__ + addr = desc['data'][0] + if addr is None: + return self.new_buffer(0) + import numpy as np + start, end = get_contiguous_span( + desc['shape'], desc.get('strides'), + np.dtype(desc['typestr']).itemsize) + return self.foreign_buffer(addr + start, end - start, base=obj) + raise ArrowTypeError('cannot create device buffer view from' + ' `%s` object' % (type(obj))) + + +cdef class IpcMemHandle(_Weakrefable): + """A serializable container for a CUDA IPC handle. + """ + cdef void init(self, shared_ptr[CCudaIpcMemHandle]& h): + self.handle = h + + @staticmethod + def from_buffer(Buffer opaque_handle): + """Create IpcMemHandle from opaque buffer (e.g. from another + process) + + Parameters + ---------- + opaque_handle : + a CUipcMemHandle as a const void* + + Results + ------- + ipc_handle : IpcMemHandle + """ + c_buf = pyarrow_unwrap_buffer(opaque_handle) + cdef: + shared_ptr[CCudaIpcMemHandle] handle + + handle = GetResultValue( + CCudaIpcMemHandle.FromBuffer(c_buf.get().data())) + return pyarrow_wrap_cudaipcmemhandle(handle) + + def serialize(self, pool=None): + """Write IpcMemHandle to a Buffer + + Parameters + ---------- + pool : {MemoryPool, None} + Specify a pool to allocate memory from + + Returns + ------- + buf : Buffer + The serialized buffer. + """ + cdef CMemoryPool* pool_ = maybe_unbox_memory_pool(pool) + cdef shared_ptr[CBuffer] buf + cdef CCudaIpcMemHandle* h = self.handle.get() + with nogil: + buf = GetResultValue(h.Serialize(pool_)) + return pyarrow_wrap_buffer(buf) + + +cdef class CudaBuffer(Buffer): + """An Arrow buffer with data located in a GPU device. + + To create a CudaBuffer instance, use Context.device_buffer(). + + The memory allocated in a CudaBuffer is freed when the buffer object + is deleted. + """ + + def __init__(self): + raise TypeError("Do not call CudaBuffer's constructor directly, use " + "`<pyarrow.Context instance>.device_buffer`" + " method instead.") + + cdef void init_cuda(self, + const shared_ptr[CCudaBuffer]& buffer, + object base): + self.cuda_buffer = buffer + self.init(<shared_ptr[CBuffer]> buffer) + self.base = base + + @staticmethod + def from_buffer(buf): + """ Convert back generic buffer into CudaBuffer + + Parameters + ---------- + buf : Buffer + Specify buffer containing CudaBuffer + + Returns + ------- + dbuf : CudaBuffer + Resulting device buffer. + """ + c_buf = pyarrow_unwrap_buffer(buf) + cuda_buffer = GetResultValue(CCudaBuffer.FromBuffer(c_buf)) + return pyarrow_wrap_cudabuffer(cuda_buffer) + + @staticmethod + def from_numba(mem): + """Create a CudaBuffer view from numba MemoryPointer instance. + + Parameters + ---------- + mem : numba.cuda.cudadrv.driver.MemoryPointer + + Returns + ------- + cbuf : CudaBuffer + Device buffer as a view of numba MemoryPointer. + """ + ctx = Context.from_numba(mem.context) + if mem.device_pointer.value is None and mem.size==0: + return ctx.new_buffer(0) + return ctx.foreign_buffer(mem.device_pointer.value, mem.size, base=mem) + + def to_numba(self): + """Return numba memory pointer of CudaBuffer instance. + """ + import ctypes + from numba.cuda.cudadrv.driver import MemoryPointer + return MemoryPointer(self.context.to_numba(), + pointer=ctypes.c_void_p(self.address), + size=self.size) + + cdef getitem(self, int64_t i): + return self.copy_to_host(position=i, nbytes=1)[0] + + def copy_to_host(self, int64_t position=0, int64_t nbytes=-1, + Buffer buf=None, + MemoryPool memory_pool=None, c_bool resizable=False): + """Copy memory from GPU device to CPU host + + Caller is responsible for ensuring that all tasks affecting + the memory are finished. Use + + `<CudaBuffer instance>.context.synchronize()` + + when needed. + + Parameters + ---------- + position : int + Specify the starting position of the source data in GPU + device buffer. Default: 0. + nbytes : int + Specify the number of bytes to copy. Default: -1 (all from + the position until host buffer is full). + buf : Buffer + Specify a pre-allocated output buffer in host. Default: None + (allocate new output buffer). + memory_pool : MemoryPool + resizable : bool + Specify extra arguments to allocate_buffer. Used only when + buf is None. + + Returns + ------- + buf : Buffer + Output buffer in host. + + """ + if position < 0 or (self.size and position > self.size) \ + or (self.size == 0 and position != 0): + raise ValueError('position argument is out-of-range') + cdef: + int64_t c_nbytes + if buf is None: + if nbytes < 0: + # copy all starting from position to new host buffer + c_nbytes = self.size - position + else: + if nbytes > self.size - position: + raise ValueError( + 'requested more to copy than available from ' + 'device buffer') + # copy nbytes starting from position to new host buffeer + c_nbytes = nbytes + buf = allocate_buffer(c_nbytes, memory_pool=memory_pool, + resizable=resizable) + else: + if nbytes < 0: + # copy all from position until given host buffer is full + c_nbytes = min(self.size - position, buf.size) + else: + if nbytes > buf.size: + raise ValueError( + 'requested copy does not fit into host buffer') + # copy nbytes from position to given host buffer + c_nbytes = nbytes + + cdef: + shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) + int64_t c_position = position + with nogil: + check_status(self.cuda_buffer.get() + .CopyToHost(c_position, c_nbytes, + c_buf.get().mutable_data())) + return buf + + def copy_from_host(self, data, int64_t position=0, int64_t nbytes=-1): + """Copy data from host to device. + + The device buffer must be pre-allocated. + + Parameters + ---------- + data : {Buffer, array-like} + Specify data in host. It can be array-like that is valid + argument to py_buffer + position : int + Specify the starting position of the copy in device buffer. + Default: 0. + nbytes : int + Specify the number of bytes to copy. Default: -1 (all from + source until device buffer, starting from position, is full) + + Returns + ------- + nbytes : int + Number of bytes copied. + """ + if position < 0 or position > self.size: + raise ValueError('position argument is out-of-range') + cdef: + int64_t c_nbytes + buf = as_buffer(data) + + if nbytes < 0: + # copy from host buffer to device buffer starting from + # position until device buffer is full + c_nbytes = min(self.size - position, buf.size) + else: + if nbytes > buf.size: + raise ValueError( + 'requested more to copy than available from host buffer') + if nbytes > self.size - position: + raise ValueError( + 'requested more to copy than available in device buffer') + # copy nbytes from host buffer to device buffer starting + # from position + c_nbytes = nbytes + + cdef: + shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(buf) + int64_t c_position = position + with nogil: + check_status(self.cuda_buffer.get(). + CopyFromHost(c_position, c_buf.get().data(), + c_nbytes)) + return c_nbytes + + def copy_from_device(self, buf, int64_t position=0, int64_t nbytes=-1): + """Copy data from device to device. + + Parameters + ---------- + buf : CudaBuffer + Specify source device buffer. + position : int + Specify the starting position of the copy in device buffer. + Default: 0. + nbytes : int + Specify the number of bytes to copy. Default: -1 (all from + source until device buffer, starting from position, is full) + + Returns + ------- + nbytes : int + Number of bytes copied. + + """ + if position < 0 or position > self.size: + raise ValueError('position argument is out-of-range') + cdef: + int64_t c_nbytes + + if nbytes < 0: + # copy from source device buffer to device buffer starting + # from position until device buffer is full + c_nbytes = min(self.size - position, buf.size) + else: + if nbytes > buf.size: + raise ValueError( + 'requested more to copy than available from device buffer') + if nbytes > self.size - position: + raise ValueError( + 'requested more to copy than available in device buffer') + # copy nbytes from source device buffer to device buffer + # starting from position + c_nbytes = nbytes + + cdef: + shared_ptr[CCudaBuffer] c_buf = pyarrow_unwrap_cudabuffer(buf) + int64_t c_position = position + shared_ptr[CCudaContext] c_src_ctx = pyarrow_unwrap_cudacontext( + buf.context) + void* c_source_data = <void*>(c_buf.get().address()) + + if self.context.handle != buf.context.handle: + with nogil: + check_status(self.cuda_buffer.get(). + CopyFromAnotherDevice(c_src_ctx, c_position, + c_source_data, c_nbytes)) + else: + with nogil: + check_status(self.cuda_buffer.get(). + CopyFromDevice(c_position, c_source_data, + c_nbytes)) + return c_nbytes + + def export_for_ipc(self): + """ + Expose this device buffer as IPC memory which can be used in other + processes. + + After calling this function, this device memory will not be + freed when the CudaBuffer is destructed. + + Returns + ------- + ipc_handle : IpcMemHandle + The exported IPC handle + + """ + cdef shared_ptr[CCudaIpcMemHandle] handle + with nogil: + handle = GetResultValue(self.cuda_buffer.get().ExportForIpc()) + return pyarrow_wrap_cudaipcmemhandle(handle) + + @property + def context(self): + """Returns the CUDA driver context of this buffer. + """ + return pyarrow_wrap_cudacontext(self.cuda_buffer.get().context()) + + def slice(self, offset=0, length=None): + """Return slice of device buffer + + Parameters + ---------- + offset : int, default 0 + Specify offset from the start of device buffer to slice + length : int, default None + Specify the length of slice (default is until end of device + buffer starting from offset). If the length is larger than + the data available, the returned slice will have a size of + the available data starting from the offset. + + Returns + ------- + sliced : CudaBuffer + Zero-copy slice of device buffer. + + """ + if offset < 0 or (self.size and offset >= self.size): + raise ValueError('offset argument is out-of-range') + cdef int64_t offset_ = offset + cdef int64_t size + if length is None: + size = self.size - offset_ + elif offset + length <= self.size: + size = length + else: + size = self.size - offset + parent = pyarrow_unwrap_cudabuffer(self) + return pyarrow_wrap_cudabuffer(make_shared[CCudaBuffer](parent, + offset_, size)) + + def to_pybytes(self): + """Return device buffer content as Python bytes. + """ + return self.copy_to_host().to_pybytes() + + def __getbuffer__(self, cp.Py_buffer* buffer, int flags): + # Device buffer contains data pointers on the device. Hence, + # cannot support buffer protocol PEP-3118 for CudaBuffer. + raise BufferError('buffer protocol for device buffer not supported') + + +cdef class HostBuffer(Buffer): + """Device-accessible CPU memory created using cudaHostAlloc. + + To create a HostBuffer instance, use + + cuda.new_host_buffer(<nbytes>) + """ + + def __init__(self): + raise TypeError("Do not call HostBuffer's constructor directly," + " use `cuda.new_host_buffer` function instead.") + + cdef void init_host(self, const shared_ptr[CCudaHostBuffer]& buffer): + self.host_buffer = buffer + self.init(<shared_ptr[CBuffer]> buffer) + + @property + def size(self): + return self.host_buffer.get().size() + + +cdef class BufferReader(NativeFile): + """File interface for zero-copy read from CUDA buffers. + + Note: Read methods return pointers to device memory. This means + you must be careful using this interface with any Arrow code which + may expect to be able to do anything other than pointer arithmetic + on the returned buffers. + """ + + def __cinit__(self, CudaBuffer obj): + self.buffer = obj + self.reader = new CCudaBufferReader(self.buffer.buffer) + self.set_random_access_file( + shared_ptr[CRandomAccessFile](self.reader)) + self.is_readable = True + + def read_buffer(self, nbytes=None): + """Return a slice view of the underlying device buffer. + + The slice will start at the current reader position and will + have specified size in bytes. + + Parameters + ---------- + nbytes : int, default None + Specify the number of bytes to read. Default: None (read all + remaining bytes). + + Returns + ------- + cbuf : CudaBuffer + New device buffer. + + """ + cdef: + int64_t c_nbytes + int64_t bytes_read = 0 + shared_ptr[CCudaBuffer] output + + if nbytes is None: + c_nbytes = self.size() - self.tell() + else: + c_nbytes = nbytes + + with nogil: + output = static_pointer_cast[CCudaBuffer, CBuffer]( + GetResultValue(self.reader.Read(c_nbytes))) + + return pyarrow_wrap_cudabuffer(output) + + +cdef class BufferWriter(NativeFile): + """File interface for writing to CUDA buffers. + + By default writes are unbuffered. Use set_buffer_size to enable + buffering. + """ + + def __cinit__(self, CudaBuffer buffer): + self.buffer = buffer + self.writer = new CCudaBufferWriter(self.buffer.cuda_buffer) + self.set_output_stream(shared_ptr[COutputStream](self.writer)) + self.is_writable = True + + def writeat(self, int64_t position, object data): + """Write data to buffer starting from position. + + Parameters + ---------- + position : int + Specify device buffer position where the data will be + written. + data : array-like + Specify data, the data instance must implement buffer + protocol. + """ + cdef: + Buffer buf = as_buffer(data) + const uint8_t* c_data = buf.buffer.get().data() + int64_t c_size = buf.buffer.get().size() + + with nogil: + check_status(self.writer.WriteAt(position, c_data, c_size)) + + def flush(self): + """ Flush the buffer stream """ + with nogil: + check_status(self.writer.Flush()) + + def seek(self, int64_t position, int whence=0): + # TODO: remove this method after NativeFile.seek supports + # writable files. + cdef int64_t offset + + with nogil: + if whence == 0: + offset = position + elif whence == 1: + offset = GetResultValue(self.writer.Tell()) + offset = offset + position + else: + with gil: + raise ValueError("Invalid value of whence: {0}" + .format(whence)) + check_status(self.writer.Seek(offset)) + return self.tell() + + @property + def buffer_size(self): + """Returns size of host (CPU) buffer, 0 for unbuffered + """ + return self.writer.buffer_size() + + @buffer_size.setter + def buffer_size(self, int64_t buffer_size): + """Set CPU buffer size to limit calls to cudaMemcpy + + Parameters + ---------- + buffer_size : int + Specify the size of CPU buffer to allocate in bytes. + """ + with nogil: + check_status(self.writer.SetBufferSize(buffer_size)) + + @property + def num_bytes_buffered(self): + """Returns number of bytes buffered on host + """ + return self.writer.num_bytes_buffered() + +# Functions + + +def new_host_buffer(const int64_t size, int device=0): + """Return buffer with CUDA-accessible memory on CPU host + + Parameters + ---------- + size : int + Specify the number of bytes to be allocated. + device : int + Specify GPU device number. + + Returns + ------- + dbuf : HostBuffer + Allocated host buffer + """ + cdef shared_ptr[CCudaHostBuffer] buffer + with nogil: + buffer = GetResultValue(AllocateCudaHostBuffer(device, size)) + return pyarrow_wrap_cudahostbuffer(buffer) + + +def serialize_record_batch(object batch, object ctx): + """ Write record batch message to GPU device memory + + Parameters + ---------- + batch : RecordBatch + Record batch to write + ctx : Context + CUDA Context to allocate device memory from + + Returns + ------- + dbuf : CudaBuffer + device buffer which contains the record batch message + """ + cdef shared_ptr[CCudaBuffer] buffer + cdef CRecordBatch* batch_ = pyarrow_unwrap_batch(batch).get() + cdef CCudaContext* ctx_ = pyarrow_unwrap_cudacontext(ctx).get() + with nogil: + buffer = GetResultValue(CudaSerializeRecordBatch(batch_[0], ctx_)) + return pyarrow_wrap_cudabuffer(buffer) + + +def read_message(object source, pool=None): + """ Read Arrow IPC message located on GPU device + + Parameters + ---------- + source : {CudaBuffer, cuda.BufferReader} + Device buffer or reader of device buffer. + pool : MemoryPool (optional) + Pool to allocate CPU memory for the metadata + + Returns + ------- + message : Message + The deserialized message, body still on device + """ + cdef: + Message result = Message.__new__(Message) + cdef CMemoryPool* pool_ = maybe_unbox_memory_pool(pool) + if not isinstance(source, BufferReader): + reader = BufferReader(source) + with nogil: + result.message = move( + GetResultValue(ReadMessage(reader.reader, pool_))) + return result + + +def read_record_batch(object buffer, object schema, *, + DictionaryMemo dictionary_memo=None, pool=None): + """Construct RecordBatch referencing IPC message located on CUDA device. + + While the metadata is copied to host memory for deserialization, + the record batch data remains on the device. + + Parameters + ---------- + buffer : + Device buffer containing the complete IPC message + schema : Schema + The schema for the record batch + dictionary_memo : DictionaryMemo, optional + If message contains dictionaries, must pass a populated + DictionaryMemo + pool : MemoryPool (optional) + Pool to allocate metadata from + + Returns + ------- + batch : RecordBatch + Reconstructed record batch, with device pointers + + """ + cdef: + shared_ptr[CSchema] schema_ = pyarrow_unwrap_schema(schema) + shared_ptr[CCudaBuffer] buffer_ = pyarrow_unwrap_cudabuffer(buffer) + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo + CMemoryPool* pool_ = maybe_unbox_memory_pool(pool) + shared_ptr[CRecordBatch] batch + + if dictionary_memo is not None: + arg_dict_memo = dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + + with nogil: + batch = GetResultValue(CudaReadRecordBatch( + schema_, arg_dict_memo, buffer_, pool_)) + return pyarrow_wrap_batch(batch) + + +# Public API + + +cdef public api bint pyarrow_is_buffer(object buffer): + return isinstance(buffer, Buffer) + +# cudabuffer + +cdef public api bint pyarrow_is_cudabuffer(object buffer): + return isinstance(buffer, CudaBuffer) + + +cdef public api object \ + pyarrow_wrap_cudabuffer_base(const shared_ptr[CCudaBuffer]& buf, base): + cdef CudaBuffer result = CudaBuffer.__new__(CudaBuffer) + result.init_cuda(buf, base) + return result + + +cdef public api object \ + pyarrow_wrap_cudabuffer(const shared_ptr[CCudaBuffer]& buf): + cdef CudaBuffer result = CudaBuffer.__new__(CudaBuffer) + result.init_cuda(buf, None) + return result + + +cdef public api shared_ptr[CCudaBuffer] pyarrow_unwrap_cudabuffer(object obj): + if pyarrow_is_cudabuffer(obj): + return (<CudaBuffer>obj).cuda_buffer + raise TypeError('expected CudaBuffer instance, got %s' + % (type(obj).__name__)) + +# cudahostbuffer + +cdef public api bint pyarrow_is_cudahostbuffer(object buffer): + return isinstance(buffer, HostBuffer) + + +cdef public api object \ + pyarrow_wrap_cudahostbuffer(const shared_ptr[CCudaHostBuffer]& buf): + cdef HostBuffer result = HostBuffer.__new__(HostBuffer) + result.init_host(buf) + return result + + +cdef public api shared_ptr[CCudaHostBuffer] \ + pyarrow_unwrap_cudahostbuffer(object obj): + if pyarrow_is_cudahostbuffer(obj): + return (<HostBuffer>obj).host_buffer + raise TypeError('expected HostBuffer instance, got %s' + % (type(obj).__name__)) + +# cudacontext + +cdef public api bint pyarrow_is_cudacontext(object ctx): + return isinstance(ctx, Context) + + +cdef public api object \ + pyarrow_wrap_cudacontext(const shared_ptr[CCudaContext]& ctx): + cdef Context result = Context.__new__(Context) + result.init(ctx) + return result + + +cdef public api shared_ptr[CCudaContext] \ + pyarrow_unwrap_cudacontext(object obj): + if pyarrow_is_cudacontext(obj): + return (<Context>obj).context + raise TypeError('expected Context instance, got %s' + % (type(obj).__name__)) + +# cudaipcmemhandle + +cdef public api bint pyarrow_is_cudaipcmemhandle(object handle): + return isinstance(handle, IpcMemHandle) + + +cdef public api object \ + pyarrow_wrap_cudaipcmemhandle(shared_ptr[CCudaIpcMemHandle]& h): + cdef IpcMemHandle result = IpcMemHandle.__new__(IpcMemHandle) + result.init(h) + return result + + +cdef public api shared_ptr[CCudaIpcMemHandle] \ + pyarrow_unwrap_cudaipcmemhandle(object obj): + if pyarrow_is_cudaipcmemhandle(obj): + return (<IpcMemHandle>obj).handle + raise TypeError('expected IpcMemHandle instance, got %s' + % (type(obj).__name__)) diff --git a/src/arrow/python/pyarrow/_dataset.pxd b/src/arrow/python/pyarrow/_dataset.pxd new file mode 100644 index 000000000..875e13f87 --- /dev/null +++ b/src/arrow/python/pyarrow/_dataset.pxd @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +"""Dataset is currently unstable. APIs subject to change without notice.""" + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow_dataset cimport * +from pyarrow.lib cimport * + + +cdef class FragmentScanOptions(_Weakrefable): + + cdef: + shared_ptr[CFragmentScanOptions] wrapped + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp) + + @staticmethod + cdef wrap(const shared_ptr[CFragmentScanOptions]& sp) + + +cdef class FileFormat(_Weakrefable): + + cdef: + shared_ptr[CFileFormat] wrapped + CFileFormat* format + + cdef void init(self, const shared_ptr[CFileFormat]& sp) + + @staticmethod + cdef wrap(const shared_ptr[CFileFormat]& sp) + + cdef inline shared_ptr[CFileFormat] unwrap(self) + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options) diff --git a/src/arrow/python/pyarrow/_dataset.pyx b/src/arrow/python/pyarrow/_dataset.pyx new file mode 100644 index 000000000..459c3b8fb --- /dev/null +++ b/src/arrow/python/pyarrow/_dataset.pyx @@ -0,0 +1,3408 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +"""Dataset is currently unstable. APIs subject to change without notice.""" + +from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE +from cython.operator cimport dereference as deref + +import collections +import os +import warnings + +import pyarrow as pa +from pyarrow.lib cimport * +from pyarrow.lib import ArrowTypeError, frombytes, tobytes +from pyarrow.includes.libarrow_dataset cimport * +from pyarrow._fs cimport FileSystem, FileInfo, FileSelector +from pyarrow._csv cimport ( + ConvertOptions, ParseOptions, ReadOptions, WriteOptions) +from pyarrow.util import _is_iterable, _is_path_like, _stringify_path + +from pyarrow._parquet cimport ( + _create_writer_properties, _create_arrow_writer_properties, + FileMetaData, RowGroupMetaData, ColumnChunkMetaData +) + + +def _forbid_instantiation(klass, subclasses_instead=True): + msg = '{} is an abstract class thus cannot be initialized.'.format( + klass.__name__ + ) + if subclasses_instead: + subclasses = [cls.__name__ for cls in klass.__subclasses__] + msg += ' Use one of the subclasses instead: {}'.format( + ', '.join(subclasses) + ) + raise TypeError(msg) + + +_orc_fileformat = None +_orc_imported = False + + +def _get_orc_fileformat(): + """ + Import OrcFileFormat on first usage (to avoid circular import issue + when `pyarrow._dataset_orc` would be imported first) + """ + global _orc_fileformat + global _orc_imported + if not _orc_imported: + try: + from pyarrow._dataset_orc import OrcFileFormat + _orc_fileformat = OrcFileFormat + except ImportError as e: + _orc_fileformat = None + finally: + _orc_imported = True + return _orc_fileformat + + +cdef CFileSource _make_file_source(object file, FileSystem filesystem=None): + + cdef: + CFileSource c_source + shared_ptr[CFileSystem] c_filesystem + c_string c_path + shared_ptr[CRandomAccessFile] c_file + shared_ptr[CBuffer] c_buffer + + if isinstance(file, Buffer): + c_buffer = pyarrow_unwrap_buffer(file) + c_source = CFileSource(move(c_buffer)) + + elif _is_path_like(file): + if filesystem is None: + raise ValueError("cannot construct a FileSource from " + "a path without a FileSystem") + c_filesystem = filesystem.unwrap() + c_path = tobytes(_stringify_path(file)) + c_source = CFileSource(move(c_path), move(c_filesystem)) + + elif hasattr(file, 'read'): + # Optimistically hope this is file-like + c_file = get_native_file(file, False).get_random_access_file() + c_source = CFileSource(move(c_file)) + + else: + raise TypeError("cannot construct a FileSource " + "from " + str(file)) + + return c_source + + +cdef CSegmentEncoding _get_segment_encoding(str segment_encoding): + if segment_encoding == "none": + return CSegmentEncodingNone + elif segment_encoding == "uri": + return CSegmentEncodingUri + raise ValueError(f"Unknown segment encoding: {segment_encoding}") + + +cdef class Expression(_Weakrefable): + """ + A logical expression to be evaluated against some input. + + To create an expression: + + - Use the factory function ``pyarrow.dataset.scalar()`` to create a + scalar (not necessary when combined, see example below). + - Use the factory function ``pyarrow.dataset.field()`` to reference + a field (column in table). + - Compare fields and scalars with ``<``, ``<=``, ``==``, ``>=``, ``>``. + - Combine expressions using python operators ``&`` (logical and), + ``|`` (logical or) and ``~`` (logical not). + Note: python keywords ``and``, ``or`` and ``not`` cannot be used + to combine expressions. + - Check whether the expression is contained in a list of values with + the ``pyarrow.dataset.Expression.isin()`` member function. + + Examples + -------- + + >>> import pyarrow.dataset as ds + >>> (ds.field("a") < ds.scalar(3)) | (ds.field("b") > 7) + <pyarrow.dataset.Expression ((a < 3:int64) or (b > 7:int64))> + >>> ds.field('a') != 3 + <pyarrow.dataset.Expression (a != 3)> + >>> ds.field('a').isin([1, 2, 3]) + <pyarrow.dataset.Expression (a is in [ + 1, + 2, + 3 + ])> + """ + cdef: + CExpression expr + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const CExpression& sp): + self.expr = sp + + @staticmethod + cdef wrap(const CExpression& sp): + cdef Expression self = Expression.__new__(Expression) + self.init(sp) + return self + + cdef inline CExpression unwrap(self): + return self.expr + + def equals(self, Expression other): + return self.expr.Equals(other.unwrap()) + + def __str__(self): + return frombytes(self.expr.ToString()) + + def __repr__(self): + return "<pyarrow.dataset.{0} {1}>".format( + self.__class__.__name__, str(self) + ) + + @staticmethod + def _deserialize(Buffer buffer not None): + return Expression.wrap(GetResultValue(CDeserializeExpression( + pyarrow_unwrap_buffer(buffer)))) + + def __reduce__(self): + buffer = pyarrow_wrap_buffer(GetResultValue( + CSerializeExpression(self.expr))) + return Expression._deserialize, (buffer,) + + @staticmethod + cdef Expression _expr_or_scalar(object expr): + if isinstance(expr, Expression): + return (<Expression> expr) + return (<Expression> Expression._scalar(expr)) + + @staticmethod + cdef Expression _call(str function_name, list arguments, + shared_ptr[CFunctionOptions] options=( + <shared_ptr[CFunctionOptions]> nullptr)): + cdef: + vector[CExpression] c_arguments + + for argument in arguments: + c_arguments.push_back((<Expression> argument).expr) + + return Expression.wrap(CMakeCallExpression(tobytes(function_name), + move(c_arguments), options)) + + def __richcmp__(self, other, int op): + other = Expression._expr_or_scalar(other) + return Expression._call({ + Py_EQ: "equal", + Py_NE: "not_equal", + Py_GT: "greater", + Py_GE: "greater_equal", + Py_LT: "less", + Py_LE: "less_equal", + }[op], [self, other]) + + def __bool__(self): + raise ValueError( + "An Expression cannot be evaluated to python True or False. " + "If you are using the 'and', 'or' or 'not' operators, use '&', " + "'|' or '~' instead." + ) + + def __invert__(self): + return Expression._call("invert", [self]) + + def __and__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("and_kleene", [self, other]) + + def __or__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("or_kleene", [self, other]) + + def __add__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("add_checked", [self, other]) + + def __mul__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("multiply_checked", [self, other]) + + def __sub__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("subtract_checked", [self, other]) + + def __truediv__(Expression self, other): + other = Expression._expr_or_scalar(other) + return Expression._call("divide_checked", [self, other]) + + def is_valid(self): + """Checks whether the expression is not-null (valid)""" + return Expression._call("is_valid", [self]) + + def is_null(self, bint nan_is_null=False): + """Checks whether the expression is null""" + cdef: + shared_ptr[CFunctionOptions] c_options + + c_options.reset(new CNullOptions(nan_is_null)) + return Expression._call("is_null", [self], c_options) + + def cast(self, type, bint safe=True): + """Explicitly change the expression's data type""" + cdef shared_ptr[CCastOptions] c_options + c_options.reset(new CCastOptions(safe)) + c_options.get().to_type = pyarrow_unwrap_data_type(ensure_type(type)) + return Expression._call("cast", [self], + <shared_ptr[CFunctionOptions]> c_options) + + def isin(self, values): + """Checks whether the expression is contained in values""" + cdef: + shared_ptr[CFunctionOptions] c_options + CDatum c_values + + if not isinstance(values, pa.Array): + values = pa.array(values) + + c_values = CDatum(pyarrow_unwrap_array(values)) + c_options.reset(new CSetLookupOptions(c_values, True)) + return Expression._call("is_in", [self], c_options) + + @staticmethod + def _field(str name not None): + return Expression.wrap(CMakeFieldExpression(tobytes(name))) + + @staticmethod + def _scalar(value): + cdef: + Scalar scalar + + if isinstance(value, Scalar): + scalar = value + else: + scalar = pa.scalar(value) + + return Expression.wrap(CMakeScalarExpression(scalar.unwrap())) + + +_deserialize = Expression._deserialize +cdef Expression _true = Expression._scalar(True) + + +cdef class Dataset(_Weakrefable): + """ + Collection of data fragments and potentially child datasets. + + Arrow Datasets allow you to query against data that has been split across + multiple files. This sharding of data may indicate partitioning, which + can accelerate queries that only touch some partitions (files). + """ + + cdef: + shared_ptr[CDataset] wrapped + CDataset* dataset + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CDataset]& sp): + self.wrapped = sp + self.dataset = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CDataset]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'union': UnionDataset, + 'filesystem': FileSystemDataset, + 'in-memory': InMemoryDataset, + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef Dataset self = class_.__new__(class_) + self.init(sp) + return self + + cdef shared_ptr[CDataset] unwrap(self) nogil: + return self.wrapped + + @property + def partition_expression(self): + """ + An Expression which evaluates to true for all data viewed by this + Dataset. + """ + return Expression.wrap(self.dataset.partition_expression()) + + def replace_schema(self, Schema schema not None): + """ + Return a copy of this Dataset with a different schema. + + The copy will view the same Fragments. If the new schema is not + compatible with the original dataset's schema then an error will + be raised. + """ + cdef shared_ptr[CDataset] copy = GetResultValue( + self.dataset.ReplaceSchema(pyarrow_unwrap_schema(schema))) + return Dataset.wrap(move(copy)) + + def get_fragments(self, Expression filter=None): + """Returns an iterator over the fragments in this dataset. + + Parameters + ---------- + filter : Expression, default None + Return fragments matching the optional filter, either using the + partition_expression or internal information like Parquet's + statistics. + + Returns + ------- + fragments : iterator of Fragment + """ + cdef: + CExpression c_filter + CFragmentIterator c_iterator + + if filter is None: + c_fragments = move(GetResultValue(self.dataset.GetFragments())) + else: + c_filter = _bind(filter, self.schema) + c_fragments = move(GetResultValue( + self.dataset.GetFragments(c_filter))) + + for maybe_fragment in c_fragments: + yield Fragment.wrap(GetResultValue(move(maybe_fragment))) + + def scanner(self, **kwargs): + """Builds a scan operation against the dataset. + + Data is not loaded immediately. Instead, this produces a Scanner, + which exposes further operations (e.g. loading all data as a + table, counting rows). + + Parameters + ---------- + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 1M + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + use_async : bool, default False + If enabled, an async scanner will be used that should offer + better performance with high-latency/highly-parallel filesystems + (e.g. S3) + + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + + Returns + ------- + scanner : Scanner + + Examples + -------- + >>> import pyarrow.dataset as ds + >>> dataset = ds.dataset("path/to/dataset") + + Selecting a subset of the columns: + + >>> dataset.scanner(columns=["A", "B"]).to_table() + + Projecting selected columns using an expression: + + >>> dataset.scanner(columns={ + ... "A_int": ds.field("A").cast("int64"), + ... }).to_table() + + Filtering rows while scanning: + + >>> dataset.scanner(filter=ds.field("A") > 0).to_table() + """ + return Scanner.from_dataset(self, **kwargs) + + def to_batches(self, **kwargs): + """Read the dataset as materialized record batches. + + See scanner method parameters documentation. + + Returns + ------- + record_batches : iterator of RecordBatch + """ + return self.scanner(**kwargs).to_batches() + + def to_table(self, **kwargs): + """Read the dataset to an arrow table. + + Note that this method reads all the selected data from the dataset + into memory. + + See scanner method parameters documentation. + + Returns + ------- + table : Table instance + """ + return self.scanner(**kwargs).to_table() + + def take(self, object indices, **kwargs): + """Select rows of data by index. + + See scanner method parameters documentation. + + Returns + ------- + table : Table instance + """ + return self.scanner(**kwargs).take(indices) + + def head(self, int num_rows, **kwargs): + """Load the first N rows of the dataset. + + See scanner method parameters documentation. + + Returns + ------- + table : Table instance + """ + return self.scanner(**kwargs).head(num_rows) + + def count_rows(self, **kwargs): + """Count rows matching the scanner filter. + + See scanner method parameters documentation. + + Returns + ------- + count : int + """ + return self.scanner(**kwargs).count_rows() + + @property + def schema(self): + """The common schema of the full Dataset""" + return pyarrow_wrap_schema(self.dataset.schema()) + + +cdef class InMemoryDataset(Dataset): + """ + A Dataset wrapping in-memory data. + + Parameters + ---------- + source : The data for this dataset. + Can be a RecordBatch, Table, list of + RecordBatch/Table, iterable of RecordBatch, or a RecordBatchReader. + If an iterable is provided, the schema must also be provided. + schema : Schema, optional + Only required if passing an iterable as the source. + """ + + cdef: + CInMemoryDataset* in_memory_dataset + + def __init__(self, source, Schema schema=None): + cdef: + RecordBatchReader reader + shared_ptr[CInMemoryDataset] in_memory_dataset + + if isinstance(source, (pa.RecordBatch, pa.Table)): + source = [source] + + if isinstance(source, (list, tuple)): + batches = [] + for item in source: + if isinstance(item, pa.RecordBatch): + batches.append(item) + elif isinstance(item, pa.Table): + batches.extend(item.to_batches()) + else: + raise TypeError( + 'Expected a list of tables or batches. The given list ' + 'contains a ' + type(item).__name__) + if schema is None: + schema = item.schema + elif not schema.equals(item.schema): + raise ArrowTypeError( + f'Item has schema\n{item.schema}\nwhich does not ' + f'match expected schema\n{schema}') + if not batches and schema is None: + raise ValueError('Must provide schema to construct in-memory ' + 'dataset from an empty list') + table = pa.Table.from_batches(batches, schema=schema) + in_memory_dataset = make_shared[CInMemoryDataset]( + pyarrow_unwrap_table(table)) + else: + raise TypeError( + 'Expected a table, batch, or list of tables/batches ' + 'instead of the given type: ' + + type(source).__name__ + ) + + self.init(<shared_ptr[CDataset]> in_memory_dataset) + + cdef void init(self, const shared_ptr[CDataset]& sp): + Dataset.init(self, sp) + self.in_memory_dataset = <CInMemoryDataset*> sp.get() + + +cdef class UnionDataset(Dataset): + """ + A Dataset wrapping child datasets. + + Children's schemas must agree with the provided schema. + + Parameters + ---------- + schema : Schema + A known schema to conform to. + children : list of Dataset + One or more input children + """ + + cdef: + CUnionDataset* union_dataset + + def __init__(self, Schema schema not None, children): + cdef: + Dataset child + CDatasetVector c_children + shared_ptr[CUnionDataset] union_dataset + + for child in children: + c_children.push_back(child.wrapped) + + union_dataset = GetResultValue(CUnionDataset.Make( + pyarrow_unwrap_schema(schema), move(c_children))) + self.init(<shared_ptr[CDataset]> union_dataset) + + cdef void init(self, const shared_ptr[CDataset]& sp): + Dataset.init(self, sp) + self.union_dataset = <CUnionDataset*> sp.get() + + def __reduce__(self): + return UnionDataset, (self.schema, self.children) + + @property + def children(self): + cdef CDatasetVector children = self.union_dataset.children() + return [Dataset.wrap(children[i]) for i in range(children.size())] + + +cdef class FileSystemDataset(Dataset): + """ + A Dataset of file fragments. + + A FileSystemDataset is composed of one or more FileFragment. + + Parameters + ---------- + fragments : list[Fragments] + List of fragments to consume. + schema : Schema + The top-level schema of the Dataset. + format : FileFormat + File format of the fragments, currently only ParquetFileFormat, + IpcFileFormat, and CsvFileFormat are supported. + filesystem : FileSystem + FileSystem of the fragments. + root_partition : Expression, optional + The top-level partition of the DataDataset. + """ + + cdef: + CFileSystemDataset* filesystem_dataset + + def __init__(self, fragments, Schema schema, FileFormat format, + FileSystem filesystem=None, root_partition=None): + cdef: + FileFragment fragment=None + vector[shared_ptr[CFileFragment]] c_fragments + CResult[shared_ptr[CDataset]] result + shared_ptr[CFileSystem] c_filesystem + + if root_partition is None: + root_partition = _true + elif not isinstance(root_partition, Expression): + raise TypeError( + "Argument 'root_partition' has incorrect type (expected " + "Epression, got {0})".format(type(root_partition)) + ) + + for fragment in fragments: + c_fragments.push_back( + static_pointer_cast[CFileFragment, CFragment]( + fragment.unwrap())) + + if filesystem is None: + filesystem = fragment.filesystem + + if filesystem is not None: + c_filesystem = filesystem.unwrap() + + result = CFileSystemDataset.Make( + pyarrow_unwrap_schema(schema), + (<Expression> root_partition).unwrap(), + format.unwrap(), + c_filesystem, + c_fragments + ) + self.init(GetResultValue(result)) + + @property + def filesystem(self): + return FileSystem.wrap(self.filesystem_dataset.filesystem()) + + @property + def partitioning(self): + """ + The partitioning of the Dataset source, if discovered. + + If the FileSystemDataset is created using the ``dataset()`` factory + function with a partitioning specified, this will return the + finalized Partitioning object from the dataset discovery. In all + other cases, this returns None. + """ + c_partitioning = self.filesystem_dataset.partitioning() + if c_partitioning.get() == nullptr: + return None + try: + return Partitioning.wrap(c_partitioning) + except TypeError: + # e.g. type_name "default" + return None + + cdef void init(self, const shared_ptr[CDataset]& sp): + Dataset.init(self, sp) + self.filesystem_dataset = <CFileSystemDataset*> sp.get() + + def __reduce__(self): + return FileSystemDataset, ( + list(self.get_fragments()), + self.schema, + self.format, + self.filesystem, + self.partition_expression + ) + + @classmethod + def from_paths(cls, paths, schema=None, format=None, + filesystem=None, partitions=None, root_partition=None): + """A Dataset created from a list of paths on a particular filesystem. + + Parameters + ---------- + paths : list of str + List of file paths to create the fragments from. + schema : Schema + The top-level schema of the DataDataset. + format : FileFormat + File format to create fragments from, currently only + ParquetFileFormat, IpcFileFormat, and CsvFileFormat are supported. + filesystem : FileSystem + The filesystem which files are from. + partitions : List[Expression], optional + Attach additional partition information for the file paths. + root_partition : Expression, optional + The top-level partition of the DataDataset. + """ + cdef: + FileFragment fragment + + if root_partition is None: + root_partition = _true + + for arg, class_, name in [ + (schema, Schema, 'schema'), + (format, FileFormat, 'format'), + (filesystem, FileSystem, 'filesystem'), + (root_partition, Expression, 'root_partition') + ]: + if not isinstance(arg, class_): + raise TypeError( + "Argument '{0}' has incorrect type (expected {1}, " + "got {2})".format(name, class_.__name__, type(arg)) + ) + + partitions = partitions or [_true] * len(paths) + + if len(paths) != len(partitions): + raise ValueError( + 'The number of files resulting from paths_or_selector ' + 'must be equal to the number of partitions.' + ) + + fragments = [ + format.make_fragment(path, filesystem, partitions[i]) + for i, path in enumerate(paths) + ] + return FileSystemDataset(fragments, schema, format, + filesystem, root_partition) + + @property + def files(self): + """List of the files""" + cdef vector[c_string] files = self.filesystem_dataset.files() + return [frombytes(f) for f in files] + + @property + def format(self): + """The FileFormat of this source.""" + return FileFormat.wrap(self.filesystem_dataset.format()) + + +cdef CExpression _bind(Expression filter, Schema schema) except *: + assert schema is not None + + if filter is None: + return _true.unwrap() + + return GetResultValue(filter.unwrap().Bind( + deref(pyarrow_unwrap_schema(schema).get()))) + + +cdef class FileWriteOptions(_Weakrefable): + + cdef: + shared_ptr[CFileWriteOptions] wrapped + CFileWriteOptions* c_options + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFileWriteOptions]& sp): + self.wrapped = sp + self.c_options = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CFileWriteOptions]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'csv': CsvFileWriteOptions, + 'ipc': IpcFileWriteOptions, + 'parquet': ParquetFileWriteOptions, + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef FileWriteOptions self = class_.__new__(class_) + self.init(sp) + return self + + @property + def format(self): + return FileFormat.wrap(self.c_options.format()) + + cdef inline shared_ptr[CFileWriteOptions] unwrap(self): + return self.wrapped + + +cdef class FileFormat(_Weakrefable): + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFileFormat]& sp): + self.wrapped = sp + self.format = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CFileFormat]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'ipc': IpcFileFormat, + 'csv': CsvFileFormat, + 'parquet': ParquetFileFormat, + 'orc': _get_orc_fileformat(), + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef FileFormat self = class_.__new__(class_) + self.init(sp) + return self + + cdef inline shared_ptr[CFileFormat] unwrap(self): + return self.wrapped + + def inspect(self, file, filesystem=None): + """Infer the schema of a file.""" + c_source = _make_file_source(file, filesystem) + c_schema = GetResultValue(self.format.Inspect(c_source)) + return pyarrow_wrap_schema(move(c_schema)) + + def make_fragment(self, file, filesystem=None, + Expression partition_expression=None): + """ + Make a FileFragment of this FileFormat. The filter may not reference + fields absent from the provided schema. If no schema is provided then + one will be inferred. + """ + if partition_expression is None: + partition_expression = _true + + c_source = _make_file_source(file, filesystem) + c_fragment = <shared_ptr[CFragment]> GetResultValue( + self.format.MakeFragment(move(c_source), + partition_expression.unwrap(), + <shared_ptr[CSchema]>nullptr)) + return Fragment.wrap(move(c_fragment)) + + def make_write_options(self): + return FileWriteOptions.wrap(self.format.DefaultWriteOptions()) + + @property + def default_extname(self): + return frombytes(self.format.type_name()) + + @property + def default_fragment_scan_options(self): + return FragmentScanOptions.wrap( + self.wrapped.get().default_fragment_scan_options) + + @default_fragment_scan_options.setter + def default_fragment_scan_options(self, FragmentScanOptions options): + if options is None: + self.wrapped.get().default_fragment_scan_options =\ + <shared_ptr[CFragmentScanOptions]>nullptr + else: + self._set_default_fragment_scan_options(options) + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): + raise ValueError(f"Cannot set fragment scan options for " + f"'{options.type_name}' on {self.__class__.__name__}") + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class Fragment(_Weakrefable): + """Fragment of data from a Dataset.""" + + cdef: + shared_ptr[CFragment] wrapped + CFragment* fragment + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFragment]& sp): + self.wrapped = sp + self.fragment = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CFragment]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + # IpcFileFormat and CsvFileFormat do not have corresponding + # subclasses of FileFragment + 'ipc': FileFragment, + 'csv': FileFragment, + 'parquet': ParquetFileFragment, + } + + class_ = classes.get(type_name, None) + if class_ is None: + class_ = Fragment + + cdef Fragment self = class_.__new__(class_) + self.init(sp) + return self + + cdef inline shared_ptr[CFragment] unwrap(self): + return self.wrapped + + @property + def physical_schema(self): + """Return the physical schema of this Fragment. This schema can be + different from the dataset read schema.""" + cdef: + CResult[shared_ptr[CSchema]] maybe_schema + with nogil: + maybe_schema = self.fragment.ReadPhysicalSchema() + return pyarrow_wrap_schema(GetResultValue(maybe_schema)) + + @property + def partition_expression(self): + """An Expression which evaluates to true for all data viewed by this + Fragment. + """ + return Expression.wrap(self.fragment.partition_expression()) + + def scanner(self, Schema schema=None, **kwargs): + """Builds a scan operation against the dataset. + + Data is not loaded immediately. Instead, this produces a Scanner, + which exposes further operations (e.g. loading all data as a + table, counting rows). + + Parameters + ---------- + schema : Schema + Schema to use for scanning. This is used to unify a Fragment to + it's Dataset's schema. If not specified this will use the + Fragment's physical schema which might differ for each Fragment. + columns : list of str, default None + The columns to project. This can be a list of column names to + include (order and duplicates will be preserved), or a dictionary + with {new_column_name: expression} values for more advanced + projections. + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 1M + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + fragment_scan_options : FragmentScanOptions, default None + Options specific to a particular scan and fragment type, which + can change between different scans of the same dataset. + + Returns + ------- + scanner : Scanner + + """ + return Scanner.from_fragment(self, schema=schema, **kwargs) + + def to_batches(self, Schema schema=None, **kwargs): + """Read the fragment as materialized record batches. + + See scanner method parameters documentation. + + Returns + ------- + record_batches : iterator of RecordBatch + """ + return self.scanner(schema=schema, **kwargs).to_batches() + + def to_table(self, Schema schema=None, **kwargs): + """Convert this Fragment into a Table. + + Use this convenience utility with care. This will serially materialize + the Scan result in memory before creating the Table. + + See scanner method parameters documentation. + + Returns + ------- + table : Table + """ + return self.scanner(schema=schema, **kwargs).to_table() + + def take(self, object indices, **kwargs): + """Select rows of data by index. + + See scanner method parameters documentation. + + Returns + ------- + table : Table instance + """ + return self.scanner(**kwargs).take(indices) + + def head(self, int num_rows, **kwargs): + """Load the first N rows of the fragment. + + See scanner method parameters documentation. + + Returns + ------- + table : Table instance + """ + return self.scanner(**kwargs).head(num_rows) + + def count_rows(self, **kwargs): + """Count rows matching the scanner filter. + + See scanner method parameters documentation. + + Returns + ------- + count : int + """ + return self.scanner(**kwargs).count_rows() + + +cdef class FileFragment(Fragment): + """A Fragment representing a data file.""" + + cdef: + CFileFragment* file_fragment + + cdef void init(self, const shared_ptr[CFragment]& sp): + Fragment.init(self, sp) + self.file_fragment = <CFileFragment*> sp.get() + + def __repr__(self): + type_name = frombytes(self.fragment.type_name()) + if type_name != "parquet": + typ = f" type={type_name}" + else: + # parquet has a subclass -> type embedded in class name + typ = "" + partition_dict = _get_partition_keys(self.partition_expression) + partition = ", ".join( + [f"{key}={val}" for key, val in partition_dict.items()] + ) + if partition: + partition = f" partition=[{partition}]" + return "<pyarrow.dataset.{0}{1} path={2}{3}>".format( + self.__class__.__name__, typ, self.path, partition + ) + + def __reduce__(self): + buffer = self.buffer + return self.format.make_fragment, ( + self.path if buffer is None else buffer, + self.filesystem, + self.partition_expression + ) + + @property + def path(self): + """ + The path of the data file viewed by this fragment, if it views a + file. If instead it views a buffer, this will be "<Buffer>". + """ + return frombytes(self.file_fragment.source().path()) + + @property + def filesystem(self): + """ + The FileSystem containing the data file viewed by this fragment, if + it views a file. If instead it views a buffer, this will be None. + """ + cdef: + shared_ptr[CFileSystem] c_fs + c_fs = self.file_fragment.source().filesystem() + + if c_fs.get() == nullptr: + return None + + return FileSystem.wrap(c_fs) + + @property + def buffer(self): + """ + The buffer viewed by this fragment, if it views a buffer. If + instead it views a file, this will be None. + """ + cdef: + shared_ptr[CBuffer] c_buffer + c_buffer = self.file_fragment.source().buffer() + + if c_buffer.get() == nullptr: + return None + + return pyarrow_wrap_buffer(c_buffer) + + @property + def format(self): + """ + The format of the data file viewed by this fragment. + """ + return FileFormat.wrap(self.file_fragment.format()) + + +class RowGroupInfo: + """ + A wrapper class for RowGroup information + + Parameters + ---------- + id : the group id. + metadata : the rowgroup metadata. + schema : schema of the rows. + """ + + def __init__(self, id, metadata, schema): + self.id = id + self.metadata = metadata + self.schema = schema + + @property + def num_rows(self): + return self.metadata.num_rows + + @property + def total_byte_size(self): + return self.metadata.total_byte_size + + @property + def statistics(self): + def name_stats(i): + col = self.metadata.column(i) + + stats = col.statistics + if stats is None or not stats.has_min_max: + return None, None + + name = col.path_in_schema + field_index = self.schema.get_field_index(name) + if field_index < 0: + return None, None + + typ = self.schema.field(field_index).type + return col.path_in_schema, { + 'min': pa.scalar(stats.min, type=typ).as_py(), + 'max': pa.scalar(stats.max, type=typ).as_py() + } + + return { + name: stats for name, stats + in map(name_stats, range(self.metadata.num_columns)) + if stats is not None + } + + def __repr__(self): + return "RowGroupInfo({})".format(self.id) + + def __eq__(self, other): + if isinstance(other, int): + return self.id == other + if not isinstance(other, RowGroupInfo): + return False + return self.id == other.id + + +cdef class FragmentScanOptions(_Weakrefable): + """Scan options specific to a particular fragment and scan operation.""" + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): + self.wrapped = sp + + @staticmethod + cdef wrap(const shared_ptr[CFragmentScanOptions]& sp): + if not sp: + return None + + type_name = frombytes(sp.get().type_name()) + + classes = { + 'csv': CsvFragmentScanOptions, + 'parquet': ParquetFragmentScanOptions, + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef FragmentScanOptions self = class_.__new__(class_) + self.init(sp) + return self + + @property + def type_name(self): + return frombytes(self.wrapped.get().type_name()) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + +cdef class ParquetFileFragment(FileFragment): + """A Fragment representing a parquet file.""" + + cdef: + CParquetFileFragment* parquet_file_fragment + + cdef void init(self, const shared_ptr[CFragment]& sp): + FileFragment.init(self, sp) + self.parquet_file_fragment = <CParquetFileFragment*> sp.get() + + def __reduce__(self): + buffer = self.buffer + row_groups = [row_group.id for row_group in self.row_groups] + return self.format.make_fragment, ( + self.path if buffer is None else buffer, + self.filesystem, + self.partition_expression, + row_groups + ) + + def ensure_complete_metadata(self): + """ + Ensure that all metadata (statistics, physical schema, ...) have + been read and cached in this fragment. + """ + check_status(self.parquet_file_fragment.EnsureCompleteMetadata()) + + @property + def row_groups(self): + metadata = self.metadata + cdef vector[int] row_groups = self.parquet_file_fragment.row_groups() + return [RowGroupInfo(i, metadata.row_group(i), self.physical_schema) + for i in row_groups] + + @property + def metadata(self): + self.ensure_complete_metadata() + cdef FileMetaData metadata = FileMetaData() + metadata.init(self.parquet_file_fragment.metadata()) + return metadata + + @property + def num_row_groups(self): + """ + Return the number of row groups viewed by this fragment (not the + number of row groups in the origin file). + """ + self.ensure_complete_metadata() + return self.parquet_file_fragment.row_groups().size() + + def split_by_row_group(self, Expression filter=None, + Schema schema=None): + """ + Split the fragment into multiple fragments. + + Yield a Fragment wrapping each row group in this ParquetFileFragment. + Row groups will be excluded whose metadata contradicts the optional + filter. + + Parameters + ---------- + filter : Expression, default None + Only include the row groups which satisfy this predicate (using + the Parquet RowGroup statistics). + schema : Schema, default None + Schema to use when filtering row groups. Defaults to the + Fragment's phsyical schema + + Returns + ------- + A list of Fragments + """ + cdef: + vector[shared_ptr[CFragment]] c_fragments + CExpression c_filter + shared_ptr[CFragment] c_fragment + + schema = schema or self.physical_schema + c_filter = _bind(filter, schema) + with nogil: + c_fragments = move(GetResultValue( + self.parquet_file_fragment.SplitByRowGroup(move(c_filter)))) + + return [Fragment.wrap(c_fragment) for c_fragment in c_fragments] + + def subset(self, Expression filter=None, Schema schema=None, + object row_group_ids=None): + """ + Create a subset of the fragment (viewing a subset of the row groups). + + Subset can be specified by either a filter predicate (with optional + schema) or by a list of row group IDs. Note that when using a filter, + the resulting fragment can be empty (viewing no row groups). + + Parameters + ---------- + filter : Expression, default None + Only include the row groups which satisfy this predicate (using + the Parquet RowGroup statistics). + schema : Schema, default None + Schema to use when filtering row groups. Defaults to the + Fragment's phsyical schema + row_group_ids : list of ints + The row group IDs to include in the subset. Can only be specified + if `filter` is None. + + Returns + ------- + ParquetFileFragment + """ + cdef: + CExpression c_filter + vector[int] c_row_group_ids + shared_ptr[CFragment] c_fragment + + if filter is not None and row_group_ids is not None: + raise ValueError( + "Cannot specify both 'filter' and 'row_group_ids'." + ) + + if filter is not None: + schema = schema or self.physical_schema + c_filter = _bind(filter, schema) + with nogil: + c_fragment = move(GetResultValue( + self.parquet_file_fragment.SubsetWithFilter( + move(c_filter)))) + elif row_group_ids is not None: + c_row_group_ids = [ + <int> row_group for row_group in sorted(set(row_group_ids)) + ] + with nogil: + c_fragment = move(GetResultValue( + self.parquet_file_fragment.SubsetWithIds( + move(c_row_group_ids)))) + else: + raise ValueError( + "Need to specify one of 'filter' or 'row_group_ids'" + ) + + return Fragment.wrap(c_fragment) + + +cdef class ParquetReadOptions(_Weakrefable): + """ + Parquet format specific options for reading. + + Parameters + ---------- + dictionary_columns : list of string, default None + Names of columns which should be dictionary encoded as + they are read. + coerce_int96_timestamp_unit : str, default None. + Cast timestamps that are stored in INT96 format to a particular + resolution (e.g. 'ms'). Setting to None is equivalent to 'ns' + and therefore INT96 timestamps will be infered as timestamps + in nanoseconds. + """ + + cdef public: + set dictionary_columns + TimeUnit _coerce_int96_timestamp_unit + + # Also see _PARQUET_READ_OPTIONS + def __init__(self, dictionary_columns=None, + coerce_int96_timestamp_unit=None): + self.dictionary_columns = set(dictionary_columns or set()) + self.coerce_int96_timestamp_unit = coerce_int96_timestamp_unit + + @property + def coerce_int96_timestamp_unit(self): + return timeunit_to_string(self._coerce_int96_timestamp_unit) + + @coerce_int96_timestamp_unit.setter + def coerce_int96_timestamp_unit(self, unit): + if unit is not None: + self._coerce_int96_timestamp_unit = string_to_timeunit(unit) + else: + self._coerce_int96_timestamp_unit = TimeUnit_NANO + + def equals(self, ParquetReadOptions other): + return (self.dictionary_columns == other.dictionary_columns and + self.coerce_int96_timestamp_unit == + other.coerce_int96_timestamp_unit) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return False + + def __repr__(self): + return ( + f"<ParquetReadOptions" + f" dictionary_columns={self.dictionary_columns}" + f" coerce_int96_timestamp_unit={self.coerce_int96_timestamp_unit}>" + ) + + +cdef class ParquetFileWriteOptions(FileWriteOptions): + + cdef: + CParquetFileWriteOptions* parquet_options + object _properties + + def update(self, **kwargs): + arrow_fields = { + "use_deprecated_int96_timestamps", + "coerce_timestamps", + "allow_truncated_timestamps", + } + + setters = set() + for name, value in kwargs.items(): + if name not in self._properties: + raise TypeError("unexpected parquet write option: " + name) + self._properties[name] = value + if name in arrow_fields: + setters.add(self._set_arrow_properties) + else: + setters.add(self._set_properties) + + for setter in setters: + setter() + + def _set_properties(self): + cdef CParquetFileWriteOptions* opts = self.parquet_options + + opts.writer_properties = _create_writer_properties( + use_dictionary=self._properties["use_dictionary"], + compression=self._properties["compression"], + version=self._properties["version"], + write_statistics=self._properties["write_statistics"], + data_page_size=self._properties["data_page_size"], + compression_level=self._properties["compression_level"], + use_byte_stream_split=( + self._properties["use_byte_stream_split"] + ), + data_page_version=self._properties["data_page_version"], + ) + + def _set_arrow_properties(self): + cdef CParquetFileWriteOptions* opts = self.parquet_options + + opts.arrow_writer_properties = _create_arrow_writer_properties( + use_deprecated_int96_timestamps=( + self._properties["use_deprecated_int96_timestamps"] + ), + coerce_timestamps=self._properties["coerce_timestamps"], + allow_truncated_timestamps=( + self._properties["allow_truncated_timestamps"] + ), + writer_engine_version="V2", + use_compliant_nested_type=( + self._properties["use_compliant_nested_type"] + ) + ) + + cdef void init(self, const shared_ptr[CFileWriteOptions]& sp): + FileWriteOptions.init(self, sp) + self.parquet_options = <CParquetFileWriteOptions*> sp.get() + self._properties = dict( + use_dictionary=True, + compression="snappy", + version="1.0", + write_statistics=None, + data_page_size=None, + compression_level=None, + use_byte_stream_split=False, + data_page_version="1.0", + use_deprecated_int96_timestamps=False, + coerce_timestamps=None, + allow_truncated_timestamps=False, + use_compliant_nested_type=False, + ) + self._set_properties() + self._set_arrow_properties() + + +cdef set _PARQUET_READ_OPTIONS = { + 'dictionary_columns', 'coerce_int96_timestamp_unit' +} + + +cdef class ParquetFileFormat(FileFormat): + """ + FileFormat for Parquet + + Parameters + ---------- + read_options : ParquetReadOptions + Read options for the file. + default_fragment_scan_options : ParquetFragmentScanOptions + Scan Options for the file. + **kwargs : dict + Additional options for read option or scan option. + """ + + cdef: + CParquetFileFormat* parquet_format + + def __init__(self, read_options=None, + default_fragment_scan_options=None, **kwargs): + cdef: + shared_ptr[CParquetFileFormat] wrapped + CParquetFileFormatReaderOptions* options + + # Read/scan options + read_options_args = {option: kwargs[option] for option in kwargs + if option in _PARQUET_READ_OPTIONS} + scan_args = {option: kwargs[option] for option in kwargs + if option not in _PARQUET_READ_OPTIONS} + if read_options and read_options_args: + duplicates = ', '.join(sorted(read_options_args)) + raise ValueError(f'If `read_options` is given, ' + f'cannot specify {duplicates}') + if default_fragment_scan_options and scan_args: + duplicates = ', '.join(sorted(scan_args)) + raise ValueError(f'If `default_fragment_scan_options` is given, ' + f'cannot specify {duplicates}') + + if read_options is None: + read_options = ParquetReadOptions(**read_options_args) + elif isinstance(read_options, dict): + # For backwards compatibility + duplicates = [] + for option, value in read_options.items(): + if option in _PARQUET_READ_OPTIONS: + read_options_args[option] = value + else: + duplicates.append(option) + scan_args[option] = value + if duplicates: + duplicates = ", ".join(duplicates) + warnings.warn(f'The scan options {duplicates} should be ' + 'specified directly as keyword arguments') + read_options = ParquetReadOptions(**read_options_args) + elif not isinstance(read_options, ParquetReadOptions): + raise TypeError('`read_options` must be either a dictionary or an ' + 'instance of ParquetReadOptions') + + if default_fragment_scan_options is None: + default_fragment_scan_options = ParquetFragmentScanOptions( + **scan_args) + elif isinstance(default_fragment_scan_options, dict): + default_fragment_scan_options = ParquetFragmentScanOptions( + **default_fragment_scan_options) + elif not isinstance(default_fragment_scan_options, + ParquetFragmentScanOptions): + raise TypeError('`default_fragment_scan_options` must be either a ' + 'dictionary or an instance of ' + 'ParquetFragmentScanOptions') + + wrapped = make_shared[CParquetFileFormat]() + options = &(wrapped.get().reader_options) + if read_options.dictionary_columns is not None: + for column in read_options.dictionary_columns: + options.dict_columns.insert(tobytes(column)) + options.coerce_int96_timestamp_unit = \ + read_options._coerce_int96_timestamp_unit + + self.init(<shared_ptr[CFileFormat]> wrapped) + self.default_fragment_scan_options = default_fragment_scan_options + + cdef void init(self, const shared_ptr[CFileFormat]& sp): + FileFormat.init(self, sp) + self.parquet_format = <CParquetFileFormat*> sp.get() + + @property + def read_options(self): + cdef CParquetFileFormatReaderOptions* options + options = &self.parquet_format.reader_options + parquet_read_options = ParquetReadOptions( + dictionary_columns={frombytes(col) + for col in options.dict_columns}, + ) + # Read options getter/setter works with strings so setting + # the private property which uses the C Type + parquet_read_options._coerce_int96_timestamp_unit = \ + options.coerce_int96_timestamp_unit + return parquet_read_options + + def make_write_options(self, **kwargs): + opts = FileFormat.make_write_options(self) + (<ParquetFileWriteOptions> opts).update(**kwargs) + return opts + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): + if options.type_name == 'parquet': + self.parquet_format.default_fragment_scan_options = options.wrapped + else: + super()._set_default_fragment_scan_options(options) + + def equals(self, ParquetFileFormat other): + return ( + self.read_options.equals(other.read_options) and + self.default_fragment_scan_options == + other.default_fragment_scan_options + ) + + def __reduce__(self): + return ParquetFileFormat, (self.read_options, + self.default_fragment_scan_options) + + def __repr__(self): + return f"<ParquetFileFormat read_options={self.read_options}>" + + def make_fragment(self, file, filesystem=None, + Expression partition_expression=None, row_groups=None): + cdef: + vector[int] c_row_groups + + if partition_expression is None: + partition_expression = _true + + if row_groups is None: + return super().make_fragment(file, filesystem, + partition_expression) + + c_source = _make_file_source(file, filesystem) + c_row_groups = [<int> row_group for row_group in set(row_groups)] + + c_fragment = <shared_ptr[CFragment]> GetResultValue( + self.parquet_format.MakeFragment(move(c_source), + partition_expression.unwrap(), + <shared_ptr[CSchema]>nullptr, + move(c_row_groups))) + return Fragment.wrap(move(c_fragment)) + + +cdef class ParquetFragmentScanOptions(FragmentScanOptions): + """ + Scan-specific options for Parquet fragments. + + Parameters + ---------- + use_buffered_stream : bool, default False + Read files through buffered input streams rather than loading entire + row groups at once. This may be enabled to reduce memory overhead. + Disabled by default. + buffer_size : int, default 8192 + Size of buffered stream, if enabled. Default is 8KB. + pre_buffer : bool, default False + If enabled, pre-buffer the raw Parquet data instead of issuing one + read per column chunk. This can improve performance on high-latency + filesystems. + enable_parallel_column_conversion : bool, default False + EXPERIMENTAL: Parallelize conversion across columns. This option is + ignored if a scan is already parallelized across input files to avoid + thread contention. This option will be removed after support is added + for simultaneous parallelization across files and columns. + """ + + cdef: + CParquetFragmentScanOptions* parquet_options + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, bint use_buffered_stream=False, + buffer_size=8192, + bint pre_buffer=False, + bint enable_parallel_column_conversion=False): + self.init(shared_ptr[CFragmentScanOptions]( + new CParquetFragmentScanOptions())) + self.use_buffered_stream = use_buffered_stream + self.buffer_size = buffer_size + self.pre_buffer = pre_buffer + self.enable_parallel_column_conversion = \ + enable_parallel_column_conversion + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): + FragmentScanOptions.init(self, sp) + self.parquet_options = <CParquetFragmentScanOptions*> sp.get() + + cdef CReaderProperties* reader_properties(self): + return self.parquet_options.reader_properties.get() + + cdef ArrowReaderProperties* arrow_reader_properties(self): + return self.parquet_options.arrow_reader_properties.get() + + @property + def use_buffered_stream(self): + return self.reader_properties().is_buffered_stream_enabled() + + @use_buffered_stream.setter + def use_buffered_stream(self, bint use_buffered_stream): + if use_buffered_stream: + self.reader_properties().enable_buffered_stream() + else: + self.reader_properties().disable_buffered_stream() + + @property + def buffer_size(self): + return self.reader_properties().buffer_size() + + @buffer_size.setter + def buffer_size(self, buffer_size): + if buffer_size <= 0: + raise ValueError("Buffer size must be larger than zero") + self.reader_properties().set_buffer_size(buffer_size) + + @property + def pre_buffer(self): + return self.arrow_reader_properties().pre_buffer() + + @pre_buffer.setter + def pre_buffer(self, bint pre_buffer): + self.arrow_reader_properties().set_pre_buffer(pre_buffer) + + @property + def enable_parallel_column_conversion(self): + return self.parquet_options.enable_parallel_column_conversion + + @enable_parallel_column_conversion.setter + def enable_parallel_column_conversion( + self, bint enable_parallel_column_conversion): + self.parquet_options.enable_parallel_column_conversion = \ + enable_parallel_column_conversion + + def equals(self, ParquetFragmentScanOptions other): + return ( + self.use_buffered_stream == other.use_buffered_stream and + self.buffer_size == other.buffer_size and + self.pre_buffer == other.pre_buffer and + self.enable_parallel_column_conversion == + other.enable_parallel_column_conversion + ) + + def __reduce__(self): + return ParquetFragmentScanOptions, ( + self.use_buffered_stream, self.buffer_size, self.pre_buffer, + self.enable_parallel_column_conversion + ) + + +cdef class IpcFileWriteOptions(FileWriteOptions): + + def __init__(self): + _forbid_instantiation(self.__class__) + + +cdef class IpcFileFormat(FileFormat): + + def __init__(self): + self.init(shared_ptr[CFileFormat](new CIpcFileFormat())) + + def equals(self, IpcFileFormat other): + return True + + @property + def default_extname(self): + return "feather" + + def __reduce__(self): + return IpcFileFormat, tuple() + + +cdef class CsvFileFormat(FileFormat): + """ + FileFormat for CSV files. + + Parameters + ---------- + parse_options : ParseOptions + Options regarding CSV parsing. + convert_options : ConvertOptions + Options regarding value conversion. + read_options : ReadOptions + General read options. + default_fragment_scan_options : CsvFragmentScanOptions + Default options for fragments scan. + """ + cdef: + CCsvFileFormat* csv_format + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, ParseOptions parse_options=None, + default_fragment_scan_options=None, + ConvertOptions convert_options=None, + ReadOptions read_options=None): + self.init(shared_ptr[CFileFormat](new CCsvFileFormat())) + if parse_options is not None: + self.parse_options = parse_options + if convert_options is not None or read_options is not None: + if default_fragment_scan_options: + raise ValueError('If `default_fragment_scan_options` is ' + 'given, cannot specify convert_options ' + 'or read_options') + self.default_fragment_scan_options = CsvFragmentScanOptions( + convert_options=convert_options, read_options=read_options) + elif isinstance(default_fragment_scan_options, dict): + self.default_fragment_scan_options = CsvFragmentScanOptions( + **default_fragment_scan_options) + elif isinstance(default_fragment_scan_options, CsvFragmentScanOptions): + self.default_fragment_scan_options = default_fragment_scan_options + elif default_fragment_scan_options is not None: + raise TypeError('`default_fragment_scan_options` must be either ' + 'a dictionary or an instance of ' + 'CsvFragmentScanOptions') + + cdef void init(self, const shared_ptr[CFileFormat]& sp): + FileFormat.init(self, sp) + self.csv_format = <CCsvFileFormat*> sp.get() + + def make_write_options(self, **kwargs): + cdef CsvFileWriteOptions opts = \ + <CsvFileWriteOptions> FileFormat.make_write_options(self) + opts.write_options = WriteOptions(**kwargs) + return opts + + @property + def parse_options(self): + return ParseOptions.wrap(self.csv_format.parse_options) + + @parse_options.setter + def parse_options(self, ParseOptions parse_options not None): + self.csv_format.parse_options = deref(parse_options.options) + + cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): + if options.type_name == 'csv': + self.csv_format.default_fragment_scan_options = options.wrapped + else: + super()._set_default_fragment_scan_options(options) + + def equals(self, CsvFileFormat other): + return ( + self.parse_options.equals(other.parse_options) and + self.default_fragment_scan_options == + other.default_fragment_scan_options) + + def __reduce__(self): + return CsvFileFormat, (self.parse_options, + self.default_fragment_scan_options) + + def __repr__(self): + return f"<CsvFileFormat parse_options={self.parse_options}>" + + +cdef class CsvFragmentScanOptions(FragmentScanOptions): + """ + Scan-specific options for CSV fragments. + + Parameters + ---------- + convert_options : ConvertOptions + Options regarding value conversion. + read_options : ReadOptions + General read options. + """ + + cdef: + CCsvFragmentScanOptions* csv_options + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, ConvertOptions convert_options=None, + ReadOptions read_options=None): + self.init(shared_ptr[CFragmentScanOptions]( + new CCsvFragmentScanOptions())) + if convert_options is not None: + self.convert_options = convert_options + if read_options is not None: + self.read_options = read_options + + cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): + FragmentScanOptions.init(self, sp) + self.csv_options = <CCsvFragmentScanOptions*> sp.get() + + @property + def convert_options(self): + return ConvertOptions.wrap(self.csv_options.convert_options) + + @convert_options.setter + def convert_options(self, ConvertOptions convert_options not None): + self.csv_options.convert_options = deref(convert_options.options) + + @property + def read_options(self): + return ReadOptions.wrap(self.csv_options.read_options) + + @read_options.setter + def read_options(self, ReadOptions read_options not None): + self.csv_options.read_options = deref(read_options.options) + + def equals(self, CsvFragmentScanOptions other): + return ( + other and + self.convert_options.equals(other.convert_options) and + self.read_options.equals(other.read_options)) + + def __reduce__(self): + return CsvFragmentScanOptions, (self.convert_options, + self.read_options) + + +cdef class CsvFileWriteOptions(FileWriteOptions): + cdef: + CCsvFileWriteOptions* csv_options + object _properties + + def __init__(self): + _forbid_instantiation(self.__class__) + + @property + def write_options(self): + return WriteOptions.wrap(deref(self.csv_options.write_options)) + + @write_options.setter + def write_options(self, WriteOptions write_options not None): + self.csv_options.write_options.reset( + new CCSVWriteOptions(deref(write_options.options))) + + cdef void init(self, const shared_ptr[CFileWriteOptions]& sp): + FileWriteOptions.init(self, sp) + self.csv_options = <CCsvFileWriteOptions*> sp.get() + + +cdef class Partitioning(_Weakrefable): + + cdef: + shared_ptr[CPartitioning] wrapped + CPartitioning* partitioning + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + self.wrapped = sp + self.partitioning = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CPartitioning]& sp): + type_name = frombytes(sp.get().type_name()) + + classes = { + 'directory': DirectoryPartitioning, + 'hive': HivePartitioning, + } + + class_ = classes.get(type_name, None) + if class_ is None: + raise TypeError(type_name) + + cdef Partitioning self = class_.__new__(class_) + self.init(sp) + return self + + cdef inline shared_ptr[CPartitioning] unwrap(self): + return self.wrapped + + def parse(self, path): + cdef CResult[CExpression] result + result = self.partitioning.Parse(tobytes(path)) + return Expression.wrap(GetResultValue(result)) + + @property + def schema(self): + """The arrow Schema attached to the partitioning.""" + return pyarrow_wrap_schema(self.partitioning.schema()) + + +cdef class PartitioningFactory(_Weakrefable): + + cdef: + shared_ptr[CPartitioningFactory] wrapped + CPartitioningFactory* factory + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef init(self, const shared_ptr[CPartitioningFactory]& sp): + self.wrapped = sp + self.factory = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CPartitioningFactory]& sp): + cdef PartitioningFactory self = PartitioningFactory.__new__( + PartitioningFactory + ) + self.init(sp) + return self + + cdef inline shared_ptr[CPartitioningFactory] unwrap(self): + return self.wrapped + + @property + def type_name(self): + return frombytes(self.factory.type_name()) + + +cdef vector[shared_ptr[CArray]] _partitioning_dictionaries( + Schema schema, dictionaries) except *: + cdef: + vector[shared_ptr[CArray]] c_dictionaries + + dictionaries = dictionaries or {} + + for field in schema: + dictionary = dictionaries.get(field.name) + + if (isinstance(field.type, pa.DictionaryType) and + dictionary is not None): + c_dictionaries.push_back(pyarrow_unwrap_array(dictionary)) + else: + c_dictionaries.push_back(<shared_ptr[CArray]> nullptr) + + return c_dictionaries + + +cdef class DirectoryPartitioning(Partitioning): + """ + A Partitioning based on a specified Schema. + + The DirectoryPartitioning expects one segment in the file path for each + field in the schema (all fields are required to be present). + For example given schema<year:int16, month:int8> the path "/2009/11" would + be parsed to ("year"_ == 2009 and "month"_ == 11). + + Parameters + ---------- + schema : Schema + The schema that describes the partitions present in the file path. + dictionaries : Dict[str, Array] + If the type of any field of `schema` is a dictionary type, the + corresponding entry of `dictionaries` must be an array containing + every value which may be taken by the corresponding column or an + error will be raised in parsing. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + DirectoryPartitioning + + Examples + -------- + >>> from pyarrow.dataset import DirectoryPartitioning + >>> partition = DirectoryPartitioning( + ... pa.schema([("year", pa.int16()), ("month", pa.int8())])) + >>> print(partitioning.parse("/2009/11")) + ((year == 2009:int16) and (month == 11:int8)) + """ + + cdef: + CDirectoryPartitioning* directory_partitioning + + def __init__(self, Schema schema not None, dictionaries=None, + segment_encoding="uri"): + cdef: + shared_ptr[CDirectoryPartitioning] c_partitioning + CKeyValuePartitioningOptions c_options + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + c_partitioning = make_shared[CDirectoryPartitioning]( + pyarrow_unwrap_schema(schema), + _partitioning_dictionaries(schema, dictionaries), + c_options, + ) + self.init(<shared_ptr[CPartitioning]> c_partitioning) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + Partitioning.init(self, sp) + self.directory_partitioning = <CDirectoryPartitioning*> sp.get() + + @staticmethod + def discover(field_names=None, infer_dictionary=False, + max_partition_dictionary_size=0, + schema=None, segment_encoding="uri"): + """ + Discover a DirectoryPartitioning. + + Parameters + ---------- + field_names : list of str + The names to associate with the values from the subdirectory names. + If schema is given, will be populated from the schema. + infer_dictionary : bool, default False + When inferring a schema for partition fields, yield dictionary + encoded types instead of plain types. This can be more efficient + when materializing virtual columns, and Expressions parsed by the + finished Partitioning will include dictionaries of all unique + inspected values for each field. + max_partition_dictionary_size : int, default 0 + Synonymous with infer_dictionary for backwards compatibility with + 1.0: setting this to -1 or None is equivalent to passing + infer_dictionary=True. + schema : Schema, default None + Use this schema instead of inferring a schema from partition + values. Partition values will be validated against this schema + before accumulation into the Partitioning's dictionary. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + PartitioningFactory + To be used in the FileSystemFactoryOptions. + """ + cdef: + CPartitioningFactoryOptions c_options + vector[c_string] c_field_names + + if max_partition_dictionary_size in {-1, None}: + infer_dictionary = True + elif max_partition_dictionary_size != 0: + raise NotImplementedError("max_partition_dictionary_size must be " + "0, -1, or None") + + if infer_dictionary: + c_options.infer_dictionary = True + + if schema: + c_options.schema = pyarrow_unwrap_schema(schema) + c_field_names = [tobytes(f.name) for f in schema] + elif not field_names: + raise ValueError( + "Neither field_names nor schema was passed; " + "cannot infer field_names") + else: + c_field_names = [tobytes(s) for s in field_names] + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + + return PartitioningFactory.wrap( + CDirectoryPartitioning.MakeFactory(c_field_names, c_options)) + + @property + def dictionaries(self): + """ + The unique values for each partition field, if available. + + Those values are only available if the Partitioning object was + created through dataset discovery from a PartitioningFactory, or + if the dictionaries were manually specified in the constructor. + If not available, this returns None. + """ + cdef vector[shared_ptr[CArray]] c_arrays + c_arrays = self.directory_partitioning.dictionaries() + res = [] + for arr in c_arrays: + if arr.get() == nullptr: + # Partitioning object has not been created through + # inspected Factory + return None + res.append(pyarrow_wrap_array(arr)) + return res + + +cdef class HivePartitioning(Partitioning): + """ + A Partitioning for "/$key=$value/" nested directories as found in + Apache Hive. + + Multi-level, directory based partitioning scheme originating from + Apache Hive with all data files stored in the leaf directories. Data is + partitioned by static values of a particular column in the schema. + Partition keys are represented in the form $key=$value in directory names. + Field order is ignored, as are missing or unrecognized field names. + + For example, given schema<year:int16, month:int8, day:int8>, a possible + path would be "/year=2009/month=11/day=15". + + Parameters + ---------- + schema : Schema + The schema that describes the partitions present in the file path. + dictionaries : Dict[str, Array] + If the type of any field of `schema` is a dictionary type, the + corresponding entry of `dictionaries` must be an array containing + every value which may be taken by the corresponding column or an + error will be raised in parsing. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + If any field is None then this fallback will be used as a label + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + HivePartitioning + + Examples + -------- + >>> from pyarrow.dataset import HivePartitioning + >>> partitioning = HivePartitioning( + ... pa.schema([("year", pa.int16()), ("month", pa.int8())])) + >>> print(partitioning.parse("/year=2009/month=11")) + ((year == 2009:int16) and (month == 11:int8)) + + """ + + cdef: + CHivePartitioning* hive_partitioning + + def __init__(self, + Schema schema not None, + dictionaries=None, + null_fallback="__HIVE_DEFAULT_PARTITION__", + segment_encoding="uri"): + + cdef: + shared_ptr[CHivePartitioning] c_partitioning + CHivePartitioningOptions c_options + + c_options.null_fallback = tobytes(null_fallback) + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + + c_partitioning = make_shared[CHivePartitioning]( + pyarrow_unwrap_schema(schema), + _partitioning_dictionaries(schema, dictionaries), + c_options, + ) + self.init(<shared_ptr[CPartitioning]> c_partitioning) + + cdef init(self, const shared_ptr[CPartitioning]& sp): + Partitioning.init(self, sp) + self.hive_partitioning = <CHivePartitioning*> sp.get() + + @staticmethod + def discover(infer_dictionary=False, + max_partition_dictionary_size=0, + null_fallback="__HIVE_DEFAULT_PARTITION__", + schema=None, + segment_encoding="uri"): + """ + Discover a HivePartitioning. + + Parameters + ---------- + infer_dictionary : bool, default False + When inferring a schema for partition fields, yield dictionary + encoded types instead of plain. This can be more efficient when + materializing virtual columns, and Expressions parsed by the + finished Partitioning will include dictionaries of all unique + inspected values for each field. + max_partition_dictionary_size : int, default 0 + Synonymous with infer_dictionary for backwards compatibility with + 1.0: setting this to -1 or None is equivalent to passing + infer_dictionary=True. + null_fallback : str, default "__HIVE_DEFAULT_PARTITION__" + When inferring a schema for partition fields this value will be + replaced by null. The default is set to __HIVE_DEFAULT_PARTITION__ + for compatibility with Spark + schema : Schema, default None + Use this schema instead of inferring a schema from partition + values. Partition values will be validated against this schema + before accumulation into the Partitioning's dictionary. + segment_encoding : str, default "uri" + After splitting paths into segments, decode the segments. Valid + values are "uri" (URI-decode segments) and "none" (leave as-is). + + Returns + ------- + PartitioningFactory + To be used in the FileSystemFactoryOptions. + """ + cdef: + CHivePartitioningFactoryOptions c_options + + if max_partition_dictionary_size in {-1, None}: + infer_dictionary = True + elif max_partition_dictionary_size != 0: + raise NotImplementedError("max_partition_dictionary_size must be " + "0, -1, or None") + + if infer_dictionary: + c_options.infer_dictionary = True + + c_options.null_fallback = tobytes(null_fallback) + + if schema: + c_options.schema = pyarrow_unwrap_schema(schema) + + c_options.segment_encoding = _get_segment_encoding(segment_encoding) + + return PartitioningFactory.wrap( + CHivePartitioning.MakeFactory(c_options)) + + @property + def dictionaries(self): + """ + The unique values for each partition field, if available. + + Those values are only available if the Partitioning object was + created through dataset discovery from a PartitioningFactory, or + if the dictionaries were manually specified in the constructor. + If not available, this returns None. + """ + cdef vector[shared_ptr[CArray]] c_arrays + c_arrays = self.hive_partitioning.dictionaries() + res = [] + for arr in c_arrays: + if arr.get() == nullptr: + # Partitioning object has not been created through + # inspected Factory + return None + res.append(pyarrow_wrap_array(arr)) + return res + + +cdef class DatasetFactory(_Weakrefable): + """ + DatasetFactory is used to create a Dataset, inspect the Schema + of the fragments contained in it, and declare a partitioning. + """ + + cdef: + shared_ptr[CDatasetFactory] wrapped + CDatasetFactory* factory + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef init(self, const shared_ptr[CDatasetFactory]& sp): + self.wrapped = sp + self.factory = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CDatasetFactory]& sp): + cdef DatasetFactory self = \ + DatasetFactory.__new__(DatasetFactory) + self.init(sp) + return self + + cdef inline shared_ptr[CDatasetFactory] unwrap(self) nogil: + return self.wrapped + + @property + def root_partition(self): + return Expression.wrap(self.factory.root_partition()) + + @root_partition.setter + def root_partition(self, Expression expr): + check_status(self.factory.SetRootPartition(expr.unwrap())) + + def inspect_schemas(self): + cdef CResult[vector[shared_ptr[CSchema]]] result + cdef CInspectOptions options + with nogil: + result = self.factory.InspectSchemas(options) + + schemas = [] + for s in GetResultValue(result): + schemas.append(pyarrow_wrap_schema(s)) + return schemas + + def inspect(self): + """ + Inspect all data fragments and return a common Schema. + + Returns + ------- + Schema + """ + cdef: + CInspectOptions options + CResult[shared_ptr[CSchema]] result + with nogil: + result = self.factory.Inspect(options) + return pyarrow_wrap_schema(GetResultValue(result)) + + def finish(self, Schema schema=None): + """ + Create a Dataset using the inspected schema or an explicit schema + (if given). + + Parameters + ---------- + schema : Schema, default None + The schema to conform the source to. If None, the inspected + schema is used. + + Returns + ------- + Dataset + """ + cdef: + shared_ptr[CSchema] sp_schema + CResult[shared_ptr[CDataset]] result + + if schema is not None: + sp_schema = pyarrow_unwrap_schema(schema) + with nogil: + result = self.factory.FinishWithSchema(sp_schema) + else: + with nogil: + result = self.factory.Finish() + + return Dataset.wrap(GetResultValue(result)) + + +cdef class FileSystemFactoryOptions(_Weakrefable): + """ + Influences the discovery of filesystem paths. + + Parameters + ---------- + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + partitioning : Partitioning/PartitioningFactory, optional + Apply the Partitioning to every discovered Fragment. See Partitioning or + PartitioningFactory documentation. + exclude_invalid_files : bool, optional (default True) + If True, invalid files will be excluded (file format specific check). + This will incur IO for each files in a serial and single threaded + fashion. Disabling this feature will skip the IO, but unsupported + files may be present in the Dataset (resulting in an error at scan + time). + selector_ignore_prefixes : list, optional + When discovering from a Selector (and not from an explicit file list), + ignore files and directories matching any of these prefixes. + By default this is ['.', '_']. + """ + + cdef: + CFileSystemFactoryOptions options + + __slots__ = () # avoid mistakingly creating attributes + + def __init__(self, partition_base_dir=None, partitioning=None, + exclude_invalid_files=None, + list selector_ignore_prefixes=None): + if isinstance(partitioning, PartitioningFactory): + self.partitioning_factory = partitioning + elif isinstance(partitioning, Partitioning): + self.partitioning = partitioning + + if partition_base_dir is not None: + self.partition_base_dir = partition_base_dir + if exclude_invalid_files is not None: + self.exclude_invalid_files = exclude_invalid_files + if selector_ignore_prefixes is not None: + self.selector_ignore_prefixes = selector_ignore_prefixes + + cdef inline CFileSystemFactoryOptions unwrap(self): + return self.options + + @property + def partitioning(self): + """Partitioning to apply to discovered files. + + NOTE: setting this property will overwrite partitioning_factory. + """ + c_partitioning = self.options.partitioning.partitioning() + if c_partitioning.get() == nullptr: + return None + return Partitioning.wrap(c_partitioning) + + @partitioning.setter + def partitioning(self, Partitioning value): + self.options.partitioning = (<Partitioning> value).unwrap() + + @property + def partitioning_factory(self): + """PartitioningFactory to apply to discovered files and + discover a Partitioning. + + NOTE: setting this property will overwrite partitioning. + """ + c_factory = self.options.partitioning.factory() + if c_factory.get() == nullptr: + return None + return PartitioningFactory.wrap(c_factory) + + @partitioning_factory.setter + def partitioning_factory(self, PartitioningFactory value): + self.options.partitioning = (<PartitioningFactory> value).unwrap() + + @property + def partition_base_dir(self): + """ + Base directory to strip paths before applying the partitioning. + """ + return frombytes(self.options.partition_base_dir) + + @partition_base_dir.setter + def partition_base_dir(self, value): + self.options.partition_base_dir = tobytes(value) + + @property + def exclude_invalid_files(self): + """Whether to exclude invalid files.""" + return self.options.exclude_invalid_files + + @exclude_invalid_files.setter + def exclude_invalid_files(self, bint value): + self.options.exclude_invalid_files = value + + @property + def selector_ignore_prefixes(self): + """ + List of prefixes. Files matching one of those prefixes will be + ignored by the discovery process. + """ + return [frombytes(p) for p in self.options.selector_ignore_prefixes] + + @selector_ignore_prefixes.setter + def selector_ignore_prefixes(self, values): + self.options.selector_ignore_prefixes = [tobytes(v) for v in values] + + +cdef class FileSystemDatasetFactory(DatasetFactory): + """ + Create a DatasetFactory from a list of paths with schema inspection. + + Parameters + ---------- + filesystem : pyarrow.fs.FileSystem + Filesystem to discover. + paths_or_selector : pyarrow.fs.Selector or list of path-likes + Either a Selector object or a list of path-like objects. + format : FileFormat + Currently only ParquetFileFormat and IpcFileFormat are supported. + options : FileSystemFactoryOptions, optional + Various flags influencing the discovery of filesystem paths. + """ + + cdef: + CFileSystemDatasetFactory* filesystem_factory + + def __init__(self, FileSystem filesystem not None, paths_or_selector, + FileFormat format not None, + FileSystemFactoryOptions options=None): + cdef: + vector[c_string] paths + CFileSelector c_selector + CResult[shared_ptr[CDatasetFactory]] result + shared_ptr[CFileSystem] c_filesystem + shared_ptr[CFileFormat] c_format + CFileSystemFactoryOptions c_options + + options = options or FileSystemFactoryOptions() + c_options = options.unwrap() + c_filesystem = filesystem.unwrap() + c_format = format.unwrap() + + if isinstance(paths_or_selector, FileSelector): + with nogil: + c_selector = (<FileSelector> paths_or_selector).selector + result = CFileSystemDatasetFactory.MakeFromSelector( + c_filesystem, + c_selector, + c_format, + c_options + ) + elif isinstance(paths_or_selector, (list, tuple)): + paths = [tobytes(s) for s in paths_or_selector] + with nogil: + result = CFileSystemDatasetFactory.MakeFromPaths( + c_filesystem, + paths, + c_format, + c_options + ) + else: + raise TypeError('Must pass either paths or a FileSelector, but ' + 'passed {}'.format(type(paths_or_selector))) + + self.init(GetResultValue(result)) + + cdef init(self, shared_ptr[CDatasetFactory]& sp): + DatasetFactory.init(self, sp) + self.filesystem_factory = <CFileSystemDatasetFactory*> sp.get() + + +cdef class UnionDatasetFactory(DatasetFactory): + """ + Provides a way to inspect/discover a Dataset's expected schema before + materialization. + + Parameters + ---------- + factories : list of DatasetFactory + """ + + cdef: + CUnionDatasetFactory* union_factory + + def __init__(self, list factories): + cdef: + DatasetFactory factory + vector[shared_ptr[CDatasetFactory]] c_factories + for factory in factories: + c_factories.push_back(factory.unwrap()) + self.init(GetResultValue(CUnionDatasetFactory.Make(c_factories))) + + cdef init(self, const shared_ptr[CDatasetFactory]& sp): + DatasetFactory.init(self, sp) + self.union_factory = <CUnionDatasetFactory*> sp.get() + + +cdef class ParquetFactoryOptions(_Weakrefable): + """ + Influences the discovery of parquet dataset. + + Parameters + ---------- + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + partitioning : Partitioning, PartitioningFactory, optional + The partitioning scheme applied to fragments, see ``Partitioning``. + validate_column_chunk_paths : bool, default False + Assert that all ColumnChunk paths are consistent. The parquet spec + allows for ColumnChunk data to be stored in multiple files, but + ParquetDatasetFactory supports only a single file with all ColumnChunk + data. If this flag is set construction of a ParquetDatasetFactory will + raise an error if ColumnChunk data is not resident in a single file. + """ + + cdef: + CParquetFactoryOptions options + + __slots__ = () # avoid mistakingly creating attributes + + def __init__(self, partition_base_dir=None, partitioning=None, + validate_column_chunk_paths=False): + if isinstance(partitioning, PartitioningFactory): + self.partitioning_factory = partitioning + elif isinstance(partitioning, Partitioning): + self.partitioning = partitioning + + if partition_base_dir is not None: + self.partition_base_dir = partition_base_dir + + self.options.validate_column_chunk_paths = validate_column_chunk_paths + + cdef inline CParquetFactoryOptions unwrap(self): + return self.options + + @property + def partitioning(self): + """Partitioning to apply to discovered files. + + NOTE: setting this property will overwrite partitioning_factory. + """ + c_partitioning = self.options.partitioning.partitioning() + if c_partitioning.get() == nullptr: + return None + return Partitioning.wrap(c_partitioning) + + @partitioning.setter + def partitioning(self, Partitioning value): + self.options.partitioning = (<Partitioning> value).unwrap() + + @property + def partitioning_factory(self): + """PartitioningFactory to apply to discovered files and + discover a Partitioning. + + NOTE: setting this property will overwrite partitioning. + """ + c_factory = self.options.partitioning.factory() + if c_factory.get() == nullptr: + return None + return PartitioningFactory.wrap(c_factory) + + @partitioning_factory.setter + def partitioning_factory(self, PartitioningFactory value): + self.options.partitioning = (<PartitioningFactory> value).unwrap() + + @property + def partition_base_dir(self): + """ + Base directory to strip paths before applying the partitioning. + """ + return frombytes(self.options.partition_base_dir) + + @partition_base_dir.setter + def partition_base_dir(self, value): + self.options.partition_base_dir = tobytes(value) + + @property + def validate_column_chunk_paths(self): + """ + Base directory to strip paths before applying the partitioning. + """ + return self.options.validate_column_chunk_paths + + @validate_column_chunk_paths.setter + def validate_column_chunk_paths(self, value): + self.options.validate_column_chunk_paths = value + + +cdef class ParquetDatasetFactory(DatasetFactory): + """ + Create a ParquetDatasetFactory from a Parquet `_metadata` file. + + Parameters + ---------- + metadata_path : str + Path to the `_metadata` parquet metadata-only file generated with + `pyarrow.parquet.write_metadata`. + filesystem : pyarrow.fs.FileSystem + Filesystem to read the metadata_path from, and subsequent parquet + files. + format : ParquetFileFormat + Parquet format options. + options : ParquetFactoryOptions, optional + Various flags influencing the discovery of filesystem paths. + """ + + cdef: + CParquetDatasetFactory* parquet_factory + + def __init__(self, metadata_path, FileSystem filesystem not None, + FileFormat format not None, + ParquetFactoryOptions options=None): + cdef: + c_string path + shared_ptr[CFileSystem] c_filesystem + shared_ptr[CParquetFileFormat] c_format + CResult[shared_ptr[CDatasetFactory]] result + CParquetFactoryOptions c_options + + c_path = tobytes(metadata_path) + c_filesystem = filesystem.unwrap() + c_format = static_pointer_cast[CParquetFileFormat, CFileFormat]( + format.unwrap()) + options = options or ParquetFactoryOptions() + c_options = options.unwrap() + + result = CParquetDatasetFactory.MakeFromMetaDataPath( + c_path, c_filesystem, c_format, c_options) + self.init(GetResultValue(result)) + + cdef init(self, shared_ptr[CDatasetFactory]& sp): + DatasetFactory.init(self, sp) + self.parquet_factory = <CParquetDatasetFactory*> sp.get() + + +cdef class RecordBatchIterator(_Weakrefable): + """An iterator over a sequence of record batches.""" + cdef: + # An object that must be kept alive with the iterator. + object iterator_owner + # Iterator is a non-POD type and Cython uses offsetof, leading + # to a compiler warning unless wrapped like so + shared_ptr[CRecordBatchIterator] iterator + + def __init__(self): + _forbid_instantiation(self.__class__, subclasses_instead=False) + + @staticmethod + cdef wrap(object owner, CRecordBatchIterator iterator): + cdef RecordBatchIterator self = \ + RecordBatchIterator.__new__(RecordBatchIterator) + self.iterator_owner = owner + self.iterator = make_shared[CRecordBatchIterator](move(iterator)) + return self + + def __iter__(self): + return self + + def __next__(self): + cdef shared_ptr[CRecordBatch] record_batch + with nogil: + record_batch = GetResultValue(move(self.iterator.get().Next())) + if record_batch == NULL: + raise StopIteration + return pyarrow_wrap_batch(record_batch) + + +class TaggedRecordBatch(collections.namedtuple( + "TaggedRecordBatch", ["record_batch", "fragment"])): + """ + A combination of a record batch and the fragment it came from. + + Parameters + ---------- + record_batch : The record batch. + fragment : fragment of the record batch. + """ + + +cdef class TaggedRecordBatchIterator(_Weakrefable): + """An iterator over a sequence of record batches with fragments.""" + cdef: + object iterator_owner + shared_ptr[CTaggedRecordBatchIterator] iterator + + def __init__(self): + _forbid_instantiation(self.__class__, subclasses_instead=False) + + @staticmethod + cdef wrap(object owner, CTaggedRecordBatchIterator iterator): + cdef TaggedRecordBatchIterator self = \ + TaggedRecordBatchIterator.__new__(TaggedRecordBatchIterator) + self.iterator_owner = owner + self.iterator = make_shared[CTaggedRecordBatchIterator]( + move(iterator)) + return self + + def __iter__(self): + return self + + def __next__(self): + cdef CTaggedRecordBatch batch + with nogil: + batch = GetResultValue(move(self.iterator.get().Next())) + if batch.record_batch == NULL: + raise StopIteration + return TaggedRecordBatch( + record_batch=pyarrow_wrap_batch(batch.record_batch), + fragment=Fragment.wrap(batch.fragment)) + + +_DEFAULT_BATCH_SIZE = 2**20 + + +cdef void _populate_builder(const shared_ptr[CScannerBuilder]& ptr, + object columns=None, Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + bint use_threads=True, bint use_async=False, + MemoryPool memory_pool=None, + FragmentScanOptions fragment_scan_options=None)\ + except *: + cdef: + CScannerBuilder *builder + vector[CExpression] c_exprs + + builder = ptr.get() + + check_status(builder.Filter(_bind( + filter, pyarrow_wrap_schema(builder.schema())))) + + if columns is not None: + if isinstance(columns, dict): + for expr in columns.values(): + if not isinstance(expr, Expression): + raise TypeError( + "Expected an Expression for a 'column' dictionary " + "value, got {} instead".format(type(expr)) + ) + c_exprs.push_back((<Expression> expr).unwrap()) + + check_status( + builder.Project(c_exprs, [tobytes(c) for c in columns.keys()]) + ) + elif isinstance(columns, list): + check_status(builder.ProjectColumns([tobytes(c) for c in columns])) + else: + raise ValueError( + "Expected a list or a dict for 'columns', " + "got {} instead.".format(type(columns)) + ) + + check_status(builder.BatchSize(batch_size)) + check_status(builder.UseThreads(use_threads)) + check_status(builder.UseAsync(use_async)) + if memory_pool: + check_status(builder.Pool(maybe_unbox_memory_pool(memory_pool))) + if fragment_scan_options: + check_status( + builder.FragmentScanOptions(fragment_scan_options.wrapped)) + + +cdef class Scanner(_Weakrefable): + """A materialized scan operation with context and options bound. + + A scanner is the class that glues the scan tasks, data fragments and data + sources together. + + Parameters + ---------- + dataset : Dataset + Dataset to scan. + columns : list of str or dict, default None + The columns to project. This can be a list of column names to include + (order and duplicates will be preserved), or a dictionary with + {new_column_name: expression} values for more advanced projections. + The columns will be passed down to Datasets and corresponding data + fragments to avoid loading, copying, and deserializing columns + that will not be required further down the compute chain. + By default all of the available columns are projected. Raises + an exception if any of the referenced column names does not exist + in the dataset's Schema. + filter : Expression, default None + Scan will return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the + partition information or internal metadata found in the data + source, e.g. Parquet statistics. Otherwise filters the loaded + RecordBatches before yielding them. + batch_size : int, default 1M + The maximum row count for scanned record batches. If scanned + record batches are overflowing memory then this method can be + called to reduce their size. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + use_async : bool, default False + If enabled, an async scanner will be used that should offer + better performance with high-latency/highly-parallel filesystems + (e.g. S3) + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + """ + + cdef: + shared_ptr[CScanner] wrapped + CScanner* scanner + + def __init__(self): + _forbid_instantiation(self.__class__) + + cdef void init(self, const shared_ptr[CScanner]& sp): + self.wrapped = sp + self.scanner = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CScanner]& sp): + cdef Scanner self = Scanner.__new__(Scanner) + self.init(sp) + return self + + cdef inline shared_ptr[CScanner] unwrap(self): + return self.wrapped + + @staticmethod + def from_dataset(Dataset dataset not None, + bint use_threads=True, bint use_async=False, + MemoryPool memory_pool=None, + object columns=None, Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + FragmentScanOptions fragment_scan_options=None): + """ + Create Scanner from Dataset, + refer to Scanner class doc for additional details on Scanner. + + Parameters + ---------- + dataset : Dataset + Dataset to scan. + columns : list of str or dict, default None + The columns to project. + filter : Expression, default None + Scan will return only the rows matching the filter. + batch_size : int, default 1M + The maximum row count for scanned record batches. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + use_async : bool, default False + If enabled, an async scanner will be used that should offer + better performance with high-latency/highly-parallel filesystems + (e.g. S3) + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + fragment_scan_options : FragmentScanOptions + The fragment scan options. + """ + cdef: + shared_ptr[CScanOptions] options = make_shared[CScanOptions]() + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + + builder = make_shared[CScannerBuilder](dataset.unwrap(), options) + _populate_builder(builder, columns=columns, filter=filter, + batch_size=batch_size, use_threads=use_threads, + use_async=use_async, memory_pool=memory_pool, + fragment_scan_options=fragment_scan_options) + + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + + @staticmethod + def from_fragment(Fragment fragment not None, Schema schema=None, + bint use_threads=True, bint use_async=False, + MemoryPool memory_pool=None, + object columns=None, Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + FragmentScanOptions fragment_scan_options=None): + """ + Create Scanner from Fragment, + refer to Scanner class doc for additional details on Scanner. + + Parameters + ---------- + fragment : Fragment + fragment to scan. + schema : Schema + The schema of the fragment. + columns : list of str or dict, default None + The columns to project. + filter : Expression, default None + Scan will return only the rows matching the filter. + batch_size : int, default 1M + The maximum row count for scanned record batches. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + use_async : bool, default False + If enabled, an async scanner will be used that should offer + better performance with high-latency/highly-parallel filesystems + (e.g. S3) + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + fragment_scan_options : FragmentScanOptions + The fragment scan options. + """ + cdef: + shared_ptr[CScanOptions] options = make_shared[CScanOptions]() + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + + schema = schema or fragment.physical_schema + + builder = make_shared[CScannerBuilder](pyarrow_unwrap_schema(schema), + fragment.unwrap(), options) + _populate_builder(builder, columns=columns, filter=filter, + batch_size=batch_size, use_threads=use_threads, + use_async=use_async, memory_pool=memory_pool, + fragment_scan_options=fragment_scan_options) + + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + + @staticmethod + def from_batches(source, Schema schema=None, bint use_threads=True, + bint use_async=False, + MemoryPool memory_pool=None, object columns=None, + Expression filter=None, + int batch_size=_DEFAULT_BATCH_SIZE, + FragmentScanOptions fragment_scan_options=None): + """ + Create a Scanner from an iterator of batches. + + This creates a scanner which can be used only once. It is + intended to support writing a dataset (which takes a scanner) + from a source which can be read only once (e.g. a + RecordBatchReader or generator). + + Parameters + ---------- + source : Iterator + The iterator of Batches. + schema : Schema + The schema of the batches. + columns : list of str or dict, default None + The columns to project. + filter : Expression, default None + Scan will return only the rows matching the filter. + batch_size : int, default 1M + The maximum row count for scanned record batches. + use_threads : bool, default True + If enabled, then maximum parallelism will be used determined by + the number of available CPU cores. + use_async : bool, default False + If enabled, an async scanner will be used that should offer + better performance with high-latency/highly-parallel filesystems + (e.g. S3) + memory_pool : MemoryPool, default None + For memory allocations, if required. If not specified, uses the + default pool. + fragment_scan_options : FragmentScanOptions + The fragment scan options. + """ + cdef: + shared_ptr[CScanOptions] options = make_shared[CScanOptions]() + shared_ptr[CScannerBuilder] builder + shared_ptr[CScanner] scanner + RecordBatchReader reader + if isinstance(source, pa.ipc.RecordBatchReader): + if schema: + raise ValueError('Cannot specify a schema when providing ' + 'a RecordBatchReader') + reader = source + elif _is_iterable(source): + if schema is None: + raise ValueError('Must provide schema to construct scanner ' + 'from an iterable') + reader = pa.ipc.RecordBatchReader.from_batches(schema, source) + else: + raise TypeError('Expected a RecordBatchReader or an iterable of ' + 'batches instead of the given type: ' + + type(source).__name__) + builder = CScannerBuilder.FromRecordBatchReader(reader.reader) + _populate_builder(builder, columns=columns, filter=filter, + batch_size=batch_size, use_threads=use_threads, + use_async=use_async, memory_pool=memory_pool, + fragment_scan_options=fragment_scan_options) + scanner = GetResultValue(builder.get().Finish()) + return Scanner.wrap(scanner) + + @property + def dataset_schema(self): + """The schema with which batches will be read from fragments.""" + return pyarrow_wrap_schema( + self.scanner.options().get().dataset_schema) + + @property + def projected_schema(self): + """The materialized schema of the data, accounting for projections. + + This is the schema of any data returned from the scanner. + """ + return pyarrow_wrap_schema( + self.scanner.options().get().projected_schema) + + def to_batches(self): + """Consume a Scanner in record batches. + + Returns + ------- + record_batches : iterator of RecordBatch + """ + def _iterator(batch_iter): + for batch in batch_iter: + yield batch.record_batch + # Don't make ourselves a generator so errors are raised immediately + return _iterator(self.scan_batches()) + + def scan_batches(self): + """Consume a Scanner in record batches with corresponding fragments. + + Returns + ------- + record_batches : iterator of TaggedRecordBatch + """ + cdef CTaggedRecordBatchIterator iterator + with nogil: + iterator = move(GetResultValue(self.scanner.ScanBatches())) + # Don't make ourselves a generator so errors are raised immediately + return TaggedRecordBatchIterator.wrap(self, move(iterator)) + + def to_table(self): + """Convert a Scanner into a Table. + + Use this convenience utility with care. This will serially materialize + the Scan result in memory before creating the Table. + + Returns + ------- + table : Table + """ + cdef CResult[shared_ptr[CTable]] result + + with nogil: + result = self.scanner.ToTable() + + return pyarrow_wrap_table(GetResultValue(result)) + + def take(self, object indices): + """Select rows of data by index. + + Will only consume as many batches of the underlying dataset as + needed. Otherwise, this is equivalent to + ``to_table().take(indices)``. + + Returns + ------- + table : Table + """ + cdef CResult[shared_ptr[CTable]] result + cdef shared_ptr[CArray] c_indices = pyarrow_unwrap_array(indices) + with nogil: + result = self.scanner.TakeRows(deref(c_indices)) + return pyarrow_wrap_table(GetResultValue(result)) + + def head(self, int num_rows): + """Load the first N rows of the dataset. + + Returns + ------- + table : Table instance + """ + cdef CResult[shared_ptr[CTable]] result + with nogil: + result = self.scanner.Head(num_rows) + return pyarrow_wrap_table(GetResultValue(result)) + + def count_rows(self): + """Count rows matching the scanner filter. + + Returns + ------- + count : int + """ + cdef CResult[int64_t] result + with nogil: + result = self.scanner.CountRows() + return GetResultValue(result) + + def to_reader(self): + """Consume this scanner as a RecordBatchReader.""" + cdef RecordBatchReader reader + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = GetResultValue(self.scanner.ToRecordBatchReader()) + return reader + + +def _get_partition_keys(Expression partition_expression): + """ + Extract partition keys (equality constraints between a field and a scalar) + from an expression as a dict mapping the field's name to its value. + + NB: All expressions yielded by a HivePartitioning or DirectoryPartitioning + will be conjunctions of equality conditions and are accessible through this + function. Other subexpressions will be ignored. + + For example, an expression of + <pyarrow.dataset.Expression ((part == A:string) and (year == 2016:int32))> + is converted to {'part': 'A', 'year': 2016} + """ + cdef: + CExpression expr = partition_expression.unwrap() + pair[CFieldRef, CDatum] ref_val + + out = {} + for ref_val in GetResultValue(CExtractKnownFieldValues(expr)).map: + assert ref_val.first.name() != nullptr + assert ref_val.second.kind() == DatumType_SCALAR + val = pyarrow_wrap_scalar(ref_val.second.scalar()) + out[frombytes(deref(ref_val.first.name()))] = val.as_py() + return out + + +ctypedef CParquetFileWriter* _CParquetFileWriterPtr + +cdef class WrittenFile(_Weakrefable): + """ + Metadata information about files written as + part of a dataset write operation + """ + + """The full path to the created file""" + cdef public str path + """ + If the file is a parquet file this will contain the parquet metadata. + This metadata will have the file path attribute set to the path of + the written file. + """ + cdef public object metadata + + def __init__(self, path, metadata): + self.path = path + self.metadata = metadata + +cdef void _filesystemdataset_write_visitor( + dict visit_args, + CFileWriter* file_writer): + cdef: + str path + str base_dir + WrittenFile written_file + FileMetaData parquet_metadata + CParquetFileWriter* parquet_file_writer + + parquet_metadata = None + path = frombytes(deref(file_writer).destination().path) + if deref(deref(file_writer).format()).type_name() == b"parquet": + parquet_file_writer = dynamic_cast[_CParquetFileWriterPtr](file_writer) + with nogil: + metadata = deref( + deref(parquet_file_writer).parquet_writer()).metadata() + if metadata: + base_dir = frombytes(visit_args['base_dir']) + parquet_metadata = FileMetaData() + parquet_metadata.init(metadata) + parquet_metadata.set_file_path(os.path.relpath(path, base_dir)) + written_file = WrittenFile(path, parquet_metadata) + visit_args['file_visitor'](written_file) + + +def _filesystemdataset_write( + Scanner data not None, + object base_dir not None, + str basename_template not None, + FileSystem filesystem not None, + Partitioning partitioning not None, + FileWriteOptions file_options not None, + int max_partitions, + object file_visitor, + str existing_data_behavior not None +): + """ + CFileSystemDataset.Write wrapper + """ + cdef: + CFileSystemDatasetWriteOptions c_options + shared_ptr[CScanner] c_scanner + vector[shared_ptr[CRecordBatch]] c_batches + dict visit_args + + c_options.file_write_options = file_options.unwrap() + c_options.filesystem = filesystem.unwrap() + c_options.base_dir = tobytes(_stringify_path(base_dir)) + c_options.partitioning = partitioning.unwrap() + c_options.max_partitions = max_partitions + c_options.basename_template = tobytes(basename_template) + if existing_data_behavior == 'error': + c_options.existing_data_behavior = ExistingDataBehavior_ERROR + elif existing_data_behavior == 'overwrite_or_ignore': + c_options.existing_data_behavior =\ + ExistingDataBehavior_OVERWRITE_OR_IGNORE + elif existing_data_behavior == 'delete_matching': + c_options.existing_data_behavior = ExistingDataBehavior_DELETE_MATCHING + else: + raise ValueError( + ("existing_data_behavior must be one of 'error', ", + "'overwrite_or_ignore' or 'delete_matching'") + ) + + if file_visitor is not None: + visit_args = {'base_dir': c_options.base_dir, + 'file_visitor': file_visitor} + # Need to use post_finish because parquet metadata is not available + # until after Finish has been called + c_options.writer_post_finish = BindFunction[cb_writer_finish_internal]( + &_filesystemdataset_write_visitor, visit_args) + + c_scanner = data.unwrap() + with nogil: + check_status(CFileSystemDataset.Write(c_options, c_scanner)) diff --git a/src/arrow/python/pyarrow/_dataset_orc.pyx b/src/arrow/python/pyarrow/_dataset_orc.pyx new file mode 100644 index 000000000..40a21ef54 --- /dev/null +++ b/src/arrow/python/pyarrow/_dataset_orc.pyx @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +"""Dataset support for ORC file format.""" + +from pyarrow.lib cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_dataset cimport * + +from pyarrow._dataset cimport FileFormat + + +cdef class OrcFileFormat(FileFormat): + + def __init__(self): + self.init(shared_ptr[CFileFormat](new COrcFileFormat())) + + def equals(self, OrcFileFormat other): + return True + + @property + def default_extname(self): + return "orc" + + def __reduce__(self): + return OrcFileFormat, tuple() diff --git a/src/arrow/python/pyarrow/_feather.pyx b/src/arrow/python/pyarrow/_feather.pyx new file mode 100644 index 000000000..8df7935aa --- /dev/null +++ b/src/arrow/python/pyarrow/_feather.pyx @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# --------------------------------------------------------------------- +# Implement Feather file format + +# cython: profile=False +# distutils: language = c++ +# cython: language_level=3 + +from cython.operator cimport dereference as deref +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_feather cimport * +from pyarrow.lib cimport (check_status, Table, _Weakrefable, + get_writer, get_reader, pyarrow_wrap_table) +from pyarrow.lib import tobytes + + +class FeatherError(Exception): + pass + + +def write_feather(Table table, object dest, compression=None, + compression_level=None, chunksize=None, version=2): + cdef shared_ptr[COutputStream] sink + get_writer(dest, &sink) + + cdef CFeatherProperties properties + if version == 2: + properties.version = kFeatherV2Version + else: + properties.version = kFeatherV1Version + + if compression == 'zstd': + properties.compression = CCompressionType_ZSTD + elif compression == 'lz4': + properties.compression = CCompressionType_LZ4_FRAME + else: + properties.compression = CCompressionType_UNCOMPRESSED + + if chunksize is not None: + properties.chunksize = chunksize + + if compression_level is not None: + properties.compression_level = compression_level + + with nogil: + check_status(WriteFeather(deref(table.table), sink.get(), + properties)) + + +cdef class FeatherReader(_Weakrefable): + cdef: + shared_ptr[CFeatherReader] reader + + def __cinit__(self, source, c_bool use_memory_map): + cdef shared_ptr[CRandomAccessFile] reader + get_reader(source, use_memory_map, &reader) + with nogil: + self.reader = GetResultValue(CFeatherReader.Open(reader)) + + @property + def version(self): + return self.reader.get().version() + + def read(self): + cdef shared_ptr[CTable] sp_table + with nogil: + check_status(self.reader.get() + .Read(&sp_table)) + + return pyarrow_wrap_table(sp_table) + + def read_indices(self, indices): + cdef: + shared_ptr[CTable] sp_table + vector[int] c_indices + + for index in indices: + c_indices.push_back(index) + with nogil: + check_status(self.reader.get() + .Read(c_indices, &sp_table)) + + return pyarrow_wrap_table(sp_table) + + def read_names(self, names): + cdef: + shared_ptr[CTable] sp_table + vector[c_string] c_names + + for name in names: + c_names.push_back(tobytes(name)) + with nogil: + check_status(self.reader.get() + .Read(c_names, &sp_table)) + + return pyarrow_wrap_table(sp_table) diff --git a/src/arrow/python/pyarrow/_flight.pyx b/src/arrow/python/pyarrow/_flight.pyx new file mode 100644 index 000000000..5b00c531d --- /dev/null +++ b/src/arrow/python/pyarrow/_flight.pyx @@ -0,0 +1,2664 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +import collections +import contextlib +import enum +import re +import socket +import time +import threading +import warnings + +from cython.operator cimport dereference as deref +from cython.operator cimport postincrement +from libcpp cimport bool as c_bool + +from pyarrow.lib cimport * +from pyarrow.lib import ArrowException, ArrowInvalid, SignalStopHandler +from pyarrow.lib import as_buffer, frombytes, tobytes +from pyarrow.includes.libarrow_flight cimport * +from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin +import pyarrow.lib as lib + + +cdef CFlightCallOptions DEFAULT_CALL_OPTIONS + + +cdef int check_flight_status(const CStatus& status) nogil except -1: + cdef shared_ptr[FlightStatusDetail] detail + + if status.ok(): + return 0 + + detail = FlightStatusDetail.UnwrapStatus(status) + if detail: + with gil: + message = frombytes(status.message(), safe=True) + detail_msg = detail.get().extra_info() + if detail.get().code() == CFlightStatusInternal: + raise FlightInternalError(message, detail_msg) + elif detail.get().code() == CFlightStatusFailed: + message = _munge_grpc_python_error(message) + raise FlightServerError(message, detail_msg) + elif detail.get().code() == CFlightStatusTimedOut: + raise FlightTimedOutError(message, detail_msg) + elif detail.get().code() == CFlightStatusCancelled: + raise FlightCancelledError(message, detail_msg) + elif detail.get().code() == CFlightStatusUnauthenticated: + raise FlightUnauthenticatedError(message, detail_msg) + elif detail.get().code() == CFlightStatusUnauthorized: + raise FlightUnauthorizedError(message, detail_msg) + elif detail.get().code() == CFlightStatusUnavailable: + raise FlightUnavailableError(message, detail_msg) + + size_detail = FlightWriteSizeStatusDetail.UnwrapStatus(status) + if size_detail: + with gil: + message = frombytes(status.message(), safe=True) + raise FlightWriteSizeExceededError( + message, + size_detail.get().limit(), size_detail.get().actual()) + + return check_status(status) + + +_FLIGHT_SERVER_ERROR_REGEX = re.compile( + r'Flight RPC failed with message: (.*). Detail: ' + r'Python exception: (.*)', + re.DOTALL +) + + +def _munge_grpc_python_error(message): + m = _FLIGHT_SERVER_ERROR_REGEX.match(message) + if m: + return ('Flight RPC failed with Python exception \"{}: {}\"' + .format(m.group(2), m.group(1))) + else: + return message + + +cdef IpcWriteOptions _get_options(options): + return <IpcWriteOptions> _get_legacy_format_default( + use_legacy_format=None, options=options) + + +cdef class FlightCallOptions(_Weakrefable): + """RPC-layer options for a Flight call.""" + + cdef: + CFlightCallOptions options + + def __init__(self, timeout=None, write_options=None, headers=None): + """Create call options. + + Parameters + ---------- + timeout : float, None + A timeout for the call, in seconds. None means that the + timeout defaults to an implementation-specific value. + write_options : pyarrow.ipc.IpcWriteOptions, optional + IPC write options. The default options can be controlled + by environment variables (see pyarrow.ipc). + headers : List[Tuple[str, str]], optional + A list of arbitrary headers as key, value tuples + """ + cdef IpcWriteOptions c_write_options + + if timeout is not None: + self.options.timeout = CTimeoutDuration(timeout) + if write_options is not None: + c_write_options = _get_options(write_options) + self.options.write_options = c_write_options.c_options + if headers is not None: + self.options.headers = headers + + @staticmethod + cdef CFlightCallOptions* unwrap(obj): + if not obj: + return &DEFAULT_CALL_OPTIONS + elif isinstance(obj, FlightCallOptions): + return &((<FlightCallOptions> obj).options) + raise TypeError("Expected a FlightCallOptions object, not " + "'{}'".format(type(obj))) + + +_CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key']) + + +class CertKeyPair(_CertKeyPair): + """A TLS certificate and key for use in Flight.""" + + +cdef class FlightError(Exception): + cdef dict __dict__ + + def __init__(self, message='', extra_info=b''): + super().__init__(message) + self.extra_info = tobytes(extra_info) + + cdef CStatus to_status(self): + message = tobytes("Flight error: {}".format(str(self))) + return CStatus_UnknownError(message) + +cdef class FlightInternalError(FlightError, ArrowException): + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusInternal, + tobytes(str(self)), self.extra_info) + + +cdef class FlightTimedOutError(FlightError, ArrowException): + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusTimedOut, + tobytes(str(self)), self.extra_info) + + +cdef class FlightCancelledError(FlightError, ArrowException): + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)), + self.extra_info) + + +cdef class FlightServerError(FlightError, ArrowException): + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusFailed, tobytes(str(self)), + self.extra_info) + + +cdef class FlightUnauthenticatedError(FlightError, ArrowException): + cdef CStatus to_status(self): + return MakeFlightError( + CFlightStatusUnauthenticated, tobytes(str(self)), self.extra_info) + + +cdef class FlightUnauthorizedError(FlightError, ArrowException): + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusUnauthorized, tobytes(str(self)), + self.extra_info) + + +cdef class FlightUnavailableError(FlightError, ArrowException): + cdef CStatus to_status(self): + return MakeFlightError(CFlightStatusUnavailable, tobytes(str(self)), + self.extra_info) + + +class FlightWriteSizeExceededError(ArrowInvalid): + """A write operation exceeded the client-configured limit.""" + + def __init__(self, message, limit, actual): + super().__init__(message) + self.limit = limit + self.actual = actual + + +cdef class Action(_Weakrefable): + """An action executable on a Flight service.""" + cdef: + CAction action + + def __init__(self, action_type, buf): + """Create an action from a type and a buffer. + + Parameters + ---------- + action_type : bytes or str + buf : Buffer or bytes-like object + """ + self.action.type = tobytes(action_type) + self.action.body = pyarrow_unwrap_buffer(as_buffer(buf)) + + @property + def type(self): + """The action type.""" + return frombytes(self.action.type) + + @property + def body(self): + """The action body (arguments for the action).""" + return pyarrow_wrap_buffer(self.action.body) + + @staticmethod + cdef CAction unwrap(action) except *: + if not isinstance(action, Action): + raise TypeError("Must provide Action, not '{}'".format( + type(action))) + return (<Action> action).action + + +_ActionType = collections.namedtuple('_ActionType', ['type', 'description']) + + +class ActionType(_ActionType): + """A type of action that is executable on a Flight service.""" + + def make_action(self, buf): + """Create an Action with this type. + + Parameters + ---------- + buf : obj + An Arrow buffer or Python bytes or bytes-like object. + """ + return Action(self.type, buf) + + +cdef class Result(_Weakrefable): + """A result from executing an Action.""" + cdef: + unique_ptr[CFlightResult] result + + def __init__(self, buf): + """Create a new result. + + Parameters + ---------- + buf : Buffer or bytes-like object + """ + self.result.reset(new CFlightResult()) + self.result.get().body = pyarrow_unwrap_buffer(as_buffer(buf)) + + @property + def body(self): + """Get the Buffer containing the result.""" + return pyarrow_wrap_buffer(self.result.get().body) + + +cdef class BasicAuth(_Weakrefable): + """A container for basic auth.""" + cdef: + unique_ptr[CBasicAuth] basic_auth + + def __init__(self, username=None, password=None): + """Create a new basic auth object. + + Parameters + ---------- + username : string + password : string + """ + self.basic_auth.reset(new CBasicAuth()) + if username: + self.basic_auth.get().username = tobytes(username) + if password: + self.basic_auth.get().password = tobytes(password) + + @property + def username(self): + """Get the username.""" + return self.basic_auth.get().username + + @property + def password(self): + """Get the password.""" + return self.basic_auth.get().password + + @staticmethod + def deserialize(string): + auth = BasicAuth() + check_flight_status(DeserializeBasicAuth(string, &auth.basic_auth)) + return auth + + def serialize(self): + cdef: + c_string auth + check_flight_status(SerializeBasicAuth(deref(self.basic_auth), &auth)) + return frombytes(auth) + + +class DescriptorType(enum.Enum): + """ + The type of a FlightDescriptor. + + Attributes + ---------- + + UNKNOWN + An unknown descriptor type. + + PATH + A Flight stream represented by a path. + + CMD + A Flight stream represented by an application-defined command. + + """ + + UNKNOWN = 0 + PATH = 1 + CMD = 2 + + +class FlightMethod(enum.Enum): + """The implemented methods in Flight.""" + + INVALID = 0 + HANDSHAKE = 1 + LIST_FLIGHTS = 2 + GET_FLIGHT_INFO = 3 + GET_SCHEMA = 4 + DO_GET = 5 + DO_PUT = 6 + DO_ACTION = 7 + LIST_ACTIONS = 8 + DO_EXCHANGE = 9 + + +cdef wrap_flight_method(CFlightMethod method): + if method == CFlightMethodHandshake: + return FlightMethod.HANDSHAKE + elif method == CFlightMethodListFlights: + return FlightMethod.LIST_FLIGHTS + elif method == CFlightMethodGetFlightInfo: + return FlightMethod.GET_FLIGHT_INFO + elif method == CFlightMethodGetSchema: + return FlightMethod.GET_SCHEMA + elif method == CFlightMethodDoGet: + return FlightMethod.DO_GET + elif method == CFlightMethodDoPut: + return FlightMethod.DO_PUT + elif method == CFlightMethodDoAction: + return FlightMethod.DO_ACTION + elif method == CFlightMethodListActions: + return FlightMethod.LIST_ACTIONS + elif method == CFlightMethodDoExchange: + return FlightMethod.DO_EXCHANGE + return FlightMethod.INVALID + + +cdef class FlightDescriptor(_Weakrefable): + """A description of a data stream available from a Flight service.""" + cdef: + CFlightDescriptor descriptor + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "`pyarrow.flight.FlightDescriptor.for_{path,command}` " + "function instead." + .format(self.__class__.__name__)) + + @staticmethod + def for_path(*path): + """Create a FlightDescriptor for a resource path.""" + cdef FlightDescriptor result = \ + FlightDescriptor.__new__(FlightDescriptor) + result.descriptor.type = CDescriptorTypePath + result.descriptor.path = [tobytes(p) for p in path] + return result + + @staticmethod + def for_command(command): + """Create a FlightDescriptor for an opaque command.""" + cdef FlightDescriptor result = \ + FlightDescriptor.__new__(FlightDescriptor) + result.descriptor.type = CDescriptorTypeCmd + result.descriptor.cmd = tobytes(command) + return result + + @property + def descriptor_type(self): + """Get the type of this descriptor.""" + if self.descriptor.type == CDescriptorTypeUnknown: + return DescriptorType.UNKNOWN + elif self.descriptor.type == CDescriptorTypePath: + return DescriptorType.PATH + elif self.descriptor.type == CDescriptorTypeCmd: + return DescriptorType.CMD + raise RuntimeError("Invalid descriptor type!") + + @property + def command(self): + """Get the command for this descriptor.""" + if self.descriptor_type != DescriptorType.CMD: + return None + return self.descriptor.cmd + + @property + def path(self): + """Get the path for this descriptor.""" + if self.descriptor_type != DescriptorType.PATH: + return None + return self.descriptor.path + + def __repr__(self): + if self.descriptor_type == DescriptorType.PATH: + return "<FlightDescriptor path: {!r}>".format(self.path) + elif self.descriptor_type == DescriptorType.CMD: + return "<FlightDescriptor command: {!r}>".format(self.command) + else: + return "<FlightDescriptor type: {!r}>".format(self.descriptor_type) + + @staticmethod + cdef CFlightDescriptor unwrap(descriptor) except *: + if not isinstance(descriptor, FlightDescriptor): + raise TypeError("Must provide a FlightDescriptor, not '{}'".format( + type(descriptor))) + return (<FlightDescriptor> descriptor).descriptor + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef c_string out + check_flight_status(self.descriptor.SerializeToString(&out)) + return out + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef FlightDescriptor descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + check_flight_status(CFlightDescriptor.Deserialize( + tobytes(serialized), &descriptor.descriptor)) + return descriptor + + def __eq__(self, FlightDescriptor other): + return self.descriptor == other.descriptor + + +cdef class Ticket(_Weakrefable): + """A ticket for requesting a Flight stream.""" + + cdef: + CTicket ticket + + def __init__(self, ticket): + self.ticket.ticket = tobytes(ticket) + + @property + def ticket(self): + return self.ticket.ticket + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef c_string out + check_flight_status(self.ticket.SerializeToString(&out)) + return out + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef: + CTicket c_ticket + Ticket ticket + check_flight_status( + CTicket.Deserialize(tobytes(serialized), &c_ticket)) + ticket = Ticket.__new__(Ticket) + ticket.ticket = c_ticket + return ticket + + def __eq__(self, Ticket other): + return self.ticket == other.ticket + + def __repr__(self): + return '<Ticket {}>'.format(self.ticket.ticket) + + +cdef class Location(_Weakrefable): + """The location of a Flight service.""" + cdef: + CLocation location + + def __init__(self, uri): + check_flight_status(CLocation.Parse(tobytes(uri), &self.location)) + + def __repr__(self): + return '<Location {}>'.format(self.location.ToString()) + + @property + def uri(self): + return self.location.ToString() + + def equals(self, Location other): + return self == other + + def __eq__(self, other): + if not isinstance(other, Location): + return NotImplemented + return self.location.Equals((<Location> other).location) + + @staticmethod + def for_grpc_tcp(host, port): + """Create a Location for a TCP-based gRPC service.""" + cdef: + c_string c_host = tobytes(host) + int c_port = port + Location result = Location.__new__(Location) + check_flight_status( + CLocation.ForGrpcTcp(c_host, c_port, &result.location)) + return result + + @staticmethod + def for_grpc_tls(host, port): + """Create a Location for a TLS-based gRPC service.""" + cdef: + c_string c_host = tobytes(host) + int c_port = port + Location result = Location.__new__(Location) + check_flight_status( + CLocation.ForGrpcTls(c_host, c_port, &result.location)) + return result + + @staticmethod + def for_grpc_unix(path): + """Create a Location for a domain socket-based gRPC service.""" + cdef: + c_string c_path = tobytes(path) + Location result = Location.__new__(Location) + check_flight_status(CLocation.ForGrpcUnix(c_path, &result.location)) + return result + + @staticmethod + cdef Location wrap(CLocation location): + cdef Location result = Location.__new__(Location) + result.location = location + return result + + @staticmethod + cdef CLocation unwrap(object location) except *: + cdef CLocation c_location + if isinstance(location, str): + check_flight_status( + CLocation.Parse(tobytes(location), &c_location)) + return c_location + elif not isinstance(location, Location): + raise TypeError("Must provide a Location, not '{}'".format( + type(location))) + return (<Location> location).location + + +cdef class FlightEndpoint(_Weakrefable): + """A Flight stream, along with the ticket and locations to access it.""" + cdef: + CFlightEndpoint endpoint + + def __init__(self, ticket, locations): + """Create a FlightEndpoint from a ticket and list of locations. + + Parameters + ---------- + ticket : Ticket or bytes + the ticket needed to access this flight + locations : list of string URIs + locations where this flight is available + + Raises + ------ + ArrowException + If one of the location URIs is not a valid URI. + """ + cdef: + CLocation c_location + + if isinstance(ticket, Ticket): + self.endpoint.ticket.ticket = tobytes(ticket.ticket) + else: + self.endpoint.ticket.ticket = tobytes(ticket) + + for location in locations: + if isinstance(location, Location): + c_location = (<Location> location).location + else: + c_location = CLocation() + check_flight_status( + CLocation.Parse(tobytes(location), &c_location)) + self.endpoint.locations.push_back(c_location) + + @property + def ticket(self): + """Get the ticket in this endpoint.""" + return Ticket(self.endpoint.ticket.ticket) + + @property + def locations(self): + return [Location.wrap(location) + for location in self.endpoint.locations] + + def __repr__(self): + return "<FlightEndpoint ticket: {!r} locations: {!r}>".format( + self.ticket, self.locations) + + def __eq__(self, FlightEndpoint other): + return self.endpoint == other.endpoint + + +cdef class SchemaResult(_Weakrefable): + """A result from a getschema request. Holding a schema""" + cdef: + unique_ptr[CSchemaResult] result + + def __init__(self, Schema schema): + """Create a SchemaResult from a schema. + + Parameters + ---------- + schema: Schema + the schema of the data in this flight. + """ + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + check_flight_status(CreateSchemaResult(c_schema, &self.result)) + + @property + def schema(self): + """The schema of the data in this flight.""" + cdef: + shared_ptr[CSchema] schema + CDictionaryMemo dummy_memo + + check_flight_status(self.result.get().GetSchema(&dummy_memo, &schema)) + return pyarrow_wrap_schema(schema) + + +cdef class FlightInfo(_Weakrefable): + """A description of a Flight stream.""" + cdef: + unique_ptr[CFlightInfo] info + + def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints, + total_records, total_bytes): + """Create a FlightInfo object from a schema, descriptor, and endpoints. + + Parameters + ---------- + schema : Schema + the schema of the data in this flight. + descriptor : FlightDescriptor + the descriptor for this flight. + endpoints : list of FlightEndpoint + a list of endpoints where this flight is available. + total_records : int + the total records in this flight, or -1 if unknown + total_bytes : int + the total bytes in this flight, or -1 if unknown + """ + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + vector[CFlightEndpoint] c_endpoints + + for endpoint in endpoints: + if isinstance(endpoint, FlightEndpoint): + c_endpoints.push_back((<FlightEndpoint> endpoint).endpoint) + else: + raise TypeError('Endpoint {} is not instance of' + ' FlightEndpoint'.format(endpoint)) + + check_flight_status(CreateFlightInfo(c_schema, + descriptor.descriptor, + c_endpoints, + total_records, + total_bytes, &self.info)) + + @property + def total_records(self): + """The total record count of this flight, or -1 if unknown.""" + return self.info.get().total_records() + + @property + def total_bytes(self): + """The size in bytes of the data in this flight, or -1 if unknown.""" + return self.info.get().total_bytes() + + @property + def schema(self): + """The schema of the data in this flight.""" + cdef: + shared_ptr[CSchema] schema + CDictionaryMemo dummy_memo + + check_flight_status(self.info.get().GetSchema(&dummy_memo, &schema)) + return pyarrow_wrap_schema(schema) + + @property + def descriptor(self): + """The descriptor of the data in this flight.""" + cdef FlightDescriptor result = \ + FlightDescriptor.__new__(FlightDescriptor) + result.descriptor = self.info.get().descriptor() + return result + + @property + def endpoints(self): + """The endpoints where this flight is available.""" + # TODO: get Cython to iterate over reference directly + cdef: + vector[CFlightEndpoint] endpoints = self.info.get().endpoints() + FlightEndpoint py_endpoint + + result = [] + for endpoint in endpoints: + py_endpoint = FlightEndpoint.__new__(FlightEndpoint) + py_endpoint.endpoint = endpoint + result.append(py_endpoint) + return result + + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef c_string out + check_flight_status(self.info.get().SerializeToString(&out)) + return out + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef FlightInfo info = FlightInfo.__new__(FlightInfo) + check_flight_status(CFlightInfo.Deserialize( + tobytes(serialized), &info.info)) + return info + + +cdef class FlightStreamChunk(_Weakrefable): + """A RecordBatch with application metadata on the side.""" + cdef: + CFlightStreamChunk chunk + + @property + def data(self): + if self.chunk.data == NULL: + return None + return pyarrow_wrap_batch(self.chunk.data) + + @property + def app_metadata(self): + if self.chunk.app_metadata == NULL: + return None + return pyarrow_wrap_buffer(self.chunk.app_metadata) + + def __iter__(self): + return iter((self.data, self.app_metadata)) + + def __repr__(self): + return "<FlightStreamChunk with data: {} with metadata: {}>".format( + self.chunk.data != NULL, self.chunk.app_metadata != NULL) + + +cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin): + """A reader for Flight streams.""" + + # Needs to be separate class so the "real" class can subclass the + # pure-Python mixin class + + cdef dict __dict__ + cdef shared_ptr[CMetadataRecordBatchReader] reader + + def __iter__(self): + while True: + yield self.read_chunk() + + @property + def schema(self): + """Get the schema for this reader.""" + cdef shared_ptr[CSchema] c_schema + with nogil: + c_schema = GetResultValue(self.reader.get().GetSchema()) + return pyarrow_wrap_schema(c_schema) + + def read_all(self): + """Read the entire contents of the stream as a Table.""" + cdef: + shared_ptr[CTable] c_table + with nogil: + check_flight_status(self.reader.get().ReadAll(&c_table)) + return pyarrow_wrap_table(c_table) + + def read_chunk(self): + """Read the next RecordBatch along with any metadata. + + Returns + ------- + data : RecordBatch + The next RecordBatch in the stream. + app_metadata : Buffer or None + Application-specific metadata for the batch as defined by + Flight. + + Raises + ------ + StopIteration + when the stream is finished + """ + cdef: + FlightStreamChunk chunk = FlightStreamChunk() + + with nogil: + check_flight_status(self.reader.get().Next(&chunk.chunk)) + + if chunk.chunk.data == NULL and chunk.chunk.app_metadata == NULL: + raise StopIteration + + return chunk + + def to_reader(self): + """Convert this reader into a regular RecordBatchReader. + + This may fail if the schema cannot be read from the remote end. + """ + cdef RecordBatchReader reader + reader = RecordBatchReader.__new__(RecordBatchReader) + reader.reader = GetResultValue(MakeRecordBatchReader(self.reader)) + return reader + + +cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader): + """The virtual base class for readers for Flight streams.""" + + +cdef class FlightStreamReader(MetadataRecordBatchReader): + """A reader that can also be canceled.""" + + def cancel(self): + """Cancel the read operation.""" + with nogil: + (<CFlightStreamReader*> self.reader.get()).Cancel() + + def read_all(self): + """Read the entire contents of the stream as a Table.""" + cdef: + shared_ptr[CTable] c_table + CStopToken stop_token + with SignalStopHandler() as stop_handler: + stop_token = (<StopToken> stop_handler.stop_token).stop_token + with nogil: + check_flight_status( + (<CFlightStreamReader*> self.reader.get()) + .ReadAllWithStopToken(&c_table, stop_token)) + return pyarrow_wrap_table(c_table) + + +cdef class MetadataRecordBatchWriter(_CRecordBatchWriter): + """A RecordBatchWriter that also allows writing application metadata. + + This class is a context manager; on exit, close() will be called. + """ + + cdef CMetadataRecordBatchWriter* _writer(self) nogil: + return <CMetadataRecordBatchWriter*> self.writer.get() + + def begin(self, schema: Schema, options=None): + """Prepare to write data to this stream with the given schema.""" + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + CIpcWriteOptions c_options = _get_options(options).c_options + with nogil: + check_flight_status(self._writer().Begin(c_schema, c_options)) + + def write_metadata(self, buf): + """Write Flight metadata by itself.""" + cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf)) + with nogil: + check_flight_status( + self._writer().WriteMetadata(c_buf)) + + def write_batch(self, RecordBatch batch): + """ + Write RecordBatch to stream. + + Parameters + ---------- + batch : RecordBatch + """ + # Override superclass method to use check_flight_status so we + # can generate FlightWriteSizeExceededError. We don't do this + # for write_table as callers who intend to handle the error + # and retry with a smaller batch should be working with + # individual batches to have control. + with nogil: + check_flight_status( + self._writer().WriteRecordBatch(deref(batch.batch))) + + def write_table(self, Table table, max_chunksize=None, **kwargs): + """ + Write Table to stream in (contiguous) RecordBatch objects. + + Parameters + ---------- + table : Table + max_chunksize : int, default None + Maximum size for RecordBatch chunks. Individual chunks may be + smaller depending on the chunk layout of individual columns. + """ + cdef: + # max_chunksize must be > 0 to have any impact + int64_t c_max_chunksize = -1 + + if 'chunksize' in kwargs: + max_chunksize = kwargs['chunksize'] + msg = ('The parameter chunksize is deprecated for the write_table ' + 'methods as of 0.15, please use parameter ' + 'max_chunksize instead') + warnings.warn(msg, FutureWarning) + + if max_chunksize is not None: + c_max_chunksize = max_chunksize + + with nogil: + check_flight_status( + self._writer().WriteTable(table.table[0], c_max_chunksize)) + + def close(self): + """ + Close stream and write end-of-stream 0 marker. + """ + with nogil: + check_flight_status(self._writer().Close()) + + def write_with_metadata(self, RecordBatch batch, buf): + """Write a RecordBatch along with Flight metadata. + + Parameters + ---------- + batch : RecordBatch + The next RecordBatch in the stream. + buf : Buffer + Application-specific metadata for the batch as defined by + Flight. + """ + cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf)) + with nogil: + check_flight_status( + self._writer().WriteWithMetadata(deref(batch.batch), c_buf)) + + +cdef class FlightStreamWriter(MetadataRecordBatchWriter): + """A writer that also allows closing the write side of a stream.""" + + def done_writing(self): + """Indicate that the client is done writing, but not done reading.""" + with nogil: + check_flight_status( + (<CFlightStreamWriter*> self.writer.get()).DoneWriting()) + + +cdef class FlightMetadataReader(_Weakrefable): + """A reader for Flight metadata messages sent during a DoPut.""" + + cdef: + unique_ptr[CFlightMetadataReader] reader + + def read(self): + """Read the next metadata message.""" + cdef shared_ptr[CBuffer] buf + with nogil: + check_flight_status(self.reader.get().ReadMetadata(&buf)) + if buf == NULL: + return None + return pyarrow_wrap_buffer(buf) + + +cdef class FlightMetadataWriter(_Weakrefable): + """A sender for Flight metadata messages during a DoPut.""" + + cdef: + unique_ptr[CFlightMetadataWriter] writer + + def write(self, message): + """Write the next metadata message. + + Parameters + ---------- + message : Buffer + """ + cdef shared_ptr[CBuffer] buf = \ + pyarrow_unwrap_buffer(as_buffer(message)) + with nogil: + check_flight_status(self.writer.get().WriteMetadata(deref(buf))) + + +cdef class FlightClient(_Weakrefable): + """A client to a Flight service. + + Connect to a Flight service on the given host and port. + + Parameters + ---------- + location : str, tuple or Location + Location to connect to. Either a gRPC URI like `grpc://localhost:port`, + a tuple of (host, port) pair, or a Location instance. + tls_root_certs : bytes or None + PEM-encoded + cert_chain: bytes or None + Client certificate if using mutual TLS + private_key: bytes or None + Client private key for cert_chain is using mutual TLS + override_hostname : str or None + Override the hostname checked by TLS. Insecure, use with caution. + middleware : list optional, default None + A list of ClientMiddlewareFactory instances. + write_size_limit_bytes : int optional, default None + A soft limit on the size of a data payload sent to the + server. Enabled if positive. If enabled, writing a record + batch that (when serialized) exceeds this limit will raise an + exception; the client can retry the write with a smaller + batch. + disable_server_verification : boolean optional, default False + A flag that indicates that, if the client is connecting + with TLS, that it skips server verification. If this is + enabled, all other TLS settings are overridden. + generic_options : list optional, default None + A list of generic (string, int or string) option tuples passed + to the underlying transport. Effect is implementation + dependent. + """ + cdef: + unique_ptr[CFlightClient] client + + def __init__(self, location, *, tls_root_certs=None, cert_chain=None, + private_key=None, override_hostname=None, middleware=None, + write_size_limit_bytes=None, + disable_server_verification=None, generic_options=None): + if isinstance(location, (bytes, str)): + location = Location(location) + elif isinstance(location, tuple): + host, port = location + if tls_root_certs or disable_server_verification is not None: + location = Location.for_grpc_tls(host, port) + else: + location = Location.for_grpc_tcp(host, port) + elif not isinstance(location, Location): + raise TypeError('`location` argument must be a string, tuple or a ' + 'Location instance') + self.init(location, tls_root_certs, cert_chain, private_key, + override_hostname, middleware, write_size_limit_bytes, + disable_server_verification, generic_options) + + cdef init(self, Location location, tls_root_certs, cert_chain, + private_key, override_hostname, middleware, + write_size_limit_bytes, disable_server_verification, + generic_options): + cdef: + int c_port = 0 + CLocation c_location = Location.unwrap(location) + CFlightClientOptions c_options = CFlightClientOptions.Defaults() + function[cb_client_middleware_start_call] start_call = \ + &_client_middleware_start_call + CIntStringVariant variant + + if tls_root_certs: + c_options.tls_root_certs = tobytes(tls_root_certs) + if cert_chain: + c_options.cert_chain = tobytes(cert_chain) + if private_key: + c_options.private_key = tobytes(private_key) + if override_hostname: + c_options.override_hostname = tobytes(override_hostname) + if disable_server_verification is not None: + c_options.disable_server_verification = disable_server_verification + if middleware: + for factory in middleware: + c_options.middleware.push_back( + <shared_ptr[CClientMiddlewareFactory]> + make_shared[CPyClientMiddlewareFactory]( + <PyObject*> factory, start_call)) + if write_size_limit_bytes is not None: + c_options.write_size_limit_bytes = write_size_limit_bytes + else: + c_options.write_size_limit_bytes = 0 + if generic_options: + for key, value in generic_options: + if isinstance(value, (str, bytes)): + variant = CIntStringVariant(<c_string> tobytes(value)) + else: + variant = CIntStringVariant(<int> value) + c_options.generic_options.push_back( + pair[c_string, CIntStringVariant](tobytes(key), variant)) + + with nogil: + check_flight_status(CFlightClient.Connect(c_location, c_options, + &self.client)) + + def wait_for_available(self, timeout=5): + """Block until the server can be contacted. + + Parameters + ---------- + timeout : int, default 5 + The maximum seconds to wait. + """ + deadline = time.time() + timeout + while True: + try: + list(self.list_flights()) + except FlightUnavailableError: + if time.time() < deadline: + time.sleep(0.025) + continue + else: + raise + except NotImplementedError: + # allow if list_flights is not implemented, because + # the server can be contacted nonetheless + break + else: + break + + @classmethod + def connect(cls, location, tls_root_certs=None, cert_chain=None, + private_key=None, override_hostname=None, + disable_server_verification=None): + warnings.warn("The 'FlightClient.connect' method is deprecated, use " + "FlightClient constructor or pyarrow.flight.connect " + "function instead") + return FlightClient( + location, tls_root_certs=tls_root_certs, + cert_chain=cert_chain, private_key=private_key, + override_hostname=override_hostname, + disable_server_verification=disable_server_verification + ) + + def authenticate(self, auth_handler, options: FlightCallOptions = None): + """Authenticate to the server. + + Parameters + ---------- + auth_handler : ClientAuthHandler + The authentication mechanism to use. + options : FlightCallOptions + Options for this call. + """ + cdef: + unique_ptr[CClientAuthHandler] handler + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + if not isinstance(auth_handler, ClientAuthHandler): + raise TypeError( + "FlightClient.authenticate takes a ClientAuthHandler, " + "not '{}'".format(type(auth_handler))) + handler.reset((<ClientAuthHandler> auth_handler).to_handler()) + with nogil: + check_flight_status( + self.client.get().Authenticate(deref(c_options), + move(handler))) + + def authenticate_basic_token(self, username, password, + options: FlightCallOptions = None): + """Authenticate to the server with HTTP basic authentication. + + Parameters + ---------- + username : string + Username to authenticate with + password : string + Password to authenticate with + options : FlightCallOptions + Options for this call + + Returns + ------- + tuple : Tuple[str, str] + A tuple representing the FlightCallOptions authorization + header entry of a bearer token. + """ + cdef: + CResult[pair[c_string, c_string]] result + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + c_string user = tobytes(username) + c_string pw = tobytes(password) + + with nogil: + result = self.client.get().AuthenticateBasicToken(deref(c_options), + user, pw) + check_flight_status(result.status()) + + return GetResultValue(result) + + def list_actions(self, options: FlightCallOptions = None): + """List the actions available on a service.""" + cdef: + vector[CActionType] results + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + with SignalStopHandler() as stop_handler: + c_options.stop_token = \ + (<StopToken> stop_handler.stop_token).stop_token + with nogil: + check_flight_status( + self.client.get().ListActions(deref(c_options), &results)) + + result = [] + for action_type in results: + py_action = ActionType(frombytes(action_type.type), + frombytes(action_type.description)) + result.append(py_action) + + return result + + def do_action(self, action, options: FlightCallOptions = None): + """ + Execute an action on a service. + + Parameters + ---------- + action : str, tuple, or Action + Can be action type name (no body), type and body, or any Action + object + options : FlightCallOptions + RPC options + + Returns + ------- + results : iterator of Result values + """ + cdef: + unique_ptr[CResultStream] results + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + if isinstance(action, (str, bytes)): + action = Action(action, b'') + elif isinstance(action, tuple): + action = Action(*action) + elif not isinstance(action, Action): + raise TypeError("Action must be Action instance, string, or tuple") + + cdef CAction c_action = Action.unwrap(<Action> action) + with nogil: + check_flight_status( + self.client.get().DoAction( + deref(c_options), c_action, &results)) + + def _do_action_response(): + cdef: + Result result + while True: + result = Result.__new__(Result) + with nogil: + check_flight_status(results.get().Next(&result.result)) + if result.result == NULL: + break + yield result + return _do_action_response() + + def list_flights(self, criteria: bytes = None, + options: FlightCallOptions = None): + """List the flights available on a service.""" + cdef: + unique_ptr[CFlightListing] listing + FlightInfo result + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CCriteria c_criteria + + if criteria: + c_criteria.expression = tobytes(criteria) + + with SignalStopHandler() as stop_handler: + c_options.stop_token = \ + (<StopToken> stop_handler.stop_token).stop_token + with nogil: + check_flight_status( + self.client.get().ListFlights(deref(c_options), + c_criteria, &listing)) + + while True: + result = FlightInfo.__new__(FlightInfo) + with nogil: + check_flight_status(listing.get().Next(&result.info)) + if result.info == NULL: + break + yield result + + def get_flight_info(self, descriptor: FlightDescriptor, + options: FlightCallOptions = None): + """Request information about an available flight.""" + cdef: + FlightInfo result = FlightInfo.__new__(FlightInfo) + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + + with nogil: + check_flight_status(self.client.get().GetFlightInfo( + deref(c_options), c_descriptor, &result.info)) + + return result + + def get_schema(self, descriptor: FlightDescriptor, + options: FlightCallOptions = None): + """Request schema for an available flight.""" + cdef: + SchemaResult result = SchemaResult.__new__(SchemaResult) + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + with nogil: + check_status( + self.client.get() + .GetSchema(deref(c_options), c_descriptor, &result.result) + ) + + return result + + def do_get(self, ticket: Ticket, options: FlightCallOptions = None): + """Request the data for a flight. + + Returns + ------- + reader : FlightStreamReader + """ + cdef: + unique_ptr[CFlightStreamReader] reader + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + + with nogil: + check_flight_status( + self.client.get().DoGet( + deref(c_options), ticket.ticket, &reader)) + result = FlightStreamReader() + result.reader.reset(reader.release()) + return result + + def do_put(self, descriptor: FlightDescriptor, schema: Schema, + options: FlightCallOptions = None): + """Upload data to a flight. + + Returns + ------- + writer : FlightStreamWriter + reader : FlightMetadataReader + """ + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + unique_ptr[CFlightStreamWriter] writer + unique_ptr[CFlightMetadataReader] metadata_reader + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + FlightMetadataReader reader = FlightMetadataReader() + + with nogil: + check_flight_status(self.client.get().DoPut( + deref(c_options), + c_descriptor, + c_schema, + &writer, + &reader.reader)) + result = FlightStreamWriter() + result.writer.reset(writer.release()) + return result, reader + + def do_exchange(self, descriptor: FlightDescriptor, + options: FlightCallOptions = None): + """Start a bidirectional data exchange with a server. + + Parameters + ---------- + descriptor : FlightDescriptor + A descriptor for the flight. + options : FlightCallOptions + RPC options. + + Returns + ------- + writer : FlightStreamWriter + reader : FlightStreamReader + """ + cdef: + unique_ptr[CFlightStreamWriter] c_writer + unique_ptr[CFlightStreamReader] c_reader + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + + with nogil: + check_flight_status(self.client.get().DoExchange( + deref(c_options), + c_descriptor, + &c_writer, + &c_reader)) + py_writer = FlightStreamWriter() + py_writer.writer.reset(c_writer.release()) + py_reader = FlightStreamReader() + py_reader.reader.reset(c_reader.release()) + return py_writer, py_reader + + +cdef class FlightDataStream(_Weakrefable): + """Abstract base class for Flight data streams.""" + + cdef CFlightDataStream* to_stream(self) except *: + """Create the C++ data stream for the backing Python object. + + We don't expose the C++ object to Python, so we can manage its + lifetime from the Cython/C++ side. + """ + raise NotImplementedError + + +cdef class RecordBatchStream(FlightDataStream): + """A Flight data stream backed by RecordBatches.""" + cdef: + object data_source + CIpcWriteOptions write_options + + def __init__(self, data_source, options=None): + """Create a RecordBatchStream from a data source. + + Parameters + ---------- + data_source : RecordBatchReader or Table + options : pyarrow.ipc.IpcWriteOptions, optional + """ + if (not isinstance(data_source, RecordBatchReader) and + not isinstance(data_source, lib.Table)): + raise TypeError("Expected RecordBatchReader or Table, " + "but got: {}".format(type(data_source))) + self.data_source = data_source + self.write_options = _get_options(options).c_options + + cdef CFlightDataStream* to_stream(self) except *: + cdef: + shared_ptr[CRecordBatchReader] reader + if isinstance(self.data_source, RecordBatchReader): + reader = (<RecordBatchReader> self.data_source).reader + elif isinstance(self.data_source, lib.Table): + table = (<Table> self.data_source).table + reader.reset(new TableBatchReader(deref(table))) + else: + raise RuntimeError("Can't construct RecordBatchStream " + "from type {}".format(type(self.data_source))) + return new CRecordBatchStream(reader, self.write_options) + + +cdef class GeneratorStream(FlightDataStream): + """A Flight data stream backed by a Python generator.""" + cdef: + shared_ptr[CSchema] schema + object generator + # A substream currently being consumed by the client, if + # present. Produced by the generator. + unique_ptr[CFlightDataStream] current_stream + CIpcWriteOptions c_options + + def __init__(self, schema, generator, options=None): + """Create a GeneratorStream from a Python generator. + + Parameters + ---------- + schema : Schema + The schema for the data to be returned. + + generator : iterator or iterable + The generator should yield other FlightDataStream objects, + Tables, RecordBatches, or RecordBatchReaders. + + options : pyarrow.ipc.IpcWriteOptions, optional + """ + self.schema = pyarrow_unwrap_schema(schema) + self.generator = iter(generator) + self.c_options = _get_options(options).c_options + + cdef CFlightDataStream* to_stream(self) except *: + cdef: + function[cb_data_stream_next] callback = &_data_stream_next + return new CPyGeneratorFlightDataStream(self, self.schema, callback, + self.c_options) + + +cdef class ServerCallContext(_Weakrefable): + """Per-call state/context.""" + cdef: + const CServerCallContext* context + + def peer_identity(self): + """Get the identity of the authenticated peer. + + May be the empty string. + """ + return tobytes(self.context.peer_identity()) + + def peer(self): + """Get the address of the peer.""" + # Set safe=True as gRPC on Windows sometimes gives garbage bytes + return frombytes(self.context.peer(), safe=True) + + def is_cancelled(self): + return self.context.is_cancelled() + + def get_middleware(self, key): + """ + Get a middleware instance by key. + + Returns None if the middleware was not found. + """ + cdef: + CServerMiddleware* c_middleware = \ + self.context.GetMiddleware(CPyServerMiddlewareName) + CPyServerMiddleware* middleware + if c_middleware == NULL: + return None + if c_middleware.name() != CPyServerMiddlewareName: + return None + middleware = <CPyServerMiddleware*> c_middleware + py_middleware = <_ServerMiddlewareWrapper> middleware.py_object() + return py_middleware.middleware.get(key) + + @staticmethod + cdef ServerCallContext wrap(const CServerCallContext& context): + cdef ServerCallContext result = \ + ServerCallContext.__new__(ServerCallContext) + result.context = &context + return result + + +cdef class ServerAuthReader(_Weakrefable): + """A reader for messages from the client during an auth handshake.""" + cdef: + CServerAuthReader* reader + + def read(self): + cdef c_string token + if not self.reader: + raise ValueError("Cannot use ServerAuthReader outside " + "ServerAuthHandler.authenticate") + with nogil: + check_flight_status(self.reader.Read(&token)) + return token + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.reader = NULL + + @staticmethod + cdef ServerAuthReader wrap(CServerAuthReader* reader): + cdef ServerAuthReader result = \ + ServerAuthReader.__new__(ServerAuthReader) + result.reader = reader + return result + + +cdef class ServerAuthSender(_Weakrefable): + """A writer for messages to the client during an auth handshake.""" + cdef: + CServerAuthSender* sender + + def write(self, message): + cdef c_string c_message = tobytes(message) + if not self.sender: + raise ValueError("Cannot use ServerAuthSender outside " + "ServerAuthHandler.authenticate") + with nogil: + check_flight_status(self.sender.Write(c_message)) + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.sender = NULL + + @staticmethod + cdef ServerAuthSender wrap(CServerAuthSender* sender): + cdef ServerAuthSender result = \ + ServerAuthSender.__new__(ServerAuthSender) + result.sender = sender + return result + + +cdef class ClientAuthReader(_Weakrefable): + """A reader for messages from the server during an auth handshake.""" + cdef: + CClientAuthReader* reader + + def read(self): + cdef c_string token + if not self.reader: + raise ValueError("Cannot use ClientAuthReader outside " + "ClientAuthHandler.authenticate") + with nogil: + check_flight_status(self.reader.Read(&token)) + return token + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.reader = NULL + + @staticmethod + cdef ClientAuthReader wrap(CClientAuthReader* reader): + cdef ClientAuthReader result = \ + ClientAuthReader.__new__(ClientAuthReader) + result.reader = reader + return result + + +cdef class ClientAuthSender(_Weakrefable): + """A writer for messages to the server during an auth handshake.""" + cdef: + CClientAuthSender* sender + + def write(self, message): + cdef c_string c_message = tobytes(message) + if not self.sender: + raise ValueError("Cannot use ClientAuthSender outside " + "ClientAuthHandler.authenticate") + with nogil: + check_flight_status(self.sender.Write(c_message)) + + cdef void poison(self): + """Prevent further usage of this object. + + This object is constructed by taking a pointer to a reference, + so we want to make sure Python users do not access this after + the reference goes away. + """ + self.sender = NULL + + @staticmethod + cdef ClientAuthSender wrap(CClientAuthSender* sender): + cdef ClientAuthSender result = \ + ClientAuthSender.__new__(ClientAuthSender) + result.sender = sender + return result + + +cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *: + """Callback for implementing FlightDataStream in Python.""" + cdef: + unique_ptr[CFlightDataStream] data_stream + + py_stream = <object> self + if not isinstance(py_stream, GeneratorStream): + raise RuntimeError("self object in callback is not GeneratorStream") + stream = <GeneratorStream> py_stream + + # The generator is allowed to yield a reader or table which we + # yield from; if that sub-generator is empty, we need to reset and + # try again. However, limit the number of attempts so that we + # don't just spin forever. + max_attempts = 128 + for _ in range(max_attempts): + if stream.current_stream != nullptr: + check_flight_status(stream.current_stream.get().Next(payload)) + # If the stream ended, see if there's another stream from the + # generator + if payload.ipc_message.metadata != nullptr: + return CStatus_OK() + stream.current_stream.reset(nullptr) + + try: + result = next(stream.generator) + except StopIteration: + payload.ipc_message.metadata.reset(<CBuffer*> nullptr) + return CStatus_OK() + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + + if isinstance(result, (list, tuple)): + result, metadata = result + else: + result, metadata = result, None + + if isinstance(result, (Table, RecordBatchReader)): + if metadata: + raise ValueError("Can only return metadata alongside a " + "RecordBatch.") + result = RecordBatchStream(result) + + stream_schema = pyarrow_wrap_schema(stream.schema) + if isinstance(result, FlightDataStream): + if metadata: + raise ValueError("Can only return metadata alongside a " + "RecordBatch.") + data_stream = unique_ptr[CFlightDataStream]( + (<FlightDataStream> result).to_stream()) + substream_schema = pyarrow_wrap_schema(data_stream.get().schema()) + if substream_schema != stream_schema: + raise ValueError("Got a FlightDataStream whose schema " + "does not match the declared schema of this " + "GeneratorStream. " + "Got: {}\nExpected: {}".format( + substream_schema, stream_schema)) + stream.current_stream.reset( + new CPyFlightDataStream(result, move(data_stream))) + # Loop around and try again + continue + elif isinstance(result, RecordBatch): + batch = <RecordBatch> result + if batch.schema != stream_schema: + raise ValueError("Got a RecordBatch whose schema does not " + "match the declared schema of this " + "GeneratorStream. " + "Got: {}\nExpected: {}".format(batch.schema, + stream_schema)) + check_flight_status(GetRecordBatchPayload( + deref(batch.batch), + stream.c_options, + &payload.ipc_message)) + if metadata: + payload.app_metadata = pyarrow_unwrap_buffer( + as_buffer(metadata)) + else: + raise TypeError("GeneratorStream must be initialized with " + "an iterator of FlightDataStream, Table, " + "RecordBatch, or RecordBatchStreamReader objects, " + "not {}.".format(type(result))) + # Don't loop around + return CStatus_OK() + # Ran out of attempts (the RPC handler kept yielding empty tables/readers) + raise RuntimeError("While getting next payload, ran out of attempts to " + "get something to send " + "(application server implementation error)") + + +cdef CStatus _list_flights(void* self, const CServerCallContext& context, + const CCriteria* c_criteria, + unique_ptr[CFlightListing]* listing) except *: + """Callback for implementing ListFlights in Python.""" + cdef: + vector[CFlightInfo] flights + + try: + result = (<object> self).list_flights(ServerCallContext.wrap(context), + c_criteria.expression) + for info in result: + if not isinstance(info, FlightInfo): + raise TypeError("FlightServerBase.list_flights must return " + "FlightInfo instances, but got {}".format( + type(info))) + flights.push_back(deref((<FlightInfo> info).info.get())) + listing.reset(new CSimpleFlightListing(flights)) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _get_flight_info(void* self, const CServerCallContext& context, + CFlightDescriptor c_descriptor, + unique_ptr[CFlightInfo]* info) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + FlightDescriptor py_descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + py_descriptor.descriptor = c_descriptor + try: + result = (<object> self).get_flight_info( + ServerCallContext.wrap(context), + py_descriptor) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + if not isinstance(result, FlightInfo): + raise TypeError("FlightServerBase.get_flight_info must return " + "a FlightInfo instance, but got {}".format( + type(result))) + info.reset(new CFlightInfo(deref((<FlightInfo> result).info.get()))) + return CStatus_OK() + +cdef CStatus _get_schema(void* self, const CServerCallContext& context, + CFlightDescriptor c_descriptor, + unique_ptr[CSchemaResult]* info) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + FlightDescriptor py_descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + py_descriptor.descriptor = c_descriptor + result = (<object> self).get_schema(ServerCallContext.wrap(context), + py_descriptor) + if not isinstance(result, SchemaResult): + raise TypeError("FlightServerBase.get_schema_info must return " + "a SchemaResult instance, but got {}".format( + type(result))) + info.reset(new CSchemaResult(deref((<SchemaResult> result).result.get()))) + return CStatus_OK() + +cdef CStatus _do_put(void* self, const CServerCallContext& context, + unique_ptr[CFlightMessageReader] reader, + unique_ptr[CFlightMetadataWriter] writer) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + MetadataRecordBatchReader py_reader = MetadataRecordBatchReader() + FlightMetadataWriter py_writer = FlightMetadataWriter() + FlightDescriptor descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + + descriptor.descriptor = reader.get().descriptor() + py_reader.reader.reset(reader.release()) + py_writer.writer.reset(writer.release()) + try: + (<object> self).do_put(ServerCallContext.wrap(context), descriptor, + py_reader, py_writer) + return CStatus_OK() + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + + +cdef CStatus _do_get(void* self, const CServerCallContext& context, + CTicket ticket, + unique_ptr[CFlightDataStream]* stream) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + unique_ptr[CFlightDataStream] data_stream + + py_ticket = Ticket(ticket.ticket) + try: + result = (<object> self).do_get(ServerCallContext.wrap(context), + py_ticket) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + if not isinstance(result, FlightDataStream): + raise TypeError("FlightServerBase.do_get must return " + "a FlightDataStream") + data_stream = unique_ptr[CFlightDataStream]( + (<FlightDataStream> result).to_stream()) + stream[0] = unique_ptr[CFlightDataStream]( + new CPyFlightDataStream(result, move(data_stream))) + return CStatus_OK() + + +cdef CStatus _do_exchange(void* self, const CServerCallContext& context, + unique_ptr[CFlightMessageReader] reader, + unique_ptr[CFlightMessageWriter] writer) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + MetadataRecordBatchReader py_reader = MetadataRecordBatchReader() + MetadataRecordBatchWriter py_writer = MetadataRecordBatchWriter() + FlightDescriptor descriptor = \ + FlightDescriptor.__new__(FlightDescriptor) + + descriptor.descriptor = reader.get().descriptor() + py_reader.reader.reset(reader.release()) + py_writer.writer.reset(writer.release()) + try: + (<object> self).do_exchange(ServerCallContext.wrap(context), + descriptor, py_reader, py_writer) + return CStatus_OK() + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + + +cdef CStatus _do_action_result_next( + void* self, + unique_ptr[CFlightResult]* result +) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + CFlightResult* c_result + + try: + action_result = next(<object> self) + if not isinstance(action_result, Result): + action_result = Result(action_result) + c_result = (<Result> action_result).result.get() + result.reset(new CFlightResult(deref(c_result))) + except StopIteration: + result.reset(nullptr) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _do_action(void* self, const CServerCallContext& context, + const CAction& action, + unique_ptr[CResultStream]* result) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + function[cb_result_next] ptr = &_do_action_result_next + py_action = Action(action.type, pyarrow_wrap_buffer(action.body)) + try: + responses = (<object> self).do_action(ServerCallContext.wrap(context), + py_action) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + # Let the application return an iterator or anything convertible + # into one + if responses is None: + # Server didn't return anything + responses = [] + result.reset(new CPyFlightResultStream(iter(responses), ptr)) + return CStatus_OK() + + +cdef CStatus _list_actions(void* self, const CServerCallContext& context, + vector[CActionType]* actions) except *: + """Callback for implementing Flight servers in Python.""" + cdef: + CActionType action_type + # Method should return a list of ActionTypes or similar tuple + try: + result = (<object> self).list_actions(ServerCallContext.wrap(context)) + for action in result: + if not isinstance(action, tuple): + raise TypeError( + "Results of list_actions must be ActionType or tuple") + action_type.type = tobytes(action[0]) + action_type.description = tobytes(action[1]) + actions.push_back(action_type) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _server_authenticate(void* self, CServerAuthSender* outgoing, + CServerAuthReader* incoming) except *: + """Callback for implementing authentication in Python.""" + sender = ServerAuthSender.wrap(outgoing) + reader = ServerAuthReader.wrap(incoming) + try: + (<object> self).authenticate(sender, reader) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + finally: + sender.poison() + reader.poison() + return CStatus_OK() + +cdef CStatus _is_valid(void* self, const c_string& token, + c_string* peer_identity) except *: + """Callback for implementing authentication in Python.""" + cdef c_string c_result + try: + c_result = tobytes((<object> self).is_valid(token)) + peer_identity[0] = c_result + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _client_authenticate(void* self, CClientAuthSender* outgoing, + CClientAuthReader* incoming) except *: + """Callback for implementing authentication in Python.""" + sender = ClientAuthSender.wrap(outgoing) + reader = ClientAuthReader.wrap(incoming) + try: + (<object> self).authenticate(sender, reader) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + finally: + sender.poison() + reader.poison() + return CStatus_OK() + + +cdef CStatus _get_token(void* self, c_string* token) except *: + """Callback for implementing authentication in Python.""" + cdef c_string c_result + try: + c_result = tobytes((<object> self).get_token()) + token[0] = c_result + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _middleware_sending_headers( + void* self, CAddCallHeaders* add_headers) except *: + """Callback for implementing middleware.""" + try: + headers = (<object> self).sending_headers() + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + + if headers: + for header, values in headers.items(): + if isinstance(values, (str, bytes)): + values = (values,) + # Headers in gRPC (and HTTP/1, HTTP/2) are required to be + # valid ASCII. + if isinstance(header, str): + header = header.encode("ascii") + for value in values: + if isinstance(value, str): + value = value.encode("ascii") + # Allow bytes values to pass through. + add_headers.AddHeader(header, value) + + return CStatus_OK() + + +cdef CStatus _middleware_call_completed( + void* self, + const CStatus& call_status) except *: + """Callback for implementing middleware.""" + try: + try: + check_flight_status(call_status) + except Exception as e: + (<object> self).call_completed(e) + else: + (<object> self).call_completed(None) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + return CStatus_OK() + + +cdef CStatus _middleware_received_headers( + void* self, + const CCallHeaders& c_headers) except *: + """Callback for implementing middleware.""" + try: + headers = convert_headers(c_headers) + (<object> self).received_headers(headers) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + return CStatus_OK() + + +cdef dict convert_headers(const CCallHeaders& c_headers): + cdef: + CCallHeaders.const_iterator header_iter = c_headers.cbegin() + headers = {} + while header_iter != c_headers.cend(): + header = c_string(deref(header_iter).first).decode("ascii") + value = c_string(deref(header_iter).second) + if not header.endswith("-bin"): + # Text header values in gRPC (and HTTP/1, HTTP/2) are + # required to be valid ASCII. Binary header values are + # exposed as bytes. + value = value.decode("ascii") + headers.setdefault(header, []).append(value) + postincrement(header_iter) + return headers + + +cdef CStatus _server_middleware_start_call( + void* self, + const CCallInfo& c_info, + const CCallHeaders& c_headers, + shared_ptr[CServerMiddleware]* c_instance) except *: + """Callback for implementing server middleware.""" + instance = None + try: + call_info = wrap_call_info(c_info) + headers = convert_headers(c_headers) + instance = (<object> self).start_call(call_info, headers) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + + if instance: + ServerMiddleware.wrap(instance, c_instance) + + return CStatus_OK() + + +cdef CStatus _client_middleware_start_call( + void* self, + const CCallInfo& c_info, + unique_ptr[CClientMiddleware]* c_instance) except *: + """Callback for implementing client middleware.""" + instance = None + try: + call_info = wrap_call_info(c_info) + instance = (<object> self).start_call(call_info) + except FlightError as flight_error: + return (<FlightError> flight_error).to_status() + + if instance: + ClientMiddleware.wrap(instance, c_instance) + + return CStatus_OK() + + +cdef class ServerAuthHandler(_Weakrefable): + """Authentication middleware for a server. + + To implement an authentication mechanism, subclass this class and + override its methods. + + """ + + def authenticate(self, outgoing, incoming): + """Conduct the handshake with the client. + + May raise an error if the client cannot authenticate. + + Parameters + ---------- + outgoing : ServerAuthSender + A channel to send messages to the client. + incoming : ServerAuthReader + A channel to read messages from the client. + """ + raise NotImplementedError + + def is_valid(self, token): + """Validate a client token, returning their identity. + + May return an empty string (if the auth mechanism does not + name the peer) or raise an exception (if the token is + invalid). + + Parameters + ---------- + token : bytes + The authentication token from the client. + + """ + raise NotImplementedError + + cdef PyServerAuthHandler* to_handler(self): + cdef PyServerAuthHandlerVtable vtable + vtable.authenticate = _server_authenticate + vtable.is_valid = _is_valid + return new PyServerAuthHandler(self, vtable) + + +cdef class ClientAuthHandler(_Weakrefable): + """Authentication plugin for a client.""" + + def authenticate(self, outgoing, incoming): + """Conduct the handshake with the server. + + Parameters + ---------- + outgoing : ClientAuthSender + A channel to send messages to the server. + incoming : ClientAuthReader + A channel to read messages from the server. + """ + raise NotImplementedError + + def get_token(self): + """Get the auth token for a call.""" + raise NotImplementedError + + cdef PyClientAuthHandler* to_handler(self): + cdef PyClientAuthHandlerVtable vtable + vtable.authenticate = _client_authenticate + vtable.get_token = _get_token + return new PyClientAuthHandler(self, vtable) + + +_CallInfo = collections.namedtuple("_CallInfo", ["method"]) + + +class CallInfo(_CallInfo): + """Information about a particular RPC for Flight middleware.""" + + +cdef wrap_call_info(const CCallInfo& c_info): + method = wrap_flight_method(c_info.method) + return CallInfo(method=method) + + +cdef class ClientMiddlewareFactory(_Weakrefable): + """A factory for new middleware instances. + + All middleware methods will be called from the same thread as the + RPC method implementation. That is, thread-locals set in the + client are accessible from the middleware itself. + + """ + + def start_call(self, info): + """Called at the start of an RPC. + + This must be thread-safe and must not raise exceptions. + + Parameters + ---------- + info : CallInfo + Information about the call. + + Returns + ------- + instance : ClientMiddleware + An instance of ClientMiddleware (the instance to use for + the call), or None if this call is not intercepted. + + """ + + +cdef class ClientMiddleware(_Weakrefable): + """Client-side middleware for a call, instantiated per RPC. + + Methods here should be fast and must be infallible: they should + not raise exceptions or stall indefinitely. + + """ + + def sending_headers(self): + """A callback before headers are sent. + + Returns + ------- + headers : dict + A dictionary of header values to add to the request, or + None if no headers are to be added. The dictionary should + have string keys and string or list-of-string values. + + Bytes values are allowed, but the underlying transport may + not support them or may restrict them. For gRPC, binary + values are only allowed on headers ending in "-bin". + + """ + + def received_headers(self, headers): + """A callback when headers are received. + + The default implementation does nothing. + + Parameters + ---------- + headers : dict + A dictionary of headers from the server. Keys are strings + and values are lists of strings (for text headers) or + bytes (for binary headers). + + """ + + def call_completed(self, exception): + """A callback when the call finishes. + + The default implementation does nothing. + + Parameters + ---------- + exception : ArrowException + If the call errored, this is the equivalent + exception. Will be None if the call succeeded. + + """ + + @staticmethod + cdef void wrap(object py_middleware, + unique_ptr[CClientMiddleware]* c_instance): + cdef PyClientMiddlewareVtable vtable + vtable.sending_headers = _middleware_sending_headers + vtable.received_headers = _middleware_received_headers + vtable.call_completed = _middleware_call_completed + c_instance[0].reset(new CPyClientMiddleware(py_middleware, vtable)) + + +cdef class ServerMiddlewareFactory(_Weakrefable): + """A factory for new middleware instances. + + All middleware methods will be called from the same thread as the + RPC method implementation. That is, thread-locals set in the + middleware are accessible from the method itself. + + """ + + def start_call(self, info, headers): + """Called at the start of an RPC. + + This must be thread-safe. + + Parameters + ---------- + info : CallInfo + Information about the call. + headers : dict + A dictionary of headers from the client. Keys are strings + and values are lists of strings (for text headers) or + bytes (for binary headers). + + Returns + ------- + instance : ServerMiddleware + An instance of ServerMiddleware (the instance to use for + the call), or None if this call is not intercepted. + + Raises + ------ + exception : pyarrow.ArrowException + If an exception is raised, the call will be rejected with + the given error. + + """ + + +cdef class ServerMiddleware(_Weakrefable): + """Server-side middleware for a call, instantiated per RPC. + + Methods here should be fast and must be infalliable: they should + not raise exceptions or stall indefinitely. + + """ + + def sending_headers(self): + """A callback before headers are sent. + + Returns + ------- + headers : dict + A dictionary of header values to add to the response, or + None if no headers are to be added. The dictionary should + have string keys and string or list-of-string values. + + Bytes values are allowed, but the underlying transport may + not support them or may restrict them. For gRPC, binary + values are only allowed on headers ending in "-bin". + + """ + + def call_completed(self, exception): + """A callback when the call finishes. + + Parameters + ---------- + exception : pyarrow.ArrowException + If the call errored, this is the equivalent + exception. Will be None if the call succeeded. + + """ + + @staticmethod + cdef void wrap(object py_middleware, + shared_ptr[CServerMiddleware]* c_instance): + cdef PyServerMiddlewareVtable vtable + vtable.sending_headers = _middleware_sending_headers + vtable.call_completed = _middleware_call_completed + c_instance[0].reset(new CPyServerMiddleware(py_middleware, vtable)) + + +cdef class _ServerMiddlewareFactoryWrapper(ServerMiddlewareFactory): + """Wrapper to bundle server middleware into a single C++ one.""" + + cdef: + dict factories + + def __init__(self, dict factories): + self.factories = factories + + def start_call(self, info, headers): + instances = {} + for key, factory in self.factories.items(): + instance = factory.start_call(info, headers) + if instance: + # TODO: prevent duplicate keys + instances[key] = instance + if instances: + wrapper = _ServerMiddlewareWrapper(instances) + return wrapper + return None + + +cdef class _ServerMiddlewareWrapper(ServerMiddleware): + cdef: + dict middleware + + def __init__(self, dict middleware): + self.middleware = middleware + + def sending_headers(self): + headers = collections.defaultdict(list) + for instance in self.middleware.values(): + more_headers = instance.sending_headers() + if not more_headers: + continue + # Manually merge with existing headers (since headers are + # multi-valued) + for key, values in more_headers.items(): + if isinstance(values, (bytes, str)): + values = (values,) + headers[key].extend(values) + return headers + + def call_completed(self, exception): + for instance in self.middleware.values(): + instance.call_completed(exception) + + +cdef class FlightServerBase(_Weakrefable): + """A Flight service definition. + + Override methods to define your Flight service. + + Parameters + ---------- + location : str, tuple or Location optional, default None + Location to serve on. Either a gRPC URI like `grpc://localhost:port`, + a tuple of (host, port) pair, or a Location instance. + If None is passed then the server will be started on localhost with a + system provided random port. + auth_handler : ServerAuthHandler optional, default None + An authentication mechanism to use. May be None. + tls_certificates : list optional, default None + A list of (certificate, key) pairs. + verify_client : boolean optional, default False + If True, then enable mutual TLS: require the client to present + a client certificate, and validate the certificate. + root_certificates : bytes optional, default None + If enabling mutual TLS, this specifies the PEM-encoded root + certificate used to validate client certificates. + middleware : list optional, default None + A dictionary of :class:`ServerMiddlewareFactory` items. The + keys are used to retrieve the middleware instance during calls + (see :meth:`ServerCallContext.get_middleware`). + + """ + + cdef: + unique_ptr[PyFlightServer] server + + def __init__(self, location=None, auth_handler=None, + tls_certificates=None, verify_client=None, + root_certificates=None, middleware=None): + if isinstance(location, (bytes, str)): + location = Location(location) + elif isinstance(location, (tuple, type(None))): + if location is None: + location = ('localhost', 0) + host, port = location + if tls_certificates: + location = Location.for_grpc_tls(host, port) + else: + location = Location.for_grpc_tcp(host, port) + elif not isinstance(location, Location): + raise TypeError('`location` argument must be a string, tuple or a ' + 'Location instance') + self.init(location, auth_handler, tls_certificates, verify_client, + tobytes(root_certificates or b""), middleware) + + cdef init(self, Location location, ServerAuthHandler auth_handler, + list tls_certificates, c_bool verify_client, + bytes root_certificates, dict middleware): + cdef: + PyFlightServerVtable vtable = PyFlightServerVtable() + PyFlightServer* c_server + unique_ptr[CFlightServerOptions] c_options + CCertKeyPair c_cert + function[cb_server_middleware_start_call] start_call = \ + &_server_middleware_start_call + pair[c_string, shared_ptr[CServerMiddlewareFactory]] c_middleware + + c_options.reset(new CFlightServerOptions(Location.unwrap(location))) + # mTLS configuration + c_options.get().verify_client = verify_client + c_options.get().root_certificates = root_certificates + + if auth_handler: + if not isinstance(auth_handler, ServerAuthHandler): + raise TypeError("auth_handler must be a ServerAuthHandler, " + "not a '{}'".format(type(auth_handler))) + c_options.get().auth_handler.reset( + (<ServerAuthHandler> auth_handler).to_handler()) + + if tls_certificates: + for cert, key in tls_certificates: + c_cert.pem_cert = tobytes(cert) + c_cert.pem_key = tobytes(key) + c_options.get().tls_certificates.push_back(c_cert) + + if middleware: + py_middleware = _ServerMiddlewareFactoryWrapper(middleware) + c_middleware.first = CPyServerMiddlewareName + c_middleware.second.reset(new CPyServerMiddlewareFactory( + py_middleware, + start_call)) + c_options.get().middleware.push_back(c_middleware) + + vtable.list_flights = &_list_flights + vtable.get_flight_info = &_get_flight_info + vtable.get_schema = &_get_schema + vtable.do_put = &_do_put + vtable.do_get = &_do_get + vtable.do_exchange = &_do_exchange + vtable.list_actions = &_list_actions + vtable.do_action = &_do_action + + c_server = new PyFlightServer(self, vtable) + self.server.reset(c_server) + with nogil: + check_flight_status(c_server.Init(deref(c_options))) + + @property + def port(self): + """ + Get the port that this server is listening on. + + Returns a non-positive value if the operation is invalid + (e.g. init() was not called or server is listening on a domain + socket). + """ + return self.server.get().port() + + def list_flights(self, context, criteria): + raise NotImplementedError + + def get_flight_info(self, context, descriptor): + raise NotImplementedError + + def get_schema(self, context, descriptor): + raise NotImplementedError + + def do_put(self, context, descriptor, reader, + writer: FlightMetadataWriter): + raise NotImplementedError + + def do_get(self, context, ticket): + raise NotImplementedError + + def do_exchange(self, context, descriptor, reader, writer): + raise NotImplementedError + + def list_actions(self, context): + raise NotImplementedError + + def do_action(self, context, action): + raise NotImplementedError + + def serve(self): + """Start serving. + + This method only returns if shutdown() is called or a signal a + received. + """ + if self.server.get() == nullptr: + raise ValueError("run() on uninitialized FlightServerBase") + with nogil: + check_flight_status(self.server.get().ServeWithSignals()) + + def run(self): + warnings.warn("The 'FlightServer.run' method is deprecated, use " + "FlightServer.serve method instead") + self.serve() + + def shutdown(self): + """Shut down the server, blocking until current requests finish. + + Do not call this directly from the implementation of a Flight + method, as then the server will block forever waiting for that + request to finish. Instead, call this method from a background + thread. + """ + # Must not hold the GIL: shutdown waits for pending RPCs to + # complete. Holding the GIL means Python-implemented Flight + # methods will never get to run, so this will hang + # indefinitely. + if self.server.get() == nullptr: + raise ValueError("shutdown() on uninitialized FlightServerBase") + with nogil: + check_flight_status(self.server.get().Shutdown()) + + def wait(self): + """Block until server is terminated with shutdown.""" + with nogil: + self.server.get().Wait() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + self.wait() + + +def connect(location, **kwargs): + """ + Connect to the Flight server + Parameters + ---------- + location : str, tuple or Location + Location to connect to. Either a gRPC URI like `grpc://localhost:port`, + a tuple of (host, port) pair, or a Location instance. + tls_root_certs : bytes or None + PEM-encoded + cert_chain: str or None + If provided, enables TLS mutual authentication. + private_key: str or None + If provided, enables TLS mutual authentication. + override_hostname : str or None + Override the hostname checked by TLS. Insecure, use with caution. + middleware : list or None + A list of ClientMiddlewareFactory instances to apply. + write_size_limit_bytes : int or None + A soft limit on the size of a data payload sent to the + server. Enabled if positive. If enabled, writing a record + batch that (when serialized) exceeds this limit will raise an + exception; the client can retry the write with a smaller + batch. + disable_server_verification : boolean or None + Disable verifying the server when using TLS. + Insecure, use with caution. + generic_options : list or None + A list of generic (string, int or string) options to pass to + the underlying transport. + Returns + ------- + client : FlightClient + """ + return FlightClient(location, **kwargs) diff --git a/src/arrow/python/pyarrow/_fs.pxd b/src/arrow/python/pyarrow/_fs.pxd new file mode 100644 index 000000000..4504b78b8 --- /dev/null +++ b/src/arrow/python/pyarrow/_fs.pxd @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow.lib import _detect_compression, frombytes, tobytes +from pyarrow.lib cimport * + + +cpdef enum FileType: + NotFound = <int8_t> CFileType_NotFound + Unknown = <int8_t> CFileType_Unknown + File = <int8_t> CFileType_File + Directory = <int8_t> CFileType_Directory + + +cdef class FileInfo(_Weakrefable): + cdef: + CFileInfo info + + @staticmethod + cdef wrap(CFileInfo info) + + cdef inline CFileInfo unwrap(self) nogil + + @staticmethod + cdef CFileInfo unwrap_safe(obj) + + +cdef class FileSelector(_Weakrefable): + cdef: + CFileSelector selector + + @staticmethod + cdef FileSelector wrap(CFileSelector selector) + + cdef inline CFileSelector unwrap(self) nogil + + +cdef class FileSystem(_Weakrefable): + cdef: + shared_ptr[CFileSystem] wrapped + CFileSystem* fs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + @staticmethod + cdef wrap(const shared_ptr[CFileSystem]& sp) + + cdef inline shared_ptr[CFileSystem] unwrap(self) nogil + + +cdef class LocalFileSystem(FileSystem): + cdef: + CLocalFileSystem* localfs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + +cdef class SubTreeFileSystem(FileSystem): + cdef: + CSubTreeFileSystem* subtreefs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + +cdef class _MockFileSystem(FileSystem): + cdef: + CMockFileSystem* mockfs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) + + +cdef class PyFileSystem(FileSystem): + cdef: + CPyFileSystem* pyfs + + cdef init(self, const shared_ptr[CFileSystem]& wrapped) diff --git a/src/arrow/python/pyarrow/_fs.pyx b/src/arrow/python/pyarrow/_fs.pyx new file mode 100644 index 000000000..34a0ef55b --- /dev/null +++ b/src/arrow/python/pyarrow/_fs.pyx @@ -0,0 +1,1233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from cpython.datetime cimport datetime, PyDateTime_DateTime + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport PyDateTime_to_TimePoint +from pyarrow.lib import _detect_compression, frombytes, tobytes +from pyarrow.lib cimport * +from pyarrow.util import _stringify_path + +from abc import ABC, abstractmethod +from datetime import datetime, timezone +import os +import pathlib +import sys +import warnings + + +cdef _init_ca_paths(): + cdef CFileSystemGlobalOptions options + + import ssl + paths = ssl.get_default_verify_paths() + if paths.cafile: + options.tls_ca_file_path = os.fsencode(paths.cafile) + if paths.capath: + options.tls_ca_dir_path = os.fsencode(paths.capath) + check_status(CFileSystemsInitialize(options)) + + +if sys.platform == 'linux': + # ARROW-9261: On Linux, we may need to fixup the paths to TLS CA certs + # (especially in manylinux packages) since the values hardcoded at + # compile-time in libcurl may be wrong. + _init_ca_paths() + + +cdef inline c_string _path_as_bytes(path) except *: + # handle only abstract paths, not bound to any filesystem like pathlib is, + # so we only accept plain strings + if not isinstance(path, (bytes, str)): + raise TypeError('Path must be a string') + # tobytes always uses utf-8, which is more or less ok, at least on Windows + # since the C++ side then decodes from utf-8. On Unix, os.fsencode may be + # better. + return tobytes(path) + + +cdef object _wrap_file_type(CFileType ty): + return FileType(<int8_t> ty) + + +cdef CFileType _unwrap_file_type(FileType ty) except *: + if ty == FileType.Unknown: + return CFileType_Unknown + elif ty == FileType.NotFound: + return CFileType_NotFound + elif ty == FileType.File: + return CFileType_File + elif ty == FileType.Directory: + return CFileType_Directory + assert 0 + + +cdef class FileInfo(_Weakrefable): + """ + FileSystem entry info. + + Parameters + ---------- + path : str + The full path to the filesystem entry. + type : FileType + The type of the filesystem entry. + mtime : datetime or float, default None + If given, the modification time of the filesystem entry. + If a float is given, it is the number of seconds since the + Unix epoch. + mtime_ns : int, default None + If given, the modification time of the filesystem entry, + in nanoseconds since the Unix epoch. + `mtime` and `mtime_ns` are mutually exclusive. + size : int, default None + If given, the filesystem entry size in bytes. This should only + be given if `type` is `FileType.File`. + + """ + + def __init__(self, path, FileType type=FileType.Unknown, *, + mtime=None, mtime_ns=None, size=None): + self.info.set_path(tobytes(path)) + self.info.set_type(_unwrap_file_type(type)) + if mtime is not None: + if mtime_ns is not None: + raise TypeError("Only one of mtime and mtime_ns " + "can be given") + if isinstance(mtime, datetime): + self.info.set_mtime(PyDateTime_to_TimePoint( + <PyDateTime_DateTime*> mtime)) + else: + self.info.set_mtime(TimePoint_from_s(mtime)) + elif mtime_ns is not None: + self.info.set_mtime(TimePoint_from_ns(mtime_ns)) + if size is not None: + self.info.set_size(size) + + @staticmethod + cdef wrap(CFileInfo info): + cdef FileInfo self = FileInfo.__new__(FileInfo) + self.info = move(info) + return self + + cdef inline CFileInfo unwrap(self) nogil: + return self.info + + @staticmethod + cdef CFileInfo unwrap_safe(obj): + if not isinstance(obj, FileInfo): + raise TypeError("Expected FileInfo instance, got {0}" + .format(type(obj))) + return (<FileInfo> obj).unwrap() + + def __repr__(self): + def getvalue(attr): + try: + return getattr(self, attr) + except ValueError: + return '' + + s = '<FileInfo for {!r}: type={}'.format(self.path, str(self.type)) + if self.is_file: + s += ', size={}'.format(self.size) + s += '>' + return s + + @property + def type(self): + """ + Type of the file. + + The returned enum values can be the following: + + - FileType.NotFound: target does not exist + - FileType.Unknown: target exists but its type is unknown (could be a + special file such as a Unix socket or character device, or + Windows NUL / CON / ...) + - FileType.File: target is a regular file + - FileType.Directory: target is a regular directory + + Returns + ------- + type : FileType + """ + return _wrap_file_type(self.info.type()) + + @property + def is_file(self): + """ + """ + return self.type == FileType.File + + @property + def path(self): + """ + The full file path in the filesystem. + """ + return frombytes(self.info.path()) + + @property + def base_name(self): + """ + The file base name. + + Component after the last directory separator. + """ + return frombytes(self.info.base_name()) + + @property + def size(self): + """ + The size in bytes, if available. + + Only regular files are guaranteed to have a size. + + Returns + ------- + size : int or None + """ + cdef int64_t size + size = self.info.size() + return (size if size != -1 else None) + + @property + def extension(self): + """ + The file extension. + """ + return frombytes(self.info.extension()) + + @property + def mtime(self): + """ + The time of last modification, if available. + + Returns + ------- + mtime : datetime.datetime or None + """ + cdef int64_t nanoseconds + nanoseconds = TimePoint_to_ns(self.info.mtime()) + return (datetime.fromtimestamp(nanoseconds / 1.0e9, timezone.utc) + if nanoseconds != -1 else None) + + @property + def mtime_ns(self): + """ + The time of last modification, if available, expressed in nanoseconds + since the Unix epoch. + + Returns + ------- + mtime_ns : int or None + """ + cdef int64_t nanoseconds + nanoseconds = TimePoint_to_ns(self.info.mtime()) + return (nanoseconds if nanoseconds != -1 else None) + + +cdef class FileSelector(_Weakrefable): + """ + File and directory selector. + + It contains a set of options that describes how to search for files and + directories. + + Parameters + ---------- + base_dir : str + The directory in which to select files. Relative paths also work, use + '.' for the current directory and '..' for the parent. + allow_not_found : bool, default False + The behavior if `base_dir` doesn't exist in the filesystem. + If false, an error is returned. + If true, an empty selection is returned. + recursive : bool, default False + Whether to recurse into subdirectories. + """ + + def __init__(self, base_dir, bint allow_not_found=False, + bint recursive=False): + self.base_dir = base_dir + self.recursive = recursive + self.allow_not_found = allow_not_found + + @staticmethod + cdef FileSelector wrap(CFileSelector wrapped): + cdef FileSelector self = FileSelector.__new__(FileSelector) + self.selector = move(wrapped) + return self + + cdef inline CFileSelector unwrap(self) nogil: + return self.selector + + @property + def base_dir(self): + return frombytes(self.selector.base_dir) + + @base_dir.setter + def base_dir(self, base_dir): + self.selector.base_dir = _path_as_bytes(base_dir) + + @property + def allow_not_found(self): + return self.selector.allow_not_found + + @allow_not_found.setter + def allow_not_found(self, bint allow_not_found): + self.selector.allow_not_found = allow_not_found + + @property + def recursive(self): + return self.selector.recursive + + @recursive.setter + def recursive(self, bint recursive): + self.selector.recursive = recursive + + def __repr__(self): + return ("<FileSelector base_dir={0.base_dir!r} " + "recursive={0.recursive}>".format(self)) + + +cdef class FileSystem(_Weakrefable): + """ + Abstract file system API. + """ + + def __init__(self): + raise TypeError("FileSystem is an abstract class, instantiate one of " + "the subclasses instead: LocalFileSystem or " + "SubTreeFileSystem") + + @staticmethod + def from_uri(uri): + """ + Create a new FileSystem from URI or Path. + + Recognized URI schemes are "file", "mock", "s3fs", "hdfs" and "viewfs". + In addition, the argument can be a pathlib.Path object, or a string + describing an absolute local path. + + Parameters + ---------- + uri : string + URI-based path, for example: file:///some/local/path. + + Returns + ------- + With (filesystem, path) tuple where path is the abstract path inside + the FileSystem instance. + """ + cdef: + c_string path + CResult[shared_ptr[CFileSystem]] result + + if isinstance(uri, pathlib.Path): + # Make absolute + uri = uri.resolve().absolute() + uri = _stringify_path(uri) + result = CFileSystemFromUriOrPath(tobytes(uri), &path) + return FileSystem.wrap(GetResultValue(result)), frombytes(path) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + self.wrapped = wrapped + self.fs = wrapped.get() + + @staticmethod + cdef wrap(const shared_ptr[CFileSystem]& sp): + cdef FileSystem self + + typ = frombytes(sp.get().type_name()) + if typ == 'local': + self = LocalFileSystem.__new__(LocalFileSystem) + elif typ == 'mock': + self = _MockFileSystem.__new__(_MockFileSystem) + elif typ == 'subtree': + self = SubTreeFileSystem.__new__(SubTreeFileSystem) + elif typ == 's3': + from pyarrow._s3fs import S3FileSystem + self = S3FileSystem.__new__(S3FileSystem) + elif typ == 'hdfs': + from pyarrow._hdfs import HadoopFileSystem + self = HadoopFileSystem.__new__(HadoopFileSystem) + elif typ.startswith('py::'): + self = PyFileSystem.__new__(PyFileSystem) + else: + raise TypeError('Cannot wrap FileSystem pointer') + + self.init(sp) + return self + + cdef inline shared_ptr[CFileSystem] unwrap(self) nogil: + return self.wrapped + + def equals(self, FileSystem other): + return self.fs.Equals(other.unwrap()) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + @property + def type_name(self): + """ + The filesystem's type name. + """ + return frombytes(self.fs.type_name()) + + def get_file_info(self, paths_or_selector): + """ + Get info for the given files. + + Any symlink is automatically dereferenced, recursively. A non-existing + or unreachable file returns a FileStat object and has a FileType of + value NotFound. An exception indicates a truly exceptional condition + (low-level I/O error, etc.). + + Parameters + ---------- + paths_or_selector: FileSelector, path-like or list of path-likes + Either a selector object, a path-like object or a list of + path-like objects. The selector's base directory will not be + part of the results, even if it exists. If it doesn't exist, + use `allow_not_found`. + + Returns + ------- + FileInfo or list of FileInfo + Single FileInfo object is returned for a single path, otherwise + a list of FileInfo objects is returned. + """ + cdef: + CFileInfo info + c_string path + vector[CFileInfo] infos + vector[c_string] paths + CFileSelector selector + + if isinstance(paths_or_selector, FileSelector): + with nogil: + selector = (<FileSelector>paths_or_selector).selector + infos = GetResultValue(self.fs.GetFileInfo(selector)) + elif isinstance(paths_or_selector, (list, tuple)): + paths = [_path_as_bytes(s) for s in paths_or_selector] + with nogil: + infos = GetResultValue(self.fs.GetFileInfo(paths)) + elif isinstance(paths_or_selector, (bytes, str)): + path =_path_as_bytes(paths_or_selector) + with nogil: + info = GetResultValue(self.fs.GetFileInfo(path)) + return FileInfo.wrap(info) + else: + raise TypeError('Must pass either path(s) or a FileSelector') + + return [FileInfo.wrap(info) for info in infos] + + def create_dir(self, path, *, bint recursive=True): + """ + Create a directory and subdirectories. + + This function succeeds if the directory already exists. + + Parameters + ---------- + path : str + The path of the new directory. + recursive: bool, default True + Create nested directories as well. + """ + cdef c_string directory = _path_as_bytes(path) + with nogil: + check_status(self.fs.CreateDir(directory, recursive=recursive)) + + def delete_dir(self, path): + """Delete a directory and its contents, recursively. + + Parameters + ---------- + path : str + The path of the directory to be deleted. + """ + cdef c_string directory = _path_as_bytes(path) + with nogil: + check_status(self.fs.DeleteDir(directory)) + + def delete_dir_contents(self, path, *, bint accept_root_dir=False): + """Delete a directory's contents, recursively. + + Like delete_dir, but doesn't delete the directory itself. + + Parameters + ---------- + path : str + The path of the directory to be deleted. + accept_root_dir : boolean, default False + Allow deleting the root directory's contents + (if path is empty or "/") + """ + cdef c_string directory = _path_as_bytes(path) + if accept_root_dir and directory.strip(b"/") == b"": + with nogil: + check_status(self.fs.DeleteRootDirContents()) + else: + with nogil: + check_status(self.fs.DeleteDirContents(directory)) + + def move(self, src, dest): + """ + Move / rename a file or directory. + + If the destination exists: + - if it is a non-empty directory, an error is returned + - otherwise, if it has the same type as the source, it is replaced + - otherwise, behavior is unspecified (implementation-dependent). + + Parameters + ---------- + src : str + The path of the file or the directory to be moved. + dest : str + The destination path where the file or directory is moved to. + """ + cdef: + c_string source = _path_as_bytes(src) + c_string destination = _path_as_bytes(dest) + with nogil: + check_status(self.fs.Move(source, destination)) + + def copy_file(self, src, dest): + """ + Copy a file. + + If the destination exists and is a directory, an error is returned. + Otherwise, it is replaced. + + Parameters + ---------- + src : str + The path of the file to be copied from. + dest : str + The destination path where the file is copied to. + """ + cdef: + c_string source = _path_as_bytes(src) + c_string destination = _path_as_bytes(dest) + with nogil: + check_status(self.fs.CopyFile(source, destination)) + + def delete_file(self, path): + """ + Delete a file. + + Parameters + ---------- + path : str + The path of the file to be deleted. + """ + cdef c_string file = _path_as_bytes(path) + with nogil: + check_status(self.fs.DeleteFile(file)) + + def _wrap_input_stream(self, stream, path, compression, buffer_size): + if buffer_size is not None and buffer_size != 0: + stream = BufferedInputStream(stream, buffer_size) + if compression == 'detect': + compression = _detect_compression(path) + if compression is not None: + stream = CompressedInputStream(stream, compression) + return stream + + def _wrap_output_stream(self, stream, path, compression, buffer_size): + if buffer_size is not None and buffer_size != 0: + stream = BufferedOutputStream(stream, buffer_size) + if compression == 'detect': + compression = _detect_compression(path) + if compression is not None: + stream = CompressedOutputStream(stream, compression) + return stream + + def open_input_file(self, path): + """ + Open an input file for random access reading. + + Parameters + ---------- + path : str + The source to open for reading. + + Returns + ------- + stram : NativeFile + """ + cdef: + c_string pathstr = _path_as_bytes(path) + NativeFile stream = NativeFile() + shared_ptr[CRandomAccessFile] in_handle + + with nogil: + in_handle = GetResultValue(self.fs.OpenInputFile(pathstr)) + + stream.set_random_access_file(in_handle) + stream.is_readable = True + return stream + + def open_input_stream(self, path, compression='detect', buffer_size=None): + """ + Open an input stream for sequential reading. + + Parameters + ---------- + source : str + The source to open for reading. + compression : str optional, default 'detect' + The compression algorithm to use for on-the-fly decompression. + If "detect" and source is a file path, then compression will be + chosen based on the file extension. + If None, no compression will be applied. Otherwise, a well-known + algorithm name must be supplied (e.g. "gzip"). + buffer_size : int optional, default None + If None or 0, no buffering will happen. Otherwise the size of the + temporary read buffer. + + Returns + ------- + stream : NativeFile + """ + cdef: + c_string pathstr = _path_as_bytes(path) + NativeFile stream = NativeFile() + shared_ptr[CInputStream] in_handle + + with nogil: + in_handle = GetResultValue(self.fs.OpenInputStream(pathstr)) + + stream.set_input_stream(in_handle) + stream.is_readable = True + + return self._wrap_input_stream( + stream, path=path, compression=compression, buffer_size=buffer_size + ) + + def open_output_stream(self, path, compression='detect', + buffer_size=None, metadata=None): + """ + Open an output stream for sequential writing. + + If the target already exists, existing data is truncated. + + Parameters + ---------- + path : str + The source to open for writing. + compression : str optional, default 'detect' + The compression algorithm to use for on-the-fly compression. + If "detect" and source is a file path, then compression will be + chosen based on the file extension. + If None, no compression will be applied. Otherwise, a well-known + algorithm name must be supplied (e.g. "gzip"). + buffer_size : int optional, default None + If None or 0, no buffering will happen. Otherwise the size of the + temporary write buffer. + metadata : dict optional, default None + If not None, a mapping of string keys to string values. + Some filesystems support storing metadata along the file + (such as "Content-Type"). + Unsupported metadata keys will be ignored. + + Returns + ------- + stream : NativeFile + """ + cdef: + c_string pathstr = _path_as_bytes(path) + NativeFile stream = NativeFile() + shared_ptr[COutputStream] out_handle + shared_ptr[const CKeyValueMetadata] c_metadata + + if metadata is not None: + c_metadata = pyarrow_unwrap_metadata(KeyValueMetadata(metadata)) + + with nogil: + out_handle = GetResultValue( + self.fs.OpenOutputStream(pathstr, c_metadata)) + + stream.set_output_stream(out_handle) + stream.is_writable = True + + return self._wrap_output_stream( + stream, path=path, compression=compression, buffer_size=buffer_size + ) + + def open_append_stream(self, path, compression='detect', + buffer_size=None, metadata=None): + """ + DEPRECATED: Open an output stream for appending. + + If the target doesn't exist, a new empty file is created. + + .. deprecated:: 6.0 + Several filesystems don't support this functionality + and it will be later removed. + + Parameters + ---------- + path : str + The source to open for writing. + compression : str optional, default 'detect' + The compression algorithm to use for on-the-fly compression. + If "detect" and source is a file path, then compression will be + chosen based on the file extension. + If None, no compression will be applied. Otherwise, a well-known + algorithm name must be supplied (e.g. "gzip"). + buffer_size : int optional, default None + If None or 0, no buffering will happen. Otherwise the size of the + temporary write buffer. + metadata : dict optional, default None + If not None, a mapping of string keys to string values. + Some filesystems support storing metadata along the file + (such as "Content-Type"). + Unsupported metadata keys will be ignored. + + Returns + ------- + stream : NativeFile + """ + cdef: + c_string pathstr = _path_as_bytes(path) + NativeFile stream = NativeFile() + shared_ptr[COutputStream] out_handle + shared_ptr[const CKeyValueMetadata] c_metadata + + warnings.warn( + "`open_append_stream` is deprecated as of 6.0.0; several " + "filesystems don't support it and it will be later removed", + FutureWarning) + + if metadata is not None: + c_metadata = pyarrow_unwrap_metadata(KeyValueMetadata(metadata)) + + with nogil: + out_handle = GetResultValue( + self.fs.OpenAppendStream(pathstr, c_metadata)) + + stream.set_output_stream(out_handle) + stream.is_writable = True + + return self._wrap_output_stream( + stream, path=path, compression=compression, buffer_size=buffer_size + ) + + def normalize_path(self, path): + """ + Normalize filesystem path. + + Parameters + ---------- + path : str + The path to normalize + + Returns + ------- + normalized_path : str + The normalized path + """ + cdef: + c_string c_path = _path_as_bytes(path) + c_string c_path_normalized + + c_path_normalized = GetResultValue(self.fs.NormalizePath(c_path)) + return frombytes(c_path_normalized) + + +cdef class LocalFileSystem(FileSystem): + """ + A FileSystem implementation accessing files on the local machine. + + Details such as symlinks are abstracted away (symlinks are always followed, + except when deleting an entry). + + Parameters + ---------- + use_mmap : bool, default False + Whether open_input_stream and open_input_file should return + a mmap'ed file or a regular file. + """ + + def __init__(self, *, use_mmap=False): + cdef: + CLocalFileSystemOptions opts + shared_ptr[CLocalFileSystem] fs + + opts = CLocalFileSystemOptions.Defaults() + opts.use_mmap = use_mmap + + fs = make_shared[CLocalFileSystem](opts) + self.init(<shared_ptr[CFileSystem]> fs) + + cdef init(self, const shared_ptr[CFileSystem]& c_fs): + FileSystem.init(self, c_fs) + self.localfs = <CLocalFileSystem*> c_fs.get() + + @classmethod + def _reconstruct(cls, kwargs): + # __reduce__ doesn't allow passing named arguments directly to the + # reconstructor, hence this wrapper. + return cls(**kwargs) + + def __reduce__(self): + cdef CLocalFileSystemOptions opts = self.localfs.options() + return LocalFileSystem._reconstruct, (dict( + use_mmap=opts.use_mmap),) + + +cdef class SubTreeFileSystem(FileSystem): + """ + Delegates to another implementation after prepending a fixed base path. + + This is useful to expose a logical view of a subtree of a filesystem, + for example a directory in a LocalFileSystem. + + Note, that this makes no security guarantee. For example, symlinks may + allow to "escape" the subtree and access other parts of the underlying + filesystem. + + Parameters + ---------- + base_path : str + The root of the subtree. + base_fs : FileSystem + FileSystem object the operations delegated to. + """ + + def __init__(self, base_path, FileSystem base_fs): + cdef: + c_string pathstr + shared_ptr[CSubTreeFileSystem] wrapped + + pathstr = _path_as_bytes(base_path) + wrapped = make_shared[CSubTreeFileSystem](pathstr, base_fs.wrapped) + + self.init(<shared_ptr[CFileSystem]> wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.subtreefs = <CSubTreeFileSystem*> wrapped.get() + + def __repr__(self): + return ("SubTreeFileSystem(base_path={}, base_fs={}" + .format(self.base_path, self.base_fs)) + + def __reduce__(self): + return SubTreeFileSystem, ( + frombytes(self.subtreefs.base_path()), + FileSystem.wrap(self.subtreefs.base_fs()) + ) + + @property + def base_path(self): + return frombytes(self.subtreefs.base_path()) + + @property + def base_fs(self): + return FileSystem.wrap(self.subtreefs.base_fs()) + + +cdef class _MockFileSystem(FileSystem): + + def __init__(self, datetime current_time=None): + cdef shared_ptr[CMockFileSystem] wrapped + + current_time = current_time or datetime.now() + wrapped = make_shared[CMockFileSystem]( + PyDateTime_to_TimePoint(<PyDateTime_DateTime*> current_time) + ) + + self.init(<shared_ptr[CFileSystem]> wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.mockfs = <CMockFileSystem*> wrapped.get() + + +cdef class PyFileSystem(FileSystem): + """ + A FileSystem with behavior implemented in Python. + + Parameters + ---------- + handler : FileSystemHandler + The handler object implementing custom filesystem behavior. + """ + + def __init__(self, handler): + cdef: + CPyFileSystemVtable vtable + shared_ptr[CPyFileSystem] wrapped + + if not isinstance(handler, FileSystemHandler): + raise TypeError("Expected a FileSystemHandler instance, got {0}" + .format(type(handler))) + + vtable.get_type_name = _cb_get_type_name + vtable.equals = _cb_equals + vtable.get_file_info = _cb_get_file_info + vtable.get_file_info_vector = _cb_get_file_info_vector + vtable.get_file_info_selector = _cb_get_file_info_selector + vtable.create_dir = _cb_create_dir + vtable.delete_dir = _cb_delete_dir + vtable.delete_dir_contents = _cb_delete_dir_contents + vtable.delete_root_dir_contents = _cb_delete_root_dir_contents + vtable.delete_file = _cb_delete_file + vtable.move = _cb_move + vtable.copy_file = _cb_copy_file + vtable.open_input_stream = _cb_open_input_stream + vtable.open_input_file = _cb_open_input_file + vtable.open_output_stream = _cb_open_output_stream + vtable.open_append_stream = _cb_open_append_stream + vtable.normalize_path = _cb_normalize_path + + wrapped = CPyFileSystem.Make(handler, move(vtable)) + self.init(<shared_ptr[CFileSystem]> wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.pyfs = <CPyFileSystem*> wrapped.get() + + @property + def handler(self): + """ + The filesystem's underlying handler. + + Returns + ------- + handler : FileSystemHandler + """ + return <object> self.pyfs.handler() + + def __reduce__(self): + return PyFileSystem, (self.handler,) + + +class FileSystemHandler(ABC): + """ + An abstract class exposing methods to implement PyFileSystem's behavior. + """ + + @abstractmethod + def get_type_name(self): + """ + Implement PyFileSystem.type_name. + """ + + @abstractmethod + def get_file_info(self, paths): + """ + Implement PyFileSystem.get_file_info(paths). + + Parameters + ---------- + paths : paths for which we want to retrieve the info. + """ + + @abstractmethod + def get_file_info_selector(self, selector): + """ + Implement PyFileSystem.get_file_info(selector). + + Parameters + ---------- + selector : selector for which we want to retrieve the info. + """ + + @abstractmethod + def create_dir(self, path, recursive): + """ + Implement PyFileSystem.create_dir(...). + + Parameters + ---------- + path : path of the directory. + recursive : if the parent directories should be created too. + """ + + @abstractmethod + def delete_dir(self, path): + """ + Implement PyFileSystem.delete_dir(...). + + Parameters + ---------- + path : path of the directory. + """ + + @abstractmethod + def delete_dir_contents(self, path): + """ + Implement PyFileSystem.delete_dir_contents(...). + + Parameters + ---------- + path : path of the directory. + """ + + @abstractmethod + def delete_root_dir_contents(self): + """ + Implement PyFileSystem.delete_dir_contents("/", accept_root_dir=True). + """ + + @abstractmethod + def delete_file(self, path): + """ + Implement PyFileSystem.delete_file(...). + + Parameters + ---------- + path : path of the file. + """ + + @abstractmethod + def move(self, src, dest): + """ + Implement PyFileSystem.move(...). + + Parameters + ---------- + src : path of what should be moved. + dest : path of where it should be moved to. + """ + + @abstractmethod + def copy_file(self, src, dest): + """ + Implement PyFileSystem.copy_file(...). + + Parameters + ---------- + src : path of what should be copied. + dest : path of where it should be copied to. + """ + + @abstractmethod + def open_input_stream(self, path): + """ + Implement PyFileSystem.open_input_stream(...). + + Parameters + ---------- + path : path of what should be opened. + """ + + @abstractmethod + def open_input_file(self, path): + """ + Implement PyFileSystem.open_input_file(...). + + Parameters + ---------- + path : path of what should be opened. + """ + + @abstractmethod + def open_output_stream(self, path, metadata): + """ + Implement PyFileSystem.open_output_stream(...). + + Parameters + ---------- + path : path of what should be opened. + metadata : mapping of string keys to string values. + Some filesystems support storing metadata along the file + (such as "Content-Type"). + """ + + @abstractmethod + def open_append_stream(self, path, metadata): + """ + Implement PyFileSystem.open_append_stream(...). + + Parameters + ---------- + path : path of what should be opened. + metadata : mapping of string keys to string values. + Some filesystems support storing metadata along the file + (such as "Content-Type"). + """ + + @abstractmethod + def normalize_path(self, path): + """ + Implement PyFileSystem.normalize_path(...). + + Parameters + ---------- + path : path of what should be normalized. + """ + +# Callback definitions for CPyFileSystemVtable + + +cdef void _cb_get_type_name(handler, c_string* out) except *: + out[0] = tobytes("py::" + handler.get_type_name()) + +cdef c_bool _cb_equals(handler, const CFileSystem& c_other) except False: + if c_other.type_name().startswith(b"py::"): + return <object> (<const CPyFileSystem&> c_other).handler() == handler + + return False + +cdef void _cb_get_file_info(handler, const c_string& path, + CFileInfo* out) except *: + infos = handler.get_file_info([frombytes(path)]) + if not isinstance(infos, list) or len(infos) != 1: + raise TypeError("get_file_info should have returned a 1-element list") + out[0] = FileInfo.unwrap_safe(infos[0]) + +cdef void _cb_get_file_info_vector(handler, const vector[c_string]& paths, + vector[CFileInfo]* out) except *: + py_paths = [frombytes(paths[i]) for i in range(len(paths))] + infos = handler.get_file_info(py_paths) + if not isinstance(infos, list): + raise TypeError("get_file_info should have returned a list") + out[0].clear() + out[0].reserve(len(infos)) + for info in infos: + out[0].push_back(FileInfo.unwrap_safe(info)) + +cdef void _cb_get_file_info_selector(handler, const CFileSelector& selector, + vector[CFileInfo]* out) except *: + infos = handler.get_file_info_selector(FileSelector.wrap(selector)) + if not isinstance(infos, list): + raise TypeError("get_file_info_selector should have returned a list") + out[0].clear() + out[0].reserve(len(infos)) + for info in infos: + out[0].push_back(FileInfo.unwrap_safe(info)) + +cdef void _cb_create_dir(handler, const c_string& path, + c_bool recursive) except *: + handler.create_dir(frombytes(path), recursive) + +cdef void _cb_delete_dir(handler, const c_string& path) except *: + handler.delete_dir(frombytes(path)) + +cdef void _cb_delete_dir_contents(handler, const c_string& path) except *: + handler.delete_dir_contents(frombytes(path)) + +cdef void _cb_delete_root_dir_contents(handler) except *: + handler.delete_root_dir_contents() + +cdef void _cb_delete_file(handler, const c_string& path) except *: + handler.delete_file(frombytes(path)) + +cdef void _cb_move(handler, const c_string& src, + const c_string& dest) except *: + handler.move(frombytes(src), frombytes(dest)) + +cdef void _cb_copy_file(handler, const c_string& src, + const c_string& dest) except *: + handler.copy_file(frombytes(src), frombytes(dest)) + +cdef void _cb_open_input_stream(handler, const c_string& path, + shared_ptr[CInputStream]* out) except *: + stream = handler.open_input_stream(frombytes(path)) + if not isinstance(stream, NativeFile): + raise TypeError("open_input_stream should have returned " + "a PyArrow file") + out[0] = (<NativeFile> stream).get_input_stream() + +cdef void _cb_open_input_file(handler, const c_string& path, + shared_ptr[CRandomAccessFile]* out) except *: + stream = handler.open_input_file(frombytes(path)) + if not isinstance(stream, NativeFile): + raise TypeError("open_input_file should have returned " + "a PyArrow file") + out[0] = (<NativeFile> stream).get_random_access_file() + +cdef void _cb_open_output_stream( + handler, const c_string& path, + const shared_ptr[const CKeyValueMetadata]& metadata, + shared_ptr[COutputStream]* out) except *: + stream = handler.open_output_stream( + frombytes(path), pyarrow_wrap_metadata(metadata)) + if not isinstance(stream, NativeFile): + raise TypeError("open_output_stream should have returned " + "a PyArrow file") + out[0] = (<NativeFile> stream).get_output_stream() + +cdef void _cb_open_append_stream( + handler, const c_string& path, + const shared_ptr[const CKeyValueMetadata]& metadata, + shared_ptr[COutputStream]* out) except *: + stream = handler.open_append_stream( + frombytes(path), pyarrow_wrap_metadata(metadata)) + if not isinstance(stream, NativeFile): + raise TypeError("open_append_stream should have returned " + "a PyArrow file") + out[0] = (<NativeFile> stream).get_output_stream() + +cdef void _cb_normalize_path(handler, const c_string& path, + c_string* out) except *: + out[0] = tobytes(handler.normalize_path(frombytes(path))) + + +def _copy_files(FileSystem source_fs, str source_path, + FileSystem destination_fs, str destination_path, + int64_t chunk_size, c_bool use_threads): + # low-level helper exposed through pyarrow/fs.py::copy_files + cdef: + CFileLocator c_source + vector[CFileLocator] c_sources + CFileLocator c_destination + vector[CFileLocator] c_destinations + FileSystem fs + CStatus c_status + shared_ptr[CFileSystem] c_fs + + c_source.filesystem = source_fs.unwrap() + c_source.path = tobytes(source_path) + c_sources.push_back(c_source) + + c_destination.filesystem = destination_fs.unwrap() + c_destination.path = tobytes(destination_path) + c_destinations.push_back(c_destination) + + with nogil: + check_status(CCopyFiles( + c_sources, c_destinations, + c_default_io_context(), chunk_size, use_threads, + )) + + +def _copy_files_selector(FileSystem source_fs, FileSelector source_sel, + FileSystem destination_fs, str destination_base_dir, + int64_t chunk_size, c_bool use_threads): + # low-level helper exposed through pyarrow/fs.py::copy_files + cdef c_string c_destination_base_dir = tobytes(destination_base_dir) + + with nogil: + check_status(CCopyFilesWithSelector( + source_fs.unwrap(), source_sel.unwrap(), + destination_fs.unwrap(), c_destination_base_dir, + c_default_io_context(), chunk_size, use_threads, + )) diff --git a/src/arrow/python/pyarrow/_hdfs.pyx b/src/arrow/python/pyarrow/_hdfs.pyx new file mode 100644 index 000000000..7a3b974be --- /dev/null +++ b/src/arrow/python/pyarrow/_hdfs.pyx @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.lib cimport check_status +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow._fs cimport FileSystem + +from pyarrow.lib import frombytes, tobytes +from pyarrow.util import _stringify_path + + +cdef class HadoopFileSystem(FileSystem): + """ + HDFS backed FileSystem implementation + + Parameters + ---------- + host : str + HDFS host to connect to. Set to "default" for fs.defaultFS from + core-site.xml. + port : int, default 8020 + HDFS port to connect to. Set to 0 for default or logical (HA) nodes. + user : str, default None + Username when connecting to HDFS; None implies login user. + replication : int, default 3 + Number of copies each block will have. + buffer_size : int, default 0 + If 0, no buffering will happen otherwise the size of the temporary read + and write buffer. + default_block_size : int, default None + None means the default configuration for HDFS, a typical block size is + 128 MB. + kerb_ticket : string or path, default None + If not None, the path to the Kerberos ticket cache. + extra_conf : dict, default None + Extra key/value pairs for configuration; will override any + hdfs-site.xml properties. + """ + + cdef: + CHadoopFileSystem* hdfs + + def __init__(self, str host, int port=8020, *, str user=None, + int replication=3, int buffer_size=0, + default_block_size=None, kerb_ticket=None, + extra_conf=None): + cdef: + CHdfsOptions options + shared_ptr[CHadoopFileSystem] wrapped + + if not host.startswith(('hdfs://', 'viewfs://')) and host != "default": + # TODO(kszucs): do more sanitization + host = 'hdfs://{}'.format(host) + + options.ConfigureEndPoint(tobytes(host), int(port)) + options.ConfigureReplication(replication) + options.ConfigureBufferSize(buffer_size) + + if user is not None: + options.ConfigureUser(tobytes(user)) + if default_block_size is not None: + options.ConfigureBlockSize(default_block_size) + if kerb_ticket is not None: + options.ConfigureKerberosTicketCachePath( + tobytes(_stringify_path(kerb_ticket))) + if extra_conf is not None: + for k, v in extra_conf.items(): + options.ConfigureExtraConf(tobytes(k), tobytes(v)) + + with nogil: + wrapped = GetResultValue(CHadoopFileSystem.Make(options)) + self.init(<shared_ptr[CFileSystem]> wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.hdfs = <CHadoopFileSystem*> wrapped.get() + + @staticmethod + def from_uri(uri): + """ + Instantiate HadoopFileSystem object from an URI string. + + The following two calls are equivalent + + * ``HadoopFileSystem.from_uri('hdfs://localhost:8020/?user=test\ +&replication=1')`` + * ``HadoopFileSystem('localhost', port=8020, user='test', \ +replication=1)`` + + Parameters + ---------- + uri : str + A string URI describing the connection to HDFS. + In order to change the user, replication, buffer_size or + default_block_size pass the values as query parts. + + Returns + ------- + HadoopFileSystem + """ + cdef: + HadoopFileSystem self = HadoopFileSystem.__new__(HadoopFileSystem) + shared_ptr[CHadoopFileSystem] wrapped + CHdfsOptions options + + options = GetResultValue(CHdfsOptions.FromUriString(tobytes(uri))) + with nogil: + wrapped = GetResultValue(CHadoopFileSystem.Make(options)) + + self.init(<shared_ptr[CFileSystem]> wrapped) + return self + + @classmethod + def _reconstruct(cls, kwargs): + return cls(**kwargs) + + def __reduce__(self): + cdef CHdfsOptions opts = self.hdfs.options() + return ( + HadoopFileSystem._reconstruct, (dict( + host=frombytes(opts.connection_config.host), + port=opts.connection_config.port, + user=frombytes(opts.connection_config.user), + replication=opts.replication, + buffer_size=opts.buffer_size, + default_block_size=opts.default_block_size, + kerb_ticket=frombytes(opts.connection_config.kerb_ticket), + extra_conf={frombytes(k): frombytes(v) + for k, v in opts.connection_config.extra_conf}, + ),) + ) diff --git a/src/arrow/python/pyarrow/_hdfsio.pyx b/src/arrow/python/pyarrow/_hdfsio.pyx new file mode 100644 index 000000000..b864f8a68 --- /dev/null +++ b/src/arrow/python/pyarrow/_hdfsio.pyx @@ -0,0 +1,480 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# ---------------------------------------------------------------------- +# HDFS IO implementation + +# cython: language_level = 3 + +import re + +from pyarrow.lib cimport check_status, _Weakrefable, NativeFile +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow.lib import frombytes, tobytes, ArrowIOError + +from queue import Queue, Empty as QueueEmpty, Full as QueueFull + + +_HDFS_PATH_RE = re.compile(r'hdfs://(.*):(\d+)(.*)') + + +def have_libhdfs(): + try: + with nogil: + check_status(HaveLibHdfs()) + return True + except Exception: + return False + + +def strip_hdfs_abspath(path): + m = _HDFS_PATH_RE.match(path) + if m: + return m.group(3) + else: + return path + + +cdef class HadoopFileSystem(_Weakrefable): + cdef: + shared_ptr[CIOHadoopFileSystem] client + + cdef readonly: + bint is_open + object host + object user + object kerb_ticket + int port + dict extra_conf + + def _connect(self, host, port, user, kerb_ticket, extra_conf): + cdef HdfsConnectionConfig conf + + if host is not None: + conf.host = tobytes(host) + self.host = host + + conf.port = port + self.port = port + + if user is not None: + conf.user = tobytes(user) + self.user = user + + if kerb_ticket is not None: + conf.kerb_ticket = tobytes(kerb_ticket) + self.kerb_ticket = kerb_ticket + + with nogil: + check_status(HaveLibHdfs()) + + if extra_conf is not None and isinstance(extra_conf, dict): + conf.extra_conf = {tobytes(k): tobytes(v) + for k, v in extra_conf.items()} + self.extra_conf = extra_conf + + with nogil: + check_status(CIOHadoopFileSystem.Connect(&conf, &self.client)) + self.is_open = True + + @classmethod + def connect(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def __dealloc__(self): + if self.is_open: + self.close() + + def close(self): + """ + Disconnect from the HDFS cluster + """ + self._ensure_client() + with nogil: + check_status(self.client.get().Disconnect()) + self.is_open = False + + cdef _ensure_client(self): + if self.client.get() == NULL: + raise IOError('HDFS client improperly initialized') + elif not self.is_open: + raise IOError('HDFS client is closed') + + def exists(self, path): + """ + Returns True if the path is known to the cluster, False if it does not + (or there is an RPC error) + """ + self._ensure_client() + + cdef c_string c_path = tobytes(path) + cdef c_bool result + with nogil: + result = self.client.get().Exists(c_path) + return result + + def isdir(self, path): + cdef HdfsPathInfo info + try: + self._path_info(path, &info) + except ArrowIOError: + return False + return info.kind == ObjectType_DIRECTORY + + def isfile(self, path): + cdef HdfsPathInfo info + try: + self._path_info(path, &info) + except ArrowIOError: + return False + return info.kind == ObjectType_FILE + + def get_capacity(self): + """ + Get reported total capacity of file system + + Returns + ------- + capacity : int + """ + cdef int64_t capacity = 0 + with nogil: + check_status(self.client.get().GetCapacity(&capacity)) + return capacity + + def get_space_used(self): + """ + Get space used on file system + + Returns + ------- + space_used : int + """ + cdef int64_t space_used = 0 + with nogil: + check_status(self.client.get().GetUsed(&space_used)) + return space_used + + def df(self): + """ + Return free space on disk, like the UNIX df command + + Returns + ------- + space : int + """ + return self.get_capacity() - self.get_space_used() + + def rename(self, path, new_path): + cdef c_string c_path = tobytes(path) + cdef c_string c_new_path = tobytes(new_path) + with nogil: + check_status(self.client.get().Rename(c_path, c_new_path)) + + def info(self, path): + """ + Return detailed HDFS information for path + + Parameters + ---------- + path : string + Path to file or directory + + Returns + ------- + path_info : dict + """ + cdef HdfsPathInfo info + self._path_info(path, &info) + return { + 'path': frombytes(info.name), + 'owner': frombytes(info.owner), + 'group': frombytes(info.group), + 'size': info.size, + 'block_size': info.block_size, + 'last_modified': info.last_modified_time, + 'last_accessed': info.last_access_time, + 'replication': info.replication, + 'permissions': info.permissions, + 'kind': ('directory' if info.kind == ObjectType_DIRECTORY + else 'file') + } + + def stat(self, path): + """ + Return basic file system statistics about path + + Parameters + ---------- + path : string + Path to file or directory + + Returns + ------- + stat : dict + """ + cdef FileStatistics info + cdef c_string c_path = tobytes(path) + with nogil: + check_status(self.client.get() + .Stat(c_path, &info)) + return { + 'size': info.size, + 'kind': ('directory' if info.kind == ObjectType_DIRECTORY + else 'file') + } + + cdef _path_info(self, path, HdfsPathInfo* info): + cdef c_string c_path = tobytes(path) + + with nogil: + check_status(self.client.get() + .GetPathInfo(c_path, info)) + + def ls(self, path, bint full_info): + cdef: + c_string c_path = tobytes(path) + vector[HdfsPathInfo] listing + list results = [] + int i + + self._ensure_client() + + with nogil: + check_status(self.client.get() + .ListDirectory(c_path, &listing)) + + cdef const HdfsPathInfo* info + for i in range(<int> listing.size()): + info = &listing[i] + + # Try to trim off the hdfs://HOST:PORT piece + name = strip_hdfs_abspath(frombytes(info.name)) + + if full_info: + kind = ('file' if info.kind == ObjectType_FILE + else 'directory') + + results.append({ + 'kind': kind, + 'name': name, + 'owner': frombytes(info.owner), + 'group': frombytes(info.group), + 'last_modified_time': info.last_modified_time, + 'last_access_time': info.last_access_time, + 'size': info.size, + 'replication': info.replication, + 'block_size': info.block_size, + 'permissions': info.permissions + }) + else: + results.append(name) + + return results + + def chmod(self, path, mode): + """ + Change file permissions + + Parameters + ---------- + path : string + absolute path to file or directory + mode : int + POSIX-like bitmask + """ + self._ensure_client() + cdef c_string c_path = tobytes(path) + cdef int c_mode = mode + with nogil: + check_status(self.client.get() + .Chmod(c_path, c_mode)) + + def chown(self, path, owner=None, group=None): + """ + Change file permissions + + Parameters + ---------- + path : string + absolute path to file or directory + owner : string, default None + New owner, None for no change + group : string, default None + New group, None for no change + """ + cdef: + c_string c_path + c_string c_owner + c_string c_group + const char* c_owner_ptr = NULL + const char* c_group_ptr = NULL + + self._ensure_client() + + c_path = tobytes(path) + if owner is not None: + c_owner = tobytes(owner) + c_owner_ptr = c_owner.c_str() + + if group is not None: + c_group = tobytes(group) + c_group_ptr = c_group.c_str() + + with nogil: + check_status(self.client.get() + .Chown(c_path, c_owner_ptr, c_group_ptr)) + + def mkdir(self, path): + """ + Create indicated directory and any necessary parent directories + """ + self._ensure_client() + cdef c_string c_path = tobytes(path) + with nogil: + check_status(self.client.get() + .MakeDirectory(c_path)) + + def delete(self, path, bint recursive=False): + """ + Delete the indicated file or directory + + Parameters + ---------- + path : string + recursive : boolean, default False + If True, also delete child paths for directories + """ + self._ensure_client() + + cdef c_string c_path = tobytes(path) + with nogil: + check_status(self.client.get() + .Delete(c_path, recursive == 1)) + + def open(self, path, mode='rb', buffer_size=None, replication=None, + default_block_size=None): + """ + Open HDFS file for reading or writing + + Parameters + ---------- + mode : string + Must be one of 'rb', 'wb', 'ab' + + Returns + ------- + handle : HdfsFile + """ + self._ensure_client() + + cdef HdfsFile out = HdfsFile() + + if mode not in ('rb', 'wb', 'ab'): + raise Exception("Mode must be 'rb' (read), " + "'wb' (write, new file), or 'ab' (append)") + + cdef c_string c_path = tobytes(path) + cdef c_bool append = False + + # 0 in libhdfs means "use the default" + cdef int32_t c_buffer_size = buffer_size or 0 + cdef int16_t c_replication = replication or 0 + cdef int64_t c_default_block_size = default_block_size or 0 + + cdef shared_ptr[HdfsOutputStream] wr_handle + cdef shared_ptr[HdfsReadableFile] rd_handle + + if mode in ('wb', 'ab'): + if mode == 'ab': + append = True + + with nogil: + check_status( + self.client.get() + .OpenWritable(c_path, append, c_buffer_size, + c_replication, c_default_block_size, + &wr_handle)) + + out.set_output_stream(<shared_ptr[COutputStream]> wr_handle) + out.is_writable = True + else: + with nogil: + check_status(self.client.get() + .OpenReadable(c_path, &rd_handle)) + + out.set_random_access_file( + <shared_ptr[CRandomAccessFile]> rd_handle) + out.is_readable = True + + assert not out.closed + + if c_buffer_size == 0: + c_buffer_size = 2 ** 16 + + out.mode = mode + out.buffer_size = c_buffer_size + out.parent = _HdfsFileNanny(self, out) + out.own_file = True + + return out + + def download(self, path, stream, buffer_size=None): + with self.open(path, 'rb') as f: + f.download(stream, buffer_size=buffer_size) + + def upload(self, path, stream, buffer_size=None): + """ + Upload file-like object to HDFS path + """ + with self.open(path, 'wb') as f: + f.upload(stream, buffer_size=buffer_size) + + +# ARROW-404: Helper class to ensure that files are closed before the +# client. During deallocation of the extension class, the attributes are +# decref'd which can cause the client to get closed first if the file has the +# last remaining reference +cdef class _HdfsFileNanny(_Weakrefable): + cdef: + object client + object file_handle_ref + + def __cinit__(self, client, file_handle): + import weakref + self.client = client + self.file_handle_ref = weakref.ref(file_handle) + + def __dealloc__(self): + fh = self.file_handle_ref() + if fh: + fh.close() + # avoid cyclic GC + self.file_handle_ref = None + self.client = None + + +cdef class HdfsFile(NativeFile): + cdef readonly: + int32_t buffer_size + object mode + object parent + + def __dealloc__(self): + self.parent = None diff --git a/src/arrow/python/pyarrow/_json.pyx b/src/arrow/python/pyarrow/_json.pyx new file mode 100644 index 000000000..1c08e546e --- /dev/null +++ b/src/arrow/python/pyarrow/_json.pyx @@ -0,0 +1,248 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (check_status, _Weakrefable, Field, MemoryPool, + ensure_type, maybe_unbox_memory_pool, + get_input_stream, pyarrow_wrap_table, + pyarrow_wrap_data_type, pyarrow_unwrap_data_type, + pyarrow_wrap_schema, pyarrow_unwrap_schema) + + +cdef class ReadOptions(_Weakrefable): + """ + Options for reading JSON files. + + Parameters + ---------- + use_threads : bool, optional (default True) + Whether to use multiple threads to accelerate reading + block_size : int, optional + How much bytes to process at a time from the input stream. + This will determine multi-threading granularity as well as + the size of individual chunks in the Table. + """ + cdef: + CJSONReadOptions options + + # Avoid mistakingly creating attributes + __slots__ = () + + def __init__(self, use_threads=None, block_size=None): + self.options = CJSONReadOptions.Defaults() + if use_threads is not None: + self.use_threads = use_threads + if block_size is not None: + self.block_size = block_size + + @property + def use_threads(self): + """ + Whether to use multiple threads to accelerate reading. + """ + return self.options.use_threads + + @use_threads.setter + def use_threads(self, value): + self.options.use_threads = value + + @property + def block_size(self): + """ + How much bytes to process at a time from the input stream. + + This will determine multi-threading granularity as well as the size of + individual chunks in the Table. + """ + return self.options.block_size + + @block_size.setter + def block_size(self, value): + self.options.block_size = value + + +cdef class ParseOptions(_Weakrefable): + """ + Options for parsing JSON files. + + Parameters + ---------- + explicit_schema : Schema, optional (default None) + Optional explicit schema (no type inference, ignores other fields). + newlines_in_values : bool, optional (default False) + Whether objects may be printed across multiple lines (for example + pretty printed). If false, input must end with an empty line. + unexpected_field_behavior : str, default "infer" + How JSON fields outside of explicit_schema (if given) are treated. + + Possible behaviors: + + - "ignore": unexpected JSON fields are ignored + - "error": error out on unexpected JSON fields + - "infer": unexpected JSON fields are type-inferred and included in + the output + """ + + cdef: + CJSONParseOptions options + + __slots__ = () + + def __init__(self, explicit_schema=None, newlines_in_values=None, + unexpected_field_behavior=None): + self.options = CJSONParseOptions.Defaults() + if explicit_schema is not None: + self.explicit_schema = explicit_schema + if newlines_in_values is not None: + self.newlines_in_values = newlines_in_values + if unexpected_field_behavior is not None: + self.unexpected_field_behavior = unexpected_field_behavior + + @property + def explicit_schema(self): + """ + Optional explicit schema (no type inference, ignores other fields) + """ + if self.options.explicit_schema.get() == NULL: + return None + else: + return pyarrow_wrap_schema(self.options.explicit_schema) + + @explicit_schema.setter + def explicit_schema(self, value): + self.options.explicit_schema = pyarrow_unwrap_schema(value) + + @property + def newlines_in_values(self): + """ + Whether newline characters are allowed in JSON values. + Setting this to True reduces the performance of multi-threaded + JSON reading. + """ + return self.options.newlines_in_values + + @newlines_in_values.setter + def newlines_in_values(self, value): + self.options.newlines_in_values = value + + @property + def unexpected_field_behavior(self): + """ + How JSON fields outside of explicit_schema (if given) are treated. + + Possible behaviors: + + - "ignore": unexpected JSON fields are ignored + - "error": error out on unexpected JSON fields + - "infer": unexpected JSON fields are type-inferred and included in + the output + + Set to "infer" by default. + """ + v = self.options.unexpected_field_behavior + if v == CUnexpectedFieldBehavior_Ignore: + return "ignore" + elif v == CUnexpectedFieldBehavior_Error: + return "error" + elif v == CUnexpectedFieldBehavior_InferType: + return "infer" + else: + raise ValueError('Unexpected value for unexpected_field_behavior') + + @unexpected_field_behavior.setter + def unexpected_field_behavior(self, value): + cdef CUnexpectedFieldBehavior v + + if value == "ignore": + v = CUnexpectedFieldBehavior_Ignore + elif value == "error": + v = CUnexpectedFieldBehavior_Error + elif value == "infer": + v = CUnexpectedFieldBehavior_InferType + else: + raise ValueError( + "Unexpected value `{}` for `unexpected_field_behavior`, pass " + "either `ignore`, `error` or `infer`.".format(value) + ) + + self.options.unexpected_field_behavior = v + + +cdef _get_reader(input_file, shared_ptr[CInputStream]* out): + use_memory_map = False + get_input_stream(input_file, use_memory_map, out) + +cdef _get_read_options(ReadOptions read_options, CJSONReadOptions* out): + if read_options is None: + out[0] = CJSONReadOptions.Defaults() + else: + out[0] = read_options.options + +cdef _get_parse_options(ParseOptions parse_options, CJSONParseOptions* out): + if parse_options is None: + out[0] = CJSONParseOptions.Defaults() + else: + out[0] = parse_options.options + + +def read_json(input_file, read_options=None, parse_options=None, + MemoryPool memory_pool=None): + """ + Read a Table from a stream of JSON data. + + Parameters + ---------- + input_file : str, path or file-like object + The location of JSON data. Currently only the line-delimited JSON + format is supported. + read_options : pyarrow.json.ReadOptions, optional + Options for the JSON reader (see ReadOptions constructor for defaults). + parse_options : pyarrow.json.ParseOptions, optional + Options for the JSON parser + (see ParseOptions constructor for defaults). + memory_pool : MemoryPool, optional + Pool to allocate Table memory from. + + Returns + ------- + :class:`pyarrow.Table` + Contents of the JSON file as a in-memory table. + """ + cdef: + shared_ptr[CInputStream] stream + CJSONReadOptions c_read_options + CJSONParseOptions c_parse_options + shared_ptr[CJSONReader] reader + shared_ptr[CTable] table + + _get_reader(input_file, &stream) + _get_read_options(read_options, &c_read_options) + _get_parse_options(parse_options, &c_parse_options) + + reader = GetResultValue( + CJSONReader.Make(maybe_unbox_memory_pool(memory_pool), + stream, c_read_options, c_parse_options)) + + with nogil: + table = GetResultValue(reader.get().Read()) + + return pyarrow_wrap_table(table) diff --git a/src/arrow/python/pyarrow/_orc.pxd b/src/arrow/python/pyarrow/_orc.pxd new file mode 100644 index 000000000..736622591 --- /dev/null +++ b/src/arrow/python/pyarrow/_orc.pxd @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ +# cython: language_level = 3 + +from libc.string cimport const_char +from libcpp.vector cimport vector as std_vector +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport (CArray, CSchema, CStatus, + CResult, CTable, CMemoryPool, + CKeyValueMetadata, + CRecordBatch, + CTable, + CRandomAccessFile, COutputStream, + TimeUnit) + + +cdef extern from "arrow/adapters/orc/adapter.h" \ + namespace "arrow::adapters::orc" nogil: + + cdef cppclass ORCFileReader: + @staticmethod + CResult[unique_ptr[ORCFileReader]] Open( + const shared_ptr[CRandomAccessFile]& file, + CMemoryPool* pool) + + CResult[shared_ptr[const CKeyValueMetadata]] ReadMetadata() + + CResult[shared_ptr[CSchema]] ReadSchema() + + CResult[shared_ptr[CRecordBatch]] ReadStripe(int64_t stripe) + CResult[shared_ptr[CRecordBatch]] ReadStripe( + int64_t stripe, std_vector[c_string]) + + CResult[shared_ptr[CTable]] Read() + CResult[shared_ptr[CTable]] Read(std_vector[c_string]) + + int64_t NumberOfStripes() + + int64_t NumberOfRows() + + cdef cppclass ORCFileWriter: + @staticmethod + CResult[unique_ptr[ORCFileWriter]] Open(COutputStream* output_stream) + + CStatus Write(const CTable& table) + + CStatus Close() diff --git a/src/arrow/python/pyarrow/_orc.pyx b/src/arrow/python/pyarrow/_orc.pyx new file mode 100644 index 000000000..18ca28682 --- /dev/null +++ b/src/arrow/python/pyarrow/_orc.pyx @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ + +from cython.operator cimport dereference as deref +from libcpp.vector cimport vector as std_vector +from libcpp.utility cimport move +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (check_status, _Weakrefable, + MemoryPool, maybe_unbox_memory_pool, + Schema, pyarrow_wrap_schema, + KeyValueMetadata, + pyarrow_wrap_batch, + RecordBatch, + Table, + pyarrow_wrap_table, + pyarrow_unwrap_schema, + pyarrow_wrap_metadata, + pyarrow_unwrap_table, + get_reader, + get_writer) +from pyarrow.lib import tobytes + + +cdef class ORCReader(_Weakrefable): + cdef: + object source + CMemoryPool* allocator + unique_ptr[ORCFileReader] reader + + def __cinit__(self, MemoryPool memory_pool=None): + self.allocator = maybe_unbox_memory_pool(memory_pool) + + def open(self, object source, c_bool use_memory_map=True): + cdef: + shared_ptr[CRandomAccessFile] rd_handle + + self.source = source + + get_reader(source, use_memory_map, &rd_handle) + with nogil: + self.reader = move(GetResultValue( + ORCFileReader.Open(rd_handle, self.allocator) + )) + + def metadata(self): + """ + The arrow metadata for this file. + + Returns + ------- + metadata : pyarrow.KeyValueMetadata + """ + cdef: + shared_ptr[const CKeyValueMetadata] sp_arrow_metadata + + with nogil: + sp_arrow_metadata = GetResultValue( + deref(self.reader).ReadMetadata() + ) + + return pyarrow_wrap_metadata(sp_arrow_metadata) + + def schema(self): + """ + The arrow schema for this file. + + Returns + ------- + schema : pyarrow.Schema + """ + cdef: + shared_ptr[CSchema] sp_arrow_schema + + with nogil: + sp_arrow_schema = GetResultValue(deref(self.reader).ReadSchema()) + + return pyarrow_wrap_schema(sp_arrow_schema) + + def nrows(self): + return deref(self.reader).NumberOfRows() + + def nstripes(self): + return deref(self.reader).NumberOfStripes() + + def read_stripe(self, n, columns=None): + cdef: + shared_ptr[CRecordBatch] sp_record_batch + RecordBatch batch + int64_t stripe + std_vector[c_string] c_names + + stripe = n + + if columns is None: + with nogil: + sp_record_batch = GetResultValue( + deref(self.reader).ReadStripe(stripe) + ) + else: + c_names = [tobytes(name) for name in columns] + with nogil: + sp_record_batch = GetResultValue( + deref(self.reader).ReadStripe(stripe, c_names) + ) + + return pyarrow_wrap_batch(sp_record_batch) + + def read(self, columns=None): + cdef: + shared_ptr[CTable] sp_table + std_vector[c_string] c_names + + if columns is None: + with nogil: + sp_table = GetResultValue(deref(self.reader).Read()) + else: + c_names = [tobytes(name) for name in columns] + with nogil: + sp_table = GetResultValue(deref(self.reader).Read(c_names)) + + return pyarrow_wrap_table(sp_table) + +cdef class ORCWriter(_Weakrefable): + cdef: + object source + unique_ptr[ORCFileWriter] writer + shared_ptr[COutputStream] rd_handle + + def open(self, object source): + self.source = source + get_writer(source, &self.rd_handle) + with nogil: + self.writer = move(GetResultValue[unique_ptr[ORCFileWriter]]( + ORCFileWriter.Open(self.rd_handle.get()))) + + def write(self, Table table): + cdef: + shared_ptr[CTable] sp_table + sp_table = pyarrow_unwrap_table(table) + with nogil: + check_status(deref(self.writer).Write(deref(sp_table))) + + def close(self): + with nogil: + check_status(deref(self.writer).Close()) diff --git a/src/arrow/python/pyarrow/_parquet.pxd b/src/arrow/python/pyarrow/_parquet.pxd new file mode 100644 index 000000000..0efbb0c86 --- /dev/null +++ b/src/arrow/python/pyarrow/_parquet.pxd @@ -0,0 +1,559 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ +# cython: language_level = 3 + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport (CChunkedArray, CSchema, CStatus, + CTable, CMemoryPool, CBuffer, + CKeyValueMetadata, + CRandomAccessFile, COutputStream, + TimeUnit, CRecordBatchReader) +from pyarrow.lib cimport _Weakrefable + + +cdef extern from "parquet/api/schema.h" namespace "parquet::schema" nogil: + cdef cppclass Node: + pass + + cdef cppclass GroupNode(Node): + pass + + cdef cppclass PrimitiveNode(Node): + pass + + cdef cppclass ColumnPath: + c_string ToDotString() + vector[c_string] ToDotVector() + + +cdef extern from "parquet/api/schema.h" namespace "parquet" nogil: + enum ParquetType" parquet::Type::type": + ParquetType_BOOLEAN" parquet::Type::BOOLEAN" + ParquetType_INT32" parquet::Type::INT32" + ParquetType_INT64" parquet::Type::INT64" + ParquetType_INT96" parquet::Type::INT96" + ParquetType_FLOAT" parquet::Type::FLOAT" + ParquetType_DOUBLE" parquet::Type::DOUBLE" + ParquetType_BYTE_ARRAY" parquet::Type::BYTE_ARRAY" + ParquetType_FIXED_LEN_BYTE_ARRAY" parquet::Type::FIXED_LEN_BYTE_ARRAY" + + enum ParquetLogicalTypeId" parquet::LogicalType::Type::type": + ParquetLogicalType_UNDEFINED" parquet::LogicalType::Type::UNDEFINED" + ParquetLogicalType_STRING" parquet::LogicalType::Type::STRING" + ParquetLogicalType_MAP" parquet::LogicalType::Type::MAP" + ParquetLogicalType_LIST" parquet::LogicalType::Type::LIST" + ParquetLogicalType_ENUM" parquet::LogicalType::Type::ENUM" + ParquetLogicalType_DECIMAL" parquet::LogicalType::Type::DECIMAL" + ParquetLogicalType_DATE" parquet::LogicalType::Type::DATE" + ParquetLogicalType_TIME" parquet::LogicalType::Type::TIME" + ParquetLogicalType_TIMESTAMP" parquet::LogicalType::Type::TIMESTAMP" + ParquetLogicalType_INT" parquet::LogicalType::Type::INT" + ParquetLogicalType_JSON" parquet::LogicalType::Type::JSON" + ParquetLogicalType_BSON" parquet::LogicalType::Type::BSON" + ParquetLogicalType_UUID" parquet::LogicalType::Type::UUID" + ParquetLogicalType_NONE" parquet::LogicalType::Type::NONE" + + enum ParquetTimeUnit" parquet::LogicalType::TimeUnit::unit": + ParquetTimeUnit_UNKNOWN" parquet::LogicalType::TimeUnit::UNKNOWN" + ParquetTimeUnit_MILLIS" parquet::LogicalType::TimeUnit::MILLIS" + ParquetTimeUnit_MICROS" parquet::LogicalType::TimeUnit::MICROS" + ParquetTimeUnit_NANOS" parquet::LogicalType::TimeUnit::NANOS" + + enum ParquetConvertedType" parquet::ConvertedType::type": + ParquetConvertedType_NONE" parquet::ConvertedType::NONE" + ParquetConvertedType_UTF8" parquet::ConvertedType::UTF8" + ParquetConvertedType_MAP" parquet::ConvertedType::MAP" + ParquetConvertedType_MAP_KEY_VALUE \ + " parquet::ConvertedType::MAP_KEY_VALUE" + ParquetConvertedType_LIST" parquet::ConvertedType::LIST" + ParquetConvertedType_ENUM" parquet::ConvertedType::ENUM" + ParquetConvertedType_DECIMAL" parquet::ConvertedType::DECIMAL" + ParquetConvertedType_DATE" parquet::ConvertedType::DATE" + ParquetConvertedType_TIME_MILLIS" parquet::ConvertedType::TIME_MILLIS" + ParquetConvertedType_TIME_MICROS" parquet::ConvertedType::TIME_MICROS" + ParquetConvertedType_TIMESTAMP_MILLIS \ + " parquet::ConvertedType::TIMESTAMP_MILLIS" + ParquetConvertedType_TIMESTAMP_MICROS \ + " parquet::ConvertedType::TIMESTAMP_MICROS" + ParquetConvertedType_UINT_8" parquet::ConvertedType::UINT_8" + ParquetConvertedType_UINT_16" parquet::ConvertedType::UINT_16" + ParquetConvertedType_UINT_32" parquet::ConvertedType::UINT_32" + ParquetConvertedType_UINT_64" parquet::ConvertedType::UINT_64" + ParquetConvertedType_INT_8" parquet::ConvertedType::INT_8" + ParquetConvertedType_INT_16" parquet::ConvertedType::INT_16" + ParquetConvertedType_INT_32" parquet::ConvertedType::INT_32" + ParquetConvertedType_INT_64" parquet::ConvertedType::INT_64" + ParquetConvertedType_JSON" parquet::ConvertedType::JSON" + ParquetConvertedType_BSON" parquet::ConvertedType::BSON" + ParquetConvertedType_INTERVAL" parquet::ConvertedType::INTERVAL" + + enum ParquetRepetition" parquet::Repetition::type": + ParquetRepetition_REQUIRED" parquet::REPETITION::REQUIRED" + ParquetRepetition_OPTIONAL" parquet::REPETITION::OPTIONAL" + ParquetRepetition_REPEATED" parquet::REPETITION::REPEATED" + + enum ParquetEncoding" parquet::Encoding::type": + ParquetEncoding_PLAIN" parquet::Encoding::PLAIN" + ParquetEncoding_PLAIN_DICTIONARY" parquet::Encoding::PLAIN_DICTIONARY" + ParquetEncoding_RLE" parquet::Encoding::RLE" + ParquetEncoding_BIT_PACKED" parquet::Encoding::BIT_PACKED" + ParquetEncoding_DELTA_BINARY_PACKED \ + " parquet::Encoding::DELTA_BINARY_PACKED" + ParquetEncoding_DELTA_LENGTH_BYTE_ARRAY \ + " parquet::Encoding::DELTA_LENGTH_BYTE_ARRAY" + ParquetEncoding_DELTA_BYTE_ARRAY" parquet::Encoding::DELTA_BYTE_ARRAY" + ParquetEncoding_RLE_DICTIONARY" parquet::Encoding::RLE_DICTIONARY" + ParquetEncoding_BYTE_STREAM_SPLIT \ + " parquet::Encoding::BYTE_STREAM_SPLIT" + + enum ParquetCompression" parquet::Compression::type": + ParquetCompression_UNCOMPRESSED" parquet::Compression::UNCOMPRESSED" + ParquetCompression_SNAPPY" parquet::Compression::SNAPPY" + ParquetCompression_GZIP" parquet::Compression::GZIP" + ParquetCompression_LZO" parquet::Compression::LZO" + ParquetCompression_BROTLI" parquet::Compression::BROTLI" + ParquetCompression_LZ4" parquet::Compression::LZ4" + ParquetCompression_ZSTD" parquet::Compression::ZSTD" + + enum ParquetVersion" parquet::ParquetVersion::type": + ParquetVersion_V1" parquet::ParquetVersion::PARQUET_1_0" + ParquetVersion_V2_0" parquet::ParquetVersion::PARQUET_2_0" + ParquetVersion_V2_4" parquet::ParquetVersion::PARQUET_2_4" + ParquetVersion_V2_6" parquet::ParquetVersion::PARQUET_2_6" + + enum ParquetSortOrder" parquet::SortOrder::type": + ParquetSortOrder_SIGNED" parquet::SortOrder::SIGNED" + ParquetSortOrder_UNSIGNED" parquet::SortOrder::UNSIGNED" + ParquetSortOrder_UNKNOWN" parquet::SortOrder::UNKNOWN" + + cdef cppclass CParquetLogicalType" parquet::LogicalType": + c_string ToString() const + c_string ToJSON() const + ParquetLogicalTypeId type() const + + cdef cppclass CParquetDecimalType \ + " parquet::DecimalLogicalType"(CParquetLogicalType): + int32_t precision() const + int32_t scale() const + + cdef cppclass CParquetIntType \ + " parquet::IntLogicalType"(CParquetLogicalType): + int bit_width() const + c_bool is_signed() const + + cdef cppclass CParquetTimeType \ + " parquet::TimeLogicalType"(CParquetLogicalType): + c_bool is_adjusted_to_utc() const + ParquetTimeUnit time_unit() const + + cdef cppclass CParquetTimestampType \ + " parquet::TimestampLogicalType"(CParquetLogicalType): + c_bool is_adjusted_to_utc() const + ParquetTimeUnit time_unit() const + + cdef cppclass ColumnDescriptor" parquet::ColumnDescriptor": + c_bool Equals(const ColumnDescriptor& other) + + shared_ptr[ColumnPath] path() + int16_t max_definition_level() + int16_t max_repetition_level() + + ParquetType physical_type() + const shared_ptr[const CParquetLogicalType]& logical_type() + ParquetConvertedType converted_type() + const c_string& name() + int type_length() + int type_precision() + int type_scale() + + cdef cppclass SchemaDescriptor: + const ColumnDescriptor* Column(int i) + shared_ptr[Node] schema() + GroupNode* group() + c_bool Equals(const SchemaDescriptor& other) + c_string ToString() + int num_columns() + + cdef c_string FormatStatValue(ParquetType parquet_type, c_string val) + + +cdef extern from "parquet/api/reader.h" namespace "parquet" nogil: + cdef cppclass ColumnReader: + pass + + cdef cppclass BoolReader(ColumnReader): + pass + + cdef cppclass Int32Reader(ColumnReader): + pass + + cdef cppclass Int64Reader(ColumnReader): + pass + + cdef cppclass Int96Reader(ColumnReader): + pass + + cdef cppclass FloatReader(ColumnReader): + pass + + cdef cppclass DoubleReader(ColumnReader): + pass + + cdef cppclass ByteArrayReader(ColumnReader): + pass + + cdef cppclass RowGroupReader: + pass + + cdef cppclass CEncodedStatistics" parquet::EncodedStatistics": + const c_string& max() const + const c_string& min() const + int64_t null_count + int64_t distinct_count + bint has_min + bint has_max + bint has_null_count + bint has_distinct_count + + cdef cppclass ParquetByteArray" parquet::ByteArray": + uint32_t len + const uint8_t* ptr + + cdef cppclass ParquetFLBA" parquet::FLBA": + const uint8_t* ptr + + cdef cppclass CStatistics" parquet::Statistics": + int64_t null_count() const + int64_t distinct_count() const + int64_t num_values() const + bint HasMinMax() + bint HasNullCount() + bint HasDistinctCount() + c_bool Equals(const CStatistics&) const + void Reset() + c_string EncodeMin() + c_string EncodeMax() + CEncodedStatistics Encode() + void SetComparator() + ParquetType physical_type() const + const ColumnDescriptor* descr() const + + cdef cppclass CBoolStatistics" parquet::BoolStatistics"(CStatistics): + c_bool min() + c_bool max() + + cdef cppclass CInt32Statistics" parquet::Int32Statistics"(CStatistics): + int32_t min() + int32_t max() + + cdef cppclass CInt64Statistics" parquet::Int64Statistics"(CStatistics): + int64_t min() + int64_t max() + + cdef cppclass CFloatStatistics" parquet::FloatStatistics"(CStatistics): + float min() + float max() + + cdef cppclass CDoubleStatistics" parquet::DoubleStatistics"(CStatistics): + double min() + double max() + + cdef cppclass CByteArrayStatistics \ + " parquet::ByteArrayStatistics"(CStatistics): + ParquetByteArray min() + ParquetByteArray max() + + cdef cppclass CFLBAStatistics" parquet::FLBAStatistics"(CStatistics): + ParquetFLBA min() + ParquetFLBA max() + + cdef cppclass CColumnChunkMetaData" parquet::ColumnChunkMetaData": + int64_t file_offset() const + const c_string& file_path() const + + ParquetType type() const + int64_t num_values() const + shared_ptr[ColumnPath] path_in_schema() const + bint is_stats_set() const + shared_ptr[CStatistics] statistics() const + ParquetCompression compression() const + const vector[ParquetEncoding]& encodings() const + c_bool Equals(const CColumnChunkMetaData&) const + + int64_t has_dictionary_page() const + int64_t dictionary_page_offset() const + int64_t data_page_offset() const + int64_t index_page_offset() const + int64_t total_compressed_size() const + int64_t total_uncompressed_size() const + + cdef cppclass CRowGroupMetaData" parquet::RowGroupMetaData": + c_bool Equals(const CRowGroupMetaData&) const + int num_columns() + int64_t num_rows() + int64_t total_byte_size() + unique_ptr[CColumnChunkMetaData] ColumnChunk(int i) const + + cdef cppclass CFileMetaData" parquet::FileMetaData": + c_bool Equals(const CFileMetaData&) const + uint32_t size() + int num_columns() + int64_t num_rows() + int num_row_groups() + ParquetVersion version() + const c_string created_by() + int num_schema_elements() + + void set_file_path(const c_string& path) + void AppendRowGroups(const CFileMetaData& other) except + + + unique_ptr[CRowGroupMetaData] RowGroup(int i) + const SchemaDescriptor* schema() + shared_ptr[const CKeyValueMetadata] key_value_metadata() const + void WriteTo(COutputStream* dst) const + + cdef shared_ptr[CFileMetaData] CFileMetaData_Make \ + " parquet::FileMetaData::Make"(const void* serialized_metadata, + uint32_t* metadata_len) + + cdef cppclass CReaderProperties" parquet::ReaderProperties": + c_bool is_buffered_stream_enabled() const + void enable_buffered_stream() + void disable_buffered_stream() + void set_buffer_size(int64_t buf_size) + int64_t buffer_size() const + + CReaderProperties default_reader_properties() + + cdef cppclass ArrowReaderProperties: + ArrowReaderProperties() + void set_read_dictionary(int column_index, c_bool read_dict) + c_bool read_dictionary() + void set_batch_size(int64_t batch_size) + int64_t batch_size() + void set_pre_buffer(c_bool pre_buffer) + c_bool pre_buffer() const + void set_coerce_int96_timestamp_unit(TimeUnit unit) + TimeUnit coerce_int96_timestamp_unit() const + + ArrowReaderProperties default_arrow_reader_properties() + + cdef cppclass ParquetFileReader: + shared_ptr[CFileMetaData] metadata() + + +cdef extern from "parquet/api/writer.h" namespace "parquet" nogil: + cdef cppclass WriterProperties: + cppclass Builder: + Builder* data_page_version(ParquetDataPageVersion version) + Builder* version(ParquetVersion version) + Builder* compression(ParquetCompression codec) + Builder* compression(const c_string& path, + ParquetCompression codec) + Builder* compression_level(int compression_level) + Builder* compression_level(const c_string& path, + int compression_level) + Builder* disable_dictionary() + Builder* enable_dictionary() + Builder* enable_dictionary(const c_string& path) + Builder* disable_statistics() + Builder* enable_statistics() + Builder* enable_statistics(const c_string& path) + Builder* data_pagesize(int64_t size) + Builder* encoding(ParquetEncoding encoding) + Builder* encoding(const c_string& path, + ParquetEncoding encoding) + Builder* write_batch_size(int64_t batch_size) + shared_ptr[WriterProperties] build() + + cdef cppclass ArrowWriterProperties: + cppclass Builder: + Builder() + Builder* disable_deprecated_int96_timestamps() + Builder* enable_deprecated_int96_timestamps() + Builder* coerce_timestamps(TimeUnit unit) + Builder* allow_truncated_timestamps() + Builder* disallow_truncated_timestamps() + Builder* store_schema() + Builder* enable_compliant_nested_types() + Builder* disable_compliant_nested_types() + Builder* set_engine_version(ArrowWriterEngineVersion version) + shared_ptr[ArrowWriterProperties] build() + c_bool support_deprecated_int96_timestamps() + + +cdef extern from "parquet/arrow/reader.h" namespace "parquet::arrow" nogil: + cdef cppclass FileReader: + FileReader(CMemoryPool* pool, unique_ptr[ParquetFileReader] reader) + + CStatus GetSchema(shared_ptr[CSchema]* out) + + CStatus ReadColumn(int i, shared_ptr[CChunkedArray]* out) + CStatus ReadSchemaField(int i, shared_ptr[CChunkedArray]* out) + + int num_row_groups() + CStatus ReadRowGroup(int i, shared_ptr[CTable]* out) + CStatus ReadRowGroup(int i, const vector[int]& column_indices, + shared_ptr[CTable]* out) + + CStatus ReadRowGroups(const vector[int]& row_groups, + shared_ptr[CTable]* out) + CStatus ReadRowGroups(const vector[int]& row_groups, + const vector[int]& column_indices, + shared_ptr[CTable]* out) + + CStatus GetRecordBatchReader(const vector[int]& row_group_indices, + const vector[int]& column_indices, + unique_ptr[CRecordBatchReader]* out) + CStatus GetRecordBatchReader(const vector[int]& row_group_indices, + unique_ptr[CRecordBatchReader]* out) + + CStatus ReadTable(shared_ptr[CTable]* out) + CStatus ReadTable(const vector[int]& column_indices, + shared_ptr[CTable]* out) + + CStatus ScanContents(vector[int] columns, int32_t column_batch_size, + int64_t* num_rows) + + const ParquetFileReader* parquet_reader() + + void set_use_threads(c_bool use_threads) + + void set_batch_size(int64_t batch_size) + + cdef cppclass FileReaderBuilder: + FileReaderBuilder() + CStatus Open(const shared_ptr[CRandomAccessFile]& file, + const CReaderProperties& properties, + const shared_ptr[CFileMetaData]& metadata) + + ParquetFileReader* raw_reader() + FileReaderBuilder* memory_pool(CMemoryPool*) + FileReaderBuilder* properties(const ArrowReaderProperties&) + CStatus Build(unique_ptr[FileReader]* out) + + CStatus FromParquetSchema( + const SchemaDescriptor* parquet_schema, + const ArrowReaderProperties& properties, + const shared_ptr[const CKeyValueMetadata]& key_value_metadata, + shared_ptr[CSchema]* out) + +cdef extern from "parquet/arrow/schema.h" namespace "parquet::arrow" nogil: + + CStatus ToParquetSchema( + const CSchema* arrow_schema, + const ArrowReaderProperties& properties, + const shared_ptr[const CKeyValueMetadata]& key_value_metadata, + shared_ptr[SchemaDescriptor]* out) + + +cdef extern from "parquet/properties.h" namespace "parquet" nogil: + cdef enum ArrowWriterEngineVersion: + V1 "parquet::ArrowWriterProperties::V1", + V2 "parquet::ArrowWriterProperties::V2" + + cdef cppclass ParquetDataPageVersion: + pass + + cdef ParquetDataPageVersion ParquetDataPageVersion_V1 \ + " parquet::ParquetDataPageVersion::V1" + cdef ParquetDataPageVersion ParquetDataPageVersion_V2 \ + " parquet::ParquetDataPageVersion::V2" + +cdef extern from "parquet/arrow/writer.h" namespace "parquet::arrow" nogil: + cdef cppclass FileWriter: + + @staticmethod + CStatus Open(const CSchema& schema, CMemoryPool* pool, + const shared_ptr[COutputStream]& sink, + const shared_ptr[WriterProperties]& properties, + const shared_ptr[ArrowWriterProperties]& arrow_properties, + unique_ptr[FileWriter]* writer) + + CStatus WriteTable(const CTable& table, int64_t chunk_size) + CStatus NewRowGroup(int64_t chunk_size) + CStatus Close() + + const shared_ptr[CFileMetaData] metadata() const + + CStatus WriteMetaDataFile( + const CFileMetaData& file_metadata, + const COutputStream* sink) + + +cdef shared_ptr[WriterProperties] _create_writer_properties( + use_dictionary=*, + compression=*, + version=*, + write_statistics=*, + data_page_size=*, + compression_level=*, + use_byte_stream_split=*, + data_page_version=*) except * + + +cdef shared_ptr[ArrowWriterProperties] _create_arrow_writer_properties( + use_deprecated_int96_timestamps=*, + coerce_timestamps=*, + allow_truncated_timestamps=*, + writer_engine_version=*, + use_compliant_nested_type=*) except * + +cdef class ParquetSchema(_Weakrefable): + cdef: + FileMetaData parent # the FileMetaData owning the SchemaDescriptor + const SchemaDescriptor* schema + +cdef class FileMetaData(_Weakrefable): + cdef: + shared_ptr[CFileMetaData] sp_metadata + CFileMetaData* _metadata + ParquetSchema _schema + + cdef inline init(self, const shared_ptr[CFileMetaData]& metadata): + self.sp_metadata = metadata + self._metadata = metadata.get() + +cdef class RowGroupMetaData(_Weakrefable): + cdef: + int index # for pickling support + unique_ptr[CRowGroupMetaData] up_metadata + CRowGroupMetaData* metadata + FileMetaData parent + +cdef class ColumnChunkMetaData(_Weakrefable): + cdef: + unique_ptr[CColumnChunkMetaData] up_metadata + CColumnChunkMetaData* metadata + RowGroupMetaData parent + + cdef inline init(self, RowGroupMetaData parent, int i): + self.up_metadata = parent.metadata.ColumnChunk(i) + self.metadata = self.up_metadata.get() + self.parent = parent + +cdef class Statistics(_Weakrefable): + cdef: + shared_ptr[CStatistics] statistics + ColumnChunkMetaData parent + + cdef inline init(self, const shared_ptr[CStatistics]& statistics, + ColumnChunkMetaData parent): + self.statistics = statistics + self.parent = parent diff --git a/src/arrow/python/pyarrow/_parquet.pyx b/src/arrow/python/pyarrow/_parquet.pyx new file mode 100644 index 000000000..87dfecf1b --- /dev/null +++ b/src/arrow/python/pyarrow/_parquet.pyx @@ -0,0 +1,1466 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ + +import io +from textwrap import indent +import warnings + +import numpy as np + +from cython.operator cimport dereference as deref +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (_Weakrefable, Buffer, Array, Schema, + check_status, + MemoryPool, maybe_unbox_memory_pool, + Table, NativeFile, + pyarrow_wrap_chunked_array, + pyarrow_wrap_schema, + pyarrow_wrap_table, + pyarrow_wrap_buffer, + pyarrow_wrap_batch, + NativeFile, get_reader, get_writer, + string_to_timeunit) + +from pyarrow.lib import (ArrowException, NativeFile, BufferOutputStream, + _stringify_path, _datetime_from_int, + tobytes, frombytes) + +cimport cpython as cp + + +cdef class Statistics(_Weakrefable): + def __cinit__(self): + pass + + def __repr__(self): + return """{} + has_min_max: {} + min: {} + max: {} + null_count: {} + distinct_count: {} + num_values: {} + physical_type: {} + logical_type: {} + converted_type (legacy): {}""".format(object.__repr__(self), + self.has_min_max, + self.min, + self.max, + self.null_count, + self.distinct_count, + self.num_values, + self.physical_type, + str(self.logical_type), + self.converted_type) + + def to_dict(self): + d = dict( + has_min_max=self.has_min_max, + min=self.min, + max=self.max, + null_count=self.null_count, + distinct_count=self.distinct_count, + num_values=self.num_values, + physical_type=self.physical_type + ) + return d + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, Statistics other): + return self.statistics.get().Equals(deref(other.statistics.get())) + + @property + def has_min_max(self): + return self.statistics.get().HasMinMax() + + @property + def has_null_count(self): + return self.statistics.get().HasNullCount() + + @property + def has_distinct_count(self): + return self.statistics.get().HasDistinctCount() + + @property + def min_raw(self): + if self.has_min_max: + return _cast_statistic_raw_min(self.statistics.get()) + else: + return None + + @property + def max_raw(self): + if self.has_min_max: + return _cast_statistic_raw_max(self.statistics.get()) + else: + return None + + @property + def min(self): + if self.has_min_max: + return _cast_statistic_min(self.statistics.get()) + else: + return None + + @property + def max(self): + if self.has_min_max: + return _cast_statistic_max(self.statistics.get()) + else: + return None + + @property + def null_count(self): + return self.statistics.get().null_count() + + @property + def distinct_count(self): + return self.statistics.get().distinct_count() + + @property + def num_values(self): + return self.statistics.get().num_values() + + @property + def physical_type(self): + raw_physical_type = self.statistics.get().physical_type() + return physical_type_name_from_enum(raw_physical_type) + + @property + def logical_type(self): + return wrap_logical_type(self.statistics.get().descr().logical_type()) + + @property + def converted_type(self): + raw_converted_type = self.statistics.get().descr().converted_type() + return converted_type_name_from_enum(raw_converted_type) + + +cdef class ParquetLogicalType(_Weakrefable): + cdef: + shared_ptr[const CParquetLogicalType] type + + def __cinit__(self): + pass + + cdef init(self, const shared_ptr[const CParquetLogicalType]& type): + self.type = type + + def __str__(self): + return frombytes(self.type.get().ToString(), safe=True) + + def to_json(self): + return frombytes(self.type.get().ToJSON()) + + @property + def type(self): + return logical_type_name_from_enum(self.type.get().type()) + + +cdef wrap_logical_type(const shared_ptr[const CParquetLogicalType]& type): + cdef ParquetLogicalType out = ParquetLogicalType() + out.init(type) + return out + + +cdef _cast_statistic_raw_min(CStatistics* statistics): + cdef ParquetType physical_type = statistics.physical_type() + cdef uint32_t type_length = statistics.descr().type_length() + if physical_type == ParquetType_BOOLEAN: + return (<CBoolStatistics*> statistics).min() + elif physical_type == ParquetType_INT32: + return (<CInt32Statistics*> statistics).min() + elif physical_type == ParquetType_INT64: + return (<CInt64Statistics*> statistics).min() + elif physical_type == ParquetType_FLOAT: + return (<CFloatStatistics*> statistics).min() + elif physical_type == ParquetType_DOUBLE: + return (<CDoubleStatistics*> statistics).min() + elif physical_type == ParquetType_BYTE_ARRAY: + return _box_byte_array((<CByteArrayStatistics*> statistics).min()) + elif physical_type == ParquetType_FIXED_LEN_BYTE_ARRAY: + return _box_flba((<CFLBAStatistics*> statistics).min(), type_length) + + +cdef _cast_statistic_raw_max(CStatistics* statistics): + cdef ParquetType physical_type = statistics.physical_type() + cdef uint32_t type_length = statistics.descr().type_length() + if physical_type == ParquetType_BOOLEAN: + return (<CBoolStatistics*> statistics).max() + elif physical_type == ParquetType_INT32: + return (<CInt32Statistics*> statistics).max() + elif physical_type == ParquetType_INT64: + return (<CInt64Statistics*> statistics).max() + elif physical_type == ParquetType_FLOAT: + return (<CFloatStatistics*> statistics).max() + elif physical_type == ParquetType_DOUBLE: + return (<CDoubleStatistics*> statistics).max() + elif physical_type == ParquetType_BYTE_ARRAY: + return _box_byte_array((<CByteArrayStatistics*> statistics).max()) + elif physical_type == ParquetType_FIXED_LEN_BYTE_ARRAY: + return _box_flba((<CFLBAStatistics*> statistics).max(), type_length) + + +cdef _cast_statistic_min(CStatistics* statistics): + min_raw = _cast_statistic_raw_min(statistics) + return _box_logical_type_value(min_raw, statistics.descr()) + + +cdef _cast_statistic_max(CStatistics* statistics): + max_raw = _cast_statistic_raw_max(statistics) + return _box_logical_type_value(max_raw, statistics.descr()) + + +cdef _box_logical_type_value(object value, const ColumnDescriptor* descr): + cdef: + const CParquetLogicalType* ltype = descr.logical_type().get() + ParquetTimeUnit time_unit + const CParquetIntType* itype + const CParquetTimestampType* ts_type + + if ltype.type() == ParquetLogicalType_STRING: + return value.decode('utf8') + elif ltype.type() == ParquetLogicalType_TIME: + time_unit = (<const CParquetTimeType*> ltype).time_unit() + if time_unit == ParquetTimeUnit_MILLIS: + return _datetime_from_int(value, unit=TimeUnit_MILLI).time() + else: + return _datetime_from_int(value, unit=TimeUnit_MICRO).time() + elif ltype.type() == ParquetLogicalType_TIMESTAMP: + ts_type = <const CParquetTimestampType*> ltype + time_unit = ts_type.time_unit() + if ts_type.is_adjusted_to_utc(): + import pytz + tzinfo = pytz.utc + else: + tzinfo = None + if time_unit == ParquetTimeUnit_MILLIS: + return _datetime_from_int(value, unit=TimeUnit_MILLI, + tzinfo=tzinfo) + elif time_unit == ParquetTimeUnit_MICROS: + return _datetime_from_int(value, unit=TimeUnit_MICRO, + tzinfo=tzinfo) + elif time_unit == ParquetTimeUnit_NANOS: + return _datetime_from_int(value, unit=TimeUnit_NANO, + tzinfo=tzinfo) + else: + raise ValueError("Unsupported time unit") + elif ltype.type() == ParquetLogicalType_INT: + itype = <const CParquetIntType*> ltype + if not itype.is_signed() and itype.bit_width() == 32: + return int(np.int32(value).view(np.uint32)) + elif not itype.is_signed() and itype.bit_width() == 64: + return int(np.int64(value).view(np.uint64)) + else: + return value + else: + # No logical boxing defined + return value + + +cdef _box_byte_array(ParquetByteArray val): + return cp.PyBytes_FromStringAndSize(<char*> val.ptr, <Py_ssize_t> val.len) + + +cdef _box_flba(ParquetFLBA val, uint32_t len): + return cp.PyBytes_FromStringAndSize(<char*> val.ptr, <Py_ssize_t> len) + + +cdef class ColumnChunkMetaData(_Weakrefable): + def __cinit__(self): + pass + + def __repr__(self): + statistics = indent(repr(self.statistics), 4 * ' ') + return """{0} + file_offset: {1} + file_path: {2} + physical_type: {3} + num_values: {4} + path_in_schema: {5} + is_stats_set: {6} + statistics: +{7} + compression: {8} + encodings: {9} + has_dictionary_page: {10} + dictionary_page_offset: {11} + data_page_offset: {12} + total_compressed_size: {13} + total_uncompressed_size: {14}""".format(object.__repr__(self), + self.file_offset, + self.file_path, + self.physical_type, + self.num_values, + self.path_in_schema, + self.is_stats_set, + statistics, + self.compression, + self.encodings, + self.has_dictionary_page, + self.dictionary_page_offset, + self.data_page_offset, + self.total_compressed_size, + self.total_uncompressed_size) + + def to_dict(self): + statistics = self.statistics.to_dict() if self.is_stats_set else None + d = dict( + file_offset=self.file_offset, + file_path=self.file_path, + physical_type=self.physical_type, + num_values=self.num_values, + path_in_schema=self.path_in_schema, + is_stats_set=self.is_stats_set, + statistics=statistics, + compression=self.compression, + encodings=self.encodings, + has_dictionary_page=self.has_dictionary_page, + dictionary_page_offset=self.dictionary_page_offset, + data_page_offset=self.data_page_offset, + total_compressed_size=self.total_compressed_size, + total_uncompressed_size=self.total_uncompressed_size + ) + return d + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, ColumnChunkMetaData other): + return self.metadata.Equals(deref(other.metadata)) + + @property + def file_offset(self): + return self.metadata.file_offset() + + @property + def file_path(self): + return frombytes(self.metadata.file_path()) + + @property + def physical_type(self): + return physical_type_name_from_enum(self.metadata.type()) + + @property + def num_values(self): + return self.metadata.num_values() + + @property + def path_in_schema(self): + path = self.metadata.path_in_schema().get().ToDotString() + return frombytes(path) + + @property + def is_stats_set(self): + return self.metadata.is_stats_set() + + @property + def statistics(self): + if not self.metadata.is_stats_set(): + return None + statistics = Statistics() + statistics.init(self.metadata.statistics(), self) + return statistics + + @property + def compression(self): + return compression_name_from_enum(self.metadata.compression()) + + @property + def encodings(self): + return tuple(map(encoding_name_from_enum, self.metadata.encodings())) + + @property + def has_dictionary_page(self): + return bool(self.metadata.has_dictionary_page()) + + @property + def dictionary_page_offset(self): + if self.has_dictionary_page: + return self.metadata.dictionary_page_offset() + else: + return None + + @property + def data_page_offset(self): + return self.metadata.data_page_offset() + + @property + def has_index_page(self): + raise NotImplementedError('not supported in parquet-cpp') + + @property + def index_page_offset(self): + raise NotImplementedError("parquet-cpp doesn't return valid values") + + @property + def total_compressed_size(self): + return self.metadata.total_compressed_size() + + @property + def total_uncompressed_size(self): + return self.metadata.total_uncompressed_size() + + +cdef class RowGroupMetaData(_Weakrefable): + def __cinit__(self, FileMetaData parent, int index): + if index < 0 or index >= parent.num_row_groups: + raise IndexError('{0} out of bounds'.format(index)) + self.up_metadata = parent._metadata.RowGroup(index) + self.metadata = self.up_metadata.get() + self.parent = parent + self.index = index + + def __reduce__(self): + return RowGroupMetaData, (self.parent, self.index) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, RowGroupMetaData other): + return self.metadata.Equals(deref(other.metadata)) + + def column(self, int i): + if i < 0 or i >= self.num_columns: + raise IndexError('{0} out of bounds'.format(i)) + chunk = ColumnChunkMetaData() + chunk.init(self, i) + return chunk + + def __repr__(self): + return """{0} + num_columns: {1} + num_rows: {2} + total_byte_size: {3}""".format(object.__repr__(self), + self.num_columns, + self.num_rows, + self.total_byte_size) + + def to_dict(self): + columns = [] + d = dict( + num_columns=self.num_columns, + num_rows=self.num_rows, + total_byte_size=self.total_byte_size, + columns=columns, + ) + for i in range(self.num_columns): + columns.append(self.column(i).to_dict()) + return d + + @property + def num_columns(self): + return self.metadata.num_columns() + + @property + def num_rows(self): + return self.metadata.num_rows() + + @property + def total_byte_size(self): + return self.metadata.total_byte_size() + + +def _reconstruct_filemetadata(Buffer serialized): + cdef: + FileMetaData metadata = FileMetaData.__new__(FileMetaData) + CBuffer *buffer = serialized.buffer.get() + uint32_t metadata_len = <uint32_t>buffer.size() + + metadata.init(CFileMetaData_Make(buffer.data(), &metadata_len)) + + return metadata + + +cdef class FileMetaData(_Weakrefable): + def __cinit__(self): + pass + + def __reduce__(self): + cdef: + NativeFile sink = BufferOutputStream() + COutputStream* c_sink = sink.get_output_stream().get() + with nogil: + self._metadata.WriteTo(c_sink) + + cdef Buffer buffer = sink.getvalue() + return _reconstruct_filemetadata, (buffer,) + + def __repr__(self): + return """{0} + created_by: {1} + num_columns: {2} + num_rows: {3} + num_row_groups: {4} + format_version: {5} + serialized_size: {6}""".format(object.__repr__(self), + self.created_by, self.num_columns, + self.num_rows, self.num_row_groups, + self.format_version, + self.serialized_size) + + def to_dict(self): + row_groups = [] + d = dict( + created_by=self.created_by, + num_columns=self.num_columns, + num_rows=self.num_rows, + num_row_groups=self.num_row_groups, + row_groups=row_groups, + format_version=self.format_version, + serialized_size=self.serialized_size + ) + for i in range(self.num_row_groups): + row_groups.append(self.row_group(i).to_dict()) + return d + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, FileMetaData other): + return self._metadata.Equals(deref(other._metadata)) + + @property + def schema(self): + if self._schema is None: + self._schema = ParquetSchema(self) + return self._schema + + @property + def serialized_size(self): + return self._metadata.size() + + @property + def num_columns(self): + return self._metadata.num_columns() + + @property + def num_rows(self): + return self._metadata.num_rows() + + @property + def num_row_groups(self): + return self._metadata.num_row_groups() + + @property + def format_version(self): + cdef ParquetVersion version = self._metadata.version() + if version == ParquetVersion_V1: + return '1.0' + elif version == ParquetVersion_V2_0: + return 'pseudo-2.0' + elif version == ParquetVersion_V2_4: + return '2.4' + elif version == ParquetVersion_V2_6: + return '2.6' + else: + warnings.warn('Unrecognized file version, assuming 1.0: {}' + .format(version)) + return '1.0' + + @property + def created_by(self): + return frombytes(self._metadata.created_by()) + + @property + def metadata(self): + cdef: + unordered_map[c_string, c_string] metadata + const CKeyValueMetadata* underlying_metadata + underlying_metadata = self._metadata.key_value_metadata().get() + if underlying_metadata != NULL: + underlying_metadata.ToUnorderedMap(&metadata) + return metadata + else: + return None + + def row_group(self, int i): + return RowGroupMetaData(self, i) + + def set_file_path(self, path): + """ + Modify the file_path field of each ColumnChunk in the + FileMetaData to be a particular value + """ + cdef: + c_string c_path = tobytes(path) + self._metadata.set_file_path(c_path) + + def append_row_groups(self, FileMetaData other): + """ + Append row groups of other FileMetaData object + """ + cdef shared_ptr[CFileMetaData] c_metadata + + c_metadata = other.sp_metadata + self._metadata.AppendRowGroups(deref(c_metadata)) + + def write_metadata_file(self, where): + """ + Write the metadata object to a metadata-only file + """ + cdef: + shared_ptr[COutputStream] sink + c_string c_where + + try: + where = _stringify_path(where) + except TypeError: + get_writer(where, &sink) + else: + c_where = tobytes(where) + with nogil: + sink = GetResultValue(FileOutputStream.Open(c_where)) + + with nogil: + check_status( + WriteMetaDataFile(deref(self._metadata), sink.get())) + + +cdef class ParquetSchema(_Weakrefable): + def __cinit__(self, FileMetaData container): + self.parent = container + self.schema = container._metadata.schema() + + def __repr__(self): + return "{0}\n{1}".format( + object.__repr__(self), + frombytes(self.schema.ToString(), safe=True)) + + def __reduce__(self): + return ParquetSchema, (self.parent,) + + def __len__(self): + return self.schema.num_columns() + + def __getitem__(self, i): + return self.column(i) + + @property + def names(self): + return [self[i].name for i in range(len(self))] + + def to_arrow_schema(self): + """ + Convert Parquet schema to effective Arrow schema + + Returns + ------- + schema : pyarrow.Schema + """ + cdef shared_ptr[CSchema] sp_arrow_schema + + with nogil: + check_status(FromParquetSchema( + self.schema, default_arrow_reader_properties(), + self.parent._metadata.key_value_metadata(), + &sp_arrow_schema)) + + return pyarrow_wrap_schema(sp_arrow_schema) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, ParquetSchema other): + """ + Returns True if the Parquet schemas are equal + """ + return self.schema.Equals(deref(other.schema)) + + def column(self, i): + if i < 0 or i >= len(self): + raise IndexError('{0} out of bounds'.format(i)) + + return ColumnSchema(self, i) + + +cdef class ColumnSchema(_Weakrefable): + cdef: + int index + ParquetSchema parent + const ColumnDescriptor* descr + + def __cinit__(self, ParquetSchema schema, int index): + self.parent = schema + self.index = index # for pickling support + self.descr = schema.schema.Column(index) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def __reduce__(self): + return ColumnSchema, (self.parent, self.index) + + def equals(self, ColumnSchema other): + """ + Returns True if the column schemas are equal + """ + return self.descr.Equals(deref(other.descr)) + + def __repr__(self): + physical_type = self.physical_type + converted_type = self.converted_type + if converted_type == 'DECIMAL': + converted_type = 'DECIMAL({0}, {1})'.format(self.precision, + self.scale) + elif physical_type == 'FIXED_LEN_BYTE_ARRAY': + converted_type = ('FIXED_LEN_BYTE_ARRAY(length={0})' + .format(self.length)) + + return """<ParquetColumnSchema> + name: {0} + path: {1} + max_definition_level: {2} + max_repetition_level: {3} + physical_type: {4} + logical_type: {5} + converted_type (legacy): {6}""".format(self.name, self.path, + self.max_definition_level, + self.max_repetition_level, + physical_type, + str(self.logical_type), + converted_type) + + @property + def name(self): + return frombytes(self.descr.name()) + + @property + def path(self): + return frombytes(self.descr.path().get().ToDotString()) + + @property + def max_definition_level(self): + return self.descr.max_definition_level() + + @property + def max_repetition_level(self): + return self.descr.max_repetition_level() + + @property + def physical_type(self): + return physical_type_name_from_enum(self.descr.physical_type()) + + @property + def logical_type(self): + return wrap_logical_type(self.descr.logical_type()) + + @property + def converted_type(self): + return converted_type_name_from_enum(self.descr.converted_type()) + + @property + def logical_type(self): + return wrap_logical_type(self.descr.logical_type()) + + # FIXED_LEN_BYTE_ARRAY attribute + @property + def length(self): + return self.descr.type_length() + + # Decimal attributes + @property + def precision(self): + return self.descr.type_precision() + + @property + def scale(self): + return self.descr.type_scale() + + +cdef physical_type_name_from_enum(ParquetType type_): + return { + ParquetType_BOOLEAN: 'BOOLEAN', + ParquetType_INT32: 'INT32', + ParquetType_INT64: 'INT64', + ParquetType_INT96: 'INT96', + ParquetType_FLOAT: 'FLOAT', + ParquetType_DOUBLE: 'DOUBLE', + ParquetType_BYTE_ARRAY: 'BYTE_ARRAY', + ParquetType_FIXED_LEN_BYTE_ARRAY: 'FIXED_LEN_BYTE_ARRAY', + }.get(type_, 'UNKNOWN') + + +cdef logical_type_name_from_enum(ParquetLogicalTypeId type_): + return { + ParquetLogicalType_UNDEFINED: 'UNDEFINED', + ParquetLogicalType_STRING: 'STRING', + ParquetLogicalType_MAP: 'MAP', + ParquetLogicalType_LIST: 'LIST', + ParquetLogicalType_ENUM: 'ENUM', + ParquetLogicalType_DECIMAL: 'DECIMAL', + ParquetLogicalType_DATE: 'DATE', + ParquetLogicalType_TIME: 'TIME', + ParquetLogicalType_TIMESTAMP: 'TIMESTAMP', + ParquetLogicalType_INT: 'INT', + ParquetLogicalType_JSON: 'JSON', + ParquetLogicalType_BSON: 'BSON', + ParquetLogicalType_UUID: 'UUID', + ParquetLogicalType_NONE: 'NONE', + }.get(type_, 'UNKNOWN') + + +cdef converted_type_name_from_enum(ParquetConvertedType type_): + return { + ParquetConvertedType_NONE: 'NONE', + ParquetConvertedType_UTF8: 'UTF8', + ParquetConvertedType_MAP: 'MAP', + ParquetConvertedType_MAP_KEY_VALUE: 'MAP_KEY_VALUE', + ParquetConvertedType_LIST: 'LIST', + ParquetConvertedType_ENUM: 'ENUM', + ParquetConvertedType_DECIMAL: 'DECIMAL', + ParquetConvertedType_DATE: 'DATE', + ParquetConvertedType_TIME_MILLIS: 'TIME_MILLIS', + ParquetConvertedType_TIME_MICROS: 'TIME_MICROS', + ParquetConvertedType_TIMESTAMP_MILLIS: 'TIMESTAMP_MILLIS', + ParquetConvertedType_TIMESTAMP_MICROS: 'TIMESTAMP_MICROS', + ParquetConvertedType_UINT_8: 'UINT_8', + ParquetConvertedType_UINT_16: 'UINT_16', + ParquetConvertedType_UINT_32: 'UINT_32', + ParquetConvertedType_UINT_64: 'UINT_64', + ParquetConvertedType_INT_8: 'INT_8', + ParquetConvertedType_INT_16: 'INT_16', + ParquetConvertedType_INT_32: 'INT_32', + ParquetConvertedType_INT_64: 'INT_64', + ParquetConvertedType_JSON: 'JSON', + ParquetConvertedType_BSON: 'BSON', + ParquetConvertedType_INTERVAL: 'INTERVAL', + }.get(type_, 'UNKNOWN') + + +cdef encoding_name_from_enum(ParquetEncoding encoding_): + return { + ParquetEncoding_PLAIN: 'PLAIN', + ParquetEncoding_PLAIN_DICTIONARY: 'PLAIN_DICTIONARY', + ParquetEncoding_RLE: 'RLE', + ParquetEncoding_BIT_PACKED: 'BIT_PACKED', + ParquetEncoding_DELTA_BINARY_PACKED: 'DELTA_BINARY_PACKED', + ParquetEncoding_DELTA_LENGTH_BYTE_ARRAY: 'DELTA_LENGTH_BYTE_ARRAY', + ParquetEncoding_DELTA_BYTE_ARRAY: 'DELTA_BYTE_ARRAY', + ParquetEncoding_RLE_DICTIONARY: 'RLE_DICTIONARY', + ParquetEncoding_BYTE_STREAM_SPLIT: 'BYTE_STREAM_SPLIT', + }.get(encoding_, 'UNKNOWN') + + +cdef compression_name_from_enum(ParquetCompression compression_): + return { + ParquetCompression_UNCOMPRESSED: 'UNCOMPRESSED', + ParquetCompression_SNAPPY: 'SNAPPY', + ParquetCompression_GZIP: 'GZIP', + ParquetCompression_LZO: 'LZO', + ParquetCompression_BROTLI: 'BROTLI', + ParquetCompression_LZ4: 'LZ4', + ParquetCompression_ZSTD: 'ZSTD', + }.get(compression_, 'UNKNOWN') + + +cdef int check_compression_name(name) except -1: + if name.upper() not in {'NONE', 'SNAPPY', 'GZIP', 'LZO', 'BROTLI', 'LZ4', + 'ZSTD'}: + raise ArrowException("Unsupported compression: " + name) + return 0 + + +cdef ParquetCompression compression_from_name(name): + name = name.upper() + if name == 'SNAPPY': + return ParquetCompression_SNAPPY + elif name == 'GZIP': + return ParquetCompression_GZIP + elif name == 'LZO': + return ParquetCompression_LZO + elif name == 'BROTLI': + return ParquetCompression_BROTLI + elif name == 'LZ4': + return ParquetCompression_LZ4 + elif name == 'ZSTD': + return ParquetCompression_ZSTD + else: + return ParquetCompression_UNCOMPRESSED + + +cdef class ParquetReader(_Weakrefable): + cdef: + object source + CMemoryPool* pool + unique_ptr[FileReader] reader + FileMetaData _metadata + + cdef public: + _column_idx_map + + def __cinit__(self, MemoryPool memory_pool=None): + self.pool = maybe_unbox_memory_pool(memory_pool) + self._metadata = None + + def open(self, object source not None, bint use_memory_map=True, + read_dictionary=None, FileMetaData metadata=None, + int buffer_size=0, bint pre_buffer=False, + coerce_int96_timestamp_unit=None): + cdef: + shared_ptr[CRandomAccessFile] rd_handle + shared_ptr[CFileMetaData] c_metadata + CReaderProperties properties = default_reader_properties() + ArrowReaderProperties arrow_props = ( + default_arrow_reader_properties()) + c_string path + FileReaderBuilder builder + TimeUnit int96_timestamp_unit_code + + if metadata is not None: + c_metadata = metadata.sp_metadata + + if buffer_size > 0: + properties.enable_buffered_stream() + properties.set_buffer_size(buffer_size) + elif buffer_size == 0: + properties.disable_buffered_stream() + else: + raise ValueError('Buffer size must be larger than zero') + + arrow_props.set_pre_buffer(pre_buffer) + + if coerce_int96_timestamp_unit is None: + # use the default defined in default_arrow_reader_properties() + pass + else: + arrow_props.set_coerce_int96_timestamp_unit( + string_to_timeunit(coerce_int96_timestamp_unit)) + + self.source = source + + get_reader(source, use_memory_map, &rd_handle) + with nogil: + check_status(builder.Open(rd_handle, properties, c_metadata)) + + # Set up metadata + with nogil: + c_metadata = builder.raw_reader().metadata() + self._metadata = result = FileMetaData() + result.init(c_metadata) + + if read_dictionary is not None: + self._set_read_dictionary(read_dictionary, &arrow_props) + + with nogil: + check_status(builder.memory_pool(self.pool) + .properties(arrow_props) + .Build(&self.reader)) + + cdef _set_read_dictionary(self, read_dictionary, + ArrowReaderProperties* props): + for column in read_dictionary: + if not isinstance(column, int): + column = self.column_name_idx(column) + props.set_read_dictionary(column, True) + + @property + def column_paths(self): + cdef: + FileMetaData container = self.metadata + const CFileMetaData* metadata = container._metadata + vector[c_string] path + int i = 0 + + paths = [] + for i in range(0, metadata.num_columns()): + path = (metadata.schema().Column(i) + .path().get().ToDotVector()) + paths.append([frombytes(x) for x in path]) + + return paths + + @property + def metadata(self): + return self._metadata + + @property + def schema_arrow(self): + cdef shared_ptr[CSchema] out + with nogil: + check_status(self.reader.get().GetSchema(&out)) + return pyarrow_wrap_schema(out) + + @property + def num_row_groups(self): + return self.reader.get().num_row_groups() + + def set_use_threads(self, bint use_threads): + self.reader.get().set_use_threads(use_threads) + + def set_batch_size(self, int64_t batch_size): + self.reader.get().set_batch_size(batch_size) + + def iter_batches(self, int64_t batch_size, row_groups, column_indices=None, + bint use_threads=True): + cdef: + vector[int] c_row_groups + vector[int] c_column_indices + shared_ptr[CRecordBatch] record_batch + shared_ptr[TableBatchReader] batch_reader + unique_ptr[CRecordBatchReader] recordbatchreader + + self.set_batch_size(batch_size) + + if use_threads: + self.set_use_threads(use_threads) + + for row_group in row_groups: + c_row_groups.push_back(row_group) + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + with nogil: + check_status( + self.reader.get().GetRecordBatchReader( + c_row_groups, c_column_indices, &recordbatchreader + ) + ) + else: + with nogil: + check_status( + self.reader.get().GetRecordBatchReader( + c_row_groups, &recordbatchreader + ) + ) + + while True: + with nogil: + check_status( + recordbatchreader.get().ReadNext(&record_batch) + ) + + if record_batch.get() == NULL: + break + + yield pyarrow_wrap_batch(record_batch) + + def read_row_group(self, int i, column_indices=None, + bint use_threads=True): + return self.read_row_groups([i], column_indices, use_threads) + + def read_row_groups(self, row_groups not None, column_indices=None, + bint use_threads=True): + cdef: + shared_ptr[CTable] ctable + vector[int] c_row_groups + vector[int] c_column_indices + + self.set_use_threads(use_threads) + + for row_group in row_groups: + c_row_groups.push_back(row_group) + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + + with nogil: + check_status(self.reader.get() + .ReadRowGroups(c_row_groups, c_column_indices, + &ctable)) + else: + # Read all columns + with nogil: + check_status(self.reader.get() + .ReadRowGroups(c_row_groups, &ctable)) + return pyarrow_wrap_table(ctable) + + def read_all(self, column_indices=None, bint use_threads=True): + cdef: + shared_ptr[CTable] ctable + vector[int] c_column_indices + + self.set_use_threads(use_threads) + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + + with nogil: + check_status(self.reader.get() + .ReadTable(c_column_indices, &ctable)) + else: + # Read all columns + with nogil: + check_status(self.reader.get() + .ReadTable(&ctable)) + return pyarrow_wrap_table(ctable) + + def scan_contents(self, column_indices=None, batch_size=65536): + cdef: + vector[int] c_column_indices + int32_t c_batch_size + int64_t c_num_rows + + if column_indices is not None: + for index in column_indices: + c_column_indices.push_back(index) + + c_batch_size = batch_size + + with nogil: + check_status(self.reader.get() + .ScanContents(c_column_indices, c_batch_size, + &c_num_rows)) + + return c_num_rows + + def column_name_idx(self, column_name): + """ + Find the matching index of a column in the schema. + + Parameter + --------- + column_name: str + Name of the column, separation of nesting levels is done via ".". + + Returns + ------- + column_idx: int + Integer index of the position of the column + """ + cdef: + FileMetaData container = self.metadata + const CFileMetaData* metadata = container._metadata + int i = 0 + + if self._column_idx_map is None: + self._column_idx_map = {} + for i in range(0, metadata.num_columns()): + col_bytes = tobytes(metadata.schema().Column(i) + .path().get().ToDotString()) + self._column_idx_map[col_bytes] = i + + return self._column_idx_map[tobytes(column_name)] + + def read_column(self, int column_index): + cdef shared_ptr[CChunkedArray] out + with nogil: + check_status(self.reader.get() + .ReadColumn(column_index, &out)) + return pyarrow_wrap_chunked_array(out) + + def read_schema_field(self, int field_index): + cdef shared_ptr[CChunkedArray] out + with nogil: + check_status(self.reader.get() + .ReadSchemaField(field_index, &out)) + return pyarrow_wrap_chunked_array(out) + + +cdef shared_ptr[WriterProperties] _create_writer_properties( + use_dictionary=None, + compression=None, + version=None, + write_statistics=None, + data_page_size=None, + compression_level=None, + use_byte_stream_split=False, + data_page_version=None) except *: + """General writer properties""" + cdef: + shared_ptr[WriterProperties] properties + WriterProperties.Builder props + + # data_page_version + + if data_page_version is not None: + if data_page_version == "1.0": + props.data_page_version(ParquetDataPageVersion_V1) + elif data_page_version == "2.0": + props.data_page_version(ParquetDataPageVersion_V2) + else: + raise ValueError("Unsupported Parquet data page version: {0}" + .format(data_page_version)) + + # version + + if version is not None: + if version == "1.0": + props.version(ParquetVersion_V1) + elif version in ("2.0", "pseudo-2.0"): + warnings.warn( + "Parquet format '2.0' pseudo version is deprecated, use " + "'2.4' or '2.6' for fine-grained feature selection", + FutureWarning, stacklevel=2) + props.version(ParquetVersion_V2_0) + elif version == "2.4": + props.version(ParquetVersion_V2_4) + elif version == "2.6": + props.version(ParquetVersion_V2_6) + else: + raise ValueError("Unsupported Parquet format version: {0}" + .format(version)) + + # compression + + if isinstance(compression, basestring): + check_compression_name(compression) + props.compression(compression_from_name(compression)) + elif compression is not None: + for column, codec in compression.iteritems(): + check_compression_name(codec) + props.compression(tobytes(column), compression_from_name(codec)) + + if isinstance(compression_level, int): + props.compression_level(compression_level) + elif compression_level is not None: + for column, level in compression_level.iteritems(): + props.compression_level(tobytes(column), level) + + # use_dictionary + + if isinstance(use_dictionary, bool): + if use_dictionary: + props.enable_dictionary() + else: + props.disable_dictionary() + elif use_dictionary is not None: + # Deactivate dictionary encoding by default + props.disable_dictionary() + for column in use_dictionary: + props.enable_dictionary(tobytes(column)) + + # write_statistics + + if isinstance(write_statistics, bool): + if write_statistics: + props.enable_statistics() + else: + props.disable_statistics() + elif write_statistics is not None: + # Deactivate statistics by default and enable for specified columns + props.disable_statistics() + for column in write_statistics: + props.enable_statistics(tobytes(column)) + + # use_byte_stream_split + + if isinstance(use_byte_stream_split, bool): + if use_byte_stream_split: + props.encoding(ParquetEncoding_BYTE_STREAM_SPLIT) + elif use_byte_stream_split is not None: + for column in use_byte_stream_split: + props.encoding(tobytes(column), + ParquetEncoding_BYTE_STREAM_SPLIT) + + if data_page_size is not None: + props.data_pagesize(data_page_size) + + properties = props.build() + + return properties + + +cdef shared_ptr[ArrowWriterProperties] _create_arrow_writer_properties( + use_deprecated_int96_timestamps=False, + coerce_timestamps=None, + allow_truncated_timestamps=False, + writer_engine_version=None, + use_compliant_nested_type=False) except *: + """Arrow writer properties""" + cdef: + shared_ptr[ArrowWriterProperties] arrow_properties + ArrowWriterProperties.Builder arrow_props + + # Store the original Arrow schema so things like dictionary types can + # be automatically reconstructed + arrow_props.store_schema() + + # int96 support + + if use_deprecated_int96_timestamps: + arrow_props.enable_deprecated_int96_timestamps() + else: + arrow_props.disable_deprecated_int96_timestamps() + + # coerce_timestamps + + if coerce_timestamps == 'ms': + arrow_props.coerce_timestamps(TimeUnit_MILLI) + elif coerce_timestamps == 'us': + arrow_props.coerce_timestamps(TimeUnit_MICRO) + elif coerce_timestamps is not None: + raise ValueError('Invalid value for coerce_timestamps: {0}' + .format(coerce_timestamps)) + + # allow_truncated_timestamps + + if allow_truncated_timestamps: + arrow_props.allow_truncated_timestamps() + else: + arrow_props.disallow_truncated_timestamps() + + # use_compliant_nested_type + + if use_compliant_nested_type: + arrow_props.enable_compliant_nested_types() + else: + arrow_props.disable_compliant_nested_types() + + # writer_engine_version + + if writer_engine_version == "V1": + warnings.warn("V1 parquet writer engine is a no-op. Use V2.") + arrow_props.set_engine_version(ArrowWriterEngineVersion.V1) + elif writer_engine_version != "V2": + raise ValueError("Unsupported Writer Engine Version: {0}" + .format(writer_engine_version)) + + arrow_properties = arrow_props.build() + + return arrow_properties + + +cdef class ParquetWriter(_Weakrefable): + cdef: + unique_ptr[FileWriter] writer + shared_ptr[COutputStream] sink + bint own_sink + + cdef readonly: + object use_dictionary + object use_deprecated_int96_timestamps + object use_byte_stream_split + object coerce_timestamps + object allow_truncated_timestamps + object compression + object compression_level + object data_page_version + object use_compliant_nested_type + object version + object write_statistics + object writer_engine_version + int row_group_size + int64_t data_page_size + + def __cinit__(self, where, Schema schema, use_dictionary=None, + compression=None, version=None, + write_statistics=None, + MemoryPool memory_pool=None, + use_deprecated_int96_timestamps=False, + coerce_timestamps=None, + data_page_size=None, + allow_truncated_timestamps=False, + compression_level=None, + use_byte_stream_split=False, + writer_engine_version=None, + data_page_version=None, + use_compliant_nested_type=False): + cdef: + shared_ptr[WriterProperties] properties + shared_ptr[ArrowWriterProperties] arrow_properties + c_string c_where + CMemoryPool* pool + + try: + where = _stringify_path(where) + except TypeError: + get_writer(where, &self.sink) + self.own_sink = False + else: + c_where = tobytes(where) + with nogil: + self.sink = GetResultValue(FileOutputStream.Open(c_where)) + self.own_sink = True + + properties = _create_writer_properties( + use_dictionary=use_dictionary, + compression=compression, + version=version, + write_statistics=write_statistics, + data_page_size=data_page_size, + compression_level=compression_level, + use_byte_stream_split=use_byte_stream_split, + data_page_version=data_page_version + ) + arrow_properties = _create_arrow_writer_properties( + use_deprecated_int96_timestamps=use_deprecated_int96_timestamps, + coerce_timestamps=coerce_timestamps, + allow_truncated_timestamps=allow_truncated_timestamps, + writer_engine_version=writer_engine_version, + use_compliant_nested_type=use_compliant_nested_type + ) + + pool = maybe_unbox_memory_pool(memory_pool) + with nogil: + check_status( + FileWriter.Open(deref(schema.schema), pool, + self.sink, properties, arrow_properties, + &self.writer)) + + def close(self): + with nogil: + check_status(self.writer.get().Close()) + if self.own_sink: + check_status(self.sink.get().Close()) + + def write_table(self, Table table, row_group_size=None): + cdef: + CTable* ctable = table.table + int64_t c_row_group_size + + if row_group_size is None or row_group_size == -1: + c_row_group_size = ctable.num_rows() + elif row_group_size == 0: + raise ValueError('Row group size cannot be 0') + else: + c_row_group_size = row_group_size + + with nogil: + check_status(self.writer.get() + .WriteTable(deref(ctable), c_row_group_size)) + + @property + def metadata(self): + cdef: + shared_ptr[CFileMetaData] metadata + FileMetaData result + with nogil: + metadata = self.writer.get().metadata() + if metadata: + result = FileMetaData() + result.init(metadata) + return result + raise RuntimeError( + 'file metadata is only available after writer close') diff --git a/src/arrow/python/pyarrow/_plasma.pyx b/src/arrow/python/pyarrow/_plasma.pyx new file mode 100644 index 000000000..e38c81f80 --- /dev/null +++ b/src/arrow/python/pyarrow/_plasma.pyx @@ -0,0 +1,867 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from libcpp cimport bool as c_bool, nullptr +from libcpp.memory cimport shared_ptr, unique_ptr, make_shared +from libcpp.string cimport string as c_string +from libcpp.vector cimport vector as c_vector +from libcpp.unordered_map cimport unordered_map +from libc.stdint cimport int64_t, uint8_t, uintptr_t +from cython.operator cimport dereference as deref, preincrement as inc +from cpython.pycapsule cimport * + +from collections.abc import Sequence +import random +import socket +import warnings + +import pyarrow +from pyarrow.lib cimport (Buffer, NativeFile, _Weakrefable, + check_status, pyarrow_wrap_buffer) +from pyarrow.lib import ArrowException, frombytes +from pyarrow.includes.libarrow cimport (CBuffer, CMutableBuffer, + CFixedSizeBufferWriter, CStatus) +from pyarrow.includes.libplasma cimport * + +PLASMA_WAIT_TIMEOUT = 2 ** 30 + + +cdef extern from "plasma/common.h" nogil: + cdef cppclass CCudaIpcPlaceholder" plasma::internal::CudaIpcPlaceholder": + pass + + cdef cppclass CUniqueID" plasma::UniqueID": + + @staticmethod + CUniqueID from_binary(const c_string& binary) + + @staticmethod + CUniqueID from_random() + + c_bool operator==(const CUniqueID& rhs) const + + c_string hex() const + + c_string binary() const + + @staticmethod + int64_t size() + + cdef enum CObjectState" plasma::ObjectState": + PLASMA_CREATED" plasma::ObjectState::PLASMA_CREATED" + PLASMA_SEALED" plasma::ObjectState::PLASMA_SEALED" + + cdef struct CObjectTableEntry" plasma::ObjectTableEntry": + int fd + int device_num + int64_t map_size + ptrdiff_t offset + uint8_t* pointer + int64_t data_size + int64_t metadata_size + int ref_count + int64_t create_time + int64_t construct_duration + CObjectState state + shared_ptr[CCudaIpcPlaceholder] ipc_handle + + ctypedef unordered_map[CUniqueID, unique_ptr[CObjectTableEntry]] \ + CObjectTable" plasma::ObjectTable" + + +cdef extern from "plasma/common.h": + cdef int64_t kDigestSize" plasma::kDigestSize" + +cdef extern from "plasma/client.h" nogil: + + cdef cppclass CPlasmaClient" plasma::PlasmaClient": + + CPlasmaClient() + + CStatus Connect(const c_string& store_socket_name, + const c_string& manager_socket_name, + int release_delay, int num_retries) + + CStatus Create(const CUniqueID& object_id, + int64_t data_size, const uint8_t* metadata, int64_t + metadata_size, const shared_ptr[CBuffer]* data) + + CStatus CreateAndSeal(const CUniqueID& object_id, + const c_string& data, const c_string& metadata) + + CStatus Get(const c_vector[CUniqueID] object_ids, int64_t timeout_ms, + c_vector[CObjectBuffer]* object_buffers) + + CStatus Seal(const CUniqueID& object_id) + + CStatus Evict(int64_t num_bytes, int64_t& num_bytes_evicted) + + CStatus Hash(const CUniqueID& object_id, uint8_t* digest) + + CStatus Release(const CUniqueID& object_id) + + CStatus Contains(const CUniqueID& object_id, c_bool* has_object) + + CStatus List(CObjectTable* objects) + + CStatus Subscribe(int* fd) + + CStatus DecodeNotifications(const uint8_t* buffer, + c_vector[CUniqueID]* object_ids, + c_vector[int64_t]* data_sizes, + c_vector[int64_t]* metadata_sizes) + + CStatus GetNotification(int fd, CUniqueID* object_id, + int64_t* data_size, int64_t* metadata_size) + + CStatus Disconnect() + + CStatus Delete(const c_vector[CUniqueID] object_ids) + + CStatus SetClientOptions(const c_string& client_name, + int64_t limit_output_memory) + + c_string DebugString() + + int64_t store_capacity() + +cdef extern from "plasma/client.h" nogil: + + cdef struct CObjectBuffer" plasma::ObjectBuffer": + shared_ptr[CBuffer] data + shared_ptr[CBuffer] metadata + + +def make_object_id(object_id): + return ObjectID(object_id) + + +cdef class ObjectID(_Weakrefable): + """ + An ObjectID represents a string of bytes used to identify Plasma objects. + """ + + cdef: + CUniqueID data + + def __cinit__(self, object_id): + if (not isinstance(object_id, bytes) or + len(object_id) != CUniqueID.size()): + raise ValueError("Object ID must by 20 bytes," + " is " + str(object_id)) + self.data = CUniqueID.from_binary(object_id) + + def __eq__(self, other): + try: + return self.data == (<ObjectID?>other).data + except TypeError: + return False + + def __hash__(self): + return hash(self.data.binary()) + + def __repr__(self): + return "ObjectID(" + self.data.hex().decode() + ")" + + def __reduce__(self): + return (make_object_id, (self.data.binary(),)) + + def binary(self): + """ + Return the binary representation of this ObjectID. + + Returns + ------- + bytes + Binary representation of the ObjectID. + """ + return self.data.binary() + + @staticmethod + def from_random(): + """ + Returns a randomly generated ObjectID. + + Returns + ------- + ObjectID + A randomly generated ObjectID. + """ + random_id = bytes(bytearray( + random.getrandbits(8) for _ in range(CUniqueID.size()))) + return ObjectID(random_id) + + +cdef class ObjectNotAvailable(_Weakrefable): + """ + Placeholder for an object that was not available within the given timeout. + """ + pass + + +cdef class PlasmaBuffer(Buffer): + """ + This is the type returned by calls to get with a PlasmaClient. + + We define our own class instead of directly returning a buffer object so + that we can add a custom destructor which notifies Plasma that the object + is no longer being used, so the memory in the Plasma store backing the + object can potentially be freed. + + Attributes + ---------- + object_id : ObjectID + The ID of the object in the buffer. + client : PlasmaClient + The PlasmaClient that we use to communicate with the store and manager. + """ + + cdef: + ObjectID object_id + PlasmaClient client + + @staticmethod + cdef PlasmaBuffer create(ObjectID object_id, PlasmaClient client, + const shared_ptr[CBuffer]& buffer): + cdef PlasmaBuffer self = PlasmaBuffer.__new__(PlasmaBuffer) + self.object_id = object_id + self.client = client + self.init(buffer) + return self + + def __init__(self): + raise TypeError("Do not call PlasmaBuffer's constructor directly, use " + "`PlasmaClient.create` instead.") + + def __dealloc__(self): + """ + Notify Plasma that the object is no longer needed. + + If the plasma client has been shut down, then don't do anything. + """ + self.client._release(self.object_id) + + +class PlasmaObjectNotFound(ArrowException): + pass + + +class PlasmaStoreFull(ArrowException): + pass + + +class PlasmaObjectExists(ArrowException): + pass + + +cdef int plasma_check_status(const CStatus& status) nogil except -1: + if status.ok(): + return 0 + + with gil: + message = frombytes(status.message()) + if IsPlasmaObjectExists(status): + raise PlasmaObjectExists(message) + elif IsPlasmaObjectNotFound(status): + raise PlasmaObjectNotFound(message) + elif IsPlasmaStoreFull(status): + raise PlasmaStoreFull(message) + + return check_status(status) + + +def get_socket_from_fd(fileno, family, type): + import socket + return socket.socket(fileno=fileno, family=family, type=type) + + +cdef class PlasmaClient(_Weakrefable): + """ + The PlasmaClient is used to interface with a plasma store and manager. + + The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a + buffer, and get a buffer. Buffers are referred to by object IDs, which are + strings. + """ + + cdef: + shared_ptr[CPlasmaClient] client + int notification_fd + c_string store_socket_name + + def __cinit__(self): + self.client.reset(new CPlasmaClient()) + self.notification_fd = -1 + self.store_socket_name = b"" + + cdef _get_object_buffers(self, object_ids, int64_t timeout_ms, + c_vector[CObjectBuffer]* result): + cdef: + c_vector[CUniqueID] ids + ObjectID object_id + + for object_id in object_ids: + ids.push_back(object_id.data) + with nogil: + plasma_check_status(self.client.get().Get(ids, timeout_ms, result)) + + # XXX C++ API should instead expose some kind of CreateAuto() + cdef _make_mutable_plasma_buffer(self, ObjectID object_id, uint8_t* data, + int64_t size): + cdef shared_ptr[CBuffer] buffer + buffer.reset(new CMutableBuffer(data, size)) + return PlasmaBuffer.create(object_id, self, buffer) + + @property + def store_socket_name(self): + return self.store_socket_name.decode() + + def create(self, ObjectID object_id, int64_t data_size, + c_string metadata=b""): + """ + Create a new buffer in the PlasmaStore for a particular object ID. + + The returned buffer is mutable until seal is called. + + Parameters + ---------- + object_id : ObjectID + The object ID used to identify an object. + size : int + The size in bytes of the created buffer. + metadata : bytes + An optional string of bytes encoding whatever metadata the user + wishes to encode. + + Raises + ------ + PlasmaObjectExists + This exception is raised if the object could not be created because + there already is an object with the same ID in the plasma store. + + PlasmaStoreFull: This exception is raised if the object could + not be created because the plasma store is unable to evict + enough objects to create room for it. + """ + cdef shared_ptr[CBuffer] data + with nogil: + plasma_check_status( + self.client.get().Create(object_id.data, data_size, + <uint8_t*>(metadata.data()), + metadata.size(), &data)) + return self._make_mutable_plasma_buffer(object_id, + data.get().mutable_data(), + data_size) + + def create_and_seal(self, ObjectID object_id, c_string data, + c_string metadata=b""): + """ + Store a new object in the PlasmaStore for a particular object ID. + + Parameters + ---------- + object_id : ObjectID + The object ID used to identify an object. + data : bytes + The object to store. + metadata : bytes + An optional string of bytes encoding whatever metadata the user + wishes to encode. + + Raises + ------ + PlasmaObjectExists + This exception is raised if the object could not be created because + there already is an object with the same ID in the plasma store. + + PlasmaStoreFull: This exception is raised if the object could + not be created because the plasma store is unable to evict + enough objects to create room for it. + """ + with nogil: + plasma_check_status( + self.client.get().CreateAndSeal(object_id.data, data, + metadata)) + + def get_buffers(self, object_ids, timeout_ms=-1, with_meta=False): + """ + Returns data buffer from the PlasmaStore based on object ID. + + If the object has not been sealed yet, this call will block. The + retrieved buffer is immutable. + + Parameters + ---------- + object_ids : list + A list of ObjectIDs used to identify some objects. + timeout_ms : int + The number of milliseconds that the get call should block before + timing out and returning. Pass -1 if the call should block and 0 + if the call should return immediately. + with_meta : bool + + Returns + ------- + list + If with_meta=False, this is a list of PlasmaBuffers for the data + associated with the object_ids and None if the object was not + available. If with_meta=True, this is a list of tuples of + PlasmaBuffer and metadata bytes. + """ + cdef c_vector[CObjectBuffer] object_buffers + self._get_object_buffers(object_ids, timeout_ms, &object_buffers) + result = [] + for i in range(object_buffers.size()): + if object_buffers[i].data.get() != nullptr: + data = pyarrow_wrap_buffer(object_buffers[i].data) + else: + data = None + if not with_meta: + result.append(data) + else: + if object_buffers[i].metadata.get() != nullptr: + size = object_buffers[i].metadata.get().size() + metadata = object_buffers[i].metadata.get().data()[:size] + else: + metadata = None + result.append((metadata, data)) + return result + + def get_metadata(self, object_ids, timeout_ms=-1): + """ + Returns metadata buffer from the PlasmaStore based on object ID. + + If the object has not been sealed yet, this call will block. The + retrieved buffer is immutable. + + Parameters + ---------- + object_ids : list + A list of ObjectIDs used to identify some objects. + timeout_ms : int + The number of milliseconds that the get call should block before + timing out and returning. Pass -1 if the call should block and 0 + if the call should return immediately. + + Returns + ------- + list + List of PlasmaBuffers for the metadata associated with the + object_ids and None if the object was not available. + """ + cdef c_vector[CObjectBuffer] object_buffers + self._get_object_buffers(object_ids, timeout_ms, &object_buffers) + result = [] + for i in range(object_buffers.size()): + if object_buffers[i].metadata.get() != nullptr: + result.append(pyarrow_wrap_buffer(object_buffers[i].metadata)) + else: + result.append(None) + return result + + def put_raw_buffer(self, object value, ObjectID object_id=None, + c_string metadata=b"", int memcopy_threads=6): + """ + Store Python buffer into the object store. + + Parameters + ---------- + value : Python object that implements the buffer protocol + A Python buffer object to store. + object_id : ObjectID, default None + If this is provided, the specified object ID will be used to refer + to the object. + metadata : bytes + An optional string of bytes encoding whatever metadata the user + wishes to encode. + memcopy_threads : int, default 6 + The number of threads to use to write the serialized object into + the object store for large objects. + + Returns + ------- + The object ID associated to the Python buffer object. + """ + cdef ObjectID target_id = (object_id if object_id + else ObjectID.from_random()) + cdef Buffer arrow_buffer = pyarrow.py_buffer(value) + write_buffer = self.create(target_id, len(value), metadata) + stream = pyarrow.FixedSizeBufferWriter(write_buffer) + stream.set_memcopy_threads(memcopy_threads) + stream.write(arrow_buffer) + self.seal(target_id) + return target_id + + def put(self, object value, ObjectID object_id=None, int memcopy_threads=6, + serialization_context=None): + """ + Store a Python value into the object store. + + Parameters + ---------- + value : object + A Python object to store. + object_id : ObjectID, default None + If this is provided, the specified object ID will be used to refer + to the object. + memcopy_threads : int, default 6 + The number of threads to use to write the serialized object into + the object store for large objects. + serialization_context : pyarrow.SerializationContext, default None + Custom serialization and deserialization context. + + Returns + ------- + The object ID associated to the Python object. + """ + cdef ObjectID target_id = (object_id if object_id + else ObjectID.from_random()) + if serialization_context is not None: + warnings.warn( + "'serialization_context' is deprecated and will be removed " + "in a future version.", + FutureWarning, stacklevel=2 + ) + serialized = pyarrow.lib._serialize(value, serialization_context) + buffer = self.create(target_id, serialized.total_bytes) + stream = pyarrow.FixedSizeBufferWriter(buffer) + stream.set_memcopy_threads(memcopy_threads) + serialized.write_to(stream) + self.seal(target_id) + return target_id + + def get(self, object_ids, int timeout_ms=-1, serialization_context=None): + """ + Get one or more Python values from the object store. + + Parameters + ---------- + object_ids : list or ObjectID + Object ID or list of object IDs associated to the values we get + from the store. + timeout_ms : int, default -1 + The number of milliseconds that the get call should block before + timing out and returning. Pass -1 if the call should block and 0 + if the call should return immediately. + serialization_context : pyarrow.SerializationContext, default None + Custom serialization and deserialization context. + + Returns + ------- + list or object + Python value or list of Python values for the data associated with + the object_ids and ObjectNotAvailable if the object was not + available. + """ + if serialization_context is not None: + warnings.warn( + "'serialization_context' is deprecated and will be removed " + "in a future version.", + FutureWarning, stacklevel=2 + ) + if isinstance(object_ids, Sequence): + results = [] + buffers = self.get_buffers(object_ids, timeout_ms) + for i in range(len(object_ids)): + # buffers[i] is None if this object was not available within + # the timeout + if buffers[i]: + val = pyarrow.lib._deserialize(buffers[i], + serialization_context) + results.append(val) + else: + results.append(ObjectNotAvailable) + return results + else: + return self.get([object_ids], timeout_ms, serialization_context)[0] + + def seal(self, ObjectID object_id): + """ + Seal the buffer in the PlasmaStore for a particular object ID. + + Once a buffer has been sealed, the buffer is immutable and can only be + accessed through get. + + Parameters + ---------- + object_id : ObjectID + A string used to identify an object. + """ + with nogil: + plasma_check_status(self.client.get().Seal(object_id.data)) + + def _release(self, ObjectID object_id): + """ + Notify Plasma that the object is no longer needed. + + Parameters + ---------- + object_id : ObjectID + A string used to identify an object. + """ + with nogil: + plasma_check_status(self.client.get().Release(object_id.data)) + + def contains(self, ObjectID object_id): + """ + Check if the object is present and sealed in the PlasmaStore. + + Parameters + ---------- + object_id : ObjectID + A string used to identify an object. + """ + cdef c_bool is_contained + with nogil: + plasma_check_status(self.client.get().Contains(object_id.data, + &is_contained)) + return is_contained + + def hash(self, ObjectID object_id): + """ + Compute the checksum of an object in the object store. + + Parameters + ---------- + object_id : ObjectID + A string used to identify an object. + + Returns + ------- + bytes + A digest string object's hash. If the object isn't in the object + store, the string will have length zero. + """ + cdef c_vector[uint8_t] digest = c_vector[uint8_t](kDigestSize) + with nogil: + plasma_check_status(self.client.get().Hash(object_id.data, + digest.data())) + return bytes(digest[:]) + + def evict(self, int64_t num_bytes): + """ + Evict some objects until to recover some bytes. + + Recover at least num_bytes bytes if possible. + + Parameters + ---------- + num_bytes : int + The number of bytes to attempt to recover. + """ + cdef int64_t num_bytes_evicted = -1 + with nogil: + plasma_check_status( + self.client.get().Evict(num_bytes, num_bytes_evicted)) + return num_bytes_evicted + + def subscribe(self): + """Subscribe to notifications about sealed objects.""" + with nogil: + plasma_check_status( + self.client.get().Subscribe(&self.notification_fd)) + + def get_notification_socket(self): + """ + Get the notification socket. + """ + return get_socket_from_fd(self.notification_fd, + family=socket.AF_UNIX, + type=socket.SOCK_STREAM) + + def decode_notifications(self, const uint8_t* buf): + """ + Get the notification from the buffer. + + Returns + ------- + [ObjectID] + The list of object IDs in the notification message. + c_vector[int64_t] + The data sizes of the objects in the notification message. + c_vector[int64_t] + The metadata sizes of the objects in the notification message. + """ + cdef c_vector[CUniqueID] ids + cdef c_vector[int64_t] data_sizes + cdef c_vector[int64_t] metadata_sizes + with nogil: + status = self.client.get().DecodeNotifications(buf, + &ids, + &data_sizes, + &metadata_sizes) + plasma_check_status(status) + object_ids = [] + for object_id in ids: + object_ids.append(ObjectID(object_id.binary())) + return object_ids, data_sizes, metadata_sizes + + def get_next_notification(self): + """ + Get the next notification from the notification socket. + + Returns + ------- + ObjectID + The object ID of the object that was stored. + int + The data size of the object that was stored. + int + The metadata size of the object that was stored. + """ + cdef ObjectID object_id = ObjectID(CUniqueID.size() * b"\0") + cdef int64_t data_size + cdef int64_t metadata_size + with nogil: + status = self.client.get().GetNotification(self.notification_fd, + &object_id.data, + &data_size, + &metadata_size) + plasma_check_status(status) + return object_id, data_size, metadata_size + + def to_capsule(self): + return PyCapsule_New(<void *>self.client.get(), "plasma", NULL) + + def disconnect(self): + """ + Disconnect this client from the Plasma store. + """ + with nogil: + plasma_check_status(self.client.get().Disconnect()) + + def delete(self, object_ids): + """ + Delete the objects with the given IDs from other object store. + + Parameters + ---------- + object_ids : list + A list of strings used to identify the objects. + """ + cdef c_vector[CUniqueID] ids + cdef ObjectID object_id + for object_id in object_ids: + ids.push_back(object_id.data) + with nogil: + plasma_check_status(self.client.get().Delete(ids)) + + def set_client_options(self, client_name, int64_t limit_output_memory): + cdef c_string name + name = client_name.encode() + with nogil: + plasma_check_status( + self.client.get().SetClientOptions(name, limit_output_memory)) + + def debug_string(self): + cdef c_string result + with nogil: + result = self.client.get().DebugString() + return result.decode() + + def list(self): + """ + Experimental: List the objects in the store. + + Returns + ------- + dict + Dictionary from ObjectIDs to an "info" dictionary describing the + object. The "info" dictionary has the following entries: + + data_size + size of the object in bytes + + metadata_size + size of the object metadata in bytes + + ref_count + Number of clients referencing the object buffer + + create_time + Unix timestamp of the creation of the object + + construct_duration + Time the creation of the object took in seconds + + state + "created" if the object is still being created and + "sealed" if it is already sealed + """ + cdef CObjectTable objects + with nogil: + plasma_check_status(self.client.get().List(&objects)) + result = dict() + cdef ObjectID object_id + cdef CObjectTableEntry entry + it = objects.begin() + while it != objects.end(): + object_id = ObjectID(deref(it).first.binary()) + entry = deref(deref(it).second) + if entry.state == CObjectState.PLASMA_CREATED: + state = "created" + else: + state = "sealed" + result[object_id] = { + "data_size": entry.data_size, + "metadata_size": entry.metadata_size, + "ref_count": entry.ref_count, + "create_time": entry.create_time, + "construct_duration": entry.construct_duration, + "state": state + } + inc(it) + return result + + def store_capacity(self): + """ + Get the memory capacity of the store. + + Returns + ------- + + int + The memory capacity of the store in bytes. + """ + return self.client.get().store_capacity() + + +def connect(store_socket_name, int num_retries=-1): + """ + Return a new PlasmaClient that is connected a plasma store and + optionally a manager. + + Parameters + ---------- + store_socket_name : str + Name of the socket the plasma store is listening at. + num_retries : int, default -1 + Number of times to try to connect to plasma store. Default value of -1 + uses the default (50) + """ + cdef PlasmaClient result = PlasmaClient() + cdef int deprecated_release_delay = 0 + result.store_socket_name = store_socket_name.encode() + with nogil: + plasma_check_status( + result.client.get().Connect(result.store_socket_name, b"", + deprecated_release_delay, num_retries)) + return result diff --git a/src/arrow/python/pyarrow/_s3fs.pyx b/src/arrow/python/pyarrow/_s3fs.pyx new file mode 100644 index 000000000..5829d74d3 --- /dev/null +++ b/src/arrow/python/pyarrow/_s3fs.pyx @@ -0,0 +1,284 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from pyarrow.lib cimport (check_status, pyarrow_wrap_metadata, + pyarrow_unwrap_metadata) +from pyarrow.lib import frombytes, tobytes, KeyValueMetadata +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow._fs cimport FileSystem + + +cpdef enum S3LogLevel: + Off = <int8_t> CS3LogLevel_Off + Fatal = <int8_t> CS3LogLevel_Fatal + Error = <int8_t> CS3LogLevel_Error + Warn = <int8_t> CS3LogLevel_Warn + Info = <int8_t> CS3LogLevel_Info + Debug = <int8_t> CS3LogLevel_Debug + Trace = <int8_t> CS3LogLevel_Trace + + +def initialize_s3(S3LogLevel log_level=S3LogLevel.Fatal): + """ + Initialize S3 support + + Parameters + ---------- + log_level : S3LogLevel + level of logging + """ + cdef CS3GlobalOptions options + options.log_level = <CS3LogLevel> log_level + check_status(CInitializeS3(options)) + + +def finalize_s3(): + check_status(CFinalizeS3()) + + +cdef class S3FileSystem(FileSystem): + """ + S3-backed FileSystem implementation + + If neither access_key nor secret_key are provided, and role_arn is also not + provided, then attempts to initialize from AWS environment variables, + otherwise both access_key and secret_key must be provided. + + If role_arn is provided instead of access_key and secret_key, temporary + credentials will be fetched by issuing a request to STS to assume the + specified role. + + Note: S3 buckets are special and the operations available on them may be + limited or more expensive than desired. + + Parameters + ---------- + access_key : str, default None + AWS Access Key ID. Pass None to use the standard AWS environment + variables and/or configuration file. + secret_key : str, default None + AWS Secret Access key. Pass None to use the standard AWS environment + variables and/or configuration file. + session_token : str, default None + AWS Session Token. An optional session token, required if access_key + and secret_key are temporary credentials from STS. + anonymous : boolean, default False + Whether to connect anonymously if access_key and secret_key are None. + If true, will not attempt to look up credentials using standard AWS + configuration methods. + role_arn : str, default None + AWS Role ARN. If provided instead of access_key and secret_key, + temporary credentials will be fetched by assuming this role. + session_name : str, default None + An optional identifier for the assumed role session. + external_id : str, default None + An optional unique identifier that might be required when you assume + a role in another account. + load_frequency : int, default 900 + The frequency (in seconds) with which temporary credentials from an + assumed role session will be refreshed. + region : str, default 'us-east-1' + AWS region to connect to. + scheme : str, default 'https' + S3 connection transport scheme. + endpoint_override : str, default None + Override region with a connect string such as "localhost:9000" + background_writes : boolean, default True + Whether file writes will be issued in the background, without + blocking. + default_metadata : mapping or KeyValueMetadata, default None + Default metadata for open_output_stream. This will be ignored if + non-empty metadata is passed to open_output_stream. + proxy_options : dict or str, default None + If a proxy is used, provide the options here. Supported options are: + 'scheme' (str: 'http' or 'https'; required), 'host' (str; required), + 'port' (int; required), 'username' (str; optional), + 'password' (str; optional). + A proxy URI (str) can also be provided, in which case these options + will be derived from the provided URI. + The following are equivalent:: + + S3FileSystem(proxy_options='http://username:password@localhost:8020') + S3FileSystem(proxy_options={'scheme': 'http', 'host': 'localhost', + 'port': 8020, 'username': 'username', + 'password': 'password'}) + """ + + cdef: + CS3FileSystem* s3fs + + def __init__(self, *, access_key=None, secret_key=None, session_token=None, + bint anonymous=False, region=None, scheme=None, + endpoint_override=None, bint background_writes=True, + default_metadata=None, role_arn=None, session_name=None, + external_id=None, load_frequency=900, proxy_options=None): + cdef: + CS3Options options + shared_ptr[CS3FileSystem] wrapped + + if access_key is not None and secret_key is None: + raise ValueError( + 'In order to initialize with explicit credentials both ' + 'access_key and secret_key must be provided, ' + '`secret_key` is not set.' + ) + elif access_key is None and secret_key is not None: + raise ValueError( + 'In order to initialize with explicit credentials both ' + 'access_key and secret_key must be provided, ' + '`access_key` is not set.' + ) + + elif session_token is not None and (access_key is None or + secret_key is None): + raise ValueError( + 'In order to initialize a session with temporary credentials, ' + 'both secret_key and access_key must be provided in addition ' + 'to session_token.' + ) + + elif (access_key is not None or secret_key is not None): + if anonymous: + raise ValueError( + 'Cannot pass anonymous=True together with access_key ' + 'and secret_key.') + + if role_arn: + raise ValueError( + 'Cannot provide role_arn with access_key and secret_key') + + if session_token is None: + session_token = "" + + options = CS3Options.FromAccessKey( + tobytes(access_key), + tobytes(secret_key), + tobytes(session_token) + ) + elif anonymous: + if role_arn: + raise ValueError( + 'Cannot provide role_arn with anonymous=True') + + options = CS3Options.Anonymous() + elif role_arn: + + options = CS3Options.FromAssumeRole( + tobytes(role_arn), + tobytes(session_name), + tobytes(external_id), + load_frequency + ) + else: + options = CS3Options.Defaults() + + if region is not None: + options.region = tobytes(region) + if scheme is not None: + options.scheme = tobytes(scheme) + if endpoint_override is not None: + options.endpoint_override = tobytes(endpoint_override) + if background_writes is not None: + options.background_writes = background_writes + if default_metadata is not None: + if not isinstance(default_metadata, KeyValueMetadata): + default_metadata = KeyValueMetadata(default_metadata) + options.default_metadata = pyarrow_unwrap_metadata( + default_metadata) + + if proxy_options is not None: + if isinstance(proxy_options, dict): + options.proxy_options.scheme = tobytes(proxy_options["scheme"]) + options.proxy_options.host = tobytes(proxy_options["host"]) + options.proxy_options.port = proxy_options["port"] + proxy_username = proxy_options.get("username", None) + if proxy_username: + options.proxy_options.username = tobytes(proxy_username) + proxy_password = proxy_options.get("password", None) + if proxy_password: + options.proxy_options.password = tobytes(proxy_password) + elif isinstance(proxy_options, str): + options.proxy_options = GetResultValue( + CS3ProxyOptions.FromUriString(tobytes(proxy_options))) + else: + raise TypeError( + "'proxy_options': expected 'dict' or 'str', " + f"got {type(proxy_options)} instead.") + + with nogil: + wrapped = GetResultValue(CS3FileSystem.Make(options)) + + self.init(<shared_ptr[CFileSystem]> wrapped) + + cdef init(self, const shared_ptr[CFileSystem]& wrapped): + FileSystem.init(self, wrapped) + self.s3fs = <CS3FileSystem*> wrapped.get() + + @classmethod + def _reconstruct(cls, kwargs): + return cls(**kwargs) + + def __reduce__(self): + cdef CS3Options opts = self.s3fs.options() + + # if creds were explicitly provided, then use them + # else obtain them as they were last time. + if opts.credentials_kind == CS3CredentialsKind_Explicit: + access_key = frombytes(opts.GetAccessKey()) + secret_key = frombytes(opts.GetSecretKey()) + session_token = frombytes(opts.GetSessionToken()) + else: + access_key = None + secret_key = None + session_token = None + + return ( + S3FileSystem._reconstruct, (dict( + access_key=access_key, + secret_key=secret_key, + session_token=session_token, + anonymous=(opts.credentials_kind == + CS3CredentialsKind_Anonymous), + region=frombytes(opts.region), + scheme=frombytes(opts.scheme), + endpoint_override=frombytes(opts.endpoint_override), + role_arn=frombytes(opts.role_arn), + session_name=frombytes(opts.session_name), + external_id=frombytes(opts.external_id), + load_frequency=opts.load_frequency, + background_writes=opts.background_writes, + default_metadata=pyarrow_wrap_metadata(opts.default_metadata), + proxy_options={'scheme': frombytes(opts.proxy_options.scheme), + 'host': frombytes(opts.proxy_options.host), + 'port': opts.proxy_options.port, + 'username': frombytes( + opts.proxy_options.username), + 'password': frombytes( + opts.proxy_options.password)} + ),) + ) + + @property + def region(self): + """ + The AWS region this filesystem connects to. + """ + return frombytes(self.s3fs.region()) diff --git a/src/arrow/python/pyarrow/array.pxi b/src/arrow/python/pyarrow/array.pxi new file mode 100644 index 000000000..97cbce759 --- /dev/null +++ b/src/arrow/python/pyarrow/array.pxi @@ -0,0 +1,2541 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import warnings + + +cdef _sequence_to_array(object sequence, object mask, object size, + DataType type, CMemoryPool* pool, c_bool from_pandas): + cdef: + int64_t c_size + PyConversionOptions options + shared_ptr[CChunkedArray] chunked + + if type is not None: + options.type = type.sp_type + + if size is not None: + options.size = size + + options.from_pandas = from_pandas + options.ignore_timezone = os.environ.get('PYARROW_IGNORE_TIMEZONE', False) + + with nogil: + chunked = GetResultValue( + ConvertPySequence(sequence, mask, options, pool) + ) + + if chunked.get().num_chunks() == 1: + return pyarrow_wrap_array(chunked.get().chunk(0)) + else: + return pyarrow_wrap_chunked_array(chunked) + + +cdef inline _is_array_like(obj): + if isinstance(obj, np.ndarray): + return True + return pandas_api._have_pandas_internal() and pandas_api.is_array_like(obj) + + +def _ndarray_to_arrow_type(object values, DataType type): + return pyarrow_wrap_data_type(_ndarray_to_type(values, type)) + + +cdef shared_ptr[CDataType] _ndarray_to_type(object values, + DataType type) except *: + cdef shared_ptr[CDataType] c_type + + dtype = values.dtype + + if type is None and dtype != object: + with nogil: + check_status(NumPyDtypeToArrow(dtype, &c_type)) + + if type is not None: + c_type = type.sp_type + + return c_type + + +cdef _ndarray_to_array(object values, object mask, DataType type, + c_bool from_pandas, c_bool safe, CMemoryPool* pool): + cdef: + shared_ptr[CChunkedArray] chunked_out + shared_ptr[CDataType] c_type = _ndarray_to_type(values, type) + CCastOptions cast_options = CCastOptions(safe) + + with nogil: + check_status(NdarrayToArrow(pool, values, mask, from_pandas, + c_type, cast_options, &chunked_out)) + + if chunked_out.get().num_chunks() > 1: + return pyarrow_wrap_chunked_array(chunked_out) + else: + return pyarrow_wrap_array(chunked_out.get().chunk(0)) + + +cdef _codes_to_indices(object codes, object mask, DataType type, + MemoryPool memory_pool): + """ + Convert the codes of a pandas Categorical to indices for a pyarrow + DictionaryArray, taking into account missing values + mask + """ + if mask is None: + mask = codes == -1 + else: + mask = mask | (codes == -1) + return array(codes, mask=mask, type=type, memory_pool=memory_pool) + + +def _handle_arrow_array_protocol(obj, type, mask, size): + if mask is not None or size is not None: + raise ValueError( + "Cannot specify a mask or a size when passing an object that is " + "converted with the __arrow_array__ protocol.") + res = obj.__arrow_array__(type=type) + if not isinstance(res, (Array, ChunkedArray)): + raise TypeError("The object's __arrow_array__ method does not " + "return a pyarrow Array or ChunkedArray.") + return res + + +def array(object obj, type=None, mask=None, size=None, from_pandas=None, + bint safe=True, MemoryPool memory_pool=None): + """ + Create pyarrow.Array instance from a Python object. + + Parameters + ---------- + obj : sequence, iterable, ndarray or Series + If both type and size are specified may be a single use iterable. If + not strongly-typed, Arrow type will be inferred for resulting array. + type : pyarrow.DataType + Explicit type to attempt to coerce to, otherwise will be inferred from + the data. + mask : array[bool], optional + Indicate which values are null (True) or not null (False). + size : int64, optional + Size of the elements. If the input is larger than size bail at this + length. For iterators, if size is larger than the input iterator this + will be treated as a "max size", but will involve an initial allocation + of size followed by a resize to the actual size (so if you know the + exact size specifying it correctly will give you better performance). + from_pandas : bool, default None + Use pandas's semantics for inferring nulls from values in + ndarray-like data. If passed, the mask tasks precedence, but + if a value is unmasked (not-null), but still null according to + pandas semantics, then it is null. Defaults to False if not + passed explicitly by user, or True if a pandas object is + passed in. + safe : bool, default True + Check for overflows or other unsafe conversions. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the currently-set default + memory pool. + + Returns + ------- + array : pyarrow.Array or pyarrow.ChunkedArray + A ChunkedArray instead of an Array is returned if: + + - the object data overflowed binary storage. + - the object's ``__arrow_array__`` protocol method returned a chunked + array. + + Notes + ----- + Localized timestamps will currently be returned as UTC (pandas's native + representation). Timezone-naive data will be implicitly interpreted as + UTC. + + Pandas's DateOffsets and dateutil.relativedelta.relativedelta are by + default converted as MonthDayNanoIntervalArray. relativedelta leapdays + are ignored as are all absolute fields on both objects. datetime.timedelta + can also be converted to MonthDayNanoIntervalArray but this requires + passing MonthDayNanoIntervalType explicitly. + + Converting to dictionary array will promote to a wider integer type for + indices if the number of distinct values cannot be represented, even if + the index type was explicitly set. This means that if there are more than + 127 values the returned dictionary array's index type will be at least + pa.int16() even if pa.int8() was passed to the function. Note that an + explicit index type will not be demoted even if it is wider than required. + + Examples + -------- + >>> import pandas as pd + >>> import pyarrow as pa + >>> pa.array(pd.Series([1, 2])) + <pyarrow.lib.Int64Array object at 0x7f674e4c0e10> + [ + 1, + 2 + ] + + >>> pa.array(["a", "b", "a"], type=pa.dictionary(pa.int8(), pa.string())) + <pyarrow.lib.DictionaryArray object at 0x7feb288d9040> + -- dictionary: + [ + "a", + "b" + ] + -- indices: + [ + 0, + 1, + 0 + ] + + >>> import numpy as np + >>> pa.array(pd.Series([1, 2]), mask=np.array([0, 1], dtype=bool)) + <pyarrow.lib.Int64Array object at 0x7f9019e11208> + [ + 1, + null + ] + + >>> arr = pa.array(range(1024), type=pa.dictionary(pa.int8(), pa.int64())) + >>> arr.type.index_type + DataType(int16) + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + bint is_pandas_object = False + bint c_from_pandas + + type = ensure_type(type, allow_none=True) + + if from_pandas is None: + c_from_pandas = False + else: + c_from_pandas = from_pandas + + if hasattr(obj, '__arrow_array__'): + return _handle_arrow_array_protocol(obj, type, mask, size) + elif _is_array_like(obj): + if mask is not None: + if _is_array_like(mask): + mask = get_values(mask, &is_pandas_object) + else: + raise TypeError("Mask must be a numpy array " + "when converting numpy arrays") + + values = get_values(obj, &is_pandas_object) + if is_pandas_object and from_pandas is None: + c_from_pandas = True + + if isinstance(values, np.ma.MaskedArray): + if mask is not None: + raise ValueError("Cannot pass a numpy masked array and " + "specify a mask at the same time") + else: + # don't use shrunken masks + mask = None if values.mask is np.ma.nomask else values.mask + values = values.data + + if mask is not None: + if mask.dtype != np.bool_: + raise TypeError("Mask must be boolean dtype") + if mask.ndim != 1: + raise ValueError("Mask must be 1D array") + if len(values) != len(mask): + raise ValueError( + "Mask is a different length from sequence being converted") + + if hasattr(values, '__arrow_array__'): + return _handle_arrow_array_protocol(values, type, mask, size) + elif pandas_api.is_categorical(values): + if type is not None: + if type.id != Type_DICTIONARY: + return _ndarray_to_array( + np.asarray(values), mask, type, c_from_pandas, safe, + pool) + index_type = type.index_type + value_type = type.value_type + if values.ordered != type.ordered: + warnings.warn( + "The 'ordered' flag of the passed categorical values " + "does not match the 'ordered' of the specified type. " + "Using the flag of the values, but in the future this " + "mismatch will raise a ValueError.", + FutureWarning, stacklevel=2) + else: + index_type = None + value_type = None + + indices = _codes_to_indices( + values.codes, mask, index_type, memory_pool) + try: + dictionary = array( + values.categories.values, type=value_type, + memory_pool=memory_pool) + except TypeError: + # TODO when removing the deprecation warning, this whole + # try/except can be removed (to bubble the TypeError of + # the first array(..) call) + if value_type is not None: + warnings.warn( + "The dtype of the 'categories' of the passed " + "categorical values ({0}) does not match the " + "specified type ({1}). For now ignoring the specified " + "type, but in the future this mismatch will raise a " + "TypeError".format( + values.categories.dtype, value_type), + FutureWarning, stacklevel=2) + dictionary = array( + values.categories.values, memory_pool=memory_pool) + else: + raise + + return DictionaryArray.from_arrays( + indices, dictionary, ordered=values.ordered, safe=safe) + else: + if pandas_api.have_pandas: + values, type = pandas_api.compat.get_datetimetz_type( + values, obj.dtype, type) + return _ndarray_to_array(values, mask, type, c_from_pandas, safe, + pool) + else: + # ConvertPySequence does strict conversion if type is explicitly passed + return _sequence_to_array(obj, mask, size, type, pool, c_from_pandas) + + +def asarray(values, type=None): + """ + Convert to pyarrow.Array, inferring type if not provided. + + Parameters + ---------- + values : array-like + This can be a sequence, numpy.ndarray, pyarrow.Array or + pyarrow.ChunkedArray. If a ChunkedArray is passed, the output will be + a ChunkedArray, otherwise the output will be a Array. + type : string or DataType + Explicitly construct the array with this type. Attempt to cast if + indicated type is different. + + Returns + ------- + arr : Array or ChunkedArray + """ + if isinstance(values, (Array, ChunkedArray)): + if type is not None and not values.type.equals(type): + values = values.cast(type) + return values + else: + return array(values, type=type) + + +def nulls(size, type=None, MemoryPool memory_pool=None): + """ + Create a strongly-typed Array instance with all elements null. + + Parameters + ---------- + size : int + Array length. + type : pyarrow.DataType, default None + Explicit type for the array. By default use NullType. + memory_pool : MemoryPool, default None + Arrow MemoryPool to use for allocations. Uses the default memory + pool is not passed. + + Returns + ------- + arr : Array + + Examples + -------- + >>> import pyarrow as pa + >>> pa.nulls(10) + <pyarrow.lib.NullArray object at 0x7ffaf04c2e50> + 10 nulls + + >>> pa.nulls(3, pa.uint32()) + <pyarrow.lib.UInt32Array object at 0x7ffaf04c2e50> + [ + null, + null, + null + ] + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + int64_t length = size + shared_ptr[CDataType] ty + shared_ptr[CArray] arr + + type = ensure_type(type, allow_none=True) + if type is None: + type = null() + + ty = pyarrow_unwrap_data_type(type) + with nogil: + arr = GetResultValue(MakeArrayOfNull(ty, length, pool)) + + return pyarrow_wrap_array(arr) + + +def repeat(value, size, MemoryPool memory_pool=None): + """ + Create an Array instance whose slots are the given scalar. + + Parameters + ---------- + value : Scalar-like object + Either a pyarrow.Scalar or any python object coercible to a Scalar. + size : int + Number of times to repeat the scalar in the output Array. + memory_pool : MemoryPool, default None + Arrow MemoryPool to use for allocations. Uses the default memory + pool is not passed. + + Returns + ------- + arr : Array + + Examples + -------- + >>> import pyarrow as pa + >>> pa.repeat(10, 3) + <pyarrow.lib.Int64Array object at 0x7ffac03a2750> + [ + 10, + 10, + 10 + ] + + >>> pa.repeat([1, 2], 2) + <pyarrow.lib.ListArray object at 0x7ffaf04c2e50> + [ + [ + 1, + 2 + ], + [ + 1, + 2 + ] + ] + + >>> pa.repeat("string", 3) + <pyarrow.lib.StringArray object at 0x7ffac03a2750> + [ + "string", + "string", + "string" + ] + + >>> pa.repeat(pa.scalar({'a': 1, 'b': [1, 2]}), 2) + <pyarrow.lib.StructArray object at 0x7ffac03a2750> + -- is_valid: all not null + -- child 0 type: int64 + [ + 1, + 1 + ] + -- child 1 type: list<item: int64> + [ + [ + 1, + 2 + ], + [ + 1, + 2 + ] + ] + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + int64_t length = size + shared_ptr[CArray] c_array + shared_ptr[CScalar] c_scalar + + if not isinstance(value, Scalar): + value = scalar(value, memory_pool=memory_pool) + + c_scalar = (<Scalar> value).unwrap() + with nogil: + c_array = GetResultValue( + MakeArrayFromScalar(deref(c_scalar), length, pool) + ) + + return pyarrow_wrap_array(c_array) + + +def infer_type(values, mask=None, from_pandas=False): + """ + Attempt to infer Arrow data type that can hold the passed Python + sequence type in an Array object + + Parameters + ---------- + values : array-like + Sequence to infer type from. + mask : ndarray (bool type), optional + Optional exclusion mask where True marks null, False non-null. + from_pandas : bool, default False + Use pandas's NA/null sentinel values for type inference. + + Returns + ------- + type : DataType + """ + cdef: + shared_ptr[CDataType] out + c_bool use_pandas_sentinels = from_pandas + + if mask is not None and not isinstance(mask, np.ndarray): + mask = np.array(mask, dtype=bool) + + out = GetResultValue(InferArrowType(values, mask, use_pandas_sentinels)) + return pyarrow_wrap_data_type(out) + + +def _normalize_slice(object arrow_obj, slice key): + """ + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + """ + cdef: + Py_ssize_t start, stop, step + Py_ssize_t n = len(arrow_obj) + + start = key.start or 0 + if start < 0: + start += n + if start < 0: + start = 0 + elif start >= n: + start = n + + stop = key.stop if key.stop is not None else n + if stop < 0: + stop += n + if stop < 0: + stop = 0 + elif stop >= n: + stop = n + + step = key.step or 1 + if step != 1: + if step < 0: + # Negative steps require some special handling + if key.start is None: + start = n - 1 + + if key.stop is None: + stop = -1 + + indices = np.arange(start, stop, step) + return arrow_obj.take(indices) + else: + length = max(stop - start, 0) + return arrow_obj.slice(start, length) + + +cdef Py_ssize_t _normalize_index(Py_ssize_t index, + Py_ssize_t length) except -1: + if index < 0: + index += length + if index < 0: + raise IndexError("index out of bounds") + elif index >= length: + raise IndexError("index out of bounds") + return index + + +cdef wrap_datum(const CDatum& datum): + if datum.kind() == DatumType_ARRAY: + return pyarrow_wrap_array(MakeArray(datum.array())) + elif datum.kind() == DatumType_CHUNKED_ARRAY: + return pyarrow_wrap_chunked_array(datum.chunked_array()) + elif datum.kind() == DatumType_RECORD_BATCH: + return pyarrow_wrap_batch(datum.record_batch()) + elif datum.kind() == DatumType_TABLE: + return pyarrow_wrap_table(datum.table()) + elif datum.kind() == DatumType_SCALAR: + return pyarrow_wrap_scalar(datum.scalar()) + else: + raise ValueError("Unable to wrap Datum in a Python object") + + +cdef _append_array_buffers(const CArrayData* ad, list res): + """ + Recursively append Buffer wrappers from *ad* and its children. + """ + cdef size_t i, n + assert ad != NULL + n = ad.buffers.size() + for i in range(n): + buf = ad.buffers[i] + res.append(pyarrow_wrap_buffer(buf) + if buf.get() != NULL else None) + n = ad.child_data.size() + for i in range(n): + _append_array_buffers(ad.child_data[i].get(), res) + + +cdef _reduce_array_data(const CArrayData* ad): + """ + Recursively dissect ArrayData to (pickable) tuples. + """ + cdef size_t i, n + assert ad != NULL + + n = ad.buffers.size() + buffers = [] + for i in range(n): + buf = ad.buffers[i] + buffers.append(pyarrow_wrap_buffer(buf) + if buf.get() != NULL else None) + + children = [] + n = ad.child_data.size() + for i in range(n): + children.append(_reduce_array_data(ad.child_data[i].get())) + + if ad.dictionary.get() != NULL: + dictionary = _reduce_array_data(ad.dictionary.get()) + else: + dictionary = None + + return pyarrow_wrap_data_type(ad.type), ad.length, ad.null_count, \ + ad.offset, buffers, children, dictionary + + +cdef shared_ptr[CArrayData] _reconstruct_array_data(data): + """ + Reconstruct CArrayData objects from the tuple structure generated + by _reduce_array_data. + """ + cdef: + int64_t length, null_count, offset, i + DataType dtype + Buffer buf + vector[shared_ptr[CBuffer]] c_buffers + vector[shared_ptr[CArrayData]] c_children + shared_ptr[CArrayData] c_dictionary + + dtype, length, null_count, offset, buffers, children, dictionary = data + + for i in range(len(buffers)): + buf = buffers[i] + if buf is None: + c_buffers.push_back(shared_ptr[CBuffer]()) + else: + c_buffers.push_back(buf.buffer) + + for i in range(len(children)): + c_children.push_back(_reconstruct_array_data(children[i])) + + if dictionary is not None: + c_dictionary = _reconstruct_array_data(dictionary) + + return CArrayData.MakeWithChildrenAndDictionary( + dtype.sp_type, + length, + c_buffers, + c_children, + c_dictionary, + null_count, + offset) + + +def _restore_array(data): + """ + Reconstruct an Array from pickled ArrayData. + """ + cdef shared_ptr[CArrayData] ad = _reconstruct_array_data(data) + return pyarrow_wrap_array(MakeArray(ad)) + + +cdef class _PandasConvertible(_Weakrefable): + + def to_pandas( + self, + memory_pool=None, + categories=None, + bint strings_to_categorical=False, + bint zero_copy_only=False, + bint integer_object_nulls=False, + bint date_as_object=True, + bint timestamp_as_object=False, + bint use_threads=True, + bint deduplicate_objects=True, + bint ignore_metadata=False, + bint safe=True, + bint split_blocks=False, + bint self_destruct=False, + types_mapper=None + ): + """ + Convert to a pandas-compatible NumPy array or DataFrame, as appropriate + + Parameters + ---------- + memory_pool : MemoryPool, default None + Arrow MemoryPool to use for allocations. Uses the default memory + pool is not passed. + strings_to_categorical : bool, default False + Encode string (UTF8) and binary types to pandas.Categorical. + categories: list, default empty + List of fields that should be returned as pandas.Categorical. Only + applies to table-like data structures. + zero_copy_only : bool, default False + Raise an ArrowException if this function call would require copying + the underlying data. + integer_object_nulls : bool, default False + Cast integers with nulls to objects + date_as_object : bool, default True + Cast dates to objects. If False, convert to datetime64[ns] dtype. + timestamp_as_object : bool, default False + Cast non-nanosecond timestamps (np.datetime64) to objects. This is + useful if you have timestamps that don't fit in the normal date + range of nanosecond timestamps (1678 CE-2262 CE). + If False, all timestamps are converted to datetime64[ns] dtype. + use_threads: bool, default True + Whether to parallelize the conversion using multiple threads. + deduplicate_objects : bool, default False + Do not create multiple copies Python objects when created, to save + on memory use. Conversion will be slower. + ignore_metadata : bool, default False + If True, do not use the 'pandas' metadata to reconstruct the + DataFrame index, if present + safe : bool, default True + For certain data types, a cast is needed in order to store the + data in a pandas DataFrame or Series (e.g. timestamps are always + stored as nanoseconds in pandas). This option controls whether it + is a safe cast or not. + split_blocks : bool, default False + If True, generate one internal "block" for each column when + creating a pandas.DataFrame from a RecordBatch or Table. While this + can temporarily reduce memory note that various pandas operations + can trigger "consolidation" which may balloon memory use. + self_destruct : bool, default False + EXPERIMENTAL: If True, attempt to deallocate the originating Arrow + memory while converting the Arrow object to pandas. If you use the + object after calling to_pandas with this option it will crash your + program. + + Note that you may not see always memory usage improvements. For + example, if multiple columns share an underlying allocation, + memory can't be freed until all columns are converted. + types_mapper : function, default None + A function mapping a pyarrow DataType to a pandas ExtensionDtype. + This can be used to override the default pandas type for conversion + of built-in pyarrow types or in absence of pandas_metadata in the + Table schema. The function receives a pyarrow DataType and is + expected to return a pandas ExtensionDtype or ``None`` if the + default conversion should be used for that type. If you have + a dictionary mapping, you can pass ``dict.get`` as function. + + Returns + ------- + pandas.Series or pandas.DataFrame depending on type of object + """ + options = dict( + pool=memory_pool, + strings_to_categorical=strings_to_categorical, + zero_copy_only=zero_copy_only, + integer_object_nulls=integer_object_nulls, + date_as_object=date_as_object, + timestamp_as_object=timestamp_as_object, + use_threads=use_threads, + deduplicate_objects=deduplicate_objects, + safe=safe, + split_blocks=split_blocks, + self_destruct=self_destruct + ) + return self._to_pandas(options, categories=categories, + ignore_metadata=ignore_metadata, + types_mapper=types_mapper) + + +cdef PandasOptions _convert_pandas_options(dict options): + cdef PandasOptions result + result.pool = maybe_unbox_memory_pool(options['pool']) + result.strings_to_categorical = options['strings_to_categorical'] + result.zero_copy_only = options['zero_copy_only'] + result.integer_object_nulls = options['integer_object_nulls'] + result.date_as_object = options['date_as_object'] + result.timestamp_as_object = options['timestamp_as_object'] + result.use_threads = options['use_threads'] + result.deduplicate_objects = options['deduplicate_objects'] + result.safe_cast = options['safe'] + result.split_blocks = options['split_blocks'] + result.self_destruct = options['self_destruct'] + result.ignore_timezone = os.environ.get('PYARROW_IGNORE_TIMEZONE', False) + return result + + +cdef class Array(_PandasConvertible): + """ + The base class for all Arrow arrays. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use one of " + "the `pyarrow.Array.from_*` functions instead." + .format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CArray]& sp_array) except *: + self.sp_array = sp_array + self.ap = sp_array.get() + self.type = pyarrow_wrap_data_type(self.sp_array.get().type()) + + def _debug_print(self): + with nogil: + check_status(DebugPrint(deref(self.ap), 0)) + + def diff(self, Array other): + """ + Compare contents of this array against another one. + + Return string containing the result of arrow::Diff comparing contents + of this array against the other array. + """ + cdef c_string result + with nogil: + result = self.ap.Diff(deref(other.ap)) + return frombytes(result, safe=True) + + def cast(self, object target_type, safe=True): + """ + Cast array values to another data type + + See pyarrow.compute.cast for usage + """ + return _pc().cast(self, target_type, safe=safe) + + def view(self, object target_type): + """ + Return zero-copy "view" of array as another data type. + + The data types must have compatible columnar buffer layouts + + Parameters + ---------- + target_type : DataType + Type to construct view as. + + Returns + ------- + view : Array + """ + cdef DataType type = ensure_type(target_type) + cdef shared_ptr[CArray] result + with nogil: + result = GetResultValue(self.ap.View(type.sp_type)) + return pyarrow_wrap_array(result) + + def sum(self, **kwargs): + """ + Sum the values in a numerical array. + """ + options = _pc().ScalarAggregateOptions(**kwargs) + return _pc().call_function('sum', [self], options) + + def unique(self): + """ + Compute distinct elements in array. + """ + return _pc().call_function('unique', [self]) + + def dictionary_encode(self, null_encoding='mask'): + """ + Compute dictionary-encoded representation of array. + """ + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) + + def value_counts(self): + """ + Compute counts of unique elements in array. + + Returns + ------- + An array of <input type "Values", int64_t "Counts"> structs + """ + return _pc().call_function('value_counts', [self]) + + @staticmethod + def from_pandas(obj, mask=None, type=None, bint safe=True, + MemoryPool memory_pool=None): + """ + Convert pandas.Series to an Arrow Array. + + This method uses Pandas semantics about what values indicate + nulls. See pyarrow.array for more general conversion from arrays or + sequences to Arrow arrays. + + Parameters + ---------- + obj : ndarray, pandas.Series, array-like + mask : array (boolean), optional + Indicate which values are null (True) or not null (False). + type : pyarrow.DataType + Explicit type to attempt to coerce to, otherwise will be inferred + from the data. + safe : bool, default True + Check for overflows or other unsafe conversions. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the currently-set default + memory pool. + + Notes + ----- + Localized timestamps will currently be returned as UTC (pandas's native + representation). Timezone-naive data will be implicitly interpreted as + UTC. + + Returns + ------- + array : pyarrow.Array or pyarrow.ChunkedArray + ChunkedArray is returned if object data overflows binary buffer. + """ + return array(obj, mask=mask, type=type, safe=safe, from_pandas=True, + memory_pool=memory_pool) + + def __reduce__(self): + return _restore_array, \ + (_reduce_array_data(self.sp_array.get().data().get()),) + + @staticmethod + def from_buffers(DataType type, length, buffers, null_count=-1, offset=0, + children=None): + """ + Construct an Array from a sequence of buffers. + + The concrete type returned depends on the datatype. + + Parameters + ---------- + type : DataType + The value type of the array. + length : int + The number of values in the array. + buffers : List[Buffer] + The buffers backing this array. + null_count : int, default -1 + The number of null entries in the array. Negative value means that + the null count is not known. + offset : int, default 0 + The array's logical offset (in values, not in bytes) from the + start of each buffer. + children : List[Array], default None + Nested type children with length matching type.num_fields. + + Returns + ------- + array : Array + """ + cdef: + Buffer buf + Array child + vector[shared_ptr[CBuffer]] c_buffers + vector[shared_ptr[CArrayData]] c_child_data + shared_ptr[CArrayData] array_data + + children = children or [] + + if type.num_fields != len(children): + raise ValueError("Type's expected number of children " + "({0}) did not match the passed number " + "({1}).".format(type.num_fields, len(children))) + + if type.num_buffers != len(buffers): + raise ValueError("Type's expected number of buffers " + "({0}) did not match the passed number " + "({1}).".format(type.num_buffers, len(buffers))) + + for buf in buffers: + # None will produce a null buffer pointer + c_buffers.push_back(pyarrow_unwrap_buffer(buf)) + + for child in children: + c_child_data.push_back(child.ap.data()) + + array_data = CArrayData.MakeWithChildren(type.sp_type, length, + c_buffers, c_child_data, + null_count, offset) + cdef Array result = pyarrow_wrap_array(MakeArray(array_data)) + result.validate() + return result + + @property + def null_count(self): + return self.sp_array.get().null_count() + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the array. + """ + size = 0 + for buf in self.buffers(): + if buf is not None: + size += buf.size + return size + + def __sizeof__(self): + return super(Array, self).__sizeof__() + self.nbytes + + def __iter__(self): + for i in range(len(self)): + yield self.getitem(i) + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def to_string(self, *, int indent=0, int window=10, + c_bool skip_new_lines=False): + """ + Render a "pretty-printed" string representation of the Array. + + Parameters + ---------- + indent : int + How much to indent right the content of the array, + by default ``0``. + window : int + How many items to preview at the begin and end + of the array when the arrays is bigger than the window. + The other elements will be ellipsed. + skip_new_lines : bool + If the array should be rendered as a single line of text + or if each element should be on its own line. + """ + cdef: + c_string result + PrettyPrintOptions options + + with nogil: + options = PrettyPrintOptions(indent, window) + options.skip_new_lines = skip_new_lines + check_status( + PrettyPrint( + deref(self.ap), + options, + &result + ) + ) + + return frombytes(result, safe=True) + + def format(self, **kwargs): + import warnings + warnings.warn('Array.format is deprecated, use Array.to_string') + return self.to_string(**kwargs) + + def __str__(self): + return self.to_string() + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + # This also handles comparing with None + # as Array.equals(None) raises a TypeError. + return NotImplemented + + def equals(Array self, Array other not None): + return self.ap.Equals(deref(other.ap)) + + def __len__(self): + return self.length() + + cdef int64_t length(self): + if self.sp_array.get(): + return self.sp_array.get().length() + else: + return 0 + + def is_null(self, *, nan_is_null=False): + """ + Return BooleanArray indicating the null values. + + Parameters + ---------- + nan_is_null : bool (optional, default False) + Whether floating-point NaN values should also be considered null. + + Returns + ------- + array : boolean Array + """ + options = _pc().NullOptions(nan_is_null=nan_is_null) + return _pc().call_function('is_null', [self], options) + + def is_valid(self): + """ + Return BooleanArray indicating the non-null values. + """ + return _pc().is_valid(self) + + def fill_null(self, fill_value): + """ + See pyarrow.compute.fill_null for usage. + """ + return _pc().fill_null(self, fill_value) + + def __getitem__(self, key): + """ + Slice or return value at given index + + Parameters + ---------- + key : integer or slice + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + + Returns + ------- + value : Scalar (index) or Array (slice) + """ + if PySlice_Check(key): + return _normalize_slice(self, key) + + return self.getitem(_normalize_index(key, self.length())) + + cdef getitem(self, int64_t i): + return Scalar.wrap(GetResultValue(self.ap.GetScalar(i))) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this array. + + Parameters + ---------- + offset : int, default 0 + Offset from start of array to slice. + length : int, default None + Length of slice (default is until end of Array starting from + offset). + + Returns + ------- + sliced : RecordBatch + """ + cdef: + shared_ptr[CArray] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.ap.Slice(offset) + else: + if length < 0: + raise ValueError('Length must be non-negative') + result = self.ap.Slice(offset, length) + + return pyarrow_wrap_array(result) + + def take(self, object indices): + """ + Select values from an array. See pyarrow.compute.take for full usage. + """ + return _pc().take(self, indices) + + def drop_null(self): + """ + Remove missing values from an array. + """ + return _pc().drop_null(self) + + def filter(self, Array mask, *, null_selection_behavior='drop'): + """ + Select values from an array. See pyarrow.compute.filter for full usage. + """ + return _pc().filter(self, mask, + null_selection_behavior=null_selection_behavior) + + def index(self, value, start=None, end=None, *, memory_pool=None): + """ + Find the first index of a value. + + See pyarrow.compute.index for full usage. + """ + return _pc().index(self, value, start, end, memory_pool=memory_pool) + + def _to_pandas(self, options, **kwargs): + return _array_like_to_pandas(self, options) + + def __array__(self, dtype=None): + values = self.to_numpy(zero_copy_only=False) + if dtype is None: + return values + return values.astype(dtype) + + def to_numpy(self, zero_copy_only=True, writable=False): + """ + Return a NumPy view or copy of this array (experimental). + + By default, tries to return a view of this array. This is only + supported for primitive arrays with the same memory layout as NumPy + (i.e. integers, floating point, ..) and without any nulls. + + Parameters + ---------- + zero_copy_only : bool, default True + If True, an exception will be raised if the conversion to a numpy + array would require copying the underlying data (e.g. in presence + of nulls, or for non-primitive types). + writable : bool, default False + For numpy arrays created with zero copy (view on the Arrow data), + the resulting array is not writable (Arrow data is immutable). + By setting this to True, a copy of the array is made to ensure + it is writable. + + Returns + ------- + array : numpy.ndarray + """ + cdef: + PyObject* out + PandasOptions c_options + object values + + if zero_copy_only and writable: + raise ValueError( + "Cannot return a writable array if asking for zero-copy") + + # If there are nulls and the array is a DictionaryArray + # decoding the dictionary will make sure nulls are correctly handled. + # Decoding a dictionary does imply a copy by the way, + # so it can't be done if the user requested a zero_copy. + c_options.decode_dictionaries = not zero_copy_only + c_options.zero_copy_only = zero_copy_only + + with nogil: + check_status(ConvertArrayToPandas(c_options, self.sp_array, + self, &out)) + + # wrap_array_output uses pandas to convert to Categorical, here + # always convert to numpy array without pandas dependency + array = PyObject_to_object(out) + + if isinstance(array, dict): + array = np.take(array['dictionary'], array['indices']) + + if writable and not array.flags.writeable: + # if the conversion already needed to a copy, writeable is True + array = array.copy() + return array + + def to_pylist(self): + """ + Convert to a list of native Python objects. + + Returns + ------- + lst : list + """ + return [x.as_py() for x in self] + + def tolist(self): + """ + Alias of to_pylist for compatibility with NumPy. + """ + return self.to_pylist() + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full: bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + with nogil: + check_status(self.ap.ValidateFull()) + else: + with nogil: + check_status(self.ap.Validate()) + + @property + def offset(self): + """ + A relative position into another array's data. + + The purpose is to enable zero-copy slicing. This value defaults to zero + but must be applied on all operations with the physical storage + buffers. + """ + return self.sp_array.get().offset() + + def buffers(self): + """ + Return a list of Buffer objects pointing to this array's physical + storage. + + To correctly interpret these buffers, you need to also apply the offset + multiplied with the size of the stored data type. + """ + res = [] + _append_array_buffers(self.sp_array.get().data().get(), res) + return res + + def _export_to_c(self, uintptr_t out_ptr, uintptr_t out_schema_ptr=0): + """ + Export to a C ArrowArray struct, given its pointer. + + If a C ArrowSchema struct pointer is also given, the array type + is exported to it at the same time. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowArray struct. + out_schema_ptr: int (optional) + The raw pointer to a C ArrowSchema struct. + + Be careful: if you don't pass the ArrowArray struct to a consumer, + array memory will leak. This is a low-level function intended for + expert users. + """ + with nogil: + check_status(ExportArray(deref(self.sp_array), + <ArrowArray*> out_ptr, + <ArrowSchema*> out_schema_ptr)) + + @staticmethod + def _import_from_c(uintptr_t in_ptr, type): + """ + Import Array from a C ArrowArray struct, given its pointer + and the imported array type. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowArray struct. + type: DataType or int + Either a DataType object, or the raw pointer to a C ArrowSchema + struct. + + This is a low-level function intended for expert users. + """ + cdef: + shared_ptr[CArray] c_array + + c_type = pyarrow_unwrap_data_type(type) + if c_type == nullptr: + # Not a DataType object, perhaps a raw ArrowSchema pointer + type_ptr = <uintptr_t> type + with nogil: + c_array = GetResultValue(ImportArray(<ArrowArray*> in_ptr, + <ArrowSchema*> type_ptr)) + else: + with nogil: + c_array = GetResultValue(ImportArray(<ArrowArray*> in_ptr, + c_type)) + return pyarrow_wrap_array(c_array) + + +cdef _array_like_to_pandas(obj, options): + cdef: + PyObject* out + PandasOptions c_options = _convert_pandas_options(options) + + original_type = obj.type + name = obj._name + + # ARROW-3789(wesm): Convert date/timestamp types to datetime64[ns] + c_options.coerce_temporal_nanoseconds = True + + if isinstance(obj, Array): + with nogil: + check_status(ConvertArrayToPandas(c_options, + (<Array> obj).sp_array, + obj, &out)) + elif isinstance(obj, ChunkedArray): + with nogil: + check_status(libarrow.ConvertChunkedArrayToPandas( + c_options, + (<ChunkedArray> obj).sp_chunked_array, + obj, &out)) + + arr = wrap_array_output(out) + + if (isinstance(original_type, TimestampType) and + options["timestamp_as_object"]): + # ARROW-5359 - need to specify object dtype to avoid pandas to + # coerce back to ns resolution + dtype = "object" + else: + dtype = None + + result = pandas_api.series(arr, dtype=dtype, name=name) + + if (isinstance(original_type, TimestampType) and + original_type.tz is not None and + # can be object dtype for non-ns and timestamp_as_object=True + result.dtype.kind == "M"): + from pyarrow.pandas_compat import make_tz_aware + result = make_tz_aware(result, original_type.tz) + + return result + + +cdef wrap_array_output(PyObject* output): + cdef object obj = PyObject_to_object(output) + + if isinstance(obj, dict): + return pandas_api.categorical_type(obj['indices'], + categories=obj['dictionary'], + ordered=obj['ordered'], + fastpath=True) + else: + return obj + + +cdef class NullArray(Array): + """ + Concrete class for Arrow arrays of null data type. + """ + + +cdef class BooleanArray(Array): + """ + Concrete class for Arrow arrays of boolean data type. + """ + @property + def false_count(self): + return (<CBooleanArray*> self.ap).false_count() + + @property + def true_count(self): + return (<CBooleanArray*> self.ap).true_count() + + +cdef class NumericArray(Array): + """ + A base class for Arrow numeric arrays. + """ + + +cdef class IntegerArray(NumericArray): + """ + A base class for Arrow integer arrays. + """ + + +cdef class FloatingPointArray(NumericArray): + """ + A base class for Arrow floating-point arrays. + """ + + +cdef class Int8Array(IntegerArray): + """ + Concrete class for Arrow arrays of int8 data type. + """ + + +cdef class UInt8Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint8 data type. + """ + + +cdef class Int16Array(IntegerArray): + """ + Concrete class for Arrow arrays of int16 data type. + """ + + +cdef class UInt16Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint16 data type. + """ + + +cdef class Int32Array(IntegerArray): + """ + Concrete class for Arrow arrays of int32 data type. + """ + + +cdef class UInt32Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint32 data type. + """ + + +cdef class Int64Array(IntegerArray): + """ + Concrete class for Arrow arrays of int64 data type. + """ + + +cdef class UInt64Array(IntegerArray): + """ + Concrete class for Arrow arrays of uint64 data type. + """ + + +cdef class Date32Array(NumericArray): + """ + Concrete class for Arrow arrays of date32 data type. + """ + + +cdef class Date64Array(NumericArray): + """ + Concrete class for Arrow arrays of date64 data type. + """ + + +cdef class TimestampArray(NumericArray): + """ + Concrete class for Arrow arrays of timestamp data type. + """ + + +cdef class Time32Array(NumericArray): + """ + Concrete class for Arrow arrays of time32 data type. + """ + + +cdef class Time64Array(NumericArray): + """ + Concrete class for Arrow arrays of time64 data type. + """ + + +cdef class DurationArray(NumericArray): + """ + Concrete class for Arrow arrays of duration data type. + """ + + +cdef class MonthDayNanoIntervalArray(Array): + """ + Concrete class for Arrow arrays of interval[MonthDayNano] type. + """ + + def to_pylist(self): + """ + Convert to a list of native Python objects. + + pyarrow.MonthDayNano is used as the native representation. + + Returns + ------- + lst : list + """ + cdef: + CResult[PyObject*] maybe_py_list + PyObject* py_list + CMonthDayNanoIntervalArray* array + array = <CMonthDayNanoIntervalArray*>self.sp_array.get() + maybe_py_list = MonthDayNanoIntervalArrayToPyList(deref(array)) + py_list = GetResultValue(maybe_py_list) + return PyObject_to_object(py_list) + + +cdef class HalfFloatArray(FloatingPointArray): + """ + Concrete class for Arrow arrays of float16 data type. + """ + + +cdef class FloatArray(FloatingPointArray): + """ + Concrete class for Arrow arrays of float32 data type. + """ + + +cdef class DoubleArray(FloatingPointArray): + """ + Concrete class for Arrow arrays of float64 data type. + """ + + +cdef class FixedSizeBinaryArray(Array): + """ + Concrete class for Arrow arrays of a fixed-size binary data type. + """ + + +cdef class Decimal128Array(FixedSizeBinaryArray): + """ + Concrete class for Arrow arrays of decimal128 data type. + """ + + +cdef class Decimal256Array(FixedSizeBinaryArray): + """ + Concrete class for Arrow arrays of decimal256 data type. + """ + +cdef class BaseListArray(Array): + + def flatten(self): + """ + Unnest this ListArray/LargeListArray by one level. + + The returned Array is logically a concatenation of all the sub-lists + in this Array. + + Note that this method is different from ``self.values()`` in that + it takes care of the slicing offset as well as null elements backed + by non-empty sub-lists. + + Returns + ------- + result : Array + """ + return _pc().list_flatten(self) + + def value_parent_indices(self): + """ + Return array of same length as list child values array where each + output value is the index of the parent list array slot containing each + child value. + + Examples + -------- + >>> arr = pa.array([[1, 2, 3], [], None, [4]], + ... type=pa.list_(pa.int32())) + >>> arr.value_parent_indices() + <pyarrow.lib.Int32Array object at 0x7efc5db958a0> + [ + 0, + 0, + 0, + 3 + ] + """ + return _pc().list_parent_indices(self) + + def value_lengths(self): + """ + Return integers array with values equal to the respective length of + each list element. Null list values are null in the output. + + Examples + -------- + >>> arr = pa.array([[1, 2, 3], [], None, [4]], + ... type=pa.list_(pa.int32())) + >>> arr.value_lengths() + <pyarrow.lib.Int32Array object at 0x7efc5db95910> + [ + 3, + 0, + null, + 1 + ] + """ + return _pc().list_value_length(self) + + +cdef class ListArray(BaseListArray): + """ + Concrete class for Arrow arrays of a list data type. + """ + + @staticmethod + def from_arrays(offsets, values, MemoryPool pool=None): + """ + Construct ListArray from arrays of int32 offsets and values. + + Parameters + ---------- + offsets : Array (int32 type) + values : Array (any type) + pool : MemoryPool + + Returns + ------- + list_array : ListArray + + Examples + -------- + >>> values = pa.array([1, 2, 3, 4]) + >>> offsets = pa.array([0, 2, 4]) + >>> pa.ListArray.from_arrays(offsets, values) + <pyarrow.lib.ListArray object at 0x7fbde226bf40> + [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ] + # nulls in the offsets array become null lists + >>> offsets = pa.array([0, None, 2, 4]) + >>> pa.ListArray.from_arrays(offsets, values) + <pyarrow.lib.ListArray object at 0x7fbde226bf40> + [ + [ + 0, + 1 + ], + null, + [ + 2, + 3 + ] + ] + """ + cdef: + Array _offsets, _values + shared_ptr[CArray] out + cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int32') + _values = asarray(values) + + with nogil: + out = GetResultValue( + CListArray.FromArrays(_offsets.ap[0], _values.ap[0], cpool)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def values(self): + cdef CListArray* arr = <CListArray*> self.ap + return pyarrow_wrap_array(arr.values()) + + @property + def offsets(self): + """ + Return the offsets as an int32 array. + """ + return pyarrow_wrap_array((<CListArray*> self.ap).offsets()) + + +cdef class LargeListArray(BaseListArray): + """ + Concrete class for Arrow arrays of a large list data type. + + Identical to ListArray, but 64-bit offsets. + """ + + @staticmethod + def from_arrays(offsets, values, MemoryPool pool=None): + """ + Construct LargeListArray from arrays of int64 offsets and values. + + Parameters + ---------- + offsets : Array (int64 type) + values : Array (any type) + pool : MemoryPool + + Returns + ------- + list_array : LargeListArray + """ + cdef: + Array _offsets, _values + shared_ptr[CArray] out + cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int64') + _values = asarray(values) + + with nogil: + out = GetResultValue( + CLargeListArray.FromArrays(_offsets.ap[0], _values.ap[0], + cpool)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def values(self): + cdef CLargeListArray* arr = <CLargeListArray*> self.ap + return pyarrow_wrap_array(arr.values()) + + @property + def offsets(self): + """ + Return the offsets as an int64 array. + """ + return pyarrow_wrap_array((<CLargeListArray*> self.ap).offsets()) + + +cdef class MapArray(Array): + """ + Concrete class for Arrow arrays of a map data type. + """ + + @staticmethod + def from_arrays(offsets, keys, items, MemoryPool pool=None): + """ + Construct MapArray from arrays of int32 offsets and key, item arrays. + + Parameters + ---------- + offsets : array-like or sequence (int32 type) + keys : array-like or sequence (any type) + items : array-like or sequence (any type) + pool : MemoryPool + + Returns + ------- + map_array : MapArray + """ + cdef: + Array _offsets, _keys, _items + shared_ptr[CArray] out + cdef CMemoryPool* cpool = maybe_unbox_memory_pool(pool) + + _offsets = asarray(offsets, type='int32') + _keys = asarray(keys) + _items = asarray(items) + + with nogil: + out = GetResultValue( + CMapArray.FromArrays(_offsets.sp_array, + _keys.sp_array, + _items.sp_array, cpool)) + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @property + def keys(self): + return pyarrow_wrap_array((<CMapArray*> self.ap).keys()) + + @property + def items(self): + return pyarrow_wrap_array((<CMapArray*> self.ap).items()) + + +cdef class FixedSizeListArray(Array): + """ + Concrete class for Arrow arrays of a fixed size list data type. + """ + + @staticmethod + def from_arrays(values, int32_t list_size): + """ + Construct FixedSizeListArray from array of values and a list length. + + Parameters + ---------- + values : Array (any type) + list_size : int + The fixed length of the lists. + + Returns + ------- + FixedSizeListArray + """ + cdef: + Array _values + CResult[shared_ptr[CArray]] c_result + + _values = asarray(values) + + with nogil: + c_result = CFixedSizeListArray.FromArrays( + _values.sp_array, list_size) + cdef Array result = pyarrow_wrap_array(GetResultValue(c_result)) + result.validate() + return result + + @property + def values(self): + return self.flatten() + + def flatten(self): + """ + Unnest this FixedSizeListArray by one level. + + Returns + ------- + result : Array + """ + cdef CFixedSizeListArray* arr = <CFixedSizeListArray*> self.ap + return pyarrow_wrap_array(arr.values()) + + +cdef class UnionArray(Array): + """ + Concrete class for Arrow arrays of a Union data type. + """ + + def child(self, int pos): + import warnings + warnings.warn("child is deprecated, use field", FutureWarning) + return self.field(pos) + + def field(self, int pos): + """ + Return the given child field as an individual array. + + For sparse unions, the returned array has its offset, length, + and null count adjusted. + + For dense unions, the returned array is unchanged. + """ + cdef shared_ptr[CArray] result + result = (<CUnionArray*> self.ap).field(pos) + if result != NULL: + return pyarrow_wrap_array(result) + raise KeyError("UnionArray does not have child {}".format(pos)) + + @property + def type_codes(self): + """Get the type codes array.""" + buf = pyarrow_wrap_buffer((<CUnionArray*> self.ap).type_codes()) + return Array.from_buffers(int8(), len(self), [None, buf]) + + @property + def offsets(self): + """ + Get the value offsets array (dense arrays only). + + Does not account for any slice offset. + """ + if self.type.mode != "dense": + raise ArrowTypeError("Can only get value offsets for dense arrays") + cdef CDenseUnionArray* dense = <CDenseUnionArray*> self.ap + buf = pyarrow_wrap_buffer(dense.value_offsets()) + return Array.from_buffers(int32(), len(self), [None, buf]) + + @staticmethod + def from_dense(Array types, Array value_offsets, list children, + list field_names=None, list type_codes=None): + """ + Construct dense UnionArray from arrays of int8 types, int32 offsets and + children arrays + + Parameters + ---------- + types : Array (int8 type) + value_offsets : Array (int32 type) + children : list + field_names : list + type_codes : list + + Returns + ------- + union_array : UnionArray + """ + cdef: + shared_ptr[CArray] out + vector[shared_ptr[CArray]] c + Array child + vector[c_string] c_field_names + vector[int8_t] c_type_codes + + for child in children: + c.push_back(child.sp_array) + if field_names is not None: + for x in field_names: + c_field_names.push_back(tobytes(x)) + if type_codes is not None: + for x in type_codes: + c_type_codes.push_back(x) + + with nogil: + out = GetResultValue(CDenseUnionArray.Make( + deref(types.ap), deref(value_offsets.ap), c, c_field_names, + c_type_codes)) + + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + @staticmethod + def from_sparse(Array types, list children, list field_names=None, + list type_codes=None): + """ + Construct sparse UnionArray from arrays of int8 types and children + arrays + + Parameters + ---------- + types : Array (int8 type) + children : list + field_names : list + type_codes : list + + Returns + ------- + union_array : UnionArray + """ + cdef: + shared_ptr[CArray] out + vector[shared_ptr[CArray]] c + Array child + vector[c_string] c_field_names + vector[int8_t] c_type_codes + + for child in children: + c.push_back(child.sp_array) + if field_names is not None: + for x in field_names: + c_field_names.push_back(tobytes(x)) + if type_codes is not None: + for x in type_codes: + c_type_codes.push_back(x) + + with nogil: + out = GetResultValue(CSparseUnionArray.Make( + deref(types.ap), c, c_field_names, c_type_codes)) + + cdef Array result = pyarrow_wrap_array(out) + result.validate() + return result + + +cdef class StringArray(Array): + """ + Concrete class for Arrow arrays of string (or utf8) data type. + """ + + @staticmethod + def from_buffers(int length, Buffer value_offsets, Buffer data, + Buffer null_bitmap=None, int null_count=-1, + int offset=0): + """ + Construct a StringArray from value_offsets and data buffers. + If there are nulls in the data, also a null_bitmap and the matching + null_count must be passed. + + Parameters + ---------- + length : int + value_offsets : Buffer + data : Buffer + null_bitmap : Buffer, optional + null_count : int, default 0 + offset : int, default 0 + + Returns + ------- + string_array : StringArray + """ + return Array.from_buffers(utf8(), length, + [null_bitmap, value_offsets, data], + null_count, offset) + + +cdef class LargeStringArray(Array): + """ + Concrete class for Arrow arrays of large string (or utf8) data type. + """ + + @staticmethod + def from_buffers(int length, Buffer value_offsets, Buffer data, + Buffer null_bitmap=None, int null_count=-1, + int offset=0): + """ + Construct a LargeStringArray from value_offsets and data buffers. + If there are nulls in the data, also a null_bitmap and the matching + null_count must be passed. + + Parameters + ---------- + length : int + value_offsets : Buffer + data : Buffer + null_bitmap : Buffer, optional + null_count : int, default 0 + offset : int, default 0 + + Returns + ------- + string_array : StringArray + """ + return Array.from_buffers(large_utf8(), length, + [null_bitmap, value_offsets, data], + null_count, offset) + + +cdef class BinaryArray(Array): + """ + Concrete class for Arrow arrays of variable-sized binary data type. + """ + @property + def total_values_length(self): + """ + The number of bytes from beginning to end of the data buffer addressed + by the offsets of this BinaryArray. + """ + return (<CBinaryArray*> self.ap).total_values_length() + + +cdef class LargeBinaryArray(Array): + """ + Concrete class for Arrow arrays of large variable-sized binary data type. + """ + @property + def total_values_length(self): + """ + The number of bytes from beginning to end of the data buffer addressed + by the offsets of this LargeBinaryArray. + """ + return (<CLargeBinaryArray*> self.ap).total_values_length() + + +cdef class DictionaryArray(Array): + """ + Concrete class for dictionary-encoded Arrow arrays. + """ + + def dictionary_encode(self): + return self + + def dictionary_decode(self): + """ + Decodes the DictionaryArray to an Array. + """ + return self.dictionary.take(self.indices) + + @property + def dictionary(self): + cdef CDictionaryArray* darr = <CDictionaryArray*>(self.ap) + + if self._dictionary is None: + self._dictionary = pyarrow_wrap_array(darr.dictionary()) + + return self._dictionary + + @property + def indices(self): + cdef CDictionaryArray* darr = <CDictionaryArray*>(self.ap) + + if self._indices is None: + self._indices = pyarrow_wrap_array(darr.indices()) + + return self._indices + + @staticmethod + def from_arrays(indices, dictionary, mask=None, bint ordered=False, + bint from_pandas=False, bint safe=True, + MemoryPool memory_pool=None): + """ + Construct a DictionaryArray from indices and values. + + Parameters + ---------- + indices : pyarrow.Array, numpy.ndarray or pandas.Series, int type + Non-negative integers referencing the dictionary values by zero + based index. + dictionary : pyarrow.Array, ndarray or pandas.Series + The array of values referenced by the indices. + mask : ndarray or pandas.Series, bool type + True values indicate that indices are actually null. + from_pandas : bool, default False + If True, the indices should be treated as though they originated in + a pandas.Categorical (null encoded as -1). + ordered : bool, default False + Set to True if the category values are ordered. + safe : bool, default True + If True, check that the dictionary indices are in range. + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise uses default pool. + + Returns + ------- + dict_array : DictionaryArray + """ + cdef: + Array _indices, _dictionary + shared_ptr[CDataType] c_type + shared_ptr[CArray] c_result + + if isinstance(indices, Array): + if mask is not None: + raise NotImplementedError( + "mask not implemented with Arrow array inputs yet") + _indices = indices + else: + if from_pandas: + _indices = _codes_to_indices(indices, mask, None, memory_pool) + else: + _indices = array(indices, mask=mask, memory_pool=memory_pool) + + if isinstance(dictionary, Array): + _dictionary = dictionary + else: + _dictionary = array(dictionary, memory_pool=memory_pool) + + if not isinstance(_indices, IntegerArray): + raise ValueError('Indices must be integer type') + + cdef c_bool c_ordered = ordered + + c_type.reset(new CDictionaryType(_indices.type.sp_type, + _dictionary.sp_array.get().type(), + c_ordered)) + + if safe: + with nogil: + c_result = GetResultValue( + CDictionaryArray.FromArrays(c_type, _indices.sp_array, + _dictionary.sp_array)) + else: + c_result.reset(new CDictionaryArray(c_type, _indices.sp_array, + _dictionary.sp_array)) + + cdef Array result = pyarrow_wrap_array(c_result) + result.validate() + return result + + +cdef class StructArray(Array): + """ + Concrete class for Arrow arrays of a struct data type. + """ + + def field(self, index): + """ + Retrieves the child array belonging to field. + + Parameters + ---------- + index : Union[int, str] + Index / position or name of the field. + + Returns + ------- + result : Array + """ + cdef: + CStructArray* arr = <CStructArray*> self.ap + shared_ptr[CArray] child + + if isinstance(index, (bytes, str)): + child = arr.GetFieldByName(tobytes(index)) + if child == nullptr: + raise KeyError(index) + elif isinstance(index, int): + child = arr.field( + <int>_normalize_index(index, self.ap.num_fields())) + else: + raise TypeError('Expected integer or string index') + + return pyarrow_wrap_array(child) + + def flatten(self, MemoryPool memory_pool=None): + """ + Return one individual array for each field in the struct. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + + Returns + ------- + result : List[Array] + """ + cdef: + vector[shared_ptr[CArray]] arrays + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + CStructArray* sarr = <CStructArray*> self.ap + + with nogil: + arrays = GetResultValue(sarr.Flatten(pool)) + + return [pyarrow_wrap_array(arr) for arr in arrays] + + @staticmethod + def from_arrays(arrays, names=None, fields=None, mask=None, + memory_pool=None): + """ + Construct StructArray from collection of arrays representing + each field in the struct. + + Either field names or field instances must be passed. + + Parameters + ---------- + arrays : sequence of Array + names : List[str] (optional) + Field names for each struct child. + fields : List[Field] (optional) + Field instances for each struct child. + mask : pyarrow.Array[bool] (optional) + Indicate which values are null (True) or not null (False). + memory_pool : MemoryPool (optional) + For memory allocations, if required, otherwise uses default pool. + + Returns + ------- + result : StructArray + """ + cdef: + shared_ptr[CArray] c_array + shared_ptr[CBuffer] c_mask + vector[shared_ptr[CArray]] c_arrays + vector[c_string] c_names + vector[shared_ptr[CField]] c_fields + CResult[shared_ptr[CArray]] c_result + ssize_t num_arrays + ssize_t length + ssize_t i + Field py_field + DataType struct_type + + if names is None and fields is None: + raise ValueError('Must pass either names or fields') + if names is not None and fields is not None: + raise ValueError('Must pass either names or fields, not both') + + if mask is None: + c_mask = shared_ptr[CBuffer]() + elif isinstance(mask, Array): + if mask.type.id != Type_BOOL: + raise ValueError('Mask must be a pyarrow.Array of type bool') + if mask.null_count != 0: + raise ValueError('Mask must not contain nulls') + inverted_mask = _pc().invert(mask, memory_pool=memory_pool) + c_mask = pyarrow_unwrap_buffer(inverted_mask.buffers()[1]) + else: + raise ValueError('Mask must be a pyarrow.Array of type bool') + + arrays = [asarray(x) for x in arrays] + for arr in arrays: + c_array = pyarrow_unwrap_array(arr) + if c_array == nullptr: + raise TypeError(f"Expected Array, got {arr.__class__}") + c_arrays.push_back(c_array) + if names is not None: + for name in names: + c_names.push_back(tobytes(name)) + else: + for item in fields: + if isinstance(item, tuple): + py_field = field(*item) + else: + py_field = item + c_fields.push_back(py_field.sp_field) + + if (c_arrays.size() == 0 and c_names.size() == 0 and + c_fields.size() == 0): + # The C++ side doesn't allow this + return array([], struct([])) + + if names is not None: + # XXX Cannot pass "nullptr" for a shared_ptr<T> argument: + # https://github.com/cython/cython/issues/3020 + c_result = CStructArray.MakeFromFieldNames( + c_arrays, c_names, c_mask, -1, 0) + else: + c_result = CStructArray.MakeFromFields( + c_arrays, c_fields, c_mask, -1, 0) + cdef Array result = pyarrow_wrap_array(GetResultValue(c_result)) + result.validate() + return result + + +cdef class ExtensionArray(Array): + """ + Concrete class for Arrow extension arrays. + """ + + @property + def storage(self): + cdef: + CExtensionArray* ext_array = <CExtensionArray*>(self.ap) + + return pyarrow_wrap_array(ext_array.storage()) + + @staticmethod + def from_storage(BaseExtensionType typ, Array storage): + """ + Construct ExtensionArray from type and storage array. + + Parameters + ---------- + typ : DataType + The extension type for the result array. + storage : Array + The underlying storage for the result array. + + Returns + ------- + ext_array : ExtensionArray + """ + cdef: + shared_ptr[CExtensionArray] ext_array + + if storage.type != typ.storage_type: + raise TypeError("Incompatible storage type {0} " + "for extension type {1}".format(storage.type, typ)) + + ext_array = make_shared[CExtensionArray](typ.sp_type, storage.sp_array) + cdef Array result = pyarrow_wrap_array(<shared_ptr[CArray]> ext_array) + result.validate() + return result + + def _to_pandas(self, options, **kwargs): + pandas_dtype = None + try: + pandas_dtype = self.type.to_pandas_dtype() + except NotImplementedError: + pass + + # pandas ExtensionDtype that implements conversion from pyarrow + if hasattr(pandas_dtype, '__from_arrow__'): + arr = pandas_dtype.__from_arrow__(self) + return pandas_api.series(arr) + + # otherwise convert the storage array with the base implementation + return Array._to_pandas(self.storage, options, **kwargs) + + def to_numpy(self, **kwargs): + """ + Convert extension array to a numpy ndarray. + + See Also + -------- + Array.to_numpy + """ + return self.storage.to_numpy(**kwargs) + + +cdef dict _array_classes = { + _Type_NA: NullArray, + _Type_BOOL: BooleanArray, + _Type_UINT8: UInt8Array, + _Type_UINT16: UInt16Array, + _Type_UINT32: UInt32Array, + _Type_UINT64: UInt64Array, + _Type_INT8: Int8Array, + _Type_INT16: Int16Array, + _Type_INT32: Int32Array, + _Type_INT64: Int64Array, + _Type_DATE32: Date32Array, + _Type_DATE64: Date64Array, + _Type_TIMESTAMP: TimestampArray, + _Type_TIME32: Time32Array, + _Type_TIME64: Time64Array, + _Type_DURATION: DurationArray, + _Type_INTERVAL_MONTH_DAY_NANO: MonthDayNanoIntervalArray, + _Type_HALF_FLOAT: HalfFloatArray, + _Type_FLOAT: FloatArray, + _Type_DOUBLE: DoubleArray, + _Type_LIST: ListArray, + _Type_LARGE_LIST: LargeListArray, + _Type_MAP: MapArray, + _Type_FIXED_SIZE_LIST: FixedSizeListArray, + _Type_SPARSE_UNION: UnionArray, + _Type_DENSE_UNION: UnionArray, + _Type_BINARY: BinaryArray, + _Type_STRING: StringArray, + _Type_LARGE_BINARY: LargeBinaryArray, + _Type_LARGE_STRING: LargeStringArray, + _Type_DICTIONARY: DictionaryArray, + _Type_FIXED_SIZE_BINARY: FixedSizeBinaryArray, + _Type_DECIMAL128: Decimal128Array, + _Type_DECIMAL256: Decimal256Array, + _Type_STRUCT: StructArray, + _Type_EXTENSION: ExtensionArray, +} + + +cdef object get_array_class_from_type( + const shared_ptr[CDataType]& sp_data_type): + cdef CDataType* data_type = sp_data_type.get() + if data_type == NULL: + raise ValueError('Array data type was NULL') + + if data_type.id() == _Type_EXTENSION: + py_ext_data_type = pyarrow_wrap_data_type(sp_data_type) + return py_ext_data_type.__arrow_ext_class__() + else: + return _array_classes[data_type.id()] + + +cdef object get_values(object obj, bint* is_series): + if pandas_api.is_series(obj) or pandas_api.is_index(obj): + result = pandas_api.get_values(obj) + is_series[0] = True + elif isinstance(obj, np.ndarray): + result = obj + is_series[0] = False + else: + result = pandas_api.series(obj).values + is_series[0] = False + + return result + + +def concat_arrays(arrays, MemoryPool memory_pool=None): + """ + Concatenate the given arrays. + + The contents of the input arrays are copied into the returned array. + + Raises + ------ + ArrowInvalid : if not all of the arrays have the same type. + + Parameters + ---------- + arrays : iterable of pyarrow.Array + Arrays to concatenate, must be identically typed. + memory_pool : MemoryPool, default None + For memory allocations. If None, the default pool is used. + """ + cdef: + vector[shared_ptr[CArray]] c_arrays + shared_ptr[CArray] c_concatenated + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + for array in arrays: + if not isinstance(array, Array): + raise TypeError("Iterable should contain Array objects, " + "got {0} instead".format(type(array))) + c_arrays.push_back(pyarrow_unwrap_array(array)) + + with nogil: + c_concatenated = GetResultValue(Concatenate(c_arrays, pool)) + + return pyarrow_wrap_array(c_concatenated) + + +def _empty_array(DataType type): + """ + Create empty array of the given type. + """ + if type.id == Type_DICTIONARY: + arr = DictionaryArray.from_arrays( + _empty_array(type.index_type), _empty_array(type.value_type), + ordered=type.ordered) + else: + arr = array([], type=type) + return arr diff --git a/src/arrow/python/pyarrow/benchmark.pxi b/src/arrow/python/pyarrow/benchmark.pxi new file mode 100644 index 000000000..ab251017d --- /dev/null +++ b/src/arrow/python/pyarrow/benchmark.pxi @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +def benchmark_PandasObjectIsNull(list obj): + Benchmark_PandasObjectIsNull(obj) diff --git a/src/arrow/python/pyarrow/benchmark.py b/src/arrow/python/pyarrow/benchmark.py new file mode 100644 index 000000000..25ee1141f --- /dev/null +++ b/src/arrow/python/pyarrow/benchmark.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# flake8: noqa + + +from pyarrow.lib import benchmark_PandasObjectIsNull diff --git a/src/arrow/python/pyarrow/builder.pxi b/src/arrow/python/pyarrow/builder.pxi new file mode 100644 index 000000000..a34ea5412 --- /dev/null +++ b/src/arrow/python/pyarrow/builder.pxi @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +cdef class StringBuilder(_Weakrefable): + """ + Builder class for UTF8 strings. + + This class exposes facilities for incrementally adding string values and + building the null bitmap for a pyarrow.Array (type='string'). + """ + cdef: + unique_ptr[CStringBuilder] builder + + def __cinit__(self, MemoryPool memory_pool=None): + cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + self.builder.reset(new CStringBuilder(pool)) + + def append(self, value): + """ + Append a single value to the builder. + + The value can either be a string/bytes object or a null value + (np.nan or None). + + Parameters + ---------- + value : string/bytes or np.nan/None + The value to append to the string array builder. + """ + if value is None or value is np.nan: + self.builder.get().AppendNull() + elif isinstance(value, (bytes, str)): + self.builder.get().Append(tobytes(value)) + else: + raise TypeError('StringBuilder only accepts string objects') + + def append_values(self, values): + """ + Append all the values from an iterable. + + Parameters + ---------- + values : iterable of string/bytes or np.nan/None values + The values to append to the string array builder. + """ + for value in values: + self.append(value) + + def finish(self): + """ + Return result of builder as an Array object; also resets the builder. + + Returns + ------- + array : pyarrow.Array + """ + cdef shared_ptr[CArray] out + with nogil: + self.builder.get().Finish(&out) + return pyarrow_wrap_array(out) + + @property + def null_count(self): + return self.builder.get().null_count() + + def __len__(self): + return self.builder.get().length() diff --git a/src/arrow/python/pyarrow/cffi.py b/src/arrow/python/pyarrow/cffi.py new file mode 100644 index 000000000..961b61dee --- /dev/null +++ b/src/arrow/python/pyarrow/cffi.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import + +import cffi + +c_source = """ + struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; + }; + + struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; + }; + + struct ArrowArrayStream { + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback + void (*release)(struct ArrowArrayStream*); + // Opaque producer-specific data + void* private_data; + }; + """ + +# TODO use out-of-line mode for faster import and avoid C parsing +ffi = cffi.FFI() +ffi.cdef(c_source) diff --git a/src/arrow/python/pyarrow/compat.pxi b/src/arrow/python/pyarrow/compat.pxi new file mode 100644 index 000000000..a5db5741b --- /dev/null +++ b/src/arrow/python/pyarrow/compat.pxi @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + + +def encode_file_path(path): + if isinstance(path, str): + # POSIX systems can handle utf-8. UTF8 is converted to utf16-le in + # libarrow + encoded_path = path.encode('utf-8') + else: + encoded_path = path + + # Windows file system requires utf-16le for file names; Arrow C++ libraries + # will convert utf8 to utf16 + return encoded_path + + +if sys.version_info >= (3, 7): + # Starting with Python 3.7, dicts are guaranteed to be insertion-ordered. + ordered_dict = dict +else: + import collections + ordered_dict = collections.OrderedDict + + +try: + import pickle5 as builtin_pickle +except ImportError: + import pickle as builtin_pickle + + +try: + import cloudpickle as pickle +except ImportError: + pickle = builtin_pickle + + +def tobytes(o): + if isinstance(o, str): + return o.encode('utf8') + else: + return o + + +def frombytes(o, *, safe=False): + if safe: + return o.decode('utf8', errors='replace') + else: + return o.decode('utf8') diff --git a/src/arrow/python/pyarrow/compat.py b/src/arrow/python/pyarrow/compat.py new file mode 100644 index 000000000..814a51bef --- /dev/null +++ b/src/arrow/python/pyarrow/compat.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# flake8: noqa + +import pyarrow.util as util +import warnings + + +warnings.warn("pyarrow.compat has been deprecated and will be removed in a " + "future release", FutureWarning) + + +guid = util._deprecate_api("compat.guid", "util.guid", + util.guid, "1.0.0") diff --git a/src/arrow/python/pyarrow/compute.py b/src/arrow/python/pyarrow/compute.py new file mode 100644 index 000000000..6e3bd7fca --- /dev/null +++ b/src/arrow/python/pyarrow/compute.py @@ -0,0 +1,759 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow._compute import ( # noqa + Function, + FunctionOptions, + FunctionRegistry, + HashAggregateFunction, + HashAggregateKernel, + Kernel, + ScalarAggregateFunction, + ScalarAggregateKernel, + ScalarFunction, + ScalarKernel, + VectorFunction, + VectorKernel, + # Option classes + ArraySortOptions, + AssumeTimezoneOptions, + CastOptions, + CountOptions, + DayOfWeekOptions, + DictionaryEncodeOptions, + ElementWiseAggregateOptions, + ExtractRegexOptions, + FilterOptions, + IndexOptions, + JoinOptions, + MakeStructOptions, + MatchSubstringOptions, + ModeOptions, + NullOptions, + PadOptions, + PartitionNthOptions, + QuantileOptions, + ReplaceSliceOptions, + ReplaceSubstringOptions, + RoundOptions, + RoundToMultipleOptions, + ScalarAggregateOptions, + SelectKOptions, + SetLookupOptions, + SliceOptions, + SortOptions, + SplitOptions, + SplitPatternOptions, + StrftimeOptions, + StrptimeOptions, + TakeOptions, + TDigestOptions, + TrimOptions, + VarianceOptions, + WeekOptions, + # Functions + call_function, + function_registry, + get_function, + list_functions, +) + +import inspect +from textwrap import dedent +import warnings + +import pyarrow as pa + + +def _get_arg_names(func): + return func._doc.arg_names + + +def _decorate_compute_function(wrapper, exposed_name, func, option_class): + # Decorate the given compute function wrapper with useful metadata + # and documentation. + wrapper.__arrow_compute_function__ = dict(name=func.name, + arity=func.arity) + wrapper.__name__ = exposed_name + wrapper.__qualname__ = exposed_name + + doc_pieces = [] + + cpp_doc = func._doc + summary = cpp_doc.summary + if not summary: + arg_str = "arguments" if func.arity > 1 else "argument" + summary = ("Call compute function {!r} with the given {}" + .format(func.name, arg_str)) + + description = cpp_doc.description + arg_names = _get_arg_names(func) + + doc_pieces.append("""\ + {}. + + """.format(summary)) + + if description: + doc_pieces.append("{}\n\n".format(description)) + + doc_pieces.append("""\ + Parameters + ---------- + """) + + for arg_name in arg_names: + if func.kind in ('vector', 'scalar_aggregate'): + arg_type = 'Array-like' + else: + arg_type = 'Array-like or scalar-like' + doc_pieces.append("""\ + {} : {} + Argument to compute function + """.format(arg_name, arg_type)) + + doc_pieces.append("""\ + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + """) + if option_class is not None: + doc_pieces.append("""\ + options : pyarrow.compute.{0}, optional + Parameters altering compute function semantics. + """.format(option_class.__name__)) + options_sig = inspect.signature(option_class) + for p in options_sig.parameters.values(): + doc_pieces.append("""\ + {0} : optional + Parameter for {1} constructor. Either `options` + or `{0}` can be passed, but not both at the same time. + """.format(p.name, option_class.__name__)) + + wrapper.__doc__ = "".join(dedent(s) for s in doc_pieces) + return wrapper + + +def _get_options_class(func): + class_name = func._doc.options_class + if not class_name: + return None + try: + return globals()[class_name] + except KeyError: + warnings.warn("Python binding for {} not exposed" + .format(class_name), RuntimeWarning) + return None + + +def _handle_options(name, option_class, options, kwargs): + if kwargs: + if options is None: + return option_class(**kwargs) + raise TypeError( + "Function {!r} called with both an 'options' argument " + "and additional named arguments" + .format(name)) + + if options is not None: + if isinstance(options, dict): + return option_class(**options) + elif isinstance(options, option_class): + return options + raise TypeError( + "Function {!r} expected a {} parameter, got {}" + .format(name, option_class, type(options))) + + return options + + +def _make_generic_wrapper(func_name, func, option_class): + if option_class is None: + def wrapper(*args, memory_pool=None): + return func.call(args, None, memory_pool) + else: + def wrapper(*args, memory_pool=None, options=None, **kwargs): + options = _handle_options(func_name, option_class, options, + kwargs) + return func.call(args, options, memory_pool) + return wrapper + + +def _make_signature(arg_names, var_arg_names, option_class): + from inspect import Parameter + params = [] + for name in arg_names: + params.append(Parameter(name, Parameter.POSITIONAL_OR_KEYWORD)) + for name in var_arg_names: + params.append(Parameter(name, Parameter.VAR_POSITIONAL)) + params.append(Parameter("memory_pool", Parameter.KEYWORD_ONLY, + default=None)) + if option_class is not None: + params.append(Parameter("options", Parameter.KEYWORD_ONLY, + default=None)) + options_sig = inspect.signature(option_class) + for p in options_sig.parameters.values(): + # XXX for now, our generic wrappers don't allow positional + # option arguments + params.append(p.replace(kind=Parameter.KEYWORD_ONLY)) + return inspect.Signature(params) + + +def _wrap_function(name, func): + option_class = _get_options_class(func) + arg_names = _get_arg_names(func) + has_vararg = arg_names and arg_names[-1].startswith('*') + if has_vararg: + var_arg_names = [arg_names.pop().lstrip('*')] + else: + var_arg_names = [] + + wrapper = _make_generic_wrapper(name, func, option_class) + wrapper.__signature__ = _make_signature(arg_names, var_arg_names, + option_class) + return _decorate_compute_function(wrapper, name, func, option_class) + + +def _make_global_functions(): + """ + Make global functions wrapping each compute function. + + Note that some of the automatically-generated wrappers may be overriden + by custom versions below. + """ + g = globals() + reg = function_registry() + + # Avoid clashes with Python keywords + rewrites = {'and': 'and_', + 'or': 'or_'} + + for cpp_name in reg.list_functions(): + name = rewrites.get(cpp_name, cpp_name) + func = reg.get_function(cpp_name) + assert name not in g, name + g[cpp_name] = g[name] = _wrap_function(name, func) + + +_make_global_functions() + + +def cast(arr, target_type, safe=True): + """ + Cast array values to another data type. Can also be invoked as an array + instance method. + + Parameters + ---------- + arr : Array or ChunkedArray + target_type : DataType or type string alias + Type to cast to + safe : bool, default True + Check for overflows or other unsafe conversions + + Examples + -------- + >>> from datetime import datetime + >>> import pyarrow as pa + >>> arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)]) + >>> arr.type + TimestampType(timestamp[us]) + + You can use ``pyarrow.DataType`` objects to specify the target type: + + >>> cast(arr, pa.timestamp('ms')) + <pyarrow.lib.TimestampArray object at 0x7fe93c0f6910> + [ + 2010-01-01 00:00:00.000, + 2015-01-01 00:00:00.000 + ] + + >>> cast(arr, pa.timestamp('ms')).type + TimestampType(timestamp[ms]) + + Alternatively, it is also supported to use the string aliases for these + types: + + >>> arr.cast('timestamp[ms]') + <pyarrow.lib.TimestampArray object at 0x10420eb88> + [ + 1262304000000, + 1420070400000 + ] + >>> arr.cast('timestamp[ms]').type + TimestampType(timestamp[ms]) + + Returns + ------- + casted : Array + """ + if target_type is None: + raise ValueError("Cast target type must not be None") + if safe: + options = CastOptions.safe(target_type) + else: + options = CastOptions.unsafe(target_type) + return call_function("cast", [arr], options) + + +def count_substring(array, pattern, *, ignore_case=False): + """ + Count the occurrences of substring *pattern* in each value of a + string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + ignore_case : bool, default False + Ignore case while searching. + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("count_substring", [array], + MatchSubstringOptions(pattern, + ignore_case=ignore_case)) + + +def count_substring_regex(array, pattern, *, ignore_case=False): + """ + Count the non-overlapping matches of regex *pattern* in each value + of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + ignore_case : bool, default False + Ignore case while searching. + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("count_substring_regex", [array], + MatchSubstringOptions(pattern, + ignore_case=ignore_case)) + + +def find_substring(array, pattern, *, ignore_case=False): + """ + Find the index of the first occurrence of substring *pattern* in each + value of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + ignore_case : bool, default False + Ignore case while searching. + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("find_substring", [array], + MatchSubstringOptions(pattern, + ignore_case=ignore_case)) + + +def find_substring_regex(array, pattern, *, ignore_case=False): + """ + Find the index of the first match of regex *pattern* in each + value of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + regex pattern to search for + ignore_case : bool, default False + Ignore case while searching. + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("find_substring_regex", [array], + MatchSubstringOptions(pattern, + ignore_case=ignore_case)) + + +def match_like(array, pattern, *, ignore_case=False): + """ + Test if the SQL-style LIKE pattern *pattern* matches a value of a + string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + SQL-style LIKE pattern. '%' will match any number of + characters, '_' will match exactly one character, and all + other characters match themselves. To match a literal percent + sign or underscore, precede the character with a backslash. + ignore_case : bool, default False + Ignore case while searching. + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + + """ + return call_function("match_like", [array], + MatchSubstringOptions(pattern, + ignore_case=ignore_case)) + + +def match_substring(array, pattern, *, ignore_case=False): + """ + Test if substring *pattern* is contained within a value of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + ignore_case : bool, default False + Ignore case while searching. + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("match_substring", [array], + MatchSubstringOptions(pattern, + ignore_case=ignore_case)) + + +def match_substring_regex(array, pattern, *, ignore_case=False): + """ + Test if regex *pattern* matches at any position a value of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + regex pattern to search + ignore_case : bool, default False + Ignore case while searching. + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("match_substring_regex", [array], + MatchSubstringOptions(pattern, + ignore_case=ignore_case)) + + +def mode(array, n=1, *, skip_nulls=True, min_count=0): + """ + Return top-n most common values and number of times they occur in a passed + numerical (chunked) array, in descending order of occurrence. If there are + multiple values with same count, the smaller one is returned first. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + n : int, default 1 + Specify the top-n values. + skip_nulls : bool, default True + If True, ignore nulls in the input. Else return an empty array + if any input is null. + min_count : int, default 0 + If there are fewer than this many values in the input, return + an empty array. + + Returns + ------- + An array of <input type "Mode", int64_t "Count"> structs + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array([1, 1, 2, 2, 3, 2, 2, 2]) + >>> modes = pc.mode(arr, 2) + >>> modes[0] + <pyarrow.StructScalar: {'mode': 2, 'count': 5}> + >>> modes[1] + <pyarrow.StructScalar: {'mode': 1, 'count': 2}> + """ + options = ModeOptions(n, skip_nulls=skip_nulls, min_count=min_count) + return call_function("mode", [array], options) + + +def filter(data, mask, null_selection_behavior='drop'): + """ + Select values (or records) from array- or table-like data given boolean + filter, where true values are selected. + + Parameters + ---------- + data : Array, ChunkedArray, RecordBatch, or Table + mask : Array, ChunkedArray + Must be of boolean type + null_selection_behavior : str, default 'drop' + Configure the behavior on encountering a null slot in the mask. + Allowed values are 'drop' and 'emit_null'. + + - 'drop': nulls will be treated as equivalent to False. + - 'emit_null': nulls will result in a null in the output. + + Returns + ------- + result : depends on inputs + + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array(["a", "b", "c", None, "e"]) + >>> mask = pa.array([True, False, None, False, True]) + >>> arr.filter(mask) + <pyarrow.lib.StringArray object at 0x7fa826df9200> + [ + "a", + "e" + ] + >>> arr.filter(mask, null_selection_behavior='emit_null') + <pyarrow.lib.StringArray object at 0x7fa826df9200> + [ + "a", + null, + "e" + ] + """ + options = FilterOptions(null_selection_behavior) + return call_function('filter', [data, mask], options) + + +def index(data, value, start=None, end=None, *, memory_pool=None): + """ + Find the index of the first occurrence of a given value. + + Parameters + ---------- + data : Array or ChunkedArray + value : Scalar-like object + start : int, optional + end : int, optional + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + index : the index, or -1 if not found + """ + if start is not None: + if end is not None: + data = data.slice(start, end - start) + else: + data = data.slice(start) + elif end is not None: + data = data.slice(0, end) + + if not isinstance(value, pa.Scalar): + value = pa.scalar(value, type=data.type) + elif data.type != value.type: + value = pa.scalar(value.as_py(), type=data.type) + options = IndexOptions(value=value) + result = call_function('index', [data], options, memory_pool) + if start is not None and result.as_py() >= 0: + result = pa.scalar(result.as_py() + start, type=pa.int64()) + return result + + +def take(data, indices, *, boundscheck=True, memory_pool=None): + """ + Select values (or records) from array- or table-like data given integer + selection indices. + + The result will be of the same type(s) as the input, with elements taken + from the input array (or record batch / table fields) at the given + indices. If an index is null then the corresponding value in the output + will be null. + + Parameters + ---------- + data : Array, ChunkedArray, RecordBatch, or Table + indices : Array, ChunkedArray + Must be of integer type + boundscheck : boolean, default True + Whether to boundscheck the indices. If False and there is an out of + bounds index, will likely cause the process to crash. + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + result : depends on inputs + + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> indices = pa.array([0, None, 4, 3]) + >>> arr.take(indices) + <pyarrow.lib.StringArray object at 0x7ffa4fc7d368> + [ + "a", + null, + "e", + null + ] + """ + options = TakeOptions(boundscheck=boundscheck) + return call_function('take', [data, indices], options, memory_pool) + + +def fill_null(values, fill_value): + """ + Replace each null element in values with fill_value. The fill_value must be + the same type as values or able to be implicitly casted to the array's + type. + + This is an alias for :func:`coalesce`. + + Parameters + ---------- + values : Array, ChunkedArray, or Scalar-like object + Each null element is replaced with the corresponding value + from fill_value. + fill_value : Array, ChunkedArray, or Scalar-like object + If not same type as data will attempt to cast. + + Returns + ------- + result : depends on inputs + + Examples + -------- + >>> import pyarrow as pa + >>> arr = pa.array([1, 2, None, 3], type=pa.int8()) + >>> fill_value = pa.scalar(5, type=pa.int8()) + >>> arr.fill_null(fill_value) + pyarrow.lib.Int8Array object at 0x7f95437f01a0> + [ + 1, + 2, + 5, + 3 + ] + """ + if not isinstance(fill_value, (pa.Array, pa.ChunkedArray, pa.Scalar)): + fill_value = pa.scalar(fill_value, type=values.type) + elif values.type != fill_value.type: + fill_value = pa.scalar(fill_value.as_py(), type=values.type) + + return call_function("coalesce", [values, fill_value]) + + +def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): + """ + Select the indices of the top-k ordered elements from array- or table-like + data. + + This is a specialization for :func:`select_k_unstable`. Output is not + guaranteed to be stable. + + Parameters + ---------- + values : Array, ChunkedArray, RecordBatch, or Table + Data to sort and get top indices from. + k : int + The number of `k` elements to keep. + sort_keys : List-like + Column key names to order by when input is table-like data. + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + result : Array of indices + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> pc.top_k_unstable(arr, k=3) + <pyarrow.lib.UInt64Array object at 0x7fdcb19d7f30> + [ + 5, + 4, + 2 + ] + """ + if sort_keys is None: + sort_keys = [] + if isinstance(values, (pa.Array, pa.ChunkedArray)): + sort_keys.append(("dummy", "descending")) + else: + sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys) + options = SelectKOptions(k, sort_keys) + return call_function("select_k_unstable", [values], options, memory_pool) + + +def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): + """ + Select the indices of the bottom-k ordered elements from + array- or table-like data. + + This is a specialization for :func:`select_k_unstable`. Output is not + guaranteed to be stable. + + Parameters + ---------- + values : Array, ChunkedArray, RecordBatch, or Table + Data to sort and get bottom indices from. + k : int + The number of `k` elements to keep. + sort_keys : List-like + Column key names to order by when input is table-like data. + memory_pool : MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + + Returns + ------- + result : Array of indices + + Examples + -------- + >>> import pyarrow as pa + >>> import pyarrow.compute as pc + >>> arr = pa.array(["a", "b", "c", None, "e", "f"]) + >>> pc.bottom_k_unstable(arr, k=3) + <pyarrow.lib.UInt64Array object at 0x7fdcb19d7fa0> + [ + 0, + 1, + 2 + ] + """ + if sort_keys is None: + sort_keys = [] + if isinstance(values, (pa.Array, pa.ChunkedArray)): + sort_keys.append(("dummy", "ascending")) + else: + sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys) + options = SelectKOptions(k, sort_keys) + return call_function("select_k_unstable", [values], options, memory_pool) diff --git a/src/arrow/python/pyarrow/config.pxi b/src/arrow/python/pyarrow/config.pxi new file mode 100644 index 000000000..fc88c28d6 --- /dev/null +++ b/src/arrow/python/pyarrow/config.pxi @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow.includes.libarrow cimport GetBuildInfo + +from collections import namedtuple + + +VersionInfo = namedtuple('VersionInfo', ('major', 'minor', 'patch')) + +BuildInfo = namedtuple( + 'BuildInfo', + ('version', 'version_info', 'so_version', 'full_so_version', + 'compiler_id', 'compiler_version', 'compiler_flags', + 'git_id', 'git_description', 'package_kind')) + +RuntimeInfo = namedtuple('RuntimeInfo', + ('simd_level', 'detected_simd_level')) + +cdef _build_info(): + cdef: + const CBuildInfo* c_info + + c_info = &GetBuildInfo() + + return BuildInfo(version=frombytes(c_info.version_string), + version_info=VersionInfo(c_info.version_major, + c_info.version_minor, + c_info.version_patch), + so_version=frombytes(c_info.so_version), + full_so_version=frombytes(c_info.full_so_version), + compiler_id=frombytes(c_info.compiler_id), + compiler_version=frombytes(c_info.compiler_version), + compiler_flags=frombytes(c_info.compiler_flags), + git_id=frombytes(c_info.git_id), + git_description=frombytes(c_info.git_description), + package_kind=frombytes(c_info.package_kind)) + + +cpp_build_info = _build_info() +cpp_version = cpp_build_info.version +cpp_version_info = cpp_build_info.version_info + + +def runtime_info(): + """ + Get runtime information. + + Returns + ------- + info : pyarrow.RuntimeInfo + """ + cdef: + CRuntimeInfo c_info + + c_info = GetRuntimeInfo() + + return RuntimeInfo( + simd_level=frombytes(c_info.simd_level), + detected_simd_level=frombytes(c_info.detected_simd_level)) diff --git a/src/arrow/python/pyarrow/csv.py b/src/arrow/python/pyarrow/csv.py new file mode 100644 index 000000000..e073252cb --- /dev/null +++ b/src/arrow/python/pyarrow/csv.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from pyarrow._csv import ( # noqa + ReadOptions, ParseOptions, ConvertOptions, ISO8601, + open_csv, read_csv, CSVStreamingReader, write_csv, + WriteOptions, CSVWriter) diff --git a/src/arrow/python/pyarrow/cuda.py b/src/arrow/python/pyarrow/cuda.py new file mode 100644 index 000000000..18c530d4a --- /dev/null +++ b/src/arrow/python/pyarrow/cuda.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# flake8: noqa + + +from pyarrow._cuda import (Context, IpcMemHandle, CudaBuffer, + HostBuffer, BufferReader, BufferWriter, + new_host_buffer, + serialize_record_batch, read_message, + read_record_batch) diff --git a/src/arrow/python/pyarrow/dataset.py b/src/arrow/python/pyarrow/dataset.py new file mode 100644 index 000000000..42515a9f4 --- /dev/null +++ b/src/arrow/python/pyarrow/dataset.py @@ -0,0 +1,881 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dataset is currently unstable. APIs subject to change without notice.""" + +import pyarrow as pa +from pyarrow.util import _is_iterable, _stringify_path, _is_path_like + +from pyarrow._dataset import ( # noqa + CsvFileFormat, + CsvFragmentScanOptions, + Expression, + Dataset, + DatasetFactory, + DirectoryPartitioning, + FileFormat, + FileFragment, + FileSystemDataset, + FileSystemDatasetFactory, + FileSystemFactoryOptions, + FileWriteOptions, + Fragment, + HivePartitioning, + IpcFileFormat, + IpcFileWriteOptions, + InMemoryDataset, + ParquetDatasetFactory, + ParquetFactoryOptions, + ParquetFileFormat, + ParquetFileFragment, + ParquetFileWriteOptions, + ParquetFragmentScanOptions, + ParquetReadOptions, + Partitioning, + PartitioningFactory, + RowGroupInfo, + Scanner, + TaggedRecordBatch, + UnionDataset, + UnionDatasetFactory, + _get_partition_keys, + _filesystemdataset_write, +) + +_orc_available = False +_orc_msg = ( + "The pyarrow installation is not built with support for the ORC file " + "format." +) + +try: + from pyarrow._dataset_orc import OrcFileFormat + _orc_available = True +except ImportError: + pass + + +def __getattr__(name): + if name == "OrcFileFormat" and not _orc_available: + raise ImportError(_orc_msg) + + raise AttributeError( + "module 'pyarrow.dataset' has no attribute '{0}'".format(name) + ) + + +def field(name): + """Reference a named column of the dataset. + + Stores only the field's name. Type and other information is known only when + the expression is bound to a dataset having an explicit scheme. + + Parameters + ---------- + name : string + The name of the field the expression references to. + + Returns + ------- + field_expr : Expression + """ + return Expression._field(name) + + +def scalar(value): + """Expression representing a scalar value. + + Parameters + ---------- + value : bool, int, float or string + Python value of the scalar. Note that only a subset of types are + currently supported. + + Returns + ------- + scalar_expr : Expression + """ + return Expression._scalar(value) + + +def partitioning(schema=None, field_names=None, flavor=None, + dictionaries=None): + """ + Specify a partitioning scheme. + + The supported schemes include: + + - "DirectoryPartitioning": this scheme expects one segment in the file path + for each field in the specified schema (all fields are required to be + present). For example given schema<year:int16, month:int8> the path + "/2009/11" would be parsed to ("year"_ == 2009 and "month"_ == 11). + - "HivePartitioning": a scheme for "/$key=$value/" nested directories as + found in Apache Hive. This is a multi-level, directory based partitioning + scheme. Data is partitioned by static values of a particular column in + the schema. Partition keys are represented in the form $key=$value in + directory names. Field order is ignored, as are missing or unrecognized + field names. + For example, given schema<year:int16, month:int8, day:int8>, a possible + path would be "/year=2009/month=11/day=15" (but the field order does not + need to match). + + Parameters + ---------- + schema : pyarrow.Schema, default None + The schema that describes the partitions present in the file path. + If not specified, and `field_names` and/or `flavor` are specified, + the schema will be inferred from the file path (and a + PartitioningFactory is returned). + field_names : list of str, default None + A list of strings (field names). If specified, the schema's types are + inferred from the file paths (only valid for DirectoryPartitioning). + flavor : str, default None + The default is DirectoryPartitioning. Specify ``flavor="hive"`` for + a HivePartitioning. + dictionaries : Dict[str, Array] + If the type of any field of `schema` is a dictionary type, the + corresponding entry of `dictionaries` must be an array containing + every value which may be taken by the corresponding column or an + error will be raised in parsing. Alternatively, pass `infer` to have + Arrow discover the dictionary values, in which case a + PartitioningFactory is returned. + + Returns + ------- + Partitioning or PartitioningFactory + + Examples + -------- + + Specify the Schema for paths like "/2009/June": + + >>> partitioning(pa.schema([("year", pa.int16()), ("month", pa.string())])) + + or let the types be inferred by only specifying the field names: + + >>> partitioning(field_names=["year", "month"]) + + For paths like "/2009/June", the year will be inferred as int32 while month + will be inferred as string. + + Specify a Schema with dictionary encoding, providing dictionary values: + + >>> partitioning( + ... pa.schema([ + ... ("year", pa.int16()), + ... ("month", pa.dictionary(pa.int8(), pa.string())) + ... ]), + ... dictionaries={ + ... "month": pa.array(["January", "February", "March"]), + ... }) + + Alternatively, specify a Schema with dictionary encoding, but have Arrow + infer the dictionary values: + + >>> partitioning( + ... pa.schema([ + ... ("year", pa.int16()), + ... ("month", pa.dictionary(pa.int8(), pa.string())) + ... ]), + ... dictionaries="infer") + + Create a Hive scheme for a path like "/year=2009/month=11": + + >>> partitioning( + ... pa.schema([("year", pa.int16()), ("month", pa.int8())]), + ... flavor="hive") + + A Hive scheme can also be discovered from the directory structure (and + types will be inferred): + + >>> partitioning(flavor="hive") + + """ + if flavor is None: + # default flavor + if schema is not None: + if field_names is not None: + raise ValueError( + "Cannot specify both 'schema' and 'field_names'") + if dictionaries == 'infer': + return DirectoryPartitioning.discover(schema=schema) + return DirectoryPartitioning(schema, dictionaries) + elif field_names is not None: + if isinstance(field_names, list): + return DirectoryPartitioning.discover(field_names) + else: + raise ValueError( + "Expected list of field names, got {}".format( + type(field_names))) + else: + raise ValueError( + "For the default directory flavor, need to specify " + "a Schema or a list of field names") + elif flavor == 'hive': + if field_names is not None: + raise ValueError("Cannot specify 'field_names' for flavor 'hive'") + elif schema is not None: + if isinstance(schema, pa.Schema): + if dictionaries == 'infer': + return HivePartitioning.discover(schema=schema) + return HivePartitioning(schema, dictionaries) + else: + raise ValueError( + "Expected Schema for 'schema', got {}".format( + type(schema))) + else: + return HivePartitioning.discover() + else: + raise ValueError("Unsupported flavor") + + +def _ensure_partitioning(scheme): + """ + Validate input and return a Partitioning(Factory). + + It passes None through if no partitioning scheme is defined. + """ + if scheme is None: + pass + elif isinstance(scheme, str): + scheme = partitioning(flavor=scheme) + elif isinstance(scheme, list): + scheme = partitioning(field_names=scheme) + elif isinstance(scheme, (Partitioning, PartitioningFactory)): + pass + else: + ValueError("Expected Partitioning or PartitioningFactory, got {}" + .format(type(scheme))) + return scheme + + +def _ensure_format(obj): + if isinstance(obj, FileFormat): + return obj + elif obj == "parquet": + return ParquetFileFormat() + elif obj in {"ipc", "arrow", "feather"}: + return IpcFileFormat() + elif obj == "csv": + return CsvFileFormat() + elif obj == "orc": + if not _orc_available: + raise ValueError(_orc_msg) + return OrcFileFormat() + else: + raise ValueError("format '{}' is not supported".format(obj)) + + +def _ensure_multiple_sources(paths, filesystem=None): + """ + Treat a list of paths as files belonging to a single file system + + If the file system is local then also validates that all paths + are referencing existing *files* otherwise any non-file paths will be + silently skipped (for example on a remote filesystem). + + Parameters + ---------- + paths : list of path-like + Note that URIs are not allowed. + filesystem : FileSystem or str, optional + If an URI is passed, then its path component will act as a prefix for + the file paths. + + Returns + ------- + (FileSystem, list of str) + File system object and a list of normalized paths. + + Raises + ------ + TypeError + If the passed filesystem has wrong type. + IOError + If the file system is local and a referenced path is not available or + not a file. + """ + from pyarrow.fs import ( + LocalFileSystem, SubTreeFileSystem, _MockFileSystem, FileType, + _ensure_filesystem + ) + + if filesystem is None: + # fall back to local file system as the default + filesystem = LocalFileSystem() + else: + # construct a filesystem if it is a valid URI + filesystem = _ensure_filesystem(filesystem) + + is_local = ( + isinstance(filesystem, (LocalFileSystem, _MockFileSystem)) or + (isinstance(filesystem, SubTreeFileSystem) and + isinstance(filesystem.base_fs, LocalFileSystem)) + ) + + # allow normalizing irregular paths such as Windows local paths + paths = [filesystem.normalize_path(_stringify_path(p)) for p in paths] + + # validate that all of the paths are pointing to existing *files* + # possible improvement is to group the file_infos by type and raise for + # multiple paths per error category + if is_local: + for info in filesystem.get_file_info(paths): + file_type = info.type + if file_type == FileType.File: + continue + elif file_type == FileType.NotFound: + raise FileNotFoundError(info.path) + elif file_type == FileType.Directory: + raise IsADirectoryError( + 'Path {} points to a directory, but only file paths are ' + 'supported. To construct a nested or union dataset pass ' + 'a list of dataset objects instead.'.format(info.path) + ) + else: + raise IOError( + 'Path {} exists but its type is unknown (could be a ' + 'special file such as a Unix socket or character device, ' + 'or Windows NUL / CON / ...)'.format(info.path) + ) + + return filesystem, paths + + +def _ensure_single_source(path, filesystem=None): + """ + Treat path as either a recursively traversable directory or a single file. + + Parameters + ---------- + path : path-like + filesystem : FileSystem or str, optional + If an URI is passed, then its path component will act as a prefix for + the file paths. + + Returns + ------- + (FileSystem, list of str or fs.Selector) + File system object and either a single item list pointing to a file or + an fs.Selector object pointing to a directory. + + Raises + ------ + TypeError + If the passed filesystem has wrong type. + FileNotFoundError + If the referenced file or directory doesn't exist. + """ + from pyarrow.fs import FileType, FileSelector, _resolve_filesystem_and_path + + # at this point we already checked that `path` is a path-like + filesystem, path = _resolve_filesystem_and_path(path, filesystem) + + # ensure that the path is normalized before passing to dataset discovery + path = filesystem.normalize_path(path) + + # retrieve the file descriptor + file_info = filesystem.get_file_info(path) + + # depending on the path type either return with a recursive + # directory selector or as a list containing a single file + if file_info.type == FileType.Directory: + paths_or_selector = FileSelector(path, recursive=True) + elif file_info.type == FileType.File: + paths_or_selector = [path] + else: + raise FileNotFoundError(path) + + return filesystem, paths_or_selector + + +def _filesystem_dataset(source, schema=None, filesystem=None, + partitioning=None, format=None, + partition_base_dir=None, exclude_invalid_files=None, + selector_ignore_prefixes=None): + """ + Create a FileSystemDataset which can be used to build a Dataset. + + Parameters are documented in the dataset function. + + Returns + ------- + FileSystemDataset + """ + format = _ensure_format(format or 'parquet') + partitioning = _ensure_partitioning(partitioning) + + if isinstance(source, (list, tuple)): + fs, paths_or_selector = _ensure_multiple_sources(source, filesystem) + else: + fs, paths_or_selector = _ensure_single_source(source, filesystem) + + options = FileSystemFactoryOptions( + partitioning=partitioning, + partition_base_dir=partition_base_dir, + exclude_invalid_files=exclude_invalid_files, + selector_ignore_prefixes=selector_ignore_prefixes + ) + factory = FileSystemDatasetFactory(fs, paths_or_selector, format, options) + + return factory.finish(schema) + + +def _in_memory_dataset(source, schema=None, **kwargs): + if any(v is not None for v in kwargs.values()): + raise ValueError( + "For in-memory datasets, you cannot pass any additional arguments") + return InMemoryDataset(source, schema) + + +def _union_dataset(children, schema=None, **kwargs): + if any(v is not None for v in kwargs.values()): + raise ValueError( + "When passing a list of Datasets, you cannot pass any additional " + "arguments" + ) + + if schema is None: + # unify the children datasets' schemas + schema = pa.unify_schemas([child.schema for child in children]) + + # create datasets with the requested schema + children = [child.replace_schema(schema) for child in children] + + return UnionDataset(schema, children) + + +def parquet_dataset(metadata_path, schema=None, filesystem=None, format=None, + partitioning=None, partition_base_dir=None): + """ + Create a FileSystemDataset from a `_metadata` file created via + `pyarrrow.parquet.write_metadata`. + + Parameters + ---------- + metadata_path : path, + Path pointing to a single file parquet metadata file + schema : Schema, optional + Optionally provide the Schema for the Dataset, in which case it will + not be inferred from the source. + filesystem : FileSystem or URI string, default None + If a single path is given as source and filesystem is None, then the + filesystem will be inferred from the path. + If an URI string is passed, then a filesystem object is constructed + using the URI's optional path component as a directory prefix. See the + examples below. + Note that the URIs on Windows must follow 'file:///C:...' or + 'file:/C:...' patterns. + format : ParquetFileFormat + An instance of a ParquetFileFormat if special options needs to be + passed. + partitioning : Partitioning, PartitioningFactory, str, list of str + The partitioning scheme specified with the ``partitioning()`` + function. A flavor string can be used as shortcut, and with a list of + field names a DirectionaryPartitioning will be inferred. + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + + Returns + ------- + FileSystemDataset + """ + from pyarrow.fs import LocalFileSystem, _ensure_filesystem + + if format is None: + format = ParquetFileFormat() + elif not isinstance(format, ParquetFileFormat): + raise ValueError("format argument must be a ParquetFileFormat") + + if filesystem is None: + filesystem = LocalFileSystem() + else: + filesystem = _ensure_filesystem(filesystem) + + metadata_path = filesystem.normalize_path(_stringify_path(metadata_path)) + options = ParquetFactoryOptions( + partition_base_dir=partition_base_dir, + partitioning=_ensure_partitioning(partitioning) + ) + + factory = ParquetDatasetFactory( + metadata_path, filesystem, format, options=options) + return factory.finish(schema) + + +def dataset(source, schema=None, format=None, filesystem=None, + partitioning=None, partition_base_dir=None, + exclude_invalid_files=None, ignore_prefixes=None): + """ + Open a dataset. + + Datasets provides functionality to efficiently work with tabular, + potentially larger than memory and multi-file dataset. + + - A unified interface for different sources, like Parquet and Feather + - Discovery of sources (crawling directories, handle directory-based + partitioned datasets, basic schema normalization) + - Optimized reading with predicate pushdown (filtering rows), projection + (selecting columns), parallel reading or fine-grained managing of tasks. + + Note that this is the high-level API, to have more control over the dataset + construction use the low-level API classes (FileSystemDataset, + FilesystemDatasetFactory, etc.) + + Parameters + ---------- + source : path, list of paths, dataset, list of datasets, (list of) batches\ +or tables, iterable of batches, RecordBatchReader, or URI + Path pointing to a single file: + Open a FileSystemDataset from a single file. + Path pointing to a directory: + The directory gets discovered recursively according to a + partitioning scheme if given. + List of file paths: + Create a FileSystemDataset from explicitly given files. The files + must be located on the same filesystem given by the filesystem + parameter. + Note that in contrary of construction from a single file, passing + URIs as paths is not allowed. + List of datasets: + A nested UnionDataset gets constructed, it allows arbitrary + composition of other datasets. + Note that additional keyword arguments are not allowed. + (List of) batches or tables, iterable of batches, or RecordBatchReader: + Create an InMemoryDataset. If an iterable or empty list is given, + a schema must also be given. If an iterable or RecordBatchReader + is given, the resulting dataset can only be scanned once; further + attempts will raise an error. + schema : Schema, optional + Optionally provide the Schema for the Dataset, in which case it will + not be inferred from the source. + format : FileFormat or str + Currently "parquet" and "ipc"/"arrow"/"feather" are supported. For + Feather, only version 2 files are supported. + filesystem : FileSystem or URI string, default None + If a single path is given as source and filesystem is None, then the + filesystem will be inferred from the path. + If an URI string is passed, then a filesystem object is constructed + using the URI's optional path component as a directory prefix. See the + examples below. + Note that the URIs on Windows must follow 'file:///C:...' or + 'file:/C:...' patterns. + partitioning : Partitioning, PartitioningFactory, str, list of str + The partitioning scheme specified with the ``partitioning()`` + function. A flavor string can be used as shortcut, and with a list of + field names a DirectionaryPartitioning will be inferred. + partition_base_dir : str, optional + For the purposes of applying the partitioning, paths will be + stripped of the partition_base_dir. Files not matching the + partition_base_dir prefix will be skipped for partitioning discovery. + The ignored files will still be part of the Dataset, but will not + have partition information. + exclude_invalid_files : bool, optional (default True) + If True, invalid files will be excluded (file format specific check). + This will incur IO for each files in a serial and single threaded + fashion. Disabling this feature will skip the IO, but unsupported + files may be present in the Dataset (resulting in an error at scan + time). + ignore_prefixes : list, optional + Files matching any of these prefixes will be ignored by the + discovery process. This is matched to the basename of a path. + By default this is ['.', '_']. + Note that discovery happens only if a directory is passed as source. + + Returns + ------- + dataset : Dataset + Either a FileSystemDataset or a UnionDataset depending on the source + parameter. + + Examples + -------- + Opening a single file: + + >>> dataset("path/to/file.parquet", format="parquet") + + Opening a single file with an explicit schema: + + >>> dataset("path/to/file.parquet", schema=myschema, format="parquet") + + Opening a dataset for a single directory: + + >>> dataset("path/to/nyc-taxi/", format="parquet") + >>> dataset("s3://mybucket/nyc-taxi/", format="parquet") + + Opening a dataset from a list of relatives local paths: + + >>> dataset([ + ... "part0/data.parquet", + ... "part1/data.parquet", + ... "part3/data.parquet", + ... ], format='parquet') + + With filesystem provided: + + >>> paths = [ + ... 'part0/data.parquet', + ... 'part1/data.parquet', + ... 'part3/data.parquet', + ... ] + >>> dataset(paths, filesystem='file:///directory/prefix, format='parquet') + + Which is equivalent with: + + >>> fs = SubTreeFileSystem("/directory/prefix", LocalFileSystem()) + >>> dataset(paths, filesystem=fs, format='parquet') + + With a remote filesystem URI: + + >>> paths = [ + ... 'nested/directory/part0/data.parquet', + ... 'nested/directory/part1/data.parquet', + ... 'nested/directory/part3/data.parquet', + ... ] + >>> dataset(paths, filesystem='s3://bucket/', format='parquet') + + Similarly to the local example, the directory prefix may be included in the + filesystem URI: + + >>> dataset(paths, filesystem='s3://bucket/nested/directory', + ... format='parquet') + + Construction of a nested dataset: + + >>> dataset([ + ... dataset("s3://old-taxi-data", format="parquet"), + ... dataset("local/path/to/data", format="ipc") + ... ]) + """ + # collect the keyword arguments for later reuse + kwargs = dict( + schema=schema, + filesystem=filesystem, + partitioning=partitioning, + format=format, + partition_base_dir=partition_base_dir, + exclude_invalid_files=exclude_invalid_files, + selector_ignore_prefixes=ignore_prefixes + ) + + if _is_path_like(source): + return _filesystem_dataset(source, **kwargs) + elif isinstance(source, (tuple, list)): + if all(_is_path_like(elem) for elem in source): + return _filesystem_dataset(source, **kwargs) + elif all(isinstance(elem, Dataset) for elem in source): + return _union_dataset(source, **kwargs) + elif all(isinstance(elem, (pa.RecordBatch, pa.Table)) + for elem in source): + return _in_memory_dataset(source, **kwargs) + else: + unique_types = set(type(elem).__name__ for elem in source) + type_names = ', '.join('{}'.format(t) for t in unique_types) + raise TypeError( + 'Expected a list of path-like or dataset objects, or a list ' + 'of batches or tables. The given list contains the following ' + 'types: {}'.format(type_names) + ) + elif isinstance(source, (pa.RecordBatch, pa.Table)): + return _in_memory_dataset(source, **kwargs) + else: + raise TypeError( + 'Expected a path-like, list of path-likes or a list of Datasets ' + 'instead of the given type: {}'.format(type(source).__name__) + ) + + +def _ensure_write_partitioning(part, schema, flavor): + if isinstance(part, PartitioningFactory): + raise ValueError("A PartitioningFactory cannot be used. " + "Did you call the partitioning function " + "without supplying a schema?") + + if isinstance(part, Partitioning) and flavor: + raise ValueError( + "Providing a partitioning_flavor with " + "a Partitioning object is not supported" + ) + elif isinstance(part, (tuple, list)): + # Name of fields were provided instead of a partitioning object. + # Create a partitioning factory with those field names. + part = partitioning( + schema=pa.schema([schema.field(f) for f in part]), + flavor=flavor + ) + elif part is None: + part = partitioning(pa.schema([]), flavor=flavor) + + if not isinstance(part, Partitioning): + raise ValueError( + "partitioning must be a Partitioning object or " + "a list of column names" + ) + + return part + + +def write_dataset(data, base_dir, basename_template=None, format=None, + partitioning=None, partitioning_flavor=None, schema=None, + filesystem=None, file_options=None, use_threads=True, + max_partitions=None, file_visitor=None, + existing_data_behavior='error'): + """ + Write a dataset to a given format and partitioning. + + Parameters + ---------- + data : Dataset, Table/RecordBatch, RecordBatchReader, list of + Table/RecordBatch, or iterable of RecordBatch + The data to write. This can be a Dataset instance or + in-memory Arrow data. If an iterable is given, the schema must + also be given. + base_dir : str + The root directory where to write the dataset. + basename_template : str, optional + A template string used to generate basenames of written data files. + The token '{i}' will be replaced with an automatically incremented + integer. If not specified, it defaults to + "part-{i}." + format.default_extname + format : FileFormat or str + The format in which to write the dataset. Currently supported: + "parquet", "ipc"/"feather". If a FileSystemDataset is being written + and `format` is not specified, it defaults to the same format as the + specified FileSystemDataset. When writing a Table or RecordBatch, this + keyword is required. + partitioning : Partitioning or list[str], optional + The partitioning scheme specified with the ``partitioning()`` + function or a list of field names. When providing a list of + field names, you can use ``partitioning_flavor`` to drive which + partitioning type should be used. + partitioning_flavor : str, optional + One of the partitioning flavors supported by + ``pyarrow.dataset.partitioning``. If omitted will use the + default of ``partitioning()`` which is directory partitioning. + schema : Schema, optional + filesystem : FileSystem, optional + file_options : FileWriteOptions, optional + FileFormat specific write options, created using the + ``FileFormat.make_write_options()`` function. + use_threads : bool, default True + Write files in parallel. If enabled, then maximum parallelism will be + used determined by the number of available CPU cores. + max_partitions : int, default 1024 + Maximum number of partitions any batch may be written into. + file_visitor : Function + If set, this function will be called with a WrittenFile instance + for each file created during the call. This object will have both + a path attribute and a metadata attribute. + + The path attribute will be a string containing the path to + the created file. + + The metadata attribute will be the parquet metadata of the file. + This metadata will have the file path attribute set and can be used + to build a _metadata file. The metadata attribute will be None if + the format is not parquet. + + Example visitor which simple collects the filenames created:: + + visited_paths = [] + + def file_visitor(written_file): + visited_paths.append(written_file.path) + existing_data_behavior : 'error' | 'overwrite_or_ignore' | \ +'delete_matching' + Controls how the dataset will handle data that already exists in + the destination. The default behavior ('error') is to raise an error + if any data exists in the destination. + + 'overwrite_or_ignore' will ignore any existing data and will + overwrite files with the same name as an output file. Other + existing files will be ignored. This behavior, in combination + with a unique basename_template for each write, will allow for + an append workflow. + + 'delete_matching' is useful when you are writing a partitioned + dataset. The first time each partition directory is encountered + the entire directory will be deleted. This allows you to overwrite + old partitions completely. + """ + from pyarrow.fs import _resolve_filesystem_and_path + + if isinstance(data, (list, tuple)): + schema = schema or data[0].schema + data = InMemoryDataset(data, schema=schema) + elif isinstance(data, (pa.RecordBatch, pa.Table)): + schema = schema or data.schema + data = InMemoryDataset(data, schema=schema) + elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data): + data = Scanner.from_batches(data, schema=schema, use_async=True) + schema = None + elif not isinstance(data, (Dataset, Scanner)): + raise ValueError( + "Only Dataset, Scanner, Table/RecordBatch, RecordBatchReader, " + "a list of Tables/RecordBatches, or iterable of batches are " + "supported." + ) + + if format is None and isinstance(data, FileSystemDataset): + format = data.format + else: + format = _ensure_format(format) + + if file_options is None: + file_options = format.make_write_options() + + if format != file_options.format: + raise TypeError("Supplied FileWriteOptions have format {}, " + "which doesn't match supplied FileFormat {}".format( + format, file_options)) + + if basename_template is None: + basename_template = "part-{i}." + format.default_extname + + if max_partitions is None: + max_partitions = 1024 + + # at this point data is a Scanner or a Dataset, anything else + # was converted to one of those two. So we can grab the schema + # to build the partitioning object from Dataset. + if isinstance(data, Scanner): + partitioning_schema = data.dataset_schema + else: + partitioning_schema = data.schema + partitioning = _ensure_write_partitioning(partitioning, + schema=partitioning_schema, + flavor=partitioning_flavor) + + filesystem, base_dir = _resolve_filesystem_and_path(base_dir, filesystem) + + if isinstance(data, Dataset): + scanner = data.scanner(use_threads=use_threads, use_async=True) + else: + # scanner was passed directly by the user, in which case a schema + # cannot be passed + if schema is not None: + raise ValueError("Cannot specify a schema when writing a Scanner") + scanner = data + + _filesystemdataset_write( + scanner, base_dir, basename_template, filesystem, partitioning, + file_options, max_partitions, file_visitor, existing_data_behavior + ) diff --git a/src/arrow/python/pyarrow/error.pxi b/src/arrow/python/pyarrow/error.pxi new file mode 100644 index 000000000..233b4fb16 --- /dev/null +++ b/src/arrow/python/pyarrow/error.pxi @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.exc cimport PyErr_CheckSignals, PyErr_SetInterrupt + +from pyarrow.includes.libarrow cimport CStatus, IsPyError, RestorePyError +from pyarrow.includes.common cimport c_string + +from contextlib import contextmanager +import os +import signal +import threading + +from pyarrow.util import _break_traceback_cycle_from_frame + + +class ArrowException(Exception): + pass + + +class ArrowInvalid(ValueError, ArrowException): + pass + + +class ArrowMemoryError(MemoryError, ArrowException): + pass + + +class ArrowKeyError(KeyError, ArrowException): + def __str__(self): + # Override KeyError.__str__, as it uses the repr() of the key + return ArrowException.__str__(self) + + +class ArrowTypeError(TypeError, ArrowException): + pass + + +class ArrowNotImplementedError(NotImplementedError, ArrowException): + pass + + +class ArrowCapacityError(ArrowException): + pass + + +class ArrowIndexError(IndexError, ArrowException): + pass + + +class ArrowSerializationError(ArrowException): + pass + + +class ArrowCancelled(ArrowException): + def __init__(self, message, signum=None): + super().__init__(message) + self.signum = signum + + +# Compatibility alias +ArrowIOError = IOError + + +# This function could be written directly in C++ if we didn't +# define Arrow-specific subclasses (ArrowInvalid etc.) +cdef int check_status(const CStatus& status) nogil except -1: + if status.ok(): + return 0 + + with gil: + if IsPyError(status): + RestorePyError(status) + return -1 + + # We don't use Status::ToString() as it would redundantly include + # the C++ class name. + message = frombytes(status.message(), safe=True) + detail = status.detail() + if detail != nullptr: + message += ". Detail: " + frombytes(detail.get().ToString(), + safe=True) + + if status.IsInvalid(): + raise ArrowInvalid(message) + elif status.IsIOError(): + # Note: OSError constructor is + # OSError(message) + # or + # OSError(errno, message, filename=None) + # or (on Windows) + # OSError(errno, message, filename, winerror) + errno = ErrnoFromStatus(status) + winerror = WinErrorFromStatus(status) + if winerror != 0: + raise IOError(errno, message, None, winerror) + elif errno != 0: + raise IOError(errno, message) + else: + raise IOError(message) + elif status.IsOutOfMemory(): + raise ArrowMemoryError(message) + elif status.IsKeyError(): + raise ArrowKeyError(message) + elif status.IsNotImplemented(): + raise ArrowNotImplementedError(message) + elif status.IsTypeError(): + raise ArrowTypeError(message) + elif status.IsCapacityError(): + raise ArrowCapacityError(message) + elif status.IsIndexError(): + raise ArrowIndexError(message) + elif status.IsSerializationError(): + raise ArrowSerializationError(message) + elif status.IsCancelled(): + signum = SignalFromStatus(status) + if signum > 0: + raise ArrowCancelled(message, signum) + else: + raise ArrowCancelled(message) + else: + message = frombytes(status.ToString(), safe=True) + raise ArrowException(message) + + +# This is an API function for C++ PyArrow +cdef api int pyarrow_internal_check_status(const CStatus& status) \ + nogil except -1: + return check_status(status) + + +cdef class StopToken: + cdef void init(self, CStopToken stop_token): + self.stop_token = move(stop_token) + + +cdef c_bool signal_handlers_enabled = True + + +def enable_signal_handlers(c_bool enable): + """ + Enable or disable interruption of long-running operations. + + By default, certain long running operations will detect user + interruptions, such as by pressing Ctrl-C. This detection relies + on setting a signal handler for the duration of the long-running + operation, and may therefore interfere with other frameworks or + libraries (such as an event loop). + + Parameters + ---------- + enable : bool + Whether to enable user interruption by setting a temporary + signal handler. + """ + global signal_handlers_enabled + signal_handlers_enabled = enable + + +# For internal use + +# Whether we need a workaround for https://bugs.python.org/issue42248 +have_signal_refcycle = (sys.version_info < (3, 8, 10) or + (3, 9) <= sys.version_info < (3, 9, 5) or + sys.version_info[:2] == (3, 10)) + +cdef class SignalStopHandler: + cdef: + StopToken _stop_token + vector[int] _signals + c_bool _enabled + + def __cinit__(self): + self._enabled = False + + self._init_signals() + if have_signal_refcycle: + _break_traceback_cycle_from_frame(sys._getframe(0)) + + self._stop_token = StopToken() + if not self._signals.empty(): + self._stop_token.init(GetResultValue( + SetSignalStopSource()).token()) + self._enabled = True + + def _init_signals(self): + if (signal_handlers_enabled and + threading.current_thread() is threading.main_thread()): + self._signals = [ + sig for sig in (signal.SIGINT, signal.SIGTERM) + if signal.getsignal(sig) not in (signal.SIG_DFL, + signal.SIG_IGN, None)] + + def __enter__(self): + if self._enabled: + check_status(RegisterCancellingSignalHandler(self._signals)) + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + if self._enabled: + UnregisterCancellingSignalHandler() + if isinstance(exc_value, ArrowCancelled): + if exc_value.signum: + # Re-emit the exact same signal. We restored the Python signal + # handler above, so it should receive it. + if os.name == 'nt': + SendSignal(exc_value.signum) + else: + SendSignalToThread(exc_value.signum, threading.get_ident()) + else: + # Simulate Python receiving a SIGINT + # (see https://bugs.python.org/issue43356 for why we can't + # simulate the exact signal number) + PyErr_SetInterrupt() + # Maximize chances of the Python signal handler being executed now. + # Otherwise a potential KeyboardInterrupt might be missed by an + # immediately enclosing try/except block. + PyErr_CheckSignals() + # ArrowCancelled will be re-raised if PyErr_CheckSignals() + # returned successfully. + + def __dealloc__(self): + if self._enabled: + ResetSignalStopSource() + + @property + def stop_token(self): + return self._stop_token diff --git a/src/arrow/python/pyarrow/feather.py b/src/arrow/python/pyarrow/feather.py new file mode 100644 index 000000000..2170a93c3 --- /dev/null +++ b/src/arrow/python/pyarrow/feather.py @@ -0,0 +1,265 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import os + +from pyarrow.pandas_compat import _pandas_api # noqa +from pyarrow.lib import (Codec, Table, # noqa + concat_tables, schema) +import pyarrow.lib as ext +from pyarrow import _feather +from pyarrow._feather import FeatherError # noqa: F401 +from pyarrow.vendored.version import Version + + +def _check_pandas_version(): + if _pandas_api.loose_version < Version('0.17.0'): + raise ImportError("feather requires pandas >= 0.17.0") + + +class FeatherDataset: + """ + Encapsulates details of reading a list of Feather files. + + Parameters + ---------- + path_or_paths : List[str] + A list of file names + validate_schema : bool, default True + Check that individual file schemas are all the same / compatible + """ + + def __init__(self, path_or_paths, validate_schema=True): + self.paths = path_or_paths + self.validate_schema = validate_schema + + def read_table(self, columns=None): + """ + Read multiple feather files as a single pyarrow.Table + + Parameters + ---------- + columns : List[str] + Names of columns to read from the file + + Returns + ------- + pyarrow.Table + Content of the file as a table (of columns) + """ + _fil = read_table(self.paths[0], columns=columns) + self._tables = [_fil] + self.schema = _fil.schema + + for path in self.paths[1:]: + table = read_table(path, columns=columns) + if self.validate_schema: + self.validate_schemas(path, table) + self._tables.append(table) + return concat_tables(self._tables) + + def validate_schemas(self, piece, table): + if not self.schema.equals(table.schema): + raise ValueError('Schema in {!s} was different. \n' + '{!s}\n\nvs\n\n{!s}' + .format(piece, self.schema, + table.schema)) + + def read_pandas(self, columns=None, use_threads=True): + """ + Read multiple Parquet files as a single pandas DataFrame + + Parameters + ---------- + columns : List[str] + Names of columns to read from the file + use_threads : bool, default True + Use multiple threads when converting to pandas + + Returns + ------- + pandas.DataFrame + Content of the file as a pandas DataFrame (of columns) + """ + _check_pandas_version() + return self.read_table(columns=columns).to_pandas( + use_threads=use_threads) + + +def check_chunked_overflow(name, col): + if col.num_chunks == 1: + return + + if col.type in (ext.binary(), ext.string()): + raise ValueError("Column '{}' exceeds 2GB maximum capacity of " + "a Feather binary column. This restriction may be " + "lifted in the future".format(name)) + else: + # TODO(wesm): Not sure when else this might be reached + raise ValueError("Column '{}' of type {} was chunked on conversion " + "to Arrow and cannot be currently written to " + "Feather format".format(name, str(col.type))) + + +_FEATHER_SUPPORTED_CODECS = {'lz4', 'zstd', 'uncompressed'} + + +def write_feather(df, dest, compression=None, compression_level=None, + chunksize=None, version=2): + """ + Write a pandas.DataFrame to Feather format. + + Parameters + ---------- + df : pandas.DataFrame or pyarrow.Table + Data to write out as Feather format. + dest : str + Local destination path. + compression : string, default None + Can be one of {"zstd", "lz4", "uncompressed"}. The default of None uses + LZ4 for V2 files if it is available, otherwise uncompressed. + compression_level : int, default None + Use a compression level particular to the chosen compressor. If None + use the default compression level + chunksize : int, default None + For V2 files, the internal maximum size of Arrow RecordBatch chunks + when writing the Arrow IPC file format. None means use the default, + which is currently 64K + version : int, default 2 + Feather file version. Version 2 is the current. Version 1 is the more + limited legacy format + """ + if _pandas_api.have_pandas: + _check_pandas_version() + if (_pandas_api.has_sparse and + isinstance(df, _pandas_api.pd.SparseDataFrame)): + df = df.to_dense() + + if _pandas_api.is_data_frame(df): + table = Table.from_pandas(df, preserve_index=False) + + if version == 1: + # Version 1 does not chunking + for i, name in enumerate(table.schema.names): + col = table[i] + check_chunked_overflow(name, col) + else: + table = df + + if version == 1: + if len(table.column_names) > len(set(table.column_names)): + raise ValueError("cannot serialize duplicate column names") + + if compression is not None: + raise ValueError("Feather V1 files do not support compression " + "option") + + if chunksize is not None: + raise ValueError("Feather V1 files do not support chunksize " + "option") + else: + if compression is None and Codec.is_available('lz4_frame'): + compression = 'lz4' + elif (compression is not None and + compression not in _FEATHER_SUPPORTED_CODECS): + raise ValueError('compression="{}" not supported, must be ' + 'one of {}'.format(compression, + _FEATHER_SUPPORTED_CODECS)) + + try: + _feather.write_feather(table, dest, compression=compression, + compression_level=compression_level, + chunksize=chunksize, version=version) + except Exception: + if isinstance(dest, str): + try: + os.remove(dest) + except os.error: + pass + raise + + +def read_feather(source, columns=None, use_threads=True, memory_map=True): + """ + Read a pandas.DataFrame from Feather format. To read as pyarrow.Table use + feather.read_table. + + Parameters + ---------- + source : str file path, or file-like object + columns : sequence, optional + Only read a specific set of columns. If not provided, all columns are + read. + use_threads : bool, default True + Whether to parallelize reading using multiple threads. If false the + restriction is only used in the conversion to Pandas and not in the + reading from Feather format. + memory_map : boolean, default True + Use memory mapping when opening file on disk + + Returns + ------- + df : pandas.DataFrame + """ + _check_pandas_version() + return (read_table(source, columns=columns, memory_map=memory_map) + .to_pandas(use_threads=use_threads)) + + +def read_table(source, columns=None, memory_map=True): + """ + Read a pyarrow.Table from Feather format + + Parameters + ---------- + source : str file path, or file-like object + columns : sequence, optional + Only read a specific set of columns. If not provided, all columns are + read. + memory_map : boolean, default True + Use memory mapping when opening file on disk + + Returns + ------- + table : pyarrow.Table + """ + reader = _feather.FeatherReader(source, use_memory_map=memory_map) + + if columns is None: + return reader.read() + + column_types = [type(column) for column in columns] + if all(map(lambda t: t == int, column_types)): + table = reader.read_indices(columns) + elif all(map(lambda t: t == str, column_types)): + table = reader.read_names(columns) + else: + column_type_names = [t.__name__ for t in column_types] + raise TypeError("Columns must be indices or names. " + "Got columns {} of types {}" + .format(columns, column_type_names)) + + # Feather v1 already respects the column selection + if reader.version < 3: + return table + # Feather v2 reads with sorted / deduplicated selection + elif sorted(set(columns)) == columns: + return table + else: + # follow exact order / selection of names + return table.select(columns) diff --git a/src/arrow/python/pyarrow/filesystem.py b/src/arrow/python/pyarrow/filesystem.py new file mode 100644 index 000000000..c2017e42b --- /dev/null +++ b/src/arrow/python/pyarrow/filesystem.py @@ -0,0 +1,511 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import os +import posixpath +import sys +import urllib.parse +import warnings + +from os.path import join as pjoin + +import pyarrow as pa +from pyarrow.util import implements, _stringify_path, _is_path_like, _DEPR_MSG + + +_FS_DEPR_MSG = _DEPR_MSG.format( + "filesystem.LocalFileSystem", "2.0.0", "fs.LocalFileSystem" +) + + +class FileSystem: + """ + Abstract filesystem interface. + """ + + def cat(self, path): + """ + Return contents of file as a bytes object. + + Parameters + ---------- + path : str + File path to read content from. + + Returns + ------- + contents : bytes + """ + with self.open(path, 'rb') as f: + return f.read() + + def ls(self, path): + """ + Return list of file paths. + + Parameters + ---------- + path : str + Directory to list contents from. + """ + raise NotImplementedError + + def delete(self, path, recursive=False): + """ + Delete the indicated file or directory. + + Parameters + ---------- + path : str + Path to delete. + recursive : bool, default False + If True, also delete child paths for directories. + """ + raise NotImplementedError + + def disk_usage(self, path): + """ + Compute bytes used by all contents under indicated path in file tree. + + Parameters + ---------- + path : str + Can be a file path or directory. + + Returns + ------- + usage : int + """ + path = _stringify_path(path) + path_info = self.stat(path) + if path_info['kind'] == 'file': + return path_info['size'] + + total = 0 + for root, directories, files in self.walk(path): + for child_path in files: + abspath = self._path_join(root, child_path) + total += self.stat(abspath)['size'] + + return total + + def _path_join(self, *args): + return self.pathsep.join(args) + + def stat(self, path): + """ + Information about a filesystem entry. + + Returns + ------- + stat : dict + """ + raise NotImplementedError('FileSystem.stat') + + def rm(self, path, recursive=False): + """ + Alias for FileSystem.delete. + """ + return self.delete(path, recursive=recursive) + + def mv(self, path, new_path): + """ + Alias for FileSystem.rename. + """ + return self.rename(path, new_path) + + def rename(self, path, new_path): + """ + Rename file, like UNIX mv command. + + Parameters + ---------- + path : str + Path to alter. + new_path : str + Path to move to. + """ + raise NotImplementedError('FileSystem.rename') + + def mkdir(self, path, create_parents=True): + """ + Create a directory. + + Parameters + ---------- + path : str + Path to the directory. + create_parents : bool, default True + If the parent directories don't exists create them as well. + """ + raise NotImplementedError + + def exists(self, path): + """ + Return True if path exists. + + Parameters + ---------- + path : str + Path to check. + """ + raise NotImplementedError + + def isdir(self, path): + """ + Return True if path is a directory. + + Parameters + ---------- + path : str + Path to check. + """ + raise NotImplementedError + + def isfile(self, path): + """ + Return True if path is a file. + + Parameters + ---------- + path : str + Path to check. + """ + raise NotImplementedError + + def _isfilestore(self): + """ + Returns True if this FileSystem is a unix-style file store with + directories. + """ + raise NotImplementedError + + def read_parquet(self, path, columns=None, metadata=None, schema=None, + use_threads=True, use_pandas_metadata=False): + """ + Read Parquet data from path in file system. Can read from a single file + or a directory of files. + + Parameters + ---------- + path : str + Single file path or directory + columns : List[str], optional + Subset of columns to read. + metadata : pyarrow.parquet.FileMetaData + Known metadata to validate files against. + schema : pyarrow.parquet.Schema + Known schema to validate files against. Alternative to metadata + argument. + use_threads : bool, default True + Perform multi-threaded column reads. + use_pandas_metadata : bool, default False + If True and file has custom pandas schema metadata, ensure that + index columns are also loaded. + + Returns + ------- + table : pyarrow.Table + """ + from pyarrow.parquet import ParquetDataset + dataset = ParquetDataset(path, schema=schema, metadata=metadata, + filesystem=self) + return dataset.read(columns=columns, use_threads=use_threads, + use_pandas_metadata=use_pandas_metadata) + + def open(self, path, mode='rb'): + """ + Open file for reading or writing. + """ + raise NotImplementedError + + @property + def pathsep(self): + return '/' + + +class LocalFileSystem(FileSystem): + + _instance = None + + def __init__(self): + warnings.warn(_FS_DEPR_MSG, FutureWarning, stacklevel=2) + super().__init__() + + @classmethod + def _get_instance(cls): + if cls._instance is None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cls._instance = LocalFileSystem() + return cls._instance + + @classmethod + def get_instance(cls): + warnings.warn(_FS_DEPR_MSG, FutureWarning, stacklevel=2) + return cls._get_instance() + + @implements(FileSystem.ls) + def ls(self, path): + path = _stringify_path(path) + return sorted(pjoin(path, x) for x in os.listdir(path)) + + @implements(FileSystem.mkdir) + def mkdir(self, path, create_parents=True): + path = _stringify_path(path) + if create_parents: + os.makedirs(path) + else: + os.mkdir(path) + + @implements(FileSystem.isdir) + def isdir(self, path): + path = _stringify_path(path) + return os.path.isdir(path) + + @implements(FileSystem.isfile) + def isfile(self, path): + path = _stringify_path(path) + return os.path.isfile(path) + + @implements(FileSystem._isfilestore) + def _isfilestore(self): + return True + + @implements(FileSystem.exists) + def exists(self, path): + path = _stringify_path(path) + return os.path.exists(path) + + @implements(FileSystem.open) + def open(self, path, mode='rb'): + """ + Open file for reading or writing. + """ + path = _stringify_path(path) + return open(path, mode=mode) + + @property + def pathsep(self): + return os.path.sep + + def walk(self, path): + """ + Directory tree generator, see os.walk. + """ + path = _stringify_path(path) + return os.walk(path) + + +class DaskFileSystem(FileSystem): + """ + Wraps s3fs Dask filesystem implementation like s3fs, gcsfs, etc. + """ + + def __init__(self, fs): + warnings.warn( + "The pyarrow.filesystem.DaskFileSystem/S3FSWrapper are deprecated " + "as of pyarrow 3.0.0, and will be removed in a future version.", + FutureWarning, stacklevel=2) + self.fs = fs + + @implements(FileSystem.isdir) + def isdir(self, path): + raise NotImplementedError("Unsupported file system API") + + @implements(FileSystem.isfile) + def isfile(self, path): + raise NotImplementedError("Unsupported file system API") + + @implements(FileSystem._isfilestore) + def _isfilestore(self): + """ + Object Stores like S3 and GCSFS are based on key lookups, not true + file-paths. + """ + return False + + @implements(FileSystem.delete) + def delete(self, path, recursive=False): + path = _stringify_path(path) + return self.fs.rm(path, recursive=recursive) + + @implements(FileSystem.exists) + def exists(self, path): + path = _stringify_path(path) + return self.fs.exists(path) + + @implements(FileSystem.mkdir) + def mkdir(self, path, create_parents=True): + path = _stringify_path(path) + if create_parents: + return self.fs.mkdirs(path) + else: + return self.fs.mkdir(path) + + @implements(FileSystem.open) + def open(self, path, mode='rb'): + """ + Open file for reading or writing. + """ + path = _stringify_path(path) + return self.fs.open(path, mode=mode) + + def ls(self, path, detail=False): + path = _stringify_path(path) + return self.fs.ls(path, detail=detail) + + def walk(self, path): + """ + Directory tree generator, like os.walk. + """ + path = _stringify_path(path) + return self.fs.walk(path) + + +class S3FSWrapper(DaskFileSystem): + + @implements(FileSystem.isdir) + def isdir(self, path): + path = _sanitize_s3(_stringify_path(path)) + try: + contents = self.fs.ls(path) + if len(contents) == 1 and contents[0] == path: + return False + else: + return True + except OSError: + return False + + @implements(FileSystem.isfile) + def isfile(self, path): + path = _sanitize_s3(_stringify_path(path)) + try: + contents = self.fs.ls(path) + return len(contents) == 1 and contents[0] == path + except OSError: + return False + + def walk(self, path, refresh=False): + """ + Directory tree generator, like os.walk. + + Generator version of what is in s3fs, which yields a flattened list of + files. + """ + path = _sanitize_s3(_stringify_path(path)) + directories = set() + files = set() + + for key in list(self.fs._ls(path, refresh=refresh)): + path = key['Key'] + if key['StorageClass'] == 'DIRECTORY': + directories.add(path) + elif key['StorageClass'] == 'BUCKET': + pass + else: + files.add(path) + + # s3fs creates duplicate 'DIRECTORY' entries + files = sorted([posixpath.split(f)[1] for f in files + if f not in directories]) + directories = sorted([posixpath.split(x)[1] + for x in directories]) + + yield path, directories, files + + for directory in directories: + yield from self.walk(directory, refresh=refresh) + + +def _sanitize_s3(path): + if path.startswith('s3://'): + return path.replace('s3://', '') + else: + return path + + +def _ensure_filesystem(fs): + fs_type = type(fs) + + # If the arrow filesystem was subclassed, assume it supports the full + # interface and return it + if not issubclass(fs_type, FileSystem): + if "fsspec" in sys.modules: + fsspec = sys.modules["fsspec"] + if isinstance(fs, fsspec.AbstractFileSystem): + # for recent fsspec versions that stop inheriting from + # pyarrow.filesystem.FileSystem, still allow fsspec + # filesystems (which should be compatible with our legacy fs) + return fs + + raise OSError('Unrecognized filesystem: {}'.format(fs_type)) + else: + return fs + + +def resolve_filesystem_and_path(where, filesystem=None): + """ + Return filesystem from path which could be an HDFS URI, a local URI, + or a plain filesystem path. + """ + if not _is_path_like(where): + if filesystem is not None: + raise ValueError("filesystem passed but where is file-like, so" + " there is nothing to open with filesystem.") + return filesystem, where + + if filesystem is not None: + filesystem = _ensure_filesystem(filesystem) + if isinstance(filesystem, LocalFileSystem): + path = _stringify_path(where) + elif not isinstance(where, str): + raise TypeError( + "Expected string path; path-like objects are only allowed " + "with a local filesystem" + ) + else: + path = where + return filesystem, path + + path = _stringify_path(where) + + parsed_uri = urllib.parse.urlparse(path) + if parsed_uri.scheme == 'hdfs' or parsed_uri.scheme == 'viewfs': + # Input is hdfs URI such as hdfs://host:port/myfile.parquet + netloc_split = parsed_uri.netloc.split(':') + host = netloc_split[0] + if host == '': + host = 'default' + else: + host = parsed_uri.scheme + "://" + host + port = 0 + if len(netloc_split) == 2 and netloc_split[1].isnumeric(): + port = int(netloc_split[1]) + fs = pa.hdfs._connect(host=host, port=port) + fs_path = parsed_uri.path + elif parsed_uri.scheme == 'file': + # Input is local URI such as file:///home/user/myfile.parquet + fs = LocalFileSystem._get_instance() + fs_path = parsed_uri.path + else: + # Input is local path such as /home/user/myfile.parquet + fs = LocalFileSystem._get_instance() + fs_path = path + + return fs, fs_path diff --git a/src/arrow/python/pyarrow/flight.py b/src/arrow/python/pyarrow/flight.py new file mode 100644 index 000000000..0664ff2c9 --- /dev/null +++ b/src/arrow/python/pyarrow/flight.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyarrow._flight import ( # noqa:F401 + connect, + Action, + ActionType, + BasicAuth, + CallInfo, + CertKeyPair, + ClientAuthHandler, + ClientMiddleware, + ClientMiddlewareFactory, + DescriptorType, + FlightCallOptions, + FlightCancelledError, + FlightClient, + FlightDataStream, + FlightDescriptor, + FlightEndpoint, + FlightError, + FlightInfo, + FlightInternalError, + FlightMetadataReader, + FlightMetadataWriter, + FlightMethod, + FlightServerBase, + FlightServerError, + FlightStreamChunk, + FlightStreamReader, + FlightStreamWriter, + FlightTimedOutError, + FlightUnauthenticatedError, + FlightUnauthorizedError, + FlightUnavailableError, + FlightWriteSizeExceededError, + GeneratorStream, + Location, + MetadataRecordBatchReader, + MetadataRecordBatchWriter, + RecordBatchStream, + Result, + SchemaResult, + ServerAuthHandler, + ServerCallContext, + ServerMiddleware, + ServerMiddlewareFactory, + Ticket, +) diff --git a/src/arrow/python/pyarrow/fs.py b/src/arrow/python/pyarrow/fs.py new file mode 100644 index 000000000..5d3326861 --- /dev/null +++ b/src/arrow/python/pyarrow/fs.py @@ -0,0 +1,405 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +FileSystem abstraction to interact with various local and remote filesystems. +""" + +from pyarrow.util import _is_path_like, _stringify_path + +from pyarrow._fs import ( # noqa + FileSelector, + FileType, + FileInfo, + FileSystem, + LocalFileSystem, + SubTreeFileSystem, + _MockFileSystem, + FileSystemHandler, + PyFileSystem, + _copy_files, + _copy_files_selector, +) + +# For backward compatibility. +FileStats = FileInfo + +_not_imported = [] + +try: + from pyarrow._hdfs import HadoopFileSystem # noqa +except ImportError: + _not_imported.append("HadoopFileSystem") + +try: + from pyarrow._s3fs import ( # noqa + S3FileSystem, S3LogLevel, initialize_s3, finalize_s3) +except ImportError: + _not_imported.append("S3FileSystem") +else: + initialize_s3() + + +def __getattr__(name): + if name in _not_imported: + raise ImportError( + "The pyarrow installation is not built with support for " + "'{0}'".format(name) + ) + + raise AttributeError( + "module 'pyarrow.fs' has no attribute '{0}'".format(name) + ) + + +def _filesystem_from_str(uri): + # instantiate the file system from an uri, if the uri has a path + # component then it will be treated as a path prefix + filesystem, prefix = FileSystem.from_uri(uri) + prefix = filesystem.normalize_path(prefix) + if prefix: + # validate that the prefix is pointing to a directory + prefix_info = filesystem.get_file_info([prefix])[0] + if prefix_info.type != FileType.Directory: + raise ValueError( + "The path component of the filesystem URI must point to a " + "directory but it has a type: `{}`. The path component " + "is `{}` and the given filesystem URI is `{}`".format( + prefix_info.type.name, prefix_info.path, uri + ) + ) + filesystem = SubTreeFileSystem(prefix, filesystem) + return filesystem + + +def _ensure_filesystem( + filesystem, use_mmap=False, allow_legacy_filesystem=False +): + if isinstance(filesystem, FileSystem): + return filesystem + elif isinstance(filesystem, str): + if use_mmap: + raise ValueError( + "Specifying to use memory mapping not supported for " + "filesytem specified as an URI string" + ) + return _filesystem_from_str(filesystem) + + # handle fsspec-compatible filesystems + try: + import fsspec + except ImportError: + pass + else: + if isinstance(filesystem, fsspec.AbstractFileSystem): + if type(filesystem).__name__ == 'LocalFileSystem': + # In case its a simple LocalFileSystem, use native arrow one + return LocalFileSystem(use_mmap=use_mmap) + return PyFileSystem(FSSpecHandler(filesystem)) + + # map old filesystems to new ones + import pyarrow.filesystem as legacyfs + + if isinstance(filesystem, legacyfs.LocalFileSystem): + return LocalFileSystem(use_mmap=use_mmap) + # TODO handle HDFS? + if allow_legacy_filesystem and isinstance(filesystem, legacyfs.FileSystem): + return filesystem + + raise TypeError( + "Unrecognized filesystem: {}. `filesystem` argument must be a " + "FileSystem instance or a valid file system URI'".format( + type(filesystem)) + ) + + +def _resolve_filesystem_and_path( + path, filesystem=None, allow_legacy_filesystem=False +): + """ + Return filesystem/path from path which could be an URI or a plain + filesystem path. + """ + if not _is_path_like(path): + if filesystem is not None: + raise ValueError( + "'filesystem' passed but the specified path is file-like, so" + " there is nothing to open with 'filesystem'." + ) + return filesystem, path + + if filesystem is not None: + filesystem = _ensure_filesystem( + filesystem, allow_legacy_filesystem=allow_legacy_filesystem + ) + if isinstance(filesystem, LocalFileSystem): + path = _stringify_path(path) + elif not isinstance(path, str): + raise TypeError( + "Expected string path; path-like objects are only allowed " + "with a local filesystem" + ) + if not allow_legacy_filesystem: + path = filesystem.normalize_path(path) + return filesystem, path + + path = _stringify_path(path) + + # if filesystem is not given, try to automatically determine one + # first check if the file exists as a local (relative) file path + # if not then try to parse the path as an URI + filesystem = LocalFileSystem() + try: + file_info = filesystem.get_file_info(path) + except OSError: + file_info = None + exists_locally = False + else: + exists_locally = (file_info.type != FileType.NotFound) + + # if the file or directory doesn't exists locally, then assume that + # the path is an URI describing the file system as well + if not exists_locally: + try: + filesystem, path = FileSystem.from_uri(path) + except ValueError as e: + # neither an URI nor a locally existing path, so assume that + # local path was given and propagate a nicer file not found error + # instead of a more confusing scheme parsing error + if "empty scheme" not in str(e): + raise + else: + path = filesystem.normalize_path(path) + + return filesystem, path + + +def copy_files(source, destination, + source_filesystem=None, destination_filesystem=None, + *, chunk_size=1024*1024, use_threads=True): + """ + Copy files between FileSystems. + + This functions allows you to recursively copy directories of files from + one file system to another, such as from S3 to your local machine. + + Parameters + ---------- + source : string + Source file path or URI to a single file or directory. + If a directory, files will be copied recursively from this path. + destination : string + Destination file path or URI. If `source` is a file, `destination` + is also interpreted as the destination file (not directory). + Directories will be created as necessary. + source_filesystem : FileSystem, optional + Source filesystem, needs to be specified if `source` is not a URI, + otherwise inferred. + destination_filesystem : FileSystem, optional + Destination filesystem, needs to be specified if `destination` is not + a URI, otherwise inferred. + chunk_size : int, default 1MB + The maximum size of block to read before flushing to the + destination file. A larger chunk_size will use more memory while + copying but may help accommodate high latency FileSystems. + use_threads : bool, default True + Whether to use multiple threads to accelerate copying. + + Examples + -------- + Copy an S3 bucket's files to a local directory: + + >>> copy_files("s3://your-bucket-name", "local-directory") + + Using a FileSystem object: + + >>> copy_files("your-bucket-name", "local-directory", + ... source_filesystem=S3FileSystem(...)) + + """ + source_fs, source_path = _resolve_filesystem_and_path( + source, source_filesystem + ) + destination_fs, destination_path = _resolve_filesystem_and_path( + destination, destination_filesystem + ) + + file_info = source_fs.get_file_info(source_path) + if file_info.type == FileType.Directory: + source_sel = FileSelector(source_path, recursive=True) + _copy_files_selector(source_fs, source_sel, + destination_fs, destination_path, + chunk_size, use_threads) + else: + _copy_files(source_fs, source_path, + destination_fs, destination_path, + chunk_size, use_threads) + + +class FSSpecHandler(FileSystemHandler): + """ + Handler for fsspec-based Python filesystems. + + https://filesystem-spec.readthedocs.io/en/latest/index.html + + Parameters + ---------- + fs : The FSSpec-compliant filesystem instance. + + Examples + -------- + >>> PyFileSystem(FSSpecHandler(fsspec_fs)) + """ + + def __init__(self, fs): + self.fs = fs + + def __eq__(self, other): + if isinstance(other, FSSpecHandler): + return self.fs == other.fs + return NotImplemented + + def __ne__(self, other): + if isinstance(other, FSSpecHandler): + return self.fs != other.fs + return NotImplemented + + def get_type_name(self): + protocol = self.fs.protocol + if isinstance(protocol, list): + protocol = protocol[0] + return "fsspec+{0}".format(protocol) + + def normalize_path(self, path): + return path + + @staticmethod + def _create_file_info(path, info): + size = info["size"] + if info["type"] == "file": + ftype = FileType.File + elif info["type"] == "directory": + ftype = FileType.Directory + # some fsspec filesystems include a file size for directories + size = None + else: + ftype = FileType.Unknown + return FileInfo(path, ftype, size=size, mtime=info.get("mtime", None)) + + def get_file_info(self, paths): + infos = [] + for path in paths: + try: + info = self.fs.info(path) + except FileNotFoundError: + infos.append(FileInfo(path, FileType.NotFound)) + else: + infos.append(self._create_file_info(path, info)) + return infos + + def get_file_info_selector(self, selector): + if not self.fs.isdir(selector.base_dir): + if self.fs.exists(selector.base_dir): + raise NotADirectoryError(selector.base_dir) + else: + if selector.allow_not_found: + return [] + else: + raise FileNotFoundError(selector.base_dir) + + if selector.recursive: + maxdepth = None + else: + maxdepth = 1 + + infos = [] + selected_files = self.fs.find( + selector.base_dir, maxdepth=maxdepth, withdirs=True, detail=True + ) + for path, info in selected_files.items(): + infos.append(self._create_file_info(path, info)) + + return infos + + def create_dir(self, path, recursive): + # mkdir also raises FileNotFoundError when base directory is not found + try: + self.fs.mkdir(path, create_parents=recursive) + except FileExistsError: + pass + + def delete_dir(self, path): + self.fs.rm(path, recursive=True) + + def _delete_dir_contents(self, path): + for subpath in self.fs.listdir(path, detail=False): + if self.fs.isdir(subpath): + self.fs.rm(subpath, recursive=True) + elif self.fs.isfile(subpath): + self.fs.rm(subpath) + + def delete_dir_contents(self, path): + if path.strip("/") == "": + raise ValueError( + "delete_dir_contents called on path '", path, "'") + self._delete_dir_contents(path) + + def delete_root_dir_contents(self): + self._delete_dir_contents("/") + + def delete_file(self, path): + # fs.rm correctly raises IsADirectoryError when `path` is a directory + # instead of a file and `recursive` is not set to True + if not self.fs.exists(path): + raise FileNotFoundError(path) + self.fs.rm(path) + + def move(self, src, dest): + self.fs.mv(src, dest, recursive=True) + + def copy_file(self, src, dest): + # fs.copy correctly raises IsADirectoryError when `src` is a directory + # instead of a file + self.fs.copy(src, dest) + + # TODO can we read/pass metadata (e.g. Content-Type) in the methods below? + + def open_input_stream(self, path): + from pyarrow import PythonFile + + if not self.fs.isfile(path): + raise FileNotFoundError(path) + + return PythonFile(self.fs.open(path, mode="rb"), mode="r") + + def open_input_file(self, path): + from pyarrow import PythonFile + + if not self.fs.isfile(path): + raise FileNotFoundError(path) + + return PythonFile(self.fs.open(path, mode="rb"), mode="r") + + def open_output_stream(self, path, metadata): + from pyarrow import PythonFile + + return PythonFile(self.fs.open(path, mode="wb"), mode="w") + + def open_append_stream(self, path, metadata): + from pyarrow import PythonFile + + return PythonFile(self.fs.open(path, mode="ab"), mode="w") diff --git a/src/arrow/python/pyarrow/gandiva.pyx b/src/arrow/python/pyarrow/gandiva.pyx new file mode 100644 index 000000000..12d572b33 --- /dev/null +++ b/src/arrow/python/pyarrow/gandiva.pyx @@ -0,0 +1,518 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: language_level = 3 + +from libcpp cimport bool as c_bool, nullptr +from libcpp.memory cimport shared_ptr, unique_ptr, make_shared +from libcpp.string cimport string as c_string +from libcpp.vector cimport vector as c_vector +from libcpp.unordered_set cimport unordered_set as c_unordered_set +from libc.stdint cimport int64_t, int32_t, uint8_t, uintptr_t + +from pyarrow.includes.libarrow cimport * +from pyarrow.lib cimport (Array, DataType, Field, MemoryPool, RecordBatch, + Schema, check_status, pyarrow_wrap_array, + pyarrow_wrap_data_type, ensure_type, _Weakrefable, + pyarrow_wrap_field) +from pyarrow.lib import frombytes + +from pyarrow.includes.libgandiva cimport ( + CCondition, CExpression, + CNode, CProjector, CFilter, + CSelectionVector, + CSelectionVector_Mode, + _ensure_selection_mode, + CConfiguration, + CConfigurationBuilder, + TreeExprBuilder_MakeExpression, + TreeExprBuilder_MakeFunction, + TreeExprBuilder_MakeBoolLiteral, + TreeExprBuilder_MakeUInt8Literal, + TreeExprBuilder_MakeUInt16Literal, + TreeExprBuilder_MakeUInt32Literal, + TreeExprBuilder_MakeUInt64Literal, + TreeExprBuilder_MakeInt8Literal, + TreeExprBuilder_MakeInt16Literal, + TreeExprBuilder_MakeInt32Literal, + TreeExprBuilder_MakeInt64Literal, + TreeExprBuilder_MakeFloatLiteral, + TreeExprBuilder_MakeDoubleLiteral, + TreeExprBuilder_MakeStringLiteral, + TreeExprBuilder_MakeBinaryLiteral, + TreeExprBuilder_MakeField, + TreeExprBuilder_MakeIf, + TreeExprBuilder_MakeAnd, + TreeExprBuilder_MakeOr, + TreeExprBuilder_MakeCondition, + TreeExprBuilder_MakeInExpressionInt32, + TreeExprBuilder_MakeInExpressionInt64, + TreeExprBuilder_MakeInExpressionTime32, + TreeExprBuilder_MakeInExpressionTime64, + TreeExprBuilder_MakeInExpressionDate32, + TreeExprBuilder_MakeInExpressionDate64, + TreeExprBuilder_MakeInExpressionTimeStamp, + TreeExprBuilder_MakeInExpressionString, + TreeExprBuilder_MakeInExpressionBinary, + SelectionVector_MakeInt16, + SelectionVector_MakeInt32, + SelectionVector_MakeInt64, + Projector_Make, + Filter_Make, + CFunctionSignature, + GetRegisteredFunctionSignatures) + +cdef class Node(_Weakrefable): + cdef: + shared_ptr[CNode] node + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use the " + "TreeExprBuilder API directly" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CNode] node): + cdef Node self = Node.__new__(Node) + self.node = node + return self + + def __str__(self): + return self.node.get().ToString().decode() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def return_type(self): + return pyarrow_wrap_data_type(self.node.get().return_type()) + +cdef class Expression(_Weakrefable): + cdef: + shared_ptr[CExpression] expression + + cdef void init(self, shared_ptr[CExpression] expression): + self.expression = expression + + def __str__(self): + return self.expression.get().ToString().decode() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def root(self): + return Node.create(self.expression.get().root()) + + def result(self): + return pyarrow_wrap_field(self.expression.get().result()) + +cdef class Condition(_Weakrefable): + cdef: + shared_ptr[CCondition] condition + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use the " + "TreeExprBuilder API instead" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CCondition] condition): + cdef Condition self = Condition.__new__(Condition) + self.condition = condition + return self + + def __str__(self): + return self.condition.get().ToString().decode() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def root(self): + return Node.create(self.condition.get().root()) + + def result(self): + return pyarrow_wrap_field(self.condition.get().result()) + +cdef class SelectionVector(_Weakrefable): + cdef: + shared_ptr[CSelectionVector] selection_vector + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly." + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CSelectionVector] selection_vector): + cdef SelectionVector self = SelectionVector.__new__(SelectionVector) + self.selection_vector = selection_vector + return self + + def to_array(self): + cdef shared_ptr[CArray] result = self.selection_vector.get().ToArray() + return pyarrow_wrap_array(result) + +cdef class Projector(_Weakrefable): + cdef: + shared_ptr[CProjector] projector + MemoryPool pool + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "make_projector instead" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CProjector] projector, MemoryPool pool): + cdef Projector self = Projector.__new__(Projector) + self.projector = projector + self.pool = pool + return self + + @property + def llvm_ir(self): + return self.projector.get().DumpIR().decode() + + def evaluate(self, RecordBatch batch, SelectionVector selection=None): + cdef vector[shared_ptr[CArray]] results + if selection is None: + check_status(self.projector.get().Evaluate( + batch.sp_batch.get()[0], self.pool.pool, &results)) + else: + check_status( + self.projector.get().Evaluate( + batch.sp_batch.get()[0], selection.selection_vector.get(), + self.pool.pool, &results)) + cdef shared_ptr[CArray] result + arrays = [] + for result in results: + arrays.append(pyarrow_wrap_array(result)) + return arrays + +cdef class Filter(_Weakrefable): + cdef: + shared_ptr[CFilter] filter + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "make_filter instead" + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CFilter] filter): + cdef Filter self = Filter.__new__(Filter) + self.filter = filter + return self + + @property + def llvm_ir(self): + return self.filter.get().DumpIR().decode() + + def evaluate(self, RecordBatch batch, MemoryPool pool, dtype='int32'): + cdef: + DataType type = ensure_type(dtype) + shared_ptr[CSelectionVector] selection + + if type.id == _Type_INT16: + check_status(SelectionVector_MakeInt16( + batch.num_rows, pool.pool, &selection)) + elif type.id == _Type_INT32: + check_status(SelectionVector_MakeInt32( + batch.num_rows, pool.pool, &selection)) + elif type.id == _Type_INT64: + check_status(SelectionVector_MakeInt64( + batch.num_rows, pool.pool, &selection)) + else: + raise ValueError("'dtype' of the selection vector should be " + "one of 'int16', 'int32' and 'int64'.") + + check_status(self.filter.get().Evaluate( + batch.sp_batch.get()[0], selection)) + return SelectionVector.create(selection) + + +cdef class TreeExprBuilder(_Weakrefable): + + def make_literal(self, value, dtype): + cdef: + DataType type = ensure_type(dtype) + shared_ptr[CNode] r + + if type.id == _Type_BOOL: + r = TreeExprBuilder_MakeBoolLiteral(value) + elif type.id == _Type_UINT8: + r = TreeExprBuilder_MakeUInt8Literal(value) + elif type.id == _Type_UINT16: + r = TreeExprBuilder_MakeUInt16Literal(value) + elif type.id == _Type_UINT32: + r = TreeExprBuilder_MakeUInt32Literal(value) + elif type.id == _Type_UINT64: + r = TreeExprBuilder_MakeUInt64Literal(value) + elif type.id == _Type_INT8: + r = TreeExprBuilder_MakeInt8Literal(value) + elif type.id == _Type_INT16: + r = TreeExprBuilder_MakeInt16Literal(value) + elif type.id == _Type_INT32: + r = TreeExprBuilder_MakeInt32Literal(value) + elif type.id == _Type_INT64: + r = TreeExprBuilder_MakeInt64Literal(value) + elif type.id == _Type_FLOAT: + r = TreeExprBuilder_MakeFloatLiteral(value) + elif type.id == _Type_DOUBLE: + r = TreeExprBuilder_MakeDoubleLiteral(value) + elif type.id == _Type_STRING: + r = TreeExprBuilder_MakeStringLiteral(value.encode('UTF-8')) + elif type.id == _Type_BINARY: + r = TreeExprBuilder_MakeBinaryLiteral(value) + else: + raise TypeError("Didn't recognize dtype " + str(dtype)) + + return Node.create(r) + + def make_expression(self, Node root_node, Field return_field): + cdef shared_ptr[CExpression] r = TreeExprBuilder_MakeExpression( + root_node.node, return_field.sp_field) + cdef Expression expression = Expression() + expression.init(r) + return expression + + def make_function(self, name, children, DataType return_type): + cdef c_vector[shared_ptr[CNode]] c_children + cdef Node child + for child in children: + c_children.push_back(child.node) + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeFunction( + name.encode(), c_children, return_type.sp_type) + return Node.create(r) + + def make_field(self, Field field): + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeField(field.sp_field) + return Node.create(r) + + def make_if(self, Node condition, Node this_node, + Node else_node, DataType return_type): + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeIf( + condition.node, this_node.node, else_node.node, + return_type.sp_type) + return Node.create(r) + + def make_and(self, children): + cdef c_vector[shared_ptr[CNode]] c_children + cdef Node child + for child in children: + c_children.push_back(child.node) + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeAnd(c_children) + return Node.create(r) + + def make_or(self, children): + cdef c_vector[shared_ptr[CNode]] c_children + cdef Node child + for child in children: + c_children.push_back(child.node) + cdef shared_ptr[CNode] r = TreeExprBuilder_MakeOr(c_children) + return Node.create(r) + + def _make_in_expression_int32(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int32_t] c_values + cdef int32_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionInt32(node.node, c_values) + return Node.create(r) + + def _make_in_expression_int64(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionInt64(node.node, c_values) + return Node.create(r) + + def _make_in_expression_time32(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int32_t] c_values + cdef int32_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionTime32(node.node, c_values) + return Node.create(r) + + def _make_in_expression_time64(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionTime64(node.node, c_values) + return Node.create(r) + + def _make_in_expression_date32(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int32_t] c_values + cdef int32_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionDate32(node.node, c_values) + return Node.create(r) + + def _make_in_expression_date64(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionDate64(node.node, c_values) + return Node.create(r) + + def _make_in_expression_timestamp(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[int64_t] c_values + cdef int64_t v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionTimeStamp(node.node, c_values) + return Node.create(r) + + def _make_in_expression_binary(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[c_string] c_values + cdef c_string v + for v in values: + c_values.insert(v) + r = TreeExprBuilder_MakeInExpressionString(node.node, c_values) + return Node.create(r) + + def _make_in_expression_string(self, Node node, values): + cdef shared_ptr[CNode] r + cdef c_unordered_set[c_string] c_values + cdef c_string _v + for v in values: + _v = v.encode('UTF-8') + c_values.insert(_v) + r = TreeExprBuilder_MakeInExpressionString(node.node, c_values) + return Node.create(r) + + def make_in_expression(self, Node node, values, dtype): + cdef DataType type = ensure_type(dtype) + + if type.id == _Type_INT32: + return self._make_in_expression_int32(node, values) + elif type.id == _Type_INT64: + return self._make_in_expression_int64(node, values) + elif type.id == _Type_TIME32: + return self._make_in_expression_time32(node, values) + elif type.id == _Type_TIME64: + return self._make_in_expression_time64(node, values) + elif type.id == _Type_TIMESTAMP: + return self._make_in_expression_timestamp(node, values) + elif type.id == _Type_DATE32: + return self._make_in_expression_date32(node, values) + elif type.id == _Type_DATE64: + return self._make_in_expression_date64(node, values) + elif type.id == _Type_BINARY: + return self._make_in_expression_binary(node, values) + elif type.id == _Type_STRING: + return self._make_in_expression_string(node, values) + else: + raise TypeError("Data type " + str(dtype) + " not supported.") + + def make_condition(self, Node condition): + cdef shared_ptr[CCondition] r = TreeExprBuilder_MakeCondition( + condition.node) + return Condition.create(r) + +cpdef make_projector(Schema schema, children, MemoryPool pool, + str selection_mode="NONE"): + cdef c_vector[shared_ptr[CExpression]] c_children + cdef Expression child + for child in children: + c_children.push_back(child.expression) + cdef shared_ptr[CProjector] result + check_status( + Projector_Make(schema.sp_schema, c_children, + _ensure_selection_mode(selection_mode), + CConfigurationBuilder.DefaultConfiguration(), + &result)) + return Projector.create(result, pool) + +cpdef make_filter(Schema schema, Condition condition): + cdef shared_ptr[CFilter] result + check_status( + Filter_Make(schema.sp_schema, condition.condition, &result)) + return Filter.create(result) + +cdef class FunctionSignature(_Weakrefable): + """ + Signature of a Gandiva function including name, parameter types + and return type. + """ + + cdef: + shared_ptr[CFunctionSignature] signature + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly." + .format(self.__class__.__name__)) + + @staticmethod + cdef create(shared_ptr[CFunctionSignature] signature): + cdef FunctionSignature self = FunctionSignature.__new__( + FunctionSignature) + self.signature = signature + return self + + def return_type(self): + return pyarrow_wrap_data_type(self.signature.get().ret_type()) + + def param_types(self): + result = [] + cdef vector[shared_ptr[CDataType]] types = \ + self.signature.get().param_types() + for t in types: + result.append(pyarrow_wrap_data_type(t)) + return result + + def name(self): + return self.signature.get().base_name().decode() + + def __repr__(self): + signature = self.signature.get().ToString().decode() + return "FunctionSignature(" + signature + ")" + + +def get_registered_function_signatures(): + """ + Return the function in Gandiva's ExpressionRegistry. + + Returns + ------- + registry: a list of registered function signatures + """ + results = [] + + cdef vector[shared_ptr[CFunctionSignature]] signatures = \ + GetRegisteredFunctionSignatures() + + for signature in signatures: + results.append(FunctionSignature.create(signature)) + + return results diff --git a/src/arrow/python/pyarrow/hdfs.py b/src/arrow/python/pyarrow/hdfs.py new file mode 100644 index 000000000..56667bd5d --- /dev/null +++ b/src/arrow/python/pyarrow/hdfs.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import os +import posixpath +import sys +import warnings + +from pyarrow.util import implements, _DEPR_MSG +from pyarrow.filesystem import FileSystem +import pyarrow._hdfsio as _hdfsio + + +class HadoopFileSystem(_hdfsio.HadoopFileSystem, FileSystem): + """ + DEPRECATED: FileSystem interface for HDFS cluster. + + See pyarrow.hdfs.connect for full connection details + + .. deprecated:: 2.0 + ``pyarrow.hdfs.HadoopFileSystem`` is deprecated, + please use ``pyarrow.fs.HadoopFileSystem`` instead. + """ + + def __init__(self, host="default", port=0, user=None, kerb_ticket=None, + driver='libhdfs', extra_conf=None): + warnings.warn( + _DEPR_MSG.format( + "hdfs.HadoopFileSystem", "2.0.0", "fs.HadoopFileSystem"), + FutureWarning, stacklevel=2) + if driver == 'libhdfs': + _maybe_set_hadoop_classpath() + + self._connect(host, port, user, kerb_ticket, extra_conf) + + def __reduce__(self): + return (HadoopFileSystem, (self.host, self.port, self.user, + self.kerb_ticket, self.extra_conf)) + + def _isfilestore(self): + """ + Return True if this is a Unix-style file store with directories. + """ + return True + + @implements(FileSystem.isdir) + def isdir(self, path): + return super().isdir(path) + + @implements(FileSystem.isfile) + def isfile(self, path): + return super().isfile(path) + + @implements(FileSystem.delete) + def delete(self, path, recursive=False): + return super().delete(path, recursive) + + def mkdir(self, path, **kwargs): + """ + Create directory in HDFS. + + Parameters + ---------- + path : str + Directory path to create, including any parent directories. + + Notes + ----- + libhdfs does not support create_parents=False, so we ignore this here + """ + return super().mkdir(path) + + @implements(FileSystem.rename) + def rename(self, path, new_path): + return super().rename(path, new_path) + + @implements(FileSystem.exists) + def exists(self, path): + return super().exists(path) + + def ls(self, path, detail=False): + """ + Retrieve directory contents and metadata, if requested. + + Parameters + ---------- + path : str + HDFS path to retrieve contents of. + detail : bool, default False + If False, only return list of paths. + + Returns + ------- + result : list of dicts (detail=True) or strings (detail=False) + """ + return super().ls(path, detail) + + def walk(self, top_path): + """ + Directory tree generator for HDFS, like os.walk. + + Parameters + ---------- + top_path : str + Root directory for tree traversal. + + Returns + ------- + Generator yielding 3-tuple (dirpath, dirnames, filename) + """ + contents = self.ls(top_path, detail=True) + + directories, files = _libhdfs_walk_files_dirs(top_path, contents) + yield top_path, directories, files + for dirname in directories: + yield from self.walk(self._path_join(top_path, dirname)) + + +def _maybe_set_hadoop_classpath(): + import re + + if re.search(r'hadoop-common[^/]+.jar', os.environ.get('CLASSPATH', '')): + return + + if 'HADOOP_HOME' in os.environ: + if sys.platform != 'win32': + classpath = _derive_hadoop_classpath() + else: + hadoop_bin = '{}/bin/hadoop'.format(os.environ['HADOOP_HOME']) + classpath = _hadoop_classpath_glob(hadoop_bin) + else: + classpath = _hadoop_classpath_glob('hadoop') + + os.environ['CLASSPATH'] = classpath.decode('utf-8') + + +def _derive_hadoop_classpath(): + import subprocess + + find_args = ('find', '-L', os.environ['HADOOP_HOME'], '-name', '*.jar') + find = subprocess.Popen(find_args, stdout=subprocess.PIPE) + xargs_echo = subprocess.Popen(('xargs', 'echo'), + stdin=find.stdout, + stdout=subprocess.PIPE) + jars = subprocess.check_output(('tr', "' '", "':'"), + stdin=xargs_echo.stdout) + hadoop_conf = os.environ["HADOOP_CONF_DIR"] \ + if "HADOOP_CONF_DIR" in os.environ \ + else os.environ["HADOOP_HOME"] + "/etc/hadoop" + return (hadoop_conf + ":").encode("utf-8") + jars + + +def _hadoop_classpath_glob(hadoop_bin): + import subprocess + + hadoop_classpath_args = (hadoop_bin, 'classpath', '--glob') + return subprocess.check_output(hadoop_classpath_args) + + +def _libhdfs_walk_files_dirs(top_path, contents): + files = [] + directories = [] + for c in contents: + scrubbed_name = posixpath.split(c['name'])[1] + if c['kind'] == 'file': + files.append(scrubbed_name) + else: + directories.append(scrubbed_name) + + return directories, files + + +def connect(host="default", port=0, user=None, kerb_ticket=None, + extra_conf=None): + """ + DEPRECATED: Connect to an HDFS cluster. + + All parameters are optional and should only be set if the defaults need + to be overridden. + + Authentication should be automatic if the HDFS cluster uses Kerberos. + However, if a username is specified, then the ticket cache will likely + be required. + + .. deprecated:: 2.0 + ``pyarrow.hdfs.connect`` is deprecated, + please use ``pyarrow.fs.HadoopFileSystem`` instead. + + Parameters + ---------- + host : NameNode. Set to "default" for fs.defaultFS from core-site.xml. + port : NameNode's port. Set to 0 for default or logical (HA) nodes. + user : Username when connecting to HDFS; None implies login user. + kerb_ticket : Path to Kerberos ticket cache. + extra_conf : dict, default None + extra Key/Value pairs for config; Will override any + hdfs-site.xml properties + + Notes + ----- + The first time you call this method, it will take longer than usual due + to JNI spin-up time. + + Returns + ------- + filesystem : HadoopFileSystem + """ + warnings.warn( + _DEPR_MSG.format("hdfs.connect", "2.0.0", "fs.HadoopFileSystem"), + FutureWarning, stacklevel=2 + ) + return _connect( + host=host, port=port, user=user, kerb_ticket=kerb_ticket, + extra_conf=extra_conf + ) + + +def _connect(host="default", port=0, user=None, kerb_ticket=None, + extra_conf=None): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + fs = HadoopFileSystem(host=host, port=port, user=user, + kerb_ticket=kerb_ticket, + extra_conf=extra_conf) + return fs diff --git a/src/arrow/python/pyarrow/includes/__init__.pxd b/src/arrow/python/pyarrow/includes/__init__.pxd new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/arrow/python/pyarrow/includes/__init__.pxd diff --git a/src/arrow/python/pyarrow/includes/common.pxd b/src/arrow/python/pyarrow/includes/common.pxd new file mode 100644 index 000000000..902eaafbb --- /dev/null +++ b/src/arrow/python/pyarrow/includes/common.pxd @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from libc.stdint cimport * +from libcpp cimport bool as c_bool, nullptr +from libcpp.functional cimport function +from libcpp.memory cimport shared_ptr, unique_ptr, make_shared +from libcpp.string cimport string as c_string +from libcpp.utility cimport pair +from libcpp.vector cimport vector +from libcpp.unordered_map cimport unordered_map +from libcpp.unordered_set cimport unordered_set + +from cpython cimport PyObject +from cpython.datetime cimport PyDateTime_DateTime +cimport cpython + + +cdef extern from * namespace "std" nogil: + cdef shared_ptr[T] static_pointer_cast[T, U](shared_ptr[U]) + +# vendored from the cymove project https://github.com/ozars/cymove +cdef extern from * namespace "cymove" nogil: + """ + #include <type_traits> + #include <utility> + namespace cymove { + template <typename T> + inline typename std::remove_reference<T>::type&& cymove(T& t) { + return std::move(t); + } + template <typename T> + inline typename std::remove_reference<T>::type&& cymove(T&& t) { + return std::move(t); + } + } // namespace cymove + """ + cdef T move" cymove::cymove"[T](T) + +cdef extern from * namespace "arrow::py" nogil: + """ + #include <memory> + #include <utility> + + namespace arrow { + namespace py { + template <typename T> + std::shared_ptr<T> to_shared(std::unique_ptr<T>& t) { + return std::move(t); + } + template <typename T> + std::shared_ptr<T> to_shared(std::unique_ptr<T>&& t) { + return std::move(t); + } + } // namespace py + } // namespace arrow + """ + cdef shared_ptr[T] to_shared" arrow::py::to_shared"[T](unique_ptr[T]) + +cdef extern from "arrow/python/platform.h": + pass + +cdef extern from "<Python.h>": + void Py_XDECREF(PyObject* o) + Py_ssize_t Py_REFCNT(PyObject* o) + +cdef extern from "numpy/halffloat.h": + ctypedef uint16_t npy_half + +cdef extern from "arrow/api.h" namespace "arrow" nogil: + # We can later add more of the common status factory methods as needed + cdef CStatus CStatus_OK "arrow::Status::OK"() + + cdef CStatus CStatus_Invalid "arrow::Status::Invalid"() + cdef CStatus CStatus_NotImplemented \ + "arrow::Status::NotImplemented"(const c_string& msg) + cdef CStatus CStatus_UnknownError \ + "arrow::Status::UnknownError"(const c_string& msg) + + cdef cppclass CStatus "arrow::Status": + CStatus() + + c_string ToString() + c_string message() + shared_ptr[CStatusDetail] detail() + + c_bool ok() + c_bool IsIOError() + c_bool IsOutOfMemory() + c_bool IsInvalid() + c_bool IsKeyError() + c_bool IsNotImplemented() + c_bool IsTypeError() + c_bool IsCapacityError() + c_bool IsIndexError() + c_bool IsSerializationError() + c_bool IsCancelled() + + cdef cppclass CStatusDetail "arrow::StatusDetail": + c_string ToString() + + +cdef extern from "arrow/result.h" namespace "arrow" nogil: + cdef cppclass CResult "arrow::Result"[T]: + CResult() + CResult(CStatus) + CResult(T) + c_bool ok() + CStatus status() + T operator*() + + +cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil: + T GetResultValue[T](CResult[T]) except * + cdef function[F] BindFunction[F](void* unbound, object bound, ...) + + +cdef inline object PyObject_to_object(PyObject* o): + # Cast to "object" increments reference count + cdef object result = <object> o + cpython.Py_DECREF(result) + return result diff --git a/src/arrow/python/pyarrow/includes/libarrow.pxd b/src/arrow/python/pyarrow/includes/libarrow.pxd new file mode 100644 index 000000000..815238f11 --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libarrow.pxd @@ -0,0 +1,2615 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.common cimport * + +cdef extern from "arrow/util/key_value_metadata.h" namespace "arrow" nogil: + cdef cppclass CKeyValueMetadata" arrow::KeyValueMetadata": + CKeyValueMetadata() + CKeyValueMetadata(const unordered_map[c_string, c_string]&) + CKeyValueMetadata(const vector[c_string]& keys, + const vector[c_string]& values) + + void reserve(int64_t n) + int64_t size() const + c_string key(int64_t i) const + c_string value(int64_t i) const + int FindKey(const c_string& key) const + + shared_ptr[CKeyValueMetadata] Copy() const + c_bool Equals(const CKeyValueMetadata& other) + void Append(const c_string& key, const c_string& value) + void ToUnorderedMap(unordered_map[c_string, c_string]*) const + c_string ToString() const + + CResult[c_string] Get(const c_string& key) const + CStatus Delete(const c_string& key) + CStatus Set(const c_string& key, const c_string& value) + c_bool Contains(const c_string& key) const + + +cdef extern from "arrow/util/decimal.h" namespace "arrow" nogil: + cdef cppclass CDecimal128" arrow::Decimal128": + c_string ToString(int32_t scale) const + + +cdef extern from "arrow/util/decimal.h" namespace "arrow" nogil: + cdef cppclass CDecimal256" arrow::Decimal256": + c_string ToString(int32_t scale) const + + +cdef extern from "arrow/config.h" namespace "arrow" nogil: + cdef cppclass CBuildInfo" arrow::BuildInfo": + int version + int version_major + int version_minor + int version_patch + c_string version_string + c_string so_version + c_string full_so_version + c_string compiler_id + c_string compiler_version + c_string compiler_flags + c_string git_id + c_string git_description + c_string package_kind + + const CBuildInfo& GetBuildInfo() + + cdef cppclass CRuntimeInfo" arrow::RuntimeInfo": + c_string simd_level + c_string detected_simd_level + + CRuntimeInfo GetRuntimeInfo() + + +cdef extern from "arrow/api.h" namespace "arrow" nogil: + cdef enum Type" arrow::Type::type": + _Type_NA" arrow::Type::NA" + + _Type_BOOL" arrow::Type::BOOL" + + _Type_UINT8" arrow::Type::UINT8" + _Type_INT8" arrow::Type::INT8" + _Type_UINT16" arrow::Type::UINT16" + _Type_INT16" arrow::Type::INT16" + _Type_UINT32" arrow::Type::UINT32" + _Type_INT32" arrow::Type::INT32" + _Type_UINT64" arrow::Type::UINT64" + _Type_INT64" arrow::Type::INT64" + + _Type_HALF_FLOAT" arrow::Type::HALF_FLOAT" + _Type_FLOAT" arrow::Type::FLOAT" + _Type_DOUBLE" arrow::Type::DOUBLE" + + _Type_DECIMAL128" arrow::Type::DECIMAL128" + _Type_DECIMAL256" arrow::Type::DECIMAL256" + + _Type_DATE32" arrow::Type::DATE32" + _Type_DATE64" arrow::Type::DATE64" + _Type_TIMESTAMP" arrow::Type::TIMESTAMP" + _Type_TIME32" arrow::Type::TIME32" + _Type_TIME64" arrow::Type::TIME64" + _Type_DURATION" arrow::Type::DURATION" + _Type_INTERVAL_MONTH_DAY_NANO" arrow::Type::INTERVAL_MONTH_DAY_NANO" + + _Type_BINARY" arrow::Type::BINARY" + _Type_STRING" arrow::Type::STRING" + _Type_LARGE_BINARY" arrow::Type::LARGE_BINARY" + _Type_LARGE_STRING" arrow::Type::LARGE_STRING" + _Type_FIXED_SIZE_BINARY" arrow::Type::FIXED_SIZE_BINARY" + + _Type_LIST" arrow::Type::LIST" + _Type_LARGE_LIST" arrow::Type::LARGE_LIST" + _Type_FIXED_SIZE_LIST" arrow::Type::FIXED_SIZE_LIST" + _Type_STRUCT" arrow::Type::STRUCT" + _Type_SPARSE_UNION" arrow::Type::SPARSE_UNION" + _Type_DENSE_UNION" arrow::Type::DENSE_UNION" + _Type_DICTIONARY" arrow::Type::DICTIONARY" + _Type_MAP" arrow::Type::MAP" + + _Type_EXTENSION" arrow::Type::EXTENSION" + + cdef enum UnionMode" arrow::UnionMode::type": + _UnionMode_SPARSE" arrow::UnionMode::SPARSE" + _UnionMode_DENSE" arrow::UnionMode::DENSE" + + cdef enum TimeUnit" arrow::TimeUnit::type": + TimeUnit_SECOND" arrow::TimeUnit::SECOND" + TimeUnit_MILLI" arrow::TimeUnit::MILLI" + TimeUnit_MICRO" arrow::TimeUnit::MICRO" + TimeUnit_NANO" arrow::TimeUnit::NANO" + + cdef cppclass CBufferSpec" arrow::DataTypeLayout::BufferSpec": + pass + + cdef cppclass CDataTypeLayout" arrow::DataTypeLayout": + vector[CBufferSpec] buffers + c_bool has_dictionary + + cdef cppclass CDataType" arrow::DataType": + Type id() + + c_bool Equals(const CDataType& other) + c_bool Equals(const shared_ptr[CDataType]& other) + + shared_ptr[CField] field(int i) + const vector[shared_ptr[CField]] fields() + int num_fields() + CDataTypeLayout layout() + c_string ToString() + + c_bool is_primitive(Type type) + + cdef cppclass CArrayData" arrow::ArrayData": + shared_ptr[CDataType] type + int64_t length + int64_t null_count + int64_t offset + vector[shared_ptr[CBuffer]] buffers + vector[shared_ptr[CArrayData]] child_data + shared_ptr[CArrayData] dictionary + + @staticmethod + shared_ptr[CArrayData] Make(const shared_ptr[CDataType]& type, + int64_t length, + vector[shared_ptr[CBuffer]]& buffers, + int64_t null_count, + int64_t offset) + + @staticmethod + shared_ptr[CArrayData] MakeWithChildren" Make"( + const shared_ptr[CDataType]& type, + int64_t length, + vector[shared_ptr[CBuffer]]& buffers, + vector[shared_ptr[CArrayData]]& child_data, + int64_t null_count, + int64_t offset) + + @staticmethod + shared_ptr[CArrayData] MakeWithChildrenAndDictionary" Make"( + const shared_ptr[CDataType]& type, + int64_t length, + vector[shared_ptr[CBuffer]]& buffers, + vector[shared_ptr[CArrayData]]& child_data, + shared_ptr[CArrayData]& dictionary, + int64_t null_count, + int64_t offset) + + cdef cppclass CArray" arrow::Array": + shared_ptr[CDataType] type() + + int64_t length() + int64_t null_count() + int64_t offset() + Type type_id() + + int num_fields() + + CResult[shared_ptr[CScalar]] GetScalar(int64_t i) const + + c_string Diff(const CArray& other) + c_bool Equals(const CArray& arr) + c_bool IsNull(int i) + + shared_ptr[CArrayData] data() + + shared_ptr[CArray] Slice(int64_t offset) + shared_ptr[CArray] Slice(int64_t offset, int64_t length) + + CStatus Validate() const + CStatus ValidateFull() const + CResult[shared_ptr[CArray]] View(const shared_ptr[CDataType]& type) + + shared_ptr[CArray] MakeArray(const shared_ptr[CArrayData]& data) + CResult[shared_ptr[CArray]] MakeArrayOfNull( + const shared_ptr[CDataType]& type, int64_t length, CMemoryPool* pool) + + CResult[shared_ptr[CArray]] MakeArrayFromScalar( + const CScalar& scalar, int64_t length, CMemoryPool* pool) + + CStatus DebugPrint(const CArray& arr, int indent) + + cdef cppclass CFixedWidthType" arrow::FixedWidthType"(CDataType): + int bit_width() + + cdef cppclass CNullArray" arrow::NullArray"(CArray): + CNullArray(int64_t length) + + cdef cppclass CDictionaryArray" arrow::DictionaryArray"(CArray): + CDictionaryArray(const shared_ptr[CDataType]& type, + const shared_ptr[CArray]& indices, + const shared_ptr[CArray]& dictionary) + + @staticmethod + CResult[shared_ptr[CArray]] FromArrays( + const shared_ptr[CDataType]& type, + const shared_ptr[CArray]& indices, + const shared_ptr[CArray]& dictionary) + + shared_ptr[CArray] indices() + shared_ptr[CArray] dictionary() + + cdef cppclass CDate32Type" arrow::Date32Type"(CFixedWidthType): + pass + + cdef cppclass CDate64Type" arrow::Date64Type"(CFixedWidthType): + pass + + cdef cppclass CTimestampType" arrow::TimestampType"(CFixedWidthType): + CTimestampType(TimeUnit unit) + TimeUnit unit() + const c_string& timezone() + + cdef cppclass CTime32Type" arrow::Time32Type"(CFixedWidthType): + TimeUnit unit() + + cdef cppclass CTime64Type" arrow::Time64Type"(CFixedWidthType): + TimeUnit unit() + + shared_ptr[CDataType] ctime32" arrow::time32"(TimeUnit unit) + shared_ptr[CDataType] ctime64" arrow::time64"(TimeUnit unit) + + cdef cppclass CDurationType" arrow::DurationType"(CFixedWidthType): + TimeUnit unit() + + shared_ptr[CDataType] cduration" arrow::duration"(TimeUnit unit) + + cdef cppclass CDictionaryType" arrow::DictionaryType"(CFixedWidthType): + CDictionaryType(const shared_ptr[CDataType]& index_type, + const shared_ptr[CDataType]& value_type, + c_bool ordered) + + shared_ptr[CDataType] index_type() + shared_ptr[CDataType] value_type() + c_bool ordered() + + shared_ptr[CDataType] ctimestamp" arrow::timestamp"(TimeUnit unit) + shared_ptr[CDataType] ctimestamp" arrow::timestamp"( + TimeUnit unit, const c_string& timezone) + + cdef cppclass CMemoryPool" arrow::MemoryPool": + int64_t bytes_allocated() + int64_t max_memory() + c_string backend_name() + void ReleaseUnused() + + cdef cppclass CLoggingMemoryPool" arrow::LoggingMemoryPool"(CMemoryPool): + CLoggingMemoryPool(CMemoryPool*) + + cdef cppclass CProxyMemoryPool" arrow::ProxyMemoryPool"(CMemoryPool): + CProxyMemoryPool(CMemoryPool*) + + cdef cppclass CBuffer" arrow::Buffer": + CBuffer(const uint8_t* data, int64_t size) + const uint8_t* data() + uint8_t* mutable_data() + uintptr_t address() + uintptr_t mutable_address() + int64_t size() + shared_ptr[CBuffer] parent() + c_bool is_cpu() const + c_bool is_mutable() const + c_string ToHexString() + c_bool Equals(const CBuffer& other) + + shared_ptr[CBuffer] SliceBuffer(const shared_ptr[CBuffer]& buffer, + int64_t offset, int64_t length) + shared_ptr[CBuffer] SliceBuffer(const shared_ptr[CBuffer]& buffer, + int64_t offset) + + cdef cppclass CMutableBuffer" arrow::MutableBuffer"(CBuffer): + CMutableBuffer(const uint8_t* data, int64_t size) + + cdef cppclass CResizableBuffer" arrow::ResizableBuffer"(CMutableBuffer): + CStatus Resize(const int64_t new_size, c_bool shrink_to_fit) + CStatus Reserve(const int64_t new_size) + + CResult[unique_ptr[CBuffer]] AllocateBuffer(const int64_t size, + CMemoryPool* pool) + + CResult[unique_ptr[CResizableBuffer]] AllocateResizableBuffer( + const int64_t size, CMemoryPool* pool) + + cdef CMemoryPool* c_default_memory_pool" arrow::default_memory_pool"() + cdef CMemoryPool* c_system_memory_pool" arrow::system_memory_pool"() + cdef CStatus c_jemalloc_memory_pool" arrow::jemalloc_memory_pool"( + CMemoryPool** out) + cdef CStatus c_mimalloc_memory_pool" arrow::mimalloc_memory_pool"( + CMemoryPool** out) + + CStatus c_jemalloc_set_decay_ms" arrow::jemalloc_set_decay_ms"(int ms) + + cdef cppclass CListType" arrow::ListType"(CDataType): + CListType(const shared_ptr[CDataType]& value_type) + CListType(const shared_ptr[CField]& field) + shared_ptr[CDataType] value_type() + shared_ptr[CField] value_field() + + cdef cppclass CLargeListType" arrow::LargeListType"(CDataType): + CLargeListType(const shared_ptr[CDataType]& value_type) + CLargeListType(const shared_ptr[CField]& field) + shared_ptr[CDataType] value_type() + shared_ptr[CField] value_field() + + cdef cppclass CMapType" arrow::MapType"(CDataType): + CMapType(const shared_ptr[CField]& key_field, + const shared_ptr[CField]& item_field, c_bool keys_sorted) + shared_ptr[CDataType] key_type() + shared_ptr[CField] key_field() + shared_ptr[CDataType] item_type() + shared_ptr[CField] item_field() + c_bool keys_sorted() + + cdef cppclass CFixedSizeListType" arrow::FixedSizeListType"(CDataType): + CFixedSizeListType(const shared_ptr[CDataType]& value_type, + int32_t list_size) + CFixedSizeListType(const shared_ptr[CField]& field, int32_t list_size) + shared_ptr[CDataType] value_type() + shared_ptr[CField] value_field() + int32_t list_size() + + cdef cppclass CStringType" arrow::StringType"(CDataType): + pass + + cdef cppclass CFixedSizeBinaryType \ + " arrow::FixedSizeBinaryType"(CFixedWidthType): + CFixedSizeBinaryType(int byte_width) + int byte_width() + int bit_width() + + cdef cppclass CDecimal128Type \ + " arrow::Decimal128Type"(CFixedSizeBinaryType): + CDecimal128Type(int precision, int scale) + int precision() + int scale() + + cdef cppclass CDecimal256Type \ + " arrow::Decimal256Type"(CFixedSizeBinaryType): + CDecimal256Type(int precision, int scale) + int precision() + int scale() + + cdef cppclass CField" arrow::Field": + cppclass CMergeOptions "arrow::Field::MergeOptions": + c_bool promote_nullability + + @staticmethod + CMergeOptions Defaults() + + const c_string& name() + shared_ptr[CDataType] type() + c_bool nullable() + + c_string ToString() + c_bool Equals(const CField& other, c_bool check_metadata) + + shared_ptr[const CKeyValueMetadata] metadata() + + CField(const c_string& name, const shared_ptr[CDataType]& type, + c_bool nullable) + + CField(const c_string& name, const shared_ptr[CDataType]& type, + c_bool nullable, const shared_ptr[CKeyValueMetadata]& metadata) + + # Removed const in Cython so don't have to cast to get code to generate + shared_ptr[CField] AddMetadata( + const shared_ptr[CKeyValueMetadata]& metadata) + shared_ptr[CField] WithMetadata( + const shared_ptr[CKeyValueMetadata]& metadata) + shared_ptr[CField] RemoveMetadata() + shared_ptr[CField] WithType(const shared_ptr[CDataType]& type) + shared_ptr[CField] WithName(const c_string& name) + shared_ptr[CField] WithNullable(c_bool nullable) + vector[shared_ptr[CField]] Flatten() + + cdef cppclass CFieldRef" arrow::FieldRef": + CFieldRef() + CFieldRef(c_string name) + CFieldRef(int index) + const c_string* name() const + + cdef cppclass CFieldRefHash" arrow::FieldRef::Hash": + pass + + cdef cppclass CStructType" arrow::StructType"(CDataType): + CStructType(const vector[shared_ptr[CField]]& fields) + + shared_ptr[CField] GetFieldByName(const c_string& name) + vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name) + int GetFieldIndex(const c_string& name) + vector[int] GetAllFieldIndices(const c_string& name) + + cdef cppclass CUnionType" arrow::UnionType"(CDataType): + UnionMode mode() + const vector[int8_t]& type_codes() + const vector[int]& child_ids() + + cdef shared_ptr[CDataType] CMakeSparseUnionType" arrow::sparse_union"( + vector[shared_ptr[CField]] fields, + vector[int8_t] type_codes) + + cdef shared_ptr[CDataType] CMakeDenseUnionType" arrow::dense_union"( + vector[shared_ptr[CField]] fields, + vector[int8_t] type_codes) + + cdef cppclass CSchema" arrow::Schema": + CSchema(const vector[shared_ptr[CField]]& fields) + CSchema(const vector[shared_ptr[CField]]& fields, + const shared_ptr[const CKeyValueMetadata]& metadata) + + # Does not actually exist, but gets Cython to not complain + CSchema(const vector[shared_ptr[CField]]& fields, + const shared_ptr[CKeyValueMetadata]& metadata) + + c_bool Equals(const CSchema& other, c_bool check_metadata) + + shared_ptr[CField] field(int i) + shared_ptr[const CKeyValueMetadata] metadata() + shared_ptr[CField] GetFieldByName(const c_string& name) + vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name) + int GetFieldIndex(const c_string& name) + vector[int] GetAllFieldIndices(const c_string& name) + int num_fields() + c_string ToString() + + CResult[shared_ptr[CSchema]] AddField(int i, + const shared_ptr[CField]& field) + CResult[shared_ptr[CSchema]] RemoveField(int i) + CResult[shared_ptr[CSchema]] SetField(int i, + const shared_ptr[CField]& field) + + # Removed const in Cython so don't have to cast to get code to generate + shared_ptr[CSchema] AddMetadata( + const shared_ptr[CKeyValueMetadata]& metadata) + shared_ptr[CSchema] WithMetadata( + const shared_ptr[CKeyValueMetadata]& metadata) + shared_ptr[CSchema] RemoveMetadata() + + CResult[shared_ptr[CSchema]] UnifySchemas( + const vector[shared_ptr[CSchema]]& schemas) + + cdef cppclass PrettyPrintOptions: + PrettyPrintOptions() + PrettyPrintOptions(int indent_arg) + PrettyPrintOptions(int indent_arg, int window_arg) + int indent + int indent_size + int window + c_string null_rep + c_bool skip_new_lines + c_bool truncate_metadata + c_bool show_field_metadata + c_bool show_schema_metadata + + @staticmethod + PrettyPrintOptions Defaults() + + CStatus PrettyPrint(const CArray& schema, + const PrettyPrintOptions& options, + c_string* result) + CStatus PrettyPrint(const CChunkedArray& schema, + const PrettyPrintOptions& options, + c_string* result) + CStatus PrettyPrint(const CSchema& schema, + const PrettyPrintOptions& options, + c_string* result) + + cdef cppclass CBooleanArray" arrow::BooleanArray"(CArray): + c_bool Value(int i) + int64_t false_count() + int64_t true_count() + + cdef cppclass CUInt8Array" arrow::UInt8Array"(CArray): + uint8_t Value(int i) + + cdef cppclass CInt8Array" arrow::Int8Array"(CArray): + int8_t Value(int i) + + cdef cppclass CUInt16Array" arrow::UInt16Array"(CArray): + uint16_t Value(int i) + + cdef cppclass CInt16Array" arrow::Int16Array"(CArray): + int16_t Value(int i) + + cdef cppclass CUInt32Array" arrow::UInt32Array"(CArray): + uint32_t Value(int i) + + cdef cppclass CInt32Array" arrow::Int32Array"(CArray): + int32_t Value(int i) + + cdef cppclass CUInt64Array" arrow::UInt64Array"(CArray): + uint64_t Value(int i) + + cdef cppclass CInt64Array" arrow::Int64Array"(CArray): + int64_t Value(int i) + + cdef cppclass CDate32Array" arrow::Date32Array"(CArray): + int32_t Value(int i) + + cdef cppclass CDate64Array" arrow::Date64Array"(CArray): + int64_t Value(int i) + + cdef cppclass CTime32Array" arrow::Time32Array"(CArray): + int32_t Value(int i) + + cdef cppclass CTime64Array" arrow::Time64Array"(CArray): + int64_t Value(int i) + + cdef cppclass CTimestampArray" arrow::TimestampArray"(CArray): + int64_t Value(int i) + + cdef cppclass CDurationArray" arrow::DurationArray"(CArray): + int64_t Value(int i) + + cdef cppclass CMonthDayNanoIntervalArray \ + "arrow::MonthDayNanoIntervalArray"(CArray): + pass + + cdef cppclass CHalfFloatArray" arrow::HalfFloatArray"(CArray): + uint16_t Value(int i) + + cdef cppclass CFloatArray" arrow::FloatArray"(CArray): + float Value(int i) + + cdef cppclass CDoubleArray" arrow::DoubleArray"(CArray): + double Value(int i) + + cdef cppclass CFixedSizeBinaryArray" arrow::FixedSizeBinaryArray"(CArray): + const uint8_t* GetValue(int i) + + cdef cppclass CDecimal128Array" arrow::Decimal128Array"( + CFixedSizeBinaryArray + ): + c_string FormatValue(int i) + + cdef cppclass CDecimal256Array" arrow::Decimal256Array"( + CFixedSizeBinaryArray + ): + c_string FormatValue(int i) + + cdef cppclass CListArray" arrow::ListArray"(CArray): + @staticmethod + CResult[shared_ptr[CArray]] FromArrays( + const CArray& offsets, const CArray& values, CMemoryPool* pool) + + const int32_t* raw_value_offsets() + int32_t value_offset(int i) + int32_t value_length(int i) + shared_ptr[CArray] values() + shared_ptr[CArray] offsets() + shared_ptr[CDataType] value_type() + + cdef cppclass CLargeListArray" arrow::LargeListArray"(CArray): + @staticmethod + CResult[shared_ptr[CArray]] FromArrays( + const CArray& offsets, const CArray& values, CMemoryPool* pool) + + int64_t value_offset(int i) + int64_t value_length(int i) + shared_ptr[CArray] values() + shared_ptr[CArray] offsets() + shared_ptr[CDataType] value_type() + + cdef cppclass CFixedSizeListArray" arrow::FixedSizeListArray"(CArray): + @staticmethod + CResult[shared_ptr[CArray]] FromArrays( + const shared_ptr[CArray]& values, int32_t list_size) + + int64_t value_offset(int i) + int64_t value_length(int i) + shared_ptr[CArray] values() + shared_ptr[CDataType] value_type() + + cdef cppclass CMapArray" arrow::MapArray"(CArray): + @staticmethod + CResult[shared_ptr[CArray]] FromArrays( + const shared_ptr[CArray]& offsets, + const shared_ptr[CArray]& keys, + const shared_ptr[CArray]& items, + CMemoryPool* pool) + + shared_ptr[CArray] keys() + shared_ptr[CArray] items() + CMapType* map_type() + int64_t value_offset(int i) + int64_t value_length(int i) + shared_ptr[CArray] values() + shared_ptr[CDataType] value_type() + + cdef cppclass CUnionArray" arrow::UnionArray"(CArray): + shared_ptr[CBuffer] type_codes() + int8_t* raw_type_codes() + int child_id(int64_t index) + shared_ptr[CArray] field(int pos) + const CArray* UnsafeField(int pos) + UnionMode mode() + + cdef cppclass CSparseUnionArray" arrow::SparseUnionArray"(CUnionArray): + @staticmethod + CResult[shared_ptr[CArray]] Make( + const CArray& type_codes, + const vector[shared_ptr[CArray]]& children, + const vector[c_string]& field_names, + const vector[int8_t]& type_codes) + + cdef cppclass CDenseUnionArray" arrow::DenseUnionArray"(CUnionArray): + @staticmethod + CResult[shared_ptr[CArray]] Make( + const CArray& type_codes, + const CArray& value_offsets, + const vector[shared_ptr[CArray]]& children, + const vector[c_string]& field_names, + const vector[int8_t]& type_codes) + + int32_t value_offset(int i) + shared_ptr[CBuffer] value_offsets() + + cdef cppclass CBinaryArray" arrow::BinaryArray"(CArray): + const uint8_t* GetValue(int i, int32_t* length) + shared_ptr[CBuffer] value_data() + int32_t value_offset(int64_t i) + int32_t value_length(int64_t i) + int32_t total_values_length() + + cdef cppclass CLargeBinaryArray" arrow::LargeBinaryArray"(CArray): + const uint8_t* GetValue(int i, int64_t* length) + shared_ptr[CBuffer] value_data() + int64_t value_offset(int64_t i) + int64_t value_length(int64_t i) + int64_t total_values_length() + + cdef cppclass CStringArray" arrow::StringArray"(CBinaryArray): + CStringArray(int64_t length, shared_ptr[CBuffer] value_offsets, + shared_ptr[CBuffer] data, + shared_ptr[CBuffer] null_bitmap, + int64_t null_count, + int64_t offset) + + c_string GetString(int i) + + cdef cppclass CLargeStringArray" arrow::LargeStringArray" \ + (CLargeBinaryArray): + CLargeStringArray(int64_t length, shared_ptr[CBuffer] value_offsets, + shared_ptr[CBuffer] data, + shared_ptr[CBuffer] null_bitmap, + int64_t null_count, + int64_t offset) + + c_string GetString(int i) + + cdef cppclass CStructArray" arrow::StructArray"(CArray): + CStructArray(shared_ptr[CDataType]& type, int64_t length, + vector[shared_ptr[CArray]]& children, + shared_ptr[CBuffer] null_bitmap=nullptr, + int64_t null_count=-1, + int64_t offset=0) + + # XXX Cython crashes if default argument values are declared here + # https://github.com/cython/cython/issues/2167 + @staticmethod + CResult[shared_ptr[CArray]] MakeFromFieldNames "Make"( + vector[shared_ptr[CArray]] children, + vector[c_string] field_names, + shared_ptr[CBuffer] null_bitmap, + int64_t null_count, + int64_t offset) + + @staticmethod + CResult[shared_ptr[CArray]] MakeFromFields "Make"( + vector[shared_ptr[CArray]] children, + vector[shared_ptr[CField]] fields, + shared_ptr[CBuffer] null_bitmap, + int64_t null_count, + int64_t offset) + + shared_ptr[CArray] field(int pos) + shared_ptr[CArray] GetFieldByName(const c_string& name) const + + CResult[vector[shared_ptr[CArray]]] Flatten(CMemoryPool* pool) + + cdef cppclass CChunkedArray" arrow::ChunkedArray": + CChunkedArray(const vector[shared_ptr[CArray]]& arrays) + CChunkedArray(const vector[shared_ptr[CArray]]& arrays, + const shared_ptr[CDataType]& type) + int64_t length() + int64_t null_count() + int num_chunks() + c_bool Equals(const CChunkedArray& other) + + shared_ptr[CArray] chunk(int i) + shared_ptr[CDataType] type() + shared_ptr[CChunkedArray] Slice(int64_t offset, int64_t length) const + shared_ptr[CChunkedArray] Slice(int64_t offset) const + + CResult[vector[shared_ptr[CChunkedArray]]] Flatten(CMemoryPool* pool) + + CStatus Validate() const + CStatus ValidateFull() const + + cdef cppclass CRecordBatch" arrow::RecordBatch": + @staticmethod + shared_ptr[CRecordBatch] Make( + const shared_ptr[CSchema]& schema, int64_t num_rows, + const vector[shared_ptr[CArray]]& columns) + + @staticmethod + CResult[shared_ptr[CRecordBatch]] FromStructArray( + const shared_ptr[CArray]& array) + + c_bool Equals(const CRecordBatch& other, c_bool check_metadata) + + shared_ptr[CSchema] schema() + shared_ptr[CArray] column(int i) + const c_string& column_name(int i) + + const vector[shared_ptr[CArray]]& columns() + + int num_columns() + int64_t num_rows() + + CStatus Validate() const + CStatus ValidateFull() const + + shared_ptr[CRecordBatch] ReplaceSchemaMetadata( + const shared_ptr[CKeyValueMetadata]& metadata) + + shared_ptr[CRecordBatch] Slice(int64_t offset) + shared_ptr[CRecordBatch] Slice(int64_t offset, int64_t length) + + cdef cppclass CTable" arrow::Table": + CTable(const shared_ptr[CSchema]& schema, + const vector[shared_ptr[CChunkedArray]]& columns) + + @staticmethod + shared_ptr[CTable] Make( + const shared_ptr[CSchema]& schema, + const vector[shared_ptr[CChunkedArray]]& columns) + + @staticmethod + shared_ptr[CTable] MakeFromArrays" Make"( + const shared_ptr[CSchema]& schema, + const vector[shared_ptr[CArray]]& arrays) + + @staticmethod + CResult[shared_ptr[CTable]] FromRecordBatches( + const shared_ptr[CSchema]& schema, + const vector[shared_ptr[CRecordBatch]]& batches) + + int num_columns() + int64_t num_rows() + + c_bool Equals(const CTable& other, c_bool check_metadata) + + shared_ptr[CSchema] schema() + shared_ptr[CChunkedArray] column(int i) + shared_ptr[CField] field(int i) + + CResult[shared_ptr[CTable]] AddColumn( + int i, shared_ptr[CField] field, shared_ptr[CChunkedArray] column) + CResult[shared_ptr[CTable]] RemoveColumn(int i) + CResult[shared_ptr[CTable]] SetColumn( + int i, shared_ptr[CField] field, shared_ptr[CChunkedArray] column) + + vector[c_string] ColumnNames() + CResult[shared_ptr[CTable]] RenameColumns(const vector[c_string]&) + CResult[shared_ptr[CTable]] SelectColumns(const vector[int]&) + + CResult[shared_ptr[CTable]] Flatten(CMemoryPool* pool) + + CResult[shared_ptr[CTable]] CombineChunks(CMemoryPool* pool) + + CStatus Validate() const + CStatus ValidateFull() const + + shared_ptr[CTable] ReplaceSchemaMetadata( + const shared_ptr[CKeyValueMetadata]& metadata) + + shared_ptr[CTable] Slice(int64_t offset) + shared_ptr[CTable] Slice(int64_t offset, int64_t length) + + cdef cppclass CRecordBatchReader" arrow::RecordBatchReader": + shared_ptr[CSchema] schema() + CStatus ReadNext(shared_ptr[CRecordBatch]* batch) + CStatus ReadAll(shared_ptr[CTable]* out) + + cdef cppclass TableBatchReader(CRecordBatchReader): + TableBatchReader(const CTable& table) + void set_chunksize(int64_t chunksize) + + cdef cppclass CTensor" arrow::Tensor": + shared_ptr[CDataType] type() + shared_ptr[CBuffer] data() + + const vector[int64_t]& shape() + const vector[int64_t]& strides() + int64_t size() + + int ndim() + const vector[c_string]& dim_names() + const c_string& dim_name(int i) + + c_bool is_mutable() + c_bool is_contiguous() + Type type_id() + c_bool Equals(const CTensor& other) + + cdef cppclass CSparseIndex" arrow::SparseIndex": + pass + + cdef cppclass CSparseCOOIndex" arrow::SparseCOOIndex": + c_bool is_canonical() + + cdef cppclass CSparseCOOTensor" arrow::SparseCOOTensor": + shared_ptr[CDataType] type() + shared_ptr[CBuffer] data() + CResult[shared_ptr[CTensor]] ToTensor() + + shared_ptr[CSparseIndex] sparse_index() + + const vector[int64_t]& shape() + int64_t size() + int64_t non_zero_length() + + int ndim() + const vector[c_string]& dim_names() + const c_string& dim_name(int i) + + c_bool is_mutable() + Type type_id() + c_bool Equals(const CSparseCOOTensor& other) + + cdef cppclass CSparseCSRMatrix" arrow::SparseCSRMatrix": + shared_ptr[CDataType] type() + shared_ptr[CBuffer] data() + CResult[shared_ptr[CTensor]] ToTensor() + + const vector[int64_t]& shape() + int64_t size() + int64_t non_zero_length() + + int ndim() + const vector[c_string]& dim_names() + const c_string& dim_name(int i) + + c_bool is_mutable() + Type type_id() + c_bool Equals(const CSparseCSRMatrix& other) + + cdef cppclass CSparseCSCMatrix" arrow::SparseCSCMatrix": + shared_ptr[CDataType] type() + shared_ptr[CBuffer] data() + CResult[shared_ptr[CTensor]] ToTensor() + + const vector[int64_t]& shape() + int64_t size() + int64_t non_zero_length() + + int ndim() + const vector[c_string]& dim_names() + const c_string& dim_name(int i) + + c_bool is_mutable() + Type type_id() + c_bool Equals(const CSparseCSCMatrix& other) + + cdef cppclass CSparseCSFTensor" arrow::SparseCSFTensor": + shared_ptr[CDataType] type() + shared_ptr[CBuffer] data() + CResult[shared_ptr[CTensor]] ToTensor() + + const vector[int64_t]& shape() + int64_t size() + int64_t non_zero_length() + + int ndim() + const vector[c_string]& dim_names() + const c_string& dim_name(int i) + + c_bool is_mutable() + Type type_id() + c_bool Equals(const CSparseCSFTensor& other) + + cdef cppclass CScalar" arrow::Scalar": + CScalar(shared_ptr[CDataType]) + + shared_ptr[CDataType] type + c_bool is_valid + + c_string ToString() const + c_bool Equals(const CScalar& other) const + CStatus Validate() const + CStatus ValidateFull() const + CResult[shared_ptr[CScalar]] CastTo(shared_ptr[CDataType] to) const + + cdef cppclass CScalarHash" arrow::Scalar::Hash": + size_t operator()(const shared_ptr[CScalar]& scalar) const + + cdef cppclass CNullScalar" arrow::NullScalar"(CScalar): + CNullScalar() + + cdef cppclass CBooleanScalar" arrow::BooleanScalar"(CScalar): + c_bool value + + cdef cppclass CInt8Scalar" arrow::Int8Scalar"(CScalar): + int8_t value + + cdef cppclass CUInt8Scalar" arrow::UInt8Scalar"(CScalar): + uint8_t value + + cdef cppclass CInt16Scalar" arrow::Int16Scalar"(CScalar): + int16_t value + + cdef cppclass CUInt16Scalar" arrow::UInt16Scalar"(CScalar): + uint16_t value + + cdef cppclass CInt32Scalar" arrow::Int32Scalar"(CScalar): + int32_t value + + cdef cppclass CUInt32Scalar" arrow::UInt32Scalar"(CScalar): + uint32_t value + + cdef cppclass CInt64Scalar" arrow::Int64Scalar"(CScalar): + int64_t value + + cdef cppclass CUInt64Scalar" arrow::UInt64Scalar"(CScalar): + uint64_t value + + cdef cppclass CHalfFloatScalar" arrow::HalfFloatScalar"(CScalar): + npy_half value + + cdef cppclass CFloatScalar" arrow::FloatScalar"(CScalar): + float value + + cdef cppclass CDoubleScalar" arrow::DoubleScalar"(CScalar): + double value + + cdef cppclass CDecimal128Scalar" arrow::Decimal128Scalar"(CScalar): + CDecimal128 value + + cdef cppclass CDecimal256Scalar" arrow::Decimal256Scalar"(CScalar): + CDecimal256 value + + cdef cppclass CDate32Scalar" arrow::Date32Scalar"(CScalar): + int32_t value + + cdef cppclass CDate64Scalar" arrow::Date64Scalar"(CScalar): + int64_t value + + cdef cppclass CTime32Scalar" arrow::Time32Scalar"(CScalar): + int32_t value + + cdef cppclass CTime64Scalar" arrow::Time64Scalar"(CScalar): + int64_t value + + cdef cppclass CTimestampScalar" arrow::TimestampScalar"(CScalar): + int64_t value + + cdef cppclass CDurationScalar" arrow::DurationScalar"(CScalar): + int64_t value + + cdef cppclass CMonthDayNanoIntervalScalar \ + "arrow::MonthDayNanoIntervalScalar"(CScalar): + pass + + cdef cppclass CBaseBinaryScalar" arrow::BaseBinaryScalar"(CScalar): + shared_ptr[CBuffer] value + + cdef cppclass CBaseListScalar" arrow::BaseListScalar"(CScalar): + shared_ptr[CArray] value + + cdef cppclass CListScalar" arrow::ListScalar"(CBaseListScalar): + pass + + cdef cppclass CMapScalar" arrow::MapScalar"(CListScalar): + pass + + cdef cppclass CStructScalar" arrow::StructScalar"(CScalar): + vector[shared_ptr[CScalar]] value + CResult[shared_ptr[CScalar]] field(CFieldRef ref) const + + cdef cppclass CDictionaryScalarIndexAndDictionary \ + "arrow::DictionaryScalar::ValueType": + shared_ptr[CScalar] index + shared_ptr[CArray] dictionary + + cdef cppclass CDictionaryScalar" arrow::DictionaryScalar"(CScalar): + CDictionaryScalar(CDictionaryScalarIndexAndDictionary value, + shared_ptr[CDataType], c_bool is_valid) + CDictionaryScalarIndexAndDictionary value + + CResult[shared_ptr[CScalar]] GetEncodedValue() + + cdef cppclass CUnionScalar" arrow::UnionScalar"(CScalar): + shared_ptr[CScalar] value + int8_t type_code + + cdef cppclass CExtensionScalar" arrow::ExtensionScalar"(CScalar): + shared_ptr[CScalar] value + + shared_ptr[CScalar] MakeScalar[Value](Value value) + + cdef cppclass CConcatenateTablesOptions" arrow::ConcatenateTablesOptions": + c_bool unify_schemas + CField.CMergeOptions field_merge_options + + @staticmethod + CConcatenateTablesOptions Defaults() + + CResult[shared_ptr[CTable]] ConcatenateTables( + const vector[shared_ptr[CTable]]& tables, + CConcatenateTablesOptions options, + CMemoryPool* memory_pool) + + cdef cppclass CDictionaryUnifier" arrow::DictionaryUnifier": + @staticmethod + CResult[shared_ptr[CChunkedArray]] UnifyChunkedArray( + shared_ptr[CChunkedArray] array, CMemoryPool* pool) + + @staticmethod + CResult[shared_ptr[CTable]] UnifyTable( + const CTable& table, CMemoryPool* pool) + + +cdef extern from "arrow/builder.h" namespace "arrow" nogil: + + cdef cppclass CArrayBuilder" arrow::ArrayBuilder": + CArrayBuilder(shared_ptr[CDataType], CMemoryPool* pool) + + int64_t length() + int64_t null_count() + CStatus AppendNull() + CStatus Finish(shared_ptr[CArray]* out) + CStatus Reserve(int64_t additional_capacity) + + cdef cppclass CBooleanBuilder" arrow::BooleanBuilder"(CArrayBuilder): + CBooleanBuilder(CMemoryPool* pool) + CStatus Append(const c_bool val) + CStatus Append(const uint8_t val) + + cdef cppclass CInt8Builder" arrow::Int8Builder"(CArrayBuilder): + CInt8Builder(CMemoryPool* pool) + CStatus Append(const int8_t value) + + cdef cppclass CInt16Builder" arrow::Int16Builder"(CArrayBuilder): + CInt16Builder(CMemoryPool* pool) + CStatus Append(const int16_t value) + + cdef cppclass CInt32Builder" arrow::Int32Builder"(CArrayBuilder): + CInt32Builder(CMemoryPool* pool) + CStatus Append(const int32_t value) + + cdef cppclass CInt64Builder" arrow::Int64Builder"(CArrayBuilder): + CInt64Builder(CMemoryPool* pool) + CStatus Append(const int64_t value) + + cdef cppclass CUInt8Builder" arrow::UInt8Builder"(CArrayBuilder): + CUInt8Builder(CMemoryPool* pool) + CStatus Append(const uint8_t value) + + cdef cppclass CUInt16Builder" arrow::UInt16Builder"(CArrayBuilder): + CUInt16Builder(CMemoryPool* pool) + CStatus Append(const uint16_t value) + + cdef cppclass CUInt32Builder" arrow::UInt32Builder"(CArrayBuilder): + CUInt32Builder(CMemoryPool* pool) + CStatus Append(const uint32_t value) + + cdef cppclass CUInt64Builder" arrow::UInt64Builder"(CArrayBuilder): + CUInt64Builder(CMemoryPool* pool) + CStatus Append(const uint64_t value) + + cdef cppclass CHalfFloatBuilder" arrow::HalfFloatBuilder"(CArrayBuilder): + CHalfFloatBuilder(CMemoryPool* pool) + + cdef cppclass CFloatBuilder" arrow::FloatBuilder"(CArrayBuilder): + CFloatBuilder(CMemoryPool* pool) + CStatus Append(const float value) + + cdef cppclass CDoubleBuilder" arrow::DoubleBuilder"(CArrayBuilder): + CDoubleBuilder(CMemoryPool* pool) + CStatus Append(const double value) + + cdef cppclass CBinaryBuilder" arrow::BinaryBuilder"(CArrayBuilder): + CArrayBuilder(shared_ptr[CDataType], CMemoryPool* pool) + CStatus Append(const char* value, int32_t length) + + cdef cppclass CStringBuilder" arrow::StringBuilder"(CBinaryBuilder): + CStringBuilder(CMemoryPool* pool) + + CStatus Append(const c_string& value) + + cdef cppclass CTimestampBuilder "arrow::TimestampBuilder"(CArrayBuilder): + CTimestampBuilder(const shared_ptr[CDataType] typ, CMemoryPool* pool) + CStatus Append(const int64_t value) + + cdef cppclass CDate32Builder "arrow::Date32Builder"(CArrayBuilder): + CDate32Builder(CMemoryPool* pool) + CStatus Append(const int32_t value) + + cdef cppclass CDate64Builder "arrow::Date64Builder"(CArrayBuilder): + CDate64Builder(CMemoryPool* pool) + CStatus Append(const int64_t value) + + +# Use typedef to emulate syntax for std::function<void(..)> +ctypedef void CallbackTransform(object, const shared_ptr[CBuffer]& src, + shared_ptr[CBuffer]* dest) + + +cdef extern from "arrow/util/cancel.h" namespace "arrow" nogil: + cdef cppclass CStopToken "arrow::StopToken": + CStatus Poll() + c_bool IsStopRequested() + + cdef cppclass CStopSource "arrow::StopSource": + CStopToken token() + + CResult[CStopSource*] SetSignalStopSource() + void ResetSignalStopSource() + + CStatus RegisterCancellingSignalHandler(vector[int] signals) + void UnregisterCancellingSignalHandler() + + +cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil: + cdef enum FileMode" arrow::io::FileMode::type": + FileMode_READ" arrow::io::FileMode::READ" + FileMode_WRITE" arrow::io::FileMode::WRITE" + FileMode_READWRITE" arrow::io::FileMode::READWRITE" + + cdef enum ObjectType" arrow::io::ObjectType::type": + ObjectType_FILE" arrow::io::ObjectType::FILE" + ObjectType_DIRECTORY" arrow::io::ObjectType::DIRECTORY" + + cdef cppclass CIOContext" arrow::io::IOContext": + CIOContext() + CIOContext(CStopToken) + CIOContext(CMemoryPool*) + CIOContext(CMemoryPool*, CStopToken) + + CIOContext c_default_io_context "arrow::io::default_io_context"() + int GetIOThreadPoolCapacity() + CStatus SetIOThreadPoolCapacity(int threads) + + cdef cppclass FileStatistics: + int64_t size + ObjectType kind + + cdef cppclass FileInterface: + CStatus Close() + CResult[int64_t] Tell() + FileMode mode() + c_bool closed() + + cdef cppclass Readable: + # put overload under a different name to avoid cython bug with multiple + # layers of inheritance + CResult[shared_ptr[CBuffer]] ReadBuffer" Read"(int64_t nbytes) + CResult[int64_t] Read(int64_t nbytes, uint8_t* out) + + cdef cppclass Seekable: + CStatus Seek(int64_t position) + + cdef cppclass Writable: + CStatus WriteBuffer" Write"(shared_ptr[CBuffer] data) + CStatus Write(const uint8_t* data, int64_t nbytes) + CStatus Flush() + + cdef cppclass COutputStream" arrow::io::OutputStream"(FileInterface, + Writable): + pass + + cdef cppclass CInputStream" arrow::io::InputStream"(FileInterface, + Readable): + CResult[shared_ptr[const CKeyValueMetadata]] ReadMetadata() + + cdef cppclass CRandomAccessFile" arrow::io::RandomAccessFile"(CInputStream, + Seekable): + CResult[int64_t] GetSize() + + CResult[int64_t] ReadAt(int64_t position, int64_t nbytes, + uint8_t* buffer) + CResult[shared_ptr[CBuffer]] ReadAt(int64_t position, int64_t nbytes) + c_bool supports_zero_copy() + + cdef cppclass WritableFile(COutputStream, Seekable): + CStatus WriteAt(int64_t position, const uint8_t* data, + int64_t nbytes) + + cdef cppclass ReadWriteFileInterface(CRandomAccessFile, + WritableFile): + pass + + cdef cppclass CIOFileSystem" arrow::io::FileSystem": + CStatus Stat(const c_string& path, FileStatistics* stat) + + cdef cppclass FileOutputStream(COutputStream): + @staticmethod + CResult[shared_ptr[COutputStream]] Open(const c_string& path) + + int file_descriptor() + + cdef cppclass ReadableFile(CRandomAccessFile): + @staticmethod + CResult[shared_ptr[ReadableFile]] Open(const c_string& path) + + @staticmethod + CResult[shared_ptr[ReadableFile]] Open(const c_string& path, + CMemoryPool* memory_pool) + + int file_descriptor() + + cdef cppclass CMemoryMappedFile \ + " arrow::io::MemoryMappedFile"(ReadWriteFileInterface): + + @staticmethod + CResult[shared_ptr[CMemoryMappedFile]] Create(const c_string& path, + int64_t size) + + @staticmethod + CResult[shared_ptr[CMemoryMappedFile]] Open(const c_string& path, + FileMode mode) + + CStatus Resize(int64_t size) + + int file_descriptor() + + cdef cppclass CCompressedInputStream \ + " arrow::io::CompressedInputStream"(CInputStream): + @staticmethod + CResult[shared_ptr[CCompressedInputStream]] Make( + CCodec* codec, shared_ptr[CInputStream] raw) + + cdef cppclass CCompressedOutputStream \ + " arrow::io::CompressedOutputStream"(COutputStream): + @staticmethod + CResult[shared_ptr[CCompressedOutputStream]] Make( + CCodec* codec, shared_ptr[COutputStream] raw) + + cdef cppclass CBufferedInputStream \ + " arrow::io::BufferedInputStream"(CInputStream): + + @staticmethod + CResult[shared_ptr[CBufferedInputStream]] Create( + int64_t buffer_size, CMemoryPool* pool, + shared_ptr[CInputStream] raw) + + CResult[shared_ptr[CInputStream]] Detach() + + cdef cppclass CBufferedOutputStream \ + " arrow::io::BufferedOutputStream"(COutputStream): + + @staticmethod + CResult[shared_ptr[CBufferedOutputStream]] Create( + int64_t buffer_size, CMemoryPool* pool, + shared_ptr[COutputStream] raw) + + CResult[shared_ptr[COutputStream]] Detach() + + cdef cppclass CTransformInputStreamVTable \ + "arrow::py::TransformInputStreamVTable": + CTransformInputStreamVTable() + function[CallbackTransform] transform + + shared_ptr[CInputStream] MakeTransformInputStream \ + "arrow::py::MakeTransformInputStream"( + shared_ptr[CInputStream] wrapped, CTransformInputStreamVTable vtable, + object method_arg) + + # ---------------------------------------------------------------------- + # HDFS + + CStatus HaveLibHdfs() + CStatus HaveLibHdfs3() + + cdef enum HdfsDriver" arrow::io::HdfsDriver": + HdfsDriver_LIBHDFS" arrow::io::HdfsDriver::LIBHDFS" + HdfsDriver_LIBHDFS3" arrow::io::HdfsDriver::LIBHDFS3" + + cdef cppclass HdfsConnectionConfig: + c_string host + int port + c_string user + c_string kerb_ticket + unordered_map[c_string, c_string] extra_conf + HdfsDriver driver + + cdef cppclass HdfsPathInfo: + ObjectType kind + c_string name + c_string owner + c_string group + int32_t last_modified_time + int32_t last_access_time + int64_t size + int16_t replication + int64_t block_size + int16_t permissions + + cdef cppclass HdfsReadableFile(CRandomAccessFile): + pass + + cdef cppclass HdfsOutputStream(COutputStream): + pass + + cdef cppclass CIOHadoopFileSystem \ + "arrow::io::HadoopFileSystem"(CIOFileSystem): + @staticmethod + CStatus Connect(const HdfsConnectionConfig* config, + shared_ptr[CIOHadoopFileSystem]* client) + + CStatus MakeDirectory(const c_string& path) + + CStatus Delete(const c_string& path, c_bool recursive) + + CStatus Disconnect() + + c_bool Exists(const c_string& path) + + CStatus Chmod(const c_string& path, int mode) + CStatus Chown(const c_string& path, const char* owner, + const char* group) + + CStatus GetCapacity(int64_t* nbytes) + CStatus GetUsed(int64_t* nbytes) + + CStatus ListDirectory(const c_string& path, + vector[HdfsPathInfo]* listing) + + CStatus GetPathInfo(const c_string& path, HdfsPathInfo* info) + + CStatus Rename(const c_string& src, const c_string& dst) + + CStatus OpenReadable(const c_string& path, + shared_ptr[HdfsReadableFile]* handle) + + CStatus OpenWritable(const c_string& path, c_bool append, + int32_t buffer_size, int16_t replication, + int64_t default_block_size, + shared_ptr[HdfsOutputStream]* handle) + + cdef cppclass CBufferReader \ + " arrow::io::BufferReader"(CRandomAccessFile): + CBufferReader(const shared_ptr[CBuffer]& buffer) + CBufferReader(const uint8_t* data, int64_t nbytes) + + cdef cppclass CBufferOutputStream \ + " arrow::io::BufferOutputStream"(COutputStream): + CBufferOutputStream(const shared_ptr[CResizableBuffer]& buffer) + + cdef cppclass CMockOutputStream \ + " arrow::io::MockOutputStream"(COutputStream): + CMockOutputStream() + int64_t GetExtentBytesWritten() + + cdef cppclass CFixedSizeBufferWriter \ + " arrow::io::FixedSizeBufferWriter"(WritableFile): + CFixedSizeBufferWriter(const shared_ptr[CBuffer]& buffer) + + void set_memcopy_threads(int num_threads) + void set_memcopy_blocksize(int64_t blocksize) + void set_memcopy_threshold(int64_t threshold) + + +cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: + cdef enum MessageType" arrow::ipc::MessageType": + MessageType_SCHEMA" arrow::ipc::MessageType::SCHEMA" + MessageType_RECORD_BATCH" arrow::ipc::MessageType::RECORD_BATCH" + MessageType_DICTIONARY_BATCH \ + " arrow::ipc::MessageType::DICTIONARY_BATCH" + + # TODO: use "cpdef enum class" to automatically get a Python wrapper? + # See + # https://github.com/cython/cython/commit/2c7c22f51405299a4e247f78edf52957d30cf71d#diff-61c1365c0f761a8137754bb3a73bfbf7 + ctypedef enum CMetadataVersion" arrow::ipc::MetadataVersion": + CMetadataVersion_V1" arrow::ipc::MetadataVersion::V1" + CMetadataVersion_V2" arrow::ipc::MetadataVersion::V2" + CMetadataVersion_V3" arrow::ipc::MetadataVersion::V3" + CMetadataVersion_V4" arrow::ipc::MetadataVersion::V4" + CMetadataVersion_V5" arrow::ipc::MetadataVersion::V5" + + cdef cppclass CIpcWriteOptions" arrow::ipc::IpcWriteOptions": + c_bool allow_64bit + int max_recursion_depth + int32_t alignment + c_bool write_legacy_ipc_format + CMemoryPool* memory_pool + CMetadataVersion metadata_version + shared_ptr[CCodec] codec + c_bool use_threads + c_bool emit_dictionary_deltas + + @staticmethod + CIpcWriteOptions Defaults() + + cdef cppclass CIpcReadOptions" arrow::ipc::IpcReadOptions": + int max_recursion_depth + CMemoryPool* memory_pool + shared_ptr[unordered_set[int]] included_fields + + @staticmethod + CIpcReadOptions Defaults() + + cdef cppclass CIpcWriteStats" arrow::ipc::WriteStats": + int64_t num_messages + int64_t num_record_batches + int64_t num_dictionary_batches + int64_t num_dictionary_deltas + int64_t num_replaced_dictionaries + + cdef cppclass CIpcReadStats" arrow::ipc::ReadStats": + int64_t num_messages + int64_t num_record_batches + int64_t num_dictionary_batches + int64_t num_dictionary_deltas + int64_t num_replaced_dictionaries + + cdef cppclass CDictionaryMemo" arrow::ipc::DictionaryMemo": + pass + + cdef cppclass CIpcPayload" arrow::ipc::IpcPayload": + MessageType type + shared_ptr[CBuffer] metadata + vector[shared_ptr[CBuffer]] body_buffers + int64_t body_length + + cdef cppclass CMessage" arrow::ipc::Message": + CResult[unique_ptr[CMessage]] Open(shared_ptr[CBuffer] metadata, + shared_ptr[CBuffer] body) + + shared_ptr[CBuffer] body() + + c_bool Equals(const CMessage& other) + + shared_ptr[CBuffer] metadata() + CMetadataVersion metadata_version() + MessageType type() + + CStatus SerializeTo(COutputStream* stream, + const CIpcWriteOptions& options, + int64_t* output_length) + + c_string FormatMessageType(MessageType type) + + cdef cppclass CMessageReader" arrow::ipc::MessageReader": + @staticmethod + unique_ptr[CMessageReader] Open(const shared_ptr[CInputStream]& stream) + + CResult[unique_ptr[CMessage]] ReadNextMessage() + + cdef cppclass CRecordBatchWriter" arrow::ipc::RecordBatchWriter": + CStatus Close() + CStatus WriteRecordBatch(const CRecordBatch& batch) + CStatus WriteTable(const CTable& table, int64_t max_chunksize) + + CIpcWriteStats stats() + + cdef cppclass CRecordBatchStreamReader \ + " arrow::ipc::RecordBatchStreamReader"(CRecordBatchReader): + @staticmethod + CResult[shared_ptr[CRecordBatchReader]] Open( + const shared_ptr[CInputStream], const CIpcReadOptions&) + + @staticmethod + CResult[shared_ptr[CRecordBatchReader]] Open2" Open"( + unique_ptr[CMessageReader] message_reader, + const CIpcReadOptions& options) + + CIpcReadStats stats() + + cdef cppclass CRecordBatchFileReader \ + " arrow::ipc::RecordBatchFileReader": + @staticmethod + CResult[shared_ptr[CRecordBatchFileReader]] Open( + CRandomAccessFile* file, + const CIpcReadOptions& options) + + @staticmethod + CResult[shared_ptr[CRecordBatchFileReader]] Open2" Open"( + CRandomAccessFile* file, int64_t footer_offset, + const CIpcReadOptions& options) + + shared_ptr[CSchema] schema() + + int num_record_batches() + + CResult[shared_ptr[CRecordBatch]] ReadRecordBatch(int i) + + CIpcReadStats stats() + + CResult[shared_ptr[CRecordBatchWriter]] MakeStreamWriter( + shared_ptr[COutputStream] sink, const shared_ptr[CSchema]& schema, + CIpcWriteOptions& options) + + CResult[shared_ptr[CRecordBatchWriter]] MakeFileWriter( + shared_ptr[COutputStream] sink, const shared_ptr[CSchema]& schema, + CIpcWriteOptions& options) + + CResult[unique_ptr[CMessage]] ReadMessage(CInputStream* stream, + CMemoryPool* pool) + + CStatus GetRecordBatchSize(const CRecordBatch& batch, int64_t* size) + CStatus GetTensorSize(const CTensor& tensor, int64_t* size) + + CStatus WriteTensor(const CTensor& tensor, COutputStream* dst, + int32_t* metadata_length, + int64_t* body_length) + + CResult[shared_ptr[CTensor]] ReadTensor(CInputStream* stream) + + CResult[shared_ptr[CRecordBatch]] ReadRecordBatch( + const CMessage& message, const shared_ptr[CSchema]& schema, + CDictionaryMemo* dictionary_memo, + const CIpcReadOptions& options) + + CResult[shared_ptr[CBuffer]] SerializeSchema( + const CSchema& schema, CMemoryPool* pool) + + CResult[shared_ptr[CBuffer]] SerializeRecordBatch( + const CRecordBatch& schema, const CIpcWriteOptions& options) + + CResult[shared_ptr[CSchema]] ReadSchema(CInputStream* stream, + CDictionaryMemo* dictionary_memo) + + CResult[shared_ptr[CRecordBatch]] ReadRecordBatch( + const shared_ptr[CSchema]& schema, + CDictionaryMemo* dictionary_memo, + const CIpcReadOptions& options, + CInputStream* stream) + + CStatus AlignStream(CInputStream* stream, int64_t alignment) + CStatus AlignStream(COutputStream* stream, int64_t alignment) + + cdef CStatus GetRecordBatchPayload \ + " arrow::ipc::GetRecordBatchPayload"( + const CRecordBatch& batch, + const CIpcWriteOptions& options, + CIpcPayload* out) + + +cdef extern from "arrow/util/value_parsing.h" namespace "arrow" nogil: + cdef cppclass CTimestampParser" arrow::TimestampParser": + const char* kind() const + const char* format() const + + @staticmethod + shared_ptr[CTimestampParser] MakeStrptime(c_string format) + + @staticmethod + shared_ptr[CTimestampParser] MakeISO8601() + + +cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: + + cdef cppclass CCSVParseOptions" arrow::csv::ParseOptions": + unsigned char delimiter + c_bool quoting + unsigned char quote_char + c_bool double_quote + c_bool escaping + unsigned char escape_char + c_bool newlines_in_values + c_bool ignore_empty_lines + + CCSVParseOptions() + CCSVParseOptions(CCSVParseOptions&&) + + @staticmethod + CCSVParseOptions Defaults() + + CStatus Validate() + + cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions": + c_bool check_utf8 + unordered_map[c_string, shared_ptr[CDataType]] column_types + vector[c_string] null_values + vector[c_string] true_values + vector[c_string] false_values + c_bool strings_can_be_null + c_bool quoted_strings_can_be_null + vector[shared_ptr[CTimestampParser]] timestamp_parsers + + c_bool auto_dict_encode + int32_t auto_dict_max_cardinality + unsigned char decimal_point + + vector[c_string] include_columns + c_bool include_missing_columns + + CCSVConvertOptions() + CCSVConvertOptions(CCSVConvertOptions&&) + + @staticmethod + CCSVConvertOptions Defaults() + + CStatus Validate() + + cdef cppclass CCSVReadOptions" arrow::csv::ReadOptions": + c_bool use_threads + int32_t block_size + int32_t skip_rows + int32_t skip_rows_after_names + vector[c_string] column_names + c_bool autogenerate_column_names + + CCSVReadOptions() + CCSVReadOptions(CCSVReadOptions&&) + + @staticmethod + CCSVReadOptions Defaults() + + CStatus Validate() + + cdef cppclass CCSVWriteOptions" arrow::csv::WriteOptions": + c_bool include_header + int32_t batch_size + CIOContext io_context + + CCSVWriteOptions() + CCSVWriteOptions(CCSVWriteOptions&&) + + @staticmethod + CCSVWriteOptions Defaults() + + CStatus Validate() + + cdef cppclass CCSVReader" arrow::csv::TableReader": + @staticmethod + CResult[shared_ptr[CCSVReader]] Make( + CIOContext, shared_ptr[CInputStream], + CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions) + + CResult[shared_ptr[CTable]] Read() + + cdef cppclass CCSVStreamingReader" arrow::csv::StreamingReader"( + CRecordBatchReader): + @staticmethod + CResult[shared_ptr[CCSVStreamingReader]] Make( + CIOContext, shared_ptr[CInputStream], + CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions) + + cdef CStatus WriteCSV(CTable&, CCSVWriteOptions& options, COutputStream*) + cdef CStatus WriteCSV( + CRecordBatch&, CCSVWriteOptions& options, COutputStream*) + cdef CResult[shared_ptr[CRecordBatchWriter]] MakeCSVWriter( + shared_ptr[COutputStream], shared_ptr[CSchema], + CCSVWriteOptions& options) + + +cdef extern from "arrow/json/options.h" nogil: + + ctypedef enum CUnexpectedFieldBehavior \ + "arrow::json::UnexpectedFieldBehavior": + CUnexpectedFieldBehavior_Ignore \ + "arrow::json::UnexpectedFieldBehavior::Ignore" + CUnexpectedFieldBehavior_Error \ + "arrow::json::UnexpectedFieldBehavior::Error" + CUnexpectedFieldBehavior_InferType \ + "arrow::json::UnexpectedFieldBehavior::InferType" + + cdef cppclass CJSONReadOptions" arrow::json::ReadOptions": + c_bool use_threads + int32_t block_size + + @staticmethod + CJSONReadOptions Defaults() + + cdef cppclass CJSONParseOptions" arrow::json::ParseOptions": + shared_ptr[CSchema] explicit_schema + c_bool newlines_in_values + CUnexpectedFieldBehavior unexpected_field_behavior + + @staticmethod + CJSONParseOptions Defaults() + + +cdef extern from "arrow/json/reader.h" namespace "arrow::json" nogil: + + cdef cppclass CJSONReader" arrow::json::TableReader": + @staticmethod + CResult[shared_ptr[CJSONReader]] Make( + CMemoryPool*, shared_ptr[CInputStream], + CJSONReadOptions, CJSONParseOptions) + + CResult[shared_ptr[CTable]] Read() + + +cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: + + cdef cppclass CExecContext" arrow::compute::ExecContext": + CExecContext() + CExecContext(CMemoryPool* pool) + + cdef cppclass CKernelSignature" arrow::compute::KernelSignature": + c_string ToString() const + + cdef cppclass CKernel" arrow::compute::Kernel": + shared_ptr[CKernelSignature] signature + + cdef cppclass CArrayKernel" arrow::compute::ArrayKernel"(CKernel): + pass + + cdef cppclass CScalarKernel" arrow::compute::ScalarKernel"(CArrayKernel): + pass + + cdef cppclass CVectorKernel" arrow::compute::VectorKernel"(CArrayKernel): + pass + + cdef cppclass CScalarAggregateKernel \ + " arrow::compute::ScalarAggregateKernel"(CKernel): + pass + + cdef cppclass CHashAggregateKernel \ + " arrow::compute::HashAggregateKernel"(CKernel): + pass + + cdef cppclass CArity" arrow::compute::Arity": + int num_args + c_bool is_varargs + + cdef enum FunctionKind" arrow::compute::Function::Kind": + FunctionKind_SCALAR" arrow::compute::Function::SCALAR" + FunctionKind_VECTOR" arrow::compute::Function::VECTOR" + FunctionKind_SCALAR_AGGREGATE \ + " arrow::compute::Function::SCALAR_AGGREGATE" + FunctionKind_HASH_AGGREGATE \ + " arrow::compute::Function::HASH_AGGREGATE" + FunctionKind_META \ + " arrow::compute::Function::META" + + cdef cppclass CFunctionDoc" arrow::compute::FunctionDoc": + c_string summary + c_string description + vector[c_string] arg_names + c_string options_class + + cdef cppclass CFunctionOptionsType" arrow::compute::FunctionOptionsType": + const char* type_name() const + + cdef cppclass CFunctionOptions" arrow::compute::FunctionOptions": + const CFunctionOptionsType* options_type() const + const char* type_name() const + c_bool Equals(const CFunctionOptions& other) const + c_string ToString() const + unique_ptr[CFunctionOptions] Copy() const + CResult[shared_ptr[CBuffer]] Serialize() const + + @staticmethod + CResult[unique_ptr[CFunctionOptions]] Deserialize( + const c_string& type_name, const CBuffer& buffer) + + cdef cppclass CFunction" arrow::compute::Function": + const c_string& name() const + FunctionKind kind() const + const CArity& arity() const + const CFunctionDoc& doc() const + int num_kernels() const + CResult[CDatum] Execute(const vector[CDatum]& args, + const CFunctionOptions* options, + CExecContext* ctx) const + + cdef cppclass CScalarFunction" arrow::compute::ScalarFunction"(CFunction): + vector[const CScalarKernel*] kernels() const + + cdef cppclass CVectorFunction" arrow::compute::VectorFunction"(CFunction): + vector[const CVectorKernel*] kernels() const + + cdef cppclass CScalarAggregateFunction \ + " arrow::compute::ScalarAggregateFunction"(CFunction): + vector[const CScalarAggregateKernel*] kernels() const + + cdef cppclass CHashAggregateFunction \ + " arrow::compute::HashAggregateFunction"(CFunction): + vector[const CHashAggregateKernel*] kernels() const + + cdef cppclass CMetaFunction" arrow::compute::MetaFunction"(CFunction): + pass + + cdef cppclass CFunctionRegistry" arrow::compute::FunctionRegistry": + CResult[shared_ptr[CFunction]] GetFunction( + const c_string& name) const + vector[c_string] GetFunctionNames() const + int num_functions() const + + CFunctionRegistry* GetFunctionRegistry() + + cdef cppclass CElementWiseAggregateOptions \ + "arrow::compute::ElementWiseAggregateOptions"(CFunctionOptions): + CElementWiseAggregateOptions(c_bool skip_nulls) + c_bool skip_nulls + + ctypedef enum CRoundMode \ + "arrow::compute::RoundMode": + CRoundMode_DOWN \ + "arrow::compute::RoundMode::DOWN" + CRoundMode_UP \ + "arrow::compute::RoundMode::UP" + CRoundMode_TOWARDS_ZERO \ + "arrow::compute::RoundMode::TOWARDS_ZERO" + CRoundMode_TOWARDS_INFINITY \ + "arrow::compute::RoundMode::TOWARDS_INFINITY" + CRoundMode_HALF_DOWN \ + "arrow::compute::RoundMode::HALF_DOWN" + CRoundMode_HALF_UP \ + "arrow::compute::RoundMode::HALF_UP" + CRoundMode_HALF_TOWARDS_ZERO \ + "arrow::compute::RoundMode::HALF_TOWARDS_ZERO" + CRoundMode_HALF_TOWARDS_INFINITY \ + "arrow::compute::RoundMode::HALF_TOWARDS_INFINITY" + CRoundMode_HALF_TO_EVEN \ + "arrow::compute::RoundMode::HALF_TO_EVEN" + CRoundMode_HALF_TO_ODD \ + "arrow::compute::RoundMode::HALF_TO_ODD" + + cdef cppclass CRoundOptions \ + "arrow::compute::RoundOptions"(CFunctionOptions): + CRoundOptions(int64_t ndigits, CRoundMode round_mode) + int64_t ndigits + CRoundMode round_mode + + cdef cppclass CRoundToMultipleOptions \ + "arrow::compute::RoundToMultipleOptions"(CFunctionOptions): + CRoundToMultipleOptions(double multiple, CRoundMode round_mode) + double multiple + CRoundMode round_mode + + cdef enum CJoinNullHandlingBehavior \ + "arrow::compute::JoinOptions::NullHandlingBehavior": + CJoinNullHandlingBehavior_EMIT_NULL \ + "arrow::compute::JoinOptions::EMIT_NULL" + CJoinNullHandlingBehavior_SKIP \ + "arrow::compute::JoinOptions::SKIP" + CJoinNullHandlingBehavior_REPLACE \ + "arrow::compute::JoinOptions::REPLACE" + + cdef cppclass CJoinOptions \ + "arrow::compute::JoinOptions"(CFunctionOptions): + CJoinOptions(CJoinNullHandlingBehavior null_handling, + c_string null_replacement) + CJoinNullHandlingBehavior null_handling + c_string null_replacement + + cdef cppclass CMatchSubstringOptions \ + "arrow::compute::MatchSubstringOptions"(CFunctionOptions): + CMatchSubstringOptions(c_string pattern, c_bool ignore_case) + c_string pattern + c_bool ignore_case + + cdef cppclass CTrimOptions \ + "arrow::compute::TrimOptions"(CFunctionOptions): + CTrimOptions(c_string characters) + c_string characters + + cdef cppclass CPadOptions \ + "arrow::compute::PadOptions"(CFunctionOptions): + CPadOptions(int64_t width, c_string padding) + int64_t width + c_string padding + + cdef cppclass CSliceOptions \ + "arrow::compute::SliceOptions"(CFunctionOptions): + CSliceOptions(int64_t start, int64_t stop, int64_t step) + int64_t start + int64_t stop + int64_t step + + cdef cppclass CSplitOptions \ + "arrow::compute::SplitOptions"(CFunctionOptions): + CSplitOptions(int64_t max_splits, c_bool reverse) + int64_t max_splits + c_bool reverse + + cdef cppclass CSplitPatternOptions \ + "arrow::compute::SplitPatternOptions"(CFunctionOptions): + CSplitPatternOptions(c_string pattern, int64_t max_splits, + c_bool reverse) + int64_t max_splits + c_bool reverse + c_string pattern + + cdef cppclass CReplaceSliceOptions \ + "arrow::compute::ReplaceSliceOptions"(CFunctionOptions): + CReplaceSliceOptions(int64_t start, int64_t stop, c_string replacement) + int64_t start + int64_t stop + c_string replacement + + cdef cppclass CReplaceSubstringOptions \ + "arrow::compute::ReplaceSubstringOptions"(CFunctionOptions): + CReplaceSubstringOptions(c_string pattern, c_string replacement, + int64_t max_replacements) + c_string pattern + c_string replacement + int64_t max_replacements + + cdef cppclass CExtractRegexOptions \ + "arrow::compute::ExtractRegexOptions"(CFunctionOptions): + CExtractRegexOptions(c_string pattern) + c_string pattern + + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): + CCastOptions() + CCastOptions(c_bool safe) + CCastOptions(CCastOptions&& options) + + @staticmethod + CCastOptions Safe() + + @staticmethod + CCastOptions Unsafe() + shared_ptr[CDataType] to_type + c_bool allow_int_overflow + c_bool allow_time_truncate + c_bool allow_time_overflow + c_bool allow_decimal_truncate + c_bool allow_float_truncate + c_bool allow_invalid_utf8 + + cdef enum CFilterNullSelectionBehavior \ + "arrow::compute::FilterOptions::NullSelectionBehavior": + CFilterNullSelectionBehavior_DROP \ + "arrow::compute::FilterOptions::DROP" + CFilterNullSelectionBehavior_EMIT_NULL \ + "arrow::compute::FilterOptions::EMIT_NULL" + + cdef cppclass CFilterOptions \ + " arrow::compute::FilterOptions"(CFunctionOptions): + CFilterOptions() + CFilterOptions(CFilterNullSelectionBehavior null_selection_behavior) + CFilterNullSelectionBehavior null_selection_behavior + + cdef enum CDictionaryEncodeNullEncodingBehavior \ + "arrow::compute::DictionaryEncodeOptions::NullEncodingBehavior": + CDictionaryEncodeNullEncodingBehavior_ENCODE \ + "arrow::compute::DictionaryEncodeOptions::ENCODE" + CDictionaryEncodeNullEncodingBehavior_MASK \ + "arrow::compute::DictionaryEncodeOptions::MASK" + + cdef cppclass CDictionaryEncodeOptions \ + "arrow::compute::DictionaryEncodeOptions"(CFunctionOptions): + CDictionaryEncodeOptions( + CDictionaryEncodeNullEncodingBehavior null_encoding) + CDictionaryEncodeNullEncodingBehavior null_encoding + + cdef cppclass CTakeOptions \ + " arrow::compute::TakeOptions"(CFunctionOptions): + CTakeOptions(c_bool boundscheck) + c_bool boundscheck + + cdef cppclass CStrptimeOptions \ + "arrow::compute::StrptimeOptions"(CFunctionOptions): + CStrptimeOptions(c_string format, TimeUnit unit) + c_string format + TimeUnit unit + + cdef cppclass CStrftimeOptions \ + "arrow::compute::StrftimeOptions"(CFunctionOptions): + CStrftimeOptions(c_string format, c_string locale) + c_string format + c_string locale + + cdef cppclass CDayOfWeekOptions \ + "arrow::compute::DayOfWeekOptions"(CFunctionOptions): + CDayOfWeekOptions(c_bool count_from_zero, uint32_t week_start) + c_bool count_from_zero + uint32_t week_start + + cdef enum CAssumeTimezoneAmbiguous \ + "arrow::compute::AssumeTimezoneOptions::Ambiguous": + CAssumeTimezoneAmbiguous_AMBIGUOUS_RAISE \ + "arrow::compute::AssumeTimezoneOptions::AMBIGUOUS_RAISE" + CAssumeTimezoneAmbiguous_AMBIGUOUS_EARLIEST \ + "arrow::compute::AssumeTimezoneOptions::AMBIGUOUS_EARLIEST" + CAssumeTimezoneAmbiguous_AMBIGUOUS_LATEST \ + "arrow::compute::AssumeTimezoneOptions::AMBIGUOUS_LATEST" + + cdef enum CAssumeTimezoneNonexistent \ + "arrow::compute::AssumeTimezoneOptions::Nonexistent": + CAssumeTimezoneNonexistent_NONEXISTENT_RAISE \ + "arrow::compute::AssumeTimezoneOptions::NONEXISTENT_RAISE" + CAssumeTimezoneNonexistent_NONEXISTENT_EARLIEST \ + "arrow::compute::AssumeTimezoneOptions::NONEXISTENT_EARLIEST" + CAssumeTimezoneNonexistent_NONEXISTENT_LATEST \ + "arrow::compute::AssumeTimezoneOptions::NONEXISTENT_LATEST" + + cdef cppclass CAssumeTimezoneOptions \ + "arrow::compute::AssumeTimezoneOptions"(CFunctionOptions): + CAssumeTimezoneOptions(c_string timezone, + CAssumeTimezoneAmbiguous ambiguous, + CAssumeTimezoneNonexistent nonexistent) + c_string timezone + CAssumeTimezoneAmbiguous ambiguous + CAssumeTimezoneNonexistent nonexistent + + cdef cppclass CWeekOptions \ + "arrow::compute::WeekOptions"(CFunctionOptions): + CWeekOptions(c_bool week_starts_monday, c_bool count_from_zero, + c_bool first_week_is_fully_in_year) + c_bool week_starts_monday + c_bool count_from_zero + c_bool first_week_is_fully_in_year + + cdef cppclass CNullOptions \ + "arrow::compute::NullOptions"(CFunctionOptions): + CNullOptions(c_bool nan_is_null) + c_bool nan_is_null + + cdef cppclass CVarianceOptions \ + "arrow::compute::VarianceOptions"(CFunctionOptions): + CVarianceOptions(int ddof, c_bool skip_nulls, uint32_t min_count) + int ddof + c_bool skip_nulls + uint32_t min_count + + cdef cppclass CScalarAggregateOptions \ + "arrow::compute::ScalarAggregateOptions"(CFunctionOptions): + CScalarAggregateOptions(c_bool skip_nulls, uint32_t min_count) + c_bool skip_nulls + uint32_t min_count + + cdef enum CCountMode "arrow::compute::CountOptions::CountMode": + CCountMode_ONLY_VALID "arrow::compute::CountOptions::ONLY_VALID" + CCountMode_ONLY_NULL "arrow::compute::CountOptions::ONLY_NULL" + CCountMode_ALL "arrow::compute::CountOptions::ALL" + + cdef cppclass CCountOptions \ + "arrow::compute::CountOptions"(CFunctionOptions): + CCountOptions(CCountMode mode) + CCountMode mode + + cdef cppclass CModeOptions \ + "arrow::compute::ModeOptions"(CFunctionOptions): + CModeOptions(int64_t n, c_bool skip_nulls, uint32_t min_count) + int64_t n + c_bool skip_nulls + uint32_t min_count + + cdef cppclass CIndexOptions \ + "arrow::compute::IndexOptions"(CFunctionOptions): + CIndexOptions(shared_ptr[CScalar] value) + shared_ptr[CScalar] value + + cdef cppclass CMakeStructOptions \ + "arrow::compute::MakeStructOptions"(CFunctionOptions): + CMakeStructOptions(vector[c_string] n, + vector[c_bool] r, + vector[shared_ptr[const CKeyValueMetadata]] m) + CMakeStructOptions(vector[c_string] n) + vector[c_string] field_names + vector[c_bool] field_nullability + vector[shared_ptr[const CKeyValueMetadata]] field_metadata + + ctypedef enum CSortOrder" arrow::compute::SortOrder": + CSortOrder_Ascending \ + "arrow::compute::SortOrder::Ascending" + CSortOrder_Descending \ + "arrow::compute::SortOrder::Descending" + + ctypedef enum CNullPlacement" arrow::compute::NullPlacement": + CNullPlacement_AtStart \ + "arrow::compute::NullPlacement::AtStart" + CNullPlacement_AtEnd \ + "arrow::compute::NullPlacement::AtEnd" + + cdef cppclass CPartitionNthOptions \ + "arrow::compute::PartitionNthOptions"(CFunctionOptions): + CPartitionNthOptions(int64_t pivot, CNullPlacement) + int64_t pivot + CNullPlacement null_placement + + cdef cppclass CArraySortOptions \ + "arrow::compute::ArraySortOptions"(CFunctionOptions): + CArraySortOptions(CSortOrder, CNullPlacement) + CSortOrder order + CNullPlacement null_placement + + cdef cppclass CSortKey" arrow::compute::SortKey": + CSortKey(c_string name, CSortOrder order) + c_string name + CSortOrder order + + cdef cppclass CSortOptions \ + "arrow::compute::SortOptions"(CFunctionOptions): + CSortOptions(vector[CSortKey] sort_keys, CNullPlacement) + vector[CSortKey] sort_keys + CNullPlacement null_placement + + cdef cppclass CSelectKOptions \ + "arrow::compute::SelectKOptions"(CFunctionOptions): + CSelectKOptions(int64_t k, vector[CSortKey] sort_keys) + int64_t k + vector[CSortKey] sort_keys + + cdef enum CQuantileInterp \ + "arrow::compute::QuantileOptions::Interpolation": + CQuantileInterp_LINEAR "arrow::compute::QuantileOptions::LINEAR" + CQuantileInterp_LOWER "arrow::compute::QuantileOptions::LOWER" + CQuantileInterp_HIGHER "arrow::compute::QuantileOptions::HIGHER" + CQuantileInterp_NEAREST "arrow::compute::QuantileOptions::NEAREST" + CQuantileInterp_MIDPOINT "arrow::compute::QuantileOptions::MIDPOINT" + + cdef cppclass CQuantileOptions \ + "arrow::compute::QuantileOptions"(CFunctionOptions): + CQuantileOptions(vector[double] q, CQuantileInterp interpolation, + c_bool skip_nulls, uint32_t min_count) + vector[double] q + CQuantileInterp interpolation + c_bool skip_nulls + uint32_t min_count + + cdef cppclass CTDigestOptions \ + "arrow::compute::TDigestOptions"(CFunctionOptions): + CTDigestOptions(vector[double] q, + uint32_t delta, uint32_t buffer_size, + c_bool skip_nulls, uint32_t min_count) + vector[double] q + uint32_t delta + uint32_t buffer_size + c_bool skip_nulls + uint32_t min_count + + cdef enum DatumType" arrow::Datum::type": + DatumType_NONE" arrow::Datum::NONE" + DatumType_SCALAR" arrow::Datum::SCALAR" + DatumType_ARRAY" arrow::Datum::ARRAY" + DatumType_CHUNKED_ARRAY" arrow::Datum::CHUNKED_ARRAY" + DatumType_RECORD_BATCH" arrow::Datum::RECORD_BATCH" + DatumType_TABLE" arrow::Datum::TABLE" + DatumType_COLLECTION" arrow::Datum::COLLECTION" + + cdef cppclass CDatum" arrow::Datum": + CDatum() + CDatum(const shared_ptr[CArray]& value) + CDatum(const shared_ptr[CChunkedArray]& value) + CDatum(const shared_ptr[CScalar]& value) + CDatum(const shared_ptr[CRecordBatch]& value) + CDatum(const shared_ptr[CTable]& value) + + DatumType kind() const + c_string ToString() const + + const shared_ptr[CArrayData]& array() const + const shared_ptr[CChunkedArray]& chunked_array() const + const shared_ptr[CRecordBatch]& record_batch() const + const shared_ptr[CTable]& table() const + const shared_ptr[CScalar]& scalar() const + + cdef cppclass CSetLookupOptions \ + "arrow::compute::SetLookupOptions"(CFunctionOptions): + CSetLookupOptions(CDatum value_set, c_bool skip_nulls) + CDatum value_set + c_bool skip_nulls + + +cdef extern from * namespace "arrow::compute": + # inlined from compute/function_internal.h to avoid exposing + # implementation details + """ + #include "arrow/compute/function.h" + namespace arrow { + namespace compute { + namespace internal { + Result<std::unique_ptr<FunctionOptions>> DeserializeFunctionOptions( + const Buffer& buffer); + } // namespace internal + } // namespace compute + } // namespace arrow + """ + CResult[unique_ptr[CFunctionOptions]] DeserializeFunctionOptions \ + " arrow::compute::internal::DeserializeFunctionOptions"( + const CBuffer& buffer) + + +cdef extern from "arrow/python/api.h" namespace "arrow::py": + # Requires GIL + CResult[shared_ptr[CDataType]] InferArrowType( + object obj, object mask, c_bool pandas_null_sentinels) + + +cdef extern from "arrow/python/api.h" namespace "arrow::py::internal": + object NewMonthDayNanoTupleType() + CResult[PyObject*] MonthDayNanoIntervalArrayToPyList( + const CMonthDayNanoIntervalArray& array) + CResult[PyObject*] MonthDayNanoIntervalScalarToPyObject( + const CMonthDayNanoIntervalScalar& scalar) + + +cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil: + shared_ptr[CDataType] GetPrimitiveType(Type type) + + object PyHalf_FromHalf(npy_half value) + + cdef cppclass PyConversionOptions: + PyConversionOptions() + + shared_ptr[CDataType] type + int64_t size + CMemoryPool* pool + c_bool from_pandas + c_bool ignore_timezone + c_bool strict + + # TODO Some functions below are not actually "nogil" + + CResult[shared_ptr[CChunkedArray]] ConvertPySequence( + object obj, object mask, const PyConversionOptions& options, + CMemoryPool* pool) + + CStatus NumPyDtypeToArrow(object dtype, shared_ptr[CDataType]* type) + + CStatus NdarrayToArrow(CMemoryPool* pool, object ao, object mo, + c_bool from_pandas, + const shared_ptr[CDataType]& type, + shared_ptr[CChunkedArray]* out) + + CStatus NdarrayToArrow(CMemoryPool* pool, object ao, object mo, + c_bool from_pandas, + const shared_ptr[CDataType]& type, + const CCastOptions& cast_options, + shared_ptr[CChunkedArray]* out) + + CStatus NdarrayToTensor(CMemoryPool* pool, object ao, + const vector[c_string]& dim_names, + shared_ptr[CTensor]* out) + + CStatus TensorToNdarray(const shared_ptr[CTensor]& tensor, object base, + PyObject** out) + + CStatus SparseCOOTensorToNdarray( + const shared_ptr[CSparseCOOTensor]& sparse_tensor, object base, + PyObject** out_data, PyObject** out_coords) + + CStatus SparseCSRMatrixToNdarray( + const shared_ptr[CSparseCSRMatrix]& sparse_tensor, object base, + PyObject** out_data, PyObject** out_indptr, PyObject** out_indices) + + CStatus SparseCSCMatrixToNdarray( + const shared_ptr[CSparseCSCMatrix]& sparse_tensor, object base, + PyObject** out_data, PyObject** out_indptr, PyObject** out_indices) + + CStatus SparseCSFTensorToNdarray( + const shared_ptr[CSparseCSFTensor]& sparse_tensor, object base, + PyObject** out_data, PyObject** out_indptr, PyObject** out_indices) + + CStatus NdarraysToSparseCOOTensor(CMemoryPool* pool, object data_ao, + object coords_ao, + const vector[int64_t]& shape, + const vector[c_string]& dim_names, + shared_ptr[CSparseCOOTensor]* out) + + CStatus NdarraysToSparseCSRMatrix(CMemoryPool* pool, object data_ao, + object indptr_ao, object indices_ao, + const vector[int64_t]& shape, + const vector[c_string]& dim_names, + shared_ptr[CSparseCSRMatrix]* out) + + CStatus NdarraysToSparseCSCMatrix(CMemoryPool* pool, object data_ao, + object indptr_ao, object indices_ao, + const vector[int64_t]& shape, + const vector[c_string]& dim_names, + shared_ptr[CSparseCSCMatrix]* out) + + CStatus NdarraysToSparseCSFTensor(CMemoryPool* pool, object data_ao, + object indptr_ao, object indices_ao, + const vector[int64_t]& shape, + const vector[int64_t]& axis_order, + const vector[c_string]& dim_names, + shared_ptr[CSparseCSFTensor]* out) + + CStatus TensorToSparseCOOTensor(shared_ptr[CTensor], + shared_ptr[CSparseCOOTensor]* out) + + CStatus TensorToSparseCSRMatrix(shared_ptr[CTensor], + shared_ptr[CSparseCSRMatrix]* out) + + CStatus TensorToSparseCSCMatrix(shared_ptr[CTensor], + shared_ptr[CSparseCSCMatrix]* out) + + CStatus TensorToSparseCSFTensor(shared_ptr[CTensor], + shared_ptr[CSparseCSFTensor]* out) + + CStatus ConvertArrayToPandas(const PandasOptions& options, + shared_ptr[CArray] arr, + object py_ref, PyObject** out) + + CStatus ConvertChunkedArrayToPandas(const PandasOptions& options, + shared_ptr[CChunkedArray] arr, + object py_ref, PyObject** out) + + CStatus ConvertTableToPandas(const PandasOptions& options, + shared_ptr[CTable] table, + PyObject** out) + + void c_set_default_memory_pool \ + " arrow::py::set_default_memory_pool"(CMemoryPool* pool)\ + + CMemoryPool* c_get_memory_pool \ + " arrow::py::get_memory_pool"() + + cdef cppclass PyBuffer(CBuffer): + @staticmethod + CResult[shared_ptr[CBuffer]] FromPyObject(object obj) + + cdef cppclass PyForeignBuffer(CBuffer): + @staticmethod + CStatus Make(const uint8_t* data, int64_t size, object base, + shared_ptr[CBuffer]* out) + + cdef cppclass PyReadableFile(CRandomAccessFile): + PyReadableFile(object fo) + + cdef cppclass PyOutputStream(COutputStream): + PyOutputStream(object fo) + + cdef cppclass PandasOptions: + CMemoryPool* pool + c_bool strings_to_categorical + c_bool zero_copy_only + c_bool integer_object_nulls + c_bool date_as_object + c_bool timestamp_as_object + c_bool use_threads + c_bool coerce_temporal_nanoseconds + c_bool ignore_timezone + c_bool deduplicate_objects + c_bool safe_cast + c_bool split_blocks + c_bool self_destruct + c_bool decode_dictionaries + unordered_set[c_string] categorical_columns + unordered_set[c_string] extension_columns + + cdef cppclass CSerializedPyObject" arrow::py::SerializedPyObject": + shared_ptr[CRecordBatch] batch + vector[shared_ptr[CTensor]] tensors + + CStatus WriteTo(COutputStream* dst) + CStatus GetComponents(CMemoryPool* pool, PyObject** dst) + + CStatus SerializeObject(object context, object sequence, + CSerializedPyObject* out) + + CStatus DeserializeObject(object context, + const CSerializedPyObject& obj, + PyObject* base, PyObject** out) + + CStatus ReadSerializedObject(CRandomAccessFile* src, + CSerializedPyObject* out) + + cdef cppclass SparseTensorCounts: + SparseTensorCounts() + int coo + int csr + int csc + int csf + int ndim_csf + int num_total_tensors() const + int num_total_buffers() const + + CStatus GetSerializedFromComponents( + int num_tensors, + const SparseTensorCounts& num_sparse_tensors, + int num_ndarrays, + int num_buffers, + object buffers, + CSerializedPyObject* out) + + +cdef extern from "arrow/python/api.h" namespace "arrow::py::internal" nogil: + cdef cppclass CTimePoint "arrow::py::internal::TimePoint": + pass + + CTimePoint PyDateTime_to_TimePoint(PyDateTime_DateTime* pydatetime) + int64_t TimePoint_to_ns(CTimePoint val) + CTimePoint TimePoint_from_s(double val) + CTimePoint TimePoint_from_ns(int64_t val) + + CResult[c_string] TzinfoToString(PyObject* pytzinfo) + CResult[PyObject*] StringToTzinfo(c_string) + + +cdef extern from "arrow/python/init.h": + int arrow_init_numpy() except -1 + + +cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py": + int import_pyarrow() except -1 + + +cdef extern from "arrow/python/common.h" namespace "arrow::py": + c_bool IsPyError(const CStatus& status) + void RestorePyError(const CStatus& status) + + +cdef extern from "arrow/python/inference.h" namespace "arrow::py": + c_bool IsPyBool(object o) + c_bool IsPyInt(object o) + c_bool IsPyFloat(object o) + + +cdef extern from "arrow/python/ipc.h" namespace "arrow::py": + cdef cppclass CPyRecordBatchReader" arrow::py::PyRecordBatchReader" \ + (CRecordBatchReader): + @staticmethod + CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CSchema], + object) + + +cdef extern from "arrow/extension_type.h" namespace "arrow": + cdef cppclass CExtensionTypeRegistry" arrow::ExtensionTypeRegistry": + @staticmethod + shared_ptr[CExtensionTypeRegistry] GetGlobalRegistry() + + cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType): + c_string extension_name() + shared_ptr[CDataType] storage_type() + + @staticmethod + shared_ptr[CArray] WrapArray(shared_ptr[CDataType] ext_type, + shared_ptr[CArray] storage) + + @staticmethod + shared_ptr[CChunkedArray] WrapArray(shared_ptr[CDataType] ext_type, + shared_ptr[CChunkedArray] storage) + + cdef cppclass CExtensionArray" arrow::ExtensionArray"(CArray): + CExtensionArray(shared_ptr[CDataType], shared_ptr[CArray] storage) + + shared_ptr[CArray] storage() + + +cdef extern from "arrow/python/extension_type.h" namespace "arrow::py": + cdef cppclass CPyExtensionType \ + " arrow::py::PyExtensionType"(CExtensionType): + @staticmethod + CStatus FromClass(const shared_ptr[CDataType] storage_type, + const c_string extension_name, object typ, + shared_ptr[CExtensionType]* out) + + @staticmethod + CStatus FromInstance(shared_ptr[CDataType] storage_type, + object inst, shared_ptr[CExtensionType]* out) + + object GetInstance() + CStatus SetInstance(object) + + c_string PyExtensionName() + CStatus RegisterPyExtensionType(shared_ptr[CDataType]) + CStatus UnregisterPyExtensionType(c_string type_name) + + +cdef extern from "arrow/python/benchmark.h" namespace "arrow::py::benchmark": + void Benchmark_PandasObjectIsNull(object lst) except * + + +cdef extern from "arrow/util/compression.h" namespace "arrow" nogil: + cdef enum CCompressionType" arrow::Compression::type": + CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED" + CCompressionType_SNAPPY" arrow::Compression::SNAPPY" + CCompressionType_GZIP" arrow::Compression::GZIP" + CCompressionType_BROTLI" arrow::Compression::BROTLI" + CCompressionType_ZSTD" arrow::Compression::ZSTD" + CCompressionType_LZ4" arrow::Compression::LZ4" + CCompressionType_LZ4_FRAME" arrow::Compression::LZ4_FRAME" + CCompressionType_BZ2" arrow::Compression::BZ2" + + cdef cppclass CCodec" arrow::util::Codec": + @staticmethod + CResult[unique_ptr[CCodec]] Create(CCompressionType codec) + + @staticmethod + CResult[unique_ptr[CCodec]] CreateWithLevel" Create"( + CCompressionType codec, + int compression_level) + + @staticmethod + c_bool SupportsCompressionLevel(CCompressionType codec) + + @staticmethod + CResult[int] MinimumCompressionLevel(CCompressionType codec) + + @staticmethod + CResult[int] MaximumCompressionLevel(CCompressionType codec) + + @staticmethod + CResult[int] DefaultCompressionLevel(CCompressionType codec) + + @staticmethod + c_bool IsAvailable(CCompressionType codec) + + CResult[int64_t] Decompress(int64_t input_len, const uint8_t* input, + int64_t output_len, + uint8_t* output_buffer) + CResult[int64_t] Compress(int64_t input_len, const uint8_t* input, + int64_t output_buffer_len, + uint8_t* output_buffer) + c_string name() const + int compression_level() const + int64_t MaxCompressedLen(int64_t input_len, const uint8_t* input) + + +cdef extern from "arrow/util/io_util.h" namespace "arrow::internal" nogil: + int ErrnoFromStatus(CStatus status) + int WinErrorFromStatus(CStatus status) + int SignalFromStatus(CStatus status) + + CStatus SendSignal(int signum) + CStatus SendSignalToThread(int signum, uint64_t thread_id) + + +cdef extern from "arrow/util/iterator.h" namespace "arrow" nogil: + cdef cppclass CIterator" arrow::Iterator"[T]: + CResult[T] Next() + CStatus Visit[Visitor](Visitor&& visitor) + cppclass RangeIterator: + CResult[T] operator*() + RangeIterator& operator++() + c_bool operator!=(RangeIterator) const + RangeIterator begin() + RangeIterator end() + CIterator[T] MakeVectorIterator[T](vector[T] v) + +cdef extern from "arrow/util/thread_pool.h" namespace "arrow" nogil: + int GetCpuThreadPoolCapacity() + CStatus SetCpuThreadPoolCapacity(int threads) + +cdef extern from "arrow/array/concatenate.h" namespace "arrow" nogil: + CResult[shared_ptr[CArray]] Concatenate( + const vector[shared_ptr[CArray]]& arrays, + CMemoryPool* pool) + +cdef extern from "arrow/c/abi.h": + cdef struct ArrowSchema: + pass + + cdef struct ArrowArray: + pass + + cdef struct ArrowArrayStream: + pass + +cdef extern from "arrow/c/bridge.h" namespace "arrow" nogil: + CStatus ExportType(CDataType&, ArrowSchema* out) + CResult[shared_ptr[CDataType]] ImportType(ArrowSchema*) + + CStatus ExportField(CField&, ArrowSchema* out) + CResult[shared_ptr[CField]] ImportField(ArrowSchema*) + + CStatus ExportSchema(CSchema&, ArrowSchema* out) + CResult[shared_ptr[CSchema]] ImportSchema(ArrowSchema*) + + CStatus ExportArray(CArray&, ArrowArray* out) + CStatus ExportArray(CArray&, ArrowArray* out, ArrowSchema* out_schema) + CResult[shared_ptr[CArray]] ImportArray(ArrowArray*, + shared_ptr[CDataType]) + CResult[shared_ptr[CArray]] ImportArray(ArrowArray*, ArrowSchema*) + + CStatus ExportRecordBatch(CRecordBatch&, ArrowArray* out) + CStatus ExportRecordBatch(CRecordBatch&, ArrowArray* out, + ArrowSchema* out_schema) + CResult[shared_ptr[CRecordBatch]] ImportRecordBatch(ArrowArray*, + shared_ptr[CSchema]) + CResult[shared_ptr[CRecordBatch]] ImportRecordBatch(ArrowArray*, + ArrowSchema*) + + CStatus ExportRecordBatchReader(shared_ptr[CRecordBatchReader], + ArrowArrayStream*) + CResult[shared_ptr[CRecordBatchReader]] ImportRecordBatchReader( + ArrowArrayStream*) diff --git a/src/arrow/python/pyarrow/includes/libarrow_cuda.pxd b/src/arrow/python/pyarrow/includes/libarrow_cuda.pxd new file mode 100644 index 000000000..3ac943cf9 --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libarrow_cuda.pxd @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.libarrow cimport * + +cdef extern from "arrow/gpu/cuda_api.h" namespace "arrow::cuda" nogil: + + cdef cppclass CCudaDeviceManager" arrow::cuda::CudaDeviceManager": + @staticmethod + CResult[CCudaDeviceManager*] Instance() + CResult[shared_ptr[CCudaContext]] GetContext(int gpu_number) + CResult[shared_ptr[CCudaContext]] GetSharedContext(int gpu_number, + void* handle) + CStatus AllocateHost(int device_number, int64_t nbytes, + shared_ptr[CCudaHostBuffer]* buffer) + int num_devices() const + + cdef cppclass CCudaContext" arrow::cuda::CudaContext": + CResult[shared_ptr[CCudaBuffer]] Allocate(int64_t nbytes) + CResult[shared_ptr[CCudaBuffer]] View(uint8_t* data, int64_t nbytes) + CResult[shared_ptr[CCudaBuffer]] OpenIpcBuffer( + const CCudaIpcMemHandle& ipc_handle) + CStatus Synchronize() + int64_t bytes_allocated() const + const void* handle() const + int device_number() const + CResult[uintptr_t] GetDeviceAddress(uintptr_t addr) + + cdef cppclass CCudaIpcMemHandle" arrow::cuda::CudaIpcMemHandle": + @staticmethod + CResult[shared_ptr[CCudaIpcMemHandle]] FromBuffer( + const void* opaque_handle) + CResult[shared_ptr[CBuffer]] Serialize(CMemoryPool* pool) const + + cdef cppclass CCudaBuffer" arrow::cuda::CudaBuffer"(CBuffer): + CCudaBuffer(uint8_t* data, int64_t size, + const shared_ptr[CCudaContext]& context, + c_bool own_data=false, c_bool is_ipc=false) + CCudaBuffer(const shared_ptr[CCudaBuffer]& parent, + const int64_t offset, const int64_t size) + + @staticmethod + CResult[shared_ptr[CCudaBuffer]] FromBuffer(shared_ptr[CBuffer] buf) + + CStatus CopyToHost(const int64_t position, const int64_t nbytes, + void* out) const + CStatus CopyFromHost(const int64_t position, const void* data, + int64_t nbytes) + CStatus CopyFromDevice(const int64_t position, const void* data, + int64_t nbytes) + CStatus CopyFromAnotherDevice(const shared_ptr[CCudaContext]& src_ctx, + const int64_t position, const void* data, + int64_t nbytes) + CResult[shared_ptr[CCudaIpcMemHandle]] ExportForIpc() + shared_ptr[CCudaContext] context() const + + cdef cppclass \ + CCudaHostBuffer" arrow::cuda::CudaHostBuffer"(CMutableBuffer): + pass + + cdef cppclass \ + CCudaBufferReader" arrow::cuda::CudaBufferReader"(CBufferReader): + CCudaBufferReader(const shared_ptr[CBuffer]& buffer) + CResult[int64_t] Read(int64_t nbytes, void* buffer) + CResult[shared_ptr[CBuffer]] Read(int64_t nbytes) + + cdef cppclass \ + CCudaBufferWriter" arrow::cuda::CudaBufferWriter"(WritableFile): + CCudaBufferWriter(const shared_ptr[CCudaBuffer]& buffer) + CStatus Close() + CStatus Write(const void* data, int64_t nbytes) + CStatus WriteAt(int64_t position, const void* data, int64_t nbytes) + CStatus SetBufferSize(const int64_t buffer_size) + int64_t buffer_size() + int64_t num_bytes_buffered() const + + CResult[shared_ptr[CCudaHostBuffer]] AllocateCudaHostBuffer( + int device_number, const int64_t size) + + # Cuda prefix is added to avoid picking up arrow::cuda functions + # from arrow namespace. + CResult[shared_ptr[CCudaBuffer]] \ + CudaSerializeRecordBatch" arrow::cuda::SerializeRecordBatch"\ + (const CRecordBatch& batch, + CCudaContext* ctx) + CResult[shared_ptr[CRecordBatch]] \ + CudaReadRecordBatch" arrow::cuda::ReadRecordBatch"\ + (const shared_ptr[CSchema]& schema, + CDictionaryMemo* dictionary_memo, + const shared_ptr[CCudaBuffer]& buffer, + CMemoryPool* pool) diff --git a/src/arrow/python/pyarrow/includes/libarrow_dataset.pxd b/src/arrow/python/pyarrow/includes/libarrow_dataset.pxd new file mode 100644 index 000000000..abc79fea8 --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libarrow_dataset.pxd @@ -0,0 +1,478 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from libcpp.unordered_map cimport unordered_map + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_fs cimport * +from pyarrow._parquet cimport * + + +cdef extern from "arrow/api.h" namespace "arrow" nogil: + + cdef cppclass CRecordBatchIterator "arrow::RecordBatchIterator"( + CIterator[shared_ptr[CRecordBatch]]): + pass + + +cdef extern from * namespace "arrow::compute": + # inlined from expression_internal.h to avoid + # proliferation of #include <unordered_map> + """ + #include <unordered_map> + + #include "arrow/type.h" + #include "arrow/datum.h" + + namespace arrow { + namespace compute { + struct KnownFieldValues { + std::unordered_map<FieldRef, Datum, FieldRef::Hash> map; + }; + } // namespace compute + } // namespace arrow + """ + cdef struct CKnownFieldValues "arrow::compute::KnownFieldValues": + unordered_map[CFieldRef, CDatum, CFieldRefHash] map + +cdef extern from "arrow/compute/exec/expression.h" \ + namespace "arrow::compute" nogil: + + cdef cppclass CExpression "arrow::compute::Expression": + c_bool Equals(const CExpression& other) const + c_string ToString() const + CResult[CExpression] Bind(const CSchema&) + + cdef CExpression CMakeScalarExpression \ + "arrow::compute::literal"(shared_ptr[CScalar] value) + + cdef CExpression CMakeFieldExpression \ + "arrow::compute::field_ref"(c_string name) + + cdef CExpression CMakeCallExpression \ + "arrow::compute::call"(c_string function, + vector[CExpression] arguments, + shared_ptr[CFunctionOptions] options) + + cdef CResult[shared_ptr[CBuffer]] CSerializeExpression \ + "arrow::compute::Serialize"(const CExpression&) + + cdef CResult[CExpression] CDeserializeExpression \ + "arrow::compute::Deserialize"(shared_ptr[CBuffer]) + + cdef CResult[CKnownFieldValues] \ + CExtractKnownFieldValues "arrow::compute::ExtractKnownFieldValues"( + const CExpression& partition_expression) + +ctypedef CStatus cb_writer_finish_internal(CFileWriter*) +ctypedef void cb_writer_finish(dict, CFileWriter*) + +cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: + + cdef enum ExistingDataBehavior" arrow::dataset::ExistingDataBehavior": + ExistingDataBehavior_DELETE_MATCHING" \ + arrow::dataset::ExistingDataBehavior::kDeleteMatchingPartitions" + ExistingDataBehavior_OVERWRITE_OR_IGNORE" \ + arrow::dataset::ExistingDataBehavior::kOverwriteOrIgnore" + ExistingDataBehavior_ERROR" \ + arrow::dataset::ExistingDataBehavior::kError" + + cdef cppclass CScanOptions "arrow::dataset::ScanOptions": + @staticmethod + shared_ptr[CScanOptions] Make(shared_ptr[CSchema] schema) + + shared_ptr[CSchema] dataset_schema + shared_ptr[CSchema] projected_schema + + cdef cppclass CFragmentScanOptions "arrow::dataset::FragmentScanOptions": + c_string type_name() const + + ctypedef CIterator[shared_ptr[CScanTask]] CScanTaskIterator \ + "arrow::dataset::ScanTaskIterator" + + cdef cppclass CScanTask" arrow::dataset::ScanTask": + CResult[CRecordBatchIterator] Execute() + + cdef cppclass CFragment "arrow::dataset::Fragment": + CResult[shared_ptr[CSchema]] ReadPhysicalSchema() + CResult[CScanTaskIterator] Scan(shared_ptr[CScanOptions] options) + c_bool splittable() const + c_string type_name() const + const CExpression& partition_expression() const + + ctypedef vector[shared_ptr[CFragment]] CFragmentVector \ + "arrow::dataset::FragmentVector" + + ctypedef CIterator[shared_ptr[CFragment]] CFragmentIterator \ + "arrow::dataset::FragmentIterator" + + cdef cppclass CInMemoryFragment "arrow::dataset::InMemoryFragment"( + CFragment): + CInMemoryFragment(vector[shared_ptr[CRecordBatch]] record_batches, + CExpression partition_expression) + + cdef cppclass CTaggedRecordBatch "arrow::dataset::TaggedRecordBatch": + shared_ptr[CRecordBatch] record_batch + shared_ptr[CFragment] fragment + + ctypedef CIterator[CTaggedRecordBatch] CTaggedRecordBatchIterator \ + "arrow::dataset::TaggedRecordBatchIterator" + + cdef cppclass CScanner "arrow::dataset::Scanner": + CScanner(shared_ptr[CDataset], shared_ptr[CScanOptions]) + CScanner(shared_ptr[CFragment], shared_ptr[CScanOptions]) + CResult[CScanTaskIterator] Scan() + CResult[CTaggedRecordBatchIterator] ScanBatches() + CResult[shared_ptr[CTable]] ToTable() + CResult[shared_ptr[CTable]] TakeRows(const CArray& indices) + CResult[shared_ptr[CTable]] Head(int64_t num_rows) + CResult[int64_t] CountRows() + CResult[CFragmentIterator] GetFragments() + CResult[shared_ptr[CRecordBatchReader]] ToRecordBatchReader() + const shared_ptr[CScanOptions]& options() + + cdef cppclass CScannerBuilder "arrow::dataset::ScannerBuilder": + CScannerBuilder(shared_ptr[CDataset], + shared_ptr[CScanOptions] scan_options) + CScannerBuilder(shared_ptr[CSchema], shared_ptr[CFragment], + shared_ptr[CScanOptions] scan_options) + + @staticmethod + shared_ptr[CScannerBuilder] FromRecordBatchReader( + shared_ptr[CRecordBatchReader] reader) + CStatus ProjectColumns "Project"(const vector[c_string]& columns) + CStatus Project(vector[CExpression]& exprs, vector[c_string]& columns) + CStatus Filter(CExpression filter) + CStatus UseThreads(c_bool use_threads) + CStatus UseAsync(c_bool use_async) + CStatus Pool(CMemoryPool* pool) + CStatus BatchSize(int64_t batch_size) + CStatus FragmentScanOptions( + shared_ptr[CFragmentScanOptions] fragment_scan_options) + CResult[shared_ptr[CScanner]] Finish() + shared_ptr[CSchema] schema() const + + ctypedef vector[shared_ptr[CDataset]] CDatasetVector \ + "arrow::dataset::DatasetVector" + + cdef cppclass CDataset "arrow::dataset::Dataset": + const shared_ptr[CSchema] & schema() + CResult[CFragmentIterator] GetFragments() + CResult[CFragmentIterator] GetFragments(CExpression predicate) + const CExpression & partition_expression() + c_string type_name() + + CResult[shared_ptr[CDataset]] ReplaceSchema(shared_ptr[CSchema]) + + CResult[shared_ptr[CScannerBuilder]] NewScan() + + cdef cppclass CInMemoryDataset "arrow::dataset::InMemoryDataset"( + CDataset): + CInMemoryDataset(shared_ptr[CRecordBatchReader]) + CInMemoryDataset(shared_ptr[CTable]) + + cdef cppclass CUnionDataset "arrow::dataset::UnionDataset"( + CDataset): + @staticmethod + CResult[shared_ptr[CUnionDataset]] Make(shared_ptr[CSchema] schema, + CDatasetVector children) + + const CDatasetVector& children() const + + cdef cppclass CInspectOptions "arrow::dataset::InspectOptions": + int fragments + + cdef cppclass CFinishOptions "arrow::dataset::FinishOptions": + shared_ptr[CSchema] schema + CInspectOptions inspect_options + c_bool validate_fragments + + cdef cppclass CDatasetFactory "arrow::dataset::DatasetFactory": + CResult[vector[shared_ptr[CSchema]]] InspectSchemas(CInspectOptions) + CResult[shared_ptr[CSchema]] Inspect(CInspectOptions) + CResult[shared_ptr[CDataset]] FinishWithSchema "Finish"( + const shared_ptr[CSchema]& schema) + CResult[shared_ptr[CDataset]] Finish() + const CExpression& root_partition() + CStatus SetRootPartition(CExpression partition) + + cdef cppclass CUnionDatasetFactory "arrow::dataset::UnionDatasetFactory": + @staticmethod + CResult[shared_ptr[CDatasetFactory]] Make( + vector[shared_ptr[CDatasetFactory]] factories) + + cdef cppclass CFileSource "arrow::dataset::FileSource": + const c_string& path() const + const shared_ptr[CFileSystem]& filesystem() const + const shared_ptr[CBuffer]& buffer() const + # HACK: Cython can't handle all the overloads so don't declare them. + # This means invalid construction of CFileSource won't be caught in + # the C++ generation phase (though it will still be caught when + # the generated C++ is compiled). + CFileSource(...) + + cdef cppclass CFileWriteOptions \ + "arrow::dataset::FileWriteOptions": + const shared_ptr[CFileFormat]& format() const + c_string type_name() const + + cdef cppclass CFileWriter \ + "arrow::dataset::FileWriter": + const shared_ptr[CFileFormat]& format() const + const shared_ptr[CSchema]& schema() const + const shared_ptr[CFileWriteOptions]& options() const + const CFileLocator& destination() const + + cdef cppclass CParquetFileWriter \ + "arrow::dataset::ParquetFileWriter"(CFileWriter): + const shared_ptr[FileWriter]& parquet_writer() const + + cdef cppclass CFileFormat "arrow::dataset::FileFormat": + shared_ptr[CFragmentScanOptions] default_fragment_scan_options + c_string type_name() const + CResult[shared_ptr[CSchema]] Inspect(const CFileSource&) const + CResult[shared_ptr[CFileFragment]] MakeFragment( + CFileSource source, + CExpression partition_expression, + shared_ptr[CSchema] physical_schema) + shared_ptr[CFileWriteOptions] DefaultWriteOptions() + + cdef cppclass CFileFragment "arrow::dataset::FileFragment"( + CFragment): + const CFileSource& source() const + const shared_ptr[CFileFormat]& format() const + + cdef cppclass CParquetFileWriteOptions \ + "arrow::dataset::ParquetFileWriteOptions"(CFileWriteOptions): + shared_ptr[WriterProperties] writer_properties + shared_ptr[ArrowWriterProperties] arrow_writer_properties + + cdef cppclass CParquetFileFragment "arrow::dataset::ParquetFileFragment"( + CFileFragment): + const vector[int]& row_groups() const + shared_ptr[CFileMetaData] metadata() const + CResult[vector[shared_ptr[CFragment]]] SplitByRowGroup( + CExpression predicate) + CResult[shared_ptr[CFragment]] SubsetWithFilter "Subset"( + CExpression predicate) + CResult[shared_ptr[CFragment]] SubsetWithIds "Subset"( + vector[int] row_group_ids) + CStatus EnsureCompleteMetadata() + + cdef cppclass CFileSystemDatasetWriteOptions \ + "arrow::dataset::FileSystemDatasetWriteOptions": + shared_ptr[CFileWriteOptions] file_write_options + shared_ptr[CFileSystem] filesystem + c_string base_dir + shared_ptr[CPartitioning] partitioning + int max_partitions + c_string basename_template + function[cb_writer_finish_internal] writer_pre_finish + function[cb_writer_finish_internal] writer_post_finish + ExistingDataBehavior existing_data_behavior + + cdef cppclass CFileSystemDataset \ + "arrow::dataset::FileSystemDataset"(CDataset): + @staticmethod + CResult[shared_ptr[CDataset]] Make( + shared_ptr[CSchema] schema, + CExpression source_partition, + shared_ptr[CFileFormat] format, + shared_ptr[CFileSystem] filesystem, + vector[shared_ptr[CFileFragment]] fragments) + + @staticmethod + CStatus Write( + const CFileSystemDatasetWriteOptions& write_options, + shared_ptr[CScanner] scanner) + + c_string type() + vector[c_string] files() + const shared_ptr[CFileFormat]& format() const + const shared_ptr[CFileSystem]& filesystem() const + const shared_ptr[CPartitioning]& partitioning() const + + cdef cppclass CParquetFileFormatReaderOptions \ + "arrow::dataset::ParquetFileFormat::ReaderOptions": + unordered_set[c_string] dict_columns + TimeUnit coerce_int96_timestamp_unit + + cdef cppclass CParquetFileFormat "arrow::dataset::ParquetFileFormat"( + CFileFormat): + CParquetFileFormatReaderOptions reader_options + CResult[shared_ptr[CFileFragment]] MakeFragment( + CFileSource source, + CExpression partition_expression, + shared_ptr[CSchema] physical_schema, + vector[int] row_groups) + + cdef cppclass CParquetFragmentScanOptions \ + "arrow::dataset::ParquetFragmentScanOptions"(CFragmentScanOptions): + shared_ptr[CReaderProperties] reader_properties + shared_ptr[ArrowReaderProperties] arrow_reader_properties + c_bool enable_parallel_column_conversion + + cdef cppclass CIpcFileWriteOptions \ + "arrow::dataset::IpcFileWriteOptions"(CFileWriteOptions): + pass + + cdef cppclass CIpcFileFormat "arrow::dataset::IpcFileFormat"( + CFileFormat): + pass + + cdef cppclass COrcFileFormat "arrow::dataset::OrcFileFormat"( + CFileFormat): + pass + + cdef cppclass CCsvFileWriteOptions \ + "arrow::dataset::CsvFileWriteOptions"(CFileWriteOptions): + shared_ptr[CCSVWriteOptions] write_options + CMemoryPool* pool + + cdef cppclass CCsvFileFormat "arrow::dataset::CsvFileFormat"( + CFileFormat): + CCSVParseOptions parse_options + + cdef cppclass CCsvFragmentScanOptions \ + "arrow::dataset::CsvFragmentScanOptions"(CFragmentScanOptions): + CCSVConvertOptions convert_options + CCSVReadOptions read_options + + cdef cppclass CPartitioning "arrow::dataset::Partitioning": + c_string type_name() const + CResult[CExpression] Parse(const c_string & path) const + const shared_ptr[CSchema] & schema() + + cdef cppclass CSegmentEncoding" arrow::dataset::SegmentEncoding": + pass + + CSegmentEncoding CSegmentEncodingNone\ + " arrow::dataset::SegmentEncoding::None" + CSegmentEncoding CSegmentEncodingUri\ + " arrow::dataset::SegmentEncoding::Uri" + + cdef cppclass CKeyValuePartitioningOptions \ + "arrow::dataset::KeyValuePartitioningOptions": + CSegmentEncoding segment_encoding + + cdef cppclass CHivePartitioningOptions \ + "arrow::dataset::HivePartitioningOptions": + CSegmentEncoding segment_encoding + c_string null_fallback + + cdef cppclass CPartitioningFactoryOptions \ + "arrow::dataset::PartitioningFactoryOptions": + c_bool infer_dictionary + shared_ptr[CSchema] schema + CSegmentEncoding segment_encoding + + cdef cppclass CHivePartitioningFactoryOptions \ + "arrow::dataset::HivePartitioningFactoryOptions": + c_bool infer_dictionary + c_string null_fallback + shared_ptr[CSchema] schema + CSegmentEncoding segment_encoding + + cdef cppclass CPartitioningFactory "arrow::dataset::PartitioningFactory": + c_string type_name() const + + cdef cppclass CDirectoryPartitioning \ + "arrow::dataset::DirectoryPartitioning"(CPartitioning): + CDirectoryPartitioning(shared_ptr[CSchema] schema, + vector[shared_ptr[CArray]] dictionaries) + + @staticmethod + shared_ptr[CPartitioningFactory] MakeFactory( + vector[c_string] field_names, CPartitioningFactoryOptions) + + vector[shared_ptr[CArray]] dictionaries() const + + cdef cppclass CHivePartitioning \ + "arrow::dataset::HivePartitioning"(CPartitioning): + CHivePartitioning(shared_ptr[CSchema] schema, + vector[shared_ptr[CArray]] dictionaries, + CHivePartitioningOptions options) + + @staticmethod + shared_ptr[CPartitioningFactory] MakeFactory( + CHivePartitioningFactoryOptions) + + vector[shared_ptr[CArray]] dictionaries() const + + cdef cppclass CPartitioningOrFactory \ + "arrow::dataset::PartitioningOrFactory": + CPartitioningOrFactory(shared_ptr[CPartitioning]) + CPartitioningOrFactory(shared_ptr[CPartitioningFactory]) + CPartitioningOrFactory & operator = (shared_ptr[CPartitioning]) + CPartitioningOrFactory & operator = ( + shared_ptr[CPartitioningFactory]) + shared_ptr[CPartitioning] partitioning() const + shared_ptr[CPartitioningFactory] factory() const + + cdef cppclass CFileSystemFactoryOptions \ + "arrow::dataset::FileSystemFactoryOptions": + CPartitioningOrFactory partitioning + c_string partition_base_dir + c_bool exclude_invalid_files + vector[c_string] selector_ignore_prefixes + + cdef cppclass CFileSystemDatasetFactory \ + "arrow::dataset::FileSystemDatasetFactory"( + CDatasetFactory): + @staticmethod + CResult[shared_ptr[CDatasetFactory]] MakeFromPaths "Make"( + shared_ptr[CFileSystem] filesystem, + vector[c_string] paths, + shared_ptr[CFileFormat] format, + CFileSystemFactoryOptions options + ) + + @staticmethod + CResult[shared_ptr[CDatasetFactory]] MakeFromSelector "Make"( + shared_ptr[CFileSystem] filesystem, + CFileSelector, + shared_ptr[CFileFormat] format, + CFileSystemFactoryOptions options + ) + + cdef cppclass CParquetFactoryOptions \ + "arrow::dataset::ParquetFactoryOptions": + CPartitioningOrFactory partitioning + c_string partition_base_dir + c_bool validate_column_chunk_paths + + cdef cppclass CParquetDatasetFactory \ + "arrow::dataset::ParquetDatasetFactory"(CDatasetFactory): + @staticmethod + CResult[shared_ptr[CDatasetFactory]] MakeFromMetaDataPath "Make"( + const c_string& metadata_path, + shared_ptr[CFileSystem] filesystem, + shared_ptr[CParquetFileFormat] format, + CParquetFactoryOptions options + ) + + @staticmethod + CResult[shared_ptr[CDatasetFactory]] MakeFromMetaDataSource "Make"( + const CFileSource& metadata_path, + const c_string& base_path, + shared_ptr[CFileSystem] filesystem, + shared_ptr[CParquetFileFormat] format, + CParquetFactoryOptions options + ) diff --git a/src/arrow/python/pyarrow/includes/libarrow_feather.pxd b/src/arrow/python/pyarrow/includes/libarrow_feather.pxd new file mode 100644 index 000000000..ddfc8b2e5 --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libarrow_feather.pxd @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.libarrow cimport (CCompressionType, CStatus, CTable, + COutputStream, CResult, shared_ptr, + vector, CRandomAccessFile, CSchema, + c_string) + + +cdef extern from "arrow/ipc/api.h" namespace "arrow::ipc" nogil: + int kFeatherV1Version" arrow::ipc::feather::kFeatherV1Version" + int kFeatherV2Version" arrow::ipc::feather::kFeatherV2Version" + + cdef cppclass CFeatherProperties" arrow::ipc::feather::WriteProperties": + int version + int chunksize + CCompressionType compression + int compression_level + + CStatus WriteFeather" arrow::ipc::feather::WriteTable" \ + (const CTable& table, COutputStream* out, + CFeatherProperties properties) + + cdef cppclass CFeatherReader" arrow::ipc::feather::Reader": + @staticmethod + CResult[shared_ptr[CFeatherReader]] Open( + const shared_ptr[CRandomAccessFile]& file) + int version() + shared_ptr[CSchema] schema() + + CStatus Read(shared_ptr[CTable]* out) + CStatus Read(const vector[int] indices, shared_ptr[CTable]* out) + CStatus Read(const vector[c_string] names, shared_ptr[CTable]* out) diff --git a/src/arrow/python/pyarrow/includes/libarrow_flight.pxd b/src/arrow/python/pyarrow/includes/libarrow_flight.pxd new file mode 100644 index 000000000..2ac737aba --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libarrow_flight.pxd @@ -0,0 +1,560 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + + +cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: + cdef char* CPyServerMiddlewareName\ + " arrow::py::flight::kPyServerMiddlewareName" + + cdef cppclass CActionType" arrow::flight::ActionType": + c_string type + c_string description + + cdef cppclass CAction" arrow::flight::Action": + c_string type + shared_ptr[CBuffer] body + + cdef cppclass CFlightResult" arrow::flight::Result": + CFlightResult() + CFlightResult(CFlightResult) + shared_ptr[CBuffer] body + + cdef cppclass CBasicAuth" arrow::flight::BasicAuth": + CBasicAuth() + CBasicAuth(CBuffer) + CBasicAuth(CBasicAuth) + c_string username + c_string password + + cdef cppclass CResultStream" arrow::flight::ResultStream": + CStatus Next(unique_ptr[CFlightResult]* result) + + cdef cppclass CDescriptorType \ + " arrow::flight::FlightDescriptor::DescriptorType": + bint operator==(CDescriptorType) + + CDescriptorType CDescriptorTypeUnknown\ + " arrow::flight::FlightDescriptor::UNKNOWN" + CDescriptorType CDescriptorTypePath\ + " arrow::flight::FlightDescriptor::PATH" + CDescriptorType CDescriptorTypeCmd\ + " arrow::flight::FlightDescriptor::CMD" + + cdef cppclass CFlightDescriptor" arrow::flight::FlightDescriptor": + CDescriptorType type + c_string cmd + vector[c_string] path + CStatus SerializeToString(c_string* out) + + @staticmethod + CStatus Deserialize(const c_string& serialized, + CFlightDescriptor* out) + bint operator==(CFlightDescriptor) + + cdef cppclass CTicket" arrow::flight::Ticket": + CTicket() + c_string ticket + bint operator==(CTicket) + CStatus SerializeToString(c_string* out) + + @staticmethod + CStatus Deserialize(const c_string& serialized, CTicket* out) + + cdef cppclass CCriteria" arrow::flight::Criteria": + CCriteria() + c_string expression + + cdef cppclass CLocation" arrow::flight::Location": + CLocation() + c_string ToString() + c_bool Equals(const CLocation& other) + + @staticmethod + CStatus Parse(c_string& uri_string, CLocation* location) + + @staticmethod + CStatus ForGrpcTcp(c_string& host, int port, CLocation* location) + + @staticmethod + CStatus ForGrpcTls(c_string& host, int port, CLocation* location) + + @staticmethod + CStatus ForGrpcUnix(c_string& path, CLocation* location) + + cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint": + CFlightEndpoint() + + CTicket ticket + vector[CLocation] locations + + bint operator==(CFlightEndpoint) + + cdef cppclass CFlightInfo" arrow::flight::FlightInfo": + CFlightInfo(CFlightInfo info) + int64_t total_records() + int64_t total_bytes() + CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out) + CFlightDescriptor& descriptor() + const vector[CFlightEndpoint]& endpoints() + CStatus SerializeToString(c_string* out) + + @staticmethod + CStatus Deserialize(const c_string& serialized, + unique_ptr[CFlightInfo]* out) + + cdef cppclass CSchemaResult" arrow::flight::SchemaResult": + CSchemaResult(CSchemaResult result) + CStatus GetSchema(CDictionaryMemo* memo, shared_ptr[CSchema]* out) + + cdef cppclass CFlightListing" arrow::flight::FlightListing": + CStatus Next(unique_ptr[CFlightInfo]* info) + + cdef cppclass CSimpleFlightListing" arrow::flight::SimpleFlightListing": + CSimpleFlightListing(vector[CFlightInfo]&& info) + + cdef cppclass CFlightPayload" arrow::flight::FlightPayload": + shared_ptr[CBuffer] descriptor + shared_ptr[CBuffer] app_metadata + CIpcPayload ipc_message + + cdef cppclass CFlightDataStream" arrow::flight::FlightDataStream": + shared_ptr[CSchema] schema() + CStatus Next(CFlightPayload*) + + cdef cppclass CFlightStreamChunk" arrow::flight::FlightStreamChunk": + CFlightStreamChunk() + shared_ptr[CRecordBatch] data + shared_ptr[CBuffer] app_metadata + + cdef cppclass CMetadataRecordBatchReader \ + " arrow::flight::MetadataRecordBatchReader": + CResult[shared_ptr[CSchema]] GetSchema() + CStatus Next(CFlightStreamChunk* out) + CStatus ReadAll(shared_ptr[CTable]* table) + + CResult[shared_ptr[CRecordBatchReader]] MakeRecordBatchReader\ + " arrow::flight::MakeRecordBatchReader"( + shared_ptr[CMetadataRecordBatchReader]) + + cdef cppclass CMetadataRecordBatchWriter \ + " arrow::flight::MetadataRecordBatchWriter"(CRecordBatchWriter): + CStatus Begin(shared_ptr[CSchema] schema, + const CIpcWriteOptions& options) + CStatus WriteMetadata(shared_ptr[CBuffer] app_metadata) + CStatus WriteWithMetadata(const CRecordBatch& batch, + shared_ptr[CBuffer] app_metadata) + + cdef cppclass CFlightStreamReader \ + " arrow::flight::FlightStreamReader"(CMetadataRecordBatchReader): + void Cancel() + CStatus ReadAllWithStopToken" ReadAll"\ + (shared_ptr[CTable]* table, const CStopToken& stop_token) + + cdef cppclass CFlightMessageReader \ + " arrow::flight::FlightMessageReader"(CMetadataRecordBatchReader): + CFlightDescriptor& descriptor() + + cdef cppclass CFlightMessageWriter \ + " arrow::flight::FlightMessageWriter"(CMetadataRecordBatchWriter): + pass + + cdef cppclass CFlightStreamWriter \ + " arrow::flight::FlightStreamWriter"(CMetadataRecordBatchWriter): + CStatus DoneWriting() + + cdef cppclass CRecordBatchStream \ + " arrow::flight::RecordBatchStream"(CFlightDataStream): + CRecordBatchStream(shared_ptr[CRecordBatchReader]& reader, + const CIpcWriteOptions& options) + + cdef cppclass CFlightMetadataReader" arrow::flight::FlightMetadataReader": + CStatus ReadMetadata(shared_ptr[CBuffer]* out) + + cdef cppclass CFlightMetadataWriter" arrow::flight::FlightMetadataWriter": + CStatus WriteMetadata(const CBuffer& message) + + cdef cppclass CServerAuthReader" arrow::flight::ServerAuthReader": + CStatus Read(c_string* token) + + cdef cppclass CServerAuthSender" arrow::flight::ServerAuthSender": + CStatus Write(c_string& token) + + cdef cppclass CClientAuthReader" arrow::flight::ClientAuthReader": + CStatus Read(c_string* token) + + cdef cppclass CClientAuthSender" arrow::flight::ClientAuthSender": + CStatus Write(c_string& token) + + cdef cppclass CServerAuthHandler" arrow::flight::ServerAuthHandler": + pass + + cdef cppclass CClientAuthHandler" arrow::flight::ClientAuthHandler": + pass + + cdef cppclass CServerCallContext" arrow::flight::ServerCallContext": + c_string& peer_identity() + c_string& peer() + c_bool is_cancelled() + CServerMiddleware* GetMiddleware(const c_string& key) + + cdef cppclass CTimeoutDuration" arrow::flight::TimeoutDuration": + CTimeoutDuration(double) + + cdef cppclass CFlightCallOptions" arrow::flight::FlightCallOptions": + CFlightCallOptions() + CTimeoutDuration timeout + CIpcWriteOptions write_options + vector[pair[c_string, c_string]] headers + CStopToken stop_token + + cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair": + CCertKeyPair() + c_string pem_cert + c_string pem_key + + cdef cppclass CFlightMethod" arrow::flight::FlightMethod": + bint operator==(CFlightMethod) + + CFlightMethod CFlightMethodInvalid\ + " arrow::flight::FlightMethod::Invalid" + CFlightMethod CFlightMethodHandshake\ + " arrow::flight::FlightMethod::Handshake" + CFlightMethod CFlightMethodListFlights\ + " arrow::flight::FlightMethod::ListFlights" + CFlightMethod CFlightMethodGetFlightInfo\ + " arrow::flight::FlightMethod::GetFlightInfo" + CFlightMethod CFlightMethodGetSchema\ + " arrow::flight::FlightMethod::GetSchema" + CFlightMethod CFlightMethodDoGet\ + " arrow::flight::FlightMethod::DoGet" + CFlightMethod CFlightMethodDoPut\ + " arrow::flight::FlightMethod::DoPut" + CFlightMethod CFlightMethodDoAction\ + " arrow::flight::FlightMethod::DoAction" + CFlightMethod CFlightMethodListActions\ + " arrow::flight::FlightMethod::ListActions" + CFlightMethod CFlightMethodDoExchange\ + " arrow::flight::FlightMethod::DoExchange" + + cdef cppclass CCallInfo" arrow::flight::CallInfo": + CFlightMethod method + + # This is really std::unordered_multimap, but Cython has no + # bindings for it, so treat it as an opaque class and bind the + # methods we need + cdef cppclass CCallHeaders" arrow::flight::CallHeaders": + cppclass const_iterator: + pair[c_string, c_string] operator*() + const_iterator operator++() + bint operator==(const_iterator) + bint operator!=(const_iterator) + const_iterator cbegin() + const_iterator cend() + + cdef cppclass CAddCallHeaders" arrow::flight::AddCallHeaders": + void AddHeader(const c_string& key, const c_string& value) + + cdef cppclass CServerMiddleware" arrow::flight::ServerMiddleware": + c_string name() + + cdef cppclass CServerMiddlewareFactory\ + " arrow::flight::ServerMiddlewareFactory": + pass + + cdef cppclass CClientMiddleware" arrow::flight::ClientMiddleware": + pass + + cdef cppclass CClientMiddlewareFactory\ + " arrow::flight::ClientMiddlewareFactory": + pass + + cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions": + CFlightServerOptions(const CLocation& location) + CLocation location + unique_ptr[CServerAuthHandler] auth_handler + vector[CCertKeyPair] tls_certificates + c_bool verify_client + c_string root_certificates + vector[pair[c_string, shared_ptr[CServerMiddlewareFactory]]] middleware + + cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions": + c_string tls_root_certs + c_string cert_chain + c_string private_key + c_string override_hostname + vector[shared_ptr[CClientMiddlewareFactory]] middleware + int64_t write_size_limit_bytes + vector[pair[c_string, CIntStringVariant]] generic_options + c_bool disable_server_verification + + @staticmethod + CFlightClientOptions Defaults() + + cdef cppclass CFlightClient" arrow::flight::FlightClient": + @staticmethod + CStatus Connect(const CLocation& location, + const CFlightClientOptions& options, + unique_ptr[CFlightClient]* client) + + CStatus Authenticate(CFlightCallOptions& options, + unique_ptr[CClientAuthHandler] auth_handler) + + CResult[pair[c_string, c_string]] AuthenticateBasicToken( + CFlightCallOptions& options, + const c_string& username, + const c_string& password) + + CStatus DoAction(CFlightCallOptions& options, CAction& action, + unique_ptr[CResultStream]* results) + CStatus ListActions(CFlightCallOptions& options, + vector[CActionType]* actions) + + CStatus ListFlights(CFlightCallOptions& options, CCriteria criteria, + unique_ptr[CFlightListing]* listing) + CStatus GetFlightInfo(CFlightCallOptions& options, + CFlightDescriptor& descriptor, + unique_ptr[CFlightInfo]* info) + CStatus GetSchema(CFlightCallOptions& options, + CFlightDescriptor& descriptor, + unique_ptr[CSchemaResult]* result) + CStatus DoGet(CFlightCallOptions& options, CTicket& ticket, + unique_ptr[CFlightStreamReader]* stream) + CStatus DoPut(CFlightCallOptions& options, + CFlightDescriptor& descriptor, + shared_ptr[CSchema]& schema, + unique_ptr[CFlightStreamWriter]* stream, + unique_ptr[CFlightMetadataReader]* reader) + CStatus DoExchange(CFlightCallOptions& options, + CFlightDescriptor& descriptor, + unique_ptr[CFlightStreamWriter]* writer, + unique_ptr[CFlightStreamReader]* reader) + + cdef cppclass CFlightStatusCode" arrow::flight::FlightStatusCode": + bint operator==(CFlightStatusCode) + + CFlightStatusCode CFlightStatusInternal \ + " arrow::flight::FlightStatusCode::Internal" + CFlightStatusCode CFlightStatusTimedOut \ + " arrow::flight::FlightStatusCode::TimedOut" + CFlightStatusCode CFlightStatusCancelled \ + " arrow::flight::FlightStatusCode::Cancelled" + CFlightStatusCode CFlightStatusUnauthenticated \ + " arrow::flight::FlightStatusCode::Unauthenticated" + CFlightStatusCode CFlightStatusUnauthorized \ + " arrow::flight::FlightStatusCode::Unauthorized" + CFlightStatusCode CFlightStatusUnavailable \ + " arrow::flight::FlightStatusCode::Unavailable" + CFlightStatusCode CFlightStatusFailed \ + " arrow::flight::FlightStatusCode::Failed" + + cdef cppclass FlightStatusDetail" arrow::flight::FlightStatusDetail": + CFlightStatusCode code() + c_string extra_info() + + @staticmethod + shared_ptr[FlightStatusDetail] UnwrapStatus(const CStatus& status) + + cdef cppclass FlightWriteSizeStatusDetail\ + " arrow::flight::FlightWriteSizeStatusDetail": + int64_t limit() + int64_t actual() + + @staticmethod + shared_ptr[FlightWriteSizeStatusDetail] UnwrapStatus( + const CStatus& status) + + cdef CStatus MakeFlightError" arrow::flight::MakeFlightError" \ + (CFlightStatusCode code, const c_string& message) + + cdef CStatus MakeFlightError" arrow::flight::MakeFlightError" \ + (CFlightStatusCode code, + const c_string& message, + const c_string& extra_info) + +# Callbacks for implementing Flight servers +# Use typedef to emulate syntax for std::function<void(..)> +ctypedef CStatus cb_list_flights(object, const CServerCallContext&, + const CCriteria*, + unique_ptr[CFlightListing]*) +ctypedef CStatus cb_get_flight_info(object, const CServerCallContext&, + const CFlightDescriptor&, + unique_ptr[CFlightInfo]*) +ctypedef CStatus cb_get_schema(object, const CServerCallContext&, + const CFlightDescriptor&, + unique_ptr[CSchemaResult]*) +ctypedef CStatus cb_do_put(object, const CServerCallContext&, + unique_ptr[CFlightMessageReader], + unique_ptr[CFlightMetadataWriter]) +ctypedef CStatus cb_do_get(object, const CServerCallContext&, + const CTicket&, + unique_ptr[CFlightDataStream]*) +ctypedef CStatus cb_do_exchange(object, const CServerCallContext&, + unique_ptr[CFlightMessageReader], + unique_ptr[CFlightMessageWriter]) +ctypedef CStatus cb_do_action(object, const CServerCallContext&, + const CAction&, + unique_ptr[CResultStream]*) +ctypedef CStatus cb_list_actions(object, const CServerCallContext&, + vector[CActionType]*) +ctypedef CStatus cb_result_next(object, unique_ptr[CFlightResult]*) +ctypedef CStatus cb_data_stream_next(object, CFlightPayload*) +ctypedef CStatus cb_server_authenticate(object, CServerAuthSender*, + CServerAuthReader*) +ctypedef CStatus cb_is_valid(object, const c_string&, c_string*) +ctypedef CStatus cb_client_authenticate(object, CClientAuthSender*, + CClientAuthReader*) +ctypedef CStatus cb_get_token(object, c_string*) + +ctypedef CStatus cb_middleware_sending_headers(object, CAddCallHeaders*) +ctypedef CStatus cb_middleware_call_completed(object, const CStatus&) +ctypedef CStatus cb_client_middleware_received_headers( + object, const CCallHeaders&) +ctypedef CStatus cb_server_middleware_start_call( + object, + const CCallInfo&, + const CCallHeaders&, + shared_ptr[CServerMiddleware]*) +ctypedef CStatus cb_client_middleware_start_call( + object, + const CCallInfo&, + unique_ptr[CClientMiddleware]*) + +cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: + cdef cppclass PyFlightServerVtable: + PyFlightServerVtable() + function[cb_list_flights] list_flights + function[cb_get_flight_info] get_flight_info + function[cb_get_schema] get_schema + function[cb_do_put] do_put + function[cb_do_get] do_get + function[cb_do_exchange] do_exchange + function[cb_do_action] do_action + function[cb_list_actions] list_actions + + cdef cppclass PyServerAuthHandlerVtable: + PyServerAuthHandlerVtable() + function[cb_server_authenticate] authenticate + function[cb_is_valid] is_valid + + cdef cppclass PyClientAuthHandlerVtable: + PyClientAuthHandlerVtable() + function[cb_client_authenticate] authenticate + function[cb_get_token] get_token + + cdef cppclass PyFlightServer: + PyFlightServer(object server, PyFlightServerVtable vtable) + + CStatus Init(CFlightServerOptions& options) + int port() + CStatus ServeWithSignals() except * + CStatus Shutdown() + CStatus Wait() + + cdef cppclass PyServerAuthHandler\ + " arrow::py::flight::PyServerAuthHandler"(CServerAuthHandler): + PyServerAuthHandler(object handler, PyServerAuthHandlerVtable vtable) + + cdef cppclass PyClientAuthHandler\ + " arrow::py::flight::PyClientAuthHandler"(CClientAuthHandler): + PyClientAuthHandler(object handler, PyClientAuthHandlerVtable vtable) + + cdef cppclass CPyFlightResultStream\ + " arrow::py::flight::PyFlightResultStream"(CResultStream): + CPyFlightResultStream(object generator, + function[cb_result_next] callback) + + cdef cppclass CPyFlightDataStream\ + " arrow::py::flight::PyFlightDataStream"(CFlightDataStream): + CPyFlightDataStream(object data_source, + unique_ptr[CFlightDataStream] stream) + + cdef cppclass CPyGeneratorFlightDataStream\ + " arrow::py::flight::PyGeneratorFlightDataStream"\ + (CFlightDataStream): + CPyGeneratorFlightDataStream(object generator, + shared_ptr[CSchema] schema, + function[cb_data_stream_next] callback, + const CIpcWriteOptions& options) + + cdef cppclass PyServerMiddlewareVtable\ + " arrow::py::flight::PyServerMiddleware::Vtable": + PyServerMiddlewareVtable() + function[cb_middleware_sending_headers] sending_headers + function[cb_middleware_call_completed] call_completed + + cdef cppclass PyClientMiddlewareVtable\ + " arrow::py::flight::PyClientMiddleware::Vtable": + PyClientMiddlewareVtable() + function[cb_middleware_sending_headers] sending_headers + function[cb_client_middleware_received_headers] received_headers + function[cb_middleware_call_completed] call_completed + + cdef cppclass CPyServerMiddleware\ + " arrow::py::flight::PyServerMiddleware"(CServerMiddleware): + CPyServerMiddleware(object middleware, PyServerMiddlewareVtable vtable) + void* py_object() + + cdef cppclass CPyServerMiddlewareFactory\ + " arrow::py::flight::PyServerMiddlewareFactory"\ + (CServerMiddlewareFactory): + CPyServerMiddlewareFactory( + object factory, + function[cb_server_middleware_start_call] start_call) + + cdef cppclass CPyClientMiddleware\ + " arrow::py::flight::PyClientMiddleware"(CClientMiddleware): + CPyClientMiddleware(object middleware, PyClientMiddlewareVtable vtable) + + cdef cppclass CPyClientMiddlewareFactory\ + " arrow::py::flight::PyClientMiddlewareFactory"\ + (CClientMiddlewareFactory): + CPyClientMiddlewareFactory( + object factory, + function[cb_client_middleware_start_call] start_call) + + cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"( + shared_ptr[CSchema] schema, + CFlightDescriptor& descriptor, + vector[CFlightEndpoint] endpoints, + int64_t total_records, + int64_t total_bytes, + unique_ptr[CFlightInfo]* out) + + cdef CStatus CreateSchemaResult" arrow::py::flight::CreateSchemaResult"( + shared_ptr[CSchema] schema, + unique_ptr[CSchemaResult]* out) + + cdef CStatus DeserializeBasicAuth\ + " arrow::py::flight::DeserializeBasicAuth"( + c_string buf, + unique_ptr[CBasicAuth]* out) + + cdef CStatus SerializeBasicAuth" arrow::py::flight::SerializeBasicAuth"( + CBasicAuth basic_auth, + c_string* out) + + +cdef extern from "arrow/util/variant.h" namespace "arrow" nogil: + cdef cppclass CIntStringVariant" arrow::util::Variant<int, std::string>": + CIntStringVariant() + CIntStringVariant(int) + CIntStringVariant(c_string) diff --git a/src/arrow/python/pyarrow/includes/libarrow_fs.pxd b/src/arrow/python/pyarrow/includes/libarrow_fs.pxd new file mode 100644 index 000000000..9dca5fbf6 --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libarrow_fs.pxd @@ -0,0 +1,296 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + +cdef extern from "arrow/filesystem/api.h" namespace "arrow::fs" nogil: + + ctypedef enum CFileType "arrow::fs::FileType": + CFileType_NotFound "arrow::fs::FileType::NotFound" + CFileType_Unknown "arrow::fs::FileType::Unknown" + CFileType_File "arrow::fs::FileType::File" + CFileType_Directory "arrow::fs::FileType::Directory" + + cdef cppclass CFileInfo "arrow::fs::FileInfo": + CFileInfo() + CFileInfo(CFileInfo&&) + CFileInfo& operator=(CFileInfo&&) + CFileInfo(const CFileInfo&) + CFileInfo& operator=(const CFileInfo&) + + CFileType type() + void set_type(CFileType type) + c_string path() + void set_path(const c_string& path) + c_string base_name() + int64_t size() + void set_size(int64_t size) + c_string extension() + CTimePoint mtime() + void set_mtime(CTimePoint mtime) + + cdef cppclass CFileSelector "arrow::fs::FileSelector": + CFileSelector() + c_string base_dir + c_bool allow_not_found + c_bool recursive + + cdef cppclass CFileLocator "arrow::fs::FileLocator": + shared_ptr[CFileSystem] filesystem + c_string path + + cdef cppclass CFileSystem "arrow::fs::FileSystem": + shared_ptr[CFileSystem] shared_from_this() + c_string type_name() const + CResult[c_string] NormalizePath(c_string path) + CResult[CFileInfo] GetFileInfo(const c_string& path) + CResult[vector[CFileInfo]] GetFileInfo( + const vector[c_string]& paths) + CResult[vector[CFileInfo]] GetFileInfo(const CFileSelector& select) + CStatus CreateDir(const c_string& path, c_bool recursive) + CStatus DeleteDir(const c_string& path) + CStatus DeleteDirContents(const c_string& path) + CStatus DeleteRootDirContents() + CStatus DeleteFile(const c_string& path) + CStatus DeleteFiles(const vector[c_string]& paths) + CStatus Move(const c_string& src, const c_string& dest) + CStatus CopyFile(const c_string& src, const c_string& dest) + CResult[shared_ptr[CInputStream]] OpenInputStream( + const c_string& path) + CResult[shared_ptr[CRandomAccessFile]] OpenInputFile( + const c_string& path) + CResult[shared_ptr[COutputStream]] OpenOutputStream( + const c_string& path, const shared_ptr[const CKeyValueMetadata]&) + CResult[shared_ptr[COutputStream]] OpenAppendStream( + const c_string& path, const shared_ptr[const CKeyValueMetadata]&) + c_bool Equals(const CFileSystem& other) + c_bool Equals(shared_ptr[CFileSystem] other) + + CResult[shared_ptr[CFileSystem]] CFileSystemFromUri \ + "arrow::fs::FileSystemFromUri"(const c_string& uri, c_string* out_path) + CResult[shared_ptr[CFileSystem]] CFileSystemFromUriOrPath \ + "arrow::fs::FileSystemFromUriOrPath"(const c_string& uri, + c_string* out_path) + + cdef cppclass CFileSystemGlobalOptions \ + "arrow::fs::FileSystemGlobalOptions": + c_string tls_ca_file_path + c_string tls_ca_dir_path + + CStatus CFileSystemsInitialize "arrow::fs::Initialize" \ + (const CFileSystemGlobalOptions& options) + + cdef cppclass CLocalFileSystemOptions "arrow::fs::LocalFileSystemOptions": + c_bool use_mmap + + @staticmethod + CLocalFileSystemOptions Defaults() + + c_bool Equals(const CLocalFileSystemOptions& other) + + cdef cppclass CLocalFileSystem "arrow::fs::LocalFileSystem"(CFileSystem): + CLocalFileSystem() + CLocalFileSystem(CLocalFileSystemOptions) + CLocalFileSystemOptions options() + + cdef cppclass CSubTreeFileSystem \ + "arrow::fs::SubTreeFileSystem"(CFileSystem): + CSubTreeFileSystem(const c_string& base_path, + shared_ptr[CFileSystem] base_fs) + c_string base_path() + shared_ptr[CFileSystem] base_fs() + + ctypedef enum CS3LogLevel "arrow::fs::S3LogLevel": + CS3LogLevel_Off "arrow::fs::S3LogLevel::Off" + CS3LogLevel_Fatal "arrow::fs::S3LogLevel::Fatal" + CS3LogLevel_Error "arrow::fs::S3LogLevel::Error" + CS3LogLevel_Warn "arrow::fs::S3LogLevel::Warn" + CS3LogLevel_Info "arrow::fs::S3LogLevel::Info" + CS3LogLevel_Debug "arrow::fs::S3LogLevel::Debug" + CS3LogLevel_Trace "arrow::fs::S3LogLevel::Trace" + + cdef struct CS3GlobalOptions "arrow::fs::S3GlobalOptions": + CS3LogLevel log_level + + cdef cppclass CS3ProxyOptions "arrow::fs::S3ProxyOptions": + c_string scheme + c_string host + int port + c_string username + c_string password + c_bool Equals(const CS3ProxyOptions& other) + + @staticmethod + CResult[CS3ProxyOptions] FromUriString "FromUri"( + const c_string& uri_string) + + ctypedef enum CS3CredentialsKind "arrow::fs::S3CredentialsKind": + CS3CredentialsKind_Anonymous "arrow::fs::S3CredentialsKind::Anonymous" + CS3CredentialsKind_Default "arrow::fs::S3CredentialsKind::Default" + CS3CredentialsKind_Explicit "arrow::fs::S3CredentialsKind::Explicit" + CS3CredentialsKind_Role "arrow::fs::S3CredentialsKind::Role" + CS3CredentialsKind_WebIdentity \ + "arrow::fs::S3CredentialsKind::WebIdentity" + + cdef cppclass CS3Options "arrow::fs::S3Options": + c_string region + c_string endpoint_override + c_string scheme + c_bool background_writes + shared_ptr[const CKeyValueMetadata] default_metadata + c_string role_arn + c_string session_name + c_string external_id + int load_frequency + CS3ProxyOptions proxy_options + CS3CredentialsKind credentials_kind + void ConfigureDefaultCredentials() + void ConfigureAccessKey(const c_string& access_key, + const c_string& secret_key, + const c_string& session_token) + c_string GetAccessKey() + c_string GetSecretKey() + c_string GetSessionToken() + c_bool Equals(const CS3Options& other) + + @staticmethod + CS3Options Defaults() + + @staticmethod + CS3Options Anonymous() + + @staticmethod + CS3Options FromAccessKey(const c_string& access_key, + const c_string& secret_key, + const c_string& session_token) + + @staticmethod + CS3Options FromAssumeRole(const c_string& role_arn, + const c_string& session_name, + const c_string& external_id, + const int load_frequency) + + cdef cppclass CS3FileSystem "arrow::fs::S3FileSystem"(CFileSystem): + @staticmethod + CResult[shared_ptr[CS3FileSystem]] Make(const CS3Options& options) + CS3Options options() + c_string region() + + cdef CStatus CInitializeS3 "arrow::fs::InitializeS3"( + const CS3GlobalOptions& options) + cdef CStatus CFinalizeS3 "arrow::fs::FinalizeS3"() + + cdef cppclass CHdfsOptions "arrow::fs::HdfsOptions": + HdfsConnectionConfig connection_config + int32_t buffer_size + int16_t replication + int64_t default_block_size + + @staticmethod + CResult[CHdfsOptions] FromUriString "FromUri"( + const c_string& uri_string) + void ConfigureEndPoint(c_string host, int port) + void ConfigureDriver(c_bool use_hdfs3) + void ConfigureReplication(int16_t replication) + void ConfigureUser(c_string user_name) + void ConfigureBufferSize(int32_t buffer_size) + void ConfigureBlockSize(int64_t default_block_size) + void ConfigureKerberosTicketCachePath(c_string path) + void ConfigureExtraConf(c_string key, c_string value) + + cdef cppclass CHadoopFileSystem "arrow::fs::HadoopFileSystem"(CFileSystem): + @staticmethod + CResult[shared_ptr[CHadoopFileSystem]] Make( + const CHdfsOptions& options) + CHdfsOptions options() + + cdef cppclass CMockFileSystem "arrow::fs::internal::MockFileSystem"( + CFileSystem): + CMockFileSystem(CTimePoint current_time) + + CStatus CCopyFiles "arrow::fs::CopyFiles"( + const vector[CFileLocator]& sources, + const vector[CFileLocator]& destinations, + const CIOContext& io_context, + int64_t chunk_size, c_bool use_threads) + CStatus CCopyFilesWithSelector "arrow::fs::CopyFiles"( + const shared_ptr[CFileSystem]& source_fs, + const CFileSelector& source_sel, + const shared_ptr[CFileSystem]& destination_fs, + const c_string& destination_base_dir, + const CIOContext& io_context, + int64_t chunk_size, c_bool use_threads) + + +# Callbacks for implementing Python filesystems +# Use typedef to emulate syntax for std::function<void(..)> +ctypedef void CallbackGetTypeName(object, c_string*) +ctypedef c_bool CallbackEquals(object, const CFileSystem&) + +ctypedef void CallbackGetFileInfo(object, const c_string&, CFileInfo*) +ctypedef void CallbackGetFileInfoVector(object, const vector[c_string]&, + vector[CFileInfo]*) +ctypedef void CallbackGetFileInfoSelector(object, const CFileSelector&, + vector[CFileInfo]*) +ctypedef void CallbackCreateDir(object, const c_string&, c_bool) +ctypedef void CallbackDeleteDir(object, const c_string&) +ctypedef void CallbackDeleteDirContents(object, const c_string&) +ctypedef void CallbackDeleteRootDirContents(object) +ctypedef void CallbackDeleteFile(object, const c_string&) +ctypedef void CallbackMove(object, const c_string&, const c_string&) +ctypedef void CallbackCopyFile(object, const c_string&, const c_string&) + +ctypedef void CallbackOpenInputStream(object, const c_string&, + shared_ptr[CInputStream]*) +ctypedef void CallbackOpenInputFile(object, const c_string&, + shared_ptr[CRandomAccessFile]*) +ctypedef void CallbackOpenOutputStream( + object, const c_string&, const shared_ptr[const CKeyValueMetadata]&, + shared_ptr[COutputStream]*) +ctypedef void CallbackNormalizePath(object, const c_string&, c_string*) + +cdef extern from "arrow/python/filesystem.h" namespace "arrow::py::fs" nogil: + + cdef cppclass CPyFileSystemVtable "arrow::py::fs::PyFileSystemVtable": + PyFileSystemVtable() + function[CallbackGetTypeName] get_type_name + function[CallbackEquals] equals + function[CallbackGetFileInfo] get_file_info + function[CallbackGetFileInfoVector] get_file_info_vector + function[CallbackGetFileInfoSelector] get_file_info_selector + function[CallbackCreateDir] create_dir + function[CallbackDeleteDir] delete_dir + function[CallbackDeleteDirContents] delete_dir_contents + function[CallbackDeleteRootDirContents] delete_root_dir_contents + function[CallbackDeleteFile] delete_file + function[CallbackMove] move + function[CallbackCopyFile] copy_file + function[CallbackOpenInputStream] open_input_stream + function[CallbackOpenInputFile] open_input_file + function[CallbackOpenOutputStream] open_output_stream + function[CallbackOpenOutputStream] open_append_stream + function[CallbackNormalizePath] normalize_path + + cdef cppclass CPyFileSystem "arrow::py::fs::PyFileSystem": + @staticmethod + shared_ptr[CPyFileSystem] Make(object handler, + CPyFileSystemVtable vtable) + + PyObject* handler() diff --git a/src/arrow/python/pyarrow/includes/libgandiva.pxd b/src/arrow/python/pyarrow/includes/libgandiva.pxd new file mode 100644 index 000000000..c75977d37 --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libgandiva.pxd @@ -0,0 +1,286 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from libcpp.string cimport string as c_string +from libcpp.unordered_set cimport unordered_set as c_unordered_set +from libc.stdint cimport int64_t, int32_t, uint8_t, uintptr_t + +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + +cdef extern from "gandiva/node.h" namespace "gandiva" nogil: + + cdef cppclass CNode" gandiva::Node": + c_string ToString() + shared_ptr[CDataType] return_type() + + cdef cppclass CExpression" gandiva::Expression": + c_string ToString() + shared_ptr[CNode] root() + shared_ptr[CField] result() + + ctypedef vector[shared_ptr[CNode]] CNodeVector" gandiva::NodeVector" + + ctypedef vector[shared_ptr[CExpression]] \ + CExpressionVector" gandiva::ExpressionVector" + +cdef extern from "gandiva/selection_vector.h" namespace "gandiva" nogil: + + cdef cppclass CSelectionVector" gandiva::SelectionVector": + + shared_ptr[CArray] ToArray() + + enum CSelectionVector_Mode" gandiva::SelectionVector::Mode": + CSelectionVector_Mode_NONE" gandiva::SelectionVector::Mode::MODE_NONE" + CSelectionVector_Mode_UINT16" \ + gandiva::SelectionVector::Mode::MODE_UINT16" + CSelectionVector_Mode_UINT32" \ + gandiva::SelectionVector::Mode::MODE_UINT32" + CSelectionVector_Mode_UINT64" \ + gandiva::SelectionVector::Mode::MODE_UINT64" + + cdef CStatus SelectionVector_MakeInt16\ + "gandiva::SelectionVector::MakeInt16"( + int64_t max_slots, CMemoryPool* pool, + shared_ptr[CSelectionVector]* selection_vector) + + cdef CStatus SelectionVector_MakeInt32\ + "gandiva::SelectionVector::MakeInt32"( + int64_t max_slots, CMemoryPool* pool, + shared_ptr[CSelectionVector]* selection_vector) + + cdef CStatus SelectionVector_MakeInt64\ + "gandiva::SelectionVector::MakeInt64"( + int64_t max_slots, CMemoryPool* pool, + shared_ptr[CSelectionVector]* selection_vector) + +cdef inline CSelectionVector_Mode _ensure_selection_mode(str name) except *: + uppercase = name.upper() + if uppercase == 'NONE': + return CSelectionVector_Mode_NONE + elif uppercase == 'UINT16': + return CSelectionVector_Mode_UINT16 + elif uppercase == 'UINT32': + return CSelectionVector_Mode_UINT32 + elif uppercase == 'UINT64': + return CSelectionVector_Mode_UINT64 + else: + raise ValueError('Invalid value for Selection Mode: {!r}'.format(name)) + +cdef inline str _selection_mode_name(CSelectionVector_Mode ctype): + if ctype == CSelectionVector_Mode_NONE: + return 'NONE' + elif ctype == CSelectionVector_Mode_UINT16: + return 'UINT16' + elif ctype == CSelectionVector_Mode_UINT32: + return 'UINT32' + elif ctype == CSelectionVector_Mode_UINT64: + return 'UINT64' + else: + raise RuntimeError('Unexpected CSelectionVector_Mode value') + +cdef extern from "gandiva/condition.h" namespace "gandiva" nogil: + + cdef cppclass CCondition" gandiva::Condition": + c_string ToString() + shared_ptr[CNode] root() + shared_ptr[CField] result() + +cdef extern from "gandiva/arrow.h" namespace "gandiva" nogil: + + ctypedef vector[shared_ptr[CArray]] CArrayVector" gandiva::ArrayVector" + + +cdef extern from "gandiva/tree_expr_builder.h" namespace "gandiva" nogil: + + cdef shared_ptr[CNode] TreeExprBuilder_MakeBoolLiteral \ + "gandiva::TreeExprBuilder::MakeLiteral"(c_bool value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt8Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(uint8_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt16Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(uint16_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt32Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(uint32_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt64Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(uint64_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInt8Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(int8_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInt16Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(int16_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInt32Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(int32_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInt64Literal \ + "gandiva::TreeExprBuilder::MakeLiteral"(int64_t value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeFloatLiteral \ + "gandiva::TreeExprBuilder::MakeLiteral"(float value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeDoubleLiteral \ + "gandiva::TreeExprBuilder::MakeLiteral"(double value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeStringLiteral \ + "gandiva::TreeExprBuilder::MakeStringLiteral"(const c_string& value) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeBinaryLiteral \ + "gandiva::TreeExprBuilder::MakeBinaryLiteral"(const c_string& value) + + cdef shared_ptr[CExpression] TreeExprBuilder_MakeExpression\ + "gandiva::TreeExprBuilder::MakeExpression"( + shared_ptr[CNode] root_node, shared_ptr[CField] result_field) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeFunction \ + "gandiva::TreeExprBuilder::MakeFunction"( + const c_string& name, const CNodeVector& children, + shared_ptr[CDataType] return_type) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeField \ + "gandiva::TreeExprBuilder::MakeField"(shared_ptr[CField] field) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeIf \ + "gandiva::TreeExprBuilder::MakeIf"( + shared_ptr[CNode] condition, shared_ptr[CNode] this_node, + shared_ptr[CNode] else_node, shared_ptr[CDataType] return_type) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeAnd \ + "gandiva::TreeExprBuilder::MakeAnd"(const CNodeVector& children) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeOr \ + "gandiva::TreeExprBuilder::MakeOr"(const CNodeVector& children) + + cdef shared_ptr[CCondition] TreeExprBuilder_MakeCondition \ + "gandiva::TreeExprBuilder::MakeCondition"( + shared_ptr[CNode] condition) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionInt32 \ + "gandiva::TreeExprBuilder::MakeInExpressionInt32"( + shared_ptr[CNode] node, const c_unordered_set[int32_t]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionInt64 \ + "gandiva::TreeExprBuilder::MakeInExpressionInt64"( + shared_ptr[CNode] node, const c_unordered_set[int64_t]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionTime32 \ + "gandiva::TreeExprBuilder::MakeInExpressionTime32"( + shared_ptr[CNode] node, const c_unordered_set[int32_t]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionTime64 \ + "gandiva::TreeExprBuilder::MakeInExpressionTime64"( + shared_ptr[CNode] node, const c_unordered_set[int64_t]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionDate32 \ + "gandiva::TreeExprBuilder::MakeInExpressionDate32"( + shared_ptr[CNode] node, const c_unordered_set[int32_t]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionDate64 \ + "gandiva::TreeExprBuilder::MakeInExpressionDate64"( + shared_ptr[CNode] node, const c_unordered_set[int64_t]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionTimeStamp \ + "gandiva::TreeExprBuilder::MakeInExpressionTimeStamp"( + shared_ptr[CNode] node, const c_unordered_set[int64_t]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionString \ + "gandiva::TreeExprBuilder::MakeInExpressionString"( + shared_ptr[CNode] node, const c_unordered_set[c_string]& values) + + cdef shared_ptr[CNode] TreeExprBuilder_MakeInExpressionBinary \ + "gandiva::TreeExprBuilder::MakeInExpressionBinary"( + shared_ptr[CNode] node, const c_unordered_set[c_string]& values) + +cdef extern from "gandiva/projector.h" namespace "gandiva" nogil: + + cdef cppclass CProjector" gandiva::Projector": + + CStatus Evaluate( + const CRecordBatch& batch, CMemoryPool* pool, + const CArrayVector* output) + + CStatus Evaluate( + const CRecordBatch& batch, + const CSelectionVector* selection, + CMemoryPool* pool, + const CArrayVector* output) + + c_string DumpIR() + + cdef CStatus Projector_Make \ + "gandiva::Projector::Make"( + shared_ptr[CSchema] schema, const CExpressionVector& children, + shared_ptr[CProjector]* projector) + + cdef CStatus Projector_Make \ + "gandiva::Projector::Make"( + shared_ptr[CSchema] schema, const CExpressionVector& children, + CSelectionVector_Mode mode, + shared_ptr[CConfiguration] configuration, + shared_ptr[CProjector]* projector) + +cdef extern from "gandiva/filter.h" namespace "gandiva" nogil: + + cdef cppclass CFilter" gandiva::Filter": + + CStatus Evaluate( + const CRecordBatch& batch, + shared_ptr[CSelectionVector] out_selection) + + c_string DumpIR() + + cdef CStatus Filter_Make \ + "gandiva::Filter::Make"( + shared_ptr[CSchema] schema, shared_ptr[CCondition] condition, + shared_ptr[CFilter]* filter) + +cdef extern from "gandiva/function_signature.h" namespace "gandiva" nogil: + + cdef cppclass CFunctionSignature" gandiva::FunctionSignature": + + CFunctionSignature(const c_string& base_name, + vector[shared_ptr[CDataType]] param_types, + shared_ptr[CDataType] ret_type) + + shared_ptr[CDataType] ret_type() const + + const c_string& base_name() const + + vector[shared_ptr[CDataType]] param_types() const + + c_string ToString() const + +cdef extern from "gandiva/expression_registry.h" namespace "gandiva" nogil: + + cdef vector[shared_ptr[CFunctionSignature]] \ + GetRegisteredFunctionSignatures() + +cdef extern from "gandiva/configuration.h" namespace "gandiva" nogil: + + cdef cppclass CConfiguration" gandiva::Configuration": + pass + + cdef cppclass CConfigurationBuilder \ + " gandiva::ConfigurationBuilder": + @staticmethod + shared_ptr[CConfiguration] DefaultConfiguration() diff --git a/src/arrow/python/pyarrow/includes/libplasma.pxd b/src/arrow/python/pyarrow/includes/libplasma.pxd new file mode 100644 index 000000000..d54e9f484 --- /dev/null +++ b/src/arrow/python/pyarrow/includes/libplasma.pxd @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language = c++ + +from pyarrow.includes.common cimport * + +cdef extern from "plasma/common.h" namespace "plasma" nogil: + cdef c_bool IsPlasmaObjectExists(const CStatus& status) + cdef c_bool IsPlasmaObjectNotFound(const CStatus& status) + cdef c_bool IsPlasmaStoreFull(const CStatus& status) diff --git a/src/arrow/python/pyarrow/io.pxi b/src/arrow/python/pyarrow/io.pxi new file mode 100644 index 000000000..f6c2b4219 --- /dev/null +++ b/src/arrow/python/pyarrow/io.pxi @@ -0,0 +1,2137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Cython wrappers for IO interfaces defined in arrow::io and messaging in +# arrow::ipc + +from libc.stdlib cimport malloc, free + +import codecs +import re +import sys +import threading +import time +import warnings +from io import BufferedIOBase, IOBase, TextIOBase, UnsupportedOperation +from queue import Queue, Empty as QueueEmpty + +from pyarrow.util import _is_path_like, _stringify_path + + +# 64K +DEFAULT_BUFFER_SIZE = 2 ** 16 + + +# To let us get a PyObject* and avoid Cython auto-ref-counting +cdef extern from "Python.h": + PyObject* PyBytes_FromStringAndSizeNative" PyBytes_FromStringAndSize"( + char *v, Py_ssize_t len) except NULL + + +def io_thread_count(): + """ + Return the number of threads to use for I/O operations. + + Many operations, such as scanning a dataset, will implicitly make + use of this pool. The number of threads is set to a fixed value at + startup. It can be modified at runtime by calling + :func:`set_io_thread_count()`. + + See Also + -------- + set_io_thread_count : Modify the size of this pool. + cpu_count : The analogous function for the CPU thread pool. + """ + return GetIOThreadPoolCapacity() + + +def set_io_thread_count(int count): + """ + Set the number of threads to use for I/O operations. + + Many operations, such as scanning a dataset, will implicitly make + use of this pool. + + Parameters + ---------- + count : int + The max number of threads that may be used for I/O. + Must be positive. + + See Also + -------- + io_thread_count : Get the size of this pool. + set_cpu_count : The analogous function for the CPU thread pool. + """ + if count < 1: + raise ValueError("IO thread count must be strictly positive") + check_status(SetIOThreadPoolCapacity(count)) + + +cdef class NativeFile(_Weakrefable): + """ + The base class for all Arrow streams. + + Streams are either readable, writable, or both. + They optionally support seeking. + + While this class exposes methods to read or write data from Python, the + primary intent of using a Arrow stream is to pass it to other Arrow + facilities that will make use of it, such as Arrow IPC routines. + + Be aware that there are subtle differences with regular Python files, + e.g. destroying a writable Arrow stream without closing it explicitly + will not flush any pending data. + """ + + def __cinit__(self): + self.own_file = False + self.is_readable = False + self.is_writable = False + self.is_seekable = False + + def __dealloc__(self): + if self.own_file: + self.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + self.close() + + @property + def mode(self): + """ + The file mode. Currently instances of NativeFile may support: + + * rb: binary read + * wb: binary write + * rb+: binary read and write + """ + # Emulate built-in file modes + if self.is_readable and self.is_writable: + return 'rb+' + elif self.is_readable: + return 'rb' + elif self.is_writable: + return 'wb' + else: + raise ValueError('File object is malformed, has no mode') + + def readable(self): + self._assert_open() + return self.is_readable + + def writable(self): + self._assert_open() + return self.is_writable + + def seekable(self): + self._assert_open() + return self.is_seekable + + def isatty(self): + self._assert_open() + return False + + def fileno(self): + """ + NOT IMPLEMENTED + """ + raise UnsupportedOperation() + + @property + def closed(self): + if self.is_readable: + return self.input_stream.get().closed() + elif self.is_writable: + return self.output_stream.get().closed() + else: + return True + + def close(self): + if not self.closed: + with nogil: + if self.is_readable: + check_status(self.input_stream.get().Close()) + else: + check_status(self.output_stream.get().Close()) + + cdef set_random_access_file(self, shared_ptr[CRandomAccessFile] handle): + self.input_stream = <shared_ptr[CInputStream]> handle + self.random_access = handle + self.is_seekable = True + + cdef set_input_stream(self, shared_ptr[CInputStream] handle): + self.input_stream = handle + self.random_access.reset() + self.is_seekable = False + + cdef set_output_stream(self, shared_ptr[COutputStream] handle): + self.output_stream = handle + + cdef shared_ptr[CRandomAccessFile] get_random_access_file(self) except *: + self._assert_readable() + self._assert_seekable() + return self.random_access + + cdef shared_ptr[CInputStream] get_input_stream(self) except *: + self._assert_readable() + return self.input_stream + + cdef shared_ptr[COutputStream] get_output_stream(self) except *: + self._assert_writable() + return self.output_stream + + def _assert_open(self): + if self.closed: + raise ValueError("I/O operation on closed file") + + def _assert_readable(self): + self._assert_open() + if not self.is_readable: + # XXX UnsupportedOperation + raise IOError("only valid on readable files") + + def _assert_writable(self): + self._assert_open() + if not self.is_writable: + raise IOError("only valid on writable files") + + def _assert_seekable(self): + self._assert_open() + if not self.is_seekable: + raise IOError("only valid on seekable files") + + def size(self): + """ + Return file size + """ + cdef int64_t size + + handle = self.get_random_access_file() + with nogil: + size = GetResultValue(handle.get().GetSize()) + + return size + + def metadata(self): + """ + Return file metadata + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_metadata + + handle = self.get_input_stream() + with nogil: + c_metadata = GetResultValue(handle.get().ReadMetadata()) + + metadata = {} + if c_metadata.get() != nullptr: + for i in range(c_metadata.get().size()): + metadata[frombytes(c_metadata.get().key(i))] = \ + c_metadata.get().value(i) + return metadata + + def tell(self): + """ + Return current stream position + """ + cdef int64_t position + + if self.is_readable: + rd_handle = self.get_random_access_file() + with nogil: + position = GetResultValue(rd_handle.get().Tell()) + else: + wr_handle = self.get_output_stream() + with nogil: + position = GetResultValue(wr_handle.get().Tell()) + + return position + + def seek(self, int64_t position, int whence=0): + """ + Change current file stream position + + Parameters + ---------- + position : int + Byte offset, interpreted relative to value of whence argument + whence : int, default 0 + Point of reference for seek offset + + Notes + ----- + Values of whence: + * 0 -- start of stream (the default); offset should be zero or positive + * 1 -- current stream position; offset may be negative + * 2 -- end of stream; offset is usually negative + + Returns + ------- + new_position : the new absolute stream position + """ + cdef int64_t offset + handle = self.get_random_access_file() + + with nogil: + if whence == 0: + offset = position + elif whence == 1: + offset = GetResultValue(handle.get().Tell()) + offset = offset + position + elif whence == 2: + offset = GetResultValue(handle.get().GetSize()) + offset = offset + position + else: + with gil: + raise ValueError("Invalid value of whence: {0}" + .format(whence)) + check_status(handle.get().Seek(offset)) + + return self.tell() + + def flush(self): + """ + Flush the stream, if applicable. + + An error is raised if stream is not writable. + """ + self._assert_open() + # For IOBase compatibility, flush() on an input stream is a no-op + if self.is_writable: + handle = self.get_output_stream() + with nogil: + check_status(handle.get().Flush()) + + def write(self, data): + """ + Write byte from any object implementing buffer protocol (bytes, + bytearray, ndarray, pyarrow.Buffer) + + Parameters + ---------- + data : bytes-like object or exporter of buffer protocol + + Returns + ------- + nbytes : number of bytes written + """ + self._assert_writable() + handle = self.get_output_stream() + + cdef shared_ptr[CBuffer] buf = as_c_buffer(data) + + with nogil: + check_status(handle.get().WriteBuffer(buf)) + return buf.get().size() + + def read(self, nbytes=None): + """ + Read indicated number of bytes from file, or read all remaining bytes + if no argument passed + + Parameters + ---------- + nbytes : int, default None + + Returns + ------- + data : bytes + """ + cdef: + int64_t c_nbytes + int64_t bytes_read = 0 + PyObject* obj + + if nbytes is None: + if not self.is_seekable: + # Cannot get file size => read chunkwise + bs = 16384 + chunks = [] + while True: + chunk = self.read(bs) + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks) + + c_nbytes = self.size() - self.tell() + else: + c_nbytes = nbytes + + handle = self.get_input_stream() + + # Allocate empty write space + obj = PyBytes_FromStringAndSizeNative(NULL, c_nbytes) + + cdef uint8_t* buf = <uint8_t*> cp.PyBytes_AS_STRING(<object> obj) + with nogil: + bytes_read = GetResultValue(handle.get().Read(c_nbytes, buf)) + + if bytes_read < c_nbytes: + cp._PyBytes_Resize(&obj, <Py_ssize_t> bytes_read) + + return PyObject_to_object(obj) + + def read_at(self, nbytes, offset): + """ + Read indicated number of bytes at offset from the file + + Parameters + ---------- + nbytes : int + offset : int + + Returns + ------- + data : bytes + """ + cdef: + int64_t c_nbytes + int64_t c_offset + int64_t bytes_read = 0 + PyObject* obj + + c_nbytes = nbytes + + c_offset = offset + + handle = self.get_random_access_file() + + # Allocate empty write space + obj = PyBytes_FromStringAndSizeNative(NULL, c_nbytes) + + cdef uint8_t* buf = <uint8_t*> cp.PyBytes_AS_STRING(<object> obj) + with nogil: + bytes_read = GetResultValue(handle.get(). + ReadAt(c_offset, c_nbytes, buf)) + + if bytes_read < c_nbytes: + cp._PyBytes_Resize(&obj, <Py_ssize_t> bytes_read) + + return PyObject_to_object(obj) + + def read1(self, nbytes=None): + """Read and return up to n bytes. + + Alias for read, needed to match the IOBase interface.""" + return self.read(nbytes=None) + + def readall(self): + return self.read() + + def readinto(self, b): + """ + Read into the supplied buffer + + Parameters + ----------- + b: any python object supporting buffer interface + + Returns + -------- + number of bytes written + """ + + cdef: + int64_t bytes_read + uint8_t* buf + Buffer py_buf + int64_t buf_len + + handle = self.get_input_stream() + + py_buf = py_buffer(b) + buf_len = py_buf.size + buf = py_buf.buffer.get().mutable_data() + + with nogil: + bytes_read = GetResultValue(handle.get().Read(buf_len, buf)) + + return bytes_read + + def readline(self, size=None): + """NOT IMPLEMENTED. Read and return a line of bytes from the file. + + If size is specified, read at most size bytes. + + Line terminator is always b"\\n". + """ + + raise UnsupportedOperation() + + def readlines(self, hint=None): + """NOT IMPLEMENTED. Read lines of the file + + Parameters + ----------- + + hint: int maximum number of bytes read until we stop + """ + + raise UnsupportedOperation() + + def __iter__(self): + self._assert_readable() + return self + + def __next__(self): + line = self.readline() + if not line: + raise StopIteration + return line + + def read_buffer(self, nbytes=None): + cdef: + int64_t c_nbytes + int64_t bytes_read = 0 + shared_ptr[CBuffer] output + + handle = self.get_input_stream() + + if nbytes is None: + if not self.is_seekable: + # Cannot get file size => read chunkwise + return py_buffer(self.read()) + c_nbytes = self.size() - self.tell() + else: + c_nbytes = nbytes + + with nogil: + output = GetResultValue(handle.get().ReadBuffer(c_nbytes)) + + return pyarrow_wrap_buffer(output) + + def truncate(self): + """ + NOT IMPLEMENTED + """ + raise UnsupportedOperation() + + def writelines(self, lines): + self._assert_writable() + + for line in lines: + self.write(line) + + def download(self, stream_or_path, buffer_size=None): + """ + Read file completely to local path (rather than reading completely into + memory). First seeks to the beginning of the file. + """ + cdef: + int64_t bytes_read = 0 + uint8_t* buf + + handle = self.get_input_stream() + + buffer_size = buffer_size or DEFAULT_BUFFER_SIZE + + write_queue = Queue(50) + + if not hasattr(stream_or_path, 'read'): + stream = open(stream_or_path, 'wb') + + def cleanup(): + stream.close() + else: + stream = stream_or_path + + def cleanup(): + pass + + done = False + exc_info = None + + def bg_write(): + try: + while not done or write_queue.qsize() > 0: + try: + buf = write_queue.get(timeout=0.01) + except QueueEmpty: + continue + stream.write(buf) + except Exception as e: + exc_info = sys.exc_info() + finally: + cleanup() + + self.seek(0) + + writer_thread = threading.Thread(target=bg_write) + + # This isn't ideal -- PyBytes_FromStringAndSize copies the data from + # the passed buffer, so it's hard for us to avoid doubling the memory + buf = <uint8_t*> malloc(buffer_size) + if buf == NULL: + raise MemoryError("Failed to allocate {0} bytes" + .format(buffer_size)) + + writer_thread.start() + + cdef int64_t total_bytes = 0 + cdef int32_t c_buffer_size = buffer_size + + try: + while True: + with nogil: + bytes_read = GetResultValue( + handle.get().Read(c_buffer_size, buf)) + + total_bytes += bytes_read + + # EOF + if bytes_read == 0: + break + + pybuf = cp.PyBytes_FromStringAndSize(<const char*>buf, + bytes_read) + + if writer_thread.is_alive(): + while write_queue.full(): + time.sleep(0.01) + else: + break + + write_queue.put_nowait(pybuf) + finally: + free(buf) + done = True + + writer_thread.join() + if exc_info is not None: + raise exc_info[0], exc_info[1], exc_info[2] + + def upload(self, stream, buffer_size=None): + """ + Pipe file-like object to file + """ + write_queue = Queue(50) + self._assert_writable() + + buffer_size = buffer_size or DEFAULT_BUFFER_SIZE + + done = False + exc_info = None + + def bg_write(): + try: + while not done or write_queue.qsize() > 0: + try: + buf = write_queue.get(timeout=0.01) + except QueueEmpty: + continue + + self.write(buf) + + except Exception as e: + exc_info = sys.exc_info() + + writer_thread = threading.Thread(target=bg_write) + writer_thread.start() + + try: + while True: + buf = stream.read(buffer_size) + if not buf: + break + + if writer_thread.is_alive(): + while write_queue.full(): + time.sleep(0.01) + else: + break + + write_queue.put_nowait(buf) + finally: + done = True + + writer_thread.join() + if exc_info is not None: + raise exc_info[0], exc_info[1], exc_info[2] + +BufferedIOBase.register(NativeFile) + +# ---------------------------------------------------------------------- +# Python file-like objects + + +cdef class PythonFile(NativeFile): + """ + A stream backed by a Python file object. + + This class allows using Python file objects with arbitrary Arrow + functions, including functions written in another language than Python. + + As a downside, there is a non-zero redirection cost in translating + Arrow stream calls to Python method calls. Furthermore, Python's + Global Interpreter Lock may limit parallelism in some situations. + """ + cdef: + object handle + + def __cinit__(self, handle, mode=None): + self.handle = handle + + if mode is None: + try: + inferred_mode = handle.mode + except AttributeError: + # Not all file-like objects have a mode attribute + # (e.g. BytesIO) + try: + inferred_mode = 'w' if handle.writable() else 'r' + except AttributeError: + raise ValueError("could not infer open mode for file-like " + "object %r, please pass it explicitly" + % (handle,)) + else: + inferred_mode = mode + + if inferred_mode.startswith('w'): + kind = 'w' + elif inferred_mode.startswith('r'): + kind = 'r' + else: + raise ValueError('Invalid file mode: {0}'.format(mode)) + + # If mode was given, check it matches the given file + if mode is not None: + if isinstance(handle, IOBase): + # Python 3 IO object + if kind == 'r': + if not handle.readable(): + raise TypeError("readable file expected") + else: + if not handle.writable(): + raise TypeError("writable file expected") + # (other duck-typed file-like objects are possible) + + # If possible, check the file is a binary file + if isinstance(handle, TextIOBase): + raise TypeError("binary file expected, got text file") + + if kind == 'r': + self.set_random_access_file( + shared_ptr[CRandomAccessFile](new PyReadableFile(handle))) + self.is_readable = True + else: + self.set_output_stream( + shared_ptr[COutputStream](new PyOutputStream(handle))) + self.is_writable = True + + def truncate(self, pos=None): + self.handle.truncate(pos) + + def readline(self, size=None): + return self.handle.readline(size) + + def readlines(self, hint=None): + return self.handle.readlines(hint) + + +cdef class MemoryMappedFile(NativeFile): + """ + A stream that represents a memory-mapped file. + + Supports 'r', 'r+', 'w' modes. + """ + cdef: + shared_ptr[CMemoryMappedFile] handle + object path + + @staticmethod + def create(path, size): + """ + Create a MemoryMappedFile + + Parameters + ---------- + path : str + Where to create the file. + size : int + Size of the memory mapped file. + """ + cdef: + shared_ptr[CMemoryMappedFile] handle + c_string c_path = encode_file_path(path) + int64_t c_size = size + + with nogil: + handle = GetResultValue(CMemoryMappedFile.Create(c_path, c_size)) + + cdef MemoryMappedFile result = MemoryMappedFile() + result.path = path + result.is_readable = True + result.is_writable = True + result.set_output_stream(<shared_ptr[COutputStream]> handle) + result.set_random_access_file(<shared_ptr[CRandomAccessFile]> handle) + result.handle = handle + + return result + + def _open(self, path, mode='r'): + self.path = path + + cdef: + FileMode c_mode + shared_ptr[CMemoryMappedFile] handle + c_string c_path = encode_file_path(path) + + if mode in ('r', 'rb'): + c_mode = FileMode_READ + self.is_readable = True + elif mode in ('w', 'wb'): + c_mode = FileMode_WRITE + self.is_writable = True + elif mode in ('r+', 'r+b', 'rb+'): + c_mode = FileMode_READWRITE + self.is_readable = True + self.is_writable = True + else: + raise ValueError('Invalid file mode: {0}'.format(mode)) + + with nogil: + handle = GetResultValue(CMemoryMappedFile.Open(c_path, c_mode)) + + self.set_output_stream(<shared_ptr[COutputStream]> handle) + self.set_random_access_file(<shared_ptr[CRandomAccessFile]> handle) + self.handle = handle + + def resize(self, new_size): + """ + Resize the map and underlying file. + + Parameters + ---------- + new_size : new size in bytes + """ + check_status(self.handle.get().Resize(new_size)) + + def fileno(self): + self._assert_open() + return self.handle.get().file_descriptor() + + +def memory_map(path, mode='r'): + """ + Open memory map at file path. Size of the memory map cannot change. + + Parameters + ---------- + path : str + mode : {'r', 'r+', 'w'}, default 'r' + Whether the file is opened for reading ('r+'), writing ('w') + or both ('r+'). + + Returns + ------- + mmap : MemoryMappedFile + """ + _check_is_file(path) + + cdef MemoryMappedFile mmap = MemoryMappedFile() + mmap._open(path, mode) + return mmap + + +cdef _check_is_file(path): + if os.path.isdir(path): + raise IOError("Expected file path, but {0} is a directory" + .format(path)) + + +def create_memory_map(path, size): + """ + Create a file of the given size and memory-map it. + + Parameters + ---------- + path : str + The file path to create, on the local filesystem. + size : int + The file size to create. + + Returns + ------- + mmap : MemoryMappedFile + """ + return MemoryMappedFile.create(path, size) + + +cdef class OSFile(NativeFile): + """ + A stream backed by a regular file descriptor. + """ + cdef: + object path + + def __cinit__(self, path, mode='r', MemoryPool memory_pool=None): + _check_is_file(path) + self.path = path + + cdef: + FileMode c_mode + shared_ptr[Readable] handle + c_string c_path = encode_file_path(path) + + if mode in ('r', 'rb'): + self._open_readable(c_path, maybe_unbox_memory_pool(memory_pool)) + elif mode in ('w', 'wb'): + self._open_writable(c_path) + else: + raise ValueError('Invalid file mode: {0}'.format(mode)) + + cdef _open_readable(self, c_string path, CMemoryPool* pool): + cdef shared_ptr[ReadableFile] handle + + with nogil: + handle = GetResultValue(ReadableFile.Open(path, pool)) + + self.is_readable = True + self.set_random_access_file(<shared_ptr[CRandomAccessFile]> handle) + + cdef _open_writable(self, c_string path): + with nogil: + self.output_stream = GetResultValue(FileOutputStream.Open(path)) + self.is_writable = True + + def fileno(self): + self._assert_open() + return self.handle.file_descriptor() + + +cdef class FixedSizeBufferWriter(NativeFile): + """ + A stream writing to a Arrow buffer. + """ + + def __cinit__(self, Buffer buffer): + self.output_stream.reset(new CFixedSizeBufferWriter(buffer.buffer)) + self.is_writable = True + + def set_memcopy_threads(self, int num_threads): + cdef CFixedSizeBufferWriter* writer = \ + <CFixedSizeBufferWriter*> self.output_stream.get() + writer.set_memcopy_threads(num_threads) + + def set_memcopy_blocksize(self, int64_t blocksize): + cdef CFixedSizeBufferWriter* writer = \ + <CFixedSizeBufferWriter*> self.output_stream.get() + writer.set_memcopy_blocksize(blocksize) + + def set_memcopy_threshold(self, int64_t threshold): + cdef CFixedSizeBufferWriter* writer = \ + <CFixedSizeBufferWriter*> self.output_stream.get() + writer.set_memcopy_threshold(threshold) + + +# ---------------------------------------------------------------------- +# Arrow buffers + + +cdef class Buffer(_Weakrefable): + """ + The base class for all Arrow buffers. + + A buffer represents a contiguous memory area. Many buffers will own + their memory, though not all of them do. + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call Buffer's constructor directly, use " + "`pyarrow.py_buffer` function instead.") + + cdef void init(self, const shared_ptr[CBuffer]& buffer): + self.buffer = buffer + self.shape[0] = self.size + self.strides[0] = <Py_ssize_t>(1) + + def __len__(self): + return self.size + + @property + def size(self): + """ + The buffer size in bytes. + """ + return self.buffer.get().size() + + @property + def address(self): + """ + The buffer's address, as an integer. + + The returned address may point to CPU or device memory. + Use `is_cpu()` to disambiguate. + """ + return self.buffer.get().address() + + def hex(self): + """ + Compute hexadecimal representation of the buffer. + + Returns + ------- + : bytes + """ + return self.buffer.get().ToHexString() + + @property + def is_mutable(self): + """ + Whether the buffer is mutable. + """ + return self.buffer.get().is_mutable() + + @property + def is_cpu(self): + """ + Whether the buffer is CPU-accessible. + """ + return self.buffer.get().is_cpu() + + @property + def parent(self): + cdef shared_ptr[CBuffer] parent_buf = self.buffer.get().parent() + + if parent_buf.get() == NULL: + return None + else: + return pyarrow_wrap_buffer(parent_buf) + + def __getitem__(self, key): + if PySlice_Check(key): + if (key.step or 1) != 1: + raise IndexError('only slices with step 1 supported') + return _normalize_slice(self, key) + + return self.getitem(_normalize_index(key, self.size)) + + cdef getitem(self, int64_t i): + return self.buffer.get().data()[i] + + def slice(self, offset=0, length=None): + """ + Slice this buffer. Memory is not copied. + + You can also use the Python slice notation ``buffer[start:stop]``. + + Parameters + ---------- + offset : int, default 0 + Offset from start of buffer to slice. + length : int, default None + Length of slice (default is until end of Buffer starting from + offset). + + Returns + ------- + sliced : Buffer + A logical view over this buffer. + """ + cdef shared_ptr[CBuffer] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + if length is None: + result = SliceBuffer(self.buffer, offset) + else: + result = SliceBuffer(self.buffer, offset, max(length, 0)) + + return pyarrow_wrap_buffer(result) + + def equals(self, Buffer other): + """ + Determine if two buffers contain exactly the same data. + + Parameters + ---------- + other : Buffer + + Returns + ------- + are_equal : True if buffer contents and size are equal + """ + cdef c_bool result = False + with nogil: + result = self.buffer.get().Equals(deref(other.buffer.get())) + return result + + def __eq__(self, other): + if isinstance(other, Buffer): + return self.equals(other) + else: + return self.equals(py_buffer(other)) + + def __reduce_ex__(self, protocol): + if protocol >= 5: + return py_buffer, (builtin_pickle.PickleBuffer(self),) + else: + return py_buffer, (self.to_pybytes(),) + + def to_pybytes(self): + """ + Return this buffer as a Python bytes object. Memory is copied. + """ + return cp.PyBytes_FromStringAndSize( + <const char*>self.buffer.get().data(), + self.buffer.get().size()) + + def __getbuffer__(self, cp.Py_buffer* buffer, int flags): + if self.buffer.get().is_mutable(): + buffer.readonly = 0 + else: + if flags & cp.PyBUF_WRITABLE: + raise BufferError("Writable buffer requested but Arrow " + "buffer was not mutable") + buffer.readonly = 1 + buffer.buf = <char *>self.buffer.get().data() + buffer.format = 'b' + buffer.internal = NULL + buffer.itemsize = 1 + buffer.len = self.size + buffer.ndim = 1 + buffer.obj = self + buffer.shape = self.shape + buffer.strides = self.strides + buffer.suboffsets = NULL + + def __getsegcount__(self, Py_ssize_t *len_out): + if len_out != NULL: + len_out[0] = <Py_ssize_t>self.size + return 1 + + def __getreadbuffer__(self, Py_ssize_t idx, void **p): + if idx != 0: + raise SystemError("accessing non-existent buffer segment") + if p != NULL: + p[0] = <void*> self.buffer.get().data() + return self.size + + def __getwritebuffer__(self, Py_ssize_t idx, void **p): + if not self.buffer.get().is_mutable(): + raise SystemError("trying to write an immutable buffer") + if idx != 0: + raise SystemError("accessing non-existent buffer segment") + if p != NULL: + p[0] = <void*> self.buffer.get().data() + return self.size + + +cdef class ResizableBuffer(Buffer): + """ + A base class for buffers that can be resized. + """ + + cdef void init_rz(self, const shared_ptr[CResizableBuffer]& buffer): + self.init(<shared_ptr[CBuffer]> buffer) + + def resize(self, int64_t new_size, shrink_to_fit=False): + """ + Resize buffer to indicated size. + + Parameters + ---------- + new_size : int + New size of buffer (padding may be added internally). + shrink_to_fit : bool, default False + If this is true, the buffer is shrunk when new_size is less + than the current size. + If this is false, the buffer is never shrunk. + """ + cdef c_bool c_shrink_to_fit = shrink_to_fit + with nogil: + check_status((<CResizableBuffer*> self.buffer.get()) + .Resize(new_size, c_shrink_to_fit)) + + +cdef shared_ptr[CResizableBuffer] _allocate_buffer(CMemoryPool* pool) except *: + with nogil: + return to_shared(GetResultValue(AllocateResizableBuffer(0, pool))) + + +def allocate_buffer(int64_t size, MemoryPool memory_pool=None, + resizable=False): + """ + Allocate a mutable buffer. + + Parameters + ---------- + size : int + Number of bytes to allocate (plus internal padding) + memory_pool : MemoryPool, optional + The pool to allocate memory from. + If not given, the default memory pool is used. + resizable : bool, default False + If true, the returned buffer is resizable. + + Returns + ------- + buffer : Buffer or ResizableBuffer + """ + cdef: + CMemoryPool* cpool = maybe_unbox_memory_pool(memory_pool) + shared_ptr[CResizableBuffer] c_rz_buffer + shared_ptr[CBuffer] c_buffer + + if resizable: + with nogil: + c_rz_buffer = to_shared(GetResultValue( + AllocateResizableBuffer(size, cpool))) + return pyarrow_wrap_resizable_buffer(c_rz_buffer) + else: + with nogil: + c_buffer = to_shared(GetResultValue(AllocateBuffer(size, cpool))) + return pyarrow_wrap_buffer(c_buffer) + + +cdef class BufferOutputStream(NativeFile): + + cdef: + shared_ptr[CResizableBuffer] buffer + + def __cinit__(self, MemoryPool memory_pool=None): + self.buffer = _allocate_buffer(maybe_unbox_memory_pool(memory_pool)) + self.output_stream.reset(new CBufferOutputStream( + <shared_ptr[CResizableBuffer]> self.buffer)) + self.is_writable = True + + def getvalue(self): + """ + Finalize output stream and return result as pyarrow.Buffer. + + Returns + ------- + value : Buffer + """ + with nogil: + check_status(self.output_stream.get().Close()) + return pyarrow_wrap_buffer(<shared_ptr[CBuffer]> self.buffer) + + +cdef class MockOutputStream(NativeFile): + + def __cinit__(self): + self.output_stream.reset(new CMockOutputStream()) + self.is_writable = True + + def size(self): + handle = <CMockOutputStream*> self.output_stream.get() + return handle.GetExtentBytesWritten() + + +cdef class BufferReader(NativeFile): + """ + Zero-copy reader from objects convertible to Arrow buffer. + + Parameters + ---------- + obj : Python bytes or pyarrow.Buffer + """ + cdef: + Buffer buffer + + def __cinit__(self, object obj): + self.buffer = as_buffer(obj) + self.set_random_access_file(shared_ptr[CRandomAccessFile]( + new CBufferReader(self.buffer.buffer))) + self.is_readable = True + + +cdef class CompressedInputStream(NativeFile): + """ + An input stream wrapper which decompresses data on the fly. + + Parameters + ---------- + stream : string, path, pa.NativeFile, or file-like object + Input stream object to wrap with the compression. + compression : str + The compression type ("bz2", "brotli", "gzip", "lz4" or "zstd"). + """ + + def __init__(self, object stream, str compression not None): + cdef: + NativeFile nf + Codec codec = Codec(compression) + shared_ptr[CInputStream] c_reader + shared_ptr[CCompressedInputStream] compressed_stream + nf = get_native_file(stream, False) + c_reader = nf.get_input_stream() + compressed_stream = GetResultValue( + CCompressedInputStream.Make(codec.unwrap(), c_reader) + ) + self.set_input_stream(<shared_ptr[CInputStream]> compressed_stream) + self.is_readable = True + + +cdef class CompressedOutputStream(NativeFile): + """ + An output stream wrapper which compresses data on the fly. + + Parameters + ---------- + stream : string, path, pa.NativeFile, or file-like object + Input stream object to wrap with the compression. + compression : str + The compression type ("bz2", "brotli", "gzip", "lz4" or "zstd"). + """ + + def __init__(self, object stream, str compression not None): + cdef: + Codec codec = Codec(compression) + shared_ptr[COutputStream] c_writer + shared_ptr[CCompressedOutputStream] compressed_stream + get_writer(stream, &c_writer) + compressed_stream = GetResultValue( + CCompressedOutputStream.Make(codec.unwrap(), c_writer) + ) + self.set_output_stream(<shared_ptr[COutputStream]> compressed_stream) + self.is_writable = True + + +ctypedef CBufferedInputStream* _CBufferedInputStreamPtr +ctypedef CBufferedOutputStream* _CBufferedOutputStreamPtr +ctypedef CRandomAccessFile* _RandomAccessFilePtr + + +cdef class BufferedInputStream(NativeFile): + """ + An input stream that performs buffered reads from + an unbuffered input stream, which can mitigate the overhead + of many small reads in some cases. + + Parameters + ---------- + stream : NativeFile + The input stream to wrap with the buffer + buffer_size : int + Size of the temporary read buffer. + memory_pool : MemoryPool + The memory pool used to allocate the buffer. + """ + + def __init__(self, NativeFile stream, int buffer_size, + MemoryPool memory_pool=None): + cdef shared_ptr[CBufferedInputStream] buffered_stream + + if buffer_size <= 0: + raise ValueError('Buffer size must be larger than zero') + buffered_stream = GetResultValue(CBufferedInputStream.Create( + buffer_size, maybe_unbox_memory_pool(memory_pool), + stream.get_input_stream())) + + self.set_input_stream(<shared_ptr[CInputStream]> buffered_stream) + self.is_readable = True + + def detach(self): + """ + Release the raw InputStream. + Further operations on this stream are invalid. + + Returns + ------- + raw : NativeFile + The underlying raw input stream + """ + cdef: + shared_ptr[CInputStream] c_raw + _CBufferedInputStreamPtr buffered + NativeFile raw + + buffered = dynamic_cast[_CBufferedInputStreamPtr]( + self.input_stream.get()) + assert buffered != nullptr + + with nogil: + c_raw = GetResultValue(buffered.Detach()) + + raw = NativeFile() + raw.is_readable = True + # Find out whether the raw stream is a RandomAccessFile + # or a mere InputStream. This helps us support seek() etc. + # selectively. + if dynamic_cast[_RandomAccessFilePtr](c_raw.get()) != nullptr: + raw.set_random_access_file( + static_pointer_cast[CRandomAccessFile, CInputStream](c_raw)) + else: + raw.set_input_stream(c_raw) + return raw + + +cdef class BufferedOutputStream(NativeFile): + """ + An output stream that performs buffered reads from + an unbuffered output stream, which can mitigate the overhead + of many small writes in some cases. + + Parameters + ---------- + stream : NativeFile + The writable output stream to wrap with the buffer + buffer_size : int + Size of the buffer that should be added. + memory_pool : MemoryPool + The memory pool used to allocate the buffer. + """ + + def __init__(self, NativeFile stream, int buffer_size, + MemoryPool memory_pool=None): + cdef shared_ptr[CBufferedOutputStream] buffered_stream + + if buffer_size <= 0: + raise ValueError('Buffer size must be larger than zero') + buffered_stream = GetResultValue(CBufferedOutputStream.Create( + buffer_size, maybe_unbox_memory_pool(memory_pool), + stream.get_output_stream())) + + self.set_output_stream(<shared_ptr[COutputStream]> buffered_stream) + self.is_writable = True + + def detach(self): + """ + Flush any buffered writes and release the raw OutputStream. + Further operations on this stream are invalid. + + Returns + ------- + raw : NativeFile + The underlying raw output stream. + """ + cdef: + shared_ptr[COutputStream] c_raw + _CBufferedOutputStreamPtr buffered + NativeFile raw + + buffered = dynamic_cast[_CBufferedOutputStreamPtr]( + self.output_stream.get()) + assert buffered != nullptr + + with nogil: + c_raw = GetResultValue(buffered.Detach()) + + raw = NativeFile() + raw.is_writable = True + raw.set_output_stream(c_raw) + return raw + + +cdef void _cb_transform(transform_func, const shared_ptr[CBuffer]& src, + shared_ptr[CBuffer]* dest) except *: + py_dest = transform_func(pyarrow_wrap_buffer(src)) + dest[0] = pyarrow_unwrap_buffer(py_buffer(py_dest)) + + +cdef class TransformInputStream(NativeFile): + """ + Transform an input stream. + + Parameters + ---------- + stream : NativeFile + The stream to transform. + transform_func : callable + The transformation to apply. + """ + + def __init__(self, NativeFile stream, transform_func): + self.set_input_stream(TransformInputStream.make_native( + stream.get_input_stream(), transform_func)) + self.is_readable = True + + @staticmethod + cdef shared_ptr[CInputStream] make_native( + shared_ptr[CInputStream] stream, transform_func) except *: + cdef: + shared_ptr[CInputStream] transform_stream + CTransformInputStreamVTable vtable + + vtable.transform = _cb_transform + return MakeTransformInputStream(stream, move(vtable), + transform_func) + + +class Transcoder: + + def __init__(self, decoder, encoder): + self._decoder = decoder + self._encoder = encoder + + def __call__(self, buf): + final = len(buf) == 0 + return self._encoder.encode(self._decoder.decode(buf, final), final) + + +def transcoding_input_stream(stream, src_encoding, dest_encoding): + """ + Add a transcoding transformation to the stream. + Incoming data will be decoded according to ``src_encoding`` and + then re-encoded according to ``dest_encoding``. + + Parameters + ---------- + stream : NativeFile + The stream to which the transformation should be applied. + src_encoding : str + The codec to use when reading data data. + dest_encoding : str + The codec to use for emitted data. + """ + src_codec = codecs.lookup(src_encoding) + dest_codec = codecs.lookup(dest_encoding) + if src_codec.name == dest_codec.name: + # Avoid losing performance on no-op transcoding + # (encoding errors won't be detected) + return stream + return TransformInputStream(stream, + Transcoder(src_codec.incrementaldecoder(), + dest_codec.incrementalencoder())) + + +cdef shared_ptr[CInputStream] native_transcoding_input_stream( + shared_ptr[CInputStream] stream, src_encoding, + dest_encoding) except *: + src_codec = codecs.lookup(src_encoding) + dest_codec = codecs.lookup(dest_encoding) + if src_codec.name == dest_codec.name: + # Avoid losing performance on no-op transcoding + # (encoding errors won't be detected) + return stream + return TransformInputStream.make_native( + stream, Transcoder(src_codec.incrementaldecoder(), + dest_codec.incrementalencoder())) + + +def py_buffer(object obj): + """ + Construct an Arrow buffer from a Python bytes-like or buffer-like object + + Parameters + ---------- + obj : object + the object from which the buffer should be constructed. + """ + cdef shared_ptr[CBuffer] buf + buf = GetResultValue(PyBuffer.FromPyObject(obj)) + return pyarrow_wrap_buffer(buf) + + +def foreign_buffer(address, size, base=None): + """ + Construct an Arrow buffer with the given *address* and *size*. + + The buffer will be optionally backed by the Python *base* object, if given. + The *base* object will be kept alive as long as this buffer is alive, + including across language boundaries (for example if the buffer is + referenced by C++ code). + + Parameters + ---------- + address : int + The starting address of the buffer. The address can + refer to both device or host memory but it must be + accessible from device after mapping it with + `get_device_address` method. + size : int + The size of device buffer in bytes. + base : {None, object} + Object that owns the referenced memory. + """ + cdef: + intptr_t c_addr = address + int64_t c_size = size + shared_ptr[CBuffer] buf + + check_status(PyForeignBuffer.Make(<uint8_t*> c_addr, c_size, + base, &buf)) + return pyarrow_wrap_buffer(buf) + + +def as_buffer(object o): + if isinstance(o, Buffer): + return o + return py_buffer(o) + + +cdef shared_ptr[CBuffer] as_c_buffer(object o) except *: + cdef shared_ptr[CBuffer] buf + if isinstance(o, Buffer): + buf = (<Buffer> o).buffer + if buf == nullptr: + raise ValueError("got null buffer") + else: + buf = GetResultValue(PyBuffer.FromPyObject(o)) + return buf + + +cdef NativeFile get_native_file(object source, c_bool use_memory_map): + try: + source_path = _stringify_path(source) + except TypeError: + if isinstance(source, Buffer): + source = BufferReader(source) + elif not isinstance(source, NativeFile) and hasattr(source, 'read'): + # Optimistically hope this is file-like + source = PythonFile(source, mode='r') + else: + if use_memory_map: + source = memory_map(source_path, mode='r') + else: + source = OSFile(source_path, mode='r') + + return source + + +cdef get_reader(object source, c_bool use_memory_map, + shared_ptr[CRandomAccessFile]* reader): + cdef NativeFile nf + + nf = get_native_file(source, use_memory_map) + reader[0] = nf.get_random_access_file() + + +cdef get_input_stream(object source, c_bool use_memory_map, + shared_ptr[CInputStream]* out): + """ + Like get_reader(), but can automatically decompress, and returns + an InputStream. + """ + cdef: + NativeFile nf + Codec codec + shared_ptr[CInputStream] input_stream + + try: + codec = Codec.detect(source) + except TypeError: + codec = None + + nf = get_native_file(source, use_memory_map) + input_stream = nf.get_input_stream() + + # codec is None if compression can't be detected + if codec is not None: + input_stream = <shared_ptr[CInputStream]> GetResultValue( + CCompressedInputStream.Make(codec.unwrap(), input_stream) + ) + + out[0] = input_stream + + +cdef get_writer(object source, shared_ptr[COutputStream]* writer): + cdef NativeFile nf + + try: + source_path = _stringify_path(source) + except TypeError: + if not isinstance(source, NativeFile) and hasattr(source, 'write'): + # Optimistically hope this is file-like + source = PythonFile(source, mode='w') + else: + source = OSFile(source_path, mode='w') + + if isinstance(source, NativeFile): + nf = source + writer[0] = nf.get_output_stream() + else: + raise TypeError('Unable to read from object of type: {0}' + .format(type(source))) + + +# --------------------------------------------------------------------- + + +def _detect_compression(path): + if isinstance(path, str): + if path.endswith('.bz2'): + return 'bz2' + elif path.endswith('.gz'): + return 'gzip' + elif path.endswith('.lz4'): + return 'lz4' + elif path.endswith('.zst'): + return 'zstd' + + +cdef CCompressionType _ensure_compression(str name) except *: + uppercase = name.upper() + if uppercase == 'BZ2': + return CCompressionType_BZ2 + elif uppercase == 'GZIP': + return CCompressionType_GZIP + elif uppercase == 'BROTLI': + return CCompressionType_BROTLI + elif uppercase == 'LZ4' or uppercase == 'LZ4_FRAME': + return CCompressionType_LZ4_FRAME + elif uppercase == 'LZ4_RAW': + return CCompressionType_LZ4 + elif uppercase == 'SNAPPY': + return CCompressionType_SNAPPY + elif uppercase == 'ZSTD': + return CCompressionType_ZSTD + else: + raise ValueError('Invalid value for compression: {!r}'.format(name)) + + +cdef class Codec(_Weakrefable): + """ + Compression codec. + + Parameters + ---------- + compression : str + Type of compression codec to initialize, valid values are: 'gzip', + 'bz2', 'brotli', 'lz4' (or 'lz4_frame'), 'lz4_raw', 'zstd' and + 'snappy'. + compression_level : int, None + Optional parameter specifying how aggressively to compress. The + possible ranges and effect of this parameter depend on the specific + codec chosen. Higher values compress more but typically use more + resources (CPU/RAM). Some codecs support negative values. + + gzip + The compression_level maps to the memlevel parameter of + deflateInit2. Higher levels use more RAM but are faster + and should have higher compression ratios. + + bz2 + The compression level maps to the blockSize100k parameter of + the BZ2_bzCompressInit function. Higher levels use more RAM + but are faster and should have higher compression ratios. + + brotli + The compression level maps to the BROTLI_PARAM_QUALITY + parameter. Higher values are slower and should have higher + compression ratios. + + lz4/lz4_frame/lz4_raw + The compression level parameter is not supported and must + be None + + zstd + The compression level maps to the compressionLevel parameter + of ZSTD_initCStream. Negative values are supported. Higher + values are slower and should have higher compression ratios. + + snappy + The compression level parameter is not supported and must + be None + + + Raises + ------ + ValueError + If invalid compression value is passed. + """ + + def __init__(self, str compression not None, compression_level=None): + cdef CCompressionType typ = _ensure_compression(compression) + if compression_level is not None: + self.wrapped = shared_ptr[CCodec](move(GetResultValue( + CCodec.CreateWithLevel(typ, compression_level)))) + else: + self.wrapped = shared_ptr[CCodec](move(GetResultValue( + CCodec.Create(typ)))) + + cdef inline CCodec* unwrap(self) nogil: + return self.wrapped.get() + + @staticmethod + def detect(path): + """ + Detect and instantiate compression codec based on file extension. + + Parameters + ---------- + path : str, path-like + File-path to detect compression from. + + Raises + ------ + TypeError + If the passed value is not path-like. + ValueError + If the compression can't be detected from the path. + + Returns + ------- + Codec + """ + return Codec(_detect_compression(_stringify_path(path))) + + @staticmethod + def is_available(str compression not None): + """ + Returns whether the compression support has been built and enabled. + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + + Returns + ------- + bool + """ + cdef CCompressionType typ = _ensure_compression(compression) + return CCodec.IsAvailable(typ) + + @staticmethod + def supports_compression_level(str compression not None): + """ + Returns true if the compression level parameter is supported + for the given codec. + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return CCodec.SupportsCompressionLevel(typ) + + @staticmethod + def default_compression_level(str compression not None): + """ + Returns the compression level that Arrow will use for the codec if + None is specified. + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return GetResultValue(CCodec.DefaultCompressionLevel(typ)) + + @staticmethod + def minimum_compression_level(str compression not None): + """ + Returns the smallest valid value for the compression level + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return GetResultValue(CCodec.MinimumCompressionLevel(typ)) + + @staticmethod + def maximum_compression_level(str compression not None): + """ + Returns the largest valid value for the compression level + + Parameters + ---------- + compression : str + Type of compression codec, + refer to Codec docstring for a list of supported ones. + """ + cdef CCompressionType typ = _ensure_compression(compression) + return GetResultValue(CCodec.MaximumCompressionLevel(typ)) + + @property + def name(self): + """Returns the name of the codec""" + return frombytes(self.unwrap().name()) + + @property + def compression_level(self): + """Returns the compression level parameter of the codec""" + return frombytes(self.unwrap().compression_level()) + + def compress(self, object buf, asbytes=False, memory_pool=None): + """ + Compress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or other object supporting buffer protocol + asbytes : bool, default False + Return result as Python bytes object, otherwise Buffer + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any + + Returns + ------- + compressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef: + shared_ptr[CBuffer] owned_buf + CBuffer* c_buf + PyObject* pyobj + ResizableBuffer out_buf + int64_t max_output_size + int64_t output_length + uint8_t* output_buffer = NULL + + owned_buf = as_c_buffer(buf) + c_buf = owned_buf.get() + + max_output_size = self.wrapped.get().MaxCompressedLen( + c_buf.size(), c_buf.data() + ) + + if asbytes: + pyobj = PyBytes_FromStringAndSizeNative(NULL, max_output_size) + output_buffer = <uint8_t*> cp.PyBytes_AS_STRING(<object> pyobj) + else: + out_buf = allocate_buffer( + max_output_size, memory_pool=memory_pool, resizable=True + ) + output_buffer = out_buf.buffer.get().mutable_data() + + with nogil: + output_length = GetResultValue( + self.unwrap().Compress( + c_buf.size(), + c_buf.data(), + max_output_size, + output_buffer + ) + ) + + if asbytes: + cp._PyBytes_Resize(&pyobj, <Py_ssize_t> output_length) + return PyObject_to_object(pyobj) + else: + out_buf.resize(output_length) + return out_buf + + def decompress(self, object buf, decompressed_size=None, asbytes=False, + memory_pool=None): + """ + Decompress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or memoryview-compatible object + decompressed_size : int64_t, default None + If not specified, will be computed if the codec is able to + determine the uncompressed buffer size. + asbytes : boolean, default False + Return result as Python bytes object, otherwise Buffer + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any. + + Returns + ------- + uncompressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef: + shared_ptr[CBuffer] owned_buf + CBuffer* c_buf + Buffer out_buf + int64_t output_size + uint8_t* output_buffer = NULL + + owned_buf = as_c_buffer(buf) + c_buf = owned_buf.get() + + if decompressed_size is None: + raise ValueError( + "Must pass decompressed_size for {} codec".format(self) + ) + + output_size = decompressed_size + + if asbytes: + pybuf = cp.PyBytes_FromStringAndSize(NULL, output_size) + output_buffer = <uint8_t*> cp.PyBytes_AS_STRING(pybuf) + else: + out_buf = allocate_buffer(output_size, memory_pool=memory_pool) + output_buffer = out_buf.buffer.get().mutable_data() + + with nogil: + GetResultValue( + self.unwrap().Decompress( + c_buf.size(), + c_buf.data(), + output_size, + output_buffer + ) + ) + + return pybuf if asbytes else out_buf + + +def compress(object buf, codec='lz4', asbytes=False, memory_pool=None): + """ + Compress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or other object supporting buffer protocol + codec : str, default 'lz4' + Compression codec. + Supported types: {'brotli, 'gzip', 'lz4', 'lz4_raw', 'snappy', 'zstd'} + asbytes : bool, default False + Return result as Python bytes object, otherwise Buffer. + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any. + + Returns + ------- + compressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef Codec coder = Codec(codec) + return coder.compress(buf, asbytes=asbytes, memory_pool=memory_pool) + + +def decompress(object buf, decompressed_size=None, codec='lz4', + asbytes=False, memory_pool=None): + """ + Decompress data from buffer-like object. + + Parameters + ---------- + buf : pyarrow.Buffer, bytes, or memoryview-compatible object + Input object to decompress data from. + decompressed_size : int64_t, default None + If not specified, will be computed if the codec is able to determine + the uncompressed buffer size. + codec : str, default 'lz4' + Compression codec. + Supported types: {'brotli, 'gzip', 'lz4', 'lz4_raw', 'snappy', 'zstd'} + asbytes : bool, default False + Return result as Python bytes object, otherwise Buffer. + memory_pool : MemoryPool, default None + Memory pool to use for buffer allocations, if any. + + Returns + ------- + uncompressed : pyarrow.Buffer or bytes (if asbytes=True) + """ + cdef Codec decoder = Codec(codec) + return decoder.decompress(buf, asbytes=asbytes, memory_pool=memory_pool, + decompressed_size=decompressed_size) + + +def input_stream(source, compression='detect', buffer_size=None): + """ + Create an Arrow input stream. + + Parameters + ---------- + source : str, Path, buffer, file-like object, ... + The source to open for reading. + compression : str optional, default 'detect' + The compression algorithm to use for on-the-fly decompression. + If "detect" and source is a file path, then compression will be + chosen based on the file extension. + If None, no compression will be applied. + Otherwise, a well-known algorithm name must be supplied (e.g. "gzip"). + buffer_size : int, default None + If None or 0, no buffering will happen. Otherwise the size of the + temporary read buffer. + """ + cdef NativeFile stream + + try: + source_path = _stringify_path(source) + except TypeError: + source_path = None + + if isinstance(source, NativeFile): + stream = source + elif source_path is not None: + stream = OSFile(source_path, 'r') + elif isinstance(source, (Buffer, memoryview)): + stream = BufferReader(as_buffer(source)) + elif (hasattr(source, 'read') and + hasattr(source, 'close') and + hasattr(source, 'closed')): + stream = PythonFile(source, 'r') + else: + raise TypeError("pa.input_stream() called with instance of '{}'" + .format(source.__class__)) + + if compression == 'detect': + # detect for OSFile too + compression = _detect_compression(source_path) + + if buffer_size is not None and buffer_size != 0: + stream = BufferedInputStream(stream, buffer_size) + + if compression is not None: + stream = CompressedInputStream(stream, compression) + + return stream + + +def output_stream(source, compression='detect', buffer_size=None): + """ + Create an Arrow output stream. + + Parameters + ---------- + source : str, Path, buffer, file-like object, ... + The source to open for writing. + compression : str optional, default 'detect' + The compression algorithm to use for on-the-fly compression. + If "detect" and source is a file path, then compression will be + chosen based on the file extension. + If None, no compression will be applied. + Otherwise, a well-known algorithm name must be supplied (e.g. "gzip"). + buffer_size : int, default None + If None or 0, no buffering will happen. Otherwise the size of the + temporary write buffer. + """ + cdef NativeFile stream + + try: + source_path = _stringify_path(source) + except TypeError: + source_path = None + + if isinstance(source, NativeFile): + stream = source + elif source_path is not None: + stream = OSFile(source_path, 'w') + elif isinstance(source, (Buffer, memoryview)): + stream = FixedSizeBufferWriter(as_buffer(source)) + elif (hasattr(source, 'write') and + hasattr(source, 'close') and + hasattr(source, 'closed')): + stream = PythonFile(source, 'w') + else: + raise TypeError("pa.output_stream() called with instance of '{}'" + .format(source.__class__)) + + if compression == 'detect': + compression = _detect_compression(source_path) + + if buffer_size is not None and buffer_size != 0: + stream = BufferedOutputStream(stream, buffer_size) + + if compression is not None: + stream = CompressedOutputStream(stream, compression) + + return stream diff --git a/src/arrow/python/pyarrow/ipc.pxi b/src/arrow/python/pyarrow/ipc.pxi new file mode 100644 index 000000000..9304bbb97 --- /dev/null +++ b/src/arrow/python/pyarrow/ipc.pxi @@ -0,0 +1,1009 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import namedtuple +import warnings + + +cpdef enum MetadataVersion: + V1 = <char> CMetadataVersion_V1 + V2 = <char> CMetadataVersion_V2 + V3 = <char> CMetadataVersion_V3 + V4 = <char> CMetadataVersion_V4 + V5 = <char> CMetadataVersion_V5 + + +cdef object _wrap_metadata_version(CMetadataVersion version): + return MetadataVersion(<char> version) + + +cdef CMetadataVersion _unwrap_metadata_version( + MetadataVersion version) except *: + if version == MetadataVersion.V1: + return CMetadataVersion_V1 + elif version == MetadataVersion.V2: + return CMetadataVersion_V2 + elif version == MetadataVersion.V3: + return CMetadataVersion_V3 + elif version == MetadataVersion.V4: + return CMetadataVersion_V4 + elif version == MetadataVersion.V5: + return CMetadataVersion_V5 + raise ValueError("Not a metadata version: " + repr(version)) + + +_WriteStats = namedtuple( + 'WriteStats', + ('num_messages', 'num_record_batches', 'num_dictionary_batches', + 'num_dictionary_deltas', 'num_replaced_dictionaries')) + + +class WriteStats(_WriteStats): + """IPC write statistics + + Parameters + ---------- + num_messages : number of messages. + num_record_batches : number of record batches. + num_dictionary_batches : number of dictionary batches. + num_dictionary_deltas : delta of dictionaries. + num_replaced_dictionaries : number of replaced dictionaries. + """ + __slots__ = () + + +@staticmethod +cdef _wrap_write_stats(CIpcWriteStats c): + return WriteStats(c.num_messages, c.num_record_batches, + c.num_dictionary_batches, c.num_dictionary_deltas, + c.num_replaced_dictionaries) + + +_ReadStats = namedtuple( + 'ReadStats', + ('num_messages', 'num_record_batches', 'num_dictionary_batches', + 'num_dictionary_deltas', 'num_replaced_dictionaries')) + + +class ReadStats(_ReadStats): + """IPC read statistics + + Parameters + ---------- + num_messages : number of messages. + num_record_batches : number of record batches. + num_dictionary_batches : number of dictionary batches. + num_dictionary_deltas : delta of dictionaries. + num_replaced_dictionaries : number of replaced dictionaries. + """ + __slots__ = () + + +@staticmethod +cdef _wrap_read_stats(CIpcReadStats c): + return ReadStats(c.num_messages, c.num_record_batches, + c.num_dictionary_batches, c.num_dictionary_deltas, + c.num_replaced_dictionaries) + + +cdef class IpcWriteOptions(_Weakrefable): + """ + Serialization options for the IPC format. + + Parameters + ---------- + metadata_version : MetadataVersion, default MetadataVersion.V5 + The metadata version to write. V5 is the current and latest, + V4 is the pre-1.0 metadata version (with incompatible Union layout). + allow_64bit : bool, default False + If true, allow field lengths that don't fit in a signed 32-bit int. + use_legacy_format : bool, default False + Whether to use the pre-Arrow 0.15 IPC format. + compression : str, Codec, or None + compression codec to use for record batch buffers. + If None then batch buffers will be uncompressed. + Must be "lz4", "zstd" or None. + To specify a compression_level use `pyarrow.Codec` + use_threads : bool + Whether to use the global CPU thread pool to parallelize any + computational tasks like compression. + emit_dictionary_deltas : bool + Whether to emit dictionary deltas. Default is false for maximum + stream compatibility. + """ + __slots__ = () + + # cdef block is in lib.pxd + + def __init__(self, *, metadata_version=MetadataVersion.V5, + bint allow_64bit=False, use_legacy_format=False, + compression=None, bint use_threads=True, + bint emit_dictionary_deltas=False): + self.c_options = CIpcWriteOptions.Defaults() + self.allow_64bit = allow_64bit + self.use_legacy_format = use_legacy_format + self.metadata_version = metadata_version + if compression is not None: + self.compression = compression + self.use_threads = use_threads + self.emit_dictionary_deltas = emit_dictionary_deltas + + @property + def allow_64bit(self): + return self.c_options.allow_64bit + + @allow_64bit.setter + def allow_64bit(self, bint value): + self.c_options.allow_64bit = value + + @property + def use_legacy_format(self): + return self.c_options.write_legacy_ipc_format + + @use_legacy_format.setter + def use_legacy_format(self, bint value): + self.c_options.write_legacy_ipc_format = value + + @property + def metadata_version(self): + return _wrap_metadata_version(self.c_options.metadata_version) + + @metadata_version.setter + def metadata_version(self, value): + self.c_options.metadata_version = _unwrap_metadata_version(value) + + @property + def compression(self): + if self.c_options.codec == nullptr: + return None + else: + return frombytes(self.c_options.codec.get().name()) + + @compression.setter + def compression(self, value): + if value is None: + self.c_options.codec.reset() + elif isinstance(value, str): + self.c_options.codec = shared_ptr[CCodec](GetResultValue( + CCodec.Create(_ensure_compression(value))).release()) + elif isinstance(value, Codec): + self.c_options.codec = (<Codec>value).wrapped + else: + raise TypeError( + "Property `compression` must be None, str, or pyarrow.Codec") + + @property + def use_threads(self): + return self.c_options.use_threads + + @use_threads.setter + def use_threads(self, bint value): + self.c_options.use_threads = value + + @property + def emit_dictionary_deltas(self): + return self.c_options.emit_dictionary_deltas + + @emit_dictionary_deltas.setter + def emit_dictionary_deltas(self, bint value): + self.c_options.emit_dictionary_deltas = value + + +cdef class Message(_Weakrefable): + """ + Container for an Arrow IPC message with metadata and optional body + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "`pyarrow.ipc.read_message` function instead." + .format(self.__class__.__name__)) + + @property + def type(self): + return frombytes(FormatMessageType(self.message.get().type())) + + @property + def metadata(self): + return pyarrow_wrap_buffer(self.message.get().metadata()) + + @property + def metadata_version(self): + return _wrap_metadata_version(self.message.get().metadata_version()) + + @property + def body(self): + cdef shared_ptr[CBuffer] body = self.message.get().body() + if body.get() == NULL: + return None + else: + return pyarrow_wrap_buffer(body) + + def equals(self, Message other): + """ + Returns True if the message contents (metadata and body) are identical + + Parameters + ---------- + other : Message + + Returns + ------- + are_equal : bool + """ + cdef c_bool result + with nogil: + result = self.message.get().Equals(deref(other.message.get())) + return result + + def serialize_to(self, NativeFile sink, alignment=8, memory_pool=None): + """ + Write message to generic OutputStream + + Parameters + ---------- + sink : NativeFile + alignment : int, default 8 + Byte alignment for metadata and body + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + """ + cdef: + int64_t output_length = 0 + COutputStream* out + CIpcWriteOptions options + + options.alignment = alignment + out = sink.get_output_stream().get() + with nogil: + check_status(self.message.get() + .SerializeTo(out, options, &output_length)) + + def serialize(self, alignment=8, memory_pool=None): + """ + Write message as encapsulated IPC message + + Parameters + ---------- + alignment : int, default 8 + Byte alignment for metadata and body + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + + Returns + ------- + serialized : Buffer + """ + stream = BufferOutputStream(memory_pool) + self.serialize_to(stream, alignment=alignment, memory_pool=memory_pool) + return stream.getvalue() + + def __repr__(self): + if self.message == nullptr: + return """pyarrow.Message(uninitialized)""" + + metadata_len = self.metadata.size + body = self.body + body_len = 0 if body is None else body.size + + return """pyarrow.Message +type: {0} +metadata length: {1} +body length: {2}""".format(self.type, metadata_len, body_len) + + +cdef class MessageReader(_Weakrefable): + """ + Interface for reading Message objects from some source (like an + InputStream) + """ + cdef: + unique_ptr[CMessageReader] reader + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "`pyarrow.ipc.MessageReader.open_stream` function " + "instead.".format(self.__class__.__name__)) + + @staticmethod + def open_stream(source): + """ + Open stream from source. + + Parameters + ---------- + source : a readable source, like an InputStream + """ + cdef: + MessageReader result = MessageReader.__new__(MessageReader) + shared_ptr[CInputStream] in_stream + unique_ptr[CMessageReader] reader + + _get_input_stream(source, &in_stream) + with nogil: + reader = CMessageReader.Open(in_stream) + result.reader.reset(reader.release()) + + return result + + def __iter__(self): + while True: + yield self.read_next_message() + + def read_next_message(self): + """ + Read next Message from the stream. + + Raises + ------ + StopIteration : at end of stream + """ + cdef Message result = Message.__new__(Message) + + with nogil: + result.message = move(GetResultValue(self.reader.get() + .ReadNextMessage())) + + if result.message.get() == NULL: + raise StopIteration + + return result + +# ---------------------------------------------------------------------- +# File and stream readers and writers + +cdef class _CRecordBatchWriter(_Weakrefable): + """The base RecordBatchWriter wrapper. + + Provides common implementations of convenience methods. Should not + be instantiated directly by user code. + """ + + # cdef block is in lib.pxd + + def write(self, table_or_batch): + """ + Write RecordBatch or Table to stream. + + Parameters + ---------- + table_or_batch : {RecordBatch, Table} + """ + if isinstance(table_or_batch, RecordBatch): + self.write_batch(table_or_batch) + elif isinstance(table_or_batch, Table): + self.write_table(table_or_batch) + else: + raise ValueError(type(table_or_batch)) + + def write_batch(self, RecordBatch batch): + """ + Write RecordBatch to stream. + + Parameters + ---------- + batch : RecordBatch + """ + with nogil: + check_status(self.writer.get() + .WriteRecordBatch(deref(batch.batch))) + + def write_table(self, Table table, max_chunksize=None, **kwargs): + """ + Write Table to stream in (contiguous) RecordBatch objects. + + Parameters + ---------- + table : Table + max_chunksize : int, default None + Maximum size for RecordBatch chunks. Individual chunks may be + smaller depending on the chunk layout of individual columns. + """ + cdef: + # max_chunksize must be > 0 to have any impact + int64_t c_max_chunksize = -1 + + if 'chunksize' in kwargs: + max_chunksize = kwargs['chunksize'] + msg = ('The parameter chunksize is deprecated for the write_table ' + 'methods as of 0.15, please use parameter ' + 'max_chunksize instead') + warnings.warn(msg, FutureWarning) + + if max_chunksize is not None: + c_max_chunksize = max_chunksize + + with nogil: + check_status(self.writer.get().WriteTable(table.table[0], + c_max_chunksize)) + + def close(self): + """ + Close stream and write end-of-stream 0 marker. + """ + with nogil: + check_status(self.writer.get().Close()) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + @property + def stats(self): + """ + Current IPC write statistics. + """ + if not self.writer: + raise ValueError("Operation on closed writer") + return _wrap_write_stats(self.writer.get().stats()) + + +cdef class _RecordBatchStreamWriter(_CRecordBatchWriter): + cdef: + CIpcWriteOptions options + bint closed + + def __cinit__(self): + pass + + def __dealloc__(self): + pass + + @property + def _use_legacy_format(self): + # For testing (see test_ipc.py) + return self.options.write_legacy_ipc_format + + @property + def _metadata_version(self): + # For testing (see test_ipc.py) + return _wrap_metadata_version(self.options.metadata_version) + + def _open(self, sink, Schema schema not None, + IpcWriteOptions options=IpcWriteOptions()): + cdef: + shared_ptr[COutputStream] c_sink + + self.options = options.c_options + get_writer(sink, &c_sink) + with nogil: + self.writer = GetResultValue( + MakeStreamWriter(c_sink, schema.sp_schema, + self.options)) + + +cdef _get_input_stream(object source, shared_ptr[CInputStream]* out): + try: + source = as_buffer(source) + except TypeError: + # Non-buffer-like + pass + + get_input_stream(source, True, out) + + +class _ReadPandasMixin: + + def read_pandas(self, **options): + """ + Read contents of stream to a pandas.DataFrame. + + Read all record batches as a pyarrow.Table then convert it to a + pandas.DataFrame using Table.to_pandas. + + Parameters + ---------- + **options : arguments to forward to Table.to_pandas + + Returns + ------- + df : pandas.DataFrame + """ + table = self.read_all() + return table.to_pandas(**options) + + +cdef class RecordBatchReader(_Weakrefable): + """Base class for reading stream of record batches. + + Provides common implementations of convenience methods. Should not + be instantiated directly by user code. + """ + + # cdef block is in lib.pxd + + def __iter__(self): + while True: + try: + yield self.read_next_batch() + except StopIteration: + return + + @property + def schema(self): + """ + Shared schema of the record batches in the stream. + """ + cdef shared_ptr[CSchema] c_schema + + with nogil: + c_schema = self.reader.get().schema() + + return pyarrow_wrap_schema(c_schema) + + def get_next_batch(self): + import warnings + warnings.warn('Please use read_next_batch instead of ' + 'get_next_batch', FutureWarning) + return self.read_next_batch() + + def read_next_batch(self): + """ + Read next RecordBatch from the stream. + + Raises + ------ + StopIteration: + At end of stream. + """ + cdef shared_ptr[CRecordBatch] batch + + with nogil: + check_status(self.reader.get().ReadNext(&batch)) + + if batch.get() == NULL: + raise StopIteration + + return pyarrow_wrap_batch(batch) + + def read_all(self): + """ + Read all record batches as a pyarrow.Table. + """ + cdef shared_ptr[CTable] table + with nogil: + check_status(self.reader.get().ReadAll(&table)) + return pyarrow_wrap_table(table) + + read_pandas = _ReadPandasMixin.read_pandas + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def _export_to_c(self, uintptr_t out_ptr): + """ + Export to a C ArrowArrayStream struct, given its pointer. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowArrayStream struct. + + Be careful: if you don't pass the ArrowArrayStream struct to a + consumer, array memory will leak. This is a low-level function + intended for expert users. + """ + with nogil: + check_status(ExportRecordBatchReader( + self.reader, <ArrowArrayStream*> out_ptr)) + + @staticmethod + def _import_from_c(uintptr_t in_ptr): + """ + Import RecordBatchReader from a C ArrowArrayStream struct, + given its pointer. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowArrayStream struct. + + This is a low-level function intended for expert users. + """ + cdef: + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader self + + with nogil: + c_reader = GetResultValue(ImportRecordBatchReader( + <ArrowArrayStream*> in_ptr)) + + self = RecordBatchReader.__new__(RecordBatchReader) + self.reader = c_reader + return self + + @staticmethod + def from_batches(schema, batches): + """ + Create RecordBatchReader from an iterable of batches. + + Parameters + ---------- + schema : Schema + The shared schema of the record batches + batches : Iterable[RecordBatch] + The batches that this reader will return. + + Returns + ------- + reader : RecordBatchReader + """ + cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader self + + c_schema = pyarrow_unwrap_schema(schema) + c_reader = GetResultValue(CPyRecordBatchReader.Make( + c_schema, batches)) + + self = RecordBatchReader.__new__(RecordBatchReader) + self.reader = c_reader + return self + + +cdef class _RecordBatchStreamReader(RecordBatchReader): + cdef: + shared_ptr[CInputStream] in_stream + CIpcReadOptions options + CRecordBatchStreamReader* stream_reader + + def __cinit__(self): + pass + + def _open(self, source): + _get_input_stream(source, &self.in_stream) + with nogil: + self.reader = GetResultValue(CRecordBatchStreamReader.Open( + self.in_stream, self.options)) + self.stream_reader = <CRecordBatchStreamReader*> self.reader.get() + + @property + def stats(self): + """ + Current IPC read statistics. + """ + if not self.reader: + raise ValueError("Operation on closed reader") + return _wrap_read_stats(self.stream_reader.stats()) + + +cdef class _RecordBatchFileWriter(_RecordBatchStreamWriter): + + def _open(self, sink, Schema schema not None, + IpcWriteOptions options=IpcWriteOptions()): + cdef: + shared_ptr[COutputStream] c_sink + + self.options = options.c_options + get_writer(sink, &c_sink) + with nogil: + self.writer = GetResultValue( + MakeFileWriter(c_sink, schema.sp_schema, self.options)) + + +cdef class _RecordBatchFileReader(_Weakrefable): + cdef: + shared_ptr[CRecordBatchFileReader] reader + shared_ptr[CRandomAccessFile] file + CIpcReadOptions options + + cdef readonly: + Schema schema + + def __cinit__(self): + pass + + def _open(self, source, footer_offset=None): + try: + source = as_buffer(source) + except TypeError: + pass + + get_reader(source, True, &self.file) + + cdef int64_t offset = 0 + if footer_offset is not None: + offset = footer_offset + + with nogil: + if offset != 0: + self.reader = GetResultValue( + CRecordBatchFileReader.Open2(self.file.get(), offset, + self.options)) + + else: + self.reader = GetResultValue( + CRecordBatchFileReader.Open(self.file.get(), + self.options)) + + self.schema = pyarrow_wrap_schema(self.reader.get().schema()) + + @property + def num_record_batches(self): + return self.reader.get().num_record_batches() + + def get_batch(self, int i): + cdef shared_ptr[CRecordBatch] batch + + if i < 0 or i >= self.num_record_batches: + raise ValueError('Batch number {0} out of range'.format(i)) + + with nogil: + batch = GetResultValue(self.reader.get().ReadRecordBatch(i)) + + return pyarrow_wrap_batch(batch) + + # TODO(wesm): ARROW-503: Function was renamed. Remove after a period of + # time has passed + get_record_batch = get_batch + + def read_all(self): + """ + Read all record batches as a pyarrow.Table + """ + cdef: + vector[shared_ptr[CRecordBatch]] batches + shared_ptr[CTable] table + int i, nbatches + + nbatches = self.num_record_batches + + batches.resize(nbatches) + with nogil: + for i in range(nbatches): + batches[i] = GetResultValue(self.reader.get() + .ReadRecordBatch(i)) + table = GetResultValue( + CTable.FromRecordBatches(self.schema.sp_schema, move(batches))) + + return pyarrow_wrap_table(table) + + read_pandas = _ReadPandasMixin.read_pandas + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + @property + def stats(self): + """ + Current IPC read statistics. + """ + if not self.reader: + raise ValueError("Operation on closed reader") + return _wrap_read_stats(self.reader.get().stats()) + + +def get_tensor_size(Tensor tensor): + """ + Return total size of serialized Tensor including metadata and padding. + + Parameters + ---------- + tensor : Tensor + The tensor for which we want to known the size. + """ + cdef int64_t size + with nogil: + check_status(GetTensorSize(deref(tensor.tp), &size)) + return size + + +def get_record_batch_size(RecordBatch batch): + """ + Return total size of serialized RecordBatch including metadata and padding. + + Parameters + ---------- + batch : RecordBatch + The recordbatch for which we want to know the size. + """ + cdef int64_t size + with nogil: + check_status(GetRecordBatchSize(deref(batch.batch), &size)) + return size + + +def write_tensor(Tensor tensor, NativeFile dest): + """ + Write pyarrow.Tensor to pyarrow.NativeFile object its current position. + + Parameters + ---------- + tensor : pyarrow.Tensor + dest : pyarrow.NativeFile + + Returns + ------- + bytes_written : int + Total number of bytes written to the file + """ + cdef: + int32_t metadata_length + int64_t body_length + + handle = dest.get_output_stream() + + with nogil: + check_status( + WriteTensor(deref(tensor.tp), handle.get(), + &metadata_length, &body_length)) + + return metadata_length + body_length + + +cdef NativeFile as_native_file(source): + if not isinstance(source, NativeFile): + if hasattr(source, 'read'): + source = PythonFile(source) + else: + source = BufferReader(source) + + if not isinstance(source, NativeFile): + raise ValueError('Unable to read message from object with type: {0}' + .format(type(source))) + return source + + +def read_tensor(source): + """Read pyarrow.Tensor from pyarrow.NativeFile object from current + position. If the file source supports zero copy (e.g. a memory map), then + this operation does not allocate any memory. This function not assume that + the stream is aligned + + Parameters + ---------- + source : pyarrow.NativeFile + + Returns + ------- + tensor : Tensor + + """ + cdef: + shared_ptr[CTensor] sp_tensor + CInputStream* c_stream + NativeFile nf = as_native_file(source) + + c_stream = nf.get_input_stream().get() + with nogil: + sp_tensor = GetResultValue(ReadTensor(c_stream)) + return pyarrow_wrap_tensor(sp_tensor) + + +def read_message(source): + """ + Read length-prefixed message from file or buffer-like object + + Parameters + ---------- + source : pyarrow.NativeFile, file-like object, or buffer-like object + + Returns + ------- + message : Message + """ + cdef: + Message result = Message.__new__(Message) + CInputStream* c_stream + + cdef NativeFile nf = as_native_file(source) + c_stream = nf.get_input_stream().get() + + with nogil: + result.message = move( + GetResultValue(ReadMessage(c_stream, c_default_memory_pool()))) + + if result.message == nullptr: + raise EOFError("End of Arrow stream") + + return result + + +def read_schema(obj, DictionaryMemo dictionary_memo=None): + """ + Read Schema from message or buffer + + Parameters + ---------- + obj : buffer or Message + dictionary_memo : DictionaryMemo, optional + Needed to be able to reconstruct dictionary-encoded fields + with read_record_batch + + Returns + ------- + schema : Schema + """ + cdef: + shared_ptr[CSchema] result + shared_ptr[CRandomAccessFile] cpp_file + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo + + if isinstance(obj, Message): + raise NotImplementedError(type(obj)) + + get_reader(obj, True, &cpp_file) + + if dictionary_memo is not None: + arg_dict_memo = dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + + with nogil: + result = GetResultValue(ReadSchema(cpp_file.get(), arg_dict_memo)) + + return pyarrow_wrap_schema(result) + + +def read_record_batch(obj, Schema schema, + DictionaryMemo dictionary_memo=None): + """ + Read RecordBatch from message, given a known schema. If reading data from a + complete IPC stream, use ipc.open_stream instead + + Parameters + ---------- + obj : Message or Buffer-like + schema : Schema + dictionary_memo : DictionaryMemo, optional + If message contains dictionaries, must pass a populated + DictionaryMemo + + Returns + ------- + batch : RecordBatch + """ + cdef: + shared_ptr[CRecordBatch] result + Message message + CDictionaryMemo temp_memo + CDictionaryMemo* arg_dict_memo + + if isinstance(obj, Message): + message = obj + else: + message = read_message(obj) + + if dictionary_memo is not None: + arg_dict_memo = dictionary_memo.memo + else: + arg_dict_memo = &temp_memo + + with nogil: + result = GetResultValue( + ReadRecordBatch(deref(message.message.get()), + schema.sp_schema, + arg_dict_memo, + CIpcReadOptions.Defaults())) + + return pyarrow_wrap_batch(result) diff --git a/src/arrow/python/pyarrow/ipc.py b/src/arrow/python/pyarrow/ipc.py new file mode 100644 index 000000000..cb28a0b5f --- /dev/null +++ b/src/arrow/python/pyarrow/ipc.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Arrow file and stream reader/writer classes, and other messaging tools + +import os + +import pyarrow as pa + +from pyarrow.lib import (IpcWriteOptions, ReadStats, WriteStats, # noqa + Message, MessageReader, + RecordBatchReader, _ReadPandasMixin, + MetadataVersion, + read_message, read_record_batch, read_schema, + read_tensor, write_tensor, + get_record_batch_size, get_tensor_size) +import pyarrow.lib as lib + + +class RecordBatchStreamReader(lib._RecordBatchStreamReader): + """ + Reader for the Arrow streaming binary format. + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object. + """ + + def __init__(self, source): + self._open(source) + + +_ipc_writer_class_doc = """\ +Parameters +---------- +sink : str, pyarrow.NativeFile, or file-like Python object + Either a file path, or a writable file object. +schema : pyarrow.Schema + The Arrow schema for data to be written to the file. +options : pyarrow.ipc.IpcWriteOptions + Options for IPC serialization. + + If None, default values will be used: the legacy format will not + be used unless overridden by setting the environment variable + ARROW_PRE_0_15_IPC_FORMAT=1, and the V5 metadata version will be + used unless overridden by setting the environment variable + ARROW_PRE_1_0_METADATA_VERSION=1. +use_legacy_format : bool, default None + Deprecated in favor of setting options. Cannot be provided with + options. + + If None, False will be used unless this default is overridden by + setting the environment variable ARROW_PRE_0_15_IPC_FORMAT=1""" + + +class RecordBatchStreamWriter(lib._RecordBatchStreamWriter): + __doc__ = """Writer for the Arrow streaming binary format + +{}""".format(_ipc_writer_class_doc) + + def __init__(self, sink, schema, *, use_legacy_format=None, options=None): + options = _get_legacy_format_default(use_legacy_format, options) + self._open(sink, schema, options=options) + + +class RecordBatchFileReader(lib._RecordBatchFileReader): + """ + Class for reading Arrow record batch data from the Arrow binary file format + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object + footer_offset : int, default None + If the file is embedded in some larger file, this is the byte offset to + the very end of the file data + """ + + def __init__(self, source, footer_offset=None): + self._open(source, footer_offset=footer_offset) + + +class RecordBatchFileWriter(lib._RecordBatchFileWriter): + + __doc__ = """Writer to create the Arrow binary file format + +{}""".format(_ipc_writer_class_doc) + + def __init__(self, sink, schema, *, use_legacy_format=None, options=None): + options = _get_legacy_format_default(use_legacy_format, options) + self._open(sink, schema, options=options) + + +def _get_legacy_format_default(use_legacy_format, options): + if use_legacy_format is not None and options is not None: + raise ValueError( + "Can provide at most one of options and use_legacy_format") + elif options: + if not isinstance(options, IpcWriteOptions): + raise TypeError("expected IpcWriteOptions, got {}" + .format(type(options))) + return options + + metadata_version = MetadataVersion.V5 + if use_legacy_format is None: + use_legacy_format = \ + bool(int(os.environ.get('ARROW_PRE_0_15_IPC_FORMAT', '0'))) + if bool(int(os.environ.get('ARROW_PRE_1_0_METADATA_VERSION', '0'))): + metadata_version = MetadataVersion.V4 + return IpcWriteOptions(use_legacy_format=use_legacy_format, + metadata_version=metadata_version) + + +def new_stream(sink, schema, *, use_legacy_format=None, options=None): + return RecordBatchStreamWriter(sink, schema, + use_legacy_format=use_legacy_format, + options=options) + + +new_stream.__doc__ = """\ +Create an Arrow columnar IPC stream writer instance + +{}""".format(_ipc_writer_class_doc) + + +def open_stream(source): + """ + Create reader for Arrow streaming format. + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object. + + Returns + ------- + reader : RecordBatchStreamReader + """ + return RecordBatchStreamReader(source) + + +def new_file(sink, schema, *, use_legacy_format=None, options=None): + return RecordBatchFileWriter(sink, schema, + use_legacy_format=use_legacy_format, + options=options) + + +new_file.__doc__ = """\ +Create an Arrow columnar IPC file writer instance + +{}""".format(_ipc_writer_class_doc) + + +def open_file(source, footer_offset=None): + """ + Create reader for Arrow file format. + + Parameters + ---------- + source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object + Either an in-memory buffer, or a readable file object. + footer_offset : int, default None + If the file is embedded in some larger file, this is the byte offset to + the very end of the file data. + + Returns + ------- + reader : RecordBatchFileReader + """ + return RecordBatchFileReader(source, footer_offset=footer_offset) + + +def serialize_pandas(df, *, nthreads=None, preserve_index=None): + """ + Serialize a pandas DataFrame into a buffer protocol compatible object. + + Parameters + ---------- + df : pandas.DataFrame + nthreads : int, default None + Number of threads to use for conversion to Arrow, default all CPUs. + preserve_index : bool, default None + The default of None will store the index as a column, except for + RangeIndex which is stored as metadata only. If True, always + preserve the pandas index data as a column. If False, no index + information is saved and the result will have a default RangeIndex. + + Returns + ------- + buf : buffer + An object compatible with the buffer protocol. + """ + batch = pa.RecordBatch.from_pandas(df, nthreads=nthreads, + preserve_index=preserve_index) + sink = pa.BufferOutputStream() + with pa.RecordBatchStreamWriter(sink, batch.schema) as writer: + writer.write_batch(batch) + return sink.getvalue() + + +def deserialize_pandas(buf, *, use_threads=True): + """Deserialize a buffer protocol compatible object into a pandas DataFrame. + + Parameters + ---------- + buf : buffer + An object compatible with the buffer protocol. + use_threads : bool, default True + Whether to parallelize the conversion using multiple threads. + + Returns + ------- + df : pandas.DataFrame + """ + buffer_reader = pa.BufferReader(buf) + with pa.RecordBatchStreamReader(buffer_reader) as reader: + table = reader.read_all() + return table.to_pandas(use_threads=use_threads) diff --git a/src/arrow/python/pyarrow/json.py b/src/arrow/python/pyarrow/json.py new file mode 100644 index 000000000..a864f5d99 --- /dev/null +++ b/src/arrow/python/pyarrow/json.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from pyarrow._json import ReadOptions, ParseOptions, read_json # noqa diff --git a/src/arrow/python/pyarrow/jvm.py b/src/arrow/python/pyarrow/jvm.py new file mode 100644 index 000000000..161c5ff4d --- /dev/null +++ b/src/arrow/python/pyarrow/jvm.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Functions to interact with Arrow memory allocated by Arrow Java. + +These functions convert the objects holding the metadata, the actual +data is not copied at all. + +This will only work with a JVM running in the same process such as provided +through jpype. Modules that talk to a remote JVM like py4j will not work as the +memory addresses reported by them are not reachable in the python process. +""" + +import pyarrow as pa + + +class _JvmBufferNanny: + """ + An object that keeps a org.apache.arrow.memory.ArrowBuf's underlying + memory alive. + """ + ref_manager = None + + def __init__(self, jvm_buf): + ref_manager = jvm_buf.getReferenceManager() + # Will raise a java.lang.IllegalArgumentException if the buffer + # is already freed. It seems that exception cannot easily be + # caught... + ref_manager.retain() + self.ref_manager = ref_manager + + def __del__(self): + if self.ref_manager is not None: + self.ref_manager.release() + + +def jvm_buffer(jvm_buf): + """ + Construct an Arrow buffer from org.apache.arrow.memory.ArrowBuf + + Parameters + ---------- + + jvm_buf: org.apache.arrow.memory.ArrowBuf + Arrow Buffer representation on the JVM. + + Returns + ------- + pyarrow.Buffer + Python Buffer that references the JVM memory. + """ + nanny = _JvmBufferNanny(jvm_buf) + address = jvm_buf.memoryAddress() + size = jvm_buf.capacity() + return pa.foreign_buffer(address, size, base=nanny) + + +def _from_jvm_int_type(jvm_type): + """ + Convert a JVM int type to its Python equivalent. + + Parameters + ---------- + jvm_type : org.apache.arrow.vector.types.pojo.ArrowType$Int + + Returns + ------- + typ : pyarrow.DataType + """ + + bit_width = jvm_type.getBitWidth() + if jvm_type.getIsSigned(): + if bit_width == 8: + return pa.int8() + elif bit_width == 16: + return pa.int16() + elif bit_width == 32: + return pa.int32() + elif bit_width == 64: + return pa.int64() + else: + if bit_width == 8: + return pa.uint8() + elif bit_width == 16: + return pa.uint16() + elif bit_width == 32: + return pa.uint32() + elif bit_width == 64: + return pa.uint64() + + +def _from_jvm_float_type(jvm_type): + """ + Convert a JVM float type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$FloatingPoint + + Returns + ------- + typ: pyarrow.DataType + """ + precision = jvm_type.getPrecision().toString() + if precision == 'HALF': + return pa.float16() + elif precision == 'SINGLE': + return pa.float32() + elif precision == 'DOUBLE': + return pa.float64() + + +def _from_jvm_time_type(jvm_type): + """ + Convert a JVM time type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Time + + Returns + ------- + typ: pyarrow.DataType + """ + time_unit = jvm_type.getUnit().toString() + if time_unit == 'SECOND': + assert jvm_type.getBitWidth() == 32 + return pa.time32('s') + elif time_unit == 'MILLISECOND': + assert jvm_type.getBitWidth() == 32 + return pa.time32('ms') + elif time_unit == 'MICROSECOND': + assert jvm_type.getBitWidth() == 64 + return pa.time64('us') + elif time_unit == 'NANOSECOND': + assert jvm_type.getBitWidth() == 64 + return pa.time64('ns') + + +def _from_jvm_timestamp_type(jvm_type): + """ + Convert a JVM timestamp type to its Python equivalent. + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Timestamp + + Returns + ------- + typ: pyarrow.DataType + """ + time_unit = jvm_type.getUnit().toString() + timezone = jvm_type.getTimezone() + if timezone is not None: + timezone = str(timezone) + if time_unit == 'SECOND': + return pa.timestamp('s', tz=timezone) + elif time_unit == 'MILLISECOND': + return pa.timestamp('ms', tz=timezone) + elif time_unit == 'MICROSECOND': + return pa.timestamp('us', tz=timezone) + elif time_unit == 'NANOSECOND': + return pa.timestamp('ns', tz=timezone) + + +def _from_jvm_date_type(jvm_type): + """ + Convert a JVM date type to its Python equivalent + + Parameters + ---------- + jvm_type: org.apache.arrow.vector.types.pojo.ArrowType$Date + + Returns + ------- + typ: pyarrow.DataType + """ + day_unit = jvm_type.getUnit().toString() + if day_unit == 'DAY': + return pa.date32() + elif day_unit == 'MILLISECOND': + return pa.date64() + + +def field(jvm_field): + """ + Construct a Field from a org.apache.arrow.vector.types.pojo.Field + instance. + + Parameters + ---------- + jvm_field: org.apache.arrow.vector.types.pojo.Field + + Returns + ------- + pyarrow.Field + """ + name = str(jvm_field.getName()) + jvm_type = jvm_field.getType() + + typ = None + if not jvm_type.isComplex(): + type_str = jvm_type.getTypeID().toString() + if type_str == 'Null': + typ = pa.null() + elif type_str == 'Int': + typ = _from_jvm_int_type(jvm_type) + elif type_str == 'FloatingPoint': + typ = _from_jvm_float_type(jvm_type) + elif type_str == 'Utf8': + typ = pa.string() + elif type_str == 'Binary': + typ = pa.binary() + elif type_str == 'FixedSizeBinary': + typ = pa.binary(jvm_type.getByteWidth()) + elif type_str == 'Bool': + typ = pa.bool_() + elif type_str == 'Time': + typ = _from_jvm_time_type(jvm_type) + elif type_str == 'Timestamp': + typ = _from_jvm_timestamp_type(jvm_type) + elif type_str == 'Date': + typ = _from_jvm_date_type(jvm_type) + elif type_str == 'Decimal': + typ = pa.decimal128(jvm_type.getPrecision(), jvm_type.getScale()) + else: + raise NotImplementedError( + "Unsupported JVM type: {}".format(type_str)) + else: + # TODO: The following JVM types are not implemented: + # Struct, List, FixedSizeList, Union, Dictionary + raise NotImplementedError( + "JVM field conversion only implemented for primitive types.") + + nullable = jvm_field.isNullable() + jvm_metadata = jvm_field.getMetadata() + if jvm_metadata.isEmpty(): + metadata = None + else: + metadata = {str(entry.getKey()): str(entry.getValue()) + for entry in jvm_metadata.entrySet()} + return pa.field(name, typ, nullable, metadata) + + +def schema(jvm_schema): + """ + Construct a Schema from a org.apache.arrow.vector.types.pojo.Schema + instance. + + Parameters + ---------- + jvm_schema: org.apache.arrow.vector.types.pojo.Schema + + Returns + ------- + pyarrow.Schema + """ + fields = jvm_schema.getFields() + fields = [field(f) for f in fields] + jvm_metadata = jvm_schema.getCustomMetadata() + if jvm_metadata.isEmpty(): + metadata = None + else: + metadata = {str(entry.getKey()): str(entry.getValue()) + for entry in jvm_metadata.entrySet()} + return pa.schema(fields, metadata) + + +def array(jvm_array): + """ + Construct an (Python) Array from its JVM equivalent. + + Parameters + ---------- + jvm_array : org.apache.arrow.vector.ValueVector + + Returns + ------- + array : Array + """ + if jvm_array.getField().getType().isComplex(): + minor_type_str = jvm_array.getMinorType().toString() + raise NotImplementedError( + "Cannot convert JVM Arrow array of type {}," + " complex types not yet implemented.".format(minor_type_str)) + dtype = field(jvm_array.getField()).type + buffers = [jvm_buffer(buf) + for buf in list(jvm_array.getBuffers(False))] + + # If JVM has an empty Vector, buffer list will be empty so create manually + if len(buffers) == 0: + return pa.array([], type=dtype) + + length = jvm_array.getValueCount() + null_count = jvm_array.getNullCount() + return pa.Array.from_buffers(dtype, length, buffers, null_count) + + +def record_batch(jvm_vector_schema_root): + """ + Construct a (Python) RecordBatch from a JVM VectorSchemaRoot + + Parameters + ---------- + jvm_vector_schema_root : org.apache.arrow.vector.VectorSchemaRoot + + Returns + ------- + record_batch: pyarrow.RecordBatch + """ + pa_schema = schema(jvm_vector_schema_root.getSchema()) + + arrays = [] + for name in pa_schema.names: + arrays.append(array(jvm_vector_schema_root.getVector(name))) + + return pa.RecordBatch.from_arrays( + arrays, + pa_schema.names, + metadata=pa_schema.metadata + ) diff --git a/src/arrow/python/pyarrow/lib.pxd b/src/arrow/python/pyarrow/lib.pxd new file mode 100644 index 000000000..e3b07f404 --- /dev/null +++ b/src/arrow/python/pyarrow/lib.pxd @@ -0,0 +1,604 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: language_level = 3 + +from cpython cimport PyObject +from libcpp cimport nullptr +from libcpp.cast cimport dynamic_cast +from pyarrow.includes.common cimport * +from pyarrow.includes.libarrow cimport * + + +cdef extern from "Python.h": + int PySlice_Check(object) + + +cdef int check_status(const CStatus& status) nogil except -1 + + +cdef class _Weakrefable: + cdef object __weakref__ + + +cdef class IpcWriteOptions(_Weakrefable): + cdef: + CIpcWriteOptions c_options + + +cdef class Message(_Weakrefable): + cdef: + unique_ptr[CMessage] message + + +cdef class MemoryPool(_Weakrefable): + cdef: + CMemoryPool* pool + + cdef void init(self, CMemoryPool* pool) + + +cdef CMemoryPool* maybe_unbox_memory_pool(MemoryPool memory_pool) + + +cdef class DataType(_Weakrefable): + cdef: + shared_ptr[CDataType] sp_type + CDataType* type + bytes pep3118_format + + cdef void init(self, const shared_ptr[CDataType]& type) except * + cdef Field field(self, int i) + + +cdef class ListType(DataType): + cdef: + const CListType* list_type + + +cdef class LargeListType(DataType): + cdef: + const CLargeListType* list_type + + +cdef class MapType(DataType): + cdef: + const CMapType* map_type + + +cdef class FixedSizeListType(DataType): + cdef: + const CFixedSizeListType* list_type + + +cdef class StructType(DataType): + cdef: + const CStructType* struct_type + + cdef Field field_by_name(self, name) + + +cdef class DictionaryMemo(_Weakrefable): + cdef: + # Even though the CDictionaryMemo instance is private, we allocate + # it on the heap so as to avoid C++ ABI issues with Python wheels. + shared_ptr[CDictionaryMemo] sp_memo + CDictionaryMemo* memo + + +cdef class DictionaryType(DataType): + cdef: + const CDictionaryType* dict_type + + +cdef class TimestampType(DataType): + cdef: + const CTimestampType* ts_type + + +cdef class Time32Type(DataType): + cdef: + const CTime32Type* time_type + + +cdef class Time64Type(DataType): + cdef: + const CTime64Type* time_type + + +cdef class DurationType(DataType): + cdef: + const CDurationType* duration_type + + +cdef class FixedSizeBinaryType(DataType): + cdef: + const CFixedSizeBinaryType* fixed_size_binary_type + + +cdef class Decimal128Type(FixedSizeBinaryType): + cdef: + const CDecimal128Type* decimal128_type + + +cdef class Decimal256Type(FixedSizeBinaryType): + cdef: + const CDecimal256Type* decimal256_type + + +cdef class BaseExtensionType(DataType): + cdef: + const CExtensionType* ext_type + + +cdef class ExtensionType(BaseExtensionType): + cdef: + const CPyExtensionType* cpy_ext_type + + +cdef class PyExtensionType(ExtensionType): + pass + + +cdef class _Metadata(_Weakrefable): + # required because KeyValueMetadata also extends collections.abc.Mapping + # and the first parent class must be an extension type + pass + + +cdef class KeyValueMetadata(_Metadata): + cdef: + shared_ptr[const CKeyValueMetadata] wrapped + const CKeyValueMetadata* metadata + + cdef void init(self, const shared_ptr[const CKeyValueMetadata]& wrapped) + + @staticmethod + cdef wrap(const shared_ptr[const CKeyValueMetadata]& sp) + cdef inline shared_ptr[const CKeyValueMetadata] unwrap(self) nogil + + +cdef class Field(_Weakrefable): + cdef: + shared_ptr[CField] sp_field + CField* field + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CField]& field) + + +cdef class Schema(_Weakrefable): + cdef: + shared_ptr[CSchema] sp_schema + CSchema* schema + + cdef void init(self, const vector[shared_ptr[CField]]& fields) + cdef void init_schema(self, const shared_ptr[CSchema]& schema) + + +cdef class Scalar(_Weakrefable): + cdef: + shared_ptr[CScalar] wrapped + + cdef void init(self, const shared_ptr[CScalar]& wrapped) + + @staticmethod + cdef wrap(const shared_ptr[CScalar]& wrapped) + + cdef inline shared_ptr[CScalar] unwrap(self) nogil + + +cdef class _PandasConvertible(_Weakrefable): + pass + + +cdef class Array(_PandasConvertible): + cdef: + shared_ptr[CArray] sp_array + CArray* ap + + cdef readonly: + DataType type + # To allow Table to propagate metadata to pandas.Series + object _name + + cdef void init(self, const shared_ptr[CArray]& sp_array) except * + cdef getitem(self, int64_t i) + cdef int64_t length(self) + + +cdef class Tensor(_Weakrefable): + cdef: + shared_ptr[CTensor] sp_tensor + CTensor* tp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CTensor]& sp_tensor) + + +cdef class SparseCSRMatrix(_Weakrefable): + cdef: + shared_ptr[CSparseCSRMatrix] sp_sparse_tensor + CSparseCSRMatrix* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor) + + +cdef class SparseCSCMatrix(_Weakrefable): + cdef: + shared_ptr[CSparseCSCMatrix] sp_sparse_tensor + CSparseCSCMatrix* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor) + + +cdef class SparseCOOTensor(_Weakrefable): + cdef: + shared_ptr[CSparseCOOTensor] sp_sparse_tensor + CSparseCOOTensor* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor) + + +cdef class SparseCSFTensor(_Weakrefable): + cdef: + shared_ptr[CSparseCSFTensor] sp_sparse_tensor + CSparseCSFTensor* stp + + cdef readonly: + DataType type + + cdef void init(self, const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) + + +cdef class NullArray(Array): + pass + + +cdef class BooleanArray(Array): + pass + + +cdef class NumericArray(Array): + pass + + +cdef class IntegerArray(NumericArray): + pass + + +cdef class FloatingPointArray(NumericArray): + pass + + +cdef class Int8Array(IntegerArray): + pass + + +cdef class UInt8Array(IntegerArray): + pass + + +cdef class Int16Array(IntegerArray): + pass + + +cdef class UInt16Array(IntegerArray): + pass + + +cdef class Int32Array(IntegerArray): + pass + + +cdef class UInt32Array(IntegerArray): + pass + + +cdef class Int64Array(IntegerArray): + pass + + +cdef class UInt64Array(IntegerArray): + pass + + +cdef class HalfFloatArray(FloatingPointArray): + pass + + +cdef class FloatArray(FloatingPointArray): + pass + + +cdef class DoubleArray(FloatingPointArray): + pass + + +cdef class FixedSizeBinaryArray(Array): + pass + + +cdef class Decimal128Array(FixedSizeBinaryArray): + pass + + +cdef class Decimal256Array(FixedSizeBinaryArray): + pass + + +cdef class StructArray(Array): + pass + + +cdef class BaseListArray(Array): + pass + + +cdef class ListArray(BaseListArray): + pass + + +cdef class LargeListArray(BaseListArray): + pass + + +cdef class MapArray(Array): + pass + + +cdef class FixedSizeListArray(Array): + pass + + +cdef class UnionArray(Array): + pass + + +cdef class StringArray(Array): + pass + + +cdef class BinaryArray(Array): + pass + + +cdef class DictionaryArray(Array): + cdef: + object _indices, _dictionary + + +cdef class ExtensionArray(Array): + pass + + +cdef class MonthDayNanoIntervalArray(Array): + pass + + +cdef wrap_array_output(PyObject* output) +cdef wrap_datum(const CDatum& datum) + + +cdef class ChunkedArray(_PandasConvertible): + cdef: + shared_ptr[CChunkedArray] sp_chunked_array + CChunkedArray* chunked_array + + cdef readonly: + # To allow Table to propagate metadata to pandas.Series + object _name + + cdef void init(self, const shared_ptr[CChunkedArray]& chunked_array) + cdef getitem(self, int64_t i) + + +cdef class Table(_PandasConvertible): + cdef: + shared_ptr[CTable] sp_table + CTable* table + + cdef void init(self, const shared_ptr[CTable]& table) + + +cdef class RecordBatch(_PandasConvertible): + cdef: + shared_ptr[CRecordBatch] sp_batch + CRecordBatch* batch + Schema _schema + + cdef void init(self, const shared_ptr[CRecordBatch]& table) + + +cdef class Buffer(_Weakrefable): + cdef: + shared_ptr[CBuffer] buffer + Py_ssize_t shape[1] + Py_ssize_t strides[1] + + cdef void init(self, const shared_ptr[CBuffer]& buffer) + cdef getitem(self, int64_t i) + + +cdef class ResizableBuffer(Buffer): + + cdef void init_rz(self, const shared_ptr[CResizableBuffer]& buffer) + + +cdef class NativeFile(_Weakrefable): + cdef: + shared_ptr[CInputStream] input_stream + shared_ptr[CRandomAccessFile] random_access + shared_ptr[COutputStream] output_stream + bint is_readable + bint is_writable + bint is_seekable + bint own_file + + # By implementing these "virtual" functions (all functions in Cython + # extension classes are technically virtual in the C++ sense) we can expose + # the arrow::io abstract file interfaces to other components throughout the + # suite of Arrow C++ libraries + cdef set_random_access_file(self, shared_ptr[CRandomAccessFile] handle) + cdef set_input_stream(self, shared_ptr[CInputStream] handle) + cdef set_output_stream(self, shared_ptr[COutputStream] handle) + + cdef shared_ptr[CRandomAccessFile] get_random_access_file(self) except * + cdef shared_ptr[CInputStream] get_input_stream(self) except * + cdef shared_ptr[COutputStream] get_output_stream(self) except * + + +cdef class BufferedInputStream(NativeFile): + pass + + +cdef class BufferedOutputStream(NativeFile): + pass + + +cdef class CompressedInputStream(NativeFile): + pass + + +cdef class CompressedOutputStream(NativeFile): + pass + + +cdef class _CRecordBatchWriter(_Weakrefable): + cdef: + shared_ptr[CRecordBatchWriter] writer + + +cdef class RecordBatchReader(_Weakrefable): + cdef: + shared_ptr[CRecordBatchReader] reader + + +cdef class Codec(_Weakrefable): + cdef: + shared_ptr[CCodec] wrapped + + cdef inline CCodec* unwrap(self) nogil + + +# This class is only used internally for now +cdef class StopToken: + cdef: + CStopToken stop_token + + cdef void init(self, CStopToken stop_token) + + +cdef get_input_stream(object source, c_bool use_memory_map, + shared_ptr[CInputStream]* reader) +cdef get_reader(object source, c_bool use_memory_map, + shared_ptr[CRandomAccessFile]* reader) +cdef get_writer(object source, shared_ptr[COutputStream]* writer) +cdef NativeFile get_native_file(object source, c_bool use_memory_map) + +cdef shared_ptr[CInputStream] native_transcoding_input_stream( + shared_ptr[CInputStream] stream, src_encoding, + dest_encoding) except * + +# Default is allow_none=False +cpdef DataType ensure_type(object type, bint allow_none=*) + +cdef timeunit_to_string(TimeUnit unit) +cdef TimeUnit string_to_timeunit(unit) except * + +# Exceptions may be raised when converting dict values, so need to +# check exception state on return +cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata( + object meta) except * +cdef object pyarrow_wrap_metadata( + const shared_ptr[const CKeyValueMetadata]& meta) + +# +# Public Cython API for 3rd party code +# +# If you add functions to this list, please also update +# `cpp/src/arrow/python/pyarrow.{h, cc}` +# + +# Wrapping C++ -> Python + +cdef public object pyarrow_wrap_buffer(const shared_ptr[CBuffer]& buf) +cdef public object pyarrow_wrap_resizable_buffer( + const shared_ptr[CResizableBuffer]& buf) + +cdef public object pyarrow_wrap_data_type(const shared_ptr[CDataType]& type) +cdef public object pyarrow_wrap_field(const shared_ptr[CField]& field) +cdef public object pyarrow_wrap_schema(const shared_ptr[CSchema]& type) + +cdef public object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar) + +cdef public object pyarrow_wrap_array(const shared_ptr[CArray]& sp_array) +cdef public object pyarrow_wrap_chunked_array( + const shared_ptr[CChunkedArray]& sp_array) + +cdef public object pyarrow_wrap_sparse_coo_tensor( + const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor) +cdef public object pyarrow_wrap_sparse_csc_matrix( + const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor) +cdef public object pyarrow_wrap_sparse_csf_tensor( + const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor) +cdef public object pyarrow_wrap_sparse_csr_matrix( + const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor) +cdef public object pyarrow_wrap_tensor(const shared_ptr[CTensor]& sp_tensor) + +cdef public object pyarrow_wrap_batch(const shared_ptr[CRecordBatch]& cbatch) +cdef public object pyarrow_wrap_table(const shared_ptr[CTable]& ctable) + +# Unwrapping Python -> C++ + +cdef public shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer) + +cdef public shared_ptr[CDataType] pyarrow_unwrap_data_type(object data_type) +cdef public shared_ptr[CField] pyarrow_unwrap_field(object field) +cdef public shared_ptr[CSchema] pyarrow_unwrap_schema(object schema) + +cdef public shared_ptr[CScalar] pyarrow_unwrap_scalar(object scalar) + +cdef public shared_ptr[CArray] pyarrow_unwrap_array(object array) +cdef public shared_ptr[CChunkedArray] pyarrow_unwrap_chunked_array( + object array) + +cdef public shared_ptr[CSparseCOOTensor] pyarrow_unwrap_sparse_coo_tensor( + object sparse_tensor) +cdef public shared_ptr[CSparseCSCMatrix] pyarrow_unwrap_sparse_csc_matrix( + object sparse_tensor) +cdef public shared_ptr[CSparseCSFTensor] pyarrow_unwrap_sparse_csf_tensor( + object sparse_tensor) +cdef public shared_ptr[CSparseCSRMatrix] pyarrow_unwrap_sparse_csr_matrix( + object sparse_tensor) +cdef public shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor) + +cdef public shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch) +cdef public shared_ptr[CTable] pyarrow_unwrap_table(object table) diff --git a/src/arrow/python/pyarrow/lib.pyx b/src/arrow/python/pyarrow/lib.pyx new file mode 100644 index 000000000..0c9cbcc5b --- /dev/null +++ b/src/arrow/python/pyarrow/lib.pyx @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile = False +# cython: nonecheck = True +# distutils: language = c++ + +import datetime +import decimal as _pydecimal +import numpy as np +import os +import sys + +from cython.operator cimport dereference as deref +from pyarrow.includes.libarrow cimport * +from pyarrow.includes.common cimport PyObject_to_object +cimport pyarrow.includes.libarrow as libarrow +cimport cpython as cp + +# Initialize NumPy C API +arrow_init_numpy() +# Initialize PyArrow C++ API +# (used from some of our C++ code, see e.g. ARROW-5260) +import_pyarrow() + + +MonthDayNano = NewMonthDayNanoTupleType() + + +def cpu_count(): + """ + Return the number of threads to use in parallel operations. + + The number of threads is determined at startup by inspecting the + ``OMP_NUM_THREADS`` and ``OMP_THREAD_LIMIT`` environment variables. + If neither is present, it will default to the number of hardware threads + on the system. It can be modified at runtime by calling + :func:`set_cpu_count()`. + + See Also + -------- + set_cpu_count : Modify the size of this pool. + io_thread_count : The analogous function for the I/O thread pool. + """ + return GetCpuThreadPoolCapacity() + + +def set_cpu_count(int count): + """ + Set the number of threads to use in parallel operations. + + Parameters + ---------- + count : int + The number of concurrent threads that should be used. + + See Also + -------- + cpu_count : Get the size of this pool. + set_io_thread_count : The analogous function for the I/O thread pool. + """ + if count < 1: + raise ValueError("CPU count must be strictly positive") + check_status(SetCpuThreadPoolCapacity(count)) + + +Type_NA = _Type_NA +Type_BOOL = _Type_BOOL +Type_UINT8 = _Type_UINT8 +Type_INT8 = _Type_INT8 +Type_UINT16 = _Type_UINT16 +Type_INT16 = _Type_INT16 +Type_UINT32 = _Type_UINT32 +Type_INT32 = _Type_INT32 +Type_UINT64 = _Type_UINT64 +Type_INT64 = _Type_INT64 +Type_HALF_FLOAT = _Type_HALF_FLOAT +Type_FLOAT = _Type_FLOAT +Type_DOUBLE = _Type_DOUBLE +Type_DECIMAL128 = _Type_DECIMAL128 +Type_DECIMAL256 = _Type_DECIMAL256 +Type_DATE32 = _Type_DATE32 +Type_DATE64 = _Type_DATE64 +Type_TIMESTAMP = _Type_TIMESTAMP +Type_TIME32 = _Type_TIME32 +Type_TIME64 = _Type_TIME64 +Type_DURATION = _Type_DURATION +Type_INTERVAL_MONTH_DAY_NANO = _Type_INTERVAL_MONTH_DAY_NANO +Type_BINARY = _Type_BINARY +Type_STRING = _Type_STRING +Type_LARGE_BINARY = _Type_LARGE_BINARY +Type_LARGE_STRING = _Type_LARGE_STRING +Type_FIXED_SIZE_BINARY = _Type_FIXED_SIZE_BINARY +Type_LIST = _Type_LIST +Type_LARGE_LIST = _Type_LARGE_LIST +Type_MAP = _Type_MAP +Type_FIXED_SIZE_LIST = _Type_FIXED_SIZE_LIST +Type_STRUCT = _Type_STRUCT +Type_SPARSE_UNION = _Type_SPARSE_UNION +Type_DENSE_UNION = _Type_DENSE_UNION +Type_DICTIONARY = _Type_DICTIONARY + +UnionMode_SPARSE = _UnionMode_SPARSE +UnionMode_DENSE = _UnionMode_DENSE + + +def _pc(): + import pyarrow.compute as pc + return pc + + +# Assorted compatibility helpers +include "compat.pxi" + +# Exception types and Status handling +include "error.pxi" + +# Configuration information +include "config.pxi" + +# pandas API shim +include "pandas-shim.pxi" + +# Memory pools and allocation +include "memory.pxi" + +# DataType, Field, Schema +include "types.pxi" + +# Array scalar values +include "scalar.pxi" + +# Array types +include "array.pxi" + +# Builders +include "builder.pxi" + +# Column, Table, Record Batch +include "table.pxi" + +# Tensors +include "tensor.pxi" + +# File IO +include "io.pxi" + +# IPC / Messaging +include "ipc.pxi" + +# Python serialization +include "serialization.pxi" + +# Micro-benchmark routines +include "benchmark.pxi" + +# Public API +include "public-api.pxi" diff --git a/src/arrow/python/pyarrow/memory.pxi b/src/arrow/python/pyarrow/memory.pxi new file mode 100644 index 000000000..8ccb35058 --- /dev/null +++ b/src/arrow/python/pyarrow/memory.pxi @@ -0,0 +1,249 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True + + +cdef class MemoryPool(_Weakrefable): + """ + Base class for memory allocation. + + Besides tracking its number of allocated bytes, a memory pool also + takes care of the required 64-byte alignment for Arrow data. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.*_memory_pool instead." + .format(self.__class__.__name__)) + + cdef void init(self, CMemoryPool* pool): + self.pool = pool + + def release_unused(self): + """ + Attempt to return to the OS any memory being held onto by the pool. + + This function should not be called except potentially for + benchmarking or debugging as it could be expensive and detrimental to + performance. + + This is best effort and may not have any effect on some memory pools + or in some situations (e.g. fragmentation). + """ + cdef CMemoryPool* pool = c_get_memory_pool() + with nogil: + pool.ReleaseUnused() + + def bytes_allocated(self): + """ + Return the number of bytes that are currently allocated from this + memory pool. + """ + return self.pool.bytes_allocated() + + def max_memory(self): + """ + Return the peak memory allocation in this memory pool. + This can be an approximate number in multi-threaded applications. + + None is returned if the pool implementation doesn't know how to + compute this number. + """ + ret = self.pool.max_memory() + return ret if ret >= 0 else None + + @property + def backend_name(self): + """ + The name of the backend used by this MemoryPool (e.g. "jemalloc"). + """ + return frombytes(self.pool.backend_name()) + + +cdef CMemoryPool* maybe_unbox_memory_pool(MemoryPool memory_pool): + if memory_pool is None: + return c_get_memory_pool() + else: + return memory_pool.pool + + +cdef class LoggingMemoryPool(MemoryPool): + cdef: + unique_ptr[CLoggingMemoryPool] logging_pool + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.logging_memory_pool instead." + .format(self.__class__.__name__)) + + +cdef class ProxyMemoryPool(MemoryPool): + """ + Memory pool implementation that tracks the number of bytes and + maximum memory allocated through its direct calls, while redirecting + to another memory pool. + """ + cdef: + unique_ptr[CProxyMemoryPool] proxy_pool + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, " + "use pyarrow.proxy_memory_pool instead." + .format(self.__class__.__name__)) + + +def default_memory_pool(): + """ + Return the process-global memory pool. + """ + cdef: + MemoryPool pool = MemoryPool.__new__(MemoryPool) + pool.init(c_get_memory_pool()) + return pool + + +def proxy_memory_pool(MemoryPool parent): + """ + Create and return a MemoryPool instance that redirects to the + *parent*, but with separate allocation statistics. + + Parameters + ---------- + parent : MemoryPool + The real memory pool that should be used for allocations. + """ + cdef ProxyMemoryPool out = ProxyMemoryPool.__new__(ProxyMemoryPool) + out.proxy_pool.reset(new CProxyMemoryPool(parent.pool)) + out.init(out.proxy_pool.get()) + return out + + +def logging_memory_pool(MemoryPool parent): + """ + Create and return a MemoryPool instance that redirects to the + *parent*, but also dumps allocation logs on stderr. + + Parameters + ---------- + parent : MemoryPool + The real memory pool that should be used for allocations. + """ + cdef LoggingMemoryPool out = LoggingMemoryPool.__new__( + LoggingMemoryPool, parent) + out.logging_pool.reset(new CLoggingMemoryPool(parent.pool)) + out.init(out.logging_pool.get()) + return out + + +def system_memory_pool(): + """ + Return a memory pool based on the C malloc heap. + """ + cdef: + MemoryPool pool = MemoryPool.__new__(MemoryPool) + pool.init(c_system_memory_pool()) + return pool + + +def jemalloc_memory_pool(): + """ + Return a memory pool based on the jemalloc heap. + + NotImplementedError is raised if jemalloc support is not enabled. + """ + cdef: + CMemoryPool* c_pool + MemoryPool pool = MemoryPool.__new__(MemoryPool) + check_status(c_jemalloc_memory_pool(&c_pool)) + pool.init(c_pool) + return pool + + +def mimalloc_memory_pool(): + """ + Return a memory pool based on the mimalloc heap. + + NotImplementedError is raised if mimalloc support is not enabled. + """ + cdef: + CMemoryPool* c_pool + MemoryPool pool = MemoryPool.__new__(MemoryPool) + check_status(c_mimalloc_memory_pool(&c_pool)) + pool.init(c_pool) + return pool + + +def set_memory_pool(MemoryPool pool): + """ + Set the default memory pool. + + Parameters + ---------- + pool : MemoryPool + The memory pool that should be used by default. + """ + c_set_default_memory_pool(pool.pool) + + +cdef MemoryPool _default_memory_pool = default_memory_pool() +cdef LoggingMemoryPool _logging_memory_pool = logging_memory_pool( + _default_memory_pool) + + +def log_memory_allocations(enable=True): + """ + Enable or disable memory allocator logging for debugging purposes + + Parameters + ---------- + enable : bool, default True + Pass False to disable logging + """ + if enable: + set_memory_pool(_logging_memory_pool) + else: + set_memory_pool(_default_memory_pool) + + +def total_allocated_bytes(): + """ + Return the currently allocated bytes from the default memory pool. + Other memory pools may not be accounted for. + """ + cdef CMemoryPool* pool = c_get_memory_pool() + return pool.bytes_allocated() + + +def jemalloc_set_decay_ms(decay_ms): + """ + Set arenas.dirty_decay_ms and arenas.muzzy_decay_ms to indicated number of + milliseconds. A value of 0 (the default) results in dirty / muzzy memory + pages being released right away to the OS, while a higher value will result + in a time-based decay. See the jemalloc docs for more information + + It's best to set this at the start of your application. + + Parameters + ---------- + decay_ms : int + Number of milliseconds to set for jemalloc decay conf parameters. Note + that this change will only affect future memory arenas + """ + check_status(c_jemalloc_set_decay_ms(decay_ms)) diff --git a/src/arrow/python/pyarrow/orc.py b/src/arrow/python/pyarrow/orc.py new file mode 100644 index 000000000..87f805886 --- /dev/null +++ b/src/arrow/python/pyarrow/orc.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from numbers import Integral +import warnings + +from pyarrow.lib import Table +import pyarrow._orc as _orc + + +class ORCFile: + """ + Reader interface for a single ORC file + + Parameters + ---------- + source : str or pyarrow.io.NativeFile + Readable source. For passing Python file objects or byte buffers, + see pyarrow.io.PythonFileInterface or pyarrow.io.BufferReader. + """ + + def __init__(self, source): + self.reader = _orc.ORCReader() + self.reader.open(source) + + @property + def metadata(self): + """The file metadata, as an arrow KeyValueMetadata""" + return self.reader.metadata() + + @property + def schema(self): + """The file schema, as an arrow schema""" + return self.reader.schema() + + @property + def nrows(self): + """The number of rows in the file""" + return self.reader.nrows() + + @property + def nstripes(self): + """The number of stripes in the file""" + return self.reader.nstripes() + + def _select_names(self, columns=None): + if columns is None: + return None + + schema = self.schema + names = [] + for col in columns: + if isinstance(col, Integral): + col = int(col) + if 0 <= col < len(schema): + col = schema[col].name + names.append(col) + else: + raise ValueError("Column indices must be in 0 <= ind < %d," + " got %d" % (len(schema), col)) + else: + return columns + + return names + + def read_stripe(self, n, columns=None): + """Read a single stripe from the file. + + Parameters + ---------- + n : int + The stripe index + columns : list + If not None, only these columns will be read from the stripe. A + column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e' + + Returns + ------- + pyarrow.lib.RecordBatch + Content of the stripe as a RecordBatch. + """ + columns = self._select_names(columns) + return self.reader.read_stripe(n, columns=columns) + + def read(self, columns=None): + """Read the whole file. + + Parameters + ---------- + columns : list + If not None, only these columns will be read from the file. A + column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e' + + Returns + ------- + pyarrow.lib.Table + Content of the file as a Table. + """ + columns = self._select_names(columns) + return self.reader.read(columns=columns) + + +class ORCWriter: + """ + Writer interface for a single ORC file + + Parameters + ---------- + where : str or pyarrow.io.NativeFile + Writable target. For passing Python file objects or byte buffers, + see pyarrow.io.PythonFileInterface, pyarrow.io.BufferOutputStream + or pyarrow.io.FixedSizeBufferWriter. + """ + + def __init__(self, where): + self.writer = _orc.ORCWriter() + self.writer.open(where) + + def write(self, table): + """ + Write the table into an ORC file. The schema of the table must + be equal to the schema used when opening the ORC file. + + Parameters + ---------- + schema : pyarrow.lib.Table + The table to be written into the ORC file + """ + self.writer.write(table) + + def close(self): + """ + Close the ORC file + """ + self.writer.close() + + +def write_table(table, where): + """ + Write a table into an ORC file + + Parameters + ---------- + table : pyarrow.lib.Table + The table to be written into the ORC file + where : str or pyarrow.io.NativeFile + Writable target. For passing Python file objects or byte buffers, + see pyarrow.io.PythonFileInterface, pyarrow.io.BufferOutputStream + or pyarrow.io.FixedSizeBufferWriter. + """ + if isinstance(where, Table): + warnings.warn( + "The order of the arguments has changed. Pass as " + "'write_table(table, where)' instead. The old order will raise " + "an error in the future.", FutureWarning, stacklevel=2 + ) + table, where = where, table + writer = ORCWriter(where) + writer.write(table) + writer.close() diff --git a/src/arrow/python/pyarrow/pandas-shim.pxi b/src/arrow/python/pyarrow/pandas-shim.pxi new file mode 100644 index 000000000..0e7cfe937 --- /dev/null +++ b/src/arrow/python/pyarrow/pandas-shim.pxi @@ -0,0 +1,254 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pandas lazy-loading API shim that reduces API call and import overhead + +import warnings + + +cdef class _PandasAPIShim(object): + """ + Lazy pandas importer that isolates usages of pandas APIs and avoids + importing pandas until it's actually needed + """ + cdef: + bint _tried_importing_pandas + bint _have_pandas + + cdef readonly: + object _loose_version, _version + object _pd, _types_api, _compat_module + object _data_frame, _index, _series, _categorical_type + object _datetimetz_type, _extension_array, _extension_dtype + object _array_like_types, _is_extension_array_dtype + bint has_sparse + bint _pd024 + + def __init__(self): + self._tried_importing_pandas = False + self._have_pandas = 0 + + cdef _import_pandas(self, bint raise_): + try: + import pandas as pd + import pyarrow.pandas_compat as pdcompat + except ImportError: + self._have_pandas = False + if raise_: + raise + else: + return + + from pyarrow.vendored.version import Version + + self._pd = pd + self._version = pd.__version__ + self._loose_version = Version(pd.__version__) + + if self._loose_version < Version('0.23.0'): + self._have_pandas = False + if raise_: + raise ImportError( + "pyarrow requires pandas 0.23.0 or above, pandas {} is " + "installed".format(self._version) + ) + else: + warnings.warn( + "pyarrow requires pandas 0.23.0 or above, pandas {} is " + "installed. Therefore, pandas-specific integration is not " + "used.".format(self._version), stacklevel=2) + return + + self._compat_module = pdcompat + self._data_frame = pd.DataFrame + self._index = pd.Index + self._categorical_type = pd.Categorical + self._series = pd.Series + self._extension_array = pd.api.extensions.ExtensionArray + self._array_like_types = ( + self._series, self._index, self._categorical_type, + self._extension_array) + self._extension_dtype = pd.api.extensions.ExtensionDtype + if self._loose_version >= Version('0.24.0'): + self._is_extension_array_dtype = \ + pd.api.types.is_extension_array_dtype + else: + self._is_extension_array_dtype = None + + self._types_api = pd.api.types + self._datetimetz_type = pd.api.types.DatetimeTZDtype + self._have_pandas = True + + if self._loose_version > Version('0.25'): + self.has_sparse = False + else: + self.has_sparse = True + + self._pd024 = self._loose_version >= Version('0.24') + + cdef inline _check_import(self, bint raise_=True): + if self._tried_importing_pandas: + if not self._have_pandas and raise_: + self._import_pandas(raise_) + return + + self._tried_importing_pandas = True + self._import_pandas(raise_) + + def series(self, *args, **kwargs): + self._check_import() + return self._series(*args, **kwargs) + + def data_frame(self, *args, **kwargs): + self._check_import() + return self._data_frame(*args, **kwargs) + + cdef inline bint _have_pandas_internal(self): + if not self._tried_importing_pandas: + self._check_import(raise_=False) + return self._have_pandas + + @property + def have_pandas(self): + return self._have_pandas_internal() + + @property + def compat(self): + self._check_import() + return self._compat_module + + @property + def pd(self): + self._check_import() + return self._pd + + cpdef infer_dtype(self, obj): + self._check_import() + try: + return self._types_api.infer_dtype(obj, skipna=False) + except AttributeError: + return self._pd.lib.infer_dtype(obj) + + cpdef pandas_dtype(self, dtype): + self._check_import() + try: + return self._types_api.pandas_dtype(dtype) + except AttributeError: + return None + + @property + def loose_version(self): + self._check_import() + return self._loose_version + + @property + def version(self): + self._check_import() + return self._version + + @property + def categorical_type(self): + self._check_import() + return self._categorical_type + + @property + def datetimetz_type(self): + self._check_import() + return self._datetimetz_type + + @property + def extension_dtype(self): + self._check_import() + return self._extension_dtype + + cpdef is_array_like(self, obj): + self._check_import() + return isinstance(obj, self._array_like_types) + + cpdef is_categorical(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._categorical_type) + else: + return False + + cpdef is_datetimetz(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._datetimetz_type) + else: + return False + + cpdef is_extension_array_dtype(self, obj): + self._check_import() + if self._is_extension_array_dtype: + return self._is_extension_array_dtype(obj) + else: + return False + + cpdef is_sparse(self, obj): + if self._have_pandas_internal(): + return self._types_api.is_sparse(obj) + else: + return False + + cpdef is_data_frame(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._data_frame) + else: + return False + + cpdef is_series(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._series) + else: + return False + + cpdef is_index(self, obj): + if self._have_pandas_internal(): + return isinstance(obj, self._index) + else: + return False + + cpdef get_values(self, obj): + """ + Get the underlying array values of a pandas Series or Index in the + format (np.ndarray or pandas ExtensionArray) as we need them. + + Assumes obj is a pandas Series or Index. + """ + self._check_import() + if isinstance(obj.dtype, (self.pd.api.types.IntervalDtype, + self.pd.api.types.PeriodDtype)): + if self._pd024: + # only since pandas 0.24, interval and period are stored as + # such in Series + return obj.array + return obj.values + + def assert_frame_equal(self, *args, **kwargs): + self._check_import() + return self._pd.util.testing.assert_frame_equal + + def get_rangeindex_attribute(self, level, name): + # public start/stop/step attributes added in pandas 0.25.0 + self._check_import() + if hasattr(level, name): + return getattr(level, name) + return getattr(level, '_' + name) + + +cdef _PandasAPIShim pandas_api = _PandasAPIShim() +_pandas_api = pandas_api diff --git a/src/arrow/python/pyarrow/pandas_compat.py b/src/arrow/python/pyarrow/pandas_compat.py new file mode 100644 index 000000000..e4b13175f --- /dev/null +++ b/src/arrow/python/pyarrow/pandas_compat.py @@ -0,0 +1,1226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import ast +from collections.abc import Sequence +from concurrent import futures +# import threading submodule upfront to avoid partially initialized +# module bug (ARROW-11983) +import concurrent.futures.thread # noqa +from copy import deepcopy +from itertools import zip_longest +import json +import operator +import re +import warnings + +import numpy as np + +import pyarrow as pa +from pyarrow.lib import _pandas_api, builtin_pickle, frombytes # noqa + + +_logical_type_map = {} + + +def get_logical_type_map(): + global _logical_type_map + + if not _logical_type_map: + _logical_type_map.update({ + pa.lib.Type_NA: 'empty', + pa.lib.Type_BOOL: 'bool', + pa.lib.Type_INT8: 'int8', + pa.lib.Type_INT16: 'int16', + pa.lib.Type_INT32: 'int32', + pa.lib.Type_INT64: 'int64', + pa.lib.Type_UINT8: 'uint8', + pa.lib.Type_UINT16: 'uint16', + pa.lib.Type_UINT32: 'uint32', + pa.lib.Type_UINT64: 'uint64', + pa.lib.Type_HALF_FLOAT: 'float16', + pa.lib.Type_FLOAT: 'float32', + pa.lib.Type_DOUBLE: 'float64', + pa.lib.Type_DATE32: 'date', + pa.lib.Type_DATE64: 'date', + pa.lib.Type_TIME32: 'time', + pa.lib.Type_TIME64: 'time', + pa.lib.Type_BINARY: 'bytes', + pa.lib.Type_FIXED_SIZE_BINARY: 'bytes', + pa.lib.Type_STRING: 'unicode', + }) + return _logical_type_map + + +def get_logical_type(arrow_type): + logical_type_map = get_logical_type_map() + + try: + return logical_type_map[arrow_type.id] + except KeyError: + if isinstance(arrow_type, pa.lib.DictionaryType): + return 'categorical' + elif isinstance(arrow_type, pa.lib.ListType): + return 'list[{}]'.format(get_logical_type(arrow_type.value_type)) + elif isinstance(arrow_type, pa.lib.TimestampType): + return 'datetimetz' if arrow_type.tz is not None else 'datetime' + elif isinstance(arrow_type, pa.lib.Decimal128Type): + return 'decimal' + return 'object' + + +_numpy_logical_type_map = { + np.bool_: 'bool', + np.int8: 'int8', + np.int16: 'int16', + np.int32: 'int32', + np.int64: 'int64', + np.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + np.float32: 'float32', + np.float64: 'float64', + 'datetime64[D]': 'date', + np.unicode_: 'string', + np.bytes_: 'bytes', +} + + +def get_logical_type_from_numpy(pandas_collection): + try: + return _numpy_logical_type_map[pandas_collection.dtype.type] + except KeyError: + if hasattr(pandas_collection.dtype, 'tz'): + return 'datetimetz' + # See https://github.com/pandas-dev/pandas/issues/24739 + if str(pandas_collection.dtype) == 'datetime64[ns]': + return 'datetime64[ns]' + result = _pandas_api.infer_dtype(pandas_collection) + if result == 'string': + return 'unicode' + return result + + +def get_extension_dtype_info(column): + dtype = column.dtype + if str(dtype) == 'category': + cats = getattr(column, 'cat', column) + assert cats is not None + metadata = { + 'num_categories': len(cats.categories), + 'ordered': cats.ordered, + } + physical_dtype = str(cats.codes.dtype) + elif hasattr(dtype, 'tz'): + metadata = {'timezone': pa.lib.tzinfo_to_string(dtype.tz)} + physical_dtype = 'datetime64[ns]' + else: + metadata = None + physical_dtype = str(dtype) + return physical_dtype, metadata + + +def get_column_metadata(column, name, arrow_type, field_name): + """Construct the metadata for a given column + + Parameters + ---------- + column : pandas.Series or pandas.Index + name : str + arrow_type : pyarrow.DataType + field_name : str + Equivalent to `name` when `column` is a `Series`, otherwise if `column` + is a pandas Index then `field_name` will not be the same as `name`. + This is the name of the field in the arrow Table's schema. + + Returns + ------- + dict + """ + logical_type = get_logical_type(arrow_type) + + string_dtype, extra_metadata = get_extension_dtype_info(column) + if logical_type == 'decimal': + extra_metadata = { + 'precision': arrow_type.precision, + 'scale': arrow_type.scale, + } + string_dtype = 'object' + + if name is not None and not isinstance(name, str): + raise TypeError( + 'Column name must be a string. Got column {} of type {}'.format( + name, type(name).__name__ + ) + ) + + assert field_name is None or isinstance(field_name, str), \ + str(type(field_name)) + return { + 'name': name, + 'field_name': 'None' if field_name is None else field_name, + 'pandas_type': logical_type, + 'numpy_type': string_dtype, + 'metadata': extra_metadata, + } + + +def construct_metadata(columns_to_convert, df, column_names, index_levels, + index_descriptors, preserve_index, types): + """Returns a dictionary containing enough metadata to reconstruct a pandas + DataFrame as an Arrow Table, including index columns. + + Parameters + ---------- + columns_to_convert : list[pd.Series] + df : pandas.DataFrame + index_levels : List[pd.Index] + index_descriptors : List[Dict] + preserve_index : bool + types : List[pyarrow.DataType] + + Returns + ------- + dict + """ + num_serialized_index_levels = len([descr for descr in index_descriptors + if not isinstance(descr, dict)]) + # Use ntypes instead of Python shorthand notation [:-len(x)] as [:-0] + # behaves differently to what we want. + ntypes = len(types) + df_types = types[:ntypes - num_serialized_index_levels] + index_types = types[ntypes - num_serialized_index_levels:] + + column_metadata = [] + for col, sanitized_name, arrow_type in zip(columns_to_convert, + column_names, df_types): + metadata = get_column_metadata(col, name=sanitized_name, + arrow_type=arrow_type, + field_name=sanitized_name) + column_metadata.append(metadata) + + index_column_metadata = [] + if preserve_index is not False: + for level, arrow_type, descriptor in zip(index_levels, index_types, + index_descriptors): + if isinstance(descriptor, dict): + # The index is represented in a non-serialized fashion, + # e.g. RangeIndex + continue + metadata = get_column_metadata(level, name=level.name, + arrow_type=arrow_type, + field_name=descriptor) + index_column_metadata.append(metadata) + + column_indexes = [] + + levels = getattr(df.columns, 'levels', [df.columns]) + names = getattr(df.columns, 'names', [df.columns.name]) + for level, name in zip(levels, names): + metadata = _get_simple_index_descriptor(level, name) + column_indexes.append(metadata) + else: + index_descriptors = index_column_metadata = column_indexes = [] + + return { + b'pandas': json.dumps({ + 'index_columns': index_descriptors, + 'column_indexes': column_indexes, + 'columns': column_metadata + index_column_metadata, + 'creator': { + 'library': 'pyarrow', + 'version': pa.__version__ + }, + 'pandas_version': _pandas_api.version + }).encode('utf8') + } + + +def _get_simple_index_descriptor(level, name): + string_dtype, extra_metadata = get_extension_dtype_info(level) + pandas_type = get_logical_type_from_numpy(level) + if 'mixed' in pandas_type: + warnings.warn( + "The DataFrame has column names of mixed type. They will be " + "converted to strings and not roundtrip correctly.", + UserWarning, stacklevel=4) + if pandas_type == 'unicode': + assert not extra_metadata + extra_metadata = {'encoding': 'UTF-8'} + return { + 'name': name, + 'field_name': name, + 'pandas_type': pandas_type, + 'numpy_type': string_dtype, + 'metadata': extra_metadata, + } + + +def _column_name_to_strings(name): + """Convert a column name (or level) to either a string or a recursive + collection of strings. + + Parameters + ---------- + name : str or tuple + + Returns + ------- + value : str or tuple + + Examples + -------- + >>> name = 'foo' + >>> _column_name_to_strings(name) + 'foo' + >>> name = ('foo', 'bar') + >>> _column_name_to_strings(name) + ('foo', 'bar') + >>> import pandas as pd + >>> name = (1, pd.Timestamp('2017-02-01 00:00:00')) + >>> _column_name_to_strings(name) + ('1', '2017-02-01 00:00:00') + """ + if isinstance(name, str): + return name + elif isinstance(name, bytes): + # XXX: should we assume that bytes in Python 3 are UTF-8? + return name.decode('utf8') + elif isinstance(name, tuple): + return str(tuple(map(_column_name_to_strings, name))) + elif isinstance(name, Sequence): + raise TypeError("Unsupported type for MultiIndex level") + elif name is None: + return None + return str(name) + + +def _index_level_name(index, i, column_names): + """Return the name of an index level or a default name if `index.name` is + None or is already a column name. + + Parameters + ---------- + index : pandas.Index + i : int + + Returns + ------- + name : str + """ + if index.name is not None and index.name not in column_names: + return index.name + else: + return '__index_level_{:d}__'.format(i) + + +def _get_columns_to_convert(df, schema, preserve_index, columns): + columns = _resolve_columns_of_interest(df, schema, columns) + + if not df.columns.is_unique: + raise ValueError( + 'Duplicate column names found: {}'.format(list(df.columns)) + ) + + if schema is not None: + return _get_columns_to_convert_given_schema(df, schema, preserve_index) + + column_names = [] + + index_levels = ( + _get_index_level_values(df.index) if preserve_index is not False + else [] + ) + + columns_to_convert = [] + convert_fields = [] + + for name in columns: + col = df[name] + name = _column_name_to_strings(name) + + if _pandas_api.is_sparse(col): + raise TypeError( + "Sparse pandas data (column {}) not supported.".format(name)) + + columns_to_convert.append(col) + convert_fields.append(None) + column_names.append(name) + + index_descriptors = [] + index_column_names = [] + for i, index_level in enumerate(index_levels): + name = _index_level_name(index_level, i, column_names) + if (isinstance(index_level, _pandas_api.pd.RangeIndex) and + preserve_index is None): + descr = _get_range_index_descriptor(index_level) + else: + columns_to_convert.append(index_level) + convert_fields.append(None) + descr = name + index_column_names.append(name) + index_descriptors.append(descr) + + all_names = column_names + index_column_names + + # all_names : all of the columns in the resulting table including the data + # columns and serialized index columns + # column_names : the names of the data columns + # index_column_names : the names of the serialized index columns + # index_descriptors : descriptions of each index to be used for + # reconstruction + # index_levels : the extracted index level values + # columns_to_convert : assembled raw data (both data columns and indexes) + # to be converted to Arrow format + # columns_fields : specified column to use for coercion / casting + # during serialization, if a Schema was provided + return (all_names, column_names, index_column_names, index_descriptors, + index_levels, columns_to_convert, convert_fields) + + +def _get_columns_to_convert_given_schema(df, schema, preserve_index): + """ + Specialized version of _get_columns_to_convert in case a Schema is + specified. + In that case, the Schema is used as the single point of truth for the + table structure (types, which columns are included, order of columns, ...). + """ + column_names = [] + columns_to_convert = [] + convert_fields = [] + index_descriptors = [] + index_column_names = [] + index_levels = [] + + for name in schema.names: + try: + col = df[name] + is_index = False + except KeyError: + try: + col = _get_index_level(df, name) + except (KeyError, IndexError): + # name not found as index level + raise KeyError( + "name '{}' present in the specified schema is not found " + "in the columns or index".format(name)) + if preserve_index is False: + raise ValueError( + "name '{}' present in the specified schema corresponds " + "to the index, but 'preserve_index=False' was " + "specified".format(name)) + elif (preserve_index is None and + isinstance(col, _pandas_api.pd.RangeIndex)): + raise ValueError( + "name '{}' is present in the schema, but it is a " + "RangeIndex which will not be converted as a column " + "in the Table, but saved as metadata-only not in " + "columns. Specify 'preserve_index=True' to force it " + "being added as a column, or remove it from the " + "specified schema".format(name)) + is_index = True + + name = _column_name_to_strings(name) + + if _pandas_api.is_sparse(col): + raise TypeError( + "Sparse pandas data (column {}) not supported.".format(name)) + + field = schema.field(name) + columns_to_convert.append(col) + convert_fields.append(field) + column_names.append(name) + + if is_index: + index_column_names.append(name) + index_descriptors.append(name) + index_levels.append(col) + + all_names = column_names + index_column_names + + return (all_names, column_names, index_column_names, index_descriptors, + index_levels, columns_to_convert, convert_fields) + + +def _get_index_level(df, name): + """ + Get the index level of a DataFrame given 'name' (column name in an arrow + Schema). + """ + key = name + if name not in df.index.names and _is_generated_index_name(name): + # we know we have an autogenerated name => extract number and get + # the index level positionally + key = int(name[len("__index_level_"):-2]) + return df.index.get_level_values(key) + + +def _level_name(name): + # preserve type when default serializable, otherwise str it + try: + json.dumps(name) + return name + except TypeError: + return str(name) + + +def _get_range_index_descriptor(level): + # public start/stop/step attributes added in pandas 0.25.0 + return { + 'kind': 'range', + 'name': _level_name(level.name), + 'start': _pandas_api.get_rangeindex_attribute(level, 'start'), + 'stop': _pandas_api.get_rangeindex_attribute(level, 'stop'), + 'step': _pandas_api.get_rangeindex_attribute(level, 'step') + } + + +def _get_index_level_values(index): + n = len(getattr(index, 'levels', [index])) + return [index.get_level_values(i) for i in range(n)] + + +def _resolve_columns_of_interest(df, schema, columns): + if schema is not None and columns is not None: + raise ValueError('Schema and columns arguments are mutually ' + 'exclusive, pass only one of them') + elif schema is not None: + columns = schema.names + elif columns is not None: + columns = [c for c in columns if c in df.columns] + else: + columns = df.columns + + return columns + + +def dataframe_to_types(df, preserve_index, columns=None): + (all_names, + column_names, + _, + index_descriptors, + index_columns, + columns_to_convert, + _) = _get_columns_to_convert(df, None, preserve_index, columns) + + types = [] + # If pandas knows type, skip conversion + for c in columns_to_convert: + values = c.values + if _pandas_api.is_categorical(values): + type_ = pa.array(c, from_pandas=True).type + elif _pandas_api.is_extension_array_dtype(values): + type_ = pa.array(c.head(0), from_pandas=True).type + else: + values, type_ = get_datetimetz_type(values, c.dtype, None) + type_ = pa.lib._ndarray_to_arrow_type(values, type_) + if type_ is None: + type_ = pa.array(c, from_pandas=True).type + types.append(type_) + + metadata = construct_metadata( + columns_to_convert, df, column_names, index_columns, + index_descriptors, preserve_index, types + ) + + return all_names, types, metadata + + +def dataframe_to_arrays(df, schema, preserve_index, nthreads=1, columns=None, + safe=True): + (all_names, + column_names, + index_column_names, + index_descriptors, + index_columns, + columns_to_convert, + convert_fields) = _get_columns_to_convert(df, schema, preserve_index, + columns) + + # NOTE(wesm): If nthreads=None, then we use a heuristic to decide whether + # using a thread pool is worth it. Currently the heuristic is whether the + # nrows > 100 * ncols and ncols > 1. + if nthreads is None: + nrows, ncols = len(df), len(df.columns) + if nrows > ncols * 100 and ncols > 1: + nthreads = pa.cpu_count() + else: + nthreads = 1 + + def convert_column(col, field): + if field is None: + field_nullable = True + type_ = None + else: + field_nullable = field.nullable + type_ = field.type + + try: + result = pa.array(col, type=type_, from_pandas=True, safe=safe) + except (pa.ArrowInvalid, + pa.ArrowNotImplementedError, + pa.ArrowTypeError) as e: + e.args += ("Conversion failed for column {!s} with type {!s}" + .format(col.name, col.dtype),) + raise e + if not field_nullable and result.null_count > 0: + raise ValueError("Field {} was non-nullable but pandas column " + "had {} null values".format(str(field), + result.null_count)) + return result + + def _can_definitely_zero_copy(arr): + return (isinstance(arr, np.ndarray) and + arr.flags.contiguous and + issubclass(arr.dtype.type, np.integer)) + + if nthreads == 1: + arrays = [convert_column(c, f) + for c, f in zip(columns_to_convert, convert_fields)] + else: + arrays = [] + with futures.ThreadPoolExecutor(nthreads) as executor: + for c, f in zip(columns_to_convert, convert_fields): + if _can_definitely_zero_copy(c.values): + arrays.append(convert_column(c, f)) + else: + arrays.append(executor.submit(convert_column, c, f)) + + for i, maybe_fut in enumerate(arrays): + if isinstance(maybe_fut, futures.Future): + arrays[i] = maybe_fut.result() + + types = [x.type for x in arrays] + + if schema is None: + fields = [] + for name, type_ in zip(all_names, types): + name = name if name is not None else 'None' + fields.append(pa.field(name, type_)) + schema = pa.schema(fields) + + pandas_metadata = construct_metadata( + columns_to_convert, df, column_names, index_columns, + index_descriptors, preserve_index, types + ) + metadata = deepcopy(schema.metadata) if schema.metadata else dict() + metadata.update(pandas_metadata) + schema = schema.with_metadata(metadata) + + return arrays, schema + + +def get_datetimetz_type(values, dtype, type_): + if values.dtype.type != np.datetime64: + return values, type_ + + if _pandas_api.is_datetimetz(dtype) and type_ is None: + # If no user type passed, construct a tz-aware timestamp type + tz = dtype.tz + unit = dtype.unit + type_ = pa.timestamp(unit, tz) + elif type_ is None: + # Trust the NumPy dtype + type_ = pa.from_numpy_dtype(values.dtype) + + return values, type_ + +# ---------------------------------------------------------------------- +# Converting pandas.DataFrame to a dict containing only NumPy arrays or other +# objects friendly to pyarrow.serialize + + +def dataframe_to_serialized_dict(frame): + block_manager = frame._data + + blocks = [] + axes = [ax for ax in block_manager.axes] + + for block in block_manager.blocks: + values = block.values + block_data = {} + + if _pandas_api.is_datetimetz(values.dtype): + block_data['timezone'] = pa.lib.tzinfo_to_string(values.tz) + if hasattr(values, 'values'): + values = values.values + elif _pandas_api.is_categorical(values): + block_data.update(dictionary=values.categories, + ordered=values.ordered) + values = values.codes + block_data.update( + placement=block.mgr_locs.as_array, + block=values + ) + + # If we are dealing with an object array, pickle it instead. + if values.dtype == np.dtype(object): + block_data['object'] = None + block_data['block'] = builtin_pickle.dumps( + values, protocol=builtin_pickle.HIGHEST_PROTOCOL) + + blocks.append(block_data) + + return { + 'blocks': blocks, + 'axes': axes + } + + +def serialized_dict_to_dataframe(data): + import pandas.core.internals as _int + reconstructed_blocks = [_reconstruct_block(block) + for block in data['blocks']] + + block_mgr = _int.BlockManager(reconstructed_blocks, data['axes']) + return _pandas_api.data_frame(block_mgr) + + +def _reconstruct_block(item, columns=None, extension_columns=None): + """ + Construct a pandas Block from the `item` dictionary coming from pyarrow's + serialization or returned by arrow::python::ConvertTableToPandas. + + This function takes care of converting dictionary types to pandas + categorical, Timestamp-with-timezones to the proper pandas Block, and + conversion to pandas ExtensionBlock + + Parameters + ---------- + item : dict + For basic types, this is a dictionary in the form of + {'block': np.ndarray of values, 'placement': pandas block placement}. + Additional keys are present for other types (dictionary, timezone, + object). + columns : + Column names of the table being constructed, used for extension types + extension_columns : dict + Dictionary of {column_name: pandas_dtype} that includes all columns + and corresponding dtypes that will be converted to a pandas + ExtensionBlock. + + Returns + ------- + pandas Block + + """ + import pandas.core.internals as _int + + block_arr = item.get('block', None) + placement = item['placement'] + if 'dictionary' in item: + cat = _pandas_api.categorical_type.from_codes( + block_arr, categories=item['dictionary'], + ordered=item['ordered']) + block = _int.make_block(cat, placement=placement) + elif 'timezone' in item: + dtype = make_datetimetz(item['timezone']) + block = _int.make_block(block_arr, placement=placement, + klass=_int.DatetimeTZBlock, + dtype=dtype) + elif 'object' in item: + block = _int.make_block(builtin_pickle.loads(block_arr), + placement=placement) + elif 'py_array' in item: + # create ExtensionBlock + arr = item['py_array'] + assert len(placement) == 1 + name = columns[placement[0]] + pandas_dtype = extension_columns[name] + if not hasattr(pandas_dtype, '__from_arrow__'): + raise ValueError("This column does not support to be converted " + "to a pandas ExtensionArray") + pd_ext_arr = pandas_dtype.__from_arrow__(arr) + block = _int.make_block(pd_ext_arr, placement=placement) + else: + block = _int.make_block(block_arr, placement=placement) + + return block + + +def make_datetimetz(tz): + tz = pa.lib.string_to_tzinfo(tz) + return _pandas_api.datetimetz_type('ns', tz=tz) + + +# ---------------------------------------------------------------------- +# Converting pyarrow.Table efficiently to pandas.DataFrame + + +def table_to_blockmanager(options, table, categories=None, + ignore_metadata=False, types_mapper=None): + from pandas.core.internals import BlockManager + + all_columns = [] + column_indexes = [] + pandas_metadata = table.schema.pandas_metadata + + if not ignore_metadata and pandas_metadata is not None: + all_columns = pandas_metadata['columns'] + column_indexes = pandas_metadata.get('column_indexes', []) + index_descriptors = pandas_metadata['index_columns'] + table = _add_any_metadata(table, pandas_metadata) + table, index = _reconstruct_index(table, index_descriptors, + all_columns) + ext_columns_dtypes = _get_extension_dtypes( + table, all_columns, types_mapper) + else: + index = _pandas_api.pd.RangeIndex(table.num_rows) + ext_columns_dtypes = _get_extension_dtypes(table, [], types_mapper) + + _check_data_column_metadata_consistency(all_columns) + columns = _deserialize_column_index(table, all_columns, column_indexes) + blocks = _table_to_blocks(options, table, categories, ext_columns_dtypes) + + axes = [columns, index] + return BlockManager(blocks, axes) + + +# Set of the string repr of all numpy dtypes that can be stored in a pandas +# dataframe (complex not included since not supported by Arrow) +_pandas_supported_numpy_types = { + str(np.dtype(typ)) + for typ in (np.sctypes['int'] + np.sctypes['uint'] + np.sctypes['float'] + + ['object', 'bool']) +} + + +def _get_extension_dtypes(table, columns_metadata, types_mapper=None): + """ + Based on the stored column pandas metadata and the extension types + in the arrow schema, infer which columns should be converted to a + pandas extension dtype. + + The 'numpy_type' field in the column metadata stores the string + representation of the original pandas dtype (and, despite its name, + not the 'pandas_type' field). + Based on this string representation, a pandas/numpy dtype is constructed + and then we can check if this dtype supports conversion from arrow. + + """ + ext_columns = {} + + # older pandas version that does not yet support extension dtypes + if _pandas_api.extension_dtype is None: + return ext_columns + + # infer the extension columns from the pandas metadata + for col_meta in columns_metadata: + name = col_meta['name'] + dtype = col_meta['numpy_type'] + if dtype not in _pandas_supported_numpy_types: + # pandas_dtype is expensive, so avoid doing this for types + # that are certainly numpy dtypes + pandas_dtype = _pandas_api.pandas_dtype(dtype) + if isinstance(pandas_dtype, _pandas_api.extension_dtype): + if hasattr(pandas_dtype, "__from_arrow__"): + ext_columns[name] = pandas_dtype + + # infer from extension type in the schema + for field in table.schema: + typ = field.type + if isinstance(typ, pa.BaseExtensionType): + try: + pandas_dtype = typ.to_pandas_dtype() + except NotImplementedError: + pass + else: + ext_columns[field.name] = pandas_dtype + + # use the specified mapping of built-in arrow types to pandas dtypes + if types_mapper: + for field in table.schema: + typ = field.type + pandas_dtype = types_mapper(typ) + if pandas_dtype is not None: + ext_columns[field.name] = pandas_dtype + + return ext_columns + + +def _check_data_column_metadata_consistency(all_columns): + # It can never be the case in a released version of pyarrow that + # c['name'] is None *and* 'field_name' is not a key in the column metadata, + # because the change to allow c['name'] to be None and the change to add + # 'field_name' are in the same release (0.8.0) + assert all( + (c['name'] is None and 'field_name' in c) or c['name'] is not None + for c in all_columns + ) + + +def _deserialize_column_index(block_table, all_columns, column_indexes): + column_strings = [frombytes(x) if isinstance(x, bytes) else x + for x in block_table.column_names] + if all_columns: + columns_name_dict = { + c.get('field_name', _column_name_to_strings(c['name'])): c['name'] + for c in all_columns + } + columns_values = [ + columns_name_dict.get(name, name) for name in column_strings + ] + else: + columns_values = column_strings + + # If we're passed multiple column indexes then evaluate with + # ast.literal_eval, since the column index values show up as a list of + # tuples + to_pair = ast.literal_eval if len(column_indexes) > 1 else lambda x: (x,) + + # Create the column index + + # Construct the base index + if not columns_values: + columns = _pandas_api.pd.Index(columns_values) + else: + columns = _pandas_api.pd.MultiIndex.from_tuples( + list(map(to_pair, columns_values)), + names=[col_index['name'] for col_index in column_indexes] or None, + ) + + # if we're reconstructing the index + if len(column_indexes) > 0: + columns = _reconstruct_columns_from_metadata(columns, column_indexes) + + # ARROW-1751: flatten a single level column MultiIndex for pandas 0.21.0 + columns = _flatten_single_level_multiindex(columns) + + return columns + + +def _reconstruct_index(table, index_descriptors, all_columns): + # 0. 'field_name' is the name of the column in the arrow Table + # 1. 'name' is the user-facing name of the column, that is, it came from + # pandas + # 2. 'field_name' and 'name' differ for index columns + # 3. We fall back on c['name'] for backwards compatibility + field_name_to_metadata = { + c.get('field_name', c['name']): c + for c in all_columns + } + + # Build up a list of index columns and names while removing those columns + # from the original table + index_arrays = [] + index_names = [] + result_table = table + for descr in index_descriptors: + if isinstance(descr, str): + result_table, index_level, index_name = _extract_index_level( + table, result_table, descr, field_name_to_metadata) + if index_level is None: + # ARROW-1883: the serialized index column was not found + continue + elif descr['kind'] == 'range': + index_name = descr['name'] + index_level = _pandas_api.pd.RangeIndex(descr['start'], + descr['stop'], + step=descr['step'], + name=index_name) + if len(index_level) != len(table): + # Possibly the result of munged metadata + continue + else: + raise ValueError("Unrecognized index kind: {}" + .format(descr['kind'])) + index_arrays.append(index_level) + index_names.append(index_name) + + pd = _pandas_api.pd + + # Reconstruct the row index + if len(index_arrays) > 1: + index = pd.MultiIndex.from_arrays(index_arrays, names=index_names) + elif len(index_arrays) == 1: + index = index_arrays[0] + if not isinstance(index, pd.Index): + # Box anything that wasn't boxed above + index = pd.Index(index, name=index_names[0]) + else: + index = pd.RangeIndex(table.num_rows) + + return result_table, index + + +def _extract_index_level(table, result_table, field_name, + field_name_to_metadata): + logical_name = field_name_to_metadata[field_name]['name'] + index_name = _backwards_compatible_index_name(field_name, logical_name) + i = table.schema.get_field_index(field_name) + + if i == -1: + # The serialized index column was removed by the user + return result_table, None, None + + pd = _pandas_api.pd + + col = table.column(i) + values = col.to_pandas().values + + if hasattr(values, 'flags') and not values.flags.writeable: + # ARROW-1054: in pandas 0.19.2, factorize will reject + # non-writeable arrays when calling MultiIndex.from_arrays + values = values.copy() + + if isinstance(col.type, pa.lib.TimestampType) and col.type.tz is not None: + index_level = make_tz_aware(pd.Series(values), col.type.tz) + else: + index_level = pd.Series(values, dtype=values.dtype) + result_table = result_table.remove_column( + result_table.schema.get_field_index(field_name) + ) + return result_table, index_level, index_name + + +def _backwards_compatible_index_name(raw_name, logical_name): + """Compute the name of an index column that is compatible with older + versions of :mod:`pyarrow`. + + Parameters + ---------- + raw_name : str + logical_name : str + + Returns + ------- + result : str + + Notes + ----- + * Part of :func:`~pyarrow.pandas_compat.table_to_blockmanager` + """ + # Part of table_to_blockmanager + if raw_name == logical_name and _is_generated_index_name(raw_name): + return None + else: + return logical_name + + +def _is_generated_index_name(name): + pattern = r'^__index_level_\d+__$' + return re.match(pattern, name) is not None + + +_pandas_logical_type_map = { + 'date': 'datetime64[D]', + 'datetime': 'datetime64[ns]', + 'unicode': np.unicode_, + 'bytes': np.bytes_, + 'string': np.str_, + 'integer': np.int64, + 'floating': np.float64, + 'empty': np.object_, +} + + +def _pandas_type_to_numpy_type(pandas_type): + """Get the numpy dtype that corresponds to a pandas type. + + Parameters + ---------- + pandas_type : str + The result of a call to pandas.lib.infer_dtype. + + Returns + ------- + dtype : np.dtype + The dtype that corresponds to `pandas_type`. + """ + try: + return _pandas_logical_type_map[pandas_type] + except KeyError: + if 'mixed' in pandas_type: + # catching 'mixed', 'mixed-integer' and 'mixed-integer-float' + return np.object_ + return np.dtype(pandas_type) + + +def _get_multiindex_codes(mi): + # compat for pandas < 0.24 (MI labels renamed to codes). + if isinstance(mi, _pandas_api.pd.MultiIndex): + return mi.codes if hasattr(mi, 'codes') else mi.labels + else: + return None + + +def _reconstruct_columns_from_metadata(columns, column_indexes): + """Construct a pandas MultiIndex from `columns` and column index metadata + in `column_indexes`. + + Parameters + ---------- + columns : List[pd.Index] + The columns coming from a pyarrow.Table + column_indexes : List[Dict[str, str]] + The column index metadata deserialized from the JSON schema metadata + in a :class:`~pyarrow.Table`. + + Returns + ------- + result : MultiIndex + The index reconstructed using `column_indexes` metadata with levels of + the correct type. + + Notes + ----- + * Part of :func:`~pyarrow.pandas_compat.table_to_blockmanager` + """ + pd = _pandas_api.pd + # Get levels and labels, and provide sane defaults if the index has a + # single level to avoid if/else spaghetti. + levels = getattr(columns, 'levels', None) or [columns] + labels = _get_multiindex_codes(columns) or [ + pd.RangeIndex(len(level)) for level in levels + ] + + # Convert each level to the dtype provided in the metadata + levels_dtypes = [ + (level, col_index.get('pandas_type', str(level.dtype)), + col_index.get('numpy_type', None)) + for level, col_index in zip_longest( + levels, column_indexes, fillvalue={} + ) + ] + + new_levels = [] + encoder = operator.methodcaller('encode', 'UTF-8') + + for level, pandas_dtype, numpy_dtype in levels_dtypes: + dtype = _pandas_type_to_numpy_type(pandas_dtype) + # Since our metadata is UTF-8 encoded, Python turns things that were + # bytes into unicode strings when json.loads-ing them. We need to + # convert them back to bytes to preserve metadata. + if dtype == np.bytes_: + level = level.map(encoder) + elif level.dtype != dtype: + level = level.astype(dtype) + # ARROW-9096: if original DataFrame was upcast we keep that + if level.dtype != numpy_dtype: + level = level.astype(numpy_dtype) + + new_levels.append(level) + + return pd.MultiIndex(new_levels, labels, names=columns.names) + + +def _table_to_blocks(options, block_table, categories, extension_columns): + # Part of table_to_blockmanager + + # Convert an arrow table to Block from the internal pandas API + columns = block_table.column_names + result = pa.lib.table_to_blocks(options, block_table, categories, + list(extension_columns.keys())) + return [_reconstruct_block(item, columns, extension_columns) + for item in result] + + +def _flatten_single_level_multiindex(index): + pd = _pandas_api.pd + if isinstance(index, pd.MultiIndex) and index.nlevels == 1: + levels, = index.levels + labels, = _get_multiindex_codes(index) + # ARROW-9096: use levels.dtype to match cast with original DataFrame + dtype = levels.dtype + + # Cheaply check that we do not somehow have duplicate column names + if not index.is_unique: + raise ValueError('Found non-unique column index') + + return pd.Index( + [levels[_label] if _label != -1 else None for _label in labels], + dtype=dtype, + name=index.names[0] + ) + return index + + +def _add_any_metadata(table, pandas_metadata): + modified_columns = {} + modified_fields = {} + + schema = table.schema + + index_columns = pandas_metadata['index_columns'] + # only take index columns into account if they are an actual table column + index_columns = [idx_col for idx_col in index_columns + if isinstance(idx_col, str)] + n_index_levels = len(index_columns) + n_columns = len(pandas_metadata['columns']) - n_index_levels + + # Add time zones + for i, col_meta in enumerate(pandas_metadata['columns']): + + raw_name = col_meta.get('field_name') + if not raw_name: + # deal with metadata written with arrow < 0.8 or fastparquet + raw_name = col_meta['name'] + if i >= n_columns: + # index columns + raw_name = index_columns[i - n_columns] + if raw_name is None: + raw_name = 'None' + + idx = schema.get_field_index(raw_name) + if idx != -1: + if col_meta['pandas_type'] == 'datetimetz': + col = table[idx] + if not isinstance(col.type, pa.lib.TimestampType): + continue + metadata = col_meta['metadata'] + if not metadata: + continue + metadata_tz = metadata.get('timezone') + if metadata_tz and metadata_tz != col.type.tz: + converted = col.to_pandas() + tz_aware_type = pa.timestamp('ns', tz=metadata_tz) + with_metadata = pa.Array.from_pandas(converted, + type=tz_aware_type) + + modified_fields[idx] = pa.field(schema[idx].name, + tz_aware_type) + modified_columns[idx] = with_metadata + + if len(modified_columns) > 0: + columns = [] + fields = [] + for i in range(len(table.schema)): + if i in modified_columns: + columns.append(modified_columns[i]) + fields.append(modified_fields[i]) + else: + columns.append(table[i]) + fields.append(table.schema[i]) + return pa.Table.from_arrays(columns, schema=pa.schema(fields)) + else: + return table + + +# ---------------------------------------------------------------------- +# Helper functions used in lib + + +def make_tz_aware(series, tz): + """ + Make a datetime64 Series timezone-aware for the given tz + """ + tz = pa.lib.string_to_tzinfo(tz) + series = (series.dt.tz_localize('utc') + .dt.tz_convert(tz)) + return series diff --git a/src/arrow/python/pyarrow/parquet.py b/src/arrow/python/pyarrow/parquet.py new file mode 100644 index 000000000..8041b4e3c --- /dev/null +++ b/src/arrow/python/pyarrow/parquet.py @@ -0,0 +1,2299 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from collections import defaultdict +from concurrent import futures +from functools import partial, reduce + +import json +from collections.abc import Collection +import numpy as np +import os +import re +import operator +import urllib.parse +import warnings + +import pyarrow as pa +import pyarrow.lib as lib +import pyarrow._parquet as _parquet + +from pyarrow._parquet import (ParquetReader, Statistics, # noqa + FileMetaData, RowGroupMetaData, + ColumnChunkMetaData, + ParquetSchema, ColumnSchema) +from pyarrow.fs import (LocalFileSystem, FileSystem, + _resolve_filesystem_and_path, _ensure_filesystem) +from pyarrow import filesystem as legacyfs +from pyarrow.util import guid, _is_path_like, _stringify_path + +_URI_STRIP_SCHEMES = ('hdfs',) + + +def _parse_uri(path): + path = _stringify_path(path) + parsed_uri = urllib.parse.urlparse(path) + if parsed_uri.scheme in _URI_STRIP_SCHEMES: + return parsed_uri.path + else: + # ARROW-4073: On Windows returning the path with the scheme + # stripped removes the drive letter, if any + return path + + +def _get_filesystem_and_path(passed_filesystem, path): + if passed_filesystem is None: + return legacyfs.resolve_filesystem_and_path(path, passed_filesystem) + else: + passed_filesystem = legacyfs._ensure_filesystem(passed_filesystem) + parsed_path = _parse_uri(path) + return passed_filesystem, parsed_path + + +def _check_contains_null(val): + if isinstance(val, bytes): + for byte in val: + if isinstance(byte, bytes): + compare_to = chr(0) + else: + compare_to = 0 + if byte == compare_to: + return True + elif isinstance(val, str): + return '\x00' in val + return False + + +def _check_filters(filters, check_null_strings=True): + """ + Check if filters are well-formed. + """ + if filters is not None: + if len(filters) == 0 or any(len(f) == 0 for f in filters): + raise ValueError("Malformed filters") + if isinstance(filters[0][0], str): + # We have encountered the situation where we have one nesting level + # too few: + # We have [(,,), ..] instead of [[(,,), ..]] + filters = [filters] + if check_null_strings: + for conjunction in filters: + for col, op, val in conjunction: + if ( + isinstance(val, list) and + all(_check_contains_null(v) for v in val) or + _check_contains_null(val) + ): + raise NotImplementedError( + "Null-terminated binary strings are not supported " + "as filter values." + ) + return filters + + +_DNF_filter_doc = """Predicates are expressed in disjunctive normal form (DNF), like + ``[[('x', '=', 0), ...], ...]``. DNF allows arbitrary boolean logical + combinations of single column predicates. The innermost tuples each + describe a single column predicate. The list of inner predicates is + interpreted as a conjunction (AND), forming a more selective and + multiple column predicate. Finally, the most outer list combines these + filters as a disjunction (OR). + + Predicates may also be passed as List[Tuple]. This form is interpreted + as a single conjunction. To express OR in predicates, one must + use the (preferred) List[List[Tuple]] notation. + + Each tuple has format: (``key``, ``op``, ``value``) and compares the + ``key`` with the ``value``. + The supported ``op`` are: ``=`` or ``==``, ``!=``, ``<``, ``>``, ``<=``, + ``>=``, ``in`` and ``not in``. If the ``op`` is ``in`` or ``not in``, the + ``value`` must be a collection such as a ``list``, a ``set`` or a + ``tuple``. + + Examples: + + .. code-block:: python + + ('x', '=', 0) + ('y', 'in', ['a', 'b', 'c']) + ('z', 'not in', {'a','b'}) + + """ + + +def _filters_to_expression(filters): + """ + Check if filters are well-formed. + + See _DNF_filter_doc above for more details. + """ + import pyarrow.dataset as ds + + if isinstance(filters, ds.Expression): + return filters + + filters = _check_filters(filters, check_null_strings=False) + + def convert_single_predicate(col, op, val): + field = ds.field(col) + + if op == "=" or op == "==": + return field == val + elif op == "!=": + return field != val + elif op == '<': + return field < val + elif op == '>': + return field > val + elif op == '<=': + return field <= val + elif op == '>=': + return field >= val + elif op == 'in': + return field.isin(val) + elif op == 'not in': + return ~field.isin(val) + else: + raise ValueError( + '"{0}" is not a valid operator in predicates.'.format( + (col, op, val))) + + disjunction_members = [] + + for conjunction in filters: + conjunction_members = [ + convert_single_predicate(col, op, val) + for col, op, val in conjunction + ] + + disjunction_members.append(reduce(operator.and_, conjunction_members)) + + return reduce(operator.or_, disjunction_members) + + +# ---------------------------------------------------------------------- +# Reading a single Parquet file + + +class ParquetFile: + """ + Reader interface for a single Parquet file. + + Parameters + ---------- + source : str, pathlib.Path, pyarrow.NativeFile, or file-like object + Readable source. For passing bytes or buffer-like file containing a + Parquet file, use pyarrow.BufferReader. + metadata : FileMetaData, default None + Use existing metadata object, rather than reading from file. + common_metadata : FileMetaData, default None + Will be used in reads for pandas schema metadata if not found in the + main file's metadata, no other uses at the moment. + memory_map : bool, default False + If the source is a file path, use a memory map to read file, which can + improve performance in some environments. + buffer_size : int, default 0 + If positive, perform read buffering when deserializing individual + column chunks. Otherwise IO calls are unbuffered. + pre_buffer : bool, default False + Coalesce and issue file reads in parallel to improve performance on + high-latency filesystems (e.g. S3). If True, Arrow will use a + background I/O thread pool. + read_dictionary : list + List of column names to read directly as DictionaryArray. + coerce_int96_timestamp_unit : str, default None. + Cast timestamps that are stored in INT96 format to a particular + resolution (e.g. 'ms'). Setting to None is equivalent to 'ns' + and therefore INT96 timestamps will be infered as timestamps + in nanoseconds. + """ + + def __init__(self, source, metadata=None, common_metadata=None, + read_dictionary=None, memory_map=False, buffer_size=0, + pre_buffer=False, coerce_int96_timestamp_unit=None): + self.reader = ParquetReader() + self.reader.open( + source, use_memory_map=memory_map, + buffer_size=buffer_size, pre_buffer=pre_buffer, + read_dictionary=read_dictionary, metadata=metadata, + coerce_int96_timestamp_unit=coerce_int96_timestamp_unit + ) + self.common_metadata = common_metadata + self._nested_paths_by_prefix = self._build_nested_paths() + + def _build_nested_paths(self): + paths = self.reader.column_paths + + result = defaultdict(list) + + for i, path in enumerate(paths): + key = path[0] + rest = path[1:] + while True: + result[key].append(i) + + if not rest: + break + + key = '.'.join((key, rest[0])) + rest = rest[1:] + + return result + + @property + def metadata(self): + return self.reader.metadata + + @property + def schema(self): + """ + Return the Parquet schema, unconverted to Arrow types + """ + return self.metadata.schema + + @property + def schema_arrow(self): + """ + Return the inferred Arrow schema, converted from the whole Parquet + file's schema + """ + return self.reader.schema_arrow + + @property + def num_row_groups(self): + return self.reader.num_row_groups + + def read_row_group(self, i, columns=None, use_threads=True, + use_pandas_metadata=False): + """ + Read a single row group from a Parquet file. + + Parameters + ---------- + i : int + Index of the individual row group that we want to read. + columns : list + If not None, only these columns will be read from the row group. A + column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e'. + use_threads : bool, default True + Perform multi-threaded column reads. + use_pandas_metadata : bool, default False + If True and file has custom pandas schema metadata, ensure that + index columns are also loaded. + + Returns + ------- + pyarrow.table.Table + Content of the row group as a table (of columns) + """ + column_indices = self._get_column_indices( + columns, use_pandas_metadata=use_pandas_metadata) + return self.reader.read_row_group(i, column_indices=column_indices, + use_threads=use_threads) + + def read_row_groups(self, row_groups, columns=None, use_threads=True, + use_pandas_metadata=False): + """ + Read a multiple row groups from a Parquet file. + + Parameters + ---------- + row_groups : list + Only these row groups will be read from the file. + columns : list + If not None, only these columns will be read from the row group. A + column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e'. + use_threads : bool, default True + Perform multi-threaded column reads. + use_pandas_metadata : bool, default False + If True and file has custom pandas schema metadata, ensure that + index columns are also loaded. + + Returns + ------- + pyarrow.table.Table + Content of the row groups as a table (of columns). + """ + column_indices = self._get_column_indices( + columns, use_pandas_metadata=use_pandas_metadata) + return self.reader.read_row_groups(row_groups, + column_indices=column_indices, + use_threads=use_threads) + + def iter_batches(self, batch_size=65536, row_groups=None, columns=None, + use_threads=True, use_pandas_metadata=False): + """ + Read streaming batches from a Parquet file + + Parameters + ---------- + batch_size : int, default 64K + Maximum number of records to yield per batch. Batches may be + smaller if there aren't enough rows in the file. + row_groups : list + Only these row groups will be read from the file. + columns : list + If not None, only these columns will be read from the file. A + column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e'. + use_threads : boolean, default True + Perform multi-threaded column reads. + use_pandas_metadata : boolean, default False + If True and file has custom pandas schema metadata, ensure that + index columns are also loaded. + + Returns + ------- + iterator of pyarrow.RecordBatch + Contents of each batch as a record batch + """ + if row_groups is None: + row_groups = range(0, self.metadata.num_row_groups) + column_indices = self._get_column_indices( + columns, use_pandas_metadata=use_pandas_metadata) + + batches = self.reader.iter_batches(batch_size, + row_groups=row_groups, + column_indices=column_indices, + use_threads=use_threads) + return batches + + def read(self, columns=None, use_threads=True, use_pandas_metadata=False): + """ + Read a Table from Parquet format, + + Parameters + ---------- + columns : list + If not None, only these columns will be read from the file. A + column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e'. + use_threads : bool, default True + Perform multi-threaded column reads. + use_pandas_metadata : bool, default False + If True and file has custom pandas schema metadata, ensure that + index columns are also loaded. + + Returns + ------- + pyarrow.table.Table + Content of the file as a table (of columns). + """ + column_indices = self._get_column_indices( + columns, use_pandas_metadata=use_pandas_metadata) + return self.reader.read_all(column_indices=column_indices, + use_threads=use_threads) + + def scan_contents(self, columns=None, batch_size=65536): + """ + Read contents of file for the given columns and batch size. + + Notes + ----- + This function's primary purpose is benchmarking. + The scan is executed on a single thread. + + Parameters + ---------- + columns : list of integers, default None + Select columns to read, if None scan all columns. + batch_size : int, default 64K + Number of rows to read at a time internally. + + Returns + ------- + num_rows : number of rows in file + """ + column_indices = self._get_column_indices(columns) + return self.reader.scan_contents(column_indices, + batch_size=batch_size) + + def _get_column_indices(self, column_names, use_pandas_metadata=False): + if column_names is None: + return None + + indices = [] + + for name in column_names: + if name in self._nested_paths_by_prefix: + indices.extend(self._nested_paths_by_prefix[name]) + + if use_pandas_metadata: + file_keyvalues = self.metadata.metadata + common_keyvalues = (self.common_metadata.metadata + if self.common_metadata is not None + else None) + + if file_keyvalues and b'pandas' in file_keyvalues: + index_columns = _get_pandas_index_columns(file_keyvalues) + elif common_keyvalues and b'pandas' in common_keyvalues: + index_columns = _get_pandas_index_columns(common_keyvalues) + else: + index_columns = [] + + if indices is not None and index_columns: + indices += [self.reader.column_name_idx(descr) + for descr in index_columns + if not isinstance(descr, dict)] + + return indices + + +_SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]') + + +def _sanitized_spark_field_name(name): + return _SPARK_DISALLOWED_CHARS.sub('_', name) + + +def _sanitize_schema(schema, flavor): + if 'spark' in flavor: + sanitized_fields = [] + + schema_changed = False + + for field in schema: + name = field.name + sanitized_name = _sanitized_spark_field_name(name) + + if sanitized_name != name: + schema_changed = True + sanitized_field = pa.field(sanitized_name, field.type, + field.nullable, field.metadata) + sanitized_fields.append(sanitized_field) + else: + sanitized_fields.append(field) + + new_schema = pa.schema(sanitized_fields, metadata=schema.metadata) + return new_schema, schema_changed + else: + return schema, False + + +def _sanitize_table(table, new_schema, flavor): + # TODO: This will not handle prohibited characters in nested field names + if 'spark' in flavor: + column_data = [table[i] for i in range(table.num_columns)] + return pa.Table.from_arrays(column_data, schema=new_schema) + else: + return table + + +_parquet_writer_arg_docs = """version : {"1.0", "2.4", "2.6"}, default "1.0" + Determine which Parquet logical types are available for use, whether the + reduced set from the Parquet 1.x.x format or the expanded logical types + added in later format versions. + Files written with version='2.4' or '2.6' may not be readable in all + Parquet implementations, so version='1.0' is likely the choice that + maximizes file compatibility. + UINT32 and some logical types are only available with version '2.4'. + Nanosecond timestamps are only available with version '2.6'. + Other features such as compression algorithms or the new serialized + data page format must be enabled separately (see 'compression' and + 'data_page_version'). +use_dictionary : bool or list + Specify if we should use dictionary encoding in general or only for + some columns. +use_deprecated_int96_timestamps : bool, default None + Write timestamps to INT96 Parquet format. Defaults to False unless enabled + by flavor argument. This take priority over the coerce_timestamps option. +coerce_timestamps : str, default None + Cast timestamps to a particular resolution. If omitted, defaults are chosen + depending on `version`. By default, for ``version='1.0'`` (the default) + and ``version='2.4'``, nanoseconds are cast to microseconds ('us'), while + for other `version` values, they are written natively without loss + of resolution. Seconds are always cast to milliseconds ('ms') by default, + as Parquet does not have any temporal type with seconds resolution. + If the casting results in loss of data, it will raise an exception + unless ``allow_truncated_timestamps=True`` is given. + Valid values: {None, 'ms', 'us'} +data_page_size : int, default None + Set a target threshold for the approximate encoded size of data + pages within a column chunk (in bytes). If None, use the default data page + size of 1MByte. +allow_truncated_timestamps : bool, default False + Allow loss of data when coercing timestamps to a particular + resolution. E.g. if microsecond or nanosecond data is lost when coercing to + 'ms', do not raise an exception. Passing ``allow_truncated_timestamp=True`` + will NOT result in the truncation exception being ignored unless + ``coerce_timestamps`` is not None. +compression : str or dict + Specify the compression codec, either on a general basis or per-column. + Valid values: {'NONE', 'SNAPPY', 'GZIP', 'BROTLI', 'LZ4', 'ZSTD'}. +write_statistics : bool or list + Specify if we should write statistics in general (default is True) or only + for some columns. +flavor : {'spark'}, default None + Sanitize schema or set other compatibility options to work with + various target systems. +filesystem : FileSystem, default None + If nothing passed, will be inferred from `where` if path-like, else + `where` is already a file-like object so no filesystem is needed. +compression_level : int or dict, default None + Specify the compression level for a codec, either on a general basis or + per-column. If None is passed, arrow selects the compression level for + the compression codec in use. The compression level has a different + meaning for each codec, so you have to read the documentation of the + codec you are using. + An exception is thrown if the compression codec does not allow specifying + a compression level. +use_byte_stream_split : bool or list, default False + Specify if the byte_stream_split encoding should be used in general or + only for some columns. If both dictionary and byte_stream_stream are + enabled, then dictionary is preferred. + The byte_stream_split encoding is valid only for floating-point data types + and should be combined with a compression codec. +data_page_version : {"1.0", "2.0"}, default "1.0" + The serialized Parquet data page format version to write, defaults to + 1.0. This does not impact the file schema logical types and Arrow to + Parquet type casting behavior; for that use the "version" option. +use_compliant_nested_type : bool, default False + Whether to write compliant Parquet nested type (lists) as defined + `here <https://github.com/apache/parquet-format/blob/master/ + LogicalTypes.md#nested-types>`_, defaults to ``False``. + For ``use_compliant_nested_type=True``, this will write into a list + with 3-level structure where the middle level, named ``list``, + is a repeated group with a single field named ``element``:: + + <list-repetition> group <name> (LIST) { + repeated group list { + <element-repetition> <element-type> element; + } + } + + For ``use_compliant_nested_type=False``, this will also write into a list + with 3-level structure, where the name of the single field of the middle + level ``list`` is taken from the element name for nested columns in Arrow, + which defaults to ``item``:: + + <list-repetition> group <name> (LIST) { + repeated group list { + <element-repetition> <element-type> item; + } + } +""" + + +class ParquetWriter: + + __doc__ = """ +Class for incrementally building a Parquet file for Arrow tables. + +Parameters +---------- +where : path or file-like object +schema : arrow Schema +{} +writer_engine_version : unused +**options : dict + If options contains a key `metadata_collector` then the + corresponding value is assumed to be a list (or any object with + `.append` method) that will be filled with the file metadata instance + of the written file. +""".format(_parquet_writer_arg_docs) + + def __init__(self, where, schema, filesystem=None, + flavor=None, + version='1.0', + use_dictionary=True, + compression='snappy', + write_statistics=True, + use_deprecated_int96_timestamps=None, + compression_level=None, + use_byte_stream_split=False, + writer_engine_version=None, + data_page_version='1.0', + use_compliant_nested_type=False, + **options): + if use_deprecated_int96_timestamps is None: + # Use int96 timestamps for Spark + if flavor is not None and 'spark' in flavor: + use_deprecated_int96_timestamps = True + else: + use_deprecated_int96_timestamps = False + + self.flavor = flavor + if flavor is not None: + schema, self.schema_changed = _sanitize_schema(schema, flavor) + else: + self.schema_changed = False + + self.schema = schema + self.where = where + + # If we open a file using a filesystem, store file handle so we can be + # sure to close it when `self.close` is called. + self.file_handle = None + + filesystem, path = _resolve_filesystem_and_path( + where, filesystem, allow_legacy_filesystem=True + ) + if filesystem is not None: + if isinstance(filesystem, legacyfs.FileSystem): + # legacy filesystem (eg custom subclass) + # TODO deprecate + sink = self.file_handle = filesystem.open(path, 'wb') + else: + # ARROW-10480: do not auto-detect compression. While + # a filename like foo.parquet.gz is nonconforming, it + # shouldn't implicitly apply compression. + sink = self.file_handle = filesystem.open_output_stream( + path, compression=None) + else: + sink = where + self._metadata_collector = options.pop('metadata_collector', None) + engine_version = 'V2' + self.writer = _parquet.ParquetWriter( + sink, schema, + version=version, + compression=compression, + use_dictionary=use_dictionary, + write_statistics=write_statistics, + use_deprecated_int96_timestamps=use_deprecated_int96_timestamps, + compression_level=compression_level, + use_byte_stream_split=use_byte_stream_split, + writer_engine_version=engine_version, + data_page_version=data_page_version, + use_compliant_nested_type=use_compliant_nested_type, + **options) + self.is_open = True + + def __del__(self): + if getattr(self, 'is_open', False): + self.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + # return false since we want to propagate exceptions + return False + + def write_table(self, table, row_group_size=None): + if self.schema_changed: + table = _sanitize_table(table, self.schema, self.flavor) + assert self.is_open + + if not table.schema.equals(self.schema, check_metadata=False): + msg = ('Table schema does not match schema used to create file: ' + '\ntable:\n{!s} vs. \nfile:\n{!s}' + .format(table.schema, self.schema)) + raise ValueError(msg) + + self.writer.write_table(table, row_group_size=row_group_size) + + def close(self): + if self.is_open: + self.writer.close() + self.is_open = False + if self._metadata_collector is not None: + self._metadata_collector.append(self.writer.metadata) + if self.file_handle is not None: + self.file_handle.close() + + +def _get_pandas_index_columns(keyvalues): + return (json.loads(keyvalues[b'pandas'].decode('utf8')) + ['index_columns']) + + +# ---------------------------------------------------------------------- +# Metadata container providing instructions about reading a single Parquet +# file, possibly part of a partitioned dataset + + +class ParquetDatasetPiece: + """ + DEPRECATED: A single chunk of a potentially larger Parquet dataset to read. + + The arguments will indicate to read either a single row group or all row + groups, and whether to add partition keys to the resulting pyarrow.Table. + + .. deprecated:: 5.0 + Directly constructing a ``ParquetDatasetPiece`` is deprecated, as well + as accessing the pieces of a ``ParquetDataset`` object. Specify + ``use_legacy_dataset=False`` when constructing the ``ParquetDataset`` + and use the ``ParquetDataset.fragments`` attribute instead. + + Parameters + ---------- + path : str or pathlib.Path + Path to file in the file system where this piece is located. + open_file_func : callable + Function to use for obtaining file handle to dataset piece. + partition_keys : list of tuples + Two-element tuples of ``(column name, ordinal index)``. + row_group : int, default None + Row group to load. By default, reads all row groups. + file_options : dict + Options + """ + + def __init__(self, path, open_file_func=partial(open, mode='rb'), + file_options=None, row_group=None, partition_keys=None): + warnings.warn( + "ParquetDatasetPiece is deprecated as of pyarrow 5.0.0 and will " + "be removed in a future version.", + DeprecationWarning, stacklevel=2) + self._init( + path, open_file_func, file_options, row_group, partition_keys) + + @staticmethod + def _create(path, open_file_func=partial(open, mode='rb'), + file_options=None, row_group=None, partition_keys=None): + self = ParquetDatasetPiece.__new__(ParquetDatasetPiece) + self._init( + path, open_file_func, file_options, row_group, partition_keys) + return self + + def _init(self, path, open_file_func, file_options, row_group, + partition_keys): + self.path = _stringify_path(path) + self.open_file_func = open_file_func + self.row_group = row_group + self.partition_keys = partition_keys or [] + self.file_options = file_options or {} + + def __eq__(self, other): + if not isinstance(other, ParquetDatasetPiece): + return False + return (self.path == other.path and + self.row_group == other.row_group and + self.partition_keys == other.partition_keys) + + def __repr__(self): + return ('{}({!r}, row_group={!r}, partition_keys={!r})' + .format(type(self).__name__, self.path, + self.row_group, + self.partition_keys)) + + def __str__(self): + result = '' + + if len(self.partition_keys) > 0: + partition_str = ', '.join('{}={}'.format(name, index) + for name, index in self.partition_keys) + result += 'partition[{}] '.format(partition_str) + + result += self.path + + if self.row_group is not None: + result += ' | row_group={}'.format(self.row_group) + + return result + + def get_metadata(self): + """ + Return the file's metadata. + + Returns + ------- + metadata : FileMetaData + """ + f = self.open() + return f.metadata + + def open(self): + """ + Return instance of ParquetFile. + """ + reader = self.open_file_func(self.path) + if not isinstance(reader, ParquetFile): + reader = ParquetFile(reader, **self.file_options) + return reader + + def read(self, columns=None, use_threads=True, partitions=None, + file=None, use_pandas_metadata=False): + """ + Read this piece as a pyarrow.Table. + + Parameters + ---------- + columns : list of column names, default None + use_threads : bool, default True + Perform multi-threaded column reads. + partitions : ParquetPartitions, default None + file : file-like object + Passed to ParquetFile. + use_pandas_metadata : bool + If pandas metadata should be used or not. + + Returns + ------- + table : pyarrow.Table + """ + if self.open_file_func is not None: + reader = self.open() + elif file is not None: + reader = ParquetFile(file, **self.file_options) + else: + # try to read the local path + reader = ParquetFile(self.path, **self.file_options) + + options = dict(columns=columns, + use_threads=use_threads, + use_pandas_metadata=use_pandas_metadata) + + if self.row_group is not None: + table = reader.read_row_group(self.row_group, **options) + else: + table = reader.read(**options) + + if len(self.partition_keys) > 0: + if partitions is None: + raise ValueError('Must pass partition sets') + + # Here, the index is the categorical code of the partition where + # this piece is located. Suppose we had + # + # /foo=a/0.parq + # /foo=b/0.parq + # /foo=c/0.parq + # + # Then we assign a=0, b=1, c=2. And the resulting Table pieces will + # have a DictionaryArray column named foo having the constant index + # value as indicated. The distinct categories of the partition have + # been computed in the ParquetManifest + for i, (name, index) in enumerate(self.partition_keys): + # The partition code is the same for all values in this piece + indices = np.full(len(table), index, dtype='i4') + + # This is set of all partition values, computed as part of the + # manifest, so ['a', 'b', 'c'] as in our example above. + dictionary = partitions.levels[i].dictionary + + arr = pa.DictionaryArray.from_arrays(indices, dictionary) + table = table.append_column(name, arr) + + return table + + +class PartitionSet: + """ + A data structure for cataloguing the observed Parquet partitions at a + particular level. So if we have + + /foo=a/bar=0 + /foo=a/bar=1 + /foo=a/bar=2 + /foo=b/bar=0 + /foo=b/bar=1 + /foo=b/bar=2 + + Then we have two partition sets, one for foo, another for bar. As we visit + levels of the partition hierarchy, a PartitionSet tracks the distinct + values and assigns categorical codes to use when reading the pieces + + Parameters + ---------- + name : str + Name of the partition set. Under which key to collect all values. + keys : list + All possible values that have been collected for that partition set. + """ + + def __init__(self, name, keys=None): + self.name = name + self.keys = keys or [] + self.key_indices = {k: i for i, k in enumerate(self.keys)} + self._dictionary = None + + def get_index(self, key): + """ + Get the index of the partition value if it is known, otherwise assign + one + + Parameters + ---------- + key : The value for which we want to known the index. + """ + if key in self.key_indices: + return self.key_indices[key] + else: + index = len(self.key_indices) + self.keys.append(key) + self.key_indices[key] = index + return index + + @property + def dictionary(self): + if self._dictionary is not None: + return self._dictionary + + if len(self.keys) == 0: + raise ValueError('No known partition keys') + + # Only integer and string partition types are supported right now + try: + integer_keys = [int(x) for x in self.keys] + dictionary = lib.array(integer_keys) + except ValueError: + dictionary = lib.array(self.keys) + + self._dictionary = dictionary + return dictionary + + @property + def is_sorted(self): + return list(self.keys) == sorted(self.keys) + + +class ParquetPartitions: + + def __init__(self): + self.levels = [] + self.partition_names = set() + + def __len__(self): + return len(self.levels) + + def __getitem__(self, i): + return self.levels[i] + + def equals(self, other): + if not isinstance(other, ParquetPartitions): + raise TypeError('`other` must be an instance of ParquetPartitions') + + return (self.levels == other.levels and + self.partition_names == other.partition_names) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def get_index(self, level, name, key): + """ + Record a partition value at a particular level, returning the distinct + code for that value at that level. + + Example: + + partitions.get_index(1, 'foo', 'a') returns 0 + partitions.get_index(1, 'foo', 'b') returns 1 + partitions.get_index(1, 'foo', 'c') returns 2 + partitions.get_index(1, 'foo', 'a') returns 0 + + Parameters + ---------- + level : int + The nesting level of the partition we are observing + name : str + The partition name + key : str or int + The partition value + """ + if level == len(self.levels): + if name in self.partition_names: + raise ValueError('{} was the name of the partition in ' + 'another level'.format(name)) + + part_set = PartitionSet(name) + self.levels.append(part_set) + self.partition_names.add(name) + + return self.levels[level].get_index(key) + + def filter_accepts_partition(self, part_key, filter, level): + p_column, p_value_index = part_key + f_column, op, f_value = filter + if p_column != f_column: + return True + + f_type = type(f_value) + + if op in {'in', 'not in'}: + if not isinstance(f_value, Collection): + raise TypeError( + "'%s' object is not a collection", f_type.__name__) + if not f_value: + raise ValueError("Cannot use empty collection as filter value") + if len({type(item) for item in f_value}) != 1: + raise ValueError("All elements of the collection '%s' must be" + " of same type", f_value) + f_type = type(next(iter(f_value))) + + elif not isinstance(f_value, str) and isinstance(f_value, Collection): + raise ValueError( + "Op '%s' not supported with a collection value", op) + + p_value = f_type(self.levels[level] + .dictionary[p_value_index].as_py()) + + if op == "=" or op == "==": + return p_value == f_value + elif op == "!=": + return p_value != f_value + elif op == '<': + return p_value < f_value + elif op == '>': + return p_value > f_value + elif op == '<=': + return p_value <= f_value + elif op == '>=': + return p_value >= f_value + elif op == 'in': + return p_value in f_value + elif op == 'not in': + return p_value not in f_value + else: + raise ValueError("'%s' is not a valid operator in predicates.", + filter[1]) + + +class ParquetManifest: + + def __init__(self, dirpath, open_file_func=None, filesystem=None, + pathsep='/', partition_scheme='hive', metadata_nthreads=1): + filesystem, dirpath = _get_filesystem_and_path(filesystem, dirpath) + self.filesystem = filesystem + self.open_file_func = open_file_func + self.pathsep = pathsep + self.dirpath = _stringify_path(dirpath) + self.partition_scheme = partition_scheme + self.partitions = ParquetPartitions() + self.pieces = [] + self._metadata_nthreads = metadata_nthreads + self._thread_pool = futures.ThreadPoolExecutor( + max_workers=metadata_nthreads) + + self.common_metadata_path = None + self.metadata_path = None + + self._visit_level(0, self.dirpath, []) + + # Due to concurrency, pieces will potentially by out of order if the + # dataset is partitioned so we sort them to yield stable results + self.pieces.sort(key=lambda piece: piece.path) + + if self.common_metadata_path is None: + # _common_metadata is a subset of _metadata + self.common_metadata_path = self.metadata_path + + self._thread_pool.shutdown() + + def _visit_level(self, level, base_path, part_keys): + fs = self.filesystem + + _, directories, files = next(fs.walk(base_path)) + + filtered_files = [] + for path in files: + full_path = self.pathsep.join((base_path, path)) + if path.endswith('_common_metadata'): + self.common_metadata_path = full_path + elif path.endswith('_metadata'): + self.metadata_path = full_path + elif self._should_silently_exclude(path): + continue + else: + filtered_files.append(full_path) + + # ARROW-1079: Filter out "private" directories starting with underscore + filtered_directories = [self.pathsep.join((base_path, x)) + for x in directories + if not _is_private_directory(x)] + + filtered_files.sort() + filtered_directories.sort() + + if len(filtered_files) > 0 and len(filtered_directories) > 0: + raise ValueError('Found files in an intermediate ' + 'directory: {}'.format(base_path)) + elif len(filtered_directories) > 0: + self._visit_directories(level, filtered_directories, part_keys) + else: + self._push_pieces(filtered_files, part_keys) + + def _should_silently_exclude(self, file_name): + return (file_name.endswith('.crc') or # Checksums + file_name.endswith('_$folder$') or # HDFS directories in S3 + file_name.startswith('.') or # Hidden files starting with . + file_name.startswith('_') or # Hidden files starting with _ + file_name in EXCLUDED_PARQUET_PATHS) + + def _visit_directories(self, level, directories, part_keys): + futures_list = [] + for path in directories: + head, tail = _path_split(path, self.pathsep) + name, key = _parse_hive_partition(tail) + + index = self.partitions.get_index(level, name, key) + dir_part_keys = part_keys + [(name, index)] + # If you have less threads than levels, the wait call will block + # indefinitely due to multiple waits within a thread. + if level < self._metadata_nthreads: + future = self._thread_pool.submit(self._visit_level, + level + 1, + path, + dir_part_keys) + futures_list.append(future) + else: + self._visit_level(level + 1, path, dir_part_keys) + if futures_list: + futures.wait(futures_list) + + def _parse_partition(self, dirname): + if self.partition_scheme == 'hive': + return _parse_hive_partition(dirname) + else: + raise NotImplementedError('partition schema: {}' + .format(self.partition_scheme)) + + def _push_pieces(self, files, part_keys): + self.pieces.extend([ + ParquetDatasetPiece._create(path, partition_keys=part_keys, + open_file_func=self.open_file_func) + for path in files + ]) + + +def _parse_hive_partition(value): + if '=' not in value: + raise ValueError('Directory name did not appear to be a ' + 'partition: {}'.format(value)) + return value.split('=', 1) + + +def _is_private_directory(x): + _, tail = os.path.split(x) + return (tail.startswith('_') or tail.startswith('.')) and '=' not in tail + + +def _path_split(path, sep): + i = path.rfind(sep) + 1 + head, tail = path[:i], path[i:] + head = head.rstrip(sep) + return head, tail + + +EXCLUDED_PARQUET_PATHS = {'_SUCCESS'} + + +class _ParquetDatasetMetadata: + __slots__ = ('fs', 'memory_map', 'read_dictionary', 'common_metadata', + 'buffer_size') + + +def _open_dataset_file(dataset, path, meta=None): + if (dataset.fs is not None and + not isinstance(dataset.fs, legacyfs.LocalFileSystem)): + path = dataset.fs.open(path, mode='rb') + return ParquetFile( + path, + metadata=meta, + memory_map=dataset.memory_map, + read_dictionary=dataset.read_dictionary, + common_metadata=dataset.common_metadata, + buffer_size=dataset.buffer_size + ) + + +_DEPR_MSG = ( + "'{}' attribute is deprecated as of pyarrow 5.0.0 and will be removed " + "in a future version.{}" +) + + +_read_docstring_common = """\ +read_dictionary : list, default None + List of names or column paths (for nested types) to read directly + as DictionaryArray. Only supported for BYTE_ARRAY storage. To read + a flat column as dictionary-encoded pass the column name. For + nested types, you must pass the full column "path", which could be + something like level1.level2.list.item. Refer to the Parquet + file's schema to obtain the paths. +memory_map : bool, default False + If the source is a file path, use a memory map to read file, which can + improve performance in some environments. +buffer_size : int, default 0 + If positive, perform read buffering when deserializing individual + column chunks. Otherwise IO calls are unbuffered. +partitioning : Partitioning or str or list of str, default "hive" + The partitioning scheme for a partitioned dataset. The default of "hive" + assumes directory names with key=value pairs like "/year=2009/month=11". + In addition, a scheme like "/2009/11" is also supported, in which case + you need to specify the field names or a full schema. See the + ``pyarrow.dataset.partitioning()`` function for more details.""" + + +class ParquetDataset: + + __doc__ = """ +Encapsulates details of reading a complete Parquet dataset possibly +consisting of multiple files and partitions in subdirectories. + +Parameters +---------- +path_or_paths : str or List[str] + A directory name, single file name, or list of file names. +filesystem : FileSystem, default None + If nothing passed, paths assumed to be found in the local on-disk + filesystem. +metadata : pyarrow.parquet.FileMetaData + Use metadata obtained elsewhere to validate file schemas. +schema : pyarrow.parquet.Schema + Use schema obtained elsewhere to validate file schemas. Alternative to + metadata parameter. +split_row_groups : bool, default False + Divide files into pieces for each row group in the file. +validate_schema : bool, default True + Check that individual file schemas are all the same / compatible. +filters : List[Tuple] or List[List[Tuple]] or None (default) + Rows which do not match the filter predicate will be removed from scanned + data. Partition keys embedded in a nested directory structure will be + exploited to avoid loading files at all if they contain no matching rows. + If `use_legacy_dataset` is True, filters can only reference partition + keys and only a hive-style directory structure is supported. When + setting `use_legacy_dataset` to False, also within-file level filtering + and different partitioning schemes are supported. + + {1} +metadata_nthreads : int, default 1 + How many threads to allow the thread pool which is used to read the + dataset metadata. Increasing this is helpful to read partitioned + datasets. +{0} +use_legacy_dataset : bool, default True + Set to False to enable the new code path (experimental, using the + new Arrow Dataset API). Among other things, this allows to pass + `filters` for all columns and not only the partition keys, enables + different partitioning schemes, etc. +pre_buffer : bool, default True + Coalesce and issue file reads in parallel to improve performance on + high-latency filesystems (e.g. S3). If True, Arrow will use a + background I/O thread pool. This option is only supported for + use_legacy_dataset=False. If using a filesystem layer that itself + performs readahead (e.g. fsspec's S3FS), disable readahead for best + results. +coerce_int96_timestamp_unit : str, default None. + Cast timestamps that are stored in INT96 format to a particular resolution + (e.g. 'ms'). Setting to None is equivalent to 'ns' and therefore INT96 + timestamps will be infered as timestamps in nanoseconds. +""".format(_read_docstring_common, _DNF_filter_doc) + + def __new__(cls, path_or_paths=None, filesystem=None, schema=None, + metadata=None, split_row_groups=False, validate_schema=True, + filters=None, metadata_nthreads=1, read_dictionary=None, + memory_map=False, buffer_size=0, partitioning="hive", + use_legacy_dataset=None, pre_buffer=True, + coerce_int96_timestamp_unit=None): + if use_legacy_dataset is None: + # if a new filesystem is passed -> default to new implementation + if isinstance(filesystem, FileSystem): + use_legacy_dataset = False + # otherwise the default is still True + else: + use_legacy_dataset = True + + if not use_legacy_dataset: + return _ParquetDatasetV2( + path_or_paths, filesystem=filesystem, + filters=filters, + partitioning=partitioning, + read_dictionary=read_dictionary, + memory_map=memory_map, + buffer_size=buffer_size, + pre_buffer=pre_buffer, + coerce_int96_timestamp_unit=coerce_int96_timestamp_unit, + # unsupported keywords + schema=schema, metadata=metadata, + split_row_groups=split_row_groups, + validate_schema=validate_schema, + metadata_nthreads=metadata_nthreads + ) + self = object.__new__(cls) + return self + + def __init__(self, path_or_paths, filesystem=None, schema=None, + metadata=None, split_row_groups=False, validate_schema=True, + filters=None, metadata_nthreads=1, read_dictionary=None, + memory_map=False, buffer_size=0, partitioning="hive", + use_legacy_dataset=True, pre_buffer=True, + coerce_int96_timestamp_unit=None): + if partitioning != "hive": + raise ValueError( + 'Only "hive" for hive-like partitioning is supported when ' + 'using use_legacy_dataset=True') + self._metadata = _ParquetDatasetMetadata() + a_path = path_or_paths + if isinstance(a_path, list): + a_path = a_path[0] + + self._metadata.fs, _ = _get_filesystem_and_path(filesystem, a_path) + if isinstance(path_or_paths, list): + self.paths = [_parse_uri(path) for path in path_or_paths] + else: + self.paths = _parse_uri(path_or_paths) + + self._metadata.read_dictionary = read_dictionary + self._metadata.memory_map = memory_map + self._metadata.buffer_size = buffer_size + + (self._pieces, + self._partitions, + self.common_metadata_path, + self.metadata_path) = _make_manifest( + path_or_paths, self._fs, metadata_nthreads=metadata_nthreads, + open_file_func=partial(_open_dataset_file, self._metadata) + ) + + if self.common_metadata_path is not None: + with self._fs.open(self.common_metadata_path) as f: + self._metadata.common_metadata = read_metadata( + f, + memory_map=memory_map + ) + else: + self._metadata.common_metadata = None + + if metadata is None and self.metadata_path is not None: + with self._fs.open(self.metadata_path) as f: + self.metadata = read_metadata(f, memory_map=memory_map) + else: + self.metadata = metadata + + self.schema = schema + + self.split_row_groups = split_row_groups + + if split_row_groups: + raise NotImplementedError("split_row_groups not yet implemented") + + if filters is not None: + filters = _check_filters(filters) + self._filter(filters) + + if validate_schema: + self.validate_schemas() + + def equals(self, other): + if not isinstance(other, ParquetDataset): + raise TypeError('`other` must be an instance of ParquetDataset') + + if self._fs.__class__ != other._fs.__class__: + return False + for prop in ('paths', '_pieces', '_partitions', + 'common_metadata_path', 'metadata_path', + 'common_metadata', 'metadata', 'schema', + 'split_row_groups'): + if getattr(self, prop) != getattr(other, prop): + return False + for prop in ('memory_map', 'buffer_size'): + if getattr(self._metadata, prop) != getattr(other._metadata, prop): + return False + + return True + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def validate_schemas(self): + if self.metadata is None and self.schema is None: + if self.common_metadata is not None: + self.schema = self.common_metadata.schema + else: + self.schema = self._pieces[0].get_metadata().schema + elif self.schema is None: + self.schema = self.metadata.schema + + # Verify schemas are all compatible + dataset_schema = self.schema.to_arrow_schema() + # Exclude the partition columns from the schema, they are provided + # by the path, not the DatasetPiece + if self._partitions is not None: + for partition_name in self._partitions.partition_names: + if dataset_schema.get_field_index(partition_name) != -1: + field_idx = dataset_schema.get_field_index(partition_name) + dataset_schema = dataset_schema.remove(field_idx) + + for piece in self._pieces: + file_metadata = piece.get_metadata() + file_schema = file_metadata.schema.to_arrow_schema() + if not dataset_schema.equals(file_schema, check_metadata=False): + raise ValueError('Schema in {!s} was different. \n' + '{!s}\n\nvs\n\n{!s}' + .format(piece, file_schema, + dataset_schema)) + + def read(self, columns=None, use_threads=True, use_pandas_metadata=False): + """ + Read multiple Parquet files as a single pyarrow.Table. + + Parameters + ---------- + columns : List[str] + Names of columns to read from the file. + use_threads : bool, default True + Perform multi-threaded column reads + use_pandas_metadata : bool, default False + Passed through to each dataset piece. + + Returns + ------- + pyarrow.Table + Content of the file as a table (of columns). + """ + tables = [] + for piece in self._pieces: + table = piece.read(columns=columns, use_threads=use_threads, + partitions=self._partitions, + use_pandas_metadata=use_pandas_metadata) + tables.append(table) + + all_data = lib.concat_tables(tables) + + if use_pandas_metadata: + # We need to ensure that this metadata is set in the Table's schema + # so that Table.to_pandas will construct pandas.DataFrame with the + # right index + common_metadata = self._get_common_pandas_metadata() + current_metadata = all_data.schema.metadata or {} + + if common_metadata and b'pandas' not in current_metadata: + all_data = all_data.replace_schema_metadata({ + b'pandas': common_metadata}) + + return all_data + + def read_pandas(self, **kwargs): + """ + Read dataset including pandas metadata, if any. Other arguments passed + through to ParquetDataset.read, see docstring for further details. + + Parameters + ---------- + **kwargs : optional + All additional options to pass to the reader. + + Returns + ------- + pyarrow.Table + Content of the file as a table (of columns). + """ + return self.read(use_pandas_metadata=True, **kwargs) + + def _get_common_pandas_metadata(self): + if self.common_metadata is None: + return None + + keyvalues = self.common_metadata.metadata + return keyvalues.get(b'pandas', None) + + def _filter(self, filters): + accepts_filter = self._partitions.filter_accepts_partition + + def one_filter_accepts(piece, filter): + return all(accepts_filter(part_key, filter, level) + for level, part_key in enumerate(piece.partition_keys)) + + def all_filters_accept(piece): + return any(all(one_filter_accepts(piece, f) for f in conjunction) + for conjunction in filters) + + self._pieces = [p for p in self._pieces if all_filters_accept(p)] + + @property + def pieces(self): + warnings.warn( + _DEPR_MSG.format( + "ParquetDataset.pieces", + " Specify 'use_legacy_dataset=False' while constructing the " + "ParquetDataset, and then use the '.fragments' attribute " + "instead."), + DeprecationWarning, stacklevel=2) + return self._pieces + + @property + def partitions(self): + warnings.warn( + _DEPR_MSG.format( + "ParquetDataset.partitions", + " Specify 'use_legacy_dataset=False' while constructing the " + "ParquetDataset, and then use the '.partitioning' attribute " + "instead."), + DeprecationWarning, stacklevel=2) + return self._partitions + + @property + def memory_map(self): + warnings.warn( + _DEPR_MSG.format("ParquetDataset.memory_map", ""), + DeprecationWarning, stacklevel=2) + return self._metadata.memory_map + + @property + def read_dictionary(self): + warnings.warn( + _DEPR_MSG.format("ParquetDataset.read_dictionary", ""), + DeprecationWarning, stacklevel=2) + return self._metadata.read_dictionary + + @property + def buffer_size(self): + warnings.warn( + _DEPR_MSG.format("ParquetDataset.buffer_size", ""), + DeprecationWarning, stacklevel=2) + return self._metadata.buffer_size + + _fs = property( + operator.attrgetter('_metadata.fs') + ) + + @property + def fs(self): + warnings.warn( + _DEPR_MSG.format( + "ParquetDataset.fs", + " Specify 'use_legacy_dataset=False' while constructing the " + "ParquetDataset, and then use the '.filesystem' attribute " + "instead."), + DeprecationWarning, stacklevel=2) + return self._metadata.fs + + common_metadata = property( + operator.attrgetter('_metadata.common_metadata') + ) + + +def _make_manifest(path_or_paths, fs, pathsep='/', metadata_nthreads=1, + open_file_func=None): + partitions = None + common_metadata_path = None + metadata_path = None + + if isinstance(path_or_paths, list) and len(path_or_paths) == 1: + # Dask passes a directory as a list of length 1 + path_or_paths = path_or_paths[0] + + if _is_path_like(path_or_paths) and fs.isdir(path_or_paths): + manifest = ParquetManifest(path_or_paths, filesystem=fs, + open_file_func=open_file_func, + pathsep=getattr(fs, "pathsep", "/"), + metadata_nthreads=metadata_nthreads) + common_metadata_path = manifest.common_metadata_path + metadata_path = manifest.metadata_path + pieces = manifest.pieces + partitions = manifest.partitions + else: + if not isinstance(path_or_paths, list): + path_or_paths = [path_or_paths] + + # List of paths + if len(path_or_paths) == 0: + raise ValueError('Must pass at least one file path') + + pieces = [] + for path in path_or_paths: + if not fs.isfile(path): + raise OSError('Passed non-file path: {}' + .format(path)) + piece = ParquetDatasetPiece._create( + path, open_file_func=open_file_func) + pieces.append(piece) + + return pieces, partitions, common_metadata_path, metadata_path + + +def _is_local_file_system(fs): + return isinstance(fs, LocalFileSystem) or isinstance( + fs, legacyfs.LocalFileSystem + ) + + +class _ParquetDatasetV2: + """ + ParquetDataset shim using the Dataset API under the hood. + """ + + def __init__(self, path_or_paths, filesystem=None, filters=None, + partitioning="hive", read_dictionary=None, buffer_size=None, + memory_map=False, ignore_prefixes=None, pre_buffer=True, + coerce_int96_timestamp_unit=None, **kwargs): + import pyarrow.dataset as ds + + # Raise error for not supported keywords + for keyword, default in [ + ("schema", None), ("metadata", None), + ("split_row_groups", False), ("validate_schema", True), + ("metadata_nthreads", 1)]: + if keyword in kwargs and kwargs[keyword] is not default: + raise ValueError( + "Keyword '{0}' is not yet supported with the new " + "Dataset API".format(keyword)) + + # map format arguments + read_options = { + "pre_buffer": pre_buffer, + "coerce_int96_timestamp_unit": coerce_int96_timestamp_unit + } + if buffer_size: + read_options.update(use_buffered_stream=True, + buffer_size=buffer_size) + if read_dictionary is not None: + read_options.update(dictionary_columns=read_dictionary) + + # map filters to Expressions + self._filters = filters + self._filter_expression = filters and _filters_to_expression(filters) + + # map old filesystems to new one + if filesystem is not None: + filesystem = _ensure_filesystem( + filesystem, use_mmap=memory_map) + elif filesystem is None and memory_map: + # if memory_map is specified, assume local file system (string + # path can in principle be URI for any filesystem) + filesystem = LocalFileSystem(use_mmap=memory_map) + + # This needs to be checked after _ensure_filesystem, because that + # handles the case of an fsspec LocalFileSystem + if ( + hasattr(path_or_paths, "__fspath__") and + filesystem is not None and + not _is_local_file_system(filesystem) + ): + raise TypeError( + "Path-like objects with __fspath__ must only be used with " + f"local file systems, not {type(filesystem)}" + ) + + # check for single fragment dataset + single_file = None + if isinstance(path_or_paths, list): + if len(path_or_paths) == 1: + single_file = path_or_paths[0] + else: + if _is_path_like(path_or_paths): + path_or_paths = _stringify_path(path_or_paths) + if filesystem is None: + # path might be a URI describing the FileSystem as well + try: + filesystem, path_or_paths = FileSystem.from_uri( + path_or_paths) + except ValueError: + filesystem = LocalFileSystem(use_mmap=memory_map) + if filesystem.get_file_info(path_or_paths).is_file: + single_file = path_or_paths + else: + single_file = path_or_paths + + if single_file is not None: + self._enable_parallel_column_conversion = True + read_options.update(enable_parallel_column_conversion=True) + + parquet_format = ds.ParquetFileFormat(**read_options) + fragment = parquet_format.make_fragment(single_file, filesystem) + + self._dataset = ds.FileSystemDataset( + [fragment], schema=fragment.physical_schema, + format=parquet_format, + filesystem=fragment.filesystem + ) + return + else: + self._enable_parallel_column_conversion = False + + parquet_format = ds.ParquetFileFormat(**read_options) + + # check partitioning to enable dictionary encoding + if partitioning == "hive": + partitioning = ds.HivePartitioning.discover( + infer_dictionary=True) + + self._dataset = ds.dataset(path_or_paths, filesystem=filesystem, + format=parquet_format, + partitioning=partitioning, + ignore_prefixes=ignore_prefixes) + + @property + def schema(self): + return self._dataset.schema + + def read(self, columns=None, use_threads=True, use_pandas_metadata=False): + """ + Read (multiple) Parquet files as a single pyarrow.Table. + + Parameters + ---------- + columns : List[str] + Names of columns to read from the dataset. The partition fields + are not automatically included (in contrast to when setting + ``use_legacy_dataset=True``). + use_threads : bool, default True + Perform multi-threaded column reads. + use_pandas_metadata : bool, default False + If True and file has custom pandas schema metadata, ensure that + index columns are also loaded. + + Returns + ------- + pyarrow.Table + Content of the file as a table (of columns). + """ + # if use_pandas_metadata, we need to include index columns in the + # column selection, to be able to restore those in the pandas DataFrame + metadata = self.schema.metadata + if columns is not None and use_pandas_metadata: + if metadata and b'pandas' in metadata: + # RangeIndex can be represented as dict instead of column name + index_columns = [ + col for col in _get_pandas_index_columns(metadata) + if not isinstance(col, dict) + ] + columns = ( + list(columns) + list(set(index_columns) - set(columns)) + ) + + if self._enable_parallel_column_conversion: + if use_threads: + # Allow per-column parallelism; would otherwise cause + # contention in the presence of per-file parallelism. + use_threads = False + + table = self._dataset.to_table( + columns=columns, filter=self._filter_expression, + use_threads=use_threads + ) + + # if use_pandas_metadata, restore the pandas metadata (which gets + # lost if doing a specific `columns` selection in to_table) + if use_pandas_metadata: + if metadata and b"pandas" in metadata: + new_metadata = table.schema.metadata or {} + new_metadata.update({b"pandas": metadata[b"pandas"]}) + table = table.replace_schema_metadata(new_metadata) + + return table + + def read_pandas(self, **kwargs): + """ + Read dataset including pandas metadata, if any. Other arguments passed + through to ParquetDataset.read, see docstring for further details. + """ + return self.read(use_pandas_metadata=True, **kwargs) + + @property + def pieces(self): + warnings.warn( + _DEPR_MSG.format("ParquetDataset.pieces", + " Use the '.fragments' attribute instead"), + DeprecationWarning, stacklevel=2) + return list(self._dataset.get_fragments()) + + @property + def fragments(self): + return list(self._dataset.get_fragments()) + + @property + def files(self): + return self._dataset.files + + @property + def filesystem(self): + return self._dataset.filesystem + + @property + def partitioning(self): + """ + The partitioning of the Dataset source, if discovered. + """ + return self._dataset.partitioning + + +_read_table_docstring = """ +{0} + +Parameters +---------- +source : str, pyarrow.NativeFile, or file-like object + If a string passed, can be a single file name or directory name. For + file-like objects, only read a single file. Use pyarrow.BufferReader to + read a file contained in a bytes or buffer-like object. +columns : list + If not None, only these columns will be read from the file. A column + name may be a prefix of a nested field, e.g. 'a' will select 'a.b', + 'a.c', and 'a.d.e'. If empty, no columns will be read. Note + that the table will still have the correct num_rows set despite having + no columns. +use_threads : bool, default True + Perform multi-threaded column reads. +metadata : FileMetaData + If separately computed +{1} +use_legacy_dataset : bool, default False + By default, `read_table` uses the new Arrow Datasets API since + pyarrow 1.0.0. Among other things, this allows to pass `filters` + for all columns and not only the partition keys, enables + different partitioning schemes, etc. + Set to True to use the legacy behaviour. +ignore_prefixes : list, optional + Files matching any of these prefixes will be ignored by the + discovery process if use_legacy_dataset=False. + This is matched to the basename of a path. + By default this is ['.', '_']. + Note that discovery happens only if a directory is passed as source. +filesystem : FileSystem, default None + If nothing passed, paths assumed to be found in the local on-disk + filesystem. +filters : List[Tuple] or List[List[Tuple]] or None (default) + Rows which do not match the filter predicate will be removed from scanned + data. Partition keys embedded in a nested directory structure will be + exploited to avoid loading files at all if they contain no matching rows. + If `use_legacy_dataset` is True, filters can only reference partition + keys and only a hive-style directory structure is supported. When + setting `use_legacy_dataset` to False, also within-file level filtering + and different partitioning schemes are supported. + + {3} +pre_buffer : bool, default True + Coalesce and issue file reads in parallel to improve performance on + high-latency filesystems (e.g. S3). If True, Arrow will use a + background I/O thread pool. This option is only supported for + use_legacy_dataset=False. If using a filesystem layer that itself + performs readahead (e.g. fsspec's S3FS), disable readahead for best + results. +coerce_int96_timestamp_unit : str, default None. + Cast timestamps that are stored in INT96 format to a particular + resolution (e.g. 'ms'). Setting to None is equivalent to 'ns' + and therefore INT96 timestamps will be infered as timestamps + in nanoseconds. + +Returns +------- +{2} +""" + + +def read_table(source, columns=None, use_threads=True, metadata=None, + use_pandas_metadata=False, memory_map=False, + read_dictionary=None, filesystem=None, filters=None, + buffer_size=0, partitioning="hive", use_legacy_dataset=False, + ignore_prefixes=None, pre_buffer=True, + coerce_int96_timestamp_unit=None): + if not use_legacy_dataset: + if metadata is not None: + raise ValueError( + "The 'metadata' keyword is no longer supported with the new " + "datasets-based implementation. Specify " + "'use_legacy_dataset=True' to temporarily recover the old " + "behaviour." + ) + try: + dataset = _ParquetDatasetV2( + source, + filesystem=filesystem, + partitioning=partitioning, + memory_map=memory_map, + read_dictionary=read_dictionary, + buffer_size=buffer_size, + filters=filters, + ignore_prefixes=ignore_prefixes, + pre_buffer=pre_buffer, + coerce_int96_timestamp_unit=coerce_int96_timestamp_unit + ) + except ImportError: + # fall back on ParquetFile for simple cases when pyarrow.dataset + # module is not available + if filters is not None: + raise ValueError( + "the 'filters' keyword is not supported when the " + "pyarrow.dataset module is not available" + ) + if partitioning != "hive": + raise ValueError( + "the 'partitioning' keyword is not supported when the " + "pyarrow.dataset module is not available" + ) + filesystem, path = _resolve_filesystem_and_path(source, filesystem) + if filesystem is not None: + source = filesystem.open_input_file(path) + # TODO test that source is not a directory or a list + dataset = ParquetFile( + source, metadata=metadata, read_dictionary=read_dictionary, + memory_map=memory_map, buffer_size=buffer_size, + pre_buffer=pre_buffer, + coerce_int96_timestamp_unit=coerce_int96_timestamp_unit + ) + + return dataset.read(columns=columns, use_threads=use_threads, + use_pandas_metadata=use_pandas_metadata) + + if ignore_prefixes is not None: + raise ValueError( + "The 'ignore_prefixes' keyword is only supported when " + "use_legacy_dataset=False") + + if _is_path_like(source): + pf = ParquetDataset( + source, metadata=metadata, memory_map=memory_map, + read_dictionary=read_dictionary, + buffer_size=buffer_size, + filesystem=filesystem, filters=filters, + partitioning=partitioning, + coerce_int96_timestamp_unit=coerce_int96_timestamp_unit + ) + else: + pf = ParquetFile( + source, metadata=metadata, + read_dictionary=read_dictionary, + memory_map=memory_map, + buffer_size=buffer_size, + coerce_int96_timestamp_unit=coerce_int96_timestamp_unit + ) + return pf.read(columns=columns, use_threads=use_threads, + use_pandas_metadata=use_pandas_metadata) + + +read_table.__doc__ = _read_table_docstring.format( + """Read a Table from Parquet format + +Note: starting with pyarrow 1.0, the default for `use_legacy_dataset` is +switched to False.""", + "\n".join((_read_docstring_common, + """use_pandas_metadata : bool, default False + If True and file has custom pandas schema metadata, ensure that + index columns are also loaded.""")), + """pyarrow.Table + Content of the file as a table (of columns)""", + _DNF_filter_doc) + + +def read_pandas(source, columns=None, **kwargs): + return read_table( + source, columns=columns, use_pandas_metadata=True, **kwargs + ) + + +read_pandas.__doc__ = _read_table_docstring.format( + 'Read a Table from Parquet format, also reading DataFrame\n' + 'index values if known in the file metadata', + "\n".join((_read_docstring_common, + """**kwargs : additional options for :func:`read_table`""")), + """pyarrow.Table + Content of the file as a Table of Columns, including DataFrame + indexes as columns""", + _DNF_filter_doc) + + +def write_table(table, where, row_group_size=None, version='1.0', + use_dictionary=True, compression='snappy', + write_statistics=True, + use_deprecated_int96_timestamps=None, + coerce_timestamps=None, + allow_truncated_timestamps=False, + data_page_size=None, flavor=None, + filesystem=None, + compression_level=None, + use_byte_stream_split=False, + data_page_version='1.0', + use_compliant_nested_type=False, + **kwargs): + row_group_size = kwargs.pop('chunk_size', row_group_size) + use_int96 = use_deprecated_int96_timestamps + try: + with ParquetWriter( + where, table.schema, + filesystem=filesystem, + version=version, + flavor=flavor, + use_dictionary=use_dictionary, + write_statistics=write_statistics, + coerce_timestamps=coerce_timestamps, + data_page_size=data_page_size, + allow_truncated_timestamps=allow_truncated_timestamps, + compression=compression, + use_deprecated_int96_timestamps=use_int96, + compression_level=compression_level, + use_byte_stream_split=use_byte_stream_split, + data_page_version=data_page_version, + use_compliant_nested_type=use_compliant_nested_type, + **kwargs) as writer: + writer.write_table(table, row_group_size=row_group_size) + except Exception: + if _is_path_like(where): + try: + os.remove(_stringify_path(where)) + except os.error: + pass + raise + + +write_table.__doc__ = """ +Write a Table to Parquet format. + +Parameters +---------- +table : pyarrow.Table +where : string or pyarrow.NativeFile +row_group_size : int + The number of rows per rowgroup +{} +**kwargs : optional + Additional options for ParquetWriter +""".format(_parquet_writer_arg_docs) + + +def _mkdir_if_not_exists(fs, path): + if fs._isfilestore() and not fs.exists(path): + try: + fs.mkdir(path) + except OSError: + assert fs.exists(path) + + +def write_to_dataset(table, root_path, partition_cols=None, + partition_filename_cb=None, filesystem=None, + use_legacy_dataset=None, **kwargs): + """Wrapper around parquet.write_table for writing a Table to + Parquet format by partitions. + For each combination of partition columns and values, + a subdirectories are created in the following + manner: + + root_dir/ + group1=value1 + group2=value1 + <uuid>.parquet + group2=value2 + <uuid>.parquet + group1=valueN + group2=value1 + <uuid>.parquet + group2=valueN + <uuid>.parquet + + Parameters + ---------- + table : pyarrow.Table + root_path : str, pathlib.Path + The root directory of the dataset + filesystem : FileSystem, default None + If nothing passed, paths assumed to be found in the local on-disk + filesystem + partition_cols : list, + Column names by which to partition the dataset + Columns are partitioned in the order they are given + partition_filename_cb : callable, + A callback function that takes the partition key(s) as an argument + and allow you to override the partition filename. If nothing is + passed, the filename will consist of a uuid. + use_legacy_dataset : bool + Default is True unless a ``pyarrow.fs`` filesystem is passed. + Set to False to enable the new code path (experimental, using the + new Arrow Dataset API). This is more efficient when using partition + columns, but does not (yet) support `partition_filename_cb` and + `metadata_collector` keywords. + **kwargs : dict, + Additional kwargs for write_table function. See docstring for + `write_table` or `ParquetWriter` for more information. + Using `metadata_collector` in kwargs allows one to collect the + file metadata instances of dataset pieces. The file paths in the + ColumnChunkMetaData will be set relative to `root_path`. + """ + if use_legacy_dataset is None: + # if a new filesystem is passed -> default to new implementation + if isinstance(filesystem, FileSystem): + use_legacy_dataset = False + # otherwise the default is still True + else: + use_legacy_dataset = True + + if not use_legacy_dataset: + import pyarrow.dataset as ds + + # extract non-file format options + schema = kwargs.pop("schema", None) + use_threads = kwargs.pop("use_threads", True) + + # raise for unsupported keywords + msg = ( + "The '{}' argument is not supported with the new dataset " + "implementation." + ) + metadata_collector = kwargs.pop('metadata_collector', None) + file_visitor = None + if metadata_collector is not None: + def file_visitor(written_file): + metadata_collector.append(written_file.metadata) + if partition_filename_cb is not None: + raise ValueError(msg.format("partition_filename_cb")) + + # map format arguments + parquet_format = ds.ParquetFileFormat() + write_options = parquet_format.make_write_options(**kwargs) + + # map old filesystems to new one + if filesystem is not None: + filesystem = _ensure_filesystem(filesystem) + + partitioning = None + if partition_cols: + part_schema = table.select(partition_cols).schema + partitioning = ds.partitioning(part_schema, flavor="hive") + + ds.write_dataset( + table, root_path, filesystem=filesystem, + format=parquet_format, file_options=write_options, schema=schema, + partitioning=partitioning, use_threads=use_threads, + file_visitor=file_visitor) + return + + fs, root_path = legacyfs.resolve_filesystem_and_path(root_path, filesystem) + + _mkdir_if_not_exists(fs, root_path) + + metadata_collector = kwargs.pop('metadata_collector', None) + + if partition_cols is not None and len(partition_cols) > 0: + df = table.to_pandas() + partition_keys = [df[col] for col in partition_cols] + data_df = df.drop(partition_cols, axis='columns') + data_cols = df.columns.drop(partition_cols) + if len(data_cols) == 0: + raise ValueError('No data left to save outside partition columns') + + subschema = table.schema + + # ARROW-2891: Ensure the output_schema is preserved when writing a + # partitioned dataset + for col in table.schema.names: + if col in partition_cols: + subschema = subschema.remove(subschema.get_field_index(col)) + + for keys, subgroup in data_df.groupby(partition_keys): + if not isinstance(keys, tuple): + keys = (keys,) + subdir = '/'.join( + ['{colname}={value}'.format(colname=name, value=val) + for name, val in zip(partition_cols, keys)]) + subtable = pa.Table.from_pandas(subgroup, schema=subschema, + safe=False) + _mkdir_if_not_exists(fs, '/'.join([root_path, subdir])) + if partition_filename_cb: + outfile = partition_filename_cb(keys) + else: + outfile = guid() + '.parquet' + relative_path = '/'.join([subdir, outfile]) + full_path = '/'.join([root_path, relative_path]) + with fs.open(full_path, 'wb') as f: + write_table(subtable, f, metadata_collector=metadata_collector, + **kwargs) + if metadata_collector is not None: + metadata_collector[-1].set_file_path(relative_path) + else: + if partition_filename_cb: + outfile = partition_filename_cb(None) + else: + outfile = guid() + '.parquet' + full_path = '/'.join([root_path, outfile]) + with fs.open(full_path, 'wb') as f: + write_table(table, f, metadata_collector=metadata_collector, + **kwargs) + if metadata_collector is not None: + metadata_collector[-1].set_file_path(outfile) + + +def write_metadata(schema, where, metadata_collector=None, **kwargs): + """ + Write metadata-only Parquet file from schema. This can be used with + `write_to_dataset` to generate `_common_metadata` and `_metadata` sidecar + files. + + Parameters + ---------- + schema : pyarrow.Schema + where : string or pyarrow.NativeFile + metadata_collector : list + where to collect metadata information. + **kwargs : dict, + Additional kwargs for ParquetWriter class. See docstring for + `ParquetWriter` for more information. + + Examples + -------- + + Write a dataset and collect metadata information. + + >>> metadata_collector = [] + >>> write_to_dataset( + ... table, root_path, + ... metadata_collector=metadata_collector, **writer_kwargs) + + Write the `_common_metadata` parquet file without row groups statistics. + + >>> write_metadata( + ... table.schema, root_path / '_common_metadata', **writer_kwargs) + + Write the `_metadata` parquet file with row groups statistics. + + >>> write_metadata( + ... table.schema, root_path / '_metadata', + ... metadata_collector=metadata_collector, **writer_kwargs) + """ + writer = ParquetWriter(where, schema, **kwargs) + writer.close() + + if metadata_collector is not None: + # ParquetWriter doesn't expose the metadata until it's written. Write + # it and read it again. + metadata = read_metadata(where) + for m in metadata_collector: + metadata.append_row_groups(m) + metadata.write_metadata_file(where) + + +def read_metadata(where, memory_map=False): + """ + Read FileMetadata from footer of a single Parquet file. + + Parameters + ---------- + where : str (filepath) or file-like object + memory_map : bool, default False + Create memory map when the source is a file path. + + Returns + ------- + metadata : FileMetadata + """ + return ParquetFile(where, memory_map=memory_map).metadata + + +def read_schema(where, memory_map=False): + """ + Read effective Arrow schema from Parquet file metadata. + + Parameters + ---------- + where : str (filepath) or file-like object + memory_map : bool, default False + Create memory map when the source is a file path. + + Returns + ------- + schema : pyarrow.Schema + """ + return ParquetFile(where, memory_map=memory_map).schema.to_arrow_schema() diff --git a/src/arrow/python/pyarrow/plasma.py b/src/arrow/python/pyarrow/plasma.py new file mode 100644 index 000000000..239d29094 --- /dev/null +++ b/src/arrow/python/pyarrow/plasma.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import contextlib +import os +import pyarrow as pa +import shutil +import subprocess +import sys +import tempfile +import time + +from pyarrow._plasma import (ObjectID, ObjectNotAvailable, # noqa + PlasmaBuffer, PlasmaClient, connect, + PlasmaObjectExists, PlasmaObjectNotFound, + PlasmaStoreFull) + + +# The Plasma TensorFlow Operator needs to be compiled on the end user's +# machine since the TensorFlow ABI is not stable between versions. +# The following code checks if the operator is already present. If not, +# the function build_plasma_tensorflow_op can be used to compile it. + + +TF_PLASMA_OP_PATH = os.path.join(pa.__path__[0], "tensorflow", "plasma_op.so") + + +tf_plasma_op = None + + +def load_plasma_tensorflow_op(): + global tf_plasma_op + import tensorflow as tf + tf_plasma_op = tf.load_op_library(TF_PLASMA_OP_PATH) + + +def build_plasma_tensorflow_op(): + global tf_plasma_op + try: + import tensorflow as tf + print("TensorFlow version: " + tf.__version__) + except ImportError: + pass + else: + print("Compiling Plasma TensorFlow Op...") + dir_path = os.path.dirname(os.path.realpath(__file__)) + cc_path = os.path.join(dir_path, "tensorflow", "plasma_op.cc") + so_path = os.path.join(dir_path, "tensorflow", "plasma_op.so") + tf_cflags = tf.sysconfig.get_compile_flags() + if sys.platform == 'darwin': + tf_cflags = ["-undefined", "dynamic_lookup"] + tf_cflags + cmd = ["g++", "-std=c++11", "-g", "-shared", cc_path, + "-o", so_path, "-DNDEBUG", "-I" + pa.get_include()] + cmd += ["-L" + dir for dir in pa.get_library_dirs()] + cmd += ["-lplasma", "-larrow_python", "-larrow", "-fPIC"] + cmd += tf_cflags + cmd += tf.sysconfig.get_link_flags() + cmd += ["-O2"] + if tf.test.is_built_with_cuda(): + cmd += ["-DGOOGLE_CUDA"] + print("Running command " + str(cmd)) + subprocess.check_call(cmd) + tf_plasma_op = tf.load_op_library(TF_PLASMA_OP_PATH) + + +@contextlib.contextmanager +def start_plasma_store(plasma_store_memory, + use_valgrind=False, use_profiler=False, + plasma_directory=None, use_hugepages=False, + external_store=None): + """Start a plasma store process. + Args: + plasma_store_memory (int): Capacity of the plasma store in bytes. + use_valgrind (bool): True if the plasma store should be started inside + of valgrind. If this is True, use_profiler must be False. + use_profiler (bool): True if the plasma store should be started inside + a profiler. If this is True, use_valgrind must be False. + plasma_directory (str): Directory where plasma memory mapped files + will be stored. + use_hugepages (bool): True if the plasma store should use huge pages. + external_store (str): External store to use for evicted objects. + Return: + A tuple of the name of the plasma store socket and the process ID of + the plasma store process. + """ + if use_valgrind and use_profiler: + raise Exception("Cannot use valgrind and profiler at the same time.") + + tmpdir = tempfile.mkdtemp(prefix='test_plasma-') + try: + plasma_store_name = os.path.join(tmpdir, 'plasma.sock') + plasma_store_executable = os.path.join( + pa.__path__[0], "plasma-store-server") + if not os.path.exists(plasma_store_executable): + # Fallback to sys.prefix/bin/ (conda) + plasma_store_executable = os.path.join( + sys.prefix, "bin", "plasma-store-server") + command = [plasma_store_executable, + "-s", plasma_store_name, + "-m", str(plasma_store_memory)] + if plasma_directory: + command += ["-d", plasma_directory] + if use_hugepages: + command += ["-h"] + if external_store is not None: + command += ["-e", external_store] + stdout_file = None + stderr_file = None + if use_valgrind: + command = ["valgrind", + "--track-origins=yes", + "--leak-check=full", + "--show-leak-kinds=all", + "--leak-check-heuristics=stdstring", + "--error-exitcode=1"] + command + proc = subprocess.Popen(command, stdout=stdout_file, + stderr=stderr_file) + time.sleep(1.0) + elif use_profiler: + command = ["valgrind", "--tool=callgrind"] + command + proc = subprocess.Popen(command, stdout=stdout_file, + stderr=stderr_file) + time.sleep(1.0) + else: + proc = subprocess.Popen(command, stdout=stdout_file, + stderr=stderr_file) + time.sleep(0.1) + rc = proc.poll() + if rc is not None: + raise RuntimeError("plasma_store exited unexpectedly with " + "code %d" % (rc,)) + + yield plasma_store_name, proc + finally: + if proc.poll() is None: + proc.kill() + shutil.rmtree(tmpdir) diff --git a/src/arrow/python/pyarrow/public-api.pxi b/src/arrow/python/pyarrow/public-api.pxi new file mode 100644 index 000000000..c427fb9f5 --- /dev/null +++ b/src/arrow/python/pyarrow/public-api.pxi @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from libcpp.memory cimport shared_ptr +from pyarrow.includes.libarrow cimport (CArray, CDataType, CField, + CRecordBatch, CSchema, + CTable, CTensor, + CSparseCOOTensor, CSparseCSRMatrix, + CSparseCSCMatrix, CSparseCSFTensor) + +# You cannot assign something to a dereferenced pointer in Cython thus these +# methods don't use Status to indicate a successful operation. + + +cdef api bint pyarrow_is_buffer(object buffer): + return isinstance(buffer, Buffer) + + +cdef api shared_ptr[CBuffer] pyarrow_unwrap_buffer(object buffer): + cdef Buffer buf + if pyarrow_is_buffer(buffer): + buf = <Buffer>(buffer) + return buf.buffer + + return shared_ptr[CBuffer]() + + +cdef api object pyarrow_wrap_buffer(const shared_ptr[CBuffer]& buf): + cdef Buffer result = Buffer.__new__(Buffer) + result.init(buf) + return result + + +cdef api object pyarrow_wrap_resizable_buffer( + const shared_ptr[CResizableBuffer]& buf): + cdef ResizableBuffer result = ResizableBuffer.__new__(ResizableBuffer) + result.init_rz(buf) + return result + + +cdef api bint pyarrow_is_data_type(object type_): + return isinstance(type_, DataType) + + +cdef api shared_ptr[CDataType] pyarrow_unwrap_data_type( + object data_type): + cdef DataType type_ + if pyarrow_is_data_type(data_type): + type_ = <DataType>(data_type) + return type_.sp_type + + return shared_ptr[CDataType]() + + +# Workaround for Cython parsing bug +# https://github.com/cython/cython/issues/2143 +ctypedef const CPyExtensionType* _CPyExtensionTypePtr + + +cdef api object pyarrow_wrap_data_type( + const shared_ptr[CDataType]& type): + cdef: + const CExtensionType* ext_type + const CPyExtensionType* cpy_ext_type + DataType out + + if type.get() == NULL: + return None + + if type.get().id() == _Type_DICTIONARY: + out = DictionaryType.__new__(DictionaryType) + elif type.get().id() == _Type_LIST: + out = ListType.__new__(ListType) + elif type.get().id() == _Type_LARGE_LIST: + out = LargeListType.__new__(LargeListType) + elif type.get().id() == _Type_MAP: + out = MapType.__new__(MapType) + elif type.get().id() == _Type_FIXED_SIZE_LIST: + out = FixedSizeListType.__new__(FixedSizeListType) + elif type.get().id() == _Type_STRUCT: + out = StructType.__new__(StructType) + elif type.get().id() == _Type_SPARSE_UNION: + out = SparseUnionType.__new__(SparseUnionType) + elif type.get().id() == _Type_DENSE_UNION: + out = DenseUnionType.__new__(DenseUnionType) + elif type.get().id() == _Type_TIMESTAMP: + out = TimestampType.__new__(TimestampType) + elif type.get().id() == _Type_DURATION: + out = DurationType.__new__(DurationType) + elif type.get().id() == _Type_FIXED_SIZE_BINARY: + out = FixedSizeBinaryType.__new__(FixedSizeBinaryType) + elif type.get().id() == _Type_DECIMAL128: + out = Decimal128Type.__new__(Decimal128Type) + elif type.get().id() == _Type_DECIMAL256: + out = Decimal256Type.__new__(Decimal256Type) + elif type.get().id() == _Type_EXTENSION: + ext_type = <const CExtensionType*> type.get() + cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type) + if cpy_ext_type != nullptr: + return cpy_ext_type.GetInstance() + else: + out = BaseExtensionType.__new__(BaseExtensionType) + else: + out = DataType.__new__(DataType) + + out.init(type) + return out + + +cdef object pyarrow_wrap_metadata( + const shared_ptr[const CKeyValueMetadata]& meta): + if meta.get() == nullptr: + return None + else: + return KeyValueMetadata.wrap(meta) + + +cdef api bint pyarrow_is_metadata(object metadata): + return isinstance(metadata, KeyValueMetadata) + + +cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata(object meta): + cdef shared_ptr[const CKeyValueMetadata] c_meta + if pyarrow_is_metadata(meta): + c_meta = (<KeyValueMetadata>meta).unwrap() + return c_meta + + +cdef api bint pyarrow_is_field(object field): + return isinstance(field, Field) + + +cdef api shared_ptr[CField] pyarrow_unwrap_field(object field): + cdef Field field_ + if pyarrow_is_field(field): + field_ = <Field>(field) + return field_.sp_field + + return shared_ptr[CField]() + + +cdef api object pyarrow_wrap_field(const shared_ptr[CField]& field): + if field.get() == NULL: + return None + cdef Field out = Field.__new__(Field) + out.init(field) + return out + + +cdef api bint pyarrow_is_schema(object schema): + return isinstance(schema, Schema) + + +cdef api shared_ptr[CSchema] pyarrow_unwrap_schema(object schema): + cdef Schema sch + if pyarrow_is_schema(schema): + sch = <Schema>(schema) + return sch.sp_schema + + return shared_ptr[CSchema]() + + +cdef api object pyarrow_wrap_schema(const shared_ptr[CSchema]& schema): + cdef Schema out = Schema.__new__(Schema) + out.init_schema(schema) + return out + + +cdef api bint pyarrow_is_array(object array): + return isinstance(array, Array) + + +cdef api shared_ptr[CArray] pyarrow_unwrap_array(object array): + cdef Array arr + if pyarrow_is_array(array): + arr = <Array>(array) + return arr.sp_array + + return shared_ptr[CArray]() + + +cdef api object pyarrow_wrap_array(const shared_ptr[CArray]& sp_array): + if sp_array.get() == NULL: + raise ValueError('Array was NULL') + + klass = get_array_class_from_type(sp_array.get().type()) + + cdef Array arr = klass.__new__(klass) + arr.init(sp_array) + return arr + + +cdef api bint pyarrow_is_chunked_array(object array): + return isinstance(array, ChunkedArray) + + +cdef api shared_ptr[CChunkedArray] pyarrow_unwrap_chunked_array(object array): + cdef ChunkedArray arr + if pyarrow_is_chunked_array(array): + arr = <ChunkedArray>(array) + return arr.sp_chunked_array + + return shared_ptr[CChunkedArray]() + + +cdef api object pyarrow_wrap_chunked_array( + const shared_ptr[CChunkedArray]& sp_array): + if sp_array.get() == NULL: + raise ValueError('ChunkedArray was NULL') + + cdef CDataType* data_type = sp_array.get().type().get() + + if data_type == NULL: + raise ValueError('ChunkedArray data type was NULL') + + cdef ChunkedArray arr = ChunkedArray.__new__(ChunkedArray) + arr.init(sp_array) + return arr + + +cdef api bint pyarrow_is_scalar(object value): + return isinstance(value, Scalar) + + +cdef api shared_ptr[CScalar] pyarrow_unwrap_scalar(object scalar): + if pyarrow_is_scalar(scalar): + return (<Scalar> scalar).unwrap() + return shared_ptr[CScalar]() + + +cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar): + if sp_scalar.get() == NULL: + raise ValueError('Scalar was NULL') + + cdef CDataType* data_type = sp_scalar.get().type.get() + + if data_type == NULL: + raise ValueError('Scalar data type was NULL') + + if data_type.id() == _Type_NA: + return _NULL + + if data_type.id() not in _scalar_classes: + raise ValueError('Scalar type not supported') + + klass = _scalar_classes[data_type.id()] + + cdef Scalar scalar = klass.__new__(klass) + scalar.init(sp_scalar) + return scalar + + +cdef api bint pyarrow_is_tensor(object tensor): + return isinstance(tensor, Tensor) + + +cdef api shared_ptr[CTensor] pyarrow_unwrap_tensor(object tensor): + cdef Tensor ten + if pyarrow_is_tensor(tensor): + ten = <Tensor>(tensor) + return ten.sp_tensor + + return shared_ptr[CTensor]() + + +cdef api object pyarrow_wrap_tensor( + const shared_ptr[CTensor]& sp_tensor): + if sp_tensor.get() == NULL: + raise ValueError('Tensor was NULL') + + cdef Tensor tensor = Tensor.__new__(Tensor) + tensor.init(sp_tensor) + return tensor + + +cdef api bint pyarrow_is_sparse_coo_tensor(object sparse_tensor): + return isinstance(sparse_tensor, SparseCOOTensor) + +cdef api shared_ptr[CSparseCOOTensor] pyarrow_unwrap_sparse_coo_tensor( + object sparse_tensor): + cdef SparseCOOTensor sten + if pyarrow_is_sparse_coo_tensor(sparse_tensor): + sten = <SparseCOOTensor>(sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCOOTensor]() + +cdef api object pyarrow_wrap_sparse_coo_tensor( + const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCOOTensor was NULL') + + cdef SparseCOOTensor sparse_tensor = SparseCOOTensor.__new__( + SparseCOOTensor) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_sparse_csr_matrix(object sparse_tensor): + return isinstance(sparse_tensor, SparseCSRMatrix) + +cdef api shared_ptr[CSparseCSRMatrix] pyarrow_unwrap_sparse_csr_matrix( + object sparse_tensor): + cdef SparseCSRMatrix sten + if pyarrow_is_sparse_csr_matrix(sparse_tensor): + sten = <SparseCSRMatrix>(sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCSRMatrix]() + +cdef api object pyarrow_wrap_sparse_csr_matrix( + const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCSRMatrix was NULL') + + cdef SparseCSRMatrix sparse_tensor = SparseCSRMatrix.__new__( + SparseCSRMatrix) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_sparse_csc_matrix(object sparse_tensor): + return isinstance(sparse_tensor, SparseCSCMatrix) + +cdef api shared_ptr[CSparseCSCMatrix] pyarrow_unwrap_sparse_csc_matrix( + object sparse_tensor): + cdef SparseCSCMatrix sten + if pyarrow_is_sparse_csc_matrix(sparse_tensor): + sten = <SparseCSCMatrix>(sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCSCMatrix]() + +cdef api object pyarrow_wrap_sparse_csc_matrix( + const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCSCMatrix was NULL') + + cdef SparseCSCMatrix sparse_tensor = SparseCSCMatrix.__new__( + SparseCSCMatrix) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_sparse_csf_tensor(object sparse_tensor): + return isinstance(sparse_tensor, SparseCSFTensor) + +cdef api shared_ptr[CSparseCSFTensor] pyarrow_unwrap_sparse_csf_tensor( + object sparse_tensor): + cdef SparseCSFTensor sten + if pyarrow_is_sparse_csf_tensor(sparse_tensor): + sten = <SparseCSFTensor>(sparse_tensor) + return sten.sp_sparse_tensor + + return shared_ptr[CSparseCSFTensor]() + +cdef api object pyarrow_wrap_sparse_csf_tensor( + const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor): + if sp_sparse_tensor.get() == NULL: + raise ValueError('SparseCSFTensor was NULL') + + cdef SparseCSFTensor sparse_tensor = SparseCSFTensor.__new__( + SparseCSFTensor) + sparse_tensor.init(sp_sparse_tensor) + return sparse_tensor + + +cdef api bint pyarrow_is_table(object table): + return isinstance(table, Table) + + +cdef api shared_ptr[CTable] pyarrow_unwrap_table(object table): + cdef Table tab + if pyarrow_is_table(table): + tab = <Table>(table) + return tab.sp_table + + return shared_ptr[CTable]() + + +cdef api object pyarrow_wrap_table(const shared_ptr[CTable]& ctable): + cdef Table table = Table.__new__(Table) + table.init(ctable) + return table + + +cdef api bint pyarrow_is_batch(object batch): + return isinstance(batch, RecordBatch) + + +cdef api shared_ptr[CRecordBatch] pyarrow_unwrap_batch(object batch): + cdef RecordBatch bat + if pyarrow_is_batch(batch): + bat = <RecordBatch>(batch) + return bat.sp_batch + + return shared_ptr[CRecordBatch]() + + +cdef api object pyarrow_wrap_batch( + const shared_ptr[CRecordBatch]& cbatch): + cdef RecordBatch batch = RecordBatch.__new__(RecordBatch) + batch.init(cbatch) + return batch diff --git a/src/arrow/python/pyarrow/scalar.pxi b/src/arrow/python/pyarrow/scalar.pxi new file mode 100644 index 000000000..80fcc0028 --- /dev/null +++ b/src/arrow/python/pyarrow/scalar.pxi @@ -0,0 +1,1048 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections + + +cdef class Scalar(_Weakrefable): + """ + The base class for scalars. + """ + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use " + "pa.scalar() instead.".format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CScalar]& wrapped): + self.wrapped = wrapped + + @staticmethod + cdef wrap(const shared_ptr[CScalar]& wrapped): + cdef: + Scalar self + Type type_id = wrapped.get().type.get().id() + + if type_id == _Type_NA: + return _NULL + + try: + typ = _scalar_classes[type_id] + except KeyError: + raise NotImplementedError( + "Wrapping scalar of type " + + frombytes(wrapped.get().type.get().ToString())) + self = typ.__new__(typ) + self.init(wrapped) + + return self + + cdef inline shared_ptr[CScalar] unwrap(self) nogil: + return self.wrapped + + @property + def type(self): + """ + Data type of the Scalar object. + """ + return pyarrow_wrap_data_type(self.wrapped.get().type) + + @property + def is_valid(self): + """ + Holds a valid (non-null) value. + """ + return self.wrapped.get().is_valid + + def cast(self, object target_type): + """ + Attempt a safe cast to target data type. + """ + cdef: + DataType type = ensure_type(target_type) + shared_ptr[CScalar] result + + with nogil: + result = GetResultValue(self.wrapped.get().CastTo(type.sp_type)) + + return Scalar.wrap(result) + + def __repr__(self): + return '<pyarrow.{}: {!r}>'.format( + self.__class__.__name__, self.as_py() + ) + + def __str__(self): + return str(self.as_py()) + + def equals(self, Scalar other not None): + return self.wrapped.get().Equals(other.unwrap().get()[0]) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def __hash__(self): + cdef CScalarHash hasher + return hasher(self.wrapped) + + def __reduce__(self): + return scalar, (self.as_py(), self.type) + + def as_py(self): + raise NotImplementedError() + + +_NULL = NA = None + + +cdef class NullScalar(Scalar): + """ + Concrete class for null scalars. + """ + + def __cinit__(self): + global NA + if NA is not None: + raise RuntimeError('Cannot create multiple NullScalar instances') + self.init(shared_ptr[CScalar](new CNullScalar())) + + def __init__(self): + pass + + def as_py(self): + """ + Return this value as a Python None. + """ + return None + + +_NULL = NA = NullScalar() + + +cdef class BooleanScalar(Scalar): + """ + Concrete class for boolean scalars. + """ + + def as_py(self): + """ + Return this value as a Python bool. + """ + cdef CBooleanScalar* sp = <CBooleanScalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt8Scalar(Scalar): + """ + Concrete class for uint8 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt8Scalar* sp = <CUInt8Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int8Scalar(Scalar): + """ + Concrete class for int8 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt8Scalar* sp = <CInt8Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt16Scalar(Scalar): + """ + Concrete class for uint16 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt16Scalar* sp = <CUInt16Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int16Scalar(Scalar): + """ + Concrete class for int16 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt16Scalar* sp = <CInt16Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt32Scalar(Scalar): + """ + Concrete class for uint32 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt32Scalar* sp = <CUInt32Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int32Scalar(Scalar): + """ + Concrete class for int32 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt32Scalar* sp = <CInt32Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class UInt64Scalar(Scalar): + """ + Concrete class for uint64 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CUInt64Scalar* sp = <CUInt64Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Int64Scalar(Scalar): + """ + Concrete class for int64 scalars. + """ + + def as_py(self): + """ + Return this value as a Python int. + """ + cdef CInt64Scalar* sp = <CInt64Scalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class HalfFloatScalar(Scalar): + """ + Concrete class for float scalars. + """ + + def as_py(self): + """ + Return this value as a Python float. + """ + cdef CHalfFloatScalar* sp = <CHalfFloatScalar*> self.wrapped.get() + return PyHalf_FromHalf(sp.value) if sp.is_valid else None + + +cdef class FloatScalar(Scalar): + """ + Concrete class for float scalars. + """ + + def as_py(self): + """ + Return this value as a Python float. + """ + cdef CFloatScalar* sp = <CFloatScalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class DoubleScalar(Scalar): + """ + Concrete class for double scalars. + """ + + def as_py(self): + """ + Return this value as a Python float. + """ + cdef CDoubleScalar* sp = <CDoubleScalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + +cdef class Decimal128Scalar(Scalar): + """ + Concrete class for decimal128 scalars. + """ + + def as_py(self): + """ + Return this value as a Python Decimal. + """ + cdef: + CDecimal128Scalar* sp = <CDecimal128Scalar*> self.wrapped.get() + CDecimal128Type* dtype = <CDecimal128Type*> sp.type.get() + if sp.is_valid: + return _pydecimal.Decimal( + frombytes(sp.value.ToString(dtype.scale())) + ) + else: + return None + + +cdef class Decimal256Scalar(Scalar): + """ + Concrete class for decimal256 scalars. + """ + + def as_py(self): + """ + Return this value as a Python Decimal. + """ + cdef: + CDecimal256Scalar* sp = <CDecimal256Scalar*> self.wrapped.get() + CDecimal256Type* dtype = <CDecimal256Type*> sp.type.get() + if sp.is_valid: + return _pydecimal.Decimal( + frombytes(sp.value.ToString(dtype.scale())) + ) + else: + return None + + +cdef class Date32Scalar(Scalar): + """ + Concrete class for date32 scalars. + """ + + def as_py(self): + """ + Return this value as a Python datetime.datetime instance. + """ + cdef CDate32Scalar* sp = <CDate32Scalar*> self.wrapped.get() + + if sp.is_valid: + # shift to seconds since epoch + return ( + datetime.date(1970, 1, 1) + datetime.timedelta(days=sp.value) + ) + else: + return None + + +cdef class Date64Scalar(Scalar): + """ + Concrete class for date64 scalars. + """ + + def as_py(self): + """ + Return this value as a Python datetime.datetime instance. + """ + cdef CDate64Scalar* sp = <CDate64Scalar*> self.wrapped.get() + + if sp.is_valid: + return ( + datetime.date(1970, 1, 1) + + datetime.timedelta(days=sp.value / 86400000) + ) + else: + return None + + +def _datetime_from_int(int64_t value, TimeUnit unit, tzinfo=None): + if unit == TimeUnit_SECOND: + delta = datetime.timedelta(seconds=value) + elif unit == TimeUnit_MILLI: + delta = datetime.timedelta(milliseconds=value) + elif unit == TimeUnit_MICRO: + delta = datetime.timedelta(microseconds=value) + else: + # TimeUnit_NANO: prefer pandas timestamps if available + if _pandas_api.have_pandas: + return _pandas_api.pd.Timestamp(value, tz=tzinfo, unit='ns') + # otherwise safely truncate to microsecond resolution datetime + if value % 1000 != 0: + raise ValueError( + "Nanosecond resolution temporal type {} is not safely " + "convertible to microseconds to convert to datetime.datetime. " + "Install pandas to return as Timestamp with nanosecond " + "support or access the .value attribute.".format(value) + ) + delta = datetime.timedelta(microseconds=value // 1000) + + dt = datetime.datetime(1970, 1, 1) + delta + # adjust timezone if set to the datatype + if tzinfo is not None: + dt = tzinfo.fromutc(dt) + + return dt + + +cdef class Time32Scalar(Scalar): + """ + Concrete class for time32 scalars. + """ + + def as_py(self): + """ + Return this value as a Python datetime.timedelta instance. + """ + cdef: + CTime32Scalar* sp = <CTime32Scalar*> self.wrapped.get() + CTime32Type* dtype = <CTime32Type*> sp.type.get() + + if sp.is_valid: + return _datetime_from_int(sp.value, unit=dtype.unit()).time() + else: + return None + + +cdef class Time64Scalar(Scalar): + """ + Concrete class for time64 scalars. + """ + + def as_py(self): + """ + Return this value as a Python datetime.timedelta instance. + """ + cdef: + CTime64Scalar* sp = <CTime64Scalar*> self.wrapped.get() + CTime64Type* dtype = <CTime64Type*> sp.type.get() + + if sp.is_valid: + return _datetime_from_int(sp.value, unit=dtype.unit()).time() + else: + return None + + +cdef class TimestampScalar(Scalar): + """ + Concrete class for timestamp scalars. + """ + + @property + def value(self): + cdef CTimestampScalar* sp = <CTimestampScalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Pandas Timestamp instance (if units are + nanoseconds and pandas is available), otherwise as a Python + datetime.datetime instance. + """ + cdef: + CTimestampScalar* sp = <CTimestampScalar*> self.wrapped.get() + CTimestampType* dtype = <CTimestampType*> sp.type.get() + + if not sp.is_valid: + return None + + if not dtype.timezone().empty(): + tzinfo = string_to_tzinfo(frombytes(dtype.timezone())) + else: + tzinfo = None + + return _datetime_from_int(sp.value, unit=dtype.unit(), tzinfo=tzinfo) + + +cdef class DurationScalar(Scalar): + """ + Concrete class for duration scalars. + """ + + @property + def value(self): + cdef CDurationScalar* sp = <CDurationScalar*> self.wrapped.get() + return sp.value if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Pandas Timedelta instance (if units are + nanoseconds and pandas is available), otherwise as a Python + datetime.timedelta instance. + """ + cdef: + CDurationScalar* sp = <CDurationScalar*> self.wrapped.get() + CDurationType* dtype = <CDurationType*> sp.type.get() + TimeUnit unit = dtype.unit() + + if not sp.is_valid: + return None + + if unit == TimeUnit_SECOND: + return datetime.timedelta(seconds=sp.value) + elif unit == TimeUnit_MILLI: + return datetime.timedelta(milliseconds=sp.value) + elif unit == TimeUnit_MICRO: + return datetime.timedelta(microseconds=sp.value) + else: + # TimeUnit_NANO: prefer pandas timestamps if available + if _pandas_api.have_pandas: + return _pandas_api.pd.Timedelta(sp.value, unit='ns') + # otherwise safely truncate to microsecond resolution timedelta + if sp.value % 1000 != 0: + raise ValueError( + "Nanosecond duration {} is not safely convertible to " + "microseconds to convert to datetime.timedelta. Install " + "pandas to return as Timedelta with nanosecond support or " + "access the .value attribute.".format(sp.value) + ) + return datetime.timedelta(microseconds=sp.value // 1000) + + +cdef class MonthDayNanoIntervalScalar(Scalar): + """ + Concrete class for month, day, nanosecond interval scalars. + """ + + @property + def value(self): + """ + Same as self.as_py() + """ + return self.as_py() + + def as_py(self): + """ + Return this value as a pyarrow.MonthDayNano. + """ + cdef: + PyObject* val + CMonthDayNanoIntervalScalar* scalar + scalar = <CMonthDayNanoIntervalScalar*>self.wrapped.get() + val = GetResultValue(MonthDayNanoIntervalScalarToPyObject( + deref(scalar))) + return PyObject_to_object(val) + + +cdef class BinaryScalar(Scalar): + """ + Concrete class for binary-like scalars. + """ + + def as_buffer(self): + """ + Return a view over this value as a Buffer object. + """ + cdef CBaseBinaryScalar* sp = <CBaseBinaryScalar*> self.wrapped.get() + return pyarrow_wrap_buffer(sp.value) if sp.is_valid else None + + def as_py(self): + """ + Return this value as a Python bytes. + """ + buffer = self.as_buffer() + return None if buffer is None else buffer.to_pybytes() + + +cdef class LargeBinaryScalar(BinaryScalar): + pass + + +cdef class FixedSizeBinaryScalar(BinaryScalar): + pass + + +cdef class StringScalar(BinaryScalar): + """ + Concrete class for string-like (utf8) scalars. + """ + + def as_py(self): + """ + Return this value as a Python string. + """ + buffer = self.as_buffer() + return None if buffer is None else str(buffer, 'utf8') + + +cdef class LargeStringScalar(StringScalar): + pass + + +cdef class ListScalar(Scalar): + """ + Concrete class for list-like scalars. + """ + + @property + def values(self): + cdef CBaseListScalar* sp = <CBaseListScalar*> self.wrapped.get() + if sp.is_valid: + return pyarrow_wrap_array(sp.value) + else: + return None + + def __len__(self): + """ + Return the number of values. + """ + return len(self.values) + + def __getitem__(self, i): + """ + Return the value at the given index. + """ + return self.values[_normalize_index(i, len(self))] + + def __iter__(self): + """ + Iterate over this element's values. + """ + return iter(self.values) + + def as_py(self): + """ + Return this value as a Python list. + """ + arr = self.values + return None if arr is None else arr.to_pylist() + + +cdef class FixedSizeListScalar(ListScalar): + pass + + +cdef class LargeListScalar(ListScalar): + pass + + +cdef class StructScalar(Scalar, collections.abc.Mapping): + """ + Concrete class for struct scalars. + """ + + def __len__(self): + cdef CStructScalar* sp = <CStructScalar*> self.wrapped.get() + return sp.value.size() + + def __iter__(self): + cdef: + CStructScalar* sp = <CStructScalar*> self.wrapped.get() + CStructType* dtype = <CStructType*> sp.type.get() + vector[shared_ptr[CField]] fields = dtype.fields() + + for i in range(dtype.num_fields()): + yield frombytes(fields[i].get().name()) + + def items(self): + return ((key, self[i]) for i, key in enumerate(self)) + + def __contains__(self, key): + return key in list(self) + + def __getitem__(self, key): + """ + Return the child value for the given field. + + Parameters + ---------- + index : Union[int, str] + Index / position or name of the field. + + Returns + ------- + result : Scalar + """ + cdef: + CFieldRef ref + CStructScalar* sp = <CStructScalar*> self.wrapped.get() + + if isinstance(key, (bytes, str)): + ref = CFieldRef(<c_string> tobytes(key)) + elif isinstance(key, int): + ref = CFieldRef(<int> key) + else: + raise TypeError('Expected integer or string index') + + try: + return Scalar.wrap(GetResultValue(sp.field(ref))) + except ArrowInvalid as exc: + if isinstance(key, int): + raise IndexError(key) from exc + else: + raise KeyError(key) from exc + + def as_py(self): + """ + Return this value as a Python dict. + """ + if self.is_valid: + try: + return {k: self[k].as_py() for k in self.keys()} + except KeyError: + raise ValueError( + "Converting to Python dictionary is not supported when " + "duplicate field names are present") + else: + return None + + def _as_py_tuple(self): + # a version that returns a tuple instead of dict to support repr/str + # with the presence of duplicate field names + if self.is_valid: + return [(key, self[i].as_py()) for i, key in enumerate(self)] + else: + return None + + def __repr__(self): + return '<pyarrow.{}: {!r}>'.format( + self.__class__.__name__, self._as_py_tuple() + ) + + def __str__(self): + return str(self._as_py_tuple()) + + +cdef class MapScalar(ListScalar): + """ + Concrete class for map scalars. + """ + + def __getitem__(self, i): + """ + Return the value at the given index. + """ + arr = self.values + if arr is None: + raise IndexError(i) + dct = arr[_normalize_index(i, len(arr))] + return (dct['key'], dct['value']) + + def __iter__(self): + """ + Iterate over this element's values. + """ + arr = self.values + if array is None: + raise StopIteration + for k, v in zip(arr.field('key'), arr.field('value')): + yield (k.as_py(), v.as_py()) + + def as_py(self): + """ + Return this value as a Python list. + """ + cdef CStructScalar* sp = <CStructScalar*> self.wrapped.get() + return list(self) if sp.is_valid else None + + +cdef class DictionaryScalar(Scalar): + """ + Concrete class for dictionary-encoded scalars. + """ + + @classmethod + def _reconstruct(cls, type, is_valid, index, dictionary): + cdef: + CDictionaryScalarIndexAndDictionary value + shared_ptr[CDictionaryScalar] wrapped + DataType type_ + Scalar index_ + Array dictionary_ + + type_ = ensure_type(type, allow_none=False) + if not isinstance(type_, DictionaryType): + raise TypeError('Must pass a DictionaryType instance') + + if isinstance(index, Scalar): + if not index.type.equals(type.index_type): + raise TypeError("The Scalar value passed as index must have " + "identical type to the dictionary type's " + "index_type") + index_ = index + else: + index_ = scalar(index, type=type_.index_type) + + if isinstance(dictionary, Array): + if not dictionary.type.equals(type.value_type): + raise TypeError("The Array passed as dictionary must have " + "identical type to the dictionary type's " + "value_type") + dictionary_ = dictionary + else: + dictionary_ = array(dictionary, type=type_.value_type) + + value.index = pyarrow_unwrap_scalar(index_) + value.dictionary = pyarrow_unwrap_array(dictionary_) + + wrapped = make_shared[CDictionaryScalar]( + value, pyarrow_unwrap_data_type(type_), <c_bool>(is_valid) + ) + return Scalar.wrap(<shared_ptr[CScalar]> wrapped) + + def __reduce__(self): + return DictionaryScalar._reconstruct, ( + self.type, self.is_valid, self.index, self.dictionary + ) + + @property + def index(self): + """ + Return this value's underlying index as a scalar. + """ + cdef CDictionaryScalar* sp = <CDictionaryScalar*> self.wrapped.get() + return Scalar.wrap(sp.value.index) + + @property + def value(self): + """ + Return the encoded value as a scalar. + """ + cdef CDictionaryScalar* sp = <CDictionaryScalar*> self.wrapped.get() + return Scalar.wrap(GetResultValue(sp.GetEncodedValue())) + + @property + def dictionary(self): + cdef CDictionaryScalar* sp = <CDictionaryScalar*> self.wrapped.get() + return pyarrow_wrap_array(sp.value.dictionary) + + def as_py(self): + """ + Return this encoded value as a Python object. + """ + return self.value.as_py() if self.is_valid else None + + @property + def index_value(self): + warnings.warn("`index_value` property is deprecated as of 1.0.0" + "please use the `index` property instead", + FutureWarning) + return self.index + + @property + def dictionary_value(self): + warnings.warn("`dictionary_value` property is deprecated as of 1.0.0, " + "please use the `value` property instead", FutureWarning) + return self.value + + +cdef class UnionScalar(Scalar): + """ + Concrete class for Union scalars. + """ + + @property + def value(self): + """ + Return underlying value as a scalar. + """ + cdef CUnionScalar* sp = <CUnionScalar*> self.wrapped.get() + return Scalar.wrap(sp.value) if sp.is_valid else None + + def as_py(self): + """ + Return underlying value as a Python object. + """ + value = self.value + return None if value is None else value.as_py() + + @property + def type_code(self): + """ + Return the union type code for this scalar. + """ + cdef CUnionScalar* sp = <CUnionScalar*> self.wrapped.get() + return sp.type_code + + +cdef class ExtensionScalar(Scalar): + """ + Concrete class for Extension scalars. + """ + + @property + def value(self): + """ + Return storage value as a scalar. + """ + cdef CExtensionScalar* sp = <CExtensionScalar*> self.wrapped.get() + return Scalar.wrap(sp.value) if sp.is_valid else None + + def as_py(self): + """ + Return this scalar as a Python object. + """ + # XXX should there be a hook to wrap the result in a custom class? + value = self.value + return None if value is None else value.as_py() + + @staticmethod + def from_storage(BaseExtensionType typ, value): + """ + Construct ExtensionScalar from type and storage value. + + Parameters + ---------- + typ : DataType + The extension type for the result scalar. + value : object + The storage value for the result scalar. + + Returns + ------- + ext_scalar : ExtensionScalar + """ + cdef: + shared_ptr[CExtensionScalar] sp_scalar + CExtensionScalar* ext_scalar + + if value is None: + storage = None + elif isinstance(value, Scalar): + if value.type != typ.storage_type: + raise TypeError("Incompatible storage type {0} " + "for extension type {1}" + .format(value.type, typ)) + storage = value + else: + storage = scalar(value, typ.storage_type) + + sp_scalar = make_shared[CExtensionScalar](typ.sp_type) + ext_scalar = sp_scalar.get() + ext_scalar.is_valid = storage is not None and storage.is_valid + if ext_scalar.is_valid: + ext_scalar.value = pyarrow_unwrap_scalar(storage) + check_status(ext_scalar.Validate()) + return pyarrow_wrap_scalar(<shared_ptr[CScalar]> sp_scalar) + + +cdef dict _scalar_classes = { + _Type_BOOL: BooleanScalar, + _Type_UINT8: UInt8Scalar, + _Type_UINT16: UInt16Scalar, + _Type_UINT32: UInt32Scalar, + _Type_UINT64: UInt64Scalar, + _Type_INT8: Int8Scalar, + _Type_INT16: Int16Scalar, + _Type_INT32: Int32Scalar, + _Type_INT64: Int64Scalar, + _Type_HALF_FLOAT: HalfFloatScalar, + _Type_FLOAT: FloatScalar, + _Type_DOUBLE: DoubleScalar, + _Type_DECIMAL128: Decimal128Scalar, + _Type_DECIMAL256: Decimal256Scalar, + _Type_DATE32: Date32Scalar, + _Type_DATE64: Date64Scalar, + _Type_TIME32: Time32Scalar, + _Type_TIME64: Time64Scalar, + _Type_TIMESTAMP: TimestampScalar, + _Type_DURATION: DurationScalar, + _Type_BINARY: BinaryScalar, + _Type_LARGE_BINARY: LargeBinaryScalar, + _Type_FIXED_SIZE_BINARY: FixedSizeBinaryScalar, + _Type_STRING: StringScalar, + _Type_LARGE_STRING: LargeStringScalar, + _Type_LIST: ListScalar, + _Type_LARGE_LIST: LargeListScalar, + _Type_FIXED_SIZE_LIST: FixedSizeListScalar, + _Type_STRUCT: StructScalar, + _Type_MAP: MapScalar, + _Type_DICTIONARY: DictionaryScalar, + _Type_SPARSE_UNION: UnionScalar, + _Type_DENSE_UNION: UnionScalar, + _Type_INTERVAL_MONTH_DAY_NANO: MonthDayNanoIntervalScalar, + _Type_EXTENSION: ExtensionScalar, +} + + +def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None): + """ + Create a pyarrow.Scalar instance from a Python object. + + Parameters + ---------- + value : Any + Python object coercible to arrow's type system. + type : pyarrow.DataType + Explicit type to attempt to coerce to, otherwise will be inferred from + the value. + from_pandas : bool, default None + Use pandas's semantics for inferring nulls from values in + ndarray-like data. Defaults to False if not passed explicitly by user, + or True if a pandas object is passed in. + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the currently-set default + memory pool. + + Returns + ------- + scalar : pyarrow.Scalar + + Examples + -------- + >>> import pyarrow as pa + + >>> pa.scalar(42) + <pyarrow.Int64Scalar: 42> + + >>> pa.scalar("string") + <pyarrow.StringScalar: 'string'> + + >>> pa.scalar([1, 2]) + <pyarrow.ListScalar: [1, 2]> + + >>> pa.scalar([1, 2], type=pa.list_(pa.int16())) + <pyarrow.ListScalar: [1, 2]> + """ + cdef: + DataType ty + PyConversionOptions options + shared_ptr[CScalar] scalar + shared_ptr[CArray] array + shared_ptr[CChunkedArray] chunked + bint is_pandas_object = False + CMemoryPool* pool + + type = ensure_type(type, allow_none=True) + pool = maybe_unbox_memory_pool(memory_pool) + + if _is_array_like(value): + value = get_values(value, &is_pandas_object) + + options.size = 1 + + if type is not None: + ty = ensure_type(type) + options.type = ty.sp_type + + if from_pandas is None: + options.from_pandas = is_pandas_object + else: + options.from_pandas = from_pandas + + value = [value] + with nogil: + chunked = GetResultValue(ConvertPySequence(value, None, options, pool)) + + # get the first chunk + assert chunked.get().num_chunks() == 1 + array = chunked.get().chunk(0) + + # retrieve the scalar from the first position + scalar = GetResultValue(array.get().GetScalar(0)) + return Scalar.wrap(scalar) diff --git a/src/arrow/python/pyarrow/serialization.pxi b/src/arrow/python/pyarrow/serialization.pxi new file mode 100644 index 000000000..c03721578 --- /dev/null +++ b/src/arrow/python/pyarrow/serialization.pxi @@ -0,0 +1,556 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from cpython.ref cimport PyObject + +import warnings + + +def _deprecate_serialization(name): + msg = ( + "'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a " + "future version. Use pickle or the pyarrow IPC functionality instead." + ).format(name) + warnings.warn(msg, FutureWarning, stacklevel=3) + + +def is_named_tuple(cls): + """ + Return True if cls is a namedtuple and False otherwise. + """ + b = cls.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(cls, "_fields", None) + if not isinstance(f, tuple): + return False + return all(isinstance(n, str) for n in f) + + +class SerializationCallbackError(ArrowSerializationError): + def __init__(self, message, example_object): + ArrowSerializationError.__init__(self, message) + self.example_object = example_object + + +class DeserializationCallbackError(ArrowSerializationError): + def __init__(self, message, type_id): + ArrowSerializationError.__init__(self, message) + self.type_id = type_id + + +cdef class SerializationContext(_Weakrefable): + cdef: + object type_to_type_id + object whitelisted_types + object types_to_pickle + object custom_serializers + object custom_deserializers + object pickle_serializer + object pickle_deserializer + + def __init__(self): + # Types with special serialization handlers + self.type_to_type_id = dict() + self.whitelisted_types = dict() + self.types_to_pickle = set() + self.custom_serializers = dict() + self.custom_deserializers = dict() + self.pickle_serializer = pickle.dumps + self.pickle_deserializer = pickle.loads + + def set_pickle(self, serializer, deserializer): + """ + Set the serializer and deserializer to use for objects that are to be + pickled. + + Parameters + ---------- + serializer : callable + The serializer to use (e.g., pickle.dumps or cloudpickle.dumps). + deserializer : callable + The deserializer to use (e.g., pickle.dumps or cloudpickle.dumps). + """ + self.pickle_serializer = serializer + self.pickle_deserializer = deserializer + + def clone(self): + """ + Return copy of this SerializationContext. + + Returns + ------- + clone : SerializationContext + """ + result = SerializationContext() + result.type_to_type_id = self.type_to_type_id.copy() + result.whitelisted_types = self.whitelisted_types.copy() + result.types_to_pickle = self.types_to_pickle.copy() + result.custom_serializers = self.custom_serializers.copy() + result.custom_deserializers = self.custom_deserializers.copy() + result.pickle_serializer = self.pickle_serializer + result.pickle_deserializer = self.pickle_deserializer + + return result + + def register_type(self, type_, type_id, pickle=False, + custom_serializer=None, custom_deserializer=None): + r""" + EXPERIMENTAL: Add type to the list of types we can serialize. + + Parameters + ---------- + type\_ : type + The type that we can serialize. + type_id : string + A string used to identify the type. + pickle : bool + True if the serialization should be done with pickle. + False if it should be done efficiently with Arrow. + custom_serializer : callable + This argument is optional, but can be provided to + serialize objects of the class in a particular way. + custom_deserializer : callable + This argument is optional, but can be provided to + deserialize objects of the class in a particular way. + """ + if not isinstance(type_id, str): + raise TypeError("The type_id argument must be a string. The value " + "passed in has type {}.".format(type(type_id))) + + self.type_to_type_id[type_] = type_id + self.whitelisted_types[type_id] = type_ + if pickle: + self.types_to_pickle.add(type_id) + if custom_serializer is not None: + self.custom_serializers[type_id] = custom_serializer + self.custom_deserializers[type_id] = custom_deserializer + + def _serialize_callback(self, obj): + found = False + for type_ in type(obj).__mro__: + if type_ in self.type_to_type_id: + found = True + break + + if not found: + raise SerializationCallbackError( + "pyarrow does not know how to " + "serialize objects of type {}.".format(type(obj)), obj + ) + + # use the closest match to type(obj) + type_id = self.type_to_type_id[type_] + if type_id in self.types_to_pickle: + serialized_obj = {"data": self.pickle_serializer(obj), + "pickle": True} + elif type_id in self.custom_serializers: + serialized_obj = {"data": self.custom_serializers[type_id](obj)} + else: + if is_named_tuple(type_): + serialized_obj = {} + serialized_obj["_pa_getnewargs_"] = obj.__getnewargs__() + elif hasattr(obj, "__dict__"): + serialized_obj = obj.__dict__ + else: + msg = "We do not know how to serialize " \ + "the object '{}'".format(obj) + raise SerializationCallbackError(msg, obj) + return dict(serialized_obj, **{"_pytype_": type_id}) + + def _deserialize_callback(self, serialized_obj): + type_id = serialized_obj["_pytype_"] + if isinstance(type_id, bytes): + # ARROW-4675: Python 2 serialized, read in Python 3 + type_id = frombytes(type_id) + + if "pickle" in serialized_obj: + # The object was pickled, so unpickle it. + obj = self.pickle_deserializer(serialized_obj["data"]) + else: + assert type_id not in self.types_to_pickle + if type_id not in self.whitelisted_types: + msg = "Type ID " + type_id + " not registered in " \ + "deserialization callback" + raise DeserializationCallbackError(msg, type_id) + type_ = self.whitelisted_types[type_id] + if type_id in self.custom_deserializers: + obj = self.custom_deserializers[type_id]( + serialized_obj["data"]) + else: + # In this case, serialized_obj should just be + # the __dict__ field. + if "_pa_getnewargs_" in serialized_obj: + obj = type_.__new__( + type_, *serialized_obj["_pa_getnewargs_"]) + else: + obj = type_.__new__(type_) + serialized_obj.pop("_pytype_") + obj.__dict__.update(serialized_obj) + return obj + + def serialize(self, obj): + """ + Call pyarrow.serialize and pass this SerializationContext. + """ + return serialize(obj, context=self) + + def serialize_to(self, object value, sink): + """ + Call pyarrow.serialize_to and pass this SerializationContext. + """ + return serialize_to(value, sink, context=self) + + def deserialize(self, what): + """ + Call pyarrow.deserialize and pass this SerializationContext. + """ + return deserialize(what, context=self) + + def deserialize_components(self, what): + """ + Call pyarrow.deserialize_components and pass this SerializationContext. + """ + return deserialize_components(what, context=self) + + +_default_serialization_context = SerializationContext() +_default_context_initialized = False + + +def _get_default_context(): + global _default_context_initialized + from pyarrow.serialization import _register_default_serialization_handlers + if not _default_context_initialized: + _register_default_serialization_handlers( + _default_serialization_context) + _default_context_initialized = True + return _default_serialization_context + + +cdef class SerializedPyObject(_Weakrefable): + """ + Arrow-serialized representation of Python object. + """ + cdef: + CSerializedPyObject data + + cdef readonly: + object base + + @property + def total_bytes(self): + cdef CMockOutputStream mock_stream + with nogil: + check_status(self.data.WriteTo(&mock_stream)) + + return mock_stream.GetExtentBytesWritten() + + def write_to(self, sink): + """ + Write serialized object to a sink. + """ + cdef shared_ptr[COutputStream] stream + get_writer(sink, &stream) + self._write_to(stream.get()) + + cdef _write_to(self, COutputStream* stream): + with nogil: + check_status(self.data.WriteTo(stream)) + + def deserialize(self, SerializationContext context=None): + """ + Convert back to Python object. + """ + cdef PyObject* result + + if context is None: + context = _get_default_context() + + with nogil: + check_status(DeserializeObject(context, self.data, + <PyObject*> self.base, &result)) + + # PyObject_to_object is necessary to avoid a memory leak; + # also unpack the list the object was wrapped in in serialize + return PyObject_to_object(result)[0] + + def to_buffer(self, nthreads=1): + """ + Write serialized data as Buffer. + """ + cdef Buffer output = allocate_buffer(self.total_bytes) + sink = FixedSizeBufferWriter(output) + if nthreads > 1: + sink.set_memcopy_threads(nthreads) + self.write_to(sink) + return output + + @staticmethod + def from_components(components): + """ + Reconstruct SerializedPyObject from output of + SerializedPyObject.to_components. + """ + cdef: + int num_tensors = components['num_tensors'] + int num_ndarrays = components['num_ndarrays'] + int num_buffers = components['num_buffers'] + list buffers = components['data'] + SparseTensorCounts num_sparse_tensors = SparseTensorCounts() + SerializedPyObject result = SerializedPyObject() + + num_sparse_tensors.coo = components['num_sparse_tensors']['coo'] + num_sparse_tensors.csr = components['num_sparse_tensors']['csr'] + num_sparse_tensors.csc = components['num_sparse_tensors']['csc'] + num_sparse_tensors.csf = components['num_sparse_tensors']['csf'] + num_sparse_tensors.ndim_csf = \ + components['num_sparse_tensors']['ndim_csf'] + + with nogil: + check_status(GetSerializedFromComponents(num_tensors, + num_sparse_tensors, + num_ndarrays, + num_buffers, + buffers, &result.data)) + + return result + + def to_components(self, memory_pool=None): + """ + Return the decomposed dict representation of the serialized object + containing a collection of Buffer objects which maximize opportunities + for zero-copy. + + Parameters + ---------- + memory_pool : MemoryPool default None + Pool to use for necessary allocations. + + Returns + + """ + cdef PyObject* result + cdef CMemoryPool* c_pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + check_status(self.data.GetComponents(c_pool, &result)) + + return PyObject_to_object(result) + + +def serialize(object value, SerializationContext context=None): + """ + DEPRECATED: Serialize a general Python sequence for transient storage + and transport. + + .. deprecated:: 2.0 + The custom serialization functionality is deprecated in pyarrow 2.0, + and will be removed in a future version. Use the standard library + ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for + more). + + Notes + ----- + This function produces data that is incompatible with the standard + Arrow IPC binary protocol, i.e. it cannot be used with ipc.open_stream or + ipc.open_file. You can use deserialize, deserialize_from, or + deserialize_components to read it. + + Parameters + ---------- + value : object + Python object for the sequence that is to be serialized. + context : SerializationContext + Custom serialization and deserialization context, uses a default + context with some standard type handlers if not specified. + + Returns + ------- + serialized : SerializedPyObject + + """ + _deprecate_serialization("serialize") + return _serialize(value, context) + + +def _serialize(object value, SerializationContext context=None): + cdef SerializedPyObject serialized = SerializedPyObject() + wrapped_value = [value] + + if context is None: + context = _get_default_context() + + with nogil: + check_status(SerializeObject(context, wrapped_value, &serialized.data)) + return serialized + + +def serialize_to(object value, sink, SerializationContext context=None): + """ + DEPRECATED: Serialize a Python sequence to a file. + + .. deprecated:: 2.0 + The custom serialization functionality is deprecated in pyarrow 2.0, + and will be removed in a future version. Use the standard library + ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for + more). + + Parameters + ---------- + value : object + Python object for the sequence that is to be serialized. + sink : NativeFile or file-like + File the sequence will be written to. + context : SerializationContext + Custom serialization and deserialization context, uses a default + context with some standard type handlers if not specified. + """ + _deprecate_serialization("serialize_to") + serialized = _serialize(value, context) + serialized.write_to(sink) + + +def read_serialized(source, base=None): + """ + DEPRECATED: Read serialized Python sequence from file-like object. + + .. deprecated:: 2.0 + The custom serialization functionality is deprecated in pyarrow 2.0, + and will be removed in a future version. Use the standard library + ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for + more). + + Parameters + ---------- + source : NativeFile + File to read the sequence from. + base : object + This object will be the base object of all the numpy arrays + contained in the sequence. + + Returns + ------- + serialized : the serialized data + """ + _deprecate_serialization("read_serialized") + return _read_serialized(source, base=base) + + +def _read_serialized(source, base=None): + cdef shared_ptr[CRandomAccessFile] stream + get_reader(source, True, &stream) + + cdef SerializedPyObject serialized = SerializedPyObject() + serialized.base = base + with nogil: + check_status(ReadSerializedObject(stream.get(), &serialized.data)) + + return serialized + + +def deserialize_from(source, object base, SerializationContext context=None): + """ + DEPRECATED: Deserialize a Python sequence from a file. + + .. deprecated:: 2.0 + The custom serialization functionality is deprecated in pyarrow 2.0, + and will be removed in a future version. Use the standard library + ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for + more). + + This only can interact with data produced by pyarrow.serialize or + pyarrow.serialize_to. + + Parameters + ---------- + source : NativeFile + File to read the sequence from. + base : object + This object will be the base object of all the numpy arrays + contained in the sequence. + context : SerializationContext + Custom serialization and deserialization context. + + Returns + ------- + object + Python object for the deserialized sequence. + """ + _deprecate_serialization("deserialize_from") + serialized = _read_serialized(source, base=base) + return serialized.deserialize(context) + + +def deserialize_components(components, SerializationContext context=None): + """ + DEPRECATED: Reconstruct Python object from output of + SerializedPyObject.to_components. + + .. deprecated:: 2.0 + The custom serialization functionality is deprecated in pyarrow 2.0, + and will be removed in a future version. Use the standard library + ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for + more). + + Parameters + ---------- + components : dict + Output of SerializedPyObject.to_components + context : SerializationContext, default None + + Returns + ------- + object : the Python object that was originally serialized + """ + _deprecate_serialization("deserialize_components") + serialized = SerializedPyObject.from_components(components) + return serialized.deserialize(context) + + +def deserialize(obj, SerializationContext context=None): + """ + DEPRECATED: Deserialize Python object from Buffer or other Python + object supporting the buffer protocol. + + .. deprecated:: 2.0 + The custom serialization functionality is deprecated in pyarrow 2.0, + and will be removed in a future version. Use the standard library + ``pickle`` or the IPC functionality of pyarrow (see :ref:`ipc` for + more). + + This only can interact with data produced by pyarrow.serialize or + pyarrow.serialize_to. + + Parameters + ---------- + obj : pyarrow.Buffer or Python object supporting buffer protocol + context : SerializationContext + Custom serialization and deserialization context. + + Returns + ------- + deserialized : object + """ + _deprecate_serialization("deserialize") + return _deserialize(obj, context=context) + + +def _deserialize(obj, SerializationContext context=None): + source = BufferReader(obj) + serialized = _read_serialized(source, base=obj) + return serialized.deserialize(context) diff --git a/src/arrow/python/pyarrow/serialization.py b/src/arrow/python/pyarrow/serialization.py new file mode 100644 index 000000000..d59a13166 --- /dev/null +++ b/src/arrow/python/pyarrow/serialization.py @@ -0,0 +1,504 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections +import warnings + +import numpy as np + +import pyarrow as pa +from pyarrow.lib import SerializationContext, py_buffer, builtin_pickle + +try: + import cloudpickle +except ImportError: + cloudpickle = builtin_pickle + + +try: + # This function is available after numpy-0.16.0. + # See also: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py + from numpy.lib.format import descr_to_dtype +except ImportError: + def descr_to_dtype(descr): + ''' + descr may be stored as dtype.descr, which is a list of (name, format, + [shape]) tuples where format may be a str or a tuple. Offsets are not + explicitly saved, rather empty fields with name, format == '', '|Vn' + are added as padding. This function reverses the process, eliminating + the empty padding fields. + ''' + if isinstance(descr, str): + # No padding removal needed + return np.dtype(descr) + elif isinstance(descr, tuple): + # subtype, will always have a shape descr[1] + dt = descr_to_dtype(descr[0]) + return np.dtype((dt, descr[1])) + fields = [] + offset = 0 + for field in descr: + if len(field) == 2: + name, descr_str = field + dt = descr_to_dtype(descr_str) + else: + name, descr_str, shape = field + dt = np.dtype((descr_to_dtype(descr_str), shape)) + + # Ignore padding bytes, which will be void bytes with '' as name + # Once support for blank names is removed, only "if name == ''" + # needed) + is_pad = (name == '' and dt.type is np.void and dt.names is None) + if not is_pad: + fields.append((name, dt, offset)) + + offset += dt.itemsize + + names, formats, offsets = zip(*fields) + # names may be (title, names) tuples + nametups = (n if isinstance(n, tuple) else (None, n) for n in names) + titles, names = zip(*nametups) + return np.dtype({'names': names, 'formats': formats, 'titles': titles, + 'offsets': offsets, 'itemsize': offset}) + + +def _deprecate_serialization(name): + msg = ( + "'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a " + "future version. Use pickle or the pyarrow IPC functionality instead." + ).format(name) + warnings.warn(msg, FutureWarning, stacklevel=3) + + +# ---------------------------------------------------------------------- +# Set up serialization for numpy with dtype object (primitive types are +# handled efficiently with Arrow's Tensor facilities, see +# python_to_arrow.cc) + +def _serialize_numpy_array_list(obj): + if obj.dtype.str != '|O': + # Make the array c_contiguous if necessary so that we can call change + # the view. + if not obj.flags.c_contiguous: + obj = np.ascontiguousarray(obj) + return obj.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype) + else: + return obj.tolist(), np.lib.format.dtype_to_descr(obj.dtype) + + +def _deserialize_numpy_array_list(data): + if data[1] != '|O': + assert data[0].dtype == np.uint8 + return data[0].view(descr_to_dtype(data[1])) + else: + return np.array(data[0], dtype=np.dtype(data[1])) + + +def _serialize_numpy_matrix(obj): + if obj.dtype.str != '|O': + # Make the array c_contiguous if necessary so that we can call change + # the view. + if not obj.flags.c_contiguous: + obj = np.ascontiguousarray(obj.A) + return obj.A.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype) + else: + return obj.A.tolist(), np.lib.format.dtype_to_descr(obj.dtype) + + +def _deserialize_numpy_matrix(data): + if data[1] != '|O': + assert data[0].dtype == np.uint8 + return np.matrix(data[0].view(descr_to_dtype(data[1])), + copy=False) + else: + return np.matrix(data[0], dtype=np.dtype(data[1]), copy=False) + + +# ---------------------------------------------------------------------- +# pyarrow.RecordBatch-specific serialization matters + +def _serialize_pyarrow_recordbatch(batch): + output_stream = pa.BufferOutputStream() + with pa.RecordBatchStreamWriter(output_stream, schema=batch.schema) as wr: + wr.write_batch(batch) + return output_stream.getvalue() # This will also close the stream. + + +def _deserialize_pyarrow_recordbatch(buf): + with pa.RecordBatchStreamReader(buf) as reader: + return reader.read_next_batch() + + +# ---------------------------------------------------------------------- +# pyarrow.Array-specific serialization matters + +def _serialize_pyarrow_array(array): + # TODO(suquark): implement more effcient array serialization. + batch = pa.RecordBatch.from_arrays([array], ['']) + return _serialize_pyarrow_recordbatch(batch) + + +def _deserialize_pyarrow_array(buf): + # TODO(suquark): implement more effcient array deserialization. + batch = _deserialize_pyarrow_recordbatch(buf) + return batch.columns[0] + + +# ---------------------------------------------------------------------- +# pyarrow.Table-specific serialization matters + +def _serialize_pyarrow_table(table): + output_stream = pa.BufferOutputStream() + with pa.RecordBatchStreamWriter(output_stream, schema=table.schema) as wr: + wr.write_table(table) + return output_stream.getvalue() # This will also close the stream. + + +def _deserialize_pyarrow_table(buf): + with pa.RecordBatchStreamReader(buf) as reader: + return reader.read_all() + + +def _pickle_to_buffer(x): + pickled = builtin_pickle.dumps(x, protocol=builtin_pickle.HIGHEST_PROTOCOL) + return py_buffer(pickled) + + +def _load_pickle_from_buffer(data): + as_memoryview = memoryview(data) + return builtin_pickle.loads(as_memoryview) + + +# ---------------------------------------------------------------------- +# pandas-specific serialization matters + +def _register_custom_pandas_handlers(context): + # ARROW-1784, faster path for pandas-only visibility + + try: + import pandas as pd + except ImportError: + return + + import pyarrow.pandas_compat as pdcompat + + sparse_type_error_msg = ( + '{0} serialization is not supported.\n' + 'Note that {0} is planned to be deprecated ' + 'in pandas future releases.\n' + 'See https://github.com/pandas-dev/pandas/issues/19239 ' + 'for more information.' + ) + + def _serialize_pandas_dataframe(obj): + if (pdcompat._pandas_api.has_sparse and + isinstance(obj, pd.SparseDataFrame)): + raise NotImplementedError( + sparse_type_error_msg.format('SparseDataFrame') + ) + + return pdcompat.dataframe_to_serialized_dict(obj) + + def _deserialize_pandas_dataframe(data): + return pdcompat.serialized_dict_to_dataframe(data) + + def _serialize_pandas_series(obj): + if (pdcompat._pandas_api.has_sparse and + isinstance(obj, pd.SparseSeries)): + raise NotImplementedError( + sparse_type_error_msg.format('SparseSeries') + ) + + return _serialize_pandas_dataframe(pd.DataFrame({obj.name: obj})) + + def _deserialize_pandas_series(data): + deserialized = _deserialize_pandas_dataframe(data) + return deserialized[deserialized.columns[0]] + + context.register_type( + pd.Series, 'pd.Series', + custom_serializer=_serialize_pandas_series, + custom_deserializer=_deserialize_pandas_series) + + context.register_type( + pd.Index, 'pd.Index', + custom_serializer=_pickle_to_buffer, + custom_deserializer=_load_pickle_from_buffer) + + if hasattr(pd.core, 'arrays'): + if hasattr(pd.core.arrays, 'interval'): + context.register_type( + pd.core.arrays.interval.IntervalArray, + 'pd.core.arrays.interval.IntervalArray', + custom_serializer=_pickle_to_buffer, + custom_deserializer=_load_pickle_from_buffer) + + if hasattr(pd.core.arrays, 'period'): + context.register_type( + pd.core.arrays.period.PeriodArray, + 'pd.core.arrays.period.PeriodArray', + custom_serializer=_pickle_to_buffer, + custom_deserializer=_load_pickle_from_buffer) + + if hasattr(pd.core.arrays, 'datetimes'): + context.register_type( + pd.core.arrays.datetimes.DatetimeArray, + 'pd.core.arrays.datetimes.DatetimeArray', + custom_serializer=_pickle_to_buffer, + custom_deserializer=_load_pickle_from_buffer) + + context.register_type( + pd.DataFrame, 'pd.DataFrame', + custom_serializer=_serialize_pandas_dataframe, + custom_deserializer=_deserialize_pandas_dataframe) + + +def register_torch_serialization_handlers(serialization_context): + # ---------------------------------------------------------------------- + # Set up serialization for pytorch tensors + _deprecate_serialization("register_torch_serialization_handlers") + + try: + import torch + + def _serialize_torch_tensor(obj): + if obj.is_sparse: + return pa.SparseCOOTensor.from_numpy( + obj._values().detach().numpy(), + obj._indices().detach().numpy().T, + shape=list(obj.shape)) + else: + return obj.detach().numpy() + + def _deserialize_torch_tensor(data): + if isinstance(data, pa.SparseCOOTensor): + return torch.sparse_coo_tensor( + indices=data.to_numpy()[1].T, + values=data.to_numpy()[0][:, 0], + size=data.shape) + else: + return torch.from_numpy(data) + + for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor, + torch.ByteTensor, torch.CharTensor, torch.ShortTensor, + torch.IntTensor, torch.LongTensor, torch.Tensor]: + serialization_context.register_type( + t, "torch." + t.__name__, + custom_serializer=_serialize_torch_tensor, + custom_deserializer=_deserialize_torch_tensor) + except ImportError: + # no torch + pass + + +def _register_collections_serialization_handlers(serialization_context): + def _serialize_deque(obj): + return list(obj) + + def _deserialize_deque(data): + return collections.deque(data) + + serialization_context.register_type( + collections.deque, "collections.deque", + custom_serializer=_serialize_deque, + custom_deserializer=_deserialize_deque) + + def _serialize_ordered_dict(obj): + return list(obj.keys()), list(obj.values()) + + def _deserialize_ordered_dict(data): + return collections.OrderedDict(zip(data[0], data[1])) + + serialization_context.register_type( + collections.OrderedDict, "collections.OrderedDict", + custom_serializer=_serialize_ordered_dict, + custom_deserializer=_deserialize_ordered_dict) + + def _serialize_default_dict(obj): + return list(obj.keys()), list(obj.values()), obj.default_factory + + def _deserialize_default_dict(data): + return collections.defaultdict(data[2], zip(data[0], data[1])) + + serialization_context.register_type( + collections.defaultdict, "collections.defaultdict", + custom_serializer=_serialize_default_dict, + custom_deserializer=_deserialize_default_dict) + + def _serialize_counter(obj): + return list(obj.keys()), list(obj.values()) + + def _deserialize_counter(data): + return collections.Counter(dict(zip(data[0], data[1]))) + + serialization_context.register_type( + collections.Counter, "collections.Counter", + custom_serializer=_serialize_counter, + custom_deserializer=_deserialize_counter) + + +# ---------------------------------------------------------------------- +# Set up serialization for scipy sparse matrices. Primitive types are handled +# efficiently with Arrow's SparseTensor facilities, see numpy_convert.cc) + +def _register_scipy_handlers(serialization_context): + try: + from scipy.sparse import (csr_matrix, csc_matrix, coo_matrix, + isspmatrix_coo, isspmatrix_csr, + isspmatrix_csc, isspmatrix) + + def _serialize_scipy_sparse(obj): + if isspmatrix_coo(obj): + return 'coo', pa.SparseCOOTensor.from_scipy(obj) + + elif isspmatrix_csr(obj): + return 'csr', pa.SparseCSRMatrix.from_scipy(obj) + + elif isspmatrix_csc(obj): + return 'csc', pa.SparseCSCMatrix.from_scipy(obj) + + elif isspmatrix(obj): + return 'csr', pa.SparseCOOTensor.from_scipy(obj.to_coo()) + + else: + raise NotImplementedError( + "Serialization of {} is not supported.".format(obj[0])) + + def _deserialize_scipy_sparse(data): + if data[0] == 'coo': + return data[1].to_scipy() + + elif data[0] == 'csr': + return data[1].to_scipy() + + elif data[0] == 'csc': + return data[1].to_scipy() + + else: + return data[1].to_scipy() + + serialization_context.register_type( + coo_matrix, 'scipy.sparse.coo.coo_matrix', + custom_serializer=_serialize_scipy_sparse, + custom_deserializer=_deserialize_scipy_sparse) + + serialization_context.register_type( + csr_matrix, 'scipy.sparse.csr.csr_matrix', + custom_serializer=_serialize_scipy_sparse, + custom_deserializer=_deserialize_scipy_sparse) + + serialization_context.register_type( + csc_matrix, 'scipy.sparse.csc.csc_matrix', + custom_serializer=_serialize_scipy_sparse, + custom_deserializer=_deserialize_scipy_sparse) + + except ImportError: + # no scipy + pass + + +# ---------------------------------------------------------------------- +# Set up serialization for pydata/sparse tensors. + +def _register_pydata_sparse_handlers(serialization_context): + try: + import sparse + + def _serialize_pydata_sparse(obj): + if isinstance(obj, sparse.COO): + return 'coo', pa.SparseCOOTensor.from_pydata_sparse(obj) + else: + raise NotImplementedError( + "Serialization of {} is not supported.".format(sparse.COO)) + + def _deserialize_pydata_sparse(data): + if data[0] == 'coo': + data_array, coords = data[1].to_numpy() + return sparse.COO( + data=data_array[:, 0], + coords=coords.T, shape=data[1].shape) + + serialization_context.register_type( + sparse.COO, 'sparse.COO', + custom_serializer=_serialize_pydata_sparse, + custom_deserializer=_deserialize_pydata_sparse) + + except ImportError: + # no pydata/sparse + pass + + +def _register_default_serialization_handlers(serialization_context): + + # ---------------------------------------------------------------------- + # Set up serialization for primitive datatypes + + # TODO(pcm): This is currently a workaround until arrow supports + # arbitrary precision integers. This is only called on long integers, + # see the associated case in the append method in python_to_arrow.cc + serialization_context.register_type( + int, "int", + custom_serializer=lambda obj: str(obj), + custom_deserializer=lambda data: int(data)) + + serialization_context.register_type( + type(lambda: 0), "function", + pickle=True) + + serialization_context.register_type(type, "type", pickle=True) + + serialization_context.register_type( + np.matrix, 'np.matrix', + custom_serializer=_serialize_numpy_matrix, + custom_deserializer=_deserialize_numpy_matrix) + + serialization_context.register_type( + np.ndarray, 'np.array', + custom_serializer=_serialize_numpy_array_list, + custom_deserializer=_deserialize_numpy_array_list) + + serialization_context.register_type( + pa.Array, 'pyarrow.Array', + custom_serializer=_serialize_pyarrow_array, + custom_deserializer=_deserialize_pyarrow_array) + + serialization_context.register_type( + pa.RecordBatch, 'pyarrow.RecordBatch', + custom_serializer=_serialize_pyarrow_recordbatch, + custom_deserializer=_deserialize_pyarrow_recordbatch) + + serialization_context.register_type( + pa.Table, 'pyarrow.Table', + custom_serializer=_serialize_pyarrow_table, + custom_deserializer=_deserialize_pyarrow_table) + + _register_collections_serialization_handlers(serialization_context) + _register_custom_pandas_handlers(serialization_context) + _register_scipy_handlers(serialization_context) + _register_pydata_sparse_handlers(serialization_context) + + +def register_default_serialization_handlers(serialization_context): + _deprecate_serialization("register_default_serialization_handlers") + _register_default_serialization_handlers(serialization_context) + + +def default_serialization_context(): + _deprecate_serialization("default_serialization_context") + context = SerializationContext() + _register_default_serialization_handlers(context) + return context diff --git a/src/arrow/python/pyarrow/table.pxi b/src/arrow/python/pyarrow/table.pxi new file mode 100644 index 000000000..8105ce482 --- /dev/null +++ b/src/arrow/python/pyarrow/table.pxi @@ -0,0 +1,2389 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import warnings + + +cdef class ChunkedArray(_PandasConvertible): + """ + An array-like composed from a (possibly empty) collection of pyarrow.Arrays + + Warnings + -------- + Do not call this class's constructor directly. + """ + + def __cinit__(self): + self.chunked_array = NULL + + def __init__(self): + raise TypeError("Do not call ChunkedArray's constructor directly, use " + "`chunked_array` function instead.") + + cdef void init(self, const shared_ptr[CChunkedArray]& chunked_array): + self.sp_chunked_array = chunked_array + self.chunked_array = chunked_array.get() + + def __reduce__(self): + return chunked_array, (self.chunks, self.type) + + @property + def data(self): + import warnings + warnings.warn("Calling .data on ChunkedArray is provided for " + "compatibility after Column was removed, simply drop " + "this attribute", FutureWarning) + return self + + @property + def type(self): + return pyarrow_wrap_data_type(self.sp_chunked_array.get().type()) + + def length(self): + return self.chunked_array.length() + + def __len__(self): + return self.length() + + def __repr__(self): + type_format = object.__repr__(self) + return '{0}\n{1}'.format(type_format, str(self)) + + def to_string(self, *, int indent=0, int window=10, + c_bool skip_new_lines=False): + """ + Render a "pretty-printed" string representation of the ChunkedArray + + Parameters + ---------- + indent : int + How much to indent right the content of the array, + by default ``0``. + window : int + How many items to preview at the begin and end + of the array when the arrays is bigger than the window. + The other elements will be ellipsed. + skip_new_lines : bool + If the array should be rendered as a single line of text + or if each element should be on its own line. + """ + cdef: + c_string result + PrettyPrintOptions options + + with nogil: + options = PrettyPrintOptions(indent, window) + options.skip_new_lines = skip_new_lines + check_status( + PrettyPrint( + deref(self.chunked_array), + options, + &result + ) + ) + + return frombytes(result, safe=True) + + def format(self, **kwargs): + import warnings + warnings.warn('ChunkedArray.format is deprecated, ' + 'use ChunkedArray.to_string') + return self.to_string(**kwargs) + + def __str__(self): + return self.to_string() + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full: bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + with nogil: + check_status(self.sp_chunked_array.get().ValidateFull()) + else: + with nogil: + check_status(self.sp_chunked_array.get().Validate()) + + @property + def null_count(self): + """ + Number of null entries + + Returns + ------- + int + """ + return self.chunked_array.null_count() + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the chunked array. + """ + size = 0 + for chunk in self.iterchunks(): + size += chunk.nbytes + return size + + def __sizeof__(self): + return super(ChunkedArray, self).__sizeof__() + self.nbytes + + def __iter__(self): + for chunk in self.iterchunks(): + for item in chunk: + yield item + + def __getitem__(self, key): + """ + Slice or return value at given index + + Parameters + ---------- + key : integer or slice + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + + Returns + ------- + value : Scalar (index) or ChunkedArray (slice) + """ + if isinstance(key, slice): + return _normalize_slice(self, key) + + return self.getitem(_normalize_index(key, self.chunked_array.length())) + + cdef getitem(self, int64_t index): + cdef int j + + for j in range(self.num_chunks): + if index < self.chunked_array.chunk(j).get().length(): + return self.chunk(j)[index] + else: + index -= self.chunked_array.chunk(j).get().length() + + def is_null(self, *, nan_is_null=False): + """ + Return boolean array indicating the null values. + + Parameters + ---------- + nan_is_null : bool (optional, default False) + Whether floating-point NaN values should also be considered null. + + Returns + ------- + array : boolean Array or ChunkedArray + """ + options = _pc().NullOptions(nan_is_null=nan_is_null) + return _pc().call_function('is_null', [self], options) + + def is_valid(self): + """ + Return boolean array indicating the non-null values. + """ + return _pc().is_valid(self) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def fill_null(self, fill_value): + """ + See pyarrow.compute.fill_null docstring for usage. + """ + return _pc().fill_null(self, fill_value) + + def equals(self, ChunkedArray other): + """ + Return whether the contents of two chunked arrays are equal. + + Parameters + ---------- + other : pyarrow.ChunkedArray + Chunked array to compare against. + + Returns + ------- + are_equal : bool + """ + if other is None: + return False + + cdef: + CChunkedArray* this_arr = self.chunked_array + CChunkedArray* other_arr = other.chunked_array + c_bool result + + with nogil: + result = this_arr.Equals(deref(other_arr)) + + return result + + def _to_pandas(self, options, **kwargs): + return _array_like_to_pandas(self, options) + + def to_numpy(self): + """ + Return a NumPy copy of this array (experimental). + + Returns + ------- + array : numpy.ndarray + """ + cdef: + PyObject* out + PandasOptions c_options + object values + + if self.type.id == _Type_EXTENSION: + storage_array = chunked_array( + [chunk.storage for chunk in self.iterchunks()], + type=self.type.storage_type + ) + return storage_array.to_numpy() + + with nogil: + check_status( + ConvertChunkedArrayToPandas( + c_options, + self.sp_chunked_array, + self, + &out + ) + ) + + # wrap_array_output uses pandas to convert to Categorical, here + # always convert to numpy array + values = PyObject_to_object(out) + + if isinstance(values, dict): + values = np.take(values['dictionary'], values['indices']) + + return values + + def __array__(self, dtype=None): + values = self.to_numpy() + if dtype is None: + return values + return values.astype(dtype) + + def cast(self, object target_type, safe=True): + """ + Cast array values to another data type + + See pyarrow.compute.cast for usage + """ + return _pc().cast(self, target_type, safe=safe) + + def dictionary_encode(self, null_encoding='mask'): + """ + Compute dictionary-encoded representation of array + + Returns + ------- + pyarrow.ChunkedArray + Same chunking as the input, all chunks share a common dictionary. + """ + options = _pc().DictionaryEncodeOptions(null_encoding) + return _pc().call_function('dictionary_encode', [self], options) + + def flatten(self, MemoryPool memory_pool=None): + """ + Flatten this ChunkedArray. If it has a struct type, the column is + flattened into one array per struct field. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + result : List[ChunkedArray] + """ + cdef: + vector[shared_ptr[CChunkedArray]] flattened + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + flattened = GetResultValue(self.chunked_array.Flatten(pool)) + + return [pyarrow_wrap_chunked_array(col) for col in flattened] + + def combine_chunks(self, MemoryPool memory_pool=None): + """ + Flatten this ChunkedArray into a single non-chunked array. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + result : Array + """ + return concat_arrays(self.chunks) + + def unique(self): + """ + Compute distinct elements in array + + Returns + ------- + pyarrow.Array + """ + return _pc().call_function('unique', [self]) + + def value_counts(self): + """ + Compute counts of unique elements in array. + + Returns + ------- + An array of <input type "Values", int64_t "Counts"> structs + """ + return _pc().call_function('value_counts', [self]) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this ChunkedArray + + Parameters + ---------- + offset : int, default 0 + Offset from start of array to slice + length : int, default None + Length of slice (default is until end of batch starting from + offset) + + Returns + ------- + sliced : ChunkedArray + """ + cdef shared_ptr[CChunkedArray] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.chunked_array.Slice(offset) + else: + result = self.chunked_array.Slice(offset, length) + + return pyarrow_wrap_chunked_array(result) + + def filter(self, mask, object null_selection_behavior="drop"): + """ + Select values from a chunked array. See pyarrow.compute.filter for full + usage. + """ + return _pc().filter(self, mask, null_selection_behavior) + + def index(self, value, start=None, end=None, *, memory_pool=None): + """ + Find the first index of a value. + + See pyarrow.compute.index for full usage. + """ + return _pc().index(self, value, start, end, memory_pool=memory_pool) + + def take(self, object indices): + """ + Select values from a chunked array. See pyarrow.compute.take for full + usage. + """ + return _pc().take(self, indices) + + def drop_null(self): + """ + Remove missing values from a chunked array. + See pyarrow.compute.drop_null for full description. + """ + return _pc().drop_null(self) + + def unify_dictionaries(self, MemoryPool memory_pool=None): + """ + Unify dictionaries across all chunks. + + This method returns an equivalent chunked array, but where all + chunks share the same dictionary values. Dictionary indices are + transposed accordingly. + + If there are no dictionaries in the chunked array, it is returned + unchanged. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + result : ChunkedArray + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + shared_ptr[CChunkedArray] c_result + + with nogil: + c_result = GetResultValue(CDictionaryUnifier.UnifyChunkedArray( + self.sp_chunked_array, pool)) + + return pyarrow_wrap_chunked_array(c_result) + + @property + def num_chunks(self): + """ + Number of underlying chunks + + Returns + ------- + int + """ + return self.chunked_array.num_chunks() + + def chunk(self, i): + """ + Select a chunk by its index + + Parameters + ---------- + i : int + + Returns + ------- + pyarrow.Array + """ + if i >= self.num_chunks or i < 0: + raise IndexError('Chunk index out of range.') + + return pyarrow_wrap_array(self.chunked_array.chunk(i)) + + @property + def chunks(self): + return list(self.iterchunks()) + + def iterchunks(self): + for i in range(self.num_chunks): + yield self.chunk(i) + + def to_pylist(self): + """ + Convert to a list of native Python objects. + """ + result = [] + for i in range(self.num_chunks): + result += self.chunk(i).to_pylist() + return result + + +def chunked_array(arrays, type=None): + """ + Construct chunked array from list of array-like objects + + Parameters + ---------- + arrays : Array, list of Array, or values coercible to arrays + Must all be the same data type. Can be empty only if type also passed. + type : DataType or string coercible to DataType + + Returns + ------- + ChunkedArray + """ + cdef: + Array arr + vector[shared_ptr[CArray]] c_arrays + shared_ptr[CChunkedArray] sp_chunked_array + + type = ensure_type(type, allow_none=True) + + if isinstance(arrays, Array): + arrays = [arrays] + + for x in arrays: + arr = x if isinstance(x, Array) else array(x, type=type) + + if type is None: + # it allows more flexible chunked array construction from to coerce + # subsequent arrays to the firstly inferred array type + # it also spares the inference overhead after the first chunk + type = arr.type + else: + if arr.type != type: + raise TypeError( + "All array chunks must have type {}".format(type) + ) + + c_arrays.push_back(arr.sp_array) + + if c_arrays.size() == 0 and type is None: + raise ValueError("When passing an empty collection of arrays " + "you must also pass the data type") + + sp_chunked_array.reset( + new CChunkedArray(c_arrays, pyarrow_unwrap_data_type(type)) + ) + with nogil: + check_status(sp_chunked_array.get().Validate()) + + return pyarrow_wrap_chunked_array(sp_chunked_array) + + +cdef _schema_from_arrays(arrays, names, metadata, shared_ptr[CSchema]* schema): + cdef: + Py_ssize_t K = len(arrays) + c_string c_name + shared_ptr[CDataType] c_type + shared_ptr[const CKeyValueMetadata] c_meta + vector[shared_ptr[CField]] c_fields + + if metadata is not None: + c_meta = KeyValueMetadata(metadata).unwrap() + + if K == 0: + if names is None or len(names) == 0: + schema.reset(new CSchema(c_fields, c_meta)) + return arrays + else: + raise ValueError('Length of names ({}) does not match ' + 'length of arrays ({})'.format(len(names), K)) + + c_fields.resize(K) + + if names is None: + raise ValueError('Must pass names or schema when constructing ' + 'Table or RecordBatch.') + + if len(names) != K: + raise ValueError('Length of names ({}) does not match ' + 'length of arrays ({})'.format(len(names), K)) + + converted_arrays = [] + for i in range(K): + val = arrays[i] + if not isinstance(val, (Array, ChunkedArray)): + val = array(val) + + c_type = (<DataType> val.type).sp_type + + if names[i] is None: + c_name = b'None' + else: + c_name = tobytes(names[i]) + c_fields[i].reset(new CField(c_name, c_type, True)) + converted_arrays.append(val) + + schema.reset(new CSchema(c_fields, c_meta)) + return converted_arrays + + +cdef _sanitize_arrays(arrays, names, schema, metadata, + shared_ptr[CSchema]* c_schema): + cdef Schema cy_schema + if schema is None: + converted_arrays = _schema_from_arrays(arrays, names, metadata, + c_schema) + else: + if names is not None: + raise ValueError('Cannot pass both schema and names') + if metadata is not None: + raise ValueError('Cannot pass both schema and metadata') + cy_schema = schema + + if len(schema) != len(arrays): + raise ValueError('Schema and number of arrays unequal') + + c_schema[0] = cy_schema.sp_schema + converted_arrays = [] + for i, item in enumerate(arrays): + item = asarray(item, type=schema[i].type) + converted_arrays.append(item) + return converted_arrays + + +cdef class RecordBatch(_PandasConvertible): + """ + Batch of rows of columns of equal length + + Warnings + -------- + Do not call this class's constructor directly, use one of the + ``RecordBatch.from_*`` functions instead. + """ + + def __cinit__(self): + self.batch = NULL + self._schema = None + + def __init__(self): + raise TypeError("Do not call RecordBatch's constructor directly, use " + "one of the `RecordBatch.from_*` functions instead.") + + cdef void init(self, const shared_ptr[CRecordBatch]& batch): + self.sp_batch = batch + self.batch = batch.get() + + @staticmethod + def from_pydict(mapping, schema=None, metadata=None): + """ + Construct a RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + RecordBatch + """ + + return _from_pydict(cls=RecordBatch, + mapping=mapping, + schema=schema, + metadata=metadata) + + def __reduce__(self): + return _reconstruct_record_batch, (self.columns, self.schema) + + def __len__(self): + return self.batch.num_rows() + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def to_string(self, show_metadata=False): + # Use less verbose schema output. + schema_as_string = self.schema.to_string( + show_field_metadata=show_metadata, + show_schema_metadata=show_metadata + ) + return 'pyarrow.{}\n{}'.format(type(self).__name__, schema_as_string) + + def __repr__(self): + return self.to_string() + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full: bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + with nogil: + check_status(self.batch.ValidateFull()) + else: + with nogil: + check_status(self.batch.Validate()) + + def replace_schema_metadata(self, metadata=None): + """ + Create shallow copy of record batch by replacing schema + key-value metadata with the indicated new metadata (which may be None, + which deletes any existing metadata + + Parameters + ---------- + metadata : dict, default None + + Returns + ------- + shallow_copy : RecordBatch + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_meta + shared_ptr[CRecordBatch] c_batch + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + with nogil: + c_batch = self.batch.ReplaceSchemaMetadata(c_meta) + + return pyarrow_wrap_batch(c_batch) + + @property + def num_columns(self): + """ + Number of columns + + Returns + ------- + int + """ + return self.batch.num_columns() + + @property + def num_rows(self): + """ + Number of rows + + Due to the definition of a RecordBatch, all columns have the same + number of rows. + + Returns + ------- + int + """ + return len(self) + + @property + def schema(self): + """ + Schema of the RecordBatch and its columns + + Returns + ------- + pyarrow.Schema + """ + if self._schema is None: + self._schema = pyarrow_wrap_schema(self.batch.schema()) + + return self._schema + + def field(self, i): + """ + Select a schema field by its column name or numeric index + + Parameters + ---------- + i : int or string + The index or name of the field to retrieve + + Returns + ------- + pyarrow.Field + """ + return self.schema.field(i) + + @property + def columns(self): + """ + List of all columns in numerical order + + Returns + ------- + list of pa.Array + """ + return [self.column(i) for i in range(self.num_columns)] + + def _ensure_integer_index(self, i): + """ + Ensure integer index (convert string column name to integer if needed). + """ + if isinstance(i, (bytes, str)): + field_indices = self.schema.get_all_field_indices(i) + + if len(field_indices) == 0: + raise KeyError( + "Field \"{}\" does not exist in record batch schema" + .format(i)) + elif len(field_indices) > 1: + raise KeyError( + "Field \"{}\" exists {} times in record batch schema" + .format(i, len(field_indices))) + else: + return field_indices[0] + elif isinstance(i, int): + return i + else: + raise TypeError("Index must either be string or integer") + + def column(self, i): + """ + Select single column from record batch + + Parameters + ---------- + i : int or string + The index or name of the column to retrieve. + + Returns + ------- + column : pyarrow.Array + """ + return self._column(self._ensure_integer_index(i)) + + def _column(self, int i): + """ + Select single column from record batch by its numeric index. + + Parameters + ---------- + i : int + The index of the column to retrieve. + + Returns + ------- + column : pyarrow.Array + """ + cdef int index = <int> _normalize_index(i, self.num_columns) + cdef Array result = pyarrow_wrap_array(self.batch.column(index)) + result._name = self.schema[index].name + return result + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the record batch. + """ + size = 0 + for i in range(self.num_columns): + size += self.column(i).nbytes + return size + + def __sizeof__(self): + return super(RecordBatch, self).__sizeof__() + self.nbytes + + def __getitem__(self, key): + """ + Slice or return column at given index or column name + + Parameters + ---------- + key : integer, str, or slice + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view + + Returns + ------- + value : Array (index/column) or RecordBatch (slice) + """ + if isinstance(key, slice): + return _normalize_slice(self, key) + else: + return self.column(key) + + def serialize(self, memory_pool=None): + """ + Write RecordBatch to Buffer as encapsulated IPC message. + + Parameters + ---------- + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + + Returns + ------- + serialized : Buffer + """ + cdef shared_ptr[CBuffer] buffer + cdef CIpcWriteOptions options = CIpcWriteOptions.Defaults() + options.memory_pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + buffer = GetResultValue( + SerializeRecordBatch(deref(self.batch), options)) + return pyarrow_wrap_buffer(buffer) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this RecordBatch + + Parameters + ---------- + offset : int, default 0 + Offset from start of record batch to slice + length : int, default None + Length of slice (default is until end of batch starting from + offset) + + Returns + ------- + sliced : RecordBatch + """ + cdef shared_ptr[CRecordBatch] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.batch.Slice(offset) + else: + result = self.batch.Slice(offset, length) + + return pyarrow_wrap_batch(result) + + def filter(self, mask, object null_selection_behavior="drop"): + """ + Select record from a record batch. See pyarrow.compute.filter for full + usage. + """ + return _pc().filter(self, mask, null_selection_behavior) + + def equals(self, object other, bint check_metadata=False): + """ + Check if contents of two record batches are equal. + + Parameters + ---------- + other : pyarrow.RecordBatch + RecordBatch to compare against. + check_metadata : bool, default False + Whether schema metadata equality should be checked as well. + + Returns + ------- + are_equal : bool + """ + cdef: + CRecordBatch* this_batch = self.batch + shared_ptr[CRecordBatch] other_batch = pyarrow_unwrap_batch(other) + c_bool result + + if not other_batch: + return False + + with nogil: + result = this_batch.Equals(deref(other_batch), check_metadata) + + return result + + def take(self, object indices): + """ + Select records from a RecordBatch. See pyarrow.compute.take for full + usage. + """ + return _pc().take(self, indices) + + def drop_null(self): + """ + Remove missing values from a RecordBatch. + See pyarrow.compute.drop_null for full usage. + """ + return _pc().drop_null(self) + + def to_pydict(self): + """ + Convert the RecordBatch to a dict or OrderedDict. + + Returns + ------- + dict + """ + entries = [] + for i in range(self.batch.num_columns()): + name = bytes(self.batch.column_name(i)).decode('utf8') + column = self[i].to_pylist() + entries.append((name, column)) + return ordered_dict(entries) + + def _to_pandas(self, options, **kwargs): + return Table.from_batches([self])._to_pandas(options, **kwargs) + + @classmethod + def from_pandas(cls, df, Schema schema=None, preserve_index=None, + nthreads=None, columns=None): + """ + Convert pandas.DataFrame to an Arrow RecordBatch + + Parameters + ---------- + df : pandas.DataFrame + schema : pyarrow.Schema, optional + The expected schema of the RecordBatch. This can be used to + indicate the type of columns if we cannot infer it automatically. + If passed, the output will have exactly this schema. Columns + specified in the schema that are not found in the DataFrame columns + or its index will raise an error. Additional columns or index + levels in the DataFrame which are not specified in the schema will + be ignored. + preserve_index : bool, optional + Whether to store the index as an additional column in the resulting + ``RecordBatch``. The default of None will store the index as a + column, except for RangeIndex which is stored as metadata only. Use + ``preserve_index=True`` to force it to be stored as a column. + nthreads : int, default None (may use up to system CPU count threads) + If greater than 1, convert columns to Arrow in parallel using + indicated number of threads + columns : list, optional + List of column to be converted. If None, use all columns. + + Returns + ------- + pyarrow.RecordBatch + """ + from pyarrow.pandas_compat import dataframe_to_arrays + arrays, schema = dataframe_to_arrays( + df, schema, preserve_index, nthreads=nthreads, columns=columns + ) + return cls.from_arrays(arrays, schema=schema) + + @staticmethod + def from_arrays(list arrays, names=None, schema=None, metadata=None): + """ + Construct a RecordBatch from multiple pyarrow.Arrays + + Parameters + ---------- + arrays : list of pyarrow.Array + One for each field in RecordBatch + names : list of str, optional + Names for the batch fields. If not passed, schema must be passed + schema : Schema, default None + Schema for the created batch. If not passed, names must be passed + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + pyarrow.RecordBatch + """ + cdef: + Array arr + shared_ptr[CSchema] c_schema + vector[shared_ptr[CArray]] c_arrays + int64_t num_rows + + if len(arrays) > 0: + num_rows = len(arrays[0]) + else: + num_rows = 0 + + if isinstance(names, Schema): + import warnings + warnings.warn("Schema passed to names= option, please " + "pass schema= explicitly. " + "Will raise exception in future", FutureWarning) + schema = names + names = None + + converted_arrays = _sanitize_arrays(arrays, names, schema, metadata, + &c_schema) + + c_arrays.reserve(len(arrays)) + for arr in converted_arrays: + if len(arr) != num_rows: + raise ValueError('Arrays were not all the same length: ' + '{0} vs {1}'.format(len(arr), num_rows)) + c_arrays.push_back(arr.sp_array) + + result = pyarrow_wrap_batch(CRecordBatch.Make(c_schema, num_rows, + c_arrays)) + result.validate() + return result + + @staticmethod + def from_struct_array(StructArray struct_array): + """ + Construct a RecordBatch from a StructArray. + + Each field in the StructArray will become a column in the resulting + ``RecordBatch``. + + Parameters + ---------- + struct_array : StructArray + Array to construct the record batch from. + + Returns + ------- + pyarrow.RecordBatch + """ + cdef: + shared_ptr[CRecordBatch] c_record_batch + with nogil: + c_record_batch = GetResultValue( + CRecordBatch.FromStructArray(struct_array.sp_array)) + return pyarrow_wrap_batch(c_record_batch) + + def _export_to_c(self, uintptr_t out_ptr, uintptr_t out_schema_ptr=0): + """ + Export to a C ArrowArray struct, given its pointer. + + If a C ArrowSchema struct pointer is also given, the record batch + schema is exported to it at the same time. + + Parameters + ---------- + out_ptr: int + The raw pointer to a C ArrowArray struct. + out_schema_ptr: int (optional) + The raw pointer to a C ArrowSchema struct. + + Be careful: if you don't pass the ArrowArray struct to a consumer, + array memory will leak. This is a low-level function intended for + expert users. + """ + with nogil: + check_status(ExportRecordBatch(deref(self.sp_batch), + <ArrowArray*> out_ptr, + <ArrowSchema*> out_schema_ptr)) + + @staticmethod + def _import_from_c(uintptr_t in_ptr, schema): + """ + Import RecordBatch from a C ArrowArray struct, given its pointer + and the imported schema. + + Parameters + ---------- + in_ptr: int + The raw pointer to a C ArrowArray struct. + type: Schema or int + Either a Schema object, or the raw pointer to a C ArrowSchema + struct. + + This is a low-level function intended for expert users. + """ + cdef: + shared_ptr[CRecordBatch] c_batch + + c_schema = pyarrow_unwrap_schema(schema) + if c_schema == nullptr: + # Not a Schema object, perhaps a raw ArrowSchema pointer + schema_ptr = <uintptr_t> schema + with nogil: + c_batch = GetResultValue(ImportRecordBatch( + <ArrowArray*> in_ptr, <ArrowSchema*> schema_ptr)) + else: + with nogil: + c_batch = GetResultValue(ImportRecordBatch( + <ArrowArray*> in_ptr, c_schema)) + return pyarrow_wrap_batch(c_batch) + + +def _reconstruct_record_batch(columns, schema): + """ + Internal: reconstruct RecordBatch from pickled components. + """ + return RecordBatch.from_arrays(columns, schema=schema) + + +def table_to_blocks(options, Table table, categories, extension_columns): + cdef: + PyObject* result_obj + shared_ptr[CTable] c_table + CMemoryPool* pool + PandasOptions c_options = _convert_pandas_options(options) + + if categories is not None: + c_options.categorical_columns = {tobytes(cat) for cat in categories} + if extension_columns is not None: + c_options.extension_columns = {tobytes(col) + for col in extension_columns} + + # ARROW-3789(wesm); Convert date/timestamp types to datetime64[ns] + c_options.coerce_temporal_nanoseconds = True + + if c_options.self_destruct: + # Move the shared_ptr, table is now unsafe to use further + c_table = move(table.sp_table) + table.table = NULL + else: + c_table = table.sp_table + + with nogil: + check_status( + libarrow.ConvertTableToPandas(c_options, move(c_table), + &result_obj) + ) + + return PyObject_to_object(result_obj) + + +cdef class Table(_PandasConvertible): + """ + A collection of top-level named, equal length Arrow arrays. + + Warning + ------- + Do not call this class's constructor directly, use one of the ``from_*`` + methods instead. + """ + + def __cinit__(self): + self.table = NULL + + def __init__(self): + raise TypeError("Do not call Table's constructor directly, use one of " + "the `Table.from_*` functions instead.") + + def to_string(self, *, show_metadata=False, preview_cols=0): + """ + Return human-readable string representation of Table. + + Parameters + ---------- + show_metadata : bool, default True + Display Field-level and Schema-level KeyValueMetadata. + preview_cols : int, default 0 + Display values of the columns for the first N columns. + + Returns + ------- + str + """ + # Use less verbose schema output. + schema_as_string = self.schema.to_string( + show_field_metadata=show_metadata, + show_schema_metadata=show_metadata + ) + title = 'pyarrow.{}\n{}'.format(type(self).__name__, schema_as_string) + pieces = [title] + if preview_cols: + pieces.append('----') + for i in range(min(self.num_columns, preview_cols)): + pieces.append('{}: {}'.format( + self.field(i).name, + self.column(i).to_string(indent=0, skip_new_lines=True) + )) + if preview_cols < self.num_columns: + pieces.append('...') + return '\n'.join(pieces) + + def __repr__(self): + if self.table == NULL: + raise ValueError("Table's internal pointer is NULL, do not use " + "any methods or attributes on this object") + return self.to_string(preview_cols=10) + + cdef void init(self, const shared_ptr[CTable]& table): + self.sp_table = table + self.table = table.get() + + def validate(self, *, full=False): + """ + Perform validation checks. An exception is raised if validation fails. + + By default only cheap validation checks are run. Pass `full=True` + for thorough validation checks (potentially O(n)). + + Parameters + ---------- + full: bool, default False + If True, run expensive checks, otherwise cheap checks only. + + Raises + ------ + ArrowInvalid + """ + if full: + with nogil: + check_status(self.table.ValidateFull()) + else: + with nogil: + check_status(self.table.Validate()) + + def __reduce__(self): + # Reduce the columns as ChunkedArrays to avoid serializing schema + # data twice + columns = [col for col in self.columns] + return _reconstruct_table, (columns, self.schema) + + def __getitem__(self, key): + """ + Slice or return column at given index or column name. + + Parameters + ---------- + key : integer, str, or slice + Slices with step not equal to 1 (or None) will produce a copy + rather than a zero-copy view. + + Returns + ------- + ChunkedArray (index/column) or Table (slice) + """ + if isinstance(key, slice): + return _normalize_slice(self, key) + else: + return self.column(key) + + def slice(self, offset=0, length=None): + """ + Compute zero-copy slice of this Table. + + Parameters + ---------- + offset : int, default 0 + Offset from start of table to slice. + length : int, default None + Length of slice (default is until end of table starting from + offset). + + Returns + ------- + Table + """ + cdef shared_ptr[CTable] result + + if offset < 0: + raise IndexError('Offset must be non-negative') + + offset = min(len(self), offset) + if length is None: + result = self.table.Slice(offset) + else: + result = self.table.Slice(offset, length) + + return pyarrow_wrap_table(result) + + def filter(self, mask, object null_selection_behavior="drop"): + """ + Select records from a Table. See :func:`pyarrow.compute.filter` for + full usage. + """ + return _pc().filter(self, mask, null_selection_behavior) + + def take(self, object indices): + """ + Select records from a Table. See :func:`pyarrow.compute.take` for full + usage. + """ + return _pc().take(self, indices) + + def drop_null(self): + """ + Remove missing values from a Table. + See :func:`pyarrow.compute.drop_null` for full usage. + """ + return _pc().drop_null(self) + + def select(self, object columns): + """ + Select columns of the Table. + + Returns a new Table with the specified columns, and metadata + preserved. + + Parameters + ---------- + columns : list-like + The column names or integer indices to select. + + Returns + ------- + Table + """ + cdef: + shared_ptr[CTable] c_table + vector[int] c_indices + + for idx in columns: + idx = self._ensure_integer_index(idx) + idx = _normalize_index(idx, self.num_columns) + c_indices.push_back(<int> idx) + + with nogil: + c_table = GetResultValue(self.table.SelectColumns(move(c_indices))) + + return pyarrow_wrap_table(c_table) + + def replace_schema_metadata(self, metadata=None): + """ + Create shallow copy of table by replacing schema + key-value metadata with the indicated new metadata (which may be None), + which deletes any existing metadata. + + Parameters + ---------- + metadata : dict, default None + + Returns + ------- + Table + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_meta + shared_ptr[CTable] c_table + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + with nogil: + c_table = self.table.ReplaceSchemaMetadata(c_meta) + + return pyarrow_wrap_table(c_table) + + def flatten(self, MemoryPool memory_pool=None): + """ + Flatten this Table. + + Each column with a struct type is flattened + into one column per struct field. Other columns are left unchanged. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + Table + """ + cdef: + shared_ptr[CTable] flattened + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + flattened = GetResultValue(self.table.Flatten(pool)) + + return pyarrow_wrap_table(flattened) + + def combine_chunks(self, MemoryPool memory_pool=None): + """ + Make a new table by combining the chunks this table has. + + All the underlying chunks in the ChunkedArray of each column are + concatenated into zero or one chunk. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + + Returns + ------- + Table + """ + cdef: + shared_ptr[CTable] combined + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + combined = GetResultValue(self.table.CombineChunks(pool)) + + return pyarrow_wrap_table(combined) + + def unify_dictionaries(self, MemoryPool memory_pool=None): + """ + Unify dictionaries across all chunks. + + This method returns an equivalent table, but where all chunks of + each column share the same dictionary values. Dictionary indices + are transposed accordingly. + + Columns without dictionaries are returned unchanged. + + Parameters + ---------- + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool + + Returns + ------- + Table + """ + cdef: + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + shared_ptr[CTable] c_result + + with nogil: + c_result = GetResultValue(CDictionaryUnifier.UnifyTable( + deref(self.table), pool)) + + return pyarrow_wrap_table(c_result) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def equals(self, Table other, bint check_metadata=False): + """ + Check if contents of two tables are equal. + + Parameters + ---------- + other : pyarrow.Table + Table to compare against. + check_metadata : bool, default False + Whether schema metadata equality should be checked as well. + + Returns + ------- + bool + """ + if other is None: + return False + + cdef: + CTable* this_table = self.table + CTable* other_table = other.table + c_bool result + + with nogil: + result = this_table.Equals(deref(other_table), check_metadata) + + return result + + def cast(self, Schema target_schema, bint safe=True): + """ + Cast table values to another schema. + + Parameters + ---------- + target_schema : Schema + Schema to cast to, the names and order of fields must match. + safe : bool, default True + Check for overflows or other unsafe conversions. + + Returns + ------- + Table + """ + cdef: + ChunkedArray column, casted + Field field + list newcols = [] + + if self.schema.names != target_schema.names: + raise ValueError("Target schema's field names are not matching " + "the table's field names: {!r}, {!r}" + .format(self.schema.names, target_schema.names)) + + for column, field in zip(self.itercolumns(), target_schema): + casted = column.cast(field.type, safe=safe) + newcols.append(casted) + + return Table.from_arrays(newcols, schema=target_schema) + + @classmethod + def from_pandas(cls, df, Schema schema=None, preserve_index=None, + nthreads=None, columns=None, bint safe=True): + """ + Convert pandas.DataFrame to an Arrow Table. + + The column types in the resulting Arrow Table are inferred from the + dtypes of the pandas.Series in the DataFrame. In the case of non-object + Series, the NumPy dtype is translated to its Arrow equivalent. In the + case of `object`, we need to guess the datatype by looking at the + Python objects in this Series. + + Be aware that Series of the `object` dtype don't carry enough + information to always lead to a meaningful Arrow type. In the case that + we cannot infer a type, e.g. because the DataFrame is of length 0 or + the Series only contains None/nan objects, the type is set to + null. This behavior can be avoided by constructing an explicit schema + and passing it to this function. + + Parameters + ---------- + df : pandas.DataFrame + schema : pyarrow.Schema, optional + The expected schema of the Arrow Table. This can be used to + indicate the type of columns if we cannot infer it automatically. + If passed, the output will have exactly this schema. Columns + specified in the schema that are not found in the DataFrame columns + or its index will raise an error. Additional columns or index + levels in the DataFrame which are not specified in the schema will + be ignored. + preserve_index : bool, optional + Whether to store the index as an additional column in the resulting + ``Table``. The default of None will store the index as a column, + except for RangeIndex which is stored as metadata only. Use + ``preserve_index=True`` to force it to be stored as a column. + nthreads : int, default None (may use up to system CPU count threads) + If greater than 1, convert columns to Arrow in parallel using + indicated number of threads. + columns : list, optional + List of column to be converted. If None, use all columns. + safe : bool, default True + Check for overflows or other unsafe conversions. + + Returns + ------- + Table + + Examples + -------- + + >>> import pandas as pd + >>> import pyarrow as pa + >>> df = pd.DataFrame({ + ... 'int': [1, 2], + ... 'str': ['a', 'b'] + ... }) + >>> pa.Table.from_pandas(df) + <pyarrow.lib.Table object at 0x7f05d1fb1b40> + """ + from pyarrow.pandas_compat import dataframe_to_arrays + arrays, schema = dataframe_to_arrays( + df, + schema=schema, + preserve_index=preserve_index, + nthreads=nthreads, + columns=columns, + safe=safe + ) + return cls.from_arrays(arrays, schema=schema) + + @staticmethod + def from_arrays(arrays, names=None, schema=None, metadata=None): + """ + Construct a Table from Arrow arrays. + + Parameters + ---------- + arrays : list of pyarrow.Array or pyarrow.ChunkedArray + Equal-length arrays that should form the table. + names : list of str, optional + Names for the table columns. If not passed, schema must be passed. + schema : Schema, default None + Schema for the created table. If not passed, names must be passed. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table + """ + cdef: + vector[shared_ptr[CChunkedArray]] columns + shared_ptr[CSchema] c_schema + int i, K = <int> len(arrays) + + converted_arrays = _sanitize_arrays(arrays, names, schema, metadata, + &c_schema) + + columns.reserve(K) + for item in converted_arrays: + if isinstance(item, Array): + columns.push_back( + make_shared[CChunkedArray]( + (<Array> item).sp_array + ) + ) + elif isinstance(item, ChunkedArray): + columns.push_back((<ChunkedArray> item).sp_chunked_array) + else: + raise TypeError(type(item)) + + result = pyarrow_wrap_table(CTable.Make(c_schema, columns)) + result.validate() + return result + + @staticmethod + def from_pydict(mapping, schema=None, metadata=None): + """ + Construct a Table from Arrow arrays or columns. + + Parameters + ---------- + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table + """ + + return _from_pydict(cls=Table, + mapping=mapping, + schema=schema, + metadata=metadata) + + @staticmethod + def from_batches(batches, Schema schema=None): + """ + Construct a Table from a sequence or iterator of Arrow RecordBatches. + + Parameters + ---------- + batches : sequence or iterator of RecordBatch + Sequence of RecordBatch to be converted, all schemas must be equal. + schema : Schema, default None + If not passed, will be inferred from the first RecordBatch. + + Returns + ------- + Table + """ + cdef: + vector[shared_ptr[CRecordBatch]] c_batches + shared_ptr[CTable] c_table + shared_ptr[CSchema] c_schema + RecordBatch batch + + for batch in batches: + c_batches.push_back(batch.sp_batch) + + if schema is None: + if c_batches.size() == 0: + raise ValueError('Must pass schema, or at least ' + 'one RecordBatch') + c_schema = c_batches[0].get().schema() + else: + c_schema = schema.sp_schema + + with nogil: + c_table = GetResultValue( + CTable.FromRecordBatches(c_schema, move(c_batches))) + + return pyarrow_wrap_table(c_table) + + def to_batches(self, max_chunksize=None, **kwargs): + """ + Convert Table to list of (contiguous) RecordBatch objects. + + Parameters + ---------- + max_chunksize : int, default None + Maximum size for RecordBatch chunks. Individual chunks may be + smaller depending on the chunk layout of individual columns. + + Returns + ------- + list of RecordBatch + """ + cdef: + unique_ptr[TableBatchReader] reader + int64_t c_max_chunksize + list result = [] + shared_ptr[CRecordBatch] batch + + reader.reset(new TableBatchReader(deref(self.table))) + + if 'chunksize' in kwargs: + max_chunksize = kwargs['chunksize'] + msg = ('The parameter chunksize is deprecated for ' + 'pyarrow.Table.to_batches as of 0.15, please use ' + 'the parameter max_chunksize instead') + warnings.warn(msg, FutureWarning) + + if max_chunksize is not None: + c_max_chunksize = max_chunksize + reader.get().set_chunksize(c_max_chunksize) + + while True: + with nogil: + check_status(reader.get().ReadNext(&batch)) + + if batch.get() == NULL: + break + + result.append(pyarrow_wrap_batch(batch)) + + return result + + def _to_pandas(self, options, categories=None, ignore_metadata=False, + types_mapper=None): + from pyarrow.pandas_compat import table_to_blockmanager + mgr = table_to_blockmanager( + options, self, categories, + ignore_metadata=ignore_metadata, + types_mapper=types_mapper) + return pandas_api.data_frame(mgr) + + def to_pydict(self): + """ + Convert the Table to a dict or OrderedDict. + + Returns + ------- + dict + """ + cdef: + size_t i + size_t num_columns = self.table.num_columns() + list entries = [] + ChunkedArray column + + for i in range(num_columns): + column = self.column(i) + entries.append((self.field(i).name, column.to_pylist())) + + return ordered_dict(entries) + + @property + def schema(self): + """ + Schema of the table and its columns. + + Returns + ------- + Schema + """ + return pyarrow_wrap_schema(self.table.schema()) + + def field(self, i): + """ + Select a schema field by its column name or numeric index. + + Parameters + ---------- + i : int or string + The index or name of the field to retrieve. + + Returns + ------- + Field + """ + return self.schema.field(i) + + def _ensure_integer_index(self, i): + """ + Ensure integer index (convert string column name to integer if needed). + """ + if isinstance(i, (bytes, str)): + field_indices = self.schema.get_all_field_indices(i) + + if len(field_indices) == 0: + raise KeyError("Field \"{}\" does not exist in table schema" + .format(i)) + elif len(field_indices) > 1: + raise KeyError("Field \"{}\" exists {} times in table schema" + .format(i, len(field_indices))) + else: + return field_indices[0] + elif isinstance(i, int): + return i + else: + raise TypeError("Index must either be string or integer") + + def column(self, i): + """ + Select a column by its column name, or numeric index. + + Parameters + ---------- + i : int or string + The index or name of the column to retrieve. + + Returns + ------- + ChunkedArray + """ + return self._column(self._ensure_integer_index(i)) + + def _column(self, int i): + """ + Select a column by its numeric index. + + Parameters + ---------- + i : int + The index of the column to retrieve. + + Returns + ------- + ChunkedArray + """ + cdef int index = <int> _normalize_index(i, self.num_columns) + cdef ChunkedArray result = pyarrow_wrap_chunked_array( + self.table.column(index)) + result._name = self.schema[index].name + return result + + def itercolumns(self): + """ + Iterator over all columns in their numerical order. + + Yields + ------ + ChunkedArray + """ + for i in range(self.num_columns): + yield self._column(i) + + @property + def columns(self): + """ + List of all columns in numerical order. + + Returns + ------- + list of ChunkedArray + """ + return [self._column(i) for i in range(self.num_columns)] + + @property + def num_columns(self): + """ + Number of columns in this table. + + Returns + ------- + int + """ + return self.table.num_columns() + + @property + def num_rows(self): + """ + Number of rows in this table. + + Due to the definition of a table, all columns have the same number of + rows. + + Returns + ------- + int + """ + return self.table.num_rows() + + def __len__(self): + return self.num_rows + + @property + def shape(self): + """ + Dimensions of the table: (#rows, #columns). + + Returns + ------- + (int, int) + Number of rows and number of columns. + """ + return (self.num_rows, self.num_columns) + + @property + def nbytes(self): + """ + Total number of bytes consumed by the elements of the table. + + Returns + ------- + int + """ + size = 0 + for column in self.itercolumns(): + size += column.nbytes + return size + + def __sizeof__(self): + return super(Table, self).__sizeof__() + self.nbytes + + def add_column(self, int i, field_, column): + """ + Add column to Table at position. + + A new table is returned with the column added, the original table + object is left unchanged. + + Parameters + ---------- + i : int + Index to place the column at. + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array, list of Array, or values coercible to arrays + Column data. + + Returns + ------- + Table + New table with the passed column added. + """ + cdef: + shared_ptr[CTable] c_table + Field c_field + ChunkedArray c_arr + + if isinstance(column, ChunkedArray): + c_arr = column + else: + c_arr = chunked_array(column) + + if isinstance(field_, Field): + c_field = field_ + else: + c_field = field(field_, c_arr.type) + + with nogil: + c_table = GetResultValue(self.table.AddColumn( + i, c_field.sp_field, c_arr.sp_chunked_array)) + + return pyarrow_wrap_table(c_table) + + def append_column(self, field_, column): + """ + Append column at end of columns. + + Parameters + ---------- + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array, list of Array, or values coercible to arrays + Column data. + + Returns + ------- + Table + New table with the passed column added. + """ + return self.add_column(self.num_columns, field_, column) + + def remove_column(self, int i): + """ + Create new Table with the indicated column removed. + + Parameters + ---------- + i : int + Index of column to remove. + + Returns + ------- + Table + New table without the column. + """ + cdef shared_ptr[CTable] c_table + + with nogil: + c_table = GetResultValue(self.table.RemoveColumn(i)) + + return pyarrow_wrap_table(c_table) + + def set_column(self, int i, field_, column): + """ + Replace column in Table at position. + + Parameters + ---------- + i : int + Index to place the column at. + field_ : str or Field + If a string is passed then the type is deduced from the column + data. + column : Array, list of Array, or values coercible to arrays + Column data. + + Returns + ------- + Table + New table with the passed column set. + """ + cdef: + shared_ptr[CTable] c_table + Field c_field + ChunkedArray c_arr + + if isinstance(column, ChunkedArray): + c_arr = column + else: + c_arr = chunked_array(column) + + if isinstance(field_, Field): + c_field = field_ + else: + c_field = field(field_, c_arr.type) + + with nogil: + c_table = GetResultValue(self.table.SetColumn( + i, c_field.sp_field, c_arr.sp_chunked_array)) + + return pyarrow_wrap_table(c_table) + + @property + def column_names(self): + """ + Names of the table's columns. + + Returns + ------- + list of str + """ + names = self.table.ColumnNames() + return [frombytes(name) for name in names] + + def rename_columns(self, names): + """ + Create new table with columns renamed to provided names. + + Parameters + ---------- + names : list of str + List of new column names. + + Returns + ------- + Table + """ + cdef: + shared_ptr[CTable] c_table + vector[c_string] c_names + + for name in names: + c_names.push_back(tobytes(name)) + + with nogil: + c_table = GetResultValue(self.table.RenameColumns(move(c_names))) + + return pyarrow_wrap_table(c_table) + + def drop(self, columns): + """ + Drop one or more columns and return a new table. + + Parameters + ---------- + columns : list of str + List of field names referencing existing columns. + + Raises + ------ + KeyError + If any of the passed columns name are not existing. + + Returns + ------- + Table + New table without the columns. + """ + indices = [] + for col in columns: + idx = self.schema.get_field_index(col) + if idx == -1: + raise KeyError("Column {!r} not found".format(col)) + indices.append(idx) + + indices.sort() + indices.reverse() + + table = self + for idx in indices: + table = table.remove_column(idx) + + return table + + +def _reconstruct_table(arrays, schema): + """ + Internal: reconstruct pa.Table from pickled components. + """ + return Table.from_arrays(arrays, schema=schema) + + +def record_batch(data, names=None, schema=None, metadata=None): + """ + Create a pyarrow.RecordBatch from another Python data structure or sequence + of arrays. + + Parameters + ---------- + data : pandas.DataFrame, list + A DataFrame or list of arrays or chunked arrays. + names : list, default None + Column names if list of arrays passed as data. Mutually exclusive with + 'schema' argument. + schema : Schema, default None + The expected schema of the RecordBatch. If not passed, will be inferred + from the data. Mutually exclusive with 'names' argument. + metadata : dict or Mapping, default None + Optional metadata for the schema (if schema not passed). + + Returns + ------- + RecordBatch + + See Also + -------- + RecordBatch.from_arrays, RecordBatch.from_pandas, table + """ + # accept schema as first argument for backwards compatibility / usability + if isinstance(names, Schema) and schema is None: + schema = names + names = None + + if isinstance(data, (list, tuple)): + return RecordBatch.from_arrays(data, names=names, schema=schema, + metadata=metadata) + elif _pandas_api.is_data_frame(data): + return RecordBatch.from_pandas(data, schema=schema) + else: + raise TypeError("Expected pandas DataFrame or list of arrays") + + +def table(data, names=None, schema=None, metadata=None, nthreads=None): + """ + Create a pyarrow.Table from a Python data structure or sequence of arrays. + + Parameters + ---------- + data : pandas.DataFrame, dict, list + A DataFrame, mapping of strings to Arrays or Python lists, or list of + arrays or chunked arrays. + names : list, default None + Column names if list of arrays passed as data. Mutually exclusive with + 'schema' argument. + schema : Schema, default None + The expected schema of the Arrow Table. If not passed, will be inferred + from the data. Mutually exclusive with 'names' argument. + If passed, the output will have exactly this schema (raising an error + when columns are not found in the data and ignoring additional data not + specified in the schema, when data is a dict or DataFrame). + metadata : dict or Mapping, default None + Optional metadata for the schema (if schema not passed). + nthreads : int, default None (may use up to system CPU count threads) + For pandas.DataFrame inputs: if greater than 1, convert columns to + Arrow in parallel using indicated number of threads. + + Returns + ------- + Table + + See Also + -------- + Table.from_arrays, Table.from_pandas, Table.from_pydict + """ + # accept schema as first argument for backwards compatibility / usability + if isinstance(names, Schema) and schema is None: + schema = names + names = None + + if isinstance(data, (list, tuple)): + return Table.from_arrays(data, names=names, schema=schema, + metadata=metadata) + elif isinstance(data, dict): + if names is not None: + raise ValueError( + "The 'names' argument is not valid when passing a dictionary") + return Table.from_pydict(data, schema=schema, metadata=metadata) + elif _pandas_api.is_data_frame(data): + if names is not None or metadata is not None: + raise ValueError( + "The 'names' and 'metadata' arguments are not valid when " + "passing a pandas DataFrame") + return Table.from_pandas(data, schema=schema, nthreads=nthreads) + else: + raise TypeError( + "Expected pandas DataFrame, python dictionary or list of arrays") + + +def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): + """ + Concatenate pyarrow.Table objects. + + If promote==False, a zero-copy concatenation will be performed. The schemas + of all the Tables must be the same (except the metadata), otherwise an + exception will be raised. The result Table will share the metadata with the + first table. + + If promote==True, any null type arrays will be casted to the type of other + arrays in the column of the same name. If a table is missing a particular + field, null values of the appropriate type will be generated to take the + place of the missing field. The new schema will share the metadata with the + first table. Each field in the new schema will share the metadata with the + first table which has the field defined. Note that type promotions may + involve additional allocations on the given ``memory_pool``. + + Parameters + ---------- + tables : iterable of pyarrow.Table objects + Pyarrow tables to concatenate into a single Table. + promote : bool, default False + If True, concatenate tables with null-filling and null type promotion. + memory_pool : MemoryPool, default None + For memory allocations, if required, otherwise use default pool. + """ + cdef: + vector[shared_ptr[CTable]] c_tables + shared_ptr[CTable] c_result_table + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + Table table + CConcatenateTablesOptions options = ( + CConcatenateTablesOptions.Defaults()) + + for table in tables: + c_tables.push_back(table.sp_table) + + with nogil: + options.unify_schemas = promote + c_result_table = GetResultValue( + ConcatenateTables(c_tables, options, pool)) + + return pyarrow_wrap_table(c_result_table) + + +def _from_pydict(cls, mapping, schema, metadata): + """ + Construct a Table/RecordBatch from Arrow arrays or columns. + + Parameters + ---------- + cls : Class Table/RecordBatch + mapping : dict or Mapping + A mapping of strings to Arrays or Python lists. + schema : Schema, default None + If not passed, will be inferred from the Mapping values. + metadata : dict or Mapping, default None + Optional metadata for the schema (if inferred). + + Returns + ------- + Table/RecordBatch + """ + + arrays = [] + if schema is None: + names = [] + for k, v in mapping.items(): + names.append(k) + arrays.append(asarray(v)) + return cls.from_arrays(arrays, names, metadata=metadata) + elif isinstance(schema, Schema): + for field in schema: + try: + v = mapping[field.name] + except KeyError: + try: + v = mapping[tobytes(field.name)] + except KeyError: + present = mapping.keys() + missing = [n for n in schema.names if n not in present] + raise KeyError( + "The passed mapping doesn't contain the " + "following field(s) of the schema: {}". + format(', '.join(missing)) + ) + arrays.append(asarray(v, type=field.type)) + # Will raise if metadata is not None + return cls.from_arrays(arrays, schema=schema, metadata=metadata) + else: + raise TypeError('Schema must be an instance of pyarrow.Schema') diff --git a/src/arrow/python/pyarrow/tensor.pxi b/src/arrow/python/pyarrow/tensor.pxi new file mode 100644 index 000000000..42fd44741 --- /dev/null +++ b/src/arrow/python/pyarrow/tensor.pxi @@ -0,0 +1,1025 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +cdef class Tensor(_Weakrefable): + """ + A n-dimensional array a.k.a Tensor. + """ + + def __init__(self): + raise TypeError("Do not call Tensor's constructor directly, use one " + "of the `pyarrow.Tensor.from_*` functions instead.") + + cdef void init(self, const shared_ptr[CTensor]& sp_tensor): + self.sp_tensor = sp_tensor + self.tp = sp_tensor.get() + self.type = pyarrow_wrap_data_type(self.tp.type()) + + def __repr__(self): + return """<pyarrow.Tensor> +type: {0.type} +shape: {0.shape} +strides: {0.strides}""".format(self) + + @staticmethod + def from_numpy(obj, dim_names=None): + """ + Create a Tensor from a numpy array. + + Parameters + ---------- + obj : numpy.ndarray + The source numpy array + dim_names : list, optional + Names of each dimension of the Tensor. + """ + cdef: + vector[c_string] c_dim_names + shared_ptr[CTensor] ctensor + + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + check_status(NdarrayToTensor(c_default_memory_pool(), obj, + c_dim_names, &ctensor)) + return pyarrow_wrap_tensor(ctensor) + + def to_numpy(self): + """ + Convert arrow::Tensor to numpy.ndarray with zero copy + """ + cdef PyObject* out + + check_status(TensorToNdarray(self.sp_tensor, self, &out)) + return PyObject_to_object(out) + + def equals(self, Tensor other): + """ + Return true if the tensors contains exactly equal data + """ + return self.tp.Equals(deref(other.tp)) + + def __eq__(self, other): + if isinstance(other, Tensor): + return self.equals(other) + else: + return NotImplemented + + def dim_name(self, i): + return frombytes(self.tp.dim_name(i)) + + @property + def dim_names(self): + return [frombytes(x) for x in tuple(self.tp.dim_names())] + + @property + def is_mutable(self): + return self.tp.is_mutable() + + @property + def is_contiguous(self): + return self.tp.is_contiguous() + + @property + def ndim(self): + return self.tp.ndim() + + @property + def size(self): + return self.tp.size() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.tp.shape()) + + @property + def strides(self): + return tuple(self.tp.strides()) + + def __getbuffer__(self, cp.Py_buffer* buffer, int flags): + buffer.buf = <char *> self.tp.data().get().data() + pep3118_format = self.type.pep3118_format + if pep3118_format is None: + raise NotImplementedError("type %s not supported for buffer " + "protocol" % (self.type,)) + buffer.format = pep3118_format + buffer.itemsize = self.type.bit_width // 8 + buffer.internal = NULL + buffer.len = self.tp.size() * buffer.itemsize + buffer.ndim = self.tp.ndim() + buffer.obj = self + if self.tp.is_mutable(): + buffer.readonly = 0 + else: + buffer.readonly = 1 + # NOTE: This assumes Py_ssize_t == int64_t, and that the shape + # and strides arrays lifetime is tied to the tensor's + buffer.shape = <Py_ssize_t *> &self.tp.shape()[0] + buffer.strides = <Py_ssize_t *> &self.tp.strides()[0] + buffer.suboffsets = NULL + + +ctypedef CSparseCOOIndex* _CSparseCOOIndexPtr + + +cdef class SparseCOOTensor(_Weakrefable): + """ + A sparse COO tensor. + """ + + def __init__(self): + raise TypeError("Do not call SparseCOOTensor's constructor directly, " + "use one of the `pyarrow.SparseCOOTensor.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCOOTensor]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """<pyarrow.SparseCOOTensor> +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCOOTensor + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, coords, shape, dim_names=None): + """ + Create arrow::SparseCOOTensor from numpy.ndarrays + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the rows. + coords : numpy.ndarray + Coordinates of the data. + shape : tuple + Shape of the tensor. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseCOOTensor indices + coords = np.require(coords, dtype='i8', requirements='C') + if coords.ndim != 2: + raise ValueError("Expected 2-dimensional array for " + "SparseCOOTensor indices") + + check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(), + data, coords, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + @staticmethod + def from_scipy(obj, dim_names=None): + """ + Convert scipy.sparse.coo_matrix to arrow::SparseCOOTensor + + Parameters + ---------- + obj : scipy.sparse.csr_matrix + The scipy matrix that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import scipy.sparse + if not isinstance(obj, scipy.sparse.coo_matrix): + raise TypeError( + "Expected scipy.sparse.coo_matrix, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + row = obj.row + col = obj.col + + # When SciPy's coo_matrix has canonical format, its indices matrix is + # sorted in column-major order. As Arrow's SparseCOOIndex is sorted + # in row-major order if it is canonical, we must sort indices matrix + # into row-major order to keep its canonicalness, here. + if obj.has_canonical_format: + order = np.lexsort((col, row)) # sort in row-major order + row = row[order] + col = col[order] + coords = np.vstack([row, col]).T + coords = np.require(coords, dtype='i8', requirements='C') + + check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(), + obj.data, coords, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + @staticmethod + def from_pydata_sparse(obj, dim_names=None): + """ + Convert pydata/sparse.COO to arrow::SparseCOOTensor. + + Parameters + ---------- + obj : pydata.sparse.COO + The sparse multidimensional array that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import sparse + if not isinstance(obj, sparse.COO): + raise TypeError( + "Expected sparse.COO, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + coords = np.require(obj.coords.T, dtype='i8', requirements='C') + + check_status(NdarraysToSparseCOOTensor(c_default_memory_pool(), + obj.data, coords, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCOOTensor. + + Parameters + ---------- + obj : Tensor + The tensor that should be converted. + """ + cdef shared_ptr[CSparseCOOTensor] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCOOTensor(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_coo_tensor(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCOOTensor to numpy.ndarrays with zero copy. + """ + cdef PyObject* out_data + cdef PyObject* out_coords + + check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_coords)) + return PyObject_to_object(out_data), PyObject_to_object(out_coords) + + def to_scipy(self): + """ + Convert arrow::SparseCOOTensor to scipy.sparse.coo_matrix. + """ + from scipy.sparse import coo_matrix + cdef PyObject* out_data + cdef PyObject* out_coords + + check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_coords)) + data = PyObject_to_object(out_data) + coords = PyObject_to_object(out_coords) + row, col = coords[:, 0], coords[:, 1] + result = coo_matrix((data[:, 0], (row, col)), shape=self.shape) + + # As the description in from_scipy above, we sorted indices matrix + # in row-major order if SciPy's coo_matrix has canonical format. + # So, we must call sum_duplicates() to make the result coo_matrix + # has canonical format. + if self.has_canonical_format: + result.sum_duplicates() + return result + + def to_pydata_sparse(self): + """ + Convert arrow::SparseCOOTensor to pydata/sparse.COO. + """ + from sparse import COO + cdef PyObject* out_data + cdef PyObject* out_coords + + check_status(SparseCOOTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_coords)) + data = PyObject_to_object(out_data) + coords = PyObject_to_object(out_coords) + result = COO(data=data[:, 0], coords=coords.T, shape=self.shape) + return result + + def to_tensor(self): + """ + Convert arrow::SparseCOOTensor to arrow::Tensor. + """ + + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCOOTensor other): + """ + Return true if sparse tensors contains exactly equal data. + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCOOTensor): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + return tuple(frombytes(x) for x in tuple(self.stp.dim_names())) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() + + @property + def has_canonical_format(self): + cdef: + _CSparseCOOIndexPtr csi + + csi = <_CSparseCOOIndexPtr>(self.stp.sparse_index().get()) + if csi != nullptr: + return csi.is_canonical() + return True + +cdef class SparseCSRMatrix(_Weakrefable): + """ + A sparse CSR matrix. + """ + + def __init__(self): + raise TypeError("Do not call SparseCSRMatrix's constructor directly, " + "use one of the `pyarrow.SparseCSRMatrix.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCSRMatrix]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """<pyarrow.SparseCSRMatrix> +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCSRMatrix + + Parameters + ---------- + obj : numpy.ndarray + The dense numpy array that should be converted. + dim_names : list, optional + The names of the dimensions. + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, indptr, indices, shape, dim_names=None): + """ + Create arrow::SparseCSRMatrix from numpy.ndarrays. + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the sparse matrix. + indptr : numpy.ndarray + Range of the rows, + The i-th row spans from `indptr[i]` to `indptr[i+1]` in the data. + indices : numpy.ndarray + Column indices of the corresponding non-zero values. + shape : tuple + Shape of the matrix. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCSRMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseCSRMatrix indices + indptr = np.require(indptr, dtype='i8') + indices = np.require(indices, dtype='i8') + if indptr.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSRMatrix indptr") + if indices.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSRMatrix indices") + + check_status(NdarraysToSparseCSRMatrix(c_default_memory_pool(), + data, indptr, indices, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_csr_matrix(csparse_tensor) + + @staticmethod + def from_scipy(obj, dim_names=None): + """ + Convert scipy.sparse.csr_matrix to arrow::SparseCSRMatrix. + + Parameters + ---------- + obj : scipy.sparse.csr_matrix + The scipy matrix that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import scipy.sparse + if not isinstance(obj, scipy.sparse.csr_matrix): + raise TypeError( + "Expected scipy.sparse.csr_matrix, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCSRMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for CSparseCSRMatrix indices + indptr = np.require(obj.indptr, dtype='i8') + indices = np.require(obj.indices, dtype='i8') + + check_status(NdarraysToSparseCSRMatrix(c_default_memory_pool(), + obj.data, indptr, indices, + c_shape, c_dim_names, + &csparse_tensor)) + return pyarrow_wrap_sparse_csr_matrix(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCSRMatrix. + + Parameters + ---------- + obj : Tensor + The dense tensor that should be converted. + """ + cdef shared_ptr[CSparseCSRMatrix] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCSRMatrix(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_csr_matrix(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCSRMatrix to numpy.ndarrays with zero copy. + """ + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSRMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + return (PyObject_to_object(out_data), PyObject_to_object(out_indptr), + PyObject_to_object(out_indices)) + + def to_scipy(self): + """ + Convert arrow::SparseCSRMatrix to scipy.sparse.csr_matrix. + """ + from scipy.sparse import csr_matrix + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSRMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + + data = PyObject_to_object(out_data) + indptr = PyObject_to_object(out_indptr) + indices = PyObject_to_object(out_indices) + result = csr_matrix((data[:, 0], indices, indptr), shape=self.shape) + return result + + def to_tensor(self): + """ + Convert arrow::SparseCSRMatrix to arrow::Tensor. + """ + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCSRMatrix other): + """ + Return true if sparse tensors contains exactly equal data. + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCSRMatrix): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + return tuple(frombytes(x) for x in tuple(self.stp.dim_names())) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() + +cdef class SparseCSCMatrix(_Weakrefable): + """ + A sparse CSC matrix. + """ + + def __init__(self): + raise TypeError("Do not call SparseCSCMatrix's constructor directly, " + "use one of the `pyarrow.SparseCSCMatrix.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCSCMatrix]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """<pyarrow.SparseCSCMatrix> +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCSCMatrix + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, indptr, indices, shape, dim_names=None): + """ + Create arrow::SparseCSCMatrix from numpy.ndarrays + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the sparse matrix. + indptr : numpy.ndarray + Range of the rows, + The i-th row spans from `indptr[i]` to `indptr[i+1]` in the data. + indices : numpy.ndarray + Column indices of the corresponding non-zero values. + shape : tuple + Shape of the matrix. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCSCMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for SparseCSCMatrix indices + indptr = np.require(indptr, dtype='i8') + indices = np.require(indices, dtype='i8') + if indptr.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSCMatrix indptr") + if indices.ndim != 1: + raise ValueError("Expected 1-dimensional array for " + "SparseCSCMatrix indices") + + check_status(NdarraysToSparseCSCMatrix(c_default_memory_pool(), + data, indptr, indices, c_shape, + c_dim_names, &csparse_tensor)) + return pyarrow_wrap_sparse_csc_matrix(csparse_tensor) + + @staticmethod + def from_scipy(obj, dim_names=None): + """ + Convert scipy.sparse.csc_matrix to arrow::SparseCSCMatrix + + Parameters + ---------- + obj : scipy.sparse.csc_matrix + The scipy matrix that should be converted. + dim_names : list, optional + Names of the dimensions. + """ + import scipy.sparse + if not isinstance(obj, scipy.sparse.csc_matrix): + raise TypeError( + "Expected scipy.sparse.csc_matrix, got {}".format(type(obj))) + + cdef shared_ptr[CSparseCSCMatrix] csparse_tensor + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in obj.shape: + c_shape.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce precondition for CSparseCSCMatrix indices + indptr = np.require(obj.indptr, dtype='i8') + indices = np.require(obj.indices, dtype='i8') + + check_status(NdarraysToSparseCSCMatrix(c_default_memory_pool(), + obj.data, indptr, indices, + c_shape, c_dim_names, + &csparse_tensor)) + return pyarrow_wrap_sparse_csc_matrix(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCSCMatrix + + Parameters + ---------- + obj : Tensor + The dense tensor that should be converted. + """ + cdef shared_ptr[CSparseCSCMatrix] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCSCMatrix(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_csc_matrix(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCSCMatrix to numpy.ndarrays with zero copy + """ + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSCMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + return (PyObject_to_object(out_data), PyObject_to_object(out_indptr), + PyObject_to_object(out_indices)) + + def to_scipy(self): + """ + Convert arrow::SparseCSCMatrix to scipy.sparse.csc_matrix + """ + from scipy.sparse import csc_matrix + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSCMatrixToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + + data = PyObject_to_object(out_data) + indptr = PyObject_to_object(out_indptr) + indices = PyObject_to_object(out_indices) + result = csc_matrix((data[:, 0], indices, indptr), shape=self.shape) + return result + + def to_tensor(self): + """ + Convert arrow::SparseCSCMatrix to arrow::Tensor + """ + + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCSCMatrix other): + """ + Return true if sparse tensors contains exactly equal data + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCSCMatrix): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + return tuple(frombytes(x) for x in tuple(self.stp.dim_names())) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() + + +cdef class SparseCSFTensor(_Weakrefable): + """ + A sparse CSF tensor. + + CSF is a generalization of compressed sparse row (CSR) index. + + CSF index recursively compresses each dimension of a tensor into a set + of prefix trees. Each path from a root to leaf forms one tensor + non-zero index. CSF is implemented with two arrays of buffers and one + arrays of integers. + """ + + def __init__(self): + raise TypeError("Do not call SparseCSFTensor's constructor directly, " + "use one of the `pyarrow.SparseCSFTensor.from_*` " + "functions instead.") + + cdef void init(self, const shared_ptr[CSparseCSFTensor]& sp_sparse_tensor): + self.sp_sparse_tensor = sp_sparse_tensor + self.stp = sp_sparse_tensor.get() + self.type = pyarrow_wrap_data_type(self.stp.type()) + + def __repr__(self): + return """<pyarrow.SparseCSFTensor> +type: {0.type} +shape: {0.shape}""".format(self) + + @classmethod + def from_dense_numpy(cls, obj, dim_names=None): + """ + Convert numpy.ndarray to arrow::SparseCSFTensor + """ + return cls.from_tensor(Tensor.from_numpy(obj, dim_names=dim_names)) + + @staticmethod + def from_numpy(data, indptr, indices, shape, axis_order=None, + dim_names=None): + """ + Create arrow::SparseCSFTensor from numpy.ndarrays + + Parameters + ---------- + data : numpy.ndarray + Data used to populate the sparse tensor. + indptr : numpy.ndarray + The sparsity structure. + Each two consecutive dimensions in a tensor correspond to + a buffer in indices. + A pair of consecutive values at `indptr[dim][i]` + `indptr[dim][i + 1]` signify a range of nodes in + `indices[dim + 1]` who are children of `indices[dim][i]` node. + indices : numpy.ndarray + Stores values of nodes. + Each tensor dimension corresponds to a buffer in indptr. + shape : tuple + Shape of the matrix. + axis_order : list, optional + the sequence in which dimensions were traversed to + produce the prefix tree. + dim_names : list, optional + Names of the dimensions. + """ + cdef shared_ptr[CSparseCSFTensor] csparse_tensor + cdef vector[int64_t] c_axis_order + cdef vector[int64_t] c_shape + cdef vector[c_string] c_dim_names + + for x in shape: + c_shape.push_back(x) + if not axis_order: + axis_order = np.argsort(shape) + for x in axis_order: + c_axis_order.push_back(x) + if dim_names is not None: + for x in dim_names: + c_dim_names.push_back(tobytes(x)) + + # Enforce preconditions for SparseCSFTensor indices + if not (isinstance(indptr, (list, tuple)) and + isinstance(indices, (list, tuple))): + raise TypeError("Expected list or tuple, got {}, {}" + .format(type(indptr), type(indices))) + if len(indptr) != len(shape) - 1: + raise ValueError("Expected list of {ndim} np.arrays for " + "SparseCSFTensor.indptr".format(ndim=len(shape))) + if len(indices) != len(shape): + raise ValueError("Expected list of {ndim} np.arrays for " + "SparseCSFTensor.indices".format(ndim=len(shape))) + if any([x.ndim != 1 for x in indptr]): + raise ValueError("Expected a list of 1-dimensional arrays for " + "SparseCSFTensor.indptr") + if any([x.ndim != 1 for x in indices]): + raise ValueError("Expected a list of 1-dimensional arrays for " + "SparseCSFTensor.indices") + indptr = [np.require(arr, dtype='i8') for arr in indptr] + indices = [np.require(arr, dtype='i8') for arr in indices] + + check_status(NdarraysToSparseCSFTensor(c_default_memory_pool(), data, + indptr, indices, c_shape, + c_axis_order, c_dim_names, + &csparse_tensor)) + return pyarrow_wrap_sparse_csf_tensor(csparse_tensor) + + @staticmethod + def from_tensor(obj): + """ + Convert arrow::Tensor to arrow::SparseCSFTensor + + Parameters + ---------- + obj : Tensor + The dense tensor that should be converted. + """ + cdef shared_ptr[CSparseCSFTensor] csparse_tensor + cdef shared_ptr[CTensor] ctensor = pyarrow_unwrap_tensor(obj) + + with nogil: + check_status(TensorToSparseCSFTensor(ctensor, &csparse_tensor)) + + return pyarrow_wrap_sparse_csf_tensor(csparse_tensor) + + def to_numpy(self): + """ + Convert arrow::SparseCSFTensor to numpy.ndarrays with zero copy + """ + cdef PyObject* out_data + cdef PyObject* out_indptr + cdef PyObject* out_indices + + check_status(SparseCSFTensorToNdarray(self.sp_sparse_tensor, self, + &out_data, &out_indptr, + &out_indices)) + return (PyObject_to_object(out_data), PyObject_to_object(out_indptr), + PyObject_to_object(out_indices)) + + def to_tensor(self): + """ + Convert arrow::SparseCSFTensor to arrow::Tensor + """ + + cdef shared_ptr[CTensor] ctensor + with nogil: + ctensor = GetResultValue(self.stp.ToTensor()) + + return pyarrow_wrap_tensor(ctensor) + + def equals(self, SparseCSFTensor other): + """ + Return true if sparse tensors contains exactly equal data + """ + return self.stp.Equals(deref(other.stp)) + + def __eq__(self, other): + if isinstance(other, SparseCSFTensor): + return self.equals(other) + else: + return NotImplemented + + @property + def is_mutable(self): + return self.stp.is_mutable() + + @property + def ndim(self): + return self.stp.ndim() + + @property + def shape(self): + # Cython knows how to convert a vector[T] to a Python list + return tuple(self.stp.shape()) + + @property + def size(self): + return self.stp.size() + + def dim_name(self, i): + return frombytes(self.stp.dim_name(i)) + + @property + def dim_names(self): + return tuple(frombytes(x) for x in tuple(self.stp.dim_names())) + + @property + def non_zero_length(self): + return self.stp.non_zero_length() diff --git a/src/arrow/python/pyarrow/tensorflow/plasma_op.cc b/src/arrow/python/pyarrow/tensorflow/plasma_op.cc new file mode 100644 index 000000000..bf4eec789 --- /dev/null +++ b/src/arrow/python/pyarrow/tensorflow/plasma_op.cc @@ -0,0 +1,391 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/stream.h" + +#ifdef GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif + +#include "arrow/adapters/tensorflow/convert.h" +#include "arrow/api.h" +#include "arrow/io/api.h" +#include "arrow/util/logging.h" + +// These headers do not include Python.h +#include "arrow/python/deserialize.h" +#include "arrow/python/serialize.h" + +#include "plasma/client.h" + +namespace tf = tensorflow; + +using ArrowStatus = arrow::Status; +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +using Event = perftools::gputools::Event; +using Stream = perftools::gputools::Stream; + +// NOTE(zongheng): for some reason using unique_ptr or shared_ptr results in +// CUDA_ERROR_DEINITIALIZED on program exit. I suspect this is because the +// static object's dtor gets called *after* TensorFlow's own CUDA cleanup. +// Instead, we use a raw pointer here and manually clean up in the Ops' dtors. +static Stream* d2h_stream = nullptr; +static tf::mutex d2h_stream_mu; + +// TODO(zongheng): CPU kernels' std::memcpy might be able to be sped up by +// parallelization. + +int64_t get_byte_width(const arrow::DataType& dtype) { + return arrow::internal::checked_cast<const arrow::FixedWidthType&>(dtype) + .bit_width() / CHAR_BIT; +} + +// Put: tf.Tensor -> plasma. +template <typename Device> +class TensorToPlasmaOp : public tf::AsyncOpKernel { + public: + explicit TensorToPlasmaOp(tf::OpKernelConstruction* context) : tf::AsyncOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("plasma_store_socket_name", + &plasma_store_socket_name_)); + tf::mutex_lock lock(mu_); + if (!connected_) { + VLOG(1) << "Connecting to Plasma..."; + ARROW_CHECK_OK(client_.Connect(plasma_store_socket_name_)); + VLOG(1) << "Connected!"; + connected_ = true; + } + } + + ~TensorToPlasmaOp() override { + { + tf::mutex_lock lock(mu_); + ARROW_CHECK_OK(client_.Disconnect()); + connected_ = false; + } + { + tf::mutex_lock lock(d2h_stream_mu); + if (d2h_stream != nullptr) { + delete d2h_stream; + } + } + } + + void ComputeAsync(tf::OpKernelContext* context, DoneCallback done) override { + const int num_inputs = context->num_inputs(); + OP_REQUIRES_ASYNC( + context, num_inputs >= 2, + tf::errors::InvalidArgument("Input should have at least 1 tensor and 1 object_id"), + done); + const int num_tensors = num_inputs - 1; + + // Check that all tensors have the same dtype + tf::DataType tf_dtype = context->input(0).dtype(); + for (int i = 1; i < num_inputs - 1; i++) { + if (tf_dtype != context->input(i).dtype()) { + ARROW_CHECK_OK(arrow::Status(arrow::StatusCode::TypeError, + "All input tensors must have the same data type")); + } + } + + std::shared_ptr<arrow::DataType> arrow_dtype; + ARROW_CHECK_OK(arrow::adapters::tensorflow::GetArrowType(tf_dtype, &arrow_dtype)); + int64_t byte_width = get_byte_width(*arrow_dtype); + + std::vector<size_t> offsets; + offsets.reserve(num_tensors + 1); + offsets.push_back(0); + int64_t total_bytes = 0; + for (int i = 0; i < num_tensors; ++i) { + const size_t s = context->input(i).TotalBytes(); + CHECK_EQ(s, context->input(i).NumElements() * byte_width); + CHECK_GT(s, 0); + total_bytes += s; + offsets.push_back(total_bytes); + } + + const tf::Tensor& plasma_object_id = context->input(num_inputs - 1); + CHECK_EQ(plasma_object_id.NumElements(), 1); + const std::string& plasma_object_id_str = plasma_object_id.flat<std::string>()(0); + VLOG(1) << "plasma_object_id_str: '" << plasma_object_id_str << "'"; + const plasma::ObjectID object_id = + plasma::ObjectID::from_binary(plasma_object_id_str); + + std::vector<int64_t> shape = {total_bytes / byte_width}; + + arrow::io::MockOutputStream mock; + ARROW_CHECK_OK(arrow::py::WriteNdarrayHeader(arrow_dtype, shape, 0, &mock)); + int64_t header_size = mock.GetExtentBytesWritten(); + + std::shared_ptr<Buffer> data_buffer; + { + tf::mutex_lock lock(mu_); + ARROW_CHECK_OK(client_.Create(object_id, header_size + total_bytes, + /*metadata=*/nullptr, 0, &data_buffer)); + } + + int64_t offset; + arrow::io::FixedSizeBufferWriter buf(data_buffer); + ARROW_CHECK_OK(arrow::py::WriteNdarrayHeader(arrow_dtype, shape, total_bytes, &buf)); + ARROW_CHECK_OK(buf.Tell(&offset)); + + uint8_t* data = reinterpret_cast<uint8_t*>(data_buffer->mutable_data() + offset); + + auto wrapped_callback = [this, context, done, data_buffer, data, object_id]() { + { + tf::mutex_lock lock(mu_); + ARROW_CHECK_OK(client_.Seal(object_id)); + ARROW_CHECK_OK(client_.Release(object_id)); +#ifdef GOOGLE_CUDA + auto orig_stream = context->op_device_context()->stream(); + auto stream_executor = orig_stream->parent(); + CHECK(stream_executor->HostMemoryUnregister(static_cast<void*>(data))); +#endif + } + context->SetStatus(tensorflow::Status::OK()); + done(); + }; + + if (std::is_same<Device, CPUDevice>::value) { + for (int i = 0; i < num_tensors; ++i) { + const auto& input_tensor = context->input(i); + std::memcpy(static_cast<void*>(data + offsets[i]), + input_tensor.tensor_data().data(), + static_cast<tf::uint64>(offsets[i + 1] - offsets[i])); + } + wrapped_callback(); + } else { +#ifdef GOOGLE_CUDA + auto orig_stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, orig_stream != nullptr, + tf::errors::Internal("No GPU stream available."), done); + auto stream_executor = orig_stream->parent(); + + // NOTE(zongheng): this is critical of getting good performance out of D2H + // async memcpy. Under the hood it performs cuMemHostRegister(), see: + // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223 + CHECK(stream_executor->HostMemoryRegister(static_cast<void*>(data), + static_cast<tf::uint64>(total_bytes))); + + { + tf::mutex_lock l(d2h_stream_mu); + if (d2h_stream == nullptr) { + d2h_stream = new Stream(stream_executor); + CHECK(d2h_stream->Init().ok()); + } + } + + // Needed to make sure the input buffers have been computed. + // NOTE(ekl): this is unnecessary when the op is behind a NCCL allreduce already + CHECK(d2h_stream->ThenWaitFor(orig_stream).ok()); + + for (int i = 0; i < num_tensors; ++i) { + const auto& input_tensor = context->input(i); + auto input_buffer = const_cast<char*>(input_tensor.tensor_data().data()); + perftools::gputools::DeviceMemoryBase wrapped_src( + static_cast<void*>(input_buffer)); + const bool success = + d2h_stream + ->ThenMemcpy(static_cast<void*>(data + offsets[i]), wrapped_src, + static_cast<tf::uint64>(offsets[i + 1] - offsets[i])) + .ok(); + OP_REQUIRES_ASYNC(context, success, + tf::errors::Internal("D2H memcpy failed to be enqueued."), done); + } + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + d2h_stream, std::move(wrapped_callback)); +#endif + } + } + + private: + std::string plasma_store_socket_name_; + + tf::mutex mu_; + bool connected_ = false; + plasma::PlasmaClient client_ GUARDED_BY(mu_); +}; + +static Stream* h2d_stream = nullptr; +static tf::mutex h2d_stream_mu; + +// Get: plasma -> tf.Tensor. +template <typename Device> +class PlasmaToTensorOp : public tf::AsyncOpKernel { + public: + explicit PlasmaToTensorOp(tf::OpKernelConstruction* context) : tf::AsyncOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("plasma_store_socket_name", + &plasma_store_socket_name_)); + tf::mutex_lock lock(mu_); + if (!connected_) { + VLOG(1) << "Connecting to Plasma..."; + ARROW_CHECK_OK(client_.Connect(plasma_store_socket_name_)); + VLOG(1) << "Connected!"; + connected_ = true; + } + } + + ~PlasmaToTensorOp() override { + { + tf::mutex_lock lock(mu_); + ARROW_CHECK_OK(client_.Disconnect()); + connected_ = false; + } + { + tf::mutex_lock lock(h2d_stream_mu); + if (h2d_stream != nullptr) { + delete h2d_stream; + } + } + } + + void ComputeAsync(tf::OpKernelContext* context, DoneCallback done) override { + const tf::Tensor& plasma_object_id = context->input(0); + CHECK_EQ(plasma_object_id.NumElements(), 1); + const std::string& plasma_object_id_str = plasma_object_id.flat<std::string>()(0); + + VLOG(1) << "plasma_object_id_str: '" << plasma_object_id_str << "'"; + const plasma::ObjectID object_id = + plasma::ObjectID::from_binary(plasma_object_id_str); + + plasma::ObjectBuffer object_buffer; + { + tf::mutex_lock lock(mu_); + // NOTE(zongheng): this is a blocking call. We might want to (1) make + // Plasma asynchronous, (2) launch a thread / event here ourselves, or + // something like that... + ARROW_CHECK_OK(client_.Get(&object_id, /*num_objects=*/1, + /*timeout_ms=*/-1, &object_buffer)); + } + + std::shared_ptr<arrow::Tensor> ndarray; + ARROW_CHECK_OK(arrow::py::NdarrayFromBuffer(object_buffer.data, &ndarray)); + + int64_t byte_width = get_byte_width(*ndarray->type()); + const int64_t size_in_bytes = ndarray->data()->size(); + + tf::TensorShape shape({static_cast<int64_t>(size_in_bytes / byte_width)}); + + const float* plasma_data = reinterpret_cast<const float*>(ndarray->raw_data()); + + tf::Tensor* output_tensor = nullptr; + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shape, &output_tensor), + done); + + auto wrapped_callback = [this, context, done, plasma_data, object_id]() { + { + tf::mutex_lock lock(mu_); + ARROW_CHECK_OK(client_.Release(object_id)); +#ifdef GOOGLE_CUDA + auto orig_stream = context->op_device_context()->stream(); + auto stream_executor = orig_stream->parent(); + CHECK(stream_executor->HostMemoryUnregister( + const_cast<void*>(static_cast<const void*>(plasma_data)))); +#endif + } + done(); + }; + + if (std::is_same<Device, CPUDevice>::value) { + std::memcpy( + reinterpret_cast<void*>(const_cast<char*>(output_tensor->tensor_data().data())), + plasma_data, size_in_bytes); + wrapped_callback(); + } else { +#ifdef GOOGLE_CUDA + auto orig_stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, orig_stream != nullptr, + tf::errors::Internal("No GPU stream available."), done); + auto stream_executor = orig_stream->parent(); + + { + tf::mutex_lock l(h2d_stream_mu); + if (h2d_stream == nullptr) { + h2d_stream = new Stream(stream_executor); + CHECK(h2d_stream->Init().ok()); + } + } + + // Important. See note in T2P op. + CHECK(stream_executor->HostMemoryRegister( + const_cast<void*>(static_cast<const void*>(plasma_data)), + static_cast<tf::uint64>(size_in_bytes))); + + perftools::gputools::DeviceMemoryBase wrapped_dst( + reinterpret_cast<void*>(const_cast<char*>(output_tensor->tensor_data().data()))); + const bool success = + h2d_stream + ->ThenMemcpy(&wrapped_dst, static_cast<const void*>(plasma_data), + static_cast<tf::uint64>(size_in_bytes)) + .ok(); + OP_REQUIRES_ASYNC(context, success, + tf::errors::Internal("H2D memcpy failed to be enqueued."), done); + + // Without this sync the main compute stream might proceed to use the + // Tensor buffer, but its contents might still be in-flight from our + // h2d_stream. + CHECK(orig_stream->ThenWaitFor(h2d_stream).ok()); + + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + h2d_stream, std::move(wrapped_callback)); +#endif + } + } + + private: + std::string plasma_store_socket_name_; + + tf::mutex mu_; + bool connected_ = false; + plasma::PlasmaClient client_ GUARDED_BY(mu_); +}; + +REGISTER_OP("TensorToPlasma") + .Input("input_tensor: dtypes") + .Input("plasma_object_id: string") + .Attr("dtypes: list(type)") + .Attr("plasma_store_socket_name: string"); + +REGISTER_KERNEL_BUILDER(Name("TensorToPlasma").Device(tf::DEVICE_CPU), + TensorToPlasmaOp<CPUDevice>); +#ifdef GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("TensorToPlasma").Device(tf::DEVICE_GPU), + TensorToPlasmaOp<GPUDevice>); +#endif + +REGISTER_OP("PlasmaToTensor") + .Input("plasma_object_id: string") + .Output("tensor: dtype") + .Attr("dtype: type") + .Attr("plasma_store_socket_name: string"); + +REGISTER_KERNEL_BUILDER(Name("PlasmaToTensor").Device(tf::DEVICE_CPU), + PlasmaToTensorOp<CPUDevice>); +#ifdef GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("PlasmaToTensor").Device(tf::DEVICE_GPU), + PlasmaToTensorOp<GPUDevice>); +#endif diff --git a/src/arrow/python/pyarrow/tests/__init__.py b/src/arrow/python/pyarrow/tests/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/src/arrow/python/pyarrow/tests/__init__.py diff --git a/src/arrow/python/pyarrow/tests/arrow_7980.py b/src/arrow/python/pyarrow/tests/arrow_7980.py new file mode 100644 index 000000000..c1bc3176d --- /dev/null +++ b/src/arrow/python/pyarrow/tests/arrow_7980.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is called from a test in test_schema.py. + +import pyarrow as pa + + +# the types where to_pandas_dtype returns a non-numpy dtype +cases = [ + (pa.timestamp('ns', tz='UTC'), "datetime64[ns, UTC]"), +] + + +for arrow_type, pandas_type in cases: + assert str(arrow_type.to_pandas_dtype()) == pandas_type diff --git a/src/arrow/python/pyarrow/tests/bound_function_visit_strings.pyx b/src/arrow/python/pyarrow/tests/bound_function_visit_strings.pyx new file mode 100644 index 000000000..90437be8c --- /dev/null +++ b/src/arrow/python/pyarrow/tests/bound_function_visit_strings.pyx @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language=c++ +# cython: language_level = 3 + +import pyarrow as pa +from pyarrow.lib cimport * +from pyarrow.lib import frombytes, tobytes + +# basic test to roundtrip through a BoundFunction + +ctypedef CStatus visit_string_cb(const c_string&) + +cdef extern from * namespace "arrow::py" nogil: + """ + #include <functional> + #include <string> + #include <vector> + + #include "arrow/status.h" + + namespace arrow { + namespace py { + + Status VisitStrings(const std::vector<std::string>& strs, + std::function<Status(const std::string&)> cb) { + for (const std::string& str : strs) { + RETURN_NOT_OK(cb(str)); + } + return Status::OK(); + } + + } // namespace py + } // namespace arrow + """ + cdef CStatus CVisitStrings" arrow::py::VisitStrings"( + vector[c_string], function[visit_string_cb]) + + +cdef void _visit_strings_impl(py_cb, const c_string& s) except *: + py_cb(frombytes(s)) + + +def _visit_strings(strings, cb): + cdef: + function[visit_string_cb] c_cb + vector[c_string] c_strings + + c_cb = BindFunction[visit_string_cb](&_visit_strings_impl, cb) + for s in strings: + c_strings.push_back(tobytes(s)) + + check_status(CVisitStrings(c_strings, c_cb)) diff --git a/src/arrow/python/pyarrow/tests/conftest.py b/src/arrow/python/pyarrow/tests/conftest.py new file mode 100644 index 000000000..8fa520b93 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/conftest.py @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import pathlib +import subprocess +from tempfile import TemporaryDirectory + +import pytest +import hypothesis as h + +from pyarrow.util import find_free_port +from pyarrow import Codec + + +# setup hypothesis profiles +h.settings.register_profile('ci', max_examples=1000) +h.settings.register_profile('dev', max_examples=50) +h.settings.register_profile('debug', max_examples=10, + verbosity=h.Verbosity.verbose) + +# load default hypothesis profile, either set HYPOTHESIS_PROFILE environment +# variable or pass --hypothesis-profile option to pytest, to see the generated +# examples try: +# pytest pyarrow -sv --enable-hypothesis --hypothesis-profile=debug +h.settings.load_profile(os.environ.get('HYPOTHESIS_PROFILE', 'dev')) + +# Set this at the beginning before the AWS SDK was loaded to avoid reading in +# user configuration values. +os.environ['AWS_CONFIG_FILE'] = "/dev/null" + + +groups = [ + 'brotli', + 'bz2', + 'cython', + 'dataset', + 'hypothesis', + 'fastparquet', + 'gandiva', + 'gzip', + 'hdfs', + 'large_memory', + 'lz4', + 'memory_leak', + 'nopandas', + 'orc', + 'pandas', + 'parquet', + 'plasma', + 's3', + 'snappy', + 'tensorflow', + 'flight', + 'slow', + 'requires_testing_data', + 'zstd', +] + +defaults = { + 'brotli': Codec.is_available('brotli'), + 'bz2': Codec.is_available('bz2'), + 'cython': False, + 'dataset': False, + 'fastparquet': False, + 'hypothesis': False, + 'gandiva': False, + 'gzip': Codec.is_available('gzip'), + 'hdfs': False, + 'large_memory': False, + 'lz4': Codec.is_available('lz4'), + 'memory_leak': False, + 'orc': False, + 'nopandas': False, + 'pandas': False, + 'parquet': False, + 'plasma': False, + 's3': False, + 'snappy': Codec.is_available('snappy'), + 'tensorflow': False, + 'flight': False, + 'slow': False, + 'requires_testing_data': True, + 'zstd': Codec.is_available('zstd'), +} + +try: + import cython # noqa + defaults['cython'] = True +except ImportError: + pass + +try: + import fastparquet # noqa + defaults['fastparquet'] = True +except ImportError: + pass + +try: + import pyarrow.gandiva # noqa + defaults['gandiva'] = True +except ImportError: + pass + +try: + import pyarrow.dataset # noqa + defaults['dataset'] = True +except ImportError: + pass + +try: + import pyarrow.orc # noqa + defaults['orc'] = True +except ImportError: + pass + +try: + import pandas # noqa + defaults['pandas'] = True +except ImportError: + defaults['nopandas'] = True + +try: + import pyarrow.parquet # noqa + defaults['parquet'] = True +except ImportError: + pass + +try: + import pyarrow.plasma # noqa + defaults['plasma'] = True +except ImportError: + pass + +try: + import tensorflow # noqa + defaults['tensorflow'] = True +except ImportError: + pass + +try: + import pyarrow.flight # noqa + defaults['flight'] = True +except ImportError: + pass + +try: + from pyarrow.fs import S3FileSystem # noqa + defaults['s3'] = True +except ImportError: + pass + +try: + from pyarrow.fs import HadoopFileSystem # noqa + defaults['hdfs'] = True +except ImportError: + pass + + +def pytest_addoption(parser): + # Create options to selectively enable test groups + def bool_env(name, default=None): + value = os.environ.get(name.upper()) + if value is None: + return default + value = value.lower() + if value in {'1', 'true', 'on', 'yes', 'y'}: + return True + elif value in {'0', 'false', 'off', 'no', 'n'}: + return False + else: + raise ValueError('{}={} is not parsable as boolean' + .format(name.upper(), value)) + + for group in groups: + default = bool_env('PYARROW_TEST_{}'.format(group), defaults[group]) + parser.addoption('--enable-{}'.format(group), + action='store_true', default=default, + help=('Enable the {} test group'.format(group))) + parser.addoption('--disable-{}'.format(group), + action='store_true', default=False, + help=('Disable the {} test group'.format(group))) + + +class PyArrowConfig: + def __init__(self): + self.is_enabled = {} + + def apply_mark(self, mark): + group = mark.name + if group in groups: + self.requires(group) + + def requires(self, group): + if not self.is_enabled[group]: + pytest.skip('{} NOT enabled'.format(group)) + + +def pytest_configure(config): + # Apply command-line options to initialize PyArrow-specific config object + config.pyarrow = PyArrowConfig() + + for mark in groups: + config.addinivalue_line( + "markers", mark, + ) + + enable_flag = '--enable-{}'.format(mark) + disable_flag = '--disable-{}'.format(mark) + + is_enabled = (config.getoption(enable_flag) and not + config.getoption(disable_flag)) + config.pyarrow.is_enabled[mark] = is_enabled + + +def pytest_runtest_setup(item): + # Apply test markers to skip tests selectively + for mark in item.iter_markers(): + item.config.pyarrow.apply_mark(mark) + + +@pytest.fixture +def tempdir(tmpdir): + # convert pytest's LocalPath to pathlib.Path + return pathlib.Path(tmpdir.strpath) + + +@pytest.fixture(scope='session') +def base_datadir(): + return pathlib.Path(__file__).parent / 'data' + + +@pytest.fixture(autouse=True) +def disable_aws_metadata(monkeypatch): + """Stop the AWS SDK from trying to contact the EC2 metadata server. + + Otherwise, this causes a 5 second delay in tests that exercise the + S3 filesystem. + """ + monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true") + + +# TODO(kszucs): move the following fixtures to test_fs.py once the previous +# parquet dataset implementation and hdfs implementation are removed. + +@pytest.fixture(scope='session') +def hdfs_connection(): + host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default') + port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0)) + user = os.environ.get('ARROW_HDFS_TEST_USER', 'hdfs') + return host, port, user + + +@pytest.fixture(scope='session') +def s3_connection(): + host, port = 'localhost', find_free_port() + access_key, secret_key = 'arrow', 'apachearrow' + return host, port, access_key, secret_key + + +@pytest.fixture(scope='session') +def s3_server(s3_connection): + host, port, access_key, secret_key = s3_connection + + address = '{}:{}'.format(host, port) + env = os.environ.copy() + env.update({ + 'MINIO_ACCESS_KEY': access_key, + 'MINIO_SECRET_KEY': secret_key + }) + + with TemporaryDirectory() as tempdir: + args = ['minio', '--compat', 'server', '--quiet', '--address', + address, tempdir] + proc = None + try: + proc = subprocess.Popen(args, env=env) + except OSError: + pytest.skip('`minio` command cannot be located') + else: + yield { + 'connection': s3_connection, + 'process': proc, + 'tempdir': tempdir + } + finally: + if proc is not None: + proc.kill() diff --git a/src/arrow/python/pyarrow/tests/data/feather/v0.17.0.version=2-compression=lz4.feather b/src/arrow/python/pyarrow/tests/data/feather/v0.17.0.version=2-compression=lz4.feather Binary files differnew file mode 100644 index 000000000..562b0b2c5 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/feather/v0.17.0.version=2-compression=lz4.feather diff --git a/src/arrow/python/pyarrow/tests/data/orc/README.md b/src/arrow/python/pyarrow/tests/data/orc/README.md new file mode 100644 index 000000000..ccbb0e8b1 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/README.md @@ -0,0 +1,22 @@ +<!--- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +The ORC and JSON files come from the `examples` directory in the Apache ORC +source tree: +https://github.com/apache/orc/tree/master/examples diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.jsn.gz Binary files differnew file mode 100644 index 000000000..91c85cd76 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.jsn.gz diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.orc b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.orc Binary files differnew file mode 100644 index 000000000..ecdadcbff --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.emptyFile.orc diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.jsn.gz Binary files differnew file mode 100644 index 000000000..5eab19a41 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.jsn.gz diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.orc b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.orc Binary files differnew file mode 100644 index 000000000..4fb0beff8 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.test1.orc diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.jsn.gz Binary files differnew file mode 100644 index 000000000..62dbaba42 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.jsn.gz diff --git a/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.orc b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.orc Binary files differnew file mode 100644 index 000000000..f51ffdbd0 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/TestOrcFile.testDate1900.orc diff --git a/src/arrow/python/pyarrow/tests/data/orc/decimal.jsn.gz b/src/arrow/python/pyarrow/tests/data/orc/decimal.jsn.gz Binary files differnew file mode 100644 index 000000000..e634bd70b --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/decimal.jsn.gz diff --git a/src/arrow/python/pyarrow/tests/data/orc/decimal.orc b/src/arrow/python/pyarrow/tests/data/orc/decimal.orc Binary files differnew file mode 100644 index 000000000..cb0f7b9d7 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/orc/decimal.orc diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.all-named-index.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.all-named-index.parquet Binary files differnew file mode 100644 index 000000000..e9efd9b39 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.all-named-index.parquet diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.column-metadata-handling.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.column-metadata-handling.parquet Binary files differnew file mode 100644 index 000000000..d48041f51 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.column-metadata-handling.parquet diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.parquet Binary files differnew file mode 100644 index 000000000..44670bcd1 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.parquet diff --git a/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.some-named-index.parquet b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.some-named-index.parquet Binary files differnew file mode 100644 index 000000000..34097ca12 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/data/parquet/v0.7.1.some-named-index.parquet diff --git a/src/arrow/python/pyarrow/tests/deserialize_buffer.py b/src/arrow/python/pyarrow/tests/deserialize_buffer.py new file mode 100644 index 000000000..982dc6695 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/deserialize_buffer.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is called from a test in test_serialization.py. + +import sys + +import pyarrow as pa + +with open(sys.argv[1], 'rb') as f: + data = f.read() + pa.deserialize(data) diff --git a/src/arrow/python/pyarrow/tests/pandas_examples.py b/src/arrow/python/pyarrow/tests/pandas_examples.py new file mode 100644 index 000000000..466c14eeb --- /dev/null +++ b/src/arrow/python/pyarrow/tests/pandas_examples.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import OrderedDict +from datetime import date, time + +import numpy as np +import pandas as pd +import pyarrow as pa + + +def dataframe_with_arrays(include_index=False): + """ + Dataframe with numpy arrays columns of every possible primitive type. + + Returns + ------- + df: pandas.DataFrame + schema: pyarrow.Schema + Arrow schema definition that is in line with the constructed df. + """ + dtypes = [('i1', pa.int8()), ('i2', pa.int16()), + ('i4', pa.int32()), ('i8', pa.int64()), + ('u1', pa.uint8()), ('u2', pa.uint16()), + ('u4', pa.uint32()), ('u8', pa.uint64()), + ('f4', pa.float32()), ('f8', pa.float64())] + + arrays = OrderedDict() + fields = [] + for dtype, arrow_dtype in dtypes: + fields.append(pa.field(dtype, pa.list_(arrow_dtype))) + arrays[dtype] = [ + np.arange(10, dtype=dtype), + np.arange(5, dtype=dtype), + None, + np.arange(1, dtype=dtype) + ] + + fields.append(pa.field('str', pa.list_(pa.string()))) + arrays['str'] = [ + np.array(["1", "ä"], dtype="object"), + None, + np.array(["1"], dtype="object"), + np.array(["1", "2", "3"], dtype="object") + ] + + fields.append(pa.field('datetime64', pa.list_(pa.timestamp('ms')))) + arrays['datetime64'] = [ + np.array(['2007-07-13T01:23:34.123456789', + None, + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ms]'), + None, + None, + np.array(['2007-07-13T02', + None, + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ms]'), + ] + + if include_index: + fields.append(pa.field('__index_level_0__', pa.int64())) + df = pd.DataFrame(arrays) + schema = pa.schema(fields) + + return df, schema + + +def dataframe_with_lists(include_index=False, parquet_compatible=False): + """ + Dataframe with list columns of every possible primitive type. + + Returns + ------- + df: pandas.DataFrame + schema: pyarrow.Schema + Arrow schema definition that is in line with the constructed df. + parquet_compatible: bool + Exclude types not supported by parquet + """ + arrays = OrderedDict() + fields = [] + + fields.append(pa.field('int64', pa.list_(pa.int64()))) + arrays['int64'] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, 1, 2, 3, 4], + None, + [], + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9] * 2, + dtype=np.int64)[::2] + ] + fields.append(pa.field('double', pa.list_(pa.float64()))) + arrays['double'] = [ + [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], + [0., 1., 2., 3., 4.], + None, + [], + np.array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.] * 2)[::2], + ] + fields.append(pa.field('bytes_list', pa.list_(pa.binary()))) + arrays['bytes_list'] = [ + [b"1", b"f"], + None, + [b"1"], + [b"1", b"2", b"3"], + [], + ] + fields.append(pa.field('str_list', pa.list_(pa.string()))) + arrays['str_list'] = [ + ["1", "ä"], + None, + ["1"], + ["1", "2", "3"], + [], + ] + + date_data = [ + [], + [date(2018, 1, 1), date(2032, 12, 30)], + [date(2000, 6, 7)], + None, + [date(1969, 6, 9), date(1972, 7, 3)] + ] + time_data = [ + [time(23, 11, 11), time(1, 2, 3), time(23, 59, 59)], + [], + [time(22, 5, 59)], + None, + [time(0, 0, 0), time(18, 0, 2), time(12, 7, 3)] + ] + + temporal_pairs = [ + (pa.date32(), date_data), + (pa.date64(), date_data), + (pa.time32('s'), time_data), + (pa.time32('ms'), time_data), + (pa.time64('us'), time_data) + ] + if not parquet_compatible: + temporal_pairs += [ + (pa.time64('ns'), time_data), + ] + + for value_type, data in temporal_pairs: + field_name = '{}_list'.format(value_type) + field_type = pa.list_(value_type) + field = pa.field(field_name, field_type) + fields.append(field) + arrays[field_name] = data + + if include_index: + fields.append(pa.field('__index_level_0__', pa.int64())) + + df = pd.DataFrame(arrays) + schema = pa.schema(fields) + + return df, schema diff --git a/src/arrow/python/pyarrow/tests/pandas_threaded_import.py b/src/arrow/python/pyarrow/tests/pandas_threaded_import.py new file mode 100644 index 000000000..f44632d74 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/pandas_threaded_import.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is called from a test in test_pandas.py. + +from concurrent.futures import ThreadPoolExecutor +import faulthandler +import sys + +import pyarrow as pa + +num_threads = 60 +timeout = 10 # seconds + + +def thread_func(i): + pa.array([i]).to_pandas() + + +def main(): + # In case of import deadlock, crash after a finite timeout + faulthandler.dump_traceback_later(timeout, exit=True) + with ThreadPoolExecutor(num_threads) as pool: + assert "pandas" not in sys.modules # pandas is imported lazily + list(pool.map(thread_func, range(num_threads))) + assert "pandas" in sys.modules + + +if __name__ == "__main__": + main() diff --git a/src/arrow/python/pyarrow/tests/parquet/common.py b/src/arrow/python/pyarrow/tests/parquet/common.py new file mode 100644 index 000000000..90bfb55d1 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/common.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow.tests import util + +parametrize_legacy_dataset = pytest.mark.parametrize( + "use_legacy_dataset", + [True, pytest.param(False, marks=pytest.mark.dataset)]) +parametrize_legacy_dataset_not_supported = pytest.mark.parametrize( + "use_legacy_dataset", [True, pytest.param(False, marks=pytest.mark.skip)]) +parametrize_legacy_dataset_fixed = pytest.mark.parametrize( + "use_legacy_dataset", [pytest.param(True, marks=pytest.mark.xfail), + pytest.param(False, marks=pytest.mark.dataset)]) + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not parquet' +pytestmark = pytest.mark.parquet + + +def _write_table(table, path, **kwargs): + # So we see the ImportError somewhere + import pyarrow.parquet as pq + from pyarrow.pandas_compat import _pandas_api + + if _pandas_api.is_data_frame(table): + table = pa.Table.from_pandas(table) + + pq.write_table(table, path, **kwargs) + return table + + +def _read_table(*args, **kwargs): + import pyarrow.parquet as pq + + table = pq.read_table(*args, **kwargs) + table.validate(full=True) + return table + + +def _roundtrip_table(table, read_table_kwargs=None, + write_table_kwargs=None, use_legacy_dataset=True): + read_table_kwargs = read_table_kwargs or {} + write_table_kwargs = write_table_kwargs or {} + + writer = pa.BufferOutputStream() + _write_table(table, writer, **write_table_kwargs) + reader = pa.BufferReader(writer.getvalue()) + return _read_table(reader, use_legacy_dataset=use_legacy_dataset, + **read_table_kwargs) + + +def _check_roundtrip(table, expected=None, read_table_kwargs=None, + use_legacy_dataset=True, **write_table_kwargs): + if expected is None: + expected = table + + read_table_kwargs = read_table_kwargs or {} + + # intentionally check twice + result = _roundtrip_table(table, read_table_kwargs=read_table_kwargs, + write_table_kwargs=write_table_kwargs, + use_legacy_dataset=use_legacy_dataset) + assert result.equals(expected) + result = _roundtrip_table(result, read_table_kwargs=read_table_kwargs, + write_table_kwargs=write_table_kwargs, + use_legacy_dataset=use_legacy_dataset) + assert result.equals(expected) + + +def _roundtrip_pandas_dataframe(df, write_kwargs, use_legacy_dataset=True): + table = pa.Table.from_pandas(df) + result = _roundtrip_table( + table, write_table_kwargs=write_kwargs, + use_legacy_dataset=use_legacy_dataset) + return result.to_pandas() + + +def _random_integers(size, dtype): + # We do not generate integers outside the int64 range + platform_int_info = np.iinfo('int_') + iinfo = np.iinfo(dtype) + return np.random.randint(max(iinfo.min, platform_int_info.min), + min(iinfo.max, platform_int_info.max), + size=size).astype(dtype) + + +def _test_dataframe(size=10000, seed=0): + import pandas as pd + + np.random.seed(seed) + df = pd.DataFrame({ + 'uint8': _random_integers(size, np.uint8), + 'uint16': _random_integers(size, np.uint16), + 'uint32': _random_integers(size, np.uint32), + 'uint64': _random_integers(size, np.uint64), + 'int8': _random_integers(size, np.int8), + 'int16': _random_integers(size, np.int16), + 'int32': _random_integers(size, np.int32), + 'int64': _random_integers(size, np.int64), + 'float32': np.random.randn(size).astype(np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + 'strings': [util.rands(10) for i in range(size)], + 'all_none': [None] * size, + 'all_none_category': [None] * size + }) + + # TODO(PARQUET-1015) + # df['all_none_category'] = df['all_none_category'].astype('category') + return df + + +def make_sample_file(table_or_df): + import pyarrow.parquet as pq + + if isinstance(table_or_df, pa.Table): + a_table = table_or_df + else: + a_table = pa.Table.from_pandas(table_or_df) + + buf = io.BytesIO() + _write_table(a_table, buf, compression='SNAPPY', version='2.6', + coerce_timestamps='ms') + + buf.seek(0) + return pq.ParquetFile(buf) + + +def alltypes_sample(size=10000, seed=0, categorical=False): + import pandas as pd + + np.random.seed(seed) + arrays = { + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16), + 'uint32': np.arange(size, dtype=np.uint32), + 'uint64': np.arange(size, dtype=np.uint64), + 'int8': np.arange(size, dtype=np.int16), + 'int16': np.arange(size, dtype=np.int16), + 'int32': np.arange(size, dtype=np.int32), + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + # TODO(wesm): Test other timestamp resolutions now that arrow supports + # them + 'datetime': np.arange("2016-01-01T00:00:00.001", size, + dtype='datetime64[ms]'), + 'str': pd.Series([str(x) for x in range(size)]), + 'empty_str': [''] * size, + 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None], + 'null': [None] * size, + 'null_list': [None] * 2 + [[None] * (x % 4) for x in range(size - 2)], + } + if categorical: + arrays['str_category'] = arrays['str'].astype('category') + return pd.DataFrame(arrays) diff --git a/src/arrow/python/pyarrow/tests/parquet/conftest.py b/src/arrow/python/pyarrow/tests/parquet/conftest.py new file mode 100644 index 000000000..1e75493cd --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/conftest.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from pyarrow.util import guid + + +@pytest.fixture(scope='module') +def datadir(base_datadir): + return base_datadir / 'parquet' + + +@pytest.fixture +def s3_bucket(s3_server): + boto3 = pytest.importorskip('boto3') + botocore = pytest.importorskip('botocore') + + host, port, access_key, secret_key = s3_server['connection'] + s3 = boto3.resource( + 's3', + endpoint_url='http://{}:{}'.format(host, port), + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + config=botocore.client.Config(signature_version='s3v4'), + region_name='us-east-1' + ) + bucket = s3.Bucket('test-s3fs') + try: + bucket.create() + except Exception: + # we get BucketAlreadyOwnedByYou error with fsspec handler + pass + return 'test-s3fs' + + +@pytest.fixture +def s3_example_s3fs(s3_server, s3_bucket): + s3fs = pytest.importorskip('s3fs') + + host, port, access_key, secret_key = s3_server['connection'] + fs = s3fs.S3FileSystem( + key=access_key, + secret=secret_key, + client_kwargs={ + 'endpoint_url': 'http://{}:{}'.format(host, port) + } + ) + + test_path = '{}/{}'.format(s3_bucket, guid()) + + fs.mkdir(test_path) + yield fs, test_path + try: + fs.rm(test_path, recursive=True) + except FileNotFoundError: + pass + + +@pytest.fixture +def s3_example_fs(s3_server): + from pyarrow.fs import FileSystem + + host, port, access_key, secret_key = s3_server['connection'] + uri = ( + "s3://{}:{}@mybucket/data.parquet?scheme=http&endpoint_override={}:{}" + .format(access_key, secret_key, host, port) + ) + fs, path = FileSystem.from_uri(uri) + + fs.create_dir("mybucket") + + yield fs, uri, path diff --git a/src/arrow/python/pyarrow/tests/parquet/test_basic.py b/src/arrow/python/pyarrow/tests/parquet/test_basic.py new file mode 100644 index 000000000..cf1aaa21f --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_basic.py @@ -0,0 +1,631 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import OrderedDict +import io + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow import fs +from pyarrow.filesystem import LocalFileSystem, FileSystem +from pyarrow.tests import util +from pyarrow.tests.parquet.common import (_check_roundtrip, _roundtrip_table, + parametrize_legacy_dataset) + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import _read_table, _write_table +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + + from pyarrow.tests.pandas_examples import dataframe_with_lists + from pyarrow.tests.parquet.common import alltypes_sample +except ImportError: + pd = tm = None + + +pytestmark = pytest.mark.parquet + + +def test_parquet_invalid_version(tempdir): + table = pa.table({'a': [1, 2, 3]}) + with pytest.raises(ValueError, match="Unsupported Parquet format version"): + _write_table(table, tempdir / 'test_version.parquet', version="2.2") + with pytest.raises(ValueError, match="Unsupported Parquet data page " + + "version"): + _write_table(table, tempdir / 'test_version.parquet', + data_page_version="2.2") + + +@parametrize_legacy_dataset +def test_set_data_page_size(use_legacy_dataset): + arr = pa.array([1, 2, 3] * 100000) + t = pa.Table.from_arrays([arr], names=['f0']) + + # 128K, 512K + page_sizes = [2 << 16, 2 << 18] + for target_page_size in page_sizes: + _check_roundtrip(t, data_page_size=target_page_size, + use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_chunked_table_write(use_legacy_dataset): + # ARROW-232 + tables = [] + batch = pa.RecordBatch.from_pandas(alltypes_sample(size=10)) + tables.append(pa.Table.from_batches([batch] * 3)) + df, _ = dataframe_with_lists() + batch = pa.RecordBatch.from_pandas(df) + tables.append(pa.Table.from_batches([batch] * 3)) + + for data_page_version in ['1.0', '2.0']: + for use_dictionary in [True, False]: + for table in tables: + _check_roundtrip( + table, version='2.6', + use_legacy_dataset=use_legacy_dataset, + data_page_version=data_page_version, + use_dictionary=use_dictionary) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_memory_map(tempdir, use_legacy_dataset): + df = alltypes_sample(size=10) + + table = pa.Table.from_pandas(df) + _check_roundtrip(table, read_table_kwargs={'memory_map': True}, + version='2.6', use_legacy_dataset=use_legacy_dataset) + + filename = str(tempdir / 'tmp_file') + with open(filename, 'wb') as f: + _write_table(table, f, version='2.6') + table_read = pq.read_pandas(filename, memory_map=True, + use_legacy_dataset=use_legacy_dataset) + assert table_read.equals(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_enable_buffered_stream(tempdir, use_legacy_dataset): + df = alltypes_sample(size=10) + + table = pa.Table.from_pandas(df) + _check_roundtrip(table, read_table_kwargs={'buffer_size': 1025}, + version='2.6', use_legacy_dataset=use_legacy_dataset) + + filename = str(tempdir / 'tmp_file') + with open(filename, 'wb') as f: + _write_table(table, f, version='2.6') + table_read = pq.read_pandas(filename, buffer_size=4096, + use_legacy_dataset=use_legacy_dataset) + assert table_read.equals(table) + + +@parametrize_legacy_dataset +def test_special_chars_filename(tempdir, use_legacy_dataset): + table = pa.Table.from_arrays([pa.array([42])], ["ints"]) + filename = "foo # bar" + path = tempdir / filename + assert not path.exists() + _write_table(table, str(path)) + assert path.exists() + table_read = _read_table(str(path), use_legacy_dataset=use_legacy_dataset) + assert table_read.equals(table) + + +@parametrize_legacy_dataset +def test_invalid_source(use_legacy_dataset): + # Test that we provide an helpful error message pointing out + # that None wasn't expected when trying to open a Parquet None file. + # + # Depending on use_legacy_dataset the message changes slightly + # but in both cases it should point out that None wasn't expected. + with pytest.raises(TypeError, match="None"): + pq.read_table(None, use_legacy_dataset=use_legacy_dataset) + + with pytest.raises(TypeError, match="None"): + pq.ParquetFile(None) + + +@pytest.mark.slow +def test_file_with_over_int16_max_row_groups(): + # PARQUET-1857: Parquet encryption support introduced a INT16_MAX upper + # limit on the number of row groups, but this limit only impacts files with + # encrypted row group metadata because of the int16 row group ordinal used + # in the Parquet Thrift metadata. Unencrypted files are not impacted, so + # this test checks that it works (even if it isn't a good idea) + t = pa.table([list(range(40000))], names=['f0']) + _check_roundtrip(t, row_group_size=1) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_empty_table_roundtrip(use_legacy_dataset): + df = alltypes_sample(size=10) + + # Create a non-empty table to infer the types correctly, then slice to 0 + table = pa.Table.from_pandas(df) + table = pa.Table.from_arrays( + [col.chunk(0)[:0] for col in table.itercolumns()], + names=table.schema.names) + + assert table.schema.field('null').type == pa.null() + assert table.schema.field('null_list').type == pa.list_(pa.null()) + _check_roundtrip( + table, version='2.6', use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_empty_table_no_columns(use_legacy_dataset): + df = pd.DataFrame() + empty = pa.Table.from_pandas(df, preserve_index=False) + _check_roundtrip(empty, use_legacy_dataset=use_legacy_dataset) + + +@parametrize_legacy_dataset +def test_write_nested_zero_length_array_chunk_failure(use_legacy_dataset): + # Bug report in ARROW-3792 + cols = OrderedDict( + int32=pa.int32(), + list_string=pa.list_(pa.string()) + ) + data = [[], [OrderedDict(int32=1, list_string=('G',)), ]] + + # This produces a table with a column like + # <Column name='list_string' type=ListType(list<item: string>)> + # [ + # [], + # [ + # [ + # "G" + # ] + # ] + # ] + # + # Each column is a ChunkedArray with 2 elements + my_arrays = [pa.array(batch, type=pa.struct(cols)).flatten() + for batch in data] + my_batches = [pa.RecordBatch.from_arrays(batch, schema=pa.schema(cols)) + for batch in my_arrays] + tbl = pa.Table.from_batches(my_batches, pa.schema(cols)) + _check_roundtrip(tbl, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_multiple_path_types(tempdir, use_legacy_dataset): + # Test compatibility with PEP 519 path-like objects + path = tempdir / 'zzz.parquet' + df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)}) + _write_table(df, path) + table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + # Test compatibility with plain string paths + path = str(tempdir) + 'zzz.parquet' + df = pd.DataFrame({'x': np.arange(10, dtype=np.int64)}) + _write_table(df, path) + table_read = _read_table(path, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@parametrize_legacy_dataset +def test_fspath(tempdir, use_legacy_dataset): + # ARROW-12472 support __fspath__ objects without using str() + path = tempdir / "test.parquet" + table = pa.table({"a": [1, 2, 3]}) + _write_table(table, path) + + fs_protocol_obj = util.FSProtocolClass(path) + + result = _read_table( + fs_protocol_obj, use_legacy_dataset=use_legacy_dataset + ) + assert result.equals(table) + + # combined with non-local filesystem raises + with pytest.raises(TypeError): + _read_table(fs_protocol_obj, filesystem=FileSystem()) + + +@pytest.mark.dataset +@parametrize_legacy_dataset +@pytest.mark.parametrize("filesystem", [ + None, fs.LocalFileSystem(), LocalFileSystem._get_instance() +]) +def test_relative_paths(tempdir, use_legacy_dataset, filesystem): + # reading and writing from relative paths + table = pa.table({"a": [1, 2, 3]}) + + # reading + pq.write_table(table, str(tempdir / "data.parquet")) + with util.change_cwd(tempdir): + result = pq.read_table("data.parquet", filesystem=filesystem, + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + # writing + with util.change_cwd(tempdir): + pq.write_table(table, "data2.parquet", filesystem=filesystem) + result = pq.read_table(tempdir / "data2.parquet") + assert result.equals(table) + + +def test_read_non_existing_file(): + # ensure we have a proper error message + with pytest.raises(FileNotFoundError): + pq.read_table('i-am-not-existing.parquet') + + +def test_file_error_python_exception(): + class BogusFile(io.BytesIO): + def read(self, *args): + raise ZeroDivisionError("zorglub") + + def seek(self, *args): + raise ZeroDivisionError("zorglub") + + # ensure the Python exception is restored + with pytest.raises(ZeroDivisionError, match="zorglub"): + pq.read_table(BogusFile(b"")) + + +@parametrize_legacy_dataset +def test_parquet_read_from_buffer(tempdir, use_legacy_dataset): + # reading from a buffer from python's open() + table = pa.table({"a": [1, 2, 3]}) + pq.write_table(table, str(tempdir / "data.parquet")) + + with open(str(tempdir / "data.parquet"), "rb") as f: + result = pq.read_table(f, use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + with open(str(tempdir / "data.parquet"), "rb") as f: + result = pq.read_table(pa.PythonFile(f), + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + +@parametrize_legacy_dataset +def test_byte_stream_split(use_legacy_dataset): + # This is only a smoke test. + arr_float = pa.array(list(map(float, range(100)))) + arr_int = pa.array(list(map(int, range(100)))) + data_float = [arr_float, arr_float] + table = pa.Table.from_arrays(data_float, names=['a', 'b']) + + # Check with byte_stream_split for both columns. + _check_roundtrip(table, expected=table, compression="gzip", + use_dictionary=False, use_byte_stream_split=True) + + # Check with byte_stream_split for column 'b' and dictionary + # for column 'a'. + _check_roundtrip(table, expected=table, compression="gzip", + use_dictionary=['a'], + use_byte_stream_split=['b']) + + # Check with a collision for both columns. + _check_roundtrip(table, expected=table, compression="gzip", + use_dictionary=['a', 'b'], + use_byte_stream_split=['a', 'b']) + + # Check with mixed column types. + mixed_table = pa.Table.from_arrays([arr_float, arr_int], + names=['a', 'b']) + _check_roundtrip(mixed_table, expected=mixed_table, + use_dictionary=['b'], + use_byte_stream_split=['a']) + + # Try to use the wrong data type with the byte_stream_split encoding. + # This should throw an exception. + table = pa.Table.from_arrays([arr_int], names=['tmp']) + with pytest.raises(IOError): + _check_roundtrip(table, expected=table, use_byte_stream_split=True, + use_dictionary=False, + use_legacy_dataset=use_legacy_dataset) + + +@parametrize_legacy_dataset +def test_compression_level(use_legacy_dataset): + arr = pa.array(list(map(int, range(1000)))) + data = [arr, arr] + table = pa.Table.from_arrays(data, names=['a', 'b']) + + # Check one compression level. + _check_roundtrip(table, expected=table, compression="gzip", + compression_level=1, + use_legacy_dataset=use_legacy_dataset) + + # Check another one to make sure that compression_level=1 does not + # coincide with the default one in Arrow. + _check_roundtrip(table, expected=table, compression="gzip", + compression_level=5, + use_legacy_dataset=use_legacy_dataset) + + # Check that the user can provide a compression per column + _check_roundtrip(table, expected=table, + compression={'a': "gzip", 'b': "snappy"}, + use_legacy_dataset=use_legacy_dataset) + + # Check that the user can provide a compression level per column + _check_roundtrip(table, expected=table, compression="gzip", + compression_level={'a': 2, 'b': 3}, + use_legacy_dataset=use_legacy_dataset) + + # Check that specifying a compression level for a codec which does allow + # specifying one, results into an error. + # Uncompressed, snappy, lz4 and lzo do not support specifying a compression + # level. + # GZIP (zlib) allows for specifying a compression level but as of up + # to version 1.2.11 the valid range is [-1, 9]. + invalid_combinations = [("snappy", 4), ("lz4", 5), ("gzip", -1337), + ("None", 444), ("lzo", 14)] + buf = io.BytesIO() + for (codec, level) in invalid_combinations: + with pytest.raises((ValueError, OSError)): + _write_table(table, buf, compression=codec, + compression_level=level) + + +def test_sanitized_spark_field_names(): + a0 = pa.array([0, 1, 2, 3, 4]) + name = 'prohib; ,\t{}' + table = pa.Table.from_arrays([a0], [name]) + + result = _roundtrip_table(table, write_table_kwargs={'flavor': 'spark'}) + + expected_name = 'prohib______' + assert result.schema[0].name == expected_name + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_multithreaded_read(use_legacy_dataset): + df = alltypes_sample(size=10000) + + table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(table, buf, compression='SNAPPY', version='2.6') + + buf.seek(0) + table1 = _read_table( + buf, use_threads=True, use_legacy_dataset=use_legacy_dataset) + + buf.seek(0) + table2 = _read_table( + buf, use_threads=False, use_legacy_dataset=use_legacy_dataset) + + assert table1.equals(table2) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_min_chunksize(use_legacy_dataset): + data = pd.DataFrame([np.arange(4)], columns=['A', 'B', 'C', 'D']) + table = pa.Table.from_pandas(data.reset_index()) + + buf = io.BytesIO() + _write_table(table, buf, chunk_size=-1) + + buf.seek(0) + result = _read_table(buf, use_legacy_dataset=use_legacy_dataset) + + assert result.equals(table) + + with pytest.raises(ValueError): + _write_table(table, buf, chunk_size=0) + + +@pytest.mark.pandas +def test_write_error_deletes_incomplete_file(tempdir): + # ARROW-1285 + df = pd.DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, freq='ns')}) + + pdf = pa.Table.from_pandas(df) + + filename = tempdir / 'tmp_file' + try: + _write_table(pdf, filename) + except pa.ArrowException: + pass + + assert not filename.exists() + + +@parametrize_legacy_dataset +def test_read_non_existent_file(tempdir, use_legacy_dataset): + path = 'non-existent-file.parquet' + try: + pq.read_table(path, use_legacy_dataset=use_legacy_dataset) + except Exception as e: + assert path in e.args[0] + + +@parametrize_legacy_dataset +def test_read_table_doesnt_warn(datadir, use_legacy_dataset): + with pytest.warns(None) as record: + pq.read_table(datadir / 'v0.7.1.parquet', + use_legacy_dataset=use_legacy_dataset) + + assert len(record) == 0 + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_zlib_compression_bug(use_legacy_dataset): + # ARROW-3514: "zlib deflate failed, output buffer too small" + table = pa.Table.from_arrays([pa.array(['abc', 'def'])], ['some_col']) + f = io.BytesIO() + pq.write_table(table, f, compression='gzip') + + f.seek(0) + roundtrip = pq.read_table(f, use_legacy_dataset=use_legacy_dataset) + tm.assert_frame_equal(roundtrip.to_pandas(), table.to_pandas()) + + +@parametrize_legacy_dataset +def test_parquet_file_too_small(tempdir, use_legacy_dataset): + path = str(tempdir / "test.parquet") + # TODO(dataset) with datasets API it raises OSError instead + with pytest.raises((pa.ArrowInvalid, OSError), + match='size is 0 bytes'): + with open(path, 'wb') as f: + pass + pq.read_table(path, use_legacy_dataset=use_legacy_dataset) + + with pytest.raises((pa.ArrowInvalid, OSError), + match='size is 4 bytes'): + with open(path, 'wb') as f: + f.write(b'ffff') + pq.read_table(path, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +@pytest.mark.fastparquet +@pytest.mark.filterwarnings("ignore:RangeIndex:FutureWarning") +@pytest.mark.filterwarnings("ignore:tostring:DeprecationWarning:fastparquet") +def test_fastparquet_cross_compatibility(tempdir): + fp = pytest.importorskip('fastparquet') + + df = pd.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + "c": np.arange(4.0, 7.0, dtype="float64"), + "d": [True, False, True], + "e": pd.date_range("20130101", periods=3), + "f": pd.Categorical(["a", "b", "a"]), + # fastparquet writes list as BYTE_ARRAY JSON, so no roundtrip + # "g": [[1, 2], None, [1, 2, 3]], + } + ) + table = pa.table(df) + + # Arrow -> fastparquet + file_arrow = str(tempdir / "cross_compat_arrow.parquet") + pq.write_table(table, file_arrow, compression=None) + + fp_file = fp.ParquetFile(file_arrow) + df_fp = fp_file.to_pandas() + tm.assert_frame_equal(df, df_fp) + + # Fastparquet -> arrow + file_fastparquet = str(tempdir / "cross_compat_fastparquet.parquet") + fp.write(file_fastparquet, df) + + table_fp = pq.read_pandas(file_fastparquet) + # for fastparquet written file, categoricals comes back as strings + # (no arrow schema in parquet metadata) + df['f'] = df['f'].astype(object) + tm.assert_frame_equal(table_fp.to_pandas(), df) + + +@parametrize_legacy_dataset +@pytest.mark.parametrize('array_factory', [ + lambda: pa.array([0, None] * 10), + lambda: pa.array([0, None] * 10).dictionary_encode(), + lambda: pa.array(["", None] * 10), + lambda: pa.array(["", None] * 10).dictionary_encode(), +]) +@pytest.mark.parametrize('use_dictionary', [False, True]) +@pytest.mark.parametrize('read_dictionary', [False, True]) +def test_buffer_contents( + array_factory, use_dictionary, read_dictionary, use_legacy_dataset +): + # Test that null values are deterministically initialized to zero + # after a roundtrip through Parquet. + # See ARROW-8006 and ARROW-8011. + orig_table = pa.Table.from_pydict({"col": array_factory()}) + bio = io.BytesIO() + pq.write_table(orig_table, bio, use_dictionary=True) + bio.seek(0) + read_dictionary = ['col'] if read_dictionary else None + table = pq.read_table(bio, use_threads=False, + read_dictionary=read_dictionary, + use_legacy_dataset=use_legacy_dataset) + + for col in table.columns: + [chunk] = col.chunks + buf = chunk.buffers()[1] + assert buf.to_pybytes() == buf.size * b"\0" + + +def test_parquet_compression_roundtrip(tempdir): + # ARROW-10480: ensure even with nonstandard Parquet file naming + # conventions, writing and then reading a file works. In + # particular, ensure that we don't automatically double-compress + # the stream due to auto-detecting the extension in the filename + table = pa.table([pa.array(range(4))], names=["ints"]) + path = tempdir / "arrow-10480.pyarrow.gz" + pq.write_table(table, path, compression="GZIP") + result = pq.read_table(path) + assert result.equals(table) + + +def test_empty_row_groups(tempdir): + # ARROW-3020 + table = pa.Table.from_arrays([pa.array([], type='int32')], ['f0']) + + path = tempdir / 'empty_row_groups.parquet' + + num_groups = 3 + with pq.ParquetWriter(path, table.schema) as writer: + for i in range(num_groups): + writer.write_table(table) + + reader = pq.ParquetFile(path) + assert reader.metadata.num_row_groups == num_groups + + for i in range(num_groups): + assert reader.read_row_group(i).equals(table) + + +def test_reads_over_batch(tempdir): + data = [None] * (1 << 20) + data.append([1]) + # Large list<int64> with mostly nones and one final + # value. This should force batched reads when + # reading back. + table = pa.Table.from_arrays([data], ['column']) + + path = tempdir / 'arrow-11607.parquet' + pq.write_table(table, path) + table2 = pq.read_table(path) + assert table == table2 diff --git a/src/arrow/python/pyarrow/tests/parquet/test_compliant_nested_type.py b/src/arrow/python/pyarrow/tests/parquet/test_compliant_nested_type.py new file mode 100644 index 000000000..91bc08df6 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_compliant_nested_type.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import pyarrow as pa +from pyarrow.tests.parquet.common import parametrize_legacy_dataset + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import (_read_table, + _check_roundtrip) +except ImportError: + pq = None + +try: + import pandas as pd + import pandas.testing as tm + + from pyarrow.tests.parquet.common import _roundtrip_pandas_dataframe +except ImportError: + pd = tm = None + +pytestmark = pytest.mark.parquet + +# Tests for ARROW-11497 +_test_data_simple = [ + {'items': [1, 2]}, + {'items': [0]}, +] + +_test_data_complex = [ + {'items': [{'name': 'elem1', 'value': '1'}, + {'name': 'elem2', 'value': '2'}]}, + {'items': [{'name': 'elem1', 'value': '0'}]}, +] + +parametrize_test_data = pytest.mark.parametrize( + "test_data", [_test_data_simple, _test_data_complex]) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@parametrize_test_data +def test_write_compliant_nested_type_enable(tempdir, + use_legacy_dataset, test_data): + # prepare dataframe for testing + df = pd.DataFrame(data=test_data) + # verify that we can read/write pandas df with new flag + _roundtrip_pandas_dataframe(df, + write_kwargs={ + 'use_compliant_nested_type': True}, + use_legacy_dataset=use_legacy_dataset) + + # Write to a parquet file with compliant nested type + table = pa.Table.from_pandas(df, preserve_index=False) + path = str(tempdir / 'data.parquet') + with pq.ParquetWriter(path, table.schema, + use_compliant_nested_type=True, + version='2.6') as writer: + writer.write_table(table) + # Read back as a table + new_table = _read_table(path) + # Validate that "items" columns compliant to Parquet nested format + # Should be like this: list<element: struct<name: string, value: string>> + assert isinstance(new_table.schema.types[0], pa.ListType) + assert new_table.schema.types[0].value_field.name == 'element' + + # Verify that the new table can be read/written correctly + _check_roundtrip(new_table, + use_legacy_dataset=use_legacy_dataset, + use_compliant_nested_type=True) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@parametrize_test_data +def test_write_compliant_nested_type_disable(tempdir, + use_legacy_dataset, test_data): + # prepare dataframe for testing + df = pd.DataFrame(data=test_data) + # verify that we can read/write with new flag disabled (default behaviour) + _roundtrip_pandas_dataframe(df, write_kwargs={}, + use_legacy_dataset=use_legacy_dataset) + + # Write to a parquet file while disabling compliant nested type + table = pa.Table.from_pandas(df, preserve_index=False) + path = str(tempdir / 'data.parquet') + with pq.ParquetWriter(path, table.schema, version='2.6') as writer: + writer.write_table(table) + new_table = _read_table(path) + + # Validate that "items" columns is not compliant to Parquet nested format + # Should be like this: list<item: struct<name: string, value: string>> + assert isinstance(new_table.schema.types[0], pa.ListType) + assert new_table.schema.types[0].value_field.name == 'item' + + # Verify that the new table can be read/written correctly + _check_roundtrip(new_table, + use_legacy_dataset=use_legacy_dataset, + use_compliant_nested_type=False) diff --git a/src/arrow/python/pyarrow/tests/parquet/test_data_types.py b/src/arrow/python/pyarrow/tests/parquet/test_data_types.py new file mode 100644 index 000000000..1e2660006 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_data_types.py @@ -0,0 +1,529 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import decimal +import io + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow.tests import util +from pyarrow.tests.parquet.common import (_check_roundtrip, + parametrize_legacy_dataset) + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import _read_table, _write_table +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + + from pyarrow.tests.pandas_examples import (dataframe_with_arrays, + dataframe_with_lists) + from pyarrow.tests.parquet.common import alltypes_sample +except ImportError: + pd = tm = None + + +pytestmark = pytest.mark.parquet + + +# General roundtrip of data types +# ----------------------------------------------------------------------------- + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.parametrize('chunk_size', [None, 1000]) +def test_parquet_2_0_roundtrip(tempdir, chunk_size, use_legacy_dataset): + df = alltypes_sample(size=10000, categorical=True) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + assert arrow_table.schema.pandas_metadata is not None + + _write_table(arrow_table, filename, version='2.6', + coerce_timestamps='ms', chunk_size=chunk_size) + table_read = pq.read_pandas( + filename, use_legacy_dataset=use_legacy_dataset) + assert table_read.schema.pandas_metadata is not None + + read_metadata = table_read.schema.metadata + assert arrow_table.schema.metadata == read_metadata + + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_1_0_roundtrip(tempdir, use_legacy_dataset): + size = 10000 + np.random.seed(0) + df = pd.DataFrame({ + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16), + 'uint32': np.arange(size, dtype=np.uint32), + 'uint64': np.arange(size, dtype=np.uint64), + 'int8': np.arange(size, dtype=np.int16), + 'int16': np.arange(size, dtype=np.int16), + 'int32': np.arange(size, dtype=np.int32), + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + 'str': [str(x) for x in range(size)], + 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None], + 'empty_str': [''] * size + }) + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + _write_table(arrow_table, filename, version='1.0') + table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + + # We pass uint32_t as int64_t if we write Parquet version 1.0 + df['uint32'] = df['uint32'].values.astype(np.int64) + + tm.assert_frame_equal(df, df_read) + + +# Dictionary +# ----------------------------------------------------------------------------- + + +def _simple_table_write_read(table, use_legacy_dataset): + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + return pq.read_table( + pa.BufferReader(contents), use_legacy_dataset=use_legacy_dataset + ) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_direct_read_dictionary(use_legacy_dataset): + # ARROW-3325 + repeats = 10 + nunique = 5 + + data = [ + [util.rands(10) for i in range(nunique)] * repeats, + + ] + table = pa.table(data, names=['f0']) + + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + + result = pq.read_table(pa.BufferReader(contents), + read_dictionary=['f0'], + use_legacy_dataset=use_legacy_dataset) + + # Compute dictionary-encoded subfield + expected = pa.table([table[0].dictionary_encode()], names=['f0']) + assert result.equals(expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_direct_read_dictionary_subfield(use_legacy_dataset): + repeats = 10 + nunique = 5 + + data = [ + [[util.rands(10)] for i in range(nunique)] * repeats, + ] + table = pa.table(data, names=['f0']) + + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + result = pq.read_table(pa.BufferReader(contents), + read_dictionary=['f0.list.item'], + use_legacy_dataset=use_legacy_dataset) + + arr = pa.array(data[0]) + values_as_dict = arr.values.dictionary_encode() + + inner_indices = values_as_dict.indices.cast('int32') + new_values = pa.DictionaryArray.from_arrays(inner_indices, + values_as_dict.dictionary) + + offsets = pa.array(range(51), type='int32') + expected_arr = pa.ListArray.from_arrays(offsets, new_values) + expected = pa.table([expected_arr], names=['f0']) + + assert result.equals(expected) + assert result[0].num_chunks == 1 + + +@parametrize_legacy_dataset +def test_dictionary_array_automatically_read(use_legacy_dataset): + # ARROW-3246 + + # Make a large dictionary, a little over 4MB of data + dict_length = 4000 + dict_values = pa.array([('x' * 1000 + '_{}'.format(i)) + for i in range(dict_length)]) + + num_chunks = 10 + chunk_size = 100 + chunks = [] + for i in range(num_chunks): + indices = np.random.randint(0, dict_length, + size=chunk_size).astype(np.int32) + chunks.append(pa.DictionaryArray.from_arrays(pa.array(indices), + dict_values)) + + table = pa.table([pa.chunked_array(chunks)], names=['f0']) + result = _simple_table_write_read(table, use_legacy_dataset) + + assert result.equals(table) + + # The only key in the metadata was the Arrow schema key + assert result.schema.metadata is None + + +# Decimal +# ----------------------------------------------------------------------------- + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_decimal_roundtrip(tempdir, use_legacy_dataset): + num_values = 10 + + columns = {} + for precision in range(1, 39): + for scale in range(0, precision + 1): + with util.random_seed(0): + random_decimal_values = [ + util.randdecimal(precision, scale) + for _ in range(num_values) + ] + column_name = ('dec_precision_{:d}_scale_{:d}' + .format(precision, scale)) + columns[column_name] = random_decimal_values + + expected = pd.DataFrame(columns) + filename = tempdir / 'decimals.parquet' + string_filename = str(filename) + table = pa.Table.from_pandas(expected) + _write_table(table, string_filename) + result_table = _read_table( + string_filename, use_legacy_dataset=use_legacy_dataset) + result = result_table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@pytest.mark.xfail( + raises=OSError, reason='Parquet does not support negative scale' +) +def test_decimal_roundtrip_negative_scale(tempdir): + expected = pd.DataFrame({'decimal_num': [decimal.Decimal('1.23E4')]}) + filename = tempdir / 'decimals.parquet' + string_filename = str(filename) + t = pa.Table.from_pandas(expected) + _write_table(t, string_filename) + result_table = _read_table(string_filename) + result = result_table.to_pandas() + tm.assert_frame_equal(result, expected) + + +# List types +# ----------------------------------------------------------------------------- + + +@parametrize_legacy_dataset +@pytest.mark.parametrize('dtype', [int, float]) +def test_single_pylist_column_roundtrip(tempdir, dtype, use_legacy_dataset): + filename = tempdir / 'single_{}_column.parquet'.format(dtype.__name__) + data = [pa.array(list(map(dtype, range(5))))] + table = pa.Table.from_arrays(data, names=['a']) + _write_table(table, filename) + table_read = _read_table(filename, use_legacy_dataset=use_legacy_dataset) + for i in range(table.num_columns): + col_written = table[i] + col_read = table_read[i] + assert table.field(i).name == table_read.field(i).name + assert col_read.num_chunks == 1 + data_written = col_written.chunk(0) + data_read = col_read.chunk(0) + assert data_written.equals(data_read) + + +@parametrize_legacy_dataset +def test_empty_lists_table_roundtrip(use_legacy_dataset): + # ARROW-2744: Shouldn't crash when writing an array of empty lists + arr = pa.array([[], []], type=pa.list_(pa.int32())) + table = pa.Table.from_arrays([arr], ["A"]) + _check_roundtrip(table, use_legacy_dataset=use_legacy_dataset) + + +@parametrize_legacy_dataset +def test_nested_list_nonnullable_roundtrip_bug(use_legacy_dataset): + # Reproduce failure in ARROW-5630 + typ = pa.list_(pa.field("item", pa.float32(), False)) + num_rows = 10000 + t = pa.table([ + pa.array(([[0] * ((i + 5) % 10) for i in range(0, 10)] * + (num_rows // 10)), type=typ) + ], ['a']) + _check_roundtrip( + t, data_page_size=4096, use_legacy_dataset=use_legacy_dataset) + + +@parametrize_legacy_dataset +def test_nested_list_struct_multiple_batches_roundtrip( + tempdir, use_legacy_dataset +): + # Reproduce failure in ARROW-11024 + data = [[{'x': 'abc', 'y': 'abc'}]]*100 + [[{'x': 'abc', 'y': 'gcb'}]]*100 + table = pa.table([pa.array(data)], names=['column']) + _check_roundtrip( + table, row_group_size=20, use_legacy_dataset=use_legacy_dataset) + + # Reproduce failure in ARROW-11069 (plain non-nested structs with strings) + data = pa.array( + [{'a': '1', 'b': '2'}, {'a': '3', 'b': '4'}, {'a': '5', 'b': '6'}]*10 + ) + table = pa.table({'column': data}) + _check_roundtrip( + table, row_group_size=10, use_legacy_dataset=use_legacy_dataset) + + +def test_writing_empty_lists(): + # ARROW-2591: [Python] Segmentation fault issue in pq.write_table + arr1 = pa.array([[], []], pa.list_(pa.int32())) + table = pa.Table.from_arrays([arr1], ['list(int32)']) + _check_roundtrip(table) + + +@pytest.mark.pandas +def test_column_of_arrays(tempdir): + df, schema = dataframe_with_arrays() + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, schema=schema) + _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms') + table_read = _read_table(filename) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +def test_column_of_lists(tempdir): + df, schema = dataframe_with_lists(parquet_compatible=True) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, schema=schema) + _write_table(arrow_table, filename, version='2.6') + table_read = _read_table(filename) + df_read = table_read.to_pandas() + + tm.assert_frame_equal(df, df_read) + + +def test_large_list_records(): + # This was fixed in PARQUET-1100 + + list_lengths = np.random.randint(0, 500, size=50) + list_lengths[::10] = 0 + + list_values = [list(map(int, np.random.randint(0, 100, size=x))) + if i % 8 else None + for i, x in enumerate(list_lengths)] + + a1 = pa.array(list_values) + + table = pa.Table.from_arrays([a1], ['int_lists']) + _check_roundtrip(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_nested_convenience(tempdir, use_legacy_dataset): + # ARROW-1684 + df = pd.DataFrame({ + 'a': [[1, 2, 3], None, [4, 5], []], + 'b': [[1.], None, None, [6., 7.]], + }) + + path = str(tempdir / 'nested_convenience.parquet') + + table = pa.Table.from_pandas(df, preserve_index=False) + _write_table(table, path) + + read = pq.read_table( + path, columns=['a'], use_legacy_dataset=use_legacy_dataset) + tm.assert_frame_equal(read.to_pandas(), df[['a']]) + + read = pq.read_table( + path, columns=['a', 'b'], use_legacy_dataset=use_legacy_dataset) + tm.assert_frame_equal(read.to_pandas(), df) + + +# Binary +# ----------------------------------------------------------------------------- + + +def test_fixed_size_binary(): + t0 = pa.binary(10) + data = [b'fooooooooo', None, b'barooooooo', b'quxooooooo'] + a0 = pa.array(data, type=t0) + + table = pa.Table.from_arrays([a0], + ['binary[10]']) + _check_roundtrip(table) + + +# Large types +# ----------------------------------------------------------------------------- + + +@pytest.mark.slow +@pytest.mark.large_memory +def test_large_table_int32_overflow(): + size = np.iinfo('int32').max + 1 + + arr = np.ones(size, dtype='uint8') + + parr = pa.array(arr, type=pa.uint8()) + + table = pa.Table.from_arrays([parr], names=['one']) + f = io.BytesIO() + _write_table(table, f) + + +def _simple_table_roundtrip(table, use_legacy_dataset=False, **write_kwargs): + stream = pa.BufferOutputStream() + _write_table(table, stream, **write_kwargs) + buf = stream.getvalue() + return _read_table(buf, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.slow +@pytest.mark.large_memory +@parametrize_legacy_dataset +def test_byte_array_exactly_2gb(use_legacy_dataset): + # Test edge case reported in ARROW-3762 + val = b'x' * (1 << 10) + + base = pa.array([val] * ((1 << 21) - 1)) + cases = [ + [b'x' * 1023], # 2^31 - 1 + [b'x' * 1024], # 2^31 + [b'x' * 1025] # 2^31 + 1 + ] + for case in cases: + values = pa.chunked_array([base, pa.array(case)]) + t = pa.table([values], names=['f0']) + result = _simple_table_roundtrip( + t, use_legacy_dataset=use_legacy_dataset, use_dictionary=False) + assert t.equals(result) + + +@pytest.mark.slow +@pytest.mark.pandas +@pytest.mark.large_memory +@parametrize_legacy_dataset +def test_binary_array_overflow_to_chunked(use_legacy_dataset): + # ARROW-3762 + + # 2^31 + 1 bytes + values = [b'x'] + [ + b'x' * (1 << 20) + ] * 2 * (1 << 10) + df = pd.DataFrame({'byte_col': values}) + + tbl = pa.Table.from_pandas(df, preserve_index=False) + read_tbl = _simple_table_roundtrip( + tbl, use_legacy_dataset=use_legacy_dataset) + + col0_data = read_tbl[0] + assert isinstance(col0_data, pa.ChunkedArray) + + # Split up into 2GB chunks + assert col0_data.num_chunks == 2 + + assert tbl.equals(read_tbl) + + +@pytest.mark.slow +@pytest.mark.pandas +@pytest.mark.large_memory +@parametrize_legacy_dataset +def test_list_of_binary_large_cell(use_legacy_dataset): + # ARROW-4688 + data = [] + + # TODO(wesm): handle chunked children + # 2^31 - 1 bytes in a single cell + # data.append([b'x' * (1 << 20)] * 2047 + [b'x' * ((1 << 20) - 1)]) + + # A little under 2GB in cell each containing approximately 10MB each + data.extend([[b'x' * 1000000] * 10] * 214) + + arr = pa.array(data) + table = pa.Table.from_arrays([arr], ['chunky_cells']) + read_table = _simple_table_roundtrip( + table, use_legacy_dataset=use_legacy_dataset) + assert table.equals(read_table) + + +def test_large_binary(): + data = [b'foo', b'bar'] * 50 + for type in [pa.large_binary(), pa.large_string()]: + arr = pa.array(data, type=type) + table = pa.Table.from_arrays([arr], names=['strs']) + for use_dictionary in [False, True]: + _check_roundtrip(table, use_dictionary=use_dictionary) + + +@pytest.mark.slow +@pytest.mark.large_memory +def test_large_binary_huge(): + s = b'xy' * 997 + data = [s] * ((1 << 33) // len(s)) + for type in [pa.large_binary(), pa.large_string()]: + arr = pa.array(data, type=type) + table = pa.Table.from_arrays([arr], names=['strs']) + for use_dictionary in [False, True]: + _check_roundtrip(table, use_dictionary=use_dictionary) + del arr, table + + +@pytest.mark.large_memory +def test_large_binary_overflow(): + s = b'x' * (1 << 31) + arr = pa.array([s], type=pa.large_binary()) + table = pa.Table.from_arrays([arr], names=['strs']) + for use_dictionary in [False, True]: + writer = pa.BufferOutputStream() + with pytest.raises( + pa.ArrowInvalid, + match="Parquet cannot store strings with size 2GB or more"): + _write_table(table, writer, use_dictionary=use_dictionary) diff --git a/src/arrow/python/pyarrow/tests/parquet/test_dataset.py b/src/arrow/python/pyarrow/tests/parquet/test_dataset.py new file mode 100644 index 000000000..82f7e5814 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_dataset.py @@ -0,0 +1,1661 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import os + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow import fs +from pyarrow.filesystem import LocalFileSystem +from pyarrow.tests import util +from pyarrow.tests.parquet.common import ( + parametrize_legacy_dataset, parametrize_legacy_dataset_fixed, + parametrize_legacy_dataset_not_supported) +from pyarrow.util import guid +from pyarrow.vendored.version import Version + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import ( + _read_table, _test_dataframe, _write_table) +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + +except ImportError: + pd = tm = None + +pytestmark = pytest.mark.parquet + + +@pytest.mark.pandas +def test_parquet_piece_read(tempdir): + df = _test_dataframe(1000) + table = pa.Table.from_pandas(df) + + path = tempdir / 'parquet_piece_read.parquet' + _write_table(table, path, version='2.6') + + with pytest.warns(DeprecationWarning): + piece1 = pq.ParquetDatasetPiece(path) + + result = piece1.read() + assert result.equals(table) + + +@pytest.mark.pandas +def test_parquet_piece_open_and_get_metadata(tempdir): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df) + + path = tempdir / 'parquet_piece_read.parquet' + _write_table(table, path, version='2.6') + + with pytest.warns(DeprecationWarning): + piece = pq.ParquetDatasetPiece(path) + table1 = piece.read() + assert isinstance(table1, pa.Table) + meta1 = piece.get_metadata() + assert isinstance(meta1, pq.FileMetaData) + + assert table.equals(table1) + + +@pytest.mark.filterwarnings("ignore:ParquetDatasetPiece:DeprecationWarning") +def test_parquet_piece_basics(): + path = '/baz.parq' + + piece1 = pq.ParquetDatasetPiece(path) + piece2 = pq.ParquetDatasetPiece(path, row_group=1) + piece3 = pq.ParquetDatasetPiece( + path, row_group=1, partition_keys=[('foo', 0), ('bar', 1)]) + + assert str(piece1) == path + assert str(piece2) == '/baz.parq | row_group=1' + assert str(piece3) == 'partition[foo=0, bar=1] /baz.parq | row_group=1' + + assert piece1 == piece1 + assert piece2 == piece2 + assert piece3 == piece3 + assert piece1 != piece3 + + +def test_partition_set_dictionary_type(): + set1 = pq.PartitionSet('key1', ['foo', 'bar', 'baz']) + set2 = pq.PartitionSet('key2', [2007, 2008, 2009]) + + assert isinstance(set1.dictionary, pa.StringArray) + assert isinstance(set2.dictionary, pa.IntegerArray) + + set3 = pq.PartitionSet('key2', [datetime.datetime(2007, 1, 1)]) + with pytest.raises(TypeError): + set3.dictionary + + +@parametrize_legacy_dataset_fixed +def test_filesystem_uri(tempdir, use_legacy_dataset): + table = pa.table({"a": [1, 2, 3]}) + + directory = tempdir / "data_dir" + directory.mkdir() + path = directory / "data.parquet" + pq.write_table(table, str(path)) + + # filesystem object + result = pq.read_table( + path, filesystem=fs.LocalFileSystem(), + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + # filesystem URI + result = pq.read_table( + "data_dir/data.parquet", filesystem=util._filesystem_uri(tempdir), + use_legacy_dataset=use_legacy_dataset) + assert result.equals(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_partitioned_directory(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + _partition_test_for_filesystem(fs, tempdir, use_legacy_dataset) + + +@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning") +@pytest.mark.pandas +def test_create_parquet_dataset_multi_threaded(tempdir): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + _partition_test_for_filesystem(fs, base_path) + + manifest = pq.ParquetManifest(base_path, filesystem=fs, + metadata_nthreads=1) + dataset = pq.ParquetDataset(base_path, filesystem=fs, metadata_nthreads=16) + assert len(dataset.pieces) > 0 + partitions = dataset.partitions + assert len(partitions.partition_names) > 0 + assert partitions.partition_names == manifest.partitions.partition_names + assert len(partitions.levels) == len(manifest.partitions.levels) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_partitioned_columns_selection(tempdir, use_legacy_dataset): + # ARROW-3861 - do not include partition columns in resulting table when + # `columns` keyword was passed without those columns + fs = LocalFileSystem._get_instance() + base_path = tempdir + _partition_test_for_filesystem(fs, base_path) + + dataset = pq.ParquetDataset( + base_path, use_legacy_dataset=use_legacy_dataset) + result = dataset.read(columns=["values"]) + if use_legacy_dataset: + # ParquetDataset implementation always includes the partition columns + # automatically, and we can't easily "fix" this since dask relies on + # this behaviour (ARROW-8644) + assert result.column_names == ["values", "foo", "bar"] + else: + assert result.column_names == ["values"] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_equivalency(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1] + string_keys = ['a', 'b', 'c'] + boolean_keys = [True, False] + partition_spec = [ + ['integer', integer_keys], + ['string', string_keys], + ['boolean', boolean_keys] + ] + + df = pd.DataFrame({ + 'integer': np.array(integer_keys, dtype='i4').repeat(15), + 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2), + 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5), + 3), + }, columns=['integer', 'string', 'boolean']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + # Old filters syntax: + # integer == 1 AND string != b AND boolean == True + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[('integer', '=', 1), ('string', '!=', 'b'), + ('boolean', '==', 'True')], + use_legacy_dataset=use_legacy_dataset, + ) + table = dataset.read() + result_df = (table.to_pandas().reset_index(drop=True)) + + assert 0 not in result_df['integer'].values + assert 'b' not in result_df['string'].values + assert False not in result_df['boolean'].values + + # filters in disjunctive normal form: + # (integer == 1 AND string != b AND boolean == True) OR + # (integer == 2 AND boolean == False) + # TODO(ARROW-3388): boolean columns are reconstructed as string + filters = [ + [ + ('integer', '=', 1), + ('string', '!=', 'b'), + ('boolean', '==', 'True') + ], + [('integer', '=', 0), ('boolean', '==', 'False')] + ] + dataset = pq.ParquetDataset( + base_path, filesystem=fs, filters=filters, + use_legacy_dataset=use_legacy_dataset) + table = dataset.read() + result_df = table.to_pandas().reset_index(drop=True) + + # Check that all rows in the DF fulfill the filter + # Pandas 0.23.x has problems with indexing constant memoryviews in + # categoricals. Thus we need to make an explicit copy here with np.array. + df_filter_1 = (np.array(result_df['integer']) == 1) \ + & (np.array(result_df['string']) != 'b') \ + & (np.array(result_df['boolean']) == 'True') + df_filter_2 = (np.array(result_df['integer']) == 0) \ + & (np.array(result_df['boolean']) == 'False') + assert df_filter_1.sum() > 0 + assert df_filter_2.sum() > 0 + assert result_df.shape[0] == (df_filter_1.sum() + df_filter_2.sum()) + + if use_legacy_dataset: + # Check for \0 in predicate values. Until they are correctly + # implemented in ARROW-3391, they would otherwise lead to weird + # results with the current code. + with pytest.raises(NotImplementedError): + filters = [[('string', '==', b'1\0a')]] + pq.ParquetDataset(base_path, filesystem=fs, filters=filters) + with pytest.raises(NotImplementedError): + filters = [[('string', '==', '1\0a')]] + pq.ParquetDataset(base_path, filesystem=fs, filters=filters) + else: + for filters in [[[('string', '==', b'1\0a')]], + [[('string', '==', '1\0a')]]]: + dataset = pq.ParquetDataset( + base_path, filesystem=fs, filters=filters, + use_legacy_dataset=False) + assert dataset.read().num_rows == 0 + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_cutoff_exclusive_integer(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[ + ('integers', '<', 4), + ('integers', '>', 1), + ], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + result_list = [x for x in map(int, result_df['integers'].values)] + assert result_list == [2, 3] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.xfail( + # different error with use_legacy_datasets because result_df is no longer + # categorical + raises=(TypeError, AssertionError), + reason='Loss of type information in creation of categoricals.' +) +def test_filters_cutoff_exclusive_datetime(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + date_keys = [ + datetime.date(2018, 4, 9), + datetime.date(2018, 4, 10), + datetime.date(2018, 4, 11), + datetime.date(2018, 4, 12), + datetime.date(2018, 4, 13) + ] + partition_spec = [ + ['dates', date_keys] + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'dates': np.array(date_keys, dtype='datetime64'), + }, columns=['index', 'dates']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[ + ('dates', '<', "2018-04-12"), + ('dates', '>', "2018-04-10") + ], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + expected = pd.Categorical( + np.array([datetime.date(2018, 4, 11)], dtype='datetime64'), + categories=np.array(date_keys, dtype='datetime64')) + + assert result_df['dates'].values == expected + + +@pytest.mark.pandas +@pytest.mark.dataset +def test_filters_inclusive_datetime(tempdir): + # ARROW-11480 + path = tempdir / 'timestamps.parquet' + + pd.DataFrame({ + "dates": pd.date_range("2020-01-01", periods=10, freq="D"), + "id": range(10) + }).to_parquet(path, use_deprecated_int96_timestamps=True) + + table = pq.read_table(path, filters=[ + ("dates", "<=", datetime.datetime(2020, 1, 5)) + ]) + + assert table.column('id').to_pylist() == [0, 1, 2, 3, 4] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_inclusive_integer(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[ + ('integers', '<=', 3), + ('integers', '>=', 2), + ], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + result_list = [int(x) for x in map(int, result_df['integers'].values)] + assert result_list == [2, 3] + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_inclusive_set(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1] + string_keys = ['a', 'b', 'c'] + boolean_keys = [True, False] + partition_spec = [ + ['integer', integer_keys], + ['string', string_keys], + ['boolean', boolean_keys] + ] + + df = pd.DataFrame({ + 'integer': np.array(integer_keys, dtype='i4').repeat(15), + 'string': np.tile(np.tile(np.array(string_keys, dtype=object), 5), 2), + 'boolean': np.tile(np.tile(np.array(boolean_keys, dtype='bool'), 5), + 3), + }, columns=['integer', 'string', 'boolean']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[('string', 'in', 'ab')], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas().reset_index(drop=True)) + + assert 'a' in result_df['string'].values + assert 'b' in result_df['string'].values + assert 'c' not in result_df['string'].values + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, + filters=[('integer', 'in', [1]), ('string', 'in', ('a', 'b')), + ('boolean', 'not in', {False})], + use_legacy_dataset=use_legacy_dataset + ) + table = dataset.read() + result_df = (table.to_pandas().reset_index(drop=True)) + + assert 0 not in result_df['integer'].values + assert 'c' not in result_df['string'].values + assert False not in result_df['boolean'].values + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_invalid_pred_op(tempdir, use_legacy_dataset): + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + with pytest.raises(TypeError): + pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', 'in', 3), ], + use_legacy_dataset=use_legacy_dataset) + + with pytest.raises(ValueError): + pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', '=<', 3), ], + use_legacy_dataset=use_legacy_dataset) + + if use_legacy_dataset: + with pytest.raises(ValueError): + pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', 'in', set()), ], + use_legacy_dataset=use_legacy_dataset) + else: + # Dataset API returns empty table instead + dataset = pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', 'in', set()), ], + use_legacy_dataset=use_legacy_dataset) + assert dataset.read().num_rows == 0 + + if use_legacy_dataset: + with pytest.raises(ValueError): + pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', '!=', {3})], + use_legacy_dataset=use_legacy_dataset) + else: + dataset = pq.ParquetDataset(base_path, + filesystem=fs, + filters=[('integers', '!=', {3})], + use_legacy_dataset=use_legacy_dataset) + with pytest.raises(NotImplementedError): + assert dataset.read().num_rows == 0 + + +@pytest.mark.pandas +@parametrize_legacy_dataset_fixed +def test_filters_invalid_column(tempdir, use_legacy_dataset): + # ARROW-5572 - raise error on invalid name in filter specification + # works with new dataset / xfail with legacy implementation + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [['integers', integer_keys]] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + msg = r"No match for FieldRef.Name\(non_existent_column\)" + with pytest.raises(ValueError, match=msg): + pq.ParquetDataset(base_path, filesystem=fs, + filters=[('non_existent_column', '<', 3), ], + use_legacy_dataset=use_legacy_dataset).read() + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_filters_read_table(tempdir, use_legacy_dataset): + # test that filters keyword is passed through in read_table + fs = LocalFileSystem._get_instance() + base_path = tempdir + + integer_keys = [0, 1, 2, 3, 4] + partition_spec = [ + ['integers', integer_keys], + ] + N = 5 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'integers': np.array(integer_keys, dtype='i4'), + }, columns=['index', 'integers']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + table = pq.read_table( + base_path, filesystem=fs, filters=[('integers', '<', 3)], + use_legacy_dataset=use_legacy_dataset) + assert table.num_rows == 3 + + table = pq.read_table( + base_path, filesystem=fs, filters=[[('integers', '<', 3)]], + use_legacy_dataset=use_legacy_dataset) + assert table.num_rows == 3 + + table = pq.read_pandas( + base_path, filters=[('integers', '<', 3)], + use_legacy_dataset=use_legacy_dataset) + assert table.num_rows == 3 + + +@pytest.mark.pandas +@parametrize_legacy_dataset_fixed +def test_partition_keys_with_underscores(tempdir, use_legacy_dataset): + # ARROW-5666 - partition field values with underscores preserve underscores + # xfail with legacy dataset -> they get interpreted as integers + fs = LocalFileSystem._get_instance() + base_path = tempdir + + string_keys = ["2019_2", "2019_3"] + partition_spec = [ + ['year_week', string_keys], + ] + N = 2 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'year_week': np.array(string_keys, dtype='object'), + }, columns=['index', 'year_week']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, use_legacy_dataset=use_legacy_dataset) + result = dataset.read() + assert result.column("year_week").to_pylist() == string_keys + + +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_read_s3fs(s3_example_s3fs, use_legacy_dataset): + fs, path = s3_example_s3fs + path = path + "/test.parquet" + table = pa.table({"a": [1, 2, 3]}) + _write_table(table, path, filesystem=fs) + + result = _read_table( + path, filesystem=fs, use_legacy_dataset=use_legacy_dataset + ) + assert result.equals(table) + + +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_read_directory_s3fs(s3_example_s3fs, use_legacy_dataset): + fs, directory = s3_example_s3fs + path = directory + "/test.parquet" + table = pa.table({"a": [1, 2, 3]}) + _write_table(table, path, filesystem=fs) + + result = _read_table( + directory, filesystem=fs, use_legacy_dataset=use_legacy_dataset + ) + assert result.equals(table) + + +@pytest.mark.pandas +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_read_partitioned_directory_s3fs_wrapper( + s3_example_s3fs, use_legacy_dataset +): + import s3fs + + from pyarrow.filesystem import S3FSWrapper + + if Version(s3fs.__version__) >= Version("0.5"): + pytest.skip("S3FSWrapper no longer working for s3fs 0.5+") + + fs, path = s3_example_s3fs + with pytest.warns(FutureWarning): + wrapper = S3FSWrapper(fs) + _partition_test_for_filesystem(wrapper, path) + + # Check that we can auto-wrap + dataset = pq.ParquetDataset( + path, filesystem=fs, use_legacy_dataset=use_legacy_dataset + ) + dataset.read() + + +@pytest.mark.pandas +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_read_partitioned_directory_s3fs(s3_example_s3fs, use_legacy_dataset): + fs, path = s3_example_s3fs + _partition_test_for_filesystem( + fs, path, use_legacy_dataset=use_legacy_dataset + ) + + +def _partition_test_for_filesystem(fs, base_path, use_legacy_dataset=True): + foo_keys = [0, 1] + bar_keys = ['a', 'b', 'c'] + partition_spec = [ + ['foo', foo_keys], + ['bar', bar_keys] + ] + N = 30 + + df = pd.DataFrame({ + 'index': np.arange(N), + 'foo': np.array(foo_keys, dtype='i4').repeat(15), + 'bar': np.tile(np.tile(np.array(bar_keys, dtype=object), 5), 2), + 'values': np.random.randn(N) + }, columns=['index', 'foo', 'bar', 'values']) + + _generate_partition_directories(fs, base_path, partition_spec, df) + + dataset = pq.ParquetDataset( + base_path, filesystem=fs, use_legacy_dataset=use_legacy_dataset) + table = dataset.read() + result_df = (table.to_pandas() + .sort_values(by='index') + .reset_index(drop=True)) + + expected_df = (df.sort_values(by='index') + .reset_index(drop=True) + .reindex(columns=result_df.columns)) + + expected_df['foo'] = pd.Categorical(df['foo'], categories=foo_keys) + expected_df['bar'] = pd.Categorical(df['bar'], categories=bar_keys) + + assert (result_df.columns == ['index', 'values', 'foo', 'bar']).all() + + tm.assert_frame_equal(result_df, expected_df) + + +def _generate_partition_directories(fs, base_dir, partition_spec, df): + # partition_spec : list of lists, e.g. [['foo', [0, 1, 2], + # ['bar', ['a', 'b', 'c']] + # part_table : a pyarrow.Table to write to each partition + DEPTH = len(partition_spec) + + pathsep = getattr(fs, "pathsep", getattr(fs, "sep", "/")) + + def _visit_level(base_dir, level, part_keys): + name, values = partition_spec[level] + for value in values: + this_part_keys = part_keys + [(name, value)] + + level_dir = pathsep.join([ + str(base_dir), + '{}={}'.format(name, value) + ]) + fs.mkdir(level_dir) + + if level == DEPTH - 1: + # Generate example data + file_path = pathsep.join([level_dir, guid()]) + filtered_df = _filter_partition(df, this_part_keys) + part_table = pa.Table.from_pandas(filtered_df) + with fs.open(file_path, 'wb') as f: + _write_table(part_table, f) + assert fs.exists(file_path) + + file_success = pathsep.join([level_dir, '_SUCCESS']) + with fs.open(file_success, 'wb') as f: + pass + else: + _visit_level(level_dir, level + 1, this_part_keys) + file_success = pathsep.join([level_dir, '_SUCCESS']) + with fs.open(file_success, 'wb') as f: + pass + + _visit_level(base_dir, 0, []) + + +def _test_read_common_metadata_files(fs, base_path): + import pandas as pd + + import pyarrow.parquet as pq + + N = 100 + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + + base_path = str(base_path) + data_path = os.path.join(base_path, 'data.parquet') + + table = pa.Table.from_pandas(df) + + with fs.open(data_path, 'wb') as f: + _write_table(table, f) + + metadata_path = os.path.join(base_path, '_common_metadata') + with fs.open(metadata_path, 'wb') as f: + pq.write_metadata(table.schema, f) + + dataset = pq.ParquetDataset(base_path, filesystem=fs) + assert dataset.common_metadata_path == str(metadata_path) + + with fs.open(data_path) as f: + common_schema = pq.read_metadata(f).schema + assert dataset.schema.equals(common_schema) + + # handle list of one directory + dataset2 = pq.ParquetDataset([base_path], filesystem=fs) + assert dataset2.schema.equals(dataset.schema) + + +@pytest.mark.pandas +def test_read_common_metadata_files(tempdir): + fs = LocalFileSystem._get_instance() + _test_read_common_metadata_files(fs, tempdir) + + +@pytest.mark.pandas +def test_read_metadata_files(tempdir): + fs = LocalFileSystem._get_instance() + + N = 100 + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + + data_path = tempdir / 'data.parquet' + + table = pa.Table.from_pandas(df) + + with fs.open(data_path, 'wb') as f: + _write_table(table, f) + + metadata_path = tempdir / '_metadata' + with fs.open(metadata_path, 'wb') as f: + pq.write_metadata(table.schema, f) + + dataset = pq.ParquetDataset(tempdir, filesystem=fs) + assert dataset.metadata_path == str(metadata_path) + + with fs.open(data_path) as f: + metadata_schema = pq.read_metadata(f).schema + assert dataset.schema.equals(metadata_schema) + + +def _filter_partition(df, part_keys): + predicate = np.ones(len(df), dtype=bool) + + to_drop = [] + for name, value in part_keys: + to_drop.append(name) + + # to avoid pandas warning + if isinstance(value, (datetime.date, datetime.datetime)): + value = pd.Timestamp(value) + + predicate &= df[name] == value + + return df[predicate].drop(to_drop, axis=1) + + +@parametrize_legacy_dataset +@pytest.mark.pandas +def test_filter_before_validate_schema(tempdir, use_legacy_dataset): + # ARROW-4076 apply filter before schema validation + # to avoid checking unneeded schemas + + # create partitioned dataset with mismatching schemas which would + # otherwise raise if first validation all schemas + dir1 = tempdir / 'A=0' + dir1.mkdir() + table1 = pa.Table.from_pandas(pd.DataFrame({'B': [1, 2, 3]})) + pq.write_table(table1, dir1 / 'data.parquet') + + dir2 = tempdir / 'A=1' + dir2.mkdir() + table2 = pa.Table.from_pandas(pd.DataFrame({'B': ['a', 'b', 'c']})) + pq.write_table(table2, dir2 / 'data.parquet') + + # read single file using filter + table = pq.read_table(tempdir, filters=[[('A', '==', 0)]], + use_legacy_dataset=use_legacy_dataset) + assert table.column('B').equals(pa.chunked_array([[1, 2, 3]])) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_multiple_files(tempdir, use_legacy_dataset): + nfiles = 10 + size = 5 + + dirpath = tempdir / guid() + dirpath.mkdir() + + test_data = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(size, seed=i) + + # Hack so that we don't have a dtype cast in v1 files + df['uint32'] = df['uint32'].astype(np.int64) + + path = dirpath / '{}.parquet'.format(i) + + table = pa.Table.from_pandas(df) + _write_table(table, path) + + test_data.append(table) + paths.append(path) + + # Write a _SUCCESS.crc file + (dirpath / '_SUCCESS.crc').touch() + + def read_multiple_files(paths, columns=None, use_threads=True, **kwargs): + dataset = pq.ParquetDataset( + paths, use_legacy_dataset=use_legacy_dataset, **kwargs) + return dataset.read(columns=columns, use_threads=use_threads) + + result = read_multiple_files(paths) + expected = pa.concat_tables(test_data) + + assert result.equals(expected) + + # Read with provided metadata + # TODO(dataset) specifying metadata not yet supported + metadata = pq.read_metadata(paths[0]) + if use_legacy_dataset: + result2 = read_multiple_files(paths, metadata=metadata) + assert result2.equals(expected) + + result3 = pq.ParquetDataset(dirpath, schema=metadata.schema).read() + assert result3.equals(expected) + else: + with pytest.raises(ValueError, match="no longer supported"): + pq.read_table(paths, metadata=metadata, use_legacy_dataset=False) + + # Read column subset + to_read = [0, 2, 6, result.num_columns - 1] + + col_names = [result.field(i).name for i in to_read] + out = pq.read_table( + dirpath, columns=col_names, use_legacy_dataset=use_legacy_dataset + ) + expected = pa.Table.from_arrays([result.column(i) for i in to_read], + names=col_names, + metadata=result.schema.metadata) + assert out.equals(expected) + + # Read with multiple threads + pq.read_table( + dirpath, use_threads=True, use_legacy_dataset=use_legacy_dataset + ) + + # Test failure modes with non-uniform metadata + bad_apple = _test_dataframe(size, seed=i).iloc[:, :4] + bad_apple_path = tempdir / '{}.parquet'.format(guid()) + + t = pa.Table.from_pandas(bad_apple) + _write_table(t, bad_apple_path) + + if not use_legacy_dataset: + # TODO(dataset) Dataset API skips bad files + return + + bad_meta = pq.read_metadata(bad_apple_path) + + with pytest.raises(ValueError): + read_multiple_files(paths + [bad_apple_path]) + + with pytest.raises(ValueError): + read_multiple_files(paths, metadata=bad_meta) + + mixed_paths = [bad_apple_path, paths[0]] + + with pytest.raises(ValueError): + read_multiple_files(mixed_paths, schema=bad_meta.schema) + + with pytest.raises(ValueError): + read_multiple_files(mixed_paths) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_read_pandas(tempdir, use_legacy_dataset): + nfiles = 5 + size = 5 + + dirpath = tempdir / guid() + dirpath.mkdir() + + test_data = [] + frames = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(size, seed=i) + df.index = np.arange(i * size, (i + 1) * size) + df.index.name = 'index' + + path = dirpath / '{}.parquet'.format(i) + + table = pa.Table.from_pandas(df) + _write_table(table, path) + test_data.append(table) + frames.append(df) + paths.append(path) + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + columns = ['uint8', 'strings'] + result = dataset.read_pandas(columns=columns).to_pandas() + expected = pd.concat([x[columns] for x in frames]) + + tm.assert_frame_equal(result, expected) + + # also be able to pass the columns as a set (ARROW-12314) + result = dataset.read_pandas(columns=set(columns)).to_pandas() + assert result.shape == expected.shape + # column order can be different because of using a set + tm.assert_frame_equal(result.reindex(columns=expected.columns), expected) + + +@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning") +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_memory_map(tempdir, use_legacy_dataset): + # ARROW-2627: Check that we can use ParquetDataset with memory-mapping + dirpath = tempdir / guid() + dirpath.mkdir() + + df = _test_dataframe(10, seed=0) + path = dirpath / '{}.parquet'.format(0) + table = pa.Table.from_pandas(df) + _write_table(table, path, version='2.6') + + dataset = pq.ParquetDataset( + dirpath, memory_map=True, use_legacy_dataset=use_legacy_dataset) + assert dataset.read().equals(table) + if use_legacy_dataset: + assert dataset.pieces[0].read().equals(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_enable_buffered_stream(tempdir, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + df = _test_dataframe(10, seed=0) + path = dirpath / '{}.parquet'.format(0) + table = pa.Table.from_pandas(df) + _write_table(table, path, version='2.6') + + with pytest.raises(ValueError): + pq.ParquetDataset( + dirpath, buffer_size=-64, + use_legacy_dataset=use_legacy_dataset) + + for buffer_size in [128, 1024]: + dataset = pq.ParquetDataset( + dirpath, buffer_size=buffer_size, + use_legacy_dataset=use_legacy_dataset) + assert dataset.read().equals(table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_enable_pre_buffer(tempdir, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + df = _test_dataframe(10, seed=0) + path = dirpath / '{}.parquet'.format(0) + table = pa.Table.from_pandas(df) + _write_table(table, path, version='2.6') + + for pre_buffer in (True, False): + dataset = pq.ParquetDataset( + dirpath, pre_buffer=pre_buffer, + use_legacy_dataset=use_legacy_dataset) + assert dataset.read().equals(table) + actual = pq.read_table(dirpath, pre_buffer=pre_buffer, + use_legacy_dataset=use_legacy_dataset) + assert actual.equals(table) + + +def _make_example_multifile_dataset(base_path, nfiles=10, file_nrows=5): + test_data = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(file_nrows, seed=i) + path = base_path / '{}.parquet'.format(i) + + test_data.append(_write_table(df, path)) + paths.append(path) + return paths + + +def _assert_dataset_paths(dataset, paths, use_legacy_dataset): + if use_legacy_dataset: + assert set(map(str, paths)) == {x.path for x in dataset._pieces} + else: + paths = [str(path.as_posix()) for path in paths] + assert set(paths) == set(dataset._dataset.files) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.parametrize('dir_prefix', ['_', '.']) +def test_ignore_private_directories(tempdir, dir_prefix, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + # private directory + (dirpath / '{}staging'.format(dir_prefix)).mkdir() + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_ignore_hidden_files_dot(tempdir, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + with (dirpath / '.DS_Store').open('wb') as f: + f.write(b'gibberish') + + with (dirpath / '.private').open('wb') as f: + f.write(b'gibberish') + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_ignore_hidden_files_underscore(tempdir, use_legacy_dataset): + dirpath = tempdir / guid() + dirpath.mkdir() + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + with (dirpath / '_committed_123').open('wb') as f: + f.write(b'abcd') + + with (dirpath / '_started_321').open('wb') as f: + f.write(b'abcd') + + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +@pytest.mark.parametrize('dir_prefix', ['_', '.']) +def test_ignore_no_private_directories_in_base_path( + tempdir, dir_prefix, use_legacy_dataset +): + # ARROW-8427 - don't ignore explicitly listed files if parent directory + # is a private directory + dirpath = tempdir / "{0}data".format(dir_prefix) / guid() + dirpath.mkdir(parents=True) + + paths = _make_example_multifile_dataset(dirpath, nfiles=10, + file_nrows=5) + + dataset = pq.ParquetDataset(paths, use_legacy_dataset=use_legacy_dataset) + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + # ARROW-9644 - don't ignore full directory with underscore in base path + dataset = pq.ParquetDataset(dirpath, use_legacy_dataset=use_legacy_dataset) + _assert_dataset_paths(dataset, paths, use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset_fixed +def test_ignore_custom_prefixes(tempdir, use_legacy_dataset): + # ARROW-9573 - allow override of default ignore_prefixes + part = ["xxx"] * 3 + ["yyy"] * 3 + table = pa.table([ + pa.array(range(len(part))), + pa.array(part).dictionary_encode(), + ], names=['index', '_part']) + + # TODO use_legacy_dataset ARROW-10247 + pq.write_to_dataset(table, str(tempdir), partition_cols=['_part']) + + private_duplicate = tempdir / '_private_duplicate' + private_duplicate.mkdir() + pq.write_to_dataset(table, str(private_duplicate), + partition_cols=['_part']) + + read = pq.read_table( + tempdir, use_legacy_dataset=use_legacy_dataset, + ignore_prefixes=['_private']) + + assert read.equals(table) + + +@parametrize_legacy_dataset_fixed +def test_empty_directory(tempdir, use_legacy_dataset): + # ARROW-5310 - reading empty directory + # fails with legacy implementation + empty_dir = tempdir / 'dataset' + empty_dir.mkdir() + + dataset = pq.ParquetDataset( + empty_dir, use_legacy_dataset=use_legacy_dataset) + result = dataset.read() + assert result.num_rows == 0 + assert result.num_columns == 0 + + +def _test_write_to_dataset_with_partitions(base_path, + use_legacy_dataset=True, + filesystem=None, + schema=None, + index_name=None): + import pandas as pd + import pandas.testing as tm + + import pyarrow.parquet as pq + + # ARROW-1400 + output_df = pd.DataFrame({'group1': list('aaabbbbccc'), + 'group2': list('eefeffgeee'), + 'num': list(range(10)), + 'nan': [np.nan] * 10, + 'date': np.arange('2017-01-01', '2017-01-11', + dtype='datetime64[D]')}) + cols = output_df.columns.tolist() + partition_by = ['group1', 'group2'] + output_table = pa.Table.from_pandas(output_df, schema=schema, safe=False, + preserve_index=False) + pq.write_to_dataset(output_table, base_path, partition_by, + filesystem=filesystem, + use_legacy_dataset=use_legacy_dataset) + + metadata_path = os.path.join(str(base_path), '_common_metadata') + + if filesystem is not None: + with filesystem.open(metadata_path, 'wb') as f: + pq.write_metadata(output_table.schema, f) + else: + pq.write_metadata(output_table.schema, metadata_path) + + # ARROW-2891: Ensure the output_schema is preserved when writing a + # partitioned dataset + dataset = pq.ParquetDataset(base_path, + filesystem=filesystem, + validate_schema=True, + use_legacy_dataset=use_legacy_dataset) + # ARROW-2209: Ensure the dataset schema also includes the partition columns + if use_legacy_dataset: + dataset_cols = set(dataset.schema.to_arrow_schema().names) + else: + # NB schema property is an arrow and not parquet schema + dataset_cols = set(dataset.schema.names) + + assert dataset_cols == set(output_table.schema.names) + + input_table = dataset.read() + input_df = input_table.to_pandas() + + # Read data back in and compare with original DataFrame + # Partitioned columns added to the end of the DataFrame when read + input_df_cols = input_df.columns.tolist() + assert partition_by == input_df_cols[-1 * len(partition_by):] + + input_df = input_df[cols] + # Partitioned columns become 'categorical' dtypes + for col in partition_by: + output_df[col] = output_df[col].astype('category') + tm.assert_frame_equal(output_df, input_df) + + +def _test_write_to_dataset_no_partitions(base_path, + use_legacy_dataset=True, + filesystem=None): + import pandas as pd + + import pyarrow.parquet as pq + + # ARROW-1400 + output_df = pd.DataFrame({'group1': list('aaabbbbccc'), + 'group2': list('eefeffgeee'), + 'num': list(range(10)), + 'date': np.arange('2017-01-01', '2017-01-11', + dtype='datetime64[D]')}) + cols = output_df.columns.tolist() + output_table = pa.Table.from_pandas(output_df) + + if filesystem is None: + filesystem = LocalFileSystem._get_instance() + + # Without partitions, append files to root_path + n = 5 + for i in range(n): + pq.write_to_dataset(output_table, base_path, + filesystem=filesystem) + output_files = [file for file in filesystem.ls(str(base_path)) + if file.endswith(".parquet")] + assert len(output_files) == n + + # Deduplicated incoming DataFrame should match + # original outgoing Dataframe + input_table = pq.ParquetDataset( + base_path, filesystem=filesystem, + use_legacy_dataset=use_legacy_dataset + ).read() + input_df = input_table.to_pandas() + input_df = input_df.drop_duplicates() + input_df = input_df[cols] + assert output_df.equals(input_df) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions(tempdir, use_legacy_dataset): + _test_write_to_dataset_with_partitions(str(tempdir), use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions_and_schema( + tempdir, use_legacy_dataset +): + schema = pa.schema([pa.field('group1', type=pa.string()), + pa.field('group2', type=pa.string()), + pa.field('num', type=pa.int64()), + pa.field('nan', type=pa.int32()), + pa.field('date', type=pa.timestamp(unit='us'))]) + _test_write_to_dataset_with_partitions( + str(tempdir), use_legacy_dataset, schema=schema) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions_and_index_name( + tempdir, use_legacy_dataset +): + _test_write_to_dataset_with_partitions( + str(tempdir), use_legacy_dataset, index_name='index_name') + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_no_partitions(tempdir, use_legacy_dataset): + _test_write_to_dataset_no_partitions(str(tempdir), use_legacy_dataset) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_pathlib(tempdir, use_legacy_dataset): + _test_write_to_dataset_with_partitions( + tempdir / "test1", use_legacy_dataset) + _test_write_to_dataset_no_partitions( + tempdir / "test2", use_legacy_dataset) + + +@pytest.mark.pandas +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_write_to_dataset_pathlib_nonlocal( + tempdir, s3_example_s3fs, use_legacy_dataset +): + # pathlib paths are only accepted for local files + fs, _ = s3_example_s3fs + + with pytest.raises(TypeError, match="path-like objects are only allowed"): + _test_write_to_dataset_with_partitions( + tempdir / "test1", use_legacy_dataset, filesystem=fs) + + with pytest.raises(TypeError, match="path-like objects are only allowed"): + _test_write_to_dataset_no_partitions( + tempdir / "test2", use_legacy_dataset, filesystem=fs) + + +@pytest.mark.pandas +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_write_to_dataset_with_partitions_s3fs( + s3_example_s3fs, use_legacy_dataset +): + fs, path = s3_example_s3fs + + _test_write_to_dataset_with_partitions( + path, use_legacy_dataset, filesystem=fs) + + +@pytest.mark.pandas +@pytest.mark.s3 +@parametrize_legacy_dataset +def test_write_to_dataset_no_partitions_s3fs( + s3_example_s3fs, use_legacy_dataset +): + fs, path = s3_example_s3fs + + _test_write_to_dataset_no_partitions( + path, use_legacy_dataset, filesystem=fs) + + +@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning") +@pytest.mark.pandas +@parametrize_legacy_dataset_not_supported +def test_write_to_dataset_with_partitions_and_custom_filenames( + tempdir, use_legacy_dataset +): + output_df = pd.DataFrame({'group1': list('aaabbbbccc'), + 'group2': list('eefeffgeee'), + 'num': list(range(10)), + 'nan': [np.nan] * 10, + 'date': np.arange('2017-01-01', '2017-01-11', + dtype='datetime64[D]')}) + partition_by = ['group1', 'group2'] + output_table = pa.Table.from_pandas(output_df) + path = str(tempdir) + + def partition_filename_callback(keys): + return "{}-{}.parquet".format(*keys) + + pq.write_to_dataset(output_table, path, + partition_by, partition_filename_callback, + use_legacy_dataset=use_legacy_dataset) + + dataset = pq.ParquetDataset(path) + + # ARROW-3538: Ensure partition filenames match the given pattern + # defined in the local function partition_filename_callback + expected_basenames = [ + 'a-e.parquet', 'a-f.parquet', + 'b-e.parquet', 'b-f.parquet', + 'b-g.parquet', 'c-e.parquet' + ] + output_basenames = [os.path.basename(p.path) for p in dataset.pieces] + + assert sorted(expected_basenames) == sorted(output_basenames) + + +@pytest.mark.dataset +@pytest.mark.pandas +def test_write_to_dataset_filesystem(tempdir): + df = pd.DataFrame({'A': [1, 2, 3]}) + table = pa.Table.from_pandas(df) + path = str(tempdir) + + pq.write_to_dataset(table, path, filesystem=fs.LocalFileSystem()) + result = pq.read_table(path) + assert result.equals(table) + + +# TODO(dataset) support pickling +def _make_dataset_for_pickling(tempdir, N=100): + path = tempdir / 'data.parquet' + fs = LocalFileSystem._get_instance() + + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + table = pa.Table.from_pandas(df) + + num_groups = 3 + with pq.ParquetWriter(path, table.schema) as writer: + for i in range(num_groups): + writer.write_table(table) + + reader = pq.ParquetFile(path) + assert reader.metadata.num_row_groups == num_groups + + metadata_path = tempdir / '_metadata' + with fs.open(metadata_path, 'wb') as f: + pq.write_metadata(table.schema, f) + + dataset = pq.ParquetDataset(tempdir, filesystem=fs) + assert dataset.metadata_path == str(metadata_path) + + return dataset + + +def _assert_dataset_is_picklable(dataset, pickler): + def is_pickleable(obj): + return obj == pickler.loads(pickler.dumps(obj)) + + assert is_pickleable(dataset) + assert is_pickleable(dataset.metadata) + assert is_pickleable(dataset.metadata.schema) + assert len(dataset.metadata.schema) + for column in dataset.metadata.schema: + assert is_pickleable(column) + + for piece in dataset._pieces: + assert is_pickleable(piece) + metadata = piece.get_metadata() + assert metadata.num_row_groups + for i in range(metadata.num_row_groups): + assert is_pickleable(metadata.row_group(i)) + + +@pytest.mark.pandas +def test_builtin_pickle_dataset(tempdir, datadir): + import pickle + dataset = _make_dataset_for_pickling(tempdir) + _assert_dataset_is_picklable(dataset, pickler=pickle) + + +@pytest.mark.pandas +def test_cloudpickle_dataset(tempdir, datadir): + cp = pytest.importorskip('cloudpickle') + dataset = _make_dataset_for_pickling(tempdir) + _assert_dataset_is_picklable(dataset, pickler=cp) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_partitioned_dataset(tempdir, use_legacy_dataset): + # ARROW-3208: Segmentation fault when reading a Parquet partitioned dataset + # to a Parquet file + path = tempdir / "ARROW-3208" + df = pd.DataFrame({ + 'one': [-1, 10, 2.5, 100, 1000, 1, 29.2], + 'two': [-1, 10, 2, 100, 1000, 1, 11], + 'three': [0, 0, 0, 0, 0, 0, 0] + }) + table = pa.Table.from_pandas(df) + pq.write_to_dataset(table, root_path=str(path), + partition_cols=['one', 'two']) + table = pq.ParquetDataset( + path, use_legacy_dataset=use_legacy_dataset).read() + pq.write_table(table, path / "output.parquet") + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_dataset_read_dictionary(tempdir, use_legacy_dataset): + path = tempdir / "ARROW-3325-dataset" + t1 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0']) + t2 = pa.table([[util.rands(10) for i in range(5)] * 10], names=['f0']) + # TODO pass use_legacy_dataset (need to fix unique names) + pq.write_to_dataset(t1, root_path=str(path)) + pq.write_to_dataset(t2, root_path=str(path)) + + result = pq.ParquetDataset( + path, read_dictionary=['f0'], + use_legacy_dataset=use_legacy_dataset).read() + + # The order of the chunks is non-deterministic + ex_chunks = [t1[0].chunk(0).dictionary_encode(), + t2[0].chunk(0).dictionary_encode()] + + assert result[0].num_chunks == 2 + c0, c1 = result[0].chunk(0), result[0].chunk(1) + if c0.equals(ex_chunks[0]): + assert c1.equals(ex_chunks[1]) + else: + assert c0.equals(ex_chunks[1]) + assert c1.equals(ex_chunks[0]) + + +@pytest.mark.dataset +def test_dataset_unsupported_keywords(): + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, schema=pa.schema([])) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, metadata=pa.schema([])) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, validate_schema=False) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, split_row_groups=True) + + with pytest.raises(ValueError, match="not yet supported with the new"): + pq.ParquetDataset("", use_legacy_dataset=False, metadata_nthreads=4) + + with pytest.raises(ValueError, match="no longer supported"): + pq.read_table("", use_legacy_dataset=False, metadata=pa.schema([])) + + +@pytest.mark.dataset +def test_dataset_partitioning(tempdir): + import pyarrow.dataset as ds + + # create small dataset with directory partitioning + root_path = tempdir / "test_partitioning" + (root_path / "2012" / "10" / "01").mkdir(parents=True) + + table = pa.table({'a': [1, 2, 3]}) + pq.write_table( + table, str(root_path / "2012" / "10" / "01" / "data.parquet")) + + # This works with new dataset API + + # read_table + part = ds.partitioning(field_names=["year", "month", "day"]) + result = pq.read_table( + str(root_path), partitioning=part, use_legacy_dataset=False) + assert result.column_names == ["a", "year", "month", "day"] + + result = pq.ParquetDataset( + str(root_path), partitioning=part, use_legacy_dataset=False).read() + assert result.column_names == ["a", "year", "month", "day"] + + # This raises an error for legacy dataset + with pytest.raises(ValueError): + pq.read_table( + str(root_path), partitioning=part, use_legacy_dataset=True) + + with pytest.raises(ValueError): + pq.ParquetDataset( + str(root_path), partitioning=part, use_legacy_dataset=True) + + +@pytest.mark.dataset +def test_parquet_dataset_new_filesystem(tempdir): + # Ensure we can pass new FileSystem object to ParquetDataset + # (use new implementation automatically without specifying + # use_legacy_dataset=False) + table = pa.table({'a': [1, 2, 3]}) + pq.write_table(table, tempdir / 'data.parquet') + # don't use simple LocalFileSystem (as that gets mapped to legacy one) + filesystem = fs.SubTreeFileSystem(str(tempdir), fs.LocalFileSystem()) + dataset = pq.ParquetDataset('.', filesystem=filesystem) + result = dataset.read() + assert result.equals(table) + + +@pytest.mark.filterwarnings("ignore:'ParquetDataset:DeprecationWarning") +def test_parquet_dataset_partitions_piece_path_with_fsspec(tempdir): + # ARROW-10462 ensure that on Windows we properly use posix-style paths + # as used by fsspec + fsspec = pytest.importorskip("fsspec") + filesystem = fsspec.filesystem('file') + table = pa.table({'a': [1, 2, 3]}) + pq.write_table(table, tempdir / 'data.parquet') + + # pass a posix-style path (using "/" also on Windows) + path = str(tempdir).replace("\\", "/") + dataset = pq.ParquetDataset(path, filesystem=filesystem) + # ensure the piece path is also posix-style + expected = path + "/data.parquet" + assert dataset.pieces[0].path == expected + + +@pytest.mark.dataset +def test_parquet_dataset_deprecated_properties(tempdir): + table = pa.table({'a': [1, 2, 3]}) + path = tempdir / 'data.parquet' + pq.write_table(table, path) + dataset = pq.ParquetDataset(path) + + with pytest.warns(DeprecationWarning, match="'ParquetDataset.pieces"): + dataset.pieces + + with pytest.warns(DeprecationWarning, match="'ParquetDataset.partitions"): + dataset.partitions + + with pytest.warns(DeprecationWarning, match="'ParquetDataset.memory_map"): + dataset.memory_map + + with pytest.warns(DeprecationWarning, match="'ParquetDataset.read_dictio"): + dataset.read_dictionary + + with pytest.warns(DeprecationWarning, match="'ParquetDataset.buffer_size"): + dataset.buffer_size + + with pytest.warns(DeprecationWarning, match="'ParquetDataset.fs"): + dataset.fs + + dataset2 = pq.ParquetDataset(path, use_legacy_dataset=False) + + with pytest.warns(DeprecationWarning, match="'ParquetDataset.pieces"): + dataset2.pieces diff --git a/src/arrow/python/pyarrow/tests/parquet/test_datetime.py b/src/arrow/python/pyarrow/tests/parquet/test_datetime.py new file mode 100644 index 000000000..f39f1c762 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_datetime.py @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import io + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow.tests.parquet.common import ( + _check_roundtrip, parametrize_legacy_dataset) + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import _read_table, _write_table +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + + from pyarrow.tests.parquet.common import _roundtrip_pandas_dataframe +except ImportError: + pd = tm = None + + +pytestmark = pytest.mark.parquet + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_datetime_tz(use_legacy_dataset): + s = pd.Series([datetime.datetime(2017, 9, 6)]) + s = s.dt.tz_localize('utc') + + s.index = s + + # Both a column and an index to hit both use cases + df = pd.DataFrame({'tz_aware': s, + 'tz_eastern': s.dt.tz_convert('US/Eastern')}, + index=s) + + f = io.BytesIO() + + arrow_table = pa.Table.from_pandas(df) + + _write_table(arrow_table, f, coerce_timestamps='ms') + f.seek(0) + + table_read = pq.read_pandas(f, use_legacy_dataset=use_legacy_dataset) + + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_datetime_timezone_tzinfo(use_legacy_dataset): + value = datetime.datetime(2018, 1, 1, 1, 23, 45, + tzinfo=datetime.timezone.utc) + df = pd.DataFrame({'foo': [value]}) + + _roundtrip_pandas_dataframe( + df, write_kwargs={}, use_legacy_dataset=use_legacy_dataset) + + +@pytest.mark.pandas +def test_coerce_timestamps(tempdir): + from collections import OrderedDict + + # ARROW-622 + arrays = OrderedDict() + fields = [pa.field('datetime64', + pa.list_(pa.timestamp('ms')))] + arrays['datetime64'] = [ + np.array(['2007-07-13T01:23:34.123456789', + None, + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ms]'), + None, + None, + np.array(['2007-07-13T02', + None, + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ms]'), + ] + + df = pd.DataFrame(arrays) + schema = pa.schema(fields) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, schema=schema) + + _write_table(arrow_table, filename, version='2.6', coerce_timestamps='us') + table_read = _read_table(filename) + df_read = table_read.to_pandas() + + df_expected = df.copy() + for i, x in enumerate(df_expected['datetime64']): + if isinstance(x, np.ndarray): + df_expected['datetime64'][i] = x.astype('M8[us]') + + tm.assert_frame_equal(df_expected, df_read) + + with pytest.raises(ValueError): + _write_table(arrow_table, filename, version='2.6', + coerce_timestamps='unknown') + + +@pytest.mark.pandas +def test_coerce_timestamps_truncated(tempdir): + """ + ARROW-2555: Test that we can truncate timestamps when coercing if + explicitly allowed. + """ + dt_us = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1, + second=1, microsecond=1) + dt_ms = datetime.datetime(year=2017, month=1, day=1, hour=1, minute=1, + second=1) + + fields_us = [pa.field('datetime64', pa.timestamp('us'))] + arrays_us = {'datetime64': [dt_us, dt_ms]} + + df_us = pd.DataFrame(arrays_us) + schema_us = pa.schema(fields_us) + + filename = tempdir / 'pandas_truncated.parquet' + table_us = pa.Table.from_pandas(df_us, schema=schema_us) + + _write_table(table_us, filename, version='2.6', coerce_timestamps='ms', + allow_truncated_timestamps=True) + table_ms = _read_table(filename) + df_ms = table_ms.to_pandas() + + arrays_expected = {'datetime64': [dt_ms, dt_ms]} + df_expected = pd.DataFrame(arrays_expected) + tm.assert_frame_equal(df_expected, df_ms) + + +@pytest.mark.pandas +def test_date_time_types(tempdir): + t1 = pa.date32() + data1 = np.array([17259, 17260, 17261], dtype='int32') + a1 = pa.array(data1, type=t1) + + t2 = pa.date64() + data2 = data1.astype('int64') * 86400000 + a2 = pa.array(data2, type=t2) + + t3 = pa.timestamp('us') + start = pd.Timestamp('2001-01-01').value / 1000 + data3 = np.array([start, start + 1, start + 2], dtype='int64') + a3 = pa.array(data3, type=t3) + + t4 = pa.time32('ms') + data4 = np.arange(3, dtype='i4') + a4 = pa.array(data4, type=t4) + + t5 = pa.time64('us') + a5 = pa.array(data4.astype('int64'), type=t5) + + t6 = pa.time32('s') + a6 = pa.array(data4, type=t6) + + ex_t6 = pa.time32('ms') + ex_a6 = pa.array(data4 * 1000, type=ex_t6) + + t7 = pa.timestamp('ns') + start = pd.Timestamp('2001-01-01').value + data7 = np.array([start, start + 1000, start + 2000], + dtype='int64') + a7 = pa.array(data7, type=t7) + + table = pa.Table.from_arrays([a1, a2, a3, a4, a5, a6, a7], + ['date32', 'date64', 'timestamp[us]', + 'time32[s]', 'time64[us]', + 'time32_from64[s]', + 'timestamp[ns]']) + + # date64 as date32 + # time32[s] to time32[ms] + expected = pa.Table.from_arrays([a1, a1, a3, a4, a5, ex_a6, a7], + ['date32', 'date64', 'timestamp[us]', + 'time32[s]', 'time64[us]', + 'time32_from64[s]', + 'timestamp[ns]']) + + _check_roundtrip(table, expected=expected, version='2.6') + + t0 = pa.timestamp('ms') + data0 = np.arange(4, dtype='int64') + a0 = pa.array(data0, type=t0) + + t1 = pa.timestamp('us') + data1 = np.arange(4, dtype='int64') + a1 = pa.array(data1, type=t1) + + t2 = pa.timestamp('ns') + data2 = np.arange(4, dtype='int64') + a2 = pa.array(data2, type=t2) + + table = pa.Table.from_arrays([a0, a1, a2], + ['ts[ms]', 'ts[us]', 'ts[ns]']) + expected = pa.Table.from_arrays([a0, a1, a2], + ['ts[ms]', 'ts[us]', 'ts[ns]']) + + # int64 for all timestamps supported by default + filename = tempdir / 'int64_timestamps.parquet' + _write_table(table, filename, version='2.6') + parquet_schema = pq.ParquetFile(filename).schema + for i in range(3): + assert parquet_schema.column(i).physical_type == 'INT64' + read_table = _read_table(filename) + assert read_table.equals(expected) + + t0_ns = pa.timestamp('ns') + data0_ns = np.array(data0 * 1000000, dtype='int64') + a0_ns = pa.array(data0_ns, type=t0_ns) + + t1_ns = pa.timestamp('ns') + data1_ns = np.array(data1 * 1000, dtype='int64') + a1_ns = pa.array(data1_ns, type=t1_ns) + + expected = pa.Table.from_arrays([a0_ns, a1_ns, a2], + ['ts[ms]', 'ts[us]', 'ts[ns]']) + + # int96 nanosecond timestamps produced upon request + filename = tempdir / 'explicit_int96_timestamps.parquet' + _write_table(table, filename, version='2.6', + use_deprecated_int96_timestamps=True) + parquet_schema = pq.ParquetFile(filename).schema + for i in range(3): + assert parquet_schema.column(i).physical_type == 'INT96' + read_table = _read_table(filename) + assert read_table.equals(expected) + + # int96 nanosecond timestamps implied by flavor 'spark' + filename = tempdir / 'spark_int96_timestamps.parquet' + _write_table(table, filename, version='2.6', + flavor='spark') + parquet_schema = pq.ParquetFile(filename).schema + for i in range(3): + assert parquet_schema.column(i).physical_type == 'INT96' + read_table = _read_table(filename) + assert read_table.equals(expected) + + +@pytest.mark.pandas +@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) +def test_coerce_int96_timestamp_unit(unit): + i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000 + + d_s = np.arange(i_s, i_s + 10, 1, dtype='int64') + d_ms = d_s * 1000 + d_us = d_ms * 1000 + d_ns = d_us * 1000 + + a_s = pa.array(d_s, type=pa.timestamp('s')) + a_ms = pa.array(d_ms, type=pa.timestamp('ms')) + a_us = pa.array(d_us, type=pa.timestamp('us')) + a_ns = pa.array(d_ns, type=pa.timestamp('ns')) + + arrays = {"s": a_s, "ms": a_ms, "us": a_us, "ns": a_ns} + names = ['ts_s', 'ts_ms', 'ts_us', 'ts_ns'] + table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names) + + # For either Parquet version, coercing to nanoseconds is allowed + # if Int96 storage is used + expected = pa.Table.from_arrays([arrays.get(unit)]*4, names) + read_table_kwargs = {"coerce_int96_timestamp_unit": unit} + _check_roundtrip(table, expected, + read_table_kwargs=read_table_kwargs, + use_deprecated_int96_timestamps=True) + _check_roundtrip(table, expected, version='2.6', + read_table_kwargs=read_table_kwargs, + use_deprecated_int96_timestamps=True) + + +@pytest.mark.pandas +@pytest.mark.parametrize('pq_reader_method', ['ParquetFile', 'read_table']) +def test_coerce_int96_timestamp_overflow(pq_reader_method, tempdir): + + def get_table(pq_reader_method, filename, **kwargs): + if pq_reader_method == "ParquetFile": + return pq.ParquetFile(filename, **kwargs).read() + elif pq_reader_method == "read_table": + return pq.read_table(filename, **kwargs) + + # Recreating the initial JIRA issue referrenced in ARROW-12096 + oob_dts = [ + datetime.datetime(1000, 1, 1), + datetime.datetime(2000, 1, 1), + datetime.datetime(3000, 1, 1) + ] + df = pd.DataFrame({"a": oob_dts}) + table = pa.table(df) + + filename = tempdir / "test_round_trip_overflow.parquet" + pq.write_table(table, filename, use_deprecated_int96_timestamps=True, + version="1.0") + + # with the default resolution of ns, we get wrong values for INT96 + # that are out of bounds for nanosecond range + tab_error = get_table(pq_reader_method, filename) + assert tab_error["a"].to_pylist() != oob_dts + + # avoid this overflow by specifying the resolution to use for INT96 values + tab_correct = get_table( + pq_reader_method, filename, coerce_int96_timestamp_unit="s" + ) + df_correct = tab_correct.to_pandas(timestamp_as_object=True) + tm.assert_frame_equal(df, df_correct) + + +def test_timestamp_restore_timezone(): + # ARROW-5888, restore timezone from serialized metadata + ty = pa.timestamp('ms', tz='America/New_York') + arr = pa.array([1, 2, 3], type=ty) + t = pa.table([arr], names=['f0']) + _check_roundtrip(t) + + +def test_timestamp_restore_timezone_nanosecond(): + # ARROW-9634, also restore timezone for nanosecond data that get stored + # as microseconds in the parquet file + ty = pa.timestamp('ns', tz='America/New_York') + arr = pa.array([1000, 2000, 3000], type=ty) + table = pa.table([arr], names=['f0']) + ty_us = pa.timestamp('us', tz='America/New_York') + expected = pa.table([arr.cast(ty_us)], names=['f0']) + _check_roundtrip(table, expected=expected) + + +@pytest.mark.pandas +def test_list_of_datetime_time_roundtrip(): + # ARROW-4135 + times = pd.to_datetime(['09:00', '09:30', '10:00', '10:30', '11:00', + '11:30', '12:00']) + df = pd.DataFrame({'time': [times.time]}) + _roundtrip_pandas_dataframe(df, write_kwargs={}) + + +@pytest.mark.pandas +def test_parquet_version_timestamp_differences(): + i_s = pd.Timestamp('2010-01-01').value / 1000000000 # := 1262304000 + + d_s = np.arange(i_s, i_s + 10, 1, dtype='int64') + d_ms = d_s * 1000 + d_us = d_ms * 1000 + d_ns = d_us * 1000 + + a_s = pa.array(d_s, type=pa.timestamp('s')) + a_ms = pa.array(d_ms, type=pa.timestamp('ms')) + a_us = pa.array(d_us, type=pa.timestamp('us')) + a_ns = pa.array(d_ns, type=pa.timestamp('ns')) + + names = ['ts:s', 'ts:ms', 'ts:us', 'ts:ns'] + table = pa.Table.from_arrays([a_s, a_ms, a_us, a_ns], names) + + # Using Parquet version 1.0, seconds should be coerced to milliseconds + # and nanoseconds should be coerced to microseconds by default + expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_us], names) + _check_roundtrip(table, expected) + + # Using Parquet version 2.0, seconds should be coerced to milliseconds + # and nanoseconds should be retained by default + expected = pa.Table.from_arrays([a_ms, a_ms, a_us, a_ns], names) + _check_roundtrip(table, expected, version='2.6') + + # Using Parquet version 1.0, coercing to milliseconds or microseconds + # is allowed + expected = pa.Table.from_arrays([a_ms, a_ms, a_ms, a_ms], names) + _check_roundtrip(table, expected, coerce_timestamps='ms') + + # Using Parquet version 2.0, coercing to milliseconds or microseconds + # is allowed + expected = pa.Table.from_arrays([a_us, a_us, a_us, a_us], names) + _check_roundtrip(table, expected, version='2.6', coerce_timestamps='us') + + # TODO: after pyarrow allows coerce_timestamps='ns', tests like the + # following should pass ... + + # Using Parquet version 1.0, coercing to nanoseconds is not allowed + # expected = None + # with pytest.raises(NotImplementedError): + # _roundtrip_table(table, coerce_timestamps='ns') + + # Using Parquet version 2.0, coercing to nanoseconds is allowed + # expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names) + # _check_roundtrip(table, expected, version='2.6', coerce_timestamps='ns') + + # For either Parquet version, coercing to nanoseconds is allowed + # if Int96 storage is used + expected = pa.Table.from_arrays([a_ns, a_ns, a_ns, a_ns], names) + _check_roundtrip(table, expected, + use_deprecated_int96_timestamps=True) + _check_roundtrip(table, expected, version='2.6', + use_deprecated_int96_timestamps=True) + + +@pytest.mark.pandas +def test_noncoerced_nanoseconds_written_without_exception(tempdir): + # ARROW-1957: the Parquet version 2.0 writer preserves Arrow + # nanosecond timestamps by default + n = 9 + df = pd.DataFrame({'x': range(n)}, + index=pd.date_range('2017-01-01', freq='1n', periods=n)) + tb = pa.Table.from_pandas(df) + + filename = tempdir / 'written.parquet' + try: + pq.write_table(tb, filename, version='2.6') + except Exception: + pass + assert filename.exists() + + recovered_table = pq.read_table(filename) + assert tb.equals(recovered_table) + + # Loss of data through coercion (without explicit override) still an error + filename = tempdir / 'not_written.parquet' + with pytest.raises(ValueError): + pq.write_table(tb, filename, coerce_timestamps='ms', version='2.6') diff --git a/src/arrow/python/pyarrow/tests/parquet/test_metadata.py b/src/arrow/python/pyarrow/tests/parquet/test_metadata.py new file mode 100644 index 000000000..9fa2f6394 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_metadata.py @@ -0,0 +1,524 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +from collections import OrderedDict + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow.tests.parquet.common import _check_roundtrip, make_sample_file + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import _write_table +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + + from pyarrow.tests.parquet.common import alltypes_sample +except ImportError: + pd = tm = None + + +pytestmark = pytest.mark.parquet + + +@pytest.mark.pandas +def test_parquet_metadata_api(): + df = alltypes_sample(size=10000) + df = df.reindex(columns=sorted(df.columns)) + df.index = np.random.randint(0, 1000000, size=len(df)) + + fileh = make_sample_file(df) + ncols = len(df.columns) + + # Series of sniff tests + meta = fileh.metadata + repr(meta) + assert meta.num_rows == len(df) + assert meta.num_columns == ncols + 1 # +1 for index + assert meta.num_row_groups == 1 + assert meta.format_version == '2.6' + assert 'parquet-cpp' in meta.created_by + assert isinstance(meta.serialized_size, int) + assert isinstance(meta.metadata, dict) + + # Schema + schema = fileh.schema + assert meta.schema is schema + assert len(schema) == ncols + 1 # +1 for index + repr(schema) + + col = schema[0] + repr(col) + assert col.name == df.columns[0] + assert col.max_definition_level == 1 + assert col.max_repetition_level == 0 + assert col.max_repetition_level == 0 + + assert col.physical_type == 'BOOLEAN' + assert col.converted_type == 'NONE' + + with pytest.raises(IndexError): + schema[ncols + 1] # +1 for index + + with pytest.raises(IndexError): + schema[-1] + + # Row group + for rg in range(meta.num_row_groups): + rg_meta = meta.row_group(rg) + assert isinstance(rg_meta, pq.RowGroupMetaData) + repr(rg_meta) + + for col in range(rg_meta.num_columns): + col_meta = rg_meta.column(col) + assert isinstance(col_meta, pq.ColumnChunkMetaData) + repr(col_meta) + + with pytest.raises(IndexError): + meta.row_group(-1) + + with pytest.raises(IndexError): + meta.row_group(meta.num_row_groups + 1) + + rg_meta = meta.row_group(0) + assert rg_meta.num_rows == len(df) + assert rg_meta.num_columns == ncols + 1 # +1 for index + assert rg_meta.total_byte_size > 0 + + with pytest.raises(IndexError): + col_meta = rg_meta.column(-1) + + with pytest.raises(IndexError): + col_meta = rg_meta.column(ncols + 2) + + col_meta = rg_meta.column(0) + assert col_meta.file_offset > 0 + assert col_meta.file_path == '' # created from BytesIO + assert col_meta.physical_type == 'BOOLEAN' + assert col_meta.num_values == 10000 + assert col_meta.path_in_schema == 'bool' + assert col_meta.is_stats_set is True + assert isinstance(col_meta.statistics, pq.Statistics) + assert col_meta.compression == 'SNAPPY' + assert col_meta.encodings == ('PLAIN', 'RLE') + assert col_meta.has_dictionary_page is False + assert col_meta.dictionary_page_offset is None + assert col_meta.data_page_offset > 0 + assert col_meta.total_compressed_size > 0 + assert col_meta.total_uncompressed_size > 0 + with pytest.raises(NotImplementedError): + col_meta.has_index_page + with pytest.raises(NotImplementedError): + col_meta.index_page_offset + + +def test_parquet_metadata_lifetime(tempdir): + # ARROW-6642 - ensure that chained access keeps parent objects alive + table = pa.table({'a': [1, 2, 3]}) + pq.write_table(table, tempdir / 'test_metadata_segfault.parquet') + parquet_file = pq.ParquetFile(tempdir / 'test_metadata_segfault.parquet') + parquet_file.metadata.row_group(0).column(0).statistics + + +@pytest.mark.pandas +@pytest.mark.parametrize( + ( + 'data', + 'type', + 'physical_type', + 'min_value', + 'max_value', + 'null_count', + 'num_values', + 'distinct_count' + ), + [ + ([1, 2, 2, None, 4], pa.uint8(), 'INT32', 1, 4, 1, 4, 0), + ([1, 2, 2, None, 4], pa.uint16(), 'INT32', 1, 4, 1, 4, 0), + ([1, 2, 2, None, 4], pa.uint32(), 'INT32', 1, 4, 1, 4, 0), + ([1, 2, 2, None, 4], pa.uint64(), 'INT64', 1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int8(), 'INT32', -1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int16(), 'INT32', -1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int32(), 'INT32', -1, 4, 1, 4, 0), + ([-1, 2, 2, None, 4], pa.int64(), 'INT64', -1, 4, 1, 4, 0), + ( + [-1.1, 2.2, 2.3, None, 4.4], pa.float32(), + 'FLOAT', -1.1, 4.4, 1, 4, 0 + ), + ( + [-1.1, 2.2, 2.3, None, 4.4], pa.float64(), + 'DOUBLE', -1.1, 4.4, 1, 4, 0 + ), + ( + ['', 'b', chr(1000), None, 'aaa'], pa.binary(), + 'BYTE_ARRAY', b'', chr(1000).encode('utf-8'), 1, 4, 0 + ), + ( + [True, False, False, True, True], pa.bool_(), + 'BOOLEAN', False, True, 0, 5, 0 + ), + ( + [b'\x00', b'b', b'12', None, b'aaa'], pa.binary(), + 'BYTE_ARRAY', b'\x00', b'b', 1, 4, 0 + ), + ] +) +def test_parquet_column_statistics_api(data, type, physical_type, min_value, + max_value, null_count, num_values, + distinct_count): + df = pd.DataFrame({'data': data}) + schema = pa.schema([pa.field('data', type)]) + table = pa.Table.from_pandas(df, schema=schema, safe=False) + fileh = make_sample_file(table) + + meta = fileh.metadata + + rg_meta = meta.row_group(0) + col_meta = rg_meta.column(0) + + stat = col_meta.statistics + assert stat.has_min_max + assert _close(type, stat.min, min_value) + assert _close(type, stat.max, max_value) + assert stat.null_count == null_count + assert stat.num_values == num_values + # TODO(kszucs) until parquet-cpp API doesn't expose HasDistinctCount + # method, missing distinct_count is represented as zero instead of None + assert stat.distinct_count == distinct_count + assert stat.physical_type == physical_type + + +def _close(type, left, right): + if type == pa.float32(): + return abs(left - right) < 1E-7 + elif type == pa.float64(): + return abs(left - right) < 1E-13 + else: + return left == right + + +# ARROW-6339 +@pytest.mark.pandas +def test_parquet_raise_on_unset_statistics(): + df = pd.DataFrame({"t": pd.Series([pd.NaT], dtype="datetime64[ns]")}) + meta = make_sample_file(pa.Table.from_pandas(df)).metadata + + assert not meta.row_group(0).column(0).statistics.has_min_max + assert meta.row_group(0).column(0).statistics.max is None + + +def test_statistics_convert_logical_types(tempdir): + # ARROW-5166, ARROW-4139 + + # (min, max, type) + cases = [(10, 11164359321221007157, pa.uint64()), + (10, 4294967295, pa.uint32()), + ("ähnlich", "öffentlich", pa.utf8()), + (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000), + pa.time32('ms')), + (datetime.time(10, 30, 0, 1000), datetime.time(15, 30, 0, 1000), + pa.time64('us')), + (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000), + datetime.datetime(2019, 6, 25, 0, 0, 0, 1000), + pa.timestamp('ms')), + (datetime.datetime(2019, 6, 24, 0, 0, 0, 1000), + datetime.datetime(2019, 6, 25, 0, 0, 0, 1000), + pa.timestamp('us'))] + + for i, (min_val, max_val, typ) in enumerate(cases): + t = pa.Table.from_arrays([pa.array([min_val, max_val], type=typ)], + ['col']) + path = str(tempdir / ('example{}.parquet'.format(i))) + pq.write_table(t, path, version='2.6') + pf = pq.ParquetFile(path) + stats = pf.metadata.row_group(0).column(0).statistics + assert stats.min == min_val + assert stats.max == max_val + + +def test_parquet_write_disable_statistics(tempdir): + table = pa.Table.from_pydict( + OrderedDict([ + ('a', pa.array([1, 2, 3])), + ('b', pa.array(['a', 'b', 'c'])) + ]) + ) + _write_table(table, tempdir / 'data.parquet') + meta = pq.read_metadata(tempdir / 'data.parquet') + for col in [0, 1]: + cc = meta.row_group(0).column(col) + assert cc.is_stats_set is True + assert cc.statistics is not None + + _write_table(table, tempdir / 'data2.parquet', write_statistics=False) + meta = pq.read_metadata(tempdir / 'data2.parquet') + for col in [0, 1]: + cc = meta.row_group(0).column(col) + assert cc.is_stats_set is False + assert cc.statistics is None + + _write_table(table, tempdir / 'data3.parquet', write_statistics=['a']) + meta = pq.read_metadata(tempdir / 'data3.parquet') + cc_a = meta.row_group(0).column(0) + cc_b = meta.row_group(0).column(1) + assert cc_a.is_stats_set is True + assert cc_b.is_stats_set is False + assert cc_a.statistics is not None + assert cc_b.statistics is None + + +def test_field_id_metadata(): + # ARROW-7080 + field_id = b'PARQUET:field_id' + inner = pa.field('inner', pa.int32(), metadata={field_id: b'100'}) + middle = pa.field('middle', pa.struct( + [inner]), metadata={field_id: b'101'}) + fields = [ + pa.field('basic', pa.int32(), metadata={ + b'other': b'abc', field_id: b'1'}), + pa.field( + 'list', + pa.list_(pa.field('list-inner', pa.int32(), + metadata={field_id: b'10'})), + metadata={field_id: b'11'}), + pa.field('struct', pa.struct([middle]), metadata={field_id: b'102'}), + pa.field('no-metadata', pa.int32()), + pa.field('non-integral-field-id', pa.int32(), + metadata={field_id: b'xyz'}), + pa.field('negative-field-id', pa.int32(), + metadata={field_id: b'-1000'}) + ] + arrs = [[] for _ in fields] + table = pa.table(arrs, schema=pa.schema(fields)) + + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + contents = bio.getvalue() + + pf = pq.ParquetFile(pa.BufferReader(contents)) + schema = pf.schema_arrow + + assert schema[0].metadata[field_id] == b'1' + assert schema[0].metadata[b'other'] == b'abc' + + list_field = schema[1] + assert list_field.metadata[field_id] == b'11' + + list_item_field = list_field.type.value_field + assert list_item_field.metadata[field_id] == b'10' + + struct_field = schema[2] + assert struct_field.metadata[field_id] == b'102' + + struct_middle_field = struct_field.type[0] + assert struct_middle_field.metadata[field_id] == b'101' + + struct_inner_field = struct_middle_field.type[0] + assert struct_inner_field.metadata[field_id] == b'100' + + assert schema[3].metadata is None + # Invalid input is passed through (ok) but does not + # have field_id in parquet (not tested) + assert schema[4].metadata[field_id] == b'xyz' + assert schema[5].metadata[field_id] == b'-1000' + + +@pytest.mark.pandas +def test_multi_dataset_metadata(tempdir): + filenames = ["ARROW-1983-dataset.0", "ARROW-1983-dataset.1"] + metapath = str(tempdir / "_metadata") + + # create a test dataset + df = pd.DataFrame({ + 'one': [1, 2, 3], + 'two': [-1, -2, -3], + 'three': [[1, 2], [2, 3], [3, 4]], + }) + table = pa.Table.from_pandas(df) + + # write dataset twice and collect/merge metadata + _meta = None + for filename in filenames: + meta = [] + pq.write_table(table, str(tempdir / filename), + metadata_collector=meta) + meta[0].set_file_path(filename) + if _meta is None: + _meta = meta[0] + else: + _meta.append_row_groups(meta[0]) + + # Write merged metadata-only file + with open(metapath, "wb") as f: + _meta.write_metadata_file(f) + + # Read back the metadata + meta = pq.read_metadata(metapath) + md = meta.to_dict() + _md = _meta.to_dict() + for key in _md: + if key != 'serialized_size': + assert _md[key] == md[key] + assert _md['num_columns'] == 3 + assert _md['num_rows'] == 6 + assert _md['num_row_groups'] == 2 + assert _md['serialized_size'] == 0 + assert md['serialized_size'] > 0 + + +def test_write_metadata(tempdir): + path = str(tempdir / "metadata") + schema = pa.schema([("a", "int64"), ("b", "float64")]) + + # write a pyarrow schema + pq.write_metadata(schema, path) + parquet_meta = pq.read_metadata(path) + schema_as_arrow = parquet_meta.schema.to_arrow_schema() + assert schema_as_arrow.equals(schema) + + # ARROW-8980: Check that the ARROW:schema metadata key was removed + if schema_as_arrow.metadata: + assert b'ARROW:schema' not in schema_as_arrow.metadata + + # pass through writer keyword arguments + for version in ["1.0", "2.0", "2.4", "2.6"]: + pq.write_metadata(schema, path, version=version) + parquet_meta = pq.read_metadata(path) + # The version is stored as a single integer in the Parquet metadata, + # so it cannot correctly express dotted format versions + expected_version = "1.0" if version == "1.0" else "2.6" + assert parquet_meta.format_version == expected_version + + # metadata_collector: list of FileMetaData objects + table = pa.table({'a': [1, 2], 'b': [.1, .2]}, schema=schema) + pq.write_table(table, tempdir / "data.parquet") + parquet_meta = pq.read_metadata(str(tempdir / "data.parquet")) + pq.write_metadata( + schema, path, metadata_collector=[parquet_meta, parquet_meta] + ) + parquet_meta_mult = pq.read_metadata(path) + assert parquet_meta_mult.num_row_groups == 2 + + # append metadata with different schema raises an error + with pytest.raises(RuntimeError, match="requires equal schemas"): + pq.write_metadata( + pa.schema([("a", "int32"), ("b", "null")]), + path, metadata_collector=[parquet_meta, parquet_meta] + ) + + +def test_table_large_metadata(): + # ARROW-8694 + my_schema = pa.schema([pa.field('f0', 'double')], + metadata={'large': 'x' * 10000000}) + + table = pa.table([np.arange(10)], schema=my_schema) + _check_roundtrip(table) + + +@pytest.mark.pandas +def test_compare_schemas(): + df = alltypes_sample(size=10000) + + fileh = make_sample_file(df) + fileh2 = make_sample_file(df) + fileh3 = make_sample_file(df[df.columns[::2]]) + + # ParquetSchema + assert isinstance(fileh.schema, pq.ParquetSchema) + assert fileh.schema.equals(fileh.schema) + assert fileh.schema == fileh.schema + assert fileh.schema.equals(fileh2.schema) + assert fileh.schema == fileh2.schema + assert fileh.schema != 'arbitrary object' + assert not fileh.schema.equals(fileh3.schema) + assert fileh.schema != fileh3.schema + + # ColumnSchema + assert isinstance(fileh.schema[0], pq.ColumnSchema) + assert fileh.schema[0].equals(fileh.schema[0]) + assert fileh.schema[0] == fileh.schema[0] + assert not fileh.schema[0].equals(fileh.schema[1]) + assert fileh.schema[0] != fileh.schema[1] + assert fileh.schema[0] != 'arbitrary object' + + +@pytest.mark.pandas +def test_read_schema(tempdir): + N = 100 + df = pd.DataFrame({ + 'index': np.arange(N), + 'values': np.random.randn(N) + }, columns=['index', 'values']) + + data_path = tempdir / 'test.parquet' + + table = pa.Table.from_pandas(df) + _write_table(table, data_path) + + read1 = pq.read_schema(data_path) + read2 = pq.read_schema(data_path, memory_map=True) + assert table.schema.equals(read1) + assert table.schema.equals(read2) + + assert table.schema.metadata[b'pandas'] == read1.metadata[b'pandas'] + + +def test_parquet_metadata_empty_to_dict(tempdir): + # https://issues.apache.org/jira/browse/ARROW-10146 + table = pa.table({"a": pa.array([], type="int64")}) + pq.write_table(table, tempdir / "data.parquet") + metadata = pq.read_metadata(tempdir / "data.parquet") + # ensure this doesn't error / statistics set to None + metadata_dict = metadata.to_dict() + assert len(metadata_dict["row_groups"]) == 1 + assert len(metadata_dict["row_groups"][0]["columns"]) == 1 + assert metadata_dict["row_groups"][0]["columns"][0]["statistics"] is None + + +@pytest.mark.slow +@pytest.mark.large_memory +def test_metadata_exceeds_message_size(): + # ARROW-13655: Thrift may enable a defaut message size that limits + # the size of Parquet metadata that can be written. + NCOLS = 1000 + NREPEATS = 4000 + + table = pa.table({str(i): np.random.randn(10) for i in range(NCOLS)}) + + with pa.BufferOutputStream() as out: + pq.write_table(table, out) + buf = out.getvalue() + + original_metadata = pq.read_metadata(pa.BufferReader(buf)) + metadata = pq.read_metadata(pa.BufferReader(buf)) + for i in range(NREPEATS): + metadata.append_row_groups(original_metadata) + + with pa.BufferOutputStream() as out: + metadata.write_metadata_file(out) + buf = out.getvalue() + + metadata = pq.read_metadata(pa.BufferReader(buf)) diff --git a/src/arrow/python/pyarrow/tests/parquet/test_pandas.py b/src/arrow/python/pyarrow/tests/parquet/test_pandas.py new file mode 100644 index 000000000..5fc1b7060 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_pandas.py @@ -0,0 +1,687 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import json + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow.fs import LocalFileSystem, SubTreeFileSystem +from pyarrow.tests.parquet.common import ( + parametrize_legacy_dataset, parametrize_legacy_dataset_not_supported) +from pyarrow.util import guid +from pyarrow.vendored.version import Version + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import (_read_table, _test_dataframe, + _write_table) +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + + from pyarrow.tests.parquet.common import (_roundtrip_pandas_dataframe, + alltypes_sample) +except ImportError: + pd = tm = None + + +pytestmark = pytest.mark.parquet + + +@pytest.mark.pandas +def test_pandas_parquet_custom_metadata(tempdir): + df = alltypes_sample(size=10000) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + assert b'pandas' in arrow_table.schema.metadata + + _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms') + + metadata = pq.read_metadata(filename).metadata + assert b'pandas' in metadata + + js = json.loads(metadata[b'pandas'].decode('utf8')) + assert js['index_columns'] == [{'kind': 'range', + 'name': None, + 'start': 0, 'stop': 10000, + 'step': 1}] + + +@pytest.mark.pandas +def test_merging_parquet_tables_with_different_pandas_metadata(tempdir): + # ARROW-3728: Merging Parquet Files - Pandas Meta in Schema Mismatch + schema = pa.schema([ + pa.field('int', pa.int16()), + pa.field('float', pa.float32()), + pa.field('string', pa.string()) + ]) + df1 = pd.DataFrame({ + 'int': np.arange(3, dtype=np.uint8), + 'float': np.arange(3, dtype=np.float32), + 'string': ['ABBA', 'EDDA', 'ACDC'] + }) + df2 = pd.DataFrame({ + 'int': [4, 5], + 'float': [1.1, None], + 'string': [None, None] + }) + table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False) + table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False) + + assert not table1.schema.equals(table2.schema, check_metadata=True) + assert table1.schema.equals(table2.schema) + + writer = pq.ParquetWriter(tempdir / 'merged.parquet', schema=schema) + writer.write_table(table1) + writer.write_table(table2) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_column_multiindex(tempdir, use_legacy_dataset): + df = alltypes_sample(size=10) + df.columns = pd.MultiIndex.from_tuples( + list(zip(df.columns, df.columns[::-1])), + names=['level_1', 'level_2'] + ) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + assert arrow_table.schema.pandas_metadata is not None + + _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms') + + table_read = pq.read_pandas( + filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_2_0_roundtrip_read_pandas_no_index_written( + tempdir, use_legacy_dataset +): + df = alltypes_sample(size=10000) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + js = arrow_table.schema.pandas_metadata + assert not js['index_columns'] + # ARROW-2170 + # While index_columns should be empty, columns needs to be filled still. + assert js['columns'] + + _write_table(arrow_table, filename, version='2.6', coerce_timestamps='ms') + table_read = pq.read_pandas( + filename, use_legacy_dataset=use_legacy_dataset) + + js = table_read.schema.pandas_metadata + assert not js['index_columns'] + + read_metadata = table_read.schema.metadata + assert arrow_table.schema.metadata == read_metadata + + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +# TODO(dataset) duplicate column selection actually gives duplicate columns now +@pytest.mark.pandas +@parametrize_legacy_dataset_not_supported +def test_pandas_column_selection(tempdir, use_legacy_dataset): + size = 10000 + np.random.seed(0) + df = pd.DataFrame({ + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16) + }) + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + _write_table(arrow_table, filename) + table_read = _read_table( + filename, columns=['uint8'], use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + + tm.assert_frame_equal(df[['uint8']], df_read) + + # ARROW-4267: Selection of duplicate columns still leads to these columns + # being read uniquely. + table_read = _read_table( + filename, columns=['uint8', 'uint8'], + use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + + tm.assert_frame_equal(df[['uint8']], df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_native_file_roundtrip(tempdir, use_legacy_dataset): + df = _test_dataframe(10000) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + _write_table(arrow_table, imos, version='2.6') + buf = imos.getvalue() + reader = pa.BufferReader(buf) + df_read = _read_table( + reader, use_legacy_dataset=use_legacy_dataset).to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_read_pandas_column_subset(tempdir, use_legacy_dataset): + df = _test_dataframe(10000) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + _write_table(arrow_table, imos, version='2.6') + buf = imos.getvalue() + reader = pa.BufferReader(buf) + df_read = pq.read_pandas( + reader, columns=['strings', 'uint8'], + use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(df[['strings', 'uint8']], df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_empty_roundtrip(tempdir, use_legacy_dataset): + df = _test_dataframe(0) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + _write_table(arrow_table, imos, version='2.6') + buf = imos.getvalue() + reader = pa.BufferReader(buf) + df_read = _read_table( + reader, use_legacy_dataset=use_legacy_dataset).to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +def test_pandas_can_write_nested_data(tempdir): + data = { + "agg_col": [ + {"page_type": 1}, + {"record_type": 1}, + {"non_consecutive_home": 0}, + ], + "uid_first": "1001" + } + df = pd.DataFrame(data=data) + arrow_table = pa.Table.from_pandas(df) + imos = pa.BufferOutputStream() + # This succeeds under V2 + _write_table(arrow_table, imos) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_pyfile_roundtrip(tempdir, use_legacy_dataset): + filename = tempdir / 'pandas_pyfile_roundtrip.parquet' + size = 5 + df = pd.DataFrame({ + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + 'strings': ['foo', 'bar', None, 'baz', 'qux'] + }) + + arrow_table = pa.Table.from_pandas(df) + + with filename.open('wb') as f: + _write_table(arrow_table, f, version="1.0") + + data = io.BytesIO(filename.read_bytes()) + + table_read = _read_table(data, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_parquet_configuration_options(tempdir, use_legacy_dataset): + size = 10000 + np.random.seed(0) + df = pd.DataFrame({ + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16), + 'uint32': np.arange(size, dtype=np.uint32), + 'uint64': np.arange(size, dtype=np.uint64), + 'int8': np.arange(size, dtype=np.int16), + 'int16': np.arange(size, dtype=np.int16), + 'int32': np.arange(size, dtype=np.int32), + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0 + }) + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + + for use_dictionary in [True, False]: + _write_table(arrow_table, filename, version='2.6', + use_dictionary=use_dictionary) + table_read = _read_table( + filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + for write_statistics in [True, False]: + _write_table(arrow_table, filename, version='2.6', + write_statistics=write_statistics) + table_read = _read_table(filename, + use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + for compression in ['NONE', 'SNAPPY', 'GZIP', 'LZ4', 'ZSTD']: + if (compression != 'NONE' and + not pa.lib.Codec.is_available(compression)): + continue + _write_table(arrow_table, filename, version='2.6', + compression=compression) + table_read = _read_table( + filename, use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df, df_read) + + +@pytest.mark.pandas +def test_spark_flavor_preserves_pandas_metadata(): + df = _test_dataframe(size=100) + df.index = np.arange(0, 10 * len(df), 10) + df.index.name = 'foo' + + result = _roundtrip_pandas_dataframe(df, {'version': '2.0', + 'flavor': 'spark'}) + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_index_column_name_duplicate(tempdir, use_legacy_dataset): + data = { + 'close': { + pd.Timestamp('2017-06-30 01:31:00'): 154.99958999999998, + pd.Timestamp('2017-06-30 01:32:00'): 154.99958999999998, + }, + 'time': { + pd.Timestamp('2017-06-30 01:31:00'): pd.Timestamp( + '2017-06-30 01:31:00' + ), + pd.Timestamp('2017-06-30 01:32:00'): pd.Timestamp( + '2017-06-30 01:32:00' + ), + } + } + path = str(tempdir / 'data.parquet') + dfx = pd.DataFrame(data).set_index('time', drop=False) + tdfx = pa.Table.from_pandas(dfx) + _write_table(tdfx, path) + arrow_table = _read_table(path, use_legacy_dataset=use_legacy_dataset) + result_df = arrow_table.to_pandas() + tm.assert_frame_equal(result_df, dfx) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_multiindex_duplicate_values(tempdir, use_legacy_dataset): + num_rows = 3 + numbers = list(range(num_rows)) + index = pd.MultiIndex.from_arrays( + [['foo', 'foo', 'bar'], numbers], + names=['foobar', 'some_numbers'], + ) + + df = pd.DataFrame({'numbers': numbers}, index=index) + table = pa.Table.from_pandas(df) + + filename = tempdir / 'dup_multi_index_levels.parquet' + + _write_table(table, filename) + result_table = _read_table(filename, use_legacy_dataset=use_legacy_dataset) + assert table.equals(result_table) + + result_df = result_table.to_pandas() + tm.assert_frame_equal(result_df, df) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_index_naming(datadir, use_legacy_dataset): + expected_string = b"""\ +carat cut color clarity depth table price x y z + 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 + 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 + 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 + 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 + 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 + 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 + 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 + 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 + 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 + 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" + expected = pd.read_csv(io.BytesIO(expected_string), sep=r'\s{2,}', + index_col=None, header=0, engine='python') + table = _read_table( + datadir / 'v0.7.1.parquet', use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_index_multi_level_named( + datadir, use_legacy_dataset +): + expected_string = b"""\ +carat cut color clarity depth table price x y z + 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 + 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 + 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 + 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 + 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 + 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 + 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 + 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 + 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 + 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" + expected = pd.read_csv( + io.BytesIO(expected_string), sep=r'\s{2,}', + index_col=['cut', 'color', 'clarity'], + header=0, engine='python' + ).sort_index() + + table = _read_table(datadir / 'v0.7.1.all-named-index.parquet', + use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_index_multi_level_some_named( + datadir, use_legacy_dataset +): + expected_string = b"""\ +carat cut color clarity depth table price x y z + 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43 + 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31 + 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31 + 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63 + 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75 + 0.24 Very Good J VVS2 62.8 57.0 336 3.94 3.96 2.48 + 0.24 Very Good I VVS1 62.3 57.0 336 3.95 3.98 2.47 + 0.26 Very Good H SI1 61.9 55.0 337 4.07 4.11 2.53 + 0.22 Fair E VS2 65.1 61.0 337 3.87 3.78 2.49 + 0.23 Very Good H VS1 59.4 61.0 338 4.00 4.05 2.39""" + expected = pd.read_csv( + io.BytesIO(expected_string), + sep=r'\s{2,}', index_col=['cut', 'color', 'clarity'], + header=0, engine='python' + ).sort_index() + expected.index = expected.index.set_names(['cut', None, 'clarity']) + + table = _read_table(datadir / 'v0.7.1.some-named-index.parquet', + use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_backwards_compatible_column_metadata_handling( + datadir, use_legacy_dataset +): + expected = pd.DataFrame( + {'a': [1, 2, 3], 'b': [.1, .2, .3], + 'c': pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')}) + expected.index = pd.MultiIndex.from_arrays( + [['a', 'b', 'c'], + pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')], + names=['index', None]) + + path = datadir / 'v0.7.1.column-metadata-handling.parquet' + table = _read_table(path, use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + table = _read_table( + path, columns=['a'], use_legacy_dataset=use_legacy_dataset) + result = table.to_pandas() + tm.assert_frame_equal(result, expected[['a']].reset_index(drop=True)) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_categorical_index_survives_roundtrip(use_legacy_dataset): + # ARROW-3652, addressed by ARROW-3246 + df = pd.DataFrame([['a', 'b'], ['c', 'd']], columns=['c1', 'c2']) + df['c1'] = df['c1'].astype('category') + df = df.set_index(['c1']) + + table = pa.Table.from_pandas(df) + bos = pa.BufferOutputStream() + pq.write_table(table, bos) + ref_df = pq.read_pandas( + bos.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas() + assert isinstance(ref_df.index, pd.CategoricalIndex) + assert ref_df.index.equals(df.index) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_categorical_order_survives_roundtrip(use_legacy_dataset): + # ARROW-6302 + df = pd.DataFrame({"a": pd.Categorical( + ["a", "b", "c", "a"], categories=["b", "c", "d"], ordered=True)}) + + table = pa.Table.from_pandas(df) + bos = pa.BufferOutputStream() + pq.write_table(table, bos) + + contents = bos.getvalue() + result = pq.read_pandas( + contents, use_legacy_dataset=use_legacy_dataset).to_pandas() + + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_categorical_na_type_row_groups(use_legacy_dataset): + # ARROW-5085 + df = pd.DataFrame({"col": [None] * 100, "int": [1.0] * 100}) + df_category = df.astype({"col": "category", "int": "category"}) + table = pa.Table.from_pandas(df) + table_cat = pa.Table.from_pandas(df_category) + buf = pa.BufferOutputStream() + + # it works + pq.write_table(table_cat, buf, version='2.6', chunk_size=10) + result = pq.read_table( + buf.getvalue(), use_legacy_dataset=use_legacy_dataset) + + # Result is non-categorical + assert result[0].equals(table[0]) + assert result[1].equals(table[1]) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_pandas_categorical_roundtrip(use_legacy_dataset): + # ARROW-5480, this was enabled by ARROW-3246 + + # Have one of the categories unobserved and include a null (-1) + codes = np.array([2, 0, 0, 2, 0, -1, 2], dtype='int32') + categories = ['foo', 'bar', 'baz'] + df = pd.DataFrame({'x': pd.Categorical.from_codes( + codes, categories=categories)}) + + buf = pa.BufferOutputStream() + pq.write_table(pa.table(df), buf) + + result = pq.read_table( + buf.getvalue(), use_legacy_dataset=use_legacy_dataset).to_pandas() + assert result.x.dtype == 'category' + assert (result.x.cat.categories == categories).all() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_pandas_preserve_extensiondtypes( + tempdir, use_legacy_dataset +): + # ARROW-8251 - preserve pandas extension dtypes in roundtrip + if Version(pd.__version__) < Version("1.0.0"): + pytest.skip("__arrow_array__ added to pandas in 1.0.0") + + df = pd.DataFrame({'part': 'a', "col": [1, 2, 3]}) + df['col'] = df['col'].astype("Int64") + table = pa.table(df) + + pq.write_to_dataset( + table, str(tempdir / "case1"), partition_cols=['part'], + use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result[["col"]], df[["col"]]) + + pq.write_to_dataset( + table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result[["col"]], df[["col"]]) + + pq.write_table(table, str(tempdir / "data.parquet")) + result = pq.read_table( + str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result[["col"]], df[["col"]]) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_write_to_dataset_pandas_preserve_index(tempdir, use_legacy_dataset): + # ARROW-8251 - preserve pandas index in roundtrip + + df = pd.DataFrame({'part': ['a', 'a', 'b'], "col": [1, 2, 3]}) + df.index = pd.Index(['a', 'b', 'c'], name="idx") + table = pa.table(df) + df_cat = df[["col", "part"]].copy() + df_cat["part"] = df_cat["part"].astype("category") + + pq.write_to_dataset( + table, str(tempdir / "case1"), partition_cols=['part'], + use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case1"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result, df_cat) + + pq.write_to_dataset( + table, str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ) + result = pq.read_table( + str(tempdir / "case2"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result, df) + + pq.write_table(table, str(tempdir / "data.parquet")) + result = pq.read_table( + str(tempdir / "data.parquet"), use_legacy_dataset=use_legacy_dataset + ).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.parametrize('preserve_index', [True, False, None]) +def test_dataset_read_pandas_common_metadata(tempdir, preserve_index): + # ARROW-1103 + nfiles = 5 + size = 5 + + dirpath = tempdir / guid() + dirpath.mkdir() + + test_data = [] + frames = [] + paths = [] + for i in range(nfiles): + df = _test_dataframe(size, seed=i) + df.index = pd.Index(np.arange(i * size, (i + 1) * size), name='index') + + path = dirpath / '{}.parquet'.format(i) + + table = pa.Table.from_pandas(df, preserve_index=preserve_index) + + # Obliterate metadata + table = table.replace_schema_metadata(None) + assert table.schema.metadata is None + + _write_table(table, path) + test_data.append(table) + frames.append(df) + paths.append(path) + + # Write _metadata common file + table_for_metadata = pa.Table.from_pandas( + df, preserve_index=preserve_index + ) + pq.write_metadata(table_for_metadata.schema, dirpath / '_metadata') + + dataset = pq.ParquetDataset(dirpath) + columns = ['uint8', 'strings'] + result = dataset.read_pandas(columns=columns).to_pandas() + expected = pd.concat([x[columns] for x in frames]) + expected.index.name = ( + df.index.name if preserve_index is not False else None) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.pandas +def test_read_pandas_passthrough_keywords(tempdir): + # ARROW-11464 - previously not all keywords were passed through (such as + # the filesystem keyword) + df = pd.DataFrame({'a': [1, 2, 3]}) + + filename = tempdir / 'data.parquet' + _write_table(df, filename) + + result = pq.read_pandas( + 'data.parquet', + filesystem=SubTreeFileSystem(str(tempdir), LocalFileSystem()) + ) + assert result.equals(pa.table(df)) diff --git a/src/arrow/python/pyarrow/tests/parquet/test_parquet_file.py b/src/arrow/python/pyarrow/tests/parquet/test_parquet_file.py new file mode 100644 index 000000000..3a4d4aaa2 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_parquet_file.py @@ -0,0 +1,276 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import os + +import pytest + +import pyarrow as pa + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import _write_table +except ImportError: + pq = None + +try: + import pandas as pd + import pandas.testing as tm + + from pyarrow.tests.parquet.common import alltypes_sample +except ImportError: + pd = tm = None + +pytestmark = pytest.mark.parquet + + +@pytest.mark.pandas +def test_pass_separate_metadata(): + # ARROW-471 + df = alltypes_sample(size=10000) + + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, compression='snappy', version='2.6') + + buf.seek(0) + metadata = pq.read_metadata(buf) + + buf.seek(0) + + fileh = pq.ParquetFile(buf, metadata=metadata) + + tm.assert_frame_equal(df, fileh.read().to_pandas()) + + +@pytest.mark.pandas +def test_read_single_row_group(): + # ARROW-471 + N, K = 10000, 4 + df = alltypes_sample(size=N) + + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.6') + + buf.seek(0) + + pf = pq.ParquetFile(buf) + + assert pf.num_row_groups == K + + row_groups = [pf.read_row_group(i) for i in range(K)] + result = pa.concat_tables(row_groups) + tm.assert_frame_equal(df, result.to_pandas()) + + +@pytest.mark.pandas +def test_read_single_row_group_with_column_subset(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.6') + + buf.seek(0) + pf = pq.ParquetFile(buf) + + cols = list(df.columns[:2]) + row_groups = [pf.read_row_group(i, columns=cols) for i in range(K)] + result = pa.concat_tables(row_groups) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + # ARROW-4267: Selection of duplicate columns still leads to these columns + # being read uniquely. + row_groups = [pf.read_row_group(i, columns=cols + cols) for i in range(K)] + result = pa.concat_tables(row_groups) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + +@pytest.mark.pandas +def test_read_multiple_row_groups(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.6') + + buf.seek(0) + + pf = pq.ParquetFile(buf) + + assert pf.num_row_groups == K + + result = pf.read_row_groups(range(K)) + tm.assert_frame_equal(df, result.to_pandas()) + + +@pytest.mark.pandas +def test_read_multiple_row_groups_with_column_subset(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.6') + + buf.seek(0) + pf = pq.ParquetFile(buf) + + cols = list(df.columns[:2]) + result = pf.read_row_groups(range(K), columns=cols) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + # ARROW-4267: Selection of duplicate columns still leads to these columns + # being read uniquely. + result = pf.read_row_groups(range(K), columns=cols + cols) + tm.assert_frame_equal(df[cols], result.to_pandas()) + + +@pytest.mark.pandas +def test_scan_contents(): + N, K = 10000, 4 + df = alltypes_sample(size=N) + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.6') + + buf.seek(0) + pf = pq.ParquetFile(buf) + + assert pf.scan_contents() == 10000 + assert pf.scan_contents(df.columns[:4]) == 10000 + + +def test_parquet_file_pass_directory_instead_of_file(tempdir): + # ARROW-7208 + path = tempdir / 'directory' + os.mkdir(str(path)) + + with pytest.raises(IOError, match="Expected file path"): + pq.ParquetFile(path) + + +def test_read_column_invalid_index(): + table = pa.table([pa.array([4, 5]), pa.array(["foo", "bar"])], + names=['ints', 'strs']) + bio = pa.BufferOutputStream() + pq.write_table(table, bio) + f = pq.ParquetFile(bio.getvalue()) + assert f.reader.read_column(0).to_pylist() == [4, 5] + assert f.reader.read_column(1).to_pylist() == ["foo", "bar"] + for index in (-1, 2): + with pytest.raises((ValueError, IndexError)): + f.reader.read_column(index) + + +@pytest.mark.pandas +@pytest.mark.parametrize('batch_size', [300, 1000, 1300]) +def test_iter_batches_columns_reader(tempdir, batch_size): + total_size = 3000 + chunk_size = 1000 + # TODO: Add categorical support + df = alltypes_sample(size=total_size) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + _write_table(arrow_table, filename, version='2.6', + coerce_timestamps='ms', chunk_size=chunk_size) + + file_ = pq.ParquetFile(filename) + for columns in [df.columns[:10], df.columns[10:]]: + batches = file_.iter_batches(batch_size=batch_size, columns=columns) + batch_starts = range(0, total_size+batch_size, batch_size) + for batch, start in zip(batches, batch_starts): + end = min(total_size, start + batch_size) + tm.assert_frame_equal( + batch.to_pandas(), + df.iloc[start:end, :].loc[:, columns].reset_index(drop=True) + ) + + +@pytest.mark.pandas +@pytest.mark.parametrize('chunk_size', [1000]) +def test_iter_batches_reader(tempdir, chunk_size): + df = alltypes_sample(size=10000, categorical=True) + + filename = tempdir / 'pandas_roundtrip.parquet' + arrow_table = pa.Table.from_pandas(df) + assert arrow_table.schema.pandas_metadata is not None + + _write_table(arrow_table, filename, version='2.6', + coerce_timestamps='ms', chunk_size=chunk_size) + + file_ = pq.ParquetFile(filename) + + def get_all_batches(f): + for row_group in range(f.num_row_groups): + batches = f.iter_batches( + batch_size=900, + row_groups=[row_group], + ) + + for batch in batches: + yield batch + + batches = list(get_all_batches(file_)) + batch_no = 0 + + for i in range(file_.num_row_groups): + tm.assert_frame_equal( + batches[batch_no].to_pandas(), + file_.read_row_groups([i]).to_pandas().head(900) + ) + + batch_no += 1 + + tm.assert_frame_equal( + batches[batch_no].to_pandas().reset_index(drop=True), + file_.read_row_groups([i]).to_pandas().iloc[900:].reset_index( + drop=True + ) + ) + + batch_no += 1 + + +@pytest.mark.pandas +@pytest.mark.parametrize('pre_buffer', [False, True]) +def test_pre_buffer(pre_buffer): + N, K = 10000, 4 + df = alltypes_sample(size=N) + a_table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(a_table, buf, row_group_size=N / K, + compression='snappy', version='2.6') + + buf.seek(0) + pf = pq.ParquetFile(buf, pre_buffer=pre_buffer) + assert pf.read().num_rows == N diff --git a/src/arrow/python/pyarrow/tests/parquet/test_parquet_writer.py b/src/arrow/python/pyarrow/tests/parquet/test_parquet_writer.py new file mode 100644 index 000000000..9be7634c8 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/parquet/test_parquet_writer.py @@ -0,0 +1,278 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import pyarrow as pa +from pyarrow import fs +from pyarrow.filesystem import FileSystem, LocalFileSystem +from pyarrow.tests.parquet.common import parametrize_legacy_dataset + +try: + import pyarrow.parquet as pq + from pyarrow.tests.parquet.common import _read_table, _test_dataframe +except ImportError: + pq = None + + +try: + import pandas as pd + import pandas.testing as tm + +except ImportError: + pd = tm = None + +pytestmark = pytest.mark.parquet + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_incremental_file_build(tempdir, use_legacy_dataset): + df = _test_dataframe(100) + df['unique_id'] = 0 + + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + out = pa.BufferOutputStream() + + writer = pq.ParquetWriter(out, arrow_table.schema, version='2.6') + + frames = [] + for i in range(10): + df['unique_id'] = i + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + writer.write_table(arrow_table) + + frames.append(df.copy()) + + writer.close() + + buf = out.getvalue() + result = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + + expected = pd.concat(frames, ignore_index=True) + tm.assert_frame_equal(result.to_pandas(), expected) + + +def test_validate_schema_write_table(tempdir): + # ARROW-2926 + simple_fields = [ + pa.field('POS', pa.uint32()), + pa.field('desc', pa.string()) + ] + + simple_schema = pa.schema(simple_fields) + + # simple_table schema does not match simple_schema + simple_from_array = [pa.array([1]), pa.array(['bla'])] + simple_table = pa.Table.from_arrays(simple_from_array, ['POS', 'desc']) + + path = tempdir / 'simple_validate_schema.parquet' + + with pq.ParquetWriter(path, simple_schema, + version='2.6', + compression='snappy', flavor='spark') as w: + with pytest.raises(ValueError): + w.write_table(simple_table) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_writer_context_obj(tempdir, use_legacy_dataset): + df = _test_dataframe(100) + df['unique_id'] = 0 + + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + out = pa.BufferOutputStream() + + with pq.ParquetWriter(out, arrow_table.schema, version='2.6') as writer: + + frames = [] + for i in range(10): + df['unique_id'] = i + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + writer.write_table(arrow_table) + + frames.append(df.copy()) + + buf = out.getvalue() + result = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + + expected = pd.concat(frames, ignore_index=True) + tm.assert_frame_equal(result.to_pandas(), expected) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_writer_context_obj_with_exception( + tempdir, use_legacy_dataset +): + df = _test_dataframe(100) + df['unique_id'] = 0 + + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + out = pa.BufferOutputStream() + error_text = 'Artificial Error' + + try: + with pq.ParquetWriter(out, + arrow_table.schema, + version='2.6') as writer: + + frames = [] + for i in range(10): + df['unique_id'] = i + arrow_table = pa.Table.from_pandas(df, preserve_index=False) + writer.write_table(arrow_table) + frames.append(df.copy()) + if i == 5: + raise ValueError(error_text) + except Exception as e: + assert str(e) == error_text + + buf = out.getvalue() + result = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + + expected = pd.concat(frames, ignore_index=True) + tm.assert_frame_equal(result.to_pandas(), expected) + + +@pytest.mark.pandas +@pytest.mark.parametrize("filesystem", [ + None, + LocalFileSystem._get_instance(), + fs.LocalFileSystem(), +]) +def test_parquet_writer_filesystem_local(tempdir, filesystem): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + path = str(tempdir / 'data.parquet') + + with pq.ParquetWriter( + path, table.schema, filesystem=filesystem, version='2.6' + ) as writer: + writer.write_table(table) + + result = _read_table(path).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.s3 +def test_parquet_writer_filesystem_s3(s3_example_fs): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + fs, uri, path = s3_example_fs + + with pq.ParquetWriter( + path, table.schema, filesystem=fs, version='2.6' + ) as writer: + writer.write_table(table) + + result = _read_table(uri).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.s3 +def test_parquet_writer_filesystem_s3_uri(s3_example_fs): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + fs, uri, path = s3_example_fs + + with pq.ParquetWriter(uri, table.schema, version='2.6') as writer: + writer.write_table(table) + + result = _read_table(path, filesystem=fs).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.s3 +def test_parquet_writer_filesystem_s3fs(s3_example_s3fs): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + fs, directory = s3_example_s3fs + path = directory + "/test.parquet" + + with pq.ParquetWriter( + path, table.schema, filesystem=fs, version='2.6' + ) as writer: + writer.write_table(table) + + result = _read_table(path, filesystem=fs).to_pandas() + tm.assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_parquet_writer_filesystem_buffer_raises(): + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + filesystem = fs.LocalFileSystem() + + # Should raise ValueError when filesystem is passed with file-like object + with pytest.raises(ValueError, match="specified path is file-like"): + pq.ParquetWriter( + pa.BufferOutputStream(), table.schema, filesystem=filesystem + ) + + +@pytest.mark.pandas +@parametrize_legacy_dataset +def test_parquet_writer_with_caller_provided_filesystem(use_legacy_dataset): + out = pa.BufferOutputStream() + + class CustomFS(FileSystem): + def __init__(self): + self.path = None + self.mode = None + + def open(self, path, mode='rb'): + self.path = path + self.mode = mode + return out + + fs = CustomFS() + fname = 'expected_fname.parquet' + df = _test_dataframe(100) + table = pa.Table.from_pandas(df, preserve_index=False) + + with pq.ParquetWriter(fname, table.schema, filesystem=fs, version='2.6') \ + as writer: + writer.write_table(table) + + assert fs.path == fname + assert fs.mode == 'wb' + assert out.closed + + buf = out.getvalue() + table_read = _read_table( + pa.BufferReader(buf), use_legacy_dataset=use_legacy_dataset) + df_read = table_read.to_pandas() + tm.assert_frame_equal(df_read, df) + + # Should raise ValueError when filesystem is passed with file-like object + with pytest.raises(ValueError) as err_info: + pq.ParquetWriter(pa.BufferOutputStream(), table.schema, filesystem=fs) + expected_msg = ("filesystem passed but where is file-like, so" + " there is nothing to open with filesystem.") + assert str(err_info) == expected_msg diff --git a/src/arrow/python/pyarrow/tests/pyarrow_cython_example.pyx b/src/arrow/python/pyarrow/tests/pyarrow_cython_example.pyx new file mode 100644 index 000000000..08f5e17a9 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/pyarrow_cython_example.pyx @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# distutils: language=c++ +# cython: language_level = 3 + +from pyarrow.lib cimport * + + +def get_array_length(obj): + # An example function accessing both the pyarrow Cython API + # and the Arrow C++ API + cdef shared_ptr[CArray] arr = pyarrow_unwrap_array(obj) + if arr.get() == NULL: + raise TypeError("not an array") + return arr.get().length() + + +def make_null_array(length): + # An example function that returns a PyArrow object without PyArrow + # being imported explicitly at the Python level. + cdef shared_ptr[CArray] null_array + null_array.reset(new CNullArray(length)) + return pyarrow_wrap_array(null_array) + + +def cast_scalar(scalar, to_type): + cdef: + shared_ptr[CScalar] c_scalar + shared_ptr[CDataType] c_type + CResult[shared_ptr[CScalar]] c_result + + c_scalar = pyarrow_unwrap_scalar(scalar) + if c_scalar.get() == NULL: + raise TypeError("not a scalar") + c_type = pyarrow_unwrap_data_type(to_type) + if c_type.get() == NULL: + raise TypeError("not a type") + c_result = c_scalar.get().CastTo(c_type) + c_scalar = GetResultValue(c_result) + return pyarrow_wrap_scalar(c_scalar) diff --git a/src/arrow/python/pyarrow/tests/strategies.py b/src/arrow/python/pyarrow/tests/strategies.py new file mode 100644 index 000000000..d314785ff --- /dev/null +++ b/src/arrow/python/pyarrow/tests/strategies.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime + +import pytz +import hypothesis as h +import hypothesis.strategies as st +import hypothesis.extra.numpy as npst +import hypothesis.extra.pytz as tzst +import numpy as np + +import pyarrow as pa + + +# TODO(kszucs): alphanum_text, surrogate_text +custom_text = st.text( + alphabet=st.characters( + min_codepoint=0x41, + max_codepoint=0x7E + ) +) + +null_type = st.just(pa.null()) +bool_type = st.just(pa.bool_()) + +binary_type = st.just(pa.binary()) +string_type = st.just(pa.string()) +large_binary_type = st.just(pa.large_binary()) +large_string_type = st.just(pa.large_string()) +fixed_size_binary_type = st.builds( + pa.binary, + st.integers(min_value=0, max_value=16) +) +binary_like_types = st.one_of( + binary_type, + string_type, + large_binary_type, + large_string_type, + fixed_size_binary_type +) + +signed_integer_types = st.sampled_from([ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64() +]) +unsigned_integer_types = st.sampled_from([ + pa.uint8(), + pa.uint16(), + pa.uint32(), + pa.uint64() +]) +integer_types = st.one_of(signed_integer_types, unsigned_integer_types) + +floating_types = st.sampled_from([ + pa.float16(), + pa.float32(), + pa.float64() +]) +decimal128_type = st.builds( + pa.decimal128, + precision=st.integers(min_value=1, max_value=38), + scale=st.integers(min_value=1, max_value=38) +) +decimal256_type = st.builds( + pa.decimal256, + precision=st.integers(min_value=1, max_value=76), + scale=st.integers(min_value=1, max_value=76) +) +numeric_types = st.one_of(integer_types, floating_types, + decimal128_type, decimal256_type) + +date_types = st.sampled_from([ + pa.date32(), + pa.date64() +]) +time_types = st.sampled_from([ + pa.time32('s'), + pa.time32('ms'), + pa.time64('us'), + pa.time64('ns') +]) +timestamp_types = st.builds( + pa.timestamp, + unit=st.sampled_from(['s', 'ms', 'us', 'ns']), + tz=tzst.timezones() +) +duration_types = st.builds( + pa.duration, + st.sampled_from(['s', 'ms', 'us', 'ns']) +) +interval_types = st.sampled_from( + pa.month_day_nano_interval() +) +temporal_types = st.one_of( + date_types, + time_types, + timestamp_types, + duration_types, + interval_types +) + +primitive_types = st.one_of( + null_type, + bool_type, + numeric_types, + temporal_types, + binary_like_types +) + +metadata = st.dictionaries(st.text(), st.text()) + + +@st.composite +def fields(draw, type_strategy=primitive_types): + name = draw(custom_text) + typ = draw(type_strategy) + if pa.types.is_null(typ): + nullable = True + else: + nullable = draw(st.booleans()) + meta = draw(metadata) + return pa.field(name, type=typ, nullable=nullable, metadata=meta) + + +def list_types(item_strategy=primitive_types): + return ( + st.builds(pa.list_, item_strategy) | + st.builds(pa.large_list, item_strategy) | + st.builds( + pa.list_, + item_strategy, + st.integers(min_value=0, max_value=16) + ) + ) + + +@st.composite +def struct_types(draw, item_strategy=primitive_types): + fields_strategy = st.lists(fields(item_strategy)) + fields_rendered = draw(fields_strategy) + field_names = [field.name for field in fields_rendered] + # check that field names are unique, see ARROW-9997 + h.assume(len(set(field_names)) == len(field_names)) + return pa.struct(fields_rendered) + + +def dictionary_types(key_strategy=None, value_strategy=None): + key_strategy = key_strategy or signed_integer_types + value_strategy = value_strategy or st.one_of( + bool_type, + integer_types, + st.sampled_from([pa.float32(), pa.float64()]), + binary_type, + string_type, + fixed_size_binary_type, + ) + return st.builds(pa.dictionary, key_strategy, value_strategy) + + +@st.composite +def map_types(draw, key_strategy=primitive_types, + item_strategy=primitive_types): + key_type = draw(key_strategy) + h.assume(not pa.types.is_null(key_type)) + value_type = draw(item_strategy) + return pa.map_(key_type, value_type) + + +# union type +# extension type + + +def schemas(type_strategy=primitive_types, max_fields=None): + children = st.lists(fields(type_strategy), max_size=max_fields) + return st.builds(pa.schema, children) + + +all_types = st.deferred( + lambda: ( + primitive_types | + list_types() | + struct_types() | + dictionary_types() | + map_types() | + list_types(all_types) | + struct_types(all_types) + ) +) +all_fields = fields(all_types) +all_schemas = schemas(all_types) + + +_default_array_sizes = st.integers(min_value=0, max_value=20) + + +@st.composite +def _pylist(draw, value_type, size, nullable=True): + arr = draw(arrays(value_type, size=size, nullable=False)) + return arr.to_pylist() + + +@st.composite +def _pymap(draw, key_type, value_type, size, nullable=True): + length = draw(size) + keys = draw(_pylist(key_type, size=length, nullable=False)) + values = draw(_pylist(value_type, size=length, nullable=nullable)) + return list(zip(keys, values)) + + +@st.composite +def arrays(draw, type, size=None, nullable=True): + if isinstance(type, st.SearchStrategy): + ty = draw(type) + elif isinstance(type, pa.DataType): + ty = type + else: + raise TypeError('Type must be a pyarrow DataType') + + if isinstance(size, st.SearchStrategy): + size = draw(size) + elif size is None: + size = draw(_default_array_sizes) + elif not isinstance(size, int): + raise TypeError('Size must be an integer') + + if pa.types.is_null(ty): + h.assume(nullable) + value = st.none() + elif pa.types.is_boolean(ty): + value = st.booleans() + elif pa.types.is_integer(ty): + values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,))) + return pa.array(values, type=ty) + elif pa.types.is_floating(ty): + values = draw(npst.arrays(ty.to_pandas_dtype(), shape=(size,))) + # Workaround ARROW-4952: no easy way to assert array equality + # in a NaN-tolerant way. + values[np.isnan(values)] = -42.0 + return pa.array(values, type=ty) + elif pa.types.is_decimal(ty): + # TODO(kszucs): properly limit the precision + # value = st.decimals(places=type.scale, allow_infinity=False) + h.reject() + elif pa.types.is_time(ty): + value = st.times() + elif pa.types.is_date(ty): + value = st.dates() + elif pa.types.is_timestamp(ty): + min_int64 = -(2**63) + max_int64 = 2**63 - 1 + min_datetime = datetime.datetime.fromtimestamp(min_int64 // 10**9) + max_datetime = datetime.datetime.fromtimestamp(max_int64 // 10**9) + try: + offset_hours = int(ty.tz) + tz = pytz.FixedOffset(offset_hours * 60) + except ValueError: + tz = pytz.timezone(ty.tz) + value = st.datetimes(timezones=st.just(tz), min_value=min_datetime, + max_value=max_datetime) + elif pa.types.is_duration(ty): + value = st.timedeltas() + elif pa.types.is_binary(ty) or pa.types.is_large_binary(ty): + value = st.binary() + elif pa.types.is_string(ty) or pa.types.is_large_string(ty): + value = st.text() + elif pa.types.is_fixed_size_binary(ty): + value = st.binary(min_size=ty.byte_width, max_size=ty.byte_width) + elif pa.types.is_list(ty): + value = _pylist(ty.value_type, size=size, nullable=nullable) + elif pa.types.is_large_list(ty): + value = _pylist(ty.value_type, size=size, nullable=nullable) + elif pa.types.is_fixed_size_list(ty): + value = _pylist(ty.value_type, size=ty.list_size, nullable=nullable) + elif pa.types.is_dictionary(ty): + values = _pylist(ty.value_type, size=size, nullable=nullable) + return pa.array(draw(values), type=ty) + elif pa.types.is_map(ty): + value = _pymap(ty.key_type, ty.item_type, size=_default_array_sizes, + nullable=nullable) + elif pa.types.is_struct(ty): + h.assume(len(ty) > 0) + fields, child_arrays = [], [] + for field in ty: + fields.append(field) + child_arrays.append(draw(arrays(field.type, size=size))) + return pa.StructArray.from_arrays(child_arrays, fields=fields) + else: + raise NotImplementedError(ty) + + if nullable: + value = st.one_of(st.none(), value) + values = st.lists(value, min_size=size, max_size=size) + + return pa.array(draw(values), type=ty) + + +@st.composite +def chunked_arrays(draw, type, min_chunks=0, max_chunks=None, chunk_size=None): + if isinstance(type, st.SearchStrategy): + type = draw(type) + + # TODO(kszucs): remove it, field metadata is not kept + h.assume(not pa.types.is_struct(type)) + + chunk = arrays(type, size=chunk_size) + chunks = st.lists(chunk, min_size=min_chunks, max_size=max_chunks) + + return pa.chunked_array(draw(chunks), type=type) + + +@st.composite +def record_batches(draw, type, rows=None, max_fields=None): + if isinstance(rows, st.SearchStrategy): + rows = draw(rows) + elif rows is None: + rows = draw(_default_array_sizes) + elif not isinstance(rows, int): + raise TypeError('Rows must be an integer') + + schema = draw(schemas(type, max_fields=max_fields)) + children = [draw(arrays(field.type, size=rows)) for field in schema] + # TODO(kszucs): the names and schema arguments are not consistent with + # Table.from_array's arguments + return pa.RecordBatch.from_arrays(children, names=schema) + + +@st.composite +def tables(draw, type, rows=None, max_fields=None): + if isinstance(rows, st.SearchStrategy): + rows = draw(rows) + elif rows is None: + rows = draw(_default_array_sizes) + elif not isinstance(rows, int): + raise TypeError('Rows must be an integer') + + schema = draw(schemas(type, max_fields=max_fields)) + children = [draw(arrays(field.type, size=rows)) for field in schema] + return pa.Table.from_arrays(children, schema=schema) + + +all_arrays = arrays(all_types) +all_chunked_arrays = chunked_arrays(all_types) +all_record_batches = record_batches(all_types) +all_tables = tables(all_types) + + +# Define the same rules as above for pandas tests by excluding certain types +# from the generation because of known issues. + +pandas_compatible_primitive_types = st.one_of( + null_type, + bool_type, + integer_types, + st.sampled_from([pa.float32(), pa.float64()]), + decimal128_type, + date_types, + time_types, + # Need to exclude timestamp and duration types otherwise hypothesis + # discovers ARROW-10210 + # timestamp_types, + # duration_types + interval_types, + binary_type, + string_type, + large_binary_type, + large_string_type, +) + +# Need to exclude floating point types otherwise hypothesis discovers +# ARROW-10211 +pandas_compatible_dictionary_value_types = st.one_of( + bool_type, + integer_types, + binary_type, + string_type, + fixed_size_binary_type, +) + + +def pandas_compatible_list_types( + item_strategy=pandas_compatible_primitive_types +): + # Need to exclude fixed size list type otherwise hypothesis discovers + # ARROW-10194 + return ( + st.builds(pa.list_, item_strategy) | + st.builds(pa.large_list, item_strategy) + ) + + +pandas_compatible_types = st.deferred( + lambda: st.one_of( + pandas_compatible_primitive_types, + pandas_compatible_list_types(pandas_compatible_primitive_types), + struct_types(pandas_compatible_primitive_types), + dictionary_types( + value_strategy=pandas_compatible_dictionary_value_types + ), + pandas_compatible_list_types(pandas_compatible_types), + struct_types(pandas_compatible_types) + ) +) diff --git a/src/arrow/python/pyarrow/tests/test_adhoc_memory_leak.py b/src/arrow/python/pyarrow/tests/test_adhoc_memory_leak.py new file mode 100644 index 000000000..cd381cf42 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_adhoc_memory_leak.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import numpy as np +import pyarrow as pa + +import pyarrow.tests.util as test_util + +try: + import pandas as pd +except ImportError: + pass + + +@pytest.mark.memory_leak +@pytest.mark.pandas +def test_deserialize_pandas_arrow_7956(): + df = pd.DataFrame({'a': np.arange(10000), + 'b': [test_util.rands(5) for _ in range(10000)]}) + + def action(): + df_bytes = pa.ipc.serialize_pandas(df).to_pybytes() + buf = pa.py_buffer(df_bytes) + pa.ipc.deserialize_pandas(buf) + + # Abort at 128MB threshold + test_util.memory_leak_check(action, threshold=1 << 27, iterations=100) diff --git a/src/arrow/python/pyarrow/tests/test_array.py b/src/arrow/python/pyarrow/tests/test_array.py new file mode 100644 index 000000000..9a1f41efe --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_array.py @@ -0,0 +1,3064 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections.abc import Iterable +import datetime +import decimal +import hypothesis as h +import hypothesis.strategies as st +import itertools +import pickle +import pytest +import struct +import sys +import weakref + +import numpy as np +try: + import pickle5 +except ImportError: + pickle5 = None +import pytz + +import pyarrow as pa +import pyarrow.tests.strategies as past + + +def test_total_bytes_allocated(): + assert pa.total_allocated_bytes() == 0 + + +def test_weakref(): + arr = pa.array([1, 2, 3]) + wr = weakref.ref(arr) + assert wr() is not None + del arr + assert wr() is None + + +def test_getitem_NULL(): + arr = pa.array([1, None, 2]) + assert arr[1].as_py() is None + assert arr[1].is_valid is False + assert isinstance(arr[1], pa.Int64Scalar) + + +def test_constructor_raises(): + # This could happen by wrong capitalization. + # ARROW-2638: prevent calling extension class constructors directly + with pytest.raises(TypeError): + pa.Array([1, 2]) + + +def test_list_format(): + arr = pa.array([[1], None, [2, 3, None]]) + result = arr.to_string() + expected = """\ +[ + [ + 1 + ], + null, + [ + 2, + 3, + null + ] +]""" + assert result == expected + + +def test_string_format(): + arr = pa.array(['', None, 'foo']) + result = arr.to_string() + expected = """\ +[ + "", + null, + "foo" +]""" + assert result == expected + + +def test_long_array_format(): + arr = pa.array(range(100)) + result = arr.to_string(window=2) + expected = """\ +[ + 0, + 1, + ... + 98, + 99 +]""" + assert result == expected + + +def test_binary_format(): + arr = pa.array([b'\x00', b'', None, b'\x01foo', b'\x80\xff']) + result = arr.to_string() + expected = """\ +[ + 00, + , + null, + 01666F6F, + 80FF +]""" + assert result == expected + + +def test_binary_total_values_length(): + arr = pa.array([b'0000', None, b'11111', b'222222', b'3333333'], + type='binary') + large_arr = pa.array([b'0000', None, b'11111', b'222222', b'3333333'], + type='large_binary') + + assert arr.total_values_length == 22 + assert arr.slice(1, 3).total_values_length == 11 + assert large_arr.total_values_length == 22 + assert large_arr.slice(1, 3).total_values_length == 11 + + +def test_to_numpy_zero_copy(): + arr = pa.array(range(10)) + + np_arr = arr.to_numpy() + + # check for zero copy (both arrays using same memory) + arrow_buf = arr.buffers()[1] + assert arrow_buf.address == np_arr.ctypes.data + + arr = None + import gc + gc.collect() + + # Ensure base is still valid + assert np_arr.base is not None + expected = np.arange(10) + np.testing.assert_array_equal(np_arr, expected) + + +def test_to_numpy_unsupported_types(): + # ARROW-2871: Some primitive types are not yet supported in to_numpy + bool_arr = pa.array([True, False, True]) + + with pytest.raises(ValueError): + bool_arr.to_numpy() + + result = bool_arr.to_numpy(zero_copy_only=False) + expected = np.array([True, False, True]) + np.testing.assert_array_equal(result, expected) + + null_arr = pa.array([None, None, None]) + + with pytest.raises(ValueError): + null_arr.to_numpy() + + result = null_arr.to_numpy(zero_copy_only=False) + expected = np.array([None, None, None], dtype=object) + np.testing.assert_array_equal(result, expected) + + arr = pa.array([1, 2, None]) + + with pytest.raises(ValueError, match="with 1 nulls"): + arr.to_numpy() + + +def test_to_numpy_writable(): + arr = pa.array(range(10)) + np_arr = arr.to_numpy() + + # by default not writable for zero-copy conversion + with pytest.raises(ValueError): + np_arr[0] = 10 + + np_arr2 = arr.to_numpy(zero_copy_only=False, writable=True) + np_arr2[0] = 10 + assert arr[0].as_py() == 0 + + # when asking for writable, cannot do zero-copy + with pytest.raises(ValueError): + arr.to_numpy(zero_copy_only=True, writable=True) + + +@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) +def test_to_numpy_datetime64(unit): + arr = pa.array([1, 2, 3], pa.timestamp(unit)) + expected = np.array([1, 2, 3], dtype="datetime64[{}]".format(unit)) + np_arr = arr.to_numpy() + np.testing.assert_array_equal(np_arr, expected) + + +@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) +def test_to_numpy_timedelta64(unit): + arr = pa.array([1, 2, 3], pa.duration(unit)) + expected = np.array([1, 2, 3], dtype="timedelta64[{}]".format(unit)) + np_arr = arr.to_numpy() + np.testing.assert_array_equal(np_arr, expected) + + +def test_to_numpy_dictionary(): + # ARROW-7591 + arr = pa.array(["a", "b", "a"]).dictionary_encode() + expected = np.array(["a", "b", "a"], dtype=object) + np_arr = arr.to_numpy(zero_copy_only=False) + np.testing.assert_array_equal(np_arr, expected) + + +@pytest.mark.pandas +def test_to_pandas_zero_copy(): + import gc + + arr = pa.array(range(10)) + + for i in range(10): + series = arr.to_pandas() + assert sys.getrefcount(series) == 2 + series = None # noqa + + assert sys.getrefcount(arr) == 2 + + for i in range(10): + arr = pa.array(range(10)) + series = arr.to_pandas() + arr = None + gc.collect() + + # Ensure base is still valid + + # Because of py.test's assert inspection magic, if you put getrefcount + # on the line being examined, it will be 1 higher than you expect + base_refcount = sys.getrefcount(series.values.base) + assert base_refcount == 2 + series.sum() + + +@pytest.mark.nopandas +@pytest.mark.pandas +def test_asarray(): + # ensure this is tested both when pandas is present or not (ARROW-6564) + + arr = pa.array(range(4)) + + # The iterator interface gives back an array of Int64Value's + np_arr = np.asarray([_ for _ in arr]) + assert np_arr.tolist() == [0, 1, 2, 3] + assert np_arr.dtype == np.dtype('O') + assert type(np_arr[0]) == pa.lib.Int64Value + + # Calling with the arrow array gives back an array with 'int64' dtype + np_arr = np.asarray(arr) + assert np_arr.tolist() == [0, 1, 2, 3] + assert np_arr.dtype == np.dtype('int64') + + # An optional type can be specified when calling np.asarray + np_arr = np.asarray(arr, dtype='str') + assert np_arr.tolist() == ['0', '1', '2', '3'] + + # If PyArrow array has null values, numpy type will be changed as needed + # to support nulls. + arr = pa.array([0, 1, 2, None]) + assert arr.type == pa.int64() + np_arr = np.asarray(arr) + elements = np_arr.tolist() + assert elements[:3] == [0., 1., 2.] + assert np.isnan(elements[3]) + assert np_arr.dtype == np.dtype('float64') + + # DictionaryType data will be converted to dense numpy array + arr = pa.DictionaryArray.from_arrays( + pa.array([0, 1, 2, 0, 1]), pa.array(['a', 'b', 'c'])) + np_arr = np.asarray(arr) + assert np_arr.dtype == np.dtype('object') + assert np_arr.tolist() == ['a', 'b', 'c', 'a', 'b'] + + +@pytest.mark.parametrize('ty', [ + None, + pa.null(), + pa.int8(), + pa.string() +]) +def test_nulls(ty): + arr = pa.nulls(3, type=ty) + expected = pa.array([None, None, None], type=ty) + + assert len(arr) == 3 + assert arr.equals(expected) + + if ty is None: + assert arr.type == pa.null() + else: + assert arr.type == ty + + +def test_array_from_scalar(): + today = datetime.date.today() + now = datetime.datetime.now() + now_utc = now.replace(tzinfo=pytz.utc) + now_with_tz = now_utc.astimezone(pytz.timezone('US/Eastern')) + oneday = datetime.timedelta(days=1) + + cases = [ + (None, 1, pa.array([None])), + (None, 10, pa.nulls(10)), + (-1, 3, pa.array([-1, -1, -1], type=pa.int64())), + (2.71, 2, pa.array([2.71, 2.71], type=pa.float64())), + ("string", 4, pa.array(["string"] * 4)), + ( + pa.scalar(8, type=pa.uint8()), + 17, + pa.array([8] * 17, type=pa.uint8()) + ), + (pa.scalar(None), 3, pa.array([None, None, None])), + (pa.scalar(True), 11, pa.array([True] * 11)), + (today, 2, pa.array([today] * 2)), + (now, 10, pa.array([now] * 10)), + ( + now_with_tz, + 2, + pa.array( + [now_utc] * 2, + type=pa.timestamp('us', tz=pytz.timezone('US/Eastern')) + ) + ), + (now.time(), 9, pa.array([now.time()] * 9)), + (oneday, 4, pa.array([oneday] * 4)), + (False, 9, pa.array([False] * 9)), + ([1, 2], 2, pa.array([[1, 2], [1, 2]])), + ( + pa.scalar([-1, 3], type=pa.large_list(pa.int8())), + 5, + pa.array([[-1, 3]] * 5, type=pa.large_list(pa.int8())) + ), + ({'a': 1, 'b': 2}, 3, pa.array([{'a': 1, 'b': 2}] * 3)) + ] + + for value, size, expected in cases: + arr = pa.repeat(value, size) + assert len(arr) == size + assert arr.type.equals(expected.type) + assert arr.equals(expected) + if expected.type == pa.null(): + assert arr.null_count == size + else: + assert arr.null_count == 0 + + +def test_array_from_dictionary_scalar(): + dictionary = ['foo', 'bar', 'baz'] + arr = pa.DictionaryArray.from_arrays([2, 1, 2, 0], dictionary=dictionary) + + result = pa.repeat(arr[0], 5) + expected = pa.DictionaryArray.from_arrays([2] * 5, dictionary=dictionary) + assert result.equals(expected) + + result = pa.repeat(arr[3], 5) + expected = pa.DictionaryArray.from_arrays([0] * 5, dictionary=dictionary) + assert result.equals(expected) + + +def test_array_getitem(): + arr = pa.array(range(10, 15)) + lst = arr.to_pylist() + + for idx in range(-len(arr), len(arr)): + assert arr[idx].as_py() == lst[idx] + for idx in range(-2 * len(arr), -len(arr)): + with pytest.raises(IndexError): + arr[idx] + for idx in range(len(arr), 2 * len(arr)): + with pytest.raises(IndexError): + arr[idx] + + # check that numpy scalars are supported + for idx in range(-len(arr), len(arr)): + assert arr[np.int32(idx)].as_py() == lst[idx] + + +def test_array_slice(): + arr = pa.array(range(10)) + + sliced = arr.slice(2) + expected = pa.array(range(2, 10)) + assert sliced.equals(expected) + + sliced2 = arr.slice(2, 4) + expected2 = pa.array(range(2, 6)) + assert sliced2.equals(expected2) + + # 0 offset + assert arr.slice(0).equals(arr) + + # Slice past end of array + assert len(arr.slice(len(arr))) == 0 + assert len(arr.slice(len(arr) + 2)) == 0 + assert len(arr.slice(len(arr) + 2, 100)) == 0 + + with pytest.raises(IndexError): + arr.slice(-1) + + with pytest.raises(ValueError): + arr.slice(2, -1) + + # Test slice notation + assert arr[2:].equals(arr.slice(2)) + assert arr[2:5].equals(arr.slice(2, 3)) + assert arr[-5:].equals(arr.slice(len(arr) - 5)) + + n = len(arr) + for start in range(-n * 2, n * 2): + for stop in range(-n * 2, n * 2): + res = arr[start:stop] + res.validate() + expected = arr.to_pylist()[start:stop] + assert res.to_pylist() == expected + assert res.to_numpy().tolist() == expected + + +def test_array_slice_negative_step(): + # ARROW-2714 + np_arr = np.arange(20) + arr = pa.array(np_arr) + chunked_arr = pa.chunked_array([arr]) + + cases = [ + slice(None, None, -1), + slice(None, 6, -2), + slice(10, 6, -2), + slice(8, None, -2), + slice(2, 10, -2), + slice(10, 2, -2), + slice(None, None, 2), + slice(0, 10, 2), + ] + + for case in cases: + result = arr[case] + expected = pa.array(np_arr[case]) + assert result.equals(expected) + + result = pa.record_batch([arr], names=['f0'])[case] + expected = pa.record_batch([expected], names=['f0']) + assert result.equals(expected) + + result = chunked_arr[case] + expected = pa.chunked_array([np_arr[case]]) + assert result.equals(expected) + + +def test_array_diff(): + # ARROW-6252 + arr1 = pa.array(['foo'], type=pa.utf8()) + arr2 = pa.array(['foo', 'bar', None], type=pa.utf8()) + arr3 = pa.array([1, 2, 3]) + arr4 = pa.array([[], [1], None], type=pa.list_(pa.int64())) + + assert arr1.diff(arr1) == '' + assert arr1.diff(arr2) == ''' +@@ -1, +1 @@ ++"bar" ++null +''' + assert arr1.diff(arr3).strip() == '# Array types differed: string vs int64' + assert arr1.diff(arr3).strip() == '# Array types differed: string vs int64' + assert arr1.diff(arr4).strip() == ('# Array types differed: string vs ' + 'list<item: int64>') + + +def test_array_iter(): + arr = pa.array(range(10)) + + for i, j in zip(range(10), arr): + assert i == j.as_py() + + assert isinstance(arr, Iterable) + + +def test_struct_array_slice(): + # ARROW-2311: slicing nested arrays needs special care + ty = pa.struct([pa.field('a', pa.int8()), + pa.field('b', pa.float32())]) + arr = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) + assert arr[1:].to_pylist() == [{'a': 3, 'b': 4.5}, + {'a': 5, 'b': 6.5}] + + +def test_array_factory_invalid_type(): + + class MyObject: + pass + + arr = np.array([MyObject()]) + with pytest.raises(ValueError): + pa.array(arr) + + +def test_array_ref_to_ndarray_base(): + arr = np.array([1, 2, 3]) + + refcount = sys.getrefcount(arr) + arr2 = pa.array(arr) # noqa + assert sys.getrefcount(arr) == (refcount + 1) + + +def test_array_eq(): + # ARROW-2150 / ARROW-9445: we define the __eq__ behavior to be + # data equality (not element-wise equality) + arr1 = pa.array([1, 2, 3], type=pa.int32()) + arr2 = pa.array([1, 2, 3], type=pa.int32()) + arr3 = pa.array([1, 2, 3], type=pa.int64()) + + assert (arr1 == arr2) is True + assert (arr1 != arr2) is False + assert (arr1 == arr3) is False + assert (arr1 != arr3) is True + + assert (arr1 == 1) is False + assert (arr1 == None) is False # noqa: E711 + + +def test_array_from_buffers(): + values_buf = pa.py_buffer(np.int16([4, 5, 6, 7])) + nulls_buf = pa.py_buffer(np.uint8([0b00001101])) + arr = pa.Array.from_buffers(pa.int16(), 4, [nulls_buf, values_buf]) + assert arr.type == pa.int16() + assert arr.to_pylist() == [4, None, 6, 7] + + arr = pa.Array.from_buffers(pa.int16(), 4, [None, values_buf]) + assert arr.type == pa.int16() + assert arr.to_pylist() == [4, 5, 6, 7] + + arr = pa.Array.from_buffers(pa.int16(), 3, [nulls_buf, values_buf], + offset=1) + assert arr.type == pa.int16() + assert arr.to_pylist() == [None, 6, 7] + + with pytest.raises(TypeError): + pa.Array.from_buffers(pa.int16(), 3, ['', ''], offset=1) + + +def test_string_binary_from_buffers(): + array = pa.array(["a", None, "b", "c"]) + + buffers = array.buffers() + copied = pa.StringArray.from_buffers( + len(array), buffers[1], buffers[2], buffers[0], array.null_count, + array.offset) + assert copied.to_pylist() == ["a", None, "b", "c"] + + binary_copy = pa.Array.from_buffers(pa.binary(), len(array), + array.buffers(), array.null_count, + array.offset) + assert binary_copy.to_pylist() == [b"a", None, b"b", b"c"] + + copied = pa.StringArray.from_buffers( + len(array), buffers[1], buffers[2], buffers[0]) + assert copied.to_pylist() == ["a", None, "b", "c"] + + sliced = array[1:] + buffers = sliced.buffers() + copied = pa.StringArray.from_buffers( + len(sliced), buffers[1], buffers[2], buffers[0], -1, sliced.offset) + assert copied.to_pylist() == [None, "b", "c"] + assert copied.null_count == 1 + + # Slice but exclude all null entries so that we don't need to pass + # the null bitmap. + sliced = array[2:] + buffers = sliced.buffers() + copied = pa.StringArray.from_buffers( + len(sliced), buffers[1], buffers[2], None, -1, sliced.offset) + assert copied.to_pylist() == ["b", "c"] + assert copied.null_count == 0 + + +@pytest.mark.parametrize('list_type_factory', [pa.list_, pa.large_list]) +def test_list_from_buffers(list_type_factory): + ty = list_type_factory(pa.int16()) + array = pa.array([[0, 1, 2], None, [], [3, 4, 5]], type=ty) + assert array.type == ty + + buffers = array.buffers() + + with pytest.raises(ValueError): + # No children + pa.Array.from_buffers(ty, 4, [None, buffers[1]]) + + child = pa.Array.from_buffers(pa.int16(), 6, buffers[2:]) + copied = pa.Array.from_buffers(ty, 4, buffers[:2], children=[child]) + assert copied.equals(array) + + with pytest.raises(ValueError): + # too many children + pa.Array.from_buffers(ty, 4, [None, buffers[1]], + children=[child, child]) + + +def test_struct_from_buffers(): + ty = pa.struct([pa.field('a', pa.int16()), pa.field('b', pa.utf8())]) + array = pa.array([{'a': 0, 'b': 'foo'}, None, {'a': 5, 'b': ''}], + type=ty) + buffers = array.buffers() + + with pytest.raises(ValueError): + # No children + pa.Array.from_buffers(ty, 3, [None, buffers[1]]) + + children = [pa.Array.from_buffers(pa.int16(), 3, buffers[1:3]), + pa.Array.from_buffers(pa.utf8(), 3, buffers[3:])] + copied = pa.Array.from_buffers(ty, 3, buffers[:1], children=children) + assert copied.equals(array) + + with pytest.raises(ValueError): + # not enough many children + pa.Array.from_buffers(ty, 3, [buffers[0]], + children=children[:1]) + + +def test_struct_from_arrays(): + a = pa.array([4, 5, 6], type=pa.int64()) + b = pa.array(["bar", None, ""]) + c = pa.array([[1, 2], None, [3, None]]) + expected_list = [ + {'a': 4, 'b': 'bar', 'c': [1, 2]}, + {'a': 5, 'b': None, 'c': None}, + {'a': 6, 'b': '', 'c': [3, None]}, + ] + + # From field names + arr = pa.StructArray.from_arrays([a, b, c], ["a", "b", "c"]) + assert arr.type == pa.struct( + [("a", a.type), ("b", b.type), ("c", c.type)]) + assert arr.to_pylist() == expected_list + + with pytest.raises(ValueError): + pa.StructArray.from_arrays([a, b, c], ["a", "b"]) + + arr = pa.StructArray.from_arrays([], []) + assert arr.type == pa.struct([]) + assert arr.to_pylist() == [] + + # From fields + fa = pa.field("a", a.type, nullable=False) + fb = pa.field("b", b.type) + fc = pa.field("c", c.type) + arr = pa.StructArray.from_arrays([a, b, c], fields=[fa, fb, fc]) + assert arr.type == pa.struct([fa, fb, fc]) + assert not arr.type[0].nullable + assert arr.to_pylist() == expected_list + + with pytest.raises(ValueError): + pa.StructArray.from_arrays([a, b, c], fields=[fa, fb]) + + arr = pa.StructArray.from_arrays([], fields=[]) + assert arr.type == pa.struct([]) + assert arr.to_pylist() == [] + + # Inconsistent fields + fa2 = pa.field("a", pa.int32()) + with pytest.raises(ValueError, match="int64 vs int32"): + pa.StructArray.from_arrays([a, b, c], fields=[fa2, fb, fc]) + + arrays = [a, b, c] + fields = [fa, fb, fc] + # With mask + mask = pa.array([True, False, False]) + arr = pa.StructArray.from_arrays(arrays, fields=fields, mask=mask) + assert arr.to_pylist() == [None] + expected_list[1:] + + arr = pa.StructArray.from_arrays(arrays, names=['a', 'b', 'c'], mask=mask) + assert arr.to_pylist() == [None] + expected_list[1:] + + # Bad masks + with pytest.raises(ValueError, match='Mask must be'): + pa.StructArray.from_arrays(arrays, fields, mask=[True, False, False]) + + with pytest.raises(ValueError, match='not contain nulls'): + pa.StructArray.from_arrays( + arrays, fields, mask=pa.array([True, False, None])) + + with pytest.raises(ValueError, match='Mask must be'): + pa.StructArray.from_arrays( + arrays, fields, mask=pa.chunked_array([mask])) + + +def test_struct_array_from_chunked(): + # ARROW-11780 + # Check that we don't segfault when trying to build + # a StructArray from a chunked array. + chunked_arr = pa.chunked_array([[1, 2, 3], [4, 5, 6]]) + + with pytest.raises(TypeError, match="Expected Array"): + pa.StructArray.from_arrays([chunked_arr], ["foo"]) + + +def test_dictionary_from_numpy(): + indices = np.repeat([0, 1, 2], 2) + dictionary = np.array(['foo', 'bar', 'baz'], dtype=object) + mask = np.array([False, False, True, False, False, False]) + + d1 = pa.DictionaryArray.from_arrays(indices, dictionary) + d2 = pa.DictionaryArray.from_arrays(indices, dictionary, mask=mask) + + assert d1.indices.to_pylist() == indices.tolist() + assert d1.indices.to_pylist() == indices.tolist() + assert d1.dictionary.to_pylist() == dictionary.tolist() + assert d2.dictionary.to_pylist() == dictionary.tolist() + + for i in range(len(indices)): + assert d1[i].as_py() == dictionary[indices[i]] + + if mask[i]: + assert d2[i].as_py() is None + else: + assert d2[i].as_py() == dictionary[indices[i]] + + +def test_dictionary_to_numpy(): + expected = pa.array( + ["foo", "bar", None, "foo"] + ).to_numpy(zero_copy_only=False) + a = pa.DictionaryArray.from_arrays( + pa.array([0, 1, None, 0]), + pa.array(['foo', 'bar']) + ) + np.testing.assert_array_equal(a.to_numpy(zero_copy_only=False), + expected) + + with pytest.raises(pa.ArrowInvalid): + # If this would be changed to no longer raise in the future, + # ensure to test the actual result because, currently, to_numpy takes + # for granted that when zero_copy_only=True there will be no nulls + # (it's the decoding of the DictionaryArray that handles the nulls and + # this is only activated with zero_copy_only=False) + a.to_numpy(zero_copy_only=True) + + anonulls = pa.DictionaryArray.from_arrays( + pa.array([0, 1, 1, 0]), + pa.array(['foo', 'bar']) + ) + expected = pa.array( + ["foo", "bar", "bar", "foo"] + ).to_numpy(zero_copy_only=False) + np.testing.assert_array_equal(anonulls.to_numpy(zero_copy_only=False), + expected) + + with pytest.raises(pa.ArrowInvalid): + anonulls.to_numpy(zero_copy_only=True) + + afloat = pa.DictionaryArray.from_arrays( + pa.array([0, 1, 1, 0]), + pa.array([13.7, 11.0]) + ) + expected = pa.array([13.7, 11.0, 11.0, 13.7]).to_numpy() + np.testing.assert_array_equal(afloat.to_numpy(zero_copy_only=True), + expected) + np.testing.assert_array_equal(afloat.to_numpy(zero_copy_only=False), + expected) + + afloat2 = pa.DictionaryArray.from_arrays( + pa.array([0, 1, None, 0]), + pa.array([13.7, 11.0]) + ) + expected = pa.array( + [13.7, 11.0, None, 13.7] + ).to_numpy(zero_copy_only=False) + np.testing.assert_allclose( + afloat2.to_numpy(zero_copy_only=False), + expected, + equal_nan=True + ) + + # Testing for integers can reveal problems related to dealing + # with None values, as a numpy array of int dtype + # can't contain NaN nor None. + aints = pa.DictionaryArray.from_arrays( + pa.array([0, 1, None, 0]), + pa.array([7, 11]) + ) + expected = pa.array([7, 11, None, 7]).to_numpy(zero_copy_only=False) + np.testing.assert_allclose( + aints.to_numpy(zero_copy_only=False), + expected, + equal_nan=True + ) + + +def test_dictionary_from_boxed_arrays(): + indices = np.repeat([0, 1, 2], 2) + dictionary = np.array(['foo', 'bar', 'baz'], dtype=object) + + iarr = pa.array(indices) + darr = pa.array(dictionary) + + d1 = pa.DictionaryArray.from_arrays(iarr, darr) + + assert d1.indices.to_pylist() == indices.tolist() + assert d1.dictionary.to_pylist() == dictionary.tolist() + + for i in range(len(indices)): + assert d1[i].as_py() == dictionary[indices[i]] + + +def test_dictionary_from_arrays_boundscheck(): + indices1 = pa.array([0, 1, 2, 0, 1, 2]) + indices2 = pa.array([0, -1, 2]) + indices3 = pa.array([0, 1, 2, 3]) + + dictionary = pa.array(['foo', 'bar', 'baz']) + + # Works fine + pa.DictionaryArray.from_arrays(indices1, dictionary) + + with pytest.raises(pa.ArrowException): + pa.DictionaryArray.from_arrays(indices2, dictionary) + + with pytest.raises(pa.ArrowException): + pa.DictionaryArray.from_arrays(indices3, dictionary) + + # If we are confident that the indices are "safe" we can pass safe=False to + # disable the boundschecking + pa.DictionaryArray.from_arrays(indices2, dictionary, safe=False) + + +def test_dictionary_indices(): + # https://issues.apache.org/jira/browse/ARROW-6882 + indices = pa.array([0, 1, 2, 0, 1, 2]) + dictionary = pa.array(['foo', 'bar', 'baz']) + arr = pa.DictionaryArray.from_arrays(indices, dictionary) + arr.indices.validate(full=True) + + +@pytest.mark.parametrize(('list_array_type', 'list_type_factory'), + [(pa.ListArray, pa.list_), + (pa.LargeListArray, pa.large_list)]) +def test_list_from_arrays(list_array_type, list_type_factory): + offsets_arr = np.array([0, 2, 5, 8], dtype='i4') + offsets = pa.array(offsets_arr, type='int32') + pyvalues = [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h'] + values = pa.array(pyvalues, type='binary') + + result = list_array_type.from_arrays(offsets, values) + expected = pa.array([pyvalues[:2], pyvalues[2:5], pyvalues[5:8]], + type=list_type_factory(pa.binary())) + + assert result.equals(expected) + + # With nulls + offsets = [0, None, 2, 6] + values = [b'a', b'b', b'c', b'd', b'e', b'f'] + + result = list_array_type.from_arrays(offsets, values) + expected = pa.array([values[:2], None, values[2:]], + type=list_type_factory(pa.binary())) + + assert result.equals(expected) + + # Another edge case + offsets2 = [0, 2, None, 6] + result = list_array_type.from_arrays(offsets2, values) + expected = pa.array([values[:2], values[2:], None], + type=list_type_factory(pa.binary())) + assert result.equals(expected) + + # raise on invalid array + offsets = [1, 3, 10] + values = np.arange(5) + with pytest.raises(ValueError): + list_array_type.from_arrays(offsets, values) + + # Non-monotonic offsets + offsets = [0, 3, 2, 6] + values = list(range(6)) + result = list_array_type.from_arrays(offsets, values) + with pytest.raises(ValueError): + result.validate(full=True) + + +def test_map_from_arrays(): + offsets_arr = np.array([0, 2, 5, 8], dtype='i4') + offsets = pa.array(offsets_arr, type='int32') + pykeys = [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h'] + pyitems = list(range(len(pykeys))) + pypairs = list(zip(pykeys, pyitems)) + pyentries = [pypairs[:2], pypairs[2:5], pypairs[5:8]] + keys = pa.array(pykeys, type='binary') + items = pa.array(pyitems, type='i4') + + result = pa.MapArray.from_arrays(offsets, keys, items) + expected = pa.array(pyentries, type=pa.map_(pa.binary(), pa.int32())) + + assert result.equals(expected) + + # With nulls + offsets = [0, None, 2, 6] + pykeys = [b'a', b'b', b'c', b'd', b'e', b'f'] + pyitems = [1, 2, 3, None, 4, 5] + pypairs = list(zip(pykeys, pyitems)) + pyentries = [pypairs[:2], None, pypairs[2:]] + keys = pa.array(pykeys, type='binary') + items = pa.array(pyitems, type='i4') + + result = pa.MapArray.from_arrays(offsets, keys, items) + expected = pa.array(pyentries, type=pa.map_(pa.binary(), pa.int32())) + + assert result.equals(expected) + + # check invalid usage + + offsets = [0, 1, 3, 5] + keys = np.arange(5) + items = np.arange(5) + _ = pa.MapArray.from_arrays(offsets, keys, items) + + # raise on invalid offsets + with pytest.raises(ValueError): + pa.MapArray.from_arrays(offsets + [6], keys, items) + + # raise on length of keys != items + with pytest.raises(ValueError): + pa.MapArray.from_arrays(offsets, keys, np.concatenate([items, items])) + + # raise on keys with null + keys_with_null = list(keys)[:-1] + [None] + assert len(keys_with_null) == len(items) + with pytest.raises(ValueError): + pa.MapArray.from_arrays(offsets, keys_with_null, items) + + +def test_fixed_size_list_from_arrays(): + values = pa.array(range(12), pa.int64()) + result = pa.FixedSizeListArray.from_arrays(values, 4) + assert result.to_pylist() == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] + assert result.type.equals(pa.list_(pa.int64(), 4)) + + # raise on invalid values / list_size + with pytest.raises(ValueError): + pa.FixedSizeListArray.from_arrays(values, -4) + + with pytest.raises(ValueError): + # array with list size 0 cannot be constructed with from_arrays + pa.FixedSizeListArray.from_arrays(pa.array([], pa.int64()), 0) + + with pytest.raises(ValueError): + # length of values not multiple of 5 + pa.FixedSizeListArray.from_arrays(values, 5) + + +def test_variable_list_from_arrays(): + values = pa.array([1, 2, 3, 4], pa.int64()) + offsets = pa.array([0, 2, 4]) + result = pa.ListArray.from_arrays(offsets, values) + assert result.to_pylist() == [[1, 2], [3, 4]] + assert result.type.equals(pa.list_(pa.int64())) + + offsets = pa.array([0, None, 2, 4]) + result = pa.ListArray.from_arrays(offsets, values) + assert result.to_pylist() == [[1, 2], None, [3, 4]] + + # raise if offset out of bounds + with pytest.raises(ValueError): + pa.ListArray.from_arrays(pa.array([-1, 2, 4]), values) + + with pytest.raises(ValueError): + pa.ListArray.from_arrays(pa.array([0, 2, 5]), values) + + +def test_union_from_dense(): + binary = pa.array([b'a', b'b', b'c', b'd'], type='binary') + int64 = pa.array([1, 2, 3], type='int64') + types = pa.array([0, 1, 0, 0, 1, 1, 0], type='int8') + logical_types = pa.array([11, 13, 11, 11, 13, 13, 11], type='int8') + value_offsets = pa.array([0, 0, 1, 2, 1, 2, 3], type='int32') + py_value = [b'a', 1, b'b', b'c', 2, 3, b'd'] + + def check_result(result, expected_field_names, expected_type_codes, + expected_type_code_values): + result.validate(full=True) + actual_field_names = [result.type[i].name + for i in range(result.type.num_fields)] + assert actual_field_names == expected_field_names + assert result.type.mode == "dense" + assert result.type.type_codes == expected_type_codes + assert result.to_pylist() == py_value + assert expected_type_code_values.equals(result.type_codes) + assert value_offsets.equals(result.offsets) + assert result.field(0).equals(binary) + assert result.field(1).equals(int64) + with pytest.raises(KeyError): + result.field(-1) + with pytest.raises(KeyError): + result.field(2) + + # without field names and type codes + check_result(pa.UnionArray.from_dense(types, value_offsets, + [binary, int64]), + expected_field_names=['0', '1'], + expected_type_codes=[0, 1], + expected_type_code_values=types) + + # with field names + check_result(pa.UnionArray.from_dense(types, value_offsets, + [binary, int64], + ['bin', 'int']), + expected_field_names=['bin', 'int'], + expected_type_codes=[0, 1], + expected_type_code_values=types) + + # with type codes + check_result(pa.UnionArray.from_dense(logical_types, value_offsets, + [binary, int64], + type_codes=[11, 13]), + expected_field_names=['0', '1'], + expected_type_codes=[11, 13], + expected_type_code_values=logical_types) + + # with field names and type codes + check_result(pa.UnionArray.from_dense(logical_types, value_offsets, + [binary, int64], + ['bin', 'int'], [11, 13]), + expected_field_names=['bin', 'int'], + expected_type_codes=[11, 13], + expected_type_code_values=logical_types) + + # Bad type ids + arr = pa.UnionArray.from_dense(logical_types, value_offsets, + [binary, int64]) + with pytest.raises(pa.ArrowInvalid): + arr.validate(full=True) + arr = pa.UnionArray.from_dense(types, value_offsets, [binary, int64], + type_codes=[11, 13]) + with pytest.raises(pa.ArrowInvalid): + arr.validate(full=True) + + # Offset larger than child size + bad_offsets = pa.array([0, 0, 1, 2, 1, 2, 4], type='int32') + arr = pa.UnionArray.from_dense(types, bad_offsets, [binary, int64]) + with pytest.raises(pa.ArrowInvalid): + arr.validate(full=True) + + +def test_union_from_sparse(): + binary = pa.array([b'a', b' ', b'b', b'c', b' ', b' ', b'd'], + type='binary') + int64 = pa.array([0, 1, 0, 0, 2, 3, 0], type='int64') + types = pa.array([0, 1, 0, 0, 1, 1, 0], type='int8') + logical_types = pa.array([11, 13, 11, 11, 13, 13, 11], type='int8') + py_value = [b'a', 1, b'b', b'c', 2, 3, b'd'] + + def check_result(result, expected_field_names, expected_type_codes, + expected_type_code_values): + result.validate(full=True) + assert result.to_pylist() == py_value + actual_field_names = [result.type[i].name + for i in range(result.type.num_fields)] + assert actual_field_names == expected_field_names + assert result.type.mode == "sparse" + assert result.type.type_codes == expected_type_codes + assert expected_type_code_values.equals(result.type_codes) + assert result.field(0).equals(binary) + assert result.field(1).equals(int64) + with pytest.raises(pa.ArrowTypeError): + result.offsets + with pytest.raises(KeyError): + result.field(-1) + with pytest.raises(KeyError): + result.field(2) + + # without field names and type codes + check_result(pa.UnionArray.from_sparse(types, [binary, int64]), + expected_field_names=['0', '1'], + expected_type_codes=[0, 1], + expected_type_code_values=types) + + # with field names + check_result(pa.UnionArray.from_sparse(types, [binary, int64], + ['bin', 'int']), + expected_field_names=['bin', 'int'], + expected_type_codes=[0, 1], + expected_type_code_values=types) + + # with type codes + check_result(pa.UnionArray.from_sparse(logical_types, [binary, int64], + type_codes=[11, 13]), + expected_field_names=['0', '1'], + expected_type_codes=[11, 13], + expected_type_code_values=logical_types) + + # with field names and type codes + check_result(pa.UnionArray.from_sparse(logical_types, [binary, int64], + ['bin', 'int'], + [11, 13]), + expected_field_names=['bin', 'int'], + expected_type_codes=[11, 13], + expected_type_code_values=logical_types) + + # Bad type ids + arr = pa.UnionArray.from_sparse(logical_types, [binary, int64]) + with pytest.raises(pa.ArrowInvalid): + arr.validate(full=True) + arr = pa.UnionArray.from_sparse(types, [binary, int64], + type_codes=[11, 13]) + with pytest.raises(pa.ArrowInvalid): + arr.validate(full=True) + + # Invalid child length + with pytest.raises(pa.ArrowInvalid): + arr = pa.UnionArray.from_sparse(logical_types, [binary, int64[1:]]) + + +def test_union_array_to_pylist_with_nulls(): + # ARROW-9556 + arr = pa.UnionArray.from_sparse( + pa.array([0, 1, 0, 0, 1], type=pa.int8()), + [ + pa.array([0.0, 1.1, None, 3.3, 4.4]), + pa.array([True, None, False, True, False]), + ] + ) + assert arr.to_pylist() == [0.0, None, None, 3.3, False] + + arr = pa.UnionArray.from_dense( + pa.array([0, 1, 0, 0, 0, 1, 1], type=pa.int8()), + pa.array([0, 0, 1, 2, 3, 1, 2], type=pa.int32()), + [ + pa.array([0.0, 1.1, None, 3.3]), + pa.array([True, None, False]) + ] + ) + assert arr.to_pylist() == [0.0, True, 1.1, None, 3.3, None, False] + + +def test_union_array_slice(): + # ARROW-2314 + arr = pa.UnionArray.from_sparse(pa.array([0, 0, 1, 1], type=pa.int8()), + [pa.array(["a", "b", "c", "d"]), + pa.array([1, 2, 3, 4])]) + assert arr[1:].to_pylist() == ["b", 3, 4] + + binary = pa.array([b'a', b'b', b'c', b'd'], type='binary') + int64 = pa.array([1, 2, 3], type='int64') + types = pa.array([0, 1, 0, 0, 1, 1, 0], type='int8') + value_offsets = pa.array([0, 0, 2, 1, 1, 2, 3], type='int32') + + arr = pa.UnionArray.from_dense(types, value_offsets, [binary, int64]) + lst = arr.to_pylist() + for i in range(len(arr)): + for j in range(i, len(arr)): + assert arr[i:j].to_pylist() == lst[i:j] + + +def _check_cast_case(case, *, safe=True, check_array_construction=True): + in_data, in_type, out_data, out_type = case + if isinstance(out_data, pa.Array): + assert out_data.type == out_type + expected = out_data + else: + expected = pa.array(out_data, type=out_type) + + # check casting an already created array + if isinstance(in_data, pa.Array): + assert in_data.type == in_type + in_arr = in_data + else: + in_arr = pa.array(in_data, type=in_type) + casted = in_arr.cast(out_type, safe=safe) + casted.validate(full=True) + assert casted.equals(expected) + + # constructing an array with out type which optionally involves casting + # for more see ARROW-1949 + if check_array_construction: + in_arr = pa.array(in_data, type=out_type, safe=safe) + assert in_arr.equals(expected) + + +def test_cast_integers_safe(): + safe_cases = [ + (np.array([0, 1, 2, 3], dtype='i1'), 'int8', + np.array([0, 1, 2, 3], dtype='i4'), pa.int32()), + (np.array([0, 1, 2, 3], dtype='i1'), 'int8', + np.array([0, 1, 2, 3], dtype='u4'), pa.uint16()), + (np.array([0, 1, 2, 3], dtype='i1'), 'int8', + np.array([0, 1, 2, 3], dtype='u1'), pa.uint8()), + (np.array([0, 1, 2, 3], dtype='i1'), 'int8', + np.array([0, 1, 2, 3], dtype='f8'), pa.float64()) + ] + + for case in safe_cases: + _check_cast_case(case) + + unsafe_cases = [ + (np.array([50000], dtype='i4'), 'int32', 'int16'), + (np.array([70000], dtype='i4'), 'int32', 'uint16'), + (np.array([-1], dtype='i4'), 'int32', 'uint16'), + (np.array([50000], dtype='u2'), 'uint16', 'int16') + ] + for in_data, in_type, out_type in unsafe_cases: + in_arr = pa.array(in_data, type=in_type) + + with pytest.raises(pa.ArrowInvalid): + in_arr.cast(out_type) + + +def test_cast_none(): + # ARROW-3735: Ensure that calling cast(None) doesn't segfault. + arr = pa.array([1, 2, 3]) + + with pytest.raises(ValueError): + arr.cast(None) + + +def test_cast_list_to_primitive(): + # ARROW-8070: cast segfaults on unsupported cast from list<binary> to utf8 + arr = pa.array([[1, 2], [3, 4]]) + with pytest.raises(NotImplementedError): + arr.cast(pa.int8()) + + arr = pa.array([[b"a", b"b"], [b"c"]], pa.list_(pa.binary())) + with pytest.raises(NotImplementedError): + arr.cast(pa.binary()) + + +def test_slice_chunked_array_zero_chunks(): + # ARROW-8911 + arr = pa.chunked_array([], type='int8') + assert arr.num_chunks == 0 + + result = arr[:] + assert result.equals(arr) + + # Do not crash + arr[:5] + + +def test_cast_chunked_array(): + arrays = [pa.array([1, 2, 3]), pa.array([4, 5, 6])] + carr = pa.chunked_array(arrays) + + target = pa.float64() + casted = carr.cast(target) + expected = pa.chunked_array([x.cast(target) for x in arrays]) + assert casted.equals(expected) + + +def test_cast_chunked_array_empty(): + # ARROW-8142 + for typ1, typ2 in [(pa.dictionary(pa.int8(), pa.string()), pa.string()), + (pa.int64(), pa.int32())]: + + arr = pa.chunked_array([], type=typ1) + result = arr.cast(typ2) + expected = pa.chunked_array([], type=typ2) + assert result.equals(expected) + + +def test_chunked_array_data_warns(): + with pytest.warns(FutureWarning): + res = pa.chunked_array([[]]).data + assert isinstance(res, pa.ChunkedArray) + + +def test_cast_integers_unsafe(): + # We let NumPy do the unsafe casting + unsafe_cases = [ + (np.array([50000], dtype='i4'), 'int32', + np.array([50000], dtype='i2'), pa.int16()), + (np.array([70000], dtype='i4'), 'int32', + np.array([70000], dtype='u2'), pa.uint16()), + (np.array([-1], dtype='i4'), 'int32', + np.array([-1], dtype='u2'), pa.uint16()), + (np.array([50000], dtype='u2'), pa.uint16(), + np.array([50000], dtype='i2'), pa.int16()) + ] + + for case in unsafe_cases: + _check_cast_case(case, safe=False) + + +def test_floating_point_truncate_safe(): + safe_cases = [ + (np.array([1.0, 2.0, 3.0], dtype='float32'), 'float32', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([1.0, 2.0, 3.0], dtype='float64'), 'float64', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([-10.0, 20.0, -30.0], dtype='float64'), 'float64', + np.array([-10, 20, -30], dtype='i4'), pa.int32()), + ] + for case in safe_cases: + _check_cast_case(case, safe=True) + + +def test_floating_point_truncate_unsafe(): + unsafe_cases = [ + (np.array([1.1, 2.2, 3.3], dtype='float32'), 'float32', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([1.1, 2.2, 3.3], dtype='float64'), 'float64', + np.array([1, 2, 3], dtype='i4'), pa.int32()), + (np.array([-10.1, 20.2, -30.3], dtype='float64'), 'float64', + np.array([-10, 20, -30], dtype='i4'), pa.int32()), + ] + for case in unsafe_cases: + # test safe casting raises + with pytest.raises(pa.ArrowInvalid, match='truncated'): + _check_cast_case(case, safe=True) + + # test unsafe casting truncates + _check_cast_case(case, safe=False) + + +def test_decimal_to_int_safe(): + safe_cases = [ + ( + [decimal.Decimal("123456"), None, decimal.Decimal("-912345")], + pa.decimal128(32, 5), + [123456, None, -912345], + pa.int32() + ), + ( + [decimal.Decimal("1234"), None, decimal.Decimal("-9123")], + pa.decimal128(19, 10), + [1234, None, -9123], + pa.int16() + ), + ( + [decimal.Decimal("123"), None, decimal.Decimal("-91")], + pa.decimal128(19, 10), + [123, None, -91], + pa.int8() + ), + ] + for case in safe_cases: + _check_cast_case(case) + _check_cast_case(case, safe=True) + + +def test_decimal_to_int_value_out_of_bounds(): + out_of_bounds_cases = [ + ( + np.array([ + decimal.Decimal("1234567890123"), + None, + decimal.Decimal("-912345678901234") + ]), + pa.decimal128(32, 5), + [1912276171, None, -135950322], + pa.int32() + ), + ( + [decimal.Decimal("123456"), None, decimal.Decimal("-912345678")], + pa.decimal128(32, 5), + [-7616, None, -19022], + pa.int16() + ), + ( + [decimal.Decimal("1234"), None, decimal.Decimal("-9123")], + pa.decimal128(32, 5), + [-46, None, 93], + pa.int8() + ), + ] + + for case in out_of_bounds_cases: + # test safe casting raises + with pytest.raises(pa.ArrowInvalid, + match='Integer value out of bounds'): + _check_cast_case(case) + + # XXX `safe=False` can be ignored when constructing an array + # from a sequence of Python objects (ARROW-8567) + _check_cast_case(case, safe=False, check_array_construction=False) + + +def test_decimal_to_int_non_integer(): + non_integer_cases = [ + ( + [ + decimal.Decimal("123456.21"), + None, + decimal.Decimal("-912345.13") + ], + pa.decimal128(32, 5), + [123456, None, -912345], + pa.int32() + ), + ( + [decimal.Decimal("1234.134"), None, decimal.Decimal("-9123.1")], + pa.decimal128(19, 10), + [1234, None, -9123], + pa.int16() + ), + ( + [decimal.Decimal("123.1451"), None, decimal.Decimal("-91.21")], + pa.decimal128(19, 10), + [123, None, -91], + pa.int8() + ), + ] + + for case in non_integer_cases: + # test safe casting raises + msg_regexp = 'Rescaling Decimal128 value would cause data loss' + with pytest.raises(pa.ArrowInvalid, match=msg_regexp): + _check_cast_case(case) + + _check_cast_case(case, safe=False) + + +def test_decimal_to_decimal(): + arr = pa.array( + [decimal.Decimal("1234.12"), None], + type=pa.decimal128(19, 10) + ) + result = arr.cast(pa.decimal128(15, 6)) + expected = pa.array( + [decimal.Decimal("1234.12"), None], + type=pa.decimal128(15, 6) + ) + assert result.equals(expected) + + msg_regexp = 'Rescaling Decimal128 value would cause data loss' + with pytest.raises(pa.ArrowInvalid, match=msg_regexp): + result = arr.cast(pa.decimal128(9, 1)) + + result = arr.cast(pa.decimal128(9, 1), safe=False) + expected = pa.array( + [decimal.Decimal("1234.1"), None], + type=pa.decimal128(9, 1) + ) + assert result.equals(expected) + + with pytest.raises(pa.ArrowInvalid, + match='Decimal value does not fit in precision'): + result = arr.cast(pa.decimal128(5, 2)) + + +def test_safe_cast_nan_to_int_raises(): + arr = pa.array([np.nan, 1.]) + + with pytest.raises(pa.ArrowInvalid, match='truncated'): + arr.cast(pa.int64(), safe=True) + + +def test_cast_signed_to_unsigned(): + safe_cases = [ + (np.array([0, 1, 2, 3], dtype='i1'), pa.uint8(), + np.array([0, 1, 2, 3], dtype='u1'), pa.uint8()), + (np.array([0, 1, 2, 3], dtype='i2'), pa.uint16(), + np.array([0, 1, 2, 3], dtype='u2'), pa.uint16()) + ] + + for case in safe_cases: + _check_cast_case(case) + + +def test_cast_from_null(): + in_data = [None] * 3 + in_type = pa.null() + out_types = [ + pa.null(), + pa.uint8(), + pa.float16(), + pa.utf8(), + pa.binary(), + pa.binary(10), + pa.list_(pa.int16()), + pa.list_(pa.int32(), 4), + pa.large_list(pa.uint8()), + pa.decimal128(19, 4), + pa.timestamp('us'), + pa.timestamp('us', tz='UTC'), + pa.timestamp('us', tz='Europe/Paris'), + pa.duration('us'), + pa.month_day_nano_interval(), + pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.list_(pa.int8())), + pa.field('c', pa.string())]), + pa.dictionary(pa.int32(), pa.string()), + ] + for out_type in out_types: + _check_cast_case((in_data, in_type, in_data, out_type)) + + out_types = [ + + pa.union([pa.field('a', pa.binary(10)), + pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), + pa.union([pa.field('a', pa.binary(10)), + pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), + ] + in_arr = pa.array(in_data, type=pa.null()) + for out_type in out_types: + with pytest.raises(NotImplementedError): + in_arr.cast(out_type) + + +def test_cast_string_to_number_roundtrip(): + cases = [ + (pa.array(["1", "127", "-128"]), + pa.array([1, 127, -128], type=pa.int8())), + (pa.array([None, "18446744073709551615"]), + pa.array([None, 18446744073709551615], type=pa.uint64())), + ] + for in_arr, expected in cases: + casted = in_arr.cast(expected.type, safe=True) + casted.validate(full=True) + assert casted.equals(expected) + casted_back = casted.cast(in_arr.type, safe=True) + casted_back.validate(full=True) + assert casted_back.equals(in_arr) + + +def test_cast_dictionary(): + # cast to the value type + arr = pa.array( + ["foo", "bar", None], + type=pa.dictionary(pa.int64(), pa.string()) + ) + expected = pa.array(["foo", "bar", None]) + assert arr.type == pa.dictionary(pa.int64(), pa.string()) + assert arr.cast(pa.string()) == expected + + # cast to a different key type + for key_type in [pa.int8(), pa.int16(), pa.int32()]: + typ = pa.dictionary(key_type, pa.string()) + expected = pa.array( + ["foo", "bar", None], + type=pa.dictionary(key_type, pa.string()) + ) + assert arr.cast(typ) == expected + + # shouldn't crash (ARROW-7077) + with pytest.raises(pa.ArrowInvalid): + arr.cast(pa.int32()) + + +def test_view(): + # ARROW-5992 + arr = pa.array(['foo', 'bar', 'baz'], type=pa.utf8()) + expected = pa.array(['foo', 'bar', 'baz'], type=pa.binary()) + + assert arr.view(pa.binary()).equals(expected) + assert arr.view('binary').equals(expected) + + +def test_unique_simple(): + cases = [ + (pa.array([1, 2, 3, 1, 2, 3]), pa.array([1, 2, 3])), + (pa.array(['foo', None, 'bar', 'foo']), + pa.array(['foo', None, 'bar'])), + (pa.array(['foo', None, 'bar', 'foo'], pa.large_binary()), + pa.array(['foo', None, 'bar'], pa.large_binary())), + ] + for arr, expected in cases: + result = arr.unique() + assert result.equals(expected) + result = pa.chunked_array([arr]).unique() + assert result.equals(expected) + + +def test_value_counts_simple(): + cases = [ + (pa.array([1, 2, 3, 1, 2, 3]), + pa.array([1, 2, 3]), + pa.array([2, 2, 2], type=pa.int64())), + (pa.array(['foo', None, 'bar', 'foo']), + pa.array(['foo', None, 'bar']), + pa.array([2, 1, 1], type=pa.int64())), + (pa.array(['foo', None, 'bar', 'foo'], pa.large_binary()), + pa.array(['foo', None, 'bar'], pa.large_binary()), + pa.array([2, 1, 1], type=pa.int64())), + ] + for arr, expected_values, expected_counts in cases: + for arr_in in (arr, pa.chunked_array([arr])): + result = arr_in.value_counts() + assert result.type.equals( + pa.struct([pa.field("values", arr.type), + pa.field("counts", pa.int64())])) + assert result.field("values").equals(expected_values) + assert result.field("counts").equals(expected_counts) + + +def test_unique_value_counts_dictionary_type(): + indices = pa.array([3, 0, 0, 0, 1, 1, 3, 0, 1, 3, 0, 1]) + dictionary = pa.array(['foo', 'bar', 'baz', 'qux']) + + arr = pa.DictionaryArray.from_arrays(indices, dictionary) + + unique_result = arr.unique() + expected = pa.DictionaryArray.from_arrays(indices.unique(), dictionary) + assert unique_result.equals(expected) + + result = arr.value_counts() + assert result.field('values').equals(unique_result) + assert result.field('counts').equals(pa.array([3, 5, 4], type='int64')) + + arr = pa.DictionaryArray.from_arrays( + pa.array([], type='int64'), dictionary) + unique_result = arr.unique() + expected = pa.DictionaryArray.from_arrays(pa.array([], type='int64'), + pa.array([], type='utf8')) + assert unique_result.equals(expected) + + result = arr.value_counts() + assert result.field('values').equals(unique_result) + assert result.field('counts').equals(pa.array([], type='int64')) + + +def test_dictionary_encode_simple(): + cases = [ + (pa.array([1, 2, 3, None, 1, 2, 3]), + pa.DictionaryArray.from_arrays( + pa.array([0, 1, 2, None, 0, 1, 2], type='int32'), + [1, 2, 3])), + (pa.array(['foo', None, 'bar', 'foo']), + pa.DictionaryArray.from_arrays( + pa.array([0, None, 1, 0], type='int32'), + ['foo', 'bar'])), + (pa.array(['foo', None, 'bar', 'foo'], type=pa.large_binary()), + pa.DictionaryArray.from_arrays( + pa.array([0, None, 1, 0], type='int32'), + pa.array(['foo', 'bar'], type=pa.large_binary()))), + ] + for arr, expected in cases: + result = arr.dictionary_encode() + assert result.equals(expected) + result = pa.chunked_array([arr]).dictionary_encode() + assert result.num_chunks == 1 + assert result.chunk(0).equals(expected) + result = pa.chunked_array([], type=arr.type).dictionary_encode() + assert result.num_chunks == 0 + assert result.type == expected.type + + +def test_dictionary_encode_sliced(): + cases = [ + (pa.array([1, 2, 3, None, 1, 2, 3])[1:-1], + pa.DictionaryArray.from_arrays( + pa.array([0, 1, None, 2, 0], type='int32'), + [2, 3, 1])), + (pa.array([None, 'foo', 'bar', 'foo', 'xyzzy'])[1:-1], + pa.DictionaryArray.from_arrays( + pa.array([0, 1, 0], type='int32'), + ['foo', 'bar'])), + (pa.array([None, 'foo', 'bar', 'foo', 'xyzzy'], + type=pa.large_string())[1:-1], + pa.DictionaryArray.from_arrays( + pa.array([0, 1, 0], type='int32'), + pa.array(['foo', 'bar'], type=pa.large_string()))), + ] + for arr, expected in cases: + result = arr.dictionary_encode() + assert result.equals(expected) + result = pa.chunked_array([arr]).dictionary_encode() + assert result.num_chunks == 1 + assert result.type == expected.type + assert result.chunk(0).equals(expected) + result = pa.chunked_array([], type=arr.type).dictionary_encode() + assert result.num_chunks == 0 + assert result.type == expected.type + + # ARROW-9143 dictionary_encode after slice was segfaulting + array = pa.array(['foo', 'bar', 'baz']) + array.slice(1).dictionary_encode() + + +def test_dictionary_encode_zero_length(): + # User-facing experience of ARROW-7008 + arr = pa.array([], type=pa.string()) + encoded = arr.dictionary_encode() + assert len(encoded.dictionary) == 0 + encoded.validate(full=True) + + +def test_dictionary_decode(): + cases = [ + (pa.array([1, 2, 3, None, 1, 2, 3]), + pa.DictionaryArray.from_arrays( + pa.array([0, 1, 2, None, 0, 1, 2], type='int32'), + [1, 2, 3])), + (pa.array(['foo', None, 'bar', 'foo']), + pa.DictionaryArray.from_arrays( + pa.array([0, None, 1, 0], type='int32'), + ['foo', 'bar'])), + (pa.array(['foo', None, 'bar', 'foo'], type=pa.large_binary()), + pa.DictionaryArray.from_arrays( + pa.array([0, None, 1, 0], type='int32'), + pa.array(['foo', 'bar'], type=pa.large_binary()))), + ] + for expected, arr in cases: + result = arr.dictionary_decode() + assert result.equals(expected) + + +def test_cast_time32_to_int(): + arr = pa.array(np.array([0, 1, 2], dtype='int32'), + type=pa.time32('s')) + expected = pa.array([0, 1, 2], type='i4') + + result = arr.cast('i4') + assert result.equals(expected) + + +def test_cast_time64_to_int(): + arr = pa.array(np.array([0, 1, 2], dtype='int64'), + type=pa.time64('us')) + expected = pa.array([0, 1, 2], type='i8') + + result = arr.cast('i8') + assert result.equals(expected) + + +def test_cast_timestamp_to_int(): + arr = pa.array(np.array([0, 1, 2], dtype='int64'), + type=pa.timestamp('us')) + expected = pa.array([0, 1, 2], type='i8') + + result = arr.cast('i8') + assert result.equals(expected) + + +def test_cast_date32_to_int(): + arr = pa.array([0, 1, 2], type='i4') + + result1 = arr.cast('date32') + result2 = result1.cast('i4') + + expected1 = pa.array([ + datetime.date(1970, 1, 1), + datetime.date(1970, 1, 2), + datetime.date(1970, 1, 3) + ]).cast('date32') + + assert result1.equals(expected1) + assert result2.equals(arr) + + +def test_cast_duration_to_int(): + arr = pa.array(np.array([0, 1, 2], dtype='int64'), + type=pa.duration('us')) + expected = pa.array([0, 1, 2], type='i8') + + result = arr.cast('i8') + assert result.equals(expected) + + +def test_cast_binary_to_utf8(): + binary_arr = pa.array([b'foo', b'bar', b'baz'], type=pa.binary()) + utf8_arr = binary_arr.cast(pa.utf8()) + expected = pa.array(['foo', 'bar', 'baz'], type=pa.utf8()) + + assert utf8_arr.equals(expected) + + non_utf8_values = [('mañana').encode('utf-16-le')] + non_utf8_binary = pa.array(non_utf8_values) + assert non_utf8_binary.type == pa.binary() + with pytest.raises(ValueError): + non_utf8_binary.cast(pa.string()) + + non_utf8_all_null = pa.array(non_utf8_values, mask=np.array([True]), + type=pa.binary()) + # No error + casted = non_utf8_all_null.cast(pa.string()) + assert casted.null_count == 1 + + +def test_cast_date64_to_int(): + arr = pa.array(np.array([0, 1, 2], dtype='int64'), + type=pa.date64()) + expected = pa.array([0, 1, 2], type='i8') + + result = arr.cast('i8') + + assert result.equals(expected) + + +def test_date64_from_builtin_datetime(): + val1 = datetime.datetime(2000, 1, 1, 12, 34, 56, 123456) + val2 = datetime.datetime(2000, 1, 1) + result = pa.array([val1, val2], type='date64') + result2 = pa.array([val1.date(), val2.date()], type='date64') + + assert result.equals(result2) + + as_i8 = result.view('int64') + assert as_i8[0].as_py() == as_i8[1].as_py() + + +@pytest.mark.parametrize(('ty', 'values'), [ + ('bool', [True, False, True]), + ('uint8', range(0, 255)), + ('int8', range(0, 128)), + ('uint16', range(0, 10)), + ('int16', range(0, 10)), + ('uint32', range(0, 10)), + ('int32', range(0, 10)), + ('uint64', range(0, 10)), + ('int64', range(0, 10)), + ('float', [0.0, 0.1, 0.2]), + ('double', [0.0, 0.1, 0.2]), + ('string', ['a', 'b', 'c']), + ('binary', [b'a', b'b', b'c']), + (pa.binary(3), [b'abc', b'bcd', b'cde']) +]) +def test_cast_identities(ty, values): + arr = pa.array(values, type=ty) + assert arr.cast(ty).equals(arr) + + +pickle_test_parametrize = pytest.mark.parametrize( + ('data', 'typ'), + [ + ([True, False, True, True], pa.bool_()), + ([1, 2, 4, 6], pa.int64()), + ([1.0, 2.5, None], pa.float64()), + (['a', None, 'b'], pa.string()), + ([], None), + ([[1, 2], [3]], pa.list_(pa.int64())), + ([[4, 5], [6]], pa.large_list(pa.int16())), + ([['a'], None, ['b', 'c']], pa.list_(pa.string())), + ([(1, 'a'), (2, 'c'), None], + pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.string())])) + ] +) + + +@pickle_test_parametrize +def test_array_pickle(data, typ): + # Allocate here so that we don't have any Arrow data allocated. + # This is needed to ensure that allocator tests can be reliable. + array = pa.array(data, type=typ) + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + result = pickle.loads(pickle.dumps(array, proto)) + assert array.equals(result) + + +def test_array_pickle_dictionary(): + # not included in the above as dictionary array cannot be created with + # the pa.array function + array = pa.DictionaryArray.from_arrays([0, 1, 2, 0, 1], ['a', 'b', 'c']) + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + result = pickle.loads(pickle.dumps(array, proto)) + assert array.equals(result) + + +@h.given( + past.arrays( + past.all_types, + size=st.integers(min_value=0, max_value=10) + ) +) +def test_pickling(arr): + data = pickle.dumps(arr) + restored = pickle.loads(data) + assert arr.equals(restored) + + +@pickle_test_parametrize +def test_array_pickle5(data, typ): + # Test zero-copy pickling with protocol 5 (PEP 574) + picklemod = pickle5 or pickle + if pickle5 is None and picklemod.HIGHEST_PROTOCOL < 5: + pytest.skip("need pickle5 package or Python 3.8+") + + array = pa.array(data, type=typ) + addresses = [buf.address if buf is not None else 0 + for buf in array.buffers()] + + for proto in range(5, pickle.HIGHEST_PROTOCOL + 1): + buffers = [] + pickled = picklemod.dumps(array, proto, buffer_callback=buffers.append) + result = picklemod.loads(pickled, buffers=buffers) + assert array.equals(result) + + result_addresses = [buf.address if buf is not None else 0 + for buf in result.buffers()] + assert result_addresses == addresses + + +@pytest.mark.parametrize( + 'narr', + [ + np.arange(10, dtype=np.int64), + np.arange(10, dtype=np.int32), + np.arange(10, dtype=np.int16), + np.arange(10, dtype=np.int8), + np.arange(10, dtype=np.uint64), + np.arange(10, dtype=np.uint32), + np.arange(10, dtype=np.uint16), + np.arange(10, dtype=np.uint8), + np.arange(10, dtype=np.float64), + np.arange(10, dtype=np.float32), + np.arange(10, dtype=np.float16), + ] +) +def test_to_numpy_roundtrip(narr): + arr = pa.array(narr) + assert narr.dtype == arr.to_numpy().dtype + np.testing.assert_array_equal(narr, arr.to_numpy()) + np.testing.assert_array_equal(narr[:6], arr[:6].to_numpy()) + np.testing.assert_array_equal(narr[2:], arr[2:].to_numpy()) + np.testing.assert_array_equal(narr[2:6], arr[2:6].to_numpy()) + + +def test_array_uint64_from_py_over_range(): + arr = pa.array([2 ** 63], type=pa.uint64()) + expected = pa.array(np.array([2 ** 63], dtype='u8')) + assert arr.equals(expected) + + +def test_array_conversions_no_sentinel_values(): + arr = np.array([1, 2, 3, 4], dtype='int8') + refcount = sys.getrefcount(arr) + arr2 = pa.array(arr) # noqa + assert sys.getrefcount(arr) == (refcount + 1) + + assert arr2.type == 'int8' + + arr3 = pa.array(np.array([1, np.nan, 2, 3, np.nan, 4], dtype='float32'), + type='float32') + assert arr3.type == 'float32' + assert arr3.null_count == 0 + + +def test_time32_time64_from_integer(): + # ARROW-4111 + result = pa.array([1, 2, None], type=pa.time32('s')) + expected = pa.array([datetime.time(second=1), + datetime.time(second=2), None], + type=pa.time32('s')) + assert result.equals(expected) + + result = pa.array([1, 2, None], type=pa.time32('ms')) + expected = pa.array([datetime.time(microsecond=1000), + datetime.time(microsecond=2000), None], + type=pa.time32('ms')) + assert result.equals(expected) + + result = pa.array([1, 2, None], type=pa.time64('us')) + expected = pa.array([datetime.time(microsecond=1), + datetime.time(microsecond=2), None], + type=pa.time64('us')) + assert result.equals(expected) + + result = pa.array([1000, 2000, None], type=pa.time64('ns')) + expected = pa.array([datetime.time(microsecond=1), + datetime.time(microsecond=2), None], + type=pa.time64('ns')) + assert result.equals(expected) + + +def test_binary_string_pandas_null_sentinels(): + # ARROW-6227 + def _check_case(ty): + arr = pa.array(['string', np.nan], type=ty, from_pandas=True) + expected = pa.array(['string', None], type=ty) + assert arr.equals(expected) + _check_case('binary') + _check_case('utf8') + + +def test_pandas_null_sentinels_raise_error(): + # ARROW-6227 + cases = [ + ([None, np.nan], 'null'), + (['string', np.nan], 'binary'), + (['string', np.nan], 'utf8'), + (['string', np.nan], 'large_binary'), + (['string', np.nan], 'large_utf8'), + ([b'string', np.nan], pa.binary(6)), + ([True, np.nan], pa.bool_()), + ([decimal.Decimal('0'), np.nan], pa.decimal128(12, 2)), + ([0, np.nan], pa.date32()), + ([0, np.nan], pa.date32()), + ([0, np.nan], pa.date64()), + ([0, np.nan], pa.time32('s')), + ([0, np.nan], pa.time64('us')), + ([0, np.nan], pa.timestamp('us')), + ([0, np.nan], pa.duration('us')), + ] + for case, ty in cases: + # Both types of exceptions are raised. May want to clean that up + with pytest.raises((ValueError, TypeError)): + pa.array(case, type=ty) + + # from_pandas option suppresses failure + result = pa.array(case, type=ty, from_pandas=True) + assert result.null_count == (1 if ty != 'null' else 2) + + +@pytest.mark.pandas +def test_pandas_null_sentinels_index(): + # ARROW-7023 - ensure that when passing a pandas Index, "from_pandas" + # semantics are used + import pandas as pd + idx = pd.Index([1, 2, np.nan], dtype=object) + result = pa.array(idx) + expected = pa.array([1, 2, np.nan], from_pandas=True) + assert result.equals(expected) + + +def test_array_from_numpy_datetimeD(): + arr = np.array([None, datetime.date(2017, 4, 4)], dtype='datetime64[D]') + + result = pa.array(arr) + expected = pa.array([None, datetime.date(2017, 4, 4)], type=pa.date32()) + assert result.equals(expected) + + +def test_array_from_naive_datetimes(): + arr = pa.array([ + None, + datetime.datetime(2017, 4, 4, 12, 11, 10), + datetime.datetime(2018, 1, 1, 0, 2, 0) + ]) + assert arr.type == pa.timestamp('us', tz=None) + + +@pytest.mark.parametrize(('dtype', 'type'), [ + ('datetime64[s]', pa.timestamp('s')), + ('datetime64[ms]', pa.timestamp('ms')), + ('datetime64[us]', pa.timestamp('us')), + ('datetime64[ns]', pa.timestamp('ns')) +]) +def test_array_from_numpy_datetime(dtype, type): + data = [ + None, + datetime.datetime(2017, 4, 4, 12, 11, 10), + datetime.datetime(2018, 1, 1, 0, 2, 0) + ] + + # from numpy array + arr = pa.array(np.array(data, dtype=dtype)) + expected = pa.array(data, type=type) + assert arr.equals(expected) + + # from list of numpy scalars + arr = pa.array(list(np.array(data, dtype=dtype))) + assert arr.equals(expected) + + +def test_array_from_different_numpy_datetime_units_raises(): + data = [ + None, + datetime.datetime(2017, 4, 4, 12, 11, 10), + datetime.datetime(2018, 1, 1, 0, 2, 0) + ] + s = np.array(data, dtype='datetime64[s]') + ms = np.array(data, dtype='datetime64[ms]') + data = list(s[:2]) + list(ms[2:]) + + with pytest.raises(pa.ArrowNotImplementedError): + pa.array(data) + + +@pytest.mark.parametrize('unit', ['ns', 'us', 'ms', 's']) +def test_array_from_list_of_timestamps(unit): + n = np.datetime64('NaT', unit) + x = np.datetime64('2017-01-01 01:01:01.111111111', unit) + y = np.datetime64('2018-11-22 12:24:48.111111111', unit) + + a1 = pa.array([n, x, y]) + a2 = pa.array([n, x, y], type=pa.timestamp(unit)) + + assert a1.type == a2.type + assert a1.type.unit == unit + assert a1[0] == a2[0] + + +def test_array_from_timestamp_with_generic_unit(): + n = np.datetime64('NaT') + x = np.datetime64('2017-01-01 01:01:01.111111111') + y = np.datetime64('2018-11-22 12:24:48.111111111') + + with pytest.raises(pa.ArrowNotImplementedError, + match='Unbound or generic datetime64 time unit'): + pa.array([n, x, y]) + + +@pytest.mark.parametrize(('dtype', 'type'), [ + ('timedelta64[s]', pa.duration('s')), + ('timedelta64[ms]', pa.duration('ms')), + ('timedelta64[us]', pa.duration('us')), + ('timedelta64[ns]', pa.duration('ns')) +]) +def test_array_from_numpy_timedelta(dtype, type): + data = [ + None, + datetime.timedelta(1), + datetime.timedelta(0, 1) + ] + + # from numpy array + np_arr = np.array(data, dtype=dtype) + arr = pa.array(np_arr) + assert isinstance(arr, pa.DurationArray) + assert arr.type == type + expected = pa.array(data, type=type) + assert arr.equals(expected) + assert arr.to_pylist() == data + + # from list of numpy scalars + arr = pa.array(list(np.array(data, dtype=dtype))) + assert arr.equals(expected) + assert arr.to_pylist() == data + + +def test_array_from_numpy_timedelta_incorrect_unit(): + # generic (no unit) + td = np.timedelta64(1) + + for data in [[td], np.array([td])]: + with pytest.raises(NotImplementedError): + pa.array(data) + + # unsupported unit + td = np.timedelta64(1, 'M') + for data in [[td], np.array([td])]: + with pytest.raises(NotImplementedError): + pa.array(data) + + +def test_array_from_numpy_ascii(): + arr = np.array(['abcde', 'abc', ''], dtype='|S5') + + arrow_arr = pa.array(arr) + assert arrow_arr.type == 'binary' + expected = pa.array(['abcde', 'abc', ''], type='binary') + assert arrow_arr.equals(expected) + + mask = np.array([False, True, False]) + arrow_arr = pa.array(arr, mask=mask) + expected = pa.array(['abcde', None, ''], type='binary') + assert arrow_arr.equals(expected) + + # Strided variant + arr = np.array(['abcde', 'abc', ''] * 5, dtype='|S5')[::2] + mask = np.array([False, True, False] * 5)[::2] + arrow_arr = pa.array(arr, mask=mask) + + expected = pa.array(['abcde', '', None, 'abcde', '', None, 'abcde', ''], + type='binary') + assert arrow_arr.equals(expected) + + # 0 itemsize + arr = np.array(['', '', ''], dtype='|S0') + arrow_arr = pa.array(arr) + expected = pa.array(['', '', ''], type='binary') + assert arrow_arr.equals(expected) + + +def test_interval_array_from_timedelta(): + data = [ + None, + datetime.timedelta(days=1, seconds=1, microseconds=1, + milliseconds=1, minutes=1, hours=1, weeks=1)] + + # From timedelta (explicit type required) + arr = pa.array(data, pa.month_day_nano_interval()) + assert isinstance(arr, pa.MonthDayNanoIntervalArray) + assert arr.type == pa.month_day_nano_interval() + expected_list = [ + None, + pa.MonthDayNano([0, 8, + (datetime.timedelta(seconds=1, microseconds=1, + milliseconds=1, minutes=1, + hours=1) // + datetime.timedelta(microseconds=1)) * 1000])] + expected = pa.array(expected_list) + assert arr.equals(expected) + assert arr.to_pylist() == expected_list + + +@pytest.mark.pandas +def test_interval_array_from_relativedelta(): + # dateutil is dependency of pandas + from dateutil.relativedelta import relativedelta + from pandas import DateOffset + data = [ + None, + relativedelta(years=1, months=1, + days=1, seconds=1, microseconds=1, + minutes=1, hours=1, weeks=1, leapdays=1)] + # Note leapdays are ignored. + + # From relativedelta + arr = pa.array(data) + assert isinstance(arr, pa.MonthDayNanoIntervalArray) + assert arr.type == pa.month_day_nano_interval() + expected_list = [ + None, + pa.MonthDayNano([13, 8, + (datetime.timedelta(seconds=1, microseconds=1, + minutes=1, hours=1) // + datetime.timedelta(microseconds=1)) * 1000])] + expected = pa.array(expected_list) + assert arr.equals(expected) + assert arr.to_pandas().tolist() == [ + None, DateOffset(months=13, days=8, + microseconds=( + datetime.timedelta(seconds=1, microseconds=1, + minutes=1, hours=1) // + datetime.timedelta(microseconds=1)), + nanoseconds=0)] + with pytest.raises(ValueError): + pa.array([DateOffset(years=((1 << 32) // 12), months=100)]) + with pytest.raises(ValueError): + pa.array([DateOffset(weeks=((1 << 32) // 7), days=100)]) + with pytest.raises(ValueError): + pa.array([DateOffset(seconds=((1 << 64) // 1000000000), + nanoseconds=1)]) + with pytest.raises(ValueError): + pa.array([DateOffset(microseconds=((1 << 64) // 100))]) + + +@pytest.mark.pandas +def test_interval_array_from_dateoffset(): + from pandas.tseries.offsets import DateOffset + data = [ + None, + DateOffset(years=1, months=1, + days=1, seconds=1, microseconds=1, + minutes=1, hours=1, weeks=1, nanoseconds=1), + DateOffset()] + + arr = pa.array(data) + assert isinstance(arr, pa.MonthDayNanoIntervalArray) + assert arr.type == pa.month_day_nano_interval() + expected_list = [ + None, + pa.MonthDayNano([13, 8, 3661000001001]), + pa.MonthDayNano([0, 0, 0])] + expected = pa.array(expected_list) + assert arr.equals(expected) + assert arr.to_pandas().tolist() == [ + None, DateOffset(months=13, days=8, + microseconds=( + datetime.timedelta(seconds=1, microseconds=1, + minutes=1, hours=1) // + datetime.timedelta(microseconds=1)), + nanoseconds=1), + DateOffset(months=0, days=0, microseconds=0, nanoseconds=0)] + + +def test_array_from_numpy_unicode(): + dtypes = ['<U5', '>U5'] + + for dtype in dtypes: + arr = np.array(['abcde', 'abc', ''], dtype=dtype) + + arrow_arr = pa.array(arr) + assert arrow_arr.type == 'utf8' + expected = pa.array(['abcde', 'abc', ''], type='utf8') + assert arrow_arr.equals(expected) + + mask = np.array([False, True, False]) + arrow_arr = pa.array(arr, mask=mask) + expected = pa.array(['abcde', None, ''], type='utf8') + assert arrow_arr.equals(expected) + + # Strided variant + arr = np.array(['abcde', 'abc', ''] * 5, dtype=dtype)[::2] + mask = np.array([False, True, False] * 5)[::2] + arrow_arr = pa.array(arr, mask=mask) + + expected = pa.array(['abcde', '', None, 'abcde', '', None, + 'abcde', ''], type='utf8') + assert arrow_arr.equals(expected) + + # 0 itemsize + arr = np.array(['', '', ''], dtype='<U0') + arrow_arr = pa.array(arr) + expected = pa.array(['', '', ''], type='utf8') + assert arrow_arr.equals(expected) + + +def test_array_string_from_non_string(): + # ARROW-5682 - when converting to string raise on non string-like dtype + with pytest.raises(TypeError): + pa.array(np.array([1, 2, 3]), type=pa.string()) + + +def test_array_string_from_all_null(): + # ARROW-5682 + vals = np.array([None, None], dtype=object) + arr = pa.array(vals, type=pa.string()) + assert arr.null_count == 2 + + vals = np.array([np.nan, np.nan], dtype='float64') + # by default raises, but accept as all-null when from_pandas=True + with pytest.raises(TypeError): + pa.array(vals, type=pa.string()) + arr = pa.array(vals, type=pa.string(), from_pandas=True) + assert arr.null_count == 2 + + +def test_array_from_masked(): + ma = np.ma.array([1, 2, 3, 4], dtype='int64', + mask=[False, False, True, False]) + result = pa.array(ma) + expected = pa.array([1, 2, None, 4], type='int64') + assert expected.equals(result) + + with pytest.raises(ValueError, match="Cannot pass a numpy masked array"): + pa.array(ma, mask=np.array([True, False, False, False])) + + +def test_array_from_shrunken_masked(): + ma = np.ma.array([0], dtype='int64') + result = pa.array(ma) + expected = pa.array([0], type='int64') + assert expected.equals(result) + + +def test_array_from_invalid_dim_raises(): + msg = "only handle 1-dimensional arrays" + arr2d = np.array([[1, 2, 3], [4, 5, 6]]) + with pytest.raises(ValueError, match=msg): + pa.array(arr2d) + + arr0d = np.array(0) + with pytest.raises(ValueError, match=msg): + pa.array(arr0d) + + +def test_array_from_strided_bool(): + # ARROW-6325 + arr = np.ones((3, 2), dtype=bool) + result = pa.array(arr[:, 0]) + expected = pa.array([True, True, True]) + assert result.equals(expected) + result = pa.array(arr[0, :]) + expected = pa.array([True, True]) + assert result.equals(expected) + + +def test_array_from_strided(): + pydata = [ + ([b"ab", b"cd", b"ef"], (pa.binary(), pa.binary(2))), + ([1, 2, 3], (pa.int8(), pa.int16(), pa.int32(), pa.int64())), + ([1.0, 2.0, 3.0], (pa.float32(), pa.float64())), + (["ab", "cd", "ef"], (pa.utf8(), )) + ] + + for values, dtypes in pydata: + nparray = np.array(values) + for patype in dtypes: + for mask in (None, np.array([False, False])): + arrow_array = pa.array(nparray[::2], patype, + mask=mask) + assert values[::2] == arrow_array.to_pylist() + + +def test_boolean_true_count_false_count(): + # ARROW-9145 + arr = pa.array([True, True, None, False, None, True] * 1000) + assert arr.true_count == 3000 + assert arr.false_count == 1000 + + +def test_buffers_primitive(): + a = pa.array([1, 2, None, 4], type=pa.int16()) + buffers = a.buffers() + assert len(buffers) == 2 + null_bitmap = buffers[0].to_pybytes() + assert 1 <= len(null_bitmap) <= 64 # XXX this is varying + assert bytearray(null_bitmap)[0] == 0b00001011 + + # Slicing does not affect the buffers but the offset + a_sliced = a[1:] + buffers = a_sliced.buffers() + a_sliced.offset == 1 + assert len(buffers) == 2 + null_bitmap = buffers[0].to_pybytes() + assert 1 <= len(null_bitmap) <= 64 # XXX this is varying + assert bytearray(null_bitmap)[0] == 0b00001011 + + assert struct.unpack('hhxxh', buffers[1].to_pybytes()) == (1, 2, 4) + + a = pa.array(np.int8([4, 5, 6])) + buffers = a.buffers() + assert len(buffers) == 2 + # No null bitmap from Numpy int array + assert buffers[0] is None + assert struct.unpack('3b', buffers[1].to_pybytes()) == (4, 5, 6) + + a = pa.array([b'foo!', None, b'bar!!']) + buffers = a.buffers() + assert len(buffers) == 3 + null_bitmap = buffers[0].to_pybytes() + assert bytearray(null_bitmap)[0] == 0b00000101 + offsets = buffers[1].to_pybytes() + assert struct.unpack('4i', offsets) == (0, 4, 4, 9) + values = buffers[2].to_pybytes() + assert values == b'foo!bar!!' + + +def test_buffers_nested(): + a = pa.array([[1, 2], None, [3, None, 4, 5]], type=pa.list_(pa.int64())) + buffers = a.buffers() + assert len(buffers) == 4 + # The parent buffers + null_bitmap = buffers[0].to_pybytes() + assert bytearray(null_bitmap)[0] == 0b00000101 + offsets = buffers[1].to_pybytes() + assert struct.unpack('4i', offsets) == (0, 2, 2, 6) + # The child buffers + null_bitmap = buffers[2].to_pybytes() + assert bytearray(null_bitmap)[0] == 0b00110111 + values = buffers[3].to_pybytes() + assert struct.unpack('qqq8xqq', values) == (1, 2, 3, 4, 5) + + a = pa.array([(42, None), None, (None, 43)], + type=pa.struct([pa.field('a', pa.int8()), + pa.field('b', pa.int16())])) + buffers = a.buffers() + assert len(buffers) == 5 + # The parent buffer + null_bitmap = buffers[0].to_pybytes() + assert bytearray(null_bitmap)[0] == 0b00000101 + # The child buffers: 'a' + null_bitmap = buffers[1].to_pybytes() + assert bytearray(null_bitmap)[0] == 0b00000011 + values = buffers[2].to_pybytes() + assert struct.unpack('bxx', values) == (42,) + # The child buffers: 'b' + null_bitmap = buffers[3].to_pybytes() + assert bytearray(null_bitmap)[0] == 0b00000110 + values = buffers[4].to_pybytes() + assert struct.unpack('4xh', values) == (43,) + + +def test_nbytes_sizeof(): + a = pa.array(np.array([4, 5, 6], dtype='int64')) + assert a.nbytes == 8 * 3 + assert sys.getsizeof(a) >= object.__sizeof__(a) + a.nbytes + a = pa.array([1, None, 3], type='int64') + assert a.nbytes == 8*3 + 1 + assert sys.getsizeof(a) >= object.__sizeof__(a) + a.nbytes + a = pa.array([[1, 2], None, [3, None, 4, 5]], type=pa.list_(pa.int64())) + assert a.nbytes == 1 + 4 * 4 + 1 + 6 * 8 + assert sys.getsizeof(a) >= object.__sizeof__(a) + a.nbytes + + +def test_invalid_tensor_constructor_repr(): + # ARROW-2638: prevent calling extension class constructors directly + with pytest.raises(TypeError): + repr(pa.Tensor([1])) + + +def test_invalid_tensor_construction(): + with pytest.raises(TypeError): + pa.Tensor() + + +@pytest.mark.parametrize(('offset_type', 'list_type_factory'), + [(pa.int32(), pa.list_), (pa.int64(), pa.large_list)]) +def test_list_array_flatten(offset_type, list_type_factory): + typ2 = list_type_factory( + list_type_factory( + pa.int64() + ) + ) + arr2 = pa.array([ + None, + [ + [1, None, 2], + None, + [3, 4] + ], + [], + [ + [], + [5, 6], + None + ], + [ + [7, 8] + ] + ], type=typ2) + offsets2 = pa.array([0, 0, 3, 3, 6, 7], type=offset_type) + + typ1 = list_type_factory(pa.int64()) + arr1 = pa.array([ + [1, None, 2], + None, + [3, 4], + [], + [5, 6], + None, + [7, 8] + ], type=typ1) + offsets1 = pa.array([0, 3, 3, 5, 5, 7, 7, 9], type=offset_type) + + arr0 = pa.array([ + 1, None, 2, + 3, 4, + 5, 6, + 7, 8 + ], type=pa.int64()) + + assert arr2.flatten().equals(arr1) + assert arr2.offsets.equals(offsets2) + assert arr2.values.equals(arr1) + assert arr1.flatten().equals(arr0) + assert arr1.offsets.equals(offsets1) + assert arr1.values.equals(arr0) + assert arr2.flatten().flatten().equals(arr0) + assert arr2.values.values.equals(arr0) + + +@pytest.mark.parametrize(('offset_type', 'list_type_factory'), + [(pa.int32(), pa.list_), (pa.int64(), pa.large_list)]) +def test_list_value_parent_indices(offset_type, list_type_factory): + arr = pa.array( + [ + [0, 1, 2], + None, + [], + [3, 4] + ], type=list_type_factory(pa.int32())) + expected = pa.array([0, 0, 0, 3, 3], type=offset_type) + assert arr.value_parent_indices().equals(expected) + + +@pytest.mark.parametrize(('offset_type', 'list_type_factory'), + [(pa.int32(), pa.list_), (pa.int64(), pa.large_list)]) +def test_list_value_lengths(offset_type, list_type_factory): + arr = pa.array( + [ + [0, 1, 2], + None, + [], + [3, 4] + ], type=list_type_factory(pa.int32())) + expected = pa.array([3, None, 0, 2], type=offset_type) + assert arr.value_lengths().equals(expected) + + +@pytest.mark.parametrize('list_type_factory', [pa.list_, pa.large_list]) +def test_list_array_flatten_non_canonical(list_type_factory): + # Non-canonical list array (null elements backed by non-empty sublists) + typ = list_type_factory(pa.int64()) + arr = pa.array([[1], [2, 3], [4, 5, 6]], type=typ) + buffers = arr.buffers()[:2] + buffers[0] = pa.py_buffer(b"\x05") # validity bitmap + arr = arr.from_buffers(arr.type, len(arr), buffers, children=[arr.values]) + assert arr.to_pylist() == [[1], None, [4, 5, 6]] + assert arr.offsets.to_pylist() == [0, 1, 3, 6] + + flattened = arr.flatten() + flattened.validate(full=True) + assert flattened.type == typ.value_type + assert flattened.to_pylist() == [1, 4, 5, 6] + + # .values is the physical values array (including masked elements) + assert arr.values.to_pylist() == [1, 2, 3, 4, 5, 6] + + +@pytest.mark.parametrize('klass', [pa.ListArray, pa.LargeListArray]) +def test_list_array_values_offsets_sliced(klass): + # ARROW-7301 + arr = klass.from_arrays(offsets=[0, 3, 4, 6], values=[1, 2, 3, 4, 5, 6]) + assert arr.values.to_pylist() == [1, 2, 3, 4, 5, 6] + assert arr.offsets.to_pylist() == [0, 3, 4, 6] + + # sliced -> values keeps referring to full values buffer, but offsets is + # sliced as well so the offsets correctly point into the full values array + # sliced -> flatten() will return the sliced value array. + arr2 = arr[1:] + assert arr2.values.to_pylist() == [1, 2, 3, 4, 5, 6] + assert arr2.offsets.to_pylist() == [3, 4, 6] + assert arr2.flatten().to_pylist() == [4, 5, 6] + i = arr2.offsets[0].as_py() + j = arr2.offsets[1].as_py() + assert arr2[0].as_py() == arr2.values[i:j].to_pylist() == [4] + + +def test_fixed_size_list_array_flatten(): + typ2 = pa.list_(pa.list_(pa.int64(), 2), 3) + arr2 = pa.array([ + [ + [1, 2], + [3, 4], + [5, 6], + ], + None, + [ + [7, None], + None, + [8, 9] + ], + ], type=typ2) + assert arr2.type.equals(typ2) + + typ1 = pa.list_(pa.int64(), 2) + arr1 = pa.array([ + [1, 2], [3, 4], [5, 6], + None, None, None, + [7, None], None, [8, 9] + ], type=typ1) + assert arr1.type.equals(typ1) + assert arr2.flatten().equals(arr1) + + typ0 = pa.int64() + arr0 = pa.array([ + 1, 2, 3, 4, 5, 6, + None, None, None, None, None, None, + 7, None, None, None, 8, 9, + ], type=typ0) + assert arr0.type.equals(typ0) + assert arr1.flatten().equals(arr0) + assert arr2.flatten().flatten().equals(arr0) + + +def test_struct_array_flatten(): + ty = pa.struct([pa.field('x', pa.int16()), + pa.field('y', pa.float32())]) + a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) + xs, ys = a.flatten() + assert xs.type == pa.int16() + assert ys.type == pa.float32() + assert xs.to_pylist() == [1, 3, 5] + assert ys.to_pylist() == [2.5, 4.5, 6.5] + xs, ys = a[1:].flatten() + assert xs.to_pylist() == [3, 5] + assert ys.to_pylist() == [4.5, 6.5] + + a = pa.array([(1, 2.5), None, (3, 4.5)], type=ty) + xs, ys = a.flatten() + assert xs.to_pylist() == [1, None, 3] + assert ys.to_pylist() == [2.5, None, 4.5] + xs, ys = a[1:].flatten() + assert xs.to_pylist() == [None, 3] + assert ys.to_pylist() == [None, 4.5] + + a = pa.array([(1, None), (2, 3.5), (None, 4.5)], type=ty) + xs, ys = a.flatten() + assert xs.to_pylist() == [1, 2, None] + assert ys.to_pylist() == [None, 3.5, 4.5] + xs, ys = a[1:].flatten() + assert xs.to_pylist() == [2, None] + assert ys.to_pylist() == [3.5, 4.5] + + a = pa.array([(1, None), None, (None, 2.5)], type=ty) + xs, ys = a.flatten() + assert xs.to_pylist() == [1, None, None] + assert ys.to_pylist() == [None, None, 2.5] + xs, ys = a[1:].flatten() + assert xs.to_pylist() == [None, None] + assert ys.to_pylist() == [None, 2.5] + + +def test_struct_array_field(): + ty = pa.struct([pa.field('x', pa.int16()), + pa.field('y', pa.float32())]) + a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) + + x0 = a.field(0) + y0 = a.field(1) + x1 = a.field(-2) + y1 = a.field(-1) + x2 = a.field('x') + y2 = a.field('y') + + assert isinstance(x0, pa.lib.Int16Array) + assert isinstance(y1, pa.lib.FloatArray) + assert x0.equals(pa.array([1, 3, 5], type=pa.int16())) + assert y0.equals(pa.array([2.5, 4.5, 6.5], type=pa.float32())) + assert x0.equals(x1) + assert x0.equals(x2) + assert y0.equals(y1) + assert y0.equals(y2) + + for invalid_index in [None, pa.int16()]: + with pytest.raises(TypeError): + a.field(invalid_index) + + for invalid_index in [3, -3]: + with pytest.raises(IndexError): + a.field(invalid_index) + + for invalid_name in ['z', '']: + with pytest.raises(KeyError): + a.field(invalid_name) + + +def test_empty_cast(): + types = [ + pa.null(), + pa.bool_(), + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.uint8(), + pa.uint16(), + pa.uint32(), + pa.uint64(), + pa.float16(), + pa.float32(), + pa.float64(), + pa.date32(), + pa.date64(), + pa.binary(), + pa.binary(length=4), + pa.string(), + ] + + for (t1, t2) in itertools.product(types, types): + try: + # ARROW-4766: Ensure that supported types conversion don't segfault + # on empty arrays of common types + pa.array([], type=t1).cast(t2) + except (pa.lib.ArrowNotImplementedError, pa.ArrowInvalid): + continue + + +def test_nested_dictionary_array(): + dict_arr = pa.DictionaryArray.from_arrays([0, 1, 0], ['a', 'b']) + list_arr = pa.ListArray.from_arrays([0, 2, 3], dict_arr) + assert list_arr.to_pylist() == [['a', 'b'], ['a']] + + dict_arr = pa.DictionaryArray.from_arrays([0, 1, 0], ['a', 'b']) + dict_arr2 = pa.DictionaryArray.from_arrays([0, 1, 2, 1, 0], dict_arr) + assert dict_arr2.to_pylist() == ['a', 'b', 'a', 'b', 'a'] + + +def test_array_from_numpy_str_utf8(): + # ARROW-3890 -- in Python 3, NPY_UNICODE arrays are produced, but in Python + # 2 they are NPY_STRING (binary), so we must do UTF-8 validation + vec = np.array(["toto", "tata"]) + vec2 = np.array(["toto", "tata"], dtype=object) + + arr = pa.array(vec, pa.string()) + arr2 = pa.array(vec2, pa.string()) + expected = pa.array(["toto", "tata"]) + assert arr.equals(expected) + assert arr2.equals(expected) + + # with mask, separate code path + mask = np.array([False, False], dtype=bool) + arr = pa.array(vec, pa.string(), mask=mask) + assert arr.equals(expected) + + # UTF8 validation failures + vec = np.array([('mañana').encode('utf-16-le')]) + with pytest.raises(ValueError): + pa.array(vec, pa.string()) + + with pytest.raises(ValueError): + pa.array(vec, pa.string(), mask=np.array([False])) + + +@pytest.mark.slow +@pytest.mark.large_memory +def test_numpy_binary_overflow_to_chunked(): + # ARROW-3762, ARROW-5966 + + # 2^31 + 1 bytes + values = [b'x'] + unicode_values = ['x'] + + # Make 10 unique 1MB strings then repeat then 2048 times + unique_strings = { + i: b'x' * ((1 << 20) - 1) + str(i % 10).encode('utf8') + for i in range(10) + } + unicode_unique_strings = {i: x.decode('utf8') + for i, x in unique_strings.items()} + values += [unique_strings[i % 10] for i in range(1 << 11)] + unicode_values += [unicode_unique_strings[i % 10] + for i in range(1 << 11)] + + for case, ex_type in [(values, pa.binary()), + (unicode_values, pa.utf8())]: + arr = np.array(case) + arrow_arr = pa.array(arr) + arr = None + + assert isinstance(arrow_arr, pa.ChunkedArray) + assert arrow_arr.type == ex_type + + # Split up into 16MB chunks. 128 * 16 = 2048, so 129 + assert arrow_arr.num_chunks == 129 + + value_index = 0 + for i in range(arrow_arr.num_chunks): + chunk = arrow_arr.chunk(i) + for val in chunk: + assert val.as_py() == case[value_index] + value_index += 1 + + +@pytest.mark.large_memory +def test_list_child_overflow_to_chunked(): + kilobyte_string = 'x' * 1024 + two_mega = 2**21 + + vals = [[kilobyte_string]] * (two_mega - 1) + arr = pa.array(vals) + assert isinstance(arr, pa.Array) + assert len(arr) == two_mega - 1 + + vals = [[kilobyte_string]] * two_mega + arr = pa.array(vals) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr) == two_mega + assert len(arr.chunk(0)) == two_mega - 1 + assert len(arr.chunk(1)) == 1 + + +def test_infer_type_masked(): + # ARROW-5208 + ty = pa.infer_type(['foo', 'bar', None, 2], + mask=[False, False, False, True]) + assert ty == pa.utf8() + + # all masked + ty = pa.infer_type(['foo', 'bar', None, 2], + mask=np.array([True, True, True, True])) + assert ty == pa.null() + + # length 0 + assert pa.infer_type([], mask=[]) == pa.null() + + +def test_array_masked(): + # ARROW-5208 + arr = pa.array([4, None, 4, 3.], + mask=np.array([False, True, False, True])) + assert arr.type == pa.int64() + + # ndarray dtype=object argument + arr = pa.array(np.array([4, None, 4, 3.], dtype="O"), + mask=np.array([False, True, False, True])) + assert arr.type == pa.int64() + + +def test_array_supported_masks(): + # ARROW-13883 + arr = pa.array([4, None, 4, 3.], + mask=np.array([False, True, False, True])) + assert arr.to_pylist() == [4, None, 4, None] + + arr = pa.array([4, None, 4, 3], + mask=pa.array([False, True, False, True])) + assert arr.to_pylist() == [4, None, 4, None] + + arr = pa.array([4, None, 4, 3], + mask=[False, True, False, True]) + assert arr.to_pylist() == [4, None, 4, None] + + arr = pa.array([4, 3, None, 3], + mask=[False, True, False, True]) + assert arr.to_pylist() == [4, None, None, None] + + # Non boolean values + with pytest.raises(pa.ArrowTypeError): + arr = pa.array([4, None, 4, 3], + mask=pa.array([1.0, 2.0, 3.0, 4.0])) + + with pytest.raises(pa.ArrowTypeError): + arr = pa.array([4, None, 4, 3], + mask=[1.0, 2.0, 3.0, 4.0]) + + with pytest.raises(pa.ArrowTypeError): + arr = pa.array([4, None, 4, 3], + mask=np.array([1.0, 2.0, 3.0, 4.0])) + + with pytest.raises(pa.ArrowTypeError): + arr = pa.array([4, None, 4, 3], + mask=pa.array([False, True, False, True], + mask=pa.array([True, True, True, True]))) + + with pytest.raises(pa.ArrowTypeError): + arr = pa.array([4, None, 4, 3], + mask=pa.array([False, None, False, True])) + + # Numpy arrays only accepts numpy masks + with pytest.raises(TypeError): + arr = pa.array(np.array([4, None, 4, 3.]), + mask=[True, False, True, False]) + + with pytest.raises(TypeError): + arr = pa.array(np.array([4, None, 4, 3.]), + mask=pa.array([True, False, True, False])) + + +def test_binary_array_masked(): + # ARROW-12431 + masked_basic = pa.array([b'\x05'], type=pa.binary(1), + mask=np.array([False])) + assert [b'\x05'] == masked_basic.to_pylist() + + # Fixed Length Binary + masked = pa.array(np.array([b'\x05']), type=pa.binary(1), + mask=np.array([False])) + assert [b'\x05'] == masked.to_pylist() + + masked_nulls = pa.array(np.array([b'\x05']), type=pa.binary(1), + mask=np.array([True])) + assert [None] == masked_nulls.to_pylist() + + # Variable Length Binary + masked = pa.array(np.array([b'\x05']), type=pa.binary(), + mask=np.array([False])) + assert [b'\x05'] == masked.to_pylist() + + masked_nulls = pa.array(np.array([b'\x05']), type=pa.binary(), + mask=np.array([True])) + assert [None] == masked_nulls.to_pylist() + + # Fixed Length Binary, copy + npa = np.array([b'aaa', b'bbb', b'ccc']*10) + arrow_array = pa.array(npa, type=pa.binary(3), + mask=np.array([False, False, False]*10)) + npa[npa == b"bbb"] = b"XXX" + assert ([b'aaa', b'bbb', b'ccc']*10) == arrow_array.to_pylist() + + +def test_binary_array_strided(): + # Masked + nparray = np.array([b"ab", b"cd", b"ef"]) + arrow_array = pa.array(nparray[::2], pa.binary(2), + mask=np.array([False, False])) + assert [b"ab", b"ef"] == arrow_array.to_pylist() + + # Unmasked + nparray = np.array([b"ab", b"cd", b"ef"]) + arrow_array = pa.array(nparray[::2], pa.binary(2)) + assert [b"ab", b"ef"] == arrow_array.to_pylist() + + +def test_array_invalid_mask_raises(): + # ARROW-10742 + cases = [ + ([1, 2], np.array([False, False], dtype="O"), + TypeError, "must be boolean dtype"), + + ([1, 2], np.array([[False], [False]]), + pa.ArrowInvalid, "must be 1D array"), + + ([1, 2, 3], np.array([False, False]), + pa.ArrowInvalid, "different length"), + + (np.array([1, 2]), np.array([False, False], dtype="O"), + TypeError, "must be boolean dtype"), + + (np.array([1, 2]), np.array([[False], [False]]), + ValueError, "must be 1D array"), + + (np.array([1, 2, 3]), np.array([False, False]), + ValueError, "different length"), + ] + for obj, mask, ex, msg in cases: + with pytest.raises(ex, match=msg): + pa.array(obj, mask=mask) + + +def test_array_from_large_pyints(): + # ARROW-5430 + with pytest.raises(OverflowError): + # too large for int64 so dtype must be explicitly provided + pa.array([int(2 ** 63)]) + + +def test_array_protocol(): + + class MyArray: + def __init__(self, data): + self.data = data + + def __arrow_array__(self, type=None): + return pa.array(self.data, type=type) + + arr = MyArray(np.array([1, 2, 3], dtype='int64')) + result = pa.array(arr) + expected = pa.array([1, 2, 3], type=pa.int64()) + assert result.equals(expected) + result = pa.array(arr, type=pa.int64()) + expected = pa.array([1, 2, 3], type=pa.int64()) + assert result.equals(expected) + result = pa.array(arr, type=pa.float64()) + expected = pa.array([1, 2, 3], type=pa.float64()) + assert result.equals(expected) + + # raise error when passing size or mask keywords + with pytest.raises(ValueError): + pa.array(arr, mask=np.array([True, False, True])) + with pytest.raises(ValueError): + pa.array(arr, size=3) + + # ensure the return value is an Array + class MyArrayInvalid: + def __init__(self, data): + self.data = data + + def __arrow_array__(self, type=None): + return np.array(self.data) + + arr = MyArrayInvalid(np.array([1, 2, 3], dtype='int64')) + with pytest.raises(TypeError): + pa.array(arr) + + # ARROW-7066 - allow ChunkedArray output + class MyArray2: + def __init__(self, data): + self.data = data + + def __arrow_array__(self, type=None): + return pa.chunked_array([self.data], type=type) + + arr = MyArray2(np.array([1, 2, 3], dtype='int64')) + result = pa.array(arr) + expected = pa.chunked_array([[1, 2, 3]], type=pa.int64()) + assert result.equals(expected) + + +def test_concat_array(): + concatenated = pa.concat_arrays( + [pa.array([1, 2]), pa.array([3, 4])]) + assert concatenated.equals(pa.array([1, 2, 3, 4])) + + +def test_concat_array_different_types(): + with pytest.raises(pa.ArrowInvalid): + pa.concat_arrays([pa.array([1]), pa.array([2.])]) + + +def test_concat_array_invalid_type(): + # ARROW-9920 - do not segfault on non-array input + + with pytest.raises(TypeError, match="should contain Array objects"): + pa.concat_arrays([None]) + + arr = pa.chunked_array([[0, 1], [3, 4]]) + with pytest.raises(TypeError, match="should contain Array objects"): + pa.concat_arrays(arr) + + +@pytest.mark.pandas +def test_to_pandas_timezone(): + # https://issues.apache.org/jira/browse/ARROW-6652 + arr = pa.array([1, 2, 3], type=pa.timestamp('s', tz='Europe/Brussels')) + s = arr.to_pandas() + assert s.dt.tz is not None + arr = pa.chunked_array([arr]) + s = arr.to_pandas() + assert s.dt.tz is not None diff --git a/src/arrow/python/pyarrow/tests/test_builder.py b/src/arrow/python/pyarrow/tests/test_builder.py new file mode 100644 index 000000000..50d801026 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_builder.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import weakref + +import numpy as np + +import pyarrow as pa +from pyarrow.lib import StringBuilder + + +def test_weakref(): + sbuilder = StringBuilder() + wr = weakref.ref(sbuilder) + assert wr() is not None + del sbuilder + assert wr() is None + + +def test_string_builder_append(): + sbuilder = StringBuilder() + sbuilder.append(b"a byte string") + sbuilder.append("a string") + sbuilder.append(np.nan) + sbuilder.append(None) + assert len(sbuilder) == 4 + assert sbuilder.null_count == 2 + arr = sbuilder.finish() + assert len(sbuilder) == 0 + assert isinstance(arr, pa.Array) + assert arr.null_count == 2 + assert arr.type == 'str' + expected = ["a byte string", "a string", None, None] + assert arr.to_pylist() == expected + + +def test_string_builder_append_values(): + sbuilder = StringBuilder() + sbuilder.append_values([np.nan, None, "text", None, "other text"]) + assert sbuilder.null_count == 3 + arr = sbuilder.finish() + assert arr.null_count == 3 + expected = [None, None, "text", None, "other text"] + assert arr.to_pylist() == expected + + +def test_string_builder_append_after_finish(): + sbuilder = StringBuilder() + sbuilder.append_values([np.nan, None, "text", None, "other text"]) + arr = sbuilder.finish() + sbuilder.append("No effect") + expected = [None, None, "text", None, "other text"] + assert arr.to_pylist() == expected diff --git a/src/arrow/python/pyarrow/tests/test_cffi.py b/src/arrow/python/pyarrow/tests/test_cffi.py new file mode 100644 index 000000000..f0ce42909 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_cffi.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import gc + +import pyarrow as pa +try: + from pyarrow.cffi import ffi +except ImportError: + ffi = None + +import pytest + +try: + import pandas as pd + import pandas.testing as tm +except ImportError: + pd = tm = None + + +needs_cffi = pytest.mark.skipif(ffi is None, + reason="test needs cffi package installed") + + +assert_schema_released = pytest.raises( + ValueError, match="Cannot import released ArrowSchema") + +assert_array_released = pytest.raises( + ValueError, match="Cannot import released ArrowArray") + +assert_stream_released = pytest.raises( + ValueError, match="Cannot import released ArrowArrayStream") + + +class ParamExtType(pa.PyExtensionType): + + def __init__(self, width): + self._width = width + pa.PyExtensionType.__init__(self, pa.binary(width)) + + @property + def width(self): + return self._width + + def __reduce__(self): + return ParamExtType, (self.width,) + + +def make_schema(): + return pa.schema([('ints', pa.list_(pa.int32()))], + metadata={b'key1': b'value1'}) + + +def make_extension_schema(): + return pa.schema([('ext', ParamExtType(3))], + metadata={b'key1': b'value1'}) + + +def make_batch(): + return pa.record_batch([[[1], [2, 42]]], make_schema()) + + +def make_extension_batch(): + schema = make_extension_schema() + ext_col = schema[0].type.wrap_array(pa.array([b"foo", b"bar"], + type=pa.binary(3))) + return pa.record_batch([ext_col], schema) + + +def make_batches(): + schema = make_schema() + return [ + pa.record_batch([[[1], [2, 42]]], schema), + pa.record_batch([[None, [], [5, 6]]], schema), + ] + + +def make_serialized(schema, batches): + with pa.BufferOutputStream() as sink: + with pa.ipc.new_stream(sink, schema) as out: + for batch in batches: + out.write(batch) + return sink.getvalue() + + +@needs_cffi +def test_export_import_type(): + c_schema = ffi.new("struct ArrowSchema*") + ptr_schema = int(ffi.cast("uintptr_t", c_schema)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + typ = pa.list_(pa.int32()) + typ._export_to_c(ptr_schema) + assert pa.total_allocated_bytes() > old_allocated + # Delete and recreate C++ object from exported pointer + del typ + assert pa.total_allocated_bytes() > old_allocated + typ_new = pa.DataType._import_from_c(ptr_schema) + assert typ_new == pa.list_(pa.int32()) + assert pa.total_allocated_bytes() == old_allocated + # Now released + with assert_schema_released: + pa.DataType._import_from_c(ptr_schema) + + # Invalid format string + pa.int32()._export_to_c(ptr_schema) + bad_format = ffi.new("char[]", b"zzz") + c_schema.format = bad_format + with pytest.raises(ValueError, + match="Invalid or unsupported format string"): + pa.DataType._import_from_c(ptr_schema) + # Now released + with assert_schema_released: + pa.DataType._import_from_c(ptr_schema) + + +@needs_cffi +def test_export_import_field(): + c_schema = ffi.new("struct ArrowSchema*") + ptr_schema = int(ffi.cast("uintptr_t", c_schema)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + field = pa.field("test", pa.list_(pa.int32()), nullable=True) + field._export_to_c(ptr_schema) + assert pa.total_allocated_bytes() > old_allocated + # Delete and recreate C++ object from exported pointer + del field + assert pa.total_allocated_bytes() > old_allocated + + field_new = pa.Field._import_from_c(ptr_schema) + assert field_new == pa.field("test", pa.list_(pa.int32()), nullable=True) + assert pa.total_allocated_bytes() == old_allocated + + # Now released + with assert_schema_released: + pa.Field._import_from_c(ptr_schema) + + +@needs_cffi +def test_export_import_array(): + c_schema = ffi.new("struct ArrowSchema*") + ptr_schema = int(ffi.cast("uintptr_t", c_schema)) + c_array = ffi.new("struct ArrowArray*") + ptr_array = int(ffi.cast("uintptr_t", c_array)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + # Type is known up front + typ = pa.list_(pa.int32()) + arr = pa.array([[1], [2, 42]], type=typ) + py_value = arr.to_pylist() + arr._export_to_c(ptr_array) + assert pa.total_allocated_bytes() > old_allocated + # Delete recreate C++ object from exported pointer + del arr + arr_new = pa.Array._import_from_c(ptr_array, typ) + assert arr_new.to_pylist() == py_value + assert arr_new.type == pa.list_(pa.int32()) + assert pa.total_allocated_bytes() > old_allocated + del arr_new, typ + assert pa.total_allocated_bytes() == old_allocated + # Now released + with assert_array_released: + pa.Array._import_from_c(ptr_array, pa.list_(pa.int32())) + + # Type is exported and imported at the same time + arr = pa.array([[1], [2, 42]], type=pa.list_(pa.int32())) + py_value = arr.to_pylist() + arr._export_to_c(ptr_array, ptr_schema) + # Delete and recreate C++ objects from exported pointers + del arr + arr_new = pa.Array._import_from_c(ptr_array, ptr_schema) + assert arr_new.to_pylist() == py_value + assert arr_new.type == pa.list_(pa.int32()) + assert pa.total_allocated_bytes() > old_allocated + del arr_new + assert pa.total_allocated_bytes() == old_allocated + # Now released + with assert_schema_released: + pa.Array._import_from_c(ptr_array, ptr_schema) + + +def check_export_import_schema(schema_factory): + c_schema = ffi.new("struct ArrowSchema*") + ptr_schema = int(ffi.cast("uintptr_t", c_schema)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + schema_factory()._export_to_c(ptr_schema) + assert pa.total_allocated_bytes() > old_allocated + # Delete and recreate C++ object from exported pointer + schema_new = pa.Schema._import_from_c(ptr_schema) + assert schema_new == schema_factory() + assert pa.total_allocated_bytes() == old_allocated + del schema_new + assert pa.total_allocated_bytes() == old_allocated + # Now released + with assert_schema_released: + pa.Schema._import_from_c(ptr_schema) + + # Not a struct type + pa.int32()._export_to_c(ptr_schema) + with pytest.raises(ValueError, + match="ArrowSchema describes non-struct type"): + pa.Schema._import_from_c(ptr_schema) + # Now released + with assert_schema_released: + pa.Schema._import_from_c(ptr_schema) + + +@needs_cffi +def test_export_import_schema(): + check_export_import_schema(make_schema) + + +@needs_cffi +def test_export_import_schema_with_extension(): + check_export_import_schema(make_extension_schema) + + +def check_export_import_batch(batch_factory): + c_schema = ffi.new("struct ArrowSchema*") + ptr_schema = int(ffi.cast("uintptr_t", c_schema)) + c_array = ffi.new("struct ArrowArray*") + ptr_array = int(ffi.cast("uintptr_t", c_array)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + # Schema is known up front + batch = batch_factory() + schema = batch.schema + py_value = batch.to_pydict() + batch._export_to_c(ptr_array) + assert pa.total_allocated_bytes() > old_allocated + # Delete and recreate C++ object from exported pointer + del batch + batch_new = pa.RecordBatch._import_from_c(ptr_array, schema) + assert batch_new.to_pydict() == py_value + assert batch_new.schema == schema + assert pa.total_allocated_bytes() > old_allocated + del batch_new, schema + assert pa.total_allocated_bytes() == old_allocated + # Now released + with assert_array_released: + pa.RecordBatch._import_from_c(ptr_array, make_schema()) + + # Type is exported and imported at the same time + batch = batch_factory() + py_value = batch.to_pydict() + batch._export_to_c(ptr_array, ptr_schema) + # Delete and recreate C++ objects from exported pointers + del batch + batch_new = pa.RecordBatch._import_from_c(ptr_array, ptr_schema) + assert batch_new.to_pydict() == py_value + assert batch_new.schema == batch_factory().schema + assert pa.total_allocated_bytes() > old_allocated + del batch_new + assert pa.total_allocated_bytes() == old_allocated + # Now released + with assert_schema_released: + pa.RecordBatch._import_from_c(ptr_array, ptr_schema) + + # Not a struct type + pa.int32()._export_to_c(ptr_schema) + batch_factory()._export_to_c(ptr_array) + with pytest.raises(ValueError, + match="ArrowSchema describes non-struct type"): + pa.RecordBatch._import_from_c(ptr_array, ptr_schema) + # Now released + with assert_schema_released: + pa.RecordBatch._import_from_c(ptr_array, ptr_schema) + + +@needs_cffi +def test_export_import_batch(): + check_export_import_batch(make_batch) + + +@needs_cffi +def test_export_import_batch_with_extension(): + check_export_import_batch(make_extension_batch) + + +def _export_import_batch_reader(ptr_stream, reader_factory): + # Prepare input + batches = make_batches() + schema = batches[0].schema + + reader = reader_factory(schema, batches) + reader._export_to_c(ptr_stream) + # Delete and recreate C++ object from exported pointer + del reader, batches + + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) + assert reader_new.schema == schema + got_batches = list(reader_new) + del reader_new + assert got_batches == make_batches() + + # Test read_pandas() + if pd is not None: + batches = make_batches() + schema = batches[0].schema + expected_df = pa.Table.from_batches(batches).to_pandas() + + reader = reader_factory(schema, batches) + reader._export_to_c(ptr_stream) + del reader, batches + + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) + got_df = reader_new.read_pandas() + del reader_new + tm.assert_frame_equal(expected_df, got_df) + + +def make_ipc_stream_reader(schema, batches): + return pa.ipc.open_stream(make_serialized(schema, batches)) + + +def make_py_record_batch_reader(schema, batches): + return pa.ipc.RecordBatchReader.from_batches(schema, batches) + + +@needs_cffi +@pytest.mark.parametrize('reader_factory', + [make_ipc_stream_reader, + make_py_record_batch_reader]) +def test_export_import_batch_reader(reader_factory): + c_stream = ffi.new("struct ArrowArrayStream*") + ptr_stream = int(ffi.cast("uintptr_t", c_stream)) + + gc.collect() # Make sure no Arrow data dangles in a ref cycle + old_allocated = pa.total_allocated_bytes() + + _export_import_batch_reader(ptr_stream, reader_factory) + + assert pa.total_allocated_bytes() == old_allocated + + # Now released + with assert_stream_released: + pa.ipc.RecordBatchReader._import_from_c(ptr_stream) + + +@needs_cffi +def test_imported_batch_reader_error(): + c_stream = ffi.new("struct ArrowArrayStream*") + ptr_stream = int(ffi.cast("uintptr_t", c_stream)) + + schema = pa.schema([('foo', pa.int32())]) + batches = [pa.record_batch([[1, 2, 3]], schema=schema), + pa.record_batch([[4, 5, 6]], schema=schema)] + buf = make_serialized(schema, batches) + + # Open a corrupt/incomplete stream and export it + reader = pa.ipc.open_stream(buf[:-16]) + reader._export_to_c(ptr_stream) + del reader + + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) + batch = reader_new.read_next_batch() + assert batch == batches[0] + with pytest.raises(OSError, + match="Expected to be able to read 16 bytes " + "for message body, got 8"): + reader_new.read_next_batch() + + # Again, but call read_all() + reader = pa.ipc.open_stream(buf[:-16]) + reader._export_to_c(ptr_stream) + del reader + + reader_new = pa.ipc.RecordBatchReader._import_from_c(ptr_stream) + with pytest.raises(OSError, + match="Expected to be able to read 16 bytes " + "for message body, got 8"): + reader_new.read_all() diff --git a/src/arrow/python/pyarrow/tests/test_compute.py b/src/arrow/python/pyarrow/tests/test_compute.py new file mode 100644 index 000000000..be2da31b9 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_compute.py @@ -0,0 +1,2238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from functools import lru_cache, partial +import inspect +import pickle +import pytest +import random +import sys +import textwrap + +import numpy as np + +try: + import pandas as pd +except ImportError: + pd = None + +import pyarrow as pa +import pyarrow.compute as pc + +all_array_types = [ + ('bool', [True, False, False, True, True]), + ('uint8', np.arange(5)), + ('int8', np.arange(5)), + ('uint16', np.arange(5)), + ('int16', np.arange(5)), + ('uint32', np.arange(5)), + ('int32', np.arange(5)), + ('uint64', np.arange(5, 10)), + ('int64', np.arange(5, 10)), + ('float', np.arange(0, 0.5, 0.1)), + ('double', np.arange(0, 0.5, 0.1)), + ('string', ['a', 'b', None, 'ddd', 'ee']), + ('binary', [b'a', b'b', b'c', b'ddd', b'ee']), + (pa.binary(3), [b'abc', b'bcd', b'cde', b'def', b'efg']), + (pa.list_(pa.int8()), [[1, 2], [3, 4], [5, 6], None, [9, 16]]), + (pa.large_list(pa.int16()), [[1], [2, 3, 4], [5, 6], None, [9, 16]]), + (pa.struct([('a', pa.int8()), ('b', pa.int8())]), [ + {'a': 1, 'b': 2}, None, {'a': 3, 'b': 4}, None, {'a': 5, 'b': 6}]), +] + +exported_functions = [ + func for (name, func) in sorted(pc.__dict__.items()) + if hasattr(func, '__arrow_compute_function__')] + +exported_option_classes = [ + cls for (name, cls) in sorted(pc.__dict__.items()) + if (isinstance(cls, type) and + cls is not pc.FunctionOptions and + issubclass(cls, pc.FunctionOptions))] + +numerical_arrow_types = [ + pa.int8(), + pa.int16(), + pa.int64(), + pa.uint8(), + pa.uint16(), + pa.uint64(), + pa.float32(), + pa.float64() +] + + +def test_exported_functions(): + # Check that all exported concrete functions can be called with + # the right number of arguments. + # Note that unregistered functions (e.g. with a mismatching name) + # will raise KeyError. + functions = exported_functions + assert len(functions) >= 10 + for func in functions: + arity = func.__arrow_compute_function__['arity'] + if arity is Ellipsis: + args = [object()] * 3 + else: + args = [object()] * arity + with pytest.raises(TypeError, + match="Got unexpected argument type " + "<class 'object'> for compute function"): + func(*args) + + +def test_exported_option_classes(): + classes = exported_option_classes + assert len(classes) >= 10 + for cls in classes: + # Option classes must have an introspectable constructor signature, + # and that signature should not have any *args or **kwargs. + sig = inspect.signature(cls) + for param in sig.parameters.values(): + assert param.kind not in (param.VAR_POSITIONAL, + param.VAR_KEYWORD) + + +def test_option_class_equality(): + options = [ + pc.ArraySortOptions(), + pc.AssumeTimezoneOptions("UTC"), + pc.CastOptions.safe(pa.int8()), + pc.CountOptions(), + pc.DayOfWeekOptions(count_from_zero=False, week_start=0), + pc.DictionaryEncodeOptions(), + pc.ElementWiseAggregateOptions(skip_nulls=True), + pc.ExtractRegexOptions("pattern"), + pc.FilterOptions(), + pc.IndexOptions(pa.scalar(1)), + pc.JoinOptions(), + pc.MakeStructOptions(["field", "names"], + field_nullability=[True, True], + field_metadata=[pa.KeyValueMetadata({"a": "1"}), + pa.KeyValueMetadata({"b": "2"})]), + pc.MatchSubstringOptions("pattern"), + pc.ModeOptions(), + pc.NullOptions(), + pc.PadOptions(5), + pc.PartitionNthOptions(1, null_placement="at_start"), + pc.QuantileOptions(), + pc.ReplaceSliceOptions(0, 1, "a"), + pc.ReplaceSubstringOptions("a", "b"), + pc.RoundOptions(2, "towards_infinity"), + pc.RoundToMultipleOptions(100, "towards_infinity"), + pc.ScalarAggregateOptions(), + pc.SelectKOptions(0, sort_keys=[("b", "ascending")]), + pc.SetLookupOptions(pa.array([1])), + pc.SliceOptions(0, 1, 1), + pc.SortOptions([("dummy", "descending")], null_placement="at_start"), + pc.SplitOptions(), + pc.SplitPatternOptions("pattern"), + pc.StrftimeOptions(), + pc.StrptimeOptions("%Y", "s"), + pc.TakeOptions(), + pc.TDigestOptions(), + pc.TrimOptions(" "), + pc.VarianceOptions(), + pc.WeekOptions(week_starts_monday=True, count_from_zero=False, + first_week_is_fully_in_year=False), + ] + # TODO: We should test on windows once ARROW-13168 is resolved. + # Timezone database is not available on Windows yet + if sys.platform != 'win32': + options.append(pc.AssumeTimezoneOptions("Europe/Ljubljana")) + + classes = {type(option) for option in options} + for cls in exported_option_classes: + # Timezone database is not available on Windows yet + if cls not in classes and sys.platform != 'win32' and \ + cls != pc.AssumeTimezoneOptions: + try: + options.append(cls()) + except TypeError: + pytest.fail(f"Options class is not tested: {cls}") + for option in options: + assert option == option + assert repr(option).startswith(option.__class__.__name__) + buf = option.serialize() + deserialized = pc.FunctionOptions.deserialize(buf) + assert option == deserialized + assert repr(option) == repr(deserialized) + for option1, option2 in zip(options, options[1:]): + assert option1 != option2 + + assert repr(pc.IndexOptions(pa.scalar(1))) == "IndexOptions(value=int64:1)" + assert repr(pc.ArraySortOptions()) == \ + "ArraySortOptions(order=Ascending, null_placement=AtEnd)" + + +def test_list_functions(): + assert len(pc.list_functions()) > 10 + assert "add" in pc.list_functions() + + +def _check_get_function(name, expected_func_cls, expected_ker_cls, + min_num_kernels=1): + func = pc.get_function(name) + assert isinstance(func, expected_func_cls) + n = func.num_kernels + assert n >= min_num_kernels + assert n == len(func.kernels) + assert all(isinstance(ker, expected_ker_cls) for ker in func.kernels) + + +def test_get_function_scalar(): + _check_get_function("add", pc.ScalarFunction, pc.ScalarKernel, 8) + + +def test_get_function_vector(): + _check_get_function("unique", pc.VectorFunction, pc.VectorKernel, 8) + + +def test_get_function_scalar_aggregate(): + _check_get_function("mean", pc.ScalarAggregateFunction, + pc.ScalarAggregateKernel, 8) + + +def test_get_function_hash_aggregate(): + _check_get_function("hash_sum", pc.HashAggregateFunction, + pc.HashAggregateKernel, 1) + + +def test_call_function_with_memory_pool(): + arr = pa.array(["foo", "bar", "baz"]) + indices = np.array([2, 2, 1]) + result1 = arr.take(indices) + result2 = pc.call_function('take', [arr, indices], + memory_pool=pa.default_memory_pool()) + expected = pa.array(["baz", "baz", "bar"]) + assert result1.equals(expected) + assert result2.equals(expected) + + result3 = pc.take(arr, indices, memory_pool=pa.default_memory_pool()) + assert result3.equals(expected) + + +def test_pickle_functions(): + # Pickle registered functions + for name in pc.list_functions(): + func = pc.get_function(name) + reconstructed = pickle.loads(pickle.dumps(func)) + assert type(reconstructed) is type(func) + assert reconstructed.name == func.name + assert reconstructed.arity == func.arity + assert reconstructed.num_kernels == func.num_kernels + + +def test_pickle_global_functions(): + # Pickle global wrappers (manual or automatic) of registered functions + for name in pc.list_functions(): + func = getattr(pc, name) + reconstructed = pickle.loads(pickle.dumps(func)) + assert reconstructed is func + + +def test_function_attributes(): + # Sanity check attributes of registered functions + for name in pc.list_functions(): + func = pc.get_function(name) + assert isinstance(func, pc.Function) + assert func.name == name + kernels = func.kernels + assert func.num_kernels == len(kernels) + assert all(isinstance(ker, pc.Kernel) for ker in kernels) + if func.arity is not Ellipsis: + assert func.arity >= 1 + repr(func) + for ker in kernels: + repr(ker) + + +def test_input_type_conversion(): + # Automatic array conversion from Python + arr = pc.add([1, 2], [4, None]) + assert arr.to_pylist() == [5, None] + # Automatic scalar conversion from Python + arr = pc.add([1, 2], 4) + assert arr.to_pylist() == [5, 6] + # Other scalar type + assert pc.equal(["foo", "bar", None], + "foo").to_pylist() == [True, False, None] + + +@pytest.mark.parametrize('arrow_type', numerical_arrow_types) +def test_sum_array(arrow_type): + arr = pa.array([1, 2, 3, 4], type=arrow_type) + assert arr.sum().as_py() == 10 + assert pc.sum(arr).as_py() == 10 + + arr = pa.array([1, 2, 3, 4, None], type=arrow_type) + assert arr.sum().as_py() == 10 + assert pc.sum(arr).as_py() == 10 + + arr = pa.array([None], type=arrow_type) + assert arr.sum().as_py() is None # noqa: E711 + assert pc.sum(arr).as_py() is None # noqa: E711 + assert arr.sum(min_count=0).as_py() == 0 + assert pc.sum(arr, min_count=0).as_py() == 0 + + arr = pa.array([], type=arrow_type) + assert arr.sum().as_py() is None # noqa: E711 + assert arr.sum(min_count=0).as_py() == 0 + assert pc.sum(arr, min_count=0).as_py() == 0 + + +@pytest.mark.parametrize('arrow_type', numerical_arrow_types) +def test_sum_chunked_array(arrow_type): + arr = pa.chunked_array([pa.array([1, 2, 3, 4], type=arrow_type)]) + assert pc.sum(arr).as_py() == 10 + + arr = pa.chunked_array([ + pa.array([1, 2], type=arrow_type), pa.array([3, 4], type=arrow_type) + ]) + assert pc.sum(arr).as_py() == 10 + + arr = pa.chunked_array([ + pa.array([1, 2], type=arrow_type), + pa.array([], type=arrow_type), + pa.array([3, 4], type=arrow_type) + ]) + assert pc.sum(arr).as_py() == 10 + + arr = pa.chunked_array((), type=arrow_type) + assert arr.num_chunks == 0 + assert pc.sum(arr).as_py() is None # noqa: E711 + assert pc.sum(arr, min_count=0).as_py() == 0 + + +def test_mode_array(): + # ARROW-9917 + arr = pa.array([1, 1, 3, 4, 3, 5], type='int64') + mode = pc.mode(arr) + assert len(mode) == 1 + assert mode[0].as_py() == {"mode": 1, "count": 2} + + mode = pc.mode(arr, n=2) + assert len(mode) == 2 + assert mode[0].as_py() == {"mode": 1, "count": 2} + assert mode[1].as_py() == {"mode": 3, "count": 2} + + arr = pa.array([], type='int64') + assert len(pc.mode(arr)) == 0 + + arr = pa.array([1, 1, 3, 4, 3, None], type='int64') + mode = pc.mode(arr, skip_nulls=False) + assert len(mode) == 0 + mode = pc.mode(arr, min_count=6) + assert len(mode) == 0 + mode = pc.mode(arr, skip_nulls=False, min_count=5) + assert len(mode) == 0 + + +def test_mode_chunked_array(): + # ARROW-9917 + arr = pa.chunked_array([pa.array([1, 1, 3, 4, 3, 5], type='int64')]) + mode = pc.mode(arr) + assert len(mode) == 1 + assert mode[0].as_py() == {"mode": 1, "count": 2} + + mode = pc.mode(arr, n=2) + assert len(mode) == 2 + assert mode[0].as_py() == {"mode": 1, "count": 2} + assert mode[1].as_py() == {"mode": 3, "count": 2} + + arr = pa.chunked_array((), type='int64') + assert arr.num_chunks == 0 + assert len(pc.mode(arr)) == 0 + + +def test_variance(): + data = [1, 2, 3, 4, 5, 6, 7, 8] + assert pc.variance(data).as_py() == 5.25 + assert pc.variance(data, ddof=0).as_py() == 5.25 + assert pc.variance(data, ddof=1).as_py() == 6.0 + + +def test_count_substring(): + for (ty, offset) in [(pa.string(), pa.int32()), + (pa.large_string(), pa.int64())]: + arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None], type=ty) + + result = pc.count_substring(arr, "ab") + expected = pa.array([1, 1, 2, 0, 0, None], type=offset) + assert expected.equals(result) + + result = pc.count_substring(arr, "ab", ignore_case=True) + expected = pa.array([1, 1, 2, 0, 1, None], type=offset) + assert expected.equals(result) + + +def test_count_substring_regex(): + for (ty, offset) in [(pa.string(), pa.int32()), + (pa.large_string(), pa.int64())]: + arr = pa.array(["ab", "cab", "baAacaa", "ba", "AB", None], type=ty) + + result = pc.count_substring_regex(arr, "a+") + expected = pa.array([1, 1, 3, 1, 0, None], type=offset) + assert expected.equals(result) + + result = pc.count_substring_regex(arr, "a+", ignore_case=True) + expected = pa.array([1, 1, 2, 1, 1, None], type=offset) + assert expected.equals(result) + + +def test_find_substring(): + for ty in [pa.string(), pa.binary(), pa.large_string(), pa.large_binary()]: + arr = pa.array(["ab", "cab", "ba", None], type=ty) + result = pc.find_substring(arr, "ab") + assert result.to_pylist() == [0, 1, -1, None] + + result = pc.find_substring_regex(arr, "a?b") + assert result.to_pylist() == [0, 1, 0, None] + + arr = pa.array(["ab*", "cAB*", "ba", "aB?"], type=ty) + result = pc.find_substring(arr, "aB*", ignore_case=True) + assert result.to_pylist() == [0, 1, -1, -1] + + result = pc.find_substring_regex(arr, "a?b", ignore_case=True) + assert result.to_pylist() == [0, 1, 0, 0] + + +def test_match_like(): + arr = pa.array(["ab", "ba%", "ba", "ca%d", None]) + result = pc.match_like(arr, r"_a\%%") + expected = pa.array([False, True, False, True, None]) + assert expected.equals(result) + + arr = pa.array(["aB", "bA%", "ba", "ca%d", None]) + result = pc.match_like(arr, r"_a\%%", ignore_case=True) + expected = pa.array([False, True, False, True, None]) + assert expected.equals(result) + result = pc.match_like(arr, r"_a\%%", ignore_case=False) + expected = pa.array([False, False, False, True, None]) + assert expected.equals(result) + + +def test_match_substring(): + arr = pa.array(["ab", "abc", "ba", None]) + result = pc.match_substring(arr, "ab") + expected = pa.array([True, True, False, None]) + assert expected.equals(result) + + arr = pa.array(["áB", "Ábc", "ba", None]) + result = pc.match_substring(arr, "áb", ignore_case=True) + expected = pa.array([True, True, False, None]) + assert expected.equals(result) + result = pc.match_substring(arr, "áb", ignore_case=False) + expected = pa.array([False, False, False, None]) + assert expected.equals(result) + + +def test_match_substring_regex(): + arr = pa.array(["ab", "abc", "ba", "c", None]) + result = pc.match_substring_regex(arr, "^a?b") + expected = pa.array([True, True, True, False, None]) + assert expected.equals(result) + + arr = pa.array(["aB", "Abc", "BA", "c", None]) + result = pc.match_substring_regex(arr, "^a?b", ignore_case=True) + expected = pa.array([True, True, True, False, None]) + assert expected.equals(result) + result = pc.match_substring_regex(arr, "^a?b", ignore_case=False) + expected = pa.array([False, False, False, False, None]) + assert expected.equals(result) + + +def test_trim(): + # \u3000 is unicode whitespace + arr = pa.array([" foo", None, " \u3000foo bar \t"]) + result = pc.utf8_trim_whitespace(arr) + expected = pa.array(["foo", None, "foo bar"]) + assert expected.equals(result) + + arr = pa.array([" foo", None, " \u3000foo bar \t"]) + result = pc.ascii_trim_whitespace(arr) + expected = pa.array(["foo", None, "\u3000foo bar"]) + assert expected.equals(result) + + arr = pa.array([" foo", None, " \u3000foo bar \t"]) + result = pc.utf8_trim(arr, characters=' f\u3000') + expected = pa.array(["oo", None, "oo bar \t"]) + assert expected.equals(result) + + +def test_slice_compatibility(): + arr = pa.array(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"]) + for start in range(-6, 6): + for stop in range(-6, 6): + for step in [-3, -2, -1, 1, 2, 3]: + expected = pa.array([k.as_py()[start:stop:step] + for k in arr]) + result = pc.utf8_slice_codeunits( + arr, start=start, stop=stop, step=step) + assert expected.equals(result) + + +def test_split_pattern(): + arr = pa.array(["-foo---bar--", "---foo---b"]) + result = pc.split_pattern(arr, pattern="---") + expected = pa.array([["-foo", "bar--"], ["", "foo", "b"]]) + assert expected.equals(result) + + result = pc.split_pattern(arr, pattern="---", max_splits=1) + expected = pa.array([["-foo", "bar--"], ["", "foo---b"]]) + assert expected.equals(result) + + result = pc.split_pattern(arr, pattern="---", max_splits=1, reverse=True) + expected = pa.array([["-foo", "bar--"], ["---foo", "b"]]) + assert expected.equals(result) + + +def test_split_whitespace_utf8(): + arr = pa.array(["foo bar", " foo \u3000\tb"]) + result = pc.utf8_split_whitespace(arr) + expected = pa.array([["foo", "bar"], ["", "foo", "b"]]) + assert expected.equals(result) + + result = pc.utf8_split_whitespace(arr, max_splits=1) + expected = pa.array([["foo", "bar"], ["", "foo \u3000\tb"]]) + assert expected.equals(result) + + result = pc.utf8_split_whitespace(arr, max_splits=1, reverse=True) + expected = pa.array([["foo", "bar"], [" foo", "b"]]) + assert expected.equals(result) + + +def test_split_whitespace_ascii(): + arr = pa.array(["foo bar", " foo \u3000\tb"]) + result = pc.ascii_split_whitespace(arr) + expected = pa.array([["foo", "bar"], ["", "foo", "\u3000", "b"]]) + assert expected.equals(result) + + result = pc.ascii_split_whitespace(arr, max_splits=1) + expected = pa.array([["foo", "bar"], ["", "foo \u3000\tb"]]) + assert expected.equals(result) + + result = pc.ascii_split_whitespace(arr, max_splits=1, reverse=True) + expected = pa.array([["foo", "bar"], [" foo \u3000", "b"]]) + assert expected.equals(result) + + +def test_split_pattern_regex(): + arr = pa.array(["-foo---bar--", "---foo---b"]) + result = pc.split_pattern_regex(arr, pattern="-+") + expected = pa.array([["", "foo", "bar", ""], ["", "foo", "b"]]) + assert expected.equals(result) + + result = pc.split_pattern_regex(arr, pattern="-+", max_splits=1) + expected = pa.array([["", "foo---bar--"], ["", "foo---b"]]) + assert expected.equals(result) + + with pytest.raises(NotImplementedError, + match="Cannot split in reverse with regex"): + result = pc.split_pattern_regex( + arr, pattern="---", max_splits=1, reverse=True) + + +def test_min_max(): + # An example generated function wrapper with possible options + data = [4, 5, 6, None, 1] + s = pc.min_max(data) + assert s.as_py() == {'min': 1, 'max': 6} + s = pc.min_max(data, options=pc.ScalarAggregateOptions()) + assert s.as_py() == {'min': 1, 'max': 6} + s = pc.min_max(data, options=pc.ScalarAggregateOptions(skip_nulls=True)) + assert s.as_py() == {'min': 1, 'max': 6} + s = pc.min_max(data, options=pc.ScalarAggregateOptions(skip_nulls=False)) + assert s.as_py() == {'min': None, 'max': None} + + # Options as dict of kwargs + s = pc.min_max(data, options={'skip_nulls': False}) + assert s.as_py() == {'min': None, 'max': None} + # Options as named functions arguments + s = pc.min_max(data, skip_nulls=False) + assert s.as_py() == {'min': None, 'max': None} + + # Both options and named arguments + with pytest.raises(TypeError): + s = pc.min_max( + data, options=pc.ScalarAggregateOptions(), skip_nulls=False) + + # Wrong options type + options = pc.TakeOptions() + with pytest.raises(TypeError): + s = pc.min_max(data, options=options) + + # Missing argument + with pytest.raises(ValueError, + match="Function min_max accepts 1 argument"): + s = pc.min_max() + + +def test_any(): + # ARROW-1846 + + options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0) + + a = pa.array([], type='bool') + assert pc.any(a).as_py() is None + assert pc.any(a, min_count=0).as_py() is False + assert pc.any(a, options=options).as_py() is False + + a = pa.array([False, None, True]) + assert pc.any(a).as_py() is True + assert pc.any(a, options=options).as_py() is True + + a = pa.array([False, None, False]) + assert pc.any(a).as_py() is False + assert pc.any(a, options=options).as_py() is None + + +def test_all(): + # ARROW-10301 + + options = pc.ScalarAggregateOptions(skip_nulls=False, min_count=0) + + a = pa.array([], type='bool') + assert pc.all(a).as_py() is None + assert pc.all(a, min_count=0).as_py() is True + assert pc.all(a, options=options).as_py() is True + + a = pa.array([False, True]) + assert pc.all(a).as_py() is False + assert pc.all(a, options=options).as_py() is False + + a = pa.array([True, None]) + assert pc.all(a).as_py() is True + assert pc.all(a, options=options).as_py() is None + + a = pa.chunked_array([[True], [True, None]]) + assert pc.all(a).as_py() is True + assert pc.all(a, options=options).as_py() is None + + a = pa.chunked_array([[True], [False]]) + assert pc.all(a).as_py() is False + assert pc.all(a, options=options).as_py() is False + + +def test_is_valid(): + # An example generated function wrapper without options + data = [4, 5, None] + assert pc.is_valid(data).to_pylist() == [True, True, False] + + with pytest.raises(TypeError): + pc.is_valid(data, options=None) + + +def test_generated_docstrings(): + assert pc.min_max.__doc__ == textwrap.dedent("""\ + Compute the minimum and maximum values of a numeric array. + + Null values are ignored by default. + This can be changed through ScalarAggregateOptions. + + Parameters + ---------- + array : Array-like + Argument to compute function + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + options : pyarrow.compute.ScalarAggregateOptions, optional + Parameters altering compute function semantics. + skip_nulls : optional + Parameter for ScalarAggregateOptions constructor. Either `options` + or `skip_nulls` can be passed, but not both at the same time. + min_count : optional + Parameter for ScalarAggregateOptions constructor. Either `options` + or `min_count` can be passed, but not both at the same time. + """) + assert pc.add.__doc__ == textwrap.dedent("""\ + Add the arguments element-wise. + + Results will wrap around on integer overflow. + Use function "add_checked" if you want overflow + to return an error. + + Parameters + ---------- + x : Array-like or scalar-like + Argument to compute function + y : Array-like or scalar-like + Argument to compute function + memory_pool : pyarrow.MemoryPool, optional + If not passed, will allocate memory from the default memory pool. + """) + + +def test_generated_signatures(): + # The self-documentation provided by signatures should show acceptable + # options and their default values. + sig = inspect.signature(pc.add) + assert str(sig) == "(x, y, *, memory_pool=None)" + sig = inspect.signature(pc.min_max) + assert str(sig) == ("(array, *, memory_pool=None, " + "options=None, skip_nulls=True, min_count=1)") + sig = inspect.signature(pc.quantile) + assert str(sig) == ("(array, *, memory_pool=None, " + "options=None, q=0.5, interpolation='linear', " + "skip_nulls=True, min_count=0)") + sig = inspect.signature(pc.binary_join_element_wise) + assert str(sig) == ("(*strings, memory_pool=None, options=None, " + "null_handling='emit_null', null_replacement='')") + + +# We use isprintable to find about codepoints that Python doesn't know, but +# utf8proc does (or in a future version of Python the other way around). +# These codepoints cannot be compared between Arrow and the Python +# implementation. +@lru_cache() +def find_new_unicode_codepoints(): + new = set() + characters = [chr(c) for c in range(0x80, 0x11000) + if not (0xD800 <= c < 0xE000)] + is_printable = pc.utf8_is_printable(pa.array(characters)).to_pylist() + for i, c in enumerate(characters): + if is_printable[i] != c.isprintable(): + new.add(ord(c)) + return new + + +# Python claims there are not alpha, not sure why, they are in +# gc='Other Letter': https://graphemica.com/%E1%B3%B2 +unknown_issue_is_alpha = {0x1cf2, 0x1cf3} +# utf8proc does not know if codepoints are lower case +utf8proc_issue_is_lower = { + 0xaa, 0xba, 0x2b0, 0x2b1, 0x2b2, 0x2b3, 0x2b4, + 0x2b5, 0x2b6, 0x2b7, 0x2b8, 0x2c0, 0x2c1, 0x2e0, + 0x2e1, 0x2e2, 0x2e3, 0x2e4, 0x37a, 0x1d2c, 0x1d2d, + 0x1d2e, 0x1d2f, 0x1d30, 0x1d31, 0x1d32, 0x1d33, + 0x1d34, 0x1d35, 0x1d36, 0x1d37, 0x1d38, 0x1d39, + 0x1d3a, 0x1d3b, 0x1d3c, 0x1d3d, 0x1d3e, 0x1d3f, + 0x1d40, 0x1d41, 0x1d42, 0x1d43, 0x1d44, 0x1d45, + 0x1d46, 0x1d47, 0x1d48, 0x1d49, 0x1d4a, 0x1d4b, + 0x1d4c, 0x1d4d, 0x1d4e, 0x1d4f, 0x1d50, 0x1d51, + 0x1d52, 0x1d53, 0x1d54, 0x1d55, 0x1d56, 0x1d57, + 0x1d58, 0x1d59, 0x1d5a, 0x1d5b, 0x1d5c, 0x1d5d, + 0x1d5e, 0x1d5f, 0x1d60, 0x1d61, 0x1d62, 0x1d63, + 0x1d64, 0x1d65, 0x1d66, 0x1d67, 0x1d68, 0x1d69, + 0x1d6a, 0x1d78, 0x1d9b, 0x1d9c, 0x1d9d, 0x1d9e, + 0x1d9f, 0x1da0, 0x1da1, 0x1da2, 0x1da3, 0x1da4, + 0x1da5, 0x1da6, 0x1da7, 0x1da8, 0x1da9, 0x1daa, + 0x1dab, 0x1dac, 0x1dad, 0x1dae, 0x1daf, 0x1db0, + 0x1db1, 0x1db2, 0x1db3, 0x1db4, 0x1db5, 0x1db6, + 0x1db7, 0x1db8, 0x1db9, 0x1dba, 0x1dbb, 0x1dbc, + 0x1dbd, 0x1dbe, 0x1dbf, 0x2071, 0x207f, 0x2090, + 0x2091, 0x2092, 0x2093, 0x2094, 0x2095, 0x2096, + 0x2097, 0x2098, 0x2099, 0x209a, 0x209b, 0x209c, + 0x2c7c, 0x2c7d, 0xa69c, 0xa69d, 0xa770, 0xa7f8, + 0xa7f9, 0xab5c, 0xab5d, 0xab5e, 0xab5f, } +# utf8proc does not store if a codepoint is numeric +numeric_info_missing = { + 0x3405, 0x3483, 0x382a, 0x3b4d, 0x4e00, 0x4e03, + 0x4e07, 0x4e09, 0x4e5d, 0x4e8c, 0x4e94, 0x4e96, + 0x4ebf, 0x4ec0, 0x4edf, 0x4ee8, 0x4f0d, 0x4f70, + 0x5104, 0x5146, 0x5169, 0x516b, 0x516d, 0x5341, + 0x5343, 0x5344, 0x5345, 0x534c, 0x53c1, 0x53c2, + 0x53c3, 0x53c4, 0x56db, 0x58f1, 0x58f9, 0x5e7a, + 0x5efe, 0x5eff, 0x5f0c, 0x5f0d, 0x5f0e, 0x5f10, + 0x62fe, 0x634c, 0x67d2, 0x6f06, 0x7396, 0x767e, + 0x8086, 0x842c, 0x8cae, 0x8cb3, 0x8d30, 0x9621, + 0x9646, 0x964c, 0x9678, 0x96f6, 0xf96b, 0xf973, + 0xf978, 0xf9b2, 0xf9d1, 0xf9d3, 0xf9fd, 0x10fc5, + 0x10fc6, 0x10fc7, 0x10fc8, 0x10fc9, 0x10fca, + 0x10fcb, } +# utf8proc has no no digit/numeric information +digit_info_missing = { + 0xb2, 0xb3, 0xb9, 0x1369, 0x136a, 0x136b, 0x136c, + 0x136d, 0x136e, 0x136f, 0x1370, 0x1371, 0x19da, 0x2070, + 0x2074, 0x2075, 0x2076, 0x2077, 0x2078, 0x2079, 0x2080, + 0x2081, 0x2082, 0x2083, 0x2084, 0x2085, 0x2086, 0x2087, + 0x2088, 0x2089, 0x2460, 0x2461, 0x2462, 0x2463, 0x2464, + 0x2465, 0x2466, 0x2467, 0x2468, 0x2474, 0x2475, 0x2476, + 0x2477, 0x2478, 0x2479, 0x247a, 0x247b, 0x247c, 0x2488, + 0x2489, 0x248a, 0x248b, 0x248c, 0x248d, 0x248e, 0x248f, + 0x2490, 0x24ea, 0x24f5, 0x24f6, 0x24f7, 0x24f8, 0x24f9, + 0x24fa, 0x24fb, 0x24fc, 0x24fd, 0x24ff, 0x2776, 0x2777, + 0x2778, 0x2779, 0x277a, 0x277b, 0x277c, 0x277d, 0x277e, + 0x2780, 0x2781, 0x2782, 0x2783, 0x2784, 0x2785, 0x2786, + 0x2787, 0x2788, 0x278a, 0x278b, 0x278c, 0x278d, 0x278e, + 0x278f, 0x2790, 0x2791, 0x2792, 0x10a40, 0x10a41, + 0x10a42, 0x10a43, 0x10e60, 0x10e61, 0x10e62, 0x10e63, + 0x10e64, 0x10e65, 0x10e66, 0x10e67, 0x10e68, } +numeric_info_missing = { + 0x3405, 0x3483, 0x382a, 0x3b4d, 0x4e00, 0x4e03, + 0x4e07, 0x4e09, 0x4e5d, 0x4e8c, 0x4e94, 0x4e96, + 0x4ebf, 0x4ec0, 0x4edf, 0x4ee8, 0x4f0d, 0x4f70, + 0x5104, 0x5146, 0x5169, 0x516b, 0x516d, 0x5341, + 0x5343, 0x5344, 0x5345, 0x534c, 0x53c1, 0x53c2, + 0x53c3, 0x53c4, 0x56db, 0x58f1, 0x58f9, 0x5e7a, + 0x5efe, 0x5eff, 0x5f0c, 0x5f0d, 0x5f0e, 0x5f10, + 0x62fe, 0x634c, 0x67d2, 0x6f06, 0x7396, 0x767e, + 0x8086, 0x842c, 0x8cae, 0x8cb3, 0x8d30, 0x9621, + 0x9646, 0x964c, 0x9678, 0x96f6, 0xf96b, 0xf973, + 0xf978, 0xf9b2, 0xf9d1, 0xf9d3, 0xf9fd, } + +codepoints_ignore = { + 'is_alnum': numeric_info_missing | digit_info_missing | + unknown_issue_is_alpha, + 'is_alpha': unknown_issue_is_alpha, + 'is_digit': digit_info_missing, + 'is_numeric': numeric_info_missing, + 'is_lower': utf8proc_issue_is_lower +} + + +@pytest.mark.parametrize('function_name', ['is_alnum', 'is_alpha', + 'is_ascii', 'is_decimal', + 'is_digit', 'is_lower', + 'is_numeric', 'is_printable', + 'is_space', 'is_upper', ]) +@pytest.mark.parametrize('variant', ['ascii', 'utf8']) +def test_string_py_compat_boolean(function_name, variant): + arrow_name = variant + "_" + function_name + py_name = function_name.replace('_', '') + ignore = codepoints_ignore.get(function_name, set()) | \ + find_new_unicode_codepoints() + for i in range(128 if ascii else 0x11000): + if i in range(0xD800, 0xE000): + continue # bug? pyarrow doesn't allow utf16 surrogates + # the issues we know of, we skip + if i in ignore: + continue + # Compare results with the equivalent Python predicate + # (except "is_space" where functions are known to be incompatible) + c = chr(i) + if hasattr(pc, arrow_name) and function_name != 'is_space': + ar = pa.array([c]) + arrow_func = getattr(pc, arrow_name) + assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() + + +def test_pad(): + arr = pa.array([None, 'a', 'abcd']) + assert pc.ascii_center(arr, width=3).tolist() == [None, ' a ', 'abcd'] + assert pc.ascii_lpad(arr, width=3).tolist() == [None, ' a', 'abcd'] + assert pc.ascii_rpad(arr, width=3).tolist() == [None, 'a ', 'abcd'] + + arr = pa.array([None, 'á', 'abcd']) + assert pc.utf8_center(arr, width=3).tolist() == [None, ' á ', 'abcd'] + assert pc.utf8_lpad(arr, width=3).tolist() == [None, ' á', 'abcd'] + assert pc.utf8_rpad(arr, width=3).tolist() == [None, 'á ', 'abcd'] + + +@pytest.mark.pandas +def test_replace_slice(): + offsets = range(-3, 4) + + arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd', 'abcde']) + series = arr.to_pandas() + for start in offsets: + for stop in offsets: + expected = series.str.slice_replace(start, stop, 'XX') + actual = pc.binary_replace_slice( + arr, start=start, stop=stop, replacement='XX') + assert actual.tolist() == expected.tolist() + + arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd', 'πbθde']) + series = arr.to_pandas() + for start in offsets: + for stop in offsets: + expected = series.str.slice_replace(start, stop, 'XX') + actual = pc.utf8_replace_slice( + arr, start=start, stop=stop, replacement='XX') + assert actual.tolist() == expected.tolist() + + +def test_replace_plain(): + ar = pa.array(['foo', 'food', None]) + ar = pc.replace_substring(ar, pattern='foo', replacement='bar') + assert ar.tolist() == ['bar', 'bard', None] + + +def test_replace_regex(): + ar = pa.array(['foo', 'mood', None]) + ar = pc.replace_substring_regex(ar, pattern='(.)oo', replacement=r'\100') + assert ar.tolist() == ['f00', 'm00d', None] + + +def test_extract_regex(): + ar = pa.array(['a1', 'zb2z']) + struct = pc.extract_regex(ar, pattern=r'(?P<letter>[ab])(?P<digit>\d)') + assert struct.tolist() == [{'letter': 'a', 'digit': '1'}, { + 'letter': 'b', 'digit': '2'}] + + +def test_binary_join(): + ar_list = pa.array([['foo', 'bar'], None, []]) + expected = pa.array(['foo-bar', None, '']) + assert pc.binary_join(ar_list, '-').equals(expected) + + separator_array = pa.array(['1', '2'], type=pa.binary()) + expected = pa.array(['a1b', 'c2d'], type=pa.binary()) + ar_list = pa.array([['a', 'b'], ['c', 'd']], type=pa.list_(pa.binary())) + assert pc.binary_join(ar_list, separator_array).equals(expected) + + +def test_binary_join_element_wise(): + null = pa.scalar(None, type=pa.string()) + arrs = [[None, 'a', 'b'], ['c', None, 'd'], [None, '-', '--']] + assert pc.binary_join_element_wise(*arrs).to_pylist() == \ + [None, None, 'b--d'] + assert pc.binary_join_element_wise('a', 'b', '-').as_py() == 'a-b' + assert pc.binary_join_element_wise('a', null, '-').as_py() is None + assert pc.binary_join_element_wise('a', 'b', null).as_py() is None + + skip = pc.JoinOptions(null_handling='skip') + assert pc.binary_join_element_wise(*arrs, options=skip).to_pylist() == \ + [None, 'a', 'b--d'] + assert pc.binary_join_element_wise( + 'a', 'b', '-', options=skip).as_py() == 'a-b' + assert pc.binary_join_element_wise( + 'a', null, '-', options=skip).as_py() == 'a' + assert pc.binary_join_element_wise( + 'a', 'b', null, options=skip).as_py() is None + + replace = pc.JoinOptions(null_handling='replace', null_replacement='spam') + assert pc.binary_join_element_wise(*arrs, options=replace).to_pylist() == \ + [None, 'a-spam', 'b--d'] + assert pc.binary_join_element_wise( + 'a', 'b', '-', options=replace).as_py() == 'a-b' + assert pc.binary_join_element_wise( + 'a', null, '-', options=replace).as_py() == 'a-spam' + assert pc.binary_join_element_wise( + 'a', 'b', null, options=replace).as_py() is None + + +@pytest.mark.parametrize(('ty', 'values'), all_array_types) +def test_take(ty, values): + arr = pa.array(values, type=ty) + for indices_type in [pa.int8(), pa.int64()]: + indices = pa.array([0, 4, 2, None], type=indices_type) + result = arr.take(indices) + result.validate() + expected = pa.array([values[0], values[4], values[2], None], type=ty) + assert result.equals(expected) + + # empty indices + indices = pa.array([], type=indices_type) + result = arr.take(indices) + result.validate() + expected = pa.array([], type=ty) + assert result.equals(expected) + + indices = pa.array([2, 5]) + with pytest.raises(IndexError): + arr.take(indices) + + indices = pa.array([2, -1]) + with pytest.raises(IndexError): + arr.take(indices) + + +def test_take_indices_types(): + arr = pa.array(range(5)) + + for indices_type in ['uint8', 'int8', 'uint16', 'int16', + 'uint32', 'int32', 'uint64', 'int64']: + indices = pa.array([0, 4, 2, None], type=indices_type) + result = arr.take(indices) + result.validate() + expected = pa.array([0, 4, 2, None]) + assert result.equals(expected) + + for indices_type in [pa.float32(), pa.float64()]: + indices = pa.array([0, 4, 2], type=indices_type) + with pytest.raises(NotImplementedError): + arr.take(indices) + + +def test_take_on_chunked_array(): + # ARROW-9504 + arr = pa.chunked_array([ + [ + "a", + "b", + "c", + "d", + "e" + ], + [ + "f", + "g", + "h", + "i", + "j" + ] + ]) + + indices = np.array([0, 5, 1, 6, 9, 2]) + result = arr.take(indices) + expected = pa.chunked_array([["a", "f", "b", "g", "j", "c"]]) + assert result.equals(expected) + + indices = pa.chunked_array([[1], [9, 2]]) + result = arr.take(indices) + expected = pa.chunked_array([ + [ + "b" + ], + [ + "j", + "c" + ] + ]) + assert result.equals(expected) + + +@pytest.mark.parametrize('ordered', [False, True]) +def test_take_dictionary(ordered): + arr = pa.DictionaryArray.from_arrays([0, 1, 2, 0, 1, 2], ['a', 'b', 'c'], + ordered=ordered) + result = arr.take(pa.array([0, 1, 3])) + result.validate() + assert result.to_pylist() == ['a', 'b', 'a'] + assert result.dictionary.to_pylist() == ['a', 'b', 'c'] + assert result.type.ordered is ordered + + +def test_take_null_type(): + # ARROW-10027 + arr = pa.array([None] * 10) + chunked_arr = pa.chunked_array([[None] * 5] * 2) + batch = pa.record_batch([arr], names=['a']) + table = pa.table({'a': arr}) + + indices = pa.array([1, 3, 7, None]) + assert len(arr.take(indices)) == 4 + assert len(chunked_arr.take(indices)) == 4 + assert len(batch.take(indices).column(0)) == 4 + assert len(table.take(indices).column(0)) == 4 + + +@pytest.mark.parametrize(('ty', 'values'), all_array_types) +def test_drop_null(ty, values): + arr = pa.array(values, type=ty) + result = arr.drop_null() + result.validate(full=True) + indices = [i for i in range(len(arr)) if arr[i].is_valid] + expected = arr.take(pa.array(indices)) + assert result.equals(expected) + + +def test_drop_null_chunked_array(): + arr = pa.chunked_array([["a", None], ["c", "d", None], [None], []]) + expected_drop = pa.chunked_array([["a"], ["c", "d"], [], []]) + + result = arr.drop_null() + assert result.equals(expected_drop) + + +def test_drop_null_record_batch(): + batch = pa.record_batch( + [pa.array(["a", None, "c", "d", None])], names=["a'"]) + result = batch.drop_null() + expected = pa.record_batch([pa.array(["a", "c", "d"])], names=["a'"]) + assert result.equals(expected) + + batch = pa.record_batch( + [pa.array(["a", None, "c", "d", None]), + pa.array([None, None, "c", None, "e"])], names=["a'", "b'"]) + + result = batch.drop_null() + expected = pa.record_batch( + [pa.array(["c"]), pa.array(["c"])], names=["a'", "b'"]) + assert result.equals(expected) + + +def test_drop_null_table(): + table = pa.table([pa.array(["a", None, "c", "d", None])], names=["a"]) + expected = pa.table([pa.array(["a", "c", "d"])], names=["a"]) + result = table.drop_null() + assert result.equals(expected) + + table = pa.table([pa.chunked_array([["a", None], ["c", "d", None]]), + pa.chunked_array([["a", None], [None, "d", None]]), + pa.chunked_array([["a"], ["b"], [None], ["d", None]])], + names=["a", "b", "c"]) + expected = pa.table([pa.array(["a", "d"]), + pa.array(["a", "d"]), + pa.array(["a", "d"])], + names=["a", "b", "c"]) + result = table.drop_null() + assert result.equals(expected) + + table = pa.table([pa.chunked_array([["a", "b"], ["c", "d", "e"]]), + pa.chunked_array([["A"], ["B"], [None], ["D", None]]), + pa.chunked_array([["a`", None], ["c`", "d`", None]])], + names=["a", "b", "c"]) + expected = pa.table([pa.array(["a", "d"]), + pa.array(["A", "D"]), + pa.array(["a`", "d`"])], + names=["a", "b", "c"]) + result = table.drop_null() + assert result.equals(expected) + + +def test_drop_null_null_type(): + arr = pa.array([None] * 10) + chunked_arr = pa.chunked_array([[None] * 5] * 2) + batch = pa.record_batch([arr], names=['a']) + table = pa.table({'a': arr}) + + assert len(arr.drop_null()) == 0 + assert len(chunked_arr.drop_null()) == 0 + assert len(batch.drop_null().column(0)) == 0 + assert len(table.drop_null().column(0)) == 0 + + +@pytest.mark.parametrize(('ty', 'values'), all_array_types) +def test_filter(ty, values): + arr = pa.array(values, type=ty) + + mask = pa.array([True, False, False, True, None]) + result = arr.filter(mask, null_selection_behavior='drop') + result.validate() + assert result.equals(pa.array([values[0], values[3]], type=ty)) + result = arr.filter(mask, null_selection_behavior='emit_null') + result.validate() + assert result.equals(pa.array([values[0], values[3], None], type=ty)) + + # non-boolean dtype + mask = pa.array([0, 1, 0, 1, 0]) + with pytest.raises(NotImplementedError): + arr.filter(mask) + + # wrong length + mask = pa.array([True, False, True]) + with pytest.raises(ValueError, match="must all be the same length"): + arr.filter(mask) + + +def test_filter_chunked_array(): + arr = pa.chunked_array([["a", None], ["c", "d", "e"]]) + expected_drop = pa.chunked_array([["a"], ["e"]]) + expected_null = pa.chunked_array([["a"], [None, "e"]]) + + for mask in [ + # mask is array + pa.array([True, False, None, False, True]), + # mask is chunked array + pa.chunked_array([[True, False, None], [False, True]]), + # mask is python object + [True, False, None, False, True] + ]: + result = arr.filter(mask) + assert result.equals(expected_drop) + result = arr.filter(mask, null_selection_behavior="emit_null") + assert result.equals(expected_null) + + +def test_filter_record_batch(): + batch = pa.record_batch( + [pa.array(["a", None, "c", "d", "e"])], names=["a'"]) + + # mask is array + mask = pa.array([True, False, None, False, True]) + result = batch.filter(mask) + expected = pa.record_batch([pa.array(["a", "e"])], names=["a'"]) + assert result.equals(expected) + + result = batch.filter(mask, null_selection_behavior="emit_null") + expected = pa.record_batch([pa.array(["a", None, "e"])], names=["a'"]) + assert result.equals(expected) + + +def test_filter_table(): + table = pa.table([pa.array(["a", None, "c", "d", "e"])], names=["a"]) + expected_drop = pa.table([pa.array(["a", "e"])], names=["a"]) + expected_null = pa.table([pa.array(["a", None, "e"])], names=["a"]) + + for mask in [ + # mask is array + pa.array([True, False, None, False, True]), + # mask is chunked array + pa.chunked_array([[True, False], [None, False, True]]), + # mask is python object + [True, False, None, False, True] + ]: + result = table.filter(mask) + assert result.equals(expected_drop) + result = table.filter(mask, null_selection_behavior="emit_null") + assert result.equals(expected_null) + + +def test_filter_errors(): + arr = pa.chunked_array([["a", None], ["c", "d", "e"]]) + batch = pa.record_batch( + [pa.array(["a", None, "c", "d", "e"])], names=["a'"]) + table = pa.table([pa.array(["a", None, "c", "d", "e"])], names=["a"]) + + for obj in [arr, batch, table]: + # non-boolean dtype + mask = pa.array([0, 1, 0, 1, 0]) + with pytest.raises(NotImplementedError): + obj.filter(mask) + + # wrong length + mask = pa.array([True, False, True]) + with pytest.raises(pa.ArrowInvalid, + match="must all be the same length"): + obj.filter(mask) + + +def test_filter_null_type(): + # ARROW-10027 + arr = pa.array([None] * 10) + chunked_arr = pa.chunked_array([[None] * 5] * 2) + batch = pa.record_batch([arr], names=['a']) + table = pa.table({'a': arr}) + + mask = pa.array([True, False] * 5) + assert len(arr.filter(mask)) == 5 + assert len(chunked_arr.filter(mask)) == 5 + assert len(batch.filter(mask).column(0)) == 5 + assert len(table.filter(mask).column(0)) == 5 + + +@pytest.mark.parametrize("typ", ["array", "chunked_array"]) +def test_compare_array(typ): + if typ == "array": + def con(values): + return pa.array(values) + else: + def con(values): + return pa.chunked_array([values]) + + arr1 = con([1, 2, 3, 4, None]) + arr2 = con([1, 1, 4, None, 4]) + + result = pc.equal(arr1, arr2) + assert result.equals(con([True, False, False, None, None])) + + result = pc.not_equal(arr1, arr2) + assert result.equals(con([False, True, True, None, None])) + + result = pc.less(arr1, arr2) + assert result.equals(con([False, False, True, None, None])) + + result = pc.less_equal(arr1, arr2) + assert result.equals(con([True, False, True, None, None])) + + result = pc.greater(arr1, arr2) + assert result.equals(con([False, True, False, None, None])) + + result = pc.greater_equal(arr1, arr2) + assert result.equals(con([True, True, False, None, None])) + + +@pytest.mark.parametrize("typ", ["array", "chunked_array"]) +def test_compare_string_scalar(typ): + if typ == "array": + def con(values): + return pa.array(values) + else: + def con(values): + return pa.chunked_array([values]) + + arr = con(['a', 'b', 'c', None]) + scalar = pa.scalar('b') + + result = pc.equal(arr, scalar) + assert result.equals(con([False, True, False, None])) + + if typ == "array": + nascalar = pa.scalar(None, type="string") + result = pc.equal(arr, nascalar) + isnull = pc.is_null(result) + assert isnull.equals(con([True, True, True, True])) + + result = pc.not_equal(arr, scalar) + assert result.equals(con([True, False, True, None])) + + result = pc.less(arr, scalar) + assert result.equals(con([True, False, False, None])) + + result = pc.less_equal(arr, scalar) + assert result.equals(con([True, True, False, None])) + + result = pc.greater(arr, scalar) + assert result.equals(con([False, False, True, None])) + + result = pc.greater_equal(arr, scalar) + assert result.equals(con([False, True, True, None])) + + +@pytest.mark.parametrize("typ", ["array", "chunked_array"]) +def test_compare_scalar(typ): + if typ == "array": + def con(values): + return pa.array(values) + else: + def con(values): + return pa.chunked_array([values]) + + arr = con([1, 2, 3, None]) + scalar = pa.scalar(2) + + result = pc.equal(arr, scalar) + assert result.equals(con([False, True, False, None])) + + if typ == "array": + nascalar = pa.scalar(None, type="int64") + result = pc.equal(arr, nascalar) + assert result.to_pylist() == [None, None, None, None] + + result = pc.not_equal(arr, scalar) + assert result.equals(con([True, False, True, None])) + + result = pc.less(arr, scalar) + assert result.equals(con([True, False, False, None])) + + result = pc.less_equal(arr, scalar) + assert result.equals(con([True, True, False, None])) + + result = pc.greater(arr, scalar) + assert result.equals(con([False, False, True, None])) + + result = pc.greater_equal(arr, scalar) + assert result.equals(con([False, True, True, None])) + + +def test_compare_chunked_array_mixed(): + arr = pa.array([1, 2, 3, 4, None]) + arr_chunked = pa.chunked_array([[1, 2, 3], [4, None]]) + arr_chunked2 = pa.chunked_array([[1, 2], [3, 4, None]]) + + expected = pa.chunked_array([[True, True, True, True, None]]) + + for left, right in [ + (arr, arr_chunked), + (arr_chunked, arr), + (arr_chunked, arr_chunked2), + ]: + result = pc.equal(left, right) + assert result.equals(expected) + + +def test_arithmetic_add(): + left = pa.array([1, 2, 3, 4, 5]) + right = pa.array([0, -1, 1, 2, 3]) + result = pc.add(left, right) + expected = pa.array([1, 1, 4, 6, 8]) + assert result.equals(expected) + + +def test_arithmetic_subtract(): + left = pa.array([1, 2, 3, 4, 5]) + right = pa.array([0, -1, 1, 2, 3]) + result = pc.subtract(left, right) + expected = pa.array([1, 3, 2, 2, 2]) + assert result.equals(expected) + + +def test_arithmetic_multiply(): + left = pa.array([1, 2, 3, 4, 5]) + right = pa.array([0, -1, 1, 2, 3]) + result = pc.multiply(left, right) + expected = pa.array([0, -2, 3, 8, 15]) + assert result.equals(expected) + + +@pytest.mark.parametrize("ty", ["round", "round_to_multiple"]) +def test_round_to_integer(ty): + if ty == "round": + round = pc.round + RoundOptions = partial(pc.RoundOptions, ndigits=0) + elif ty == "round_to_multiple": + round = pc.round_to_multiple + RoundOptions = partial(pc.RoundToMultipleOptions, multiple=1) + + values = [3.2, 3.5, 3.7, 4.5, -3.2, -3.5, -3.7, None] + rmode_and_expected = { + "down": [3, 3, 3, 4, -4, -4, -4, None], + "up": [4, 4, 4, 5, -3, -3, -3, None], + "towards_zero": [3, 3, 3, 4, -3, -3, -3, None], + "towards_infinity": [4, 4, 4, 5, -4, -4, -4, None], + "half_down": [3, 3, 4, 4, -3, -4, -4, None], + "half_up": [3, 4, 4, 5, -3, -3, -4, None], + "half_towards_zero": [3, 3, 4, 4, -3, -3, -4, None], + "half_towards_infinity": [3, 4, 4, 5, -3, -4, -4, None], + "half_to_even": [3, 4, 4, 4, -3, -4, -4, None], + "half_to_odd": [3, 3, 4, 5, -3, -3, -4, None], + } + for round_mode, expected in rmode_and_expected.items(): + options = RoundOptions(round_mode=round_mode) + result = round(values, options=options) + np.testing.assert_array_equal(result, pa.array(expected)) + + +def test_round(): + values = [320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045, None] + ndigits_and_expected = { + -2: [300, 0, 0, 0, -0, -0, -0, None], + -1: [320, 0, 0, 0, -0, -40, -0, None], + 0: [320, 4, 3, 5, -3, -35, -3, None], + 1: [320, 3.5, 3.1, 4.5, -3.2, -35.1, -3, None], + 2: [320, 3.5, 3.08, 4.5, -3.21, -35.12, -3.05, None], + } + for ndigits, expected in ndigits_and_expected.items(): + options = pc.RoundOptions(ndigits, "half_towards_infinity") + result = pc.round(values, options=options) + np.testing.assert_allclose(result, pa.array(expected), equal_nan=True) + + +def test_round_to_multiple(): + values = [320, 3.5, 3.075, 4.5, -3.212, -35.1234, -3.045, None] + multiple_and_expected = { + 2: [320, 4, 4, 4, -4, -36, -4, None], + 0.05: [320, 3.5, 3.1, 4.5, -3.2, -35.1, -3.05, None], + 0.1: [320, 3.5, 3.1, 4.5, -3.2, -35.1, -3, None], + 10: [320, 0, 0, 0, -0, -40, -0, None], + 100: [300, 0, 0, 0, -0, -0, -0, None], + } + for multiple, expected in multiple_and_expected.items(): + options = pc.RoundToMultipleOptions(multiple, "half_towards_infinity") + result = pc.round_to_multiple(values, options=options) + np.testing.assert_allclose(result, pa.array(expected), equal_nan=True) + + with pytest.raises(pa.ArrowInvalid, match="multiple must be positive"): + pc.round_to_multiple(values, multiple=-2) + + +def test_is_null(): + arr = pa.array([1, 2, 3, None]) + result = arr.is_null() + expected = pa.array([False, False, False, True]) + assert result.equals(expected) + assert result.equals(pc.is_null(arr)) + result = arr.is_valid() + expected = pa.array([True, True, True, False]) + assert result.equals(expected) + assert result.equals(pc.is_valid(arr)) + + arr = pa.chunked_array([[1, 2], [3, None]]) + result = arr.is_null() + expected = pa.chunked_array([[False, False], [False, True]]) + assert result.equals(expected) + result = arr.is_valid() + expected = pa.chunked_array([[True, True], [True, False]]) + assert result.equals(expected) + + arr = pa.array([1, 2, 3, None, np.nan]) + result = arr.is_null() + expected = pa.array([False, False, False, True, False]) + assert result.equals(expected) + + result = arr.is_null(nan_is_null=True) + expected = pa.array([False, False, False, True, True]) + assert result.equals(expected) + + +def test_fill_null(): + arr = pa.array([1, 2, None, 4], type=pa.int8()) + fill_value = pa.array([5], type=pa.int8()) + with pytest.raises(pa.ArrowInvalid, + match="Array arguments must all be the same length"): + arr.fill_null(fill_value) + + arr = pa.array([None, None, None, None], type=pa.null()) + fill_value = pa.scalar(None, type=pa.null()) + result = arr.fill_null(fill_value) + expected = pa.array([None, None, None, None]) + assert result.equals(expected) + + arr = pa.array(['a', 'bb', None]) + result = arr.fill_null('ccc') + expected = pa.array(['a', 'bb', 'ccc']) + assert result.equals(expected) + + arr = pa.array([b'a', b'bb', None], type=pa.large_binary()) + result = arr.fill_null('ccc') + expected = pa.array([b'a', b'bb', b'ccc'], type=pa.large_binary()) + assert result.equals(expected) + + arr = pa.array(['a', 'bb', None]) + result = arr.fill_null(None) + expected = pa.array(['a', 'bb', None]) + assert result.equals(expected) + + +@pytest.mark.parametrize('arrow_type', numerical_arrow_types) +def test_fill_null_array(arrow_type): + arr = pa.array([1, 2, None, 4], type=arrow_type) + fill_value = pa.scalar(5, type=arrow_type) + result = arr.fill_null(fill_value) + expected = pa.array([1, 2, 5, 4], type=arrow_type) + assert result.equals(expected) + + # Implicit conversions + result = arr.fill_null(5) + assert result.equals(expected) + + # ARROW-9451: Unsigned integers allow this for some reason + if not pa.types.is_unsigned_integer(arr.type): + with pytest.raises((ValueError, TypeError)): + arr.fill_null('5') + + result = arr.fill_null(pa.scalar(5, type='int8')) + assert result.equals(expected) + + +@pytest.mark.parametrize('arrow_type', numerical_arrow_types) +def test_fill_null_chunked_array(arrow_type): + fill_value = pa.scalar(5, type=arrow_type) + arr = pa.chunked_array([pa.array([None, 2, 3, 4], type=arrow_type)]) + result = arr.fill_null(fill_value) + expected = pa.chunked_array([pa.array([5, 2, 3, 4], type=arrow_type)]) + assert result.equals(expected) + + arr = pa.chunked_array([ + pa.array([1, 2], type=arrow_type), + pa.array([], type=arrow_type), + pa.array([None, 4], type=arrow_type) + ]) + expected = pa.chunked_array([ + pa.array([1, 2], type=arrow_type), + pa.array([], type=arrow_type), + pa.array([5, 4], type=arrow_type) + ]) + result = arr.fill_null(fill_value) + assert result.equals(expected) + + # Implicit conversions + result = arr.fill_null(5) + assert result.equals(expected) + + result = arr.fill_null(pa.scalar(5, type='int8')) + assert result.equals(expected) + + +def test_logical(): + a = pa.array([True, False, False, None]) + b = pa.array([True, True, False, True]) + + assert pc.and_(a, b) == pa.array([True, False, False, None]) + assert pc.and_kleene(a, b) == pa.array([True, False, False, None]) + + assert pc.or_(a, b) == pa.array([True, True, False, None]) + assert pc.or_kleene(a, b) == pa.array([True, True, False, True]) + + assert pc.xor(a, b) == pa.array([False, True, False, None]) + + assert pc.invert(a) == pa.array([False, True, True, None]) + + +def test_cast(): + arr = pa.array([2 ** 63 - 1], type='int64') + + with pytest.raises(pa.ArrowInvalid): + pc.cast(arr, 'int32') + + assert pc.cast(arr, 'int32', safe=False) == pa.array([-1], type='int32') + + arr = pa.array([datetime(2010, 1, 1), datetime(2015, 1, 1)]) + expected = pa.array([1262304000000, 1420070400000], type='timestamp[ms]') + assert pc.cast(arr, 'timestamp[ms]') == expected + + arr = pa.array([[1, 2], [3, 4, 5]], type=pa.large_list(pa.int8())) + expected = pa.array([["1", "2"], ["3", "4", "5"]], + type=pa.list_(pa.utf8())) + assert pc.cast(arr, expected.type) == expected + + +def test_strptime(): + arr = pa.array(["5/1/2020", None, "12/13/1900"]) + + got = pc.strptime(arr, format='%m/%d/%Y', unit='s') + expected = pa.array([datetime(2020, 5, 1), None, datetime(1900, 12, 13)], + type=pa.timestamp('s')) + assert got == expected + + +# TODO: We should test on windows once ARROW-13168 is resolved. +@pytest.mark.pandas +@pytest.mark.skipif(sys.platform == 'win32', + reason="Timezone database is not available on Windows yet") +def test_strftime(): + from pyarrow.vendored.version import Version + + def _fix_timestamp(s): + if Version(pd.__version__) < Version("1.0.0"): + return s.to_series().replace("NaT", pd.NaT) + else: + return s + + times = ["2018-03-10 09:00", "2038-01-31 12:23", None] + timezones = ["CET", "UTC", "Europe/Ljubljana"] + + formats = ["%a", "%A", "%w", "%d", "%b", "%B", "%m", "%y", "%Y", "%H", + "%I", "%p", "%M", "%z", "%Z", "%j", "%U", "%W", "%c", "%x", + "%X", "%%", "%G", "%V", "%u"] + + for timezone in timezones: + ts = pd.to_datetime(times).tz_localize(timezone) + for unit in ["s", "ms", "us", "ns"]: + tsa = pa.array(ts, type=pa.timestamp(unit, timezone)) + for fmt in formats: + options = pc.StrftimeOptions(fmt) + result = pc.strftime(tsa, options=options) + expected = pa.array(_fix_timestamp(ts.strftime(fmt))) + assert result.equals(expected) + + fmt = "%Y-%m-%dT%H:%M:%S" + + # Default format + tsa = pa.array(ts, type=pa.timestamp("s", timezone)) + result = pc.strftime(tsa, options=pc.StrftimeOptions()) + expected = pa.array(_fix_timestamp(ts.strftime(fmt))) + assert result.equals(expected) + + # Default format plus timezone + tsa = pa.array(ts, type=pa.timestamp("s", timezone)) + result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%Z")) + expected = pa.array(_fix_timestamp(ts.strftime(fmt + "%Z"))) + assert result.equals(expected) + + # Pandas %S is equivalent to %S in arrow for unit="s" + tsa = pa.array(ts, type=pa.timestamp("s", timezone)) + options = pc.StrftimeOptions("%S") + result = pc.strftime(tsa, options=options) + expected = pa.array(_fix_timestamp(ts.strftime("%S"))) + assert result.equals(expected) + + # Pandas %S.%f is equivalent to %S in arrow for unit="us" + tsa = pa.array(ts, type=pa.timestamp("us", timezone)) + options = pc.StrftimeOptions("%S") + result = pc.strftime(tsa, options=options) + expected = pa.array(_fix_timestamp(ts.strftime("%S.%f"))) + assert result.equals(expected) + + # Test setting locale + tsa = pa.array(ts, type=pa.timestamp("s", timezone)) + options = pc.StrftimeOptions(fmt, locale="C") + result = pc.strftime(tsa, options=options) + expected = pa.array(_fix_timestamp(ts.strftime(fmt))) + assert result.equals(expected) + + # Test timestamps without timezone + fmt = "%Y-%m-%dT%H:%M:%S" + ts = pd.to_datetime(times) + tsa = pa.array(ts, type=pa.timestamp("s")) + result = pc.strftime(tsa, options=pc.StrftimeOptions(fmt)) + expected = pa.array(_fix_timestamp(ts.strftime(fmt))) + + assert result.equals(expected) + with pytest.raises(pa.ArrowInvalid, + match="Timezone not present, cannot convert to string"): + pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%Z")) + with pytest.raises(pa.ArrowInvalid, + match="Timezone not present, cannot convert to string"): + pc.strftime(tsa, options=pc.StrftimeOptions(fmt + "%z")) + + +def _check_datetime_components(timestamps, timezone=None): + from pyarrow.vendored.version import Version + + ts = pd.to_datetime(timestamps).tz_localize( + "UTC").tz_convert(timezone).to_series() + tsa = pa.array(ts, pa.timestamp("ns", tz=timezone)) + + subseconds = ((ts.dt.microsecond * 10 ** 3 + + ts.dt.nanosecond) * 10 ** -9).round(9) + iso_calendar_fields = [ + pa.field('iso_year', pa.int64()), + pa.field('iso_week', pa.int64()), + pa.field('iso_day_of_week', pa.int64()) + ] + + if Version(pd.__version__) < Version("1.1.0"): + # https://github.com/pandas-dev/pandas/issues/33206 + iso_year = ts.map(lambda x: x.isocalendar()[0]).astype("int64") + iso_week = ts.map(lambda x: x.isocalendar()[1]).astype("int64") + iso_day = ts.map(lambda x: x.isocalendar()[2]).astype("int64") + else: + # Casting is required because pandas isocalendar returns int32 + # while arrow isocalendar returns int64. + iso_year = ts.dt.isocalendar()["year"].astype("int64") + iso_week = ts.dt.isocalendar()["week"].astype("int64") + iso_day = ts.dt.isocalendar()["day"].astype("int64") + + iso_calendar = pa.StructArray.from_arrays( + [iso_year, iso_week, iso_day], + fields=iso_calendar_fields) + + assert pc.year(tsa).equals(pa.array(ts.dt.year)) + assert pc.month(tsa).equals(pa.array(ts.dt.month)) + assert pc.day(tsa).equals(pa.array(ts.dt.day)) + assert pc.day_of_week(tsa).equals(pa.array(ts.dt.dayofweek)) + assert pc.day_of_year(tsa).equals(pa.array(ts.dt.dayofyear)) + assert pc.iso_year(tsa).equals(pa.array(iso_year)) + assert pc.iso_week(tsa).equals(pa.array(iso_week)) + assert pc.iso_calendar(tsa).equals(iso_calendar) + assert pc.quarter(tsa).equals(pa.array(ts.dt.quarter)) + assert pc.hour(tsa).equals(pa.array(ts.dt.hour)) + assert pc.minute(tsa).equals(pa.array(ts.dt.minute)) + assert pc.second(tsa).equals(pa.array(ts.dt.second.values)) + assert pc.millisecond(tsa).equals(pa.array(ts.dt.microsecond // 10 ** 3)) + assert pc.microsecond(tsa).equals(pa.array(ts.dt.microsecond % 10 ** 3)) + assert pc.nanosecond(tsa).equals(pa.array(ts.dt.nanosecond)) + assert pc.subsecond(tsa).equals(pa.array(subseconds)) + + day_of_week_options = pc.DayOfWeekOptions( + count_from_zero=False, week_start=1) + assert pc.day_of_week(tsa, options=day_of_week_options).equals( + pa.array(ts.dt.dayofweek + 1)) + + week_options = pc.WeekOptions( + week_starts_monday=True, count_from_zero=False, + first_week_is_fully_in_year=False) + assert pc.week(tsa, options=week_options).equals(pa.array(iso_week)) + + +@pytest.mark.pandas +def test_extract_datetime_components(): + from pyarrow.vendored.version import Version + + timestamps = ["1970-01-01T00:00:59.123456789", + "2000-02-29T23:23:23.999999999", + "2033-05-18T03:33:20.000000000", + "2020-01-01T01:05:05.001", + "2019-12-31T02:10:10.002", + "2019-12-30T03:15:15.003", + "2009-12-31T04:20:20.004132", + "2010-01-01T05:25:25.005321", + "2010-01-03T06:30:30.006163", + "2010-01-04T07:35:35", + "2006-01-01T08:40:40", + "2005-12-31T09:45:45", + "2008-12-28", + "2008-12-29", + "2012-01-01 01:02:03"] + timezones = ["UTC", "US/Central", "Asia/Kolkata", + "Etc/GMT-4", "Etc/GMT+4", "Australia/Broken_Hill"] + + # Test timezone naive timestamp array + _check_datetime_components(timestamps) + + # Test timezone aware timestamp array + if sys.platform == 'win32': + # TODO: We should test on windows once ARROW-13168 is resolved. + pytest.skip('Timezone database is not available on Windows yet') + elif Version(pd.__version__) < Version('1.0.0'): + pytest.skip('Pandas < 1.0 extracts time components incorrectly.') + else: + for timezone in timezones: + _check_datetime_components(timestamps, timezone) + + +# TODO: We should test on windows once ARROW-13168 is resolved. +@pytest.mark.pandas +@pytest.mark.skipif(sys.platform == 'win32', + reason="Timezone database is not available on Windows yet") +def test_assume_timezone(): + from pyarrow.vendored.version import Version + + ts_type = pa.timestamp("ns") + timestamps = pd.to_datetime(["1970-01-01T00:00:59.123456789", + "2000-02-29T23:23:23.999999999", + "2033-05-18T03:33:20.000000000", + "2020-01-01T01:05:05.001", + "2019-12-31T02:10:10.002", + "2019-12-30T03:15:15.003", + "2009-12-31T04:20:20.004132", + "2010-01-01T05:25:25.005321", + "2010-01-03T06:30:30.006163", + "2010-01-04T07:35:35", + "2006-01-01T08:40:40", + "2005-12-31T09:45:45", + "2008-12-28", + "2008-12-29", + "2012-01-01 01:02:03"]) + nonexistent = pd.to_datetime(["2015-03-29 02:30:00", + "2015-03-29 03:30:00"]) + ambiguous = pd.to_datetime(["2018-10-28 01:20:00", + "2018-10-28 02:36:00", + "2018-10-28 03:46:00"]) + ambiguous_array = pa.array(ambiguous, type=ts_type) + nonexistent_array = pa.array(nonexistent, type=ts_type) + + for timezone in ["UTC", "US/Central", "Asia/Kolkata"]: + options = pc.AssumeTimezoneOptions(timezone) + ta = pa.array(timestamps, type=ts_type) + expected = timestamps.tz_localize(timezone) + result = pc.assume_timezone(ta, options=options) + assert result.equals(pa.array(expected)) + + ta_zoned = pa.array(timestamps, type=pa.timestamp("ns", timezone)) + with pytest.raises(pa.ArrowInvalid, match="already have a timezone:"): + pc.assume_timezone(ta_zoned, options=options) + + invalid_options = pc.AssumeTimezoneOptions("Europe/Brusselsss") + with pytest.raises(ValueError, match="not found in timezone database"): + pc.assume_timezone(ta, options=invalid_options) + + timezone = "Europe/Brussels" + + # nonexistent parameter was introduced in Pandas 0.24.0 + if Version(pd.__version__) >= Version("0.24.0"): + options_nonexistent_raise = pc.AssumeTimezoneOptions(timezone) + options_nonexistent_earliest = pc.AssumeTimezoneOptions( + timezone, ambiguous="raise", nonexistent="earliest") + options_nonexistent_latest = pc.AssumeTimezoneOptions( + timezone, ambiguous="raise", nonexistent="latest") + + with pytest.raises(ValueError, + match="Timestamp doesn't exist in " + f"timezone '{timezone}'"): + pc.assume_timezone(nonexistent_array, + options=options_nonexistent_raise) + + expected = pa.array(nonexistent.tz_localize( + timezone, nonexistent="shift_forward")) + result = pc.assume_timezone( + nonexistent_array, options=options_nonexistent_latest) + expected.equals(result) + + expected = pa.array(nonexistent.tz_localize( + timezone, nonexistent="shift_backward")) + result = pc.assume_timezone( + nonexistent_array, options=options_nonexistent_earliest) + expected.equals(result) + + options_ambiguous_raise = pc.AssumeTimezoneOptions(timezone) + options_ambiguous_latest = pc.AssumeTimezoneOptions( + timezone, ambiguous="latest", nonexistent="raise") + options_ambiguous_earliest = pc.AssumeTimezoneOptions( + timezone, ambiguous="earliest", nonexistent="raise") + + with pytest.raises(ValueError, + match="Timestamp is ambiguous in " + f"timezone '{timezone}'"): + pc.assume_timezone(ambiguous_array, options=options_ambiguous_raise) + + expected = ambiguous.tz_localize(timezone, ambiguous=[True, True, True]) + result = pc.assume_timezone( + ambiguous_array, options=options_ambiguous_earliest) + result.equals(pa.array(expected)) + + expected = ambiguous.tz_localize(timezone, ambiguous=[False, False, False]) + result = pc.assume_timezone( + ambiguous_array, options=options_ambiguous_latest) + result.equals(pa.array(expected)) + + +def test_count(): + arr = pa.array([1, 2, 3, None, None]) + assert pc.count(arr).as_py() == 3 + assert pc.count(arr, mode='only_valid').as_py() == 3 + assert pc.count(arr, mode='only_null').as_py() == 2 + assert pc.count(arr, mode='all').as_py() == 5 + + +def test_index(): + arr = pa.array([0, 1, None, 3, 4], type=pa.int64()) + assert pc.index(arr, pa.scalar(0)).as_py() == 0 + assert pc.index(arr, pa.scalar(2, type=pa.int8())).as_py() == -1 + assert pc.index(arr, 4).as_py() == 4 + assert arr.index(3, start=2).as_py() == 3 + assert arr.index(None).as_py() == -1 + + arr = pa.chunked_array([[1, 2], [1, 3]], type=pa.int64()) + assert arr.index(1).as_py() == 0 + assert arr.index(1, start=2).as_py() == 2 + assert arr.index(1, start=1, end=2).as_py() == -1 + + +def check_partition_nth(data, indices, pivot, null_placement): + indices = indices.to_pylist() + assert len(indices) == len(data) + assert sorted(indices) == list(range(len(data))) + until_pivot = [data[indices[i]] for i in range(pivot)] + after_pivot = [data[indices[i]] for i in range(pivot, len(data))] + p = data[indices[pivot]] + if p is None: + if null_placement == "at_start": + assert all(v is None for v in until_pivot) + else: + assert all(v is None for v in after_pivot) + else: + if null_placement == "at_start": + assert all(v is None or v <= p for v in until_pivot) + assert all(v >= p for v in after_pivot) + else: + assert all(v <= p for v in until_pivot) + assert all(v is None or v >= p for v in after_pivot) + + +def test_partition_nth(): + data = list(range(100, 140)) + random.shuffle(data) + pivot = 10 + indices = pc.partition_nth_indices(data, pivot=pivot) + check_partition_nth(data, indices, pivot, "at_end") + + +def test_partition_nth_null_placement(): + data = list(range(10)) + [None] * 10 + random.shuffle(data) + + for pivot in (0, 7, 13, 19): + for null_placement in ("at_start", "at_end"): + indices = pc.partition_nth_indices(data, pivot=pivot, + null_placement=null_placement) + check_partition_nth(data, indices, pivot, null_placement) + + +def test_select_k_array(): + def validate_select_k(select_k_indices, arr, order, stable_sort=False): + sorted_indices = pc.sort_indices(arr, sort_keys=[("dummy", order)]) + head_k_indices = sorted_indices.slice(0, len(select_k_indices)) + if stable_sort: + assert select_k_indices == head_k_indices + else: + expected = pc.take(arr, head_k_indices) + actual = pc.take(arr, select_k_indices) + assert actual == expected + + arr = pa.array([1, 2, None, 0]) + for k in [0, 2, 4]: + for order in ["descending", "ascending"]: + result = pc.select_k_unstable( + arr, k=k, sort_keys=[("dummy", order)]) + validate_select_k(result, arr, order) + + result = pc.top_k_unstable(arr, k=k) + validate_select_k(result, arr, "descending") + + result = pc.bottom_k_unstable(arr, k=k) + validate_select_k(result, arr, "ascending") + + result = pc.select_k_unstable( + arr, options=pc.SelectKOptions( + k=2, sort_keys=[("dummy", "descending")]) + ) + validate_select_k(result, arr, "descending") + + result = pc.select_k_unstable( + arr, options=pc.SelectKOptions(k=2, sort_keys=[("dummy", "ascending")]) + ) + validate_select_k(result, arr, "ascending") + + +def test_select_k_table(): + def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): + sorted_indices = pc.sort_indices(tbl, sort_keys=sort_keys) + head_k_indices = sorted_indices.slice(0, len(select_k_indices)) + if stable_sort: + assert select_k_indices == head_k_indices + else: + expected = pc.take(tbl, head_k_indices) + actual = pc.take(tbl, select_k_indices) + assert actual == expected + + table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]}) + for k in [0, 2, 4]: + result = pc.select_k_unstable( + table, k=k, sort_keys=[("a", "ascending")]) + validate_select_k(result, table, sort_keys=[("a", "ascending")]) + + result = pc.select_k_unstable( + table, k=k, sort_keys=[("a", "ascending"), ("b", "ascending")]) + validate_select_k( + result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]) + + result = pc.top_k_unstable(table, k=k, sort_keys=["a"]) + validate_select_k(result, table, sort_keys=[("a", "descending")]) + + result = pc.bottom_k_unstable(table, k=k, sort_keys=["a", "b"]) + validate_select_k( + result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]) + + with pytest.raises(ValueError, + match="select_k_unstable requires a nonnegative `k`"): + pc.select_k_unstable(table) + + with pytest.raises(ValueError, + match="select_k_unstable requires a " + "non-empty `sort_keys`"): + pc.select_k_unstable(table, k=2, sort_keys=[]) + + with pytest.raises(ValueError, match="not a valid sort order"): + pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending")]) + + with pytest.raises(ValueError, match="Nonexistent sort key column"): + pc.select_k_unstable(table, k=k, sort_keys=[("unknown", "ascending")]) + + +def test_array_sort_indices(): + arr = pa.array([1, 2, None, 0]) + result = pc.array_sort_indices(arr) + assert result.to_pylist() == [3, 0, 1, 2] + result = pc.array_sort_indices(arr, order="ascending") + assert result.to_pylist() == [3, 0, 1, 2] + result = pc.array_sort_indices(arr, order="descending") + assert result.to_pylist() == [1, 0, 3, 2] + result = pc.array_sort_indices(arr, order="descending", + null_placement="at_start") + assert result.to_pylist() == [2, 1, 0, 3] + + with pytest.raises(ValueError, match="not a valid sort order"): + pc.array_sort_indices(arr, order="nonscending") + + +def test_sort_indices_array(): + arr = pa.array([1, 2, None, 0]) + result = pc.sort_indices(arr) + assert result.to_pylist() == [3, 0, 1, 2] + result = pc.sort_indices(arr, sort_keys=[("dummy", "ascending")]) + assert result.to_pylist() == [3, 0, 1, 2] + result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")]) + assert result.to_pylist() == [1, 0, 3, 2] + result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")], + null_placement="at_start") + assert result.to_pylist() == [2, 1, 0, 3] + # Using SortOptions + result = pc.sort_indices( + arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")]) + ) + assert result.to_pylist() == [1, 0, 3, 2] + result = pc.sort_indices( + arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")], + null_placement="at_start") + ) + assert result.to_pylist() == [2, 1, 0, 3] + + +def test_sort_indices_table(): + table = pa.table({"a": [1, 1, None, 0], "b": [1, 0, 0, 1]}) + + result = pc.sort_indices(table, sort_keys=[("a", "ascending")]) + assert result.to_pylist() == [3, 0, 1, 2] + result = pc.sort_indices(table, sort_keys=[("a", "ascending")], + null_placement="at_start") + assert result.to_pylist() == [2, 3, 0, 1] + + result = pc.sort_indices( + table, sort_keys=[("a", "descending"), ("b", "ascending")] + ) + assert result.to_pylist() == [1, 0, 3, 2] + result = pc.sort_indices( + table, sort_keys=[("a", "descending"), ("b", "ascending")], + null_placement="at_start" + ) + assert result.to_pylist() == [2, 1, 0, 3] + + with pytest.raises(ValueError, match="Must specify one or more sort keys"): + pc.sort_indices(table) + + with pytest.raises(ValueError, match="Nonexistent sort key column"): + pc.sort_indices(table, sort_keys=[("unknown", "ascending")]) + + with pytest.raises(ValueError, match="not a valid sort order"): + pc.sort_indices(table, sort_keys=[("a", "nonscending")]) + + +def test_is_in(): + arr = pa.array([1, 2, None, 1, 2, 3]) + + result = pc.is_in(arr, value_set=pa.array([1, 3, None])) + assert result.to_pylist() == [True, False, True, True, False, True] + + result = pc.is_in(arr, value_set=pa.array([1, 3, None]), skip_nulls=True) + assert result.to_pylist() == [True, False, False, True, False, True] + + result = pc.is_in(arr, value_set=pa.array([1, 3])) + assert result.to_pylist() == [True, False, False, True, False, True] + + result = pc.is_in(arr, value_set=pa.array([1, 3]), skip_nulls=True) + assert result.to_pylist() == [True, False, False, True, False, True] + + +def test_index_in(): + arr = pa.array([1, 2, None, 1, 2, 3]) + + result = pc.index_in(arr, value_set=pa.array([1, 3, None])) + assert result.to_pylist() == [0, None, 2, 0, None, 1] + + result = pc.index_in(arr, value_set=pa.array([1, 3, None]), + skip_nulls=True) + assert result.to_pylist() == [0, None, None, 0, None, 1] + + result = pc.index_in(arr, value_set=pa.array([1, 3])) + assert result.to_pylist() == [0, None, None, 0, None, 1] + + result = pc.index_in(arr, value_set=pa.array([1, 3]), skip_nulls=True) + assert result.to_pylist() == [0, None, None, 0, None, 1] + + +def test_quantile(): + arr = pa.array([1, 2, 3, 4]) + + result = pc.quantile(arr) + assert result.to_pylist() == [2.5] + + result = pc.quantile(arr, interpolation='lower') + assert result.to_pylist() == [2] + result = pc.quantile(arr, interpolation='higher') + assert result.to_pylist() == [3] + result = pc.quantile(arr, interpolation='nearest') + assert result.to_pylist() == [3] + result = pc.quantile(arr, interpolation='midpoint') + assert result.to_pylist() == [2.5] + result = pc.quantile(arr, interpolation='linear') + assert result.to_pylist() == [2.5] + + arr = pa.array([1, 2]) + + result = pc.quantile(arr, q=[0.25, 0.5, 0.75]) + assert result.to_pylist() == [1.25, 1.5, 1.75] + + result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='lower') + assert result.to_pylist() == [1, 1, 1] + result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='higher') + assert result.to_pylist() == [2, 2, 2] + result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='midpoint') + assert result.to_pylist() == [1.5, 1.5, 1.5] + result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='nearest') + assert result.to_pylist() == [1, 1, 2] + result = pc.quantile(arr, q=[0.25, 0.5, 0.75], interpolation='linear') + assert result.to_pylist() == [1.25, 1.5, 1.75] + + with pytest.raises(ValueError, match="Quantile must be between 0 and 1"): + pc.quantile(arr, q=1.1) + with pytest.raises(ValueError, match="not a valid quantile interpolation"): + pc.quantile(arr, interpolation='zzz') + + +def test_tdigest(): + arr = pa.array([1, 2, 3, 4]) + result = pc.tdigest(arr) + assert result.to_pylist() == [2.5] + + arr = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]) + result = pc.tdigest(arr) + assert result.to_pylist() == [2.5] + + arr = pa.array([1, 2, 3, 4]) + result = pc.tdigest(arr, q=[0, 0.5, 1]) + assert result.to_pylist() == [1, 2.5, 4] + + arr = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]) + result = pc.tdigest(arr, q=[0, 0.5, 1]) + assert result.to_pylist() == [1, 2.5, 4] + + +def test_fill_null_segfault(): + # ARROW-12672 + arr = pa.array([None], pa.bool_()).fill_null(False) + result = arr.cast(pa.int8()) + assert result == pa.array([0], pa.int8()) + + +def test_min_max_element_wise(): + arr1 = pa.array([1, 2, 3]) + arr2 = pa.array([3, 1, 2]) + arr3 = pa.array([2, 3, None]) + + result = pc.max_element_wise(arr1, arr2) + assert result == pa.array([3, 2, 3]) + result = pc.min_element_wise(arr1, arr2) + assert result == pa.array([1, 1, 2]) + + result = pc.max_element_wise(arr1, arr2, arr3) + assert result == pa.array([3, 3, 3]) + result = pc.min_element_wise(arr1, arr2, arr3) + assert result == pa.array([1, 1, 2]) + + # with specifying the option + result = pc.max_element_wise(arr1, arr3, skip_nulls=True) + assert result == pa.array([2, 3, 3]) + result = pc.min_element_wise(arr1, arr3, skip_nulls=True) + assert result == pa.array([1, 2, 3]) + result = pc.max_element_wise( + arr1, arr3, options=pc.ElementWiseAggregateOptions()) + assert result == pa.array([2, 3, 3]) + result = pc.min_element_wise( + arr1, arr3, options=pc.ElementWiseAggregateOptions()) + assert result == pa.array([1, 2, 3]) + + # not skipping nulls + result = pc.max_element_wise(arr1, arr3, skip_nulls=False) + assert result == pa.array([2, 3, None]) + result = pc.min_element_wise(arr1, arr3, skip_nulls=False) + assert result == pa.array([1, 2, None]) + + +def test_make_struct(): + assert pc.make_struct(1, 'a').as_py() == {'0': 1, '1': 'a'} + + assert pc.make_struct(1, 'a', field_names=['i', 's']).as_py() == { + 'i': 1, 's': 'a'} + + assert pc.make_struct([1, 2, 3], + "a b c".split()) == pa.StructArray.from_arrays([ + [1, 2, 3], + "a b c".split()], names='0 1'.split()) + + with pytest.raises(ValueError, + match="Array arguments must all be the same length"): + pc.make_struct([1, 2, 3, 4], "a b c".split()) + + with pytest.raises(ValueError, match="0 arguments but 2 field names"): + pc.make_struct(field_names=['one', 'two']) + + +def test_case_when(): + assert pc.case_when(pc.make_struct([True, False, None], + [False, True, None]), + [1, 2, 3], + [11, 12, 13]) == pa.array([1, 12, None]) + + +def test_list_element(): + element_type = pa.struct([('a', pa.float64()), ('b', pa.int8())]) + list_type = pa.list_(element_type) + l1 = [{'a': .4, 'b': 2}, None, {'a': .2, 'b': 4}, None, {'a': 5.6, 'b': 6}] + l2 = [None, {'a': .52, 'b': 3}, {'a': .7, 'b': 4}, None, {'a': .6, 'b': 8}] + lists = pa.array([l1, l2], list_type) + + index = 1 + result = pa.compute.list_element(lists, index) + expected = pa.array([None, {'a': 0.52, 'b': 3}], element_type) + assert result.equals(expected) + + index = 4 + result = pa.compute.list_element(lists, index) + expected = pa.array([{'a': 5.6, 'b': 6}, {'a': .6, 'b': 8}], element_type) + assert result.equals(expected) + + +def test_count_distinct(): + seed = datetime.now() + samples = [seed.replace(year=y) for y in range(1992, 2092)] + arr = pa.array(samples, pa.timestamp("ns")) + result = pa.compute.count_distinct(arr) + expected = pa.scalar(len(samples), type=pa.int64()) + assert result.equals(expected) + + +def test_count_distinct_options(): + arr = pa.array([1, 2, 3, None, None]) + assert pc.count_distinct(arr).as_py() == 3 + assert pc.count_distinct(arr, mode='only_valid').as_py() == 3 + assert pc.count_distinct(arr, mode='only_null').as_py() == 1 + assert pc.count_distinct(arr, mode='all').as_py() == 4 diff --git a/src/arrow/python/pyarrow/tests/test_convert_builtin.py b/src/arrow/python/pyarrow/tests/test_convert_builtin.py new file mode 100644 index 000000000..7a355390a --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_convert_builtin.py @@ -0,0 +1,2309 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import collections +import datetime +import decimal +import itertools +import math +import re + +import hypothesis as h +import numpy as np +import pytz +import pytest + +from pyarrow.pandas_compat import _pandas_api # noqa +import pyarrow as pa +import pyarrow.tests.strategies as past + + +int_type_pairs = [ + (np.int8, pa.int8()), + (np.int16, pa.int16()), + (np.int32, pa.int32()), + (np.int64, pa.int64()), + (np.uint8, pa.uint8()), + (np.uint16, pa.uint16()), + (np.uint32, pa.uint32()), + (np.uint64, pa.uint64())] + + +np_int_types, pa_int_types = zip(*int_type_pairs) + + +class StrangeIterable: + def __init__(self, lst): + self.lst = lst + + def __iter__(self): + return self.lst.__iter__() + + +class MyInt: + def __init__(self, value): + self.value = value + + def __int__(self): + return self.value + + +class MyBrokenInt: + def __int__(self): + 1/0 # MARKER + + +def check_struct_type(ty, expected): + """ + Check a struct type is as expected, but not taking order into account. + """ + assert pa.types.is_struct(ty) + assert set(ty) == set(expected) + + +def test_iterable_types(): + arr1 = pa.array(StrangeIterable([0, 1, 2, 3])) + arr2 = pa.array((0, 1, 2, 3)) + + assert arr1.equals(arr2) + + +def test_empty_iterable(): + arr = pa.array(StrangeIterable([])) + assert len(arr) == 0 + assert arr.null_count == 0 + assert arr.type == pa.null() + assert arr.to_pylist() == [] + + +def test_limited_iterator_types(): + arr1 = pa.array(iter(range(3)), type=pa.int64(), size=3) + arr2 = pa.array((0, 1, 2)) + assert arr1.equals(arr2) + + +def test_limited_iterator_size_overflow(): + arr1 = pa.array(iter(range(3)), type=pa.int64(), size=2) + arr2 = pa.array((0, 1)) + assert arr1.equals(arr2) + + +def test_limited_iterator_size_underflow(): + arr1 = pa.array(iter(range(3)), type=pa.int64(), size=10) + arr2 = pa.array((0, 1, 2)) + assert arr1.equals(arr2) + + +def test_iterator_without_size(): + expected = pa.array((0, 1, 2)) + arr1 = pa.array(iter(range(3))) + assert arr1.equals(expected) + # Same with explicit type + arr1 = pa.array(iter(range(3)), type=pa.int64()) + assert arr1.equals(expected) + + +def test_infinite_iterator(): + expected = pa.array((0, 1, 2)) + arr1 = pa.array(itertools.count(0), size=3) + assert arr1.equals(expected) + # Same with explicit type + arr1 = pa.array(itertools.count(0), type=pa.int64(), size=3) + assert arr1.equals(expected) + + +def _as_list(xs): + return xs + + +def _as_tuple(xs): + return tuple(xs) + + +def _as_deque(xs): + # deque is a sequence while neither tuple nor list + return collections.deque(xs) + + +def _as_dict_values(xs): + # a dict values object is not a sequence, just a regular iterable + dct = {k: v for k, v in enumerate(xs)} + return dct.values() + + +def _as_numpy_array(xs): + arr = np.empty(len(xs), dtype=object) + arr[:] = xs + return arr + + +def _as_set(xs): + return set(xs) + + +SEQUENCE_TYPES = [_as_list, _as_tuple, _as_numpy_array] +ITERABLE_TYPES = [_as_set, _as_dict_values] + SEQUENCE_TYPES +COLLECTIONS_TYPES = [_as_deque] + ITERABLE_TYPES + +parametrize_with_iterable_types = pytest.mark.parametrize( + "seq", ITERABLE_TYPES +) + +parametrize_with_sequence_types = pytest.mark.parametrize( + "seq", SEQUENCE_TYPES +) + +parametrize_with_collections_types = pytest.mark.parametrize( + "seq", COLLECTIONS_TYPES +) + + +@parametrize_with_collections_types +def test_sequence_types(seq): + arr1 = pa.array(seq([1, 2, 3])) + arr2 = pa.array([1, 2, 3]) + + assert arr1.equals(arr2) + + +@parametrize_with_iterable_types +def test_nested_sequence_types(seq): + arr1 = pa.array([seq([1, 2, 3])]) + arr2 = pa.array([[1, 2, 3]]) + + assert arr1.equals(arr2) + + +@parametrize_with_sequence_types +def test_sequence_boolean(seq): + expected = [True, None, False, None] + arr = pa.array(seq(expected)) + assert len(arr) == 4 + assert arr.null_count == 2 + assert arr.type == pa.bool_() + assert arr.to_pylist() == expected + + +@parametrize_with_sequence_types +def test_sequence_numpy_boolean(seq): + expected = [np.bool_(True), None, np.bool_(False), None] + arr = pa.array(seq(expected)) + assert arr.type == pa.bool_() + assert arr.to_pylist() == [True, None, False, None] + + +@parametrize_with_sequence_types +def test_sequence_mixed_numpy_python_bools(seq): + values = np.array([True, False]) + arr = pa.array(seq([values[0], None, values[1], True, False])) + assert arr.type == pa.bool_() + assert arr.to_pylist() == [True, None, False, True, False] + + +@parametrize_with_collections_types +def test_empty_list(seq): + arr = pa.array(seq([])) + assert len(arr) == 0 + assert arr.null_count == 0 + assert arr.type == pa.null() + assert arr.to_pylist() == [] + + +@parametrize_with_sequence_types +def test_nested_lists(seq): + data = [[], [1, 2], None] + arr = pa.array(seq(data)) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.int64()) + assert arr.to_pylist() == data + # With explicit type + arr = pa.array(seq(data), type=pa.list_(pa.int32())) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.int32()) + assert arr.to_pylist() == data + + +@parametrize_with_sequence_types +def test_nested_large_lists(seq): + data = [[], [1, 2], None] + arr = pa.array(seq(data), type=pa.large_list(pa.int16())) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.large_list(pa.int16()) + assert arr.to_pylist() == data + + +@parametrize_with_collections_types +def test_list_with_non_list(seq): + # List types don't accept non-sequences + with pytest.raises(TypeError): + pa.array(seq([[], [1, 2], 3]), type=pa.list_(pa.int64())) + with pytest.raises(TypeError): + pa.array(seq([[], [1, 2], 3]), type=pa.large_list(pa.int64())) + + +@parametrize_with_sequence_types +def test_nested_arrays(seq): + arr = pa.array(seq([np.array([], dtype=np.int64), + np.array([1, 2], dtype=np.int64), None])) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.int64()) + assert arr.to_pylist() == [[], [1, 2], None] + + +@parametrize_with_sequence_types +def test_nested_fixed_size_list(seq): + # sequence of lists + data = [[1, 2], [3, None], None] + arr = pa.array(seq(data), type=pa.list_(pa.int64(), 2)) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.int64(), 2) + assert arr.to_pylist() == data + + # sequence of numpy arrays + data = [np.array([1, 2], dtype='int64'), np.array([3, 4], dtype='int64'), + None] + arr = pa.array(seq(data), type=pa.list_(pa.int64(), 2)) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.int64(), 2) + assert arr.to_pylist() == [[1, 2], [3, 4], None] + + # incorrect length of the lists or arrays + data = [[1, 2, 4], [3, None], None] + for data in [[[1, 2, 3]], [np.array([1, 2, 4], dtype='int64')]]: + with pytest.raises( + ValueError, match="Length of item not correct: expected 2"): + pa.array(seq(data), type=pa.list_(pa.int64(), 2)) + + # with list size of 0 + data = [[], [], None] + arr = pa.array(seq(data), type=pa.list_(pa.int64(), 0)) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.int64(), 0) + assert arr.to_pylist() == [[], [], None] + + +@parametrize_with_sequence_types +def test_sequence_all_none(seq): + arr = pa.array(seq([None, None])) + assert len(arr) == 2 + assert arr.null_count == 2 + assert arr.type == pa.null() + assert arr.to_pylist() == [None, None] + + +@parametrize_with_sequence_types +@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs) +def test_sequence_integer(seq, np_scalar_pa_type): + np_scalar, pa_type = np_scalar_pa_type + expected = [1, None, 3, None, + np.iinfo(np_scalar).min, np.iinfo(np_scalar).max] + arr = pa.array(seq(expected), type=pa_type) + assert len(arr) == 6 + assert arr.null_count == 2 + assert arr.type == pa_type + assert arr.to_pylist() == expected + + +@parametrize_with_collections_types +@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs) +def test_sequence_integer_np_nan(seq, np_scalar_pa_type): + # ARROW-2806: numpy.nan is a double value and thus should produce + # a double array. + _, pa_type = np_scalar_pa_type + with pytest.raises(ValueError): + pa.array(seq([np.nan]), type=pa_type, from_pandas=False) + + arr = pa.array(seq([np.nan]), type=pa_type, from_pandas=True) + expected = [None] + assert len(arr) == 1 + assert arr.null_count == 1 + assert arr.type == pa_type + assert arr.to_pylist() == expected + + +@parametrize_with_sequence_types +@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs) +def test_sequence_integer_nested_np_nan(seq, np_scalar_pa_type): + # ARROW-2806: numpy.nan is a double value and thus should produce + # a double array. + _, pa_type = np_scalar_pa_type + with pytest.raises(ValueError): + pa.array(seq([[np.nan]]), type=pa.list_(pa_type), from_pandas=False) + + arr = pa.array(seq([[np.nan]]), type=pa.list_(pa_type), from_pandas=True) + expected = [[None]] + assert len(arr) == 1 + assert arr.null_count == 0 + assert arr.type == pa.list_(pa_type) + assert arr.to_pylist() == expected + + +@parametrize_with_sequence_types +def test_sequence_integer_inferred(seq): + expected = [1, None, 3, None] + arr = pa.array(seq(expected)) + assert len(arr) == 4 + assert arr.null_count == 2 + assert arr.type == pa.int64() + assert arr.to_pylist() == expected + + +@parametrize_with_sequence_types +@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs) +def test_sequence_numpy_integer(seq, np_scalar_pa_type): + np_scalar, pa_type = np_scalar_pa_type + expected = [np_scalar(1), None, np_scalar(3), None, + np_scalar(np.iinfo(np_scalar).min), + np_scalar(np.iinfo(np_scalar).max)] + arr = pa.array(seq(expected), type=pa_type) + assert len(arr) == 6 + assert arr.null_count == 2 + assert arr.type == pa_type + assert arr.to_pylist() == expected + + +@parametrize_with_sequence_types +@pytest.mark.parametrize("np_scalar_pa_type", int_type_pairs) +def test_sequence_numpy_integer_inferred(seq, np_scalar_pa_type): + np_scalar, pa_type = np_scalar_pa_type + expected = [np_scalar(1), None, np_scalar(3), None] + expected += [np_scalar(np.iinfo(np_scalar).min), + np_scalar(np.iinfo(np_scalar).max)] + arr = pa.array(seq(expected)) + assert len(arr) == 6 + assert arr.null_count == 2 + assert arr.type == pa_type + assert arr.to_pylist() == expected + + +@parametrize_with_sequence_types +def test_sequence_custom_integers(seq): + expected = [0, 42, 2**33 + 1, -2**63] + data = list(map(MyInt, expected)) + arr = pa.array(seq(data), type=pa.int64()) + assert arr.to_pylist() == expected + + +@parametrize_with_collections_types +def test_broken_integers(seq): + data = [MyBrokenInt()] + with pytest.raises(pa.ArrowInvalid, match="tried to convert to int"): + pa.array(seq(data), type=pa.int64()) + + +def test_numpy_scalars_mixed_type(): + # ARROW-4324 + data = [np.int32(10), np.float32(0.5)] + arr = pa.array(data) + expected = pa.array([10, 0.5], type="float64") + assert arr.equals(expected) + + # ARROW-9490 + data = [np.int8(10), np.float32(0.5)] + arr = pa.array(data) + expected = pa.array([10, 0.5], type="float32") + assert arr.equals(expected) + + +@pytest.mark.xfail(reason="Type inference for uint64 not implemented", + raises=OverflowError) +def test_uint64_max_convert(): + data = [0, np.iinfo(np.uint64).max] + + arr = pa.array(data, type=pa.uint64()) + expected = pa.array(np.array(data, dtype='uint64')) + assert arr.equals(expected) + + arr_inferred = pa.array(data) + assert arr_inferred.equals(expected) + + +@pytest.mark.parametrize("bits", [8, 16, 32, 64]) +def test_signed_integer_overflow(bits): + ty = getattr(pa, "int%d" % bits)() + # XXX ideally would always raise OverflowError + with pytest.raises((OverflowError, pa.ArrowInvalid)): + pa.array([2 ** (bits - 1)], ty) + with pytest.raises((OverflowError, pa.ArrowInvalid)): + pa.array([-2 ** (bits - 1) - 1], ty) + + +@pytest.mark.parametrize("bits", [8, 16, 32, 64]) +def test_unsigned_integer_overflow(bits): + ty = getattr(pa, "uint%d" % bits)() + # XXX ideally would always raise OverflowError + with pytest.raises((OverflowError, pa.ArrowInvalid)): + pa.array([2 ** bits], ty) + with pytest.raises((OverflowError, pa.ArrowInvalid)): + pa.array([-1], ty) + + +@parametrize_with_collections_types +@pytest.mark.parametrize("typ", pa_int_types) +def test_integer_from_string_error(seq, typ): + # ARROW-9451: pa.array(['1'], type=pa.uint32()) should not succeed + with pytest.raises(pa.ArrowInvalid): + pa.array(seq(['1']), type=typ) + + +def test_convert_with_mask(): + data = [1, 2, 3, 4, 5] + mask = np.array([False, True, False, False, True]) + + result = pa.array(data, mask=mask) + expected = pa.array([1, None, 3, 4, None]) + + assert result.equals(expected) + + # Mask wrong length + with pytest.raises(ValueError): + pa.array(data, mask=mask[1:]) + + +def test_garbage_collection(): + import gc + + # Force the cyclic garbage collector to run + gc.collect() + + bytes_before = pa.total_allocated_bytes() + pa.array([1, None, 3, None]) + gc.collect() + assert pa.total_allocated_bytes() == bytes_before + + +def test_sequence_double(): + data = [1.5, 1., None, 2.5, None, None] + arr = pa.array(data) + assert len(arr) == 6 + assert arr.null_count == 3 + assert arr.type == pa.float64() + assert arr.to_pylist() == data + + +def test_double_auto_coerce_from_integer(): + # Done as part of ARROW-2814 + data = [1.5, 1., None, 2.5, None, None] + arr = pa.array(data) + + data2 = [1.5, 1, None, 2.5, None, None] + arr2 = pa.array(data2) + + assert arr.equals(arr2) + + data3 = [1, 1.5, None, 2.5, None, None] + arr3 = pa.array(data3) + + data4 = [1., 1.5, None, 2.5, None, None] + arr4 = pa.array(data4) + + assert arr3.equals(arr4) + + +def test_double_integer_coerce_representable_range(): + valid_values = [1.5, 1, 2, None, 1 << 53, -(1 << 53)] + invalid_values = [1.5, 1, 2, None, (1 << 53) + 1] + invalid_values2 = [1.5, 1, 2, None, -((1 << 53) + 1)] + + # it works + pa.array(valid_values) + + # it fails + with pytest.raises(ValueError): + pa.array(invalid_values) + + with pytest.raises(ValueError): + pa.array(invalid_values2) + + +def test_float32_integer_coerce_representable_range(): + f32 = np.float32 + valid_values = [f32(1.5), 1 << 24, -(1 << 24)] + invalid_values = [f32(1.5), (1 << 24) + 1] + invalid_values2 = [f32(1.5), -((1 << 24) + 1)] + + # it works + pa.array(valid_values, type=pa.float32()) + + # it fails + with pytest.raises(ValueError): + pa.array(invalid_values, type=pa.float32()) + + with pytest.raises(ValueError): + pa.array(invalid_values2, type=pa.float32()) + + +def test_mixed_sequence_errors(): + with pytest.raises(ValueError, match="tried to convert to boolean"): + pa.array([True, 'foo'], type=pa.bool_()) + + with pytest.raises(ValueError, match="tried to convert to float32"): + pa.array([1.5, 'foo'], type=pa.float32()) + + with pytest.raises(ValueError, match="tried to convert to double"): + pa.array([1.5, 'foo']) + + +@parametrize_with_sequence_types +@pytest.mark.parametrize("np_scalar,pa_type", [ + (np.float16, pa.float16()), + (np.float32, pa.float32()), + (np.float64, pa.float64()) +]) +@pytest.mark.parametrize("from_pandas", [True, False]) +def test_sequence_numpy_double(seq, np_scalar, pa_type, from_pandas): + data = [np_scalar(1.5), np_scalar(1), None, np_scalar(2.5), None, np.nan] + arr = pa.array(seq(data), from_pandas=from_pandas) + assert len(arr) == 6 + if from_pandas: + assert arr.null_count == 3 + else: + assert arr.null_count == 2 + if from_pandas: + # The NaN is skipped in type inference, otherwise it forces a + # float64 promotion + assert arr.type == pa_type + else: + assert arr.type == pa.float64() + + assert arr.to_pylist()[:4] == data[:4] + if from_pandas: + assert arr.to_pylist()[5] is None + else: + assert np.isnan(arr.to_pylist()[5]) + + +@pytest.mark.parametrize("from_pandas", [True, False]) +@pytest.mark.parametrize("inner_seq", [np.array, list]) +def test_ndarray_nested_numpy_double(from_pandas, inner_seq): + # ARROW-2806 + data = np.array([ + inner_seq([1., 2.]), + inner_seq([1., 2., 3.]), + inner_seq([np.nan]), + None + ], dtype=object) + arr = pa.array(data, from_pandas=from_pandas) + assert len(arr) == 4 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.float64()) + if from_pandas: + assert arr.to_pylist() == [[1.0, 2.0], [1.0, 2.0, 3.0], [None], None] + else: + np.testing.assert_equal(arr.to_pylist(), + [[1., 2.], [1., 2., 3.], [np.nan], None]) + + +def test_nested_ndarray_in_object_array(): + # ARROW-4350 + arr = np.empty(2, dtype=object) + arr[:] = [np.array([1, 2], dtype=np.int64), + np.array([2, 3], dtype=np.int64)] + + arr2 = np.empty(2, dtype=object) + arr2[0] = [3, 4] + arr2[1] = [5, 6] + + expected_type = pa.list_(pa.list_(pa.int64())) + assert pa.infer_type([arr]) == expected_type + + result = pa.array([arr, arr2]) + expected = pa.array([[[1, 2], [2, 3]], [[3, 4], [5, 6]]], + type=expected_type) + + assert result.equals(expected) + + # test case for len-1 arrays to ensure they are interpreted as + # sublists and not scalars + arr = np.empty(2, dtype=object) + arr[:] = [np.array([1]), np.array([2])] + result = pa.array([arr, arr]) + assert result.to_pylist() == [[[1], [2]], [[1], [2]]] + + +@pytest.mark.xfail(reason=("Type inference for multidimensional ndarray " + "not yet implemented"), + raises=AssertionError) +def test_multidimensional_ndarray_as_nested_list(): + # TODO(wesm): see ARROW-5645 + arr = np.array([[1, 2], [2, 3]], dtype=np.int64) + arr2 = np.array([[3, 4], [5, 6]], dtype=np.int64) + + expected_type = pa.list_(pa.list_(pa.int64())) + assert pa.infer_type([arr]) == expected_type + + result = pa.array([arr, arr2]) + expected = pa.array([[[1, 2], [2, 3]], [[3, 4], [5, 6]]], + type=expected_type) + + assert result.equals(expected) + + +@pytest.mark.parametrize(('data', 'value_type'), [ + ([True, False], pa.bool_()), + ([None, None], pa.null()), + ([1, 2, None], pa.int8()), + ([1, 2., 3., None], pa.float32()), + ([datetime.date.today(), None], pa.date32()), + ([None, datetime.date.today()], pa.date64()), + ([datetime.time(1, 1, 1), None], pa.time32('s')), + ([None, datetime.time(2, 2, 2)], pa.time64('us')), + ([datetime.datetime.now(), None], pa.timestamp('us')), + ([datetime.timedelta(seconds=10)], pa.duration('s')), + ([b"a", b"b"], pa.binary()), + ([b"aaa", b"bbb", b"ccc"], pa.binary(3)), + ([b"a", b"b", b"c"], pa.large_binary()), + (["a", "b", "c"], pa.string()), + (["a", "b", "c"], pa.large_string()), + ( + [{"a": 1, "b": 2}, None, {"a": 5, "b": None}], + pa.struct([('a', pa.int8()), ('b', pa.int16())]) + ) +]) +def test_list_array_from_object_ndarray(data, value_type): + ty = pa.list_(value_type) + ndarray = np.array(data, dtype=object) + arr = pa.array([ndarray], type=ty) + assert arr.type.equals(ty) + assert arr.to_pylist() == [data] + + +@pytest.mark.parametrize(('data', 'value_type'), [ + ([[1, 2], [3]], pa.list_(pa.int64())), + ([[1, 2], [3, 4]], pa.list_(pa.int64(), 2)), + ([[1], [2, 3]], pa.large_list(pa.int64())) +]) +def test_nested_list_array_from_object_ndarray(data, value_type): + ndarray = np.empty(len(data), dtype=object) + ndarray[:] = [np.array(item, dtype=object) for item in data] + + ty = pa.list_(value_type) + arr = pa.array([ndarray], type=ty) + assert arr.type.equals(ty) + assert arr.to_pylist() == [data] + + +def test_array_ignore_nan_from_pandas(): + # See ARROW-4324, this reverts logic that was introduced in + # ARROW-2240 + with pytest.raises(ValueError): + pa.array([np.nan, 'str']) + + arr = pa.array([np.nan, 'str'], from_pandas=True) + expected = pa.array([None, 'str']) + assert arr.equals(expected) + + +def test_nested_ndarray_different_dtypes(): + data = [ + np.array([1, 2, 3], dtype='int64'), + None, + np.array([4, 5, 6], dtype='uint32') + ] + + arr = pa.array(data) + expected = pa.array([[1, 2, 3], None, [4, 5, 6]], + type=pa.list_(pa.int64())) + assert arr.equals(expected) + + t2 = pa.list_(pa.uint32()) + arr2 = pa.array(data, type=t2) + expected2 = expected.cast(t2) + assert arr2.equals(expected2) + + +def test_sequence_unicode(): + data = ['foo', 'bar', None, 'mañana'] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.null_count == 1 + assert arr.type == pa.string() + assert arr.to_pylist() == data + + +def check_array_mixed_unicode_bytes(binary_type, string_type): + values = ['qux', b'foo', bytearray(b'barz')] + b_values = [b'qux', b'foo', b'barz'] + u_values = ['qux', 'foo', 'barz'] + + arr = pa.array(values) + expected = pa.array(b_values, type=pa.binary()) + assert arr.type == pa.binary() + assert arr.equals(expected) + + arr = pa.array(values, type=binary_type) + expected = pa.array(b_values, type=binary_type) + assert arr.type == binary_type + assert arr.equals(expected) + + arr = pa.array(values, type=string_type) + expected = pa.array(u_values, type=string_type) + assert arr.type == string_type + assert arr.equals(expected) + + +def test_array_mixed_unicode_bytes(): + check_array_mixed_unicode_bytes(pa.binary(), pa.string()) + check_array_mixed_unicode_bytes(pa.large_binary(), pa.large_string()) + + +@pytest.mark.large_memory +@pytest.mark.parametrize("ty", [pa.large_binary(), pa.large_string()]) +def test_large_binary_array(ty): + # Construct a large binary array with more than 4GB of data + s = b"0123456789abcdefghijklmnopqrstuvwxyz" * 10 + nrepeats = math.ceil((2**32 + 5) / len(s)) + data = [s] * nrepeats + arr = pa.array(data, type=ty) + assert isinstance(arr, pa.Array) + assert arr.type == ty + assert len(arr) == nrepeats + + +@pytest.mark.slow +@pytest.mark.large_memory +@pytest.mark.parametrize("ty", [pa.large_binary(), pa.large_string()]) +def test_large_binary_value(ty): + # Construct a large binary array with a single value larger than 4GB + s = b"0123456789abcdefghijklmnopqrstuvwxyz" + nrepeats = math.ceil((2**32 + 5) / len(s)) + arr = pa.array([b"foo", s * nrepeats, None, b"bar"], type=ty) + assert isinstance(arr, pa.Array) + assert arr.type == ty + assert len(arr) == 4 + buf = arr[1].as_buffer() + assert len(buf) == len(s) * nrepeats + + +@pytest.mark.large_memory +@pytest.mark.parametrize("ty", [pa.binary(), pa.string()]) +def test_string_too_large(ty): + # Construct a binary array with a single value larger than 4GB + s = b"0123456789abcdefghijklmnopqrstuvwxyz" + nrepeats = math.ceil((2**32 + 5) / len(s)) + with pytest.raises(pa.ArrowCapacityError): + pa.array([b"foo", s * nrepeats, None, b"bar"], type=ty) + + +def test_sequence_bytes(): + u1 = b'ma\xc3\xb1ana' + + data = [b'foo', + memoryview(b'dada'), + memoryview(b'd-a-t-a')[::2], # non-contiguous is made contiguous + u1.decode('utf-8'), # unicode gets encoded, + bytearray(b'bar'), + None] + for ty in [None, pa.binary(), pa.large_binary()]: + arr = pa.array(data, type=ty) + assert len(arr) == 6 + assert arr.null_count == 1 + assert arr.type == ty or pa.binary() + assert arr.to_pylist() == [b'foo', b'dada', b'data', u1, b'bar', None] + + +@pytest.mark.parametrize("ty", [pa.string(), pa.large_string()]) +def test_sequence_utf8_to_unicode(ty): + # ARROW-1225 + data = [b'foo', None, b'bar'] + arr = pa.array(data, type=ty) + assert arr.type == ty + assert arr[0].as_py() == 'foo' + + # test a non-utf8 unicode string + val = ('mañana').encode('utf-16-le') + with pytest.raises(pa.ArrowInvalid): + pa.array([val], type=ty) + + +def test_sequence_fixed_size_bytes(): + data = [b'foof', None, bytearray(b'barb'), b'2346'] + arr = pa.array(data, type=pa.binary(4)) + assert len(arr) == 4 + assert arr.null_count == 1 + assert arr.type == pa.binary(4) + assert arr.to_pylist() == [b'foof', None, b'barb', b'2346'] + + +def test_fixed_size_bytes_does_not_accept_varying_lengths(): + data = [b'foo', None, b'barb', b'2346'] + with pytest.raises(pa.ArrowInvalid): + pa.array(data, type=pa.binary(4)) + + +def test_fixed_size_binary_length_check(): + # ARROW-10193 + data = [b'\x19h\r\x9e\x00\x00\x00\x00\x01\x9b\x9fA'] + assert len(data[0]) == 12 + ty = pa.binary(12) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == data + + +def test_sequence_date(): + data = [datetime.date(2000, 1, 1), None, datetime.date(1970, 1, 1), + datetime.date(2040, 2, 26)] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.type == pa.date32() + assert arr.null_count == 1 + assert arr[0].as_py() == datetime.date(2000, 1, 1) + assert arr[1].as_py() is None + assert arr[2].as_py() == datetime.date(1970, 1, 1) + assert arr[3].as_py() == datetime.date(2040, 2, 26) + + +@pytest.mark.parametrize('input', + [(pa.date32(), [10957, None]), + (pa.date64(), [10957 * 86400000, None])]) +def test_sequence_explicit_types(input): + t, ex_values = input + data = [datetime.date(2000, 1, 1), None] + arr = pa.array(data, type=t) + arr2 = pa.array(ex_values, type=t) + + for x in [arr, arr2]: + assert len(x) == 2 + assert x.type == t + assert x.null_count == 1 + assert x[0].as_py() == datetime.date(2000, 1, 1) + assert x[1].as_py() is None + + +def test_date32_overflow(): + # Overflow + data3 = [2**32, None] + with pytest.raises((OverflowError, pa.ArrowException)): + pa.array(data3, type=pa.date32()) + + +@pytest.mark.parametrize(('time_type', 'unit', 'int_type'), [ + (pa.time32, 's', 'int32'), + (pa.time32, 'ms', 'int32'), + (pa.time64, 'us', 'int64'), + (pa.time64, 'ns', 'int64'), +]) +def test_sequence_time_with_timezone(time_type, unit, int_type): + def expected_integer_value(t): + # only use with utc time object because it doesn't adjust with the + # offset + units = ['s', 'ms', 'us', 'ns'] + multiplier = 10**(units.index(unit) * 3) + if t is None: + return None + seconds = ( + t.hour * 3600 + + t.minute * 60 + + t.second + + t.microsecond * 10**-6 + ) + return int(seconds * multiplier) + + def expected_time_value(t): + # only use with utc time object because it doesn't adjust with the + # time objects tzdata + if unit == 's': + return t.replace(microsecond=0) + elif unit == 'ms': + return t.replace(microsecond=(t.microsecond // 1000) * 1000) + else: + return t + + # only timezone naive times are supported in arrow + data = [ + datetime.time(8, 23, 34, 123456), + datetime.time(5, 0, 0, 1000), + None, + datetime.time(1, 11, 56, 432539), + datetime.time(23, 10, 0, 437699) + ] + + ty = time_type(unit) + arr = pa.array(data, type=ty) + assert len(arr) == 5 + assert arr.type == ty + assert arr.null_count == 1 + + # test that the underlying integers are UTC values + values = arr.cast(int_type) + expected = list(map(expected_integer_value, data)) + assert values.to_pylist() == expected + + # test that the scalars are datetime.time objects with UTC timezone + assert arr[0].as_py() == expected_time_value(data[0]) + assert arr[1].as_py() == expected_time_value(data[1]) + assert arr[2].as_py() is None + assert arr[3].as_py() == expected_time_value(data[3]) + assert arr[4].as_py() == expected_time_value(data[4]) + + def tz(hours, minutes=0): + offset = datetime.timedelta(hours=hours, minutes=minutes) + return datetime.timezone(offset) + + +def test_sequence_timestamp(): + data = [ + datetime.datetime(2007, 7, 13, 1, 23, 34, 123456), + None, + datetime.datetime(2006, 1, 13, 12, 34, 56, 432539), + datetime.datetime(2010, 8, 13, 5, 46, 57, 437699) + ] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.type == pa.timestamp('us') + assert arr.null_count == 1 + assert arr[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + assert arr[1].as_py() is None + assert arr[2].as_py() == datetime.datetime(2006, 1, 13, 12, + 34, 56, 432539) + assert arr[3].as_py() == datetime.datetime(2010, 8, 13, 5, + 46, 57, 437699) + + +@pytest.mark.parametrize('timezone', [ + None, + 'UTC', + 'Etc/GMT-1', + 'Europe/Budapest', +]) +@pytest.mark.parametrize('unit', [ + 's', + 'ms', + 'us', + 'ns' +]) +def test_sequence_timestamp_with_timezone(timezone, unit): + def expected_integer_value(dt): + units = ['s', 'ms', 'us', 'ns'] + multiplier = 10**(units.index(unit) * 3) + if dt is None: + return None + else: + # avoid float precision issues + ts = decimal.Decimal(str(dt.timestamp())) + return int(ts * multiplier) + + def expected_datetime_value(dt): + if dt is None: + return None + + if unit == 's': + dt = dt.replace(microsecond=0) + elif unit == 'ms': + dt = dt.replace(microsecond=(dt.microsecond // 1000) * 1000) + + # adjust the timezone + if timezone is None: + # make datetime timezone unaware + return dt.replace(tzinfo=None) + else: + # convert to the expected timezone + return dt.astimezone(pytz.timezone(timezone)) + + data = [ + datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive + pytz.utc.localize( + datetime.datetime(2008, 1, 5, 5, 0, 0, 1000) + ), + None, + pytz.timezone('US/Eastern').localize( + datetime.datetime(2006, 1, 13, 12, 34, 56, 432539) + ), + pytz.timezone('Europe/Moscow').localize( + datetime.datetime(2010, 8, 13, 5, 0, 0, 437699) + ), + ] + utcdata = [ + pytz.utc.localize(data[0]), + data[1], + None, + data[3].astimezone(pytz.utc), + data[4].astimezone(pytz.utc), + ] + + ty = pa.timestamp(unit, tz=timezone) + arr = pa.array(data, type=ty) + assert len(arr) == 5 + assert arr.type == ty + assert arr.null_count == 1 + + # test that the underlying integers are UTC values + values = arr.cast('int64') + expected = list(map(expected_integer_value, utcdata)) + assert values.to_pylist() == expected + + # test that the scalars are datetimes with the correct timezone + for i in range(len(arr)): + assert arr[i].as_py() == expected_datetime_value(utcdata[i]) + + +@pytest.mark.parametrize('timezone', [ + None, + 'UTC', + 'Etc/GMT-1', + 'Europe/Budapest', +]) +def test_pyarrow_ignore_timezone_environment_variable(monkeypatch, timezone): + # note that any non-empty value will evaluate to true + monkeypatch.setenv("PYARROW_IGNORE_TIMEZONE", "1") + data = [ + datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive + pytz.utc.localize( + datetime.datetime(2008, 1, 5, 5, 0, 0, 1000) + ), + pytz.timezone('US/Eastern').localize( + datetime.datetime(2006, 1, 13, 12, 34, 56, 432539) + ), + pytz.timezone('Europe/Moscow').localize( + datetime.datetime(2010, 8, 13, 5, 0, 0, 437699) + ), + ] + + expected = [dt.replace(tzinfo=None) for dt in data] + if timezone is not None: + tzinfo = pytz.timezone(timezone) + expected = [tzinfo.fromutc(dt) for dt in expected] + + ty = pa.timestamp('us', tz=timezone) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == expected + + +def test_sequence_timestamp_with_timezone_inference(): + data = [ + datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive + pytz.utc.localize( + datetime.datetime(2008, 1, 5, 5, 0, 0, 1000) + ), + None, + pytz.timezone('US/Eastern').localize( + datetime.datetime(2006, 1, 13, 12, 34, 56, 432539) + ), + pytz.timezone('Europe/Moscow').localize( + datetime.datetime(2010, 8, 13, 5, 0, 0, 437699) + ), + ] + expected = [ + pa.timestamp('us', tz=None), + pa.timestamp('us', tz='UTC'), + pa.timestamp('us', tz=None), + pa.timestamp('us', tz='US/Eastern'), + pa.timestamp('us', tz='Europe/Moscow') + ] + for dt, expected_type in zip(data, expected): + prepended = [dt] + data + arr = pa.array(prepended) + assert arr.type == expected_type + + +@pytest.mark.pandas +def test_sequence_timestamp_from_mixed_builtin_and_pandas_datetimes(): + import pandas as pd + + data = [ + pd.Timestamp(1184307814123456123, tz=pytz.timezone('US/Eastern'), + unit='ns'), + datetime.datetime(2007, 7, 13, 8, 23, 34, 123456), # naive + pytz.utc.localize( + datetime.datetime(2008, 1, 5, 5, 0, 0, 1000) + ), + None, + ] + utcdata = [ + data[0].astimezone(pytz.utc), + pytz.utc.localize(data[1]), + data[2].astimezone(pytz.utc), + None, + ] + + arr = pa.array(data) + assert arr.type == pa.timestamp('us', tz='US/Eastern') + + values = arr.cast('int64') + expected = [int(dt.timestamp() * 10**6) if dt else None for dt in utcdata] + assert values.to_pylist() == expected + + +def test_sequence_timestamp_out_of_bounds_nanosecond(): + # https://issues.apache.org/jira/browse/ARROW-9768 + # datetime outside of range supported for nanosecond resolution + data = [datetime.datetime(2262, 4, 12)] + with pytest.raises(ValueError, match="out of bounds"): + pa.array(data, type=pa.timestamp('ns')) + + # with microsecond resolution it works fine + arr = pa.array(data, type=pa.timestamp('us')) + assert arr.to_pylist() == data + + # case where the naive is within bounds, but converted to UTC not + tz = datetime.timezone(datetime.timedelta(hours=-1)) + data = [datetime.datetime(2262, 4, 11, 23, tzinfo=tz)] + with pytest.raises(ValueError, match="out of bounds"): + pa.array(data, type=pa.timestamp('ns')) + + arr = pa.array(data, type=pa.timestamp('us')) + assert arr.to_pylist()[0] == datetime.datetime(2262, 4, 12) + + +def test_sequence_numpy_timestamp(): + data = [ + np.datetime64(datetime.datetime(2007, 7, 13, 1, 23, 34, 123456)), + None, + np.datetime64(datetime.datetime(2006, 1, 13, 12, 34, 56, 432539)), + np.datetime64(datetime.datetime(2010, 8, 13, 5, 46, 57, 437699)) + ] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.type == pa.timestamp('us') + assert arr.null_count == 1 + assert arr[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + assert arr[1].as_py() is None + assert arr[2].as_py() == datetime.datetime(2006, 1, 13, 12, + 34, 56, 432539) + assert arr[3].as_py() == datetime.datetime(2010, 8, 13, 5, + 46, 57, 437699) + + +class MyDate(datetime.date): + pass + + +class MyDatetime(datetime.datetime): + pass + + +class MyTimedelta(datetime.timedelta): + pass + + +def test_datetime_subclassing(): + data = [ + MyDate(2007, 7, 13), + ] + date_type = pa.date32() + arr_date = pa.array(data, type=date_type) + assert len(arr_date) == 1 + assert arr_date.type == date_type + assert arr_date[0].as_py() == datetime.date(2007, 7, 13) + + data = [ + MyDatetime(2007, 7, 13, 1, 23, 34, 123456), + ] + + s = pa.timestamp('s') + ms = pa.timestamp('ms') + us = pa.timestamp('us') + + arr_s = pa.array(data, type=s) + assert len(arr_s) == 1 + assert arr_s.type == s + assert arr_s[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 0) + + arr_ms = pa.array(data, type=ms) + assert len(arr_ms) == 1 + assert arr_ms.type == ms + assert arr_ms[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123000) + + arr_us = pa.array(data, type=us) + assert len(arr_us) == 1 + assert arr_us.type == us + assert arr_us[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + + data = [ + MyTimedelta(123, 456, 1002), + ] + + s = pa.duration('s') + ms = pa.duration('ms') + us = pa.duration('us') + + arr_s = pa.array(data) + assert len(arr_s) == 1 + assert arr_s.type == us + assert arr_s[0].as_py() == datetime.timedelta(123, 456, 1002) + + arr_s = pa.array(data, type=s) + assert len(arr_s) == 1 + assert arr_s.type == s + assert arr_s[0].as_py() == datetime.timedelta(123, 456) + + arr_ms = pa.array(data, type=ms) + assert len(arr_ms) == 1 + assert arr_ms.type == ms + assert arr_ms[0].as_py() == datetime.timedelta(123, 456, 1000) + + arr_us = pa.array(data, type=us) + assert len(arr_us) == 1 + assert arr_us.type == us + assert arr_us[0].as_py() == datetime.timedelta(123, 456, 1002) + + +@pytest.mark.xfail(not _pandas_api.have_pandas, + reason="pandas required for nanosecond conversion") +def test_sequence_timestamp_nanoseconds(): + inputs = [ + [datetime.datetime(2007, 7, 13, 1, 23, 34, 123456)], + [MyDatetime(2007, 7, 13, 1, 23, 34, 123456)] + ] + + for data in inputs: + ns = pa.timestamp('ns') + arr_ns = pa.array(data, type=ns) + assert len(arr_ns) == 1 + assert arr_ns.type == ns + assert arr_ns[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + + +@pytest.mark.pandas +def test_sequence_timestamp_from_int_with_unit(): + # TODO(wesm): This test might be rewritten to assert the actual behavior + # when pandas is not installed + + data = [1] + + s = pa.timestamp('s') + ms = pa.timestamp('ms') + us = pa.timestamp('us') + ns = pa.timestamp('ns') + + arr_s = pa.array(data, type=s) + assert len(arr_s) == 1 + assert arr_s.type == s + assert repr(arr_s[0]) == ( + "<pyarrow.TimestampScalar: datetime.datetime(1970, 1, 1, 0, 0, 1)>" + ) + assert str(arr_s[0]) == "1970-01-01 00:00:01" + + arr_ms = pa.array(data, type=ms) + assert len(arr_ms) == 1 + assert arr_ms.type == ms + assert repr(arr_ms[0].as_py()) == ( + "datetime.datetime(1970, 1, 1, 0, 0, 0, 1000)" + ) + assert str(arr_ms[0]) == "1970-01-01 00:00:00.001000" + + arr_us = pa.array(data, type=us) + assert len(arr_us) == 1 + assert arr_us.type == us + assert repr(arr_us[0].as_py()) == ( + "datetime.datetime(1970, 1, 1, 0, 0, 0, 1)" + ) + assert str(arr_us[0]) == "1970-01-01 00:00:00.000001" + + arr_ns = pa.array(data, type=ns) + assert len(arr_ns) == 1 + assert arr_ns.type == ns + assert repr(arr_ns[0].as_py()) == ( + "Timestamp('1970-01-01 00:00:00.000000001')" + ) + assert str(arr_ns[0]) == "1970-01-01 00:00:00.000000001" + + expected_exc = TypeError + + class CustomClass(): + pass + + for ty in [ns, pa.date32(), pa.date64()]: + with pytest.raises(expected_exc): + pa.array([1, CustomClass()], type=ty) + + +@pytest.mark.parametrize('np_scalar', [True, False]) +def test_sequence_duration(np_scalar): + td1 = datetime.timedelta(2, 3601, 1) + td2 = datetime.timedelta(1, 100, 1000) + if np_scalar: + data = [np.timedelta64(td1), None, np.timedelta64(td2)] + else: + data = [td1, None, td2] + + arr = pa.array(data) + assert len(arr) == 3 + assert arr.type == pa.duration('us') + assert arr.null_count == 1 + assert arr[0].as_py() == td1 + assert arr[1].as_py() is None + assert arr[2].as_py() == td2 + + +@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) +def test_sequence_duration_with_unit(unit): + data = [ + datetime.timedelta(3, 22, 1001), + ] + expected = {'s': datetime.timedelta(3, 22), + 'ms': datetime.timedelta(3, 22, 1000), + 'us': datetime.timedelta(3, 22, 1001), + 'ns': datetime.timedelta(3, 22, 1001)} + + ty = pa.duration(unit) + + arr_s = pa.array(data, type=ty) + assert len(arr_s) == 1 + assert arr_s.type == ty + assert arr_s[0].as_py() == expected[unit] + + +@pytest.mark.parametrize('unit', ['s', 'ms', 'us', 'ns']) +def test_sequence_duration_from_int_with_unit(unit): + data = [5] + + ty = pa.duration(unit) + arr = pa.array(data, type=ty) + assert len(arr) == 1 + assert arr.type == ty + assert arr[0].value == 5 + + +def test_sequence_duration_nested_lists(): + td1 = datetime.timedelta(1, 1, 1000) + td2 = datetime.timedelta(1, 100) + + data = [[td1, None], [td1, td2]] + + arr = pa.array(data) + assert len(arr) == 2 + assert arr.type == pa.list_(pa.duration('us')) + assert arr.to_pylist() == data + + arr = pa.array(data, type=pa.list_(pa.duration('ms'))) + assert len(arr) == 2 + assert arr.type == pa.list_(pa.duration('ms')) + assert arr.to_pylist() == data + + +def test_sequence_duration_nested_lists_numpy(): + td1 = datetime.timedelta(1, 1, 1000) + td2 = datetime.timedelta(1, 100) + + data = [[np.timedelta64(td1), None], + [np.timedelta64(td1), np.timedelta64(td2)]] + + arr = pa.array(data) + assert len(arr) == 2 + assert arr.type == pa.list_(pa.duration('us')) + assert arr.to_pylist() == [[td1, None], [td1, td2]] + + data = [np.array([np.timedelta64(td1), None], dtype='timedelta64[us]'), + np.array([np.timedelta64(td1), np.timedelta64(td2)])] + + arr = pa.array(data) + assert len(arr) == 2 + assert arr.type == pa.list_(pa.duration('us')) + assert arr.to_pylist() == [[td1, None], [td1, td2]] + + +def test_sequence_nesting_levels(): + data = [1, 2, None] + arr = pa.array(data) + assert arr.type == pa.int64() + assert arr.to_pylist() == data + + data = [[1], [2], None] + arr = pa.array(data) + assert arr.type == pa.list_(pa.int64()) + assert arr.to_pylist() == data + + data = [[1], [2, 3, 4], [None]] + arr = pa.array(data) + assert arr.type == pa.list_(pa.int64()) + assert arr.to_pylist() == data + + data = [None, [[None, 1]], [[2, 3, 4], None], [None]] + arr = pa.array(data) + assert arr.type == pa.list_(pa.list_(pa.int64())) + assert arr.to_pylist() == data + + exceptions = (pa.ArrowInvalid, pa.ArrowTypeError) + + # Mixed nesting levels are rejected + with pytest.raises(exceptions): + pa.array([1, 2, [1]]) + + with pytest.raises(exceptions): + pa.array([1, 2, []]) + + with pytest.raises(exceptions): + pa.array([[1], [2], [None, [1]]]) + + +def test_sequence_mixed_types_fails(): + data = ['a', 1, 2.0] + with pytest.raises(pa.ArrowTypeError): + pa.array(data) + + +def test_sequence_mixed_types_with_specified_type_fails(): + data = ['-10', '-5', {'a': 1}, '0', '5', '10'] + + type = pa.string() + with pytest.raises(TypeError): + pa.array(data, type=type) + + +def test_sequence_decimal(): + data = [decimal.Decimal('1234.183'), decimal.Decimal('8094.234')] + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=7, scale=3)) + assert arr.to_pylist() == data + + +def test_sequence_decimal_different_precisions(): + data = [ + decimal.Decimal('1234234983.183'), decimal.Decimal('80943244.234') + ] + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=13, scale=3)) + assert arr.to_pylist() == data + + +def test_sequence_decimal_no_scale(): + data = [decimal.Decimal('1234234983'), decimal.Decimal('8094324')] + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=10)) + assert arr.to_pylist() == data + + +def test_sequence_decimal_negative(): + data = [decimal.Decimal('-1234.234983'), decimal.Decimal('-8.094324')] + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=10, scale=6)) + assert arr.to_pylist() == data + + +def test_sequence_decimal_no_whole_part(): + data = [decimal.Decimal('-.4234983'), decimal.Decimal('.0103943')] + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=7, scale=7)) + assert arr.to_pylist() == data + + +def test_sequence_decimal_large_integer(): + data = [decimal.Decimal('-394029506937548693.42983'), + decimal.Decimal('32358695912932.01033')] + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=23, scale=5)) + assert arr.to_pylist() == data + + +def test_sequence_decimal_from_integers(): + data = [0, 1, -39402950693754869342983] + expected = [decimal.Decimal(x) for x in data] + for type in [pa.decimal128, pa.decimal256]: + arr = pa.array(data, type=type(precision=28, scale=5)) + assert arr.to_pylist() == expected + + +def test_sequence_decimal_too_high_precision(): + # ARROW-6989 python decimal has too high precision + with pytest.raises(ValueError, match="precision out of range"): + pa.array([decimal.Decimal('1' * 80)]) + + +def test_sequence_decimal_infer(): + for data, typ in [ + # simple case + (decimal.Decimal('1.234'), pa.decimal128(4, 3)), + # trailing zeros + (decimal.Decimal('12300'), pa.decimal128(5, 0)), + (decimal.Decimal('12300.0'), pa.decimal128(6, 1)), + # scientific power notation + (decimal.Decimal('1.23E+4'), pa.decimal128(5, 0)), + (decimal.Decimal('123E+2'), pa.decimal128(5, 0)), + (decimal.Decimal('123E+4'), pa.decimal128(7, 0)), + # leading zeros + (decimal.Decimal('0.0123'), pa.decimal128(4, 4)), + (decimal.Decimal('0.01230'), pa.decimal128(5, 5)), + (decimal.Decimal('1.230E-2'), pa.decimal128(5, 5)), + ]: + assert pa.infer_type([data]) == typ + arr = pa.array([data]) + assert arr.type == typ + assert arr.to_pylist()[0] == data + + +def test_sequence_decimal_infer_mixed(): + # ARROW-12150 - ensure mixed precision gets correctly inferred to + # common type that can hold all input values + cases = [ + ([decimal.Decimal('1.234'), decimal.Decimal('3.456')], + pa.decimal128(4, 3)), + ([decimal.Decimal('1.234'), decimal.Decimal('456.7')], + pa.decimal128(6, 3)), + ([decimal.Decimal('123.4'), decimal.Decimal('4.567')], + pa.decimal128(6, 3)), + ([decimal.Decimal('123e2'), decimal.Decimal('4567e3')], + pa.decimal128(7, 0)), + ([decimal.Decimal('123e4'), decimal.Decimal('4567e2')], + pa.decimal128(7, 0)), + ([decimal.Decimal('0.123'), decimal.Decimal('0.04567')], + pa.decimal128(5, 5)), + ([decimal.Decimal('0.001'), decimal.Decimal('1.01E5')], + pa.decimal128(9, 3)), + ] + for data, typ in cases: + assert pa.infer_type(data) == typ + arr = pa.array(data) + assert arr.type == typ + assert arr.to_pylist() == data + + +def test_sequence_decimal_given_type(): + for data, typs, wrong_typs in [ + # simple case + ( + decimal.Decimal('1.234'), + [pa.decimal128(4, 3), pa.decimal128(5, 3), pa.decimal128(5, 4)], + [pa.decimal128(4, 2), pa.decimal128(4, 4)] + ), + # trailing zeros + ( + decimal.Decimal('12300'), + [pa.decimal128(5, 0), pa.decimal128(6, 0), pa.decimal128(3, -2)], + [pa.decimal128(4, 0), pa.decimal128(3, -3)] + ), + # scientific power notation + ( + decimal.Decimal('1.23E+4'), + [pa.decimal128(5, 0), pa.decimal128(6, 0), pa.decimal128(3, -2)], + [pa.decimal128(4, 0), pa.decimal128(3, -3)] + ), + ]: + for typ in typs: + arr = pa.array([data], type=typ) + assert arr.type == typ + assert arr.to_pylist()[0] == data + for typ in wrong_typs: + with pytest.raises(ValueError): + pa.array([data], type=typ) + + +def test_range_types(): + arr1 = pa.array(range(3)) + arr2 = pa.array((0, 1, 2)) + assert arr1.equals(arr2) + + +def test_empty_range(): + arr = pa.array(range(0)) + assert len(arr) == 0 + assert arr.null_count == 0 + assert arr.type == pa.null() + assert arr.to_pylist() == [] + + +def test_structarray(): + arr = pa.StructArray.from_arrays([], names=[]) + assert arr.type == pa.struct([]) + assert len(arr) == 0 + assert arr.to_pylist() == [] + + ints = pa.array([None, 2, 3], type=pa.int64()) + strs = pa.array(['a', None, 'c'], type=pa.string()) + bools = pa.array([True, False, None], type=pa.bool_()) + arr = pa.StructArray.from_arrays( + [ints, strs, bools], + ['ints', 'strs', 'bools']) + + expected = [ + {'ints': None, 'strs': 'a', 'bools': True}, + {'ints': 2, 'strs': None, 'bools': False}, + {'ints': 3, 'strs': 'c', 'bools': None}, + ] + + pylist = arr.to_pylist() + assert pylist == expected, (pylist, expected) + + # len(names) != len(arrays) + with pytest.raises(ValueError): + pa.StructArray.from_arrays([ints], ['ints', 'strs']) + + +def test_struct_from_dicts(): + ty = pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_())]) + arr = pa.array([], type=ty) + assert arr.to_pylist() == [] + + data = [{'a': 5, 'b': 'foo', 'c': True}, + {'a': 6, 'b': 'bar', 'c': False}] + arr = pa.array(data, type=ty) + assert arr.to_pylist() == data + + # With omitted values + data = [{'a': 5, 'c': True}, + None, + {}, + {'a': None, 'b': 'bar'}] + arr = pa.array(data, type=ty) + expected = [{'a': 5, 'b': None, 'c': True}, + None, + {'a': None, 'b': None, 'c': None}, + {'a': None, 'b': 'bar', 'c': None}] + assert arr.to_pylist() == expected + + +def test_struct_from_dicts_bytes_keys(): + # ARROW-6878 + ty = pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_())]) + arr = pa.array([], type=ty) + assert arr.to_pylist() == [] + + data = [{b'a': 5, b'b': 'foo'}, + {b'a': 6, b'c': False}] + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [ + {'a': 5, 'b': 'foo', 'c': None}, + {'a': 6, 'b': None, 'c': False}, + ] + + +def test_struct_from_tuples(): + ty = pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_())]) + + data = [(5, 'foo', True), + (6, 'bar', False)] + expected = [{'a': 5, 'b': 'foo', 'c': True}, + {'a': 6, 'b': 'bar', 'c': False}] + arr = pa.array(data, type=ty) + + data_as_ndarray = np.empty(len(data), dtype=object) + data_as_ndarray[:] = data + arr2 = pa.array(data_as_ndarray, type=ty) + assert arr.to_pylist() == expected + + assert arr.equals(arr2) + + # With omitted values + data = [(5, 'foo', None), + None, + (6, None, False)] + expected = [{'a': 5, 'b': 'foo', 'c': None}, + None, + {'a': 6, 'b': None, 'c': False}] + arr = pa.array(data, type=ty) + assert arr.to_pylist() == expected + + # Invalid tuple size + for tup in [(5, 'foo'), (), ('5', 'foo', True, None)]: + with pytest.raises(ValueError, match="(?i)tuple size"): + pa.array([tup], type=ty) + + +def test_struct_from_list_of_pairs(): + ty = pa.struct([ + pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_()) + ]) + data = [ + [('a', 5), ('b', 'foo'), ('c', True)], + [('a', 6), ('b', 'bar'), ('c', False)], + None + ] + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [ + {'a': 5, 'b': 'foo', 'c': True}, + {'a': 6, 'b': 'bar', 'c': False}, + None + ] + + # test with duplicated field names + ty = pa.struct([ + pa.field('a', pa.int32()), + pa.field('a', pa.string()), + pa.field('b', pa.bool_()) + ]) + data = [ + [('a', 5), ('a', 'foo'), ('b', True)], + [('a', 6), ('a', 'bar'), ('b', False)], + ] + arr = pa.array(data, type=ty) + with pytest.raises(ValueError): + # TODO(kszucs): ARROW-9997 + arr.to_pylist() + + # test with empty elements + ty = pa.struct([ + pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_()) + ]) + data = [ + [], + [('a', 5), ('b', 'foo'), ('c', True)], + [('a', 2), ('b', 'baz')], + [('a', 1), ('b', 'bar'), ('c', False), ('d', 'julia')], + ] + expected = [ + {'a': None, 'b': None, 'c': None}, + {'a': 5, 'b': 'foo', 'c': True}, + {'a': 2, 'b': 'baz', 'c': None}, + {'a': 1, 'b': 'bar', 'c': False}, + ] + arr = pa.array(data, type=ty) + assert arr.to_pylist() == expected + + +def test_struct_from_list_of_pairs_errors(): + ty = pa.struct([ + pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_()) + ]) + + # test that it raises if the key doesn't match the expected field name + data = [ + [], + [('a', 5), ('c', True), ('b', None)], + ] + msg = "The expected field name is `b` but `c` was given" + with pytest.raises(ValueError, match=msg): + pa.array(data, type=ty) + + # test various errors both at the first position and after because of key + # type inference + template = ( + r"Could not convert {} with type {}: was expecting tuple of " + r"(key, value) pair" + ) + cases = [ + tuple(), # empty key-value pair + tuple('a',), # missing value + tuple('unknown-key',), # not known field name + 'string', # not a tuple + ] + for key_value_pair in cases: + msg = re.escape(template.format( + repr(key_value_pair), type(key_value_pair).__name__ + )) + + with pytest.raises(TypeError, match=msg): + pa.array([ + [key_value_pair], + [('a', 5), ('b', 'foo'), ('c', None)], + ], type=ty) + + with pytest.raises(TypeError, match=msg): + pa.array([ + [('a', 5), ('b', 'foo'), ('c', None)], + [key_value_pair], + ], type=ty) + + +def test_struct_from_mixed_sequence(): + # It is forbidden to mix dicts and tuples when initializing a struct array + ty = pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_())]) + data = [(5, 'foo', True), + {'a': 6, 'b': 'bar', 'c': False}] + with pytest.raises(TypeError): + pa.array(data, type=ty) + + +def test_struct_from_dicts_inference(): + expected_type = pa.struct([pa.field('a', pa.int64()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_())]) + data = [{'a': 5, 'b': 'foo', 'c': True}, + {'a': 6, 'b': 'bar', 'c': False}] + + arr = pa.array(data) + check_struct_type(arr.type, expected_type) + assert arr.to_pylist() == data + + # With omitted values + data = [{'a': 5, 'c': True}, + None, + {}, + {'a': None, 'b': 'bar'}] + expected = [{'a': 5, 'b': None, 'c': True}, + None, + {'a': None, 'b': None, 'c': None}, + {'a': None, 'b': 'bar', 'c': None}] + + arr = pa.array(data) + data_as_ndarray = np.empty(len(data), dtype=object) + data_as_ndarray[:] = data + arr2 = pa.array(data) + + check_struct_type(arr.type, expected_type) + assert arr.to_pylist() == expected + assert arr.equals(arr2) + + # Nested + expected_type = pa.struct([ + pa.field('a', pa.struct([pa.field('aa', pa.list_(pa.int64())), + pa.field('ab', pa.bool_())])), + pa.field('b', pa.string())]) + data = [{'a': {'aa': [5, 6], 'ab': True}, 'b': 'foo'}, + {'a': {'aa': None, 'ab': False}, 'b': None}, + {'a': None, 'b': 'bar'}] + arr = pa.array(data) + + assert arr.to_pylist() == data + + # Edge cases + arr = pa.array([{}]) + assert arr.type == pa.struct([]) + assert arr.to_pylist() == [{}] + + # Mixing structs and scalars is rejected + with pytest.raises((pa.ArrowInvalid, pa.ArrowTypeError)): + pa.array([1, {'a': 2}]) + + +def test_structarray_from_arrays_coerce(): + # ARROW-1706 + ints = [None, 2, 3] + strs = ['a', None, 'c'] + bools = [True, False, None] + ints_nonnull = [1, 2, 3] + + arrays = [ints, strs, bools, ints_nonnull] + result = pa.StructArray.from_arrays(arrays, + ['ints', 'strs', 'bools', + 'int_nonnull']) + expected = pa.StructArray.from_arrays( + [pa.array(ints, type='int64'), + pa.array(strs, type='utf8'), + pa.array(bools), + pa.array(ints_nonnull, type='int64')], + ['ints', 'strs', 'bools', 'int_nonnull']) + + with pytest.raises(ValueError): + pa.StructArray.from_arrays(arrays) + + assert result.equals(expected) + + +def test_decimal_array_with_none_and_nan(): + values = [decimal.Decimal('1.234'), None, np.nan, decimal.Decimal('nan')] + + with pytest.raises(TypeError): + # ARROW-6227: Without from_pandas=True, NaN is considered a float + array = pa.array(values) + + array = pa.array(values, from_pandas=True) + assert array.type == pa.decimal128(4, 3) + assert array.to_pylist() == values[:2] + [None, None] + + array = pa.array(values, type=pa.decimal128(10, 4), from_pandas=True) + assert array.to_pylist() == [decimal.Decimal('1.2340'), None, None, None] + + +def test_map_from_dicts(): + data = [[{'key': b'a', 'value': 1}, {'key': b'b', 'value': 2}], + [{'key': b'c', 'value': 3}], + [{'key': b'd', 'value': 4}, {'key': b'e', 'value': 5}, + {'key': b'f', 'value': None}], + [{'key': b'g', 'value': 7}]] + expected = [[(d['key'], d['value']) for d in entry] for entry in data] + + arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32())) + + assert arr.to_pylist() == expected + + # With omitted values + data[1] = None + expected[1] = None + + arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32())) + + assert arr.to_pylist() == expected + + # Invalid dictionary + for entry in [[{'value': 5}], [{}], [{'k': 1, 'v': 2}]]: + with pytest.raises(ValueError, match="Invalid Map"): + pa.array([entry], type=pa.map_('i4', 'i4')) + + # Invalid dictionary types + for entry in [[{'key': '1', 'value': 5}], [{'key': {'value': 2}}]]: + with pytest.raises(pa.ArrowInvalid, match="tried to convert to int"): + pa.array([entry], type=pa.map_('i4', 'i4')) + + +def test_map_from_tuples(): + expected = [[(b'a', 1), (b'b', 2)], + [(b'c', 3)], + [(b'd', 4), (b'e', 5), (b'f', None)], + [(b'g', 7)]] + + arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32())) + + assert arr.to_pylist() == expected + + # With omitted values + expected[1] = None + + arr = pa.array(expected, type=pa.map_(pa.binary(), pa.int32())) + + assert arr.to_pylist() == expected + + # Invalid tuple size + for entry in [[(5,)], [()], [('5', 'foo', True)]]: + with pytest.raises(ValueError, match="(?i)tuple size"): + pa.array([entry], type=pa.map_('i4', 'i4')) + + +def test_dictionary_from_boolean(): + typ = pa.dictionary(pa.int8(), value_type=pa.bool_()) + a = pa.array([False, False, True, False, True], type=typ) + assert isinstance(a.type, pa.DictionaryType) + assert a.type.equals(typ) + + expected_indices = pa.array([0, 0, 1, 0, 1], type=pa.int8()) + expected_dictionary = pa.array([False, True], type=pa.bool_()) + assert a.indices.equals(expected_indices) + assert a.dictionary.equals(expected_dictionary) + + +@pytest.mark.parametrize('value_type', [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64(), + pa.uint8(), + pa.uint16(), + pa.uint32(), + pa.uint64(), + pa.float32(), + pa.float64(), +]) +def test_dictionary_from_integers(value_type): + typ = pa.dictionary(pa.int8(), value_type=value_type) + a = pa.array([1, 2, 1, 1, 2, 3], type=typ) + assert isinstance(a.type, pa.DictionaryType) + assert a.type.equals(typ) + + expected_indices = pa.array([0, 1, 0, 0, 1, 2], type=pa.int8()) + expected_dictionary = pa.array([1, 2, 3], type=value_type) + assert a.indices.equals(expected_indices) + assert a.dictionary.equals(expected_dictionary) + + +@pytest.mark.parametrize('input_index_type', [ + pa.int8(), + pa.int16(), + pa.int32(), + pa.int64() +]) +def test_dictionary_index_type(input_index_type): + # dictionary array is constructed using adaptive index type builder, + # but the input index type is considered as the minimal width type to use + + typ = pa.dictionary(input_index_type, value_type=pa.int64()) + arr = pa.array(range(10), type=typ) + assert arr.type.equals(typ) + + +def test_dictionary_is_always_adaptive(): + # dictionary array is constructed using adaptive index type builder, + # meaning that the output index type may be wider than the given index type + # since it depends on the input data + typ = pa.dictionary(pa.int8(), value_type=pa.int64()) + + a = pa.array(range(2**7), type=typ) + expected = pa.dictionary(pa.int8(), pa.int64()) + assert a.type.equals(expected) + + a = pa.array(range(2**7 + 1), type=typ) + expected = pa.dictionary(pa.int16(), pa.int64()) + assert a.type.equals(expected) + + +def test_dictionary_from_strings(): + for value_type in [pa.binary(), pa.string()]: + typ = pa.dictionary(pa.int8(), value_type) + a = pa.array(["", "a", "bb", "a", "bb", "ccc"], type=typ) + + assert isinstance(a.type, pa.DictionaryType) + + expected_indices = pa.array([0, 1, 2, 1, 2, 3], type=pa.int8()) + expected_dictionary = pa.array(["", "a", "bb", "ccc"], type=value_type) + assert a.indices.equals(expected_indices) + assert a.dictionary.equals(expected_dictionary) + + # fixed size binary type + typ = pa.dictionary(pa.int8(), pa.binary(3)) + a = pa.array(["aaa", "aaa", "bbb", "ccc", "bbb"], type=typ) + assert isinstance(a.type, pa.DictionaryType) + + expected_indices = pa.array([0, 0, 1, 2, 1], type=pa.int8()) + expected_dictionary = pa.array(["aaa", "bbb", "ccc"], type=pa.binary(3)) + assert a.indices.equals(expected_indices) + assert a.dictionary.equals(expected_dictionary) + + +@pytest.mark.parametrize(('unit', 'expected'), [ + ('s', datetime.timedelta(seconds=-2147483000)), + ('ms', datetime.timedelta(milliseconds=-2147483000)), + ('us', datetime.timedelta(microseconds=-2147483000)), + ('ns', datetime.timedelta(microseconds=-2147483)) +]) +def test_duration_array_roundtrip_corner_cases(unit, expected): + # Corner case discovered by hypothesis: there were implicit conversions to + # unsigned values resulting wrong values with wrong signs. + ty = pa.duration(unit) + arr = pa.array([-2147483000], type=ty) + restored = pa.array(arr.to_pylist(), type=ty) + assert arr.equals(restored) + + expected_list = [expected] + if unit == 'ns': + # if pandas is available then a pandas Timedelta is returned + try: + import pandas as pd + except ImportError: + pass + else: + expected_list = [pd.Timedelta(-2147483000, unit='ns')] + + assert restored.to_pylist() == expected_list + + +@pytest.mark.pandas +def test_roundtrip_nanosecond_resolution_pandas_temporal_objects(): + # corner case discovered by hypothesis: preserving the nanoseconds on + # conversion from a list of Timedelta and Timestamp objects + import pandas as pd + + ty = pa.duration('ns') + arr = pa.array([9223371273709551616], type=ty) + data = arr.to_pylist() + assert isinstance(data[0], pd.Timedelta) + restored = pa.array(data, type=ty) + assert arr.equals(restored) + assert restored.to_pylist() == [ + pd.Timedelta(9223371273709551616, unit='ns') + ] + + ty = pa.timestamp('ns') + arr = pa.array([9223371273709551616], type=ty) + data = arr.to_pylist() + assert isinstance(data[0], pd.Timestamp) + restored = pa.array(data, type=ty) + assert arr.equals(restored) + assert restored.to_pylist() == [ + pd.Timestamp(9223371273709551616, unit='ns') + ] + + ty = pa.timestamp('ns', tz='US/Eastern') + value = 1604119893000000000 + arr = pa.array([value], type=ty) + data = arr.to_pylist() + assert isinstance(data[0], pd.Timestamp) + restored = pa.array(data, type=ty) + assert arr.equals(restored) + assert restored.to_pylist() == [ + pd.Timestamp(value, unit='ns').tz_localize( + "UTC").tz_convert('US/Eastern') + ] + + +@h.given(past.all_arrays) +def test_array_to_pylist_roundtrip(arr): + seq = arr.to_pylist() + restored = pa.array(seq, type=arr.type) + assert restored.equals(arr) + + +@pytest.mark.large_memory +def test_auto_chunking_binary_like(): + # single chunk + v1 = b'x' * 100000000 + v2 = b'x' * 147483646 + + # single chunk + one_chunk_data = [v1] * 20 + [b'', None, v2] + arr = pa.array(one_chunk_data, type=pa.binary()) + assert isinstance(arr, pa.Array) + assert len(arr) == 23 + assert arr[20].as_py() == b'' + assert arr[21].as_py() is None + assert arr[22].as_py() == v2 + + # two chunks + two_chunk_data = one_chunk_data + [b'two'] + arr = pa.array(two_chunk_data, type=pa.binary()) + assert isinstance(arr, pa.ChunkedArray) + assert arr.num_chunks == 2 + assert len(arr.chunk(0)) == 23 + assert len(arr.chunk(1)) == 1 + assert arr.chunk(0)[20].as_py() == b'' + assert arr.chunk(0)[21].as_py() is None + assert arr.chunk(0)[22].as_py() == v2 + assert arr.chunk(1).to_pylist() == [b'two'] + + # three chunks + three_chunk_data = one_chunk_data * 2 + [b'three', b'three'] + arr = pa.array(three_chunk_data, type=pa.binary()) + assert isinstance(arr, pa.ChunkedArray) + assert arr.num_chunks == 3 + assert len(arr.chunk(0)) == 23 + assert len(arr.chunk(1)) == 23 + assert len(arr.chunk(2)) == 2 + for i in range(2): + assert arr.chunk(i)[20].as_py() == b'' + assert arr.chunk(i)[21].as_py() is None + assert arr.chunk(i)[22].as_py() == v2 + assert arr.chunk(2).to_pylist() == [b'three', b'three'] + + +@pytest.mark.large_memory +def test_auto_chunking_list_of_binary(): + # ARROW-6281 + vals = [['x' * 1024]] * ((2 << 20) + 1) + arr = pa.array(vals) + assert isinstance(arr, pa.ChunkedArray) + assert arr.num_chunks == 2 + assert len(arr.chunk(0)) == 2**21 - 1 + assert len(arr.chunk(1)) == 2 + assert arr.chunk(1).to_pylist() == [['x' * 1024]] * 2 + + +@pytest.mark.large_memory +def test_auto_chunking_list_like(): + item = np.ones((2**28,), dtype='uint8') + data = [item] * (2**3 - 1) + arr = pa.array(data, type=pa.list_(pa.uint8())) + assert isinstance(arr, pa.Array) + assert len(arr) == 7 + + item = np.ones((2**28,), dtype='uint8') + data = [item] * 2**3 + arr = pa.array(data, type=pa.list_(pa.uint8())) + assert isinstance(arr, pa.ChunkedArray) + assert arr.num_chunks == 2 + assert len(arr.chunk(0)) == 7 + assert len(arr.chunk(1)) == 1 + chunk = arr.chunk(1) + scalar = chunk[0] + assert isinstance(scalar, pa.ListScalar) + expected = pa.array(item, type=pa.uint8()) + assert scalar.values == expected + + +@pytest.mark.slow +@pytest.mark.large_memory +def test_auto_chunking_map_type(): + # takes ~20 minutes locally + ty = pa.map_(pa.int8(), pa.int8()) + item = [(1, 1)] * 2**28 + data = [item] * 2**3 + arr = pa.array(data, type=ty) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr.chunk(0)) == 7 + assert len(arr.chunk(1)) == 1 + + +@pytest.mark.large_memory +@pytest.mark.parametrize(('ty', 'char'), [ + (pa.string(), 'x'), + (pa.binary(), b'x'), +]) +def test_nested_auto_chunking(ty, char): + v1 = char * 100000000 + v2 = char * 147483646 + + struct_type = pa.struct([ + pa.field('bool', pa.bool_()), + pa.field('integer', pa.int64()), + pa.field('string-like', ty), + ]) + + data = [{'bool': True, 'integer': 1, 'string-like': v1}] * 20 + data.append({'bool': True, 'integer': 1, 'string-like': v2}) + arr = pa.array(data, type=struct_type) + assert isinstance(arr, pa.Array) + + data.append({'bool': True, 'integer': 1, 'string-like': char}) + arr = pa.array(data, type=struct_type) + assert isinstance(arr, pa.ChunkedArray) + assert arr.num_chunks == 2 + assert len(arr.chunk(0)) == 21 + assert len(arr.chunk(1)) == 1 + assert arr.chunk(1)[0].as_py() == { + 'bool': True, + 'integer': 1, + 'string-like': char + } + + +@pytest.mark.large_memory +def test_array_from_pylist_data_overflow(): + # Regression test for ARROW-12983 + # Data buffer overflow - should result in chunked array + items = [b'a' * 4096] * (2 ** 19) + arr = pa.array(items, type=pa.string()) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr) == 2**19 + assert len(arr.chunks) > 1 + + mask = np.zeros(2**19, bool) + arr = pa.array(items, mask=mask, type=pa.string()) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr) == 2**19 + assert len(arr.chunks) > 1 + + arr = pa.array(items, type=pa.binary()) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr) == 2**19 + assert len(arr.chunks) > 1 + + +@pytest.mark.slow +@pytest.mark.large_memory +def test_array_from_pylist_offset_overflow(): + # Regression test for ARROW-12983 + # Offset buffer overflow - should result in chunked array + # Note this doesn't apply to primitive arrays + items = [b'a'] * (2 ** 31) + arr = pa.array(items, type=pa.string()) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr) == 2**31 + assert len(arr.chunks) > 1 + + mask = np.zeros(2**31, bool) + arr = pa.array(items, mask=mask, type=pa.string()) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr) == 2**31 + assert len(arr.chunks) > 1 + + arr = pa.array(items, type=pa.binary()) + assert isinstance(arr, pa.ChunkedArray) + assert len(arr) == 2**31 + assert len(arr.chunks) > 1 diff --git a/src/arrow/python/pyarrow/tests/test_csv.py b/src/arrow/python/pyarrow/tests/test_csv.py new file mode 100644 index 000000000..b6cca243b --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_csv.py @@ -0,0 +1,1824 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import abc +import bz2 +from datetime import date, datetime +from decimal import Decimal +import gc +import gzip +import io +import itertools +import os +import pickle +import shutil +import signal +import string +import tempfile +import threading +import time +import unittest +import weakref + +import pytest + +import numpy as np + +import pyarrow as pa +from pyarrow.csv import ( + open_csv, read_csv, ReadOptions, ParseOptions, ConvertOptions, ISO8601, + write_csv, WriteOptions, CSVWriter) +from pyarrow.tests import util + + +def generate_col_names(): + # 'a', 'b'... 'z', then 'aa', 'ab'... + letters = string.ascii_lowercase + yield from letters + for first in letters: + for second in letters: + yield first + second + + +def make_random_csv(num_cols=2, num_rows=10, linesep='\r\n', write_names=True): + arr = np.random.RandomState(42).randint(0, 1000, size=(num_cols, num_rows)) + csv = io.StringIO() + col_names = list(itertools.islice(generate_col_names(), num_cols)) + if write_names: + csv.write(",".join(col_names)) + csv.write(linesep) + for row in arr.T: + csv.write(",".join(map(str, row))) + csv.write(linesep) + csv = csv.getvalue().encode() + columns = [pa.array(a, type=pa.int64()) for a in arr] + expected = pa.Table.from_arrays(columns, col_names) + return csv, expected + + +def make_empty_csv(column_names): + csv = io.StringIO() + csv.write(",".join(column_names)) + csv.write("\n") + return csv.getvalue().encode() + + +def check_options_class(cls, **attr_values): + """ + Check setting and getting attributes of an *Options class. + """ + opts = cls() + + for name, values in attr_values.items(): + assert getattr(opts, name) == values[0], \ + "incorrect default value for " + name + for v in values: + setattr(opts, name, v) + assert getattr(opts, name) == v, "failed setting value" + + with pytest.raises(AttributeError): + opts.zzz_non_existent = True + + # Check constructor named arguments + non_defaults = {name: values[1] for name, values in attr_values.items()} + opts = cls(**non_defaults) + for name, value in non_defaults.items(): + assert getattr(opts, name) == value + + +# The various options classes need to be picklable for dataset +def check_options_class_pickling(cls, **attr_values): + opts = cls(**attr_values) + new_opts = pickle.loads(pickle.dumps(opts, + protocol=pickle.HIGHEST_PROTOCOL)) + for name, value in attr_values.items(): + assert getattr(new_opts, name) == value + + +def test_read_options(): + cls = ReadOptions + opts = cls() + + check_options_class(cls, use_threads=[True, False], + skip_rows=[0, 3], + column_names=[[], ["ab", "cd"]], + autogenerate_column_names=[False, True], + encoding=['utf8', 'utf16'], + skip_rows_after_names=[0, 27]) + + check_options_class_pickling(cls, use_threads=True, + skip_rows=3, + column_names=["ab", "cd"], + autogenerate_column_names=False, + encoding='utf16', + skip_rows_after_names=27) + + assert opts.block_size > 0 + opts.block_size = 12345 + assert opts.block_size == 12345 + + opts = cls(block_size=1234) + assert opts.block_size == 1234 + + opts.validate() + + match = "ReadOptions: block_size must be at least 1: 0" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.block_size = 0 + opts.validate() + + match = "ReadOptions: skip_rows cannot be negative: -1" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.skip_rows = -1 + opts.validate() + + match = "ReadOptions: skip_rows_after_names cannot be negative: -1" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.skip_rows_after_names = -1 + opts.validate() + + match = "ReadOptions: autogenerate_column_names cannot be true when" \ + " column_names are provided" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.autogenerate_column_names = True + opts.column_names = ('a', 'b') + opts.validate() + + +def test_parse_options(): + cls = ParseOptions + + check_options_class(cls, delimiter=[',', 'x'], + escape_char=[False, 'y'], + quote_char=['"', 'z', False], + double_quote=[True, False], + newlines_in_values=[False, True], + ignore_empty_lines=[True, False]) + + check_options_class_pickling(cls, delimiter='x', + escape_char='y', + quote_char=False, + double_quote=False, + newlines_in_values=True, + ignore_empty_lines=False) + + cls().validate() + opts = cls() + opts.delimiter = "\t" + opts.validate() + + match = "ParseOptions: delimiter cannot be \\\\r or \\\\n" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.delimiter = "\n" + opts.validate() + + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.delimiter = "\r" + opts.validate() + + match = "ParseOptions: quote_char cannot be \\\\r or \\\\n" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.quote_char = "\n" + opts.validate() + + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.quote_char = "\r" + opts.validate() + + match = "ParseOptions: escape_char cannot be \\\\r or \\\\n" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.escape_char = "\n" + opts.validate() + + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.escape_char = "\r" + opts.validate() + + +def test_convert_options(): + cls = ConvertOptions + opts = cls() + + check_options_class( + cls, check_utf8=[True, False], + strings_can_be_null=[False, True], + quoted_strings_can_be_null=[True, False], + decimal_point=['.', ','], + include_columns=[[], ['def', 'abc']], + include_missing_columns=[False, True], + auto_dict_encode=[False, True], + timestamp_parsers=[[], [ISO8601, '%y-%m']]) + + check_options_class_pickling( + cls, check_utf8=False, + strings_can_be_null=True, + quoted_strings_can_be_null=False, + decimal_point=',', + include_columns=['def', 'abc'], + include_missing_columns=False, + auto_dict_encode=True, + timestamp_parsers=[ISO8601, '%y-%m']) + + with pytest.raises(ValueError): + opts.decimal_point = '..' + + assert opts.auto_dict_max_cardinality > 0 + opts.auto_dict_max_cardinality = 99999 + assert opts.auto_dict_max_cardinality == 99999 + + assert opts.column_types == {} + # Pass column_types as mapping + opts.column_types = {'b': pa.int16(), 'c': pa.float32()} + assert opts.column_types == {'b': pa.int16(), 'c': pa.float32()} + opts.column_types = {'v': 'int16', 'w': 'null'} + assert opts.column_types == {'v': pa.int16(), 'w': pa.null()} + # Pass column_types as schema + schema = pa.schema([('a', pa.int32()), ('b', pa.string())]) + opts.column_types = schema + assert opts.column_types == {'a': pa.int32(), 'b': pa.string()} + # Pass column_types as sequence + opts.column_types = [('x', pa.binary())] + assert opts.column_types == {'x': pa.binary()} + + with pytest.raises(TypeError, match='DataType expected'): + opts.column_types = {'a': None} + with pytest.raises(TypeError): + opts.column_types = 0 + + assert isinstance(opts.null_values, list) + assert '' in opts.null_values + assert 'N/A' in opts.null_values + opts.null_values = ['xxx', 'yyy'] + assert opts.null_values == ['xxx', 'yyy'] + + assert isinstance(opts.true_values, list) + opts.true_values = ['xxx', 'yyy'] + assert opts.true_values == ['xxx', 'yyy'] + + assert isinstance(opts.false_values, list) + opts.false_values = ['xxx', 'yyy'] + assert opts.false_values == ['xxx', 'yyy'] + + assert opts.timestamp_parsers == [] + opts.timestamp_parsers = [ISO8601] + assert opts.timestamp_parsers == [ISO8601] + + opts = cls(column_types={'a': pa.null()}, + null_values=['N', 'nn'], true_values=['T', 'tt'], + false_values=['F', 'ff'], auto_dict_max_cardinality=999, + timestamp_parsers=[ISO8601, '%Y-%m-%d']) + assert opts.column_types == {'a': pa.null()} + assert opts.null_values == ['N', 'nn'] + assert opts.false_values == ['F', 'ff'] + assert opts.true_values == ['T', 'tt'] + assert opts.auto_dict_max_cardinality == 999 + assert opts.timestamp_parsers == [ISO8601, '%Y-%m-%d'] + + +def test_write_options(): + cls = WriteOptions + opts = cls() + + check_options_class( + cls, include_header=[True, False]) + + assert opts.batch_size > 0 + opts.batch_size = 12345 + assert opts.batch_size == 12345 + + opts = cls(batch_size=9876) + assert opts.batch_size == 9876 + + opts.validate() + + match = "WriteOptions: batch_size must be at least 1: 0" + with pytest.raises(pa.ArrowInvalid, match=match): + opts = cls() + opts.batch_size = 0 + opts.validate() + + +class BaseTestCSV(abc.ABC): + """Common tests which are shared by streaming and non streaming readers""" + + @abc.abstractmethod + def read_bytes(self, b, **kwargs): + """ + :param b: bytes to be parsed + :param kwargs: arguments passed on to open the csv file + :return: b parsed as a single RecordBatch + """ + raise NotImplementedError + + @property + @abc.abstractmethod + def use_threads(self): + """Whether this test is multi-threaded""" + raise NotImplementedError + + @staticmethod + def check_names(table, names): + assert table.num_columns == len(names) + assert table.column_names == names + + def test_header_skip_rows(self): + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + opts = ReadOptions() + opts.skip_rows = 1 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["ef", "gh"]) + assert table.to_pydict() == { + "ef": ["ij", "mn"], + "gh": ["kl", "op"], + } + + opts.skip_rows = 3 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["mn", "op"]) + assert table.to_pydict() == { + "mn": [], + "op": [], + } + + opts.skip_rows = 4 + with pytest.raises(pa.ArrowInvalid): + # Not enough rows + table = self.read_bytes(rows, read_options=opts) + + # Can skip rows with a different number of columns + rows = b"abcd\n,,,,,\nij,kl\nmn,op\n" + opts.skip_rows = 2 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["ij", "kl"]) + assert table.to_pydict() == { + "ij": ["mn"], + "kl": ["op"], + } + + # Can skip all rows exactly when columns are given + opts.skip_rows = 4 + opts.column_names = ['ij', 'kl'] + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["ij", "kl"]) + assert table.to_pydict() == { + "ij": [], + "kl": [], + } + + def test_skip_rows_after_names(self): + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + opts = ReadOptions() + opts.skip_rows_after_names = 1 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["ab", "cd"]) + assert table.to_pydict() == { + "ab": ["ij", "mn"], + "cd": ["kl", "op"], + } + + # Can skip exact number of rows + opts.skip_rows_after_names = 3 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["ab", "cd"]) + assert table.to_pydict() == { + "ab": [], + "cd": [], + } + + # Can skip beyond all rows + opts.skip_rows_after_names = 4 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["ab", "cd"]) + assert table.to_pydict() == { + "ab": [], + "cd": [], + } + + # Can skip rows with a different number of columns + rows = b"abcd\n,,,,,\nij,kl\nmn,op\n" + opts.skip_rows_after_names = 2 + opts.column_names = ["f0", "f1"] + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["f0", "f1"]) + assert table.to_pydict() == { + "f0": ["ij", "mn"], + "f1": ["kl", "op"], + } + opts = ReadOptions() + + # Can skip rows with new lines in the value + rows = b'ab,cd\n"e\nf","g\n\nh"\n"ij","k\nl"\nmn,op' + opts.skip_rows_after_names = 2 + parse_opts = ParseOptions() + parse_opts.newlines_in_values = True + table = self.read_bytes(rows, read_options=opts, + parse_options=parse_opts) + self.check_names(table, ["ab", "cd"]) + assert table.to_pydict() == { + "ab": ["mn"], + "cd": ["op"], + } + + # Can skip rows when block ends in middle of quoted value + opts.skip_rows_after_names = 2 + opts.block_size = 26 + table = self.read_bytes(rows, read_options=opts, + parse_options=parse_opts) + self.check_names(table, ["ab", "cd"]) + assert table.to_pydict() == { + "ab": ["mn"], + "cd": ["op"], + } + opts = ReadOptions() + + # Can skip rows that are beyond the first block without lexer + rows, expected = make_random_csv(num_cols=5, num_rows=1000) + opts.skip_rows_after_names = 900 + opts.block_size = len(rows) / 11 + table = self.read_bytes(rows, read_options=opts) + assert table.schema == expected.schema + assert table.num_rows == 100 + table_dict = table.to_pydict() + for name, values in expected.to_pydict().items(): + assert values[900:] == table_dict[name] + + # Can skip rows that are beyond the first block with lexer + table = self.read_bytes(rows, read_options=opts, + parse_options=parse_opts) + assert table.schema == expected.schema + assert table.num_rows == 100 + table_dict = table.to_pydict() + for name, values in expected.to_pydict().items(): + assert values[900:] == table_dict[name] + + # Skip rows and skip rows after names + rows, expected = make_random_csv(num_cols=5, num_rows=200, + write_names=False) + opts = ReadOptions() + opts.skip_rows = 37 + opts.skip_rows_after_names = 41 + opts.column_names = expected.schema.names + table = self.read_bytes(rows, read_options=opts, + parse_options=parse_opts) + assert table.schema == expected.schema + assert (table.num_rows == + expected.num_rows - opts.skip_rows - + opts.skip_rows_after_names) + table_dict = table.to_pydict() + for name, values in expected.to_pydict().items(): + assert (values[opts.skip_rows + opts.skip_rows_after_names:] == + table_dict[name]) + + def test_row_number_offset_in_errors(self): + # Row numbers are only correctly counted in serial reads + def format_msg(msg_format, row, *args): + if self.use_threads: + row_info = "" + else: + row_info = "Row #{}: ".format(row) + return msg_format.format(row_info, *args) + + csv, _ = make_random_csv(4, 100, write_names=True) + + read_options = ReadOptions() + read_options.block_size = len(csv) / 3 + convert_options = ConvertOptions() + convert_options.column_types = {"a": pa.int32()} + + # Test without skip_rows and column names in the csv + csv_bad_columns = csv + b"1,2\r\n" + message_columns = format_msg("{}Expected 4 columns, got 2", 102) + with pytest.raises(pa.ArrowInvalid, match=message_columns): + self.read_bytes(csv_bad_columns, + read_options=read_options, + convert_options=convert_options) + + csv_bad_type = csv + b"a,b,c,d\r\n" + message_value = format_msg( + "In CSV column #0: {}" + "CSV conversion error to int32: invalid value 'a'", + 102, csv) + with pytest.raises(pa.ArrowInvalid, match=message_value): + self.read_bytes(csv_bad_type, + read_options=read_options, + convert_options=convert_options) + + long_row = (b"this is a long row" * 15) + b",3\r\n" + csv_bad_columns_long = csv + long_row + message_long = format_msg("{}Expected 4 columns, got 2: {} ...", 102, + long_row[0:96].decode("utf-8")) + with pytest.raises(pa.ArrowInvalid, match=message_long): + self.read_bytes(csv_bad_columns_long, + read_options=read_options, + convert_options=convert_options) + + # Test skipping rows after the names + read_options.skip_rows_after_names = 47 + + with pytest.raises(pa.ArrowInvalid, match=message_columns): + self.read_bytes(csv_bad_columns, + read_options=read_options, + convert_options=convert_options) + + with pytest.raises(pa.ArrowInvalid, match=message_value): + self.read_bytes(csv_bad_type, + read_options=read_options, + convert_options=convert_options) + + with pytest.raises(pa.ArrowInvalid, match=message_long): + self.read_bytes(csv_bad_columns_long, + read_options=read_options, + convert_options=convert_options) + + read_options.skip_rows_after_names = 0 + + # Test without skip_rows and column names not in the csv + csv, _ = make_random_csv(4, 100, write_names=False) + read_options.column_names = ["a", "b", "c", "d"] + csv_bad_columns = csv + b"1,2\r\n" + message_columns = format_msg("{}Expected 4 columns, got 2", 101) + with pytest.raises(pa.ArrowInvalid, match=message_columns): + self.read_bytes(csv_bad_columns, + read_options=read_options, + convert_options=convert_options) + + csv_bad_columns_long = csv + long_row + message_long = format_msg("{}Expected 4 columns, got 2: {} ...", 101, + long_row[0:96].decode("utf-8")) + with pytest.raises(pa.ArrowInvalid, match=message_long): + self.read_bytes(csv_bad_columns_long, + read_options=read_options, + convert_options=convert_options) + + csv_bad_type = csv + b"a,b,c,d\r\n" + message_value = format_msg( + "In CSV column #0: {}" + "CSV conversion error to int32: invalid value 'a'", + 101) + message_value = message_value.format(len(csv)) + with pytest.raises(pa.ArrowInvalid, match=message_value): + self.read_bytes(csv_bad_type, + read_options=read_options, + convert_options=convert_options) + + # Test with skip_rows and column names not in the csv + read_options.skip_rows = 23 + with pytest.raises(pa.ArrowInvalid, match=message_columns): + self.read_bytes(csv_bad_columns, + read_options=read_options, + convert_options=convert_options) + + with pytest.raises(pa.ArrowInvalid, match=message_value): + self.read_bytes(csv_bad_type, + read_options=read_options, + convert_options=convert_options) + + +class BaseCSVTableRead(BaseTestCSV): + + def read_csv(self, csv, *args, validate_full=True, **kwargs): + """ + Reads the CSV file into memory using pyarrow's read_csv + csv The CSV bytes + args Positional arguments to be forwarded to pyarrow's read_csv + validate_full Whether or not to fully validate the resulting table + kwargs Keyword arguments to be forwarded to pyarrow's read_csv + """ + assert isinstance(self.use_threads, bool) # sanity check + read_options = kwargs.setdefault('read_options', ReadOptions()) + read_options.use_threads = self.use_threads + table = read_csv(csv, *args, **kwargs) + table.validate(full=validate_full) + return table + + def read_bytes(self, b, **kwargs): + return self.read_csv(pa.py_buffer(b), **kwargs) + + def test_file_object(self): + data = b"a,b\n1,2\n" + expected_data = {'a': [1], 'b': [2]} + bio = io.BytesIO(data) + table = self.read_csv(bio) + assert table.to_pydict() == expected_data + # Text files not allowed + sio = io.StringIO(data.decode()) + with pytest.raises(TypeError): + self.read_csv(sio) + + def test_header(self): + rows = b"abc,def,gh\n" + table = self.read_bytes(rows) + assert isinstance(table, pa.Table) + self.check_names(table, ["abc", "def", "gh"]) + assert table.num_rows == 0 + + def test_bom(self): + rows = b"\xef\xbb\xbfa,b\n1,2\n" + expected_data = {'a': [1], 'b': [2]} + table = self.read_bytes(rows) + assert table.to_pydict() == expected_data + + def test_one_chunk(self): + # ARROW-7661: lack of newline at end of file should not produce + # an additional chunk. + rows = [b"a,b", b"1,2", b"3,4", b"56,78"] + for line_ending in [b'\n', b'\r', b'\r\n']: + for file_ending in [b'', line_ending]: + data = line_ending.join(rows) + file_ending + table = self.read_bytes(data) + assert len(table.to_batches()) == 1 + assert table.to_pydict() == { + "a": [1, 3, 56], + "b": [2, 4, 78], + } + + def test_header_column_names(self): + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + opts = ReadOptions() + opts.column_names = ["x", "y"] + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["x", "y"]) + assert table.to_pydict() == { + "x": ["ab", "ef", "ij", "mn"], + "y": ["cd", "gh", "kl", "op"], + } + + opts.skip_rows = 3 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["x", "y"]) + assert table.to_pydict() == { + "x": ["mn"], + "y": ["op"], + } + + opts.skip_rows = 4 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["x", "y"]) + assert table.to_pydict() == { + "x": [], + "y": [], + } + + opts.skip_rows = 5 + with pytest.raises(pa.ArrowInvalid): + # Not enough rows + table = self.read_bytes(rows, read_options=opts) + + # Unexpected number of columns + opts.skip_rows = 0 + opts.column_names = ["x", "y", "z"] + with pytest.raises(pa.ArrowInvalid, + match="Expected 3 columns, got 2"): + table = self.read_bytes(rows, read_options=opts) + + # Can skip rows with a different number of columns + rows = b"abcd\n,,,,,\nij,kl\nmn,op\n" + opts.skip_rows = 2 + opts.column_names = ["x", "y"] + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["x", "y"]) + assert table.to_pydict() == { + "x": ["ij", "mn"], + "y": ["kl", "op"], + } + + def test_header_autogenerate_column_names(self): + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + opts = ReadOptions() + opts.autogenerate_column_names = True + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["f0", "f1"]) + assert table.to_pydict() == { + "f0": ["ab", "ef", "ij", "mn"], + "f1": ["cd", "gh", "kl", "op"], + } + + opts.skip_rows = 3 + table = self.read_bytes(rows, read_options=opts) + self.check_names(table, ["f0", "f1"]) + assert table.to_pydict() == { + "f0": ["mn"], + "f1": ["op"], + } + + # Not enough rows, impossible to infer number of columns + opts.skip_rows = 4 + with pytest.raises(pa.ArrowInvalid): + table = self.read_bytes(rows, read_options=opts) + + def test_include_columns(self): + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + convert_options = ConvertOptions() + convert_options.include_columns = ['ab'] + table = self.read_bytes(rows, convert_options=convert_options) + self.check_names(table, ["ab"]) + assert table.to_pydict() == { + "ab": ["ef", "ij", "mn"], + } + + # Order of include_columns is respected, regardless of CSV order + convert_options.include_columns = ['cd', 'ab'] + table = self.read_bytes(rows, convert_options=convert_options) + schema = pa.schema([('cd', pa.string()), + ('ab', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + "cd": ["gh", "kl", "op"], + "ab": ["ef", "ij", "mn"], + } + + # Include a column not in the CSV file => raises by default + convert_options.include_columns = ['xx', 'ab', 'yy'] + with pytest.raises(KeyError, + match="Column 'xx' in include_columns " + "does not exist in CSV file"): + self.read_bytes(rows, convert_options=convert_options) + + def test_include_missing_columns(self): + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + read_options = ReadOptions() + convert_options = ConvertOptions() + convert_options.include_columns = ['xx', 'ab', 'yy'] + convert_options.include_missing_columns = True + table = self.read_bytes(rows, read_options=read_options, + convert_options=convert_options) + schema = pa.schema([('xx', pa.null()), + ('ab', pa.string()), + ('yy', pa.null())]) + assert table.schema == schema + assert table.to_pydict() == { + "xx": [None, None, None], + "ab": ["ef", "ij", "mn"], + "yy": [None, None, None], + } + + # Combining with `column_names` + read_options.column_names = ["xx", "yy"] + convert_options.include_columns = ["yy", "cd"] + table = self.read_bytes(rows, read_options=read_options, + convert_options=convert_options) + schema = pa.schema([('yy', pa.string()), + ('cd', pa.null())]) + assert table.schema == schema + assert table.to_pydict() == { + "yy": ["cd", "gh", "kl", "op"], + "cd": [None, None, None, None], + } + + # And with `column_types` as well + convert_options.column_types = {"yy": pa.binary(), + "cd": pa.int32()} + table = self.read_bytes(rows, read_options=read_options, + convert_options=convert_options) + schema = pa.schema([('yy', pa.binary()), + ('cd', pa.int32())]) + assert table.schema == schema + assert table.to_pydict() == { + "yy": [b"cd", b"gh", b"kl", b"op"], + "cd": [None, None, None, None], + } + + def test_simple_ints(self): + # Infer integer columns + rows = b"a,b,c\n1,2,3\n4,5,6\n" + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.int64()), + ('b', pa.int64()), + ('c', pa.int64())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1, 4], + 'b': [2, 5], + 'c': [3, 6], + } + + def test_simple_varied(self): + # Infer various kinds of data + rows = b"a,b,c,d\n1,2,3,0\n4.0,-5,foo,True\n" + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.float64()), + ('b', pa.int64()), + ('c', pa.string()), + ('d', pa.bool_())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1.0, 4.0], + 'b': [2, -5], + 'c': ["3", "foo"], + 'd': [False, True], + } + + def test_simple_nulls(self): + # Infer various kinds of data, with nulls + rows = (b"a,b,c,d,e,f\n" + b"1,2,,,3,N/A\n" + b"nan,-5,foo,,nan,TRUE\n" + b"4.5,#N/A,nan,,\xff,false\n") + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.float64()), + ('b', pa.int64()), + ('c', pa.string()), + ('d', pa.null()), + ('e', pa.binary()), + ('f', pa.bool_())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1.0, None, 4.5], + 'b': [2, -5, None], + 'c': ["", "foo", "nan"], + 'd': [None, None, None], + 'e': [b"3", b"nan", b"\xff"], + 'f': [None, True, False], + } + + def test_decimal_point(self): + # Infer floats with a custom decimal point + parse_options = ParseOptions(delimiter=';') + rows = b"a;b\n1.25;2,5\nNA;-3\n-4;NA" + + table = self.read_bytes(rows, parse_options=parse_options) + schema = pa.schema([('a', pa.float64()), + ('b', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1.25, None, -4.0], + 'b': ["2,5", "-3", "NA"], + } + + convert_options = ConvertOptions(decimal_point=',') + table = self.read_bytes(rows, parse_options=parse_options, + convert_options=convert_options) + schema = pa.schema([('a', pa.string()), + ('b', pa.float64())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ["1.25", "NA", "-4"], + 'b': [2.5, -3.0, None], + } + + def test_simple_timestamps(self): + # Infer a timestamp column + rows = (b"a,b,c\n" + b"1970,1970-01-01 00:00:00,1970-01-01 00:00:00.123\n" + b"1989,1989-07-14 01:00:00,1989-07-14 01:00:00.123456\n") + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.int64()), + ('b', pa.timestamp('s')), + ('c', pa.timestamp('ns'))]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1970, 1989], + 'b': [datetime(1970, 1, 1), datetime(1989, 7, 14, 1)], + 'c': [datetime(1970, 1, 1, 0, 0, 0, 123000), + datetime(1989, 7, 14, 1, 0, 0, 123456)], + } + + def test_timestamp_parsers(self): + # Infer timestamps with custom parsers + rows = b"a,b\n1970/01/01,1980-01-01 00\n1970/01/02,1980-01-02 00\n" + opts = ConvertOptions() + + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.string()), + ('b', pa.timestamp('s'))]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ['1970/01/01', '1970/01/02'], + 'b': [datetime(1980, 1, 1), datetime(1980, 1, 2)], + } + + opts.timestamp_parsers = ['%Y/%m/%d'] + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.timestamp('s')), + ('b', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [datetime(1970, 1, 1), datetime(1970, 1, 2)], + 'b': ['1980-01-01 00', '1980-01-02 00'], + } + + opts.timestamp_parsers = ['%Y/%m/%d', ISO8601] + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.timestamp('s')), + ('b', pa.timestamp('s'))]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [datetime(1970, 1, 1), datetime(1970, 1, 2)], + 'b': [datetime(1980, 1, 1), datetime(1980, 1, 2)], + } + + def test_dates(self): + # Dates are inferred as date32 by default + rows = b"a,b\n1970-01-01,1970-01-02\n1971-01-01,1971-01-02\n" + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.date32()), + ('b', pa.date32())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [date(1970, 1, 1), date(1971, 1, 1)], + 'b': [date(1970, 1, 2), date(1971, 1, 2)], + } + + # Can ask for date types explicitly + opts = ConvertOptions() + opts.column_types = {'a': pa.date32(), 'b': pa.date64()} + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.date32()), + ('b', pa.date64())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [date(1970, 1, 1), date(1971, 1, 1)], + 'b': [date(1970, 1, 2), date(1971, 1, 2)], + } + + # Can ask for timestamp types explicitly + opts = ConvertOptions() + opts.column_types = {'a': pa.timestamp('s'), 'b': pa.timestamp('ms')} + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.timestamp('s')), + ('b', pa.timestamp('ms'))]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [datetime(1970, 1, 1), datetime(1971, 1, 1)], + 'b': [datetime(1970, 1, 2), datetime(1971, 1, 2)], + } + + def test_times(self): + # Times are inferred as time32[s] by default + from datetime import time + + rows = b"a,b\n12:34:56,12:34:56.789\n23:59:59,23:59:59.999\n" + table = self.read_bytes(rows) + # Column 'b' has subseconds, so cannot be inferred as time32[s] + schema = pa.schema([('a', pa.time32('s')), + ('b', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [time(12, 34, 56), time(23, 59, 59)], + 'b': ["12:34:56.789", "23:59:59.999"], + } + + # Can ask for time types explicitly + opts = ConvertOptions() + opts.column_types = {'a': pa.time64('us'), 'b': pa.time32('ms')} + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.time64('us')), + ('b', pa.time32('ms'))]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [time(12, 34, 56), time(23, 59, 59)], + 'b': [time(12, 34, 56, 789000), time(23, 59, 59, 999000)], + } + + def test_auto_dict_encode(self): + opts = ConvertOptions(auto_dict_encode=True) + rows = "a,b\nab,1\ncdé,2\ncdé,3\nab,4".encode() + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.dictionary(pa.int32(), pa.string())), + ('b', pa.int64())]) + expected = { + 'a': ["ab", "cdé", "cdé", "ab"], + 'b': [1, 2, 3, 4], + } + assert table.schema == schema + assert table.to_pydict() == expected + + opts.auto_dict_max_cardinality = 2 + table = self.read_bytes(rows, convert_options=opts) + assert table.schema == schema + assert table.to_pydict() == expected + + # Cardinality above max => plain-encoded + opts.auto_dict_max_cardinality = 1 + table = self.read_bytes(rows, convert_options=opts) + assert table.schema == pa.schema([('a', pa.string()), + ('b', pa.int64())]) + assert table.to_pydict() == expected + + # With invalid UTF8, not checked + opts.auto_dict_max_cardinality = 50 + opts.check_utf8 = False + rows = b"a,b\nab,1\ncd\xff,2\nab,3" + table = self.read_bytes(rows, convert_options=opts, + validate_full=False) + assert table.schema == schema + dict_values = table['a'].chunk(0).dictionary + assert len(dict_values) == 2 + assert dict_values[0].as_py() == "ab" + assert dict_values[1].as_buffer() == b"cd\xff" + + # With invalid UTF8, checked + opts.check_utf8 = True + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.dictionary(pa.int32(), pa.binary())), + ('b', pa.int64())]) + expected = { + 'a': [b"ab", b"cd\xff", b"ab"], + 'b': [1, 2, 3], + } + assert table.schema == schema + assert table.to_pydict() == expected + + def test_custom_nulls(self): + # Infer nulls with custom values + opts = ConvertOptions(null_values=['Xxx', 'Zzz']) + rows = b"""a,b,c,d\nZzz,"Xxx",1,2\nXxx,#N/A,,Zzz\n""" + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.null()), + ('b', pa.string()), + ('c', pa.string()), + ('d', pa.int64())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [None, None], + 'b': ["Xxx", "#N/A"], + 'c': ["1", ""], + 'd': [2, None], + } + + opts = ConvertOptions(null_values=['Xxx', 'Zzz'], + strings_can_be_null=True) + table = self.read_bytes(rows, convert_options=opts) + assert table.to_pydict() == { + 'a': [None, None], + 'b': [None, "#N/A"], + 'c': ["1", ""], + 'd': [2, None], + } + opts.quoted_strings_can_be_null = False + table = self.read_bytes(rows, convert_options=opts) + assert table.to_pydict() == { + 'a': [None, None], + 'b': ["Xxx", "#N/A"], + 'c': ["1", ""], + 'd': [2, None], + } + + opts = ConvertOptions(null_values=[]) + rows = b"a,b\n#N/A,\n" + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.string()), + ('b', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ["#N/A"], + 'b': [""], + } + + def test_custom_bools(self): + # Infer booleans with custom values + opts = ConvertOptions(true_values=['T', 'yes'], + false_values=['F', 'no']) + rows = (b"a,b,c\n" + b"True,T,t\n" + b"False,F,f\n" + b"True,yes,yes\n" + b"False,no,no\n" + b"N/A,N/A,N/A\n") + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.string()), + ('b', pa.bool_()), + ('c', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': ["True", "False", "True", "False", "N/A"], + 'b': [True, False, True, False, None], + 'c': ["t", "f", "yes", "no", "N/A"], + } + + def test_column_types(self): + # Ask for specific column types in ConvertOptions + opts = ConvertOptions(column_types={'b': 'float32', + 'c': 'string', + 'd': 'boolean', + 'e': pa.decimal128(11, 2), + 'zz': 'null'}) + rows = b"a,b,c,d,e\n1,2,3,true,1.0\n4,-5,6,false,0\n" + table = self.read_bytes(rows, convert_options=opts) + schema = pa.schema([('a', pa.int64()), + ('b', pa.float32()), + ('c', pa.string()), + ('d', pa.bool_()), + ('e', pa.decimal128(11, 2))]) + expected = { + 'a': [1, 4], + 'b': [2.0, -5.0], + 'c': ["3", "6"], + 'd': [True, False], + 'e': [Decimal("1.00"), Decimal("0.00")] + } + assert table.schema == schema + assert table.to_pydict() == expected + # Pass column_types as schema + opts = ConvertOptions( + column_types=pa.schema([('b', pa.float32()), + ('c', pa.string()), + ('d', pa.bool_()), + ('e', pa.decimal128(11, 2)), + ('zz', pa.bool_())])) + table = self.read_bytes(rows, convert_options=opts) + assert table.schema == schema + assert table.to_pydict() == expected + # One of the columns in column_types fails converting + rows = b"a,b,c,d,e\n1,XXX,3,true,5\n4,-5,6,false,7\n" + with pytest.raises(pa.ArrowInvalid) as exc: + self.read_bytes(rows, convert_options=opts) + err = str(exc.value) + assert "In CSV column #1: " in err + assert "CSV conversion error to float: invalid value 'XXX'" in err + + def test_column_types_dict(self): + # Ask for dict-encoded column types in ConvertOptions + column_types = [ + ('a', pa.dictionary(pa.int32(), pa.utf8())), + ('b', pa.dictionary(pa.int32(), pa.int64())), + ('c', pa.dictionary(pa.int32(), pa.decimal128(11, 2))), + ('d', pa.dictionary(pa.int32(), pa.large_utf8()))] + + opts = ConvertOptions(column_types=dict(column_types)) + rows = (b"a,b,c,d\n" + b"abc,123456,1.0,zz\n" + b"defg,123456,0.5,xx\n" + b"abc,N/A,1.0,xx\n") + table = self.read_bytes(rows, convert_options=opts) + + schema = pa.schema(column_types) + expected = { + 'a': ["abc", "defg", "abc"], + 'b': [123456, 123456, None], + 'c': [Decimal("1.00"), Decimal("0.50"), Decimal("1.00")], + 'd': ["zz", "xx", "xx"], + } + assert table.schema == schema + assert table.to_pydict() == expected + + # Unsupported index type + column_types[0] = ('a', pa.dictionary(pa.int8(), pa.utf8())) + + opts = ConvertOptions(column_types=dict(column_types)) + with pytest.raises(NotImplementedError): + table = self.read_bytes(rows, convert_options=opts) + + def test_column_types_with_column_names(self): + # When both `column_names` and `column_types` are given, names + # in `column_types` should refer to names in `column_names` + rows = b"a,b\nc,d\ne,f\n" + read_options = ReadOptions(column_names=['x', 'y']) + convert_options = ConvertOptions(column_types={'x': pa.binary()}) + table = self.read_bytes(rows, read_options=read_options, + convert_options=convert_options) + schema = pa.schema([('x', pa.binary()), + ('y', pa.string())]) + assert table.schema == schema + assert table.to_pydict() == { + 'x': [b'a', b'c', b'e'], + 'y': ['b', 'd', 'f'], + } + + def test_no_ending_newline(self): + # No \n after last line + rows = b"a,b,c\n1,2,3\n4,5,6" + table = self.read_bytes(rows) + assert table.to_pydict() == { + 'a': [1, 4], + 'b': [2, 5], + 'c': [3, 6], + } + + def test_trivial(self): + # A bit pointless, but at least it shouldn't crash + rows = b",\n\n" + table = self.read_bytes(rows) + assert table.to_pydict() == {'': []} + + def test_empty_lines(self): + rows = b"a,b\n\r1,2\r\n\r\n3,4\r\n" + table = self.read_bytes(rows) + assert table.to_pydict() == { + 'a': [1, 3], + 'b': [2, 4], + } + parse_options = ParseOptions(ignore_empty_lines=False) + table = self.read_bytes(rows, parse_options=parse_options) + assert table.to_pydict() == { + 'a': [None, 1, None, 3], + 'b': [None, 2, None, 4], + } + read_options = ReadOptions(skip_rows=2) + table = self.read_bytes(rows, parse_options=parse_options, + read_options=read_options) + assert table.to_pydict() == { + '1': [None, 3], + '2': [None, 4], + } + + def test_invalid_csv(self): + # Various CSV errors + rows = b"a,b,c\n1,2\n4,5,6\n" + with pytest.raises(pa.ArrowInvalid, match="Expected 3 columns, got 2"): + self.read_bytes(rows) + rows = b"a,b,c\n1,2,3\n4" + with pytest.raises(pa.ArrowInvalid, match="Expected 3 columns, got 1"): + self.read_bytes(rows) + for rows in [b"", b"\n", b"\r\n", b"\r", b"\n\n"]: + with pytest.raises(pa.ArrowInvalid, match="Empty CSV file"): + self.read_bytes(rows) + + def test_options_delimiter(self): + rows = b"a;b,c\nde,fg;eh\n" + table = self.read_bytes(rows) + assert table.to_pydict() == { + 'a;b': ['de'], + 'c': ['fg;eh'], + } + opts = ParseOptions(delimiter=';') + table = self.read_bytes(rows, parse_options=opts) + assert table.to_pydict() == { + 'a': ['de,fg'], + 'b,c': ['eh'], + } + + def test_small_random_csv(self): + csv, expected = make_random_csv(num_cols=2, num_rows=10) + table = self.read_bytes(csv) + assert table.schema == expected.schema + assert table.equals(expected) + assert table.to_pydict() == expected.to_pydict() + + def test_stress_block_sizes(self): + # Test a number of small block sizes to stress block stitching + csv_base, expected = make_random_csv(num_cols=2, num_rows=500) + block_sizes = [11, 12, 13, 17, 37, 111] + csvs = [csv_base, csv_base.rstrip(b'\r\n')] + for csv in csvs: + for block_size in block_sizes: + read_options = ReadOptions(block_size=block_size) + table = self.read_bytes(csv, read_options=read_options) + assert table.schema == expected.schema + if not table.equals(expected): + # Better error output + assert table.to_pydict() == expected.to_pydict() + + def test_stress_convert_options_blowup(self): + # ARROW-6481: A convert_options with a very large number of columns + # should not blow memory and CPU time. + try: + clock = time.thread_time + except AttributeError: + clock = time.time + num_columns = 10000 + col_names = ["K{}".format(i) for i in range(num_columns)] + csv = make_empty_csv(col_names) + t1 = clock() + convert_options = ConvertOptions( + column_types={k: pa.string() for k in col_names[::2]}) + table = self.read_bytes(csv, convert_options=convert_options) + dt = clock() - t1 + # Check that processing time didn't blow up. + # This is a conservative check (it takes less than 300 ms + # in debug mode on my local machine). + assert dt <= 10.0 + # Check result + assert table.num_columns == num_columns + assert table.num_rows == 0 + assert table.column_names == col_names + + def test_cancellation(self): + if (threading.current_thread().ident != + threading.main_thread().ident): + pytest.skip("test only works from main Python thread") + # Skips test if not available + raise_signal = util.get_raise_signal() + + # Make the interruptible workload large enough to not finish + # before the interrupt comes, even in release mode on fast machines. + last_duration = 0.0 + workload_size = 100_000 + + while last_duration < 1.0: + print("workload size:", workload_size) + large_csv = b"a,b,c\n" + b"1,2,3\n" * workload_size + t1 = time.time() + self.read_bytes(large_csv) + last_duration = time.time() - t1 + workload_size = workload_size * 3 + + def signal_from_thread(): + time.sleep(0.2) + raise_signal(signal.SIGINT) + + t1 = time.time() + try: + try: + t = threading.Thread(target=signal_from_thread) + with pytest.raises(KeyboardInterrupt) as exc_info: + t.start() + self.read_bytes(large_csv) + finally: + t.join() + except KeyboardInterrupt: + # In case KeyboardInterrupt didn't interrupt `self.read_bytes` + # above, at least prevent it from stopping the test suite + pytest.fail("KeyboardInterrupt didn't interrupt CSV reading") + dt = time.time() - t1 + # Interruption should have arrived timely + assert dt <= 1.0 + e = exc_info.value.__context__ + assert isinstance(e, pa.ArrowCancelled) + assert e.signum == signal.SIGINT + + def test_cancellation_disabled(self): + # ARROW-12622: reader would segfault when the cancelling signal + # handler was not enabled (e.g. if disabled, or if not on the + # main thread) + t = threading.Thread( + target=lambda: self.read_bytes(b"f64\n0.1")) + t.start() + t.join() + + +class TestSerialCSVTableRead(BaseCSVTableRead): + @property + def use_threads(self): + return False + + +class TestThreadedCSVTableRead(BaseCSVTableRead): + @property + def use_threads(self): + return True + + +class BaseStreamingCSVRead(BaseTestCSV): + + def open_csv(self, csv, *args, **kwargs): + """ + Reads the CSV file into memory using pyarrow's open_csv + csv The CSV bytes + args Positional arguments to be forwarded to pyarrow's open_csv + kwargs Keyword arguments to be forwarded to pyarrow's open_csv + """ + read_options = kwargs.setdefault('read_options', ReadOptions()) + read_options.use_threads = self.use_threads + return open_csv(csv, *args, **kwargs) + + def open_bytes(self, b, **kwargs): + return self.open_csv(pa.py_buffer(b), **kwargs) + + def check_reader(self, reader, expected_schema, expected_data): + assert reader.schema == expected_schema + batches = list(reader) + assert len(batches) == len(expected_data) + for batch, expected_batch in zip(batches, expected_data): + batch.validate(full=True) + assert batch.schema == expected_schema + assert batch.to_pydict() == expected_batch + + def read_bytes(self, b, **kwargs): + return self.open_bytes(b, **kwargs).read_all() + + def test_file_object(self): + data = b"a,b\n1,2\n3,4\n" + expected_data = {'a': [1, 3], 'b': [2, 4]} + bio = io.BytesIO(data) + reader = self.open_csv(bio) + expected_schema = pa.schema([('a', pa.int64()), + ('b', pa.int64())]) + self.check_reader(reader, expected_schema, [expected_data]) + + def test_header(self): + rows = b"abc,def,gh\n" + reader = self.open_bytes(rows) + expected_schema = pa.schema([('abc', pa.null()), + ('def', pa.null()), + ('gh', pa.null())]) + self.check_reader(reader, expected_schema, []) + + def test_inference(self): + # Inference is done on first block + rows = b"a,b\n123,456\nabc,de\xff\ngh,ij\n" + expected_schema = pa.schema([('a', pa.string()), + ('b', pa.binary())]) + + read_options = ReadOptions() + read_options.block_size = len(rows) + reader = self.open_bytes(rows, read_options=read_options) + self.check_reader(reader, expected_schema, + [{'a': ['123', 'abc', 'gh'], + 'b': [b'456', b'de\xff', b'ij']}]) + + read_options.block_size = len(rows) - 1 + reader = self.open_bytes(rows, read_options=read_options) + self.check_reader(reader, expected_schema, + [{'a': ['123', 'abc'], + 'b': [b'456', b'de\xff']}, + {'a': ['gh'], + 'b': [b'ij']}]) + + def test_inference_failure(self): + # Inference on first block, then conversion failure on second block + rows = b"a,b\n123,456\nabc,de\xff\ngh,ij\n" + read_options = ReadOptions() + read_options.block_size = len(rows) - 7 + reader = self.open_bytes(rows, read_options=read_options) + expected_schema = pa.schema([('a', pa.int64()), + ('b', pa.int64())]) + assert reader.schema == expected_schema + assert reader.read_next_batch().to_pydict() == { + 'a': [123], 'b': [456] + } + # Second block + with pytest.raises(ValueError, + match="CSV conversion error to int64"): + reader.read_next_batch() + # EOF + with pytest.raises(StopIteration): + reader.read_next_batch() + + def test_invalid_csv(self): + # CSV errors on first block + rows = b"a,b\n1,2,3\n4,5\n6,7\n" + read_options = ReadOptions() + read_options.block_size = 10 + with pytest.raises(pa.ArrowInvalid, + match="Expected 2 columns, got 3"): + reader = self.open_bytes( + rows, read_options=read_options) + + # CSV errors on second block + rows = b"a,b\n1,2\n3,4,5\n6,7\n" + read_options.block_size = 8 + reader = self.open_bytes(rows, read_options=read_options) + assert reader.read_next_batch().to_pydict() == {'a': [1], 'b': [2]} + with pytest.raises(pa.ArrowInvalid, + match="Expected 2 columns, got 3"): + reader.read_next_batch() + # Cannot continue after a parse error + with pytest.raises(StopIteration): + reader.read_next_batch() + + def test_options_delimiter(self): + rows = b"a;b,c\nde,fg;eh\n" + reader = self.open_bytes(rows) + expected_schema = pa.schema([('a;b', pa.string()), + ('c', pa.string())]) + self.check_reader(reader, expected_schema, + [{'a;b': ['de'], + 'c': ['fg;eh']}]) + + opts = ParseOptions(delimiter=';') + reader = self.open_bytes(rows, parse_options=opts) + expected_schema = pa.schema([('a', pa.string()), + ('b,c', pa.string())]) + self.check_reader(reader, expected_schema, + [{'a': ['de,fg'], + 'b,c': ['eh']}]) + + def test_no_ending_newline(self): + # No \n after last line + rows = b"a,b,c\n1,2,3\n4,5,6" + reader = self.open_bytes(rows) + expected_schema = pa.schema([('a', pa.int64()), + ('b', pa.int64()), + ('c', pa.int64())]) + self.check_reader(reader, expected_schema, + [{'a': [1, 4], + 'b': [2, 5], + 'c': [3, 6]}]) + + def test_empty_file(self): + with pytest.raises(ValueError, match="Empty CSV file"): + self.open_bytes(b"") + + def test_column_options(self): + # With column_names + rows = b"1,2,3\n4,5,6" + read_options = ReadOptions() + read_options.column_names = ['d', 'e', 'f'] + reader = self.open_bytes(rows, read_options=read_options) + expected_schema = pa.schema([('d', pa.int64()), + ('e', pa.int64()), + ('f', pa.int64())]) + self.check_reader(reader, expected_schema, + [{'d': [1, 4], + 'e': [2, 5], + 'f': [3, 6]}]) + + # With include_columns + convert_options = ConvertOptions() + convert_options.include_columns = ['f', 'e'] + reader = self.open_bytes(rows, read_options=read_options, + convert_options=convert_options) + expected_schema = pa.schema([('f', pa.int64()), + ('e', pa.int64())]) + self.check_reader(reader, expected_schema, + [{'e': [2, 5], + 'f': [3, 6]}]) + + # With column_types + convert_options.column_types = {'e': pa.string()} + reader = self.open_bytes(rows, read_options=read_options, + convert_options=convert_options) + expected_schema = pa.schema([('f', pa.int64()), + ('e', pa.string())]) + self.check_reader(reader, expected_schema, + [{'e': ["2", "5"], + 'f': [3, 6]}]) + + # Missing columns in include_columns + convert_options.include_columns = ['g', 'f', 'e'] + with pytest.raises( + KeyError, + match="Column 'g' in include_columns does not exist"): + reader = self.open_bytes(rows, read_options=read_options, + convert_options=convert_options) + + convert_options.include_missing_columns = True + reader = self.open_bytes(rows, read_options=read_options, + convert_options=convert_options) + expected_schema = pa.schema([('g', pa.null()), + ('f', pa.int64()), + ('e', pa.string())]) + self.check_reader(reader, expected_schema, + [{'g': [None, None], + 'e': ["2", "5"], + 'f': [3, 6]}]) + + convert_options.column_types = {'e': pa.string(), 'g': pa.float64()} + reader = self.open_bytes(rows, read_options=read_options, + convert_options=convert_options) + expected_schema = pa.schema([('g', pa.float64()), + ('f', pa.int64()), + ('e', pa.string())]) + self.check_reader(reader, expected_schema, + [{'g': [None, None], + 'e': ["2", "5"], + 'f': [3, 6]}]) + + def test_encoding(self): + # latin-1 (invalid utf-8) + rows = b"a,b\nun,\xe9l\xe9phant" + read_options = ReadOptions() + reader = self.open_bytes(rows, read_options=read_options) + expected_schema = pa.schema([('a', pa.string()), + ('b', pa.binary())]) + self.check_reader(reader, expected_schema, + [{'a': ["un"], + 'b': [b"\xe9l\xe9phant"]}]) + + read_options.encoding = 'latin1' + reader = self.open_bytes(rows, read_options=read_options) + expected_schema = pa.schema([('a', pa.string()), + ('b', pa.string())]) + self.check_reader(reader, expected_schema, + [{'a': ["un"], + 'b': ["éléphant"]}]) + + # utf-16 + rows = (b'\xff\xfea\x00,\x00b\x00\n\x00u\x00n\x00,' + b'\x00\xe9\x00l\x00\xe9\x00p\x00h\x00a\x00n\x00t\x00') + read_options.encoding = 'utf16' + reader = self.open_bytes(rows, read_options=read_options) + expected_schema = pa.schema([('a', pa.string()), + ('b', pa.string())]) + self.check_reader(reader, expected_schema, + [{'a': ["un"], + 'b': ["éléphant"]}]) + + def test_small_random_csv(self): + csv, expected = make_random_csv(num_cols=2, num_rows=10) + reader = self.open_bytes(csv) + table = reader.read_all() + assert table.schema == expected.schema + assert table.equals(expected) + assert table.to_pydict() == expected.to_pydict() + + def test_stress_block_sizes(self): + # Test a number of small block sizes to stress block stitching + csv_base, expected = make_random_csv(num_cols=2, num_rows=500) + block_sizes = [19, 21, 23, 26, 37, 111] + csvs = [csv_base, csv_base.rstrip(b'\r\n')] + for csv in csvs: + for block_size in block_sizes: + # Need at least two lines for type inference + assert csv[:block_size].count(b'\n') >= 2 + read_options = ReadOptions(block_size=block_size) + reader = self.open_bytes( + csv, read_options=read_options) + table = reader.read_all() + assert table.schema == expected.schema + if not table.equals(expected): + # Better error output + assert table.to_pydict() == expected.to_pydict() + + def test_batch_lifetime(self): + gc.collect() + old_allocated = pa.total_allocated_bytes() + + # Memory occupation should not grow with CSV file size + def check_one_batch(reader, expected): + batch = reader.read_next_batch() + assert batch.to_pydict() == expected + + rows = b"10,11\n12,13\n14,15\n16,17\n" + read_options = ReadOptions() + read_options.column_names = ['a', 'b'] + read_options.block_size = 6 + reader = self.open_bytes(rows, read_options=read_options) + check_one_batch(reader, {'a': [10], 'b': [11]}) + allocated_after_first_batch = pa.total_allocated_bytes() + check_one_batch(reader, {'a': [12], 'b': [13]}) + assert pa.total_allocated_bytes() <= allocated_after_first_batch + check_one_batch(reader, {'a': [14], 'b': [15]}) + assert pa.total_allocated_bytes() <= allocated_after_first_batch + check_one_batch(reader, {'a': [16], 'b': [17]}) + assert pa.total_allocated_bytes() <= allocated_after_first_batch + with pytest.raises(StopIteration): + reader.read_next_batch() + assert pa.total_allocated_bytes() == old_allocated + reader = None + assert pa.total_allocated_bytes() == old_allocated + + def test_header_skip_rows(self): + super().test_header_skip_rows() + + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + # Skipping all rows immediately results in end of iteration + opts = ReadOptions() + opts.skip_rows = 4 + opts.column_names = ['ab', 'cd'] + reader = self.open_bytes(rows, read_options=opts) + with pytest.raises(StopIteration): + assert reader.read_next_batch() + + def test_skip_rows_after_names(self): + super().test_skip_rows_after_names() + + rows = b"ab,cd\nef,gh\nij,kl\nmn,op\n" + + # Skipping all rows immediately results in end of iteration + opts = ReadOptions() + opts.skip_rows_after_names = 3 + reader = self.open_bytes(rows, read_options=opts) + with pytest.raises(StopIteration): + assert reader.read_next_batch() + + # Skipping beyond all rows immediately results in end of iteration + opts.skip_rows_after_names = 99999 + reader = self.open_bytes(rows, read_options=opts) + with pytest.raises(StopIteration): + assert reader.read_next_batch() + + +class TestSerialStreamingCSVRead(BaseStreamingCSVRead, unittest.TestCase): + @property + def use_threads(self): + return False + + +class TestThreadedStreamingCSVRead(BaseStreamingCSVRead, unittest.TestCase): + @property + def use_threads(self): + return True + + +class BaseTestCompressedCSVRead: + + def setUp(self): + self.tmpdir = tempfile.mkdtemp(prefix='arrow-csv-test-') + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def read_csv(self, csv_path): + try: + return read_csv(csv_path) + except pa.ArrowNotImplementedError as e: + pytest.skip(str(e)) + + def test_random_csv(self): + csv, expected = make_random_csv(num_cols=2, num_rows=100) + csv_path = os.path.join(self.tmpdir, self.csv_filename) + self.write_file(csv_path, csv) + table = self.read_csv(csv_path) + table.validate(full=True) + assert table.schema == expected.schema + assert table.equals(expected) + assert table.to_pydict() == expected.to_pydict() + + +class TestGZipCSVRead(BaseTestCompressedCSVRead, unittest.TestCase): + csv_filename = "compressed.csv.gz" + + def write_file(self, path, contents): + with gzip.open(path, 'wb', 3) as f: + f.write(contents) + + def test_concatenated(self): + # ARROW-5974 + csv_path = os.path.join(self.tmpdir, self.csv_filename) + with gzip.open(csv_path, 'wb', 3) as f: + f.write(b"ab,cd\nef,gh\n") + with gzip.open(csv_path, 'ab', 3) as f: + f.write(b"ij,kl\nmn,op\n") + table = self.read_csv(csv_path) + assert table.to_pydict() == { + 'ab': ['ef', 'ij', 'mn'], + 'cd': ['gh', 'kl', 'op'], + } + + +class TestBZ2CSVRead(BaseTestCompressedCSVRead, unittest.TestCase): + csv_filename = "compressed.csv.bz2" + + def write_file(self, path, contents): + with bz2.BZ2File(path, 'w') as f: + f.write(contents) + + +def test_read_csv_does_not_close_passed_file_handles(): + # ARROW-4823 + buf = io.BytesIO(b"a,b,c\n1,2,3\n4,5,6") + read_csv(buf) + assert not buf.closed + + +def test_write_read_round_trip(): + t = pa.Table.from_arrays([[1, 2, 3], ["a", "b", "c"]], ["c1", "c2"]) + record_batch = t.to_batches(max_chunksize=4)[0] + for data in [t, record_batch]: + # Test with header + buf = io.BytesIO() + write_csv(data, buf, WriteOptions(include_header=True)) + buf.seek(0) + assert t == read_csv(buf) + + # Test without header + buf = io.BytesIO() + write_csv(data, buf, WriteOptions(include_header=False)) + buf.seek(0) + + read_options = ReadOptions(column_names=t.column_names) + assert t == read_csv(buf, read_options=read_options) + + # Test with writer + for read_options, write_options in [ + (None, WriteOptions(include_header=True)), + (ReadOptions(column_names=t.column_names), + WriteOptions(include_header=False)), + ]: + buf = io.BytesIO() + with CSVWriter(buf, t.schema, write_options=write_options) as writer: + writer.write_table(t) + buf.seek(0) + assert t == read_csv(buf, read_options=read_options) + + buf = io.BytesIO() + with CSVWriter(buf, t.schema, write_options=write_options) as writer: + for batch in t.to_batches(max_chunksize=1): + writer.write_batch(batch) + buf.seek(0) + assert t == read_csv(buf, read_options=read_options) + + +def test_read_csv_reference_cycle(): + # ARROW-13187 + def inner(): + buf = io.BytesIO(b"a,b,c\n1,2,3\n4,5,6") + table = read_csv(buf) + return weakref.ref(table) + + with util.disabled_gc(): + wr = inner() + assert wr() is None diff --git a/src/arrow/python/pyarrow/tests/test_cuda.py b/src/arrow/python/pyarrow/tests/test_cuda.py new file mode 100644 index 000000000..2ba2f8267 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_cuda.py @@ -0,0 +1,792 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +UNTESTED: +read_message +""" + +import sys +import sysconfig + +import pytest + +import pyarrow as pa +import numpy as np + + +cuda = pytest.importorskip("pyarrow.cuda") + +platform = sysconfig.get_platform() +# TODO: enable ppc64 when Arrow C++ supports IPC in ppc64 systems: +has_ipc_support = platform == 'linux-x86_64' # or 'ppc64' in platform + +cuda_ipc = pytest.mark.skipif( + not has_ipc_support, + reason='CUDA IPC not supported in platform `%s`' % (platform)) + +global_context = None # for flake8 +global_context1 = None # for flake8 + + +def setup_module(module): + module.global_context = cuda.Context(0) + module.global_context1 = cuda.Context(cuda.Context.get_num_devices() - 1) + + +def teardown_module(module): + del module.global_context + + +def test_Context(): + assert cuda.Context.get_num_devices() > 0 + assert global_context.device_number == 0 + assert global_context1.device_number == cuda.Context.get_num_devices() - 1 + + with pytest.raises(ValueError, + match=("device_number argument must " + "be non-negative less than")): + cuda.Context(cuda.Context.get_num_devices()) + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_manage_allocate_free_host(size): + buf = cuda.new_host_buffer(size) + arr = np.frombuffer(buf, dtype=np.uint8) + arr[size//4:3*size//4] = 1 + arr_cp = arr.copy() + arr2 = np.frombuffer(buf, dtype=np.uint8) + np.testing.assert_equal(arr2, arr_cp) + assert buf.size == size + + +def test_context_allocate_del(): + bytes_allocated = global_context.bytes_allocated + cudabuf = global_context.new_buffer(128) + assert global_context.bytes_allocated == bytes_allocated + 128 + del cudabuf + assert global_context.bytes_allocated == bytes_allocated + + +def make_random_buffer(size, target='host'): + """Return a host or device buffer with random data. + """ + if target == 'host': + assert size >= 0 + buf = pa.allocate_buffer(size) + assert buf.size == size + arr = np.frombuffer(buf, dtype=np.uint8) + assert arr.size == size + arr[:] = np.random.randint(low=1, high=255, size=size, dtype=np.uint8) + assert arr.sum() > 0 or size == 0 + arr_ = np.frombuffer(buf, dtype=np.uint8) + np.testing.assert_equal(arr, arr_) + return arr, buf + elif target == 'device': + arr, buf = make_random_buffer(size, target='host') + dbuf = global_context.new_buffer(size) + assert dbuf.size == size + dbuf.copy_from_host(buf, position=0, nbytes=size) + return arr, dbuf + raise ValueError('invalid target value') + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_context_device_buffer(size): + # Creating device buffer from host buffer; + arr, buf = make_random_buffer(size) + cudabuf = global_context.buffer_from_data(buf) + assert cudabuf.size == size + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + # CudaBuffer does not support buffer protocol + with pytest.raises(BufferError): + memoryview(cudabuf) + + # Creating device buffer from array: + cudabuf = global_context.buffer_from_data(arr) + assert cudabuf.size == size + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + # Creating device buffer from bytes: + cudabuf = global_context.buffer_from_data(arr.tobytes()) + assert cudabuf.size == size + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + # Creating a device buffer from another device buffer, view: + cudabuf2 = cudabuf.slice(0, cudabuf.size) + assert cudabuf2.size == size + arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + if size > 1: + cudabuf2.copy_from_host(arr[size//2:]) + arr3 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(np.concatenate((arr[size//2:], arr[size//2:])), + arr3) + cudabuf2.copy_from_host(arr[:size//2]) # restoring arr + + # Creating a device buffer from another device buffer, copy: + cudabuf2 = global_context.buffer_from_data(cudabuf) + assert cudabuf2.size == size + arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + cudabuf2.copy_from_host(arr[size//2:]) + arr3 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr3) + + # Slice of a device buffer + cudabuf2 = cudabuf.slice(0, cudabuf.size+10) + assert cudabuf2.size == size + arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + cudabuf2 = cudabuf.slice(size//4, size+10) + assert cudabuf2.size == size - size//4 + arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[size//4:], arr2) + + # Creating a device buffer from a slice of host buffer + soffset = size//4 + ssize = 2*size//4 + cudabuf = global_context.buffer_from_data(buf, offset=soffset, + size=ssize) + assert cudabuf.size == ssize + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) + + cudabuf = global_context.buffer_from_data(buf.slice(offset=soffset, + length=ssize)) + assert cudabuf.size == ssize + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) + + # Creating a device buffer from a slice of an array + cudabuf = global_context.buffer_from_data(arr, offset=soffset, size=ssize) + assert cudabuf.size == ssize + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) + + cudabuf = global_context.buffer_from_data(arr[soffset:soffset+ssize]) + assert cudabuf.size == ssize + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) + + # Creating a device buffer from a slice of bytes + cudabuf = global_context.buffer_from_data(arr.tobytes(), + offset=soffset, + size=ssize) + assert cudabuf.size == ssize + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset + ssize], arr2) + + # Creating a device buffer from size + cudabuf = global_context.new_buffer(size) + assert cudabuf.size == size + + # Creating device buffer from a slice of another device buffer: + cudabuf = global_context.buffer_from_data(arr) + cudabuf2 = cudabuf.slice(soffset, ssize) + assert cudabuf2.size == ssize + arr2 = np.frombuffer(cudabuf2.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset+ssize], arr2) + + # Creating device buffer from HostBuffer + + buf = cuda.new_host_buffer(size) + arr_ = np.frombuffer(buf, dtype=np.uint8) + arr_[:] = arr + cudabuf = global_context.buffer_from_data(buf) + assert cudabuf.size == size + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + # Creating device buffer from HostBuffer slice + + cudabuf = global_context.buffer_from_data(buf, offset=soffset, size=ssize) + assert cudabuf.size == ssize + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset+ssize], arr2) + + cudabuf = global_context.buffer_from_data( + buf.slice(offset=soffset, length=ssize)) + assert cudabuf.size == ssize + arr2 = np.frombuffer(cudabuf.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr[soffset:soffset+ssize], arr2) + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_context_from_object(size): + ctx = global_context + arr, cbuf = make_random_buffer(size, target='device') + dtype = arr.dtype + + # Creating device buffer from a CUDA host buffer + hbuf = cuda.new_host_buffer(size * arr.dtype.itemsize) + np.frombuffer(hbuf, dtype=dtype)[:] = arr + cbuf2 = ctx.buffer_from_object(hbuf) + assert cbuf2.size == cbuf.size + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr, arr2) + + # Creating device buffer from a device buffer + cbuf2 = ctx.buffer_from_object(cbuf2) + assert cbuf2.size == cbuf.size + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr, arr2) + + # Trying to create a device buffer from a Buffer + with pytest.raises(pa.ArrowTypeError, + match=('buffer is not backed by a CudaBuffer')): + ctx.buffer_from_object(pa.py_buffer(b"123")) + + # Trying to create a device buffer from numpy.array + with pytest.raises(pa.ArrowTypeError, + match=("cannot create device buffer view from " + ".* \'numpy.ndarray\'")): + ctx.buffer_from_object(np.array([1, 2, 3])) + + +def test_foreign_buffer(): + ctx = global_context + dtype = np.dtype(np.uint8) + size = 10 + hbuf = cuda.new_host_buffer(size * dtype.itemsize) + + # test host buffer memory reference counting + rc = sys.getrefcount(hbuf) + fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size, hbuf) + assert sys.getrefcount(hbuf) == rc + 1 + del fbuf + assert sys.getrefcount(hbuf) == rc + + # test postponed deallocation of host buffer memory + fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size, hbuf) + del hbuf + fbuf.copy_to_host() + + # test deallocating the host buffer memory making it inaccessible + hbuf = cuda.new_host_buffer(size * dtype.itemsize) + fbuf = ctx.foreign_buffer(hbuf.address, hbuf.size) + del hbuf + with pytest.raises(pa.ArrowIOError, + match=('Cuda error ')): + fbuf.copy_to_host() + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_CudaBuffer(size): + arr, buf = make_random_buffer(size) + assert arr.tobytes() == buf.to_pybytes() + cbuf = global_context.buffer_from_data(buf) + assert cbuf.size == size + assert not cbuf.is_cpu + assert arr.tobytes() == cbuf.to_pybytes() + if size > 0: + assert cbuf.address > 0 + + for i in range(size): + assert cbuf[i] == arr[i] + + for s in [ + slice(None), + slice(size//4, size//2), + ]: + assert cbuf[s].to_pybytes() == arr[s].tobytes() + + sbuf = cbuf.slice(size//4, size//2) + assert sbuf.parent == cbuf + + with pytest.raises(TypeError, + match="Do not call CudaBuffer's constructor directly"): + cuda.CudaBuffer() + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_HostBuffer(size): + arr, buf = make_random_buffer(size) + assert arr.tobytes() == buf.to_pybytes() + hbuf = cuda.new_host_buffer(size) + np.frombuffer(hbuf, dtype=np.uint8)[:] = arr + assert hbuf.size == size + assert hbuf.is_cpu + assert arr.tobytes() == hbuf.to_pybytes() + for i in range(size): + assert hbuf[i] == arr[i] + for s in [ + slice(None), + slice(size//4, size//2), + ]: + assert hbuf[s].to_pybytes() == arr[s].tobytes() + + sbuf = hbuf.slice(size//4, size//2) + assert sbuf.parent == hbuf + + del hbuf + + with pytest.raises(TypeError, + match="Do not call HostBuffer's constructor directly"): + cuda.HostBuffer() + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_copy_from_to_host(size): + + # Create a buffer in host containing range(size) + buf = pa.allocate_buffer(size, resizable=True) # in host + assert isinstance(buf, pa.Buffer) + assert not isinstance(buf, cuda.CudaBuffer) + arr = np.frombuffer(buf, dtype=np.uint8) + assert arr.size == size + arr[:] = range(size) + arr_ = np.frombuffer(buf, dtype=np.uint8) + np.testing.assert_equal(arr, arr_) + + device_buffer = global_context.new_buffer(size) + assert isinstance(device_buffer, cuda.CudaBuffer) + assert isinstance(device_buffer, pa.Buffer) + assert device_buffer.size == size + assert not device_buffer.is_cpu + + device_buffer.copy_from_host(buf, position=0, nbytes=size) + + buf2 = device_buffer.copy_to_host(position=0, nbytes=size) + arr2 = np.frombuffer(buf2, dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_copy_to_host(size): + arr, dbuf = make_random_buffer(size, target='device') + + buf = dbuf.copy_to_host() + assert buf.is_cpu + np.testing.assert_equal(arr, np.frombuffer(buf, dtype=np.uint8)) + + buf = dbuf.copy_to_host(position=size//4) + assert buf.is_cpu + np.testing.assert_equal(arr[size//4:], np.frombuffer(buf, dtype=np.uint8)) + + buf = dbuf.copy_to_host(position=size//4, nbytes=size//8) + assert buf.is_cpu + np.testing.assert_equal(arr[size//4:size//4+size//8], + np.frombuffer(buf, dtype=np.uint8)) + + buf = dbuf.copy_to_host(position=size//4, nbytes=0) + assert buf.is_cpu + assert buf.size == 0 + + for (position, nbytes) in [ + (size+2, -1), (-2, -1), (size+1, 0), (-3, 0), + ]: + with pytest.raises(ValueError, + match='position argument is out-of-range'): + dbuf.copy_to_host(position=position, nbytes=nbytes) + + for (position, nbytes) in [ + (0, size+1), (size//2, (size+1)//2+1), (size, 1) + ]: + with pytest.raises(ValueError, + match=('requested more to copy than' + ' available from device buffer')): + dbuf.copy_to_host(position=position, nbytes=nbytes) + + buf = pa.allocate_buffer(size//4) + dbuf.copy_to_host(buf=buf) + np.testing.assert_equal(arr[:size//4], np.frombuffer(buf, dtype=np.uint8)) + + if size < 12: + return + + dbuf.copy_to_host(buf=buf, position=12) + np.testing.assert_equal(arr[12:12+size//4], + np.frombuffer(buf, dtype=np.uint8)) + + dbuf.copy_to_host(buf=buf, nbytes=12) + np.testing.assert_equal(arr[:12], np.frombuffer(buf, dtype=np.uint8)[:12]) + + dbuf.copy_to_host(buf=buf, nbytes=12, position=6) + np.testing.assert_equal(arr[6:6+12], + np.frombuffer(buf, dtype=np.uint8)[:12]) + + for (position, nbytes) in [ + (0, size+10), (10, size-5), + (0, size//2), (size//4, size//4+1) + ]: + with pytest.raises(ValueError, + match=('requested copy does not ' + 'fit into host buffer')): + dbuf.copy_to_host(buf=buf, position=position, nbytes=nbytes) + + +@pytest.mark.parametrize("dest_ctx", ['same', 'another']) +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_copy_from_device(dest_ctx, size): + arr, buf = make_random_buffer(size=size, target='device') + lst = arr.tolist() + if dest_ctx == 'another': + dest_ctx = global_context1 + if buf.context.device_number == dest_ctx.device_number: + pytest.skip("not a multi-GPU system") + else: + dest_ctx = buf.context + dbuf = dest_ctx.new_buffer(size) + + def put(*args, **kwargs): + dbuf.copy_from_device(buf, *args, **kwargs) + rbuf = dbuf.copy_to_host() + return np.frombuffer(rbuf, dtype=np.uint8).tolist() + assert put() == lst + if size > 4: + assert put(position=size//4) == lst[:size//4]+lst[:-size//4] + assert put() == lst + assert put(position=1, nbytes=size//2) == \ + lst[:1] + lst[:size//2] + lst[-(size-size//2-1):] + + for (position, nbytes) in [ + (size+2, -1), (-2, -1), (size+1, 0), (-3, 0), + ]: + with pytest.raises(ValueError, + match='position argument is out-of-range'): + put(position=position, nbytes=nbytes) + + for (position, nbytes) in [ + (0, size+1), + ]: + with pytest.raises(ValueError, + match=('requested more to copy than' + ' available from device buffer')): + put(position=position, nbytes=nbytes) + + if size < 4: + return + + for (position, nbytes) in [ + (size//2, (size+1)//2+1) + ]: + with pytest.raises(ValueError, + match=('requested more to copy than' + ' available in device buffer')): + put(position=position, nbytes=nbytes) + + +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_copy_from_host(size): + arr, buf = make_random_buffer(size=size, target='host') + lst = arr.tolist() + dbuf = global_context.new_buffer(size) + + def put(*args, **kwargs): + dbuf.copy_from_host(buf, *args, **kwargs) + rbuf = dbuf.copy_to_host() + return np.frombuffer(rbuf, dtype=np.uint8).tolist() + assert put() == lst + if size > 4: + assert put(position=size//4) == lst[:size//4]+lst[:-size//4] + assert put() == lst + assert put(position=1, nbytes=size//2) == \ + lst[:1] + lst[:size//2] + lst[-(size-size//2-1):] + + for (position, nbytes) in [ + (size+2, -1), (-2, -1), (size+1, 0), (-3, 0), + ]: + with pytest.raises(ValueError, + match='position argument is out-of-range'): + put(position=position, nbytes=nbytes) + + for (position, nbytes) in [ + (0, size+1), + ]: + with pytest.raises(ValueError, + match=('requested more to copy than' + ' available from host buffer')): + put(position=position, nbytes=nbytes) + + if size < 4: + return + + for (position, nbytes) in [ + (size//2, (size+1)//2+1) + ]: + with pytest.raises(ValueError, + match=('requested more to copy than' + ' available in device buffer')): + put(position=position, nbytes=nbytes) + + +def test_BufferWriter(): + def allocate(size): + cbuf = global_context.new_buffer(size) + writer = cuda.BufferWriter(cbuf) + return cbuf, writer + + def test_writes(total_size, chunksize, buffer_size=0): + cbuf, writer = allocate(total_size) + arr, buf = make_random_buffer(size=total_size, target='host') + + if buffer_size > 0: + writer.buffer_size = buffer_size + + position = writer.tell() + assert position == 0 + writer.write(buf.slice(length=chunksize)) + assert writer.tell() == chunksize + writer.seek(0) + position = writer.tell() + assert position == 0 + + while position < total_size: + bytes_to_write = min(chunksize, total_size - position) + writer.write(buf.slice(offset=position, length=bytes_to_write)) + position += bytes_to_write + + writer.flush() + assert cbuf.size == total_size + cbuf.context.synchronize() + buf2 = cbuf.copy_to_host() + cbuf.context.synchronize() + assert buf2.size == total_size + arr2 = np.frombuffer(buf2, dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + total_size, chunk_size = 1 << 16, 1000 + test_writes(total_size, chunk_size) + test_writes(total_size, chunk_size, total_size // 16) + + cbuf, writer = allocate(100) + writer.write(np.arange(100, dtype=np.uint8)) + writer.writeat(50, np.arange(25, dtype=np.uint8)) + writer.write(np.arange(25, dtype=np.uint8)) + writer.flush() + + arr = np.frombuffer(cbuf.copy_to_host(), np.uint8) + np.testing.assert_equal(arr[:50], np.arange(50, dtype=np.uint8)) + np.testing.assert_equal(arr[50:75], np.arange(25, dtype=np.uint8)) + np.testing.assert_equal(arr[75:], np.arange(25, dtype=np.uint8)) + + +def test_BufferWriter_edge_cases(): + # edge cases, see cuda-test.cc for more information: + size = 1000 + cbuf = global_context.new_buffer(size) + writer = cuda.BufferWriter(cbuf) + arr, buf = make_random_buffer(size=size, target='host') + + assert writer.buffer_size == 0 + writer.buffer_size = 100 + assert writer.buffer_size == 100 + + writer.write(buf.slice(length=0)) + assert writer.tell() == 0 + + writer.write(buf.slice(length=10)) + writer.buffer_size = 200 + assert writer.buffer_size == 200 + assert writer.num_bytes_buffered == 0 + + writer.write(buf.slice(offset=10, length=300)) + assert writer.num_bytes_buffered == 0 + + writer.write(buf.slice(offset=310, length=200)) + assert writer.num_bytes_buffered == 0 + + writer.write(buf.slice(offset=510, length=390)) + writer.write(buf.slice(offset=900, length=100)) + + writer.flush() + + buf2 = cbuf.copy_to_host() + assert buf2.size == size + arr2 = np.frombuffer(buf2, dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + +def test_BufferReader(): + size = 1000 + arr, cbuf = make_random_buffer(size=size, target='device') + + reader = cuda.BufferReader(cbuf) + reader.seek(950) + assert reader.tell() == 950 + + data = reader.read(100) + assert len(data) == 50 + assert reader.tell() == 1000 + + reader.seek(925) + arr2 = np.zeros(100, dtype=np.uint8) + n = reader.readinto(arr2) + assert n == 75 + assert reader.tell() == 1000 + np.testing.assert_equal(arr[925:], arr2[:75]) + + reader.seek(0) + assert reader.tell() == 0 + buf2 = reader.read_buffer() + arr2 = np.frombuffer(buf2.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + +def test_BufferReader_zero_size(): + arr, cbuf = make_random_buffer(size=0, target='device') + reader = cuda.BufferReader(cbuf) + reader.seek(0) + data = reader.read() + assert len(data) == 0 + assert reader.tell() == 0 + buf2 = reader.read_buffer() + arr2 = np.frombuffer(buf2.copy_to_host(), dtype=np.uint8) + np.testing.assert_equal(arr, arr2) + + +def make_recordbatch(length): + schema = pa.schema([pa.field('f0', pa.int16()), + pa.field('f1', pa.int16())]) + a0 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16)) + a1 = pa.array(np.random.randint(0, 255, size=length, dtype=np.int16)) + batch = pa.record_batch([a0, a1], schema=schema) + return batch + + +def test_batch_serialize(): + batch = make_recordbatch(10) + hbuf = batch.serialize() + cbuf = cuda.serialize_record_batch(batch, global_context) + + # Test that read_record_batch works properly + cbatch = cuda.read_record_batch(cbuf, batch.schema) + assert isinstance(cbatch, pa.RecordBatch) + assert batch.schema == cbatch.schema + assert batch.num_columns == cbatch.num_columns + assert batch.num_rows == cbatch.num_rows + + # Deserialize CUDA-serialized batch on host + buf = cbuf.copy_to_host() + assert hbuf.equals(buf) + batch2 = pa.ipc.read_record_batch(buf, batch.schema) + assert hbuf.equals(batch2.serialize()) + + assert batch.num_columns == batch2.num_columns + assert batch.num_rows == batch2.num_rows + assert batch.column(0).equals(batch2.column(0)) + assert batch.equals(batch2) + + +def make_table(): + a0 = pa.array([0, 1, 42, None], type=pa.int16()) + a1 = pa.array([[0, 1], [2], [], None], type=pa.list_(pa.int32())) + a2 = pa.array([("ab", True), ("cde", False), (None, None), None], + type=pa.struct([("strs", pa.utf8()), + ("bools", pa.bool_())])) + # Dictionaries are validated on the IPC read path, but that can produce + # issues for GPU-located dictionaries. Check that they work fine. + a3 = pa.DictionaryArray.from_arrays( + indices=[0, 1, 1, None], + dictionary=pa.array(['foo', 'bar'])) + a4 = pa.DictionaryArray.from_arrays( + indices=[2, 1, 2, None], + dictionary=a1) + a5 = pa.DictionaryArray.from_arrays( + indices=[2, 1, 0, None], + dictionary=a2) + + arrays = [a0, a1, a2, a3, a4, a5] + schema = pa.schema([('f{}'.format(i), arr.type) + for i, arr in enumerate(arrays)]) + batch = pa.record_batch(arrays, schema=schema) + table = pa.Table.from_batches([batch]) + return table + + +def make_table_cuda(): + htable = make_table() + # Serialize the host table to bytes + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, htable.schema) as out: + out.write_table(htable) + hbuf = pa.py_buffer(sink.getvalue().to_pybytes()) + + # Copy the host bytes to a device buffer + dbuf = global_context.new_buffer(len(hbuf)) + dbuf.copy_from_host(hbuf, nbytes=len(hbuf)) + # Deserialize the device buffer into a Table + dtable = pa.ipc.open_stream(cuda.BufferReader(dbuf)).read_all() + return hbuf, htable, dbuf, dtable + + +def test_table_deserialize(): + # ARROW-9659: make sure that we can deserialize a GPU-located table + # without crashing when initializing or validating the underlying arrays. + hbuf, htable, dbuf, dtable = make_table_cuda() + # Assert basic fields the same between host and device tables + assert htable.schema == dtable.schema + assert htable.num_rows == dtable.num_rows + assert htable.num_columns == dtable.num_columns + # Assert byte-level equality + assert hbuf.equals(dbuf.copy_to_host()) + # Copy DtoH and assert the tables are still equivalent + assert htable.equals(pa.ipc.open_stream( + dbuf.copy_to_host() + ).read_all()) + + +def test_create_table_with_device_buffers(): + # ARROW-11872: make sure that we can create an Arrow Table from + # GPU-located Arrays without crashing. + hbuf, htable, dbuf, dtable = make_table_cuda() + # Construct a new Table from the device Table + dtable2 = pa.Table.from_arrays(dtable.columns, dtable.column_names) + # Assert basic fields the same between host and device tables + assert htable.schema == dtable2.schema + assert htable.num_rows == dtable2.num_rows + assert htable.num_columns == dtable2.num_columns + # Assert byte-level equality + assert hbuf.equals(dbuf.copy_to_host()) + # Copy DtoH and assert the tables are still equivalent + assert htable.equals(pa.ipc.open_stream( + dbuf.copy_to_host() + ).read_all()) + + +def other_process_for_test_IPC(handle_buffer, expected_arr): + other_context = pa.cuda.Context(0) + ipc_handle = pa.cuda.IpcMemHandle.from_buffer(handle_buffer) + ipc_buf = other_context.open_ipc_buffer(ipc_handle) + ipc_buf.context.synchronize() + buf = ipc_buf.copy_to_host() + assert buf.size == expected_arr.size, repr((buf.size, expected_arr.size)) + arr = np.frombuffer(buf, dtype=expected_arr.dtype) + np.testing.assert_equal(arr, expected_arr) + + +@cuda_ipc +@pytest.mark.parametrize("size", [0, 1, 1000]) +def test_IPC(size): + import multiprocessing + ctx = multiprocessing.get_context('spawn') + arr, cbuf = make_random_buffer(size=size, target='device') + ipc_handle = cbuf.export_for_ipc() + handle_buffer = ipc_handle.serialize() + p = ctx.Process(target=other_process_for_test_IPC, + args=(handle_buffer, arr)) + p.start() + p.join() + assert p.exitcode == 0 diff --git a/src/arrow/python/pyarrow/tests/test_cuda_numba_interop.py b/src/arrow/python/pyarrow/tests/test_cuda_numba_interop.py new file mode 100644 index 000000000..ff1722d27 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_cuda_numba_interop.py @@ -0,0 +1,235 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import pyarrow as pa +import numpy as np + +dtypes = ['uint8', 'int16', 'float32'] +cuda = pytest.importorskip("pyarrow.cuda") +nb_cuda = pytest.importorskip("numba.cuda") + +from numba.cuda.cudadrv.devicearray import DeviceNDArray # noqa: E402 + + +context_choices = None +context_choice_ids = ['pyarrow.cuda', 'numba.cuda'] + + +def setup_module(module): + np.random.seed(1234) + ctx1 = cuda.Context() + nb_ctx1 = ctx1.to_numba() + nb_ctx2 = nb_cuda.current_context() + ctx2 = cuda.Context.from_numba(nb_ctx2) + module.context_choices = [(ctx1, nb_ctx1), (ctx2, nb_ctx2)] + + +def teardown_module(module): + del module.context_choices + + +@pytest.mark.parametrize("c", range(len(context_choice_ids)), + ids=context_choice_ids) +def test_context(c): + ctx, nb_ctx = context_choices[c] + assert ctx.handle == nb_ctx.handle.value + assert ctx.handle == ctx.to_numba().handle.value + ctx2 = cuda.Context.from_numba(nb_ctx) + assert ctx.handle == ctx2.handle + size = 10 + buf = ctx.new_buffer(size) + assert ctx.handle == buf.context.handle + + +def make_random_buffer(size, target='host', dtype='uint8', ctx=None): + """Return a host or device buffer with random data. + """ + dtype = np.dtype(dtype) + if target == 'host': + assert size >= 0 + buf = pa.allocate_buffer(size*dtype.itemsize) + arr = np.frombuffer(buf, dtype=dtype) + arr[:] = np.random.randint(low=0, high=255, size=size, + dtype=np.uint8) + return arr, buf + elif target == 'device': + arr, buf = make_random_buffer(size, target='host', dtype=dtype) + dbuf = ctx.new_buffer(size * dtype.itemsize) + dbuf.copy_from_host(buf, position=0, nbytes=buf.size) + return arr, dbuf + raise ValueError('invalid target value') + + +@pytest.mark.parametrize("c", range(len(context_choice_ids)), + ids=context_choice_ids) +@pytest.mark.parametrize("dtype", dtypes, ids=dtypes) +@pytest.mark.parametrize("size", [0, 1, 8, 1000]) +def test_from_object(c, dtype, size): + ctx, nb_ctx = context_choices[c] + arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx) + + # Creating device buffer from numba DeviceNDArray: + darr = nb_cuda.to_device(arr) + cbuf2 = ctx.buffer_from_object(darr) + assert cbuf2.size == cbuf.size + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr, arr2) + + # Creating device buffer from a slice of numba DeviceNDArray: + if size >= 8: + # 1-D arrays + for s in [slice(size//4, None, None), + slice(size//4, -(size//4), None)]: + cbuf2 = ctx.buffer_from_object(darr[s]) + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr[s], arr2) + + # cannot test negative strides due to numba bug, see its issue 3705 + if 0: + rdarr = darr[::-1] + cbuf2 = ctx.buffer_from_object(rdarr) + assert cbuf2.size == cbuf.size + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr, arr2) + + with pytest.raises(ValueError, + match=('array data is non-contiguous')): + ctx.buffer_from_object(darr[::2]) + + # a rectangular 2-D array + s1 = size//4 + s2 = size//s1 + assert s1 * s2 == size + cbuf2 = ctx.buffer_from_object(darr.reshape(s1, s2)) + assert cbuf2.size == cbuf.size + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr, arr2) + + with pytest.raises(ValueError, + match=('array data is non-contiguous')): + ctx.buffer_from_object(darr.reshape(s1, s2)[:, ::2]) + + # a 3-D array + s1 = 4 + s2 = size//8 + s3 = size//(s1*s2) + assert s1 * s2 * s3 == size + cbuf2 = ctx.buffer_from_object(darr.reshape(s1, s2, s3)) + assert cbuf2.size == cbuf.size + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr, arr2) + + with pytest.raises(ValueError, + match=('array data is non-contiguous')): + ctx.buffer_from_object(darr.reshape(s1, s2, s3)[::2]) + + # Creating device buffer from am object implementing cuda array + # interface: + class MyObj: + def __init__(self, darr): + self.darr = darr + + @property + def __cuda_array_interface__(self): + return self.darr.__cuda_array_interface__ + + cbuf2 = ctx.buffer_from_object(MyObj(darr)) + assert cbuf2.size == cbuf.size + arr2 = np.frombuffer(cbuf2.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr, arr2) + + +@pytest.mark.parametrize("c", range(len(context_choice_ids)), + ids=context_choice_ids) +@pytest.mark.parametrize("dtype", dtypes, ids=dtypes) +def test_numba_memalloc(c, dtype): + ctx, nb_ctx = context_choices[c] + dtype = np.dtype(dtype) + # Allocate memory using numba context + # Warning: this will not be reflected in pyarrow context manager + # (e.g bytes_allocated does not change) + size = 10 + mem = nb_ctx.memalloc(size * dtype.itemsize) + darr = DeviceNDArray((size,), (dtype.itemsize,), dtype, gpu_data=mem) + darr[:5] = 99 + darr[5:] = 88 + np.testing.assert_equal(darr.copy_to_host()[:5], 99) + np.testing.assert_equal(darr.copy_to_host()[5:], 88) + + # wrap numba allocated memory with CudaBuffer + cbuf = cuda.CudaBuffer.from_numba(mem) + arr2 = np.frombuffer(cbuf.copy_to_host(), dtype=dtype) + np.testing.assert_equal(arr2, darr.copy_to_host()) + + +@pytest.mark.parametrize("c", range(len(context_choice_ids)), + ids=context_choice_ids) +@pytest.mark.parametrize("dtype", dtypes, ids=dtypes) +def test_pyarrow_memalloc(c, dtype): + ctx, nb_ctx = context_choices[c] + size = 10 + arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx) + + # wrap CudaBuffer with numba device array + mem = cbuf.to_numba() + darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem) + np.testing.assert_equal(darr.copy_to_host(), arr) + + +@pytest.mark.parametrize("c", range(len(context_choice_ids)), + ids=context_choice_ids) +@pytest.mark.parametrize("dtype", dtypes, ids=dtypes) +def test_numba_context(c, dtype): + ctx, nb_ctx = context_choices[c] + size = 10 + with nb_cuda.gpus[0]: + arr, cbuf = make_random_buffer(size, target='device', + dtype=dtype, ctx=ctx) + assert cbuf.context.handle == nb_ctx.handle.value + mem = cbuf.to_numba() + darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem) + np.testing.assert_equal(darr.copy_to_host(), arr) + darr[0] = 99 + cbuf.context.synchronize() + arr2 = np.frombuffer(cbuf.copy_to_host(), dtype=dtype) + assert arr2[0] == 99 + + +@pytest.mark.parametrize("c", range(len(context_choice_ids)), + ids=context_choice_ids) +@pytest.mark.parametrize("dtype", dtypes, ids=dtypes) +def test_pyarrow_jit(c, dtype): + ctx, nb_ctx = context_choices[c] + + @nb_cuda.jit + def increment_by_one(an_array): + pos = nb_cuda.grid(1) + if pos < an_array.size: + an_array[pos] += 1 + + # applying numba.cuda kernel to memory hold by CudaBuffer + size = 10 + arr, cbuf = make_random_buffer(size, target='device', dtype=dtype, ctx=ctx) + threadsperblock = 32 + blockspergrid = (arr.size + (threadsperblock - 1)) // threadsperblock + mem = cbuf.to_numba() + darr = DeviceNDArray(arr.shape, arr.strides, arr.dtype, gpu_data=mem) + increment_by_one[blockspergrid, threadsperblock](darr) + cbuf.context.synchronize() + arr1 = np.frombuffer(cbuf.copy_to_host(), dtype=arr.dtype) + np.testing.assert_equal(arr1, arr + 1) diff --git a/src/arrow/python/pyarrow/tests/test_cython.py b/src/arrow/python/pyarrow/tests/test_cython.py new file mode 100644 index 000000000..e202b417a --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_cython.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import shutil +import subprocess +import sys + +import pytest + +import pyarrow as pa +import pyarrow.tests.util as test_util + + +here = os.path.dirname(os.path.abspath(__file__)) +test_ld_path = os.environ.get('PYARROW_TEST_LD_PATH', '') +if os.name == 'posix': + compiler_opts = ['-std=c++11'] +else: + compiler_opts = [] + + +setup_template = """if 1: + from setuptools import setup + from Cython.Build import cythonize + + import numpy as np + + import pyarrow as pa + + ext_modules = cythonize({pyx_file!r}) + compiler_opts = {compiler_opts!r} + custom_ld_path = {test_ld_path!r} + + for ext in ext_modules: + # XXX required for numpy/numpyconfig.h, + # included from arrow/python/api.h + ext.include_dirs.append(np.get_include()) + ext.include_dirs.append(pa.get_include()) + ext.libraries.extend(pa.get_libraries()) + ext.library_dirs.extend(pa.get_library_dirs()) + if custom_ld_path: + ext.library_dirs.append(custom_ld_path) + ext.extra_compile_args.extend(compiler_opts) + print("Extension module:", + ext, ext.include_dirs, ext.libraries, ext.library_dirs) + + setup( + ext_modules=ext_modules, + ) +""" + + +def check_cython_example_module(mod): + arr = pa.array([1, 2, 3]) + assert mod.get_array_length(arr) == 3 + with pytest.raises(TypeError, match="not an array"): + mod.get_array_length(None) + + scal = pa.scalar(123) + cast_scal = mod.cast_scalar(scal, pa.utf8()) + assert cast_scal == pa.scalar("123") + with pytest.raises(NotImplementedError, + match="casting scalars of type int64 to type list"): + mod.cast_scalar(scal, pa.list_(pa.int64())) + + +@pytest.mark.cython +def test_cython_api(tmpdir): + """ + Basic test for the Cython API. + """ + # Fail early if cython is not found + import cython # noqa + + with tmpdir.as_cwd(): + # Set up temporary workspace + pyx_file = 'pyarrow_cython_example.pyx' + shutil.copyfile(os.path.join(here, pyx_file), + os.path.join(str(tmpdir), pyx_file)) + # Create setup.py file + setup_code = setup_template.format(pyx_file=pyx_file, + compiler_opts=compiler_opts, + test_ld_path=test_ld_path) + with open('setup.py', 'w') as f: + f.write(setup_code) + + # ARROW-2263: Make environment with this pyarrow/ package first on the + # PYTHONPATH, for local dev environments + subprocess_env = test_util.get_modified_env_with_pythonpath() + + # Compile extension module + subprocess.check_call([sys.executable, 'setup.py', + 'build_ext', '--inplace'], + env=subprocess_env) + + # Check basic functionality + orig_path = sys.path[:] + sys.path.insert(0, str(tmpdir)) + try: + mod = __import__('pyarrow_cython_example') + check_cython_example_module(mod) + finally: + sys.path = orig_path + + # Check the extension module is loadable from a subprocess without + # pyarrow imported first. + code = """if 1: + import sys + + mod = __import__({mod_name!r}) + arr = mod.make_null_array(5) + assert mod.get_array_length(arr) == 5 + assert arr.null_count == 5 + """.format(mod_name='pyarrow_cython_example') + + if sys.platform == 'win32': + delim, var = ';', 'PATH' + else: + delim, var = ':', 'LD_LIBRARY_PATH' + + subprocess_env[var] = delim.join( + pa.get_library_dirs() + [subprocess_env.get(var, '')] + ) + + subprocess.check_call([sys.executable, '-c', code], + stdout=subprocess.PIPE, + env=subprocess_env) + + +@pytest.mark.cython +def test_visit_strings(tmpdir): + with tmpdir.as_cwd(): + # Set up temporary workspace + pyx_file = 'bound_function_visit_strings.pyx' + shutil.copyfile(os.path.join(here, pyx_file), + os.path.join(str(tmpdir), pyx_file)) + # Create setup.py file + setup_code = setup_template.format(pyx_file=pyx_file, + compiler_opts=compiler_opts, + test_ld_path=test_ld_path) + with open('setup.py', 'w') as f: + f.write(setup_code) + + subprocess_env = test_util.get_modified_env_with_pythonpath() + + # Compile extension module + subprocess.check_call([sys.executable, 'setup.py', + 'build_ext', '--inplace'], + env=subprocess_env) + + sys.path.insert(0, str(tmpdir)) + mod = __import__('bound_function_visit_strings') + + strings = ['a', 'b', 'c'] + visited = [] + mod._visit_strings(strings, visited.append) + + assert visited == strings + + with pytest.raises(ValueError, match="wtf"): + def raise_on_b(s): + if s == 'b': + raise ValueError('wtf') + + mod._visit_strings(strings, raise_on_b) diff --git a/src/arrow/python/pyarrow/tests/test_dataset.py b/src/arrow/python/pyarrow/tests/test_dataset.py new file mode 100644 index 000000000..20b12316b --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_dataset.py @@ -0,0 +1,3976 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import contextlib +import os +import posixpath +import pathlib +import pickle +import textwrap +import tempfile +import threading +import time + +import numpy as np +import pytest + +import pyarrow as pa +import pyarrow.csv +import pyarrow.feather +import pyarrow.fs as fs +from pyarrow.tests.util import (change_cwd, _filesystem_uri, + FSProtocolClass, ProxyHandler) + +try: + import pandas as pd +except ImportError: + pd = None + +try: + import pyarrow.dataset as ds +except ImportError: + ds = None + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not dataset' +pytestmark = pytest.mark.dataset + + +def _generate_data(n): + import datetime + import itertools + + day = datetime.datetime(2000, 1, 1) + interval = datetime.timedelta(days=5) + colors = itertools.cycle(['green', 'blue', 'yellow', 'red', 'orange']) + + data = [] + for i in range(n): + data.append((day, i, float(i), next(colors))) + day += interval + + return pd.DataFrame(data, columns=['date', 'index', 'value', 'color']) + + +def _table_from_pandas(df): + schema = pa.schema([ + pa.field('date', pa.date32()), + pa.field('index', pa.int64()), + pa.field('value', pa.float64()), + pa.field('color', pa.string()), + ]) + table = pa.Table.from_pandas(df, schema=schema, preserve_index=False) + return table.replace_schema_metadata() + + +@pytest.fixture +@pytest.mark.parquet +def mockfs(): + import pyarrow.parquet as pq + + mockfs = fs._MockFileSystem() + + directories = [ + 'subdir/1/xxx', + 'subdir/2/yyy', + ] + + for i, directory in enumerate(directories): + path = '{}/file{}.parquet'.format(directory, i) + mockfs.create_dir(directory) + with mockfs.open_output_stream(path) as out: + data = [ + list(range(5)), + list(map(float, range(5))), + list(map(str, range(5))), + [i] * 5 + ] + schema = pa.schema([ + pa.field('i64', pa.int64()), + pa.field('f64', pa.float64()), + pa.field('str', pa.string()), + pa.field('const', pa.int64()), + ]) + batch = pa.record_batch(data, schema=schema) + table = pa.Table.from_batches([batch]) + + pq.write_table(table, out) + + return mockfs + + +@pytest.fixture +def open_logging_fs(monkeypatch): + from pyarrow.fs import PyFileSystem, LocalFileSystem + from .test_fs import ProxyHandler + + localfs = LocalFileSystem() + + def normalized(paths): + return {localfs.normalize_path(str(p)) for p in paths} + + opened = set() + + def open_input_file(self, path): + path = localfs.normalize_path(str(path)) + opened.add(path) + return self._fs.open_input_file(path) + + # patch proxyhandler to log calls to open_input_file + monkeypatch.setattr(ProxyHandler, "open_input_file", open_input_file) + fs = PyFileSystem(ProxyHandler(localfs)) + + @contextlib.contextmanager + def assert_opens(expected_opened): + opened.clear() + try: + yield + finally: + assert normalized(opened) == normalized(expected_opened) + + return fs, assert_opens + + +@pytest.fixture(scope='module') +def multisourcefs(request): + request.config.pyarrow.requires('pandas') + request.config.pyarrow.requires('parquet') + import pyarrow.parquet as pq + + df = _generate_data(1000) + mockfs = fs._MockFileSystem() + + # simply split the dataframe into four chunks to construct a data source + # from each chunk into its own directory + df_a, df_b, df_c, df_d = np.array_split(df, 4) + + # create a directory containing a flat sequence of parquet files without + # any partitioning involved + mockfs.create_dir('plain') + for i, chunk in enumerate(np.array_split(df_a, 10)): + path = 'plain/chunk-{}.parquet'.format(i) + with mockfs.open_output_stream(path) as out: + pq.write_table(_table_from_pandas(chunk), out) + + # create one with schema partitioning by weekday and color + mockfs.create_dir('schema') + for part, chunk in df_b.groupby([df_b.date.dt.dayofweek, df_b.color]): + folder = 'schema/{}/{}'.format(*part) + path = '{}/chunk.parquet'.format(folder) + mockfs.create_dir(folder) + with mockfs.open_output_stream(path) as out: + pq.write_table(_table_from_pandas(chunk), out) + + # create one with hive partitioning by year and month + mockfs.create_dir('hive') + for part, chunk in df_c.groupby([df_c.date.dt.year, df_c.date.dt.month]): + folder = 'hive/year={}/month={}'.format(*part) + path = '{}/chunk.parquet'.format(folder) + mockfs.create_dir(folder) + with mockfs.open_output_stream(path) as out: + pq.write_table(_table_from_pandas(chunk), out) + + # create one with hive partitioning by color + mockfs.create_dir('hive_color') + for part, chunk in df_d.groupby(["color"]): + folder = 'hive_color/color={}'.format(*part) + path = '{}/chunk.parquet'.format(folder) + mockfs.create_dir(folder) + with mockfs.open_output_stream(path) as out: + pq.write_table(_table_from_pandas(chunk), out) + + return mockfs + + +@pytest.fixture +def dataset(mockfs): + format = ds.ParquetFileFormat() + selector = fs.FileSelector('subdir', recursive=True) + options = ds.FileSystemFactoryOptions('subdir') + options.partitioning = ds.DirectoryPartitioning( + pa.schema([ + pa.field('group', pa.int32()), + pa.field('key', pa.string()) + ]) + ) + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + return factory.finish() + + +@pytest.fixture(params=[ + (True, True), + (True, False), + (False, True), + (False, False) +], ids=['threaded-async', 'threaded-sync', 'serial-async', 'serial-sync']) +def dataset_reader(request): + ''' + Fixture which allows dataset scanning operations to be + run with/without threads and with/without async + ''' + use_threads, use_async = request.param + + class reader: + + def __init__(self): + self.use_threads = use_threads + self.use_async = use_async + + def _patch_kwargs(self, kwargs): + if 'use_threads' in kwargs: + raise Exception( + ('Invalid use of dataset_reader, do not specify' + ' use_threads')) + if 'use_async' in kwargs: + raise Exception( + 'Invalid use of dataset_reader, do not specify use_async') + kwargs['use_threads'] = use_threads + kwargs['use_async'] = use_async + + def to_table(self, dataset, **kwargs): + self._patch_kwargs(kwargs) + return dataset.to_table(**kwargs) + + def to_batches(self, dataset, **kwargs): + self._patch_kwargs(kwargs) + return dataset.to_batches(**kwargs) + + def scanner(self, dataset, **kwargs): + self._patch_kwargs(kwargs) + return dataset.scanner(**kwargs) + + def head(self, dataset, num_rows, **kwargs): + self._patch_kwargs(kwargs) + return dataset.head(num_rows, **kwargs) + + def take(self, dataset, indices, **kwargs): + self._patch_kwargs(kwargs) + return dataset.take(indices, **kwargs) + + def count_rows(self, dataset, **kwargs): + self._patch_kwargs(kwargs) + return dataset.count_rows(**kwargs) + + return reader() + + +def test_filesystem_dataset(mockfs): + schema = pa.schema([ + pa.field('const', pa.int64()) + ]) + file_format = ds.ParquetFileFormat() + paths = ['subdir/1/xxx/file0.parquet', 'subdir/2/yyy/file1.parquet'] + partitions = [ds.field('part') == x for x in range(1, 3)] + fragments = [file_format.make_fragment(path, mockfs, part) + for path, part in zip(paths, partitions)] + root_partition = ds.field('level') == ds.scalar(1337) + + dataset_from_fragments = ds.FileSystemDataset( + fragments, schema=schema, format=file_format, + filesystem=mockfs, root_partition=root_partition, + ) + dataset_from_paths = ds.FileSystemDataset.from_paths( + paths, schema=schema, format=file_format, filesystem=mockfs, + partitions=partitions, root_partition=root_partition, + ) + + for dataset in [dataset_from_fragments, dataset_from_paths]: + assert isinstance(dataset, ds.FileSystemDataset) + assert isinstance(dataset.format, ds.ParquetFileFormat) + assert dataset.partition_expression.equals(root_partition) + assert set(dataset.files) == set(paths) + + fragments = list(dataset.get_fragments()) + for fragment, partition, path in zip(fragments, partitions, paths): + assert fragment.partition_expression.equals(partition) + assert fragment.path == path + assert isinstance(fragment.format, ds.ParquetFileFormat) + assert isinstance(fragment, ds.ParquetFileFragment) + assert fragment.row_groups == [0] + assert fragment.num_row_groups == 1 + + row_group_fragments = list(fragment.split_by_row_group()) + assert fragment.num_row_groups == len(row_group_fragments) == 1 + assert isinstance(row_group_fragments[0], ds.ParquetFileFragment) + assert row_group_fragments[0].path == path + assert row_group_fragments[0].row_groups == [0] + assert row_group_fragments[0].num_row_groups == 1 + + fragments = list(dataset.get_fragments(filter=ds.field("const") == 0)) + assert len(fragments) == 2 + + # the root_partition keyword has a default + dataset = ds.FileSystemDataset( + fragments, schema=schema, format=file_format, filesystem=mockfs + ) + assert dataset.partition_expression.equals(ds.scalar(True)) + + # from_paths partitions have defaults + dataset = ds.FileSystemDataset.from_paths( + paths, schema=schema, format=file_format, filesystem=mockfs + ) + assert dataset.partition_expression.equals(ds.scalar(True)) + for fragment in dataset.get_fragments(): + assert fragment.partition_expression.equals(ds.scalar(True)) + + # validation of required arguments + with pytest.raises(TypeError, match="incorrect type"): + ds.FileSystemDataset(fragments, file_format, schema) + # validation of root_partition + with pytest.raises(TypeError, match="incorrect type"): + ds.FileSystemDataset(fragments, schema=schema, + format=file_format, root_partition=1) + # missing required argument in from_paths + with pytest.raises(TypeError, match="incorrect type"): + ds.FileSystemDataset.from_paths(fragments, format=file_format) + + +def test_filesystem_dataset_no_filesystem_interaction(dataset_reader): + # ARROW-8283 + schema = pa.schema([ + pa.field('f1', pa.int64()) + ]) + file_format = ds.IpcFileFormat() + paths = ['nonexistingfile.arrow'] + + # creating the dataset itself doesn't raise + dataset = ds.FileSystemDataset.from_paths( + paths, schema=schema, format=file_format, + filesystem=fs.LocalFileSystem(), + ) + + # getting fragments also doesn't raise + dataset.get_fragments() + + # scanning does raise + with pytest.raises(FileNotFoundError): + dataset_reader.to_table(dataset) + + +def test_dataset(dataset, dataset_reader): + assert isinstance(dataset, ds.Dataset) + assert isinstance(dataset.schema, pa.Schema) + + # TODO(kszucs): test non-boolean Exprs for filter do raise + expected_i64 = pa.array([0, 1, 2, 3, 4], type=pa.int64()) + expected_f64 = pa.array([0, 1, 2, 3, 4], type=pa.float64()) + + for batch in dataset_reader.to_batches(dataset): + assert isinstance(batch, pa.RecordBatch) + assert batch.column(0).equals(expected_i64) + assert batch.column(1).equals(expected_f64) + + for batch in dataset_reader.scanner(dataset).scan_batches(): + assert isinstance(batch, ds.TaggedRecordBatch) + assert isinstance(batch.fragment, ds.Fragment) + + table = dataset_reader.to_table(dataset) + assert isinstance(table, pa.Table) + assert len(table) == 10 + + condition = ds.field('i64') == 1 + result = dataset.to_table(use_threads=True, filter=condition).to_pydict() + + # don't rely on the scanning order + assert result['i64'] == [1, 1] + assert result['f64'] == [1., 1.] + assert sorted(result['group']) == [1, 2] + assert sorted(result['key']) == ['xxx', 'yyy'] + + +def test_scanner(dataset, dataset_reader): + scanner = dataset_reader.scanner( + dataset, memory_pool=pa.default_memory_pool()) + assert isinstance(scanner, ds.Scanner) + + with pytest.raises(pa.ArrowInvalid): + dataset_reader.scanner(dataset, columns=['unknown']) + + scanner = dataset_reader.scanner(dataset, columns=['i64'], + memory_pool=pa.default_memory_pool()) + assert scanner.dataset_schema == dataset.schema + assert scanner.projected_schema == pa.schema([("i64", pa.int64())]) + + assert isinstance(scanner, ds.Scanner) + table = scanner.to_table() + for batch in scanner.to_batches(): + assert batch.schema == scanner.projected_schema + assert batch.num_columns == 1 + assert table == scanner.to_reader().read_all() + + assert table.schema == scanner.projected_schema + for i in range(table.num_rows): + indices = pa.array([i]) + assert table.take(indices) == scanner.take(indices) + with pytest.raises(pa.ArrowIndexError): + scanner.take(pa.array([table.num_rows])) + + assert table.num_rows == scanner.count_rows() + + +def test_head(dataset, dataset_reader): + result = dataset_reader.head(dataset, 0) + assert result == pa.Table.from_batches([], schema=dataset.schema) + + result = dataset_reader.head(dataset, 1, columns=['i64']).to_pydict() + assert result == {'i64': [0]} + + result = dataset_reader.head(dataset, 2, columns=['i64'], + filter=ds.field('i64') > 1).to_pydict() + assert result == {'i64': [2, 3]} + + result = dataset_reader.head(dataset, 1024, columns=['i64']).to_pydict() + assert result == {'i64': list(range(5)) * 2} + + fragment = next(dataset.get_fragments()) + result = fragment.head(1, columns=['i64']).to_pydict() + assert result == {'i64': [0]} + + result = fragment.head(1024, columns=['i64']).to_pydict() + assert result == {'i64': list(range(5))} + + +def test_take(dataset, dataset_reader): + fragment = next(dataset.get_fragments()) + indices = pa.array([1, 3]) + assert dataset_reader.take( + fragment, indices) == dataset_reader.to_table(fragment).take(indices) + with pytest.raises(IndexError): + dataset_reader.take(fragment, pa.array([5])) + + indices = pa.array([1, 7]) + assert dataset_reader.take( + dataset, indices) == dataset_reader.to_table(dataset).take(indices) + with pytest.raises(IndexError): + dataset_reader.take(dataset, pa.array([10])) + + +def test_count_rows(dataset, dataset_reader): + fragment = next(dataset.get_fragments()) + assert dataset_reader.count_rows(fragment) == 5 + assert dataset_reader.count_rows( + fragment, filter=ds.field("i64") == 4) == 1 + + assert dataset_reader.count_rows(dataset) == 10 + # Filter on partition key + assert dataset_reader.count_rows( + dataset, filter=ds.field("group") == 1) == 5 + # Filter on data + assert dataset_reader.count_rows(dataset, filter=ds.field("i64") >= 3) == 4 + assert dataset_reader.count_rows(dataset, filter=ds.field("i64") < 0) == 0 + + +def test_abstract_classes(): + classes = [ + ds.FileFormat, + ds.Scanner, + ds.Partitioning, + ] + for klass in classes: + with pytest.raises(TypeError): + klass() + + +def test_partitioning(): + schema = pa.schema([ + pa.field('i64', pa.int64()), + pa.field('f64', pa.float64()) + ]) + for klass in [ds.DirectoryPartitioning, ds.HivePartitioning]: + partitioning = klass(schema) + assert isinstance(partitioning, ds.Partitioning) + + partitioning = ds.DirectoryPartitioning( + pa.schema([ + pa.field('group', pa.int64()), + pa.field('key', pa.float64()) + ]) + ) + assert partitioning.dictionaries is None + expr = partitioning.parse('/3/3.14') + assert isinstance(expr, ds.Expression) + + expected = (ds.field('group') == 3) & (ds.field('key') == 3.14) + assert expr.equals(expected) + + with pytest.raises(pa.ArrowInvalid): + partitioning.parse('/prefix/3/aaa') + + expr = partitioning.parse('/3') + expected = ds.field('group') == 3 + assert expr.equals(expected) + + partitioning = ds.HivePartitioning( + pa.schema([ + pa.field('alpha', pa.int64()), + pa.field('beta', pa.int64()) + ]), + null_fallback='xyz' + ) + assert partitioning.dictionaries is None + expr = partitioning.parse('/alpha=0/beta=3') + expected = ( + (ds.field('alpha') == ds.scalar(0)) & + (ds.field('beta') == ds.scalar(3)) + ) + assert expr.equals(expected) + + expr = partitioning.parse('/alpha=xyz/beta=3') + expected = ( + (ds.field('alpha').is_null() & (ds.field('beta') == ds.scalar(3))) + ) + assert expr.equals(expected) + + for shouldfail in ['/alpha=one/beta=2', '/alpha=one', '/beta=two']: + with pytest.raises(pa.ArrowInvalid): + partitioning.parse(shouldfail) + + +def test_expression_serialization(): + a = ds.scalar(1) + b = ds.scalar(1.1) + c = ds.scalar(True) + d = ds.scalar("string") + e = ds.scalar(None) + f = ds.scalar({'a': 1}) + g = ds.scalar(pa.scalar(1)) + + all_exprs = [a, b, c, d, e, f, g, a == b, a > b, a & b, a | b, ~c, + d.is_valid(), a.cast(pa.int32(), safe=False), + a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), + ds.field('i64') > 5, ds.field('i64') == 5, + ds.field('i64') == 7, ds.field('i64').is_null()] + for expr in all_exprs: + assert isinstance(expr, ds.Expression) + restored = pickle.loads(pickle.dumps(expr)) + assert expr.equals(restored) + + +def test_expression_construction(): + zero = ds.scalar(0) + one = ds.scalar(1) + true = ds.scalar(True) + false = ds.scalar(False) + string = ds.scalar("string") + field = ds.field("field") + + zero | one == string + ~true == false + for typ in ("bool", pa.bool_()): + field.cast(typ) == true + + field.isin([1, 2]) + + with pytest.raises(TypeError): + field.isin(1) + + with pytest.raises(pa.ArrowInvalid): + field != object() + + +def test_expression_boolean_operators(): + # https://issues.apache.org/jira/browse/ARROW-11412 + true = ds.scalar(True) + false = ds.scalar(False) + + with pytest.raises(ValueError, match="cannot be evaluated to python True"): + true and false + + with pytest.raises(ValueError, match="cannot be evaluated to python True"): + true or false + + with pytest.raises(ValueError, match="cannot be evaluated to python True"): + bool(true) + + with pytest.raises(ValueError, match="cannot be evaluated to python True"): + not true + + +def test_expression_arithmetic_operators(): + dataset = ds.dataset(pa.table({'a': [1, 2, 3], 'b': [2, 2, 2]})) + a = ds.field("a") + b = ds.field("b") + result = dataset.to_table(columns={ + "a+1": a + 1, + "b-a": b - a, + "a*2": a * 2, + "a/b": a.cast("float64") / b, + }) + expected = pa.table({ + "a+1": [2, 3, 4], "b-a": [1, 0, -1], + "a*2": [2, 4, 6], "a/b": [0.5, 1.0, 1.5], + }) + assert result.equals(expected) + + +def test_partition_keys(): + a, b, c = [ds.field(f) == f for f in 'abc'] + assert ds._get_partition_keys(a) == {'a': 'a'} + assert ds._get_partition_keys(a & b & c) == {f: f for f in 'abc'} + + nope = ds.field('d') >= 3 + assert ds._get_partition_keys(nope) == {} + assert ds._get_partition_keys(a & nope) == {'a': 'a'} + + null = ds.field('a').is_null() + assert ds._get_partition_keys(null) == {'a': None} + + +def test_parquet_read_options(): + opts1 = ds.ParquetReadOptions() + opts2 = ds.ParquetReadOptions(dictionary_columns=['a', 'b']) + opts3 = ds.ParquetReadOptions(coerce_int96_timestamp_unit="ms") + + assert opts1.dictionary_columns == set() + + assert opts2.dictionary_columns == {'a', 'b'} + + assert opts1.coerce_int96_timestamp_unit == "ns" + assert opts3.coerce_int96_timestamp_unit == "ms" + + assert opts1 == opts1 + assert opts1 != opts2 + assert opts1 != opts3 + + +def test_parquet_file_format_read_options(): + pff1 = ds.ParquetFileFormat() + pff2 = ds.ParquetFileFormat(dictionary_columns={'a'}) + pff3 = ds.ParquetFileFormat(coerce_int96_timestamp_unit="s") + + assert pff1.read_options == ds.ParquetReadOptions() + assert pff2.read_options == ds.ParquetReadOptions(dictionary_columns=['a']) + assert pff3.read_options == ds.ParquetReadOptions( + coerce_int96_timestamp_unit="s") + + +def test_parquet_scan_options(): + opts1 = ds.ParquetFragmentScanOptions() + opts2 = ds.ParquetFragmentScanOptions(buffer_size=4096) + opts3 = ds.ParquetFragmentScanOptions( + buffer_size=2**13, use_buffered_stream=True) + opts4 = ds.ParquetFragmentScanOptions(buffer_size=2**13, pre_buffer=True) + + assert opts1.use_buffered_stream is False + assert opts1.buffer_size == 2**13 + assert opts1.pre_buffer is False + + assert opts2.use_buffered_stream is False + assert opts2.buffer_size == 2**12 + assert opts2.pre_buffer is False + + assert opts3.use_buffered_stream is True + assert opts3.buffer_size == 2**13 + assert opts3.pre_buffer is False + + assert opts4.use_buffered_stream is False + assert opts4.buffer_size == 2**13 + assert opts4.pre_buffer is True + + assert opts1 == opts1 + assert opts1 != opts2 + assert opts2 != opts3 + assert opts3 != opts4 + + +def test_file_format_pickling(): + formats = [ + ds.IpcFileFormat(), + ds.CsvFileFormat(), + ds.CsvFileFormat(pa.csv.ParseOptions(delimiter='\t', + ignore_empty_lines=True)), + ds.CsvFileFormat(read_options=pa.csv.ReadOptions( + skip_rows=3, column_names=['foo'])), + ds.CsvFileFormat(read_options=pa.csv.ReadOptions( + skip_rows=3, block_size=2**20)), + ds.ParquetFileFormat(), + ds.ParquetFileFormat(dictionary_columns={'a'}), + ds.ParquetFileFormat(use_buffered_stream=True), + ds.ParquetFileFormat( + use_buffered_stream=True, + buffer_size=4096, + ) + ] + try: + formats.append(ds.OrcFileFormat()) + except (ImportError, AttributeError): + # catch AttributeError for Python 3.6 + pass + + for file_format in formats: + assert pickle.loads(pickle.dumps(file_format)) == file_format + + +def test_fragment_scan_options_pickling(): + options = [ + ds.CsvFragmentScanOptions(), + ds.CsvFragmentScanOptions( + convert_options=pa.csv.ConvertOptions(strings_can_be_null=True)), + ds.CsvFragmentScanOptions( + read_options=pa.csv.ReadOptions(block_size=2**16)), + ds.ParquetFragmentScanOptions(buffer_size=4096), + ds.ParquetFragmentScanOptions(pre_buffer=True), + ] + for option in options: + assert pickle.loads(pickle.dumps(option)) == option + + +@pytest.mark.parametrize('paths_or_selector', [ + fs.FileSelector('subdir', recursive=True), + [ + 'subdir/1/xxx/file0.parquet', + 'subdir/2/yyy/file1.parquet', + ] +]) +@pytest.mark.parametrize('pre_buffer', [False, True]) +def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer): + format = ds.ParquetFileFormat( + read_options=ds.ParquetReadOptions(dictionary_columns={"str"}), + pre_buffer=pre_buffer + ) + + options = ds.FileSystemFactoryOptions('subdir') + options.partitioning = ds.DirectoryPartitioning( + pa.schema([ + pa.field('group', pa.int32()), + pa.field('key', pa.string()) + ]) + ) + assert options.partition_base_dir == 'subdir' + assert options.selector_ignore_prefixes == ['.', '_'] + assert options.exclude_invalid_files is False + + factory = ds.FileSystemDatasetFactory( + mockfs, paths_or_selector, format, options + ) + inspected_schema = factory.inspect() + + assert factory.inspect().equals(pa.schema([ + pa.field('i64', pa.int64()), + pa.field('f64', pa.float64()), + pa.field('str', pa.dictionary(pa.int32(), pa.string())), + pa.field('const', pa.int64()), + pa.field('group', pa.int32()), + pa.field('key', pa.string()), + ]), check_metadata=False) + + assert isinstance(factory.inspect_schemas(), list) + assert isinstance(factory.finish(inspected_schema), + ds.FileSystemDataset) + assert factory.root_partition.equals(ds.scalar(True)) + + dataset = factory.finish() + assert isinstance(dataset, ds.FileSystemDataset) + + scanner = dataset.scanner() + expected_i64 = pa.array([0, 1, 2, 3, 4], type=pa.int64()) + expected_f64 = pa.array([0, 1, 2, 3, 4], type=pa.float64()) + expected_str = pa.DictionaryArray.from_arrays( + pa.array([0, 1, 2, 3, 4], type=pa.int32()), + pa.array("0 1 2 3 4".split(), type=pa.string()) + ) + iterator = scanner.scan_batches() + for (batch, fragment), group, key in zip(iterator, [1, 2], ['xxx', 'yyy']): + expected_group = pa.array([group] * 5, type=pa.int32()) + expected_key = pa.array([key] * 5, type=pa.string()) + expected_const = pa.array([group - 1] * 5, type=pa.int64()) + # Can't compare or really introspect expressions from Python + assert fragment.partition_expression is not None + assert batch.num_columns == 6 + assert batch[0].equals(expected_i64) + assert batch[1].equals(expected_f64) + assert batch[2].equals(expected_str) + assert batch[3].equals(expected_const) + assert batch[4].equals(expected_group) + assert batch[5].equals(expected_key) + + table = dataset.to_table() + assert isinstance(table, pa.Table) + assert len(table) == 10 + assert table.num_columns == 6 + + +def test_make_fragment(multisourcefs): + parquet_format = ds.ParquetFileFormat() + dataset = ds.dataset('/plain', filesystem=multisourcefs, + format=parquet_format) + + for path in dataset.files: + fragment = parquet_format.make_fragment(path, multisourcefs) + assert fragment.row_groups == [0] + + row_group_fragment = parquet_format.make_fragment(path, multisourcefs, + row_groups=[0]) + for f in [fragment, row_group_fragment]: + assert isinstance(f, ds.ParquetFileFragment) + assert f.path == path + assert isinstance(f.filesystem, type(multisourcefs)) + assert row_group_fragment.row_groups == [0] + + +def test_make_csv_fragment_from_buffer(dataset_reader): + content = textwrap.dedent(""" + alpha,num,animal + a,12,dog + b,11,cat + c,10,rabbit + """) + buffer = pa.py_buffer(content.encode('utf-8')) + + csv_format = ds.CsvFileFormat() + fragment = csv_format.make_fragment(buffer) + + expected = pa.table([['a', 'b', 'c'], + [12, 11, 10], + ['dog', 'cat', 'rabbit']], + names=['alpha', 'num', 'animal']) + assert dataset_reader.to_table(fragment).equals(expected) + + pickled = pickle.loads(pickle.dumps(fragment)) + assert dataset_reader.to_table(pickled).equals(fragment.to_table()) + + +@pytest.mark.parquet +def test_make_parquet_fragment_from_buffer(dataset_reader): + import pyarrow.parquet as pq + + arrays = [ + pa.array(['a', 'b', 'c']), + pa.array([12, 11, 10]), + pa.array(['dog', 'cat', 'rabbit']) + ] + dictionary_arrays = [ + arrays[0].dictionary_encode(), + arrays[1], + arrays[2].dictionary_encode() + ] + dictionary_format = ds.ParquetFileFormat( + read_options=ds.ParquetReadOptions( + dictionary_columns=['alpha', 'animal'] + ), + use_buffered_stream=True, + buffer_size=4096, + ) + + cases = [ + (arrays, ds.ParquetFileFormat()), + (dictionary_arrays, dictionary_format) + ] + for arrays, format_ in cases: + table = pa.table(arrays, names=['alpha', 'num', 'animal']) + + out = pa.BufferOutputStream() + pq.write_table(table, out) + buffer = out.getvalue() + + fragment = format_.make_fragment(buffer) + assert dataset_reader.to_table(fragment).equals(table) + + pickled = pickle.loads(pickle.dumps(fragment)) + assert dataset_reader.to_table(pickled).equals(table) + + +def _create_dataset_for_fragments(tempdir, chunk_size=None, filesystem=None): + import pyarrow.parquet as pq + + table = pa.table( + [range(8), [1] * 8, ['a'] * 4 + ['b'] * 4], + names=['f1', 'f2', 'part'] + ) + + path = str(tempdir / "test_parquet_dataset") + + # write_to_dataset currently requires pandas + pq.write_to_dataset(table, path, + partition_cols=["part"], chunk_size=chunk_size) + dataset = ds.dataset( + path, format="parquet", partitioning="hive", filesystem=filesystem + ) + + return table, dataset + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments(tempdir, dataset_reader): + table, dataset = _create_dataset_for_fragments(tempdir) + + # list fragments + fragments = list(dataset.get_fragments()) + assert len(fragments) == 2 + f = fragments[0] + + physical_names = ['f1', 'f2'] + # file's schema does not include partition column + assert f.physical_schema.names == physical_names + assert f.format.inspect(f.path, f.filesystem) == f.physical_schema + assert f.partition_expression.equals(ds.field('part') == 'a') + + # By default, the partition column is not part of the schema. + result = dataset_reader.to_table(f) + assert result.column_names == physical_names + assert result.equals(table.remove_column(2).slice(0, 4)) + + # scanning fragment includes partition columns when given the proper + # schema. + result = dataset_reader.to_table(f, schema=dataset.schema) + assert result.column_names == ['f1', 'f2', 'part'] + assert result.equals(table.slice(0, 4)) + assert f.physical_schema == result.schema.remove(2) + + # scanning fragments follow filter predicate + result = dataset_reader.to_table( + f, schema=dataset.schema, filter=ds.field('f1') < 2) + assert result.column_names == ['f1', 'f2', 'part'] + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_implicit_cast(tempdir): + # ARROW-8693 + import pyarrow.parquet as pq + + table = pa.table([range(8), [1] * 4 + [2] * 4], names=['col', 'part']) + path = str(tempdir / "test_parquet_dataset") + pq.write_to_dataset(table, path, partition_cols=["part"]) + + part = ds.partitioning(pa.schema([('part', 'int8')]), flavor="hive") + dataset = ds.dataset(path, format="parquet", partitioning=part) + fragments = dataset.get_fragments(filter=ds.field("part") >= 2) + assert len(list(fragments)) == 1 + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_reconstruct(tempdir, dataset_reader): + table, dataset = _create_dataset_for_fragments(tempdir) + + def assert_yields_projected(fragment, row_slice, + columns=None, filter=None): + actual = fragment.to_table( + schema=table.schema, columns=columns, filter=filter) + column_names = columns if columns else table.column_names + assert actual.column_names == column_names + + expected = table.slice(*row_slice).select(column_names) + assert actual.equals(expected) + + fragment = list(dataset.get_fragments())[0] + parquet_format = fragment.format + + # test pickle roundtrip + pickled_fragment = pickle.loads(pickle.dumps(fragment)) + assert dataset_reader.to_table( + pickled_fragment) == dataset_reader.to_table(fragment) + + # manually re-construct a fragment, with explicit schema + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression) + assert dataset_reader.to_table(new_fragment).equals( + dataset_reader.to_table(fragment)) + assert_yields_projected(new_fragment, (0, 4)) + + # filter / column projection, inspected schema + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression) + assert_yields_projected(new_fragment, (0, 2), filter=ds.field('f1') < 2) + + # filter requiring cast / column projection, inspected schema + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression) + assert_yields_projected(new_fragment, (0, 2), + columns=['f1'], filter=ds.field('f1') < 2.0) + + # filter on the partition column + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression) + assert_yields_projected(new_fragment, (0, 4), + filter=ds.field('part') == 'a') + + # Fragments don't contain the partition's columns if not provided to the + # `to_table(schema=...)` method. + pattern = (r'No match for FieldRef.Name\(part\) in ' + + fragment.physical_schema.to_string(False, False, False)) + with pytest.raises(ValueError, match=pattern): + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression) + dataset_reader.to_table(new_fragment, filter=ds.field('part') == 'a') + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_row_groups(tempdir, dataset_reader): + table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2) + + fragment = list(dataset.get_fragments())[0] + + # list and scan row group fragments + row_group_fragments = list(fragment.split_by_row_group()) + assert len(row_group_fragments) == fragment.num_row_groups == 2 + result = dataset_reader.to_table( + row_group_fragments[0], schema=dataset.schema) + assert result.column_names == ['f1', 'f2', 'part'] + assert len(result) == 2 + assert result.equals(table.slice(0, 2)) + + assert row_group_fragments[0].row_groups is not None + assert row_group_fragments[0].num_row_groups == 1 + assert row_group_fragments[0].row_groups[0].statistics == { + 'f1': {'min': 0, 'max': 1}, + 'f2': {'min': 1, 'max': 1}, + } + + fragment = list(dataset.get_fragments(filter=ds.field('f1') < 1))[0] + row_group_fragments = list(fragment.split_by_row_group(ds.field('f1') < 1)) + assert len(row_group_fragments) == 1 + result = dataset_reader.to_table( + row_group_fragments[0], filter=ds.field('f1') < 1) + assert len(result) == 1 + + +@pytest.mark.parquet +def test_fragments_parquet_num_row_groups(tempdir): + import pyarrow.parquet as pq + + table = pa.table({'a': range(8)}) + pq.write_table(table, tempdir / "test.parquet", row_group_size=2) + dataset = ds.dataset(tempdir / "test.parquet", format="parquet") + original_fragment = list(dataset.get_fragments())[0] + + # create fragment with subset of row groups + fragment = original_fragment.format.make_fragment( + original_fragment.path, original_fragment.filesystem, + row_groups=[1, 3]) + assert fragment.num_row_groups == 2 + # ensure that parsing metadata preserves correct number of row groups + fragment.ensure_complete_metadata() + assert fragment.num_row_groups == 2 + assert len(fragment.row_groups) == 2 + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_row_groups_dictionary(tempdir, dataset_reader): + import pandas as pd + + df = pd.DataFrame(dict(col1=['a', 'b'], col2=[1, 2])) + df['col1'] = df['col1'].astype("category") + + import pyarrow.parquet as pq + pq.write_table(pa.table(df), tempdir / "test_filter_dictionary.parquet") + + import pyarrow.dataset as ds + dataset = ds.dataset(tempdir / 'test_filter_dictionary.parquet') + result = dataset_reader.to_table(dataset, filter=ds.field("col1") == "a") + + assert (df.iloc[0] == result.to_pandas()).all().all() + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_ensure_metadata(tempdir, open_logging_fs): + fs, assert_opens = open_logging_fs + _, dataset = _create_dataset_for_fragments( + tempdir, chunk_size=2, filesystem=fs + ) + fragment = list(dataset.get_fragments())[0] + + # with default discovery, no metadata loaded + with assert_opens([fragment.path]): + fragment.ensure_complete_metadata() + assert fragment.row_groups == [0, 1] + + # second time -> use cached / no file IO + with assert_opens([]): + fragment.ensure_complete_metadata() + + # recreate fragment with row group ids + new_fragment = fragment.format.make_fragment( + fragment.path, fragment.filesystem, row_groups=[0, 1] + ) + assert new_fragment.row_groups == fragment.row_groups + + # collect metadata + new_fragment.ensure_complete_metadata() + row_group = new_fragment.row_groups[0] + assert row_group.id == 0 + assert row_group.num_rows == 2 + assert row_group.statistics is not None + + # pickling preserves row group ids + pickled_fragment = pickle.loads(pickle.dumps(new_fragment)) + with assert_opens([fragment.path]): + assert pickled_fragment.row_groups == [0, 1] + row_group = pickled_fragment.row_groups[0] + assert row_group.id == 0 + assert row_group.statistics is not None + + +def _create_dataset_all_types(tempdir, chunk_size=None): + import pyarrow.parquet as pq + + table = pa.table( + [ + pa.array([True, None, False], pa.bool_()), + pa.array([1, 10, 42], pa.int8()), + pa.array([1, 10, 42], pa.uint8()), + pa.array([1, 10, 42], pa.int16()), + pa.array([1, 10, 42], pa.uint16()), + pa.array([1, 10, 42], pa.int32()), + pa.array([1, 10, 42], pa.uint32()), + pa.array([1, 10, 42], pa.int64()), + pa.array([1, 10, 42], pa.uint64()), + pa.array([1.0, 10.0, 42.0], pa.float32()), + pa.array([1.0, 10.0, 42.0], pa.float64()), + pa.array(['a', None, 'z'], pa.utf8()), + pa.array(['a', None, 'z'], pa.binary()), + pa.array([1, 10, 42], pa.timestamp('s')), + pa.array([1, 10, 42], pa.timestamp('ms')), + pa.array([1, 10, 42], pa.timestamp('us')), + pa.array([1, 10, 42], pa.date32()), + pa.array([1, 10, 4200000000], pa.date64()), + pa.array([1, 10, 42], pa.time32('s')), + pa.array([1, 10, 42], pa.time64('us')), + ], + names=[ + 'boolean', + 'int8', + 'uint8', + 'int16', + 'uint16', + 'int32', + 'uint32', + 'int64', + 'uint64', + 'float', + 'double', + 'utf8', + 'binary', + 'ts[s]', + 'ts[ms]', + 'ts[us]', + 'date32', + 'date64', + 'time32', + 'time64', + ] + ) + + path = str(tempdir / "test_parquet_dataset_all_types") + + # write_to_dataset currently requires pandas + pq.write_to_dataset(table, path, chunk_size=chunk_size) + + return table, ds.dataset(path, format="parquet", partitioning="hive") + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_parquet_fragment_statistics(tempdir): + table, dataset = _create_dataset_all_types(tempdir) + + fragment = list(dataset.get_fragments())[0] + + import datetime + def dt_s(x): return datetime.datetime(1970, 1, 1, 0, 0, x) + def dt_ms(x): return datetime.datetime(1970, 1, 1, 0, 0, 0, x*1000) + def dt_us(x): return datetime.datetime(1970, 1, 1, 0, 0, 0, x) + date = datetime.date + time = datetime.time + + # list and scan row group fragments + row_group_fragments = list(fragment.split_by_row_group()) + assert row_group_fragments[0].row_groups is not None + row_group = row_group_fragments[0].row_groups[0] + assert row_group.num_rows == 3 + assert row_group.total_byte_size > 1000 + assert row_group.statistics == { + 'boolean': {'min': False, 'max': True}, + 'int8': {'min': 1, 'max': 42}, + 'uint8': {'min': 1, 'max': 42}, + 'int16': {'min': 1, 'max': 42}, + 'uint16': {'min': 1, 'max': 42}, + 'int32': {'min': 1, 'max': 42}, + 'uint32': {'min': 1, 'max': 42}, + 'int64': {'min': 1, 'max': 42}, + 'uint64': {'min': 1, 'max': 42}, + 'float': {'min': 1.0, 'max': 42.0}, + 'double': {'min': 1.0, 'max': 42.0}, + 'utf8': {'min': 'a', 'max': 'z'}, + 'binary': {'min': b'a', 'max': b'z'}, + 'ts[s]': {'min': dt_s(1), 'max': dt_s(42)}, + 'ts[ms]': {'min': dt_ms(1), 'max': dt_ms(42)}, + 'ts[us]': {'min': dt_us(1), 'max': dt_us(42)}, + 'date32': {'min': date(1970, 1, 2), 'max': date(1970, 2, 12)}, + 'date64': {'min': date(1970, 1, 1), 'max': date(1970, 2, 18)}, + 'time32': {'min': time(0, 0, 1), 'max': time(0, 0, 42)}, + 'time64': {'min': time(0, 0, 0, 1), 'max': time(0, 0, 0, 42)}, + } + + +@pytest.mark.parquet +def test_parquet_fragment_statistics_nulls(tempdir): + import pyarrow.parquet as pq + + table = pa.table({'a': [0, 1, None, None], 'b': ['a', 'b', None, None]}) + pq.write_table(table, tempdir / "test.parquet", row_group_size=2) + + dataset = ds.dataset(tempdir / "test.parquet", format="parquet") + fragments = list(dataset.get_fragments())[0].split_by_row_group() + # second row group has all nulls -> no statistics + assert fragments[1].row_groups[0].statistics == {} + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_parquet_empty_row_group_statistics(tempdir): + df = pd.DataFrame({"a": ["a", "b", "b"], "b": [4, 5, 6]})[:0] + df.to_parquet(tempdir / "test.parquet", engine="pyarrow") + + dataset = ds.dataset(tempdir / "test.parquet", format="parquet") + fragments = list(dataset.get_fragments())[0].split_by_row_group() + # Only row group is empty + assert fragments[0].row_groups[0].statistics == {} + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_row_groups_predicate(tempdir): + table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2) + + fragment = list(dataset.get_fragments())[0] + assert fragment.partition_expression.equals(ds.field('part') == 'a') + + # predicate may reference a partition field not present in the + # physical_schema if an explicit schema is provided to split_by_row_group + + # filter matches partition_expression: all row groups + row_group_fragments = list( + fragment.split_by_row_group(filter=ds.field('part') == 'a', + schema=dataset.schema)) + assert len(row_group_fragments) == 2 + + # filter contradicts partition_expression: no row groups + row_group_fragments = list( + fragment.split_by_row_group(filter=ds.field('part') == 'b', + schema=dataset.schema)) + assert len(row_group_fragments) == 0 + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_row_groups_reconstruct(tempdir, dataset_reader): + table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=2) + + fragment = list(dataset.get_fragments())[0] + parquet_format = fragment.format + row_group_fragments = list(fragment.split_by_row_group()) + + # test pickle roundtrip + pickled_fragment = pickle.loads(pickle.dumps(fragment)) + assert dataset_reader.to_table( + pickled_fragment) == dataset_reader.to_table(fragment) + + # manually re-construct row group fragments + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression, + row_groups=[0]) + result = dataset_reader.to_table(new_fragment) + assert result.equals(dataset_reader.to_table(row_group_fragments[0])) + + # manually re-construct a row group fragment with filter/column projection + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression, + row_groups={1}) + result = dataset_reader.to_table( + new_fragment, schema=table.schema, columns=['f1', 'part'], + filter=ds.field('f1') < 3, ) + assert result.column_names == ['f1', 'part'] + assert len(result) == 1 + + # out of bounds row group index + new_fragment = parquet_format.make_fragment( + fragment.path, fragment.filesystem, + partition_expression=fragment.partition_expression, + row_groups={2}) + with pytest.raises(IndexError, match="references row group 2"): + dataset_reader.to_table(new_fragment) + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_subset_ids(tempdir, open_logging_fs, + dataset_reader): + fs, assert_opens = open_logging_fs + table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=1, + filesystem=fs) + fragment = list(dataset.get_fragments())[0] + + # select with row group ids + subfrag = fragment.subset(row_group_ids=[0, 3]) + with assert_opens([]): + assert subfrag.num_row_groups == 2 + assert subfrag.row_groups == [0, 3] + assert subfrag.row_groups[0].statistics is not None + + # check correct scan result of subset + result = dataset_reader.to_table(subfrag) + assert result.to_pydict() == {"f1": [0, 3], "f2": [1, 1]} + + # empty list of ids + subfrag = fragment.subset(row_group_ids=[]) + assert subfrag.num_row_groups == 0 + assert subfrag.row_groups == [] + result = dataset_reader.to_table(subfrag, schema=dataset.schema) + assert result.num_rows == 0 + assert result.equals(table[:0]) + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_subset_filter(tempdir, open_logging_fs, + dataset_reader): + fs, assert_opens = open_logging_fs + table, dataset = _create_dataset_for_fragments(tempdir, chunk_size=1, + filesystem=fs) + fragment = list(dataset.get_fragments())[0] + + # select with filter + subfrag = fragment.subset(ds.field("f1") >= 1) + with assert_opens([]): + assert subfrag.num_row_groups == 3 + assert len(subfrag.row_groups) == 3 + assert subfrag.row_groups[0].statistics is not None + + # check correct scan result of subset + result = dataset_reader.to_table(subfrag) + assert result.to_pydict() == {"f1": [1, 2, 3], "f2": [1, 1, 1]} + + # filter that results in empty selection + subfrag = fragment.subset(ds.field("f1") > 5) + assert subfrag.num_row_groups == 0 + assert subfrag.row_groups == [] + result = dataset_reader.to_table(subfrag, schema=dataset.schema) + assert result.num_rows == 0 + assert result.equals(table[:0]) + + # passing schema to ensure filter on partition expression works + subfrag = fragment.subset(ds.field("part") == "a", schema=dataset.schema) + assert subfrag.num_row_groups == 4 + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_parquet_subset_invalid(tempdir): + _, dataset = _create_dataset_for_fragments(tempdir, chunk_size=1) + fragment = list(dataset.get_fragments())[0] + + # passing none or both of filter / row_group_ids + with pytest.raises(ValueError): + fragment.subset(ds.field("f1") >= 1, row_group_ids=[1, 2]) + + with pytest.raises(ValueError): + fragment.subset() + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_fragments_repr(tempdir, dataset): + # partitioned parquet dataset + fragment = list(dataset.get_fragments())[0] + assert ( + repr(fragment) == + "<pyarrow.dataset.ParquetFileFragment path=subdir/1/xxx/file0.parquet " + "partition=[key=xxx, group=1]>" + ) + + # single-file parquet dataset (no partition information in repr) + table, path = _create_single_file(tempdir) + dataset = ds.dataset(path, format="parquet") + fragment = list(dataset.get_fragments())[0] + assert ( + repr(fragment) == + "<pyarrow.dataset.ParquetFileFragment path={}>".format( + dataset.filesystem.normalize_path(str(path))) + ) + + # non-parquet format + path = tempdir / "data.feather" + pa.feather.write_feather(table, path) + dataset = ds.dataset(path, format="feather") + fragment = list(dataset.get_fragments())[0] + assert ( + repr(fragment) == + "<pyarrow.dataset.FileFragment type=ipc path={}>".format( + dataset.filesystem.normalize_path(str(path))) + ) + + +def test_partitioning_factory(mockfs): + paths_or_selector = fs.FileSelector('subdir', recursive=True) + format = ds.ParquetFileFormat() + + options = ds.FileSystemFactoryOptions('subdir') + partitioning_factory = ds.DirectoryPartitioning.discover(['group', 'key']) + assert isinstance(partitioning_factory, ds.PartitioningFactory) + options.partitioning_factory = partitioning_factory + + factory = ds.FileSystemDatasetFactory( + mockfs, paths_or_selector, format, options + ) + inspected_schema = factory.inspect() + # i64/f64 from data, group/key from "/1/xxx" and "/2/yyy" paths + expected_schema = pa.schema([ + ("i64", pa.int64()), + ("f64", pa.float64()), + ("str", pa.string()), + ("const", pa.int64()), + ("group", pa.int32()), + ("key", pa.string()), + ]) + assert inspected_schema.equals(expected_schema) + + hive_partitioning_factory = ds.HivePartitioning.discover() + assert isinstance(hive_partitioning_factory, ds.PartitioningFactory) + + +@pytest.mark.parametrize('infer_dictionary', [False, True]) +def test_partitioning_factory_dictionary(mockfs, infer_dictionary): + paths_or_selector = fs.FileSelector('subdir', recursive=True) + format = ds.ParquetFileFormat() + options = ds.FileSystemFactoryOptions('subdir') + + options.partitioning_factory = ds.DirectoryPartitioning.discover( + ['group', 'key'], infer_dictionary=infer_dictionary) + + factory = ds.FileSystemDatasetFactory( + mockfs, paths_or_selector, format, options) + + inferred_schema = factory.inspect() + if infer_dictionary: + expected_type = pa.dictionary(pa.int32(), pa.string()) + assert inferred_schema.field('key').type == expected_type + + table = factory.finish().to_table().combine_chunks() + actual = table.column('key').chunk(0) + expected = pa.array(['xxx'] * 5 + ['yyy'] * 5).dictionary_encode() + assert actual.equals(expected) + + # ARROW-9345 ensure filtering on the partition field works + table = factory.finish().to_table(filter=ds.field('key') == 'xxx') + actual = table.column('key').chunk(0) + expected = expected.slice(0, 5) + assert actual.equals(expected) + else: + assert inferred_schema.field('key').type == pa.string() + + +def test_partitioning_factory_segment_encoding(): + mockfs = fs._MockFileSystem() + format = ds.IpcFileFormat() + schema = pa.schema([("i64", pa.int64())]) + table = pa.table([pa.array(range(10))], schema=schema) + partition_schema = pa.schema( + [("date", pa.timestamp("s")), ("string", pa.string())]) + string_partition_schema = pa.schema( + [("date", pa.string()), ("string", pa.string())]) + full_schema = pa.schema(list(schema) + list(partition_schema)) + for directory in [ + "directory/2021-05-04 00%3A00%3A00/%24", + "hive/date=2021-05-04 00%3A00%3A00/string=%24", + ]: + mockfs.create_dir(directory) + with mockfs.open_output_stream(directory + "/0.feather") as sink: + with pa.ipc.new_file(sink, schema) as writer: + writer.write_table(table) + writer.close() + + # Directory + selector = fs.FileSelector("directory", recursive=True) + options = ds.FileSystemFactoryOptions("directory") + options.partitioning_factory = ds.DirectoryPartitioning.discover( + schema=partition_schema) + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + inferred_schema = factory.inspect() + assert inferred_schema == full_schema + actual = factory.finish().to_table(columns={ + "date_int": ds.field("date").cast(pa.int64()), + }) + assert actual[0][0].as_py() == 1620086400 + + options.partitioning_factory = ds.DirectoryPartitioning.discover( + ["date", "string"], segment_encoding="none") + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + fragments = list(factory.finish().get_fragments()) + assert fragments[0].partition_expression.equals( + (ds.field("date") == "2021-05-04 00%3A00%3A00") & + (ds.field("string") == "%24")) + + options.partitioning = ds.DirectoryPartitioning( + string_partition_schema, segment_encoding="none") + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + fragments = list(factory.finish().get_fragments()) + assert fragments[0].partition_expression.equals( + (ds.field("date") == "2021-05-04 00%3A00%3A00") & + (ds.field("string") == "%24")) + + options.partitioning_factory = ds.DirectoryPartitioning.discover( + schema=partition_schema, segment_encoding="none") + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + with pytest.raises(pa.ArrowInvalid, + match="Could not cast segments for partition field"): + inferred_schema = factory.inspect() + + # Hive + selector = fs.FileSelector("hive", recursive=True) + options = ds.FileSystemFactoryOptions("hive") + options.partitioning_factory = ds.HivePartitioning.discover( + schema=partition_schema) + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + inferred_schema = factory.inspect() + assert inferred_schema == full_schema + actual = factory.finish().to_table(columns={ + "date_int": ds.field("date").cast(pa.int64()), + }) + assert actual[0][0].as_py() == 1620086400 + + options.partitioning_factory = ds.HivePartitioning.discover( + segment_encoding="none") + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + fragments = list(factory.finish().get_fragments()) + assert fragments[0].partition_expression.equals( + (ds.field("date") == "2021-05-04 00%3A00%3A00") & + (ds.field("string") == "%24")) + + options.partitioning = ds.HivePartitioning( + string_partition_schema, segment_encoding="none") + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + fragments = list(factory.finish().get_fragments()) + assert fragments[0].partition_expression.equals( + (ds.field("date") == "2021-05-04 00%3A00%3A00") & + (ds.field("string") == "%24")) + + options.partitioning_factory = ds.HivePartitioning.discover( + schema=partition_schema, segment_encoding="none") + factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options) + with pytest.raises(pa.ArrowInvalid, + match="Could not cast segments for partition field"): + inferred_schema = factory.inspect() + + +def test_dictionary_partitioning_outer_nulls_raises(tempdir): + table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']}) + part = ds.partitioning( + pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())])) + with pytest.raises(pa.ArrowInvalid): + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + + +def _has_subdirs(basedir): + elements = os.listdir(basedir) + return any([os.path.isdir(os.path.join(basedir, el)) for el in elements]) + + +def _do_list_all_dirs(basedir, path_so_far, result): + for f in os.listdir(basedir): + true_nested = os.path.join(basedir, f) + if os.path.isdir(true_nested): + norm_nested = posixpath.join(path_so_far, f) + if _has_subdirs(true_nested): + _do_list_all_dirs(true_nested, norm_nested, result) + else: + result.append(norm_nested) + + +def _list_all_dirs(basedir): + result = [] + _do_list_all_dirs(basedir, '', result) + return result + + +def _check_dataset_directories(tempdir, expected_directories): + actual_directories = set(_list_all_dirs(tempdir)) + assert actual_directories == set(expected_directories) + + +def test_dictionary_partitioning_inner_nulls(tempdir): + table = pa.table({'a': ['x', 'y', 'z'], 'b': ['x', 'y', None]}) + part = ds.partitioning( + pa.schema([pa.field('a', pa.string()), pa.field('b', pa.string())])) + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + _check_dataset_directories(tempdir, ['x/x', 'y/y', 'z']) + + +def test_hive_partitioning_nulls(tempdir): + table = pa.table({'a': ['x', None, 'z'], 'b': ['x', 'y', None]}) + part = ds.HivePartitioning(pa.schema( + [pa.field('a', pa.string()), pa.field('b', pa.string())]), None, 'xyz') + ds.write_dataset(table, tempdir, format='parquet', partitioning=part) + _check_dataset_directories(tempdir, ['a=x/b=x', 'a=xyz/b=y', 'a=z/b=xyz']) + + +def test_partitioning_function(): + schema = pa.schema([("year", pa.int16()), ("month", pa.int8())]) + names = ["year", "month"] + + # default DirectoryPartitioning + part = ds.partitioning(schema) + assert isinstance(part, ds.DirectoryPartitioning) + part = ds.partitioning(schema, dictionaries="infer") + assert isinstance(part, ds.PartitioningFactory) + part = ds.partitioning(field_names=names) + assert isinstance(part, ds.PartitioningFactory) + # needs schema or list of names + with pytest.raises(ValueError): + ds.partitioning() + with pytest.raises(ValueError, match="Expected list"): + ds.partitioning(field_names=schema) + with pytest.raises(ValueError, match="Cannot specify both"): + ds.partitioning(schema, field_names=schema) + + # Hive partitioning + part = ds.partitioning(schema, flavor="hive") + assert isinstance(part, ds.HivePartitioning) + part = ds.partitioning(schema, dictionaries="infer", flavor="hive") + assert isinstance(part, ds.PartitioningFactory) + part = ds.partitioning(flavor="hive") + assert isinstance(part, ds.PartitioningFactory) + # cannot pass list of names + with pytest.raises(ValueError): + ds.partitioning(names, flavor="hive") + with pytest.raises(ValueError, match="Cannot specify 'field_names'"): + ds.partitioning(field_names=names, flavor="hive") + + # unsupported flavor + with pytest.raises(ValueError): + ds.partitioning(schema, flavor="unsupported") + + +def test_directory_partitioning_dictionary_key(mockfs): + # ARROW-8088 specifying partition key as dictionary type + schema = pa.schema([ + pa.field('group', pa.dictionary(pa.int8(), pa.int32())), + pa.field('key', pa.dictionary(pa.int8(), pa.string())) + ]) + part = ds.DirectoryPartitioning.discover(schema=schema) + + dataset = ds.dataset( + "subdir", format="parquet", filesystem=mockfs, partitioning=part + ) + assert dataset.partitioning.schema == schema + table = dataset.to_table() + + assert table.column('group').type.equals(schema.types[0]) + assert table.column('group').to_pylist() == [1] * 5 + [2] * 5 + assert table.column('key').type.equals(schema.types[1]) + assert table.column('key').to_pylist() == ['xxx'] * 5 + ['yyy'] * 5 + + +def test_hive_partitioning_dictionary_key(multisourcefs): + # ARROW-8088 specifying partition key as dictionary type + schema = pa.schema([ + pa.field('year', pa.dictionary(pa.int8(), pa.int16())), + pa.field('month', pa.dictionary(pa.int8(), pa.int16())) + ]) + part = ds.HivePartitioning.discover(schema=schema) + + dataset = ds.dataset( + "hive", format="parquet", filesystem=multisourcefs, partitioning=part + ) + assert dataset.partitioning.schema == schema + table = dataset.to_table() + + year_dictionary = list(range(2006, 2011)) + month_dictionary = list(range(1, 13)) + assert table.column('year').type.equals(schema.types[0]) + for chunk in table.column('year').chunks: + actual = chunk.dictionary.to_pylist() + actual.sort() + assert actual == year_dictionary + assert table.column('month').type.equals(schema.types[1]) + for chunk in table.column('month').chunks: + actual = chunk.dictionary.to_pylist() + actual.sort() + assert actual == month_dictionary + + +def _create_single_file(base_dir, table=None, row_group_size=None): + import pyarrow.parquet as pq + if table is None: + table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5}) + path = base_dir / "test.parquet" + pq.write_table(table, path, row_group_size=row_group_size) + return table, path + + +def _create_directory_of_files(base_dir): + import pyarrow.parquet as pq + table1 = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5}) + path1 = base_dir / "test1.parquet" + pq.write_table(table1, path1) + table2 = pa.table({'a': range(9, 18), 'b': [0.] * 4 + [1.] * 5}) + path2 = base_dir / "test2.parquet" + pq.write_table(table2, path2) + return (table1, table2), (path1, path2) + + +def _check_dataset(dataset, table, dataset_reader): + # also test that pickle roundtrip keeps the functionality + for d in [dataset, pickle.loads(pickle.dumps(dataset))]: + assert dataset.schema.equals(table.schema) + assert dataset_reader.to_table(dataset).equals(table) + + +def _check_dataset_from_path(path, table, dataset_reader, **kwargs): + # pathlib object + assert isinstance(path, pathlib.Path) + + # accept Path, str, List[Path], List[str] + for p in [path, str(path), [path], [str(path)]]: + dataset = ds.dataset(path, **kwargs) + assert isinstance(dataset, ds.FileSystemDataset) + _check_dataset(dataset, table, dataset_reader) + + # relative string path + with change_cwd(path.parent): + dataset = ds.dataset(path.name, **kwargs) + assert isinstance(dataset, ds.FileSystemDataset) + _check_dataset(dataset, table, dataset_reader) + + +@pytest.mark.parquet +def test_open_dataset_single_file(tempdir, dataset_reader): + table, path = _create_single_file(tempdir) + _check_dataset_from_path(path, table, dataset_reader) + + +@pytest.mark.parquet +def test_deterministic_row_order(tempdir, dataset_reader): + # ARROW-8447 Ensure that dataset.to_table (and Scanner::ToTable) returns a + # deterministic row ordering. This is achieved by constructing a single + # parquet file with one row per RowGroup. + table, path = _create_single_file(tempdir, row_group_size=1) + _check_dataset_from_path(path, table, dataset_reader) + + +@pytest.mark.parquet +def test_open_dataset_directory(tempdir, dataset_reader): + tables, _ = _create_directory_of_files(tempdir) + table = pa.concat_tables(tables) + _check_dataset_from_path(tempdir, table, dataset_reader) + + +@pytest.mark.parquet +def test_open_dataset_list_of_files(tempdir, dataset_reader): + tables, (path1, path2) = _create_directory_of_files(tempdir) + table = pa.concat_tables(tables) + + datasets = [ + ds.dataset([path1, path2]), + ds.dataset([str(path1), str(path2)]) + ] + datasets += [ + pickle.loads(pickle.dumps(d)) for d in datasets + ] + + for dataset in datasets: + assert dataset.schema.equals(table.schema) + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + +@pytest.mark.parquet +def test_open_dataset_filesystem_fspath(tempdir): + # single file + table, path = _create_single_file(tempdir) + + fspath = FSProtocolClass(path) + + # filesystem inferred from path + dataset1 = ds.dataset(fspath) + assert dataset1.schema.equals(table.schema) + + # filesystem specified + dataset2 = ds.dataset(fspath, filesystem=fs.LocalFileSystem()) + assert dataset2.schema.equals(table.schema) + + # passing different filesystem + with pytest.raises(TypeError): + ds.dataset(fspath, filesystem=fs._MockFileSystem()) + + +def test_construct_from_single_file(tempdir, dataset_reader): + directory = tempdir / 'single-file' + directory.mkdir() + table, path = _create_single_file(directory) + relative_path = path.relative_to(directory) + + # instantiate from a single file + d1 = ds.dataset(path) + # instantiate from a single file with a filesystem object + d2 = ds.dataset(path, filesystem=fs.LocalFileSystem()) + # instantiate from a single file with prefixed filesystem URI + d3 = ds.dataset(str(relative_path), filesystem=_filesystem_uri(directory)) + # pickle roundtrip + d4 = pickle.loads(pickle.dumps(d1)) + + assert dataset_reader.to_table(d1) == dataset_reader.to_table( + d2) == dataset_reader.to_table(d3) == dataset_reader.to_table(d4) + + +def test_construct_from_single_directory(tempdir, dataset_reader): + directory = tempdir / 'single-directory' + directory.mkdir() + tables, paths = _create_directory_of_files(directory) + + d1 = ds.dataset(directory) + d2 = ds.dataset(directory, filesystem=fs.LocalFileSystem()) + d3 = ds.dataset(directory.name, filesystem=_filesystem_uri(tempdir)) + t1 = dataset_reader.to_table(d1) + t2 = dataset_reader.to_table(d2) + t3 = dataset_reader.to_table(d3) + assert t1 == t2 == t3 + + # test pickle roundtrip + for d in [d1, d2, d3]: + restored = pickle.loads(pickle.dumps(d)) + assert dataset_reader.to_table(restored) == t1 + + +def test_construct_from_list_of_files(tempdir, dataset_reader): + # instantiate from a list of files + directory = tempdir / 'list-of-files' + directory.mkdir() + tables, paths = _create_directory_of_files(directory) + + relative_paths = [p.relative_to(tempdir) for p in paths] + with change_cwd(tempdir): + d1 = ds.dataset(relative_paths) + t1 = dataset_reader.to_table(d1) + assert len(t1) == sum(map(len, tables)) + + d2 = ds.dataset(relative_paths, filesystem=_filesystem_uri(tempdir)) + t2 = dataset_reader.to_table(d2) + d3 = ds.dataset(paths) + t3 = dataset_reader.to_table(d3) + d4 = ds.dataset(paths, filesystem=fs.LocalFileSystem()) + t4 = dataset_reader.to_table(d4) + + assert t1 == t2 == t3 == t4 + + +def test_construct_from_list_of_mixed_paths_fails(mockfs): + # isntantiate from a list of mixed paths + files = [ + 'subdir/1/xxx/file0.parquet', + 'subdir/1/xxx/doesnt-exist.parquet', + ] + with pytest.raises(FileNotFoundError, match='doesnt-exist'): + ds.dataset(files, filesystem=mockfs) + + +def test_construct_from_mixed_child_datasets(mockfs): + # isntantiate from a list of mixed paths + a = ds.dataset(['subdir/1/xxx/file0.parquet', + 'subdir/2/yyy/file1.parquet'], filesystem=mockfs) + b = ds.dataset('subdir', filesystem=mockfs) + + dataset = ds.dataset([a, b]) + + assert isinstance(dataset, ds.UnionDataset) + assert len(list(dataset.get_fragments())) == 4 + + table = dataset.to_table() + assert len(table) == 20 + assert table.num_columns == 4 + + assert len(dataset.children) == 2 + for child in dataset.children: + assert child.files == ['subdir/1/xxx/file0.parquet', + 'subdir/2/yyy/file1.parquet'] + + +def test_construct_empty_dataset(): + empty = ds.dataset([]) + table = empty.to_table() + assert table.num_rows == 0 + assert table.num_columns == 0 + + +def test_construct_dataset_with_invalid_schema(): + empty = ds.dataset([], schema=pa.schema([ + ('a', pa.int64()), + ('a', pa.string()) + ])) + with pytest.raises(ValueError, match='Multiple matches for .*a.* in '): + empty.to_table() + + +def test_construct_from_invalid_sources_raise(multisourcefs): + child1 = ds.FileSystemDatasetFactory( + multisourcefs, + fs.FileSelector('/plain'), + format=ds.ParquetFileFormat() + ) + child2 = ds.FileSystemDatasetFactory( + multisourcefs, + fs.FileSelector('/schema'), + format=ds.ParquetFileFormat() + ) + batch1 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + batch2 = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["b"]) + + with pytest.raises(TypeError, match='Expected.*FileSystemDatasetFactory'): + ds.dataset([child1, child2]) + + expected = ( + "Expected a list of path-like or dataset objects, or a list " + "of batches or tables. The given list contains the following " + "types: int" + ) + with pytest.raises(TypeError, match=expected): + ds.dataset([1, 2, 3]) + + expected = ( + "Expected a path-like, list of path-likes or a list of Datasets " + "instead of the given type: NoneType" + ) + with pytest.raises(TypeError, match=expected): + ds.dataset(None) + + expected = ( + "Expected a path-like, list of path-likes or a list of Datasets " + "instead of the given type: generator" + ) + with pytest.raises(TypeError, match=expected): + ds.dataset((batch1 for _ in range(3))) + + expected = ( + "Must provide schema to construct in-memory dataset from an empty list" + ) + with pytest.raises(ValueError, match=expected): + ds.InMemoryDataset([]) + + expected = ( + "Item has schema\nb: int64\nwhich does not match expected schema\n" + "a: int64" + ) + with pytest.raises(TypeError, match=expected): + ds.dataset([batch1, batch2]) + + expected = ( + "Expected a list of path-like or dataset objects, or a list of " + "batches or tables. The given list contains the following types:" + ) + with pytest.raises(TypeError, match=expected): + ds.dataset([batch1, 0]) + + expected = ( + "Expected a list of tables or batches. The given list contains a int" + ) + with pytest.raises(TypeError, match=expected): + ds.InMemoryDataset([batch1, 0]) + + +def test_construct_in_memory(dataset_reader): + batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + table = pa.Table.from_batches([batch]) + + assert ds.dataset([], schema=pa.schema([])).to_table() == pa.table([]) + + for source in (batch, table, [batch], [table]): + dataset = ds.dataset(source) + assert dataset_reader.to_table(dataset) == table + assert len(list(dataset.get_fragments())) == 1 + assert next(dataset.get_fragments()).to_table() == table + assert pa.Table.from_batches(list(dataset.to_batches())) == table + + +@pytest.mark.parametrize('use_threads,use_async', + [(False, False), (False, True), + (True, False), (True, True)]) +def test_scan_iterator(use_threads, use_async): + batch = pa.RecordBatch.from_arrays([pa.array(range(10))], names=["a"]) + table = pa.Table.from_batches([batch]) + # When constructed from readers/iterators, should be one-shot + match = "OneShotFragment was already scanned" + for factory, schema in ( + (lambda: pa.ipc.RecordBatchReader.from_batches( + batch.schema, [batch]), None), + (lambda: (batch for _ in range(1)), batch.schema), + ): + # Scanning the fragment consumes the underlying iterator + scanner = ds.Scanner.from_batches( + factory(), schema=schema, use_threads=use_threads, + use_async=use_async) + assert scanner.to_table() == table + with pytest.raises(pa.ArrowInvalid, match=match): + scanner.to_table() + + +def _create_partitioned_dataset(basedir): + import pyarrow.parquet as pq + table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5}) + + path = basedir / "dataset-partitioned" + path.mkdir() + + for i in range(3): + part = path / "part={}".format(i) + part.mkdir() + pq.write_table(table.slice(3*i, 3), part / "test.parquet") + + full_table = table.append_column( + "part", pa.array(np.repeat([0, 1, 2], 3), type=pa.int32())) + + return full_table, path + + +@pytest.mark.parquet +def test_open_dataset_partitioned_directory(tempdir, dataset_reader): + full_table, path = _create_partitioned_dataset(tempdir) + + # no partitioning specified, just read all individual files + table = full_table.select(['a', 'b']) + _check_dataset_from_path(path, table, dataset_reader) + + # specify partition scheme with discovery + dataset = ds.dataset( + str(path), partitioning=ds.partitioning(flavor="hive")) + assert dataset.schema.equals(full_table.schema) + + # specify partition scheme with discovery and relative path + with change_cwd(tempdir): + dataset = ds.dataset("dataset-partitioned/", + partitioning=ds.partitioning(flavor="hive")) + assert dataset.schema.equals(full_table.schema) + + # specify partition scheme with string short-cut + dataset = ds.dataset(str(path), partitioning="hive") + assert dataset.schema.equals(full_table.schema) + + # specify partition scheme with explicit scheme + dataset = ds.dataset( + str(path), + partitioning=ds.partitioning( + pa.schema([("part", pa.int8())]), flavor="hive")) + expected_schema = table.schema.append(pa.field("part", pa.int8())) + assert dataset.schema.equals(expected_schema) + + result = dataset.to_table() + expected = table.append_column( + "part", pa.array(np.repeat([0, 1, 2], 3), type=pa.int8())) + assert result.equals(expected) + + +@pytest.mark.parquet +def test_open_dataset_filesystem(tempdir): + # single file + table, path = _create_single_file(tempdir) + + # filesystem inferred from path + dataset1 = ds.dataset(str(path)) + assert dataset1.schema.equals(table.schema) + + # filesystem specified + dataset2 = ds.dataset(str(path), filesystem=fs.LocalFileSystem()) + assert dataset2.schema.equals(table.schema) + + # local filesystem specified with relative path + with change_cwd(tempdir): + dataset3 = ds.dataset("test.parquet", filesystem=fs.LocalFileSystem()) + assert dataset3.schema.equals(table.schema) + + # passing different filesystem + with pytest.raises(FileNotFoundError): + ds.dataset(str(path), filesystem=fs._MockFileSystem()) + + +@pytest.mark.parquet +def test_open_dataset_unsupported_format(tempdir): + _, path = _create_single_file(tempdir) + with pytest.raises(ValueError, match="format 'blabla' is not supported"): + ds.dataset([path], format="blabla") + + +@pytest.mark.parquet +def test_open_union_dataset(tempdir, dataset_reader): + _, path = _create_single_file(tempdir) + dataset = ds.dataset(path) + + union = ds.dataset([dataset, dataset]) + assert isinstance(union, ds.UnionDataset) + + pickled = pickle.loads(pickle.dumps(union)) + assert dataset_reader.to_table(pickled) == dataset_reader.to_table(union) + + +def test_open_union_dataset_with_additional_kwargs(multisourcefs): + child = ds.dataset('/plain', filesystem=multisourcefs, format='parquet') + with pytest.raises(ValueError, match="cannot pass any additional"): + ds.dataset([child], format="parquet") + + +def test_open_dataset_non_existing_file(): + # ARROW-8213: Opening a dataset with a local incorrect path gives confusing + # error message + with pytest.raises(FileNotFoundError): + ds.dataset('i-am-not-existing.parquet', format='parquet') + + with pytest.raises(pa.ArrowInvalid, match='cannot be relative'): + ds.dataset('file:i-am-not-existing.parquet', format='parquet') + + +@pytest.mark.parquet +@pytest.mark.parametrize('partitioning', ["directory", "hive"]) +@pytest.mark.parametrize('null_fallback', ['xyz', None]) +@pytest.mark.parametrize('infer_dictionary', [False, True]) +@pytest.mark.parametrize('partition_keys', [ + (["A", "B", "C"], [1, 2, 3]), + ([1, 2, 3], ["A", "B", "C"]), + (["A", "B", "C"], ["D", "E", "F"]), + ([1, 2, 3], [4, 5, 6]), + ([1, None, 3], ["A", "B", "C"]), + ([1, 2, 3], ["A", None, "C"]), + ([None, 2, 3], [None, 2, 3]), +]) +def test_partition_discovery( + tempdir, partitioning, null_fallback, infer_dictionary, partition_keys +): + # ARROW-9288 / ARROW-9476 + import pyarrow.parquet as pq + + table = pa.table({'a': range(9), 'b': [0.0] * 4 + [1.0] * 5}) + + has_null = None in partition_keys[0] or None in partition_keys[1] + if partitioning == "directory" and has_null: + # Directory partitioning can't handle the first part being null + return + + if partitioning == "directory": + partitioning = ds.DirectoryPartitioning.discover( + ["part1", "part2"], infer_dictionary=infer_dictionary) + fmt = "{0}/{1}" + null_value = None + else: + if null_fallback: + partitioning = ds.HivePartitioning.discover( + infer_dictionary=infer_dictionary, null_fallback=null_fallback + ) + else: + partitioning = ds.HivePartitioning.discover( + infer_dictionary=infer_dictionary) + fmt = "part1={0}/part2={1}" + if null_fallback: + null_value = null_fallback + else: + null_value = "__HIVE_DEFAULT_PARTITION__" + + basepath = tempdir / "dataset" + basepath.mkdir() + + part_keys1, part_keys2 = partition_keys + for part1 in part_keys1: + for part2 in part_keys2: + path = basepath / \ + fmt.format(part1 or null_value, part2 or null_value) + path.mkdir(parents=True) + pq.write_table(table, path / "test.parquet") + + dataset = ds.dataset(str(basepath), partitioning=partitioning) + + def expected_type(key): + if infer_dictionary: + value_type = pa.string() if isinstance(key, str) else pa.int32() + return pa.dictionary(pa.int32(), value_type) + else: + return pa.string() if isinstance(key, str) else pa.int32() + expected_schema = table.schema.append( + pa.field("part1", expected_type(part_keys1[0])) + ).append( + pa.field("part2", expected_type(part_keys2[0])) + ) + assert dataset.schema.equals(expected_schema) + + +@pytest.mark.pandas +def test_dataset_partitioned_dictionary_type_reconstruct(tempdir): + # https://issues.apache.org/jira/browse/ARROW-11400 + table = pa.table({'part': np.repeat(['A', 'B'], 5), 'col': range(10)}) + part = ds.partitioning(table.select(['part']).schema, flavor="hive") + ds.write_dataset(table, tempdir, partitioning=part, format="feather") + + dataset = ds.dataset( + tempdir, format="feather", + partitioning=ds.HivePartitioning.discover(infer_dictionary=True) + ) + expected = pa.table( + {'col': table['col'], 'part': table['part'].dictionary_encode()} + ) + assert dataset.to_table().equals(expected) + fragment = list(dataset.get_fragments())[0] + assert fragment.to_table(schema=dataset.schema).equals(expected[:5]) + part_expr = fragment.partition_expression + + restored = pickle.loads(pickle.dumps(dataset)) + assert restored.to_table().equals(expected) + + restored = pickle.loads(pickle.dumps(fragment)) + assert restored.to_table(schema=dataset.schema).equals(expected[:5]) + # to_pandas call triggers computation of the actual dictionary values + assert restored.to_table(schema=dataset.schema).to_pandas().equals( + expected[:5].to_pandas() + ) + assert restored.partition_expression.equals(part_expr) + + +@pytest.fixture +def s3_example_simple(s3_server): + from pyarrow.fs import FileSystem + import pyarrow.parquet as pq + + host, port, access_key, secret_key = s3_server['connection'] + uri = ( + "s3://{}:{}@mybucket/data.parquet?scheme=http&endpoint_override={}:{}" + .format(access_key, secret_key, host, port) + ) + + fs, path = FileSystem.from_uri(uri) + + fs.create_dir("mybucket") + table = pa.table({'a': [1, 2, 3]}) + with fs.open_output_stream("mybucket/data.parquet") as out: + pq.write_table(table, out) + + return table, path, fs, uri, host, port, access_key, secret_key + + +@pytest.mark.parquet +@pytest.mark.s3 +def test_open_dataset_from_uri_s3(s3_example_simple, dataset_reader): + # open dataset from non-localfs string path + table, path, fs, uri, _, _, _, _ = s3_example_simple + + # full string URI + dataset = ds.dataset(uri, format="parquet") + assert dataset_reader.to_table(dataset).equals(table) + + # passing filesystem object + dataset = ds.dataset(path, format="parquet", filesystem=fs) + assert dataset_reader.to_table(dataset).equals(table) + + +@pytest.mark.parquet +@pytest.mark.s3 # still needed to create the data +def test_open_dataset_from_uri_s3_fsspec(s3_example_simple): + table, path, _, _, host, port, access_key, secret_key = s3_example_simple + s3fs = pytest.importorskip("s3fs") + + from pyarrow.fs import PyFileSystem, FSSpecHandler + + fs = s3fs.S3FileSystem( + key=access_key, + secret=secret_key, + client_kwargs={ + 'endpoint_url': 'http://{}:{}'.format(host, port) + } + ) + + # passing as fsspec filesystem + dataset = ds.dataset(path, format="parquet", filesystem=fs) + assert dataset.to_table().equals(table) + + # directly passing the fsspec-handler + fs = PyFileSystem(FSSpecHandler(fs)) + dataset = ds.dataset(path, format="parquet", filesystem=fs) + assert dataset.to_table().equals(table) + + +@pytest.mark.parquet +@pytest.mark.s3 +def test_open_dataset_from_s3_with_filesystem_uri(s3_server): + from pyarrow.fs import FileSystem + import pyarrow.parquet as pq + + host, port, access_key, secret_key = s3_server['connection'] + bucket = 'theirbucket' + path = 'nested/folder/data.parquet' + uri = "s3://{}:{}@{}/{}?scheme=http&endpoint_override={}:{}".format( + access_key, secret_key, bucket, path, host, port + ) + + fs, path = FileSystem.from_uri(uri) + assert path == 'theirbucket/nested/folder/data.parquet' + + fs.create_dir(bucket) + + table = pa.table({'a': [1, 2, 3]}) + with fs.open_output_stream(path) as out: + pq.write_table(table, out) + + # full string URI + dataset = ds.dataset(uri, format="parquet") + assert dataset.to_table().equals(table) + + # passing filesystem as an uri + template = ( + "s3://{}:{}@{{}}?scheme=http&endpoint_override={}:{}".format( + access_key, secret_key, host, port + ) + ) + cases = [ + ('theirbucket/nested/folder/', '/data.parquet'), + ('theirbucket/nested/folder', 'data.parquet'), + ('theirbucket/nested/', 'folder/data.parquet'), + ('theirbucket/nested', 'folder/data.parquet'), + ('theirbucket', '/nested/folder/data.parquet'), + ('theirbucket', 'nested/folder/data.parquet'), + ] + for prefix, path in cases: + uri = template.format(prefix) + dataset = ds.dataset(path, filesystem=uri, format="parquet") + assert dataset.to_table().equals(table) + + with pytest.raises(pa.ArrowInvalid, match='Missing bucket name'): + uri = template.format('/') + ds.dataset('/theirbucket/nested/folder/data.parquet', filesystem=uri) + + error = ( + "The path component of the filesystem URI must point to a directory " + "but it has a type: `{}`. The path component is `{}` and the given " + "filesystem URI is `{}`" + ) + + path = 'theirbucket/doesnt/exist' + uri = template.format(path) + with pytest.raises(ValueError) as exc: + ds.dataset('data.parquet', filesystem=uri) + assert str(exc.value) == error.format('NotFound', path, uri) + + path = 'theirbucket/nested/folder/data.parquet' + uri = template.format(path) + with pytest.raises(ValueError) as exc: + ds.dataset('data.parquet', filesystem=uri) + assert str(exc.value) == error.format('File', path, uri) + + +@pytest.mark.parquet +def test_open_dataset_from_fsspec(tempdir): + table, path = _create_single_file(tempdir) + + fsspec = pytest.importorskip("fsspec") + + localfs = fsspec.filesystem("file") + dataset = ds.dataset(path, filesystem=localfs) + assert dataset.schema.equals(table.schema) + + +@pytest.mark.pandas +def test_filter_timestamp(tempdir, dataset_reader): + # ARROW-11379 + path = tempdir / "test_partition_timestamps" + + table = pa.table({ + "dates": ['2012-01-01', '2012-01-02'] * 5, + "id": range(10)}) + + # write dataset partitioned on dates (as strings) + part = ds.partitioning(table.select(['dates']).schema, flavor="hive") + ds.write_dataset(table, path, partitioning=part, format="feather") + + # read dataset partitioned on dates (as timestamps) + part = ds.partitioning(pa.schema([("dates", pa.timestamp("s"))]), + flavor="hive") + dataset = ds.dataset(path, format="feather", partitioning=part) + + condition = ds.field("dates") > pd.Timestamp("2012-01-01") + table = dataset_reader.to_table(dataset, filter=condition) + assert table.column('id').to_pylist() == [1, 3, 5, 7, 9] + + import datetime + condition = ds.field("dates") > datetime.datetime(2012, 1, 1) + table = dataset_reader.to_table(dataset, filter=condition) + assert table.column('id').to_pylist() == [1, 3, 5, 7, 9] + + +@pytest.mark.parquet +def test_filter_implicit_cast(tempdir, dataset_reader): + # ARROW-7652 + table = pa.table({'a': pa.array([0, 1, 2, 3, 4, 5], type=pa.int8())}) + _, path = _create_single_file(tempdir, table) + dataset = ds.dataset(str(path)) + + filter_ = ds.field('a') > 2 + assert len(dataset_reader.to_table(dataset, filter=filter_)) == 3 + + +def test_dataset_union(multisourcefs): + child = ds.FileSystemDatasetFactory( + multisourcefs, fs.FileSelector('/plain'), + format=ds.ParquetFileFormat() + ) + factory = ds.UnionDatasetFactory([child]) + + # TODO(bkietz) reintroduce factory.children property + assert len(factory.inspect_schemas()) == 1 + assert all(isinstance(s, pa.Schema) for s in factory.inspect_schemas()) + assert factory.inspect_schemas()[0].equals(child.inspect()) + assert factory.inspect().equals(child.inspect()) + assert isinstance(factory.finish(), ds.Dataset) + + +def test_union_dataset_from_other_datasets(tempdir, multisourcefs): + child1 = ds.dataset('/plain', filesystem=multisourcefs, format='parquet') + child2 = ds.dataset('/schema', filesystem=multisourcefs, format='parquet', + partitioning=['week', 'color']) + child3 = ds.dataset('/hive', filesystem=multisourcefs, format='parquet', + partitioning='hive') + + assert child1.schema != child2.schema != child3.schema + + assembled = ds.dataset([child1, child2, child3]) + assert isinstance(assembled, ds.UnionDataset) + + msg = 'cannot pass any additional arguments' + with pytest.raises(ValueError, match=msg): + ds.dataset([child1, child2], filesystem=multisourcefs) + + expected_schema = pa.schema([ + ('date', pa.date32()), + ('index', pa.int64()), + ('value', pa.float64()), + ('color', pa.string()), + ('week', pa.int32()), + ('year', pa.int32()), + ('month', pa.int32()), + ]) + assert assembled.schema.equals(expected_schema) + assert assembled.to_table().schema.equals(expected_schema) + + assembled = ds.dataset([child1, child3]) + expected_schema = pa.schema([ + ('date', pa.date32()), + ('index', pa.int64()), + ('value', pa.float64()), + ('color', pa.string()), + ('year', pa.int32()), + ('month', pa.int32()), + ]) + assert assembled.schema.equals(expected_schema) + assert assembled.to_table().schema.equals(expected_schema) + + expected_schema = pa.schema([ + ('month', pa.int32()), + ('color', pa.string()), + ('date', pa.date32()), + ]) + assembled = ds.dataset([child1, child3], schema=expected_schema) + assert assembled.to_table().schema.equals(expected_schema) + + expected_schema = pa.schema([ + ('month', pa.int32()), + ('color', pa.string()), + ('unknown', pa.string()) # fill with nulls + ]) + assembled = ds.dataset([child1, child3], schema=expected_schema) + assert assembled.to_table().schema.equals(expected_schema) + + # incompatible schemas, date and index columns have conflicting types + table = pa.table([range(9), [0.] * 4 + [1.] * 5, 'abcdefghj'], + names=['date', 'value', 'index']) + _, path = _create_single_file(tempdir, table=table) + child4 = ds.dataset(path) + + with pytest.raises(pa.ArrowInvalid, match='Unable to merge'): + ds.dataset([child1, child4]) + + +def test_dataset_from_a_list_of_local_directories_raises(multisourcefs): + msg = 'points to a directory, but only file paths are supported' + with pytest.raises(IsADirectoryError, match=msg): + ds.dataset(['/plain', '/schema', '/hive'], filesystem=multisourcefs) + + +def test_union_dataset_filesystem_datasets(multisourcefs): + # without partitioning + dataset = ds.dataset([ + ds.dataset('/plain', filesystem=multisourcefs), + ds.dataset('/schema', filesystem=multisourcefs), + ds.dataset('/hive', filesystem=multisourcefs), + ]) + expected_schema = pa.schema([ + ('date', pa.date32()), + ('index', pa.int64()), + ('value', pa.float64()), + ('color', pa.string()), + ]) + assert dataset.schema.equals(expected_schema) + + # with hive partitioning for two hive sources + dataset = ds.dataset([ + ds.dataset('/plain', filesystem=multisourcefs), + ds.dataset('/schema', filesystem=multisourcefs), + ds.dataset('/hive', filesystem=multisourcefs, partitioning='hive') + ]) + expected_schema = pa.schema([ + ('date', pa.date32()), + ('index', pa.int64()), + ('value', pa.float64()), + ('color', pa.string()), + ('year', pa.int32()), + ('month', pa.int32()), + ]) + assert dataset.schema.equals(expected_schema) + + +@pytest.mark.parquet +def test_specified_schema(tempdir, dataset_reader): + import pyarrow.parquet as pq + + table = pa.table({'a': [1, 2, 3], 'b': [.1, .2, .3]}) + pq.write_table(table, tempdir / "data.parquet") + + def _check_dataset(schema, expected, expected_schema=None): + dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema) + if expected_schema is not None: + assert dataset.schema.equals(expected_schema) + else: + assert dataset.schema.equals(schema) + result = dataset_reader.to_table(dataset) + assert result.equals(expected) + + # no schema specified + schema = None + expected = table + _check_dataset(schema, expected, expected_schema=table.schema) + + # identical schema specified + schema = table.schema + expected = table + _check_dataset(schema, expected) + + # Specifying schema with change column order + schema = pa.schema([('b', 'float64'), ('a', 'int64')]) + expected = pa.table([[.1, .2, .3], [1, 2, 3]], names=['b', 'a']) + _check_dataset(schema, expected) + + # Specifying schema with missing column + schema = pa.schema([('a', 'int64')]) + expected = pa.table([[1, 2, 3]], names=['a']) + _check_dataset(schema, expected) + + # Specifying schema with additional column + schema = pa.schema([('a', 'int64'), ('c', 'int32')]) + expected = pa.table([[1, 2, 3], + pa.array([None, None, None], type='int32')], + names=['a', 'c']) + _check_dataset(schema, expected) + + # Specifying with differing field types + schema = pa.schema([('a', 'int32'), ('b', 'float64')]) + dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema) + expected = pa.table([table['a'].cast('int32'), + table['b']], + names=['a', 'b']) + _check_dataset(schema, expected) + + # Specifying with incompatible schema + schema = pa.schema([('a', pa.list_(pa.int32())), ('b', 'float64')]) + dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema) + assert dataset.schema.equals(schema) + with pytest.raises(NotImplementedError, + match='Unsupported cast from int64 to list'): + dataset_reader.to_table(dataset) + + +@pytest.mark.parquet +def test_incompatible_schema_hang(tempdir, dataset_reader): + # ARROW-13480: deadlock when reading past an errored fragment + import pyarrow.parquet as pq + + fn = tempdir / "data.parquet" + table = pa.table({'a': [1, 2, 3]}) + pq.write_table(table, fn) + + schema = pa.schema([('a', pa.null())]) + dataset = ds.dataset([str(fn)] * 100, schema=schema) + assert dataset.schema.equals(schema) + scanner = dataset_reader.scanner(dataset) + reader = scanner.to_reader() + with pytest.raises(NotImplementedError, + match='Unsupported cast from int64 to null'): + reader.read_all() + + +def test_ipc_format(tempdir, dataset_reader): + table = pa.table({'a': pa.array([1, 2, 3], type="int8"), + 'b': pa.array([.1, .2, .3], type="float64")}) + + path = str(tempdir / 'test.arrow') + with pa.output_stream(path) as sink: + writer = pa.RecordBatchFileWriter(sink, table.schema) + writer.write_batch(table.to_batches()[0]) + writer.close() + + dataset = ds.dataset(path, format=ds.IpcFileFormat()) + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + for format_str in ["ipc", "arrow"]: + dataset = ds.dataset(path, format=format_str) + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + +@pytest.mark.orc +def test_orc_format(tempdir, dataset_reader): + from pyarrow import orc + table = pa.table({'a': pa.array([1, 2, 3], type="int8"), + 'b': pa.array([.1, .2, .3], type="float64")}) + + path = str(tempdir / 'test.orc') + orc.write_table(table, path) + + dataset = ds.dataset(path, format=ds.OrcFileFormat()) + result = dataset_reader.to_table(dataset) + result.validate(full=True) + assert result.equals(table) + + dataset = ds.dataset(path, format="orc") + result = dataset_reader.to_table(dataset) + result.validate(full=True) + assert result.equals(table) + + result = dataset_reader.to_table(dataset, columns=["b"]) + result.validate(full=True) + assert result.equals(table.select(["b"])) + + result = dataset_reader.to_table( + dataset, columns={"b2": ds.field("b") * 2} + ) + result.validate(full=True) + assert result.equals( + pa.table({'b2': pa.array([.2, .4, .6], type="float64")}) + ) + + assert dataset_reader.count_rows(dataset) == 3 + assert dataset_reader.count_rows(dataset, filter=ds.field("a") > 2) == 1 + + +@pytest.mark.orc +def test_orc_scan_options(tempdir, dataset_reader): + from pyarrow import orc + table = pa.table({'a': pa.array([1, 2, 3], type="int8"), + 'b': pa.array([.1, .2, .3], type="float64")}) + + path = str(tempdir / 'test.orc') + orc.write_table(table, path) + + dataset = ds.dataset(path, format="orc") + result = list(dataset_reader.to_batches(dataset)) + assert len(result) == 1 + assert result[0].num_rows == 3 + assert result[0].equals(table.to_batches()[0]) + # TODO batch_size is not yet supported (ARROW-14153) + # result = list(dataset_reader.to_batches(dataset, batch_size=2)) + # assert len(result) == 2 + # assert result[0].num_rows == 2 + # assert result[0].equals(table.slice(0, 2).to_batches()[0]) + # assert result[1].num_rows == 1 + # assert result[1].equals(table.slice(2, 1).to_batches()[0]) + + +def test_orc_format_not_supported(): + try: + from pyarrow.dataset import OrcFileFormat # noqa + except (ImportError, AttributeError): + # catch AttributeError for Python 3.6 + # ORC is not available, test error message + with pytest.raises( + ValueError, match="not built with support for the ORC file" + ): + ds.dataset(".", format="orc") + + +@pytest.mark.pandas +def test_csv_format(tempdir, dataset_reader): + table = pa.table({'a': pa.array([1, 2, 3], type="int64"), + 'b': pa.array([.1, .2, .3], type="float64")}) + + path = str(tempdir / 'test.csv') + table.to_pandas().to_csv(path, index=False) + + dataset = ds.dataset(path, format=ds.CsvFileFormat()) + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + dataset = ds.dataset(path, format='csv') + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + +@pytest.mark.pandas +@pytest.mark.parametrize("compression", [ + "bz2", + "gzip", + "lz4", + "zstd", +]) +def test_csv_format_compressed(tempdir, compression, dataset_reader): + if not pyarrow.Codec.is_available(compression): + pytest.skip("{} support is not built".format(compression)) + table = pa.table({'a': pa.array([1, 2, 3], type="int64"), + 'b': pa.array([.1, .2, .3], type="float64")}) + filesystem = fs.LocalFileSystem() + suffix = compression if compression != 'gzip' else 'gz' + path = str(tempdir / f'test.csv.{suffix}') + with filesystem.open_output_stream(path, compression=compression) as sink: + # https://github.com/pandas-dev/pandas/issues/23854 + # With CI version of Pandas (anything < 1.2), Pandas tries to write + # str to the sink + csv_str = table.to_pandas().to_csv(index=False) + sink.write(csv_str.encode('utf-8')) + + dataset = ds.dataset(path, format=ds.CsvFileFormat()) + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + +def test_csv_format_options(tempdir, dataset_reader): + path = str(tempdir / 'test.csv') + with open(path, 'w') as sink: + sink.write('skipped\ncol0\nfoo\nbar\n') + dataset = ds.dataset(path, format='csv') + result = dataset_reader.to_table(dataset) + assert result.equals( + pa.table({'skipped': pa.array(['col0', 'foo', 'bar'])})) + + dataset = ds.dataset(path, format=ds.CsvFileFormat( + read_options=pa.csv.ReadOptions(skip_rows=1))) + result = dataset_reader.to_table(dataset) + assert result.equals(pa.table({'col0': pa.array(['foo', 'bar'])})) + + dataset = ds.dataset(path, format=ds.CsvFileFormat( + read_options=pa.csv.ReadOptions(column_names=['foo']))) + result = dataset_reader.to_table(dataset) + assert result.equals( + pa.table({'foo': pa.array(['skipped', 'col0', 'foo', 'bar'])})) + + +def test_csv_fragment_options(tempdir, dataset_reader): + path = str(tempdir / 'test.csv') + with open(path, 'w') as sink: + sink.write('col0\nfoo\nspam\nMYNULL\n') + dataset = ds.dataset(path, format='csv') + convert_options = pyarrow.csv.ConvertOptions(null_values=['MYNULL'], + strings_can_be_null=True) + options = ds.CsvFragmentScanOptions( + convert_options=convert_options, + read_options=pa.csv.ReadOptions(block_size=2**16)) + result = dataset_reader.to_table(dataset, fragment_scan_options=options) + assert result.equals(pa.table({'col0': pa.array(['foo', 'spam', None])})) + + csv_format = ds.CsvFileFormat(convert_options=convert_options) + dataset = ds.dataset(path, format=csv_format) + result = dataset_reader.to_table(dataset) + assert result.equals(pa.table({'col0': pa.array(['foo', 'spam', None])})) + + options = ds.CsvFragmentScanOptions() + result = dataset_reader.to_table(dataset, fragment_scan_options=options) + assert result.equals( + pa.table({'col0': pa.array(['foo', 'spam', 'MYNULL'])})) + + +def test_feather_format(tempdir, dataset_reader): + from pyarrow.feather import write_feather + + table = pa.table({'a': pa.array([1, 2, 3], type="int8"), + 'b': pa.array([.1, .2, .3], type="float64")}) + + basedir = tempdir / "feather_dataset" + basedir.mkdir() + write_feather(table, str(basedir / "data.feather")) + + dataset = ds.dataset(basedir, format=ds.IpcFileFormat()) + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + dataset = ds.dataset(basedir, format="feather") + result = dataset_reader.to_table(dataset) + assert result.equals(table) + + # ARROW-8641 - column selection order + result = dataset_reader.to_table(dataset, columns=["b", "a"]) + assert result.column_names == ["b", "a"] + result = dataset_reader.to_table(dataset, columns=["a", "a"]) + assert result.column_names == ["a", "a"] + + # error with Feather v1 files + write_feather(table, str(basedir / "data1.feather"), version=1) + with pytest.raises(ValueError): + dataset_reader.to_table(ds.dataset(basedir, format="feather")) + + +def _create_parquet_dataset_simple(root_path): + """ + Creates a simple (flat files, no nested partitioning) Parquet dataset + """ + import pyarrow.parquet as pq + + metadata_collector = [] + + for i in range(4): + table = pa.table({'f1': [i] * 10, 'f2': np.random.randn(10)}) + pq.write_to_dataset( + table, str(root_path), metadata_collector=metadata_collector + ) + + metadata_path = str(root_path / '_metadata') + # write _metadata file + pq.write_metadata( + table.schema, metadata_path, + metadata_collector=metadata_collector + ) + return metadata_path, table + + +@pytest.mark.parquet +@pytest.mark.pandas # write_to_dataset currently requires pandas +def test_parquet_dataset_factory(tempdir): + root_path = tempdir / "test_parquet_dataset" + metadata_path, table = _create_parquet_dataset_simple(root_path) + dataset = ds.parquet_dataset(metadata_path) + assert dataset.schema.equals(table.schema) + assert len(dataset.files) == 4 + result = dataset.to_table() + assert result.num_rows == 40 + + +@pytest.mark.parquet +@pytest.mark.pandas # write_to_dataset currently requires pandas +@pytest.mark.parametrize('use_legacy_dataset', [False, True]) +def test_parquet_dataset_factory_roundtrip(tempdir, use_legacy_dataset): + # Simple test to ensure we can roundtrip dataset to + # _metadata/common_metadata and back. A more complex test + # using partitioning will have to wait for ARROW-13269. The + # above test (test_parquet_dataset_factory) will not work + # when legacy is False as there is no "append" equivalent in + # the new dataset until ARROW-12358 + import pyarrow.parquet as pq + root_path = tempdir / "test_parquet_dataset" + table = pa.table({'f1': [0] * 10, 'f2': np.random.randn(10)}) + metadata_collector = [] + pq.write_to_dataset( + table, str(root_path), metadata_collector=metadata_collector, + use_legacy_dataset=use_legacy_dataset + ) + metadata_path = str(root_path / '_metadata') + # write _metadata file + pq.write_metadata( + table.schema, metadata_path, + metadata_collector=metadata_collector + ) + dataset = ds.parquet_dataset(metadata_path) + assert dataset.schema.equals(table.schema) + result = dataset.to_table() + assert result.num_rows == 10 + + +def test_parquet_dataset_factory_order(tempdir): + # The order of the fragments in the dataset should match the order of the + # row groups in the _metadata file. + import pyarrow.parquet as pq + metadatas = [] + # Create a dataset where f1 is incrementing from 0 to 100 spread across + # 10 files. Put the row groups in the correct order in _metadata + for i in range(10): + table = pa.table( + {'f1': list(range(i*10, (i+1)*10))}) + table_path = tempdir / f'{i}.parquet' + pq.write_table(table, table_path, metadata_collector=metadatas) + metadatas[-1].set_file_path(f'{i}.parquet') + metadata_path = str(tempdir / '_metadata') + pq.write_metadata(table.schema, metadata_path, metadatas) + dataset = ds.parquet_dataset(metadata_path) + # Ensure the table contains values from 0-100 in the right order + scanned_table = dataset.to_table() + scanned_col = scanned_table.column('f1').to_pylist() + assert scanned_col == list(range(0, 100)) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_parquet_dataset_factory_invalid(tempdir): + root_path = tempdir / "test_parquet_dataset_invalid" + metadata_path, table = _create_parquet_dataset_simple(root_path) + # remove one of the files + list(root_path.glob("*.parquet"))[0].unlink() + dataset = ds.parquet_dataset(metadata_path) + assert dataset.schema.equals(table.schema) + assert len(dataset.files) == 4 + with pytest.raises(FileNotFoundError): + dataset.to_table() + + +def _create_metadata_file(root_path): + # create _metadata file from existing parquet dataset + import pyarrow.parquet as pq + + parquet_paths = list(sorted(root_path.rglob("*.parquet"))) + schema = pq.ParquetFile(parquet_paths[0]).schema.to_arrow_schema() + + metadata_collector = [] + for path in parquet_paths: + metadata = pq.ParquetFile(path).metadata + metadata.set_file_path(str(path.relative_to(root_path))) + metadata_collector.append(metadata) + + metadata_path = root_path / "_metadata" + pq.write_metadata( + schema, metadata_path, metadata_collector=metadata_collector + ) + return metadata_path + + +def _create_parquet_dataset_partitioned(root_path): + import pyarrow.parquet as pq + + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10))], + names=["f1", "f2", "part"] + ) + table = table.replace_schema_metadata({"key": "value"}) + pq.write_to_dataset(table, str(root_path), partition_cols=['part']) + return _create_metadata_file(root_path), table + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_parquet_dataset_factory_partitioned(tempdir): + root_path = tempdir / "test_parquet_dataset_factory_partitioned" + metadata_path, table = _create_parquet_dataset_partitioned(root_path) + + partitioning = ds.partitioning(flavor="hive") + dataset = ds.parquet_dataset(metadata_path, partitioning=partitioning) + + assert dataset.schema.equals(table.schema) + assert len(dataset.files) == 2 + result = dataset.to_table() + assert result.num_rows == 20 + + # the partitioned dataset does not preserve order + result = result.to_pandas().sort_values("f1").reset_index(drop=True) + expected = table.to_pandas() + pd.testing.assert_frame_equal(result, expected) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_parquet_dataset_factory_metadata(tempdir): + # ensure ParquetDatasetFactory preserves metadata (ARROW-9363) + root_path = tempdir / "test_parquet_dataset_factory_metadata" + metadata_path, table = _create_parquet_dataset_partitioned(root_path) + + dataset = ds.parquet_dataset(metadata_path, partitioning="hive") + assert dataset.schema.equals(table.schema) + assert b"key" in dataset.schema.metadata + + fragments = list(dataset.get_fragments()) + assert b"key" in fragments[0].physical_schema.metadata + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_parquet_dataset_lazy_filtering(tempdir, open_logging_fs): + fs, assert_opens = open_logging_fs + + # Test to ensure that no IO happens when filtering a dataset + # created with ParquetDatasetFactory from a _metadata file + + root_path = tempdir / "test_parquet_dataset_lazy_filtering" + metadata_path, _ = _create_parquet_dataset_simple(root_path) + + # creating the dataset should only open the metadata file + with assert_opens([metadata_path]): + dataset = ds.parquet_dataset( + metadata_path, + partitioning=ds.partitioning(flavor="hive"), + filesystem=fs) + + # materializing fragments should not open any file + with assert_opens([]): + fragments = list(dataset.get_fragments()) + + # filtering fragments should not open any file + with assert_opens([]): + list(dataset.get_fragments(ds.field("f1") > 15)) + + # splitting by row group should still not open any file + with assert_opens([]): + fragments[0].split_by_row_group(ds.field("f1") > 15) + + # ensuring metadata of splitted fragment should also not open any file + with assert_opens([]): + rg_fragments = fragments[0].split_by_row_group() + rg_fragments[0].ensure_complete_metadata() + + # FIXME(bkietz) on Windows this results in FileNotFoundErrors. + # but actually scanning does open files + # with assert_opens([f.path for f in fragments]): + # dataset.to_table() + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_dataset_schema_metadata(tempdir, dataset_reader): + # ARROW-8802 + df = pd.DataFrame({'a': [1, 2, 3]}) + path = tempdir / "test.parquet" + df.to_parquet(path) + dataset = ds.dataset(path) + + schema = dataset_reader.to_table(dataset).schema + projected_schema = dataset_reader.to_table(dataset, columns=["a"]).schema + + # ensure the pandas metadata is included in the schema + assert b"pandas" in schema.metadata + # ensure it is still there in a projected schema (with column selection) + assert schema.equals(projected_schema, check_metadata=True) + + +@pytest.mark.parquet +def test_filter_mismatching_schema(tempdir, dataset_reader): + # ARROW-9146 + import pyarrow.parquet as pq + + table = pa.table({"col": pa.array([1, 2, 3, 4], type='int32')}) + pq.write_table(table, str(tempdir / "data.parquet")) + + # specifying explicit schema, but that mismatches the schema of the data + schema = pa.schema([("col", pa.int64())]) + dataset = ds.dataset( + tempdir / "data.parquet", format="parquet", schema=schema) + + # filtering on a column with such type mismatch should implicitly + # cast the column + filtered = dataset_reader.to_table(dataset, filter=ds.field("col") > 2) + assert filtered["col"].equals(table["col"].cast('int64').slice(2)) + + fragment = list(dataset.get_fragments())[0] + filtered = dataset_reader.to_table( + fragment, filter=ds.field("col") > 2, schema=schema) + assert filtered["col"].equals(table["col"].cast('int64').slice(2)) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_dataset_project_only_partition_columns(tempdir, dataset_reader): + # ARROW-8729 + import pyarrow.parquet as pq + + table = pa.table({'part': 'a a b b'.split(), 'col': list(range(4))}) + + path = str(tempdir / 'test_dataset') + pq.write_to_dataset(table, path, partition_cols=['part']) + dataset = ds.dataset(path, partitioning='hive') + + all_cols = dataset_reader.to_table(dataset) + part_only = dataset_reader.to_table(dataset, columns=['part']) + + assert all_cols.column('part').equals(part_only.column('part')) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_dataset_project_null_column(tempdir, dataset_reader): + import pandas as pd + df = pd.DataFrame({"col": np.array([None, None, None], dtype='object')}) + + f = tempdir / "test_dataset_project_null_column.parquet" + df.to_parquet(f, engine="pyarrow") + + dataset = ds.dataset(f, format="parquet", + schema=pa.schema([("col", pa.int64())])) + expected = pa.table({'col': pa.array([None, None, None], pa.int64())}) + assert dataset_reader.to_table(dataset).equals(expected) + + +def test_dataset_project_columns(tempdir, dataset_reader): + # basic column re-projection with expressions + from pyarrow import feather + table = pa.table({"A": [1, 2, 3], "B": [1., 2., 3.], "C": ["a", "b", "c"]}) + feather.write_feather(table, tempdir / "data.feather") + + dataset = ds.dataset(tempdir / "data.feather", format="feather") + result = dataset_reader.to_table(dataset, columns={ + 'A_renamed': ds.field('A'), + 'B_as_int': ds.field('B').cast("int32", safe=False), + 'C_is_a': ds.field('C') == 'a' + }) + expected = pa.table({ + "A_renamed": [1, 2, 3], + "B_as_int": pa.array([1, 2, 3], type="int32"), + "C_is_a": [True, False, False], + }) + assert result.equals(expected) + + # raise proper error when not passing an expression + with pytest.raises(TypeError, match="Expected an Expression"): + dataset_reader.to_table(dataset, columns={"A": "A"}) + + +@pytest.mark.pandas +@pytest.mark.parquet +def test_dataset_preserved_partitioning(tempdir): + # ARROW-8655 + + # through discovery, but without partitioning + _, path = _create_single_file(tempdir) + dataset = ds.dataset(path) + assert dataset.partitioning is None + + # through discovery, with hive partitioning but not specified + full_table, path = _create_partitioned_dataset(tempdir) + dataset = ds.dataset(path) + assert dataset.partitioning is None + + # through discovery, with hive partitioning (from a partitioning factory) + dataset = ds.dataset(path, partitioning="hive") + part = dataset.partitioning + assert part is not None + assert isinstance(part, ds.HivePartitioning) + assert part.schema == pa.schema([("part", pa.int32())]) + assert len(part.dictionaries) == 1 + assert part.dictionaries[0] == pa.array([0, 1, 2], pa.int32()) + + # through discovery, with hive partitioning (from a partitioning object) + part = ds.partitioning(pa.schema([("part", pa.int32())]), flavor="hive") + assert isinstance(part, ds.HivePartitioning) # not a factory + assert part.dictionaries is None + dataset = ds.dataset(path, partitioning=part) + part = dataset.partitioning + assert isinstance(part, ds.HivePartitioning) + assert part.schema == pa.schema([("part", pa.int32())]) + # TODO is this expected? + assert part.dictionaries is None + + # through manual creation -> not available + dataset = ds.dataset(path, partitioning="hive") + dataset2 = ds.FileSystemDataset( + list(dataset.get_fragments()), schema=dataset.schema, + format=dataset.format, filesystem=dataset.filesystem + ) + assert dataset2.partitioning is None + + # through discovery with ParquetDatasetFactory + root_path = tempdir / "data-partitioned-metadata" + metadata_path, _ = _create_parquet_dataset_partitioned(root_path) + dataset = ds.parquet_dataset(metadata_path, partitioning="hive") + part = dataset.partitioning + assert part is not None + assert isinstance(part, ds.HivePartitioning) + assert part.schema == pa.schema([("part", pa.string())]) + assert len(part.dictionaries) == 1 + # will be fixed by ARROW-13153 (order is not preserved at the moment) + # assert part.dictionaries[0] == pa.array(["a", "b"], pa.string()) + assert set(part.dictionaries[0].to_pylist()) == {"a", "b"} + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_write_to_dataset_given_null_just_works(tempdir): + import pyarrow.parquet as pq + + schema = pa.schema([ + pa.field('col', pa.int64()), + pa.field('part', pa.dictionary(pa.int32(), pa.string())) + ]) + table = pa.table({'part': [None, None, 'a', 'a'], + 'col': list(range(4))}, schema=schema) + + path = str(tempdir / 'test_dataset') + pq.write_to_dataset(table, path, partition_cols=[ + 'part'], use_legacy_dataset=False) + + actual_table = pq.read_table(tempdir / 'test_dataset') + # column.equals can handle the difference in chunking but not the fact + # that `part` will have different dictionaries for the two chunks + assert actual_table.column('part').to_pylist( + ) == table.column('part').to_pylist() + assert actual_table.column('col').equals(table.column('col')) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_legacy_write_to_dataset_drops_null(tempdir): + import pyarrow.parquet as pq + + schema = pa.schema([ + pa.field('col', pa.int64()), + pa.field('part', pa.dictionary(pa.int32(), pa.string())) + ]) + table = pa.table({'part': ['a', 'a', None, None], + 'col': list(range(4))}, schema=schema) + expected = pa.table( + {'part': ['a', 'a'], 'col': list(range(2))}, schema=schema) + + path = str(tempdir / 'test_dataset') + pq.write_to_dataset(table, path, partition_cols=[ + 'part'], use_legacy_dataset=True) + + actual = pq.read_table(tempdir / 'test_dataset') + assert actual == expected + + +def _sort_table(tab, sort_col): + import pyarrow.compute as pc + sorted_indices = pc.sort_indices( + tab, options=pc.SortOptions([(sort_col, 'ascending')])) + return pc.take(tab, sorted_indices) + + +def _check_dataset_roundtrip(dataset, base_dir, expected_files, sort_col, + base_dir_path=None, partitioning=None): + base_dir_path = base_dir_path or base_dir + + ds.write_dataset(dataset, base_dir, format="feather", + partitioning=partitioning, use_threads=False) + + # check that all files are present + file_paths = list(base_dir_path.rglob("*")) + assert set(file_paths) == set(expected_files) + + # check that reading back in as dataset gives the same result + dataset2 = ds.dataset( + base_dir_path, format="feather", partitioning=partitioning) + + assert _sort_table(dataset2.to_table(), sort_col).equals( + _sort_table(dataset.to_table(), sort_col)) + + +@pytest.mark.parquet +def test_write_dataset(tempdir): + # manually create a written dataset and read as dataset object + directory = tempdir / 'single-file' + directory.mkdir() + _ = _create_single_file(directory) + dataset = ds.dataset(directory) + + # full string path + target = tempdir / 'single-file-target' + expected_files = [target / "part-0.feather"] + _check_dataset_roundtrip(dataset, str(target), expected_files, 'a', target) + + # pathlib path object + target = tempdir / 'single-file-target2' + expected_files = [target / "part-0.feather"] + _check_dataset_roundtrip(dataset, target, expected_files, 'a', target) + + # TODO + # # relative path + # target = tempdir / 'single-file-target3' + # expected_files = [target / "part-0.ipc"] + # _check_dataset_roundtrip( + # dataset, './single-file-target3', expected_files, target) + + # Directory of files + directory = tempdir / 'single-directory' + directory.mkdir() + _ = _create_directory_of_files(directory) + dataset = ds.dataset(directory) + + target = tempdir / 'single-directory-target' + expected_files = [target / "part-0.feather"] + _check_dataset_roundtrip(dataset, str(target), expected_files, 'a', target) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_write_dataset_partitioned(tempdir): + directory = tempdir / "partitioned" + _ = _create_parquet_dataset_partitioned(directory) + partitioning = ds.partitioning(flavor="hive") + dataset = ds.dataset(directory, partitioning=partitioning) + + # hive partitioning + target = tempdir / 'partitioned-hive-target' + expected_paths = [ + target / "part=a", target / "part=a" / "part-0.feather", + target / "part=b", target / "part=b" / "part-0.feather" + ] + partitioning_schema = ds.partitioning( + pa.schema([("part", pa.string())]), flavor="hive") + _check_dataset_roundtrip( + dataset, str(target), expected_paths, 'f1', target, + partitioning=partitioning_schema) + + # directory partitioning + target = tempdir / 'partitioned-dir-target' + expected_paths = [ + target / "a", target / "a" / "part-0.feather", + target / "b", target / "b" / "part-0.feather" + ] + partitioning_schema = ds.partitioning( + pa.schema([("part", pa.string())])) + _check_dataset_roundtrip( + dataset, str(target), expected_paths, 'f1', target, + partitioning=partitioning_schema) + + +def test_write_dataset_with_field_names(tempdir): + table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']}) + + ds.write_dataset(table, tempdir, format='parquet', + partitioning=["b"]) + + load_back = ds.dataset(tempdir, partitioning=["b"]) + files = load_back.files + partitioning_dirs = { + str(pathlib.Path(f).relative_to(tempdir).parent) for f in files + } + assert partitioning_dirs == {"x", "y", "z"} + + load_back_table = load_back.to_table() + assert load_back_table.equals(table) + + +def test_write_dataset_with_field_names_hive(tempdir): + table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z']}) + + ds.write_dataset(table, tempdir, format='parquet', + partitioning=["b"], partitioning_flavor="hive") + + load_back = ds.dataset(tempdir, partitioning="hive") + files = load_back.files + partitioning_dirs = { + str(pathlib.Path(f).relative_to(tempdir).parent) for f in files + } + assert partitioning_dirs == {"b=x", "b=y", "b=z"} + + load_back_table = load_back.to_table() + assert load_back_table.equals(table) + + +def test_write_dataset_with_scanner(tempdir): + table = pa.table({'a': ['x', 'y', None], 'b': ['x', 'y', 'z'], + 'c': [1, 2, 3]}) + + ds.write_dataset(table, tempdir, format='parquet', + partitioning=["b"]) + + dataset = ds.dataset(tempdir, partitioning=["b"]) + + with tempfile.TemporaryDirectory() as tempdir2: + ds.write_dataset(dataset.scanner(columns=["b", "c"], use_async=True), + tempdir2, format='parquet', partitioning=["b"]) + + load_back = ds.dataset(tempdir2, partitioning=["b"]) + load_back_table = load_back.to_table() + assert dict(load_back_table.to_pydict() + ) == table.drop(["a"]).to_pydict() + + +def test_write_dataset_with_backpressure(tempdir): + consumer_gate = threading.Event() + + # A filesystem that blocks all writes so that we can build + # up backpressure. The writes are released at the end of + # the test. + class GatingFs(ProxyHandler): + def open_output_stream(self, path, metadata): + # Block until the end of the test + consumer_gate.wait() + return self._fs.open_output_stream(path, metadata=metadata) + gating_fs = fs.PyFileSystem(GatingFs(fs.LocalFileSystem())) + + schema = pa.schema([pa.field('data', pa.int32())]) + # By default, the dataset writer will queue up 64Mi rows so + # with batches of 1M it should only fit ~67 batches + batch = pa.record_batch([pa.array(list(range(1_000_000)))], schema=schema) + batches_read = 0 + min_backpressure = 67 + end = 200 + + def counting_generator(): + nonlocal batches_read + while batches_read < end: + time.sleep(0.01) + batches_read += 1 + yield batch + + scanner = ds.Scanner.from_batches( + counting_generator(), schema=schema, use_threads=True, + use_async=True) + + write_thread = threading.Thread( + target=lambda: ds.write_dataset( + scanner, str(tempdir), format='parquet', filesystem=gating_fs)) + write_thread.start() + + try: + start = time.time() + + def duration(): + return time.time() - start + + # This test is timing dependent. There is no signal from the C++ + # when backpressure has been hit. We don't know exactly when + # backpressure will be hit because it may take some time for the + # signal to get from the sink to the scanner. + # + # The test may emit false positives on slow systems. It could + # theoretically emit a false negative if the scanner managed to read + # and emit all 200 batches before the backpressure signal had a chance + # to propagate but the 0.01s delay in the generator should make that + # scenario unlikely. + last_value = 0 + backpressure_probably_hit = False + while duration() < 10: + if batches_read > min_backpressure: + if batches_read == last_value: + backpressure_probably_hit = True + break + last_value = batches_read + time.sleep(0.5) + + assert backpressure_probably_hit + + finally: + consumer_gate.set() + write_thread.join() + assert batches_read == end + + +def test_write_dataset_with_dataset(tempdir): + table = pa.table({'b': ['x', 'y', 'z'], 'c': [1, 2, 3]}) + + ds.write_dataset(table, tempdir, format='parquet', + partitioning=["b"]) + + dataset = ds.dataset(tempdir, partitioning=["b"]) + + with tempfile.TemporaryDirectory() as tempdir2: + ds.write_dataset(dataset, tempdir2, + format='parquet', partitioning=["b"]) + + load_back = ds.dataset(tempdir2, partitioning=["b"]) + load_back_table = load_back.to_table() + assert dict(load_back_table.to_pydict()) == table.to_pydict() + + +@pytest.mark.pandas +def test_write_dataset_existing_data(tempdir): + directory = tempdir / 'ds' + table = pa.table({'b': ['x', 'y', 'z'], 'c': [1, 2, 3]}) + partitioning = ds.partitioning(schema=pa.schema( + [pa.field('c', pa.int64())]), flavor='hive') + + def compare_tables_ignoring_order(t1, t2): + df1 = t1.to_pandas().sort_values('b').reset_index(drop=True) + df2 = t2.to_pandas().sort_values('b').reset_index(drop=True) + assert df1.equals(df2) + + # First write is ok + ds.write_dataset(table, directory, partitioning=partitioning, format='ipc') + + table = pa.table({'b': ['a', 'b', 'c'], 'c': [2, 3, 4]}) + + # Second write should fail + with pytest.raises(pa.ArrowInvalid): + ds.write_dataset(table, directory, + partitioning=partitioning, format='ipc') + + extra_table = pa.table({'b': ['e']}) + extra_file = directory / 'c=2' / 'foo.arrow' + pyarrow.feather.write_feather(extra_table, extra_file) + + # Should be ok and overwrite with overwrite behavior + ds.write_dataset(table, directory, partitioning=partitioning, + format='ipc', + existing_data_behavior='overwrite_or_ignore') + + overwritten = pa.table( + {'b': ['e', 'x', 'a', 'b', 'c'], 'c': [2, 1, 2, 3, 4]}) + readback = ds.dataset(tempdir, format='ipc', + partitioning=partitioning).to_table() + compare_tables_ignoring_order(readback, overwritten) + assert extra_file.exists() + + # Should be ok and delete matching with delete_matching + ds.write_dataset(table, directory, partitioning=partitioning, + format='ipc', existing_data_behavior='delete_matching') + + overwritten = pa.table({'b': ['x', 'a', 'b', 'c'], 'c': [1, 2, 3, 4]}) + readback = ds.dataset(tempdir, format='ipc', + partitioning=partitioning).to_table() + compare_tables_ignoring_order(readback, overwritten) + assert not extra_file.exists() + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_write_dataset_partitioned_dict(tempdir): + directory = tempdir / "partitioned" + _ = _create_parquet_dataset_partitioned(directory) + + # directory partitioning, dictionary partition columns + dataset = ds.dataset( + directory, + partitioning=ds.HivePartitioning.discover(infer_dictionary=True)) + target = tempdir / 'partitioned-dir-target' + expected_paths = [ + target / "a", target / "a" / "part-0.feather", + target / "b", target / "b" / "part-0.feather" + ] + partitioning = ds.partitioning(pa.schema([ + dataset.schema.field('part')]), + dictionaries={'part': pa.array(['a', 'b'])}) + # NB: dictionaries required here since we use partitioning to parse + # directories in _check_dataset_roundtrip (not currently required for + # the formatting step) + _check_dataset_roundtrip( + dataset, str(target), expected_paths, 'f1', target, + partitioning=partitioning) + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_write_dataset_use_threads(tempdir): + directory = tempdir / "partitioned" + _ = _create_parquet_dataset_partitioned(directory) + dataset = ds.dataset(directory, partitioning="hive") + + partitioning = ds.partitioning( + pa.schema([("part", pa.string())]), flavor="hive") + + target1 = tempdir / 'partitioned1' + paths_written = [] + + def file_visitor(written_file): + paths_written.append(written_file.path) + + ds.write_dataset( + dataset, target1, format="feather", partitioning=partitioning, + use_threads=True, file_visitor=file_visitor + ) + + expected_paths = { + target1 / 'part=a' / 'part-0.feather', + target1 / 'part=b' / 'part-0.feather' + } + paths_written_set = set(map(pathlib.Path, paths_written)) + assert paths_written_set == expected_paths + + target2 = tempdir / 'partitioned2' + ds.write_dataset( + dataset, target2, format="feather", partitioning=partitioning, + use_threads=False + ) + + # check that reading in gives same result + result1 = ds.dataset(target1, format="feather", partitioning=partitioning) + result2 = ds.dataset(target2, format="feather", partitioning=partitioning) + assert result1.to_table().equals(result2.to_table()) + + +def test_write_table(tempdir): + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10)) + ], names=["f1", "f2", "part"]) + + base_dir = tempdir / 'single' + ds.write_dataset(table, base_dir, + basename_template='dat_{i}.arrow', format="feather") + # check that all files are present + file_paths = list(base_dir.rglob("*")) + expected_paths = [base_dir / "dat_0.arrow"] + assert set(file_paths) == set(expected_paths) + # check Table roundtrip + result = ds.dataset(base_dir, format="ipc").to_table() + assert result.equals(table) + + # with partitioning + base_dir = tempdir / 'partitioned' + expected_paths = [ + base_dir / "part=a", base_dir / "part=a" / "dat_0.arrow", + base_dir / "part=b", base_dir / "part=b" / "dat_0.arrow" + ] + + visited_paths = [] + + def file_visitor(written_file): + visited_paths.append(written_file.path) + + partitioning = ds.partitioning( + pa.schema([("part", pa.string())]), flavor="hive") + ds.write_dataset(table, base_dir, format="feather", + basename_template='dat_{i}.arrow', + partitioning=partitioning, file_visitor=file_visitor) + file_paths = list(base_dir.rglob("*")) + assert set(file_paths) == set(expected_paths) + result = ds.dataset(base_dir, format="ipc", partitioning=partitioning) + assert result.to_table().equals(table) + assert len(visited_paths) == 2 + for visited_path in visited_paths: + assert pathlib.Path(visited_path) in expected_paths + + +def test_write_table_multiple_fragments(tempdir): + table = pa.table([ + pa.array(range(10)), pa.array(np.random.randn(10)), + pa.array(np.repeat(['a', 'b'], 5)) + ], names=["f1", "f2", "part"]) + table = pa.concat_tables([table]*2) + + # Table with multiple batches written as single Fragment by default + base_dir = tempdir / 'single' + ds.write_dataset(table, base_dir, format="feather") + assert set(base_dir.rglob("*")) == set([base_dir / "part-0.feather"]) + assert ds.dataset(base_dir, format="ipc").to_table().equals(table) + + # Same for single-element list of Table + base_dir = tempdir / 'single-list' + ds.write_dataset([table], base_dir, format="feather") + assert set(base_dir.rglob("*")) == set([base_dir / "part-0.feather"]) + assert ds.dataset(base_dir, format="ipc").to_table().equals(table) + + # Provide list of batches to write multiple fragments + base_dir = tempdir / 'multiple' + ds.write_dataset(table.to_batches(), base_dir, format="feather") + assert set(base_dir.rglob("*")) == set( + [base_dir / "part-0.feather"]) + assert ds.dataset(base_dir, format="ipc").to_table().equals(table) + + # Provide list of tables to write multiple fragments + base_dir = tempdir / 'multiple-table' + ds.write_dataset([table, table], base_dir, format="feather") + assert set(base_dir.rglob("*")) == set( + [base_dir / "part-0.feather"]) + assert ds.dataset(base_dir, format="ipc").to_table().equals( + pa.concat_tables([table]*2) + ) + + +def test_write_iterable(tempdir): + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10)) + ], names=["f1", "f2", "part"]) + + base_dir = tempdir / 'inmemory_iterable' + ds.write_dataset((batch for batch in table.to_batches()), base_dir, + schema=table.schema, + basename_template='dat_{i}.arrow', format="feather") + result = ds.dataset(base_dir, format="ipc").to_table() + assert result.equals(table) + + base_dir = tempdir / 'inmemory_reader' + reader = pa.ipc.RecordBatchReader.from_batches(table.schema, + table.to_batches()) + ds.write_dataset(reader, base_dir, + basename_template='dat_{i}.arrow', format="feather") + result = ds.dataset(base_dir, format="ipc").to_table() + assert result.equals(table) + + +def test_write_scanner(tempdir, dataset_reader): + if not dataset_reader.use_async: + pytest.skip( + ('ARROW-13338: Write dataset with scanner does not' + ' support synchronous scan')) + + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10)) + ], names=["f1", "f2", "part"]) + dataset = ds.dataset(table) + + base_dir = tempdir / 'dataset_from_scanner' + ds.write_dataset(dataset_reader.scanner( + dataset), base_dir, format="feather") + result = dataset_reader.to_table(ds.dataset(base_dir, format="ipc")) + assert result.equals(table) + + # scanner with different projected_schema + base_dir = tempdir / 'dataset_from_scanner2' + ds.write_dataset(dataset_reader.scanner(dataset, columns=["f1"]), + base_dir, format="feather") + result = dataset_reader.to_table(ds.dataset(base_dir, format="ipc")) + assert result.equals(table.select(["f1"])) + + # schema not allowed when writing a scanner + with pytest.raises(ValueError, match="Cannot specify a schema"): + ds.write_dataset(dataset_reader.scanner(dataset), base_dir, + schema=table.schema, format="feather") + + +def test_write_table_partitioned_dict(tempdir): + # ensure writing table partitioned on a dictionary column works without + # specifying the dictionary values explicitly + table = pa.table([ + pa.array(range(20)), + pa.array(np.repeat(['a', 'b'], 10)).dictionary_encode(), + ], names=['col', 'part']) + + partitioning = ds.partitioning(table.select(["part"]).schema) + + base_dir = tempdir / "dataset" + ds.write_dataset( + table, base_dir, format="feather", partitioning=partitioning + ) + + # check roundtrip + partitioning_read = ds.DirectoryPartitioning.discover( + ["part"], infer_dictionary=True) + result = ds.dataset( + base_dir, format="ipc", partitioning=partitioning_read + ).to_table() + assert result.equals(table) + + +@pytest.mark.parquet +def test_write_dataset_parquet(tempdir): + import pyarrow.parquet as pq + + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10)) + ], names=["f1", "f2", "part"]) + + # using default "parquet" format string + + base_dir = tempdir / 'parquet_dataset' + ds.write_dataset(table, base_dir, format="parquet") + # check that all files are present + file_paths = list(base_dir.rglob("*")) + expected_paths = [base_dir / "part-0.parquet"] + assert set(file_paths) == set(expected_paths) + # check Table roundtrip + result = ds.dataset(base_dir, format="parquet").to_table() + assert result.equals(table) + + # using custom options + for version in ["1.0", "2.4", "2.6"]: + format = ds.ParquetFileFormat() + opts = format.make_write_options(version=version) + base_dir = tempdir / 'parquet_dataset_version{0}'.format(version) + ds.write_dataset(table, base_dir, format=format, file_options=opts) + meta = pq.read_metadata(base_dir / "part-0.parquet") + expected_version = "1.0" if version == "1.0" else "2.6" + assert meta.format_version == expected_version + + +def test_write_dataset_csv(tempdir): + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10)) + ], names=["f1", "f2", "chr1"]) + + base_dir = tempdir / 'csv_dataset' + ds.write_dataset(table, base_dir, format="csv") + # check that all files are present + file_paths = list(base_dir.rglob("*")) + expected_paths = [base_dir / "part-0.csv"] + assert set(file_paths) == set(expected_paths) + # check Table roundtrip + result = ds.dataset(base_dir, format="csv").to_table() + assert result.equals(table) + + # using custom options + format = ds.CsvFileFormat(read_options=pyarrow.csv.ReadOptions( + column_names=table.schema.names)) + opts = format.make_write_options(include_header=False) + base_dir = tempdir / 'csv_dataset_noheader' + ds.write_dataset(table, base_dir, format=format, file_options=opts) + result = ds.dataset(base_dir, format=format).to_table() + assert result.equals(table) + + +@pytest.mark.parquet +def test_write_dataset_parquet_file_visitor(tempdir): + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10)) + ], names=["f1", "f2", "part"]) + + visitor_called = False + + def file_visitor(written_file): + nonlocal visitor_called + if (written_file.metadata is not None and + written_file.metadata.num_columns == 3): + visitor_called = True + + base_dir = tempdir / 'parquet_dataset' + ds.write_dataset(table, base_dir, format="parquet", + file_visitor=file_visitor) + + assert visitor_called + + +def test_partition_dataset_parquet_file_visitor(tempdir): + f1_vals = [item for chunk in range(4) for item in [chunk] * 10] + f2_vals = [item*10 for chunk in range(4) for item in [chunk] * 10] + table = pa.table({'f1': f1_vals, 'f2': f2_vals, + 'part': np.repeat(['a', 'b'], 20)}) + + root_path = tempdir / 'partitioned' + partitioning = ds.partitioning( + pa.schema([("part", pa.string())]), flavor="hive") + + paths_written = [] + + sample_metadata = None + + def file_visitor(written_file): + nonlocal sample_metadata + if written_file.metadata: + sample_metadata = written_file.metadata + paths_written.append(written_file.path) + + ds.write_dataset( + table, root_path, format="parquet", partitioning=partitioning, + use_threads=True, file_visitor=file_visitor + ) + + expected_paths = { + root_path / 'part=a' / 'part-0.parquet', + root_path / 'part=b' / 'part-0.parquet' + } + paths_written_set = set(map(pathlib.Path, paths_written)) + assert paths_written_set == expected_paths + assert sample_metadata is not None + assert sample_metadata.num_columns == 2 + + +@pytest.mark.parquet +@pytest.mark.pandas +def test_write_dataset_arrow_schema_metadata(tempdir): + # ensure we serialize ARROW schema in the parquet metadata, to have a + # correct roundtrip (e.g. preserve non-UTC timezone) + import pyarrow.parquet as pq + + table = pa.table({"a": [pd.Timestamp("2012-01-01", tz="Europe/Brussels")]}) + assert table["a"].type.tz == "Europe/Brussels" + + ds.write_dataset(table, tempdir, format="parquet") + result = pq.read_table(tempdir / "part-0.parquet") + assert result["a"].type.tz == "Europe/Brussels" + + +def test_write_dataset_schema_metadata(tempdir): + # ensure that schema metadata gets written + from pyarrow import feather + + table = pa.table({'a': [1, 2, 3]}) + table = table.replace_schema_metadata({b'key': b'value'}) + ds.write_dataset(table, tempdir, format="feather") + + schema = feather.read_table(tempdir / "part-0.feather").schema + assert schema.metadata == {b'key': b'value'} + + +@pytest.mark.parquet +def test_write_dataset_schema_metadata_parquet(tempdir): + # ensure that schema metadata gets written + import pyarrow.parquet as pq + + table = pa.table({'a': [1, 2, 3]}) + table = table.replace_schema_metadata({b'key': b'value'}) + ds.write_dataset(table, tempdir, format="parquet") + + schema = pq.read_table(tempdir / "part-0.parquet").schema + assert schema.metadata == {b'key': b'value'} + + +@pytest.mark.parquet +@pytest.mark.s3 +def test_write_dataset_s3(s3_example_simple): + # write dataset with s3 filesystem + _, _, fs, _, host, port, access_key, secret_key = s3_example_simple + uri_template = ( + "s3://{}:{}@{{}}?scheme=http&endpoint_override={}:{}".format( + access_key, secret_key, host, port) + ) + + table = pa.table([ + pa.array(range(20)), pa.array(np.random.randn(20)), + pa.array(np.repeat(['a', 'b'], 10))], + names=["f1", "f2", "part"] + ) + part = ds.partitioning(pa.schema([("part", pa.string())]), flavor="hive") + + # writing with filesystem object + ds.write_dataset( + table, "mybucket/dataset", filesystem=fs, format="feather", + partitioning=part + ) + # check rountrip + result = ds.dataset( + "mybucket/dataset", filesystem=fs, format="ipc", partitioning="hive" + ).to_table() + assert result.equals(table) + + # writing with URI + uri = uri_template.format("mybucket/dataset2") + ds.write_dataset(table, uri, format="feather", partitioning=part) + # check rountrip + result = ds.dataset( + "mybucket/dataset2", filesystem=fs, format="ipc", partitioning="hive" + ).to_table() + assert result.equals(table) + + # writing with path + URI as filesystem + uri = uri_template.format("mybucket") + ds.write_dataset( + table, "dataset3", filesystem=uri, format="feather", partitioning=part + ) + # check rountrip + result = ds.dataset( + "mybucket/dataset3", filesystem=fs, format="ipc", partitioning="hive" + ).to_table() + assert result.equals(table) + + +@pytest.mark.parquet +def test_dataset_null_to_dictionary_cast(tempdir, dataset_reader): + # ARROW-12420 + import pyarrow.parquet as pq + + table = pa.table({"a": [None, None]}) + pq.write_table(table, tempdir / "test.parquet") + + schema = pa.schema([ + pa.field("a", pa.dictionary(pa.int32(), pa.string())) + ]) + fsds = ds.FileSystemDataset.from_paths( + paths=[tempdir / "test.parquet"], + schema=schema, + format=ds.ParquetFileFormat(), + filesystem=fs.LocalFileSystem(), + ) + table = dataset_reader.to_table(fsds) + assert table.schema == schema diff --git a/src/arrow/python/pyarrow/tests/test_deprecations.py b/src/arrow/python/pyarrow/tests/test_deprecations.py new file mode 100644 index 000000000..b16528937 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_deprecations.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Check that various deprecation warnings are raised + +# flake8: noqa + +import pyarrow as pa +import pytest diff --git a/src/arrow/python/pyarrow/tests/test_extension_type.py b/src/arrow/python/pyarrow/tests/test_extension_type.py new file mode 100644 index 000000000..4ea6e94e4 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_extension_type.py @@ -0,0 +1,779 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pickle +import weakref + +import numpy as np +import pyarrow as pa + +import pytest + + +class IntegerType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int64()) + + def __reduce__(self): + return IntegerType, () + + +class UuidType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.binary(16)) + + def __reduce__(self): + return UuidType, () + + +class ParamExtType(pa.PyExtensionType): + + def __init__(self, width): + self._width = width + pa.PyExtensionType.__init__(self, pa.binary(width)) + + @property + def width(self): + return self._width + + def __reduce__(self): + return ParamExtType, (self.width,) + + +class MyStructType(pa.PyExtensionType): + storage_type = pa.struct([('left', pa.int64()), + ('right', pa.int64())]) + + def __init__(self): + pa.PyExtensionType.__init__(self, self.storage_type) + + def __reduce__(self): + return MyStructType, () + + +class MyListType(pa.PyExtensionType): + + def __init__(self, storage_type): + pa.PyExtensionType.__init__(self, storage_type) + + def __reduce__(self): + return MyListType, (self.storage_type,) + + +def ipc_write_batch(batch): + stream = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + writer.close() + return stream.getvalue() + + +def ipc_read_batch(buf): + reader = pa.RecordBatchStreamReader(buf) + return reader.read_next_batch() + + +def test_ext_type_basics(): + ty = UuidType() + assert ty.extension_name == "arrow.py_extension_type" + + +def test_ext_type_str(): + ty = IntegerType() + expected = "extension<arrow.py_extension_type<IntegerType>>" + assert str(ty) == expected + assert pa.DataType.__str__(ty) == expected + + +def test_ext_type_repr(): + ty = IntegerType() + assert repr(ty) == "IntegerType(DataType(int64))" + + +def test_ext_type__lifetime(): + ty = UuidType() + wr = weakref.ref(ty) + del ty + assert wr() is None + + +def test_ext_type__storage_type(): + ty = UuidType() + assert ty.storage_type == pa.binary(16) + assert ty.__class__ is UuidType + ty = ParamExtType(5) + assert ty.storage_type == pa.binary(5) + assert ty.__class__ is ParamExtType + + +def test_uuid_type_pickle(): + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + ty = UuidType() + ser = pickle.dumps(ty, protocol=proto) + del ty + ty = pickle.loads(ser) + wr = weakref.ref(ty) + assert ty.extension_name == "arrow.py_extension_type" + del ty + assert wr() is None + + +def test_ext_type_equality(): + a = ParamExtType(5) + b = ParamExtType(6) + c = ParamExtType(6) + assert a != b + assert b == c + d = UuidType() + e = UuidType() + assert a != d + assert d == e + + +def test_ext_array_basics(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + arr.validate() + assert arr.type is ty + assert arr.storage.equals(storage) + + +def test_ext_array_lifetime(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + + refs = [weakref.ref(ty), weakref.ref(arr), weakref.ref(storage)] + del ty, storage, arr + for ref in refs: + assert ref() is None + + +def test_ext_array_to_pylist(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar", None], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + + assert arr.to_pylist() == [b"foo", b"bar", None] + + +def test_ext_array_errors(): + ty = ParamExtType(4) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + with pytest.raises(TypeError, match="Incompatible storage type"): + pa.ExtensionArray.from_storage(ty, storage) + + +def test_ext_array_equality(): + storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + storage3 = pa.array([], type=pa.binary(16)) + ty1 = UuidType() + ty2 = ParamExtType(16) + + a = pa.ExtensionArray.from_storage(ty1, storage1) + b = pa.ExtensionArray.from_storage(ty1, storage2) + assert a.equals(b) + c = pa.ExtensionArray.from_storage(ty1, storage3) + assert not a.equals(c) + d = pa.ExtensionArray.from_storage(ty2, storage1) + assert not a.equals(d) + e = pa.ExtensionArray.from_storage(ty2, storage2) + assert d.equals(e) + f = pa.ExtensionArray.from_storage(ty2, storage3) + assert not d.equals(f) + + +def test_ext_array_wrap_array(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar", None], type=pa.binary(3)) + arr = ty.wrap_array(storage) + arr.validate(full=True) + assert isinstance(arr, pa.ExtensionArray) + assert arr.type == ty + assert arr.storage == storage + + storage = pa.chunked_array([[b"abc", b"def"], [b"ghi"]], + type=pa.binary(3)) + arr = ty.wrap_array(storage) + arr.validate(full=True) + assert isinstance(arr, pa.ChunkedArray) + assert arr.type == ty + assert arr.chunk(0).storage == storage.chunk(0) + assert arr.chunk(1).storage == storage.chunk(1) + + # Wrong storage type + storage = pa.array([b"foo", b"bar", None]) + with pytest.raises(TypeError, match="Incompatible storage type"): + ty.wrap_array(storage) + + # Not an array or chunked array + with pytest.raises(TypeError, match="Expected array or chunked array"): + ty.wrap_array(None) + + +def test_ext_scalar_from_array(): + data = [b"0123456789abcdef", b"0123456789abcdef", + b"zyxwvutsrqponmlk", None] + storage = pa.array(data, type=pa.binary(16)) + ty1 = UuidType() + ty2 = ParamExtType(16) + + a = pa.ExtensionArray.from_storage(ty1, storage) + b = pa.ExtensionArray.from_storage(ty2, storage) + + scalars_a = list(a) + assert len(scalars_a) == 4 + + for s, val in zip(scalars_a, data): + assert isinstance(s, pa.ExtensionScalar) + assert s.is_valid == (val is not None) + assert s.type == ty1 + if val is not None: + assert s.value == pa.scalar(val, storage.type) + else: + assert s.value is None + assert s.as_py() == val + + scalars_b = list(b) + assert len(scalars_b) == 4 + + for sa, sb in zip(scalars_a, scalars_b): + assert sa.is_valid == sb.is_valid + assert sa.as_py() == sb.as_py() + assert sa != sb + + +def test_ext_scalar_from_storage(): + ty = UuidType() + + s = pa.ExtensionScalar.from_storage(ty, None) + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is False + assert s.value is None + + s = pa.ExtensionScalar.from_storage(ty, b"0123456789abcdef") + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is True + assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type) + + s = pa.ExtensionScalar.from_storage(ty, pa.scalar(None, ty.storage_type)) + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is False + assert s.value is None + + s = pa.ExtensionScalar.from_storage( + ty, pa.scalar(b"0123456789abcdef", ty.storage_type)) + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is True + assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type) + + +def test_ext_array_pickling(): + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + ser = pickle.dumps(arr, protocol=proto) + del ty, storage, arr + arr = pickle.loads(ser) + arr.validate() + assert isinstance(arr, pa.ExtensionArray) + assert arr.type == ParamExtType(3) + assert arr.type.storage_type == pa.binary(3) + assert arr.storage.type == pa.binary(3) + assert arr.storage.to_pylist() == [b"foo", b"bar"] + + +def test_ext_array_conversion_to_numpy(): + storage1 = pa.array([1, 2, 3], type=pa.int64()) + storage2 = pa.array([b"123", b"456", b"789"], type=pa.binary(3)) + ty1 = IntegerType() + ty2 = ParamExtType(3) + + arr1 = pa.ExtensionArray.from_storage(ty1, storage1) + arr2 = pa.ExtensionArray.from_storage(ty2, storage2) + + result = arr1.to_numpy() + expected = np.array([1, 2, 3], dtype="int64") + np.testing.assert_array_equal(result, expected) + + with pytest.raises(ValueError, match="zero_copy_only was True"): + arr2.to_numpy() + result = arr2.to_numpy(zero_copy_only=False) + expected = np.array([b"123", b"456", b"789"]) + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.pandas +def test_ext_array_conversion_to_pandas(): + import pandas as pd + + storage1 = pa.array([1, 2, 3], type=pa.int64()) + storage2 = pa.array([b"123", b"456", b"789"], type=pa.binary(3)) + ty1 = IntegerType() + ty2 = ParamExtType(3) + + arr1 = pa.ExtensionArray.from_storage(ty1, storage1) + arr2 = pa.ExtensionArray.from_storage(ty2, storage2) + + result = arr1.to_pandas() + expected = pd.Series([1, 2, 3], dtype="int64") + pd.testing.assert_series_equal(result, expected) + + result = arr2.to_pandas() + expected = pd.Series([b"123", b"456", b"789"], dtype=object) + pd.testing.assert_series_equal(result, expected) + + +def test_cast_kernel_on_extension_arrays(): + # test array casting + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(IntegerType(), storage) + + # test that no allocation happens during identity cast + allocated_before_cast = pa.total_allocated_bytes() + casted = arr.cast(pa.int64()) + assert pa.total_allocated_bytes() == allocated_before_cast + + cases = [ + (pa.int64(), pa.Int64Array), + (pa.int32(), pa.Int32Array), + (pa.int16(), pa.Int16Array), + (pa.uint64(), pa.UInt64Array), + (pa.uint32(), pa.UInt32Array), + (pa.uint16(), pa.UInt16Array) + ] + for typ, klass in cases: + casted = arr.cast(typ) + assert casted.type == typ + assert isinstance(casted, klass) + + # test chunked array casting + arr = pa.chunked_array([arr, arr]) + casted = arr.cast(pa.int16()) + assert casted.type == pa.int16() + assert isinstance(casted, pa.ChunkedArray) + + +def test_casting_to_extension_type_raises(): + arr = pa.array([1, 2, 3, 4], pa.int64()) + with pytest.raises(pa.ArrowNotImplementedError): + arr.cast(IntegerType()) + + +def example_batch(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + return pa.RecordBatch.from_arrays([arr], ["exts"]) + + +def check_example_batch(batch): + arr = batch.column(0) + assert isinstance(arr, pa.ExtensionArray) + assert arr.type.storage_type == pa.binary(3) + assert arr.storage.to_pylist() == [b"foo", b"bar"] + return arr + + +def test_ipc(): + batch = example_batch() + buf = ipc_write_batch(batch) + del batch + + batch = ipc_read_batch(buf) + arr = check_example_batch(batch) + assert arr.type == ParamExtType(3) + + +def test_ipc_unknown_type(): + batch = example_batch() + buf = ipc_write_batch(batch) + del batch + + orig_type = ParamExtType + try: + # Simulate the original Python type being unavailable. + # Deserialization should not fail but return a placeholder type. + del globals()['ParamExtType'] + + batch = ipc_read_batch(buf) + arr = check_example_batch(batch) + assert isinstance(arr.type, pa.UnknownExtensionType) + + # Can be serialized again + buf2 = ipc_write_batch(batch) + del batch, arr + + batch = ipc_read_batch(buf2) + arr = check_example_batch(batch) + assert isinstance(arr.type, pa.UnknownExtensionType) + finally: + globals()['ParamExtType'] = orig_type + + # Deserialize again with the type restored + batch = ipc_read_batch(buf2) + arr = check_example_batch(batch) + assert arr.type == ParamExtType(3) + + +class PeriodArray(pa.ExtensionArray): + pass + + +class PeriodType(pa.ExtensionType): + def __init__(self, freq): + # attributes need to be set first before calling + # super init (as that calls serialize) + self._freq = freq + pa.ExtensionType.__init__(self, pa.int64(), 'test.period') + + @property + def freq(self): + return self._freq + + def __arrow_ext_serialize__(self): + return "freq={}".format(self.freq).encode() + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + serialized = serialized.decode() + assert serialized.startswith("freq=") + freq = serialized.split('=')[1] + return PeriodType(freq) + + def __eq__(self, other): + if isinstance(other, pa.BaseExtensionType): + return (type(self) == type(other) and + self.freq == other.freq) + else: + return NotImplemented + + +class PeriodTypeWithClass(PeriodType): + def __init__(self, freq): + PeriodType.__init__(self, freq) + + def __arrow_ext_class__(self): + return PeriodArray + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + freq = PeriodType.__arrow_ext_deserialize__( + storage_type, serialized).freq + return PeriodTypeWithClass(freq) + + +@pytest.fixture(params=[PeriodType('D'), PeriodTypeWithClass('D')]) +def registered_period_type(request): + # setup + period_type = request.param + period_class = period_type.__arrow_ext_class__() + pa.register_extension_type(period_type) + yield period_type, period_class + # teardown + try: + pa.unregister_extension_type('test.period') + except KeyError: + pass + + +def test_generic_ext_type(): + period_type = PeriodType('D') + assert period_type.extension_name == "test.period" + assert period_type.storage_type == pa.int64() + # default ext_class expected. + assert period_type.__arrow_ext_class__() == pa.ExtensionArray + + +def test_generic_ext_type_ipc(registered_period_type): + period_type, period_class = registered_period_type + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(period_type, storage) + batch = pa.RecordBatch.from_arrays([arr], ["ext"]) + # check the built array has exactly the expected clss + assert type(arr) == period_class + + buf = ipc_write_batch(batch) + del batch + batch = ipc_read_batch(buf) + + result = batch.column(0) + # check the deserialized array class is the expected one + assert type(result) == period_class + assert result.type.extension_name == "test.period" + assert arr.storage.to_pylist() == [1, 2, 3, 4] + + # we get back an actual PeriodType + assert isinstance(result.type, PeriodType) + assert result.type.freq == 'D' + assert result.type == period_type + + # using different parametrization as how it was registered + period_type_H = period_type.__class__('H') + assert period_type_H.extension_name == "test.period" + assert period_type_H.freq == 'H' + + arr = pa.ExtensionArray.from_storage(period_type_H, storage) + batch = pa.RecordBatch.from_arrays([arr], ["ext"]) + + buf = ipc_write_batch(batch) + del batch + batch = ipc_read_batch(buf) + result = batch.column(0) + assert isinstance(result.type, PeriodType) + assert result.type.freq == 'H' + assert type(result) == period_class + + +def test_generic_ext_type_ipc_unknown(registered_period_type): + period_type, _ = registered_period_type + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(period_type, storage) + batch = pa.RecordBatch.from_arrays([arr], ["ext"]) + + buf = ipc_write_batch(batch) + del batch + + # unregister type before loading again => reading unknown extension type + # as plain array (but metadata in schema's field are preserved) + pa.unregister_extension_type('test.period') + + batch = ipc_read_batch(buf) + result = batch.column(0) + + assert isinstance(result, pa.Int64Array) + ext_field = batch.schema.field('ext') + assert ext_field.metadata == { + b'ARROW:extension:metadata': b'freq=D', + b'ARROW:extension:name': b'test.period' + } + + +def test_generic_ext_type_equality(): + period_type = PeriodType('D') + assert period_type.extension_name == "test.period" + + period_type2 = PeriodType('D') + period_type3 = PeriodType('H') + assert period_type == period_type2 + assert not period_type == period_type3 + + +def test_generic_ext_type_register(registered_period_type): + # test that trying to register other type does not segfault + with pytest.raises(TypeError): + pa.register_extension_type(pa.string()) + + # register second time raises KeyError + period_type = PeriodType('D') + with pytest.raises(KeyError): + pa.register_extension_type(period_type) + + +@pytest.mark.parquet +def test_parquet_period(tmpdir, registered_period_type): + # Parquet support for primitive extension types + period_type, period_class = registered_period_type + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(period_type, storage) + table = pa.table([arr], names=["ext"]) + + import pyarrow.parquet as pq + + filename = tmpdir / 'period_extension_type.parquet' + pq.write_table(table, filename) + + # Stored in parquet as storage type but with extension metadata saved + # in the serialized arrow schema + meta = pq.read_metadata(filename) + assert meta.schema.column(0).physical_type == "INT64" + assert b"ARROW:schema" in meta.metadata + + import base64 + decoded_schema = base64.b64decode(meta.metadata[b"ARROW:schema"]) + schema = pa.ipc.read_schema(pa.BufferReader(decoded_schema)) + # Since the type could be reconstructed, the extension type metadata is + # absent. + assert schema.field("ext").metadata == {} + + # When reading in, properly create extension type if it is registered + result = pq.read_table(filename) + assert result.schema.field("ext").type == period_type + assert result.schema.field("ext").metadata == {} + # Get the exact array class defined by the registered type. + result_array = result.column("ext").chunk(0) + assert type(result_array) is period_class + + # When the type is not registered, read in as storage type + pa.unregister_extension_type(period_type.extension_name) + result = pq.read_table(filename) + assert result.schema.field("ext").type == pa.int64() + # The extension metadata is present for roundtripping. + assert result.schema.field("ext").metadata == { + b'ARROW:extension:metadata': b'freq=D', + b'ARROW:extension:name': b'test.period' + } + + +@pytest.mark.parquet +def test_parquet_extension_with_nested_storage(tmpdir): + # Parquet support for extension types with nested storage type + import pyarrow.parquet as pq + + struct_array = pa.StructArray.from_arrays( + [pa.array([0, 1], type="int64"), pa.array([4, 5], type="int64")], + names=["left", "right"]) + list_array = pa.array([[1, 2, 3], [4, 5]], type=pa.list_(pa.int32())) + + mystruct_array = pa.ExtensionArray.from_storage(MyStructType(), + struct_array) + mylist_array = pa.ExtensionArray.from_storage( + MyListType(list_array.type), list_array) + + orig_table = pa.table({'structs': mystruct_array, + 'lists': mylist_array}) + filename = tmpdir / 'nested_extension_storage.parquet' + pq.write_table(orig_table, filename) + + table = pq.read_table(filename) + assert table.column('structs').type == mystruct_array.type + assert table.column('lists').type == mylist_array.type + assert table == orig_table + + +@pytest.mark.parquet +def test_parquet_nested_extension(tmpdir): + # Parquet support for extension types nested in struct or list + import pyarrow.parquet as pq + + ext_type = IntegerType() + storage = pa.array([4, 5, 6, 7], type=pa.int64()) + ext_array = pa.ExtensionArray.from_storage(ext_type, storage) + + # Struct of extensions + struct_array = pa.StructArray.from_arrays( + [storage, ext_array], + names=['ints', 'exts']) + + orig_table = pa.table({'structs': struct_array}) + filename = tmpdir / 'struct_of_ext.parquet' + pq.write_table(orig_table, filename) + + table = pq.read_table(filename) + assert table.column(0).type == struct_array.type + assert table == orig_table + + # List of extensions + list_array = pa.ListArray.from_arrays([0, 1, None, 3], ext_array) + + orig_table = pa.table({'lists': list_array}) + filename = tmpdir / 'list_of_ext.parquet' + pq.write_table(orig_table, filename) + + table = pq.read_table(filename) + assert table.column(0).type == list_array.type + assert table == orig_table + + # Large list of extensions + list_array = pa.LargeListArray.from_arrays([0, 1, None, 3], ext_array) + + orig_table = pa.table({'lists': list_array}) + filename = tmpdir / 'list_of_ext.parquet' + pq.write_table(orig_table, filename) + + table = pq.read_table(filename) + assert table.column(0).type == list_array.type + assert table == orig_table + + +@pytest.mark.parquet +def test_parquet_extension_nested_in_extension(tmpdir): + # Parquet support for extension<list<extension>> + import pyarrow.parquet as pq + + inner_ext_type = IntegerType() + inner_storage = pa.array([4, 5, 6, 7], type=pa.int64()) + inner_ext_array = pa.ExtensionArray.from_storage(inner_ext_type, + inner_storage) + + list_array = pa.ListArray.from_arrays([0, 1, None, 3], inner_ext_array) + mylist_array = pa.ExtensionArray.from_storage( + MyListType(list_array.type), list_array) + + orig_table = pa.table({'lists': mylist_array}) + filename = tmpdir / 'ext_of_list_of_ext.parquet' + pq.write_table(orig_table, filename) + + table = pq.read_table(filename) + assert table.column(0).type == mylist_array.type + assert table == orig_table + + +def test_to_numpy(): + period_type = PeriodType('D') + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(period_type, storage) + + expected = storage.to_numpy() + result = arr.to_numpy() + np.testing.assert_array_equal(result, expected) + + result = np.asarray(arr) + np.testing.assert_array_equal(result, expected) + + # chunked array + a1 = pa.chunked_array([arr, arr]) + a2 = pa.chunked_array([arr, arr], type=period_type) + expected = np.hstack([expected, expected]) + + for charr in [a1, a2]: + assert charr.type == period_type + for result in [np.asarray(charr), charr.to_numpy()]: + assert result.dtype == np.int64 + np.testing.assert_array_equal(result, expected) + + # zero chunks + charr = pa.chunked_array([], type=period_type) + assert charr.type == period_type + + for result in [np.asarray(charr), charr.to_numpy()]: + assert result.dtype == np.int64 + np.testing.assert_array_equal(result, np.array([], dtype='int64')) + + +def test_empty_take(): + # https://issues.apache.org/jira/browse/ARROW-13474 + ext_type = IntegerType() + storage = pa.array([], type=pa.int64()) + empty_arr = pa.ExtensionArray.from_storage(ext_type, storage) + + result = empty_arr.filter(pa.array([], pa.bool_())) + assert len(result) == 0 + assert result.equals(empty_arr) + + result = empty_arr.take(pa.array([], pa.int32())) + assert len(result) == 0 + assert result.equals(empty_arr) diff --git a/src/arrow/python/pyarrow/tests/test_feather.py b/src/arrow/python/pyarrow/tests/test_feather.py new file mode 100644 index 000000000..3d0451ee3 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_feather.py @@ -0,0 +1,799 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import os +import sys +import tempfile +import pytest +import hypothesis as h +import hypothesis.strategies as st + +import numpy as np + +import pyarrow as pa +import pyarrow.tests.strategies as past +from pyarrow.feather import (read_feather, write_feather, read_table, + FeatherDataset) + + +try: + from pandas.testing import assert_frame_equal + import pandas as pd + import pyarrow.pandas_compat +except ImportError: + pass + + +@pytest.fixture(scope='module') +def datadir(base_datadir): + return base_datadir / 'feather' + + +def random_path(prefix='feather_'): + return tempfile.mktemp(prefix=prefix) + + +@pytest.fixture(scope="module", params=[1, 2]) +def version(request): + yield request.param + + +@pytest.fixture(scope="module", params=[None, "uncompressed", "lz4", "zstd"]) +def compression(request): + if request.param in ['lz4', 'zstd'] and not pa.Codec.is_available( + request.param): + pytest.skip(f'{request.param} is not available') + yield request.param + + +TEST_FILES = None + + +def setup_module(module): + global TEST_FILES + TEST_FILES = [] + + +def teardown_module(module): + for path in TEST_FILES: + try: + os.remove(path) + except os.error: + pass + + +@pytest.mark.pandas +def test_file_not_exist(): + with pytest.raises(pa.ArrowIOError): + read_feather('test_invalid_file') + + +def _check_pandas_roundtrip(df, expected=None, path=None, + columns=None, use_threads=False, + version=None, compression=None, + compression_level=None): + if path is None: + path = random_path() + + TEST_FILES.append(path) + write_feather(df, path, compression=compression, + compression_level=compression_level, version=version) + if not os.path.exists(path): + raise Exception('file not written') + + result = read_feather(path, columns, use_threads=use_threads) + if expected is None: + expected = df + + assert_frame_equal(result, expected) + + +def _check_arrow_roundtrip(table, path=None, compression=None): + if path is None: + path = random_path() + + TEST_FILES.append(path) + write_feather(table, path, compression=compression) + if not os.path.exists(path): + raise Exception('file not written') + + result = read_table(path) + assert result.equals(table) + + +def _assert_error_on_write(df, exc, path=None, version=2): + # check that we are raising the exception + # on writing + + if path is None: + path = random_path() + + TEST_FILES.append(path) + + def f(): + write_feather(df, path, version=version) + + pytest.raises(exc, f) + + +def test_dataset(version): + num_values = (100, 100) + num_files = 5 + paths = [random_path() for i in range(num_files)] + data = { + "col_" + str(i): np.random.randn(num_values[0]) + for i in range(num_values[1]) + } + table = pa.table(data) + + TEST_FILES.extend(paths) + for index, path in enumerate(paths): + rows = ( + index * (num_values[0] // num_files), + (index + 1) * (num_values[0] // num_files), + ) + + write_feather(table[rows[0]: rows[1]], path, version=version) + + data = FeatherDataset(paths).read_table() + assert data.equals(table) + + +@pytest.mark.pandas +def test_float_no_nulls(version): + data = {} + numpy_dtypes = ['f4', 'f8'] + num_values = 100 + + for dtype in numpy_dtypes: + values = np.random.randn(num_values) + data[dtype] = values.astype(dtype) + + df = pd.DataFrame(data) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_read_table(version): + num_values = (100, 100) + path = random_path() + + TEST_FILES.append(path) + + values = np.random.randint(0, 100, size=num_values) + columns = ['col_' + str(i) for i in range(100)] + table = pa.Table.from_arrays(values, columns) + + write_feather(table, path, version=version) + + result = read_table(path) + assert result.equals(table) + + # Test without memory mapping + result = read_table(path, memory_map=False) + assert result.equals(table) + + result = read_feather(path, memory_map=False) + assert_frame_equal(table.to_pandas(), result) + + +@pytest.mark.pandas +def test_float_nulls(version): + num_values = 100 + + path = random_path() + TEST_FILES.append(path) + + null_mask = np.random.randint(0, 10, size=num_values) < 3 + dtypes = ['f4', 'f8'] + expected_cols = [] + + arrays = [] + for name in dtypes: + values = np.random.randn(num_values).astype(name) + arrays.append(pa.array(values, mask=null_mask)) + + values[null_mask] = np.nan + + expected_cols.append(values) + + table = pa.table(arrays, names=dtypes) + _check_arrow_roundtrip(table) + + df = table.to_pandas() + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_integer_no_nulls(version): + data, arr = {}, [] + + numpy_dtypes = ['i1', 'i2', 'i4', 'i8', + 'u1', 'u2', 'u4', 'u8'] + num_values = 100 + + for dtype in numpy_dtypes: + values = np.random.randint(0, 100, size=num_values) + data[dtype] = values.astype(dtype) + arr.append(values.astype(dtype)) + + df = pd.DataFrame(data) + _check_pandas_roundtrip(df, version=version) + + table = pa.table(arr, names=numpy_dtypes) + _check_arrow_roundtrip(table) + + +@pytest.mark.pandas +def test_platform_numpy_integers(version): + data = {} + + numpy_dtypes = ['longlong'] + num_values = 100 + + for dtype in numpy_dtypes: + values = np.random.randint(0, 100, size=num_values) + data[dtype] = values.astype(dtype) + + df = pd.DataFrame(data) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_integer_with_nulls(version): + # pandas requires upcast to float dtype + path = random_path() + TEST_FILES.append(path) + + int_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8'] + num_values = 100 + + arrays = [] + null_mask = np.random.randint(0, 10, size=num_values) < 3 + expected_cols = [] + for name in int_dtypes: + values = np.random.randint(0, 100, size=num_values) + arrays.append(pa.array(values, mask=null_mask)) + + expected = values.astype('f8') + expected[null_mask] = np.nan + + expected_cols.append(expected) + + table = pa.table(arrays, names=int_dtypes) + _check_arrow_roundtrip(table) + + df = table.to_pandas() + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_boolean_no_nulls(version): + num_values = 100 + + np.random.seed(0) + + df = pd.DataFrame({'bools': np.random.randn(num_values) > 0}) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_boolean_nulls(version): + # pandas requires upcast to object dtype + path = random_path() + TEST_FILES.append(path) + + num_values = 100 + np.random.seed(0) + + mask = np.random.randint(0, 10, size=num_values) < 3 + values = np.random.randint(0, 10, size=num_values) < 5 + + table = pa.table([pa.array(values, mask=mask)], names=['bools']) + _check_arrow_roundtrip(table) + + df = table.to_pandas() + _check_pandas_roundtrip(df, version=version) + + +def test_buffer_bounds_error(version): + # ARROW-1676 + path = random_path() + TEST_FILES.append(path) + + for i in range(16, 256): + table = pa.Table.from_arrays( + [pa.array([None] + list(range(i)), type=pa.float64())], + names=["arr"] + ) + _check_arrow_roundtrip(table) + + +def test_boolean_object_nulls(version): + repeats = 100 + table = pa.Table.from_arrays( + [np.array([False, None, True] * repeats, dtype=object)], + names=["arr"] + ) + _check_arrow_roundtrip(table) + + +@pytest.mark.pandas +def test_delete_partial_file_on_error(version): + if sys.platform == 'win32': + pytest.skip('Windows hangs on to file handle for some reason') + + class CustomClass: + pass + + # strings will fail + df = pd.DataFrame( + { + 'numbers': range(5), + 'strings': [b'foo', None, 'bar', CustomClass(), np.nan]}, + columns=['numbers', 'strings']) + + path = random_path() + try: + write_feather(df, path, version=version) + except Exception: + pass + + assert not os.path.exists(path) + + +@pytest.mark.pandas +def test_strings(version): + repeats = 1000 + + # Mixed bytes, unicode, strings coerced to binary + values = [b'foo', None, 'bar', 'qux', np.nan] + df = pd.DataFrame({'strings': values * repeats}) + + ex_values = [b'foo', None, b'bar', b'qux', np.nan] + expected = pd.DataFrame({'strings': ex_values * repeats}) + _check_pandas_roundtrip(df, expected, version=version) + + # embedded nulls are ok + values = ['foo', None, 'bar', 'qux', None] + df = pd.DataFrame({'strings': values * repeats}) + expected = pd.DataFrame({'strings': values * repeats}) + _check_pandas_roundtrip(df, expected, version=version) + + values = ['foo', None, 'bar', 'qux', np.nan] + df = pd.DataFrame({'strings': values * repeats}) + expected = pd.DataFrame({'strings': values * repeats}) + _check_pandas_roundtrip(df, expected, version=version) + + +@pytest.mark.pandas +def test_empty_strings(version): + df = pd.DataFrame({'strings': [''] * 10}) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_all_none(version): + df = pd.DataFrame({'all_none': [None] * 10}) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_all_null_category(version): + # ARROW-1188 + df = pd.DataFrame({"A": (1, 2, 3), "B": (None, None, None)}) + df = df.assign(B=df.B.astype("category")) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_multithreaded_read(version): + data = {'c{}'.format(i): [''] * 10 + for i in range(100)} + df = pd.DataFrame(data) + _check_pandas_roundtrip(df, use_threads=True, version=version) + + +@pytest.mark.pandas +def test_nan_as_null(version): + # Create a nan that is not numpy.nan + values = np.array(['foo', np.nan, np.nan * 2, 'bar'] * 10) + df = pd.DataFrame({'strings': values}) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_category(version): + repeats = 1000 + values = ['foo', None, 'bar', 'qux', np.nan] + df = pd.DataFrame({'strings': values * repeats}) + df['strings'] = df['strings'].astype('category') + + values = ['foo', None, 'bar', 'qux', None] + expected = pd.DataFrame({'strings': pd.Categorical(values * repeats)}) + _check_pandas_roundtrip(df, expected, version=version) + + +@pytest.mark.pandas +def test_timestamp(version): + df = pd.DataFrame({'naive': pd.date_range('2016-03-28', periods=10)}) + df['with_tz'] = (df.naive.dt.tz_localize('utc') + .dt.tz_convert('America/Los_Angeles')) + + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_timestamp_with_nulls(version): + df = pd.DataFrame({'test': [pd.Timestamp(2016, 1, 1), + None, + pd.Timestamp(2016, 1, 3)]}) + df['with_tz'] = df.test.dt.tz_localize('utc') + + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +@pytest.mark.xfail(reason="not supported", raises=TypeError) +def test_timedelta_with_nulls_v1(): + df = pd.DataFrame({'test': [pd.Timedelta('1 day'), + None, + pd.Timedelta('3 day')]}) + _check_pandas_roundtrip(df, version=1) + + +@pytest.mark.pandas +def test_timedelta_with_nulls(): + df = pd.DataFrame({'test': [pd.Timedelta('1 day'), + None, + pd.Timedelta('3 day')]}) + _check_pandas_roundtrip(df, version=2) + + +@pytest.mark.pandas +def test_out_of_float64_timestamp_with_nulls(version): + df = pd.DataFrame( + {'test': pd.DatetimeIndex([1451606400000000001, + None, 14516064000030405])}) + df['with_tz'] = df.test.dt.tz_localize('utc') + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.pandas +def test_non_string_columns(version): + df = pd.DataFrame({0: [1, 2, 3, 4], + 1: [True, False, True, False]}) + + expected = df.rename(columns=str) + _check_pandas_roundtrip(df, expected, version=version) + + +@pytest.mark.pandas +@pytest.mark.skipif(not os.path.supports_unicode_filenames, + reason='unicode filenames not supported') +def test_unicode_filename(version): + # GH #209 + name = (b'Besa_Kavaj\xc3\xab.feather').decode('utf-8') + df = pd.DataFrame({'foo': [1, 2, 3, 4]}) + _check_pandas_roundtrip(df, path=random_path(prefix=name), + version=version) + + +@pytest.mark.pandas +def test_read_columns(version): + df = pd.DataFrame({ + 'foo': [1, 2, 3, 4], + 'boo': [5, 6, 7, 8], + 'woo': [1, 3, 5, 7] + }) + expected = df[['boo', 'woo']] + + _check_pandas_roundtrip(df, expected, version=version, + columns=['boo', 'woo']) + + +def test_overwritten_file(version): + path = random_path() + TEST_FILES.append(path) + + num_values = 100 + np.random.seed(0) + + values = np.random.randint(0, 10, size=num_values) + + table = pa.table({'ints': values}) + write_feather(table, path) + + table = pa.table({'more_ints': values[0:num_values//2]}) + _check_arrow_roundtrip(table, path=path) + + +@pytest.mark.pandas +def test_filelike_objects(version): + buf = io.BytesIO() + + # the copy makes it non-strided + df = pd.DataFrame(np.arange(12).reshape(4, 3), + columns=['a', 'b', 'c']).copy() + write_feather(df, buf, version=version) + + buf.seek(0) + + result = read_feather(buf) + assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.filterwarnings("ignore:Sparse:FutureWarning") +@pytest.mark.filterwarnings("ignore:DataFrame.to_sparse:FutureWarning") +def test_sparse_dataframe(version): + if not pa.pandas_compat._pandas_api.has_sparse: + pytest.skip("version of pandas does not support SparseDataFrame") + # GH #221 + data = {'A': [0, 1, 2], + 'B': [1, 0, 1]} + df = pd.DataFrame(data).to_sparse(fill_value=1) + expected = df.to_dense() + _check_pandas_roundtrip(df, expected, version=version) + + +@pytest.mark.pandas +def test_duplicate_columns_pandas(): + + # https://github.com/wesm/feather/issues/53 + # not currently able to handle duplicate columns + df = pd.DataFrame(np.arange(12).reshape(4, 3), + columns=list('aaa')).copy() + _assert_error_on_write(df, ValueError) + + +def test_duplicate_columns(): + # only works for version 2 + table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'a', 'b']) + _check_arrow_roundtrip(table) + _assert_error_on_write(table, ValueError, version=1) + + +@pytest.mark.pandas +def test_unsupported(): + # https://github.com/wesm/feather/issues/240 + # serializing actual python objects + + # custom python objects + class A: + pass + + df = pd.DataFrame({'a': [A(), A()]}) + _assert_error_on_write(df, ValueError) + + # non-strings + df = pd.DataFrame({'a': ['a', 1, 2.0]}) + _assert_error_on_write(df, TypeError) + + +@pytest.mark.pandas +def test_v2_set_chunksize(): + df = pd.DataFrame({'A': np.arange(1000)}) + table = pa.table(df) + + buf = io.BytesIO() + write_feather(table, buf, chunksize=250, version=2) + + result = buf.getvalue() + + ipc_file = pa.ipc.open_file(pa.BufferReader(result)) + assert ipc_file.num_record_batches == 4 + assert len(ipc_file.get_batch(0)) == 250 + + +@pytest.mark.pandas +@pytest.mark.lz4 +@pytest.mark.snappy +@pytest.mark.zstd +def test_v2_compression_options(): + df = pd.DataFrame({'A': np.arange(1000)}) + + cases = [ + # compression, compression_level + ('uncompressed', None), + ('lz4', None), + ('zstd', 1), + ('zstd', 10) + ] + + for compression, compression_level in cases: + _check_pandas_roundtrip(df, compression=compression, + compression_level=compression_level) + + buf = io.BytesIO() + + # LZ4 doesn't support compression_level + with pytest.raises(pa.ArrowInvalid, + match="doesn't support setting a compression level"): + write_feather(df, buf, compression='lz4', compression_level=10) + + # Trying to compress with V1 + with pytest.raises( + ValueError, + match="Feather V1 files do not support compression option"): + write_feather(df, buf, compression='lz4', version=1) + + # Trying to set chunksize with V1 + with pytest.raises( + ValueError, + match="Feather V1 files do not support chunksize option"): + write_feather(df, buf, chunksize=4096, version=1) + + # Unsupported compressor + with pytest.raises(ValueError, + match='compression="snappy" not supported'): + write_feather(df, buf, compression='snappy') + + +def test_v2_lz4_default_compression(): + # ARROW-8750: Make sure that the compression=None option selects lz4 if + # it's available + if not pa.Codec.is_available('lz4_frame'): + pytest.skip("LZ4 compression support is not built in C++") + + # some highly compressible data + t = pa.table([np.repeat(0, 100000)], names=['f0']) + + buf = io.BytesIO() + write_feather(t, buf) + default_result = buf.getvalue() + + buf = io.BytesIO() + write_feather(t, buf, compression='uncompressed') + uncompressed_result = buf.getvalue() + + assert len(default_result) < len(uncompressed_result) + + +def test_v1_unsupported_types(): + table = pa.table([pa.array([[1, 2, 3], [], None])], names=['f0']) + + buf = io.BytesIO() + with pytest.raises(TypeError, + match=("Unsupported Feather V1 type: " + "list<item: int64>. " + "Use V2 format to serialize all Arrow types.")): + write_feather(table, buf, version=1) + + +@pytest.mark.slow +@pytest.mark.pandas +def test_large_dataframe(version): + df = pd.DataFrame({'A': np.arange(400000000)}) + _check_pandas_roundtrip(df, version=version) + + +@pytest.mark.large_memory +@pytest.mark.pandas +def test_chunked_binary_error_message(): + # ARROW-3058: As Feather does not yet support chunked columns, we at least + # make sure it's clear to the user what is going on + + # 2^31 + 1 bytes + values = [b'x'] + [ + b'x' * (1 << 20) + ] * 2 * (1 << 10) + df = pd.DataFrame({'byte_col': values}) + + # Works fine with version 2 + buf = io.BytesIO() + write_feather(df, buf, version=2) + result = read_feather(pa.BufferReader(buf.getvalue())) + assert_frame_equal(result, df) + + with pytest.raises(ValueError, match="'byte_col' exceeds 2GB maximum " + "capacity of a Feather binary column. This restriction " + "may be lifted in the future"): + write_feather(df, io.BytesIO(), version=1) + + +def test_feather_without_pandas(tempdir, version): + # ARROW-8345 + table = pa.table([pa.array([1, 2, 3])], names=['f0']) + path = str(tempdir / "data.feather") + _check_arrow_roundtrip(table, path) + + +@pytest.mark.pandas +def test_read_column_selection(version): + # ARROW-8641 + df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=['a', 'b', 'c']) + + # select columns as string names or integer indices + _check_pandas_roundtrip( + df, columns=['a', 'c'], expected=df[['a', 'c']], version=version) + _check_pandas_roundtrip( + df, columns=[0, 2], expected=df[['a', 'c']], version=version) + + # different order is followed + _check_pandas_roundtrip( + df, columns=['b', 'a'], expected=df[['b', 'a']], version=version) + _check_pandas_roundtrip( + df, columns=[1, 0], expected=df[['b', 'a']], version=version) + + +def test_read_column_duplicated_selection(tempdir, version): + # duplicated columns in the column selection + table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'b', 'c']) + path = str(tempdir / "data.feather") + write_feather(table, path, version=version) + + expected = pa.table([[1, 2, 3], [4, 5, 6], [1, 2, 3]], + names=['a', 'b', 'a']) + for col_selection in [['a', 'b', 'a'], [0, 1, 0]]: + result = read_table(path, columns=col_selection) + assert result.equals(expected) + + +def test_read_column_duplicated_in_file(tempdir): + # duplicated columns in feather file (only works for feather v2) + table = pa.table([[1, 2, 3], [4, 5, 6], [7, 8, 9]], names=['a', 'b', 'a']) + path = str(tempdir / "data.feather") + write_feather(table, path, version=2) + + # no selection works fine + result = read_table(path) + assert result.equals(table) + + # selection with indices works + result = read_table(path, columns=[0, 2]) + assert result.column_names == ['a', 'a'] + + # selection with column names errors + with pytest.raises(ValueError): + read_table(path, columns=['a', 'b']) + + +def test_nested_types(compression): + # https://issues.apache.org/jira/browse/ARROW-8860 + table = pa.table({'col': pa.StructArray.from_arrays( + [[0, 1, 2], [1, 2, 3]], names=["f1", "f2"])}) + _check_arrow_roundtrip(table, compression=compression) + + table = pa.table({'col': pa.array([[1, 2], [3, 4]])}) + _check_arrow_roundtrip(table, compression=compression) + + table = pa.table({'col': pa.array([[[1, 2], [3, 4]], [[5, 6], None]])}) + _check_arrow_roundtrip(table, compression=compression) + + +@h.given(past.all_tables, st.sampled_from(["uncompressed", "lz4", "zstd"])) +def test_roundtrip(table, compression): + _check_arrow_roundtrip(table, compression=compression) + + +@pytest.mark.lz4 +def test_feather_v017_experimental_compression_backward_compatibility(datadir): + # ARROW-11163 - ensure newer pyarrow versions can read the old feather + # files from version 0.17.0 with experimental compression support (before + # it was officially added to IPC format in 1.0.0) + + # file generated with: + # table = pa.table({'a': range(5)}) + # from pyarrow import feather + # feather.write_feather( + # table, "v0.17.0.version=2-compression=lz4.feather", + # compression="lz4", version=2) + expected = pa.table({'a': range(5)}) + result = read_table(datadir / "v0.17.0.version=2-compression=lz4.feather") + assert result.equals(expected) diff --git a/src/arrow/python/pyarrow/tests/test_filesystem.py b/src/arrow/python/pyarrow/tests/test_filesystem.py new file mode 100644 index 000000000..3d54f33e1 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_filesystem.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import pyarrow as pa +from pyarrow import filesystem + +import pytest + + +def test_filesystem_deprecated(): + with pytest.warns(FutureWarning): + filesystem.LocalFileSystem() + + with pytest.warns(FutureWarning): + filesystem.LocalFileSystem.get_instance() + + +@pytest.mark.skipif(sys.version_info < (3, 7), + reason="getattr needs Python 3.7") +def test_filesystem_deprecated_toplevel(): + + with pytest.warns(FutureWarning): + pa.localfs + + with pytest.warns(FutureWarning): + pa.FileSystem + + with pytest.warns(FutureWarning): + pa.LocalFileSystem + + with pytest.warns(FutureWarning): + pa.HadoopFileSystem + + +def test_resolve_uri(): + uri = "file:///home/user/myfile.parquet" + fs, path = filesystem.resolve_filesystem_and_path(uri) + assert isinstance(fs, filesystem.LocalFileSystem) + assert path == "/home/user/myfile.parquet" + + +def test_resolve_local_path(): + for uri in ['/home/user/myfile.parquet', + 'myfile.parquet', + 'my # file ? parquet', + 'C:/Windows/myfile.parquet', + r'C:\\Windows\\myfile.parquet', + ]: + fs, path = filesystem.resolve_filesystem_and_path(uri) + assert isinstance(fs, filesystem.LocalFileSystem) + assert path == uri diff --git a/src/arrow/python/pyarrow/tests/test_flight.py b/src/arrow/python/pyarrow/tests/test_flight.py new file mode 100644 index 000000000..5c40467a5 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_flight.py @@ -0,0 +1,2047 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import ast +import base64 +import itertools +import os +import signal +import struct +import tempfile +import threading +import time +import traceback +import json + +import numpy as np +import pytest +import pyarrow as pa + +from pyarrow.lib import tobytes +from pyarrow.util import pathlib, find_free_port +from pyarrow.tests import util + +try: + from pyarrow import flight + from pyarrow.flight import ( + FlightClient, FlightServerBase, + ServerAuthHandler, ClientAuthHandler, + ServerMiddleware, ServerMiddlewareFactory, + ClientMiddleware, ClientMiddlewareFactory, + ) +except ImportError: + flight = None + FlightClient, FlightServerBase = object, object + ServerAuthHandler, ClientAuthHandler = object, object + ServerMiddleware, ServerMiddlewareFactory = object, object + ClientMiddleware, ClientMiddlewareFactory = object, object + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not flight' +pytestmark = pytest.mark.flight + + +def test_import(): + # So we see the ImportError somewhere + import pyarrow.flight # noqa + + +def resource_root(): + """Get the path to the test resources directory.""" + if not os.environ.get("ARROW_TEST_DATA"): + raise RuntimeError("Test resources not found; set " + "ARROW_TEST_DATA to <repo root>/testing/data") + return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight" + + +def read_flight_resource(path): + """Get the contents of a test resource file.""" + root = resource_root() + if not root: + return None + try: + with (root / path).open("rb") as f: + return f.read() + except FileNotFoundError: + raise RuntimeError( + "Test resource {} not found; did you initialize the " + "test resource submodule?\n{}".format(root / path, + traceback.format_exc())) + + +def example_tls_certs(): + """Get the paths to test TLS certificates.""" + return { + "root_cert": read_flight_resource("root-ca.pem"), + "certificates": [ + flight.CertKeyPair( + cert=read_flight_resource("cert0.pem"), + key=read_flight_resource("cert0.key"), + ), + flight.CertKeyPair( + cert=read_flight_resource("cert1.pem"), + key=read_flight_resource("cert1.key"), + ), + ] + } + + +def simple_ints_table(): + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + return pa.Table.from_arrays(data, names=['some_ints']) + + +def simple_dicts_table(): + dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8()) + data = [ + pa.chunked_array([ + pa.DictionaryArray.from_arrays([1, 0, None], dict_values), + pa.DictionaryArray.from_arrays([2, 1], dict_values) + ]) + ] + return pa.Table.from_arrays(data, names=['some_dicts']) + + +class ConstantFlightServer(FlightServerBase): + """A Flight server that always returns the same data. + + See ARROW-4796: this server implementation will segfault if Flight + does not properly hold a reference to the Table object. + """ + + CRITERIA = b"the expected criteria" + + def __init__(self, location=None, options=None, **kwargs): + super().__init__(location, **kwargs) + # Ticket -> Table + self.table_factories = { + b'ints': simple_ints_table, + b'dicts': simple_dicts_table, + } + self.options = options + + def list_flights(self, context, criteria): + if criteria == self.CRITERIA: + yield flight.FlightInfo( + pa.schema([]), + flight.FlightDescriptor.for_path('/foo'), + [], + -1, -1 + ) + + def do_get(self, context, ticket): + # Return a fresh table, so that Flight is the only one keeping a + # reference. + table = self.table_factories[ticket.ticket]() + return flight.RecordBatchStream(table, options=self.options) + + +class MetadataFlightServer(FlightServerBase): + """A Flight server that numbers incoming/outgoing data.""" + + def __init__(self, options=None, **kwargs): + super().__init__(**kwargs) + self.options = options + + def do_get(self, context, ticket): + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + return flight.GeneratorStream( + table.schema, + self.number_batches(table), + options=self.options) + + def do_put(self, context, descriptor, reader, writer): + counter = 0 + expected_data = [-10, -5, 0, 5, 10] + while True: + try: + batch, buf = reader.read_chunk() + assert batch.equals(pa.RecordBatch.from_arrays( + [pa.array([expected_data[counter]])], + ['a'] + )) + assert buf is not None + client_counter, = struct.unpack('<i', buf.to_pybytes()) + assert counter == client_counter + writer.write(struct.pack('<i', counter)) + counter += 1 + except StopIteration: + return + + @staticmethod + def number_batches(table): + for idx, batch in enumerate(table.to_batches()): + buf = struct.pack('<i', idx) + yield batch, buf + + +class EchoFlightServer(FlightServerBase): + """A Flight server that returns the last data uploaded.""" + + def __init__(self, location=None, expected_schema=None, **kwargs): + super().__init__(location, **kwargs) + self.last_message = None + self.expected_schema = expected_schema + + def do_get(self, context, ticket): + return flight.RecordBatchStream(self.last_message) + + def do_put(self, context, descriptor, reader, writer): + if self.expected_schema: + assert self.expected_schema == reader.schema + self.last_message = reader.read_all() + + def do_exchange(self, context, descriptor, reader, writer): + for chunk in reader: + pass + + +class EchoStreamFlightServer(EchoFlightServer): + """An echo server that streams individual record batches.""" + + def do_get(self, context, ticket): + return flight.GeneratorStream( + self.last_message.schema, + self.last_message.to_batches(max_chunksize=1024)) + + def list_actions(self, context): + return [] + + def do_action(self, context, action): + if action.type == "who-am-i": + return [context.peer_identity(), context.peer().encode("utf-8")] + raise NotImplementedError + + +class GetInfoFlightServer(FlightServerBase): + """A Flight server that tests GetFlightInfo.""" + + def get_flight_info(self, context, descriptor): + return flight.FlightInfo( + pa.schema([('a', pa.int32())]), + descriptor, + [ + flight.FlightEndpoint(b'', ['grpc://test']), + flight.FlightEndpoint( + b'', + [flight.Location.for_grpc_tcp('localhost', 5005)], + ), + ], + -1, + -1, + ) + + def get_schema(self, context, descriptor): + info = self.get_flight_info(context, descriptor) + return flight.SchemaResult(info.schema) + + +class ListActionsFlightServer(FlightServerBase): + """A Flight server that tests ListActions.""" + + @classmethod + def expected_actions(cls): + return [ + ("action-1", "description"), + ("action-2", ""), + flight.ActionType("action-3", "more detail"), + ] + + def list_actions(self, context): + yield from self.expected_actions() + + +class ListActionsErrorFlightServer(FlightServerBase): + """A Flight server that tests ListActions.""" + + def list_actions(self, context): + yield ("action-1", "") + yield "foo" + + +class CheckTicketFlightServer(FlightServerBase): + """A Flight server that compares the given ticket to an expected value.""" + + def __init__(self, expected_ticket, location=None, **kwargs): + super().__init__(location, **kwargs) + self.expected_ticket = expected_ticket + + def do_get(self, context, ticket): + assert self.expected_ticket == ticket.ticket + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + table = pa.Table.from_arrays(data1, names=['a']) + return flight.RecordBatchStream(table) + + def do_put(self, context, descriptor, reader): + self.last_message = reader.read_all() + + +class InvalidStreamFlightServer(FlightServerBase): + """A Flight server that tries to return messages with differing schemas.""" + + schema = pa.schema([('a', pa.int32())]) + + def do_get(self, context, ticket): + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())] + assert data1.type != data2.type + table1 = pa.Table.from_arrays(data1, names=['a']) + table2 = pa.Table.from_arrays(data2, names=['a']) + assert table1.schema == self.schema + + return flight.GeneratorStream(self.schema, [table1, table2]) + + +class NeverSendsDataFlightServer(FlightServerBase): + """A Flight server that never actually yields data.""" + + schema = pa.schema([('a', pa.int32())]) + + def do_get(self, context, ticket): + if ticket.ticket == b'yield_data': + # Check that the server handler will ignore empty tables + # up to a certain extent + data = [ + self.schema.empty_table(), + self.schema.empty_table(), + pa.RecordBatch.from_arrays([range(5)], schema=self.schema), + ] + return flight.GeneratorStream(self.schema, data) + return flight.GeneratorStream( + self.schema, itertools.repeat(self.schema.empty_table())) + + +class SlowFlightServer(FlightServerBase): + """A Flight server that delays its responses to test timeouts.""" + + def do_get(self, context, ticket): + return flight.GeneratorStream(pa.schema([('a', pa.int32())]), + self.slow_stream()) + + def do_action(self, context, action): + time.sleep(0.5) + return [] + + @staticmethod + def slow_stream(): + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + yield pa.Table.from_arrays(data1, names=['a']) + # The second message should never get sent; the client should + # cancel before we send this + time.sleep(10) + yield pa.Table.from_arrays(data1, names=['a']) + + +class ErrorFlightServer(FlightServerBase): + """A Flight server that uses all the Flight-specific errors.""" + + def do_action(self, context, action): + if action.type == "internal": + raise flight.FlightInternalError("foo") + elif action.type == "timedout": + raise flight.FlightTimedOutError("foo") + elif action.type == "cancel": + raise flight.FlightCancelledError("foo") + elif action.type == "unauthenticated": + raise flight.FlightUnauthenticatedError("foo") + elif action.type == "unauthorized": + raise flight.FlightUnauthorizedError("foo") + elif action.type == "protobuf": + err_msg = b'this is an error message' + raise flight.FlightUnauthorizedError("foo", err_msg) + raise NotImplementedError + + def list_flights(self, context, criteria): + yield flight.FlightInfo( + pa.schema([]), + flight.FlightDescriptor.for_path('/foo'), + [], + -1, -1 + ) + raise flight.FlightInternalError("foo") + + def do_put(self, context, descriptor, reader, writer): + if descriptor.command == b"internal": + raise flight.FlightInternalError("foo") + elif descriptor.command == b"timedout": + raise flight.FlightTimedOutError("foo") + elif descriptor.command == b"cancel": + raise flight.FlightCancelledError("foo") + elif descriptor.command == b"unauthenticated": + raise flight.FlightUnauthenticatedError("foo") + elif descriptor.command == b"unauthorized": + raise flight.FlightUnauthorizedError("foo") + elif descriptor.command == b"protobuf": + err_msg = b'this is an error message' + raise flight.FlightUnauthorizedError("foo", err_msg) + + +class ExchangeFlightServer(FlightServerBase): + """A server for testing DoExchange.""" + + def __init__(self, options=None, **kwargs): + super().__init__(**kwargs) + self.options = options + + def do_exchange(self, context, descriptor, reader, writer): + if descriptor.descriptor_type != flight.DescriptorType.CMD: + raise pa.ArrowInvalid("Must provide a command descriptor") + elif descriptor.command == b"echo": + return self.exchange_echo(context, reader, writer) + elif descriptor.command == b"get": + return self.exchange_do_get(context, reader, writer) + elif descriptor.command == b"put": + return self.exchange_do_put(context, reader, writer) + elif descriptor.command == b"transform": + return self.exchange_transform(context, reader, writer) + else: + raise pa.ArrowInvalid( + "Unknown command: {}".format(descriptor.command)) + + def exchange_do_get(self, context, reader, writer): + """Emulate DoGet with DoExchange.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + writer.begin(data.schema) + writer.write_table(data) + + def exchange_do_put(self, context, reader, writer): + """Emulate DoPut with DoExchange.""" + num_batches = 0 + for chunk in reader: + if not chunk.data: + raise pa.ArrowInvalid("All chunks must have data.") + num_batches += 1 + writer.write_metadata(str(num_batches).encode("utf-8")) + + def exchange_echo(self, context, reader, writer): + """Run a simple echo server.""" + started = False + for chunk in reader: + if not started and chunk.data: + writer.begin(chunk.data.schema, options=self.options) + started = True + if chunk.app_metadata and chunk.data: + writer.write_with_metadata(chunk.data, chunk.app_metadata) + elif chunk.app_metadata: + writer.write_metadata(chunk.app_metadata) + elif chunk.data: + writer.write_batch(chunk.data) + else: + assert False, "Should not happen" + + def exchange_transform(self, context, reader, writer): + """Sum rows in an uploaded table.""" + for field in reader.schema: + if not pa.types.is_integer(field.type): + raise pa.ArrowInvalid("Invalid field: " + repr(field)) + table = reader.read_all() + sums = [0] * table.num_rows + for column in table: + for row, value in enumerate(column): + sums[row] += value.as_py() + result = pa.Table.from_arrays([pa.array(sums)], names=["sum"]) + writer.begin(result.schema) + writer.write_table(result) + + +class HttpBasicServerAuthHandler(ServerAuthHandler): + """An example implementation of HTTP basic authentication.""" + + def __init__(self, creds): + super().__init__() + self.creds = creds + + def authenticate(self, outgoing, incoming): + buf = incoming.read() + auth = flight.BasicAuth.deserialize(buf) + if auth.username not in self.creds: + raise flight.FlightUnauthenticatedError("unknown user") + if self.creds[auth.username] != auth.password: + raise flight.FlightUnauthenticatedError("wrong password") + outgoing.write(tobytes(auth.username)) + + def is_valid(self, token): + if not token: + raise flight.FlightUnauthenticatedError("token not provided") + if token not in self.creds: + raise flight.FlightUnauthenticatedError("unknown user") + return token + + +class HttpBasicClientAuthHandler(ClientAuthHandler): + """An example implementation of HTTP basic authentication.""" + + def __init__(self, username, password): + super().__init__() + self.basic_auth = flight.BasicAuth(username, password) + self.token = None + + def authenticate(self, outgoing, incoming): + auth = self.basic_auth.serialize() + outgoing.write(auth) + self.token = incoming.read() + + def get_token(self): + return self.token + + +class TokenServerAuthHandler(ServerAuthHandler): + """An example implementation of authentication via handshake.""" + + def __init__(self, creds): + super().__init__() + self.creds = creds + + def authenticate(self, outgoing, incoming): + username = incoming.read() + password = incoming.read() + if username in self.creds and self.creds[username] == password: + outgoing.write(base64.b64encode(b'secret:' + username)) + else: + raise flight.FlightUnauthenticatedError( + "invalid username/password") + + def is_valid(self, token): + token = base64.b64decode(token) + if not token.startswith(b'secret:'): + raise flight.FlightUnauthenticatedError("invalid token") + return token[7:] + + +class TokenClientAuthHandler(ClientAuthHandler): + """An example implementation of authentication via handshake.""" + + def __init__(self, username, password): + super().__init__() + self.username = username + self.password = password + self.token = b'' + + def authenticate(self, outgoing, incoming): + outgoing.write(self.username) + outgoing.write(self.password) + self.token = incoming.read() + + def get_token(self): + return self.token + + +class NoopAuthHandler(ServerAuthHandler): + """A no-op auth handler.""" + + def authenticate(self, outgoing, incoming): + """Do nothing.""" + + def is_valid(self, token): + """ + Returning an empty string. + Returning None causes Type error. + """ + return "" + + +def case_insensitive_header_lookup(headers, lookup_key): + """Lookup the value of given key in the given headers. + The key lookup is case insensitive. + """ + for key in headers: + if key.lower() == lookup_key.lower(): + return headers.get(key) + + +class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory): + """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware.""" + + def __init__(self): + self.call_credential = [] + + def start_call(self, info): + return ClientHeaderAuthMiddleware(self) + + def set_call_credential(self, call_credential): + self.call_credential = call_credential + + +class ClientHeaderAuthMiddleware(ClientMiddleware): + """ + ClientMiddleware that extracts the authorization header + from the server. + + This is an example of a ClientMiddleware that can extract + the bearer token authorization header from a HTTP header + authentication enabled server. + + Parameters + ---------- + factory : ClientHeaderAuthMiddlewareFactory + This factory is used to set call credentials if an + authorization header is found in the headers from the server. + """ + + def __init__(self, factory): + self.factory = factory + + def received_headers(self, headers): + auth_header = case_insensitive_header_lookup(headers, 'Authorization') + self.factory.set_call_credential([ + b'authorization', + auth_header[0].encode("utf-8")]) + + +class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory): + """Validates incoming username and password.""" + + def start_call(self, info, headers): + auth_header = case_insensitive_header_lookup( + headers, + 'Authorization' + ) + values = auth_header[0].split(' ') + token = '' + error_message = 'Invalid credentials' + + if values[0] == 'Basic': + decoded = base64.b64decode(values[1]) + pair = decoded.decode("utf-8").split(':') + if not (pair[0] == 'test' and pair[1] == 'password'): + raise flight.FlightUnauthenticatedError(error_message) + token = 'token1234' + elif values[0] == 'Bearer': + token = values[1] + if not token == 'token1234': + raise flight.FlightUnauthenticatedError(error_message) + else: + raise flight.FlightUnauthenticatedError(error_message) + + return HeaderAuthServerMiddleware(token) + + +class HeaderAuthServerMiddleware(ServerMiddleware): + """A ServerMiddleware that transports incoming username and passowrd.""" + + def __init__(self, token): + self.token = token + + def sending_headers(self): + return {'authorization': 'Bearer ' + self.token} + + +class HeaderAuthFlightServer(FlightServerBase): + """A Flight server that tests with basic token authentication. """ + + def do_action(self, context, action): + middleware = context.get_middleware("auth") + if middleware: + auth_header = case_insensitive_header_lookup( + middleware.sending_headers(), 'Authorization') + values = auth_header.split(' ') + return [values[1].encode("utf-8")] + raise flight.FlightUnauthenticatedError( + 'No token auth middleware found.') + + +class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory): + """A ServerMiddlewareFactory that transports arbitrary headers.""" + + def start_call(self, info, headers): + return ArbitraryHeadersServerMiddleware(headers) + + +class ArbitraryHeadersServerMiddleware(ServerMiddleware): + """A ServerMiddleware that transports arbitrary headers.""" + + def __init__(self, incoming): + self.incoming = incoming + + def sending_headers(self): + return self.incoming + + +class ArbitraryHeadersFlightServer(FlightServerBase): + """A Flight server that tests multiple arbitrary headers.""" + + def do_action(self, context, action): + middleware = context.get_middleware("arbitrary-headers") + if middleware: + headers = middleware.sending_headers() + header_1 = case_insensitive_header_lookup( + headers, + 'test-header-1' + ) + header_2 = case_insensitive_header_lookup( + headers, + 'test-header-2' + ) + value1 = header_1[0].encode("utf-8") + value2 = header_2[0].encode("utf-8") + return [value1, value2] + raise flight.FlightServerError("No headers middleware found") + + +class HeaderServerMiddleware(ServerMiddleware): + """Expose a per-call value to the RPC method body.""" + + def __init__(self, special_value): + self.special_value = special_value + + +class HeaderServerMiddlewareFactory(ServerMiddlewareFactory): + """Expose a per-call hard-coded value to the RPC method body.""" + + def start_call(self, info, headers): + return HeaderServerMiddleware("right value") + + +class HeaderFlightServer(FlightServerBase): + """Echo back the per-call hard-coded value.""" + + def do_action(self, context, action): + middleware = context.get_middleware("test") + if middleware: + return [middleware.special_value.encode()] + return [b""] + + +class MultiHeaderFlightServer(FlightServerBase): + """Test sending/receiving multiple (binary-valued) headers.""" + + def do_action(self, context, action): + middleware = context.get_middleware("test") + headers = repr(middleware.client_headers).encode("utf-8") + return [headers] + + +class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory): + """Deny access to certain methods based on a header.""" + + def start_call(self, info, headers): + if info.method == flight.FlightMethod.LIST_ACTIONS: + # No auth needed + return + + token = headers.get("x-auth-token") + if not token: + raise flight.FlightUnauthenticatedError("No token") + + token = token[0] + if token != "password": + raise flight.FlightUnauthenticatedError("Invalid token") + + return HeaderServerMiddleware(token) + + +class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory): + def start_call(self, info): + return SelectiveAuthClientMiddleware() + + +class SelectiveAuthClientMiddleware(ClientMiddleware): + def sending_headers(self): + return { + "x-auth-token": "password", + } + + +class RecordingServerMiddlewareFactory(ServerMiddlewareFactory): + """Record what methods were called.""" + + def __init__(self): + super().__init__() + self.methods = [] + + def start_call(self, info, headers): + self.methods.append(info.method) + return None + + +class RecordingClientMiddlewareFactory(ClientMiddlewareFactory): + """Record what methods were called.""" + + def __init__(self): + super().__init__() + self.methods = [] + + def start_call(self, info): + self.methods.append(info.method) + return None + + +class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory): + """Test sending/receiving multiple (binary-valued) headers.""" + + def __init__(self): + # Read in test_middleware_multi_header below. + # The middleware instance will update this value. + self.last_headers = {} + + def start_call(self, info): + return MultiHeaderClientMiddleware(self) + + +class MultiHeaderClientMiddleware(ClientMiddleware): + """Test sending/receiving multiple (binary-valued) headers.""" + + EXPECTED = { + "x-text": ["foo", "bar"], + "x-binary-bin": [b"\x00", b"\x01"], + } + + def __init__(self, factory): + self.factory = factory + + def sending_headers(self): + return self.EXPECTED + + def received_headers(self, headers): + # Let the test code know what the last set of headers we + # received were. + self.factory.last_headers = headers + + +class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory): + """Test sending/receiving multiple (binary-valued) headers.""" + + def start_call(self, info, headers): + return MultiHeaderServerMiddleware(headers) + + +class MultiHeaderServerMiddleware(ServerMiddleware): + """Test sending/receiving multiple (binary-valued) headers.""" + + def __init__(self, client_headers): + self.client_headers = client_headers + + def sending_headers(self): + return MultiHeaderClientMiddleware.EXPECTED + + +class LargeMetadataFlightServer(FlightServerBase): + """Regression test for ARROW-13253.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._metadata = b' ' * (2 ** 31 + 1) + + def do_get(self, context, ticket): + schema = pa.schema([('a', pa.int64())]) + return flight.GeneratorStream(schema, [ + (pa.record_batch([[1]], schema=schema), self._metadata), + ]) + + def do_exchange(self, context, descriptor, reader, writer): + writer.write_metadata(self._metadata) + + +def test_flight_server_location_argument(): + locations = [ + None, + 'grpc://localhost:0', + ('localhost', find_free_port()), + ] + for location in locations: + with FlightServerBase(location) as server: + assert isinstance(server, FlightServerBase) + + +def test_server_exit_reraises_exception(): + with pytest.raises(ValueError): + with FlightServerBase(): + raise ValueError() + + +@pytest.mark.slow +def test_client_wait_for_available(): + location = ('localhost', find_free_port()) + server = None + + def serve(): + global server + time.sleep(0.5) + server = FlightServerBase(location) + server.serve() + + client = FlightClient(location) + thread = threading.Thread(target=serve, daemon=True) + thread.start() + + started = time.time() + client.wait_for_available(timeout=5) + elapsed = time.time() - started + assert elapsed >= 0.5 + + +def test_flight_list_flights(): + """Try a simple list_flights call.""" + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + assert list(client.list_flights()) == [] + flights = client.list_flights(ConstantFlightServer.CRITERIA) + assert len(list(flights)) == 1 + + +def test_flight_do_get_ints(): + """Try a simple do_get call.""" + table = simple_ints_table() + + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4) + with ConstantFlightServer(options=options) as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + # Also test via RecordBatchReader interface + data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all() + assert data.equals(table) + + with pytest.raises(flight.FlightServerError, + match="expected IpcWriteOptions, got <class 'int'>"): + with ConstantFlightServer(options=42) as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_all() + + +@pytest.mark.pandas +def test_do_get_ints_pandas(): + """Try a simple do_get call.""" + table = simple_ints_table() + + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_pandas() + assert list(data['some_ints']) == table.column(0).to_pylist() + + +def test_flight_do_get_dicts(): + table = simple_dicts_table() + + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'dicts')).read_all() + assert data.equals(table) + + +def test_flight_do_get_ticket(): + """Make sure Tickets get passed to the server.""" + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + table = pa.Table.from_arrays(data1, names=['a']) + with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'the-ticket')).read_all() + assert data.equals(table) + + +def test_flight_get_info(): + """Make sure FlightEndpoint accepts string and object URIs.""" + with GetInfoFlightServer() as server: + client = FlightClient(('localhost', server.port)) + info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) + assert info.total_records == -1 + assert info.total_bytes == -1 + assert info.schema == pa.schema([('a', pa.int32())]) + assert len(info.endpoints) == 2 + assert len(info.endpoints[0].locations) == 1 + assert info.endpoints[0].locations[0] == flight.Location('grpc://test') + assert info.endpoints[1].locations[0] == \ + flight.Location.for_grpc_tcp('localhost', 5005) + + +def test_flight_get_schema(): + """Make sure GetSchema returns correct schema.""" + with GetInfoFlightServer() as server: + client = FlightClient(('localhost', server.port)) + info = client.get_schema(flight.FlightDescriptor.for_command(b'')) + assert info.schema == pa.schema([('a', pa.int32())]) + + +def test_list_actions(): + """Make sure the return type of ListActions is validated.""" + # ARROW-6392 + with ListActionsErrorFlightServer() as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises( + flight.FlightServerError, + match=("Results of list_actions must be " + "ActionType or tuple") + ): + list(client.list_actions()) + + with ListActionsFlightServer() as server: + client = FlightClient(('localhost', server.port)) + assert list(client.list_actions()) == \ + ListActionsFlightServer.expected_actions() + + +class ConvenienceServer(FlightServerBase): + """ + Server for testing various implementation conveniences (auto-boxing, etc.) + """ + + @property + def simple_action_results(self): + return [b'foo', b'bar', b'baz'] + + def do_action(self, context, action): + if action.type == 'simple-action': + return self.simple_action_results + elif action.type == 'echo': + return [action.body] + elif action.type == 'bad-action': + return ['foo'] + elif action.type == 'arrow-exception': + raise pa.ArrowMemoryError() + + +def test_do_action_result_convenience(): + with ConvenienceServer() as server: + client = FlightClient(('localhost', server.port)) + + # do_action as action type without body + results = [x.body for x in client.do_action('simple-action')] + assert results == server.simple_action_results + + # do_action with tuple of type and body + body = b'the-body' + results = [x.body for x in client.do_action(('echo', body))] + assert results == [body] + + +def test_nicer_server_exceptions(): + with ConvenienceServer() as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(flight.FlightServerError, + match="a bytes-like object is required"): + list(client.do_action('bad-action')) + # While Flight/C++ sends across the original status code, it + # doesn't get mapped to the equivalent code here, since we + # want to be able to distinguish between client- and server- + # side errors. + with pytest.raises(flight.FlightServerError, + match="ArrowMemoryError"): + list(client.do_action('arrow-exception')) + + +def test_get_port(): + """Make sure port() works.""" + server = GetInfoFlightServer("grpc://localhost:0") + try: + assert server.port > 0 + finally: + server.shutdown() + + +@pytest.mark.skipif(os.name == 'nt', + reason="Unix sockets can't be tested on Windows") +def test_flight_domain_socket(): + """Try a simple do_get call over a Unix domain socket.""" + with tempfile.NamedTemporaryFile() as sock: + sock.close() + location = flight.Location.for_grpc_unix(sock.name) + with ConstantFlightServer(location=location): + client = FlightClient(location) + + reader = client.do_get(flight.Ticket(b'ints')) + table = simple_ints_table() + assert reader.schema.equals(table.schema) + data = reader.read_all() + assert data.equals(table) + + reader = client.do_get(flight.Ticket(b'dicts')) + table = simple_dicts_table() + assert reader.schema.equals(table.schema) + data = reader.read_all() + assert data.equals(table) + + +@pytest.mark.slow +def test_flight_large_message(): + """Try sending/receiving a large message via Flight. + + See ARROW-4421: by default, gRPC won't allow us to send messages > + 4MiB in size. + """ + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024 * 1024)) + ], names=['a']) + + with EchoFlightServer(expected_schema=data.schema) as server: + client = FlightClient(('localhost', server.port)) + writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), + data.schema) + # Write a single giant chunk + writer.write_table(data, 10 * 1024 * 1024) + writer.close() + result = client.do_get(flight.Ticket(b'')).read_all() + assert result.equals(data) + + +def test_flight_generator_stream(): + """Try downloading a flight of RecordBatches in a GeneratorStream.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=['a']) + + with EchoStreamFlightServer() as server: + client = FlightClient(('localhost', server.port)) + writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), + data.schema) + writer.write_table(data) + writer.close() + result = client.do_get(flight.Ticket(b'')).read_all() + assert result.equals(data) + + +def test_flight_invalid_generator_stream(): + """Try streaming data with mismatched schemas.""" + with InvalidStreamFlightServer() as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(pa.ArrowException): + client.do_get(flight.Ticket(b'')).read_all() + + +def test_timeout_fires(): + """Make sure timeouts fire on slow requests.""" + # Do this in a separate thread so that if it fails, we don't hang + # the entire test process + with SlowFlightServer() as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("", b"") + options = flight.FlightCallOptions(timeout=0.2) + # gRPC error messages change based on version, so don't look + # for a particular error + with pytest.raises(flight.FlightTimedOutError): + list(client.do_action(action, options=options)) + + +def test_timeout_passes(): + """Make sure timeouts do not fire on fast requests.""" + with ConstantFlightServer() as server: + client = FlightClient(('localhost', server.port)) + options = flight.FlightCallOptions(timeout=5.0) + client.do_get(flight.Ticket(b'ints'), options=options).read_all() + + +basic_auth_handler = HttpBasicServerAuthHandler(creds={ + b"test": b"p4ssw0rd", +}) + +token_auth_handler = TokenServerAuthHandler(creds={ + b"test": b"p4ssw0rd", +}) + + +@pytest.mark.slow +def test_http_basic_unauth(): + """Test that auth fails when not authenticated.""" + with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + with pytest.raises(flight.FlightUnauthenticatedError, + match=".*unauthenticated.*"): + list(client.do_action(action)) + + +@pytest.mark.skipif(os.name == 'nt', + reason="ARROW-10013: gRPC on Windows corrupts peer()") +def test_http_basic_auth(): + """Test a Python implementation of HTTP basic authentication.""" + with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) + results = client.do_action(action) + identity = next(results) + assert identity.body.to_pybytes() == b'test' + peer_address = next(results) + assert peer_address.body.to_pybytes() != b'' + + +def test_http_basic_auth_invalid_password(): + """Test that auth fails with the wrong password.""" + with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + with pytest.raises(flight.FlightUnauthenticatedError, + match=".*wrong password.*"): + client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) + next(client.do_action(action)) + + +def test_token_auth(): + """Test an auth mechanism that uses a handshake.""" + with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) + identity = next(client.do_action(action)) + assert identity.body.to_pybytes() == b'test' + + +def test_token_auth_invalid(): + """Test an auth mechanism that uses a handshake.""" + with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(flight.FlightUnauthenticatedError): + client.authenticate(TokenClientAuthHandler('test', 'wrong')) + + +header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory() +no_op_auth_handler = NoopAuthHandler() + + +def test_authenticate_basic_token(): + """Test authenticate_basic_token with bearer token and auth headers.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + + +def test_authenticate_basic_token_invalid_password(): + """Test authenticate_basic_token with an invalid password.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(flight.FlightUnauthenticatedError): + client.authenticate_basic_token(b'test', b'badpassword') + + +def test_authenticate_basic_token_and_action(): + """Test authenticate_basic_token and doAction after authentication.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + options = flight.FlightCallOptions(headers=[token_pair]) + result = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result[0].body.to_pybytes() == b'token1234' + + +def test_authenticate_basic_token_with_client_middleware(): + """Test authenticate_basic_token with client middleware + to intercept authorization header returned by the + HTTP header auth enabled server. + """ + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client_auth_middleware = ClientHeaderAuthMiddlewareFactory() + client = FlightClient( + ('localhost', server.port), + middleware=[client_auth_middleware] + ) + encoded_credentials = base64.b64encode(b'test:password') + options = flight.FlightCallOptions(headers=[ + (b'authorization', b'Basic ' + encoded_credentials) + ]) + result = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result[0].body.to_pybytes() == b'token1234' + assert client_auth_middleware.call_credential[0] == b'authorization' + assert client_auth_middleware.call_credential[1] == \ + b'Bearer ' + b'token1234' + result2 = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result2[0].body.to_pybytes() == b'token1234' + assert client_auth_middleware.call_credential[0] == b'authorization' + assert client_auth_middleware.call_credential[1] == \ + b'Bearer ' + b'token1234' + + +def test_arbitrary_headers_in_flight_call_options(): + """Test passing multiple arbitrary headers to the middleware.""" + with ArbitraryHeadersFlightServer( + auth_handler=no_op_auth_handler, + middleware={ + "auth": HeaderAuthServerMiddlewareFactory(), + "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + options = flight.FlightCallOptions(headers=[ + token_pair, + (b'test-header-1', b'value1'), + (b'test-header-2', b'value2') + ]) + result = list(client.do_action(flight.Action( + "test-action", b""), options=options)) + assert result[0].body.to_pybytes() == b'value1' + assert result[1].body.to_pybytes() == b'value2' + + +def test_location_invalid(): + """Test constructing invalid URIs.""" + with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): + flight.connect("%") + + with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): + ConstantFlightServer("%") + + +def test_location_unknown_scheme(): + """Test creating locations for unknown schemes.""" + assert flight.Location("s3://foo").uri == b"s3://foo" + assert flight.Location("https://example.com/bar.parquet").uri == \ + b"https://example.com/bar.parquet" + + +@pytest.mark.slow +@pytest.mark.requires_testing_data +def test_tls_fails(): + """Make sure clients cannot connect when cert verification fails.""" + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + # Ensure client doesn't connect when certificate verification + # fails (this is a slow test since gRPC does retry a few times) + client = FlightClient("grpc+tls://localhost:" + str(s.port)) + + # gRPC error messages change based on version, so don't look + # for a particular error + with pytest.raises(flight.FlightUnavailableError): + client.do_get(flight.Ticket(b'ints')).read_all() + + +@pytest.mark.requires_testing_data +def test_tls_do_get(): + """Try a simple do_get call over TLS.""" + table = simple_ints_table() + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + client = FlightClient(('localhost', s.port), + tls_root_certs=certs["root_cert"]) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + +@pytest.mark.requires_testing_data +def test_tls_disable_server_verification(): + """Try a simple do_get call over TLS with server verification disabled.""" + table = simple_ints_table() + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + try: + client = FlightClient(('localhost', s.port), + disable_server_verification=True) + except NotImplementedError: + pytest.skip('disable_server_verification feature is not available') + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + +@pytest.mark.requires_testing_data +def test_tls_override_hostname(): + """Check that incorrectly overriding the hostname fails.""" + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + client = flight.connect(('localhost', s.port), + tls_root_certs=certs["root_cert"], + override_hostname="fakehostname") + with pytest.raises(flight.FlightUnavailableError): + client.do_get(flight.Ticket(b'ints')) + + +def test_flight_do_get_metadata(): + """Try a simple do_get call with metadata.""" + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + batches = [] + with MetadataFlightServer() as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'')) + idx = 0 + while True: + try: + batch, metadata = reader.read_chunk() + batches.append(batch) + server_idx, = struct.unpack('<i', metadata.to_pybytes()) + assert idx == server_idx + idx += 1 + except StopIteration: + break + data = pa.Table.from_batches(batches) + assert data.equals(table) + + +def test_flight_do_get_metadata_v4(): + """Try a simple do_get call with V4 metadata version.""" + table = pa.Table.from_arrays( + [pa.array([-10, -5, 0, 5, 10])], names=['a']) + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4) + with MetadataFlightServer(options=options) as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'')) + data = reader.read_all() + assert data.equals(table) + + +def test_flight_do_put_metadata(): + """Try a simple do_put call with metadata.""" + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + with MetadataFlightServer() as server: + client = FlightClient(('localhost', server.port)) + writer, metadata_reader = client.do_put( + flight.FlightDescriptor.for_path(''), + table.schema) + with writer: + for idx, batch in enumerate(table.to_batches(max_chunksize=1)): + metadata = struct.pack('<i', idx) + writer.write_with_metadata(batch, metadata) + buf = metadata_reader.read() + assert buf is not None + server_idx, = struct.unpack('<i', buf.to_pybytes()) + assert idx == server_idx + + +def test_flight_do_put_limit(): + """Try a simple do_put call with a size limit.""" + large_batch = pa.RecordBatch.from_arrays([ + pa.array(np.ones(768, dtype=np.int64())), + ], names=['a']) + + with EchoFlightServer() as server: + client = FlightClient(('localhost', server.port), + write_size_limit_bytes=4096) + writer, metadata_reader = client.do_put( + flight.FlightDescriptor.for_path(''), + large_batch.schema) + with writer: + with pytest.raises(flight.FlightWriteSizeExceededError, + match="exceeded soft limit") as excinfo: + writer.write_batch(large_batch) + assert excinfo.value.limit == 4096 + smaller_batches = [ + large_batch.slice(0, 384), + large_batch.slice(384), + ] + for batch in smaller_batches: + writer.write_batch(batch) + expected = pa.Table.from_batches([large_batch]) + actual = client.do_get(flight.Ticket(b'')).read_all() + assert expected == actual + + +@pytest.mark.slow +def test_cancel_do_get(): + """Test canceling a DoGet operation on the client side.""" + with ConstantFlightServer() as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'ints')) + reader.cancel() + with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"): + reader.read_chunk() + + +@pytest.mark.slow +def test_cancel_do_get_threaded(): + """Test canceling a DoGet operation from another thread.""" + with SlowFlightServer() as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'ints')) + + read_first_message = threading.Event() + stream_canceled = threading.Event() + result_lock = threading.Lock() + raised_proper_exception = threading.Event() + + def block_read(): + reader.read_chunk() + read_first_message.set() + stream_canceled.wait(timeout=5) + try: + reader.read_chunk() + except flight.FlightCancelledError: + with result_lock: + raised_proper_exception.set() + + thread = threading.Thread(target=block_read, daemon=True) + thread.start() + read_first_message.wait(timeout=5) + reader.cancel() + stream_canceled.set() + thread.join(timeout=1) + + with result_lock: + assert raised_proper_exception.is_set() + + +def test_roundtrip_types(): + """Make sure serializable types round-trip.""" + ticket = flight.Ticket("foo") + assert ticket == flight.Ticket.deserialize(ticket.serialize()) + + desc = flight.FlightDescriptor.for_command("test") + assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) + + desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow") + assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) + + info = flight.FlightInfo( + pa.schema([('a', pa.int32())]), + desc, + [ + flight.FlightEndpoint(b'', ['grpc://test']), + flight.FlightEndpoint( + b'', + [flight.Location.for_grpc_tcp('localhost', 5005)], + ), + ], + -1, + -1, + ) + info2 = flight.FlightInfo.deserialize(info.serialize()) + assert info.schema == info2.schema + assert info.descriptor == info2.descriptor + assert info.total_bytes == info2.total_bytes + assert info.total_records == info2.total_records + assert info.endpoints == info2.endpoints + + +def test_roundtrip_errors(): + """Ensure that Flight errors propagate from server to client.""" + with ErrorFlightServer() as server: + client = FlightClient(('localhost', server.port)) + + with pytest.raises(flight.FlightInternalError, match=".*foo.*"): + list(client.do_action(flight.Action("internal", b""))) + with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"): + list(client.do_action(flight.Action("timedout", b""))) + with pytest.raises(flight.FlightCancelledError, match=".*foo.*"): + list(client.do_action(flight.Action("cancel", b""))) + with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"): + list(client.do_action(flight.Action("unauthenticated", b""))) + with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"): + list(client.do_action(flight.Action("unauthorized", b""))) + with pytest.raises(flight.FlightInternalError, match=".*foo.*"): + list(client.list_flights()) + + data = [pa.array([-10, -5, 0, 5, 10])] + table = pa.Table.from_arrays(data, names=['a']) + + exceptions = { + 'internal': flight.FlightInternalError, + 'timedout': flight.FlightTimedOutError, + 'cancel': flight.FlightCancelledError, + 'unauthenticated': flight.FlightUnauthenticatedError, + 'unauthorized': flight.FlightUnauthorizedError, + } + + for command, exception in exceptions.items(): + + with pytest.raises(exception, match=".*foo.*"): + writer, reader = client.do_put( + flight.FlightDescriptor.for_command(command), + table.schema) + writer.write_table(table) + writer.close() + + with pytest.raises(exception, match=".*foo.*"): + writer, reader = client.do_put( + flight.FlightDescriptor.for_command(command), + table.schema) + writer.close() + + +def test_do_put_independent_read_write(): + """Ensure that separate threads can read/write on a DoPut.""" + # ARROW-6063: previously this would cause gRPC to abort when the + # writer was closed (due to simultaneous reads), or would hang + # forever. + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + with MetadataFlightServer() as server: + client = FlightClient(('localhost', server.port)) + writer, metadata_reader = client.do_put( + flight.FlightDescriptor.for_path(''), + table.schema) + + count = [0] + + def _reader_thread(): + while metadata_reader.read() is not None: + count[0] += 1 + + thread = threading.Thread(target=_reader_thread) + thread.start() + + batches = table.to_batches(max_chunksize=1) + with writer: + for idx, batch in enumerate(batches): + metadata = struct.pack('<i', idx) + writer.write_with_metadata(batch, metadata) + # Causes the server to stop writing and end the call + writer.done_writing() + # Thus reader thread will break out of loop + thread.join() + # writer.close() won't segfault since reader thread has + # stopped + assert count[0] == len(batches) + + +def test_server_middleware_same_thread(): + """Ensure that server middleware run on the same thread as the RPC.""" + with HeaderFlightServer(middleware={ + "test": HeaderServerMiddlewareFactory(), + }) as server: + client = FlightClient(('localhost', server.port)) + results = list(client.do_action(flight.Action(b"test", b""))) + assert len(results) == 1 + value = results[0].body.to_pybytes() + assert b"right value" == value + + +def test_middleware_reject(): + """Test rejecting an RPC with server middleware.""" + with HeaderFlightServer(middleware={ + "test": SelectiveAuthServerMiddlewareFactory(), + }) as server: + client = FlightClient(('localhost', server.port)) + # The middleware allows this through without auth. + with pytest.raises(pa.ArrowNotImplementedError): + list(client.list_actions()) + + # But not anything else. + with pytest.raises(flight.FlightUnauthenticatedError): + list(client.do_action(flight.Action(b"", b""))) + + client = FlightClient( + ('localhost', server.port), + middleware=[SelectiveAuthClientMiddlewareFactory()] + ) + response = next(client.do_action(flight.Action(b"", b""))) + assert b"password" == response.body.to_pybytes() + + +def test_middleware_mapping(): + """Test that middleware records methods correctly.""" + server_middleware = RecordingServerMiddlewareFactory() + client_middleware = RecordingClientMiddlewareFactory() + with FlightServerBase(middleware={"test": server_middleware}) as server: + client = FlightClient( + ('localhost', server.port), + middleware=[client_middleware] + ) + + descriptor = flight.FlightDescriptor.for_command(b"") + with pytest.raises(NotImplementedError): + list(client.list_flights()) + with pytest.raises(NotImplementedError): + client.get_flight_info(descriptor) + with pytest.raises(NotImplementedError): + client.get_schema(descriptor) + with pytest.raises(NotImplementedError): + client.do_get(flight.Ticket(b"")) + with pytest.raises(NotImplementedError): + writer, _ = client.do_put(descriptor, pa.schema([])) + writer.close() + with pytest.raises(NotImplementedError): + list(client.do_action(flight.Action(b"", b""))) + with pytest.raises(NotImplementedError): + list(client.list_actions()) + with pytest.raises(NotImplementedError): + writer, _ = client.do_exchange(descriptor) + writer.close() + + expected = [ + flight.FlightMethod.LIST_FLIGHTS, + flight.FlightMethod.GET_FLIGHT_INFO, + flight.FlightMethod.GET_SCHEMA, + flight.FlightMethod.DO_GET, + flight.FlightMethod.DO_PUT, + flight.FlightMethod.DO_ACTION, + flight.FlightMethod.LIST_ACTIONS, + flight.FlightMethod.DO_EXCHANGE, + ] + assert server_middleware.methods == expected + assert client_middleware.methods == expected + + +def test_extra_info(): + with ErrorFlightServer() as server: + client = FlightClient(('localhost', server.port)) + try: + list(client.do_action(flight.Action("protobuf", b""))) + assert False + except flight.FlightUnauthorizedError as e: + assert e.extra_info is not None + ei = e.extra_info + assert ei == b'this is an error message' + + +@pytest.mark.requires_testing_data +def test_mtls(): + """Test mutual TLS (mTLS) with gRPC.""" + certs = example_tls_certs() + table = simple_ints_table() + + with ConstantFlightServer( + tls_certificates=[certs["certificates"][0]], + verify_client=True, + root_certificates=certs["root_cert"]) as s: + client = FlightClient( + ('localhost', s.port), + tls_root_certs=certs["root_cert"], + cert_chain=certs["certificates"][0].cert, + private_key=certs["certificates"][0].key) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + +def test_doexchange_get(): + """Emulate DoGet with DoExchange.""" + expected = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"get") + writer, reader = client.do_exchange(descriptor) + with writer: + table = reader.read_all() + assert expected == table + + +def test_doexchange_put(): + """Emulate DoPut with DoExchange.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + batches = data.to_batches(max_chunksize=512) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"put") + writer, reader = client.do_exchange(descriptor) + with writer: + writer.begin(data.schema) + for batch in batches: + writer.write_batch(batch) + writer.done_writing() + chunk = reader.read_chunk() + assert chunk.data is None + expected_buf = str(len(batches)).encode("utf-8") + assert chunk.app_metadata == expected_buf + + +def test_doexchange_echo(): + """Try a DoExchange echo server.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + batches = data.to_batches(max_chunksize=512) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"echo") + writer, reader = client.do_exchange(descriptor) + with writer: + # Read/write metadata before starting data. + for i in range(10): + buf = str(i).encode("utf-8") + writer.write_metadata(buf) + chunk = reader.read_chunk() + assert chunk.data is None + assert chunk.app_metadata == buf + + # Now write data without metadata. + writer.begin(data.schema) + for batch in batches: + writer.write_batch(batch) + assert reader.schema == data.schema + chunk = reader.read_chunk() + assert chunk.data == batch + assert chunk.app_metadata is None + + # And write data with metadata. + for i, batch in enumerate(batches): + buf = str(i).encode("utf-8") + writer.write_with_metadata(batch, buf) + chunk = reader.read_chunk() + assert chunk.data == batch + assert chunk.app_metadata == buf + + +def test_doexchange_echo_v4(): + """Try a DoExchange echo server using the V4 metadata version.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + batches = data.to_batches(max_chunksize=512) + + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4) + with ExchangeFlightServer(options=options) as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"echo") + writer, reader = client.do_exchange(descriptor) + with writer: + # Now write data without metadata. + writer.begin(data.schema, options=options) + for batch in batches: + writer.write_batch(batch) + assert reader.schema == data.schema + chunk = reader.read_chunk() + assert chunk.data == batch + assert chunk.app_metadata is None + + +def test_doexchange_transform(): + """Transform a table with a service.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 1024)), + pa.array(range(1, 1025)), + pa.array(range(2, 1026)), + ], names=["a", "b", "c"]) + expected = pa.Table.from_arrays([ + pa.array(range(3, 1024 * 3 + 3, 3)), + ], names=["sum"]) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"transform") + writer, reader = client.do_exchange(descriptor) + with writer: + writer.begin(data.schema) + writer.write_table(data) + writer.done_writing() + table = reader.read_all() + assert expected == table + + +def test_middleware_multi_header(): + """Test sending/receiving multiple (binary-valued) headers.""" + with MultiHeaderFlightServer(middleware={ + "test": MultiHeaderServerMiddlewareFactory(), + }) as server: + headers = MultiHeaderClientMiddlewareFactory() + client = FlightClient(('localhost', server.port), middleware=[headers]) + response = next(client.do_action(flight.Action(b"", b""))) + # The server echoes the headers it got back to us. + raw_headers = response.body.to_pybytes().decode("utf-8") + client_headers = ast.literal_eval(raw_headers) + # Don't directly compare; gRPC may add headers like User-Agent. + for header, values in MultiHeaderClientMiddleware.EXPECTED.items(): + assert client_headers.get(header) == values + assert headers.last_headers.get(header) == values + + +@pytest.mark.requires_testing_data +def test_generic_options(): + """Test setting generic client options.""" + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + # Try setting a string argument that will make requests fail + options = [("grpc.ssl_target_name_override", "fakehostname")] + client = flight.connect(('localhost', s.port), + tls_root_certs=certs["root_cert"], + generic_options=options) + with pytest.raises(flight.FlightUnavailableError): + client.do_get(flight.Ticket(b'ints')) + # Try setting an int argument that will make requests fail + options = [("grpc.max_receive_message_length", 32)] + client = flight.connect(('localhost', s.port), + tls_root_certs=certs["root_cert"], + generic_options=options) + with pytest.raises(pa.ArrowInvalid): + client.do_get(flight.Ticket(b'ints')) + + +class CancelFlightServer(FlightServerBase): + """A server for testing StopToken.""" + + def do_get(self, context, ticket): + schema = pa.schema([]) + rb = pa.RecordBatch.from_arrays([], schema=schema) + return flight.GeneratorStream(schema, itertools.repeat(rb)) + + def do_exchange(self, context, descriptor, reader, writer): + schema = pa.schema([]) + rb = pa.RecordBatch.from_arrays([], schema=schema) + writer.begin(schema) + while not context.is_cancelled(): + writer.write_batch(rb) + time.sleep(0.5) + + +def test_interrupt(): + if threading.current_thread().ident != threading.main_thread().ident: + pytest.skip("test only works from main Python thread") + # Skips test if not available + raise_signal = util.get_raise_signal() + + def signal_from_thread(): + time.sleep(0.5) + raise_signal(signal.SIGINT) + + exc_types = (KeyboardInterrupt, pa.ArrowCancelled) + + def test(read_all): + try: + try: + t = threading.Thread(target=signal_from_thread) + with pytest.raises(exc_types) as exc_info: + t.start() + read_all() + finally: + t.join() + except KeyboardInterrupt: + # In case KeyboardInterrupt didn't interrupt read_all + # above, at least prevent it from stopping the test suite + pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all") + e = exc_info.value.__context__ + assert isinstance(e, pa.ArrowCancelled) or \ + isinstance(e, KeyboardInterrupt) + + with CancelFlightServer() as server: + client = FlightClient(("localhost", server.port)) + + reader = client.do_get(flight.Ticket(b"")) + test(reader.read_all) + + descriptor = flight.FlightDescriptor.for_command(b"echo") + writer, reader = client.do_exchange(descriptor) + test(reader.read_all) + + +def test_never_sends_data(): + # Regression test for ARROW-12779 + match = "application server implementation error" + with NeverSendsDataFlightServer() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(flight.FlightServerError, match=match): + client.do_get(flight.Ticket(b'')).read_all() + + # Check that the server handler will ignore empty tables + # up to a certain extent + table = client.do_get(flight.Ticket(b'yield_data')).read_all() + assert table.num_rows == 5 + + +@pytest.mark.large_memory +@pytest.mark.slow +def test_large_descriptor(): + # Regression test for ARROW-13253. Placed here with appropriate marks + # since some CI pipelines can't run the C++ equivalent + large_descriptor = flight.FlightDescriptor.for_command( + b' ' * (2 ** 31 + 1)) + with FlightServerBase() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(OSError, + match="Failed to serialize Flight descriptor"): + writer, _ = client.do_put(large_descriptor, pa.schema([])) + writer.close() + with pytest.raises(pa.ArrowException, + match="Failed to serialize Flight descriptor"): + client.do_exchange(large_descriptor) + + +@pytest.mark.large_memory +@pytest.mark.slow +def test_large_metadata_client(): + # Regression test for ARROW-13253 + descriptor = flight.FlightDescriptor.for_command(b'') + metadata = b' ' * (2 ** 31 + 1) + with EchoFlightServer() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(pa.ArrowCapacityError, + match="app_metadata size overflow"): + writer, _ = client.do_put(descriptor, pa.schema([])) + with writer: + writer.write_metadata(metadata) + writer.close() + with pytest.raises(pa.ArrowCapacityError, + match="app_metadata size overflow"): + writer, reader = client.do_exchange(descriptor) + with writer: + writer.write_metadata(metadata) + + del metadata + with LargeMetadataFlightServer() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(flight.FlightServerError, + match="app_metadata size overflow"): + reader = client.do_get(flight.Ticket(b'')) + reader.read_all() + with pytest.raises(pa.ArrowException, + match="app_metadata size overflow"): + writer, reader = client.do_exchange(descriptor) + with writer: + reader.read_all() + + +class ActionNoneFlightServer(EchoFlightServer): + """A server that implements a side effect to a non iterable action.""" + VALUES = [] + + def do_action(self, context, action): + if action.type == "get_value": + return [json.dumps(self.VALUES).encode('utf-8')] + elif action.type == "append": + self.VALUES.append(True) + return None + raise NotImplementedError + + +def test_none_action_side_effect(): + """Ensure that actions are executed even when we don't consume iterator. + + See https://issues.apache.org/jira/browse/ARROW-14255 + """ + + with ActionNoneFlightServer() as server: + client = FlightClient(('localhost', server.port)) + client.do_action(flight.Action("append", b"")) + r = client.do_action(flight.Action("get_value", b"")) + assert json.loads(next(r).body.to_pybytes()) == [True] diff --git a/src/arrow/python/pyarrow/tests/test_fs.py b/src/arrow/python/pyarrow/tests/test_fs.py new file mode 100644 index 000000000..48bdae8a5 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_fs.py @@ -0,0 +1,1714 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime, timezone, timedelta +import gzip +import os +import pathlib +import pickle +import subprocess +import sys +import time + +import pytest +import weakref + +import pyarrow as pa +from pyarrow.tests.test_io import assert_file_not_found +from pyarrow.tests.util import _filesystem_uri, ProxyHandler +from pyarrow.vendored.version import Version + +from pyarrow.fs import (FileType, FileInfo, FileSelector, FileSystem, + LocalFileSystem, SubTreeFileSystem, _MockFileSystem, + FileSystemHandler, PyFileSystem, FSSpecHandler, + copy_files) + + +class DummyHandler(FileSystemHandler): + def __init__(self, value=42): + self._value = value + + def __eq__(self, other): + if isinstance(other, FileSystemHandler): + return self._value == other._value + return NotImplemented + + def __ne__(self, other): + if isinstance(other, FileSystemHandler): + return self._value != other._value + return NotImplemented + + def get_type_name(self): + return "dummy" + + def normalize_path(self, path): + return path + + def get_file_info(self, paths): + info = [] + for path in paths: + if "file" in path: + info.append(FileInfo(path, FileType.File)) + elif "dir" in path: + info.append(FileInfo(path, FileType.Directory)) + elif "notfound" in path: + info.append(FileInfo(path, FileType.NotFound)) + elif "badtype" in path: + # Will raise when converting + info.append(object()) + else: + raise IOError + return info + + def get_file_info_selector(self, selector): + if selector.base_dir != "somedir": + if selector.allow_not_found: + return [] + else: + raise FileNotFoundError(selector.base_dir) + infos = [ + FileInfo("somedir/file1", FileType.File, size=123), + FileInfo("somedir/subdir1", FileType.Directory), + ] + if selector.recursive: + infos += [ + FileInfo("somedir/subdir1/file2", FileType.File, size=456), + ] + return infos + + def create_dir(self, path, recursive): + if path == "recursive": + assert recursive is True + elif path == "non-recursive": + assert recursive is False + else: + raise IOError + + def delete_dir(self, path): + assert path == "delete_dir" + + def delete_dir_contents(self, path): + if not path.strip("/"): + raise ValueError + assert path == "delete_dir_contents" + + def delete_root_dir_contents(self): + pass + + def delete_file(self, path): + assert path == "delete_file" + + def move(self, src, dest): + assert src == "move_from" + assert dest == "move_to" + + def copy_file(self, src, dest): + assert src == "copy_file_from" + assert dest == "copy_file_to" + + def open_input_stream(self, path): + if "notfound" in path: + raise FileNotFoundError(path) + data = "{0}:input_stream".format(path).encode('utf8') + return pa.BufferReader(data) + + def open_input_file(self, path): + if "notfound" in path: + raise FileNotFoundError(path) + data = "{0}:input_file".format(path).encode('utf8') + return pa.BufferReader(data) + + def open_output_stream(self, path, metadata): + if "notfound" in path: + raise FileNotFoundError(path) + return pa.BufferOutputStream() + + def open_append_stream(self, path, metadata): + if "notfound" in path: + raise FileNotFoundError(path) + return pa.BufferOutputStream() + + +@pytest.fixture +def localfs(request, tempdir): + return dict( + fs=LocalFileSystem(), + pathfn=lambda p: (tempdir / p).as_posix(), + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def py_localfs(request, tempdir): + return dict( + fs=PyFileSystem(ProxyHandler(LocalFileSystem())), + pathfn=lambda p: (tempdir / p).as_posix(), + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def mockfs(request): + return dict( + fs=_MockFileSystem(), + pathfn=lambda p: p, + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def py_mockfs(request): + return dict( + fs=PyFileSystem(ProxyHandler(_MockFileSystem())), + pathfn=lambda p: p, + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def localfs_with_mmap(request, tempdir): + return dict( + fs=LocalFileSystem(use_mmap=True), + pathfn=lambda p: (tempdir / p).as_posix(), + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def subtree_localfs(request, tempdir, localfs): + return dict( + fs=SubTreeFileSystem(str(tempdir), localfs['fs']), + pathfn=lambda p: p, + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def s3fs(request, s3_server): + request.config.pyarrow.requires('s3') + from pyarrow.fs import S3FileSystem + + host, port, access_key, secret_key = s3_server['connection'] + bucket = 'pyarrow-filesystem/' + + fs = S3FileSystem( + access_key=access_key, + secret_key=secret_key, + endpoint_override='{}:{}'.format(host, port), + scheme='http' + ) + fs.create_dir(bucket) + + yield dict( + fs=fs, + pathfn=bucket.__add__, + allow_move_dir=False, + allow_append_to_file=False, + ) + fs.delete_dir(bucket) + + +@pytest.fixture +def subtree_s3fs(request, s3fs): + prefix = 'pyarrow-filesystem/prefix/' + return dict( + fs=SubTreeFileSystem(prefix, s3fs['fs']), + pathfn=prefix.__add__, + allow_move_dir=False, + allow_append_to_file=False, + ) + + +_minio_limited_policy = """{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:ListAllMyBuckets", + "s3:PutObject", + "s3:GetObject", + "s3:ListBucket", + "s3:PutObjectTagging", + "s3:DeleteObject", + "s3:GetObjectVersion" + ], + "Resource": [ + "arn:aws:s3:::*" + ] + } + ] +}""" + + +def _run_mc_command(mcdir, *args): + full_args = ['mc', '-C', mcdir] + list(args) + proc = subprocess.Popen(full_args, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, encoding='utf-8') + retval = proc.wait(10) + cmd_str = ' '.join(full_args) + print(f'Cmd: {cmd_str}') + print(f' Return: {retval}') + print(f' Stdout: {proc.stdout.read()}') + print(f' Stderr: {proc.stderr.read()}') + if retval != 0: + raise ChildProcessError("Could not run mc") + + +def _wait_for_minio_startup(mcdir, address, access_key, secret_key): + start = time.time() + while time.time() - start < 10: + try: + _run_mc_command(mcdir, 'alias', 'set', 'myminio', + f'http://{address}', access_key, secret_key) + return + except ChildProcessError: + time.sleep(1) + raise Exception("mc command could not connect to local minio") + + +def _configure_limited_user(tmpdir, address, access_key, secret_key): + """ + Attempts to use the mc command to configure the minio server + with a special user limited:limited123 which does not have + permission to create buckets. This mirrors some real life S3 + configurations where users are given strict permissions. + + Arrow S3 operations should still work in such a configuration + (e.g. see ARROW-13685) + """ + try: + mcdir = os.path.join(tmpdir, 'mc') + os.mkdir(mcdir) + policy_path = os.path.join(tmpdir, 'limited-buckets-policy.json') + with open(policy_path, mode='w') as policy_file: + policy_file.write(_minio_limited_policy) + # The s3_server fixture starts the minio process but + # it takes a few moments for the process to become available + _wait_for_minio_startup(mcdir, address, access_key, secret_key) + # These commands create a limited user with a specific + # policy and creates a sample bucket for that user to + # write to + _run_mc_command(mcdir, 'admin', 'policy', 'add', + 'myminio/', 'no-create-buckets', policy_path) + _run_mc_command(mcdir, 'admin', 'user', 'add', + 'myminio/', 'limited', 'limited123') + _run_mc_command(mcdir, 'admin', 'policy', 'set', + 'myminio', 'no-create-buckets', 'user=limited') + _run_mc_command(mcdir, 'mb', 'myminio/existing-bucket') + return True + except FileNotFoundError: + # If mc is not found, skip these tests + return False + + +@pytest.fixture(scope='session') +def limited_s3_user(request, s3_server): + if sys.platform == 'win32': + # Can't rely on FileNotFound check because + # there is sometimes an mc command on Windows + # which is unrelated to the minio mc + pytest.skip('The mc command is not installed on Windows') + request.config.pyarrow.requires('s3') + tempdir = s3_server['tempdir'] + host, port, access_key, secret_key = s3_server['connection'] + address = '{}:{}'.format(host, port) + if not _configure_limited_user(tempdir, address, access_key, secret_key): + pytest.skip('Could not locate mc command to configure limited user') + + +@pytest.fixture +def hdfs(request, hdfs_connection): + request.config.pyarrow.requires('hdfs') + if not pa.have_libhdfs(): + pytest.skip('Cannot locate libhdfs') + + from pyarrow.fs import HadoopFileSystem + + host, port, user = hdfs_connection + fs = HadoopFileSystem(host, port=port, user=user) + + return dict( + fs=fs, + pathfn=lambda p: p, + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def py_fsspec_localfs(request, tempdir): + fsspec = pytest.importorskip("fsspec") + fs = fsspec.filesystem('file') + return dict( + fs=PyFileSystem(FSSpecHandler(fs)), + pathfn=lambda p: (tempdir / p).as_posix(), + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def py_fsspec_memoryfs(request, tempdir): + fsspec = pytest.importorskip("fsspec", minversion="0.7.5") + if fsspec.__version__ == "0.8.5": + # see https://issues.apache.org/jira/browse/ARROW-10934 + pytest.skip("Bug in fsspec 0.8.5 for in-memory filesystem") + fs = fsspec.filesystem('memory') + return dict( + fs=PyFileSystem(FSSpecHandler(fs)), + pathfn=lambda p: p, + allow_move_dir=True, + allow_append_to_file=True, + ) + + +@pytest.fixture +def py_fsspec_s3fs(request, s3_server): + s3fs = pytest.importorskip("s3fs") + if (sys.version_info < (3, 7) and + Version(s3fs.__version__) >= Version("0.5")): + pytest.skip("s3fs>=0.5 version is async and requires Python >= 3.7") + + host, port, access_key, secret_key = s3_server['connection'] + bucket = 'pyarrow-filesystem/' + + fs = s3fs.S3FileSystem( + key=access_key, + secret=secret_key, + client_kwargs=dict(endpoint_url='http://{}:{}'.format(host, port)) + ) + fs = PyFileSystem(FSSpecHandler(fs)) + fs.create_dir(bucket) + + yield dict( + fs=fs, + pathfn=bucket.__add__, + allow_move_dir=False, + allow_append_to_file=True, + ) + fs.delete_dir(bucket) + + +@pytest.fixture(params=[ + pytest.param( + pytest.lazy_fixture('localfs'), + id='LocalFileSystem()' + ), + pytest.param( + pytest.lazy_fixture('localfs_with_mmap'), + id='LocalFileSystem(use_mmap=True)' + ), + pytest.param( + pytest.lazy_fixture('subtree_localfs'), + id='SubTreeFileSystem(LocalFileSystem())' + ), + pytest.param( + pytest.lazy_fixture('s3fs'), + id='S3FileSystem', + marks=pytest.mark.s3 + ), + pytest.param( + pytest.lazy_fixture('hdfs'), + id='HadoopFileSystem', + marks=pytest.mark.hdfs + ), + pytest.param( + pytest.lazy_fixture('mockfs'), + id='_MockFileSystem()' + ), + pytest.param( + pytest.lazy_fixture('py_localfs'), + id='PyFileSystem(ProxyHandler(LocalFileSystem()))' + ), + pytest.param( + pytest.lazy_fixture('py_mockfs'), + id='PyFileSystem(ProxyHandler(_MockFileSystem()))' + ), + pytest.param( + pytest.lazy_fixture('py_fsspec_localfs'), + id='PyFileSystem(FSSpecHandler(fsspec.LocalFileSystem()))' + ), + pytest.param( + pytest.lazy_fixture('py_fsspec_memoryfs'), + id='PyFileSystem(FSSpecHandler(fsspec.filesystem("memory")))' + ), + pytest.param( + pytest.lazy_fixture('py_fsspec_s3fs'), + id='PyFileSystem(FSSpecHandler(s3fs.S3FileSystem()))', + marks=pytest.mark.s3 + ), +]) +def filesystem_config(request): + return request.param + + +@pytest.fixture +def fs(request, filesystem_config): + return filesystem_config['fs'] + + +@pytest.fixture +def pathfn(request, filesystem_config): + return filesystem_config['pathfn'] + + +@pytest.fixture +def allow_move_dir(request, filesystem_config): + return filesystem_config['allow_move_dir'] + + +@pytest.fixture +def allow_append_to_file(request, filesystem_config): + return filesystem_config['allow_append_to_file'] + + +def check_mtime(file_info): + assert isinstance(file_info.mtime, datetime) + assert isinstance(file_info.mtime_ns, int) + assert file_info.mtime_ns >= 0 + assert file_info.mtime_ns == pytest.approx( + file_info.mtime.timestamp() * 1e9) + # It's an aware UTC datetime + tzinfo = file_info.mtime.tzinfo + assert tzinfo is not None + assert tzinfo.utcoffset(None) == timedelta(0) + + +def check_mtime_absent(file_info): + assert file_info.mtime is None + assert file_info.mtime_ns is None + + +def check_mtime_or_absent(file_info): + if file_info.mtime is None: + check_mtime_absent(file_info) + else: + check_mtime(file_info) + + +def skip_fsspec_s3fs(fs): + if fs.type_name == "py::fsspec+s3": + pytest.xfail(reason="Not working with fsspec's s3fs") + + +@pytest.mark.s3 +def test_s3fs_limited_permissions_create_bucket(s3_server, limited_s3_user): + from pyarrow.fs import S3FileSystem + + host, port, _, _ = s3_server['connection'] + + fs = S3FileSystem( + access_key='limited', + secret_key='limited123', + endpoint_override='{}:{}'.format(host, port), + scheme='http' + ) + fs.create_dir('existing-bucket/test') + + +def test_file_info_constructor(): + dt = datetime.fromtimestamp(1568799826, timezone.utc) + + info = FileInfo("foo/bar") + assert info.path == "foo/bar" + assert info.base_name == "bar" + assert info.type == FileType.Unknown + assert info.size is None + check_mtime_absent(info) + + info = FileInfo("foo/baz.txt", type=FileType.File, size=123, + mtime=1568799826.5) + assert info.path == "foo/baz.txt" + assert info.base_name == "baz.txt" + assert info.type == FileType.File + assert info.size == 123 + assert info.mtime_ns == 1568799826500000000 + check_mtime(info) + + info = FileInfo("foo", type=FileType.Directory, mtime=dt) + assert info.path == "foo" + assert info.base_name == "foo" + assert info.type == FileType.Directory + assert info.size is None + assert info.mtime == dt + assert info.mtime_ns == 1568799826000000000 + check_mtime(info) + + +def test_cannot_instantiate_base_filesystem(): + with pytest.raises(TypeError): + FileSystem() + + +def test_filesystem_equals(): + fs0 = LocalFileSystem() + fs1 = LocalFileSystem() + fs2 = _MockFileSystem() + + assert fs0.equals(fs0) + assert fs0.equals(fs1) + with pytest.raises(TypeError): + fs0.equals('string') + assert fs0 == fs0 == fs1 + assert fs0 != 4 + + assert fs2 == fs2 + assert fs2 != _MockFileSystem() + + assert SubTreeFileSystem('/base', fs0) == SubTreeFileSystem('/base', fs0) + assert SubTreeFileSystem('/base', fs0) != SubTreeFileSystem('/base', fs2) + assert SubTreeFileSystem('/base', fs0) != SubTreeFileSystem('/other', fs0) + + +def test_subtree_filesystem(): + localfs = LocalFileSystem() + + subfs = SubTreeFileSystem('/base', localfs) + assert subfs.base_path == '/base/' + assert subfs.base_fs == localfs + assert repr(subfs).startswith('SubTreeFileSystem(base_path=/base/, ' + 'base_fs=<pyarrow._fs.LocalFileSystem') + + subfs = SubTreeFileSystem('/another/base/', LocalFileSystem()) + assert subfs.base_path == '/another/base/' + assert subfs.base_fs == localfs + assert repr(subfs).startswith('SubTreeFileSystem(base_path=/another/base/,' + ' base_fs=<pyarrow._fs.LocalFileSystem') + + +def test_filesystem_pickling(fs): + if fs.type_name.split('::')[-1] == 'mock': + pytest.xfail(reason='MockFileSystem is not serializable') + + serialized = pickle.dumps(fs) + restored = pickle.loads(serialized) + assert isinstance(restored, FileSystem) + assert restored.equals(fs) + + +def test_filesystem_is_functional_after_pickling(fs, pathfn): + if fs.type_name.split('::')[-1] == 'mock': + pytest.xfail(reason='MockFileSystem is not serializable') + skip_fsspec_s3fs(fs) + + aaa = pathfn('a/aa/aaa/') + bb = pathfn('a/bb') + c = pathfn('c.txt') + + fs.create_dir(aaa) + with fs.open_output_stream(bb): + pass # touch + with fs.open_output_stream(c) as fp: + fp.write(b'test') + + restored = pickle.loads(pickle.dumps(fs)) + aaa_info, bb_info, c_info = restored.get_file_info([aaa, bb, c]) + assert aaa_info.type == FileType.Directory + assert bb_info.type == FileType.File + assert c_info.type == FileType.File + + +def test_type_name(): + fs = LocalFileSystem() + assert fs.type_name == "local" + fs = _MockFileSystem() + assert fs.type_name == "mock" + + +def test_normalize_path(fs): + # Trivial path names (without separators) should generally be + # already normalized. Just a sanity check. + assert fs.normalize_path("foo") == "foo" + + +def test_non_path_like_input_raises(fs): + class Path: + pass + + invalid_paths = [1, 1.1, Path(), tuple(), {}, [], lambda: 1, + pathlib.Path()] + for path in invalid_paths: + with pytest.raises(TypeError): + fs.create_dir(path) + + +def test_get_file_info(fs, pathfn): + aaa = pathfn('a/aa/aaa/') + bb = pathfn('a/bb') + c = pathfn('c.txt') + zzz = pathfn('zzz') + + fs.create_dir(aaa) + with fs.open_output_stream(bb): + pass # touch + with fs.open_output_stream(c) as fp: + fp.write(b'test') + + aaa_info, bb_info, c_info, zzz_info = fs.get_file_info([aaa, bb, c, zzz]) + + assert aaa_info.path == aaa + assert 'aaa' in repr(aaa_info) + assert aaa_info.extension == '' + if fs.type_name == "py::fsspec+s3": + # s3fs doesn't create empty directories + assert aaa_info.type == FileType.NotFound + else: + assert aaa_info.type == FileType.Directory + assert 'FileType.Directory' in repr(aaa_info) + assert aaa_info.size is None + check_mtime_or_absent(aaa_info) + + assert bb_info.path == str(bb) + assert bb_info.base_name == 'bb' + assert bb_info.extension == '' + assert bb_info.type == FileType.File + assert 'FileType.File' in repr(bb_info) + assert bb_info.size == 0 + if fs.type_name not in ["py::fsspec+memory", "py::fsspec+s3"]: + check_mtime(bb_info) + + assert c_info.path == str(c) + assert c_info.base_name == 'c.txt' + assert c_info.extension == 'txt' + assert c_info.type == FileType.File + assert 'FileType.File' in repr(c_info) + assert c_info.size == 4 + if fs.type_name not in ["py::fsspec+memory", "py::fsspec+s3"]: + check_mtime(c_info) + + assert zzz_info.path == str(zzz) + assert zzz_info.base_name == 'zzz' + assert zzz_info.extension == '' + assert zzz_info.type == FileType.NotFound + assert zzz_info.size is None + assert zzz_info.mtime is None + assert 'FileType.NotFound' in repr(zzz_info) + check_mtime_absent(zzz_info) + + # with single path + aaa_info2 = fs.get_file_info(aaa) + assert aaa_info.path == aaa_info2.path + assert aaa_info.type == aaa_info2.type + + +def test_get_file_info_with_selector(fs, pathfn): + base_dir = pathfn('selector-dir/') + file_a = pathfn('selector-dir/test_file_a') + file_b = pathfn('selector-dir/test_file_b') + dir_a = pathfn('selector-dir/test_dir_a') + file_c = pathfn('selector-dir/test_dir_a/test_file_c') + dir_b = pathfn('selector-dir/test_dir_b') + + try: + fs.create_dir(base_dir) + with fs.open_output_stream(file_a): + pass + with fs.open_output_stream(file_b): + pass + fs.create_dir(dir_a) + with fs.open_output_stream(file_c): + pass + fs.create_dir(dir_b) + + # recursive selector + selector = FileSelector(base_dir, allow_not_found=False, + recursive=True) + assert selector.base_dir == base_dir + + infos = fs.get_file_info(selector) + if fs.type_name == "py::fsspec+s3": + # s3fs only lists directories if they are not empty, but depending + # on the s3fs/fsspec version combo, it includes the base_dir + # (https://github.com/dask/s3fs/issues/393) + assert (len(infos) == 4) or (len(infos) == 5) + else: + assert len(infos) == 5 + + for info in infos: + if (info.path.endswith(file_a) or info.path.endswith(file_b) or + info.path.endswith(file_c)): + assert info.type == FileType.File + elif (info.path.rstrip("/").endswith(dir_a) or + info.path.rstrip("/").endswith(dir_b)): + assert info.type == FileType.Directory + elif (fs.type_name == "py::fsspec+s3" and + info.path.rstrip("/").endswith("selector-dir")): + # s3fs can include base dir, see above + assert info.type == FileType.Directory + else: + raise ValueError('unexpected path {}'.format(info.path)) + check_mtime_or_absent(info) + + # non-recursive selector -> not selecting the nested file_c + selector = FileSelector(base_dir, recursive=False) + + infos = fs.get_file_info(selector) + if fs.type_name == "py::fsspec+s3": + # s3fs only lists directories if they are not empty + # + for s3fs 0.5.2 all directories are dropped because of buggy + # side-effect of previous find() call + # (https://github.com/dask/s3fs/issues/410) + assert (len(infos) == 3) or (len(infos) == 2) + else: + assert len(infos) == 4 + + finally: + fs.delete_dir(base_dir) + + +def test_create_dir(fs, pathfn): + # s3fs fails deleting dir fails if it is empty + # (https://github.com/dask/s3fs/issues/317) + skip_fsspec_s3fs(fs) + d = pathfn('test-directory/') + + with pytest.raises(pa.ArrowIOError): + fs.delete_dir(d) + + fs.create_dir(d) + fs.delete_dir(d) + + d = pathfn('deeply/nested/test-directory/') + fs.create_dir(d, recursive=True) + fs.delete_dir(d) + + +def test_delete_dir(fs, pathfn): + skip_fsspec_s3fs(fs) + + d = pathfn('directory/') + nd = pathfn('directory/nested/') + + fs.create_dir(nd) + fs.delete_dir(d) + with pytest.raises(pa.ArrowIOError): + fs.delete_dir(nd) + with pytest.raises(pa.ArrowIOError): + fs.delete_dir(d) + + +def test_delete_dir_contents(fs, pathfn): + skip_fsspec_s3fs(fs) + + d = pathfn('directory/') + nd = pathfn('directory/nested/') + + fs.create_dir(nd) + fs.delete_dir_contents(d) + with pytest.raises(pa.ArrowIOError): + fs.delete_dir(nd) + fs.delete_dir(d) + with pytest.raises(pa.ArrowIOError): + fs.delete_dir(d) + + +def _check_root_dir_contents(config): + fs = config['fs'] + pathfn = config['pathfn'] + + d = pathfn('directory/') + nd = pathfn('directory/nested/') + + fs.create_dir(nd) + with pytest.raises(pa.ArrowInvalid): + fs.delete_dir_contents("") + with pytest.raises(pa.ArrowInvalid): + fs.delete_dir_contents("/") + with pytest.raises(pa.ArrowInvalid): + fs.delete_dir_contents("//") + + fs.delete_dir_contents("", accept_root_dir=True) + fs.delete_dir_contents("/", accept_root_dir=True) + fs.delete_dir_contents("//", accept_root_dir=True) + with pytest.raises(pa.ArrowIOError): + fs.delete_dir(d) + + +def test_delete_root_dir_contents(mockfs, py_mockfs): + _check_root_dir_contents(mockfs) + _check_root_dir_contents(py_mockfs) + + +def test_copy_file(fs, pathfn): + s = pathfn('test-copy-source-file') + t = pathfn('test-copy-target-file') + + with fs.open_output_stream(s): + pass + + fs.copy_file(s, t) + fs.delete_file(s) + fs.delete_file(t) + + +def test_move_directory(fs, pathfn, allow_move_dir): + # move directory (doesn't work with S3) + s = pathfn('source-dir/') + t = pathfn('target-dir/') + + fs.create_dir(s) + + if allow_move_dir: + fs.move(s, t) + with pytest.raises(pa.ArrowIOError): + fs.delete_dir(s) + fs.delete_dir(t) + else: + with pytest.raises(pa.ArrowIOError): + fs.move(s, t) + + +def test_move_file(fs, pathfn): + # s3fs moving a file with recursive=True on latest 0.5 version + # (https://github.com/dask/s3fs/issues/394) + skip_fsspec_s3fs(fs) + + s = pathfn('test-move-source-file') + t = pathfn('test-move-target-file') + + with fs.open_output_stream(s): + pass + + fs.move(s, t) + with pytest.raises(pa.ArrowIOError): + fs.delete_file(s) + fs.delete_file(t) + + +def test_delete_file(fs, pathfn): + p = pathfn('test-delete-target-file') + with fs.open_output_stream(p): + pass + + fs.delete_file(p) + with pytest.raises(pa.ArrowIOError): + fs.delete_file(p) + + d = pathfn('test-delete-nested') + fs.create_dir(d) + f = pathfn('test-delete-nested/target-file') + with fs.open_output_stream(f) as s: + s.write(b'data') + + fs.delete_dir(d) + + +def identity(v): + return v + + +@pytest.mark.gzip +@pytest.mark.parametrize( + ('compression', 'buffer_size', 'compressor'), + [ + (None, None, identity), + (None, 64, identity), + ('gzip', None, gzip.compress), + ('gzip', 256, gzip.compress), + ] +) +def test_open_input_stream(fs, pathfn, compression, buffer_size, compressor): + p = pathfn('open-input-stream') + + data = b'some data for reading\n' * 512 + with fs.open_output_stream(p) as s: + s.write(compressor(data)) + + with fs.open_input_stream(p, compression, buffer_size) as s: + result = s.read() + + assert result == data + + +def test_open_input_file(fs, pathfn): + p = pathfn('open-input-file') + + data = b'some data' * 1024 + with fs.open_output_stream(p) as s: + s.write(data) + + read_from = len(b'some data') * 512 + with fs.open_input_file(p) as f: + f.seek(read_from) + result = f.read() + + assert result == data[read_from:] + + +@pytest.mark.gzip +@pytest.mark.parametrize( + ('compression', 'buffer_size', 'decompressor'), + [ + (None, None, identity), + (None, 64, identity), + ('gzip', None, gzip.decompress), + ('gzip', 256, gzip.decompress), + ] +) +def test_open_output_stream(fs, pathfn, compression, buffer_size, + decompressor): + p = pathfn('open-output-stream') + + data = b'some data for writing' * 1024 + with fs.open_output_stream(p, compression, buffer_size) as f: + f.write(data) + + with fs.open_input_stream(p, compression, buffer_size) as f: + assert f.read(len(data)) == data + + +@pytest.mark.gzip +@pytest.mark.parametrize( + ('compression', 'buffer_size', 'compressor', 'decompressor'), + [ + (None, None, identity, identity), + (None, 64, identity, identity), + ('gzip', None, gzip.compress, gzip.decompress), + ('gzip', 256, gzip.compress, gzip.decompress), + ] +) +@pytest.mark.filterwarnings("ignore::FutureWarning") +def test_open_append_stream(fs, pathfn, compression, buffer_size, compressor, + decompressor, allow_append_to_file): + p = pathfn('open-append-stream') + + initial = compressor(b'already existing') + with fs.open_output_stream(p) as s: + s.write(initial) + + if allow_append_to_file: + with fs.open_append_stream(p, compression=compression, + buffer_size=buffer_size) as f: + f.write(b'\nnewly added') + + with fs.open_input_stream(p) as f: + result = f.read() + + result = decompressor(result) + assert result == b'already existing\nnewly added' + else: + with pytest.raises(pa.ArrowNotImplementedError): + fs.open_append_stream(p, compression=compression, + buffer_size=buffer_size) + + +def test_open_output_stream_metadata(fs, pathfn): + p = pathfn('open-output-stream-metadata') + metadata = {'Content-Type': 'x-pyarrow/test'} + + data = b'some data' + with fs.open_output_stream(p, metadata=metadata) as f: + f.write(data) + + with fs.open_input_stream(p) as f: + assert f.read() == data + got_metadata = f.metadata() + + if fs.type_name == 's3' or 'mock' in fs.type_name: + for k, v in metadata.items(): + assert got_metadata[k] == v.encode() + else: + assert got_metadata == {} + + +def test_localfs_options(): + # LocalFileSystem instantiation + LocalFileSystem(use_mmap=False) + + with pytest.raises(TypeError): + LocalFileSystem(xxx=False) + + +def test_localfs_errors(localfs): + # Local filesystem errors should raise the right Python exceptions + # (e.g. FileNotFoundError) + fs = localfs['fs'] + with assert_file_not_found(): + fs.open_input_stream('/non/existent/file') + with assert_file_not_found(): + fs.open_output_stream('/non/existent/file') + with assert_file_not_found(): + fs.create_dir('/non/existent/dir', recursive=False) + with assert_file_not_found(): + fs.delete_dir('/non/existent/dir') + with assert_file_not_found(): + fs.delete_file('/non/existent/dir') + with assert_file_not_found(): + fs.move('/non/existent', '/xxx') + with assert_file_not_found(): + fs.copy_file('/non/existent', '/xxx') + + +def test_localfs_file_info(localfs): + fs = localfs['fs'] + + file_path = pathlib.Path(__file__) + dir_path = file_path.parent + [file_info, dir_info] = fs.get_file_info([file_path.as_posix(), + dir_path.as_posix()]) + assert file_info.size == file_path.stat().st_size + assert file_info.mtime_ns == file_path.stat().st_mtime_ns + check_mtime(file_info) + assert dir_info.mtime_ns == dir_path.stat().st_mtime_ns + check_mtime(dir_info) + + +def test_mockfs_mtime_roundtrip(mockfs): + dt = datetime.fromtimestamp(1568799826, timezone.utc) + fs = _MockFileSystem(dt) + + with fs.open_output_stream('foo'): + pass + [info] = fs.get_file_info(['foo']) + assert info.mtime == dt + + +@pytest.mark.s3 +def test_s3_options(): + from pyarrow.fs import S3FileSystem + + fs = S3FileSystem(access_key='access', secret_key='secret', + session_token='token', region='us-east-2', + scheme='https', endpoint_override='localhost:8999') + assert isinstance(fs, S3FileSystem) + assert fs.region == 'us-east-2' + assert pickle.loads(pickle.dumps(fs)) == fs + + fs = S3FileSystem(role_arn='role', session_name='session', + external_id='id', load_frequency=100) + assert isinstance(fs, S3FileSystem) + assert pickle.loads(pickle.dumps(fs)) == fs + + fs = S3FileSystem(anonymous=True) + assert isinstance(fs, S3FileSystem) + assert pickle.loads(pickle.dumps(fs)) == fs + + fs = S3FileSystem(background_writes=True, + default_metadata={"ACL": "authenticated-read", + "Content-Type": "text/plain"}) + assert isinstance(fs, S3FileSystem) + assert pickle.loads(pickle.dumps(fs)) == fs + + with pytest.raises(ValueError): + S3FileSystem(access_key='access') + with pytest.raises(ValueError): + S3FileSystem(secret_key='secret') + with pytest.raises(ValueError): + S3FileSystem(access_key='access', session_token='token') + with pytest.raises(ValueError): + S3FileSystem(secret_key='secret', session_token='token') + with pytest.raises(ValueError): + S3FileSystem( + access_key='access', secret_key='secret', role_arn='arn' + ) + with pytest.raises(ValueError): + S3FileSystem( + access_key='access', secret_key='secret', anonymous=True + ) + with pytest.raises(ValueError): + S3FileSystem(role_arn="arn", anonymous=True) + with pytest.raises(ValueError): + S3FileSystem(default_metadata=["foo", "bar"]) + + +@pytest.mark.s3 +def test_s3_proxy_options(monkeypatch): + from pyarrow.fs import S3FileSystem + + # The following two are equivalent: + proxy_opts_1_dict = {'scheme': 'http', 'host': 'localhost', 'port': 8999} + proxy_opts_1_str = 'http://localhost:8999' + # The following two are equivalent: + proxy_opts_2_dict = {'scheme': 'https', 'host': 'localhost', 'port': 8080} + proxy_opts_2_str = 'https://localhost:8080' + + # Check dict case for 'proxy_options' + fs = S3FileSystem(proxy_options=proxy_opts_1_dict) + assert isinstance(fs, S3FileSystem) + assert pickle.loads(pickle.dumps(fs)) == fs + + fs = S3FileSystem(proxy_options=proxy_opts_2_dict) + assert isinstance(fs, S3FileSystem) + assert pickle.loads(pickle.dumps(fs)) == fs + + # Check str case for 'proxy_options' + fs = S3FileSystem(proxy_options=proxy_opts_1_str) + assert isinstance(fs, S3FileSystem) + assert pickle.loads(pickle.dumps(fs)) == fs + + fs = S3FileSystem(proxy_options=proxy_opts_2_str) + assert isinstance(fs, S3FileSystem) + assert pickle.loads(pickle.dumps(fs)) == fs + + # Check that two FSs using the same proxy_options dict are equal + fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict) + fs2 = S3FileSystem(proxy_options=proxy_opts_1_dict) + assert fs1 == fs2 + assert pickle.loads(pickle.dumps(fs1)) == fs2 + assert pickle.loads(pickle.dumps(fs2)) == fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_2_dict) + fs2 = S3FileSystem(proxy_options=proxy_opts_2_dict) + assert fs1 == fs2 + assert pickle.loads(pickle.dumps(fs1)) == fs2 + assert pickle.loads(pickle.dumps(fs2)) == fs1 + + # Check that two FSs using the same proxy_options str are equal + fs1 = S3FileSystem(proxy_options=proxy_opts_1_str) + fs2 = S3FileSystem(proxy_options=proxy_opts_1_str) + assert fs1 == fs2 + assert pickle.loads(pickle.dumps(fs1)) == fs2 + assert pickle.loads(pickle.dumps(fs2)) == fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_2_str) + fs2 = S3FileSystem(proxy_options=proxy_opts_2_str) + assert fs1 == fs2 + assert pickle.loads(pickle.dumps(fs1)) == fs2 + assert pickle.loads(pickle.dumps(fs2)) == fs1 + + # Check that two FSs using equivalent proxy_options + # (one dict, one str) are equal + fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict) + fs2 = S3FileSystem(proxy_options=proxy_opts_1_str) + assert fs1 == fs2 + assert pickle.loads(pickle.dumps(fs1)) == fs2 + assert pickle.loads(pickle.dumps(fs2)) == fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_2_dict) + fs2 = S3FileSystem(proxy_options=proxy_opts_2_str) + assert fs1 == fs2 + assert pickle.loads(pickle.dumps(fs1)) == fs2 + assert pickle.loads(pickle.dumps(fs2)) == fs1 + + # Check that two FSs using nonequivalent proxy_options are not equal + fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict) + fs2 = S3FileSystem(proxy_options=proxy_opts_2_dict) + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict) + fs2 = S3FileSystem(proxy_options=proxy_opts_2_str) + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_1_str) + fs2 = S3FileSystem(proxy_options=proxy_opts_2_dict) + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_1_str) + fs2 = S3FileSystem(proxy_options=proxy_opts_2_str) + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + # Check that two FSs (one using proxy_options and the other not) + # are not equal + fs1 = S3FileSystem(proxy_options=proxy_opts_1_dict) + fs2 = S3FileSystem() + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_1_str) + fs2 = S3FileSystem() + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_2_dict) + fs2 = S3FileSystem() + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + fs1 = S3FileSystem(proxy_options=proxy_opts_2_str) + fs2 = S3FileSystem() + assert fs1 != fs2 + assert pickle.loads(pickle.dumps(fs1)) != fs2 + assert pickle.loads(pickle.dumps(fs2)) != fs1 + + # Only dict and str are supported + with pytest.raises(TypeError): + S3FileSystem(proxy_options=('http', 'localhost', 9090)) + # Missing scheme + with pytest.raises(KeyError): + S3FileSystem(proxy_options={'host': 'localhost', 'port': 9090}) + # Missing host + with pytest.raises(KeyError): + S3FileSystem(proxy_options={'scheme': 'https', 'port': 9090}) + # Missing port + with pytest.raises(KeyError): + S3FileSystem(proxy_options={'scheme': 'http', 'host': 'localhost'}) + # Invalid proxy URI (invalid scheme htttps) + with pytest.raises(pa.ArrowInvalid): + S3FileSystem(proxy_options='htttps://localhost:9000') + # Invalid proxy_options dict (invalid scheme htttps) + with pytest.raises(pa.ArrowInvalid): + S3FileSystem(proxy_options={'scheme': 'htttp', 'host': 'localhost', + 'port': 8999}) + + +@pytest.mark.hdfs +def test_hdfs_options(hdfs_connection): + from pyarrow.fs import HadoopFileSystem + if not pa.have_libhdfs(): + pytest.skip('Cannot locate libhdfs') + + host, port, user = hdfs_connection + + replication = 2 + buffer_size = 64*1024 + default_block_size = 128*1024**2 + uri = ('hdfs://{}:{}/?user={}&replication={}&buffer_size={}' + '&default_block_size={}') + + hdfs1 = HadoopFileSystem(host, port, user='libhdfs', + replication=replication, buffer_size=buffer_size, + default_block_size=default_block_size) + hdfs2 = HadoopFileSystem.from_uri(uri.format( + host, port, 'libhdfs', replication, buffer_size, default_block_size + )) + hdfs3 = HadoopFileSystem.from_uri(uri.format( + host, port, 'me', replication, buffer_size, default_block_size + )) + hdfs4 = HadoopFileSystem.from_uri(uri.format( + host, port, 'me', replication + 1, buffer_size, default_block_size + )) + hdfs5 = HadoopFileSystem(host, port) + hdfs6 = HadoopFileSystem.from_uri('hdfs://{}:{}'.format(host, port)) + hdfs7 = HadoopFileSystem(host, port, user='localuser') + hdfs8 = HadoopFileSystem(host, port, user='localuser', + kerb_ticket="cache_path") + hdfs9 = HadoopFileSystem(host, port, user='localuser', + kerb_ticket=pathlib.Path("cache_path")) + hdfs10 = HadoopFileSystem(host, port, user='localuser', + kerb_ticket="cache_path2") + hdfs11 = HadoopFileSystem(host, port, user='localuser', + kerb_ticket="cache_path", + extra_conf={'hdfs_token': 'abcd'}) + + assert hdfs1 == hdfs2 + assert hdfs5 == hdfs6 + assert hdfs6 != hdfs7 + assert hdfs2 != hdfs3 + assert hdfs3 != hdfs4 + assert hdfs7 != hdfs5 + assert hdfs2 != hdfs3 + assert hdfs3 != hdfs4 + assert hdfs7 != hdfs8 + assert hdfs8 == hdfs9 + assert hdfs10 != hdfs9 + assert hdfs11 != hdfs8 + + with pytest.raises(TypeError): + HadoopFileSystem() + with pytest.raises(TypeError): + HadoopFileSystem.from_uri(3) + + for fs in [hdfs1, hdfs2, hdfs3, hdfs4, hdfs5, hdfs6, hdfs7, hdfs8, + hdfs9, hdfs10, hdfs11]: + assert pickle.loads(pickle.dumps(fs)) == fs + + host, port, user = hdfs_connection + + hdfs = HadoopFileSystem(host, port, user=user) + assert hdfs.get_file_info(FileSelector('/')) + + hdfs = HadoopFileSystem.from_uri( + "hdfs://{}:{}/?user={}".format(host, port, user) + ) + assert hdfs.get_file_info(FileSelector('/')) + + +@pytest.mark.parametrize(('uri', 'expected_klass', 'expected_path'), [ + # leading slashes are removed intentionally, because MockFileSystem doesn't + # have a distinction between relative and absolute paths + ('mock:', _MockFileSystem, ''), + ('mock:foo/bar', _MockFileSystem, 'foo/bar'), + ('mock:/foo/bar', _MockFileSystem, 'foo/bar'), + ('mock:///foo/bar', _MockFileSystem, 'foo/bar'), + ('file:/', LocalFileSystem, '/'), + ('file:///', LocalFileSystem, '/'), + ('file:/foo/bar', LocalFileSystem, '/foo/bar'), + ('file:///foo/bar', LocalFileSystem, '/foo/bar'), + ('/', LocalFileSystem, '/'), + ('/foo/bar', LocalFileSystem, '/foo/bar'), +]) +def test_filesystem_from_uri(uri, expected_klass, expected_path): + fs, path = FileSystem.from_uri(uri) + assert isinstance(fs, expected_klass) + assert path == expected_path + + +@pytest.mark.parametrize( + 'path', + ['', '/', 'foo/bar', '/foo/bar', __file__] +) +def test_filesystem_from_path_object(path): + p = pathlib.Path(path) + fs, path = FileSystem.from_uri(p) + assert isinstance(fs, LocalFileSystem) + assert path == p.resolve().absolute().as_posix() + + +@pytest.mark.s3 +def test_filesystem_from_uri_s3(s3_server): + from pyarrow.fs import S3FileSystem + + host, port, access_key, secret_key = s3_server['connection'] + + uri = "s3://{}:{}@mybucket/foo/bar?scheme=http&endpoint_override={}:{}" \ + .format(access_key, secret_key, host, port) + + fs, path = FileSystem.from_uri(uri) + assert isinstance(fs, S3FileSystem) + assert path == "mybucket/foo/bar" + + fs.create_dir(path) + [info] = fs.get_file_info([path]) + assert info.path == path + assert info.type == FileType.Directory + + +def test_py_filesystem(): + handler = DummyHandler() + fs = PyFileSystem(handler) + assert isinstance(fs, PyFileSystem) + assert fs.type_name == "py::dummy" + assert fs.handler is handler + + with pytest.raises(TypeError): + PyFileSystem(None) + + +def test_py_filesystem_equality(): + handler1 = DummyHandler(1) + handler2 = DummyHandler(2) + handler3 = DummyHandler(2) + fs1 = PyFileSystem(handler1) + fs2 = PyFileSystem(handler1) + fs3 = PyFileSystem(handler2) + fs4 = PyFileSystem(handler3) + + assert fs2 is not fs1 + assert fs3 is not fs2 + assert fs4 is not fs3 + assert fs2 == fs1 # Same handler + assert fs3 != fs2 # Unequal handlers + assert fs4 == fs3 # Equal handlers + + assert fs1 != LocalFileSystem() + assert fs1 != object() + + +def test_py_filesystem_pickling(): + handler = DummyHandler() + fs = PyFileSystem(handler) + + serialized = pickle.dumps(fs) + restored = pickle.loads(serialized) + assert isinstance(restored, FileSystem) + assert restored == fs + assert restored.handler == handler + assert restored.type_name == "py::dummy" + + +def test_py_filesystem_lifetime(): + handler = DummyHandler() + fs = PyFileSystem(handler) + assert isinstance(fs, PyFileSystem) + wr = weakref.ref(handler) + handler = None + assert wr() is not None + fs = None + assert wr() is None + + # Taking the .handler attribute doesn't wreck reference counts + handler = DummyHandler() + fs = PyFileSystem(handler) + wr = weakref.ref(handler) + handler = None + assert wr() is fs.handler + assert wr() is not None + fs = None + assert wr() is None + + +def test_py_filesystem_get_file_info(): + handler = DummyHandler() + fs = PyFileSystem(handler) + + [info] = fs.get_file_info(['some/dir']) + assert info.path == 'some/dir' + assert info.type == FileType.Directory + + [info] = fs.get_file_info(['some/file']) + assert info.path == 'some/file' + assert info.type == FileType.File + + [info] = fs.get_file_info(['notfound']) + assert info.path == 'notfound' + assert info.type == FileType.NotFound + + with pytest.raises(TypeError): + fs.get_file_info(['badtype']) + + with pytest.raises(IOError): + fs.get_file_info(['xxx']) + + +def test_py_filesystem_get_file_info_selector(): + handler = DummyHandler() + fs = PyFileSystem(handler) + + selector = FileSelector(base_dir="somedir") + infos = fs.get_file_info(selector) + assert len(infos) == 2 + assert infos[0].path == "somedir/file1" + assert infos[0].type == FileType.File + assert infos[0].size == 123 + assert infos[1].path == "somedir/subdir1" + assert infos[1].type == FileType.Directory + assert infos[1].size is None + + selector = FileSelector(base_dir="somedir", recursive=True) + infos = fs.get_file_info(selector) + assert len(infos) == 3 + assert infos[0].path == "somedir/file1" + assert infos[1].path == "somedir/subdir1" + assert infos[2].path == "somedir/subdir1/file2" + + selector = FileSelector(base_dir="notfound") + with pytest.raises(FileNotFoundError): + fs.get_file_info(selector) + + selector = FileSelector(base_dir="notfound", allow_not_found=True) + assert fs.get_file_info(selector) == [] + + +def test_py_filesystem_ops(): + handler = DummyHandler() + fs = PyFileSystem(handler) + + fs.create_dir("recursive", recursive=True) + fs.create_dir("non-recursive", recursive=False) + with pytest.raises(IOError): + fs.create_dir("foobar") + + fs.delete_dir("delete_dir") + fs.delete_dir_contents("delete_dir_contents") + for path in ("", "/", "//"): + with pytest.raises(ValueError): + fs.delete_dir_contents(path) + fs.delete_dir_contents(path, accept_root_dir=True) + fs.delete_file("delete_file") + fs.move("move_from", "move_to") + fs.copy_file("copy_file_from", "copy_file_to") + + +def test_py_open_input_stream(): + fs = PyFileSystem(DummyHandler()) + + with fs.open_input_stream("somefile") as f: + assert f.read() == b"somefile:input_stream" + with pytest.raises(FileNotFoundError): + fs.open_input_stream("notfound") + + +def test_py_open_input_file(): + fs = PyFileSystem(DummyHandler()) + + with fs.open_input_file("somefile") as f: + assert f.read() == b"somefile:input_file" + with pytest.raises(FileNotFoundError): + fs.open_input_file("notfound") + + +def test_py_open_output_stream(): + fs = PyFileSystem(DummyHandler()) + + with fs.open_output_stream("somefile") as f: + f.write(b"data") + + +@pytest.mark.filterwarnings("ignore::FutureWarning") +def test_py_open_append_stream(): + fs = PyFileSystem(DummyHandler()) + + with fs.open_append_stream("somefile") as f: + f.write(b"data") + + +@pytest.mark.s3 +def test_s3_real_aws(): + # Exercise connection code with an AWS-backed S3 bucket. + # This is a minimal integration check for ARROW-9261 and similar issues. + from pyarrow.fs import S3FileSystem + default_region = (os.environ.get('PYARROW_TEST_S3_REGION') or + 'us-east-1') + fs = S3FileSystem(anonymous=True) + assert fs.region == default_region + + fs = S3FileSystem(anonymous=True, region='us-east-2') + entries = fs.get_file_info(FileSelector('ursa-labs-taxi-data')) + assert len(entries) > 0 + with fs.open_input_stream('ursa-labs-taxi-data/2019/06/data.parquet') as f: + md = f.metadata() + assert 'Content-Type' in md + assert md['Last-Modified'] == b'2020-01-17T16:26:28Z' + # For some reason, the header value is quoted + # (both with AWS and Minio) + assert md['ETag'] == b'"f1efd5d76cb82861e1542117bfa52b90-8"' + + +@pytest.mark.s3 +def test_s3_real_aws_region_selection(): + # Taken from a registry of open S3-hosted datasets + # at https://github.com/awslabs/open-data-registry + fs, path = FileSystem.from_uri('s3://mf-nwp-models/README.txt') + assert fs.region == 'eu-west-1' + with fs.open_input_stream(path) as f: + assert b"Meteo-France Atmospheric models on AWS" in f.read(50) + + # Passing an explicit region disables auto-selection + fs, path = FileSystem.from_uri( + 's3://mf-nwp-models/README.txt?region=us-east-2') + assert fs.region == 'us-east-2' + # Reading from the wrong region may still work for public buckets... + + # Non-existent bucket (hopefully, otherwise need to fix this test) + with pytest.raises(IOError, match="Bucket '.*' not found"): + FileSystem.from_uri('s3://x-arrow-non-existent-bucket') + fs, path = FileSystem.from_uri( + 's3://x-arrow-non-existent-bucket?region=us-east-3') + assert fs.region == 'us-east-3' + + +@pytest.mark.s3 +def test_copy_files(s3_connection, s3fs, tempdir): + fs = s3fs["fs"] + pathfn = s3fs["pathfn"] + + # create test file on S3 filesystem + path = pathfn('c.txt') + with fs.open_output_stream(path) as f: + f.write(b'test') + + # create URI for created file + host, port, access_key, secret_key = s3_connection + source_uri = ( + f"s3://{access_key}:{secret_key}@{path}" + f"?scheme=http&endpoint_override={host}:{port}" + ) + # copy from S3 URI to local file + local_path1 = str(tempdir / "c_copied1.txt") + copy_files(source_uri, local_path1) + + localfs = LocalFileSystem() + with localfs.open_input_stream(local_path1) as f: + assert f.read() == b"test" + + # copy from S3 path+filesystem to local file + local_path2 = str(tempdir / "c_copied2.txt") + copy_files(path, local_path2, source_filesystem=fs) + with localfs.open_input_stream(local_path2) as f: + assert f.read() == b"test" + + # copy to local file with URI + local_path3 = str(tempdir / "c_copied3.txt") + destination_uri = _filesystem_uri(local_path3) # file:// + copy_files(source_uri, destination_uri) + + with localfs.open_input_stream(local_path3) as f: + assert f.read() == b"test" + + # copy to local file with path+filesystem + local_path4 = str(tempdir / "c_copied4.txt") + copy_files(source_uri, local_path4, destination_filesystem=localfs) + + with localfs.open_input_stream(local_path4) as f: + assert f.read() == b"test" + + # copy with additional options + local_path5 = str(tempdir / "c_copied5.txt") + copy_files(source_uri, local_path5, chunk_size=1, use_threads=False) + + with localfs.open_input_stream(local_path5) as f: + assert f.read() == b"test" + + +def test_copy_files_directory(tempdir): + localfs = LocalFileSystem() + + # create source directory with 2 files + source_dir = tempdir / "source" + source_dir.mkdir() + with localfs.open_output_stream(str(source_dir / "file1")) as f: + f.write(b'test1') + with localfs.open_output_stream(str(source_dir / "file2")) as f: + f.write(b'test2') + + def check_copied_files(destination_dir): + with localfs.open_input_stream(str(destination_dir / "file1")) as f: + assert f.read() == b"test1" + with localfs.open_input_stream(str(destination_dir / "file2")) as f: + assert f.read() == b"test2" + + # Copy directory with local file paths + destination_dir1 = tempdir / "destination1" + # TODO need to create? + destination_dir1.mkdir() + copy_files(str(source_dir), str(destination_dir1)) + check_copied_files(destination_dir1) + + # Copy directory with path+filesystem + destination_dir2 = tempdir / "destination2" + destination_dir2.mkdir() + copy_files(str(source_dir), str(destination_dir2), + source_filesystem=localfs, destination_filesystem=localfs) + check_copied_files(destination_dir2) + + # Copy directory with URI + destination_dir3 = tempdir / "destination3" + destination_dir3.mkdir() + source_uri = _filesystem_uri(str(source_dir)) # file:// + destination_uri = _filesystem_uri(str(destination_dir3)) + copy_files(source_uri, destination_uri) + check_copied_files(destination_dir3) + + # Copy directory with Path objects + destination_dir4 = tempdir / "destination4" + destination_dir4.mkdir() + copy_files(source_dir, destination_dir4) + check_copied_files(destination_dir4) + + # copy with additional non-default options + destination_dir5 = tempdir / "destination5" + destination_dir5.mkdir() + copy_files(source_dir, destination_dir5, chunk_size=1, use_threads=False) + check_copied_files(destination_dir5) diff --git a/src/arrow/python/pyarrow/tests/test_gandiva.py b/src/arrow/python/pyarrow/tests/test_gandiva.py new file mode 100644 index 000000000..6522c233a --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_gandiva.py @@ -0,0 +1,391 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import pytest + +import pyarrow as pa + + +@pytest.mark.gandiva +def test_tree_exp_builder(): + import pyarrow.gandiva as gandiva + + builder = gandiva.TreeExprBuilder() + + field_a = pa.field('a', pa.int32()) + field_b = pa.field('b', pa.int32()) + + schema = pa.schema([field_a, field_b]) + + field_result = pa.field('res', pa.int32()) + + node_a = builder.make_field(field_a) + node_b = builder.make_field(field_b) + + assert node_a.return_type() == field_a.type + + condition = builder.make_function("greater_than", [node_a, node_b], + pa.bool_()) + if_node = builder.make_if(condition, node_a, node_b, pa.int32()) + + expr = builder.make_expression(if_node, field_result) + + assert expr.result().type == pa.int32() + + projector = gandiva.make_projector( + schema, [expr], pa.default_memory_pool()) + + # Gandiva generates compute kernel function named `@expr_X` + assert projector.llvm_ir.find("@expr_") != -1 + + a = pa.array([10, 12, -20, 5], type=pa.int32()) + b = pa.array([5, 15, 15, 17], type=pa.int32()) + e = pa.array([10, 15, 15, 17], type=pa.int32()) + input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b']) + + r, = projector.evaluate(input_batch) + assert r.equals(e) + + +@pytest.mark.gandiva +def test_table(): + import pyarrow.gandiva as gandiva + + table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])], + ['a', 'b']) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + node_b = builder.make_field(table.schema.field("b")) + + sum = builder.make_function("add", [node_a, node_b], pa.float64()) + + field_result = pa.field("c", pa.float64()) + expr = builder.make_expression(sum, field_result) + + projector = gandiva.make_projector( + table.schema, [expr], pa.default_memory_pool()) + + # TODO: Add .evaluate function which can take Tables instead of + # RecordBatches + r, = projector.evaluate(table.to_batches()[0]) + + e = pa.array([4.0, 6.0]) + assert r.equals(e) + + +@pytest.mark.gandiva +def test_filter(): + import pyarrow.gandiva as gandiva + + table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])], + ['a']) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + thousand = builder.make_literal(1000.0, pa.float64()) + cond = builder.make_function("less_than", [node_a, thousand], pa.bool_()) + condition = builder.make_condition(cond) + + assert condition.result().type == pa.bool_() + + filter = gandiva.make_filter(table.schema, condition) + # Gandiva generates compute kernel function named `@expr_X` + assert filter.llvm_ir.find("@expr_") != -1 + + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert result.to_array().equals(pa.array(range(1000), type=pa.uint32())) + + +@pytest.mark.gandiva +def test_in_expr(): + import pyarrow.gandiva as gandiva + + arr = pa.array(["ga", "an", "nd", "di", "iv", "va"]) + table = pa.Table.from_arrays([arr], ["a"]) + + # string + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + cond = builder.make_in_expression(node_a, ["an", "nd"], pa.string()) + condition = builder.make_condition(cond) + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert result.to_array().equals(pa.array([1, 2], type=pa.uint32())) + + # int32 + arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4]) + table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"]) + node_a = builder.make_field(table.schema.field("a")) + cond = builder.make_in_expression(node_a, [1, 5], pa.int32()) + condition = builder.make_condition(cond) + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32())) + + # int64 + arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4]) + table = pa.Table.from_arrays([arr], ["a"]) + node_a = builder.make_field(table.schema.field("a")) + cond = builder.make_in_expression(node_a, [1, 5], pa.int64()) + condition = builder.make_condition(cond) + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32())) + + +@pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, " + "time and date support.") +def test_in_expr_todo(): + import pyarrow.gandiva as gandiva + # TODO: Implement reasonable support for timestamp, time & date. + # Current exceptions: + # pyarrow.lib.ArrowException: ExpressionValidationError: + # Evaluation expression for IN clause returns XXXX values are of typeXXXX + + # binary + arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"]) + table = pa.Table.from_arrays([arr], ["a"]) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary()) + condition = builder.make_condition(cond) + + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert result.to_array().equals(pa.array([1, 2], type=pa.uint32())) + + # timestamp + datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877) + datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877) + datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877) + + arr = pa.array([datetime_1, datetime_2, datetime_3]) + table = pa.Table.from_arrays([arr], ["a"]) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms')) + condition = builder.make_condition(cond) + + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert list(result.to_array()) == [1] + + # time + time_1 = datetime_1.time() + time_2 = datetime_2.time() + time_3 = datetime_3.time() + + arr = pa.array([time_1, time_2, time_3]) + table = pa.Table.from_arrays([arr], ["a"]) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms')) + condition = builder.make_condition(cond) + + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert list(result.to_array()) == [1] + + # date + date_1 = datetime_1.date() + date_2 = datetime_2.date() + date_3 = datetime_3.date() + + arr = pa.array([date_1, date_2, date_3]) + table = pa.Table.from_arrays([arr], ["a"]) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + cond = builder.make_in_expression(node_a, [date_2], pa.date32()) + condition = builder.make_condition(cond) + + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert list(result.to_array()) == [1] + + +@pytest.mark.gandiva +def test_boolean(): + import pyarrow.gandiva as gandiva + + table = pa.Table.from_arrays([ + pa.array([1., 31., 46., 3., 57., 44., 22.]), + pa.array([5., 45., 36., 73., 83., 23., 76.])], + ['a', 'b']) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + node_b = builder.make_field(table.schema.field("b")) + fifty = builder.make_literal(50.0, pa.float64()) + eleven = builder.make_literal(11.0, pa.float64()) + + cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_()) + cond_2 = builder.make_function("greater_than", [node_a, node_b], + pa.bool_()) + cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_()) + cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3]) + condition = builder.make_condition(cond) + + filter = gandiva.make_filter(table.schema, condition) + result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool()) + assert result.to_array().equals(pa.array([0, 2, 5], type=pa.uint32())) + + +@pytest.mark.gandiva +def test_literals(): + import pyarrow.gandiva as gandiva + + builder = gandiva.TreeExprBuilder() + + builder.make_literal(True, pa.bool_()) + builder.make_literal(0, pa.uint8()) + builder.make_literal(1, pa.uint16()) + builder.make_literal(2, pa.uint32()) + builder.make_literal(3, pa.uint64()) + builder.make_literal(4, pa.int8()) + builder.make_literal(5, pa.int16()) + builder.make_literal(6, pa.int32()) + builder.make_literal(7, pa.int64()) + builder.make_literal(8.0, pa.float32()) + builder.make_literal(9.0, pa.float64()) + builder.make_literal("hello", pa.string()) + builder.make_literal(b"world", pa.binary()) + + builder.make_literal(True, "bool") + builder.make_literal(0, "uint8") + builder.make_literal(1, "uint16") + builder.make_literal(2, "uint32") + builder.make_literal(3, "uint64") + builder.make_literal(4, "int8") + builder.make_literal(5, "int16") + builder.make_literal(6, "int32") + builder.make_literal(7, "int64") + builder.make_literal(8.0, "float32") + builder.make_literal(9.0, "float64") + builder.make_literal("hello", "string") + builder.make_literal(b"world", "binary") + + with pytest.raises(TypeError): + builder.make_literal("hello", pa.int64()) + with pytest.raises(TypeError): + builder.make_literal(True, None) + + +@pytest.mark.gandiva +def test_regex(): + import pyarrow.gandiva as gandiva + + elements = ["park", "sparkle", "bright spark and fire", "spark"] + data = pa.array(elements, type=pa.string()) + table = pa.Table.from_arrays([data], names=['a']) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + regex = builder.make_literal("%spark%", pa.string()) + like = builder.make_function("like", [node_a, regex], pa.bool_()) + + field_result = pa.field("b", pa.bool_()) + expr = builder.make_expression(like, field_result) + + projector = gandiva.make_projector( + table.schema, [expr], pa.default_memory_pool()) + + r, = projector.evaluate(table.to_batches()[0]) + b = pa.array([False, True, True, True], type=pa.bool_()) + assert r.equals(b) + + +@pytest.mark.gandiva +def test_get_registered_function_signatures(): + import pyarrow.gandiva as gandiva + signatures = gandiva.get_registered_function_signatures() + + assert type(signatures[0].return_type()) is pa.DataType + assert type(signatures[0].param_types()) is list + assert hasattr(signatures[0], "name") + + +@pytest.mark.gandiva +def test_filter_project(): + import pyarrow.gandiva as gandiva + mpool = pa.default_memory_pool() + # Create a table with some sample data + array0 = pa.array([10, 12, -20, 5, 21, 29], pa.int32()) + array1 = pa.array([5, 15, 15, 17, 12, 3], pa.int32()) + array2 = pa.array([1, 25, 11, 30, -21, None], pa.int32()) + + table = pa.Table.from_arrays([array0, array1, array2], ['a', 'b', 'c']) + + field_result = pa.field("res", pa.int32()) + + builder = gandiva.TreeExprBuilder() + node_a = builder.make_field(table.schema.field("a")) + node_b = builder.make_field(table.schema.field("b")) + node_c = builder.make_field(table.schema.field("c")) + + greater_than_function = builder.make_function("greater_than", + [node_a, node_b], pa.bool_()) + filter_condition = builder.make_condition( + greater_than_function) + + project_condition = builder.make_function("less_than", + [node_b, node_c], pa.bool_()) + if_node = builder.make_if(project_condition, + node_b, node_c, pa.int32()) + expr = builder.make_expression(if_node, field_result) + + # Build a filter for the expressions. + filter = gandiva.make_filter(table.schema, filter_condition) + + # Build a projector for the expressions. + projector = gandiva.make_projector( + table.schema, [expr], mpool, "UINT32") + + # Evaluate filter + selection_vector = filter.evaluate(table.to_batches()[0], mpool) + + # Evaluate project + r, = projector.evaluate( + table.to_batches()[0], selection_vector) + + exp = pa.array([1, -21, None], pa.int32()) + assert r.equals(exp) + + +@pytest.mark.gandiva +def test_to_string(): + import pyarrow.gandiva as gandiva + builder = gandiva.TreeExprBuilder() + + assert str(builder.make_literal(2.0, pa.float64()) + ).startswith('(const double) 2 raw(') + assert str(builder.make_literal(2, pa.int64())) == '(const int64) 2' + assert str(builder.make_field(pa.field('x', pa.float64()))) == '(double) x' + assert str(builder.make_field(pa.field('y', pa.string()))) == '(string) y' + + field_z = builder.make_field(pa.field('z', pa.bool_())) + func_node = builder.make_function('not', [field_z], pa.bool_()) + assert str(func_node) == 'bool not((bool) z)' + + field_y = builder.make_field(pa.field('y', pa.bool_())) + and_node = builder.make_and([func_node, field_y]) + assert str(and_node) == 'bool not((bool) z) && (bool) y' diff --git a/src/arrow/python/pyarrow/tests/test_hdfs.py b/src/arrow/python/pyarrow/tests/test_hdfs.py new file mode 100644 index 000000000..c71353b45 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_hdfs.py @@ -0,0 +1,447 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import pickle +import random +import unittest +from io import BytesIO +from os.path import join as pjoin + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow.pandas_compat import _pandas_api +from pyarrow.tests import util +from pyarrow.tests.parquet.common import _test_dataframe +from pyarrow.tests.parquet.test_dataset import ( + _test_read_common_metadata_files, _test_write_to_dataset_with_partitions, + _test_write_to_dataset_no_partitions +) +from pyarrow.util import guid + +# ---------------------------------------------------------------------- +# HDFS tests + + +def check_libhdfs_present(): + if not pa.have_libhdfs(): + message = 'No libhdfs available on system' + if os.environ.get('PYARROW_HDFS_TEST_LIBHDFS_REQUIRE'): + pytest.fail(message) + else: + pytest.skip(message) + + +def hdfs_test_client(): + host = os.environ.get('ARROW_HDFS_TEST_HOST', 'default') + user = os.environ.get('ARROW_HDFS_TEST_USER', None) + try: + port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0)) + except ValueError: + raise ValueError('Env variable ARROW_HDFS_TEST_PORT was not ' + 'an integer') + + with pytest.warns(FutureWarning): + return pa.hdfs.connect(host, port, user) + + +@pytest.mark.hdfs +class HdfsTestCases: + + def _make_test_file(self, hdfs, test_name, test_path, test_data): + base_path = pjoin(self.tmp_path, test_name) + hdfs.mkdir(base_path) + + full_path = pjoin(base_path, test_path) + + with hdfs.open(full_path, 'wb') as f: + f.write(test_data) + + return full_path + + @classmethod + def setUpClass(cls): + cls.check_driver() + cls.hdfs = hdfs_test_client() + cls.tmp_path = '/tmp/pyarrow-test-{}'.format(random.randint(0, 1000)) + cls.hdfs.mkdir(cls.tmp_path) + + @classmethod + def tearDownClass(cls): + cls.hdfs.delete(cls.tmp_path, recursive=True) + cls.hdfs.close() + + def test_pickle(self): + s = pickle.dumps(self.hdfs) + h2 = pickle.loads(s) + assert h2.is_open + assert h2.host == self.hdfs.host + assert h2.port == self.hdfs.port + assert h2.user == self.hdfs.user + assert h2.kerb_ticket == self.hdfs.kerb_ticket + # smoketest unpickled client works + h2.ls(self.tmp_path) + + def test_cat(self): + path = pjoin(self.tmp_path, 'cat-test') + + data = b'foobarbaz' + with self.hdfs.open(path, 'wb') as f: + f.write(data) + + contents = self.hdfs.cat(path) + assert contents == data + + def test_capacity_space(self): + capacity = self.hdfs.get_capacity() + space_used = self.hdfs.get_space_used() + disk_free = self.hdfs.df() + + assert capacity > 0 + assert capacity > space_used + assert disk_free == (capacity - space_used) + + def test_close(self): + client = hdfs_test_client() + assert client.is_open + client.close() + assert not client.is_open + + with pytest.raises(Exception): + client.ls('/') + + def test_mkdir(self): + path = pjoin(self.tmp_path, 'test-dir/test-dir') + parent_path = pjoin(self.tmp_path, 'test-dir') + + self.hdfs.mkdir(path) + assert self.hdfs.exists(path) + + self.hdfs.delete(parent_path, recursive=True) + assert not self.hdfs.exists(path) + + def test_mv_rename(self): + path = pjoin(self.tmp_path, 'mv-test') + new_path = pjoin(self.tmp_path, 'mv-new-test') + + data = b'foobarbaz' + with self.hdfs.open(path, 'wb') as f: + f.write(data) + + assert self.hdfs.exists(path) + self.hdfs.mv(path, new_path) + assert not self.hdfs.exists(path) + assert self.hdfs.exists(new_path) + + assert self.hdfs.cat(new_path) == data + + self.hdfs.rename(new_path, path) + assert self.hdfs.cat(path) == data + + def test_info(self): + path = pjoin(self.tmp_path, 'info-base') + file_path = pjoin(path, 'ex') + self.hdfs.mkdir(path) + + data = b'foobarbaz' + with self.hdfs.open(file_path, 'wb') as f: + f.write(data) + + path_info = self.hdfs.info(path) + file_path_info = self.hdfs.info(file_path) + + assert path_info['kind'] == 'directory' + + assert file_path_info['kind'] == 'file' + assert file_path_info['size'] == len(data) + + def test_exists_isdir_isfile(self): + dir_path = pjoin(self.tmp_path, 'info-base') + file_path = pjoin(dir_path, 'ex') + missing_path = pjoin(dir_path, 'this-path-is-missing') + + self.hdfs.mkdir(dir_path) + with self.hdfs.open(file_path, 'wb') as f: + f.write(b'foobarbaz') + + assert self.hdfs.exists(dir_path) + assert self.hdfs.exists(file_path) + assert not self.hdfs.exists(missing_path) + + assert self.hdfs.isdir(dir_path) + assert not self.hdfs.isdir(file_path) + assert not self.hdfs.isdir(missing_path) + + assert not self.hdfs.isfile(dir_path) + assert self.hdfs.isfile(file_path) + assert not self.hdfs.isfile(missing_path) + + def test_disk_usage(self): + path = pjoin(self.tmp_path, 'disk-usage-base') + p1 = pjoin(path, 'p1') + p2 = pjoin(path, 'p2') + + subdir = pjoin(path, 'subdir') + p3 = pjoin(subdir, 'p3') + + if self.hdfs.exists(path): + self.hdfs.delete(path, True) + + self.hdfs.mkdir(path) + self.hdfs.mkdir(subdir) + + data = b'foobarbaz' + + for file_path in [p1, p2, p3]: + with self.hdfs.open(file_path, 'wb') as f: + f.write(data) + + assert self.hdfs.disk_usage(path) == len(data) * 3 + + def test_ls(self): + base_path = pjoin(self.tmp_path, 'ls-test') + self.hdfs.mkdir(base_path) + + dir_path = pjoin(base_path, 'a-dir') + f1_path = pjoin(base_path, 'a-file-1') + + self.hdfs.mkdir(dir_path) + + f = self.hdfs.open(f1_path, 'wb') + f.write(b'a' * 10) + + contents = sorted(self.hdfs.ls(base_path, False)) + assert contents == [dir_path, f1_path] + + def test_chmod_chown(self): + path = pjoin(self.tmp_path, 'chmod-test') + with self.hdfs.open(path, 'wb') as f: + f.write(b'a' * 10) + + def test_download_upload(self): + base_path = pjoin(self.tmp_path, 'upload-test') + + data = b'foobarbaz' + buf = BytesIO(data) + buf.seek(0) + + self.hdfs.upload(base_path, buf) + + out_buf = BytesIO() + self.hdfs.download(base_path, out_buf) + out_buf.seek(0) + assert out_buf.getvalue() == data + + def test_file_context_manager(self): + path = pjoin(self.tmp_path, 'ctx-manager') + + data = b'foo' + with self.hdfs.open(path, 'wb') as f: + f.write(data) + + with self.hdfs.open(path, 'rb') as f: + assert f.size() == 3 + result = f.read(10) + assert result == data + + def test_open_not_exist(self): + path = pjoin(self.tmp_path, 'does-not-exist-123') + + with pytest.raises(FileNotFoundError): + self.hdfs.open(path) + + def test_open_write_error(self): + with pytest.raises((FileExistsError, IsADirectoryError)): + self.hdfs.open('/', 'wb') + + def test_read_whole_file(self): + path = pjoin(self.tmp_path, 'read-whole-file') + + data = b'foo' * 1000 + with self.hdfs.open(path, 'wb') as f: + f.write(data) + + with self.hdfs.open(path, 'rb') as f: + result = f.read() + + assert result == data + + def _write_multiple_hdfs_pq_files(self, tmpdir): + import pyarrow.parquet as pq + nfiles = 10 + size = 5 + test_data = [] + for i in range(nfiles): + df = _test_dataframe(size, seed=i) + + df['index'] = np.arange(i * size, (i + 1) * size) + + # Hack so that we don't have a dtype cast in v1 files + df['uint32'] = df['uint32'].astype(np.int64) + + path = pjoin(tmpdir, '{}.parquet'.format(i)) + + table = pa.Table.from_pandas(df, preserve_index=False) + with self.hdfs.open(path, 'wb') as f: + pq.write_table(table, f) + + test_data.append(table) + + expected = pa.concat_tables(test_data) + return expected + + @pytest.mark.pandas + @pytest.mark.parquet + def test_read_multiple_parquet_files(self): + + tmpdir = pjoin(self.tmp_path, 'multi-parquet-' + guid()) + + self.hdfs.mkdir(tmpdir) + + expected = self._write_multiple_hdfs_pq_files(tmpdir) + result = self.hdfs.read_parquet(tmpdir) + + _pandas_api.assert_frame_equal(result.to_pandas() + .sort_values(by='index') + .reset_index(drop=True), + expected.to_pandas()) + + @pytest.mark.pandas + @pytest.mark.parquet + def test_read_multiple_parquet_files_with_uri(self): + import pyarrow.parquet as pq + + tmpdir = pjoin(self.tmp_path, 'multi-parquet-uri-' + guid()) + + self.hdfs.mkdir(tmpdir) + + expected = self._write_multiple_hdfs_pq_files(tmpdir) + path = _get_hdfs_uri(tmpdir) + result = pq.read_table(path) + + _pandas_api.assert_frame_equal(result.to_pandas() + .sort_values(by='index') + .reset_index(drop=True), + expected.to_pandas()) + + @pytest.mark.pandas + @pytest.mark.parquet + def test_read_write_parquet_files_with_uri(self): + import pyarrow.parquet as pq + + tmpdir = pjoin(self.tmp_path, 'uri-parquet-' + guid()) + self.hdfs.mkdir(tmpdir) + path = _get_hdfs_uri(pjoin(tmpdir, 'test.parquet')) + + size = 5 + df = _test_dataframe(size, seed=0) + # Hack so that we don't have a dtype cast in v1 files + df['uint32'] = df['uint32'].astype(np.int64) + table = pa.Table.from_pandas(df, preserve_index=False) + + pq.write_table(table, path, filesystem=self.hdfs) + + result = pq.read_table( + path, filesystem=self.hdfs, use_legacy_dataset=True + ).to_pandas() + + _pandas_api.assert_frame_equal(result, df) + + @pytest.mark.parquet + @pytest.mark.pandas + def test_read_common_metadata_files(self): + tmpdir = pjoin(self.tmp_path, 'common-metadata-' + guid()) + self.hdfs.mkdir(tmpdir) + _test_read_common_metadata_files(self.hdfs, tmpdir) + + @pytest.mark.parquet + @pytest.mark.pandas + def test_write_to_dataset_with_partitions(self): + tmpdir = pjoin(self.tmp_path, 'write-partitions-' + guid()) + self.hdfs.mkdir(tmpdir) + _test_write_to_dataset_with_partitions( + tmpdir, filesystem=self.hdfs) + + @pytest.mark.parquet + @pytest.mark.pandas + def test_write_to_dataset_no_partitions(self): + tmpdir = pjoin(self.tmp_path, 'write-no_partitions-' + guid()) + self.hdfs.mkdir(tmpdir) + _test_write_to_dataset_no_partitions( + tmpdir, filesystem=self.hdfs) + + +class TestLibHdfs(HdfsTestCases, unittest.TestCase): + + @classmethod + def check_driver(cls): + check_libhdfs_present() + + def test_orphaned_file(self): + hdfs = hdfs_test_client() + file_path = self._make_test_file(hdfs, 'orphaned_file_test', 'fname', + b'foobarbaz') + + f = hdfs.open(file_path) + hdfs = None + f = None # noqa + + +def _get_hdfs_uri(path): + host = os.environ.get('ARROW_HDFS_TEST_HOST', 'localhost') + try: + port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 0)) + except ValueError: + raise ValueError('Env variable ARROW_HDFS_TEST_PORT was not ' + 'an integer') + uri = "hdfs://{}:{}{}".format(host, port, path) + + return uri + + +@pytest.mark.hdfs +@pytest.mark.pandas +@pytest.mark.parquet +@pytest.mark.fastparquet +def test_fastparquet_read_with_hdfs(): + from pandas.testing import assert_frame_equal + + check_libhdfs_present() + try: + import snappy # noqa + except ImportError: + pytest.skip('fastparquet test requires snappy') + + import pyarrow.parquet as pq + fastparquet = pytest.importorskip('fastparquet') + + fs = hdfs_test_client() + + df = util.make_dataframe() + + table = pa.Table.from_pandas(df) + + path = '/tmp/testing.parquet' + with fs.open(path, 'wb') as f: + pq.write_table(table, f) + + parquet_file = fastparquet.ParquetFile(path, open_with=fs.open) + + result = parquet_file.to_pandas() + assert_frame_equal(result, df) diff --git a/src/arrow/python/pyarrow/tests/test_io.py b/src/arrow/python/pyarrow/tests/test_io.py new file mode 100644 index 000000000..ea1e5e557 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_io.py @@ -0,0 +1,1886 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import bz2 +from contextlib import contextmanager +from io import (BytesIO, StringIO, TextIOWrapper, BufferedIOBase, IOBase) +import itertools +import gc +import gzip +import os +import pathlib +import pickle +import pytest +import sys +import tempfile +import weakref + +import numpy as np + +from pyarrow.util import guid +from pyarrow import Codec +import pyarrow as pa + + +def check_large_seeks(file_factory): + if sys.platform in ('win32', 'darwin'): + pytest.skip("need sparse file support") + try: + filename = tempfile.mktemp(prefix='test_io') + with open(filename, 'wb') as f: + f.truncate(2 ** 32 + 10) + f.seek(2 ** 32 + 5) + f.write(b'mark\n') + with file_factory(filename) as f: + assert f.seek(2 ** 32 + 5) == 2 ** 32 + 5 + assert f.tell() == 2 ** 32 + 5 + assert f.read(5) == b'mark\n' + assert f.tell() == 2 ** 32 + 10 + finally: + os.unlink(filename) + + +@contextmanager +def assert_file_not_found(): + with pytest.raises(FileNotFoundError): + yield + + +# ---------------------------------------------------------------------- +# Python file-like objects + + +def test_python_file_write(): + buf = BytesIO() + + f = pa.PythonFile(buf) + + assert f.tell() == 0 + + s1 = b'enga\xc3\xb1ado' + s2 = b'foobar' + + f.write(s1) + assert f.tell() == len(s1) + + f.write(s2) + + expected = s1 + s2 + + result = buf.getvalue() + assert result == expected + + assert not f.closed + f.close() + assert f.closed + + with pytest.raises(TypeError, match="binary file expected"): + pa.PythonFile(StringIO()) + + +def test_python_file_read(): + data = b'some sample data' + + buf = BytesIO(data) + f = pa.PythonFile(buf, mode='r') + + assert f.size() == len(data) + + assert f.tell() == 0 + + assert f.read(4) == b'some' + assert f.tell() == 4 + + f.seek(0) + assert f.tell() == 0 + + f.seek(5) + assert f.tell() == 5 + + v = f.read(50) + assert v == b'sample data' + assert len(v) == 11 + + assert f.size() == len(data) + + assert not f.closed + f.close() + assert f.closed + + with pytest.raises(TypeError, match="binary file expected"): + pa.PythonFile(StringIO(), mode='r') + + +def test_python_file_read_at(): + data = b'some sample data' + + buf = BytesIO(data) + f = pa.PythonFile(buf, mode='r') + + # test simple read at + v = f.read_at(nbytes=5, offset=3) + assert v == b'e sam' + assert len(v) == 5 + + # test reading entire file when nbytes > len(file) + w = f.read_at(nbytes=50, offset=0) + assert w == data + assert len(w) == 16 + + +def test_python_file_readall(): + data = b'some sample data' + + buf = BytesIO(data) + with pa.PythonFile(buf, mode='r') as f: + assert f.readall() == data + + +def test_python_file_readinto(): + length = 10 + data = b'some sample data longer than 10' + dst_buf = bytearray(length) + src_buf = BytesIO(data) + + with pa.PythonFile(src_buf, mode='r') as f: + assert f.readinto(dst_buf) == 10 + + assert dst_buf[:length] == data[:length] + assert len(dst_buf) == length + + +def test_python_file_read_buffer(): + length = 10 + data = b'0123456798' + dst_buf = bytearray(data) + + class DuckReader: + def close(self): + pass + + @property + def closed(self): + return False + + def read_buffer(self, nbytes): + assert nbytes == length + return memoryview(dst_buf)[:nbytes] + + duck_reader = DuckReader() + with pa.PythonFile(duck_reader, mode='r') as f: + buf = f.read_buffer(length) + assert len(buf) == length + assert memoryview(buf).tobytes() == dst_buf[:length] + # buf should point to the same memory, so modyfing it + memoryview(buf)[0] = ord(b'x') + # should modify the original + assert dst_buf[0] == ord(b'x') + + +def test_python_file_correct_abc(): + with pa.PythonFile(BytesIO(b''), mode='r') as f: + assert isinstance(f, BufferedIOBase) + assert isinstance(f, IOBase) + + +def test_python_file_iterable(): + data = b'''line1 + line2 + line3 + ''' + + buf = BytesIO(data) + buf2 = BytesIO(data) + + with pa.PythonFile(buf, mode='r') as f: + for read, expected in zip(f, buf2): + assert read == expected + + +def test_python_file_large_seeks(): + def factory(filename): + return pa.PythonFile(open(filename, 'rb')) + + check_large_seeks(factory) + + +def test_bytes_reader(): + # Like a BytesIO, but zero-copy underneath for C++ consumers + data = b'some sample data' + f = pa.BufferReader(data) + assert f.tell() == 0 + + assert f.size() == len(data) + + assert f.read(4) == b'some' + assert f.tell() == 4 + + f.seek(0) + assert f.tell() == 0 + + f.seek(0, 2) + assert f.tell() == len(data) + + f.seek(5) + assert f.tell() == 5 + + assert f.read(50) == b'sample data' + + assert not f.closed + f.close() + assert f.closed + + +def test_bytes_reader_non_bytes(): + with pytest.raises(TypeError): + pa.BufferReader('some sample data') + + +def test_bytes_reader_retains_parent_reference(): + import gc + + # ARROW-421 + def get_buffer(): + data = b'some sample data' * 1000 + reader = pa.BufferReader(data) + reader.seek(5) + return reader.read_buffer(6) + + buf = get_buffer() + gc.collect() + assert buf.to_pybytes() == b'sample' + assert buf.parent is not None + + +def test_python_file_implicit_mode(tmpdir): + path = os.path.join(str(tmpdir), 'foo.txt') + with open(path, 'wb') as f: + pf = pa.PythonFile(f) + assert pf.writable() + assert not pf.readable() + assert not pf.seekable() # PyOutputStream isn't seekable + f.write(b'foobar\n') + + with open(path, 'rb') as f: + pf = pa.PythonFile(f) + assert pf.readable() + assert not pf.writable() + assert pf.seekable() + assert pf.read() == b'foobar\n' + + bio = BytesIO() + pf = pa.PythonFile(bio) + assert pf.writable() + assert not pf.readable() + assert not pf.seekable() + pf.write(b'foobar\n') + assert bio.getvalue() == b'foobar\n' + + +def test_python_file_writelines(tmpdir): + lines = [b'line1\n', b'line2\n' b'line3'] + path = os.path.join(str(tmpdir), 'foo.txt') + with open(path, 'wb') as f: + try: + f = pa.PythonFile(f, mode='w') + assert f.writable() + f.writelines(lines) + finally: + f.close() + + with open(path, 'rb') as f: + try: + f = pa.PythonFile(f, mode='r') + assert f.readable() + assert f.read() == b''.join(lines) + finally: + f.close() + + +def test_python_file_closing(): + bio = BytesIO() + pf = pa.PythonFile(bio) + wr = weakref.ref(pf) + del pf + assert wr() is None # object was destroyed + assert not bio.closed + pf = pa.PythonFile(bio) + pf.close() + assert bio.closed + + +# ---------------------------------------------------------------------- +# Buffers + + +def test_buffer_bytes(): + val = b'some data' + + buf = pa.py_buffer(val) + assert isinstance(buf, pa.Buffer) + assert not buf.is_mutable + assert buf.is_cpu + + result = buf.to_pybytes() + + assert result == val + + # Check that buffers survive a pickle roundtrip + result_buf = pickle.loads(pickle.dumps(buf)) + result = result_buf.to_pybytes() + assert result == val + + +def test_buffer_memoryview(): + val = b'some data' + + buf = pa.py_buffer(val) + assert isinstance(buf, pa.Buffer) + assert not buf.is_mutable + assert buf.is_cpu + + result = memoryview(buf) + + assert result == val + + +def test_buffer_bytearray(): + val = bytearray(b'some data') + + buf = pa.py_buffer(val) + assert isinstance(buf, pa.Buffer) + assert buf.is_mutable + assert buf.is_cpu + + result = bytearray(buf) + + assert result == val + + +def test_buffer_invalid(): + with pytest.raises(TypeError, + match="(bytes-like object|buffer interface)"): + pa.py_buffer(None) + + +def test_buffer_weakref(): + buf = pa.py_buffer(b'some data') + wr = weakref.ref(buf) + assert wr() is not None + del buf + assert wr() is None + + +@pytest.mark.parametrize('val, expected_hex_buffer', + [(b'check', b'636865636B'), + (b'\a0', b'0730'), + (b'', b'')]) +def test_buffer_hex(val, expected_hex_buffer): + buf = pa.py_buffer(val) + assert buf.hex() == expected_hex_buffer + + +def test_buffer_to_numpy(): + # Make sure creating a numpy array from an arrow buffer works + byte_array = bytearray(20) + byte_array[0] = 42 + buf = pa.py_buffer(byte_array) + array = np.frombuffer(buf, dtype="uint8") + assert array[0] == byte_array[0] + byte_array[0] += 1 + assert array[0] == byte_array[0] + assert array.base == buf + + +def test_buffer_from_numpy(): + # C-contiguous + arr = np.arange(12, dtype=np.int8).reshape((3, 4)) + buf = pa.py_buffer(arr) + assert buf.is_cpu + assert buf.is_mutable + assert buf.to_pybytes() == arr.tobytes() + # F-contiguous; note strides information is lost + buf = pa.py_buffer(arr.T) + assert buf.is_cpu + assert buf.is_mutable + assert buf.to_pybytes() == arr.tobytes() + # Non-contiguous + with pytest.raises(ValueError, match="not contiguous"): + buf = pa.py_buffer(arr.T[::2]) + + +def test_buffer_address(): + b1 = b'some data!' + b2 = bytearray(b1) + b3 = bytearray(b1) + + buf1 = pa.py_buffer(b1) + buf2 = pa.py_buffer(b1) + buf3 = pa.py_buffer(b2) + buf4 = pa.py_buffer(b3) + + assert buf1.address > 0 + assert buf1.address == buf2.address + assert buf3.address != buf2.address + assert buf4.address != buf3.address + + arr = np.arange(5) + buf = pa.py_buffer(arr) + assert buf.address == arr.ctypes.data + + +def test_buffer_equals(): + # Buffer.equals() returns true iff the buffers have the same contents + def eq(a, b): + assert a.equals(b) + assert a == b + assert not (a != b) + + def ne(a, b): + assert not a.equals(b) + assert not (a == b) + assert a != b + + b1 = b'some data!' + b2 = bytearray(b1) + b3 = bytearray(b1) + b3[0] = 42 + buf1 = pa.py_buffer(b1) + buf2 = pa.py_buffer(b2) + buf3 = pa.py_buffer(b2) + buf4 = pa.py_buffer(b3) + buf5 = pa.py_buffer(np.frombuffer(b2, dtype=np.int16)) + eq(buf1, buf1) + eq(buf1, buf2) + eq(buf2, buf3) + ne(buf2, buf4) + # Data type is indifferent + eq(buf2, buf5) + + +def test_buffer_eq_bytes(): + buf = pa.py_buffer(b'some data') + assert buf == b'some data' + assert buf == bytearray(b'some data') + assert buf != b'some dat1' + + with pytest.raises(TypeError): + buf == 'some data' + + +def test_buffer_getitem(): + data = bytearray(b'some data!') + buf = pa.py_buffer(data) + + n = len(data) + for ix in range(-n, n - 1): + assert buf[ix] == data[ix] + + with pytest.raises(IndexError): + buf[n] + + with pytest.raises(IndexError): + buf[-n - 1] + + +def test_buffer_slicing(): + data = b'some data!' + buf = pa.py_buffer(data) + + sliced = buf.slice(2) + expected = pa.py_buffer(b'me data!') + assert sliced.equals(expected) + + sliced2 = buf.slice(2, 4) + expected2 = pa.py_buffer(b'me d') + assert sliced2.equals(expected2) + + # 0 offset + assert buf.slice(0).equals(buf) + + # Slice past end of buffer + assert len(buf.slice(len(buf))) == 0 + + with pytest.raises(IndexError): + buf.slice(-1) + + # Test slice notation + assert buf[2:].equals(buf.slice(2)) + assert buf[2:5].equals(buf.slice(2, 3)) + assert buf[-5:].equals(buf.slice(len(buf) - 5)) + with pytest.raises(IndexError): + buf[::-1] + with pytest.raises(IndexError): + buf[::2] + + n = len(buf) + for start in range(-n * 2, n * 2): + for stop in range(-n * 2, n * 2): + assert buf[start:stop].to_pybytes() == buf.to_pybytes()[start:stop] + + +def test_buffer_hashing(): + # Buffers are unhashable + with pytest.raises(TypeError, match="unhashable"): + hash(pa.py_buffer(b'123')) + + +def test_buffer_protocol_respects_immutability(): + # ARROW-3228; NumPy's frombuffer ctor determines whether a buffer-like + # object is mutable by first attempting to get a mutable buffer using + # PyObject_FromBuffer. If that fails, it assumes that the object is + # immutable + a = b'12345' + arrow_ref = pa.py_buffer(a) + numpy_ref = np.frombuffer(arrow_ref, dtype=np.uint8) + assert not numpy_ref.flags.writeable + + +def test_foreign_buffer(): + obj = np.array([1, 2], dtype=np.int32) + addr = obj.__array_interface__["data"][0] + size = obj.nbytes + buf = pa.foreign_buffer(addr, size, obj) + wr = weakref.ref(obj) + del obj + assert np.frombuffer(buf, dtype=np.int32).tolist() == [1, 2] + assert wr() is not None + del buf + assert wr() is None + + +def test_allocate_buffer(): + buf = pa.allocate_buffer(100) + assert buf.size == 100 + assert buf.is_mutable + assert buf.parent is None + + bit = b'abcde' + writer = pa.FixedSizeBufferWriter(buf) + writer.write(bit) + + assert buf.to_pybytes()[:5] == bit + + +def test_allocate_buffer_resizable(): + buf = pa.allocate_buffer(100, resizable=True) + assert isinstance(buf, pa.ResizableBuffer) + + buf.resize(200) + assert buf.size == 200 + + +@pytest.mark.parametrize("compression", [ + pytest.param( + "bz2", marks=pytest.mark.xfail(raises=pa.lib.ArrowNotImplementedError) + ), + "brotli", + "gzip", + "lz4", + "zstd", + "snappy" +]) +def test_compress_decompress(compression): + if not Codec.is_available(compression): + pytest.skip("{} support is not built".format(compression)) + + INPUT_SIZE = 10000 + test_data = (np.random.randint(0, 255, size=INPUT_SIZE) + .astype(np.uint8) + .tobytes()) + test_buf = pa.py_buffer(test_data) + + compressed_buf = pa.compress(test_buf, codec=compression) + compressed_bytes = pa.compress(test_data, codec=compression, + asbytes=True) + + assert isinstance(compressed_bytes, bytes) + + decompressed_buf = pa.decompress(compressed_buf, INPUT_SIZE, + codec=compression) + decompressed_bytes = pa.decompress(compressed_bytes, INPUT_SIZE, + codec=compression, asbytes=True) + + assert isinstance(decompressed_bytes, bytes) + + assert decompressed_buf.equals(test_buf) + assert decompressed_bytes == test_data + + with pytest.raises(ValueError): + pa.decompress(compressed_bytes, codec=compression) + + +@pytest.mark.parametrize("compression", [ + pytest.param( + "bz2", marks=pytest.mark.xfail(raises=pa.lib.ArrowNotImplementedError) + ), + "brotli", + "gzip", + "lz4", + "zstd", + "snappy" +]) +def test_compression_level(compression): + if not Codec.is_available(compression): + pytest.skip("{} support is not built".format(compression)) + + # These codecs do not support a compression level + no_level = ['snappy', 'lz4'] + if compression in no_level: + assert not Codec.supports_compression_level(compression) + with pytest.raises(ValueError): + Codec(compression, 0) + with pytest.raises(ValueError): + Codec.minimum_compression_level(compression) + with pytest.raises(ValueError): + Codec.maximum_compression_level(compression) + with pytest.raises(ValueError): + Codec.default_compression_level(compression) + return + + INPUT_SIZE = 10000 + test_data = (np.random.randint(0, 255, size=INPUT_SIZE) + .astype(np.uint8) + .tobytes()) + test_buf = pa.py_buffer(test_data) + + min_level = Codec.minimum_compression_level(compression) + max_level = Codec.maximum_compression_level(compression) + default_level = Codec.default_compression_level(compression) + + assert min_level < max_level + assert default_level >= min_level + assert default_level <= max_level + + for compression_level in range(min_level, max_level+1): + codec = Codec(compression, compression_level) + compressed_buf = codec.compress(test_buf) + compressed_bytes = codec.compress(test_data, asbytes=True) + assert isinstance(compressed_bytes, bytes) + decompressed_buf = codec.decompress(compressed_buf, INPUT_SIZE) + decompressed_bytes = codec.decompress(compressed_bytes, INPUT_SIZE, + asbytes=True) + + assert isinstance(decompressed_bytes, bytes) + + assert decompressed_buf.equals(test_buf) + assert decompressed_bytes == test_data + + with pytest.raises(ValueError): + codec.decompress(compressed_bytes) + + # The ability to set a seed this way is not present on older versions of + # numpy (currently in our python 3.6 CI build). Some inputs might just + # happen to compress the same between the two levels so using seeded + # random numbers is neccesary to help get more reliable results + # + # The goal of this part is to ensure the compression_level is being + # passed down to the C++ layer, not to verify the compression algs + # themselves + if not hasattr(np.random, 'default_rng'): + pytest.skip('Requires newer version of numpy') + rng = np.random.default_rng(seed=42) + values = rng.integers(0, 100, 1000) + arr = pa.array(values) + hard_to_compress_buffer = arr.buffers()[1] + + weak_codec = Codec(compression, min_level) + weakly_compressed_buf = weak_codec.compress(hard_to_compress_buffer) + + strong_codec = Codec(compression, max_level) + strongly_compressed_buf = strong_codec.compress(hard_to_compress_buffer) + + assert len(weakly_compressed_buf) > len(strongly_compressed_buf) + + +def test_buffer_memoryview_is_immutable(): + val = b'some data' + + buf = pa.py_buffer(val) + assert not buf.is_mutable + assert isinstance(buf, pa.Buffer) + + result = memoryview(buf) + assert result.readonly + + with pytest.raises(TypeError) as exc: + result[0] = b'h' + assert 'cannot modify read-only' in str(exc.value) + + b = bytes(buf) + with pytest.raises(TypeError) as exc: + b[0] = b'h' + assert 'cannot modify read-only' in str(exc.value) + + +def test_uninitialized_buffer(): + # ARROW-2039: calling Buffer() directly creates an uninitialized object + # ARROW-2638: prevent calling extension class constructors directly + with pytest.raises(TypeError): + pa.Buffer() + + +def test_memory_output_stream(): + # 10 bytes + val = b'dataabcdef' + f = pa.BufferOutputStream() + + K = 1000 + for i in range(K): + f.write(val) + + buf = f.getvalue() + assert len(buf) == len(val) * K + assert buf.to_pybytes() == val * K + + +def test_inmemory_write_after_closed(): + f = pa.BufferOutputStream() + f.write(b'ok') + assert not f.closed + f.getvalue() + assert f.closed + + with pytest.raises(ValueError): + f.write(b'not ok') + + +def test_buffer_protocol_ref_counting(): + def make_buffer(bytes_obj): + return bytearray(pa.py_buffer(bytes_obj)) + + buf = make_buffer(b'foo') + gc.collect() + assert buf == b'foo' + + # ARROW-1053 + val = b'foo' + refcount_before = sys.getrefcount(val) + for i in range(10): + make_buffer(val) + gc.collect() + assert refcount_before == sys.getrefcount(val) + + +def test_nativefile_write_memoryview(): + f = pa.BufferOutputStream() + data = b'ok' + + arr = np.frombuffer(data, dtype='S1') + + f.write(arr) + f.write(bytearray(data)) + f.write(pa.py_buffer(data)) + with pytest.raises(TypeError): + f.write(data.decode('utf8')) + + buf = f.getvalue() + + assert buf.to_pybytes() == data * 3 + + +# ---------------------------------------------------------------------- +# Mock output stream + + +def test_mock_output_stream(): + # Make sure that the MockOutputStream and the BufferOutputStream record the + # same size + + # 10 bytes + val = b'dataabcdef' + + f1 = pa.MockOutputStream() + f2 = pa.BufferOutputStream() + + K = 1000 + for i in range(K): + f1.write(val) + f2.write(val) + + assert f1.size() == len(f2.getvalue()) + + # Do the same test with a table + record_batch = pa.RecordBatch.from_arrays([pa.array([1, 2, 3])], ['a']) + + f1 = pa.MockOutputStream() + f2 = pa.BufferOutputStream() + + stream_writer1 = pa.RecordBatchStreamWriter(f1, record_batch.schema) + stream_writer2 = pa.RecordBatchStreamWriter(f2, record_batch.schema) + + stream_writer1.write_batch(record_batch) + stream_writer2.write_batch(record_batch) + stream_writer1.close() + stream_writer2.close() + + assert f1.size() == len(f2.getvalue()) + + +# ---------------------------------------------------------------------- +# OS files and memory maps + + +@pytest.fixture +def sample_disk_data(request, tmpdir): + SIZE = 4096 + arr = np.random.randint(0, 256, size=SIZE).astype('u1') + data = arr.tobytes()[:SIZE] + + path = os.path.join(str(tmpdir), guid()) + + with open(path, 'wb') as f: + f.write(data) + + def teardown(): + _try_delete(path) + + request.addfinalizer(teardown) + return path, data + + +def _check_native_file_reader(FACTORY, sample_data, + allow_read_out_of_bounds=True): + path, data = sample_data + + f = FACTORY(path, mode='r') + + assert f.read(10) == data[:10] + assert f.read(0) == b'' + assert f.tell() == 10 + + assert f.read() == data[10:] + + assert f.size() == len(data) + + f.seek(0) + assert f.tell() == 0 + + # Seeking past end of file not supported in memory maps + if allow_read_out_of_bounds: + f.seek(len(data) + 1) + assert f.tell() == len(data) + 1 + assert f.read(5) == b'' + + # Test whence argument of seek, ARROW-1287 + assert f.seek(3) == 3 + assert f.seek(3, os.SEEK_CUR) == 6 + assert f.tell() == 6 + + ex_length = len(data) - 2 + assert f.seek(-2, os.SEEK_END) == ex_length + assert f.tell() == ex_length + + +def test_memory_map_reader(sample_disk_data): + _check_native_file_reader(pa.memory_map, sample_disk_data, + allow_read_out_of_bounds=False) + + +def test_memory_map_retain_buffer_reference(sample_disk_data): + path, data = sample_disk_data + + cases = [] + with pa.memory_map(path, 'rb') as f: + cases.append((f.read_buffer(100), data[:100])) + cases.append((f.read_buffer(100), data[100:200])) + cases.append((f.read_buffer(100), data[200:300])) + + # Call gc.collect() for good measure + gc.collect() + + for buf, expected in cases: + assert buf.to_pybytes() == expected + + +def test_os_file_reader(sample_disk_data): + _check_native_file_reader(pa.OSFile, sample_disk_data) + + +def test_os_file_large_seeks(): + check_large_seeks(pa.OSFile) + + +def _try_delete(path): + try: + os.remove(path) + except os.error: + pass + + +def test_memory_map_writer(tmpdir): + SIZE = 4096 + arr = np.random.randint(0, 256, size=SIZE).astype('u1') + data = arr.tobytes()[:SIZE] + + path = os.path.join(str(tmpdir), guid()) + with open(path, 'wb') as f: + f.write(data) + + f = pa.memory_map(path, mode='r+b') + + f.seek(10) + f.write(b'peekaboo') + assert f.tell() == 18 + + f.seek(10) + assert f.read(8) == b'peekaboo' + + f2 = pa.memory_map(path, mode='r+b') + + f2.seek(10) + f2.write(b'booapeak') + f2.seek(10) + + f.seek(10) + assert f.read(8) == b'booapeak' + + # Does not truncate file + f3 = pa.memory_map(path, mode='w') + f3.write(b'foo') + + with pa.memory_map(path) as f4: + assert f4.size() == SIZE + + with pytest.raises(IOError): + f3.read(5) + + f.seek(0) + assert f.read(3) == b'foo' + + +def test_memory_map_resize(tmpdir): + SIZE = 4096 + arr = np.random.randint(0, 256, size=SIZE).astype(np.uint8) + data1 = arr.tobytes()[:(SIZE // 2)] + data2 = arr.tobytes()[(SIZE // 2):] + + path = os.path.join(str(tmpdir), guid()) + + mmap = pa.create_memory_map(path, SIZE / 2) + mmap.write(data1) + + mmap.resize(SIZE) + mmap.write(data2) + + mmap.close() + + with open(path, 'rb') as f: + assert f.read() == arr.tobytes() + + +def test_memory_zero_length(tmpdir): + path = os.path.join(str(tmpdir), guid()) + f = open(path, 'wb') + f.close() + with pa.memory_map(path, mode='r+b') as memory_map: + assert memory_map.size() == 0 + + +def test_memory_map_large_seeks(): + check_large_seeks(pa.memory_map) + + +def test_memory_map_close_remove(tmpdir): + # ARROW-6740: should be able to delete closed memory-mapped file (Windows) + path = os.path.join(str(tmpdir), guid()) + mmap = pa.create_memory_map(path, 4096) + mmap.close() + assert mmap.closed + os.remove(path) # Shouldn't fail + + +def test_memory_map_deref_remove(tmpdir): + path = os.path.join(str(tmpdir), guid()) + pa.create_memory_map(path, 4096) + os.remove(path) # Shouldn't fail + + +def test_os_file_writer(tmpdir): + SIZE = 4096 + arr = np.random.randint(0, 256, size=SIZE).astype('u1') + data = arr.tobytes()[:SIZE] + + path = os.path.join(str(tmpdir), guid()) + with open(path, 'wb') as f: + f.write(data) + + # Truncates file + f2 = pa.OSFile(path, mode='w') + f2.write(b'foo') + + with pa.OSFile(path) as f3: + assert f3.size() == 3 + + with pytest.raises(IOError): + f2.read(5) + + +def test_native_file_write_reject_unicode(): + # ARROW-3227 + nf = pa.BufferOutputStream() + with pytest.raises(TypeError): + nf.write('foo') + + +def test_native_file_modes(tmpdir): + path = os.path.join(str(tmpdir), guid()) + with open(path, 'wb') as f: + f.write(b'foooo') + + with pa.OSFile(path, mode='r') as f: + assert f.mode == 'rb' + assert f.readable() + assert not f.writable() + assert f.seekable() + + with pa.OSFile(path, mode='rb') as f: + assert f.mode == 'rb' + assert f.readable() + assert not f.writable() + assert f.seekable() + + with pa.OSFile(path, mode='w') as f: + assert f.mode == 'wb' + assert not f.readable() + assert f.writable() + assert not f.seekable() + + with pa.OSFile(path, mode='wb') as f: + assert f.mode == 'wb' + assert not f.readable() + assert f.writable() + assert not f.seekable() + + with open(path, 'wb') as f: + f.write(b'foooo') + + with pa.memory_map(path, 'r') as f: + assert f.mode == 'rb' + assert f.readable() + assert not f.writable() + assert f.seekable() + + with pa.memory_map(path, 'r+') as f: + assert f.mode == 'rb+' + assert f.readable() + assert f.writable() + assert f.seekable() + + with pa.memory_map(path, 'r+b') as f: + assert f.mode == 'rb+' + assert f.readable() + assert f.writable() + assert f.seekable() + + +def test_native_file_permissions(tmpdir): + # ARROW-10124: permissions of created files should follow umask + cur_umask = os.umask(0o002) + os.umask(cur_umask) + + path = os.path.join(str(tmpdir), guid()) + with pa.OSFile(path, mode='w'): + pass + assert os.stat(path).st_mode & 0o777 == 0o666 & ~cur_umask + + path = os.path.join(str(tmpdir), guid()) + with pa.memory_map(path, 'w'): + pass + assert os.stat(path).st_mode & 0o777 == 0o666 & ~cur_umask + + +def test_native_file_raises_ValueError_after_close(tmpdir): + path = os.path.join(str(tmpdir), guid()) + with open(path, 'wb') as f: + f.write(b'foooo') + + with pa.OSFile(path, mode='rb') as os_file: + assert not os_file.closed + assert os_file.closed + + with pa.memory_map(path, mode='rb') as mmap_file: + assert not mmap_file.closed + assert mmap_file.closed + + files = [os_file, + mmap_file] + + methods = [('tell', ()), + ('seek', (0,)), + ('size', ()), + ('flush', ()), + ('readable', ()), + ('writable', ()), + ('seekable', ())] + + for f in files: + for method, args in methods: + with pytest.raises(ValueError): + getattr(f, method)(*args) + + +def test_native_file_TextIOWrapper(tmpdir): + data = ('foooo\n' + 'barrr\n' + 'bazzz\n') + + path = os.path.join(str(tmpdir), guid()) + with open(path, 'wb') as f: + f.write(data.encode('utf-8')) + + with TextIOWrapper(pa.OSFile(path, mode='rb')) as fil: + assert fil.readable() + res = fil.read() + assert res == data + assert fil.closed + + with TextIOWrapper(pa.OSFile(path, mode='rb')) as fil: + # Iteration works + lines = list(fil) + assert ''.join(lines) == data + + # Writing + path2 = os.path.join(str(tmpdir), guid()) + with TextIOWrapper(pa.OSFile(path2, mode='wb')) as fil: + assert fil.writable() + fil.write(data) + + with TextIOWrapper(pa.OSFile(path2, mode='rb')) as fil: + res = fil.read() + assert res == data + + +def test_native_file_open_error(): + with assert_file_not_found(): + pa.OSFile('non_existent_file', 'rb') + with assert_file_not_found(): + pa.memory_map('non_existent_file', 'rb') + + +# ---------------------------------------------------------------------- +# Buffered streams + +def test_buffered_input_stream(): + raw = pa.BufferReader(b"123456789") + f = pa.BufferedInputStream(raw, buffer_size=4) + assert f.read(2) == b"12" + assert raw.tell() == 4 + f.close() + assert f.closed + assert raw.closed + + +def test_buffered_input_stream_detach_seekable(): + # detach() to a seekable file (io::RandomAccessFile in C++) + f = pa.BufferedInputStream(pa.BufferReader(b"123456789"), buffer_size=4) + assert f.read(2) == b"12" + raw = f.detach() + assert f.closed + assert not raw.closed + assert raw.seekable() + assert raw.read(4) == b"5678" + raw.seek(2) + assert raw.read(4) == b"3456" + + +def test_buffered_input_stream_detach_non_seekable(): + # detach() to a non-seekable file (io::InputStream in C++) + f = pa.BufferedInputStream( + pa.BufferedInputStream(pa.BufferReader(b"123456789"), buffer_size=4), + buffer_size=4) + assert f.read(2) == b"12" + raw = f.detach() + assert f.closed + assert not raw.closed + assert not raw.seekable() + assert raw.read(4) == b"5678" + with pytest.raises(EnvironmentError): + raw.seek(2) + + +def test_buffered_output_stream(): + np_buf = np.zeros(100, dtype=np.int8) # zero-initialized buffer + buf = pa.py_buffer(np_buf) + + raw = pa.FixedSizeBufferWriter(buf) + f = pa.BufferedOutputStream(raw, buffer_size=4) + f.write(b"12") + assert np_buf[:4].tobytes() == b'\0\0\0\0' + f.flush() + assert np_buf[:4].tobytes() == b'12\0\0' + f.write(b"3456789") + f.close() + assert f.closed + assert raw.closed + assert np_buf[:10].tobytes() == b'123456789\0' + + +def test_buffered_output_stream_detach(): + np_buf = np.zeros(100, dtype=np.int8) # zero-initialized buffer + buf = pa.py_buffer(np_buf) + + f = pa.BufferedOutputStream(pa.FixedSizeBufferWriter(buf), buffer_size=4) + f.write(b"12") + assert np_buf[:4].tobytes() == b'\0\0\0\0' + raw = f.detach() + assert f.closed + assert not raw.closed + assert np_buf[:4].tobytes() == b'12\0\0' + + +# ---------------------------------------------------------------------- +# Compressed input and output streams + +def check_compressed_input(data, fn, compression): + raw = pa.OSFile(fn, mode="rb") + with pa.CompressedInputStream(raw, compression) as compressed: + assert not compressed.closed + assert compressed.readable() + assert not compressed.writable() + assert not compressed.seekable() + got = compressed.read() + assert got == data + assert compressed.closed + assert raw.closed + + # Same with read_buffer() + raw = pa.OSFile(fn, mode="rb") + with pa.CompressedInputStream(raw, compression) as compressed: + buf = compressed.read_buffer() + assert isinstance(buf, pa.Buffer) + assert buf.to_pybytes() == data + + +@pytest.mark.gzip +def test_compressed_input_gzip(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + fn = str(tmpdir / "compressed_input_test.gz") + with gzip.open(fn, "wb") as f: + f.write(data) + check_compressed_input(data, fn, "gzip") + + +def test_compressed_input_bz2(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + fn = str(tmpdir / "compressed_input_test.bz2") + with bz2.BZ2File(fn, "w") as f: + f.write(data) + try: + check_compressed_input(data, fn, "bz2") + except NotImplementedError as e: + pytest.skip(str(e)) + + +@pytest.mark.gzip +def test_compressed_input_openfile(tmpdir): + if not Codec.is_available("gzip"): + pytest.skip("gzip support is not built") + + data = b"some test data\n" * 10 + b"eof\n" + fn = str(tmpdir / "test_compressed_input_openfile.gz") + with gzip.open(fn, "wb") as f: + f.write(data) + + with pa.CompressedInputStream(fn, "gzip") as compressed: + buf = compressed.read_buffer() + assert buf.to_pybytes() == data + assert compressed.closed + + with pa.CompressedInputStream(pathlib.Path(fn), "gzip") as compressed: + buf = compressed.read_buffer() + assert buf.to_pybytes() == data + assert compressed.closed + + f = open(fn, "rb") + with pa.CompressedInputStream(f, "gzip") as compressed: + buf = compressed.read_buffer() + assert buf.to_pybytes() == data + assert f.closed + + +def check_compressed_concatenated(data, fn, compression): + raw = pa.OSFile(fn, mode="rb") + with pa.CompressedInputStream(raw, compression) as compressed: + got = compressed.read() + assert got == data + + +@pytest.mark.gzip +def test_compressed_concatenated_gzip(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + fn = str(tmpdir / "compressed_input_test2.gz") + with gzip.open(fn, "wb") as f: + f.write(data[:50]) + with gzip.open(fn, "ab") as f: + f.write(data[50:]) + check_compressed_concatenated(data, fn, "gzip") + + +@pytest.mark.gzip +def test_compressed_input_invalid(): + data = b"foo" * 10 + raw = pa.BufferReader(data) + with pytest.raises(ValueError): + pa.CompressedInputStream(raw, "unknown_compression") + with pytest.raises(TypeError): + pa.CompressedInputStream(raw, None) + + with pa.CompressedInputStream(raw, "gzip") as compressed: + with pytest.raises(IOError, match="zlib inflate failed"): + compressed.read() + + +def make_compressed_output(data, fn, compression): + raw = pa.BufferOutputStream() + with pa.CompressedOutputStream(raw, compression) as compressed: + assert not compressed.closed + assert not compressed.readable() + assert compressed.writable() + assert not compressed.seekable() + compressed.write(data) + assert compressed.closed + assert raw.closed + with open(fn, "wb") as f: + f.write(raw.getvalue()) + + +@pytest.mark.gzip +def test_compressed_output_gzip(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + fn = str(tmpdir / "compressed_output_test.gz") + make_compressed_output(data, fn, "gzip") + with gzip.open(fn, "rb") as f: + got = f.read() + assert got == data + + +def test_compressed_output_bz2(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + fn = str(tmpdir / "compressed_output_test.bz2") + try: + make_compressed_output(data, fn, "bz2") + except NotImplementedError as e: + pytest.skip(str(e)) + with bz2.BZ2File(fn, "r") as f: + got = f.read() + assert got == data + + +def test_output_stream_constructor(tmpdir): + if not Codec.is_available("gzip"): + pytest.skip("gzip support is not built") + with pa.CompressedOutputStream(tmpdir / "ctor.gz", "gzip") as stream: + stream.write(b"test") + with (tmpdir / "ctor2.gz").open("wb") as f: + with pa.CompressedOutputStream(f, "gzip") as stream: + stream.write(b"test") + + +@pytest.mark.parametrize(("path", "expected_compression"), [ + ("file.bz2", "bz2"), + ("file.lz4", "lz4"), + (pathlib.Path("file.gz"), "gzip"), + (pathlib.Path("path/to/file.zst"), "zstd"), +]) +def test_compression_detection(path, expected_compression): + if not Codec.is_available(expected_compression): + with pytest.raises(pa.lib.ArrowNotImplementedError): + Codec.detect(path) + else: + codec = Codec.detect(path) + assert isinstance(codec, Codec) + assert codec.name == expected_compression + + +def test_unknown_compression_raises(): + with pytest.raises(ValueError): + Codec.is_available('unknown') + with pytest.raises(TypeError): + Codec(None) + with pytest.raises(ValueError): + Codec('unknown') + + +@pytest.mark.parametrize("compression", [ + "bz2", + "brotli", + "gzip", + "lz4", + "zstd", + pytest.param( + "snappy", + marks=pytest.mark.xfail(raises=pa.lib.ArrowNotImplementedError) + ) +]) +def test_compressed_roundtrip(compression): + if not Codec.is_available(compression): + pytest.skip("{} support is not built".format(compression)) + + data = b"some test data\n" * 10 + b"eof\n" + raw = pa.BufferOutputStream() + with pa.CompressedOutputStream(raw, compression) as compressed: + compressed.write(data) + + cdata = raw.getvalue() + assert len(cdata) < len(data) + raw = pa.BufferReader(cdata) + with pa.CompressedInputStream(raw, compression) as compressed: + got = compressed.read() + assert got == data + + +@pytest.mark.parametrize( + "compression", + ["bz2", "brotli", "gzip", "lz4", "zstd"] +) +def test_compressed_recordbatch_stream(compression): + if not Codec.is_available(compression): + pytest.skip("{} support is not built".format(compression)) + + # ARROW-4836: roundtrip a RecordBatch through a compressed stream + table = pa.Table.from_arrays([pa.array([1, 2, 3, 4, 5])], ['a']) + raw = pa.BufferOutputStream() + stream = pa.CompressedOutputStream(raw, compression) + writer = pa.RecordBatchStreamWriter(stream, table.schema) + writer.write_table(table, max_chunksize=3) + writer.close() + stream.close() # Flush data + buf = raw.getvalue() + stream = pa.CompressedInputStream(pa.BufferReader(buf), compression) + got_table = pa.RecordBatchStreamReader(stream).read_all() + assert got_table == table + + +# ---------------------------------------------------------------------- +# Transform input streams + +unicode_transcoding_example = ( + "Dès Noël où un zéphyr haï me vêt de glaçons würmiens " + "je dîne d’exquis rôtis de bœuf au kir à l’aÿ d’âge mûr & cætera !" +) + + +def check_transcoding(data, src_encoding, dest_encoding, chunk_sizes): + chunk_sizes = iter(chunk_sizes) + stream = pa.transcoding_input_stream( + pa.BufferReader(data.encode(src_encoding)), + src_encoding, dest_encoding) + out = [] + while True: + buf = stream.read(next(chunk_sizes)) + out.append(buf) + if not buf: + break + out = b''.join(out) + assert out.decode(dest_encoding) == data + + +@pytest.mark.parametrize('src_encoding, dest_encoding', + [('utf-8', 'utf-16'), + ('utf-16', 'utf-8'), + ('utf-8', 'utf-32-le'), + ('utf-8', 'utf-32-be'), + ]) +def test_transcoding_input_stream(src_encoding, dest_encoding): + # All at once + check_transcoding(unicode_transcoding_example, + src_encoding, dest_encoding, [1000, 0]) + # Incremental + check_transcoding(unicode_transcoding_example, + src_encoding, dest_encoding, + itertools.cycle([1, 2, 3, 5])) + + +@pytest.mark.parametrize('src_encoding, dest_encoding', + [('utf-8', 'utf-8'), + ('utf-8', 'UTF8')]) +def test_transcoding_no_ops(src_encoding, dest_encoding): + # No indirection is wasted when a trivial transcoding is requested + stream = pa.BufferReader(b"abc123") + assert pa.transcoding_input_stream( + stream, src_encoding, dest_encoding) is stream + + +@pytest.mark.parametrize('src_encoding, dest_encoding', + [('utf-8', 'ascii'), + ('utf-8', 'latin-1'), + ]) +def test_transcoding_encoding_error(src_encoding, dest_encoding): + # Character \u0100 cannot be represented in the destination encoding + stream = pa.transcoding_input_stream( + pa.BufferReader("\u0100".encode(src_encoding)), + src_encoding, + dest_encoding) + with pytest.raises(UnicodeEncodeError): + stream.read(1) + + +@pytest.mark.parametrize('src_encoding, dest_encoding', + [('utf-8', 'utf-16'), + ('utf-16', 'utf-8'), + ]) +def test_transcoding_decoding_error(src_encoding, dest_encoding): + # The given bytestring is not valid in the source encoding + stream = pa.transcoding_input_stream( + pa.BufferReader(b"\xff\xff\xff\xff"), + src_encoding, + dest_encoding) + with pytest.raises(UnicodeError): + stream.read(1) + + +# ---------------------------------------------------------------------- +# High-level API + +@pytest.mark.gzip +def test_input_stream_buffer(): + data = b"some test data\n" * 10 + b"eof\n" + for arg in [pa.py_buffer(data), memoryview(data)]: + stream = pa.input_stream(arg) + assert stream.read() == data + + gz_data = gzip.compress(data) + stream = pa.input_stream(memoryview(gz_data)) + assert stream.read() == gz_data + stream = pa.input_stream(memoryview(gz_data), compression='gzip') + assert stream.read() == data + + +def test_input_stream_duck_typing(): + # Accept objects having the right file-like methods... + class DuckReader: + + def close(self): + pass + + @property + def closed(self): + return False + + def read(self, nbytes=None): + return b'hello' + + stream = pa.input_stream(DuckReader()) + assert stream.read(5) == b'hello' + + +def test_input_stream_file_path(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + file_path = tmpdir / 'input_stream' + with open(str(file_path), 'wb') as f: + f.write(data) + + stream = pa.input_stream(file_path) + assert stream.read() == data + stream = pa.input_stream(str(file_path)) + assert stream.read() == data + stream = pa.input_stream(pathlib.Path(str(file_path))) + assert stream.read() == data + + +@pytest.mark.gzip +def test_input_stream_file_path_compressed(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + gz_data = gzip.compress(data) + file_path = tmpdir / 'input_stream.gz' + with open(str(file_path), 'wb') as f: + f.write(gz_data) + + stream = pa.input_stream(file_path) + assert stream.read() == data + stream = pa.input_stream(str(file_path)) + assert stream.read() == data + stream = pa.input_stream(pathlib.Path(str(file_path))) + assert stream.read() == data + + stream = pa.input_stream(file_path, compression='gzip') + assert stream.read() == data + stream = pa.input_stream(file_path, compression=None) + assert stream.read() == gz_data + + +def test_input_stream_file_path_buffered(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + file_path = tmpdir / 'input_stream.buffered' + with open(str(file_path), 'wb') as f: + f.write(data) + + stream = pa.input_stream(file_path, buffer_size=32) + assert isinstance(stream, pa.BufferedInputStream) + assert stream.read() == data + stream = pa.input_stream(str(file_path), buffer_size=64) + assert isinstance(stream, pa.BufferedInputStream) + assert stream.read() == data + stream = pa.input_stream(pathlib.Path(str(file_path)), buffer_size=1024) + assert isinstance(stream, pa.BufferedInputStream) + assert stream.read() == data + + unbuffered_stream = pa.input_stream(file_path, buffer_size=0) + assert isinstance(unbuffered_stream, pa.OSFile) + + msg = 'Buffer size must be larger than zero' + with pytest.raises(ValueError, match=msg): + pa.input_stream(file_path, buffer_size=-1) + with pytest.raises(TypeError): + pa.input_stream(file_path, buffer_size='million') + + +@pytest.mark.gzip +def test_input_stream_file_path_compressed_and_buffered(tmpdir): + data = b"some test data\n" * 100 + b"eof\n" + gz_data = gzip.compress(data) + file_path = tmpdir / 'input_stream_compressed_and_buffered.gz' + with open(str(file_path), 'wb') as f: + f.write(gz_data) + + stream = pa.input_stream(file_path, buffer_size=32, compression='gzip') + assert stream.read() == data + stream = pa.input_stream(str(file_path), buffer_size=64) + assert stream.read() == data + stream = pa.input_stream(pathlib.Path(str(file_path)), buffer_size=1024) + assert stream.read() == data + + +@pytest.mark.gzip +def test_input_stream_python_file(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + bio = BytesIO(data) + + stream = pa.input_stream(bio) + assert stream.read() == data + + gz_data = gzip.compress(data) + bio = BytesIO(gz_data) + stream = pa.input_stream(bio) + assert stream.read() == gz_data + bio.seek(0) + stream = pa.input_stream(bio, compression='gzip') + assert stream.read() == data + + file_path = tmpdir / 'input_stream' + with open(str(file_path), 'wb') as f: + f.write(data) + with open(str(file_path), 'rb') as f: + stream = pa.input_stream(f) + assert stream.read() == data + + +@pytest.mark.gzip +def test_input_stream_native_file(): + data = b"some test data\n" * 10 + b"eof\n" + gz_data = gzip.compress(data) + reader = pa.BufferReader(gz_data) + stream = pa.input_stream(reader) + assert stream is reader + reader = pa.BufferReader(gz_data) + stream = pa.input_stream(reader, compression='gzip') + assert stream.read() == data + + +def test_input_stream_errors(tmpdir): + buf = memoryview(b"") + with pytest.raises(ValueError): + pa.input_stream(buf, compression="foo") + + for arg in [bytearray(), StringIO()]: + with pytest.raises(TypeError): + pa.input_stream(arg) + + with assert_file_not_found(): + pa.input_stream("non_existent_file") + + with open(str(tmpdir / 'new_file'), 'wb') as f: + with pytest.raises(TypeError, match="readable file expected"): + pa.input_stream(f) + + +def test_output_stream_buffer(): + data = b"some test data\n" * 10 + b"eof\n" + buf = bytearray(len(data)) + stream = pa.output_stream(pa.py_buffer(buf)) + stream.write(data) + assert buf == data + + buf = bytearray(len(data)) + stream = pa.output_stream(memoryview(buf)) + stream.write(data) + assert buf == data + + +def test_output_stream_duck_typing(): + # Accept objects having the right file-like methods... + class DuckWriter: + def __init__(self): + self.buf = pa.BufferOutputStream() + + def close(self): + pass + + @property + def closed(self): + return False + + def write(self, data): + self.buf.write(data) + + duck_writer = DuckWriter() + stream = pa.output_stream(duck_writer) + assert stream.write(b'hello') + assert duck_writer.buf.getvalue().to_pybytes() == b'hello' + + +def test_output_stream_file_path(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + file_path = tmpdir / 'output_stream' + + def check_data(file_path, data): + with pa.output_stream(file_path) as stream: + stream.write(data) + with open(str(file_path), 'rb') as f: + assert f.read() == data + + check_data(file_path, data) + check_data(str(file_path), data) + check_data(pathlib.Path(str(file_path)), data) + + +@pytest.mark.gzip +def test_output_stream_file_path_compressed(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + file_path = tmpdir / 'output_stream.gz' + + def check_data(file_path, data, **kwargs): + with pa.output_stream(file_path, **kwargs) as stream: + stream.write(data) + with open(str(file_path), 'rb') as f: + return f.read() + + assert gzip.decompress(check_data(file_path, data)) == data + assert gzip.decompress(check_data(str(file_path), data)) == data + assert gzip.decompress( + check_data(pathlib.Path(str(file_path)), data)) == data + + assert gzip.decompress( + check_data(file_path, data, compression='gzip')) == data + assert check_data(file_path, data, compression=None) == data + + with pytest.raises(ValueError, match='Invalid value for compression'): + assert check_data(file_path, data, compression='rabbit') == data + + +def test_output_stream_file_path_buffered(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + file_path = tmpdir / 'output_stream.buffered' + + def check_data(file_path, data, **kwargs): + with pa.output_stream(file_path, **kwargs) as stream: + if kwargs.get('buffer_size', 0) > 0: + assert isinstance(stream, pa.BufferedOutputStream) + stream.write(data) + with open(str(file_path), 'rb') as f: + return f.read() + + unbuffered_stream = pa.output_stream(file_path, buffer_size=0) + assert isinstance(unbuffered_stream, pa.OSFile) + + msg = 'Buffer size must be larger than zero' + with pytest.raises(ValueError, match=msg): + assert check_data(file_path, data, buffer_size=-128) == data + + assert check_data(file_path, data, buffer_size=32) == data + assert check_data(file_path, data, buffer_size=1024) == data + assert check_data(str(file_path), data, buffer_size=32) == data + + result = check_data(pathlib.Path(str(file_path)), data, buffer_size=32) + assert result == data + + +@pytest.mark.gzip +def test_output_stream_file_path_compressed_and_buffered(tmpdir): + data = b"some test data\n" * 100 + b"eof\n" + file_path = tmpdir / 'output_stream_compressed_and_buffered.gz' + + def check_data(file_path, data, **kwargs): + with pa.output_stream(file_path, **kwargs) as stream: + stream.write(data) + with open(str(file_path), 'rb') as f: + return f.read() + + result = check_data(file_path, data, buffer_size=32) + assert gzip.decompress(result) == data + + result = check_data(file_path, data, buffer_size=1024) + assert gzip.decompress(result) == data + + result = check_data(file_path, data, buffer_size=1024, compression='gzip') + assert gzip.decompress(result) == data + + +def test_output_stream_destructor(tmpdir): + # The wrapper returned by pa.output_stream() should respect Python + # file semantics, i.e. destroying it should close the underlying + # file cleanly. + data = b"some test data\n" + file_path = tmpdir / 'output_stream.buffered' + + def check_data(file_path, data, **kwargs): + stream = pa.output_stream(file_path, **kwargs) + stream.write(data) + del stream + gc.collect() + with open(str(file_path), 'rb') as f: + return f.read() + + assert check_data(file_path, data, buffer_size=0) == data + assert check_data(file_path, data, buffer_size=1024) == data + + +@pytest.mark.gzip +def test_output_stream_python_file(tmpdir): + data = b"some test data\n" * 10 + b"eof\n" + + def check_data(data, **kwargs): + # XXX cannot use BytesIO because stream.close() is necessary + # to finish writing compressed data, but it will also close the + # underlying BytesIO + fn = str(tmpdir / 'output_stream_file') + with open(fn, 'wb') as f: + with pa.output_stream(f, **kwargs) as stream: + stream.write(data) + with open(fn, 'rb') as f: + return f.read() + + assert check_data(data) == data + assert gzip.decompress(check_data(data, compression='gzip')) == data + + +def test_output_stream_errors(tmpdir): + buf = memoryview(bytearray()) + with pytest.raises(ValueError): + pa.output_stream(buf, compression="foo") + + for arg in [bytearray(), StringIO()]: + with pytest.raises(TypeError): + pa.output_stream(arg) + + fn = str(tmpdir / 'new_file') + with open(fn, 'wb') as f: + pass + with open(fn, 'rb') as f: + with pytest.raises(TypeError, match="writable file expected"): + pa.output_stream(f) diff --git a/src/arrow/python/pyarrow/tests/test_ipc.py b/src/arrow/python/pyarrow/tests/test_ipc.py new file mode 100644 index 000000000..87944bcc0 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_ipc.py @@ -0,0 +1,999 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import UserList +import io +import pathlib +import pytest +import socket +import threading +import weakref + +import numpy as np + +import pyarrow as pa +from pyarrow.tests.util import changed_environ + + +try: + from pandas.testing import assert_frame_equal, assert_series_equal + import pandas as pd +except ImportError: + pass + + +class IpcFixture: + write_stats = None + + def __init__(self, sink_factory=lambda: io.BytesIO()): + self._sink_factory = sink_factory + self.sink = self.get_sink() + + def get_sink(self): + return self._sink_factory() + + def get_source(self): + return self.sink.getvalue() + + def write_batches(self, num_batches=5, as_table=False): + nrows = 5 + schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())]) + + writer = self._get_writer(self.sink, schema) + + batches = [] + for i in range(num_batches): + batch = pa.record_batch( + [np.random.randn(nrows), + ['foo', None, 'bar', 'bazbaz', 'qux']], + schema=schema) + batches.append(batch) + + if as_table: + table = pa.Table.from_batches(batches) + writer.write_table(table) + else: + for batch in batches: + writer.write_batch(batch) + + self.write_stats = writer.stats + writer.close() + return batches + + +class FileFormatFixture(IpcFixture): + + def _get_writer(self, sink, schema): + return pa.ipc.new_file(sink, schema) + + def _check_roundtrip(self, as_table=False): + batches = self.write_batches(as_table=as_table) + file_contents = pa.BufferReader(self.get_source()) + + reader = pa.ipc.open_file(file_contents) + + assert reader.num_record_batches == len(batches) + + for i, batch in enumerate(batches): + # it works. Must convert back to DataFrame + batch = reader.get_batch(i) + assert batches[i].equals(batch) + assert reader.schema.equals(batches[0].schema) + + assert isinstance(reader.stats, pa.ipc.ReadStats) + assert isinstance(self.write_stats, pa.ipc.WriteStats) + assert tuple(reader.stats) == tuple(self.write_stats) + + +class StreamFormatFixture(IpcFixture): + + # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix + use_legacy_ipc_format = False + # ARROW-9395, for testing writing old metadata version + options = None + + def _get_writer(self, sink, schema): + return pa.ipc.new_stream( + sink, + schema, + use_legacy_format=self.use_legacy_ipc_format, + options=self.options, + ) + + +class MessageFixture(IpcFixture): + + def _get_writer(self, sink, schema): + return pa.RecordBatchStreamWriter(sink, schema) + + +@pytest.fixture +def ipc_fixture(): + return IpcFixture() + + +@pytest.fixture +def file_fixture(): + return FileFormatFixture() + + +@pytest.fixture +def stream_fixture(): + return StreamFormatFixture() + + +def test_empty_file(): + buf = b'' + with pytest.raises(pa.ArrowInvalid): + pa.ipc.open_file(pa.BufferReader(buf)) + + +def test_file_simple_roundtrip(file_fixture): + file_fixture._check_roundtrip(as_table=False) + + +def test_file_write_table(file_fixture): + file_fixture._check_roundtrip(as_table=True) + + +@pytest.mark.parametrize("sink_factory", [ + lambda: io.BytesIO(), + lambda: pa.BufferOutputStream() +]) +def test_file_read_all(sink_factory): + fixture = FileFormatFixture(sink_factory) + + batches = fixture.write_batches() + file_contents = pa.BufferReader(fixture.get_source()) + + reader = pa.ipc.open_file(file_contents) + + result = reader.read_all() + expected = pa.Table.from_batches(batches) + assert result.equals(expected) + + +def test_open_file_from_buffer(file_fixture): + # ARROW-2859; APIs accept the buffer protocol + file_fixture.write_batches() + source = file_fixture.get_source() + + reader1 = pa.ipc.open_file(source) + reader2 = pa.ipc.open_file(pa.BufferReader(source)) + reader3 = pa.RecordBatchFileReader(source) + + result1 = reader1.read_all() + result2 = reader2.read_all() + result3 = reader3.read_all() + + assert result1.equals(result2) + assert result1.equals(result3) + + st1 = reader1.stats + assert st1.num_messages == 6 + assert st1.num_record_batches == 5 + assert reader2.stats == st1 + assert reader3.stats == st1 + + +@pytest.mark.pandas +def test_file_read_pandas(file_fixture): + frames = [batch.to_pandas() for batch in file_fixture.write_batches()] + + file_contents = pa.BufferReader(file_fixture.get_source()) + reader = pa.ipc.open_file(file_contents) + result = reader.read_pandas() + + expected = pd.concat(frames).reset_index(drop=True) + assert_frame_equal(result, expected) + + +def test_file_pathlib(file_fixture, tmpdir): + file_fixture.write_batches() + source = file_fixture.get_source() + + path = tmpdir.join('file.arrow').strpath + with open(path, 'wb') as f: + f.write(source) + + t1 = pa.ipc.open_file(pathlib.Path(path)).read_all() + t2 = pa.ipc.open_file(pa.OSFile(path)).read_all() + + assert t1.equals(t2) + + +def test_empty_stream(): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.ipc.open_stream(buf) + + +@pytest.mark.pandas +def test_stream_categorical_roundtrip(stream_fixture): + df = pd.DataFrame({ + 'one': np.random.randn(5), + 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'], + categories=['foo', 'bar'], + ordered=True) + }) + batch = pa.RecordBatch.from_pandas(df) + with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr: + wr.write_batch(batch) + + table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source())) + .read_all()) + assert_frame_equal(table.to_pandas(), df) + + +def test_open_stream_from_buffer(stream_fixture): + # ARROW-2859 + stream_fixture.write_batches() + source = stream_fixture.get_source() + + reader1 = pa.ipc.open_stream(source) + reader2 = pa.ipc.open_stream(pa.BufferReader(source)) + reader3 = pa.RecordBatchStreamReader(source) + + result1 = reader1.read_all() + result2 = reader2.read_all() + result3 = reader3.read_all() + + assert result1.equals(result2) + assert result1.equals(result3) + + st1 = reader1.stats + assert st1.num_messages == 6 + assert st1.num_record_batches == 5 + assert reader2.stats == st1 + assert reader3.stats == st1 + + assert tuple(st1) == tuple(stream_fixture.write_stats) + + +@pytest.mark.pandas +def test_stream_write_dispatch(stream_fixture): + # ARROW-1616 + df = pd.DataFrame({ + 'one': np.random.randn(5), + 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'], + categories=['foo', 'bar'], + ordered=True) + }) + table = pa.Table.from_pandas(df, preserve_index=False) + batch = pa.RecordBatch.from_pandas(df, preserve_index=False) + with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr: + wr.write(table) + wr.write(batch) + + table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source())) + .read_all()) + assert_frame_equal(table.to_pandas(), + pd.concat([df, df], ignore_index=True)) + + +@pytest.mark.pandas +def test_stream_write_table_batches(stream_fixture): + # ARROW-504 + df = pd.DataFrame({ + 'one': np.random.randn(20), + }) + + b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False) + b2 = pa.RecordBatch.from_pandas(df, preserve_index=False) + + table = pa.Table.from_batches([b1, b2, b1]) + + with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr: + wr.write_table(table, max_chunksize=15) + + batches = list(pa.ipc.open_stream(stream_fixture.get_source())) + + assert list(map(len, batches)) == [10, 15, 5, 10] + result_table = pa.Table.from_batches(batches) + assert_frame_equal(result_table.to_pandas(), + pd.concat([df[:10], df, df[:10]], + ignore_index=True)) + + +@pytest.mark.parametrize('use_legacy_ipc_format', [False, True]) +def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format): + stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format + batches = stream_fixture.write_batches() + file_contents = pa.BufferReader(stream_fixture.get_source()) + reader = pa.ipc.open_stream(file_contents) + + assert reader.schema.equals(batches[0].schema) + + total = 0 + for i, next_batch in enumerate(reader): + assert next_batch.equals(batches[i]) + total += 1 + + assert total == len(batches) + + with pytest.raises(StopIteration): + reader.read_next_batch() + + +@pytest.mark.zstd +def test_compression_roundtrip(): + sink = io.BytesIO() + values = np.random.randint(0, 10, 10000) + table = pa.Table.from_arrays([values], names=["values"]) + + options = pa.ipc.IpcWriteOptions(compression='zstd') + with pa.ipc.RecordBatchFileWriter( + sink, table.schema, options=options) as writer: + writer.write_table(table) + len1 = len(sink.getvalue()) + + sink2 = io.BytesIO() + codec = pa.Codec('zstd', compression_level=5) + options = pa.ipc.IpcWriteOptions(compression=codec) + with pa.ipc.RecordBatchFileWriter( + sink2, table.schema, options=options) as writer: + writer.write_table(table) + len2 = len(sink2.getvalue()) + + # In theory len2 should be less than len1 but for this test we just want + # to ensure compression_level is being correctly passed down to the C++ + # layer so we don't really care if it makes it worse or better + assert len2 != len1 + + t1 = pa.ipc.open_file(sink).read_all() + t2 = pa.ipc.open_file(sink2).read_all() + + assert t1 == t2 + + +def test_write_options(): + options = pa.ipc.IpcWriteOptions() + assert options.allow_64bit is False + assert options.use_legacy_format is False + assert options.metadata_version == pa.ipc.MetadataVersion.V5 + + options.allow_64bit = True + assert options.allow_64bit is True + + options.use_legacy_format = True + assert options.use_legacy_format is True + + options.metadata_version = pa.ipc.MetadataVersion.V4 + assert options.metadata_version == pa.ipc.MetadataVersion.V4 + for value in ('V5', 42): + with pytest.raises((TypeError, ValueError)): + options.metadata_version = value + + assert options.compression is None + for value in ['lz4', 'zstd']: + if pa.Codec.is_available(value): + options.compression = value + assert options.compression == value + options.compression = value.upper() + assert options.compression == value + options.compression = None + assert options.compression is None + + with pytest.raises(TypeError): + options.compression = 0 + + assert options.use_threads is True + options.use_threads = False + assert options.use_threads is False + + if pa.Codec.is_available('lz4'): + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4, + allow_64bit=True, + use_legacy_format=True, + compression='lz4', + use_threads=False) + assert options.metadata_version == pa.ipc.MetadataVersion.V4 + assert options.allow_64bit is True + assert options.use_legacy_format is True + assert options.compression == 'lz4' + assert options.use_threads is False + + +def test_write_options_legacy_exclusive(stream_fixture): + with pytest.raises( + ValueError, + match="provide at most one of options and use_legacy_format"): + stream_fixture.use_legacy_ipc_format = True + stream_fixture.options = pa.ipc.IpcWriteOptions() + stream_fixture.write_batches() + + +@pytest.mark.parametrize('options', [ + pa.ipc.IpcWriteOptions(), + pa.ipc.IpcWriteOptions(allow_64bit=True), + pa.ipc.IpcWriteOptions(use_legacy_format=True), + pa.ipc.IpcWriteOptions(metadata_version=pa.ipc.MetadataVersion.V4), + pa.ipc.IpcWriteOptions(use_legacy_format=True, + metadata_version=pa.ipc.MetadataVersion.V4), +]) +def test_stream_options_roundtrip(stream_fixture, options): + stream_fixture.use_legacy_ipc_format = None + stream_fixture.options = options + batches = stream_fixture.write_batches() + file_contents = pa.BufferReader(stream_fixture.get_source()) + + message = pa.ipc.read_message(stream_fixture.get_source()) + assert message.metadata_version == options.metadata_version + + reader = pa.ipc.open_stream(file_contents) + + assert reader.schema.equals(batches[0].schema) + + total = 0 + for i, next_batch in enumerate(reader): + assert next_batch.equals(batches[i]) + total += 1 + + assert total == len(batches) + + with pytest.raises(StopIteration): + reader.read_next_batch() + + +def test_dictionary_delta(stream_fixture): + ty = pa.dictionary(pa.int8(), pa.utf8()) + data = [["foo", "foo", None], + ["foo", "bar", "foo"], # potential delta + ["foo", "bar"], + ["foo", None, "bar", "quux"], # potential delta + ["bar", "quux"], # replacement + ] + batches = [ + pa.RecordBatch.from_arrays([pa.array(v, type=ty)], names=['dicts']) + for v in data] + schema = batches[0].schema + + def write_batches(): + with stream_fixture._get_writer(pa.MockOutputStream(), + schema) as writer: + for batch in batches: + writer.write_batch(batch) + return writer.stats + + st = write_batches() + assert st.num_record_batches == 5 + assert st.num_dictionary_batches == 4 + assert st.num_replaced_dictionaries == 3 + assert st.num_dictionary_deltas == 0 + + stream_fixture.use_legacy_ipc_format = None + stream_fixture.options = pa.ipc.IpcWriteOptions( + emit_dictionary_deltas=True) + st = write_batches() + assert st.num_record_batches == 5 + assert st.num_dictionary_batches == 4 + assert st.num_replaced_dictionaries == 1 + assert st.num_dictionary_deltas == 2 + + +def test_envvar_set_legacy_ipc_format(): + schema = pa.schema([pa.field('foo', pa.int32())]) + + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + + with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'): + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + + with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'): + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + + with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'): + with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'): + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + + +def test_stream_read_all(stream_fixture): + batches = stream_fixture.write_batches() + file_contents = pa.BufferReader(stream_fixture.get_source()) + reader = pa.ipc.open_stream(file_contents) + + result = reader.read_all() + expected = pa.Table.from_batches(batches) + assert result.equals(expected) + + +@pytest.mark.pandas +def test_stream_read_pandas(stream_fixture): + frames = [batch.to_pandas() for batch in stream_fixture.write_batches()] + file_contents = stream_fixture.get_source() + reader = pa.ipc.open_stream(file_contents) + result = reader.read_pandas() + + expected = pd.concat(frames).reset_index(drop=True) + assert_frame_equal(result, expected) + + +@pytest.fixture +def example_messages(stream_fixture): + batches = stream_fixture.write_batches() + file_contents = stream_fixture.get_source() + buf_reader = pa.BufferReader(file_contents) + reader = pa.MessageReader.open_stream(buf_reader) + return batches, list(reader) + + +def test_message_ctors_no_segfault(): + with pytest.raises(TypeError): + repr(pa.Message()) + + with pytest.raises(TypeError): + repr(pa.MessageReader()) + + +def test_message_reader(example_messages): + _, messages = example_messages + + assert len(messages) == 6 + assert messages[0].type == 'schema' + assert isinstance(messages[0].metadata, pa.Buffer) + assert isinstance(messages[0].body, pa.Buffer) + assert messages[0].metadata_version == pa.MetadataVersion.V5 + + for msg in messages[1:]: + assert msg.type == 'record batch' + assert isinstance(msg.metadata, pa.Buffer) + assert isinstance(msg.body, pa.Buffer) + assert msg.metadata_version == pa.MetadataVersion.V5 + + +def test_message_serialize_read_message(example_messages): + _, messages = example_messages + + msg = messages[0] + buf = msg.serialize() + reader = pa.BufferReader(buf.to_pybytes() * 2) + + restored = pa.ipc.read_message(buf) + restored2 = pa.ipc.read_message(reader) + restored3 = pa.ipc.read_message(buf.to_pybytes()) + restored4 = pa.ipc.read_message(reader) + + assert msg.equals(restored) + assert msg.equals(restored2) + assert msg.equals(restored3) + assert msg.equals(restored4) + + with pytest.raises(pa.ArrowInvalid, match="Corrupted message"): + pa.ipc.read_message(pa.BufferReader(b'ab')) + + with pytest.raises(EOFError): + pa.ipc.read_message(reader) + + +@pytest.mark.gzip +def test_message_read_from_compressed(example_messages): + # Part of ARROW-5910 + _, messages = example_messages + for message in messages: + raw_out = pa.BufferOutputStream() + with pa.output_stream(raw_out, compression='gzip') as compressed_out: + message.serialize_to(compressed_out) + + compressed_buf = raw_out.getvalue() + + result = pa.ipc.read_message(pa.input_stream(compressed_buf, + compression='gzip')) + assert result.equals(message) + + +def test_message_read_record_batch(example_messages): + batches, messages = example_messages + + for batch, message in zip(batches, messages[1:]): + read_batch = pa.ipc.read_record_batch(message, batch.schema) + assert read_batch.equals(batch) + + +def test_read_record_batch_on_stream_error_message(): + # ARROW-5374 + batch = pa.record_batch([pa.array([b"foo"], type=pa.utf8())], + names=['strs']) + stream = pa.BufferOutputStream() + with pa.ipc.new_stream(stream, batch.schema) as writer: + writer.write_batch(batch) + buf = stream.getvalue() + with pytest.raises(IOError, + match="type record batch but got schema"): + pa.ipc.read_record_batch(buf, batch.schema) + + +# ---------------------------------------------------------------------- +# Socket streaming testa + + +class StreamReaderServer(threading.Thread): + + def init(self, do_read_all): + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.bind(('127.0.0.1', 0)) + self._sock.listen(1) + host, port = self._sock.getsockname() + self._do_read_all = do_read_all + self._schema = None + self._batches = [] + self._table = None + return port + + def run(self): + connection, client_address = self._sock.accept() + try: + source = connection.makefile(mode='rb') + reader = pa.ipc.open_stream(source) + self._schema = reader.schema + if self._do_read_all: + self._table = reader.read_all() + else: + for i, batch in enumerate(reader): + self._batches.append(batch) + finally: + connection.close() + + def get_result(self): + return(self._schema, self._table if self._do_read_all + else self._batches) + + +class SocketStreamFixture(IpcFixture): + + def __init__(self): + # XXX(wesm): test will decide when to start socket server. This should + # probably be refactored + pass + + def start_server(self, do_read_all): + self._server = StreamReaderServer() + port = self._server.init(do_read_all) + self._server.start() + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.connect(('127.0.0.1', port)) + self.sink = self.get_sink() + + def stop_and_get_result(self): + import struct + self.sink.write(struct.pack('Q', 0)) + self.sink.flush() + self._sock.close() + self._server.join() + return self._server.get_result() + + def get_sink(self): + return self._sock.makefile(mode='wb') + + def _get_writer(self, sink, schema): + return pa.RecordBatchStreamWriter(sink, schema) + + +@pytest.fixture +def socket_fixture(): + return SocketStreamFixture() + + +def test_socket_simple_roundtrip(socket_fixture): + socket_fixture.start_server(do_read_all=False) + writer_batches = socket_fixture.write_batches() + reader_schema, reader_batches = socket_fixture.stop_and_get_result() + + assert reader_schema.equals(writer_batches[0].schema) + assert len(reader_batches) == len(writer_batches) + for i, batch in enumerate(writer_batches): + assert reader_batches[i].equals(batch) + + +def test_socket_read_all(socket_fixture): + socket_fixture.start_server(do_read_all=True) + writer_batches = socket_fixture.write_batches() + _, result = socket_fixture.stop_and_get_result() + + expected = pa.Table.from_batches(writer_batches) + assert result.equals(expected) + + +# ---------------------------------------------------------------------- +# Miscellaneous IPC tests + +@pytest.mark.pandas +def test_ipc_file_stream_has_eos(): + # ARROW-5395 + df = pd.DataFrame({'foo': [1.5]}) + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + write_file(batch, sink) + buffer = sink.getvalue() + + # skip the file magic + reader = pa.ipc.open_stream(buffer[8:]) + + # will fail if encounters footer data instead of eos + rdf = reader.read_pandas() + + assert_frame_equal(df, rdf) + + +@pytest.mark.pandas +def test_ipc_zero_copy_numpy(): + df = pd.DataFrame({'foo': [1.5]}) + + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + write_file(batch, sink) + buffer = sink.getvalue() + reader = pa.BufferReader(buffer) + + batches = read_file(reader) + + data = batches[0].to_pandas() + rdf = pd.DataFrame(data) + assert_frame_equal(df, rdf) + + +def test_ipc_stream_no_batches(): + # ARROW-2307 + table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]), + pa.array(['foo', 'bar', 'baz', 'qux'])], + names=['a', 'b']) + + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, table.schema): + pass + + source = sink.getvalue() + with pa.ipc.open_stream(source) as reader: + result = reader.read_all() + + assert result.schema.equals(table.schema) + assert len(result) == 0 + + +@pytest.mark.pandas +def test_get_record_batch_size(): + N = 10 + itemsize = 8 + df = pd.DataFrame({'foo': np.random.randn(N)}) + + batch = pa.RecordBatch.from_pandas(df) + assert pa.ipc.get_record_batch_size(batch) > (N * itemsize) + + +@pytest.mark.pandas +def _check_serialize_pandas_round_trip(df, use_threads=False): + buf = pa.serialize_pandas(df, nthreads=2 if use_threads else 1) + result = pa.deserialize_pandas(buf, use_threads=use_threads) + assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip(): + index = pd.Index([1, 2, 3], name='my_index') + columns = ['foo', 'bar'] + df = pd.DataFrame( + {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, + index=index, columns=columns + ) + _check_serialize_pandas_round_trip(df) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip_nthreads(): + index = pd.Index([1, 2, 3], name='my_index') + columns = ['foo', 'bar'] + df = pd.DataFrame( + {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, + index=index, columns=columns + ) + _check_serialize_pandas_round_trip(df, use_threads=True) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip_multi_index(): + index1 = pd.Index([1, 2, 3], name='level_1') + index2 = pd.Index(list('def'), name=None) + index = pd.MultiIndex.from_arrays([index1, index2]) + + columns = ['foo', 'bar'] + df = pd.DataFrame( + {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, + index=index, + columns=columns, + ) + _check_serialize_pandas_round_trip(df) + + +@pytest.mark.pandas +def test_serialize_pandas_empty_dataframe(): + df = pd.DataFrame() + _check_serialize_pandas_round_trip(df) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip_not_string_columns(): + df = pd.DataFrame(list(zip([1.5, 1.6, 1.7], 'abc'))) + buf = pa.serialize_pandas(df) + result = pa.deserialize_pandas(buf) + assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_serialize_pandas_no_preserve_index(): + df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3]) + expected = pd.DataFrame({'a': [1, 2, 3]}) + + buf = pa.serialize_pandas(df, preserve_index=False) + result = pa.deserialize_pandas(buf) + assert_frame_equal(result, expected) + + buf = pa.serialize_pandas(df, preserve_index=True) + result = pa.deserialize_pandas(buf) + assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning") +def test_serialize_with_pandas_objects(): + df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3]) + s = pd.Series([1, 2, 3, 4]) + + data = { + 'a_series': df['a'], + 'a_frame': df, + 's_series': s + } + + serialized = pa.serialize(data).to_buffer() + deserialized = pa.deserialize(serialized) + assert_frame_equal(deserialized['a_frame'], df) + + assert_series_equal(deserialized['a_series'], df['a']) + assert deserialized['a_series'].name == 'a' + + assert_series_equal(deserialized['s_series'], s) + assert deserialized['s_series'].name is None + + +@pytest.mark.pandas +def test_schema_batch_serialize_methods(): + nrows = 5 + df = pd.DataFrame({ + 'one': np.random.randn(nrows), + 'two': ['foo', np.nan, 'bar', 'bazbaz', 'qux']}) + batch = pa.RecordBatch.from_pandas(df) + + s_schema = batch.schema.serialize() + s_batch = batch.serialize() + + recons_schema = pa.ipc.read_schema(s_schema) + recons_batch = pa.ipc.read_record_batch(s_batch, recons_schema) + assert recons_batch.equals(batch) + + +def test_schema_serialization_with_metadata(): + field_metadata = {b'foo': b'bar', b'kind': b'field'} + schema_metadata = {b'foo': b'bar', b'kind': b'schema'} + + f0 = pa.field('a', pa.int8()) + f1 = pa.field('b', pa.string(), metadata=field_metadata) + + schema = pa.schema([f0, f1], metadata=schema_metadata) + + s_schema = schema.serialize() + recons_schema = pa.ipc.read_schema(s_schema) + + assert recons_schema.equals(schema) + assert recons_schema.metadata == schema_metadata + assert recons_schema[0].metadata is None + assert recons_schema[1].metadata == field_metadata + + +def test_deprecated_pyarrow_ns_apis(): + table = pa.table([pa.array([1, 2, 3, 4])], names=['a']) + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, table.schema) as writer: + writer.write(table) + + with pytest.warns(FutureWarning, + match="please use pyarrow.ipc.open_stream"): + pa.open_stream(sink.getvalue()) + + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, table.schema) as writer: + writer.write(table) + with pytest.warns(FutureWarning, match="please use pyarrow.ipc.open_file"): + pa.open_file(sink.getvalue()) + + +def write_file(batch, sink): + with pa.ipc.new_file(sink, batch.schema) as writer: + writer.write_batch(batch) + + +def read_file(source): + with pa.ipc.open_file(source) as reader: + return [reader.get_batch(i) for i in range(reader.num_record_batches)] + + +def test_write_empty_ipc_file(): + # ARROW-3894: IPC file was not being properly initialized when no record + # batches are being written + schema = pa.schema([('field', pa.int64())]) + + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, schema): + pass + + buf = sink.getvalue() + with pa.RecordBatchFileReader(pa.BufferReader(buf)) as reader: + table = reader.read_all() + assert len(table) == 0 + assert table.schema.equals(schema) + + +def test_py_record_batch_reader(): + def make_schema(): + return pa.schema([('field', pa.int64())]) + + def make_batches(): + schema = make_schema() + batch1 = pa.record_batch([[1, 2, 3]], schema=schema) + batch2 = pa.record_batch([[4, 5]], schema=schema) + return [batch1, batch2] + + # With iterable + batches = UserList(make_batches()) # weakrefable + wr = weakref.ref(batches) + + with pa.ipc.RecordBatchReader.from_batches(make_schema(), + batches) as reader: + batches = None + assert wr() is not None + assert list(reader) == make_batches() + assert wr() is None + + # With iterator + batches = iter(UserList(make_batches())) # weakrefable + wr = weakref.ref(batches) + + with pa.ipc.RecordBatchReader.from_batches(make_schema(), + batches) as reader: + batches = None + assert wr() is not None + assert list(reader) == make_batches() + assert wr() is None diff --git a/src/arrow/python/pyarrow/tests/test_json.py b/src/arrow/python/pyarrow/tests/test_json.py new file mode 100644 index 000000000..6ce584e51 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_json.py @@ -0,0 +1,310 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import OrderedDict +import io +import itertools +import json +import string +import unittest + +import numpy as np +import pytest + +import pyarrow as pa +from pyarrow.json import read_json, ReadOptions, ParseOptions + + +def generate_col_names(): + # 'a', 'b'... 'z', then 'aa', 'ab'... + letters = string.ascii_lowercase + yield from letters + for first in letters: + for second in letters: + yield first + second + + +def make_random_json(num_cols=2, num_rows=10, linesep='\r\n'): + arr = np.random.RandomState(42).randint(0, 1000, size=(num_cols, num_rows)) + col_names = list(itertools.islice(generate_col_names(), num_cols)) + lines = [] + for row in arr.T: + json_obj = OrderedDict([(k, int(v)) for (k, v) in zip(col_names, row)]) + lines.append(json.dumps(json_obj)) + data = linesep.join(lines).encode() + columns = [pa.array(col, type=pa.int64()) for col in arr] + expected = pa.Table.from_arrays(columns, col_names) + return data, expected + + +def test_read_options(): + cls = ReadOptions + opts = cls() + + assert opts.block_size > 0 + opts.block_size = 12345 + assert opts.block_size == 12345 + + assert opts.use_threads is True + opts.use_threads = False + assert opts.use_threads is False + + opts = cls(block_size=1234, use_threads=False) + assert opts.block_size == 1234 + assert opts.use_threads is False + + +def test_parse_options(): + cls = ParseOptions + opts = cls() + assert opts.newlines_in_values is False + assert opts.explicit_schema is None + + opts.newlines_in_values = True + assert opts.newlines_in_values is True + + schema = pa.schema([pa.field('foo', pa.int32())]) + opts.explicit_schema = schema + assert opts.explicit_schema == schema + + assert opts.unexpected_field_behavior == "infer" + for value in ["ignore", "error", "infer"]: + opts.unexpected_field_behavior = value + assert opts.unexpected_field_behavior == value + + with pytest.raises(ValueError): + opts.unexpected_field_behavior = "invalid-value" + + +class BaseTestJSONRead: + + def read_bytes(self, b, **kwargs): + return self.read_json(pa.py_buffer(b), **kwargs) + + def check_names(self, table, names): + assert table.num_columns == len(names) + assert [c.name for c in table.columns] == names + + def test_file_object(self): + data = b'{"a": 1, "b": 2}\n' + expected_data = {'a': [1], 'b': [2]} + bio = io.BytesIO(data) + table = self.read_json(bio) + assert table.to_pydict() == expected_data + # Text files not allowed + sio = io.StringIO(data.decode()) + with pytest.raises(TypeError): + self.read_json(sio) + + def test_block_sizes(self): + rows = b'{"a": 1}\n{"a": 2}\n{"a": 3}' + read_options = ReadOptions() + parse_options = ParseOptions() + + for data in [rows, rows + b'\n']: + for newlines_in_values in [False, True]: + parse_options.newlines_in_values = newlines_in_values + read_options.block_size = 4 + with pytest.raises(ValueError, + match="try to increase block size"): + self.read_bytes(data, read_options=read_options, + parse_options=parse_options) + + # Validate reader behavior with various block sizes. + # There used to be bugs in this area. + for block_size in range(9, 20): + read_options.block_size = block_size + table = self.read_bytes(data, read_options=read_options, + parse_options=parse_options) + assert table.to_pydict() == {'a': [1, 2, 3]} + + def test_no_newline_at_end(self): + rows = b'{"a": 1,"b": 2, "c": 3}\n{"a": 4,"b": 5, "c": 6}' + table = self.read_bytes(rows) + assert table.to_pydict() == { + 'a': [1, 4], + 'b': [2, 5], + 'c': [3, 6], + } + + def test_simple_ints(self): + # Infer integer columns + rows = b'{"a": 1,"b": 2, "c": 3}\n{"a": 4,"b": 5, "c": 6}\n' + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.int64()), + ('b', pa.int64()), + ('c', pa.int64())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1, 4], + 'b': [2, 5], + 'c': [3, 6], + } + + def test_simple_varied(self): + # Infer various kinds of data + rows = (b'{"a": 1,"b": 2, "c": "3", "d": false}\n' + b'{"a": 4.0, "b": -5, "c": "foo", "d": true}\n') + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.float64()), + ('b', pa.int64()), + ('c', pa.string()), + ('d', pa.bool_())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1.0, 4.0], + 'b': [2, -5], + 'c': ["3", "foo"], + 'd': [False, True], + } + + def test_simple_nulls(self): + # Infer various kinds of data, with nulls + rows = (b'{"a": 1, "b": 2, "c": null, "d": null, "e": null}\n' + b'{"a": null, "b": -5, "c": "foo", "d": null, "e": true}\n' + b'{"a": 4.5, "b": null, "c": "nan", "d": null,"e": false}\n') + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.float64()), + ('b', pa.int64()), + ('c', pa.string()), + ('d', pa.null()), + ('e', pa.bool_())]) + assert table.schema == schema + assert table.to_pydict() == { + 'a': [1.0, None, 4.5], + 'b': [2, -5, None], + 'c': [None, "foo", "nan"], + 'd': [None, None, None], + 'e': [None, True, False], + } + + def test_empty_lists(self): + # ARROW-10955: Infer list(null) + rows = b'{"a": []}' + table = self.read_bytes(rows) + schema = pa.schema([('a', pa.list_(pa.null()))]) + assert table.schema == schema + assert table.to_pydict() == {'a': [[]]} + + def test_empty_rows(self): + rows = b'{}\n{}\n' + table = self.read_bytes(rows) + schema = pa.schema([]) + assert table.schema == schema + assert table.num_columns == 0 + assert table.num_rows == 2 + + def test_reconcile_accross_blocks(self): + # ARROW-12065: reconciling inferred types accross blocks + first_row = b'{ }\n' + read_options = ReadOptions(block_size=len(first_row)) + for next_rows, expected_pylist in [ + (b'{"a": 0}', [None, 0]), + (b'{"a": []}', [None, []]), + (b'{"a": []}\n{"a": [[1]]}', [None, [], [[1]]]), + (b'{"a": {}}', [None, {}]), + (b'{"a": {}}\n{"a": {"b": {"c": 1}}}', + [None, {"b": None}, {"b": {"c": 1}}]), + ]: + table = self.read_bytes(first_row + next_rows, + read_options=read_options) + expected = {"a": expected_pylist} + assert table.to_pydict() == expected + # Check that the issue was exercised + assert table.column("a").num_chunks > 1 + + def test_explicit_schema_with_unexpected_behaviour(self): + # infer by default + rows = (b'{"foo": "bar", "num": 0}\n' + b'{"foo": "baz", "num": 1}\n') + schema = pa.schema([ + ('foo', pa.binary()) + ]) + + opts = ParseOptions(explicit_schema=schema) + table = self.read_bytes(rows, parse_options=opts) + assert table.schema == pa.schema([ + ('foo', pa.binary()), + ('num', pa.int64()) + ]) + assert table.to_pydict() == { + 'foo': [b'bar', b'baz'], + 'num': [0, 1], + } + + # ignore the unexpected fields + opts = ParseOptions(explicit_schema=schema, + unexpected_field_behavior="ignore") + table = self.read_bytes(rows, parse_options=opts) + assert table.schema == pa.schema([ + ('foo', pa.binary()), + ]) + assert table.to_pydict() == { + 'foo': [b'bar', b'baz'], + } + + # raise error + opts = ParseOptions(explicit_schema=schema, + unexpected_field_behavior="error") + with pytest.raises(pa.ArrowInvalid, + match="JSON parse error: unexpected field"): + self.read_bytes(rows, parse_options=opts) + + def test_small_random_json(self): + data, expected = make_random_json(num_cols=2, num_rows=10) + table = self.read_bytes(data) + assert table.schema == expected.schema + assert table.equals(expected) + assert table.to_pydict() == expected.to_pydict() + + def test_stress_block_sizes(self): + # Test a number of small block sizes to stress block stitching + data_base, expected = make_random_json(num_cols=2, num_rows=100) + read_options = ReadOptions() + parse_options = ParseOptions() + + for data in [data_base, data_base.rstrip(b'\r\n')]: + for newlines_in_values in [False, True]: + parse_options.newlines_in_values = newlines_in_values + for block_size in [22, 23, 37]: + read_options.block_size = block_size + table = self.read_bytes(data, read_options=read_options, + parse_options=parse_options) + assert table.schema == expected.schema + if not table.equals(expected): + # Better error output + assert table.to_pydict() == expected.to_pydict() + + +class TestSerialJSONRead(BaseTestJSONRead, unittest.TestCase): + + def read_json(self, *args, **kwargs): + read_options = kwargs.setdefault('read_options', ReadOptions()) + read_options.use_threads = False + table = read_json(*args, **kwargs) + table.validate(full=True) + return table + + +class TestParallelJSONRead(BaseTestJSONRead, unittest.TestCase): + + def read_json(self, *args, **kwargs): + read_options = kwargs.setdefault('read_options', ReadOptions()) + read_options.use_threads = True + table = read_json(*args, **kwargs) + table.validate(full=True) + return table diff --git a/src/arrow/python/pyarrow/tests/test_jvm.py b/src/arrow/python/pyarrow/tests/test_jvm.py new file mode 100644 index 000000000..c5996f921 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_jvm.py @@ -0,0 +1,433 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import os +import pyarrow as pa +import pyarrow.jvm as pa_jvm +import pytest +import sys +import xml.etree.ElementTree as ET + + +jpype = pytest.importorskip("jpype") + + +@pytest.fixture(scope="session") +def root_allocator(): + # This test requires Arrow Java to be built in the same source tree + try: + arrow_dir = os.environ["ARROW_SOURCE_DIR"] + except KeyError: + arrow_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..') + pom_path = os.path.join(arrow_dir, 'java', 'pom.xml') + tree = ET.parse(pom_path) + version = tree.getroot().find( + 'POM:version', + namespaces={ + 'POM': 'http://maven.apache.org/POM/4.0.0' + }).text + jar_path = os.path.join( + arrow_dir, 'java', 'tools', 'target', + 'arrow-tools-{}-jar-with-dependencies.jar'.format(version)) + jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path) + kwargs = {} + # This will be the default behaviour in jpype 0.8+ + kwargs['convertStrings'] = False + jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path, + **kwargs) + return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize) + + +def test_jvm_buffer(root_allocator): + # Create a Java buffer + jvm_buffer = root_allocator.buffer(8) + for i in range(8): + jvm_buffer.setByte(i, 8 - i) + + orig_refcnt = jvm_buffer.refCnt() + + # Convert to Python + buf = pa_jvm.jvm_buffer(jvm_buffer) + + # Check its content + assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01' + + # Check Java buffer lifetime is tied to PyArrow buffer lifetime + assert jvm_buffer.refCnt() == orig_refcnt + 1 + del buf + assert jvm_buffer.refCnt() == orig_refcnt + + +def test_jvm_buffer_released(root_allocator): + import jpype.imports # noqa + from java.lang import IllegalArgumentException + + jvm_buffer = root_allocator.buffer(8) + jvm_buffer.release() + + with pytest.raises(IllegalArgumentException): + pa_jvm.jvm_buffer(jvm_buffer) + + +def _jvm_field(jvm_spec): + om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() + pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field') + return om.readValue(jvm_spec, pojo_Field) + + +def _jvm_schema(jvm_spec, metadata=None): + field = _jvm_field(jvm_spec) + schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema') + fields = jpype.JClass('java.util.ArrayList')() + fields.add(field) + if metadata: + dct = jpype.JClass('java.util.HashMap')() + for k, v in metadata.items(): + dct.put(k, v) + return schema_cls(fields, dct) + else: + return schema_cls(fields) + + +# In the following, we use the JSON serialization of the Field objects in Java. +# This ensures that we neither rely on the exact mechanics on how to construct +# them using Java code as well as enables us to define them as parameters +# without to invoke the JVM. +# +# The specifications were created using: +# +# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() +# field = … # Code to instantiate the field +# jvm_spec = om.writeValueAsString(field) +@pytest.mark.parametrize('pa_type,jvm_spec', [ + (pa.null(), '{"name":"null"}'), + (pa.bool_(), '{"name":"bool"}'), + (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'), + (pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'), + (pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'), + (pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'), + (pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'), + (pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'), + (pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'), + (pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'), + (pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'), + (pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'), + (pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'), + (pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'), + (pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'), + (pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'), + (pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'), + (pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",' + '"timezone":null}'), + (pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",' + '"timezone":null}'), + (pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",' + '"timezone":null}'), + (pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",' + '"timezone":null}'), + (pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"' + ',"timezone":"UTC"}'), + (pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",' + '"unit":"NANOSECOND","timezone":"Europe/Paris"}'), + (pa.date32(), '{"name":"date","unit":"DAY"}'), + (pa.date64(), '{"name":"date","unit":"MILLISECOND"}'), + (pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'), + (pa.string(), '{"name":"utf8"}'), + (pa.binary(), '{"name":"binary"}'), + (pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'), + # TODO(ARROW-2609): complex types that have children + # pa.list_(pa.int32()), + # pa.struct([pa.field('a', pa.int32()), + # pa.field('b', pa.int8()), + # pa.field('c', pa.string())]), + # pa.union([pa.field('a', pa.binary(10)), + # pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), + # pa.union([pa.field('a', pa.binary(10)), + # pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), + # TODO: DictionaryType requires a vector in the type + # pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])), +]) +@pytest.mark.parametrize('nullable', [True, False]) +def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable): + if pa_type == pa.null() and not nullable: + return + spec = { + 'name': 'field_name', + 'nullable': nullable, + 'type': json.loads(jvm_spec), + # TODO: This needs to be set for complex types + 'children': [] + } + jvm_field = _jvm_field(json.dumps(spec)) + result = pa_jvm.field(jvm_field) + expected_field = pa.field('field_name', pa_type, nullable=nullable) + assert result == expected_field + + jvm_schema = _jvm_schema(json.dumps(spec)) + result = pa_jvm.schema(jvm_schema) + assert result == pa.schema([expected_field]) + + # Schema with custom metadata + jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'}) + result = pa_jvm.schema(jvm_schema) + assert result == pa.schema([expected_field], {'meta': 'data'}) + + # Schema with custom field metadata + spec['metadata'] = [{'key': 'field meta', 'value': 'field data'}] + jvm_schema = _jvm_schema(json.dumps(spec)) + result = pa_jvm.schema(jvm_schema) + expected_field = expected_field.with_metadata( + {'field meta': 'field data'}) + assert result == pa.schema([expected_field]) + + +# These test parameters mostly use an integer range as an input as this is +# often the only type that is understood by both Python and Java +# implementations of Arrow. +@pytest.mark.parametrize('pa_type,py_data,jvm_type', [ + (pa.bool_(), [True, False, True, True], 'BitVector'), + (pa.uint8(), list(range(128)), 'UInt1Vector'), + (pa.uint16(), list(range(128)), 'UInt2Vector'), + (pa.int32(), list(range(128)), 'IntVector'), + (pa.int64(), list(range(128)), 'BigIntVector'), + (pa.float32(), list(range(128)), 'Float4Vector'), + (pa.float64(), list(range(128)), 'Float8Vector'), + (pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'), + (pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'), + (pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'), + (pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'), + # TODO(ARROW-2605): These types miss a conversion from pure Python objects + # * pa.time32('s') + # * pa.time32('ms') + # * pa.time64('us') + # * pa.time64('ns') + (pa.date32(), list(range(128)), 'DateDayVector'), + (pa.date64(), list(range(128)), 'DateMilliVector'), + # TODO(ARROW-2606): pa.decimal128(19, 4) +]) +def test_jvm_array(root_allocator, pa_type, py_data, jvm_type): + # Create vector + cls = "org.apache.arrow.vector.{}".format(jvm_type) + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew(len(py_data)) + for i, val in enumerate(py_data): + # char and int are ambiguous overloads for these two setSafe calls + if jvm_type in {'UInt1Vector', 'UInt2Vector'}: + val = jpype.JInt(val) + jvm_vector.setSafe(i, val) + jvm_vector.setValueCount(len(py_data)) + + py_array = pa.array(py_data, type=pa_type) + jvm_array = pa_jvm.array(jvm_vector) + + assert py_array.equals(jvm_array) + + +def test_jvm_array_empty(root_allocator): + cls = "org.apache.arrow.vector.{}".format('IntVector') + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew() + jvm_array = pa_jvm.array(jvm_vector) + assert len(jvm_array) == 0 + assert jvm_array.type == pa.int32() + + +# These test parameters mostly use an integer range as an input as this is +# often the only type that is understood by both Python and Java +# implementations of Arrow. +@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [ + # TODO: null + (pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'), + ( + pa.uint8(), + list(range(128)), + 'UInt1Vector', + '{"name":"int","bitWidth":8,"isSigned":false}' + ), + ( + pa.uint16(), + list(range(128)), + 'UInt2Vector', + '{"name":"int","bitWidth":16,"isSigned":false}' + ), + ( + pa.uint32(), + list(range(128)), + 'UInt4Vector', + '{"name":"int","bitWidth":32,"isSigned":false}' + ), + ( + pa.uint64(), + list(range(128)), + 'UInt8Vector', + '{"name":"int","bitWidth":64,"isSigned":false}' + ), + ( + pa.int8(), + list(range(128)), + 'TinyIntVector', + '{"name":"int","bitWidth":8,"isSigned":true}' + ), + ( + pa.int16(), + list(range(128)), + 'SmallIntVector', + '{"name":"int","bitWidth":16,"isSigned":true}' + ), + ( + pa.int32(), + list(range(128)), + 'IntVector', + '{"name":"int","bitWidth":32,"isSigned":true}' + ), + ( + pa.int64(), + list(range(128)), + 'BigIntVector', + '{"name":"int","bitWidth":64,"isSigned":true}' + ), + # TODO: float16 + ( + pa.float32(), + list(range(128)), + 'Float4Vector', + '{"name":"floatingpoint","precision":"SINGLE"}' + ), + ( + pa.float64(), + list(range(128)), + 'Float8Vector', + '{"name":"floatingpoint","precision":"DOUBLE"}' + ), + ( + pa.timestamp('s'), + list(range(128)), + 'TimeStampSecVector', + '{"name":"timestamp","unit":"SECOND","timezone":null}' + ), + ( + pa.timestamp('ms'), + list(range(128)), + 'TimeStampMilliVector', + '{"name":"timestamp","unit":"MILLISECOND","timezone":null}' + ), + ( + pa.timestamp('us'), + list(range(128)), + 'TimeStampMicroVector', + '{"name":"timestamp","unit":"MICROSECOND","timezone":null}' + ), + ( + pa.timestamp('ns'), + list(range(128)), + 'TimeStampNanoVector', + '{"name":"timestamp","unit":"NANOSECOND","timezone":null}' + ), + # TODO(ARROW-2605): These types miss a conversion from pure Python objects + # * pa.time32('s') + # * pa.time32('ms') + # * pa.time64('us') + # * pa.time64('ns') + ( + pa.date32(), + list(range(128)), + 'DateDayVector', + '{"name":"date","unit":"DAY"}' + ), + ( + pa.date64(), + list(range(128)), + 'DateMilliVector', + '{"name":"date","unit":"MILLISECOND"}' + ), + # TODO(ARROW-2606): pa.decimal128(19, 4) +]) +def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type, + jvm_spec): + # Create vector + cls = "org.apache.arrow.vector.{}".format(jvm_type) + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew(len(py_data)) + for i, val in enumerate(py_data): + if jvm_type in {'UInt1Vector', 'UInt2Vector'}: + val = jpype.JInt(val) + jvm_vector.setSafe(i, val) + jvm_vector.setValueCount(len(py_data)) + + # Create field + spec = { + 'name': 'field_name', + 'nullable': False, + 'type': json.loads(jvm_spec), + # TODO: This needs to be set for complex types + 'children': [] + } + jvm_field = _jvm_field(json.dumps(spec)) + + # Create VectorSchemaRoot + jvm_fields = jpype.JClass('java.util.ArrayList')() + jvm_fields.add(jvm_field) + jvm_vectors = jpype.JClass('java.util.ArrayList')() + jvm_vectors.add(jvm_vector) + jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot') + jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data)) + + py_record_batch = pa.RecordBatch.from_arrays( + [pa.array(py_data, type=pa_type)], + ['col'] + ) + jvm_record_batch = pa_jvm.record_batch(jvm_vsr) + + assert py_record_batch.equals(jvm_record_batch) + + +def _string_to_varchar_holder(ra, string): + nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder" + holder = jpype.JClass(nvch_cls)() + if string is None: + holder.isSet = 0 + else: + holder.isSet = 1 + value = jpype.JClass("java.lang.String")("string") + std_charsets = jpype.JClass("java.nio.charset.StandardCharsets") + bytes_ = value.getBytes(std_charsets.UTF_8) + holder.buffer = ra.buffer(len(bytes_)) + holder.buffer.setBytes(0, bytes_, 0, len(bytes_)) + holder.start = 0 + holder.end = len(bytes_) + return holder + + +# TODO(ARROW-2607) +@pytest.mark.xfail(reason="from_buffers is only supported for " + "primitive arrays yet") +def test_jvm_string_array(root_allocator): + data = ["string", None, "töst"] + cls = "org.apache.arrow.vector.VarCharVector" + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew() + + for i, string in enumerate(data): + holder = _string_to_varchar_holder(root_allocator, "string") + jvm_vector.setSafe(i, holder) + jvm_vector.setValueCount(i + 1) + + py_array = pa.array(data, type=pa.string()) + jvm_array = pa_jvm.array(jvm_vector) + + assert py_array.equals(jvm_array) diff --git a/src/arrow/python/pyarrow/tests/test_memory.py b/src/arrow/python/pyarrow/tests/test_memory.py new file mode 100644 index 000000000..b8dd7344f --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_memory.py @@ -0,0 +1,161 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import contextlib +import os +import subprocess +import sys +import weakref + +import pyarrow as pa + + +possible_backends = ["system", "jemalloc", "mimalloc"] + +should_have_jemalloc = sys.platform == "linux" +should_have_mimalloc = sys.platform == "win32" + + +@contextlib.contextmanager +def allocate_bytes(pool, nbytes): + """ + Temporarily allocate *nbytes* from the given *pool*. + """ + arr = pa.array([b"x" * nbytes], type=pa.binary(), memory_pool=pool) + # Fetch the values buffer from the varbinary array and release the rest, + # to get the desired allocation amount + buf = arr.buffers()[2] + arr = None + assert len(buf) == nbytes + try: + yield + finally: + buf = None + + +def check_allocated_bytes(pool): + """ + Check allocation stats on *pool*. + """ + allocated_before = pool.bytes_allocated() + max_mem_before = pool.max_memory() + with allocate_bytes(pool, 512): + assert pool.bytes_allocated() == allocated_before + 512 + new_max_memory = pool.max_memory() + assert pool.max_memory() >= max_mem_before + assert pool.bytes_allocated() == allocated_before + assert pool.max_memory() == new_max_memory + + +def test_default_allocated_bytes(): + pool = pa.default_memory_pool() + with allocate_bytes(pool, 1024): + check_allocated_bytes(pool) + assert pool.bytes_allocated() == pa.total_allocated_bytes() + + +def test_proxy_memory_pool(): + pool = pa.proxy_memory_pool(pa.default_memory_pool()) + check_allocated_bytes(pool) + wr = weakref.ref(pool) + assert wr() is not None + del pool + assert wr() is None + + +def test_logging_memory_pool(capfd): + pool = pa.logging_memory_pool(pa.default_memory_pool()) + check_allocated_bytes(pool) + out, err = capfd.readouterr() + assert err == "" + assert out.count("Allocate:") > 0 + assert out.count("Allocate:") == out.count("Free:") + + +def test_set_memory_pool(): + old_pool = pa.default_memory_pool() + pool = pa.proxy_memory_pool(old_pool) + pa.set_memory_pool(pool) + try: + allocated_before = pool.bytes_allocated() + with allocate_bytes(None, 512): + assert pool.bytes_allocated() == allocated_before + 512 + assert pool.bytes_allocated() == allocated_before + finally: + pa.set_memory_pool(old_pool) + + +def test_default_backend_name(): + pool = pa.default_memory_pool() + assert pool.backend_name in possible_backends + + +def test_release_unused(): + pool = pa.default_memory_pool() + pool.release_unused() + + +def check_env_var(name, expected, *, expect_warning=False): + code = f"""if 1: + import pyarrow as pa + + pool = pa.default_memory_pool() + assert pool.backend_name in {expected!r}, pool.backend_name + """ + env = dict(os.environ) + env['ARROW_DEFAULT_MEMORY_POOL'] = name + res = subprocess.run([sys.executable, "-c", code], env=env, + universal_newlines=True, stderr=subprocess.PIPE) + if res.returncode != 0: + print(res.stderr, file=sys.stderr) + res.check_returncode() # fail + errlines = res.stderr.splitlines() + if expect_warning: + assert len(errlines) == 1 + assert f"Unsupported backend '{name}'" in errlines[0] + else: + assert len(errlines) == 0 + + +def test_env_var(): + check_env_var("system", ["system"]) + if should_have_jemalloc: + check_env_var("jemalloc", ["jemalloc"]) + if should_have_mimalloc: + check_env_var("mimalloc", ["mimalloc"]) + check_env_var("nonexistent", possible_backends, expect_warning=True) + + +def test_specific_memory_pools(): + specific_pools = set() + + def check(factory, name, *, can_fail=False): + if can_fail: + try: + pool = factory() + except NotImplementedError: + return + else: + pool = factory() + assert pool.backend_name == name + specific_pools.add(pool) + + check(pa.system_memory_pool, "system") + check(pa.jemalloc_memory_pool, "jemalloc", + can_fail=not should_have_jemalloc) + check(pa.mimalloc_memory_pool, "mimalloc", + can_fail=not should_have_mimalloc) diff --git a/src/arrow/python/pyarrow/tests/test_misc.py b/src/arrow/python/pyarrow/tests/test_misc.py new file mode 100644 index 000000000..012f15e16 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_misc.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import subprocess +import sys + +import pytest + +import pyarrow as pa + + +def test_get_include(): + include_dir = pa.get_include() + assert os.path.exists(os.path.join(include_dir, 'arrow', 'api.h')) + + +@pytest.mark.skipif('sys.platform != "win32"') +def test_get_library_dirs_win32(): + assert any(os.path.exists(os.path.join(directory, 'arrow.lib')) + for directory in pa.get_library_dirs()) + + +def test_cpu_count(): + n = pa.cpu_count() + assert n > 0 + try: + pa.set_cpu_count(n + 5) + assert pa.cpu_count() == n + 5 + finally: + pa.set_cpu_count(n) + + +def test_io_thread_count(): + n = pa.io_thread_count() + assert n > 0 + try: + pa.set_io_thread_count(n + 5) + assert pa.io_thread_count() == n + 5 + finally: + pa.set_io_thread_count(n) + + +def test_build_info(): + assert isinstance(pa.cpp_build_info, pa.BuildInfo) + assert isinstance(pa.cpp_version_info, pa.VersionInfo) + assert isinstance(pa.cpp_version, str) + assert isinstance(pa.__version__, str) + assert pa.cpp_build_info.version_info == pa.cpp_version_info + + # assert pa.version == pa.__version__ # XXX currently false + + +def test_runtime_info(): + info = pa.runtime_info() + assert isinstance(info, pa.RuntimeInfo) + possible_simd_levels = ('none', 'sse4_2', 'avx', 'avx2', 'avx512') + assert info.simd_level in possible_simd_levels + assert info.detected_simd_level in possible_simd_levels + + if info.simd_level != 'none': + env = os.environ.copy() + env['ARROW_USER_SIMD_LEVEL'] = 'none' + code = f"""if 1: + import pyarrow as pa + + info = pa.runtime_info() + assert info.simd_level == 'none', info.simd_level + assert info.detected_simd_level == {info.detected_simd_level!r},\ + info.detected_simd_level + """ + subprocess.check_call([sys.executable, "-c", code], env=env) + + +@pytest.mark.parametrize('klass', [ + pa.Field, + pa.Schema, + pa.ChunkedArray, + pa.RecordBatch, + pa.Table, + pa.Buffer, + pa.Array, + pa.Tensor, + pa.DataType, + pa.ListType, + pa.LargeListType, + pa.FixedSizeListType, + pa.UnionType, + pa.SparseUnionType, + pa.DenseUnionType, + pa.StructType, + pa.Time32Type, + pa.Time64Type, + pa.TimestampType, + pa.Decimal128Type, + pa.Decimal256Type, + pa.DictionaryType, + pa.FixedSizeBinaryType, + pa.NullArray, + pa.NumericArray, + pa.IntegerArray, + pa.FloatingPointArray, + pa.BooleanArray, + pa.Int8Array, + pa.Int16Array, + pa.Int32Array, + pa.Int64Array, + pa.UInt8Array, + pa.UInt16Array, + pa.UInt32Array, + pa.UInt64Array, + pa.ListArray, + pa.LargeListArray, + pa.MapArray, + pa.FixedSizeListArray, + pa.UnionArray, + pa.BinaryArray, + pa.StringArray, + pa.FixedSizeBinaryArray, + pa.DictionaryArray, + pa.Date32Array, + pa.Date64Array, + pa.TimestampArray, + pa.Time32Array, + pa.Time64Array, + pa.DurationArray, + pa.Decimal128Array, + pa.Decimal256Array, + pa.StructArray, + pa.Scalar, + pa.BooleanScalar, + pa.Int8Scalar, + pa.Int16Scalar, + pa.Int32Scalar, + pa.Int64Scalar, + pa.UInt8Scalar, + pa.UInt16Scalar, + pa.UInt32Scalar, + pa.UInt64Scalar, + pa.HalfFloatScalar, + pa.FloatScalar, + pa.DoubleScalar, + pa.Decimal128Scalar, + pa.Decimal256Scalar, + pa.Date32Scalar, + pa.Date64Scalar, + pa.Time32Scalar, + pa.Time64Scalar, + pa.TimestampScalar, + pa.DurationScalar, + pa.StringScalar, + pa.BinaryScalar, + pa.FixedSizeBinaryScalar, + pa.ListScalar, + pa.LargeListScalar, + pa.MapScalar, + pa.FixedSizeListScalar, + pa.UnionScalar, + pa.StructScalar, + pa.DictionaryScalar, + pa.ipc.Message, + pa.ipc.MessageReader, + pa.MemoryPool, + pa.LoggingMemoryPool, + pa.ProxyMemoryPool, +]) +def test_extension_type_constructor_errors(klass): + # ARROW-2638: prevent calling extension class constructors directly + msg = "Do not call {cls}'s constructor directly, use .* instead." + with pytest.raises(TypeError, match=msg.format(cls=klass.__name__)): + klass() diff --git a/src/arrow/python/pyarrow/tests/test_orc.py b/src/arrow/python/pyarrow/tests/test_orc.py new file mode 100644 index 000000000..f38121ffb --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_orc.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import decimal +import datetime + +import pyarrow as pa + + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not orc' +pytestmark = pytest.mark.orc + + +try: + from pandas.testing import assert_frame_equal + import pandas as pd +except ImportError: + pass + + +@pytest.fixture(scope="module") +def datadir(base_datadir): + return base_datadir / "orc" + + +def fix_example_values(actual_cols, expected_cols): + """ + Fix type of expected values (as read from JSON) according to + actual ORC datatype. + """ + for name in expected_cols: + expected = expected_cols[name] + actual = actual_cols[name] + if (name == "map" and + [d.keys() == {'key', 'value'} for m in expected for d in m]): + # convert [{'key': k, 'value': v}, ...] to [(k, v), ...] + for i, m in enumerate(expected): + expected_cols[name][i] = [(d['key'], d['value']) for d in m] + continue + + typ = actual[0].__class__ + if issubclass(typ, datetime.datetime): + # timestamp fields are represented as strings in JSON files + expected = pd.to_datetime(expected) + elif issubclass(typ, datetime.date): + # date fields are represented as strings in JSON files + expected = expected.dt.date + elif typ is decimal.Decimal: + converted_decimals = [None] * len(expected) + # decimal fields are represented as reals in JSON files + for i, (d, v) in enumerate(zip(actual, expected)): + if not pd.isnull(v): + exp = d.as_tuple().exponent + factor = 10 ** -exp + converted_decimals[i] = ( + decimal.Decimal(round(v * factor)).scaleb(exp)) + expected = pd.Series(converted_decimals) + + expected_cols[name] = expected + + +def check_example_values(orc_df, expected_df, start=None, stop=None): + if start is not None or stop is not None: + expected_df = expected_df[start:stop].reset_index(drop=True) + assert_frame_equal(orc_df, expected_df, check_dtype=False) + + +def check_example_file(orc_path, expected_df, need_fix=False): + """ + Check a ORC file against the expected columns dictionary. + """ + from pyarrow import orc + + orc_file = orc.ORCFile(orc_path) + # Exercise ORCFile.read() + table = orc_file.read() + assert isinstance(table, pa.Table) + table.validate() + + # This workaround needed because of ARROW-3080 + orc_df = pd.DataFrame(table.to_pydict()) + + assert set(expected_df.columns) == set(orc_df.columns) + + # reorder columns if necessary + if not orc_df.columns.equals(expected_df.columns): + expected_df = expected_df.reindex(columns=orc_df.columns) + + if need_fix: + fix_example_values(orc_df, expected_df) + + check_example_values(orc_df, expected_df) + # Exercise ORCFile.read_stripe() + json_pos = 0 + for i in range(orc_file.nstripes): + batch = orc_file.read_stripe(i) + check_example_values(pd.DataFrame(batch.to_pydict()), + expected_df, + start=json_pos, + stop=json_pos + len(batch)) + json_pos += len(batch) + assert json_pos == orc_file.nrows + + +@pytest.mark.pandas +@pytest.mark.parametrize('filename', [ + 'TestOrcFile.test1.orc', + 'TestOrcFile.testDate1900.orc', + 'decimal.orc' +]) +def test_example_using_json(filename, datadir): + """ + Check a ORC file example against the equivalent JSON file, as given + in the Apache ORC repository (the JSON file has one JSON object per + line, corresponding to one row in the ORC file). + """ + # Read JSON file + path = datadir / filename + table = pd.read_json(str(path.with_suffix('.jsn.gz')), lines=True) + check_example_file(path, table, need_fix=True) + + +def test_orcfile_empty(datadir): + from pyarrow import orc + + table = orc.ORCFile(datadir / "TestOrcFile.emptyFile.orc").read() + assert table.num_rows == 0 + + expected_schema = pa.schema([ + ("boolean1", pa.bool_()), + ("byte1", pa.int8()), + ("short1", pa.int16()), + ("int1", pa.int32()), + ("long1", pa.int64()), + ("float1", pa.float32()), + ("double1", pa.float64()), + ("bytes1", pa.binary()), + ("string1", pa.string()), + ("middle", pa.struct( + [("list", pa.list_( + pa.struct([("int1", pa.int32()), + ("string1", pa.string())]))) + ])), + ("list", pa.list_( + pa.struct([("int1", pa.int32()), + ("string1", pa.string())]) + )), + ("map", pa.map_(pa.string(), + pa.struct([("int1", pa.int32()), + ("string1", pa.string())]) + )), + ]) + assert table.schema == expected_schema + + +def test_orcfile_readwrite(): + from pyarrow import orc + + buffer_output_stream = pa.BufferOutputStream() + a = pa.array([1, None, 3, None]) + b = pa.array([None, "Arrow", None, "ORC"]) + table = pa.table({"int64": a, "utf8": b}) + orc.write_table(table, buffer_output_stream) + buffer_reader = pa.BufferReader(buffer_output_stream.getvalue()) + orc_file = orc.ORCFile(buffer_reader) + output_table = orc_file.read() + assert table.equals(output_table) + + # deprecated keyword order + buffer_output_stream = pa.BufferOutputStream() + with pytest.warns(FutureWarning): + orc.write_table(buffer_output_stream, table) + buffer_reader = pa.BufferReader(buffer_output_stream.getvalue()) + orc_file = orc.ORCFile(buffer_reader) + output_table = orc_file.read() + assert table.equals(output_table) + + +def test_column_selection(tempdir): + from pyarrow import orc + + # create a table with nested types + inner = pa.field('inner', pa.int64()) + middle = pa.field('middle', pa.struct([inner])) + fields = [ + pa.field('basic', pa.int32()), + pa.field( + 'list', pa.list_(pa.field('item', pa.int32())) + ), + pa.field( + 'struct', pa.struct([middle, pa.field('inner2', pa.int64())]) + ), + pa.field( + 'list-struct', pa.list_(pa.field( + 'item', pa.struct([ + pa.field('inner1', pa.int64()), + pa.field('inner2', pa.int64()) + ]) + )) + ), + pa.field('basic2', pa.int64()), + ] + arrs = [ + [0], [[1, 2]], [{"middle": {"inner": 3}, "inner2": 4}], + [[{"inner1": 5, "inner2": 6}, {"inner1": 7, "inner2": 8}]], [9]] + table = pa.table(arrs, schema=pa.schema(fields)) + + path = str(tempdir / 'test.orc') + orc.write_table(table, path) + orc_file = orc.ORCFile(path) + + # default selecting all columns + result1 = orc_file.read() + assert result1.equals(table) + + # selecting with columns names + result2 = orc_file.read(columns=["basic", "basic2"]) + assert result2.equals(table.select(["basic", "basic2"])) + + result3 = orc_file.read(columns=["list", "struct", "basic2"]) + assert result3.equals(table.select(["list", "struct", "basic2"])) + + # using dotted paths + result4 = orc_file.read(columns=["struct.middle.inner"]) + expected4 = pa.table({"struct": [{"middle": {"inner": 3}}]}) + assert result4.equals(expected4) + + result5 = orc_file.read(columns=["struct.inner2"]) + expected5 = pa.table({"struct": [{"inner2": 4}]}) + assert result5.equals(expected5) + + result6 = orc_file.read( + columns=["list", "struct.middle.inner", "struct.inner2"] + ) + assert result6.equals(table.select(["list", "struct"])) + + result7 = orc_file.read(columns=["list-struct.inner1"]) + expected7 = pa.table({"list-struct": [[{"inner1": 5}, {"inner1": 7}]]}) + assert result7.equals(expected7) + + # selecting with (Arrow-based) field indices + result2 = orc_file.read(columns=[0, 4]) + assert result2.equals(table.select(["basic", "basic2"])) + + result3 = orc_file.read(columns=[1, 2, 3]) + assert result3.equals(table.select(["list", "struct", "list-struct"])) + + # error on non-existing name or index + with pytest.raises(IOError): + # liborc returns ParseError, which gets translated into IOError + # instead of ValueError + orc_file.read(columns=["wrong"]) + + with pytest.raises(ValueError): + orc_file.read(columns=[5]) diff --git a/src/arrow/python/pyarrow/tests/test_pandas.py b/src/arrow/python/pyarrow/tests/test_pandas.py new file mode 100644 index 000000000..112c7938e --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_pandas.py @@ -0,0 +1,4386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import gc +import decimal +import json +import multiprocessing as mp +import sys + +from collections import OrderedDict +from datetime import date, datetime, time, timedelta, timezone + +import hypothesis as h +import hypothesis.extra.pytz as tzst +import hypothesis.strategies as st +import numpy as np +import numpy.testing as npt +import pytest +import pytz + +from pyarrow.pandas_compat import get_logical_type, _pandas_api +from pyarrow.tests.util import invoke_script, random_ascii, rands +import pyarrow.tests.strategies as past +from pyarrow.vendored.version import Version + +import pyarrow as pa +try: + from pyarrow import parquet as pq +except ImportError: + pass + +try: + import pandas as pd + import pandas.testing as tm + from .pandas_examples import dataframe_with_arrays, dataframe_with_lists +except ImportError: + pass + + +# Marks all of the tests in this module +pytestmark = pytest.mark.pandas + + +def _alltypes_example(size=100): + return pd.DataFrame({ + 'uint8': np.arange(size, dtype=np.uint8), + 'uint16': np.arange(size, dtype=np.uint16), + 'uint32': np.arange(size, dtype=np.uint32), + 'uint64': np.arange(size, dtype=np.uint64), + 'int8': np.arange(size, dtype=np.int16), + 'int16': np.arange(size, dtype=np.int16), + 'int32': np.arange(size, dtype=np.int32), + 'int64': np.arange(size, dtype=np.int64), + 'float32': np.arange(size, dtype=np.float32), + 'float64': np.arange(size, dtype=np.float64), + 'bool': np.random.randn(size) > 0, + # TODO(wesm): Pandas only support ns resolution, Arrow supports s, ms, + # us, ns + 'datetime': np.arange("2016-01-01T00:00:00.001", size, + dtype='datetime64[ms]'), + 'str': [str(x) for x in range(size)], + 'str_with_nulls': [None] + [str(x) for x in range(size - 2)] + [None], + 'empty_str': [''] * size + }) + + +def _check_pandas_roundtrip(df, expected=None, use_threads=False, + expected_schema=None, + check_dtype=True, schema=None, + preserve_index=False, + as_batch=False): + klass = pa.RecordBatch if as_batch else pa.Table + table = klass.from_pandas(df, schema=schema, + preserve_index=preserve_index, + nthreads=2 if use_threads else 1) + result = table.to_pandas(use_threads=use_threads) + + if expected_schema: + # all occurrences of _check_pandas_roundtrip passes expected_schema + # without the pandas generated key-value metadata + assert table.schema.equals(expected_schema) + + if expected is None: + expected = df + + tm.assert_frame_equal(result, expected, check_dtype=check_dtype, + check_index_type=('equiv' if preserve_index + else False)) + + +def _check_series_roundtrip(s, type_=None, expected_pa_type=None): + arr = pa.array(s, from_pandas=True, type=type_) + + if type_ is not None and expected_pa_type is None: + expected_pa_type = type_ + + if expected_pa_type is not None: + assert arr.type == expected_pa_type + + result = pd.Series(arr.to_pandas(), name=s.name) + tm.assert_series_equal(s, result) + + +def _check_array_roundtrip(values, expected=None, mask=None, + type=None): + arr = pa.array(values, from_pandas=True, mask=mask, type=type) + result = arr.to_pandas() + + values_nulls = pd.isnull(values) + if mask is None: + assert arr.null_count == values_nulls.sum() + else: + assert arr.null_count == (mask | values_nulls).sum() + + if expected is None: + if mask is None: + expected = pd.Series(values) + else: + expected = pd.Series(np.ma.masked_array(values, mask=mask)) + + tm.assert_series_equal(pd.Series(result), expected, check_names=False) + + +def _check_array_from_pandas_roundtrip(np_array, type=None): + arr = pa.array(np_array, from_pandas=True, type=type) + result = arr.to_pandas() + npt.assert_array_equal(result, np_array) + + +class TestConvertMetadata: + """ + Conversion tests for Pandas metadata & indices. + """ + + def test_non_string_columns(self): + df = pd.DataFrame({0: [1, 2, 3]}) + table = pa.Table.from_pandas(df) + assert table.field(0).name == '0' + + def test_from_pandas_with_columns(self): + df = pd.DataFrame({0: [1, 2, 3], 1: [1, 3, 3], 2: [2, 4, 5]}, + columns=[1, 0]) + + table = pa.Table.from_pandas(df, columns=[0, 1]) + expected = pa.Table.from_pandas(df[[0, 1]]) + assert expected.equals(table) + + record_batch_table = pa.RecordBatch.from_pandas(df, columns=[0, 1]) + record_batch_expected = pa.RecordBatch.from_pandas(df[[0, 1]]) + assert record_batch_expected.equals(record_batch_table) + + def test_column_index_names_are_preserved(self): + df = pd.DataFrame({'data': [1, 2, 3]}) + df.columns.names = ['a'] + _check_pandas_roundtrip(df, preserve_index=True) + + def test_range_index_shortcut(self): + # ARROW-1639 + index_name = 'foo' + df = pd.DataFrame({'a': [1, 2, 3, 4]}, + index=pd.RangeIndex(0, 8, step=2, name=index_name)) + + df2 = pd.DataFrame({'a': [4, 5, 6, 7]}, + index=pd.RangeIndex(0, 4)) + + table = pa.Table.from_pandas(df) + table_no_index_name = pa.Table.from_pandas(df2) + + # The RangeIndex is tracked in the metadata only + assert len(table.schema) == 1 + + result = table.to_pandas() + tm.assert_frame_equal(result, df) + assert isinstance(result.index, pd.RangeIndex) + assert _pandas_api.get_rangeindex_attribute(result.index, 'step') == 2 + assert result.index.name == index_name + + result2 = table_no_index_name.to_pandas() + tm.assert_frame_equal(result2, df2) + assert isinstance(result2.index, pd.RangeIndex) + assert _pandas_api.get_rangeindex_attribute(result2.index, 'step') == 1 + assert result2.index.name is None + + def test_range_index_force_serialization(self): + # ARROW-5427: preserve_index=True will force the RangeIndex to + # be serialized as a column rather than tracked more + # efficiently as metadata + df = pd.DataFrame({'a': [1, 2, 3, 4]}, + index=pd.RangeIndex(0, 8, step=2, name='foo')) + + table = pa.Table.from_pandas(df, preserve_index=True) + assert table.num_columns == 2 + assert 'foo' in table.column_names + + restored = table.to_pandas() + tm.assert_frame_equal(restored, df) + + def test_rangeindex_doesnt_warn(self): + # ARROW-5606: pandas 0.25 deprecated private _start/stop/step + # attributes -> can be removed if support < pd 0.25 is dropped + df = pd.DataFrame(np.random.randn(4, 2), columns=['a', 'b']) + + with pytest.warns(None) as record: + _check_pandas_roundtrip(df, preserve_index=True) + + assert len(record) == 0 + + def test_multiindex_columns(self): + columns = pd.MultiIndex.from_arrays([ + ['one', 'two'], ['X', 'Y'] + ]) + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns) + _check_pandas_roundtrip(df, preserve_index=True) + + def test_multiindex_columns_with_dtypes(self): + columns = pd.MultiIndex.from_arrays( + [ + ['one', 'two'], + pd.DatetimeIndex(['2017-08-01', '2017-08-02']), + ], + names=['level_1', 'level_2'], + ) + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns) + _check_pandas_roundtrip(df, preserve_index=True) + + def test_multiindex_with_column_dtype_object(self): + # ARROW-3651 & ARROW-9096 + # Bug when dtype of the columns is object. + + # uinderlying dtype: integer + df = pd.DataFrame([1], columns=pd.Index([1], dtype=object)) + _check_pandas_roundtrip(df, preserve_index=True) + + # underlying dtype: floating + df = pd.DataFrame([1], columns=pd.Index([1.1], dtype=object)) + _check_pandas_roundtrip(df, preserve_index=True) + + # underlying dtype: datetime + # ARROW-9096: a simple roundtrip now works + df = pd.DataFrame([1], columns=pd.Index( + [datetime(2018, 1, 1)], dtype="object")) + _check_pandas_roundtrip(df, preserve_index=True) + + def test_multiindex_columns_unicode(self): + columns = pd.MultiIndex.from_arrays([['あ', 'い'], ['X', 'Y']]) + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns) + _check_pandas_roundtrip(df, preserve_index=True) + + def test_multiindex_doesnt_warn(self): + # ARROW-3953: pandas 0.24 rename of MultiIndex labels to codes + columns = pd.MultiIndex.from_arrays([['one', 'two'], ['X', 'Y']]) + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')], columns=columns) + + with pytest.warns(None) as record: + _check_pandas_roundtrip(df, preserve_index=True) + + assert len(record) == 0 + + def test_integer_index_column(self): + df = pd.DataFrame([(1, 'a'), (2, 'b'), (3, 'c')]) + _check_pandas_roundtrip(df, preserve_index=True) + + def test_index_metadata_field_name(self): + # test None case, and strangely named non-index columns + df = pd.DataFrame( + [(1, 'a', 3.1), (2, 'b', 2.2), (3, 'c', 1.3)], + index=pd.MultiIndex.from_arrays( + [['c', 'b', 'a'], [3, 2, 1]], + names=[None, 'foo'] + ), + columns=['a', None, '__index_level_0__'], + ) + with pytest.warns(UserWarning): + t = pa.Table.from_pandas(df, preserve_index=True) + js = t.schema.pandas_metadata + + col1, col2, col3, idx0, foo = js['columns'] + + assert col1['name'] == 'a' + assert col1['name'] == col1['field_name'] + + assert col2['name'] is None + assert col2['field_name'] == 'None' + + assert col3['name'] == '__index_level_0__' + assert col3['name'] == col3['field_name'] + + idx0_descr, foo_descr = js['index_columns'] + assert idx0_descr == '__index_level_0__' + assert idx0['field_name'] == idx0_descr + assert idx0['name'] is None + + assert foo_descr == 'foo' + assert foo['field_name'] == foo_descr + assert foo['name'] == foo_descr + + def test_categorical_column_index(self): + df = pd.DataFrame( + [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)], + columns=pd.Index(list('def'), dtype='category') + ) + t = pa.Table.from_pandas(df, preserve_index=True) + js = t.schema.pandas_metadata + + column_indexes, = js['column_indexes'] + assert column_indexes['name'] is None + assert column_indexes['pandas_type'] == 'categorical' + assert column_indexes['numpy_type'] == 'int8' + + md = column_indexes['metadata'] + assert md['num_categories'] == 3 + assert md['ordered'] is False + + def test_string_column_index(self): + df = pd.DataFrame( + [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)], + columns=pd.Index(list('def'), name='stringz') + ) + t = pa.Table.from_pandas(df, preserve_index=True) + js = t.schema.pandas_metadata + + column_indexes, = js['column_indexes'] + assert column_indexes['name'] == 'stringz' + assert column_indexes['name'] == column_indexes['field_name'] + assert column_indexes['numpy_type'] == 'object' + assert column_indexes['pandas_type'] == 'unicode' + + md = column_indexes['metadata'] + + assert len(md) == 1 + assert md['encoding'] == 'UTF-8' + + def test_datetimetz_column_index(self): + df = pd.DataFrame( + [(1, 'a', 2.0), (2, 'b', 3.0), (3, 'c', 4.0)], + columns=pd.date_range( + start='2017-01-01', periods=3, tz='America/New_York' + ) + ) + t = pa.Table.from_pandas(df, preserve_index=True) + js = t.schema.pandas_metadata + + column_indexes, = js['column_indexes'] + assert column_indexes['name'] is None + assert column_indexes['pandas_type'] == 'datetimetz' + assert column_indexes['numpy_type'] == 'datetime64[ns]' + + md = column_indexes['metadata'] + assert md['timezone'] == 'America/New_York' + + def test_datetimetz_row_index(self): + df = pd.DataFrame({ + 'a': pd.date_range( + start='2017-01-01', periods=3, tz='America/New_York' + ) + }) + df = df.set_index('a') + + _check_pandas_roundtrip(df, preserve_index=True) + + def test_categorical_row_index(self): + df = pd.DataFrame({'a': [1, 2, 3], 'b': [1, 2, 3]}) + df['a'] = df.a.astype('category') + df = df.set_index('a') + + _check_pandas_roundtrip(df, preserve_index=True) + + def test_duplicate_column_names_does_not_crash(self): + df = pd.DataFrame([(1, 'a'), (2, 'b')], columns=list('aa')) + with pytest.raises(ValueError): + pa.Table.from_pandas(df) + + def test_dictionary_indices_boundscheck(self): + # ARROW-1658. No validation of indices leads to segfaults in pandas + indices = [[0, 1], [0, -1]] + + for inds in indices: + arr = pa.DictionaryArray.from_arrays(inds, ['a'], safe=False) + batch = pa.RecordBatch.from_arrays([arr], ['foo']) + table = pa.Table.from_batches([batch, batch, batch]) + + with pytest.raises(IndexError): + arr.to_pandas() + + with pytest.raises(IndexError): + table.to_pandas() + + def test_unicode_with_unicode_column_and_index(self): + df = pd.DataFrame({'あ': ['い']}, index=['う']) + + _check_pandas_roundtrip(df, preserve_index=True) + + def test_mixed_column_names(self): + # mixed type column names are not reconstructed exactly + df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + + for cols in [['あ', b'a'], [1, '2'], [1, 1.5]]: + df.columns = pd.Index(cols, dtype=object) + + # assert that the from_pandas raises the warning + with pytest.warns(UserWarning): + pa.Table.from_pandas(df) + + expected = df.copy() + expected.columns = df.columns.values.astype(str) + with pytest.warns(UserWarning): + _check_pandas_roundtrip(df, expected=expected, + preserve_index=True) + + def test_binary_column_name(self): + column_data = ['い'] + key = 'あ'.encode() + data = {key: column_data} + df = pd.DataFrame(data) + + # we can't use _check_pandas_roundtrip here because our metadata + # is always decoded as utf8: even if binary goes in, utf8 comes out + t = pa.Table.from_pandas(df, preserve_index=True) + df2 = t.to_pandas() + assert df.values[0] == df2.values[0] + assert df.index.values[0] == df2.index.values[0] + assert df.columns[0] == key + + def test_multiindex_duplicate_values(self): + num_rows = 3 + numbers = list(range(num_rows)) + index = pd.MultiIndex.from_arrays( + [['foo', 'foo', 'bar'], numbers], + names=['foobar', 'some_numbers'], + ) + + df = pd.DataFrame({'numbers': numbers}, index=index) + + _check_pandas_roundtrip(df, preserve_index=True) + + def test_metadata_with_mixed_types(self): + df = pd.DataFrame({'data': [b'some_bytes', 'some_unicode']}) + table = pa.Table.from_pandas(df) + js = table.schema.pandas_metadata + assert 'mixed' not in js + data_column = js['columns'][0] + assert data_column['pandas_type'] == 'bytes' + assert data_column['numpy_type'] == 'object' + + def test_ignore_metadata(self): + df = pd.DataFrame({'a': [1, 2, 3], 'b': ['foo', 'bar', 'baz']}, + index=['one', 'two', 'three']) + table = pa.Table.from_pandas(df) + + result = table.to_pandas(ignore_metadata=True) + expected = (table.cast(table.schema.remove_metadata()) + .to_pandas()) + + tm.assert_frame_equal(result, expected) + + def test_list_metadata(self): + df = pd.DataFrame({'data': [[1], [2, 3, 4], [5] * 7]}) + schema = pa.schema([pa.field('data', type=pa.list_(pa.int64()))]) + table = pa.Table.from_pandas(df, schema=schema) + js = table.schema.pandas_metadata + assert 'mixed' not in js + data_column = js['columns'][0] + assert data_column['pandas_type'] == 'list[int64]' + assert data_column['numpy_type'] == 'object' + + def test_struct_metadata(self): + df = pd.DataFrame({'dicts': [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]}) + table = pa.Table.from_pandas(df) + pandas_metadata = table.schema.pandas_metadata + assert pandas_metadata['columns'][0]['pandas_type'] == 'object' + + def test_decimal_metadata(self): + expected = pd.DataFrame({ + 'decimals': [ + decimal.Decimal('394092382910493.12341234678'), + -decimal.Decimal('314292388910493.12343437128'), + ] + }) + table = pa.Table.from_pandas(expected) + js = table.schema.pandas_metadata + assert 'mixed' not in js + data_column = js['columns'][0] + assert data_column['pandas_type'] == 'decimal' + assert data_column['numpy_type'] == 'object' + assert data_column['metadata'] == {'precision': 26, 'scale': 11} + + def test_table_column_subset_metadata(self): + # ARROW-1883 + # non-default index + for index in [ + pd.Index(['a', 'b', 'c'], name='index'), + pd.date_range("2017-01-01", periods=3, tz='Europe/Brussels')]: + df = pd.DataFrame({'a': [1, 2, 3], + 'b': [.1, .2, .3]}, index=index) + table = pa.Table.from_pandas(df) + + table_subset = table.remove_column(1) + result = table_subset.to_pandas() + expected = df[['a']] + if isinstance(df.index, pd.DatetimeIndex): + df.index.freq = None + tm.assert_frame_equal(result, expected) + + table_subset2 = table_subset.remove_column(1) + result = table_subset2.to_pandas() + tm.assert_frame_equal(result, df[['a']].reset_index(drop=True)) + + def test_to_pandas_column_subset_multiindex(self): + # ARROW-10122 + df = pd.DataFrame( + {"first": list(range(5)), + "second": list(range(5)), + "value": np.arange(5)} + ) + table = pa.Table.from_pandas(df.set_index(["first", "second"])) + + subset = table.select(["first", "value"]) + result = subset.to_pandas() + expected = df[["first", "value"]].set_index("first") + tm.assert_frame_equal(result, expected) + + def test_empty_list_metadata(self): + # Create table with array of empty lists, forced to have type + # list(string) in pyarrow + c1 = [["test"], ["a", "b"], None] + c2 = [[], [], []] + arrays = OrderedDict([ + ('c1', pa.array(c1, type=pa.list_(pa.string()))), + ('c2', pa.array(c2, type=pa.list_(pa.string()))), + ]) + rb = pa.RecordBatch.from_arrays( + list(arrays.values()), + list(arrays.keys()) + ) + tbl = pa.Table.from_batches([rb]) + + # First roundtrip changes schema, because pandas cannot preserve the + # type of empty lists + df = tbl.to_pandas() + tbl2 = pa.Table.from_pandas(df) + md2 = tbl2.schema.pandas_metadata + + # Second roundtrip + df2 = tbl2.to_pandas() + expected = pd.DataFrame(OrderedDict([('c1', c1), ('c2', c2)])) + + tm.assert_frame_equal(df2, expected) + + assert md2['columns'] == [ + { + 'name': 'c1', + 'field_name': 'c1', + 'metadata': None, + 'numpy_type': 'object', + 'pandas_type': 'list[unicode]', + }, + { + 'name': 'c2', + 'field_name': 'c2', + 'metadata': None, + 'numpy_type': 'object', + 'pandas_type': 'list[empty]', + } + ] + + def test_metadata_pandas_version(self): + df = pd.DataFrame({'a': [1, 2, 3], 'b': [1, 2, 3]}) + table = pa.Table.from_pandas(df) + assert table.schema.pandas_metadata['pandas_version'] is not None + + def test_mismatch_metadata_schema(self): + # ARROW-10511 + # It is possible that the metadata and actual schema is not fully + # matching (eg no timezone information for tz-aware column) + # -> to_pandas() conversion should not fail on that + df = pd.DataFrame({"datetime": pd.date_range("2020-01-01", periods=3)}) + + # OPTION 1: casting after conversion + table = pa.Table.from_pandas(df) + # cast the "datetime" column to be tz-aware + new_col = table["datetime"].cast(pa.timestamp('ns', tz="UTC")) + new_table1 = table.set_column( + 0, pa.field("datetime", new_col.type), new_col + ) + + # OPTION 2: specify schema during conversion + schema = pa.schema([("datetime", pa.timestamp('ns', tz="UTC"))]) + new_table2 = pa.Table.from_pandas(df, schema=schema) + + expected = df.copy() + expected["datetime"] = expected["datetime"].dt.tz_localize("UTC") + + for new_table in [new_table1, new_table2]: + # ensure the new table still has the pandas metadata + assert new_table.schema.pandas_metadata is not None + # convert to pandas + result = new_table.to_pandas() + tm.assert_frame_equal(result, expected) + + +class TestConvertPrimitiveTypes: + """ + Conversion tests for primitive (e.g. numeric) types. + """ + + def test_float_no_nulls(self): + data = {} + fields = [] + dtypes = [('f2', pa.float16()), + ('f4', pa.float32()), + ('f8', pa.float64())] + num_values = 100 + + for numpy_dtype, arrow_dtype in dtypes: + values = np.random.randn(num_values) + data[numpy_dtype] = values.astype(numpy_dtype) + fields.append(pa.field(numpy_dtype, arrow_dtype)) + + df = pd.DataFrame(data) + schema = pa.schema(fields) + _check_pandas_roundtrip(df, expected_schema=schema) + + def test_float_nulls(self): + num_values = 100 + + null_mask = np.random.randint(0, 10, size=num_values) < 3 + dtypes = [('f2', pa.float16()), + ('f4', pa.float32()), + ('f8', pa.float64())] + names = ['f2', 'f4', 'f8'] + expected_cols = [] + + arrays = [] + fields = [] + for name, arrow_dtype in dtypes: + values = np.random.randn(num_values).astype(name) + + arr = pa.array(values, from_pandas=True, mask=null_mask) + arrays.append(arr) + fields.append(pa.field(name, arrow_dtype)) + values[null_mask] = np.nan + + expected_cols.append(values) + + ex_frame = pd.DataFrame(dict(zip(names, expected_cols)), + columns=names) + + table = pa.Table.from_arrays(arrays, names) + assert table.schema.equals(pa.schema(fields)) + result = table.to_pandas() + tm.assert_frame_equal(result, ex_frame) + + def test_float_nulls_to_ints(self): + # ARROW-2135 + df = pd.DataFrame({"a": [1.0, 2.0, np.NaN]}) + schema = pa.schema([pa.field("a", pa.int16(), nullable=True)]) + table = pa.Table.from_pandas(df, schema=schema, safe=False) + assert table[0].to_pylist() == [1, 2, None] + tm.assert_frame_equal(df, table.to_pandas()) + + def test_float_nulls_to_boolean(self): + s = pd.Series([0.0, 1.0, 2.0, None, -3.0]) + expected = pd.Series([False, True, True, None, True]) + _check_array_roundtrip(s, expected=expected, type=pa.bool_()) + + def test_series_from_pandas_false_respected(self): + # Check that explicit from_pandas=False is respected + s = pd.Series([0.0, np.nan]) + arr = pa.array(s, from_pandas=False) + assert arr.null_count == 0 + assert np.isnan(arr[1].as_py()) + + def test_integer_no_nulls(self): + data = OrderedDict() + fields = [] + + numpy_dtypes = [ + ('i1', pa.int8()), ('i2', pa.int16()), + ('i4', pa.int32()), ('i8', pa.int64()), + ('u1', pa.uint8()), ('u2', pa.uint16()), + ('u4', pa.uint32()), ('u8', pa.uint64()), + ('longlong', pa.int64()), ('ulonglong', pa.uint64()) + ] + num_values = 100 + + for dtype, arrow_dtype in numpy_dtypes: + info = np.iinfo(dtype) + values = np.random.randint(max(info.min, np.iinfo(np.int_).min), + min(info.max, np.iinfo(np.int_).max), + size=num_values) + data[dtype] = values.astype(dtype) + fields.append(pa.field(dtype, arrow_dtype)) + + df = pd.DataFrame(data) + schema = pa.schema(fields) + _check_pandas_roundtrip(df, expected_schema=schema) + + def test_all_integer_types(self): + # Test all Numpy integer aliases + data = OrderedDict() + numpy_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8', + 'byte', 'ubyte', 'short', 'ushort', 'intc', 'uintc', + 'int_', 'uint', 'longlong', 'ulonglong'] + for dtype in numpy_dtypes: + data[dtype] = np.arange(12, dtype=dtype) + df = pd.DataFrame(data) + _check_pandas_roundtrip(df) + + # Do the same with pa.array() + # (for some reason, it doesn't use the same code paths at all) + for np_arr in data.values(): + arr = pa.array(np_arr) + assert arr.to_pylist() == np_arr.tolist() + + def test_integer_byteorder(self): + # Byteswapped arrays are not supported yet + int_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8'] + for dt in int_dtypes: + for order in '=<>': + data = np.array([1, 2, 42], dtype=order + dt) + for np_arr in (data, data[::2]): + if data.dtype.isnative: + arr = pa.array(data) + assert arr.to_pylist() == data.tolist() + else: + with pytest.raises(NotImplementedError): + arr = pa.array(data) + + def test_integer_with_nulls(self): + # pandas requires upcast to float dtype + + int_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8'] + num_values = 100 + + null_mask = np.random.randint(0, 10, size=num_values) < 3 + + expected_cols = [] + arrays = [] + for name in int_dtypes: + values = np.random.randint(0, 100, size=num_values) + + arr = pa.array(values, mask=null_mask) + arrays.append(arr) + + expected = values.astype('f8') + expected[null_mask] = np.nan + + expected_cols.append(expected) + + ex_frame = pd.DataFrame(dict(zip(int_dtypes, expected_cols)), + columns=int_dtypes) + + table = pa.Table.from_arrays(arrays, int_dtypes) + result = table.to_pandas() + + tm.assert_frame_equal(result, ex_frame) + + def test_array_from_pandas_type_cast(self): + arr = np.arange(10, dtype='int64') + + target_type = pa.int8() + + result = pa.array(arr, type=target_type) + expected = pa.array(arr.astype('int8')) + assert result.equals(expected) + + def test_boolean_no_nulls(self): + num_values = 100 + + np.random.seed(0) + + df = pd.DataFrame({'bools': np.random.randn(num_values) > 0}) + field = pa.field('bools', pa.bool_()) + schema = pa.schema([field]) + _check_pandas_roundtrip(df, expected_schema=schema) + + def test_boolean_nulls(self): + # pandas requires upcast to object dtype + num_values = 100 + np.random.seed(0) + + mask = np.random.randint(0, 10, size=num_values) < 3 + values = np.random.randint(0, 10, size=num_values) < 5 + + arr = pa.array(values, mask=mask) + + expected = values.astype(object) + expected[mask] = None + + field = pa.field('bools', pa.bool_()) + schema = pa.schema([field]) + ex_frame = pd.DataFrame({'bools': expected}) + + table = pa.Table.from_arrays([arr], ['bools']) + assert table.schema.equals(schema) + result = table.to_pandas() + + tm.assert_frame_equal(result, ex_frame) + + def test_boolean_to_int(self): + # test from dtype=bool + s = pd.Series([True, True, False, True, True] * 2) + expected = pd.Series([1, 1, 0, 1, 1] * 2) + _check_array_roundtrip(s, expected=expected, type=pa.int64()) + + def test_boolean_objects_to_int(self): + # test from dtype=object + s = pd.Series([True, True, False, True, True] * 2, dtype=object) + expected = pd.Series([1, 1, 0, 1, 1] * 2) + expected_msg = 'Expected integer, got bool' + with pytest.raises(pa.ArrowTypeError, match=expected_msg): + _check_array_roundtrip(s, expected=expected, type=pa.int64()) + + def test_boolean_nulls_to_float(self): + # test from dtype=object + s = pd.Series([True, True, False, None, True] * 2) + expected = pd.Series([1.0, 1.0, 0.0, None, 1.0] * 2) + _check_array_roundtrip(s, expected=expected, type=pa.float64()) + + def test_boolean_multiple_columns(self): + # ARROW-6325 (multiple columns resulting in strided conversion) + df = pd.DataFrame(np.ones((3, 2), dtype='bool'), columns=['a', 'b']) + _check_pandas_roundtrip(df) + + def test_float_object_nulls(self): + arr = np.array([None, 1.5, np.float64(3.5)] * 5, dtype=object) + df = pd.DataFrame({'floats': arr}) + expected = pd.DataFrame({'floats': pd.to_numeric(arr)}) + field = pa.field('floats', pa.float64()) + schema = pa.schema([field]) + _check_pandas_roundtrip(df, expected=expected, + expected_schema=schema) + + def test_float_with_null_as_integer(self): + # ARROW-2298 + s = pd.Series([np.nan, 1., 2., np.nan]) + + types = [pa.int8(), pa.int16(), pa.int32(), pa.int64(), + pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()] + for ty in types: + result = pa.array(s, type=ty) + expected = pa.array([None, 1, 2, None], type=ty) + assert result.equals(expected) + + df = pd.DataFrame({'has_nulls': s}) + schema = pa.schema([pa.field('has_nulls', ty)]) + result = pa.Table.from_pandas(df, schema=schema, + preserve_index=False) + assert result[0].chunk(0).equals(expected) + + def test_int_object_nulls(self): + arr = np.array([None, 1, np.int64(3)] * 5, dtype=object) + df = pd.DataFrame({'ints': arr}) + expected = pd.DataFrame({'ints': pd.to_numeric(arr)}) + field = pa.field('ints', pa.int64()) + schema = pa.schema([field]) + _check_pandas_roundtrip(df, expected=expected, + expected_schema=schema) + + def test_boolean_object_nulls(self): + arr = np.array([False, None, True] * 100, dtype=object) + df = pd.DataFrame({'bools': arr}) + field = pa.field('bools', pa.bool_()) + schema = pa.schema([field]) + _check_pandas_roundtrip(df, expected_schema=schema) + + def test_all_nulls_cast_numeric(self): + arr = np.array([None], dtype=object) + + def _check_type(t): + a2 = pa.array(arr, type=t) + assert a2.type == t + assert a2[0].as_py() is None + + _check_type(pa.int32()) + _check_type(pa.float64()) + + def test_half_floats_from_numpy(self): + arr = np.array([1.5, np.nan], dtype=np.float16) + a = pa.array(arr, type=pa.float16()) + x, y = a.to_pylist() + assert isinstance(x, np.float16) + assert x == 1.5 + assert isinstance(y, np.float16) + assert np.isnan(y) + + a = pa.array(arr, type=pa.float16(), from_pandas=True) + x, y = a.to_pylist() + assert isinstance(x, np.float16) + assert x == 1.5 + assert y is None + + +@pytest.mark.parametrize('dtype', + ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8']) +def test_array_integer_object_nulls_option(dtype): + num_values = 100 + + null_mask = np.random.randint(0, 10, size=num_values) < 3 + values = np.random.randint(0, 100, size=num_values, dtype=dtype) + + array = pa.array(values, mask=null_mask) + + if null_mask.any(): + expected = values.astype('O') + expected[null_mask] = None + else: + expected = values + + result = array.to_pandas(integer_object_nulls=True) + + np.testing.assert_equal(result, expected) + + +@pytest.mark.parametrize('dtype', + ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8']) +def test_table_integer_object_nulls_option(dtype): + num_values = 100 + + null_mask = np.random.randint(0, 10, size=num_values) < 3 + values = np.random.randint(0, 100, size=num_values, dtype=dtype) + + array = pa.array(values, mask=null_mask) + + if null_mask.any(): + expected = values.astype('O') + expected[null_mask] = None + else: + expected = values + + expected = pd.DataFrame({dtype: expected}) + + table = pa.Table.from_arrays([array], [dtype]) + result = table.to_pandas(integer_object_nulls=True) + + tm.assert_frame_equal(result, expected) + + +class TestConvertDateTimeLikeTypes: + """ + Conversion tests for datetime- and timestamp-like types (date64, etc.). + """ + + def test_timestamps_notimezone_no_nulls(self): + df = pd.DataFrame({ + 'datetime64': np.array([ + '2007-07-13T01:23:34.123456789', + '2006-01-13T12:34:56.432539784', + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ns]') + }) + field = pa.field('datetime64', pa.timestamp('ns')) + schema = pa.schema([field]) + _check_pandas_roundtrip( + df, + expected_schema=schema, + ) + + def test_timestamps_notimezone_nulls(self): + df = pd.DataFrame({ + 'datetime64': np.array([ + '2007-07-13T01:23:34.123456789', + None, + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ns]') + }) + field = pa.field('datetime64', pa.timestamp('ns')) + schema = pa.schema([field]) + _check_pandas_roundtrip( + df, + expected_schema=schema, + ) + + def test_timestamps_with_timezone(self): + df = pd.DataFrame({ + 'datetime64': np.array([ + '2007-07-13T01:23:34.123', + '2006-01-13T12:34:56.432', + '2010-08-13T05:46:57.437'], + dtype='datetime64[ms]') + }) + df['datetime64'] = df['datetime64'].dt.tz_localize('US/Eastern') + _check_pandas_roundtrip(df) + + _check_series_roundtrip(df['datetime64']) + + # drop-in a null and ns instead of ms + df = pd.DataFrame({ + 'datetime64': np.array([ + '2007-07-13T01:23:34.123456789', + None, + '2006-01-13T12:34:56.432539784', + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ns]') + }) + df['datetime64'] = df['datetime64'].dt.tz_localize('US/Eastern') + + _check_pandas_roundtrip(df) + + def test_python_datetime(self): + # ARROW-2106 + date_array = [datetime.today() + timedelta(days=x) for x in range(10)] + df = pd.DataFrame({ + 'datetime': pd.Series(date_array, dtype=object) + }) + + table = pa.Table.from_pandas(df) + assert isinstance(table[0].chunk(0), pa.TimestampArray) + + result = table.to_pandas() + expected_df = pd.DataFrame({ + 'datetime': date_array + }) + tm.assert_frame_equal(expected_df, result) + + def test_python_datetime_with_pytz_tzinfo(self): + for tz in [pytz.utc, pytz.timezone('US/Eastern'), pytz.FixedOffset(1)]: + values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz)] + df = pd.DataFrame({'datetime': values}) + _check_pandas_roundtrip(df) + + @h.given(st.none() | tzst.timezones()) + def test_python_datetime_with_pytz_timezone(self, tz): + values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz)] + df = pd.DataFrame({'datetime': values}) + _check_pandas_roundtrip(df) + + def test_python_datetime_with_timezone_tzinfo(self): + from datetime import timezone + + if Version(pd.__version__) > Version("0.25.0"): + # older pandas versions fail on datetime.timezone.utc (as in input) + # vs pytz.UTC (as in result) + values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=timezone.utc)] + # also test with index to ensure both paths roundtrip (ARROW-9962) + df = pd.DataFrame({'datetime': values}, index=values) + _check_pandas_roundtrip(df, preserve_index=True) + + # datetime.timezone is going to be pytz.FixedOffset + hours = 1 + tz_timezone = timezone(timedelta(hours=hours)) + tz_pytz = pytz.FixedOffset(hours * 60) + values = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz_timezone)] + values_exp = [datetime(2018, 1, 1, 12, 23, 45, tzinfo=tz_pytz)] + df = pd.DataFrame({'datetime': values}, index=values) + df_exp = pd.DataFrame({'datetime': values_exp}, index=values_exp) + _check_pandas_roundtrip(df, expected=df_exp, preserve_index=True) + + def test_python_datetime_subclass(self): + + class MyDatetime(datetime): + # see https://github.com/pandas-dev/pandas/issues/21142 + nanosecond = 0.0 + + date_array = [MyDatetime(2000, 1, 1, 1, 1, 1)] + df = pd.DataFrame({"datetime": pd.Series(date_array, dtype=object)}) + + table = pa.Table.from_pandas(df) + assert isinstance(table[0].chunk(0), pa.TimestampArray) + + result = table.to_pandas() + expected_df = pd.DataFrame({"datetime": date_array}) + + # https://github.com/pandas-dev/pandas/issues/21142 + expected_df["datetime"] = pd.to_datetime(expected_df["datetime"]) + + tm.assert_frame_equal(expected_df, result) + + def test_python_date_subclass(self): + + class MyDate(date): + pass + + date_array = [MyDate(2000, 1, 1)] + df = pd.DataFrame({"date": pd.Series(date_array, dtype=object)}) + + table = pa.Table.from_pandas(df) + assert isinstance(table[0].chunk(0), pa.Date32Array) + + result = table.to_pandas() + expected_df = pd.DataFrame( + {"date": np.array([date(2000, 1, 1)], dtype=object)} + ) + tm.assert_frame_equal(expected_df, result) + + def test_datetime64_to_date32(self): + # ARROW-1718 + arr = pa.array([date(2017, 10, 23), None]) + c = pa.chunked_array([arr]) + s = c.to_pandas() + + arr2 = pa.Array.from_pandas(s, type=pa.date32()) + + assert arr2.equals(arr.cast('date32')) + + @pytest.mark.parametrize('mask', [ + None, + np.array([True, False, False, True, False, False]), + ]) + def test_pandas_datetime_to_date64(self, mask): + s = pd.to_datetime([ + '2018-05-10T00:00:00', + '2018-05-11T00:00:00', + '2018-05-12T00:00:00', + '2018-05-10T10:24:01', + '2018-05-11T10:24:01', + '2018-05-12T10:24:01', + ]) + arr = pa.Array.from_pandas(s, type=pa.date64(), mask=mask) + + data = np.array([ + date(2018, 5, 10), + date(2018, 5, 11), + date(2018, 5, 12), + date(2018, 5, 10), + date(2018, 5, 11), + date(2018, 5, 12), + ]) + expected = pa.array(data, mask=mask, type=pa.date64()) + + assert arr.equals(expected) + + def test_array_types_date_as_object(self): + data = [date(2000, 1, 1), + None, + date(1970, 1, 1), + date(2040, 2, 26)] + expected_d = np.array(['2000-01-01', None, '1970-01-01', + '2040-02-26'], dtype='datetime64[D]') + + expected_ns = np.array(['2000-01-01', None, '1970-01-01', + '2040-02-26'], dtype='datetime64[ns]') + + objects = [pa.array(data), + pa.chunked_array([data])] + + for obj in objects: + result = obj.to_pandas() + expected_obj = expected_d.astype(object) + assert result.dtype == expected_obj.dtype + npt.assert_array_equal(result, expected_obj) + + result = obj.to_pandas(date_as_object=False) + assert result.dtype == expected_ns.dtype + npt.assert_array_equal(result, expected_ns) + + def test_table_convert_date_as_object(self): + df = pd.DataFrame({ + 'date': [date(2000, 1, 1), + None, + date(1970, 1, 1), + date(2040, 2, 26)]}) + + table = pa.Table.from_pandas(df, preserve_index=False) + + df_datetime = table.to_pandas(date_as_object=False) + df_object = table.to_pandas() + + tm.assert_frame_equal(df.astype('datetime64[ns]'), df_datetime, + check_dtype=True) + tm.assert_frame_equal(df, df_object, check_dtype=True) + + def test_date_infer(self): + df = pd.DataFrame({ + 'date': [date(2000, 1, 1), + None, + date(1970, 1, 1), + date(2040, 2, 26)]}) + table = pa.Table.from_pandas(df, preserve_index=False) + field = pa.field('date', pa.date32()) + + # schema's metadata is generated by from_pandas conversion + expected_schema = pa.schema([field], metadata=table.schema.metadata) + assert table.schema.equals(expected_schema) + + result = table.to_pandas() + tm.assert_frame_equal(result, df) + + def test_date_mask(self): + arr = np.array([date(2017, 4, 3), date(2017, 4, 4)], + dtype='datetime64[D]') + mask = [True, False] + result = pa.array(arr, mask=np.array(mask)) + expected = np.array([None, date(2017, 4, 4)], dtype='datetime64[D]') + expected = pa.array(expected, from_pandas=True) + assert expected.equals(result) + + def test_date_objects_typed(self): + arr = np.array([ + date(2017, 4, 3), + None, + date(2017, 4, 4), + date(2017, 4, 5)], dtype=object) + + arr_i4 = np.array([17259, -1, 17260, 17261], dtype='int32') + arr_i8 = arr_i4.astype('int64') * 86400000 + mask = np.array([False, True, False, False]) + + t32 = pa.date32() + t64 = pa.date64() + + a32 = pa.array(arr, type=t32) + a64 = pa.array(arr, type=t64) + + a32_expected = pa.array(arr_i4, mask=mask, type=t32) + a64_expected = pa.array(arr_i8, mask=mask, type=t64) + + assert a32.equals(a32_expected) + assert a64.equals(a64_expected) + + # Test converting back to pandas + colnames = ['date32', 'date64'] + table = pa.Table.from_arrays([a32, a64], colnames) + + ex_values = (np.array(['2017-04-03', '2017-04-04', '2017-04-04', + '2017-04-05'], + dtype='datetime64[D]')) + ex_values[1] = pd.NaT.value + + ex_datetime64ns = ex_values.astype('datetime64[ns]') + expected_pandas = pd.DataFrame({'date32': ex_datetime64ns, + 'date64': ex_datetime64ns}, + columns=colnames) + table_pandas = table.to_pandas(date_as_object=False) + tm.assert_frame_equal(table_pandas, expected_pandas) + + table_pandas_objects = table.to_pandas() + ex_objects = ex_values.astype('object') + expected_pandas_objects = pd.DataFrame({'date32': ex_objects, + 'date64': ex_objects}, + columns=colnames) + tm.assert_frame_equal(table_pandas_objects, + expected_pandas_objects) + + def test_pandas_null_values(self): + # ARROW-842 + pd_NA = getattr(pd, 'NA', None) + values = np.array([datetime(2000, 1, 1), pd.NaT, pd_NA], dtype=object) + values_with_none = np.array([datetime(2000, 1, 1), None, None], + dtype=object) + result = pa.array(values, from_pandas=True) + expected = pa.array(values_with_none, from_pandas=True) + assert result.equals(expected) + assert result.null_count == 2 + + # ARROW-9407 + assert pa.array([pd.NaT], from_pandas=True).type == pa.null() + assert pa.array([pd_NA], from_pandas=True).type == pa.null() + + def test_dates_from_integers(self): + t1 = pa.date32() + t2 = pa.date64() + + arr = np.array([17259, 17260, 17261], dtype='int32') + arr2 = arr.astype('int64') * 86400000 + + a1 = pa.array(arr, type=t1) + a2 = pa.array(arr2, type=t2) + + expected = date(2017, 4, 3) + assert a1[0].as_py() == expected + assert a2[0].as_py() == expected + + def test_pytime_from_pandas(self): + pytimes = [time(1, 2, 3, 1356), + time(4, 5, 6, 1356)] + + # microseconds + t1 = pa.time64('us') + + aobjs = np.array(pytimes + [None], dtype=object) + parr = pa.array(aobjs) + assert parr.type == t1 + assert parr[0].as_py() == pytimes[0] + assert parr[1].as_py() == pytimes[1] + assert parr[2].as_py() is None + + # DataFrame + df = pd.DataFrame({'times': aobjs}) + batch = pa.RecordBatch.from_pandas(df) + assert batch[0].equals(parr) + + # Test ndarray of int64 values + arr = np.array([_pytime_to_micros(v) for v in pytimes], + dtype='int64') + + a1 = pa.array(arr, type=pa.time64('us')) + assert a1[0].as_py() == pytimes[0] + + a2 = pa.array(arr * 1000, type=pa.time64('ns')) + assert a2[0].as_py() == pytimes[0] + + a3 = pa.array((arr / 1000).astype('i4'), + type=pa.time32('ms')) + assert a3[0].as_py() == pytimes[0].replace(microsecond=1000) + + a4 = pa.array((arr / 1000000).astype('i4'), + type=pa.time32('s')) + assert a4[0].as_py() == pytimes[0].replace(microsecond=0) + + def test_arrow_time_to_pandas(self): + pytimes = [time(1, 2, 3, 1356), + time(4, 5, 6, 1356), + time(0, 0, 0)] + + expected = np.array(pytimes[:2] + [None]) + expected_ms = np.array([x.replace(microsecond=1000) + for x in pytimes[:2]] + + [None]) + expected_s = np.array([x.replace(microsecond=0) + for x in pytimes[:2]] + + [None]) + + arr = np.array([_pytime_to_micros(v) for v in pytimes], + dtype='int64') + arr = np.array([_pytime_to_micros(v) for v in pytimes], + dtype='int64') + + null_mask = np.array([False, False, True], dtype=bool) + + a1 = pa.array(arr, mask=null_mask, type=pa.time64('us')) + a2 = pa.array(arr * 1000, mask=null_mask, + type=pa.time64('ns')) + + a3 = pa.array((arr / 1000).astype('i4'), mask=null_mask, + type=pa.time32('ms')) + a4 = pa.array((arr / 1000000).astype('i4'), mask=null_mask, + type=pa.time32('s')) + + names = ['time64[us]', 'time64[ns]', 'time32[ms]', 'time32[s]'] + batch = pa.RecordBatch.from_arrays([a1, a2, a3, a4], names) + + for arr, expected_values in [(a1, expected), + (a2, expected), + (a3, expected_ms), + (a4, expected_s)]: + result_pandas = arr.to_pandas() + assert (result_pandas.values == expected_values).all() + + df = batch.to_pandas() + expected_df = pd.DataFrame({'time64[us]': expected, + 'time64[ns]': expected, + 'time32[ms]': expected_ms, + 'time32[s]': expected_s}, + columns=names) + + tm.assert_frame_equal(df, expected_df) + + def test_numpy_datetime64_columns(self): + datetime64_ns = np.array([ + '2007-07-13T01:23:34.123456789', + None, + '2006-01-13T12:34:56.432539784', + '2010-08-13T05:46:57.437699912'], + dtype='datetime64[ns]') + _check_array_from_pandas_roundtrip(datetime64_ns) + + datetime64_us = np.array([ + '2007-07-13T01:23:34.123456', + None, + '2006-01-13T12:34:56.432539', + '2010-08-13T05:46:57.437699'], + dtype='datetime64[us]') + _check_array_from_pandas_roundtrip(datetime64_us) + + datetime64_ms = np.array([ + '2007-07-13T01:23:34.123', + None, + '2006-01-13T12:34:56.432', + '2010-08-13T05:46:57.437'], + dtype='datetime64[ms]') + _check_array_from_pandas_roundtrip(datetime64_ms) + + datetime64_s = np.array([ + '2007-07-13T01:23:34', + None, + '2006-01-13T12:34:56', + '2010-08-13T05:46:57'], + dtype='datetime64[s]') + _check_array_from_pandas_roundtrip(datetime64_s) + + def test_timestamp_to_pandas_ns(self): + # non-ns timestamp gets cast to ns on conversion to pandas + arr = pa.array([1, 2, 3], pa.timestamp('ms')) + expected = pd.Series(pd.to_datetime([1, 2, 3], unit='ms')) + s = arr.to_pandas() + tm.assert_series_equal(s, expected) + arr = pa.chunked_array([arr]) + s = arr.to_pandas() + tm.assert_series_equal(s, expected) + + def test_timestamp_to_pandas_out_of_bounds(self): + # ARROW-7758 check for out of bounds timestamps for non-ns timestamps + + for unit in ['s', 'ms', 'us']: + for tz in [None, 'America/New_York']: + arr = pa.array([datetime(1, 1, 1)], pa.timestamp(unit, tz=tz)) + table = pa.table({'a': arr}) + + msg = "would result in out of bounds timestamp" + with pytest.raises(ValueError, match=msg): + arr.to_pandas() + + with pytest.raises(ValueError, match=msg): + table.to_pandas() + + with pytest.raises(ValueError, match=msg): + # chunked array + table.column('a').to_pandas() + + # just ensure those don't give an error, but do not + # check actual garbage output + arr.to_pandas(safe=False) + table.to_pandas(safe=False) + table.column('a').to_pandas(safe=False) + + def test_timestamp_to_pandas_empty_chunked(self): + # ARROW-7907 table with chunked array with 0 chunks + table = pa.table({'a': pa.chunked_array([], type=pa.timestamp('us'))}) + result = table.to_pandas() + expected = pd.DataFrame({'a': pd.Series([], dtype="datetime64[ns]")}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize('dtype', [pa.date32(), pa.date64()]) + def test_numpy_datetime64_day_unit(self, dtype): + datetime64_d = np.array([ + '2007-07-13', + None, + '2006-01-15', + '2010-08-19'], + dtype='datetime64[D]') + _check_array_from_pandas_roundtrip(datetime64_d, type=dtype) + + def test_array_from_pandas_date_with_mask(self): + m = np.array([True, False, True]) + data = pd.Series([ + date(1990, 1, 1), + date(1991, 1, 1), + date(1992, 1, 1) + ]) + + result = pa.Array.from_pandas(data, mask=m) + + expected = pd.Series([None, date(1991, 1, 1), None]) + assert pa.Array.from_pandas(expected).equals(result) + + @pytest.mark.skipif( + Version('1.16.0') <= Version(np.__version__) < Version('1.16.1'), + reason='Until numpy/numpy#12745 is resolved') + def test_fixed_offset_timezone(self): + df = pd.DataFrame({ + 'a': [ + pd.Timestamp('2012-11-11 00:00:00+01:00'), + pd.NaT + ] + }) + _check_pandas_roundtrip(df) + _check_serialize_components_roundtrip(df) + + def test_timedeltas_no_nulls(self): + df = pd.DataFrame({ + 'timedelta64': np.array([0, 3600000000000, 7200000000000], + dtype='timedelta64[ns]') + }) + field = pa.field('timedelta64', pa.duration('ns')) + schema = pa.schema([field]) + _check_pandas_roundtrip( + df, + expected_schema=schema, + ) + + def test_timedeltas_nulls(self): + df = pd.DataFrame({ + 'timedelta64': np.array([0, None, 7200000000000], + dtype='timedelta64[ns]') + }) + field = pa.field('timedelta64', pa.duration('ns')) + schema = pa.schema([field]) + _check_pandas_roundtrip( + df, + expected_schema=schema, + ) + + def test_month_day_nano_interval(self): + from pandas.tseries.offsets import DateOffset + df = pd.DataFrame({ + 'date_offset': [None, + DateOffset(days=3600, months=3600, microseconds=3, + nanoseconds=600)] + }) + schema = pa.schema([('date_offset', pa.month_day_nano_interval())]) + _check_pandas_roundtrip( + df, + expected_schema=schema) + + +# ---------------------------------------------------------------------- +# Conversion tests for string and binary types. + + +class TestConvertStringLikeTypes: + + def test_pandas_unicode(self): + repeats = 1000 + values = ['foo', None, 'bar', 'mañana', np.nan] + df = pd.DataFrame({'strings': values * repeats}) + field = pa.field('strings', pa.string()) + schema = pa.schema([field]) + + _check_pandas_roundtrip(df, expected_schema=schema) + + def test_bytes_to_binary(self): + values = ['qux', b'foo', None, bytearray(b'barz'), 'qux', np.nan] + df = pd.DataFrame({'strings': values}) + + table = pa.Table.from_pandas(df) + assert table[0].type == pa.binary() + + values2 = [b'qux', b'foo', None, b'barz', b'qux', np.nan] + expected = pd.DataFrame({'strings': values2}) + _check_pandas_roundtrip(df, expected) + + @pytest.mark.large_memory + def test_bytes_exceed_2gb(self): + v1 = b'x' * 100000000 + v2 = b'x' * 147483646 + + # ARROW-2227, hit exactly 2GB on the nose + df = pd.DataFrame({ + 'strings': [v1] * 20 + [v2] + ['x'] * 20 + }) + arr = pa.array(df['strings']) + assert isinstance(arr, pa.ChunkedArray) + assert arr.num_chunks == 2 + arr = None + + table = pa.Table.from_pandas(df) + assert table[0].num_chunks == 2 + + @pytest.mark.large_memory + @pytest.mark.parametrize('char', ['x', b'x']) + def test_auto_chunking_pandas_series_of_strings(self, char): + # ARROW-2367 + v1 = char * 100000000 + v2 = char * 147483646 + + df = pd.DataFrame({ + 'strings': [[v1]] * 20 + [[v2]] + [[b'x']] + }) + arr = pa.array(df['strings'], from_pandas=True) + assert isinstance(arr, pa.ChunkedArray) + assert arr.num_chunks == 2 + assert len(arr.chunk(0)) == 21 + assert len(arr.chunk(1)) == 1 + + def test_fixed_size_bytes(self): + values = [b'foo', None, bytearray(b'bar'), None, None, b'hey'] + df = pd.DataFrame({'strings': values}) + schema = pa.schema([pa.field('strings', pa.binary(3))]) + table = pa.Table.from_pandas(df, schema=schema) + assert table.schema[0].type == schema[0].type + assert table.schema[0].name == schema[0].name + result = table.to_pandas() + tm.assert_frame_equal(result, df) + + def test_fixed_size_bytes_does_not_accept_varying_lengths(self): + values = [b'foo', None, b'ba', None, None, b'hey'] + df = pd.DataFrame({'strings': values}) + schema = pa.schema([pa.field('strings', pa.binary(3))]) + with pytest.raises(pa.ArrowInvalid): + pa.Table.from_pandas(df, schema=schema) + + def test_variable_size_bytes(self): + s = pd.Series([b'123', b'', b'a', None]) + _check_series_roundtrip(s, type_=pa.binary()) + + def test_binary_from_bytearray(self): + s = pd.Series([bytearray(b'123'), bytearray(b''), bytearray(b'a'), + None]) + # Explicitly set type + _check_series_roundtrip(s, type_=pa.binary()) + # Infer type from bytearrays + _check_series_roundtrip(s, expected_pa_type=pa.binary()) + + def test_large_binary(self): + s = pd.Series([b'123', b'', b'a', None]) + _check_series_roundtrip(s, type_=pa.large_binary()) + df = pd.DataFrame({'a': s}) + _check_pandas_roundtrip( + df, schema=pa.schema([('a', pa.large_binary())])) + + def test_large_string(self): + s = pd.Series(['123', '', 'a', None]) + _check_series_roundtrip(s, type_=pa.large_string()) + df = pd.DataFrame({'a': s}) + _check_pandas_roundtrip( + df, schema=pa.schema([('a', pa.large_string())])) + + def test_table_empty_str(self): + values = ['', '', '', '', ''] + df = pd.DataFrame({'strings': values}) + field = pa.field('strings', pa.string()) + schema = pa.schema([field]) + table = pa.Table.from_pandas(df, schema=schema) + + result1 = table.to_pandas(strings_to_categorical=False) + expected1 = pd.DataFrame({'strings': values}) + tm.assert_frame_equal(result1, expected1, check_dtype=True) + + result2 = table.to_pandas(strings_to_categorical=True) + expected2 = pd.DataFrame({'strings': pd.Categorical(values)}) + tm.assert_frame_equal(result2, expected2, check_dtype=True) + + def test_selective_categoricals(self): + values = ['', '', '', '', ''] + df = pd.DataFrame({'strings': values}) + field = pa.field('strings', pa.string()) + schema = pa.schema([field]) + table = pa.Table.from_pandas(df, schema=schema) + expected_str = pd.DataFrame({'strings': values}) + expected_cat = pd.DataFrame({'strings': pd.Categorical(values)}) + + result1 = table.to_pandas(categories=['strings']) + tm.assert_frame_equal(result1, expected_cat, check_dtype=True) + result2 = table.to_pandas(categories=[]) + tm.assert_frame_equal(result2, expected_str, check_dtype=True) + result3 = table.to_pandas(categories=('strings',)) + tm.assert_frame_equal(result3, expected_cat, check_dtype=True) + result4 = table.to_pandas(categories=tuple()) + tm.assert_frame_equal(result4, expected_str, check_dtype=True) + + def test_to_pandas_categorical_zero_length(self): + # ARROW-3586 + array = pa.array([], type=pa.int32()) + table = pa.Table.from_arrays(arrays=[array], names=['col']) + # This would segfault under 0.11.0 + table.to_pandas(categories=['col']) + + def test_to_pandas_categories_already_dictionary(self): + # Showed up in ARROW-6434, ARROW-6435 + array = pa.array(['foo', 'foo', 'foo', 'bar']).dictionary_encode() + table = pa.Table.from_arrays(arrays=[array], names=['col']) + result = table.to_pandas(categories=['col']) + assert table.to_pandas().equals(result) + + def test_table_str_to_categorical_without_na(self): + values = ['a', 'a', 'b', 'b', 'c'] + df = pd.DataFrame({'strings': values}) + field = pa.field('strings', pa.string()) + schema = pa.schema([field]) + table = pa.Table.from_pandas(df, schema=schema) + + result = table.to_pandas(strings_to_categorical=True) + expected = pd.DataFrame({'strings': pd.Categorical(values)}) + tm.assert_frame_equal(result, expected, check_dtype=True) + + with pytest.raises(pa.ArrowInvalid): + table.to_pandas(strings_to_categorical=True, + zero_copy_only=True) + + def test_table_str_to_categorical_with_na(self): + values = [None, 'a', 'b', np.nan] + df = pd.DataFrame({'strings': values}) + field = pa.field('strings', pa.string()) + schema = pa.schema([field]) + table = pa.Table.from_pandas(df, schema=schema) + + result = table.to_pandas(strings_to_categorical=True) + expected = pd.DataFrame({'strings': pd.Categorical(values)}) + tm.assert_frame_equal(result, expected, check_dtype=True) + + with pytest.raises(pa.ArrowInvalid): + table.to_pandas(strings_to_categorical=True, + zero_copy_only=True) + + # Regression test for ARROW-2101 + def test_array_of_bytes_to_strings(self): + converted = pa.array(np.array([b'x'], dtype=object), pa.string()) + assert converted.type == pa.string() + + # Make sure that if an ndarray of bytes is passed to the array + # constructor and the type is string, it will fail if those bytes + # cannot be converted to utf-8 + def test_array_of_bytes_to_strings_bad_data(self): + with pytest.raises( + pa.lib.ArrowInvalid, + match="was not a utf8 string"): + pa.array(np.array([b'\x80\x81'], dtype=object), pa.string()) + + def test_numpy_string_array_to_fixed_size_binary(self): + arr = np.array([b'foo', b'bar', b'baz'], dtype='|S3') + + converted = pa.array(arr, type=pa.binary(3)) + expected = pa.array(list(arr), type=pa.binary(3)) + assert converted.equals(expected) + + mask = np.array([False, True, False]) + converted = pa.array(arr, type=pa.binary(3), mask=mask) + expected = pa.array([b'foo', None, b'baz'], type=pa.binary(3)) + assert converted.equals(expected) + + with pytest.raises(pa.lib.ArrowInvalid, + match=r'Got bytestring of length 3 \(expected 4\)'): + arr = np.array([b'foo', b'bar', b'baz'], dtype='|S3') + pa.array(arr, type=pa.binary(4)) + + with pytest.raises( + pa.lib.ArrowInvalid, + match=r'Got bytestring of length 12 \(expected 3\)'): + arr = np.array([b'foo', b'bar', b'baz'], dtype='|U3') + pa.array(arr, type=pa.binary(3)) + + +class TestConvertDecimalTypes: + """ + Conversion test for decimal types. + """ + decimal32 = [ + decimal.Decimal('-1234.123'), + decimal.Decimal('1234.439') + ] + decimal64 = [ + decimal.Decimal('-129934.123331'), + decimal.Decimal('129534.123731') + ] + decimal128 = [ + decimal.Decimal('394092382910493.12341234678'), + decimal.Decimal('-314292388910493.12343437128') + ] + + @pytest.mark.parametrize(('values', 'expected_type'), [ + pytest.param(decimal32, pa.decimal128(7, 3), id='decimal32'), + pytest.param(decimal64, pa.decimal128(12, 6), id='decimal64'), + pytest.param(decimal128, pa.decimal128(26, 11), id='decimal128') + ]) + def test_decimal_from_pandas(self, values, expected_type): + expected = pd.DataFrame({'decimals': values}) + table = pa.Table.from_pandas(expected, preserve_index=False) + field = pa.field('decimals', expected_type) + + # schema's metadata is generated by from_pandas conversion + expected_schema = pa.schema([field], metadata=table.schema.metadata) + assert table.schema.equals(expected_schema) + + @pytest.mark.parametrize('values', [ + pytest.param(decimal32, id='decimal32'), + pytest.param(decimal64, id='decimal64'), + pytest.param(decimal128, id='decimal128') + ]) + def test_decimal_to_pandas(self, values): + expected = pd.DataFrame({'decimals': values}) + converted = pa.Table.from_pandas(expected) + df = converted.to_pandas() + tm.assert_frame_equal(df, expected) + + def test_decimal_fails_with_truncation(self): + data1 = [decimal.Decimal('1.234')] + type1 = pa.decimal128(10, 2) + with pytest.raises(pa.ArrowInvalid): + pa.array(data1, type=type1) + + data2 = [decimal.Decimal('1.2345')] + type2 = pa.decimal128(10, 3) + with pytest.raises(pa.ArrowInvalid): + pa.array(data2, type=type2) + + def test_decimal_with_different_precisions(self): + data = [ + decimal.Decimal('0.01'), + decimal.Decimal('0.001'), + ] + series = pd.Series(data) + array = pa.array(series) + assert array.to_pylist() == data + assert array.type == pa.decimal128(3, 3) + + array = pa.array(data, type=pa.decimal128(12, 5)) + expected = [decimal.Decimal('0.01000'), decimal.Decimal('0.00100')] + assert array.to_pylist() == expected + + def test_decimal_with_None_explicit_type(self): + series = pd.Series([decimal.Decimal('3.14'), None]) + _check_series_roundtrip(series, type_=pa.decimal128(12, 5)) + + # Test that having all None values still produces decimal array + series = pd.Series([None] * 2) + _check_series_roundtrip(series, type_=pa.decimal128(12, 5)) + + def test_decimal_with_None_infer_type(self): + series = pd.Series([decimal.Decimal('3.14'), None]) + _check_series_roundtrip(series, expected_pa_type=pa.decimal128(3, 2)) + + def test_strided_objects(self, tmpdir): + # see ARROW-3053 + data = { + 'a': {0: 'a'}, + 'b': {0: decimal.Decimal('0.0')} + } + + # This yields strided objects + df = pd.DataFrame.from_dict(data) + _check_pandas_roundtrip(df) + + +class TestConvertListTypes: + """ + Conversion tests for list<> types. + """ + + def test_column_of_arrays(self): + df, schema = dataframe_with_arrays() + _check_pandas_roundtrip(df, schema=schema, expected_schema=schema) + table = pa.Table.from_pandas(df, schema=schema, preserve_index=False) + + # schema's metadata is generated by from_pandas conversion + expected_schema = schema.with_metadata(table.schema.metadata) + assert table.schema.equals(expected_schema) + + for column in df.columns: + field = schema.field(column) + _check_array_roundtrip(df[column], type=field.type) + + def test_column_of_arrays_to_py(self): + # Test regression in ARROW-1199 not caught in above test + dtype = 'i1' + arr = np.array([ + np.arange(10, dtype=dtype), + np.arange(5, dtype=dtype), + None, + np.arange(1, dtype=dtype) + ], dtype=object) + type_ = pa.list_(pa.int8()) + parr = pa.array(arr, type=type_) + + assert parr[0].as_py() == list(range(10)) + assert parr[1].as_py() == list(range(5)) + assert parr[2].as_py() is None + assert parr[3].as_py() == [0] + + def test_column_of_boolean_list(self): + # ARROW-4370: Table to pandas conversion fails for list of bool + array = pa.array([[True, False], [True]], type=pa.list_(pa.bool_())) + table = pa.Table.from_arrays([array], names=['col1']) + df = table.to_pandas() + + expected_df = pd.DataFrame({'col1': [[True, False], [True]]}) + tm.assert_frame_equal(df, expected_df) + + s = table[0].to_pandas() + tm.assert_series_equal(pd.Series(s), df['col1'], check_names=False) + + def test_column_of_decimal_list(self): + array = pa.array([[decimal.Decimal('1'), decimal.Decimal('2')], + [decimal.Decimal('3.3')]], + type=pa.list_(pa.decimal128(2, 1))) + table = pa.Table.from_arrays([array], names=['col1']) + df = table.to_pandas() + + expected_df = pd.DataFrame( + {'col1': [[decimal.Decimal('1'), decimal.Decimal('2')], + [decimal.Decimal('3.3')]]}) + tm.assert_frame_equal(df, expected_df) + + def test_nested_types_from_ndarray_null_entries(self): + # Root cause of ARROW-6435 + s = pd.Series(np.array([np.nan, np.nan], dtype=object)) + + for ty in [pa.list_(pa.int64()), + pa.large_list(pa.int64()), + pa.struct([pa.field('f0', 'int32')])]: + result = pa.array(s, type=ty) + expected = pa.array([None, None], type=ty) + assert result.equals(expected) + + with pytest.raises(TypeError): + pa.array(s.values, type=ty) + + def test_column_of_lists(self): + df, schema = dataframe_with_lists() + _check_pandas_roundtrip(df, schema=schema, expected_schema=schema) + table = pa.Table.from_pandas(df, schema=schema, preserve_index=False) + + # schema's metadata is generated by from_pandas conversion + expected_schema = schema.with_metadata(table.schema.metadata) + assert table.schema.equals(expected_schema) + + for column in df.columns: + field = schema.field(column) + _check_array_roundtrip(df[column], type=field.type) + + def test_column_of_lists_first_empty(self): + # ARROW-2124 + num_lists = [[], [2, 3, 4], [3, 6, 7, 8], [], [2]] + series = pd.Series([np.array(s, dtype=float) for s in num_lists]) + arr = pa.array(series) + result = pd.Series(arr.to_pandas()) + tm.assert_series_equal(result, series) + + def test_column_of_lists_chunked(self): + # ARROW-1357 + df = pd.DataFrame({ + 'lists': np.array([ + [1, 2], + None, + [2, 3], + [4, 5], + [6, 7], + [8, 9] + ], dtype=object) + }) + + schema = pa.schema([ + pa.field('lists', pa.list_(pa.int64())) + ]) + + t1 = pa.Table.from_pandas(df[:2], schema=schema) + t2 = pa.Table.from_pandas(df[2:], schema=schema) + + table = pa.concat_tables([t1, t2]) + result = table.to_pandas() + + tm.assert_frame_equal(result, df) + + def test_empty_column_of_lists_chunked(self): + df = pd.DataFrame({ + 'lists': np.array([], dtype=object) + }) + + schema = pa.schema([ + pa.field('lists', pa.list_(pa.int64())) + ]) + + table = pa.Table.from_pandas(df, schema=schema) + result = table.to_pandas() + + tm.assert_frame_equal(result, df) + + def test_column_of_lists_chunked2(self): + data1 = [[0, 1], [2, 3], [4, 5], [6, 7], [10, 11], + [12, 13], [14, 15], [16, 17]] + data2 = [[8, 9], [18, 19]] + + a1 = pa.array(data1) + a2 = pa.array(data2) + + t1 = pa.Table.from_arrays([a1], names=['a']) + t2 = pa.Table.from_arrays([a2], names=['a']) + + concatenated = pa.concat_tables([t1, t2]) + + result = concatenated.to_pandas() + expected = pd.DataFrame({'a': data1 + data2}) + + tm.assert_frame_equal(result, expected) + + def test_column_of_lists_strided(self): + df, schema = dataframe_with_lists() + df = pd.concat([df] * 6, ignore_index=True) + + arr = df['int64'].values[::3] + assert arr.strides[0] != 8 + + _check_array_roundtrip(arr) + + def test_nested_lists_all_none(self): + data = np.array([[None, None], None], dtype=object) + + arr = pa.array(data) + expected = pa.array(list(data)) + assert arr.equals(expected) + assert arr.type == pa.list_(pa.null()) + + data2 = np.array([None, None, [None, None], + np.array([None, None], dtype=object)], + dtype=object) + arr = pa.array(data2) + expected = pa.array([None, None, [None, None], [None, None]]) + assert arr.equals(expected) + + def test_nested_lists_all_empty(self): + # ARROW-2128 + data = pd.Series([[], [], []]) + arr = pa.array(data) + expected = pa.array(list(data)) + assert arr.equals(expected) + assert arr.type == pa.list_(pa.null()) + + def test_nested_list_first_empty(self): + # ARROW-2711 + data = pd.Series([[], ["a"]]) + arr = pa.array(data) + expected = pa.array(list(data)) + assert arr.equals(expected) + assert arr.type == pa.list_(pa.string()) + + def test_nested_smaller_ints(self): + # ARROW-1345, ARROW-2008, there were some type inference bugs happening + # before + data = pd.Series([np.array([1, 2, 3], dtype='i1'), None]) + result = pa.array(data) + result2 = pa.array(data.values) + expected = pa.array([[1, 2, 3], None], type=pa.list_(pa.int8())) + assert result.equals(expected) + assert result2.equals(expected) + + data3 = pd.Series([np.array([1, 2, 3], dtype='f4'), None]) + result3 = pa.array(data3) + expected3 = pa.array([[1, 2, 3], None], type=pa.list_(pa.float32())) + assert result3.equals(expected3) + + def test_infer_lists(self): + data = OrderedDict([ + ('nan_ints', [[None, 1], [2, 3]]), + ('ints', [[0, 1], [2, 3]]), + ('strs', [[None, 'b'], ['c', 'd']]), + ('nested_strs', [[[None, 'b'], ['c', 'd']], None]) + ]) + df = pd.DataFrame(data) + + expected_schema = pa.schema([ + pa.field('nan_ints', pa.list_(pa.int64())), + pa.field('ints', pa.list_(pa.int64())), + pa.field('strs', pa.list_(pa.string())), + pa.field('nested_strs', pa.list_(pa.list_(pa.string()))) + ]) + + _check_pandas_roundtrip(df, expected_schema=expected_schema) + + def test_fixed_size_list(self): + # ARROW-7365 + fixed_ty = pa.list_(pa.int64(), list_size=4) + variable_ty = pa.list_(pa.int64()) + + data = [[0, 1, 2, 3], None, [4, 5, 6, 7], [8, 9, 10, 11]] + fixed_arr = pa.array(data, type=fixed_ty) + variable_arr = pa.array(data, type=variable_ty) + + result = fixed_arr.to_pandas() + expected = variable_arr.to_pandas() + + for left, right in zip(result, expected): + if left is None: + assert right is None + npt.assert_array_equal(left, right) + + def test_infer_numpy_array(self): + data = OrderedDict([ + ('ints', [ + np.array([0, 1], dtype=np.int64), + np.array([2, 3], dtype=np.int64) + ]) + ]) + df = pd.DataFrame(data) + expected_schema = pa.schema([ + pa.field('ints', pa.list_(pa.int64())) + ]) + + _check_pandas_roundtrip(df, expected_schema=expected_schema) + + def test_to_list_of_structs_pandas(self): + ints = pa.array([1, 2, 3], pa.int32()) + strings = pa.array([['a', 'b'], ['c', 'd'], ['e', 'f']], + pa.list_(pa.string())) + structs = pa.StructArray.from_arrays([ints, strings], ['f1', 'f2']) + data = pa.ListArray.from_arrays([0, 1, 3], structs) + + expected = pd.Series([ + [{'f1': 1, 'f2': ['a', 'b']}], + [{'f1': 2, 'f2': ['c', 'd']}, + {'f1': 3, 'f2': ['e', 'f']}] + ]) + + series = pd.Series(data.to_pandas()) + tm.assert_series_equal(series, expected) + + @pytest.mark.parametrize('t,data,expected', [ + ( + pa.int64, + [[1, 2], [3], None], + [None, [3], None] + ), + ( + pa.string, + [['aaa', 'bb'], ['c'], None], + [None, ['c'], None] + ), + ( + pa.null, + [[None, None], [None], None], + [None, [None], None] + ) + ]) + def test_array_from_pandas_typed_array_with_mask(self, t, data, expected): + m = np.array([True, False, True]) + + s = pd.Series(data) + result = pa.Array.from_pandas(s, mask=m, type=pa.list_(t())) + + assert pa.Array.from_pandas(expected, + type=pa.list_(t())).equals(result) + + def test_empty_list_roundtrip(self): + empty_list_array = np.empty((3,), dtype=object) + empty_list_array.fill([]) + + df = pd.DataFrame({'a': np.array(['1', '2', '3']), + 'b': empty_list_array}) + tbl = pa.Table.from_pandas(df) + + result = tbl.to_pandas() + + tm.assert_frame_equal(result, df) + + def test_array_from_nested_arrays(self): + df, schema = dataframe_with_arrays() + for field in schema: + arr = df[field.name].values + expected = pa.array(list(arr), type=field.type) + result = pa.array(arr) + assert result.type == field.type # == list<scalar> + assert result.equals(expected) + + def test_nested_large_list(self): + s = (pa.array([[[1, 2, 3], [4]], None], + type=pa.large_list(pa.large_list(pa.int64()))) + .to_pandas()) + tm.assert_series_equal( + s, pd.Series([[[1, 2, 3], [4]], None], dtype=object), + check_names=False) + + def test_large_binary_list(self): + for list_type_factory in (pa.list_, pa.large_list): + s = (pa.array([["aa", "bb"], None, ["cc"], []], + type=list_type_factory(pa.large_binary())) + .to_pandas()) + tm.assert_series_equal( + s, pd.Series([[b"aa", b"bb"], None, [b"cc"], []]), + check_names=False) + s = (pa.array([["aa", "bb"], None, ["cc"], []], + type=list_type_factory(pa.large_string())) + .to_pandas()) + tm.assert_series_equal( + s, pd.Series([["aa", "bb"], None, ["cc"], []]), + check_names=False) + + def test_list_of_dictionary(self): + child = pa.array(["foo", "bar", None, "foo"]).dictionary_encode() + arr = pa.ListArray.from_arrays([0, 1, 3, 3, 4], child) + + # Expected a Series of lists + expected = pd.Series(arr.to_pylist()) + tm.assert_series_equal(arr.to_pandas(), expected) + + # Same but with nulls + arr = arr.take([0, 1, None, 3]) + expected[2] = None + tm.assert_series_equal(arr.to_pandas(), expected) + + @pytest.mark.large_memory + def test_auto_chunking_on_list_overflow(self): + # ARROW-9976 + n = 2**21 + df = pd.DataFrame.from_dict({ + "a": list(np.zeros((n, 2**10), dtype='uint8')), + "b": range(n) + }) + table = pa.Table.from_pandas(df) + + column_a = table[0] + assert column_a.num_chunks == 2 + assert len(column_a.chunk(0)) == 2**21 - 1 + assert len(column_a.chunk(1)) == 1 + + def test_map_array_roundtrip(self): + data = [[(b'a', 1), (b'b', 2)], + [(b'c', 3)], + [(b'd', 4), (b'e', 5), (b'f', 6)], + [(b'g', 7)]] + + df = pd.DataFrame({"map": data}) + schema = pa.schema([("map", pa.map_(pa.binary(), pa.int32()))]) + + _check_pandas_roundtrip(df, schema=schema) + + def test_map_array_chunked(self): + data1 = [[(b'a', 1), (b'b', 2)], + [(b'c', 3)], + [(b'd', 4), (b'e', 5), (b'f', 6)], + [(b'g', 7)]] + data2 = [[(k, v * 2) for k, v in row] for row in data1] + + arr1 = pa.array(data1, type=pa.map_(pa.binary(), pa.int32())) + arr2 = pa.array(data2, type=pa.map_(pa.binary(), pa.int32())) + arr = pa.chunked_array([arr1, arr2]) + + expected = pd.Series(data1 + data2) + actual = arr.to_pandas() + tm.assert_series_equal(actual, expected, check_names=False) + + def test_map_array_with_nulls(self): + data = [[(b'a', 1), (b'b', 2)], + None, + [(b'd', 4), (b'e', 5), (b'f', None)], + [(b'g', 7)]] + + # None value in item array causes upcast to float + expected = [[(k, float(v) if v is not None else None) for k, v in row] + if row is not None else None for row in data] + expected = pd.Series(expected) + + arr = pa.array(data, type=pa.map_(pa.binary(), pa.int32())) + actual = arr.to_pandas() + tm.assert_series_equal(actual, expected, check_names=False) + + def test_map_array_dictionary_encoded(self): + offsets = pa.array([0, 3, 5]) + items = pa.array(['a', 'b', 'c', 'a', 'd']).dictionary_encode() + keys = pa.array(list(range(len(items)))) + arr = pa.MapArray.from_arrays(offsets, keys, items) + + # Dictionary encoded values converted to dense + expected = pd.Series( + [[(0, 'a'), (1, 'b'), (2, 'c')], [(3, 'a'), (4, 'd')]]) + + actual = arr.to_pandas() + tm.assert_series_equal(actual, expected, check_names=False) + + +class TestConvertStructTypes: + """ + Conversion tests for struct types. + """ + + def test_pandas_roundtrip(self): + df = pd.DataFrame({'dicts': [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]}) + + expected_schema = pa.schema([ + ('dicts', pa.struct([('a', pa.int64()), ('b', pa.int64())])), + ]) + + _check_pandas_roundtrip(df, expected_schema=expected_schema) + + # specifying schema explicitly in from_pandas + _check_pandas_roundtrip( + df, schema=expected_schema, expected_schema=expected_schema) + + def test_to_pandas(self): + ints = pa.array([None, 2, 3], type=pa.int64()) + strs = pa.array(['a', None, 'c'], type=pa.string()) + bools = pa.array([True, False, None], type=pa.bool_()) + arr = pa.StructArray.from_arrays( + [ints, strs, bools], + ['ints', 'strs', 'bools']) + + expected = pd.Series([ + {'ints': None, 'strs': 'a', 'bools': True}, + {'ints': 2, 'strs': None, 'bools': False}, + {'ints': 3, 'strs': 'c', 'bools': None}, + ]) + + series = pd.Series(arr.to_pandas()) + tm.assert_series_equal(series, expected) + + def test_to_pandas_multiple_chunks(self): + # ARROW-11855 + gc.collect() + bytes_start = pa.total_allocated_bytes() + ints1 = pa.array([1], type=pa.int64()) + ints2 = pa.array([2], type=pa.int64()) + arr1 = pa.StructArray.from_arrays([ints1], ['ints']) + arr2 = pa.StructArray.from_arrays([ints2], ['ints']) + arr = pa.chunked_array([arr1, arr2]) + + expected = pd.Series([ + {'ints': 1}, + {'ints': 2} + ]) + + series = pd.Series(arr.to_pandas()) + tm.assert_series_equal(series, expected) + + del series + del arr + del arr1 + del arr2 + del ints1 + del ints2 + bytes_end = pa.total_allocated_bytes() + assert bytes_end == bytes_start + + def test_from_numpy(self): + dt = np.dtype([('x', np.int32), + (('y_title', 'y'), np.bool_)]) + ty = pa.struct([pa.field('x', pa.int32()), + pa.field('y', pa.bool_())]) + + data = np.array([], dtype=dt) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [] + + data = np.array([(42, True), (43, False)], dtype=dt) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [{'x': 42, 'y': True}, + {'x': 43, 'y': False}] + + # With mask + arr = pa.array(data, mask=np.bool_([False, True]), type=ty) + assert arr.to_pylist() == [{'x': 42, 'y': True}, None] + + # Trivial struct type + dt = np.dtype([]) + ty = pa.struct([]) + + data = np.array([], dtype=dt) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [] + + data = np.array([(), ()], dtype=dt) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [{}, {}] + + def test_from_numpy_nested(self): + # Note: an object field inside a struct + dt = np.dtype([('x', np.dtype([('xx', np.int8), + ('yy', np.bool_)])), + ('y', np.int16), + ('z', np.object_)]) + # Note: itemsize is not a multiple of sizeof(object) + assert dt.itemsize == 12 + ty = pa.struct([pa.field('x', pa.struct([pa.field('xx', pa.int8()), + pa.field('yy', pa.bool_())])), + pa.field('y', pa.int16()), + pa.field('z', pa.string())]) + + data = np.array([], dtype=dt) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [] + + data = np.array([ + ((1, True), 2, 'foo'), + ((3, False), 4, 'bar')], dtype=dt) + arr = pa.array(data, type=ty) + assert arr.to_pylist() == [ + {'x': {'xx': 1, 'yy': True}, 'y': 2, 'z': 'foo'}, + {'x': {'xx': 3, 'yy': False}, 'y': 4, 'z': 'bar'}] + + @pytest.mark.slow + @pytest.mark.large_memory + def test_from_numpy_large(self): + # Exercise rechunking + nulls + target_size = 3 * 1024**3 # 4GB + dt = np.dtype([('x', np.float64), ('y', 'object')]) + bs = 65536 - dt.itemsize + block = b'.' * bs + n = target_size // (bs + dt.itemsize) + data = np.zeros(n, dtype=dt) + data['x'] = np.random.random_sample(n) + data['y'] = block + # Add implicit nulls + data['x'][data['x'] < 0.2] = np.nan + + ty = pa.struct([pa.field('x', pa.float64()), + pa.field('y', pa.binary())]) + arr = pa.array(data, type=ty, from_pandas=True) + assert arr.num_chunks == 2 + + def iter_chunked_array(arr): + for chunk in arr.iterchunks(): + yield from chunk + + def check(arr, data, mask=None): + assert len(arr) == len(data) + xs = data['x'] + ys = data['y'] + for i, obj in enumerate(iter_chunked_array(arr)): + try: + d = obj.as_py() + if mask is not None and mask[i]: + assert d is None + else: + x = xs[i] + if np.isnan(x): + assert d['x'] is None + else: + assert d['x'] == x + assert d['y'] == ys[i] + except Exception: + print("Failed at index", i) + raise + + check(arr, data) + del arr + + # Now with explicit mask + mask = np.random.random_sample(n) < 0.2 + arr = pa.array(data, type=ty, mask=mask, from_pandas=True) + assert arr.num_chunks == 2 + + check(arr, data, mask) + del arr + + def test_from_numpy_bad_input(self): + ty = pa.struct([pa.field('x', pa.int32()), + pa.field('y', pa.bool_())]) + dt = np.dtype([('x', np.int32), + ('z', np.bool_)]) + + data = np.array([], dtype=dt) + with pytest.raises(ValueError, + match="Missing field 'y'"): + pa.array(data, type=ty) + data = np.int32([]) + with pytest.raises(TypeError, + match="Expected struct array"): + pa.array(data, type=ty) + + def test_from_tuples(self): + df = pd.DataFrame({'tuples': [(1, 2), (3, 4)]}) + expected_df = pd.DataFrame( + {'tuples': [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]}) + + # conversion from tuples works when specifying expected struct type + struct_type = pa.struct([('a', pa.int64()), ('b', pa.int64())]) + + arr = np.asarray(df['tuples']) + _check_array_roundtrip( + arr, expected=expected_df['tuples'], type=struct_type) + + expected_schema = pa.schema([('tuples', struct_type)]) + _check_pandas_roundtrip( + df, expected=expected_df, schema=expected_schema, + expected_schema=expected_schema) + + def test_struct_of_dictionary(self): + names = ['ints', 'strs'] + children = [pa.array([456, 789, 456]).dictionary_encode(), + pa.array(["foo", "foo", None]).dictionary_encode()] + arr = pa.StructArray.from_arrays(children, names=names) + + # Expected a Series of {field name: field value} dicts + rows_as_tuples = zip(*(child.to_pylist() for child in children)) + rows_as_dicts = [dict(zip(names, row)) for row in rows_as_tuples] + + expected = pd.Series(rows_as_dicts) + tm.assert_series_equal(arr.to_pandas(), expected) + + # Same but with nulls + arr = arr.take([0, None, 2]) + expected[1] = None + tm.assert_series_equal(arr.to_pandas(), expected) + + +class TestZeroCopyConversion: + """ + Tests that zero-copy conversion works with some types. + """ + + def test_zero_copy_success(self): + result = pa.array([0, 1, 2]).to_pandas(zero_copy_only=True) + npt.assert_array_equal(result, [0, 1, 2]) + + def test_zero_copy_dictionaries(self): + arr = pa.DictionaryArray.from_arrays( + np.array([0, 0]), + np.array([5])) + + result = arr.to_pandas(zero_copy_only=True) + values = pd.Categorical([5, 5]) + + tm.assert_series_equal(pd.Series(result), pd.Series(values), + check_names=False) + + def test_zero_copy_timestamp(self): + arr = np.array(['2007-07-13'], dtype='datetime64[ns]') + result = pa.array(arr).to_pandas(zero_copy_only=True) + npt.assert_array_equal(result, arr) + + def test_zero_copy_duration(self): + arr = np.array([1], dtype='timedelta64[ns]') + result = pa.array(arr).to_pandas(zero_copy_only=True) + npt.assert_array_equal(result, arr) + + def check_zero_copy_failure(self, arr): + with pytest.raises(pa.ArrowInvalid): + arr.to_pandas(zero_copy_only=True) + + def test_zero_copy_failure_on_object_types(self): + self.check_zero_copy_failure(pa.array(['A', 'B', 'C'])) + + def test_zero_copy_failure_with_int_when_nulls(self): + self.check_zero_copy_failure(pa.array([0, 1, None])) + + def test_zero_copy_failure_with_float_when_nulls(self): + self.check_zero_copy_failure(pa.array([0.0, 1.0, None])) + + def test_zero_copy_failure_on_bool_types(self): + self.check_zero_copy_failure(pa.array([True, False])) + + def test_zero_copy_failure_on_list_types(self): + arr = pa.array([[1, 2], [8, 9]], type=pa.list_(pa.int64())) + self.check_zero_copy_failure(arr) + + def test_zero_copy_failure_on_timestamp_with_nulls(self): + arr = np.array([1, None], dtype='datetime64[ns]') + self.check_zero_copy_failure(pa.array(arr)) + + def test_zero_copy_failure_on_duration_with_nulls(self): + arr = np.array([1, None], dtype='timedelta64[ns]') + self.check_zero_copy_failure(pa.array(arr)) + + +def _non_threaded_conversion(): + df = _alltypes_example() + _check_pandas_roundtrip(df, use_threads=False) + _check_pandas_roundtrip(df, use_threads=False, as_batch=True) + + +def _threaded_conversion(): + df = _alltypes_example() + _check_pandas_roundtrip(df, use_threads=True) + _check_pandas_roundtrip(df, use_threads=True, as_batch=True) + + +class TestConvertMisc: + """ + Miscellaneous conversion tests. + """ + + type_pairs = [ + (np.int8, pa.int8()), + (np.int16, pa.int16()), + (np.int32, pa.int32()), + (np.int64, pa.int64()), + (np.uint8, pa.uint8()), + (np.uint16, pa.uint16()), + (np.uint32, pa.uint32()), + (np.uint64, pa.uint64()), + (np.float16, pa.float16()), + (np.float32, pa.float32()), + (np.float64, pa.float64()), + # XXX unsupported + # (np.dtype([('a', 'i2')]), pa.struct([pa.field('a', pa.int16())])), + (np.object_, pa.string()), + (np.object_, pa.binary()), + (np.object_, pa.binary(10)), + (np.object_, pa.list_(pa.int64())), + ] + + def test_all_none_objects(self): + df = pd.DataFrame({'a': [None, None, None]}) + _check_pandas_roundtrip(df) + + def test_all_none_category(self): + df = pd.DataFrame({'a': [None, None, None]}) + df['a'] = df['a'].astype('category') + _check_pandas_roundtrip(df) + + def test_empty_arrays(self): + for dtype, pa_type in self.type_pairs: + arr = np.array([], dtype=dtype) + _check_array_roundtrip(arr, type=pa_type) + + def test_non_threaded_conversion(self): + _non_threaded_conversion() + + def test_threaded_conversion_multiprocess(self): + # Parallel conversion should work from child processes too (ARROW-2963) + pool = mp.Pool(2) + try: + pool.apply(_threaded_conversion) + finally: + pool.close() + pool.join() + + def test_category(self): + repeats = 5 + v1 = ['foo', None, 'bar', 'qux', np.nan] + v2 = [4, 5, 6, 7, 8] + v3 = [b'foo', None, b'bar', b'qux', np.nan] + + arrays = { + 'cat_strings': pd.Categorical(v1 * repeats), + 'cat_strings_with_na': pd.Categorical(v1 * repeats, + categories=['foo', 'bar']), + 'cat_ints': pd.Categorical(v2 * repeats), + 'cat_binary': pd.Categorical(v3 * repeats), + 'cat_strings_ordered': pd.Categorical( + v1 * repeats, categories=['bar', 'qux', 'foo'], + ordered=True), + 'ints': v2 * repeats, + 'ints2': v2 * repeats, + 'strings': v1 * repeats, + 'strings2': v1 * repeats, + 'strings3': v3 * repeats} + df = pd.DataFrame(arrays) + _check_pandas_roundtrip(df) + + for k in arrays: + _check_array_roundtrip(arrays[k]) + + def test_category_implicit_from_pandas(self): + # ARROW-3374 + def _check(v): + arr = pa.array(v) + result = arr.to_pandas() + tm.assert_series_equal(pd.Series(result), pd.Series(v)) + + arrays = [ + pd.Categorical(['a', 'b', 'c'], categories=['a', 'b']), + pd.Categorical(['a', 'b', 'c'], categories=['a', 'b'], + ordered=True) + ] + for arr in arrays: + _check(arr) + + def test_empty_category(self): + # ARROW-2443 + df = pd.DataFrame({'cat': pd.Categorical([])}) + _check_pandas_roundtrip(df) + + def test_category_zero_chunks(self): + # ARROW-5952 + for pa_type, dtype in [(pa.string(), 'object'), (pa.int64(), 'int64')]: + a = pa.chunked_array([], pa.dictionary(pa.int8(), pa_type)) + result = a.to_pandas() + expected = pd.Categorical([], categories=np.array([], dtype=dtype)) + tm.assert_series_equal(pd.Series(result), pd.Series(expected)) + + table = pa.table({'a': a}) + result = table.to_pandas() + expected = pd.DataFrame({'a': expected}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "data,error_type", + [ + ({"a": ["a", 1, 2.0]}, pa.ArrowTypeError), + ({"a": ["a", 1, 2.0]}, pa.ArrowTypeError), + ({"a": [1, True]}, pa.ArrowTypeError), + ({"a": [True, "a"]}, pa.ArrowInvalid), + ({"a": [1, "a"]}, pa.ArrowInvalid), + ({"a": [1.0, "a"]}, pa.ArrowInvalid), + ], + ) + def test_mixed_types_fails(self, data, error_type): + df = pd.DataFrame(data) + msg = "Conversion failed for column a with type object" + with pytest.raises(error_type, match=msg): + pa.Table.from_pandas(df) + + def test_strided_data_import(self): + cases = [] + + columns = ['a', 'b', 'c'] + N, K = 100, 3 + random_numbers = np.random.randn(N, K).copy() * 100 + + numeric_dtypes = ['i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8', + 'f4', 'f8'] + + for type_name in numeric_dtypes: + cases.append(random_numbers.astype(type_name)) + + # strings + cases.append(np.array([random_ascii(10) for i in range(N * K)], + dtype=object) + .reshape(N, K).copy()) + + # booleans + boolean_objects = (np.array([True, False, True] * N, dtype=object) + .reshape(N, K).copy()) + + # add some nulls, so dtype comes back as objects + boolean_objects[5] = None + cases.append(boolean_objects) + + cases.append(np.arange("2016-01-01T00:00:00.001", N * K, + dtype='datetime64[ms]') + .reshape(N, K).copy()) + + strided_mask = (random_numbers > 0).astype(bool)[:, 0] + + for case in cases: + df = pd.DataFrame(case, columns=columns) + col = df['a'] + + _check_pandas_roundtrip(df) + _check_array_roundtrip(col) + _check_array_roundtrip(col, mask=strided_mask) + + def test_all_nones(self): + def _check_series(s): + converted = pa.array(s) + assert isinstance(converted, pa.NullArray) + assert len(converted) == 3 + assert converted.null_count == 3 + for item in converted: + assert item is pa.NA + + _check_series(pd.Series([None] * 3, dtype=object)) + _check_series(pd.Series([np.nan] * 3, dtype=object)) + _check_series(pd.Series([None, np.nan, None], dtype=object)) + + def test_partial_schema(self): + data = OrderedDict([ + ('a', [0, 1, 2, 3, 4]), + ('b', np.array([-10, -5, 0, 5, 10], dtype=np.int32)), + ('c', [-10, -5, 0, 5, 10]) + ]) + df = pd.DataFrame(data) + + partial_schema = pa.schema([ + pa.field('c', pa.int64()), + pa.field('a', pa.int64()) + ]) + + _check_pandas_roundtrip(df, schema=partial_schema, + expected=df[['c', 'a']], + expected_schema=partial_schema) + + def test_table_batch_empty_dataframe(self): + df = pd.DataFrame({}) + _check_pandas_roundtrip(df) + _check_pandas_roundtrip(df, as_batch=True) + + df2 = pd.DataFrame({}, index=[0, 1, 2]) + _check_pandas_roundtrip(df2, preserve_index=True) + _check_pandas_roundtrip(df2, as_batch=True, preserve_index=True) + + def test_convert_empty_table(self): + arr = pa.array([], type=pa.int64()) + empty_objects = pd.Series(np.array([], dtype=object)) + tm.assert_series_equal(arr.to_pandas(), + pd.Series(np.array([], dtype=np.int64))) + arr = pa.array([], type=pa.string()) + tm.assert_series_equal(arr.to_pandas(), empty_objects) + arr = pa.array([], type=pa.list_(pa.int64())) + tm.assert_series_equal(arr.to_pandas(), empty_objects) + arr = pa.array([], type=pa.struct([pa.field('a', pa.int64())])) + tm.assert_series_equal(arr.to_pandas(), empty_objects) + + def test_non_natural_stride(self): + """ + ARROW-2172: converting from a Numpy array with a stride that's + not a multiple of itemsize. + """ + dtype = np.dtype([('x', np.int32), ('y', np.int16)]) + data = np.array([(42, -1), (-43, 2)], dtype=dtype) + assert data.strides == (6,) + arr = pa.array(data['x'], type=pa.int32()) + assert arr.to_pylist() == [42, -43] + arr = pa.array(data['y'], type=pa.int16()) + assert arr.to_pylist() == [-1, 2] + + def test_array_from_strided_numpy_array(self): + # ARROW-5651 + np_arr = np.arange(0, 10, dtype=np.float32)[1:-1:2] + pa_arr = pa.array(np_arr, type=pa.float64()) + expected = pa.array([1.0, 3.0, 5.0, 7.0], type=pa.float64()) + pa_arr.equals(expected) + + def test_safe_unsafe_casts(self): + # ARROW-2799 + df = pd.DataFrame({ + 'A': list('abc'), + 'B': np.linspace(0, 1, 3) + }) + + schema = pa.schema([ + pa.field('A', pa.string()), + pa.field('B', pa.int32()) + ]) + + with pytest.raises(ValueError): + pa.Table.from_pandas(df, schema=schema) + + table = pa.Table.from_pandas(df, schema=schema, safe=False) + assert table.column('B').type == pa.int32() + + def test_error_sparse(self): + # ARROW-2818 + try: + df = pd.DataFrame({'a': pd.arrays.SparseArray([1, np.nan, 3])}) + except AttributeError: + # pandas.arrays module introduced in pandas 0.24 + df = pd.DataFrame({'a': pd.SparseArray([1, np.nan, 3])}) + with pytest.raises(TypeError, match="Sparse pandas data"): + pa.Table.from_pandas(df) + + +def test_safe_cast_from_float_with_nans_to_int(): + # TODO(kszucs): write tests for creating Date32 and Date64 arrays, see + # ARROW-4258 and https://github.com/apache/arrow/pull/3395 + values = pd.Series([1, 2, None, 4]) + arr = pa.Array.from_pandas(values, type=pa.int32(), safe=True) + expected = pa.array([1, 2, None, 4], type=pa.int32()) + assert arr.equals(expected) + + +def _fully_loaded_dataframe_example(): + index = pd.MultiIndex.from_arrays([ + pd.date_range('2000-01-01', periods=5).repeat(2), + np.tile(np.array(['foo', 'bar'], dtype=object), 5) + ]) + + c1 = pd.date_range('2000-01-01', periods=10) + data = { + 0: c1, + 1: c1.tz_localize('utc'), + 2: c1.tz_localize('US/Eastern'), + 3: c1[::2].tz_localize('utc').repeat(2).astype('category'), + 4: ['foo', 'bar'] * 5, + 5: pd.Series(['foo', 'bar'] * 5).astype('category').values, + 6: [True, False] * 5, + 7: np.random.randn(10), + 8: np.random.randint(0, 100, size=10), + 9: pd.period_range('2013', periods=10, freq='M') + } + + if Version(pd.__version__) >= Version('0.21'): + # There is an issue with pickling IntervalIndex in pandas 0.20.x + data[10] = pd.interval_range(start=1, freq=1, periods=10) + + return pd.DataFrame(data, index=index) + + +@pytest.mark.parametrize('columns', ([b'foo'], ['foo'])) +def test_roundtrip_with_bytes_unicode(columns): + df = pd.DataFrame(columns=columns) + table1 = pa.Table.from_pandas(df) + table2 = pa.Table.from_pandas(table1.to_pandas()) + assert table1.equals(table2) + assert table1.schema.equals(table2.schema) + assert table1.schema.metadata == table2.schema.metadata + + +def _check_serialize_components_roundtrip(pd_obj): + with pytest.warns(FutureWarning): + ctx = pa.default_serialization_context() + + with pytest.warns(FutureWarning): + components = ctx.serialize(pd_obj).to_components() + with pytest.warns(FutureWarning): + deserialized = ctx.deserialize_components(components) + + if isinstance(pd_obj, pd.DataFrame): + tm.assert_frame_equal(pd_obj, deserialized) + else: + tm.assert_series_equal(pd_obj, deserialized) + + +@pytest.mark.skipif( + Version('1.16.0') <= Version(np.__version__) < Version('1.16.1'), + reason='Until numpy/numpy#12745 is resolved') +def test_serialize_deserialize_pandas(): + # ARROW-1784, serialize and deserialize DataFrame by decomposing + # BlockManager + df = _fully_loaded_dataframe_example() + _check_serialize_components_roundtrip(df) + + +def test_serialize_deserialize_empty_pandas(): + # ARROW-7996, serialize and deserialize empty pandas objects + df = pd.DataFrame({'col1': [], 'col2': [], 'col3': []}) + _check_serialize_components_roundtrip(df) + + series = pd.Series([], dtype=np.float32, name='col') + _check_serialize_components_roundtrip(series) + + +def _pytime_from_micros(val): + microseconds = val % 1000000 + val //= 1000000 + seconds = val % 60 + val //= 60 + minutes = val % 60 + hours = val // 60 + return time(hours, minutes, seconds, microseconds) + + +def _pytime_to_micros(pytime): + return (pytime.hour * 3600000000 + + pytime.minute * 60000000 + + pytime.second * 1000000 + + pytime.microsecond) + + +def test_convert_unsupported_type_error_message(): + # ARROW-1454 + + # custom python objects + class A: + pass + + df = pd.DataFrame({'a': [A(), A()]}) + + msg = 'Conversion failed for column a with type object' + with pytest.raises(ValueError, match=msg): + pa.Table.from_pandas(df) + + # period unsupported for pandas <= 0.25 + if Version(pd.__version__) <= Version('0.25'): + df = pd.DataFrame({ + 'a': pd.period_range('2000-01-01', periods=20), + }) + + msg = 'Conversion failed for column a with type (period|object)' + with pytest.raises((TypeError, ValueError), match=msg): + pa.Table.from_pandas(df) + + +# ---------------------------------------------------------------------- +# Hypothesis tests + + +@h.given(past.arrays(past.pandas_compatible_types)) +def test_array_to_pandas_roundtrip(arr): + s = arr.to_pandas() + restored = pa.array(s, type=arr.type, from_pandas=True) + assert restored.equals(arr) + + +# ---------------------------------------------------------------------- +# Test object deduplication in to_pandas + + +def _generate_dedup_example(nunique, repeats): + unique_values = [rands(10) for i in range(nunique)] + return unique_values * repeats + + +def _assert_nunique(obj, expected): + assert len({id(x) for x in obj}) == expected + + +def test_to_pandas_deduplicate_strings_array_types(): + nunique = 100 + repeats = 10 + values = _generate_dedup_example(nunique, repeats) + + for arr in [pa.array(values, type=pa.binary()), + pa.array(values, type=pa.utf8()), + pa.chunked_array([values, values])]: + _assert_nunique(arr.to_pandas(), nunique) + _assert_nunique(arr.to_pandas(deduplicate_objects=False), len(arr)) + + +def test_to_pandas_deduplicate_strings_table_types(): + nunique = 100 + repeats = 10 + values = _generate_dedup_example(nunique, repeats) + + arr = pa.array(values) + rb = pa.RecordBatch.from_arrays([arr], ['foo']) + tbl = pa.Table.from_batches([rb]) + + for obj in [rb, tbl]: + _assert_nunique(obj.to_pandas()['foo'], nunique) + _assert_nunique(obj.to_pandas(deduplicate_objects=False)['foo'], + len(obj)) + + +def test_to_pandas_deduplicate_integers_as_objects(): + nunique = 100 + repeats = 10 + + # Python automatically interns smaller integers + unique_values = list(np.random.randint(10000000, 1000000000, size=nunique)) + unique_values[nunique // 2] = None + + arr = pa.array(unique_values * repeats) + + _assert_nunique(arr.to_pandas(integer_object_nulls=True), nunique) + _assert_nunique(arr.to_pandas(integer_object_nulls=True, + deduplicate_objects=False), + # Account for None + (nunique - 1) * repeats + 1) + + +def test_to_pandas_deduplicate_date_time(): + nunique = 100 + repeats = 10 + + unique_values = list(range(nunique)) + + cases = [ + # raw type, array type, to_pandas options + ('int32', 'date32', {'date_as_object': True}), + ('int64', 'date64', {'date_as_object': True}), + ('int32', 'time32[ms]', {}), + ('int64', 'time64[us]', {}) + ] + + for raw_type, array_type, pandas_options in cases: + raw_arr = pa.array(unique_values * repeats, type=raw_type) + casted_arr = raw_arr.cast(array_type) + + _assert_nunique(casted_arr.to_pandas(**pandas_options), + nunique) + _assert_nunique(casted_arr.to_pandas(deduplicate_objects=False, + **pandas_options), + len(casted_arr)) + + +# --------------------------------------------------------------------- + +def test_table_from_pandas_checks_field_nullability(): + # ARROW-2136 + df = pd.DataFrame({'a': [1.2, 2.1, 3.1], + 'b': [np.nan, 'string', 'foo']}) + schema = pa.schema([pa.field('a', pa.float64(), nullable=False), + pa.field('b', pa.utf8(), nullable=False)]) + + with pytest.raises(ValueError): + pa.Table.from_pandas(df, schema=schema) + + +def test_table_from_pandas_keeps_column_order_of_dataframe(): + df1 = pd.DataFrame(OrderedDict([ + ('partition', [0, 0, 1, 1]), + ('arrays', [[0, 1, 2], [3, 4], None, None]), + ('floats', [None, None, 1.1, 3.3]) + ])) + df2 = df1[['floats', 'partition', 'arrays']] + + schema1 = pa.schema([ + ('partition', pa.int64()), + ('arrays', pa.list_(pa.int64())), + ('floats', pa.float64()), + ]) + schema2 = pa.schema([ + ('floats', pa.float64()), + ('partition', pa.int64()), + ('arrays', pa.list_(pa.int64())) + ]) + + table1 = pa.Table.from_pandas(df1, preserve_index=False) + table2 = pa.Table.from_pandas(df2, preserve_index=False) + + assert table1.schema.equals(schema1) + assert table2.schema.equals(schema2) + + +def test_table_from_pandas_keeps_column_order_of_schema(): + # ARROW-3766 + df = pd.DataFrame(OrderedDict([ + ('partition', [0, 0, 1, 1]), + ('arrays', [[0, 1, 2], [3, 4], None, None]), + ('floats', [None, None, 1.1, 3.3]) + ])) + + schema = pa.schema([ + ('floats', pa.float64()), + ('arrays', pa.list_(pa.int32())), + ('partition', pa.int32()) + ]) + + df1 = df[df.partition == 0] + df2 = df[df.partition == 1][['floats', 'partition', 'arrays']] + + table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False) + table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False) + + assert table1.schema.equals(schema) + assert table1.schema.equals(table2.schema) + + +def test_table_from_pandas_columns_argument_only_does_filtering(): + df = pd.DataFrame(OrderedDict([ + ('partition', [0, 0, 1, 1]), + ('arrays', [[0, 1, 2], [3, 4], None, None]), + ('floats', [None, None, 1.1, 3.3]) + ])) + + columns1 = ['arrays', 'floats', 'partition'] + schema1 = pa.schema([ + ('arrays', pa.list_(pa.int64())), + ('floats', pa.float64()), + ('partition', pa.int64()) + ]) + + columns2 = ['floats', 'partition'] + schema2 = pa.schema([ + ('floats', pa.float64()), + ('partition', pa.int64()) + ]) + + table1 = pa.Table.from_pandas(df, columns=columns1, preserve_index=False) + table2 = pa.Table.from_pandas(df, columns=columns2, preserve_index=False) + + assert table1.schema.equals(schema1) + assert table2.schema.equals(schema2) + + +def test_table_from_pandas_columns_and_schema_are_mutually_exclusive(): + df = pd.DataFrame(OrderedDict([ + ('partition', [0, 0, 1, 1]), + ('arrays', [[0, 1, 2], [3, 4], None, None]), + ('floats', [None, None, 1.1, 3.3]) + ])) + schema = pa.schema([ + ('partition', pa.int32()), + ('arrays', pa.list_(pa.int32())), + ('floats', pa.float64()), + ]) + columns = ['arrays', 'floats'] + + with pytest.raises(ValueError): + pa.Table.from_pandas(df, schema=schema, columns=columns) + + +def test_table_from_pandas_keeps_schema_nullability(): + # ARROW-5169 + df = pd.DataFrame({'a': [1, 2, 3, 4]}) + + schema = pa.schema([ + pa.field('a', pa.int64(), nullable=False), + ]) + + table = pa.Table.from_pandas(df) + assert table.schema.field('a').nullable is True + table = pa.Table.from_pandas(df, schema=schema) + assert table.schema.field('a').nullable is False + + +def test_table_from_pandas_schema_index_columns(): + # ARROW-5220 + df = pd.DataFrame({'a': [1, 2, 3], 'b': [0.1, 0.2, 0.3]}) + + schema = pa.schema([ + ('a', pa.int64()), + ('b', pa.float64()), + ('index', pa.int32()), + ]) + + # schema includes index with name not in dataframe + with pytest.raises(KeyError, match="name 'index' present in the"): + pa.Table.from_pandas(df, schema=schema) + + df.index.name = 'index' + + # schema includes correct index name -> roundtrip works + _check_pandas_roundtrip(df, schema=schema, preserve_index=True, + expected_schema=schema) + + # schema includes correct index name but preserve_index=False + with pytest.raises(ValueError, match="'preserve_index=False' was"): + pa.Table.from_pandas(df, schema=schema, preserve_index=False) + + # in case of preserve_index=None -> RangeIndex serialized as metadata + # clashes with the index in the schema + with pytest.raises(ValueError, match="name 'index' is present in the " + "schema, but it is a RangeIndex"): + pa.Table.from_pandas(df, schema=schema, preserve_index=None) + + df.index = pd.Index([0, 1, 2], name='index') + + # for non-RangeIndex, both preserve_index=None and True work + _check_pandas_roundtrip(df, schema=schema, preserve_index=None, + expected_schema=schema) + _check_pandas_roundtrip(df, schema=schema, preserve_index=True, + expected_schema=schema) + + # schema has different order (index column not at the end) + schema = pa.schema([ + ('index', pa.int32()), + ('a', pa.int64()), + ('b', pa.float64()), + ]) + _check_pandas_roundtrip(df, schema=schema, preserve_index=None, + expected_schema=schema) + _check_pandas_roundtrip(df, schema=schema, preserve_index=True, + expected_schema=schema) + + # schema does not include the index -> index is not included as column + # even though preserve_index=True/None + schema = pa.schema([ + ('a', pa.int64()), + ('b', pa.float64()), + ]) + expected = df.copy() + expected = expected.reset_index(drop=True) + _check_pandas_roundtrip(df, schema=schema, preserve_index=None, + expected_schema=schema, expected=expected) + _check_pandas_roundtrip(df, schema=schema, preserve_index=True, + expected_schema=schema, expected=expected) + + # dataframe with a MultiIndex + df.index = pd.MultiIndex.from_tuples([('a', 1), ('a', 2), ('b', 1)], + names=['level1', 'level2']) + schema = pa.schema([ + ('level1', pa.string()), + ('level2', pa.int64()), + ('a', pa.int64()), + ('b', pa.float64()), + ]) + _check_pandas_roundtrip(df, schema=schema, preserve_index=True, + expected_schema=schema) + _check_pandas_roundtrip(df, schema=schema, preserve_index=None, + expected_schema=schema) + + # only one of the levels of the MultiIndex is included + schema = pa.schema([ + ('level2', pa.int64()), + ('a', pa.int64()), + ('b', pa.float64()), + ]) + expected = df.copy() + expected = expected.reset_index('level1', drop=True) + _check_pandas_roundtrip(df, schema=schema, preserve_index=True, + expected_schema=schema, expected=expected) + _check_pandas_roundtrip(df, schema=schema, preserve_index=None, + expected_schema=schema, expected=expected) + + +def test_table_from_pandas_schema_index_columns__unnamed_index(): + # ARROW-6999 - unnamed indices in specified schema + df = pd.DataFrame({'a': [1, 2, 3], 'b': [0.1, 0.2, 0.3]}) + + expected_schema = pa.schema([ + ('a', pa.int64()), + ('b', pa.float64()), + ('__index_level_0__', pa.int64()), + ]) + + schema = pa.Schema.from_pandas(df, preserve_index=True) + table = pa.Table.from_pandas(df, preserve_index=True, schema=schema) + assert table.schema.remove_metadata().equals(expected_schema) + + # non-RangeIndex (preserved by default) + df = pd.DataFrame({'a': [1, 2, 3], 'b': [0.1, 0.2, 0.3]}, index=[0, 1, 2]) + schema = pa.Schema.from_pandas(df) + table = pa.Table.from_pandas(df, schema=schema) + assert table.schema.remove_metadata().equals(expected_schema) + + +def test_table_from_pandas_schema_with_custom_metadata(): + # ARROW-7087 - metadata disappear from pandas + df = pd.DataFrame() + schema = pa.Schema.from_pandas(df).with_metadata({'meta': 'True'}) + table = pa.Table.from_pandas(df, schema=schema) + assert table.schema.metadata.get(b'meta') == b'True' + + +def test_table_from_pandas_schema_field_order_metadat(): + # ARROW-10532 + # ensure that a different field order in specified schema doesn't + # mangle metadata + df = pd.DataFrame({ + "datetime": pd.date_range("2020-01-01T00:00:00Z", freq="H", periods=2), + "float": np.random.randn(2) + }) + + schema = pa.schema([ + pa.field("float", pa.float32(), nullable=True), + pa.field("datetime", pa.timestamp("s", tz="UTC"), nullable=False) + ]) + + table = pa.Table.from_pandas(df, schema=schema) + assert table.schema.equals(schema) + metadata_float = table.schema.pandas_metadata["columns"][0] + assert metadata_float["name"] == "float" + assert metadata_float["metadata"] is None + metadata_datetime = table.schema.pandas_metadata["columns"][1] + assert metadata_datetime["name"] == "datetime" + assert metadata_datetime["metadata"] == {'timezone': 'UTC'} + + result = table.to_pandas() + expected = df[["float", "datetime"]].astype({"float": "float32"}) + tm.assert_frame_equal(result, expected) + + +# ---------------------------------------------------------------------- +# RecordBatch, Table + + +def test_recordbatch_from_to_pandas(): + data = pd.DataFrame({ + 'c1': np.array([1, 2, 3, 4, 5], dtype='int64'), + 'c2': np.array([1, 2, 3, 4, 5], dtype='uint32'), + 'c3': np.random.randn(5), + 'c4': ['foo', 'bar', None, 'baz', 'qux'], + 'c5': [False, True, False, True, False] + }) + + batch = pa.RecordBatch.from_pandas(data) + result = batch.to_pandas() + tm.assert_frame_equal(data, result) + + +def test_recordbatchlist_to_pandas(): + data1 = pd.DataFrame({ + 'c1': np.array([1, 1, 2], dtype='uint32'), + 'c2': np.array([1.0, 2.0, 3.0], dtype='float64'), + 'c3': [True, None, False], + 'c4': ['foo', 'bar', None] + }) + + data2 = pd.DataFrame({ + 'c1': np.array([3, 5], dtype='uint32'), + 'c2': np.array([4.0, 5.0], dtype='float64'), + 'c3': [True, True], + 'c4': ['baz', 'qux'] + }) + + batch1 = pa.RecordBatch.from_pandas(data1) + batch2 = pa.RecordBatch.from_pandas(data2) + + table = pa.Table.from_batches([batch1, batch2]) + result = table.to_pandas() + data = pd.concat([data1, data2]).reset_index(drop=True) + tm.assert_frame_equal(data, result) + + +def test_recordbatch_table_pass_name_to_pandas(): + rb = pa.record_batch([pa.array([1, 2, 3, 4])], names=['a0']) + t = pa.table([pa.array([1, 2, 3, 4])], names=['a0']) + assert rb[0].to_pandas().name == 'a0' + assert t[0].to_pandas().name == 'a0' + + +# ---------------------------------------------------------------------- +# Metadata serialization + + +@pytest.mark.parametrize( + ('type', 'expected'), + [ + (pa.null(), 'empty'), + (pa.bool_(), 'bool'), + (pa.int8(), 'int8'), + (pa.int16(), 'int16'), + (pa.int32(), 'int32'), + (pa.int64(), 'int64'), + (pa.uint8(), 'uint8'), + (pa.uint16(), 'uint16'), + (pa.uint32(), 'uint32'), + (pa.uint64(), 'uint64'), + (pa.float16(), 'float16'), + (pa.float32(), 'float32'), + (pa.float64(), 'float64'), + (pa.date32(), 'date'), + (pa.date64(), 'date'), + (pa.binary(), 'bytes'), + (pa.binary(length=4), 'bytes'), + (pa.string(), 'unicode'), + (pa.list_(pa.list_(pa.int16())), 'list[list[int16]]'), + (pa.decimal128(18, 3), 'decimal'), + (pa.timestamp('ms'), 'datetime'), + (pa.timestamp('us', 'UTC'), 'datetimetz'), + (pa.time32('s'), 'time'), + (pa.time64('us'), 'time') + ] +) +def test_logical_type(type, expected): + assert get_logical_type(type) == expected + + +# ---------------------------------------------------------------------- +# to_pandas uses MemoryPool + +def test_array_uses_memory_pool(): + # ARROW-6570 + N = 10000 + arr = pa.array(np.arange(N, dtype=np.int64), + mask=np.random.randint(0, 2, size=N).astype(np.bool_)) + + # In the case the gc is caught loafing + gc.collect() + + prior_allocation = pa.total_allocated_bytes() + + x = arr.to_pandas() + assert pa.total_allocated_bytes() == (prior_allocation + N * 8) + x = None # noqa + gc.collect() + + assert pa.total_allocated_bytes() == prior_allocation + + # zero copy does not allocate memory + arr = pa.array(np.arange(N, dtype=np.int64)) + + prior_allocation = pa.total_allocated_bytes() + x = arr.to_pandas() # noqa + assert pa.total_allocated_bytes() == prior_allocation + + +def test_singleton_blocks_zero_copy(): + # Part of ARROW-3789 + t = pa.table([pa.array(np.arange(1000, dtype=np.int64))], ['f0']) + + # Zero copy if split_blocks=True + _check_to_pandas_memory_unchanged(t, split_blocks=True) + + prior_allocation = pa.total_allocated_bytes() + result = t.to_pandas() + assert result['f0'].values.flags.writeable + assert pa.total_allocated_bytes() > prior_allocation + + +def _check_to_pandas_memory_unchanged(obj, **kwargs): + prior_allocation = pa.total_allocated_bytes() + x = obj.to_pandas(**kwargs) # noqa + + # Memory allocation unchanged -- either zero copy or self-destructing + assert pa.total_allocated_bytes() == prior_allocation + + +def test_to_pandas_split_blocks(): + # ARROW-3789 + t = pa.table([ + pa.array([1, 2, 3, 4, 5], type='i1'), + pa.array([1, 2, 3, 4, 5], type='i4'), + pa.array([1, 2, 3, 4, 5], type='i8'), + pa.array([1, 2, 3, 4, 5], type='f4'), + pa.array([1, 2, 3, 4, 5], type='f8'), + pa.array([1, 2, 3, 4, 5], type='f8'), + pa.array([1, 2, 3, 4, 5], type='f8'), + pa.array([1, 2, 3, 4, 5], type='f8'), + ], ['f{}'.format(i) for i in range(8)]) + + _check_blocks_created(t, 8) + _check_to_pandas_memory_unchanged(t, split_blocks=True) + + +def _check_blocks_created(t, number): + x = t.to_pandas(split_blocks=True) + assert len(x._data.blocks) == number + + +def test_to_pandas_self_destruct(): + K = 50 + + def _make_table(): + return pa.table([ + # Slice to force a copy + pa.array(np.random.randn(10000)[::2]) + for i in range(K) + ], ['f{}'.format(i) for i in range(K)]) + + t = _make_table() + _check_to_pandas_memory_unchanged(t, split_blocks=True, self_destruct=True) + + # Check non-split-block behavior + t = _make_table() + _check_to_pandas_memory_unchanged(t, self_destruct=True) + + +def test_table_uses_memory_pool(): + N = 10000 + arr = pa.array(np.arange(N, dtype=np.int64)) + t = pa.table([arr, arr, arr], ['f0', 'f1', 'f2']) + + prior_allocation = pa.total_allocated_bytes() + x = t.to_pandas() + + assert pa.total_allocated_bytes() == (prior_allocation + 3 * N * 8) + + # Check successful garbage collection + x = None # noqa + gc.collect() + assert pa.total_allocated_bytes() == prior_allocation + + +def test_object_leak_in_numpy_array(): + # ARROW-6876 + arr = pa.array([{'a': 1}]) + np_arr = arr.to_pandas() + assert np_arr.dtype == np.dtype('object') + obj = np_arr[0] + refcount = sys.getrefcount(obj) + assert sys.getrefcount(obj) == refcount + del np_arr + assert sys.getrefcount(obj) == refcount - 1 + + +def test_object_leak_in_dataframe(): + # ARROW-6876 + arr = pa.array([{'a': 1}]) + table = pa.table([arr], ['f0']) + col = table.to_pandas()['f0'] + assert col.dtype == np.dtype('object') + obj = col[0] + refcount = sys.getrefcount(obj) + assert sys.getrefcount(obj) == refcount + del col + assert sys.getrefcount(obj) == refcount - 1 + + +# ---------------------------------------------------------------------- +# Some nested array tests array tests + + +def test_array_from_py_float32(): + data = [[1.2, 3.4], [9.0, 42.0]] + + t = pa.float32() + + arr1 = pa.array(data[0], type=t) + arr2 = pa.array(data, type=pa.list_(t)) + + expected1 = np.array(data[0], dtype=np.float32) + expected2 = pd.Series([np.array(data[0], dtype=np.float32), + np.array(data[1], dtype=np.float32)]) + + assert arr1.type == t + assert arr1.equals(pa.array(expected1)) + assert arr2.equals(pa.array(expected2)) + + +# ---------------------------------------------------------------------- +# Timestamp tests + + +def test_cast_timestamp_unit(): + # ARROW-1680 + val = datetime.now() + s = pd.Series([val]) + s_nyc = s.dt.tz_localize('tzlocal()').dt.tz_convert('America/New_York') + + us_with_tz = pa.timestamp('us', tz='America/New_York') + + arr = pa.Array.from_pandas(s_nyc, type=us_with_tz) + + # ARROW-1906 + assert arr.type == us_with_tz + + arr2 = pa.Array.from_pandas(s, type=pa.timestamp('us')) + + assert arr[0].as_py() == s_nyc[0].to_pydatetime() + assert arr2[0].as_py() == s[0].to_pydatetime() + + # Disallow truncation + arr = pa.array([123123], type='int64').cast(pa.timestamp('ms')) + expected = pa.array([123], type='int64').cast(pa.timestamp('s')) + + # sanity check that the cast worked right + assert arr.type == pa.timestamp('ms') + + target = pa.timestamp('s') + with pytest.raises(ValueError): + arr.cast(target) + + result = arr.cast(target, safe=False) + assert result.equals(expected) + + # ARROW-1949 + series = pd.Series([pd.Timestamp(1), pd.Timestamp(10), pd.Timestamp(1000)]) + expected = pa.array([0, 0, 1], type=pa.timestamp('us')) + + with pytest.raises(ValueError): + pa.array(series, type=pa.timestamp('us')) + + with pytest.raises(ValueError): + pa.Array.from_pandas(series, type=pa.timestamp('us')) + + result = pa.Array.from_pandas(series, type=pa.timestamp('us'), safe=False) + assert result.equals(expected) + + result = pa.array(series, type=pa.timestamp('us'), safe=False) + assert result.equals(expected) + + +def test_nested_with_timestamp_tz_round_trip(): + ts = pd.Timestamp.now() + ts_dt = ts.to_pydatetime() + arr = pa.array([ts_dt], type=pa.timestamp('us', tz='America/New_York')) + struct = pa.StructArray.from_arrays([arr, arr], ['start', 'stop']) + + result = struct.to_pandas() + restored = pa.array(result) + assert restored.equals(struct) + + +def test_nested_with_timestamp_tz(): + # ARROW-7723 + ts = pd.Timestamp.now() + ts_dt = ts.to_pydatetime() + + # XXX: Ensure that this data does not get promoted to nanoseconds (and thus + # integers) to preserve behavior in 0.15.1 + for unit in ['s', 'ms', 'us']: + if unit in ['s', 'ms']: + # This is used for verifying timezone conversion to micros are not + # important + def truncate(x): return x.replace(microsecond=0) + else: + def truncate(x): return x + arr = pa.array([ts], type=pa.timestamp(unit)) + arr2 = pa.array([ts], type=pa.timestamp(unit, tz='America/New_York')) + + arr3 = pa.StructArray.from_arrays([arr, arr], ['start', 'stop']) + arr4 = pa.StructArray.from_arrays([arr2, arr2], ['start', 'stop']) + + result = arr3.to_pandas() + assert isinstance(result[0]['start'], datetime) + assert result[0]['start'].tzinfo is None + assert isinstance(result[0]['stop'], datetime) + assert result[0]['stop'].tzinfo is None + + result = arr4.to_pandas() + assert isinstance(result[0]['start'], datetime) + assert result[0]['start'].tzinfo is not None + utc_dt = result[0]['start'].astimezone(timezone.utc) + assert truncate(utc_dt).replace(tzinfo=None) == truncate(ts_dt) + assert isinstance(result[0]['stop'], datetime) + assert result[0]['stop'].tzinfo is not None + + # same conversion for table + result = pa.table({'a': arr3}).to_pandas() + assert isinstance(result['a'][0]['start'], datetime) + assert result['a'][0]['start'].tzinfo is None + assert isinstance(result['a'][0]['stop'], datetime) + assert result['a'][0]['stop'].tzinfo is None + + result = pa.table({'a': arr4}).to_pandas() + assert isinstance(result['a'][0]['start'], datetime) + assert result['a'][0]['start'].tzinfo is not None + assert isinstance(result['a'][0]['stop'], datetime) + assert result['a'][0]['stop'].tzinfo is not None + + +# ---------------------------------------------------------------------- +# DictionaryArray tests + + +def test_dictionary_with_pandas(): + src_indices = np.repeat([0, 1, 2], 2) + dictionary = np.array(['foo', 'bar', 'baz'], dtype=object) + mask = np.array([False, False, True, False, False, False]) + + for index_type in ['uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', + 'uint64', 'int64']: + indices = src_indices.astype(index_type) + d1 = pa.DictionaryArray.from_arrays(indices, dictionary) + d2 = pa.DictionaryArray.from_arrays(indices, dictionary, mask=mask) + + if index_type[0] == 'u': + # TODO: unsigned dictionary indices to pandas + with pytest.raises(TypeError): + d1.to_pandas() + continue + + pandas1 = d1.to_pandas() + ex_pandas1 = pd.Categorical.from_codes(indices, categories=dictionary) + + tm.assert_series_equal(pd.Series(pandas1), pd.Series(ex_pandas1)) + + pandas2 = d2.to_pandas() + assert pandas2.isnull().sum() == 1 + + # Unsigned integers converted to signed + signed_indices = indices + if index_type[0] == 'u': + signed_indices = indices.astype(index_type[1:]) + ex_pandas2 = pd.Categorical.from_codes(np.where(mask, -1, + signed_indices), + categories=dictionary) + + tm.assert_series_equal(pd.Series(pandas2), pd.Series(ex_pandas2)) + + +def random_strings(n, item_size, pct_null=0, dictionary=None): + if dictionary is not None: + result = dictionary[np.random.randint(0, len(dictionary), size=n)] + else: + result = np.array([random_ascii(item_size) for i in range(n)], + dtype=object) + + if pct_null > 0: + result[np.random.rand(n) < pct_null] = None + + return result + + +def test_variable_dictionary_to_pandas(): + np.random.seed(12345) + + d1 = pa.array(random_strings(100, 32), type='string') + d2 = pa.array(random_strings(100, 16), type='string') + d3 = pa.array(random_strings(10000, 10), type='string') + + a1 = pa.DictionaryArray.from_arrays( + np.random.randint(0, len(d1), size=1000, dtype='i4'), + d1 + ) + a2 = pa.DictionaryArray.from_arrays( + np.random.randint(0, len(d2), size=1000, dtype='i4'), + d2 + ) + + # With some nulls + a3 = pa.DictionaryArray.from_arrays( + np.random.randint(0, len(d3), size=1000, dtype='i4'), d3) + + i4 = pa.array( + np.random.randint(0, len(d3), size=1000, dtype='i4'), + mask=np.random.rand(1000) < 0.1 + ) + a4 = pa.DictionaryArray.from_arrays(i4, d3) + + expected_dict = pa.concat_arrays([d1, d2, d3]) + + a = pa.chunked_array([a1, a2, a3, a4]) + a_dense = pa.chunked_array([a1.cast('string'), + a2.cast('string'), + a3.cast('string'), + a4.cast('string')]) + + result = a.to_pandas() + result_dense = a_dense.to_pandas() + + assert (result.cat.categories == expected_dict.to_pandas()).all() + + expected_dense = result.astype('str') + expected_dense[result_dense.isnull()] = None + tm.assert_series_equal(result_dense, expected_dense) + + +def test_dictionary_encoded_nested_to_pandas(): + # ARROW-6899 + child = pa.array(['a', 'a', 'a', 'b', 'b']).dictionary_encode() + + arr = pa.ListArray.from_arrays([0, 3, 5], child) + + result = arr.to_pandas() + expected = pd.Series([np.array(['a', 'a', 'a'], dtype=object), + np.array(['b', 'b'], dtype=object)]) + + tm.assert_series_equal(result, expected) + + +def test_dictionary_from_pandas(): + cat = pd.Categorical(['a', 'b', 'a']) + expected_type = pa.dictionary(pa.int8(), pa.string()) + + result = pa.array(cat) + assert result.to_pylist() == ['a', 'b', 'a'] + assert result.type.equals(expected_type) + + # with missing values in categorical + cat = pd.Categorical(['a', 'b', None, 'a']) + + result = pa.array(cat) + assert result.to_pylist() == ['a', 'b', None, 'a'] + assert result.type.equals(expected_type) + + # with additional mask + result = pa.array(cat, mask=np.array([False, False, False, True])) + assert result.to_pylist() == ['a', 'b', None, None] + assert result.type.equals(expected_type) + + +def test_dictionary_from_pandas_specified_type(): + # ARROW-7168 - ensure specified type is always respected + + # the same as cat = pd.Categorical(['a', 'b']) but explicit about dtypes + cat = pd.Categorical.from_codes( + np.array([0, 1], dtype='int8'), np.array(['a', 'b'], dtype=object)) + + # different index type -> allow this + # (the type of the 'codes' in pandas is not part of the data type) + typ = pa.dictionary(index_type=pa.int16(), value_type=pa.string()) + result = pa.array(cat, type=typ) + assert result.type.equals(typ) + assert result.to_pylist() == ['a', 'b'] + + # mismatching values type -> raise error + typ = pa.dictionary(index_type=pa.int8(), value_type=pa.int64()) + with pytest.raises(pa.ArrowInvalid): + result = pa.array(cat, type=typ) + + # mismatching order -> raise error (for now a deprecation warning) + typ = pa.dictionary( + index_type=pa.int8(), value_type=pa.string(), ordered=True) + with pytest.warns(FutureWarning, match="The 'ordered' flag of the passed"): + result = pa.array(cat, type=typ) + assert result.to_pylist() == ['a', 'b'] + + # with mask + typ = pa.dictionary(index_type=pa.int16(), value_type=pa.string()) + result = pa.array(cat, type=typ, mask=np.array([False, True])) + assert result.type.equals(typ) + assert result.to_pylist() == ['a', None] + + # empty categorical -> be flexible in values type to allow + cat = pd.Categorical([]) + + typ = pa.dictionary(index_type=pa.int8(), value_type=pa.string()) + result = pa.array(cat, type=typ) + assert result.type.equals(typ) + assert result.to_pylist() == [] + typ = pa.dictionary(index_type=pa.int8(), value_type=pa.int64()) + result = pa.array(cat, type=typ) + assert result.type.equals(typ) + assert result.to_pylist() == [] + + # passing non-dictionary type + cat = pd.Categorical(['a', 'b']) + result = pa.array(cat, type=pa.string()) + expected = pa.array(['a', 'b'], type=pa.string()) + assert result.equals(expected) + assert result.to_pylist() == ['a', 'b'] + + +# ---------------------------------------------------------------------- +# Array protocol in pandas conversions tests + + +def test_array_protocol(): + if Version(pd.__version__) < Version('0.24.0'): + pytest.skip('IntegerArray only introduced in 0.24') + + df = pd.DataFrame({'a': pd.Series([1, 2, None], dtype='Int64')}) + + if Version(pd.__version__) < Version('0.26.0.dev'): + # with pandas<=0.25, trying to convert nullable integer errors + with pytest.raises(TypeError): + pa.table(df) + else: + # __arrow_array__ added to pandas IntegerArray in 0.26.0.dev + + # default conversion + result = pa.table(df) + expected = pa.array([1, 2, None], pa.int64()) + assert result[0].chunk(0).equals(expected) + + # with specifying schema + schema = pa.schema([('a', pa.float64())]) + result = pa.table(df, schema=schema) + expected2 = pa.array([1, 2, None], pa.float64()) + assert result[0].chunk(0).equals(expected2) + + # pass Series to pa.array + result = pa.array(df['a']) + assert result.equals(expected) + result = pa.array(df['a'], type=pa.float64()) + assert result.equals(expected2) + + # pass actual ExtensionArray to pa.array + result = pa.array(df['a'].values) + assert result.equals(expected) + result = pa.array(df['a'].values, type=pa.float64()) + assert result.equals(expected2) + + +class DummyExtensionType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int64()) + + def __reduce__(self): + return DummyExtensionType, () + + +def PandasArray__arrow_array__(self, type=None): + # hardcode dummy return regardless of self - we only want to check that + # this method is correctly called + storage = pa.array([1, 2, 3], type=pa.int64()) + return pa.ExtensionArray.from_storage(DummyExtensionType(), storage) + + +def test_array_protocol_pandas_extension_types(monkeypatch): + # ARROW-7022 - ensure protocol works for Period / Interval extension dtypes + + if Version(pd.__version__) < Version('0.24.0'): + pytest.skip('Period/IntervalArray only introduced in 0.24') + + storage = pa.array([1, 2, 3], type=pa.int64()) + expected = pa.ExtensionArray.from_storage(DummyExtensionType(), storage) + + monkeypatch.setattr(pd.arrays.PeriodArray, "__arrow_array__", + PandasArray__arrow_array__, raising=False) + monkeypatch.setattr(pd.arrays.IntervalArray, "__arrow_array__", + PandasArray__arrow_array__, raising=False) + for arr in [pd.period_range("2012-01-01", periods=3, freq="D").array, + pd.interval_range(1, 4).array]: + result = pa.array(arr) + assert result.equals(expected) + result = pa.array(pd.Series(arr)) + assert result.equals(expected) + result = pa.array(pd.Index(arr)) + assert result.equals(expected) + result = pa.table(pd.DataFrame({'a': arr})).column('a').chunk(0) + assert result.equals(expected) + + +# ---------------------------------------------------------------------- +# Pandas ExtensionArray support + + +def _Int64Dtype__from_arrow__(self, array): + # for test only deal with single chunk for now + # TODO: do we require handling of chunked arrays in the protocol? + if isinstance(array, pa.Array): + arr = array + else: + # ChunkedArray - here only deal with a single chunk for the test + arr = array.chunk(0) + buflist = arr.buffers() + data = np.frombuffer(buflist[-1], dtype='int64')[ + arr.offset:arr.offset + len(arr)] + bitmask = buflist[0] + if bitmask is not None: + mask = pa.BooleanArray.from_buffers( + pa.bool_(), len(arr), [None, bitmask]) + mask = np.asarray(mask) + else: + mask = np.ones(len(arr), dtype=bool) + int_arr = pd.arrays.IntegerArray(data.copy(), ~mask, copy=False) + return int_arr + + +def test_convert_to_extension_array(monkeypatch): + if Version(pd.__version__) < Version("0.26.0.dev"): + pytest.skip("Conversion from IntegerArray to arrow not yet supported") + + import pandas.core.internals as _int + + # table converted from dataframe with extension types (so pandas_metadata + # has this information) + df = pd.DataFrame( + {'a': [1, 2, 3], 'b': pd.array([2, 3, 4], dtype='Int64'), + 'c': [4, 5, 6]}) + table = pa.table(df) + + # Int64Dtype is recognized -> convert to extension block by default + # for a proper roundtrip + result = table.to_pandas() + assert not isinstance(result._data.blocks[0], _int.ExtensionBlock) + assert result._data.blocks[0].values.dtype == np.dtype("int64") + assert isinstance(result._data.blocks[1], _int.ExtensionBlock) + tm.assert_frame_equal(result, df) + + # test with missing values + df2 = pd.DataFrame({'a': pd.array([1, 2, None], dtype='Int64')}) + table2 = pa.table(df2) + result = table2.to_pandas() + assert isinstance(result._data.blocks[0], _int.ExtensionBlock) + tm.assert_frame_equal(result, df2) + + # monkeypatch pandas Int64Dtype to *not* have the protocol method + if Version(pd.__version__) < Version("1.3.0.dev"): + monkeypatch.delattr( + pd.core.arrays.integer._IntegerDtype, "__from_arrow__") + else: + monkeypatch.delattr( + pd.core.arrays.integer.NumericDtype, "__from_arrow__") + # Int64Dtype has no __from_arrow__ -> use normal conversion + result = table.to_pandas() + assert len(result._data.blocks) == 1 + assert not isinstance(result._data.blocks[0], _int.ExtensionBlock) + + +class MyCustomIntegerType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int64()) + + def __reduce__(self): + return MyCustomIntegerType, () + + def to_pandas_dtype(self): + return pd.Int64Dtype() + + +def test_conversion_extensiontype_to_extensionarray(monkeypatch): + # converting extension type to linked pandas ExtensionDtype/Array + import pandas.core.internals as _int + + if Version(pd.__version__) < Version("0.24.0"): + pytest.skip("ExtensionDtype introduced in pandas 0.24") + + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(MyCustomIntegerType(), storage) + table = pa.table({'a': arr}) + + if Version(pd.__version__) < Version("0.26.0.dev"): + # ensure pandas Int64Dtype has the protocol method (for older pandas) + monkeypatch.setattr( + pd.Int64Dtype, '__from_arrow__', _Int64Dtype__from_arrow__, + raising=False) + + # extension type points to Int64Dtype, which knows how to create a + # pandas ExtensionArray + result = arr.to_pandas() + assert isinstance(result._data.blocks[0], _int.ExtensionBlock) + expected = pd.Series([1, 2, 3, 4], dtype='Int64') + tm.assert_series_equal(result, expected) + + result = table.to_pandas() + assert isinstance(result._data.blocks[0], _int.ExtensionBlock) + expected = pd.DataFrame({'a': pd.array([1, 2, 3, 4], dtype='Int64')}) + tm.assert_frame_equal(result, expected) + + # monkeypatch pandas Int64Dtype to *not* have the protocol method + # (remove the version added above and the actual version for recent pandas) + if Version(pd.__version__) < Version("0.26.0.dev"): + monkeypatch.delattr(pd.Int64Dtype, "__from_arrow__") + elif Version(pd.__version__) < Version("1.3.0.dev"): + monkeypatch.delattr( + pd.core.arrays.integer._IntegerDtype, "__from_arrow__") + else: + monkeypatch.delattr( + pd.core.arrays.integer.NumericDtype, "__from_arrow__") + + result = arr.to_pandas() + assert not isinstance(result._data.blocks[0], _int.ExtensionBlock) + expected = pd.Series([1, 2, 3, 4]) + tm.assert_series_equal(result, expected) + + with pytest.raises(ValueError): + table.to_pandas() + + +def test_to_pandas_extension_dtypes_mapping(): + if Version(pd.__version__) < Version("0.26.0.dev"): + pytest.skip("Conversion to pandas IntegerArray not yet supported") + + table = pa.table({'a': pa.array([1, 2, 3], pa.int64())}) + + # default use numpy dtype + result = table.to_pandas() + assert result['a'].dtype == np.dtype('int64') + + # specify to override the default + result = table.to_pandas(types_mapper={pa.int64(): pd.Int64Dtype()}.get) + assert isinstance(result['a'].dtype, pd.Int64Dtype) + + # types that return None in function get normal conversion + table = pa.table({'a': pa.array([1, 2, 3], pa.int32())}) + result = table.to_pandas(types_mapper={pa.int64(): pd.Int64Dtype()}.get) + assert result['a'].dtype == np.dtype('int32') + + # `types_mapper` overrules the pandas metadata + table = pa.table(pd.DataFrame({'a': pd.array([1, 2, 3], dtype="Int64")})) + result = table.to_pandas() + assert isinstance(result['a'].dtype, pd.Int64Dtype) + result = table.to_pandas( + types_mapper={pa.int64(): pd.PeriodDtype('D')}.get) + assert isinstance(result['a'].dtype, pd.PeriodDtype) + + +def test_array_to_pandas(): + if Version(pd.__version__) < Version("1.1"): + pytest.skip("ExtensionDtype to_pandas method missing") + + for arr in [pd.period_range("2012-01-01", periods=3, freq="D").array, + pd.interval_range(1, 4).array]: + result = pa.array(arr).to_pandas() + expected = pd.Series(arr) + tm.assert_series_equal(result, expected) + + # TODO implement proper conversion for chunked array + # result = pa.table({"col": arr})["col"].to_pandas() + # expected = pd.Series(arr, name="col") + # tm.assert_series_equal(result, expected) + + +# ---------------------------------------------------------------------- +# Legacy metadata compatibility tests + + +def test_metadata_compat_range_index_pre_0_12(): + # Forward compatibility for metadata created from pandas.RangeIndex + # prior to pyarrow 0.13.0 + a_values = ['foo', 'bar', None, 'baz'] + b_values = ['a', 'a', 'b', 'b'] + a_arrow = pa.array(a_values, type='utf8') + b_arrow = pa.array(b_values, type='utf8') + + rng_index_arrow = pa.array([0, 2, 4, 6], type='int64') + + gen_name_0 = '__index_level_0__' + gen_name_1 = '__index_level_1__' + + # Case 1: named RangeIndex + e1 = pd.DataFrame({ + 'a': a_values + }, index=pd.RangeIndex(0, 8, step=2, name='qux')) + t1 = pa.Table.from_arrays([a_arrow, rng_index_arrow], + names=['a', 'qux']) + t1 = t1.replace_schema_metadata({ + b'pandas': json.dumps( + {'index_columns': ['qux'], + 'column_indexes': [{'name': None, + 'field_name': None, + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': {'encoding': 'UTF-8'}}], + 'columns': [{'name': 'a', + 'field_name': 'a', + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': None}, + {'name': 'qux', + 'field_name': 'qux', + 'pandas_type': 'int64', + 'numpy_type': 'int64', + 'metadata': None}], + 'pandas_version': '0.23.4'} + )}) + r1 = t1.to_pandas() + tm.assert_frame_equal(r1, e1) + + # Case 2: named RangeIndex, but conflicts with an actual column + e2 = pd.DataFrame({ + 'qux': a_values + }, index=pd.RangeIndex(0, 8, step=2, name='qux')) + t2 = pa.Table.from_arrays([a_arrow, rng_index_arrow], + names=['qux', gen_name_0]) + t2 = t2.replace_schema_metadata({ + b'pandas': json.dumps( + {'index_columns': [gen_name_0], + 'column_indexes': [{'name': None, + 'field_name': None, + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': {'encoding': 'UTF-8'}}], + 'columns': [{'name': 'a', + 'field_name': 'a', + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': None}, + {'name': 'qux', + 'field_name': gen_name_0, + 'pandas_type': 'int64', + 'numpy_type': 'int64', + 'metadata': None}], + 'pandas_version': '0.23.4'} + )}) + r2 = t2.to_pandas() + tm.assert_frame_equal(r2, e2) + + # Case 3: unnamed RangeIndex + e3 = pd.DataFrame({ + 'a': a_values + }, index=pd.RangeIndex(0, 8, step=2, name=None)) + t3 = pa.Table.from_arrays([a_arrow, rng_index_arrow], + names=['a', gen_name_0]) + t3 = t3.replace_schema_metadata({ + b'pandas': json.dumps( + {'index_columns': [gen_name_0], + 'column_indexes': [{'name': None, + 'field_name': None, + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': {'encoding': 'UTF-8'}}], + 'columns': [{'name': 'a', + 'field_name': 'a', + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': None}, + {'name': None, + 'field_name': gen_name_0, + 'pandas_type': 'int64', + 'numpy_type': 'int64', + 'metadata': None}], + 'pandas_version': '0.23.4'} + )}) + r3 = t3.to_pandas() + tm.assert_frame_equal(r3, e3) + + # Case 4: MultiIndex with named RangeIndex + e4 = pd.DataFrame({ + 'a': a_values + }, index=[pd.RangeIndex(0, 8, step=2, name='qux'), b_values]) + t4 = pa.Table.from_arrays([a_arrow, rng_index_arrow, b_arrow], + names=['a', 'qux', gen_name_1]) + t4 = t4.replace_schema_metadata({ + b'pandas': json.dumps( + {'index_columns': ['qux', gen_name_1], + 'column_indexes': [{'name': None, + 'field_name': None, + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': {'encoding': 'UTF-8'}}], + 'columns': [{'name': 'a', + 'field_name': 'a', + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': None}, + {'name': 'qux', + 'field_name': 'qux', + 'pandas_type': 'int64', + 'numpy_type': 'int64', + 'metadata': None}, + {'name': None, + 'field_name': gen_name_1, + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': None}], + 'pandas_version': '0.23.4'} + )}) + r4 = t4.to_pandas() + tm.assert_frame_equal(r4, e4) + + # Case 4: MultiIndex with unnamed RangeIndex + e5 = pd.DataFrame({ + 'a': a_values + }, index=[pd.RangeIndex(0, 8, step=2, name=None), b_values]) + t5 = pa.Table.from_arrays([a_arrow, rng_index_arrow, b_arrow], + names=['a', gen_name_0, gen_name_1]) + t5 = t5.replace_schema_metadata({ + b'pandas': json.dumps( + {'index_columns': [gen_name_0, gen_name_1], + 'column_indexes': [{'name': None, + 'field_name': None, + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': {'encoding': 'UTF-8'}}], + 'columns': [{'name': 'a', + 'field_name': 'a', + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': None}, + {'name': None, + 'field_name': gen_name_0, + 'pandas_type': 'int64', + 'numpy_type': 'int64', + 'metadata': None}, + {'name': None, + 'field_name': gen_name_1, + 'pandas_type': 'unicode', + 'numpy_type': 'object', + 'metadata': None}], + 'pandas_version': '0.23.4'} + )}) + r5 = t5.to_pandas() + tm.assert_frame_equal(r5, e5) + + +def test_metadata_compat_missing_field_name(): + # Combination of missing field name but with index column as metadata. + # This combo occurs in the latest versions of fastparquet (0.3.2), but not + # in pyarrow itself (since field_name was added in 0.8, index as metadata + # only added later) + + a_values = [1, 2, 3, 4] + b_values = ['a', 'b', 'c', 'd'] + a_arrow = pa.array(a_values, type='int64') + b_arrow = pa.array(b_values, type='utf8') + + expected = pd.DataFrame({ + 'a': a_values, + 'b': b_values, + }, index=pd.RangeIndex(0, 8, step=2, name='qux')) + table = pa.table({'a': a_arrow, 'b': b_arrow}) + + # metadata generated by fastparquet 0.3.2 with missing field_names + table = table.replace_schema_metadata({ + b'pandas': json.dumps({ + 'column_indexes': [ + {'field_name': None, + 'metadata': None, + 'name': None, + 'numpy_type': 'object', + 'pandas_type': 'mixed-integer'} + ], + 'columns': [ + {'metadata': None, + 'name': 'a', + 'numpy_type': 'int64', + 'pandas_type': 'int64'}, + {'metadata': None, + 'name': 'b', + 'numpy_type': 'object', + 'pandas_type': 'unicode'} + ], + 'index_columns': [ + {'kind': 'range', + 'name': 'qux', + 'start': 0, + 'step': 2, + 'stop': 8} + ], + 'pandas_version': '0.25.0'} + + )}) + result = table.to_pandas() + tm.assert_frame_equal(result, expected) + + +def test_metadata_index_name_not_json_serializable(): + name = np.int64(6) # not json serializable by default + table = pa.table(pd.DataFrame(index=pd.RangeIndex(0, 4, name=name))) + metadata = table.schema.pandas_metadata + assert metadata['index_columns'][0]['name'] == '6' + + +def test_metadata_index_name_is_json_serializable(): + name = 6 # json serializable by default + table = pa.table(pd.DataFrame(index=pd.RangeIndex(0, 4, name=name))) + metadata = table.schema.pandas_metadata + assert metadata['index_columns'][0]['name'] == 6 + + +def make_df_with_timestamps(): + # Some of the milliseconds timestamps deliberately don't fit in the range + # that is possible with nanosecond timestamps. + df = pd.DataFrame({ + 'dateTimeMs': [ + np.datetime64('0001-01-01 00:00', 'ms'), + np.datetime64('2012-05-02 12:35', 'ms'), + np.datetime64('2012-05-03 15:42', 'ms'), + np.datetime64('3000-05-03 15:42', 'ms'), + ], + 'dateTimeNs': [ + np.datetime64('1991-01-01 00:00', 'ns'), + np.datetime64('2012-05-02 12:35', 'ns'), + np.datetime64('2012-05-03 15:42', 'ns'), + np.datetime64('2050-05-03 15:42', 'ns'), + ], + }) + # Not part of what we're testing, just ensuring that the inputs are what we + # expect. + assert (df.dateTimeMs.dtype, df.dateTimeNs.dtype) == ( + # O == object, <M8[ns] == timestamp64[ns] + np.dtype("O"), np.dtype("<M8[ns]") + ) + return df + + +@pytest.mark.parquet +def test_timestamp_as_object_parquet(tempdir): + # Timestamps can be stored as Parquet and reloaded into Pandas with no loss + # of information if the timestamp_as_object option is True. + df = make_df_with_timestamps() + table = pa.Table.from_pandas(df) + filename = tempdir / "timestamps_from_pandas.parquet" + pq.write_table(table, filename, version="2.0") + result = pq.read_table(filename) + df2 = result.to_pandas(timestamp_as_object=True) + tm.assert_frame_equal(df, df2) + + +def test_timestamp_as_object_out_of_range(): + # Out of range timestamps can be converted Arrow and reloaded into Pandas + # with no loss of information if the timestamp_as_object option is True. + df = make_df_with_timestamps() + table = pa.Table.from_pandas(df) + df2 = table.to_pandas(timestamp_as_object=True) + tm.assert_frame_equal(df, df2) + + +@pytest.mark.parametrize("resolution", ["s", "ms", "us"]) +@pytest.mark.parametrize("tz", [None, "America/New_York"]) +# One datetime outside nanosecond range, one inside nanosecond range: +@pytest.mark.parametrize("dt", [datetime(1553, 1, 1), datetime(2020, 1, 1)]) +def test_timestamp_as_object_non_nanosecond(resolution, tz, dt): + # Timestamps can be converted Arrow and reloaded into Pandas with no loss + # of information if the timestamp_as_object option is True. + arr = pa.array([dt], type=pa.timestamp(resolution, tz=tz)) + table = pa.table({'a': arr}) + + for result in [ + arr.to_pandas(timestamp_as_object=True), + table.to_pandas(timestamp_as_object=True)['a'] + ]: + assert result.dtype == object + assert isinstance(result[0], datetime) + if tz: + assert result[0].tzinfo is not None + expected = result[0].tzinfo.fromutc(dt) + else: + assert result[0].tzinfo is None + expected = dt + assert result[0] == expected + + +def test_threaded_pandas_import(): + invoke_script("pandas_threaded_import.py") diff --git a/src/arrow/python/pyarrow/tests/test_plasma.py b/src/arrow/python/pyarrow/tests/test_plasma.py new file mode 100644 index 000000000..ed08a6872 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_plasma.py @@ -0,0 +1,1073 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import multiprocessing +import os +import pytest +import random +import signal +import struct +import subprocess +import sys +import time + +import numpy as np +import pyarrow as pa + + +DEFAULT_PLASMA_STORE_MEMORY = 10 ** 8 +USE_VALGRIND = os.getenv("PLASMA_VALGRIND") == "1" +EXTERNAL_STORE = "hashtable://test" +SMALL_OBJECT_SIZE = 9000 + + +def random_name(): + return str(random.randint(0, 99999999)) + + +def random_object_id(): + import pyarrow.plasma as plasma + return plasma.ObjectID(np.random.bytes(20)) + + +def generate_metadata(length): + metadata = bytearray(length) + if length > 0: + metadata[0] = random.randint(0, 255) + metadata[-1] = random.randint(0, 255) + for _ in range(100): + metadata[random.randint(0, length - 1)] = random.randint(0, 255) + return metadata + + +def write_to_data_buffer(buff, length): + array = np.frombuffer(buff, dtype="uint8") + if length > 0: + array[0] = random.randint(0, 255) + array[-1] = random.randint(0, 255) + for _ in range(100): + array[random.randint(0, length - 1)] = random.randint(0, 255) + + +def create_object_with_id(client, object_id, data_size, metadata_size, + seal=True): + metadata = generate_metadata(metadata_size) + memory_buffer = client.create(object_id, data_size, metadata) + write_to_data_buffer(memory_buffer, data_size) + if seal: + client.seal(object_id) + return memory_buffer, metadata + + +def create_object(client, data_size, metadata_size=0, seal=True): + object_id = random_object_id() + memory_buffer, metadata = create_object_with_id(client, object_id, + data_size, metadata_size, + seal=seal) + return object_id, memory_buffer, metadata + + +@pytest.mark.plasma +class TestPlasmaClient: + + def setup_method(self, test_method): + import pyarrow.plasma as plasma + # Start Plasma store. + self.plasma_store_ctx = plasma.start_plasma_store( + plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY, + use_valgrind=USE_VALGRIND) + self.plasma_store_name, self.p = self.plasma_store_ctx.__enter__() + # Connect to Plasma. + self.plasma_client = plasma.connect(self.plasma_store_name) + self.plasma_client2 = plasma.connect(self.plasma_store_name) + + def teardown_method(self, test_method): + try: + # Check that the Plasma store is still alive. + assert self.p.poll() is None + # Ensure Valgrind and/or coverage have a clean exit + # Valgrind misses SIGTERM if it is delivered before the + # event loop is ready; this race condition is mitigated + # but not solved by time.sleep(). + if USE_VALGRIND: + time.sleep(1.0) + self.p.send_signal(signal.SIGTERM) + self.p.wait(timeout=5) + assert self.p.returncode == 0 + finally: + self.plasma_store_ctx.__exit__(None, None, None) + + def test_connection_failure_raises_exception(self): + import pyarrow.plasma as plasma + # ARROW-1264 + with pytest.raises(IOError): + plasma.connect('unknown-store-name', num_retries=1) + + def test_create(self): + # Create an object id string. + object_id = random_object_id() + # Create a new buffer and write to it. + length = 50 + memory_buffer = np.frombuffer(self.plasma_client.create(object_id, + length), + dtype="uint8") + for i in range(length): + memory_buffer[i] = i % 256 + # Seal the object. + self.plasma_client.seal(object_id) + # Get the object. + memory_buffer = np.frombuffer( + self.plasma_client.get_buffers([object_id])[0], dtype="uint8") + for i in range(length): + assert memory_buffer[i] == i % 256 + + def test_create_with_metadata(self): + for length in range(0, 1000, 3): + # Create an object id string. + object_id = random_object_id() + # Create a random metadata string. + metadata = generate_metadata(length) + # Create a new buffer and write to it. + memory_buffer = np.frombuffer(self.plasma_client.create(object_id, + length, + metadata), + dtype="uint8") + for i in range(length): + memory_buffer[i] = i % 256 + # Seal the object. + self.plasma_client.seal(object_id) + # Get the object. + memory_buffer = np.frombuffer( + self.plasma_client.get_buffers([object_id])[0], dtype="uint8") + for i in range(length): + assert memory_buffer[i] == i % 256 + # Get the metadata. + metadata_buffer = np.frombuffer( + self.plasma_client.get_metadata([object_id])[0], dtype="uint8") + assert len(metadata) == len(metadata_buffer) + for i in range(len(metadata)): + assert metadata[i] == metadata_buffer[i] + + def test_create_existing(self): + # This test is partially used to test the code path in which we create + # an object with an ID that already exists + length = 100 + for _ in range(1000): + object_id = random_object_id() + self.plasma_client.create(object_id, length, + generate_metadata(length)) + try: + self.plasma_client.create(object_id, length, + generate_metadata(length)) + # TODO(pcm): Introduce a more specific error type here. + except pa.lib.ArrowException: + pass + else: + assert False + + def test_create_and_seal(self): + + # Create a bunch of objects. + object_ids = [] + for i in range(1000): + object_id = random_object_id() + object_ids.append(object_id) + self.plasma_client.create_and_seal(object_id, i * b'a', i * b'b') + + for i in range(1000): + [data_tuple] = self.plasma_client.get_buffers([object_ids[i]], + with_meta=True) + assert data_tuple[1].to_pybytes() == i * b'a' + assert (self.plasma_client.get_metadata( + [object_ids[i]])[0].to_pybytes() == + i * b'b') + + # Make sure that creating the same object twice raises an exception. + object_id = random_object_id() + self.plasma_client.create_and_seal(object_id, b'a', b'b') + with pytest.raises(pa.plasma.PlasmaObjectExists): + self.plasma_client.create_and_seal(object_id, b'a', b'b') + + # Make sure that these objects can be evicted. + big_object = DEFAULT_PLASMA_STORE_MEMORY // 10 * b'a' + object_ids = [] + for _ in range(20): + object_id = random_object_id() + object_ids.append(object_id) + self.plasma_client.create_and_seal(random_object_id(), big_object, + big_object) + for i in range(10): + assert not self.plasma_client.contains(object_ids[i]) + + def test_get(self): + num_object_ids = 60 + # Test timing out of get with various timeouts. + for timeout in [0, 10, 100, 1000]: + object_ids = [random_object_id() for _ in range(num_object_ids)] + results = self.plasma_client.get_buffers(object_ids, + timeout_ms=timeout) + assert results == num_object_ids * [None] + + data_buffers = [] + metadata_buffers = [] + for i in range(num_object_ids): + if i % 2 == 0: + data_buffer, metadata_buffer = create_object_with_id( + self.plasma_client, object_ids[i], 2000, 2000) + data_buffers.append(data_buffer) + metadata_buffers.append(metadata_buffer) + + # Test timing out from some but not all get calls with various + # timeouts. + for timeout in [0, 10, 100, 1000]: + data_results = self.plasma_client.get_buffers(object_ids, + timeout_ms=timeout) + # metadata_results = self.plasma_client.get_metadata( + # object_ids, timeout_ms=timeout) + for i in range(num_object_ids): + if i % 2 == 0: + array1 = np.frombuffer(data_buffers[i // 2], dtype="uint8") + array2 = np.frombuffer(data_results[i], dtype="uint8") + np.testing.assert_equal(array1, array2) + # TODO(rkn): We should compare the metadata as well. But + # currently the types are different (e.g., memoryview + # versus bytearray). + # assert plasma.buffers_equal( + # metadata_buffers[i // 2], metadata_results[i]) + else: + assert results[i] is None + + # Test trying to get an object that was created by the same client but + # not sealed. + object_id = random_object_id() + self.plasma_client.create(object_id, 10, b"metadata") + assert self.plasma_client.get_buffers( + [object_id], timeout_ms=0, with_meta=True)[0][1] is None + assert self.plasma_client.get_buffers( + [object_id], timeout_ms=1, with_meta=True)[0][1] is None + self.plasma_client.seal(object_id) + assert self.plasma_client.get_buffers( + [object_id], timeout_ms=0, with_meta=True)[0][1] is not None + + def test_buffer_lifetime(self): + # ARROW-2195 + arr = pa.array([1, 12, 23, 3, 34], pa.int32()) + batch = pa.RecordBatch.from_arrays([arr], ['field1']) + + # Serialize RecordBatch into Plasma store + sink = pa.MockOutputStream() + writer = pa.RecordBatchStreamWriter(sink, batch.schema) + writer.write_batch(batch) + writer.close() + + object_id = random_object_id() + data_buffer = self.plasma_client.create(object_id, sink.size()) + stream = pa.FixedSizeBufferWriter(data_buffer) + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + writer.close() + self.plasma_client.seal(object_id) + del data_buffer + + # Unserialize RecordBatch from Plasma store + [data_buffer] = self.plasma_client2.get_buffers([object_id]) + reader = pa.RecordBatchStreamReader(data_buffer) + read_batch = reader.read_next_batch() + # Lose reference to returned buffer. The RecordBatch must still + # be backed by valid memory. + del data_buffer, reader + + assert read_batch.equals(batch) + + def test_put_and_get(self): + for value in [["hello", "world", 3, 1.0], None, "hello"]: + object_id = self.plasma_client.put(value) + [result] = self.plasma_client.get([object_id]) + assert result == value + + result = self.plasma_client.get(object_id) + assert result == value + + object_id = random_object_id() + [result] = self.plasma_client.get([object_id], timeout_ms=0) + assert result == pa.plasma.ObjectNotAvailable + + @pytest.mark.filterwarnings( + "ignore:'pyarrow.deserialize':FutureWarning") + def test_put_and_get_raw_buffer(self): + temp_id = random_object_id() + use_meta = b"RAW" + + def deserialize_or_output(data_tuple): + if data_tuple[0] == use_meta: + return data_tuple[1].to_pybytes() + else: + if data_tuple[1] is None: + return pa.plasma.ObjectNotAvailable + else: + return pa.deserialize(data_tuple[1]) + + for value in [b"Bytes Test", temp_id.binary(), 10 * b"\x00", 123]: + if isinstance(value, bytes): + object_id = self.plasma_client.put_raw_buffer( + value, metadata=use_meta) + else: + object_id = self.plasma_client.put(value) + [result] = self.plasma_client.get_buffers([object_id], + with_meta=True) + result = deserialize_or_output(result) + assert result == value + + object_id = random_object_id() + [result] = self.plasma_client.get_buffers([object_id], + timeout_ms=0, + with_meta=True) + result = deserialize_or_output(result) + assert result == pa.plasma.ObjectNotAvailable + + @pytest.mark.filterwarnings( + "ignore:'serialization_context':FutureWarning") + def test_put_and_get_serialization_context(self): + + class CustomType: + def __init__(self, val): + self.val = val + + val = CustomType(42) + + with pytest.raises(pa.ArrowSerializationError): + self.plasma_client.put(val) + + serialization_context = pa.lib.SerializationContext() + serialization_context.register_type(CustomType, 20*"\x00") + + object_id = self.plasma_client.put( + val, None, serialization_context=serialization_context) + + with pytest.raises(pa.ArrowSerializationError): + result = self.plasma_client.get(object_id) + + result = self.plasma_client.get( + object_id, -1, serialization_context=serialization_context) + assert result.val == val.val + + def test_store_arrow_objects(self): + data = np.random.randn(10, 4) + # Write an arrow object. + object_id = random_object_id() + tensor = pa.Tensor.from_numpy(data) + data_size = pa.ipc.get_tensor_size(tensor) + buf = self.plasma_client.create(object_id, data_size) + stream = pa.FixedSizeBufferWriter(buf) + pa.ipc.write_tensor(tensor, stream) + self.plasma_client.seal(object_id) + # Read the arrow object. + [tensor] = self.plasma_client.get_buffers([object_id]) + reader = pa.BufferReader(tensor) + array = pa.ipc.read_tensor(reader).to_numpy() + # Assert that they are equal. + np.testing.assert_equal(data, array) + + @pytest.mark.pandas + def test_store_pandas_dataframe(self): + import pandas as pd + import pyarrow.plasma as plasma + d = {'one': pd.Series([1., 2., 3.], index=['a', 'b', 'c']), + 'two': pd.Series([1., 2., 3., 4.], index=['a', 'b', 'c', 'd'])} + df = pd.DataFrame(d) + + # Write the DataFrame. + record_batch = pa.RecordBatch.from_pandas(df) + # Determine the size. + s = pa.MockOutputStream() + stream_writer = pa.RecordBatchStreamWriter(s, record_batch.schema) + stream_writer.write_batch(record_batch) + data_size = s.size() + object_id = plasma.ObjectID(np.random.bytes(20)) + + buf = self.plasma_client.create(object_id, data_size) + stream = pa.FixedSizeBufferWriter(buf) + stream_writer = pa.RecordBatchStreamWriter(stream, record_batch.schema) + stream_writer.write_batch(record_batch) + + self.plasma_client.seal(object_id) + + # Read the DataFrame. + [data] = self.plasma_client.get_buffers([object_id]) + reader = pa.RecordBatchStreamReader(pa.BufferReader(data)) + result = reader.read_next_batch().to_pandas() + + pd.testing.assert_frame_equal(df, result) + + def test_pickle_object_ids(self): + # This can be used for sharing object IDs between processes. + import pickle + object_id = random_object_id() + data = pickle.dumps(object_id) + object_id2 = pickle.loads(data) + assert object_id == object_id2 + + def test_store_full(self): + # The store is started with 1GB, so make sure that create throws an + # exception when it is full. + def assert_create_raises_plasma_full(unit_test, size): + partial_size = np.random.randint(size) + try: + _, memory_buffer, _ = create_object(unit_test.plasma_client, + partial_size, + size - partial_size) + # TODO(pcm): More specific error here. + except pa.lib.ArrowException: + pass + else: + # For some reason the above didn't throw an exception, so fail. + assert False + + PERCENT = DEFAULT_PLASMA_STORE_MEMORY // 100 + + # Create a list to keep some of the buffers in scope. + memory_buffers = [] + _, memory_buffer, _ = create_object(self.plasma_client, 50 * PERCENT) + memory_buffers.append(memory_buffer) + # Remaining space is 50%. Make sure that we can't create an + # object of size 50% + 1, but we can create one of size 20%. + assert_create_raises_plasma_full( + self, 50 * PERCENT + SMALL_OBJECT_SIZE) + _, memory_buffer, _ = create_object(self.plasma_client, 20 * PERCENT) + del memory_buffer + _, memory_buffer, _ = create_object(self.plasma_client, 20 * PERCENT) + del memory_buffer + assert_create_raises_plasma_full( + self, 50 * PERCENT + SMALL_OBJECT_SIZE) + + _, memory_buffer, _ = create_object(self.plasma_client, 20 * PERCENT) + memory_buffers.append(memory_buffer) + # Remaining space is 30%. + assert_create_raises_plasma_full( + self, 30 * PERCENT + SMALL_OBJECT_SIZE) + + _, memory_buffer, _ = create_object(self.plasma_client, 10 * PERCENT) + memory_buffers.append(memory_buffer) + # Remaining space is 20%. + assert_create_raises_plasma_full( + self, 20 * PERCENT + SMALL_OBJECT_SIZE) + + def test_contains(self): + fake_object_ids = [random_object_id() for _ in range(100)] + real_object_ids = [random_object_id() for _ in range(100)] + for object_id in real_object_ids: + assert self.plasma_client.contains(object_id) is False + self.plasma_client.create(object_id, 100) + self.plasma_client.seal(object_id) + assert self.plasma_client.contains(object_id) + for object_id in fake_object_ids: + assert not self.plasma_client.contains(object_id) + for object_id in real_object_ids: + assert self.plasma_client.contains(object_id) + + def test_hash(self): + # Check the hash of an object that doesn't exist. + object_id1 = random_object_id() + try: + self.plasma_client.hash(object_id1) + # TODO(pcm): Introduce a more specific error type here + except pa.lib.ArrowException: + pass + else: + assert False + + length = 1000 + # Create a random object, and check that the hash function always + # returns the same value. + metadata = generate_metadata(length) + memory_buffer = np.frombuffer(self.plasma_client.create(object_id1, + length, + metadata), + dtype="uint8") + for i in range(length): + memory_buffer[i] = i % 256 + self.plasma_client.seal(object_id1) + assert (self.plasma_client.hash(object_id1) == + self.plasma_client.hash(object_id1)) + + # Create a second object with the same value as the first, and check + # that their hashes are equal. + object_id2 = random_object_id() + memory_buffer = np.frombuffer(self.plasma_client.create(object_id2, + length, + metadata), + dtype="uint8") + for i in range(length): + memory_buffer[i] = i % 256 + self.plasma_client.seal(object_id2) + assert (self.plasma_client.hash(object_id1) == + self.plasma_client.hash(object_id2)) + + # Create a third object with a different value from the first two, and + # check that its hash is different. + object_id3 = random_object_id() + metadata = generate_metadata(length) + memory_buffer = np.frombuffer(self.plasma_client.create(object_id3, + length, + metadata), + dtype="uint8") + for i in range(length): + memory_buffer[i] = (i + 1) % 256 + self.plasma_client.seal(object_id3) + assert (self.plasma_client.hash(object_id1) != + self.plasma_client.hash(object_id3)) + + # Create a fourth object with the same value as the third, but + # different metadata. Check that its hash is different from any of the + # previous three. + object_id4 = random_object_id() + metadata4 = generate_metadata(length) + memory_buffer = np.frombuffer(self.plasma_client.create(object_id4, + length, + metadata4), + dtype="uint8") + for i in range(length): + memory_buffer[i] = (i + 1) % 256 + self.plasma_client.seal(object_id4) + assert (self.plasma_client.hash(object_id1) != + self.plasma_client.hash(object_id4)) + assert (self.plasma_client.hash(object_id3) != + self.plasma_client.hash(object_id4)) + + def test_many_hashes(self): + hashes = [] + length = 2 ** 10 + + for i in range(256): + object_id = random_object_id() + memory_buffer = np.frombuffer(self.plasma_client.create(object_id, + length), + dtype="uint8") + for j in range(length): + memory_buffer[j] = i + self.plasma_client.seal(object_id) + hashes.append(self.plasma_client.hash(object_id)) + + # Create objects of varying length. Each pair has two bits different. + for i in range(length): + object_id = random_object_id() + memory_buffer = np.frombuffer(self.plasma_client.create(object_id, + length), + dtype="uint8") + for j in range(length): + memory_buffer[j] = 0 + memory_buffer[i] = 1 + self.plasma_client.seal(object_id) + hashes.append(self.plasma_client.hash(object_id)) + + # Create objects of varying length, all with value 0. + for i in range(length): + object_id = random_object_id() + memory_buffer = np.frombuffer(self.plasma_client.create(object_id, + i), + dtype="uint8") + for j in range(i): + memory_buffer[j] = 0 + self.plasma_client.seal(object_id) + hashes.append(self.plasma_client.hash(object_id)) + + # Check that all hashes were unique. + assert len(set(hashes)) == 256 + length + length + + # def test_individual_delete(self): + # length = 100 + # # Create an object id string. + # object_id = random_object_id() + # # Create a random metadata string. + # metadata = generate_metadata(100) + # # Create a new buffer and write to it. + # memory_buffer = self.plasma_client.create(object_id, length, + # metadata) + # for i in range(length): + # memory_buffer[i] = chr(i % 256) + # # Seal the object. + # self.plasma_client.seal(object_id) + # # Check that the object is present. + # assert self.plasma_client.contains(object_id) + # # Delete the object. + # self.plasma_client.delete(object_id) + # # Make sure the object is no longer present. + # self.assertFalse(self.plasma_client.contains(object_id)) + # + # def test_delete(self): + # # Create some objects. + # object_ids = [random_object_id() for _ in range(100)] + # for object_id in object_ids: + # length = 100 + # # Create a random metadata string. + # metadata = generate_metadata(100) + # # Create a new buffer and write to it. + # memory_buffer = self.plasma_client.create(object_id, length, + # metadata) + # for i in range(length): + # memory_buffer[i] = chr(i % 256) + # # Seal the object. + # self.plasma_client.seal(object_id) + # # Check that the object is present. + # assert self.plasma_client.contains(object_id) + # + # # Delete the objects and make sure they are no longer present. + # for object_id in object_ids: + # # Delete the object. + # self.plasma_client.delete(object_id) + # # Make sure the object is no longer present. + # self.assertFalse(self.plasma_client.contains(object_id)) + + def test_illegal_functionality(self): + # Create an object id string. + object_id = random_object_id() + # Create a new buffer and write to it. + length = 1000 + memory_buffer = self.plasma_client.create(object_id, length) + # Make sure we cannot access memory out of bounds. + with pytest.raises(Exception): + memory_buffer[length] + # Seal the object. + self.plasma_client.seal(object_id) + # This test is commented out because it currently fails. + # # Make sure the object is ready only now. + # def illegal_assignment(): + # memory_buffer[0] = chr(0) + # with pytest.raises(Exception): + # illegal_assignment() + # Get the object. + memory_buffer = self.plasma_client.get_buffers([object_id])[0] + + # Make sure the object is read only. + def illegal_assignment(): + memory_buffer[0] = chr(0) + with pytest.raises(Exception): + illegal_assignment() + + def test_evict(self): + client = self.plasma_client2 + object_id1 = random_object_id() + b1 = client.create(object_id1, 1000) + client.seal(object_id1) + del b1 + assert client.evict(1) == 1000 + + object_id2 = random_object_id() + object_id3 = random_object_id() + b2 = client.create(object_id2, 999) + b3 = client.create(object_id3, 998) + client.seal(object_id3) + del b3 + assert client.evict(1000) == 998 + + object_id4 = random_object_id() + b4 = client.create(object_id4, 997) + client.seal(object_id4) + del b4 + client.seal(object_id2) + del b2 + assert client.evict(1) == 997 + assert client.evict(1) == 999 + + object_id5 = random_object_id() + object_id6 = random_object_id() + object_id7 = random_object_id() + b5 = client.create(object_id5, 996) + b6 = client.create(object_id6, 995) + b7 = client.create(object_id7, 994) + client.seal(object_id5) + client.seal(object_id6) + client.seal(object_id7) + del b5 + del b6 + del b7 + assert client.evict(2000) == 996 + 995 + 994 + + # Mitigate valgrind-induced slowness + SUBSCRIBE_TEST_SIZES = ([1, 10, 100, 1000] if USE_VALGRIND + else [1, 10, 100, 1000, 10000]) + + def test_subscribe(self): + # Subscribe to notifications from the Plasma Store. + self.plasma_client.subscribe() + for i in self.SUBSCRIBE_TEST_SIZES: + object_ids = [random_object_id() for _ in range(i)] + metadata_sizes = [np.random.randint(1000) for _ in range(i)] + data_sizes = [np.random.randint(1000) for _ in range(i)] + for j in range(i): + self.plasma_client.create( + object_ids[j], data_sizes[j], + metadata=bytearray(np.random.bytes(metadata_sizes[j]))) + self.plasma_client.seal(object_ids[j]) + # Check that we received notifications for all of the objects. + for j in range(i): + notification_info = self.plasma_client.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + assert object_ids[j] == recv_objid + assert data_sizes[j] == recv_dsize + assert metadata_sizes[j] == recv_msize + + def test_subscribe_socket(self): + # Subscribe to notifications from the Plasma Store. + self.plasma_client.subscribe() + rsock = self.plasma_client.get_notification_socket() + for i in self.SUBSCRIBE_TEST_SIZES: + # Get notification from socket. + object_ids = [random_object_id() for _ in range(i)] + metadata_sizes = [np.random.randint(1000) for _ in range(i)] + data_sizes = [np.random.randint(1000) for _ in range(i)] + + for j in range(i): + self.plasma_client.create( + object_ids[j], data_sizes[j], + metadata=bytearray(np.random.bytes(metadata_sizes[j]))) + self.plasma_client.seal(object_ids[j]) + + # Check that we received notifications for all of the objects. + for j in range(i): + # Assume the plasma store will not be full, + # so we always get the data size instead of -1. + msg_len, = struct.unpack('L', rsock.recv(8)) + content = rsock.recv(msg_len) + recv_objids, recv_dsizes, recv_msizes = ( + self.plasma_client.decode_notifications(content)) + assert object_ids[j] == recv_objids[0] + assert data_sizes[j] == recv_dsizes[0] + assert metadata_sizes[j] == recv_msizes[0] + + def test_subscribe_deletions(self): + # Subscribe to notifications from the Plasma Store. We use + # plasma_client2 to make sure that all used objects will get evicted + # properly. + self.plasma_client2.subscribe() + for i in self.SUBSCRIBE_TEST_SIZES: + object_ids = [random_object_id() for _ in range(i)] + # Add 1 to the sizes to make sure we have nonzero object sizes. + metadata_sizes = [np.random.randint(1000) + 1 for _ in range(i)] + data_sizes = [np.random.randint(1000) + 1 for _ in range(i)] + for j in range(i): + x = self.plasma_client2.create( + object_ids[j], data_sizes[j], + metadata=bytearray(np.random.bytes(metadata_sizes[j]))) + self.plasma_client2.seal(object_ids[j]) + del x + # Check that we received notifications for creating all of the + # objects. + for j in range(i): + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + assert object_ids[j] == recv_objid + assert data_sizes[j] == recv_dsize + assert metadata_sizes[j] == recv_msize + + # Check that we receive notifications for deleting all objects, as + # we evict them. + for j in range(i): + assert (self.plasma_client2.evict(1) == + data_sizes[j] + metadata_sizes[j]) + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + assert object_ids[j] == recv_objid + assert -1 == recv_dsize + assert -1 == recv_msize + + # Test multiple deletion notifications. The first 9 object IDs have + # size 0, and the last has a nonzero size. When Plasma evicts 1 byte, + # it will evict all objects, so we should receive deletion + # notifications for each. + num_object_ids = 10 + object_ids = [random_object_id() for _ in range(num_object_ids)] + metadata_sizes = [0] * (num_object_ids - 1) + data_sizes = [0] * (num_object_ids - 1) + metadata_sizes.append(np.random.randint(1000)) + data_sizes.append(np.random.randint(1000)) + for i in range(num_object_ids): + x = self.plasma_client2.create( + object_ids[i], data_sizes[i], + metadata=bytearray(np.random.bytes(metadata_sizes[i]))) + self.plasma_client2.seal(object_ids[i]) + del x + for i in range(num_object_ids): + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + assert object_ids[i] == recv_objid + assert data_sizes[i] == recv_dsize + assert metadata_sizes[i] == recv_msize + assert (self.plasma_client2.evict(1) == + data_sizes[-1] + metadata_sizes[-1]) + for i in range(num_object_ids): + notification_info = self.plasma_client2.get_next_notification() + recv_objid, recv_dsize, recv_msize = notification_info + assert object_ids[i] == recv_objid + assert -1 == recv_dsize + assert -1 == recv_msize + + def test_use_full_memory(self): + # Fill the object store up with a large number of small objects and let + # them go out of scope. + for _ in range(100): + create_object( + self.plasma_client2, + np.random.randint(1, DEFAULT_PLASMA_STORE_MEMORY // 20), 0) + # Create large objects that require the full object store size, and + # verify that they fit. + for _ in range(2): + create_object(self.plasma_client2, DEFAULT_PLASMA_STORE_MEMORY, 0) + # Verify that an object that is too large does not fit. + # Also verifies that the right error is thrown, and does not + # create the object ID prematurely. + object_id = random_object_id() + for i in range(3): + with pytest.raises(pa.plasma.PlasmaStoreFull): + self.plasma_client2.create( + object_id, DEFAULT_PLASMA_STORE_MEMORY + SMALL_OBJECT_SIZE) + + @staticmethod + def _client_blocked_in_get(plasma_store_name, object_id): + import pyarrow.plasma as plasma + client = plasma.connect(plasma_store_name) + # Try to get an object ID that doesn't exist. This should block. + client.get([object_id]) + + def test_client_death_during_get(self): + object_id = random_object_id() + + p = multiprocessing.Process(target=self._client_blocked_in_get, + args=(self.plasma_store_name, object_id)) + p.start() + # Make sure the process is running. + time.sleep(0.2) + assert p.is_alive() + + # Kill the client process. + p.terminate() + # Wait a little for the store to process the disconnect event. + time.sleep(0.1) + + # Create the object. + self.plasma_client.put(1, object_id=object_id) + + # Check that the store is still alive. This will raise an exception if + # the store is dead. + self.plasma_client.contains(random_object_id()) + + @staticmethod + def _client_get_multiple(plasma_store_name, object_ids): + import pyarrow.plasma as plasma + client = plasma.connect(plasma_store_name) + # Try to get an object ID that doesn't exist. This should block. + client.get(object_ids) + + def test_client_getting_multiple_objects(self): + object_ids = [random_object_id() for _ in range(10)] + + p = multiprocessing.Process(target=self._client_get_multiple, + args=(self.plasma_store_name, object_ids)) + p.start() + # Make sure the process is running. + time.sleep(0.2) + assert p.is_alive() + + # Create the objects one by one. + for object_id in object_ids: + self.plasma_client.put(1, object_id=object_id) + + # Check that the store is still alive. This will raise an exception if + # the store is dead. + self.plasma_client.contains(random_object_id()) + + # Make sure that the blocked client finishes. + start_time = time.time() + while True: + if time.time() - start_time > 5: + raise Exception("Timing out while waiting for blocked client " + "to finish.") + if not p.is_alive(): + break + + +@pytest.mark.plasma +class TestEvictionToExternalStore: + + def setup_method(self, test_method): + import pyarrow.plasma as plasma + # Start Plasma store. + self.plasma_store_ctx = plasma.start_plasma_store( + plasma_store_memory=1000 * 1024, + use_valgrind=USE_VALGRIND, + external_store=EXTERNAL_STORE) + self.plasma_store_name, self.p = self.plasma_store_ctx.__enter__() + # Connect to Plasma. + self.plasma_client = plasma.connect(self.plasma_store_name) + + def teardown_method(self, test_method): + try: + # Check that the Plasma store is still alive. + assert self.p.poll() is None + self.p.send_signal(signal.SIGTERM) + self.p.wait(timeout=5) + finally: + self.plasma_store_ctx.__exit__(None, None, None) + + def test_eviction(self): + client = self.plasma_client + + object_ids = [random_object_id() for _ in range(0, 20)] + data = b'x' * 100 * 1024 + metadata = b'' + + for i in range(0, 20): + # Test for object non-existence. + assert not client.contains(object_ids[i]) + + # Create and seal the object. + client.create_and_seal(object_ids[i], data, metadata) + + # Test that the client can get the object. + assert client.contains(object_ids[i]) + + for i in range(0, 20): + # Since we are accessing objects sequentially, every object we + # access would be a cache "miss" owing to LRU eviction. + # Try and access the object from the plasma store first, and then + # try external store on failure. This should succeed to fetch the + # object. However, it may evict the next few objects. + [result] = client.get_buffers([object_ids[i]]) + assert result.to_pybytes() == data + + # Make sure we still cannot fetch objects that do not exist + [result] = client.get_buffers([random_object_id()], timeout_ms=100) + assert result is None + + +@pytest.mark.plasma +def test_object_id_size(): + import pyarrow.plasma as plasma + with pytest.raises(ValueError): + plasma.ObjectID("hello") + plasma.ObjectID(20 * b"0") + + +@pytest.mark.plasma +def test_object_id_equality_operators(): + import pyarrow.plasma as plasma + + oid1 = plasma.ObjectID(20 * b'0') + oid2 = plasma.ObjectID(20 * b'0') + oid3 = plasma.ObjectID(19 * b'0' + b'1') + + assert oid1 == oid2 + assert oid2 != oid3 + assert oid1 != 'foo' + + +@pytest.mark.xfail(reason="often fails on travis") +@pytest.mark.skipif(not os.path.exists("/mnt/hugepages"), + reason="requires hugepage support") +def test_use_huge_pages(): + import pyarrow.plasma as plasma + with plasma.start_plasma_store( + plasma_store_memory=2*10**9, + plasma_directory="/mnt/hugepages", + use_hugepages=True) as (plasma_store_name, p): + plasma_client = plasma.connect(plasma_store_name) + create_object(plasma_client, 10**8) + + +# This is checking to make sure plasma_clients cannot be destroyed +# before all the PlasmaBuffers that have handles to them are +# destroyed, see ARROW-2448. +@pytest.mark.plasma +def test_plasma_client_sharing(): + import pyarrow.plasma as plasma + + with plasma.start_plasma_store( + plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY) \ + as (plasma_store_name, p): + plasma_client = plasma.connect(plasma_store_name) + object_id = plasma_client.put(np.zeros(3)) + buf = plasma_client.get(object_id) + del plasma_client + assert (buf == np.zeros(3)).all() + del buf # This segfaulted pre ARROW-2448. + + +@pytest.mark.plasma +def test_plasma_list(): + import pyarrow.plasma as plasma + + with plasma.start_plasma_store( + plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY) \ + as (plasma_store_name, p): + plasma_client = plasma.connect(plasma_store_name) + + # Test sizes + u, _, _ = create_object(plasma_client, 11, metadata_size=7, seal=False) + l1 = plasma_client.list() + assert l1[u]["data_size"] == 11 + assert l1[u]["metadata_size"] == 7 + + # Test ref_count + v = plasma_client.put(np.zeros(3)) + # Ref count has already been released + # XXX flaky test, disabled (ARROW-3344) + # l2 = plasma_client.list() + # assert l2[v]["ref_count"] == 0 + a = plasma_client.get(v) + l3 = plasma_client.list() + assert l3[v]["ref_count"] == 1 + del a + + # Test state + w, _, _ = create_object(plasma_client, 3, metadata_size=0, seal=False) + l4 = plasma_client.list() + assert l4[w]["state"] == "created" + plasma_client.seal(w) + l5 = plasma_client.list() + assert l5[w]["state"] == "sealed" + + # Test timestamps + slack = 1.5 # seconds + t1 = time.time() + x, _, _ = create_object(plasma_client, 3, metadata_size=0, seal=False) + t2 = time.time() + l6 = plasma_client.list() + assert t1 - slack <= l6[x]["create_time"] <= t2 + slack + time.sleep(2.0) + t3 = time.time() + plasma_client.seal(x) + t4 = time.time() + l7 = plasma_client.list() + assert t3 - t2 - slack <= l7[x]["construct_duration"] + assert l7[x]["construct_duration"] <= t4 - t1 + slack + + +@pytest.mark.plasma +def test_object_id_randomness(): + cmd = "from pyarrow import plasma; print(plasma.ObjectID.from_random())" + first_object_id = subprocess.check_output([sys.executable, "-c", cmd]) + second_object_id = subprocess.check_output([sys.executable, "-c", cmd]) + assert first_object_id != second_object_id + + +@pytest.mark.plasma +def test_store_capacity(): + import pyarrow.plasma as plasma + with plasma.start_plasma_store(plasma_store_memory=10000) as (name, p): + plasma_client = plasma.connect(name) + assert plasma_client.store_capacity() == 10000 diff --git a/src/arrow/python/pyarrow/tests/test_plasma_tf_op.py b/src/arrow/python/pyarrow/tests/test_plasma_tf_op.py new file mode 100644 index 000000000..53ecae217 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_plasma_tf_op.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + + +def run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name, + client, use_gpu, dtype): + FORCE_DEVICE = '/gpu' if use_gpu else '/cpu' + + object_id = np.random.bytes(20) + + data = np.random.randn(3, 244, 244).astype(dtype) + ones = np.ones((3, 244, 244)).astype(dtype) + + sess = tf.Session(config=tf.ConfigProto( + allow_soft_placement=True, log_device_placement=True)) + + def ToPlasma(): + data_tensor = tf.constant(data) + ones_tensor = tf.constant(ones) + return plasma.tf_plasma_op.tensor_to_plasma( + [data_tensor, ones_tensor], + object_id, + plasma_store_socket_name=plasma_store_name) + + def FromPlasma(): + return plasma.tf_plasma_op.plasma_to_tensor( + object_id, + dtype=tf.as_dtype(dtype), + plasma_store_socket_name=plasma_store_name) + + with tf.device(FORCE_DEVICE): + to_plasma = ToPlasma() + from_plasma = FromPlasma() + + z = from_plasma + 1 + + sess.run(to_plasma) + # NOTE(zongheng): currently it returns a flat 1D tensor. + # So reshape manually. + out = sess.run(from_plasma) + + out = np.split(out, 2) + out0 = out[0].reshape(3, 244, 244) + out1 = out[1].reshape(3, 244, 244) + + sess.run(z) + + assert np.array_equal(data, out0), "Data not equal!" + assert np.array_equal(ones, out1), "Data not equal!" + + # Try getting the data from Python + plasma_object_id = plasma.ObjectID(object_id) + obj = client.get(plasma_object_id) + + # Deserialized Tensor should be 64-byte aligned. + assert obj.ctypes.data % 64 == 0 + + result = np.split(obj, 2) + result0 = result[0].reshape(3, 244, 244) + result1 = result[1].reshape(3, 244, 244) + + assert np.array_equal(data, result0), "Data not equal!" + assert np.array_equal(ones, result1), "Data not equal!" + + +@pytest.mark.plasma +@pytest.mark.tensorflow +@pytest.mark.skip(reason='Until ARROW-4259 is resolved') +def test_plasma_tf_op(use_gpu=False): + import pyarrow.plasma as plasma + import tensorflow as tf + + plasma.build_plasma_tensorflow_op() + + if plasma.tf_plasma_op is None: + pytest.skip("TensorFlow Op not found") + + with plasma.start_plasma_store(10**8) as (plasma_store_name, p): + client = plasma.connect(plasma_store_name) + for dtype in [np.float32, np.float64, + np.int8, np.int16, np.int32, np.int64]: + run_tensorflow_test_with_dtype(tf, plasma, plasma_store_name, + client, use_gpu, dtype) + + # Make sure the objects have been released. + for _, info in client.list().items(): + assert info['ref_count'] == 0 diff --git a/src/arrow/python/pyarrow/tests/test_scalars.py b/src/arrow/python/pyarrow/tests/test_scalars.py new file mode 100644 index 000000000..778ce1066 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_scalars.py @@ -0,0 +1,687 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import decimal +import pickle +import pytest +import weakref + +import numpy as np + +import pyarrow as pa + + +@pytest.mark.parametrize(['value', 'ty', 'klass', 'deprecated'], [ + (False, None, pa.BooleanScalar, pa.BooleanValue), + (True, None, pa.BooleanScalar, pa.BooleanValue), + (1, None, pa.Int64Scalar, pa.Int64Value), + (-1, None, pa.Int64Scalar, pa.Int64Value), + (1, pa.int8(), pa.Int8Scalar, pa.Int8Value), + (1, pa.uint8(), pa.UInt8Scalar, pa.UInt8Value), + (1, pa.int16(), pa.Int16Scalar, pa.Int16Value), + (1, pa.uint16(), pa.UInt16Scalar, pa.UInt16Value), + (1, pa.int32(), pa.Int32Scalar, pa.Int32Value), + (1, pa.uint32(), pa.UInt32Scalar, pa.UInt32Value), + (1, pa.int64(), pa.Int64Scalar, pa.Int64Value), + (1, pa.uint64(), pa.UInt64Scalar, pa.UInt64Value), + (1.0, None, pa.DoubleScalar, pa.DoubleValue), + (np.float16(1.0), pa.float16(), pa.HalfFloatScalar, pa.HalfFloatValue), + (1.0, pa.float32(), pa.FloatScalar, pa.FloatValue), + (decimal.Decimal("1.123"), None, pa.Decimal128Scalar, pa.Decimal128Value), + (decimal.Decimal("1.1234567890123456789012345678901234567890"), + None, pa.Decimal256Scalar, pa.Decimal256Value), + ("string", None, pa.StringScalar, pa.StringValue), + (b"bytes", None, pa.BinaryScalar, pa.BinaryValue), + ("largestring", pa.large_string(), pa.LargeStringScalar, + pa.LargeStringValue), + (b"largebytes", pa.large_binary(), pa.LargeBinaryScalar, + pa.LargeBinaryValue), + (b"abc", pa.binary(3), pa.FixedSizeBinaryScalar, pa.FixedSizeBinaryValue), + ([1, 2, 3], None, pa.ListScalar, pa.ListValue), + ([1, 2, 3, 4], pa.large_list(pa.int8()), pa.LargeListScalar, + pa.LargeListValue), + ([1, 2, 3, 4, 5], pa.list_(pa.int8(), 5), pa.FixedSizeListScalar, + pa.FixedSizeListValue), + (datetime.date.today(), None, pa.Date32Scalar, pa.Date32Value), + (datetime.date.today(), pa.date64(), pa.Date64Scalar, pa.Date64Value), + (datetime.datetime.now(), None, pa.TimestampScalar, pa.TimestampValue), + (datetime.datetime.now().time().replace(microsecond=0), pa.time32('s'), + pa.Time32Scalar, pa.Time32Value), + (datetime.datetime.now().time(), None, pa.Time64Scalar, pa.Time64Value), + (datetime.timedelta(days=1), None, pa.DurationScalar, pa.DurationValue), + (pa.MonthDayNano([1, -1, -10100]), None, + pa.MonthDayNanoIntervalScalar, None), + ({'a': 1, 'b': [1, 2]}, None, pa.StructScalar, pa.StructValue), + ([('a', 1), ('b', 2)], pa.map_(pa.string(), pa.int8()), pa.MapScalar, + pa.MapValue), +]) +def test_basics(value, ty, klass, deprecated): + s = pa.scalar(value, type=ty) + assert isinstance(s, klass) + assert s.as_py() == value + assert s == pa.scalar(value, type=ty) + assert s != value + assert s != "else" + assert hash(s) == hash(s) + assert s.is_valid is True + assert s != None # noqa: E711 + if deprecated is not None: + with pytest.warns(FutureWarning): + assert isinstance(s, deprecated) + + s = pa.scalar(None, type=s.type) + assert s.is_valid is False + assert s.as_py() is None + assert s != pa.scalar(value, type=ty) + + # test pickle roundtrip + restored = pickle.loads(pickle.dumps(s)) + assert s.equals(restored) + + # test that scalars are weak-referenceable + wr = weakref.ref(s) + assert wr() is not None + del s + assert wr() is None + + +def test_null_singleton(): + with pytest.raises(RuntimeError): + pa.NullScalar() + + +def test_nulls(): + null = pa.scalar(None) + assert null is pa.NA + assert null.as_py() is None + assert null != "something" + assert (null == pa.scalar(None)) is True + assert (null == 0) is False + assert pa.NA == pa.NA + assert pa.NA not in [5] + + arr = pa.array([None, None]) + for v in arr: + assert v is pa.NA + assert v.as_py() is None + + # test pickle roundtrip + restored = pickle.loads(pickle.dumps(null)) + assert restored.equals(null) + + # test that scalars are weak-referenceable + wr = weakref.ref(null) + assert wr() is not None + del null + assert wr() is not None # singleton + + +def test_hashing(): + # ARROW-640 + values = list(range(500)) + arr = pa.array(values + values) + set_from_array = set(arr) + assert isinstance(set_from_array, set) + assert len(set_from_array) == 500 + + +def test_bool(): + false = pa.scalar(False) + true = pa.scalar(True) + + assert isinstance(false, pa.BooleanScalar) + assert isinstance(true, pa.BooleanScalar) + + assert repr(true) == "<pyarrow.BooleanScalar: True>" + assert str(true) == "True" + assert repr(false) == "<pyarrow.BooleanScalar: False>" + assert str(false) == "False" + + assert true.as_py() is True + assert false.as_py() is False + + +def test_numerics(): + # int64 + s = pa.scalar(1) + assert isinstance(s, pa.Int64Scalar) + assert repr(s) == "<pyarrow.Int64Scalar: 1>" + assert str(s) == "1" + assert s.as_py() == 1 + + with pytest.raises(OverflowError): + pa.scalar(-1, type='uint8') + + # float64 + s = pa.scalar(1.5) + assert isinstance(s, pa.DoubleScalar) + assert repr(s) == "<pyarrow.DoubleScalar: 1.5>" + assert str(s) == "1.5" + assert s.as_py() == 1.5 + + # float16 + s = pa.scalar(np.float16(0.5), type='float16') + assert isinstance(s, pa.HalfFloatScalar) + assert repr(s) == "<pyarrow.HalfFloatScalar: 0.5>" + assert str(s) == "0.5" + assert s.as_py() == 0.5 + + +def test_decimal128(): + v = decimal.Decimal("1.123") + s = pa.scalar(v) + assert isinstance(s, pa.Decimal128Scalar) + assert s.as_py() == v + assert s.type == pa.decimal128(4, 3) + + v = decimal.Decimal("1.1234") + with pytest.raises(pa.ArrowInvalid): + pa.scalar(v, type=pa.decimal128(4, scale=3)) + with pytest.raises(pa.ArrowInvalid): + pa.scalar(v, type=pa.decimal128(5, scale=3)) + + s = pa.scalar(v, type=pa.decimal128(5, scale=4)) + assert isinstance(s, pa.Decimal128Scalar) + assert s.as_py() == v + + +def test_decimal256(): + v = decimal.Decimal("1234567890123456789012345678901234567890.123") + s = pa.scalar(v) + assert isinstance(s, pa.Decimal256Scalar) + assert s.as_py() == v + assert s.type == pa.decimal256(43, 3) + + v = decimal.Decimal("1.1234") + with pytest.raises(pa.ArrowInvalid): + pa.scalar(v, type=pa.decimal256(4, scale=3)) + with pytest.raises(pa.ArrowInvalid): + pa.scalar(v, type=pa.decimal256(5, scale=3)) + + s = pa.scalar(v, type=pa.decimal256(5, scale=4)) + assert isinstance(s, pa.Decimal256Scalar) + assert s.as_py() == v + + +def test_date(): + # ARROW-5125 + d1 = datetime.date(3200, 1, 1) + d2 = datetime.date(1960, 1, 1) + + for ty in [pa.date32(), pa.date64()]: + for d in [d1, d2]: + s = pa.scalar(d, type=ty) + assert s.as_py() == d + + +def test_date_cast(): + # ARROW-10472 - casting fo scalars doesn't segfault + scalar = pa.scalar(datetime.datetime(2012, 1, 1), type=pa.timestamp("us")) + expected = datetime.date(2012, 1, 1) + for ty in [pa.date32(), pa.date64()]: + result = scalar.cast(ty) + assert result.as_py() == expected + + +def test_time(): + t1 = datetime.time(18, 0) + t2 = datetime.time(21, 0) + + types = [pa.time32('s'), pa.time32('ms'), pa.time64('us'), pa.time64('ns')] + for ty in types: + for t in [t1, t2]: + s = pa.scalar(t, type=ty) + assert s.as_py() == t + + +def test_cast(): + val = pa.scalar(5, type='int8') + assert val.cast('int64') == pa.scalar(5, type='int64') + assert val.cast('uint32') == pa.scalar(5, type='uint32') + assert val.cast('string') == pa.scalar('5', type='string') + with pytest.raises(ValueError): + pa.scalar('foo').cast('int32') + + +@pytest.mark.pandas +def test_timestamp(): + import pandas as pd + arr = pd.date_range('2000-01-01 12:34:56', periods=10).values + + units = ['ns', 'us', 'ms', 's'] + + for i, unit in enumerate(units): + dtype = 'datetime64[{}]'.format(unit) + arrow_arr = pa.Array.from_pandas(arr.astype(dtype)) + expected = pd.Timestamp('2000-01-01 12:34:56') + + assert arrow_arr[0].as_py() == expected + assert arrow_arr[0].value * 1000**i == expected.value + + tz = 'America/New_York' + arrow_type = pa.timestamp(unit, tz=tz) + + dtype = 'datetime64[{}]'.format(unit) + arrow_arr = pa.Array.from_pandas(arr.astype(dtype), type=arrow_type) + expected = (pd.Timestamp('2000-01-01 12:34:56') + .tz_localize('utc') + .tz_convert(tz)) + + assert arrow_arr[0].as_py() == expected + assert arrow_arr[0].value * 1000**i == expected.value + + +@pytest.mark.nopandas +def test_timestamp_nanos_nopandas(): + # ARROW-5450 + import pytz + tz = 'America/New_York' + ty = pa.timestamp('ns', tz=tz) + + # 2000-01-01 00:00:00 + 1 microsecond + s = pa.scalar(946684800000000000 + 1000, type=ty) + + tzinfo = pytz.timezone(tz) + expected = datetime.datetime(2000, 1, 1, microsecond=1, tzinfo=tzinfo) + expected = tzinfo.fromutc(expected) + result = s.as_py() + assert result == expected + assert result.year == 1999 + assert result.hour == 19 + + # Non-zero nanos yields ValueError + s = pa.scalar(946684800000000001, type=ty) + with pytest.raises(ValueError): + s.as_py() + + +def test_timestamp_no_overflow(): + # ARROW-5450 + import pytz + + timestamps = [ + datetime.datetime(1, 1, 1, 0, 0, 0, tzinfo=pytz.utc), + datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, tzinfo=pytz.utc), + datetime.datetime(1970, 1, 1, 0, 0, 0, tzinfo=pytz.utc), + ] + for ts in timestamps: + s = pa.scalar(ts, type=pa.timestamp("us", tz="UTC")) + assert s.as_py() == ts + + +def test_duration(): + arr = np.array([0, 3600000000000], dtype='timedelta64[ns]') + + units = ['us', 'ms', 's'] + + for i, unit in enumerate(units): + dtype = 'timedelta64[{}]'.format(unit) + arrow_arr = pa.array(arr.astype(dtype)) + expected = datetime.timedelta(seconds=60*60) + assert isinstance(arrow_arr[1].as_py(), datetime.timedelta) + assert arrow_arr[1].as_py() == expected + assert (arrow_arr[1].value * 1000**(i+1) == + expected.total_seconds() * 1e9) + + +@pytest.mark.pandas +def test_duration_nanos_pandas(): + import pandas as pd + arr = pa.array([0, 3600000000000], type=pa.duration('ns')) + expected = pd.Timedelta('1 hour') + assert isinstance(arr[1].as_py(), pd.Timedelta) + assert arr[1].as_py() == expected + assert arr[1].value == expected.value + + # Non-zero nanos work fine + arr = pa.array([946684800000000001], type=pa.duration('ns')) + assert arr[0].as_py() == pd.Timedelta(946684800000000001, unit='ns') + + +@pytest.mark.nopandas +def test_duration_nanos_nopandas(): + arr = pa.array([0, 3600000000000], pa.duration('ns')) + expected = datetime.timedelta(seconds=60*60) + assert isinstance(arr[1].as_py(), datetime.timedelta) + assert arr[1].as_py() == expected + assert arr[1].value == expected.total_seconds() * 1e9 + + # Non-zero nanos yields ValueError + arr = pa.array([946684800000000001], type=pa.duration('ns')) + with pytest.raises(ValueError): + arr[0].as_py() + + +def test_month_day_nano_interval(): + triple = pa.MonthDayNano([-3600, 1800, -50]) + arr = pa.array([triple]) + assert isinstance(arr[0].as_py(), pa.MonthDayNano) + assert arr[0].as_py() == triple + assert arr[0].value == triple + + +@pytest.mark.parametrize('value', ['foo', 'mañana']) +@pytest.mark.parametrize(('ty', 'scalar_typ'), [ + (pa.string(), pa.StringScalar), + (pa.large_string(), pa.LargeStringScalar) +]) +def test_string(value, ty, scalar_typ): + s = pa.scalar(value, type=ty) + assert isinstance(s, scalar_typ) + assert s.as_py() == value + assert s.as_py() != 'something' + assert repr(value) in repr(s) + assert str(s) == str(value) + + buf = s.as_buffer() + assert isinstance(buf, pa.Buffer) + assert buf.to_pybytes() == value.encode() + + +@pytest.mark.parametrize('value', [b'foo', b'bar']) +@pytest.mark.parametrize(('ty', 'scalar_typ'), [ + (pa.binary(), pa.BinaryScalar), + (pa.large_binary(), pa.LargeBinaryScalar) +]) +def test_binary(value, ty, scalar_typ): + s = pa.scalar(value, type=ty) + assert isinstance(s, scalar_typ) + assert s.as_py() == value + assert str(s) == str(value) + assert repr(value) in repr(s) + assert s.as_py() == value + assert s != b'xxxxx' + + buf = s.as_buffer() + assert isinstance(buf, pa.Buffer) + assert buf.to_pybytes() == value + + +def test_fixed_size_binary(): + s = pa.scalar(b'foof', type=pa.binary(4)) + assert isinstance(s, pa.FixedSizeBinaryScalar) + assert s.as_py() == b'foof' + + with pytest.raises(pa.ArrowInvalid): + pa.scalar(b'foof5', type=pa.binary(4)) + + +@pytest.mark.parametrize(('ty', 'klass'), [ + (pa.list_(pa.string()), pa.ListScalar), + (pa.large_list(pa.string()), pa.LargeListScalar) +]) +def test_list(ty, klass): + v = ['foo', None] + s = pa.scalar(v, type=ty) + assert s.type == ty + assert len(s) == 2 + assert isinstance(s.values, pa.Array) + assert s.values.to_pylist() == v + assert isinstance(s, klass) + assert repr(v) in repr(s) + assert s.as_py() == v + assert s[0].as_py() == 'foo' + assert s[1].as_py() is None + assert s[-1] == s[1] + assert s[-2] == s[0] + with pytest.raises(IndexError): + s[-3] + with pytest.raises(IndexError): + s[2] + + +def test_list_from_numpy(): + s = pa.scalar(np.array([1, 2, 3], dtype=np.int64())) + assert s.type == pa.list_(pa.int64()) + assert s.as_py() == [1, 2, 3] + + +@pytest.mark.pandas +def test_list_from_pandas(): + import pandas as pd + + s = pa.scalar(pd.Series([1, 2, 3])) + assert s.as_py() == [1, 2, 3] + + cases = [ + (np.nan, 'null'), + (['string', np.nan], pa.list_(pa.binary())), + (['string', np.nan], pa.list_(pa.utf8())), + ([b'string', np.nan], pa.list_(pa.binary(6))), + ([True, np.nan], pa.list_(pa.bool_())), + ([decimal.Decimal('0'), np.nan], pa.list_(pa.decimal128(12, 2))), + ] + for case, ty in cases: + # Both types of exceptions are raised. May want to clean that up + with pytest.raises((ValueError, TypeError)): + pa.scalar(case, type=ty) + + # from_pandas option suppresses failure + s = pa.scalar(case, type=ty, from_pandas=True) + + +def test_fixed_size_list(): + s = pa.scalar([1, None, 3], type=pa.list_(pa.int64(), 3)) + + assert len(s) == 3 + assert isinstance(s, pa.FixedSizeListScalar) + assert repr(s) == "<pyarrow.FixedSizeListScalar: [1, None, 3]>" + assert s.as_py() == [1, None, 3] + assert s[0].as_py() == 1 + assert s[1].as_py() is None + assert s[-1] == s[2] + with pytest.raises(IndexError): + s[-4] + with pytest.raises(IndexError): + s[3] + + +def test_struct(): + ty = pa.struct([ + pa.field('x', pa.int16()), + pa.field('y', pa.float32()) + ]) + + v = {'x': 2, 'y': 3.5} + s = pa.scalar(v, type=ty) + assert list(s) == list(s.keys()) == ['x', 'y'] + assert list(s.values()) == [ + pa.scalar(2, type=pa.int16()), + pa.scalar(3.5, type=pa.float32()) + ] + assert list(s.items()) == [ + ('x', pa.scalar(2, type=pa.int16())), + ('y', pa.scalar(3.5, type=pa.float32())) + ] + assert 'x' in s + assert 'y' in s + assert 'z' not in s + assert 0 not in s + + assert s.as_py() == v + assert repr(s) != repr(v) + assert repr(s.as_py()) == repr(v) + assert len(s) == 2 + assert isinstance(s['x'], pa.Int16Scalar) + assert isinstance(s['y'], pa.FloatScalar) + assert s['x'].as_py() == 2 + assert s['y'].as_py() == 3.5 + + with pytest.raises(KeyError): + s['non-existent'] + + s = pa.scalar(None, type=ty) + assert list(s) == list(s.keys()) == ['x', 'y'] + assert s.as_py() is None + assert 'x' in s + assert 'y' in s + assert isinstance(s['x'], pa.Int16Scalar) + assert isinstance(s['y'], pa.FloatScalar) + assert s['x'].is_valid is False + assert s['y'].is_valid is False + assert s['x'].as_py() is None + assert s['y'].as_py() is None + + +def test_struct_duplicate_fields(): + ty = pa.struct([ + pa.field('x', pa.int16()), + pa.field('y', pa.float32()), + pa.field('x', pa.int64()), + ]) + s = pa.scalar([('x', 1), ('y', 2.0), ('x', 3)], type=ty) + + assert list(s) == list(s.keys()) == ['x', 'y', 'x'] + assert len(s) == 3 + assert s == s + assert list(s.items()) == [ + ('x', pa.scalar(1, pa.int16())), + ('y', pa.scalar(2.0, pa.float32())), + ('x', pa.scalar(3, pa.int64())) + ] + + assert 'x' in s + assert 'y' in s + assert 'z' not in s + assert 0 not in s + + # getitem with field names fails for duplicate fields, works for others + with pytest.raises(KeyError): + s['x'] + + assert isinstance(s['y'], pa.FloatScalar) + assert s['y'].as_py() == 2.0 + + # getitem with integer index works for all fields + assert isinstance(s[0], pa.Int16Scalar) + assert s[0].as_py() == 1 + assert isinstance(s[1], pa.FloatScalar) + assert s[1].as_py() == 2.0 + assert isinstance(s[2], pa.Int64Scalar) + assert s[2].as_py() == 3 + + assert "pyarrow.StructScalar" in repr(s) + + with pytest.raises(ValueError, match="duplicate field names"): + s.as_py() + + +def test_map(): + ty = pa.map_(pa.string(), pa.int8()) + v = [('a', 1), ('b', 2)] + s = pa.scalar(v, type=ty) + + assert len(s) == 2 + assert isinstance(s, pa.MapScalar) + assert isinstance(s.values, pa.Array) + assert repr(s) == "<pyarrow.MapScalar: [('a', 1), ('b', 2)]>" + assert s.values.to_pylist() == [ + {'key': 'a', 'value': 1}, + {'key': 'b', 'value': 2} + ] + + # test iteration + for i, j in zip(s, v): + assert i == j + + assert s.as_py() == v + assert s[1] == ( + pa.scalar('b', type=pa.string()), + pa.scalar(2, type=pa.int8()) + ) + assert s[-1] == s[1] + assert s[-2] == s[0] + with pytest.raises(IndexError): + s[-3] + with pytest.raises(IndexError): + s[2] + + restored = pickle.loads(pickle.dumps(s)) + assert restored.equals(s) + + +def test_dictionary(): + indices = pa.array([2, None, 1, 2, 0, None]) + dictionary = pa.array(['foo', 'bar', 'baz']) + + arr = pa.DictionaryArray.from_arrays(indices, dictionary) + expected = ['baz', None, 'bar', 'baz', 'foo', None] + assert arr.to_pylist() == expected + + for j, (i, v) in enumerate(zip(indices, expected)): + s = arr[j] + + assert s.as_py() == v + assert s.value.as_py() == v + assert s.index.equals(i) + assert s.dictionary.equals(dictionary) + + with pytest.warns(FutureWarning): + assert s.index_value.equals(i) + with pytest.warns(FutureWarning): + assert s.dictionary_value.as_py() == v + + restored = pickle.loads(pickle.dumps(s)) + assert restored.equals(s) + + +def test_union(): + # sparse + arr = pa.UnionArray.from_sparse( + pa.array([0, 0, 1, 1], type=pa.int8()), + [ + pa.array(["a", "b", "c", "d"]), + pa.array([1, 2, 3, 4]) + ] + ) + for s in arr: + assert isinstance(s, pa.UnionScalar) + assert s.type.equals(arr.type) + assert s.is_valid is True + with pytest.raises(pa.ArrowNotImplementedError): + pickle.loads(pickle.dumps(s)) + + assert arr[0].type_code == 0 + assert arr[0].as_py() == "a" + assert arr[1].type_code == 0 + assert arr[1].as_py() == "b" + assert arr[2].type_code == 1 + assert arr[2].as_py() == 3 + assert arr[3].type_code == 1 + assert arr[3].as_py() == 4 + + # dense + arr = pa.UnionArray.from_dense( + types=pa.array([0, 1, 0, 0, 1, 1, 0], type='int8'), + value_offsets=pa.array([0, 0, 2, 1, 1, 2, 3], type='int32'), + children=[ + pa.array([b'a', b'b', b'c', b'd'], type='binary'), + pa.array([1, 2, 3], type='int64') + ] + ) + for s in arr: + assert isinstance(s, pa.UnionScalar) + assert s.type.equals(arr.type) + assert s.is_valid is True + with pytest.raises(pa.ArrowNotImplementedError): + pickle.loads(pickle.dumps(s)) + + assert arr[0].type_code == 0 + assert arr[0].as_py() == b'a' + assert arr[5].type_code == 1 + assert arr[5].as_py() == 3 diff --git a/src/arrow/python/pyarrow/tests/test_schema.py b/src/arrow/python/pyarrow/tests/test_schema.py new file mode 100644 index 000000000..f26eaaf5f --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_schema.py @@ -0,0 +1,730 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import OrderedDict +import pickle +import sys +import weakref + +import pytest +import numpy as np +import pyarrow as pa + +import pyarrow.tests.util as test_util +from pyarrow.vendored.version import Version + + +def test_schema_constructor_errors(): + msg = ("Do not call Schema's constructor directly, use `pyarrow.schema` " + "instead") + with pytest.raises(TypeError, match=msg): + pa.Schema() + + +def test_type_integers(): + dtypes = ['int8', 'int16', 'int32', 'int64', + 'uint8', 'uint16', 'uint32', 'uint64'] + + for name in dtypes: + factory = getattr(pa, name) + t = factory() + assert str(t) == name + + +def test_type_to_pandas_dtype(): + M8_ns = np.dtype('datetime64[ns]') + cases = [ + (pa.null(), np.object_), + (pa.bool_(), np.bool_), + (pa.int8(), np.int8), + (pa.int16(), np.int16), + (pa.int32(), np.int32), + (pa.int64(), np.int64), + (pa.uint8(), np.uint8), + (pa.uint16(), np.uint16), + (pa.uint32(), np.uint32), + (pa.uint64(), np.uint64), + (pa.float16(), np.float16), + (pa.float32(), np.float32), + (pa.float64(), np.float64), + (pa.date32(), M8_ns), + (pa.date64(), M8_ns), + (pa.timestamp('ms'), M8_ns), + (pa.binary(), np.object_), + (pa.binary(12), np.object_), + (pa.string(), np.object_), + (pa.list_(pa.int8()), np.object_), + # (pa.list_(pa.int8(), 2), np.object_), # TODO needs pandas conversion + (pa.map_(pa.int64(), pa.float64()), np.object_), + ] + for arrow_type, numpy_type in cases: + assert arrow_type.to_pandas_dtype() == numpy_type + + +@pytest.mark.pandas +def test_type_to_pandas_dtype_check_import(): + # ARROW-7980 + test_util.invoke_script('arrow_7980.py') + + +def test_type_list(): + value_type = pa.int32() + list_type = pa.list_(value_type) + assert str(list_type) == 'list<item: int32>' + + field = pa.field('my_item', pa.string()) + l2 = pa.list_(field) + assert str(l2) == 'list<my_item: string>' + + +def test_type_comparisons(): + val = pa.int32() + assert val == pa.int32() + assert val == 'int32' + assert val != 5 + + +def test_type_for_alias(): + cases = [ + ('i1', pa.int8()), + ('int8', pa.int8()), + ('i2', pa.int16()), + ('int16', pa.int16()), + ('i4', pa.int32()), + ('int32', pa.int32()), + ('i8', pa.int64()), + ('int64', pa.int64()), + ('u1', pa.uint8()), + ('uint8', pa.uint8()), + ('u2', pa.uint16()), + ('uint16', pa.uint16()), + ('u4', pa.uint32()), + ('uint32', pa.uint32()), + ('u8', pa.uint64()), + ('uint64', pa.uint64()), + ('f4', pa.float32()), + ('float32', pa.float32()), + ('f8', pa.float64()), + ('float64', pa.float64()), + ('date32', pa.date32()), + ('date64', pa.date64()), + ('string', pa.string()), + ('str', pa.string()), + ('binary', pa.binary()), + ('time32[s]', pa.time32('s')), + ('time32[ms]', pa.time32('ms')), + ('time64[us]', pa.time64('us')), + ('time64[ns]', pa.time64('ns')), + ('timestamp[s]', pa.timestamp('s')), + ('timestamp[ms]', pa.timestamp('ms')), + ('timestamp[us]', pa.timestamp('us')), + ('timestamp[ns]', pa.timestamp('ns')), + ('duration[s]', pa.duration('s')), + ('duration[ms]', pa.duration('ms')), + ('duration[us]', pa.duration('us')), + ('duration[ns]', pa.duration('ns')), + ('month_day_nano_interval', pa.month_day_nano_interval()), + ] + + for val, expected in cases: + assert pa.type_for_alias(val) == expected + + +def test_type_string(): + t = pa.string() + assert str(t) == 'string' + + +def test_type_timestamp_with_tz(): + tz = 'America/Los_Angeles' + t = pa.timestamp('ns', tz=tz) + assert t.unit == 'ns' + assert t.tz == tz + + +def test_time_types(): + t1 = pa.time32('s') + t2 = pa.time32('ms') + t3 = pa.time64('us') + t4 = pa.time64('ns') + + assert t1.unit == 's' + assert t2.unit == 'ms' + assert t3.unit == 'us' + assert t4.unit == 'ns' + + assert str(t1) == 'time32[s]' + assert str(t4) == 'time64[ns]' + + with pytest.raises(ValueError): + pa.time32('us') + + with pytest.raises(ValueError): + pa.time64('s') + + +def test_from_numpy_dtype(): + cases = [ + (np.dtype('bool'), pa.bool_()), + (np.dtype('int8'), pa.int8()), + (np.dtype('int16'), pa.int16()), + (np.dtype('int32'), pa.int32()), + (np.dtype('int64'), pa.int64()), + (np.dtype('uint8'), pa.uint8()), + (np.dtype('uint16'), pa.uint16()), + (np.dtype('uint32'), pa.uint32()), + (np.dtype('float16'), pa.float16()), + (np.dtype('float32'), pa.float32()), + (np.dtype('float64'), pa.float64()), + (np.dtype('U'), pa.string()), + (np.dtype('S'), pa.binary()), + (np.dtype('datetime64[s]'), pa.timestamp('s')), + (np.dtype('datetime64[ms]'), pa.timestamp('ms')), + (np.dtype('datetime64[us]'), pa.timestamp('us')), + (np.dtype('datetime64[ns]'), pa.timestamp('ns')), + (np.dtype('timedelta64[s]'), pa.duration('s')), + (np.dtype('timedelta64[ms]'), pa.duration('ms')), + (np.dtype('timedelta64[us]'), pa.duration('us')), + (np.dtype('timedelta64[ns]'), pa.duration('ns')), + ] + + for dt, pt in cases: + result = pa.from_numpy_dtype(dt) + assert result == pt + + # Things convertible to numpy dtypes work + assert pa.from_numpy_dtype('U') == pa.string() + assert pa.from_numpy_dtype(np.str_) == pa.string() + assert pa.from_numpy_dtype('int32') == pa.int32() + assert pa.from_numpy_dtype(bool) == pa.bool_() + + with pytest.raises(NotImplementedError): + pa.from_numpy_dtype(np.dtype('O')) + + with pytest.raises(TypeError): + pa.from_numpy_dtype('not_convertible_to_dtype') + + +def test_schema(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + sch = pa.schema(fields) + + assert sch.names == ['foo', 'bar', 'baz'] + assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())] + + assert len(sch) == 3 + assert sch[0].name == 'foo' + assert sch[0].type == fields[0].type + assert sch.field('foo').name == 'foo' + assert sch.field('foo').type == fields[0].type + + assert repr(sch) == """\ +foo: int32 +bar: string +baz: list<item: int8> + child 0, item: int8""" + + with pytest.raises(TypeError): + pa.schema([None]) + + +def test_schema_weakref(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + schema = pa.schema(fields) + wr = weakref.ref(schema) + assert wr() is not None + del schema + assert wr() is None + + +def test_schema_to_string_with_metadata(): + lorem = """\ +Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla accumsan vel +turpis et mollis. Aliquam tincidunt arcu id tortor blandit blandit. Donec +eget leo quis lectus scelerisque varius. Class aptent taciti sociosqu ad +litora torquent per conubia nostra, per inceptos himenaeos. Praesent +faucibus, diam eu volutpat iaculis, tellus est porta ligula, a efficitur +turpis nulla facilisis quam. Aliquam vitae lorem erat. Proin a dolor ac libero +dignissim mollis vitae eu mauris. Quisque posuere tellus vitae massa +pellentesque sagittis. Aenean feugiat, diam ac dignissim fermentum, lorem +sapien commodo massa, vel volutpat orci nisi eu justo. Nulla non blandit +sapien. Quisque pretium vestibulum urna eu vehicula.""" + # ARROW-7063 + my_schema = pa.schema([pa.field("foo", "int32", False, + metadata={"key1": "value1"}), + pa.field("bar", "string", True, + metadata={"key3": "value3"})], + metadata={"lorem": lorem}) + + assert my_schema.to_string() == """\ +foo: int32 not null + -- field metadata -- + key1: 'value1' +bar: string + -- field metadata -- + key3: 'value3' +-- schema metadata -- +lorem: '""" + lorem[:65] + "' + " + str(len(lorem) - 65) + + # Metadata that exactly fits + result = pa.schema([('f0', 'int32')], + metadata={'key': 'value' + 'x' * 62}).to_string() + assert result == """\ +f0: int32 +-- schema metadata -- +key: 'valuexxxxxxxxxxxxxxxxxxxxxxxxxxxxx\ +xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'""" + + assert my_schema.to_string(truncate_metadata=False) == """\ +foo: int32 not null + -- field metadata -- + key1: 'value1' +bar: string + -- field metadata -- + key3: 'value3' +-- schema metadata -- +lorem: '{}'""".format(lorem) + + assert my_schema.to_string(truncate_metadata=False, + show_field_metadata=False) == """\ +foo: int32 not null +bar: string +-- schema metadata -- +lorem: '{}'""".format(lorem) + + assert my_schema.to_string(truncate_metadata=False, + show_schema_metadata=False) == """\ +foo: int32 not null + -- field metadata -- + key1: 'value1' +bar: string + -- field metadata -- + key3: 'value3'""" + + assert my_schema.to_string(truncate_metadata=False, + show_field_metadata=False, + show_schema_metadata=False) == """\ +foo: int32 not null +bar: string""" + + +def test_schema_from_tuples(): + fields = [ + ('foo', pa.int32()), + ('bar', pa.string()), + ('baz', pa.list_(pa.int8())), + ] + sch = pa.schema(fields) + assert sch.names == ['foo', 'bar', 'baz'] + assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())] + assert len(sch) == 3 + assert repr(sch) == """\ +foo: int32 +bar: string +baz: list<item: int8> + child 0, item: int8""" + + with pytest.raises(TypeError): + pa.schema([('foo', None)]) + + +def test_schema_from_mapping(): + fields = OrderedDict([ + ('foo', pa.int32()), + ('bar', pa.string()), + ('baz', pa.list_(pa.int8())), + ]) + sch = pa.schema(fields) + assert sch.names == ['foo', 'bar', 'baz'] + assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())] + assert len(sch) == 3 + assert repr(sch) == """\ +foo: int32 +bar: string +baz: list<item: int8> + child 0, item: int8""" + + fields = OrderedDict([('foo', None)]) + with pytest.raises(TypeError): + pa.schema(fields) + + +def test_schema_duplicate_fields(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('foo', pa.list_(pa.int8())), + ] + sch = pa.schema(fields) + assert sch.names == ['foo', 'bar', 'foo'] + assert sch.types == [pa.int32(), pa.string(), pa.list_(pa.int8())] + assert len(sch) == 3 + assert repr(sch) == """\ +foo: int32 +bar: string +foo: list<item: int8> + child 0, item: int8""" + + assert sch[0].name == 'foo' + assert sch[0].type == fields[0].type + with pytest.warns(FutureWarning): + assert sch.field_by_name('bar') == fields[1] + with pytest.warns(FutureWarning): + assert sch.field_by_name('xxx') is None + with pytest.warns((UserWarning, FutureWarning)): + assert sch.field_by_name('foo') is None + + # Schema::GetFieldIndex + assert sch.get_field_index('foo') == -1 + + # Schema::GetAllFieldIndices + assert sch.get_all_field_indices('foo') == [0, 2] + + +def test_field_flatten(): + f0 = pa.field('foo', pa.int32()).with_metadata({b'foo': b'bar'}) + assert f0.flatten() == [f0] + + f1 = pa.field('bar', pa.float64(), nullable=False) + ff = pa.field('ff', pa.struct([f0, f1]), nullable=False) + assert ff.flatten() == [ + pa.field('ff.foo', pa.int32()).with_metadata({b'foo': b'bar'}), + pa.field('ff.bar', pa.float64(), nullable=False)] # XXX + + # Nullable parent makes flattened child nullable + ff = pa.field('ff', pa.struct([f0, f1])) + assert ff.flatten() == [ + pa.field('ff.foo', pa.int32()).with_metadata({b'foo': b'bar'}), + pa.field('ff.bar', pa.float64())] + + fff = pa.field('fff', pa.struct([ff])) + assert fff.flatten() == [pa.field('fff.ff', pa.struct([f0, f1]))] + + +def test_schema_add_remove_metadata(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + + s1 = pa.schema(fields) + + assert s1.metadata is None + + metadata = {b'foo': b'bar', b'pandas': b'badger'} + + s2 = s1.with_metadata(metadata) + assert s2.metadata == metadata + + s3 = s2.remove_metadata() + assert s3.metadata is None + + # idempotent + s4 = s3.remove_metadata() + assert s4.metadata is None + + +def test_schema_equals(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + metadata = {b'foo': b'bar', b'pandas': b'badger'} + + sch1 = pa.schema(fields) + sch2 = pa.schema(fields) + sch3 = pa.schema(fields, metadata=metadata) + sch4 = pa.schema(fields, metadata=metadata) + + assert sch1.equals(sch2, check_metadata=True) + assert sch3.equals(sch4, check_metadata=True) + assert sch1.equals(sch3) + assert not sch1.equals(sch3, check_metadata=True) + assert not sch1.equals(sch3, check_metadata=True) + + del fields[-1] + sch3 = pa.schema(fields) + assert not sch1.equals(sch3) + + +def test_schema_equals_propagates_check_metadata(): + # ARROW-4088 + schema1 = pa.schema([ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()) + ]) + schema2 = pa.schema([ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string(), metadata={'a': 'alpha'}), + ]) + assert not schema1.equals(schema2, check_metadata=True) + assert schema1.equals(schema2) + + +def test_schema_equals_invalid_type(): + # ARROW-5873 + schema = pa.schema([pa.field("a", pa.int64())]) + + for val in [None, 'string', pa.array([1, 2])]: + with pytest.raises(TypeError): + schema.equals(val) + + +def test_schema_equality_operators(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + metadata = {b'foo': b'bar', b'pandas': b'badger'} + + sch1 = pa.schema(fields) + sch2 = pa.schema(fields) + sch3 = pa.schema(fields, metadata=metadata) + sch4 = pa.schema(fields, metadata=metadata) + + assert sch1 == sch2 + assert sch3 == sch4 + + # __eq__ and __ne__ do not check metadata + assert sch1 == sch3 + assert not sch1 != sch3 + + assert sch2 == sch4 + + # comparison with other types doesn't raise + assert sch1 != [] + assert sch3 != 'foo' + + +def test_schema_get_fields(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + + schema = pa.schema(fields) + + assert schema.field('foo').name == 'foo' + assert schema.field(0).name == 'foo' + assert schema.field(-1).name == 'baz' + + with pytest.raises(KeyError): + schema.field('other') + with pytest.raises(TypeError): + schema.field(0.0) + with pytest.raises(IndexError): + schema.field(4) + + +def test_schema_negative_indexing(): + fields = [ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ] + + schema = pa.schema(fields) + + assert schema[-1].equals(schema[2]) + assert schema[-2].equals(schema[1]) + assert schema[-3].equals(schema[0]) + + with pytest.raises(IndexError): + schema[-4] + + with pytest.raises(IndexError): + schema[3] + + +def test_schema_repr_with_dictionaries(): + fields = [ + pa.field('one', pa.dictionary(pa.int16(), pa.string())), + pa.field('two', pa.int32()) + ] + sch = pa.schema(fields) + + expected = ( + """\ +one: dictionary<values=string, indices=int16, ordered=0> +two: int32""") + + assert repr(sch) == expected + + +def test_type_schema_pickling(): + cases = [ + pa.int8(), + pa.string(), + pa.binary(), + pa.binary(10), + pa.list_(pa.string()), + pa.map_(pa.string(), pa.int8()), + pa.struct([ + pa.field('a', 'int8'), + pa.field('b', 'string') + ]), + pa.union([ + pa.field('a', pa.int8()), + pa.field('b', pa.int16()) + ], pa.lib.UnionMode_SPARSE), + pa.union([ + pa.field('a', pa.int8()), + pa.field('b', pa.int16()) + ], pa.lib.UnionMode_DENSE), + pa.time32('s'), + pa.time64('us'), + pa.date32(), + pa.date64(), + pa.timestamp('ms'), + pa.timestamp('ns'), + pa.decimal128(12, 2), + pa.decimal256(76, 38), + pa.field('a', 'string', metadata={b'foo': b'bar'}), + pa.list_(pa.field("element", pa.int64())), + pa.large_list(pa.field("element", pa.int64())), + pa.map_(pa.field("key", pa.string(), nullable=False), + pa.field("value", pa.int8())) + ] + + for val in cases: + roundtripped = pickle.loads(pickle.dumps(val)) + assert val == roundtripped + + fields = [] + for i, f in enumerate(cases): + if isinstance(f, pa.Field): + fields.append(f) + else: + fields.append(pa.field('_f{}'.format(i), f)) + + schema = pa.schema(fields, metadata={b'foo': b'bar'}) + roundtripped = pickle.loads(pickle.dumps(schema)) + assert schema == roundtripped + + +def test_empty_table(): + schema1 = pa.schema([ + pa.field('f0', pa.int64()), + pa.field('f1', pa.dictionary(pa.int32(), pa.string())), + pa.field('f2', pa.list_(pa.list_(pa.int64()))), + ]) + # test it preserves field nullability + schema2 = pa.schema([ + pa.field('a', pa.int64(), nullable=False), + pa.field('b', pa.int64()) + ]) + + for schema in [schema1, schema2]: + table = schema.empty_table() + assert isinstance(table, pa.Table) + assert table.num_rows == 0 + assert table.schema == schema + + +@pytest.mark.pandas +def test_schema_from_pandas(): + import pandas as pd + inputs = [ + list(range(10)), + pd.Categorical(list(range(10))), + ['foo', 'bar', None, 'baz', 'qux'], + np.array([ + '2007-07-13T01:23:34.123456789', + '2006-01-13T12:34:56.432539784', + '2010-08-13T05:46:57.437699912' + ], dtype='datetime64[ns]'), + ] + if Version(pd.__version__) >= Version('1.0.0'): + inputs.append(pd.array([1, 2, None], dtype=pd.Int32Dtype())) + for data in inputs: + df = pd.DataFrame({'a': data}) + schema = pa.Schema.from_pandas(df) + expected = pa.Table.from_pandas(df).schema + assert schema == expected + + +def test_schema_sizeof(): + schema = pa.schema([ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + ]) + + assert sys.getsizeof(schema) > 30 + + schema2 = schema.with_metadata({"key": "some metadata"}) + assert sys.getsizeof(schema2) > sys.getsizeof(schema) + schema3 = schema.with_metadata({"key": "some more metadata"}) + assert sys.getsizeof(schema3) > sys.getsizeof(schema2) + + +def test_schema_merge(): + a = pa.schema([ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())) + ]) + b = pa.schema([ + pa.field('foo', pa.int32()), + pa.field('qux', pa.bool_()) + ]) + c = pa.schema([ + pa.field('quux', pa.dictionary(pa.int32(), pa.string())) + ]) + d = pa.schema([ + pa.field('foo', pa.int64()), + pa.field('qux', pa.bool_()) + ]) + + result = pa.unify_schemas([a, b, c]) + expected = pa.schema([ + pa.field('foo', pa.int32()), + pa.field('bar', pa.string()), + pa.field('baz', pa.list_(pa.int8())), + pa.field('qux', pa.bool_()), + pa.field('quux', pa.dictionary(pa.int32(), pa.string())) + ]) + assert result.equals(expected) + + with pytest.raises(pa.ArrowInvalid): + pa.unify_schemas([b, d]) + + # ARROW-14002: Try with tuple instead of list + result = pa.unify_schemas((a, b, c)) + assert result.equals(expected) + + +def test_undecodable_metadata(): + # ARROW-10214: undecodable metadata shouldn't fail repr() + data1 = b'abcdef\xff\x00' + data2 = b'ghijkl\xff\x00' + schema = pa.schema( + [pa.field('ints', pa.int16(), metadata={'key': data1})], + metadata={'key': data2}) + assert 'abcdef' in str(schema) + assert 'ghijkl' in str(schema) diff --git a/src/arrow/python/pyarrow/tests/test_serialization.py b/src/arrow/python/pyarrow/tests/test_serialization.py new file mode 100644 index 000000000..750827a50 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_serialization.py @@ -0,0 +1,1233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import pytest + +import collections +import datetime +import os +import pathlib +import pickle +import subprocess +import string +import sys + +import pyarrow as pa +import numpy as np + +import pyarrow.tests.util as test_util + +try: + import torch +except ImportError: + torch = None + # Blacklist the module in case `import torch` is costly before + # failing (ARROW-2071) + sys.modules['torch'] = None + +try: + from scipy.sparse import coo_matrix, csr_matrix, csc_matrix +except ImportError: + coo_matrix = None + csr_matrix = None + csc_matrix = None + +try: + import sparse +except ImportError: + sparse = None + + +# ignore all serialization deprecation warnings in this file, we test that the +# warnings are actually raised in test_serialization_deprecated.py +pytestmark = pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning") + + +def assert_equal(obj1, obj2): + if torch is not None and torch.is_tensor(obj1) and torch.is_tensor(obj2): + if obj1.is_sparse: + obj1 = obj1.to_dense() + if obj2.is_sparse: + obj2 = obj2.to_dense() + assert torch.equal(obj1, obj2) + return + module_numpy = (type(obj1).__module__ == np.__name__ or + type(obj2).__module__ == np.__name__) + if module_numpy: + empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or + (hasattr(obj2, "shape") and obj2.shape == ())) + if empty_shape: + # This is a special case because currently np.testing.assert_equal + # fails because we do not properly handle different numerical + # types. + assert obj1 == obj2, ("Objects {} and {} are " + "different.".format(obj1, obj2)) + else: + np.testing.assert_equal(obj1, obj2) + elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): + special_keys = ["_pytype_"] + assert (set(list(obj1.__dict__.keys()) + special_keys) == + set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} " + "and {} are " + "different." + .format( + obj1, + obj2)) + if obj1.__dict__ == {}: + print("WARNING: Empty dict in ", obj1) + for key in obj1.__dict__.keys(): + if key not in special_keys: + assert_equal(obj1.__dict__[key], obj2.__dict__[key]) + elif type(obj1) is dict or type(obj2) is dict: + assert_equal(obj1.keys(), obj2.keys()) + for key in obj1.keys(): + assert_equal(obj1[key], obj2[key]) + elif type(obj1) is list or type(obj2) is list: + assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " + "different lengths." + .format(obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif type(obj1) is tuple or type(obj2) is tuple: + assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with " + "different lengths." + .format(obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif (pa.lib.is_named_tuple(type(obj1)) or + pa.lib.is_named_tuple(type(obj2))): + assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples " + "with different lengths." + .format(obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif isinstance(obj1, pa.Array) and isinstance(obj2, pa.Array): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.Tensor) and isinstance(obj2, pa.Tensor): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.SparseCOOTensor) and \ + isinstance(obj2, pa.SparseCOOTensor): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.SparseCSRMatrix) and \ + isinstance(obj2, pa.SparseCSRMatrix): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.SparseCSCMatrix) and \ + isinstance(obj2, pa.SparseCSCMatrix): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.SparseCSFTensor) and \ + isinstance(obj2, pa.SparseCSFTensor): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.RecordBatch) and isinstance(obj2, pa.RecordBatch): + assert obj1.equals(obj2) + elif isinstance(obj1, pa.Table) and isinstance(obj2, pa.Table): + assert obj1.equals(obj2) + else: + assert type(obj1) == type(obj2) and obj1 == obj2, \ + "Objects {} and {} are different.".format(obj1, obj2) + + +PRIMITIVE_OBJECTS = [ + 0, 0.0, 0.9, 1 << 62, 1 << 999, + [1 << 100, [1 << 100]], "a", string.printable, "\u262F", + "hello world", "hello world", "\xff\xfe\x9c\x001\x000\x00", + None, True, False, [], (), {}, {(1, 2): 1}, {(): 2}, + [1, "hello", 3.0], "\u262F", 42.0, (1.0, "hi"), + [1, 2, 3, None], [(None,), 3, 1.0], ["h", "e", "l", "l", "o", None], + (None, None), ("hello", None), (True, False), + {True: "hello", False: "world"}, {"hello": "world", 1: 42, 2.5: 45}, + {"hello": {2, 3}, "world": {42.0}, "this": None}, + np.int8(3), np.int32(4), np.int64(5), + np.uint8(3), np.uint32(4), np.uint64(5), + np.float16(1.9), np.float32(1.9), + np.float64(1.9), np.zeros([8, 20]), + np.random.normal(size=[17, 10]), np.array(["hi", 3]), + np.array(["hi", 3], dtype=object), + np.random.normal(size=[15, 13]).T +] + + +index_types = ('i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8') +tensor_types = ('i1', 'i2', 'i4', 'i8', 'u1', 'u2', 'u4', 'u8', + 'f2', 'f4', 'f8') + +PRIMITIVE_OBJECTS += [0, np.array([["hi", "hi"], [1.3, 1]])] + + +COMPLEX_OBJECTS = [ + [[[[[[[[[[[[]]]]]]]]]]]], + {"obj{}".format(i): np.random.normal(size=[4, 4]) for i in range(5)}, + # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): { + # (): {(): {}}}}}}}}}}}}}, + ((((((((((),),),),),),),),),), + {"a": {"b": {"c": {"d": {}}}}}, +] + + +class Foo: + def __init__(self, value=0): + self.value = value + + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + + +class Bar: + def __init__(self): + for i, val in enumerate(COMPLEX_OBJECTS): + setattr(self, "field{}".format(i), val) + + +class Baz: + def __init__(self): + self.foo = Foo() + self.bar = Bar() + + def method(self, arg): + pass + + +class Qux: + def __init__(self): + self.objs = [Foo(1), Foo(42)] + + +class SubQux(Qux): + def __init__(self): + Qux.__init__(self) + + +class SubQuxPickle(Qux): + def __init__(self): + Qux.__init__(self) + + +class CustomError(Exception): + pass + + +Point = collections.namedtuple("Point", ["x", "y"]) +NamedTupleExample = collections.namedtuple( + "Example", "field1, field2, field3, field4, field5") + + +CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22), + Foo(), Bar(), Baz(), Qux(), SubQux(), SubQuxPickle(), + NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]), + collections.OrderedDict([("hello", 1), ("world", 2)]), + collections.deque([1, 2, 3, "a", "b", "c", 3.5]), + collections.Counter([1, 1, 1, 2, 2, 3, "a", "b"])] + + +def make_serialization_context(): + with pytest.warns(FutureWarning): + context = pa.default_serialization_context() + + context.register_type(Foo, "Foo") + context.register_type(Bar, "Bar") + context.register_type(Baz, "Baz") + context.register_type(Qux, "Quz") + context.register_type(SubQux, "SubQux") + context.register_type(SubQuxPickle, "SubQuxPickle", pickle=True) + context.register_type(Exception, "Exception") + context.register_type(CustomError, "CustomError") + context.register_type(Point, "Point") + context.register_type(NamedTupleExample, "NamedTupleExample") + + return context + + +global_serialization_context = make_serialization_context() + + +def serialization_roundtrip(value, scratch_buffer, + context=global_serialization_context): + writer = pa.FixedSizeBufferWriter(scratch_buffer) + pa.serialize_to(value, writer, context=context) + + reader = pa.BufferReader(scratch_buffer) + result = pa.deserialize_from(reader, None, context=context) + assert_equal(value, result) + + _check_component_roundtrip(value, context=context) + + +def _check_component_roundtrip(value, context=global_serialization_context): + # Test to/from components + serialized = pa.serialize(value, context=context) + components = serialized.to_components() + from_comp = pa.SerializedPyObject.from_components(components) + recons = from_comp.deserialize(context=context) + assert_equal(value, recons) + + +@pytest.fixture(scope='session') +def large_buffer(size=32*1024*1024): + yield pa.allocate_buffer(size) + + +def large_memory_map(tmpdir_factory, size=100*1024*1024): + path = (tmpdir_factory.mktemp('data') + .join('pyarrow-serialization-tmp-file').strpath) + + # Create a large memory mapped file + with open(path, 'wb') as f: + f.write(np.random.randint(0, 256, size=size) + .astype('u1') + .tobytes() + [:size]) + return path + + +def test_clone(): + context = pa.SerializationContext() + + class Foo: + pass + + def custom_serializer(obj): + return 0 + + def custom_deserializer(serialized_obj): + return (serialized_obj, 'a') + + context.register_type(Foo, 'Foo', custom_serializer=custom_serializer, + custom_deserializer=custom_deserializer) + + new_context = context.clone() + + f = Foo() + serialized = pa.serialize(f, context=context) + deserialized = serialized.deserialize(context=context) + assert deserialized == (0, 'a') + + serialized = pa.serialize(f, context=new_context) + deserialized = serialized.deserialize(context=new_context) + assert deserialized == (0, 'a') + + +def test_primitive_serialization_notbroken(large_buffer): + serialization_roundtrip({(1, 2): 2}, large_buffer) + + +def test_primitive_serialization_broken(large_buffer): + serialization_roundtrip({(): 2}, large_buffer) + + +def test_primitive_serialization(large_buffer): + for obj in PRIMITIVE_OBJECTS: + serialization_roundtrip(obj, large_buffer) + + +def test_integer_limits(large_buffer): + # Check that Numpy scalars can be represented up to their limit values + # (except np.uint64 which is limited to 2**63 - 1) + for dt in [np.int8, np.int64, np.int32, np.int64, + np.uint8, np.uint64, np.uint32, np.uint64]: + scal = dt(np.iinfo(dt).min) + serialization_roundtrip(scal, large_buffer) + if dt is not np.uint64: + scal = dt(np.iinfo(dt).max) + serialization_roundtrip(scal, large_buffer) + else: + scal = dt(2**63 - 1) + serialization_roundtrip(scal, large_buffer) + for v in (2**63, 2**64 - 1): + scal = dt(v) + with pytest.raises(pa.ArrowInvalid): + pa.serialize(scal) + + +def test_serialize_to_buffer(): + for nthreads in [1, 4]: + for value in COMPLEX_OBJECTS: + buf = pa.serialize(value).to_buffer(nthreads=nthreads) + result = pa.deserialize(buf) + assert_equal(value, result) + + +def test_complex_serialization(large_buffer): + for obj in COMPLEX_OBJECTS: + serialization_roundtrip(obj, large_buffer) + + +def test_custom_serialization(large_buffer): + for obj in CUSTOM_OBJECTS: + serialization_roundtrip(obj, large_buffer) + + +def test_default_dict_serialization(large_buffer): + pytest.importorskip("cloudpickle") + + obj = collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)]) + serialization_roundtrip(obj, large_buffer) + + +def test_numpy_serialization(large_buffer): + for t in ["bool", "int8", "uint8", "int16", "uint16", "int32", + "uint32", "float16", "float32", "float64", "<U1", "<U2", "<U3", + "<U4", "|S1", "|S2", "|S3", "|S4", "|O", + np.dtype([('a', 'int64'), ('b', 'float')]), + np.dtype([('x', 'uint32'), ('y', '<U8')])]: + obj = np.random.randint(0, 10, size=(100, 100)).astype(t) + serialization_roundtrip(obj, large_buffer) + obj = obj[1:99, 10:90] + serialization_roundtrip(obj, large_buffer) + + +def test_datetime_serialization(large_buffer): + data = [ + # Principia Mathematica published + datetime.datetime(year=1687, month=7, day=5), + + # Some random date + datetime.datetime(year=1911, month=6, day=3, hour=4, + minute=55, second=44), + # End of WWI + datetime.datetime(year=1918, month=11, day=11), + + # Beginning of UNIX time + datetime.datetime(year=1970, month=1, day=1), + + # The Berlin wall falls + datetime.datetime(year=1989, month=11, day=9), + + # Another random date + datetime.datetime(year=2011, month=6, day=3, hour=4, + minute=0, second=3), + # Another random date + datetime.datetime(year=1970, month=1, day=3, hour=4, + minute=0, second=0) + ] + for d in data: + serialization_roundtrip(d, large_buffer) + + +def test_torch_serialization(large_buffer): + pytest.importorskip("torch") + + serialization_context = pa.default_serialization_context() + pa.register_torch_serialization_handlers(serialization_context) + + # Dense tensors: + + # These are the only types that are supported for the + # PyTorch to NumPy conversion + for t in ["float32", "float64", + "uint8", "int16", "int32", "int64"]: + obj = torch.from_numpy(np.random.randn(1000).astype(t)) + serialization_roundtrip(obj, large_buffer, + context=serialization_context) + + tensor_requiring_grad = torch.randn(10, 10, requires_grad=True) + serialization_roundtrip(tensor_requiring_grad, large_buffer, + context=serialization_context) + + # Sparse tensors: + + # These are the only types that are supported for the + # PyTorch to NumPy conversion + for t in ["float32", "float64", + "uint8", "int16", "int32", "int64"]: + i = torch.LongTensor([[0, 2], [1, 0], [1, 2]]) + v = torch.from_numpy(np.array([3, 4, 5]).astype(t)) + obj = torch.sparse_coo_tensor(i.t(), v, torch.Size([2, 3])) + serialization_roundtrip(obj, large_buffer, + context=serialization_context) + + +@pytest.mark.skipif(not torch or not torch.cuda.is_available(), + reason="requires pytorch with CUDA") +def test_torch_cuda(): + # ARROW-2920: This used to segfault if torch is not imported + # before pyarrow + # Note that this test will only catch the issue if it is run + # with a pyarrow that has been built in the manylinux1 environment + torch.nn.Conv2d(64, 2, kernel_size=3, stride=1, + padding=1, bias=False).cuda() + + +def test_numpy_immutable(large_buffer): + obj = np.zeros([10]) + + writer = pa.FixedSizeBufferWriter(large_buffer) + pa.serialize_to(obj, writer, global_serialization_context) + + reader = pa.BufferReader(large_buffer) + result = pa.deserialize_from(reader, None, global_serialization_context) + with pytest.raises(ValueError): + result[0] = 1.0 + + +def test_numpy_base_object(tmpdir): + # ARROW-2040: deserialized Numpy array should keep a reference to the + # owner of its memory + path = os.path.join(str(tmpdir), 'zzz.bin') + data = np.arange(12, dtype=np.int32) + + with open(path, 'wb') as f: + f.write(pa.serialize(data).to_buffer()) + + serialized = pa.read_serialized(pa.OSFile(path)) + result = serialized.deserialize() + assert_equal(result, data) + serialized = None + assert_equal(result, data) + assert result.base is not None + + +# see https://issues.apache.org/jira/browse/ARROW-1695 +def test_serialization_callback_numpy(): + + class DummyClass: + pass + + def serialize_dummy_class(obj): + x = np.zeros(4) + return x + + def deserialize_dummy_class(serialized_obj): + return serialized_obj + + context = pa.default_serialization_context() + context.register_type(DummyClass, "DummyClass", + custom_serializer=serialize_dummy_class, + custom_deserializer=deserialize_dummy_class) + + pa.serialize(DummyClass(), context=context) + + +def test_numpy_subclass_serialization(): + # Check that we can properly serialize subclasses of np.ndarray. + class CustomNDArray(np.ndarray): + def __new__(cls, input_array): + array = np.asarray(input_array).view(cls) + return array + + def serializer(obj): + return {'numpy': obj.view(np.ndarray)} + + def deserializer(data): + array = data['numpy'].view(CustomNDArray) + return array + + context = pa.default_serialization_context() + + context.register_type(CustomNDArray, 'CustomNDArray', + custom_serializer=serializer, + custom_deserializer=deserializer) + + x = CustomNDArray(np.zeros(3)) + serialized = pa.serialize(x, context=context).to_buffer() + new_x = pa.deserialize(serialized, context=context) + assert type(new_x) == CustomNDArray + assert np.alltrue(new_x.view(np.ndarray) == np.zeros(3)) + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_coo_tensor_serialization(index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([[1, 2, 3, 4, 5, 6]]).T.astype(tensor_dtype) + coords = np.array([ + [0, 0, 2, 3, 1, 3], + [0, 2, 0, 4, 5, 5], + ]).T.astype(index_dtype) + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCOOTensor.from_numpy(data, coords, + shape, dim_names) + + context = pa.default_serialization_context() + serialized = pa.serialize(sparse_tensor, context=context).to_buffer() + result = pa.deserialize(serialized) + assert_equal(result, sparse_tensor) + assert isinstance(result, pa.SparseCOOTensor) + + data_result, coords_result = result.to_numpy() + assert np.array_equal(data_result, data) + assert np.array_equal(coords_result, coords) + assert result.dim_names == dim_names + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_coo_tensor_components_serialization(large_buffer, + index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([[1, 2, 3, 4, 5, 6]]).T.astype(tensor_dtype) + coords = np.array([ + [0, 0, 2, 3, 1, 3], + [0, 2, 0, 4, 5, 5], + ]).T.astype(index_dtype) + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCOOTensor.from_numpy(data, coords, + shape, dim_names) + serialization_roundtrip(sparse_tensor, large_buffer) + + +@pytest.mark.skipif(not coo_matrix, reason="requires scipy") +def test_scipy_sparse_coo_tensor_serialization(): + data = np.array([1, 2, 3, 4, 5, 6]) + row = np.array([0, 0, 2, 3, 1, 3]) + col = np.array([0, 2, 0, 4, 5, 5]) + shape = (4, 6) + + sparse_array = coo_matrix((data, (row, col)), shape=shape) + serialized = pa.serialize(sparse_array) + result = serialized.deserialize() + + assert np.array_equal(sparse_array.toarray(), result.toarray()) + + +@pytest.mark.skipif(not sparse, reason="requires pydata/sparse") +def test_pydata_sparse_sparse_coo_tensor_serialization(): + data = np.array([1, 2, 3, 4, 5, 6]) + coords = np.array([ + [0, 0, 2, 3, 1, 3], + [0, 2, 0, 4, 5, 5], + ]) + shape = (4, 6) + + sparse_array = sparse.COO(data=data, coords=coords, shape=shape) + serialized = pa.serialize(sparse_array) + result = serialized.deserialize() + + assert np.array_equal(sparse_array.todense(), result.todense()) + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_csr_matrix_serialization(index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(tensor_dtype) + indptr = np.array([0, 2, 3, 4, 6]).astype(index_dtype) + indices = np.array([0, 2, 5, 0, 4, 5]).astype(index_dtype) + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCSRMatrix.from_numpy(data, indptr, indices, + shape, dim_names) + + context = pa.default_serialization_context() + serialized = pa.serialize(sparse_tensor, context=context).to_buffer() + result = pa.deserialize(serialized) + assert_equal(result, sparse_tensor) + assert isinstance(result, pa.SparseCSRMatrix) + + data_result, indptr_result, indices_result = result.to_numpy() + assert np.array_equal(data_result, data) + assert np.array_equal(indptr_result, indptr) + assert np.array_equal(indices_result, indices) + assert result.dim_names == dim_names + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_csr_matrix_components_serialization(large_buffer, + index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([8, 2, 5, 3, 4, 6]).astype(tensor_dtype) + indptr = np.array([0, 2, 3, 4, 6]).astype(index_dtype) + indices = np.array([0, 2, 5, 0, 4, 5]).astype(index_dtype) + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCSRMatrix.from_numpy(data, indptr, indices, + shape, dim_names) + serialization_roundtrip(sparse_tensor, large_buffer) + + +@pytest.mark.skipif(not csr_matrix, reason="requires scipy") +def test_scipy_sparse_csr_matrix_serialization(): + data = np.array([8, 2, 5, 3, 4, 6]) + indptr = np.array([0, 2, 3, 4, 6]) + indices = np.array([0, 2, 5, 0, 4, 5]) + shape = (4, 6) + + sparse_array = csr_matrix((data, indices, indptr), shape=shape) + serialized = pa.serialize(sparse_array) + result = serialized.deserialize() + + assert np.array_equal(sparse_array.toarray(), result.toarray()) + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_csc_matrix_serialization(index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(tensor_dtype) + indptr = np.array([0, 2, 3, 4, 6]).astype(index_dtype) + indices = np.array([0, 2, 5, 0, 4, 5]).astype(index_dtype) + shape = (6, 4) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCSCMatrix.from_numpy(data, indptr, indices, + shape, dim_names) + + context = pa.default_serialization_context() + serialized = pa.serialize(sparse_tensor, context=context).to_buffer() + result = pa.deserialize(serialized) + assert_equal(result, sparse_tensor) + assert isinstance(result, pa.SparseCSCMatrix) + + data_result, indptr_result, indices_result = result.to_numpy() + assert np.array_equal(data_result, data) + assert np.array_equal(indptr_result, indptr) + assert np.array_equal(indices_result, indices) + assert result.dim_names == dim_names + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_csc_matrix_components_serialization(large_buffer, + index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([8, 2, 5, 3, 4, 6]).astype(tensor_dtype) + indptr = np.array([0, 2, 3, 6]).astype(index_dtype) + indices = np.array([0, 2, 2, 0, 1, 2]).astype(index_dtype) + shape = (3, 3) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCSCMatrix.from_numpy(data, indptr, indices, + shape, dim_names) + serialization_roundtrip(sparse_tensor, large_buffer) + + +@pytest.mark.skipif(not csc_matrix, reason="requires scipy") +def test_scipy_sparse_csc_matrix_serialization(): + data = np.array([8, 2, 5, 3, 4, 6]) + indptr = np.array([0, 2, 3, 4, 6]) + indices = np.array([0, 2, 5, 0, 4, 5]) + shape = (6, 4) + + sparse_array = csc_matrix((data, indices, indptr), shape=shape) + serialized = pa.serialize(sparse_array) + result = serialized.deserialize() + + assert np.array_equal(sparse_array.toarray(), result.toarray()) + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_csf_tensor_serialization(index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([[1, 2, 3, 4, 5, 6, 7, 8]]).T.astype(tensor_dtype) + indptr = [ + np.array([0, 2, 3]), + np.array([0, 1, 3, 4]), + np.array([0, 2, 4, 5, 8]), + ] + indices = [ + np.array([0, 1]), + np.array([0, 1, 1]), + np.array([0, 0, 1, 1]), + np.array([1, 2, 0, 2, 0, 0, 1, 2]), + ] + indptr = [x.astype(index_dtype) for x in indptr] + indices = [x.astype(index_dtype) for x in indices] + shape = (2, 3, 4, 5) + axis_order = (0, 1, 2, 3) + dim_names = ("a", "b", "c", "d") + + for ndim in [2, 3, 4]: + sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr[:ndim - 1], + indices[:ndim], + shape[:ndim], + axis_order[:ndim], + dim_names[:ndim]) + + context = pa.default_serialization_context() + serialized = pa.serialize(sparse_tensor, context=context).to_buffer() + result = pa.deserialize(serialized) + assert_equal(result, sparse_tensor) + assert isinstance(result, pa.SparseCSFTensor) + + +@pytest.mark.parametrize('tensor_type', tensor_types) +@pytest.mark.parametrize('index_type', index_types) +def test_sparse_csf_tensor_components_serialization(large_buffer, + index_type, tensor_type): + tensor_dtype = np.dtype(tensor_type) + index_dtype = np.dtype(index_type) + data = np.array([[1, 2, 3, 4, 5, 6, 7, 8]]).T.astype(tensor_dtype) + indptr = [ + np.array([0, 2, 3]), + np.array([0, 1, 3, 4]), + np.array([0, 2, 4, 5, 8]), + ] + indices = [ + np.array([0, 1]), + np.array([0, 1, 1]), + np.array([0, 0, 1, 1]), + np.array([1, 2, 0, 2, 0, 0, 1, 2]), + ] + indptr = [x.astype(index_dtype) for x in indptr] + indices = [x.astype(index_dtype) for x in indices] + shape = (2, 3, 4, 5) + axis_order = (0, 1, 2, 3) + dim_names = ("a", "b", "c", "d") + + for ndim in [2, 3, 4]: + sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr[:ndim - 1], + indices[:ndim], + shape[:ndim], + axis_order[:ndim], + dim_names[:ndim]) + + serialization_roundtrip(sparse_tensor, large_buffer) + + +@pytest.mark.filterwarnings( + "ignore:the matrix subclass:PendingDeprecationWarning") +def test_numpy_matrix_serialization(tmpdir): + class CustomType: + def __init__(self, val): + self.val = val + + rec_type = np.dtype([('x', 'int64'), ('y', 'double'), ('z', '<U4')]) + + path = os.path.join(str(tmpdir), 'pyarrow_npmatrix_serialization_test.bin') + array = np.random.randint(low=-1, high=1, size=(2, 2)) + + for data_type in [str, int, float, rec_type, CustomType]: + matrix = np.matrix(array.astype(data_type)) + + with open(path, 'wb') as f: + f.write(pa.serialize(matrix).to_buffer()) + + serialized = pa.read_serialized(pa.OSFile(path)) + result = serialized.deserialize() + assert_equal(result, matrix) + assert_equal(result.dtype, matrix.dtype) + serialized = None + assert_equal(result, matrix) + assert result.base is not None + + +def test_pyarrow_objects_serialization(large_buffer): + # NOTE: We have to put these objects inside, + # or it will affect 'test_total_bytes_allocated'. + pyarrow_objects = [ + pa.array([1, 2, 3, 4]), pa.array(['1', 'never U+1F631', '', + "233 * U+1F600"]), + pa.array([1, None, 2, 3]), + pa.Tensor.from_numpy(np.random.rand(2, 3, 4)), + pa.RecordBatch.from_arrays( + [pa.array([1, None, 2, 3]), + pa.array(['1', 'never U+1F631', '', "233 * u1F600"])], + ['a', 'b']), + pa.Table.from_arrays([pa.array([1, None, 2, 3]), + pa.array(['1', 'never U+1F631', '', + "233 * u1F600"])], + ['a', 'b']) + ] + for obj in pyarrow_objects: + serialization_roundtrip(obj, large_buffer) + + +def test_buffer_serialization(): + + class BufferClass: + pass + + def serialize_buffer_class(obj): + return pa.py_buffer(b"hello") + + def deserialize_buffer_class(serialized_obj): + return serialized_obj + + context = pa.default_serialization_context() + context.register_type( + BufferClass, "BufferClass", + custom_serializer=serialize_buffer_class, + custom_deserializer=deserialize_buffer_class) + + b = pa.serialize(BufferClass(), context=context).to_buffer() + assert pa.deserialize(b, context=context).to_pybytes() == b"hello" + + +@pytest.mark.skip(reason="extensive memory requirements") +def test_arrow_limits(self): + def huge_memory_map(temp_dir): + return large_memory_map(temp_dir, 100 * 1024 * 1024 * 1024) + + with pa.memory_map(huge_memory_map, mode="r+") as mmap: + # Test that objects that are too large for Arrow throw a Python + # exception. These tests give out of memory errors on Travis and need + # to be run on a machine with lots of RAM. + x = 2 ** 29 * [1.0] + serialization_roundtrip(x, mmap) + del x + x = 2 ** 29 * ["s"] + serialization_roundtrip(x, mmap) + del x + x = 2 ** 29 * [["1"], 2, 3, [{"s": 4}]] + serialization_roundtrip(x, mmap) + del x + x = 2 ** 29 * [{"s": 1}] + 2 ** 29 * [1.0] + serialization_roundtrip(x, mmap) + del x + x = np.zeros(2 ** 25) + serialization_roundtrip(x, mmap) + del x + x = [np.zeros(2 ** 18) for _ in range(2 ** 7)] + serialization_roundtrip(x, mmap) + del x + + +def test_serialization_callback_error(): + + class TempClass: + pass + + # Pass a SerializationContext into serialize, but TempClass + # is not registered + serialization_context = pa.SerializationContext() + val = TempClass() + with pytest.raises(pa.SerializationCallbackError) as err: + serialized_object = pa.serialize(val, serialization_context) + assert err.value.example_object == val + + serialization_context.register_type(TempClass, "TempClass") + serialized_object = pa.serialize(TempClass(), serialization_context) + deserialization_context = pa.SerializationContext() + + # Pass a Serialization Context into deserialize, but TempClass + # is not registered + with pytest.raises(pa.DeserializationCallbackError) as err: + serialized_object.deserialize(deserialization_context) + assert err.value.type_id == "TempClass" + + class TempClass2: + pass + + # Make sure that we receive an error when we use an inappropriate value for + # the type_id argument. + with pytest.raises(TypeError): + serialization_context.register_type(TempClass2, 1) + + +def test_fallback_to_subclasses(): + + class SubFoo(Foo): + def __init__(self): + Foo.__init__(self) + + # should be able to serialize/deserialize an instance + # if a base class has been registered + serialization_context = pa.SerializationContext() + serialization_context.register_type(Foo, "Foo") + + subfoo = SubFoo() + # should fallbact to Foo serializer + serialized_object = pa.serialize(subfoo, serialization_context) + + reconstructed_object = serialized_object.deserialize( + serialization_context + ) + assert type(reconstructed_object) == Foo + + +class Serializable: + pass + + +def serialize_serializable(obj): + return {"type": type(obj), "data": obj.__dict__} + + +def deserialize_serializable(obj): + val = obj["type"].__new__(obj["type"]) + val.__dict__.update(obj["data"]) + return val + + +class SerializableClass(Serializable): + def __init__(self): + self.value = 3 + + +def test_serialize_subclasses(): + + # This test shows how subclasses can be handled in an idiomatic way + # by having only a serializer for the base class + + # This technique should however be used with care, since pickling + # type(obj) with couldpickle will include the full class definition + # in the serialized representation. + # This means the class definition is part of every instance of the + # object, which in general is not desirable; registering all subclasses + # with register_type will result in faster and more memory + # efficient serialization. + + context = pa.default_serialization_context() + context.register_type( + Serializable, "Serializable", + custom_serializer=serialize_serializable, + custom_deserializer=deserialize_serializable) + + a = SerializableClass() + serialized = pa.serialize(a, context=context) + + deserialized = serialized.deserialize(context=context) + assert type(deserialized).__name__ == SerializableClass.__name__ + assert deserialized.value == 3 + + +def test_serialize_to_components_invalid_cases(): + buf = pa.py_buffer(b'hello') + + components = { + 'num_tensors': 0, + 'num_sparse_tensors': { + 'coo': 0, 'csr': 0, 'csc': 0, 'csf': 0, 'ndim_csf': 0 + }, + 'num_ndarrays': 0, + 'num_buffers': 1, + 'data': [buf] + } + + with pytest.raises(pa.ArrowInvalid): + pa.deserialize_components(components) + + components = { + 'num_tensors': 0, + 'num_sparse_tensors': { + 'coo': 0, 'csr': 0, 'csc': 0, 'csf': 0, 'ndim_csf': 0 + }, + 'num_ndarrays': 1, + 'num_buffers': 0, + 'data': [buf, buf] + } + + with pytest.raises(pa.ArrowInvalid): + pa.deserialize_components(components) + + +def test_deserialize_components_in_different_process(): + arr = pa.array([1, 2, 5, 6], type=pa.int8()) + ser = pa.serialize(arr) + data = pickle.dumps(ser.to_components(), protocol=-1) + + code = """if 1: + import pickle + + import pyarrow as pa + + data = {!r} + components = pickle.loads(data) + arr = pa.deserialize_components(components) + + assert arr.to_pylist() == [1, 2, 5, 6], arr + """.format(data) + + subprocess_env = test_util.get_modified_env_with_pythonpath() + print("** sys.path =", sys.path) + print("** setting PYTHONPATH to:", subprocess_env['PYTHONPATH']) + subprocess.check_call([sys.executable, "-c", code], env=subprocess_env) + + +def test_serialize_read_concatenated_records(): + # ARROW-1996 -- see stream alignment work in ARROW-2840, ARROW-3212 + f = pa.BufferOutputStream() + pa.serialize_to(12, f) + pa.serialize_to(23, f) + buf = f.getvalue() + + f = pa.BufferReader(buf) + pa.read_serialized(f).deserialize() + pa.read_serialized(f).deserialize() + + +def deserialize_regex(serialized, q): + import pyarrow as pa + q.put(pa.deserialize(serialized)) + + +def test_deserialize_in_different_process(): + from multiprocessing import Process, Queue + import re + + regex = re.compile(r"\d+\.\d*") + + serialization_context = pa.SerializationContext() + serialization_context.register_type(type(regex), "Regex", pickle=True) + + serialized = pa.serialize(regex, serialization_context) + serialized_bytes = serialized.to_buffer().to_pybytes() + + q = Queue() + p = Process(target=deserialize_regex, args=(serialized_bytes, q)) + p.start() + assert q.get().pattern == regex.pattern + p.join() + + +def test_deserialize_buffer_in_different_process(): + import tempfile + + f = tempfile.NamedTemporaryFile(delete=False) + b = pa.serialize(pa.py_buffer(b'hello')).to_buffer() + f.write(b.to_pybytes()) + f.close() + + test_util.invoke_script('deserialize_buffer.py', f.name) + + +def test_set_pickle(): + # Use a custom type to trigger pickling. + class Foo: + pass + + context = pa.SerializationContext() + context.register_type(Foo, 'Foo', pickle=True) + + test_object = Foo() + + # Define a custom serializer and deserializer to use in place of pickle. + + def dumps1(obj): + return b'custom' + + def loads1(serialized_obj): + return serialized_obj + b' serialization 1' + + # Test that setting a custom pickler changes the behavior. + context.set_pickle(dumps1, loads1) + serialized = pa.serialize(test_object, context=context).to_buffer() + deserialized = pa.deserialize(serialized.to_pybytes(), context=context) + assert deserialized == b'custom serialization 1' + + # Define another custom serializer and deserializer. + + def dumps2(obj): + return b'custom' + + def loads2(serialized_obj): + return serialized_obj + b' serialization 2' + + # Test that setting another custom pickler changes the behavior again. + context.set_pickle(dumps2, loads2) + serialized = pa.serialize(test_object, context=context).to_buffer() + deserialized = pa.deserialize(serialized.to_pybytes(), context=context) + assert deserialized == b'custom serialization 2' + + +def test_path_objects(tmpdir): + # Test compatibility with PEP 519 path-like objects + p = pathlib.Path(tmpdir) / 'zzz.bin' + obj = 1234 + pa.serialize_to(obj, p) + res = pa.deserialize_from(p, None) + assert res == obj + + +def test_tensor_alignment(): + # Deserialized numpy arrays should be 64-byte aligned. + x = np.random.normal(size=(10, 20, 30)) + y = pa.deserialize(pa.serialize(x).to_buffer()) + assert y.ctypes.data % 64 == 0 + + xs = [np.random.normal(size=i) for i in range(100)] + ys = pa.deserialize(pa.serialize(xs).to_buffer()) + for y in ys: + assert y.ctypes.data % 64 == 0 + + xs = [np.random.normal(size=i * (1,)) for i in range(20)] + ys = pa.deserialize(pa.serialize(xs).to_buffer()) + for y in ys: + assert y.ctypes.data % 64 == 0 + + xs = [np.random.normal(size=i * (5,)) for i in range(1, 8)] + xs = [xs[i][(i + 1) * (slice(1, 3),)] for i in range(len(xs))] + ys = pa.deserialize(pa.serialize(xs).to_buffer()) + for y in ys: + assert y.ctypes.data % 64 == 0 + + +def test_empty_tensor(): + # ARROW-8122, serialize and deserialize empty tensors + x = np.array([], dtype=np.float64) + y = pa.deserialize(pa.serialize(x).to_buffer()) + np.testing.assert_array_equal(x, y) + + x = np.array([[], [], []], dtype=np.float64) + y = pa.deserialize(pa.serialize(x).to_buffer()) + np.testing.assert_array_equal(x, y) + + x = np.array([[], [], []], dtype=np.float64).T + y = pa.deserialize(pa.serialize(x).to_buffer()) + np.testing.assert_array_equal(x, y) + + +def test_serialization_determinism(): + for obj in COMPLEX_OBJECTS: + buf1 = pa.serialize(obj).to_buffer() + buf2 = pa.serialize(obj).to_buffer() + assert buf1.to_pybytes() == buf2.to_pybytes() + + +def test_serialize_recursive_objects(): + class ClassA: + pass + + # Make a list that contains itself. + lst = [] + lst.append(lst) + + # Make an object that contains itself as a field. + a1 = ClassA() + a1.field = a1 + + # Make two objects that contain each other as fields. + a2 = ClassA() + a3 = ClassA() + a2.field = a3 + a3.field = a2 + + # Make a dictionary that contains itself. + d1 = {} + d1["key"] = d1 + + # Make a numpy array that contains itself. + arr = np.array([None], dtype=object) + arr[0] = arr + + # Create a list of recursive objects. + recursive_objects = [lst, a1, a2, a3, d1, arr] + + # Check that exceptions are thrown when we serialize the recursive + # objects. + for obj in recursive_objects: + with pytest.raises(Exception): + pa.serialize(obj).deserialize() diff --git a/src/arrow/python/pyarrow/tests/test_serialization_deprecated.py b/src/arrow/python/pyarrow/tests/test_serialization_deprecated.py new file mode 100644 index 000000000..cd4b3ed78 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_serialization_deprecated.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import pytest + +import pyarrow as pa + + +def test_serialization_deprecated(): + with pytest.warns(FutureWarning): + ser = pa.serialize(1) + + with pytest.warns(FutureWarning): + pa.deserialize(ser.to_buffer()) + + f = pa.BufferOutputStream() + with pytest.warns(FutureWarning): + pa.serialize_to(12, f) + + buf = f.getvalue() + f = pa.BufferReader(buf) + with pytest.warns(FutureWarning): + pa.read_serialized(f).deserialize() + + with pytest.warns(FutureWarning): + pa.default_serialization_context() + + context = pa.lib.SerializationContext() + with pytest.warns(FutureWarning): + pa.register_default_serialization_handlers(context) + + +@pytest.mark.skipif(sys.version_info < (3, 7), + reason="getattr needs Python 3.7") +def test_serialization_deprecated_toplevel(): + with pytest.warns(FutureWarning): + pa.SerializedPyObject() + + with pytest.warns(FutureWarning): + pa.SerializationContext() diff --git a/src/arrow/python/pyarrow/tests/test_sparse_tensor.py b/src/arrow/python/pyarrow/tests/test_sparse_tensor.py new file mode 100644 index 000000000..aa7da0a74 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_sparse_tensor.py @@ -0,0 +1,491 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import sys +import weakref + +import numpy as np +import pyarrow as pa + +try: + from scipy.sparse import csr_matrix, coo_matrix +except ImportError: + coo_matrix = None + csr_matrix = None + +try: + import sparse +except ImportError: + sparse = None + + +tensor_type_pairs = [ + ('i1', pa.int8()), + ('i2', pa.int16()), + ('i4', pa.int32()), + ('i8', pa.int64()), + ('u1', pa.uint8()), + ('u2', pa.uint16()), + ('u4', pa.uint32()), + ('u8', pa.uint64()), + ('f2', pa.float16()), + ('f4', pa.float32()), + ('f8', pa.float64()) +] + + +@pytest.mark.parametrize('sparse_tensor_type', [ + pa.SparseCSRMatrix, + pa.SparseCSCMatrix, + pa.SparseCOOTensor, + pa.SparseCSFTensor, +]) +def test_sparse_tensor_attrs(sparse_tensor_type): + data = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]) + dim_names = ('x', 'y') + sparse_tensor = sparse_tensor_type.from_dense_numpy(data, dim_names) + + assert sparse_tensor.ndim == 2 + assert sparse_tensor.size == 24 + assert sparse_tensor.shape == data.shape + assert sparse_tensor.is_mutable + assert sparse_tensor.dim_name(0) == dim_names[0] + assert sparse_tensor.dim_names == dim_names + assert sparse_tensor.non_zero_length == 6 + + wr = weakref.ref(sparse_tensor) + assert wr() is not None + del sparse_tensor + assert wr() is None + + +def test_sparse_coo_tensor_base_object(): + expected_data = np.array([[8, 2, 5, 3, 4, 6]]).T + expected_coords = np.array([ + [0, 0, 1, 2, 3, 3], + [0, 2, 5, 0, 4, 5], + ]).T + array = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]) + sparse_tensor = pa.SparseCOOTensor.from_dense_numpy(array) + n = sys.getrefcount(sparse_tensor) + result_data, result_coords = sparse_tensor.to_numpy() + assert sparse_tensor.has_canonical_format + assert sys.getrefcount(sparse_tensor) == n + 2 + + sparse_tensor = None + assert np.array_equal(expected_data, result_data) + assert np.array_equal(expected_coords, result_coords) + assert result_coords.flags.c_contiguous # row-major + + +def test_sparse_csr_matrix_base_object(): + data = np.array([[8, 2, 5, 3, 4, 6]]).T + indptr = np.array([0, 2, 3, 4, 6]) + indices = np.array([0, 2, 5, 0, 4, 5]) + array = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]) + sparse_tensor = pa.SparseCSRMatrix.from_dense_numpy(array) + n = sys.getrefcount(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sys.getrefcount(sparse_tensor) == n + 3 + + sparse_tensor = None + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + + +def test_sparse_csf_tensor_base_object(): + data = np.array([[8, 2, 5, 3, 4, 6]]).T + indptr = [np.array([0, 2, 3, 4, 6])] + indices = [ + np.array([0, 1, 2, 3]), + np.array([0, 2, 5, 0, 4, 5]) + ] + array = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]) + sparse_tensor = pa.SparseCSFTensor.from_dense_numpy(array) + n = sys.getrefcount(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sys.getrefcount(sparse_tensor) == n + 4 + + sparse_tensor = None + assert np.array_equal(data, result_data) + assert np.array_equal(indptr[0], result_indptr[0]) + assert np.array_equal(indices[0], result_indices[0]) + assert np.array_equal(indices[1], result_indices[1]) + + +@pytest.mark.parametrize('sparse_tensor_type', [ + pa.SparseCSRMatrix, + pa.SparseCSCMatrix, + pa.SparseCOOTensor, + pa.SparseCSFTensor, +]) +def test_sparse_tensor_equals(sparse_tensor_type): + def eq(a, b): + assert a.equals(b) + assert a == b + assert not (a != b) + + def ne(a, b): + assert not a.equals(b) + assert not (a == b) + assert a != b + + data = np.random.randn(10, 6)[::, ::2] + sparse_tensor1 = sparse_tensor_type.from_dense_numpy(data) + sparse_tensor2 = sparse_tensor_type.from_dense_numpy( + np.ascontiguousarray(data)) + eq(sparse_tensor1, sparse_tensor2) + data = data.copy() + data[9, 0] = 1.0 + sparse_tensor2 = sparse_tensor_type.from_dense_numpy( + np.ascontiguousarray(data)) + ne(sparse_tensor1, sparse_tensor2) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_coo_tensor_from_dense(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + expected_data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype) + expected_coords = np.array([ + [0, 0, 1, 2, 3, 3], + [0, 2, 5, 0, 4, 5], + ]).T + array = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]).astype(dtype) + tensor = pa.Tensor.from_numpy(array) + + # Test from numpy array + sparse_tensor = pa.SparseCOOTensor.from_dense_numpy(array) + repr(sparse_tensor) + result_data, result_coords = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(expected_data, result_data) + assert np.array_equal(expected_coords, result_coords) + + # Test from Tensor + sparse_tensor = pa.SparseCOOTensor.from_tensor(tensor) + repr(sparse_tensor) + result_data, result_coords = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(expected_data, result_data) + assert np.array_equal(expected_coords, result_coords) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_csr_matrix_from_dense(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype) + indptr = np.array([0, 2, 3, 4, 6]) + indices = np.array([0, 2, 5, 0, 4, 5]) + array = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]).astype(dtype) + tensor = pa.Tensor.from_numpy(array) + + # Test from numpy array + sparse_tensor = pa.SparseCSRMatrix.from_dense_numpy(array) + repr(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + + # Test from Tensor + sparse_tensor = pa.SparseCSRMatrix.from_tensor(tensor) + repr(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_csf_tensor_from_dense_numpy(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype) + indptr = [np.array([0, 2, 3, 4, 6])] + indices = [ + np.array([0, 1, 2, 3]), + np.array([0, 2, 5, 0, 4, 5]) + ] + array = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]).astype(dtype) + + # Test from numpy array + sparse_tensor = pa.SparseCSFTensor.from_dense_numpy(array) + repr(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(data, result_data) + assert np.array_equal(indptr[0], result_indptr[0]) + assert np.array_equal(indices[0], result_indices[0]) + assert np.array_equal(indices[1], result_indices[1]) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_csf_tensor_from_dense_tensor(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype) + indptr = [np.array([0, 2, 3, 4, 6])] + indices = [ + np.array([0, 1, 2, 3]), + np.array([0, 2, 5, 0, 4, 5]) + ] + array = np.array([ + [8, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 5], + [3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 4, 6], + ]).astype(dtype) + tensor = pa.Tensor.from_numpy(array) + + # Test from Tensor + sparse_tensor = pa.SparseCSFTensor.from_tensor(tensor) + repr(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(data, result_data) + assert np.array_equal(indptr[0], result_indptr[0]) + assert np.array_equal(indices[0], result_indices[0]) + assert np.array_equal(indices[1], result_indices[1]) + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_coo_tensor_numpy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[1, 2, 3, 4, 5, 6]]).T.astype(dtype) + coords = np.array([ + [0, 0, 2, 3, 1, 3], + [0, 2, 0, 4, 5, 5], + ]).T + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCOOTensor.from_numpy(data, coords, shape, + dim_names) + repr(sparse_tensor) + result_data, result_coords = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(data, result_data) + assert np.array_equal(coords, result_coords) + assert sparse_tensor.dim_names == dim_names + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_csr_matrix_numpy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype) + indptr = np.array([0, 2, 3, 4, 6]) + indices = np.array([0, 2, 5, 0, 4, 5]) + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCSRMatrix.from_numpy(data, indptr, indices, + shape, dim_names) + repr(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(data, result_data) + assert np.array_equal(indptr, result_indptr) + assert np.array_equal(indices, result_indices) + assert sparse_tensor.dim_names == dim_names + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_csf_tensor_numpy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([[8, 2, 5, 3, 4, 6]]).T.astype(dtype) + indptr = [np.array([0, 2, 3, 4, 6])] + indices = [ + np.array([0, 1, 2, 3]), + np.array([0, 2, 5, 0, 4, 5]) + ] + axis_order = (0, 1) + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_tensor = pa.SparseCSFTensor.from_numpy(data, indptr, indices, + shape, axis_order, + dim_names) + repr(sparse_tensor) + result_data, result_indptr, result_indices = sparse_tensor.to_numpy() + assert sparse_tensor.type == arrow_type + assert np.array_equal(data, result_data) + assert np.array_equal(indptr[0], result_indptr[0]) + assert np.array_equal(indices[0], result_indices[0]) + assert np.array_equal(indices[1], result_indices[1]) + assert sparse_tensor.dim_names == dim_names + + +@pytest.mark.parametrize('sparse_tensor_type', [ + pa.SparseCSRMatrix, + pa.SparseCSCMatrix, + pa.SparseCOOTensor, + pa.SparseCSFTensor, +]) +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_dense_to_sparse_tensor(dtype_str, arrow_type, sparse_tensor_type): + dtype = np.dtype(dtype_str) + array = np.array([[4, 0, 9, 0], + [0, 7, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 5]]).astype(dtype) + dim_names = ('x', 'y') + + sparse_tensor = sparse_tensor_type.from_dense_numpy(array, dim_names) + tensor = sparse_tensor.to_tensor() + result_array = tensor.to_numpy() + + assert sparse_tensor.type == arrow_type + assert tensor.type == arrow_type + assert sparse_tensor.dim_names == dim_names + assert np.array_equal(array, result_array) + + +@pytest.mark.skipif(not coo_matrix, reason="requires scipy") +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_coo_tensor_scipy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([1, 2, 3, 4, 5, 6]).astype(dtype) + row = np.array([0, 0, 2, 3, 1, 3]) + col = np.array([0, 2, 0, 4, 5, 5]) + shape = (4, 6) + dim_names = ('x', 'y') + + # non-canonical sparse coo matrix + scipy_matrix = coo_matrix((data, (row, col)), shape=shape) + sparse_tensor = pa.SparseCOOTensor.from_scipy(scipy_matrix, + dim_names=dim_names) + out_scipy_matrix = sparse_tensor.to_scipy() + + assert not scipy_matrix.has_canonical_format + assert not sparse_tensor.has_canonical_format + assert not out_scipy_matrix.has_canonical_format + assert sparse_tensor.type == arrow_type + assert sparse_tensor.dim_names == dim_names + assert scipy_matrix.dtype == out_scipy_matrix.dtype + assert np.array_equal(scipy_matrix.data, out_scipy_matrix.data) + assert np.array_equal(scipy_matrix.row, out_scipy_matrix.row) + assert np.array_equal(scipy_matrix.col, out_scipy_matrix.col) + + if dtype_str == 'f2': + dense_array = \ + scipy_matrix.astype(np.float32).toarray().astype(np.float16) + else: + dense_array = scipy_matrix.toarray() + assert np.array_equal(dense_array, sparse_tensor.to_tensor().to_numpy()) + + # canonical sparse coo matrix + scipy_matrix.sum_duplicates() + sparse_tensor = pa.SparseCOOTensor.from_scipy(scipy_matrix, + dim_names=dim_names) + out_scipy_matrix = sparse_tensor.to_scipy() + + assert scipy_matrix.has_canonical_format + assert sparse_tensor.has_canonical_format + assert out_scipy_matrix.has_canonical_format + + +@pytest.mark.skipif(not csr_matrix, reason="requires scipy") +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_sparse_csr_matrix_scipy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([8, 2, 5, 3, 4, 6]).astype(dtype) + indptr = np.array([0, 2, 3, 4, 6]) + indices = np.array([0, 2, 5, 0, 4, 5]) + shape = (4, 6) + dim_names = ('x', 'y') + + sparse_array = csr_matrix((data, indices, indptr), shape=shape) + sparse_tensor = pa.SparseCSRMatrix.from_scipy(sparse_array, + dim_names=dim_names) + out_sparse_array = sparse_tensor.to_scipy() + + assert sparse_tensor.type == arrow_type + assert sparse_tensor.dim_names == dim_names + assert sparse_array.dtype == out_sparse_array.dtype + assert np.array_equal(sparse_array.data, out_sparse_array.data) + assert np.array_equal(sparse_array.indptr, out_sparse_array.indptr) + assert np.array_equal(sparse_array.indices, out_sparse_array.indices) + + if dtype_str == 'f2': + dense_array = \ + sparse_array.astype(np.float32).toarray().astype(np.float16) + else: + dense_array = sparse_array.toarray() + assert np.array_equal(dense_array, sparse_tensor.to_tensor().to_numpy()) + + +@pytest.mark.skipif(not sparse, reason="requires pydata/sparse") +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_pydata_sparse_sparse_coo_tensor_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = np.array([1, 2, 3, 4, 5, 6]).astype(dtype) + coords = np.array([ + [0, 0, 2, 3, 1, 3], + [0, 2, 0, 4, 5, 5], + ]) + shape = (4, 6) + dim_names = ("x", "y") + + sparse_array = sparse.COO(data=data, coords=coords, shape=shape) + sparse_tensor = pa.SparseCOOTensor.from_pydata_sparse(sparse_array, + dim_names=dim_names) + out_sparse_array = sparse_tensor.to_pydata_sparse() + + assert sparse_tensor.type == arrow_type + assert sparse_tensor.dim_names == dim_names + assert sparse_array.dtype == out_sparse_array.dtype + assert np.array_equal(sparse_array.data, out_sparse_array.data) + assert np.array_equal(sparse_array.coords, out_sparse_array.coords) + assert np.array_equal(sparse_array.todense(), + sparse_tensor.to_tensor().to_numpy()) diff --git a/src/arrow/python/pyarrow/tests/test_strategies.py b/src/arrow/python/pyarrow/tests/test_strategies.py new file mode 100644 index 000000000..14fc94992 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_strategies.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import hypothesis as h + +import pyarrow as pa +import pyarrow.tests.strategies as past + + +@h.given(past.all_types) +def test_types(ty): + assert isinstance(ty, pa.lib.DataType) + + +@h.given(past.all_fields) +def test_fields(field): + assert isinstance(field, pa.lib.Field) + + +@h.given(past.all_schemas) +def test_schemas(schema): + assert isinstance(schema, pa.lib.Schema) + + +@h.given(past.all_arrays) +def test_arrays(array): + assert isinstance(array, pa.lib.Array) + + +@h.given(past.arrays(past.primitive_types, nullable=False)) +def test_array_nullability(array): + assert array.null_count == 0 + + +@h.given(past.all_chunked_arrays) +def test_chunked_arrays(chunked_array): + assert isinstance(chunked_array, pa.lib.ChunkedArray) + + +@h.given(past.all_record_batches) +def test_record_batches(record_bath): + assert isinstance(record_bath, pa.lib.RecordBatch) + + +@h.given(past.all_tables) +def test_tables(table): + assert isinstance(table, pa.lib.Table) diff --git a/src/arrow/python/pyarrow/tests/test_table.py b/src/arrow/python/pyarrow/tests/test_table.py new file mode 100644 index 000000000..ef41a733d --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_table.py @@ -0,0 +1,1748 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import OrderedDict +from collections.abc import Iterable +import pickle +import sys +import weakref + +import numpy as np +import pytest +import pyarrow as pa + + +def test_chunked_array_basics(): + data = pa.chunked_array([], type=pa.string()) + assert data.type == pa.string() + assert data.to_pylist() == [] + data.validate() + + data2 = pa.chunked_array([], type='binary') + assert data2.type == pa.binary() + + with pytest.raises(ValueError): + pa.chunked_array([]) + + data = pa.chunked_array([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] + ]) + assert isinstance(data.chunks, list) + assert all(isinstance(c, pa.lib.Int64Array) for c in data.chunks) + assert all(isinstance(c, pa.lib.Int64Array) for c in data.iterchunks()) + assert len(data.chunks) == 3 + assert data.nbytes == sum(c.nbytes for c in data.iterchunks()) + assert sys.getsizeof(data) >= object.__sizeof__(data) + data.nbytes + data.validate() + + wr = weakref.ref(data) + assert wr() is not None + del data + assert wr() is None + + +def test_chunked_array_construction(): + arr = pa.chunked_array([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ]) + assert arr.type == pa.int64() + assert len(arr) == 9 + assert len(arr.chunks) == 3 + + arr = pa.chunked_array([ + [1, 2, 3], + [4., 5., 6.], + [7, 8, 9], + ]) + assert arr.type == pa.int64() + assert len(arr) == 9 + assert len(arr.chunks) == 3 + + arr = pa.chunked_array([ + [1, 2, 3], + [4., 5., 6.], + [7, 8, 9], + ], type=pa.int8()) + assert arr.type == pa.int8() + assert len(arr) == 9 + assert len(arr.chunks) == 3 + + arr = pa.chunked_array([ + [1, 2, 3], + [] + ]) + assert arr.type == pa.int64() + assert len(arr) == 3 + assert len(arr.chunks) == 2 + + msg = ( + "When passing an empty collection of arrays you must also pass the " + "data type" + ) + with pytest.raises(ValueError, match=msg): + assert pa.chunked_array([]) + + assert pa.chunked_array([], type=pa.string()).type == pa.string() + assert pa.chunked_array([[]]).type == pa.null() + assert pa.chunked_array([[]], type=pa.string()).type == pa.string() + + +def test_combine_chunks(): + # ARROW-77363 + arr = pa.array([1, 2]) + chunked_arr = pa.chunked_array([arr, arr]) + res = chunked_arr.combine_chunks() + expected = pa.array([1, 2, 1, 2]) + assert res.equals(expected) + + +def test_chunked_array_to_numpy(): + data = pa.chunked_array([ + [1, 2, 3], + [4, 5, 6], + [] + ]) + arr1 = np.asarray(data) + arr2 = data.to_numpy() + + assert isinstance(arr2, np.ndarray) + assert arr2.shape == (6,) + assert np.array_equal(arr1, arr2) + + +def test_chunked_array_mismatch_types(): + with pytest.raises(TypeError): + # Given array types are different + pa.chunked_array([ + pa.array([1, 2, 3]), + pa.array([1., 2., 3.]) + ]) + + with pytest.raises(TypeError): + # Given array type is different from explicit type argument + pa.chunked_array([pa.array([1, 2, 3])], type=pa.float64()) + + +def test_chunked_array_str(): + data = [ + pa.array([1, 2, 3]), + pa.array([4, 5, 6]) + ] + data = pa.chunked_array(data) + assert str(data) == """[ + [ + 1, + 2, + 3 + ], + [ + 4, + 5, + 6 + ] +]""" + + +def test_chunked_array_getitem(): + data = [ + pa.array([1, 2, 3]), + pa.array([4, 5, 6]) + ] + data = pa.chunked_array(data) + assert data[1].as_py() == 2 + assert data[-1].as_py() == 6 + assert data[-6].as_py() == 1 + with pytest.raises(IndexError): + data[6] + with pytest.raises(IndexError): + data[-7] + # Ensure this works with numpy scalars + assert data[np.int32(1)].as_py() == 2 + + data_slice = data[2:4] + assert data_slice.to_pylist() == [3, 4] + + data_slice = data[4:-1] + assert data_slice.to_pylist() == [5] + + data_slice = data[99:99] + assert data_slice.type == data.type + assert data_slice.to_pylist() == [] + + +def test_chunked_array_slice(): + data = [ + pa.array([1, 2, 3]), + pa.array([4, 5, 6]) + ] + data = pa.chunked_array(data) + + data_slice = data.slice(len(data)) + assert data_slice.type == data.type + assert data_slice.to_pylist() == [] + + data_slice = data.slice(len(data) + 10) + assert data_slice.type == data.type + assert data_slice.to_pylist() == [] + + table = pa.Table.from_arrays([data], names=["a"]) + table_slice = table.slice(len(table)) + assert len(table_slice) == 0 + + table = pa.Table.from_arrays([data], names=["a"]) + table_slice = table.slice(len(table) + 10) + assert len(table_slice) == 0 + + +def test_chunked_array_iter(): + data = [ + pa.array([0]), + pa.array([1, 2, 3]), + pa.array([4, 5, 6]), + pa.array([7, 8, 9]) + ] + arr = pa.chunked_array(data) + + for i, j in zip(range(10), arr): + assert i == j.as_py() + + assert isinstance(arr, Iterable) + + +def test_chunked_array_equals(): + def eq(xarrs, yarrs): + if isinstance(xarrs, pa.ChunkedArray): + x = xarrs + else: + x = pa.chunked_array(xarrs) + if isinstance(yarrs, pa.ChunkedArray): + y = yarrs + else: + y = pa.chunked_array(yarrs) + assert x.equals(y) + assert y.equals(x) + assert x == y + assert x != str(y) + + def ne(xarrs, yarrs): + if isinstance(xarrs, pa.ChunkedArray): + x = xarrs + else: + x = pa.chunked_array(xarrs) + if isinstance(yarrs, pa.ChunkedArray): + y = yarrs + else: + y = pa.chunked_array(yarrs) + assert not x.equals(y) + assert not y.equals(x) + assert x != y + + eq(pa.chunked_array([], type=pa.int32()), + pa.chunked_array([], type=pa.int32())) + ne(pa.chunked_array([], type=pa.int32()), + pa.chunked_array([], type=pa.int64())) + + a = pa.array([0, 2], type=pa.int32()) + b = pa.array([0, 2], type=pa.int64()) + c = pa.array([0, 3], type=pa.int32()) + d = pa.array([0, 2, 0, 3], type=pa.int32()) + + eq([a], [a]) + ne([a], [b]) + eq([a, c], [a, c]) + eq([a, c], [d]) + ne([c, a], [a, c]) + + # ARROW-4822 + assert not pa.chunked_array([], type=pa.int32()).equals(None) + + +@pytest.mark.parametrize( + ('data', 'typ'), + [ + ([True, False, True, True], pa.bool_()), + ([1, 2, 4, 6], pa.int64()), + ([1.0, 2.5, None], pa.float64()), + (['a', None, 'b'], pa.string()), + ([], pa.list_(pa.uint8())), + ([[1, 2], [3]], pa.list_(pa.int64())), + ([['a'], None, ['b', 'c']], pa.list_(pa.string())), + ([(1, 'a'), (2, 'c'), None], + pa.struct([pa.field('a', pa.int64()), pa.field('b', pa.string())])) + ] +) +def test_chunked_array_pickle(data, typ): + arrays = [] + while data: + arrays.append(pa.array(data[:2], type=typ)) + data = data[2:] + array = pa.chunked_array(arrays, type=typ) + array.validate() + result = pickle.loads(pickle.dumps(array)) + result.validate() + assert result.equals(array) + + +@pytest.mark.pandas +def test_chunked_array_to_pandas(): + import pandas as pd + + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.table(data, names=['a']) + col = table.column(0) + assert isinstance(col, pa.ChunkedArray) + series = col.to_pandas() + assert isinstance(series, pd.Series) + assert series.shape == (5,) + assert series[0] == -10 + assert series.name == 'a' + + +@pytest.mark.pandas +def test_chunked_array_to_pandas_preserve_name(): + # https://issues.apache.org/jira/browse/ARROW-7709 + import pandas as pd + import pandas.testing as tm + + for data in [ + pa.array([1, 2, 3]), + pa.array(pd.Categorical(["a", "b", "a"])), + pa.array(pd.date_range("2012", periods=3)), + pa.array(pd.date_range("2012", periods=3, tz="Europe/Brussels")), + pa.array([1, 2, 3], pa.timestamp("ms")), + pa.array([1, 2, 3], pa.timestamp("ms", "Europe/Brussels"))]: + table = pa.table({"name": data}) + result = table.column("name").to_pandas() + assert result.name == "name" + expected = pd.Series(data.to_pandas(), name="name") + tm.assert_series_equal(result, expected) + + +@pytest.mark.pandas +@pytest.mark.nopandas +def test_chunked_array_asarray(): + # ensure this is tested both when pandas is present or not (ARROW-6564) + + data = [ + pa.array([0]), + pa.array([1, 2, 3]) + ] + chunked_arr = pa.chunked_array(data) + + np_arr = np.asarray(chunked_arr) + assert np_arr.tolist() == [0, 1, 2, 3] + assert np_arr.dtype == np.dtype('int64') + + # An optional type can be specified when calling np.asarray + np_arr = np.asarray(chunked_arr, dtype='str') + assert np_arr.tolist() == ['0', '1', '2', '3'] + + # Types are modified when there are nulls + data = [ + pa.array([1, None]), + pa.array([1, 2, 3]) + ] + chunked_arr = pa.chunked_array(data) + + np_arr = np.asarray(chunked_arr) + elements = np_arr.tolist() + assert elements[0] == 1. + assert np.isnan(elements[1]) + assert elements[2:] == [1., 2., 3.] + assert np_arr.dtype == np.dtype('float64') + + # DictionaryType data will be converted to dense numpy array + arr = pa.DictionaryArray.from_arrays( + pa.array([0, 1, 2, 0, 1]), pa.array(['a', 'b', 'c'])) + chunked_arr = pa.chunked_array([arr, arr]) + np_arr = np.asarray(chunked_arr) + assert np_arr.dtype == np.dtype('object') + assert np_arr.tolist() == ['a', 'b', 'c', 'a', 'b'] * 2 + + +def test_chunked_array_flatten(): + ty = pa.struct([pa.field('x', pa.int16()), + pa.field('y', pa.float32())]) + a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty) + carr = pa.chunked_array(a) + x, y = carr.flatten() + assert x.equals(pa.chunked_array(pa.array([1, 3, 5], type=pa.int16()))) + assert y.equals(pa.chunked_array(pa.array([2.5, 4.5, 6.5], + type=pa.float32()))) + + # Empty column + a = pa.array([], type=ty) + carr = pa.chunked_array(a) + x, y = carr.flatten() + assert x.equals(pa.chunked_array(pa.array([], type=pa.int16()))) + assert y.equals(pa.chunked_array(pa.array([], type=pa.float32()))) + + +def test_chunked_array_unify_dictionaries(): + arr = pa.chunked_array([ + pa.array(["foo", "bar", None, "foo"]).dictionary_encode(), + pa.array(["quux", None, "foo"]).dictionary_encode(), + ]) + assert arr.chunk(0).dictionary.equals(pa.array(["foo", "bar"])) + assert arr.chunk(1).dictionary.equals(pa.array(["quux", "foo"])) + arr = arr.unify_dictionaries() + expected_dict = pa.array(["foo", "bar", "quux"]) + assert arr.chunk(0).dictionary.equals(expected_dict) + assert arr.chunk(1).dictionary.equals(expected_dict) + assert arr.to_pylist() == ["foo", "bar", None, "foo", "quux", None, "foo"] + + +def test_recordbatch_basics(): + data = [ + pa.array(range(5), type='int16'), + pa.array([-10, -5, 0, None, 10], type='int32') + ] + + batch = pa.record_batch(data, ['c0', 'c1']) + assert not batch.schema.metadata + + assert len(batch) == 5 + assert batch.num_rows == 5 + assert batch.num_columns == len(data) + # (only the second array has a null bitmap) + assert batch.nbytes == (5 * 2) + (5 * 4 + 1) + assert sys.getsizeof(batch) >= object.__sizeof__(batch) + batch.nbytes + pydict = batch.to_pydict() + assert pydict == OrderedDict([ + ('c0', [0, 1, 2, 3, 4]), + ('c1', [-10, -5, 0, None, 10]) + ]) + if sys.version_info >= (3, 7): + assert type(pydict) == dict + else: + assert type(pydict) == OrderedDict + + with pytest.raises(IndexError): + # bounds checking + batch[2] + + # Schema passed explicitly + schema = pa.schema([pa.field('c0', pa.int16(), + metadata={'key': 'value'}), + pa.field('c1', pa.int32())], + metadata={b'foo': b'bar'}) + batch = pa.record_batch(data, schema=schema) + assert batch.schema == schema + # schema as first positional argument + batch = pa.record_batch(data, schema) + assert batch.schema == schema + assert str(batch) == """pyarrow.RecordBatch +c0: int16 +c1: int32""" + + assert batch.to_string(show_metadata=True) == """\ +pyarrow.RecordBatch +c0: int16 + -- field metadata -- + key: 'value' +c1: int32 +-- schema metadata -- +foo: 'bar'""" + + wr = weakref.ref(batch) + assert wr() is not None + del batch + assert wr() is None + + +def test_recordbatch_equals(): + data1 = [ + pa.array(range(5), type='int16'), + pa.array([-10, -5, 0, None, 10], type='int32') + ] + data2 = [ + pa.array(['a', 'b', 'c']), + pa.array([['d'], ['e'], ['f']]), + ] + column_names = ['c0', 'c1'] + + batch = pa.record_batch(data1, column_names) + assert batch == pa.record_batch(data1, column_names) + assert batch.equals(pa.record_batch(data1, column_names)) + + assert batch != pa.record_batch(data2, column_names) + assert not batch.equals(pa.record_batch(data2, column_names)) + + batch_meta = pa.record_batch(data1, names=column_names, + metadata={'key': 'value'}) + assert batch_meta.equals(batch) + assert not batch_meta.equals(batch, check_metadata=True) + + # ARROW-8889 + assert not batch.equals(None) + assert batch != "foo" + + +def test_recordbatch_take(): + batch = pa.record_batch( + [pa.array([1, 2, 3, None, 5]), + pa.array(['a', 'b', 'c', 'd', 'e'])], + ['f1', 'f2']) + assert batch.take(pa.array([2, 3])).equals(batch.slice(2, 2)) + assert batch.take(pa.array([2, None])).equals( + pa.record_batch([pa.array([3, None]), pa.array(['c', None])], + ['f1', 'f2'])) + + +def test_recordbatch_column_sets_private_name(): + # ARROW-6429 + rb = pa.record_batch([pa.array([1, 2, 3, 4])], names=['a0']) + assert rb[0]._name == 'a0' + + +def test_recordbatch_from_arrays_validate_schema(): + # ARROW-6263 + arr = pa.array([1, 2]) + schema = pa.schema([pa.field('f0', pa.list_(pa.utf8()))]) + with pytest.raises(NotImplementedError): + pa.record_batch([arr], schema=schema) + + +def test_recordbatch_from_arrays_validate_lengths(): + # ARROW-2820 + data = [pa.array([1]), pa.array(["tokyo", "like", "happy"]), + pa.array(["derek"])] + + with pytest.raises(ValueError): + pa.record_batch(data, ['id', 'tags', 'name']) + + +def test_recordbatch_no_fields(): + batch = pa.record_batch([], []) + + assert len(batch) == 0 + assert batch.num_rows == 0 + assert batch.num_columns == 0 + + +def test_recordbatch_from_arrays_invalid_names(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]) + ] + with pytest.raises(ValueError): + pa.record_batch(data, names=['a', 'b', 'c']) + + with pytest.raises(ValueError): + pa.record_batch(data, names=['a']) + + +def test_recordbatch_empty_metadata(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]) + ] + + batch = pa.record_batch(data, ['c0', 'c1']) + assert batch.schema.metadata is None + + +def test_recordbatch_pickle(): + data = [ + pa.array(range(5), type='int8'), + pa.array([-10, -5, 0, 5, 10], type='float32') + ] + fields = [ + pa.field('ints', pa.int8()), + pa.field('floats', pa.float32()), + ] + schema = pa.schema(fields, metadata={b'foo': b'bar'}) + batch = pa.record_batch(data, schema=schema) + + result = pickle.loads(pickle.dumps(batch)) + assert result.equals(batch) + assert result.schema == schema + + +def test_recordbatch_get_field(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + batch = pa.RecordBatch.from_arrays(data, names=('a', 'b', 'c')) + + assert batch.field('a').equals(batch.schema.field('a')) + assert batch.field(0).equals(batch.schema.field('a')) + + with pytest.raises(KeyError): + batch.field('d') + + with pytest.raises(TypeError): + batch.field(None) + + with pytest.raises(IndexError): + batch.field(4) + + +def test_recordbatch_select_column(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + batch = pa.RecordBatch.from_arrays(data, names=('a', 'b', 'c')) + + assert batch.column('a').equals(batch.column(0)) + + with pytest.raises( + KeyError, match='Field "d" does not exist in record batch schema'): + batch.column('d') + + with pytest.raises(TypeError): + batch.column(None) + + with pytest.raises(IndexError): + batch.column(4) + + +def test_recordbatch_from_struct_array_invalid(): + with pytest.raises(TypeError): + pa.RecordBatch.from_struct_array(pa.array(range(5))) + + +def test_recordbatch_from_struct_array(): + struct_array = pa.array( + [{"ints": 1}, {"floats": 1.0}], + type=pa.struct([("ints", pa.int32()), ("floats", pa.float32())]), + ) + result = pa.RecordBatch.from_struct_array(struct_array) + assert result.equals(pa.RecordBatch.from_arrays( + [ + pa.array([1, None], type=pa.int32()), + pa.array([None, 1.0], type=pa.float32()), + ], ["ints", "floats"] + )) + + +def _table_like_slice_tests(factory): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]) + ] + names = ['c0', 'c1'] + + obj = factory(data, names=names) + + sliced = obj.slice(2) + assert sliced.num_rows == 3 + + expected = factory([x.slice(2) for x in data], names=names) + assert sliced.equals(expected) + + sliced2 = obj.slice(2, 2) + expected2 = factory([x.slice(2, 2) for x in data], names=names) + assert sliced2.equals(expected2) + + # 0 offset + assert obj.slice(0).equals(obj) + + # Slice past end of array + assert len(obj.slice(len(obj))) == 0 + + with pytest.raises(IndexError): + obj.slice(-1) + + # Check __getitem__-based slicing + assert obj.slice(0, 0).equals(obj[:0]) + assert obj.slice(0, 2).equals(obj[:2]) + assert obj.slice(2, 2).equals(obj[2:4]) + assert obj.slice(2, len(obj) - 2).equals(obj[2:]) + assert obj.slice(len(obj) - 2, 2).equals(obj[-2:]) + assert obj.slice(len(obj) - 4, 2).equals(obj[-4:-2]) + + +def test_recordbatch_slice_getitem(): + return _table_like_slice_tests(pa.RecordBatch.from_arrays) + + +def test_table_slice_getitem(): + return _table_like_slice_tests(pa.table) + + +@pytest.mark.pandas +def test_slice_zero_length_table(): + # ARROW-7907: a segfault on this code was fixed after 0.16.0 + table = pa.table({'a': pa.array([], type=pa.timestamp('us'))}) + table_slice = table.slice(0, 0) + table_slice.to_pandas() + + table = pa.table({'a': pa.chunked_array([], type=pa.string())}) + table.to_pandas() + + +def test_recordbatchlist_schema_equals(): + a1 = np.array([1], dtype='uint32') + a2 = np.array([4.0, 5.0], dtype='float64') + batch1 = pa.record_batch([pa.array(a1)], ['c1']) + batch2 = pa.record_batch([pa.array(a2)], ['c1']) + + with pytest.raises(pa.ArrowInvalid): + pa.Table.from_batches([batch1, batch2]) + + +def test_table_column_sets_private_name(): + # ARROW-6429 + t = pa.table([pa.array([1, 2, 3, 4])], names=['a0']) + assert t[0]._name == 'a0' + + +def test_table_equals(): + table = pa.Table.from_arrays([], names=[]) + assert table.equals(table) + + # ARROW-4822 + assert not table.equals(None) + + other = pa.Table.from_arrays([], names=[], metadata={'key': 'value'}) + assert not table.equals(other, check_metadata=True) + assert table.equals(other) + + +def test_table_from_batches_and_schema(): + schema = pa.schema([ + pa.field('a', pa.int64()), + pa.field('b', pa.float64()), + ]) + batch = pa.record_batch([pa.array([1]), pa.array([3.14])], + names=['a', 'b']) + table = pa.Table.from_batches([batch], schema) + assert table.schema.equals(schema) + assert table.column(0) == pa.chunked_array([[1]]) + assert table.column(1) == pa.chunked_array([[3.14]]) + + incompatible_schema = pa.schema([pa.field('a', pa.int64())]) + with pytest.raises(pa.ArrowInvalid): + pa.Table.from_batches([batch], incompatible_schema) + + incompatible_batch = pa.record_batch([pa.array([1])], ['a']) + with pytest.raises(pa.ArrowInvalid): + pa.Table.from_batches([incompatible_batch], schema) + + +@pytest.mark.pandas +def test_table_to_batches(): + from pandas.testing import assert_frame_equal + import pandas as pd + + df1 = pd.DataFrame({'a': list(range(10))}) + df2 = pd.DataFrame({'a': list(range(10, 30))}) + + batch1 = pa.RecordBatch.from_pandas(df1, preserve_index=False) + batch2 = pa.RecordBatch.from_pandas(df2, preserve_index=False) + + table = pa.Table.from_batches([batch1, batch2, batch1]) + + expected_df = pd.concat([df1, df2, df1], ignore_index=True) + + batches = table.to_batches() + assert len(batches) == 3 + + assert_frame_equal(pa.Table.from_batches(batches).to_pandas(), + expected_df) + + batches = table.to_batches(max_chunksize=15) + assert list(map(len, batches)) == [10, 15, 5, 10] + + assert_frame_equal(table.to_pandas(), expected_df) + assert_frame_equal(pa.Table.from_batches(batches).to_pandas(), + expected_df) + + table_from_iter = pa.Table.from_batches(iter([batch1, batch2, batch1])) + assert table.equals(table_from_iter) + + +def test_table_basics(): + data = [ + pa.array(range(5), type='int64'), + pa.array([-10, -5, 0, 5, 10], type='int64') + ] + table = pa.table(data, names=('a', 'b')) + table.validate() + assert len(table) == 5 + assert table.num_rows == 5 + assert table.num_columns == 2 + assert table.shape == (5, 2) + assert table.nbytes == 2 * (5 * 8) + assert sys.getsizeof(table) >= object.__sizeof__(table) + table.nbytes + pydict = table.to_pydict() + assert pydict == OrderedDict([ + ('a', [0, 1, 2, 3, 4]), + ('b', [-10, -5, 0, 5, 10]) + ]) + if sys.version_info >= (3, 7): + assert type(pydict) == dict + else: + assert type(pydict) == OrderedDict + + columns = [] + for col in table.itercolumns(): + columns.append(col) + for chunk in col.iterchunks(): + assert chunk is not None + + with pytest.raises(IndexError): + col.chunk(-1) + + with pytest.raises(IndexError): + col.chunk(col.num_chunks) + + assert table.columns == columns + assert table == pa.table(columns, names=table.column_names) + assert table != pa.table(columns[1:], names=table.column_names[1:]) + assert table != columns + + wr = weakref.ref(table) + assert wr() is not None + del table + assert wr() is None + + +def test_table_from_arrays_preserves_column_metadata(): + # Added to test https://issues.apache.org/jira/browse/ARROW-3866 + arr0 = pa.array([1, 2]) + arr1 = pa.array([3, 4]) + field0 = pa.field('field1', pa.int64(), metadata=dict(a="A", b="B")) + field1 = pa.field('field2', pa.int64(), nullable=False) + table = pa.Table.from_arrays([arr0, arr1], + schema=pa.schema([field0, field1])) + assert b"a" in table.field(0).metadata + assert table.field(1).nullable is False + + +def test_table_from_arrays_invalid_names(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]) + ] + with pytest.raises(ValueError): + pa.Table.from_arrays(data, names=['a', 'b', 'c']) + + with pytest.raises(ValueError): + pa.Table.from_arrays(data, names=['a']) + + +def test_table_from_lists(): + data = [ + list(range(5)), + [-10, -5, 0, 5, 10] + ] + + result = pa.table(data, names=['a', 'b']) + expected = pa.Table.from_arrays(data, names=['a', 'b']) + assert result.equals(expected) + + schema = pa.schema([ + pa.field('a', pa.uint16()), + pa.field('b', pa.int64()) + ]) + result = pa.table(data, schema=schema) + expected = pa.Table.from_arrays(data, schema=schema) + assert result.equals(expected) + + +def test_table_pickle(): + data = [ + pa.chunked_array([[1, 2], [3, 4]], type=pa.uint32()), + pa.chunked_array([["some", "strings", None, ""]], type=pa.string()), + ] + schema = pa.schema([pa.field('ints', pa.uint32()), + pa.field('strs', pa.string())], + metadata={b'foo': b'bar'}) + table = pa.Table.from_arrays(data, schema=schema) + + result = pickle.loads(pickle.dumps(table)) + result.validate() + assert result.equals(table) + + +def test_table_get_field(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + table = pa.Table.from_arrays(data, names=('a', 'b', 'c')) + + assert table.field('a').equals(table.schema.field('a')) + assert table.field(0).equals(table.schema.field('a')) + + with pytest.raises(KeyError): + table.field('d') + + with pytest.raises(TypeError): + table.field(None) + + with pytest.raises(IndexError): + table.field(4) + + +def test_table_select_column(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + table = pa.Table.from_arrays(data, names=('a', 'b', 'c')) + + assert table.column('a').equals(table.column(0)) + + with pytest.raises(KeyError, + match='Field "d" does not exist in table schema'): + table.column('d') + + with pytest.raises(TypeError): + table.column(None) + + with pytest.raises(IndexError): + table.column(4) + + +def test_table_column_with_duplicates(): + # ARROW-8209 + table = pa.table([pa.array([1, 2, 3]), + pa.array([4, 5, 6]), + pa.array([7, 8, 9])], names=['a', 'b', 'a']) + + with pytest.raises(KeyError, + match='Field "a" exists 2 times in table schema'): + table.column('a') + + +def test_table_add_column(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + table = pa.Table.from_arrays(data, names=('a', 'b', 'c')) + + new_field = pa.field('d', data[1].type) + t2 = table.add_column(3, new_field, data[1]) + t3 = table.append_column(new_field, data[1]) + + expected = pa.Table.from_arrays(data + [data[1]], + names=('a', 'b', 'c', 'd')) + assert t2.equals(expected) + assert t3.equals(expected) + + t4 = table.add_column(0, new_field, data[1]) + expected = pa.Table.from_arrays([data[1]] + data, + names=('d', 'a', 'b', 'c')) + assert t4.equals(expected) + + +def test_table_set_column(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + table = pa.Table.from_arrays(data, names=('a', 'b', 'c')) + + new_field = pa.field('d', data[1].type) + t2 = table.set_column(0, new_field, data[1]) + + expected_data = list(data) + expected_data[0] = data[1] + expected = pa.Table.from_arrays(expected_data, + names=('d', 'b', 'c')) + assert t2.equals(expected) + + +def test_table_drop(): + """ drop one or more columns given labels""" + a = pa.array(range(5)) + b = pa.array([-10, -5, 0, 5, 10]) + c = pa.array(range(5, 10)) + + table = pa.Table.from_arrays([a, b, c], names=('a', 'b', 'c')) + t2 = table.drop(['a', 'b']) + + exp = pa.Table.from_arrays([c], names=('c',)) + assert exp.equals(t2) + + # -- raise KeyError if column not in Table + with pytest.raises(KeyError, match="Column 'd' not found"): + table.drop(['d']) + + +def test_table_remove_column(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + table = pa.Table.from_arrays(data, names=('a', 'b', 'c')) + + t2 = table.remove_column(0) + t2.validate() + expected = pa.Table.from_arrays(data[1:], names=('b', 'c')) + assert t2.equals(expected) + + +def test_table_remove_column_empty(): + # ARROW-1865 + data = [ + pa.array(range(5)), + ] + table = pa.Table.from_arrays(data, names=['a']) + + t2 = table.remove_column(0) + t2.validate() + assert len(t2) == len(table) + + t3 = t2.add_column(0, table.field(0), table[0]) + t3.validate() + assert t3.equals(table) + + +def test_empty_table_with_names(): + # ARROW-13784 + data = [] + names = ["a", "b"] + message = ( + 'Length of names [(]2[)] does not match length of arrays [(]0[)]') + with pytest.raises(ValueError, match=message): + pa.Table.from_arrays(data, names=names) + + +def test_empty_table(): + table = pa.table([]) + + assert table.column_names == [] + assert table.equals(pa.Table.from_arrays([], [])) + + +def test_table_rename_columns(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array(range(5, 10)) + ] + table = pa.Table.from_arrays(data, names=['a', 'b', 'c']) + assert table.column_names == ['a', 'b', 'c'] + + t2 = table.rename_columns(['eh', 'bee', 'sea']) + t2.validate() + assert t2.column_names == ['eh', 'bee', 'sea'] + + expected = pa.Table.from_arrays(data, names=['eh', 'bee', 'sea']) + assert t2.equals(expected) + + +def test_table_flatten(): + ty1 = pa.struct([pa.field('x', pa.int16()), + pa.field('y', pa.float32())]) + ty2 = pa.struct([pa.field('nest', ty1)]) + a = pa.array([(1, 2.5), (3, 4.5)], type=ty1) + b = pa.array([((11, 12.5),), ((13, 14.5),)], type=ty2) + c = pa.array([False, True], type=pa.bool_()) + + table = pa.Table.from_arrays([a, b, c], names=['a', 'b', 'c']) + t2 = table.flatten() + t2.validate() + expected = pa.Table.from_arrays([ + pa.array([1, 3], type=pa.int16()), + pa.array([2.5, 4.5], type=pa.float32()), + pa.array([(11, 12.5), (13, 14.5)], type=ty1), + c], + names=['a.x', 'a.y', 'b.nest', 'c']) + assert t2.equals(expected) + + +def test_table_combine_chunks(): + batch1 = pa.record_batch([pa.array([1]), pa.array(["a"])], + names=['f1', 'f2']) + batch2 = pa.record_batch([pa.array([2]), pa.array(["b"])], + names=['f1', 'f2']) + table = pa.Table.from_batches([batch1, batch2]) + combined = table.combine_chunks() + combined.validate() + assert combined.equals(table) + for c in combined.columns: + assert c.num_chunks == 1 + + +def test_table_unify_dictionaries(): + batch1 = pa.record_batch([ + pa.array(["foo", "bar", None, "foo"]).dictionary_encode(), + pa.array([123, 456, 456, 789]).dictionary_encode(), + pa.array([True, False, None, None])], names=['a', 'b', 'c']) + batch2 = pa.record_batch([ + pa.array(["quux", "foo", None, "quux"]).dictionary_encode(), + pa.array([456, 789, 789, None]).dictionary_encode(), + pa.array([False, None, None, True])], names=['a', 'b', 'c']) + + table = pa.Table.from_batches([batch1, batch2]) + table = table.replace_schema_metadata({b"key1": b"value1"}) + assert table.column(0).chunk(0).dictionary.equals( + pa.array(["foo", "bar"])) + assert table.column(0).chunk(1).dictionary.equals( + pa.array(["quux", "foo"])) + assert table.column(1).chunk(0).dictionary.equals( + pa.array([123, 456, 789])) + assert table.column(1).chunk(1).dictionary.equals( + pa.array([456, 789])) + + table = table.unify_dictionaries(pa.default_memory_pool()) + expected_dict_0 = pa.array(["foo", "bar", "quux"]) + expected_dict_1 = pa.array([123, 456, 789]) + assert table.column(0).chunk(0).dictionary.equals(expected_dict_0) + assert table.column(0).chunk(1).dictionary.equals(expected_dict_0) + assert table.column(1).chunk(0).dictionary.equals(expected_dict_1) + assert table.column(1).chunk(1).dictionary.equals(expected_dict_1) + + assert table.to_pydict() == { + 'a': ["foo", "bar", None, "foo", "quux", "foo", None, "quux"], + 'b': [123, 456, 456, 789, 456, 789, 789, None], + 'c': [True, False, None, None, False, None, None, True], + } + assert table.schema.metadata == {b"key1": b"value1"} + + +def test_concat_tables(): + data = [ + list(range(5)), + [-10., -5., 0., 5., 10.] + ] + data2 = [ + list(range(5, 10)), + [1., 2., 3., 4., 5.] + ] + + t1 = pa.Table.from_arrays([pa.array(x) for x in data], + names=('a', 'b')) + t2 = pa.Table.from_arrays([pa.array(x) for x in data2], + names=('a', 'b')) + + result = pa.concat_tables([t1, t2]) + result.validate() + assert len(result) == 10 + + expected = pa.Table.from_arrays([pa.array(x + y) + for x, y in zip(data, data2)], + names=('a', 'b')) + + assert result.equals(expected) + + +def test_concat_tables_none_table(): + # ARROW-11997 + with pytest.raises(AttributeError): + pa.concat_tables([None]) + + +@pytest.mark.pandas +def test_concat_tables_with_different_schema_metadata(): + import pandas as pd + + schema = pa.schema([ + pa.field('a', pa.string()), + pa.field('b', pa.string()), + ]) + + values = list('abcdefgh') + df1 = pd.DataFrame({'a': values, 'b': values}) + df2 = pd.DataFrame({'a': [np.nan] * 8, 'b': values}) + + table1 = pa.Table.from_pandas(df1, schema=schema, preserve_index=False) + table2 = pa.Table.from_pandas(df2, schema=schema, preserve_index=False) + assert table1.schema.equals(table2.schema) + assert not table1.schema.equals(table2.schema, check_metadata=True) + + table3 = pa.concat_tables([table1, table2]) + assert table1.schema.equals(table3.schema, check_metadata=True) + assert table2.schema.equals(table3.schema) + + +def test_concat_tables_with_promotion(): + t1 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.int64())], ["int64_field"]) + t2 = pa.Table.from_arrays( + [pa.array([1.0, 2.0], type=pa.float32())], ["float_field"]) + + result = pa.concat_tables([t1, t2], promote=True) + + assert result.equals(pa.Table.from_arrays([ + pa.array([1, 2, None, None], type=pa.int64()), + pa.array([None, None, 1.0, 2.0], type=pa.float32()), + ], ["int64_field", "float_field"])) + + +def test_concat_tables_with_promotion_error(): + t1 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.int64())], ["f"]) + t2 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.float32())], ["f"]) + + with pytest.raises(pa.ArrowInvalid): + pa.concat_tables([t1, t2], promote=True) + + +def test_table_negative_indexing(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + pa.array([1.0, 2.0, 3.0, 4.0, 5.0]), + pa.array(['ab', 'bc', 'cd', 'de', 'ef']), + ] + table = pa.Table.from_arrays(data, names=tuple('abcd')) + + assert table[-1].equals(table[3]) + assert table[-2].equals(table[2]) + assert table[-3].equals(table[1]) + assert table[-4].equals(table[0]) + + with pytest.raises(IndexError): + table[-5] + + with pytest.raises(IndexError): + table[4] + + +def test_table_cast_to_incompatible_schema(): + data = [ + pa.array(range(5)), + pa.array([-10, -5, 0, 5, 10]), + ] + table = pa.Table.from_arrays(data, names=tuple('ab')) + + target_schema1 = pa.schema([ + pa.field('A', pa.int32()), + pa.field('b', pa.int16()), + ]) + target_schema2 = pa.schema([ + pa.field('a', pa.int32()), + ]) + message = ("Target schema's field names are not matching the table's " + "field names:.*") + with pytest.raises(ValueError, match=message): + table.cast(target_schema1) + with pytest.raises(ValueError, match=message): + table.cast(target_schema2) + + +def test_table_safe_casting(): + data = [ + pa.array(range(5), type=pa.int64()), + pa.array([-10, -5, 0, 5, 10], type=pa.int32()), + pa.array([1.0, 2.0, 3.0, 4.0, 5.0], type=pa.float64()), + pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) + ] + table = pa.Table.from_arrays(data, names=tuple('abcd')) + + expected_data = [ + pa.array(range(5), type=pa.int32()), + pa.array([-10, -5, 0, 5, 10], type=pa.int16()), + pa.array([1, 2, 3, 4, 5], type=pa.int64()), + pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) + ] + expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd')) + + target_schema = pa.schema([ + pa.field('a', pa.int32()), + pa.field('b', pa.int16()), + pa.field('c', pa.int64()), + pa.field('d', pa.string()) + ]) + casted_table = table.cast(target_schema) + + assert casted_table.equals(expected_table) + + +def test_table_unsafe_casting(): + data = [ + pa.array(range(5), type=pa.int64()), + pa.array([-10, -5, 0, 5, 10], type=pa.int32()), + pa.array([1.1, 2.2, 3.3, 4.4, 5.5], type=pa.float64()), + pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) + ] + table = pa.Table.from_arrays(data, names=tuple('abcd')) + + expected_data = [ + pa.array(range(5), type=pa.int32()), + pa.array([-10, -5, 0, 5, 10], type=pa.int16()), + pa.array([1, 2, 3, 4, 5], type=pa.int64()), + pa.array(['ab', 'bc', 'cd', 'de', 'ef'], type=pa.string()) + ] + expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd')) + + target_schema = pa.schema([ + pa.field('a', pa.int32()), + pa.field('b', pa.int16()), + pa.field('c', pa.int64()), + pa.field('d', pa.string()) + ]) + + with pytest.raises(pa.ArrowInvalid, match='truncated'): + table.cast(target_schema) + + casted_table = table.cast(target_schema, safe=False) + assert casted_table.equals(expected_table) + + +def test_invalid_table_construct(): + array = np.array([0, 1], dtype=np.uint8) + u8 = pa.uint8() + arrays = [pa.array(array, type=u8), pa.array(array[1:], type=u8)] + + with pytest.raises(pa.lib.ArrowInvalid): + pa.Table.from_arrays(arrays, names=["a1", "a2"]) + + +@pytest.mark.parametrize('data, klass', [ + ((['', 'foo', 'bar'], [4.5, 5, None]), list), + ((['', 'foo', 'bar'], [4.5, 5, None]), pa.array), + (([[''], ['foo', 'bar']], [[4.5], [5., None]]), pa.chunked_array), +]) +def test_from_arrays_schema(data, klass): + data = [klass(data[0]), klass(data[1])] + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())]) + + table = pa.Table.from_arrays(data, schema=schema) + assert table.num_columns == 2 + assert table.num_rows == 3 + assert table.schema == schema + + # length of data and schema not matching + schema = pa.schema([('strs', pa.utf8())]) + with pytest.raises(ValueError): + pa.Table.from_arrays(data, schema=schema) + + # with different but compatible schema + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())]) + table = pa.Table.from_arrays(data, schema=schema) + assert pa.types.is_float32(table.column('floats').type) + assert table.num_columns == 2 + assert table.num_rows == 3 + assert table.schema == schema + + # with different and incompatible schema + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.timestamp('s'))]) + with pytest.raises((NotImplementedError, TypeError)): + pa.Table.from_pydict(data, schema=schema) + + # Cannot pass both schema and metadata / names + with pytest.raises(ValueError): + pa.Table.from_arrays(data, schema=schema, names=['strs', 'floats']) + + with pytest.raises(ValueError): + pa.Table.from_arrays(data, schema=schema, metadata={b'foo': b'bar'}) + + +@pytest.mark.parametrize( + ('cls'), + [ + (pa.Table), + (pa.RecordBatch) + ] +) +def test_table_from_pydict(cls): + table = cls.from_pydict({}) + assert table.num_columns == 0 + assert table.num_rows == 0 + assert table.schema == pa.schema([]) + assert table.to_pydict() == {} + + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64())]) + + # With lists as values + data = OrderedDict([('strs', ['', 'foo', 'bar']), + ('floats', [4.5, 5, None])]) + table = cls.from_pydict(data) + assert table.num_columns == 2 + assert table.num_rows == 3 + assert table.schema == schema + assert table.to_pydict() == data + + # With metadata and inferred schema + metadata = {b'foo': b'bar'} + schema = schema.with_metadata(metadata) + table = cls.from_pydict(data, metadata=metadata) + assert table.schema == schema + assert table.schema.metadata == metadata + assert table.to_pydict() == data + + # With explicit schema + table = cls.from_pydict(data, schema=schema) + assert table.schema == schema + assert table.schema.metadata == metadata + assert table.to_pydict() == data + + # Cannot pass both schema and metadata + with pytest.raises(ValueError): + cls.from_pydict(data, schema=schema, metadata=metadata) + + # Non-convertible values given schema + with pytest.raises(TypeError): + cls.from_pydict({'c0': [0, 1, 2]}, + schema=pa.schema([("c0", pa.string())])) + + # Missing schema fields from the passed mapping + with pytest.raises(KeyError, match="doesn\'t contain.* c, d"): + cls.from_pydict( + {'a': [1, 2, 3], 'b': [3, 4, 5]}, + schema=pa.schema([ + ('a', pa.int64()), + ('c', pa.int32()), + ('d', pa.int16()) + ]) + ) + + # Passed wrong schema type + with pytest.raises(TypeError): + cls.from_pydict({'a': [1, 2, 3]}, schema={}) + + +@pytest.mark.parametrize('data, klass', [ + ((['', 'foo', 'bar'], [4.5, 5, None]), pa.array), + (([[''], ['foo', 'bar']], [[4.5], [5., None]]), pa.chunked_array), +]) +def test_table_from_pydict_arrow_arrays(data, klass): + data = OrderedDict([('strs', klass(data[0])), ('floats', klass(data[1]))]) + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64())]) + + # With arrays as values + table = pa.Table.from_pydict(data) + assert table.num_columns == 2 + assert table.num_rows == 3 + assert table.schema == schema + + # With explicit (matching) schema + table = pa.Table.from_pydict(data, schema=schema) + assert table.num_columns == 2 + assert table.num_rows == 3 + assert table.schema == schema + + # with different but compatible schema + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())]) + table = pa.Table.from_pydict(data, schema=schema) + assert pa.types.is_float32(table.column('floats').type) + assert table.num_columns == 2 + assert table.num_rows == 3 + assert table.schema == schema + + # with different and incompatible schema + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.timestamp('s'))]) + with pytest.raises((NotImplementedError, TypeError)): + pa.Table.from_pydict(data, schema=schema) + + +@pytest.mark.parametrize('data, klass', [ + ((['', 'foo', 'bar'], [4.5, 5, None]), list), + ((['', 'foo', 'bar'], [4.5, 5, None]), pa.array), + (([[''], ['foo', 'bar']], [[4.5], [5., None]]), pa.chunked_array), +]) +def test_table_from_pydict_schema(data, klass): + # passed schema is source of truth for the columns + + data = OrderedDict([('strs', klass(data[0])), ('floats', klass(data[1]))]) + + # schema has columns not present in data -> error + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64()), + ('ints', pa.int64())]) + with pytest.raises(KeyError, match='ints'): + pa.Table.from_pydict(data, schema=schema) + + # data has columns not present in schema -> ignored + schema = pa.schema([('strs', pa.utf8())]) + table = pa.Table.from_pydict(data, schema=schema) + assert table.num_columns == 1 + assert table.schema == schema + assert table.column_names == ['strs'] + + +@pytest.mark.pandas +def test_table_from_pandas_schema(): + # passed schema is source of truth for the columns + import pandas as pd + + df = pd.DataFrame(OrderedDict([('strs', ['', 'foo', 'bar']), + ('floats', [4.5, 5, None])])) + + # with different but compatible schema + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float32())]) + table = pa.Table.from_pandas(df, schema=schema) + assert pa.types.is_float32(table.column('floats').type) + assert table.schema.remove_metadata() == schema + + # with different and incompatible schema + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.timestamp('s'))]) + with pytest.raises((NotImplementedError, TypeError)): + pa.Table.from_pandas(df, schema=schema) + + # schema has columns not present in data -> error + schema = pa.schema([('strs', pa.utf8()), ('floats', pa.float64()), + ('ints', pa.int64())]) + with pytest.raises(KeyError, match='ints'): + pa.Table.from_pandas(df, schema=schema) + + # data has columns not present in schema -> ignored + schema = pa.schema([('strs', pa.utf8())]) + table = pa.Table.from_pandas(df, schema=schema) + assert table.num_columns == 1 + assert table.schema.remove_metadata() == schema + assert table.column_names == ['strs'] + + +@pytest.mark.pandas +def test_table_factory_function(): + import pandas as pd + + # Put in wrong order to make sure that lines up with schema + d = OrderedDict([('b', ['a', 'b', 'c']), ('a', [1, 2, 3])]) + + d_explicit = {'b': pa.array(['a', 'b', 'c'], type='string'), + 'a': pa.array([1, 2, 3], type='int32')} + + schema = pa.schema([('a', pa.int32()), ('b', pa.string())]) + + df = pd.DataFrame(d) + table1 = pa.table(df) + table2 = pa.Table.from_pandas(df) + assert table1.equals(table2) + table1 = pa.table(df, schema=schema) + table2 = pa.Table.from_pandas(df, schema=schema) + assert table1.equals(table2) + + table1 = pa.table(d_explicit) + table2 = pa.Table.from_pydict(d_explicit) + assert table1.equals(table2) + + # schema coerces type + table1 = pa.table(d, schema=schema) + table2 = pa.Table.from_pydict(d, schema=schema) + assert table1.equals(table2) + + +def test_table_factory_function_args(): + # from_pydict not accepting names: + with pytest.raises(ValueError): + pa.table({'a': [1, 2, 3]}, names=['a']) + + # backwards compatibility for schema as first positional argument + schema = pa.schema([('a', pa.int32())]) + table = pa.table({'a': pa.array([1, 2, 3], type=pa.int64())}, schema) + assert table.column('a').type == pa.int32() + + # from_arrays: accept both names and schema as positional first argument + data = [pa.array([1, 2, 3], type='int64')] + names = ['a'] + table = pa.table(data, names) + assert table.column_names == names + schema = pa.schema([('a', pa.int64())]) + table = pa.table(data, schema) + assert table.column_names == names + + +@pytest.mark.pandas +def test_table_factory_function_args_pandas(): + import pandas as pd + + # from_pandas not accepting names or metadata: + with pytest.raises(ValueError): + pa.table(pd.DataFrame({'a': [1, 2, 3]}), names=['a']) + + with pytest.raises(ValueError): + pa.table(pd.DataFrame({'a': [1, 2, 3]}), metadata={b'foo': b'bar'}) + + # backwards compatibility for schema as first positional argument + schema = pa.schema([('a', pa.int32())]) + table = pa.table(pd.DataFrame({'a': [1, 2, 3]}), schema) + assert table.column('a').type == pa.int32() + + +def test_factory_functions_invalid_input(): + with pytest.raises(TypeError, match="Expected pandas DataFrame, python"): + pa.table("invalid input") + + with pytest.raises(TypeError, match="Expected pandas DataFrame"): + pa.record_batch("invalid input") + + +def test_table_repr_to_string(): + # Schema passed explicitly + schema = pa.schema([pa.field('c0', pa.int16(), + metadata={'key': 'value'}), + pa.field('c1', pa.int32())], + metadata={b'foo': b'bar'}) + + tab = pa.table([pa.array([1, 2, 3, 4], type='int16'), + pa.array([10, 20, 30, 40], type='int32')], schema=schema) + assert str(tab) == """pyarrow.Table +c0: int16 +c1: int32 +---- +c0: [[1,2,3,4]] +c1: [[10,20,30,40]]""" + + assert tab.to_string(show_metadata=True) == """\ +pyarrow.Table +c0: int16 + -- field metadata -- + key: 'value' +c1: int32 +-- schema metadata -- +foo: 'bar'""" + + assert tab.to_string(preview_cols=5) == """\ +pyarrow.Table +c0: int16 +c1: int32 +---- +c0: [[1,2,3,4]] +c1: [[10,20,30,40]]""" + + assert tab.to_string(preview_cols=1) == """\ +pyarrow.Table +c0: int16 +c1: int32 +---- +c0: [[1,2,3,4]] +...""" + + +def test_table_repr_to_string_ellipsis(): + # Schema passed explicitly + schema = pa.schema([pa.field('c0', pa.int16(), + metadata={'key': 'value'}), + pa.field('c1', pa.int32())], + metadata={b'foo': b'bar'}) + + tab = pa.table([pa.array([1, 2, 3, 4]*10, type='int16'), + pa.array([10, 20, 30, 40]*10, type='int32')], + schema=schema) + assert str(tab) == """pyarrow.Table +c0: int16 +c1: int32 +---- +c0: [[1,2,3,4,1,2,3,4,1,2,...,3,4,1,2,3,4,1,2,3,4]] +c1: [[10,20,30,40,10,20,30,40,10,20,...,30,40,10,20,30,40,10,20,30,40]]""" + + +def test_table_function_unicode_schema(): + col_a = "äääh" + col_b = "öööf" + + # Put in wrong order to make sure that lines up with schema + d = OrderedDict([(col_b, ['a', 'b', 'c']), (col_a, [1, 2, 3])]) + + schema = pa.schema([(col_a, pa.int32()), (col_b, pa.string())]) + + result = pa.table(d, schema=schema) + assert result[0].chunk(0).equals(pa.array([1, 2, 3], type='int32')) + assert result[1].chunk(0).equals(pa.array(['a', 'b', 'c'], type='string')) + + +def test_table_take_vanilla_functionality(): + table = pa.table( + [pa.array([1, 2, 3, None, 5]), + pa.array(['a', 'b', 'c', 'd', 'e'])], + ['f1', 'f2']) + + assert table.take(pa.array([2, 3])).equals(table.slice(2, 2)) + + +def test_table_take_null_index(): + table = pa.table( + [pa.array([1, 2, 3, None, 5]), + pa.array(['a', 'b', 'c', 'd', 'e'])], + ['f1', 'f2']) + + result_with_null_index = pa.table( + [pa.array([1, None]), + pa.array(['a', None])], + ['f1', 'f2']) + + assert table.take(pa.array([0, None])).equals(result_with_null_index) + + +def test_table_take_non_consecutive(): + table = pa.table( + [pa.array([1, 2, 3, None, 5]), + pa.array(['a', 'b', 'c', 'd', 'e'])], + ['f1', 'f2']) + + result_non_consecutive = pa.table( + [pa.array([2, None]), + pa.array(['b', 'd'])], + ['f1', 'f2']) + + assert table.take(pa.array([1, 3])).equals(result_non_consecutive) + + +def test_table_select(): + a1 = pa.array([1, 2, 3, None, 5]) + a2 = pa.array(['a', 'b', 'c', 'd', 'e']) + a3 = pa.array([[1, 2], [3, 4], [5, 6], None, [9, 10]]) + table = pa.table([a1, a2, a3], ['f1', 'f2', 'f3']) + + # selecting with string names + result = table.select(['f1']) + expected = pa.table([a1], ['f1']) + assert result.equals(expected) + + result = table.select(['f3', 'f2']) + expected = pa.table([a3, a2], ['f3', 'f2']) + assert result.equals(expected) + + # selecting with integer indices + result = table.select([0]) + expected = pa.table([a1], ['f1']) + assert result.equals(expected) + + result = table.select([2, 1]) + expected = pa.table([a3, a2], ['f3', 'f2']) + assert result.equals(expected) + + # preserve metadata + table2 = table.replace_schema_metadata({"a": "test"}) + result = table2.select(["f1", "f2"]) + assert b"a" in result.schema.metadata + + # selecting non-existing column raises + with pytest.raises(KeyError, match='Field "f5" does not exist'): + table.select(['f5']) + + with pytest.raises(IndexError, match="index out of bounds"): + table.select([5]) + + # duplicate selection gives duplicated names in resulting table + result = table.select(['f2', 'f2']) + expected = pa.table([a2, a2], ['f2', 'f2']) + assert result.equals(expected) + + # selection duplicated column raises + table = pa.table([a1, a2, a3], ['f1', 'f2', 'f1']) + with pytest.raises(KeyError, match='Field "f1" exists 2 times'): + table.select(['f1']) + + result = table.select(['f2']) + expected = pa.table([a2], ['f2']) + assert result.equals(expected) diff --git a/src/arrow/python/pyarrow/tests/test_tensor.py b/src/arrow/python/pyarrow/tests/test_tensor.py new file mode 100644 index 000000000..aee46bc93 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_tensor.py @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import pytest +import weakref + +import numpy as np +import pyarrow as pa + + +tensor_type_pairs = [ + ('i1', pa.int8()), + ('i2', pa.int16()), + ('i4', pa.int32()), + ('i8', pa.int64()), + ('u1', pa.uint8()), + ('u2', pa.uint16()), + ('u4', pa.uint32()), + ('u8', pa.uint64()), + ('f2', pa.float16()), + ('f4', pa.float32()), + ('f8', pa.float64()) +] + + +def test_tensor_attrs(): + data = np.random.randn(10, 4) + + tensor = pa.Tensor.from_numpy(data) + + assert tensor.ndim == 2 + assert tensor.dim_names == [] + assert tensor.size == 40 + assert tensor.shape == data.shape + assert tensor.strides == data.strides + + assert tensor.is_contiguous + assert tensor.is_mutable + + # not writeable + data2 = data.copy() + data2.flags.writeable = False + tensor = pa.Tensor.from_numpy(data2) + assert not tensor.is_mutable + + # With dim_names + tensor = pa.Tensor.from_numpy(data, dim_names=('x', 'y')) + assert tensor.ndim == 2 + assert tensor.dim_names == ['x', 'y'] + assert tensor.dim_name(0) == 'x' + assert tensor.dim_name(1) == 'y' + + wr = weakref.ref(tensor) + assert wr() is not None + del tensor + assert wr() is None + + +def test_tensor_base_object(): + tensor = pa.Tensor.from_numpy(np.random.randn(10, 4)) + n = sys.getrefcount(tensor) + array = tensor.to_numpy() # noqa + assert sys.getrefcount(tensor) == n + 1 + + +@pytest.mark.parametrize('dtype_str,arrow_type', tensor_type_pairs) +def test_tensor_numpy_roundtrip(dtype_str, arrow_type): + dtype = np.dtype(dtype_str) + data = (100 * np.random.randn(10, 4)).astype(dtype) + + tensor = pa.Tensor.from_numpy(data) + assert tensor.type == arrow_type + + repr(tensor) + + result = tensor.to_numpy() + assert (data == result).all() + + +def test_tensor_ipc_roundtrip(tmpdir): + data = np.random.randn(10, 4) + tensor = pa.Tensor.from_numpy(data) + + path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-roundtrip') + mmap = pa.create_memory_map(path, 1024) + + pa.ipc.write_tensor(tensor, mmap) + + mmap.seek(0) + result = pa.ipc.read_tensor(mmap) + + assert result.equals(tensor) + + +@pytest.mark.gzip +def test_tensor_ipc_read_from_compressed(tempdir): + # ARROW-5910 + data = np.random.randn(10, 4) + tensor = pa.Tensor.from_numpy(data) + + path = tempdir / 'tensor-compressed-file' + + out_stream = pa.output_stream(path, compression='gzip') + pa.ipc.write_tensor(tensor, out_stream) + out_stream.close() + + result = pa.ipc.read_tensor(pa.input_stream(path, compression='gzip')) + assert result.equals(tensor) + + +def test_tensor_ipc_strided(tmpdir): + data1 = np.random.randn(10, 4) + tensor1 = pa.Tensor.from_numpy(data1[::2]) + + data2 = np.random.randn(10, 6, 4) + tensor2 = pa.Tensor.from_numpy(data2[::, ::2, ::]) + + path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-strided') + mmap = pa.create_memory_map(path, 2048) + + for tensor in [tensor1, tensor2]: + mmap.seek(0) + pa.ipc.write_tensor(tensor, mmap) + + mmap.seek(0) + result = pa.ipc.read_tensor(mmap) + + assert result.equals(tensor) + + +def test_tensor_equals(): + def eq(a, b): + assert a.equals(b) + assert a == b + assert not (a != b) + + def ne(a, b): + assert not a.equals(b) + assert not (a == b) + assert a != b + + data = np.random.randn(10, 6, 4)[::, ::2, ::] + tensor1 = pa.Tensor.from_numpy(data) + tensor2 = pa.Tensor.from_numpy(np.ascontiguousarray(data)) + eq(tensor1, tensor2) + data = data.copy() + data[9, 0, 0] = 1.0 + tensor2 = pa.Tensor.from_numpy(np.ascontiguousarray(data)) + ne(tensor1, tensor2) + + +def test_tensor_hashing(): + # Tensors are unhashable + with pytest.raises(TypeError, match="unhashable"): + hash(pa.Tensor.from_numpy(np.arange(10))) + + +def test_tensor_size(): + data = np.random.randn(10, 4) + tensor = pa.Tensor.from_numpy(data) + assert pa.ipc.get_tensor_size(tensor) > (data.size * 8) + + +def test_read_tensor(tmpdir): + # Create and write tensor tensor + data = np.random.randn(10, 4) + tensor = pa.Tensor.from_numpy(data) + data_size = pa.ipc.get_tensor_size(tensor) + path = os.path.join(str(tmpdir), 'pyarrow-tensor-ipc-read-tensor') + write_mmap = pa.create_memory_map(path, data_size) + pa.ipc.write_tensor(tensor, write_mmap) + # Try to read tensor + read_mmap = pa.memory_map(path, mode='r') + array = pa.ipc.read_tensor(read_mmap).to_numpy() + np.testing.assert_equal(data, array) + + +def test_tensor_memoryview(): + # Tensors support the PEP 3118 buffer protocol + for dtype, expected_format in [(np.int8, '=b'), + (np.int64, '=q'), + (np.uint64, '=Q'), + (np.float16, 'e'), + (np.float64, 'd'), + ]: + data = np.arange(10, dtype=dtype) + dtype = data.dtype + lst = data.tolist() + tensor = pa.Tensor.from_numpy(data) + m = memoryview(tensor) + assert m.format == expected_format + assert m.shape == data.shape + assert m.strides == data.strides + assert m.ndim == 1 + assert m.nbytes == data.nbytes + assert m.itemsize == data.itemsize + assert m.itemsize * 8 == tensor.type.bit_width + assert np.frombuffer(m, dtype).tolist() == lst + del tensor, data + assert np.frombuffer(m, dtype).tolist() == lst diff --git a/src/arrow/python/pyarrow/tests/test_types.py b/src/arrow/python/pyarrow/tests/test_types.py new file mode 100644 index 000000000..07715b985 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_types.py @@ -0,0 +1,1067 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import OrderedDict +from collections.abc import Iterator +from functools import partial +import datetime +import sys + +import pickle +import pytest +import pytz +import hypothesis as h +import hypothesis.strategies as st +import hypothesis.extra.pytz as tzst +import weakref + +import numpy as np +import pyarrow as pa +import pyarrow.types as types +import pyarrow.tests.strategies as past + + +def get_many_types(): + # returning them from a function is required because of pa.dictionary + # type holds a pyarrow array and test_array.py::test_toal_bytes_allocated + # checks that the default memory pool has zero allocated bytes + return ( + pa.null(), + pa.bool_(), + pa.int32(), + pa.time32('s'), + pa.time64('us'), + pa.date32(), + pa.timestamp('us'), + pa.timestamp('us', tz='UTC'), + pa.timestamp('us', tz='Europe/Paris'), + pa.duration('s'), + pa.float16(), + pa.float32(), + pa.float64(), + pa.decimal128(19, 4), + pa.decimal256(76, 38), + pa.string(), + pa.binary(), + pa.binary(10), + pa.large_string(), + pa.large_binary(), + pa.list_(pa.int32()), + pa.list_(pa.int32(), 2), + pa.large_list(pa.uint16()), + pa.map_(pa.string(), pa.int32()), + pa.map_(pa.field('key', pa.int32(), nullable=False), + pa.field('value', pa.int32())), + pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.int8()), + pa.field('c', pa.string())]), + pa.struct([pa.field('a', pa.int32(), nullable=False), + pa.field('b', pa.int8(), nullable=False), + pa.field('c', pa.string())]), + pa.union([pa.field('a', pa.binary(10)), + pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), + pa.union([pa.field('a', pa.binary(10)), + pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE, + type_codes=[4, 8]), + pa.union([pa.field('a', pa.binary(10)), + pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), + pa.union([pa.field('a', pa.binary(10), nullable=False), + pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), + pa.dictionary(pa.int32(), pa.string()) + ) + + +def test_is_boolean(): + assert types.is_boolean(pa.bool_()) + assert not types.is_boolean(pa.int8()) + + +def test_is_integer(): + signed_ints = [pa.int8(), pa.int16(), pa.int32(), pa.int64()] + unsigned_ints = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()] + + for t in signed_ints + unsigned_ints: + assert types.is_integer(t) + + for t in signed_ints: + assert types.is_signed_integer(t) + assert not types.is_unsigned_integer(t) + + for t in unsigned_ints: + assert types.is_unsigned_integer(t) + assert not types.is_signed_integer(t) + + assert not types.is_integer(pa.float32()) + assert not types.is_signed_integer(pa.float32()) + + +def test_is_floating(): + for t in [pa.float16(), pa.float32(), pa.float64()]: + assert types.is_floating(t) + + assert not types.is_floating(pa.int32()) + + +def test_is_null(): + assert types.is_null(pa.null()) + assert not types.is_null(pa.list_(pa.int32())) + + +def test_null_field_may_not_be_non_nullable(): + # ARROW-7273 + with pytest.raises(ValueError): + pa.field('f0', pa.null(), nullable=False) + + +def test_is_decimal(): + decimal128 = pa.decimal128(19, 4) + decimal256 = pa.decimal256(76, 38) + int32 = pa.int32() + + assert types.is_decimal(decimal128) + assert types.is_decimal(decimal256) + assert not types.is_decimal(int32) + + assert types.is_decimal128(decimal128) + assert not types.is_decimal128(decimal256) + assert not types.is_decimal128(int32) + + assert not types.is_decimal256(decimal128) + assert types.is_decimal256(decimal256) + assert not types.is_decimal256(int32) + + +def test_is_list(): + a = pa.list_(pa.int32()) + b = pa.large_list(pa.int32()) + c = pa.list_(pa.int32(), 3) + + assert types.is_list(a) + assert not types.is_large_list(a) + assert not types.is_fixed_size_list(a) + assert types.is_large_list(b) + assert not types.is_list(b) + assert not types.is_fixed_size_list(b) + assert types.is_fixed_size_list(c) + assert not types.is_list(c) + assert not types.is_large_list(c) + + assert not types.is_list(pa.int32()) + + +def test_is_map(): + m = pa.map_(pa.utf8(), pa.int32()) + + assert types.is_map(m) + assert not types.is_map(pa.int32()) + + fields = pa.map_(pa.field('key_name', pa.utf8(), nullable=False), + pa.field('value_name', pa.int32())) + assert types.is_map(fields) + + entries_type = pa.struct([pa.field('key', pa.int8()), + pa.field('value', pa.int8())]) + list_type = pa.list_(entries_type) + assert not types.is_map(list_type) + + +def test_is_dictionary(): + assert types.is_dictionary(pa.dictionary(pa.int32(), pa.string())) + assert not types.is_dictionary(pa.int32()) + + +def test_is_nested_or_struct(): + struct_ex = pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.int8()), + pa.field('c', pa.string())]) + + assert types.is_struct(struct_ex) + assert not types.is_struct(pa.list_(pa.int32())) + + assert types.is_nested(struct_ex) + assert types.is_nested(pa.list_(pa.int32())) + assert types.is_nested(pa.large_list(pa.int32())) + assert not types.is_nested(pa.int32()) + + +def test_is_union(): + for mode in [pa.lib.UnionMode_SPARSE, pa.lib.UnionMode_DENSE]: + assert types.is_union(pa.union([pa.field('a', pa.int32()), + pa.field('b', pa.int8()), + pa.field('c', pa.string())], + mode=mode)) + assert not types.is_union(pa.list_(pa.int32())) + + +# TODO(wesm): is_map, once implemented + + +def test_is_binary_string(): + assert types.is_binary(pa.binary()) + assert not types.is_binary(pa.string()) + assert not types.is_binary(pa.large_binary()) + assert not types.is_binary(pa.large_string()) + + assert types.is_string(pa.string()) + assert types.is_unicode(pa.string()) + assert not types.is_string(pa.binary()) + assert not types.is_string(pa.large_string()) + assert not types.is_string(pa.large_binary()) + + assert types.is_large_binary(pa.large_binary()) + assert not types.is_large_binary(pa.large_string()) + assert not types.is_large_binary(pa.binary()) + assert not types.is_large_binary(pa.string()) + + assert types.is_large_string(pa.large_string()) + assert not types.is_large_string(pa.large_binary()) + assert not types.is_large_string(pa.string()) + assert not types.is_large_string(pa.binary()) + + assert types.is_fixed_size_binary(pa.binary(5)) + assert not types.is_fixed_size_binary(pa.binary()) + + +def test_is_temporal_date_time_timestamp(): + date_types = [pa.date32(), pa.date64()] + time_types = [pa.time32('s'), pa.time64('ns')] + timestamp_types = [pa.timestamp('ms')] + duration_types = [pa.duration('ms')] + interval_types = [pa.month_day_nano_interval()] + + for case in (date_types + time_types + timestamp_types + duration_types + + interval_types): + assert types.is_temporal(case) + + for case in date_types: + assert types.is_date(case) + assert not types.is_time(case) + assert not types.is_timestamp(case) + assert not types.is_duration(case) + assert not types.is_interval(case) + + for case in time_types: + assert types.is_time(case) + assert not types.is_date(case) + assert not types.is_timestamp(case) + assert not types.is_duration(case) + assert not types.is_interval(case) + + for case in timestamp_types: + assert types.is_timestamp(case) + assert not types.is_date(case) + assert not types.is_time(case) + assert not types.is_duration(case) + assert not types.is_interval(case) + + for case in duration_types: + assert types.is_duration(case) + assert not types.is_date(case) + assert not types.is_time(case) + assert not types.is_timestamp(case) + assert not types.is_interval(case) + + for case in interval_types: + assert types.is_interval(case) + assert not types.is_date(case) + assert not types.is_time(case) + assert not types.is_timestamp(case) + + assert not types.is_temporal(pa.int32()) + + +def test_is_primitive(): + assert types.is_primitive(pa.int32()) + assert not types.is_primitive(pa.list_(pa.int32())) + + +@pytest.mark.parametrize(('tz', 'expected'), [ + (pytz.utc, 'UTC'), + (pytz.timezone('Europe/Paris'), 'Europe/Paris'), + # StaticTzInfo.tzname returns with '-09' so we need to infer the timezone's + # name from the tzinfo.zone attribute + (pytz.timezone('Etc/GMT-9'), 'Etc/GMT-9'), + (pytz.FixedOffset(180), '+03:00'), + (datetime.timezone.utc, 'UTC'), + (datetime.timezone(datetime.timedelta(hours=1, minutes=30)), '+01:30') +]) +def test_tzinfo_to_string(tz, expected): + assert pa.lib.tzinfo_to_string(tz) == expected + + +def test_tzinfo_to_string_errors(): + msg = "Not an instance of datetime.tzinfo" + with pytest.raises(TypeError): + pa.lib.tzinfo_to_string("Europe/Budapest") + + if sys.version_info >= (3, 8): + # before 3.8 it was only possible to create timezone objects with whole + # number of minutes + tz = datetime.timezone(datetime.timedelta(hours=1, seconds=30)) + msg = "Offset must represent whole number of minutes" + with pytest.raises(ValueError, match=msg): + pa.lib.tzinfo_to_string(tz) + + +@h.given(tzst.timezones()) +def test_pytz_timezone_roundtrip(tz): + timezone_string = pa.lib.tzinfo_to_string(tz) + timezone_tzinfo = pa.lib.string_to_tzinfo(timezone_string) + assert timezone_tzinfo == tz + + +def test_convert_custom_tzinfo_objects_to_string(): + class CorrectTimezone1(datetime.tzinfo): + """ + Conversion is using utcoffset() + """ + + def tzname(self, dt): + return None + + def utcoffset(self, dt): + return datetime.timedelta(hours=-3, minutes=30) + + class CorrectTimezone2(datetime.tzinfo): + """ + Conversion is using tzname() + """ + + def tzname(self, dt): + return "+03:00" + + def utcoffset(self, dt): + return datetime.timedelta(hours=3) + + class BuggyTimezone1(datetime.tzinfo): + """ + Unable to infer name or offset + """ + + def tzname(self, dt): + return None + + def utcoffset(self, dt): + return None + + class BuggyTimezone2(datetime.tzinfo): + """ + Wrong offset type + """ + + def tzname(self, dt): + return None + + def utcoffset(self, dt): + return "one hour" + + class BuggyTimezone3(datetime.tzinfo): + """ + Wrong timezone name type + """ + + def tzname(self, dt): + return 240 + + def utcoffset(self, dt): + return None + + assert pa.lib.tzinfo_to_string(CorrectTimezone1()) == "-02:30" + assert pa.lib.tzinfo_to_string(CorrectTimezone2()) == "+03:00" + + msg = (r"Object returned by tzinfo.utcoffset\(None\) is not an instance " + r"of datetime.timedelta") + for wrong in [BuggyTimezone1(), BuggyTimezone2(), BuggyTimezone3()]: + with pytest.raises(ValueError, match=msg): + pa.lib.tzinfo_to_string(wrong) + + +@pytest.mark.parametrize(('string', 'expected'), [ + ('UTC', pytz.utc), + ('Europe/Paris', pytz.timezone('Europe/Paris')), + ('+03:00', pytz.FixedOffset(180)), + ('+01:30', pytz.FixedOffset(90)), + ('-02:00', pytz.FixedOffset(-120)) +]) +def test_string_to_tzinfo(string, expected): + result = pa.lib.string_to_tzinfo(string) + assert result == expected + + +@pytest.mark.parametrize('tz,name', [ + (pytz.FixedOffset(90), '+01:30'), + (pytz.FixedOffset(-90), '-01:30'), + (pytz.utc, 'UTC'), + (pytz.timezone('America/New_York'), 'America/New_York') +]) +def test_timezone_string_roundtrip(tz, name): + assert pa.lib.tzinfo_to_string(tz) == name + assert pa.lib.string_to_tzinfo(name) == tz + + +def test_timestamp(): + for unit in ('s', 'ms', 'us', 'ns'): + for tz in (None, 'UTC', 'Europe/Paris'): + ty = pa.timestamp(unit, tz=tz) + assert ty.unit == unit + assert ty.tz == tz + + for invalid_unit in ('m', 'arbit', 'rary'): + with pytest.raises(ValueError, match='Invalid time unit'): + pa.timestamp(invalid_unit) + + +def test_time32_units(): + for valid_unit in ('s', 'ms'): + ty = pa.time32(valid_unit) + assert ty.unit == valid_unit + + for invalid_unit in ('m', 'us', 'ns'): + error_msg = 'Invalid time unit for time32: {!r}'.format(invalid_unit) + with pytest.raises(ValueError, match=error_msg): + pa.time32(invalid_unit) + + +def test_time64_units(): + for valid_unit in ('us', 'ns'): + ty = pa.time64(valid_unit) + assert ty.unit == valid_unit + + for invalid_unit in ('m', 's', 'ms'): + error_msg = 'Invalid time unit for time64: {!r}'.format(invalid_unit) + with pytest.raises(ValueError, match=error_msg): + pa.time64(invalid_unit) + + +def test_duration(): + for unit in ('s', 'ms', 'us', 'ns'): + ty = pa.duration(unit) + assert ty.unit == unit + + for invalid_unit in ('m', 'arbit', 'rary'): + with pytest.raises(ValueError, match='Invalid time unit'): + pa.duration(invalid_unit) + + +def test_list_type(): + ty = pa.list_(pa.int64()) + assert isinstance(ty, pa.ListType) + assert ty.value_type == pa.int64() + assert ty.value_field == pa.field("item", pa.int64(), nullable=True) + + with pytest.raises(TypeError): + pa.list_(None) + + +def test_large_list_type(): + ty = pa.large_list(pa.utf8()) + assert isinstance(ty, pa.LargeListType) + assert ty.value_type == pa.utf8() + assert ty.value_field == pa.field("item", pa.utf8(), nullable=True) + + with pytest.raises(TypeError): + pa.large_list(None) + + +def test_map_type(): + ty = pa.map_(pa.utf8(), pa.int32()) + assert isinstance(ty, pa.MapType) + assert ty.key_type == pa.utf8() + assert ty.key_field == pa.field("key", pa.utf8(), nullable=False) + assert ty.item_type == pa.int32() + assert ty.item_field == pa.field("value", pa.int32(), nullable=True) + + with pytest.raises(TypeError): + pa.map_(None) + with pytest.raises(TypeError): + pa.map_(pa.int32(), None) + with pytest.raises(TypeError): + pa.map_(pa.field("name", pa.string(), nullable=True), pa.int64()) + + +def test_fixed_size_list_type(): + ty = pa.list_(pa.float64(), 2) + assert isinstance(ty, pa.FixedSizeListType) + assert ty.value_type == pa.float64() + assert ty.value_field == pa.field("item", pa.float64(), nullable=True) + assert ty.list_size == 2 + + with pytest.raises(ValueError): + pa.list_(pa.float64(), -2) + + +def test_struct_type(): + fields = [ + # Duplicate field name on purpose + pa.field('a', pa.int64()), + pa.field('a', pa.int32()), + pa.field('b', pa.int32()) + ] + ty = pa.struct(fields) + + assert len(ty) == ty.num_fields == 3 + assert list(ty) == fields + assert ty[0].name == 'a' + assert ty[2].type == pa.int32() + with pytest.raises(IndexError): + assert ty[3] + + assert ty['b'] == ty[2] + + # Not found + with pytest.raises(KeyError): + ty['c'] + + # Neither integer nor string + with pytest.raises(TypeError): + ty[None] + + for a, b in zip(ty, fields): + a == b + + # Construct from list of tuples + ty = pa.struct([('a', pa.int64()), + ('a', pa.int32()), + ('b', pa.int32())]) + assert list(ty) == fields + for a, b in zip(ty, fields): + a == b + + # Construct from mapping + fields = [pa.field('a', pa.int64()), + pa.field('b', pa.int32())] + ty = pa.struct(OrderedDict([('a', pa.int64()), + ('b', pa.int32())])) + assert list(ty) == fields + for a, b in zip(ty, fields): + a == b + + # Invalid args + with pytest.raises(TypeError): + pa.struct([('a', None)]) + + +def test_struct_duplicate_field_names(): + fields = [ + pa.field('a', pa.int64()), + pa.field('b', pa.int32()), + pa.field('a', pa.int32()) + ] + ty = pa.struct(fields) + + # Duplicate + with pytest.warns(UserWarning): + with pytest.raises(KeyError): + ty['a'] + + # StructType::GetFieldIndex + assert ty.get_field_index('a') == -1 + + # StructType::GetAllFieldIndices + assert ty.get_all_field_indices('a') == [0, 2] + + +def test_union_type(): + def check_fields(ty, fields): + assert ty.num_fields == len(fields) + assert [ty[i] for i in range(ty.num_fields)] == fields + + fields = [pa.field('x', pa.list_(pa.int32())), + pa.field('y', pa.binary())] + type_codes = [5, 9] + + sparse_factories = [ + partial(pa.union, mode='sparse'), + partial(pa.union, mode=pa.lib.UnionMode_SPARSE), + pa.sparse_union, + ] + + dense_factories = [ + partial(pa.union, mode='dense'), + partial(pa.union, mode=pa.lib.UnionMode_DENSE), + pa.dense_union, + ] + + for factory in sparse_factories: + ty = factory(fields) + assert isinstance(ty, pa.SparseUnionType) + assert ty.mode == 'sparse' + check_fields(ty, fields) + assert ty.type_codes == [0, 1] + ty = factory(fields, type_codes=type_codes) + assert ty.mode == 'sparse' + check_fields(ty, fields) + assert ty.type_codes == type_codes + # Invalid number of type codes + with pytest.raises(ValueError): + factory(fields, type_codes=type_codes[1:]) + + for factory in dense_factories: + ty = factory(fields) + assert isinstance(ty, pa.DenseUnionType) + assert ty.mode == 'dense' + check_fields(ty, fields) + assert ty.type_codes == [0, 1] + ty = factory(fields, type_codes=type_codes) + assert ty.mode == 'dense' + check_fields(ty, fields) + assert ty.type_codes == type_codes + # Invalid number of type codes + with pytest.raises(ValueError): + factory(fields, type_codes=type_codes[1:]) + + for mode in ('unknown', 2): + with pytest.raises(ValueError, match='Invalid union mode'): + pa.union(fields, mode=mode) + + +def test_dictionary_type(): + ty0 = pa.dictionary(pa.int32(), pa.string()) + assert ty0.index_type == pa.int32() + assert ty0.value_type == pa.string() + assert ty0.ordered is False + + ty1 = pa.dictionary(pa.int8(), pa.float64(), ordered=True) + assert ty1.index_type == pa.int8() + assert ty1.value_type == pa.float64() + assert ty1.ordered is True + + # construct from non-arrow objects + ty2 = pa.dictionary('int8', 'string') + assert ty2.index_type == pa.int8() + assert ty2.value_type == pa.string() + assert ty2.ordered is False + + # allow unsigned integers for index type + ty3 = pa.dictionary(pa.uint32(), pa.string()) + assert ty3.index_type == pa.uint32() + assert ty3.value_type == pa.string() + assert ty3.ordered is False + + # invalid index type raises + with pytest.raises(TypeError): + pa.dictionary(pa.string(), pa.int64()) + + +def test_dictionary_ordered_equals(): + # Python side checking of ARROW-6345 + d1 = pa.dictionary('int32', 'binary', ordered=True) + d2 = pa.dictionary('int32', 'binary', ordered=False) + d3 = pa.dictionary('int8', 'binary', ordered=True) + d4 = pa.dictionary('int32', 'binary', ordered=True) + + assert not d1.equals(d2) + assert not d1.equals(d3) + assert d1.equals(d4) + + +def test_types_hashable(): + many_types = get_many_types() + in_dict = {} + for i, type_ in enumerate(many_types): + assert hash(type_) == hash(type_) + in_dict[type_] = i + assert len(in_dict) == len(many_types) + for i, type_ in enumerate(many_types): + assert in_dict[type_] == i + + +def test_types_picklable(): + for ty in get_many_types(): + data = pickle.dumps(ty) + assert pickle.loads(data) == ty + + +def test_types_weakref(): + for ty in get_many_types(): + wr = weakref.ref(ty) + assert wr() is not None + # Note that ty may be a singleton and therefore outlive this loop + + wr = weakref.ref(pa.int32()) + assert wr() is not None # singleton + wr = weakref.ref(pa.list_(pa.int32())) + assert wr() is None # not a singleton + + +def test_fields_hashable(): + in_dict = {} + fields = [pa.field('a', pa.int32()), + pa.field('a', pa.int64()), + pa.field('a', pa.int64(), nullable=False), + pa.field('b', pa.int32()), + pa.field('b', pa.int32(), nullable=False)] + for i, field in enumerate(fields): + in_dict[field] = i + assert len(in_dict) == len(fields) + for i, field in enumerate(fields): + assert in_dict[field] == i + + +def test_fields_weakrefable(): + field = pa.field('a', pa.int32()) + wr = weakref.ref(field) + assert wr() is not None + del field + assert wr() is None + + +@pytest.mark.parametrize('t,check_func', [ + (pa.date32(), types.is_date32), + (pa.date64(), types.is_date64), + (pa.time32('s'), types.is_time32), + (pa.time64('ns'), types.is_time64), + (pa.int8(), types.is_int8), + (pa.int16(), types.is_int16), + (pa.int32(), types.is_int32), + (pa.int64(), types.is_int64), + (pa.uint8(), types.is_uint8), + (pa.uint16(), types.is_uint16), + (pa.uint32(), types.is_uint32), + (pa.uint64(), types.is_uint64), + (pa.float16(), types.is_float16), + (pa.float32(), types.is_float32), + (pa.float64(), types.is_float64) +]) +def test_exact_primitive_types(t, check_func): + assert check_func(t) + + +def test_type_id(): + # enum values are not exposed publicly + for ty in get_many_types(): + assert isinstance(ty.id, int) + + +def test_bit_width(): + for ty, expected in [(pa.bool_(), 1), + (pa.int8(), 8), + (pa.uint32(), 32), + (pa.float16(), 16), + (pa.decimal128(19, 4), 128), + (pa.decimal256(76, 38), 256), + (pa.binary(42), 42 * 8)]: + assert ty.bit_width == expected + for ty in [pa.binary(), pa.string(), pa.list_(pa.int16())]: + with pytest.raises(ValueError, match="fixed width"): + ty.bit_width + + +def test_fixed_size_binary_byte_width(): + ty = pa.binary(5) + assert ty.byte_width == 5 + + +def test_decimal_properties(): + ty = pa.decimal128(19, 4) + assert ty.byte_width == 16 + assert ty.precision == 19 + assert ty.scale == 4 + ty = pa.decimal256(76, 38) + assert ty.byte_width == 32 + assert ty.precision == 76 + assert ty.scale == 38 + + +def test_decimal_overflow(): + pa.decimal128(1, 0) + pa.decimal128(38, 0) + for i in (0, -1, 39): + with pytest.raises(ValueError): + pa.decimal128(i, 0) + + pa.decimal256(1, 0) + pa.decimal256(76, 0) + for i in (0, -1, 77): + with pytest.raises(ValueError): + pa.decimal256(i, 0) + + +def test_type_equality_operators(): + many_types = get_many_types() + non_pyarrow = ('foo', 16, {'s', 'e', 't'}) + + for index, ty in enumerate(many_types): + # could use two parametrization levels, + # but that'd bloat pytest's output + for i, other in enumerate(many_types + non_pyarrow): + if i == index: + assert ty == other + else: + assert ty != other + + +def test_key_value_metadata(): + m = pa.KeyValueMetadata({'a': 'A', 'b': 'B'}) + assert len(m) == 2 + assert m['a'] == b'A' + assert m[b'a'] == b'A' + assert m['b'] == b'B' + assert 'a' in m + assert b'a' in m + assert 'c' not in m + + m1 = pa.KeyValueMetadata({'a': 'A', 'b': 'B'}) + m2 = pa.KeyValueMetadata(a='A', b='B') + m3 = pa.KeyValueMetadata([('a', 'A'), ('b', 'B')]) + + assert m1 != 2 + assert m1 == m2 + assert m2 == m3 + assert m1 == {'a': 'A', 'b': 'B'} + assert m1 != {'a': 'A', 'b': 'C'} + + with pytest.raises(TypeError): + pa.KeyValueMetadata({'a': 1}) + with pytest.raises(TypeError): + pa.KeyValueMetadata({1: 'a'}) + with pytest.raises(TypeError): + pa.KeyValueMetadata(a=1) + + expected = [(b'a', b'A'), (b'b', b'B')] + result = [(k, v) for k, v in m3.items()] + assert result == expected + assert list(m3.items()) == expected + assert list(m3.keys()) == [b'a', b'b'] + assert list(m3.values()) == [b'A', b'B'] + assert len(m3) == 2 + + # test duplicate key support + md = pa.KeyValueMetadata([ + ('a', 'alpha'), + ('b', 'beta'), + ('a', 'Alpha'), + ('a', 'ALPHA'), + ]) + + expected = [ + (b'a', b'alpha'), + (b'b', b'beta'), + (b'a', b'Alpha'), + (b'a', b'ALPHA') + ] + assert len(md) == 4 + assert isinstance(md.keys(), Iterator) + assert isinstance(md.values(), Iterator) + assert isinstance(md.items(), Iterator) + assert list(md.items()) == expected + assert list(md.keys()) == [k for k, _ in expected] + assert list(md.values()) == [v for _, v in expected] + + # first occurrence + assert md['a'] == b'alpha' + assert md['b'] == b'beta' + assert md.get_all('a') == [b'alpha', b'Alpha', b'ALPHA'] + assert md.get_all('b') == [b'beta'] + assert md.get_all('unkown') == [] + + with pytest.raises(KeyError): + md = pa.KeyValueMetadata([ + ('a', 'alpha'), + ('b', 'beta'), + ('a', 'Alpha'), + ('a', 'ALPHA'), + ], b='BETA') + + +def test_key_value_metadata_duplicates(): + meta = pa.KeyValueMetadata({'a': '1', 'b': '2'}) + + with pytest.raises(KeyError): + pa.KeyValueMetadata(meta, a='3') + + +def test_field_basic(): + t = pa.string() + f = pa.field('foo', t) + + assert f.name == 'foo' + assert f.nullable + assert f.type is t + assert repr(f) == "pyarrow.Field<foo: string>" + + f = pa.field('foo', t, False) + assert not f.nullable + + with pytest.raises(TypeError): + pa.field('foo', None) + + +def test_field_equals(): + meta1 = {b'foo': b'bar'} + meta2 = {b'bizz': b'bazz'} + + f1 = pa.field('a', pa.int8(), nullable=True) + f2 = pa.field('a', pa.int8(), nullable=True) + f3 = pa.field('a', pa.int8(), nullable=False) + f4 = pa.field('a', pa.int16(), nullable=False) + f5 = pa.field('b', pa.int16(), nullable=False) + f6 = pa.field('a', pa.int8(), nullable=True, metadata=meta1) + f7 = pa.field('a', pa.int8(), nullable=True, metadata=meta1) + f8 = pa.field('a', pa.int8(), nullable=True, metadata=meta2) + + assert f1.equals(f2) + assert f6.equals(f7) + assert not f1.equals(f3) + assert not f1.equals(f4) + assert not f3.equals(f4) + assert not f4.equals(f5) + + # No metadata in f1, but metadata in f6 + assert f1.equals(f6) + assert not f1.equals(f6, check_metadata=True) + + # Different metadata + assert f6.equals(f7) + assert f7.equals(f8) + assert not f7.equals(f8, check_metadata=True) + + +def test_field_equality_operators(): + f1 = pa.field('a', pa.int8(), nullable=True) + f2 = pa.field('a', pa.int8(), nullable=True) + f3 = pa.field('b', pa.int8(), nullable=True) + f4 = pa.field('b', pa.int8(), nullable=False) + + assert f1 == f2 + assert f1 != f3 + assert f3 != f4 + assert f1 != 'foo' + + +def test_field_metadata(): + f1 = pa.field('a', pa.int8()) + f2 = pa.field('a', pa.int8(), metadata={}) + f3 = pa.field('a', pa.int8(), metadata={b'bizz': b'bazz'}) + + assert f1.metadata is None + assert f2.metadata == {} + assert f3.metadata[b'bizz'] == b'bazz' + + +def test_field_add_remove_metadata(): + import collections + + f0 = pa.field('foo', pa.int32()) + + assert f0.metadata is None + + metadata = {b'foo': b'bar', b'pandas': b'badger'} + metadata2 = collections.OrderedDict([ + (b'a', b'alpha'), + (b'b', b'beta') + ]) + + f1 = f0.with_metadata(metadata) + assert f1.metadata == metadata + + f2 = f0.with_metadata(metadata2) + assert f2.metadata == metadata2 + + with pytest.raises(TypeError): + f0.with_metadata([1, 2, 3]) + + f3 = f1.remove_metadata() + assert f3.metadata is None + + # idempotent + f4 = f3.remove_metadata() + assert f4.metadata is None + + f5 = pa.field('foo', pa.int32(), True, metadata) + f6 = f0.with_metadata(metadata) + assert f5.equals(f6) + + +def test_field_modified_copies(): + f0 = pa.field('foo', pa.int32(), True) + f0_ = pa.field('foo', pa.int32(), True) + assert f0.equals(f0_) + + f1 = pa.field('foo', pa.int64(), True) + f1_ = f0.with_type(pa.int64()) + assert f1.equals(f1_) + # Original instance is unmodified + assert f0.equals(f0_) + + f2 = pa.field('foo', pa.int32(), False) + f2_ = f0.with_nullable(False) + assert f2.equals(f2_) + # Original instance is unmodified + assert f0.equals(f0_) + + f3 = pa.field('bar', pa.int32(), True) + f3_ = f0.with_name('bar') + assert f3.equals(f3_) + # Original instance is unmodified + assert f0.equals(f0_) + + +def test_is_integer_value(): + assert pa.types.is_integer_value(1) + assert pa.types.is_integer_value(np.int64(1)) + assert not pa.types.is_integer_value('1') + + +def test_is_float_value(): + assert not pa.types.is_float_value(1) + assert pa.types.is_float_value(1.) + assert pa.types.is_float_value(np.float64(1)) + assert not pa.types.is_float_value('1.0') + + +def test_is_boolean_value(): + assert not pa.types.is_boolean_value(1) + assert pa.types.is_boolean_value(True) + assert pa.types.is_boolean_value(False) + assert pa.types.is_boolean_value(np.bool_(True)) + assert pa.types.is_boolean_value(np.bool_(False)) + + +@h.given( + past.all_types | + past.all_fields | + past.all_schemas +) +@h.example( + pa.field(name='', type=pa.null(), metadata={'0': '', '': ''}) +) +def test_pickling(field): + data = pickle.dumps(field) + assert pickle.loads(data) == field + + +@h.given( + st.lists(past.all_types) | + st.lists(past.all_fields) | + st.lists(past.all_schemas) +) +def test_hashing(items): + h.assume( + # well, this is still O(n^2), but makes the input unique + all(not a.equals(b) for i, a in enumerate(items) for b in items[:i]) + ) + + container = {} + for i, item in enumerate(items): + assert hash(item) == hash(item) + container[item] = i + + assert len(container) == len(items) + + for i, item in enumerate(items): + assert container[item] == i diff --git a/src/arrow/python/pyarrow/tests/test_util.py b/src/arrow/python/pyarrow/tests/test_util.py new file mode 100644 index 000000000..2b351a534 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_util.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import gc +import signal +import sys +import weakref + +import pytest + +from pyarrow import util +from pyarrow.tests.util import disabled_gc + + +def exhibit_signal_refcycle(): + # Put an object in the frame locals and return a weakref to it. + # If `signal.getsignal` has a bug where it creates a reference cycle + # keeping alive the current execution frames, `obj` will not be + # destroyed immediately when this function returns. + obj = set() + signal.getsignal(signal.SIGINT) + return weakref.ref(obj) + + +def test_signal_refcycle(): + # Test possible workaround for https://bugs.python.org/issue42248 + with disabled_gc(): + wr = exhibit_signal_refcycle() + if wr() is None: + pytest.skip( + "Python version does not have the bug we're testing for") + + gc.collect() + with disabled_gc(): + wr = exhibit_signal_refcycle() + assert wr() is not None + util._break_traceback_cycle_from_frame(sys._getframe(0)) + assert wr() is None diff --git a/src/arrow/python/pyarrow/tests/util.py b/src/arrow/python/pyarrow/tests/util.py new file mode 100644 index 000000000..281de69e3 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/util.py @@ -0,0 +1,331 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Utility functions for testing +""" + +import contextlib +import decimal +import gc +import numpy as np +import os +import random +import signal +import string +import subprocess +import sys + +import pytest + +import pyarrow as pa +import pyarrow.fs + + +def randsign(): + """Randomly choose either 1 or -1. + + Returns + ------- + sign : int + """ + return random.choice((-1, 1)) + + +@contextlib.contextmanager +def random_seed(seed): + """Set the random seed inside of a context manager. + + Parameters + ---------- + seed : int + The seed to set + + Notes + ----- + This function is useful when you want to set a random seed but not affect + the random state of other functions using the random module. + """ + original_state = random.getstate() + random.seed(seed) + try: + yield + finally: + random.setstate(original_state) + + +def randdecimal(precision, scale): + """Generate a random decimal value with specified precision and scale. + + Parameters + ---------- + precision : int + The maximum number of digits to generate. Must be an integer between 1 + and 38 inclusive. + scale : int + The maximum number of digits following the decimal point. Must be an + integer greater than or equal to 0. + + Returns + ------- + decimal_value : decimal.Decimal + A random decimal.Decimal object with the specified precision and scale. + """ + assert 1 <= precision <= 38, 'precision must be between 1 and 38 inclusive' + if scale < 0: + raise ValueError( + 'randdecimal does not yet support generating decimals with ' + 'negative scale' + ) + max_whole_value = 10 ** (precision - scale) - 1 + whole = random.randint(-max_whole_value, max_whole_value) + + if not scale: + return decimal.Decimal(whole) + + max_fractional_value = 10 ** scale - 1 + fractional = random.randint(0, max_fractional_value) + + return decimal.Decimal( + '{}.{}'.format(whole, str(fractional).rjust(scale, '0')) + ) + + +def random_ascii(length): + return bytes(np.random.randint(65, 123, size=length, dtype='i1')) + + +def rands(nchars): + """ + Generate one random string. + """ + RANDS_CHARS = np.array( + list(string.ascii_letters + string.digits), dtype=(np.str_, 1)) + return "".join(np.random.choice(RANDS_CHARS, nchars)) + + +def make_dataframe(): + import pandas as pd + + N = 30 + df = pd.DataFrame( + {col: np.random.randn(N) for col in string.ascii_uppercase[:4]}, + index=pd.Index([rands(10) for _ in range(N)]) + ) + return df + + +def memory_leak_check(f, metric='rss', threshold=1 << 17, iterations=10, + check_interval=1): + """ + Execute the function and try to detect a clear memory leak either internal + to Arrow or caused by a reference counting problem in the Python binding + implementation. Raises exception if a leak detected + + Parameters + ---------- + f : callable + Function to invoke on each iteration + metric : {'rss', 'vms', 'shared'}, default 'rss' + Attribute of psutil.Process.memory_info to use for determining current + memory use + threshold : int, default 128K + Threshold in number of bytes to consider a leak + iterations : int, default 10 + Total number of invocations of f + check_interval : int, default 1 + Number of invocations of f in between each memory use check + """ + import psutil + proc = psutil.Process() + + def _get_use(): + gc.collect() + return getattr(proc.memory_info(), metric) + + baseline_use = _get_use() + + def _leak_check(): + current_use = _get_use() + if current_use - baseline_use > threshold: + raise Exception("Memory leak detected. " + "Departure from baseline {} after {} iterations" + .format(current_use - baseline_use, i)) + + for i in range(iterations): + f() + if i % check_interval == 0: + _leak_check() + + +def get_modified_env_with_pythonpath(): + # Prepend pyarrow root directory to PYTHONPATH + env = os.environ.copy() + existing_pythonpath = env.get('PYTHONPATH', '') + + module_path = os.path.abspath( + os.path.dirname(os.path.dirname(pa.__file__))) + + if existing_pythonpath: + new_pythonpath = os.pathsep.join((module_path, existing_pythonpath)) + else: + new_pythonpath = module_path + env['PYTHONPATH'] = new_pythonpath + return env + + +def invoke_script(script_name, *args): + subprocess_env = get_modified_env_with_pythonpath() + + dir_path = os.path.dirname(os.path.realpath(__file__)) + python_file = os.path.join(dir_path, script_name) + + cmd = [sys.executable, python_file] + cmd.extend(args) + + subprocess.check_call(cmd, env=subprocess_env) + + +@contextlib.contextmanager +def changed_environ(name, value): + """ + Temporarily set environment variable *name* to *value*. + """ + orig_value = os.environ.get(name) + os.environ[name] = value + try: + yield + finally: + if orig_value is None: + del os.environ[name] + else: + os.environ[name] = orig_value + + +@contextlib.contextmanager +def change_cwd(path): + curdir = os.getcwd() + os.chdir(str(path)) + try: + yield + finally: + os.chdir(curdir) + + +@contextlib.contextmanager +def disabled_gc(): + gc.disable() + try: + yield + finally: + gc.enable() + + +def _filesystem_uri(path): + # URIs on Windows must follow 'file:///C:...' or 'file:/C:...' patterns. + if os.name == 'nt': + uri = 'file:///{}'.format(path) + else: + uri = 'file://{}'.format(path) + return uri + + +class FSProtocolClass: + def __init__(self, path): + self._path = path + + def __fspath__(self): + return str(self._path) + + +class ProxyHandler(pyarrow.fs.FileSystemHandler): + """ + A dataset handler that proxies to an underlying filesystem. Useful + to partially wrap an existing filesystem with partial changes. + """ + + def __init__(self, fs): + self._fs = fs + + def __eq__(self, other): + if isinstance(other, ProxyHandler): + return self._fs == other._fs + return NotImplemented + + def __ne__(self, other): + if isinstance(other, ProxyHandler): + return self._fs != other._fs + return NotImplemented + + def get_type_name(self): + return "proxy::" + self._fs.type_name + + def normalize_path(self, path): + return self._fs.normalize_path(path) + + def get_file_info(self, paths): + return self._fs.get_file_info(paths) + + def get_file_info_selector(self, selector): + return self._fs.get_file_info(selector) + + def create_dir(self, path, recursive): + return self._fs.create_dir(path, recursive=recursive) + + def delete_dir(self, path): + return self._fs.delete_dir(path) + + def delete_dir_contents(self, path): + return self._fs.delete_dir_contents(path) + + def delete_root_dir_contents(self): + return self._fs.delete_dir_contents("", accept_root_dir=True) + + def delete_file(self, path): + return self._fs.delete_file(path) + + def move(self, src, dest): + return self._fs.move(src, dest) + + def copy_file(self, src, dest): + return self._fs.copy_file(src, dest) + + def open_input_stream(self, path): + return self._fs.open_input_stream(path) + + def open_input_file(self, path): + return self._fs.open_input_file(path) + + def open_output_stream(self, path, metadata): + return self._fs.open_output_stream(path, metadata=metadata) + + def open_append_stream(self, path, metadata): + return self._fs.open_append_stream(path, metadata=metadata) + + +def get_raise_signal(): + if sys.version_info >= (3, 8): + return signal.raise_signal + elif os.name == 'nt': + # On Windows, os.kill() doesn't actually send a signal, + # it just terminates the process with the given exit code. + pytest.skip("test requires Python 3.8+ on Windows") + else: + # On Unix, emulate raise_signal() with os.kill(). + def raise_signal(signum): + os.kill(os.getpid(), signum) + return raise_signal diff --git a/src/arrow/python/pyarrow/types.pxi b/src/arrow/python/pyarrow/types.pxi new file mode 100644 index 000000000..8795e4d3a --- /dev/null +++ b/src/arrow/python/pyarrow/types.pxi @@ -0,0 +1,2930 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import atexit +from collections.abc import Mapping +import re +import sys +import warnings + + +# These are imprecise because the type (in pandas 0.x) depends on the presence +# of nulls +cdef dict _pandas_type_map = { + _Type_NA: np.object_, # NaNs + _Type_BOOL: np.bool_, + _Type_INT8: np.int8, + _Type_INT16: np.int16, + _Type_INT32: np.int32, + _Type_INT64: np.int64, + _Type_UINT8: np.uint8, + _Type_UINT16: np.uint16, + _Type_UINT32: np.uint32, + _Type_UINT64: np.uint64, + _Type_HALF_FLOAT: np.float16, + _Type_FLOAT: np.float32, + _Type_DOUBLE: np.float64, + _Type_DATE32: np.dtype('datetime64[ns]'), + _Type_DATE64: np.dtype('datetime64[ns]'), + _Type_TIMESTAMP: np.dtype('datetime64[ns]'), + _Type_DURATION: np.dtype('timedelta64[ns]'), + _Type_BINARY: np.object_, + _Type_FIXED_SIZE_BINARY: np.object_, + _Type_STRING: np.object_, + _Type_LIST: np.object_, + _Type_MAP: np.object_, + _Type_DECIMAL128: np.object_, +} + +cdef dict _pep3118_type_map = { + _Type_INT8: b'b', + _Type_INT16: b'h', + _Type_INT32: b'i', + _Type_INT64: b'q', + _Type_UINT8: b'B', + _Type_UINT16: b'H', + _Type_UINT32: b'I', + _Type_UINT64: b'Q', + _Type_HALF_FLOAT: b'e', + _Type_FLOAT: b'f', + _Type_DOUBLE: b'd', +} + + +cdef bytes _datatype_to_pep3118(CDataType* type): + """ + Construct a PEP 3118 format string describing the given datatype. + None is returned for unsupported types. + """ + try: + char = _pep3118_type_map[type.id()] + except KeyError: + return None + else: + if char in b'bBhHiIqQ': + # Use "standard" int widths, not native + return b'=' + char + else: + return char + + +def _is_primitive(Type type): + # This is simply a redirect, the official API is in pyarrow.types. + return is_primitive(type) + + +# Workaround for Cython parsing bug +# https://github.com/cython/cython/issues/2143 +ctypedef CFixedWidthType* _CFixedWidthTypePtr + + +cdef class DataType(_Weakrefable): + """ + Base class of all Arrow data types. + + Each data type is an *instance* of this class. + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call {}'s constructor directly, use public " + "functions like pyarrow.int64, pyarrow.list_, etc. " + "instead.".format(self.__class__.__name__)) + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + assert type != nullptr + self.sp_type = type + self.type = type.get() + self.pep3118_format = _datatype_to_pep3118(self.type) + + cdef Field field(self, int i): + cdef int index = <int> _normalize_index(i, self.type.num_fields()) + return pyarrow_wrap_field(self.type.field(index)) + + @property + def id(self): + return self.type.id() + + @property + def bit_width(self): + cdef _CFixedWidthTypePtr ty + ty = dynamic_cast[_CFixedWidthTypePtr](self.type) + if ty == nullptr: + raise ValueError("Non-fixed width type") + return ty.bit_width() + + @property + def num_children(self): + """ + The number of child fields. + """ + import warnings + warnings.warn("num_children is deprecated, use num_fields", + FutureWarning) + return self.num_fields + + @property + def num_fields(self): + """ + The number of child fields. + """ + return self.type.num_fields() + + @property + def num_buffers(self): + """ + Number of data buffers required to construct Array type + excluding children. + """ + return self.type.layout().buffers.size() + + def __str__(self): + return frombytes(self.type.ToString(), safe=True) + + def __hash__(self): + return hash(str(self)) + + def __reduce__(self): + return type_for_alias, (str(self),) + + def __repr__(self): + return '{0.__class__.__name__}({0})'.format(self) + + def __eq__(self, other): + try: + return self.equals(other) + except (TypeError, ValueError): + return NotImplemented + + def equals(self, other): + """ + Return true if type is equivalent to passed value. + + Parameters + ---------- + other : DataType or string convertible to DataType + + Returns + ------- + is_equal : bool + """ + cdef DataType other_type + + other_type = ensure_type(other) + return self.type.Equals(deref(other_type.type)) + + def to_pandas_dtype(self): + """ + Return the equivalent NumPy / Pandas dtype. + """ + cdef Type type_id = self.type.id() + if type_id in _pandas_type_map: + return _pandas_type_map[type_id] + else: + raise NotImplementedError(str(self)) + + def _export_to_c(self, uintptr_t out_ptr): + """ + Export to a C ArrowSchema struct, given its pointer. + + Be careful: if you don't pass the ArrowSchema struct to a consumer, + its memory will leak. This is a low-level function intended for + expert users. + """ + check_status(ExportType(deref(self.type), <ArrowSchema*> out_ptr)) + + @staticmethod + def _import_from_c(uintptr_t in_ptr): + """ + Import DataType from a C ArrowSchema struct, given its pointer. + + This is a low-level function intended for expert users. + """ + result = GetResultValue(ImportType(<ArrowSchema*> in_ptr)) + return pyarrow_wrap_data_type(result) + + +cdef class DictionaryMemo(_Weakrefable): + """ + Tracking container for dictionary-encoded fields. + """ + + def __cinit__(self): + self.sp_memo.reset(new CDictionaryMemo()) + self.memo = self.sp_memo.get() + + +cdef class DictionaryType(DataType): + """ + Concrete class for dictionary data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.dict_type = <const CDictionaryType*> type.get() + + def __reduce__(self): + return dictionary, (self.index_type, self.value_type, self.ordered) + + @property + def ordered(self): + """ + Whether the dictionary is ordered, i.e. whether the ordering of values + in the dictionary is important. + """ + return self.dict_type.ordered() + + @property + def index_type(self): + """ + The data type of dictionary indices (a signed integer type). + """ + return pyarrow_wrap_data_type(self.dict_type.index_type()) + + @property + def value_type(self): + """ + The dictionary value type. + + The dictionary values are found in an instance of DictionaryArray. + """ + return pyarrow_wrap_data_type(self.dict_type.value_type()) + + +cdef class ListType(DataType): + """ + Concrete class for list data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_type = <const CListType*> type.get() + + def __reduce__(self): + return list_, (self.value_field,) + + @property + def value_field(self): + return pyarrow_wrap_field(self.list_type.value_field()) + + @property + def value_type(self): + """ + The data type of list values. + """ + return pyarrow_wrap_data_type(self.list_type.value_type()) + + +cdef class LargeListType(DataType): + """ + Concrete class for large list data types + (like ListType, but with 64-bit offsets). + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_type = <const CLargeListType*> type.get() + + def __reduce__(self): + return large_list, (self.value_field,) + + @property + def value_field(self): + return pyarrow_wrap_field(self.list_type.value_field()) + + @property + def value_type(self): + """ + The data type of large list values. + """ + return pyarrow_wrap_data_type(self.list_type.value_type()) + + +cdef class MapType(DataType): + """ + Concrete class for map data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.map_type = <const CMapType*> type.get() + + def __reduce__(self): + return map_, (self.key_field, self.item_field) + + @property + def key_field(self): + """ + The field for keys in the map entries. + """ + return pyarrow_wrap_field(self.map_type.key_field()) + + @property + def key_type(self): + """ + The data type of keys in the map entries. + """ + return pyarrow_wrap_data_type(self.map_type.key_type()) + + @property + def item_field(self): + """ + The field for items in the map entries. + """ + return pyarrow_wrap_field(self.map_type.item_field()) + + @property + def item_type(self): + """ + The data type of items in the map entries. + """ + return pyarrow_wrap_data_type(self.map_type.item_type()) + + +cdef class FixedSizeListType(DataType): + """ + Concrete class for fixed size list data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.list_type = <const CFixedSizeListType*> type.get() + + def __reduce__(self): + return list_, (self.value_type, self.list_size) + + @property + def value_field(self): + return pyarrow_wrap_field(self.list_type.value_field()) + + @property + def value_type(self): + """ + The data type of large list values. + """ + return pyarrow_wrap_data_type(self.list_type.value_type()) + + @property + def list_size(self): + """ + The size of the fixed size lists. + """ + return self.list_type.list_size() + + +cdef class StructType(DataType): + """ + Concrete class for struct data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.struct_type = <const CStructType*> type.get() + + cdef Field field_by_name(self, name): + """ + Return a child field by its name rather than its index. + """ + cdef vector[shared_ptr[CField]] fields + + fields = self.struct_type.GetAllFieldsByName(tobytes(name)) + if fields.size() == 0: + raise KeyError(name) + elif fields.size() > 1: + warnings.warn("Struct field name corresponds to more " + "than one field", UserWarning) + raise KeyError(name) + else: + return pyarrow_wrap_field(fields[0]) + + def get_field_index(self, name): + """ + Return index of field with given unique name. Returns -1 if not found + or if duplicated + """ + return self.struct_type.GetFieldIndex(tobytes(name)) + + def get_all_field_indices(self, name): + """ + Return sorted list of indices for fields with the given name + """ + return self.struct_type.GetAllFieldIndices(tobytes(name)) + + def __len__(self): + """ + Like num_fields(). + """ + return self.type.num_fields() + + def __iter__(self): + """ + Iterate over struct fields, in order. + """ + for i in range(len(self)): + yield self[i] + + def __getitem__(self, i): + """ + Return the struct field with the given index or name. + """ + if isinstance(i, (bytes, str)): + return self.field_by_name(i) + elif isinstance(i, int): + return self.field(i) + else: + raise TypeError('Expected integer or string index') + + def __reduce__(self): + return struct, (list(self),) + + +cdef class UnionType(DataType): + """ + Base class for union data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + + @property + def mode(self): + """ + The mode of the union ("dense" or "sparse"). + """ + cdef CUnionType* type = <CUnionType*> self.sp_type.get() + cdef int mode = type.mode() + if mode == _UnionMode_DENSE: + return 'dense' + if mode == _UnionMode_SPARSE: + return 'sparse' + assert 0 + + @property + def type_codes(self): + """ + The type code to indicate each data type in this union. + """ + cdef CUnionType* type = <CUnionType*> self.sp_type.get() + return type.type_codes() + + def __len__(self): + """ + Like num_fields(). + """ + return self.type.num_fields() + + def __iter__(self): + """ + Iterate over union members, in order. + """ + for i in range(len(self)): + yield self[i] + + def __getitem__(self, i): + """ + Return a child field by its index. + """ + return self.field(i) + + def __reduce__(self): + return union, (list(self), self.mode, self.type_codes) + + +cdef class SparseUnionType(UnionType): + """ + Concrete class for sparse union types. + """ + + +cdef class DenseUnionType(UnionType): + """ + Concrete class for dense union types. + """ + + +cdef class TimestampType(DataType): + """ + Concrete class for timestamp data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.ts_type = <const CTimestampType*> type.get() + + @property + def unit(self): + """ + The timestamp unit ('s', 'ms', 'us' or 'ns'). + """ + return timeunit_to_string(self.ts_type.unit()) + + @property + def tz(self): + """ + The timestamp time zone, if any, or None. + """ + if self.ts_type.timezone().size() > 0: + return frombytes(self.ts_type.timezone()) + else: + return None + + def to_pandas_dtype(self): + """ + Return the equivalent NumPy / Pandas dtype. + """ + if self.tz is None: + return _pandas_type_map[_Type_TIMESTAMP] + else: + # Return DatetimeTZ + from pyarrow.pandas_compat import make_datetimetz + return make_datetimetz(self.tz) + + def __reduce__(self): + return timestamp, (self.unit, self.tz) + + +cdef class Time32Type(DataType): + """ + Concrete class for time32 data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.time_type = <const CTime32Type*> type.get() + + @property + def unit(self): + """ + The time unit ('s', 'ms', 'us' or 'ns'). + """ + return timeunit_to_string(self.time_type.unit()) + + +cdef class Time64Type(DataType): + """ + Concrete class for time64 data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.time_type = <const CTime64Type*> type.get() + + @property + def unit(self): + """ + The time unit ('s', 'ms', 'us' or 'ns'). + """ + return timeunit_to_string(self.time_type.unit()) + + +cdef class DurationType(DataType): + """ + Concrete class for duration data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.duration_type = <const CDurationType*> type.get() + + @property + def unit(self): + """ + The duration unit ('s', 'ms', 'us' or 'ns'). + """ + return timeunit_to_string(self.duration_type.unit()) + + +cdef class FixedSizeBinaryType(DataType): + """ + Concrete class for fixed-size binary data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.fixed_size_binary_type = ( + <const CFixedSizeBinaryType*> type.get()) + + def __reduce__(self): + return binary, (self.byte_width,) + + @property + def byte_width(self): + """ + The binary size in bytes. + """ + return self.fixed_size_binary_type.byte_width() + + +cdef class Decimal128Type(FixedSizeBinaryType): + """ + Concrete class for decimal128 data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + FixedSizeBinaryType.init(self, type) + self.decimal128_type = <const CDecimal128Type*> type.get() + + def __reduce__(self): + return decimal128, (self.precision, self.scale) + + @property + def precision(self): + """ + The decimal precision, in number of decimal digits (an integer). + """ + return self.decimal128_type.precision() + + @property + def scale(self): + """ + The decimal scale (an integer). + """ + return self.decimal128_type.scale() + + +cdef class Decimal256Type(FixedSizeBinaryType): + """ + Concrete class for Decimal256 data types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + FixedSizeBinaryType.init(self, type) + self.decimal256_type = <const CDecimal256Type*> type.get() + + def __reduce__(self): + return decimal256, (self.precision, self.scale) + + @property + def precision(self): + """ + The decimal precision, in number of decimal digits (an integer). + """ + return self.decimal256_type.precision() + + @property + def scale(self): + """ + The decimal scale (an integer). + """ + return self.decimal256_type.scale() + + +cdef class BaseExtensionType(DataType): + """ + Concrete base class for extension types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.ext_type = <const CExtensionType*> type.get() + + @property + def extension_name(self): + """ + The extension type name. + """ + return frombytes(self.ext_type.extension_name()) + + @property + def storage_type(self): + """ + The underlying storage type. + """ + return pyarrow_wrap_data_type(self.ext_type.storage_type()) + + def wrap_array(self, storage): + """ + Wrap the given storage array as an extension array. + + Parameters + ---------- + storage : Array or ChunkedArray + + Returns + ------- + array : Array or ChunkedArray + Extension array wrapping the storage array + """ + cdef: + shared_ptr[CDataType] c_storage_type + + if isinstance(storage, Array): + c_storage_type = (<Array> storage).ap.type() + elif isinstance(storage, ChunkedArray): + c_storage_type = (<ChunkedArray> storage).chunked_array.type() + else: + raise TypeError( + f"Expected array or chunked array, got {storage.__class__}") + + if not c_storage_type.get().Equals(deref(self.ext_type) + .storage_type()): + raise TypeError( + f"Incompatible storage type for {self}: " + f"expected {self.storage_type}, got {storage.type}") + + if isinstance(storage, Array): + return pyarrow_wrap_array( + self.ext_type.WrapArray( + self.sp_type, (<Array> storage).sp_array)) + else: + return pyarrow_wrap_chunked_array( + self.ext_type.WrapArray( + self.sp_type, (<ChunkedArray> storage).sp_chunked_array)) + + +cdef class ExtensionType(BaseExtensionType): + """ + Concrete base class for Python-defined extension types. + + Parameters + ---------- + storage_type : DataType + extension_name : str + """ + + def __cinit__(self): + if type(self) is ExtensionType: + raise TypeError("Can only instantiate subclasses of " + "ExtensionType") + + def __init__(self, DataType storage_type, extension_name): + """ + Initialize an extension type instance. + + This should be called at the end of the subclass' + ``__init__`` method. + """ + cdef: + shared_ptr[CExtensionType] cpy_ext_type + c_string c_extension_name + + c_extension_name = tobytes(extension_name) + + assert storage_type is not None + check_status(CPyExtensionType.FromClass( + storage_type.sp_type, c_extension_name, type(self), + &cpy_ext_type)) + self.init(<shared_ptr[CDataType]> cpy_ext_type) + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.cpy_ext_type = <const CPyExtensionType*> type.get() + # Store weakref and serialized version of self on C++ type instance + check_status(self.cpy_ext_type.SetInstance(self)) + + def __eq__(self, other): + # Default implementation to avoid infinite recursion through + # DataType.__eq__ -> ExtensionType::ExtensionEquals -> DataType.__eq__ + if isinstance(other, ExtensionType): + return (type(self) == type(other) and + self.extension_name == other.extension_name and + self.storage_type == other.storage_type) + else: + return NotImplemented + + def __repr__(self): + fmt = '{0.__class__.__name__}({1})' + return fmt.format(self, repr(self.storage_type)) + + def __arrow_ext_serialize__(self): + """ + Serialized representation of metadata to reconstruct the type object. + + This method should return a bytes object, and those serialized bytes + are stored in the custom metadata of the Field holding an extension + type in an IPC message. + The bytes are passed to ``__arrow_ext_deserialize`` and should hold + sufficient information to reconstruct the data type instance. + """ + return NotImplementedError + + @classmethod + def __arrow_ext_deserialize__(self, storage_type, serialized): + """ + Return an extension type instance from the storage type and serialized + metadata. + + This method should return an instance of the ExtensionType subclass + that matches the passed storage type and serialized metadata (the + return value of ``__arrow_ext_serialize__``). + """ + return NotImplementedError + + def __arrow_ext_class__(self): + """Return an extension array class to be used for building or + deserializing arrays with this extension type. + + This method should return a subclass of the ExtensionArray class. By + default, if not specialized in the extension implementation, an + extension type array will be a built-in ExtensionArray instance. + """ + return ExtensionArray + + +cdef class PyExtensionType(ExtensionType): + """ + Concrete base class for Python-defined extension types based on pickle + for (de)serialization. + + Parameters + ---------- + storage_type : DataType + The storage type for which the extension is built. + """ + + def __cinit__(self): + if type(self) is PyExtensionType: + raise TypeError("Can only instantiate subclasses of " + "PyExtensionType") + + def __init__(self, DataType storage_type): + ExtensionType.__init__(self, storage_type, "arrow.py_extension_type") + + def __reduce__(self): + raise NotImplementedError("Please implement {0}.__reduce__" + .format(type(self).__name__)) + + def __arrow_ext_serialize__(self): + return builtin_pickle.dumps(self) + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + try: + ty = builtin_pickle.loads(serialized) + except Exception: + # For some reason, it's impossible to deserialize the + # ExtensionType instance. Perhaps the serialized data is + # corrupt, or more likely the type is being deserialized + # in an environment where the original Python class or module + # is not available. Fall back on a generic BaseExtensionType. + return UnknownExtensionType(storage_type, serialized) + + if ty.storage_type != storage_type: + raise TypeError("Expected storage type {0} but got {1}" + .format(ty.storage_type, storage_type)) + return ty + + +cdef class UnknownExtensionType(PyExtensionType): + """ + A concrete class for Python-defined extension types that refer to + an unknown Python implementation. + + Parameters + ---------- + storage_type : DataType + The storage type for which the extension is built. + serialized : bytes + The serialised output. + """ + + cdef: + bytes serialized + + def __init__(self, DataType storage_type, serialized): + self.serialized = serialized + PyExtensionType.__init__(self, storage_type) + + def __arrow_ext_serialize__(self): + return self.serialized + + +_python_extension_types_registry = [] + + +def register_extension_type(ext_type): + """ + Register a Python extension type. + + Registration is based on the extension name (so different registered types + need unique extension names). Registration needs an extension type + instance, but then works for any instance of the same subclass regardless + of parametrization of the type. + + Parameters + ---------- + ext_type : BaseExtensionType instance + The ExtensionType subclass to register. + + """ + cdef: + DataType _type = ensure_type(ext_type, allow_none=False) + + if not isinstance(_type, BaseExtensionType): + raise TypeError("Only extension types can be registered") + + # register on the C++ side + check_status( + RegisterPyExtensionType(<shared_ptr[CDataType]> _type.sp_type)) + + # register on the python side + _python_extension_types_registry.append(_type) + + +def unregister_extension_type(type_name): + """ + Unregister a Python extension type. + + Parameters + ---------- + type_name : str + The name of the ExtensionType subclass to unregister. + + """ + cdef: + c_string c_type_name = tobytes(type_name) + check_status(UnregisterPyExtensionType(c_type_name)) + + +cdef class KeyValueMetadata(_Metadata, Mapping): + """ + KeyValueMetadata + + Parameters + ---------- + __arg0__ : dict + A dict of the key-value metadata + **kwargs : optional + additional key-value metadata + """ + + def __init__(self, __arg0__=None, **kwargs): + cdef: + vector[c_string] keys, values + shared_ptr[const CKeyValueMetadata] result + + items = [] + if __arg0__ is not None: + other = (__arg0__.items() if isinstance(__arg0__, Mapping) + else __arg0__) + items.extend((tobytes(k), v) for k, v in other) + + prior_keys = {k for k, v in items} + for k, v in kwargs.items(): + k = tobytes(k) + if k in prior_keys: + raise KeyError("Duplicate key {}, " + "use pass all items as list of tuples if you " + "intend to have duplicate keys") + items.append((k, v)) + + keys.reserve(len(items)) + for key, value in items: + keys.push_back(tobytes(key)) + values.push_back(tobytes(value)) + result.reset(new CKeyValueMetadata(move(keys), move(values))) + self.init(result) + + cdef void init(self, const shared_ptr[const CKeyValueMetadata]& wrapped): + self.wrapped = wrapped + self.metadata = wrapped.get() + + @staticmethod + cdef wrap(const shared_ptr[const CKeyValueMetadata]& sp): + cdef KeyValueMetadata self = KeyValueMetadata.__new__(KeyValueMetadata) + self.init(sp) + return self + + cdef inline shared_ptr[const CKeyValueMetadata] unwrap(self) nogil: + return self.wrapped + + def equals(self, KeyValueMetadata other): + return self.metadata.Equals(deref(other.wrapped)) + + def __repr__(self): + return str(self) + + def __str__(self): + return frombytes(self.metadata.ToString(), safe=True) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + pass + + if isinstance(other, Mapping): + try: + other = KeyValueMetadata(other) + return self.equals(other) + except TypeError: + pass + + return NotImplemented + + def __len__(self): + return self.metadata.size() + + def __contains__(self, key): + return self.metadata.Contains(tobytes(key)) + + def __getitem__(self, key): + return GetResultValue(self.metadata.Get(tobytes(key))) + + def __iter__(self): + return self.keys() + + def __reduce__(self): + return KeyValueMetadata, (list(self.items()),) + + def key(self, i): + return self.metadata.key(i) + + def value(self, i): + return self.metadata.value(i) + + def keys(self): + for i in range(self.metadata.size()): + yield self.metadata.key(i) + + def values(self): + for i in range(self.metadata.size()): + yield self.metadata.value(i) + + def items(self): + for i in range(self.metadata.size()): + yield (self.metadata.key(i), self.metadata.value(i)) + + def get_all(self, key): + key = tobytes(key) + return [v for k, v in self.items() if k == key] + + def to_dict(self): + """ + Convert KeyValueMetadata to dict. If a key occurs twice, the value for + the first one is returned + """ + cdef object key # to force coercion to Python + result = ordered_dict() + for i in range(self.metadata.size()): + key = self.metadata.key(i) + if key not in result: + result[key] = self.metadata.value(i) + return result + + +cdef KeyValueMetadata ensure_metadata(object meta, c_bool allow_none=False): + if allow_none and meta is None: + return None + elif isinstance(meta, KeyValueMetadata): + return meta + else: + return KeyValueMetadata(meta) + + +cdef class Field(_Weakrefable): + """ + A named field, with a data type, nullability, and optional metadata. + + Notes + ----- + Do not use this class's constructor directly; use pyarrow.field + """ + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call Field's constructor directly, use " + "`pyarrow.field` instead.") + + cdef void init(self, const shared_ptr[CField]& field): + self.sp_field = field + self.field = field.get() + self.type = pyarrow_wrap_data_type(field.get().type()) + + def equals(self, Field other, bint check_metadata=False): + """ + Test if this field is equal to the other + + Parameters + ---------- + other : pyarrow.Field + check_metadata : bool, default False + Whether Field metadata equality should be checked as well. + + Returns + ------- + is_equal : bool + """ + return self.field.Equals(deref(other.field), check_metadata) + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def __reduce__(self): + return field, (self.name, self.type, self.nullable, self.metadata) + + def __str__(self): + return 'pyarrow.Field<{0}>'.format( + frombytes(self.field.ToString(), safe=True)) + + def __repr__(self): + return self.__str__() + + def __hash__(self): + return hash((self.field.name(), self.type, self.field.nullable())) + + @property + def nullable(self): + return self.field.nullable() + + @property + def name(self): + return frombytes(self.field.name()) + + @property + def metadata(self): + wrapped = pyarrow_wrap_metadata(self.field.metadata()) + if wrapped is not None: + return wrapped.to_dict() + else: + return wrapped + + def add_metadata(self, metadata): + warnings.warn("The 'add_metadata' method is deprecated, use " + "'with_metadata' instead", FutureWarning, stacklevel=2) + return self.with_metadata(metadata) + + def with_metadata(self, metadata): + """ + Add metadata as dict of string keys and values to Field + + Parameters + ---------- + metadata : dict + Keys and values must be string-like / coercible to bytes + + Returns + ------- + field : pyarrow.Field + """ + cdef shared_ptr[CField] c_field + + meta = ensure_metadata(metadata, allow_none=False) + with nogil: + c_field = self.field.WithMetadata(meta.unwrap()) + + return pyarrow_wrap_field(c_field) + + def remove_metadata(self): + """ + Create new field without metadata, if any + + Returns + ------- + field : pyarrow.Field + """ + cdef shared_ptr[CField] new_field + with nogil: + new_field = self.field.RemoveMetadata() + return pyarrow_wrap_field(new_field) + + def with_type(self, DataType new_type): + """ + A copy of this field with the replaced type + + Parameters + ---------- + new_type : pyarrow.DataType + + Returns + ------- + field : pyarrow.Field + """ + cdef: + shared_ptr[CField] c_field + shared_ptr[CDataType] c_datatype + + c_datatype = pyarrow_unwrap_data_type(new_type) + with nogil: + c_field = self.field.WithType(c_datatype) + + return pyarrow_wrap_field(c_field) + + def with_name(self, name): + """ + A copy of this field with the replaced name + + Parameters + ---------- + name : str + + Returns + ------- + field : pyarrow.Field + """ + cdef: + shared_ptr[CField] c_field + + c_field = self.field.WithName(tobytes(name)) + + return pyarrow_wrap_field(c_field) + + def with_nullable(self, nullable): + """ + A copy of this field with the replaced nullability + + Parameters + ---------- + nullable : bool + + Returns + ------- + field: pyarrow.Field + """ + cdef: + shared_ptr[CField] field + c_bool c_nullable + + c_nullable = bool(nullable) + with nogil: + c_field = self.field.WithNullable(c_nullable) + + return pyarrow_wrap_field(c_field) + + def flatten(self): + """ + Flatten this field. If a struct field, individual child fields + will be returned with their names prefixed by the parent's name. + + Returns + ------- + fields : List[pyarrow.Field] + """ + cdef vector[shared_ptr[CField]] flattened + with nogil: + flattened = self.field.Flatten() + return [pyarrow_wrap_field(f) for f in flattened] + + def _export_to_c(self, uintptr_t out_ptr): + """ + Export to a C ArrowSchema struct, given its pointer. + + Be careful: if you don't pass the ArrowSchema struct to a consumer, + its memory will leak. This is a low-level function intended for + expert users. + """ + check_status(ExportField(deref(self.field), <ArrowSchema*> out_ptr)) + + @staticmethod + def _import_from_c(uintptr_t in_ptr): + """ + Import Field from a C ArrowSchema struct, given its pointer. + + This is a low-level function intended for expert users. + """ + with nogil: + result = GetResultValue(ImportField(<ArrowSchema*> in_ptr)) + return pyarrow_wrap_field(result) + + +cdef class Schema(_Weakrefable): + + def __cinit__(self): + pass + + def __init__(self): + raise TypeError("Do not call Schema's constructor directly, use " + "`pyarrow.schema` instead.") + + def __len__(self): + return self.schema.num_fields() + + def __getitem__(self, key): + # access by integer index + return self._field(key) + + def __iter__(self): + for i in range(len(self)): + yield self[i] + + cdef void init(self, const vector[shared_ptr[CField]]& fields): + self.schema = new CSchema(fields) + self.sp_schema.reset(self.schema) + + cdef void init_schema(self, const shared_ptr[CSchema]& schema): + self.schema = schema.get() + self.sp_schema = schema + + def __reduce__(self): + return schema, (list(self), self.metadata) + + def __hash__(self): + return hash((tuple(self), self.metadata)) + + def __sizeof__(self): + size = 0 + if self.metadata: + for key, value in self.metadata.items(): + size += sys.getsizeof(key) + size += sys.getsizeof(value) + + return size + super(Schema, self).__sizeof__() + + @property + def pandas_metadata(self): + """ + Return deserialized-from-JSON pandas metadata field (if it exists) + """ + metadata = self.metadata + key = b'pandas' + if metadata is None or key not in metadata: + return None + + import json + return json.loads(metadata[key].decode('utf8')) + + @property + def names(self): + """ + The schema's field names. + + Returns + ------- + list of str + """ + cdef int i + result = [] + for i in range(self.schema.num_fields()): + name = frombytes(self.schema.field(i).get().name()) + result.append(name) + return result + + @property + def types(self): + """ + The schema's field types. + + Returns + ------- + list of DataType + """ + return [field.type for field in self] + + @property + def metadata(self): + wrapped = pyarrow_wrap_metadata(self.schema.metadata()) + if wrapped is not None: + return wrapped.to_dict() + else: + return wrapped + + def __eq__(self, other): + try: + return self.equals(other) + except TypeError: + return NotImplemented + + def empty_table(self): + """ + Provide an empty table according to the schema. + + Returns + ------- + table: pyarrow.Table + """ + arrays = [_empty_array(field.type) for field in self] + return Table.from_arrays(arrays, schema=self) + + def equals(self, Schema other not None, bint check_metadata=False): + """ + Test if this schema is equal to the other + + Parameters + ---------- + other : pyarrow.Schema + check_metadata : bool, default False + Key/value metadata must be equal too + + Returns + ------- + is_equal : bool + """ + return self.sp_schema.get().Equals(deref(other.schema), + check_metadata) + + @classmethod + def from_pandas(cls, df, preserve_index=None): + """ + Returns implied schema from dataframe + + Parameters + ---------- + df : pandas.DataFrame + preserve_index : bool, default True + Whether to store the index as an additional column (or columns, for + MultiIndex) in the resulting `Table`. + The default of None will store the index as a column, except for + RangeIndex which is stored as metadata only. Use + ``preserve_index=True`` to force it to be stored as a column. + + Returns + ------- + pyarrow.Schema + + Examples + -------- + + >>> import pandas as pd + >>> import pyarrow as pa + >>> df = pd.DataFrame({ + ... 'int': [1, 2], + ... 'str': ['a', 'b'] + ... }) + >>> pa.Schema.from_pandas(df) + int: int64 + str: string + __index_level_0__: int64 + """ + from pyarrow.pandas_compat import dataframe_to_types + names, types, metadata = dataframe_to_types( + df, + preserve_index=preserve_index + ) + fields = [] + for name, type_ in zip(names, types): + fields.append(field(name, type_)) + return schema(fields, metadata) + + def field(self, i): + """ + Select a field by its column name or numeric index. + + Parameters + ---------- + i : int or string + + Returns + ------- + pyarrow.Field + """ + if isinstance(i, (bytes, str)): + field_index = self.get_field_index(i) + if field_index < 0: + raise KeyError("Column {} does not exist in schema".format(i)) + else: + return self._field(field_index) + elif isinstance(i, int): + return self._field(i) + else: + raise TypeError("Index must either be string or integer") + + def _field(self, int i): + """Select a field by its numeric index.""" + cdef int index = <int> _normalize_index(i, self.schema.num_fields()) + return pyarrow_wrap_field(self.schema.field(index)) + + def field_by_name(self, name): + """ + Access a field by its name rather than the column index. + + Parameters + ---------- + name: str + + Returns + ------- + field: pyarrow.Field + """ + cdef: + vector[shared_ptr[CField]] results + + warnings.warn( + "The 'field_by_name' method is deprecated, use 'field' instead", + FutureWarning, stacklevel=2) + + results = self.schema.GetAllFieldsByName(tobytes(name)) + if results.size() == 0: + return None + elif results.size() > 1: + warnings.warn("Schema field name corresponds to more " + "than one field", UserWarning) + return None + else: + return pyarrow_wrap_field(results[0]) + + def get_field_index(self, name): + """ + Return index of field with given unique name. Returns -1 if not found + or if duplicated + """ + return self.schema.GetFieldIndex(tobytes(name)) + + def get_all_field_indices(self, name): + """ + Return sorted list of indices for fields with the given name + """ + return self.schema.GetAllFieldIndices(tobytes(name)) + + def append(self, Field field): + """ + Append a field at the end of the schema. + + In contrast to Python's ``list.append()`` it does return a new + object, leaving the original Schema unmodified. + + Parameters + ---------- + field: Field + + Returns + ------- + schema: Schema + New object with appended field. + """ + return self.insert(self.schema.num_fields(), field) + + def insert(self, int i, Field field): + """ + Add a field at position i to the schema. + + Parameters + ---------- + i: int + field: Field + + Returns + ------- + schema: Schema + """ + cdef: + shared_ptr[CSchema] new_schema + shared_ptr[CField] c_field + + c_field = field.sp_field + + with nogil: + new_schema = GetResultValue(self.schema.AddField(i, c_field)) + + return pyarrow_wrap_schema(new_schema) + + def remove(self, int i): + """ + Remove the field at index i from the schema. + + Parameters + ---------- + i: int + + Returns + ------- + schema: Schema + """ + cdef shared_ptr[CSchema] new_schema + + with nogil: + new_schema = GetResultValue(self.schema.RemoveField(i)) + + return pyarrow_wrap_schema(new_schema) + + def set(self, int i, Field field): + """ + Replace a field at position i in the schema. + + Parameters + ---------- + i: int + field: Field + + Returns + ------- + schema: Schema + """ + cdef: + shared_ptr[CSchema] new_schema + shared_ptr[CField] c_field + + c_field = field.sp_field + + with nogil: + new_schema = GetResultValue(self.schema.SetField(i, c_field)) + + return pyarrow_wrap_schema(new_schema) + + def add_metadata(self, metadata): + warnings.warn("The 'add_metadata' method is deprecated, use " + "'with_metadata' instead", FutureWarning, stacklevel=2) + return self.with_metadata(metadata) + + def with_metadata(self, metadata): + """ + Add metadata as dict of string keys and values to Schema + + Parameters + ---------- + metadata : dict + Keys and values must be string-like / coercible to bytes + + Returns + ------- + schema : pyarrow.Schema + """ + cdef shared_ptr[CSchema] c_schema + + meta = ensure_metadata(metadata, allow_none=False) + with nogil: + c_schema = self.schema.WithMetadata(meta.unwrap()) + + return pyarrow_wrap_schema(c_schema) + + def serialize(self, memory_pool=None): + """ + Write Schema to Buffer as encapsulated IPC message + + Parameters + ---------- + memory_pool : MemoryPool, default None + Uses default memory pool if not specified + + Returns + ------- + serialized : Buffer + """ + cdef: + shared_ptr[CBuffer] buffer + CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) + + with nogil: + buffer = GetResultValue(SerializeSchema(deref(self.schema), + pool)) + return pyarrow_wrap_buffer(buffer) + + def remove_metadata(self): + """ + Create new schema without metadata, if any + + Returns + ------- + schema : pyarrow.Schema + """ + cdef shared_ptr[CSchema] new_schema + with nogil: + new_schema = self.schema.RemoveMetadata() + return pyarrow_wrap_schema(new_schema) + + def to_string(self, truncate_metadata=True, show_field_metadata=True, + show_schema_metadata=True): + """ + Return human-readable representation of Schema + + Parameters + ---------- + truncate_metadata : boolean, default True + Limit metadata key/value display to a single line of ~80 characters + or less + show_field_metadata : boolean, default True + Display Field-level KeyValueMetadata + show_schema_metadata : boolean, default True + Display Schema-level KeyValueMetadata + + Returns + ------- + str : the formatted output + """ + cdef: + c_string result + PrettyPrintOptions options = PrettyPrintOptions.Defaults() + + options.indent = 0 + options.truncate_metadata = truncate_metadata + options.show_field_metadata = show_field_metadata + options.show_schema_metadata = show_schema_metadata + + with nogil: + check_status( + PrettyPrint( + deref(self.schema), + options, + &result + ) + ) + + return frombytes(result, safe=True) + + def _export_to_c(self, uintptr_t out_ptr): + """ + Export to a C ArrowSchema struct, given its pointer. + + Be careful: if you don't pass the ArrowSchema struct to a consumer, + its memory will leak. This is a low-level function intended for + expert users. + """ + check_status(ExportSchema(deref(self.schema), <ArrowSchema*> out_ptr)) + + @staticmethod + def _import_from_c(uintptr_t in_ptr): + """ + Import Schema from a C ArrowSchema struct, given its pointer. + + This is a low-level function intended for expert users. + """ + with nogil: + result = GetResultValue(ImportSchema(<ArrowSchema*> in_ptr)) + return pyarrow_wrap_schema(result) + + def __str__(self): + return self.to_string() + + def __repr__(self): + return self.__str__() + + +def unify_schemas(schemas): + """ + Unify schemas by merging fields by name. + + The resulting schema will contain the union of fields from all schemas. + Fields with the same name will be merged. Note that two fields with + different types will fail merging. + + - The unified field will inherit the metadata from the schema where + that field is first defined. + - The first N fields in the schema will be ordered the same as the + N fields in the first schema. + + The resulting schema will inherit its metadata from the first input + schema. + + Parameters + ---------- + schemas : list of Schema + Schemas to merge into a single one. + + Returns + ------- + Schema + + Raises + ------ + ArrowInvalid : + If any input schema contains fields with duplicate names. + If Fields of the same name are not mergeable. + """ + cdef: + Schema schema + vector[shared_ptr[CSchema]] c_schemas + for schema in schemas: + c_schemas.push_back(pyarrow_unwrap_schema(schema)) + return pyarrow_wrap_schema(GetResultValue(UnifySchemas(c_schemas))) + + +cdef dict _type_cache = {} + + +cdef DataType primitive_type(Type type): + if type in _type_cache: + return _type_cache[type] + + cdef DataType out = DataType.__new__(DataType) + out.init(GetPrimitiveType(type)) + + _type_cache[type] = out + return out + + +# ----------------------------------------------------------- +# Type factory functions + + +def field(name, type, bint nullable=True, metadata=None): + """ + Create a pyarrow.Field instance. + + Parameters + ---------- + name : str or bytes + Name of the field. + type : pyarrow.DataType + Arrow datatype of the field. + nullable : bool, default True + Whether the field's values are nullable. + metadata : dict, default None + Optional field metadata, the keys and values must be coercible to + bytes. + + Returns + ------- + field : pyarrow.Field + """ + cdef: + Field result = Field.__new__(Field) + DataType _type = ensure_type(type, allow_none=False) + shared_ptr[const CKeyValueMetadata] c_meta + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + + if _type.type.id() == _Type_NA and not nullable: + raise ValueError("A null type field may not be non-nullable") + + result.sp_field.reset( + new CField(tobytes(name), _type.sp_type, nullable, c_meta) + ) + result.field = result.sp_field.get() + result.type = _type + + return result + + +cdef set PRIMITIVE_TYPES = set([ + _Type_NA, _Type_BOOL, + _Type_UINT8, _Type_INT8, + _Type_UINT16, _Type_INT16, + _Type_UINT32, _Type_INT32, + _Type_UINT64, _Type_INT64, + _Type_TIMESTAMP, _Type_DATE32, + _Type_TIME32, _Type_TIME64, + _Type_DATE64, + _Type_HALF_FLOAT, + _Type_FLOAT, + _Type_DOUBLE]) + + +def null(): + """ + Create instance of null type. + """ + return primitive_type(_Type_NA) + + +def bool_(): + """ + Create instance of boolean type. + """ + return primitive_type(_Type_BOOL) + + +def uint8(): + """ + Create instance of unsigned int8 type. + """ + return primitive_type(_Type_UINT8) + + +def int8(): + """ + Create instance of signed int8 type. + """ + return primitive_type(_Type_INT8) + + +def uint16(): + """ + Create instance of unsigned uint16 type. + """ + return primitive_type(_Type_UINT16) + + +def int16(): + """ + Create instance of signed int16 type. + """ + return primitive_type(_Type_INT16) + + +def uint32(): + """ + Create instance of unsigned uint32 type. + """ + return primitive_type(_Type_UINT32) + + +def int32(): + """ + Create instance of signed int32 type. + """ + return primitive_type(_Type_INT32) + + +def uint64(): + """ + Create instance of unsigned uint64 type. + """ + return primitive_type(_Type_UINT64) + + +def int64(): + """ + Create instance of signed int64 type. + """ + return primitive_type(_Type_INT64) + + +cdef dict _timestamp_type_cache = {} +cdef dict _time_type_cache = {} +cdef dict _duration_type_cache = {} + + +cdef timeunit_to_string(TimeUnit unit): + if unit == TimeUnit_SECOND: + return 's' + elif unit == TimeUnit_MILLI: + return 'ms' + elif unit == TimeUnit_MICRO: + return 'us' + elif unit == TimeUnit_NANO: + return 'ns' + + +cdef TimeUnit string_to_timeunit(unit) except *: + if unit == 's': + return TimeUnit_SECOND + elif unit == 'ms': + return TimeUnit_MILLI + elif unit == 'us': + return TimeUnit_MICRO + elif unit == 'ns': + return TimeUnit_NANO + else: + raise ValueError(f"Invalid time unit: {unit!r}") + + +def tzinfo_to_string(tz): + """ + Converts a time zone object into a string indicating the name of a time + zone, one of: + * As used in the Olson time zone database (the "tz database" or + "tzdata"), such as "America/New_York" + * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + + Parameters + ---------- + tz : datetime.tzinfo + Time zone object + + Returns + ------- + name : str + Time zone name + """ + return frombytes(GetResultValue(TzinfoToString(<PyObject*>tz))) + + +def string_to_tzinfo(name): + """ + Convert a time zone name into a time zone object. + + Supported input strings are: + * As used in the Olson time zone database (the "tz database" or + "tzdata"), such as "America/New_York" + * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + + Parameters + ---------- + name: str + Time zone name. + + Returns + ------- + tz : datetime.tzinfo + Time zone object + """ + cdef PyObject* tz = GetResultValue(StringToTzinfo(name.encode('utf-8'))) + return PyObject_to_object(tz) + + +def timestamp(unit, tz=None): + """ + Create instance of timestamp type with resolution and optional time zone. + + Parameters + ---------- + unit : str + one of 's' [second], 'ms' [millisecond], 'us' [microsecond], or 'ns' + [nanosecond] + tz : str, default None + Time zone name. None indicates time zone naive + + Examples + -------- + >>> import pyarrow as pa + >>> pa.timestamp('us') + TimestampType(timestamp[us]) + >>> pa.timestamp('s', tz='America/New_York') + TimestampType(timestamp[s, tz=America/New_York]) + >>> pa.timestamp('s', tz='+07:30') + TimestampType(timestamp[s, tz=+07:30]) + + Returns + ------- + timestamp_type : TimestampType + """ + cdef: + TimeUnit unit_code + c_string c_timezone + + unit_code = string_to_timeunit(unit) + + cdef TimestampType out = TimestampType.__new__(TimestampType) + + if tz is None: + out.init(ctimestamp(unit_code)) + if unit_code in _timestamp_type_cache: + return _timestamp_type_cache[unit_code] + _timestamp_type_cache[unit_code] = out + else: + if not isinstance(tz, (bytes, str)): + tz = tzinfo_to_string(tz) + + c_timezone = tobytes(tz) + out.init(ctimestamp(unit_code, c_timezone)) + + return out + + +def time32(unit): + """ + Create instance of 32-bit time (time of day) type with unit resolution. + + Parameters + ---------- + unit : str + one of 's' [second], or 'ms' [millisecond] + + Returns + ------- + type : pyarrow.Time32Type + + Examples + -------- + >>> import pyarrow as pa + >>> pa.time32('s') + Time32Type(time32[s]) + >>> pa.time32('ms') + Time32Type(time32[ms]) + """ + cdef: + TimeUnit unit_code + c_string c_timezone + + if unit == 's': + unit_code = TimeUnit_SECOND + elif unit == 'ms': + unit_code = TimeUnit_MILLI + else: + raise ValueError(f"Invalid time unit for time32: {unit!r}") + + if unit_code in _time_type_cache: + return _time_type_cache[unit_code] + + cdef Time32Type out = Time32Type.__new__(Time32Type) + + out.init(ctime32(unit_code)) + _time_type_cache[unit_code] = out + + return out + + +def time64(unit): + """ + Create instance of 64-bit time (time of day) type with unit resolution. + + Parameters + ---------- + unit : str + One of 'us' [microsecond], or 'ns' [nanosecond]. + + Returns + ------- + type : pyarrow.Time64Type + + Examples + -------- + >>> import pyarrow as pa + >>> pa.time64('us') + Time64Type(time64[us]) + >>> pa.time64('ns') + Time64Type(time64[ns]) + """ + cdef: + TimeUnit unit_code + c_string c_timezone + + if unit == 'us': + unit_code = TimeUnit_MICRO + elif unit == 'ns': + unit_code = TimeUnit_NANO + else: + raise ValueError(f"Invalid time unit for time64: {unit!r}") + + if unit_code in _time_type_cache: + return _time_type_cache[unit_code] + + cdef Time64Type out = Time64Type.__new__(Time64Type) + + out.init(ctime64(unit_code)) + _time_type_cache[unit_code] = out + + return out + + +def duration(unit): + """ + Create instance of a duration type with unit resolution. + + Parameters + ---------- + unit : str + One of 's' [second], 'ms' [millisecond], 'us' [microsecond], or + 'ns' [nanosecond]. + + Returns + ------- + type : pyarrow.DurationType + + Examples + -------- + >>> import pyarrow as pa + >>> pa.duration('us') + DurationType(duration[us]) + >>> pa.duration('s') + DurationType(duration[s]) + """ + cdef: + TimeUnit unit_code + + unit_code = string_to_timeunit(unit) + + if unit_code in _duration_type_cache: + return _duration_type_cache[unit_code] + + cdef DurationType out = DurationType.__new__(DurationType) + + out.init(cduration(unit_code)) + _duration_type_cache[unit_code] = out + + return out + + +def month_day_nano_interval(): + """ + Create instance of an interval type representing months, days and + nanoseconds between two dates. + """ + return primitive_type(_Type_INTERVAL_MONTH_DAY_NANO) + + +def date32(): + """ + Create instance of 32-bit date (days since UNIX epoch 1970-01-01). + """ + return primitive_type(_Type_DATE32) + + +def date64(): + """ + Create instance of 64-bit date (milliseconds since UNIX epoch 1970-01-01). + """ + return primitive_type(_Type_DATE64) + + +def float16(): + """ + Create half-precision floating point type. + """ + return primitive_type(_Type_HALF_FLOAT) + + +def float32(): + """ + Create single-precision floating point type. + """ + return primitive_type(_Type_FLOAT) + + +def float64(): + """ + Create double-precision floating point type. + """ + return primitive_type(_Type_DOUBLE) + + +cpdef DataType decimal128(int precision, int scale=0): + """ + Create decimal type with precision and scale and 128-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + As an example, ``decimal128(7, 3)`` can exactly represent the numbers + 1234.567 and -1234.567 (encoded internally as the 128-bit integers + 1234567 and -1234567, respectively), but neither 12345.67 nor 123.4567. + + ``decimal128(5, -3)`` can exactly represent the number 12345000 + (encoded internally as the 128-bit integer 12345), but neither + 123450000 nor 1234500. + + If you need a precision higher than 38 significant digits, consider + using ``decimal256``. + + Parameters + ---------- + precision : int + Must be between 1 and 38 + scale : int + + Returns + ------- + decimal_type : Decimal128Type + """ + cdef shared_ptr[CDataType] decimal_type + if precision < 1 or precision > 38: + raise ValueError("precision should be between 1 and 38") + decimal_type.reset(new CDecimal128Type(precision, scale)) + return pyarrow_wrap_data_type(decimal_type) + + +cpdef DataType decimal256(int precision, int scale=0): + """ + Create decimal type with precision and scale and 256-bit width. + + Arrow decimals are fixed-point decimal numbers encoded as a scaled + integer. The precision is the number of significant digits that the + decimal type can represent; the scale is the number of digits after + the decimal point (note the scale can be negative). + + For most use cases, the maximum precision offered by ``decimal128`` + is sufficient, and it will result in a more compact and more efficient + encoding. ``decimal256`` is useful if you need a precision higher + than 38 significant digits. + + Parameters + ---------- + precision : int + Must be between 1 and 76 + scale : int + + Returns + ------- + decimal_type : Decimal256Type + """ + cdef shared_ptr[CDataType] decimal_type + if precision < 1 or precision > 76: + raise ValueError("precision should be between 1 and 76") + decimal_type.reset(new CDecimal256Type(precision, scale)) + return pyarrow_wrap_data_type(decimal_type) + + +def string(): + """ + Create UTF8 variable-length string type. + """ + return primitive_type(_Type_STRING) + + +def utf8(): + """ + Alias for string(). + """ + return string() + + +def binary(int length=-1): + """ + Create variable-length binary type. + + Parameters + ---------- + length : int, optional, default -1 + If length == -1 then return a variable length binary type. If length is + greater than or equal to 0 then return a fixed size binary type of + width `length`. + """ + if length == -1: + return primitive_type(_Type_BINARY) + + cdef shared_ptr[CDataType] fixed_size_binary_type + fixed_size_binary_type.reset(new CFixedSizeBinaryType(length)) + return pyarrow_wrap_data_type(fixed_size_binary_type) + + +def large_binary(): + """ + Create large variable-length binary type. + + This data type may not be supported by all Arrow implementations. Unless + you need to represent data larger than 2GB, you should prefer binary(). + """ + return primitive_type(_Type_LARGE_BINARY) + + +def large_string(): + """ + Create large UTF8 variable-length string type. + + This data type may not be supported by all Arrow implementations. Unless + you need to represent data larger than 2GB, you should prefer string(). + """ + return primitive_type(_Type_LARGE_STRING) + + +def large_utf8(): + """ + Alias for large_string(). + """ + return large_string() + + +def list_(value_type, int list_size=-1): + """ + Create ListType instance from child data type or field. + + Parameters + ---------- + value_type : DataType or Field + list_size : int, optional, default -1 + If length == -1 then return a variable length list type. If length is + greater than or equal to 0 then return a fixed size list type. + + Returns + ------- + list_type : DataType + """ + cdef: + Field _field + shared_ptr[CDataType] list_type + + if isinstance(value_type, DataType): + _field = field('item', value_type) + elif isinstance(value_type, Field): + _field = value_type + else: + raise TypeError('List requires DataType or Field') + + if list_size == -1: + list_type.reset(new CListType(_field.sp_field)) + else: + if list_size < 0: + raise ValueError("list_size should be a positive integer") + list_type.reset(new CFixedSizeListType(_field.sp_field, list_size)) + + return pyarrow_wrap_data_type(list_type) + + +cpdef LargeListType large_list(value_type): + """ + Create LargeListType instance from child data type or field. + + This data type may not be supported by all Arrow implementations. + Unless you need to represent data larger than 2**31 elements, you should + prefer list_(). + + Parameters + ---------- + value_type : DataType or Field + + Returns + ------- + list_type : DataType + """ + cdef: + DataType data_type + Field _field + shared_ptr[CDataType] list_type + LargeListType out = LargeListType.__new__(LargeListType) + + if isinstance(value_type, DataType): + _field = field('item', value_type) + elif isinstance(value_type, Field): + _field = value_type + else: + raise TypeError('List requires DataType or Field') + + list_type.reset(new CLargeListType(_field.sp_field)) + out.init(list_type) + return out + + +cpdef MapType map_(key_type, item_type, keys_sorted=False): + """ + Create MapType instance from key and item data types or fields. + + Parameters + ---------- + key_type : DataType + item_type : DataType + keys_sorted : bool + + Returns + ------- + map_type : DataType + """ + cdef: + Field _key_field + Field _item_field + shared_ptr[CDataType] map_type + MapType out = MapType.__new__(MapType) + + if isinstance(key_type, Field): + if key_type.nullable: + raise TypeError('Map key field should be non-nullable') + _key_field = key_type + else: + _key_field = field('key', ensure_type(key_type, allow_none=False), + nullable=False) + + if isinstance(item_type, Field): + _item_field = item_type + else: + _item_field = field('value', ensure_type(item_type, allow_none=False)) + + map_type.reset(new CMapType(_key_field.sp_field, _item_field.sp_field, + keys_sorted)) + out.init(map_type) + return out + + +cpdef DictionaryType dictionary(index_type, value_type, bint ordered=False): + """ + Dictionary (categorical, or simply encoded) type. + + Parameters + ---------- + index_type : DataType + value_type : DataType + ordered : bool + + Returns + ------- + type : DictionaryType + """ + cdef: + DataType _index_type = ensure_type(index_type, allow_none=False) + DataType _value_type = ensure_type(value_type, allow_none=False) + DictionaryType out = DictionaryType.__new__(DictionaryType) + shared_ptr[CDataType] dict_type + + if _index_type.id not in { + Type_INT8, Type_INT16, Type_INT32, Type_INT64, + Type_UINT8, Type_UINT16, Type_UINT32, Type_UINT64, + }: + raise TypeError("The dictionary index type should be integer.") + + dict_type.reset(new CDictionaryType(_index_type.sp_type, + _value_type.sp_type, ordered == 1)) + out.init(dict_type) + return out + + +def struct(fields): + """ + Create StructType instance from fields. + + A struct is a nested type parameterized by an ordered sequence of types + (which can all be distinct), called its fields. + + Parameters + ---------- + fields : iterable of Fields or tuples, or mapping of strings to DataTypes + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + + Examples + -------- + >>> import pyarrow as pa + >>> fields = [ + ... ('f1', pa.int32()), + ... ('f2', pa.string()), + ... ] + >>> struct_type = pa.struct(fields) + >>> struct_type + StructType(struct<f1: int32, f2: string>) + >>> fields = [ + ... pa.field('f1', pa.int32()), + ... pa.field('f2', pa.string(), nullable=False), + ... ] + >>> pa.struct(fields) + StructType(struct<f1: int32, f2: string not null>) + + Returns + ------- + type : DataType + """ + cdef: + Field py_field + vector[shared_ptr[CField]] c_fields + cdef shared_ptr[CDataType] struct_type + + if isinstance(fields, Mapping): + fields = fields.items() + + for item in fields: + if isinstance(item, tuple): + py_field = field(*item) + else: + py_field = item + c_fields.push_back(py_field.sp_field) + + struct_type.reset(new CStructType(c_fields)) + return pyarrow_wrap_data_type(struct_type) + + +cdef _extract_union_params(child_fields, type_codes, + vector[shared_ptr[CField]]* c_fields, + vector[int8_t]* c_type_codes): + cdef: + Field child_field + + for child_field in child_fields: + c_fields[0].push_back(child_field.sp_field) + + if type_codes is not None: + if len(type_codes) != <Py_ssize_t>(c_fields.size()): + raise ValueError("type_codes should have the same length " + "as fields") + for code in type_codes: + c_type_codes[0].push_back(code) + else: + c_type_codes[0] = range(c_fields.size()) + + +def sparse_union(child_fields, type_codes=None): + """ + Create SparseUnionType from child fields. + + A sparse union is a nested type where each logical value is taken from + a single child. A buffer of 8-bit type ids indicates which child + a given logical value is to be taken from. + + In a sparse union, each child array should have the same length as the + union array, regardless of the actual number of union values that + refer to it. + + Parameters + ---------- + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + type_codes : list of integers, default None + + Returns + ------- + type : SparseUnionType + """ + cdef: + vector[shared_ptr[CField]] c_fields + vector[int8_t] c_type_codes + + _extract_union_params(child_fields, type_codes, + &c_fields, &c_type_codes) + + return pyarrow_wrap_data_type( + CMakeSparseUnionType(move(c_fields), move(c_type_codes))) + + +def dense_union(child_fields, type_codes=None): + """ + Create DenseUnionType from child fields. + + A dense union is a nested type where each logical value is taken from + a single child, at a specific offset. A buffer of 8-bit type ids + indicates which child a given logical value is to be taken from, + and a buffer of 32-bit offsets indicates at which physical position + in the given child array the logical value is to be taken from. + + Unlike a sparse union, a dense union allows encoding only the child array + values which are actually referred to by the union array. This is + counterbalanced by the additional footprint of the offsets buffer, and + the additional indirection cost when looking up values. + + Parameters + ---------- + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + type_codes : list of integers, default None + + Returns + ------- + type : DenseUnionType + """ + cdef: + vector[shared_ptr[CField]] c_fields + vector[int8_t] c_type_codes + + _extract_union_params(child_fields, type_codes, + &c_fields, &c_type_codes) + + return pyarrow_wrap_data_type( + CMakeDenseUnionType(move(c_fields), move(c_type_codes))) + + +def union(child_fields, mode, type_codes=None): + """ + Create UnionType from child fields. + + A union is a nested type where each logical value is taken from a + single child. A buffer of 8-bit type ids indicates which child + a given logical value is to be taken from. + + Unions come in two flavors: sparse and dense + (see also `pyarrow.sparse_union` and `pyarrow.dense_union`). + + Parameters + ---------- + child_fields : sequence of Field values + Each field must have a UTF8-encoded name, and these field names are + part of the type metadata. + mode : str + Must be 'sparse' or 'dense' + type_codes : list of integers, default None + + Returns + ------- + type : UnionType + """ + cdef: + Field child_field + vector[shared_ptr[CField]] c_fields + vector[int8_t] c_type_codes + shared_ptr[CDataType] union_type + int i + + if isinstance(mode, int): + if mode not in (_UnionMode_SPARSE, _UnionMode_DENSE): + raise ValueError("Invalid union mode {0!r}".format(mode)) + else: + if mode == 'sparse': + mode = _UnionMode_SPARSE + elif mode == 'dense': + mode = _UnionMode_DENSE + else: + raise ValueError("Invalid union mode {0!r}".format(mode)) + + if mode == _UnionMode_SPARSE: + return sparse_union(child_fields, type_codes) + else: + return dense_union(child_fields, type_codes) + + +cdef dict _type_aliases = { + 'null': null, + 'bool': bool_, + 'boolean': bool_, + 'i1': int8, + 'int8': int8, + 'i2': int16, + 'int16': int16, + 'i4': int32, + 'int32': int32, + 'i8': int64, + 'int64': int64, + 'u1': uint8, + 'uint8': uint8, + 'u2': uint16, + 'uint16': uint16, + 'u4': uint32, + 'uint32': uint32, + 'u8': uint64, + 'uint64': uint64, + 'f2': float16, + 'halffloat': float16, + 'float16': float16, + 'f4': float32, + 'float': float32, + 'float32': float32, + 'f8': float64, + 'double': float64, + 'float64': float64, + 'string': string, + 'str': string, + 'utf8': string, + 'binary': binary, + 'large_string': large_string, + 'large_str': large_string, + 'large_utf8': large_string, + 'large_binary': large_binary, + 'date32': date32, + 'date64': date64, + 'date32[day]': date32, + 'date64[ms]': date64, + 'time32[s]': time32('s'), + 'time32[ms]': time32('ms'), + 'time64[us]': time64('us'), + 'time64[ns]': time64('ns'), + 'timestamp[s]': timestamp('s'), + 'timestamp[ms]': timestamp('ms'), + 'timestamp[us]': timestamp('us'), + 'timestamp[ns]': timestamp('ns'), + 'duration[s]': duration('s'), + 'duration[ms]': duration('ms'), + 'duration[us]': duration('us'), + 'duration[ns]': duration('ns'), + 'month_day_nano_interval': month_day_nano_interval(), +} + + +def type_for_alias(name): + """ + Return DataType given a string alias if one exists. + + Parameters + ---------- + name : str + The alias of the DataType that should be retrieved. + + Returns + ------- + type : DataType + """ + name = name.lower() + try: + alias = _type_aliases[name] + except KeyError: + raise ValueError('No type alias for {0}'.format(name)) + + if isinstance(alias, DataType): + return alias + return alias() + + +cpdef DataType ensure_type(object ty, bint allow_none=False): + if allow_none and ty is None: + return None + elif isinstance(ty, DataType): + return ty + elif isinstance(ty, str): + return type_for_alias(ty) + else: + raise TypeError('DataType expected, got {!r}'.format(type(ty))) + + +def schema(fields, metadata=None): + """ + Construct pyarrow.Schema from collection of fields. + + Parameters + ---------- + fields : iterable of Fields or tuples, or mapping of strings to DataTypes + metadata : dict, default None + Keys and values must be coercible to bytes. + + Examples + -------- + >>> import pyarrow as pa + >>> pa.schema([ + ... ('some_int', pa.int32()), + ... ('some_string', pa.string()) + ... ]) + some_int: int32 + some_string: string + >>> pa.schema([ + ... pa.field('some_int', pa.int32()), + ... pa.field('some_string', pa.string()) + ... ]) + some_int: int32 + some_string: string + + Returns + ------- + schema : pyarrow.Schema + """ + cdef: + shared_ptr[const CKeyValueMetadata] c_meta + shared_ptr[CSchema] c_schema + Schema result + Field py_field + vector[shared_ptr[CField]] c_fields + + if isinstance(fields, Mapping): + fields = fields.items() + + for item in fields: + if isinstance(item, tuple): + py_field = field(*item) + else: + py_field = item + if py_field is None: + raise TypeError("field or tuple expected, got None") + c_fields.push_back(py_field.sp_field) + + metadata = ensure_metadata(metadata, allow_none=True) + c_meta = pyarrow_unwrap_metadata(metadata) + + c_schema.reset(new CSchema(c_fields, c_meta)) + result = Schema.__new__(Schema) + result.init_schema(c_schema) + + return result + + +def from_numpy_dtype(object dtype): + """ + Convert NumPy dtype to pyarrow.DataType. + + Parameters + ---------- + dtype : the numpy dtype to convert + """ + cdef shared_ptr[CDataType] c_type + dtype = np.dtype(dtype) + with nogil: + check_status(NumPyDtypeToArrow(dtype, &c_type)) + + return pyarrow_wrap_data_type(c_type) + + +def is_boolean_value(object obj): + """ + Check if the object is a boolean. + + Parameters + ---------- + obj : object + The object to check + """ + return IsPyBool(obj) + + +def is_integer_value(object obj): + """ + Check if the object is an integer. + + Parameters + ---------- + obj : object + The object to check + """ + return IsPyInt(obj) + + +def is_float_value(object obj): + """ + Check if the object is a float. + + Parameters + ---------- + obj : object + The object to check + """ + return IsPyFloat(obj) + + +cdef class _ExtensionRegistryNanny(_Weakrefable): + # Keep the registry alive until we have unregistered PyExtensionType + cdef: + shared_ptr[CExtensionTypeRegistry] registry + + def __cinit__(self): + self.registry = CExtensionTypeRegistry.GetGlobalRegistry() + + def release_registry(self): + self.registry.reset() + + +_registry_nanny = _ExtensionRegistryNanny() + + +def _register_py_extension_type(): + cdef: + DataType storage_type + shared_ptr[CExtensionType] cpy_ext_type + c_string c_extension_name = tobytes("arrow.py_extension_type") + + # Make a dummy C++ ExtensionType + storage_type = null() + check_status(CPyExtensionType.FromClass( + storage_type.sp_type, c_extension_name, PyExtensionType, + &cpy_ext_type)) + check_status( + RegisterPyExtensionType(<shared_ptr[CDataType]> cpy_ext_type)) + + +def _unregister_py_extension_types(): + # This needs to be done explicitly before the Python interpreter is + # finalized. If the C++ type is destroyed later in the process + # teardown stage, it will invoke CPython APIs such as Py_DECREF + # with a destroyed interpreter. + unregister_extension_type("arrow.py_extension_type") + for ext_type in _python_extension_types_registry: + try: + unregister_extension_type(ext_type.extension_name) + except KeyError: + pass + _registry_nanny.release_registry() + + +_register_py_extension_type() +atexit.register(_unregister_py_extension_types) diff --git a/src/arrow/python/pyarrow/types.py b/src/arrow/python/pyarrow/types.py new file mode 100644 index 000000000..f239c883b --- /dev/null +++ b/src/arrow/python/pyarrow/types.py @@ -0,0 +1,550 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tools for dealing with Arrow type metadata in Python + + +from pyarrow.lib import (is_boolean_value, # noqa + is_integer_value, + is_float_value) + +import pyarrow.lib as lib + + +_SIGNED_INTEGER_TYPES = {lib.Type_INT8, lib.Type_INT16, lib.Type_INT32, + lib.Type_INT64} +_UNSIGNED_INTEGER_TYPES = {lib.Type_UINT8, lib.Type_UINT16, lib.Type_UINT32, + lib.Type_UINT64} +_INTEGER_TYPES = _SIGNED_INTEGER_TYPES | _UNSIGNED_INTEGER_TYPES +_FLOATING_TYPES = {lib.Type_HALF_FLOAT, lib.Type_FLOAT, lib.Type_DOUBLE} +_DECIMAL_TYPES = {lib.Type_DECIMAL128, lib.Type_DECIMAL256} +_DATE_TYPES = {lib.Type_DATE32, lib.Type_DATE64} +_TIME_TYPES = {lib.Type_TIME32, lib.Type_TIME64} +_INTERVAL_TYPES = {lib.Type_INTERVAL_MONTH_DAY_NANO} +_TEMPORAL_TYPES = ({lib.Type_TIMESTAMP, + lib.Type_DURATION} | _TIME_TYPES | _DATE_TYPES | + _INTERVAL_TYPES) +_UNION_TYPES = {lib.Type_SPARSE_UNION, lib.Type_DENSE_UNION} +_NESTED_TYPES = {lib.Type_LIST, lib.Type_LARGE_LIST, lib.Type_STRUCT, + lib.Type_MAP} | _UNION_TYPES + + +def is_null(t): + """ + Return True if value is an instance of a null type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_NA + + +def is_boolean(t): + """ + Return True if value is an instance of a boolean type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_BOOL + + +def is_integer(t): + """ + Return True if value is an instance of any integer type. + + Parameters + ---------- + t : DataType + """ + return t.id in _INTEGER_TYPES + + +def is_signed_integer(t): + """ + Return True if value is an instance of any signed integer type. + + Parameters + ---------- + t : DataType + """ + return t.id in _SIGNED_INTEGER_TYPES + + +def is_unsigned_integer(t): + """ + Return True if value is an instance of any unsigned integer type. + + Parameters + ---------- + t : DataType + """ + return t.id in _UNSIGNED_INTEGER_TYPES + + +def is_int8(t): + """ + Return True if value is an instance of an int8 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_INT8 + + +def is_int16(t): + """ + Return True if value is an instance of an int16 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_INT16 + + +def is_int32(t): + """ + Return True if value is an instance of an int32 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_INT32 + + +def is_int64(t): + """ + Return True if value is an instance of an int64 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_INT64 + + +def is_uint8(t): + """ + Return True if value is an instance of an uint8 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_UINT8 + + +def is_uint16(t): + """ + Return True if value is an instance of an uint16 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_UINT16 + + +def is_uint32(t): + """ + Return True if value is an instance of an uint32 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_UINT32 + + +def is_uint64(t): + """ + Return True if value is an instance of an uint64 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_UINT64 + + +def is_floating(t): + """ + Return True if value is an instance of a floating point numeric type. + + Parameters + ---------- + t : DataType + """ + return t.id in _FLOATING_TYPES + + +def is_float16(t): + """ + Return True if value is an instance of a float16 (half-precision) type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_HALF_FLOAT + + +def is_float32(t): + """ + Return True if value is an instance of a float32 (single precision) type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_FLOAT + + +def is_float64(t): + """ + Return True if value is an instance of a float64 (double precision) type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_DOUBLE + + +def is_list(t): + """ + Return True if value is an instance of a list type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_LIST + + +def is_large_list(t): + """ + Return True if value is an instance of a large list type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_LARGE_LIST + + +def is_fixed_size_list(t): + """ + Return True if value is an instance of a fixed size list type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_FIXED_SIZE_LIST + + +def is_struct(t): + """ + Return True if value is an instance of a struct type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_STRUCT + + +def is_union(t): + """ + Return True if value is an instance of a union type. + + Parameters + ---------- + t : DataType + """ + return t.id in _UNION_TYPES + + +def is_nested(t): + """ + Return True if value is an instance of a nested type. + + Parameters + ---------- + t : DataType + """ + return t.id in _NESTED_TYPES + + +def is_temporal(t): + """ + Return True if value is an instance of date, time, timestamp or duration. + + Parameters + ---------- + t : DataType + """ + return t.id in _TEMPORAL_TYPES + + +def is_timestamp(t): + """ + Return True if value is an instance of a timestamp type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_TIMESTAMP + + +def is_duration(t): + """ + Return True if value is an instance of a duration type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_DURATION + + +def is_time(t): + """ + Return True if value is an instance of a time type. + + Parameters + ---------- + t : DataType + """ + return t.id in _TIME_TYPES + + +def is_time32(t): + """ + Return True if value is an instance of a time32 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_TIME32 + + +def is_time64(t): + """ + Return True if value is an instance of a time64 type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_TIME64 + + +def is_binary(t): + """ + Return True if value is an instance of a variable-length binary type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_BINARY + + +def is_large_binary(t): + """ + Return True if value is an instance of a large variable-length + binary type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_LARGE_BINARY + + +def is_unicode(t): + """ + Alias for is_string. + + Parameters + ---------- + t : DataType + """ + return is_string(t) + + +def is_string(t): + """ + Return True if value is an instance of string (utf8 unicode) type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_STRING + + +def is_large_unicode(t): + """ + Alias for is_large_string. + + Parameters + ---------- + t : DataType + """ + return is_large_string(t) + + +def is_large_string(t): + """ + Return True if value is an instance of large string (utf8 unicode) type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_LARGE_STRING + + +def is_fixed_size_binary(t): + """ + Return True if value is an instance of a fixed size binary type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_FIXED_SIZE_BINARY + + +def is_date(t): + """ + Return True if value is an instance of a date type. + + Parameters + ---------- + t : DataType + """ + return t.id in _DATE_TYPES + + +def is_date32(t): + """ + Return True if value is an instance of a date32 (days) type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_DATE32 + + +def is_date64(t): + """ + Return True if value is an instance of a date64 (milliseconds) type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_DATE64 + + +def is_map(t): + """ + Return True if value is an instance of a map logical type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_MAP + + +def is_decimal(t): + """ + Return True if value is an instance of a decimal type. + + Parameters + ---------- + t : DataType + """ + return t.id in _DECIMAL_TYPES + + +def is_decimal128(t): + """ + Return True if value is an instance of a decimal type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_DECIMAL128 + + +def is_decimal256(t): + """ + Return True if value is an instance of a decimal type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_DECIMAL256 + + +def is_dictionary(t): + """ + Return True if value is an instance of a dictionary-encoded type. + + Parameters + ---------- + t : DataType + """ + return t.id == lib.Type_DICTIONARY + + +def is_interval(t): + """ + Return True if the value is an instance of an interval type. + + Parameters + ---------- + t : DateType + """ + return t.id == lib.Type_INTERVAL_MONTH_DAY_NANO + + +def is_primitive(t): + """ + Return True if the value is an instance of a primitive type. + + Parameters + ---------- + t : DataType + """ + return lib._is_primitive(t.id) diff --git a/src/arrow/python/pyarrow/util.py b/src/arrow/python/pyarrow/util.py new file mode 100644 index 000000000..69bde250c --- /dev/null +++ b/src/arrow/python/pyarrow/util.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Miscellaneous utility code + +import contextlib +import functools +import gc +import pathlib +import socket +import sys +import types +import warnings + + +_DEPR_MSG = ( + "pyarrow.{} is deprecated as of {}, please use pyarrow.{} instead." +) + + +def implements(f): + def decorator(g): + g.__doc__ = f.__doc__ + return g + return decorator + + +def _deprecate_api(old_name, new_name, api, next_version): + msg = _DEPR_MSG.format(old_name, next_version, new_name) + + def wrapper(*args, **kwargs): + warnings.warn(msg, FutureWarning) + return api(*args, **kwargs) + return wrapper + + +def _deprecate_class(old_name, new_class, next_version, + instancecheck=True): + """ + Raise warning if a deprecated class is used in an isinstance check. + """ + class _DeprecatedMeta(type): + def __instancecheck__(self, other): + warnings.warn( + _DEPR_MSG.format(old_name, next_version, new_class.__name__), + FutureWarning, + stacklevel=2 + ) + return isinstance(other, new_class) + + return _DeprecatedMeta(old_name, (new_class,), {}) + + +def _is_iterable(obj): + try: + iter(obj) + return True + except TypeError: + return False + + +def _is_path_like(path): + # PEP519 filesystem path protocol is available from python 3.6, so pathlib + # doesn't implement __fspath__ for earlier versions + return (isinstance(path, str) or + hasattr(path, '__fspath__') or + isinstance(path, pathlib.Path)) + + +def _stringify_path(path): + """ + Convert *path* to a string or unicode path if possible. + """ + if isinstance(path, str): + return path + + # checking whether path implements the filesystem protocol + try: + return path.__fspath__() # new in python 3.6 + except AttributeError: + # fallback pathlib ckeck for earlier python versions than 3.6 + if isinstance(path, pathlib.Path): + return str(path) + + raise TypeError("not a path-like object") + + +def product(seq): + """ + Return a product of sequence items. + """ + return functools.reduce(lambda a, b: a*b, seq, 1) + + +def get_contiguous_span(shape, strides, itemsize): + """ + Return a contiguous span of N-D array data. + + Parameters + ---------- + shape : tuple + strides : tuple + itemsize : int + Specify array shape data + + Returns + ------- + start, end : int + The span end points. + """ + if not strides: + start = 0 + end = itemsize * product(shape) + else: + start = 0 + end = itemsize + for i, dim in enumerate(shape): + if dim == 0: + start = end = 0 + break + stride = strides[i] + if stride > 0: + end += stride * (dim - 1) + elif stride < 0: + start += stride * (dim - 1) + if end - start != itemsize * product(shape): + raise ValueError('array data is non-contiguous') + return start, end + + +def find_free_port(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with contextlib.closing(sock) as sock: + sock.bind(('', 0)) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return sock.getsockname()[1] + + +def guid(): + from uuid import uuid4 + return uuid4().hex + + +def _break_traceback_cycle_from_frame(frame): + # Clear local variables in all inner frames, so as to break the + # reference cycle. + this_frame = sys._getframe(0) + refs = gc.get_referrers(frame) + while refs: + for frame in refs: + if frame is not this_frame and isinstance(frame, types.FrameType): + break + else: + # No frame found in referrers (finished?) + break + refs = None + # Clear the frame locals, to try and break the cycle (it is + # somewhere along the chain of execution frames). + frame.clear() + # To visit the inner frame, we need to find it among the + # referers of this frame (while `frame.f_back` would let + # us visit the outer frame). + refs = gc.get_referrers(frame) + refs = frame = this_frame = None diff --git a/src/arrow/python/pyarrow/vendored/__init__.py b/src/arrow/python/pyarrow/vendored/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/src/arrow/python/pyarrow/vendored/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/src/arrow/python/pyarrow/vendored/version.py b/src/arrow/python/pyarrow/vendored/version.py new file mode 100644 index 000000000..b74f1da97 --- /dev/null +++ b/src/arrow/python/pyarrow/vendored/version.py @@ -0,0 +1,545 @@ +# Vendored from https://github.com/pypa/packaging, +# changeset b5878c977206f60302536db969a8cef420853ade + +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of the +# `packaging` repository for complete details. + +import collections +import itertools +import re +import warnings + +__all__ = ["parse", "Version", "LegacyVersion", + "InvalidVersion", "VERSION_PATTERN"] + + +class InfinityType: + def __repr__(self): + return "Infinity" + + def __hash__(self): + return hash(repr(self)) + + def __lt__(self, other): + return False + + def __le__(self, other): + return False + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not isinstance(other, self.__class__) + + def __gt__(self, other): + return True + + def __ge__(self, other): + return True + + def __neg__(self): + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self): + return "-Infinity" + + def __hash__(self): + return hash(repr(self)) + + def __lt__(self, other): + return True + + def __le__(self, other): + return True + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not isinstance(other, self.__class__) + + def __gt__(self, other): + return False + + def __ge__(self, other): + return False + + def __neg__(self): + return Infinity + + +NegativeInfinity = NegativeInfinityType() + + +_Version = collections.namedtuple( + "_Version", ["epoch", "release", "dev", "pre", "post", "local"] +) + + +def parse(version): + """ + Parse the given version string and return either a :class:`Version` object + or a :class:`LegacyVersion` object depending on if the given version is + a valid PEP 440 version or a legacy version. + """ + try: + return Version(version) + except InvalidVersion: + return LegacyVersion(version) + + +class InvalidVersion(ValueError): + """ + An invalid version was found, users should refer to PEP 440. + """ + + +class _BaseVersion: + + def __hash__(self): + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other): + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other): + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other): + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other): + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other): + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other): + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +class LegacyVersion(_BaseVersion): + def __init__(self, version): + self._version = str(version) + self._key = _legacy_cmpkey(self._version) + + warnings.warn( + "Creating a LegacyVersion has been deprecated and will be " + "removed in the next major release", + DeprecationWarning, + ) + + def __str__(self): + return self._version + + def __repr__(self): + return f"<LegacyVersion('{self}')>" + + @property + def public(self): + return self._version + + @property + def base_version(self): + return self._version + + @property + def epoch(self): + return -1 + + @property + def release(self): + return None + + @property + def pre(self): + return None + + @property + def post(self): + return None + + @property + def dev(self): + return None + + @property + def local(self): + return None + + @property + def is_prerelease(self): + return False + + @property + def is_postrelease(self): + return False + + @property + def is_devrelease(self): + return False + + +_legacy_version_component_re = re.compile( + r"(\d+ | [a-z]+ | \.| -)", re.VERBOSE) + +_legacy_version_replacement_map = { + "pre": "c", + "preview": "c", + "-": "final-", + "rc": "c", + "dev": "@", +} + + +def _parse_version_parts(s): + for part in _legacy_version_component_re.split(s): + part = _legacy_version_replacement_map.get(part, part) + + if not part or part == ".": + continue + + if part[:1] in "0123456789": + # pad for numeric comparison + yield part.zfill(8) + else: + yield "*" + part + + # ensure that alpha/beta/candidate are before final + yield "*final" + + +def _legacy_cmpkey(version): + + # We hardcode an epoch of -1 here. A PEP 440 version can only have a epoch + # greater than or equal to 0. This will effectively put the LegacyVersion, + # which uses the defacto standard originally implemented by setuptools, + # as before all PEP 440 versions. + epoch = -1 + + # This scheme is taken from pkg_resources.parse_version setuptools prior to + # it's adoption of the packaging library. + parts = [] + for part in _parse_version_parts(version.lower()): + if part.startswith("*"): + # remove "-" before a prerelease tag + if part < "*final": + while parts and parts[-1] == "*final-": + parts.pop() + + # remove trailing zeros from each series of numeric parts + while parts and parts[-1] == "00000000": + parts.pop() + + parts.append(part) + + return epoch, tuple(parts) + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +VERSION_PATTERN = r""" + v? + (?: + (?:(?P<epoch>[0-9]+)!)? # epoch + (?P<release>[0-9]+(?:\.[0-9]+)*) # release segment + (?P<pre> # pre-release + [-_\.]? + (?P<pre_l>(a|b|c|rc|alpha|beta|pre|preview)) + [-_\.]? + (?P<pre_n>[0-9]+)? + )? + (?P<post> # post release + (?:-(?P<post_n1>[0-9]+)) + | + (?: + [-_\.]? + (?P<post_l>post|rev|r) + [-_\.]? + (?P<post_n2>[0-9]+)? + ) + )? + (?P<dev> # dev release + [-_\.]? + (?P<dev_l>dev) + [-_\.]? + (?P<dev_n>[0-9]+)? + )? + ) + (?:\+(?P<local>[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version +""" + + +class Version(_BaseVersion): + + _regex = re.compile(r"^\s*" + VERSION_PATTERN + + r"\s*$", re.VERBOSE | re.IGNORECASE) + + def __init__(self, version): + + # Validate the version and parse it into pieces + match = self._regex.search(version) + if not match: + raise InvalidVersion(f"Invalid version: '{version}'") + + # Store the parsed out pieces of the version + self._version = _Version( + epoch=int(match.group("epoch")) if match.group("epoch") else 0, + release=tuple(int(i) for i in match.group("release").split(".")), + pre=_parse_letter_version( + match.group("pre_l"), match.group("pre_n")), + post=_parse_letter_version( + match.group("post_l"), match.group( + "post_n1") or match.group("post_n2") + ), + dev=_parse_letter_version( + match.group("dev_l"), match.group("dev_n")), + local=_parse_local_version(match.group("local")), + ) + + # Generate a key which will be used for sorting + self._key = _cmpkey( + self._version.epoch, + self._version.release, + self._version.pre, + self._version.post, + self._version.dev, + self._version.local, + ) + + def __repr__(self): + return f"<Version('{self}')>" + + def __str__(self): + parts = [] + + # Epoch + if self.epoch != 0: + parts.append(f"{self.epoch}!") + + # Release segment + parts.append(".".join(str(x) for x in self.release)) + + # Pre-release + if self.pre is not None: + parts.append("".join(str(x) for x in self.pre)) + + # Post-release + if self.post is not None: + parts.append(f".post{self.post}") + + # Development release + if self.dev is not None: + parts.append(f".dev{self.dev}") + + # Local version segment + if self.local is not None: + parts.append(f"+{self.local}") + + return "".join(parts) + + @property + def epoch(self): + _epoch = self._version.epoch + return _epoch + + @property + def release(self): + _release = self._version.release + return _release + + @property + def pre(self): + _pre = self._version.pre + return _pre + + @property + def post(self): + return self._version.post[1] if self._version.post else None + + @property + def dev(self): + return self._version.dev[1] if self._version.dev else None + + @property + def local(self): + if self._version.local: + return ".".join(str(x) for x in self._version.local) + else: + return None + + @property + def public(self): + return str(self).split("+", 1)[0] + + @property + def base_version(self): + parts = [] + + # Epoch + if self.epoch != 0: + parts.append(f"{self.epoch}!") + + # Release segment + parts.append(".".join(str(x) for x in self.release)) + + return "".join(parts) + + @property + def is_prerelease(self): + return self.dev is not None or self.pre is not None + + @property + def is_postrelease(self): + return self.post is not None + + @property + def is_devrelease(self): + return self.dev is not None + + @property + def major(self): + return self.release[0] if len(self.release) >= 1 else 0 + + @property + def minor(self): + return self.release[1] if len(self.release) >= 2 else 0 + + @property + def micro(self): + return self.release[2] if len(self.release) >= 3 else 0 + + +def _parse_letter_version(letter, number): + + if letter: + # We consider there to be an implicit 0 in a pre-release if there is + # not a numeral associated with it. + if number is None: + number = 0 + + # We normalize any letters to their lower case form + letter = letter.lower() + + # We consider some words to be alternate spellings of other words and + # in those cases we want to normalize the spellings to our preferred + # spelling. + if letter == "alpha": + letter = "a" + elif letter == "beta": + letter = "b" + elif letter in ["c", "pre", "preview"]: + letter = "rc" + elif letter in ["rev", "r"]: + letter = "post" + + return letter, int(number) + if not letter and number: + # We assume if we are given a number, but we are not given a letter + # then this is using the implicit post release syntax (e.g. 1.0-1) + letter = "post" + + return letter, int(number) + + return None + + +_local_version_separators = re.compile(r"[\._-]") + + +def _parse_local_version(local): + """ + Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve"). + """ + if local is not None: + return tuple( + part.lower() if not part.isdigit() else int(part) + for part in _local_version_separators.split(local) + ) + return None + + +def _cmpkey(epoch, release, pre, post, dev, local): + + # When we compare a release version, we want to compare it with all of the + # trailing zeros removed. So we'll use a reverse the list, drop all the now + # leading zeros until we come to something non zero, then take the rest + # re-reverse it back into the correct order and make it a tuple and use + # that for our sorting key. + _release = tuple( + reversed(list(itertools.dropwhile(lambda x: x == 0, + reversed(release)))) + ) + + # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0. + # We'll do this by abusing the pre segment, but we _only_ want to do this + # if there is not a pre or a post segment. If we have one of those then + # the normal sorting rules will handle this case correctly. + if pre is None and post is None and dev is not None: + _pre = NegativeInfinity + # Versions without a pre-release (except as noted above) should sort after + # those with one. + elif pre is None: + _pre = Infinity + else: + _pre = pre + + # Versions without a post segment should sort before those with one. + if post is None: + _post = NegativeInfinity + + else: + _post = post + + # Versions without a development segment should sort after those with one. + if dev is None: + _dev = Infinity + + else: + _dev = dev + + if local is None: + # Versions without a local segment should sort before those with one. + _local = NegativeInfinity + else: + # Versions with a local segment need that segment parsed to implement + # the sorting rules in PEP440. + # - Alpha numeric segments sort before numeric segments + # - Alpha numeric segments sort lexicographically + # - Numeric segments sort numerically + # - Shorter versions sort before longer versions when the prefixes + # match exactly + _local = tuple( + (i, "") if isinstance(i, int) else (NegativeInfinity, i) + for i in local + ) + + return epoch, _release, _pre, _post, _dev, _local |