# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import json import os import pyarrow as pa import pyarrow.jvm as pa_jvm import pytest import sys import xml.etree.ElementTree as ET jpype = pytest.importorskip("jpype") @pytest.fixture(scope="session") def root_allocator(): # This test requires Arrow Java to be built in the same source tree try: arrow_dir = os.environ["ARROW_SOURCE_DIR"] except KeyError: arrow_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..') pom_path = os.path.join(arrow_dir, 'java', 'pom.xml') tree = ET.parse(pom_path) version = tree.getroot().find( 'POM:version', namespaces={ 'POM': 'http://maven.apache.org/POM/4.0.0' }).text jar_path = os.path.join( arrow_dir, 'java', 'tools', 'target', 'arrow-tools-{}-jar-with-dependencies.jar'.format(version)) jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path) kwargs = {} # This will be the default behaviour in jpype 0.8+ kwargs['convertStrings'] = False jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path, **kwargs) return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize) def test_jvm_buffer(root_allocator): # Create a Java buffer jvm_buffer = root_allocator.buffer(8) for i in range(8): jvm_buffer.setByte(i, 8 - i) orig_refcnt = jvm_buffer.refCnt() # Convert to Python buf = pa_jvm.jvm_buffer(jvm_buffer) # Check its content assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01' # Check Java buffer lifetime is tied to PyArrow buffer lifetime assert jvm_buffer.refCnt() == orig_refcnt + 1 del buf assert jvm_buffer.refCnt() == orig_refcnt def test_jvm_buffer_released(root_allocator): import jpype.imports # noqa from java.lang import IllegalArgumentException jvm_buffer = root_allocator.buffer(8) jvm_buffer.release() with pytest.raises(IllegalArgumentException): pa_jvm.jvm_buffer(jvm_buffer) def _jvm_field(jvm_spec): om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field') return om.readValue(jvm_spec, pojo_Field) def _jvm_schema(jvm_spec, metadata=None): field = _jvm_field(jvm_spec) schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema') fields = jpype.JClass('java.util.ArrayList')() fields.add(field) if metadata: dct = jpype.JClass('java.util.HashMap')() for k, v in metadata.items(): dct.put(k, v) return schema_cls(fields, dct) else: return schema_cls(fields) # In the following, we use the JSON serialization of the Field objects in Java. # This ensures that we neither rely on the exact mechanics on how to construct # them using Java code as well as enables us to define them as parameters # without to invoke the JVM. # # The specifications were created using: # # om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() # field = … # Code to instantiate the field # jvm_spec = om.writeValueAsString(field) @pytest.mark.parametrize('pa_type,jvm_spec', [ (pa.null(), '{"name":"null"}'), (pa.bool_(), '{"name":"bool"}'), (pa.int8(), '{"name":"int","bitWidth":8,"isSigned":true}'), (pa.int16(), '{"name":"int","bitWidth":16,"isSigned":true}'), (pa.int32(), '{"name":"int","bitWidth":32,"isSigned":true}'), (pa.int64(), '{"name":"int","bitWidth":64,"isSigned":true}'), (pa.uint8(), '{"name":"int","bitWidth":8,"isSigned":false}'), (pa.uint16(), '{"name":"int","bitWidth":16,"isSigned":false}'), (pa.uint32(), '{"name":"int","bitWidth":32,"isSigned":false}'), (pa.uint64(), '{"name":"int","bitWidth":64,"isSigned":false}'), (pa.float16(), '{"name":"floatingpoint","precision":"HALF"}'), (pa.float32(), '{"name":"floatingpoint","precision":"SINGLE"}'), (pa.float64(), '{"name":"floatingpoint","precision":"DOUBLE"}'), (pa.time32('s'), '{"name":"time","unit":"SECOND","bitWidth":32}'), (pa.time32('ms'), '{"name":"time","unit":"MILLISECOND","bitWidth":32}'), (pa.time64('us'), '{"name":"time","unit":"MICROSECOND","bitWidth":64}'), (pa.time64('ns'), '{"name":"time","unit":"NANOSECOND","bitWidth":64}'), (pa.timestamp('s'), '{"name":"timestamp","unit":"SECOND",' '"timezone":null}'), (pa.timestamp('ms'), '{"name":"timestamp","unit":"MILLISECOND",' '"timezone":null}'), (pa.timestamp('us'), '{"name":"timestamp","unit":"MICROSECOND",' '"timezone":null}'), (pa.timestamp('ns'), '{"name":"timestamp","unit":"NANOSECOND",' '"timezone":null}'), (pa.timestamp('ns', tz='UTC'), '{"name":"timestamp","unit":"NANOSECOND"' ',"timezone":"UTC"}'), (pa.timestamp('ns', tz='Europe/Paris'), '{"name":"timestamp",' '"unit":"NANOSECOND","timezone":"Europe/Paris"}'), (pa.date32(), '{"name":"date","unit":"DAY"}'), (pa.date64(), '{"name":"date","unit":"MILLISECOND"}'), (pa.decimal128(19, 4), '{"name":"decimal","precision":19,"scale":4}'), (pa.string(), '{"name":"utf8"}'), (pa.binary(), '{"name":"binary"}'), (pa.binary(10), '{"name":"fixedsizebinary","byteWidth":10}'), # TODO(ARROW-2609): complex types that have children # pa.list_(pa.int32()), # pa.struct([pa.field('a', pa.int32()), # pa.field('b', pa.int8()), # pa.field('c', pa.string())]), # pa.union([pa.field('a', pa.binary(10)), # pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), # pa.union([pa.field('a', pa.binary(10)), # pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), # TODO: DictionaryType requires a vector in the type # pa.dictionary(pa.int32(), pa.array(['a', 'b', 'c'])), ]) @pytest.mark.parametrize('nullable', [True, False]) def test_jvm_types(root_allocator, pa_type, jvm_spec, nullable): if pa_type == pa.null() and not nullable: return spec = { 'name': 'field_name', 'nullable': nullable, 'type': json.loads(jvm_spec), # TODO: This needs to be set for complex types 'children': [] } jvm_field = _jvm_field(json.dumps(spec)) result = pa_jvm.field(jvm_field) expected_field = pa.field('field_name', pa_type, nullable=nullable) assert result == expected_field jvm_schema = _jvm_schema(json.dumps(spec)) result = pa_jvm.schema(jvm_schema) assert result == pa.schema([expected_field]) # Schema with custom metadata jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'}) result = pa_jvm.schema(jvm_schema) assert result == pa.schema([expected_field], {'meta': 'data'}) # Schema with custom field metadata spec['metadata'] = [{'key': 'field meta', 'value': 'field data'}] jvm_schema = _jvm_schema(json.dumps(spec)) result = pa_jvm.schema(jvm_schema) expected_field = expected_field.with_metadata( {'field meta': 'field data'}) assert result == pa.schema([expected_field]) # These test parameters mostly use an integer range as an input as this is # often the only type that is understood by both Python and Java # implementations of Arrow. @pytest.mark.parametrize('pa_type,py_data,jvm_type', [ (pa.bool_(), [True, False, True, True], 'BitVector'), (pa.uint8(), list(range(128)), 'UInt1Vector'), (pa.uint16(), list(range(128)), 'UInt2Vector'), (pa.int32(), list(range(128)), 'IntVector'), (pa.int64(), list(range(128)), 'BigIntVector'), (pa.float32(), list(range(128)), 'Float4Vector'), (pa.float64(), list(range(128)), 'Float8Vector'), (pa.timestamp('s'), list(range(128)), 'TimeStampSecVector'), (pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector'), (pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector'), (pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector'), # TODO(ARROW-2605): These types miss a conversion from pure Python objects # * pa.time32('s') # * pa.time32('ms') # * pa.time64('us') # * pa.time64('ns') (pa.date32(), list(range(128)), 'DateDayVector'), (pa.date64(), list(range(128)), 'DateMilliVector'), # TODO(ARROW-2606): pa.decimal128(19, 4) ]) def test_jvm_array(root_allocator, pa_type, py_data, jvm_type): # Create vector cls = "org.apache.arrow.vector.{}".format(jvm_type) jvm_vector = jpype.JClass(cls)("vector", root_allocator) jvm_vector.allocateNew(len(py_data)) for i, val in enumerate(py_data): # char and int are ambiguous overloads for these two setSafe calls if jvm_type in {'UInt1Vector', 'UInt2Vector'}: val = jpype.JInt(val) jvm_vector.setSafe(i, val) jvm_vector.setValueCount(len(py_data)) py_array = pa.array(py_data, type=pa_type) jvm_array = pa_jvm.array(jvm_vector) assert py_array.equals(jvm_array) def test_jvm_array_empty(root_allocator): cls = "org.apache.arrow.vector.{}".format('IntVector') jvm_vector = jpype.JClass(cls)("vector", root_allocator) jvm_vector.allocateNew() jvm_array = pa_jvm.array(jvm_vector) assert len(jvm_array) == 0 assert jvm_array.type == pa.int32() # These test parameters mostly use an integer range as an input as this is # often the only type that is understood by both Python and Java # implementations of Arrow. @pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_spec', [ # TODO: null (pa.bool_(), [True, False, True, True], 'BitVector', '{"name":"bool"}'), ( pa.uint8(), list(range(128)), 'UInt1Vector', '{"name":"int","bitWidth":8,"isSigned":false}' ), ( pa.uint16(), list(range(128)), 'UInt2Vector', '{"name":"int","bitWidth":16,"isSigned":false}' ), ( pa.uint32(), list(range(128)), 'UInt4Vector', '{"name":"int","bitWidth":32,"isSigned":false}' ), ( pa.uint64(), list(range(128)), 'UInt8Vector', '{"name":"int","bitWidth":64,"isSigned":false}' ), ( pa.int8(), list(range(128)), 'TinyIntVector', '{"name":"int","bitWidth":8,"isSigned":true}' ), ( pa.int16(), list(range(128)), 'SmallIntVector', '{"name":"int","bitWidth":16,"isSigned":true}' ), ( pa.int32(), list(range(128)), 'IntVector', '{"name":"int","bitWidth":32,"isSigned":true}' ), ( pa.int64(), list(range(128)), 'BigIntVector', '{"name":"int","bitWidth":64,"isSigned":true}' ), # TODO: float16 ( pa.float32(), list(range(128)), 'Float4Vector', '{"name":"floatingpoint","precision":"SINGLE"}' ), ( pa.float64(), list(range(128)), 'Float8Vector', '{"name":"floatingpoint","precision":"DOUBLE"}' ), ( pa.timestamp('s'), list(range(128)), 'TimeStampSecVector', '{"name":"timestamp","unit":"SECOND","timezone":null}' ), ( pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector', '{"name":"timestamp","unit":"MILLISECOND","timezone":null}' ), ( pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector', '{"name":"timestamp","unit":"MICROSECOND","timezone":null}' ), ( pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector', '{"name":"timestamp","unit":"NANOSECOND","timezone":null}' ), # TODO(ARROW-2605): These types miss a conversion from pure Python objects # * pa.time32('s') # * pa.time32('ms') # * pa.time64('us') # * pa.time64('ns') ( pa.date32(), list(range(128)), 'DateDayVector', '{"name":"date","unit":"DAY"}' ), ( pa.date64(), list(range(128)), 'DateMilliVector', '{"name":"date","unit":"MILLISECOND"}' ), # TODO(ARROW-2606): pa.decimal128(19, 4) ]) def test_jvm_record_batch(root_allocator, pa_type, py_data, jvm_type, jvm_spec): # Create vector cls = "org.apache.arrow.vector.{}".format(jvm_type) jvm_vector = jpype.JClass(cls)("vector", root_allocator) jvm_vector.allocateNew(len(py_data)) for i, val in enumerate(py_data): if jvm_type in {'UInt1Vector', 'UInt2Vector'}: val = jpype.JInt(val) jvm_vector.setSafe(i, val) jvm_vector.setValueCount(len(py_data)) # Create field spec = { 'name': 'field_name', 'nullable': False, 'type': json.loads(jvm_spec), # TODO: This needs to be set for complex types 'children': [] } jvm_field = _jvm_field(json.dumps(spec)) # Create VectorSchemaRoot jvm_fields = jpype.JClass('java.util.ArrayList')() jvm_fields.add(jvm_field) jvm_vectors = jpype.JClass('java.util.ArrayList')() jvm_vectors.add(jvm_vector) jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot') jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data)) py_record_batch = pa.RecordBatch.from_arrays( [pa.array(py_data, type=pa_type)], ['col'] ) jvm_record_batch = pa_jvm.record_batch(jvm_vsr) assert py_record_batch.equals(jvm_record_batch) def _string_to_varchar_holder(ra, string): nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder" holder = jpype.JClass(nvch_cls)() if string is None: holder.isSet = 0 else: holder.isSet = 1 value = jpype.JClass("java.lang.String")("string") std_charsets = jpype.JClass("java.nio.charset.StandardCharsets") bytes_ = value.getBytes(std_charsets.UTF_8) holder.buffer = ra.buffer(len(bytes_)) holder.buffer.setBytes(0, bytes_, 0, len(bytes_)) holder.start = 0 holder.end = len(bytes_) return holder # TODO(ARROW-2607) @pytest.mark.xfail(reason="from_buffers is only supported for " "primitive arrays yet") def test_jvm_string_array(root_allocator): data = ["string", None, "töst"] cls = "org.apache.arrow.vector.VarCharVector" jvm_vector = jpype.JClass(cls)("vector", root_allocator) jvm_vector.allocateNew() for i, string in enumerate(data): holder = _string_to_varchar_holder(root_allocator, "string") jvm_vector.setSafe(i, holder) jvm_vector.setValueCount(i + 1) py_array = pa.array(data, type=pa.string()) jvm_array = pa_jvm.array(jvm_vector) assert py_array.equals(jvm_array)