diff options
Diffstat (limited to 'src/arrow/java/c')
30 files changed, 4477 insertions, 0 deletions
diff --git a/src/arrow/java/c/CMakeLists.txt b/src/arrow/java/c/CMakeLists.txt new file mode 100644 index 000000000..1025f87af --- /dev/null +++ b/src/arrow/java/c/CMakeLists.txt @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# +# arrow_cdata_java +# + +cmake_minimum_required(VERSION 3.11) +message(STATUS "Building using CMake version: ${CMAKE_VERSION}") +project(arrow_cdata_java) + +# Find java/jni +include(UseJava) + +find_package(Java REQUIRED) +find_package(JNI REQUIRED) + +set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated") + +include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR} + ${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR}) + +add_jar(${PROJECT_NAME} + src/main/java/org/apache/arrow/c/jni/JniLoader.java + src/main/java/org/apache/arrow/c/jni/JniWrapper.java + src/main/java/org/apache/arrow/c/jni/PrivateData.java + GENERATE_NATIVE_HEADERS + arrow_cdata_java-native + DESTINATION + ${JNI_HEADERS_DIR}) + +set(SOURCES src/main/cpp/jni_wrapper.cc) +add_library(arrow_cdata_jni SHARED ${SOURCES}) +target_link_libraries(arrow_cdata_jni ${JAVA_JVM_LIBRARY}) +add_dependencies(arrow_cdata_jni ${PROJECT_NAME}) + +install(TARGETS arrow_cdata_jni DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/src/arrow/java/c/README.md b/src/arrow/java/c/README.md new file mode 100644 index 000000000..ce73f531c --- /dev/null +++ b/src/arrow/java/c/README.md @@ -0,0 +1,54 @@ +<!--- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +# C Interfaces for Arrow Java + +## Setup Build Environment + +install: + - Java 8 or later + - Maven 3.3 or later + - A C++11-enabled compiler + - CMake 3.11 or later + - Make or ninja build utilities + +## Building JNI wrapper shared library + +``` +mkdir -p build +pushd build +cmake .. +cmake --build . +popd +``` + +## Building and running tests + +Run tests with + +``` +mvn test +``` + +To install Apache Arrow (Java) with this module enabled run the following from the project root directory: + +``` +cd java +mvn -Parrow-c-data install +``` diff --git a/src/arrow/java/c/pom.xml b/src/arrow/java/c/pom.xml new file mode 100644 index 000000000..901b084fd --- /dev/null +++ b/src/arrow/java/c/pom.xml @@ -0,0 +1,77 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor + license agreements. See the NOTICE file distributed with this work for additional + information regarding copyright ownership. The ASF licenses this file to + You under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of + the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required + by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS + OF ANY KIND, either express or implied. See the License for the specific + language governing permissions and limitations under the License. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <parent> + <artifactId>arrow-java-root</artifactId> + <groupId>org.apache.arrow</groupId> + <version>6.0.1</version> + </parent> + <modelVersion>4.0.0</modelVersion> + + <artifactId>arrow-c-data</artifactId> + <name>Arrow Java C Data Interface</name> + <description>Java implementation of C Data Interface</description> + <packaging>jar</packaging> + <properties> + <arrow.c.jni.dist.dir>./build</arrow.c.jni.dist.dir> + </properties> + + <dependencies> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-vector</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + <classifier>${arrow.vector.classifier}</classifier> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-vector</artifactId> + <version>${project.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-core</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-unsafe</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <version>${dep.guava.version}</version> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <resources> + <resource> + <directory>${arrow.c.jni.dist.dir}</directory> + <includes> + <include>**/*arrow_cdata_jni.*</include> + </includes> + </resource> + </resources> + </build> + +</project> diff --git a/src/arrow/java/c/src/main/cpp/abi.h b/src/arrow/java/c/src/main/cpp/abi.h new file mode 100644 index 000000000..a78170dbd --- /dev/null +++ b/src/arrow/java/c/src/main/cpp/abi.h @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +// EXPERIMENTAL: C stream interface + +struct ArrowArrayStream { + // Callback to get the stream type + // (will be the same for all arrays in the stream). + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowSchema must be released independently from the stream. + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowArray must be released independently from the stream. + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. + // + // Return value: pointer to a null-terminated character array describing + // the last error, or NULL if no description is available. + // + // The returned pointer is only valid until the next operation on this stream + // (including release). + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowArrayStream*); + + // Opaque producer-specific data + void* private_data; +}; + +#ifdef __cplusplus +} +#endif diff --git a/src/arrow/java/c/src/main/cpp/jni_wrapper.cc b/src/arrow/java/c/src/main/cpp/jni_wrapper.cc new file mode 100644 index 000000000..cfb0af9bc --- /dev/null +++ b/src/arrow/java/c/src/main/cpp/jni_wrapper.cc @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <jni.h> + +#include <cassert> +#include <memory> +#include <stdexcept> +#include <string> + +#include "./abi.h" +#include "org_apache_arrow_c_jni_JniWrapper.h" + +namespace { + +jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name) { + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + return global_class; +} + +jclass illegal_access_exception_class; +jclass illegal_argument_exception_class; +jclass runtime_exception_class; +jclass private_data_class; + +jmethodID private_data_close_method; + +jint JNI_VERSION = JNI_VERSION_1_6; + +class JniPendingException : public std::runtime_error { + public: + explicit JniPendingException(const std::string& arg) : std::runtime_error(arg) {} +}; + +void ThrowPendingException(const std::string& message) { + throw JniPendingException(message); +} + +void JniThrow(std::string message) { ThrowPendingException(message); } + +jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig) { + jmethodID ret = env->GetMethodID(this_class, name, sig); + if (ret == nullptr) { + std::string error_message = "Unable to find method " + std::string(name) + + " within signature " + std::string(sig); + ThrowPendingException(error_message); + } + return ret; +} + +class InnerPrivateData { + public: + InnerPrivateData(JavaVM* vm, jobject private_data) + : vm_(vm), j_private_data_(private_data) {} + + JavaVM* vm_; + jobject j_private_data_; +}; + +class JNIEnvGuard { + public: + explicit JNIEnvGuard(JavaVM* vm) : vm_(vm), should_detach_(false) { + JNIEnv* env; + jint code = vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION); + if (code == JNI_EDETACHED) { + JavaVMAttachArgs args; + args.version = JNI_VERSION; + args.name = NULL; + args.group = NULL; + code = vm->AttachCurrentThread(reinterpret_cast<void**>(&env), &args); + should_detach_ = (code == JNI_OK); + } + if (code != JNI_OK) { + ThrowPendingException("Failed to attach the current thread to a Java VM"); + } + env_ = env; + } + + JNIEnv* env() { return env_; } + + ~JNIEnvGuard() { + if (should_detach_) { + vm_->DetachCurrentThread(); + should_detach_ = false; + } + } + + private: + bool should_detach_; + JavaVM* vm_; + JNIEnv* env_; +}; + +template <typename T> +void release_exported(T* base) { + // This should not be called on already released structure + assert(base->release != nullptr); + + // Release children + for (int64_t i = 0; i < base->n_children; ++i) { + T* child = base->children[i]; + if (child->release != nullptr) { + child->release(child); + assert(child->release == nullptr); + } + } + + // Release dictionary + T* dict = base->dictionary; + if (dict != nullptr && dict->release != nullptr) { + dict->release(dict); + assert(dict->release == nullptr); + } + + // Release all data directly owned by the struct + InnerPrivateData* private_data = + reinterpret_cast<InnerPrivateData*>(base->private_data); + + JNIEnvGuard guard(private_data->vm_); + JNIEnv* env = guard.env(); + + env->CallObjectMethod(private_data->j_private_data_, private_data_close_method); + if (env->ExceptionCheck()) { + env->ExceptionDescribe(); + env->ExceptionClear(); + ThrowPendingException("Error calling close of private data"); + } + env->DeleteGlobalRef(private_data->j_private_data_); + delete private_data; + base->private_data = nullptr; + + // Mark released + base->release = nullptr; +} +} // namespace + +#define JNI_METHOD_START try { +// macro ended + +#define JNI_METHOD_END(fallback_expr) \ + } \ + catch (JniPendingException & e) { \ + env->ThrowNew(runtime_exception_class, e.what()); \ + return fallback_expr; \ + } +// macro ended + +jint JNI_OnLoad(JavaVM* vm, void* reserved) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + JNI_METHOD_START + illegal_access_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/IllegalAccessException;"); + illegal_argument_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;"); + runtime_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/RuntimeException;"); + private_data_class = + CreateGlobalClassReference(env, "Lorg/apache/arrow/c/jni/PrivateData;"); + + private_data_close_method = GetMethodID(env, private_data_class, "close", "()V"); + + return JNI_VERSION; + JNI_METHOD_END(JNI_ERR) +} + +void JNI_OnUnload(JavaVM* vm, void* reserved) { + JNIEnv* env; + vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION); + env->DeleteGlobalRef(illegal_access_exception_class); + env->DeleteGlobalRef(illegal_argument_exception_class); + env->DeleteGlobalRef(runtime_exception_class); +} + +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: releaseSchema + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_releaseSchema( + JNIEnv* env, jobject, jlong address) { + JNI_METHOD_START + ArrowSchema* schema = reinterpret_cast<ArrowSchema*>(address); + if (schema->release != nullptr) { + schema->release(schema); + } + JNI_METHOD_END() +} + +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: releaseArray + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_apache_arrow_c_jni_JniWrapper_releaseArray(JNIEnv* env, jobject, jlong address) { + JNI_METHOD_START + ArrowArray* array = reinterpret_cast<ArrowArray*>(address); + if (array->release != nullptr) { + array->release(array); + } + JNI_METHOD_END() +} + +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: exportSchema + * Signature: (JLorg/apache/arrow/c/jni/PrivateData;)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_exportSchema( + JNIEnv* env, jobject, jlong address, jobject private_data) { + JNI_METHOD_START + ArrowSchema* schema = reinterpret_cast<ArrowSchema*>(address); + + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + JniThrow("Unable to get JavaVM instance"); + } + jobject private_data_ref = env->NewGlobalRef(private_data); + + schema->private_data = new InnerPrivateData(vm, private_data_ref); + schema->release = &release_exported<ArrowSchema>; + JNI_METHOD_END() +} + +/* + * Class: org_apache_arrow_c_jni_JniWrapper + * Method: exportArray + * Signature: (JLorg/apache/arrow/c/jni/PrivateData;)V + */ +JNIEXPORT void JNICALL Java_org_apache_arrow_c_jni_JniWrapper_exportArray( + JNIEnv* env, jobject, jlong address, jobject private_data) { + JNI_METHOD_START + ArrowArray* array = reinterpret_cast<ArrowArray*>(address); + + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + JniThrow("Unable to get JavaVM instance"); + } + jobject private_data_ref = env->NewGlobalRef(private_data); + + array->private_data = new InnerPrivateData(vm, private_data_ref); + array->release = &release_exported<ArrowArray>; + JNI_METHOD_END() +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java new file mode 100644 index 000000000..d6479a3ba --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.c.NativeUtil.addressOrNull; +import static org.apache.arrow.util.Preconditions.checkNotNull; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.c.jni.JniWrapper; +import org.apache.arrow.c.jni.PrivateData; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; + +/** + * Exporter for {@link ArrowArray}. + */ +final class ArrayExporter { + private final BufferAllocator allocator; + + public ArrayExporter(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** + * Private data structure for exported arrays. + */ + static class ExportedArrayPrivateData implements PrivateData { + ArrowBuf buffers_ptrs; + List<ArrowBuf> buffers; + ArrowBuf children_ptrs; + List<ArrowArray> children; + ArrowArray dictionary; + + @Override + public void close() { + NativeUtil.closeBuffer(buffers_ptrs); + + if (buffers != null) { + for (ArrowBuf buffer : buffers) { + NativeUtil.closeBuffer(buffer); + } + } + NativeUtil.closeBuffer(children_ptrs); + + if (children != null) { + for (ArrowArray child : children) { + child.close(); + } + } + + if (dictionary != null) { + dictionary.close(); + } + } + } + + void export(ArrowArray array, FieldVector vector, DictionaryProvider dictionaryProvider) { + List<FieldVector> children = vector.getChildrenFromFields(); + List<ArrowBuf> buffers = vector.getFieldBuffers(); + int valueCount = vector.getValueCount(); + int nullCount = vector.getNullCount(); + DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); + + ExportedArrayPrivateData data = new ExportedArrayPrivateData(); + try { + if (children != null) { + data.children = new ArrayList<>(children.size()); + data.children_ptrs = allocator.buffer((long) children.size() * Long.BYTES); + for (int i = 0; i < children.size(); i++) { + ArrowArray child = ArrowArray.allocateNew(allocator); + data.children.add(child); + data.children_ptrs.writeLong(child.memoryAddress()); + } + } + + if (buffers != null) { + data.buffers = new ArrayList<>(buffers.size()); + data.buffers_ptrs = allocator.buffer((long) buffers.size() * Long.BYTES); + for (ArrowBuf arrowBuf : buffers) { + if (arrowBuf != null) { + arrowBuf.getReferenceManager().retain(); + data.buffers_ptrs.writeLong(arrowBuf.memoryAddress()); + } else { + data.buffers_ptrs.writeLong(NULL); + } + data.buffers.add(arrowBuf); + } + } + + if (dictionaryEncoding != null) { + Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); + checkNotNull(dictionary, "Dictionary lookup failed on export of dictionary encoded array"); + + data.dictionary = ArrowArray.allocateNew(allocator); + FieldVector dictionaryVector = dictionary.getVector(); + export(data.dictionary, dictionaryVector, dictionaryProvider); + } + + ArrowArray.Snapshot snapshot = new ArrowArray.Snapshot(); + snapshot.length = valueCount; + snapshot.null_count = nullCount; + snapshot.offset = 0; + snapshot.n_buffers = (data.buffers != null) ? data.buffers.size() : 0; + snapshot.n_children = (data.children != null) ? data.children.size() : 0; + snapshot.buffers = addressOrNull(data.buffers_ptrs); + snapshot.children = addressOrNull(data.children_ptrs); + snapshot.dictionary = addressOrNull(data.dictionary); + snapshot.release = NULL; + array.save(snapshot); + + // sets release and private data + JniWrapper.get().exportArray(array.memoryAddress(), data); + } catch (Exception e) { + data.close(); + throw e; + } + + // Export children + if (children != null) { + for (int i = 0; i < children.size(); i++) { + FieldVector childVector = children.get(i); + ArrowArray child = data.children.get(i); + export(child, childVector, dictionaryProvider); + } + } + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrayImporter.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrayImporter.java new file mode 100644 index 000000000..e82cef6a8 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrayImporter.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; +import static org.apache.arrow.util.Preconditions.checkNotNull; +import static org.apache.arrow.util.Preconditions.checkState; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.TypeLayout; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; + +/** + * Importer for {@link ArrowArray}. + */ +final class ArrayImporter { + private static final int MAX_IMPORT_RECURSION_LEVEL = 64; + + private final BufferAllocator allocator; + private final FieldVector vector; + private final DictionaryProvider dictionaryProvider; + + private CDataReferenceManager referenceManager; + private int recursionLevel; + + ArrayImporter(BufferAllocator allocator, FieldVector vector, DictionaryProvider dictionaryProvider) { + this.allocator = allocator; + this.vector = vector; + this.dictionaryProvider = dictionaryProvider; + } + + void importArray(ArrowArray src) { + ArrowArray.Snapshot snapshot = src.snapshot(); + checkState(snapshot.release != NULL, "Cannot import released ArrowArray"); + + // Move imported array + ArrowArray ownedArray = ArrowArray.allocateNew(allocator); + ownedArray.save(snapshot); + src.markReleased(); + src.close(); + + recursionLevel = 0; + + // This keeps the array alive as long as there are any buffers that need it + referenceManager = new CDataReferenceManager(ownedArray); + try { + referenceManager.increment(); + doImport(snapshot); + } finally { + referenceManager.release(); + } + } + + private void importChild(ArrayImporter parent, ArrowArray src) { + ArrowArray.Snapshot snapshot = src.snapshot(); + checkState(snapshot.release != NULL, "Cannot import released ArrowArray"); + recursionLevel = parent.recursionLevel + 1; + checkState(recursionLevel <= MAX_IMPORT_RECURSION_LEVEL, "Recursion level in ArrowArray struct exceeded"); + // Child buffers will keep the entire parent import alive. + // Perhaps we can move the child structs on import, + // but that is another level of complication. + referenceManager = parent.referenceManager; + doImport(snapshot); + } + + private void doImport(ArrowArray.Snapshot snapshot) { + // First import children (required for reconstituting parent array data) + long[] children = NativeUtil.toJavaArray(snapshot.children, checkedCastToInt(snapshot.n_children)); + if (children != null && children.length > 0) { + List<FieldVector> childVectors = vector.getChildrenFromFields(); + checkState(children.length == childVectors.size(), "ArrowArray struct has %s children (expected %s)", + children.length, childVectors.size()); + for (int i = 0; i < children.length; i++) { + checkState(children[i] != NULL, "ArrowArray struct has NULL child at position %s", i); + ArrayImporter childImporter = new ArrayImporter(allocator, childVectors.get(i), dictionaryProvider); + childImporter.importChild(this, ArrowArray.wrap(children[i])); + } + } + + // Handle import of a dictionary encoded vector + if (snapshot.dictionary != NULL) { + DictionaryEncoding encoding = vector.getField().getDictionary(); + checkNotNull(encoding, "Missing encoding on import of ArrowArray with dictionary"); + + Dictionary dictionary = dictionaryProvider.lookup(encoding.getId()); + checkNotNull(dictionary, "Dictionary lookup failed on import of ArrowArray with dictionary"); + + // reset the dictionary vector to the initial state + dictionary.getVector().clear(); + + ArrayImporter dictionaryImporter = new ArrayImporter(allocator, dictionary.getVector(), dictionaryProvider); + dictionaryImporter.importChild(this, ArrowArray.wrap(snapshot.dictionary)); + } + + // Import main data + ArrowFieldNode fieldNode = new ArrowFieldNode(snapshot.length, snapshot.null_count); + List<ArrowBuf> buffers = importBuffers(snapshot); + try { + vector.loadFieldBuffers(fieldNode, buffers); + } catch (RuntimeException e) { + throw new IllegalArgumentException( + "Could not load buffers for field " + vector.getField() + ". error message: " + e.getMessage(), e); + } + } + + private List<ArrowBuf> importBuffers(ArrowArray.Snapshot snapshot) { + long[] buffers = NativeUtil.toJavaArray(snapshot.buffers, checkedCastToInt(snapshot.n_buffers)); + if (buffers == null || buffers.length == 0) { + return new ArrayList<>(); + } + + int buffersCount = TypeLayout.getTypeBufferCount(vector.getField().getType()); + checkState(buffers.length == buffersCount, "Expected %s buffers for imported type %s, ArrowArray struct has %s", + buffersCount, vector.getField().getType().getTypeID(), buffers.length); + + List<ArrowBuf> result = new ArrayList<>(buffersCount); + for (long bufferPtr : buffers) { + ArrowBuf buffer = null; + if (bufferPtr != NULL) { + // TODO(roee88): an API for getting the size for each buffer is not yet + // available + buffer = new ArrowBuf(referenceManager, null, Integer.MAX_VALUE, bufferPtr); + } + result.add(buffer); + } + return result; + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrowArray.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrowArray.java new file mode 100644 index 000000000..99fe0432c --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrowArray.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.util.Preconditions.checkNotNull; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import org.apache.arrow.c.jni.JniWrapper; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ReferenceManager; +import org.apache.arrow.memory.util.MemoryUtil; + +/** + * C Data Interface ArrowArray. + * <p> + * Represents a wrapper for the following C structure: + * + * <pre> + * struct ArrowArray { + * // Array data description + * int64_t length; + * int64_t null_count; + * int64_t offset; + * int64_t n_buffers; + * int64_t n_children; + * const void** buffers; + * struct ArrowArray** children; + * struct ArrowArray* dictionary; + * + * // Release callback + * void (*release)(struct ArrowArray*); + * // Opaque producer-specific data + * void* private_data; + * }; + * </pre> + */ +public class ArrowArray implements BaseStruct { + private static final int SIZE_OF = 80; + private static final int INDEX_RELEASE_CALLBACK = 64; + + private ArrowBuf data; + + /** + * Snapshot of the ArrowArray raw data. + */ + public static class Snapshot { + public long length; + public long null_count; + public long offset; + public long n_buffers; + public long n_children; + public long buffers; + public long children; + public long dictionary; + public long release; + public long private_data; + + /** + * Initialize empty ArrowArray snapshot. + */ + public Snapshot() { + length = NULL; + null_count = NULL; + offset = NULL; + n_buffers = NULL; + n_children = NULL; + buffers = NULL; + children = NULL; + dictionary = NULL; + release = NULL; + private_data = NULL; + } + } + + /** + * Create ArrowArray from an existing memory address. + * <p> + * The resulting ArrowArray does not own the memory. + * + * @param memoryAddress Memory address to wrap + * @return A new ArrowArray instance + */ + public static ArrowArray wrap(long memoryAddress) { + return new ArrowArray(new ArrowBuf(ReferenceManager.NO_OP, null, ArrowArray.SIZE_OF, memoryAddress)); + } + + /** + * Create ArrowArray by allocating memory. + * <p> + * The resulting ArrowArray owns the memory. + * + * @param allocator Allocator for memory allocations + * @return A new ArrowArray instance + */ + public static ArrowArray allocateNew(BufferAllocator allocator) { + ArrowArray array = new ArrowArray(allocator.buffer(ArrowArray.SIZE_OF)); + array.markReleased(); + return array; + } + + ArrowArray(ArrowBuf data) { + checkNotNull(data, "ArrowArray initialized with a null buffer"); + this.data = data; + } + + /** + * Mark the array as released. + */ + public void markReleased() { + directBuffer().putLong(INDEX_RELEASE_CALLBACK, NULL); + } + + @Override + public long memoryAddress() { + checkNotNull(data, "ArrowArray is already closed"); + return data.memoryAddress(); + } + + @Override + public void release() { + long address = memoryAddress(); + JniWrapper.get().releaseArray(address); + } + + @Override + public void close() { + if (data != null) { + data.close(); + data = null; + } + } + + private ByteBuffer directBuffer() { + return MemoryUtil.directBuffer(memoryAddress(), ArrowArray.SIZE_OF).order(ByteOrder.nativeOrder()); + } + + /** + * Take a snapshot of the ArrowArray raw values. + * + * @return snapshot + */ + public Snapshot snapshot() { + ByteBuffer data = directBuffer(); + Snapshot snapshot = new Snapshot(); + snapshot.length = data.getLong(); + snapshot.null_count = data.getLong(); + snapshot.offset = data.getLong(); + snapshot.n_buffers = data.getLong(); + snapshot.n_children = data.getLong(); + snapshot.buffers = data.getLong(); + snapshot.children = data.getLong(); + snapshot.dictionary = data.getLong(); + snapshot.release = data.getLong(); + snapshot.private_data = data.getLong(); + return snapshot; + } + + /** + * Write values from Snapshot to the underlying ArrowArray memory buffer. + */ + public void save(Snapshot snapshot) { + directBuffer().putLong(snapshot.length).putLong(snapshot.null_count).putLong(snapshot.offset) + .putLong(snapshot.n_buffers).putLong(snapshot.n_children).putLong(snapshot.buffers).putLong(snapshot.children) + .putLong(snapshot.dictionary).putLong(snapshot.release).putLong(snapshot.private_data); + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrowSchema.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrowSchema.java new file mode 100644 index 000000000..b34ce7d5a --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/ArrowSchema.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.util.Preconditions.checkNotNull; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import org.apache.arrow.c.jni.JniWrapper; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ReferenceManager; +import org.apache.arrow.memory.util.MemoryUtil; + +/** + * C Data Interface ArrowSchema. + * <p> + * Represents a wrapper for the following C structure: + * + * <pre> + * struct ArrowSchema { + * // Array type description + * const char* format; + * const char* name; + * const char* metadata; + * int64_t flags; + * int64_t n_children; + * struct ArrowSchema** children; + * struct ArrowSchema* dictionary; + * + * // Release callback + * void (*release)(struct ArrowSchema*); + * // Opaque producer-specific data + * void* private_data; + * }; + * </pre> + */ +public class ArrowSchema implements BaseStruct { + private static final int SIZE_OF = 72; + + private ArrowBuf data; + + /** + * Snapshot of the ArrowSchema raw data. + */ + public static class Snapshot { + public long format; + public long name; + public long metadata; + public long flags; + public long n_children; + public long children; + public long dictionary; + public long release; + public long private_data; + + /** + * Initialize empty ArrowSchema snapshot. + */ + public Snapshot() { + format = NULL; + name = NULL; + metadata = NULL; + flags = NULL; + n_children = NULL; + children = NULL; + dictionary = NULL; + release = NULL; + private_data = NULL; + } + } + + /** + * Create ArrowSchema from an existing memory address. + * <p> + * The resulting ArrowSchema does not own the memory. + * + * @param memoryAddress Memory address to wrap + * @return A new ArrowSchema instance + */ + public static ArrowSchema wrap(long memoryAddress) { + return new ArrowSchema(new ArrowBuf(ReferenceManager.NO_OP, null, ArrowSchema.SIZE_OF, memoryAddress)); + } + + /** + * Create ArrowSchema by allocating memory. + * <p> + * The resulting ArrowSchema owns the memory. + * + * @param allocator Allocator for memory allocations + * @return A new ArrowSchema instance + */ + public static ArrowSchema allocateNew(BufferAllocator allocator) { + return new ArrowSchema(allocator.buffer(ArrowSchema.SIZE_OF)); + } + + ArrowSchema(ArrowBuf data) { + checkNotNull(data, "ArrowSchema initialized with a null buffer"); + this.data = data; + } + + @Override + public long memoryAddress() { + checkNotNull(data, "ArrowSchema is already closed"); + return data.memoryAddress(); + } + + @Override + public void release() { + long address = memoryAddress(); + JniWrapper.get().releaseSchema(address); + } + + @Override + public void close() { + if (data != null) { + data.close(); + data = null; + } + } + + private ByteBuffer directBuffer() { + return MemoryUtil.directBuffer(memoryAddress(), ArrowSchema.SIZE_OF).order(ByteOrder.nativeOrder()); + } + + /** + * Take a snapshot of the ArrowSchema raw values. + * + * @return snapshot + */ + public Snapshot snapshot() { + ByteBuffer data = directBuffer(); + Snapshot snapshot = new Snapshot(); + snapshot.format = data.getLong(); + snapshot.name = data.getLong(); + snapshot.metadata = data.getLong(); + snapshot.flags = data.getLong(); + snapshot.n_children = data.getLong(); + snapshot.children = data.getLong(); + snapshot.dictionary = data.getLong(); + snapshot.release = data.getLong(); + snapshot.private_data = data.getLong(); + return snapshot; + } + + /** + * Write values from Snapshot to the underlying ArrowSchema memory buffer. + */ + public void save(Snapshot snapshot) { + directBuffer().putLong(snapshot.format).putLong(snapshot.name).putLong(snapshot.metadata).putLong(snapshot.flags) + .putLong(snapshot.n_children).putLong(snapshot.children).putLong(snapshot.dictionary).putLong(snapshot.release) + .putLong(snapshot.private_data); + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/BaseStruct.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/BaseStruct.java new file mode 100644 index 000000000..d90fe8175 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/BaseStruct.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +/** + * Base interface for C Data Interface structures. + */ +public interface BaseStruct extends AutoCloseable { + /** + * Get memory address. + * + * @return Memory address + */ + long memoryAddress(); + + /** + * Call the release callback of an ArrowArray. + * <p> + * This function must not be called for child arrays. + */ + void release(); + + /** + * Close to release the main buffer. + */ + @Override + void close(); +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java new file mode 100644 index 000000000..43bcda276 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/CDataDictionaryProvider.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; + +/** + * A DictionaryProvider that is used in C Data Interface for imports. + * <p> + * CDataDictionaryProvider is similar to + * {@link DictionaryProvider.MapDictionaryProvider} with a key difference that + * the dictionaries are owned by the provider so it must eventually be closed. + * <p> + * The typical usage is to create the CDataDictionaryProvider and pass it to + * {@link Data#importField} or {@link Data#importSchema} to allocate empty + * dictionaries based on the information in {@link ArrowSchema}. Then you can + * re-use the same dictionary provider in any function that imports an + * {@link ArrowArray} that has the same schema. + */ +public class CDataDictionaryProvider implements DictionaryProvider, AutoCloseable { + + private final Map<Long, Dictionary> map; + + public CDataDictionaryProvider() { + this.map = new HashMap<>(); + } + + void put(Dictionary dictionary) { + Dictionary previous = map.put(dictionary.getEncoding().getId(), dictionary); + if (previous != null) { + previous.getVector().close(); + } + } + + public final Set<Long> getDictionaryIds() { + return map.keySet(); + } + + @Override + public Dictionary lookup(long id) { + return map.get(id); + } + + @Override + public void close() { + for (Dictionary dictionary : map.values()) { + dictionary.getVector().close(); + } + map.clear(); + } + +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/CDataReferenceManager.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/CDataReferenceManager.java new file mode 100644 index 000000000..c5c2f9779 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/CDataReferenceManager.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.OwnershipTransferResult; +import org.apache.arrow.memory.ReferenceManager; +import org.apache.arrow.util.Preconditions; + +/** + * A ReferenceManager implementation that holds a + * {@link org.apache.arrow.c.BaseStruct}. + * <p> + * A reference count is maintained and once it reaches zero the struct is + * released (as per the C data interface specification) and closed. + */ +final class CDataReferenceManager implements ReferenceManager { + private final AtomicInteger bufRefCnt = new AtomicInteger(0); + + private final BaseStruct struct; + + CDataReferenceManager(BaseStruct struct) { + this.struct = struct; + } + + @Override + public int getRefCount() { + return bufRefCnt.get(); + } + + @Override + public boolean release() { + return release(1); + } + + /** + * Increment the reference count without any safety checks. + */ + void increment() { + bufRefCnt.incrementAndGet(); + } + + @Override + public boolean release(int decrement) { + Preconditions.checkState(decrement >= 1, "ref count decrement should be greater than or equal to 1"); + // decrement the ref count + final int refCnt = bufRefCnt.addAndGet(-decrement); + // the new ref count should be >= 0 + Preconditions.checkState(refCnt >= 0, "ref count has gone negative"); + if (refCnt == 0) { + // refcount of this reference manager has dropped to 0 + // release the underlying memory + struct.release(); + struct.close(); + } + return refCnt == 0; + } + + @Override + public void retain() { + retain(1); + } + + @Override + public void retain(int increment) { + Preconditions.checkArgument(increment > 0, "retain(%s) argument is not positive", increment); + final int originalReferenceCount = bufRefCnt.getAndAdd(increment); + Preconditions.checkState(originalReferenceCount > 0, "retain called but memory was already released"); + } + + @Override + public ArrowBuf retain(ArrowBuf srcBuffer, BufferAllocator targetAllocator) { + retain(); + + ArrowBuf targetArrowBuf = this.deriveBuffer(srcBuffer, 0, srcBuffer.capacity()); + targetArrowBuf.readerIndex(srcBuffer.readerIndex()); + targetArrowBuf.writerIndex(srcBuffer.writerIndex()); + return targetArrowBuf; + } + + @Override + public ArrowBuf deriveBuffer(ArrowBuf sourceBuffer, long index, long length) { + final long derivedBufferAddress = sourceBuffer.memoryAddress() + index; + return new ArrowBuf(this, null, length, derivedBufferAddress); + } + + @Override + public OwnershipTransferResult transferOwnership(ArrowBuf sourceBuffer, BufferAllocator targetAllocator) { + throw new UnsupportedOperationException(); + } + + @Override + public BufferAllocator getAllocator() { + return null; + } + + @Override + public long getSize() { + return 0L; + } + + @Override + public long getAccountedSize() { + return 0L; + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/Data.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Data.java new file mode 100644 index 000000000..27b0ce4bf --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Data.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.StructVectorLoader; +import org.apache.arrow.vector.StructVectorUnloader; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Functions for working with the C data interface. + * <p> + * This API is EXPERIMENTAL. Note that currently only 64bit systems are + * supported. + */ +public final class Data { + + private Data() { + } + + /** + * Export Java Field using the C data interface format. + * + * @param allocator Buffer allocator for allocating C data interface fields + * @param field Field object to export + * @param provider Dictionary provider for dictionary encoded fields (optional) + * @param out C struct where to export the field + */ + public static void exportField(BufferAllocator allocator, Field field, DictionaryProvider provider, ArrowSchema out) { + SchemaExporter exporter = new SchemaExporter(allocator); + exporter.export(out, field, provider); + } + + /** + * Export Java Schema using the C data interface format. + * + * @param allocator Buffer allocator for allocating C data interface fields + * @param schema Schema object to export + * @param provider Dictionary provider for dictionary encoded fields (optional) + * @param out C struct where to export the field + */ + public static void exportSchema(BufferAllocator allocator, Schema schema, DictionaryProvider provider, + ArrowSchema out) { + // Convert to a struct field equivalent to the input schema + FieldType fieldType = new FieldType(false, new ArrowType.Struct(), null, schema.getCustomMetadata()); + Field field = new Field("", fieldType, schema.getFields()); + exportField(allocator, field, provider, out); + } + + /** + * Export Java FieldVector using the C data interface format. + * <p> + * The resulting ArrowArray struct keeps the array data and buffers alive until + * its release callback is called by the consumer. + * + * @param allocator Buffer allocator for allocating C data interface fields + * @param vector Vector object to export + * @param provider Dictionary provider for dictionary encoded vectors + * (optional) + * @param out C struct where to export the array + */ + public static void exportVector(BufferAllocator allocator, FieldVector vector, DictionaryProvider provider, + ArrowArray out) { + exportVector(allocator, vector, provider, out, null); + } + + /** + * Export Java FieldVector using the C data interface format. + * <p> + * The resulting ArrowArray struct keeps the array data and buffers alive until + * its release callback is called by the consumer. + * + * @param allocator Buffer allocator for allocating C data interface fields + * @param vector Vector object to export + * @param provider Dictionary provider for dictionary encoded vectors + * (optional) + * @param out C struct where to export the array + * @param outSchema C struct where to export the array type (optional) + */ + public static void exportVector(BufferAllocator allocator, FieldVector vector, DictionaryProvider provider, + ArrowArray out, ArrowSchema outSchema) { + if (outSchema != null) { + exportField(allocator, vector.getField(), provider, outSchema); + } + + ArrayExporter exporter = new ArrayExporter(allocator); + exporter.export(out, vector, provider); + } + + /** + * Export the current contents of a Java VectorSchemaRoot using the C data + * interface format. + * <p> + * The vector schema root is exported as if it were a struct array. The + * resulting ArrowArray struct keeps the record batch data and buffers alive + * until its release callback is called by the consumer. + * + * @param allocator Buffer allocator for allocating C data interface fields + * @param vsr Vector schema root to export + * @param provider Dictionary provider for dictionary encoded vectors + * (optional) + * @param out C struct where to export the record batch + */ + public static void exportVectorSchemaRoot(BufferAllocator allocator, VectorSchemaRoot vsr, + DictionaryProvider provider, ArrowArray out) { + exportVectorSchemaRoot(allocator, vsr, provider, out, null); + } + + /** + * Export the current contents of a Java VectorSchemaRoot using the C data + * interface format. + * <p> + * The vector schema root is exported as if it were a struct array. The + * resulting ArrowArray struct keeps the record batch data and buffers alive + * until its release callback is called by the consumer. + * + * @param allocator Buffer allocator for allocating C data interface fields + * @param vsr Vector schema root to export + * @param provider Dictionary provider for dictionary encoded vectors + * (optional) + * @param out C struct where to export the record batch + * @param outSchema C struct where to export the record batch schema (optional) + */ + public static void exportVectorSchemaRoot(BufferAllocator allocator, VectorSchemaRoot vsr, + DictionaryProvider provider, ArrowArray out, ArrowSchema outSchema) { + if (outSchema != null) { + exportSchema(allocator, vsr.getSchema(), provider, outSchema); + } + + VectorUnloader unloader = new VectorUnloader(vsr); + try (ArrowRecordBatch recordBatch = unloader.getRecordBatch()) { + StructVectorLoader loader = new StructVectorLoader(vsr.getSchema()); + try (StructVector vector = loader.load(allocator, recordBatch)) { + exportVector(allocator, vector, provider, out); + } + } + } + + /** + * Import Java Field from the C data interface. + * <p> + * The given ArrowSchema struct is released (as per the C data interface + * specification), even if this function fails. + * + * @param allocator Buffer allocator for allocating dictionary vectors + * @param schema C data interface struct representing the field [inout] + * @param provider A dictionary provider will be initialized with empty + * dictionary vectors (optional) + * @return Imported field object + */ + public static Field importField(BufferAllocator allocator, ArrowSchema schema, CDataDictionaryProvider provider) { + try { + SchemaImporter importer = new SchemaImporter(allocator); + return importer.importField(schema, provider); + } finally { + schema.release(); + schema.close(); + } + } + + /** + * Import Java Schema from the C data interface. + * <p> + * The given ArrowSchema struct is released (as per the C data interface + * specification), even if this function fails. + * + * @param allocator Buffer allocator for allocating dictionary vectors + * @param schema C data interface struct representing the field + * @param provider A dictionary provider will be initialized with empty + * dictionary vectors (optional) + * @return Imported schema object + */ + public static Schema importSchema(BufferAllocator allocator, ArrowSchema schema, CDataDictionaryProvider provider) { + Field structField = importField(allocator, schema, provider); + if (structField.getType().getTypeID() != ArrowTypeID.Struct) { + throw new IllegalArgumentException("Cannot import schema: ArrowSchema describes non-struct type"); + } + return new Schema(structField.getChildren(), structField.getMetadata()); + } + + /** + * Import Java vector from the C data interface. + * <p> + * The ArrowArray struct has its contents moved (as per the C data interface + * specification) to a private object held alive by the resulting array. + * + * @param allocator Buffer allocator + * @param array C data interface struct holding the array data + * @param vector Imported vector object [out] + * @param provider Dictionary provider to load dictionary vectors to (optional) + */ + public static void importIntoVector(BufferAllocator allocator, ArrowArray array, FieldVector vector, + DictionaryProvider provider) { + ArrayImporter importer = new ArrayImporter(allocator, vector, provider); + importer.importArray(array); + } + + /** + * Import Java vector and its type from the C data interface. + * <p> + * The ArrowArray struct has its contents moved (as per the C data interface + * specification) to a private object held alive by the resulting vector. The + * ArrowSchema struct is released, even if this function fails. + * + * @param allocator Buffer allocator for allocating the output FieldVector + * @param array C data interface struct holding the array data + * @param schema C data interface struct holding the array type + * @param provider Dictionary provider to load dictionary vectors to (optional) + * @return Imported vector object + */ + public static FieldVector importVector(BufferAllocator allocator, ArrowArray array, ArrowSchema schema, + CDataDictionaryProvider provider) { + Field field = importField(allocator, schema, provider); + FieldVector vector = field.createVector(allocator); + importIntoVector(allocator, array, vector, provider); + return vector; + } + + /** + * Import record batch from the C data interface into vector schema root. + * + * The ArrowArray struct has its contents moved (as per the C data interface + * specification) to a private object held alive by the resulting vector schema + * root. + * + * The schema of the vector schema root must match the input array (undefined + * behavior otherwise). + * + * @param allocator Buffer allocator + * @param array C data interface struct holding the record batch data + * @param root vector schema root to load into + * @param provider Dictionary provider to load dictionary vectors to (optional) + */ + public static void importIntoVectorSchemaRoot(BufferAllocator allocator, ArrowArray array, VectorSchemaRoot root, + DictionaryProvider provider) { + try (StructVector structVector = StructVector.empty("", allocator)) { + structVector.initializeChildrenFromFields(root.getSchema().getFields()); + importIntoVector(allocator, array, structVector, provider); + StructVectorUnloader unloader = new StructVectorUnloader(structVector); + VectorLoader loader = new VectorLoader(root); + try (ArrowRecordBatch recordBatch = unloader.getRecordBatch()) { + loader.load(recordBatch); + } + } + } + + /** + * Import Java vector schema root from a C data interface Schema. + * + * The type represented by the ArrowSchema struct must be a struct type array. + * + * The ArrowSchema struct is released, even if this function fails. + * + * @param allocator Buffer allocator for allocating the output VectorSchemaRoot + * @param schema C data interface struct holding the record batch schema + * @param provider Dictionary provider to load dictionary vectors to (optional) + * @return Imported vector schema root + */ + public static VectorSchemaRoot importVectorSchemaRoot(BufferAllocator allocator, ArrowSchema schema, + CDataDictionaryProvider provider) { + return importVectorSchemaRoot(allocator, null, schema, provider); + } + + /** + * Import Java vector schema root from the C data interface. + * + * The type represented by the ArrowSchema struct must be a struct type array. + * + * The ArrowArray struct has its contents moved (as per the C data interface + * specification) to a private object held alive by the resulting record batch. + * The ArrowSchema struct is released, even if this function fails. + * + * Prefer {@link #importIntoVectorSchemaRoot} for loading array data while + * reusing the same vector schema root. + * + * @param allocator Buffer allocator for allocating the output VectorSchemaRoot + * @param array C data interface struct holding the record batch data + * (optional) + * @param schema C data interface struct holding the record batch schema + * @param provider Dictionary provider to load dictionary vectors to (optional) + * @return Imported vector schema root + */ + public static VectorSchemaRoot importVectorSchemaRoot(BufferAllocator allocator, ArrowArray array, ArrowSchema schema, + CDataDictionaryProvider provider) { + VectorSchemaRoot vsr = VectorSchemaRoot.create(importSchema(allocator, schema, provider), allocator); + if (array != null) { + importIntoVectorSchemaRoot(allocator, array, vsr, provider); + } + return vsr; + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/Flags.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Flags.java new file mode 100644 index 000000000..744b4695a --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Flags.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; +import org.apache.arrow.vector.types.pojo.Field; + +/** + * Flags as defined in the C data interface specification. + */ +final class Flags { + static final int ARROW_FLAG_DICTIONARY_ORDERED = 1; + static final int ARROW_FLAG_NULLABLE = 2; + static final int ARROW_FLAG_MAP_KEYS_SORTED = 4; + + private Flags() { + } + + static long forField(Field field) { + long flags = 0L; + if (field.isNullable()) { + flags |= ARROW_FLAG_NULLABLE; + } + if (field.getDictionary() != null && field.getDictionary().isOrdered()) { + flags |= ARROW_FLAG_DICTIONARY_ORDERED; + } + if (field.getType().getTypeID() == ArrowTypeID.Map) { + ArrowType.Map map = (ArrowType.Map) field.getType(); + if (map.getKeysSorted()) { + flags |= ARROW_FLAG_MAP_KEYS_SORTED; + } + } + return flags; + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/Format.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Format.java new file mode 100644 index 000000000..315d3caad --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Format.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import java.util.Arrays; +import java.util.stream.Collectors; + +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; + +/** + * Conversion between {@link ArrowType} and string formats, as per C data + * interface specification. + */ +final class Format { + + private Format() { + } + + static String asString(ArrowType arrowType) { + if (arrowType instanceof ExtensionType) { + ArrowType innerType = ((ExtensionType) arrowType).storageType(); + return asString(innerType); + } + + switch (arrowType.getTypeID()) { + case Binary: + return "z"; + case Bool: + return "b"; + case Date: { + ArrowType.Date type = (ArrowType.Date) arrowType; + switch (type.getUnit()) { + case DAY: + return "tdD"; + case MILLISECOND: + return "tdm"; + default: + throw new UnsupportedOperationException( + String.format("Date type with unit %s is unsupported", type.getUnit())); + } + } + case Decimal: { + ArrowType.Decimal type = (ArrowType.Decimal) arrowType; + if (type.getBitWidth() == 128) { + return String.format("d:%d,%d", type.getPrecision(), type.getScale()); + } + return String.format("d:%d,%d,%d", type.getPrecision(), type.getScale(), type.getBitWidth()); + } + case Duration: { + ArrowType.Duration type = (ArrowType.Duration) arrowType; + switch (type.getUnit()) { + case SECOND: + return "tDs"; + case MILLISECOND: + return "tDm"; + case MICROSECOND: + return "tDu"; + case NANOSECOND: + return "tDn"; + default: + throw new UnsupportedOperationException( + String.format("Duration type with unit %s is unsupported", type.getUnit())); + } + } + case FixedSizeBinary: { + ArrowType.FixedSizeBinary type = (ArrowType.FixedSizeBinary) arrowType; + return String.format("w:%d", type.getByteWidth()); + } + case FixedSizeList: { + ArrowType.FixedSizeList type = (ArrowType.FixedSizeList) arrowType; + return String.format("+w:%d", type.getListSize()); + } + case FloatingPoint: { + ArrowType.FloatingPoint type = (ArrowType.FloatingPoint) arrowType; + switch (type.getPrecision()) { + case HALF: + return "e"; + case SINGLE: + return "f"; + case DOUBLE: + return "g"; + default: + throw new UnsupportedOperationException( + String.format("FloatingPoint type with precision %s is unsupported", type.getPrecision())); + } + } + case Int: { + String format; + ArrowType.Int type = (ArrowType.Int) arrowType; + switch (type.getBitWidth()) { + case Byte.SIZE: + format = "C"; + break; + case Short.SIZE: + format = "S"; + break; + case Integer.SIZE: + format = "I"; + break; + case Long.SIZE: + format = "L"; + break; + default: + throw new UnsupportedOperationException( + String.format("Int type with bitwidth %d is unsupported", type.getBitWidth())); + } + if (type.getIsSigned()) { + format = format.toLowerCase(); + } + return format; + } + case Interval: { + ArrowType.Interval type = (ArrowType.Interval) arrowType; + switch (type.getUnit()) { + case DAY_TIME: + return "tiD"; + case YEAR_MONTH: + return "tiM"; + default: + throw new UnsupportedOperationException( + String.format("Interval type with unit %s is unsupported", type.getUnit())); + } + } + case LargeBinary: + return "Z"; + case LargeList: + return "+L"; + case LargeUtf8: + return "U"; + case List: + return "+l"; + case Map: + return "+m"; + case Null: + return "n"; + case Struct: + return "+s"; + case Time: { + ArrowType.Time type = (ArrowType.Time) arrowType; + if (type.getUnit() == TimeUnit.SECOND && type.getBitWidth() == 32) { + return "tts"; + } else if (type.getUnit() == TimeUnit.MILLISECOND && type.getBitWidth() == 32) { + return "ttm"; + } else if (type.getUnit() == TimeUnit.MICROSECOND && type.getBitWidth() == 64) { + return "ttu"; + } else if (type.getUnit() == TimeUnit.NANOSECOND && type.getBitWidth() == 64) { + return "ttn"; + } else { + throw new UnsupportedOperationException(String.format("Time type with unit %s and bitwidth %d is unsupported", + type.getUnit(), type.getBitWidth())); + } + } + case Timestamp: { + String format; + ArrowType.Timestamp type = (ArrowType.Timestamp) arrowType; + switch (type.getUnit()) { + case SECOND: + format = "tss"; + break; + case MILLISECOND: + format = "tsm"; + break; + case MICROSECOND: + format = "tsu"; + break; + case NANOSECOND: + format = "tsn"; + break; + default: + throw new UnsupportedOperationException( + String.format("Timestamp type with unit %s is unsupported", type.getUnit())); + } + String timezone = type.getTimezone(); + return String.format("%s:%s", format, timezone == null ? "" : timezone); + } + case Union: + ArrowType.Union type = (ArrowType.Union) arrowType; + String typeIDs = Arrays.stream(type.getTypeIds()).mapToObj(String::valueOf).collect(Collectors.joining(",")); + switch (type.getMode()) { + case Dense: + return String.format("+ud:%s", typeIDs); + case Sparse: + return String.format("+us:%s", typeIDs); + default: + throw new UnsupportedOperationException( + String.format("Union type with mode %s is unsupported", type.getMode())); + } + case Utf8: + return "u"; + case NONE: + throw new IllegalArgumentException("Arrow type ID is NONE"); + default: + throw new UnsupportedOperationException(String.format("Unknown type id %s", arrowType.getTypeID())); + } + } + + static ArrowType asType(String format, long flags) + throws NumberFormatException, UnsupportedOperationException, IllegalStateException { + switch (format) { + case "n": + return new ArrowType.Null(); + case "b": + return new ArrowType.Bool(); + case "c": + return new ArrowType.Int(8, true); + case "C": + return new ArrowType.Int(8, false); + case "s": + return new ArrowType.Int(16, true); + case "S": + return new ArrowType.Int(16, false); + case "i": + return new ArrowType.Int(32, true); + case "I": + return new ArrowType.Int(32, false); + case "l": + return new ArrowType.Int(64, true); + case "L": + return new ArrowType.Int(64, false); + case "e": + return new ArrowType.FloatingPoint(FloatingPointPrecision.HALF); + case "f": + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case "g": + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case "z": + return new ArrowType.Binary(); + case "Z": + return new ArrowType.LargeBinary(); + case "u": + return new ArrowType.Utf8(); + case "U": + return new ArrowType.LargeUtf8(); + case "tdD": + return new ArrowType.Date(DateUnit.DAY); + case "tdm": + return new ArrowType.Date(DateUnit.MILLISECOND); + case "tts": + return new ArrowType.Time(TimeUnit.SECOND, 32); + case "ttm": + return new ArrowType.Time(TimeUnit.MILLISECOND, 32); + case "ttu": + return new ArrowType.Time(TimeUnit.MICROSECOND, 64); + case "ttn": + return new ArrowType.Time(TimeUnit.NANOSECOND, 64); + case "tDs": + return new ArrowType.Duration(TimeUnit.SECOND); + case "tDm": + return new ArrowType.Duration(TimeUnit.MILLISECOND); + case "tDu": + return new ArrowType.Duration(TimeUnit.MICROSECOND); + case "tDn": + return new ArrowType.Duration(TimeUnit.NANOSECOND); + case "tiM": + return new ArrowType.Interval(IntervalUnit.YEAR_MONTH); + case "tiD": + return new ArrowType.Interval(IntervalUnit.DAY_TIME); + case "+l": + return new ArrowType.List(); + case "+L": + return new ArrowType.LargeList(); + case "+s": + return new ArrowType.Struct(); + case "+m": + boolean keysSorted = (flags & Flags.ARROW_FLAG_MAP_KEYS_SORTED) != 0; + return new ArrowType.Map(keysSorted); + default: + String[] parts = format.split(":", 2); + if (parts.length == 2) { + return parseComplexFormat(parts[0], parts[1]); + } + throw new UnsupportedOperationException(String.format("Format %s is not supported", format)); + } + } + + private static ArrowType parseComplexFormat(String format, String payload) + throws NumberFormatException, UnsupportedOperationException, IllegalStateException { + switch (format) { + case "d": { + int[] parts = payloadToIntArray(payload); + Preconditions.checkState(parts.length == 2 || parts.length == 3, "Format %s:%s is illegal", format, payload); + int precision = parts[0]; + int scale = parts[1]; + Integer bitWidth = (parts.length == 3) ? parts[2] : null; + return ArrowType.Decimal.createDecimal(precision, scale, bitWidth); + } + case "w": + return new ArrowType.FixedSizeBinary(Integer.parseInt(payload)); + case "+w": + return new ArrowType.FixedSizeList(Integer.parseInt(payload)); + case "+ud": + return new ArrowType.Union(UnionMode.Dense, payloadToIntArray(payload)); + case "+us": + return new ArrowType.Union(UnionMode.Sparse, payloadToIntArray(payload)); + case "tss": + return new ArrowType.Timestamp(TimeUnit.SECOND, payloadToTimezone(payload)); + case "tsm": + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, payloadToTimezone(payload)); + case "tsu": + return new ArrowType.Timestamp(TimeUnit.MICROSECOND, payloadToTimezone(payload)); + case "tsn": + return new ArrowType.Timestamp(TimeUnit.NANOSECOND, payloadToTimezone(payload)); + default: + throw new UnsupportedOperationException(String.format("Format %s:%s is not supported", format, payload)); + } + } + + private static int[] payloadToIntArray(String payload) throws NumberFormatException { + return Arrays.stream(payload.split(",")).mapToInt(Integer::parseInt).toArray(); + } + + private static String payloadToTimezone(String payload) { + if (payload.isEmpty()) { + return null; + } + return payload; + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/Metadata.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Metadata.java new file mode 100644 index 000000000..b81b24fe4 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/Metadata.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.util.Preconditions.checkState; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.util.MemoryUtil; + +/** + * Encode and decode metadata. + */ +final class Metadata { + + private Metadata() { + } + + static ArrowBuf encode(BufferAllocator allocator, Map<String, String> metadata) { + if (metadata == null || metadata.size() == 0) { + return null; + } + + List<byte[]> buffers = new ArrayList<>(metadata.size() * 2); + int totalSize = 4 + metadata.size() * 8; // number of key/value pairs + buffer length fields + for (Map.Entry<String, String> entry : metadata.entrySet()) { + byte[] keyBuffer = entry.getKey().getBytes(StandardCharsets.UTF_8); + byte[] valueBuffer = entry.getValue().getBytes(StandardCharsets.UTF_8); + totalSize += keyBuffer.length; + totalSize += valueBuffer.length; + buffers.add(keyBuffer); + buffers.add(valueBuffer); + } + + ArrowBuf result = allocator.buffer(totalSize); + ByteBuffer writer = MemoryUtil.directBuffer(result.memoryAddress(), totalSize).order(ByteOrder.nativeOrder()); + writer.putInt(metadata.size()); + for (byte[] buffer : buffers) { + writer.putInt(buffer.length); + writer.put(buffer); + } + return result.slice(0, totalSize); + } + + static Map<String, String> decode(long bufferAddress) { + if (bufferAddress == NULL) { + return null; + } + + ByteBuffer reader = MemoryUtil.directBuffer(bufferAddress, Integer.MAX_VALUE).order(ByteOrder.nativeOrder()); + + int size = reader.getInt(); + checkState(size >= 0, "Metadata size must not be negative"); + if (size == 0) { + return null; + } + + Map<String, String> result = new HashMap<>(size); + for (int i = 0; i < size; i++) { + String key = readString(reader); + String value = readString(reader); + result.put(key, value); + } + return result; + } + + private static String readString(ByteBuffer reader) { + int length = reader.getInt(); + checkState(length >= 0, "Metadata item length must not be negative"); + String result = ""; + if (length > 0) { + byte[] dst = new byte[length]; + reader.get(dst); + result = new String(dst, StandardCharsets.UTF_8); + } + return result; + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/NativeUtil.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/NativeUtil.java new file mode 100644 index 000000000..e2feda1e5 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/NativeUtil.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.util.MemoryUtil; + +/** + * Utility functions for working with native memory. + */ +public final class NativeUtil { + public static final byte NULL = 0; + static final int MAX_STRING_LENGTH = Short.MAX_VALUE; + + private NativeUtil() { + } + + /** + * Convert a pointer to a null terminated string into a Java String. + * + * @param cstringPtr pointer to C string + * @return Converted string + */ + public static String toJavaString(long cstringPtr) { + if (cstringPtr == NULL) { + return null; + } + ByteBuffer reader = MemoryUtil.directBuffer(cstringPtr, MAX_STRING_LENGTH).order(ByteOrder.nativeOrder()); + + int length = 0; + while (reader.get() != NULL) { + length++; + } + byte[] bytes = new byte[length]; + ((ByteBuffer) reader.rewind()).get(bytes); + return new String(bytes, 0, length, StandardCharsets.UTF_8); + } + + /** + * Convert a native array pointer (void**) to Java array of pointers. + * + * @param arrayPtr Array pointer + * @param size Array size + * @return Array of pointer values as longs + */ + public static long[] toJavaArray(long arrayPtr, int size) { + if (arrayPtr == NULL) { + return null; + } + if (size < 0) { + throw new IllegalArgumentException("Invalid native array size"); + } + + long[] result = new long[size]; + ByteBuffer reader = MemoryUtil.directBuffer(arrayPtr, Long.BYTES * size).order(ByteOrder.nativeOrder()); + for (int i = 0; i < size; i++) { + result[i] = reader.getLong(); + } + return result; + } + + /** + * Convert Java string to a null terminated string. + * + * @param allocator Buffer allocator for allocating the native string + * @param string Input String to convert + * @return Buffer with a null terminated string or null if the input is null + */ + public static ArrowBuf toNativeString(BufferAllocator allocator, String string) { + if (string == null) { + return null; + } + + byte[] bytes = string.getBytes(StandardCharsets.UTF_8); + ArrowBuf buffer = allocator.buffer(bytes.length + 1); + buffer.writeBytes(bytes); + buffer.writeByte(NULL); + return buffer; + } + + /** + * Close a buffer if it's not null. + * + * @param buf Buffer to close + */ + public static void closeBuffer(ArrowBuf buf) { + if (buf != null) { + buf.close(); + } + } + + /** + * Get the address of a buffer or {@value #NULL} if the input buffer is null. + * + * @param buf Buffer to get the address of + * @return Memory addresss or {@value #NULL} + */ + public static long addressOrNull(ArrowBuf buf) { + if (buf == null) { + return NULL; + } + return buf.memoryAddress(); + } + + /** + * Get the address of a C Data Interface struct or {@value #NULL} if the input + * struct is null. + * + * @param struct C Data Interface struct to get the address of + * @return Memory addresss or {@value #NULL} + */ + public static long addressOrNull(BaseStruct struct) { + if (struct == null) { + return NULL; + } + return struct.memoryAddress(); + } + +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/SchemaExporter.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/SchemaExporter.java new file mode 100644 index 000000000..04d41a4e4 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/SchemaExporter.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.c.NativeUtil.addressOrNull; +import static org.apache.arrow.util.Preconditions.checkNotNull; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.c.jni.JniWrapper; +import org.apache.arrow.c.jni.PrivateData; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; + +/** + * Exporter for {@link ArrowSchema}. + */ +final class SchemaExporter { + private final BufferAllocator allocator; + + public SchemaExporter(BufferAllocator allocator) { + this.allocator = allocator; + } + + /** + * Private data structure for exported schemas. + */ + static class ExportedSchemaPrivateData implements PrivateData { + ArrowBuf format; + ArrowBuf name; + ArrowBuf metadata; + ArrowBuf children_ptrs; + ArrowSchema dictionary; + List<ArrowSchema> children; + + @Override + public void close() { + NativeUtil.closeBuffer(format); + NativeUtil.closeBuffer(name); + NativeUtil.closeBuffer(metadata); + NativeUtil.closeBuffer(children_ptrs); + if (dictionary != null) { + dictionary.close(); + } + if (children != null) { + for (ArrowSchema child : children) { + child.close(); + } + } + } + } + + void export(ArrowSchema schema, Field field, DictionaryProvider dictionaryProvider) { + String name = field.getName(); + String format = Format.asString(field.getType()); + long flags = Flags.forField(field); + List<Field> children = field.getChildren(); + DictionaryEncoding dictionaryEncoding = field.getDictionary(); + + ExportedSchemaPrivateData data = new ExportedSchemaPrivateData(); + try { + data.format = NativeUtil.toNativeString(allocator, format); + data.name = NativeUtil.toNativeString(allocator, name); + data.metadata = Metadata.encode(allocator, field.getMetadata()); + + if (children != null) { + data.children = new ArrayList<>(children.size()); + data.children_ptrs = allocator.buffer((long) children.size() * Long.BYTES); + for (int i = 0; i < children.size(); i++) { + ArrowSchema child = ArrowSchema.allocateNew(allocator); + data.children.add(child); + data.children_ptrs.writeLong(child.memoryAddress()); + } + } + + if (dictionaryEncoding != null) { + Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId()); + checkNotNull(dictionary, "Dictionary lookup failed on export of field with dictionary"); + + data.dictionary = ArrowSchema.allocateNew(allocator); + export(data.dictionary, dictionary.getVector().getField(), dictionaryProvider); + } + + ArrowSchema.Snapshot snapshot = new ArrowSchema.Snapshot(); + snapshot.format = data.format.memoryAddress(); + snapshot.name = addressOrNull(data.name); + snapshot.metadata = addressOrNull(data.metadata); + snapshot.flags = flags; + snapshot.n_children = (data.children != null) ? data.children.size() : 0; + snapshot.children = addressOrNull(data.children_ptrs); + snapshot.dictionary = addressOrNull(data.dictionary); + snapshot.release = NULL; + schema.save(snapshot); + + // sets release and private data + JniWrapper.get().exportSchema(schema.memoryAddress(), data); + } catch (Exception e) { + data.close(); + throw e; + } + + // Export children + if (children != null) { + for (int i = 0; i < children.size(); i++) { + Field childField = children.get(i); + ArrowSchema child = data.children.get(i); + export(child, childField, dictionaryProvider); + } + } + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/SchemaImporter.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/SchemaImporter.java new file mode 100644 index 000000000..21d88f6cd --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/SchemaImporter.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.c.NativeUtil.NULL; +import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; +import static org.apache.arrow.util.Preconditions.checkNotNull; +import static org.apache.arrow.util.Preconditions.checkState; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Importer for {@link ArrowSchema}. + */ +final class SchemaImporter { + private static final Logger logger = LoggerFactory.getLogger(SchemaImporter.class); + + private static final int MAX_IMPORT_RECURSION_LEVEL = 64; + private long nextDictionaryID = 1L; + + private final BufferAllocator allocator; + + public SchemaImporter(BufferAllocator allocator) { + this.allocator = allocator; + } + + Field importField(ArrowSchema schema, CDataDictionaryProvider provider) { + return importField(schema, provider, 0); + } + + private Field importField(ArrowSchema schema, CDataDictionaryProvider provider, int recursionLevel) { + checkState(recursionLevel <= MAX_IMPORT_RECURSION_LEVEL, "Recursion level in ArrowSchema struct exceeded"); + + ArrowSchema.Snapshot snapshot = schema.snapshot(); + checkState(snapshot.release != NULL, "Cannot import released ArrowSchema"); + + String name = NativeUtil.toJavaString(snapshot.name); + String format = NativeUtil.toJavaString(snapshot.format); + checkNotNull(format, "format field must not be null"); + ArrowType arrowType = Format.asType(format, snapshot.flags); + boolean nullable = (snapshot.flags & Flags.ARROW_FLAG_NULLABLE) != 0; + Map<String, String> metadata = Metadata.decode(snapshot.metadata); + + if (metadata != null && metadata.containsKey(ExtensionType.EXTENSION_METADATA_KEY_NAME)) { + final String extensionName = metadata.get(ExtensionType.EXTENSION_METADATA_KEY_NAME); + final String extensionMetadata = metadata.getOrDefault(ExtensionType.EXTENSION_METADATA_KEY_METADATA, ""); + ExtensionType extensionType = ExtensionTypeRegistry.lookup(extensionName); + if (extensionType != null) { + arrowType = extensionType.deserialize(arrowType, extensionMetadata); + } else { + // Otherwise, we haven't registered the type + logger.info("Unrecognized extension type: {}", extensionName); + } + } + + // Handle dictionary encoded vectors + DictionaryEncoding dictionaryEncoding = null; + if (snapshot.dictionary != NULL && provider != null) { + boolean ordered = (snapshot.flags & Flags.ARROW_FLAG_DICTIONARY_ORDERED) != 0; + ArrowType.Int indexType = (ArrowType.Int) arrowType; + dictionaryEncoding = new DictionaryEncoding(nextDictionaryID++, ordered, indexType); + + ArrowSchema dictionarySchema = ArrowSchema.wrap(snapshot.dictionary); + Field dictionaryField = importField(dictionarySchema, provider, recursionLevel + 1); + provider.put(new Dictionary(dictionaryField.createVector(allocator), dictionaryEncoding)); + } + + FieldType fieldType = new FieldType(nullable, arrowType, dictionaryEncoding, metadata); + + List<Field> children = null; + long[] childrenIds = NativeUtil.toJavaArray(snapshot.children, checkedCastToInt(snapshot.n_children)); + if (childrenIds != null && childrenIds.length > 0) { + children = new ArrayList<>(childrenIds.length); + for (long childAddress : childrenIds) { + ArrowSchema childSchema = ArrowSchema.wrap(childAddress); + Field field = importField(childSchema, provider, recursionLevel + 1); + children.add(field); + } + } + return new Field(name, fieldType, children); + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/JniLoader.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/JniLoader.java new file mode 100644 index 000000000..bd2008f05 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/JniLoader.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c.jni; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The JniLoader for C Data Interface API's native implementation. + */ +public class JniLoader { + private static final JniLoader INSTANCE = new JniLoader(Collections.singletonList("arrow_cdata_jni")); + + public static JniLoader get() { + return INSTANCE; + } + + private final Set<String> librariesToLoad; + + private JniLoader(List<String> libraryNames) { + librariesToLoad = new HashSet<>(libraryNames); + } + + private boolean finished() { + return librariesToLoad.isEmpty(); + } + + /** + * If required JNI libraries are not loaded, then load them. + */ + public void ensureLoaded() { + if (finished()) { + return; + } + loadRemaining(); + } + + private synchronized void loadRemaining() { + // The method is protected by a mutex via synchronized, if more than one thread + // race to call + // loadRemaining, at same time only one will do the actual loading and the + // others will wait for + // the mutex to be acquired then check on the remaining list: if there are + // libraries that were not + // successfully loaded then the mutex owner will try to load them again. + if (finished()) { + return; + } + List<String> libs = new ArrayList<>(librariesToLoad); + for (String lib : libs) { + load(lib); + librariesToLoad.remove(lib); + } + } + + private void load(String name) { + final String libraryToLoad = System.mapLibraryName(name); + try { + File temp = File.createTempFile("jnilib-", ".tmp", new File(System.getProperty("java.io.tmpdir"))); + try (final InputStream is = JniWrapper.class.getClassLoader().getResourceAsStream(libraryToLoad)) { + if (is == null) { + throw new FileNotFoundException(libraryToLoad); + } + Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING); + System.load(temp.getAbsolutePath()); + } + } catch (IOException e) { + throw new IllegalStateException("error loading native libraries: " + e); + } + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/JniWrapper.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/JniWrapper.java new file mode 100644 index 000000000..04a143a7a --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/JniWrapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c.jni; + +/** + * JniWrapper for C Data Interface API implementation. + */ +public class JniWrapper { + private static final JniWrapper INSTANCE = new JniWrapper(); + + public static JniWrapper get() { + return INSTANCE; + } + + private JniWrapper() { + // A best effort to error on 32-bit systems + String dataModel = System.getProperty("sun.arch.data.model"); + if (dataModel != null && dataModel.equals("32")) { + throw new UnsupportedOperationException( + "The Java C Data Interface implementation is currently only supported on 64-bit systems"); + } + JniLoader.get().ensureLoaded(); + } + + public native void releaseSchema(long memoryAddress); + + public native void releaseArray(long memoryAddress); + + public native void exportSchema(long memoryAddress, PrivateData privateData); + + public native void exportArray(long memoryAddress, PrivateData data); +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/PrivateData.java b/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/PrivateData.java new file mode 100644 index 000000000..e6336cc64 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/c/jni/PrivateData.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c.jni; + +import java.io.Closeable; + +/** + * Interface for Java objects stored in C data interface private data. + * <p> + * This interface is used for exported structures. + */ +public interface PrivateData extends Closeable { + + @Override + void close(); +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/vector/StructVectorLoader.java b/src/arrow/java/c/src/main/java/org/apache/arrow/vector/StructVectorLoader.java new file mode 100644 index 000000000..eab7e491f --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/vector/StructVectorLoader.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import static org.apache.arrow.util.Preconditions.checkArgument; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Collections2; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.compression.CompressionCodec; +import org.apache.arrow.vector.compression.CompressionUtil; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Loads buffers into {@link StructVector}. + */ +public class StructVectorLoader { + + private final Schema schema; + private final CompressionCodec.Factory factory; + + /** + * A flag indicating if decompression is needed. This will affect the behavior + * of releasing buffers. + */ + private boolean decompressionNeeded; + + /** + * Construct with a schema. + * + * @param schema buffers are added based on schema. + */ + public StructVectorLoader(Schema schema) { + this(schema, NoCompressionCodec.Factory.INSTANCE); + } + + /** + * Construct with a schema and a compression codec factory. + * + * @param schema buffers are added based on schema. + * @param factory the factory to create codec. + */ + public StructVectorLoader(Schema schema, CompressionCodec.Factory factory) { + this.schema = schema; + this.factory = factory; + } + + /** + * Loads the record batch into the struct vector. + * + * <p> + * This will not close the record batch. + * + * @param recordBatch the batch to load + */ + public StructVector load(BufferAllocator allocator, ArrowRecordBatch recordBatch) { + StructVector result = StructVector.empty("", allocator); + result.initializeChildrenFromFields(this.schema.getFields()); + + Iterator<ArrowBuf> buffers = recordBatch.getBuffers().iterator(); + Iterator<ArrowFieldNode> nodes = recordBatch.getNodes().iterator(); + CompressionUtil.CodecType codecType = CompressionUtil.CodecType + .fromCompressionType(recordBatch.getBodyCompression().getCodec()); + decompressionNeeded = codecType != CompressionUtil.CodecType.NO_COMPRESSION; + CompressionCodec codec = decompressionNeeded ? factory.createCodec(codecType) : NoCompressionCodec.INSTANCE; + for (FieldVector fieldVector : result.getChildrenFromFields()) { + loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec); + } + result.loadFieldBuffers(new ArrowFieldNode(recordBatch.getLength(), 0), Collections.singletonList(null)); + if (nodes.hasNext() || buffers.hasNext()) { + throw new IllegalArgumentException("not all nodes and buffers were consumed. nodes: " + + Collections2.toList(nodes).toString() + " buffers: " + Collections2.toList(buffers).toString()); + } + return result; + } + + private void loadBuffers(FieldVector vector, Field field, Iterator<ArrowBuf> buffers, Iterator<ArrowFieldNode> nodes, + CompressionCodec codec) { + checkArgument(nodes.hasNext(), "no more field nodes for for field %s and vector %s", field, vector); + ArrowFieldNode fieldNode = nodes.next(); + int bufferLayoutCount = TypeLayout.getTypeBufferCount(field.getType()); + List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount); + for (int j = 0; j < bufferLayoutCount; j++) { + ArrowBuf nextBuf = buffers.next(); + // for vectors without nulls, the buffer is empty, so there is no need to + // decompress it. + ArrowBuf bufferToAdd = nextBuf.writerIndex() > 0 ? codec.decompress(vector.getAllocator(), nextBuf) : nextBuf; + ownBuffers.add(bufferToAdd); + if (decompressionNeeded) { + // decompression performed + nextBuf.getReferenceManager().retain(); + } + } + try { + vector.loadFieldBuffers(fieldNode, ownBuffers); + if (decompressionNeeded) { + for (ArrowBuf buf : ownBuffers) { + buf.close(); + } + } + } catch (RuntimeException e) { + throw new IllegalArgumentException( + "Could not load buffers for field " + field + ". error message: " + e.getMessage(), e); + } + List<Field> children = field.getChildren(); + if (children.size() > 0) { + List<FieldVector> childrenFromFields = vector.getChildrenFromFields(); + checkArgument(children.size() == childrenFromFields.size(), + "should have as many children as in the schema: found %s expected %s", childrenFromFields.size(), + children.size()); + for (int i = 0; i < childrenFromFields.size(); i++) { + Field child = children.get(i); + FieldVector fieldVector = childrenFromFields.get(i); + loadBuffers(fieldVector, child, buffers, nodes, codec); + } + } + } +} diff --git a/src/arrow/java/c/src/main/java/org/apache/arrow/vector/StructVectorUnloader.java b/src/arrow/java/c/src/main/java/org/apache/arrow/vector/StructVectorUnloader.java new file mode 100644 index 000000000..e75156cf2 --- /dev/null +++ b/src/arrow/java/c/src/main/java/org/apache/arrow/vector/StructVectorUnloader.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.compression.CompressionCodec; +import org.apache.arrow.vector.compression.CompressionUtil; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; + +/** + * Helper class that handles converting a {@link StructVector} to a + * {@link ArrowRecordBatch}. + */ +public class StructVectorUnloader { + + private final StructVector root; + private final boolean includeNullCount; + private final CompressionCodec codec; + private final boolean alignBuffers; + + /** + * Constructs a new instance of the given struct vector. + */ + public StructVectorUnloader(StructVector root) { + this(root, true, NoCompressionCodec.INSTANCE, true); + } + + /** + * Constructs a new instance. + * + * @param root The struct vector to serialize to an + * {@link ArrowRecordBatch}. + * @param includeNullCount Controls whether null count is copied to the + * {@link ArrowRecordBatch} + * @param alignBuffers Controls if buffers get aligned to 8-byte boundaries. + */ + public StructVectorUnloader(StructVector root, boolean includeNullCount, boolean alignBuffers) { + this(root, includeNullCount, NoCompressionCodec.INSTANCE, alignBuffers); + } + + /** + * Constructs a new instance. + * + * @param root The struct vector to serialize to an + * {@link ArrowRecordBatch}. + * @param includeNullCount Controls whether null count is copied to the + * {@link ArrowRecordBatch} + * @param codec the codec for compressing data. If it is null, then + * no compression is needed. + * @param alignBuffers Controls if buffers get aligned to 8-byte boundaries. + */ + public StructVectorUnloader(StructVector root, boolean includeNullCount, CompressionCodec codec, + boolean alignBuffers) { + this.root = root; + this.includeNullCount = includeNullCount; + this.codec = codec; + this.alignBuffers = alignBuffers; + } + + /** + * Performs the depth first traversal of the Vectors to create an + * {@link ArrowRecordBatch} suitable for serialization. + */ + public ArrowRecordBatch getRecordBatch() { + List<ArrowFieldNode> nodes = new ArrayList<>(); + List<ArrowBuf> buffers = new ArrayList<>(); + for (FieldVector vector : root.getChildrenFromFields()) { + appendNodes(vector, nodes, buffers); + } + return new ArrowRecordBatch(root.getValueCount(), nodes, buffers, CompressionUtil.createBodyCompression(codec), + alignBuffers); + } + + private void appendNodes(FieldVector vector, List<ArrowFieldNode> nodes, List<ArrowBuf> buffers) { + nodes.add(new ArrowFieldNode(vector.getValueCount(), includeNullCount ? vector.getNullCount() : -1)); + List<ArrowBuf> fieldBuffers = vector.getFieldBuffers(); + int expectedBufferCount = TypeLayout.getTypeBufferCount(vector.getField().getType()); + if (fieldBuffers.size() != expectedBufferCount) { + throw new IllegalArgumentException(String.format("wrong number of buffers for field %s in vector %s. found: %s", + vector.getField(), vector.getClass().getSimpleName(), fieldBuffers)); + } + for (ArrowBuf buf : fieldBuffers) { + buffers.add(codec.compress(vector.getAllocator(), buf)); + } + for (FieldVector child : vector.getChildrenFromFields()) { + appendNodes(child, nodes, buffers); + } + } +} diff --git a/src/arrow/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java b/src/arrow/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java new file mode 100644 index 000000000..3f793f836 --- /dev/null +++ b/src/arrow/java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.util.Collections; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.CDataDictionaryProvider; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.compare.VectorEqualsVisitor; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class DictionaryTest { + private RootAllocator allocator = null; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + void roundtrip(FieldVector vector, DictionaryProvider provider, Class<?> clazz) { + // Consumer allocates empty structures + try (ArrowSchema consumerArrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray consumerArrowArray = ArrowArray.allocateNew(allocator)) { + + // Producer creates structures from existing memory pointers + try (ArrowSchema arrowSchema = ArrowSchema.wrap(consumerArrowSchema.memoryAddress()); + ArrowArray arrowArray = ArrowArray.wrap(consumerArrowArray.memoryAddress())) { + // Producer exports vector into the C Data Interface structures + Data.exportVector(allocator, vector, provider, arrowArray, arrowSchema); + } + + // Consumer imports vector + try (CDataDictionaryProvider cDictionaryProvider = new CDataDictionaryProvider(); + FieldVector imported = Data.importVector(allocator, consumerArrowArray, consumerArrowSchema, + cDictionaryProvider);) { + assertTrue(clazz.isInstance(imported), String.format("expected %s but was %s", clazz, imported.getClass())); + assertTrue(VectorEqualsVisitor.vectorEquals(vector, imported), "vectors are not equivalent"); + for (long id : cDictionaryProvider.getDictionaryIds()) { + ValueVector exportedDictionaryVector = provider.lookup(id).getVector(); + ValueVector importedDictionaryVector = cDictionaryProvider.lookup(id).getVector(); + assertTrue(VectorEqualsVisitor.vectorEquals(exportedDictionaryVector, importedDictionaryVector), + String.format("Dictionary vectors for ID %d are not equivalent", id)); + } + } + } + } + + @Test + public void testWithDictionary() throws Exception { + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + // create dictionary and provider + final VarCharVector dictVector = new VarCharVector("dict", allocator); + dictVector.allocateNewSafe(); + dictVector.setSafe(0, "aa".getBytes()); + dictVector.setSafe(1, "bb".getBytes()); + dictVector.setSafe(2, "cc".getBytes()); + dictVector.setValueCount(3); + + Dictionary dictionary = new Dictionary(dictVector, new DictionaryEncoding(1L, false, /* indexType= */null)); + provider.put(dictionary); + + // create vector and encode it + final VarCharVector vector = new VarCharVector("vector", allocator); + vector.allocateNewSafe(); + vector.setSafe(0, "bb".getBytes()); + vector.setSafe(1, "bb".getBytes()); + vector.setSafe(2, "cc".getBytes()); + vector.setSafe(3, "aa".getBytes()); + vector.setValueCount(4); + + // get the encoded vector + IntVector encodedVector = (IntVector) DictionaryEncoder.encode(vector, dictionary); + + // Perform roundtrip using C Data Interface + roundtrip(encodedVector, provider, IntVector.class); + + // Close all + AutoCloseables.close((AutoCloseable) vector, encodedVector, dictVector); + } + + @Test + public void testRoundtripMultipleBatches() throws IOException { + try (ArrowStreamReader reader = createMultiBatchReader(); + ArrowSchema consumerArrowSchema = ArrowSchema.allocateNew(allocator)) { + // Load first batch + reader.loadNextBatch(); + // Producer fills consumer schema stucture + Data.exportSchema(allocator, reader.getVectorSchemaRoot().getSchema(), reader, consumerArrowSchema); + // Consumer loads it as an empty vector schema root + try (CDataDictionaryProvider consumerDictionaryProvider = new CDataDictionaryProvider(); + VectorSchemaRoot consumerRoot = Data.importVectorSchemaRoot(allocator, consumerArrowSchema, + consumerDictionaryProvider)) { + do { + try (ArrowArray consumerArray = ArrowArray.allocateNew(allocator)) { + // Producer exports next data + Data.exportVectorSchemaRoot(allocator, reader.getVectorSchemaRoot(), reader, consumerArray); + // Consumer loads next data + Data.importIntoVectorSchemaRoot(allocator, consumerArray, consumerRoot, consumerDictionaryProvider); + + // Roundtrip validation + assertTrue(consumerRoot.equals(reader.getVectorSchemaRoot()), "vector schema roots are not equivalent"); + for (long id : consumerDictionaryProvider.getDictionaryIds()) { + ValueVector exportedDictionaryVector = reader.lookup(id).getVector(); + ValueVector importedDictionaryVector = consumerDictionaryProvider.lookup(id).getVector(); + assertTrue(VectorEqualsVisitor.vectorEquals(exportedDictionaryVector, importedDictionaryVector), + String.format("Dictionary vectors for ID %d are not equivalent", id)); + } + } + } + while (reader.loadNextBatch()); + } + } + } + + private ArrowStreamReader createMultiBatchReader() throws IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + try (final VarCharVector dictVector = new VarCharVector("dict", allocator); + IntVector vector = new IntVector("foo", allocator)) { + // create dictionary and provider + DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + dictVector.allocateNewSafe(); + dictVector.setSafe(0, "aa".getBytes()); + dictVector.setSafe(1, "bb".getBytes()); + dictVector.setSafe(2, "cc".getBytes()); + dictVector.setSafe(3, "dd".getBytes()); + dictVector.setSafe(4, "ee".getBytes()); + dictVector.setValueCount(5); + Dictionary dictionary = new Dictionary(dictVector, new DictionaryEncoding(1L, false, /* indexType= */null)); + provider.put(dictionary); + + Schema schema = new Schema(Collections.singletonList(vector.getField())); + try ( + VectorSchemaRoot root = new VectorSchemaRoot(schema, Collections.singletonList(vector), + vector.getValueCount()); + ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, Channels.newChannel(os));) { + + writer.start(); + + // Batch 1 + vector.setNull(0); + vector.setSafe(1, 1); + vector.setSafe(2, 2); + vector.setNull(3); + vector.setSafe(4, 1); + vector.setValueCount(5); + root.setRowCount(5); + writer.writeBatch(); + + // Batch 2 + vector.setNull(0); + vector.setSafe(1, 1); + vector.setSafe(2, 2); + vector.setValueCount(3); + root.setRowCount(3); + writer.writeBatch(); + + // Batch 3 + vector.setSafe(0, 0); + vector.setSafe(1, 1); + vector.setSafe(2, 2); + vector.setSafe(3, 3); + vector.setSafe(4, 4); + vector.setValueCount(5); + root.setRowCount(5); + writer.writeBatch(); + + writer.end(); + } + } + + ByteArrayInputStream in = new ByteArrayInputStream(os.toByteArray()); + return new ArrowStreamReader(in, allocator); + } + +} diff --git a/src/arrow/java/c/src/test/java/org/apache/arrow/c/FlagsTest.java b/src/arrow/java/c/src/test/java/org/apache/arrow/c/FlagsTest.java new file mode 100644 index 000000000..35f836f71 --- /dev/null +++ b/src/arrow/java/c/src/test/java/org/apache/arrow/c/FlagsTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; + +import org.apache.arrow.c.Flags; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.jupiter.api.Test; + +public class FlagsTest { + @Test + public void testForFieldNullableOrderedDict() { + FieldType fieldType = new FieldType(true, ArrowType.Binary.INSTANCE, + new DictionaryEncoding(123L, true, new ArrowType.Int(8, true))); + + assertEquals(Flags.ARROW_FLAG_DICTIONARY_ORDERED | Flags.ARROW_FLAG_NULLABLE, + Flags.forField(new Field("Name", fieldType, new ArrayList<>()))); + } + + @Test + public void testForFieldOrderedDict() { + FieldType fieldType = new FieldType(false, ArrowType.Binary.INSTANCE, + new DictionaryEncoding(123L, true, new ArrowType.Int(8, true))); + assertEquals(Flags.ARROW_FLAG_DICTIONARY_ORDERED, Flags.forField(new Field("Name", fieldType, new ArrayList<>()))); + } + + @Test + public void testForFieldNullableDict() { + FieldType fieldType = new FieldType(true, ArrowType.Binary.INSTANCE, + new DictionaryEncoding(123L, false, new ArrowType.Int(8, true))); + assertEquals(Flags.ARROW_FLAG_NULLABLE, Flags.forField(new Field("Name", fieldType, new ArrayList<>()))); + } + + @Test + public void testForFieldNullable() { + FieldType fieldType = new FieldType(true, ArrowType.Binary.INSTANCE, null); + assertEquals(Flags.ARROW_FLAG_NULLABLE, Flags.forField(new Field("Name", fieldType, new ArrayList<>()))); + } + + @Test + public void testForFieldNullableOrderedSortedMap() { + ArrowType.Map type = new ArrowType.Map(true); + FieldType fieldType = new FieldType(true, type, new DictionaryEncoding(123L, true, new ArrowType.Int(8, true))); + assertEquals(Flags.ARROW_FLAG_DICTIONARY_ORDERED | Flags.ARROW_FLAG_NULLABLE | Flags.ARROW_FLAG_MAP_KEYS_SORTED, + Flags.forField(new Field("Name", fieldType, new ArrayList<>()))); + } + + @Test + public void testForFieldNullableOrderedMap() { + ArrowType.Map type = new ArrowType.Map(false); + FieldType fieldType = new FieldType(true, type, new DictionaryEncoding(123L, true, new ArrowType.Int(8, true))); + assertEquals(Flags.ARROW_FLAG_DICTIONARY_ORDERED | Flags.ARROW_FLAG_NULLABLE, + Flags.forField(new Field("Name", fieldType, new ArrayList<>()))); + } +} diff --git a/src/arrow/java/c/src/test/java/org/apache/arrow/c/FormatTest.java b/src/arrow/java/c/src/test/java/org/apache/arrow/c/FormatTest.java new file mode 100644 index 000000000..1f7f86b36 --- /dev/null +++ b/src/arrow/java/c/src/test/java/org/apache/arrow/c/FormatTest.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.arrow.c.Flags; +import org.apache.arrow.c.Format; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.jupiter.api.Test; + +public class FormatTest { + @Test + public void testAsString() { + assertEquals("z", Format.asString(new ArrowType.Binary())); + assertEquals("b", Format.asString(new ArrowType.Bool())); + assertEquals("tdD", Format.asString(new ArrowType.Date(DateUnit.DAY))); + assertEquals("tdm", Format.asString(new ArrowType.Date(DateUnit.MILLISECOND))); + assertEquals("d:1,1", Format.asString(new ArrowType.Decimal(1, 1, 128))); + assertEquals("d:1,1,1", Format.asString(new ArrowType.Decimal(1, 1, 1))); + assertEquals("d:9,1,1", Format.asString(new ArrowType.Decimal(9, 1, 1))); + assertEquals("tDs", Format.asString(new ArrowType.Duration(TimeUnit.SECOND))); + assertEquals("tDm", Format.asString(new ArrowType.Duration(TimeUnit.MILLISECOND))); + assertEquals("tDu", Format.asString(new ArrowType.Duration(TimeUnit.MICROSECOND))); + assertEquals("tDn", Format.asString(new ArrowType.Duration(TimeUnit.NANOSECOND))); + assertEquals("w:1", Format.asString(new ArrowType.FixedSizeBinary(1))); + assertEquals("+w:3", Format.asString(new ArrowType.FixedSizeList(3))); + assertEquals("e", Format.asString(new ArrowType.FloatingPoint(FloatingPointPrecision.HALF))); + assertEquals("f", Format.asString(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); + assertEquals("g", Format.asString(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))); + assertEquals("c", Format.asString(new ArrowType.Int(Byte.SIZE, true))); + assertEquals("C", Format.asString(new ArrowType.Int(Byte.SIZE, false))); + assertEquals("s", Format.asString(new ArrowType.Int(Short.SIZE, true))); + assertEquals("S", Format.asString(new ArrowType.Int(Short.SIZE, false))); + assertEquals("i", Format.asString(new ArrowType.Int(Integer.SIZE, true))); + assertEquals("I", Format.asString(new ArrowType.Int(Integer.SIZE, false))); + assertEquals("l", Format.asString(new ArrowType.Int(Long.SIZE, true))); + assertEquals("L", Format.asString(new ArrowType.Int(Long.SIZE, false))); + assertEquals("tiD", Format.asString(new ArrowType.Interval(IntervalUnit.DAY_TIME))); + assertEquals("tiM", Format.asString(new ArrowType.Interval(IntervalUnit.YEAR_MONTH))); + assertEquals("Z", Format.asString(new ArrowType.LargeBinary())); + assertEquals("+L", Format.asString(new ArrowType.LargeList())); + assertEquals("U", Format.asString(new ArrowType.LargeUtf8())); + assertEquals("+l", Format.asString(new ArrowType.List())); + assertEquals("+m", Format.asString(new ArrowType.Map(true))); + assertEquals("n", Format.asString(new ArrowType.Null())); + assertEquals("+s", Format.asString(new ArrowType.Struct())); + assertEquals("tts", Format.asString(new ArrowType.Time(TimeUnit.SECOND, 32))); + assertEquals("ttm", Format.asString(new ArrowType.Time(TimeUnit.MILLISECOND, 32))); + assertEquals("ttu", Format.asString(new ArrowType.Time(TimeUnit.MICROSECOND, 64))); + assertEquals("ttn", Format.asString(new ArrowType.Time(TimeUnit.NANOSECOND, 64))); + assertEquals("tss:Timezone", Format.asString(new ArrowType.Timestamp(TimeUnit.SECOND, "Timezone"))); + assertEquals("tsm:Timezone", Format.asString(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "Timezone"))); + assertEquals("tsu:Timezone", Format.asString(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "Timezone"))); + assertEquals("tsn:Timezone", Format.asString(new ArrowType.Timestamp(TimeUnit.NANOSECOND, "Timezone"))); + assertEquals("+us:1,1,1", Format.asString(new ArrowType.Union(UnionMode.Sparse, new int[] { 1, 1, 1 }))); + assertEquals("+ud:1,1,1", Format.asString(new ArrowType.Union(UnionMode.Dense, new int[] { 1, 1, 1 }))); + assertEquals("u", Format.asString(new ArrowType.Utf8())); + + assertThrows(UnsupportedOperationException.class, () -> Format.asString(new ArrowType.Int(1, true))); + assertThrows(UnsupportedOperationException.class, () -> Format.asString(new ArrowType.Time(TimeUnit.SECOND, 1))); + assertThrows(UnsupportedOperationException.class, + () -> Format.asString(new ArrowType.Time(TimeUnit.MILLISECOND, 64))); + } + + @Test + public void testAsType() throws IllegalStateException, NumberFormatException, UnsupportedOperationException { + assertTrue(Format.asType("n", 0L) instanceof ArrowType.Null); + assertTrue(Format.asType("b", 0L) instanceof ArrowType.Bool); + assertEquals(new ArrowType.Int(Byte.SIZE, true), Format.asType("c", 0L)); + assertEquals(new ArrowType.Int(Byte.SIZE, false), Format.asType("C", 0L)); + assertEquals(new ArrowType.Int(Short.SIZE, true), Format.asType("s", 0L)); + assertEquals(new ArrowType.Int(Short.SIZE, false), Format.asType("S", 0L)); + assertEquals(new ArrowType.Int(Integer.SIZE, true), Format.asType("i", 0L)); + assertEquals(new ArrowType.Int(Integer.SIZE, false), Format.asType("I", 0L)); + assertEquals(new ArrowType.Int(Long.SIZE, true), Format.asType("l", 0L)); + assertEquals(new ArrowType.Int(Long.SIZE, false), Format.asType("L", 0L)); + assertEquals(new ArrowType.FloatingPoint(FloatingPointPrecision.HALF), Format.asType("e", 0L)); + assertEquals(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), Format.asType("f", 0L)); + assertEquals(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), Format.asType("g", 0L)); + assertTrue(Format.asType("z", 0L) instanceof ArrowType.Binary); + assertTrue(Format.asType("Z", 0L) instanceof ArrowType.LargeBinary); + assertTrue(Format.asType("u", 0L) instanceof ArrowType.Utf8); + assertTrue(Format.asType("U", 0L) instanceof ArrowType.LargeUtf8); + assertEquals(new ArrowType.Date(DateUnit.DAY), Format.asType("tdD", 0L)); + assertEquals(new ArrowType.Date(DateUnit.MILLISECOND), Format.asType("tdm", 0L)); + assertEquals(new ArrowType.Time(TimeUnit.SECOND, Integer.SIZE), Format.asType("tts", 0L)); + assertEquals(new ArrowType.Time(TimeUnit.MILLISECOND, Integer.SIZE), Format.asType("ttm", 0L)); + assertEquals(new ArrowType.Time(TimeUnit.MICROSECOND, Long.SIZE), Format.asType("ttu", 0L)); + assertEquals(new ArrowType.Time(TimeUnit.NANOSECOND, Long.SIZE), Format.asType("ttn", 0L)); + assertEquals(new ArrowType.Duration(TimeUnit.SECOND), Format.asType("tDs", 0L)); + assertEquals(new ArrowType.Duration(TimeUnit.MILLISECOND), Format.asType("tDm", 0L)); + assertEquals(new ArrowType.Duration(TimeUnit.MICROSECOND), Format.asType("tDu", 0L)); + assertEquals(new ArrowType.Duration(TimeUnit.NANOSECOND), Format.asType("tDn", 0L)); + assertEquals(new ArrowType.Interval(IntervalUnit.YEAR_MONTH), Format.asType("tiM", 0L)); + assertEquals(new ArrowType.Interval(IntervalUnit.DAY_TIME), Format.asType("tiD", 0L)); + assertTrue(Format.asType("+l", 0L) instanceof ArrowType.List); + assertTrue(Format.asType("+L", 0L) instanceof ArrowType.LargeList); + assertTrue(Format.asType("+s", 0L) instanceof ArrowType.Struct); + assertEquals(new ArrowType.Map(false), Format.asType("+m", 0L)); + assertEquals(new ArrowType.Map(true), Format.asType("+m", Flags.ARROW_FLAG_MAP_KEYS_SORTED)); + assertEquals(new ArrowType.Decimal(1, 1, 128), Format.asType("d:1,1", 0L)); + assertEquals(new ArrowType.Decimal(1, 1, 1), Format.asType("d:1,1,1", 0L)); + assertEquals(new ArrowType.Decimal(9, 1, 1), Format.asType("d:9,1,1", 0L)); + assertEquals(new ArrowType.FixedSizeBinary(1), Format.asType("w:1", 0L)); + assertEquals(new ArrowType.FixedSizeList(3), Format.asType("+w:3", 0L)); + assertEquals(new ArrowType.Union(UnionMode.Dense, new int[] { 1, 1, 1 }), Format.asType("+ud:1,1,1", 0L)); + assertEquals(new ArrowType.Union(UnionMode.Sparse, new int[] { 1, 1, 1 }), Format.asType("+us:1,1,1", 0L)); + assertEquals(new ArrowType.Timestamp(TimeUnit.SECOND, "Timezone"), Format.asType("tss:Timezone", 0L)); + assertEquals(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "Timezone"), Format.asType("tsm:Timezone", 0L)); + assertEquals(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "Timezone"), Format.asType("tsu:Timezone", 0L)); + assertEquals(new ArrowType.Timestamp(TimeUnit.NANOSECOND, "Timezone"), Format.asType("tsn:Timezone", 0L)); + + assertThrows(UnsupportedOperationException.class, () -> Format.asType("Format", 0L)); + assertThrows(UnsupportedOperationException.class, () -> Format.asType(":", 0L)); + assertThrows(NumberFormatException.class, () -> Format.asType("w:1,2,3", 0L)); + } +} diff --git a/src/arrow/java/c/src/test/java/org/apache/arrow/c/MetadataTest.java b/src/arrow/java/c/src/test/java/org/apache/arrow/c/MetadataTest.java new file mode 100644 index 000000000..1d9703b1a --- /dev/null +++ b/src/arrow/java/c/src/test/java/org/apache/arrow/c/MetadataTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.junit.jupiter.api.Assertions.*; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.c.Metadata; +import org.apache.arrow.c.NativeUtil; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.memory.util.MemoryUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class MetadataTest { + private RootAllocator allocator = null; + + private static Map<String, String> metadata; + private static byte[] encoded; + + @BeforeAll + static void beforeAll() { + metadata = new HashMap<>(); + metadata.put("key1", ""); + metadata.put("key2", "bar"); + + if (ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN) { + encoded = new byte[] { 2, 0, 0, 0, 4, 0, 0, 0, 'k', 'e', 'y', '1', 0, 0, 0, 0, 4, 0, 0, 0, 'k', 'e', 'y', '2', 3, + 0, 0, 0, 'b', 'a', 'r' }; + } else { + encoded = new byte[] { 0, 0, 0, 2, 0, 0, 0, 4, 'k', 'e', 'y', '1', 0, 0, 0, 0, 0, 0, 0, 4, 'k', 'e', 'y', '2', 0, + 0, 0, 3, 'b', 'a', 'r' }; + } + } + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + @Test + public void testEncode() { + try (ArrowBuf buffer = Metadata.encode(allocator, metadata)) { + int totalSize = LargeMemoryUtil.checkedCastToInt(buffer.readableBytes()); + ByteBuffer reader = MemoryUtil.directBuffer(buffer.memoryAddress(), totalSize).order(ByteOrder.nativeOrder()); + byte[] result = new byte[totalSize]; + reader.get(result); + assertArrayEquals(encoded, result); + } + } + + @Test + public void testDecode() { + try (ArrowBuf buffer = allocator.buffer(31)) { + buffer.setBytes(0, encoded); + Map<String, String> decoded = Metadata.decode(buffer.memoryAddress()); + assertNotNull(decoded); + assertEquals(metadata, decoded); + } + } + + @Test + public void testEncodeEmpty() { + Map<String, String> metadata = new HashMap<>(); + try (ArrowBuf encoded = Metadata.encode(allocator, metadata)) { + assertNull(encoded); + } + } + + @Test + public void testDecodeEmpty() { + Map<String, String> decoded = Metadata.decode(NativeUtil.NULL); + assertNull(decoded); + } + +} diff --git a/src/arrow/java/c/src/test/java/org/apache/arrow/c/NativeUtilTest.java b/src/arrow/java/c/src/test/java/org/apache/arrow/c/NativeUtilTest.java new file mode 100644 index 000000000..f46a0128c --- /dev/null +++ b/src/arrow/java/c/src/test/java/org/apache/arrow/c/NativeUtilTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import org.apache.arrow.c.NativeUtil; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.memory.util.MemoryUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class NativeUtilTest { + + private RootAllocator allocator = null; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + @Test + public void testString() { + String javaString = "abc"; + byte[] nativeString = new byte[] { 97, 98, 99, 0 }; + try (ArrowBuf buffer = NativeUtil.toNativeString(allocator, javaString)) { + int totalSize = LargeMemoryUtil.checkedCastToInt(buffer.readableBytes()); + ByteBuffer reader = MemoryUtil.directBuffer(buffer.memoryAddress(), totalSize).order(ByteOrder.nativeOrder()); + byte[] result = new byte[totalSize]; + reader.get(result); + assertArrayEquals(nativeString, result); + + assertEquals(javaString, NativeUtil.toJavaString(buffer.memoryAddress())); + } + } + + @Test + public void testToJavaArray() { + long[] nativeArray = new long[] { 1, 2, 3 }; + try (ArrowBuf buffer = allocator.buffer(Long.BYTES * nativeArray.length, null)) { + for (long value : nativeArray) { + buffer.writeLong(value); + } + long[] actual = NativeUtil.toJavaArray(buffer.memoryAddress(), nativeArray.length); + assertArrayEquals(nativeArray, actual); + } + } + + @Test + public void testToZeroJavaArray() { + long[] actual = NativeUtil.toJavaArray(0xDEADBEEF, 0); + assertEquals(0, actual.length); + } + +} diff --git a/src/arrow/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java b/src/arrow/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java new file mode 100644 index 000000000..059ca3284 --- /dev/null +++ b/src/arrow/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java @@ -0,0 +1,795 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.c; + +import static org.apache.arrow.vector.testing.ValueVectorDataPopulator.setVector; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.memory.util.hash.ArrowBufHasher; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.ExtensionTypeVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ZeroVector; +import org.apache.arrow.vector.compare.VectorEqualsVisitor; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; +import org.apache.arrow.vector.holders.IntervalDayHolder; +import org.apache.arrow.vector.holders.NullableLargeVarBinaryHolder; +import org.apache.arrow.vector.holders.NullableUInt4Holder; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; +import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class RoundtripTest { + private static final String EMPTY_SCHEMA_PATH = ""; + private RootAllocator allocator = null; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + FieldVector vectorRoundtrip(FieldVector vector) { + // Consumer allocates empty structures + try (ArrowSchema consumerArrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray consumerArrowArray = ArrowArray.allocateNew(allocator)) { + + // Producer creates structures from existing memory pointers + try (ArrowSchema arrowSchema = ArrowSchema.wrap(consumerArrowSchema.memoryAddress()); + ArrowArray arrowArray = ArrowArray.wrap(consumerArrowArray.memoryAddress())) { + // Producer exports vector into the C Data Interface structures + Data.exportVector(allocator, vector, null, arrowArray, arrowSchema); + } + + // Consumer imports vector + return Data.importVector(allocator, consumerArrowArray, consumerArrowSchema, null); + } + } + + VectorSchemaRoot vectorSchemaRootRoundtrip(VectorSchemaRoot root) { + // Consumer allocates empty structures + try (ArrowSchema consumerArrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray consumerArrowArray = ArrowArray.allocateNew(allocator)) { + + // Producer creates structures from existing memory pointers + try (ArrowSchema arrowSchema = ArrowSchema.wrap(consumerArrowSchema.memoryAddress()); + ArrowArray arrowArray = ArrowArray.wrap(consumerArrowArray.memoryAddress())) { + // Producer exports vector into the C Data Interface structures + Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); + } + + // Consumer imports vector + return Data.importVectorSchemaRoot(allocator, consumerArrowArray, consumerArrowSchema, null); + } + } + + boolean roundtrip(FieldVector vector, Class<?> clazz) { + try (ValueVector imported = vectorRoundtrip(vector)) { + assertTrue(clazz.isInstance(imported), String.format("expected %s but was %s", clazz, imported.getClass())); + return VectorEqualsVisitor.vectorEquals(vector, imported); + } + } + + @Test + public void testBitVector() { + BitVector imported; + + try (final BitVector vector = new BitVector(EMPTY_SCHEMA_PATH, allocator)) { + vector.allocateNew(1024); + vector.setValueCount(1024); + + // Put and set a few values + vector.set(0, 1); + vector.set(1, 0); + vector.set(100, 0); + vector.set(1022, 1); + + vector.setValueCount(1024); + + imported = (BitVector) vectorRoundtrip(vector); + assertTrue(VectorEqualsVisitor.vectorEquals(vector, imported)); + } + + assertEquals(1, imported.get(0)); + assertEquals(0, imported.get(1)); + assertEquals(0, imported.get(100)); + assertEquals(1, imported.get(1022)); + assertEquals(1020, imported.getNullCount()); + imported.close(); + } + + @Test + public void testIntVector() { + IntVector imported; + try (final IntVector vector = new IntVector("v", allocator)) { + setVector(vector, 1, 2, 3, null); + imported = (IntVector) vectorRoundtrip(vector); + assertTrue(VectorEqualsVisitor.vectorEquals(vector, imported)); + } + assertEquals(1, imported.get(0)); + assertEquals(2, imported.get(1)); + assertEquals(3, imported.get(2)); + assertEquals(4, imported.getValueCount()); + assertEquals(1, imported.getNullCount()); + imported.close(); + } + + @Test + public void testBigIntVector() { + BigIntVector imported; + try (final BigIntVector vector = new BigIntVector("v", allocator)) { + setVector(vector, 1L, 2L, 3L, null); + imported = (BigIntVector) vectorRoundtrip(vector); + assertTrue(VectorEqualsVisitor.vectorEquals(vector, imported)); + } + assertEquals(1, imported.get(0)); + assertEquals(2, imported.get(1)); + assertEquals(3, imported.get(2)); + assertEquals(4, imported.getValueCount()); + assertEquals(1, imported.getNullCount()); + imported.close(); + } + + @Test + public void testDateDayVector() { + DateDayVector imported; + try (final DateDayVector vector = new DateDayVector("v", allocator)) { + setVector(vector, 1, 2, 3, null); + imported = (DateDayVector) vectorRoundtrip(vector); + assertTrue(VectorEqualsVisitor.vectorEquals(vector, imported)); + } + assertEquals(1, imported.get(0)); + assertEquals(2, imported.get(1)); + assertEquals(3, imported.get(2)); + assertEquals(4, imported.getValueCount()); + assertEquals(1, imported.getNullCount()); + imported.close(); + } + + @Test + public void testDateMilliVector() { + DateMilliVector imported; + try (final DateMilliVector vector = new DateMilliVector("v", allocator)) { + setVector(vector, 1L, 2L, 3L, null); + imported = (DateMilliVector) vectorRoundtrip(vector); + assertTrue(VectorEqualsVisitor.vectorEquals(vector, imported)); + } + assertEquals(1, imported.get(0)); + assertEquals(2, imported.get(1)); + assertEquals(3, imported.get(2)); + assertEquals(4, imported.getValueCount()); + assertEquals(1, imported.getNullCount()); + imported.close(); + } + + @Test + public void testDecimalVector() { + try (final DecimalVector vector = new DecimalVector("v", allocator, 1, 1)) { + setVector(vector, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, DecimalVector.class)); + } + } + + @Test + public void testDurationVector() { + for (TimeUnit unit : TimeUnit.values()) { + final FieldType fieldType = FieldType.nullable(new ArrowType.Duration(unit)); + try (final DurationVector vector = new DurationVector("v", fieldType, allocator)) { + setVector(vector, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, DurationVector.class)); + } + } + } + + @Test + public void testZeroVectorEquals() { + try (final ZeroVector vector = new ZeroVector()) { + // A ZeroVector is imported as a NullVector + assertTrue(roundtrip(vector, NullVector.class)); + } + } + + @Test + public void testFixedSizeBinaryVector() { + try (final FixedSizeBinaryVector vector = new FixedSizeBinaryVector("v", allocator, 2)) { + setVector(vector, new byte[] { 0b0000, 0b0001 }, new byte[] { 0b0010, 0b0011 }); + assertTrue(roundtrip(vector, FixedSizeBinaryVector.class)); + } + } + + @Test + public void testFloat4Vector() { + try (final Float4Vector vector = new Float4Vector("v", allocator)) { + setVector(vector, 0.1f, 0.2f, 0.3f, null); + assertTrue(roundtrip(vector, Float4Vector.class)); + } + } + + @Test + public void testFloat8Vector() { + try (final Float8Vector vector = new Float8Vector("v", allocator)) { + setVector(vector, 0.1d, 0.2d, 0.3d, null); + assertTrue(roundtrip(vector, Float8Vector.class)); + } + } + + @Test + public void testIntervalDayVector() { + try (final IntervalDayVector vector = new IntervalDayVector("v", allocator)) { + IntervalDayHolder value = new IntervalDayHolder(); + value.days = 5; + value.milliseconds = 100; + setVector(vector, value, null); + assertTrue(roundtrip(vector, IntervalDayVector.class)); + } + } + + @Test + public void testIntervalYearVector() { + try (final IntervalYearVector vector = new IntervalYearVector("v", allocator)) { + setVector(vector, 1990, 2000, 2010, 2020, null); + assertTrue(roundtrip(vector, IntervalYearVector.class)); + } + } + + @Test + public void testSmallIntVector() { + try (final SmallIntVector vector = new SmallIntVector("v", allocator)) { + setVector(vector, (short) 0, (short) 256, null); + assertTrue(roundtrip(vector, SmallIntVector.class)); + } + } + + @Test + public void testTimeMicroVector() { + try (final TimeMicroVector vector = new TimeMicroVector("v", allocator)) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeMicroVector.class)); + } + } + + @Test + public void testTimeMilliVector() { + try (final TimeMilliVector vector = new TimeMilliVector("v", allocator)) { + setVector(vector, 0, 1, 2, 3, null); + assertTrue(roundtrip(vector, TimeMilliVector.class)); + } + } + + @Test + public void testTimeNanoVector() { + try (final TimeNanoVector vector = new TimeNanoVector("v", allocator)) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeNanoVector.class)); + } + } + + @Test + public void testTimeSecVector() { + try (final TimeSecVector vector = new TimeSecVector("v", allocator)) { + setVector(vector, 0, 1, 2, 3, null); + assertTrue(roundtrip(vector, TimeSecVector.class)); + } + } + + @Test + public void testTimeStampMicroTZVector() { + try (final TimeStampMicroTZVector vector = new TimeStampMicroTZVector("v", allocator, "UTC")) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampMicroTZVector.class)); + } + } + + @Test + public void testTimeStampMicroVector() { + try (final TimeStampMicroVector vector = new TimeStampMicroVector("v", allocator)) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampMicroVector.class)); + } + } + + @Test + public void testTimeStampMilliTZVector() { + try (final TimeStampMilliTZVector vector = new TimeStampMilliTZVector("v", allocator, "UTC")) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampMilliTZVector.class)); + } + } + + @Test + public void testTimeStampMilliVector() { + try (final TimeStampMilliVector vector = new TimeStampMilliVector("v", allocator)) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampMilliVector.class)); + } + } + + @Test + public void testTimeTimeStampNanoTZVector() { + try (final TimeStampNanoTZVector vector = new TimeStampNanoTZVector("v", allocator, "UTC")) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampNanoTZVector.class)); + } + } + + @Test + public void testTimeStampNanoVector() { + try (final TimeStampNanoVector vector = new TimeStampNanoVector("v", allocator)) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampNanoVector.class)); + } + } + + @Test + public void testTimeStampSecTZVector() { + try (final TimeStampSecTZVector vector = new TimeStampSecTZVector("v", allocator, "UTC")) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampSecTZVector.class)); + } + } + + @Test + public void testTimeStampSecVector() { + try (final TimeStampSecVector vector = new TimeStampSecVector("v", allocator)) { + setVector(vector, 0L, 1L, 2L, 3L, null); + assertTrue(roundtrip(vector, TimeStampSecVector.class)); + } + } + + @Test + public void testTinyIntVector() { + try (final TinyIntVector vector = new TinyIntVector("v", allocator)) { + setVector(vector, (byte) 0, (byte) 1, null); + assertTrue(roundtrip(vector, TinyIntVector.class)); + } + } + + @Test + public void testUInt1Vector() { + try (final UInt1Vector vector = new UInt1Vector("v", allocator)) { + setVector(vector, (byte) 0, (byte) 1, null); + assertTrue(roundtrip(vector, UInt1Vector.class)); + } + } + + @Test + public void testUInt2Vector() { + try (final UInt2Vector vector = new UInt2Vector("v", allocator)) { + setVector(vector, '0', '1', null); + assertTrue(roundtrip(vector, UInt2Vector.class)); + } + } + + @Test + public void testUInt4Vector() { + try (final UInt4Vector vector = new UInt4Vector("v", allocator)) { + setVector(vector, 0, 1, null); + assertTrue(roundtrip(vector, UInt4Vector.class)); + } + } + + @Test + public void testUInt8Vector() { + try (final UInt8Vector vector = new UInt8Vector("v", allocator)) { + setVector(vector, 0L, 1L, null); + assertTrue(roundtrip(vector, UInt8Vector.class)); + } + } + + @Test + public void testVarBinaryVector() { + try (final VarBinaryVector vector = new VarBinaryVector("v", allocator)) { + setVector(vector, "abc".getBytes(), "def".getBytes(), null); + assertTrue(roundtrip(vector, VarBinaryVector.class)); + } + } + + @Test + public void testVarCharVector() { + try (final VarCharVector vector = new VarCharVector("v", allocator)) { + setVector(vector, "abc", "def", null); + assertTrue(roundtrip(vector, VarCharVector.class)); + } + } + + @Test + public void testLargeVarBinaryVector() { + try (final LargeVarBinaryVector vector = new LargeVarBinaryVector("", allocator)) { + vector.allocateNew(5, 1); + + NullableLargeVarBinaryHolder nullHolder = new NullableLargeVarBinaryHolder(); + nullHolder.isSet = 0; + + NullableLargeVarBinaryHolder binHolder = new NullableLargeVarBinaryHolder(); + binHolder.isSet = 1; + + String str = "hello world"; + try (ArrowBuf buf = allocator.buffer(16)) { + buf.setBytes(0, str.getBytes()); + binHolder.start = 0; + binHolder.end = str.length(); + binHolder.buffer = buf; + vector.setSafe(0, binHolder); + vector.setSafe(1, nullHolder); + + assertTrue(roundtrip(vector, LargeVarBinaryVector.class)); + } + } + } + + @Test + public void testLargeVarCharVector() { + try (final LargeVarCharVector vector = new LargeVarCharVector("v", allocator)) { + setVector(vector, "abc", "def", null); + assertTrue(roundtrip(vector, LargeVarCharVector.class)); + } + } + + @Test + public void testListVector() { + try (final ListVector vector = ListVector.empty("v", allocator)) { + setVector(vector, Arrays.stream(new int[] { 1, 2 }).boxed().collect(Collectors.toList()), + Arrays.stream(new int[] { 3, 4 }).boxed().collect(Collectors.toList()), new ArrayList<Integer>()); + assertTrue(roundtrip(vector, ListVector.class)); + } + } + + @Test + public void testLargeListVector() { + try (final LargeListVector vector = LargeListVector.empty("v", allocator)) { + setVector(vector, Arrays.stream(new int[] { 1, 2 }).boxed().collect(Collectors.toList()), + Arrays.stream(new int[] { 3, 4 }).boxed().collect(Collectors.toList()), new ArrayList<Integer>()); + assertTrue(roundtrip(vector, LargeListVector.class)); + } + } + + @Test + public void testFixedSizeListVector() { + try (final FixedSizeListVector vector = FixedSizeListVector.empty("v", 2, allocator)) { + setVector(vector, Arrays.stream(new int[] { 1, 2 }).boxed().collect(Collectors.toList()), + Arrays.stream(new int[] { 3, 4 }).boxed().collect(Collectors.toList())); + assertTrue(roundtrip(vector, FixedSizeListVector.class)); + } + } + + @Test + public void testMapVector() { + int count = 5; + try (final MapVector vector = MapVector.empty("v", allocator, false)) { + vector.allocateNew(); + UnionMapWriter mapWriter = vector.getWriter(); + for (int i = 0; i < count; i++) { + mapWriter.startMap(); + for (int j = 0; j < i + 1; j++) { + mapWriter.startEntry(); + mapWriter.key().bigInt().writeBigInt(j); + mapWriter.value().integer().writeInt(j); + mapWriter.endEntry(); + } + mapWriter.endMap(); + } + mapWriter.setValueCount(count); + + assertTrue(roundtrip(vector, MapVector.class)); + } + } + + @Test + public void testUnionVector() { + final NullableUInt4Holder uInt4Holder = new NullableUInt4Holder(); + uInt4Holder.value = 100; + uInt4Holder.isSet = 1; + + try (UnionVector vector = UnionVector.empty("v", allocator)) { + vector.allocateNew(); + + // write some data + vector.setType(0, MinorType.UINT4); + vector.setSafe(0, uInt4Holder); + vector.setType(2, MinorType.UINT4); + vector.setSafe(2, uInt4Holder); + vector.setValueCount(4); + + assertTrue(roundtrip(vector, UnionVector.class)); + } + } + + @Test + public void testStructVector() { + try (final StructVector vector = StructVector.empty("v", allocator)) { + Map<String, List<Integer>> data = new HashMap<>(); + data.put("col_1", Arrays.stream(new int[] { 1, 2 }).boxed().collect(Collectors.toList())); + data.put("col_2", Arrays.stream(new int[] { 3, 4 }).boxed().collect(Collectors.toList())); + setVector(vector, data); + assertTrue(roundtrip(vector, StructVector.class)); + } + } + + @Test + public void testExtensionTypeVector() { + ExtensionTypeRegistry.register(new UuidType()); + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new UuidType()))); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + // Fill with data + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + UuidVector vector = (UuidVector) root.getVector("a"); + vector.setValueCount(2); + vector.set(0, u1); + vector.set(1, u2); + root.setRowCount(2); + + // Roundtrip (export + import) + VectorSchemaRoot importedRoot = vectorSchemaRootRoundtrip(root); + + // Verify correctness + assertEquals(root.getSchema(), importedRoot.getSchema()); + + final Field field = importedRoot.getSchema().getFields().get(0); + final UuidType expectedType = new UuidType(); + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_NAME), expectedType.extensionName()); + assertEquals(field.getMetadata().get(ExtensionType.EXTENSION_METADATA_KEY_METADATA), expectedType.serialize()); + + final UuidVector deserialized = (UuidVector) importedRoot.getFieldVectors().get(0); + assertEquals(vector.getValueCount(), deserialized.getValueCount()); + for (int i = 0; i < vector.getValueCount(); i++) { + assertEquals(vector.isNull(i), deserialized.isNull(i)); + if (!vector.isNull(i)) { + assertEquals(vector.getObject(i), deserialized.getObject(i)); + } + } + + importedRoot.close(); + } + } + + @Test + public void testVectorSchemaRoot() { + VectorSchemaRoot imported; + + // Consumer allocates empty structures + try (ArrowSchema consumerArrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray consumerArrowArray = ArrowArray.allocateNew(allocator)) { + try (VectorSchemaRoot vsr = createTestVSR()) { + // Producer creates structures from existing memory pointers + try (ArrowSchema arrowSchema = ArrowSchema.wrap(consumerArrowSchema.memoryAddress()); + ArrowArray arrowArray = ArrowArray.wrap(consumerArrowArray.memoryAddress())) { + // Producer exports vector into the C Data Interface structures + Data.exportVectorSchemaRoot(allocator, vsr, null, arrowArray, arrowSchema); + } + } + // Consumer imports vector + imported = Data.importVectorSchemaRoot(allocator, consumerArrowArray, consumerArrowSchema, null); + } + + // Ensure that imported VectorSchemaRoot is valid even after C Data Interface + // structures are closed + try (VectorSchemaRoot original = createTestVSR()) { + assertTrue(imported.equals(original)); + } + imported.close(); + } + + @Test + public void testSchema() { + Field decimalField = new Field("inner1", FieldType.nullable(new ArrowType.Decimal(19, 4, 128)), null); + Field strField = new Field("inner2", FieldType.nullable(new ArrowType.Utf8()), null); + Field itemField = new Field("col1", FieldType.nullable(new ArrowType.Struct()), + Arrays.asList(decimalField, strField)); + Field intField = new Field("col2", FieldType.nullable(new ArrowType.Int(32, true)), null); + Schema schema = new Schema(Arrays.asList(itemField, intField)); + // Consumer allocates empty ArrowSchema + try (ArrowSchema consumerArrowSchema = ArrowSchema.allocateNew(allocator)) { + // Producer fills the schema with data + try (ArrowSchema arrowSchema = ArrowSchema.wrap(consumerArrowSchema.memoryAddress())) { + Data.exportSchema(allocator, schema, null, arrowSchema); + } + // Consumer imports schema + Schema importedSchema = Data.importSchema(allocator, consumerArrowSchema, null); + assertEquals(schema.toJson(), importedSchema.toJson()); + } + } + + @Test + public void testImportReleasedArray() { + // Consumer allocates empty structures + try (ArrowSchema consumerArrowSchema = ArrowSchema.allocateNew(allocator); + ArrowArray consumerArrowArray = ArrowArray.allocateNew(allocator)) { + // Producer creates structures from existing memory pointers + try (ArrowSchema arrowSchema = ArrowSchema.wrap(consumerArrowSchema.memoryAddress()); + ArrowArray arrowArray = ArrowArray.wrap(consumerArrowArray.memoryAddress())) { + // Producer exports vector into the C Data Interface structures + try (final NullVector vector = new NullVector()) { + Data.exportVector(allocator, vector, null, arrowArray, arrowSchema); + } + } + + // Release array structure + consumerArrowArray.markReleased(); + + // Consumer tried to imports vector but fails + Exception e = assertThrows(IllegalStateException.class, () -> { + Data.importVector(allocator, consumerArrowArray, consumerArrowSchema, null); + }); + + assertEquals("Cannot import released ArrowArray", e.getMessage()); + } + } + + private VectorSchemaRoot createTestVSR() { + BitVector bitVector = new BitVector("boolean", allocator); + + Map<String, String> metadata = new HashMap<>(); + metadata.put("key", "value"); + FieldType fieldType = new FieldType(true, ArrowType.Utf8.INSTANCE, null, metadata); + VarCharVector varCharVector = new VarCharVector("varchar", fieldType, allocator); + + bitVector.allocateNew(); + varCharVector.allocateNew(); + for (int i = 0; i < 10; i++) { + bitVector.setSafe(i, i % 2 == 0 ? 0 : 1); + varCharVector.setSafe(i, ("test" + i).getBytes(StandardCharsets.UTF_8)); + } + bitVector.setValueCount(10); + varCharVector.setValueCount(10); + + List<Field> fields = Arrays.asList(bitVector.getField(), varCharVector.getField()); + List<FieldVector> vectors = Arrays.asList(bitVector, varCharVector); + + return new VectorSchemaRoot(fields, vectors); + } + + static class UuidType extends ExtensionType { + + @Override + public ArrowType storageType() { + return new ArrowType.FixedSizeBinary(16); + } + + @Override + public String extensionName() { + return "uuid"; + } + + @Override + public boolean extensionEquals(ExtensionType other) { + return other instanceof UuidType; + } + + @Override + public ArrowType deserialize(ArrowType storageType, String serializedData) { + if (!storageType.equals(storageType())) { + throw new UnsupportedOperationException("Cannot construct UuidType from underlying type " + storageType); + } + return new UuidType(); + } + + @Override + public String serialize() { + return ""; + } + + @Override + public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator) { + return new UuidVector(name, allocator, new FixedSizeBinaryVector(name, allocator, 16)); + } + } + + static class UuidVector extends ExtensionTypeVector<FixedSizeBinaryVector> { + + public UuidVector(String name, BufferAllocator allocator, FixedSizeBinaryVector underlyingVector) { + super(name, allocator, underlyingVector); + } + + @Override + public UUID getObject(int index) { + final ByteBuffer bb = ByteBuffer.wrap(getUnderlyingVector().getObject(index)); + return new UUID(bb.getLong(), bb.getLong()); + } + + @Override + public int hashCode(int index) { + return hashCode(index, null); + } + + @Override + public int hashCode(int index, ArrowBufHasher hasher) { + return getUnderlyingVector().hashCode(index, hasher); + } + + public void set(int index, UUID uuid) { + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + getUnderlyingVector().set(index, bb.array()); + } + } +} |