summaryrefslogtreecommitdiffstats
path: root/src/arrow/cpp/src/gandiva/jni/jni_common.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/cpp/src/gandiva/jni/jni_common.cc')
-rw-r--r--src/arrow/cpp/src/gandiva/jni/jni_common.cc1055
1 files changed, 1055 insertions, 0 deletions
diff --git a/src/arrow/cpp/src/gandiva/jni/jni_common.cc b/src/arrow/cpp/src/gandiva/jni/jni_common.cc
new file mode 100644
index 000000000..5a4cbb031
--- /dev/null
+++ b/src/arrow/cpp/src/gandiva/jni/jni_common.cc
@@ -0,0 +1,1055 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <google/protobuf/io/coded_stream.h>
+
+#include <map>
+#include <memory>
+#include <mutex>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include <arrow/builder.h>
+#include <arrow/record_batch.h>
+#include <arrow/type.h>
+
+#include "Types.pb.h"
+#include "gandiva/configuration.h"
+#include "gandiva/decimal_scalar.h"
+#include "gandiva/filter.h"
+#include "gandiva/jni/config_holder.h"
+#include "gandiva/jni/env_helper.h"
+#include "gandiva/jni/id_to_module_map.h"
+#include "gandiva/jni/module_holder.h"
+#include "gandiva/projector.h"
+#include "gandiva/selection_vector.h"
+#include "gandiva/tree_expr_builder.h"
+#include "jni/org_apache_arrow_gandiva_evaluator_JniWrapper.h"
+
+using gandiva::ConditionPtr;
+using gandiva::DataTypePtr;
+using gandiva::ExpressionPtr;
+using gandiva::ExpressionVector;
+using gandiva::FieldPtr;
+using gandiva::FieldVector;
+using gandiva::Filter;
+using gandiva::NodePtr;
+using gandiva::NodeVector;
+using gandiva::Projector;
+using gandiva::SchemaPtr;
+using gandiva::Status;
+using gandiva::TreeExprBuilder;
+
+using gandiva::ArrayDataVector;
+using gandiva::ConfigHolder;
+using gandiva::Configuration;
+using gandiva::ConfigurationBuilder;
+using gandiva::FilterHolder;
+using gandiva::ProjectorHolder;
+
+// forward declarations
+NodePtr ProtoTypeToNode(const types::TreeNode& node);
+
+static jint JNI_VERSION = JNI_VERSION_1_6;
+
+// extern refs - initialized for other modules.
+jclass configuration_builder_class_;
+
+// refs for self.
+static jclass gandiva_exception_;
+static jclass vector_expander_class_;
+static jclass vector_expander_ret_class_;
+static jmethodID vector_expander_method_;
+static jfieldID vector_expander_ret_address_;
+static jfieldID vector_expander_ret_capacity_;
+
+// module maps
+gandiva::IdToModuleMap<std::shared_ptr<ProjectorHolder>> projector_modules_;
+gandiva::IdToModuleMap<std::shared_ptr<FilterHolder>> filter_modules_;
+
+jint JNI_OnLoad(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) {
+ return JNI_ERR;
+ }
+ jclass local_configuration_builder_class_ =
+ env->FindClass("org/apache/arrow/gandiva/evaluator/ConfigurationBuilder");
+ configuration_builder_class_ =
+ (jclass)env->NewGlobalRef(local_configuration_builder_class_);
+ env->DeleteLocalRef(local_configuration_builder_class_);
+
+ jclass localExceptionClass =
+ env->FindClass("org/apache/arrow/gandiva/exceptions/GandivaException");
+ gandiva_exception_ = (jclass)env->NewGlobalRef(localExceptionClass);
+ env->ExceptionDescribe();
+ env->DeleteLocalRef(localExceptionClass);
+
+ jclass local_expander_class =
+ env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander");
+ vector_expander_class_ = (jclass)env->NewGlobalRef(local_expander_class);
+ env->DeleteLocalRef(local_expander_class);
+
+ vector_expander_method_ = env->GetMethodID(
+ vector_expander_class_, "expandOutputVectorAtIndex",
+ "(IJ)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;");
+
+ jclass local_expander_ret_class =
+ env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult");
+ vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class);
+ env->DeleteLocalRef(local_expander_ret_class);
+
+ vector_expander_ret_address_ =
+ env->GetFieldID(vector_expander_ret_class_, "address", "J");
+ vector_expander_ret_capacity_ =
+ env->GetFieldID(vector_expander_ret_class_, "capacity", "J");
+ return JNI_VERSION;
+}
+
+void JNI_OnUnload(JavaVM* vm, void* reserved) {
+ JNIEnv* env;
+ vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION);
+ env->DeleteGlobalRef(configuration_builder_class_);
+ env->DeleteGlobalRef(gandiva_exception_);
+ env->DeleteGlobalRef(vector_expander_class_);
+ env->DeleteGlobalRef(vector_expander_ret_class_);
+}
+
+DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.timeunit()) {
+ case types::SEC:
+ return arrow::time32(arrow::TimeUnit::SECOND);
+ case types::MILLISEC:
+ return arrow::time32(arrow::TimeUnit::MILLI);
+ default:
+ std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time32\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToTime64(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.timeunit()) {
+ case types::MICROSEC:
+ return arrow::time64(arrow::TimeUnit::MICRO);
+ case types::NANOSEC:
+ return arrow::time64(arrow::TimeUnit::NANO);
+ default:
+ std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time64\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToTimestamp(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.timeunit()) {
+ case types::SEC:
+ return arrow::timestamp(arrow::TimeUnit::SECOND);
+ case types::MILLISEC:
+ return arrow::timestamp(arrow::TimeUnit::MILLI);
+ case types::MICROSEC:
+ return arrow::timestamp(arrow::TimeUnit::MICRO);
+ case types::NANOSEC:
+ return arrow::timestamp(arrow::TimeUnit::NANO);
+ default:
+ std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for timestamp\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.intervaltype()) {
+ case types::YEAR_MONTH:
+ return arrow::month_interval();
+ case types::DAY_TIME:
+ return arrow::day_time_interval();
+ default:
+ std::cerr << "Unknown interval type: " << ext_type.intervaltype() << "\n";
+ return nullptr;
+ }
+}
+
+DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) {
+ switch (ext_type.type()) {
+ case types::NONE:
+ return arrow::null();
+ case types::BOOL:
+ return arrow::boolean();
+ case types::UINT8:
+ return arrow::uint8();
+ case types::INT8:
+ return arrow::int8();
+ case types::UINT16:
+ return arrow::uint16();
+ case types::INT16:
+ return arrow::int16();
+ case types::UINT32:
+ return arrow::uint32();
+ case types::INT32:
+ return arrow::int32();
+ case types::UINT64:
+ return arrow::uint64();
+ case types::INT64:
+ return arrow::int64();
+ case types::HALF_FLOAT:
+ return arrow::float16();
+ case types::FLOAT:
+ return arrow::float32();
+ case types::DOUBLE:
+ return arrow::float64();
+ case types::UTF8:
+ return arrow::utf8();
+ case types::BINARY:
+ return arrow::binary();
+ case types::DATE32:
+ return arrow::date32();
+ case types::DATE64:
+ return arrow::date64();
+ case types::DECIMAL:
+ // TODO: error handling
+ return arrow::decimal(ext_type.precision(), ext_type.scale());
+ case types::TIME32:
+ return ProtoTypeToTime32(ext_type);
+ case types::TIME64:
+ return ProtoTypeToTime64(ext_type);
+ case types::TIMESTAMP:
+ return ProtoTypeToTimestamp(ext_type);
+ case types::INTERVAL:
+ return ProtoTypeToInterval(ext_type);
+ case types::FIXED_SIZE_BINARY:
+ case types::LIST:
+ case types::STRUCT:
+ case types::UNION:
+ case types::DICTIONARY:
+ case types::MAP:
+ std::cerr << "Unhandled data type: " << ext_type.type() << "\n";
+ return nullptr;
+
+ default:
+ std::cerr << "Unknown data type: " << ext_type.type() << "\n";
+ return nullptr;
+ }
+}
+
+FieldPtr ProtoTypeToField(const types::Field& f) {
+ const std::string& name = f.name();
+ DataTypePtr type = ProtoTypeToDataType(f.type());
+ bool nullable = true;
+ if (f.has_nullable()) {
+ nullable = f.nullable();
+ }
+
+ return field(name, type, nullable);
+}
+
+NodePtr ProtoTypeToFieldNode(const types::FieldNode& node) {
+ FieldPtr field_ptr = ProtoTypeToField(node.field());
+ if (field_ptr == nullptr) {
+ std::cerr << "Unable to create field node from protobuf\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeField(field_ptr);
+}
+
+NodePtr ProtoTypeToFnNode(const types::FunctionNode& node) {
+ const std::string& name = node.functionname();
+ NodeVector children;
+
+ for (int i = 0; i < node.inargs_size(); i++) {
+ const types::TreeNode& arg = node.inargs(i);
+
+ NodePtr n = ProtoTypeToNode(arg);
+ if (n == nullptr) {
+ std::cerr << "Unable to create argument for function: " << name << "\n";
+ return nullptr;
+ }
+
+ children.push_back(n);
+ }
+
+ DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
+ if (return_type == nullptr) {
+ std::cerr << "Unknown return type for function: " << name << "\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeFunction(name, children, return_type);
+}
+
+NodePtr ProtoTypeToIfNode(const types::IfNode& node) {
+ NodePtr cond = ProtoTypeToNode(node.cond());
+ if (cond == nullptr) {
+ std::cerr << "Unable to create cond node for if node\n";
+ return nullptr;
+ }
+
+ NodePtr then_node = ProtoTypeToNode(node.thennode());
+ if (then_node == nullptr) {
+ std::cerr << "Unable to create then node for if node\n";
+ return nullptr;
+ }
+
+ NodePtr else_node = ProtoTypeToNode(node.elsenode());
+ if (else_node == nullptr) {
+ std::cerr << "Unable to create else node for if node\n";
+ return nullptr;
+ }
+
+ DataTypePtr return_type = ProtoTypeToDataType(node.returntype());
+ if (return_type == nullptr) {
+ std::cerr << "Unknown return type for if node\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeIf(cond, then_node, else_node, return_type);
+}
+
+NodePtr ProtoTypeToAndNode(const types::AndNode& node) {
+ NodeVector children;
+
+ for (int i = 0; i < node.args_size(); i++) {
+ const types::TreeNode& arg = node.args(i);
+
+ NodePtr n = ProtoTypeToNode(arg);
+ if (n == nullptr) {
+ std::cerr << "Unable to create argument for boolean and\n";
+ return nullptr;
+ }
+ children.push_back(n);
+ }
+ return TreeExprBuilder::MakeAnd(children);
+}
+
+NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
+ NodeVector children;
+
+ for (int i = 0; i < node.args_size(); i++) {
+ const types::TreeNode& arg = node.args(i);
+
+ NodePtr n = ProtoTypeToNode(arg);
+ if (n == nullptr) {
+ std::cerr << "Unable to create argument for boolean or\n";
+ return nullptr;
+ }
+ children.push_back(n);
+ }
+ return TreeExprBuilder::MakeOr(children);
+}
+
+NodePtr ProtoTypeToInNode(const types::InNode& node) {
+ NodePtr field = ProtoTypeToNode(node.node());
+
+ if (node.has_intvalues()) {
+ std::unordered_set<int32_t> int_values;
+ for (int i = 0; i < node.intvalues().intvalues_size(); i++) {
+ int_values.insert(node.intvalues().intvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionInt32(field, int_values);
+ }
+
+ if (node.has_longvalues()) {
+ std::unordered_set<int64_t> long_values;
+ for (int i = 0; i < node.longvalues().longvalues_size(); i++) {
+ long_values.insert(node.longvalues().longvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionInt64(field, long_values);
+ }
+
+ if (node.has_decimalvalues()) {
+ std::unordered_set<gandiva::DecimalScalar128> decimal_values;
+ for (int i = 0; i < node.decimalvalues().decimalvalues_size(); i++) {
+ decimal_values.insert(
+ gandiva::DecimalScalar128(node.decimalvalues().decimalvalues(i).value(),
+ node.decimalvalues().decimalvalues(i).precision(),
+ node.decimalvalues().decimalvalues(i).scale()));
+ }
+ return TreeExprBuilder::MakeInExpressionDecimal(field, decimal_values);
+ }
+
+ if (node.has_floatvalues()) {
+ std::unordered_set<float> float_values;
+ for (int i = 0; i < node.floatvalues().floatvalues_size(); i++) {
+ float_values.insert(node.floatvalues().floatvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionFloat(field, float_values);
+ }
+
+ if (node.has_doublevalues()) {
+ std::unordered_set<double> double_values;
+ for (int i = 0; i < node.doublevalues().doublevalues_size(); i++) {
+ double_values.insert(node.doublevalues().doublevalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionDouble(field, double_values);
+ }
+
+ if (node.has_stringvalues()) {
+ std::unordered_set<std::string> stringvalues;
+ for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) {
+ stringvalues.insert(node.stringvalues().stringvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionString(field, stringvalues);
+ }
+
+ if (node.has_binaryvalues()) {
+ std::unordered_set<std::string> stringvalues;
+ for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) {
+ stringvalues.insert(node.binaryvalues().binaryvalues(i).value());
+ }
+ return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues);
+ }
+ // not supported yet.
+ std::cerr << "Unknown constant type for in expression.\n";
+ return nullptr;
+}
+
+NodePtr ProtoTypeToNullNode(const types::NullNode& node) {
+ DataTypePtr data_type = ProtoTypeToDataType(node.type());
+ if (data_type == nullptr) {
+ std::cerr << "Unknown type " << data_type->ToString() << " for null node\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeNull(data_type);
+}
+
+NodePtr ProtoTypeToNode(const types::TreeNode& node) {
+ if (node.has_fieldnode()) {
+ return ProtoTypeToFieldNode(node.fieldnode());
+ }
+
+ if (node.has_fnnode()) {
+ return ProtoTypeToFnNode(node.fnnode());
+ }
+
+ if (node.has_ifnode()) {
+ return ProtoTypeToIfNode(node.ifnode());
+ }
+
+ if (node.has_andnode()) {
+ return ProtoTypeToAndNode(node.andnode());
+ }
+
+ if (node.has_ornode()) {
+ return ProtoTypeToOrNode(node.ornode());
+ }
+
+ if (node.has_innode()) {
+ return ProtoTypeToInNode(node.innode());
+ }
+
+ if (node.has_nullnode()) {
+ return ProtoTypeToNullNode(node.nullnode());
+ }
+
+ if (node.has_intnode()) {
+ return TreeExprBuilder::MakeLiteral(node.intnode().value());
+ }
+
+ if (node.has_floatnode()) {
+ return TreeExprBuilder::MakeLiteral(node.floatnode().value());
+ }
+
+ if (node.has_longnode()) {
+ return TreeExprBuilder::MakeLiteral(node.longnode().value());
+ }
+
+ if (node.has_booleannode()) {
+ return TreeExprBuilder::MakeLiteral(node.booleannode().value());
+ }
+
+ if (node.has_doublenode()) {
+ return TreeExprBuilder::MakeLiteral(node.doublenode().value());
+ }
+
+ if (node.has_stringnode()) {
+ return TreeExprBuilder::MakeStringLiteral(node.stringnode().value());
+ }
+
+ if (node.has_binarynode()) {
+ return TreeExprBuilder::MakeBinaryLiteral(node.binarynode().value());
+ }
+
+ if (node.has_decimalnode()) {
+ std::string value = node.decimalnode().value();
+ gandiva::DecimalScalar128 literal(value, node.decimalnode().precision(),
+ node.decimalnode().scale());
+ return TreeExprBuilder::MakeDecimalLiteral(literal);
+ }
+ std::cerr << "Unknown node type in protobuf\n";
+ return nullptr;
+}
+
+ExpressionPtr ProtoTypeToExpression(const types::ExpressionRoot& root) {
+ NodePtr root_node = ProtoTypeToNode(root.root());
+ if (root_node == nullptr) {
+ std::cerr << "Unable to create expression node from expression protobuf\n";
+ return nullptr;
+ }
+
+ FieldPtr field = ProtoTypeToField(root.resulttype());
+ if (field == nullptr) {
+ std::cerr << "Unable to extra return field from expression protobuf\n";
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeExpression(root_node, field);
+}
+
+ConditionPtr ProtoTypeToCondition(const types::Condition& condition) {
+ NodePtr root_node = ProtoTypeToNode(condition.root());
+ if (root_node == nullptr) {
+ return nullptr;
+ }
+
+ return TreeExprBuilder::MakeCondition(root_node);
+}
+
+SchemaPtr ProtoTypeToSchema(const types::Schema& schema) {
+ std::vector<FieldPtr> fields;
+
+ for (int i = 0; i < schema.columns_size(); i++) {
+ FieldPtr field = ProtoTypeToField(schema.columns(i));
+ if (field == nullptr) {
+ std::cerr << "Unable to extract arrow field from schema\n";
+ return nullptr;
+ }
+
+ fields.push_back(field);
+ }
+
+ return arrow::schema(fields);
+}
+
+// Common for both projector and filters.
+
+bool ParseProtobuf(uint8_t* buf, int bufLen, google::protobuf::Message* msg) {
+ google::protobuf::io::CodedInputStream cis(buf, bufLen);
+ cis.SetRecursionLimit(1000);
+ return msg->ParseFromCodedStream(&cis);
+}
+
+Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows,
+ jlong* in_buf_addrs, jlong* in_buf_sizes,
+ int in_bufs_len,
+ std::shared_ptr<arrow::RecordBatch>* batch) {
+ std::vector<std::shared_ptr<arrow::ArrayData>> columns;
+ auto num_fields = schema->num_fields();
+ int buf_idx = 0;
+ int sz_idx = 0;
+
+ for (int i = 0; i < num_fields; i++) {
+ auto field = schema->field(i);
+ std::vector<std::shared_ptr<arrow::Buffer>> buffers;
+
+ if (buf_idx >= in_bufs_len) {
+ return Status::Invalid("insufficient number of in_buf_addrs");
+ }
+ jlong validity_addr = in_buf_addrs[buf_idx++];
+ jlong validity_size = in_buf_sizes[sz_idx++];
+ auto validity = std::shared_ptr<arrow::Buffer>(
+ new arrow::Buffer(reinterpret_cast<uint8_t*>(validity_addr), validity_size));
+ buffers.push_back(validity);
+
+ if (buf_idx >= in_bufs_len) {
+ return Status::Invalid("insufficient number of in_buf_addrs");
+ }
+ jlong value_addr = in_buf_addrs[buf_idx++];
+ jlong value_size = in_buf_sizes[sz_idx++];
+ auto data = std::shared_ptr<arrow::Buffer>(
+ new arrow::Buffer(reinterpret_cast<uint8_t*>(value_addr), value_size));
+ buffers.push_back(data);
+
+ if (arrow::is_binary_like(field->type()->id())) {
+ if (buf_idx >= in_bufs_len) {
+ return Status::Invalid("insufficient number of in_buf_addrs");
+ }
+
+ // add offsets buffer for variable-len fields.
+ jlong offsets_addr = in_buf_addrs[buf_idx++];
+ jlong offsets_size = in_buf_sizes[sz_idx++];
+ auto offsets = std::shared_ptr<arrow::Buffer>(
+ new arrow::Buffer(reinterpret_cast<uint8_t*>(offsets_addr), offsets_size));
+ buffers.push_back(offsets);
+ }
+
+ auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers));
+ columns.push_back(array_data);
+ }
+ *batch = arrow::RecordBatch::Make(schema, num_rows, columns);
+ return Status::OK();
+}
+
+// projector related functions.
+void releaseProjectorInput(jbyteArray schema_arr, jbyte* schema_bytes,
+ jbyteArray exprs_arr, jbyte* exprs_bytes, JNIEnv* env) {
+ env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
+ env->ReleaseByteArrayElements(exprs_arr, exprs_bytes, JNI_ABORT);
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector(
+ JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray exprs_arr,
+ jint selection_vector_type, jlong configuration_id) {
+ jlong module_id = 0LL;
+ std::shared_ptr<Projector> projector;
+ std::shared_ptr<ProjectorHolder> holder;
+
+ types::Schema schema;
+ jsize schema_len = env->GetArrayLength(schema_arr);
+ jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
+
+ types::ExpressionList exprs;
+ jsize exprs_len = env->GetArrayLength(exprs_arr);
+ jbyte* exprs_bytes = env->GetByteArrayElements(exprs_arr, 0);
+
+ ExpressionVector expr_vector;
+ SchemaPtr schema_ptr;
+ FieldVector ret_types;
+ gandiva::Status status;
+ auto mode = gandiva::SelectionVector::MODE_NONE;
+
+ std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
+ std::stringstream ss;
+
+ if (config == nullptr) {
+ ss << "configuration is mandatory.";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
+ ss << "Unable to parse schema protobuf\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(exprs_bytes), exprs_len, &exprs)) {
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ ss << "Unable to parse expressions protobuf\n";
+ goto err_out;
+ }
+
+ // convert types::Schema to arrow::Schema
+ schema_ptr = ProtoTypeToSchema(schema);
+ if (schema_ptr == nullptr) {
+ ss << "Unable to construct arrow schema object from schema protobuf\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ // create Expression out of the list of exprs
+ for (int i = 0; i < exprs.exprs_size(); i++) {
+ ExpressionPtr root = ProtoTypeToExpression(exprs.exprs(i));
+
+ if (root == nullptr) {
+ ss << "Unable to construct expression object from expression protobuf\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ expr_vector.push_back(root);
+ ret_types.push_back(root->result());
+ }
+
+ switch (selection_vector_type) {
+ case types::SV_NONE:
+ mode = gandiva::SelectionVector::MODE_NONE;
+ break;
+ case types::SV_INT16:
+ mode = gandiva::SelectionVector::MODE_UINT16;
+ break;
+ case types::SV_INT32:
+ mode = gandiva::SelectionVector::MODE_UINT32;
+ break;
+ }
+ // good to invoke the evaluator now
+ status = Projector::Make(schema_ptr, expr_vector, mode, config, &projector);
+
+ if (!status.ok()) {
+ ss << "Failed to make LLVM module due to " << status.message() << "\n";
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ goto err_out;
+ }
+
+ // store the result in a map
+ holder = std::shared_ptr<ProjectorHolder>(
+ new ProjectorHolder(schema_ptr, ret_types, std::move(projector)));
+ module_id = projector_modules_.Insert(holder);
+ releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env);
+ return module_id;
+
+err_out:
+ env->ThrowNew(gandiva_exception_, ss.str().c_str());
+ return module_id;
+}
+
+///
+/// \brief Resizable buffer which resizes by doing a callback into java.
+///
+class JavaResizableBuffer : public arrow::ResizableBuffer {
+ public:
+ JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer,
+ int32_t len)
+ : ResizableBuffer(buffer, len),
+ env_(env),
+ jexpander_(jexpander),
+ vector_idx_(vector_idx) {
+ size_ = 0;
+ }
+
+ Status Resize(const int64_t new_size, bool shrink_to_fit) override;
+
+ Status Reserve(const int64_t new_capacity) override {
+ return Status::NotImplemented("reserve not implemented");
+ }
+
+ private:
+ JNIEnv* env_;
+ jobject jexpander_;
+ int32_t vector_idx_;
+};
+
+Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) {
+ if (shrink_to_fit == true) {
+ return Status::NotImplemented("shrink not implemented");
+ }
+
+ if (ARROW_PREDICT_TRUE(new_size < capacity())) {
+ // no need to expand.
+ size_ = new_size;
+ return Status::OK();
+ }
+
+ // callback into java to expand the buffer
+ jobject ret =
+ env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, new_size);
+ if (env_->ExceptionCheck()) {
+ env_->ExceptionDescribe();
+ env_->ExceptionClear();
+ return Status::OutOfMemory("buffer expand failed in java");
+ }
+
+ jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_);
+ jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_);
+ DCHECK_GE(ret_capacity, new_size);
+
+ data_ = reinterpret_cast<uint8_t*>(ret_address);
+ size_ = new_size;
+ capacity_ = ret_capacity;
+ return Status::OK();
+}
+
+#define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len) \
+ if (idx >= len) { \
+ status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \
+ break; \
+ }
+
+JNIEXPORT void JNICALL
+Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector(
+ JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows,
+ jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows,
+ jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs,
+ jlongArray out_buf_sizes) {
+ Status status;
+ std::shared_ptr<ProjectorHolder> holder = projector_modules_.Lookup(module_id);
+ if (holder == nullptr) {
+ std::stringstream ss;
+ ss << "Unknown module id " << module_id;
+ env->ThrowNew(gandiva_exception_, ss.str().c_str());
+ return;
+ }
+
+ int in_bufs_len = env->GetArrayLength(buf_addrs);
+ if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
+ env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
+ return;
+ }
+
+ int out_bufs_len = env->GetArrayLength(out_buf_addrs);
+ if (out_bufs_len != env->GetArrayLength(out_buf_sizes)) {
+ env->ThrowNew(gandiva_exception_,
+ "mismatch in arraylen of out_buf_addrs and out_buf_sizes");
+ return;
+ }
+
+ jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
+ jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
+
+ jlong* out_bufs = env->GetLongArrayElements(out_buf_addrs, 0);
+ jlong* out_sizes = env->GetLongArrayElements(out_buf_sizes, 0);
+
+ do {
+ std::shared_ptr<arrow::RecordBatch> in_batch;
+ status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
+ in_buf_sizes, in_bufs_len, &in_batch);
+ if (!status.ok()) {
+ break;
+ }
+
+ std::shared_ptr<gandiva::SelectionVector> selection_vector;
+ auto selection_buffer = std::make_shared<arrow::Buffer>(
+ reinterpret_cast<uint8_t*>(sel_vec_addr), sel_vec_size);
+ int output_row_count = 0;
+ switch (sel_vec_type) {
+ case types::SV_NONE: {
+ output_row_count = num_rows;
+ break;
+ }
+ case types::SV_INT16: {
+ status = gandiva::SelectionVector::MakeImmutableInt16(
+ sel_vec_rows, selection_buffer, &selection_vector);
+ output_row_count = sel_vec_rows;
+ break;
+ }
+ case types::SV_INT32: {
+ status = gandiva::SelectionVector::MakeImmutableInt32(
+ sel_vec_rows, selection_buffer, &selection_vector);
+ output_row_count = sel_vec_rows;
+ break;
+ }
+ }
+ if (!status.ok()) {
+ break;
+ }
+
+ auto ret_types = holder->rettypes();
+ ArrayDataVector output;
+ int buf_idx = 0;
+ int sz_idx = 0;
+ int output_vector_idx = 0;
+ for (FieldPtr field : ret_types) {
+ std::vector<std::shared_ptr<arrow::Buffer>> buffers;
+
+ CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
+ uint8_t* validity_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
+ jlong bitmap_sz = out_sizes[sz_idx++];
+ buffers.push_back(std::make_shared<arrow::MutableBuffer>(validity_buf, bitmap_sz));
+
+ if (arrow::is_binary_like(field->type()->id())) {
+ CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
+ uint8_t* offsets_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
+ jlong offsets_sz = out_sizes[sz_idx++];
+ buffers.push_back(
+ std::make_shared<arrow::MutableBuffer>(offsets_buf, offsets_sz));
+ }
+
+ CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len);
+ uint8_t* value_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]);
+ jlong data_sz = out_sizes[sz_idx++];
+ if (arrow::is_binary_like(field->type()->id())) {
+ if (jexpander == nullptr) {
+ status = Status::Invalid(
+ "expression has variable len output columns, but the expander object is "
+ "null");
+ break;
+ }
+ buffers.push_back(std::make_shared<JavaResizableBuffer>(
+ env, jexpander, output_vector_idx, value_buf, data_sz));
+ } else {
+ buffers.push_back(std::make_shared<arrow::MutableBuffer>(value_buf, data_sz));
+ }
+
+ auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers);
+ output.push_back(array_data);
+ ++output_vector_idx;
+ }
+ if (!status.ok()) {
+ break;
+ }
+ status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output);
+ } while (0);
+
+ env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
+ env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
+ env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT);
+ env->ReleaseLongArrayElements(out_buf_sizes, out_sizes, JNI_ABORT);
+
+ if (!status.ok()) {
+ std::stringstream ss;
+ ss << "Evaluate returned " << status.message() << "\n";
+ env->ThrowNew(gandiva_exception_, status.message().c_str());
+ return;
+ }
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector(
+ JNIEnv* env, jobject cls, jlong module_id) {
+ projector_modules_.Erase(module_id);
+}
+
+// filter related functions.
+void releaseFilterInput(jbyteArray schema_arr, jbyte* schema_bytes,
+ jbyteArray condition_arr, jbyte* condition_bytes, JNIEnv* env) {
+ env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT);
+ env->ReleaseByteArrayElements(condition_arr, condition_bytes, JNI_ABORT);
+}
+
+JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter(
+ JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray condition_arr,
+ jlong configuration_id) {
+ jlong module_id = 0LL;
+ std::shared_ptr<Filter> filter;
+ std::shared_ptr<FilterHolder> holder;
+
+ types::Schema schema;
+ jsize schema_len = env->GetArrayLength(schema_arr);
+ jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0);
+
+ types::Condition condition;
+ jsize condition_len = env->GetArrayLength(condition_arr);
+ jbyte* condition_bytes = env->GetByteArrayElements(condition_arr, 0);
+
+ ConditionPtr condition_ptr;
+ SchemaPtr schema_ptr;
+ gandiva::Status status;
+
+ std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id);
+ std::stringstream ss;
+
+ if (config == nullptr) {
+ ss << "configuration is mandatory.";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) {
+ ss << "Unable to parse schema protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ if (!ParseProtobuf(reinterpret_cast<uint8_t*>(condition_bytes), condition_len,
+ &condition)) {
+ ss << "Unable to parse condition protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ // convert types::Schema to arrow::Schema
+ schema_ptr = ProtoTypeToSchema(schema);
+ if (schema_ptr == nullptr) {
+ ss << "Unable to construct arrow schema object from schema protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ condition_ptr = ProtoTypeToCondition(condition);
+ if (condition_ptr == nullptr) {
+ ss << "Unable to construct condition object from condition protobuf\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ // good to invoke the filter builder now
+ status = Filter::Make(schema_ptr, condition_ptr, config, &filter);
+ if (!status.ok()) {
+ ss << "Failed to make LLVM module due to " << status.message() << "\n";
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ goto err_out;
+ }
+
+ // store the result in a map
+ holder = std::shared_ptr<FilterHolder>(new FilterHolder(schema_ptr, std::move(filter)));
+ module_id = filter_modules_.Insert(holder);
+ releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env);
+ return module_id;
+
+err_out:
+ env->ThrowNew(gandiva_exception_, ss.str().c_str());
+ return module_id;
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter(
+ JNIEnv* env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs,
+ jlongArray buf_sizes, jint jselection_vector_type, jlong out_buf_addr,
+ jlong out_buf_size) {
+ gandiva::Status status;
+ std::shared_ptr<FilterHolder> holder = filter_modules_.Lookup(module_id);
+ if (holder == nullptr) {
+ env->ThrowNew(gandiva_exception_, "Unknown module id\n");
+ return -1;
+ }
+
+ int in_bufs_len = env->GetArrayLength(buf_addrs);
+ if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
+ env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes");
+ return -1;
+ }
+
+ jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0);
+ jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0);
+ std::shared_ptr<gandiva::SelectionVector> selection_vector;
+
+ do {
+ std::shared_ptr<arrow::RecordBatch> in_batch;
+
+ status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs,
+ in_buf_sizes, in_bufs_len, &in_batch);
+ if (!status.ok()) {
+ break;
+ }
+
+ auto selection_vector_type =
+ static_cast<types::SelectionVectorType>(jselection_vector_type);
+ auto out_buffer = std::make_shared<arrow::MutableBuffer>(
+ reinterpret_cast<uint8_t*>(out_buf_addr), out_buf_size);
+ switch (selection_vector_type) {
+ case types::SV_INT16:
+ status =
+ gandiva::SelectionVector::MakeInt16(num_rows, out_buffer, &selection_vector);
+ break;
+ case types::SV_INT32:
+ status =
+ gandiva::SelectionVector::MakeInt32(num_rows, out_buffer, &selection_vector);
+ break;
+ default:
+ status = gandiva::Status::Invalid("unknown selection vector type");
+ }
+ if (!status.ok()) {
+ break;
+ }
+
+ status = holder->filter()->Evaluate(*in_batch, selection_vector);
+ } while (0);
+
+ env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT);
+ env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT);
+
+ if (!status.ok()) {
+ std::stringstream ss;
+ ss << "Evaluate returned " << status.message() << "\n";
+ env->ThrowNew(gandiva_exception_, status.message().c_str());
+ return -1;
+ } else {
+ int64_t num_slots = selection_vector->GetNumSlots();
+ // Check integer overflow
+ if (num_slots > INT_MAX) {
+ std::stringstream ss;
+ ss << "The selection vector has " << num_slots
+ << " slots, which is larger than the " << INT_MAX << " limit.\n";
+ const std::string message = ss.str();
+ env->ThrowNew(gandiva_exception_, message.c_str());
+ return -1;
+ }
+ return static_cast<int>(num_slots);
+ }
+}
+
+JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter(
+ JNIEnv* env, jobject cls, jlong module_id) {
+ filter_modules_.Erase(module_id);
+}