summaryrefslogtreecommitdiffstats
path: root/src/arrow/python/pyarrow/tests/test_ipc.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/python/pyarrow/tests/test_ipc.py')
-rw-r--r--src/arrow/python/pyarrow/tests/test_ipc.py999
1 files changed, 999 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/tests/test_ipc.py b/src/arrow/python/pyarrow/tests/test_ipc.py
new file mode 100644
index 000000000..87944bcc0
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_ipc.py
@@ -0,0 +1,999 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from collections import UserList
+import io
+import pathlib
+import pytest
+import socket
+import threading
+import weakref
+
+import numpy as np
+
+import pyarrow as pa
+from pyarrow.tests.util import changed_environ
+
+
+try:
+ from pandas.testing import assert_frame_equal, assert_series_equal
+ import pandas as pd
+except ImportError:
+ pass
+
+
+class IpcFixture:
+ write_stats = None
+
+ def __init__(self, sink_factory=lambda: io.BytesIO()):
+ self._sink_factory = sink_factory
+ self.sink = self.get_sink()
+
+ def get_sink(self):
+ return self._sink_factory()
+
+ def get_source(self):
+ return self.sink.getvalue()
+
+ def write_batches(self, num_batches=5, as_table=False):
+ nrows = 5
+ schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])
+
+ writer = self._get_writer(self.sink, schema)
+
+ batches = []
+ for i in range(num_batches):
+ batch = pa.record_batch(
+ [np.random.randn(nrows),
+ ['foo', None, 'bar', 'bazbaz', 'qux']],
+ schema=schema)
+ batches.append(batch)
+
+ if as_table:
+ table = pa.Table.from_batches(batches)
+ writer.write_table(table)
+ else:
+ for batch in batches:
+ writer.write_batch(batch)
+
+ self.write_stats = writer.stats
+ writer.close()
+ return batches
+
+
+class FileFormatFixture(IpcFixture):
+
+ def _get_writer(self, sink, schema):
+ return pa.ipc.new_file(sink, schema)
+
+ def _check_roundtrip(self, as_table=False):
+ batches = self.write_batches(as_table=as_table)
+ file_contents = pa.BufferReader(self.get_source())
+
+ reader = pa.ipc.open_file(file_contents)
+
+ assert reader.num_record_batches == len(batches)
+
+ for i, batch in enumerate(batches):
+ # it works. Must convert back to DataFrame
+ batch = reader.get_batch(i)
+ assert batches[i].equals(batch)
+ assert reader.schema.equals(batches[0].schema)
+
+ assert isinstance(reader.stats, pa.ipc.ReadStats)
+ assert isinstance(self.write_stats, pa.ipc.WriteStats)
+ assert tuple(reader.stats) == tuple(self.write_stats)
+
+
+class StreamFormatFixture(IpcFixture):
+
+ # ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
+ use_legacy_ipc_format = False
+ # ARROW-9395, for testing writing old metadata version
+ options = None
+
+ def _get_writer(self, sink, schema):
+ return pa.ipc.new_stream(
+ sink,
+ schema,
+ use_legacy_format=self.use_legacy_ipc_format,
+ options=self.options,
+ )
+
+
+class MessageFixture(IpcFixture):
+
+ def _get_writer(self, sink, schema):
+ return pa.RecordBatchStreamWriter(sink, schema)
+
+
+@pytest.fixture
+def ipc_fixture():
+ return IpcFixture()
+
+
+@pytest.fixture
+def file_fixture():
+ return FileFormatFixture()
+
+
+@pytest.fixture
+def stream_fixture():
+ return StreamFormatFixture()
+
+
+def test_empty_file():
+ buf = b''
+ with pytest.raises(pa.ArrowInvalid):
+ pa.ipc.open_file(pa.BufferReader(buf))
+
+
+def test_file_simple_roundtrip(file_fixture):
+ file_fixture._check_roundtrip(as_table=False)
+
+
+def test_file_write_table(file_fixture):
+ file_fixture._check_roundtrip(as_table=True)
+
+
+@pytest.mark.parametrize("sink_factory", [
+ lambda: io.BytesIO(),
+ lambda: pa.BufferOutputStream()
+])
+def test_file_read_all(sink_factory):
+ fixture = FileFormatFixture(sink_factory)
+
+ batches = fixture.write_batches()
+ file_contents = pa.BufferReader(fixture.get_source())
+
+ reader = pa.ipc.open_file(file_contents)
+
+ result = reader.read_all()
+ expected = pa.Table.from_batches(batches)
+ assert result.equals(expected)
+
+
+def test_open_file_from_buffer(file_fixture):
+ # ARROW-2859; APIs accept the buffer protocol
+ file_fixture.write_batches()
+ source = file_fixture.get_source()
+
+ reader1 = pa.ipc.open_file(source)
+ reader2 = pa.ipc.open_file(pa.BufferReader(source))
+ reader3 = pa.RecordBatchFileReader(source)
+
+ result1 = reader1.read_all()
+ result2 = reader2.read_all()
+ result3 = reader3.read_all()
+
+ assert result1.equals(result2)
+ assert result1.equals(result3)
+
+ st1 = reader1.stats
+ assert st1.num_messages == 6
+ assert st1.num_record_batches == 5
+ assert reader2.stats == st1
+ assert reader3.stats == st1
+
+
+@pytest.mark.pandas
+def test_file_read_pandas(file_fixture):
+ frames = [batch.to_pandas() for batch in file_fixture.write_batches()]
+
+ file_contents = pa.BufferReader(file_fixture.get_source())
+ reader = pa.ipc.open_file(file_contents)
+ result = reader.read_pandas()
+
+ expected = pd.concat(frames).reset_index(drop=True)
+ assert_frame_equal(result, expected)
+
+
+def test_file_pathlib(file_fixture, tmpdir):
+ file_fixture.write_batches()
+ source = file_fixture.get_source()
+
+ path = tmpdir.join('file.arrow').strpath
+ with open(path, 'wb') as f:
+ f.write(source)
+
+ t1 = pa.ipc.open_file(pathlib.Path(path)).read_all()
+ t2 = pa.ipc.open_file(pa.OSFile(path)).read_all()
+
+ assert t1.equals(t2)
+
+
+def test_empty_stream():
+ buf = io.BytesIO(b'')
+ with pytest.raises(pa.ArrowInvalid):
+ pa.ipc.open_stream(buf)
+
+
+@pytest.mark.pandas
+def test_stream_categorical_roundtrip(stream_fixture):
+ df = pd.DataFrame({
+ 'one': np.random.randn(5),
+ 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
+ categories=['foo', 'bar'],
+ ordered=True)
+ })
+ batch = pa.RecordBatch.from_pandas(df)
+ with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr:
+ wr.write_batch(batch)
+
+ table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
+ .read_all())
+ assert_frame_equal(table.to_pandas(), df)
+
+
+def test_open_stream_from_buffer(stream_fixture):
+ # ARROW-2859
+ stream_fixture.write_batches()
+ source = stream_fixture.get_source()
+
+ reader1 = pa.ipc.open_stream(source)
+ reader2 = pa.ipc.open_stream(pa.BufferReader(source))
+ reader3 = pa.RecordBatchStreamReader(source)
+
+ result1 = reader1.read_all()
+ result2 = reader2.read_all()
+ result3 = reader3.read_all()
+
+ assert result1.equals(result2)
+ assert result1.equals(result3)
+
+ st1 = reader1.stats
+ assert st1.num_messages == 6
+ assert st1.num_record_batches == 5
+ assert reader2.stats == st1
+ assert reader3.stats == st1
+
+ assert tuple(st1) == tuple(stream_fixture.write_stats)
+
+
+@pytest.mark.pandas
+def test_stream_write_dispatch(stream_fixture):
+ # ARROW-1616
+ df = pd.DataFrame({
+ 'one': np.random.randn(5),
+ 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
+ categories=['foo', 'bar'],
+ ordered=True)
+ })
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ batch = pa.RecordBatch.from_pandas(df, preserve_index=False)
+ with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr:
+ wr.write(table)
+ wr.write(batch)
+
+ table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
+ .read_all())
+ assert_frame_equal(table.to_pandas(),
+ pd.concat([df, df], ignore_index=True))
+
+
+@pytest.mark.pandas
+def test_stream_write_table_batches(stream_fixture):
+ # ARROW-504
+ df = pd.DataFrame({
+ 'one': np.random.randn(20),
+ })
+
+ b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False)
+ b2 = pa.RecordBatch.from_pandas(df, preserve_index=False)
+
+ table = pa.Table.from_batches([b1, b2, b1])
+
+ with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr:
+ wr.write_table(table, max_chunksize=15)
+
+ batches = list(pa.ipc.open_stream(stream_fixture.get_source()))
+
+ assert list(map(len, batches)) == [10, 15, 5, 10]
+ result_table = pa.Table.from_batches(batches)
+ assert_frame_equal(result_table.to_pandas(),
+ pd.concat([df[:10], df, df[:10]],
+ ignore_index=True))
+
+
+@pytest.mark.parametrize('use_legacy_ipc_format', [False, True])
+def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format):
+ stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format
+ batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+ reader = pa.ipc.open_stream(file_contents)
+
+ assert reader.schema.equals(batches[0].schema)
+
+ total = 0
+ for i, next_batch in enumerate(reader):
+ assert next_batch.equals(batches[i])
+ total += 1
+
+ assert total == len(batches)
+
+ with pytest.raises(StopIteration):
+ reader.read_next_batch()
+
+
+@pytest.mark.zstd
+def test_compression_roundtrip():
+ sink = io.BytesIO()
+ values = np.random.randint(0, 10, 10000)
+ table = pa.Table.from_arrays([values], names=["values"])
+
+ options = pa.ipc.IpcWriteOptions(compression='zstd')
+ with pa.ipc.RecordBatchFileWriter(
+ sink, table.schema, options=options) as writer:
+ writer.write_table(table)
+ len1 = len(sink.getvalue())
+
+ sink2 = io.BytesIO()
+ codec = pa.Codec('zstd', compression_level=5)
+ options = pa.ipc.IpcWriteOptions(compression=codec)
+ with pa.ipc.RecordBatchFileWriter(
+ sink2, table.schema, options=options) as writer:
+ writer.write_table(table)
+ len2 = len(sink2.getvalue())
+
+ # In theory len2 should be less than len1 but for this test we just want
+ # to ensure compression_level is being correctly passed down to the C++
+ # layer so we don't really care if it makes it worse or better
+ assert len2 != len1
+
+ t1 = pa.ipc.open_file(sink).read_all()
+ t2 = pa.ipc.open_file(sink2).read_all()
+
+ assert t1 == t2
+
+
+def test_write_options():
+ options = pa.ipc.IpcWriteOptions()
+ assert options.allow_64bit is False
+ assert options.use_legacy_format is False
+ assert options.metadata_version == pa.ipc.MetadataVersion.V5
+
+ options.allow_64bit = True
+ assert options.allow_64bit is True
+
+ options.use_legacy_format = True
+ assert options.use_legacy_format is True
+
+ options.metadata_version = pa.ipc.MetadataVersion.V4
+ assert options.metadata_version == pa.ipc.MetadataVersion.V4
+ for value in ('V5', 42):
+ with pytest.raises((TypeError, ValueError)):
+ options.metadata_version = value
+
+ assert options.compression is None
+ for value in ['lz4', 'zstd']:
+ if pa.Codec.is_available(value):
+ options.compression = value
+ assert options.compression == value
+ options.compression = value.upper()
+ assert options.compression == value
+ options.compression = None
+ assert options.compression is None
+
+ with pytest.raises(TypeError):
+ options.compression = 0
+
+ assert options.use_threads is True
+ options.use_threads = False
+ assert options.use_threads is False
+
+ if pa.Codec.is_available('lz4'):
+ options = pa.ipc.IpcWriteOptions(
+ metadata_version=pa.ipc.MetadataVersion.V4,
+ allow_64bit=True,
+ use_legacy_format=True,
+ compression='lz4',
+ use_threads=False)
+ assert options.metadata_version == pa.ipc.MetadataVersion.V4
+ assert options.allow_64bit is True
+ assert options.use_legacy_format is True
+ assert options.compression == 'lz4'
+ assert options.use_threads is False
+
+
+def test_write_options_legacy_exclusive(stream_fixture):
+ with pytest.raises(
+ ValueError,
+ match="provide at most one of options and use_legacy_format"):
+ stream_fixture.use_legacy_ipc_format = True
+ stream_fixture.options = pa.ipc.IpcWriteOptions()
+ stream_fixture.write_batches()
+
+
+@pytest.mark.parametrize('options', [
+ pa.ipc.IpcWriteOptions(),
+ pa.ipc.IpcWriteOptions(allow_64bit=True),
+ pa.ipc.IpcWriteOptions(use_legacy_format=True),
+ pa.ipc.IpcWriteOptions(metadata_version=pa.ipc.MetadataVersion.V4),
+ pa.ipc.IpcWriteOptions(use_legacy_format=True,
+ metadata_version=pa.ipc.MetadataVersion.V4),
+])
+def test_stream_options_roundtrip(stream_fixture, options):
+ stream_fixture.use_legacy_ipc_format = None
+ stream_fixture.options = options
+ batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+
+ message = pa.ipc.read_message(stream_fixture.get_source())
+ assert message.metadata_version == options.metadata_version
+
+ reader = pa.ipc.open_stream(file_contents)
+
+ assert reader.schema.equals(batches[0].schema)
+
+ total = 0
+ for i, next_batch in enumerate(reader):
+ assert next_batch.equals(batches[i])
+ total += 1
+
+ assert total == len(batches)
+
+ with pytest.raises(StopIteration):
+ reader.read_next_batch()
+
+
+def test_dictionary_delta(stream_fixture):
+ ty = pa.dictionary(pa.int8(), pa.utf8())
+ data = [["foo", "foo", None],
+ ["foo", "bar", "foo"], # potential delta
+ ["foo", "bar"],
+ ["foo", None, "bar", "quux"], # potential delta
+ ["bar", "quux"], # replacement
+ ]
+ batches = [
+ pa.RecordBatch.from_arrays([pa.array(v, type=ty)], names=['dicts'])
+ for v in data]
+ schema = batches[0].schema
+
+ def write_batches():
+ with stream_fixture._get_writer(pa.MockOutputStream(),
+ schema) as writer:
+ for batch in batches:
+ writer.write_batch(batch)
+ return writer.stats
+
+ st = write_batches()
+ assert st.num_record_batches == 5
+ assert st.num_dictionary_batches == 4
+ assert st.num_replaced_dictionaries == 3
+ assert st.num_dictionary_deltas == 0
+
+ stream_fixture.use_legacy_ipc_format = None
+ stream_fixture.options = pa.ipc.IpcWriteOptions(
+ emit_dictionary_deltas=True)
+ st = write_batches()
+ assert st.num_record_batches == 5
+ assert st.num_dictionary_batches == 4
+ assert st.num_replaced_dictionaries == 1
+ assert st.num_dictionary_deltas == 2
+
+
+def test_envvar_set_legacy_ipc_format():
+ schema = pa.schema([pa.field('foo', pa.int32())])
+
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+
+ with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V5
+
+ with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert not writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+
+ with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
+ with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
+ writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+ writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
+ assert writer._use_legacy_format
+ assert writer._metadata_version == pa.ipc.MetadataVersion.V4
+
+
+def test_stream_read_all(stream_fixture):
+ batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+ reader = pa.ipc.open_stream(file_contents)
+
+ result = reader.read_all()
+ expected = pa.Table.from_batches(batches)
+ assert result.equals(expected)
+
+
+@pytest.mark.pandas
+def test_stream_read_pandas(stream_fixture):
+ frames = [batch.to_pandas() for batch in stream_fixture.write_batches()]
+ file_contents = stream_fixture.get_source()
+ reader = pa.ipc.open_stream(file_contents)
+ result = reader.read_pandas()
+
+ expected = pd.concat(frames).reset_index(drop=True)
+ assert_frame_equal(result, expected)
+
+
+@pytest.fixture
+def example_messages(stream_fixture):
+ batches = stream_fixture.write_batches()
+ file_contents = stream_fixture.get_source()
+ buf_reader = pa.BufferReader(file_contents)
+ reader = pa.MessageReader.open_stream(buf_reader)
+ return batches, list(reader)
+
+
+def test_message_ctors_no_segfault():
+ with pytest.raises(TypeError):
+ repr(pa.Message())
+
+ with pytest.raises(TypeError):
+ repr(pa.MessageReader())
+
+
+def test_message_reader(example_messages):
+ _, messages = example_messages
+
+ assert len(messages) == 6
+ assert messages[0].type == 'schema'
+ assert isinstance(messages[0].metadata, pa.Buffer)
+ assert isinstance(messages[0].body, pa.Buffer)
+ assert messages[0].metadata_version == pa.MetadataVersion.V5
+
+ for msg in messages[1:]:
+ assert msg.type == 'record batch'
+ assert isinstance(msg.metadata, pa.Buffer)
+ assert isinstance(msg.body, pa.Buffer)
+ assert msg.metadata_version == pa.MetadataVersion.V5
+
+
+def test_message_serialize_read_message(example_messages):
+ _, messages = example_messages
+
+ msg = messages[0]
+ buf = msg.serialize()
+ reader = pa.BufferReader(buf.to_pybytes() * 2)
+
+ restored = pa.ipc.read_message(buf)
+ restored2 = pa.ipc.read_message(reader)
+ restored3 = pa.ipc.read_message(buf.to_pybytes())
+ restored4 = pa.ipc.read_message(reader)
+
+ assert msg.equals(restored)
+ assert msg.equals(restored2)
+ assert msg.equals(restored3)
+ assert msg.equals(restored4)
+
+ with pytest.raises(pa.ArrowInvalid, match="Corrupted message"):
+ pa.ipc.read_message(pa.BufferReader(b'ab'))
+
+ with pytest.raises(EOFError):
+ pa.ipc.read_message(reader)
+
+
+@pytest.mark.gzip
+def test_message_read_from_compressed(example_messages):
+ # Part of ARROW-5910
+ _, messages = example_messages
+ for message in messages:
+ raw_out = pa.BufferOutputStream()
+ with pa.output_stream(raw_out, compression='gzip') as compressed_out:
+ message.serialize_to(compressed_out)
+
+ compressed_buf = raw_out.getvalue()
+
+ result = pa.ipc.read_message(pa.input_stream(compressed_buf,
+ compression='gzip'))
+ assert result.equals(message)
+
+
+def test_message_read_record_batch(example_messages):
+ batches, messages = example_messages
+
+ for batch, message in zip(batches, messages[1:]):
+ read_batch = pa.ipc.read_record_batch(message, batch.schema)
+ assert read_batch.equals(batch)
+
+
+def test_read_record_batch_on_stream_error_message():
+ # ARROW-5374
+ batch = pa.record_batch([pa.array([b"foo"], type=pa.utf8())],
+ names=['strs'])
+ stream = pa.BufferOutputStream()
+ with pa.ipc.new_stream(stream, batch.schema) as writer:
+ writer.write_batch(batch)
+ buf = stream.getvalue()
+ with pytest.raises(IOError,
+ match="type record batch but got schema"):
+ pa.ipc.read_record_batch(buf, batch.schema)
+
+
+# ----------------------------------------------------------------------
+# Socket streaming testa
+
+
+class StreamReaderServer(threading.Thread):
+
+ def init(self, do_read_all):
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock.bind(('127.0.0.1', 0))
+ self._sock.listen(1)
+ host, port = self._sock.getsockname()
+ self._do_read_all = do_read_all
+ self._schema = None
+ self._batches = []
+ self._table = None
+ return port
+
+ def run(self):
+ connection, client_address = self._sock.accept()
+ try:
+ source = connection.makefile(mode='rb')
+ reader = pa.ipc.open_stream(source)
+ self._schema = reader.schema
+ if self._do_read_all:
+ self._table = reader.read_all()
+ else:
+ for i, batch in enumerate(reader):
+ self._batches.append(batch)
+ finally:
+ connection.close()
+
+ def get_result(self):
+ return(self._schema, self._table if self._do_read_all
+ else self._batches)
+
+
+class SocketStreamFixture(IpcFixture):
+
+ def __init__(self):
+ # XXX(wesm): test will decide when to start socket server. This should
+ # probably be refactored
+ pass
+
+ def start_server(self, do_read_all):
+ self._server = StreamReaderServer()
+ port = self._server.init(do_read_all)
+ self._server.start()
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock.connect(('127.0.0.1', port))
+ self.sink = self.get_sink()
+
+ def stop_and_get_result(self):
+ import struct
+ self.sink.write(struct.pack('Q', 0))
+ self.sink.flush()
+ self._sock.close()
+ self._server.join()
+ return self._server.get_result()
+
+ def get_sink(self):
+ return self._sock.makefile(mode='wb')
+
+ def _get_writer(self, sink, schema):
+ return pa.RecordBatchStreamWriter(sink, schema)
+
+
+@pytest.fixture
+def socket_fixture():
+ return SocketStreamFixture()
+
+
+def test_socket_simple_roundtrip(socket_fixture):
+ socket_fixture.start_server(do_read_all=False)
+ writer_batches = socket_fixture.write_batches()
+ reader_schema, reader_batches = socket_fixture.stop_and_get_result()
+
+ assert reader_schema.equals(writer_batches[0].schema)
+ assert len(reader_batches) == len(writer_batches)
+ for i, batch in enumerate(writer_batches):
+ assert reader_batches[i].equals(batch)
+
+
+def test_socket_read_all(socket_fixture):
+ socket_fixture.start_server(do_read_all=True)
+ writer_batches = socket_fixture.write_batches()
+ _, result = socket_fixture.stop_and_get_result()
+
+ expected = pa.Table.from_batches(writer_batches)
+ assert result.equals(expected)
+
+
+# ----------------------------------------------------------------------
+# Miscellaneous IPC tests
+
+@pytest.mark.pandas
+def test_ipc_file_stream_has_eos():
+ # ARROW-5395
+ df = pd.DataFrame({'foo': [1.5]})
+ batch = pa.RecordBatch.from_pandas(df)
+ sink = pa.BufferOutputStream()
+ write_file(batch, sink)
+ buffer = sink.getvalue()
+
+ # skip the file magic
+ reader = pa.ipc.open_stream(buffer[8:])
+
+ # will fail if encounters footer data instead of eos
+ rdf = reader.read_pandas()
+
+ assert_frame_equal(df, rdf)
+
+
+@pytest.mark.pandas
+def test_ipc_zero_copy_numpy():
+ df = pd.DataFrame({'foo': [1.5]})
+
+ batch = pa.RecordBatch.from_pandas(df)
+ sink = pa.BufferOutputStream()
+ write_file(batch, sink)
+ buffer = sink.getvalue()
+ reader = pa.BufferReader(buffer)
+
+ batches = read_file(reader)
+
+ data = batches[0].to_pandas()
+ rdf = pd.DataFrame(data)
+ assert_frame_equal(df, rdf)
+
+
+def test_ipc_stream_no_batches():
+ # ARROW-2307
+ table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]),
+ pa.array(['foo', 'bar', 'baz', 'qux'])],
+ names=['a', 'b'])
+
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, table.schema):
+ pass
+
+ source = sink.getvalue()
+ with pa.ipc.open_stream(source) as reader:
+ result = reader.read_all()
+
+ assert result.schema.equals(table.schema)
+ assert len(result) == 0
+
+
+@pytest.mark.pandas
+def test_get_record_batch_size():
+ N = 10
+ itemsize = 8
+ df = pd.DataFrame({'foo': np.random.randn(N)})
+
+ batch = pa.RecordBatch.from_pandas(df)
+ assert pa.ipc.get_record_batch_size(batch) > (N * itemsize)
+
+
+@pytest.mark.pandas
+def _check_serialize_pandas_round_trip(df, use_threads=False):
+ buf = pa.serialize_pandas(df, nthreads=2 if use_threads else 1)
+ result = pa.deserialize_pandas(buf, use_threads=use_threads)
+ assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip():
+ index = pd.Index([1, 2, 3], name='my_index')
+ columns = ['foo', 'bar']
+ df = pd.DataFrame(
+ {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
+ index=index, columns=columns
+ )
+ _check_serialize_pandas_round_trip(df)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip_nthreads():
+ index = pd.Index([1, 2, 3], name='my_index')
+ columns = ['foo', 'bar']
+ df = pd.DataFrame(
+ {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
+ index=index, columns=columns
+ )
+ _check_serialize_pandas_round_trip(df, use_threads=True)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip_multi_index():
+ index1 = pd.Index([1, 2, 3], name='level_1')
+ index2 = pd.Index(list('def'), name=None)
+ index = pd.MultiIndex.from_arrays([index1, index2])
+
+ columns = ['foo', 'bar']
+ df = pd.DataFrame(
+ {'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
+ index=index,
+ columns=columns,
+ )
+ _check_serialize_pandas_round_trip(df)
+
+
+@pytest.mark.pandas
+def test_serialize_pandas_empty_dataframe():
+ df = pd.DataFrame()
+ _check_serialize_pandas_round_trip(df)
+
+
+@pytest.mark.pandas
+def test_pandas_serialize_round_trip_not_string_columns():
+ df = pd.DataFrame(list(zip([1.5, 1.6, 1.7], 'abc')))
+ buf = pa.serialize_pandas(df)
+ result = pa.deserialize_pandas(buf)
+ assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+def test_serialize_pandas_no_preserve_index():
+ df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3])
+ expected = pd.DataFrame({'a': [1, 2, 3]})
+
+ buf = pa.serialize_pandas(df, preserve_index=False)
+ result = pa.deserialize_pandas(buf)
+ assert_frame_equal(result, expected)
+
+ buf = pa.serialize_pandas(df, preserve_index=True)
+ result = pa.deserialize_pandas(buf)
+ assert_frame_equal(result, df)
+
+
+@pytest.mark.pandas
+@pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning")
+def test_serialize_with_pandas_objects():
+ df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3])
+ s = pd.Series([1, 2, 3, 4])
+
+ data = {
+ 'a_series': df['a'],
+ 'a_frame': df,
+ 's_series': s
+ }
+
+ serialized = pa.serialize(data).to_buffer()
+ deserialized = pa.deserialize(serialized)
+ assert_frame_equal(deserialized['a_frame'], df)
+
+ assert_series_equal(deserialized['a_series'], df['a'])
+ assert deserialized['a_series'].name == 'a'
+
+ assert_series_equal(deserialized['s_series'], s)
+ assert deserialized['s_series'].name is None
+
+
+@pytest.mark.pandas
+def test_schema_batch_serialize_methods():
+ nrows = 5
+ df = pd.DataFrame({
+ 'one': np.random.randn(nrows),
+ 'two': ['foo', np.nan, 'bar', 'bazbaz', 'qux']})
+ batch = pa.RecordBatch.from_pandas(df)
+
+ s_schema = batch.schema.serialize()
+ s_batch = batch.serialize()
+
+ recons_schema = pa.ipc.read_schema(s_schema)
+ recons_batch = pa.ipc.read_record_batch(s_batch, recons_schema)
+ assert recons_batch.equals(batch)
+
+
+def test_schema_serialization_with_metadata():
+ field_metadata = {b'foo': b'bar', b'kind': b'field'}
+ schema_metadata = {b'foo': b'bar', b'kind': b'schema'}
+
+ f0 = pa.field('a', pa.int8())
+ f1 = pa.field('b', pa.string(), metadata=field_metadata)
+
+ schema = pa.schema([f0, f1], metadata=schema_metadata)
+
+ s_schema = schema.serialize()
+ recons_schema = pa.ipc.read_schema(s_schema)
+
+ assert recons_schema.equals(schema)
+ assert recons_schema.metadata == schema_metadata
+ assert recons_schema[0].metadata is None
+ assert recons_schema[1].metadata == field_metadata
+
+
+def test_deprecated_pyarrow_ns_apis():
+ table = pa.table([pa.array([1, 2, 3, 4])], names=['a'])
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, table.schema) as writer:
+ writer.write(table)
+
+ with pytest.warns(FutureWarning,
+ match="please use pyarrow.ipc.open_stream"):
+ pa.open_stream(sink.getvalue())
+
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_file(sink, table.schema) as writer:
+ writer.write(table)
+ with pytest.warns(FutureWarning, match="please use pyarrow.ipc.open_file"):
+ pa.open_file(sink.getvalue())
+
+
+def write_file(batch, sink):
+ with pa.ipc.new_file(sink, batch.schema) as writer:
+ writer.write_batch(batch)
+
+
+def read_file(source):
+ with pa.ipc.open_file(source) as reader:
+ return [reader.get_batch(i) for i in range(reader.num_record_batches)]
+
+
+def test_write_empty_ipc_file():
+ # ARROW-3894: IPC file was not being properly initialized when no record
+ # batches are being written
+ schema = pa.schema([('field', pa.int64())])
+
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_file(sink, schema):
+ pass
+
+ buf = sink.getvalue()
+ with pa.RecordBatchFileReader(pa.BufferReader(buf)) as reader:
+ table = reader.read_all()
+ assert len(table) == 0
+ assert table.schema.equals(schema)
+
+
+def test_py_record_batch_reader():
+ def make_schema():
+ return pa.schema([('field', pa.int64())])
+
+ def make_batches():
+ schema = make_schema()
+ batch1 = pa.record_batch([[1, 2, 3]], schema=schema)
+ batch2 = pa.record_batch([[4, 5]], schema=schema)
+ return [batch1, batch2]
+
+ # With iterable
+ batches = UserList(make_batches()) # weakrefable
+ wr = weakref.ref(batches)
+
+ with pa.ipc.RecordBatchReader.from_batches(make_schema(),
+ batches) as reader:
+ batches = None
+ assert wr() is not None
+ assert list(reader) == make_batches()
+ assert wr() is None
+
+ # With iterator
+ batches = iter(UserList(make_batches())) # weakrefable
+ wr = weakref.ref(batches)
+
+ with pa.ipc.RecordBatchReader.from_batches(make_schema(),
+ batches) as reader:
+ batches = None
+ assert wr() is not None
+ assert list(reader) == make_batches()
+ assert wr() is None