From e6918187568dbd01842d8d1d2c808ce16a894239 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 21 Apr 2024 13:54:28 +0200 Subject: Adding upstream version 18.2.2. Signed-off-by: Daniel Baumann --- src/arrow/python/pyarrow/tests/test_jvm.py | 433 +++++++++++++++++++++++++++++ 1 file changed, 433 insertions(+) create mode 100644 src/arrow/python/pyarrow/tests/test_jvm.py (limited to 'src/arrow/python/pyarrow/tests/test_jvm.py') diff --git a/src/arrow/python/pyarrow/tests/test_jvm.py b/src/arrow/python/pyarrow/tests/test_jvm.py new file mode 100644 index 000000000..c5996f921 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_jvm.py @@ -0,0 +1,433 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import os +import pyarrow as pa +import pyarrow.jvm as pa_jvm +import pytest +import sys +import xml.etree.ElementTree as ET + + +jpype = pytest.importorskip("jpype") + + +@pytest.fixture(scope="session") +def root_allocator(): + # This test requires Arrow Java to be built in the same source tree + try: + arrow_dir = os.environ["ARROW_SOURCE_DIR"] + except KeyError: + arrow_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..') + pom_path = os.path.join(arrow_dir, 'java', 'pom.xml') + tree = ET.parse(pom_path) + version = tree.getroot().find( + 'POM:version', + namespaces={ + 'POM': 'http://maven.apache.org/POM/4.0.0' + }).text + jar_path = os.path.join( + arrow_dir, 'java', 'tools', 'target', + 'arrow-tools-{}-jar-with-dependencies.jar'.format(version)) + jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path) + kwargs = {} + # This will be the default behaviour in jpype 0.8+ + kwargs['convertStrings'] = False + jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path, + **kwargs) + return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize) + + +def test_jvm_buffer(root_allocator): + # Create a Java buffer + jvm_buffer = root_allocator.buffer(8) + for i in range(8): + jvm_buffer.setByte(i, 8 - i) + + orig_refcnt = jvm_buffer.refCnt() + + # Convert to Python + buf = pa_jvm.jvm_buffer(jvm_buffer) + + # Check its content + assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01' + + # Check Java buffer lifetime is tied to PyArrow buffer lifetime + assert jvm_buffer.refCnt() == orig_refcnt + 1 + del buf + assert jvm_buffer.refCnt() == orig_refcnt + + +def test_jvm_buffer_released(root_allocator): + import jpype.imports # noqa + from java.lang import IllegalArgumentException + + jvm_buffer = root_allocator.buffer(8) + jvm_buffer.release() + + with pytest.raises(IllegalArgumentException): + pa_jvm.jvm_buffer(jvm_buffer) + + +def _jvm_field(jvm_spec): + om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() + pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field') + return om.readValue(jvm_spec, pojo_Field) + + +def _jvm_schema(jvm_spec, metadata=None): + field = _jvm_field(jvm_spec) + schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema') + fields = jpype.JClass('java.util.ArrayList')() + fields.add(field) + if metadata: + dct = jpype.JClass('java.util.HashMap')() + for k, v in metadata.items(): + dct.put(k, v) + return schema_cls(fields, dct) + else: + return schema_cls(fields) + + +# In the following, we use the JSON serialization of the Field objects in Java. +# This ensures that we neither rely on the exact mechanics on how to construct +# them using Java code as well as enables us to define them as parameters +# without to invoke the JVM. +# +# The specifications were created using: +# +# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() +# field = … # Code to instantiate the field +# jvm_spec = om.writeValueAsString(field) +@pytest.mark.parametrize('pa_type,jvm_spec', [ + (pa.null(), '{"name":"null"}'), + (pa.bool_(), '{"name":"bool"}'), + (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'), + (pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'), + (pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'), + (pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'), + (pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'), + (pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'), + (pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'), + (pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'), + (pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'), + (pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'), + (pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'), + (pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'), + (pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'), + (pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'), + (pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'), + (pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",' + '"timezone":null}'), + (pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",' + '"timezone":null}'), + (pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",' + '"timezone":null}'), + (pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",' + '"timezone":null}'), + (pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"' + ',"timezone":"UTC"}'), + (pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",' + '"unit":"NANOSECOND","timezone":"Europe/Paris"}'), + (pa.date32(), '{"name":"date","unit":"DAY"}'), + (pa.date64(), '{"name":"date","unit":"MILLISECOND"}'), + (pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'), + (pa.string(), '{"name":"utf8"}'), + (pa.binary(), '{"name":"binary"}'), + (pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'), + # TODO(ARROW-2609): complex types that have children + # pa.list_(pa.int32()), + # pa.struct([pa.field('a', pa.int32()), + # pa.field('b', pa.int8()), + # pa.field('c', pa.string())]), + # pa.union([pa.field('a', pa.binary(10)), + # pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), + # pa.union([pa.field('a', pa.binary(10)), + # pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), + # TODO: DictionaryType requires a vector in the type + # pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])), +]) +@pytest.mark.parametrize('nullable', [True, False]) +def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable): + if pa_type == pa.null() and not nullable: + return + spec = { + 'name': 'field_name', + 'nullable': nullable, + 'type': json.loads(jvm_spec), + # TODO: This needs to be set for complex types + 'children': [] + } + jvm_field = _jvm_field(json.dumps(spec)) + result = pa_jvm.field(jvm_field) + expected_field = pa.field('field_name', pa_type, nullable=nullable) + assert result == expected_field + + jvm_schema = _jvm_schema(json.dumps(spec)) + result = pa_jvm.schema(jvm_schema) + assert result == pa.schema([expected_field]) + + # Schema with custom metadata + jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'}) + result = pa_jvm.schema(jvm_schema) + assert result == pa.schema([expected_field], {'meta': 'data'}) + + # Schema with custom field metadata + spec['metadata'] = [{'key': 'field meta', 'value': 'field data'}] + jvm_schema = _jvm_schema(json.dumps(spec)) + result = pa_jvm.schema(jvm_schema) + expected_field = expected_field.with_metadata( + {'field meta': 'field data'}) + assert result == pa.schema([expected_field]) + + +# These test parameters mostly use an integer range as an input as this is +# often the only type that is understood by both Python and Java +# implementations of Arrow. +@pytest.mark.parametrize('pa_type,py_data,jvm_type', [ + (pa.bool_(), [True, False, True, True], 'BitVector'), + (pa.uint8(), list(range(128)), 'UInt1Vector'), + (pa.uint16(), list(range(128)), 'UInt2Vector'), + (pa.int32(), list(range(128)), 'IntVector'), + (pa.int64(), list(range(128)), 'BigIntVector'), + (pa.float32(), list(range(128)), 'Float4Vector'), + (pa.float64(), list(range(128)), 'Float8Vector'), + (pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'), + (pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'), + (pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'), + (pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'), + # TODO(ARROW-2605): These types miss a conversion from pure Python objects + # * pa.time32('s') + # * pa.time32('ms') + # * pa.time64('us') + # * pa.time64('ns') + (pa.date32(), list(range(128)), 'DateDayVector'), + (pa.date64(), list(range(128)), 'DateMilliVector'), + # TODO(ARROW-2606): pa.decimal128(19, 4) +]) +def test_jvm_array(root_allocator, pa_type, py_data, jvm_type): + # Create vector + cls = "org.apache.arrow.vector.{}".format(jvm_type) + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew(len(py_data)) + for i, val in enumerate(py_data): + # char and int are ambiguous overloads for these two setSafe calls + if jvm_type in {'UInt1Vector', 'UInt2Vector'}: + val = jpype.JInt(val) + jvm_vector.setSafe(i, val) + jvm_vector.setValueCount(len(py_data)) + + py_array = pa.array(py_data, type=pa_type) + jvm_array = pa_jvm.array(jvm_vector) + + assert py_array.equals(jvm_array) + + +def test_jvm_array_empty(root_allocator): + cls = "org.apache.arrow.vector.{}".format('IntVector') + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew() + jvm_array = pa_jvm.array(jvm_vector) + assert len(jvm_array) == 0 + assert jvm_array.type == pa.int32() + + +# These test parameters mostly use an integer range as an input as this is +# often the only type that is understood by both Python and Java +# implementations of Arrow. +@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [ + # TODO: null + (pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'), + ( + pa.uint8(), + list(range(128)), + 'UInt1Vector', + '{"name":"int","bitWidth":8,"isSigned":false}' + ), + ( + pa.uint16(), + list(range(128)), + 'UInt2Vector', + '{"name":"int","bitWidth":16,"isSigned":false}' + ), + ( + pa.uint32(), + list(range(128)), + 'UInt4Vector', + '{"name":"int","bitWidth":32,"isSigned":false}' + ), + ( + pa.uint64(), + list(range(128)), + 'UInt8Vector', + '{"name":"int","bitWidth":64,"isSigned":false}' + ), + ( + pa.int8(), + list(range(128)), + 'TinyIntVector', + '{"name":"int","bitWidth":8,"isSigned":true}' + ), + ( + pa.int16(), + list(range(128)), + 'SmallIntVector', + '{"name":"int","bitWidth":16,"isSigned":true}' + ), + ( + pa.int32(), + list(range(128)), + 'IntVector', + '{"name":"int","bitWidth":32,"isSigned":true}' + ), + ( + pa.int64(), + list(range(128)), + 'BigIntVector', + '{"name":"int","bitWidth":64,"isSigned":true}' + ), + # TODO: float16 + ( + pa.float32(), + list(range(128)), + 'Float4Vector', + '{"name":"floatingpoint","precision":"SINGLE"}' + ), + ( + pa.float64(), + list(range(128)), + 'Float8Vector', + '{"name":"floatingpoint","precision":"DOUBLE"}' + ), + ( + pa.timestamp('s'), + list(range(128)), + 'TimeStampSecVector', + '{"name":"timestamp","unit":"SECOND","timezone":null}' + ), + ( + pa.timestamp('ms'), + list(range(128)), + 'TimeStampMilliVector', + '{"name":"timestamp","unit":"MILLISECOND","timezone":null}' + ), + ( + pa.timestamp('us'), + list(range(128)), + 'TimeStampMicroVector', + '{"name":"timestamp","unit":"MICROSECOND","timezone":null}' + ), + ( + pa.timestamp('ns'), + list(range(128)), + 'TimeStampNanoVector', + '{"name":"timestamp","unit":"NANOSECOND","timezone":null}' + ), + # TODO(ARROW-2605): These types miss a conversion from pure Python objects + # * pa.time32('s') + # * pa.time32('ms') + # * pa.time64('us') + # * pa.time64('ns') + ( + pa.date32(), + list(range(128)), + 'DateDayVector', + '{"name":"date","unit":"DAY"}' + ), + ( + pa.date64(), + list(range(128)), + 'DateMilliVector', + '{"name":"date","unit":"MILLISECOND"}' + ), + # TODO(ARROW-2606): pa.decimal128(19, 4) +]) +def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type, + jvm_spec): + # Create vector + cls = "org.apache.arrow.vector.{}".format(jvm_type) + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew(len(py_data)) + for i, val in enumerate(py_data): + if jvm_type in {'UInt1Vector', 'UInt2Vector'}: + val = jpype.JInt(val) + jvm_vector.setSafe(i, val) + jvm_vector.setValueCount(len(py_data)) + + # Create field + spec = { + 'name': 'field_name', + 'nullable': False, + 'type': json.loads(jvm_spec), + # TODO: This needs to be set for complex types + 'children': [] + } + jvm_field = _jvm_field(json.dumps(spec)) + + # Create VectorSchemaRoot + jvm_fields = jpype.JClass('java.util.ArrayList')() + jvm_fields.add(jvm_field) + jvm_vectors = jpype.JClass('java.util.ArrayList')() + jvm_vectors.add(jvm_vector) + jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot') + jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data)) + + py_record_batch = pa.RecordBatch.from_arrays( + [pa.array(py_data, type=pa_type)], + ['col'] + ) + jvm_record_batch = pa_jvm.record_batch(jvm_vsr) + + assert py_record_batch.equals(jvm_record_batch) + + +def _string_to_varchar_holder(ra, string): + nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder" + holder = jpype.JClass(nvch_cls)() + if string is None: + holder.isSet = 0 + else: + holder.isSet = 1 + value = jpype.JClass("java.lang.String")("string") + std_charsets = jpype.JClass("java.nio.charset.StandardCharsets") + bytes_ = value.getBytes(std_charsets.UTF_8) + holder.buffer = ra.buffer(len(bytes_)) + holder.buffer.setBytes(0, bytes_, 0, len(bytes_)) + holder.start = 0 + holder.end = len(bytes_) + return holder + + +# TODO(ARROW-2607) +@pytest.mark.xfail(reason="from_buffers is only supported for " + "primitive arrays yet") +def test_jvm_string_array(root_allocator): + data = ["string", None, "töst"] + cls = "org.apache.arrow.vector.VarCharVector" + jvm_vector = jpype.JClass(cls)("vector", root_allocator) + jvm_vector.allocateNew() + + for i, string in enumerate(data): + holder = _string_to_varchar_holder(root_allocator, "string") + jvm_vector.setSafe(i, holder) + jvm_vector.setValueCount(i + 1) + + py_array = pa.array(data, type=pa.string()) + jvm_array = pa_jvm.array(jvm_vector) + + assert py_array.equals(jvm_array) -- cgit v1.2.3