diff options
Diffstat (limited to 'src/arrow/python/pyarrow/tests/test_ipc.py')
-rw-r--r-- | src/arrow/python/pyarrow/tests/test_ipc.py | 999 |
1 files changed, 999 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/tests/test_ipc.py b/src/arrow/python/pyarrow/tests/test_ipc.py new file mode 100644 index 000000000..87944bcc0 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_ipc.py @@ -0,0 +1,999 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from collections import UserList +import io +import pathlib +import pytest +import socket +import threading +import weakref + +import numpy as np + +import pyarrow as pa +from pyarrow.tests.util import changed_environ + + +try: + from pandas.testing import assert_frame_equal, assert_series_equal + import pandas as pd +except ImportError: + pass + + +class IpcFixture: + write_stats = None + + def __init__(self, sink_factory=lambda: io.BytesIO()): + self._sink_factory = sink_factory + self.sink = self.get_sink() + + def get_sink(self): + return self._sink_factory() + + def get_source(self): + return self.sink.getvalue() + + def write_batches(self, num_batches=5, as_table=False): + nrows = 5 + schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())]) + + writer = self._get_writer(self.sink, schema) + + batches = [] + for i in range(num_batches): + batch = pa.record_batch( + [np.random.randn(nrows), + ['foo', None, 'bar', 'bazbaz', 'qux']], + schema=schema) + batches.append(batch) + + if as_table: + table = pa.Table.from_batches(batches) + writer.write_table(table) + else: + for batch in batches: + writer.write_batch(batch) + + self.write_stats = writer.stats + writer.close() + return batches + + +class FileFormatFixture(IpcFixture): + + def _get_writer(self, sink, schema): + return pa.ipc.new_file(sink, schema) + + def _check_roundtrip(self, as_table=False): + batches = self.write_batches(as_table=as_table) + file_contents = pa.BufferReader(self.get_source()) + + reader = pa.ipc.open_file(file_contents) + + assert reader.num_record_batches == len(batches) + + for i, batch in enumerate(batches): + # it works. Must convert back to DataFrame + batch = reader.get_batch(i) + assert batches[i].equals(batch) + assert reader.schema.equals(batches[0].schema) + + assert isinstance(reader.stats, pa.ipc.ReadStats) + assert isinstance(self.write_stats, pa.ipc.WriteStats) + assert tuple(reader.stats) == tuple(self.write_stats) + + +class StreamFormatFixture(IpcFixture): + + # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix + use_legacy_ipc_format = False + # ARROW-9395, for testing writing old metadata version + options = None + + def _get_writer(self, sink, schema): + return pa.ipc.new_stream( + sink, + schema, + use_legacy_format=self.use_legacy_ipc_format, + options=self.options, + ) + + +class MessageFixture(IpcFixture): + + def _get_writer(self, sink, schema): + return pa.RecordBatchStreamWriter(sink, schema) + + +@pytest.fixture +def ipc_fixture(): + return IpcFixture() + + +@pytest.fixture +def file_fixture(): + return FileFormatFixture() + + +@pytest.fixture +def stream_fixture(): + return StreamFormatFixture() + + +def test_empty_file(): + buf = b'' + with pytest.raises(pa.ArrowInvalid): + pa.ipc.open_file(pa.BufferReader(buf)) + + +def test_file_simple_roundtrip(file_fixture): + file_fixture._check_roundtrip(as_table=False) + + +def test_file_write_table(file_fixture): + file_fixture._check_roundtrip(as_table=True) + + +@pytest.mark.parametrize("sink_factory", [ + lambda: io.BytesIO(), + lambda: pa.BufferOutputStream() +]) +def test_file_read_all(sink_factory): + fixture = FileFormatFixture(sink_factory) + + batches = fixture.write_batches() + file_contents = pa.BufferReader(fixture.get_source()) + + reader = pa.ipc.open_file(file_contents) + + result = reader.read_all() + expected = pa.Table.from_batches(batches) + assert result.equals(expected) + + +def test_open_file_from_buffer(file_fixture): + # ARROW-2859; APIs accept the buffer protocol + file_fixture.write_batches() + source = file_fixture.get_source() + + reader1 = pa.ipc.open_file(source) + reader2 = pa.ipc.open_file(pa.BufferReader(source)) + reader3 = pa.RecordBatchFileReader(source) + + result1 = reader1.read_all() + result2 = reader2.read_all() + result3 = reader3.read_all() + + assert result1.equals(result2) + assert result1.equals(result3) + + st1 = reader1.stats + assert st1.num_messages == 6 + assert st1.num_record_batches == 5 + assert reader2.stats == st1 + assert reader3.stats == st1 + + +@pytest.mark.pandas +def test_file_read_pandas(file_fixture): + frames = [batch.to_pandas() for batch in file_fixture.write_batches()] + + file_contents = pa.BufferReader(file_fixture.get_source()) + reader = pa.ipc.open_file(file_contents) + result = reader.read_pandas() + + expected = pd.concat(frames).reset_index(drop=True) + assert_frame_equal(result, expected) + + +def test_file_pathlib(file_fixture, tmpdir): + file_fixture.write_batches() + source = file_fixture.get_source() + + path = tmpdir.join('file.arrow').strpath + with open(path, 'wb') as f: + f.write(source) + + t1 = pa.ipc.open_file(pathlib.Path(path)).read_all() + t2 = pa.ipc.open_file(pa.OSFile(path)).read_all() + + assert t1.equals(t2) + + +def test_empty_stream(): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.ipc.open_stream(buf) + + +@pytest.mark.pandas +def test_stream_categorical_roundtrip(stream_fixture): + df = pd.DataFrame({ + 'one': np.random.randn(5), + 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'], + categories=['foo', 'bar'], + ordered=True) + }) + batch = pa.RecordBatch.from_pandas(df) + with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr: + wr.write_batch(batch) + + table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source())) + .read_all()) + assert_frame_equal(table.to_pandas(), df) + + +def test_open_stream_from_buffer(stream_fixture): + # ARROW-2859 + stream_fixture.write_batches() + source = stream_fixture.get_source() + + reader1 = pa.ipc.open_stream(source) + reader2 = pa.ipc.open_stream(pa.BufferReader(source)) + reader3 = pa.RecordBatchStreamReader(source) + + result1 = reader1.read_all() + result2 = reader2.read_all() + result3 = reader3.read_all() + + assert result1.equals(result2) + assert result1.equals(result3) + + st1 = reader1.stats + assert st1.num_messages == 6 + assert st1.num_record_batches == 5 + assert reader2.stats == st1 + assert reader3.stats == st1 + + assert tuple(st1) == tuple(stream_fixture.write_stats) + + +@pytest.mark.pandas +def test_stream_write_dispatch(stream_fixture): + # ARROW-1616 + df = pd.DataFrame({ + 'one': np.random.randn(5), + 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'], + categories=['foo', 'bar'], + ordered=True) + }) + table = pa.Table.from_pandas(df, preserve_index=False) + batch = pa.RecordBatch.from_pandas(df, preserve_index=False) + with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr: + wr.write(table) + wr.write(batch) + + table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source())) + .read_all()) + assert_frame_equal(table.to_pandas(), + pd.concat([df, df], ignore_index=True)) + + +@pytest.mark.pandas +def test_stream_write_table_batches(stream_fixture): + # ARROW-504 + df = pd.DataFrame({ + 'one': np.random.randn(20), + }) + + b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False) + b2 = pa.RecordBatch.from_pandas(df, preserve_index=False) + + table = pa.Table.from_batches([b1, b2, b1]) + + with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr: + wr.write_table(table, max_chunksize=15) + + batches = list(pa.ipc.open_stream(stream_fixture.get_source())) + + assert list(map(len, batches)) == [10, 15, 5, 10] + result_table = pa.Table.from_batches(batches) + assert_frame_equal(result_table.to_pandas(), + pd.concat([df[:10], df, df[:10]], + ignore_index=True)) + + +@pytest.mark.parametrize('use_legacy_ipc_format', [False, True]) +def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format): + stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format + batches = stream_fixture.write_batches() + file_contents = pa.BufferReader(stream_fixture.get_source()) + reader = pa.ipc.open_stream(file_contents) + + assert reader.schema.equals(batches[0].schema) + + total = 0 + for i, next_batch in enumerate(reader): + assert next_batch.equals(batches[i]) + total += 1 + + assert total == len(batches) + + with pytest.raises(StopIteration): + reader.read_next_batch() + + +@pytest.mark.zstd +def test_compression_roundtrip(): + sink = io.BytesIO() + values = np.random.randint(0, 10, 10000) + table = pa.Table.from_arrays([values], names=["values"]) + + options = pa.ipc.IpcWriteOptions(compression='zstd') + with pa.ipc.RecordBatchFileWriter( + sink, table.schema, options=options) as writer: + writer.write_table(table) + len1 = len(sink.getvalue()) + + sink2 = io.BytesIO() + codec = pa.Codec('zstd', compression_level=5) + options = pa.ipc.IpcWriteOptions(compression=codec) + with pa.ipc.RecordBatchFileWriter( + sink2, table.schema, options=options) as writer: + writer.write_table(table) + len2 = len(sink2.getvalue()) + + # In theory len2 should be less than len1 but for this test we just want + # to ensure compression_level is being correctly passed down to the C++ + # layer so we don't really care if it makes it worse or better + assert len2 != len1 + + t1 = pa.ipc.open_file(sink).read_all() + t2 = pa.ipc.open_file(sink2).read_all() + + assert t1 == t2 + + +def test_write_options(): + options = pa.ipc.IpcWriteOptions() + assert options.allow_64bit is False + assert options.use_legacy_format is False + assert options.metadata_version == pa.ipc.MetadataVersion.V5 + + options.allow_64bit = True + assert options.allow_64bit is True + + options.use_legacy_format = True + assert options.use_legacy_format is True + + options.metadata_version = pa.ipc.MetadataVersion.V4 + assert options.metadata_version == pa.ipc.MetadataVersion.V4 + for value in ('V5', 42): + with pytest.raises((TypeError, ValueError)): + options.metadata_version = value + + assert options.compression is None + for value in ['lz4', 'zstd']: + if pa.Codec.is_available(value): + options.compression = value + assert options.compression == value + options.compression = value.upper() + assert options.compression == value + options.compression = None + assert options.compression is None + + with pytest.raises(TypeError): + options.compression = 0 + + assert options.use_threads is True + options.use_threads = False + assert options.use_threads is False + + if pa.Codec.is_available('lz4'): + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4, + allow_64bit=True, + use_legacy_format=True, + compression='lz4', + use_threads=False) + assert options.metadata_version == pa.ipc.MetadataVersion.V4 + assert options.allow_64bit is True + assert options.use_legacy_format is True + assert options.compression == 'lz4' + assert options.use_threads is False + + +def test_write_options_legacy_exclusive(stream_fixture): + with pytest.raises( + ValueError, + match="provide at most one of options and use_legacy_format"): + stream_fixture.use_legacy_ipc_format = True + stream_fixture.options = pa.ipc.IpcWriteOptions() + stream_fixture.write_batches() + + +@pytest.mark.parametrize('options', [ + pa.ipc.IpcWriteOptions(), + pa.ipc.IpcWriteOptions(allow_64bit=True), + pa.ipc.IpcWriteOptions(use_legacy_format=True), + pa.ipc.IpcWriteOptions(metadata_version=pa.ipc.MetadataVersion.V4), + pa.ipc.IpcWriteOptions(use_legacy_format=True, + metadata_version=pa.ipc.MetadataVersion.V4), +]) +def test_stream_options_roundtrip(stream_fixture, options): + stream_fixture.use_legacy_ipc_format = None + stream_fixture.options = options + batches = stream_fixture.write_batches() + file_contents = pa.BufferReader(stream_fixture.get_source()) + + message = pa.ipc.read_message(stream_fixture.get_source()) + assert message.metadata_version == options.metadata_version + + reader = pa.ipc.open_stream(file_contents) + + assert reader.schema.equals(batches[0].schema) + + total = 0 + for i, next_batch in enumerate(reader): + assert next_batch.equals(batches[i]) + total += 1 + + assert total == len(batches) + + with pytest.raises(StopIteration): + reader.read_next_batch() + + +def test_dictionary_delta(stream_fixture): + ty = pa.dictionary(pa.int8(), pa.utf8()) + data = [["foo", "foo", None], + ["foo", "bar", "foo"], # potential delta + ["foo", "bar"], + ["foo", None, "bar", "quux"], # potential delta + ["bar", "quux"], # replacement + ] + batches = [ + pa.RecordBatch.from_arrays([pa.array(v, type=ty)], names=['dicts']) + for v in data] + schema = batches[0].schema + + def write_batches(): + with stream_fixture._get_writer(pa.MockOutputStream(), + schema) as writer: + for batch in batches: + writer.write_batch(batch) + return writer.stats + + st = write_batches() + assert st.num_record_batches == 5 + assert st.num_dictionary_batches == 4 + assert st.num_replaced_dictionaries == 3 + assert st.num_dictionary_deltas == 0 + + stream_fixture.use_legacy_ipc_format = None + stream_fixture.options = pa.ipc.IpcWriteOptions( + emit_dictionary_deltas=True) + st = write_batches() + assert st.num_record_batches == 5 + assert st.num_dictionary_batches == 4 + assert st.num_replaced_dictionaries == 1 + assert st.num_dictionary_deltas == 2 + + +def test_envvar_set_legacy_ipc_format(): + schema = pa.schema([pa.field('foo', pa.int32())]) + + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + + with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'): + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V5 + + with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'): + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert not writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + + with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'): + with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'): + writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + writer = pa.ipc.new_file(pa.BufferOutputStream(), schema) + assert writer._use_legacy_format + assert writer._metadata_version == pa.ipc.MetadataVersion.V4 + + +def test_stream_read_all(stream_fixture): + batches = stream_fixture.write_batches() + file_contents = pa.BufferReader(stream_fixture.get_source()) + reader = pa.ipc.open_stream(file_contents) + + result = reader.read_all() + expected = pa.Table.from_batches(batches) + assert result.equals(expected) + + +@pytest.mark.pandas +def test_stream_read_pandas(stream_fixture): + frames = [batch.to_pandas() for batch in stream_fixture.write_batches()] + file_contents = stream_fixture.get_source() + reader = pa.ipc.open_stream(file_contents) + result = reader.read_pandas() + + expected = pd.concat(frames).reset_index(drop=True) + assert_frame_equal(result, expected) + + +@pytest.fixture +def example_messages(stream_fixture): + batches = stream_fixture.write_batches() + file_contents = stream_fixture.get_source() + buf_reader = pa.BufferReader(file_contents) + reader = pa.MessageReader.open_stream(buf_reader) + return batches, list(reader) + + +def test_message_ctors_no_segfault(): + with pytest.raises(TypeError): + repr(pa.Message()) + + with pytest.raises(TypeError): + repr(pa.MessageReader()) + + +def test_message_reader(example_messages): + _, messages = example_messages + + assert len(messages) == 6 + assert messages[0].type == 'schema' + assert isinstance(messages[0].metadata, pa.Buffer) + assert isinstance(messages[0].body, pa.Buffer) + assert messages[0].metadata_version == pa.MetadataVersion.V5 + + for msg in messages[1:]: + assert msg.type == 'record batch' + assert isinstance(msg.metadata, pa.Buffer) + assert isinstance(msg.body, pa.Buffer) + assert msg.metadata_version == pa.MetadataVersion.V5 + + +def test_message_serialize_read_message(example_messages): + _, messages = example_messages + + msg = messages[0] + buf = msg.serialize() + reader = pa.BufferReader(buf.to_pybytes() * 2) + + restored = pa.ipc.read_message(buf) + restored2 = pa.ipc.read_message(reader) + restored3 = pa.ipc.read_message(buf.to_pybytes()) + restored4 = pa.ipc.read_message(reader) + + assert msg.equals(restored) + assert msg.equals(restored2) + assert msg.equals(restored3) + assert msg.equals(restored4) + + with pytest.raises(pa.ArrowInvalid, match="Corrupted message"): + pa.ipc.read_message(pa.BufferReader(b'ab')) + + with pytest.raises(EOFError): + pa.ipc.read_message(reader) + + +@pytest.mark.gzip +def test_message_read_from_compressed(example_messages): + # Part of ARROW-5910 + _, messages = example_messages + for message in messages: + raw_out = pa.BufferOutputStream() + with pa.output_stream(raw_out, compression='gzip') as compressed_out: + message.serialize_to(compressed_out) + + compressed_buf = raw_out.getvalue() + + result = pa.ipc.read_message(pa.input_stream(compressed_buf, + compression='gzip')) + assert result.equals(message) + + +def test_message_read_record_batch(example_messages): + batches, messages = example_messages + + for batch, message in zip(batches, messages[1:]): + read_batch = pa.ipc.read_record_batch(message, batch.schema) + assert read_batch.equals(batch) + + +def test_read_record_batch_on_stream_error_message(): + # ARROW-5374 + batch = pa.record_batch([pa.array([b"foo"], type=pa.utf8())], + names=['strs']) + stream = pa.BufferOutputStream() + with pa.ipc.new_stream(stream, batch.schema) as writer: + writer.write_batch(batch) + buf = stream.getvalue() + with pytest.raises(IOError, + match="type record batch but got schema"): + pa.ipc.read_record_batch(buf, batch.schema) + + +# ---------------------------------------------------------------------- +# Socket streaming testa + + +class StreamReaderServer(threading.Thread): + + def init(self, do_read_all): + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.bind(('127.0.0.1', 0)) + self._sock.listen(1) + host, port = self._sock.getsockname() + self._do_read_all = do_read_all + self._schema = None + self._batches = [] + self._table = None + return port + + def run(self): + connection, client_address = self._sock.accept() + try: + source = connection.makefile(mode='rb') + reader = pa.ipc.open_stream(source) + self._schema = reader.schema + if self._do_read_all: + self._table = reader.read_all() + else: + for i, batch in enumerate(reader): + self._batches.append(batch) + finally: + connection.close() + + def get_result(self): + return(self._schema, self._table if self._do_read_all + else self._batches) + + +class SocketStreamFixture(IpcFixture): + + def __init__(self): + # XXX(wesm): test will decide when to start socket server. This should + # probably be refactored + pass + + def start_server(self, do_read_all): + self._server = StreamReaderServer() + port = self._server.init(do_read_all) + self._server.start() + self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._sock.connect(('127.0.0.1', port)) + self.sink = self.get_sink() + + def stop_and_get_result(self): + import struct + self.sink.write(struct.pack('Q', 0)) + self.sink.flush() + self._sock.close() + self._server.join() + return self._server.get_result() + + def get_sink(self): + return self._sock.makefile(mode='wb') + + def _get_writer(self, sink, schema): + return pa.RecordBatchStreamWriter(sink, schema) + + +@pytest.fixture +def socket_fixture(): + return SocketStreamFixture() + + +def test_socket_simple_roundtrip(socket_fixture): + socket_fixture.start_server(do_read_all=False) + writer_batches = socket_fixture.write_batches() + reader_schema, reader_batches = socket_fixture.stop_and_get_result() + + assert reader_schema.equals(writer_batches[0].schema) + assert len(reader_batches) == len(writer_batches) + for i, batch in enumerate(writer_batches): + assert reader_batches[i].equals(batch) + + +def test_socket_read_all(socket_fixture): + socket_fixture.start_server(do_read_all=True) + writer_batches = socket_fixture.write_batches() + _, result = socket_fixture.stop_and_get_result() + + expected = pa.Table.from_batches(writer_batches) + assert result.equals(expected) + + +# ---------------------------------------------------------------------- +# Miscellaneous IPC tests + +@pytest.mark.pandas +def test_ipc_file_stream_has_eos(): + # ARROW-5395 + df = pd.DataFrame({'foo': [1.5]}) + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + write_file(batch, sink) + buffer = sink.getvalue() + + # skip the file magic + reader = pa.ipc.open_stream(buffer[8:]) + + # will fail if encounters footer data instead of eos + rdf = reader.read_pandas() + + assert_frame_equal(df, rdf) + + +@pytest.mark.pandas +def test_ipc_zero_copy_numpy(): + df = pd.DataFrame({'foo': [1.5]}) + + batch = pa.RecordBatch.from_pandas(df) + sink = pa.BufferOutputStream() + write_file(batch, sink) + buffer = sink.getvalue() + reader = pa.BufferReader(buffer) + + batches = read_file(reader) + + data = batches[0].to_pandas() + rdf = pd.DataFrame(data) + assert_frame_equal(df, rdf) + + +def test_ipc_stream_no_batches(): + # ARROW-2307 + table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]), + pa.array(['foo', 'bar', 'baz', 'qux'])], + names=['a', 'b']) + + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, table.schema): + pass + + source = sink.getvalue() + with pa.ipc.open_stream(source) as reader: + result = reader.read_all() + + assert result.schema.equals(table.schema) + assert len(result) == 0 + + +@pytest.mark.pandas +def test_get_record_batch_size(): + N = 10 + itemsize = 8 + df = pd.DataFrame({'foo': np.random.randn(N)}) + + batch = pa.RecordBatch.from_pandas(df) + assert pa.ipc.get_record_batch_size(batch) > (N * itemsize) + + +@pytest.mark.pandas +def _check_serialize_pandas_round_trip(df, use_threads=False): + buf = pa.serialize_pandas(df, nthreads=2 if use_threads else 1) + result = pa.deserialize_pandas(buf, use_threads=use_threads) + assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip(): + index = pd.Index([1, 2, 3], name='my_index') + columns = ['foo', 'bar'] + df = pd.DataFrame( + {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, + index=index, columns=columns + ) + _check_serialize_pandas_round_trip(df) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip_nthreads(): + index = pd.Index([1, 2, 3], name='my_index') + columns = ['foo', 'bar'] + df = pd.DataFrame( + {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, + index=index, columns=columns + ) + _check_serialize_pandas_round_trip(df, use_threads=True) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip_multi_index(): + index1 = pd.Index([1, 2, 3], name='level_1') + index2 = pd.Index(list('def'), name=None) + index = pd.MultiIndex.from_arrays([index1, index2]) + + columns = ['foo', 'bar'] + df = pd.DataFrame( + {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')}, + index=index, + columns=columns, + ) + _check_serialize_pandas_round_trip(df) + + +@pytest.mark.pandas +def test_serialize_pandas_empty_dataframe(): + df = pd.DataFrame() + _check_serialize_pandas_round_trip(df) + + +@pytest.mark.pandas +def test_pandas_serialize_round_trip_not_string_columns(): + df = pd.DataFrame(list(zip([1.5, 1.6, 1.7], 'abc'))) + buf = pa.serialize_pandas(df) + result = pa.deserialize_pandas(buf) + assert_frame_equal(result, df) + + +@pytest.mark.pandas +def test_serialize_pandas_no_preserve_index(): + df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3]) + expected = pd.DataFrame({'a': [1, 2, 3]}) + + buf = pa.serialize_pandas(df, preserve_index=False) + result = pa.deserialize_pandas(buf) + assert_frame_equal(result, expected) + + buf = pa.serialize_pandas(df, preserve_index=True) + result = pa.deserialize_pandas(buf) + assert_frame_equal(result, df) + + +@pytest.mark.pandas +@pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning") +def test_serialize_with_pandas_objects(): + df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3]) + s = pd.Series([1, 2, 3, 4]) + + data = { + 'a_series': df['a'], + 'a_frame': df, + 's_series': s + } + + serialized = pa.serialize(data).to_buffer() + deserialized = pa.deserialize(serialized) + assert_frame_equal(deserialized['a_frame'], df) + + assert_series_equal(deserialized['a_series'], df['a']) + assert deserialized['a_series'].name == 'a' + + assert_series_equal(deserialized['s_series'], s) + assert deserialized['s_series'].name is None + + +@pytest.mark.pandas +def test_schema_batch_serialize_methods(): + nrows = 5 + df = pd.DataFrame({ + 'one': np.random.randn(nrows), + 'two': ['foo', np.nan, 'bar', 'bazbaz', 'qux']}) + batch = pa.RecordBatch.from_pandas(df) + + s_schema = batch.schema.serialize() + s_batch = batch.serialize() + + recons_schema = pa.ipc.read_schema(s_schema) + recons_batch = pa.ipc.read_record_batch(s_batch, recons_schema) + assert recons_batch.equals(batch) + + +def test_schema_serialization_with_metadata(): + field_metadata = {b'foo': b'bar', b'kind': b'field'} + schema_metadata = {b'foo': b'bar', b'kind': b'schema'} + + f0 = pa.field('a', pa.int8()) + f1 = pa.field('b', pa.string(), metadata=field_metadata) + + schema = pa.schema([f0, f1], metadata=schema_metadata) + + s_schema = schema.serialize() + recons_schema = pa.ipc.read_schema(s_schema) + + assert recons_schema.equals(schema) + assert recons_schema.metadata == schema_metadata + assert recons_schema[0].metadata is None + assert recons_schema[1].metadata == field_metadata + + +def test_deprecated_pyarrow_ns_apis(): + table = pa.table([pa.array([1, 2, 3, 4])], names=['a']) + sink = pa.BufferOutputStream() + with pa.ipc.new_stream(sink, table.schema) as writer: + writer.write(table) + + with pytest.warns(FutureWarning, + match="please use pyarrow.ipc.open_stream"): + pa.open_stream(sink.getvalue()) + + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, table.schema) as writer: + writer.write(table) + with pytest.warns(FutureWarning, match="please use pyarrow.ipc.open_file"): + pa.open_file(sink.getvalue()) + + +def write_file(batch, sink): + with pa.ipc.new_file(sink, batch.schema) as writer: + writer.write_batch(batch) + + +def read_file(source): + with pa.ipc.open_file(source) as reader: + return [reader.get_batch(i) for i in range(reader.num_record_batches)] + + +def test_write_empty_ipc_file(): + # ARROW-3894: IPC file was not being properly initialized when no record + # batches are being written + schema = pa.schema([('field', pa.int64())]) + + sink = pa.BufferOutputStream() + with pa.ipc.new_file(sink, schema): + pass + + buf = sink.getvalue() + with pa.RecordBatchFileReader(pa.BufferReader(buf)) as reader: + table = reader.read_all() + assert len(table) == 0 + assert table.schema.equals(schema) + + +def test_py_record_batch_reader(): + def make_schema(): + return pa.schema([('field', pa.int64())]) + + def make_batches(): + schema = make_schema() + batch1 = pa.record_batch([[1, 2, 3]], schema=schema) + batch2 = pa.record_batch([[4, 5]], schema=schema) + return [batch1, batch2] + + # With iterable + batches = UserList(make_batches()) # weakrefable + wr = weakref.ref(batches) + + with pa.ipc.RecordBatchReader.from_batches(make_schema(), + batches) as reader: + batches = None + assert wr() is not None + assert list(reader) == make_batches() + assert wr() is None + + # With iterator + batches = iter(UserList(make_batches())) # weakrefable + wr = weakref.ref(batches) + + with pa.ipc.RecordBatchReader.from_batches(make_schema(), + batches) as reader: + batches = None + assert wr() is not None + assert list(reader) == make_batches() + assert wr() is None |