diff options
Diffstat (limited to 'src/arrow/cpp/src/gandiva')
190 files changed, 40339 insertions, 0 deletions
diff --git a/src/arrow/cpp/src/gandiva/CMakeLists.txt b/src/arrow/cpp/src/gandiva/CMakeLists.txt new file mode 100644 index 000000000..654a4a40b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/CMakeLists.txt @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set(GANDIVA_VERSION "${ARROW_VERSION}") + +# For "make gandiva" to build everything Gandiva-related +add_custom_target(gandiva-all) +add_custom_target(gandiva) +add_custom_target(gandiva-tests) +add_custom_target(gandiva-benchmarks) + +add_dependencies(gandiva-all gandiva gandiva-tests gandiva-benchmarks) + +find_package(LLVMAlt REQUIRED) + +if(LLVM_VERSION_MAJOR LESS "10") + set(GANDIVA_CXX_STANDARD ${CMAKE_CXX_STANDARD}) +else() + # LLVM 10 or later requires C++ 14 + if(CMAKE_CXX_STANDARD LESS 14) + set(GANDIVA_CXX_STANDARD 14) + else() + set(GANDIVA_CXX_STANDARD ${CMAKE_CXX_STANDARD}) + endif() +endif() + +add_definitions(-DGANDIVA_LLVM_VERSION=${LLVM_VERSION_MAJOR}) + +find_package(OpenSSLAlt REQUIRED) + +# Set the path where the bitcode file generated, see precompiled/CMakeLists.txt +set(GANDIVA_PRECOMPILED_BC_PATH "${CMAKE_CURRENT_BINARY_DIR}/irhelpers.bc") +set(GANDIVA_PRECOMPILED_CC_PATH "${CMAKE_CURRENT_BINARY_DIR}/precompiled_bitcode.cc") +set(GANDIVA_PRECOMPILED_CC_IN_PATH + "${CMAKE_CURRENT_SOURCE_DIR}/precompiled_bitcode.cc.in") + +# add_arrow_lib will look for this not yet existing file, so flag as generated +set_source_files_properties(${GANDIVA_PRECOMPILED_CC_PATH} PROPERTIES GENERATED TRUE) + +set(SRC_FILES + annotator.cc + bitmap_accumulator.cc + cache.cc + cast_time.cc + configuration.cc + context_helper.cc + decimal_ir.cc + decimal_type_util.cc + decimal_xlarge.cc + engine.cc + date_utils.cc + expr_decomposer.cc + expr_validator.cc + expression.cc + expression_registry.cc + exported_funcs_registry.cc + filter.cc + function_ir_builder.cc + function_registry.cc + function_registry_arithmetic.cc + function_registry_datetime.cc + function_registry_hash.cc + function_registry_math_ops.cc + function_registry_string.cc + function_registry_timestamp_arithmetic.cc + function_signature.cc + gdv_function_stubs.cc + hash_utils.cc + llvm_generator.cc + llvm_types.cc + like_holder.cc + literal_holder.cc + projector.cc + regex_util.cc + replace_holder.cc + selection_vector.cc + tree_expr_builder.cc + to_date_holder.cc + random_generator_holder.cc + ${GANDIVA_PRECOMPILED_CC_PATH}) + +set(GANDIVA_SHARED_PRIVATE_LINK_LIBS arrow_shared LLVM::LLVM_INTERFACE + ${GANDIVA_OPENSSL_LIBS}) + +set(GANDIVA_STATIC_LINK_LIBS arrow_static LLVM::LLVM_INTERFACE ${GANDIVA_OPENSSL_LIBS}) + +if(ARROW_GANDIVA_STATIC_LIBSTDCPP AND (CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX + )) + set(GANDIVA_STATIC_LINK_LIBS ${GANDIVA_STATIC_LINK_LIBS} -static-libstdc++ + -static-libgcc) +endif() + +# if (MSVC) +# # Symbols that need to be made public in gandiva.dll for LLVM IR +# # compilation +# set(MSVC_SYMBOL_EXPORTS _Init_thread_header) +# foreach(SYMBOL ${MSVC_SYMBOL_EXPORTS}) +# set(GANDIVA_SHARED_LINK_FLAGS "${GANDIVA_SHARED_LINK_FLAGS} /EXPORT:${SYMBOL}") +# endforeach() +# endif() +if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT) + set(GANDIVA_VERSION_SCRIPT_FLAGS + "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/symbols.map") + set(GANDIVA_SHARED_LINK_FLAGS + "${GANDIVA_SHARED_LINK_FLAGS} ${GANDIVA_VERSION_SCRIPT_FLAGS}") +endif() + +add_arrow_lib(gandiva + CMAKE_PACKAGE_NAME + Gandiva + PKG_CONFIG_NAME + gandiva + SOURCES + ${SRC_FILES} + PRECOMPILED_HEADERS + "$<$<COMPILE_LANGUAGE:CXX>:gandiva/pch.h>" + OUTPUTS + GANDIVA_LIBRARIES + DEPENDENCIES + arrow_dependencies + precompiled + EXTRA_INCLUDES + $<TARGET_PROPERTY:LLVM::LLVM_INTERFACE,INTERFACE_INCLUDE_DIRECTORIES> + ${GANDIVA_OPENSSL_INCLUDE_DIR} + ${UTF8PROC_INCLUDE_DIR} + SHARED_LINK_FLAGS + ${GANDIVA_SHARED_LINK_FLAGS} + SHARED_LINK_LIBS + arrow_shared + SHARED_PRIVATE_LINK_LIBS + ${GANDIVA_SHARED_PRIVATE_LINK_LIBS} + STATIC_LINK_LIBS + ${GANDIVA_STATIC_LINK_LIBS}) + +foreach(LIB_TARGET ${GANDIVA_LIBRARIES}) + target_compile_definitions(${LIB_TARGET} PRIVATE GANDIVA_EXPORTING) + set_target_properties(${LIB_TARGET} PROPERTIES CXX_STANDARD ${GANDIVA_CXX_STANDARD}) +endforeach() + +if(ARROW_BUILD_STATIC AND WIN32) + target_compile_definitions(gandiva_static PUBLIC GANDIVA_STATIC) +endif() + +add_dependencies(gandiva ${GANDIVA_LIBRARIES}) + +arrow_install_all_headers("gandiva") + +set(GANDIVA_STATIC_TEST_LINK_LIBS gandiva_static ${ARROW_TEST_LINK_LIBS}) + +set(GANDIVA_SHARED_TEST_LINK_LIBS gandiva_shared ${ARROW_TEST_LINK_LIBS}) + +function(ADD_GANDIVA_TEST REL_TEST_NAME) + set(options USE_STATIC_LINKING) + set(one_value_args) + set(multi_value_args) + cmake_parse_arguments(ARG + "${options}" + "${one_value_args}" + "${multi_value_args}" + ${ARGN}) + + if(NO_TESTS) + return() + endif() + + set(TEST_ARGUMENTS + ENABLED + PREFIX + "gandiva" + LABELS + "gandiva-tests" + ${ARG_UNPARSED_ARGUMENTS}) + + # and uses less disk space, but in some cases we need to force static + # linking (see rationale below). + if(ARG_USE_STATIC_LINKING OR ARROW_TEST_LINKAGE STREQUAL "static") + add_test_case(${REL_TEST_NAME} + ${TEST_ARGUMENTS} + STATIC_LINK_LIBS + ${GANDIVA_STATIC_TEST_LINK_LIBS} + ${ARG_UNPARSED_ARGUMENTS}) + else() + add_test_case(${REL_TEST_NAME} + ${TEST_ARGUMENTS} + STATIC_LINK_LIBS + ${GANDIVA_SHARED_TEST_LINK_LIBS} + ${ARG_UNPARSED_ARGUMENTS}) + endif() + + set(TEST_NAME gandiva-${REL_TEST_NAME}) + string(REPLACE "_" "-" TEST_NAME ${TEST_NAME}) + set_target_properties(${TEST_NAME} PROPERTIES CXX_STANDARD ${GANDIVA_CXX_STANDARD}) +endfunction() + +set(GANDIVA_INTERNALS_TEST_ARGUMENTS) +if(WIN32) + list(APPEND + GANDIVA_INTERNALS_TEST_ARGUMENTS + EXTRA_LINK_LIBS + LLVM::LLVM_INTERFACE + ${GANDIVA_OPENSSL_LIBS}) +endif() +add_gandiva_test(internals-test + SOURCES + bitmap_accumulator_test.cc + engine_llvm_test.cc + function_registry_test.cc + function_signature_test.cc + llvm_types_test.cc + llvm_generator_test.cc + annotator_test.cc + tree_expr_test.cc + expr_decomposer_test.cc + expression_registry_test.cc + selection_vector_test.cc + greedy_dual_size_cache_test.cc + to_date_holder_test.cc + simple_arena_test.cc + like_holder_test.cc + replace_holder_test.cc + decimal_type_util_test.cc + random_generator_holder_test.cc + hash_utils_test.cc + gdv_function_stubs_test.cc + EXTRA_DEPENDENCIES + LLVM::LLVM_INTERFACE + ${GANDIVA_OPENSSL_LIBS} + EXTRA_INCLUDES + $<TARGET_PROPERTY:LLVM::LLVM_INTERFACE,INTERFACE_INCLUDE_DIRECTORIES> + ${GANDIVA_INTERNALS_TEST_ARGUMENTS} + ${GANDIVA_OPENSSL_INCLUDE_DIR} + ${UTF8PROC_INCLUDE_DIR}) + +if(ARROW_GANDIVA_JAVA) + add_subdirectory(jni) +endif() + +add_subdirectory(precompiled) +add_subdirectory(tests) diff --git a/src/arrow/cpp/src/gandiva/GandivaConfig.cmake.in b/src/arrow/cpp/src/gandiva/GandivaConfig.cmake.in new file mode 100644 index 000000000..09bc33901 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/GandivaConfig.cmake.in @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This config sets the following variables in your project:: +# +# Gandiva_FOUND - true if Gandiva found on the system +# +# This config sets the following targets in your project:: +# +# gandiva_shared - for linked as shared library if shared library is built +# gandiva_static - for linked as static library if static library is built + +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) +find_dependency(Arrow) + +# Load targets only once. If we load targets multiple times, CMake reports +# already existent target error. +if(NOT (TARGET gandiva_shared OR TARGET gandiva_static)) + include("${CMAKE_CURRENT_LIST_DIR}/GandivaTargets.cmake") +endif() diff --git a/src/arrow/cpp/src/gandiva/annotator.cc b/src/arrow/cpp/src/gandiva/annotator.cc new file mode 100644 index 000000000..f6acaff18 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/annotator.cc @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/annotator.h" + +#include <memory> +#include <string> + +#include "gandiva/field_descriptor.h" + +namespace gandiva { + +FieldDescriptorPtr Annotator::CheckAndAddInputFieldDescriptor(FieldPtr field) { + // If the field is already in the map, return the entry. + auto found = in_name_to_desc_.find(field->name()); + if (found != in_name_to_desc_.end()) { + return found->second; + } + + auto desc = MakeDesc(field, false /*is_output*/); + in_name_to_desc_[field->name()] = desc; + return desc; +} + +FieldDescriptorPtr Annotator::AddOutputFieldDescriptor(FieldPtr field) { + auto desc = MakeDesc(field, true /*is_output*/); + out_descs_.push_back(desc); + return desc; +} + +FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field, bool is_output) { + int data_idx = buffer_count_++; + int validity_idx = buffer_count_++; + int offsets_idx = FieldDescriptor::kInvalidIdx; + if (arrow::is_binary_like(field->type()->id())) { + offsets_idx = buffer_count_++; + } + int data_buffer_ptr_idx = FieldDescriptor::kInvalidIdx; + if (is_output) { + data_buffer_ptr_idx = buffer_count_++; + } + return std::make_shared<FieldDescriptor>(field, data_idx, validity_idx, offsets_idx, + data_buffer_ptr_idx); +} + +void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, + const arrow::ArrayData& array_data, + EvalBatch* eval_batch, bool is_output) { + int buffer_idx = 0; + + // The validity buffer is optional. Use nullptr if it does not have one. + if (array_data.buffers[buffer_idx]) { + uint8_t* validity_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.validity_idx(), validity_buf, array_data.offset); + } else { + eval_batch->SetBuffer(desc.validity_idx(), nullptr, array_data.offset); + } + ++buffer_idx; + + if (desc.HasOffsetsIdx()) { + uint8_t* offsets_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.offsets_idx(), offsets_buf, array_data.offset); + ++buffer_idx; + } + + uint8_t* data_buf = const_cast<uint8_t*>(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); + if (is_output) { + // pass in the Buffer object for output data buffers. Can be used for resizing. + uint8_t* data_buf_ptr = + reinterpret_cast<uint8_t*>(array_data.buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + } +} + +EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch, + const ArrayDataVector& out_vector) { + EvalBatchPtr eval_batch = std::make_shared<EvalBatch>( + record_batch.num_rows(), buffer_count_, local_bitmap_count_); + + // Fill in the entries for the input fields. + for (int i = 0; i < record_batch.num_columns(); ++i) { + const std::string& name = record_batch.column_name(i); + auto found = in_name_to_desc_.find(name); + if (found == in_name_to_desc_.end()) { + // skip columns not involved in the expression. + continue; + } + + PrepareBuffersForField(*(found->second), *(record_batch.column(i))->data(), + eval_batch.get(), false /*is_output*/); + } + + // Fill in the entries for the output fields. + int idx = 0; + for (auto& arraydata : out_vector) { + const FieldDescriptorPtr& desc = out_descs_.at(idx); + PrepareBuffersForField(*desc, *arraydata, eval_batch.get(), true /*is_output*/); + ++idx; + } + return eval_batch; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/annotator.h b/src/arrow/cpp/src/gandiva/annotator.h new file mode 100644 index 000000000..5f185d183 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/annotator.h @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <list> +#include <string> +#include <unordered_map> +#include <vector> + +#include "arrow/util/logging.h" +#include "gandiva/arrow.h" +#include "gandiva/eval_batch.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief annotate the arrow fields in an expression, and use that +/// to convert the incoming arrow-format row batch to an EvalBatch. +class GANDIVA_EXPORT Annotator { + public: + Annotator() : buffer_count_(0), local_bitmap_count_(0) {} + + /// Add an annotated field descriptor for a field in an input schema. + /// If the field is already annotated, returns that instead. + FieldDescriptorPtr CheckAndAddInputFieldDescriptor(FieldPtr field); + + /// Add an annotated field descriptor for an output field. + FieldDescriptorPtr AddOutputFieldDescriptor(FieldPtr field); + + /// Add a local bitmap (for saving validity bits of an intermediate node). + /// Returns the index of the bitmap in the list of local bitmaps. + int AddLocalBitMap() { return local_bitmap_count_++; } + + /// Prepare an eval batch for the incoming record batch. + EvalBatchPtr PrepareEvalBatch(const arrow::RecordBatch& record_batch, + const ArrayDataVector& out_vector); + + int buffer_count() { return buffer_count_; } + + private: + /// Annotate a field and return the descriptor. + FieldDescriptorPtr MakeDesc(FieldPtr field, bool is_output); + + /// Populate eval_batch by extracting the raw buffers from the arrow array, whose + /// contents are represent by the annotated descriptor 'desc'. + void PrepareBuffersForField(const FieldDescriptor& desc, + const arrow::ArrayData& array_data, EvalBatch* eval_batch, + bool is_output); + + /// The list of input/output buffers (includes bitmap buffers, value buffers and + /// offset buffers). + int buffer_count_; + + /// The number of local bitmaps. These are used to save the validity bits for + /// intermediate nodes in the expression tree. + int local_bitmap_count_; + + /// map between field name and annotated input field descriptor. + std::unordered_map<std::string, FieldDescriptorPtr> in_name_to_desc_; + + /// vector of annotated output field descriptors. + std::vector<FieldDescriptorPtr> out_descs_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/annotator_test.cc b/src/arrow/cpp/src/gandiva/annotator_test.cc new file mode 100644 index 000000000..e537943d9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/annotator_test.cc @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/annotator.h" + +#include <memory> +#include <utility> + +#include <arrow/memory_pool.h> +#include <gtest/gtest.h> +#include "gandiva/field_descriptor.h" + +namespace gandiva { + +class TestAnnotator : public ::testing::Test { + protected: + ArrayPtr MakeInt32Array(int length); +}; + +ArrayPtr TestAnnotator::MakeInt32Array(int length) { + arrow::Status status; + + auto validity = *arrow::AllocateBuffer((length + 63) / 8); + + auto values = *arrow::AllocateBuffer(length * sizeof(int32_t)); + + auto array_data = arrow::ArrayData::Make(arrow::int32(), length, + {std::move(validity), std::move(values)}); + return arrow::MakeArray(array_data); +} + +TEST_F(TestAnnotator, TestAdd) { + Annotator annotator; + + auto field_a = arrow::field("a", arrow::int32()); + auto field_b = arrow::field("b", arrow::int32()); + auto in_schema = arrow::schema({field_a, field_b}); + auto field_sum = arrow::field("sum", arrow::int32()); + + FieldDescriptorPtr desc_a = annotator.CheckAndAddInputFieldDescriptor(field_a); + EXPECT_EQ(desc_a->field(), field_a); + EXPECT_EQ(desc_a->data_idx(), 0); + EXPECT_EQ(desc_a->validity_idx(), 1); + + // duplicate add shouldn't cause a new descriptor. + FieldDescriptorPtr dup = annotator.CheckAndAddInputFieldDescriptor(field_a); + EXPECT_EQ(dup, desc_a); + EXPECT_EQ(dup->validity_idx(), desc_a->validity_idx()); + + FieldDescriptorPtr desc_b = annotator.CheckAndAddInputFieldDescriptor(field_b); + EXPECT_EQ(desc_b->field(), field_b); + EXPECT_EQ(desc_b->data_idx(), 2); + EXPECT_EQ(desc_b->validity_idx(), 3); + + FieldDescriptorPtr desc_sum = annotator.AddOutputFieldDescriptor(field_sum); + EXPECT_EQ(desc_sum->field(), field_sum); + EXPECT_EQ(desc_sum->data_idx(), 4); + EXPECT_EQ(desc_sum->validity_idx(), 5); + EXPECT_EQ(desc_sum->data_buffer_ptr_idx(), 6); + + // prepare record batch + int num_records = 100; + auto arrow_v0 = MakeInt32Array(num_records); + auto arrow_v1 = MakeInt32Array(num_records); + + // prepare input record batch + auto record_batch = + arrow::RecordBatch::Make(in_schema, num_records, {arrow_v0, arrow_v1}); + + auto arrow_sum = MakeInt32Array(num_records); + EvalBatchPtr batch = annotator.PrepareEvalBatch(*record_batch, {arrow_sum->data()}); + EXPECT_EQ(batch->GetNumBuffers(), 7); + + auto buffers = batch->GetBufferArray(); + EXPECT_EQ(buffers[desc_a->validity_idx()], arrow_v0->data()->buffers.at(0)->data()); + EXPECT_EQ(buffers[desc_a->data_idx()], arrow_v0->data()->buffers.at(1)->data()); + EXPECT_EQ(buffers[desc_b->validity_idx()], arrow_v1->data()->buffers.at(0)->data()); + EXPECT_EQ(buffers[desc_b->data_idx()], arrow_v1->data()->buffers.at(1)->data()); + EXPECT_EQ(buffers[desc_sum->validity_idx()], arrow_sum->data()->buffers.at(0)->data()); + EXPECT_EQ(buffers[desc_sum->data_idx()], arrow_sum->data()->buffers.at(1)->data()); + EXPECT_EQ(buffers[desc_sum->data_buffer_ptr_idx()], + reinterpret_cast<uint8_t*>(arrow_sum->data()->buffers.at(1).get())); + + auto bitmaps = batch->GetLocalBitMapArray(); + EXPECT_EQ(bitmaps, nullptr); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/arrow.h b/src/arrow/cpp/src/gandiva/arrow.h new file mode 100644 index 000000000..e6d40cb18 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/arrow.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <vector> + +#include "arrow/array.h" // IWYU pragma: export +#include "arrow/builder.h" // IWYU pragma: export +#include "arrow/pretty_print.h" // IWYU pragma: export +#include "arrow/record_batch.h" // IWYU pragma: export +#include "arrow/status.h" // IWYU pragma: export +#include "arrow/type.h" // IWYU pragma: export + +namespace gandiva { + +using arrow::ArrayDataVector; +using arrow::DataTypeVector; +using arrow::FieldVector; +using arrow::Result; +using arrow::Status; +using arrow::StatusCode; + +using ArrayPtr = std::shared_ptr<arrow::Array>; +using ArrayDataPtr = std::shared_ptr<arrow::ArrayData>; +using DataTypePtr = std::shared_ptr<arrow::DataType>; +using FieldPtr = std::shared_ptr<arrow::Field>; +using RecordBatchPtr = std::shared_ptr<arrow::RecordBatch>; +using SchemaPtr = std::shared_ptr<arrow::Schema>; + +using Decimal128TypePtr = std::shared_ptr<arrow::Decimal128Type>; +using Decimal128TypeVector = std::vector<Decimal128TypePtr>; + +static inline bool is_decimal_128(DataTypePtr type) { + if (type->id() == arrow::Type::DECIMAL) { + auto decimal_type = arrow::internal::checked_cast<arrow::DecimalType*>(type.get()); + return decimal_type->byte_width() == 16; + } else { + return false; + } +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/basic_decimal_scalar.h b/src/arrow/cpp/src/gandiva/basic_decimal_scalar.h new file mode 100644 index 000000000..b2f0da506 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/basic_decimal_scalar.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> + +#include "arrow/util/basic_decimal.h" +#include "arrow/util/decimal.h" + +namespace gandiva { + +using arrow::BasicDecimal128; + +/// Represents a 128-bit decimal value along with its precision and scale. +class BasicDecimalScalar128 { + public: + constexpr BasicDecimalScalar128(int64_t high_bits, uint64_t low_bits, int32_t precision, + int32_t scale) + : value_(high_bits, low_bits), precision_(precision), scale_(scale) {} + + constexpr BasicDecimalScalar128(const BasicDecimal128& value, int32_t precision, + int32_t scale) + : value_(value), precision_(precision), scale_(scale) {} + + constexpr BasicDecimalScalar128(int32_t precision, int32_t scale) + : precision_(precision), scale_(scale) {} + + int32_t scale() const { return scale_; } + + int32_t precision() const { return precision_; } + + const BasicDecimal128& value() const { return value_; } + + private: + BasicDecimal128 value_; + int32_t precision_; + int32_t scale_; +}; + +inline bool operator==(const BasicDecimalScalar128& left, + const BasicDecimalScalar128& right) { + return left.value() == right.value() && left.precision() == right.precision() && + left.scale() == right.scale(); +} + +inline BasicDecimalScalar128 operator-(const BasicDecimalScalar128& operand) { + return BasicDecimalScalar128{-operand.value(), operand.precision(), operand.scale()}; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/bitmap_accumulator.cc b/src/arrow/cpp/src/gandiva/bitmap_accumulator.cc new file mode 100644 index 000000000..8fc66b389 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/bitmap_accumulator.cc @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/bitmap_accumulator.h" + +#include <vector> + +#include "arrow/util/bitmap_ops.h" + +namespace gandiva { + +void BitMapAccumulator::ComputeResult(uint8_t* dst_bitmap) { + int64_t num_records = eval_batch_.num_records(); + + if (all_invalid_) { + // set all bits to 0. + memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_records)); + } else { + IntersectBitMaps(dst_bitmap, src_maps_, src_map_offsets_, num_records); + } +} + +/// Compute the intersection of multiple bitmaps. +void BitMapAccumulator::IntersectBitMaps(uint8_t* dst_map, + const std::vector<uint8_t*>& src_maps, + const std::vector<int64_t>& src_map_offsets, + int64_t num_records) { + int64_t num_words = (num_records + 63) / 64; // aligned to 8-byte. + int64_t num_bytes = num_words * 8; + int64_t nmaps = src_maps.size(); + + switch (nmaps) { + case 0: { + // no src_maps_ bitmap. simply set all bits + memset(dst_map, 0xff, num_bytes); + break; + } + + case 1: { + // one src_maps_ bitmap. copy to dst_map + arrow::internal::CopyBitmap(src_maps[0], src_map_offsets[0], num_records, dst_map, + 0); + break; + } + + default: { + // src_maps bitmaps ANDs + arrow::internal::BitmapAnd(src_maps[0], src_map_offsets[0], src_maps[1], + src_map_offsets[1], num_records, /*offset=*/0, dst_map); + for (int64_t m = 2; m < nmaps; ++m) { + arrow::internal::BitmapAnd(dst_map, 0, src_maps[m], src_map_offsets[m], + num_records, + /*offset=*/0, dst_map); + } + + break; + } + } +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/bitmap_accumulator.h b/src/arrow/cpp/src/gandiva/bitmap_accumulator.h new file mode 100644 index 000000000..0b297a98f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/bitmap_accumulator.h @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> + +#include "arrow/util/macros.h" +#include "gandiva/dex.h" +#include "gandiva/dex_visitor.h" +#include "gandiva/eval_batch.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Extract bitmap buffer from either the input/buffer vectors or the +/// local validity bitmap, and accumulates them to do the final computation. +class GANDIVA_EXPORT BitMapAccumulator : public DexDefaultVisitor { + public: + explicit BitMapAccumulator(const EvalBatch& eval_batch) + : eval_batch_(eval_batch), all_invalid_(false) {} + + void Visit(const VectorReadValidityDex& dex) { + int idx = dex.ValidityIdx(); + auto bitmap = eval_batch_.GetBuffer(idx); + // The bitmap could be null. Ignore it in this case. + if (bitmap != NULLPTR) { + src_maps_.push_back(bitmap); + src_map_offsets_.push_back(eval_batch_.GetBufferOffset(idx)); + } + } + + void Visit(const LocalBitMapValidityDex& dex) { + int idx = dex.local_bitmap_idx(); + auto bitmap = eval_batch_.GetLocalBitMap(idx); + src_maps_.push_back(bitmap); + src_map_offsets_.push_back(0); // local bitmap has offset 0 + } + + void Visit(const TrueDex& dex) { + // bitwise-and with 1 is always 1. so, ignore. + } + + void Visit(const FalseDex& dex) { + // The final result is "all 0s". + all_invalid_ = true; + } + + /// Compute the dst_bmap based on the contents and type of the accumulated bitmap dex. + void ComputeResult(uint8_t* dst_bitmap); + + /// Compute the intersection of the accumulated bitmaps (with offsets) and save the + /// result in dst_bmap. + static void IntersectBitMaps(uint8_t* dst_map, const std::vector<uint8_t*>& src_maps, + const std::vector<int64_t>& src_maps_offsets, + int64_t num_records); + + private: + const EvalBatch& eval_batch_; + std::vector<uint8_t*> src_maps_; + std::vector<int64_t> src_map_offsets_; + bool all_invalid_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/bitmap_accumulator_test.cc b/src/arrow/cpp/src/gandiva/bitmap_accumulator_test.cc new file mode 100644 index 000000000..ccffab3e9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/bitmap_accumulator_test.cc @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/bitmap_accumulator.h" + +#include <memory> +#include <vector> + +#include <gtest/gtest.h> + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/util.h" +#include "arrow/util/bitmap_ops.h" + +#include "gandiva/dex.h" + +namespace gandiva { + +class TestBitMapAccumulator : public ::testing::Test { + protected: + void FillBitMap(uint8_t* bmap, uint32_t seed, int nrecords); + void ByteWiseIntersectBitMaps(uint8_t* dst, const std::vector<uint8_t*>& srcs, + const std::vector<int64_t>& srcOffsets, int nrecords); +}; + +void TestBitMapAccumulator::FillBitMap(uint8_t* bmap, uint32_t seed, int nbytes) { + ::arrow::random_bytes(nbytes, seed, bmap); +} + +void TestBitMapAccumulator::ByteWiseIntersectBitMaps( + uint8_t* dst, const std::vector<uint8_t*>& srcs, + const std::vector<int64_t>& srcOffsets, int nrecords) { + if (srcs.empty()) { + arrow::BitUtil::SetBitsTo(dst, 0, nrecords, true); + return; + } + + arrow::internal::CopyBitmap(srcs[0], srcOffsets[0], nrecords, dst, 0); + for (uint32_t j = 1; j < srcs.size(); ++j) { + arrow::internal::BitmapAnd(dst, 0, srcs[j], srcOffsets[j], nrecords, 0, dst); + } +} + +TEST_F(TestBitMapAccumulator, TestIntersectBitMaps) { + const int length = 128; + const int nrecords = length * 8; + uint8_t src_bitmaps[4][length]; + uint8_t dst_bitmap[length]; + uint8_t expected_bitmap[length]; + + for (int i = 0; i < 4; i++) { + FillBitMap(src_bitmaps[i], i, length); + } + + for (int i = 0; i < 4; i++) { + std::vector<uint8_t*> src_bitmap_ptrs; + std::vector<int64_t> src_bitmap_offsets(i, 0); + for (int j = 0; j < i; ++j) { + src_bitmap_ptrs.push_back(src_bitmaps[j]); + } + + BitMapAccumulator::IntersectBitMaps(dst_bitmap, src_bitmap_ptrs, src_bitmap_offsets, + nrecords); + ByteWiseIntersectBitMaps(expected_bitmap, src_bitmap_ptrs, src_bitmap_offsets, + nrecords); + EXPECT_EQ(memcmp(dst_bitmap, expected_bitmap, length), 0); + } +} + +TEST_F(TestBitMapAccumulator, TestIntersectBitMapsWithOffset) { + const int length = 128; + uint8_t src_bitmaps[4][length]; + uint8_t dst_bitmap[length]; + uint8_t expected_bitmap[length]; + + for (int i = 0; i < 4; i++) { + FillBitMap(src_bitmaps[i], i, length); + } + + for (int i = 0; i < 4; i++) { + std::vector<uint8_t*> src_bitmap_ptrs; + std::vector<int64_t> src_bitmap_offsets; + for (int j = 0; j < i; ++j) { + src_bitmap_ptrs.push_back(src_bitmaps[j]); + src_bitmap_offsets.push_back(j); // offset j + } + const int nrecords = (i == 0) ? length * 8 : length * 8 - i + 1; + + BitMapAccumulator::IntersectBitMaps(dst_bitmap, src_bitmap_ptrs, src_bitmap_offsets, + nrecords); + ByteWiseIntersectBitMaps(expected_bitmap, src_bitmap_ptrs, src_bitmap_offsets, + nrecords); + EXPECT_TRUE( + arrow::internal::BitmapEquals(dst_bitmap, 0, expected_bitmap, 0, nrecords)); + } +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/cache.cc b/src/arrow/cpp/src/gandiva/cache.cc new file mode 100644 index 000000000..d823a676b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/cache.cc @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/cache.h" +#include "arrow/util/logging.h" + +namespace gandiva { + +static const int DEFAULT_CACHE_SIZE = 500; + +int GetCapacity() { + int capacity; + const char* env_cache_size = std::getenv("GANDIVA_CACHE_SIZE"); + if (env_cache_size != nullptr) { + capacity = std::atoi(env_cache_size); + if (capacity <= 0) { + ARROW_LOG(WARNING) << "Invalid cache size provided. Using default cache size: " + << DEFAULT_CACHE_SIZE; + capacity = DEFAULT_CACHE_SIZE; + } + } else { + capacity = DEFAULT_CACHE_SIZE; + } + return capacity; +} + +void LogCacheSize(size_t capacity) { + ARROW_LOG(INFO) << "Creating gandiva cache with capacity: " << capacity; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/cache.h b/src/arrow/cpp/src/gandiva/cache.h new file mode 100644 index 000000000..8d0f75ce3 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/cache.h @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdlib> +#include <memory> +#include <mutex> + +#include "gandiva/greedy_dual_size_cache.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +GANDIVA_EXPORT +int GetCapacity(); + +GANDIVA_EXPORT +void LogCacheSize(size_t capacity); + +template <class KeyType, typename ValueType> +class Cache { + public: + explicit Cache(size_t capacity) : cache_(capacity) { LogCacheSize(capacity); } + + Cache() : Cache(GetCapacity()) {} + + ValueType GetModule(KeyType cache_key) { + arrow::util::optional<ValueCacheObject<ValueType>> result; + mtx_.lock(); + result = cache_.get(cache_key); + mtx_.unlock(); + return result != arrow::util::nullopt ? (*result).module : nullptr; + } + + void PutModule(KeyType cache_key, ValueCacheObject<ValueType> valueCacheObject) { + mtx_.lock(); + cache_.insert(cache_key, valueCacheObject); + mtx_.unlock(); + } + + private: + GreedyDualSizeCache<KeyType, ValueType> cache_; + std::mutex mtx_; +}; +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/cast_time.cc b/src/arrow/cpp/src/gandiva/cast_time.cc new file mode 100644 index 000000000..843ce01f8 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/cast_time.cc @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <cstdint> + +#include "arrow/vendored/datetime.h" + +#include "gandiva/precompiled/time_fields.h" + +#ifndef GANDIVA_UNIT_TEST +#include "gandiva/exported_funcs.h" +#include "gandiva/gdv_function_stubs.h" + +#include "gandiva/engine.h" + +namespace gandiva { + +void ExportedTimeFunctions::AddMappings(Engine* engine) const { + std::vector<llvm::Type*> args; + auto types = engine->types(); + + // gdv_fn_time_with_zone + args = {types->ptr_type(types->i32_type()), // time fields + types->i8_ptr_type(), // const char* zone + types->i32_type(), // int data_len + types->i64_type()}; // timestamp *ret_time + + engine->AddGlobalMappingForFunc("gdv_fn_time_with_zone", + types->i32_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_time_with_zone)); +} + +} // namespace gandiva +#endif // !GANDIVA_UNIT_TEST + +extern "C" { + +// TODO : Do input validation or make sure the callers do that ? +int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len, + int64_t* ret_time) { + using arrow_vendored::date::day; + using arrow_vendored::date::local_days; + using arrow_vendored::date::locate_zone; + using arrow_vendored::date::month; + using arrow_vendored::date::time_zone; + using arrow_vendored::date::year; + using std::chrono::hours; + using std::chrono::milliseconds; + using std::chrono::minutes; + using std::chrono::seconds; + + using gandiva::TimeFields; + try { + const time_zone* tz = locate_zone(std::string(zone, zone_len)); + *ret_time = tz->to_sys(local_days(year(time_fields[TimeFields::kYear]) / + month(time_fields[TimeFields::kMonth]) / + day(time_fields[TimeFields::kDay])) + + hours(time_fields[TimeFields::kHours]) + + minutes(time_fields[TimeFields::kMinutes]) + + seconds(time_fields[TimeFields::kSeconds]) + + milliseconds(time_fields[TimeFields::kSubSeconds])) + .time_since_epoch() + .count(); + } catch (...) { + return EINVAL; + } + + return 0; +} + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/compiled_expr.h b/src/arrow/cpp/src/gandiva/compiled_expr.h new file mode 100644 index 000000000..ba0ca3437 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/compiled_expr.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/llvm_includes.h" +#include "gandiva/selection_vector.h" +#include "gandiva/value_validity_pair.h" + +namespace gandiva { + +using EvalFunc = int (*)(uint8_t** buffers, int64_t* offsets, uint8_t** local_bitmaps, + const uint8_t* selection_buffer, int64_t execution_ctx_ptr, + int64_t record_count); + +/// \brief Tracks the compiled state for one expression. +class CompiledExpr { + public: + CompiledExpr(ValueValidityPairPtr value_validity, FieldDescriptorPtr output) + : value_validity_(value_validity), output_(output) {} + + ValueValidityPairPtr value_validity() const { return value_validity_; } + + FieldDescriptorPtr output() const { return output_; } + + void SetIRFunction(SelectionVector::Mode mode, llvm::Function* ir_function) { + ir_functions_[static_cast<int>(mode)] = ir_function; + } + + llvm::Function* GetIRFunction(SelectionVector::Mode mode) const { + return ir_functions_[static_cast<int>(mode)]; + } + + void SetJITFunction(SelectionVector::Mode mode, EvalFunc jit_function) { + jit_functions_[static_cast<int>(mode)] = jit_function; + } + + EvalFunc GetJITFunction(SelectionVector::Mode mode) const { + return jit_functions_[static_cast<int>(mode)]; + } + + private: + // value & validities for the expression tree (root) + ValueValidityPairPtr value_validity_; + + // output field + FieldDescriptorPtr output_; + + // IR functions for various modes in the generated code + std::array<llvm::Function*, SelectionVector::kNumModes> ir_functions_; + + // JIT functions in the generated code (set after the module is optimised and finalized) + std::array<EvalFunc, SelectionVector::kNumModes> jit_functions_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/condition.h b/src/arrow/cpp/src/gandiva/condition.h new file mode 100644 index 000000000..a3e8f9d1f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/condition.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "gandiva/arrow.h" +#include "gandiva/expression.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// \brief A condition expression. +class Condition : public Expression { + public: + explicit Condition(const NodePtr root) + : Expression(root, std::make_shared<arrow::Field>("cond", arrow::boolean())) {} + + virtual ~Condition() = default; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/configuration.cc b/src/arrow/cpp/src/gandiva/configuration.cc new file mode 100644 index 000000000..1e26c5c70 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/configuration.cc @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/configuration.h" + +#include "arrow/util/hash_util.h" + +namespace gandiva { + +const std::shared_ptr<Configuration> ConfigurationBuilder::default_configuration_ = + InitDefaultConfig(); + +std::size_t Configuration::Hash() const { + static constexpr size_t kHashSeed = 0; + size_t result = kHashSeed; + arrow::internal::hash_combine(result, static_cast<size_t>(optimize_)); + arrow::internal::hash_combine(result, static_cast<size_t>(target_host_cpu_)); + return result; +} + +bool Configuration::operator==(const Configuration& other) const { + return optimize_ == other.optimize_ && target_host_cpu_ == other.target_host_cpu_; +} + +bool Configuration::operator!=(const Configuration& other) const { + return !(*this == other); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/configuration.h b/src/arrow/cpp/src/gandiva/configuration.h new file mode 100644 index 000000000..9cd301524 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/configuration.h @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> + +#include "arrow/status.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class ConfigurationBuilder; +/// \brief runtime config for gandiva +/// +/// It contains elements to customize gandiva execution +/// at run time. +class GANDIVA_EXPORT Configuration { + public: + friend class ConfigurationBuilder; + + Configuration() : optimize_(true), target_host_cpu_(true) {} + explicit Configuration(bool optimize) : optimize_(optimize), target_host_cpu_(true) {} + + std::size_t Hash() const; + bool operator==(const Configuration& other) const; + bool operator!=(const Configuration& other) const; + + bool optimize() const { return optimize_; } + bool target_host_cpu() const { return target_host_cpu_; } + + void set_optimize(bool optimize) { optimize_ = optimize; } + void target_host_cpu(bool target_host_cpu) { target_host_cpu_ = target_host_cpu; } + + private: + bool optimize_; /* optimise the generated llvm IR */ + bool target_host_cpu_; /* set the mcpu flag to host cpu while compiling llvm ir */ +}; + +/// \brief configuration builder for gandiva +/// +/// Provides a default configuration and convenience methods +/// to override specific values and build a custom instance +class GANDIVA_EXPORT ConfigurationBuilder { + public: + std::shared_ptr<Configuration> build() { + std::shared_ptr<Configuration> configuration(new Configuration()); + return configuration; + } + + std::shared_ptr<Configuration> build(bool optimize) { + std::shared_ptr<Configuration> configuration(new Configuration(optimize)); + return configuration; + } + + static std::shared_ptr<Configuration> DefaultConfiguration() { + return default_configuration_; + } + + private: + static std::shared_ptr<Configuration> InitDefaultConfig() { + std::shared_ptr<Configuration> configuration(new Configuration()); + return configuration; + } + + static const std::shared_ptr<Configuration> default_configuration_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/context_helper.cc b/src/arrow/cpp/src/gandiva/context_helper.cc new file mode 100644 index 000000000..224bfd8f5 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/context_helper.cc @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This file is also used in the pre-compiled unit tests, which do include +// llvm/engine/.. +#ifndef GANDIVA_UNIT_TEST +#include "gandiva/exported_funcs.h" +#include "gandiva/gdv_function_stubs.h" + +#include "gandiva/engine.h" + +namespace gandiva { + +void ExportedContextFunctions::AddMappings(Engine* engine) const { + std::vector<llvm::Type*> args; + auto types = engine->types(); + + // gdv_fn_context_set_error_msg + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type()}; // char const* err_msg + + engine->AddGlobalMappingForFunc("gdv_fn_context_set_error_msg", types->void_type(), + args, + reinterpret_cast<void*>(gdv_fn_context_set_error_msg)); + + // gdv_fn_context_arena_malloc + args = {types->i64_type(), // int64_t context_ptr + types->i32_type()}; // int32_t size + + engine->AddGlobalMappingForFunc("gdv_fn_context_arena_malloc", types->i8_ptr_type(), + args, + reinterpret_cast<void*>(gdv_fn_context_arena_malloc)); + + // gdv_fn_context_arena_reset + args = {types->i64_type()}; // int64_t context_ptr + + engine->AddGlobalMappingForFunc("gdv_fn_context_arena_reset", types->void_type(), args, + reinterpret_cast<void*>(gdv_fn_context_arena_reset)); +} + +} // namespace gandiva +#endif // !GANDIVA_UNIT_TEST + +#include "gandiva/execution_context.h" + +extern "C" { + +void gdv_fn_context_set_error_msg(int64_t context_ptr, char const* err_msg) { + auto context = reinterpret_cast<gandiva::ExecutionContext*>(context_ptr); + context->set_error_msg(err_msg); +} + +uint8_t* gdv_fn_context_arena_malloc(int64_t context_ptr, int32_t size) { + auto context = reinterpret_cast<gandiva::ExecutionContext*>(context_ptr); + return context->arena()->Allocate(size); +} + +void gdv_fn_context_arena_reset(int64_t context_ptr) { + auto context = reinterpret_cast<gandiva::ExecutionContext*>(context_ptr); + return context->arena()->Reset(); +} +} diff --git a/src/arrow/cpp/src/gandiva/date_utils.cc b/src/arrow/cpp/src/gandiva/date_utils.cc new file mode 100644 index 000000000..f0a80d3c9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/date_utils.cc @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <algorithm> +#include <cstdint> +#include <memory> +#include <sstream> +#include <vector> + +#include "gandiva/date_utils.h" + +namespace gandiva { + +std::vector<std::string> DateUtils::GetMatches(std::string pattern, bool exactMatch) { + // we are case insensitive + std::transform(pattern.begin(), pattern.end(), pattern.begin(), ::tolower); + std::vector<std::string> matches; + + for (const auto& it : sql_date_format_to_boost_map_) { + if (it.first.find(pattern) != std::string::npos && + (!exactMatch || (it.first.length() == pattern.length()))) { + matches.push_back(it.first); + } + } + + return matches; +} + +std::vector<std::string> DateUtils::GetPotentialMatches(const std::string& pattern) { + return GetMatches(pattern, false); +} + +std::vector<std::string> DateUtils::GetExactMatches(const std::string& pattern) { + return GetMatches(pattern, true); +} + +/** + * Validates and converts format to the strptime equivalent + * + */ +Status DateUtils::ToInternalFormat(const std::string& format, + std::shared_ptr<std::string>* internal_format) { + std::stringstream builder; + std::stringstream buffer; + bool is_in_quoted_text = false; + + for (size_t i = 0; i < format.size(); i++) { + char currentChar = format[i]; + + // logic before we append to the buffer + if (currentChar == '"') { + if (is_in_quoted_text) { + // we are done with a quoted block + is_in_quoted_text = false; + + // use ' for quoting + builder << '\''; + builder << buffer.str(); + builder << '\''; + + // clear buffer + buffer.str(""); + continue; + } else { + ARROW_RETURN_IF(buffer.str().length() > 0, + Status::Invalid("Invalid date format string '", format, "'")); + + is_in_quoted_text = true; + continue; + } + } + + // handle special characters we want to simply pass through, but only if not in quoted + // and the buffer is empty + std::string special_characters = "*-/,.;: "; + if (!is_in_quoted_text && buffer.str().length() == 0 && + (special_characters.find_first_of(currentChar) != std::string::npos)) { + builder << currentChar; + continue; + } + + // append to the buffer + buffer << currentChar; + + // nothing else to do if we are in quoted text + if (is_in_quoted_text) { + continue; + } + + // check how many matches we have for our buffer + std::vector<std::string> potentialList = GetPotentialMatches(buffer.str()); + int64_t potentialCount = potentialList.size(); + + if (potentialCount >= 1) { + // one potential and the length match + if (potentialCount == 1 && potentialList[0].length() == buffer.str().length()) { + // we have a match! + builder << sql_date_format_to_boost_map_[potentialList[0]]; + buffer.str(""); + } else { + // Some patterns (like MON, MONTH) can cause ambiguity, such as "MON:". "MON" + // will have two potential matches, but "MON:" will match nothing, so we want to + // look ahead when we match "MON" and check if adding the next char leads to 0 + // potentials. If it does, we go ahead and treat the buffer as matched (if a + // potential match exists that matches the buffer) + if (format.length() - 1 > i) { + std::string lookAheadPattern = (buffer.str() + format.at(i + 1)); + std::transform(lookAheadPattern.begin(), lookAheadPattern.end(), + lookAheadPattern.begin(), ::tolower); + bool lookAheadMatched = false; + + // we can query potentialList to see if it has anything that matches the + // lookahead pattern + for (std::string potential : potentialList) { + if (potential.find(lookAheadPattern) != std::string::npos) { + lookAheadMatched = true; + break; + } + } + + if (!lookAheadMatched) { + // check if any of the potential matches are the same length as our buffer, we + // do not want to match "MO:" + bool matched = false; + for (std::string potential : potentialList) { + if (potential.length() == buffer.str().length()) { + matched = true; + break; + } + } + + if (matched) { + std::string match = buffer.str(); + std::transform(match.begin(), match.end(), match.begin(), ::tolower); + builder << sql_date_format_to_boost_map_[match]; + buffer.str(""); + continue; + } + } + } + } + } else { + return Status::Invalid("Invalid date format string '", format, "'"); + } + } + + if (buffer.str().length() > 0) { + // Some patterns (like MON, MONTH) can cause us to reach this point with a valid + // buffer value as MON has 2 valid potential matches, so double check here + std::vector<std::string> exactMatches = GetExactMatches(buffer.str()); + if (exactMatches.size() == 1 && exactMatches[0].length() == buffer.str().length()) { + builder << sql_date_format_to_boost_map_[exactMatches[0]]; + } else { + // Format partially parsed + int64_t pos = format.length() - buffer.str().length(); + return Status::Invalid("Invalid date format string '", format, "' at position ", + pos); + } + } + std::string final_pattern = builder.str(); + internal_format->reset(new std::string(final_pattern)); + return Status::OK(); +} + +DateUtils::date_format_converter DateUtils::sql_date_format_to_boost_map_ = InitMap(); + +DateUtils::date_format_converter DateUtils::InitMap() { + date_format_converter map; + + // Era + map["ad"] = "%EC"; + map["bc"] = "%EC"; + // Meridian + map["am"] = "%p"; + map["pm"] = "%p"; + // Century + map["cc"] = "%C"; + // Week of year + map["ww"] = "%W"; + // Day of week + map["d"] = "%u"; + // Day name of week + map["dy"] = "%a"; + map["day"] = "%a"; + // Year + map["yyyy"] = "%Y"; + map["yy"] = "%y"; + // Day of year + map["ddd"] = "%j"; + // Month + map["mm"] = "%m"; + map["mon"] = "%b"; + map["month"] = "%b"; + // Day of month + map["dd"] = "%d"; + // Hour of day + map["hh"] = "%I"; + map["hh12"] = "%I"; + map["hh24"] = "%H"; + // Minutes + map["mi"] = "%M"; + // Seconds + map["ss"] = "%S"; + // Milliseconds + map["f"] = "S"; + map["ff"] = "SS"; + map["fff"] = "SSS"; + /* + // Timezone not tested/supported yet fully. + map["tzd"] = "%Z"; + map["tzo"] = "%z"; + map["tzh:tzm"] = "%z"; + */ + + return map; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/date_utils.h b/src/arrow/cpp/src/gandiva/date_utils.h new file mode 100644 index 000000000..0d39a5f29 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/date_utils.h @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "arrow/util/macros.h" + +#include "gandiva/arrow.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Utility class for converting sql date patterns to internal date patterns. +class GANDIVA_EXPORT DateUtils { + public: + static Status ToInternalFormat(const std::string& format, + std::shared_ptr<std::string>* internal_format); + + private: + using date_format_converter = std::unordered_map<std::string, std::string>; + + static date_format_converter sql_date_format_to_boost_map_; + + static date_format_converter InitMap(); + + static std::vector<std::string> GetMatches(std::string pattern, bool exactMatch); + + static std::vector<std::string> GetPotentialMatches(const std::string& pattern); + + static std::vector<std::string> GetExactMatches(const std::string& pattern); +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/decimal_ir.cc b/src/arrow/cpp/src/gandiva/decimal_ir.cc new file mode 100644 index 000000000..5d5d30b4a --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_ir.cc @@ -0,0 +1,559 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <sstream> +#include <unordered_set> +#include <utility> + +#include "arrow/status.h" +#include "gandiva/decimal_ir.h" +#include "gandiva/decimal_type_util.h" + +// Algorithms adapted from Apache Impala + +namespace gandiva { + +#define ADD_TRACE_32(msg, value) \ + if (enable_ir_traces_) { \ + AddTrace32(msg, value); \ + } +#define ADD_TRACE_128(msg, value) \ + if (enable_ir_traces_) { \ + AddTrace128(msg, value); \ + } + +// These are the functions defined in this file. The rest are in precompiled folder, +// and the i128 needs to be dis-assembled for those. +static const char* kAddFunction = "add_decimal128_decimal128"; +static const char* kSubtractFunction = "subtract_decimal128_decimal128"; +static const char* kEQFunction = "equal_decimal128_decimal128"; +static const char* kNEFunction = "not_equal_decimal128_decimal128"; +static const char* kLTFunction = "less_than_decimal128_decimal128"; +static const char* kLEFunction = "less_than_or_equal_to_decimal128_decimal128"; +static const char* kGTFunction = "greater_than_decimal128_decimal128"; +static const char* kGEFunction = "greater_than_or_equal_to_decimal128_decimal128"; + +static const std::unordered_set<std::string> kDecimalIRBuilderFunctions{ + kAddFunction, kSubtractFunction, kEQFunction, kNEFunction, + kLTFunction, kLEFunction, kGTFunction, kGEFunction}; + +const char* DecimalIR::kScaleMultipliersName = "gandivaScaleMultipliers"; + +/// Populate globals required by decimal IR. +/// TODO: can this be done just once ? +void DecimalIR::AddGlobals(Engine* engine) { + auto types = engine->types(); + + // populate vector : [ 1, 10, 100, 1000, ..] + std::string value = "1"; + std::vector<llvm::Constant*> scale_multipliers; + for (int i = 0; i < DecimalTypeUtil::kMaxPrecision + 1; ++i) { + auto multiplier = + llvm::ConstantInt::get(llvm::Type::getInt128Ty(*engine->context()), value, 10); + scale_multipliers.push_back(multiplier); + value.append("0"); + } + + auto array_type = + llvm::ArrayType::get(types->i128_type(), DecimalTypeUtil::kMaxPrecision + 1); + auto initializer = llvm::ConstantArray::get( + array_type, llvm::ArrayRef<llvm::Constant*>(scale_multipliers)); + + auto globalScaleMultipliers = new llvm::GlobalVariable( + *engine->module(), array_type, true /*constant*/, + llvm::GlobalValue::LinkOnceAnyLinkage, initializer, kScaleMultipliersName); + globalScaleMultipliers->setAlignment(LLVM_ALIGN(16)); +} + +// Lookup intrinsic functions +void DecimalIR::InitializeIntrinsics() { + sadd_with_overflow_fn_ = llvm::Intrinsic::getDeclaration( + module(), llvm::Intrinsic::sadd_with_overflow, types()->i128_type()); + DCHECK_NE(sadd_with_overflow_fn_, nullptr); + + smul_with_overflow_fn_ = llvm::Intrinsic::getDeclaration( + module(), llvm::Intrinsic::smul_with_overflow, types()->i128_type()); + DCHECK_NE(smul_with_overflow_fn_, nullptr); + + i128_with_overflow_struct_type_ = + sadd_with_overflow_fn_->getFunctionType()->getReturnType(); +} + +// CPP: return kScaleMultipliers[scale] +llvm::Value* DecimalIR::GetScaleMultiplier(llvm::Value* scale) { + auto const_array = module()->getGlobalVariable(kScaleMultipliersName); + auto ptr = CreateGEP(ir_builder(), const_array, {types()->i32_constant(0), scale}); + return CreateLoad(ir_builder(), ptr); +} + +// CPP: x <= y ? y : x +llvm::Value* DecimalIR::GetHigherScale(llvm::Value* x_scale, llvm::Value* y_scale) { + llvm::Value* le = ir_builder()->CreateICmpSLE(x_scale, y_scale); + return ir_builder()->CreateSelect(le, y_scale, x_scale); +} + +// CPP: return (increase_scale_by <= 0) ? +// in_value : in_value * GetScaleMultiplier(increase_scale_by) +llvm::Value* DecimalIR::IncreaseScale(llvm::Value* in_value, + llvm::Value* increase_scale_by) { + llvm::Value* le_zero = + ir_builder()->CreateICmpSLE(increase_scale_by, types()->i32_constant(0)); + // then block + auto then_lambda = [&] { return in_value; }; + + // else block + auto else_lambda = [&] { + llvm::Value* multiplier = GetScaleMultiplier(increase_scale_by); + return ir_builder()->CreateMul(in_value, multiplier); + }; + + return BuildIfElse(le_zero, types()->i128_type(), then_lambda, else_lambda); +} + +// CPP: return (increase_scale_by <= 0) ? +// {in_value,false} : {in_value * GetScaleMultiplier(increase_scale_by),true} +// +// The return value also indicates if there was an overflow while increasing the scale. +DecimalIR::ValueWithOverflow DecimalIR::IncreaseScaleWithOverflowCheck( + llvm::Value* in_value, llvm::Value* increase_scale_by) { + llvm::Value* le_zero = + ir_builder()->CreateICmpSLE(increase_scale_by, types()->i32_constant(0)); + + // then block + auto then_lambda = [&] { + ValueWithOverflow ret{in_value, types()->false_constant()}; + return ret.AsStruct(this); + }; + + // else block + auto else_lambda = [&] { + llvm::Value* multiplier = GetScaleMultiplier(increase_scale_by); + return ir_builder()->CreateCall(smul_with_overflow_fn_, {in_value, multiplier}); + }; + + auto ir_struct = + BuildIfElse(le_zero, i128_with_overflow_struct_type_, then_lambda, else_lambda); + return ValueWithOverflow::MakeFromStruct(this, ir_struct); +} + +// CPP: return (reduce_scale_by <= 0) ? +// in_value : in_value / GetScaleMultiplier(reduce_scale_by) +// +// ReduceScale cannot cause an overflow. +llvm::Value* DecimalIR::ReduceScale(llvm::Value* in_value, llvm::Value* reduce_scale_by) { + auto le_zero = ir_builder()->CreateICmpSLE(reduce_scale_by, types()->i32_constant(0)); + // then block + auto then_lambda = [&] { return in_value; }; + + // else block + auto else_lambda = [&] { + // TODO : handle rounding. + llvm::Value* multiplier = GetScaleMultiplier(reduce_scale_by); + return ir_builder()->CreateSDiv(in_value, multiplier); + }; + + return BuildIfElse(le_zero, types()->i128_type(), then_lambda, else_lambda); +} + +/// @brief Fast-path for add +/// Adjust x and y to the same scale, and add them. +llvm::Value* DecimalIR::AddFastPath(const ValueFull& x, const ValueFull& y) { + auto higher_scale = GetHigherScale(x.scale(), y.scale()); + ADD_TRACE_32("AddFastPath : higher_scale", higher_scale); + + // CPP : x_scaled = IncreaseScale(x_value, higher_scale - x_scale) + auto x_delta = ir_builder()->CreateSub(higher_scale, x.scale()); + auto x_scaled = IncreaseScale(x.value(), x_delta); + ADD_TRACE_128("AddFastPath : x_scaled", x_scaled); + + // CPP : y_scaled = IncreaseScale(y_value, higher_scale - y_scale) + auto y_delta = ir_builder()->CreateSub(higher_scale, y.scale()); + auto y_scaled = IncreaseScale(y.value(), y_delta); + ADD_TRACE_128("AddFastPath : y_scaled", y_scaled); + + auto sum = ir_builder()->CreateAdd(x_scaled, y_scaled); + ADD_TRACE_128("AddFastPath : sum", sum); + return sum; +} + +// @brief Add with overflow check. +/// Adjust x and y to the same scale, add them, and reduce sum to output scale. +/// If there is an overflow, the sum is set to 0. +DecimalIR::ValueWithOverflow DecimalIR::AddWithOverflowCheck(const ValueFull& x, + const ValueFull& y, + const ValueFull& out) { + auto higher_scale = GetHigherScale(x.scale(), y.scale()); + ADD_TRACE_32("AddWithOverflowCheck : higher_scale", higher_scale); + + // CPP : x_scaled = IncreaseScale(x_value, higher_scale - x.scale()) + auto x_delta = ir_builder()->CreateSub(higher_scale, x.scale()); + auto x_scaled = IncreaseScaleWithOverflowCheck(x.value(), x_delta); + ADD_TRACE_128("AddWithOverflowCheck : x_scaled", x_scaled.value()); + + // CPP : y_scaled = IncreaseScale(y_value, higher_scale - y_scale) + auto y_delta = ir_builder()->CreateSub(higher_scale, y.scale()); + auto y_scaled = IncreaseScaleWithOverflowCheck(y.value(), y_delta); + ADD_TRACE_128("AddWithOverflowCheck : y_scaled", y_scaled.value()); + + // CPP : sum = x_scaled + y_scaled + auto sum_ir_struct = ir_builder()->CreateCall(sadd_with_overflow_fn_, + {x_scaled.value(), y_scaled.value()}); + auto sum = ValueWithOverflow::MakeFromStruct(this, sum_ir_struct); + ADD_TRACE_128("AddWithOverflowCheck : sum", sum.value()); + + // CPP : overflow ? 0 : sum / GetScaleMultiplier(max_scale - out_scale) + auto overflow = GetCombinedOverflow({x_scaled, y_scaled, sum}); + ADD_TRACE_32("AddWithOverflowCheck : overflow", overflow); + auto then_lambda = [&] { + // if there is an overflow, the value returned won't be used. so, save the division. + return types()->i128_constant(0); + }; + auto else_lambda = [&] { + auto reduce_scale_by = ir_builder()->CreateSub(higher_scale, out.scale()); + return ReduceScale(sum.value(), reduce_scale_by); + }; + auto sum_descaled = + BuildIfElse(overflow, types()->i128_type(), then_lambda, else_lambda); + return ValueWithOverflow(sum_descaled, overflow); +} + +// This is pretty complex, so use CPP fns. +llvm::Value* DecimalIR::AddLarge(const ValueFull& x, const ValueFull& y, + const ValueFull& out) { + auto block = ir_builder()->GetInsertBlock(); + auto out_high_ptr = new llvm::AllocaInst(types()->i64_type(), 0, "out_hi", block); + auto out_low_ptr = new llvm::AllocaInst(types()->i64_type(), 0, "out_low", block); + auto x_split = ValueSplit::MakeFromInt128(this, x.value()); + auto y_split = ValueSplit::MakeFromInt128(this, y.value()); + + std::vector<llvm::Value*> args = { + x_split.high(), x_split.low(), x.precision(), x.scale(), + y_split.high(), y_split.low(), y.precision(), y.scale(), + out.precision(), out.scale(), out_high_ptr, out_low_ptr, + }; + ir_builder()->CreateCall(module()->getFunction("add_large_decimal128_decimal128"), + args); + + auto out_high = CreateLoad(ir_builder(), out_high_ptr); + auto out_low = CreateLoad(ir_builder(), out_low_ptr); + auto sum = ValueSplit(out_high, out_low).AsInt128(this); + ADD_TRACE_128("AddLarge : sum", sum); + return sum; +} + +/// The output scale/precision cannot be arbitrary values. The algo here depends on them +/// to be the same as computed in DecimalTypeSql. +/// TODO: enforce this. +Status DecimalIR::BuildAdd() { + // Create fn prototype : + // int128_t + // add_decimal128_decimal128(int128_t x_value, int32_t x_precision, int32_t x_scale, + // int128_t y_value, int32_t y_precision, int32_t y_scale + // int32_t out_precision, int32_t out_scale) + auto i32 = types()->i32_type(); + auto i128 = types()->i128_type(); + auto function = BuildFunction(kAddFunction, i128, + { + {"x_value", i128}, + {"x_precision", i32}, + {"x_scale", i32}, + {"y_value", i128}, + {"y_precision", i32}, + {"y_scale", i32}, + {"out_precision", i32}, + {"out_scale", i32}, + }); + + auto arg_iter = function->arg_begin(); + ValueFull x(&arg_iter[0], &arg_iter[1], &arg_iter[2]); + ValueFull y(&arg_iter[3], &arg_iter[4], &arg_iter[5]); + ValueFull out(nullptr, &arg_iter[6], &arg_iter[7]); + + auto entry = llvm::BasicBlock::Create(*context(), "entry", function); + ir_builder()->SetInsertPoint(entry); + + // CPP : + // if (out_precision < 38) { + // return AddFastPath(x, y) + // } else { + // ret = AddWithOverflowCheck(x, y) + // if (ret.overflow) + // return AddLarge(x, y) + // else + // return ret.value; + // } + llvm::Value* lt_max_precision = ir_builder()->CreateICmpSLT( + out.precision(), types()->i32_constant(DecimalTypeUtil::kMaxPrecision)); + auto then_lambda = [&] { + // fast-path add + return AddFastPath(x, y); + }; + auto else_lambda = [&] { + if (kUseOverflowIntrinsics) { + // do the add and check if there was overflow + auto ret = AddWithOverflowCheck(x, y, out); + + // if there is an overflow, switch to the AddLarge codepath. + return BuildIfElse( + ret.overflow(), types()->i128_type(), [&] { return AddLarge(x, y, out); }, + [&] { return ret.value(); }); + } else { + return AddLarge(x, y, out); + } + }; + auto value = + BuildIfElse(lt_max_precision, types()->i128_type(), then_lambda, else_lambda); + + // store result to out + ir_builder()->CreateRet(value); + return Status::OK(); +} + +Status DecimalIR::BuildSubtract() { + // Create fn prototype : + // int128_t + // subtract_decimal128_decimal128(int128_t x_value, int32_t x_precision, int32_t + // x_scale, + // int128_t y_value, int32_t y_precision, int32_t y_scale + // int32_t out_precision, int32_t out_scale) + auto i32 = types()->i32_type(); + auto i128 = types()->i128_type(); + auto function = BuildFunction(kSubtractFunction, i128, + { + {"x_value", i128}, + {"x_precision", i32}, + {"x_scale", i32}, + {"y_value", i128}, + {"y_precision", i32}, + {"y_scale", i32}, + {"out_precision", i32}, + {"out_scale", i32}, + }); + + auto entry = llvm::BasicBlock::Create(*context(), "entry", function); + ir_builder()->SetInsertPoint(entry); + + // reuse add function after negating y_value. i.e + // add(x_value, x_precision, x_scale, -y_value, y_precision, y_scale, + // out_precision, out_scale) + std::vector<llvm::Value*> args; + int i = 0; + for (auto& in_arg : function->args()) { + if (i == 3) { + auto y_neg_value = ir_builder()->CreateNeg(&in_arg); + args.push_back(y_neg_value); + } else { + args.push_back(&in_arg); + } + ++i; + } + auto value = ir_builder()->CreateCall(module()->getFunction(kAddFunction), args); + + // store result to out + ir_builder()->CreateRet(value); + return Status::OK(); +} + +Status DecimalIR::BuildCompare(const std::string& function_name, + llvm::ICmpInst::Predicate cmp_instruction) { + // Create fn prototype : + // bool + // function_name(int128_t x_value, int32_t x_precision, int32_t x_scale, + // int128_t y_value, int32_t y_precision, int32_t y_scale) + + auto i32 = types()->i32_type(); + auto i128 = types()->i128_type(); + auto function = BuildFunction(function_name, types()->i1_type(), + { + {"x_value", i128}, + {"x_precision", i32}, + {"x_scale", i32}, + {"y_value", i128}, + {"y_precision", i32}, + {"y_scale", i32}, + }); + + auto arg_iter = function->arg_begin(); + ValueFull x(&arg_iter[0], &arg_iter[1], &arg_iter[2]); + ValueFull y(&arg_iter[3], &arg_iter[4], &arg_iter[5]); + + auto entry = llvm::BasicBlock::Create(*context(), "entry", function); + ir_builder()->SetInsertPoint(entry); + + // Make call to pre-compiled IR function. + auto x_split = ValueSplit::MakeFromInt128(this, x.value()); + auto y_split = ValueSplit::MakeFromInt128(this, y.value()); + + std::vector<llvm::Value*> args = { + x_split.high(), x_split.low(), x.precision(), x.scale(), + y_split.high(), y_split.low(), y.precision(), y.scale(), + }; + auto cmp_value = ir_builder()->CreateCall( + module()->getFunction("compare_decimal128_decimal128_internal"), args); + auto result = + ir_builder()->CreateICmp(cmp_instruction, cmp_value, types()->i32_constant(0)); + ir_builder()->CreateRet(result); + return Status::OK(); +} + +llvm::Value* DecimalIR::CallDecimalFunction(const std::string& function_name, + llvm::Type* return_type, + const std::vector<llvm::Value*>& params) { + if (kDecimalIRBuilderFunctions.count(function_name) != 0) { + // this is fn built with the irbuilder. + return ir_builder()->CreateCall(module()->getFunction(function_name), params); + } + + // ppre-compiler fn : disassemble i128 to two i64s and re-assemble. + auto i128 = types()->i128_type(); + auto i64 = types()->i64_type(); + std::vector<llvm::Value*> dis_assembled_args; + for (auto& arg : params) { + if (arg->getType() == i128) { + // split i128 arg into two int64s. + auto split = ValueSplit::MakeFromInt128(this, arg); + dis_assembled_args.push_back(split.high()); + dis_assembled_args.push_back(split.low()); + } else { + dis_assembled_args.push_back(arg); + } + } + + llvm::Value* result = nullptr; + if (return_type == i128) { + // for i128 ret, replace with two int64* args, and join them. + auto block = ir_builder()->GetInsertBlock(); + auto out_high_ptr = new llvm::AllocaInst(i64, 0, "out_hi", block); + auto out_low_ptr = new llvm::AllocaInst(i64, 0, "out_low", block); + dis_assembled_args.push_back(out_high_ptr); + dis_assembled_args.push_back(out_low_ptr); + + // Make call to pre-compiled IR function. + ir_builder()->CreateCall(module()->getFunction(function_name), dis_assembled_args); + + auto out_high = CreateLoad(ir_builder(), out_high_ptr); + auto out_low = CreateLoad(ir_builder(), out_low_ptr); + result = ValueSplit(out_high, out_low).AsInt128(this); + } else { + DCHECK_NE(return_type, types()->void_type()); + + // Make call to pre-compiled IR function. + result = ir_builder()->CreateCall(module()->getFunction(function_name), + dis_assembled_args); + } + return result; +} + +Status DecimalIR::AddFunctions(Engine* engine) { + auto decimal_ir = std::make_shared<DecimalIR>(engine); + + // Populate global variables used by decimal operations. + decimal_ir->AddGlobals(engine); + + // Lookup intrinsic functions + decimal_ir->InitializeIntrinsics(); + + ARROW_RETURN_NOT_OK(decimal_ir->BuildAdd()); + ARROW_RETURN_NOT_OK(decimal_ir->BuildSubtract()); + ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kEQFunction, llvm::ICmpInst::ICMP_EQ)); + ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kNEFunction, llvm::ICmpInst::ICMP_NE)); + ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kLTFunction, llvm::ICmpInst::ICMP_SLT)); + ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kLEFunction, llvm::ICmpInst::ICMP_SLE)); + ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kGTFunction, llvm::ICmpInst::ICMP_SGT)); + ARROW_RETURN_NOT_OK(decimal_ir->BuildCompare(kGEFunction, llvm::ICmpInst::ICMP_SGE)); + return Status::OK(); +} + +// Do an bitwise-or of all the overflow bits. +llvm::Value* DecimalIR::GetCombinedOverflow( + std::vector<DecimalIR::ValueWithOverflow> vec) { + llvm::Value* res = types()->false_constant(); + for (auto& val : vec) { + res = ir_builder()->CreateOr(res, val.overflow()); + } + return res; +} + +DecimalIR::ValueSplit DecimalIR::ValueSplit::MakeFromInt128(DecimalIR* decimal_ir, + llvm::Value* in) { + auto builder = decimal_ir->ir_builder(); + auto types = decimal_ir->types(); + + auto high = builder->CreateLShr(in, types->i128_constant(64)); + high = builder->CreateTrunc(high, types->i64_type()); + auto low = builder->CreateTrunc(in, types->i64_type()); + return ValueSplit(high, low); +} + +/// Convert IR struct {%i64, %i64} to cpp class ValueSplit +DecimalIR::ValueSplit DecimalIR::ValueSplit::MakeFromStruct(DecimalIR* decimal_ir, + llvm::Value* dstruct) { + auto builder = decimal_ir->ir_builder(); + auto high = builder->CreateExtractValue(dstruct, 0); + auto low = builder->CreateExtractValue(dstruct, 1); + return DecimalIR::ValueSplit(high, low); +} + +llvm::Value* DecimalIR::ValueSplit::AsInt128(DecimalIR* decimal_ir) const { + auto builder = decimal_ir->ir_builder(); + auto types = decimal_ir->types(); + + auto value = builder->CreateSExt(high_, types->i128_type()); + value = builder->CreateShl(value, types->i128_constant(64)); + value = builder->CreateAdd(value, builder->CreateZExt(low_, types->i128_type())); + return value; +} + +/// Convert IR struct {%i128, %i1} to cpp class ValueWithOverflow +DecimalIR::ValueWithOverflow DecimalIR::ValueWithOverflow::MakeFromStruct( + DecimalIR* decimal_ir, llvm::Value* dstruct) { + auto builder = decimal_ir->ir_builder(); + auto value = builder->CreateExtractValue(dstruct, 0); + auto overflow = builder->CreateExtractValue(dstruct, 1); + return DecimalIR::ValueWithOverflow(value, overflow); +} + +/// Convert to IR struct {%i128, %i1} +llvm::Value* DecimalIR::ValueWithOverflow::AsStruct(DecimalIR* decimal_ir) const { + auto builder = decimal_ir->ir_builder(); + + auto undef = llvm::UndefValue::get(decimal_ir->i128_with_overflow_struct_type_); + auto struct_val = builder->CreateInsertValue(undef, value(), 0); + return builder->CreateInsertValue(struct_val, overflow(), 1); +} + +/// debug traces +void DecimalIR::AddTrace(const std::string& fmt, std::vector<llvm::Value*> args) { + DCHECK(enable_ir_traces_); + + auto ir_str = ir_builder()->CreateGlobalStringPtr(fmt); + args.insert(args.begin(), ir_str); + ir_builder()->CreateCall(module()->getFunction("printf"), args, "trace"); +} + +void DecimalIR::AddTrace32(const std::string& msg, llvm::Value* value) { + AddTrace("DECIMAL_IR_TRACE:: " + msg + " %d\n", {value}); +} + +void DecimalIR::AddTrace128(const std::string& msg, llvm::Value* value) { + // convert i128 into two i64s for printing + auto split = ValueSplit::MakeFromInt128(this, value); + AddTrace("DECIMAL_IR_TRACE:: " + msg + " %llx:%llx (%lld:%llu)\n", + {split.high(), split.low(), split.high(), split.low()}); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/decimal_ir.h b/src/arrow/cpp/src/gandiva/decimal_ir.h new file mode 100644 index 000000000..b11730f1e --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_ir.h @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <vector> + +#include "gandiva/function_ir_builder.h" + +namespace gandiva { + +/// @brief Decimal IR functions +class DecimalIR : public FunctionIRBuilder { + public: + explicit DecimalIR(Engine* engine) + : FunctionIRBuilder(engine), enable_ir_traces_(false) {} + + /// Build decimal IR functions and add them to the engine. + static Status AddFunctions(Engine* engine); + + void EnableTraces() { enable_ir_traces_ = true; } + + llvm::Value* CallDecimalFunction(const std::string& function_name, + llvm::Type* return_type, + const std::vector<llvm::Value*>& args); + + private: + /// The intrinsic fn for divide with small divisors is about 10x slower, so not + /// using these. + static const bool kUseOverflowIntrinsics = false; + + // Holder for an i128 value, along with its with scale and precision. + class ValueFull { + public: + ValueFull(llvm::Value* value, llvm::Value* precision, llvm::Value* scale) + : value_(value), precision_(precision), scale_(scale) {} + + llvm::Value* value() const { return value_; } + llvm::Value* precision() const { return precision_; } + llvm::Value* scale() const { return scale_; } + + private: + llvm::Value* value_; + llvm::Value* precision_; + llvm::Value* scale_; + }; + + // Holder for an i128 value, and a boolean indicating overflow. + class ValueWithOverflow { + public: + ValueWithOverflow(llvm::Value* value, llvm::Value* overflow) + : value_(value), overflow_(overflow) {} + + // Make from IR struct + static ValueWithOverflow MakeFromStruct(DecimalIR* decimal_ir, llvm::Value* dstruct); + + // Build a corresponding IR struct + llvm::Value* AsStruct(DecimalIR* decimal_ir) const; + + llvm::Value* value() const { return value_; } + llvm::Value* overflow() const { return overflow_; } + + private: + llvm::Value* value_; + llvm::Value* overflow_; + }; + + // Holder for an i128 value that is split into two i64s + class ValueSplit { + public: + ValueSplit(llvm::Value* high, llvm::Value* low) : high_(high), low_(low) {} + + // Make from i128 value + static ValueSplit MakeFromInt128(DecimalIR* decimal_ir, llvm::Value* in); + + // Make from IR struct + static ValueSplit MakeFromStruct(DecimalIR* decimal_ir, llvm::Value* dstruct); + + // Combine the two parts into an i128 + llvm::Value* AsInt128(DecimalIR* decimal_ir) const; + + llvm::Value* high() const { return high_; } + llvm::Value* low() const { return low_; } + + private: + llvm::Value* high_; + llvm::Value* low_; + }; + + // Add global variables to the module. + static void AddGlobals(Engine* engine); + + // Initialize intrinsic functions that are used by decimal operations. + void InitializeIntrinsics(); + + // Create IR builder for decimal add function. + static Status MakeAdd(Engine* engine, std::shared_ptr<FunctionIRBuilder>* out); + + // Get the multiplier for specified scale (i.e 10^scale) + llvm::Value* GetScaleMultiplier(llvm::Value* scale); + + // Get the higher of the two scales + llvm::Value* GetHigherScale(llvm::Value* x_scale, llvm::Value* y_scale); + + // Increase scale of 'in_value' by 'increase_scale_by'. + // - If 'increase_scale_by' is <= 0, does nothing. + llvm::Value* IncreaseScale(llvm::Value* in_value, llvm::Value* increase_scale_by); + + // Similar to IncreaseScale. but, also check if there is overflow. + ValueWithOverflow IncreaseScaleWithOverflowCheck(llvm::Value* in_value, + llvm::Value* increase_scale_by); + + // Reduce scale of 'in_value' by 'reduce_scale_by'. + // - If 'reduce_scale_by' is <= 0, does nothing. + llvm::Value* ReduceScale(llvm::Value* in_value, llvm::Value* reduce_scale_by); + + // Fast path of add: guaranteed no overflow + llvm::Value* AddFastPath(const ValueFull& x, const ValueFull& y); + + // Similar to AddFastPath, but check if there's an overflow. + ValueWithOverflow AddWithOverflowCheck(const ValueFull& x, const ValueFull& y, + const ValueFull& out); + + // Do addition of large integers (both positive and negative). + llvm::Value* AddLarge(const ValueFull& x, const ValueFull& y, const ValueFull& out); + + // Get the combined overflow (logical or). + llvm::Value* GetCombinedOverflow(std::vector<ValueWithOverflow> values); + + // Build the function for adding decimals. + Status BuildAdd(); + + // Build the function for decimal subtraction. + Status BuildSubtract(); + + // Build the function for decimal multiplication. + Status BuildMultiply(); + + // Build the function for decimal division/mod. + Status BuildDivideOrMod(const std::string& function_name, + const std::string& internal_name); + + Status BuildCompare(const std::string& function_name, + llvm::ICmpInst::Predicate cmp_instruction); + + Status BuildDecimalFunction(const std::string& function_name, llvm::Type* return_type, + std::vector<NamedArg> in_types); + + // Add a trace in IR code. + void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args); + + // Add a trace msg along with a 32-bit integer. + void AddTrace32(const std::string& msg, llvm::Value* value); + + // Add a trace msg along with a 128-bit integer. + void AddTrace128(const std::string& msg, llvm::Value* value); + + // name of the global variable having the array of scale multipliers. + static const char* kScaleMultipliersName; + + // Intrinsic functions + llvm::Function* sadd_with_overflow_fn_; + llvm::Function* smul_with_overflow_fn_; + + // struct { i128: value, i1: overflow} + llvm::Type* i128_with_overflow_struct_type_; + + // if set to true, ir traces are enabled. Useful for debugging. + bool enable_ir_traces_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/decimal_scalar.h b/src/arrow/cpp/src/gandiva/decimal_scalar.h new file mode 100644 index 000000000..a03807b35 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_scalar.h @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License + +#pragma once + +#include <cstdint> +#include <string> +#include "arrow/util/decimal.h" +#include "arrow/util/hash_util.h" +#include "gandiva/basic_decimal_scalar.h" + +namespace gandiva { + +using Decimal128 = arrow::Decimal128; + +/// Represents a 128-bit decimal value along with its precision and scale. +/// +/// BasicDecimalScalar128 can be safely compiled to IR without references to libstdc++. +/// This class has additional functionality on top of BasicDecimalScalar128 to deal with +/// strings and streams. +class DecimalScalar128 : public BasicDecimalScalar128 { + public: + using BasicDecimalScalar128::BasicDecimalScalar128; + + DecimalScalar128(const std::string& value, int32_t precision, int32_t scale) + : BasicDecimalScalar128(Decimal128(value), precision, scale) {} + + /// \brief constructor creates a DecimalScalar128 from a BasicDecimalScalar128. + constexpr DecimalScalar128(const BasicDecimalScalar128& scalar) noexcept + : BasicDecimalScalar128(scalar) {} + + inline std::string ToString() const { + Decimal128 dvalue(value()); + return dvalue.ToString(0) + "," + std::to_string(precision()) + "," + + std::to_string(scale()); + } + + friend std::ostream& operator<<(std::ostream& os, const DecimalScalar128& dec) { + os << dec.ToString(); + return os; + } +}; + +} // namespace gandiva + +namespace std { +template <> +struct hash<gandiva::DecimalScalar128> { + std::size_t operator()(gandiva::DecimalScalar128 const& s) const noexcept { + arrow::BasicDecimal128 dvalue(s.value()); + + static const int kSeedValue = 4; + size_t result = kSeedValue; + + arrow::internal::hash_combine(result, dvalue.high_bits()); + arrow::internal::hash_combine(result, dvalue.low_bits()); + arrow::internal::hash_combine(result, s.precision()); + arrow::internal::hash_combine(result, s.scale()); + return result; + } +}; +} // namespace std diff --git a/src/arrow/cpp/src/gandiva/decimal_type_util.cc b/src/arrow/cpp/src/gandiva/decimal_type_util.cc new file mode 100644 index 000000000..2abc5a21e --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_type_util.cc @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/decimal_type_util.h" +#include "arrow/util/logging.h" + +namespace gandiva { + +constexpr int32_t DecimalTypeUtil::kMinAdjustedScale; + +#define DCHECK_TYPE(type) \ + { \ + DCHECK_GE(type->scale(), 0); \ + DCHECK_LE(type->precision(), kMaxPrecision); \ + } + +// Implementation of decimal rules. +Status DecimalTypeUtil::GetResultType(Op op, const Decimal128TypeVector& in_types, + Decimal128TypePtr* out_type) { + DCHECK_EQ(in_types.size(), 2); + + *out_type = nullptr; + auto t1 = in_types[0]; + auto t2 = in_types[1]; + DCHECK_TYPE(t1); + DCHECK_TYPE(t2); + + int32_t s1 = t1->scale(); + int32_t s2 = t2->scale(); + int32_t p1 = t1->precision(); + int32_t p2 = t2->precision(); + int32_t result_scale = 0; + int32_t result_precision = 0; + + switch (op) { + case kOpAdd: + case kOpSubtract: + result_scale = std::max(s1, s2); + result_precision = std::max(p1 - s1, p2 - s2) + result_scale + 1; + break; + + case kOpMultiply: + result_scale = s1 + s2; + result_precision = p1 + p2 + 1; + break; + + case kOpDivide: + result_scale = std::max(kMinAdjustedScale, s1 + p2 + 1); + result_precision = p1 - s1 + s2 + result_scale; + break; + + case kOpMod: + result_scale = std::max(s1, s2); + result_precision = std::min(p1 - s1, p2 - s2) + result_scale; + break; + } + *out_type = MakeAdjustedType(result_precision, result_scale); + return Status::OK(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/decimal_type_util.h b/src/arrow/cpp/src/gandiva/decimal_type_util.h new file mode 100644 index 000000000..2b496f6cb --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_type_util.h @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Adapted from Apache Impala + +#pragma once + +#include <algorithm> +#include <memory> + +#include "gandiva/arrow.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// @brief Handles conversion of scale/precision for operations on decimal types. +/// TODO : do validations for all of these. +class GANDIVA_EXPORT DecimalTypeUtil { + public: + enum Op { + kOpAdd, + kOpSubtract, + kOpMultiply, + kOpDivide, + kOpMod, + }; + + /// The maximum precision representable by a 4-byte decimal + static constexpr int32_t kMaxDecimal32Precision = 9; + + /// The maximum precision representable by a 8-byte decimal + static constexpr int32_t kMaxDecimal64Precision = 18; + + /// The maximum precision representable by a 16-byte decimal + static constexpr int32_t kMaxPrecision = 38; + + // The maximum scale representable. + static constexpr int32_t kMaxScale = kMaxPrecision; + + // When operating on decimal inputs, the integer part of the output can exceed the + // max precision. In such cases, the scale can be reduced, up to a minimum of + // kMinAdjustedScale. + // * There is no strong reason for 6, but both SQLServer and Impala use 6 too. + static constexpr int32_t kMinAdjustedScale = 6; + + // For specified operation and input scale/precision, determine the output + // scale/precision. + static Status GetResultType(Op op, const Decimal128TypeVector& in_types, + Decimal128TypePtr* out_type); + + static Decimal128TypePtr MakeType(int32_t precision, int32_t scale) { + return std::dynamic_pointer_cast<arrow::Decimal128Type>( + arrow::decimal(precision, scale)); + } + + private: + // Reduce the scale if possible so that precision stays <= kMaxPrecision + static Decimal128TypePtr MakeAdjustedType(int32_t precision, int32_t scale) { + if (precision > kMaxPrecision) { + int32_t min_scale = std::min(scale, kMinAdjustedScale); + int32_t delta = precision - kMaxPrecision; + precision = kMaxPrecision; + scale = std::max(scale - delta, min_scale); + } + return MakeType(precision, scale); + } +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/decimal_type_util_test.cc b/src/arrow/cpp/src/gandiva/decimal_type_util_test.cc new file mode 100644 index 000000000..98ea0bb16 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_type_util_test.cc @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Adapted from Apache Impala + +#include <gtest/gtest.h> + +#include "gandiva/decimal_type_util.h" +#include "tests/test_util.h" + +namespace gandiva { + +#define DECIMAL_TYPE(p, s) DecimalTypeUtil::MakeType(p, s) + +Decimal128TypePtr DoOp(DecimalTypeUtil::Op op, Decimal128TypePtr d1, + Decimal128TypePtr d2) { + Decimal128TypePtr ret_type; + ARROW_EXPECT_OK(DecimalTypeUtil::GetResultType(op, {d1, d2}, &ret_type)); + return ret_type; +} + +TEST(DecimalResultTypes, Basic) { + EXPECT_ARROW_TYPE_EQUALS( + DECIMAL_TYPE(31, 10), + DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 10), DECIMAL_TYPE(30, 10))); + + EXPECT_ARROW_TYPE_EQUALS( + DECIMAL_TYPE(32, 6), + DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 6), DECIMAL_TYPE(30, 5))); + + EXPECT_ARROW_TYPE_EQUALS( + DECIMAL_TYPE(38, 9), + DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(30, 10), DECIMAL_TYPE(38, 10))); + + EXPECT_ARROW_TYPE_EQUALS( + DECIMAL_TYPE(38, 9), + DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(38, 10), DECIMAL_TYPE(38, 38))); + + EXPECT_ARROW_TYPE_EQUALS( + DECIMAL_TYPE(38, 6), + DoOp(DecimalTypeUtil::kOpAdd, DECIMAL_TYPE(38, 10), DECIMAL_TYPE(38, 2))); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/decimal_xlarge.cc b/src/arrow/cpp/src/gandiva/decimal_xlarge.cc new file mode 100644 index 000000000..caebd8b09 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_xlarge.cc @@ -0,0 +1,284 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Operations that can deal with very large values (256-bit). +// +// The intermediate results with decimal can be larger than what can fit into 128-bit, +// but the final results can fit in 128-bit after scaling down. These functions deal +// with operations on the intermediate values. +// + +#include "gandiva/decimal_xlarge.h" + +#include <boost/multiprecision/cpp_int.hpp> +#include <limits> +#include <vector> + +#include "arrow/util/basic_decimal.h" +#include "arrow/util/logging.h" +#include "gandiva/decimal_type_util.h" + +#ifndef GANDIVA_UNIT_TEST +#include "gandiva/engine.h" +#include "gandiva/exported_funcs.h" + +namespace gandiva { + +void ExportedDecimalFunctions::AddMappings(Engine* engine) const { + std::vector<llvm::Type*> args; + auto types = engine->types(); + + // gdv_multiply_and_scale_down + args = {types->i64_type(), // int64_t x_high + types->i64_type(), // uint64_t x_low + types->i64_type(), // int64_t y_high + types->i64_type(), // uint64_t x_low + types->i32_type(), // int32_t reduce_scale_by + types->i64_ptr_type(), // int64_t* out_high + types->i64_ptr_type(), // uint64_t* out_low + types->i8_ptr_type()}; // bool* overflow + + engine->AddGlobalMappingForFunc( + "gdv_xlarge_multiply_and_scale_down", types->void_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_xlarge_multiply_and_scale_down)); + + // gdv_xlarge_scale_up_and_divide + args = {types->i64_type(), // int64_t x_high + types->i64_type(), // uint64_t x_low + types->i64_type(), // int64_t y_high + types->i64_type(), // uint64_t y_low + types->i32_type(), // int32_t increase_scale_by + types->i64_ptr_type(), // int64_t* out_high + types->i64_ptr_type(), // uint64_t* out_low + types->i8_ptr_type()}; // bool* overflow + + engine->AddGlobalMappingForFunc( + "gdv_xlarge_scale_up_and_divide", types->void_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_xlarge_scale_up_and_divide)); + + // gdv_xlarge_mod + args = {types->i64_type(), // int64_t x_high + types->i64_type(), // uint64_t x_low + types->i32_type(), // int32_t x_scale + types->i64_type(), // int64_t y_high + types->i64_type(), // uint64_t y_low + types->i32_type(), // int32_t y_scale + types->i64_ptr_type(), // int64_t* out_high + types->i64_ptr_type()}; // uint64_t* out_low + + engine->AddGlobalMappingForFunc("gdv_xlarge_mod", types->void_type() /*return_type*/, + args, reinterpret_cast<void*>(gdv_xlarge_mod)); + + // gdv_xlarge_compare + args = {types->i64_type(), // int64_t x_high + types->i64_type(), // uint64_t x_low + types->i32_type(), // int32_t x_scale + types->i64_type(), // int64_t y_high + types->i64_type(), // uint64_t y_low + types->i32_type()}; // int32_t y_scale + + engine->AddGlobalMappingForFunc("gdv_xlarge_compare", types->i32_type() /*return_type*/, + args, reinterpret_cast<void*>(gdv_xlarge_compare)); +} + +} // namespace gandiva + +#endif // !GANDIVA_UNIT_TEST + +using arrow::BasicDecimal128; +using boost::multiprecision::int256_t; + +namespace gandiva { +namespace internal { + +// Convert to 256-bit integer from 128-bit decimal. +static int256_t ConvertToInt256(BasicDecimal128 in) { + int256_t v = in.high_bits(); + v <<= 64; + v |= in.low_bits(); + return v; +} + +// Convert to 128-bit decimal from 256-bit integer. +// If there is an overflow, the output is undefined. +static BasicDecimal128 ConvertToDecimal128(int256_t in, bool* overflow) { + BasicDecimal128 result; + constexpr int256_t UINT64_MASK = std::numeric_limits<uint64_t>::max(); + + int256_t in_abs = abs(in); + bool is_negative = in < 0; + + uint64_t low = (in_abs & UINT64_MASK).convert_to<uint64_t>(); + in_abs >>= 64; + uint64_t high = (in_abs & UINT64_MASK).convert_to<uint64_t>(); + in_abs >>= 64; + + if (in_abs > 0) { + // we've shifted in by 128-bit, so nothing should be left. + *overflow = true; + } else if (high > INT64_MAX) { + // the high-bit must not be set (signed 128-bit). + *overflow = true; + } else { + result = BasicDecimal128(static_cast<int64_t>(high), low); + if (result > BasicDecimal128::GetMaxValue()) { + *overflow = true; + } + } + return is_negative ? -result : result; +} + +static constexpr int32_t kMaxLargeScale = 2 * DecimalTypeUtil::kMaxPrecision; + +// Compute the scale multipliers once. +static std::array<int256_t, kMaxLargeScale + 1> kLargeScaleMultipliers = + ([]() -> std::array<int256_t, kMaxLargeScale + 1> { + std::array<int256_t, kMaxLargeScale + 1> values; + values[0] = 1; + for (int32_t idx = 1; idx <= kMaxLargeScale; idx++) { + values[idx] = values[idx - 1] * 10; + } + return values; + })(); + +static int256_t GetScaleMultiplier(int scale) { + DCHECK_GE(scale, 0); + DCHECK_LE(scale, kMaxLargeScale); + + return kLargeScaleMultipliers[scale]; +} + +// divide input by 10^reduce_by, and round up the fractional part. +static int256_t ReduceScaleBy(int256_t in, int32_t reduce_by) { + if (reduce_by == 0) { + // nothing to do. + return in; + } + + int256_t divisor = GetScaleMultiplier(reduce_by); + DCHECK_GT(divisor, 0); + DCHECK_EQ(divisor % 2, 0); // multiple of 10. + auto result = in / divisor; + auto remainder = in % divisor; + // round up (same as BasicDecimal128::ReduceScaleBy) + if (abs(remainder) >= (divisor >> 1)) { + result += (in > 0 ? 1 : -1); + } + return result; +} + +// multiply input by 10^increase_by. +static int256_t IncreaseScaleBy(int256_t in, int32_t increase_by) { + DCHECK_GE(increase_by, 0); + DCHECK_LE(increase_by, 2 * DecimalTypeUtil::kMaxPrecision); + + return in * GetScaleMultiplier(increase_by); +} + +} // namespace internal +} // namespace gandiva + +extern "C" { + +void gdv_xlarge_multiply_and_scale_down(int64_t x_high, uint64_t x_low, int64_t y_high, + uint64_t y_low, int32_t reduce_scale_by, + int64_t* out_high, uint64_t* out_low, + bool* overflow) { + BasicDecimal128 x{x_high, x_low}; + BasicDecimal128 y{y_high, y_low}; + auto intermediate_result = + gandiva::internal::ConvertToInt256(x) * gandiva::internal::ConvertToInt256(y); + intermediate_result = + gandiva::internal::ReduceScaleBy(intermediate_result, reduce_scale_by); + auto result = gandiva::internal::ConvertToDecimal128(intermediate_result, overflow); + *out_high = result.high_bits(); + *out_low = result.low_bits(); +} + +void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t y_high, + uint64_t y_low, int32_t increase_scale_by, + int64_t* out_high, uint64_t* out_low, + bool* overflow) { + BasicDecimal128 x{x_high, x_low}; + BasicDecimal128 y{y_high, y_low}; + + int256_t x_large = gandiva::internal::ConvertToInt256(x); + int256_t x_large_scaled_up = + gandiva::internal::IncreaseScaleBy(x_large, increase_scale_by); + int256_t y_large = gandiva::internal::ConvertToInt256(y); + int256_t result_large = x_large_scaled_up / y_large; + int256_t remainder_large = x_large_scaled_up % y_large; + + // Since we are scaling up and then, scaling down, round-up the result (+1 for +ve, + // -1 for -ve), if the remainder is >= 2 * divisor. + if (abs(2 * remainder_large) >= abs(y_large)) { + // x +ve and y +ve, result is +ve => (1 ^ 1) + 1 = 0 + 1 = +1 + // x +ve and y -ve, result is -ve => (-1 ^ 1) + 1 = -2 + 1 = -1 + // x +ve and y -ve, result is -ve => (1 ^ -1) + 1 = -2 + 1 = -1 + // x -ve and y -ve, result is +ve => (-1 ^ -1) + 1 = 0 + 1 = +1 + result_large += (x.Sign() ^ y.Sign()) + 1; + } + auto result = gandiva::internal::ConvertToDecimal128(result_large, overflow); + *out_high = result.high_bits(); + *out_low = result.low_bits(); +} + +void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_high, + uint64_t y_low, int32_t y_scale, int64_t* out_high, + uint64_t* out_low) { + BasicDecimal128 x{x_high, x_low}; + BasicDecimal128 y{y_high, y_low}; + + int256_t x_large = gandiva::internal::ConvertToInt256(x); + int256_t y_large = gandiva::internal::ConvertToInt256(y); + if (x_scale < y_scale) { + x_large = gandiva::internal::IncreaseScaleBy(x_large, y_scale - x_scale); + } else { + y_large = gandiva::internal::IncreaseScaleBy(y_large, x_scale - y_scale); + } + auto intermediate_result = x_large % y_large; + bool overflow = false; + auto result = gandiva::internal::ConvertToDecimal128(intermediate_result, &overflow); + DCHECK_EQ(overflow, false); + + *out_high = result.high_bits(); + *out_low = result.low_bits(); +} + +int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale, + int64_t y_high, uint64_t y_low, int32_t y_scale) { + BasicDecimal128 x{x_high, x_low}; + BasicDecimal128 y{y_high, y_low}; + + int256_t x_large = gandiva::internal::ConvertToInt256(x); + int256_t y_large = gandiva::internal::ConvertToInt256(y); + if (x_scale < y_scale) { + x_large = gandiva::internal::IncreaseScaleBy(x_large, y_scale - x_scale); + } else { + y_large = gandiva::internal::IncreaseScaleBy(y_large, x_scale - y_scale); + } + + if (x_large == y_large) { + return 0; + } else if (x_large < y_large) { + return -1; + } else { + return 1; + } +} + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/decimal_xlarge.h b/src/arrow/cpp/src/gandiva/decimal_xlarge.h new file mode 100644 index 000000000..264329775 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/decimal_xlarge.h @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> + +/// Stub functions to deal with extra large decimals that can be accessed from LLVM-IR +/// code. +extern "C" { + +void gdv_xlarge_multiply_and_scale_down(int64_t x_high, uint64_t x_low, int64_t y_high, + uint64_t y_low, int32_t reduce_scale_by, + int64_t* out_high, uint64_t* out_low, + bool* overflow); + +void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t y_high, + uint64_t y_low, int32_t increase_scale_by, + int64_t* out_high, uint64_t* out_low, bool* overflow); + +void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t y_high, + uint64_t y_low, int32_t y_scale, int64_t* out_high, + uint64_t* out_low); + +int32_t gdv_xlarge_compare(int64_t x_high, uint64_t x_low, int32_t x_scale, + int64_t y_high, uint64_t y_low, int32_t y_scale); +} diff --git a/src/arrow/cpp/src/gandiva/dex.h b/src/arrow/cpp/src/gandiva/dex.h new file mode 100644 index 000000000..d1115c051 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/dex.h @@ -0,0 +1,396 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <unordered_set> +#include <vector> + +#include "gandiva/dex_visitor.h" +#include "gandiva/field_descriptor.h" +#include "gandiva/func_descriptor.h" +#include "gandiva/function_holder.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/in_holder.h" +#include "gandiva/literal_holder.h" +#include "gandiva/native_function.h" +#include "gandiva/value_validity_pair.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Decomposed expression : the validity and value are separated. +class GANDIVA_EXPORT Dex { + public: + /// Derived classes should simply invoke the Visit api of the visitor. + virtual void Accept(DexVisitor& visitor) = 0; + virtual ~Dex() = default; +}; + +/// Base class for other Vector related Dex. +class GANDIVA_EXPORT VectorReadBaseDex : public Dex { + public: + explicit VectorReadBaseDex(FieldDescriptorPtr field_desc) : field_desc_(field_desc) {} + + const std::string& FieldName() const { return field_desc_->Name(); } + + DataTypePtr FieldType() const { return field_desc_->Type(); } + + FieldPtr Field() const { return field_desc_->field(); } + + protected: + FieldDescriptorPtr field_desc_; +}; + +/// validity component of a ValueVector +class GANDIVA_EXPORT VectorReadValidityDex : public VectorReadBaseDex { + public: + explicit VectorReadValidityDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int ValidityIdx() const { return field_desc_->validity_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// value component of a fixed-len ValueVector +class GANDIVA_EXPORT VectorReadFixedLenValueDex : public VectorReadBaseDex { + public: + explicit VectorReadFixedLenValueDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// value component of a variable-len ValueVector +class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex { + public: + explicit VectorReadVarLenValueDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// validity based on a local bitmap. +class GANDIVA_EXPORT LocalBitMapValidityDex : public Dex { + public: + explicit LocalBitMapValidityDex(int local_bitmap_idx) + : local_bitmap_idx_(local_bitmap_idx) {} + + int local_bitmap_idx() const { return local_bitmap_idx_; } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } + + private: + int local_bitmap_idx_; +}; + +/// base function expression +class GANDIVA_EXPORT FuncDex : public Dex { + public: + FuncDex(FuncDescriptorPtr func_descriptor, const NativeFunction* native_function, + FunctionHolderPtr function_holder, const ValueValidityPairVector& args) + : func_descriptor_(func_descriptor), + native_function_(native_function), + function_holder_(function_holder), + args_(args) {} + + FuncDescriptorPtr func_descriptor() const { return func_descriptor_; } + + const NativeFunction* native_function() const { return native_function_; } + + FunctionHolderPtr function_holder() const { return function_holder_; } + + const ValueValidityPairVector& args() const { return args_; } + + private: + FuncDescriptorPtr func_descriptor_; + const NativeFunction* native_function_; + FunctionHolderPtr function_holder_; + ValueValidityPairVector args_; +}; + +/// A function expression that only deals with non-null inputs, and generates non-null +/// outputs. +class GANDIVA_EXPORT NonNullableFuncDex : public FuncDex { + public: + NonNullableFuncDex(FuncDescriptorPtr func_descriptor, + const NativeFunction* native_function, + FunctionHolderPtr function_holder, + const ValueValidityPairVector& args) + : FuncDex(func_descriptor, native_function, function_holder, args) {} + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// A function expression that deals with nullable inputs, but generates non-null +/// outputs. +class GANDIVA_EXPORT NullableNeverFuncDex : public FuncDex { + public: + NullableNeverFuncDex(FuncDescriptorPtr func_descriptor, + const NativeFunction* native_function, + FunctionHolderPtr function_holder, + const ValueValidityPairVector& args) + : FuncDex(func_descriptor, native_function, function_holder, args) {} + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// A function expression that deals with nullable inputs, and +/// nullable outputs. +class GANDIVA_EXPORT NullableInternalFuncDex : public FuncDex { + public: + NullableInternalFuncDex(FuncDescriptorPtr func_descriptor, + const NativeFunction* native_function, + FunctionHolderPtr function_holder, + const ValueValidityPairVector& args, int local_bitmap_idx) + : FuncDex(func_descriptor, native_function, function_holder, args), + local_bitmap_idx_(local_bitmap_idx) {} + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } + + /// The validity of the function result is saved in this bitmap. + int local_bitmap_idx() const { return local_bitmap_idx_; } + + private: + int local_bitmap_idx_; +}; + +/// special validity type that always returns true. +class GANDIVA_EXPORT TrueDex : public Dex { + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// special validity type that always returns false. +class GANDIVA_EXPORT FalseDex : public Dex { + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// decomposed expression for a literal. +class GANDIVA_EXPORT LiteralDex : public Dex { + public: + LiteralDex(DataTypePtr type, const LiteralHolder& holder) + : type_(type), holder_(holder) {} + + const DataTypePtr& type() const { return type_; } + + const LiteralHolder& holder() const { return holder_; } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } + + private: + DataTypePtr type_; + LiteralHolder holder_; +}; + +/// decomposed if-else expression. +class GANDIVA_EXPORT IfDex : public Dex { + public: + IfDex(ValueValidityPairPtr condition_vv, ValueValidityPairPtr then_vv, + ValueValidityPairPtr else_vv, DataTypePtr result_type, int local_bitmap_idx, + bool is_terminal_else) + : condition_vv_(condition_vv), + then_vv_(then_vv), + else_vv_(else_vv), + result_type_(result_type), + local_bitmap_idx_(local_bitmap_idx), + is_terminal_else_(is_terminal_else) {} + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } + + const ValueValidityPair& condition_vv() const { return *condition_vv_; } + const ValueValidityPair& then_vv() const { return *then_vv_; } + const ValueValidityPair& else_vv() const { return *else_vv_; } + + /// The validity of the result is saved in this bitmap. + int local_bitmap_idx() const { return local_bitmap_idx_; } + + /// is this a terminal else ? i.e no nested if-else underneath. + bool is_terminal_else() const { return is_terminal_else_; } + + const DataTypePtr& result_type() const { return result_type_; } + + private: + ValueValidityPairPtr condition_vv_; + ValueValidityPairPtr then_vv_; + ValueValidityPairPtr else_vv_; + DataTypePtr result_type_; + int local_bitmap_idx_; + bool is_terminal_else_; +}; + +// decomposed boolean expression. +class GANDIVA_EXPORT BooleanDex : public Dex { + public: + BooleanDex(const ValueValidityPairVector& args, int local_bitmap_idx) + : args_(args), local_bitmap_idx_(local_bitmap_idx) {} + + const ValueValidityPairVector& args() const { return args_; } + + /// The validity of the result is saved in this bitmap. + int local_bitmap_idx() const { return local_bitmap_idx_; } + + private: + ValueValidityPairVector args_; + int local_bitmap_idx_; +}; + +/// Boolean-AND expression +class GANDIVA_EXPORT BooleanAndDex : public BooleanDex { + public: + BooleanAndDex(const ValueValidityPairVector& args, int local_bitmap_idx) + : BooleanDex(args, local_bitmap_idx) {} + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +/// Boolean-OR expression +class GANDIVA_EXPORT BooleanOrDex : public BooleanDex { + public: + BooleanOrDex(const ValueValidityPairVector& args, int local_bitmap_idx) + : BooleanDex(args, local_bitmap_idx) {} + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + +// decomposed in expression. +template <typename Type> +class InExprDex; + +template <typename Type> +class InExprDexBase : public Dex { + public: + InExprDexBase(const ValueValidityPairVector& args, + const std::unordered_set<Type>& values) + : args_(args) { + in_holder_.reset(new InHolder<Type>(values)); + } + + const ValueValidityPairVector& args() const { return args_; } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } + + const std::string& runtime_function() const { return runtime_function_; } + + const std::shared_ptr<InHolder<Type>>& in_holder() const { return in_holder_; } + + protected: + ValueValidityPairVector args_; + std::string runtime_function_; + std::shared_ptr<InHolder<Type>> in_holder_; +}; + +template <> +class InExprDexBase<gandiva::DecimalScalar128> : public Dex { + public: + InExprDexBase(const ValueValidityPairVector& args, + const std::unordered_set<gandiva::DecimalScalar128>& values, + int32_t precision, int32_t scale) + : args_(args), precision_(precision), scale_(scale) { + in_holder_.reset(new InHolder<gandiva::DecimalScalar128>(values)); + } + + int32_t get_precision() const { return precision_; } + + int32_t get_scale() const { return scale_; } + + const ValueValidityPairVector& args() const { return args_; } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } + + const std::string& runtime_function() const { return runtime_function_; } + + const std::shared_ptr<InHolder<gandiva::DecimalScalar128>>& in_holder() const { + return in_holder_; + } + + protected: + ValueValidityPairVector args_; + std::string runtime_function_; + std::shared_ptr<InHolder<gandiva::DecimalScalar128>> in_holder_; + int32_t precision_, scale_; +}; + +template <> +class InExprDex<int32_t> : public InExprDexBase<int32_t> { + public: + InExprDex(const ValueValidityPairVector& args, + const std::unordered_set<int32_t>& values) + : InExprDexBase(args, values) { + runtime_function_ = "gdv_fn_in_expr_lookup_int32"; + } +}; + +template <> +class InExprDex<int64_t> : public InExprDexBase<int64_t> { + public: + InExprDex(const ValueValidityPairVector& args, + const std::unordered_set<int64_t>& values) + : InExprDexBase(args, values) { + runtime_function_ = "gdv_fn_in_expr_lookup_int64"; + } +}; + +template <> +class InExprDex<float> : public InExprDexBase<float> { + public: + InExprDex(const ValueValidityPairVector& args, const std::unordered_set<float>& values) + : InExprDexBase(args, values) { + runtime_function_ = "gdv_fn_in_expr_lookup_float"; + } +}; + +template <> +class InExprDex<double> : public InExprDexBase<double> { + public: + InExprDex(const ValueValidityPairVector& args, const std::unordered_set<double>& values) + : InExprDexBase(args, values) { + runtime_function_ = "gdv_fn_in_expr_lookup_double"; + } +}; + +template <> +class InExprDex<gandiva::DecimalScalar128> + : public InExprDexBase<gandiva::DecimalScalar128> { + public: + InExprDex(const ValueValidityPairVector& args, + const std::unordered_set<gandiva::DecimalScalar128>& values, + int32_t precision, int32_t scale) + : InExprDexBase<gandiva::DecimalScalar128>(args, values, precision, scale) { + runtime_function_ = "gdv_fn_in_expr_lookup_decimal"; + } +}; + +template <> +class InExprDex<std::string> : public InExprDexBase<std::string> { + public: + InExprDex(const ValueValidityPairVector& args, + const std::unordered_set<std::string>& values) + : InExprDexBase(args, values) { + runtime_function_ = "gdv_fn_in_expr_lookup_utf8"; + } +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/dex_visitor.h b/src/arrow/cpp/src/gandiva/dex_visitor.h new file mode 100644 index 000000000..5d160bb22 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/dex_visitor.h @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <string> + +#include "arrow/util/logging.h" +#include "gandiva/decimal_scalar.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class VectorReadValidityDex; +class VectorReadFixedLenValueDex; +class VectorReadVarLenValueDex; +class LocalBitMapValidityDex; +class LiteralDex; +class TrueDex; +class FalseDex; +class NonNullableFuncDex; +class NullableNeverFuncDex; +class NullableInternalFuncDex; +class IfDex; +class BooleanAndDex; +class BooleanOrDex; +template <typename Type> +class InExprDexBase; + +/// \brief Visitor for decomposed expression. +class GANDIVA_EXPORT DexVisitor { + public: + virtual ~DexVisitor() = default; + + virtual void Visit(const VectorReadValidityDex& dex) = 0; + virtual void Visit(const VectorReadFixedLenValueDex& dex) = 0; + virtual void Visit(const VectorReadVarLenValueDex& dex) = 0; + virtual void Visit(const LocalBitMapValidityDex& dex) = 0; + virtual void Visit(const TrueDex& dex) = 0; + virtual void Visit(const FalseDex& dex) = 0; + virtual void Visit(const LiteralDex& dex) = 0; + virtual void Visit(const NonNullableFuncDex& dex) = 0; + virtual void Visit(const NullableNeverFuncDex& dex) = 0; + virtual void Visit(const NullableInternalFuncDex& dex) = 0; + virtual void Visit(const IfDex& dex) = 0; + virtual void Visit(const BooleanAndDex& dex) = 0; + virtual void Visit(const BooleanOrDex& dex) = 0; + virtual void Visit(const InExprDexBase<int32_t>& dex) = 0; + virtual void Visit(const InExprDexBase<int64_t>& dex) = 0; + virtual void Visit(const InExprDexBase<float>& dex) = 0; + virtual void Visit(const InExprDexBase<double>& dex) = 0; + virtual void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) = 0; + virtual void Visit(const InExprDexBase<std::string>& dex) = 0; +}; + +/// Default implementation with only DCHECK(). +#define VISIT_DCHECK(DEX_CLASS) \ + void Visit(const DEX_CLASS& dex) override { DCHECK(0); } + +class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor { + VISIT_DCHECK(VectorReadValidityDex) + VISIT_DCHECK(VectorReadFixedLenValueDex) + VISIT_DCHECK(VectorReadVarLenValueDex) + VISIT_DCHECK(LocalBitMapValidityDex) + VISIT_DCHECK(TrueDex) + VISIT_DCHECK(FalseDex) + VISIT_DCHECK(LiteralDex) + VISIT_DCHECK(NonNullableFuncDex) + VISIT_DCHECK(NullableNeverFuncDex) + VISIT_DCHECK(NullableInternalFuncDex) + VISIT_DCHECK(IfDex) + VISIT_DCHECK(BooleanAndDex) + VISIT_DCHECK(BooleanOrDex) + VISIT_DCHECK(InExprDexBase<int32_t>) + VISIT_DCHECK(InExprDexBase<int64_t>) + VISIT_DCHECK(InExprDexBase<float>) + VISIT_DCHECK(InExprDexBase<double>) + VISIT_DCHECK(InExprDexBase<gandiva::DecimalScalar128>) + VISIT_DCHECK(InExprDexBase<std::string>) +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/engine.cc b/src/arrow/cpp/src/gandiva/engine.cc new file mode 100644 index 000000000..f0b768f5f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/engine.cc @@ -0,0 +1,338 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO(wesm): LLVM 7 produces pesky C4244 that disable pragmas around the LLVM +// includes seem to not fix as with LLVM 6 +#if defined(_MSC_VER) +#pragma warning(disable : 4244) +#endif + +#include "gandiva/engine.h" + +#include <iostream> +#include <memory> +#include <mutex> +#include <sstream> +#include <string> +#include <unordered_set> +#include <utility> + +#include "arrow/util/logging.h" + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4141) +#pragma warning(disable : 4146) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4624) +#endif + +#include <llvm/Analysis/Passes.h> +#include <llvm/Analysis/TargetTransformInfo.h> +#include <llvm/Bitcode/BitcodeReader.h> +#include <llvm/ExecutionEngine/ExecutionEngine.h> +#include <llvm/ExecutionEngine/MCJIT.h> +#include <llvm/IR/DataLayout.h> +#include <llvm/IR/IRBuilder.h> +#include <llvm/IR/LLVMContext.h> +#include <llvm/IR/LegacyPassManager.h> +#include <llvm/IR/Verifier.h> +#include <llvm/Linker/Linker.h> +#include <llvm/MC/SubtargetFeature.h> +#include <llvm/Support/DynamicLibrary.h> +#include <llvm/Support/Host.h> +#include <llvm/Support/TargetRegistry.h> +#include <llvm/Support/TargetSelect.h> +#include <llvm/Support/raw_ostream.h> +#include <llvm/Transforms/IPO.h> +#include <llvm/Transforms/IPO/PassManagerBuilder.h> +#include <llvm/Transforms/InstCombine/InstCombine.h> +#include <llvm/Transforms/Scalar.h> +#include <llvm/Transforms/Scalar/GVN.h> +#include <llvm/Transforms/Utils.h> +#include <llvm/Transforms/Vectorize.h> + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#include "arrow/util/make_unique.h" +#include "gandiva/configuration.h" +#include "gandiva/decimal_ir.h" +#include "gandiva/exported_funcs_registry.h" + +namespace gandiva { + +extern const unsigned char kPrecompiledBitcode[]; +extern const size_t kPrecompiledBitcodeSize; + +std::once_flag llvm_init_once_flag; +static bool llvm_init = false; +static llvm::StringRef cpu_name; +static llvm::SmallVector<std::string, 10> cpu_attrs; + +void Engine::InitOnce() { + DCHECK_EQ(llvm_init, false); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + llvm::InitializeNativeTargetDisassembler(); + llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr); + + cpu_name = llvm::sys::getHostCPUName(); + llvm::StringMap<bool> host_features; + std::string cpu_attrs_str; + if (llvm::sys::getHostCPUFeatures(host_features)) { + for (auto& f : host_features) { + std::string attr = f.second ? std::string("+") + f.first().str() + : std::string("-") + f.first().str(); + cpu_attrs.push_back(attr); + cpu_attrs_str += " " + attr; + } + } + ARROW_LOG(INFO) << "Detected CPU Name : " << cpu_name.str(); + ARROW_LOG(INFO) << "Detected CPU Features:" << cpu_attrs_str; + llvm_init = true; +} + +Engine::Engine(const std::shared_ptr<Configuration>& conf, + std::unique_ptr<llvm::LLVMContext> ctx, + std::unique_ptr<llvm::ExecutionEngine> engine, llvm::Module* module) + : context_(std::move(ctx)), + execution_engine_(std::move(engine)), + ir_builder_(arrow::internal::make_unique<llvm::IRBuilder<>>(*context_)), + module_(module), + types_(*context_), + optimize_(conf->optimize()) {} + +Status Engine::Init() { + // Add mappings for functions that can be accessed from LLVM/IR module. + AddGlobalMappings(); + + ARROW_RETURN_NOT_OK(LoadPreCompiledIR()); + ARROW_RETURN_NOT_OK(DecimalIR::AddFunctions(this)); + + return Status::OK(); +} + +/// factory method to construct the engine. +Status Engine::Make(const std::shared_ptr<Configuration>& conf, + std::unique_ptr<Engine>* out) { + std::call_once(llvm_init_once_flag, InitOnce); + + auto ctx = arrow::internal::make_unique<llvm::LLVMContext>(); + auto module = arrow::internal::make_unique<llvm::Module>("codegen", *ctx); + + // Capture before moving, ExecutionEngine does not allow retrieving the + // original Module. + auto module_ptr = module.get(); + + auto opt_level = + conf->optimize() ? llvm::CodeGenOpt::Aggressive : llvm::CodeGenOpt::None; + + // Note that the lifetime of the error string is not captured by the + // ExecutionEngine but only for the lifetime of the builder. Found by + // inspecting LLVM sources. + std::string builder_error; + + llvm::EngineBuilder engine_builder(std::move(module)); + + engine_builder.setEngineKind(llvm::EngineKind::JIT) + .setOptLevel(opt_level) + .setErrorStr(&builder_error); + + if (conf->target_host_cpu()) { + engine_builder.setMCPU(cpu_name); + engine_builder.setMAttrs(cpu_attrs); + } + std::unique_ptr<llvm::ExecutionEngine> exec_engine{engine_builder.create()}; + + if (exec_engine == nullptr) { + return Status::CodeGenError("Could not instantiate llvm::ExecutionEngine: ", + builder_error); + } + + std::unique_ptr<Engine> engine{ + new Engine(conf, std::move(ctx), std::move(exec_engine), module_ptr)}; + ARROW_RETURN_NOT_OK(engine->Init()); + *out = std::move(engine); + return Status::OK(); +} + +// This method was modified from its original version for a part of MLIR +// Original source from +// https://github.com/llvm/llvm-project/blob/9f2ce5b915a505a5488a5cf91bb0a8efa9ddfff7/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +// The original copyright notice follows. + +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +static void SetDataLayout(llvm::Module* module) { + auto target_triple = llvm::sys::getDefaultTargetTriple(); + std::string error_message; + auto target = llvm::TargetRegistry::lookupTarget(target_triple, error_message); + if (!target) { + return; + } + + std::string cpu(llvm::sys::getHostCPUName()); + llvm::SubtargetFeatures features; + llvm::StringMap<bool> host_features; + + if (llvm::sys::getHostCPUFeatures(host_features)) { + for (auto& f : host_features) { + features.AddFeature(f.first(), f.second); + } + } + + std::unique_ptr<llvm::TargetMachine> machine( + target->createTargetMachine(target_triple, cpu, features.getString(), {}, {})); + + module->setDataLayout(machine->createDataLayout()); +} +// end of the mofified method from MLIR + +// Handling for pre-compiled IR libraries. +Status Engine::LoadPreCompiledIR() { + auto bitcode = llvm::StringRef(reinterpret_cast<const char*>(kPrecompiledBitcode), + kPrecompiledBitcodeSize); + + /// Read from file into memory buffer. + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer_or_error = + llvm::MemoryBuffer::getMemBuffer(bitcode, "precompiled", false); + + ARROW_RETURN_IF(!buffer_or_error, + Status::CodeGenError("Could not load module from IR: ", + buffer_or_error.getError().message())); + + std::unique_ptr<llvm::MemoryBuffer> buffer = move(buffer_or_error.get()); + + /// Parse the IR module. + llvm::Expected<std::unique_ptr<llvm::Module>> module_or_error = + llvm::getOwningLazyBitcodeModule(move(buffer), *context()); + if (!module_or_error) { + // NOTE: llvm::handleAllErrors() fails linking with RTTI-disabled LLVM builds + // (ARROW-5148) + std::string str; + llvm::raw_string_ostream stream(str); + stream << module_or_error.takeError(); + return Status::CodeGenError(stream.str()); + } + std::unique_ptr<llvm::Module> ir_module = move(module_or_error.get()); + + // set dataLayout + SetDataLayout(ir_module.get()); + + ARROW_RETURN_IF(llvm::verifyModule(*ir_module, &llvm::errs()), + Status::CodeGenError("verify of IR Module failed")); + ARROW_RETURN_IF(llvm::Linker::linkModules(*module_, move(ir_module)), + Status::CodeGenError("failed to link IR Modules")); + + return Status::OK(); +} + +// Get rid of all functions that don't need to be compiled. +// This helps in reducing the overall compilation time. This pass is trivial, +// and is always done since the number of functions in gandiva is very high. +// (Adapted from Apache Impala) +// +// Done by marking all the unused functions as internal, and then, running +// a pass for dead code elimination. +Status Engine::RemoveUnusedFunctions() { + // Setup an optimiser pipeline + std::unique_ptr<llvm::legacy::PassManager> pass_manager( + new llvm::legacy::PassManager()); + + std::unordered_set<std::string> used_functions; + used_functions.insert(functions_to_compile_.begin(), functions_to_compile_.end()); + + pass_manager->add( + llvm::createInternalizePass([&used_functions](const llvm::GlobalValue& func) { + return (used_functions.find(func.getName().str()) != used_functions.end()); + })); + pass_manager->add(llvm::createGlobalDCEPass()); + pass_manager->run(*module_); + return Status::OK(); +} + +// Optimise and compile the module. +Status Engine::FinalizeModule() { + ARROW_RETURN_NOT_OK(RemoveUnusedFunctions()); + + if (optimize_) { + // misc passes to allow for inlining, vectorization, .. + std::unique_ptr<llvm::legacy::PassManager> pass_manager( + new llvm::legacy::PassManager()); + + llvm::TargetIRAnalysis target_analysis = + execution_engine_->getTargetMachine()->getTargetIRAnalysis(); + pass_manager->add(llvm::createTargetTransformInfoWrapperPass(target_analysis)); + pass_manager->add(llvm::createFunctionInliningPass()); + pass_manager->add(llvm::createInstructionCombiningPass()); + pass_manager->add(llvm::createPromoteMemoryToRegisterPass()); + pass_manager->add(llvm::createGVNPass()); + pass_manager->add(llvm::createNewGVNPass()); + pass_manager->add(llvm::createCFGSimplificationPass()); + pass_manager->add(llvm::createLoopVectorizePass()); + pass_manager->add(llvm::createSLPVectorizerPass()); + pass_manager->add(llvm::createGlobalOptimizerPass()); + + // run the optimiser + llvm::PassManagerBuilder pass_builder; + pass_builder.OptLevel = 3; + pass_builder.populateModulePassManager(*pass_manager); + pass_manager->run(*module_); + } + + ARROW_RETURN_IF(llvm::verifyModule(*module_, &llvm::errs()), + Status::CodeGenError("Module verification failed after optimizer")); + + // do the compilation + execution_engine_->finalizeObject(); + module_finalized_ = true; + + return Status::OK(); +} + +void* Engine::CompiledFunction(llvm::Function* irFunction) { + DCHECK(module_finalized_); + return execution_engine_->getPointerToFunction(irFunction); +} + +void Engine::AddGlobalMappingForFunc(const std::string& name, llvm::Type* ret_type, + const std::vector<llvm::Type*>& args, + void* function_ptr) { + constexpr bool is_var_arg = false; + auto prototype = llvm::FunctionType::get(ret_type, args, is_var_arg); + constexpr auto linkage = llvm::GlobalValue::ExternalLinkage; + auto fn = llvm::Function::Create(prototype, linkage, name, module()); + execution_engine_->addGlobalMapping(fn, function_ptr); +} + +void Engine::AddGlobalMappings() { ExportedFuncsRegistry::AddMappings(this); } + +std::string Engine::DumpIR() { + std::string ir; + llvm::raw_string_ostream stream(ir); + module_->print(stream, nullptr); + return ir; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/engine.h b/src/arrow/cpp/src/gandiva/engine.h new file mode 100644 index 000000000..d26b8aa0e --- /dev/null +++ b/src/arrow/cpp/src/gandiva/engine.h @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <set> +#include <string> +#include <vector> + +#include "arrow/util/macros.h" + +#include "arrow/util/logging.h" +#include "gandiva/configuration.h" +#include "gandiva/llvm_includes.h" +#include "gandiva/llvm_types.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief LLVM Execution engine wrapper. +class GANDIVA_EXPORT Engine { + public: + llvm::LLVMContext* context() { return context_.get(); } + llvm::IRBuilder<>* ir_builder() { return ir_builder_.get(); } + LLVMTypes* types() { return &types_; } + llvm::Module* module() { return module_; } + + /// Factory method to create and initialize the engine object. + /// + /// \param[in] config the engine configuration + /// \param[out] engine the created engine + static Status Make(const std::shared_ptr<Configuration>& config, + std::unique_ptr<Engine>* engine); + + /// Add the function to the list of IR functions that need to be compiled. + /// Compiling only the functions that are used by the module saves time. + void AddFunctionToCompile(const std::string& fname) { + DCHECK(!module_finalized_); + functions_to_compile_.push_back(fname); + } + + /// Optimise and compile the module. + Status FinalizeModule(); + + /// Get the compiled function corresponding to the irfunction. + void* CompiledFunction(llvm::Function* irFunction); + + // Create and add a mapping for the cpp function to make it accessible from LLVM. + void AddGlobalMappingForFunc(const std::string& name, llvm::Type* ret_type, + const std::vector<llvm::Type*>& args, void* func); + + /// Return the generated IR for the module. + std::string DumpIR(); + + private: + Engine(const std::shared_ptr<Configuration>& conf, + std::unique_ptr<llvm::LLVMContext> ctx, + std::unique_ptr<llvm::ExecutionEngine> engine, llvm::Module* module); + + // Post construction init. This _must_ be called after the constructor. + Status Init(); + + static void InitOnce(); + + llvm::ExecutionEngine& execution_engine() { return *execution_engine_; } + + /// load pre-compiled IR modules from precompiled_bitcode.cc and merge them into + /// the main module. + Status LoadPreCompiledIR(); + + // Create and add mappings for cpp functions that can be accessed from LLVM. + void AddGlobalMappings(); + + // Remove unused functions to reduce compile time. + Status RemoveUnusedFunctions(); + + std::unique_ptr<llvm::LLVMContext> context_; + std::unique_ptr<llvm::ExecutionEngine> execution_engine_; + std::unique_ptr<llvm::IRBuilder<>> ir_builder_; + llvm::Module* module_; + LLVMTypes types_; + + std::vector<std::string> functions_to_compile_; + + bool optimize_ = true; + bool module_finalized_ = false; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/engine_llvm_test.cc b/src/arrow/cpp/src/gandiva/engine_llvm_test.cc new file mode 100644 index 000000000..ef2275b34 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/engine_llvm_test.cc @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/engine.h" + +#include <gtest/gtest.h> +#include <functional> +#include "gandiva/llvm_types.h" +#include "gandiva/tests/test_util.h" + +namespace gandiva { + +typedef int64_t (*add_vector_func_t)(int64_t* data, int n); + +class TestEngine : public ::testing::Test { + protected: + llvm::Function* BuildVecAdd(Engine* engine) { + auto types = engine->types(); + llvm::IRBuilder<>* builder = engine->ir_builder(); + llvm::LLVMContext* context = engine->context(); + + // Create fn prototype : + // int64_t add_longs(int64_t *elements, int32_t nelements) + std::vector<llvm::Type*> arguments; + arguments.push_back(types->i64_ptr_type()); + arguments.push_back(types->i32_type()); + llvm::FunctionType* prototype = + llvm::FunctionType::get(types->i64_type(), arguments, false /*isVarArg*/); + + // Create fn + std::string func_name = "add_longs"; + engine->AddFunctionToCompile(func_name); + llvm::Function* fn = llvm::Function::Create( + prototype, llvm::GlobalValue::ExternalLinkage, func_name, engine->module()); + assert(fn != nullptr); + + // Name the arguments + llvm::Function::arg_iterator args = fn->arg_begin(); + llvm::Value* arg_elements = &*args; + arg_elements->setName("elements"); + ++args; + llvm::Value* arg_nelements = &*args; + arg_nelements->setName("nelements"); + ++args; + + llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context, "entry", fn); + llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context, "loop", fn); + llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context, "exit", fn); + + // Loop entry + builder->SetInsertPoint(loop_entry); + builder->CreateBr(loop_body); + + // Loop body + builder->SetInsertPoint(loop_body); + + llvm::PHINode* loop_var = builder->CreatePHI(types->i32_type(), 2, "loop_var"); + llvm::PHINode* sum = builder->CreatePHI(types->i64_type(), 2, "sum"); + + loop_var->addIncoming(types->i32_constant(0), loop_entry); + sum->addIncoming(types->i64_constant(0), loop_entry); + + // setup loop PHI + llvm::Value* loop_update = + builder->CreateAdd(loop_var, types->i32_constant(1), "loop_var+1"); + loop_var->addIncoming(loop_update, loop_body); + + // get the current value + llvm::Value* offset = CreateGEP(builder, arg_elements, loop_var, "offset"); + llvm::Value* current_value = CreateLoad(builder, offset, "value"); + + // setup sum PHI + llvm::Value* sum_update = builder->CreateAdd(sum, current_value, "sum+ith"); + sum->addIncoming(sum_update, loop_body); + + // check loop_var + llvm::Value* loop_var_check = + builder->CreateICmpSLT(loop_update, arg_nelements, "loop_var < nrec"); + builder->CreateCondBr(loop_var_check, loop_body, loop_exit); + + // Loop exit + builder->SetInsertPoint(loop_exit); + builder->CreateRet(sum_update); + return fn; + } + + void BuildEngine() { ASSERT_OK(Engine::Make(TestConfiguration(), &engine)); } + + std::unique_ptr<Engine> engine; + std::shared_ptr<Configuration> configuration = TestConfiguration(); +}; + +TEST_F(TestEngine, TestAddUnoptimised) { + configuration->set_optimize(false); + BuildEngine(); + + llvm::Function* ir_func = BuildVecAdd(engine.get()); + ASSERT_OK(engine->FinalizeModule()); + auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func)); + + int64_t my_array[] = {1, 3, -5, 8, 10}; + EXPECT_EQ(add_func(my_array, 5), 17); +} + +TEST_F(TestEngine, TestAddOptimised) { + configuration->set_optimize(true); + BuildEngine(); + + llvm::Function* ir_func = BuildVecAdd(engine.get()); + ASSERT_OK(engine->FinalizeModule()); + auto add_func = reinterpret_cast<add_vector_func_t>(engine->CompiledFunction(ir_func)); + + int64_t my_array[] = {1, 3, -5, 8, 10}; + EXPECT_EQ(add_func(my_array, 5), 17); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/eval_batch.h b/src/arrow/cpp/src/gandiva/eval_batch.h new file mode 100644 index 000000000..25d9ab1d9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/eval_batch.h @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/util/logging.h" + +#include "gandiva/arrow.h" +#include "gandiva/execution_context.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/local_bitmaps_holder.h" + +namespace gandiva { + +/// \brief The buffers corresponding to one batch of records, used for +/// expression evaluation. +class EvalBatch { + public: + explicit EvalBatch(int64_t num_records, int num_buffers, int num_local_bitmaps) + : num_records_(num_records), num_buffers_(num_buffers) { + if (num_buffers > 0) { + buffers_array_.reset(new uint8_t*[num_buffers]); + buffer_offsets_array_.reset(new int64_t[num_buffers]); + } + local_bitmaps_holder_.reset(new LocalBitMapsHolder(num_records, num_local_bitmaps)); + execution_context_.reset(new ExecutionContext()); + } + + int64_t num_records() const { return num_records_; } + + uint8_t** GetBufferArray() const { return buffers_array_.get(); } + + int64_t* GetBufferOffsetArray() const { return buffer_offsets_array_.get(); } + + int GetNumBuffers() const { return num_buffers_; } + + uint8_t* GetBuffer(int idx) const { + DCHECK(idx <= num_buffers_); + return (buffers_array_.get())[idx]; + } + + int64_t GetBufferOffset(int idx) const { + DCHECK(idx <= num_buffers_); + return (buffer_offsets_array_.get())[idx]; + } + + void SetBuffer(int idx, uint8_t* buffer, int64_t offset) { + DCHECK(idx <= num_buffers_); + (buffers_array_.get())[idx] = buffer; + (buffer_offsets_array_.get())[idx] = offset; + } + + int GetNumLocalBitMaps() const { return local_bitmaps_holder_->GetNumLocalBitMaps(); } + + int64_t GetLocalBitmapSize() const { + return local_bitmaps_holder_->GetLocalBitMapSize(); + } + + uint8_t* GetLocalBitMap(int idx) const { + DCHECK(idx <= GetNumLocalBitMaps()); + return local_bitmaps_holder_->GetLocalBitMap(idx); + } + + uint8_t** GetLocalBitMapArray() const { + return local_bitmaps_holder_->GetLocalBitMapArray(); + } + + ExecutionContext* GetExecutionContext() const { return execution_context_.get(); } + + private: + /// number of records in the current batch. + int64_t num_records_; + + // number of buffers. + int num_buffers_; + + /// An array of 'num_buffers_', each containing a buffer. The buffer + /// sizes depends on the data type, but all of them have the same + /// number of slots (equal to num_records_). + std::unique_ptr<uint8_t*[]> buffers_array_; + + /// An array of 'num_buffers_', each containing the offset for + /// corresponding buffer. + std::unique_ptr<int64_t[]> buffer_offsets_array_; + + std::unique_ptr<LocalBitMapsHolder> local_bitmaps_holder_; + + std::unique_ptr<ExecutionContext> execution_context_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/execution_context.h b/src/arrow/cpp/src/gandiva/execution_context.h new file mode 100644 index 000000000..efa546874 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/execution_context.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include "gandiva/simple_arena.h" + +namespace gandiva { + +/// Execution context during llvm evaluation +class ExecutionContext { + public: + explicit ExecutionContext(arrow::MemoryPool* pool = arrow::default_memory_pool()) + : arena_(pool) {} + std::string get_error() const { return error_msg_; } + + void set_error_msg(const char* error_msg) { + // Remember the first error only. + if (error_msg_.empty()) { + error_msg_ = std::string(error_msg); + } + } + + bool has_error() const { return !error_msg_.empty(); } + + SimpleArena* arena() { return &arena_; } + + void Reset() { + error_msg_.clear(); + arena_.Reset(); + } + + private: + std::string error_msg_; + SimpleArena arena_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/exported_funcs.h b/src/arrow/cpp/src/gandiva/exported_funcs.h new file mode 100644 index 000000000..582052660 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/exported_funcs.h @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <gandiva/exported_funcs_registry.h> +#include <vector> + +namespace gandiva { + +class Engine; + +// Base-class type for exporting functions that can be accessed from LLVM/IR. +class ExportedFuncsBase { + public: + virtual ~ExportedFuncsBase() = default; + + virtual void AddMappings(Engine* engine) const = 0; +}; + +// Class for exporting Stub functions +class ExportedStubFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedStubFunctions); + +// Class for exporting Context functions +class ExportedContextFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedContextFunctions); + +// Class for exporting Time functions +class ExportedTimeFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedTimeFunctions); + +// Class for exporting Decimal functions +class ExportedDecimalFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedDecimalFunctions); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/exported_funcs_registry.cc b/src/arrow/cpp/src/gandiva/exported_funcs_registry.cc new file mode 100644 index 000000000..4c87c4d40 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/exported_funcs_registry.cc @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/exported_funcs_registry.h" + +#include "gandiva/exported_funcs.h" + +namespace gandiva { + +void ExportedFuncsRegistry::AddMappings(Engine* engine) { + for (auto entry : registered()) { + entry->AddMappings(engine); + } +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/exported_funcs_registry.h b/src/arrow/cpp/src/gandiva/exported_funcs_registry.h new file mode 100644 index 000000000..1504f2130 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/exported_funcs_registry.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <vector> + +#include <gandiva/engine.h> + +namespace gandiva { + +class ExportedFuncsBase; + +/// Registry for classes that export functions which can be accessed by +/// LLVM/IR code. +class ExportedFuncsRegistry { + public: + using list_type = std::vector<std::shared_ptr<ExportedFuncsBase>>; + + // Add functions from all the registered classes to the engine. + static void AddMappings(Engine* engine); + + static bool Register(std::shared_ptr<ExportedFuncsBase> entry) { + registered().push_back(entry); + return true; + } + + private: + static list_type& registered() { + static list_type registered_list; + return registered_list; + } +}; + +#define REGISTER_EXPORTED_FUNCS(classname) \ + static bool _registered_##classname = \ + ExportedFuncsRegistry::Register(std::make_shared<classname>()) + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expr_decomposer.cc b/src/arrow/cpp/src/gandiva/expr_decomposer.cc new file mode 100644 index 000000000..1c09d28f5 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expr_decomposer.cc @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/expr_decomposer.h" + +#include <memory> +#include <stack> +#include <string> +#include <unordered_set> +#include <vector> + +#include "gandiva/annotator.h" +#include "gandiva/dex.h" +#include "gandiva/function_holder_registry.h" +#include "gandiva/function_registry.h" +#include "gandiva/function_signature.h" +#include "gandiva/in_holder.h" +#include "gandiva/node.h" + +namespace gandiva { + +// Decompose a field node - simply separate out validity & value arrays. +Status ExprDecomposer::Visit(const FieldNode& node) { + auto desc = annotator_.CheckAndAddInputFieldDescriptor(node.field()); + + DexPtr validity_dex = std::make_shared<VectorReadValidityDex>(desc); + DexPtr value_dex; + if (desc->HasOffsetsIdx()) { + value_dex = std::make_shared<VectorReadVarLenValueDex>(desc); + } else { + value_dex = std::make_shared<VectorReadFixedLenValueDex>(desc); + } + result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex); + return Status::OK(); +} + +// Try and optimize a function node, by substituting with cheaper alternatives. +// eg. replacing 'like' with 'starts_with' can save function calls at evaluation +// time. +const FunctionNode ExprDecomposer::TryOptimize(const FunctionNode& node) { + if (node.descriptor()->name() == "like") { + return LikeHolder::TryOptimize(node); + } else { + return node; + } +} + +// Decompose a field node - wherever possible, merge the validity vectors of the +// child nodes. +Status ExprDecomposer::Visit(const FunctionNode& in_node) { + auto node = TryOptimize(in_node); + auto desc = node.descriptor(); + FunctionSignature signature(desc->name(), desc->params(), desc->return_type()); + const NativeFunction* native_function = registry_.LookupSignature(signature); + DCHECK(native_function) << "Missing Signature " << signature.ToString(); + + // decompose the children. + std::vector<ValueValidityPairPtr> args; + for (auto& child : node.children()) { + auto status = child->Accept(*this); + ARROW_RETURN_NOT_OK(status); + + args.push_back(result()); + } + + // Make a function holder, if required. + std::shared_ptr<FunctionHolder> holder; + if (native_function->NeedsFunctionHolder()) { + auto status = FunctionHolderRegistry::Make(desc->name(), node, &holder); + ARROW_RETURN_NOT_OK(status); + } + + if (native_function->result_nullable_type() == kResultNullIfNull) { + // These functions are decomposable, merge the validity bits of the children. + + std::vector<DexPtr> merged_validity; + for (auto& decomposed : args) { + // Merge the validity_expressions of the children to build a combined validity + // expression. + merged_validity.insert(merged_validity.end(), decomposed->validity_exprs().begin(), + decomposed->validity_exprs().end()); + } + + auto value_dex = + std::make_shared<NonNullableFuncDex>(desc, native_function, holder, args); + result_ = std::make_shared<ValueValidityPair>(merged_validity, value_dex); + } else if (native_function->result_nullable_type() == kResultNullNever) { + // These functions always output valid results. So, no validity dex. + auto value_dex = + std::make_shared<NullableNeverFuncDex>(desc, native_function, holder, args); + result_ = std::make_shared<ValueValidityPair>(value_dex); + } else { + DCHECK(native_function->result_nullable_type() == kResultNullInternal); + + // Add a local bitmap to track the output validity. + int local_bitmap_idx = annotator_.AddLocalBitMap(); + auto validity_dex = std::make_shared<LocalBitMapValidityDex>(local_bitmap_idx); + + auto value_dex = std::make_shared<NullableInternalFuncDex>( + desc, native_function, holder, args, local_bitmap_idx); + result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex); + } + return Status::OK(); +} + +// Decompose an IfNode +Status ExprDecomposer::Visit(const IfNode& node) { + // nested_if_else_ might get overwritten when visiting the condition-node, so + // saving the value to a local variable and resetting nested_if_else_ to false + bool svd_nested_if_else = nested_if_else_; + nested_if_else_ = false; + + PushConditionEntry(node); + auto status = node.condition()->Accept(*this); + ARROW_RETURN_NOT_OK(status); + auto condition_vv = result(); + PopConditionEntry(node); + + // Add a local bitmap to track the output validity. + int local_bitmap_idx = PushThenEntry(node, svd_nested_if_else); + status = node.then_node()->Accept(*this); + ARROW_RETURN_NOT_OK(status); + auto then_vv = result(); + PopThenEntry(node); + + PushElseEntry(node, local_bitmap_idx); + nested_if_else_ = (dynamic_cast<IfNode*>(node.else_node().get()) != nullptr); + + status = node.else_node()->Accept(*this); + ARROW_RETURN_NOT_OK(status); + auto else_vv = result(); + bool is_terminal_else = PopElseEntry(node); + + auto validity_dex = std::make_shared<LocalBitMapValidityDex>(local_bitmap_idx); + auto value_dex = + std::make_shared<IfDex>(condition_vv, then_vv, else_vv, node.return_type(), + local_bitmap_idx, is_terminal_else); + + result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex); + return Status::OK(); +} + +// Decompose a BooleanNode +Status ExprDecomposer::Visit(const BooleanNode& node) { + // decompose the children. + std::vector<ValueValidityPairPtr> args; + for (auto& child : node.children()) { + auto status = child->Accept(*this); + ARROW_RETURN_NOT_OK(status); + + args.push_back(result()); + } + + // Add a local bitmap to track the output validity. + int local_bitmap_idx = annotator_.AddLocalBitMap(); + auto validity_dex = std::make_shared<LocalBitMapValidityDex>(local_bitmap_idx); + + std::shared_ptr<BooleanDex> value_dex; + switch (node.expr_type()) { + case BooleanNode::AND: + value_dex = std::make_shared<BooleanAndDex>(args, local_bitmap_idx); + break; + case BooleanNode::OR: + value_dex = std::make_shared<BooleanOrDex>(args, local_bitmap_idx); + break; + } + result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex); + return Status::OK(); +} +Status ExprDecomposer::Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) { + /* decompose the children. */ + std::vector<ValueValidityPairPtr> args; + auto status = node.eval_expr()->Accept(*this); + ARROW_RETURN_NOT_OK(status); + args.push_back(result()); + /* In always outputs valid results, so no validity dex */ + auto value_dex = std::make_shared<InExprDex<gandiva::DecimalScalar128>>( + args, node.values(), node.get_precision(), node.get_scale()); + result_ = std::make_shared<ValueValidityPair>(value_dex); + return Status::OK(); +} + +#define MAKE_VISIT_IN(ctype) \ + Status ExprDecomposer::Visit(const InExpressionNode<ctype>& node) { \ + /* decompose the children. */ \ + std::vector<ValueValidityPairPtr> args; \ + auto status = node.eval_expr()->Accept(*this); \ + ARROW_RETURN_NOT_OK(status); \ + args.push_back(result()); \ + /* In always outputs valid results, so no validity dex */ \ + auto value_dex = std::make_shared<InExprDex<ctype>>(args, node.values()); \ + result_ = std::make_shared<ValueValidityPair>(value_dex); \ + return Status::OK(); \ + } + +MAKE_VISIT_IN(int32_t); +MAKE_VISIT_IN(int64_t); +MAKE_VISIT_IN(float); +MAKE_VISIT_IN(double); +MAKE_VISIT_IN(std::string); + +Status ExprDecomposer::Visit(const LiteralNode& node) { + auto value_dex = std::make_shared<LiteralDex>(node.return_type(), node.holder()); + DexPtr validity_dex; + if (node.is_null()) { + validity_dex = std::make_shared<FalseDex>(); + } else { + validity_dex = std::make_shared<TrueDex>(); + } + result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex); + return Status::OK(); +} + +// The bolow functions use a stack to detect : +// a. nested if-else expressions. +// In such cases, the local bitmap can be re-used. +// b. detect terminal else expressions +// The non-terminal else expressions do not need to track validity (the if statement +// that has a match will do it). +// Both of the above optimisations save CPU cycles during expression evaluation. + +int ExprDecomposer::PushThenEntry(const IfNode& node, bool reuse_bitmap) { + int local_bitmap_idx; + + if (reuse_bitmap) { + // we also need stack in addition to reuse_bitmap flag since we + // can also enter other if-else nodes when we visit the condition-node + // (which themselves might be nested) before we visit then-node + DCHECK_EQ(if_entries_stack_.empty(), false) << "PushThenEntry: stack is empty"; + DCHECK_EQ(if_entries_stack_.top()->entry_type_, kStackEntryElse) + << "PushThenEntry: top of stack is not of type entry_else"; + auto top = if_entries_stack_.top().get(); + + // inside a nested else statement (i.e if-else-if). use the parent's bitmap. + local_bitmap_idx = top->local_bitmap_idx_; + + // clear the is_terminal bit in the current top entry (else). + top->is_terminal_else_ = false; + } else { + // alloc a new bitmap. + local_bitmap_idx = annotator_.AddLocalBitMap(); + } + + // push new entry to the stack. + std::unique_ptr<IfStackEntry> entry(new IfStackEntry( + node, kStackEntryThen, false /*is_terminal_else*/, local_bitmap_idx)); + if_entries_stack_.emplace(std::move(entry)); + return local_bitmap_idx; +} + +void ExprDecomposer::PopThenEntry(const IfNode& node) { + DCHECK_EQ(if_entries_stack_.empty(), false) << "PopThenEntry: found empty stack"; + + auto top = if_entries_stack_.top().get(); + DCHECK_EQ(top->entry_type_, kStackEntryThen) + << "PopThenEntry: found " << top->entry_type_ << " expected then"; + DCHECK_EQ(&top->if_node_, &node) << "PopThenEntry: found mismatched node"; + + if_entries_stack_.pop(); +} + +void ExprDecomposer::PushElseEntry(const IfNode& node, int local_bitmap_idx) { + std::unique_ptr<IfStackEntry> entry(new IfStackEntry( + node, kStackEntryElse, true /*is_terminal_else*/, local_bitmap_idx)); + if_entries_stack_.emplace(std::move(entry)); +} + +bool ExprDecomposer::PopElseEntry(const IfNode& node) { + DCHECK_EQ(if_entries_stack_.empty(), false) << "PopElseEntry: found empty stack"; + + auto top = if_entries_stack_.top().get(); + DCHECK_EQ(top->entry_type_, kStackEntryElse) + << "PopElseEntry: found " << top->entry_type_ << " expected else"; + DCHECK_EQ(&top->if_node_, &node) << "PopElseEntry: found mismatched node"; + bool is_terminal_else = top->is_terminal_else_; + + if_entries_stack_.pop(); + return is_terminal_else; +} + +void ExprDecomposer::PushConditionEntry(const IfNode& node) { + std::unique_ptr<IfStackEntry> entry(new IfStackEntry(node, kStackEntryCondition)); + if_entries_stack_.emplace(std::move(entry)); +} + +void ExprDecomposer::PopConditionEntry(const IfNode& node) { + DCHECK_EQ(if_entries_stack_.empty(), false) << "PopConditionEntry: found empty stack"; + + auto top = if_entries_stack_.top().get(); + DCHECK_EQ(top->entry_type_, kStackEntryCondition) + << "PopConditionEntry: found " << top->entry_type_ << " expected condition"; + DCHECK_EQ(&top->if_node_, &node) << "PopConditionEntry: found mismatched node"; + if_entries_stack_.pop(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expr_decomposer.h b/src/arrow/cpp/src/gandiva/expr_decomposer.h new file mode 100644 index 000000000..f68b8a8fc --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expr_decomposer.h @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <memory> +#include <stack> +#include <string> +#include <utility> + +#include "gandiva/arrow.h" +#include "gandiva/expression.h" +#include "gandiva/node.h" +#include "gandiva/node_visitor.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class FunctionRegistry; +class Annotator; + +/// \brief Decomposes an expression tree to separate out the validity and +/// value expressions. +class GANDIVA_EXPORT ExprDecomposer : public NodeVisitor { + public: + explicit ExprDecomposer(const FunctionRegistry& registry, Annotator& annotator) + : registry_(registry), annotator_(annotator), nested_if_else_(false) {} + + Status Decompose(const Node& root, ValueValidityPairPtr* out) { + auto status = root.Accept(*this); + if (status.ok()) { + *out = std::move(result_); + } + return status; + } + + private: + ARROW_DISALLOW_COPY_AND_ASSIGN(ExprDecomposer); + + FRIEND_TEST(TestExprDecomposer, TestStackSimple); + FRIEND_TEST(TestExprDecomposer, TestNested); + FRIEND_TEST(TestExprDecomposer, TestInternalIf); + FRIEND_TEST(TestExprDecomposer, TestParallelIf); + FRIEND_TEST(TestExprDecomposer, TestIfInCondition); + FRIEND_TEST(TestExprDecomposer, TestFunctionBetweenNestedIf); + FRIEND_TEST(TestExprDecomposer, TestComplexIfCondition); + + Status Visit(const FieldNode& node) override; + Status Visit(const FunctionNode& node) override; + Status Visit(const IfNode& node) override; + Status Visit(const LiteralNode& node) override; + Status Visit(const BooleanNode& node) override; + Status Visit(const InExpressionNode<int32_t>& node) override; + Status Visit(const InExpressionNode<int64_t>& node) override; + Status Visit(const InExpressionNode<float>& node) override; + Status Visit(const InExpressionNode<double>& node) override; + Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) override; + Status Visit(const InExpressionNode<std::string>& node) override; + + // Optimize a function node, if possible. + const FunctionNode TryOptimize(const FunctionNode& node); + + enum StackEntryType { kStackEntryCondition, kStackEntryThen, kStackEntryElse }; + + // stack of if nodes. + class IfStackEntry { + public: + IfStackEntry(const IfNode& if_node, StackEntryType entry_type, + bool is_terminal_else = false, int local_bitmap_idx = 0) + : if_node_(if_node), + entry_type_(entry_type), + is_terminal_else_(is_terminal_else), + local_bitmap_idx_(local_bitmap_idx) {} + + const IfNode& if_node_; + StackEntryType entry_type_; + bool is_terminal_else_; + int local_bitmap_idx_; + + private: + ARROW_DISALLOW_COPY_AND_ASSIGN(IfStackEntry); + }; + + // pop 'condition entry' into stack. + void PushConditionEntry(const IfNode& node); + + // pop 'condition entry' from stack. + void PopConditionEntry(const IfNode& node); + + // push 'then entry' to stack. returns either a new local bitmap or the parent's + // bitmap (in case of nested if-else). + int PushThenEntry(const IfNode& node, bool reuse_bitmap); + + // pop 'then entry' from stack. + void PopThenEntry(const IfNode& node); + + // push 'else entry' into stack. + void PushElseEntry(const IfNode& node, int local_bitmap_idx); + + // pop 'else entry' from stack. returns 'true' if this is a terminal else condition + // i.e no nested if condition below this node. + bool PopElseEntry(const IfNode& node); + + ValueValidityPairPtr result() { return std::move(result_); } + + const FunctionRegistry& registry_; + Annotator& annotator_; + std::stack<std::unique_ptr<IfStackEntry>> if_entries_stack_; + ValueValidityPairPtr result_; + bool nested_if_else_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expr_decomposer_test.cc b/src/arrow/cpp/src/gandiva/expr_decomposer_test.cc new file mode 100644 index 000000000..638ceebcb --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expr_decomposer_test.cc @@ -0,0 +1,409 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/expr_decomposer.h" + +#include <gtest/gtest.h> + +#include "gandiva/annotator.h" +#include "gandiva/dex.h" +#include "gandiva/function_registry.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/node.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::int32; + +class TestExprDecomposer : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +TEST_F(TestExprDecomposer, TestStackSimple) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + + decomposer.PushConditionEntry(node_a); + decomposer.PopConditionEntry(node_a); + + int idx_a = decomposer.PushThenEntry(node_a, false); + EXPECT_EQ(idx_a, 0); + decomposer.PopThenEntry(node_a); + + decomposer.PushElseEntry(node_a, idx_a); + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, true); + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestNested) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // else _ + // if (b) _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + + decomposer.PushConditionEntry(node_a); + decomposer.PopConditionEntry(node_a); + + int idx_a = decomposer.PushThenEntry(node_a, false); + EXPECT_EQ(idx_a, 0); + decomposer.PopThenEntry(node_a); + + decomposer.PushElseEntry(node_a, idx_a); + + { // start b + decomposer.PushConditionEntry(node_b); + decomposer.PopConditionEntry(node_b); + + int idx_b = decomposer.PushThenEntry(node_b, true); + EXPECT_EQ(idx_b, 0); // must reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal_b = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal_b, true); + } // end b + + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, false); // there was a nested if. + + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestInternalIf) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // if (b) _ + // else _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + + decomposer.PushConditionEntry(node_a); + decomposer.PopConditionEntry(node_a); + + int idx_a = decomposer.PushThenEntry(node_a, false); + EXPECT_EQ(idx_a, 0); + + { // start b + decomposer.PushConditionEntry(node_b); + decomposer.PopConditionEntry(node_b); + + int idx_b = decomposer.PushThenEntry(node_b, false); + EXPECT_EQ(idx_b, 1); // must not reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal_b = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal_b, true); + } // end b + + decomposer.PopThenEntry(node_a); + decomposer.PushElseEntry(node_a, idx_a); + + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, true); // there was no nested if. + + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestParallelIf) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // else _ + // if (b) _ + // else _ + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + + decomposer.PushConditionEntry(node_a); + decomposer.PopConditionEntry(node_a); + + int idx_a = decomposer.PushThenEntry(node_a, false); + EXPECT_EQ(idx_a, 0); + + decomposer.PopThenEntry(node_a); + decomposer.PushElseEntry(node_a, idx_a); + + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, true); // there was no nested if. + + // start b + decomposer.PushConditionEntry(node_b); + decomposer.PopConditionEntry(node_b); + + int idx_b = decomposer.PushThenEntry(node_b, false); + EXPECT_EQ(idx_b, 1); // must not reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal_b = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal_b, true); + + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestIfInCondition) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (if _ else _) : a + // - + // else + // if (if _ else _) : b + // - + // else + // - + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + IfNode cond_node_a(nullptr, nullptr, nullptr, int32()); + IfNode cond_node_b(nullptr, nullptr, nullptr, int32()); + + // start a + decomposer.PushConditionEntry(node_a); + { + // start cond_node_a + decomposer.PushConditionEntry(cond_node_a); + decomposer.PopConditionEntry(cond_node_a); + + int idx_cond_a = decomposer.PushThenEntry(cond_node_a, false); + EXPECT_EQ(idx_cond_a, 0); + decomposer.PopThenEntry(cond_node_a); + + decomposer.PushElseEntry(cond_node_a, idx_cond_a); + bool is_terminal = decomposer.PopElseEntry(cond_node_a); + EXPECT_EQ(is_terminal, true); // there was no nested if. + } + decomposer.PopConditionEntry(node_a); + + int idx_a = decomposer.PushThenEntry(node_a, false); + EXPECT_EQ(idx_a, 1); // no re-use + decomposer.PopThenEntry(node_a); + + decomposer.PushElseEntry(node_a, idx_a); + + { // start b + decomposer.PushConditionEntry(node_b); + { + // start cond_node_b + decomposer.PushConditionEntry(cond_node_b); + decomposer.PopConditionEntry(cond_node_b); + + int idx_cond_b = decomposer.PushThenEntry(cond_node_b, false); + EXPECT_EQ(idx_cond_b, 2); // no re-use + decomposer.PopThenEntry(cond_node_b); + + decomposer.PushElseEntry(cond_node_b, idx_cond_b); + bool is_terminal = decomposer.PopElseEntry(cond_node_b); + EXPECT_EQ(is_terminal, true); // there was no nested if. + } + decomposer.PopConditionEntry(node_b); + + int idx_b = decomposer.PushThenEntry(node_b, true); + EXPECT_EQ(idx_b, 1); // must reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal, true); + } // end b + + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, false); // there was a nested if. + + EXPECT_EQ(decomposer.if_entries_stack_.empty(), true); +} + +TEST_F(TestExprDecomposer, TestFunctionBetweenNestedIf) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (a) _ + // else + // function( + // if (b) _ + // else _ + // ) + + IfNode node_a(nullptr, nullptr, nullptr, int32()); + IfNode node_b(nullptr, nullptr, nullptr, int32()); + + // start outer if + decomposer.PushConditionEntry(node_a); + decomposer.PopConditionEntry(node_a); + + int idx_a = decomposer.PushThenEntry(node_a, false); + EXPECT_EQ(idx_a, 0); + decomposer.PopThenEntry(node_a); + + decomposer.PushElseEntry(node_a, idx_a); + { // start b + decomposer.PushConditionEntry(node_b); + decomposer.PopConditionEntry(node_b); + + int idx_b = decomposer.PushThenEntry(node_b, false); // not else node of parent if + EXPECT_EQ(idx_b, 1); // can't reuse bitmap. + decomposer.PopThenEntry(node_b); + + decomposer.PushElseEntry(node_b, idx_b); + bool is_terminal_b = decomposer.PopElseEntry(node_b); + EXPECT_EQ(is_terminal_b, true); + } + bool is_terminal_a = decomposer.PopElseEntry(node_a); + EXPECT_EQ(is_terminal_a, true); // a else is also terminal + + EXPECT_TRUE(decomposer.if_entries_stack_.empty()); +} + +TEST_F(TestExprDecomposer, TestComplexIfCondition) { + Annotator annotator; + ExprDecomposer decomposer(registry_, annotator); + + // if (if _ + // else + // if _ + // else _ + // ) + // then + // if _ + // else + // if _ + // else _ + // + // else + // if _ + // else + // if _ + // else _ + + IfNode node_a(nullptr, nullptr, nullptr, int32()); + + IfNode cond_node_a(nullptr, nullptr, nullptr, int32()); + IfNode cond_node_a_inner_if(nullptr, nullptr, nullptr, int32()); + + IfNode then_node_a(nullptr, nullptr, nullptr, int32()); + IfNode then_node_a_inner_if(nullptr, nullptr, nullptr, int32()); + + IfNode else_node_a(nullptr, nullptr, nullptr, int32()); + IfNode else_node_a_inner_if(nullptr, nullptr, nullptr, int32()); + + // start outer if + decomposer.PushConditionEntry(node_a); + { + // start the nested if inside the condition of a + decomposer.PushConditionEntry(cond_node_a); + decomposer.PopConditionEntry(cond_node_a); + + int idx_cond_a = decomposer.PushThenEntry(cond_node_a, false); + EXPECT_EQ(idx_cond_a, 0); + decomposer.PopThenEntry(cond_node_a); + + decomposer.PushElseEntry(cond_node_a, idx_cond_a); + { + decomposer.PushConditionEntry(cond_node_a_inner_if); + decomposer.PopConditionEntry(cond_node_a_inner_if); + + int idx_cond_a_inner_if = decomposer.PushThenEntry(cond_node_a_inner_if, true); + EXPECT_EQ(idx_cond_a_inner_if, + 0); // expect bitmap to be resused since nested if else + decomposer.PopThenEntry(cond_node_a_inner_if); + + decomposer.PushElseEntry(cond_node_a_inner_if, idx_cond_a_inner_if); + bool is_terminal = decomposer.PopElseEntry(cond_node_a_inner_if); + EXPECT_TRUE(is_terminal); + } + EXPECT_FALSE(decomposer.PopElseEntry(cond_node_a)); + } + decomposer.PopConditionEntry(node_a); + + int idx_a = decomposer.PushThenEntry(node_a, false); + EXPECT_EQ(idx_a, 1); + + { + // start the nested if inside the then node of a + decomposer.PushConditionEntry(then_node_a); + decomposer.PopConditionEntry(then_node_a); + + int idx_then_a = decomposer.PushThenEntry(then_node_a, false); + EXPECT_EQ(idx_then_a, 2); + decomposer.PopThenEntry(then_node_a); + + decomposer.PushElseEntry(then_node_a, idx_then_a); + { + decomposer.PushConditionEntry(then_node_a_inner_if); + decomposer.PopConditionEntry(then_node_a_inner_if); + + int idx_then_a_inner_if = decomposer.PushThenEntry(then_node_a_inner_if, true); + EXPECT_EQ(idx_then_a_inner_if, + 2); // expect bitmap to be resused since nested if else + decomposer.PopThenEntry(then_node_a_inner_if); + + decomposer.PushElseEntry(then_node_a_inner_if, idx_then_a_inner_if); + bool is_terminal = decomposer.PopElseEntry(then_node_a_inner_if); + EXPECT_TRUE(is_terminal); + } + EXPECT_FALSE(decomposer.PopElseEntry(then_node_a)); + } + decomposer.PopThenEntry(node_a); + + decomposer.PushElseEntry(node_a, idx_a); + { + // start the nested if inside the else node of a + decomposer.PushConditionEntry(else_node_a); + decomposer.PopConditionEntry(else_node_a); + + int idx_else_a = + decomposer.PushThenEntry(else_node_a, true); // else node is another if-node + EXPECT_EQ(idx_else_a, 1); // reuse the outer if node bitmap since nested if-else + decomposer.PopThenEntry(else_node_a); + + decomposer.PushElseEntry(else_node_a, idx_else_a); + { + decomposer.PushConditionEntry(else_node_a_inner_if); + decomposer.PopConditionEntry(else_node_a_inner_if); + + int idx_else_a_inner_if = decomposer.PushThenEntry(else_node_a_inner_if, true); + EXPECT_EQ(idx_else_a_inner_if, + 1); // expect bitmap to be resused since nested if else + decomposer.PopThenEntry(else_node_a_inner_if); + + decomposer.PushElseEntry(else_node_a_inner_if, idx_else_a_inner_if); + bool is_terminal = decomposer.PopElseEntry(else_node_a_inner_if); + EXPECT_TRUE(is_terminal); + } + EXPECT_FALSE(decomposer.PopElseEntry(else_node_a)); + } + EXPECT_FALSE(decomposer.PopElseEntry(node_a)); + EXPECT_TRUE(decomposer.if_entries_stack_.empty()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expr_validator.cc b/src/arrow/cpp/src/gandiva/expr_validator.cc new file mode 100644 index 000000000..c3c784c95 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expr_validator.cc @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <sstream> +#include <string> +#include <vector> + +#include "gandiva/expr_validator.h" + +namespace gandiva { + +Status ExprValidator::Validate(const ExpressionPtr& expr) { + ARROW_RETURN_IF(expr == nullptr, + Status::ExpressionValidationError("Expression cannot be null")); + + Node& root = *expr->root(); + ARROW_RETURN_NOT_OK(root.Accept(*this)); + + // Ensure root's return type match the expression return type. Type + // support validation is not required because root type is already supported. + ARROW_RETURN_IF(!root.return_type()->Equals(*expr->result()->type()), + Status::ExpressionValidationError("Return type of root node ", + root.return_type()->ToString(), + " does not match that of expression ", + expr->result()->type()->ToString())); + + return Status::OK(); +} + +Status ExprValidator::Visit(const FieldNode& node) { + auto llvm_type = types_->IRType(node.return_type()->id()); + ARROW_RETURN_IF(llvm_type == nullptr, + Status::ExpressionValidationError("Field ", node.field()->name(), + " has unsupported data type ", + node.return_type()->name())); + + // Ensure that field is found in schema + auto field_in_schema_entry = field_map_.find(node.field()->name()); + ARROW_RETURN_IF(field_in_schema_entry == field_map_.end(), + Status::ExpressionValidationError("Field ", node.field()->name(), + " not in schema.")); + + // Ensure that that the found field match. + FieldPtr field_in_schema = field_in_schema_entry->second; + ARROW_RETURN_IF(!field_in_schema->Equals(node.field()), + Status::ExpressionValidationError( + "Field definition in schema ", field_in_schema->ToString(), + " different from field in expression ", node.field()->ToString())); + + return Status::OK(); +} + +Status ExprValidator::Visit(const FunctionNode& node) { + auto desc = node.descriptor(); + FunctionSignature signature(desc->name(), desc->params(), desc->return_type()); + + const NativeFunction* native_function = registry_.LookupSignature(signature); + ARROW_RETURN_IF(native_function == nullptr, + Status::ExpressionValidationError("Function ", signature.ToString(), + " not supported yet. ")); + + for (auto& child : node.children()) { + ARROW_RETURN_NOT_OK(child->Accept(*this)); + } + + return Status::OK(); +} + +Status ExprValidator::Visit(const IfNode& node) { + ARROW_RETURN_NOT_OK(node.condition()->Accept(*this)); + ARROW_RETURN_NOT_OK(node.then_node()->Accept(*this)); + ARROW_RETURN_NOT_OK(node.else_node()->Accept(*this)); + + auto if_node_ret_type = node.return_type(); + auto then_node_ret_type = node.then_node()->return_type(); + auto else_node_ret_type = node.else_node()->return_type(); + + // condition must be of boolean type. + ARROW_RETURN_IF( + !node.condition()->return_type()->Equals(arrow::boolean()), + Status::ExpressionValidationError("condition must be of boolean type, found type ", + node.condition()->return_type()->ToString())); + + // Then-branch return type must match. + ARROW_RETURN_IF(!if_node_ret_type->Equals(*then_node_ret_type), + Status::ExpressionValidationError( + "Return type of if ", if_node_ret_type->ToString(), " and then ", + then_node_ret_type->ToString(), " not matching.")); + + // Else-branch return type must match. + ARROW_RETURN_IF(!if_node_ret_type->Equals(*else_node_ret_type), + Status::ExpressionValidationError( + "Return type of if ", if_node_ret_type->ToString(), " and else ", + else_node_ret_type->ToString(), " not matching.")); + + return Status::OK(); +} + +Status ExprValidator::Visit(const LiteralNode& node) { + auto llvm_type = types_->IRType(node.return_type()->id()); + ARROW_RETURN_IF(llvm_type == nullptr, + Status::ExpressionValidationError("Value ", ToString(node.holder()), + " has unsupported data type ", + node.return_type()->name())); + + return Status::OK(); +} + +Status ExprValidator::Visit(const BooleanNode& node) { + ARROW_RETURN_IF( + node.children().size() < 2, + Status::ExpressionValidationError("Boolean expression has ", node.children().size(), + " children, expected at least two")); + + for (auto& child : node.children()) { + const auto bool_type = arrow::boolean(); + const auto ret_type = child->return_type(); + + ARROW_RETURN_IF(!ret_type->Equals(bool_type), + Status::ExpressionValidationError( + "Boolean expression has a child with return type ", + ret_type->ToString(), ", expected return type boolean")); + + ARROW_RETURN_NOT_OK(child->Accept(*this)); + } + + return Status::OK(); +} + +/* + * Validate the following + * + * 1. Non empty list of constants to search in. + * 2. Expression returns of the same type as the constants. + */ +Status ExprValidator::Visit(const InExpressionNode<int32_t>& node) { + return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(), + arrow::int32()); +} + +Status ExprValidator::Visit(const InExpressionNode<int64_t>& node) { + return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(), + arrow::int64()); +} +Status ExprValidator::Visit(const InExpressionNode<float>& node) { + return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(), + arrow::float32()); +} +Status ExprValidator::Visit(const InExpressionNode<double>& node) { + return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(), + arrow::float64()); +} + +Status ExprValidator::Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) { + return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(), + arrow::decimal(node.get_precision(), node.get_scale())); +} + +Status ExprValidator::Visit(const InExpressionNode<std::string>& node) { + return ValidateInExpression(node.values().size(), node.eval_expr()->return_type(), + arrow::utf8()); +} + +Status ExprValidator::ValidateInExpression(size_t number_of_values, + DataTypePtr in_expr_return_type, + DataTypePtr type_of_values) { + ARROW_RETURN_IF(number_of_values == 0, + Status::ExpressionValidationError( + "IN Expression needs a non-empty constant list to match.")); + ARROW_RETURN_IF( + !in_expr_return_type->Equals(type_of_values), + Status::ExpressionValidationError( + "Evaluation expression for IN clause returns ", in_expr_return_type->ToString(), + " values are of type", type_of_values->ToString())); + + return Status::OK(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expr_validator.h b/src/arrow/cpp/src/gandiva/expr_validator.h new file mode 100644 index 000000000..daaf50897 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expr_validator.h @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <string> +#include <unordered_map> + +#include "arrow/status.h" + +#include "gandiva/arrow.h" +#include "gandiva/expression.h" +#include "gandiva/function_registry.h" +#include "gandiva/llvm_types.h" +#include "gandiva/node.h" +#include "gandiva/node_visitor.h" + +namespace gandiva { + +class FunctionRegistry; + +/// \brief Validates the entire expression tree including +/// data types, signatures and return types +class ExprValidator : public NodeVisitor { + public: + explicit ExprValidator(LLVMTypes* types, SchemaPtr schema) + : types_(types), schema_(schema) { + for (auto& field : schema_->fields()) { + field_map_[field->name()] = field; + } + } + + /// \brief Validates the root node + /// of an expression. + /// 1. Data type of fields and literals. + /// 2. Function signature is supported. + /// 3. For if nodes that return types match + /// for if, then and else nodes. + Status Validate(const ExpressionPtr& expr); + + private: + Status Visit(const FieldNode& node) override; + Status Visit(const FunctionNode& node) override; + Status Visit(const IfNode& node) override; + Status Visit(const LiteralNode& node) override; + Status Visit(const BooleanNode& node) override; + Status Visit(const InExpressionNode<int32_t>& node) override; + Status Visit(const InExpressionNode<int64_t>& node) override; + Status Visit(const InExpressionNode<float>& node) override; + Status Visit(const InExpressionNode<double>& node) override; + Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) override; + Status Visit(const InExpressionNode<std::string>& node) override; + Status ValidateInExpression(size_t number_of_values, DataTypePtr in_expr_return_type, + DataTypePtr type_of_values); + + FunctionRegistry registry_; + + LLVMTypes* types_; + + SchemaPtr schema_; + + using FieldMap = std::unordered_map<std::string, FieldPtr>; + FieldMap field_map_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expression.cc b/src/arrow/cpp/src/gandiva/expression.cc new file mode 100644 index 000000000..06aada27b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expression.cc @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/expression.h" +#include "gandiva/node.h" + +namespace gandiva { + +std::string Expression::ToString() { return root()->ToString(); } + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expression.h b/src/arrow/cpp/src/gandiva/expression.h new file mode 100644 index 000000000..cdda2512b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expression.h @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <string> + +#include "gandiva/arrow.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief An expression tree with a root node, and a result field. +class GANDIVA_EXPORT Expression { + public: + Expression(const NodePtr root, const FieldPtr result) : root_(root), result_(result) {} + + virtual ~Expression() = default; + + const NodePtr& root() const { return root_; } + + const FieldPtr& result() const { return result_; } + + std::string ToString(); + + private: + const NodePtr root_; + const FieldPtr result_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expression_registry.cc b/src/arrow/cpp/src/gandiva/expression_registry.cc new file mode 100644 index 000000000..c3a08fd3a --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expression_registry.cc @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/expression_registry.h" + +#include "gandiva/function_registry.h" +#include "gandiva/llvm_types.h" + +namespace gandiva { + +ExpressionRegistry::ExpressionRegistry() { + function_registry_.reset(new FunctionRegistry()); +} + +ExpressionRegistry::~ExpressionRegistry() {} + +// to be used only to create function_signature_start +ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator( + native_func_iterator_type nf_it, native_func_iterator_type nf_it_end) + : native_func_it_{nf_it}, + native_func_it_end_{nf_it_end}, + func_sig_it_{&(nf_it->signatures().front())} {} + +// to be used only to create function_signature_end +ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator( + func_sig_iterator_type fs_it) + : native_func_it_{nullptr}, native_func_it_end_{nullptr}, func_sig_it_{fs_it} {} + +const ExpressionRegistry::FunctionSignatureIterator +ExpressionRegistry::function_signature_begin() { + return FunctionSignatureIterator(function_registry_->begin(), + function_registry_->end()); +} + +const ExpressionRegistry::FunctionSignatureIterator +ExpressionRegistry::function_signature_end() const { + return FunctionSignatureIterator(&(*(function_registry_->back()->signatures().end()))); +} + +bool ExpressionRegistry::FunctionSignatureIterator::operator!=( + const FunctionSignatureIterator& func_sign_it) { + return func_sign_it.func_sig_it_ != this->func_sig_it_; +} + +FunctionSignature ExpressionRegistry::FunctionSignatureIterator::operator*() { + return *func_sig_it_; +} + +ExpressionRegistry::func_sig_iterator_type ExpressionRegistry::FunctionSignatureIterator:: +operator++(int increment) { + ++func_sig_it_; + // point func_sig_it_ to first signature of next nativefunction if func_sig_it_ is + // pointing to end + if (func_sig_it_ == &(*native_func_it_->signatures().end())) { + ++native_func_it_; + if (native_func_it_ == native_func_it_end_) { // last native function + return func_sig_it_; + } + func_sig_it_ = &(native_func_it_->signatures().front()); + } + return func_sig_it_; +} + +static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector); + +static DataTypeVector InitSupportedTypes() { + DataTypeVector data_type_vector; + llvm::LLVMContext llvm_context; + LLVMTypes llvm_types(llvm_context); + auto supported_arrow_types = llvm_types.GetSupportedArrowTypes(); + for (auto& type_id : supported_arrow_types) { + AddArrowTypesToVector(type_id, data_type_vector); + } + return data_type_vector; +} + +DataTypeVector ExpressionRegistry::supported_types_ = InitSupportedTypes(); + +static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector) { + switch (type) { + case arrow::Type::type::BOOL: + vector.push_back(arrow::boolean()); + break; + case arrow::Type::type::UINT8: + vector.push_back(arrow::uint8()); + break; + case arrow::Type::type::INT8: + vector.push_back(arrow::int8()); + break; + case arrow::Type::type::UINT16: + vector.push_back(arrow::uint16()); + break; + case arrow::Type::type::INT16: + vector.push_back(arrow::int16()); + break; + case arrow::Type::type::UINT32: + vector.push_back(arrow::uint32()); + break; + case arrow::Type::type::INT32: + vector.push_back(arrow::int32()); + break; + case arrow::Type::type::UINT64: + vector.push_back(arrow::uint64()); + break; + case arrow::Type::type::INT64: + vector.push_back(arrow::int64()); + break; + case arrow::Type::type::HALF_FLOAT: + vector.push_back(arrow::float16()); + break; + case arrow::Type::type::FLOAT: + vector.push_back(arrow::float32()); + break; + case arrow::Type::type::DOUBLE: + vector.push_back(arrow::float64()); + break; + case arrow::Type::type::STRING: + vector.push_back(arrow::utf8()); + break; + case arrow::Type::type::BINARY: + vector.push_back(arrow::binary()); + break; + case arrow::Type::type::DATE32: + vector.push_back(arrow::date32()); + break; + case arrow::Type::type::DATE64: + vector.push_back(arrow::date64()); + break; + case arrow::Type::type::TIMESTAMP: + vector.push_back(arrow::timestamp(arrow::TimeUnit::SECOND)); + vector.push_back(arrow::timestamp(arrow::TimeUnit::MILLI)); + vector.push_back(arrow::timestamp(arrow::TimeUnit::NANO)); + vector.push_back(arrow::timestamp(arrow::TimeUnit::MICRO)); + break; + case arrow::Type::type::TIME32: + vector.push_back(arrow::time32(arrow::TimeUnit::SECOND)); + vector.push_back(arrow::time32(arrow::TimeUnit::MILLI)); + break; + case arrow::Type::type::TIME64: + vector.push_back(arrow::time64(arrow::TimeUnit::MICRO)); + vector.push_back(arrow::time64(arrow::TimeUnit::NANO)); + break; + case arrow::Type::type::NA: + vector.push_back(arrow::null()); + break; + case arrow::Type::type::DECIMAL: + vector.push_back(arrow::decimal(38, 0)); + break; + case arrow::Type::type::INTERVAL_MONTHS: + vector.push_back(arrow::month_interval()); + break; + case arrow::Type::type::INTERVAL_DAY_TIME: + vector.push_back(arrow::day_time_interval()); + break; + default: + // Unsupported types. test ensures that + // when one of these are added build breaks. + DCHECK(false); + } +} + +std::vector<std::shared_ptr<FunctionSignature>> GetRegisteredFunctionSignatures() { + ExpressionRegistry registry; + std::vector<std::shared_ptr<FunctionSignature>> signatures; + for (auto iter = registry.function_signature_begin(); + iter != registry.function_signature_end(); iter++) { + signatures.push_back(std::make_shared<FunctionSignature>( + (*iter).base_name(), (*iter).param_types(), (*iter).ret_type())); + } + return signatures; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expression_registry.h b/src/arrow/cpp/src/gandiva/expression_registry.h new file mode 100644 index 000000000..fb4f177ba --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expression_registry.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <vector> + +#include "gandiva/arrow.h" +#include "gandiva/function_signature.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class NativeFunction; +class FunctionRegistry; +/// \brief Exports types supported by Gandiva for processing. +/// +/// Has helper methods for clients to programmatically discover +/// data types and functions supported by Gandiva. +class GANDIVA_EXPORT ExpressionRegistry { + public: + using native_func_iterator_type = const NativeFunction*; + using func_sig_iterator_type = const FunctionSignature*; + ExpressionRegistry(); + ~ExpressionRegistry(); + static DataTypeVector supported_types() { return supported_types_; } + class GANDIVA_EXPORT FunctionSignatureIterator { + public: + explicit FunctionSignatureIterator(native_func_iterator_type nf_it, + native_func_iterator_type nf_it_end_); + explicit FunctionSignatureIterator(func_sig_iterator_type fs_it); + + bool operator!=(const FunctionSignatureIterator& func_sign_it); + + FunctionSignature operator*(); + + func_sig_iterator_type operator++(int); + + private: + native_func_iterator_type native_func_it_; + const native_func_iterator_type native_func_it_end_; + func_sig_iterator_type func_sig_it_; + }; + const FunctionSignatureIterator function_signature_begin(); + const FunctionSignatureIterator function_signature_end() const; + + private: + static DataTypeVector supported_types_; + std::unique_ptr<FunctionRegistry> function_registry_; +}; + +GANDIVA_EXPORT +std::vector<std::shared_ptr<FunctionSignature>> GetRegisteredFunctionSignatures(); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/expression_registry_test.cc b/src/arrow/cpp/src/gandiva/expression_registry_test.cc new file mode 100644 index 000000000..c254ff4f3 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/expression_registry_test.cc @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/expression_registry.h" + +#include <algorithm> +#include <vector> + +#include <gtest/gtest.h> +#include "gandiva/function_registry.h" +#include "gandiva/function_signature.h" +#include "gandiva/llvm_types.h" + +namespace gandiva { + +typedef int64_t (*add_vector_func_t)(int64_t* elements, int nelements); + +class TestExpressionRegistry : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +// Verify all functions in registry are exported. +TEST_F(TestExpressionRegistry, VerifySupportedFunctions) { + std::vector<FunctionSignature> functions; + ExpressionRegistry expr_registry; + for (auto iter = expr_registry.function_signature_begin(); + iter != expr_registry.function_signature_end(); iter++) { + functions.push_back((*iter)); + } + for (auto& iter : registry_) { + for (auto& func_iter : iter.signatures()) { + auto element = std::find(functions.begin(), functions.end(), func_iter); + EXPECT_NE(element, functions.end()) << "function signature " << func_iter.ToString() + << " missing in supported functions.\n"; + } + } +} + +// Verify all types are supported. +TEST_F(TestExpressionRegistry, VerifyDataTypes) { + DataTypeVector data_types = ExpressionRegistry::supported_types(); + llvm::LLVMContext llvm_context; + LLVMTypes llvm_types(llvm_context); + auto supported_arrow_types = llvm_types.GetSupportedArrowTypes(); + for (auto& type_id : supported_arrow_types) { + auto element = + std::find(supported_arrow_types.begin(), supported_arrow_types.end(), type_id); + EXPECT_NE(element, supported_arrow_types.end()) + << "data type " << type_id << " missing in supported data types.\n"; + } +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/field_descriptor.h b/src/arrow/cpp/src/gandiva/field_descriptor.h new file mode 100644 index 000000000..0fe6fe37f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/field_descriptor.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <string> + +#include "gandiva/arrow.h" + +namespace gandiva { + +/// \brief Descriptor for an arrow field. Holds indexes into the flattened array of +/// buffers that is passed to LLVM generated functions. +class FieldDescriptor { + public: + static const int kInvalidIdx = -1; + + FieldDescriptor(FieldPtr field, int data_idx, int validity_idx = kInvalidIdx, + int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx) + : field_(field), + data_idx_(data_idx), + validity_idx_(validity_idx), + offsets_idx_(offsets_idx), + data_buffer_ptr_idx_(data_buffer_ptr_idx) {} + + /// Index of validity array in the array-of-buffers + int validity_idx() const { return validity_idx_; } + + /// Index of data array in the array-of-buffers + int data_idx() const { return data_idx_; } + + /// Index of offsets array in the array-of-buffers + int offsets_idx() const { return offsets_idx_; } + + /// Index of data buffer pointer in the array-of-buffers + int data_buffer_ptr_idx() const { return data_buffer_ptr_idx_; } + + FieldPtr field() const { return field_; } + + const std::string& Name() const { return field_->name(); } + DataTypePtr Type() const { return field_->type(); } + + bool HasOffsetsIdx() const { return offsets_idx_ != kInvalidIdx; } + + bool HasDataBufferPtrIdx() const { return data_buffer_ptr_idx_ != kInvalidIdx; } + + private: + FieldPtr field_; + int data_idx_; + int validity_idx_; + int offsets_idx_; + int data_buffer_ptr_idx_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/filter.cc b/src/arrow/cpp/src/gandiva/filter.cc new file mode 100644 index 000000000..875cc5447 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/filter.cc @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/filter.h" + +#include <memory> +#include <thread> +#include <utility> +#include <vector> + +#include "arrow/util/hash_util.h" + +#include "gandiva/bitmap_accumulator.h" +#include "gandiva/cache.h" +#include "gandiva/condition.h" +#include "gandiva/expr_validator.h" +#include "gandiva/llvm_generator.h" +#include "gandiva/selection_vector_impl.h" + +namespace gandiva { + +FilterCacheKey::FilterCacheKey(SchemaPtr schema, + std::shared_ptr<Configuration> configuration, + Expression& expression) + : schema_(schema), configuration_(configuration), uniqifier_(0) { + static const int kSeedValue = 4; + size_t result = kSeedValue; + expression_as_string_ = expression.ToString(); + UpdateUniqifier(expression_as_string_); + arrow::internal::hash_combine(result, expression_as_string_); + arrow::internal::hash_combine(result, configuration); + arrow::internal::hash_combine(result, schema_->ToString()); + arrow::internal::hash_combine(result, uniqifier_); + hash_code_ = result; +} + +bool FilterCacheKey::operator==(const FilterCacheKey& other) const { + // arrow schema does not overload equality operators. + if (!(schema_->Equals(*other.schema().get(), true))) { + return false; + } + + if (configuration_ != other.configuration_) { + return false; + } + + if (expression_as_string_ != other.expression_as_string_) { + return false; + } + + if (uniqifier_ != other.uniqifier_) { + return false; + } + return true; +} + +std::string FilterCacheKey::ToString() const { + std::stringstream ss; + // indent, window, indent_size, null_rep and skip new lines. + arrow::PrettyPrintOptions options{0, 10, 2, "null", true}; + DCHECK_OK(PrettyPrint(*schema_.get(), options, &ss)); + + ss << "Condition: [" << expression_as_string_ << "]"; + return ss.str(); +} + +void FilterCacheKey::UpdateUniqifier(const std::string& expr) { + // caching of expressions with re2 patterns causes lock contention. So, use + // multiple instances to reduce contention. + if (expr.find(" like(") != std::string::npos) { + uniqifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16; + } +} + +Filter::Filter(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema, + std::shared_ptr<Configuration> configuration) + : llvm_generator_(std::move(llvm_generator)), + schema_(schema), + configuration_(configuration) {} + +Filter::~Filter() {} + +Status Filter::Make(SchemaPtr schema, ConditionPtr condition, + std::shared_ptr<Configuration> configuration, + std::shared_ptr<Filter>* filter) { + ARROW_RETURN_IF(schema == nullptr, Status::Invalid("Schema cannot be null")); + ARROW_RETURN_IF(condition == nullptr, Status::Invalid("Condition cannot be null")); + ARROW_RETURN_IF(configuration == nullptr, + Status::Invalid("Configuration cannot be null")); + + static Cache<FilterCacheKey, std::shared_ptr<Filter>> cache; + FilterCacheKey cache_key(schema, configuration, *(condition.get())); + auto cachedFilter = cache.GetModule(cache_key); + if (cachedFilter != nullptr) { + *filter = cachedFilter; + return Status::OK(); + } + + // Build LLVM generator, and generate code for the specified expression + std::unique_ptr<LLVMGenerator> llvm_gen; + ARROW_RETURN_NOT_OK(LLVMGenerator::Make(configuration, &llvm_gen)); + + // Run the validation on the expression. + // Return if the expression is invalid since we will not be able to process further. + ExprValidator expr_validator(llvm_gen->types(), schema); + ARROW_RETURN_NOT_OK(expr_validator.Validate(condition)); + + // Start measuring build time + auto begin = std::chrono::high_resolution_clock::now(); + ARROW_RETURN_NOT_OK(llvm_gen->Build({condition}, SelectionVector::Mode::MODE_NONE)); + // Stop measuring time and calculate the elapsed time + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count(); + + // Instantiate the filter with the completely built llvm generator + *filter = std::make_shared<Filter>(std::move(llvm_gen), schema, configuration); + ValueCacheObject<std::shared_ptr<Filter>> value_cache(*filter, elapsed); + cache.PutModule(cache_key, value_cache); + + return Status::OK(); +} + +Status Filter::Evaluate(const arrow::RecordBatch& batch, + std::shared_ptr<SelectionVector> out_selection) { + const auto num_rows = batch.num_rows(); + ARROW_RETURN_IF(!batch.schema()->Equals(*schema_), + Status::Invalid("RecordBatch schema must expected filter schema")); + ARROW_RETURN_IF(num_rows == 0, Status::Invalid("RecordBatch must be non-empty.")); + ARROW_RETURN_IF(out_selection == nullptr, + Status::Invalid("out_selection must be non-null.")); + ARROW_RETURN_IF(out_selection->GetMaxSlots() < num_rows, + Status::Invalid("Output selection vector capacity too small")); + + // Allocate three local_bitmaps (one for output, one for validity, one to compute the + // intersection). + LocalBitMapsHolder bitmaps(num_rows, 3 /*local_bitmaps*/); + int64_t bitmap_size = bitmaps.GetLocalBitMapSize(); + + auto validity = std::make_shared<arrow::Buffer>(bitmaps.GetLocalBitMap(0), bitmap_size); + auto value = std::make_shared<arrow::Buffer>(bitmaps.GetLocalBitMap(1), bitmap_size); + auto array_data = arrow::ArrayData::Make(arrow::boolean(), num_rows, {validity, value}); + + // Execute the expression(s). + ARROW_RETURN_NOT_OK(llvm_generator_->Execute(batch, {array_data})); + + // Compute the intersection of the value and validity. + auto result = bitmaps.GetLocalBitMap(2); + BitMapAccumulator::IntersectBitMaps( + result, {bitmaps.GetLocalBitMap(0), bitmaps.GetLocalBitMap((1))}, {0, 0}, num_rows); + + return out_selection->PopulateFromBitMap(result, bitmap_size, num_rows - 1); +} + +std::string Filter::DumpIR() { return llvm_generator_->DumpIR(); } + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/filter.h b/src/arrow/cpp/src/gandiva/filter.h new file mode 100644 index 000000000..70ccd7cf0 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/filter.h @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "arrow/status.h" + +#include "gandiva/arrow.h" +#include "gandiva/condition.h" +#include "gandiva/configuration.h" +#include "gandiva/selection_vector.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class LLVMGenerator; + +class FilterCacheKey { + public: + FilterCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration, + Expression& expression); + + std::size_t Hash() const { return hash_code_; } + + bool operator==(const FilterCacheKey& other) const; + + bool operator!=(const FilterCacheKey& other) const { return !(*this == other); } + + SchemaPtr schema() const { return schema_; } + + std::string ToString() const; + + private: + void UpdateUniqifier(const std::string& expr); + + const SchemaPtr schema_; + const std::shared_ptr<Configuration> configuration_; + std::string expression_as_string_; + size_t hash_code_; + uint32_t uniqifier_; +}; + +/// \brief filter records based on a condition. +/// +/// A filter is built for a specific schema and condition. Once the filter is built, it +/// can be used to evaluate many row batches. +class GANDIVA_EXPORT Filter { + public: + Filter(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema, + std::shared_ptr<Configuration> config); + + // Inline dtor will attempt to resolve the destructor for + // LLVMGenerator on MSVC, so we compile the dtor in the object code + ~Filter(); + + /// Build a filter for the given schema and condition, with the default configuration. + /// + /// \param[in] schema schema for the record batches, and the condition. + /// \param[in] condition filter condition. + /// \param[out] filter the returned filter object + static Status Make(SchemaPtr schema, ConditionPtr condition, + std::shared_ptr<Filter>* filter) { + return Make(schema, condition, ConfigurationBuilder::DefaultConfiguration(), filter); + } + + /// \brief Build a filter for the given schema and condition. + /// Customize the filter with runtime configuration. + /// + /// \param[in] schema schema for the record batches, and the condition. + /// \param[in] condition filter conditions. + /// \param[in] config run time configuration. + /// \param[out] filter the returned filter object + static Status Make(SchemaPtr schema, ConditionPtr condition, + std::shared_ptr<Configuration> config, + std::shared_ptr<Filter>* filter); + + /// Evaluate the specified record batch, and populate output selection vector. + /// + /// \param[in] batch the record batch. schema should be the same as the one in 'Make' + /// \param[in,out] out_selection the selection array with indices of rows that match + /// the condition. + Status Evaluate(const arrow::RecordBatch& batch, + std::shared_ptr<SelectionVector> out_selection); + + std::string DumpIR(); + + private: + std::unique_ptr<LLVMGenerator> llvm_generator_; + SchemaPtr schema_; + std::shared_ptr<Configuration> configuration_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/formatting_utils.h b/src/arrow/cpp/src/gandiva/formatting_utils.h new file mode 100644 index 000000000..7bc6a4969 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/formatting_utils.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/type.h" +#include "arrow/util/formatting.h" +#include "arrow/vendored/double-conversion/double-conversion.h" + +namespace gandiva { + +/// \brief The entry point for conversion to strings. +template <typename ARROW_TYPE, typename Enable = void> +class GdvStringFormatter; + +using double_conversion::DoubleToStringConverter; + +template <typename ARROW_TYPE> +class FloatToStringGdvMixin + : public arrow::internal::FloatToStringFormatterMixin<ARROW_TYPE> { + public: + using arrow::internal::FloatToStringFormatterMixin< + ARROW_TYPE>::FloatToStringFormatterMixin; + + // The mixin is a modified version of the existent FloatToStringFormatterMixin, but + // it defines some specific parameters in the FloatToStringFormatterMixin to cast + // the float numbers to string using the same patterns like Java. + // + // The Java real numbers are represented in two ways following these rules: + //- If the number is greater or equals than 10^7 and less than 10^(-3) + // it will be represented using scientific notation, e.g: + // - 0.000012 -> 1.2E-5 + // - 10000002.3 -> 1.00000023E7 + //- If the numbers are between that interval above, they are showed as is. + explicit FloatToStringGdvMixin(const std::shared_ptr<arrow::DataType>& = NULLPTR) + : arrow::internal::FloatToStringFormatterMixin<ARROW_TYPE>( + DoubleToStringConverter::EMIT_TRAILING_ZERO_AFTER_POINT | + DoubleToStringConverter::EMIT_TRAILING_DECIMAL_POINT, + "Infinity", "NaN", 'E', -3, 7, 3, 1) {} +}; + +template <> +class GdvStringFormatter<arrow::FloatType> + : public FloatToStringGdvMixin<arrow::FloatType> { + public: + using FloatToStringGdvMixin::FloatToStringGdvMixin; +}; + +template <> +class GdvStringFormatter<arrow::DoubleType> + : public FloatToStringGdvMixin<arrow::DoubleType> { + public: + using FloatToStringGdvMixin::FloatToStringGdvMixin; +}; +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/func_descriptor.h b/src/arrow/cpp/src/gandiva/func_descriptor.h new file mode 100644 index 000000000..a2bf3a16b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/func_descriptor.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <string> +#include <vector> + +#include "gandiva/arrow.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Descriptor for a function in the expression. +class GANDIVA_EXPORT FuncDescriptor { + public: + FuncDescriptor(const std::string& name, const DataTypeVector& params, + DataTypePtr return_type) + : name_(name), params_(params), return_type_(return_type) {} + + /// base function name. + const std::string& name() const { return name_; } + + /// Data types of the input params. + const DataTypeVector& params() const { return params_; } + + /// Data type of the return parameter. + DataTypePtr return_type() const { return return_type_; } + + private: + std::string name_; + DataTypeVector params_; + DataTypePtr return_type_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_holder.h b/src/arrow/cpp/src/gandiva/function_holder.h new file mode 100644 index 000000000..e3576f09c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_holder.h @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Holder for a function that can be invoked from LLVM. +class GANDIVA_EXPORT FunctionHolder { + public: + virtual ~FunctionHolder() = default; +}; + +using FunctionHolderPtr = std::shared_ptr<FunctionHolder>; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_holder_registry.h b/src/arrow/cpp/src/gandiva/function_holder_registry.h new file mode 100644 index 000000000..ced153891 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_holder_registry.h @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <functional> +#include <memory> +#include <string> +#include <unordered_map> + +#include "arrow/status.h" + +#include "gandiva/function_holder.h" +#include "gandiva/like_holder.h" +#include "gandiva/node.h" +#include "gandiva/random_generator_holder.h" +#include "gandiva/replace_holder.h" +#include "gandiva/to_date_holder.h" + +namespace gandiva { + +#define LAMBDA_MAKER(derived) \ + [](const FunctionNode& node, FunctionHolderPtr* holder) { \ + std::shared_ptr<derived> derived_instance; \ + auto status = derived::Make(node, &derived_instance); \ + if (status.ok()) { \ + *holder = derived_instance; \ + } \ + return status; \ + } + +/// Static registry of function holders. +class FunctionHolderRegistry { + public: + using maker_type = std::function<Status(const FunctionNode&, FunctionHolderPtr*)>; + using map_type = std::unordered_map<std::string, maker_type>; + + static Status Make(const std::string& name, const FunctionNode& node, + FunctionHolderPtr* holder) { + auto found = makers().find(name); + if (found == makers().end()) { + return Status::Invalid("function holder not registered for function " + name); + } + + return found->second(node, holder); + } + + private: + static map_type& makers() { + static map_type maker_map = { + {"like", LAMBDA_MAKER(LikeHolder)}, + {"ilike", LAMBDA_MAKER(LikeHolder)}, + {"to_date", LAMBDA_MAKER(ToDateHolder)}, + {"random", LAMBDA_MAKER(RandomGeneratorHolder)}, + {"rand", LAMBDA_MAKER(RandomGeneratorHolder)}, + {"regexp_replace", LAMBDA_MAKER(ReplaceHolder)}, + }; + return maker_map; + } +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_ir_builder.cc b/src/arrow/cpp/src/gandiva/function_ir_builder.cc new file mode 100644 index 000000000..194273933 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_ir_builder.cc @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_ir_builder.h" + +namespace gandiva { + +llvm::Value* FunctionIRBuilder::BuildIfElse(llvm::Value* condition, + llvm::Type* return_type, + std::function<llvm::Value*()> then_func, + std::function<llvm::Value*()> else_func) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Function* function = builder->GetInsertBlock()->getParent(); + DCHECK_NE(function, nullptr); + + // Create blocks for the then, else and merge cases. + llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context(), "then", function); + llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context(), "else", function); + llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context(), "merge", function); + + builder->CreateCondBr(condition, then_bb, else_bb); + + // Emit the then block. + builder->SetInsertPoint(then_bb); + auto then_value = then_func(); + builder->CreateBr(merge_bb); + + // refresh then_bb for phi (could have changed due to code generation of then_value). + then_bb = builder->GetInsertBlock(); + + // Emit the else block. + builder->SetInsertPoint(else_bb); + auto else_value = else_func(); + builder->CreateBr(merge_bb); + + // refresh else_bb for phi (could have changed due to code generation of else_value). + else_bb = builder->GetInsertBlock(); + + // Emit the merge block. + builder->SetInsertPoint(merge_bb); + llvm::PHINode* result_value = builder->CreatePHI(return_type, 2, "res_value"); + result_value->addIncoming(then_value, then_bb); + result_value->addIncoming(else_value, else_bb); + return result_value; +} + +llvm::Function* FunctionIRBuilder::BuildFunction(const std::string& function_name, + llvm::Type* return_type, + std::vector<NamedArg> in_args) { + std::vector<llvm::Type*> arg_types; + for (auto& arg : in_args) { + arg_types.push_back(arg.type); + } + auto prototype = llvm::FunctionType::get(return_type, arg_types, false /*isVarArg*/); + auto function = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, + function_name, module()); + + uint32_t i = 0; + for (auto& fn_arg : function->args()) { + DCHECK_LT(i, in_args.size()); + fn_arg.setName(in_args[i].name); + ++i; + } + return function; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_ir_builder.h b/src/arrow/cpp/src/gandiva/function_ir_builder.h new file mode 100644 index 000000000..388f55840 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_ir_builder.h @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> +#include <string> +#include <vector> + +#include "gandiva/engine.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/llvm_types.h" + +namespace gandiva { + +/// @brief Base class for building IR functions. +class FunctionIRBuilder { + public: + explicit FunctionIRBuilder(Engine* engine) : engine_(engine) {} + virtual ~FunctionIRBuilder() = default; + + protected: + LLVMTypes* types() { return engine_->types(); } + llvm::Module* module() { return engine_->module(); } + llvm::LLVMContext* context() { return engine_->context(); } + llvm::IRBuilder<>* ir_builder() { return engine_->ir_builder(); } + + /// Build an if-else block. + llvm::Value* BuildIfElse(llvm::Value* condition, llvm::Type* return_type, + std::function<llvm::Value*()> then_func, + std::function<llvm::Value*()> else_func); + + struct NamedArg { + std::string name; + llvm::Type* type; + }; + + /// Build llvm fn. + llvm::Function* BuildFunction(const std::string& function_name, llvm::Type* return_type, + std::vector<NamedArg> in_args); + + private: + Engine* engine_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry.cc b/src/arrow/cpp/src/gandiva/function_registry.cc new file mode 100644 index 000000000..d5d015c10 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry.cc @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry.h" +#include "gandiva/function_registry_arithmetic.h" +#include "gandiva/function_registry_datetime.h" +#include "gandiva/function_registry_hash.h" +#include "gandiva/function_registry_math_ops.h" +#include "gandiva/function_registry_string.h" +#include "gandiva/function_registry_timestamp_arithmetic.h" + +#include <iterator> +#include <utility> +#include <vector> + +namespace gandiva { + +FunctionRegistry::iterator FunctionRegistry::begin() const { + return &(*pc_registry_.begin()); +} + +FunctionRegistry::iterator FunctionRegistry::end() const { + return &(*pc_registry_.end()); +} + +FunctionRegistry::iterator FunctionRegistry::back() const { + return &(pc_registry_.back()); +} + +std::vector<NativeFunction> FunctionRegistry::pc_registry_; + +SignatureMap FunctionRegistry::pc_registry_map_ = InitPCMap(); + +SignatureMap FunctionRegistry::InitPCMap() { + SignatureMap map; + + auto v1 = GetArithmeticFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v1.begin(), v1.end()); + auto v2 = GetDateTimeFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v2.begin(), v2.end()); + + auto v3 = GetHashFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v3.begin(), v3.end()); + + auto v4 = GetMathOpsFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v4.begin(), v4.end()); + + auto v5 = GetStringFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v5.begin(), v5.end()); + + auto v6 = GetDateTimeArithmeticFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end()); + + for (auto& elem : pc_registry_) { + for (auto& func_signature : elem.signatures()) { + map.insert(std::make_pair(&(func_signature), &elem)); + } + } + + return map; +} + +const NativeFunction* FunctionRegistry::LookupSignature( + const FunctionSignature& signature) const { + auto got = pc_registry_map_.find(&signature); + return got == pc_registry_map_.end() ? nullptr : got->second; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry.h b/src/arrow/cpp/src/gandiva/function_registry.h new file mode 100644 index 000000000..d92563260 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/function_registry_common.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/native_function.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +///\brief Registry of pre-compiled IR functions. +class GANDIVA_EXPORT FunctionRegistry { + public: + using iterator = const NativeFunction*; + + /// Lookup a pre-compiled function by its signature. + const NativeFunction* LookupSignature(const FunctionSignature& signature) const; + + iterator begin() const; + iterator end() const; + iterator back() const; + + private: + static SignatureMap InitPCMap(); + + static std::vector<NativeFunction> pc_registry_; + static SignatureMap pc_registry_map_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_arithmetic.cc b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.cc new file mode 100644 index 000000000..f34289f37 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.cc @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_arithmetic.h" +#include "gandiva/function_registry_common.h" + +namespace gandiva { + +#define BINARY_SYMMETRIC_FN(name, ALIASES) \ + NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, name, ALIASES) + +#define BINARY_RELATIONAL_BOOL_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES) + +#define BINARY_RELATIONAL_BOOL_DATE_FN(name, ALIASES) \ + NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES) + +#define UNARY_CAST_TO_FLOAT64(type) UNARY_SAFE_NULL_IF_NULL(castFLOAT8, {}, type, float64) + +#define UNARY_CAST_TO_FLOAT32(type) UNARY_SAFE_NULL_IF_NULL(castFLOAT4, {}, type, float32) + +#define UNARY_CAST_TO_INT32(type) UNARY_SAFE_NULL_IF_NULL(castINT, {}, type, int32) + +#define UNARY_CAST_TO_INT64(type) UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, type, int64) + +std::vector<NativeFunction> GetArithmeticFunctionRegistry() { + static std::vector<NativeFunction> arithmetic_fn_registry_ = { + UNARY_SAFE_NULL_IF_NULL(not, {}, boolean, boolean), + UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, int32, int64), + UNARY_SAFE_NULL_IF_NULL(castINT, {}, int64, int32), + UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, decimal128, int64), + + // cast to float32 + UNARY_CAST_TO_FLOAT32(int32), UNARY_CAST_TO_FLOAT32(int64), + UNARY_CAST_TO_FLOAT32(float64), + + // cast to int32 + UNARY_CAST_TO_INT32(float32), UNARY_CAST_TO_INT32(float64), + + // cast to int64 + UNARY_CAST_TO_INT64(float32), UNARY_CAST_TO_INT64(float64), + + // cast to float64 + UNARY_CAST_TO_FLOAT64(int32), UNARY_CAST_TO_FLOAT64(int64), + UNARY_CAST_TO_FLOAT64(float32), UNARY_CAST_TO_FLOAT64(decimal128), + + // cast to decimal + UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int32, decimal128), + UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int64, decimal128), + UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float32, decimal128), + UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float64, decimal128), + UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, decimal128, decimal128), + UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, {}, utf8, decimal128), + + NativeFunction("castDECIMALNullOnOverflow", {}, DataTypeVector{decimal128()}, + decimal128(), kResultNullInternal, + "castDECIMALNullOnOverflow_decimal128"), + + UNARY_SAFE_NULL_IF_NULL(castDATE, {}, int64, date64), + UNARY_SAFE_NULL_IF_NULL(castDATE, {}, int32, date32), + UNARY_SAFE_NULL_IF_NULL(castDATE, {}, date32, date64), + + // add/sub/multiply/divide/mod + BINARY_SYMMETRIC_FN(add, {}), BINARY_SYMMETRIC_FN(subtract, {}), + BINARY_SYMMETRIC_FN(multiply, {}), + NUMERIC_TYPES(BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL, divide, {}), + BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int32, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int64, int64), + BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, decimal128), + BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, float64), + BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int32), + BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int64), + BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float32), + BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float64), + + // bitwise operators + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, int32), + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_and, {}, int64), + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, int32), + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_or, {}, int64), + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, int32), + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(bitwise_xor, {}, int64), + UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int32, int32), + UNARY_SAFE_NULL_IF_NULL(bitwise_not, {}, int64, int64), + + // round functions + UNARY_SAFE_NULL_IF_NULL(round, {}, float32, float32), + UNARY_SAFE_NULL_IF_NULL(round, {}, float64, float64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, float32, int32, float32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, float64, int32, float64), + UNARY_SAFE_NULL_IF_NULL(round, {}, int32, int32), + UNARY_SAFE_NULL_IF_NULL(round, {}, int64, int64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, int32, int32, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, int64, int32, int64), + + // compare functions + BINARY_RELATIONAL_BOOL_FN(equal, ({"eq", "same"})), + BINARY_RELATIONAL_BOOL_FN(not_equal, {}), + BINARY_RELATIONAL_BOOL_DATE_FN(less_than, {}), + BINARY_RELATIONAL_BOOL_DATE_FN(less_than_or_equal_to, {}), + BINARY_RELATIONAL_BOOL_DATE_FN(greater_than, {}), + BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_or_equal_to, {}), + + // binary representation of integer values + UNARY_UNSAFE_NULL_IF_NULL(bin, {}, int32, utf8), + UNARY_UNSAFE_NULL_IF_NULL(bin, {}, int64, utf8)}; + + return arithmetic_fn_registry_; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_arithmetic.h b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.h new file mode 100644 index 000000000..693d3b95e --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_arithmetic.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector<NativeFunction> GetArithmeticFunctionRegistry(); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_common.h b/src/arrow/cpp/src/gandiva/function_registry_common.h new file mode 100644 index 000000000..66f945150 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_common.h @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "gandiva/arrow.h" +#include "gandiva/function_signature.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/native_function.h" + +/* This is a private file, intended for internal use by gandiva & must not be included + * directly. + */ +namespace gandiva { + +using arrow::binary; +using arrow::boolean; +using arrow::date32; +using arrow::date64; +using arrow::day_time_interval; +using arrow::float32; +using arrow::float64; +using arrow::int16; +using arrow::int32; +using arrow::int64; +using arrow::int8; +using arrow::month_interval; +using arrow::uint16; +using arrow::uint32; +using arrow::uint64; +using arrow::uint8; +using arrow::utf8; + +inline DataTypePtr time32() { return arrow::time32(arrow::TimeUnit::MILLI); } + +inline DataTypePtr time64() { return arrow::time64(arrow::TimeUnit::MICRO); } + +inline DataTypePtr timestamp() { return arrow::timestamp(arrow::TimeUnit::MILLI); } +inline DataTypePtr decimal128() { return arrow::decimal(38, 0); } + +struct KeyHash { + std::size_t operator()(const FunctionSignature* k) const { return k->Hash(); } +}; + +struct KeyEquals { + bool operator()(const FunctionSignature* s1, const FunctionSignature* s2) const { + return *s1 == *s2; + } +}; + +typedef std::unordered_map<const FunctionSignature*, const NativeFunction*, KeyHash, + KeyEquals> + SignatureMap; + +// Binary functions that : +// - have the same input type for both params +// - output type is same as the input type +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type names. eg. add_int32_int32 +#define BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, \ + DataTypeVector{TYPE(), TYPE()}, TYPE(), kResultNullIfNull, \ + ARROW_STRINGIFY(NAME##_##TYPE##_##TYPE)) + +// Binary functions that : +// - have the same input type for both params +// - NULL handling is of type NULL_IINTERNAL +// - can return error. +// +// The pre-compiled fn name includes the base name & input type names. eg. add_int32_int32 +#define BINARY_UNSAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, \ + DataTypeVector{IN_TYPE(), IN_TYPE()}, OUT_TYPE(), kResultNullIfNull, \ + ARROW_STRINGIFY(NAME##_##IN_TYPE##_##IN_TYPE), \ + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors) + +#define BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + BINARY_UNSAFE_NULL_IF_NULL(NAME, ALIASES, TYPE, TYPE) + +// Binary functions that : +// - have different input types, or output type +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type names. eg. mod_int64_int32 +#define BINARY_GENERIC_SAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE1, IN_TYPE2, OUT_TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, \ + DataTypeVector{IN_TYPE1(), IN_TYPE2()}, OUT_TYPE(), kResultNullIfNull, \ + ARROW_STRINGIFY(NAME##_##IN_TYPE1##_##IN_TYPE2)) + +// Binary functions that : +// - have the same input type +// - output type is boolean +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type names. +// eg. equal_int32_int32 +#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, \ + DataTypeVector{TYPE(), TYPE()}, boolean(), kResultNullIfNull, \ + ARROW_STRINGIFY(NAME##_##TYPE##_##TYPE)) + +// Unary functions that : +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type name. eg. castFloat_int32 +#define UNARY_SAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{IN_TYPE()}, \ + OUT_TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##IN_TYPE)) + +// Unary functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. eg. isnull_int32 +#define UNARY_SAFE_NULL_NEVER_BOOL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + boolean(), kResultNullNever, ARROW_STRINGIFY(NAME##_##TYPE)) + +// Unary functions that : +// - NULL handling is of type NULL_INTERNAL +// +// The pre-compiled fn name includes the base name & input type name. eg. castFloat_int32 +#define UNARY_UNSAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{IN_TYPE()}, \ + OUT_TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##IN_TYPE), \ + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors) + +// Binary functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type names, +// eg. is_distinct_from_int32_int32 +#define BINARY_SAFE_NULL_NEVER_BOOL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, \ + DataTypeVector{TYPE(), TYPE()}, boolean(), kResultNullNever, \ + ARROW_STRINGIFY(NAME##_##TYPE##_##TYPE)) + +// Extract functions (used with data/time types) that : +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type name. eg. extractYear_date +#define EXTRACT_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + int64(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE)) + +#define TRUNCATE_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE)) + +// Last day functions (used with data/time types) that : +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type name. eg: +// - last_day_from_date64 +#define LAST_DAY_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + date64(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_from_##TYPE)) + +// Hash32 functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32_int8 +#define HASH32_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + int32(), kResultNullNever, ARROW_STRINGIFY(NAME##_##TYPE)) + +// Hash32 functions that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32_int8 +#define HASH64_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + int64(), kResultNullNever, ARROW_STRINGIFY(NAME##_##TYPE)) + +// Hash32 functions with seed that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8 +#define HASH32_SEED_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, \ + DataTypeVector{TYPE(), int32()}, int32(), kResultNullNever, \ + ARROW_STRINGIFY(NAME##WithSeed_##TYPE)) + +// Hash64 functions with seed that : +// - NULL handling is of type NULL_NEVER +// +// The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8 +#define HASH64_SEED_SAFE_NULL_NEVER(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, \ + DataTypeVector{TYPE(), int64()}, int64(), kResultNullNever, \ + ARROW_STRINGIFY(NAME##WithSeed_##TYPE)) + +// HashSHA1 functions that : +// - NULL handling is of type NULL_NEVER +// - can return errors +// +// The function name includes the base name & input type name. gdv_fn_sha1_float64 +#define HASH_SHA1_NULL_NEVER(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, {"sha", "sha1"}, DataTypeVector{TYPE()}, utf8(), \ + kResultNullNever, ARROW_STRINGIFY(gdv_fn_sha1_##TYPE), \ + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors) + +// HashSHA256 functions that : +// - NULL handling is of type NULL_NEVER +// - can return errors +// +// The function name includes the base name & input type name. gdv_fn_sha256_float64 +#define HASH_SHA256_NULL_NEVER(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, {"sha256"}, DataTypeVector{TYPE()}, utf8(), kResultNullNever, \ + ARROW_STRINGIFY(gdv_fn_sha256_##TYPE), \ + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors) + +// Iterate the inner macro over all numeric types +#define BASE_NUMERIC_TYPES(INNER, NAME, ALIASES) \ + INNER(NAME, ALIASES, int8), INNER(NAME, ALIASES, int16), INNER(NAME, ALIASES, int32), \ + INNER(NAME, ALIASES, int64), INNER(NAME, ALIASES, uint8), \ + INNER(NAME, ALIASES, uint16), INNER(NAME, ALIASES, uint32), \ + INNER(NAME, ALIASES, uint64), INNER(NAME, ALIASES, float32), \ + INNER(NAME, ALIASES, float64) + +// Iterate the inner macro over all base numeric types +#define NUMERIC_TYPES(INNER, NAME, ALIASES) \ + BASE_NUMERIC_TYPES(INNER, NAME, ALIASES), INNER(NAME, ALIASES, decimal128) + +// Iterate the inner macro over numeric and date/time types +#define NUMERIC_DATE_TYPES(INNER, NAME, ALIASES) \ + NUMERIC_TYPES(INNER, NAME, ALIASES), DATE_TYPES(INNER, NAME, ALIASES), \ + TIME_TYPES(INNER, NAME, ALIASES), INNER(NAME, ALIASES, date32) + +// Iterate the inner macro over all date types +#define DATE_TYPES(INNER, NAME, ALIASES) \ + INNER(NAME, ALIASES, date64), INNER(NAME, ALIASES, timestamp) + +// Iterate the inner macro over all time types +#define TIME_TYPES(INNER, NAME, ALIASES) INNER(NAME, ALIASES, time32) + +// Iterate the inner macro over all data types +#define VAR_LEN_TYPES(INNER, NAME, ALIASES) \ + INNER(NAME, ALIASES, utf8), INNER(NAME, ALIASES, binary) + +// Iterate the inner macro over all numeric types, date types and bool type +#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME, ALIASES) \ + NUMERIC_DATE_TYPES(INNER, NAME, ALIASES), INNER(NAME, ALIASES, boolean) + +// Iterate the inner macro over all numeric types, date types, bool and varlen types +#define NUMERIC_BOOL_DATE_VAR_LEN_TYPES(INNER, NAME, ALIASES) \ + NUMERIC_BOOL_DATE_TYPES(INNER, NAME, ALIASES), VAR_LEN_TYPES(INNER, NAME, ALIASES) + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_datetime.cc b/src/arrow/cpp/src/gandiva/function_registry_datetime.cc new file mode 100644 index 000000000..b8d2e7b6c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_datetime.cc @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_datetime.h" + +#include "gandiva/function_registry_common.h" + +namespace gandiva { + +#define DATE_EXTRACTION_TRUNCATION_FNS(INNER, name) \ + DATE_TYPES(INNER, name##Millennium, {}), DATE_TYPES(INNER, name##Century, {}), \ + DATE_TYPES(INNER, name##Decade, {}), DATE_TYPES(INNER, name##Year, {"year"}), \ + DATE_TYPES(INNER, name##Quarter, {}), DATE_TYPES(INNER, name##Month, {"month"}), \ + DATE_TYPES(INNER, name##Week, ({"weekofyear", "yearweek"})), \ + DATE_TYPES(INNER, name##Day, ({"day", "dayofmonth"})), \ + DATE_TYPES(INNER, name##Hour, {"hour"}), \ + DATE_TYPES(INNER, name##Minute, {"minute"}), \ + DATE_TYPES(INNER, name##Second, {"second"}) + +#define TO_TIMESTAMP_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + timestamp(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE)) + +#define TO_TIME_SAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \ + NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{TYPE()}, \ + time32(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##TYPE)) + +#define TIME_EXTRACTION_FNS(name) \ + TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, name##Hour, {"hour"}), \ + TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, name##Minute, {"minute"}), \ + TIME_TYPES(EXTRACT_SAFE_NULL_IF_NULL, name##Second, {"second"}) + +std::vector<NativeFunction> GetDateTimeFunctionRegistry() { + static std::vector<NativeFunction> date_time_fn_registry_ = { + DATE_EXTRACTION_TRUNCATION_FNS(EXTRACT_SAFE_NULL_IF_NULL, extract), + DATE_EXTRACTION_TRUNCATION_FNS(TRUNCATE_SAFE_NULL_IF_NULL, date_trunc_), + + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDoy, {}), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractDow, {}), + DATE_TYPES(EXTRACT_SAFE_NULL_IF_NULL, extractEpoch, {}), + + TIME_EXTRACTION_FNS(extract), + + NativeFunction("castDATE", {}, DataTypeVector{utf8()}, date64(), kResultNullIfNull, + "castDATE_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castTIMESTAMP", {}, DataTypeVector{utf8()}, timestamp(), + kResultNullIfNull, "castTIMESTAMP_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARCHAR", {}, DataTypeVector{timestamp(), int64()}, utf8(), + kResultNullIfNull, "castVARCHAR_timestamp_int64", + NativeFunction::kNeedsContext), + + NativeFunction("to_date", {}, DataTypeVector{utf8(), utf8()}, date64(), + kResultNullInternal, "gdv_fn_to_date_utf8_utf8", + NativeFunction::kNeedsContext | + NativeFunction::kNeedsFunctionHolder | + NativeFunction::kCanReturnErrors), + + NativeFunction("to_date", {}, DataTypeVector{utf8(), utf8(), int32()}, date64(), + kResultNullInternal, "gdv_fn_to_date_utf8_utf8_int32", + NativeFunction::kNeedsContext | + NativeFunction::kNeedsFunctionHolder | + NativeFunction::kCanReturnErrors), + NativeFunction("castTIMESTAMP", {}, DataTypeVector{date64()}, timestamp(), + kResultNullIfNull, "castTIMESTAMP_date64"), + + NativeFunction("castTIMESTAMP", {}, DataTypeVector{int64()}, timestamp(), + kResultNullIfNull, "castTIMESTAMP_int64"), + + NativeFunction("castDATE", {"to_date"}, DataTypeVector{timestamp()}, date64(), + kResultNullIfNull, "castDATE_timestamp"), + + NativeFunction("castTIME", {}, DataTypeVector{timestamp()}, time32(), + kResultNullIfNull, "castTIME_timestamp"), + + NativeFunction("castBIGINT", {}, DataTypeVector{day_time_interval()}, int64(), + kResultNullIfNull, "castBIGINT_daytimeinterval"), + + NativeFunction("castINT", {"castNULLABLEINT"}, DataTypeVector{month_interval()}, + int32(), kResultNullIfNull, "castINT_year_interval", + NativeFunction::kCanReturnErrors), + + NativeFunction("castBIGINT", {"castNULLABLEBIGINT"}, + DataTypeVector{month_interval()}, int64(), kResultNullIfNull, + "castBIGINT_year_interval", NativeFunction::kCanReturnErrors), + + NativeFunction("castNULLABLEINTERVALYEAR", {"castINTERVALYEAR"}, + DataTypeVector{int32()}, month_interval(), kResultNullIfNull, + "castNULLABLEINTERVALYEAR_int32", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castNULLABLEINTERVALYEAR", {"castINTERVALYEAR"}, + DataTypeVector{int64()}, month_interval(), kResultNullIfNull, + "castNULLABLEINTERVALYEAR_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castNULLABLEINTERVALDAY", {"castINTERVALDAY"}, + DataTypeVector{int32()}, day_time_interval(), kResultNullIfNull, + "castNULLABLEINTERVALDAY_int32"), + + NativeFunction("castNULLABLEINTERVALDAY", {"castINTERVALDAY"}, + DataTypeVector{int64()}, day_time_interval(), kResultNullIfNull, + "castNULLABLEINTERVALDAY_int64"), + + NativeFunction("extractDay", {}, DataTypeVector{day_time_interval()}, int64(), + kResultNullIfNull, "extractDay_daytimeinterval"), + + DATE_TYPES(LAST_DAY_SAFE_NULL_IF_NULL, last_day, {}), + BASE_NUMERIC_TYPES(TO_TIME_SAFE_NULL_IF_NULL, to_time, {}), + BASE_NUMERIC_TYPES(TO_TIMESTAMP_SAFE_NULL_IF_NULL, to_timestamp, {})}; + + return date_time_fn_registry_; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_datetime.h b/src/arrow/cpp/src/gandiva/function_registry_datetime.h new file mode 100644 index 000000000..46172ec62 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_datetime.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector<NativeFunction> GetDateTimeFunctionRegistry(); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_hash.cc b/src/arrow/cpp/src/gandiva/function_registry_hash.cc new file mode 100644 index 000000000..7fad9321e --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_hash.cc @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_hash.h" +#include "gandiva/function_registry_common.h" + +namespace gandiva { + +#define HASH32_SAFE_NULL_NEVER_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SAFE_NULL_NEVER, name, ALIASES) + +#define HASH32_SEED_SAFE_NULL_NEVER_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH32_SEED_SAFE_NULL_NEVER, name, ALIASES) + +#define HASH64_SAFE_NULL_NEVER_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SAFE_NULL_NEVER, name, ALIASES) + +#define HASH64_SEED_SAFE_NULL_NEVER_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH64_SEED_SAFE_NULL_NEVER, name, ALIASES) + +#define HASH_SHA1_NULL_NEVER_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH_SHA1_NULL_NEVER, name, ALIASES) + +#define HASH_SHA256_NULL_NEVER_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_VAR_LEN_TYPES(HASH_SHA256_NULL_NEVER, name, ALIASES) + +std::vector<NativeFunction> GetHashFunctionRegistry() { + static std::vector<NativeFunction> hash_fn_registry_ = { + HASH32_SAFE_NULL_NEVER_FN(hash, {}), + HASH32_SAFE_NULL_NEVER_FN(hash32, {}), + HASH32_SAFE_NULL_NEVER_FN(hash32AsDouble, {}), + + HASH32_SEED_SAFE_NULL_NEVER_FN(hash32, {}), + HASH32_SEED_SAFE_NULL_NEVER_FN(hash32AsDouble, {}), + + HASH64_SAFE_NULL_NEVER_FN(hash64, {}), + HASH64_SAFE_NULL_NEVER_FN(hash64AsDouble, {}), + + HASH64_SEED_SAFE_NULL_NEVER_FN(hash64, {}), + HASH64_SEED_SAFE_NULL_NEVER_FN(hash64AsDouble, {}), + + HASH_SHA1_NULL_NEVER_FN(hashSHA1, {}), + + HASH_SHA256_NULL_NEVER_FN(hashSHA256, {})}; + + return hash_fn_registry_; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_hash.h b/src/arrow/cpp/src/gandiva/function_registry_hash.h new file mode 100644 index 000000000..4f96d30cf --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_hash.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector<NativeFunction> GetHashFunctionRegistry(); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_math_ops.cc b/src/arrow/cpp/src/gandiva/function_registry_math_ops.cc new file mode 100644 index 000000000..49afd4003 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_math_ops.cc @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_math_ops.h" +#include "gandiva/function_registry_common.h" + +namespace gandiva { + +#define MATH_UNARY_OPS(name, ALIASES) \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int32, float64), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int64, float64), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint32, float64), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint64, float64), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float64), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64) + +#define MATH_BINARY_UNSAFE(name, ALIASES) \ + BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int32, float64), \ + BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int64, float64), \ + BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, uint32, float64), \ + BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, uint64, float64), \ + BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, float32, float64), \ + BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, float64, float64) + +#define MATH_BINARY_SAFE(name, ALIASES) \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, int32, float64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, int64, float64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, uint32, uint32, float64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, uint64, uint64, float64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32, float64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64, float64) + +#define UNARY_SAFE_NULL_NEVER_BOOL_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, name, ALIASES) + +#define BINARY_SAFE_NULL_NEVER_BOOL_FN(name, ALIASES) \ + NUMERIC_BOOL_DATE_TYPES(BINARY_SAFE_NULL_NEVER_BOOL, name, ALIASES) + +std::vector<NativeFunction> GetMathOpsFunctionRegistry() { + static std::vector<NativeFunction> math_fn_registry_ = { + MATH_UNARY_OPS(cbrt, {}), MATH_UNARY_OPS(exp, {}), MATH_UNARY_OPS(log, {}), + MATH_UNARY_OPS(log10, {}), + + MATH_BINARY_UNSAFE(log, {}), + + BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(power, {"pow"}, float64), + + UNARY_SAFE_NULL_NEVER_BOOL_FN(isnull, {}), + UNARY_SAFE_NULL_NEVER_BOOL_FN(isnotnull, {}), + + NUMERIC_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, isnumeric, {}), + + BINARY_SAFE_NULL_NEVER_BOOL_FN(is_distinct_from, {}), + BINARY_SAFE_NULL_NEVER_BOOL_FN(is_not_distinct_from, {}), + + // trigonometry functions + MATH_UNARY_OPS(sin, {}), MATH_UNARY_OPS(cos, {}), MATH_UNARY_OPS(asin, {}), + MATH_UNARY_OPS(acos, {}), MATH_UNARY_OPS(tan, {}), MATH_UNARY_OPS(atan, {}), + MATH_UNARY_OPS(sinh, {}), MATH_UNARY_OPS(cosh, {}), MATH_UNARY_OPS(tanh, {}), + MATH_UNARY_OPS(cot, {}), MATH_UNARY_OPS(radians, {}), MATH_UNARY_OPS(degrees, {}), + MATH_BINARY_SAFE(atan2, {}), + + // decimal functions + UNARY_SAFE_NULL_IF_NULL(abs, {}, decimal128, decimal128), + UNARY_SAFE_NULL_IF_NULL(ceil, {}, decimal128, decimal128), + UNARY_SAFE_NULL_IF_NULL(floor, {}, decimal128, decimal128), + UNARY_SAFE_NULL_IF_NULL(round, {}, decimal128, decimal128), + UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, decimal128, decimal128), + BINARY_GENERIC_SAFE_NULL_IF_NULL(round, {}, decimal128, int32, decimal128), + BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, decimal128, int32, + decimal128), + + NativeFunction("truncate", {"trunc"}, DataTypeVector{int64(), int32()}, int64(), + kResultNullIfNull, "truncate_int64_int32"), + NativeFunction("random", {"rand"}, DataTypeVector{}, float64(), kResultNullNever, + "gdv_fn_random", NativeFunction::kNeedsFunctionHolder), + NativeFunction("random", {"rand"}, DataTypeVector{int32()}, float64(), + kResultNullNever, "gdv_fn_random_with_seed", + NativeFunction::kNeedsFunctionHolder)}; + + return math_fn_registry_; +} + +#undef MATH_UNARY_OPS + +#undef MATH_BINARY_UNSAFE + +#undef UNARY_SAFE_NULL_NEVER_BOOL_FN + +#undef BINARY_SAFE_NULL_NEVER_BOOL_FN + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_math_ops.h b/src/arrow/cpp/src/gandiva/function_registry_math_ops.h new file mode 100644 index 000000000..2c8a40d53 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_math_ops.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector<NativeFunction> GetMathOpsFunctionRegistry(); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_string.cc b/src/arrow/cpp/src/gandiva/function_registry_string.cc new file mode 100644 index 000000000..3ea426c85 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_string.cc @@ -0,0 +1,422 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_string.h" + +#include "gandiva/function_registry_common.h" + +namespace gandiva { + +#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(name, ALIASES) \ + VAR_LEN_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES) + +#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(name, ALIASES) \ + BINARY_RELATIONAL_SAFE_NULL_IF_NULL(name, ALIASES, utf8) + +#define UNARY_OCTET_LEN_FN(name, ALIASES) \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, utf8, int32), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, binary, int32) + +#define UNARY_SAFE_NULL_NEVER_BOOL_FN(name, ALIASES) \ + VAR_LEN_TYPES(UNARY_SAFE_NULL_NEVER_BOOL, name, ALIASES) + +std::vector<NativeFunction> GetStringFunctionRegistry() { + static std::vector<NativeFunction> string_fn_registry_ = { + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(equal, {}), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(not_equal, {}), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(less_than, {}), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(less_than_or_equal_to, {}), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(greater_than, {}), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN(greater_than_or_equal_to, {}), + + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(starts_with, {}), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(ends_with, {}), + BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN(is_substr, {}), + + BINARY_UNSAFE_NULL_IF_NULL(locate, {"position"}, utf8, int32), + BINARY_UNSAFE_NULL_IF_NULL(strpos, {}, utf8, int32), + + UNARY_OCTET_LEN_FN(octet_length, {}), UNARY_OCTET_LEN_FN(bit_length, {}), + + UNARY_UNSAFE_NULL_IF_NULL(char_length, {}, utf8, int32), + UNARY_UNSAFE_NULL_IF_NULL(length, {}, utf8, int32), + UNARY_UNSAFE_NULL_IF_NULL(lengthUtf8, {}, binary, int32), + UNARY_UNSAFE_NULL_IF_NULL(reverse, {}, utf8, utf8), + UNARY_UNSAFE_NULL_IF_NULL(ltrim, {}, utf8, utf8), + UNARY_UNSAFE_NULL_IF_NULL(rtrim, {}, utf8, utf8), + UNARY_UNSAFE_NULL_IF_NULL(btrim, {}, utf8, utf8), + UNARY_UNSAFE_NULL_IF_NULL(space, {}, int32, utf8), + UNARY_UNSAFE_NULL_IF_NULL(space, {}, int64, utf8), + + UNARY_SAFE_NULL_NEVER_BOOL_FN(isnull, {}), + UNARY_SAFE_NULL_NEVER_BOOL_FN(isnotnull, {}), + + NativeFunction("ascii", {}, DataTypeVector{utf8()}, int32(), kResultNullIfNull, + "ascii_utf8"), + + NativeFunction("base64", {}, DataTypeVector{binary()}, utf8(), kResultNullIfNull, + "gdv_fn_base64_encode_binary", NativeFunction::kNeedsContext), + + NativeFunction("unbase64", {}, DataTypeVector{utf8()}, binary(), kResultNullIfNull, + "gdv_fn_base64_decode_utf8", NativeFunction::kNeedsContext), + + NativeFunction("repeat", {}, DataTypeVector{utf8(), int32()}, utf8(), + kResultNullIfNull, "repeat_utf8_int32", + NativeFunction::kNeedsContext), + + NativeFunction("upper", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, + "gdv_fn_upper_utf8", NativeFunction::kNeedsContext), + + NativeFunction("lower", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, + "gdv_fn_lower_utf8", NativeFunction::kNeedsContext), + + NativeFunction("initcap", {}, DataTypeVector{utf8()}, utf8(), kResultNullIfNull, + "gdv_fn_initcap_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castBIT", {"castBOOLEAN"}, DataTypeVector{utf8()}, boolean(), + kResultNullIfNull, "castBIT_utf8", NativeFunction::kNeedsContext), + + NativeFunction("castINT", {}, DataTypeVector{utf8()}, int32(), kResultNullIfNull, + "gdv_fn_castINT_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castBIGINT", {}, DataTypeVector{utf8()}, int64(), kResultNullIfNull, + "gdv_fn_castBIGINT_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castFLOAT4", {}, DataTypeVector{utf8()}, float32(), + kResultNullIfNull, "gdv_fn_castFLOAT4_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castFLOAT8", {}, DataTypeVector{utf8()}, float64(), + kResultNullIfNull, "gdv_fn_castFLOAT8_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castINT", {}, DataTypeVector{binary()}, int32(), kResultNullIfNull, + "gdv_fn_castINT_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castBIGINT", {}, DataTypeVector{binary()}, int64(), + kResultNullIfNull, "gdv_fn_castBIGINT_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castFLOAT4", {}, DataTypeVector{binary()}, float32(), + kResultNullIfNull, "gdv_fn_castFLOAT4_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castFLOAT8", {}, DataTypeVector{binary()}, float64(), + kResultNullIfNull, "gdv_fn_castFLOAT8_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARCHAR", {}, DataTypeVector{boolean(), int64()}, utf8(), + kResultNullIfNull, "castVARCHAR_bool_int64", + NativeFunction::kNeedsContext), + + NativeFunction("castVARCHAR", {}, DataTypeVector{utf8(), int64()}, utf8(), + kResultNullIfNull, "castVARCHAR_utf8_int64", + NativeFunction::kNeedsContext), + + NativeFunction("castVARCHAR", {}, DataTypeVector{binary(), int64()}, utf8(), + kResultNullIfNull, "castVARCHAR_binary_int64", + NativeFunction::kNeedsContext), + + NativeFunction("castVARCHAR", {}, DataTypeVector{int32(), int64()}, utf8(), + kResultNullIfNull, "gdv_fn_castVARCHAR_int32_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARCHAR", {}, DataTypeVector{int64(), int64()}, utf8(), + kResultNullIfNull, "gdv_fn_castVARCHAR_int64_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARCHAR", {}, DataTypeVector{float32(), int64()}, utf8(), + kResultNullIfNull, "gdv_fn_castVARCHAR_float32_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARCHAR", {}, DataTypeVector{float64(), int64()}, utf8(), + kResultNullIfNull, "gdv_fn_castVARCHAR_float64_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARCHAR", {}, DataTypeVector{decimal128(), int64()}, utf8(), + kResultNullIfNull, "castVARCHAR_decimal128_int64", + NativeFunction::kNeedsContext), + + NativeFunction("like", {}, DataTypeVector{utf8(), utf8()}, boolean(), + kResultNullIfNull, "gdv_fn_like_utf8_utf8", + NativeFunction::kNeedsFunctionHolder), + + NativeFunction("like", {}, DataTypeVector{utf8(), utf8(), utf8()}, boolean(), + kResultNullIfNull, "gdv_fn_like_utf8_utf8_utf8", + NativeFunction::kNeedsFunctionHolder), + + NativeFunction("ilike", {}, DataTypeVector{utf8(), utf8()}, boolean(), + kResultNullIfNull, "gdv_fn_ilike_utf8_utf8", + NativeFunction::kNeedsFunctionHolder), + + NativeFunction("ltrim", {}, DataTypeVector{utf8(), utf8()}, utf8(), + kResultNullIfNull, "ltrim_utf8_utf8", NativeFunction::kNeedsContext), + + NativeFunction("rtrim", {}, DataTypeVector{utf8(), utf8()}, utf8(), + kResultNullIfNull, "rtrim_utf8_utf8", NativeFunction::kNeedsContext), + + NativeFunction("btrim", {}, DataTypeVector{utf8(), utf8()}, utf8(), + kResultNullIfNull, "btrim_utf8_utf8", NativeFunction::kNeedsContext), + + NativeFunction("substr", {"substring"}, + DataTypeVector{utf8(), int64() /*offset*/, int64() /*length*/}, + utf8(), kResultNullIfNull, "substr_utf8_int64_int64", + NativeFunction::kNeedsContext), + + NativeFunction("substr", {"substring"}, DataTypeVector{utf8(), int64() /*offset*/}, + utf8(), kResultNullIfNull, "substr_utf8_int64", + NativeFunction::kNeedsContext), + + NativeFunction("lpad", {}, DataTypeVector{utf8(), int32(), utf8()}, utf8(), + kResultNullIfNull, "lpad_utf8_int32_utf8", + NativeFunction::kNeedsContext), + + NativeFunction("lpad", {}, DataTypeVector{utf8(), int32()}, utf8(), + kResultNullIfNull, "lpad_utf8_int32", NativeFunction::kNeedsContext), + + NativeFunction("rpad", {}, DataTypeVector{utf8(), int32(), utf8()}, utf8(), + kResultNullIfNull, "rpad_utf8_int32_utf8", + NativeFunction::kNeedsContext), + + NativeFunction("rpad", {}, DataTypeVector{utf8(), int32()}, utf8(), + kResultNullIfNull, "rpad_utf8_int32", NativeFunction::kNeedsContext), + + NativeFunction("regexp_replace", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(), + kResultNullIfNull, "gdv_fn_regexp_replace_utf8_utf8", + NativeFunction::kNeedsContext | + NativeFunction::kNeedsFunctionHolder | + NativeFunction::kCanReturnErrors), + + NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8()}, utf8(), + kResultNullIfNull, "concatOperator_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(), + kResultNullIfNull, "concatOperator_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concatOperator", {}, DataTypeVector{utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullIfNull, "concatOperator_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concatOperator", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8()}, utf8(), + kResultNullIfNull, "concatOperator_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concatOperator", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullIfNull, + "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction( + "concatOperator", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, utf8(), + kResultNullIfNull, "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction( + "concatOperator", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullIfNull, + "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concatOperator", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), + utf8(), utf8(), utf8()}, + utf8(), kResultNullIfNull, + "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concatOperator", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), + utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullIfNull, + "concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + + // concat treats null inputs as empty strings whereas concatOperator returns null if + // one of the inputs is null + NativeFunction("concat", {}, DataTypeVector{utf8(), utf8()}, utf8(), + kResultNullNever, "concat_utf8_utf8", NativeFunction::kNeedsContext), + NativeFunction("concat", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(), + kResultNullNever, "concat_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concat", {}, DataTypeVector{utf8(), utf8(), utf8(), utf8()}, utf8(), + kResultNullNever, "concat_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concat", {}, DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concat", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction( + "concat", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, utf8(), + kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction( + "concat", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullNever, "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concat", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), + utf8(), utf8(), utf8()}, + utf8(), kResultNullNever, + "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + NativeFunction("concat", {}, + DataTypeVector{utf8(), utf8(), utf8(), utf8(), utf8(), utf8(), + utf8(), utf8(), utf8(), utf8()}, + utf8(), kResultNullNever, + "concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8", + NativeFunction::kNeedsContext), + + NativeFunction("byte_substr", {"bytesubstring"}, + DataTypeVector{binary(), int32(), int32()}, binary(), + kResultNullIfNull, "byte_substr_binary_int32_int32", + NativeFunction::kNeedsContext), + + NativeFunction("convert_fromUTF8", {"convert_fromutf8"}, DataTypeVector{binary()}, + utf8(), kResultNullIfNull, "convert_fromUTF8_binary", + NativeFunction::kNeedsContext), + + NativeFunction("convert_replaceUTF8", {"convert_replaceutf8"}, + DataTypeVector{binary(), utf8()}, utf8(), kResultNullIfNull, + "convert_replace_invalid_fromUTF8_binary", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toDOUBLE", {}, DataTypeVector{float64()}, binary(), + kResultNullIfNull, "convert_toDOUBLE", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toDOUBLE_be", {}, DataTypeVector{float64()}, binary(), + kResultNullIfNull, "convert_toDOUBLE_be", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toFLOAT", {}, DataTypeVector{float32()}, binary(), + kResultNullIfNull, "convert_toFLOAT", NativeFunction::kNeedsContext), + + NativeFunction("convert_toFLOAT_be", {}, DataTypeVector{float32()}, binary(), + kResultNullIfNull, "convert_toFLOAT_be", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toINT", {}, DataTypeVector{int32()}, binary(), + kResultNullIfNull, "convert_toINT", NativeFunction::kNeedsContext), + + NativeFunction("convert_toINT_be", {}, DataTypeVector{int32()}, binary(), + kResultNullIfNull, "convert_toINT_be", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toBIGINT", {}, DataTypeVector{int64()}, binary(), + kResultNullIfNull, "convert_toBIGINT", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toBIGINT_be", {}, DataTypeVector{int64()}, binary(), + kResultNullIfNull, "convert_toBIGINT_be", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toBOOLEAN_BYTE", {}, DataTypeVector{boolean()}, binary(), + kResultNullIfNull, "convert_toBOOLEAN", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toTIME_EPOCH", {}, DataTypeVector{time32()}, binary(), + kResultNullIfNull, "convert_toTIME_EPOCH", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toTIME_EPOCH_be", {}, DataTypeVector{time32()}, binary(), + kResultNullIfNull, "convert_toTIME_EPOCH_be", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toTIMESTAMP_EPOCH", {}, DataTypeVector{timestamp()}, + binary(), kResultNullIfNull, "convert_toTIMESTAMP_EPOCH", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toTIMESTAMP_EPOCH_be", {}, DataTypeVector{timestamp()}, + binary(), kResultNullIfNull, "convert_toTIMESTAMP_EPOCH_be", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toDATE_EPOCH", {}, DataTypeVector{date64()}, binary(), + kResultNullIfNull, "convert_toDATE_EPOCH", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toDATE_EPOCH_be", {}, DataTypeVector{date64()}, binary(), + kResultNullIfNull, "convert_toDATE_EPOCH_be", + NativeFunction::kNeedsContext), + + NativeFunction("convert_toUTF8", {}, DataTypeVector{utf8()}, binary(), + kResultNullIfNull, "convert_toUTF8", NativeFunction::kNeedsContext), + + NativeFunction("locate", {"position"}, DataTypeVector{utf8(), utf8(), int32()}, + int32(), kResultNullIfNull, "locate_utf8_utf8_int32", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("replace", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(), + kResultNullIfNull, "replace_utf8_utf8_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("binary_string", {}, DataTypeVector{utf8()}, binary(), + kResultNullIfNull, "binary_string", NativeFunction::kNeedsContext), + + NativeFunction("left", {}, DataTypeVector{utf8(), int32()}, utf8(), + kResultNullIfNull, "left_utf8_int32", NativeFunction::kNeedsContext), + + NativeFunction("right", {}, DataTypeVector{utf8(), int32()}, utf8(), + kResultNullIfNull, "right_utf8_int32", + NativeFunction::kNeedsContext), + + NativeFunction("castVARBINARY", {}, DataTypeVector{binary(), int64()}, binary(), + kResultNullIfNull, "castVARBINARY_binary_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARBINARY", {}, DataTypeVector{utf8(), int64()}, binary(), + kResultNullIfNull, "castVARBINARY_utf8_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARBINARY", {}, DataTypeVector{int32(), int64()}, binary(), + kResultNullIfNull, "gdv_fn_castVARBINARY_int32_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARBINARY", {}, DataTypeVector{int64(), int64()}, binary(), + kResultNullIfNull, "gdv_fn_castVARBINARY_int64_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARBINARY", {}, DataTypeVector{float32(), int64()}, binary(), + kResultNullIfNull, "gdv_fn_castVARBINARY_float32_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castVARBINARY", {}, DataTypeVector{float64(), int64()}, binary(), + kResultNullIfNull, "gdv_fn_castVARBINARY_float64_int64", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("split_part", {}, DataTypeVector{utf8(), utf8(), int32()}, utf8(), + kResultNullIfNull, "split_part", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)}; + + return string_fn_registry_; +} + +#undef BINARY_RELATIONAL_SAFE_NULL_IF_NULL_FN + +#undef BINARY_RELATIONAL_SAFE_NULL_IF_NULL_UTF8_FN + +#undef UNARY_OCTET_LEN_FN + +#undef UNARY_SAFE_NULL_NEVER_BOOL_FN + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_string.h b/src/arrow/cpp/src/gandiva/function_registry_string.h new file mode 100644 index 000000000..f14c95a81 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_string.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector<NativeFunction> GetStringFunctionRegistry(); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_test.cc b/src/arrow/cpp/src/gandiva/function_registry_test.cc new file mode 100644 index 000000000..e3c1e85f7 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_test.cc @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include <algorithm> +#include <string> +#include <unordered_set> + +namespace gandiva { + +class TestFunctionRegistry : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +TEST_F(TestFunctionRegistry, TestFound) { + FunctionSignature add_i32_i32("add", {arrow::int32(), arrow::int32()}, arrow::int32()); + + const NativeFunction* function = registry_.LookupSignature(add_i32_i32); + EXPECT_NE(function, nullptr); + EXPECT_THAT(function->signatures(), testing::Contains(add_i32_i32)); + EXPECT_EQ(function->pc_name(), "add_int32_int32"); +} + +TEST_F(TestFunctionRegistry, TestNotFound) { + FunctionSignature addX_i32_i32("addX", {arrow::int32(), arrow::int32()}, + arrow::int32()); + EXPECT_EQ(registry_.LookupSignature(addX_i32_i32), nullptr); + + FunctionSignature add_i32_i32_ret64("add", {arrow::int32(), arrow::int32()}, + arrow::int64()); + EXPECT_EQ(registry_.LookupSignature(add_i32_i32_ret64), nullptr); +} + +// one nativefunction object per precompiled function +TEST_F(TestFunctionRegistry, TestNoDuplicates) { + std::unordered_set<std::string> pc_func_sigs; + std::unordered_set<std::string> native_func_duplicates; + std::unordered_set<std::string> func_sigs; + std::unordered_set<std::string> func_sig_duplicates; + for (auto native_func_it = registry_.begin(); native_func_it != registry_.end(); + ++native_func_it) { + auto& first_sig = native_func_it->signatures().front(); + auto pc_func_sig = FunctionSignature(native_func_it->pc_name(), + first_sig.param_types(), first_sig.ret_type()) + .ToString(); + if (pc_func_sigs.count(pc_func_sig) == 0) { + pc_func_sigs.insert(pc_func_sig); + } else { + native_func_duplicates.insert(pc_func_sig); + } + + for (auto& sig : native_func_it->signatures()) { + auto sig_str = sig.ToString(); + if (func_sigs.count(sig_str) == 0) { + func_sigs.insert(sig_str); + } else { + func_sig_duplicates.insert(sig_str); + } + } + } + std::ostringstream stream; + std::copy(native_func_duplicates.begin(), native_func_duplicates.end(), + std::ostream_iterator<std::string>(stream, "\n")); + std::string result = stream.str(); + EXPECT_TRUE(native_func_duplicates.empty()) + << "Registry has duplicates.\nMultiple NativeFunction objects refer to the " + "following precompiled functions:\n" + << result; + + stream.clear(); + std::copy(func_sig_duplicates.begin(), func_sig_duplicates.end(), + std::ostream_iterator<std::string>(stream, "\n")); + EXPECT_TRUE(func_sig_duplicates.empty()) + << "The following signatures are defined more than once possibly pointing to " + "different precompiled functions:\n" + << stream.str(); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.cc b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.cc new file mode 100644 index 000000000..c277dab72 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.cc @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_timestamp_arithmetic.h" + +#include "gandiva/function_registry_common.h" + +namespace gandiva { + +#define TIMESTAMP_ADD_FNS(name, ALIASES) \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, timestamp, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, date64, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, timestamp, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, date64, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int32, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int32, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int64, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int64, date64) + +#define TIMESTAMP_DIFF_FN(name, ALIASES) \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, timestamp, int32) + +#define DATE_ADD_FNS(name, ALIASES) \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int32, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int32, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int64, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int64, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, date64, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int32, timestamp, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, date64, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, int64, timestamp, timestamp) + +#define DATE_DIFF_FNS(name, ALIASES) \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int32, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int32, timestamp), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, date64, int64, date64), \ + BINARY_GENERIC_SAFE_NULL_IF_NULL(name, ALIASES, timestamp, int64, timestamp) + +std::vector<NativeFunction> GetDateTimeArithmeticFunctionRegistry() { + static std::vector<NativeFunction> datetime_fn_registry_ = { + BINARY_GENERIC_SAFE_NULL_IF_NULL(months_between, {}, date64, date64, float64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(months_between, {}, timestamp, timestamp, float64), + + TIMESTAMP_DIFF_FN(timestampdiffSecond, {}), + TIMESTAMP_DIFF_FN(timestampdiffMinute, {}), + TIMESTAMP_DIFF_FN(timestampdiffHour, {}), + TIMESTAMP_DIFF_FN(timestampdiffDay, {"datediff"}), + TIMESTAMP_DIFF_FN(timestampdiffWeek, {}), + TIMESTAMP_DIFF_FN(timestampdiffMonth, {}), + TIMESTAMP_DIFF_FN(timestampdiffQuarter, {}), + TIMESTAMP_DIFF_FN(timestampdiffYear, {}), + + TIMESTAMP_ADD_FNS(timestampaddSecond, {}), + TIMESTAMP_ADD_FNS(timestampaddMinute, {}), + TIMESTAMP_ADD_FNS(timestampaddHour, {}), + TIMESTAMP_ADD_FNS(timestampaddDay, {}), + TIMESTAMP_ADD_FNS(timestampaddWeek, {}), + TIMESTAMP_ADD_FNS(timestampaddMonth, {"add_months"}), + TIMESTAMP_ADD_FNS(timestampaddQuarter, {}), + TIMESTAMP_ADD_FNS(timestampaddYear, {}), + + DATE_ADD_FNS(date_add, {}), + DATE_ADD_FNS(add, {}), + + NativeFunction("add", {}, DataTypeVector{date64(), int64()}, timestamp(), + kResultNullIfNull, "add_date64_int64"), + + DATE_DIFF_FNS(date_sub, {}), + DATE_DIFF_FNS(subtract, {}), + DATE_DIFF_FNS(date_diff, {})}; + + return datetime_fn_registry_; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.h b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.h new file mode 100644 index 000000000..9ac3ab2ec --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_registry_timestamp_arithmetic.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector<NativeFunction> GetDateTimeArithmeticFunctionRegistry(); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_signature.cc b/src/arrow/cpp/src/gandiva/function_signature.cc new file mode 100644 index 000000000..6dc641617 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_signature.cc @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_signature.h" + +#include <cstddef> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include "arrow/util/checked_cast.h" +#include "arrow/util/hash_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/string.h" + +using arrow::internal::AsciiEqualsCaseInsensitive; +using arrow::internal::AsciiToLower; +using arrow::internal::checked_cast; +using arrow::internal::hash_combine; + +namespace gandiva { + +bool DataTypeEquals(const DataTypePtr& left, const DataTypePtr& right) { + if (left->id() == right->id()) { + switch (left->id()) { + case arrow::Type::DECIMAL: { + // For decimal types, the precision/scale isn't part of the signature. + auto dleft = checked_cast<arrow::DecimalType*>(left.get()); + auto dright = checked_cast<arrow::DecimalType*>(right.get()); + return (dleft != NULL) && (dright != NULL) && + (dleft->byte_width() == dright->byte_width()); + } + default: + return left->Equals(right); + } + } else { + return false; + } +} + +FunctionSignature::FunctionSignature(std::string base_name, DataTypeVector param_types, + DataTypePtr ret_type) + : base_name_(std::move(base_name)), + param_types_(std::move(param_types)), + ret_type_(std::move(ret_type)) { + DCHECK_GT(base_name_.length(), 0); + for (auto it = param_types_.begin(); it != param_types_.end(); it++) { + DCHECK(*it); + } + DCHECK(ret_type_); +} + +bool FunctionSignature::operator==(const FunctionSignature& other) const { + if (param_types_.size() != other.param_types_.size() || + !DataTypeEquals(ret_type_, other.ret_type_) || + !AsciiEqualsCaseInsensitive(base_name_, other.base_name_)) { + return false; + } + + for (size_t idx = 0; idx < param_types_.size(); idx++) { + if (!DataTypeEquals(param_types_[idx], other.param_types_[idx])) { + return false; + } + } + return true; +} + +/// calculated based on name, datatype id of parameters and datatype id +/// of return type. +std::size_t FunctionSignature::Hash() const { + static const size_t kSeedValue = 17; + size_t result = kSeedValue; + hash_combine(result, AsciiToLower(base_name_)); + hash_combine(result, static_cast<size_t>(ret_type_->id())); + // not using hash_range since we only want to include the id from the data type + for (auto& param_type : param_types_) { + hash_combine(result, static_cast<size_t>(param_type->id())); + } + return result; +} + +std::string FunctionSignature::ToString() const { + std::stringstream s; + + s << ret_type_->ToString() << " " << base_name_ << "("; + for (uint32_t i = 0; i < param_types_.size(); i++) { + if (i > 0) { + s << ", "; + } + + s << param_types_[i]->ToString(); + } + + s << ")"; + return s.str(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_signature.h b/src/arrow/cpp/src/gandiva/function_signature.h new file mode 100644 index 000000000..c3e363949 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_signature.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <string> +#include <vector> + +#include "gandiva/arrow.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Signature for a function : includes the base name, input param types and +/// output types. +class GANDIVA_EXPORT FunctionSignature { + public: + FunctionSignature(std::string base_name, DataTypeVector param_types, + DataTypePtr ret_type); + + bool operator==(const FunctionSignature& other) const; + + /// calculated based on name, datatype id of parameters and datatype id + /// of return type. + std::size_t Hash() const; + + DataTypePtr ret_type() const { return ret_type_; } + + const std::string& base_name() const { return base_name_; } + + DataTypeVector param_types() const { return param_types_; } + + std::string ToString() const; + + private: + std::string base_name_; + DataTypeVector param_types_; + DataTypePtr ret_type_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/function_signature_test.cc b/src/arrow/cpp/src/gandiva/function_signature_test.cc new file mode 100644 index 000000000..0eb62d4e7 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/function_signature_test.cc @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_signature.h" + +#include <memory> + +#include <gtest/gtest.h> + +namespace gandiva { + +class TestFunctionSignature : public ::testing::Test { + protected: + virtual void SetUp() { + // Use make_shared so these are distinct from the static instances returned + // by e.g. arrow::int32() + local_i32_type_ = std::make_shared<arrow::Int32Type>(); + local_i64_type_ = std::make_shared<arrow::Int64Type>(); + local_date32_type_ = std::make_shared<arrow::Date32Type>(); + } + + virtual void TearDown() { + local_i32_type_.reset(); + local_i64_type_.reset(); + local_date32_type_.reset(); + } + + // virtual void TearDown() {} + DataTypePtr local_i32_type_; + DataTypePtr local_i64_type_; + DataTypePtr local_date32_type_; +}; + +TEST_F(TestFunctionSignature, TestToString) { + EXPECT_EQ( + FunctionSignature("myfunc", {arrow::int32(), arrow::float32()}, arrow::float64()) + .ToString(), + "double myfunc(int32, float)"); +} + +TEST_F(TestFunctionSignature, TestEqualsName) { + EXPECT_EQ(FunctionSignature("add", {arrow::int32()}, arrow::int32()), + FunctionSignature("add", {arrow::int32()}, arrow::int32())); + + EXPECT_EQ(FunctionSignature("add", {arrow::int32()}, arrow::int64()), + FunctionSignature("add", {local_i32_type_}, local_i64_type_)); + + EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int32()) == + FunctionSignature("sub", {arrow::int32()}, arrow::int32())); + + EXPECT_EQ(FunctionSignature("extractDay", {arrow::int64()}, arrow::int64()), + FunctionSignature("extractday", {arrow::int64()}, arrow::int64())); + + EXPECT_EQ( + FunctionSignature("castVARCHAR", {arrow::utf8(), arrow::int64()}, arrow::utf8()), + FunctionSignature("castvarchar", {arrow::utf8(), arrow::int64()}, arrow::utf8())); +} + +TEST_F(TestFunctionSignature, TestEqualsParamCount) { + EXPECT_FALSE( + FunctionSignature("add", {arrow::int32(), arrow::int32()}, arrow::int32()) == + FunctionSignature("add", {arrow::int32()}, arrow::int32())); +} + +TEST_F(TestFunctionSignature, TestEqualsParamValue) { + EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int32()) == + FunctionSignature("add", {arrow::int64()}, arrow::int32())); + + EXPECT_FALSE( + FunctionSignature("add", {arrow::int32()}, arrow::int32()) == + FunctionSignature("add", {arrow::float32(), arrow::float32()}, arrow::int32())); + + EXPECT_FALSE( + FunctionSignature("add", {arrow::int32(), arrow::int64()}, arrow::int32()) == + FunctionSignature("add", {arrow::int64(), arrow::int32()}, arrow::int32())); + + EXPECT_EQ(FunctionSignature("extract_month", {arrow::date32()}, arrow::int64()), + FunctionSignature("extract_month", {local_date32_type_}, local_i64_type_)); + + EXPECT_FALSE(FunctionSignature("extract_month", {arrow::date32()}, arrow::int64()) == + FunctionSignature("extract_month", {arrow::date64()}, arrow::date32())); +} + +TEST_F(TestFunctionSignature, TestEqualsReturn) { + EXPECT_FALSE(FunctionSignature("add", {arrow::int32()}, arrow::int64()) == + FunctionSignature("add", {arrow::int32()}, arrow::int32())); +} + +TEST_F(TestFunctionSignature, TestHash) { + FunctionSignature f1("add", {arrow::int32(), arrow::int32()}, arrow::int64()); + FunctionSignature f2("add", {local_i32_type_, local_i32_type_}, local_i64_type_); + EXPECT_EQ(f1.Hash(), f2.Hash()); + + FunctionSignature f3("extractDay", {arrow::int64()}, arrow::int64()); + FunctionSignature f4("extractday", {arrow::int64()}, arrow::int64()); + EXPECT_EQ(f3.Hash(), f4.Hash()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/gandiva.pc.in b/src/arrow/cpp/src/gandiva/gandiva.pc.in new file mode 100644 index 000000000..22ff11a4f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/gandiva.pc.in @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +prefix=@CMAKE_INSTALL_PREFIX@ +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/@CMAKE_INSTALL_INCLUDEDIR@ + +Name: Gandiva +Description: Gandiva is a toolset for compiling and evaluating expressions on Arrow data. +Version: @GANDIVA_VERSION@ +Requires: arrow +Libs: -L${libdir} -lgandiva +Cflags: -I${includedir} diff --git a/src/arrow/cpp/src/gandiva/gandiva_aliases.h b/src/arrow/cpp/src/gandiva/gandiva_aliases.h new file mode 100644 index 000000000..6cbb671ff --- /dev/null +++ b/src/arrow/cpp/src/gandiva/gandiva_aliases.h @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <unordered_set> +#include <vector> + +namespace gandiva { + +class Dex; +using DexPtr = std::shared_ptr<Dex>; +using DexVector = std::vector<std::shared_ptr<Dex>>; + +class ValueValidityPair; +using ValueValidityPairPtr = std::shared_ptr<ValueValidityPair>; +using ValueValidityPairVector = std::vector<ValueValidityPairPtr>; + +class FieldDescriptor; +using FieldDescriptorPtr = std::shared_ptr<FieldDescriptor>; + +class FuncDescriptor; +using FuncDescriptorPtr = std::shared_ptr<FuncDescriptor>; + +class LValue; +using LValuePtr = std::shared_ptr<LValue>; + +class Expression; +using ExpressionPtr = std::shared_ptr<Expression>; +using ExpressionVector = std::vector<ExpressionPtr>; + +class Condition; +using ConditionPtr = std::shared_ptr<Condition>; + +class Node; +using NodePtr = std::shared_ptr<Node>; +using NodeVector = std::vector<std::shared_ptr<Node>>; + +class EvalBatch; +using EvalBatchPtr = std::shared_ptr<EvalBatch>; + +class FunctionSignature; +using FuncSignaturePtr = std::shared_ptr<FunctionSignature>; +using FuncSignatureVector = std::vector<FuncSignaturePtr>; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/gdv_function_stubs.cc b/src/arrow/cpp/src/gandiva/gdv_function_stubs.cc new file mode 100644 index 000000000..ed34eef4a --- /dev/null +++ b/src/arrow/cpp/src/gandiva/gdv_function_stubs.cc @@ -0,0 +1,1603 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/gdv_function_stubs.h" + +#include <utf8proc.h> + +#include <string> +#include <vector> + +#include "arrow/util/base64.h" +#include "arrow/util/double_conversion.h" +#include "arrow/util/formatting.h" +#include "arrow/util/string_view.h" +#include "arrow/util/utf8.h" +#include "arrow/util/value_parsing.h" +#include "gandiva/engine.h" +#include "gandiva/exported_funcs.h" +#include "gandiva/formatting_utils.h" +#include "gandiva/hash_utils.h" +#include "gandiva/in_holder.h" +#include "gandiva/like_holder.h" +#include "gandiva/precompiled/types.h" +#include "gandiva/random_generator_holder.h" +#include "gandiva/replace_holder.h" +#include "gandiva/to_date_holder.h" + +/// Stub functions that can be accessed from LLVM or the pre-compiled library. + +extern "C" { + +bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len, + const char* pattern, int pattern_len) { + gandiva::LikeHolder* holder = reinterpret_cast<gandiva::LikeHolder*>(ptr); + return (*holder)(std::string(data, data_len)); +} + +bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len, + const char* pattern, int pattern_len, + const char* escape_char, int escape_char_len) { + gandiva::LikeHolder* holder = reinterpret_cast<gandiva::LikeHolder*>(ptr); + return (*holder)(std::string(data, data_len)); +} + +bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len, + const char* pattern, int pattern_len) { + gandiva::LikeHolder* holder = reinterpret_cast<gandiva::LikeHolder*>(ptr); + return (*holder)(std::string(data, data_len)); +} + +const char* gdv_fn_regexp_replace_utf8_utf8( + int64_t ptr, int64_t holder_ptr, const char* data, int32_t data_len, + const char* /*pattern*/, int32_t /*pattern_len*/, const char* replace_string, + int32_t replace_string_len, int32_t* out_length) { + gandiva::ExecutionContext* context = reinterpret_cast<gandiva::ExecutionContext*>(ptr); + + gandiva::ReplaceHolder* holder = reinterpret_cast<gandiva::ReplaceHolder*>(holder_ptr); + + return (*holder)(context, data, data_len, replace_string, replace_string_len, + out_length); +} + +double gdv_fn_random(int64_t ptr) { + gandiva::RandomGeneratorHolder* holder = + reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr); + return (*holder)(); +} + +double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity) { + gandiva::RandomGeneratorHolder* holder = + reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr); + return (*holder)(); +} + +int64_t gdv_fn_to_date_utf8_utf8(int64_t context_ptr, int64_t holder_ptr, + const char* data, int data_len, bool in1_validity, + const char* pattern, int pattern_len, bool in2_validity, + bool* out_valid) { + gandiva::ExecutionContext* context = + reinterpret_cast<gandiva::ExecutionContext*>(context_ptr); + gandiva::ToDateHolder* holder = reinterpret_cast<gandiva::ToDateHolder*>(holder_ptr); + return (*holder)(context, data, data_len, in1_validity, out_valid); +} + +int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context_ptr, int64_t holder_ptr, + const char* data, int data_len, bool in1_validity, + const char* pattern, int pattern_len, + bool in2_validity, int32_t suppress_errors, + bool in3_validity, bool* out_valid) { + gandiva::ExecutionContext* context = + reinterpret_cast<gandiva::ExecutionContext*>(context_ptr); + gandiva::ToDateHolder* holder = reinterpret_cast<gandiva::ToDateHolder*>(holder_ptr); + return (*holder)(context, data, data_len, in1_validity, out_valid); +} + +bool gdv_fn_in_expr_lookup_int32(int64_t ptr, int32_t value, bool in_validity) { + if (!in_validity) { + return false; + } + gandiva::InHolder<int32_t>* holder = reinterpret_cast<gandiva::InHolder<int32_t>*>(ptr); + return holder->HasValue(value); +} + +bool gdv_fn_in_expr_lookup_int64(int64_t ptr, int64_t value, bool in_validity) { + if (!in_validity) { + return false; + } + gandiva::InHolder<int64_t>* holder = reinterpret_cast<gandiva::InHolder<int64_t>*>(ptr); + return holder->HasValue(value); +} + +bool gdv_fn_in_expr_lookup_decimal(int64_t ptr, int64_t value_high, int64_t value_low, + int32_t precision, int32_t scale, bool in_validity) { + if (!in_validity) { + return false; + } + gandiva::DecimalScalar128 value(value_high, value_low, precision, scale); + gandiva::InHolder<gandiva::DecimalScalar128>* holder = + reinterpret_cast<gandiva::InHolder<gandiva::DecimalScalar128>*>(ptr); + return holder->HasValue(value); +} + +bool gdv_fn_in_expr_lookup_float(int64_t ptr, float value, bool in_validity) { + if (!in_validity) { + return false; + } + gandiva::InHolder<float>* holder = reinterpret_cast<gandiva::InHolder<float>*>(ptr); + return holder->HasValue(value); +} + +bool gdv_fn_in_expr_lookup_double(int64_t ptr, double value, bool in_validity) { + if (!in_validity) { + return false; + } + gandiva::InHolder<double>* holder = reinterpret_cast<gandiva::InHolder<double>*>(ptr); + return holder->HasValue(value); +} + +bool gdv_fn_in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, + bool in_validity) { + if (!in_validity) { + return false; + } + gandiva::InHolder<std::string>* holder = + reinterpret_cast<gandiva::InHolder<std::string>*>(ptr); + return holder->HasValue(arrow::util::string_view(data, data_len)); +} + +int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr, + int32_t* offsets, int64_t slot, + const char* entry_buf, int32_t entry_len) { + auto buffer = reinterpret_cast<arrow::ResizableBuffer*>(data_ptr); + int32_t offset = static_cast<int32_t>(buffer->size()); + + // This also sets the size in the buffer. + auto status = buffer->Resize(offset + entry_len, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast<gandiva::ExecutionContext*>(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new entry. + memcpy(buffer->mutable_data() + offset, entry_buf, entry_len); + + // update offsets buffer. + offsets[slot] = offset; + offsets[slot + 1] = offset + entry_len; + return 0; +} + +#define SHA1_HASH_FUNCTION(TYPE) \ + GANDIVA_EXPORT \ + const char* gdv_fn_sha1_##TYPE(int64_t context, gdv_##TYPE value, bool validity, \ + int32_t* out_length) { \ + if (!validity) { \ + return gandiva::gdv_hash_using_sha1(context, NULLPTR, 0, out_length); \ + } \ + auto value_as_long = gandiva::gdv_double_to_long((double)value); \ + const char* result = gandiva::gdv_hash_using_sha1( \ + context, &value_as_long, sizeof(value_as_long), out_length); \ + \ + return result; \ + } + +#define SHA1_HASH_FUNCTION_BUF(TYPE) \ + GANDIVA_EXPORT \ + const char* gdv_fn_sha1_##TYPE(int64_t context, gdv_##TYPE value, \ + int32_t value_length, bool value_validity, \ + int32_t* out_length) { \ + if (!value_validity) { \ + return gandiva::gdv_hash_using_sha1(context, NULLPTR, 0, out_length); \ + } \ + return gandiva::gdv_hash_using_sha1(context, value, value_length, out_length); \ + } + +#define SHA256_HASH_FUNCTION(TYPE) \ + GANDIVA_EXPORT \ + const char* gdv_fn_sha256_##TYPE(int64_t context, gdv_##TYPE value, bool validity, \ + int32_t* out_length) { \ + if (!validity) { \ + return gandiva::gdv_hash_using_sha256(context, NULLPTR, 0, out_length); \ + } \ + auto value_as_long = gandiva::gdv_double_to_long((double)value); \ + const char* result = gandiva::gdv_hash_using_sha256( \ + context, &value_as_long, sizeof(value_as_long), out_length); \ + return result; \ + } + +#define SHA256_HASH_FUNCTION_BUF(TYPE) \ + GANDIVA_EXPORT \ + const char* gdv_fn_sha256_##TYPE(int64_t context, gdv_##TYPE value, \ + int32_t value_length, bool value_validity, \ + int32_t* out_length) { \ + if (!value_validity) { \ + return gandiva::gdv_hash_using_sha256(context, NULLPTR, 0, out_length); \ + } \ + \ + return gandiva::gdv_hash_using_sha256(context, value, value_length, out_length); \ + } + +// Expand inner macro for all numeric types. +#define SHA_NUMERIC_BOOL_DATE_PARAMS(INNER) \ + INNER(int8) \ + INNER(int16) \ + INNER(int32) \ + INNER(int64) \ + INNER(uint8) \ + INNER(uint16) \ + INNER(uint32) \ + INNER(uint64) \ + INNER(float32) \ + INNER(float64) \ + INNER(boolean) \ + INNER(date64) \ + INNER(date32) \ + INNER(time32) \ + INNER(timestamp) + +// Expand inner macro for all numeric types. +#define SHA_VAR_LEN_PARAMS(INNER) \ + INNER(utf8) \ + INNER(binary) + +SHA_NUMERIC_BOOL_DATE_PARAMS(SHA256_HASH_FUNCTION) +SHA_VAR_LEN_PARAMS(SHA256_HASH_FUNCTION_BUF) + +SHA_NUMERIC_BOOL_DATE_PARAMS(SHA1_HASH_FUNCTION) +SHA_VAR_LEN_PARAMS(SHA1_HASH_FUNCTION_BUF) + +#undef SHA_NUMERIC_BOOL_DATE_PARAMS +#undef SHA_VAR_LEN_PARAMS + +// Add functions for decimal128 +GANDIVA_EXPORT +const char* gdv_fn_sha256_decimal128(int64_t context, int64_t x_high, uint64_t x_low, + int32_t /*x_precision*/, int32_t /*x_scale*/, + gdv_boolean x_isvalid, int32_t* out_length) { + if (!x_isvalid) { + return gandiva::gdv_hash_using_sha256(context, NULLPTR, 0, out_length); + } + + const gandiva::BasicDecimal128 decimal_128(x_high, x_low); + return gandiva::gdv_hash_using_sha256(context, decimal_128.ToBytes().data(), 16, + out_length); +} + +GANDIVA_EXPORT +const char* gdv_fn_sha1_decimal128(int64_t context, int64_t x_high, uint64_t x_low, + int32_t /*x_precision*/, int32_t /*x_scale*/, + gdv_boolean x_isvalid, int32_t* out_length) { + if (!x_isvalid) { + return gandiva::gdv_hash_using_sha1(context, NULLPTR, 0, out_length); + } + + const gandiva::BasicDecimal128 decimal_128(x_high, x_low); + return gandiva::gdv_hash_using_sha1(context, decimal_128.ToBytes().data(), 16, + out_length); +} + +int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length, + int32_t* precision_from_str, int32_t* scale_from_str, + int64_t* dec_high_from_str, uint64_t* dec_low_from_str) { + arrow::Decimal128 dec; + auto status = arrow::Decimal128::FromString(std::string(in, in_length), &dec, + precision_from_str, scale_from_str); + if (!status.ok()) { + gdv_fn_context_set_error_msg(context, status.message().data()); + return -1; + } + *dec_high_from_str = dec.high_bits(); + *dec_low_from_str = dec.low_bits(); + return 0; +} + +char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, + int32_t x_scale, int32_t* dec_str_len) { + arrow::Decimal128 dec(arrow::BasicDecimal128(x_high, x_low)); + std::string dec_str = dec.ToString(x_scale); + *dec_str_len = static_cast<int32_t>(dec_str.length()); + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *dec_str_len)); + if (ret == nullptr) { + std::string err_msg = "Could not allocate memory for string: " + dec_str; + gdv_fn_context_set_error_msg(context, err_msg.data()); + return nullptr; + } + memcpy(ret, dec_str.data(), *dec_str_len); + return ret; +} + +GANDIVA_EXPORT +const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t in_len, + int32_t* out_len) { + if (in_len < 0) { + gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); + *out_len = 0; + return ""; + } + if (in_len == 0) { + *out_len = 0; + return ""; + } + // use arrow method to encode base64 string + std::string encoded_str = + arrow::util::base64_encode(arrow::util::string_view(in, in_len)); + *out_len = static_cast<int32_t>(encoded_str.length()); + // allocate memory for response + char* ret = reinterpret_cast<char*>( + gdv_fn_context_arena_malloc(context, static_cast<int32_t>(*out_len))); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory"); + *out_len = 0; + return ""; + } + memcpy(ret, encoded_str.data(), *out_len); + return ret; +} + +GANDIVA_EXPORT +const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t in_len, + int32_t* out_len) { + if (in_len < 0) { + gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); + *out_len = 0; + return ""; + } + if (in_len == 0) { + *out_len = 0; + return ""; + } + // use arrow method to decode base64 string + std::string decoded_str = + arrow::util::base64_decode(arrow::util::string_view(in, in_len)); + *out_len = static_cast<int32_t>(decoded_str.length()); + // allocate memory for response + char* ret = reinterpret_cast<char*>( + gdv_fn_context_arena_malloc(context, static_cast<int32_t>(*out_len))); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory"); + *out_len = 0; + return ""; + } + memcpy(ret, decoded_str.data(), *out_len); + return ret; +} + +#define CAST_NUMERIC_FROM_VARLEN_TYPES(OUT_TYPE, ARROW_TYPE, TYPE_NAME, INNER_TYPE) \ + GANDIVA_EXPORT \ + OUT_TYPE gdv_fn_cast##TYPE_NAME##_##INNER_TYPE(int64_t context, const char* data, \ + int32_t len) { \ + OUT_TYPE val = 0; \ + /* trim leading and trailing spaces */ \ + int32_t trimmed_len; \ + int32_t start = 0, end = len - 1; \ + while (start <= end && data[start] == ' ') { \ + ++start; \ + } \ + while (end >= start && data[end] == ' ') { \ + --end; \ + } \ + trimmed_len = end - start + 1; \ + const char* trimmed_data = data + start; \ + if (!arrow::internal::ParseValue<ARROW_TYPE>(trimmed_data, trimmed_len, &val)) { \ + std::string err = \ + "Failed to cast the string " + std::string(data, len) + " to " #OUT_TYPE; \ + gdv_fn_context_set_error_msg(context, err.c_str()); \ + } \ + return val; \ + } + +#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ + CAST_NUMERIC_FROM_VARLEN_TYPES(OUT_TYPE, ARROW_TYPE, TYPE_NAME, utf8) + +CAST_NUMERIC_FROM_STRING(int32_t, arrow::Int32Type, INT) +CAST_NUMERIC_FROM_STRING(int64_t, arrow::Int64Type, BIGINT) +CAST_NUMERIC_FROM_STRING(float, arrow::FloatType, FLOAT4) +CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8) + +#undef CAST_NUMERIC_FROM_STRING + +#define CAST_NUMERIC_FROM_VARBINARY(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ + CAST_NUMERIC_FROM_VARLEN_TYPES(OUT_TYPE, ARROW_TYPE, TYPE_NAME, varbinary) + +CAST_NUMERIC_FROM_VARBINARY(int32_t, arrow::Int32Type, INT) +CAST_NUMERIC_FROM_VARBINARY(int64_t, arrow::Int64Type, BIGINT) +CAST_NUMERIC_FROM_VARBINARY(float, arrow::FloatType, FLOAT4) +CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8) + +#undef CAST_NUMERIC_STRING + +#define GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(IN_TYPE, CAST_NAME, ARROW_TYPE) \ + GANDIVA_EXPORT \ + const char* gdv_fn_cast##CAST_NAME##_##IN_TYPE##_int64( \ + int64_t context, gdv_##IN_TYPE value, int64_t len, int32_t * out_len) { \ + if (len < 0) { \ + gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); \ + *out_len = 0; \ + return ""; \ + } \ + if (len == 0) { \ + *out_len = 0; \ + return ""; \ + } \ + arrow::internal::StringFormatter<arrow::ARROW_TYPE> formatter; \ + char* ret = reinterpret_cast<char*>( \ + gdv_fn_context_arena_malloc(context, static_cast<int32_t>(len))); \ + if (ret == nullptr) { \ + gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \ + *out_len = 0; \ + return ""; \ + } \ + arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \ + int64_t size = static_cast<int64_t>(v.size()); \ + *out_len = static_cast<int32_t>(len < size ? len : size); \ + memcpy(ret, v.data(), *out_len); \ + return arrow::Status::OK(); \ + }); \ + if (!status.ok()) { \ + std::string err = "Could not cast " + std::to_string(value) + " to string"; \ + gdv_fn_context_set_error_msg(context, err.c_str()); \ + *out_len = 0; \ + return ""; \ + } \ + return ret; \ + } + +#define GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(IN_TYPE, CAST_NAME, ARROW_TYPE) \ + GANDIVA_EXPORT \ + const char* gdv_fn_cast##CAST_NAME##_##IN_TYPE##_int64( \ + int64_t context, gdv_##IN_TYPE value, int64_t len, int32_t * out_len) { \ + if (len < 0) { \ + gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); \ + *out_len = 0; \ + return ""; \ + } \ + if (len == 0) { \ + *out_len = 0; \ + return ""; \ + } \ + gandiva::GdvStringFormatter<arrow::ARROW_TYPE> formatter; \ + char* ret = reinterpret_cast<char*>( \ + gdv_fn_context_arena_malloc(context, static_cast<int32_t>(len))); \ + if (ret == nullptr) { \ + gdv_fn_context_set_error_msg(context, "Could not allocate memory"); \ + *out_len = 0; \ + return ""; \ + } \ + arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \ + int64_t size = static_cast<int64_t>(v.size()); \ + *out_len = static_cast<int32_t>(len < size ? len : size); \ + memcpy(ret, v.data(), *out_len); \ + return arrow::Status::OK(); \ + }); \ + if (!status.ok()) { \ + std::string err = "Could not cast " + std::to_string(value) + " to string"; \ + gdv_fn_context_set_error_msg(context, err.c_str()); \ + *out_len = 0; \ + return ""; \ + } \ + return ret; \ + } + +#define CAST_VARLEN_TYPE_FROM_NUMERIC(VARLEN_TYPE) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(int32, VARLEN_TYPE, Int32Type) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER(int64, VARLEN_TYPE, Int64Type) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(float32, VARLEN_TYPE, FloatType) \ + GDV_FN_CAST_VARLEN_TYPE_FROM_REAL(float64, VARLEN_TYPE, DoubleType) + +CAST_VARLEN_TYPE_FROM_NUMERIC(VARCHAR) +CAST_VARLEN_TYPE_FROM_NUMERIC(VARBINARY) + +#undef CAST_VARLEN_TYPE_FROM_NUMERIC +#undef GDV_FN_CAST_VARLEN_TYPE_FROM_INTEGER +#undef GDV_FN_CAST_VARLEN_TYPE_FROM_REAL +#undef GDV_FN_CAST_VARCHAR_INTEGER +#undef GDV_FN_CAST_VARCHAR_REAL + +GDV_FORCE_INLINE +int32_t gdv_fn_utf8_char_length(char c) { + if ((signed char)c >= 0) { // 1-byte char (0x00 ~ 0x7F) + return 1; + } else if ((c & 0xE0) == 0xC0) { // 2-byte char + return 2; + } else if ((c & 0xF0) == 0xE0) { // 3-byte char + return 3; + } else if ((c & 0xF8) == 0xF0) { // 4-byte char + return 4; + } + // invalid char + return 0; +} + +GDV_FORCE_INLINE +void gdv_fn_set_error_for_invalid_utf8(int64_t execution_context, char val) { + char const* fmt = "unexpected byte \\%02hhx encountered while decoding utf8 string"; + int size = static_cast<int>(strlen(fmt)) + 64; + char* error = reinterpret_cast<char*>(malloc(size)); + snprintf(error, size, fmt, (unsigned char)val); + gdv_fn_context_set_error_msg(execution_context, error); + free(error); +} + +// Convert an utf8 string to its corresponding uppercase string +GANDIVA_EXPORT +const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_len, + int32_t* out_len) { + if (data_len == 0) { + *out_len = 0; + return ""; + } + + // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte + // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of + // the output can be at most twice the length of the input + char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + int32_t char_len, out_char_len, out_idx = 0; + uint32_t char_codepoint; + + for (int32_t i = 0; i < data_len; i += char_len) { + char_len = gdv_fn_utf8_char_length(data[i]); + // For single byte characters: + // If it is a lowercase ASCII character, set the output to its corresponding uppercase + // character; else, set the output to the read character + if (char_len == 1) { + char cur = data[i]; + // 'A' - 'Z' : 0x41 - 0x5a + // 'a' - 'z' : 0x61 - 0x7a + if (cur >= 0x61 && cur <= 0x7a) { + out[out_idx++] = static_cast<char>(cur - 0x20); + } else { + out[out_idx++] = cur; + } + continue; + } + + // Control reaches here when we encounter a multibyte character + const auto* in_char = (const uint8_t*)(data + i); + + // Decode the multibyte character + bool is_valid_utf8_char = + arrow::util::UTF8Decode((const uint8_t**)&in_char, &char_codepoint); + + // If it is an invalid utf8 character, UTF8Decode evaluates to false + if (!is_valid_utf8_char) { + gdv_fn_set_error_for_invalid_utf8(context, data[i]); + *out_len = 0; + return ""; + } + + // Convert the encoded codepoint to its uppercase codepoint + int32_t upper_codepoint = utf8proc_toupper(char_codepoint); + + // UTF8Encode advances the pointer by the number of bytes present in the uppercase + // character + auto* out_char = (uint8_t*)(out + out_idx); + uint8_t* out_char_start = out_char; + + // Encode the uppercase character + out_char = arrow::util::UTF8Encode(out_char, upper_codepoint); + + out_char_len = static_cast<int32_t>(out_char - out_char_start); + out_idx += out_char_len; + } + + *out_len = out_idx; + return out; +} + +// Convert an utf8 string to its corresponding lowercase string +GANDIVA_EXPORT +const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len, + int32_t* out_len) { + if (data_len == 0) { + *out_len = 0; + return ""; + } + + // If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte + // long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of + // the output can be at most twice the length of the input + char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + int32_t char_len, out_char_len, out_idx = 0; + uint32_t char_codepoint; + + for (int32_t i = 0; i < data_len; i += char_len) { + char_len = gdv_fn_utf8_char_length(data[i]); + // For single byte characters: + // If it is an uppercase ASCII character, set the output to its corresponding + // lowercase character; else, set the output to the read character + if (char_len == 1) { + char cur = data[i]; + // 'A' - 'Z' : 0x41 - 0x5a + // 'a' - 'z' : 0x61 - 0x7a + if (cur >= 0x41 && cur <= 0x5a) { + out[out_idx++] = static_cast<char>(cur + 0x20); + } else { + out[out_idx++] = cur; + } + continue; + } + + // Control reaches here when we encounter a multibyte character + const auto* in_char = (const uint8_t*)(data + i); + + // Decode the multibyte character + bool is_valid_utf8_char = + arrow::util::UTF8Decode((const uint8_t**)&in_char, &char_codepoint); + + // If it is an invalid utf8 character, UTF8Decode evaluates to false + if (!is_valid_utf8_char) { + gdv_fn_set_error_for_invalid_utf8(context, data[i]); + *out_len = 0; + return ""; + } + + // Convert the encoded codepoint to its lowercase codepoint + int32_t lower_codepoint = utf8proc_tolower(char_codepoint); + + // UTF8Encode advances the pointer by the number of bytes present in the lowercase + // character + auto* out_char = (uint8_t*)(out + out_idx); + uint8_t* out_char_start = out_char; + + // Encode the lowercase character + out_char = arrow::util::UTF8Encode(out_char, lower_codepoint); + + out_char_len = static_cast<int32_t>(out_char - out_char_start); + out_idx += out_char_len; + } + + *out_len = out_idx; + return out; +} + +// Any codepoint, except the ones for lowercase letters, uppercase letters, +// titlecase letters, decimal digits and letter numbers categories will be +// considered as word separators. +// +// The Unicode characters also are divided between categories. This link +// https://www.compart.com/en/unicode/category shows +// more information about characters categories. +GDV_FORCE_INLINE +bool gdv_fn_is_codepoint_for_space(uint32_t val) { + auto category = utf8proc_category(val); + + return category != utf8proc_category_t::UTF8PROC_CATEGORY_LU && + category != utf8proc_category_t::UTF8PROC_CATEGORY_LL && + category != utf8proc_category_t::UTF8PROC_CATEGORY_LT && + category != utf8proc_category_t::UTF8PROC_CATEGORY_NL && + category != utf8proc_category_t ::UTF8PROC_CATEGORY_ND; +} + +// For a given text, initialize the first letter after a word-separator and lowercase +// the others e.g: +// - "IT is a tEXt str" -> "It Is A Text Str" +GANDIVA_EXPORT +const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_len, + int32_t* out_len) { + if (data_len == 0) { + *out_len = data_len; + return ""; + } + + // If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte + // long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of + // the output can be at most twice the length of the input + char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + int32_t char_len = 0; + int32_t out_char_len = 0; + int32_t out_idx = 0; + uint32_t char_codepoint; + + // Any character is considered as space, except if it is alphanumeric + bool last_char_was_space = true; + + for (int32_t i = 0; i < data_len; i += char_len) { + // An optimization for single byte characters: + if (static_cast<signed char>(data[i]) >= 0) { // 1-byte char (0x00 ~ 0x7F) + char_len = 1; + char cur = data[i]; + + if (cur >= 0x61 && cur <= 0x7a && last_char_was_space) { + // Check if the character is the first one of the word and it is + // lowercase -> 'a' - 'z' : 0x61 - 0x7a. + // Then turn it into uppercase -> 'A' - 'Z' : 0x41 - 0x5a + out[out_idx++] = static_cast<char>(cur - 0x20); + last_char_was_space = false; + } else if (cur >= 0x41 && cur <= 0x5a && !last_char_was_space) { + out[out_idx++] = static_cast<char>(cur + 0x20); + } else { + // Check if the ASCII character is not an alphanumeric character: + // '0' - '9': 0x30 - 0x39 + // 'a' - 'z' : 0x61 - 0x7a + // 'A' - 'Z' : 0x41 - 0x5a + last_char_was_space = (cur < 0x30) || (cur > 0x39 && cur < 0x41) || + (cur > 0x5a && cur < 0x61) || (cur > 0x7a); + out[out_idx++] = cur; + } + continue; + } + + char_len = gdv_fn_utf8_char_length(data[i]); + + // Control reaches here when we encounter a multibyte character + const auto* in_char = (const uint8_t*)(data + i); + + // Decode the multibyte character + bool is_valid_utf8_char = + arrow::util::UTF8Decode((const uint8_t**)&in_char, &char_codepoint); + + // If it is an invalid utf8 character, UTF8Decode evaluates to false + if (!is_valid_utf8_char) { + gdv_fn_set_error_for_invalid_utf8(context, data[i]); + *out_len = 0; + return ""; + } + + bool is_char_space = gdv_fn_is_codepoint_for_space(char_codepoint); + + int32_t formatted_codepoint; + if (last_char_was_space && !is_char_space) { + formatted_codepoint = utf8proc_toupper(char_codepoint); + } else { + formatted_codepoint = utf8proc_tolower(char_codepoint); + } + + // UTF8Encode advances the pointer by the number of bytes present in the character + auto* out_char = (uint8_t*)(out + out_idx); + uint8_t* out_char_start = out_char; + + // Encode the character + out_char = arrow::util::UTF8Encode(out_char, formatted_codepoint); + + out_char_len = static_cast<int32_t>(out_char - out_char_start); + out_idx += out_char_len; + + last_char_was_space = is_char_space; + } + + *out_len = out_idx; + return out; +} +} + +namespace gandiva { + +void ExportedStubFunctions::AddMappings(Engine* engine) const { + std::vector<llvm::Type*> args; + auto types = engine->types(); + + // gdv_fn_castVARBINARY_int32 + args = { + types->i64_type(), // context + types->i32_type(), // int32_t value + types->i64_type(), // int64_t out value length + types->i32_ptr_type() // int32_t out_length + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARBINARY_int32_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARBINARY_int32_int64)); + + // gdv_fn_castVARBINARY_int64 + args = { + types->i64_type(), // context + types->i64_type(), // int64_t value + types->i64_type(), // int64_t out value length + types->i32_ptr_type() // int32_t out_length + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARBINARY_int64_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARBINARY_int64_int64)); + + // gdv_fn_castVARBINARY_float32 + args = { + types->i64_type(), // context + types->float_type(), // float value + types->i64_type(), // int64_t out value length + types->i64_ptr_type() // int32_t out_length + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARBINARY_float32_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARBINARY_float32_int64)); + + // gdv_fn_castVARBINARY_float64 + args = { + types->i64_type(), // context + types->i64_type(), // double value + types->i64_type(), // int64_t out value length + types->i32_ptr_type() // int32_t out_length + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARBINARY_float64_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARBINARY_float64_int64)); + + // gdv_fn_dec_from_string + args = { + types->i64_type(), // context + types->i8_ptr_type(), // const char* in + types->i32_type(), // int32_t in_length + types->i32_ptr_type(), // int32_t* precision_from_str + types->i32_ptr_type(), // int32_t* scale_from_str + types->i64_ptr_type(), // int64_t* dec_high_from_str + types->i64_ptr_type(), // int64_t* dec_low_from_str + }; + + engine->AddGlobalMappingForFunc("gdv_fn_dec_from_string", + types->i32_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_dec_from_string)); + + // gdv_fn_dec_to_string + args = { + types->i64_type(), // context + types->i64_type(), // int64_t x_high + types->i64_type(), // int64_t x_low + types->i32_type(), // int32_t x_scale + types->i64_ptr_type(), // int64_t* dec_str_len + }; + + engine->AddGlobalMappingForFunc("gdv_fn_dec_to_string", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_dec_to_string)); + + // gdv_fn_like_utf8_utf8 + args = {types->i64_type(), // int64_t ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i8_ptr_type(), // const char* pattern + types->i32_type()}; // int pattern_len + + engine->AddGlobalMappingForFunc("gdv_fn_like_utf8_utf8", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_like_utf8_utf8)); + + // gdv_fn_like_utf8_utf8_utf8 + args = {types->i64_type(), // int64_t ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i8_ptr_type(), // const char* pattern + types->i32_type(), // int pattern_len + types->i8_ptr_type(), // const char* escape_char + types->i32_type()}; // int escape_char_len + + engine->AddGlobalMappingForFunc("gdv_fn_like_utf8_utf8_utf8", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_like_utf8_utf8_utf8)); + + // gdv_fn_ilike_utf8_utf8 + args = {types->i64_type(), // int64_t ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i8_ptr_type(), // const char* pattern + types->i32_type()}; // int pattern_len + + engine->AddGlobalMappingForFunc("gdv_fn_ilike_utf8_utf8", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_ilike_utf8_utf8)); + + // gdv_fn_regexp_replace_utf8_utf8 + args = {types->i64_type(), // int64_t ptr + types->i64_type(), // int64_t holder_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i8_ptr_type(), // const char* pattern + types->i32_type(), // int pattern_len + types->i8_ptr_type(), // const char* replace_string + types->i32_type(), // int32_t replace_string_len + types->i32_ptr_type()}; // int32_t* out_length + + engine->AddGlobalMappingForFunc( + "gdv_fn_regexp_replace_utf8_utf8", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_regexp_replace_utf8_utf8)); + + // gdv_fn_to_date_utf8_utf8 + args = {types->i64_type(), // int64_t execution_context + types->i64_type(), // int64_t holder_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i1_type(), // bool in1_validity + types->i8_ptr_type(), // const char* pattern + types->i32_type(), // int pattern_len + types->i1_type(), // bool in2_validity + types->ptr_type(types->i8_type())}; // bool* out_valid + + engine->AddGlobalMappingForFunc("gdv_fn_to_date_utf8_utf8", + types->i64_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_to_date_utf8_utf8)); + + // gdv_fn_to_date_utf8_utf8_int32 + args = {types->i64_type(), // int64_t execution_context + types->i64_type(), // int64_t holder_ptr + types->i8_ptr_type(), // const char* data + types->i32_type(), // int data_len + types->i1_type(), // bool in1_validity + types->i8_ptr_type(), // const char* pattern + types->i32_type(), // int pattern_len + types->i1_type(), // bool in2_validity + types->i32_type(), // int32_t suppress_errors + types->i1_type(), // bool in3_validity + types->ptr_type(types->i8_type())}; // bool* out_valid + + engine->AddGlobalMappingForFunc( + "gdv_fn_to_date_utf8_utf8_int32", types->i64_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_to_date_utf8_utf8_int32)); + + // gdv_fn_in_expr_lookup_int32 + args = {types->i64_type(), // int64_t in holder ptr + types->i32_type(), // int32 value + types->i1_type()}; // bool in_validity + + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_int32", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_in_expr_lookup_int32)); + + // gdv_fn_in_expr_lookup_int64 + args = {types->i64_type(), // int64_t in holder ptr + types->i64_type(), // int64 value + types->i1_type()}; // bool in_validity + + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_int64", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_in_expr_lookup_int64)); + + // gdv_fn_in_expr_lookup_decimal + args = {types->i64_type(), // int64_t in holder ptr + types->i64_type(), // high decimal value + types->i64_type(), // low decimal value + types->i32_type(), // decimal precision value + types->i32_type(), // decimal scale value + types->i1_type()}; // bool in_validity + + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_decimal", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_in_expr_lookup_decimal)); + + // gdv_fn_in_expr_lookup_utf8 + args = {types->i64_type(), // int64_t in holder ptr + types->i8_ptr_type(), // const char* value + types->i32_type(), // int value_len + types->i1_type()}; // bool in_validity + + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_utf8", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_in_expr_lookup_utf8)); + // gdv_fn_in_expr_lookup_float + args = {types->i64_type(), // int64_t in holder ptr + types->float_type(), // float value + types->i1_type()}; // bool in_validity + + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_float", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_in_expr_lookup_float)); + // gdv_fn_in_expr_lookup_double + args = {types->i64_type(), // int64_t in holder ptr + types->double_type(), // double value + types->i1_type()}; // bool in_validity + + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_double", + types->i1_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_in_expr_lookup_double)); + // gdv_fn_populate_varlen_vector + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* offsets ptr + types->i64_type(), // int64_t slot + types->i8_ptr_type(), // const char* entry_buf + types->i32_type()}; // int32_t entry__len + + engine->AddGlobalMappingForFunc("gdv_fn_populate_varlen_vector", + types->i32_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_populate_varlen_vector)); + + // gdv_fn_random + args = {types->i64_type()}; + engine->AddGlobalMappingForFunc("gdv_fn_random", types->double_type(), args, + reinterpret_cast<void*>(gdv_fn_random)); + + args = {types->i64_type(), types->i32_type(), types->i1_type()}; + engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed", types->double_type(), args, + reinterpret_cast<void*>(gdv_fn_random_with_seed)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castINT_utf8", types->i32_type(), args, + reinterpret_cast<void*>(gdv_fn_castINT_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_utf8", types->i64_type(), args, + reinterpret_cast<void*>(gdv_fn_castBIGINT_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_utf8", types->float_type(), args, + reinterpret_cast<void*>(gdv_fn_castFLOAT4_utf8)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_utf8", types->double_type(), args, + reinterpret_cast<void*>(gdv_fn_castFLOAT8_utf8)); + + // gdv_fn_castVARCHAR_int32_int64 + args = {types->i64_type(), // int64_t execution_context + types->i32_type(), // int32_t value + types->i64_type(), // int64_t len + types->i32_ptr_type()}; // int32_t* out_len + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARCHAR_int32_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARCHAR_int32_int64)); + + // gdv_fn_castVARCHAR_int64_int64 + args = {types->i64_type(), // int64_t execution_context + types->i64_type(), // int64_t value + types->i64_type(), // int64_t len + types->i32_ptr_type()}; // int32_t* out_len + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARCHAR_int64_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARCHAR_int64_int64)); + + // gdv_fn_castVARCHAR_float32_int64 + args = {types->i64_type(), // int64_t execution_context + types->float_type(), // float value + types->i64_type(), // int64_t len + types->i32_ptr_type()}; // int32_t* out_len + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARCHAR_float32_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARCHAR_float32_int64)); + + // gdv_fn_castVARCHAR_float64_int64 + args = {types->i64_type(), // int64_t execution_context + types->double_type(), // double value + types->i64_type(), // int64_t len + types->i32_ptr_type()}; // int32_t* out_len + engine->AddGlobalMappingForFunc( + "gdv_fn_castVARCHAR_float64_int64", types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_castVARCHAR_float64_int64)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castINT_varbinary", types->i32_type(), args, + reinterpret_cast<void*>(gdv_fn_castINT_varbinary)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_varbinary", types->i64_type(), args, + reinterpret_cast<void*>(gdv_fn_castBIGINT_varbinary)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_varbinary", types->float_type(), + args, + reinterpret_cast<void*>(gdv_fn_castFLOAT4_varbinary)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_varbinary", types->double_type(), + args, + reinterpret_cast<void*>(gdv_fn_castFLOAT8_varbinary)); + + // gdv_fn_sha1_int8 + args = { + types->i64_type(), // context + types->i8_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_int8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_int8)); + + // gdv_fn_sha1_int16 + args = { + types->i64_type(), // context + types->i16_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_int16", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_int16)); + + // gdv_fn_sha1_int32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_int32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_int32)); + + // gdv_fn_sha1_int32 + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_int64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_int64)); + + // gdv_fn_sha1_uint8 + args = { + types->i64_type(), // context + types->i8_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_uint8)); + + // gdv_fn_sha1_uint16 + args = { + types->i64_type(), // context + types->i16_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint16", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_uint16)); + + // gdv_fn_sha1_uint32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_uint32)); + + // gdv_fn_sha1_uint64 + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_uint64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_uint64)); + + // gdv_fn_sha1_float32 + args = { + types->i64_type(), // context + types->float_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_float32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_float32)); + + // gdv_fn_sha1_float64 + args = { + types->i64_type(), // context + types->double_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_float64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_float64)); + + // gdv_fn_sha1_boolean + args = { + types->i64_type(), // context + types->i1_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_boolean", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_boolean)); + + // gdv_fn_sha1_date64 + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_date64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_date64)); + + // gdv_fn_sha1_date32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_date32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_date32)); + + // gdv_fn_sha1_time32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_time32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_time32)); + + // gdv_fn_sha1_timestamp + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha1_timestamp", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_timestamp)); + + // gdv_fn_sha1_from_utf8 + args = { + types->i64_type(), // context + types->i8_ptr_type(), // const char* + types->i32_type(), // value_length + types->i1_type(), // validity + types->i32_ptr_type() // out + }; + + engine->AddGlobalMappingForFunc("gdv_fn_sha1_utf8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_utf8)); + + // gdv_fn_sha1_from_binary + args = { + types->i64_type(), // context + types->i8_ptr_type(), // const char* + types->i32_type(), // value_length + types->i1_type(), // validity + types->i32_ptr_type() // out + }; + + engine->AddGlobalMappingForFunc("gdv_fn_sha1_binary", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_binary)); + + // gdv_fn_sha256_int8 + args = { + types->i64_type(), // context + types->i8_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_int8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_int8)); + + // gdv_fn_sha256_int16 + args = { + types->i64_type(), // context + types->i16_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_int16", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_int16)); + + // gdv_fn_sha256_int32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_int32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_int32)); + + // gdv_fn_sha256_int32 + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_int64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_int64)); + + // gdv_fn_sha256_uint8 + args = { + types->i64_type(), // context + types->i8_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_uint8)); + + // gdv_fn_sha256_uint16 + args = { + types->i64_type(), // context + types->i16_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint16", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_uint16)); + + // gdv_fn_sha256_uint32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_uint32)); + + // gdv_fn_sha256_uint64 + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_uint64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_uint64)); + + // gdv_fn_sha256_float32 + args = { + types->i64_type(), // context + types->float_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_float32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_float32)); + + // gdv_fn_sha256_float64 + args = { + types->i64_type(), // context + types->double_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_float64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_float64)); + + // gdv_fn_sha256_boolean + args = { + types->i64_type(), // context + types->i1_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_boolean", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_boolean)); + + // gdv_fn_sha256_date64 + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_date64", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_date64)); + + // gdv_fn_sha256_date32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_date32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_date32)); + + // gdv_fn_sha256_time32 + args = { + types->i64_type(), // context + types->i32_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_time32", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_time32)); + + // gdv_fn_sha256_timestamp + args = { + types->i64_type(), // context + types->i64_type(), // value + types->i1_type(), // validity + types->i32_ptr_type() // out_length + }; + engine->AddGlobalMappingForFunc("gdv_fn_sha256_timestamp", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_timestamp)); + + // gdv_fn_hash_sha256_from_utf8 + args = { + types->i64_type(), // context + types->i8_ptr_type(), // const char* + types->i32_type(), // value_length + types->i1_type(), // validity + types->i32_ptr_type() // out + }; + + engine->AddGlobalMappingForFunc("gdv_fn_sha256_utf8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_utf8)); + + // gdv_fn_hash_sha256_from_binary + args = { + types->i64_type(), // context + types->i8_ptr_type(), // const char* + types->i32_type(), // value_length + types->i1_type(), // validity + types->i32_ptr_type() // out + }; + + engine->AddGlobalMappingForFunc("gdv_fn_sha256_binary", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_binary)); + + // gdv_fn_sha1_decimal128 + args = { + types->i64_type(), // context + types->i64_type(), // high_bits + types->i64_type(), // low_bits + types->i32_type(), // precision + types->i32_type(), // scale + types->i1_type(), // validity + types->i32_ptr_type() // out length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_sha1_decimal128", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha1_decimal128)); + // gdv_fn_sha256_decimal128 + args = { + types->i64_type(), // context + types->i64_type(), // high_bits + types->i64_type(), // low_bits + types->i32_type(), // precision + types->i32_type(), // scale + types->i1_type(), // validity + types->i32_ptr_type() // out length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_sha256_decimal128", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_sha256_decimal128)); + + // gdv_fn_base64_encode_utf8 + args = { + types->i64_type(), // context + types->i8_ptr_type(), // in + types->i32_type(), // in_len + types->i32_ptr_type(), // out_len + }; + + engine->AddGlobalMappingForFunc("gdv_fn_base64_encode_binary", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_base64_encode_binary)); + + // gdv_fn_base64_decode_utf8 + args = { + types->i64_type(), // context + types->i8_ptr_type(), // in + types->i32_type(), // in_len + types->i32_ptr_type(), // out_len + }; + + engine->AddGlobalMappingForFunc("gdv_fn_base64_decode_utf8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_base64_decode_utf8)); + + // gdv_fn_upper_utf8 + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_len + types->i32_ptr_type(), // out_len + }; + + engine->AddGlobalMappingForFunc("gdv_fn_upper_utf8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_upper_utf8)); + // gdv_fn_lower_utf8 + args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_len + types->i32_ptr_type(), // out_len + }; + + engine->AddGlobalMappingForFunc("gdv_fn_lower_utf8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_lower_utf8)); + + // gdv_fn_initcap_utf8 + args = { + types->i64_type(), // context + types->i8_ptr_type(), // const char* + types->i32_type(), // value_length + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_fn_initcap_utf8", + types->i8_ptr_type() /*return_type*/, args, + reinterpret_cast<void*>(gdv_fn_initcap_utf8)); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/gdv_function_stubs.h b/src/arrow/cpp/src/gandiva/gdv_function_stubs.h new file mode 100644 index 000000000..670ac94df --- /dev/null +++ b/src/arrow/cpp/src/gandiva/gdv_function_stubs.h @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> + +#include "gandiva/visibility.h" + +/// Stub functions that can be accessed from LLVM. +extern "C" { + +using gdv_boolean = bool; +using gdv_int8 = int8_t; +using gdv_int16 = int16_t; +using gdv_int32 = int32_t; +using gdv_int64 = int64_t; +using gdv_uint8 = uint8_t; +using gdv_uint16 = uint16_t; +using gdv_uint32 = uint32_t; +using gdv_uint64 = uint64_t; +using gdv_float32 = float; +using gdv_float64 = double; +using gdv_date64 = int64_t; +using gdv_date32 = int32_t; +using gdv_time32 = int32_t; +using gdv_timestamp = int64_t; +using gdv_utf8 = char*; +using gdv_binary = char*; +using gdv_day_time_interval = int64_t; +using gdv_month_interval = int32_t; + +#ifdef GANDIVA_UNIT_TEST +// unit tests may be compiled without O2, so inlining may not happen. +#define GDV_FORCE_INLINE +#else +#ifdef _MSC_VER +#define GDV_FORCE_INLINE __forceinline +#else +#define GDV_FORCE_INLINE inline __attribute__((always_inline)) +#endif +#endif + +bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len, + const char* pattern, int pattern_len); + +bool gdv_fn_like_utf8_utf8_utf8(int64_t ptr, const char* data, int data_len, + const char* pattern, int pattern_len, + const char* escape_char, int escape_char_len); + +bool gdv_fn_ilike_utf8_utf8(int64_t ptr, const char* data, int data_len, + const char* pattern, int pattern_len); + +int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context, int64_t ptr, const char* data, + int data_len, bool in1_validity, + const char* pattern, int pattern_len, + bool in2_validity, int32_t suppress_errors, + bool in3_validity, bool* out_valid); + +void gdv_fn_context_set_error_msg(int64_t context_ptr, const char* err_msg); + +uint8_t* gdv_fn_context_arena_malloc(int64_t context_ptr, int32_t data_len); + +void gdv_fn_context_arena_reset(int64_t context_ptr); + +bool in_expr_lookup_int32(int64_t ptr, int32_t value, bool in_validity); + +bool in_expr_lookup_int64(int64_t ptr, int64_t value, bool in_validity); + +bool in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, bool in_validity); + +int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len, + int64_t* ret_time); + +GANDIVA_EXPORT +const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t in_len, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t in_len, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_fn_castVARBINARY_int32_int64(int64_t context, gdv_int32 value, + int64_t out_len, int32_t* out_length); + +GANDIVA_EXPORT +const char* gdv_fn_castVARBINARY_int64_int64(int64_t context, gdv_int64 value, + int64_t out_len, int32_t* out_length); + +GANDIVA_EXPORT +const char* gdv_fn_sha256_decimal128(int64_t context, int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + gdv_boolean x_isvalid, int32_t* out_length); + +GANDIVA_EXPORT +const char* gdv_fn_sha1_decimal128(int64_t context, int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + gdv_boolean x_isvalid, int32_t* out_length); + +int32_t gdv_fn_dec_from_string(int64_t context, const char* in, int32_t in_length, + int32_t* precision_from_str, int32_t* scale_from_str, + int64_t* dec_high_from_str, uint64_t* dec_low_from_str); + +char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, + int32_t x_scale, int32_t* dec_str_len); + +GANDIVA_EXPORT +int32_t gdv_fn_castINT_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +int64_t gdv_fn_castBIGINT_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +float gdv_fn_castFLOAT4_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +double gdv_fn_castFLOAT8_utf8(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +const char* gdv_fn_castVARCHAR_int32_int64(int64_t context, int32_t value, int64_t len, + int32_t* out_len); +GANDIVA_EXPORT +const char* gdv_fn_castVARCHAR_int64_int64(int64_t context, int64_t value, int64_t len, + int32_t* out_len); +GANDIVA_EXPORT +const char* gdv_fn_castVARCHAR_float32_int64(int64_t context, float value, int64_t len, + int32_t* out_len); +GANDIVA_EXPORT +const char* gdv_fn_castVARCHAR_float64_int64(int64_t context, double value, int64_t len, + int32_t* out_len); + +GANDIVA_EXPORT +int32_t gdv_fn_utf8_char_length(char c); + +GANDIVA_EXPORT +const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_len, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_len, + int32_t* out_len); + +GANDIVA_EXPORT +int32_t gdv_fn_castINT_varbinary(gdv_int64 context, const char* in, int32_t in_len); + +GANDIVA_EXPORT +int64_t gdv_fn_castBIGINT_varbinary(gdv_int64 context, const char* in, int32_t in_len); + +GANDIVA_EXPORT +float gdv_fn_castFLOAT4_varbinary(gdv_int64 context, const char* in, int32_t in_len); + +GANDIVA_EXPORT +double gdv_fn_castFLOAT8_varbinary(gdv_int64 context, const char* in, int32_t in_len); +} diff --git a/src/arrow/cpp/src/gandiva/gdv_function_stubs_test.cc b/src/arrow/cpp/src/gandiva/gdv_function_stubs_test.cc new file mode 100644 index 000000000..f7c21981c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -0,0 +1,769 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/gdv_function_stubs.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "gandiva/execution_context.h" + +namespace gandiva { + +TEST(TestGdvFnStubs, TestCastVarbinaryNumeric) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + + // tests for integer values as input + const char* out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, -46, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-46"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 2147483647, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "2147483647"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, -2147483647 - 1, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-2147483648"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 0, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0"); + EXPECT_FALSE(ctx.has_error()); + + // test with required length less than actual buffer length + out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 34567, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "345"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 347, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + gdv_fn_castVARBINARY_int32_int64(ctx_ptr, 347, -1, &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative")); + ctx.Reset(); + + // tests for big integer values as input + out_str = + gdv_fn_castVARBINARY_int64_int64(ctx_ptr, 9223372036854775807LL, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "9223372036854775807"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARBINARY_int64_int64(ctx_ptr, -9223372036854775807LL - 1, 100, + &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-9223372036854775808"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARBINARY_int64_int64(ctx_ptr, 0, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0"); + EXPECT_FALSE(ctx.has_error()); + + // test with required length less than actual buffer length + out_str = gdv_fn_castVARBINARY_int64_int64(ctx_ptr, 12345, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123"); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestGdvFnStubs, TestBase64Encode) { + gandiva::ExecutionContext ctx; + + auto ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + + auto value = gdv_fn_base64_encode_binary(ctx_ptr, "hello", 5, &out_len); + std::string out_value = std::string(value, out_len); + EXPECT_EQ(out_value, "aGVsbG8="); + + value = gdv_fn_base64_encode_binary(ctx_ptr, "test", 4, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, "dGVzdA=="); + + value = gdv_fn_base64_encode_binary(ctx_ptr, "hive", 4, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, "aGl2ZQ=="); + + value = gdv_fn_base64_encode_binary(ctx_ptr, "", 0, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, ""); + + value = gdv_fn_base64_encode_binary(ctx_ptr, "test", -5, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestBase64Decode) { + gandiva::ExecutionContext ctx; + + auto ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + + auto value = gdv_fn_base64_decode_utf8(ctx_ptr, "aGVsbG8=", 8, &out_len); + std::string out_value = std::string(value, out_len); + EXPECT_EQ(out_value, "hello"); + + value = gdv_fn_base64_decode_utf8(ctx_ptr, "dGVzdA==", 8, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, "test"); + + value = gdv_fn_base64_decode_utf8(ctx_ptr, "aGl2ZQ==", 8, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, "hive"); + + value = gdv_fn_base64_decode_utf8(ctx_ptr, "", 0, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, ""); + + value = gdv_fn_base64_decode_utf8(ctx_ptr, "test", -5, &out_len); + out_value = std::string(value, out_len); + EXPECT_EQ(out_value, ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "2147483647", 10), 2147483647); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "02147483647", 11), 2147483647); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-2147483648", 11), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, "-02147483648", 12), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_utf8(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castINT_utf8(ctx_ptr, "2147483648", 10); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 2147483648 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "-2147483649", 11); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -2147483649 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int32")); + ctx.Reset(); + + gdv_fn_castINT_utf8(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int32")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastBIGINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775807", 19), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "09223372036854775807", 20), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775808", 20), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, "-009223372036854775808", 22), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_utf8(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "9223372036854775808", 19); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "-9223372036854775809", 20); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_utf8(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int64")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastFloat4) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "-45.34", 6), -45.34f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "0", 1), 0.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, "5", 1), 5.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_utf8(ctx_ptr, " 3.4 ", 5), 3.4f); + + gdv_fn_castFLOAT4_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to float")); + ctx.Reset(); + + gdv_fn_castFLOAT4_utf8(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to float")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastFloat8) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "-45.34", 6), -45.34); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "0", 1), 0.0); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, "5", 1), 5.0); + EXPECT_EQ(gdv_fn_castFLOAT8_utf8(ctx_ptr, " 3.4 ", 5), 3.4); + + gdv_fn_castFLOAT8_utf8(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to double")); + ctx.Reset(); + + gdv_fn_castFLOAT8_utf8(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to double")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVARCHARFromInt32) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + + const char* out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, -46, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-46"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 2147483647, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "2147483647"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, -2147483647 - 1, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-2147483648"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 0, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0"); + EXPECT_FALSE(ctx.has_error()); + + // test with required length less than actual buffer length + out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 34567, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "345"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 347, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int32_int64(ctx_ptr, 347, -1, &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer length can not be negative")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVARCHARFromInt64) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + + const char* out_str = + gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 9223372036854775807LL, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "9223372036854775807"); + EXPECT_FALSE(ctx.has_error()); + + out_str = + gdv_fn_castVARCHAR_int64_int64(ctx_ptr, -9223372036854775807LL - 1, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-9223372036854775808"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 0, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0"); + EXPECT_FALSE(ctx.has_error()); + + // test with required length less than actual buffer length + out_str = gdv_fn_castVARCHAR_int64_int64(ctx_ptr, 12345, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123"); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestGdvFnStubs, TestCastVARCHARFromFloat) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + + const char* out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 4.567f, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "4.567"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, -3.4567f, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-3.4567"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00001f, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1.0E-5"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00099999f, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "9.9999E-4"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.0f, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0.0"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 10.00000f, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "10.0"); + EXPECT_FALSE(ctx.has_error()); + + // test with required length less than actual buffer length + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 1.2345f, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1.2"); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + + const char* out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 4.567, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "4.567"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, -3.4567, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "-3.4567"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 0.00001, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1.0E-5"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float32_int64(ctx_ptr, 0.00099999f, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "9.9999E-4"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 0.0, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0.0"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 10.0000000000, 100, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "10.0"); + EXPECT_FALSE(ctx.has_error()); + + // test with required length less than actual buffer length + out_str = gdv_fn_castVARCHAR_float64_int64(ctx_ptr, 1.2345, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1.2"); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestGdvFnStubs, TestUpper) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = gdv_fn_upper_utf8(ctx_ptr, "AbcDEfGh", 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ABCDEFGH"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "asdfj", 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ASDFJ"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "s;dcGS,jO!l", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "S;DCGS,JO!L"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "münchen", 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "MÜNCHEN"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "CITROËN", 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "CITROËN"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "âBćDëFGH", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ÂBĆDËFGH"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "øhpqRšvñ", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ØHPQRŠVÑ"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "Möbelträgerfüße", 19, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "MÖBELTRÄGERFÜẞE"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "{õhp,PQŚv}ń+", 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "{ÕHP,PQŚV}Ń+"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_upper_utf8(ctx_ptr, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + std::string d("AbOJjÜoß\xc3"); + out_str = gdv_fn_upper_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\c3 encountered while decoding utf8 string")); + ctx.Reset(); + + std::string e( + "åbÑg\xe0\xa0" + "åBUå"); + out_str = gdv_fn_upper_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\e0 encountered while decoding utf8 string")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestLower) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = gdv_fn_lower_utf8(ctx_ptr, "AbcDEfGh", 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcdefgh"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "asdfj", 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdfj"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "S;DCgs,Jo!L", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "s;dcgs,jo!l"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "MÜNCHEN", 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "münchen"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "citroën", 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "citroën"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "ÂbĆDËFgh", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "âbćdëfgh"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "ØHPQrŠvÑ", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "øhpqršvñ"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "MÖBELTRÄGERFÜẞE", 20, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "möbelträgerfüße"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "{ÕHP,pqśv}Ń+", 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "{õhp,pqśv}ń+"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + std::string d("AbOJjÜoß\xc3"); + out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\c3 encountered while decoding utf8 string")); + ctx.Reset(); + + std::string e( + "åbÑg\xe0\xa0" + "åBUå"); + out_str = gdv_fn_lower_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\e0 encountered while decoding utf8 string")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestInitCap) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = gdv_fn_initcap_utf8(ctx_ptr, "test string", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test String"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "asdfj\nhlqf", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Asdfj\nHlqf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "s;DCgs,Jo!l", 11, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "S;Dcgs,Jo!L"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, " mÜNCHEN", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), " München"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "citroën CaR", 12, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Citroën Car"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "ÂbĆDËFgh\néll", 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Âbćdëfgh\nÉll"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, " øhpqršvñ \n\n", 17, &out_len); + EXPECT_EQ(std::string(out_str, out_len), " Øhpqršvñ \n\n"); + EXPECT_FALSE(ctx.has_error()); + + out_str = + gdv_fn_initcap_utf8(ctx_ptr, "möbelträgerfüße \nmöbelträgerfüße", 42, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Möbelträgerfüße \nMöbelträgerfüße"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "{ÕHP,pqśv}Ń+", 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "{Õhp,Pqśv}Ń+"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "sɦasasdsɦsd\"sdsdɦ", 19, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Sɦasasdsɦsd\"Sdsdɦ"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "mysuperscipt@number²isfine", 27, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Mysuperscipt@Number²Isfine"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "Ő<tŵas̓老ƕɱ¢vIYwށ", 25, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Ő<Tŵas̓老Ƕɱ¢Viywށ"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "ↆcheckↆnumberisspace", 24, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ↆcheckↆnumberisspace"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "testing ᾌTitleᾌcase", 23, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Testing ᾌtitleᾄcase"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "ʳTesting mʳodified", 20, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ʳTesting MʳOdified"); + EXPECT_FALSE(ctx.has_error()); + + out_str = gdv_fn_initcap_utf8(ctx_ptr, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + std::string d("AbOJjÜoß\xc3"); + out_str = + gdv_fn_initcap_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\c3 encountered while decoding utf8 string")); + ctx.Reset(); + + std::string e( + "åbÑg\xe0\xa0" + "åBUå"); + out_str = + gdv_fn_initcap_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\e0 encountered while decoding utf8 string")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "2147483647", 10), 2147483647); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "\x32\x33", 2), 23); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "02147483647", 11), 2147483647); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-2147483648", 11), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-02147483648", 12), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castINT_varbinary(ctx_ptr, "2147483648", 10); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 2147483648 to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "-2147483649", 11); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -2147483649 to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int32")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryBIGINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775807", 19), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "09223372036854775807", 20), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775808", 20), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-009223372036854775808", 22), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, " 12 ", 4), 12); + + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, + "\x39\x39\x39\x39\x39\x39\x39\x39\x39\x39", 10), + 9999999999LL); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775808", 19); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775809", 20); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int64")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryFloat4) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-45.34", 6), -45.34f); + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "0", 1), 0.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "5", 1), 5.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, " 3.4 ", 5), 3.4f); + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, " \x33\x2E\x34 ", 5), 3.4f); + + gdv_fn_castFLOAT4_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to float")); + ctx.Reset(); + + gdv_fn_castFLOAT4_varbinary(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to float")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryFloat8) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "-45.34", 6), -45.34); + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "0", 1), 0.0); + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "5", 1), 5.0); + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, " \x33\x2E\x34 ", 5), 3.4); + + gdv_fn_castFLOAT8_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to double")); + ctx.Reset(); + + gdv_fn_castFLOAT8_varbinary(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to double")); + ctx.Reset(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/greedy_dual_size_cache.h b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache.h new file mode 100644 index 000000000..cb5c38e07 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache.h @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <list> +#include <queue> +#include <set> +#include <unordered_map> +#include <utility> + +#include "arrow/util/optional.h" + +// modified cache to support evict policy using the GreedyDual-Size algorithm. +namespace gandiva { +// Defines a base value object supported on the cache that may contain properties +template <typename ValueType> +class ValueCacheObject { + public: + ValueCacheObject(ValueType module, uint64_t cost) : module(module), cost(cost) {} + ValueType module; + uint64_t cost; + bool operator<(const ValueCacheObject& other) const { return cost < other.cost; } +}; + +// A particular cache based on the GreedyDual-Size cache which is a generalization of LRU +// which defines costs for each cache values. +// The algorithm associates a cost, C, with each cache value. Initially, when the value +// is brought into cache, C is set to be the cost related to the value (the cost is +// always non-negative). When a replacement needs to be made, the value with the lowest C +// cost is replaced, and then all values reduce their C costs by the minimum value of C +// over all the values already in the cache. +// If a value is accessed, its C value is restored to its initial cost. Thus, the C costs +// of recently accessed values retain a larger portion of the original cost than those of +// values that have not been accessed for a long time. The C costs are reduced as time +// goes and are restored when accessed. + +template <class Key, class Value> +class GreedyDualSizeCache { + // inner class to define the priority item + class PriorityItem { + public: + PriorityItem(uint64_t actual_priority, uint64_t original_priority, Key key) + : actual_priority(actual_priority), + original_priority(original_priority), + cache_key(key) {} + // this ensure that the items with low priority stays in the beginning of the queue, + // so it can be the one removed by evict operation + bool operator<(const PriorityItem& other) const { + return actual_priority < other.actual_priority; + } + uint64_t actual_priority; + uint64_t original_priority; + Key cache_key; + }; + + public: + struct hasher { + template <typename I> + std::size_t operator()(const I& i) const { + return i.Hash(); + } + }; + // a map from 'key' to a pair of Value and a pointer to the priority value + using map_type = std::unordered_map< + Key, std::pair<ValueCacheObject<Value>, typename std::set<PriorityItem>::iterator>, + hasher>; + + explicit GreedyDualSizeCache(size_t capacity) : inflation_(0), capacity_(capacity) {} + + ~GreedyDualSizeCache() = default; + + size_t size() const { return map_.size(); } + + size_t capacity() const { return capacity_; } + + bool empty() const { return map_.empty(); } + + bool contains(const Key& key) { return map_.find(key) != map_.end(); } + + void insert(const Key& key, const ValueCacheObject<Value>& value) { + typename map_type::iterator i = map_.find(key); + // check if element is not in the cache to add it + if (i == map_.end()) { + // insert item into the cache, but first check if it is full, to evict an item + // if it is necessary + if (size() >= capacity_) { + evict(); + } + + // insert the new item + auto item = + priority_set_.insert(PriorityItem(value.cost + inflation_, value.cost, key)); + // save on map the value and the priority item iterator position + map_.emplace(key, std::make_pair(value, item.first)); + } + } + + arrow::util::optional<ValueCacheObject<Value>> get(const Key& key) { + // lookup value in the cache + typename map_type::iterator value_for_key = map_.find(key); + if (value_for_key == map_.end()) { + // value not in cache + return arrow::util::nullopt; + } + PriorityItem item = *value_for_key->second.second; + // if the value was found on the cache, update its cost (original + inflation) + if (item.actual_priority != item.original_priority + inflation_) { + priority_set_.erase(value_for_key->second.second); + auto iter = priority_set_.insert(PriorityItem( + item.original_priority + inflation_, item.original_priority, item.cache_key)); + value_for_key->second.second = iter.first; + } + return value_for_key->second.first; + } + + void clear() { + map_.clear(); + priority_set_.clear(); + } + + private: + void evict() { + // TODO: inflation overflow is unlikely to happen but needs to be handled + // for correctness. + // evict item from the beginning of the set. This set is ordered from the + // lower priority value to the higher priority value. + typename std::set<PriorityItem>::iterator i = priority_set_.begin(); + // update the inflation cost related to the evicted item + inflation_ = (*i).actual_priority; + map_.erase((*i).cache_key); + priority_set_.erase(i); + } + + map_type map_; + std::set<PriorityItem> priority_set_; + uint64_t inflation_; + size_t capacity_; +}; +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/greedy_dual_size_cache_test.cc b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache_test.cc new file mode 100644 index 000000000..3c72eef70 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/greedy_dual_size_cache_test.cc @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/greedy_dual_size_cache.h" + +#include <string> +#include <typeinfo> + +#include <gtest/gtest.h> + +namespace gandiva { + +class GreedyDualSizeCacheKey { + public: + explicit GreedyDualSizeCacheKey(int tmp) : tmp_(tmp) {} + std::size_t Hash() const { return tmp_; } + bool operator==(const GreedyDualSizeCacheKey& other) const { + return tmp_ == other.tmp_; + } + + private: + int tmp_; +}; + +class TestGreedyDualSizeCache : public ::testing::Test { + public: + TestGreedyDualSizeCache() : cache_(2) {} + + protected: + GreedyDualSizeCache<GreedyDualSizeCacheKey, std::string> cache_; +}; + +TEST_F(TestGreedyDualSizeCache, TestEvict) { + // check if the cache is evicting the items with low priority on cache + cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("1", 1)); + cache_.insert(GreedyDualSizeCacheKey(2), ValueCacheObject<std::string>("2", 10)); + cache_.insert(GreedyDualSizeCacheKey(3), ValueCacheObject<std::string>("3", 20)); + cache_.insert(GreedyDualSizeCacheKey(4), ValueCacheObject<std::string>("4", 15)); + cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("5", 1)); + ASSERT_EQ(2, cache_.size()); + // we check initially the values that won't be on the cache, since the get operation + // may affect the entity costs, which is not the purpose of this test + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(2)), arrow::util::nullopt); + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(3)), arrow::util::nullopt); + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(1))->module, "5"); + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(4))->module, "4"); +} + +TEST_F(TestGreedyDualSizeCache, TestGreedyDualSizeBehavior) { + // insert 1 and 3 evicting 2 (this eviction will increase the inflation cost by 20) + cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("1", 40)); + cache_.insert(GreedyDualSizeCacheKey(2), ValueCacheObject<std::string>("2", 20)); + cache_.insert(GreedyDualSizeCacheKey(3), ValueCacheObject<std::string>("3", 30)); + + // when accessing key 3, its actual cost will be increased by the inflation, so in the + // next eviction, the key 1 will be evicted, since the key 1 actual cost (original(40)) + // is smaller than key 3 actual increased cost (original(30) + inflation(20)) + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(3))->module, "3"); + + // try to insert key 2 and expect the eviction of key 1 + cache_.insert(GreedyDualSizeCacheKey(2), ValueCacheObject<std::string>("2", 20)); + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(1)), arrow::util::nullopt); + + // when accessing key 2, its original cost should be increased by inflation, so when + // inserting the key 1 again, now the key 3 should be evicted + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(2))->module, "2"); + cache_.insert(GreedyDualSizeCacheKey(1), ValueCacheObject<std::string>("1", 20)); + + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(1))->module, "1"); + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(2))->module, "2"); + ASSERT_EQ(cache_.get(GreedyDualSizeCacheKey(3)), arrow::util::nullopt); + ASSERT_EQ(2, cache_.size()); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/hash_utils.cc b/src/arrow/cpp/src/gandiva/hash_utils.cc new file mode 100644 index 000000000..8ebf60a9b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/hash_utils.cc @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/hash_utils.h" +#include <cstring> +#include "arrow/util/logging.h" +#include "gandiva/gdv_function_stubs.h" +#include "openssl/evp.h" + +namespace gandiva { +/// Hashes a generic message using the SHA256 algorithm +GANDIVA_EXPORT +const char* gdv_hash_using_sha256(int64_t context, const void* message, + size_t message_length, int32_t* out_length) { + constexpr int sha256_result_length = 64; + return gdv_hash_using_sha(context, message, message_length, EVP_sha256(), + sha256_result_length, out_length); +} + +/// Hashes a generic message using the SHA1 algorithm +GANDIVA_EXPORT +const char* gdv_hash_using_sha1(int64_t context, const void* message, + size_t message_length, int32_t* out_length) { + constexpr int sha1_result_length = 40; + return gdv_hash_using_sha(context, message, message_length, EVP_sha1(), + sha1_result_length, out_length); +} + +/// \brief Hashes a generic message using SHA algorithm. +/// +/// It uses the EVP API in the OpenSSL library to generate +/// the hash. The type of the hash is defined by the +/// \b hash_type \b parameter. +GANDIVA_EXPORT +const char* gdv_hash_using_sha(int64_t context, const void* message, + size_t message_length, const EVP_MD* hash_type, + uint32_t result_buf_size, int32_t* out_length) { + EVP_MD_CTX* md_ctx = EVP_MD_CTX_new(); + + if (md_ctx == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not create the context for SHA processing."); + *out_length = 0; + return ""; + } + + int evp_success_status = 1; + + if (EVP_DigestInit_ex(md_ctx, hash_type, nullptr) != evp_success_status || + EVP_DigestUpdate(md_ctx, message, message_length) != evp_success_status) { + gdv_fn_context_set_error_msg(context, + "Could not obtain the hash for the defined value."); + EVP_MD_CTX_free(md_ctx); + + *out_length = 0; + return ""; + } + + // Create the temporary buffer used by the EVP to generate the hash + unsigned int hash_digest_size = EVP_MD_size(hash_type); + auto* result = static_cast<unsigned char*>(OPENSSL_malloc(hash_digest_size)); + + if (result == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for SHA processing"); + EVP_MD_CTX_free(md_ctx); + *out_length = 0; + return ""; + } + + unsigned int result_length; + EVP_DigestFinal_ex(md_ctx, result, &result_length); + + if (result_length != hash_digest_size && result_buf_size != (2 * hash_digest_size)) { + gdv_fn_context_set_error_msg(context, + "Could not obtain the hash for the defined value"); + EVP_MD_CTX_free(md_ctx); + OPENSSL_free(result); + + *out_length = 0; + return ""; + } + + auto result_buffer = + reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, result_buf_size)); + + if (result_buffer == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for the result buffer"); + // Free the resources used by the EVP + EVP_MD_CTX_free(md_ctx); + OPENSSL_free(result); + + *out_length = 0; + return ""; + } + + unsigned int result_buff_index = 0; + for (unsigned int j = 0; j < result_length; j++) { + DCHECK(result_buff_index >= 0 && result_buff_index < result_buf_size); + + unsigned char hex_number = result[j]; + result_buff_index += + snprintf(result_buffer + result_buff_index, result_buf_size, "%02x", hex_number); + } + + // Free the resources used by the EVP to avoid memory leaks + EVP_MD_CTX_free(md_ctx); + OPENSSL_free(result); + + *out_length = result_buf_size; + return result_buffer; +} + +GANDIVA_EXPORT +uint64_t gdv_double_to_long(double value) { + uint64_t result; + memcpy(&result, &value, sizeof(result)); + return result; +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/hash_utils.h b/src/arrow/cpp/src/gandiva/hash_utils.h new file mode 100644 index 000000000..483993f30 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/hash_utils.h @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef ARROW_SRC_HASH_UTILS_H_ +#define ARROW_SRC_HASH_UTILS_H_ + +#include <cstdint> +#include <cstdlib> +#include "gandiva/visibility.h" +#include "openssl/evp.h" + +namespace gandiva { +GANDIVA_EXPORT +const char* gdv_hash_using_sha256(int64_t context, const void* message, + size_t message_length, int32_t* out_length); + +GANDIVA_EXPORT +const char* gdv_hash_using_sha1(int64_t context, const void* message, + size_t message_length, int32_t* out_length); + +GANDIVA_EXPORT +const char* gdv_hash_using_sha(int64_t context, const void* message, + size_t message_length, const EVP_MD* hash_type, + uint32_t result_buf_size, int32_t* out_length); + +GANDIVA_EXPORT +uint64_t gdv_double_to_long(double value); +} // namespace gandiva + +#endif // ARROW_SRC_HASH_UTILS_H_ diff --git a/src/arrow/cpp/src/gandiva/hash_utils_test.cc b/src/arrow/cpp/src/gandiva/hash_utils_test.cc new file mode 100644 index 000000000..a8f55e1ed --- /dev/null +++ b/src/arrow/cpp/src/gandiva/hash_utils_test.cc @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include <unordered_set> + +#include "gandiva/execution_context.h" +#include "gandiva/hash_utils.h" + +TEST(TestShaHashUtils, TestSha1Numeric) { + gandiva::ExecutionContext ctx; + + auto ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + std::vector<uint64_t> values_to_be_hashed; + + // Generate a list of values to obtains the SHA1 hash + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.0)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.1)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.2)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.10000001)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000001)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(1.000000)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000002)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.999999)); + + // Checks if the hash value is different for each one of the values + std::unordered_set<std::string> sha_values; + + int sha1_size = 40; + + for (auto value : values_to_be_hashed) { + int out_length; + const char* sha_1 = + gandiva::gdv_hash_using_sha1(ctx_ptr, &value, sizeof(value), &out_length); + std::string sha1_as_str(sha_1, out_length); + EXPECT_EQ(sha1_as_str.size(), sha1_size); + + // The value can not exists inside the set with the hash results + EXPECT_EQ(sha_values.find(sha1_as_str), sha_values.end()); + sha_values.insert(sha1_as_str); + } +} + +TEST(TestShaHashUtils, TestSha256Numeric) { + gandiva::ExecutionContext ctx; + + auto ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + std::vector<uint64_t> values_to_be_hashed; + + // Generate a list of values to obtains the SHA1 hash + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.0)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.1)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.2)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.10000001)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000001)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(1.000000)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(-0.0000002)); + values_to_be_hashed.push_back(gandiva::gdv_double_to_long(0.999999)); + + // Checks if the hash value is different for each one of the values + std::unordered_set<std::string> sha_values; + + int sha256_size = 64; + + for (auto value : values_to_be_hashed) { + int out_length; + const char* sha_256 = + gandiva::gdv_hash_using_sha256(ctx_ptr, &value, sizeof(value), &out_length); + std::string sha256_as_str(sha_256, out_length); + EXPECT_EQ(sha256_as_str.size(), sha256_size); + + // The value can not exists inside the set with the hash results + EXPECT_EQ(sha_values.find(sha256_as_str), sha_values.end()); + sha_values.insert(sha256_as_str); + } +} + +TEST(TestShaHashUtils, TestSha1Varlen) { + gandiva::ExecutionContext ctx; + + auto ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + std::string first_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ]"; + + std::string second_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ"; + + // The strings expected hashes are obtained from shell executing the following command: + // echo -n <output-string> | openssl dgst sha1 + std::string expected_first_result = "160fcdbc2fa694d884868f5fae7a4bae82706185"; + std::string expected_second_result = "a456b3e0f88669d2482170a42fade226a815bee1"; + + // Generate the hashes and compare with expected outputs + const int sha1_size = 40; + int out_length; + + const char* sha_1 = gandiva::gdv_hash_using_sha1(ctx_ptr, first_string.c_str(), + first_string.size(), &out_length); + std::string sha1_as_str(sha_1, out_length); + EXPECT_EQ(sha1_as_str.size(), sha1_size); + EXPECT_EQ(sha1_as_str, expected_first_result); + + const char* sha_2 = gandiva::gdv_hash_using_sha1(ctx_ptr, second_string.c_str(), + second_string.size(), &out_length); + std::string sha2_as_str(sha_2, out_length); + EXPECT_EQ(sha2_as_str.size(), sha1_size); + EXPECT_EQ(sha2_as_str, expected_second_result); +} + +TEST(TestShaHashUtils, TestSha256Varlen) { + gandiva::ExecutionContext ctx; + + auto ctx_ptr = reinterpret_cast<int64_t>(&ctx); + + std::string first_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ]"; + + std::string second_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ"; + + // The strings expected hashes are obtained from shell executing the following command: + // echo -n <output-string> | openssl dgst sha1 + std::string expected_first_result = + "55aeb2e789871dbd289edae94d4c1c82a1c25ca0bcd5a873924da2fefdd57acb"; + std::string expected_second_result = + "86b29c13d0d0e26ea8f85bfa649dc9b8622ae59a4da2409d7d9b463e86e796f2"; + + // Generate the hashes and compare with expected outputs + const int sha256_size = 64; + int out_length; + + const char* sha_1 = gandiva::gdv_hash_using_sha256(ctx_ptr, first_string.c_str(), + first_string.size(), &out_length); + std::string sha1_as_str(sha_1, out_length); + EXPECT_EQ(sha1_as_str.size(), sha256_size); + EXPECT_EQ(sha1_as_str, expected_first_result); + + const char* sha_2 = gandiva::gdv_hash_using_sha256(ctx_ptr, second_string.c_str(), + second_string.size(), &out_length); + std::string sha2_as_str(sha_2, out_length); + EXPECT_EQ(sha2_as_str.size(), sha256_size); + EXPECT_EQ(sha2_as_str, expected_second_result); +} diff --git a/src/arrow/cpp/src/gandiva/in_holder.h b/src/arrow/cpp/src/gandiva/in_holder.h new file mode 100644 index 000000000..d55ab5ec5 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/in_holder.h @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <string> +#include <unordered_set> + +#include "arrow/util/hashing.h" +#include "gandiva/arrow.h" +#include "gandiva/decimal_scalar.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// Function Holder for IN Expressions +template <typename Type> +class InHolder { + public: + explicit InHolder(const std::unordered_set<Type>& values) { + values_.max_load_factor(0.25f); + for (auto& value : values) { + values_.insert(value); + } + } + + bool HasValue(Type value) const { return values_.count(value) == 1; } + + private: + std::unordered_set<Type> values_; +}; + +template <> +class InHolder<gandiva::DecimalScalar128> { + public: + explicit InHolder(const std::unordered_set<gandiva::DecimalScalar128>& values) { + values_.max_load_factor(0.25f); + for (auto& value : values) { + values_.insert(value); + } + } + + bool HasValue(gandiva::DecimalScalar128 value) const { + return values_.count(value) == 1; + } + + private: + std::unordered_set<gandiva::DecimalScalar128> values_; +}; + +template <> +class InHolder<std::string> { + public: + explicit InHolder(std::unordered_set<std::string> values) : values_(std::move(values)) { + values_lookup_.max_load_factor(0.25f); + for (const std::string& value : values_) { + values_lookup_.emplace(value); + } + } + + bool HasValue(arrow::util::string_view value) const { + return values_lookup_.count(value) == 1; + } + + private: + struct string_view_hash { + public: + std::size_t operator()(arrow::util::string_view v) const { + return arrow::internal::ComputeStringHash<0>(v.data(), v.length()); + } + }; + + std::unordered_set<arrow::util::string_view, string_view_hash> values_lookup_; + const std::unordered_set<std::string> values_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/jni/CMakeLists.txt b/src/arrow/cpp/src/gandiva/jni/CMakeLists.txt new file mode 100644 index 000000000..046934141 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/CMakeLists.txt @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(CMAKE_VERSION VERSION_LESS 3.11) + message(FATAL_ERROR "Building the Gandiva JNI bindings requires CMake version >= 3.11") +endif() + +if(MSVC) + add_definitions(-DPROTOBUF_USE_DLLS) +endif() + +# Find JNI +find_package(JNI REQUIRED) + +set(PROTO_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) +set(PROTO_OUTPUT_FILES "${PROTO_OUTPUT_DIR}/Types.pb.cc") +set(PROTO_OUTPUT_FILES ${PROTO_OUTPUT_FILES} "${PROTO_OUTPUT_DIR}/Types.pb.h") + +set_source_files_properties(${PROTO_OUTPUT_FILES} PROPERTIES GENERATED TRUE) + +get_filename_component(ABS_GANDIVA_PROTO + ${CMAKE_SOURCE_DIR}/src/gandiva/proto/Types.proto ABSOLUTE) + +add_custom_command(OUTPUT ${PROTO_OUTPUT_FILES} + COMMAND ${ARROW_PROTOBUF_PROTOC} --proto_path + ${CMAKE_SOURCE_DIR}/src/gandiva/proto --cpp_out + ${PROTO_OUTPUT_DIR} + ${CMAKE_SOURCE_DIR}/src/gandiva/proto/Types.proto + DEPENDS ${ABS_GANDIVA_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF} + COMMENT "Running PROTO compiler on Types.proto" + VERBATIM) + +add_custom_target(gandiva_jni_proto ALL DEPENDS ${PROTO_OUTPUT_FILES}) +set(PROTO_SRCS "${PROTO_OUTPUT_DIR}/Types.pb.cc") +set(PROTO_HDRS "${PROTO_OUTPUT_DIR}/Types.pb.h") + +# Create the jni header file (from the java class). +set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/java") +add_subdirectory(../../../../java/gandiva ./java/gandiva) + +set(GANDIVA_LINK_LIBS ${ARROW_PROTOBUF_LIBPROTOBUF}) +if(ARROW_BUILD_STATIC) + list(APPEND GANDIVA_LINK_LIBS gandiva_static) +else() + list(APPEND GANDIVA_LINK_LIBS gandiva_shared) +endif() + +set(GANDIVA_JNI_SOURCES + config_builder.cc + config_holder.cc + expression_registry_helper.cc + jni_common.cc + ${PROTO_SRCS}) + +# For users of gandiva_jni library (including integ tests), include-dir is : +# /usr/**/include dir after install, +# cpp/include during build +# For building gandiva_jni library itself, include-dir (in addition to above) is : +# cpp/src +add_arrow_lib(gandiva_jni + SOURCES + ${GANDIVA_JNI_SOURCES} + OUTPUTS + GANDIVA_JNI_LIBRARIES + SHARED_PRIVATE_LINK_LIBS + ${GANDIVA_LINK_LIBS} + STATIC_LINK_LIBS + ${GANDIVA_LINK_LIBS} + DEPENDENCIES + ${GANDIVA_LINK_LIBS} + gandiva_java + gandiva_jni_headers + gandiva_jni_proto + EXTRA_INCLUDES + $<INSTALL_INTERFACE:include> + $<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/include> + $<BUILD_INTERFACE:${JNI_HEADERS_DIR}> + PRIVATE_INCLUDES + ${JNI_INCLUDE_DIRS} + ${CMAKE_CURRENT_BINARY_DIR}) + +add_dependencies(gandiva ${GANDIVA_JNI_LIBRARIES}) + +if(ARROW_BUILD_SHARED) + # filter out everything that is not needed for the jni bridge + # statically linked stdc++ has conflicts with stdc++ loaded by other libraries. + if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT) + set_target_properties(gandiva_jni_shared + PROPERTIES LINK_FLAGS + "-Wl,--version-script=${CMAKE_SOURCE_DIR}/src/gandiva/jni/symbols.map" + ) + endif() +endif() diff --git a/src/arrow/cpp/src/gandiva/jni/config_builder.cc b/src/arrow/cpp/src/gandiva/jni/config_builder.cc new file mode 100644 index 000000000..b115210ce --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/config_builder.cc @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <string> + +#include "gandiva/configuration.h" +#include "gandiva/jni/config_holder.h" +#include "gandiva/jni/env_helper.h" +#include "jni/org_apache_arrow_gandiva_evaluator_ConfigurationBuilder.h" + +using gandiva::ConfigHolder; +using gandiva::Configuration; +using gandiva::ConfigurationBuilder; + +/* + * Class: org_apache_arrow_gandiva_evaluator_ConfigBuilder + * Method: buildConfigInstance + * Signature: (ZZ)J + */ +JNIEXPORT jlong JNICALL +Java_org_apache_arrow_gandiva_evaluator_ConfigurationBuilder_buildConfigInstance( + JNIEnv* env, jobject configuration, jboolean optimize, jboolean target_host_cpu) { + ConfigurationBuilder configuration_builder; + std::shared_ptr<Configuration> config = configuration_builder.build(); + config->set_optimize(optimize); + config->target_host_cpu(target_host_cpu); + return ConfigHolder::MapInsert(config); +} + +/* + * Class: org_apache_arrow_gandiva_evaluator_ConfigBuilder + * Method: releaseConfigInstance + * Signature: (J)V + */ +JNIEXPORT void JNICALL +Java_org_apache_arrow_gandiva_evaluator_ConfigurationBuilder_releaseConfigInstance( + JNIEnv* env, jobject configuration, jlong config_id) { + ConfigHolder::MapErase(config_id); +} diff --git a/src/arrow/cpp/src/gandiva/jni/config_holder.cc b/src/arrow/cpp/src/gandiva/jni/config_holder.cc new file mode 100644 index 000000000..11d305c81 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/config_holder.cc @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/jni/config_holder.h" + +#include <cstdint> + +namespace gandiva { +int64_t ConfigHolder::config_id_ = 1; + +// map of configuration objects created so far +std::unordered_map<int64_t, std::shared_ptr<Configuration>> + ConfigHolder::configuration_map_; + +std::mutex ConfigHolder::g_mtx_; +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/jni/config_holder.h b/src/arrow/cpp/src/gandiva/jni/config_holder.h new file mode 100644 index 000000000..3fdb7a01d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/config_holder.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <mutex> +#include <unordered_map> +#include <utility> + +#include "gandiva/configuration.h" + +namespace gandiva { + +class ConfigHolder { + public: + static int64_t MapInsert(std::shared_ptr<Configuration> config) { + g_mtx_.lock(); + + int64_t result = config_id_++; + configuration_map_.insert( + std::pair<int64_t, std::shared_ptr<Configuration>>(result, config)); + + g_mtx_.unlock(); + return result; + } + + static void MapErase(int64_t config_id_) { + g_mtx_.lock(); + configuration_map_.erase(config_id_); + g_mtx_.unlock(); + } + + static std::shared_ptr<Configuration> MapLookup(int64_t config_id_) { + std::shared_ptr<Configuration> result = nullptr; + + try { + result = configuration_map_.at(config_id_); + } catch (const std::out_of_range&) { + } + + return result; + } + + private: + // map of configuration objects created so far + static std::unordered_map<int64_t, std::shared_ptr<Configuration>> configuration_map_; + + static std::mutex g_mtx_; + + // atomic counter for projector module ids + static int64_t config_id_; +}; +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/jni/env_helper.h b/src/arrow/cpp/src/gandiva/jni/env_helper.h new file mode 100644 index 000000000..5ae13c807 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/env_helper.h @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <jni.h> + +// class references +extern jclass configuration_builder_class_; diff --git a/src/arrow/cpp/src/gandiva/jni/expression_registry_helper.cc b/src/arrow/cpp/src/gandiva/jni/expression_registry_helper.cc new file mode 100644 index 000000000..0d1f74ba6 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/expression_registry_helper.cc @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "jni/org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper.h" + +#include <memory> + +#include "Types.pb.h" +#include "arrow/util/logging.h" +#include "gandiva/arrow.h" +#include "gandiva/expression_registry.h" + +using gandiva::DataTypePtr; +using gandiva::ExpressionRegistry; + +types::TimeUnit MapTimeUnit(arrow::TimeUnit::type& unit) { + switch (unit) { + case arrow::TimeUnit::MILLI: + return types::TimeUnit::MILLISEC; + case arrow::TimeUnit::SECOND: + return types::TimeUnit::SEC; + case arrow::TimeUnit::MICRO: + return types::TimeUnit::MICROSEC; + case arrow::TimeUnit::NANO: + return types::TimeUnit::NANOSEC; + } + // satisfy gcc. should be unreachable. + return types::TimeUnit::SEC; +} + +void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type) { + switch (type->id()) { + case arrow::Type::BOOL: + gandiva_data_type->set_type(types::GandivaType::BOOL); + break; + case arrow::Type::UINT8: + gandiva_data_type->set_type(types::GandivaType::UINT8); + break; + case arrow::Type::INT8: + gandiva_data_type->set_type(types::GandivaType::INT8); + break; + case arrow::Type::UINT16: + gandiva_data_type->set_type(types::GandivaType::UINT16); + break; + case arrow::Type::INT16: + gandiva_data_type->set_type(types::GandivaType::INT16); + break; + case arrow::Type::UINT32: + gandiva_data_type->set_type(types::GandivaType::UINT32); + break; + case arrow::Type::INT32: + gandiva_data_type->set_type(types::GandivaType::INT32); + break; + case arrow::Type::UINT64: + gandiva_data_type->set_type(types::GandivaType::UINT64); + break; + case arrow::Type::INT64: + gandiva_data_type->set_type(types::GandivaType::INT64); + break; + case arrow::Type::HALF_FLOAT: + gandiva_data_type->set_type(types::GandivaType::HALF_FLOAT); + break; + case arrow::Type::FLOAT: + gandiva_data_type->set_type(types::GandivaType::FLOAT); + break; + case arrow::Type::DOUBLE: + gandiva_data_type->set_type(types::GandivaType::DOUBLE); + break; + case arrow::Type::STRING: + gandiva_data_type->set_type(types::GandivaType::UTF8); + break; + case arrow::Type::BINARY: + gandiva_data_type->set_type(types::GandivaType::BINARY); + break; + case arrow::Type::DATE32: + gandiva_data_type->set_type(types::GandivaType::DATE32); + break; + case arrow::Type::DATE64: + gandiva_data_type->set_type(types::GandivaType::DATE64); + break; + case arrow::Type::TIMESTAMP: { + gandiva_data_type->set_type(types::GandivaType::TIMESTAMP); + std::shared_ptr<arrow::TimestampType> cast_time_stamp_type = + std::dynamic_pointer_cast<arrow::TimestampType>(type); + arrow::TimeUnit::type unit = cast_time_stamp_type->unit(); + types::TimeUnit time_unit = MapTimeUnit(unit); + gandiva_data_type->set_timeunit(time_unit); + break; + } + case arrow::Type::TIME32: { + gandiva_data_type->set_type(types::GandivaType::TIME32); + std::shared_ptr<arrow::Time32Type> cast_time_32_type = + std::dynamic_pointer_cast<arrow::Time32Type>(type); + arrow::TimeUnit::type unit = cast_time_32_type->unit(); + types::TimeUnit time_unit = MapTimeUnit(unit); + gandiva_data_type->set_timeunit(time_unit); + break; + } + case arrow::Type::TIME64: { + gandiva_data_type->set_type(types::GandivaType::TIME32); + std::shared_ptr<arrow::Time64Type> cast_time_64_type = + std::dynamic_pointer_cast<arrow::Time64Type>(type); + arrow::TimeUnit::type unit = cast_time_64_type->unit(); + types::TimeUnit time_unit = MapTimeUnit(unit); + gandiva_data_type->set_timeunit(time_unit); + break; + } + case arrow::Type::NA: + gandiva_data_type->set_type(types::GandivaType::NONE); + break; + case arrow::Type::DECIMAL: { + gandiva_data_type->set_type(types::GandivaType::DECIMAL); + gandiva_data_type->set_precision(0); + gandiva_data_type->set_scale(0); + break; + } + case arrow::Type::INTERVAL_MONTHS: + gandiva_data_type->set_type(types::GandivaType::INTERVAL); + gandiva_data_type->set_intervaltype(types::IntervalType::YEAR_MONTH); + break; + case arrow::Type::INTERVAL_DAY_TIME: + gandiva_data_type->set_type(types::GandivaType::INTERVAL); + gandiva_data_type->set_intervaltype(types::IntervalType::DAY_TIME); + break; + default: + // un-supported types. test ensures that + // when one of these are added build breaks. + DCHECK(false); + } +} + +JNIEXPORT jbyteArray JNICALL +Java_org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper_getGandivaSupportedDataTypes( // NOLINT + JNIEnv* env, jobject types_helper) { + types::GandivaDataTypes gandiva_data_types; + auto supported_types = ExpressionRegistry::supported_types(); + for (auto const& type : supported_types) { + types::ExtGandivaType* gandiva_data_type = gandiva_data_types.add_datatype(); + ArrowToProtobuf(type, gandiva_data_type); + } + auto size = gandiva_data_types.ByteSizeLong(); + std::unique_ptr<jbyte[]> buffer{new jbyte[size]}; + gandiva_data_types.SerializeToArray(reinterpret_cast<void*>(buffer.get()), size); + jbyteArray ret = env->NewByteArray(size); + env->SetByteArrayRegion(ret, 0, size, buffer.get()); + return ret; +} + +/* + * Class: org_apache_arrow_gandiva_types_ExpressionRegistryJniHelper + * Method: getGandivaSupportedFunctions + * Signature: ()[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper_getGandivaSupportedFunctions( // NOLINT + JNIEnv* env, jobject types_helper) { + ExpressionRegistry expr_registry; + types::GandivaFunctions gandiva_functions; + for (auto function = expr_registry.function_signature_begin(); + function != expr_registry.function_signature_end(); function++) { + types::FunctionSignature* function_signature = gandiva_functions.add_function(); + function_signature->set_name((*function).base_name()); + types::ExtGandivaType* return_type = function_signature->mutable_returntype(); + ArrowToProtobuf((*function).ret_type(), return_type); + for (auto& param_type : (*function).param_types()) { + types::ExtGandivaType* proto_param_type = function_signature->add_paramtypes(); + ArrowToProtobuf(param_type, proto_param_type); + } + } + auto size = gandiva_functions.ByteSizeLong(); + std::unique_ptr<jbyte[]> buffer{new jbyte[size]}; + gandiva_functions.SerializeToArray(reinterpret_cast<void*>(buffer.get()), size); + jbyteArray ret = env->NewByteArray(size); + env->SetByteArrayRegion(ret, 0, size, buffer.get()); + return ret; +} diff --git a/src/arrow/cpp/src/gandiva/jni/id_to_module_map.h b/src/arrow/cpp/src/gandiva/jni/id_to_module_map.h new file mode 100644 index 000000000..98100955b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/id_to_module_map.h @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> +#include <unordered_map> +#include <utility> + +namespace gandiva { + +template <typename HOLDER> +class IdToModuleMap { + public: + IdToModuleMap() : module_id_(kInitModuleId) {} + + jlong Insert(HOLDER holder) { + mtx_.lock(); + jlong result = module_id_++; + map_.insert(std::pair<jlong, HOLDER>(result, holder)); + mtx_.unlock(); + return result; + } + + void Erase(jlong module_id) { + mtx_.lock(); + map_.erase(module_id); + mtx_.unlock(); + } + + HOLDER Lookup(jlong module_id) { + HOLDER result = nullptr; + mtx_.lock(); + try { + result = map_.at(module_id); + } catch (const std::out_of_range&) { + } + mtx_.unlock(); + return result; + } + + private: + static const int kInitModuleId = 4; + + int64_t module_id_; + std::mutex mtx_; + // map from module ids returned to Java and module pointers + std::unordered_map<jlong, HOLDER> map_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/jni/jni_common.cc b/src/arrow/cpp/src/gandiva/jni/jni_common.cc new file mode 100644 index 000000000..5a4cbb031 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/jni_common.cc @@ -0,0 +1,1055 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <google/protobuf/io/coded_stream.h> + +#include <map> +#include <memory> +#include <mutex> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include <arrow/builder.h> +#include <arrow/record_batch.h> +#include <arrow/type.h> + +#include "Types.pb.h" +#include "gandiva/configuration.h" +#include "gandiva/decimal_scalar.h" +#include "gandiva/filter.h" +#include "gandiva/jni/config_holder.h" +#include "gandiva/jni/env_helper.h" +#include "gandiva/jni/id_to_module_map.h" +#include "gandiva/jni/module_holder.h" +#include "gandiva/projector.h" +#include "gandiva/selection_vector.h" +#include "gandiva/tree_expr_builder.h" +#include "jni/org_apache_arrow_gandiva_evaluator_JniWrapper.h" + +using gandiva::ConditionPtr; +using gandiva::DataTypePtr; +using gandiva::ExpressionPtr; +using gandiva::ExpressionVector; +using gandiva::FieldPtr; +using gandiva::FieldVector; +using gandiva::Filter; +using gandiva::NodePtr; +using gandiva::NodeVector; +using gandiva::Projector; +using gandiva::SchemaPtr; +using gandiva::Status; +using gandiva::TreeExprBuilder; + +using gandiva::ArrayDataVector; +using gandiva::ConfigHolder; +using gandiva::Configuration; +using gandiva::ConfigurationBuilder; +using gandiva::FilterHolder; +using gandiva::ProjectorHolder; + +// forward declarations +NodePtr ProtoTypeToNode(const types::TreeNode& node); + +static jint JNI_VERSION = JNI_VERSION_1_6; + +// extern refs - initialized for other modules. +jclass configuration_builder_class_; + +// refs for self. +static jclass gandiva_exception_; +static jclass vector_expander_class_; +static jclass vector_expander_ret_class_; +static jmethodID vector_expander_method_; +static jfieldID vector_expander_ret_address_; +static jfieldID vector_expander_ret_capacity_; + +// module maps +gandiva::IdToModuleMap<std::shared_ptr<ProjectorHolder>> projector_modules_; +gandiva::IdToModuleMap<std::shared_ptr<FilterHolder>> filter_modules_; + +jint JNI_OnLoad(JavaVM* vm, void* reserved) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + jclass local_configuration_builder_class_ = + env->FindClass("org/apache/arrow/gandiva/evaluator/ConfigurationBuilder"); + configuration_builder_class_ = + (jclass)env->NewGlobalRef(local_configuration_builder_class_); + env->DeleteLocalRef(local_configuration_builder_class_); + + jclass localExceptionClass = + env->FindClass("org/apache/arrow/gandiva/exceptions/GandivaException"); + gandiva_exception_ = (jclass)env->NewGlobalRef(localExceptionClass); + env->ExceptionDescribe(); + env->DeleteLocalRef(localExceptionClass); + + jclass local_expander_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander"); + vector_expander_class_ = (jclass)env->NewGlobalRef(local_expander_class); + env->DeleteLocalRef(local_expander_class); + + vector_expander_method_ = env->GetMethodID( + vector_expander_class_, "expandOutputVectorAtIndex", + "(IJ)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;"); + + jclass local_expander_ret_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult"); + vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class); + env->DeleteLocalRef(local_expander_ret_class); + + vector_expander_ret_address_ = + env->GetFieldID(vector_expander_ret_class_, "address", "J"); + vector_expander_ret_capacity_ = + env->GetFieldID(vector_expander_ret_class_, "capacity", "J"); + return JNI_VERSION; +} + +void JNI_OnUnload(JavaVM* vm, void* reserved) { + JNIEnv* env; + vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION); + env->DeleteGlobalRef(configuration_builder_class_); + env->DeleteGlobalRef(gandiva_exception_); + env->DeleteGlobalRef(vector_expander_class_); + env->DeleteGlobalRef(vector_expander_ret_class_); +} + +DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) { + switch (ext_type.timeunit()) { + case types::SEC: + return arrow::time32(arrow::TimeUnit::SECOND); + case types::MILLISEC: + return arrow::time32(arrow::TimeUnit::MILLI); + default: + std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time32\n"; + return nullptr; + } +} + +DataTypePtr ProtoTypeToTime64(const types::ExtGandivaType& ext_type) { + switch (ext_type.timeunit()) { + case types::MICROSEC: + return arrow::time64(arrow::TimeUnit::MICRO); + case types::NANOSEC: + return arrow::time64(arrow::TimeUnit::NANO); + default: + std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for time64\n"; + return nullptr; + } +} + +DataTypePtr ProtoTypeToTimestamp(const types::ExtGandivaType& ext_type) { + switch (ext_type.timeunit()) { + case types::SEC: + return arrow::timestamp(arrow::TimeUnit::SECOND); + case types::MILLISEC: + return arrow::timestamp(arrow::TimeUnit::MILLI); + case types::MICROSEC: + return arrow::timestamp(arrow::TimeUnit::MICRO); + case types::NANOSEC: + return arrow::timestamp(arrow::TimeUnit::NANO); + default: + std::cerr << "Unknown time unit: " << ext_type.timeunit() << " for timestamp\n"; + return nullptr; + } +} + +DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) { + switch (ext_type.intervaltype()) { + case types::YEAR_MONTH: + return arrow::month_interval(); + case types::DAY_TIME: + return arrow::day_time_interval(); + default: + std::cerr << "Unknown interval type: " << ext_type.intervaltype() << "\n"; + return nullptr; + } +} + +DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { + switch (ext_type.type()) { + case types::NONE: + return arrow::null(); + case types::BOOL: + return arrow::boolean(); + case types::UINT8: + return arrow::uint8(); + case types::INT8: + return arrow::int8(); + case types::UINT16: + return arrow::uint16(); + case types::INT16: + return arrow::int16(); + case types::UINT32: + return arrow::uint32(); + case types::INT32: + return arrow::int32(); + case types::UINT64: + return arrow::uint64(); + case types::INT64: + return arrow::int64(); + case types::HALF_FLOAT: + return arrow::float16(); + case types::FLOAT: + return arrow::float32(); + case types::DOUBLE: + return arrow::float64(); + case types::UTF8: + return arrow::utf8(); + case types::BINARY: + return arrow::binary(); + case types::DATE32: + return arrow::date32(); + case types::DATE64: + return arrow::date64(); + case types::DECIMAL: + // TODO: error handling + return arrow::decimal(ext_type.precision(), ext_type.scale()); + case types::TIME32: + return ProtoTypeToTime32(ext_type); + case types::TIME64: + return ProtoTypeToTime64(ext_type); + case types::TIMESTAMP: + return ProtoTypeToTimestamp(ext_type); + case types::INTERVAL: + return ProtoTypeToInterval(ext_type); + case types::FIXED_SIZE_BINARY: + case types::LIST: + case types::STRUCT: + case types::UNION: + case types::DICTIONARY: + case types::MAP: + std::cerr << "Unhandled data type: " << ext_type.type() << "\n"; + return nullptr; + + default: + std::cerr << "Unknown data type: " << ext_type.type() << "\n"; + return nullptr; + } +} + +FieldPtr ProtoTypeToField(const types::Field& f) { + const std::string& name = f.name(); + DataTypePtr type = ProtoTypeToDataType(f.type()); + bool nullable = true; + if (f.has_nullable()) { + nullable = f.nullable(); + } + + return field(name, type, nullable); +} + +NodePtr ProtoTypeToFieldNode(const types::FieldNode& node) { + FieldPtr field_ptr = ProtoTypeToField(node.field()); + if (field_ptr == nullptr) { + std::cerr << "Unable to create field node from protobuf\n"; + return nullptr; + } + + return TreeExprBuilder::MakeField(field_ptr); +} + +NodePtr ProtoTypeToFnNode(const types::FunctionNode& node) { + const std::string& name = node.functionname(); + NodeVector children; + + for (int i = 0; i < node.inargs_size(); i++) { + const types::TreeNode& arg = node.inargs(i); + + NodePtr n = ProtoTypeToNode(arg); + if (n == nullptr) { + std::cerr << "Unable to create argument for function: " << name << "\n"; + return nullptr; + } + + children.push_back(n); + } + + DataTypePtr return_type = ProtoTypeToDataType(node.returntype()); + if (return_type == nullptr) { + std::cerr << "Unknown return type for function: " << name << "\n"; + return nullptr; + } + + return TreeExprBuilder::MakeFunction(name, children, return_type); +} + +NodePtr ProtoTypeToIfNode(const types::IfNode& node) { + NodePtr cond = ProtoTypeToNode(node.cond()); + if (cond == nullptr) { + std::cerr << "Unable to create cond node for if node\n"; + return nullptr; + } + + NodePtr then_node = ProtoTypeToNode(node.thennode()); + if (then_node == nullptr) { + std::cerr << "Unable to create then node for if node\n"; + return nullptr; + } + + NodePtr else_node = ProtoTypeToNode(node.elsenode()); + if (else_node == nullptr) { + std::cerr << "Unable to create else node for if node\n"; + return nullptr; + } + + DataTypePtr return_type = ProtoTypeToDataType(node.returntype()); + if (return_type == nullptr) { + std::cerr << "Unknown return type for if node\n"; + return nullptr; + } + + return TreeExprBuilder::MakeIf(cond, then_node, else_node, return_type); +} + +NodePtr ProtoTypeToAndNode(const types::AndNode& node) { + NodeVector children; + + for (int i = 0; i < node.args_size(); i++) { + const types::TreeNode& arg = node.args(i); + + NodePtr n = ProtoTypeToNode(arg); + if (n == nullptr) { + std::cerr << "Unable to create argument for boolean and\n"; + return nullptr; + } + children.push_back(n); + } + return TreeExprBuilder::MakeAnd(children); +} + +NodePtr ProtoTypeToOrNode(const types::OrNode& node) { + NodeVector children; + + for (int i = 0; i < node.args_size(); i++) { + const types::TreeNode& arg = node.args(i); + + NodePtr n = ProtoTypeToNode(arg); + if (n == nullptr) { + std::cerr << "Unable to create argument for boolean or\n"; + return nullptr; + } + children.push_back(n); + } + return TreeExprBuilder::MakeOr(children); +} + +NodePtr ProtoTypeToInNode(const types::InNode& node) { + NodePtr field = ProtoTypeToNode(node.node()); + + if (node.has_intvalues()) { + std::unordered_set<int32_t> int_values; + for (int i = 0; i < node.intvalues().intvalues_size(); i++) { + int_values.insert(node.intvalues().intvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionInt32(field, int_values); + } + + if (node.has_longvalues()) { + std::unordered_set<int64_t> long_values; + for (int i = 0; i < node.longvalues().longvalues_size(); i++) { + long_values.insert(node.longvalues().longvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionInt64(field, long_values); + } + + if (node.has_decimalvalues()) { + std::unordered_set<gandiva::DecimalScalar128> decimal_values; + for (int i = 0; i < node.decimalvalues().decimalvalues_size(); i++) { + decimal_values.insert( + gandiva::DecimalScalar128(node.decimalvalues().decimalvalues(i).value(), + node.decimalvalues().decimalvalues(i).precision(), + node.decimalvalues().decimalvalues(i).scale())); + } + return TreeExprBuilder::MakeInExpressionDecimal(field, decimal_values); + } + + if (node.has_floatvalues()) { + std::unordered_set<float> float_values; + for (int i = 0; i < node.floatvalues().floatvalues_size(); i++) { + float_values.insert(node.floatvalues().floatvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionFloat(field, float_values); + } + + if (node.has_doublevalues()) { + std::unordered_set<double> double_values; + for (int i = 0; i < node.doublevalues().doublevalues_size(); i++) { + double_values.insert(node.doublevalues().doublevalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionDouble(field, double_values); + } + + if (node.has_stringvalues()) { + std::unordered_set<std::string> stringvalues; + for (int i = 0; i < node.stringvalues().stringvalues_size(); i++) { + stringvalues.insert(node.stringvalues().stringvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionString(field, stringvalues); + } + + if (node.has_binaryvalues()) { + std::unordered_set<std::string> stringvalues; + for (int i = 0; i < node.binaryvalues().binaryvalues_size(); i++) { + stringvalues.insert(node.binaryvalues().binaryvalues(i).value()); + } + return TreeExprBuilder::MakeInExpressionBinary(field, stringvalues); + } + // not supported yet. + std::cerr << "Unknown constant type for in expression.\n"; + return nullptr; +} + +NodePtr ProtoTypeToNullNode(const types::NullNode& node) { + DataTypePtr data_type = ProtoTypeToDataType(node.type()); + if (data_type == nullptr) { + std::cerr << "Unknown type " << data_type->ToString() << " for null node\n"; + return nullptr; + } + + return TreeExprBuilder::MakeNull(data_type); +} + +NodePtr ProtoTypeToNode(const types::TreeNode& node) { + if (node.has_fieldnode()) { + return ProtoTypeToFieldNode(node.fieldnode()); + } + + if (node.has_fnnode()) { + return ProtoTypeToFnNode(node.fnnode()); + } + + if (node.has_ifnode()) { + return ProtoTypeToIfNode(node.ifnode()); + } + + if (node.has_andnode()) { + return ProtoTypeToAndNode(node.andnode()); + } + + if (node.has_ornode()) { + return ProtoTypeToOrNode(node.ornode()); + } + + if (node.has_innode()) { + return ProtoTypeToInNode(node.innode()); + } + + if (node.has_nullnode()) { + return ProtoTypeToNullNode(node.nullnode()); + } + + if (node.has_intnode()) { + return TreeExprBuilder::MakeLiteral(node.intnode().value()); + } + + if (node.has_floatnode()) { + return TreeExprBuilder::MakeLiteral(node.floatnode().value()); + } + + if (node.has_longnode()) { + return TreeExprBuilder::MakeLiteral(node.longnode().value()); + } + + if (node.has_booleannode()) { + return TreeExprBuilder::MakeLiteral(node.booleannode().value()); + } + + if (node.has_doublenode()) { + return TreeExprBuilder::MakeLiteral(node.doublenode().value()); + } + + if (node.has_stringnode()) { + return TreeExprBuilder::MakeStringLiteral(node.stringnode().value()); + } + + if (node.has_binarynode()) { + return TreeExprBuilder::MakeBinaryLiteral(node.binarynode().value()); + } + + if (node.has_decimalnode()) { + std::string value = node.decimalnode().value(); + gandiva::DecimalScalar128 literal(value, node.decimalnode().precision(), + node.decimalnode().scale()); + return TreeExprBuilder::MakeDecimalLiteral(literal); + } + std::cerr << "Unknown node type in protobuf\n"; + return nullptr; +} + +ExpressionPtr ProtoTypeToExpression(const types::ExpressionRoot& root) { + NodePtr root_node = ProtoTypeToNode(root.root()); + if (root_node == nullptr) { + std::cerr << "Unable to create expression node from expression protobuf\n"; + return nullptr; + } + + FieldPtr field = ProtoTypeToField(root.resulttype()); + if (field == nullptr) { + std::cerr << "Unable to extra return field from expression protobuf\n"; + return nullptr; + } + + return TreeExprBuilder::MakeExpression(root_node, field); +} + +ConditionPtr ProtoTypeToCondition(const types::Condition& condition) { + NodePtr root_node = ProtoTypeToNode(condition.root()); + if (root_node == nullptr) { + return nullptr; + } + + return TreeExprBuilder::MakeCondition(root_node); +} + +SchemaPtr ProtoTypeToSchema(const types::Schema& schema) { + std::vector<FieldPtr> fields; + + for (int i = 0; i < schema.columns_size(); i++) { + FieldPtr field = ProtoTypeToField(schema.columns(i)); + if (field == nullptr) { + std::cerr << "Unable to extract arrow field from schema\n"; + return nullptr; + } + + fields.push_back(field); + } + + return arrow::schema(fields); +} + +// Common for both projector and filters. + +bool ParseProtobuf(uint8_t* buf, int bufLen, google::protobuf::Message* msg) { + google::protobuf::io::CodedInputStream cis(buf, bufLen); + cis.SetRecursionLimit(1000); + return msg->ParseFromCodedStream(&cis); +} + +Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, + jlong* in_buf_addrs, jlong* in_buf_sizes, + int in_bufs_len, + std::shared_ptr<arrow::RecordBatch>* batch) { + std::vector<std::shared_ptr<arrow::ArrayData>> columns; + auto num_fields = schema->num_fields(); + int buf_idx = 0; + int sz_idx = 0; + + for (int i = 0; i < num_fields; i++) { + auto field = schema->field(i); + std::vector<std::shared_ptr<arrow::Buffer>> buffers; + + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + jlong validity_addr = in_buf_addrs[buf_idx++]; + jlong validity_size = in_buf_sizes[sz_idx++]; + auto validity = std::shared_ptr<arrow::Buffer>( + new arrow::Buffer(reinterpret_cast<uint8_t*>(validity_addr), validity_size)); + buffers.push_back(validity); + + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + jlong value_addr = in_buf_addrs[buf_idx++]; + jlong value_size = in_buf_sizes[sz_idx++]; + auto data = std::shared_ptr<arrow::Buffer>( + new arrow::Buffer(reinterpret_cast<uint8_t*>(value_addr), value_size)); + buffers.push_back(data); + + if (arrow::is_binary_like(field->type()->id())) { + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + + // add offsets buffer for variable-len fields. + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + auto offsets = std::shared_ptr<arrow::Buffer>( + new arrow::Buffer(reinterpret_cast<uint8_t*>(offsets_addr), offsets_size)); + buffers.push_back(offsets); + } + + auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers)); + columns.push_back(array_data); + } + *batch = arrow::RecordBatch::Make(schema, num_rows, columns); + return Status::OK(); +} + +// projector related functions. +void releaseProjectorInput(jbyteArray schema_arr, jbyte* schema_bytes, + jbyteArray exprs_arr, jbyte* exprs_bytes, JNIEnv* env) { + env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT); + env->ReleaseByteArrayElements(exprs_arr, exprs_bytes, JNI_ABORT); +} + +JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildProjector( + JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray exprs_arr, + jint selection_vector_type, jlong configuration_id) { + jlong module_id = 0LL; + std::shared_ptr<Projector> projector; + std::shared_ptr<ProjectorHolder> holder; + + types::Schema schema; + jsize schema_len = env->GetArrayLength(schema_arr); + jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0); + + types::ExpressionList exprs; + jsize exprs_len = env->GetArrayLength(exprs_arr); + jbyte* exprs_bytes = env->GetByteArrayElements(exprs_arr, 0); + + ExpressionVector expr_vector; + SchemaPtr schema_ptr; + FieldVector ret_types; + gandiva::Status status; + auto mode = gandiva::SelectionVector::MODE_NONE; + + std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id); + std::stringstream ss; + + if (config == nullptr) { + ss << "configuration is mandatory."; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) { + ss << "Unable to parse schema protobuf\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast<uint8_t*>(exprs_bytes), exprs_len, &exprs)) { + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + ss << "Unable to parse expressions protobuf\n"; + goto err_out; + } + + // convert types::Schema to arrow::Schema + schema_ptr = ProtoTypeToSchema(schema); + if (schema_ptr == nullptr) { + ss << "Unable to construct arrow schema object from schema protobuf\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + // create Expression out of the list of exprs + for (int i = 0; i < exprs.exprs_size(); i++) { + ExpressionPtr root = ProtoTypeToExpression(exprs.exprs(i)); + + if (root == nullptr) { + ss << "Unable to construct expression object from expression protobuf\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + expr_vector.push_back(root); + ret_types.push_back(root->result()); + } + + switch (selection_vector_type) { + case types::SV_NONE: + mode = gandiva::SelectionVector::MODE_NONE; + break; + case types::SV_INT16: + mode = gandiva::SelectionVector::MODE_UINT16; + break; + case types::SV_INT32: + mode = gandiva::SelectionVector::MODE_UINT32; + break; + } + // good to invoke the evaluator now + status = Projector::Make(schema_ptr, expr_vector, mode, config, &projector); + + if (!status.ok()) { + ss << "Failed to make LLVM module due to " << status.message() << "\n"; + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + goto err_out; + } + + // store the result in a map + holder = std::shared_ptr<ProjectorHolder>( + new ProjectorHolder(schema_ptr, ret_types, std::move(projector))); + module_id = projector_modules_.Insert(holder); + releaseProjectorInput(schema_arr, schema_bytes, exprs_arr, exprs_bytes, env); + return module_id; + +err_out: + env->ThrowNew(gandiva_exception_, ss.str().c_str()); + return module_id; +} + +/// +/// \brief Resizable buffer which resizes by doing a callback into java. +/// +class JavaResizableBuffer : public arrow::ResizableBuffer { + public: + JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer, + int32_t len) + : ResizableBuffer(buffer, len), + env_(env), + jexpander_(jexpander), + vector_idx_(vector_idx) { + size_ = 0; + } + + Status Resize(const int64_t new_size, bool shrink_to_fit) override; + + Status Reserve(const int64_t new_capacity) override { + return Status::NotImplemented("reserve not implemented"); + } + + private: + JNIEnv* env_; + jobject jexpander_; + int32_t vector_idx_; +}; + +Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) { + if (shrink_to_fit == true) { + return Status::NotImplemented("shrink not implemented"); + } + + if (ARROW_PREDICT_TRUE(new_size < capacity())) { + // no need to expand. + size_ = new_size; + return Status::OK(); + } + + // callback into java to expand the buffer + jobject ret = + env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, new_size); + if (env_->ExceptionCheck()) { + env_->ExceptionDescribe(); + env_->ExceptionClear(); + return Status::OutOfMemory("buffer expand failed in java"); + } + + jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_); + jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_); + DCHECK_GE(ret_capacity, new_size); + + data_ = reinterpret_cast<uint8_t*>(ret_address); + size_ = new_size; + capacity_ = ret_capacity; + return Status::OK(); +} + +#define CHECK_OUT_BUFFER_IDX_AND_BREAK(idx, len) \ + if (idx >= len) { \ + status = gandiva::Status::Invalid("insufficient number of out_buf_addrs"); \ + break; \ + } + +JNIEXPORT void JNICALL +Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( + JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows, + jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows, + jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs, + jlongArray out_buf_sizes) { + Status status; + std::shared_ptr<ProjectorHolder> holder = projector_modules_.Lookup(module_id); + if (holder == nullptr) { + std::stringstream ss; + ss << "Unknown module id " << module_id; + env->ThrowNew(gandiva_exception_, ss.str().c_str()); + return; + } + + int in_bufs_len = env->GetArrayLength(buf_addrs); + if (in_bufs_len != env->GetArrayLength(buf_sizes)) { + env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes"); + return; + } + + int out_bufs_len = env->GetArrayLength(out_buf_addrs); + if (out_bufs_len != env->GetArrayLength(out_buf_sizes)) { + env->ThrowNew(gandiva_exception_, + "mismatch in arraylen of out_buf_addrs and out_buf_sizes"); + return; + } + + jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0); + jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0); + + jlong* out_bufs = env->GetLongArrayElements(out_buf_addrs, 0); + jlong* out_sizes = env->GetLongArrayElements(out_buf_sizes, 0); + + do { + std::shared_ptr<arrow::RecordBatch> in_batch; + status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs, + in_buf_sizes, in_bufs_len, &in_batch); + if (!status.ok()) { + break; + } + + std::shared_ptr<gandiva::SelectionVector> selection_vector; + auto selection_buffer = std::make_shared<arrow::Buffer>( + reinterpret_cast<uint8_t*>(sel_vec_addr), sel_vec_size); + int output_row_count = 0; + switch (sel_vec_type) { + case types::SV_NONE: { + output_row_count = num_rows; + break; + } + case types::SV_INT16: { + status = gandiva::SelectionVector::MakeImmutableInt16( + sel_vec_rows, selection_buffer, &selection_vector); + output_row_count = sel_vec_rows; + break; + } + case types::SV_INT32: { + status = gandiva::SelectionVector::MakeImmutableInt32( + sel_vec_rows, selection_buffer, &selection_vector); + output_row_count = sel_vec_rows; + break; + } + } + if (!status.ok()) { + break; + } + + auto ret_types = holder->rettypes(); + ArrayDataVector output; + int buf_idx = 0; + int sz_idx = 0; + int output_vector_idx = 0; + for (FieldPtr field : ret_types) { + std::vector<std::shared_ptr<arrow::Buffer>> buffers; + + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* validity_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]); + jlong bitmap_sz = out_sizes[sz_idx++]; + buffers.push_back(std::make_shared<arrow::MutableBuffer>(validity_buf, bitmap_sz)); + + if (arrow::is_binary_like(field->type()->id())) { + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* offsets_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]); + jlong offsets_sz = out_sizes[sz_idx++]; + buffers.push_back( + std::make_shared<arrow::MutableBuffer>(offsets_buf, offsets_sz)); + } + + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* value_buf = reinterpret_cast<uint8_t*>(out_bufs[buf_idx++]); + jlong data_sz = out_sizes[sz_idx++]; + if (arrow::is_binary_like(field->type()->id())) { + if (jexpander == nullptr) { + status = Status::Invalid( + "expression has variable len output columns, but the expander object is " + "null"); + break; + } + buffers.push_back(std::make_shared<JavaResizableBuffer>( + env, jexpander, output_vector_idx, value_buf, data_sz)); + } else { + buffers.push_back(std::make_shared<arrow::MutableBuffer>(value_buf, data_sz)); + } + + auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers); + output.push_back(array_data); + ++output_vector_idx; + } + if (!status.ok()) { + break; + } + status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output); + } while (0); + + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); + env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); + env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT); + env->ReleaseLongArrayElements(out_buf_sizes, out_sizes, JNI_ABORT); + + if (!status.ok()) { + std::stringstream ss; + ss << "Evaluate returned " << status.message() << "\n"; + env->ThrowNew(gandiva_exception_, status.message().c_str()); + return; + } +} + +JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeProjector( + JNIEnv* env, jobject cls, jlong module_id) { + projector_modules_.Erase(module_id); +} + +// filter related functions. +void releaseFilterInput(jbyteArray schema_arr, jbyte* schema_bytes, + jbyteArray condition_arr, jbyte* condition_bytes, JNIEnv* env) { + env->ReleaseByteArrayElements(schema_arr, schema_bytes, JNI_ABORT); + env->ReleaseByteArrayElements(condition_arr, condition_bytes, JNI_ABORT); +} + +JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_buildFilter( + JNIEnv* env, jobject obj, jbyteArray schema_arr, jbyteArray condition_arr, + jlong configuration_id) { + jlong module_id = 0LL; + std::shared_ptr<Filter> filter; + std::shared_ptr<FilterHolder> holder; + + types::Schema schema; + jsize schema_len = env->GetArrayLength(schema_arr); + jbyte* schema_bytes = env->GetByteArrayElements(schema_arr, 0); + + types::Condition condition; + jsize condition_len = env->GetArrayLength(condition_arr); + jbyte* condition_bytes = env->GetByteArrayElements(condition_arr, 0); + + ConditionPtr condition_ptr; + SchemaPtr schema_ptr; + gandiva::Status status; + + std::shared_ptr<Configuration> config = ConfigHolder::MapLookup(configuration_id); + std::stringstream ss; + + if (config == nullptr) { + ss << "configuration is mandatory."; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast<uint8_t*>(schema_bytes), schema_len, &schema)) { + ss << "Unable to parse schema protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + if (!ParseProtobuf(reinterpret_cast<uint8_t*>(condition_bytes), condition_len, + &condition)) { + ss << "Unable to parse condition protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + // convert types::Schema to arrow::Schema + schema_ptr = ProtoTypeToSchema(schema); + if (schema_ptr == nullptr) { + ss << "Unable to construct arrow schema object from schema protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + condition_ptr = ProtoTypeToCondition(condition); + if (condition_ptr == nullptr) { + ss << "Unable to construct condition object from condition protobuf\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + // good to invoke the filter builder now + status = Filter::Make(schema_ptr, condition_ptr, config, &filter); + if (!status.ok()) { + ss << "Failed to make LLVM module due to " << status.message() << "\n"; + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + goto err_out; + } + + // store the result in a map + holder = std::shared_ptr<FilterHolder>(new FilterHolder(schema_ptr, std::move(filter))); + module_id = filter_modules_.Insert(holder); + releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); + return module_id; + +err_out: + env->ThrowNew(gandiva_exception_, ss.str().c_str()); + return module_id; +} + +JNIEXPORT jint JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateFilter( + JNIEnv* env, jobject cls, jlong module_id, jint num_rows, jlongArray buf_addrs, + jlongArray buf_sizes, jint jselection_vector_type, jlong out_buf_addr, + jlong out_buf_size) { + gandiva::Status status; + std::shared_ptr<FilterHolder> holder = filter_modules_.Lookup(module_id); + if (holder == nullptr) { + env->ThrowNew(gandiva_exception_, "Unknown module id\n"); + return -1; + } + + int in_bufs_len = env->GetArrayLength(buf_addrs); + if (in_bufs_len != env->GetArrayLength(buf_sizes)) { + env->ThrowNew(gandiva_exception_, "mismatch in arraylen of buf_addrs and buf_sizes"); + return -1; + } + + jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, 0); + jlong* in_buf_sizes = env->GetLongArrayElements(buf_sizes, 0); + std::shared_ptr<gandiva::SelectionVector> selection_vector; + + do { + std::shared_ptr<arrow::RecordBatch> in_batch; + + status = make_record_batch_with_buf_addrs(holder->schema(), num_rows, in_buf_addrs, + in_buf_sizes, in_bufs_len, &in_batch); + if (!status.ok()) { + break; + } + + auto selection_vector_type = + static_cast<types::SelectionVectorType>(jselection_vector_type); + auto out_buffer = std::make_shared<arrow::MutableBuffer>( + reinterpret_cast<uint8_t*>(out_buf_addr), out_buf_size); + switch (selection_vector_type) { + case types::SV_INT16: + status = + gandiva::SelectionVector::MakeInt16(num_rows, out_buffer, &selection_vector); + break; + case types::SV_INT32: + status = + gandiva::SelectionVector::MakeInt32(num_rows, out_buffer, &selection_vector); + break; + default: + status = gandiva::Status::Invalid("unknown selection vector type"); + } + if (!status.ok()) { + break; + } + + status = holder->filter()->Evaluate(*in_batch, selection_vector); + } while (0); + + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); + env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); + + if (!status.ok()) { + std::stringstream ss; + ss << "Evaluate returned " << status.message() << "\n"; + env->ThrowNew(gandiva_exception_, status.message().c_str()); + return -1; + } else { + int64_t num_slots = selection_vector->GetNumSlots(); + // Check integer overflow + if (num_slots > INT_MAX) { + std::stringstream ss; + ss << "The selection vector has " << num_slots + << " slots, which is larger than the " << INT_MAX << " limit.\n"; + const std::string message = ss.str(); + env->ThrowNew(gandiva_exception_, message.c_str()); + return -1; + } + return static_cast<int>(num_slots); + } +} + +JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_closeFilter( + JNIEnv* env, jobject cls, jlong module_id) { + filter_modules_.Erase(module_id); +} diff --git a/src/arrow/cpp/src/gandiva/jni/module_holder.h b/src/arrow/cpp/src/gandiva/jni/module_holder.h new file mode 100644 index 000000000..929c64231 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/module_holder.h @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <utility> + +#include "gandiva/arrow.h" + +namespace gandiva { + +class Projector; +class Filter; + +class ProjectorHolder { + public: + ProjectorHolder(SchemaPtr schema, FieldVector ret_types, + std::shared_ptr<Projector> projector) + : schema_(schema), ret_types_(ret_types), projector_(std::move(projector)) {} + + SchemaPtr schema() { return schema_; } + FieldVector rettypes() { return ret_types_; } + std::shared_ptr<Projector> projector() { return projector_; } + + private: + SchemaPtr schema_; + FieldVector ret_types_; + std::shared_ptr<Projector> projector_; +}; + +class FilterHolder { + public: + FilterHolder(SchemaPtr schema, std::shared_ptr<Filter> filter) + : schema_(schema), filter_(std::move(filter)) {} + + SchemaPtr schema() { return schema_; } + std::shared_ptr<Filter> filter() { return filter_; } + + private: + SchemaPtr schema_; + std::shared_ptr<Filter> filter_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/jni/symbols.map b/src/arrow/cpp/src/gandiva/jni/symbols.map new file mode 100644 index 000000000..e0f5def41 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/jni/symbols.map @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +{ + global: extern "C++" { gandiva*; }; Java*; JNI*; + local: *; +}; diff --git a/src/arrow/cpp/src/gandiva/like_holder.cc b/src/arrow/cpp/src/gandiva/like_holder.cc new file mode 100644 index 000000000..af9ac67d6 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/like_holder.cc @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/like_holder.h" + +#include <regex> +#include "gandiva/node.h" +#include "gandiva/regex_util.h" + +namespace gandiva { + +RE2 LikeHolder::starts_with_regex_(R"((\w|\s)*\.\*)"); +RE2 LikeHolder::ends_with_regex_(R"(\.\*(\w|\s)*)"); +RE2 LikeHolder::is_substr_regex_(R"(\.\*(\w|\s)*\.\*)"); + +// Short-circuit pattern matches for the following common sub cases : +// - starts_with, ends_with and is_substr +const FunctionNode LikeHolder::TryOptimize(const FunctionNode& node) { + std::shared_ptr<LikeHolder> holder; + auto status = Make(node, &holder); + if (status.ok()) { + std::string& pattern = holder->pattern_; + auto literal_type = node.children().at(1)->return_type(); + + if (RE2::FullMatch(pattern, starts_with_regex_)) { + auto prefix = pattern.substr(0, pattern.length() - 2); // trim .* + auto prefix_node = + std::make_shared<LiteralNode>(literal_type, LiteralHolder(prefix), false); + return FunctionNode("starts_with", {node.children().at(0), prefix_node}, + node.return_type()); + } else if (RE2::FullMatch(pattern, ends_with_regex_)) { + auto suffix = pattern.substr(2); // skip .* + auto suffix_node = + std::make_shared<LiteralNode>(literal_type, LiteralHolder(suffix), false); + return FunctionNode("ends_with", {node.children().at(0), suffix_node}, + node.return_type()); + } else if (RE2::FullMatch(pattern, is_substr_regex_)) { + auto substr = + pattern.substr(2, pattern.length() - 4); // trim starting and ending .* + auto substr_node = + std::make_shared<LiteralNode>(literal_type, LiteralHolder(substr), false); + return FunctionNode("is_substr", {node.children().at(0), substr_node}, + node.return_type()); + } + } + + // Could not optimize, return original node. + return node; +} + +static bool IsArrowStringLiteral(arrow::Type::type type) { + return type == arrow::Type::STRING || type == arrow::Type::BINARY; +} + +Status LikeHolder::Make(const FunctionNode& node, std::shared_ptr<LikeHolder>* holder) { + ARROW_RETURN_IF(node.children().size() != 2 && node.children().size() != 3, + Status::Invalid("'like' function requires two or three parameters")); + + auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get()); + ARROW_RETURN_IF( + literal == nullptr, + Status::Invalid("'like' function requires a literal as the second parameter")); + + auto literal_type = literal->return_type()->id(); + ARROW_RETURN_IF( + !IsArrowStringLiteral(literal_type), + Status::Invalid( + "'like' function requires a string literal as the second parameter")); + + RE2::Options regex_op; + if (node.descriptor()->name() == "ilike") { + regex_op.set_case_sensitive(false); // set case-insensitive for ilike function. + + return Make(arrow::util::get<std::string>(literal->holder()), holder, regex_op); + } + if (node.children().size() == 2) { + return Make(arrow::util::get<std::string>(literal->holder()), holder); + } else { + auto escape_char = dynamic_cast<LiteralNode*>(node.children().at(2).get()); + ARROW_RETURN_IF( + escape_char == nullptr, + Status::Invalid("'like' function requires a literal as the third parameter")); + + auto escape_char_type = escape_char->return_type()->id(); + ARROW_RETURN_IF( + !IsArrowStringLiteral(escape_char_type), + Status::Invalid( + "'like' function requires a string literal as the third parameter")); + return Make(arrow::util::get<std::string>(literal->holder()), + arrow::util::get<std::string>(escape_char->holder()), holder); + } +} + +Status LikeHolder::Make(const std::string& sql_pattern, + std::shared_ptr<LikeHolder>* holder) { + std::string pcre_pattern; + ARROW_RETURN_NOT_OK(RegexUtil::SqlLikePatternToPcre(sql_pattern, pcre_pattern)); + + auto lholder = std::shared_ptr<LikeHolder>(new LikeHolder(pcre_pattern)); + ARROW_RETURN_IF(!lholder->regex_.ok(), + Status::Invalid("Building RE2 pattern '", pcre_pattern, "' failed")); + + *holder = lholder; + return Status::OK(); +} + +Status LikeHolder::Make(const std::string& sql_pattern, const std::string& escape_char, + std::shared_ptr<LikeHolder>* holder) { + ARROW_RETURN_IF(escape_char.length() > 1, + Status::Invalid("The length of escape char ", escape_char, + " in 'like' function is greater than 1")); + std::string pcre_pattern; + if (escape_char.length() == 1) { + ARROW_RETURN_NOT_OK( + RegexUtil::SqlLikePatternToPcre(sql_pattern, escape_char.at(0), pcre_pattern)); + } else { + ARROW_RETURN_NOT_OK(RegexUtil::SqlLikePatternToPcre(sql_pattern, pcre_pattern)); + } + + auto lholder = std::shared_ptr<LikeHolder>(new LikeHolder(pcre_pattern)); + ARROW_RETURN_IF(!lholder->regex_.ok(), + Status::Invalid("Building RE2 pattern '", pcre_pattern, "' failed")); + + *holder = lholder; + return Status::OK(); +} + +Status LikeHolder::Make(const std::string& sql_pattern, + std::shared_ptr<LikeHolder>* holder, RE2::Options regex_op) { + std::string pcre_pattern; + ARROW_RETURN_NOT_OK(RegexUtil::SqlLikePatternToPcre(sql_pattern, pcre_pattern)); + + std::shared_ptr<LikeHolder> lholder; + lholder = std::shared_ptr<LikeHolder>(new LikeHolder(pcre_pattern, regex_op)); + + ARROW_RETURN_IF(!lholder->regex_.ok(), + Status::Invalid("Building RE2 pattern '", pcre_pattern, "' failed")); + + *holder = lholder; + return Status::OK(); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/like_holder.h b/src/arrow/cpp/src/gandiva/like_holder.h new file mode 100644 index 000000000..73e58017d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/like_holder.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> + +#include <re2/re2.h> + +#include "arrow/status.h" + +#include "gandiva/function_holder.h" +#include "gandiva/node.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Function Holder for SQL 'like' +class GANDIVA_EXPORT LikeHolder : public FunctionHolder { + public: + ~LikeHolder() override = default; + + static Status Make(const FunctionNode& node, std::shared_ptr<LikeHolder>* holder); + + static Status Make(const std::string& sql_pattern, std::shared_ptr<LikeHolder>* holder); + + static Status Make(const std::string& sql_pattern, const std::string& escape_char, + std::shared_ptr<LikeHolder>* holder); + + static Status Make(const std::string& sql_pattern, std::shared_ptr<LikeHolder>* holder, + RE2::Options regex_op); + + // Try and optimise a function node with a "like" pattern. + static const FunctionNode TryOptimize(const FunctionNode& node); + + /// Return true if the data matches the pattern. + bool operator()(const std::string& data) { return RE2::FullMatch(data, regex_); } + + private: + explicit LikeHolder(const std::string& pattern) : pattern_(pattern), regex_(pattern) {} + + LikeHolder(const std::string& pattern, RE2::Options regex_op) + : pattern_(pattern), regex_(pattern, regex_op) {} + + std::string pattern_; // posix pattern string, to help debugging + RE2 regex_; // compiled regex for the pattern + + static RE2 starts_with_regex_; // pre-compiled pattern for matching starts_with + static RE2 ends_with_regex_; // pre-compiled pattern for matching ends_with + static RE2 is_substr_regex_; // pre-compiled pattern for matching is_substr +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/like_holder_test.cc b/src/arrow/cpp/src/gandiva/like_holder_test.cc new file mode 100644 index 000000000..a52533a11 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/like_holder_test.cc @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/like_holder.h" +#include "gandiva/regex_util.h" + +#include <memory> +#include <vector> + +#include <gtest/gtest.h> + +namespace gandiva { + +class TestLikeHolder : public ::testing::Test { + public: + RE2::Options regex_op; + FunctionNode BuildLike(std::string pattern) { + auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8())); + auto pattern_node = + std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false); + return FunctionNode("like", {field, pattern_node}, arrow::boolean()); + } + + FunctionNode BuildLike(std::string pattern, char escape_char) { + auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8())); + auto pattern_node = + std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false); + auto escape_char_node = std::make_shared<LiteralNode>( + arrow::int8(), LiteralHolder((int8_t)escape_char), false); + return FunctionNode("like", {field, pattern_node, escape_char_node}, + arrow::boolean()); + } +}; + +TEST_F(TestLikeHolder, TestMatchAny) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("ab%", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_TRUE(like("ab")); + EXPECT_TRUE(like("abc")); + EXPECT_TRUE(like("abcd")); + + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("cab")); +} + +TEST_F(TestLikeHolder, TestMatchOne) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("ab_", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_TRUE(like("abc")); + EXPECT_TRUE(like("abd")); + + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("abcd")); + EXPECT_FALSE(like("dabc")); +} + +TEST_F(TestLikeHolder, TestPcreSpecial) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make(".*ab_", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_TRUE(like(".*abc")); // . and * aren't special in sql regex + EXPECT_FALSE(like("xxabc")); +} + +TEST_F(TestLikeHolder, TestRegexEscape) { + std::string res; + auto status = RegexUtil::SqlLikePatternToPcre("#%hello#_abc_def##", '#', res); + EXPECT_TRUE(status.ok()) << status.message(); + + EXPECT_EQ(res, "%hello_abc.def#"); +} + +TEST_F(TestLikeHolder, TestDot) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("abc.", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_FALSE(like("abcd")); +} + +TEST_F(TestLikeHolder, TestOptimise) { + // optimise for 'starts_with' + auto fnode = LikeHolder::TryOptimize(BuildLike("xy 123z%")); + EXPECT_EQ(fnode.descriptor()->name(), "starts_with"); + EXPECT_EQ(fnode.ToString(), "bool starts_with((string) in, (const string) xy 123z)"); + + // optimise for 'ends_with' + fnode = LikeHolder::TryOptimize(BuildLike("%xyz")); + EXPECT_EQ(fnode.descriptor()->name(), "ends_with"); + EXPECT_EQ(fnode.ToString(), "bool ends_with((string) in, (const string) xyz)"); + + // optimise for 'is_substr' + fnode = LikeHolder::TryOptimize(BuildLike("%abc%")); + EXPECT_EQ(fnode.descriptor()->name(), "is_substr"); + EXPECT_EQ(fnode.ToString(), "bool is_substr((string) in, (const string) abc)"); + + // no optimisation for others. + fnode = LikeHolder::TryOptimize(BuildLike("xyz_")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + fnode = LikeHolder::TryOptimize(BuildLike("_xyz")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + fnode = LikeHolder::TryOptimize(BuildLike("_xyz_")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + fnode = LikeHolder::TryOptimize(BuildLike("%xyz_")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + fnode = LikeHolder::TryOptimize(BuildLike("x_yz%")); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + + // no optimisation for escaped pattern. + fnode = LikeHolder::TryOptimize(BuildLike("\\%xyz", '\\')); + EXPECT_EQ(fnode.descriptor()->name(), "like"); + EXPECT_EQ(fnode.ToString(), + "bool like((string) in, (const string) \\%xyz, (const int8) \\)"); +} + +TEST_F(TestLikeHolder, TestMatchOneEscape) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("ab\\_", "\\", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + + EXPECT_TRUE(like("ab_")); + + EXPECT_FALSE(like("abc")); + EXPECT_FALSE(like("abd")); + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("abcd")); + EXPECT_FALSE(like("dabc")); +} + +TEST_F(TestLikeHolder, TestMatchManyEscape) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("ab\\%", "\\", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + + EXPECT_TRUE(like("ab%")); + + EXPECT_FALSE(like("abc")); + EXPECT_FALSE(like("abd")); + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("abcd")); + EXPECT_FALSE(like("dabc")); +} + +TEST_F(TestLikeHolder, TestMatchEscape) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("ab\\\\", "\\", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + + EXPECT_TRUE(like("ab\\")); + + EXPECT_FALSE(like("abc")); +} + +TEST_F(TestLikeHolder, TestEmptyEscapeChar) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("ab\\_", "", &like_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + + EXPECT_TRUE(like("ab\\c")); + EXPECT_TRUE(like("ab\\_")); + + EXPECT_FALSE(like("ab\\_d")); + EXPECT_FALSE(like("ab__")); +} + +TEST_F(TestLikeHolder, TestMultipleEscapeChar) { + std::shared_ptr<LikeHolder> like_holder; + + auto status = LikeHolder::Make("ab\\_", "\\\\", &like_holder); + EXPECT_EQ(status.ok(), false) << status.message(); +} +class TestILikeHolder : public ::testing::Test { + public: + RE2::Options regex_op; + FunctionNode BuildILike(std::string pattern) { + auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8())); + auto pattern_node = + std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false); + return FunctionNode("ilike", {field, pattern_node}, arrow::boolean()); + } +}; + +TEST_F(TestILikeHolder, TestMatchAny) { + std::shared_ptr<LikeHolder> like_holder; + + regex_op.set_case_sensitive(false); + auto status = LikeHolder::Make("ab%", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_TRUE(like("ab")); + EXPECT_TRUE(like("aBc")); + EXPECT_TRUE(like("ABCD")); + + EXPECT_FALSE(like("a")); + EXPECT_FALSE(like("cab")); +} + +TEST_F(TestILikeHolder, TestMatchOne) { + std::shared_ptr<LikeHolder> like_holder; + + regex_op.set_case_sensitive(false); + auto status = LikeHolder::Make("Ab_", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_TRUE(like("abc")); + EXPECT_TRUE(like("aBd")); + + EXPECT_FALSE(like("A")); + EXPECT_FALSE(like("Abcd")); + EXPECT_FALSE(like("DaBc")); +} + +TEST_F(TestILikeHolder, TestPcreSpecial) { + std::shared_ptr<LikeHolder> like_holder; + + regex_op.set_case_sensitive(false); + auto status = LikeHolder::Make(".*aB_", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_TRUE(like(".*Abc")); // . and * aren't special in sql regex + EXPECT_FALSE(like("xxAbc")); +} + +TEST_F(TestILikeHolder, TestDot) { + std::shared_ptr<LikeHolder> like_holder; + + regex_op.set_case_sensitive(false); + auto status = LikeHolder::Make("aBc.", &like_holder, regex_op); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& like = *like_holder; + EXPECT_FALSE(like("abcd")); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/literal_holder.cc b/src/arrow/cpp/src/gandiva/literal_holder.cc new file mode 100644 index 000000000..beed8119c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/literal_holder.cc @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <sstream> + +#include "gandiva/literal_holder.h" + +namespace gandiva { + +namespace { + +template <typename OStream> +struct LiteralToStream { + OStream& ostream_; + + template <typename Value> + void operator()(const Value& v) { + ostream_ << v; + } +}; + +} // namespace + +std::string ToString(const LiteralHolder& holder) { + std::stringstream ss; + LiteralToStream<std::stringstream> visitor{ss}; + ::arrow::util::visit(visitor, holder); + return ss.str(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/literal_holder.h b/src/arrow/cpp/src/gandiva/literal_holder.h new file mode 100644 index 000000000..c4712aafc --- /dev/null +++ b/src/arrow/cpp/src/gandiva/literal_holder.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <string> + +#include <arrow/util/variant.h> + +#include <arrow/type.h> +#include "gandiva/decimal_scalar.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +using LiteralHolder = + arrow::util::Variant<bool, float, double, int8_t, int16_t, int32_t, int64_t, uint8_t, + uint16_t, uint32_t, uint64_t, std::string, DecimalScalar128>; + +GANDIVA_EXPORT std::string ToString(const LiteralHolder& holder); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/llvm_generator.cc b/src/arrow/cpp/src/gandiva/llvm_generator.cc new file mode 100644 index 000000000..0129e5278 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/llvm_generator.cc @@ -0,0 +1,1400 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/llvm_generator.h" + +#include <fstream> +#include <iostream> +#include <sstream> +#include <string> +#include <utility> +#include <vector> + +#include "gandiva/bitmap_accumulator.h" +#include "gandiva/decimal_ir.h" +#include "gandiva/dex.h" +#include "gandiva/expr_decomposer.h" +#include "gandiva/expression.h" +#include "gandiva/lvalue.h" + +namespace gandiva { + +#define ADD_TRACE(...) \ + if (enable_ir_traces_) { \ + AddTrace(__VA_ARGS__); \ + } + +LLVMGenerator::LLVMGenerator() : enable_ir_traces_(false) {} + +Status LLVMGenerator::Make(std::shared_ptr<Configuration> config, + std::unique_ptr<LLVMGenerator>* llvm_generator) { + std::unique_ptr<LLVMGenerator> llvmgen_obj(new LLVMGenerator()); + + ARROW_RETURN_NOT_OK(Engine::Make(config, &(llvmgen_obj->engine_))); + *llvm_generator = std::move(llvmgen_obj); + + return Status::OK(); +} + +Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) { + int idx = static_cast<int>(compiled_exprs_.size()); + // decompose the expression to separate out value and validities. + ExprDecomposer decomposer(function_registry_, annotator_); + ValueValidityPairPtr value_validity; + ARROW_RETURN_NOT_OK(decomposer.Decompose(*expr->root(), &value_validity)); + // Generate the IR function for the decomposed expression. + std::unique_ptr<CompiledExpr> compiled_expr(new CompiledExpr(value_validity, output)); + llvm::Function* ir_function = nullptr; + ARROW_RETURN_NOT_OK(CodeGenExprValue(value_validity->value_expr(), + annotator_.buffer_count(), output, idx, + &ir_function, selection_vector_mode_)); + compiled_expr->SetIRFunction(selection_vector_mode_, ir_function); + + compiled_exprs_.push_back(std::move(compiled_expr)); + return Status::OK(); +} + +/// Build and optimise module for projection expression. +Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode mode) { + selection_vector_mode_ = mode; + for (auto& expr : exprs) { + auto output = annotator_.AddOutputFieldDescriptor(expr->result()); + ARROW_RETURN_NOT_OK(Add(expr, output)); + } + + // Compile and inject into the process' memory the generated function. + ARROW_RETURN_NOT_OK(engine_->FinalizeModule()); + + // setup the jit functions for each expression. + for (auto& compiled_expr : compiled_exprs_) { + auto ir_fn = compiled_expr->GetIRFunction(mode); + auto jit_fn = reinterpret_cast<EvalFunc>(engine_->CompiledFunction(ir_fn)); + compiled_expr->SetJITFunction(selection_vector_mode_, jit_fn); + } + + return Status::OK(); +} + +/// Execute the compiled module against the provided vectors. +Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch, + const ArrayDataVector& output_vector) { + return Execute(record_batch, nullptr, output_vector); +} + +/// Execute the compiled module against the provided vectors based on the type of +/// selection vector. +Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch, + const SelectionVector* selection_vector, + const ArrayDataVector& output_vector) { + DCHECK_GT(record_batch.num_rows(), 0); + + auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector); + DCHECK_GT(eval_batch->GetNumBuffers(), 0); + + auto mode = SelectionVector::MODE_NONE; + if (selection_vector != nullptr) { + mode = selection_vector->GetMode(); + } + if (mode != selection_vector_mode_) { + return Status::Invalid("llvm expression built for selection vector mode ", + selection_vector_mode_, " received vector with mode ", mode); + } + + for (auto& compiled_expr : compiled_exprs_) { + // generate data/offset vectors. + const uint8_t* selection_buffer = nullptr; + auto num_output_rows = record_batch.num_rows(); + if (selection_vector != nullptr) { + selection_buffer = selection_vector->GetBuffer().data(); + num_output_rows = selection_vector->GetNumSlots(); + } + + EvalFunc jit_function = compiled_expr->GetJITFunction(mode); + jit_function(eval_batch->GetBufferArray(), eval_batch->GetBufferOffsetArray(), + eval_batch->GetLocalBitMapArray(), selection_buffer, + (int64_t)eval_batch->GetExecutionContext(), num_output_rows); + + // check for execution errors + ARROW_RETURN_IF( + eval_batch->GetExecutionContext()->has_error(), + Status::ExecutionError(eval_batch->GetExecutionContext()->get_error())); + + // generate validity vectors. + ComputeBitMapsForExpr(*compiled_expr, *eval_batch, selection_vector); + } + + return Status::OK(); +} + +llvm::Value* LLVMGenerator::LoadVectorAtIndex(llvm::Value* arg_addrs, int idx, + const std::string& name) { + auto* idx_val = types()->i32_constant(idx); + auto* offset = CreateGEP(ir_builder(), arg_addrs, idx_val, name + "_mem_addr"); + return CreateLoad(ir_builder(), offset, name + "_mem"); +} + +/// Get reference to validity array at specified index in the args list. +llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), name + "_varray"); +} + +/// Get reference to data array at specified index in the args list. +llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + return ir_builder()->CreateIntToPtr(load, types()->i8_ptr_type(), name + "_buf_ptr"); +} + +/// Get reference to data array at specified index in the args list. +llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + llvm::Type* base_type = types()->DataVecType(field->type()); + llvm::Value* ret; + if (base_type->isPointerTy()) { + ret = ir_builder()->CreateIntToPtr(load, base_type, name + "_darray"); + } else { + llvm::Type* pointer_type = types()->ptr_type(base_type); + ret = ir_builder()->CreateIntToPtr(load, pointer_type, name + "_darray"); + } + return ret; +} + +/// Get reference to offsets array at specified index in the args list. +llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray"); +} + +/// Get reference to local bitmap array at specified index in the args list. +llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) { + llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, idx, ""); + return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), + std::to_string(idx) + "_lbmap"); +} + +/// \brief Generate code for one expression. + +// Sample IR code for "c1:int + c2:int" +// +// The C-code equivalent is : +// ------------------------------ +// int expr_0(int64_t *addrs, int64_t *local_bitmaps, +// int64_t execution_context_ptr, int64_t nrecords) { +// int *outVec = (int *) addrs[5]; +// int *c0Vec = (int *) addrs[1]; +// int *c1Vec = (int *) addrs[3]; +// for (int loop_var = 0; loop_var < nrecords; ++loop_var) { +// int c0 = c0Vec[loop_var]; +// int c1 = c1Vec[loop_var]; +// int out = c0 + c1; +// outVec[loop_var] = out; +// } +// } +// +// IR Code +// -------- +// +// define i32 @expr_0(i64* %args, i64* %local_bitmaps, i64 %execution_context_ptr, , i64 +// %nrecords) { entry: +// %outmemAddr = getelementptr i64, i64* %args, i32 5 +// %outmem = load i64, i64* %outmemAddr +// %outVec = inttoptr i64 %outmem to i32* +// %c0memAddr = getelementptr i64, i64* %args, i32 1 +// %c0mem = load i64, i64* %c0memAddr +// %c0Vec = inttoptr i64 %c0mem to i32* +// %c1memAddr = getelementptr i64, i64* %args, i32 3 +// %c1mem = load i64, i64* %c1memAddr +// %c1Vec = inttoptr i64 %c1mem to i32* +// br label %loop +// loop: ; preds = %loop, %entry +// %loop_var = phi i64 [ 0, %entry ], [ %"loop_var+1", %loop ] +// %"loop_var+1" = add i64 %loop_var, 1 +// %0 = getelementptr i32, i32* %c0Vec, i32 %loop_var +// %c0 = load i32, i32* %0 +// %1 = getelementptr i32, i32* %c1Vec, i32 %loop_var +// %c1 = load i32, i32* %1 +// %add_int_int = call i32 @add_int_int(i32 %c0, i32 %c1) +// %2 = getelementptr i32, i32* %outVec, i32 %loop_var +// store i32 %add_int_int, i32* %2 +// %"loop_var < nrec" = icmp slt i64 %"loop_var+1", %nrecords +// br i1 %"loop_var < nrec", label %loop, label %exit +// exit: ; preds = %loop +// ret i32 0 +// } +Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, + FieldDescriptorPtr output, int suffix_idx, + llvm::Function** fn, + SelectionVector::Mode selection_vector_mode) { + llvm::IRBuilder<>* builder = ir_builder(); + // Create fn prototype : + // int expr_1 (long **addrs, long *offsets, long **bitmaps, + // long *context_ptr, long nrec) + std::vector<llvm::Type*> arguments; + arguments.push_back(types()->i64_ptr_type()); // addrs + arguments.push_back(types()->i64_ptr_type()); // offsets + arguments.push_back(types()->i64_ptr_type()); // bitmaps + switch (selection_vector_mode) { + case SelectionVector::MODE_NONE: + case SelectionVector::MODE_UINT16: + arguments.push_back(types()->ptr_type(types()->i16_type())); + break; + case SelectionVector::MODE_UINT32: + arguments.push_back(types()->i32_ptr_type()); + break; + case SelectionVector::MODE_UINT64: + arguments.push_back(types()->i64_ptr_type()); + } + arguments.push_back(types()->i64_type()); // ctx_ptr + arguments.push_back(types()->i64_type()); // nrec + llvm::FunctionType* prototype = + llvm::FunctionType::get(types()->i32_type(), arguments, false /*isVarArg*/); + + // Create fn + std::string func_name = "expr_" + std::to_string(suffix_idx) + "_" + + std::to_string(static_cast<int>(selection_vector_mode)); + engine_->AddFunctionToCompile(func_name); + *fn = llvm::Function::Create(prototype, llvm::GlobalValue::ExternalLinkage, func_name, + module()); + ARROW_RETURN_IF((*fn == nullptr), Status::CodeGenError("Error creating function.")); + + // Name the arguments + llvm::Function::arg_iterator args = (*fn)->arg_begin(); + llvm::Value* arg_addrs = &*args; + arg_addrs->setName("inputs_addr"); + ++args; + llvm::Value* arg_addr_offsets = &*args; + arg_addr_offsets->setName("inputs_addr_offsets"); + ++args; + llvm::Value* arg_local_bitmaps = &*args; + arg_local_bitmaps->setName("local_bitmaps"); + ++args; + llvm::Value* arg_selection_vector = &*args; + arg_selection_vector->setName("selection_vector"); + ++args; + llvm::Value* arg_context_ptr = &*args; + arg_context_ptr->setName("context_ptr"); + ++args; + llvm::Value* arg_nrecords = &*args; + arg_nrecords->setName("nrecords"); + + llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(*context(), "entry", *fn); + llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(*context(), "loop", *fn); + llvm::BasicBlock* loop_exit = llvm::BasicBlock::Create(*context(), "exit", *fn); + + // Add reference to output vector (in entry block) + builder->SetInsertPoint(loop_entry); + llvm::Value* output_ref = + GetDataReference(arg_addrs, output->data_idx(), output->field()); + llvm::Value* output_buffer_ptr_ref = GetDataBufferPtrReference( + arg_addrs, output->data_buffer_ptr_idx(), output->field()); + llvm::Value* output_offset_ref = + GetOffsetsReference(arg_addrs, output->offsets_idx(), output->field()); + + std::vector<llvm::Value*> slice_offsets; + for (int idx = 0; idx < buffer_count; idx++) { + auto offsetAddr = CreateGEP(builder, arg_addr_offsets, types()->i32_constant(idx)); + auto offset = CreateLoad(builder, offsetAddr); + slice_offsets.push_back(offset); + } + + // Loop body + builder->SetInsertPoint(loop_body); + + // define loop_var : start with 0, +1 after each iter + llvm::PHINode* loop_var = builder->CreatePHI(types()->i64_type(), 2, "loop_var"); + + llvm::Value* position_var = loop_var; + if (selection_vector_mode != SelectionVector::MODE_NONE) { + position_var = builder->CreateIntCast( + CreateLoad(builder, CreateGEP(builder, arg_selection_vector, loop_var), + "uncasted_position_var"), + types()->i64_type(), true, "position_var"); + } + + // The visitor can add code to both the entry/loop blocks. + Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, slice_offsets, + arg_context_ptr, position_var); + value_expr->Accept(visitor); + LValuePtr output_value = visitor.result(); + + // The "current" block may have changed due to code generation in the visitor. + llvm::BasicBlock* loop_body_tail = builder->GetInsertBlock(); + + // add jump to "loop block" at the end of the "setup block". + builder->SetInsertPoint(loop_entry); + builder->CreateBr(loop_body); + + // save the value in the output vector. + builder->SetInsertPoint(loop_body_tail); + + auto output_type_id = output->Type()->id(); + if (output_type_id == arrow::Type::BOOL) { + SetPackedBitValue(output_ref, loop_var, output_value->data()); + } else if (arrow::is_primitive(output_type_id) || + output_type_id == arrow::Type::DECIMAL) { + llvm::Value* slot_offset = CreateGEP(builder, output_ref, loop_var); + builder->CreateStore(output_value->data(), slot_offset); + } else if (arrow::is_binary_like(output_type_id)) { + // Var-len output. Make a function call to populate the data. + // if there is an error, the fn sets it in the context. And, will be returned at the + // end of this row batch. + AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var, + output_value->data(), output_value->length()}); + } else { + return Status::NotImplemented("output type ", output->Type()->ToString(), + " not supported"); + } + ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data()); + + if (visitor.has_arena_allocs()) { + // Reset allocations to avoid excessive memory usage. Once the result is copied to + // the output vector (store instruction above), any memory allocations in this + // iteration of the loop are no longer needed. + std::vector<llvm::Value*> reset_args; + reset_args.push_back(arg_context_ptr); + AddFunctionCall("gdv_fn_context_arena_reset", types()->void_type(), reset_args); + } + + // check loop_var + loop_var->addIncoming(types()->i64_constant(0), loop_entry); + llvm::Value* loop_update = + builder->CreateAdd(loop_var, types()->i64_constant(1), "loop_var+1"); + loop_var->addIncoming(loop_update, loop_body_tail); + + llvm::Value* loop_var_check = + builder->CreateICmpSLT(loop_update, arg_nrecords, "loop_var < nrec"); + builder->CreateCondBr(loop_var_check, loop_body, loop_exit); + + // Loop exit + builder->SetInsertPoint(loop_exit); + builder->CreateRet(types()->i32_constant(0)); + return Status::OK(); +} + +/// Return value of a bit in bitMap. +llvm::Value* LLVMGenerator::GetPackedBitValue(llvm::Value* bitmap, + llvm::Value* position) { + ADD_TRACE("fetch bit at position %T", position); + + llvm::Value* bitmap8 = ir_builder()->CreateBitCast( + bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast"); + return AddFunctionCall("bitMapGetBit", types()->i1_type(), {bitmap8, position}); +} + +/// Set the value of a bit in bitMap. +void LLVMGenerator::SetPackedBitValue(llvm::Value* bitmap, llvm::Value* position, + llvm::Value* value) { + ADD_TRACE("set bit at position %T", position); + ADD_TRACE(" to value %T ", value); + + llvm::Value* bitmap8 = ir_builder()->CreateBitCast( + bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast"); + AddFunctionCall("bitMapSetBit", types()->void_type(), {bitmap8, position, value}); +} + +/// Return value of a bit in validity bitMap (handles null bitmaps too). +llvm::Value* LLVMGenerator::GetPackedValidityBitValue(llvm::Value* bitmap, + llvm::Value* position) { + ADD_TRACE("fetch validity bit at position %T", position); + + llvm::Value* bitmap8 = ir_builder()->CreateBitCast( + bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast"); + return AddFunctionCall("bitMapValidityGetBit", types()->i1_type(), {bitmap8, position}); +} + +/// Clear the bit in bitMap if value = false. +void LLVMGenerator::ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position, + llvm::Value* value) { + ADD_TRACE("ClearIfFalse bit at position %T", position); + ADD_TRACE(" value %T ", value); + + llvm::Value* bitmap8 = ir_builder()->CreateBitCast( + bitmap, types()->ptr_type(types()->i8_type()), "bitMapCast"); + AddFunctionCall("bitMapClearBitIfFalse", types()->void_type(), + {bitmap8, position, value}); +} + +/// Extract the bitmap addresses, and do an intersection. +void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr, + const EvalBatch& eval_batch, + const SelectionVector* selection_vector) { + auto validities = compiled_expr.value_validity()->validity_exprs(); + + // Extract all the source bitmap addresses. + BitMapAccumulator accumulator(eval_batch); + for (auto& validity_dex : validities) { + validity_dex->Accept(accumulator); + } + + // Extract the destination bitmap address. + int out_idx = compiled_expr.output()->validity_idx(); + uint8_t* dst_bitmap = eval_batch.GetBuffer(out_idx); + // Compute the destination bitmap. + if (selection_vector == nullptr) { + accumulator.ComputeResult(dst_bitmap); + } else { + /// The output bitmap is an intersection of some input/local bitmaps. However, with a + /// selection vector, only the bits corresponding to the indices in the selection + /// vector need to set in the output bitmap. This is done in two steps : + /// + /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap. + /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap. + LocalBitMapsHolder bit_map_holder(eval_batch.num_records(), 1); + uint8_t* temp_bitmap = bit_map_holder.GetLocalBitMap(0); + accumulator.ComputeResult(temp_bitmap); + + auto num_out_records = selection_vector->GetNumSlots(); + // the memset isn't required, doing it just for valgrind. + memset(dst_bitmap, 0, arrow::BitUtil::BytesForBits(num_out_records)); + for (auto i = 0; i < num_out_records; ++i) { + auto bit = arrow::BitUtil::GetBit(temp_bitmap, selection_vector->GetIndex(i)); + arrow::BitUtil::SetBitTo(dst_bitmap, i, bit); + } + } +} + +llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name, + llvm::Type* ret_type, + const std::vector<llvm::Value*>& args) { + // find the llvm function. + llvm::Function* fn = module()->getFunction(full_name); + DCHECK_NE(fn, nullptr) << "missing function " << full_name; + + if (enable_ir_traces_ && !full_name.compare("printf") && + !full_name.compare("printff")) { + // Trace for debugging + ADD_TRACE("invoke native fn " + full_name); + } + + // build a call to the llvm function. + llvm::Value* value; + if (ret_type->isVoidTy()) { + // void functions can't have a name for the call. + value = ir_builder()->CreateCall(fn, args); + } else { + value = ir_builder()->CreateCall(fn, args, full_name); + DCHECK(value->getType() == ret_type); + } + + return value; +} + +std::shared_ptr<DecimalLValue> LLVMGenerator::BuildDecimalLValue(llvm::Value* value, + DataTypePtr arrow_type) { + // only decimals of size 128-bit supported. + DCHECK(is_decimal_128(arrow_type)); + auto decimal_type = + arrow::internal::checked_cast<arrow::DecimalType*>(arrow_type.get()); + return std::make_shared<DecimalLValue>(value, nullptr, + types()->i32_constant(decimal_type->precision()), + types()->i32_constant(decimal_type->scale())); +} + +#define ADD_VISITOR_TRACE(...) \ + if (generator_->enable_ir_traces_) { \ + generator_->AddTrace(__VA_ARGS__); \ + } + +// Visitor for generating the code for a decomposed expression. +LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* function, + llvm::BasicBlock* entry_block, llvm::Value* arg_addrs, + llvm::Value* arg_local_bitmaps, + std::vector<llvm::Value*> slice_offsets, + llvm::Value* arg_context_ptr, llvm::Value* loop_var) + : generator_(generator), + function_(function), + entry_block_(entry_block), + arg_addrs_(arg_addrs), + arg_local_bitmaps_(arg_local_bitmaps), + slice_offsets_(slice_offsets), + arg_context_ptr_(arg_context_ptr), + loop_var_(loop_var), + has_arena_allocs_(false) { + ADD_VISITOR_TRACE("Iteration %T", loop_var); +} + +void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + llvm::Value* slot_index = builder->CreateAdd(loop_var_, GetSliceOffset(dex.DataIdx())); + llvm::Value* slot_value; + std::shared_ptr<LValue> lvalue; + + switch (dex.FieldType()->id()) { + case arrow::Type::BOOL: + slot_value = generator_->GetPackedBitValue(slot_ref, slot_index); + lvalue = std::make_shared<LValue>(slot_value); + break; + + case arrow::Type::DECIMAL: { + auto slot_offset = CreateGEP(builder, slot_ref, slot_index); + slot_value = CreateLoad(builder, slot_offset, dex.FieldName()); + lvalue = generator_->BuildDecimalLValue(slot_value, dex.FieldType()); + break; + } + + default: { + auto slot_offset = CreateGEP(builder, slot_ref, slot_index); + slot_value = CreateLoad(builder, slot_offset, dex.FieldName()); + lvalue = std::make_shared<LValue>(slot_value); + break; + } + } + ADD_VISITOR_TRACE("visit fixed-len data vector " + dex.FieldName() + " value %T", + slot_value); + result_ = lvalue; +} + +void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot; + + // compute len from the offsets array. + llvm::Value* offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + llvm::Value* offsets_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); + + // => offset_start = offsets[loop_var] + slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = CreateLoad(builder, slot, "offset_start"); + + // => offset_end = offsets[loop_var + 1] + llvm::Value* offsets_slot_index_next = builder->CreateAdd( + offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); + slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index_next); + llvm::Value* offset_end = CreateLoad(builder, slot, "offset_end"); + + // => len_value = offset_end - offset_start + llvm::Value* len_value = + builder->CreateSub(offset_end, offset_start, dex.FieldName() + "Len"); + + // get the data from the data array, at offset 'offset_start'. + llvm::Value* data_slot_ref = + GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + llvm::Value* data_value = CreateGEP(builder, data_slot_ref, offset_start); + ADD_VISITOR_TRACE("visit var-len data vector " + dex.FieldName() + " len %T", + len_value); + result_.reset(new LValue(data_value, len_value)); +} + +void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot_ref = + GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field()); + llvm::Value* slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.ValidityIdx())); + llvm::Value* validity = generator_->GetPackedValidityBitValue(slot_ref, slot_index); + + ADD_VISITOR_TRACE("visit validity vector " + dex.FieldName() + " value %T", validity); + result_.reset(new LValue(validity)); +} + +void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) { + llvm::Value* slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx()); + llvm::Value* validity = generator_->GetPackedBitValue(slot_ref, loop_var_); + + ADD_VISITOR_TRACE( + "visit local bitmap " + std::to_string(dex.local_bitmap_idx()) + " value %T", + validity); + result_.reset(new LValue(validity)); +} + +void LLVMGenerator::Visitor::Visit(const TrueDex& dex) { + result_.reset(new LValue(generator_->types()->true_constant())); +} + +void LLVMGenerator::Visitor::Visit(const FalseDex& dex) { + result_.reset(new LValue(generator_->types()->false_constant())); +} + +void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { + LLVMTypes* types = generator_->types(); + llvm::Value* value = nullptr; + llvm::Value* len = nullptr; + + switch (dex.type()->id()) { + case arrow::Type::BOOL: + value = types->i1_constant(arrow::util::get<bool>(dex.holder())); + break; + + case arrow::Type::UINT8: + value = types->i8_constant(arrow::util::get<uint8_t>(dex.holder())); + break; + + case arrow::Type::UINT16: + value = types->i16_constant(arrow::util::get<uint16_t>(dex.holder())); + break; + + case arrow::Type::UINT32: + value = types->i32_constant(arrow::util::get<uint32_t>(dex.holder())); + break; + + case arrow::Type::UINT64: + value = types->i64_constant(arrow::util::get<uint64_t>(dex.holder())); + break; + + case arrow::Type::INT8: + value = types->i8_constant(arrow::util::get<int8_t>(dex.holder())); + break; + + case arrow::Type::INT16: + value = types->i16_constant(arrow::util::get<int16_t>(dex.holder())); + break; + + case arrow::Type::FLOAT: + value = types->float_constant(arrow::util::get<float>(dex.holder())); + break; + + case arrow::Type::DOUBLE: + value = types->double_constant(arrow::util::get<double>(dex.holder())); + break; + + case arrow::Type::STRING: + case arrow::Type::BINARY: { + const std::string& str = arrow::util::get<std::string>(dex.holder()); + + llvm::Constant* str_int_cast = types->i64_constant((int64_t)str.c_str()); + value = llvm::ConstantExpr::getIntToPtr(str_int_cast, types->i8_ptr_type()); + len = types->i32_constant(static_cast<int32_t>(str.length())); + break; + } + + case arrow::Type::INT32: + case arrow::Type::DATE32: + case arrow::Type::TIME32: + case arrow::Type::INTERVAL_MONTHS: + value = types->i32_constant(arrow::util::get<int32_t>(dex.holder())); + break; + + case arrow::Type::INT64: + case arrow::Type::DATE64: + case arrow::Type::TIME64: + case arrow::Type::TIMESTAMP: + case arrow::Type::INTERVAL_DAY_TIME: + value = types->i64_constant(arrow::util::get<int64_t>(dex.holder())); + break; + + case arrow::Type::DECIMAL: { + // build code for struct + auto scalar = arrow::util::get<DecimalScalar128>(dex.holder()); + // ConstantInt doesn't have a get method that takes int128 or a pair of int64. so, + // passing the string representation instead. + auto int128_value = + llvm::ConstantInt::get(llvm::Type::getInt128Ty(*generator_->context()), + Decimal128(scalar.value()).ToIntegerString(), 10); + auto type = arrow::decimal(scalar.precision(), scalar.scale()); + auto lvalue = generator_->BuildDecimalLValue(int128_value, type); + // set it as the l-value and return. + result_ = lvalue; + return; + } + + default: + DCHECK(0); + } + ADD_VISITOR_TRACE("visit Literal %T", value); + result_.reset(new LValue(value, len)); +} + +void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) { + const std::string& function_name = dex.func_descriptor()->name(); + ADD_VISITOR_TRACE("visit NonNullableFunc base function " + function_name); + + const NativeFunction* native_function = dex.native_function(); + + // build the function params (ignore validity). + auto params = BuildParams(dex.function_holder().get(), dex.args(), false, + native_function->NeedsContext()); + + auto arrow_return_type = dex.func_descriptor()->return_type(); + if (native_function->CanReturnErrors()) { + // slow path : if a function can return errors, skip invoking the function + // unless all of the input args are valid. Otherwise, it can cause spurious errors. + + llvm::IRBuilder<>* builder = ir_builder(); + LLVMTypes* types = generator_->types(); + auto arrow_type_id = arrow_return_type->id(); + auto result_type = types->IRType(arrow_type_id); + + // Build combined validity of the args. + llvm::Value* is_valid = types->true_constant(); + for (auto& pair : dex.args()) { + auto arg_validity = BuildCombinedValidity(pair->validity_exprs()); + is_valid = builder->CreateAnd(is_valid, arg_validity, "validityBitAnd"); + } + + // then block + auto then_lambda = [&] { + ADD_VISITOR_TRACE("fn " + function_name + + " can return errors : all args valid, invoke fn"); + return BuildFunctionCall(native_function, arrow_return_type, ¶ms); + }; + + // else block + auto else_lambda = [&] { + ADD_VISITOR_TRACE("fn " + function_name + + " can return errors : not all args valid, return dummy value"); + llvm::Value* else_value = types->NullConstant(result_type); + llvm::Value* else_value_len = nullptr; + if (arrow::is_binary_like(arrow_type_id)) { + else_value_len = types->i32_constant(0); + } + return std::make_shared<LValue>(else_value, else_value_len); + }; + + result_ = BuildIfElse(is_valid, then_lambda, else_lambda, arrow_return_type); + } else { + // fast path : invoke function without computing validities. + result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms); + } +} + +void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) { + ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name()); + const NativeFunction* native_function = dex.native_function(); + + // build function params along with validity. + auto params = BuildParams(dex.function_holder().get(), dex.args(), true, + native_function->NeedsContext()); + + auto arrow_return_type = dex.func_descriptor()->return_type(); + result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms); +} + +void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) { + ADD_VISITOR_TRACE("visit NullableInternal base function " + + dex.func_descriptor()->name()); + llvm::IRBuilder<>* builder = ir_builder(); + LLVMTypes* types = generator_->types(); + + const NativeFunction* native_function = dex.native_function(); + + // build function params along with validity. + auto params = BuildParams(dex.function_holder().get(), dex.args(), true, + native_function->NeedsContext()); + + // add an extra arg for validity (allocated on stack). + llvm::AllocaInst* result_valid_ptr = + new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_); + params.push_back(result_valid_ptr); + + auto arrow_return_type = dex.func_descriptor()->return_type(); + result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms); + + // load the result validity and truncate to i1. + llvm::Value* result_valid_i8 = CreateLoad(builder, result_valid_ptr); + llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type()); + + // set validity bit in the local bitmap. + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid); +} + +void LLVMGenerator::Visitor::Visit(const IfDex& dex) { + ADD_VISITOR_TRACE("visit IfExpression"); + llvm::IRBuilder<>* builder = ir_builder(); + + // Evaluate condition. + LValuePtr if_condition = BuildValueAndValidity(dex.condition_vv()); + + // Check if the result is valid, and there is match. + llvm::Value* validAndMatched = + builder->CreateAnd(if_condition->data(), if_condition->validity(), "validAndMatch"); + + // then block + auto then_lambda = [&] { + ADD_VISITOR_TRACE("branch to then block"); + LValuePtr then_lvalue = BuildValueAndValidity(dex.then_vv()); + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), then_lvalue->validity()); + ADD_VISITOR_TRACE("IfExpression result validity %T in matching then", + then_lvalue->validity()); + return then_lvalue; + }; + + // else block + auto else_lambda = [&] { + LValuePtr else_lvalue; + if (dex.is_terminal_else()) { + ADD_VISITOR_TRACE("branch to terminal else block"); + + else_lvalue = BuildValueAndValidity(dex.else_vv()); + // update the local bitmap with the validity. + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), else_lvalue->validity()); + ADD_VISITOR_TRACE("IfExpression result validity %T in terminal else", + else_lvalue->validity()); + } else { + ADD_VISITOR_TRACE("branch to non-terminal else block"); + + // this is a non-terminal else. let the child (nested if/else) handle validity. + auto value_expr = dex.else_vv().value_expr(); + value_expr->Accept(*this); + else_lvalue = result(); + } + return else_lvalue; + }; + + // build the if-else condition. + result_ = BuildIfElse(validAndMatched, then_lambda, else_lambda, dex.result_type()); + if (arrow::is_binary_like(dex.result_type()->id())) { + ADD_VISITOR_TRACE("IfElse result length %T", result_->length()); + } + ADD_VISITOR_TRACE("IfElse result value %T", result_->data()); +} + +// Boolean AND +// if any arg is valid and false, +// short-circuit and return FALSE (value=false, valid=true) +// else if all args are valid and true +// return TRUE (value=true, valid=true) +// else +// return NULL (value=true, valid=false) + +void LLVMGenerator::Visitor::Visit(const BooleanAndDex& dex) { + ADD_VISITOR_TRACE("visit BooleanAndExpression"); + llvm::IRBuilder<>* builder = ir_builder(); + LLVMTypes* types = generator_->types(); + llvm::LLVMContext* context = generator_->context(); + + // Create blocks for short-circuit. + llvm::BasicBlock* short_circuit_bb = + llvm::BasicBlock::Create(*context, "short_circuit", function_); + llvm::BasicBlock* non_short_circuit_bb = + llvm::BasicBlock::Create(*context, "non_short_circuit", function_); + llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_); + + llvm::Value* all_exprs_valid = types->true_constant(); + for (auto& pair : dex.args()) { + LValuePtr current = BuildValueAndValidity(*pair); + + ADD_VISITOR_TRACE("BooleanAndExpression arg value %T", current->data()); + ADD_VISITOR_TRACE("BooleanAndExpression arg validity %T", current->validity()); + + // short-circuit if valid and false + llvm::Value* is_false = builder->CreateNot(current->data()); + llvm::Value* valid_and_false = + builder->CreateAnd(is_false, current->validity(), "valid_and_false"); + + llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_); + builder->CreateCondBr(valid_and_false, short_circuit_bb, else_bb); + + // Emit the else block. + builder->SetInsertPoint(else_bb); + // remember if any nulls were encountered. + all_exprs_valid = + builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd"); + // continue to evaluate the next pair in list. + } + builder->CreateBr(non_short_circuit_bb); + + // Short-circuit case (at least one of the expressions is valid and false). + // No need to set validity bit (valid by default). + builder->SetInsertPoint(short_circuit_bb); + ADD_VISITOR_TRACE("BooleanAndExpression result value false"); + ADD_VISITOR_TRACE("BooleanAndExpression result validity true"); + builder->CreateBr(merge_bb); + + // non short-circuit case (All expressions are either true or null). + // result valid if all of the exprs are non-null. + builder->SetInsertPoint(non_short_circuit_bb); + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid); + ADD_VISITOR_TRACE("BooleanAndExpression result value true"); + ADD_VISITOR_TRACE("BooleanAndExpression result validity %T", all_exprs_valid); + builder->CreateBr(merge_bb); + + builder->SetInsertPoint(merge_bb); + llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value"); + result_value->addIncoming(types->false_constant(), short_circuit_bb); + result_value->addIncoming(types->true_constant(), non_short_circuit_bb); + result_.reset(new LValue(result_value)); +} + +// Boolean OR +// if any arg is valid and true, +// short-circuit and return TRUE (value=true, valid=true) +// else if all args are valid and false +// return FALSE (value=false, valid=true) +// else +// return NULL (value=false, valid=false) + +void LLVMGenerator::Visitor::Visit(const BooleanOrDex& dex) { + ADD_VISITOR_TRACE("visit BooleanOrExpression"); + llvm::IRBuilder<>* builder = ir_builder(); + LLVMTypes* types = generator_->types(); + llvm::LLVMContext* context = generator_->context(); + + // Create blocks for short-circuit. + llvm::BasicBlock* short_circuit_bb = + llvm::BasicBlock::Create(*context, "short_circuit", function_); + llvm::BasicBlock* non_short_circuit_bb = + llvm::BasicBlock::Create(*context, "non_short_circuit", function_); + llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_); + + llvm::Value* all_exprs_valid = types->true_constant(); + for (auto& pair : dex.args()) { + LValuePtr current = BuildValueAndValidity(*pair); + + ADD_VISITOR_TRACE("BooleanOrExpression arg value %T", current->data()); + ADD_VISITOR_TRACE("BooleanOrExpression arg validity %T", current->validity()); + + // short-circuit if valid and true. + llvm::Value* valid_and_true = + builder->CreateAnd(current->data(), current->validity(), "valid_and_true"); + + llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_); + builder->CreateCondBr(valid_and_true, short_circuit_bb, else_bb); + + // Emit the else block. + builder->SetInsertPoint(else_bb); + // remember if any nulls were encountered. + all_exprs_valid = + builder->CreateAnd(all_exprs_valid, current->validity(), "validityBitAnd"); + // continue to evaluate the next pair in list. + } + builder->CreateBr(non_short_circuit_bb); + + // Short-circuit case (at least one of the expressions is valid and true). + // No need to set validity bit (valid by default). + builder->SetInsertPoint(short_circuit_bb); + ADD_VISITOR_TRACE("BooleanOrExpression result value true"); + ADD_VISITOR_TRACE("BooleanOrExpression result validity true"); + builder->CreateBr(merge_bb); + + // non short-circuit case (All expressions are either false or null). + // result valid if all of the exprs are non-null. + builder->SetInsertPoint(non_short_circuit_bb); + ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), all_exprs_valid); + ADD_VISITOR_TRACE("BooleanOrExpression result value false"); + ADD_VISITOR_TRACE("BooleanOrExpression result validity %T", all_exprs_valid); + builder->CreateBr(merge_bb); + + builder->SetInsertPoint(merge_bb); + llvm::PHINode* result_value = builder->CreatePHI(types->i1_type(), 2, "res_value"); + result_value->addIncoming(types->true_constant(), short_circuit_bb); + result_value->addIncoming(types->false_constant(), non_short_circuit_bb); + result_.reset(new LValue(result_value)); +} + +template <typename Type> +void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase<Type>& dex) { + ADD_VISITOR_TRACE("visit In Expression"); + LLVMTypes* types = generator_->types(); + std::vector<llvm::Value*> params; + + const InExprDex<Type>& dex_instance = dynamic_cast<const InExprDex<Type>&>(dex); + /* add the holder at the beginning */ + llvm::Constant* ptr_int_cast = + types->i64_constant((int64_t)(dex_instance.in_holder().get())); + params.push_back(ptr_int_cast); + + /* eval expr result */ + for (auto& pair : dex.args()) { + DexPtr value_expr = pair->value_expr(); + value_expr->Accept(*this); + LValue& result_ref = *result(); + params.push_back(result_ref.data()); + + /* length if the result is a string */ + if (result_ref.length() != nullptr) { + params.push_back(result_ref.length()); + } + + /* push the validity of eval expr result */ + llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs()); + params.push_back(validity_expr); + } + + llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL); + + llvm::Value* value; + + value = generator_->AddFunctionCall(dex.runtime_function(), ret_type, params); + + result_.reset(new LValue(value)); +} + +template <> +void LLVMGenerator::Visitor::VisitInExpression<gandiva::DecimalScalar128>( + const InExprDexBase<gandiva::DecimalScalar128>& dex) { + ADD_VISITOR_TRACE("visit In Expression"); + LLVMTypes* types = generator_->types(); + std::vector<llvm::Value*> params; + DecimalIR decimalIR(generator_->engine_.get()); + + const InExprDex<gandiva::DecimalScalar128>& dex_instance = + dynamic_cast<const InExprDex<gandiva::DecimalScalar128>&>(dex); + /* add the holder at the beginning */ + llvm::Constant* ptr_int_cast = + types->i64_constant((int64_t)(dex_instance.in_holder().get())); + params.push_back(ptr_int_cast); + + /* eval expr result */ + for (auto& pair : dex.args()) { + DexPtr value_expr = pair->value_expr(); + value_expr->Accept(*this); + LValue& result_ref = *result(); + params.push_back(result_ref.data()); + + llvm::Constant* precision = types->i32_constant(dex.get_precision()); + llvm::Constant* scale = types->i32_constant(dex.get_scale()); + params.push_back(precision); + params.push_back(scale); + + /* push the validity of eval expr result */ + llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs()); + params.push_back(validity_expr); + } + + llvm::Type* ret_type = types->IRType(arrow::Type::type::BOOL); + + llvm::Value* value; + + value = decimalIR.CallDecimalFunction(dex.runtime_function(), ret_type, params); + + result_.reset(new LValue(value)); +} + +void LLVMGenerator::Visitor::Visit(const InExprDexBase<int32_t>& dex) { + VisitInExpression<int32_t>(dex); +} + +void LLVMGenerator::Visitor::Visit(const InExprDexBase<int64_t>& dex) { + VisitInExpression<int64_t>(dex); +} + +void LLVMGenerator::Visitor::Visit(const InExprDexBase<float>& dex) { + VisitInExpression<float>(dex); +} +void LLVMGenerator::Visitor::Visit(const InExprDexBase<double>& dex) { + VisitInExpression<double>(dex); +} + +void LLVMGenerator::Visitor::Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) { + VisitInExpression<gandiva::DecimalScalar128>(dex); +} + +void LLVMGenerator::Visitor::Visit(const InExprDexBase<std::string>& dex) { + VisitInExpression<std::string>(dex); +} + +LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition, + std::function<LValuePtr()> then_func, + std::function<LValuePtr()> else_func, + DataTypePtr result_type) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::LLVMContext* context = generator_->context(); + LLVMTypes* types = generator_->types(); + + // Create blocks for the then, else and merge cases. + llvm::BasicBlock* then_bb = llvm::BasicBlock::Create(*context, "then", function_); + llvm::BasicBlock* else_bb = llvm::BasicBlock::Create(*context, "else", function_); + llvm::BasicBlock* merge_bb = llvm::BasicBlock::Create(*context, "merge", function_); + + builder->CreateCondBr(condition, then_bb, else_bb); + + // Emit the then block. + builder->SetInsertPoint(then_bb); + LValuePtr then_lvalue = then_func(); + builder->CreateBr(merge_bb); + + // refresh then_bb for phi (could have changed due to code generation of then_vv). + then_bb = builder->GetInsertBlock(); + + // Emit the else block. + builder->SetInsertPoint(else_bb); + LValuePtr else_lvalue = else_func(); + builder->CreateBr(merge_bb); + + // refresh else_bb for phi (could have changed due to code generation of else_vv). + else_bb = builder->GetInsertBlock(); + + // Emit the merge block. + builder->SetInsertPoint(merge_bb); + auto llvm_type = types->IRType(result_type->id()); + llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value"); + result_value->addIncoming(then_lvalue->data(), then_bb); + result_value->addIncoming(else_lvalue->data(), else_bb); + + LValuePtr ret; + switch (result_type->id()) { + case arrow::Type::STRING: + case arrow::Type::BINARY: { + llvm::PHINode* result_length; + result_length = builder->CreatePHI(types->i32_type(), 2, "res_length"); + result_length->addIncoming(then_lvalue->length(), then_bb); + result_length->addIncoming(else_lvalue->length(), else_bb); + ret = std::make_shared<LValue>(result_value, result_length); + break; + } + + case arrow::Type::DECIMAL: + ret = generator_->BuildDecimalLValue(result_value, result_type); + break; + + default: + ret = std::make_shared<LValue>(result_value); + break; + } + return ret; +} + +LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair& pair) { + // generate code for value + auto value_expr = pair.value_expr(); + value_expr->Accept(*this); + auto value = result()->data(); + auto length = result()->length(); + + // generate code for validity + auto validity = BuildCombinedValidity(pair.validity_exprs()); + + return std::make_shared<LValue>(value, length, validity); +} + +LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, + DataTypePtr arrow_return_type, + std::vector<llvm::Value*>* params) { + auto types = generator_->types(); + auto arrow_return_type_id = arrow_return_type->id(); + auto llvm_return_type = types->IRType(arrow_return_type_id); + DecimalIR decimalIR(generator_->engine_.get()); + + if (arrow_return_type_id == arrow::Type::DECIMAL) { + // For decimal fns, the output precision/scale are passed along as parameters. + // + // convert from this : + // out = add_decimal(v1, p1, s1, v2, p2, s2) + // to: + // out = add_decimal(v1, p1, s1, v2, p2, s2, out_p, out_s) + + // Append the out_precision and out_scale + auto ret_lvalue = generator_->BuildDecimalLValue(nullptr, arrow_return_type); + params->push_back(ret_lvalue->precision()); + params->push_back(ret_lvalue->scale()); + + // Make the function call + auto out = decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params); + ret_lvalue->set_data(out); + return std::move(ret_lvalue); + } else { + bool isDecimalFunction = false; + for (auto& arg : *params) { + if (arg->getType() == types->i128_type()) { + isDecimalFunction = true; + } + } + // add extra arg for return length for variable len return types (allocated on stack). + llvm::AllocaInst* result_len_ptr = nullptr; + if (arrow::is_binary_like(arrow_return_type_id)) { + result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0, + "result_len", entry_block_); + params->push_back(result_len_ptr); + has_arena_allocs_ = true; + } + + // Make the function call + llvm::IRBuilder<>* builder = ir_builder(); + auto value = + isDecimalFunction + ? decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params) + : generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params); + auto value_len = + (result_len_ptr == nullptr) ? nullptr : CreateLoad(builder, result_len_ptr); + return std::make_shared<LValue>(value, value_len); + } +} + +std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams( + FunctionHolder* holder, const ValueValidityPairVector& args, bool with_validity, + bool with_context) { + LLVMTypes* types = generator_->types(); + std::vector<llvm::Value*> params; + + // add context if required. + if (with_context) { + params.push_back(arg_context_ptr_); + } + + // if the function has holder, add the holder pointer. + if (holder != nullptr) { + auto ptr = types->i64_constant((int64_t)holder); + params.push_back(ptr); + } + + // build the function params, along with the validities. + for (auto& pair : args) { + // build value. + DexPtr value_expr = pair->value_expr(); + value_expr->Accept(*this); + LValue& result_ref = *result(); + + // append all the parameters corresponding to this LValue. + result_ref.AppendFunctionParams(¶ms); + + // build validity. + if (with_validity) { + llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs()); + params.push_back(validity_expr); + } + } + + return params; +} + +// Bitwise-AND of a vector of bits to get the combined validity. +llvm::Value* LLVMGenerator::Visitor::BuildCombinedValidity(const DexVector& validities) { + llvm::IRBuilder<>* builder = ir_builder(); + LLVMTypes* types = generator_->types(); + + llvm::Value* isValid = types->true_constant(); + for (auto& dex : validities) { + dex->Accept(*this); + isValid = builder->CreateAnd(isValid, result()->data(), "validityBitAnd"); + } + ADD_VISITOR_TRACE("combined validity is %T", isValid); + return isValid; +} + +llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buffer_type, + FieldPtr field) { + llvm::IRBuilder<>* builder = ir_builder(); + + // Switch to the entry block to create a reference. + llvm::BasicBlock* saved_block = builder->GetInsertBlock(); + builder->SetInsertPoint(entry_block_); + + llvm::Value* slot_ref = nullptr; + switch (buffer_type) { + case kBufferTypeValidity: + slot_ref = generator_->GetValidityReference(arg_addrs_, idx, field); + break; + + case kBufferTypeData: + slot_ref = generator_->GetDataReference(arg_addrs_, idx, field); + break; + + case kBufferTypeOffsets: + slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field); + break; + } + + // Revert to the saved block. + builder->SetInsertPoint(saved_block); + return slot_ref; +} + +llvm::Value* LLVMGenerator::Visitor::GetSliceOffset(int idx) { + return slice_offsets_[idx]; +} + +llvm::Value* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) { + llvm::IRBuilder<>* builder = ir_builder(); + + // Switch to the entry block to create a reference. + llvm::BasicBlock* saved_block = builder->GetInsertBlock(); + builder->SetInsertPoint(entry_block_); + + llvm::Value* slot_ref = generator_->GetLocalBitMapReference(arg_local_bitmaps_, idx); + + // Revert to the saved block. + builder->SetInsertPoint(saved_block); + return slot_ref; +} + +/// The local bitmap is pre-filled with 1s. Clear only if invalid. +void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx, + llvm::Value* is_valid) { + llvm::Value* slot_ref = GetLocalBitMapReference(local_bitmap_idx); + generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid); +} + +// Hooks for tracing/printfs. +// +// replace %T with the type-specific format specifier. +// For some reason, float/double literals are getting lost when printing with the generic +// printf. so, use a wrapper instead. +std::string LLVMGenerator::ReplaceFormatInTrace(const std::string& in_msg, + llvm::Value* value, + std::string* print_fn) { + std::string msg = in_msg; + std::size_t pos = msg.find("%T"); + if (pos == std::string::npos) { + DCHECK(0); + return msg; + } + + llvm::Type* type = value->getType(); + const char* fmt = ""; + if (type->isIntegerTy(1) || type->isIntegerTy(8) || type->isIntegerTy(16) || + type->isIntegerTy(32)) { + fmt = "%d"; + } else if (type->isIntegerTy(64)) { + // bigint + fmt = "%lld"; + } else if (type->isFloatTy()) { + // float + fmt = "%f"; + *print_fn = "print_float"; + } else if (type->isDoubleTy()) { + // float + fmt = "%lf"; + *print_fn = "print_double"; + } else if (type->isPointerTy()) { + // string + fmt = "%s"; + } else { + DCHECK(0); + } + msg.replace(pos, 2, fmt); + return msg; +} + +void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) { + if (!enable_ir_traces_) { + return; + } + + std::string dmsg = "IR_TRACE:: " + msg + "\n"; + std::string print_fn_name = "printf"; + if (value != nullptr) { + dmsg = ReplaceFormatInTrace(dmsg, value, &print_fn_name); + } + trace_strings_.push_back(dmsg); + + // cast this to an llvm pointer. + const char* str = trace_strings_.back().c_str(); + llvm::Constant* str_int_cast = types()->i64_constant((int64_t)str); + llvm::Constant* str_ptr_cast = + llvm::ConstantExpr::getIntToPtr(str_int_cast, types()->i8_ptr_type()); + + std::vector<llvm::Value*> args; + args.push_back(str_ptr_cast); + if (value != nullptr) { + args.push_back(value); + } + AddFunctionCall(print_fn_name, types()->i32_type(), args); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/llvm_generator.h b/src/arrow/cpp/src/gandiva/llvm_generator.h new file mode 100644 index 000000000..ff6d84602 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/llvm_generator.h @@ -0,0 +1,253 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <memory> +#include <string> +#include <vector> + +#include "arrow/util/macros.h" + +#include "gandiva/annotator.h" +#include "gandiva/compiled_expr.h" +#include "gandiva/configuration.h" +#include "gandiva/dex_visitor.h" +#include "gandiva/engine.h" +#include "gandiva/execution_context.h" +#include "gandiva/function_registry.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/llvm_types.h" +#include "gandiva/lvalue.h" +#include "gandiva/selection_vector.h" +#include "gandiva/value_validity_pair.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class FunctionHolder; + +/// Builds an LLVM module and generates code for the specified set of expressions. +class GANDIVA_EXPORT LLVMGenerator { + public: + /// \brief Factory method to initialize the generator. + static Status Make(std::shared_ptr<Configuration> config, + std::unique_ptr<LLVMGenerator>* llvm_generator); + + /// \brief Build the code for the expression trees for default mode. Each + /// element in the vector represents an expression tree + Status Build(const ExpressionVector& exprs, SelectionVector::Mode mode); + + /// \brief Build the code for the expression trees for default mode. Each + /// element in the vector represents an expression tree + Status Build(const ExpressionVector& exprs) { + return Build(exprs, SelectionVector::Mode::MODE_NONE); + } + + /// \brief Execute the built expression against the provided arguments for + /// default mode. + Status Execute(const arrow::RecordBatch& record_batch, + const ArrayDataVector& output_vector); + + /// \brief Execute the built expression against the provided arguments for + /// all modes. Only works on the records specified in the selection_vector. + Status Execute(const arrow::RecordBatch& record_batch, + const SelectionVector* selection_vector, + const ArrayDataVector& output_vector); + + SelectionVector::Mode selection_vector_mode() { return selection_vector_mode_; } + LLVMTypes* types() { return engine_->types(); } + llvm::Module* module() { return engine_->module(); } + std::string DumpIR() { return engine_->DumpIR(); } + + private: + LLVMGenerator(); + + FRIEND_TEST(TestLLVMGenerator, VerifyPCFunctions); + FRIEND_TEST(TestLLVMGenerator, TestAdd); + FRIEND_TEST(TestLLVMGenerator, TestNullInternal); + + llvm::LLVMContext* context() { return engine_->context(); } + llvm::IRBuilder<>* ir_builder() { return engine_->ir_builder(); } + + /// Visitor to generate the code for a decomposed expression. + class Visitor : public DexVisitor { + public: + Visitor(LLVMGenerator* generator, llvm::Function* function, + llvm::BasicBlock* entry_block, llvm::Value* arg_addrs, + llvm::Value* arg_local_bitmaps, std::vector<llvm::Value*> slice_offsets, + llvm::Value* arg_context_ptr, llvm::Value* loop_var); + + void Visit(const VectorReadValidityDex& dex) override; + void Visit(const VectorReadFixedLenValueDex& dex) override; + void Visit(const VectorReadVarLenValueDex& dex) override; + void Visit(const LocalBitMapValidityDex& dex) override; + void Visit(const TrueDex& dex) override; + void Visit(const FalseDex& dex) override; + void Visit(const LiteralDex& dex) override; + void Visit(const NonNullableFuncDex& dex) override; + void Visit(const NullableNeverFuncDex& dex) override; + void Visit(const NullableInternalFuncDex& dex) override; + void Visit(const IfDex& dex) override; + void Visit(const BooleanAndDex& dex) override; + void Visit(const BooleanOrDex& dex) override; + void Visit(const InExprDexBase<int32_t>& dex) override; + void Visit(const InExprDexBase<int64_t>& dex) override; + void Visit(const InExprDexBase<float>& dex) override; + void Visit(const InExprDexBase<double>& dex) override; + void Visit(const InExprDexBase<gandiva::DecimalScalar128>& dex) override; + void Visit(const InExprDexBase<std::string>& dex) override; + template <typename Type> + void VisitInExpression(const InExprDexBase<Type>& dex); + + LValuePtr result() { return result_; } + + bool has_arena_allocs() { return has_arena_allocs_; } + + private: + enum BufferType { kBufferTypeValidity = 0, kBufferTypeData, kBufferTypeOffsets }; + + llvm::IRBuilder<>* ir_builder() { return generator_->ir_builder(); } + llvm::Module* module() { return generator_->module(); } + + // Generate the code to build the combined validity (bitwise and) from the + // vector of validities. + llvm::Value* BuildCombinedValidity(const DexVector& validities); + + // Generate the code to build the validity and the value for the given pair. + LValuePtr BuildValueAndValidity(const ValueValidityPair& pair); + + // Generate code to build the params. + std::vector<llvm::Value*> BuildParams(FunctionHolder* holder, + const ValueValidityPairVector& args, + bool with_validity, bool with_context); + + // Generate code to onvoke a function call. + LValuePtr BuildFunctionCall(const NativeFunction* func, DataTypePtr arrow_return_type, + std::vector<llvm::Value*>* params); + + // Generate code for an if-else condition. + LValuePtr BuildIfElse(llvm::Value* condition, std::function<LValuePtr()> then_func, + std::function<LValuePtr()> else_func, + DataTypePtr arrow_return_type); + + // Switch to the entry_block and get reference of the validity/value/offsets buffer + llvm::Value* GetBufferReference(int idx, BufferType buffer_type, FieldPtr field); + + // Get the slice offset of the validity/value/offsets buffer + llvm::Value* GetSliceOffset(int idx); + + // Switch to the entry_block and get reference to the local bitmap. + llvm::Value* GetLocalBitMapReference(int idx); + + // Clear the bit in the local bitmap, if is_valid is 'false' + void ClearLocalBitMapIfNotValid(int local_bitmap_idx, llvm::Value* is_valid); + + LLVMGenerator* generator_; + LValuePtr result_; + llvm::Function* function_; + llvm::BasicBlock* entry_block_; + llvm::Value* arg_addrs_; + llvm::Value* arg_local_bitmaps_; + std::vector<llvm::Value*> slice_offsets_; + llvm::Value* arg_context_ptr_; + llvm::Value* loop_var_; + bool has_arena_allocs_; + }; + + // Generate the code for one expression for default mode, with the output of + // the expression going to 'output'. + Status Add(const ExpressionPtr expr, const FieldDescriptorPtr output); + + /// Generate code to load the vector at specified index in the 'arg_addrs' array. + llvm::Value* LoadVectorAtIndex(llvm::Value* arg_addrs, int idx, + const std::string& name); + + /// Generate code to load the vector at specified index and cast it as bitmap. + llvm::Value* GetValidityReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + + /// Generate code to load the vector at specified index and cast it as data array. + llvm::Value* GetDataReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + + /// Generate code to load the vector at specified index and cast it as offsets array. + llvm::Value* GetOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + + /// Generate code to load the vector at specified index and cast it as buffer pointer. + llvm::Value* GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + + /// Generate code for the value array of one expression. + Status CodeGenExprValue(DexPtr value_expr, int num_buffers, FieldDescriptorPtr output, + int suffix_idx, llvm::Function** fn, + SelectionVector::Mode selection_vector_mode); + + /// Generate code to load the local bitmap specified index and cast it as bitmap. + llvm::Value* GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx); + + /// Generate code to get the bit value at 'position' in the bitmap. + llvm::Value* GetPackedBitValue(llvm::Value* bitmap, llvm::Value* position); + + /// Generate code to get the bit value at 'position' in the validity bitmap. + llvm::Value* GetPackedValidityBitValue(llvm::Value* bitmap, llvm::Value* position); + + /// Generate code to set the bit value at 'position' in the bitmap to 'value'. + void SetPackedBitValue(llvm::Value* bitmap, llvm::Value* position, llvm::Value* value); + + /// Generate code to clear the bit value at 'position' in the bitmap if 'value' + /// is false. + void ClearPackedBitValueIfFalse(llvm::Value* bitmap, llvm::Value* position, + llvm::Value* value); + + // Generate code to build a DecimalLValue with specified value/precision/scale. + std::shared_ptr<DecimalLValue> BuildDecimalLValue(llvm::Value* value, + DataTypePtr arrow_type); + + /// Generate code to make a function call (to a pre-compiled IR function) which takes + /// 'args' and has a return type 'ret_type'. + llvm::Value* AddFunctionCall(const std::string& full_name, llvm::Type* ret_type, + const std::vector<llvm::Value*>& args); + + /// Compute the result bitmap for the expression. + /// + /// \param[in] compiled_expr the compiled expression (includes the bitmap indices to be + /// used for computing the validity bitmap of the result). + /// \param[in] eval_batch (includes input/output buffer addresses) + /// \param[in] selection_vector the list of selected positions + void ComputeBitMapsForExpr(const CompiledExpr& compiled_expr, + const EvalBatch& eval_batch, + const SelectionVector* selection_vector); + + /// Replace the %T in the trace msg with the correct type corresponding to 'type' + /// eg. %d for int32, %ld for int64, .. + std::string ReplaceFormatInTrace(const std::string& msg, llvm::Value* value, + std::string* print_fn); + + /// Generate the code to print a trace msg with one optional argument (%T) + void AddTrace(const std::string& msg, llvm::Value* value = NULLPTR); + + std::unique_ptr<Engine> engine_; + std::vector<std::unique_ptr<CompiledExpr>> compiled_exprs_; + FunctionRegistry function_registry_; + Annotator annotator_; + SelectionVector::Mode selection_vector_mode_; + + // used for debug + bool enable_ir_traces_; + std::vector<std::string> trace_strings_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/llvm_generator_test.cc b/src/arrow/cpp/src/gandiva/llvm_generator_test.cc new file mode 100644 index 000000000..bdc3b0051 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/llvm_generator_test.cc @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/llvm_generator.h" + +#include <memory> +#include <vector> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "gandiva/configuration.h" +#include "gandiva/dex.h" +#include "gandiva/expression.h" +#include "gandiva/func_descriptor.h" +#include "gandiva/function_registry.h" +#include "gandiva/tests/test_util.h" + +namespace gandiva { + +typedef int64_t (*add_vector_func_t)(int64_t* elements, int nelements); + +class TestLLVMGenerator : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + +// Verify that a valid pc function exists for every function in the registry. +TEST_F(TestLLVMGenerator, VerifyPCFunctions) { + std::unique_ptr<LLVMGenerator> generator; + ASSERT_OK(LLVMGenerator::Make(TestConfiguration(), &generator)); + + llvm::Module* module = generator->module(); + for (auto& iter : registry_) { + EXPECT_NE(module->getFunction(iter.pc_name()), nullptr); + } +} + +TEST_F(TestLLVMGenerator, TestAdd) { + // Setup LLVM generator to do an arithmetic add of two vectors + std::unique_ptr<LLVMGenerator> generator; + ASSERT_OK(LLVMGenerator::Make(TestConfiguration(), &generator)); + Annotator annotator; + + auto field0 = std::make_shared<arrow::Field>("f0", arrow::int32()); + auto desc0 = annotator.CheckAndAddInputFieldDescriptor(field0); + auto validity_dex0 = std::make_shared<VectorReadValidityDex>(desc0); + auto value_dex0 = std::make_shared<VectorReadFixedLenValueDex>(desc0); + auto pair0 = std::make_shared<ValueValidityPair>(validity_dex0, value_dex0); + + auto field1 = std::make_shared<arrow::Field>("f1", arrow::int32()); + auto desc1 = annotator.CheckAndAddInputFieldDescriptor(field1); + auto validity_dex1 = std::make_shared<VectorReadValidityDex>(desc1); + auto value_dex1 = std::make_shared<VectorReadFixedLenValueDex>(desc1); + auto pair1 = std::make_shared<ValueValidityPair>(validity_dex1, value_dex1); + + DataTypeVector params{arrow::int32(), arrow::int32()}; + auto func_desc = std::make_shared<FuncDescriptor>("add", params, arrow::int32()); + FunctionSignature signature(func_desc->name(), func_desc->params(), + func_desc->return_type()); + const NativeFunction* native_func = + generator->function_registry_.LookupSignature(signature); + + std::vector<ValueValidityPairPtr> pairs{pair0, pair1}; + auto func_dex = std::make_shared<NonNullableFuncDex>(func_desc, native_func, + FunctionHolderPtr(nullptr), pairs); + + auto field_sum = std::make_shared<arrow::Field>("out", arrow::int32()); + auto desc_sum = annotator.CheckAndAddInputFieldDescriptor(field_sum); + + llvm::Function* ir_func = nullptr; + + ASSERT_OK(generator->CodeGenExprValue(func_dex, 4, desc_sum, 0, &ir_func, + SelectionVector::MODE_NONE)); + + ASSERT_OK(generator->engine_->FinalizeModule()); + auto ir = generator->engine_->DumpIR(); + EXPECT_THAT(ir, testing::HasSubstr("vector.body")); + + EvalFunc eval_func = (EvalFunc)generator->engine_->CompiledFunction(ir_func); + + constexpr size_t kNumRecords = 4; + std::array<uint32_t, kNumRecords> a0{1, 2, 3, 4}; + std::array<uint32_t, kNumRecords> a1{5, 6, 7, 8}; + uint64_t in_bitmap = 0xffffffffffffffffull; + + std::array<uint32_t, kNumRecords> out{0, 0, 0, 0}; + uint64_t out_bitmap = 0; + + std::array<uint8_t*, 6> addrs{ + reinterpret_cast<uint8_t*>(a0.data()), reinterpret_cast<uint8_t*>(&in_bitmap), + reinterpret_cast<uint8_t*>(a1.data()), reinterpret_cast<uint8_t*>(&in_bitmap), + reinterpret_cast<uint8_t*>(out.data()), reinterpret_cast<uint8_t*>(&out_bitmap), + }; + std::array<int64_t, 6> addr_offsets{0, 0, 0, 0, 0, 0}; + eval_func(addrs.data(), addr_offsets.data(), nullptr, nullptr, + 0 /* dummy context ptr */, kNumRecords); + + EXPECT_THAT(out, testing::ElementsAre(6, 8, 10, 12)); + EXPECT_EQ(out_bitmap, 0ULL); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/llvm_includes.h b/src/arrow/cpp/src/gandiva/llvm_includes.h new file mode 100644 index 000000000..37f915eb5 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/llvm_includes.h @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4141) +#pragma warning(disable : 4146) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4291) +#pragma warning(disable : 4624) +#endif + +#include <llvm/ExecutionEngine/ExecutionEngine.h> +#include <llvm/IR/IRBuilder.h> +#include <llvm/IR/LLVMContext.h> +#include <llvm/IR/Module.h> + +#if LLVM_VERSION_MAJOR >= 10 +#define LLVM_ALIGN(alignment) (llvm::Align((alignment))) +#else +#define LLVM_ALIGN(alignment) (alignment) +#endif + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +// Workaround for deprecated builder methods as of LLVM 13: ARROW-14363 +inline llvm::Value* CreateGEP(llvm::IRBuilder<>* builder, llvm::Value* Ptr, + llvm::ArrayRef<llvm::Value*> IdxList, + const llvm::Twine& Name = "") { + return builder->CreateGEP(Ptr->getType()->getScalarType()->getPointerElementType(), Ptr, + IdxList, Name); +} + +inline llvm::LoadInst* CreateLoad(llvm::IRBuilder<>* builder, llvm::Value* Ptr, + const llvm::Twine& Name = "") { + return builder->CreateLoad(Ptr->getType()->getPointerElementType(), Ptr, Name); +} diff --git a/src/arrow/cpp/src/gandiva/llvm_types.cc b/src/arrow/cpp/src/gandiva/llvm_types.cc new file mode 100644 index 000000000..de322a8c0 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/llvm_types.cc @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/llvm_types.h" + +namespace gandiva { + +// LLVM doesn't distinguish between signed and unsigned types. + +LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) { + arrow_id_to_llvm_type_map_ = {{arrow::Type::type::BOOL, i1_type()}, + {arrow::Type::type::INT8, i8_type()}, + {arrow::Type::type::INT16, i16_type()}, + {arrow::Type::type::INT32, i32_type()}, + {arrow::Type::type::INT64, i64_type()}, + {arrow::Type::type::UINT8, i8_type()}, + {arrow::Type::type::UINT16, i16_type()}, + {arrow::Type::type::UINT32, i32_type()}, + {arrow::Type::type::UINT64, i64_type()}, + {arrow::Type::type::FLOAT, float_type()}, + {arrow::Type::type::DOUBLE, double_type()}, + {arrow::Type::type::DATE32, i32_type()}, + {arrow::Type::type::DATE64, i64_type()}, + {arrow::Type::type::TIME32, i32_type()}, + {arrow::Type::type::TIME64, i64_type()}, + {arrow::Type::type::TIMESTAMP, i64_type()}, + {arrow::Type::type::STRING, i8_ptr_type()}, + {arrow::Type::type::BINARY, i8_ptr_type()}, + {arrow::Type::type::DECIMAL, i128_type()}, + {arrow::Type::type::INTERVAL_MONTHS, i32_type()}, + {arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}}; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/llvm_types.h b/src/arrow/cpp/src/gandiva/llvm_types.h new file mode 100644 index 000000000..d6f095271 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/llvm_types.h @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <map> +#include <vector> + +#include "arrow/util/logging.h" +#include "gandiva/arrow.h" +#include "gandiva/llvm_includes.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Holder for llvm types, and mappings between arrow types and llvm types. +class GANDIVA_EXPORT LLVMTypes { + public: + explicit LLVMTypes(llvm::LLVMContext& context); + + llvm::Type* void_type() { return llvm::Type::getVoidTy(context_); } + + llvm::Type* i1_type() { return llvm::Type::getInt1Ty(context_); } + + llvm::Type* i8_type() { return llvm::Type::getInt8Ty(context_); } + + llvm::Type* i16_type() { return llvm::Type::getInt16Ty(context_); } + + llvm::Type* i32_type() { return llvm::Type::getInt32Ty(context_); } + + llvm::Type* i64_type() { return llvm::Type::getInt64Ty(context_); } + + llvm::Type* i128_type() { return llvm::Type::getInt128Ty(context_); } + + llvm::StructType* i128_split_type() { + // struct with high/low bits (see decimal_ops.cc:DecimalSplit) + return llvm::StructType::get(context_, {i64_type(), i64_type()}, false); + } + + llvm::Type* float_type() { return llvm::Type::getFloatTy(context_); } + + llvm::Type* double_type() { return llvm::Type::getDoubleTy(context_); } + + llvm::PointerType* ptr_type(llvm::Type* type) { return type->getPointerTo(); } + + llvm::PointerType* i8_ptr_type() { return ptr_type(i8_type()); } + + llvm::PointerType* i32_ptr_type() { return ptr_type(i32_type()); } + + llvm::PointerType* i64_ptr_type() { return ptr_type(i64_type()); } + + llvm::PointerType* i128_ptr_type() { return ptr_type(i128_type()); } + + template <typename ctype, size_t N = (sizeof(ctype) * CHAR_BIT)> + llvm::Constant* int_constant(ctype val) { + return llvm::ConstantInt::get(context_, llvm::APInt(N, val)); + } + + llvm::Constant* i1_constant(bool val) { return int_constant<bool, 1>(val); } + llvm::Constant* i8_constant(int8_t val) { return int_constant(val); } + llvm::Constant* i16_constant(int16_t val) { return int_constant(val); } + llvm::Constant* i32_constant(int32_t val) { return int_constant(val); } + llvm::Constant* i64_constant(int64_t val) { return int_constant(val); } + llvm::Constant* i128_constant(int64_t val) { return int_constant<int64_t, 128>(val); } + + llvm::Constant* true_constant() { return i1_constant(true); } + llvm::Constant* false_constant() { return i1_constant(false); } + + llvm::Constant* i128_zero() { return i128_constant(0); } + llvm::Constant* i128_one() { return i128_constant(1); } + + llvm::Constant* float_constant(float val) { + return llvm::ConstantFP::get(float_type(), val); + } + + llvm::Constant* double_constant(double val) { + return llvm::ConstantFP::get(double_type(), val); + } + + llvm::Constant* NullConstant(llvm::Type* type) { + if (type->isIntegerTy()) { + return llvm::ConstantInt::get(type, 0); + } else if (type->isFloatingPointTy()) { + return llvm::ConstantFP::get(type, 0); + } else { + DCHECK(type->isPointerTy()); + return llvm::ConstantPointerNull::getNullValue(type); + } + } + + /// For a given data type, find the ir type used for the data vector slot. + llvm::Type* DataVecType(const DataTypePtr& data_type) { + return IRType(data_type->id()); + } + + /// For a given minor type, find the corresponding ir type. + llvm::Type* IRType(arrow::Type::type arrow_type) { + auto found = arrow_id_to_llvm_type_map_.find(arrow_type); + return (found == arrow_id_to_llvm_type_map_.end()) ? NULL : found->second; + } + + std::vector<arrow::Type::type> GetSupportedArrowTypes() { + std::vector<arrow::Type::type> retval; + for (auto const& element : arrow_id_to_llvm_type_map_) { + retval.push_back(element.first); + } + return retval; + } + + private: + std::map<arrow::Type::type, llvm::Type*> arrow_id_to_llvm_type_map_; + + llvm::LLVMContext& context_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/llvm_types_test.cc b/src/arrow/cpp/src/gandiva/llvm_types_test.cc new file mode 100644 index 000000000..666968306 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/llvm_types_test.cc @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/llvm_types.h" + +#include <gtest/gtest.h> + +namespace gandiva { + +class TestLLVMTypes : public ::testing::Test { + protected: + virtual void SetUp() { types_ = new LLVMTypes(context_); } + virtual void TearDown() { delete types_; } + + llvm::LLVMContext context_; + LLVMTypes* types_; +}; + +TEST_F(TestLLVMTypes, TestFound) { + EXPECT_EQ(types_->IRType(arrow::Type::BOOL), types_->i1_type()); + EXPECT_EQ(types_->IRType(arrow::Type::INT32), types_->i32_type()); + EXPECT_EQ(types_->IRType(arrow::Type::INT64), types_->i64_type()); + EXPECT_EQ(types_->IRType(arrow::Type::FLOAT), types_->float_type()); + EXPECT_EQ(types_->IRType(arrow::Type::DOUBLE), types_->double_type()); + EXPECT_EQ(types_->IRType(arrow::Type::DATE64), types_->i64_type()); + EXPECT_EQ(types_->IRType(arrow::Type::TIME64), types_->i64_type()); + EXPECT_EQ(types_->IRType(arrow::Type::TIMESTAMP), types_->i64_type()); + + EXPECT_EQ(types_->DataVecType(arrow::boolean()), types_->i1_type()); + EXPECT_EQ(types_->DataVecType(arrow::int32()), types_->i32_type()); + EXPECT_EQ(types_->DataVecType(arrow::int64()), types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::float32()), types_->float_type()); + EXPECT_EQ(types_->DataVecType(arrow::float64()), types_->double_type()); + EXPECT_EQ(types_->DataVecType(arrow::date64()), types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::time64(arrow::TimeUnit::MICRO)), + types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::timestamp(arrow::TimeUnit::MILLI)), + types_->i64_type()); +} + +TEST_F(TestLLVMTypes, TestNotFound) { + EXPECT_EQ(types_->IRType(arrow::Type::SPARSE_UNION), nullptr); + EXPECT_EQ(types_->IRType(arrow::Type::DENSE_UNION), nullptr); + EXPECT_EQ(types_->DataVecType(arrow::null()), nullptr); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/local_bitmaps_holder.h b/src/arrow/cpp/src/gandiva/local_bitmaps_holder.h new file mode 100644 index 000000000..a172fb973 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/local_bitmaps_holder.h @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <utility> +#include <vector> + +#include <arrow/util/logging.h> +#include "gandiva/arrow.h" +#include "gandiva/gandiva_aliases.h" + +namespace gandiva { + +/// \brief The buffers corresponding to one batch of records, used for +/// expression evaluation. +class LocalBitMapsHolder { + public: + LocalBitMapsHolder(int64_t num_records, int num_local_bitmaps); + + int GetNumLocalBitMaps() const { return static_cast<int>(local_bitmaps_vec_.size()); } + + int64_t GetLocalBitMapSize() const { return local_bitmap_size_; } + + uint8_t** GetLocalBitMapArray() const { return local_bitmaps_array_.get(); } + + uint8_t* GetLocalBitMap(int idx) const { + DCHECK(idx <= GetNumLocalBitMaps()); + return local_bitmaps_array_.get()[idx]; + } + + private: + /// number of records in the current batch. + int64_t num_records_; + + /// A container of 'local_bitmaps_', each sized to accommodate 'num_records'. + std::vector<std::unique_ptr<uint8_t[]>> local_bitmaps_vec_; + + /// An array of the local bitmaps. + std::unique_ptr<uint8_t*[]> local_bitmaps_array_; + + int64_t local_bitmap_size_; +}; + +inline LocalBitMapsHolder::LocalBitMapsHolder(int64_t num_records, int num_local_bitmaps) + : num_records_(num_records) { + // alloc an array for the pointers to the bitmaps. + if (num_local_bitmaps > 0) { + local_bitmaps_array_.reset(new uint8_t*[num_local_bitmaps]); + } + + // 64-bit aligned bitmaps. + int64_t roundUp64Multiple = (num_records_ + 63) >> 6; + local_bitmap_size_ = roundUp64Multiple * 8; + + // Alloc 'num_local_bitmaps_' number of bitmaps, each of capacity 'num_records_'. + for (int i = 0; i < num_local_bitmaps; ++i) { + // TODO : round-up to a slab friendly multiple. + std::unique_ptr<uint8_t[]> bitmap(new uint8_t[local_bitmap_size_]); + + // keep pointer to the bitmap in the array. + (local_bitmaps_array_.get())[i] = bitmap.get(); + + // pre-fill with 1s (assuming that the probability of is_valid is higher). + memset(bitmap.get(), 0xff, local_bitmap_size_); + local_bitmaps_vec_.push_back(std::move(bitmap)); + } +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/lvalue.h b/src/arrow/cpp/src/gandiva/lvalue.h new file mode 100644 index 000000000..df292855b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/lvalue.h @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> + +#include "arrow/util/macros.h" + +#include "arrow/util/logging.h" +#include "gandiva/llvm_includes.h" + +namespace gandiva { + +/// \brief Tracks validity/value builders in LLVM. +class GANDIVA_EXPORT LValue { + public: + explicit LValue(llvm::Value* data, llvm::Value* length = NULLPTR, + llvm::Value* validity = NULLPTR) + : data_(data), length_(length), validity_(validity) {} + virtual ~LValue() = default; + + llvm::Value* data() { return data_; } + llvm::Value* length() { return length_; } + llvm::Value* validity() { return validity_; } + + void set_data(llvm::Value* data) { data_ = data; } + + // Append the params required when passing this as a function parameter. + virtual void AppendFunctionParams(std::vector<llvm::Value*>* params) { + params->push_back(data_); + if (length_ != NULLPTR) { + params->push_back(length_); + } + } + + private: + llvm::Value* data_; + llvm::Value* length_; + llvm::Value* validity_; +}; + +class GANDIVA_EXPORT DecimalLValue : public LValue { + public: + DecimalLValue(llvm::Value* data, llvm::Value* validity, llvm::Value* precision, + llvm::Value* scale) + : LValue(data, NULLPTR, validity), precision_(precision), scale_(scale) {} + + llvm::Value* precision() { return precision_; } + llvm::Value* scale() { return scale_; } + + void AppendFunctionParams(std::vector<llvm::Value*>* params) override { + LValue::AppendFunctionParams(params); + params->push_back(precision_); + params->push_back(scale_); + } + + private: + llvm::Value* precision_; + llvm::Value* scale_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/make_precompiled_bitcode.py b/src/arrow/cpp/src/gandiva/make_precompiled_bitcode.py new file mode 100644 index 000000000..97d96f8a8 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/make_precompiled_bitcode.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +marker = b"<DATA_CHARS>" + +def expand(data): + """ + Expand *data* as a initializer list of hexadecimal char escapes. + """ + expanded_data = ", ".join([hex(c) for c in bytearray(data)]) + return expanded_data.encode('ascii') + + +def apply_template(template, data): + if template.count(marker) != 1: + raise ValueError("Invalid template") + return template.replace(marker, expand(data)) + + +if __name__ == "__main__": + if len(sys.argv) != 4: + raise ValueError("Usage: {0} <template file> <data file> " + "<output file>".format(sys.argv[0])) + with open(sys.argv[1], "rb") as f: + template = f.read() + with open(sys.argv[2], "rb") as f: + data = f.read() + + expanded_data = apply_template(template, data) + with open(sys.argv[3], "wb") as f: + f.write(expanded_data) diff --git a/src/arrow/cpp/src/gandiva/native_function.h b/src/arrow/cpp/src/gandiva/native_function.h new file mode 100644 index 000000000..1268a2567 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/native_function.h @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <vector> + +#include "gandiva/arrow.h" +#include "gandiva/function_signature.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +enum ResultNullableType { + /// result validity is an intersection of the validity of the children. + kResultNullIfNull, + /// result is always valid. + kResultNullNever, + /// result validity depends on some internal logic. + kResultNullInternal, +}; + +/// \brief Holder for the mapping from a function in an expression to a +/// precompiled function. +class GANDIVA_EXPORT NativeFunction { + public: + // function attributes. + static constexpr int32_t kNeedsContext = (1 << 1); + static constexpr int32_t kNeedsFunctionHolder = (1 << 2); + static constexpr int32_t kCanReturnErrors = (1 << 3); + + const std::vector<FunctionSignature>& signatures() const { return signatures_; } + std::string pc_name() const { return pc_name_; } + ResultNullableType result_nullable_type() const { return result_nullable_type_; } + + bool NeedsContext() const { return (flags_ & kNeedsContext) != 0; } + bool NeedsFunctionHolder() const { return (flags_ & kNeedsFunctionHolder) != 0; } + bool CanReturnErrors() const { return (flags_ & kCanReturnErrors) != 0; } + + NativeFunction(const std::string& base_name, const std::vector<std::string>& aliases, + const DataTypeVector& param_types, DataTypePtr ret_type, + const ResultNullableType& result_nullable_type, + const std::string& pc_name, int32_t flags = 0) + : signatures_(), + flags_(flags), + result_nullable_type_(result_nullable_type), + pc_name_(pc_name) { + signatures_.push_back(FunctionSignature(base_name, param_types, ret_type)); + for (auto& func_name : aliases) { + signatures_.push_back(FunctionSignature(func_name, param_types, ret_type)); + } + } + + private: + std::vector<FunctionSignature> signatures_; + + /// attributes + int32_t flags_; + ResultNullableType result_nullable_type_; + + /// pre-compiled function name. + std::string pc_name_; +}; + +} // end namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/node.h b/src/arrow/cpp/src/gandiva/node.h new file mode 100644 index 000000000..20807d4a0 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/node.h @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <sstream> +#include <string> +#include <unordered_set> +#include <vector> + +#include "arrow/status.h" + +#include "gandiva/arrow.h" +#include "gandiva/func_descriptor.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/literal_holder.h" +#include "gandiva/node_visitor.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Represents a node in the expression tree. Validity and value are +/// in a joined state. +class GANDIVA_EXPORT Node { + public: + explicit Node(DataTypePtr return_type) : return_type_(return_type) {} + + virtual ~Node() = default; + + const DataTypePtr& return_type() const { return return_type_; } + + /// Derived classes should simply invoke the Visit api of the visitor. + virtual Status Accept(NodeVisitor& visitor) const = 0; + + virtual std::string ToString() const = 0; + + protected: + DataTypePtr return_type_; +}; + +/// \brief Node in the expression tree, representing a literal. +class GANDIVA_EXPORT LiteralNode : public Node { + public: + LiteralNode(DataTypePtr type, const LiteralHolder& holder, bool is_null) + : Node(type), holder_(holder), is_null_(is_null) {} + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + const LiteralHolder& holder() const { return holder_; } + + bool is_null() const { return is_null_; } + + std::string ToString() const override { + std::stringstream ss; + ss << "(const " << return_type()->ToString() << ") "; + if (is_null()) { + ss << std::string("null"); + return ss.str(); + } + + ss << gandiva::ToString(holder_); + // The default formatter prints in decimal can cause a loss in precision. so, + // print in hex. Can't use hexfloat since gcc 4.9 doesn't support it. + if (return_type()->id() == arrow::Type::DOUBLE) { + double dvalue = arrow::util::get<double>(holder_); + uint64_t bits; + memcpy(&bits, &dvalue, sizeof(bits)); + ss << " raw(" << std::hex << bits << ")"; + } else if (return_type()->id() == arrow::Type::FLOAT) { + float fvalue = arrow::util::get<float>(holder_); + uint32_t bits; + memcpy(&bits, &fvalue, sizeof(bits)); + ss << " raw(" << std::hex << bits << ")"; + } + return ss.str(); + } + + private: + LiteralHolder holder_; + bool is_null_; +}; + +/// \brief Node in the expression tree, representing an arrow field. +class GANDIVA_EXPORT FieldNode : public Node { + public: + explicit FieldNode(FieldPtr field) : Node(field->type()), field_(field) {} + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + const FieldPtr& field() const { return field_; } + + std::string ToString() const override { + return "(" + field()->type()->ToString() + ") " + field()->name(); + } + + private: + FieldPtr field_; +}; + +/// \brief Node in the expression tree, representing a function. +class GANDIVA_EXPORT FunctionNode : public Node { + public: + FunctionNode(const std::string& name, const NodeVector& children, DataTypePtr retType); + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + const FuncDescriptorPtr& descriptor() const { return descriptor_; } + const NodeVector& children() const { return children_; } + + std::string ToString() const override { + std::stringstream ss; + ss << descriptor()->return_type()->ToString() << " " << descriptor()->name() << "("; + bool skip_comma = true; + for (auto& child : children()) { + if (skip_comma) { + ss << child->ToString(); + skip_comma = false; + } else { + ss << ", " << child->ToString(); + } + } + ss << ")"; + return ss.str(); + } + + private: + FuncDescriptorPtr descriptor_; + NodeVector children_; +}; + +inline FunctionNode::FunctionNode(const std::string& name, const NodeVector& children, + DataTypePtr return_type) + : Node(return_type), children_(children) { + DataTypeVector param_types; + for (auto& child : children) { + param_types.push_back(child->return_type()); + } + + descriptor_ = FuncDescriptorPtr(new FuncDescriptor(name, param_types, return_type)); +} + +/// \brief Node in the expression tree, representing an if-else expression. +class GANDIVA_EXPORT IfNode : public Node { + public: + IfNode(NodePtr condition, NodePtr then_node, NodePtr else_node, DataTypePtr result_type) + : Node(result_type), + condition_(condition), + then_node_(then_node), + else_node_(else_node) {} + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + const NodePtr& condition() const { return condition_; } + const NodePtr& then_node() const { return then_node_; } + const NodePtr& else_node() const { return else_node_; } + + std::string ToString() const override { + std::stringstream ss; + ss << "if (" << condition()->ToString() << ") { "; + ss << then_node()->ToString() << " } else { "; + ss << else_node()->ToString() << " }"; + return ss.str(); + } + + private: + NodePtr condition_; + NodePtr then_node_; + NodePtr else_node_; +}; + +/// \brief Node in the expression tree, representing an and/or boolean expression. +class GANDIVA_EXPORT BooleanNode : public Node { + public: + enum ExprType : char { AND, OR }; + + BooleanNode(ExprType expr_type, const NodeVector& children) + : Node(arrow::boolean()), expr_type_(expr_type), children_(children) {} + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + ExprType expr_type() const { return expr_type_; } + + const NodeVector& children() const { return children_; } + + std::string ToString() const override { + std::stringstream ss; + bool first = true; + for (auto& child : children_) { + if (!first) { + if (expr_type() == BooleanNode::AND) { + ss << " && "; + } else { + ss << " || "; + } + } + ss << child->ToString(); + first = false; + } + return ss.str(); + } + + private: + ExprType expr_type_; + NodeVector children_; +}; + +/// \brief Node in expression tree, representing an in expression. +template <typename Type> +class InExpressionNode : public Node { + public: + InExpressionNode(NodePtr eval_expr, const std::unordered_set<Type>& values) + : Node(arrow::boolean()), eval_expr_(eval_expr), values_(values) {} + + const NodePtr& eval_expr() const { return eval_expr_; } + + const std::unordered_set<Type>& values() const { return values_; } + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + std::string ToString() const override { + std::stringstream ss; + ss << eval_expr_->ToString() << " IN ("; + bool add_comma = false; + for (auto& value : values_) { + if (add_comma) { + ss << ", "; + } + // add type in the front to differentiate + ss << value; + add_comma = true; + } + ss << ")"; + return ss.str(); + } + + private: + NodePtr eval_expr_; + std::unordered_set<Type> values_; +}; + +template <> +class InExpressionNode<gandiva::DecimalScalar128> : public Node { + public: + InExpressionNode(NodePtr eval_expr, + std::unordered_set<gandiva::DecimalScalar128>& values, + int32_t precision, int32_t scale) + : Node(arrow::boolean()), + eval_expr_(std::move(eval_expr)), + values_(std::move(values)), + precision_(precision), + scale_(scale) {} + + int32_t get_precision() const { return precision_; } + + int32_t get_scale() const { return scale_; } + + const NodePtr& eval_expr() const { return eval_expr_; } + + const std::unordered_set<gandiva::DecimalScalar128>& values() const { return values_; } + + Status Accept(NodeVisitor& visitor) const override { return visitor.Visit(*this); } + + std::string ToString() const override { + std::stringstream ss; + ss << eval_expr_->ToString() << " IN ("; + bool add_comma = false; + for (auto& value : values_) { + if (add_comma) { + ss << ", "; + } + // add type in the front to differentiate + ss << value; + add_comma = true; + } + ss << ")"; + return ss.str(); + } + + private: + NodePtr eval_expr_; + std::unordered_set<gandiva::DecimalScalar128> values_; + int32_t precision_, scale_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/node_visitor.h b/src/arrow/cpp/src/gandiva/node_visitor.h new file mode 100644 index 000000000..8f233f5b7 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/node_visitor.h @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <string> + +#include "arrow/status.h" + +#include "arrow/util/logging.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class FieldNode; +class FunctionNode; +class IfNode; +class LiteralNode; +class BooleanNode; +template <typename Type> +class InExpressionNode; + +/// \brief Visitor for nodes in the expression tree. +class GANDIVA_EXPORT NodeVisitor { + public: + virtual ~NodeVisitor() = default; + + virtual Status Visit(const FieldNode& node) = 0; + virtual Status Visit(const FunctionNode& node) = 0; + virtual Status Visit(const IfNode& node) = 0; + virtual Status Visit(const LiteralNode& node) = 0; + virtual Status Visit(const BooleanNode& node) = 0; + virtual Status Visit(const InExpressionNode<int32_t>& node) = 0; + virtual Status Visit(const InExpressionNode<int64_t>& node) = 0; + virtual Status Visit(const InExpressionNode<float>& node) = 0; + virtual Status Visit(const InExpressionNode<double>& node) = 0; + virtual Status Visit(const InExpressionNode<gandiva::DecimalScalar128>& node) = 0; + virtual Status Visit(const InExpressionNode<std::string>& node) = 0; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/pch.h b/src/arrow/cpp/src/gandiva/pch.h new file mode 100644 index 000000000..f3d9b2fad --- /dev/null +++ b/src/arrow/cpp/src/gandiva/pch.h @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Often-used headers, for precompiling. +// If updating this header, please make sure you check compilation speed +// before checking in. Adding headers which are not used extremely often +// may incur a slowdown, since it makes the precompiled header heavier to load. + +#include "arrow/pch.h" +#include "gandiva/llvm_types.h" diff --git a/src/arrow/cpp/src/gandiva/precompiled/CMakeLists.txt b/src/arrow/cpp/src/gandiva/precompiled/CMakeLists.txt new file mode 100644 index 000000000..650b80f6b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/CMakeLists.txt @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +project(gandiva) + +set(PRECOMPILED_SRCS + arithmetic_ops.cc + bitmap.cc + decimal_ops.cc + decimal_wrapper.cc + extended_math_ops.cc + hash.cc + print.cc + string_ops.cc + time.cc + timestamp_arithmetic.cc + ../../arrow/util/basic_decimal.cc) + +if(MSVC) + # clang pretends to be a particular version of MSVC. 191[0-9] is + # Visual Studio 2017, and the standard library uses C++14 features, + # so we have to use that -std version to get the IR compilation to work + if(MSVC_VERSION MATCHES "^191[0-9]$") + set(FMS_COMPATIBILITY 19.10) + else() + message(FATAL_ERROR "Unsupported MSVC_VERSION=${MSVC_VERSION}") + endif() + set(PLATFORM_CLANG_OPTIONS -std=c++14 -fms-compatibility + -fms-compatibility-version=${FMS_COMPATIBILITY}) +else() + set(PLATFORM_CLANG_OPTIONS -std=c++11) +endif() + +# Create bitcode for each of the source files. +foreach(SRC_FILE ${PRECOMPILED_SRCS}) + get_filename_component(SRC_BASE ${SRC_FILE} NAME_WE) + get_filename_component(ABSOLUTE_SRC ${SRC_FILE} ABSOLUTE) + set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${SRC_BASE}.bc) + set(PRECOMPILE_COMMAND) + if(CMAKE_OSX_SYSROOT) + list(APPEND + PRECOMPILE_COMMAND + ${CMAKE_COMMAND} + -E + env + SDKROOT=${CMAKE_OSX_SYSROOT}) + endif() + list(APPEND + PRECOMPILE_COMMAND + ${CLANG_EXECUTABLE} + ${PLATFORM_CLANG_OPTIONS} + -DGANDIVA_IR + -DNDEBUG # DCHECK macros not implemented in precompiled code + -DARROW_STATIC # Do not set __declspec(dllimport) on MSVC on Arrow symbols + -DGANDIVA_STATIC # Do not set __declspec(dllimport) on MSVC on Gandiva symbols + -fno-use-cxa-atexit # Workaround for unresolved __dso_handle + -emit-llvm + -O3 + -c + ${ABSOLUTE_SRC} + -o + ${BC_FILE} + ${ARROW_GANDIVA_PC_CXX_FLAGS} + -I${CMAKE_SOURCE_DIR}/src + -I${ARROW_BINARY_DIR}/src) + + if(NOT ARROW_USE_NATIVE_INT128) + list(APPEND PRECOMPILE_COMMAND -I${Boost_INCLUDE_DIR}) + endif() + add_custom_command(OUTPUT ${BC_FILE} + COMMAND ${PRECOMPILE_COMMAND} + DEPENDS ${SRC_FILE}) + list(APPEND BC_FILES ${BC_FILE}) +endforeach() + +# link all of the bitcode files into a single bitcode file. +add_custom_command(OUTPUT ${GANDIVA_PRECOMPILED_BC_PATH} + COMMAND ${LLVM_LINK_EXECUTABLE} -o ${GANDIVA_PRECOMPILED_BC_PATH} + ${BC_FILES} + DEPENDS ${BC_FILES}) + +# turn the bitcode file into a C++ static data variable. +add_custom_command(OUTPUT ${GANDIVA_PRECOMPILED_CC_PATH} + COMMAND ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_SOURCE_DIR}/../make_precompiled_bitcode.py" + ${GANDIVA_PRECOMPILED_CC_IN_PATH} + ${GANDIVA_PRECOMPILED_BC_PATH} ${GANDIVA_PRECOMPILED_CC_PATH} + DEPENDS ${GANDIVA_PRECOMPILED_CC_IN_PATH} + ${GANDIVA_PRECOMPILED_BC_PATH}) + +add_custom_target(precompiled ALL DEPENDS ${GANDIVA_PRECOMPILED_BC_PATH} + ${GANDIVA_PRECOMPILED_CC_PATH}) + +# testing +if(ARROW_BUILD_TESTS) + add_executable(gandiva-precompiled-test + ../context_helper.cc + bitmap_test.cc + bitmap.cc + epoch_time_point_test.cc + time_test.cc + time.cc + timestamp_arithmetic.cc + ../cast_time.cc + ../../arrow/vendored/datetime/tz.cpp + hash_test.cc + hash.cc + string_ops_test.cc + string_ops.cc + arithmetic_ops_test.cc + arithmetic_ops.cc + extended_math_ops_test.cc + extended_math_ops.cc + decimal_ops_test.cc + decimal_ops.cc + ../decimal_type_util.cc + ../decimal_xlarge.cc) + target_include_directories(gandiva-precompiled-test PRIVATE ${CMAKE_SOURCE_DIR}/src) + target_link_libraries(gandiva-precompiled-test PRIVATE ${ARROW_TEST_LINK_LIBS}) + target_compile_definitions(gandiva-precompiled-test PRIVATE GANDIVA_UNIT_TEST=1 + ARROW_STATIC GANDIVA_STATIC) + set(TEST_PATH "${EXECUTABLE_OUTPUT_PATH}/gandiva-precompiled-test") + add_test(gandiva-precompiled-test ${TEST_PATH}) + set_property(TEST gandiva-precompiled-test + APPEND + PROPERTY LABELS "unittest;gandiva-tests") + add_dependencies(gandiva-tests gandiva-precompiled-test) +endif() diff --git a/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops.cc new file mode 100644 index 000000000..c736c38d3 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops.cc @@ -0,0 +1,274 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern "C" { + +#include <math.h> +#include "./types.h" + +// Expand inner macro for all numeric types. +#define NUMERIC_TYPES(INNER, NAME, OP) \ + INNER(NAME, int8, OP) \ + INNER(NAME, int16, OP) \ + INNER(NAME, int32, OP) \ + INNER(NAME, int64, OP) \ + INNER(NAME, uint8, OP) \ + INNER(NAME, uint16, OP) \ + INNER(NAME, uint32, OP) \ + INNER(NAME, uint64, OP) \ + INNER(NAME, float32, OP) \ + INNER(NAME, float64, OP) + +// Expand inner macros for all date/time types. +#define DATE_TYPES(INNER, NAME, OP) \ + INNER(NAME, date64, OP) \ + INNER(NAME, date32, OP) \ + INNER(NAME, timestamp, OP) \ + INNER(NAME, time32, OP) + +#define NUMERIC_DATE_TYPES(INNER, NAME, OP) \ + NUMERIC_TYPES(INNER, NAME, OP) \ + DATE_TYPES(INNER, NAME, OP) + +#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME, OP) \ + NUMERIC_TYPES(INNER, NAME, OP) \ + DATE_TYPES(INNER, NAME, OP) \ + INNER(NAME, boolean, OP) + +#define MOD_OP(NAME, IN_TYPE1, IN_TYPE2, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE NAME##_##IN_TYPE1##_##IN_TYPE2(gdv_##IN_TYPE1 left, \ + gdv_##IN_TYPE2 right) { \ + return (right == 0 ? static_cast<gdv_##OUT_TYPE>(left) \ + : static_cast<gdv_##OUT_TYPE>(left % right)); \ + } + +// Symmetric binary fns : left, right params and return type are same. +#define BINARY_SYMMETRIC(NAME, TYPE, OP) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE##_##TYPE(gdv_##TYPE left, gdv_##TYPE right) { \ + return static_cast<gdv_##TYPE>(left OP right); \ + } + +NUMERIC_TYPES(BINARY_SYMMETRIC, add, +) +NUMERIC_TYPES(BINARY_SYMMETRIC, subtract, -) +NUMERIC_TYPES(BINARY_SYMMETRIC, multiply, *) +BINARY_SYMMETRIC(bitwise_and, int32, &) +BINARY_SYMMETRIC(bitwise_and, int64, &) +BINARY_SYMMETRIC(bitwise_or, int32, |) +BINARY_SYMMETRIC(bitwise_or, int64, |) +BINARY_SYMMETRIC(bitwise_xor, int32, ^) +BINARY_SYMMETRIC(bitwise_xor, int64, ^) + +#undef BINARY_SYMMETRIC + +MOD_OP(mod, int64, int32, int32) +MOD_OP(mod, int64, int64, int64) + +#undef MOD_OP + +gdv_float64 mod_float64_float64(int64_t context, gdv_float64 x, gdv_float64 y) { + if (y == 0.0) { + char const* err_msg = "divide by zero error"; + gdv_fn_context_set_error_msg(context, err_msg); + return 0.0; + } + return fmod(x, y); +} + +// Relational binary fns : left, right params are same, return is bool. +#define BINARY_RELATIONAL(NAME, TYPE, OP) \ + FORCE_INLINE \ + bool NAME##_##TYPE##_##TYPE(gdv_##TYPE left, gdv_##TYPE right) { return left OP right; } + +NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL, equal, ==) +NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL, not_equal, !=) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, less_than, <) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, greater_than, >) +NUMERIC_DATE_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=) + +#undef BINARY_RELATIONAL + +// cast fns : takes one param type, returns another type. +#define CAST_UNARY(NAME, IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE NAME##_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(in); \ + } + +CAST_UNARY(castBIGINT, int32, int64) +CAST_UNARY(castINT, int64, int32) +CAST_UNARY(castFLOAT4, int32, float32) +CAST_UNARY(castFLOAT4, int64, float32) +CAST_UNARY(castFLOAT8, int32, float64) +CAST_UNARY(castFLOAT8, int64, float64) +CAST_UNARY(castFLOAT8, float32, float64) +CAST_UNARY(castFLOAT4, float64, float32) + +#undef CAST_UNARY + +// cast float types to int types. +#define CAST_INT_FLOAT(NAME, IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE NAME##_##IN_TYPE(gdv_##IN_TYPE in) { \ + gdv_##OUT_TYPE out = static_cast<gdv_##OUT_TYPE>(round(in)); \ + return out; \ + } + +CAST_INT_FLOAT(castBIGINT, float32, int64) +CAST_INT_FLOAT(castBIGINT, float64, int64) +CAST_INT_FLOAT(castINT, float32, int32) +CAST_INT_FLOAT(castINT, float64, int32) + +#undef CAST_INT_FLOAT + +// simple nullable functions, result value = fn(input validity) +#define VALIDITY_OP(NAME, TYPE, OP) \ + FORCE_INLINE \ + bool NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid) { return OP is_valid; } + +NUMERIC_BOOL_DATE_TYPES(VALIDITY_OP, isnull, !) +NUMERIC_BOOL_DATE_TYPES(VALIDITY_OP, isnotnull, +) +NUMERIC_TYPES(VALIDITY_OP, isnumeric, +) + +#undef VALIDITY_OP + +#define NUMERIC_FUNCTION(INNER) \ + INNER(int8) \ + INNER(int16) \ + INNER(int32) \ + INNER(int64) \ + INNER(uint8) \ + INNER(uint16) \ + INNER(uint32) \ + INNER(uint64) \ + INNER(float32) \ + INNER(float64) + +#define DATE_FUNCTION(INNER) \ + INNER(date32) \ + INNER(date64) \ + INNER(timestamp) \ + INNER(time32) + +#define NUMERIC_BOOL_DATE_FUNCTION(INNER) \ + NUMERIC_FUNCTION(INNER) \ + DATE_FUNCTION(INNER) \ + INNER(boolean) + +FORCE_INLINE +gdv_boolean not_boolean(gdv_boolean in) { return !in; } + +// is_distinct_from +#define IS_DISTINCT_FROM(TYPE) \ + FORCE_INLINE \ + bool is_distinct_from_##TYPE##_##TYPE(gdv_##TYPE in1, gdv_boolean is_valid1, \ + gdv_##TYPE in2, gdv_boolean is_valid2) { \ + if (is_valid1 != is_valid2) { \ + return true; \ + } \ + if (!is_valid1) { \ + return false; \ + } \ + return in1 != in2; \ + } + +// is_not_distinct_from +#define IS_NOT_DISTINCT_FROM(TYPE) \ + FORCE_INLINE \ + bool is_not_distinct_from_##TYPE##_##TYPE(gdv_##TYPE in1, gdv_boolean is_valid1, \ + gdv_##TYPE in2, gdv_boolean is_valid2) { \ + if (is_valid1 != is_valid2) { \ + return false; \ + } \ + if (!is_valid1) { \ + return true; \ + } \ + return in1 == in2; \ + } + +NUMERIC_BOOL_DATE_FUNCTION(IS_DISTINCT_FROM) +NUMERIC_BOOL_DATE_FUNCTION(IS_NOT_DISTINCT_FROM) + +#undef IS_DISTINCT_FROM +#undef IS_NOT_DISTINCT_FROM + +#define DIVIDE(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE divide_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \ + if (in2 == 0) { \ + char const* err_msg = "divide by zero error"; \ + gdv_fn_context_set_error_msg(context, err_msg); \ + return 0; \ + } \ + return static_cast<gdv_##TYPE>(in1 / in2); \ + } + +NUMERIC_FUNCTION(DIVIDE) + +#undef DIVIDE + +#define DIV(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE div_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \ + if (in2 == 0) { \ + char const* err_msg = "divide by zero error"; \ + gdv_fn_context_set_error_msg(context, err_msg); \ + return 0; \ + } \ + return static_cast<gdv_##TYPE>(in1 / in2); \ + } + +DIV(int32) +DIV(int64) + +#undef DIV + +#define DIV_FLOAT(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE div_##TYPE##_##TYPE(gdv_int64 context, gdv_##TYPE in1, gdv_##TYPE in2) { \ + if (in2 == 0) { \ + char const* err_msg = "divide by zero error"; \ + gdv_fn_context_set_error_msg(context, err_msg); \ + return 0; \ + } \ + return static_cast<gdv_##TYPE>(::trunc(in1 / in2)); \ + } + +DIV_FLOAT(float32) +DIV_FLOAT(float64) + +#undef DIV_FLOAT + +#define BITWISE_NOT(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE bitwise_not_##TYPE(gdv_##TYPE in) { return static_cast<gdv_##TYPE>(~in); } + +BITWISE_NOT(int32) +BITWISE_NOT(int64) + +#undef BITWISE_NOT + +#undef DATE_FUNCTION +#undef DATE_TYPES +#undef NUMERIC_BOOL_DATE_TYPES +#undef NUMERIC_DATE_TYPES +#undef NUMERIC_FUNCTION +#undef NUMERIC_TYPES + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc new file mode 100644 index 000000000..36b50bcfd --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "../execution_context.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestArithmeticOps, TestIsDistinctFrom) { + EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, true, 1000, false), true); + EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, false, 1000, true), true); + EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, false, 1000, false), false); + EXPECT_EQ(is_distinct_from_timestamp_timestamp(1000, true, 1000, true), false); + + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, true, 1000, false), false); + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, false, 1000, true), false); + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, false, 1000, false), true); + EXPECT_EQ(is_not_distinct_from_int32_int32(1000, true, 1000, true), true); +} + +TEST(TestArithmeticOps, TestMod) { + gandiva::ExecutionContext context; + EXPECT_EQ(mod_int64_int32(10, 0), 10); + + const double acceptable_abs_error = 0.00000000001; // 1e-10 + + EXPECT_DOUBLE_EQ(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 0.0), + 0.0); + EXPECT_TRUE(context.has_error()); + EXPECT_EQ(context.get_error(), "divide by zero error"); + + context.Reset(); + EXPECT_NEAR(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 1.2), 0.1, + acceptable_abs_error); + EXPECT_FALSE(context.has_error()); + + context.Reset(); + EXPECT_DOUBLE_EQ(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 2.5, 2.5), + 0.0); + EXPECT_FALSE(context.has_error()); + + context.Reset(); + EXPECT_NEAR(mod_float64_float64(reinterpret_cast<gdv_int64>(&context), 9.2, 3.7), 1.8, + acceptable_abs_error); + EXPECT_FALSE(context.has_error()); +} + +TEST(TestArithmeticOps, TestDivide) { + gandiva::ExecutionContext context; + EXPECT_EQ(divide_int64_int64(reinterpret_cast<gdv_int64>(&context), 10, 0), 0); + EXPECT_EQ(context.has_error(), true); + EXPECT_EQ(context.get_error(), "divide by zero error"); + + context.Reset(); + EXPECT_EQ(divide_int64_int64(reinterpret_cast<gdv_int64>(&context), 10, 2), 5); + EXPECT_EQ(context.has_error(), false); +} + +TEST(TestArithmeticOps, TestDiv) { + gandiva::ExecutionContext context; + EXPECT_EQ(div_int64_int64(reinterpret_cast<gdv_int64>(&context), 101, 0), 0); + EXPECT_EQ(context.has_error(), true); + EXPECT_EQ(context.get_error(), "divide by zero error"); + context.Reset(); + + EXPECT_EQ(div_int64_int64(reinterpret_cast<gdv_int64>(&context), 101, 111), 0); + EXPECT_EQ(context.has_error(), false); + context.Reset(); + + EXPECT_EQ(div_float64_float64(reinterpret_cast<gdv_int64>(&context), 1010.1010, 2.1), + 481.0); + EXPECT_EQ(context.has_error(), false); + context.Reset(); + + EXPECT_EQ( + div_float64_float64(reinterpret_cast<gdv_int64>(&context), 1010.1010, 0.00000), + 0.0); + EXPECT_EQ(context.has_error(), true); + EXPECT_EQ(context.get_error(), "divide by zero error"); + context.Reset(); + + EXPECT_EQ(div_float32_float32(reinterpret_cast<gdv_int64>(&context), 1010.1010f, 2.1f), + 481.0f); + EXPECT_EQ(context.has_error(), false); + context.Reset(); +} + +TEST(TestArithmeticOps, TestBitwiseOps) { + // bitwise AND + EXPECT_EQ(bitwise_and_int32_int32(0x0147D, 0x17159), 0x01059); + EXPECT_EQ(bitwise_and_int32_int32(0xFFFFFFCC, 0x00000297), 0x00000284); + EXPECT_EQ(bitwise_and_int32_int32(0x000, 0x285), 0x000); + EXPECT_EQ(bitwise_and_int64_int64(0x563672F83, 0x0D9FCF85B), 0x041642803); + EXPECT_EQ(bitwise_and_int64_int64(0xFFFFFFFFFFDA8F6A, 0xFFFFFFFFFFFF791C), + 0xFFFFFFFFFFDA0908); + EXPECT_EQ(bitwise_and_int64_int64(0x6A5B1, 0x00000), 0x00000); + + // bitwise OR + EXPECT_EQ(bitwise_or_int32_int32(0x0147D, 0x17159), 0x1757D); + EXPECT_EQ(bitwise_or_int32_int32(0xFFFFFFCC, 0x00000297), 0xFFFFFFDF); + EXPECT_EQ(bitwise_or_int32_int32(0x000, 0x285), 0x285); + EXPECT_EQ(bitwise_or_int64_int64(0x563672F83, 0x0D9FCF85B), 0x5FBFFFFDB); + EXPECT_EQ(bitwise_or_int64_int64(0xFFFFFFFFFFDA8F6A, 0xFFFFFFFFFFFF791C), + 0xFFFFFFFFFFFFFF7E); + EXPECT_EQ(bitwise_or_int64_int64(0x6A5B1, 0x00000), 0x6A5B1); + + // bitwise XOR + EXPECT_EQ(bitwise_xor_int32_int32(0x0147D, 0x17159), 0x16524); + EXPECT_EQ(bitwise_xor_int32_int32(0xFFFFFFCC, 0x00000297), 0XFFFFFD5B); + EXPECT_EQ(bitwise_xor_int32_int32(0x000, 0x285), 0x285); + EXPECT_EQ(bitwise_xor_int64_int64(0x563672F83, 0x0D9FCF85B), 0x5BA9BD7D8); + EXPECT_EQ(bitwise_xor_int64_int64(0xFFFFFFFFFFDA8F6A, 0xFFFFFFFFFFFF791C), 0X25F676); + EXPECT_EQ(bitwise_xor_int64_int64(0x6A5B1, 0x00000), 0x6A5B1); + EXPECT_EQ(bitwise_xor_int64_int64(0x6A5B1, 0x6A5B1), 0x00000); + + // bitwise NOT + EXPECT_EQ(bitwise_not_int32(0x00017159), 0xFFFE8EA6); + EXPECT_EQ(bitwise_not_int32(0xFFFFF226), 0x00000DD9); + EXPECT_EQ(bitwise_not_int64(0x000000008BCAE9B4), 0xFFFFFFFF7435164B); + EXPECT_EQ(bitwise_not_int64(0xFFFFFF966C8D7997), 0x0000006993728668); + EXPECT_EQ(bitwise_not_int64(0x0000000000000000), 0xFFFFFFFFFFFFFFFF); +} + +TEST(TestArithmeticOps, TestIntCastFloatDouble) { + // castINT from floats + EXPECT_EQ(castINT_float32(6.6f), 7); + EXPECT_EQ(castINT_float32(-6.6f), -7); + EXPECT_EQ(castINT_float32(-6.3f), -6); + EXPECT_EQ(castINT_float32(0.0f), 0); + EXPECT_EQ(castINT_float32(-0), 0); + + // castINT from doubles + EXPECT_EQ(castINT_float64(6.6), 7); + EXPECT_EQ(castINT_float64(-6.6), -7); + EXPECT_EQ(castINT_float64(-6.3), -6); + EXPECT_EQ(castINT_float64(0.0), 0); + EXPECT_EQ(castINT_float64(-0), 0); + EXPECT_EQ(castINT_float64(999999.99999999999999999999999), 1000000); + EXPECT_EQ(castINT_float64(-999999.99999999999999999999999), -1000000); + EXPECT_EQ(castINT_float64(INT32_MAX), 2147483647); + EXPECT_EQ(castINT_float64(-2147483647), -2147483647); +} + +TEST(TestArithmeticOps, TestBigIntCastFloatDouble) { + // castINT from floats + EXPECT_EQ(castBIGINT_float32(6.6f), 7); + EXPECT_EQ(castBIGINT_float32(-6.6f), -7); + EXPECT_EQ(castBIGINT_float32(-6.3f), -6); + EXPECT_EQ(castBIGINT_float32(0.0f), 0); + EXPECT_EQ(castBIGINT_float32(-0), 0); + + // castINT from doubles + EXPECT_EQ(castBIGINT_float64(6.6), 7); + EXPECT_EQ(castBIGINT_float64(-6.6), -7); + EXPECT_EQ(castBIGINT_float64(-6.3), -6); + EXPECT_EQ(castBIGINT_float64(0.0), 0); + EXPECT_EQ(castBIGINT_float64(-0), 0); + EXPECT_EQ(castBIGINT_float64(999999.99999999999999999999999), 1000000); + EXPECT_EQ(castBIGINT_float64(-999999.99999999999999999999999), -1000000); + EXPECT_EQ(castBIGINT_float64(INT32_MAX), 2147483647); + EXPECT_EQ(castBIGINT_float64(-2147483647), -2147483647); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/bitmap.cc b/src/arrow/cpp/src/gandiva/precompiled/bitmap.cc new file mode 100644 index 000000000..332f08dbe --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/bitmap.cc @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// BitMap functions + +#include "arrow/util/bit_util.h" + +extern "C" { + +#include "./types.h" + +#define BITS_TO_BYTES(x) ((x + 7) / 8) +#define BITS_TO_WORDS(x) ((x + 63) / 64) + +#define POS_TO_BYTE_INDEX(p) (p / 8) +#define POS_TO_BIT_INDEX(p) (p % 8) + +FORCE_INLINE +bool bitMapGetBit(const uint8_t* bmap, int64_t position) { + return arrow::BitUtil::GetBit(bmap, position); +} + +FORCE_INLINE +bool bitMapValidityGetBit(const uint8_t* bmap, int64_t position) { + if (bmap == nullptr) { + // if validity bitmap is null, all entries are valid. + return true; + } else { + return bitMapGetBit(bmap, position); + } +} + +FORCE_INLINE +void bitMapSetBit(uint8_t* bmap, int64_t position, bool value) { + arrow::BitUtil::SetBitTo(bmap, position, value); +} + +// Clear the bit if value = false. Does nothing if value = true. +FORCE_INLINE +void bitMapClearBitIfFalse(uint8_t* bmap, int64_t position, bool value) { + if (!value) { + arrow::BitUtil::ClearBit(bmap, position); + } +} + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/bitmap_test.cc b/src/arrow/cpp/src/gandiva/precompiled/bitmap_test.cc new file mode 100644 index 000000000..ac3084ade --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/bitmap_test.cc @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestBitMap, TestSimple) { + static const int kNumBytes = 16; + uint8_t bit_map[kNumBytes]; + memset(bit_map, 0, kNumBytes); + + EXPECT_EQ(bitMapGetBit(bit_map, 100), false); + + // set 100th bit and verify + bitMapSetBit(bit_map, 100, true); + EXPECT_EQ(bitMapGetBit(bit_map, 100), true); + + // clear 100th bit and verify + bitMapSetBit(bit_map, 100, false); + EXPECT_EQ(bitMapGetBit(bit_map, 100), false); +} + +TEST(TestBitMap, TestClearIfFalse) { + static const int kNumBytes = 32; + uint8_t bit_map[kNumBytes]; + memset(bit_map, 0, kNumBytes); + + bitMapSetBit(bit_map, 24, true); + + // bit should remain unchanged. + bitMapClearBitIfFalse(bit_map, 24, true); + EXPECT_EQ(bitMapGetBit(bit_map, 24), true); + + // bit should be cleared. + bitMapClearBitIfFalse(bit_map, 24, false); + EXPECT_EQ(bitMapGetBit(bit_map, 24), false); + + // this function should have no impact if the bit is already clear. + bitMapClearBitIfFalse(bit_map, 24, true); + EXPECT_EQ(bitMapGetBit(bit_map, 24), false); + + bitMapClearBitIfFalse(bit_map, 24, false); + EXPECT_EQ(bitMapGetBit(bit_map, 24), false); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.cc new file mode 100644 index 000000000..61cac6062 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.cc @@ -0,0 +1,723 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Algorithms adapted from Apache Impala + +#include "gandiva/precompiled/decimal_ops.h" + +#include <algorithm> +#include <cmath> +#include <limits> + +#include "arrow/util/logging.h" +#include "gandiva/decimal_type_util.h" +#include "gandiva/decimal_xlarge.h" +#include "gandiva/gdv_function_stubs.h" + +// Several operations (multiply, divide, mod, ..) require converting to 256-bit, and we +// use the boost library for doing 256-bit operations. To avoid references to boost from +// the precompiled-to-ir code (this causes issues with symbol resolution at runtime), we +// use a wrapper exported from the CPP code. The wrapper functions are named gdv_xlarge_xx + +namespace gandiva { +namespace decimalops { + +using arrow::BasicDecimal128; + +static BasicDecimal128 CheckAndIncreaseScale(const BasicDecimal128& in, int32_t delta) { + return (delta <= 0) ? in : in.IncreaseScaleBy(delta); +} + +static BasicDecimal128 CheckAndReduceScale(const BasicDecimal128& in, int32_t delta) { + return (delta <= 0) ? in : in.ReduceScaleBy(delta); +} + +/// Adjust x and y to the same scale, and add them. +static BasicDecimal128 AddFastPath(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_scale) { + auto higher_scale = std::max(x.scale(), y.scale()); + + auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale()); + auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale()); + return x_scaled + y_scaled; +} + +/// Add x and y, caller has ensured there can be no overflow. +static BasicDecimal128 AddNoOverflow(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_scale) { + auto higher_scale = std::max(x.scale(), y.scale()); + auto sum = AddFastPath(x, y, out_scale); + return CheckAndReduceScale(sum, higher_scale - out_scale); +} + +/// Both x_value and y_value must be >= 0 +static BasicDecimal128 AddLargePositive(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale) { + DCHECK_GE(x.value(), 0); + DCHECK_GE(y.value(), 0); + + // separate out whole/fractions. + BasicDecimal128 x_left, x_right, y_left, y_right; + x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right); + y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right); + + // Adjust fractional parts to higher scale. + auto higher_scale = std::max(x.scale(), y.scale()); + auto x_right_scaled = CheckAndIncreaseScale(x_right, higher_scale - x.scale()); + auto y_right_scaled = CheckAndIncreaseScale(y_right, higher_scale - y.scale()); + + BasicDecimal128 right; + BasicDecimal128 carry_to_left; + auto multiplier = BasicDecimal128::GetScaleMultiplier(higher_scale); + if (x_right_scaled >= multiplier - y_right_scaled) { + right = x_right_scaled - (multiplier - y_right_scaled); + carry_to_left = 1; + } else { + right = x_right_scaled + y_right_scaled; + carry_to_left = 0; + } + right = CheckAndReduceScale(right, higher_scale - out_scale); + + auto left = x_left + y_left + carry_to_left; + return (left * BasicDecimal128::GetScaleMultiplier(out_scale)) + right; +} + +/// x_value and y_value cannot be 0, and one must be positive and the other negative. +static BasicDecimal128 AddLargeNegative(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale) { + DCHECK_NE(x.value(), 0); + DCHECK_NE(y.value(), 0); + DCHECK((x.value() < 0 && y.value() > 0) || (x.value() > 0 && y.value() < 0)); + + // separate out whole/fractions. + BasicDecimal128 x_left, x_right, y_left, y_right; + x.value().GetWholeAndFraction(x.scale(), &x_left, &x_right); + y.value().GetWholeAndFraction(y.scale(), &y_left, &y_right); + + // Adjust fractional parts to higher scale. + auto higher_scale = std::max(x.scale(), y.scale()); + x_right = CheckAndIncreaseScale(x_right, higher_scale - x.scale()); + y_right = CheckAndIncreaseScale(y_right, higher_scale - y.scale()); + + // Overflow not possible because one is +ve and the other is -ve. + auto left = x_left + y_left; + auto right = x_right + y_right; + + // If the whole and fractional parts have different signs, then we need to make the + // fractional part have the same sign as the whole part. If either left or right is + // zero, then nothing needs to be done. + if (left < 0 && right > 0) { + left += 1; + right -= BasicDecimal128::GetScaleMultiplier(higher_scale); + } else if (left > 0 && right < 0) { + left -= 1; + right += BasicDecimal128::GetScaleMultiplier(higher_scale); + } + right = CheckAndReduceScale(right, higher_scale - out_scale); + return (left * BasicDecimal128::GetScaleMultiplier(out_scale)) + right; +} + +static BasicDecimal128 AddLarge(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_scale) { + if (x.value() >= 0 && y.value() >= 0) { + // both positive or 0 + return AddLargePositive(x, y, out_scale); + } else if (x.value() <= 0 && y.value() <= 0) { + // both negative or 0 + BasicDecimalScalar128 x_neg(-x.value(), x.precision(), x.scale()); + BasicDecimalScalar128 y_neg(-y.value(), y.precision(), y.scale()); + return -AddLargePositive(x_neg, y_neg, out_scale); + } else { + // one positive and the other negative + return AddLargeNegative(x, y, out_scale); + } +} + +// Suppose we have a number that requires x bits to be represented and we scale it up by +// 10^scale_by. Let's say now y bits are required to represent it. This function returns +// the maximum possible y - x for a given 'scale_by'. +inline int32_t MaxBitsRequiredIncreaseAfterScaling(int32_t scale_by) { + // We rely on the following formula: + // bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1 + // We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76 + DCHECK_GE(scale_by, 0); + DCHECK_LE(scale_by, 76); + static const int32_t floor_log2_plus_one[] = { + 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, 44, 47, 50, + 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, 87, 90, 94, 97, 100, 103, + 107, 110, 113, 117, 120, 123, 127, 130, 133, 137, 140, 143, 147, 150, 153, 157, + 160, 163, 167, 170, 173, 177, 180, 183, 187, 190, 193, 196, 200, 203, 206, 210, + 213, 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253}; + return floor_log2_plus_one[scale_by]; +} + +// If we have a number with 'num_lz' leading zeros, and we scale it up by 10^scale_by, +// this function returns the minimum number of leading zeros the result can have. +inline int32_t MinLeadingZerosAfterScaling(int32_t num_lz, int32_t scale_by) { + DCHECK_GE(scale_by, 0); + DCHECK_LE(scale_by, 76); + int32_t result = num_lz - MaxBitsRequiredIncreaseAfterScaling(scale_by); + return result; +} + +// Returns the maximum possible number of bits required to represent num * 10^scale_by. +inline int32_t MaxBitsRequiredAfterScaling(const BasicDecimalScalar128& num, + int32_t scale_by) { + auto value = num.value(); + auto value_abs = value.Abs(); + + int32_t num_occupied = 128 - value_abs.CountLeadingBinaryZeros(); + DCHECK_GE(scale_by, 0); + DCHECK_LE(scale_by, 76); + return num_occupied + MaxBitsRequiredIncreaseAfterScaling(scale_by); +} + +// Returns the minimum number of leading zero x or y would have after one of them gets +// scaled up to match the scale of the other one. +inline int32_t MinLeadingZeros(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y) { + auto x_value = x.value(); + auto x_value_abs = x_value.Abs(); + + auto y_value = y.value(); + auto y_value_abs = y_value.Abs(); + + int32_t x_lz = x_value_abs.CountLeadingBinaryZeros(); + int32_t y_lz = y_value_abs.CountLeadingBinaryZeros(); + if (x.scale() < y.scale()) { + x_lz = MinLeadingZerosAfterScaling(x_lz, y.scale() - x.scale()); + } else if (x.scale() > y.scale()) { + y_lz = MinLeadingZerosAfterScaling(y_lz, x.scale() - y.scale()); + } + return std::min(x_lz, y_lz); +} + +BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale) { + if (out_precision < DecimalTypeUtil::kMaxPrecision) { + // fast-path add + return AddFastPath(x, y, out_scale); + } else { + int32_t min_lz = MinLeadingZeros(x, y); + if (min_lz >= 3) { + // If both numbers have at least MIN_LZ leading zeros, we can add them directly + // without the risk of overflow. + // We want the result to have at least 2 leading zeros, which ensures that it fits + // into the maximum decimal because 2^126 - 1 < 10^38 - 1. If both x and y have at + // least 3 leading zeros, then we are guaranteed that the result will have at lest 2 + // leading zeros. + return AddNoOverflow(x, y, out_scale); + } else { + // slower-version : add whole/fraction parts separately, and then, combine. + return AddLarge(x, y, out_scale); + } + } +} + +BasicDecimal128 Subtract(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale) { + return Add(x, {-y.value(), y.precision(), y.scale()}, out_precision, out_scale); +} + +// Multiply when the out_precision is 38, and there is no trimming of the scale i.e +// the intermediate value is the same as the final value. +static BasicDecimal128 MultiplyMaxPrecisionNoScaleDown(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale, + bool* overflow) { + DCHECK_EQ(x.scale() + y.scale(), out_scale); + + BasicDecimal128 result; + auto x_abs = BasicDecimal128::Abs(x.value()); + auto y_abs = BasicDecimal128::Abs(y.value()); + + if (x_abs > BasicDecimal128::GetMaxValue() / y_abs) { + *overflow = true; + } else { + // We've verified that the result will fit into 128 bits. + *overflow = false; + result = x.value() * y.value(); + } + return result; +} + +// Multiply when the out_precision is 38, and there is trimming of the scale i.e +// the intermediate value could be larger than the final value. +static BasicDecimal128 MultiplyMaxPrecisionAndScaleDown(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale, + bool* overflow) { + auto delta_scale = x.scale() + y.scale() - out_scale; + DCHECK_GT(delta_scale, 0); + + *overflow = false; + BasicDecimal128 result; + auto x_abs = BasicDecimal128::Abs(x.value()); + auto y_abs = BasicDecimal128::Abs(y.value()); + + // It's possible that the intermediate value does not fit in 128-bits, but the + // final value will (after scaling down). + bool needs_int256 = false; + int32_t total_leading_zeros = + x_abs.CountLeadingBinaryZeros() + y_abs.CountLeadingBinaryZeros(); + // This check is quick, but conservative. In some cases it will indicate that + // converting to 256 bits is necessary, when it's not actually the case. + needs_int256 = total_leading_zeros <= 128; + if (ARROW_PREDICT_FALSE(needs_int256)) { + int64_t result_high; + uint64_t result_low; + + // This requires converting to 256-bit, and we use the boost library for that. To + // avoid references to boost from the precompiled-to-ir code (this causes issues + // with symbol resolution at runtime), we use a wrapper exported from the CPP code. + gdv_xlarge_multiply_and_scale_down(x.value().high_bits(), x.value().low_bits(), + y.value().high_bits(), y.value().low_bits(), + delta_scale, &result_high, &result_low, overflow); + result = BasicDecimal128(result_high, result_low); + } else { + if (ARROW_PREDICT_TRUE(delta_scale <= 38)) { + // The largest value that result can have here is (2^64 - 1) * (2^63 - 1), which is + // greater than BasicDecimal128::kMaxValue. + result = x.value() * y.value(); + // Since delta_scale is greater than zero, result can now be at most + // ((2^64 - 1) * (2^63 - 1)) / 10, which is less than BasicDecimal128::kMaxValue, so + // there cannot be any overflow. + result = result.ReduceScaleBy(delta_scale); + } else { + // We are multiplying decimal(38, 38) by decimal(38, 38). The result should be a + // decimal(38, 37), so delta scale = 38 + 38 - 37 = 39. Since we are not in the + // 256 bit intermediate value case and we are scaling down by 39, then we are + // guaranteed that the result is 0 (even if we try to round). The largest possible + // intermediate result is 38 "9"s. If we scale down by 39, the leftmost 9 is now + // two digits to the right of the rightmost "visible" one. The reason why we have + // to handle this case separately is because a scale multiplier with a delta_scale + // 39 does not fit into 128 bit. + DCHECK_EQ(delta_scale, 39); + result = 0; + } + } + return result; +} + +// Multiply when the out_precision is 38. +static BasicDecimal128 MultiplyMaxPrecision(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, + int32_t out_scale, bool* overflow) { + auto delta_scale = x.scale() + y.scale() - out_scale; + DCHECK_GE(delta_scale, 0); + if (delta_scale == 0) { + return MultiplyMaxPrecisionNoScaleDown(x, y, out_scale, overflow); + } else { + return MultiplyMaxPrecisionAndScaleDown(x, y, out_scale, overflow); + } +} + +BasicDecimal128 Multiply(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale, bool* overflow) { + BasicDecimal128 result; + *overflow = false; + if (out_precision < DecimalTypeUtil::kMaxPrecision) { + // fast-path multiply + result = x.value() * y.value(); + DCHECK_EQ(x.scale() + y.scale(), out_scale); + DCHECK_LE(BasicDecimal128::Abs(result), BasicDecimal128::GetMaxValue()); + } else if (x.value() == 0 || y.value() == 0) { + // Handle this separately to avoid divide-by-zero errors. + result = BasicDecimal128(0, 0); + } else { + result = MultiplyMaxPrecision(x, y, out_scale, overflow); + } + DCHECK(*overflow || BasicDecimal128::Abs(result) <= BasicDecimal128::GetMaxValue()); + return result; +} + +BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow) { + if (y.value() == 0) { + char const* err_msg = "divide by zero error"; + gdv_fn_context_set_error_msg(context, err_msg); + return 0; + } + + // scale up to the output scale, and do an integer division. + int32_t delta_scale = out_scale + y.scale() - x.scale(); + DCHECK_GE(delta_scale, 0); + + BasicDecimal128 result; + auto num_bits_required_after_scaling = MaxBitsRequiredAfterScaling(x, delta_scale); + if (num_bits_required_after_scaling <= 127) { + // fast-path. The dividend fits in 128-bit after scaling too. + *overflow = false; + + // do the division. + auto x_scaled = CheckAndIncreaseScale(x.value(), delta_scale); + BasicDecimal128 remainder; + auto status = x_scaled.Divide(y.value(), &result, &remainder); + DCHECK_EQ(status, arrow::DecimalStatus::kSuccess); + + // round-up + if (BasicDecimal128::Abs(2 * remainder) >= BasicDecimal128::Abs(y.value())) { + result += (x.value().Sign() ^ y.value().Sign()) + 1; + } + } else { + // convert to 256-bit and do the divide. + *overflow = delta_scale > 38 && num_bits_required_after_scaling > 255; + if (!*overflow) { + int64_t result_high; + uint64_t result_low; + + gdv_xlarge_scale_up_and_divide(x.value().high_bits(), x.value().low_bits(), + y.value().high_bits(), y.value().low_bits(), + delta_scale, &result_high, &result_low, overflow); + result = BasicDecimal128(result_high, result_low); + } + } + return result; +} + +BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow) { + if (y.value() == 0) { + char const* err_msg = "divide by zero error"; + gdv_fn_context_set_error_msg(context, err_msg); + return 0; + } + + // Adsjust x and y to the same scale (higher one), and then, do a integer mod. + *overflow = false; + BasicDecimal128 result; + int32_t min_lz = MinLeadingZeros(x, y); + if (min_lz >= 2) { + auto higher_scale = std::max(x.scale(), y.scale()); + auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale()); + auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale()); + result = x_scaled % y_scaled; + DCHECK_LE(BasicDecimal128::Abs(result), BasicDecimal128::GetMaxValue()); + } else { + int64_t result_high; + uint64_t result_low; + + gdv_xlarge_mod(x.value().high_bits(), x.value().low_bits(), x.scale(), + y.value().high_bits(), y.value().low_bits(), y.scale(), &result_high, + &result_low); + result = BasicDecimal128(result_high, result_low); + } + DCHECK(BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(x.value()) || + BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(y.value())); + return result; +} + +int32_t CompareSameScale(const BasicDecimal128& x, const BasicDecimal128& y) { + if (x == y) { + return 0; + } else if (x < y) { + return -1; + } else { + return 1; + } +} + +int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y) { + int32_t delta_scale = x.scale() - y.scale(); + + // fast-path : both are of the same scale. + if (delta_scale == 0) { + return CompareSameScale(x.value(), y.value()); + } + + // Check if we'll need more than 256-bits after adjusting the scale. + bool need256 = + (delta_scale < 0 && x.precision() - delta_scale > DecimalTypeUtil::kMaxPrecision) || + (y.precision() + delta_scale > DecimalTypeUtil::kMaxPrecision); + if (need256) { + return gdv_xlarge_compare(x.value().high_bits(), x.value().low_bits(), x.scale(), + y.value().high_bits(), y.value().low_bits(), y.scale()); + } else { + BasicDecimal128 x_scaled; + BasicDecimal128 y_scaled; + + if (delta_scale < 0) { + x_scaled = x.value().IncreaseScaleBy(-delta_scale); + y_scaled = y.value(); + } else { + x_scaled = x.value(); + y_scaled = y.value().IncreaseScaleBy(delta_scale); + } + return CompareSameScale(x_scaled, y_scaled); + } +} + +#define DECIMAL_OVERFLOW_IF(condition, overflow) \ + do { \ + if (*overflow || (condition)) { \ + *overflow = true; \ + return 0; \ + } \ + } while (0) + +static BasicDecimal128 GetMaxValue(int32_t precision) { + return BasicDecimal128::GetScaleMultiplier(precision) - 1; +} + +// Compute the double scale multipliers once. +static std::array<double, DecimalTypeUtil::kMaxPrecision + 1> kDoubleScaleMultipliers = + ([]() -> std::array<double, DecimalTypeUtil::kMaxPrecision + 1> { + std::array<double, DecimalTypeUtil::kMaxPrecision + 1> values; + values[0] = 1.0; + for (int32_t idx = 1; idx <= DecimalTypeUtil::kMaxPrecision; idx++) { + values[idx] = values[idx - 1] * 10; + } + return values; + })(); + +BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow) { + // Multiply decimal with the scale + auto unscaled = in * kDoubleScaleMultipliers[scale]; + DECIMAL_OVERFLOW_IF(std::isnan(unscaled), overflow); + + unscaled = std::round(unscaled); + + // convert scaled double to int128 + int32_t sign = unscaled < 0 ? -1 : 1; + auto unscaled_abs = std::abs(unscaled); + + // overflow if > 2^127 - 1 + DECIMAL_OVERFLOW_IF(unscaled_abs > std::ldexp(static_cast<double>(1), 127) - 1, + overflow); + + uint64_t high_bits = static_cast<uint64_t>(std::ldexp(unscaled_abs, -64)); + uint64_t low_bits = static_cast<uint64_t>( + unscaled_abs - std::ldexp(static_cast<double>(high_bits), 64)); + + auto result = BasicDecimal128(static_cast<int64_t>(high_bits), low_bits); + + // overflow if > max value based on precision + DECIMAL_OVERFLOW_IF(result > GetMaxValue(precision), overflow); + return result * sign; +} + +double ToDouble(const BasicDecimalScalar128& in, bool* overflow) { + // convert int128 to double + int64_t sign = in.value().Sign(); + auto value_abs = BasicDecimal128::Abs(in.value()); + double unscaled = static_cast<double>(value_abs.low_bits()) + + std::ldexp(static_cast<double>(value_abs.high_bits()), 64); + + // scale double. + return (unscaled * sign) / kDoubleScaleMultipliers[in.scale()]; +} + +BasicDecimal128 FromInt64(int64_t in, int32_t precision, int32_t scale, bool* overflow) { + // check if multiplying by scale will cause an overflow. + DECIMAL_OVERFLOW_IF(std::abs(in) > GetMaxValue(precision - scale), overflow); + return in * BasicDecimal128::GetScaleMultiplier(scale); +} + +// Helper function to modify the scale and/or precision of a decimal value. +static BasicDecimal128 ModifyScaleAndPrecision(const BasicDecimalScalar128& x, + int32_t out_precision, int32_t out_scale, + bool* overflow) { + int32_t delta_scale = out_scale - x.scale(); + if (delta_scale >= 0) { + // check if multiplying by delta_scale will cause an overflow. + DECIMAL_OVERFLOW_IF( + BasicDecimal128::Abs(x.value()) > GetMaxValue(out_precision - delta_scale), + overflow); + return x.value().IncreaseScaleBy(delta_scale); + } else { + // Do not do any rounding, that is handled by the caller. + auto result = x.value().ReduceScaleBy(-delta_scale, false); + DECIMAL_OVERFLOW_IF(BasicDecimal128::Abs(result) > GetMaxValue(out_precision), + overflow); + return result; + } +} + +enum RoundType { + kRoundTypeCeil, // +1 if +ve and trailing value is > 0, else no rounding. + kRoundTypeFloor, // -1 if -ve and trailing value is < 0, else no rounding. + kRoundTypeTrunc, // no rounding, truncate the trailing digits. + kRoundTypeHalfRoundUp, // if +ve and trailing value is >= half of base, +1. + // else if -ve and trailing value is >= half of base, -1. +}; + +// Compute the rounding delta for the givven rounding type. +static int32_t ComputeRoundingDelta(const BasicDecimal128& x, int32_t x_scale, + int32_t out_scale, RoundType type) { + if (type == kRoundTypeTrunc || // no rounding for this type. + out_scale >= x_scale) { // no digits dropped, so no rounding. + return 0; + } + + int32_t result = 0; + switch (type) { + case kRoundTypeHalfRoundUp: { + auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale); + auto trailing = x % base; + if (trailing == 0) { + result = 0; + } else if (trailing.Abs() < base / 2) { + result = 0; + } else { + result = (x < 0) ? -1 : 1; + } + break; + } + + case kRoundTypeCeil: + if (x < 0) { + // no rounding for -ve + result = 0; + } else { + auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale); + auto trailing = x % base; + result = (trailing == 0) ? 0 : 1; + } + break; + + case kRoundTypeFloor: + if (x > 0) { + // no rounding for +ve + result = 0; + } else { + auto base = BasicDecimal128::GetScaleMultiplier(x_scale - out_scale); + auto trailing = x % base; + result = (trailing == 0) ? 0 : -1; + } + break; + + case kRoundTypeTrunc: + break; + } + return result; +} + +// Modify the scale and round. +static BasicDecimal128 RoundWithPositiveScale(const BasicDecimalScalar128& x, + int32_t out_precision, int32_t out_scale, + RoundType round_type, bool* overflow) { + DCHECK_GE(out_scale, 0); + + auto scaled = ModifyScaleAndPrecision(x, out_precision, out_scale, overflow); + if (*overflow) { + return 0; + } + + auto delta = ComputeRoundingDelta(x.value(), x.scale(), out_scale, round_type); + if (delta == 0) { + return scaled; + } + + // If there is a rounding delta, the output scale must be less than the input scale. + // That means at least one digit is dropped after the decimal. The delta add can add + // utmost one digit before the decimal. So, overflow will occur only if the output + // precision has changed. + DCHECK_GT(x.scale(), out_scale); + auto result = scaled + delta; + DECIMAL_OVERFLOW_IF(out_precision < x.precision() && + BasicDecimal128::Abs(result) > GetMaxValue(out_precision), + overflow); + return result; +} + +// Modify scale to drop all digits to the right of the decimal and round. +// Then, zero out 'rounding_scale' number of digits to the left of the decimal point. +static BasicDecimal128 RoundWithNegativeScale(const BasicDecimalScalar128& x, + int32_t out_precision, + int32_t rounding_scale, + RoundType round_type, bool* overflow) { + DCHECK_LT(rounding_scale, 0); + + // get rid of the fractional part. + auto scaled = ModifyScaleAndPrecision(x, out_precision, 0, overflow); + auto rounding_delta = ComputeRoundingDelta(scaled, 0, -rounding_scale, round_type); + + auto base = BasicDecimal128::GetScaleMultiplier(-rounding_scale); + auto delta = rounding_delta * base - (scaled % base); + DECIMAL_OVERFLOW_IF(BasicDecimal128::Abs(scaled) > + GetMaxValue(out_precision) - BasicDecimal128::Abs(delta), + overflow); + return scaled + delta; +} + +BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow) { + // no-op if target scale is same as arg scale + if (x.scale() == out_scale && rounding_scale >= 0) { + return x.value(); + } + + if (rounding_scale < 0) { + return RoundWithNegativeScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeHalfRoundUp, overflow); + } else { + return RoundWithPositiveScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeHalfRoundUp, overflow); + } +} + +BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow) { + // no-op if target scale is same as arg scale + if (x.scale() == out_scale && rounding_scale >= 0) { + return x.value(); + } + + if (rounding_scale < 0) { + return RoundWithNegativeScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeTrunc, overflow); + } else { + return RoundWithPositiveScale(x, out_precision, rounding_scale, + RoundType::kRoundTypeTrunc, overflow); + } +} + +BasicDecimal128 Ceil(const BasicDecimalScalar128& x, bool* overflow) { + return RoundWithPositiveScale(x, x.precision(), 0, RoundType::kRoundTypeCeil, overflow); +} + +BasicDecimal128 Floor(const BasicDecimalScalar128& x, bool* overflow) { + return RoundWithPositiveScale(x, x.precision(), 0, RoundType::kRoundTypeFloor, + overflow); +} + +BasicDecimal128 Convert(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, bool* overflow) { + DCHECK_GE(out_scale, 0); + DCHECK_LE(out_scale, DecimalTypeUtil::kMaxScale); + DCHECK_GT(out_precision, 0); + DCHECK_LE(out_precision, DecimalTypeUtil::kMaxScale); + + return RoundWithPositiveScale(x, out_precision, out_scale, + RoundType::kRoundTypeHalfRoundUp, overflow); +} + +int64_t ToInt64(const BasicDecimalScalar128& in, bool* overflow) { + auto rounded = RoundWithPositiveScale(in, in.precision(), 0 /*scale*/, + RoundType::kRoundTypeHalfRoundUp, overflow); + DECIMAL_OVERFLOW_IF((rounded > std::numeric_limits<int64_t>::max()) || + (rounded < std::numeric_limits<int64_t>::min()), + overflow); + return static_cast<int64_t>(rounded.low_bits()); +} + +} // namespace decimalops +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.h b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.h new file mode 100644 index 000000000..292dce220 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops.h @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> +#include <string> +#include "gandiva/basic_decimal_scalar.h" + +namespace gandiva { +namespace decimalops { + +/// Return the sum of 'x' and 'y'. +/// out_precision and out_scale are passed along for efficiency, they must match +/// the rules in DecimalTypeSql::GetResultType. +arrow::BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y, + int32_t out_precision, int32_t out_scale); + +/// Subtract 'y' from 'x', and return the result. +arrow::BasicDecimal128 Subtract(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale); + +/// Multiply 'x' from 'y', and return the result. +arrow::BasicDecimal128 Multiply(const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// Divide 'x' by 'y', and return the result. +arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// Divide 'x' by 'y', and return the remainder. +arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x, + const BasicDecimalScalar128& y, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// Compare two decimals. Returns : +/// 0 if x == y +/// 1 if x > y +/// -1 if x < y +int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y); + +/// Convert to decimal from double. +BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow); + +/// Convert from decimal to double. +double ToDouble(const BasicDecimalScalar128& in, bool* overflow); + +/// Convert to decimal from gdv_int64. +BasicDecimal128 FromInt64(int64_t in, int32_t precision, int32_t scale, bool* overflow); + +/// Convert from decimal to gdv_int64 +int64_t ToInt64(const BasicDecimalScalar128& in, bool* overflow); + +/// Convert from one decimal scale/precision to another. +BasicDecimal128 Convert(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, bool* overflow); + +/// round decimal. +BasicDecimal128 Round(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow); + +/// truncate decimal. +BasicDecimal128 Truncate(const BasicDecimalScalar128& x, int32_t out_precision, + int32_t out_scale, int32_t rounding_scale, bool* overflow); + +/// ceil decimal +BasicDecimal128 Ceil(const BasicDecimalScalar128& x, bool* overflow); + +/// floor decimal +BasicDecimal128 Floor(const BasicDecimalScalar128& x, bool* overflow); + +} // namespace decimalops +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops_test.cc new file mode 100644 index 000000000..be8a1fe8a --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_ops_test.cc @@ -0,0 +1,1095 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include <algorithm> +#include <limits> +#include <memory> +#include <tuple> +#include <vector> + +#include "arrow/testing/gtest_util.h" +#include "gandiva/decimal_scalar.h" +#include "gandiva/decimal_type_util.h" +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/decimal_ops.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +const arrow::Decimal128 kThirtyFive9s(std::string(35, '9')); +const arrow::Decimal128 kThirtySix9s(std::string(36, '9')); +const arrow::Decimal128 kThirtyEight9s(std::string(38, '9')); + +class TestDecimalSql : public ::testing::Test { + protected: + static void Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x, + const DecimalScalar128& y, const DecimalScalar128& expected_result, + bool expected_overflow); + + static void VerifyAllSign(DecimalTypeUtil::Op op, const DecimalScalar128& left, + const DecimalScalar128& right, + const DecimalScalar128& expected_output, + bool expected_overflow); + + void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result) { + // TODO: overflow checks + return Verify(DecimalTypeUtil::kOpAdd, x, y, expected_result, false); + } + + void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result) { + // TODO: overflow checks + return Verify(DecimalTypeUtil::kOpSubtract, x, y, expected_result, false); + } + + void MultiplyAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result, + bool expected_overflow) { + return Verify(DecimalTypeUtil::kOpMultiply, x, y, expected_result, expected_overflow); + } + + void MultiplyAndVerifyAllSign(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result, + bool expected_overflow) { + return VerifyAllSign(DecimalTypeUtil::kOpMultiply, x, y, expected_result, + expected_overflow); + } + + void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result, bool expected_overflow) { + return Verify(DecimalTypeUtil::kOpDivide, x, y, expected_result, expected_overflow); + } + + void DivideAndVerifyAllSign(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result, + bool expected_overflow) { + return VerifyAllSign(DecimalTypeUtil::kOpDivide, x, y, expected_result, + expected_overflow); + } + + void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result, bool expected_overflow) { + return Verify(DecimalTypeUtil::kOpMod, x, y, expected_result, expected_overflow); + } + + void ModAndVerifyAllSign(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected_result, + bool expected_overflow) { + return VerifyAllSign(DecimalTypeUtil::kOpMod, x, y, expected_result, + expected_overflow); + } +}; + +#define EXPECT_DECIMAL_EQ(op, x, y, expected_result, expected_overflow, actual_result, \ + actual_overflow) \ + { \ + EXPECT_TRUE(expected_overflow == actual_overflow) \ + << op << "(" << (x).ToString() << " and " << (y).ToString() << ")" \ + << " expected overflow : " << expected_overflow \ + << " actual overflow : " << actual_overflow; \ + if (!expected_overflow) { \ + EXPECT_TRUE(expected_result == actual_result) \ + << op << "(" << (x).ToString() << " and " << (y).ToString() << ")" \ + << " expected : " << expected_result.ToString() \ + << " actual : " << actual_result.ToString(); \ + } \ + } + +void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x, + const DecimalScalar128& y, + const DecimalScalar128& expected_result, + bool expected_overflow) { + auto t1 = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale()); + auto t2 = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale()); + bool overflow = false; + int64_t context = 0; + + Decimal128TypePtr out_type; + ARROW_EXPECT_OK(DecimalTypeUtil::GetResultType(op, {t1, t2}, &out_type)); + + arrow::BasicDecimal128 out_value; + std::string op_name; + switch (op) { + case DecimalTypeUtil::kOpAdd: + op_name = "add"; + out_value = decimalops::Add(x, y, out_type->precision(), out_type->scale()); + break; + + case DecimalTypeUtil::kOpSubtract: + op_name = "subtract"; + out_value = decimalops::Subtract(x, y, out_type->precision(), out_type->scale()); + break; + + case DecimalTypeUtil::kOpMultiply: + op_name = "multiply"; + out_value = + decimalops::Multiply(x, y, out_type->precision(), out_type->scale(), &overflow); + break; + + case DecimalTypeUtil::kOpDivide: + op_name = "divide"; + out_value = decimalops::Divide(context, x, y, out_type->precision(), + out_type->scale(), &overflow); + break; + + case DecimalTypeUtil::kOpMod: + op_name = "mod"; + out_value = decimalops::Mod(context, x, y, out_type->precision(), out_type->scale(), + &overflow); + break; + + default: + // not implemented. + ASSERT_FALSE(true); + } + EXPECT_DECIMAL_EQ(op_name, x, y, expected_result, expected_overflow, + DecimalScalar128(out_value, out_type->precision(), out_type->scale()), + overflow); +} + +void TestDecimalSql::VerifyAllSign(DecimalTypeUtil::Op op, const DecimalScalar128& left, + const DecimalScalar128& right, + const DecimalScalar128& expected_output, + bool expected_overflow) { + // both +ve + Verify(op, left, right, expected_output, expected_overflow); + + // left -ve + Verify(op, -left, right, -expected_output, expected_overflow); + + if (op == DecimalTypeUtil::kOpMod) { + // right -ve + Verify(op, left, -right, expected_output, expected_overflow); + + // both -ve + Verify(op, -left, -right, -expected_output, expected_overflow); + } else { + ASSERT_TRUE(op == DecimalTypeUtil::kOpMultiply || op == DecimalTypeUtil::kOpDivide); + + // right -ve + Verify(op, left, -right, -expected_output, expected_overflow); + + // both -ve + Verify(op, -left, -right, expected_output, expected_overflow); + } +} + +TEST_F(TestDecimalSql, Add) { + // fast-path + AddAndVerify(DecimalScalar128{"201", 30, 3}, // x + DecimalScalar128{"301", 30, 3}, // y + DecimalScalar128{"502", 31, 3}); // expected + + // max precision + AddAndVerify(DecimalScalar128{"09999999999999999999999999999999000000", 38, 5}, // x + DecimalScalar128{"100", 38, 7}, // y + DecimalScalar128{"99999999999999999999999999999990000010", 38, 6}); + + // Both -ve + AddAndVerify(DecimalScalar128{"-201", 30, 3}, // x + DecimalScalar128{"-301", 30, 2}, // y + DecimalScalar128{"-3211", 32, 3}); // expected + + // -ve and max precision + AddAndVerify(DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5}, // x + DecimalScalar128{"-100", 38, 7}, // y + DecimalScalar128{"-99999999999999999999999999999990000010", 38, 6}); +} + +TEST_F(TestDecimalSql, Subtract) { + // fast-path + SubtractAndVerify(DecimalScalar128{"201", 30, 3}, // x + DecimalScalar128{"301", 30, 3}, // y + DecimalScalar128{"-100", 31, 3}); // expected + + // max precision + SubtractAndVerify( + DecimalScalar128{"09999999999999999999999999999999000000", 38, 5}, // x + DecimalScalar128{"100", 38, 7}, // y + DecimalScalar128{"99999999999999999999999999999989999990", 38, 6}); + + // Both -ve + SubtractAndVerify(DecimalScalar128{"-201", 30, 3}, // x + DecimalScalar128{"-301", 30, 2}, // y + DecimalScalar128{"2809", 32, 3}); // expected + + // -ve and max precision + SubtractAndVerify( + DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5}, // x + DecimalScalar128{"-100", 38, 7}, // y + DecimalScalar128{"-99999999999999999999999999999989999990", 38, 6}); +} + +TEST_F(TestDecimalSql, Multiply) { + // fast-path : out_precision < 38 + MultiplyAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x + DecimalScalar128{"301", 10, 2}, // y + DecimalScalar128{"60501", 21, 5}, // expected + false); // overflow + + // right 0 + MultiplyAndVerify(DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{"0", 20, 2}, // y + DecimalScalar128{"0", 38, 5}, // expected + false); // overflow + + // left 0 + MultiplyAndVerify(DecimalScalar128{"0", 20, 3}, // x + DecimalScalar128{"301", 20, 2}, // y + DecimalScalar128{"0", 38, 5}, // expected + false); // overflow + + // out_precision == 38, small input values, no trimming of scale (scale <= 6 doesn't + // get trimmed). + MultiplyAndVerify(DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{"301", 20, 2}, // y + DecimalScalar128{"60501", 38, 5}, // expected + false); // overflow + + // out_precision == 38, large values, no trimming of scale (scale <= 6 doesn't + // get trimmed). + MultiplyAndVerifyAllSign( + DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{kThirtyFive9s, 35, 2}, // y + DecimalScalar128{"20099999999999999999999999999999999799", 38, 5}, // expected + false); // overflow + + // out_precision == 38, very large values, no trimming of scale (scale <= 6 doesn't + // get trimmed). overflow expected. + MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{kThirtySix9s, 35, 2}, // y + DecimalScalar128{"0", 38, 5}, // expected + true); // overflow + + MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{kThirtyEight9s, 35, 2}, // y + DecimalScalar128{"0", 38, 5}, // expected + true); // overflow + + // out_precision == 38, small input values, trimming of scale. + MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 5}, // x + DecimalScalar128{"301", 20, 5}, // y + DecimalScalar128{"61", 38, 7}, // expected + false); // overflow + + // out_precision == 38, large values, trimming of scale. + MultiplyAndVerifyAllSign( + DecimalScalar128{"201", 20, 5}, // x + DecimalScalar128{kThirtyFive9s, 35, 5}, // y + DecimalScalar128{"2010000000000000000000000000000000", 38, 6}, // expected + false); // overflow + + // out_precision == 38, very large values, trimming of scale (requires convert to 256). + MultiplyAndVerifyAllSign( + DecimalScalar128{kThirtyFive9s, 38, 20}, // x + DecimalScalar128{kThirtySix9s, 38, 20}, // y + DecimalScalar128{"9999999999999999999999999999999999890", 38, 6}, // expected + false); // overflow + + // out_precision == 38, very large values, trimming of scale (requires convert to 256). + // should cause overflow. + MultiplyAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 38, 4}, // x + DecimalScalar128{kThirtySix9s, 38, 4}, // y + DecimalScalar128{"0", 38, 6}, // expected + true); // overflow + + // corner cases. + MultiplyAndVerifyAllSign( + DecimalScalar128{0, UINT64_MAX, 38, 4}, // x + DecimalScalar128{0, UINT64_MAX, 38, 4}, // y + DecimalScalar128{"3402823669209384634264811192843491082", 38, 6}, // expected + false); // overflow + + MultiplyAndVerifyAllSign( + DecimalScalar128{0, UINT64_MAX, 38, 4}, // x + DecimalScalar128{0, INT64_MAX, 38, 4}, // y + DecimalScalar128{"1701411834604692317040171876053197783", 38, 6}, // expected + false); // overflow + + MultiplyAndVerifyAllSign(DecimalScalar128{"201", 38, 38}, // x + DecimalScalar128{"301", 38, 38}, // y + DecimalScalar128{"0", 38, 37}, // expected + false); // overflow + + MultiplyAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 38}, // x + DecimalScalar128{0, UINT64_MAX, 38, 38}, // y + DecimalScalar128{"0", 38, 37}, // expected + false); // overflow + + MultiplyAndVerifyAllSign( + DecimalScalar128{kThirtyFive9s, 38, 38}, // x + DecimalScalar128{kThirtySix9s, 38, 38}, // y + DecimalScalar128{"100000000000000000000000000000000", 38, 37}, // expected + false); // overflow +} + +TEST_F(TestDecimalSql, Divide) { + DivideAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x + DecimalScalar128{"301", 10, 2}, // y + DecimalScalar128{"6677740863787", 23, 14}, // expected + false); // overflow + + DivideAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{"301", 20, 2}, // y + DecimalScalar128{"667774086378737542", 38, 19}, // expected + false); // overflow + + DivideAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{kThirtyFive9s, 35, 2}, // y + DecimalScalar128{"0", 38, 19}, // expected + false); // overflow + + DivideAndVerifyAllSign( + DecimalScalar128{kThirtyFive9s, 35, 6}, // x + DecimalScalar128{"201", 20, 3}, // y + DecimalScalar128{"497512437810945273631840796019900493", 38, 6}, // expected + false); // overflow + + DivideAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 20}, // x + DecimalScalar128{kThirtyFive9s, 38, 20}, // y + DecimalScalar128{"1000000000", 38, 6}, // expected + false); // overflow + + DivideAndVerifyAllSign(DecimalScalar128{"31939128063561476055", 38, 8}, // x + DecimalScalar128{"10000", 20, 0}, // y + DecimalScalar128{"3193912806356148", 38, 8}, // expected + false); + + // Corner cases + DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x + DecimalScalar128{0, UINT64_MAX, 38, 4}, // y + DecimalScalar128{"1000000", 38, 6}, // expected + false); // overflow + + DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x + DecimalScalar128{0, INT64_MAX, 38, 4}, // y + DecimalScalar128{"2000000", 38, 6}, // expected + false); // overflow + + DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 19, 5}, // x + DecimalScalar128{0, INT64_MAX, 19, 5}, // y + DecimalScalar128{"20000000000000000001", 38, 19}, // expected + false); // overflow + + DivideAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 38, 37}, // x + DecimalScalar128{kThirtyFive9s, 38, 38}, // y + DecimalScalar128{"10000000", 38, 6}, // expected + false); // overflow + + // overflow + DivideAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 6}, // x + DecimalScalar128{"201", 20, 3}, // y + DecimalScalar128{"0", 38, 6}, // expected + true); +} + +TEST_F(TestDecimalSql, Mod) { + ModAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x + DecimalScalar128{"301", 10, 2}, // y + DecimalScalar128{"201", 10, 3}, // expected + false); // overflow + + ModAndVerify(DecimalScalar128{"201", 20, 2}, // x + DecimalScalar128{"301", 20, 3}, // y + DecimalScalar128{"204", 20, 3}, // expected + false); // overflow + + ModAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x + DecimalScalar128{kThirtyFive9s, 35, 2}, // y + DecimalScalar128{"201", 20, 3}, // expected + false); // overflow + + ModAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 35, 6}, // x + DecimalScalar128{"201", 20, 3}, // y + DecimalScalar128{"180999", 23, 6}, // expected + false); // overflow + + ModAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 20}, // x + DecimalScalar128{kThirtyFive9s, 38, 21}, // y + DecimalScalar128{"9990", 38, 21}, // expected + false); // overflow + + ModAndVerifyAllSign(DecimalScalar128{"31939128063561476055", 38, 8}, // x + DecimalScalar128{"10000", 20, 0}, // y + DecimalScalar128{"63561476055", 28, 8}, // expected + false); + + ModAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x + DecimalScalar128{0, UINT64_MAX, 38, 4}, // y + DecimalScalar128{"0", 38, 4}, // expected + false); // overflow + + ModAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x + DecimalScalar128{0, INT64_MAX, 38, 4}, // y + DecimalScalar128{"1", 38, 4}, // expected + false); // overflow +} + +TEST_F(TestDecimalSql, DivideByZero) { + gandiva::ExecutionContext context; + int32_t result_precision; + int32_t result_scale; + bool overflow; + + // divide-by-zero should cause an error. + context.Reset(); + result_precision = 38; + result_scale = 19; + decimalops::Divide(reinterpret_cast<gdv_int64>(&context), + DecimalScalar128{"201", 20, 3}, DecimalScalar128{"0", 20, 2}, + result_precision, result_scale, &overflow); + EXPECT_TRUE(context.has_error()); + EXPECT_EQ(context.get_error(), "divide by zero error"); + + // divide-by-nonzero should not cause an error. + context.Reset(); + decimalops::Divide(reinterpret_cast<gdv_int64>(&context), + DecimalScalar128{"201", 20, 3}, DecimalScalar128{"1", 20, 2}, + result_precision, result_scale, &overflow); + EXPECT_FALSE(context.has_error()); + + // mod-by-zero should cause an error. + context.Reset(); + result_precision = 20; + result_scale = 3; + decimalops::Mod(reinterpret_cast<gdv_int64>(&context), DecimalScalar128{"201", 20, 3}, + DecimalScalar128{"0", 20, 2}, result_precision, result_scale, + &overflow); + EXPECT_TRUE(context.has_error()); + EXPECT_EQ(context.get_error(), "divide by zero error"); + + // mod-by-nonzero should not cause an error. + context.Reset(); + decimalops::Mod(reinterpret_cast<gdv_int64>(&context), DecimalScalar128{"201", 20, 3}, + DecimalScalar128{"1", 20, 2}, result_precision, result_scale, + &overflow); + EXPECT_FALSE(context.has_error()); +} + +TEST_F(TestDecimalSql, Compare) { + // x.scale == y.scale + EXPECT_EQ( + 0, decimalops::Compare(DecimalScalar128{100, 38, 6}, DecimalScalar128{100, 38, 6})); + EXPECT_EQ( + 1, decimalops::Compare(DecimalScalar128{200, 38, 6}, DecimalScalar128{100, 38, 6})); + EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, 38, 6}, + DecimalScalar128{200, 38, 6})); + + // x.scale == y.scale, with -ve. + EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, 38, 6}, + DecimalScalar128{-100, 38, 6})); + EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, 38, 6}, + DecimalScalar128{-100, 38, 6})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, 38, 6}, + DecimalScalar128{-200, 38, 6})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, 38, 6}, + DecimalScalar128{-200, 38, 6})); + + for (int32_t precision : {16, 36, 38}) { + // x_scale > y_scale + EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{10000, precision, 6}, + DecimalScalar128{100, precision, 4})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{20000, precision, 6}, + DecimalScalar128{100, precision, 4})); + EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{10000, precision, 6}, + DecimalScalar128{200, precision, 4})); + + // x.scale > y.scale, with -ve + EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-10000, precision, 6}, + DecimalScalar128{-100, precision, 4})); + EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-20000, precision, 6}, + DecimalScalar128{-100, precision, 4})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-10000, precision, 6}, + DecimalScalar128{-200, precision, 4})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{10000, precision, 6}, + DecimalScalar128{-200, precision, 4})); + + // x.scale < y.scale + EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{100, precision, 4}, + DecimalScalar128{10000, precision, 6})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{200, precision, 4}, + DecimalScalar128{10000, precision, 6})); + EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{100, precision, 4}, + DecimalScalar128{20000, precision, 6})); + + // x.scale < y.scale, with -ve + EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{-100, precision, 4}, + DecimalScalar128{-10000, precision, 6})); + EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{-200, precision, 4}, + DecimalScalar128{-10000, precision, 6})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{-100, precision, 4}, + DecimalScalar128{-20000, precision, 6})); + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{100, precision, 4}, + DecimalScalar128{-200, precision, 6})); + } + + // large cases. + EXPECT_EQ(0, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6}, + DecimalScalar128{kThirtyEight9s, 38, 6})); + + EXPECT_EQ(1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6}, + DecimalScalar128{kThirtySix9s, 38, 4})); + + EXPECT_EQ(-1, decimalops::Compare(DecimalScalar128{kThirtyEight9s, 38, 6}, + DecimalScalar128{kThirtyEight9s, 38, 4})); +} + +TEST_F(TestDecimalSql, Round) { + // expected, input, rounding_scale, overflow + using TupleType = std::tuple<DecimalScalar128, DecimalScalar128, int32_t, bool>; + std::vector<TupleType> test_values = { + // examples from + // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_round + std::make_tuple(DecimalScalar128{-1, 36, 0}, DecimalScalar128{-123, 38, 2}, 0, + false), + std::make_tuple(DecimalScalar128{-2, 36, 0}, DecimalScalar128{-158, 38, 2}, 0, + false), + std::make_tuple(DecimalScalar128{2, 36, 0}, DecimalScalar128{158, 38, 2}, 0, false), + std::make_tuple(DecimalScalar128{-13, 36, 1}, DecimalScalar128{-1298, 38, 3}, 1, + false), + std::make_tuple(DecimalScalar128{-1, 35, 0}, DecimalScalar128{-1298, 38, 3}, 0, + false), + std::make_tuple(DecimalScalar128{20, 35, 0}, DecimalScalar128{23298, 38, 3}, -1, + false), + std::make_tuple(DecimalScalar128{100, 38, 0}, DecimalScalar128{122, 38, 0}, -2, + false), + std::make_tuple(DecimalScalar128{3, 37, 0}, DecimalScalar128{25, 38, 1}, 0, false), + + // border cases + std::make_tuple(DecimalScalar128{INT64_MIN / 100, 36, 0}, + DecimalScalar128{INT64_MIN, 38, 2}, 0, false), + + std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0}, + DecimalScalar128{INT64_MIN, 38, 0}, 0, false), + std::make_tuple(DecimalScalar128{0, 0, 36, 0}, DecimalScalar128{0, 0, 38, 2}, 0, + false), + std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0}, + DecimalScalar128{INT64_MAX, 38, 0}, 0, false), + + std::make_tuple(DecimalScalar128{INT64_MAX / 100, 36, 0}, + DecimalScalar128{INT64_MAX, 38, 2}, 0, false), + + // large scales + std::make_tuple(DecimalScalar128{0, 0, 22, 0}, DecimalScalar128{12345, 38, 16}, 0, + false), + + std::make_tuple( + DecimalScalar128{BasicDecimal128{124}, 22, 0}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, 0, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{-124}, 22, 0}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, 0, + false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{124}, 6, 0}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, 0, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{-124}, 6, 0}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, 0, + false), + + // scale bigger than arg + std::make_tuple( + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35, + false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, 35, + false), + + // overflow + std::make_tuple(DecimalScalar128{0, 0, 1, 0}, DecimalScalar128{99, 2, 1}, 0, true), + }; + + for (auto iter : test_values) { + auto expected = std::get<0>(iter); + auto input = std::get<1>(iter); + auto rounding_scale = std::get<2>(iter); + auto expected_overflow = std::get<3>(iter); + bool overflow = false; + + EXPECT_EQ(expected.value(), + decimalops::Round(input, expected.precision(), expected.scale(), + rounding_scale, &overflow)) + << " failed on input " << input << " rounding scale " << rounding_scale; + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +TEST_F(TestDecimalSql, Truncate) { + // expected, input, rounding_scale, overflow + using TupleType = std::tuple<DecimalScalar128, DecimalScalar128, int32_t, bool>; + std::vector<TupleType> test_values = { + // examples from + // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_truncate + std::make_tuple(DecimalScalar128{12, 36, 1}, DecimalScalar128{1223, 38, 3}, 1, + false), + std::make_tuple(DecimalScalar128{19, 36, 1}, DecimalScalar128{1999, 38, 3}, 1, + false), + std::make_tuple(DecimalScalar128{1, 35, 0}, DecimalScalar128{1999, 38, 3}, 0, + false), + std::make_tuple(DecimalScalar128{-19, 36, 1}, DecimalScalar128{-1999, 38, 3}, 1, + false), + std::make_tuple(DecimalScalar128{100, 38, 0}, DecimalScalar128{122, 38, 0}, -2, + false), + std::make_tuple(DecimalScalar128{1028, 38, 0}, DecimalScalar128{1028, 38, 0}, 0, + false), + + // border cases + std::make_tuple(DecimalScalar128{BasicDecimal128{INT64_MIN / 100}, 36, 0}, + DecimalScalar128{INT64_MIN, 38, 2}, 0, false), + + std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0}, + DecimalScalar128{INT64_MIN, 38, 0}, 0, false), + std::make_tuple(DecimalScalar128{0, 0, 38, 0}, DecimalScalar128{0, 0, 38, 2}, 0, + false), + std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0}, + DecimalScalar128{INT64_MAX, 38, 0}, 0, false), + + std::make_tuple(DecimalScalar128{BasicDecimal128(INT64_MAX / 100), 36, 0}, + DecimalScalar128{INT64_MAX, 38, 2}, 0, false), + + // large scales + std::make_tuple(DecimalScalar128{BasicDecimal128{0, 0}, 22, 0}, + DecimalScalar128{12345, 38, 16}, 0, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{123}, 22, 0}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, 0, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{-123}, 22, 0}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, 0, + false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{123}, 6, 0}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, 0, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{-123}, 6, 0}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, 0, + false), + + // overflow + std::make_tuple( + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(32), 38, 32}, 35, + false), + std::make_tuple( + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(32), 38, 32}, 35, + false), + }; + + for (auto iter : test_values) { + auto expected = std::get<0>(iter); + auto input = std::get<1>(iter); + auto rounding_scale = std::get<2>(iter); + auto expected_overflow = std::get<3>(iter); + bool overflow = false; + + EXPECT_EQ(expected.value(), + decimalops::Truncate(input, expected.precision(), expected.scale(), + rounding_scale, &overflow)) + << " failed on input " << input << " rounding scale " << rounding_scale; + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +TEST_F(TestDecimalSql, Ceil) { + // expected, input, overflow + std::vector<std::tuple<BasicDecimal128, DecimalScalar128, bool>> test_values = { + // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_ceil + std::make_tuple(2, DecimalScalar128{123, 38, 2}, false), + std::make_tuple(-1, DecimalScalar128{-123, 38, 2}, false), + + // border cases + std::make_tuple(BasicDecimal128{INT64_MIN / 100}, + DecimalScalar128{INT64_MIN, 38, 2}, false), + + std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, false), + std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{0, 0, 38, 2}, false), + std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, false), + + std::make_tuple(BasicDecimal128(INT64_MAX / 100 + 1), + DecimalScalar128{INT64_MAX, 38, 2}, false), + + // large scales + std::make_tuple(BasicDecimal128{0, 1}, DecimalScalar128{12345, 38, 16}, false), + std::make_tuple( + BasicDecimal128{124}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, false), + std::make_tuple( + BasicDecimal128{-123}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, false), + std::make_tuple( + BasicDecimal128{124}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, false), + std::make_tuple( + BasicDecimal128{-123}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, false), + }; + + for (auto iter : test_values) { + auto expected = std::get<0>(iter); + auto input = std::get<1>(iter); + auto expected_overflow = std::get<2>(iter); + bool overflow = false; + + EXPECT_EQ(expected, decimalops::Ceil(input, &overflow)) + << " failed on input " << input; + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +TEST_F(TestDecimalSql, Floor) { + // expected, input, overflow + std::vector<std::tuple<BasicDecimal128, DecimalScalar128, bool>> test_values = { + // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_floor + std::make_tuple(1, DecimalScalar128{123, 38, 2}, false), + std::make_tuple(-2, DecimalScalar128{-123, 38, 2}, false), + + // border cases + std::make_tuple(BasicDecimal128{INT64_MIN / 100 - 1}, + DecimalScalar128{INT64_MIN, 38, 2}, false), + + std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, false), + std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{0, 0, 38, 2}, false), + std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, false), + + std::make_tuple(BasicDecimal128{INT64_MAX / 100}, + DecimalScalar128{INT64_MAX, 38, 2}, false), + + // large scales + std::make_tuple(BasicDecimal128{0, 0}, DecimalScalar128{12345, 38, 16}, false), + std::make_tuple( + BasicDecimal128{123}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(14), 38, 16}, false), + std::make_tuple( + BasicDecimal128{-124}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(14), 38, 16}, false), + std::make_tuple( + BasicDecimal128{123}, + DecimalScalar128{BasicDecimal128{12389}.IncreaseScaleBy(30), 38, 32}, false), + std::make_tuple( + BasicDecimal128{-124}, + DecimalScalar128{BasicDecimal128{-12389}.IncreaseScaleBy(30), 38, 32}, false), + }; + + for (auto iter : test_values) { + auto expected = std::get<0>(iter); + auto input = std::get<1>(iter); + auto expected_overflow = std::get<2>(iter); + bool overflow = false; + + EXPECT_EQ(expected, decimalops::Floor(input, &overflow)) + << " failed on input " << input; + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +TEST_F(TestDecimalSql, Convert) { + // expected, input, overflow + std::vector<std::tuple<DecimalScalar128, DecimalScalar128, bool>> test_values = { + // simple cases + std::make_tuple(DecimalScalar128{12, 38, 1}, DecimalScalar128{123, 38, 2}, false), + std::make_tuple(DecimalScalar128{1230, 38, 3}, DecimalScalar128{123, 38, 2}, false), + std::make_tuple(DecimalScalar128{123, 38, 2}, DecimalScalar128{123, 38, 2}, false), + + std::make_tuple(DecimalScalar128{-12, 38, 1}, DecimalScalar128{-123, 38, 2}, false), + std::make_tuple(DecimalScalar128{-1230, 38, 3}, DecimalScalar128{-123, 38, 2}, + false), + std::make_tuple(DecimalScalar128{-123, 38, 2}, DecimalScalar128{-123, 38, 2}, + false), + + // border cases + std::make_tuple( + DecimalScalar128{BasicDecimal128(INT64_MIN).ReduceScaleBy(1), 38, 1}, + DecimalScalar128{INT64_MIN, 38, 2}, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128(INT64_MIN).IncreaseScaleBy(1), 38, 3}, + DecimalScalar128{INT64_MIN, 38, 2}, false), + std::make_tuple(DecimalScalar128{-3, 38, 1}, DecimalScalar128{-32, 38, 2}, false), + std::make_tuple(DecimalScalar128{0, 0, 38, 1}, DecimalScalar128{0, 0, 38, 2}, + false), + std::make_tuple(DecimalScalar128{3, 38, 1}, DecimalScalar128{32, 38, 2}, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128(INT64_MAX).ReduceScaleBy(1), 38, 1}, + DecimalScalar128{INT64_MAX, 38, 2}, false), + std::make_tuple( + DecimalScalar128{BasicDecimal128(INT64_MAX).IncreaseScaleBy(1), 38, 3}, + DecimalScalar128{INT64_MAX, 38, 2}, false), + + // large scales + std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(16), 38, 18}, + DecimalScalar128{123, 38, 2}, false), + std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(16), 38, 18}, + DecimalScalar128{-123, 38, 2}, false), + std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(30), 38, 32}, + DecimalScalar128{123, 38, 2}, false), + std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(30), 38, 32}, + DecimalScalar128{-123, 38, 2}, false), + + // overflow due to scaling up. + std::make_tuple(DecimalScalar128{0, 0, 38, 36}, DecimalScalar128{12345, 38, 2}, + true), + std::make_tuple(DecimalScalar128{0, 0, 38, 36}, DecimalScalar128{-12345, 38, 2}, + true), + + // overflow due to precision. + std::make_tuple(DecimalScalar128{0, 0, 5, 3}, DecimalScalar128{12345, 5, 2}, true), + }; + + for (auto iter : test_values) { + auto expected = std::get<0>(iter); + auto input = std::get<1>(iter); + auto expected_overflow = std::get<2>(iter); + bool overflow = false; + + EXPECT_EQ(expected.value(), decimalops::Convert(input, expected.precision(), + expected.scale(), &overflow)) + << " failed on input " << input; + + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +// double can store up to this integer value without losing precision +static const int64_t kMaxDoubleInt = 1ull << 53; + +TEST_F(TestDecimalSql, FromDouble) { + // expected, input, overflow + std::vector<std::tuple<DecimalScalar128, double, bool>> test_values = { + // simple cases + std::make_tuple(DecimalScalar128{-16285, 38, 3}, -16.285, false), + std::make_tuple(DecimalScalar128{-162850, 38, 4}, -16.285, false), + std::make_tuple(DecimalScalar128{-1629, 38, 2}, -16.285, false), + + std::make_tuple(DecimalScalar128{16285, 38, 3}, 16.285, false), + std::make_tuple(DecimalScalar128{162850, 38, 4}, 16.285, false), + std::make_tuple(DecimalScalar128{1629, 38, 2}, 16.285, false), + + // round up + std::make_tuple(DecimalScalar128{1, 18, 0}, 1.15470053838, false), + std::make_tuple(DecimalScalar128{-1, 18, 0}, -1.15470053838, false), + std::make_tuple(DecimalScalar128{2, 18, 0}, 1.55470053838, false), + std::make_tuple(DecimalScalar128{-2, 18, 0}, -1.55470053838, false), + + // border cases + std::make_tuple(DecimalScalar128{-kMaxDoubleInt, 38, 0}, + static_cast<double>(-kMaxDoubleInt), false), + std::make_tuple(DecimalScalar128{-32, 38, 0}, -32, false), + std::make_tuple(DecimalScalar128{0, 0, 38, 0}, 0, false), + std::make_tuple(DecimalScalar128{32, 38, 0}, 32, false), + std::make_tuple(DecimalScalar128{kMaxDoubleInt, 38, 0}, + static_cast<double>(kMaxDoubleInt), false), + + // large scales + std::make_tuple(DecimalScalar128{123, 38, 16}, 1.23E-14, false), + std::make_tuple(DecimalScalar128{123, 38, 32}, 1.23E-30, false), + std::make_tuple(DecimalScalar128{1230, 38, 33}, 1.23E-30, false), + std::make_tuple(DecimalScalar128{123, 38, 38}, 1.23E-36, false), + + // very small doubles + std::make_tuple(DecimalScalar128{0, 0, 38, 0}, std::numeric_limits<double>::min(), + false), + std::make_tuple(DecimalScalar128{0, 0, 38, 0}, -std::numeric_limits<double>::min(), + false), + + // overflow due to large -ve double + std::make_tuple(DecimalScalar128{0, 0, 38, 0}, -std::numeric_limits<double>::max(), + true), + // overflow due to large +ve double + std::make_tuple(DecimalScalar128{0, 0, 38, 0}, std::numeric_limits<double>::max(), + true), + // overflow due to scaling up. + std::make_tuple(DecimalScalar128{0, 0, 38, 36}, 123.45, true), + // overflow due to precision. + std::make_tuple(DecimalScalar128{0, 0, 4, 2}, 12345.67, true), + }; + + for (auto iter : test_values) { + auto dscalar = std::get<0>(iter); + auto input = std::get<1>(iter); + auto expected_overflow = std::get<2>(iter); + bool overflow = false; + + EXPECT_EQ(dscalar.value(), decimalops::FromDouble(input, dscalar.precision(), + dscalar.scale(), &overflow)) + << " failed on input " << input; + + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +#define EXPECT_FUZZY_EQ(x, y) \ + EXPECT_TRUE(x - y <= 0.00001) << "expected " << x << ", got " << y + +TEST_F(TestDecimalSql, ToDouble) { + // expected, input, overflow + std::vector<std::tuple<double, DecimalScalar128>> test_values = { + // simple ones + std::make_tuple(-16.285, DecimalScalar128{-16285, 38, 3}), + std::make_tuple(-162.85, DecimalScalar128{-16285, 38, 2}), + std::make_tuple(-1.6285, DecimalScalar128{-16285, 38, 4}), + + // large scales + std::make_tuple(1.23E-14, DecimalScalar128{123, 38, 16}), + std::make_tuple(1.23E-30, DecimalScalar128{123, 38, 32}), + std::make_tuple(1.23E-36, DecimalScalar128{123, 38, 38}), + + // border cases + std::make_tuple(static_cast<double>(-kMaxDoubleInt), + DecimalScalar128{-kMaxDoubleInt, 38, 0}), + std::make_tuple(-32, DecimalScalar128{-32, 38, 0}), + std::make_tuple(0, DecimalScalar128{0, 0, 38, 0}), + std::make_tuple(32, DecimalScalar128{32, 38, 0}), + std::make_tuple(static_cast<double>(kMaxDoubleInt), + DecimalScalar128{kMaxDoubleInt, 38, 0}), + }; + for (auto iter : test_values) { + auto input = std::get<1>(iter); + bool overflow = false; + + EXPECT_FUZZY_EQ(std::get<0>(iter), decimalops::ToDouble(input, &overflow)); + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } +} + +TEST_F(TestDecimalSql, FromInt64) { + // expected, input, overflow + std::vector<std::tuple<DecimalScalar128, int64_t, bool>> test_values = { + // simple cases + std::make_tuple(DecimalScalar128{-16000, 38, 3}, -16, false), + std::make_tuple(DecimalScalar128{-160000, 38, 4}, -16, false), + std::make_tuple(DecimalScalar128{-1600, 38, 2}, -16, false), + + std::make_tuple(DecimalScalar128{16000, 38, 3}, 16, false), + std::make_tuple(DecimalScalar128{160000, 38, 4}, 16, false), + std::make_tuple(DecimalScalar128{1600, 38, 2}, 16, false), + + // border cases + std::make_tuple(DecimalScalar128{INT64_MIN, 38, 0}, INT64_MIN, false), + std::make_tuple(DecimalScalar128{-32, 38, 0}, -32, false), + std::make_tuple(DecimalScalar128{0, 0, 38, 0}, 0, false), + std::make_tuple(DecimalScalar128{32, 38, 0}, 32, false), + std::make_tuple(DecimalScalar128{INT64_MAX, 38, 0}, INT64_MAX, false), + + // large scales + std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(16), 38, 16}, + 123, false), + std::make_tuple(DecimalScalar128{BasicDecimal128(123).IncreaseScaleBy(32), 38, 32}, + 123, false), + std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(16), 38, 16}, + -123, false), + std::make_tuple(DecimalScalar128{BasicDecimal128(-123).IncreaseScaleBy(32), 38, 32}, + -123, false), + + // overflow due to scaling up. + std::make_tuple(DecimalScalar128{0, 0, 38, 36}, 123, true), + // overflow due to precision. + std::make_tuple(DecimalScalar128{0, 0, 4, 2}, 12345, true), + }; + + for (auto iter : test_values) { + auto dscalar = std::get<0>(iter); + auto input = std::get<1>(iter); + auto expected_overflow = std::get<2>(iter); + bool overflow = false; + + EXPECT_EQ(dscalar.value(), decimalops::FromInt64(input, dscalar.precision(), + dscalar.scale(), &overflow)) + << " failed on input " << input; + + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +TEST_F(TestDecimalSql, ToInt64) { + // expected, input, overflow + std::vector<std::tuple<int64_t, DecimalScalar128, bool>> test_values = { + // simple ones + std::make_tuple(-16, DecimalScalar128{-16285, 38, 3}, false), + std::make_tuple(-163, DecimalScalar128{-16285, 38, 2}, false), + std::make_tuple(-2, DecimalScalar128{-16285, 38, 4}, false), + + // border cases + std::make_tuple(INT64_MIN, DecimalScalar128{INT64_MIN, 38, 0}, false), + std::make_tuple(-32, DecimalScalar128{-32, 38, 0}, false), + std::make_tuple(0, DecimalScalar128{0, 0, 38, 0}, false), + std::make_tuple(32, DecimalScalar128{32, 38, 0}, false), + std::make_tuple(INT64_MAX, DecimalScalar128{INT64_MAX, 38, 0}, false), + + // large scales + std::make_tuple(0, DecimalScalar128{123, 38, 16}, false), + std::make_tuple(0, DecimalScalar128{123, 38, 32}, false), + std::make_tuple(0, DecimalScalar128{123, 38, 38}, false), + + // overflow test cases + // very large + std::make_tuple(0, DecimalScalar128{32768, 16, 38, 2}, true), + std::make_tuple(0, DecimalScalar128{INT64_MAX, UINT64_MAX, 38, 10}, true), + // very small + std::make_tuple(0, -DecimalScalar128{32768, 16, 38, 2}, true), + std::make_tuple(0, -DecimalScalar128{INT64_MAX, UINT64_MAX, 38, 10}, true), + }; + + for (auto iter : test_values) { + auto expected_value = std::get<0>(iter); + auto input = std::get<1>(iter); + auto expected_overflow = std::get<2>(iter); + bool overflow = false; + + EXPECT_EQ(expected_value, decimalops::ToInt64(input, &overflow)) + << " failed on input " << input; + if (expected_overflow) { + ASSERT_TRUE(overflow) << "overflow expected for input " << input; + } else { + ASSERT_FALSE(overflow) << "overflow not expected for input " << input; + } + } +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/decimal_wrapper.cc b/src/arrow/cpp/src/gandiva/precompiled/decimal_wrapper.cc new file mode 100644 index 000000000..082d5832d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/decimal_wrapper.cc @@ -0,0 +1,433 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/precompiled/decimal_ops.h" +#include "gandiva/precompiled/types.h" + +extern "C" { + +FORCE_INLINE +void add_large_decimal128_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int64_t y_high, uint64_t y_low, + int32_t y_precision, int32_t y_scale, + int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale); + gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale); + + arrow::BasicDecimal128 out = gandiva::decimalops::Add(x, y, out_precision, out_scale); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void multiply_decimal128_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int64_t y_high, uint64_t y_low, + int32_t y_precision, int32_t y_scale, + int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale); + gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale); + bool overflow; + + // TODO ravindra: generate error on overflows (ARROW-4570). + arrow::BasicDecimal128 out = + gandiva::decimalops::Multiply(x, y, out_precision, out_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void divide_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, int64_t y_high, + uint64_t y_low, int32_t y_precision, int32_t y_scale, + int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale); + gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale); + bool overflow; + + // TODO ravindra: generate error on overflows (ARROW-4570). + arrow::BasicDecimal128 out = + gandiva::decimalops::Divide(context, x, y, out_precision, out_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void mod_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, int64_t y_high, + uint64_t y_low, int32_t y_precision, int32_t y_scale, + int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale); + gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale); + bool overflow; + + // TODO ravindra: generate error on overflows (ARROW-4570). + arrow::BasicDecimal128 out = + gandiva::decimalops::Mod(context, x, y, out_precision, out_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +int32_t compare_decimal128_decimal128_internal(int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + int64_t y_high, uint64_t y_low, + int32_t y_precision, int32_t y_scale) { + gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale); + gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale); + + return gandiva::decimalops::Compare(x, y); +} + +FORCE_INLINE +void abs_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, int32_t x_scale, + int32_t out_precision, int32_t out_scale, int64_t* out_high, + uint64_t* out_low) { + gandiva::BasicDecimal128 x(x_high, x_low); + x.Abs(); + *out_high = x.high_bits(); + *out_low = x.low_bits(); +} + +FORCE_INLINE +void ceil_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, int32_t x_scale, + int32_t out_precision, int32_t out_scale, int64_t* out_high, + uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + auto out = gandiva::decimalops::Ceil(x, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void floor_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + auto out = gandiva::decimalops::Floor(x, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void round_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + auto out = gandiva::decimalops::Round(x, out_precision, 0, 0, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void round_decimal128_int32(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int32_t rounding_scale, + int32_t out_precision, int32_t out_scale, int64_t* out_high, + uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + auto out = + gandiva::decimalops::Round(x, out_precision, out_scale, rounding_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void truncate_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + auto out = gandiva::decimalops::Truncate(x, out_precision, 0, 0, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void truncate_decimal128_int32(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int32_t rounding_scale, + int32_t out_precision, int32_t out_scale, + int64_t* out_high, uint64_t* out_low) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + auto out = gandiva::decimalops::Truncate(x, out_precision, out_scale, rounding_scale, + &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +double castFLOAT8_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + return gandiva::decimalops::ToDouble(x, &overflow); +} + +FORCE_INLINE +int64_t castBIGINT_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + + bool overflow = false; + return gandiva::decimalops::ToInt64(x, &overflow); +} + +FORCE_INLINE +void castDECIMAL_int64(int64_t in, int32_t x_precision, int32_t x_scale, + int64_t* out_high, uint64_t* out_low) { + bool overflow = false; + auto out = gandiva::decimalops::FromInt64(in, x_precision, x_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void castDECIMAL_int32(int32_t in, int32_t x_precision, int32_t x_scale, + int64_t* out_high, uint64_t* out_low) { + castDECIMAL_int64(in, x_precision, x_scale, out_high, out_low); +} + +FORCE_INLINE +void castDECIMAL_float64(double in, int32_t x_precision, int32_t x_scale, + int64_t* out_high, uint64_t* out_low) { + bool overflow = false; + auto out = gandiva::decimalops::FromDouble(in, x_precision, x_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +void castDECIMAL_float32(float in, int32_t x_precision, int32_t x_scale, + int64_t* out_high, uint64_t* out_low) { + castDECIMAL_float64(in, x_precision, x_scale, out_high, out_low); +} + +FORCE_INLINE +bool castDecimal_internal(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int32_t out_precision, int32_t out_scale, + int64_t* out_high, int64_t* out_low) { + gandiva::BasicDecimalScalar128 x({x_high, x_low}, x_precision, x_scale); + bool overflow = false; + auto out = gandiva::decimalops::Convert(x, out_precision, out_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); + return overflow; +} + +FORCE_INLINE +void castDECIMAL_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, int32_t out_precision, int32_t out_scale, + int64_t* out_high, int64_t* out_low) { + castDecimal_internal(x_high, x_low, x_precision, x_scale, out_precision, out_scale, + out_high, out_low); +} + +FORCE_INLINE +void castDECIMALNullOnOverflow_decimal128(int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + bool x_isvalid, bool* out_valid, + int32_t out_precision, int32_t out_scale, + int64_t* out_high, int64_t* out_low) { + *out_valid = true; + + if (!x_isvalid) { + *out_valid = false; + return; + } + + if (castDecimal_internal(x_high, x_low, x_precision, x_scale, out_precision, out_scale, + out_high, out_low)) { + *out_valid = false; + } +} + +FORCE_INLINE +int32_t hash32_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return x_isvalid + ? hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0) + : 0; +} + +FORCE_INLINE +int32_t hash_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return hash32_decimal128(x_high, x_low, x_precision, x_scale, x_isvalid); +} + +FORCE_INLINE +int64_t hash64_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return x_isvalid + ? hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0) + : 0; +} + +FORCE_INLINE +int32_t hash32WithSeed_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid, int32_t seed, + gdv_boolean seed_isvalid) { + if (!x_isvalid) { + return seed; + } + return hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed); +} + +FORCE_INLINE +int64_t hash64WithSeed_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid, int64_t seed, + gdv_boolean seed_isvalid) { + if (!x_isvalid) { + return seed; + } + return hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed); +} + +FORCE_INLINE +int32_t hash32AsDouble_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return x_isvalid + ? hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0) + : 0; +} + +FORCE_INLINE +int64_t hash64AsDouble_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return x_isvalid + ? hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, 0) + : 0; +} + +FORCE_INLINE +int32_t hash32AsDoubleWithSeed_decimal128(int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + gdv_boolean x_isvalid, int32_t seed, + gdv_boolean seed_isvalid) { + if (!x_isvalid) { + return seed; + } + return hash32_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed); +} + +FORCE_INLINE +int64_t hash64AsDoubleWithSeed_decimal128(int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + gdv_boolean x_isvalid, int64_t seed, + gdv_boolean seed_isvalid) { + if (!x_isvalid) { + return seed; + } + return hash64_buf(gandiva::BasicDecimal128(x_high, x_low).ToBytes().data(), 16, seed); +} + +FORCE_INLINE +gdv_boolean isnull_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return !x_isvalid; +} + +FORCE_INLINE +gdv_boolean isnotnull_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return x_isvalid; +} + +FORCE_INLINE +gdv_boolean isnumeric_decimal128(int64_t x_high, uint64_t x_low, int32_t x_precision, + int32_t x_scale, gdv_boolean x_isvalid) { + return x_isvalid; +} + +FORCE_INLINE +gdv_boolean is_not_distinct_from_decimal128_decimal128( + int64_t x_high, uint64_t x_low, int32_t x_precision, int32_t x_scale, + gdv_boolean x_isvalid, int64_t y_high, uint64_t y_low, int32_t y_precision, + int32_t y_scale, gdv_boolean y_isvalid) { + if (x_isvalid != y_isvalid) { + return false; + } + if (!x_isvalid) { + return true; + } + return 0 == compare_decimal128_decimal128_internal(x_high, x_low, x_precision, x_scale, + y_high, y_low, y_precision, y_scale); +} + +FORCE_INLINE +gdv_boolean is_distinct_from_decimal128_decimal128(int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + gdv_boolean x_isvalid, int64_t y_high, + uint64_t y_low, int32_t y_precision, + int32_t y_scale, + gdv_boolean y_isvalid) { + return !is_not_distinct_from_decimal128_decimal128(x_high, x_low, x_precision, x_scale, + x_isvalid, y_high, y_low, + y_precision, y_scale, y_isvalid); +} + +FORCE_INLINE +void castDECIMAL_utf8(int64_t context, const char* in, int32_t in_length, + int32_t out_precision, int32_t out_scale, int64_t* out_high, + uint64_t* out_low) { + int64_t dec_high_from_str; + uint64_t dec_low_from_str; + int32_t precision_from_str; + int32_t scale_from_str; + int32_t status = + gdv_fn_dec_from_string(context, in, in_length, &precision_from_str, &scale_from_str, + &dec_high_from_str, &dec_low_from_str); + if (status != 0) { + return; + } + + gandiva::BasicDecimalScalar128 x({dec_high_from_str, dec_low_from_str}, + precision_from_str, scale_from_str); + bool overflow = false; + auto out = gandiva::decimalops::Convert(x, out_precision, out_scale, &overflow); + *out_high = out.high_bits(); + *out_low = out.low_bits(); +} + +FORCE_INLINE +char* castVARCHAR_decimal128_int64(int64_t context, int64_t x_high, uint64_t x_low, + int32_t x_precision, int32_t x_scale, + int64_t out_len_param, int32_t* out_length) { + int32_t full_dec_str_len; + char* dec_str = + gdv_fn_dec_to_string(context, x_high, x_low, x_scale, &full_dec_str_len); + int32_t trunc_dec_str_len = + out_len_param < full_dec_str_len ? out_len_param : full_dec_str_len; + *out_length = trunc_dec_str_len; + return dec_str; +} + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point.h b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point.h new file mode 100644 index 000000000..45cfb28ca --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point.h @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +// TODO(wesm): IR compilation does not have any include directories set +#include "../../arrow/vendored/datetime/date.h" + +bool is_leap_year(int yy); +bool did_days_overflow(arrow_vendored::date::year_month_day ymd); +int last_possible_day_in_month(int month, int year); + +// A point of time measured in millis since epoch. +class EpochTimePoint { + public: + explicit EpochTimePoint(std::chrono::milliseconds millis_since_epoch) + : tp_(millis_since_epoch) {} + + explicit EpochTimePoint(int64_t millis_since_epoch) + : EpochTimePoint(std::chrono::milliseconds(millis_since_epoch)) {} + + int TmYear() const { return static_cast<int>(YearMonthDay().year()) - 1900; } + + int TmMon() const { return static_cast<unsigned int>(YearMonthDay().month()) - 1; } + + int TmYday() const { + auto to_days = arrow_vendored::date::floor<arrow_vendored::date::days>(tp_); + auto first_day_in_year = arrow_vendored::date::sys_days{ + YearMonthDay().year() / arrow_vendored::date::jan / 1}; + return (to_days - first_day_in_year).count(); + } + + int TmMday() const { return static_cast<unsigned int>(YearMonthDay().day()); } + + int TmWday() const { + auto to_days = arrow_vendored::date::floor<arrow_vendored::date::days>(tp_); + return (arrow_vendored::date::weekday{to_days} - // NOLINT + arrow_vendored::date::Sunday) + .count(); + } + + int TmHour() const { return static_cast<int>(TimeOfDay().hours().count()); } + + int TmMin() const { return static_cast<int>(TimeOfDay().minutes().count()); } + + int TmSec() const { + // TODO(wesm): UNIX y2k issue on int=gdv_int32 platforms + return static_cast<int>(TimeOfDay().seconds().count()); + } + + EpochTimePoint AddYears(int num_years) const { + auto ymd = YearMonthDay() + arrow_vendored::date::years(num_years); + return EpochTimePoint((arrow_vendored::date::sys_days{ymd} + // NOLINT + TimeOfDay().to_duration()) + .time_since_epoch()); + } + + EpochTimePoint AddMonths(int num_months) const { + auto ymd = YearMonthDay() + arrow_vendored::date::months(num_months); + + EpochTimePoint tp = EpochTimePoint((arrow_vendored::date::sys_days{ymd} + // NOLINT + TimeOfDay().to_duration()) + .time_since_epoch()); + + if (did_days_overflow(ymd)) { + int days_to_offset = + last_possible_day_in_month(static_cast<int>(ymd.year()), + static_cast<unsigned int>(ymd.month())) - + static_cast<unsigned int>(ymd.day()); + tp = tp.AddDays(days_to_offset); + } + return tp; + } + + EpochTimePoint AddDays(int num_days) const { + auto days_since_epoch = arrow_vendored::date::sys_days{YearMonthDay()} + // NOLINT + arrow_vendored::date::days(num_days); + return EpochTimePoint( + (days_since_epoch + TimeOfDay().to_duration()).time_since_epoch()); + } + + EpochTimePoint ClearTimeOfDay() const { + return EpochTimePoint((tp_ - TimeOfDay().to_duration()).time_since_epoch()); + } + + bool operator==(const EpochTimePoint& other) const { return tp_ == other.tp_; } + + int64_t MillisSinceEpoch() const { return tp_.time_since_epoch().count(); } + + arrow_vendored::date::time_of_day<std::chrono::milliseconds> TimeOfDay() const { + auto millis_since_midnight = + tp_ - arrow_vendored::date::floor<arrow_vendored::date::days>(tp_); + return arrow_vendored::date::time_of_day<std::chrono::milliseconds>( + millis_since_midnight); + } + + private: + arrow_vendored::date::year_month_day YearMonthDay() const { + return arrow_vendored::date::year_month_day{ + arrow_vendored::date::floor<arrow_vendored::date::days>(tp_)}; // NOLINT + } + + std::chrono::time_point<std::chrono::system_clock, std::chrono::milliseconds> tp_; +}; diff --git a/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point_test.cc b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point_test.cc new file mode 100644 index 000000000..9180aac07 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/epoch_time_point_test.cc @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <ctime> + +#include <gtest/gtest.h> +#include "./epoch_time_point.h" +#include "gandiva/precompiled/testing.h" +#include "gandiva/precompiled/types.h" + +#include "gandiva/date_utils.h" + +namespace gandiva { + +TEST(TestEpochTimePoint, TestTm) { + auto ts = StringToTimestamp("2015-05-07 10:20:34"); + EpochTimePoint tp(ts); + + struct tm* tm_ptr; +#if defined(_WIN32) + __time64_t tsec = ts / 1000; + tm_ptr = _gmtime64(&tsec); +#else + struct tm tm; + time_t tsec = ts / 1000; + tm_ptr = gmtime_r(&tsec, &tm); +#endif + + EXPECT_EQ(tp.TmYear(), tm_ptr->tm_year); + EXPECT_EQ(tp.TmMon(), tm_ptr->tm_mon); + EXPECT_EQ(tp.TmYday(), tm_ptr->tm_yday); + EXPECT_EQ(tp.TmMday(), tm_ptr->tm_mday); + EXPECT_EQ(tp.TmWday(), tm_ptr->tm_wday); + EXPECT_EQ(tp.TmHour(), tm_ptr->tm_hour); + EXPECT_EQ(tp.TmMin(), tm_ptr->tm_min); + EXPECT_EQ(tp.TmSec(), tm_ptr->tm_sec); +} + +TEST(TestEpochTimePoint, TestAddYears) { + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddYears(2), + EpochTimePoint(StringToTimestamp("2017-05-05 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddYears(0), + EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddYears(-1), + EpochTimePoint(StringToTimestamp("2014-05-05 10:20:34"))); +} + +TEST(TestEpochTimePoint, TestAddMonths) { + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(2), + EpochTimePoint(StringToTimestamp("2015-07-05 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(11), + EpochTimePoint(StringToTimestamp("2016-04-05 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(0), + EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(-1), + EpochTimePoint(StringToTimestamp("2015-04-05 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddMonths(-10), + EpochTimePoint(StringToTimestamp("2014-07-05 10:20:34"))); +} + +TEST(TestEpochTimePoint, TestAddDays) { + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(2), + EpochTimePoint(StringToTimestamp("2015-05-07 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(11), + EpochTimePoint(StringToTimestamp("2015-05-16 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(0), + EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(-1), + EpochTimePoint(StringToTimestamp("2015-05-04 10:20:34"))); + + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).AddDays(-10), + EpochTimePoint(StringToTimestamp("2015-04-25 10:20:34"))); +} + +TEST(TestEpochTimePoint, TestClearTimeOfDay) { + EXPECT_EQ(EpochTimePoint(StringToTimestamp("2015-05-05 10:20:34")).ClearTimeOfDay(), + EpochTimePoint(StringToTimestamp("2015-05-05 00:00:00"))); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops.cc new file mode 100644 index 000000000..365b08a6d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops.cc @@ -0,0 +1,410 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +#include "arrow/util/logging.h" +#include "gandiva/precompiled/decimal_ops.h" + +extern "C" { + +#include <math.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include "./types.h" + +// Expand the inner fn for types that support extended math. +#define ENUMERIC_TYPES_UNARY(INNER, OUT_TYPE) \ + INNER(int32, OUT_TYPE) \ + INNER(uint32, OUT_TYPE) \ + INNER(int64, OUT_TYPE) \ + INNER(uint64, OUT_TYPE) \ + INNER(float32, OUT_TYPE) \ + INNER(float64, OUT_TYPE) + +// Cubic root +#define CBRT(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE cbrt_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_float64>(cbrtl(static_cast<long double>(in))); \ + } + +ENUMERIC_TYPES_UNARY(CBRT, float64) + +// Exponent +#define EXP(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE exp_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_float64>(expl(static_cast<long double>(in))); \ + } + +ENUMERIC_TYPES_UNARY(EXP, float64) + +// log +#define LOG(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE log_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_float64>(logl(static_cast<long double>(in))); \ + } + +ENUMERIC_TYPES_UNARY(LOG, float64) + +// log base 10 +#define LOG10(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE log10_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_float64>(log10l(static_cast<long double>(in))); \ + } + +#define LOGL(VALUE) static_cast<gdv_float64>(logl(static_cast<long double>(VALUE))) + +ENUMERIC_TYPES_UNARY(LOG10, float64) + +FORCE_INLINE +void set_error_for_logbase(int64_t execution_context, double base) { + char const* prefix = "divide by zero error with log of base"; + int size = static_cast<int>(strlen(prefix)) + 64; + char* error = reinterpret_cast<char*>(malloc(size)); + snprintf(error, size, "%s %f", prefix, base); + gdv_fn_context_set_error_msg(execution_context, error); + free(static_cast<char*>(error)); +} + +// log with base +#define LOG_WITH_BASE(IN_TYPE1, IN_TYPE2, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE log_##IN_TYPE1##_##IN_TYPE2(gdv_int64 context, gdv_##IN_TYPE1 base, \ + gdv_##IN_TYPE2 value) { \ + gdv_##OUT_TYPE log_of_base = LOGL(base); \ + if (log_of_base == 0) { \ + set_error_for_logbase(context, static_cast<gdv_float64>(base)); \ + return 0; \ + } \ + return LOGL(value) / LOGL(base); \ + } + +LOG_WITH_BASE(int32, int32, float64) +LOG_WITH_BASE(uint32, uint32, float64) +LOG_WITH_BASE(int64, int64, float64) +LOG_WITH_BASE(uint64, uint64, float64) +LOG_WITH_BASE(float32, float32, float64) +LOG_WITH_BASE(float64, float64, float64) + +// Sin +#define SIN(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE sin_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(sin(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(SIN, float64) + +// Asin +#define ASIN(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE asin_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(asin(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(ASIN, float64) + +// Cos +#define COS(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE cos_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(cos(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(COS, float64) + +// Acos +#define ACOS(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE acos_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(acos(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(ACOS, float64) + +// Tan +#define TAN(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE tan_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(tan(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(TAN, float64) + +// Atan +#define ATAN(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE atan_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(atan(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(ATAN, float64) + +// Sinh +#define SINH(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE sinh_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(sinh(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(SINH, float64) + +// Cosh +#define COSH(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE cosh_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(cosh(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(COSH, float64) + +// Tanh +#define TANH(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE tanh_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(tanh(static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(TANH, float64) + +// Atan2 +#define ATAN2(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE atan2_##IN_TYPE##_##IN_TYPE(gdv_##IN_TYPE in1, gdv_##IN_TYPE in2) { \ + return static_cast<gdv_##OUT_TYPE>( \ + atan2(static_cast<long double>(in1), static_cast<long double>(in2))); \ + } +ENUMERIC_TYPES_UNARY(ATAN2, float64) + +// Cot +#define COT(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE cot_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(tan(M_PI / 2 - static_cast<long double>(in))); \ + } +ENUMERIC_TYPES_UNARY(COT, float64) + +// Radians +#define RADIANS(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE radians_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(static_cast<long double>(in) * M_PI / 180.0); \ + } +ENUMERIC_TYPES_UNARY(RADIANS, float64) + +// Degrees +#define DEGREES(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE degrees_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast<gdv_##OUT_TYPE>(static_cast<long double>(in) * 180.0 / M_PI); \ + } +ENUMERIC_TYPES_UNARY(DEGREES, float64) + +// power +#define POWER(IN_TYPE1, IN_TYPE2, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE power_##IN_TYPE1##_##IN_TYPE2(gdv_##IN_TYPE1 in1, gdv_##IN_TYPE2 in2) { \ + return static_cast<gdv_float64>(powl(in1, in2)); \ + } +POWER(float64, float64, float64) + +FORCE_INLINE +gdv_int32 round_int32(gdv_int32 num) { return num; } + +FORCE_INLINE +gdv_int64 round_int64(gdv_int64 num) { return num; } + +// rounds the number to the nearest integer +#define ROUND_DECIMAL(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE round_##TYPE(gdv_##TYPE num) { \ + return static_cast<gdv_##TYPE>(trunc(num + ((num >= 0) ? 0.5 : -0.5))); \ + } + +ROUND_DECIMAL(float32) +ROUND_DECIMAL(float64) + +// rounds the number to the given scale +#define ROUND_DECIMAL_TO_SCALE(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE round_##TYPE##_int32(gdv_##TYPE number, gdv_int32 out_scale) { \ + gdv_float64 scale_multiplier = get_scale_multiplier(out_scale); \ + return static_cast<gdv_##TYPE>( \ + trunc(number * scale_multiplier + ((number >= 0) ? 0.5 : -0.5)) / \ + scale_multiplier); \ + } + +ROUND_DECIMAL_TO_SCALE(float32) +ROUND_DECIMAL_TO_SCALE(float64) + +FORCE_INLINE +gdv_int32 round_int32_int32(gdv_int32 number, gdv_int32 precision) { + // for integers, there is nothing following the decimal point, + // so round() always returns the same number if precision >= 0 + if (precision >= 0) { + return number; + } + gdv_int32 abs_precision = -precision; + // This is to ensure that there is no overflow while calculating 10^precision, 9 is + // the smallest N for which 10^N does not fit into 32 bits, so we can safely return 0 + if (abs_precision > 9) { + return 0; + } + gdv_int32 num_sign = (number > 0) ? 1 : -1; + gdv_int32 abs_number = number * num_sign; + gdv_int32 power_of_10 = static_cast<gdv_int32>(get_power_of_10(abs_precision)); + gdv_int32 remainder = abs_number % power_of_10; + abs_number -= remainder; + // if the fractional part of the quotient >= 0.5, round to next higher integer + if (remainder >= power_of_10 / 2) { + abs_number += power_of_10; + } + return abs_number * num_sign; +} + +FORCE_INLINE +gdv_int64 round_int64_int32(gdv_int64 number, gdv_int32 precision) { + // for long integers, there is nothing following the decimal point, + // so round() always returns the same number if precision >= 0 + if (precision >= 0) { + return number; + } + gdv_int32 abs_precision = -precision; + // This is to ensure that there is no overflow while calculating 10^precision, 19 is + // the smallest N for which 10^N does not fit into 64 bits, so we can safely return 0 + if (abs_precision > 18) { + return 0; + } + gdv_int32 num_sign = (number > 0) ? 1 : -1; + gdv_int64 abs_number = number * num_sign; + gdv_int64 power_of_10 = get_power_of_10(abs_precision); + gdv_int64 remainder = abs_number % power_of_10; + abs_number -= remainder; + // if the fractional part of the quotient >= 0.5, round to next higher integer + if (remainder >= power_of_10 / 2) { + abs_number += power_of_10; + } + return abs_number * num_sign; +} + +FORCE_INLINE +gdv_int64 get_power_of_10(gdv_int32 exp) { + DCHECK_GE(exp, 0); + DCHECK_LE(exp, 18); + static const gdv_int64 power_of_10[] = {1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000}; + return power_of_10[exp]; +} + +FORCE_INLINE +gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale) { + bool overflow = false; + arrow::BasicDecimal128 decimal = gandiva::decimalops::FromInt64(in, 38, 0, &overflow); + arrow::BasicDecimal128 decimal_with_outscale = + gandiva::decimalops::Truncate(gandiva::BasicDecimalScalar128(decimal, 38, 0), 38, + out_scale, out_scale, &overflow); + if (out_scale < 0) { + out_scale = 0; + } + return gandiva::decimalops::ToInt64( + gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), &overflow); +} + +FORCE_INLINE +gdv_float64 get_scale_multiplier(gdv_int32 scale) { + static const gdv_float64 values[] = {1.0, + 10.0, + 100.0, + 1000.0, + 10000.0, + 100000.0, + 1000000.0, + 10000000.0, + 100000000.0, + 1000000000.0, + 10000000000.0, + 100000000000.0, + 1000000000000.0, + 10000000000000.0, + 100000000000000.0, + 1000000000000000.0, + 10000000000000000.0, + 100000000000000000.0, + 1000000000000000000.0, + 10000000000000000000.0}; + if (scale >= 0 && scale < 20) { + return values[scale]; + } + return power_float64_float64(10.0, scale); +} + +// returns the binary representation of a given integer (e.g. 928 -> 1110100000) +#define BIN_INTEGER(IN_TYPE) \ + FORCE_INLINE \ + const char* bin_##IN_TYPE(int64_t context, gdv_##IN_TYPE value, int32_t* out_len) { \ + *out_len = 0; \ + int32_t len = 8 * sizeof(value); \ + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, len)); \ + if (ret == nullptr) { \ + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output"); \ + return ""; \ + } \ + /* handle case when value is zero */ \ + if (value == 0) { \ + *out_len = 1; \ + ret[0] = '0'; \ + return ret; \ + } \ + /* generate binary representation iteratively */ \ + gdv_u##IN_TYPE i; \ + int8_t count = 0; \ + bool first = false; /* flag for not printing left zeros in positive numbers */ \ + for (i = static_cast<gdv_u##IN_TYPE>(1) << (len - 1); i > 0; i = i / 2) { \ + if ((value & i) != 0) { \ + ret[count] = '1'; \ + if (!first) first = true; \ + } else { \ + if (!first) continue; \ + ret[count] = '0'; \ + } \ + count += 1; \ + } \ + *out_len = count; \ + return ret; \ + } + +BIN_INTEGER(int32) +BIN_INTEGER(int64) + +#undef BIN_INTEGER + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops_test.cc new file mode 100644 index 000000000..147b4035c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/extended_math_ops_test.cc @@ -0,0 +1,349 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +#include <gtest/gtest.h> +#include <cmath> +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +static const double MAX_ERROR = 0.00005; + +void VerifyFuzzyEquals(double actual, double expected, double max_error = MAX_ERROR) { + EXPECT_TRUE(fabs(actual - expected) < max_error) << actual << " != " << expected; +} + +TEST(TestExtendedMathOps, TestCbrt) { + VerifyFuzzyEquals(cbrt_int32(27), 3); + VerifyFuzzyEquals(cbrt_int64(27), 3); + VerifyFuzzyEquals(cbrt_float32(27), 3); + VerifyFuzzyEquals(cbrt_float64(27), 3); + VerifyFuzzyEquals(cbrt_float64(-27), -3); + + VerifyFuzzyEquals(cbrt_float32(15.625), 2.5); + VerifyFuzzyEquals(cbrt_float64(15.625), 2.5); +} + +TEST(TestExtendedMathOps, TestExp) { + double val = 20.085536923187668; + + VerifyFuzzyEquals(exp_int32(3), val); + VerifyFuzzyEquals(exp_int64(3), val); + VerifyFuzzyEquals(exp_float32(3), val); + VerifyFuzzyEquals(exp_float64(3), val); +} + +TEST(TestExtendedMathOps, TestLog) { + double val = 4.1588830833596715; + + VerifyFuzzyEquals(log_int32(64), val); + VerifyFuzzyEquals(log_int64(64), val); + VerifyFuzzyEquals(log_float32(64), val); + VerifyFuzzyEquals(log_float64(64), val); + + EXPECT_EQ(log_int32(0), -std::numeric_limits<double>::infinity()); +} + +TEST(TestExtendedMathOps, TestLog10) { + VerifyFuzzyEquals(log10_int32(100), 2); + VerifyFuzzyEquals(log10_int64(100), 2); + VerifyFuzzyEquals(log10_float32(100), 2); + VerifyFuzzyEquals(log10_float64(100), 2); +} + +TEST(TestExtendedMathOps, TestPower) { + VerifyFuzzyEquals(power_float64_float64(2, 5.4), 42.22425314473263); + VerifyFuzzyEquals(power_float64_float64(5.4, 2), 29.160000000000004); +} + +TEST(TestExtendedMathOps, TestLogWithBase) { + gandiva::ExecutionContext context; + gdv_float64 out = + log_int32_int32(reinterpret_cast<gdv_int64>(&context), 1 /*base*/, 10 /*value*/); + VerifyFuzzyEquals(out, 0); + EXPECT_EQ(context.has_error(), true); + EXPECT_TRUE(context.get_error().find("divide by zero error") != std::string::npos) + << context.get_error(); + + gandiva::ExecutionContext context1; + out = log_int32_int32(reinterpret_cast<gdv_int64>(&context), 2 /*base*/, 64 /*value*/); + VerifyFuzzyEquals(out, 6); + EXPECT_EQ(context1.has_error(), false); +} + +TEST(TestExtendedMathOps, TestRoundDecimal) { + EXPECT_FLOAT_EQ(round_float32(1234.245f), 1234); + EXPECT_FLOAT_EQ(round_float32(-11.7892f), -12); + EXPECT_FLOAT_EQ(round_float32(1.4999999f), 1); + EXPECT_EQ(std::signbit(round_float32(0)), 0); + EXPECT_FLOAT_EQ(round_float32_int32(1234.789f, 2), 1234.79f); + EXPECT_FLOAT_EQ(round_float32_int32(1234.12345f, -3), 1000); + EXPECT_FLOAT_EQ(round_float32_int32(-1234.4567f, 3), -1234.457f); + EXPECT_FLOAT_EQ(round_float32_int32(-1234.4567f, -3), -1000); + EXPECT_FLOAT_EQ(round_float32_int32(1234.4567f, 0), 1234); + EXPECT_FLOAT_EQ(round_float32_int32(1.5499999523162842f, 1), 1.5f); + EXPECT_EQ(std::signbit(round_float32_int32(0, 5)), 0); + EXPECT_FLOAT_EQ(round_float32_int32(static_cast<float>(1.55), 1), 1.5f); + EXPECT_FLOAT_EQ(round_float32_int32(static_cast<float>(9.134123), 2), 9.13f); + EXPECT_FLOAT_EQ(round_float32_int32(static_cast<float>(-1.923), 1), -1.9f); + + VerifyFuzzyEquals(round_float64(1234.245), 1234); + VerifyFuzzyEquals(round_float64(-11.7892), -12); + VerifyFuzzyEquals(round_float64(1.4999999), 1); + EXPECT_EQ(std::signbit(round_float64(0)), 0); + VerifyFuzzyEquals(round_float64_int32(1234.789, 2), 1234.79); + VerifyFuzzyEquals(round_float64_int32(1234.12345, -3), 1000); + VerifyFuzzyEquals(round_float64_int32(-1234.4567, 3), -1234.457); + VerifyFuzzyEquals(round_float64_int32(-1234.4567, -3), -1000); + VerifyFuzzyEquals(round_float64_int32(1234.4567, 0), 1234); + EXPECT_EQ(std::signbit(round_float64_int32(0, -2)), 0); + VerifyFuzzyEquals(round_float64_int32((double)INT_MAX + 1, 0), (double)INT_MAX + 1); + VerifyFuzzyEquals(round_float64_int32((double)INT_MIN - 1, 0), (double)INT_MIN - 1); +} + +TEST(TestExtendedMathOps, TestRound) { + EXPECT_EQ(round_int32(21134), 21134); + EXPECT_EQ(round_int32(-132422), -132422); + EXPECT_EQ(round_int32_int32(7589, -1), 7590); + EXPECT_EQ(round_int32_int32(8532, -2), 8500); + EXPECT_EQ(round_int32_int32(-8579, -1), -8580); + EXPECT_EQ(round_int32_int32(-8612, -2), -8600); + EXPECT_EQ(round_int32_int32(758, 2), 758); + EXPECT_EQ(round_int32_int32(8612, -5), 0); + + EXPECT_EQ(round_int64(3453562312), 3453562312); + EXPECT_EQ(round_int64(-23453462343), -23453462343); + EXPECT_EQ(round_int64_int32(3453562312, -2), 3453562300); + EXPECT_EQ(round_int64_int32(3453562343, -5), 3453600000); + EXPECT_EQ(round_int64_int32(345353425343, 12), 345353425343); + EXPECT_EQ(round_int64_int32(-23453462343, -4), -23453460000); + EXPECT_EQ(round_int64_int32(-23453462343, -5), -23453500000); + EXPECT_EQ(round_int64_int32(345353425343, -12), 0); +} + +TEST(TestExtendedMathOps, TestTruncate) { + EXPECT_EQ(truncate_int64_int32(1234, 4), 1234); + EXPECT_EQ(truncate_int64_int32(-1234, 4), -1234); + EXPECT_EQ(truncate_int64_int32(1234, -4), 0); + EXPECT_EQ(truncate_int64_int32(-1234, -2), -1200); + EXPECT_EQ(truncate_int64_int32(8124674407369523212, 0), 8124674407369523212); + EXPECT_EQ(truncate_int64_int32(8124674407369523212, -2), 8124674407369523200); +} + +TEST(TestExtendedMathOps, TestTrigonometricFunctions) { + auto pi_float = static_cast<float>(M_PI); + // Sin functions + VerifyFuzzyEquals(sin_float32(0), sin(0)); + VerifyFuzzyEquals(sin_float32(0), sin(0)); + VerifyFuzzyEquals(sin_float32(pi_float / 2), sin(M_PI / 2)); + VerifyFuzzyEquals(sin_float32(pi_float), sin(M_PI)); + VerifyFuzzyEquals(sin_float32(-pi_float / 2), sin(-M_PI / 2)); + VerifyFuzzyEquals(sin_float64(0), sin(0)); + VerifyFuzzyEquals(sin_float64(M_PI / 2), sin(M_PI / 2)); + VerifyFuzzyEquals(sin_float64(M_PI), sin(M_PI)); + VerifyFuzzyEquals(sin_float64(-M_PI / 2), sin(-M_PI / 2)); + VerifyFuzzyEquals(sin_int32(0), sin(0)); + VerifyFuzzyEquals(sin_int64(0), sin(0)); + + // Cos functions + VerifyFuzzyEquals(cos_float32(0), cos(0)); + VerifyFuzzyEquals(cos_float32(pi_float / 2), cos(M_PI / 2)); + VerifyFuzzyEquals(cos_float32(pi_float), cos(M_PI)); + VerifyFuzzyEquals(cos_float32(-pi_float / 2), cos(-M_PI / 2)); + VerifyFuzzyEquals(cos_float64(0), cos(0)); + VerifyFuzzyEquals(cos_float64(M_PI / 2), cos(M_PI / 2)); + VerifyFuzzyEquals(cos_float64(M_PI), cos(M_PI)); + VerifyFuzzyEquals(cos_float64(-M_PI / 2), cos(-M_PI / 2)); + VerifyFuzzyEquals(cos_int32(0), cos(0)); + VerifyFuzzyEquals(cos_int64(0), cos(0)); + + // Asin functions + VerifyFuzzyEquals(asin_float32(-1.0), asin(-1.0)); + VerifyFuzzyEquals(asin_float32(1.0), asin(1.0)); + VerifyFuzzyEquals(asin_float64(-1.0), asin(-1.0)); + VerifyFuzzyEquals(asin_float64(1.0), asin(1.0)); + VerifyFuzzyEquals(asin_int32(0), asin(0)); + VerifyFuzzyEquals(asin_int64(0), asin(0)); + + // Acos functions + VerifyFuzzyEquals(acos_float32(-1.0), acos(-1.0)); + VerifyFuzzyEquals(acos_float32(1.0), acos(1.0)); + VerifyFuzzyEquals(acos_float64(-1.0), acos(-1.0)); + VerifyFuzzyEquals(acos_float64(1.0), acos(1.0)); + VerifyFuzzyEquals(acos_int32(0), acos(0)); + VerifyFuzzyEquals(acos_int64(0), acos(0)); + + // Tan + VerifyFuzzyEquals(tan_float32(pi_float), tan(M_PI)); + VerifyFuzzyEquals(tan_float32(-pi_float), tan(-M_PI)); + VerifyFuzzyEquals(tan_float64(M_PI), tan(M_PI)); + VerifyFuzzyEquals(tan_float64(-M_PI), tan(-M_PI)); + VerifyFuzzyEquals(tan_int32(0), tan(0)); + VerifyFuzzyEquals(tan_int64(0), tan(0)); + + // Atan + VerifyFuzzyEquals(atan_float32(pi_float), atan(M_PI)); + VerifyFuzzyEquals(atan_float32(-pi_float), atan(-M_PI)); + VerifyFuzzyEquals(atan_float64(M_PI), atan(M_PI)); + VerifyFuzzyEquals(atan_float64(-M_PI), atan(-M_PI)); + VerifyFuzzyEquals(atan_int32(0), atan(0)); + VerifyFuzzyEquals(atan_int64(0), atan(0)); + + // Sinh functions + VerifyFuzzyEquals(sinh_float32(0), sinh(0)); + VerifyFuzzyEquals(sinh_float32(pi_float / 2), sinh(M_PI / 2)); + VerifyFuzzyEquals(sinh_float32(pi_float), sinh(M_PI)); + VerifyFuzzyEquals(sinh_float32(-pi_float / 2), sinh(-M_PI / 2)); + VerifyFuzzyEquals(sinh_float64(0), sinh(0)); + VerifyFuzzyEquals(sinh_float64(M_PI / 2), sinh(M_PI / 2)); + VerifyFuzzyEquals(sinh_float64(M_PI), sinh(M_PI)); + VerifyFuzzyEquals(sinh_float64(-M_PI / 2), sinh(-M_PI / 2)); + VerifyFuzzyEquals(sinh_int32(0), sinh(0)); + VerifyFuzzyEquals(sinh_int64(0), sinh(0)); + + // Cosh functions + VerifyFuzzyEquals(cosh_float32(0), cosh(0)); + VerifyFuzzyEquals(cosh_float32(pi_float / 2), cosh(M_PI / 2)); + VerifyFuzzyEquals(cosh_float32(pi_float), cosh(M_PI)); + VerifyFuzzyEquals(cosh_float32(-pi_float / 2), cosh(-M_PI / 2)); + VerifyFuzzyEquals(cosh_float64(0), cosh(0)); + VerifyFuzzyEquals(cosh_float64(M_PI / 2), cosh(M_PI / 2)); + VerifyFuzzyEquals(cosh_float64(M_PI), cosh(M_PI)); + VerifyFuzzyEquals(cosh_float64(-M_PI / 2), cosh(-M_PI / 2)); + VerifyFuzzyEquals(cosh_int32(0), cosh(0)); + VerifyFuzzyEquals(cosh_int64(0), cosh(0)); + + // Tanh + VerifyFuzzyEquals(tanh_float32(pi_float), tanh(M_PI)); + VerifyFuzzyEquals(tanh_float32(-pi_float), tanh(-M_PI)); + VerifyFuzzyEquals(tanh_float64(M_PI), tanh(M_PI)); + VerifyFuzzyEquals(tanh_float64(-M_PI), tanh(-M_PI)); + VerifyFuzzyEquals(tanh_int32(0), tanh(0)); + VerifyFuzzyEquals(tanh_int64(0), tanh(0)); + + // Atan2 + VerifyFuzzyEquals(atan2_float32_float32(1, 0), atan2(1, 0)); + VerifyFuzzyEquals(atan2_float32_float32(-1.0, 0), atan2(-1, 0)); + VerifyFuzzyEquals(atan2_float64_float64(1.0, 0.0), atan2(1, 0)); + VerifyFuzzyEquals(atan2_float64_float64(-1, 0), atan2(-1, 0)); + VerifyFuzzyEquals(atan2_int32_int32(1, 0), atan2(1, 0)); + VerifyFuzzyEquals(atan2_int64_int64(-1, 0), atan2(-1, 0)); + + // Radians + VerifyFuzzyEquals(radians_float32(0), 0); + VerifyFuzzyEquals(radians_float32(180.0), M_PI); + VerifyFuzzyEquals(radians_float32(90.0), M_PI / 2); + VerifyFuzzyEquals(radians_float64(0), 0); + VerifyFuzzyEquals(radians_float64(180.0), M_PI); + VerifyFuzzyEquals(radians_float64(90.0), M_PI / 2); + VerifyFuzzyEquals(radians_int32(180), M_PI); + VerifyFuzzyEquals(radians_int64(90), M_PI / 2); + + // Degrees + VerifyFuzzyEquals(degrees_float32(0), 0.0); + VerifyFuzzyEquals(degrees_float32(pi_float), 180.0); + VerifyFuzzyEquals(degrees_float32(pi_float / 2), 90.0); + VerifyFuzzyEquals(degrees_float64(0), 0.0); + VerifyFuzzyEquals(degrees_float64(M_PI), 180.0); + VerifyFuzzyEquals(degrees_float64(M_PI / 2), 90.0); + VerifyFuzzyEquals(degrees_int32(1), 57.2958); + VerifyFuzzyEquals(degrees_int64(1), 57.2958); + + // Cot + VerifyFuzzyEquals(cot_float32(pi_float / 2), tan(M_PI / 2 - M_PI / 2)); + VerifyFuzzyEquals(cot_float64(M_PI / 2), tan(M_PI / 2 - M_PI / 2)); +} + +TEST(TestExtendedMathOps, TestBinRepresentation) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = bin_int32(ctx_ptr, 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "111"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int32(ctx_ptr, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int32(ctx_ptr, 28550, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "110111110000110"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int32(ctx_ptr, -28550, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "11111111111111111001000001111010"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int32(ctx_ptr, 58117, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1110001100000101"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int32(ctx_ptr, -58117, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "11111111111111110001110011111011"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int32(ctx_ptr, INT32_MAX, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1111111111111111111111111111111"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int32(ctx_ptr, INT32_MIN, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "10000000000000000000000000000000"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "111"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "0"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, 28550, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "110111110000110"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, -28550, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "1111111111111111111111111111111111111111111111111001000001111010"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, 58117, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1110001100000101"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, -58117, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "1111111111111111111111111111111111111111111111110001110011111011"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, INT64_MAX, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "111111111111111111111111111111111111111111111111111111111111111"); + EXPECT_FALSE(ctx.has_error()); + + out_str = bin_int64(ctx_ptr, INT64_MIN, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "1000000000000000000000000000000000000000000000000000000000000000"); + EXPECT_FALSE(ctx.has_error()); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/hash.cc b/src/arrow/cpp/src/gandiva/precompiled/hash.cc new file mode 100644 index 000000000..eacf36230 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/hash.cc @@ -0,0 +1,407 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern "C" { + +#include <string.h> + +#include "./types.h" + +static inline gdv_uint64 rotate_left(gdv_uint64 val, int distance) { + return (val << distance) | (val >> (64 - distance)); +} + +// +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. +// See http://smhasher.googlecode.com/svn/trunk/MurmurHash3.cpp +// MurmurHash3_x64_128 +// +static inline gdv_uint64 fmix64(gdv_uint64 k) { + k ^= k >> 33; + k *= 0xff51afd7ed558ccduLL; + k ^= k >> 33; + k *= 0xc4ceb9fe1a85ec53uLL; + k ^= k >> 33; + return k; +} + +static inline gdv_uint64 murmur3_64(gdv_uint64 val, gdv_int32 seed) { + gdv_uint64 h1 = seed; + gdv_uint64 h2 = seed; + + gdv_uint64 c1 = 0x87c37b91114253d5ull; + gdv_uint64 c2 = 0x4cf5ad432745937full; + + int length = 8; + gdv_uint64 k1 = 0; + + k1 = val; + k1 *= c1; + k1 = rotate_left(k1, 31); + k1 *= c2; + h1 ^= k1; + + h1 ^= length; + h2 ^= length; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + + // h2 += h1; + // murmur3_128 should return 128 bit (h1,h2), now we return only 64bits, + return h1; +} + +static inline gdv_uint32 murmur3_32(gdv_uint64 val, gdv_int32 seed) { + gdv_uint64 c1 = 0xcc9e2d51ull; + gdv_uint64 c2 = 0x1b873593ull; + int length = 8; + static gdv_uint64 UINT_MASK = 0xffffffffull; + gdv_uint64 lh1 = seed & UINT_MASK; + for (int i = 0; i < 2; i++) { + gdv_uint64 lk1 = ((val >> i * 32) & UINT_MASK); + lk1 *= c1; + lk1 &= UINT_MASK; + + lk1 = ((lk1 << 15) & UINT_MASK) | (lk1 >> 17); + + lk1 *= c2; + lk1 &= UINT_MASK; + + lh1 ^= lk1; + lh1 = ((lh1 << 13) & UINT_MASK) | (lh1 >> 19); + + lh1 = lh1 * 5 + 0xe6546b64L; + lh1 = UINT_MASK & lh1; + } + lh1 ^= length; + + lh1 ^= lh1 >> 16; + lh1 *= 0x85ebca6bull; + lh1 = UINT_MASK & lh1; + lh1 ^= lh1 >> 13; + lh1 *= 0xc2b2ae35ull; + lh1 = UINT_MASK & lh1; + lh1 ^= lh1 >> 16; + + return static_cast<gdv_uint32>(lh1); +} + +static inline gdv_uint64 double_to_long_bits(double value) { + gdv_uint64 result; + memcpy(&result, &value, sizeof(result)); + return result; +} + +FORCE_INLINE gdv_int64 hash64(double val, gdv_int64 seed) { + return murmur3_64(double_to_long_bits(val), static_cast<gdv_int32>(seed)); +} + +FORCE_INLINE gdv_int32 hash32(double val, gdv_int32 seed) { + return murmur3_32(double_to_long_bits(val), seed); +} + +// Wrappers for all the numeric/data/time arrow types + +#define HASH64_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid, gdv_int64 seed, \ + gdv_boolean seed_isvalid) { \ + if (!is_valid) { \ + return seed; \ + } \ + return hash64(static_cast<double>(in), seed); \ + } + +#define HASH32_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid, gdv_int32 seed, \ + gdv_boolean seed_isvalid) { \ + if (!is_valid) { \ + return seed; \ + } \ + return hash32(static_cast<double>(in), seed); \ + } + +#define HASH64_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid) { \ + return is_valid ? hash64(static_cast<double>(in), 0) : 0; \ + } + +#define HASH32_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_boolean is_valid) { \ + return is_valid ? hash32(static_cast<double>(in), 0) : 0; \ + } + +// Expand inner macro for all numeric types. +#define NUMERIC_BOOL_DATE_TYPES(INNER, NAME) \ + INNER(NAME, int8) \ + INNER(NAME, int16) \ + INNER(NAME, int32) \ + INNER(NAME, int64) \ + INNER(NAME, uint8) \ + INNER(NAME, uint16) \ + INNER(NAME, uint32) \ + INNER(NAME, uint64) \ + INNER(NAME, float32) \ + INNER(NAME, float64) \ + INNER(NAME, boolean) \ + INNER(NAME, date64) \ + INNER(NAME, date32) \ + INNER(NAME, time32) \ + INNER(NAME, timestamp) + +NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash) +NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash32) +NUMERIC_BOOL_DATE_TYPES(HASH32_OP, hash32AsDouble) +NUMERIC_BOOL_DATE_TYPES(HASH32_WITH_SEED_OP, hash32WithSeed) +NUMERIC_BOOL_DATE_TYPES(HASH32_WITH_SEED_OP, hash32AsDoubleWithSeed) + +NUMERIC_BOOL_DATE_TYPES(HASH64_OP, hash64) +NUMERIC_BOOL_DATE_TYPES(HASH64_OP, hash64AsDouble) +NUMERIC_BOOL_DATE_TYPES(HASH64_WITH_SEED_OP, hash64WithSeed) +NUMERIC_BOOL_DATE_TYPES(HASH64_WITH_SEED_OP, hash64AsDoubleWithSeed) + +#undef NUMERIC_BOOL_DATE_TYPES + +static inline gdv_uint64 murmur3_64_buf(const gdv_uint8* key, gdv_int32 len, + gdv_int32 seed) { + gdv_uint64 h1 = seed; + gdv_uint64 h2 = seed; + gdv_uint64 c1 = 0x87c37b91114253d5ull; + gdv_uint64 c2 = 0x4cf5ad432745937full; + + const gdv_uint64* blocks = reinterpret_cast<const gdv_uint64*>(key); + int nblocks = len / 16; + for (int i = 0; i < nblocks; i++) { + gdv_uint64 k1 = blocks[i * 2 + 0]; + gdv_uint64 k2 = blocks[i * 2 + 1]; + + k1 *= c1; + k1 = rotate_left(k1, 31); + k1 *= c2; + h1 ^= k1; + h1 = rotate_left(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + k2 *= c2; + k2 = rotate_left(k2, 33); + k2 *= c1; + h2 ^= k2; + h2 = rotate_left(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + } + + // tail + gdv_uint64 k1 = 0; + gdv_uint64 k2 = 0; + + const gdv_uint8* tail = reinterpret_cast<const gdv_uint8*>(key + nblocks * 16); + switch (len & 15) { + case 15: + k2 = static_cast<gdv_uint64>(tail[14]) << 48; + case 14: + k2 ^= static_cast<gdv_uint64>(tail[13]) << 40; + case 13: + k2 ^= static_cast<gdv_uint64>(tail[12]) << 32; + case 12: + k2 ^= static_cast<gdv_uint64>(tail[11]) << 24; + case 11: + k2 ^= static_cast<gdv_uint64>(tail[10]) << 16; + case 10: + k2 ^= static_cast<gdv_uint64>(tail[9]) << 8; + case 9: + k2 ^= static_cast<gdv_uint64>(tail[8]); + k2 *= c2; + k2 = rotate_left(k2, 33); + k2 *= c1; + h2 ^= k2; + case 8: + k1 ^= static_cast<gdv_uint64>(tail[7]) << 56; + case 7: + k1 ^= static_cast<gdv_uint64>(tail[6]) << 48; + case 6: + k1 ^= static_cast<gdv_uint64>(tail[5]) << 40; + case 5: + k1 ^= static_cast<gdv_uint64>(tail[4]) << 32; + case 4: + k1 ^= static_cast<gdv_uint64>(tail[3]) << 24; + case 3: + k1 ^= static_cast<gdv_uint64>(tail[2]) << 16; + case 2: + k1 ^= static_cast<gdv_uint64>(tail[1]) << 8; + case 1: + k1 ^= static_cast<gdv_uint64>(tail[0]) << 0; + k1 *= c1; + k1 = rotate_left(k1, 31); + k1 *= c2; + h1 ^= k1; + } + + h1 ^= len; + h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + // h2 += h1; + // returning 64-bits of the 128-bit hash. + return h1; +} + +static gdv_uint32 murmur3_32_buf(const gdv_uint8* key, gdv_int32 len, gdv_int32 seed) { + static const gdv_uint64 c1 = 0xcc9e2d51ull; + static const gdv_uint64 c2 = 0x1b873593ull; + static const gdv_uint64 UINT_MASK = 0xffffffffull; + gdv_uint64 lh1 = seed; + const gdv_uint32* blocks = reinterpret_cast<const gdv_uint32*>(key); + int nblocks = len / 4; + const gdv_uint8* tail = reinterpret_cast<const gdv_uint8*>(key + nblocks * 4); + for (int i = 0; i < nblocks; i++) { + gdv_uint64 lk1 = static_cast<gdv_uint64>(blocks[i]); + + // k1 *= c1; + lk1 *= c1; + lk1 &= UINT_MASK; + + lk1 = ((lk1 << 15) & UINT_MASK) | (lk1 >> 17); + + lk1 *= c2; + lk1 = lk1 & UINT_MASK; + lh1 ^= lk1; + lh1 = ((lh1 << 13) & UINT_MASK) | (lh1 >> 19); + + lh1 = lh1 * 5 + 0xe6546b64ull; + lh1 = UINT_MASK & lh1; + } + + // tail + gdv_uint64 lk1 = 0; + + switch (len & 3) { + case 3: + lk1 = (tail[2] & 0xff) << 16; + case 2: + lk1 |= (tail[1] & 0xff) << 8; + case 1: + lk1 |= (tail[0] & 0xff); + lk1 *= c1; + lk1 = UINT_MASK & lk1; + lk1 = ((lk1 << 15) & UINT_MASK) | (lk1 >> 17); + + lk1 *= c2; + lk1 = lk1 & UINT_MASK; + + lh1 ^= lk1; + } + + // finalization + lh1 ^= len; + + lh1 ^= lh1 >> 16; + lh1 *= 0x85ebca6b; + lh1 = UINT_MASK & lh1; + lh1 ^= lh1 >> 13; + + lh1 *= 0xc2b2ae35; + lh1 = UINT_MASK & lh1; + lh1 ^= lh1 >> 16; + + return static_cast<gdv_uint32>(lh1 & UINT_MASK); +} + +FORCE_INLINE gdv_int64 hash64_buf(const gdv_uint8* buf, int len, gdv_int64 seed) { + return murmur3_64_buf(buf, len, static_cast<gdv_int32>(seed)); +} + +FORCE_INLINE gdv_int32 hash32_buf(const gdv_uint8* buf, int len, gdv_int32 seed) { + return murmur3_32_buf(buf, len, seed); +} + +// Wrappers for the varlen types + +#define HASH64_BUF_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid, \ + gdv_int64 seed, gdv_boolean seed_isvalid) { \ + if (!is_valid) { \ + return seed; \ + } \ + return hash64_buf(reinterpret_cast<const uint8_t*>(in), len, seed); \ + } + +#define HASH32_BUF_WITH_SEED_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid, \ + gdv_int32 seed, gdv_boolean seed_isvalid) { \ + if (!is_valid) { \ + return seed; \ + } \ + return hash32_buf(reinterpret_cast<const uint8_t*>(in), len, seed); \ + } + +#define HASH64_BUF_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int64 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \ + return is_valid ? hash64_buf(reinterpret_cast<const uint8_t*>(in), len, 0) : 0; \ + } + +#define HASH32_BUF_OP(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int32 NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \ + return is_valid ? hash32_buf(reinterpret_cast<const uint8_t*>(in), len, 0) : 0; \ + } + +// Expand inner macro for all non-numeric types. +#define VAR_LEN_TYPES(INNER, NAME) \ + INNER(NAME, utf8) \ + INNER(NAME, binary) + +VAR_LEN_TYPES(HASH32_BUF_OP, hash) +VAR_LEN_TYPES(HASH32_BUF_OP, hash32) +VAR_LEN_TYPES(HASH32_BUF_OP, hash32AsDouble) +VAR_LEN_TYPES(HASH32_BUF_WITH_SEED_OP, hash32WithSeed) +VAR_LEN_TYPES(HASH32_BUF_WITH_SEED_OP, hash32AsDoubleWithSeed) + +VAR_LEN_TYPES(HASH64_BUF_OP, hash64) +VAR_LEN_TYPES(HASH64_BUF_OP, hash64AsDouble) +VAR_LEN_TYPES(HASH64_BUF_WITH_SEED_OP, hash64WithSeed) +VAR_LEN_TYPES(HASH64_BUF_WITH_SEED_OP, hash64AsDoubleWithSeed) + +#undef HASH32_BUF_OP +#undef HASH32_BUF_WITH_SEED_OP +#undef HASH32_OP +#undef HASH32_WITH_SEED_OP +#undef HASH64_BUF_OP +#undef HASH64_BUF_WITH_SEED_OP +#undef HASH64_OP +#undef HASH64_WITH_SEED_OP + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/hash_test.cc b/src/arrow/cpp/src/gandiva/precompiled/hash_test.cc new file mode 100644 index 000000000..0a51dced2 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/hash_test.cc @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <time.h> + +#include <gtest/gtest.h> +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestHash, TestHash32) { + gdv_int8 s8 = 0; + gdv_uint8 u8 = 0; + gdv_int16 s16 = 0; + gdv_uint16 u16 = 0; + gdv_int32 s32 = 0; + gdv_uint32 u32 = 0; + gdv_int64 s64 = 0; + gdv_uint64 u64 = 0; + gdv_float32 f32 = 0; + gdv_float64 f64 = 0; + + // hash of 0 should be non-zero (zero is the hash value for nulls). + gdv_int32 zero_hash = hash32(s8, 0); + EXPECT_NE(zero_hash, 0); + + // for a given value, all numeric types must have the same hash. + EXPECT_EQ(hash32(u8, 0), zero_hash); + EXPECT_EQ(hash32(s16, 0), zero_hash); + EXPECT_EQ(hash32(u16, 0), zero_hash); + EXPECT_EQ(hash32(s32, 0), zero_hash); + EXPECT_EQ(hash32(u32, 0), zero_hash); + EXPECT_EQ(hash32(static_cast<double>(s64), 0), zero_hash); + EXPECT_EQ(hash32(static_cast<double>(u64), 0), zero_hash); + EXPECT_EQ(hash32(f32, 0), zero_hash); + EXPECT_EQ(hash32(f64, 0), zero_hash); + + // hash must change with a change in seed. + EXPECT_NE(hash32(s8, 1), zero_hash); + + // for a given value and seed, all numeric types must have the same hash. + EXPECT_EQ(hash32(s8, 1), hash32(s16, 1)); + EXPECT_EQ(hash32(s8, 1), hash32(u32, 1)); + EXPECT_EQ(hash32(s8, 1), hash32(f32, 1)); + EXPECT_EQ(hash32(s8, 1), hash32(f64, 1)); +} + +TEST(TestHash, TestHash64) { + gdv_int8 s8 = 0; + gdv_uint8 u8 = 0; + gdv_int16 s16 = 0; + gdv_uint16 u16 = 0; + gdv_int32 s32 = 0; + gdv_uint32 u32 = 0; + gdv_int64 s64 = 0; + gdv_uint64 u64 = 0; + gdv_float32 f32 = 0; + gdv_float64 f64 = 0; + + // hash of 0 should be non-zero (zero is the hash value for nulls). + gdv_int64 zero_hash = hash64(s8, 0); + EXPECT_NE(zero_hash, 0); + EXPECT_NE(hash64(u8, 0), hash32(u8, 0)); + + // for a given value, all numeric types must have the same hash. + EXPECT_EQ(hash64(u8, 0), zero_hash); + EXPECT_EQ(hash64(s16, 0), zero_hash); + EXPECT_EQ(hash64(u16, 0), zero_hash); + EXPECT_EQ(hash64(s32, 0), zero_hash); + EXPECT_EQ(hash64(u32, 0), zero_hash); + EXPECT_EQ(hash64(static_cast<double>(s64), 0), zero_hash); + EXPECT_EQ(hash64(static_cast<double>(u64), 0), zero_hash); + EXPECT_EQ(hash64(f32, 0), zero_hash); + EXPECT_EQ(hash64(f64, 0), zero_hash); + + // hash must change with a change in seed. + EXPECT_NE(hash64(s8, 1), zero_hash); + + // for a given value and seed, all numeric types must have the same hash. + EXPECT_EQ(hash64(s8, 1), hash64(s16, 1)); + EXPECT_EQ(hash64(s8, 1), hash64(u32, 1)); + EXPECT_EQ(hash64(s8, 1), hash64(f32, 1)); +} + +TEST(TestHash, TestHashBuf) { + const char* buf = "hello"; + int buf_len = 5; + + // hash should be non-zero (zero is the hash value for nulls). + EXPECT_NE(hash32_buf((const gdv_uint8*)buf, buf_len, 0), 0); + EXPECT_NE(hash64_buf((const gdv_uint8*)buf, buf_len, 0), 0); + + // hash must change if the string is changed. + EXPECT_NE(hash32_buf((const gdv_uint8*)buf, buf_len, 0), + hash32_buf((const gdv_uint8*)buf, buf_len - 1, 0)); + + EXPECT_NE(hash64_buf((const gdv_uint8*)buf, buf_len, 0), + hash64_buf((const gdv_uint8*)buf, buf_len - 1, 0)); + + // hash must change if the seed is changed. + EXPECT_NE(hash32_buf((const gdv_uint8*)buf, buf_len, 0), + hash32_buf((const gdv_uint8*)buf, buf_len, 1)); + + EXPECT_NE(hash64_buf((const gdv_uint8*)buf, buf_len, 0), + hash64_buf((const gdv_uint8*)buf, buf_len, 1)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/print.cc b/src/arrow/cpp/src/gandiva/precompiled/print.cc new file mode 100644 index 000000000..ecb90e1a3 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/print.cc @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern "C" { + +#include <stdio.h> + +#include "./types.h" + +int print_double(char* msg, double val) { return printf(msg, val); } + +int print_float(char* msg, float val) { return printf(msg, val); } + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/string_ops.cc b/src/arrow/cpp/src/gandiva/precompiled/string_ops.cc new file mode 100644 index 000000000..48c24b862 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/string_ops.cc @@ -0,0 +1,2198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// String functions +#include "arrow/util/value_parsing.h" + +extern "C" { + +#include <algorithm> +#include <climits> +#include <cstdio> +#include <cstdlib> +#include <cstring> + +#include "./types.h" + +FORCE_INLINE +gdv_int32 octet_length_utf8(const gdv_utf8 input, gdv_int32 length) { return length; } + +FORCE_INLINE +gdv_int32 bit_length_utf8(const gdv_utf8 input, gdv_int32 length) { return length * 8; } + +FORCE_INLINE +gdv_int32 octet_length_binary(const gdv_binary input, gdv_int32 length) { return length; } + +FORCE_INLINE +gdv_int32 bit_length_binary(const gdv_binary input, gdv_int32 length) { + return length * 8; +} + +FORCE_INLINE +int match_string(const char* input, gdv_int32 input_len, gdv_int32 start_pos, + const char* delim, gdv_int32 delim_len) { + for (int i = start_pos; i < input_len; i++) { + int left_chars = input_len - i; + if ((left_chars >= delim_len) && memcmp(input + i, delim, delim_len) == 0) { + return i + delim_len; + } + } + + return -1; +} + +FORCE_INLINE +gdv_int32 mem_compare(const char* left, gdv_int32 left_len, const char* right, + gdv_int32 right_len) { + int min = left_len; + if (right_len < min) { + min = right_len; + } + + int cmp_ret = memcmp(left, right, min); + if (cmp_ret != 0) { + return cmp_ret; + } else { + return left_len - right_len; + } +} + +// Expand inner macro for all varlen types. +#define VAR_LEN_OP_TYPES(INNER, NAME, OP) \ + INNER(NAME, utf8, OP) \ + INNER(NAME, binary, OP) + +// Relational binary fns : left, right params are same, return is bool. +#define BINARY_RELATIONAL(NAME, TYPE, OP) \ + FORCE_INLINE \ + bool NAME##_##TYPE##_##TYPE(const gdv_##TYPE left, gdv_int32 left_len, \ + const gdv_##TYPE right, gdv_int32 right_len) { \ + return mem_compare(left, left_len, right, right_len) OP 0; \ + } + +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, equal, ==) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, not_equal, !=) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, less_than, <) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, less_than_or_equal_to, <=) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, greater_than, >) +VAR_LEN_OP_TYPES(BINARY_RELATIONAL, greater_than_or_equal_to, >=) + +#undef BINARY_RELATIONAL +#undef VAR_LEN_OP_TYPES + +// Expand inner macro for all varlen types. +#define VAR_LEN_TYPES(INNER, NAME) \ + INNER(NAME, utf8) \ + INNER(NAME, binary) + +FORCE_INLINE +int to_binary_from_hex(char ch) { + if (ch >= 'A' && ch <= 'F') { + return 10 + (ch - 'A'); + } else if (ch >= 'a' && ch <= 'f') { + return 10 + (ch - 'a'); + } + return ch - '0'; +} + +FORCE_INLINE +bool starts_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* prefix, + gdv_int32 prefix_len) { + return ((data_len >= prefix_len) && (memcmp(data, prefix, prefix_len) == 0)); +} + +FORCE_INLINE +bool ends_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* suffix, + gdv_int32 suffix_len) { + return ((data_len >= suffix_len) && + (memcmp(data + data_len - suffix_len, suffix, suffix_len) == 0)); +} + +FORCE_INLINE +bool is_substr_utf8_utf8(const char* data, int32_t data_len, const char* substr, + int32_t substr_len) { + for (int32_t i = 0; i <= data_len - substr_len; ++i) { + if (memcmp(data + i, substr, substr_len) == 0) { + return true; + } + } + return false; +} + +FORCE_INLINE +gdv_int32 utf8_char_length(char c) { + if ((signed char)c >= 0) { // 1-byte char (0x00 ~ 0x7F) + return 1; + } else if ((c & 0xE0) == 0xC0) { // 2-byte char + return 2; + } else if ((c & 0xF0) == 0xE0) { // 3-byte char + return 3; + } else if ((c & 0xF8) == 0xF0) { // 4-byte char + return 4; + } + // invalid char + return 0; +} + +FORCE_INLINE +void set_error_for_invalid_utf(int64_t execution_context, char val) { + char const* fmt = "unexpected byte \\%02hhx encountered while decoding utf8 string"; + int size = static_cast<int>(strlen(fmt)) + 64; + char* error = reinterpret_cast<char*>(malloc(size)); + snprintf(error, size, fmt, (unsigned char)val); + gdv_fn_context_set_error_msg(execution_context, error); + free(error); +} + +FORCE_INLINE +bool validate_utf8_following_bytes(const char* data, int32_t data_len, + int32_t char_index) { + for (int j = 1; j < data_len; ++j) { + if ((data[char_index + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph + return false; + } + } + return true; +} + +// Count the number of utf8 characters +// return 0 for invalid/incomplete input byte sequences +FORCE_INLINE +gdv_int32 utf8_length(gdv_int64 context, const char* data, gdv_int32 data_len) { + int char_len = 0; + int count = 0; + for (int i = 0; i < data_len; i += char_len) { + char_len = utf8_char_length(data[i]); + if (char_len == 0 || i + char_len > data_len) { // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, data[i]); + return 0; + } + for (int j = 1; j < char_len; ++j) { + if ((data[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph + set_error_for_invalid_utf(context, data[i + j]); + return 0; + } + } + ++count; + } + return count; +} + +// Count the number of utf8 characters, ignoring invalid char, considering size 1 +FORCE_INLINE +gdv_int32 utf8_length_ignore_invalid(const char* data, gdv_int32 data_len) { + int char_len = 0; + int count = 0; + for (int i = 0; i < data_len; i += char_len) { + char_len = utf8_char_length(data[i]); + if (char_len == 0 || i + char_len > data_len) { // invalid byte or incomplete glyph + // if invalid byte or incomplete glyph, ignore it + char_len = 1; + } + for (int j = 1; j < char_len; ++j) { + if ((data[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph + char_len += 1; + } + } + ++count; + } + return count; +} + +// Get the byte position corresponding to a character position for a non-empty utf8 +// sequence +FORCE_INLINE +gdv_int32 utf8_byte_pos(gdv_int64 context, const char* str, gdv_int32 str_len, + gdv_int32 char_pos) { + int char_len = 0; + int byte_index = 0; + for (gdv_int32 char_index = 0; char_index < char_pos && byte_index < str_len; + char_index++) { + char_len = utf8_char_length(str[byte_index]); + if (char_len == 0 || + byte_index + char_len > str_len) { // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, str[byte_index]); + return -1; + } + byte_index += char_len; + } + return byte_index; +} + +#define UTF8_LENGTH(NAME, TYPE) \ + FORCE_INLINE \ + gdv_int32 NAME##_##TYPE(gdv_int64 context, gdv_##TYPE in, gdv_int32 in_len) { \ + return utf8_length(context, in, in_len); \ + } + +UTF8_LENGTH(char_length, utf8) +UTF8_LENGTH(length, utf8) +UTF8_LENGTH(lengthUtf8, binary) + +// Returns a string of 'n' spaces. +#define SPACE_STR(IN_TYPE) \ + GANDIVA_EXPORT \ + const char* space_##IN_TYPE(gdv_int64 ctx, gdv_##IN_TYPE n, int32_t* out_len) { \ + gdv_int32 n_times = static_cast<gdv_int32>(n); \ + if (n_times <= 0) { \ + *out_len = 0; \ + return ""; \ + } \ + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(ctx, n_times)); \ + if (ret == nullptr) { \ + gdv_fn_context_set_error_msg(ctx, "Could not allocate memory for output string"); \ + *out_len = 0; \ + return ""; \ + } \ + for (int i = 0; i < n_times; i++) { \ + ret[i] = ' '; \ + } \ + *out_len = n_times; \ + return ret; \ + } + +SPACE_STR(int32) +SPACE_STR(int64) + +// Reverse a utf8 sequence +FORCE_INLINE +const char* reverse_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len) { + if (data_len == 0) { + *out_len = 0; + return ""; + } + + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, data_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + gdv_int32 char_len; + for (gdv_int32 i = 0; i < data_len; i += char_len) { + char_len = utf8_char_length(data[i]); + + if (char_len == 0 || i + char_len > data_len) { // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, data[i]); + *out_len = 0; + return ""; + } + + for (gdv_int32 j = 0; j < char_len; ++j) { + if (j > 0 && (data[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph + set_error_for_invalid_utf(context, data[i + j]); + *out_len = 0; + return ""; + } + ret[data_len - i - char_len + j] = data[i + j]; + } + } + *out_len = data_len; + return ret; +} + +// Trims whitespaces from the left end of the input utf8 sequence +FORCE_INLINE +const char* ltrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len) { + if (data_len == 0) { + *out_len = 0; + return ""; + } + + gdv_int32 start = 0; + // start denotes the first position of non-space characters in the input string + while (start < data_len && data[start] == ' ') { + ++start; + } + + *out_len = data_len - start; + return data + start; +} + +// Trims whitespaces from the right end of the input utf8 sequence +FORCE_INLINE +const char* rtrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len) { + if (data_len == 0) { + *out_len = 0; + return ""; + } + + gdv_int32 end = data_len - 1; + // end denotes the last position of non-space characters in the input string + while (end >= 0 && data[end] == ' ') { + --end; + } + + *out_len = end + 1; + return data; +} + +// Trims whitespaces from both the ends of the input utf8 sequence +FORCE_INLINE +const char* btrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len) { + if (data_len == 0) { + *out_len = 0; + return ""; + } + + gdv_int32 start = 0, end = data_len - 1; + // start and end denote the first and last positions of non-space + // characters in the input string respectively + while (start <= end && data[start] == ' ') { + ++start; + } + while (end >= start && data[end] == ' ') { + --end; + } + + // string has some leading/trailing spaces and some non-space characters + *out_len = end - start + 1; + return data + start; +} + +// Trims characters present in the trim text from the left end of the base text +FORCE_INLINE +const char* ltrim_utf8_utf8(gdv_int64 context, const char* basetext, + gdv_int32 basetext_len, const char* trimtext, + gdv_int32 trimtext_len, int32_t* out_len) { + if (basetext_len == 0) { + *out_len = 0; + return ""; + } else if (trimtext_len == 0) { + *out_len = basetext_len; + return basetext; + } + + gdv_int32 start_ptr, char_len; + // scan the base text from left to right and increment the start pointer till + // there is a character which is not present in the trim text + for (start_ptr = 0; start_ptr < basetext_len; start_ptr += char_len) { + char_len = utf8_char_length(basetext[start_ptr]); + if (char_len == 0 || start_ptr + char_len > basetext_len) { + // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, basetext[start_ptr]); + *out_len = 0; + return ""; + } + if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + start_ptr, char_len)) { + break; + } + } + + *out_len = basetext_len - start_ptr; + return basetext + start_ptr; +} + +// Trims characters present in the trim text from the right end of the base text +FORCE_INLINE +const char* rtrim_utf8_utf8(gdv_int64 context, const char* basetext, + gdv_int32 basetext_len, const char* trimtext, + gdv_int32 trimtext_len, int32_t* out_len) { + if (basetext_len == 0) { + *out_len = 0; + return ""; + } else if (trimtext_len == 0) { + *out_len = basetext_len; + return basetext; + } + + gdv_int32 char_len, end_ptr, byte_cnt = 1; + // scan the base text from right to left and decrement the end pointer till + // there is a character which is not present in the trim text + for (end_ptr = basetext_len - 1; end_ptr >= 0; --end_ptr) { + char_len = utf8_char_length(basetext[end_ptr]); + if (char_len == 0) { // trailing bytes of multibyte character + ++byte_cnt; + continue; + } + // this is the first byte of a character, hence check if char_len = char_cnt + if (byte_cnt != char_len) { // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, basetext[end_ptr]); + *out_len = 0; + return ""; + } + byte_cnt = 1; // reset the counter*/ + if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + end_ptr, char_len)) { + break; + } + } + + // when all characters in the basetext are part of the trimtext + if (end_ptr == -1) { + *out_len = 0; + return ""; + } + + end_ptr += utf8_char_length(basetext[end_ptr]); // point to the next character + *out_len = end_ptr; + return basetext; +} + +// Trims characters present in the trim text from both ends of the base text +FORCE_INLINE +const char* btrim_utf8_utf8(gdv_int64 context, const char* basetext, + gdv_int32 basetext_len, const char* trimtext, + gdv_int32 trimtext_len, int32_t* out_len) { + if (basetext_len == 0) { + *out_len = 0; + return ""; + } else if (trimtext_len == 0) { + *out_len = basetext_len; + return basetext; + } + + gdv_int32 start_ptr, end_ptr, char_len, byte_cnt = 1; + // scan the base text from left to right and increment the start and decrement the + // end pointers till there are characters which are not present in the trim text + for (start_ptr = 0; start_ptr < basetext_len; start_ptr += char_len) { + char_len = utf8_char_length(basetext[start_ptr]); + if (char_len == 0 || start_ptr + char_len > basetext_len) { + // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, basetext[start_ptr]); + *out_len = 0; + return ""; + } + if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + start_ptr, char_len)) { + break; + } + } + for (end_ptr = basetext_len - 1; end_ptr >= start_ptr; --end_ptr) { + char_len = utf8_char_length(basetext[end_ptr]); + if (char_len == 0) { // trailing byte in multibyte character + ++byte_cnt; + continue; + } + // this is the first byte of a character, hence check if char_len = char_cnt + if (byte_cnt != char_len) { // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, basetext[end_ptr]); + *out_len = 0; + return ""; + } + byte_cnt = 1; // reset the counter*/ + if (!is_substr_utf8_utf8(trimtext, trimtext_len, basetext + end_ptr, char_len)) { + break; + } + } + + // when all characters are trimmed, start_ptr has been incremented to basetext_len and + // end_ptr still points to basetext_len - 1, hence we need to handle this case + if (start_ptr > end_ptr) { + *out_len = 0; + return ""; + } + + end_ptr += utf8_char_length(basetext[end_ptr]); // point to the next character + *out_len = end_ptr - start_ptr; + return basetext + start_ptr; +} + +FORCE_INLINE +gdv_boolean compare_lower_strings(const char* base_str, gdv_int32 base_str_len, + const char* str, gdv_int32 str_len) { + if (base_str_len != str_len) { + return false; + } + for (int i = 0; i < str_len; i++) { + // convert char to lower + char cur = str[i]; + // 'A' - 'Z' : 0x41 - 0x5a + // 'a' - 'z' : 0x61 - 0x7a + if (cur >= 0x41 && cur <= 0x5a) { + cur = static_cast<char>(cur + 0x20); + } + // if the character does not match, break the flow + if (cur != base_str[i]) break; + // if the character matches and it is the last iteration, return true + if (i == str_len - 1) return true; + } + return false; +} + +// Try to cast the received string ('0', '1', 'true', 'false'), ignoring leading +// and trailing spaces, also ignoring lower and upper case. +FORCE_INLINE +gdv_boolean castBIT_utf8(gdv_int64 context, const char* data, gdv_int32 data_len) { + if (data_len <= 0) { + gdv_fn_context_set_error_msg(context, "Invalid value for boolean."); + return false; + } + + // trim leading and trailing spaces + int32_t trimmed_len; + int32_t start = 0, end = data_len - 1; + while (start <= end && data[start] == ' ') { + ++start; + } + while (end >= start && data[end] == ' ') { + --end; + } + trimmed_len = end - start + 1; + const char* trimmed_data = data + start; + + // compare received string with the valid bool string values '1', '0', 'true', 'false' + if (trimmed_len == 1) { + // case for '0' and '1' value + if (trimmed_data[0] == '1') return true; + if (trimmed_data[0] == '0') return false; + } else if (trimmed_len == 4) { + // case for matching 'true' + if (compare_lower_strings("true", 4, trimmed_data, trimmed_len)) return true; + } else if (trimmed_len == 5) { + // case for matching 'false' + if (compare_lower_strings("false", 5, trimmed_data, trimmed_len)) return false; + } + // if no 'true', 'false', '0' or '1' value is found, set an error + gdv_fn_context_set_error_msg(context, "Invalid value for boolean."); + return false; +} + +FORCE_INLINE +const char* castVARCHAR_bool_int64(gdv_int64 context, gdv_boolean value, + gdv_int64 out_len, gdv_int32* out_length) { + gdv_int32 len = static_cast<gdv_int32>(out_len); + if (len < 0) { + gdv_fn_context_set_error_msg(context, "Output buffer length can't be negative"); + *out_length = 0; + return ""; + } + const char* out = + reinterpret_cast<const char*>(gdv_fn_context_arena_malloc(context, 5)); + out = value ? "true" : "false"; + *out_length = value ? ((len > 4) ? 4 : len) : ((len > 5) ? 5 : len); + return out; +} + +// Truncates the string to given length +#define CAST_VARCHAR_FROM_VARLEN_TYPE(TYPE) \ + FORCE_INLINE \ + const char* castVARCHAR_##TYPE##_int64(gdv_int64 context, const char* data, \ + gdv_int32 data_len, int64_t out_len, \ + int32_t* out_length) { \ + int32_t len = static_cast<int32_t>(out_len); \ + \ + if (len < 0) { \ + gdv_fn_context_set_error_msg(context, "Output buffer length can't be negative"); \ + *out_length = 0; \ + return ""; \ + } \ + \ + if (len >= data_len || len == 0) { \ + *out_length = data_len; \ + return data; \ + } \ + \ + int32_t remaining = len; \ + int32_t index = 0; \ + bool is_multibyte = false; \ + do { \ + /* In utf8, MSB of a single byte unicode char is always 0, \ + * whereas for a multibyte character the MSB of each byte is 1. \ + * So for a single byte char, a bitwise-and with x80 (10000000) will be 0 \ + * and it won't be 0 for bytes of a multibyte char. \ + */ \ + char* data_ptr = const_cast<char*>(data); \ + \ + /* advance byte by byte till the 8-byte boundary then advance 8 bytes */ \ + auto num_bytes = reinterpret_cast<uintptr_t>(data_ptr) & 0x07; \ + num_bytes = (8 - num_bytes) & 0x07; \ + while (num_bytes > 0) { \ + uint8_t* ptr = reinterpret_cast<uint8_t*>(data_ptr + index); \ + if ((*ptr & 0x80) != 0) { \ + is_multibyte = true; \ + break; \ + } \ + index++; \ + remaining--; \ + num_bytes--; \ + } \ + if (is_multibyte) break; \ + while (remaining >= 8) { \ + uint64_t* ptr = reinterpret_cast<uint64_t*>(data_ptr + index); \ + if ((*ptr & 0x8080808080808080) != 0) { \ + is_multibyte = true; \ + break; \ + } \ + index += 8; \ + remaining -= 8; \ + } \ + if (is_multibyte) break; \ + if (remaining >= 4) { \ + uint32_t* ptr = reinterpret_cast<uint32_t*>(data_ptr + index); \ + if ((*ptr & 0x80808080) != 0) break; \ + index += 4; \ + remaining -= 4; \ + } \ + while (remaining > 0) { \ + uint8_t* ptr = reinterpret_cast<uint8_t*>(data_ptr + index); \ + if ((*ptr & 0x80) != 0) { \ + is_multibyte = true; \ + break; \ + } \ + index++; \ + remaining--; \ + } \ + if (is_multibyte) break; \ + /* reached here; all are single byte characters */ \ + *out_length = len; \ + return data; \ + } while (false); \ + \ + /* detected multibyte utf8 characters; slow path */ \ + int32_t byte_pos = \ + utf8_byte_pos(context, data + index, data_len - index, len - index); \ + if (byte_pos < 0) { \ + *out_length = 0; \ + return ""; \ + } \ + \ + *out_length = index + byte_pos; \ + return data; \ + } + +CAST_VARCHAR_FROM_VARLEN_TYPE(utf8) +CAST_VARCHAR_FROM_VARLEN_TYPE(binary) + +#undef CAST_VARCHAR_FROM_VARLEN_TYPE + +// Add functions for castVARBINARY +#define CAST_VARBINARY_FROM_STRING_AND_BINARY(TYPE) \ + GANDIVA_EXPORT \ + const char* castVARBINARY_##TYPE##_int64(gdv_int64 context, const char* data, \ + gdv_int32 data_len, int64_t out_len, \ + int32_t* out_length) { \ + int32_t len = static_cast<int32_t>(out_len); \ + if (len < 0) { \ + gdv_fn_context_set_error_msg(context, "Output buffer length can't be negative"); \ + *out_length = 0; \ + return ""; \ + } \ + \ + if (len >= data_len || len == 0) { \ + *out_length = data_len; \ + } else { \ + *out_length = len; \ + } \ + return data; \ + } + +CAST_VARBINARY_FROM_STRING_AND_BINARY(utf8) +CAST_VARBINARY_FROM_STRING_AND_BINARY(binary) + +#undef CAST_VARBINARY_FROM_STRING_AND_BINARY + +#define IS_NULL(NAME, TYPE) \ + FORCE_INLINE \ + bool NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \ + return !is_valid; \ + } + +VAR_LEN_TYPES(IS_NULL, isnull) + +#undef IS_NULL + +#define IS_NOT_NULL(NAME, TYPE) \ + FORCE_INLINE \ + bool NAME##_##TYPE(gdv_##TYPE in, gdv_int32 len, gdv_boolean is_valid) { \ + return is_valid; \ + } + +VAR_LEN_TYPES(IS_NOT_NULL, isnotnull) + +#undef IS_NOT_NULL +#undef VAR_LEN_TYPES + +/* + We follow Oracle semantics for offset: + - If position is positive, then the first glyph in the substring is determined by + counting that many glyphs forward from the beginning of the input. (i.e., for position == + 1 the first glyph in the substring will be identical to the first glyph in the input) + + - If position is negative, then the first glyph in the substring is determined by + counting that many glyphs backward from the end of the input. (i.e., for position == -1 + the first glyph in the substring will be identical to the last glyph in the input) + + - If position is 0 then it is treated as 1. + */ +FORCE_INLINE +const char* substr_utf8_int64_int64(gdv_int64 context, const char* input, + gdv_int32 in_data_len, gdv_int64 position, + gdv_int64 substring_length, gdv_int32* out_data_len) { + if (substring_length <= 0 || input == nullptr || in_data_len <= 0) { + *out_data_len = 0; + return ""; + } + + gdv_int64 in_glyphs_count = + static_cast<gdv_int64>(utf8_length(context, input, in_data_len)); + + // in_glyphs_count is zero if input has invalid glyphs + if (in_glyphs_count == 0) { + *out_data_len = 0; + return ""; + } + + gdv_int64 from_glyph; // from_glyph==0 indicates the first glyph of the input + if (position > 0) { + from_glyph = position - 1; + } else if (position < 0) { + from_glyph = in_glyphs_count + position; + } else { + from_glyph = 0; + } + + if (from_glyph < 0 || from_glyph >= in_glyphs_count) { + *out_data_len = 0; + return ""; + } + + gdv_int64 out_glyphs_count = substring_length; + if (substring_length > in_glyphs_count - from_glyph) { + out_glyphs_count = in_glyphs_count - from_glyph; + } + + gdv_int64 in_data_len64 = static_cast<gdv_int64>(in_data_len); + gdv_int64 start_pos = 0; + gdv_int64 end_pos = in_data_len64; + + gdv_int64 current_glyph = 0; + gdv_int64 pos = 0; + while (pos < in_data_len64) { + if (current_glyph == from_glyph) { + start_pos = pos; + } + pos += static_cast<gdv_int64>(utf8_char_length(input[pos])); + if (current_glyph - from_glyph + 1 == out_glyphs_count) { + end_pos = pos; + } + current_glyph++; + } + + if (end_pos > in_data_len64 || end_pos > INT_MAX) { + end_pos = in_data_len64; + } + + *out_data_len = static_cast<gdv_int32>(end_pos - start_pos); + char* ret = + reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_data_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_data_len = 0; + return ""; + } + memcpy(ret, input + start_pos, *out_data_len); + return ret; +} + +FORCE_INLINE +const char* substr_utf8_int64(gdv_int64 context, const char* input, gdv_int32 in_len, + gdv_int64 offset64, gdv_int32* out_len) { + return substr_utf8_int64_int64(context, input, in_len, offset64, in_len, out_len); +} + +FORCE_INLINE +const char* repeat_utf8_int32(gdv_int64 context, const char* in, gdv_int32 in_len, + gdv_int32 repeat_number, gdv_int32* out_len) { + // if the repeat number is zero, then return empty string + if (repeat_number == 0 || in_len <= 0) { + *out_len = 0; + return ""; + } + // if the repeat number is a negative number, an error is set on context + if (repeat_number < 0) { + gdv_fn_context_set_error_msg(context, "Repeat number can't be negative"); + *out_len = 0; + return ""; + } + *out_len = repeat_number * in_len; + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + for (int i = 0; i < repeat_number; ++i) { + memcpy(ret + (i * in_len), in, in_len); + } + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8(gdv_int64 context, const char* left, gdv_int32 left_len, + bool left_validity, const char* right, gdv_int32 right_len, + bool right_validity, gdv_int32* out_len) { + if (!left_validity) { + left_len = 0; + } + if (!right_validity) { + right_len = 0; + } + return concatOperator_utf8_utf8(context, left, left_len, right, right_len, out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8(gdv_int64 context, const char* left, + gdv_int32 left_len, const char* right, + gdv_int32 right_len, gdv_int32* out_len) { + *out_len = left_len + right_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, left, left_len); + memcpy(ret + left_len, right, right_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8(gdv_int64 context, const char* in1, gdv_int32 in1_len, + bool in1_validity, const char* in2, gdv_int32 in2_len, + bool in2_validity, const char* in3, gdv_int32 in3_len, + bool in3_validity, gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + return concatOperator_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len, in3, in3_len, + out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8(gdv_int64 context, const char* in1, + gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, + gdv_int32 in3_len, gdv_int32* out_len) { + *out_len = in1_len + in2_len + in3_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1, + gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, + bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, + const char* in4, gdv_int32 in4_len, + bool in4_validity, gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + if (!in4_validity) { + in4_len = 0; + } + return concatOperator_utf8_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len, in3, + in3_len, in4, in4_len, out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1, + gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, + gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, gdv_int32* out_len) { + *out_len = in1_len + in2_len + in3_len + in4_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + if (!in4_validity) { + in4_len = 0; + } + if (!in5_validity) { + in5_len = 0; + } + return concatOperator_utf8_utf8_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len, in3, + in3_len, in4, in4_len, in5, in5_len, + out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, gdv_int32* out_len) { + *out_len = in1_len + in2_len + in3_len + in4_len + in5_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + if (!in4_validity) { + in4_len = 0; + } + if (!in5_validity) { + in5_len = 0; + } + if (!in6_validity) { + in6_len = 0; + } + return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8(context, in1, in1_len, in2, in2_len, + in3, in3_len, in4, in4_len, in5, + in5_len, in6, in6_len, out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, gdv_int32* out_len) { + *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + if (!in4_validity) { + in4_len = 0; + } + if (!in5_validity) { + in5_len = 0; + } + if (!in6_validity) { + in6_len = 0; + } + if (!in7_validity) { + in7_len = 0; + } + return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6, + in6_len, in7, in7_len, out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, gdv_int32* out_len) { + *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len, + bool in8_validity, gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + if (!in4_validity) { + in4_len = 0; + } + if (!in5_validity) { + in5_len = 0; + } + if (!in6_validity) { + in6_len = 0; + } + if (!in7_validity) { + in7_len = 0; + } + if (!in8_validity) { + in8_len = 0; + } + return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6, + in6_len, in7, in7_len, in8, in8_len, out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8, + gdv_int32 in8_len, gdv_int32* out_len) { + *out_len = + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + in8_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len, in8, + in8_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len, + bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity, + gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + if (!in4_validity) { + in4_len = 0; + } + if (!in5_validity) { + in5_len = 0; + } + if (!in6_validity) { + in6_len = 0; + } + if (!in7_validity) { + in7_len = 0; + } + if (!in8_validity) { + in8_len = 0; + } + if (!in9_validity) { + in9_len = 0; + } + return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6, + in6_len, in7, in7_len, in8, in8_len, in9, in9_len, out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8, + gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, gdv_int32* out_len) { + *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + + in8_len + in9_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len, in8, + in8_len); + memcpy( + ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + in8_len, + in9, in9_len); + return ret; +} + +FORCE_INLINE +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len, + bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity, + const char* in10, gdv_int32 in10_len, bool in10_validity, gdv_int32* out_len) { + if (!in1_validity) { + in1_len = 0; + } + if (!in2_validity) { + in2_len = 0; + } + if (!in3_validity) { + in3_len = 0; + } + if (!in4_validity) { + in4_len = 0; + } + if (!in5_validity) { + in5_len = 0; + } + if (!in6_validity) { + in6_len = 0; + } + if (!in7_validity) { + in7_len = 0; + } + if (!in8_validity) { + in8_len = 0; + } + if (!in9_validity) { + in9_len = 0; + } + if (!in10_validity) { + in10_len = 0; + } + return concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + context, in1, in1_len, in2, in2_len, in3, in3_len, in4, in4_len, in5, in5_len, in6, + in6_len, in7, in7_len, in8, in8_len, in9, in9_len, in10, in10_len, out_len); +} + +FORCE_INLINE +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8, + gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, const char* in10, + gdv_int32 in10_len, gdv_int32* out_len) { + *out_len = in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + + in8_len + in9_len + in10_len; + if (*out_len <= 0) { + *out_len = 0; + return ""; + } + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, in1, in1_len); + memcpy(ret + in1_len, in2, in2_len); + memcpy(ret + in1_len + in2_len, in3, in3_len); + memcpy(ret + in1_len + in2_len + in3_len, in4, in4_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len, in5, in5_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len, in6, in6_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len, in7, in7_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len, in8, + in8_len); + memcpy( + ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + in8_len, + in9, in9_len); + memcpy(ret + in1_len + in2_len + in3_len + in4_len + in5_len + in6_len + in7_len + + in8_len + in9_len, + in10, in10_len); + return ret; +} + +// Returns the numeric value of the first character of str. +GANDIVA_EXPORT +gdv_int32 ascii_utf8(const char* data, gdv_int32 data_len) { + if (data_len == 0) { + return 0; + } + return static_cast<gdv_int32>(data[0]); +} + +FORCE_INLINE +const char* convert_fromUTF8_binary(gdv_int64 context, const char* bin_in, gdv_int32 len, + gdv_int32* out_len) { + *out_len = len; + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, bin_in, *out_len); + return ret; +} + +FORCE_INLINE +const char* convert_replace_invalid_fromUTF8_binary(int64_t context, const char* text_in, + int32_t text_len, + const char* char_to_replace, + int32_t char_to_replace_len, + int32_t* out_len) { + if (char_to_replace_len > 1) { + gdv_fn_context_set_error_msg(context, "Replacement of multiple bytes not supported"); + *out_len = 0; + return ""; + } + // actually the convert_replace function replaces invalid chars with an ASCII + // character so the output length will be the same as the input length + *out_len = text_len; + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + int32_t valid_bytes_to_cpy = 0; + int32_t out_byte_counter = 0; + int32_t in_byte_counter = 0; + int32_t char_len; + // scan the base text from left to right and increment the start pointer till + // looking for invalid chars to substitute + for (int text_index = 0; text_index < text_len; text_index += char_len) { + char_len = utf8_char_length(text_in[text_index]); + // only memory copy the bytes when detect invalid char + if (char_len == 0 || text_index + char_len > text_len || + !validate_utf8_following_bytes(text_in, char_len, text_index)) { + // define char_len = 1 to increase text_index by 1 (as ASCII char fits in 1 byte) + char_len = 1; + // first copy the valid bytes until now and then replace the invalid character + memcpy(ret + out_byte_counter, text_in + in_byte_counter, valid_bytes_to_cpy); + // if the replacement char is empty, the invalid char should be ignored + if (char_to_replace_len == 0) { + out_byte_counter += valid_bytes_to_cpy; + } else { + ret[out_byte_counter + valid_bytes_to_cpy] = char_to_replace[0]; + out_byte_counter += valid_bytes_to_cpy + char_len; + } + in_byte_counter += valid_bytes_to_cpy + char_len; + valid_bytes_to_cpy = 0; + continue; + } + valid_bytes_to_cpy += char_len; + } + // if invalid chars were not found, return the original string + if (out_byte_counter == 0 && in_byte_counter == 0) return text_in; + // if there are still valid bytes to copy, do it + if (valid_bytes_to_cpy != 0) { + memcpy(ret + out_byte_counter, text_in + in_byte_counter, valid_bytes_to_cpy); + } + // the out length will be the out bytes copied + the missing end bytes copied + *out_len = valid_bytes_to_cpy + out_byte_counter; + return ret; +} + +// The function reverse a char array in-place +static inline void reverse_char_buf(char* buf, int32_t len) { + char temp; + + for (int32_t i = 0; i < len / 2; i++) { + int32_t pos_swp = len - (1 + i); + temp = buf[pos_swp]; + buf[pos_swp] = buf[i]; + buf[i] = temp; + } +} + +// Converts a double variable to binary +FORCE_INLINE +const char* convert_toDOUBLE(int64_t context, double value, int32_t* out_len) { + *out_len = sizeof(value); + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for the output string"); + + *out_len = 0; + return ""; + } + + memcpy(ret, &value, *out_len); + + return ret; +} + +FORCE_INLINE +const char* convert_toDOUBLE_be(int64_t context, double value, int32_t* out_len) { + // The function behaves like convert_toDOUBLE, but always return the result + // in big endian format + char* ret = const_cast<char*>(convert_toDOUBLE(context, value, out_len)); + +#if ARROW_LITTLE_ENDIAN + reverse_char_buf(ret, *out_len); +#endif + + return ret; +} + +// Converts a float variable to binary +FORCE_INLINE +const char* convert_toFLOAT(int64_t context, float value, int32_t* out_len) { + *out_len = sizeof(value); + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for the output string"); + + *out_len = 0; + return ""; + } + + memcpy(ret, &value, *out_len); + + return ret; +} + +FORCE_INLINE +const char* convert_toFLOAT_be(int64_t context, float value, int32_t* out_len) { + // The function behaves like convert_toFLOAT, but always return the result + // in big endian format + char* ret = const_cast<char*>(convert_toFLOAT(context, value, out_len)); + +#if ARROW_LITTLE_ENDIAN + reverse_char_buf(ret, *out_len); +#endif + + return ret; +} + +// Converts a bigint(int with 64 bits) variable to binary +FORCE_INLINE +const char* convert_toBIGINT(int64_t context, int64_t value, int32_t* out_len) { + *out_len = sizeof(value); + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for the output string"); + + *out_len = 0; + return ""; + } + + memcpy(ret, &value, *out_len); + + return ret; +} + +FORCE_INLINE +const char* convert_toBIGINT_be(int64_t context, int64_t value, int32_t* out_len) { + // The function behaves like convert_toBIGINT, but always return the result + // in big endian format + char* ret = const_cast<char*>(convert_toBIGINT(context, value, out_len)); + +#if ARROW_LITTLE_ENDIAN + reverse_char_buf(ret, *out_len); +#endif + + return ret; +} + +// Converts an integer(with 32 bits) variable to binary +FORCE_INLINE +const char* convert_toINT(int64_t context, int32_t value, int32_t* out_len) { + *out_len = sizeof(value); + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for the output string"); + + *out_len = 0; + return ""; + } + + memcpy(ret, &value, *out_len); + + return ret; +} + +FORCE_INLINE +const char* convert_toINT_be(int64_t context, int32_t value, int32_t* out_len) { + // The function behaves like convert_toINT, but always return the result + // in big endian format + char* ret = const_cast<char*>(convert_toINT(context, value, out_len)); + +#if ARROW_LITTLE_ENDIAN + reverse_char_buf(ret, *out_len); +#endif + + return ret; +} + +// Converts a boolean variable to binary +FORCE_INLINE +const char* convert_toBOOLEAN(int64_t context, bool value, int32_t* out_len) { + *out_len = sizeof(value); + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for the output string"); + + *out_len = 0; + return ""; + } + + memcpy(ret, &value, *out_len); + + return ret; +} + +// Converts a time variable to binary +FORCE_INLINE +const char* convert_toTIME_EPOCH(int64_t context, int32_t value, int32_t* out_len) { + return convert_toINT(context, value, out_len); +} + +FORCE_INLINE +const char* convert_toTIME_EPOCH_be(int64_t context, int32_t value, int32_t* out_len) { + // The function behaves as convert_toTIME_EPOCH, but + // returns the bytes in big endian format + return convert_toINT_be(context, value, out_len); +} + +// Converts a timestamp variable to binary +FORCE_INLINE +const char* convert_toTIMESTAMP_EPOCH(int64_t context, int64_t timestamp, + int32_t* out_len) { + return convert_toBIGINT(context, timestamp, out_len); +} + +FORCE_INLINE +const char* convert_toTIMESTAMP_EPOCH_be(int64_t context, int64_t timestamp, + int32_t* out_len) { + // The function behaves as convert_toTIMESTAMP_EPOCH, but + // returns the bytes in big endian format + return convert_toBIGINT_be(context, timestamp, out_len); +} + +// Converts a date variable to binary +FORCE_INLINE +const char* convert_toDATE_EPOCH(int64_t context, int64_t date, int32_t* out_len) { + return convert_toBIGINT(context, date, out_len); +} + +FORCE_INLINE +const char* convert_toDATE_EPOCH_be(int64_t context, int64_t date, int32_t* out_len) { + // The function behaves as convert_toDATE_EPOCH, but + // returns the bytes in big endian format + return convert_toBIGINT_be(context, date, out_len); +} + +// Converts a string variable to binary +FORCE_INLINE +const char* convert_toUTF8(int64_t context, const char* value, int32_t value_len, + int32_t* out_len) { + *out_len = value_len; + return value; +} + +// Search for a string within another string +// Same as "locate(substr, str)", except for the reverse order of the arguments. +FORCE_INLINE +gdv_int32 strpos_utf8_utf8(gdv_int64 context, const char* str, gdv_int32 str_len, + const char* sub_str, gdv_int32 sub_str_len) { + return locate_utf8_utf8_int32(context, sub_str, sub_str_len, str, str_len, 1); +} + +// Search for a string within another string +FORCE_INLINE +gdv_int32 locate_utf8_utf8(gdv_int64 context, const char* sub_str, gdv_int32 sub_str_len, + const char* str, gdv_int32 str_len) { + return locate_utf8_utf8_int32(context, sub_str, sub_str_len, str, str_len, 1); +} + +// Search for a string within another string starting at position start-pos (1-indexed) +FORCE_INLINE +gdv_int32 locate_utf8_utf8_int32(gdv_int64 context, const char* sub_str, + gdv_int32 sub_str_len, const char* str, + gdv_int32 str_len, gdv_int32 start_pos) { + if (start_pos < 1) { + gdv_fn_context_set_error_msg(context, "Start position must be greater than 0"); + return 0; + } + + if (str_len == 0 || sub_str_len == 0) { + return 0; + } + + gdv_int32 byte_pos = utf8_byte_pos(context, str, str_len, start_pos - 1); + if (byte_pos < 0 || byte_pos >= str_len) { + return 0; + } + for (gdv_int32 i = byte_pos; i <= str_len - sub_str_len; ++i) { + if (memcmp(str + i, sub_str, sub_str_len) == 0) { + return utf8_length(context, str, i) + 1; + } + } + return 0; +} + +FORCE_INLINE +const char* replace_with_max_len_utf8_utf8_utf8(gdv_int64 context, const char* text, + gdv_int32 text_len, const char* from_str, + gdv_int32 from_str_len, + const char* to_str, gdv_int32 to_str_len, + gdv_int32 max_length, + gdv_int32* out_len) { + // if from_str is empty or its length exceeds that of original string, + // return the original string + if (from_str_len <= 0 || from_str_len > text_len) { + *out_len = text_len; + return text; + } + + bool found = false; + gdv_int32 text_index = 0; + char* out; + gdv_int32 out_index = 0; + gdv_int32 last_match_index = + 0; // defer copying string from last_match_index till next match is found + + for (; text_index <= text_len - from_str_len;) { + if (memcmp(text + text_index, from_str, from_str_len) == 0) { + if (out_index + text_index - last_match_index + to_str_len > max_length) { + gdv_fn_context_set_error_msg(context, "Buffer overflow for output string"); + *out_len = 0; + return ""; + } + if (!found) { + // found match for first time + out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, max_length)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + found = true; + } + // first copy the part deferred till now + memcpy(out + out_index, text + last_match_index, (text_index - last_match_index)); + out_index += text_index - last_match_index; + // then copy the target string + memcpy(out + out_index, to_str, to_str_len); + out_index += to_str_len; + + text_index += from_str_len; + last_match_index = text_index; + } else { + text_index++; + } + } + + if (!found) { + *out_len = text_len; + return text; + } + + if (out_index + text_len - last_match_index > max_length) { + gdv_fn_context_set_error_msg(context, "Buffer overflow for output string"); + *out_len = 0; + return ""; + } + memcpy(out + out_index, text + last_match_index, text_len - last_match_index); + out_index += text_len - last_match_index; + *out_len = out_index; + return out; +} + +FORCE_INLINE +const char* replace_utf8_utf8_utf8(gdv_int64 context, const char* text, + gdv_int32 text_len, const char* from_str, + gdv_int32 from_str_len, const char* to_str, + gdv_int32 to_str_len, gdv_int32* out_len) { + return replace_with_max_len_utf8_utf8_utf8(context, text, text_len, from_str, + from_str_len, to_str, to_str_len, 65535, + out_len); +} + +FORCE_INLINE +const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, const char* fill_text, + gdv_int32 fill_text_len, gdv_int32* out_len) { + // if the text length or the defined return length (number of characters to return) + // is <=0, then return an empty string. + if (text_len == 0 || return_length <= 0) { + *out_len = 0; + return ""; + } + + // count the number of utf8 characters on text, ignoring invalid bytes + int text_char_count = utf8_length_ignore_invalid(text, text_len); + + if (return_length == text_char_count || + (return_length > text_char_count && fill_text_len == 0)) { + // case where the return length is same as the text's length, or if it need to + // fill into text but "fill_text" is empty, then return text directly. + *out_len = text_len; + return text; + } else if (return_length < text_char_count) { + // case where it truncates the result on return length. + *out_len = utf8_byte_pos(context, text, text_len, return_length); + return text; + } else { + // case (return_length > text_char_count) + // case where it needs to copy "fill_text" on the string left. The total number + // of chars to copy is given by (return_length - text_char_count) + char* ret = + reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, return_length)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + // try to fulfill the return string with the "fill_text" continuously + int32_t copied_chars_count = 0; + int32_t copied_chars_position = 0; + while (copied_chars_count < return_length - text_char_count) { + int32_t char_len; + int32_t fill_index; + // for each char, evaluate its length to consider it when mem copying + for (fill_index = 0; fill_index < fill_text_len; fill_index += char_len) { + if (copied_chars_count >= return_length - text_char_count) { + break; + } + char_len = utf8_char_length(fill_text[fill_index]); + // ignore invalid char on the fill text, considering it as size 1 + if (char_len == 0) char_len += 1; + copied_chars_count++; + } + memcpy(ret + copied_chars_position, fill_text, fill_index); + copied_chars_position += fill_index; + } + // after fulfilling the text, copy the main string + memcpy(ret + copied_chars_position, text, text_len); + *out_len = copied_chars_position + text_len; + return ret; + } +} + +FORCE_INLINE +const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, const char* fill_text, + gdv_int32 fill_text_len, gdv_int32* out_len) { + // if the text length or the defined return length (number of characters to return) + // is <=0, then return an empty string. + if (text_len == 0 || return_length <= 0) { + *out_len = 0; + return ""; + } + + // count the number of utf8 characters on text, ignoring invalid bytes + int text_char_count = utf8_length_ignore_invalid(text, text_len); + + if (return_length == text_char_count || + (return_length > text_char_count && fill_text_len == 0)) { + // case where the return length is same as the text's length, or if it need to + // fill into text but "fill_text" is empty, then return text directly. + *out_len = text_len; + return text; + } else if (return_length < text_char_count) { + // case where it truncates the result on return length. + *out_len = utf8_byte_pos(context, text, text_len, return_length); + return text; + } else { + // case (return_length > text_char_count) + // case where it needs to copy "fill_text" on the string right + char* ret = + reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, return_length)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + // fulfill the initial text copying the main input string + memcpy(ret, text, text_len); + // try to fulfill the return string with the "fill_text" continuously + int32_t copied_chars_count = 0; + int32_t copied_chars_position = 0; + while (text_char_count + copied_chars_count < return_length) { + int32_t char_len; + int32_t fill_length; + // for each char, evaluate its length to consider it when mem copying + for (fill_length = 0; fill_length < fill_text_len; fill_length += char_len) { + if (text_char_count + copied_chars_count >= return_length) { + break; + } + char_len = utf8_char_length(fill_text[fill_length]); + // ignore invalid char on the fill text, considering it as size 1 + if (char_len == 0) char_len += 1; + copied_chars_count++; + } + memcpy(ret + text_len + copied_chars_position, fill_text, fill_length); + copied_chars_position += fill_length; + } + *out_len = copied_chars_position + text_len; + return ret; + } +} + +FORCE_INLINE +const char* lpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, gdv_int32* out_len) { + return lpad_utf8_int32_utf8(context, text, text_len, return_length, " ", 1, out_len); +} + +FORCE_INLINE +const char* rpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, gdv_int32* out_len) { + return rpad_utf8_int32_utf8(context, text, text_len, return_length, " ", 1, out_len); +} + +FORCE_INLINE +const char* split_part(gdv_int64 context, const char* text, gdv_int32 text_len, + const char* delimiter, gdv_int32 delim_len, gdv_int32 index, + gdv_int32* out_len) { + *out_len = 0; + if (index < 1) { + char error_message[100]; + snprintf(error_message, sizeof(error_message), + "Index in split_part must be positive, value provided was %d", index); + gdv_fn_context_set_error_msg(context, error_message); + return ""; + } + + if (delim_len == 0 || text_len == 0) { + // output will just be text if no delimiter is provided + *out_len = text_len; + return text; + } + + int i = 0, match_no = 1; + + while (i < text_len) { + // find the position where delimiter matched for the first time + int match_pos = match_string(text, text_len, i, delimiter, delim_len); + if (match_pos == -1 && match_no != index) { + // reached the end without finding a match. + return ""; + } else { + // Found a match. If the match number is index then return this match + if (match_no == index) { + int end_pos = match_pos - delim_len; + + if (match_pos == -1) { + // end position should be last position of the string as we have the last + // delimiter + end_pos = text_len; + } + + *out_len = end_pos - i; + char* out_str = + reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (out_str == nullptr) { + gdv_fn_context_set_error_msg(context, + "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(out_str, text + i, *out_len); + return out_str; + } else { + i = match_pos; + match_no++; + } + } + } + + return ""; +} + +// Returns the x leftmost characters of a given string. Cases: +// LEFT("TestString", 10) => "TestString" +// LEFT("TestString", 3) => "Tes" +// LEFT("TestString", -3) => "TestStr" +FORCE_INLINE +const char* left_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 number, gdv_int32* out_len) { + // returns the 'number' left most characters of a given text + if (text_len == 0 || number == 0) { + *out_len = 0; + return ""; + } + + // iterate over the utf8 string validating each character + int char_len; + int char_count = 0; + int byte_index = 0; + for (int i = 0; i < text_len; i += char_len) { + char_len = utf8_char_length(text[i]); + if (char_len == 0 || i + char_len > text_len) { // invalid byte or incomplete glyph + set_error_for_invalid_utf(context, text[i]); + *out_len = 0; + return ""; + } + for (int j = 1; j < char_len; ++j) { + if ((text[i + j] & 0xC0) != 0x80) { // bytes following head-byte of glyph + set_error_for_invalid_utf(context, text[i + j]); + *out_len = 0; + return ""; + } + } + byte_index += char_len; + ++char_count; + // Define the rules to stop the iteration over the string + // case where left('abc', 5) -> 'abc' + if (number > 0 && char_count == number) break; + // case where left('abc', -5) ==> '' + if (number < 0 && char_count == number + text_len) break; + } + + *out_len = byte_index; + return text; +} + +// Returns the x rightmost characters of a given string. Cases: +// RIGHT("TestString", 10) => "TestString" +// RIGHT("TestString", 3) => "ing" +// RIGHT("TestString", -3) => "tString" +FORCE_INLINE +const char* right_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 number, gdv_int32* out_len) { + // returns the 'number' left most characters of a given text + if (text_len == 0 || number == 0) { + *out_len = 0; + return ""; + } + + // initially counts the number of utf8 characters in the defined text + int32_t char_count = utf8_length(context, text, text_len); + // char_count is zero if input has invalid utf8 char + if (char_count == 0) { + *out_len = 0; + return ""; + } + + int32_t start_char_pos; // the char result start position (inclusive) + int32_t end_char_len; // the char result end position (inclusive) + if (number > 0) { + // case where right('abc', 5) ==> 'abc' start_char_pos=1. + start_char_pos = (char_count > number) ? char_count - number : 0; + end_char_len = char_count - start_char_pos; + } else { + start_char_pos = number * -1; + end_char_len = char_count - start_char_pos; + } + + // calculate the start byte position and the output length + int32_t start_byte_pos = utf8_byte_pos(context, text, text_len, start_char_pos); + *out_len = utf8_byte_pos(context, text, text_len, end_char_len); + + // try to allocate memory for the response + char* ret = + reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(ret, text + start_byte_pos, *out_len); + return ret; +} + +FORCE_INLINE +const char* binary_string(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32* out_len) { + gdv_binary ret = + reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, text_len)); + + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + if (text_len == 0) { + *out_len = 0; + return ""; + } + + // converting hex encoded string to normal string + int j = 0; + for (int i = 0; i < text_len; i++, j++) { + if (text[i] == '\\' && i + 3 < text_len && + (text[i + 1] == 'x' || text[i + 1] == 'X')) { + char hd1 = text[i + 2]; + char hd2 = text[i + 3]; + if (isxdigit(hd1) && isxdigit(hd2)) { + // [a-fA-F0-9] + ret[j] = to_binary_from_hex(hd1) * 16 + to_binary_from_hex(hd2); + i += 3; + } else { + ret[j] = text[i]; + } + } else { + ret[j] = text[i]; + } + } + *out_len = j; + return ret; +} + +#define CAST_INT_BIGINT_VARBINARY(OUT_TYPE, TYPE_NAME) \ + FORCE_INLINE \ + OUT_TYPE \ + cast##TYPE_NAME##_varbinary(gdv_int64 context, const char* in, int32_t in_len) { \ + if (in_len == 0) { \ + gdv_fn_context_set_error_msg(context, "Can't cast an empty string."); \ + return -1; \ + } \ + char sign = in[0]; \ + \ + bool negative = false; \ + if (sign == '-') { \ + negative = true; \ + /* Ignores the sign char in the hexadecimal string */ \ + in++; \ + in_len--; \ + } \ + \ + if (negative && in_len == 0) { \ + gdv_fn_context_set_error_msg(context, \ + "Can't cast hexadecimal with only a minus sign."); \ + return -1; \ + } \ + \ + OUT_TYPE result = 0; \ + int digit; \ + \ + int read_index = 0; \ + while (read_index < in_len) { \ + char c1 = in[read_index]; \ + if (isxdigit(c1)) { \ + digit = to_binary_from_hex(c1); \ + \ + OUT_TYPE next = result * 16 - digit; \ + \ + if (next > result) { \ + gdv_fn_context_set_error_msg(context, "Integer overflow."); \ + return -1; \ + } \ + result = next; \ + read_index++; \ + } else { \ + gdv_fn_context_set_error_msg(context, \ + "The hexadecimal given has invalid characters."); \ + return -1; \ + } \ + } \ + if (!negative) { \ + result *= -1; \ + \ + if (result < 0) { \ + gdv_fn_context_set_error_msg(context, "Integer overflow."); \ + return -1; \ + } \ + } \ + return result; \ + } + +CAST_INT_BIGINT_VARBINARY(int32_t, INT) +CAST_INT_BIGINT_VARBINARY(int64_t, BIGINT) + +#undef CAST_INT_BIGINT_VARBINARY + +// Produces the binary representation of a string y characters long derived by starting +// at offset 'x' and considering the defined length 'y'. Notice that the offset index +// may be a negative number (starting from the end of the string), or a positive number +// starting on index 1. Cases: +// BYTE_SUBSTR("TestString", 1, 10) => "TestString" +// BYTE_SUBSTR("TestString", 5, 10) => "String" +// BYTE_SUBSTR("TestString", -6, 10) => "String" +// BYTE_SUBSTR("TestString", -600, 10) => "TestString" +FORCE_INLINE +const char* byte_substr_binary_int32_int32(gdv_int64 context, const char* text, + gdv_int32 text_len, gdv_int32 offset, + gdv_int32 length, gdv_int32* out_len) { + // the first offset position for a string is 1, so not consider offset == 0 + // also, the length should be always a positive number + if (text_len == 0 || offset == 0 || length <= 0) { + *out_len = 0; + return ""; + } + + char* ret = + reinterpret_cast<gdv_binary>(gdv_fn_context_arena_malloc(context, text_len)); + + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + int32_t startPos = 0; + if (offset >= 0) { + startPos = offset - 1; + } else if (text_len + offset >= 0) { + startPos = text_len + offset; + } + + // calculate end position from length and truncate to upper value bounds + if (startPos + length > text_len) { + *out_len = text_len - startPos; + } else { + *out_len = length; + } + + memcpy(ret, text + startPos, *out_len); + return ret; +} +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/string_ops_test.cc b/src/arrow/cpp/src/gandiva/precompiled/string_ops_test.cc new file mode 100644 index 000000000..6221dffb3 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -0,0 +1,1758 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include <limits> + +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestStringOps, TestCompare) { + const char* left = "abcd789"; + const char* right = "abcd123"; + + // 0 for equal + EXPECT_EQ(mem_compare(left, 4, right, 4), 0); + + // compare lengths if the prefixes match + EXPECT_GT(mem_compare(left, 5, right, 4), 0); + EXPECT_LT(mem_compare(left, 4, right, 5), 0); + + // compare bytes if the prefixes don't match + EXPECT_GT(mem_compare(left, 5, right, 5), 0); + EXPECT_GT(mem_compare(left, 5, right, 7), 0); + EXPECT_GT(mem_compare(left, 7, right, 5), 0); +} + +TEST(TestStringOps, TestAscii) { + // ASCII + EXPECT_EQ(ascii_utf8("ABC", 3), 65); + EXPECT_EQ(ascii_utf8("abc", 3), 97); + EXPECT_EQ(ascii_utf8("Hello World!", 12), 72); + EXPECT_EQ(ascii_utf8("This is us", 10), 84); + EXPECT_EQ(ascii_utf8("", 0), 0); + EXPECT_EQ(ascii_utf8("123", 3), 49); + EXPECT_EQ(ascii_utf8("999", 3), 57); +} + +TEST(TestStringOps, TestBeginsEnds) { + // starts_with + EXPECT_TRUE(starts_with_utf8_utf8("hello sir", 9, "hello", 5)); + EXPECT_TRUE(starts_with_utf8_utf8("hellos", 6, "hello", 5)); + EXPECT_TRUE(starts_with_utf8_utf8("hello", 5, "hello", 5)); + EXPECT_FALSE(starts_with_utf8_utf8("hell", 4, "hello", 5)); + EXPECT_FALSE(starts_with_utf8_utf8("world hello", 11, "hello", 5)); + + // ends_with + EXPECT_TRUE(ends_with_utf8_utf8("hello sir", 9, "sir", 3)); + EXPECT_TRUE(ends_with_utf8_utf8("ssir", 4, "sir", 3)); + EXPECT_TRUE(ends_with_utf8_utf8("sir", 3, "sir", 3)); + EXPECT_FALSE(ends_with_utf8_utf8("ir", 2, "sir", 3)); + EXPECT_FALSE(ends_with_utf8_utf8("hello", 5, "sir", 3)); +} + +TEST(TestStringOps, TestSpace) { + // Space - returns a string with 'n' spaces + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + int32_t out_len = 0; + + auto out = space_int32(ctx_ptr, 1, &out_len); + EXPECT_EQ(std::string(out, out_len), " "); + out = space_int32(ctx_ptr, 10, &out_len); + EXPECT_EQ(std::string(out, out_len), " "); + out = space_int32(ctx_ptr, 5, &out_len); + EXPECT_EQ(std::string(out, out_len), " "); + out = space_int32(ctx_ptr, -5, &out_len); + EXPECT_EQ(std::string(out, out_len), ""); + + out = space_int64(ctx_ptr, 2, &out_len); + EXPECT_EQ(std::string(out, out_len), " "); + out = space_int64(ctx_ptr, 9, &out_len); + EXPECT_EQ(std::string(out, out_len), " "); + out = space_int64(ctx_ptr, 4, &out_len); + EXPECT_EQ(std::string(out, out_len), " "); + out = space_int64(ctx_ptr, -5, &out_len); + EXPECT_EQ(std::string(out, out_len), ""); +} + +TEST(TestStringOps, TestIsSubstr) { + EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "world", 5)); + EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "lo wo", 5)); + EXPECT_FALSE(is_substr_utf8_utf8("hello world", 11, "adsed", 5)); + EXPECT_FALSE(is_substr_utf8_utf8("hel", 3, "hello", 5)); + EXPECT_TRUE(is_substr_utf8_utf8("hello", 5, "hello", 5)); + EXPECT_TRUE(is_substr_utf8_utf8("hello world", 11, "", 0)); +} + +TEST(TestStringOps, TestCharLength) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + + EXPECT_EQ(utf8_length(ctx_ptr, "hello sir", 9), 9); + + std::string a("âpple"); + EXPECT_EQ(utf8_length(ctx_ptr, a.data(), static_cast<int>(a.length())), 5); + + std::string b("मदन"); + EXPECT_EQ(utf8_length(ctx_ptr, b.data(), static_cast<int>(b.length())), 3); + + // invalid utf8 + std::string c("\xf8\x28"); + EXPECT_EQ(utf8_length(ctx_ptr, c.data(), static_cast<int>(c.length())), 0); + EXPECT_TRUE(ctx.get_error().find( + "unexpected byte \\f8 encountered while decoding utf8 string") != + std::string::npos) + << ctx.get_error(); + ctx.Reset(); + + std::string d("aa\xc3"); + EXPECT_EQ(utf8_length(ctx_ptr, d.data(), static_cast<int>(d.length())), 0); + EXPECT_TRUE(ctx.get_error().find( + "unexpected byte \\c3 encountered while decoding utf8 string") != + std::string::npos) + << ctx.get_error(); + ctx.Reset(); + + std::string e( + "a\xc3" + "a"); + EXPECT_EQ(utf8_length(ctx_ptr, e.data(), static_cast<int>(e.length())), 0); + EXPECT_TRUE(ctx.get_error().find( + "unexpected byte \\61 encountered while decoding utf8 string") != + std::string::npos) + << ctx.get_error(); + ctx.Reset(); + + std::string f( + "a\xc3\xe3" + "a"); + EXPECT_EQ(utf8_length(ctx_ptr, f.data(), static_cast<int>(f.length())), 0); + EXPECT_TRUE(ctx.get_error().find( + "unexpected byte \\e3 encountered while decoding utf8 string") != + std::string::npos) + << ctx.get_error(); + ctx.Reset(); +} + +TEST(TestStringOps, TestConvertReplaceInvalidUtf8Char) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + + // invalid utf8 (xf8 is invalid but x28 is not - x28 = '(') + std::string a( + "ok-\xf8\x28" + "-a"); + auto a_in_out_len = static_cast<int>(a.length()); + const char* a_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, a.data(), a_in_out_len, "a", 1, &a_in_out_len); + EXPECT_EQ(std::string(a_str, a_in_out_len), "ok-a(-a"); + EXPECT_FALSE(ctx.has_error()); + + // invalid utf8 (xa0 and xa1 are invalid) + std::string b("ok-\xa0\xa1-valid"); + auto b_in_out_len = static_cast<int>(b.length()); + const char* b_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, b.data(), b_in_out_len, "b", 1, &b_in_out_len); + EXPECT_EQ(std::string(b_str, b_in_out_len), "ok-bb-valid"); + EXPECT_FALSE(ctx.has_error()); + + // full valid utf8 + std::string c("all-valid"); + auto c_in_out_len = static_cast<int>(c.length()); + const char* c_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, c.data(), c_in_out_len, "c", 1, &c_in_out_len); + EXPECT_EQ(std::string(c_str, c_in_out_len), "all-valid"); + EXPECT_FALSE(ctx.has_error()); + + // valid utf8 (महसुस is 4-char string, each char of which is likely a multibyte char) + std::string d("ok-महसुस-valid-new"); + auto d_in_out_len = static_cast<int>(d.length()); + const char* d_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, d.data(), d_in_out_len, "d", 1, &d_in_out_len); + EXPECT_EQ(std::string(d_str, d_in_out_len), "ok-महसुस-valid-new"); + EXPECT_FALSE(ctx.has_error()); + + // full valid utf8, but invalid replacement char length + std::string e("all-valid"); + auto e_in_out_len = static_cast<int>(e.length()); + const char* e_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, e.data(), e_in_out_len, "ee", 2, &e_in_out_len); + EXPECT_EQ(std::string(e_str, e_in_out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + // invalid utf8 (xa0 and xa1 are invalid) with empty replacement char length + std::string f("ok-\xa0\xa1-valid"); + auto f_in_out_len = static_cast<int>(f.length()); + const char* f_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, f.data(), f_in_out_len, "", 0, &f_in_out_len); + EXPECT_EQ(std::string(f_str, f_in_out_len), "ok--valid"); + EXPECT_FALSE(ctx.has_error()); + ctx.Reset(); + + // invalid utf8 (xa0 and xa1 are invalid) with empty replacement char length + std::string g("\xa0\xa1-ok-\xa0\xa1-valid-\xa0\xa1"); + auto g_in_out_len = static_cast<int>(g.length()); + const char* g_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, g.data(), g_in_out_len, "", 0, &g_in_out_len); + EXPECT_EQ(std::string(g_str, g_in_out_len), "-ok--valid-"); + EXPECT_FALSE(ctx.has_error()); + ctx.Reset(); + + std::string h("\xa0\xa1-valid"); + auto h_in_out_len = static_cast<int>(h.length()); + const char* h_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, h.data(), h_in_out_len, "", 0, &h_in_out_len); + EXPECT_EQ(std::string(h_str, h_in_out_len), "-valid"); + EXPECT_FALSE(ctx.has_error()); + ctx.Reset(); + + std::string i("\xa0\xa1-valid-\xa0\xa1-valid-\xa0\xa1"); + auto i_in_out_len = static_cast<int>(i.length()); + const char* i_str = convert_replace_invalid_fromUTF8_binary( + ctx_ptr, i.data(), i_in_out_len, "", 0, &i_in_out_len); + EXPECT_EQ(std::string(i_str, i_in_out_len), "-valid--valid-"); + EXPECT_FALSE(ctx.has_error()); + ctx.Reset(); +} + +TEST(TestStringOps, TestRepeat) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = repeat_utf8_int32(ctx_ptr, "abc", 3, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcabc"); + EXPECT_FALSE(ctx.has_error()); + + out_str = repeat_utf8_int32(ctx_ptr, "a", 1, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "aaaaa"); + EXPECT_FALSE(ctx.has_error()); + + out_str = repeat_utf8_int32(ctx_ptr, "", 0, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = repeat_utf8_int32(ctx_ptr, "", -20, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = repeat_utf8_int32(ctx_ptr, "a", 1, -10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Repeat number can't be negative")); + ctx.Reset(); +} + +TEST(TestStringOps, TestCastBoolToVarchar) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = castVARCHAR_bool_int64(ctx_ptr, true, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "tr"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_bool_int64(ctx_ptr, true, 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "true"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_bool_int64(ctx_ptr, false, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "fals"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_bool_int64(ctx_ptr, false, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "false"); + EXPECT_FALSE(ctx.has_error()); + + castVARCHAR_bool_int64(ctx_ptr, true, -3, &out_len); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Output buffer length can't be negative")); + ctx.Reset(); +} + +TEST(TestStringOps, TestCastVarcharToBool) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "true", 4), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, " true ", 14), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "true ", 9), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, " true", 9), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "TRUE", 4), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "TrUe", 4), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "1", 1), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, " 1", 3), true); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "false", 5), false); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "false ", 10), false); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, " false", 10), false); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "0", 1), false); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "0 ", 4), false); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "FALSE", 5), false); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "FaLsE", 5), false); + EXPECT_FALSE(ctx.has_error()); + + EXPECT_EQ(castBIT_utf8(ctx_ptr, "test", 4), false); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid value for boolean")); + ctx.Reset(); +} + +TEST(TestStringOps, TestCastVarchar) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + // BINARY TESTS + const char* out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "a"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + // do not truncate if output length is 0 + out_str = castVARCHAR_binary_int64(ctx_ptr, "asdf", 4, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "", 0, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†", 9, 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "abc", 3, -1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Output buffer length can't be negative")); + ctx.Reset(); + + std::string z("aa\xc3"); + out_str = castVARCHAR_binary_int64(ctx_ptr, z.data(), static_cast<int>(z.length()), 2, + &out_len); + EXPECT_EQ(std::string(out_str, out_len), "aa"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234567812341234"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234123"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 12, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "12345678"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234567"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812341234", 16, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "1234567812çåå†123456", 25, 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234567812çåå†12"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "123456781234çåå†1234", 25, 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "12çåå†34567812123456", 25, 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "12çåå†3456781212"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†1234567812123456", 25, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "çåå†1234567812123456", 25, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_binary_int64(ctx_ptr, "123456781234çåå†", 21, 40, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå†"); + EXPECT_FALSE(ctx.has_error()); + + std::string f("123456781234çåå\xc3"); + out_str = castVARCHAR_binary_int64(ctx_ptr, f.data(), static_cast<int32_t>(f.length()), + 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\c3 encountered while decoding utf8 string")); + ctx.Reset(); + + // UTF8 TESTS + out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "a"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + // do not truncate if output length is 0 + out_str = castVARCHAR_utf8_int64(ctx_ptr, "asdf", 4, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "", 0, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†", 9, 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "abc", 3, -1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Output buffer length can't be negative")); + ctx.Reset(); + + std::string d("aa\xc3"); + out_str = castVARCHAR_utf8_int64(ctx_ptr, d.data(), static_cast<int>(d.length()), 2, + &out_len); + EXPECT_EQ(std::string(out_str, out_len), "aa"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234567812341234"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234123"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 12, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 8, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "12345678"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234567"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812341234", 16, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "1234567812çåå†123456", 25, 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "1234567812çåå†12"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "123456781234çåå†1234", 25, 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "12çåå†34567812123456", 25, 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "12çåå†3456781212"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†1234567812123456", 25, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "çåå†1234567812123456", 25, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çåå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = castVARCHAR_utf8_int64(ctx_ptr, "123456781234çåå†", 21, 40, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123456781234çåå†"); + EXPECT_FALSE(ctx.has_error()); + + std::string y("123456781234çåå\xc3"); + out_str = castVARCHAR_utf8_int64(ctx_ptr, y.data(), static_cast<int32_t>(y.length()), + 16, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\c3 encountered while decoding utf8 string")); + ctx.Reset(); +} + +TEST(TestStringOps, TestSubstring) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 1, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 1, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "as"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 1, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, 0, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, -2, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "df"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "asdf", 4, -5, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "अपाचे एरो", 25, 1, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "अपाचे"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "अपाचे एरो", 25, 7, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "एरो"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 4, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 2, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "åå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 0, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "çå"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "afg", 4, 0, -5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64_int64(ctx_ptr, "", 0, 5, 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64(ctx_ptr, "abcd", 4, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "bcd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64(ctx_ptr, "abcd", 4, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = substr_utf8_int64(ctx_ptr, "çåå†", 9, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "åå†"); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestStringOps, TestSubstringInvalidInputs) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + char bytes[] = {'\xA7', 'a'}; + const char* out_str = substr_utf8_int64_int64(ctx_ptr, bytes, 2, 1, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + char midbytes[] = {'c', '\xA7', 'a'}; + out_str = substr_utf8_int64_int64(ctx_ptr, midbytes, 3, 1, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + char midbytes2[] = {'\xC3', 'a', 'a'}; + out_str = substr_utf8_int64_int64(ctx_ptr, midbytes2, 3, 1, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + char endbytes[] = {'a', 'a', '\xA7'}; + out_str = substr_utf8_int64_int64(ctx_ptr, endbytes, 3, 1, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + char endbytes2[] = {'a', 'a', '\xC3'}; + out_str = substr_utf8_int64_int64(ctx_ptr, endbytes2, 3, 1, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + out_str = substr_utf8_int64_int64(ctx_ptr, "çåå†", 9, 2147483656, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryUtf8) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + const char* input = "abc"; + const char* out; + + out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, 0, &out_len); + EXPECT_EQ(std::string(out, out_len), input); + + out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, 1, &out_len); + EXPECT_EQ(std::string(out, out_len), "a"); + + out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, 500, &out_len); + EXPECT_EQ(std::string(out, out_len), input); + + out = castVARBINARY_utf8_int64(ctx_ptr, input, 3, -10, &out_len); + EXPECT_EQ(std::string(out, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Output buffer length can't be negative")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryBinary) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx); + int32_t out_len = 0; + const char* input = "\\x41\\x42\\x43"; + const char* out; + + out = castVARBINARY_binary_int64(ctx_ptr, input, 12, 0, &out_len); + EXPECT_EQ(std::string(out, out_len), input); + + out = castVARBINARY_binary_int64(ctx_ptr, input, 8, 8, &out_len); + EXPECT_EQ(std::string(out, out_len), "\\x41\\x42"); + + out = castVARBINARY_binary_int64(ctx_ptr, input, 12, 500, &out_len); + EXPECT_EQ(std::string(out, out_len), input); + + out = castVARBINARY_binary_int64(ctx_ptr, input, 12, -10, &out_len); + EXPECT_EQ(std::string(out, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Output buffer length can't be negative")); + ctx.Reset(); +} + +TEST(TestStringOps, TestConcat) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str = + concat_utf8_utf8(ctx_ptr, "abcd", 4, true, "\npq", 3, false, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8(ctx_ptr, "asdf", 4, "jkl", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdfjkl"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8(ctx_ptr, "asdf", 4, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "asdf"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8(ctx_ptr, "", 0, "jkl", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "jkl"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8(ctx_ptr, "", 0, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8(ctx_ptr, "abcd\n", 5, "a", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcd\na"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8(ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, + true, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqard"); + EXPECT_FALSE(ctx.has_error()); + + out_str = + concatOperator_utf8_utf8_utf8(ctx_ptr, "abcd\n", 5, "a", 1, "bcd", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcd\nabcd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8(ctx_ptr, "abcd", 4, "a", 1, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abcda"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8(ctx_ptr, "", 0, "a", 1, "pqrs", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "apqrs"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8_utf8(ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", + 3, true, "uvw", 3, false, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqard"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8_utf8(ctx_ptr, "pqrs", 4, "", 0, "\nabc", 4, "y", + 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "pqrs\nabcy"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8_utf8_utf8(ctx_ptr, "abcd", 4, false, "\npq", 3, true, + "ard", 3, true, "uvw", 3, false, "abc\n", 4, + true, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\n"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8_utf8_utf8(ctx_ptr, "pqrs", 4, "", 0, "\nabc", 4, + "y", 1, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "pqrs\nabcy"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false, + "abc\n", 4, true, "sdfgs", 5, true, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgs"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "pqrs", 4, "", 0, "\nabc", 4, "y", 1, "", 0, "\nbcd", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "pqrs\nabcy\nbcd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false, + "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgs"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "jkl", 3, + &out_len); + EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjkl"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false, + "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, "", 0, true, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgs"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "jkl", 3, "", 0, + &out_len); + EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjkl"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false, + "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, "", 0, true, "qwert|n", 7, + true, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgsqwert|n"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "jkl", 3, "", 0, + "sfl\n", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjklsfl\n"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "abcd", 4, false, "\npq", 3, true, "ard", 3, true, "uvw", 3, false, + "abc\n", 4, true, "sdfgs", 5, true, "wfw", 3, false, "", 0, true, "qwert|n", 7, + true, "ewfwe", 5, false, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "\npqardabc\nsdfgsqwert|n"); + EXPECT_FALSE(ctx.has_error()); + + out_str = concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + ctx_ptr, "", 0, "pqrs", 4, "abc\n", 4, "y", 1, "", 0, "asdf", 4, "", 0, "jkl", 3, + "sfl\n", 4, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "pqrsabc\nyasdfjklsfl\n"); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestStringOps, TestReverse) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str; + out_str = reverse_utf8(ctx_ptr, "TestString", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "gnirtStseT"); + EXPECT_FALSE(ctx.has_error()); + + out_str = reverse_utf8(ctx_ptr, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = reverse_utf8(ctx_ptr, "çåå†", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "†ååç"); + EXPECT_FALSE(ctx.has_error()); + + std::string d("aa\xc3"); + out_str = reverse_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\c3 encountered while decoding utf8 string")); + ctx.Reset(); +} + +TEST(TestStringOps, TestLtrim) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + out_str = ltrim_utf8(ctx_ptr, "TestString ", 12, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString "); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8(ctx_ptr, " TestString ", 18, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString "); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8(ctx_ptr, " Test çåå†bD", 18, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test çåå†bD"); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8(ctx_ptr, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8(ctx_ptr, " ", 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "", 0, "TestString", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "abcbbaccabbcdef", 15, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "def"); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "abcbbaccabbcdef", 15, "ababbac", 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "def"); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "ååçåå†eç†Dd", 21, "çåå†", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "eç†Dd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "ç†ååçåå†", 18, "çåå†", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + std::string d( + "aa\xc3" + "bcd"); + out_str = + ltrim_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "a", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "\xc3" + "bcd"); + EXPECT_FALSE(ctx.has_error()); + + std::string e( + "åå\xe0\xa0" + "bcd"); + out_str = + ltrim_utf8_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), "å", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "\xE0\xa0" + "bcd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "TestString", 10, "abcd", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = ltrim_utf8_utf8(ctx_ptr, "acbabbcabb", 10, "abcbd", 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestStringOps, TestLpadString) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + // LPAD function tests - with defined fill pad text + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 4, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 10, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 0, 10, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 0, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, -500, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 500, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 18, "Fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "FillFillTestString"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 15, "Fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "FillFTestString"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 20, "Fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "FillFillFiTestString"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 7, "д", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ддабвгд"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 20, "абвгд", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "абвгдабвгдабвгдабвгд"); + + out_str = lpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "дhello"); + + // LPAD function tests - with NO pad text + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test"); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 0, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, -500, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 18, &out_len); + EXPECT_EQ(std::string(out_str, out_len), " TestString"); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), " TestString"); + + out_str = lpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), " абвгд"); +} + +TEST(TestStringOps, TestRpadString) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + // RPAD function tests - with defined fill pad text + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 4, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 10, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 0, 10, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 0, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, -500, "fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 500, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 18, "Fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestStringFillFill"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 15, "Fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestStringFillF"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "TestString", 10, 20, "Fill", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestStringFillFillFi"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 7, "д", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "абвгддд"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "абвгд", 10, 20, "абвгд", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "абвгдабвгдабвгдабвгд"); + + out_str = rpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "helloд"); + + // RPAD function tests - with NO pad text + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test"); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 0, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, -500, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 18, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString "); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 15, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString "); + + out_str = rpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "абвгд "); +} + +TEST(TestStringOps, TestRtrim) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + out_str = rtrim_utf8(ctx_ptr, " TestString", 12, &out_len); + EXPECT_EQ(std::string(out_str, out_len), " TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8(ctx_ptr, " TestString ", 18, &out_len); + EXPECT_EQ(std::string(out_str, out_len), " TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8(ctx_ptr, "Test çåå†bD ", 20, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test çåå†bD"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8(ctx_ptr, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8(ctx_ptr, " ", 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "", 0, "TestString", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "TestString", 10, "ring", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestSt"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "defabcbbaccabbc", 15, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "def"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "defabcbbaccabbc", 15, "ababbac", 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "def"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "eDdç†ååçåå†", 21, "çåå†", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "eDd"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "ç†ååçåå†", 18, "çåå†", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + std::string d( + "\xc3" + "aaa"); + out_str = + rtrim_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "a", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + std::string e( + "\xe0\xa0" + "åå"); + out_str = + rtrim_utf8_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), "å", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + out_str = rtrim_utf8_utf8(ctx_ptr, "åeçå", 7, "çå", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "åe"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "TestString", 10, "abcd", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = rtrim_utf8_utf8(ctx_ptr, "acbabbcabb", 10, "abcbd", 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestStringOps, TestBtrim) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + out_str = btrim_utf8(ctx_ptr, "TestString", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8(ctx_ptr, " TestString ", 18, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8(ctx_ptr, " Test çåå†bD ", 21, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test çåå†bD"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8(ctx_ptr, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8(ctx_ptr, " ", 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "", 0, "TestString", 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "Test", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "String"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "String", 6, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Tes"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "abcbbadefccabbc", 15, "abc", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "def"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "abcbbadefccabbc", 15, "ababbac", 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "def"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "ååçåå†Ddeç†", 21, "çåå†", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Dde"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "ç†ååçåå†", 18, "çåå†", 9, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + ctx.Reset(); + + std::string d( + "acd\xc3" + "aaa"); + out_str = + btrim_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "a", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + std::string e( + "åbc\xe0\xa0" + "åå"); + out_str = + btrim_utf8_utf8(ctx_ptr, e.data(), static_cast<int>(e.length()), "å", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_TRUE(ctx.has_error()); + ctx.Reset(); + + std::string f( + "aa\xc3" + "bcd"); + out_str = + btrim_utf8_utf8(ctx_ptr, f.data(), static_cast<int>(f.length()), "a", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "\xc3" + "bcd"); + EXPECT_FALSE(ctx.has_error()); + + std::string g( + "åå\xe0\xa0" + "bcå"); + out_str = + btrim_utf8_utf8(ctx_ptr, g.data(), static_cast<int>(g.length()), "å", 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), + "\xe0\xa0" + "bc"); + + out_str = btrim_utf8_utf8(ctx_ptr, "åe†çå", 10, "çå", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "e†"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "TestString", 10, "abcd", 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = btrim_utf8_utf8(ctx_ptr, "acbabbcabb", 10, "abcbd", 5, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestStringOps, TestLocate) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + + int pos; + + pos = locate_utf8_utf8(ctx_ptr, "String", 6, "TestString", 10); + EXPECT_EQ(pos, 5); + EXPECT_FALSE(ctx.has_error()); + + pos = locate_utf8_utf8_int32(ctx_ptr, "String", 6, "TestString", 10, 1); + EXPECT_EQ(pos, 5); + EXPECT_FALSE(ctx.has_error()); + + pos = locate_utf8_utf8_int32(ctx_ptr, "abc", 3, "abcabc", 6, 2); + EXPECT_EQ(pos, 4); + EXPECT_FALSE(ctx.has_error()); + + pos = locate_utf8_utf8(ctx_ptr, "çåå", 6, "s†å†emçåå†d", 21); + EXPECT_EQ(pos, 7); + EXPECT_FALSE(ctx.has_error()); + + pos = locate_utf8_utf8_int32(ctx_ptr, "bar", 3, "†barbar", 9, 3); + EXPECT_EQ(pos, 5); + EXPECT_FALSE(ctx.has_error()); + + pos = locate_utf8_utf8_int32(ctx_ptr, "sub", 3, "", 0, 1); + EXPECT_EQ(pos, 0); + EXPECT_FALSE(ctx.has_error()); + + pos = locate_utf8_utf8_int32(ctx_ptr, "", 0, "str", 3, 1); + EXPECT_EQ(pos, 0); + EXPECT_FALSE(ctx.has_error()); + + pos = locate_utf8_utf8_int32(ctx_ptr, "bar", 3, "barbar", 6, 0); + EXPECT_EQ(pos, 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Start position must be greater than 0")); + ctx.Reset(); + + pos = locate_utf8_utf8_int32(ctx_ptr, "bar", 3, "barbar", 6, 7); + EXPECT_EQ(pos, 0); + EXPECT_FALSE(ctx.has_error()); + + std::string d( + "a\xff" + "c"); + pos = + locate_utf8_utf8_int32(ctx_ptr, "c", 1, d.data(), static_cast<int>(d.length()), 3); + EXPECT_EQ(pos, 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\ff encountered while decoding utf8 string")); + ctx.Reset(); +} + +TEST(TestStringOps, TestByteSubstr) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str; + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 5, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "String"); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, -6, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "String"); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 0, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 0, -500, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 1, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 1, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test"); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 1, 1000, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 5, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Str"); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, 5, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "String"); + EXPECT_FALSE(ctx.has_error()); + + out_str = byte_substr_binary_int32_int32(ctx_ptr, "TestString", 10, -100, 10, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); +} + +TEST(TestStringOps, TestStrPos) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + + int pos; + + pos = strpos_utf8_utf8(ctx_ptr, "TestString", 10, "String", 6); + EXPECT_EQ(pos, 5); + EXPECT_FALSE(ctx.has_error()); + + pos = strpos_utf8_utf8(ctx_ptr, "TestString", 10, "String", 6); + EXPECT_EQ(pos, 5); + EXPECT_FALSE(ctx.has_error()); + + pos = strpos_utf8_utf8(ctx_ptr, "abcabc", 6, "abc", 3); + EXPECT_EQ(pos, 1); + EXPECT_FALSE(ctx.has_error()); + + pos = strpos_utf8_utf8(ctx_ptr, "s†å†emçåå†d", 21, "çåå", 6); + EXPECT_EQ(pos, 7); + EXPECT_FALSE(ctx.has_error()); + + pos = strpos_utf8_utf8(ctx_ptr, "†barbar", 9, "bar", 3); + EXPECT_EQ(pos, 2); + EXPECT_FALSE(ctx.has_error()); + + pos = strpos_utf8_utf8(ctx_ptr, "", 0, "sub", 3); + EXPECT_EQ(pos, 0); + EXPECT_FALSE(ctx.has_error()); + + pos = strpos_utf8_utf8(ctx_ptr, "str", 3, "", 0); + EXPECT_EQ(pos, 0); + EXPECT_FALSE(ctx.has_error()); + + std::string d( + "a\xff" + "c"); + pos = strpos_utf8_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), "c", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr( + "unexpected byte \\ff encountered while decoding utf8 string")); + ctx.Reset(); +} + +TEST(TestStringOps, TestReplace) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + + const char* out_str; + out_str = replace_utf8_utf8_utf8(ctx_ptr, "TestString1String2", 18, "String", 6, + "Replace", 7, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestReplace1Replace2"); + EXPECT_FALSE(ctx.has_error()); + + out_str = + replace_utf8_utf8_utf8(ctx_ptr, "TestString1", 11, "String", 6, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test1"); + EXPECT_FALSE(ctx.has_error()); + + out_str = replace_utf8_utf8_utf8(ctx_ptr, "", 0, "test", 4, "rep", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = replace_utf8_utf8_utf8(ctx_ptr, "dž†çåå†", 17, "†", 3, "t", 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Çttçååt"); + EXPECT_FALSE(ctx.has_error()); + + out_str = replace_utf8_utf8_utf8(ctx_ptr, "TestString", 10, "", 0, "rep", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + out_str = + replace_utf8_utf8_utf8(ctx_ptr, "Test", 4, "TestString", 10, "rep", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "Test"); + EXPECT_FALSE(ctx.has_error()); + + out_str = replace_utf8_utf8_utf8(ctx_ptr, "Test", 4, "Test", 4, "", 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_FALSE(ctx.has_error()); + + out_str = + replace_utf8_utf8_utf8(ctx_ptr, "TestString", 10, "abc", 3, "xyz", 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString"); + EXPECT_FALSE(ctx.has_error()); + + replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5, 5, + &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string")); + ctx.Reset(); + + replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "eeee", 4, "e", 1, "aaaa", 4, 14, + &out_len); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string")); + ctx.Reset(); +} + +TEST(TestStringOps, TestLeftString) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + out_str = left_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len); + std::string output = std::string(out_str, out_len); + EXPECT_EQ(output, "TestString"); + + out_str = left_utf8_int32(ctx_ptr, "", 0, 0, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, ""); + + out_str = left_utf8_int32(ctx_ptr, "", 0, 500, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, ""); + + out_str = left_utf8_int32(ctx_ptr, "TestString", 10, 3, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "Tes"); + + out_str = left_utf8_int32(ctx_ptr, "TestString", 10, -3, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "TestStr"); + + // the text length for this string is 10 (each utf8 char is represented by two bytes) + out_str = left_utf8_int32(ctx_ptr, "абвгд", 10, 3, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "абв"); +} + +TEST(TestStringOps, TestRightString) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + out_str = right_utf8_int32(ctx_ptr, "TestString", 10, 10, &out_len); + std::string output = std::string(out_str, out_len); + EXPECT_EQ(output, "TestString"); + + out_str = right_utf8_int32(ctx_ptr, "", 0, 0, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, ""); + + out_str = right_utf8_int32(ctx_ptr, "", 0, 500, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, ""); + + out_str = right_utf8_int32(ctx_ptr, "TestString", 10, 3, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "ing"); + + out_str = right_utf8_int32(ctx_ptr, "TestString", 10, -3, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "tString"); + + // the text length for this string is 10 (each utf8 char is represented by two bytes) + out_str = right_utf8_int32(ctx_ptr, "абвгд", 10, 3, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "вгд"); +} + +TEST(TestStringOps, TestBinaryString) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + out_str = binary_string(ctx_ptr, "TestString", 10, &out_len); + std::string output = std::string(out_str, out_len); + EXPECT_EQ(output, "TestString"); + + out_str = binary_string(ctx_ptr, "", 0, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, ""); + + out_str = binary_string(ctx_ptr, "T", 1, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "T"); + + out_str = binary_string(ctx_ptr, "\\x41\\x42\\x43", 12, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "ABC"); + + out_str = binary_string(ctx_ptr, "\\x41", 4, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "A"); + + out_str = binary_string(ctx_ptr, "\\x6d\\x6D", 8, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "mm"); + + out_str = binary_string(ctx_ptr, "\\x6f\\x6d", 8, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "om"); + + out_str = binary_string(ctx_ptr, "\\x4f\\x4D", 8, &out_len); + output = std::string(out_str, out_len); + EXPECT_EQ(output, "OM"); +} + +TEST(TestStringOps, TestSplitPart) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 0, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Index in split_part must be positive, value provided was 0")); + + out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "A"); + + out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "B"); + + out_str = split_part(ctx_ptr, "A,B,C", 5, ",", 1, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "C"); + + out_str = split_part(ctx_ptr, "abc~@~def~@~ghi", 15, "~@~", 3, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "abc"); + + out_str = split_part(ctx_ptr, "abc~@~def~@~ghi", 15, "~@~", 3, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "def"); + + out_str = split_part(ctx_ptr, "abc~@~def~@~ghi", 15, "~@~", 3, 3, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ghi"); + + // Result must be empty when the index is > no of elements + out_str = split_part(ctx_ptr, "123|456|789", 11, "|", 1, 4, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = split_part(ctx_ptr, "123|", 4, "|", 1, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "123"); + + out_str = split_part(ctx_ptr, "|123", 4, "|", 1, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); + + out_str = split_part(ctx_ptr, "ç†ååçåå†", 18, "å", 2, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ç†"); + + out_str = split_part(ctx_ptr, "ç†ååçåå†", 18, "†åå", 6, 1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ç"); + + out_str = split_part(ctx_ptr, "ç†ååçåå†", 18, "†", 3, 2, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "ååçåå"); +} + +TEST(TestStringOps, TestConvertTo) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + const char* out_str; + + const int32_t ALL_BYTES_MATCH = 0; + + int32_t integer_value = std::numeric_limits<int32_t>::max(); + out_str = convert_toINT(ctx_ptr, integer_value, &out_len); + EXPECT_EQ(out_len, sizeof(integer_value)); + EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &integer_value, out_len)); + + int64_t big_integer_value = std::numeric_limits<int64_t>::max(); + out_str = convert_toBIGINT(ctx_ptr, big_integer_value, &out_len); + EXPECT_EQ(out_len, sizeof(big_integer_value)); + EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &big_integer_value, out_len)); + + float float_value = std::numeric_limits<float>::max(); + out_str = convert_toFLOAT(ctx_ptr, float_value, &out_len); + EXPECT_EQ(out_len, sizeof(float_value)); + EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &float_value, out_len)); + + double double_value = std::numeric_limits<double>::max(); + out_str = convert_toDOUBLE(ctx_ptr, double_value, &out_len); + EXPECT_EQ(out_len, sizeof(double_value)); + EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, &double_value, out_len)); + + const char* test_string = "test string"; + int32_t str_len = 11; + out_str = convert_toUTF8(ctx_ptr, test_string, str_len, &out_len); + EXPECT_EQ(out_len, str_len); + EXPECT_EQ(ALL_BYTES_MATCH, memcmp(out_str, test_string, out_len)); +} + +TEST(TestStringOps, TestConvertToBigEndian) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx); + gdv_int32 out_len = 0; + gdv_int32 out_len_big_endian = 0; + const char* out_str; + const char* out_str_big_endian; + + int64_t big_integer_value = std::numeric_limits<int64_t>::max(); + out_str = convert_toBIGINT(ctx_ptr, big_integer_value, &out_len); + out_str_big_endian = + convert_toBIGINT_be(ctx_ptr, big_integer_value, &out_len_big_endian); + EXPECT_EQ(out_len_big_endian, sizeof(big_integer_value)); + EXPECT_EQ(out_len_big_endian, out_len); + +#if ARROW_LITTLE_ENDIAN + // Checks that bytes are in reverse order + for (auto i = 0; i < out_len; i++) { + EXPECT_EQ(out_str[i], out_str_big_endian[out_len - (i + 1)]); + } +#else + for (auto i = 0; i < out_len; i++) { + EXPECT_EQ(out_str[i], out_str_big_endian[i]); + } +#endif + + double double_value = std::numeric_limits<double>::max(); + out_str = convert_toDOUBLE(ctx_ptr, double_value, &out_len); + out_str_big_endian = convert_toDOUBLE_be(ctx_ptr, double_value, &out_len_big_endian); + EXPECT_EQ(out_len_big_endian, sizeof(double_value)); + EXPECT_EQ(out_len_big_endian, out_len); + +#if ARROW_LITTLE_ENDIAN + // Checks that bytes are in reverse order + for (auto i = 0; i < out_len; i++) { + EXPECT_EQ(out_str[i], out_str_big_endian[out_len - (i + 1)]); + } +#else + for (auto i = 0; i < out_len; i++) { + EXPECT_EQ(out_str[i], out_str_big_endian[i]); + } +#endif +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/testing.h b/src/arrow/cpp/src/gandiva/precompiled/testing.h new file mode 100644 index 000000000..c41bc5471 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/testing.h @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <ctime> +#include <string> + +#include <gtest/gtest.h> + +#include "arrow/util/logging.h" +#include "arrow/util/value_parsing.h" + +#include "gandiva/date_utils.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +static inline gdv_timestamp StringToTimestamp(const std::string& s) { + int64_t out = 0; + bool success = ::arrow::internal::ParseTimestampStrptime( + s.c_str(), s.length(), "%Y-%m-%d %H:%M:%S", /*ignore_time_in_day=*/false, + /*allow_trailing_chars=*/false, ::arrow::TimeUnit::SECOND, &out); + DCHECK(success); + ARROW_UNUSED(success); + return out * 1000; +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/time.cc b/src/arrow/cpp/src/gandiva/precompiled/time.cc new file mode 100644 index 000000000..336f69226 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/time.cc @@ -0,0 +1,894 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "./epoch_time_point.h" + +extern "C" { + +#define __STDC_FORMAT_MACROS +#include <inttypes.h> +#include <stdlib.h> +#include <string.h> +#include <time.h> + +#include "./time_constants.h" +#include "./time_fields.h" +#include "./types.h" + +#define MINS_IN_HOUR 60 +#define SECONDS_IN_MINUTE 60 +#define SECONDS_IN_HOUR (SECONDS_IN_MINUTE) * (MINS_IN_HOUR) + +#define HOURS_IN_DAY 24 + +// Expand inner macro for all date types. +#define DATE_TYPES(INNER) \ + INNER(date64) \ + INNER(timestamp) + +// Expand inner macro for all base numeric types. +#define NUMERIC_TYPES(INNER) \ + INNER(int8) \ + INNER(int16) \ + INNER(int32) \ + INNER(int64) \ + INNER(uint8) \ + INNER(uint16) \ + INNER(uint32) \ + INNER(uint64) \ + INNER(float32) \ + INNER(float64) + +// Extract millennium +#define EXTRACT_MILLENNIUM(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractMillennium##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return (1900 + tp.TmYear() - 1) / 1000 + 1; \ + } + +DATE_TYPES(EXTRACT_MILLENNIUM) + +// Extract century +#define EXTRACT_CENTURY(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractCentury##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return (1900 + tp.TmYear() - 1) / 100 + 1; \ + } + +DATE_TYPES(EXTRACT_CENTURY) + +// Extract decade +#define EXTRACT_DECADE(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractDecade##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return (1900 + tp.TmYear()) / 10; \ + } + +DATE_TYPES(EXTRACT_DECADE) + +// Extract year. +#define EXTRACT_YEAR(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractYear##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return 1900 + tp.TmYear(); \ + } + +DATE_TYPES(EXTRACT_YEAR) + +#define EXTRACT_DOY(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractDoy##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return 1 + tp.TmYday(); \ + } + +DATE_TYPES(EXTRACT_DOY) + +#define EXTRACT_QUARTER(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractQuarter##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return tp.TmMon() / 3 + 1; \ + } + +DATE_TYPES(EXTRACT_QUARTER) + +#define EXTRACT_MONTH(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractMonth##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return 1 + tp.TmMon(); \ + } + +DATE_TYPES(EXTRACT_MONTH) + +#define JAN1_WDAY(tp) ((tp.TmWday() - (tp.TmYday() % 7) + 7) % 7) + +bool IsLeapYear(int yy) { + if ((yy % 4) != 0) { + // not divisible by 4 + return false; + } + + // yy = 4x + if ((yy % 400) == 0) { + // yy = 400x + return true; + } + + // yy = 4x, return true if yy != 100x + return ((yy % 100) != 0); +} + +// Day belongs to current year +// Note that TmYday is 0 for Jan 1 (subtract 1 from day in the below examples) +// +// If Jan 1 is Mon, (TmYday) / 7 + 1 (Jan 1->WK1, Jan 8->WK2, etc) +// If Jan 1 is Tues, (TmYday + 1) / 7 + 1 (Jan 1->WK1, Jan 7->WK2, etc) +// If Jan 1 is Wed, (TmYday + 2) / 7 + 1 +// If Jan 1 is Thu, (TmYday + 3) / 7 + 1 +// +// If Jan 1 is Fri, Sat or Sun, the first few days belong to the previous year +// If Jan 1 is Fri, (TmYday - 3) / 7 + 1 (Jan 4->WK1, Jan 11->WK2) +// If Jan 1 is Sat, (TmYday - 2) / 7 + 1 (Jan 3->WK1, Jan 10->WK2) +// If Jan 1 is Sun, (TmYday - 1) / 7 + 1 (Jan 2->WK1, Jan 9->WK2) +int weekOfCurrentYear(const EpochTimePoint& tp) { + int jan1_wday = JAN1_WDAY(tp); + switch (jan1_wday) { + // Monday + case 1: + // Tuesday + case 2: + // Wednesday + case 3: + // Thursday + case 4: { + return (tp.TmYday() + jan1_wday - 1) / 7 + 1; + } + // Friday + case 5: + // Saturday + case 6: { + return (tp.TmYday() - (8 - jan1_wday)) / 7 + 1; + } + // Sunday + case 0: { + return (tp.TmYday() - 1) / 7 + 1; + } + } + + // cannot reach here + // keep compiler happy + return 0; +} + +// Jan 1-3 +// If Jan 1 is one of Mon, Tue, Wed, Thu - belongs to week of current year +// If Jan 1 is Fri/Sat/Sun - belongs to previous year +int getJanWeekOfYear(const EpochTimePoint& tp) { + int jan1_wday = JAN1_WDAY(tp); + + if ((jan1_wday >= 1) && (jan1_wday <= 4)) { + // Jan 1-3 with the week belonging to this year + return 1; + } + + if (jan1_wday == 5) { + // Jan 1 is a Fri + // Jan 1-3 belong to previous year. Dec 31 of previous year same week # as Jan 1-3 + // previous year is a leap year: + // Prev Jan 1 is a Wed. Jan 6th is Mon + // Dec 31 - Jan 6 = 366 - 5 = 361 + // week from Jan 6 = (361 - 1) / 7 + 1 = 52 + // week # in previous year = 52 + 1 = 53 + // + // previous year is not a leap year. Jan 1 is Thu. Jan 5th is Mon + // Dec 31 - Jan 5 = 365 - 4 = 361 + // week from Jan 5 = (361 - 1) / 7 + 1 = 52 + // week # in previous year = 52 + 1 = 53 + return 53; + } + + if (jan1_wday == 0) { + // Jan 1 is a Sun + if (tp.TmMday() > 1) { + // Jan 2 and 3 belong to current year + return 1; + } + + // day belongs to previous year. Same as Dec 31 + // Same as the case where Jan 1 is a Fri, except that previous year + // does not have an extra week + // Hence, return 52 + return 52; + } + + // Jan 1 is a Sat + // Jan 1-2 belong to previous year + if (tp.TmMday() == 3) { + // Jan 3, return 1 + return 1; + } + + // prev Jan 1 is leap year + // prev Jan 1 is a Thu + // return 53 (extra week) + if (IsLeapYear(1900 + tp.TmYear() - 1)) { + return 53; + } + + // prev Jan 1 is not a leap year + // prev Jan 1 is a Fri + // return 52 (no extra week) + return 52; +} + +// Dec 29-31 +int getDecWeekOfYear(const EpochTimePoint& tp) { + int next_jan1_wday = (tp.TmWday() + (31 - tp.TmMday()) + 1) % 7; + + if (next_jan1_wday == 4) { + // next Jan 1 is a Thu + // day belongs to week 1 of next year + return 1; + } + + if (next_jan1_wday == 3) { + // next Jan 1 is a Wed + // Dec 31 and 30 belong to next year - return 1 + if (tp.TmMday() != 29) { + return 1; + } + + // Dec 29 belongs to current year + return weekOfCurrentYear(tp); + } + + if (next_jan1_wday == 2) { + // next Jan 1 is a Tue + // Dec 31 belongs to next year - return 1 + if (tp.TmMday() == 31) { + return 1; + } + + // Dec 29 and 30 belong to current year + return weekOfCurrentYear(tp); + } + + // next Jan 1 is a Fri/Sat/Sun. No day from this year belongs to that week + // next Jan 1 is a Mon. No day from this year belongs to that week + return weekOfCurrentYear(tp); +} + +// Week of year is determined by ISO 8601 standard +// Take a look at: https://en.wikipedia.org/wiki/ISO_week_date +// +// Important points to note: +// Week starts with a Monday and ends with a Sunday +// A week can have some days in this year and some days in the previous/next year +// This is true for the first and last weeks +// +// The first week of the year should have at-least 4 days in the current year +// The last week of the year should have at-least 4 days in the current year +// +// A given day might belong to the first week of the next year - e.g Dec 29, 30 and 31 +// A given day might belong to the last week of the previous year - e.g. Jan 1, 2 and 3 +// +// Algorithm: +// If day belongs to week in current year, weekOfCurrentYear +// +// If day is Jan 1-3, see getJanWeekOfYear +// If day is Dec 29-21, see getDecWeekOfYear +// +gdv_int64 weekOfYear(const EpochTimePoint& tp) { + if (tp.TmYday() < 3) { + // Jan 1-3 + return getJanWeekOfYear(tp); + } + + if ((tp.TmMon() == 11) && (tp.TmMday() >= 29)) { + // Dec 29-31 + return getDecWeekOfYear(tp); + } + + return weekOfCurrentYear(tp); +} + +#define EXTRACT_WEEK(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractWeek##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return weekOfYear(tp); \ + } + +DATE_TYPES(EXTRACT_WEEK) + +#define EXTRACT_DOW(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractDow##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return 1 + tp.TmWday(); \ + } + +DATE_TYPES(EXTRACT_DOW) + +#define EXTRACT_DAY(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractDay##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return tp.TmMday(); \ + } + +DATE_TYPES(EXTRACT_DAY) + +#define EXTRACT_HOUR(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractHour##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return tp.TmHour(); \ + } + +DATE_TYPES(EXTRACT_HOUR) + +#define EXTRACT_MINUTE(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractMinute##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return tp.TmMin(); \ + } + +DATE_TYPES(EXTRACT_MINUTE) + +#define EXTRACT_SECOND(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractSecond##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return tp.TmSec(); \ + } + +DATE_TYPES(EXTRACT_SECOND) + +#define EXTRACT_EPOCH(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractEpoch##_##TYPE(gdv_##TYPE millis) { return MILLIS_TO_SEC(millis); } + +DATE_TYPES(EXTRACT_EPOCH) + +// Functions that work on millis in a day +#define EXTRACT_SECOND_TIME(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractSecond##_##TYPE(gdv_##TYPE millis) { \ + gdv_int64 seconds_of_day = MILLIS_TO_SEC(millis); \ + gdv_int64 sec = seconds_of_day % SECONDS_IN_MINUTE; \ + return sec; \ + } + +EXTRACT_SECOND_TIME(time32) + +#define EXTRACT_MINUTE_TIME(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractMinute##_##TYPE(gdv_##TYPE millis) { \ + gdv_##TYPE mins = MILLIS_TO_MINS(millis); \ + return (mins % (MINS_IN_HOUR)); \ + } + +EXTRACT_MINUTE_TIME(time32) + +#define EXTRACT_HOUR_TIME(TYPE) \ + FORCE_INLINE \ + gdv_int64 extractHour##_##TYPE(gdv_##TYPE millis) { return MILLIS_TO_HOUR(millis); } + +EXTRACT_HOUR_TIME(time32) + +#define DATE_TRUNC_FIXED_UNIT(NAME, TYPE, NMILLIS_IN_UNIT) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \ + return ((millis / NMILLIS_IN_UNIT) * NMILLIS_IN_UNIT); \ + } + +#define DATE_TRUNC_WEEK(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE date_trunc_Week_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + int ndays_to_trunc = 0; \ + if (tp.TmWday() == 0) { \ + /* Sunday */ \ + ndays_to_trunc = 6; \ + } else { \ + /* All other days */ \ + ndays_to_trunc = tp.TmWday() - 1; \ + } \ + return tp.AddDays(-ndays_to_trunc).ClearTimeOfDay().MillisSinceEpoch(); \ + } + +#define DATE_TRUNC_MONTH_UNITS(NAME, TYPE, NMONTHS_IN_UNIT) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + int ndays_to_trunc = tp.TmMday() - 1; \ + int nmonths_to_trunc = \ + tp.TmMon() - ((tp.TmMon() / NMONTHS_IN_UNIT) * NMONTHS_IN_UNIT); \ + return tp.AddDays(-ndays_to_trunc) \ + .AddMonths(-nmonths_to_trunc) \ + .ClearTimeOfDay() \ + .MillisSinceEpoch(); \ + } + +#define DATE_TRUNC_YEAR_UNITS(NAME, TYPE, NYEARS_IN_UNIT, OFF_BY) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE(gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + int ndays_to_trunc = tp.TmMday() - 1; \ + int nmonths_to_trunc = tp.TmMon(); \ + int year = 1900 + tp.TmYear(); \ + year = ((year - OFF_BY) / NYEARS_IN_UNIT) * NYEARS_IN_UNIT + OFF_BY; \ + int nyears_to_trunc = tp.TmYear() - (year - 1900); \ + return tp.AddDays(-ndays_to_trunc) \ + .AddMonths(-nmonths_to_trunc) \ + .AddYears(-nyears_to_trunc) \ + .ClearTimeOfDay() \ + .MillisSinceEpoch(); \ + } + +#define DATE_TRUNC_FUNCTIONS(TYPE) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Second, TYPE, MILLIS_IN_SEC) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Minute, TYPE, MILLIS_IN_MIN) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Hour, TYPE, MILLIS_IN_HOUR) \ + DATE_TRUNC_FIXED_UNIT(date_trunc_Day, TYPE, MILLIS_IN_DAY) \ + DATE_TRUNC_WEEK(TYPE) \ + DATE_TRUNC_MONTH_UNITS(date_trunc_Month, TYPE, 1) \ + DATE_TRUNC_MONTH_UNITS(date_trunc_Quarter, TYPE, 3) \ + DATE_TRUNC_MONTH_UNITS(date_trunc_Year, TYPE, 12) \ + DATE_TRUNC_YEAR_UNITS(date_trunc_Decade, TYPE, 10, 0) \ + DATE_TRUNC_YEAR_UNITS(date_trunc_Century, TYPE, 100, 1) \ + DATE_TRUNC_YEAR_UNITS(date_trunc_Millennium, TYPE, 1000, 1) + +DATE_TRUNC_FUNCTIONS(date64) +DATE_TRUNC_FUNCTIONS(timestamp) + +#define LAST_DAY_FUNC(TYPE) \ + FORCE_INLINE \ + gdv_date64 last_day_from_##TYPE(gdv_date64 millis) { \ + EpochTimePoint received_day(millis); \ + const auto& day_without_hours_and_sec = received_day.ClearTimeOfDay(); \ + \ + int received_day_in_month = day_without_hours_and_sec.TmMday(); \ + const auto& first_day_in_month = \ + day_without_hours_and_sec.AddDays(1 - received_day_in_month); \ + \ + const auto& month_last_day = first_day_in_month.AddMonths(1).AddDays(-1); \ + \ + return month_last_day.MillisSinceEpoch(); \ + } + +DATE_TYPES(LAST_DAY_FUNC) + +FORCE_INLINE +gdv_date64 castDATE_int64(gdv_int64 in) { return in; } + +FORCE_INLINE +gdv_date32 castDATE_int32(gdv_int32 in) { return in; } + +FORCE_INLINE +gdv_date64 castDATE_date32(gdv_date32 days) { + return days * static_cast<gdv_date64>(MILLIS_IN_DAY); +} + +static int days_in_month[] = {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; + +bool IsLastDayOfMonth(const EpochTimePoint& tp) { + if (tp.TmMon() != 1) { + // not February. Don't worry about leap year + return (tp.TmMday() == days_in_month[tp.TmMon()]); + } + + // this is February, check if the day is 28 or 29 + if (tp.TmMday() < 28) { + return false; + } + + if (tp.TmMday() == 29) { + // Feb 29th + return true; + } + + // check if year is non-leap year + return !IsLeapYear(tp.TmYear()); +} + +FORCE_INLINE +bool is_valid_time(const int hours, const int minutes, const int seconds) { + return hours >= 0 && hours < 24 && minutes >= 0 && minutes < 60 && seconds >= 0 && + seconds < 60; +} + +// MONTHS_BETWEEN returns number of months between dates date1 and date2. +// If date1 is later than date2, then the result is positive. +// If date1 is earlier than date2, then the result is negative. +// If date1 and date2 are either the same days of the month or both last days of months, +// then the result is always an integer. Otherwise Oracle Database calculates the +// fractional portion of the result based on a 31-day month and considers the difference +// in time components date1 and date2 +#define MONTHS_BETWEEN(TYPE) \ + FORCE_INLINE \ + double months_between##_##TYPE##_##TYPE(uint64_t endEpoch, uint64_t startEpoch) { \ + EpochTimePoint endTime(endEpoch); \ + EpochTimePoint startTime(startEpoch); \ + int endYear = endTime.TmYear(); \ + int endMonth = endTime.TmMon(); \ + int startYear = startTime.TmYear(); \ + int startMonth = startTime.TmMon(); \ + int monthsDiff = (endYear - startYear) * 12 + (endMonth - startMonth); \ + if ((endTime.TmMday() == startTime.TmMday()) || \ + (IsLastDayOfMonth(endTime) && IsLastDayOfMonth(startTime))) { \ + return static_cast<double>(monthsDiff); \ + } \ + double diffDays = static_cast<double>(endTime.TmMday() - startTime.TmMday()) / \ + static_cast<double>(31); \ + double diffHours = static_cast<double>(endTime.TmHour() - startTime.TmHour()) + \ + static_cast<double>(endTime.TmMin() - startTime.TmMin()) / \ + static_cast<double>(MINS_IN_HOUR) + \ + static_cast<double>(endTime.TmSec() - startTime.TmSec()) / \ + static_cast<double>(SECONDS_IN_HOUR); \ + return static_cast<double>(monthsDiff) + diffDays + \ + diffHours / static_cast<double>(HOURS_IN_DAY * 31); \ + } + +DATE_TYPES(MONTHS_BETWEEN) + +FORCE_INLINE +void set_error_for_date(gdv_int32 length, const char* input, const char* msg, + int64_t execution_context) { + int size = length + static_cast<int>(strlen(msg)) + 1; + char* error = reinterpret_cast<char*>(malloc(size)); + snprintf(error, size, "%s%s", msg, input); + gdv_fn_context_set_error_msg(execution_context, error); + free(error); +} + +gdv_date64 castDATE_utf8(int64_t context, const char* input, gdv_int32 length) { + using arrow_vendored::date::day; + using arrow_vendored::date::month; + using arrow_vendored::date::sys_days; + using arrow_vendored::date::year; + using arrow_vendored::date::year_month_day; + using gandiva::TimeFields; + // format : 0 is year, 1 is month and 2 is day. + int dateFields[3]; + int dateIndex = 0, index = 0, value = 0; + int year_str_len = 0; + while (dateIndex < 3 && index < length) { + if (!isdigit(input[index])) { + dateFields[dateIndex++] = value; + value = 0; + } else { + value = (value * 10) + (input[index] - '0'); + if (dateIndex == TimeFields::kYear) { + year_str_len++; + } + } + index++; + } + + if (dateIndex < 3) { + // If we reached the end of input, we would have not encountered a separator + // store the last value + dateFields[dateIndex++] = value; + } + const char* msg = "Not a valid date value "; + if (dateIndex != 3) { + set_error_for_date(length, input, msg, context); + return 0; + } + + /* Handle two digit years + * If range of two digits is between 70 - 99 then year = 1970 - 1999 + * Else if two digits is between 00 - 69 = 2000 - 2069 + */ + if (dateFields[TimeFields::kYear] < 100 && year_str_len < 4) { + if (dateFields[TimeFields::kYear] < 70) { + dateFields[TimeFields::kYear] += 2000; + } else { + dateFields[TimeFields::kYear] += 1900; + } + } + year_month_day date = year(dateFields[TimeFields::kYear]) / + month(dateFields[TimeFields::kMonth]) / + day(dateFields[TimeFields::kDay]); + if (!date.ok()) { + set_error_for_date(length, input, msg, context); + return 0; + } + return std::chrono::time_point_cast<std::chrono::milliseconds>(sys_days(date)) + .time_since_epoch() + .count(); +} + +/* + * Input consists of mandatory and optional fields. + * Mandatory fields are year, month and day. + * Optional fields are time, displacement and zone. + * Format is <year-month-day>[ hours:minutes:seconds][.millis][ displacement|zone] + */ +gdv_timestamp castTIMESTAMP_utf8(int64_t context, const char* input, gdv_int32 length) { + using arrow_vendored::date::day; + using arrow_vendored::date::month; + using arrow_vendored::date::sys_days; + using arrow_vendored::date::year; + using arrow_vendored::date::year_month_day; + using gandiva::TimeFields; + using std::chrono::hours; + using std::chrono::milliseconds; + using std::chrono::minutes; + using std::chrono::seconds; + + int ts_fields[9] = {0, 0, 0, 0, 0, 0, 0, 0, 0}; + gdv_boolean add_displacement = true; + gdv_boolean encountered_zone = false; + int year_str_len = 0, sub_seconds_len = 0; + int ts_field_index = TimeFields::kYear, index = 0, value = 0; + while (ts_field_index < TimeFields::kMax && index < length) { + if (isdigit(input[index])) { + value = (value * 10) + (input[index] - '0'); + if (ts_field_index == TimeFields::kYear) { + year_str_len++; + } + if (ts_field_index == TimeFields::kSubSeconds) { + sub_seconds_len++; + } + } else { + ts_fields[ts_field_index] = value; + value = 0; + + switch (input[index]) { + case '.': + case ':': + case ' ': + ts_field_index++; + break; + case '+': + // +08:00, means time zone is 8 hours ahead. Need to subtract. + add_displacement = false; + ts_field_index = TimeFields::kDisplacementHours; + break; + case '-': + // Overloaded as date separator and negative displacement. + ts_field_index = (ts_field_index < 3) ? (ts_field_index + 1) + : TimeFields::kDisplacementHours; + break; + default: + encountered_zone = true; + break; + } + } + if (encountered_zone) { + break; + } + index++; + } + + // Store the last value + if (ts_field_index < TimeFields::kMax) { + ts_fields[ts_field_index++] = value; + } + + // adjust the year + if (ts_fields[TimeFields::kYear] < 100 && year_str_len < 4) { + if (ts_fields[TimeFields::kYear] < 70) { + ts_fields[TimeFields::kYear] += 2000; + } else { + ts_fields[TimeFields::kYear] += 1900; + } + } + + // adjust the milliseconds + if (sub_seconds_len > 0) { + if (sub_seconds_len > 3) { + const char* msg = "Invalid millis for timestamp value "; + set_error_for_date(length, input, msg, context); + return 0; + } + while (sub_seconds_len < 3) { + ts_fields[TimeFields::kSubSeconds] *= 10; + sub_seconds_len++; + } + } + // handle timezone + if (encountered_zone) { + int err = 0; + gdv_timestamp ret_time = 0; + err = gdv_fn_time_with_zone(&ts_fields[0], (input + index), (length - index), + &ret_time); + if (err) { + const char* msg = "Invalid timestamp or unknown zone for timestamp value "; + set_error_for_date(length, input, msg, context); + return 0; + } + return ret_time; + } + + year_month_day date = year(ts_fields[TimeFields::kYear]) / + month(ts_fields[TimeFields::kMonth]) / + day(ts_fields[TimeFields::kDay]); + if (!date.ok()) { + const char* msg = "Not a valid day for timestamp value "; + set_error_for_date(length, input, msg, context); + return 0; + } + + if (!is_valid_time(ts_fields[TimeFields::kHours], ts_fields[TimeFields::kMinutes], + ts_fields[TimeFields::kSeconds])) { + const char* msg = "Not a valid time for timestamp value "; + set_error_for_date(length, input, msg, context); + return 0; + } + + auto date_time = sys_days(date) + hours(ts_fields[TimeFields::kHours]) + + minutes(ts_fields[TimeFields::kMinutes]) + + seconds(ts_fields[TimeFields::kSeconds]) + + milliseconds(ts_fields[TimeFields::kSubSeconds]); + if (ts_fields[TimeFields::kDisplacementHours] || + ts_fields[TimeFields::kDisplacementMinutes]) { + auto displacement_time = hours(ts_fields[TimeFields::kDisplacementHours]) + + minutes(ts_fields[TimeFields::kDisplacementMinutes]); + date_time = (add_displacement) ? (date_time + displacement_time) + : (date_time - displacement_time); + } + return std::chrono::time_point_cast<milliseconds>(date_time).time_since_epoch().count(); +} + +gdv_timestamp castTIMESTAMP_date64(gdv_date64 date_in_millis) { return date_in_millis; } + +gdv_timestamp castTIMESTAMP_int64(gdv_int64 in) { return in; } + +gdv_date64 castDATE_timestamp(gdv_timestamp timestamp_in_millis) { + EpochTimePoint tp(timestamp_in_millis); + return tp.ClearTimeOfDay().MillisSinceEpoch(); +} + +gdv_time32 castTIME_timestamp(gdv_timestamp timestamp_in_millis) { + // Retrieves a timestamp and returns the number of milliseconds since the midnight + EpochTimePoint tp(timestamp_in_millis); + auto tp_at_midnight = tp.ClearTimeOfDay(); + + int64_t millis_since_midnight = + tp.MillisSinceEpoch() - tp_at_midnight.MillisSinceEpoch(); + + return static_cast<int32_t>(millis_since_midnight); +} + +const char* castVARCHAR_timestamp_int64(gdv_int64 context, gdv_timestamp in, + gdv_int64 length, gdv_int32* out_len) { + gdv_int64 year = extractYear_timestamp(in); + gdv_int64 month = extractMonth_timestamp(in); + gdv_int64 day = extractDay_timestamp(in); + gdv_int64 hour = extractHour_timestamp(in); + gdv_int64 minute = extractMinute_timestamp(in); + gdv_int64 second = extractSecond_timestamp(in); + gdv_int64 millis = in % MILLIS_IN_SEC; + + static const int kTimeStampStringLen = 23; + const int char_buffer_length = kTimeStampStringLen + 1; // snprintf adds \0 + char char_buffer[char_buffer_length]; + + // yyyy-MM-dd hh:mm:ss.sss + int res = snprintf(char_buffer, char_buffer_length, + "%04" PRId64 "-%02" PRId64 "-%02" PRId64 " %02" PRId64 ":%02" PRId64 + ":%02" PRId64 ".%03" PRId64, + year, month, day, hour, minute, second, millis); + if (res < 0) { + gdv_fn_context_set_error_msg(context, "Could not format the timestamp"); + return ""; + } + + *out_len = static_cast<gdv_int32>(length); + if (*out_len > kTimeStampStringLen) { + *out_len = kTimeStampStringLen; + } + + if (*out_len <= 0) { + if (*out_len < 0) { + gdv_fn_context_set_error_msg(context, "Length of output string cannot be negative"); + } + *out_len = 0; + return ""; + } + + char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); + if (ret == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + + memcpy(ret, char_buffer, *out_len); + return ret; +} + +FORCE_INLINE +gdv_int64 extractDay_daytimeinterval(gdv_day_time_interval in) { + gdv_int32 days = static_cast<gdv_int32>(in & 0x00000000FFFFFFFF); + return static_cast<gdv_int64>(days); +} + +FORCE_INLINE +gdv_int64 extractMillis_daytimeinterval(gdv_day_time_interval in) { + gdv_int32 millis = static_cast<gdv_int32>((in & 0xFFFFFFFF00000000) >> 32); + return static_cast<gdv_int64>(millis); +} + +FORCE_INLINE +gdv_int64 castBIGINT_daytimeinterval(gdv_day_time_interval in) { + return extractMillis_daytimeinterval(in) + + extractDay_daytimeinterval(in) * MILLIS_IN_DAY; +} + +// Convert the seconds since epoch argument to timestamp +#define TO_TIMESTAMP(TYPE) \ + FORCE_INLINE \ + gdv_timestamp to_timestamp##_##TYPE(gdv_##TYPE seconds) { \ + return static_cast<gdv_timestamp>(seconds * MILLIS_IN_SEC); \ + } + +NUMERIC_TYPES(TO_TIMESTAMP) + +// Convert the seconds since epoch argument to time +#define TO_TIME(TYPE) \ + FORCE_INLINE \ + gdv_time32 to_time##_##TYPE(gdv_##TYPE seconds) { \ + EpochTimePoint tp(static_cast<int64_t>(seconds * MILLIS_IN_SEC)); \ + return static_cast<gdv_time32>(tp.TimeOfDay().to_duration().count()); \ + } + +NUMERIC_TYPES(TO_TIME) + +#define CAST_INT_YEAR_INTERVAL(TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE TYPE##_year_interval(gdv_month_interval in) { \ + return static_cast<gdv_##OUT_TYPE>(in / 12.0); \ + } + +CAST_INT_YEAR_INTERVAL(castBIGINT, int64) +CAST_INT_YEAR_INTERVAL(castINT, int32) + +#define CAST_NULLABLE_INTERVAL_DAY(TYPE) \ + FORCE_INLINE \ + gdv_day_time_interval castNULLABLEINTERVALDAY_##TYPE(gdv_##TYPE in) { \ + return static_cast<gdv_day_time_interval>(in); \ + } + +CAST_NULLABLE_INTERVAL_DAY(int32) +CAST_NULLABLE_INTERVAL_DAY(int64) + +#define CAST_NULLABLE_INTERVAL_YEAR(TYPE) \ + FORCE_INLINE \ + gdv_month_interval castNULLABLEINTERVALYEAR_##TYPE(int64_t context, gdv_##TYPE in) { \ + gdv_month_interval value = static_cast<gdv_month_interval>(in); \ + if (value != in) { \ + gdv_fn_context_set_error_msg(context, "Integer overflow"); \ + } \ + return value; \ + } + +CAST_NULLABLE_INTERVAL_YEAR(int32) +CAST_NULLABLE_INTERVAL_YEAR(int64) + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/time_constants.h b/src/arrow/cpp/src/gandiva/precompiled/time_constants.h new file mode 100644 index 000000000..015ef4bf9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/time_constants.h @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#define MILLIS_IN_SEC (1000) +#define MILLIS_IN_MIN (60 * MILLIS_IN_SEC) +#define MILLIS_IN_HOUR (60 * MILLIS_IN_MIN) +#define MILLIS_IN_DAY (24 * MILLIS_IN_HOUR) +#define MILLIS_IN_WEEK (7 * MILLIS_IN_DAY) + +#define MILLIS_TO_SEC(millis) ((millis) / MILLIS_IN_SEC) +#define MILLIS_TO_MINS(millis) ((millis) / MILLIS_IN_MIN) +#define MILLIS_TO_HOUR(millis) ((millis) / MILLIS_IN_HOUR) +#define MILLIS_TO_DAY(millis) ((millis) / MILLIS_IN_DAY) +#define MILLIS_TO_WEEK(millis) ((millis) / MILLIS_IN_WEEK) diff --git a/src/arrow/cpp/src/gandiva/precompiled/time_fields.h b/src/arrow/cpp/src/gandiva/precompiled/time_fields.h new file mode 100644 index 000000000..d5277e743 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/time_fields.h @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace gandiva { + +enum TimeFields { + kYear, + kMonth, + kDay, + kHours, + kMinutes, + kSeconds, + kSubSeconds, + kDisplacementHours, + kDisplacementMinutes, + kMax +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/time_test.cc b/src/arrow/cpp/src/gandiva/precompiled/time_test.cc new file mode 100644 index 000000000..332ffa332 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/time_test.cc @@ -0,0 +1,953 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include <time.h> + +#include "../execution_context.h" +#include "gandiva/precompiled/testing.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestTime, TestCastDate) { + ExecutionContext context; + int64_t context_ptr = reinterpret_cast<int64_t>(&context); + + EXPECT_EQ(castDATE_utf8(context_ptr, "1967-12-1", 9), -65836800000); + EXPECT_EQ(castDATE_utf8(context_ptr, "2067-12-1", 9), 3089923200000); + + EXPECT_EQ(castDATE_utf8(context_ptr, "7-12-1", 6), 1196467200000); + EXPECT_EQ(castDATE_utf8(context_ptr, "67-12-1", 7), 3089923200000); + EXPECT_EQ(castDATE_utf8(context_ptr, "067-12-1", 8), 3089923200000); + EXPECT_EQ(castDATE_utf8(context_ptr, "0067-12-1", 9), -60023980800000); + EXPECT_EQ(castDATE_utf8(context_ptr, "00067-12-1", 10), -60023980800000); + EXPECT_EQ(castDATE_utf8(context_ptr, "167-12-1", 8), -56868307200000); + + EXPECT_EQ(castDATE_utf8(context_ptr, "1972-12-1", 9), 92016000000); + EXPECT_EQ(castDATE_utf8(context_ptr, "72-12-1", 7), 92016000000); + + EXPECT_EQ(castDATE_utf8(context_ptr, "1972222222", 10), 0); + EXPECT_EQ(context.get_error(), "Not a valid date value 1972222222"); + context.Reset(); + + EXPECT_EQ(castDATE_utf8(context_ptr, "blahblah", 8), 0); + EXPECT_EQ(castDATE_utf8(context_ptr, "1967-12-1bb", 11), -65836800000); + + EXPECT_EQ(castDATE_utf8(context_ptr, "67-12-1", 7), 3089923200000); + EXPECT_EQ(castDATE_utf8(context_ptr, "67-1-1", 6), 3061065600000); + EXPECT_EQ(castDATE_utf8(context_ptr, "71-1-1", 6), 31536000000); + EXPECT_EQ(castDATE_utf8(context_ptr, "71-45-1", 7), 0); + EXPECT_EQ(castDATE_utf8(context_ptr, "71-12-XX", 8), 0); + + EXPECT_EQ(castDATE_date32(1), 86400000); +} + +TEST(TestTime, TestCastTimestamp) { + ExecutionContext context; + int64_t context_ptr = reinterpret_cast<int64_t>(&context); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1967-12-1", 9), -65836800000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2067-12-1", 9), 3089923200000); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "7-12-1", 6), 1196467200000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "67-12-1", 7), 3089923200000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "067-12-1", 8), 3089923200000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "0067-12-1", 9), -60023980800000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "00067-12-1", 10), -60023980800000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "167-12-1", 8), -56868307200000); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1972-12-1", 9), 92016000000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "72-12-1", 7), 92016000000); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1972-12-1", 9), 92016000000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "67-12-1", 7), 3089923200000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "67-1-1", 6), 3061065600000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "71-1-1", 6), 31536000000); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18), 969702330000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920", 22), 969702330920); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 +08:00", 29), + 969673530920); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 -11:45", 29), + 969744630920); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "65-03-04 00:20:40.920 +00:30", 28), + 3003349840920); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1932-05-18 11:30:00.920 +11:30", 30), + -1187308799080); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1857-02-11 20:31:40.920 -05:30", 30), + -3562264699080); + EXPECT_EQ(castTIMESTAMP_date64( + castDATE_utf8(context_ptr, "2000-09-23 9:45:30.920 +08:00", 29)), + castTIMESTAMP_utf8(context_ptr, "2000-09-23 0:00:00.000 +00:00", 29)); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.1", 20), + castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18) + 100); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.10", 20), + castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18) + 100); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.100", 20), + castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30", 18) + 100); + + // error cases + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 24:00:00", 19), 0); + EXPECT_EQ(context.get_error(), + "Not a valid time for timestamp value 2000-01-01 24:00:00"); + context.Reset(); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:60:00", 19), 0); + EXPECT_EQ(context.get_error(), + "Not a valid time for timestamp value 2000-01-01 00:60:00"); + context.Reset(); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:00:100", 20), 0); + EXPECT_EQ(context.get_error(), + "Not a valid time for timestamp value 2000-01-01 00:00:100"); + context.Reset(); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:00:00.0001", 24), 0); + EXPECT_EQ(context.get_error(), + "Invalid millis for timestamp value 2000-01-01 00:00:00.0001"); + context.Reset(); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-01-01 00:00:00.1000", 24), 0); + EXPECT_EQ(context.get_error(), + "Invalid millis for timestamp value 2000-01-01 00:00:00.1000"); + context.Reset(); +} + +#ifndef _WIN32 + +// TODO(wesm): ARROW-4495. Need to address TZ database issues on Windows + +TEST(TestTime, TestCastTimestampWithTZ) { + ExecutionContext context; + int64_t context_ptr = reinterpret_cast<int64_t>(&context); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 Canada/Pacific", 37), + 969727530920); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2012-02-28 23:30:59 Asia/Kolkata", 32), + 1330452059000); + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "1923-10-07 03:03:03 America/New_York", 36), + -1459094217000); +} + +TEST(TestTime, TestCastTimestampErrors) { + ExecutionContext context; + int64_t context_ptr = reinterpret_cast<int64_t>(&context); + + // error cases + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "20000923", 8), 0); + EXPECT_EQ(context.get_error(), "Not a valid day for timestamp value 20000923"); + context.Reset(); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-2b", 10), 0); + EXPECT_EQ(context.get_error(), + "Invalid timestamp or unknown zone for timestamp value 2000-09-2b"); + context.Reset(); + + EXPECT_EQ(castTIMESTAMP_utf8(context_ptr, "2000-09-23 9:45:30.920 Unknown/Zone", 35), + 0); + EXPECT_EQ(context.get_error(), + "Invalid timestamp or unknown zone for timestamp value 2000-09-23 " + "9:45:30.920 Unknown/Zone"); + context.Reset(); +} + +#endif + +TEST(TestTime, TestExtractTime) { + // 10:20:33 + gdv_int32 time_as_millis_in_day = 37233000; + + EXPECT_EQ(extractHour_time32(time_as_millis_in_day), 10); + EXPECT_EQ(extractMinute_time32(time_as_millis_in_day), 20); + EXPECT_EQ(extractSecond_time32(time_as_millis_in_day), 33); +} + +TEST(TestTime, TestTimestampDiffMonth) { + gdv_timestamp ts1 = StringToTimestamp("2019-06-30 00:00:00"); + gdv_timestamp ts2 = StringToTimestamp("2019-05-31 00:00:00"); + EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), -1); + + ts1 = StringToTimestamp("2019-06-30 00:00:00"); + ts2 = StringToTimestamp("2019-02-28 00:00:00"); + EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), -4); + + ts1 = StringToTimestamp("2019-06-30 00:00:00"); + ts2 = StringToTimestamp("2019-03-31 00:00:00"); + EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), -3); + + ts1 = StringToTimestamp("2019-06-30 00:00:00"); + ts2 = StringToTimestamp("2019-06-30 00:00:00"); + EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 0); + + ts1 = StringToTimestamp("2019-06-30 00:00:00"); + ts2 = StringToTimestamp("2019-07-31 00:00:00"); + EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 1); + + ts1 = StringToTimestamp("2019-06-30 00:00:00"); + ts2 = StringToTimestamp("2019-07-30 00:00:00"); + EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 1); + + ts1 = StringToTimestamp("2019-06-30 00:00:00"); + ts2 = StringToTimestamp("2019-07-29 00:00:00"); + EXPECT_EQ(timestampdiffMonth_timestamp_timestamp(ts1, ts2), 0); +} + +TEST(TestTime, TestExtractTimestamp) { + gdv_timestamp ts = StringToTimestamp("1970-05-02 10:20:33"); + + EXPECT_EQ(extractMillennium_timestamp(ts), 2); + EXPECT_EQ(extractCentury_timestamp(ts), 20); + EXPECT_EQ(extractDecade_timestamp(ts), 197); + EXPECT_EQ(extractYear_timestamp(ts), 1970); + EXPECT_EQ(extractDoy_timestamp(ts), 122); + EXPECT_EQ(extractMonth_timestamp(ts), 5); + EXPECT_EQ(extractDow_timestamp(ts), 7); + EXPECT_EQ(extractDay_timestamp(ts), 2); + EXPECT_EQ(extractHour_timestamp(ts), 10); + EXPECT_EQ(extractMinute_timestamp(ts), 20); + EXPECT_EQ(extractSecond_timestamp(ts), 33); +} + +TEST(TestTime, TimeStampTrunc) { + EXPECT_EQ(date_trunc_Second_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 10:20:34")); + EXPECT_EQ(date_trunc_Minute_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 10:20:00")); + EXPECT_EQ(date_trunc_Hour_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 10:00:00")); + EXPECT_EQ(date_trunc_Day_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-05 00:00:00")); + EXPECT_EQ(date_trunc_Month_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-05-01 00:00:00")); + EXPECT_EQ(date_trunc_Quarter_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-04-01 00:00:00")); + EXPECT_EQ(date_trunc_Year_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2015-01-01 00:00:00")); + EXPECT_EQ(date_trunc_Decade_date64(StringToTimestamp("2015-05-05 10:20:34")), + StringToTimestamp("2010-01-01 00:00:00")); + EXPECT_EQ(date_trunc_Century_date64(StringToTimestamp("2115-05-05 10:20:34")), + StringToTimestamp("2101-01-01 00:00:00")); + EXPECT_EQ(date_trunc_Millennium_date64(StringToTimestamp("2115-05-05 10:20:34")), + StringToTimestamp("2001-01-01 00:00:00")); + + // truncate week going to previous year + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-01 10:10:10")), + StringToTimestamp("2010-12-27 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-02 10:10:10")), + StringToTimestamp("2010-12-27 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-03 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-04 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-05 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-06 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-07 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-08 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2011-01-09 10:10:10")), + StringToTimestamp("2011-01-03 00:00:00")); + + // truncate week for Feb in a leap year + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-02-28 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-02-29 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-01 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-02 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-03 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-04 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-05 10:10:10")), + StringToTimestamp("2000-02-28 00:00:00")); + EXPECT_EQ(date_trunc_Week_timestamp(StringToTimestamp("2000-03-06 10:10:10")), + StringToTimestamp("2000-03-06 00:00:00")); +} + +TEST(TestTime, TimeStampAdd) { + EXPECT_EQ( + timestampaddSecond_int32_timestamp(30, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("2000-05-01 10:21:04")); + + EXPECT_EQ( + timestampaddSecond_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 30), + StringToTimestamp("2000-05-01 10:21:04")); + + EXPECT_EQ( + timestampaddMinute_int64_timestamp(-30, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("2000-05-01 09:50:34")); + + EXPECT_EQ( + timestampaddMinute_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), -30), + StringToTimestamp("2000-05-01 09:50:34")); + + EXPECT_EQ( + timestampaddHour_int32_timestamp(20, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("2000-05-02 06:20:34")); + + EXPECT_EQ( + timestampaddHour_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 20), + StringToTimestamp("2000-05-02 06:20:34")); + + EXPECT_EQ( + timestampaddDay_int64_timestamp(-35, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("2000-03-27 10:20:34")); + + EXPECT_EQ( + timestampaddDay_timestamp_int64(StringToTimestamp("2000-05-01 10:20:34"), -35), + StringToTimestamp("2000-03-27 10:20:34")); + + EXPECT_EQ(timestampaddWeek_int32_timestamp(4, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("2000-05-29 10:20:34")); + + EXPECT_EQ(timestampaddWeek_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 4), + StringToTimestamp("2000-05-29 10:20:34")); + + EXPECT_EQ(timestampaddWeek_timestamp_int32(StringToTimestamp("2000-05-01 10:20:34"), 4), + StringToTimestamp("2000-05-29 10:20:34")); + + EXPECT_EQ( + timestampaddMonth_int64_timestamp(10, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("2001-03-01 10:20:34")); + + EXPECT_EQ( + timestampaddMonth_int64_timestamp(1, StringToTimestamp("2000-01-31 10:20:34")), + StringToTimestamp("2000-2-29 10:20:34")); + EXPECT_EQ( + timestampaddMonth_int64_timestamp(13, StringToTimestamp("2001-01-31 10:20:34")), + StringToTimestamp("2002-02-28 10:20:34")); + + EXPECT_EQ( + timestampaddMonth_int64_timestamp(11, StringToTimestamp("2000-05-31 10:20:34")), + StringToTimestamp("2001-04-30 10:20:34")); + + EXPECT_EQ( + timestampaddMonth_timestamp_int64(StringToTimestamp("2000-05-31 10:20:34"), 11), + StringToTimestamp("2001-04-30 10:20:34")); + + EXPECT_EQ( + timestampaddQuarter_int32_timestamp(-2, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("1999-11-01 10:20:34")); + + EXPECT_EQ(timestampaddYear_int64_timestamp(2, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("2002-05-01 10:20:34")); + + EXPECT_EQ( + timestampaddQuarter_int32_timestamp(-5, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("1999-02-01 10:20:34")); + EXPECT_EQ( + timestampaddQuarter_int32_timestamp(-6, StringToTimestamp("2000-05-01 10:20:34")), + StringToTimestamp("1998-11-01 10:20:34")); + + // date_add + EXPECT_EQ(date_add_int32_timestamp(7, StringToTimestamp("2000-05-01 00:00:00")), + StringToTimestamp("2000-05-08 00:00:00")); + + EXPECT_EQ(add_int32_timestamp(4, StringToTimestamp("2000-05-01 00:00:00")), + StringToTimestamp("2000-05-05 00:00:00")); + + EXPECT_EQ(add_int64_timestamp(7, StringToTimestamp("2000-05-01 00:00:00")), + StringToTimestamp("2000-05-08 00:00:00")); + + EXPECT_EQ(date_add_int64_timestamp(4, StringToTimestamp("2000-05-01 00:00:00")), + StringToTimestamp("2000-05-05 00:00:00")); + + EXPECT_EQ(date_add_int64_timestamp(4, StringToTimestamp("2000-02-27 00:00:00")), + StringToTimestamp("2000-03-02 00:00:00")); + + EXPECT_EQ(add_date64_int64(StringToTimestamp("2000-02-27 00:00:00"), 4), + StringToTimestamp("2000-03-02 00:00:00")); + + // date_sub + EXPECT_EQ(date_sub_timestamp_int32(StringToTimestamp("2000-05-01 00:00:00"), 7), + StringToTimestamp("2000-04-24 00:00:00")); + + EXPECT_EQ(subtract_timestamp_int32(StringToTimestamp("2000-05-01 00:00:00"), -7), + StringToTimestamp("2000-05-08 00:00:00")); + + EXPECT_EQ(date_diff_timestamp_int64(StringToTimestamp("2000-05-01 00:00:00"), 365), + StringToTimestamp("1999-05-02 00:00:00")); + + EXPECT_EQ(date_diff_timestamp_int64(StringToTimestamp("2000-03-01 00:00:00"), 1), + StringToTimestamp("2000-02-29 00:00:00")); + + EXPECT_EQ(date_diff_timestamp_int64(StringToTimestamp("2000-02-29 00:00:00"), 365), + StringToTimestamp("1999-03-01 00:00:00")); +} + +// test cases from http://www.staff.science.uu.nl/~gent0113/calendar/isocalendar.htm +TEST(TestTime, TestExtractWeek) { + std::vector<std::string> data; + + // A type + // Jan 1, 2 and 3 + data.push_back("2006-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2006-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2006-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2006-04-24 10:10:10"); + data.push_back("17"); + data.push_back("2006-04-30 10:10:10"); + data.push_back("17"); + // Dec 29-31 + data.push_back("2006-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2006-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2006-12-31 10:10:10"); + data.push_back("52"); + // B(C) type + // Jan 1, 2 and 3 + data.push_back("2011-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2011-01-02 10:10:10"); + data.push_back("52"); + data.push_back("2011-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2011-07-18 10:10:10"); + data.push_back("29"); + data.push_back("2011-07-24 10:10:10"); + data.push_back("29"); + // Dec 29-31 + data.push_back("2011-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2011-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2011-12-31 10:10:10"); + data.push_back("52"); + // B(DC) type + // Jan 1, 2 and 3 + data.push_back("2005-01-01 10:10:10"); + data.push_back("53"); + data.push_back("2005-01-02 10:10:10"); + data.push_back("53"); + data.push_back("2005-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2005-11-07 10:10:10"); + data.push_back("45"); + data.push_back("2005-11-13 10:10:10"); + data.push_back("45"); + // Dec 29-31 + data.push_back("2005-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2005-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2005-12-31 10:10:10"); + data.push_back("52"); + // C type + // Jan 1, 2 and 3 + data.push_back("2010-01-01 10:10:10"); + data.push_back("53"); + data.push_back("2010-01-02 10:10:10"); + data.push_back("53"); + data.push_back("2010-01-03 10:10:10"); + data.push_back("53"); + // middle, Monday and Sunday + data.push_back("2010-09-13 10:10:10"); + data.push_back("37"); + data.push_back("2010-09-19 10:10:10"); + data.push_back("37"); + // Dec 29-31 + data.push_back("2010-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2010-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2010-12-31 10:10:10"); + data.push_back("52"); + // D type + // Jan 1, 2 and 3 + data.push_back("2037-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2037-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2037-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2037-08-17 10:10:10"); + data.push_back("34"); + data.push_back("2037-08-23 10:10:10"); + data.push_back("34"); + // Dec 29-31 + data.push_back("2037-12-29 10:10:10"); + data.push_back("53"); + data.push_back("2037-12-30 10:10:10"); + data.push_back("53"); + data.push_back("2037-12-31 10:10:10"); + data.push_back("53"); + // E type + // Jan 1, 2 and 3 + data.push_back("2014-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2014-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2014-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2014-01-13 10:10:10"); + data.push_back("3"); + data.push_back("2014-01-19 10:10:10"); + data.push_back("3"); + // Dec 29-31 + data.push_back("2014-12-29 10:10:10"); + data.push_back("1"); + data.push_back("2014-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2014-12-31 10:10:10"); + data.push_back("1"); + // F type + // Jan 1, 2 and 3 + data.push_back("2019-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2019-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2019-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2019-02-11 10:10:10"); + data.push_back("7"); + data.push_back("2019-02-17 10:10:10"); + data.push_back("7"); + // Dec 29-31 + data.push_back("2019-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2019-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2019-12-31 10:10:10"); + data.push_back("1"); + // G type + // Jan 1, 2 and 3 + data.push_back("2001-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2001-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2001-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2001-03-19 10:10:10"); + data.push_back("12"); + data.push_back("2001-03-25 10:10:10"); + data.push_back("12"); + // Dec 29-31 + data.push_back("2001-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2001-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2001-12-31 10:10:10"); + data.push_back("1"); + // AG type + // Jan 1, 2 and 3 + data.push_back("2012-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2012-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2012-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2012-04-02 10:10:10"); + data.push_back("14"); + data.push_back("2012-04-08 10:10:10"); + data.push_back("14"); + // Dec 29-31 + data.push_back("2012-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2012-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2012-12-31 10:10:10"); + data.push_back("1"); + // BA type + // Jan 1, 2 and 3 + data.push_back("2000-01-01 10:10:10"); + data.push_back("52"); + data.push_back("2000-01-02 10:10:10"); + data.push_back("52"); + data.push_back("2000-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2000-05-22 10:10:10"); + data.push_back("21"); + data.push_back("2000-05-28 10:10:10"); + data.push_back("21"); + // Dec 29-31 + data.push_back("2000-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2000-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2000-12-31 10:10:10"); + data.push_back("52"); + // CB type + // Jan 1, 2 and 3 + data.push_back("2016-01-01 10:10:10"); + data.push_back("53"); + data.push_back("2016-01-02 10:10:10"); + data.push_back("53"); + data.push_back("2016-01-03 10:10:10"); + data.push_back("53"); + // middle, Monday and Sunday + data.push_back("2016-06-20 10:10:10"); + data.push_back("25"); + data.push_back("2016-06-26 10:10:10"); + data.push_back("25"); + // Dec 29-31 + data.push_back("2016-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2016-12-30 10:10:10"); + data.push_back("52"); + data.push_back("2016-12-31 10:10:10"); + data.push_back("52"); + // DC type + // Jan 1, 2 and 3 + data.push_back("2004-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2004-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2004-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2004-07-19 10:10:10"); + data.push_back("30"); + data.push_back("2004-07-25 10:10:10"); + data.push_back("30"); + // Dec 29-31 + data.push_back("2004-12-29 10:10:10"); + data.push_back("53"); + data.push_back("2004-12-30 10:10:10"); + data.push_back("53"); + data.push_back("2004-12-31 10:10:10"); + data.push_back("53"); + // ED type + // Jan 1, 2 and 3 + data.push_back("2020-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2020-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2020-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2020-08-17 10:10:10"); + data.push_back("34"); + data.push_back("2020-08-23 10:10:10"); + data.push_back("34"); + // Dec 29-31 + data.push_back("2020-12-29 10:10:10"); + data.push_back("53"); + data.push_back("2020-12-30 10:10:10"); + data.push_back("53"); + data.push_back("2020-12-31 10:10:10"); + data.push_back("53"); + // FE type + // Jan 1, 2 and 3 + data.push_back("2008-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2008-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2008-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2008-09-15 10:10:10"); + data.push_back("38"); + data.push_back("2008-09-21 10:10:10"); + data.push_back("38"); + // Dec 29-31 + data.push_back("2008-12-29 10:10:10"); + data.push_back("1"); + data.push_back("2008-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2008-12-31 10:10:10"); + data.push_back("1"); + // GF type + // Jan 1, 2 and 3 + data.push_back("2024-01-01 10:10:10"); + data.push_back("1"); + data.push_back("2024-01-02 10:10:10"); + data.push_back("1"); + data.push_back("2024-01-03 10:10:10"); + data.push_back("1"); + // middle, Monday and Sunday + data.push_back("2024-10-07 10:10:10"); + data.push_back("41"); + data.push_back("2024-10-13 10:10:10"); + data.push_back("41"); + // Dec 29-31 + data.push_back("2024-12-29 10:10:10"); + data.push_back("52"); + data.push_back("2024-12-30 10:10:10"); + data.push_back("1"); + data.push_back("2024-12-31 10:10:10"); + data.push_back("1"); + + for (uint32_t i = 0; i < data.size(); i += 2) { + gdv_timestamp ts = StringToTimestamp(data.at(i).c_str()); + gdv_int64 exp = atol(data.at(i + 1).c_str()); + EXPECT_EQ(extractWeek_timestamp(ts), exp); + } +} + +TEST(TestTime, TestMonthsBetween) { + std::vector<std::string> testStrings = { + "1995-03-02 00:00:00", "1995-02-02 00:00:00", "1.0", + "1995-02-02 00:00:00", "1995-03-02 00:00:00", "-1.0", + "1995-03-31 00:00:00", "1995-02-28 00:00:00", "1.0", + "1996-03-31 00:00:00", "1996-02-28 00:00:00", "1.09677418", + "1996-03-31 00:00:00", "1996-02-29 00:00:00", "1.0", + "1996-05-31 00:00:00", "1996-04-30 00:00:00", "1.0", + "1996-05-31 00:00:00", "1996-03-31 00:00:00", "2.0", + "1996-05-31 00:00:00", "1996-03-30 00:00:00", "2.03225806", + "1996-03-15 00:00:00", "1996-02-14 00:00:00", "1.03225806", + "1995-02-02 00:00:00", "1995-01-01 00:00:00", "1.03225806", + "1995-02-02 10:00:00", "1995-01-01 11:00:00", "1.03091397"}; + + for (uint32_t i = 0; i < testStrings.size();) { + gdv_timestamp endTs = StringToTimestamp(testStrings[i++].c_str()); + gdv_timestamp startTs = StringToTimestamp(testStrings[i++].c_str()); + + double expectedResult = atof(testStrings[i++].c_str()); + double actualResult = months_between_timestamp_timestamp(endTs, startTs); + + double diff = actualResult - expectedResult; + if (diff < 0) { + diff = expectedResult - actualResult; + } + + EXPECT_TRUE(diff < 0.001); + } +} + +TEST(TestTime, castVarcharTimestamp) { + ExecutionContext context; + int64_t context_ptr = reinterpret_cast<int64_t>(&context); + gdv_int32 out_len; + gdv_timestamp ts = StringToTimestamp("2000-05-01 10:20:34"); + const char* out = castVARCHAR_timestamp_int64(context_ptr, ts, 30L, &out_len); + EXPECT_EQ(std::string(out, out_len), "2000-05-01 10:20:34.000"); + + out = castVARCHAR_timestamp_int64(context_ptr, ts, 19L, &out_len); + EXPECT_EQ(std::string(out, out_len), "2000-05-01 10:20:34"); + + out = castVARCHAR_timestamp_int64(context_ptr, ts, 0L, &out_len); + EXPECT_EQ(std::string(out, out_len), ""); + + ts = StringToTimestamp("2-5-1 00:00:04"); + out = castVARCHAR_timestamp_int64(context_ptr, ts, 24L, &out_len); + EXPECT_EQ(std::string(out, out_len), "0002-05-01 00:00:04.000"); +} + +TEST(TestTime, TestCastTimestampToDate) { + gdv_timestamp ts = StringToTimestamp("2000-05-01 10:20:34"); + auto out = castDATE_timestamp(ts); + EXPECT_EQ(StringToTimestamp("2000-05-01 00:00:00"), out); +} + +TEST(TestTime, TestCastTimestampToTime) { + gdv_timestamp ts = StringToTimestamp("2000-05-01 10:20:34"); + auto expected_response = + static_cast<int32_t>(ts - StringToTimestamp("2000-05-01 00:00:00")); + auto out = castTIME_timestamp(ts); + EXPECT_EQ(expected_response, out); + + // Test when the defined value is midnight, so the returned value must 0 + ts = StringToTimestamp("1998-12-01 00:00:00"); + expected_response = 0; + out = castTIME_timestamp(ts); + EXPECT_EQ(expected_response, out); + + ts = StringToTimestamp("2015-09-16 23:59:59"); + expected_response = static_cast<int32_t>(ts - StringToTimestamp("2015-09-16 00:00:00")); + out = castTIME_timestamp(ts); + EXPECT_EQ(expected_response, out); +} + +TEST(TestTime, TestLastDay) { + // leap year test + gdv_timestamp ts = StringToTimestamp("2016-02-11 03:20:34"); + auto out = last_day_from_timestamp(ts); + EXPECT_EQ(StringToTimestamp("2016-02-29 00:00:00"), out); + + ts = StringToTimestamp("2016-02-29 23:59:59"); + out = last_day_from_timestamp(ts); + EXPECT_EQ(StringToTimestamp("2016-02-29 00:00:00"), out); + + ts = StringToTimestamp("2016-01-30 23:59:00"); + out = last_day_from_timestamp(ts); + EXPECT_EQ(StringToTimestamp("2016-01-31 00:00:00"), out); + + // normal year + ts = StringToTimestamp("2017-02-03 23:59:59"); + out = last_day_from_timestamp(ts); + EXPECT_EQ(StringToTimestamp("2017-02-28 00:00:00"), out); + + // december + ts = StringToTimestamp("2015-12-03 03:12:59"); + out = last_day_from_timestamp(ts); + EXPECT_EQ(StringToTimestamp("2015-12-31 00:00:00"), out); +} + +TEST(TestTime, TestToTimestamp) { + auto ts = StringToTimestamp("1970-01-01 00:00:00"); + EXPECT_EQ(ts, to_timestamp_int32(0)); + EXPECT_EQ(ts, to_timestamp_int64(0)); + EXPECT_EQ(ts, to_timestamp_float32(0)); + EXPECT_EQ(ts, to_timestamp_float64(0)); + + ts = StringToTimestamp("1970-01-01 00:00:01"); + EXPECT_EQ(ts, to_timestamp_int32(1)); + EXPECT_EQ(ts, to_timestamp_int64(1)); + EXPECT_EQ(ts, to_timestamp_float32(1)); + EXPECT_EQ(ts, to_timestamp_float64(1)); + + ts = StringToTimestamp("1970-01-01 00:01:00"); + EXPECT_EQ(ts, to_timestamp_int32(60)); + EXPECT_EQ(ts, to_timestamp_int64(60)); + EXPECT_EQ(ts, to_timestamp_float32(60)); + EXPECT_EQ(ts, to_timestamp_float64(60)); + + ts = StringToTimestamp("1970-01-01 01:00:00"); + EXPECT_EQ(ts, to_timestamp_int32(3600)); + EXPECT_EQ(ts, to_timestamp_int64(3600)); + EXPECT_EQ(ts, to_timestamp_float32(3600)); + EXPECT_EQ(ts, to_timestamp_float64(3600)); + + ts = StringToTimestamp("1970-01-02 00:00:00"); + EXPECT_EQ(ts, to_timestamp_int32(86400)); + EXPECT_EQ(ts, to_timestamp_int64(86400)); + EXPECT_EQ(ts, to_timestamp_float32(86400)); + EXPECT_EQ(ts, to_timestamp_float64(86400)); + + // tests with fractional part + ts = StringToTimestamp("1970-01-01 00:00:01") + 500; + EXPECT_EQ(ts, to_timestamp_float32(1.500f)); + EXPECT_EQ(ts, to_timestamp_float64(1.500)); + + ts = StringToTimestamp("1970-01-01 00:01:01") + 600; + EXPECT_EQ(ts, to_timestamp_float32(61.600f)); + EXPECT_EQ(ts, to_timestamp_float64(61.600)); + + ts = StringToTimestamp("1970-01-01 01:00:01") + 400; + EXPECT_EQ(ts, to_timestamp_float32(3601.400f)); + EXPECT_EQ(ts, to_timestamp_float64(3601.400)); +} + +TEST(TestTime, TestToTimeNumeric) { + // input timestamp in seconds: 1970-01-01 00:00:00 + int64_t expected_output = 0; // 0 milliseconds + EXPECT_EQ(expected_output, to_time_int32(0)); + EXPECT_EQ(expected_output, to_time_int64(0)); + EXPECT_EQ(expected_output, to_time_float32(0.000f)); + EXPECT_EQ(expected_output, to_time_float64(0.000)); + + // input timestamp in seconds: 1970-01-01 00:00:01 + expected_output = 1000; // 1 seconds + EXPECT_EQ(expected_output, to_time_int32(1)); + EXPECT_EQ(expected_output, to_time_int64(1)); + EXPECT_EQ(expected_output, to_time_float32(1.000f)); + EXPECT_EQ(expected_output, to_time_float64(1.000)); + + // input timestamp in seconds: 1970-01-01 01:00:00 + expected_output = 3600000; // 3600 seconds + EXPECT_EQ(expected_output, to_time_int32(3600)); + EXPECT_EQ(expected_output, to_time_int64(3600)); + EXPECT_EQ(expected_output, to_time_float32(3600.000f)); + EXPECT_EQ(expected_output, to_time_float64(3600.000)); + + // input timestamp in seconds: 1970-01-01 23:59:59 + expected_output = 86399000; // 86399 seconds + EXPECT_EQ(expected_output, to_time_int32(86399)); + EXPECT_EQ(expected_output, to_time_int64(86399)); + EXPECT_EQ(expected_output, to_time_float32(86399.000f)); + EXPECT_EQ(expected_output, to_time_float64(86399.000)); + + // input timestamp in seconds: 2020-01-01 00:00:01 + expected_output = 1000; // 1 second + EXPECT_EQ(expected_output, to_time_int64(1577836801)); + EXPECT_EQ(expected_output, to_time_float64(1577836801.000)); + + // tests with fractional part + // input timestamp in seconds: 1970-01-01 00:00:01.500 + expected_output = 1500; // 1.5 seconds + EXPECT_EQ(expected_output, to_time_float32(1.500f)); + EXPECT_EQ(expected_output, to_time_float64(1.500)); + + // input timestamp in seconds: 1970-01-01 00:01:01.500 + expected_output = 61500; // 61.5 seconds + EXPECT_EQ(expected_output, to_time_float32(61.500f)); + EXPECT_EQ(expected_output, to_time_float64(61.500)); + + // input timestamp in seconds: 1970-01-01 01:00:01.500 + expected_output = 3601500; // 3601.5 seconds + EXPECT_EQ(expected_output, to_time_float32(3601.500f)); + EXPECT_EQ(expected_output, to_time_float64(3601.500)); +} + +TEST(TestTime, TestCastIntDayInterval) { + EXPECT_EQ(castBIGINT_daytimeinterval(10), 864000000); + EXPECT_EQ(castBIGINT_daytimeinterval(-100), -8640000001); + EXPECT_EQ(castBIGINT_daytimeinterval(-0), 0); +} + +TEST(TestTime, TestCastIntYearInterval) { + EXPECT_EQ(castINT_year_interval(24), 2); + EXPECT_EQ(castINT_year_interval(-24), -2); + EXPECT_EQ(castINT_year_interval(-23), -1); + + EXPECT_EQ(castBIGINT_year_interval(24), 2); + EXPECT_EQ(castBIGINT_year_interval(-24), -2); + EXPECT_EQ(castBIGINT_year_interval(-23), -1); +} + +TEST(TestTime, TestCastNullableInterval) { + ExecutionContext context; + auto context_ptr = reinterpret_cast<int64_t>(&context); + // Test castNULLABLEINTERVALDAY for int and bigint + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(1), 1); + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(12), 12); + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(-55), -55); + EXPECT_EQ(castNULLABLEINTERVALDAY_int32(-1201), -1201); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(1), 1); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(12), 12); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(-55), -55); + EXPECT_EQ(castNULLABLEINTERVALDAY_int64(-1201), -1201); + + // Test castNULLABLEINTERVALYEAR for int and bigint + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 1), 1); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 12), 12); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 55), 55); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int32(context_ptr, 1201), 1201); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 1), 1); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 12), 12); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 55), 55); + EXPECT_EQ(castNULLABLEINTERVALYEAR_int64(context_ptr, 1201), 1201); + // validate overflow error when using bigint as input + castNULLABLEINTERVALYEAR_int64(context_ptr, INT64_MAX); + EXPECT_EQ(context.get_error(), "Integer overflow"); + context.Reset(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc b/src/arrow/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc new file mode 100644 index 000000000..695605b3c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/timestamp_arithmetic.cc @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "./epoch_time_point.h" + +// The first row is for non-leap years +static int days_in_a_month[2][12] = {{31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}, + {31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}}; + +bool is_leap_year(int yy) { + if ((yy % 4) != 0) { + // not divisible by 4 + return false; + } + // yy = 4x + if ((yy % 400) == 0) { + // yy = 400x + return true; + } + // yy = 4x, return true if yy != 100x + return ((yy % 100) != 0); +} + +bool is_last_day_of_month(const EpochTimePoint& tp) { + int matrix_index = is_leap_year(tp.TmYear()) ? 1 : 0; + + return (tp.TmMday() == days_in_a_month[matrix_index][tp.TmMon()]); +} + +bool did_days_overflow(arrow_vendored::date::year_month_day ymd) { + int year = static_cast<int>(ymd.year()); + int month = static_cast<unsigned int>(ymd.month()); + int days = static_cast<unsigned int>(ymd.day()); + + int matrix_index = is_leap_year(year) ? 1 : 0; + + return days > days_in_a_month[matrix_index][month - 1]; +} + +int last_possible_day_in_month(int year, int month) { + int matrix_index = is_leap_year(year) ? 1 : 0; + + return days_in_a_month[matrix_index][month - 1]; +} + +extern "C" { + +#include <time.h> + +#include "./time_constants.h" +#include "./types.h" + +#define TIMESTAMP_DIFF_FIXED_UNITS(TYPE, NAME, FROM_MILLIS) \ + FORCE_INLINE \ + gdv_int32 NAME##_##TYPE##_##TYPE(gdv_##TYPE start_millis, gdv_##TYPE end_millis) { \ + return static_cast<int32_t>(FROM_MILLIS(end_millis - start_millis)); \ + } + +#define SIGN_ADJUST_DIFF(is_positive, diff) ((is_positive) ? (diff) : -(diff)) +#define MONTHS_TO_TIMEUNIT(diff, num_months) (diff) / (num_months) + +// Assuming end_millis > start_millis, the algorithm to find the diff in months is: +// diff_in_months = year_diff * 12 + month_diff +// This is approximately correct, except when the last month has not fully elapsed +// +// a) If end_day > start_day, return diff_in_months e.g. diff(2015-09-10, 2017-03-31) +// b) If end_day < start_day, return diff_in_months - 1 e.g. diff(2015-09-30, 2017-03-10) +// c) If end_day = start_day, check for millis e.g. diff(2017-03-10, 2015-03-10) +// Need to check if end_millis_in_day > start_millis_in_day +// c1) If end_millis_in_day >= start_millis_in_day, return diff_in_months +// c2) else return diff_in_months - 1 +#define TIMESTAMP_DIFF_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + gdv_int32 NAME##_##TYPE##_##TYPE(gdv_##TYPE start_millis, gdv_##TYPE end_millis) { \ + gdv_int32 diff; \ + bool is_positive = (end_millis > start_millis); \ + if (!is_positive) { \ + /* if end_millis < start_millis, swap and multiply by -1 at the end */ \ + gdv_##TYPE tmp = start_millis; \ + start_millis = end_millis; \ + end_millis = tmp; \ + } \ + EpochTimePoint start_tm(start_millis); \ + EpochTimePoint end_tm(end_millis); \ + gdv_int32 months_diff; \ + months_diff = static_cast<gdv_int32>(12 * (end_tm.TmYear() - start_tm.TmYear()) + \ + (end_tm.TmMon() - start_tm.TmMon())); \ + if (end_tm.TmMday() > start_tm.TmMday()) { \ + /* case a */ \ + diff = MONTHS_TO_TIMEUNIT(months_diff, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } \ + if (end_tm.TmMday() < start_tm.TmMday()) { \ + /* case b */ \ + months_diff += (is_last_day_of_month(end_tm) ? 1 : 0); \ + diff = MONTHS_TO_TIMEUNIT(months_diff - 1, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } \ + gdv_int32 end_day_millis = \ + static_cast<gdv_int32>(end_tm.TmHour() * MILLIS_IN_HOUR + \ + end_tm.TmMin() * MILLIS_IN_MIN + end_tm.TmSec()); \ + gdv_int32 start_day_millis = \ + static_cast<gdv_int32>(start_tm.TmHour() * MILLIS_IN_HOUR + \ + start_tm.TmMin() * MILLIS_IN_MIN + start_tm.TmSec()); \ + if (end_day_millis >= start_day_millis) { \ + /* case c1 */ \ + diff = MONTHS_TO_TIMEUNIT(months_diff, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } \ + /* case c2 */ \ + diff = MONTHS_TO_TIMEUNIT(months_diff - 1, N_MONTHS); \ + return SIGN_ADJUST_DIFF(is_positive, diff); \ + } + +#define TIMESTAMP_DIFF(TYPE) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffSecond, MILLIS_TO_SEC) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffMinute, MILLIS_TO_MINS) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffHour, MILLIS_TO_HOUR) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffDay, MILLIS_TO_DAY) \ + TIMESTAMP_DIFF_FIXED_UNITS(TYPE, timestampdiffWeek, MILLIS_TO_WEEK) \ + TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffMonth, 1) \ + TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffQuarter, 3) \ + TIMESTAMP_DIFF_MONTH_UNITS(TYPE, timestampdiffYear, 12) + +TIMESTAMP_DIFF(timestamp) + +#define ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_int32_##TYPE(gdv_int32 count, gdv_##TYPE millis) { \ + return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \ + } + +// Documentation of mktime suggests that it handles +// TmMon() being negative, and also TmMon() being >= 12 by +// adjusting TmYear() accordingly +// +// Using gmtime_r() and timegm() instead of localtime_r() and mktime() +// since the input millis are since epoch +#define ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_int32_##TYPE(gdv_int32 count, gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \ + } + +// TODO: Handle overflow while converting gdv_int64 to millis +#define ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_int64_##TYPE(gdv_int64 count, gdv_##TYPE millis) { \ + return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \ + } + +#define ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_int64_##TYPE(gdv_int64 count, gdv_##TYPE millis) { \ + EpochTimePoint tp(millis); \ + return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \ + } + +#define ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE##_int32(gdv_##TYPE millis, gdv_int32 count) { \ + return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \ + } + +#define ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE##_int64(gdv_##TYPE millis, gdv_int64 count) { \ + return millis + TO_MILLIS * static_cast<gdv_##TYPE>(count); \ + } + +#define ADD_TIMESTAMP_TO_INT32_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE##_int32(gdv_##TYPE millis, gdv_int32 count) { \ + EpochTimePoint tp(millis); \ + return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \ + } + +#define ADD_TIMESTAMP_TO_INT64_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + FORCE_INLINE \ + gdv_##TYPE NAME##_##TYPE##_int64(gdv_##TYPE millis, gdv_int64 count) { \ + EpochTimePoint tp(millis); \ + return tp.AddMonths(static_cast<int>(count * N_MONTHS)).MillisSinceEpoch(); \ + } + +#define ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, NAME, TO_MILLIS) \ + ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(TYPE, NAME, TO_MILLIS) + +#define ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, NAME, N_MONTHS) \ + ADD_INT32_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + ADD_TIMESTAMP_TO_INT32_MONTH_UNITS(TYPE, NAME, N_MONTHS) + +#define TIMESTAMP_ADD_INT32(TYPE) \ + ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddSecond, MILLIS_IN_SEC) \ + ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddMinute, MILLIS_IN_MIN) \ + ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddHour, MILLIS_IN_HOUR) \ + ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddDay, MILLIS_IN_DAY) \ + ADD_TIMESTAMP_INT32_FIXEDUNITS(TYPE, timestampaddWeek, MILLIS_IN_WEEK) \ + ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, timestampaddMonth, 1) \ + ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, timestampaddQuarter, 3) \ + ADD_TIMESTAMP_INT32_MONTHUNITS(TYPE, timestampaddYear, 12) + +#define ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, NAME, TO_MILLIS) \ + ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(TYPE, NAME, TO_MILLIS) \ + ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(TYPE, NAME, TO_MILLIS) + +#define ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, NAME, N_MONTHS) \ + ADD_INT64_TO_TIMESTAMP_MONTH_UNITS(TYPE, NAME, N_MONTHS) \ + ADD_TIMESTAMP_TO_INT64_MONTH_UNITS(TYPE, NAME, N_MONTHS) + +#define TIMESTAMP_ADD_INT64(TYPE) \ + ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddSecond, MILLIS_IN_SEC) \ + ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddMinute, MILLIS_IN_MIN) \ + ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddHour, MILLIS_IN_HOUR) \ + ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddDay, MILLIS_IN_DAY) \ + ADD_TIMESTAMP_INT64_FIXEDUNITS(TYPE, timestampaddWeek, MILLIS_IN_WEEK) \ + ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, timestampaddMonth, 1) \ + ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, timestampaddQuarter, 3) \ + ADD_TIMESTAMP_INT64_MONTHUNITS(TYPE, timestampaddYear, 12) + +#define TIMESTAMP_ADD_INT(TYPE) \ + TIMESTAMP_ADD_INT32(TYPE) \ + TIMESTAMP_ADD_INT64(TYPE) + +TIMESTAMP_ADD_INT(date64) +TIMESTAMP_ADD_INT(timestamp) + +// add gdv_int32 to timestamp +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_INT32_TO_TIMESTAMP_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +// add gdv_int64 to timestamp +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_INT64_TO_TIMESTAMP_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +// date_sub, subtract, date_diff on gdv_int32 +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, date_sub, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, subtract, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, date_diff, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, date_sub, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, subtract, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, date_diff, -1 * MILLIS_IN_DAY) + +// date_sub, subtract, date_diff on gdv_int64 +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, date_sub, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, subtract, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, date_diff, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, date_sub, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, subtract, -1 * MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, date_diff, -1 * MILLIS_IN_DAY) + +// add timestamp to gdv_int32 +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT32_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +// add timestamp to gdv_int64 +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(date64, add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, date_add, MILLIS_IN_DAY) +ADD_TIMESTAMP_TO_INT64_FIXED_UNITS(timestamp, add, MILLIS_IN_DAY) + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled/types.h b/src/arrow/cpp/src/gandiva/precompiled/types.h new file mode 100644 index 000000000..987ee2c6d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled/types.h @@ -0,0 +1,592 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cstdint> + +#include "gandiva/gdv_function_stubs.h" + +// Use the same names as in arrow data types. Makes it easy to write pre-processor macros. +using gdv_boolean = bool; +using gdv_int8 = int8_t; +using gdv_int16 = int16_t; +using gdv_int32 = int32_t; +using gdv_int64 = int64_t; +using gdv_uint8 = uint8_t; +using gdv_uint16 = uint16_t; +using gdv_uint32 = uint32_t; +using gdv_uint64 = uint64_t; +using gdv_float32 = float; +using gdv_float64 = double; +using gdv_date64 = int64_t; +using gdv_date32 = int32_t; +using gdv_time32 = int32_t; +using gdv_timestamp = int64_t; +using gdv_utf8 = char*; +using gdv_binary = char*; +using gdv_day_time_interval = int64_t; + +#ifdef GANDIVA_UNIT_TEST +// unit tests may be compiled without O2, so inlining may not happen. +#define FORCE_INLINE +#else +#define FORCE_INLINE __attribute__((always_inline)) +#endif + +extern "C" { + +bool bitMapGetBit(const unsigned char* bmap, int64_t position); +void bitMapSetBit(unsigned char* bmap, int64_t position, bool value); +void bitMapClearBitIfFalse(unsigned char* bmap, int64_t position, bool value); + +gdv_int64 extractMillennium_timestamp(gdv_timestamp millis); +gdv_int64 extractCentury_timestamp(gdv_timestamp millis); +gdv_int64 extractDecade_timestamp(gdv_timestamp millis); +gdv_int64 extractYear_timestamp(gdv_timestamp millis); +gdv_int64 extractDoy_timestamp(gdv_timestamp millis); +gdv_int64 extractQuarter_timestamp(gdv_timestamp millis); +gdv_int64 extractMonth_timestamp(gdv_timestamp millis); +gdv_int64 extractWeek_timestamp(gdv_timestamp millis); +gdv_int64 extractDow_timestamp(gdv_timestamp millis); +gdv_int64 extractDay_timestamp(gdv_timestamp millis); +gdv_int64 extractHour_timestamp(gdv_timestamp millis); +gdv_int64 extractMinute_timestamp(gdv_timestamp millis); +gdv_int64 extractSecond_timestamp(gdv_timestamp millis); +gdv_int64 extractHour_time32(gdv_int32 millis_in_day); +gdv_int64 extractMinute_time32(gdv_int32 millis_in_day); +gdv_int64 extractSecond_time32(gdv_int32 millis_in_day); + +gdv_int32 hash32(double val, gdv_int32 seed); +gdv_int32 hash32_buf(const gdv_uint8* buf, int len, gdv_int32 seed); +gdv_int64 hash64(double val, gdv_int64 seed); +gdv_int64 hash64_buf(const gdv_uint8* buf, int len, gdv_int64 seed); + +gdv_int32 timestampdiffMonth_timestamp_timestamp(gdv_timestamp, gdv_timestamp); + +gdv_int64 timestampaddSecond_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 timestampaddMinute_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 timestampaddHour_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 timestampaddDay_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 timestampaddWeek_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 timestampaddMonth_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 timestampaddQuarter_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 timestampaddYear_int32_timestamp(gdv_int32, gdv_timestamp); + +gdv_int64 timestampaddSecond_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 timestampaddMinute_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 timestampaddHour_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 timestampaddDay_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 timestampaddWeek_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 timestampaddMonth_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 timestampaddQuarter_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 timestampaddYear_timestamp_int32(gdv_timestamp, gdv_int32); + +gdv_int64 timestampaddSecond_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 timestampaddMinute_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 timestampaddHour_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 timestampaddDay_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 timestampaddWeek_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 timestampaddMonth_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 timestampaddQuarter_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 timestampaddYear_int64_timestamp(gdv_int64, gdv_timestamp); + +gdv_int64 timestampaddSecond_timestamp_int64(gdv_timestamp, gdv_int64); +gdv_int64 timestampaddMinute_timestamp_int64(gdv_timestamp, gdv_int64); +gdv_int64 timestampaddHour_timestamp_int64(gdv_timestamp, gdv_int64); +gdv_int64 timestampaddDay_timestamp_int64(gdv_timestamp, gdv_int64); +gdv_int64 timestampaddWeek_timestamp_int64(gdv_timestamp, gdv_int64); +gdv_int64 timestampaddMonth_timestamp_int64(gdv_timestamp, gdv_int64); +gdv_int64 timestampaddQuarter_timestamp_int64(gdv_timestamp, gdv_int64); +gdv_int64 timestampaddYear_timestamp_int64(gdv_timestamp, gdv_int64); + +gdv_int64 date_add_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 add_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_int64 add_int32_timestamp(gdv_int32, gdv_timestamp); +gdv_int64 date_add_int64_timestamp(gdv_int64, gdv_timestamp); +gdv_timestamp add_date64_int64(gdv_date64, gdv_int64); + +gdv_timestamp to_timestamp_int32(gdv_int32); +gdv_timestamp to_timestamp_int64(gdv_int64); +gdv_timestamp to_timestamp_float32(gdv_float32); +gdv_timestamp to_timestamp_float64(gdv_float64); + +gdv_time32 to_time_int32(gdv_int32); +gdv_time32 to_time_int64(gdv_int64); +gdv_time32 to_time_float32(gdv_float32); +gdv_time32 to_time_float64(gdv_float64); + +gdv_int64 date_sub_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 subtract_timestamp_int32(gdv_timestamp, gdv_int32); +gdv_int64 date_diff_timestamp_int64(gdv_timestamp, gdv_int64); + +gdv_boolean castBIT_utf8(gdv_int64 context, const char* data, gdv_int32 data_len); + +bool is_distinct_from_timestamp_timestamp(gdv_int64, bool, gdv_int64, bool); +bool is_not_distinct_from_int32_int32(gdv_int32, bool, gdv_int32, bool); + +gdv_int64 date_trunc_Second_date64(gdv_date64); +gdv_int64 date_trunc_Minute_date64(gdv_date64); +gdv_int64 date_trunc_Hour_date64(gdv_date64); +gdv_int64 date_trunc_Day_date64(gdv_date64); +gdv_int64 date_trunc_Month_date64(gdv_date64); +gdv_int64 date_trunc_Quarter_date64(gdv_date64); +gdv_int64 date_trunc_Year_date64(gdv_date64); +gdv_int64 date_trunc_Decade_date64(gdv_date64); +gdv_int64 date_trunc_Century_date64(gdv_date64); +gdv_int64 date_trunc_Millennium_date64(gdv_date64); + +gdv_int64 date_trunc_Week_timestamp(gdv_timestamp); +double months_between_timestamp_timestamp(gdv_uint64, gdv_uint64); + +gdv_int32 mem_compare(const char* left, gdv_int32 left_len, const char* right, + gdv_int32 right_len); + +gdv_int32 mod_int64_int32(gdv_int64 left, gdv_int32 right); +gdv_float64 mod_float64_float64(gdv_int64 context, gdv_float64 left, gdv_float64 right); + +gdv_int64 divide_int64_int64(gdv_int64 context, gdv_int64 in1, gdv_int64 in2); + +gdv_int64 div_int64_int64(gdv_int64 context, gdv_int64 in1, gdv_int64 in2); +gdv_float32 div_float32_float32(gdv_int64 context, gdv_float32 in1, gdv_float32 in2); +gdv_float64 div_float64_float64(gdv_int64 context, gdv_float64 in1, gdv_float64 in2); + +gdv_float32 round_float32(gdv_float32); +gdv_float64 round_float64(gdv_float64); +gdv_float32 round_float32_int32(gdv_float32 number, gdv_int32 out_scale); +gdv_float64 round_float64_int32(gdv_float64 number, gdv_int32 out_scale); +gdv_float64 get_scale_multiplier(gdv_int32); +gdv_int32 round_int32_int32(gdv_int32 number, gdv_int32 precision); +gdv_int64 round_int64_int32(gdv_int64 number, gdv_int32 precision); +gdv_int32 round_int32(gdv_int32); +gdv_int64 round_int64(gdv_int64); +gdv_int64 get_power_of_10(gdv_int32); + +const char* bin_int32(int64_t context, gdv_int32 value, int32_t* out_len); +const char* bin_int64(int64_t context, gdv_int64 value, int32_t* out_len); + +gdv_float64 cbrt_int32(gdv_int32); +gdv_float64 cbrt_int64(gdv_int64); +gdv_float64 cbrt_float32(gdv_float32); +gdv_float64 cbrt_float64(gdv_float64); + +gdv_float64 exp_int32(gdv_int32); +gdv_float64 exp_int64(gdv_int64); +gdv_float64 exp_float32(gdv_float32); +gdv_float64 exp_float64(gdv_float64); + +gdv_float64 log_int32(gdv_int32); +gdv_float64 log_int64(gdv_int64); +gdv_float64 log_float32(gdv_float32); +gdv_float64 log_float64(gdv_float64); + +gdv_float64 log10_int32(gdv_int32); +gdv_float64 log10_int64(gdv_int64); +gdv_float64 log10_float32(gdv_float32); +gdv_float64 log10_float64(gdv_float64); + +gdv_float64 sin_int32(gdv_int32); +gdv_float64 sin_int64(gdv_int64); +gdv_float64 sin_float32(gdv_float32); +gdv_float64 sin_float64(gdv_float64); +gdv_float64 cos_int32(gdv_int32); +gdv_float64 cos_int64(gdv_int64); +gdv_float64 cos_float32(gdv_float32); +gdv_float64 cos_float64(gdv_float64); +gdv_float64 asin_int32(gdv_int32); +gdv_float64 asin_int64(gdv_int64); +gdv_float64 asin_float32(gdv_float32); +gdv_float64 asin_float64(gdv_float64); +gdv_float64 acos_int32(gdv_int32); +gdv_float64 acos_int64(gdv_int64); +gdv_float64 acos_float32(gdv_float32); +gdv_float64 acos_float64(gdv_float64); +gdv_float64 tan_int32(gdv_int32); +gdv_float64 tan_int64(gdv_int64); +gdv_float64 tan_float32(gdv_float32); +gdv_float64 tan_float64(gdv_float64); +gdv_float64 atan_int32(gdv_int32); +gdv_float64 atan_int64(gdv_int64); +gdv_float64 atan_float32(gdv_float32); +gdv_float64 atan_float64(gdv_float64); +gdv_float64 sinh_int32(gdv_int32); +gdv_float64 sinh_int64(gdv_int64); +gdv_float64 sinh_float32(gdv_float32); +gdv_float64 sinh_float64(gdv_float64); +gdv_float64 cosh_int32(gdv_int32); +gdv_float64 cosh_int64(gdv_int64); +gdv_float64 cosh_float32(gdv_float32); +gdv_float64 cosh_float64(gdv_float64); +gdv_float64 tanh_int32(gdv_int32); +gdv_float64 tanh_int64(gdv_int64); +gdv_float64 tanh_float32(gdv_float32); +gdv_float64 tanh_float64(gdv_float64); +gdv_float64 atan2_int32_int32(gdv_int32 in1, gdv_int32 in2); +gdv_float64 atan2_int64_int64(gdv_int64 in1, gdv_int64 in2); +gdv_float64 atan2_float32_float32(gdv_float32 in1, gdv_float32 in2); +gdv_float64 atan2_float64_float64(gdv_float64 in1, gdv_float64 in2); +gdv_float64 cot_float32(gdv_float32); +gdv_float64 cot_float64(gdv_float64); +gdv_float64 radians_int32(gdv_int32); +gdv_float64 radians_int64(gdv_int64); +gdv_float64 radians_float32(gdv_float32); +gdv_float64 radians_float64(gdv_float64); +gdv_float64 degrees_int32(gdv_int32); +gdv_float64 degrees_int64(gdv_int64); +gdv_float64 degrees_float32(gdv_float32); +gdv_float64 degrees_float64(gdv_float64); + +gdv_int32 bitwise_and_int32_int32(gdv_int32 in1, gdv_int32 in2); +gdv_int64 bitwise_and_int64_int64(gdv_int64 in1, gdv_int64 in2); +gdv_int32 bitwise_or_int32_int32(gdv_int32 in1, gdv_int32 in2); +gdv_int64 bitwise_or_int64_int64(gdv_int64 in1, gdv_int64 in2); +gdv_int32 bitwise_xor_int32_int32(gdv_int32 in1, gdv_int32 in2); +gdv_int64 bitwise_xor_int64_int64(gdv_int64 in1, gdv_int64 in2); +gdv_int32 bitwise_not_int32(gdv_int32); +gdv_int64 bitwise_not_int64(gdv_int64); + +gdv_float64 power_float64_float64(gdv_float64, gdv_float64); + +gdv_float64 log_int32_int32(gdv_int64 context, gdv_int32 base, gdv_int32 value); + +bool starts_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* prefix, + gdv_int32 prefix_len); +bool ends_with_utf8_utf8(const char* data, gdv_int32 data_len, const char* suffix, + gdv_int32 suffix_len); +bool is_substr_utf8_utf8(const char* data, gdv_int32 data_len, const char* substr, + gdv_int32 substr_len); + +gdv_int32 utf8_length(gdv_int64 context, const char* data, gdv_int32 data_len); + +gdv_int32 utf8_last_char_pos(gdv_int64 context, const char* data, gdv_int32 data_len); + +gdv_date64 castDATE_utf8(int64_t execution_context, const char* input, gdv_int32 length); + +gdv_date64 castDATE_int64(gdv_int64 date); + +gdv_date64 castDATE_date32(gdv_date32 date); + +gdv_date32 castDATE_int32(gdv_int32 date); + +gdv_timestamp castTIMESTAMP_utf8(int64_t execution_context, const char* input, + gdv_int32 length); +gdv_timestamp castTIMESTAMP_date64(gdv_date64); +gdv_timestamp castTIMESTAMP_int64(gdv_int64); +gdv_date64 castDATE_timestamp(gdv_timestamp); +gdv_time32 castTIME_timestamp(gdv_timestamp timestamp_in_millis); +const char* castVARCHAR_timestamp_int64(int64_t, gdv_timestamp, gdv_int64, gdv_int32*); +gdv_date64 last_day_from_timestamp(gdv_date64 millis); + +gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale); + +const char* repeat_utf8_int32(gdv_int64 context, const char* in, gdv_int32 in_len, + gdv_int32 repeat_times, gdv_int32* out_len); + +const char* substr_utf8_int64_int64(gdv_int64 context, const char* input, + gdv_int32 in_len, gdv_int64 offset64, + gdv_int64 length, gdv_int32* out_len); +const char* substr_utf8_int64(gdv_int64 context, const char* input, gdv_int32 in_len, + gdv_int64 offset64, gdv_int32* out_len); + +const char* concat_utf8_utf8(gdv_int64 context, const char* left, gdv_int32 left_len, + bool left_validity, const char* right, gdv_int32 right_len, + bool right_validity, gdv_int32* out_len); +const char* concat_utf8_utf8_utf8(gdv_int64 context, const char* in1, gdv_int32 in1_len, + bool in1_validity, const char* in2, gdv_int32 in2_len, + bool in2_validity, const char* in3, gdv_int32 in3_len, + bool in3_validity, gdv_int32* out_len); +const char* concat_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1, + gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, + bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, + const char* in4, gdv_int32 in4_len, + bool in4_validity, gdv_int32* out_len); +const char* space_int32(gdv_int64 ctx, gdv_int32 n, int32_t* out_len); +const char* space_int64(gdv_int64 ctx, gdv_int64 n, int32_t* out_len); +const char* concat_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + gdv_int32* out_len); +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, gdv_int32* out_len); +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, gdv_int32* out_len); +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len, + bool in8_validity, gdv_int32* out_len); +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len, + bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity, + gdv_int32* out_len); +const char* concat_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, bool in1_validity, + const char* in2, gdv_int32 in2_len, bool in2_validity, const char* in3, + gdv_int32 in3_len, bool in3_validity, const char* in4, gdv_int32 in4_len, + bool in4_validity, const char* in5, gdv_int32 in5_len, bool in5_validity, + const char* in6, gdv_int32 in6_len, bool in6_validity, const char* in7, + gdv_int32 in7_len, bool in7_validity, const char* in8, gdv_int32 in8_len, + bool in8_validity, const char* in9, gdv_int32 in9_len, bool in9_validity, + const char* in10, gdv_int32 in10_len, bool in10_validity, gdv_int32* out_len); + +const char* concatOperator_utf8_utf8(gdv_int64 context, const char* left, + gdv_int32 left_len, const char* right, + gdv_int32 right_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8(gdv_int64 context, const char* in1, + gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, + gdv_int32 in3_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8_utf8(gdv_int64 context, const char* in1, + gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, + gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8, + gdv_int32 in8_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8, + gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, gdv_int32* out_len); +const char* concatOperator_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8_utf8( + gdv_int64 context, const char* in1, gdv_int32 in1_len, const char* in2, + gdv_int32 in2_len, const char* in3, gdv_int32 in3_len, const char* in4, + gdv_int32 in4_len, const char* in5, gdv_int32 in5_len, const char* in6, + gdv_int32 in6_len, const char* in7, gdv_int32 in7_len, const char* in8, + gdv_int32 in8_len, const char* in9, gdv_int32 in9_len, const char* in10, + gdv_int32 in10_len, gdv_int32* out_len); + +const char* castVARCHAR_binary_int64(gdv_int64 context, const char* data, + gdv_int32 data_len, int64_t out_len, + int32_t* out_length); + +const char* castVARCHAR_utf8_int64(gdv_int64 context, const char* data, + gdv_int32 data_len, int64_t out_len, + int32_t* out_length); + +const char* castVARBINARY_utf8_int64(gdv_int64 context, const char* data, + gdv_int32 data_len, int64_t out_len, + int32_t* out_length); + +const char* castVARBINARY_binary_int64(gdv_int64 context, const char* data, + gdv_int32 data_len, int64_t out_len, + int32_t* out_length); + +const char* reverse_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len); + +const char* ltrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len); + +const char* rtrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len); + +const char* btrim_utf8(gdv_int64 context, const char* data, gdv_int32 data_len, + int32_t* out_len); + +const char* ltrim_utf8_utf8(gdv_int64 context, const char* basetext, + gdv_int32 basetext_len, const char* trimtext, + gdv_int32 trimtext_len, int32_t* out_len); + +const char* rtrim_utf8_utf8(gdv_int64 context, const char* basetext, + gdv_int32 basetext_len, const char* trimtext, + gdv_int32 trimtext_len, int32_t* out_len); + +const char* btrim_utf8_utf8(gdv_int64 context, const char* basetext, + gdv_int32 basetext_len, const char* trimtext, + gdv_int32 trimtext_len, int32_t* out_len); + +gdv_int32 ascii_utf8(const char* data, gdv_int32 data_len); + +gdv_int32 locate_utf8_utf8(gdv_int64 context, const char* sub_str, gdv_int32 sub_str_len, + const char* str, gdv_int32 str_len); + +gdv_int32 strpos_utf8_utf8(gdv_int64 context, const char* str, gdv_int32 str_len, + const char* sub_str, gdv_int32 sub_str_len); + +gdv_int32 locate_utf8_utf8_int32(gdv_int64 context, const char* sub_str, + gdv_int32 sub_str_len, const char* str, + gdv_int32 str_len, gdv_int32 start_pos); + +const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, const char* fill_text, + gdv_int32 fill_text_len, gdv_int32* out_len); + +const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, const char* fill_text, + gdv_int32 fill_text_len, gdv_int32* out_len); + +const char* lpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, gdv_int32* out_len); + +const char* rpad_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 return_length, gdv_int32* out_len); + +const char* replace_with_max_len_utf8_utf8_utf8(gdv_int64 context, const char* text, + gdv_int32 text_len, const char* from_str, + gdv_int32 from_str_len, + const char* to_str, gdv_int32 to_str_len, + gdv_int32 max_length, gdv_int32* out_len); + +const char* replace_utf8_utf8_utf8(gdv_int64 context, const char* text, + gdv_int32 text_len, const char* from_str, + gdv_int32 from_str_len, const char* to_str, + gdv_int32 to_str_len, gdv_int32* out_len); + +const char* convert_replace_invalid_fromUTF8_binary(int64_t context, const char* text_in, + int32_t text_len, + const char* char_to_replace, + int32_t char_to_replace_len, + int32_t* out_len); + +const char* convert_toDOUBLE(int64_t context, double value, int32_t* out_len); + +const char* convert_toDOUBLE_be(int64_t context, double value, int32_t* out_len); + +const char* convert_toFLOAT(int64_t context, float value, int32_t* out_len); + +const char* convert_toFLOAT_be(int64_t context, float value, int32_t* out_len); + +const char* convert_toBIGINT(int64_t context, int64_t value, int32_t* out_len); + +const char* convert_toBIGINT_be(int64_t context, int64_t value, int32_t* out_len); + +const char* convert_toINT(int64_t context, int32_t value, int32_t* out_len); + +const char* convert_toINT_be(int64_t context, int32_t value, int32_t* out_len); + +const char* convert_toBOOLEAN(int64_t context, bool value, int32_t* out_len); + +const char* convert_toTIME_EPOCH(int64_t context, int32_t value, int32_t* out_len); + +const char* convert_toTIME_EPOCH_be(int64_t context, int32_t value, int32_t* out_len); + +const char* convert_toTIMESTAMP_EPOCH(int64_t context, int64_t timestamp, + int32_t* out_len); +const char* convert_toTIMESTAMP_EPOCH_be(int64_t context, int64_t timestamp, + int32_t* out_len); + +const char* convert_toDATE_EPOCH(int64_t context, int64_t date, int32_t* out_len); + +const char* convert_toDATE_EPOCH_be(int64_t context, int64_t date, int32_t* out_len); + +const char* convert_toUTF8(int64_t context, const char* value, int32_t value_len, + int32_t* out_len); + +const char* split_part(gdv_int64 context, const char* text, gdv_int32 text_len, + const char* splitter, gdv_int32 split_len, gdv_int32 index, + gdv_int32* out_len); + +const char* byte_substr_binary_int32_int32(gdv_int64 context, const char* text, + gdv_int32 text_len, gdv_int32 offset, + gdv_int32 length, gdv_int32* out_len); + +const char* castVARCHAR_bool_int64(gdv_int64 context, gdv_boolean value, + gdv_int64 out_len, gdv_int32* out_length); + +const char* castVARCHAR_int32_int64(int64_t context, int32_t value, int64_t len, + int32_t* out_len); + +const char* castVARCHAR_int64_int64(int64_t context, int64_t value, int64_t len, + int32_t* out_len); + +const char* castVARCHAR_float32_int64(int64_t context, float value, int64_t len, + int32_t* out_len); + +const char* castVARCHAR_float64_int64(int64_t context, double value, int64_t len, + int32_t* out_len); + +const char* left_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 number, gdv_int32* out_len); + +const char* right_utf8_int32(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32 number, gdv_int32* out_len); + +const char* binary_string(gdv_int64 context, const char* text, gdv_int32 text_len, + gdv_int32* out_len); + +int32_t castINT_utf8(int64_t context, const char* data, int32_t len); + +int64_t castBIGINT_utf8(int64_t context, const char* data, int32_t len); + +float castFLOAT4_utf8(int64_t context, const char* data, int32_t len); + +double castFLOAT8_utf8(int64_t context, const char* data, int32_t len); + +int32_t castINT_float32(gdv_float32 value); + +int32_t castINT_float64(gdv_float64 value); + +int64_t castBIGINT_float32(gdv_float32 value); + +int64_t castBIGINT_float64(gdv_float64 value); + +int64_t castBIGINT_daytimeinterval(gdv_day_time_interval in); + +int32_t castINT_year_interval(gdv_month_interval in); + +int64_t castBIGINT_year_interval(gdv_month_interval in); + +gdv_day_time_interval castNULLABLEINTERVALDAY_int32(gdv_int32 in); + +gdv_day_time_interval castNULLABLEINTERVALDAY_int64(gdv_int64 in); + +gdv_month_interval castNULLABLEINTERVALYEAR_int32(int64_t context, gdv_int32 in); + +gdv_month_interval castNULLABLEINTERVALYEAR_int64(int64_t context, gdv_int64 in); + +} // extern "C" diff --git a/src/arrow/cpp/src/gandiva/precompiled_bitcode.cc.in b/src/arrow/cpp/src/gandiva/precompiled_bitcode.cc.in new file mode 100644 index 000000000..9c382961d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/precompiled_bitcode.cc.in @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <string> + +namespace gandiva { + +// Content of precompiled bitcode file. +extern const unsigned char kPrecompiledBitcode[] = { <DATA_CHARS> }; +extern const size_t kPrecompiledBitcodeSize = sizeof(kPrecompiledBitcode); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/projector.cc b/src/arrow/cpp/src/gandiva/projector.cc new file mode 100644 index 000000000..ff167538f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/projector.cc @@ -0,0 +1,369 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/projector.h" + +#include <memory> +#include <thread> +#include <utility> +#include <vector> + +#include "arrow/util/hash_util.h" +#include "arrow/util/logging.h" + +#include "gandiva/cache.h" +#include "gandiva/expr_validator.h" +#include "gandiva/llvm_generator.h" + +namespace gandiva { + +class ProjectorCacheKey { + public: + ProjectorCacheKey(SchemaPtr schema, std::shared_ptr<Configuration> configuration, + ExpressionVector expression_vector, SelectionVector::Mode mode) + : schema_(schema), configuration_(configuration), mode_(mode), uniqifier_(0) { + static const int kSeedValue = 4; + size_t result = kSeedValue; + for (auto& expr : expression_vector) { + std::string expr_as_string = expr->ToString(); + expressions_as_strings_.push_back(expr_as_string); + arrow::internal::hash_combine(result, expr_as_string); + UpdateUniqifier(expr_as_string); + } + arrow::internal::hash_combine(result, static_cast<size_t>(mode)); + arrow::internal::hash_combine(result, configuration->Hash()); + arrow::internal::hash_combine(result, schema_->ToString()); + arrow::internal::hash_combine(result, uniqifier_); + hash_code_ = result; + } + + std::size_t Hash() const { return hash_code_; } + + bool operator==(const ProjectorCacheKey& other) const { + // arrow schema does not overload equality operators. + if (!(schema_->Equals(*other.schema().get(), true))) { + return false; + } + + if (*configuration_ != *other.configuration_) { + return false; + } + + if (expressions_as_strings_ != other.expressions_as_strings_) { + return false; + } + + if (mode_ != other.mode_) { + return false; + } + + if (uniqifier_ != other.uniqifier_) { + return false; + } + return true; + } + + bool operator!=(const ProjectorCacheKey& other) const { return !(*this == other); } + + SchemaPtr schema() const { return schema_; } + + std::string ToString() const { + std::stringstream ss; + // indent, window, indent_size, null_rep and skip new lines. + arrow::PrettyPrintOptions options{0, 10, 2, "null", true}; + DCHECK_OK(PrettyPrint(*schema_.get(), options, &ss)); + + ss << "Expressions: ["; + bool first = true; + for (auto& expr : expressions_as_strings_) { + if (first) { + first = false; + } else { + ss << ", "; + } + + ss << expr; + } + ss << "]"; + return ss.str(); + } + + private: + void UpdateUniqifier(const std::string& expr) { + if (uniqifier_ == 0) { + // caching of expressions with re2 patterns causes lock contention. So, use + // multiple instances to reduce contention. + if (expr.find(" like(") != std::string::npos) { + uniqifier_ = std::hash<std::thread::id>()(std::this_thread::get_id()) % 16; + } + } + } + + const SchemaPtr schema_; + const std::shared_ptr<Configuration> configuration_; + SelectionVector::Mode mode_; + std::vector<std::string> expressions_as_strings_; + size_t hash_code_; + uint32_t uniqifier_; +}; + +Projector::Projector(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema, + const FieldVector& output_fields, + std::shared_ptr<Configuration> configuration) + : llvm_generator_(std::move(llvm_generator)), + schema_(schema), + output_fields_(output_fields), + configuration_(configuration) {} + +Projector::~Projector() {} + +Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs, + std::shared_ptr<Projector>* projector) { + return Projector::Make(schema, exprs, SelectionVector::Mode::MODE_NONE, + ConfigurationBuilder::DefaultConfiguration(), projector); +} + +Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs, + std::shared_ptr<Configuration> configuration, + std::shared_ptr<Projector>* projector) { + return Projector::Make(schema, exprs, SelectionVector::Mode::MODE_NONE, configuration, + projector); +} + +Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs, + SelectionVector::Mode selection_vector_mode, + std::shared_ptr<Configuration> configuration, + std::shared_ptr<Projector>* projector) { + ARROW_RETURN_IF(schema == nullptr, Status::Invalid("Schema cannot be null")); + ARROW_RETURN_IF(exprs.empty(), Status::Invalid("Expressions cannot be empty")); + ARROW_RETURN_IF(configuration == nullptr, + Status::Invalid("Configuration cannot be null")); + + // see if equivalent projector was already built + static Cache<ProjectorCacheKey, std::shared_ptr<Projector>> cache; + ProjectorCacheKey cache_key(schema, configuration, exprs, selection_vector_mode); + std::shared_ptr<Projector> cached_projector = cache.GetModule(cache_key); + if (cached_projector != nullptr) { + *projector = cached_projector; + return Status::OK(); + } + + // Build LLVM generator, and generate code for the specified expressions + std::unique_ptr<LLVMGenerator> llvm_gen; + ARROW_RETURN_NOT_OK(LLVMGenerator::Make(configuration, &llvm_gen)); + + // Run the validation on the expressions. + // Return if any of the expression is invalid since + // we will not be able to process further. + ExprValidator expr_validator(llvm_gen->types(), schema); + for (auto& expr : exprs) { + ARROW_RETURN_NOT_OK(expr_validator.Validate(expr)); + } + + // Start measuring build time + auto begin = std::chrono::high_resolution_clock::now(); + ARROW_RETURN_NOT_OK(llvm_gen->Build(exprs, selection_vector_mode)); + // Stop measuring time and calculate the elapsed time + auto end = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count(); + + // save the output field types. Used for validation at Evaluate() time. + std::vector<FieldPtr> output_fields; + output_fields.reserve(exprs.size()); + for (auto& expr : exprs) { + output_fields.push_back(expr->result()); + } + + // Instantiate the projector with the completely built llvm generator + *projector = std::shared_ptr<Projector>( + new Projector(std::move(llvm_gen), schema, output_fields, configuration)); + ValueCacheObject<std::shared_ptr<Projector>> value_cache(*projector, elapsed); + cache.PutModule(cache_key, value_cache); + + return Status::OK(); +} + +Status Projector::Evaluate(const arrow::RecordBatch& batch, + const ArrayDataVector& output_data_vecs) { + return Evaluate(batch, nullptr, output_data_vecs); +} + +Status Projector::Evaluate(const arrow::RecordBatch& batch, + const SelectionVector* selection_vector, + const ArrayDataVector& output_data_vecs) { + ARROW_RETURN_NOT_OK(ValidateEvaluateArgsCommon(batch)); + + if (output_data_vecs.size() != output_fields_.size()) { + std::stringstream ss; + ss << "number of buffers for output_data_vecs is " << output_data_vecs.size() + << ", expected " << output_fields_.size(); + return Status::Invalid(ss.str()); + } + + int idx = 0; + for (auto& array_data : output_data_vecs) { + if (array_data == nullptr) { + std::stringstream ss; + ss << "array for output field " << output_fields_[idx]->name() << "is null."; + return Status::Invalid(ss.str()); + } + + auto num_rows = + selection_vector == nullptr ? batch.num_rows() : selection_vector->GetNumSlots(); + + ARROW_RETURN_NOT_OK( + ValidateArrayDataCapacity(*array_data, *(output_fields_[idx]), num_rows)); + ++idx; + } + return llvm_generator_->Execute(batch, selection_vector, output_data_vecs); +} + +Status Projector::Evaluate(const arrow::RecordBatch& batch, arrow::MemoryPool* pool, + arrow::ArrayVector* output) { + return Evaluate(batch, nullptr, pool, output); +} + +Status Projector::Evaluate(const arrow::RecordBatch& batch, + const SelectionVector* selection_vector, + arrow::MemoryPool* pool, arrow::ArrayVector* output) { + ARROW_RETURN_NOT_OK(ValidateEvaluateArgsCommon(batch)); + ARROW_RETURN_IF(output == nullptr, Status::Invalid("Output must be non-null.")); + ARROW_RETURN_IF(pool == nullptr, Status::Invalid("Memory pool must be non-null.")); + + auto num_rows = + selection_vector == nullptr ? batch.num_rows() : selection_vector->GetNumSlots(); + // Allocate the output data vecs. + ArrayDataVector output_data_vecs; + for (auto& field : output_fields_) { + ArrayDataPtr output_data; + + ARROW_RETURN_NOT_OK(AllocArrayData(field->type(), num_rows, pool, &output_data)); + output_data_vecs.push_back(output_data); + } + + // Execute the expression(s). + ARROW_RETURN_NOT_OK( + llvm_generator_->Execute(batch, selection_vector, output_data_vecs)); + + // Create and return array arrays. + output->clear(); + for (auto& array_data : output_data_vecs) { + output->push_back(arrow::MakeArray(array_data)); + } + return Status::OK(); +} + +// TODO : handle complex vectors (list/map/..) +Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, + arrow::MemoryPool* pool, ArrayDataPtr* array_data) { + arrow::Status astatus; + std::vector<std::shared_ptr<arrow::Buffer>> buffers; + + // The output vector always has a null bitmap. + int64_t size = arrow::BitUtil::BytesForBits(num_records); + ARROW_ASSIGN_OR_RAISE(auto bitmap_buffer, arrow::AllocateBuffer(size, pool)); + buffers.push_back(std::move(bitmap_buffer)); + + // String/Binary vectors have an offsets array. + auto type_id = type->id(); + if (arrow::is_binary_like(type_id)) { + auto offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32); + + ARROW_ASSIGN_OR_RAISE(auto offsets_buffer, arrow::AllocateBuffer(offsets_len, pool)); + buffers.push_back(std::move(offsets_buffer)); + } + + // The output vector always has a data array. + int64_t data_len; + if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) { + const auto& fw_type = dynamic_cast<const arrow::FixedWidthType&>(*type); + data_len = arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width()); + } else if (arrow::is_binary_like(type_id)) { + // we don't know the expected size for varlen output vectors. + data_len = 0; + } else { + return Status::Invalid("Unsupported output data type " + type->ToString()); + } + ARROW_ASSIGN_OR_RAISE(auto data_buffer, arrow::AllocateResizableBuffer(data_len, pool)); + + // This is not strictly required but valgrind gets confused and detects this + // as uninitialized memory access. See arrow::util::SetBitTo(). + if (type->id() == arrow::Type::BOOL) { + memset(data_buffer->mutable_data(), 0, data_len); + } + buffers.push_back(std::move(data_buffer)); + + *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + return Status::OK(); +} + +Status Projector::ValidateEvaluateArgsCommon(const arrow::RecordBatch& batch) { + ARROW_RETURN_IF(!batch.schema()->Equals(*schema_), + Status::Invalid("Schema in RecordBatch must match schema in Make()")); + ARROW_RETURN_IF(batch.num_rows() == 0, + Status::Invalid("RecordBatch must be non-empty.")); + + return Status::OK(); +} + +Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data, + const arrow::Field& field, + int64_t num_records) { + ARROW_RETURN_IF(array_data.buffers.size() < 2, + Status::Invalid("ArrayData must have at least 2 buffers")); + + int64_t min_bitmap_len = arrow::BitUtil::BytesForBits(num_records); + int64_t bitmap_len = array_data.buffers[0]->capacity(); + ARROW_RETURN_IF( + bitmap_len < min_bitmap_len, + Status::Invalid("Bitmap buffer too small for ", field.name(), " expected minimum ", + min_bitmap_len, " actual size ", bitmap_len)); + + auto type_id = field.type()->id(); + if (arrow::is_binary_like(type_id)) { + // validate size of offsets buffer. + int64_t min_offsets_len = arrow::BitUtil::BytesForBits((num_records + 1) * 32); + int64_t offsets_len = array_data.buffers[1]->capacity(); + ARROW_RETURN_IF( + offsets_len < min_offsets_len, + Status::Invalid("offsets buffer too small for ", field.name(), + " minimum required ", min_offsets_len, " actual ", offsets_len)); + + // check that it's resizable. + auto resizable = dynamic_cast<arrow::ResizableBuffer*>(array_data.buffers[2].get()); + ARROW_RETURN_IF( + resizable == nullptr, + Status::Invalid("data buffer for varlen output vectors must be resizable")); + } else if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) { + // verify size of data buffer. + const auto& fw_type = dynamic_cast<const arrow::FixedWidthType&>(*field.type()); + int64_t min_data_len = + arrow::BitUtil::BytesForBits(num_records * fw_type.bit_width()); + int64_t data_len = array_data.buffers[1]->capacity(); + ARROW_RETURN_IF(data_len < min_data_len, + Status::Invalid("Data buffer too small for ", field.name())); + } else { + return Status::Invalid("Unsupported output data type " + field.type()->ToString()); + } + + return Status::OK(); +} + +std::string Projector::DumpIR() { return llvm_generator_->DumpIR(); } + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/projector.h b/src/arrow/cpp/src/gandiva/projector.h new file mode 100644 index 000000000..20b36c9d8 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/projector.h @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "arrow/status.h" + +#include "gandiva/arrow.h" +#include "gandiva/configuration.h" +#include "gandiva/expression.h" +#include "gandiva/selection_vector.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +class LLVMGenerator; + +/// \brief projection using expressions. +/// +/// A projector is built for a specific schema and vector of expressions. +/// Once the projector is built, it can be used to evaluate many row batches. +class GANDIVA_EXPORT Projector { + public: + // Inline dtor will attempt to resolve the destructor for + // LLVMGenerator on MSVC, so we compile the dtor in the object code + ~Projector(); + + /// Build a default projector for the given schema to evaluate + /// the vector of expressions. + /// + /// \param[in] schema schema for the record batches, and the expressions. + /// \param[in] exprs vector of expressions. + /// \param[out] projector the returned projector object + static Status Make(SchemaPtr schema, const ExpressionVector& exprs, + std::shared_ptr<Projector>* projector); + + /// Build a projector for the given schema to evaluate the vector of expressions. + /// Customize the projector with runtime configuration. + /// + /// \param[in] schema schema for the record batches, and the expressions. + /// \param[in] exprs vector of expressions. + /// \param[in] configuration run time configuration. + /// \param[out] projector the returned projector object + static Status Make(SchemaPtr schema, const ExpressionVector& exprs, + std::shared_ptr<Configuration> configuration, + std::shared_ptr<Projector>* projector); + + /// Build a projector for the given schema to evaluate the vector of expressions. + /// Customize the projector with runtime configuration. + /// + /// \param[in] schema schema for the record batches, and the expressions. + /// \param[in] exprs vector of expressions. + /// \param[in] selection_vector_mode mode of selection vector + /// \param[in] configuration run time configuration. + /// \param[out] projector the returned projector object + static Status Make(SchemaPtr schema, const ExpressionVector& exprs, + SelectionVector::Mode selection_vector_mode, + std::shared_ptr<Configuration> configuration, + std::shared_ptr<Projector>* projector); + + /// Evaluate the specified record batch, and return the allocated and populated output + /// arrays. The output arrays will be allocated from the memory pool 'pool', and added + /// to the vector 'output'. + /// + /// \param[in] batch the record batch. schema should be the same as the one in 'Make' + /// \param[in] pool memory pool used to allocate output arrays (if required). + /// \param[out] output the vector of allocated/populated arrays. + Status Evaluate(const arrow::RecordBatch& batch, arrow::MemoryPool* pool, + arrow::ArrayVector* output); + + /// Evaluate the specified record batch, and populate the output arrays. The output + /// arrays of sufficient capacity must be allocated by the caller. + /// + /// \param[in] batch the record batch. schema should be the same as the one in 'Make' + /// \param[in,out] output vector of arrays, the arrays are allocated by the caller and + /// populated by Evaluate. + Status Evaluate(const arrow::RecordBatch& batch, const ArrayDataVector& output); + + /// Evaluate the specified record batch, and return the allocated and populated output + /// arrays. The output arrays will be allocated from the memory pool 'pool', and added + /// to the vector 'output'. + /// + /// \param[in] batch the record batch. schema should be the same as the one in 'Make' + /// \param[in] selection_vector selection vector which has filtered row positions. + /// \param[in] pool memory pool used to allocate output arrays (if required). + /// \param[out] output the vector of allocated/populated arrays. + Status Evaluate(const arrow::RecordBatch& batch, + const SelectionVector* selection_vector, arrow::MemoryPool* pool, + arrow::ArrayVector* output); + + /// Evaluate the specified record batch, and populate the output arrays at the filtered + /// positions. The output arrays of sufficient capacity must be allocated by the caller. + /// + /// \param[in] batch the record batch. schema should be the same as the one in 'Make' + /// \param[in] selection_vector selection vector which has the filtered row positions + /// \param[in,out] output vector of arrays, the arrays are allocated by the caller and + /// populated by Evaluate. + Status Evaluate(const arrow::RecordBatch& batch, + const SelectionVector* selection_vector, const ArrayDataVector& output); + + std::string DumpIR(); + + private: + Projector(std::unique_ptr<LLVMGenerator> llvm_generator, SchemaPtr schema, + const FieldVector& output_fields, std::shared_ptr<Configuration>); + + /// Allocate an ArrowData of length 'length'. + Status AllocArrayData(const DataTypePtr& type, int64_t num_records, + arrow::MemoryPool* pool, ArrayDataPtr* array_data); + + /// Validate that the ArrayData has sufficient capacity to accommodate 'num_records'. + Status ValidateArrayDataCapacity(const arrow::ArrayData& array_data, + const arrow::Field& field, int64_t num_records); + + /// Validate the common args for Evaluate() APIs. + Status ValidateEvaluateArgsCommon(const arrow::RecordBatch& batch); + + std::unique_ptr<LLVMGenerator> llvm_generator_; + SchemaPtr schema_; + FieldVector output_fields_; + std::shared_ptr<Configuration> configuration_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/proto/Types.proto b/src/arrow/cpp/src/gandiva/proto/Types.proto new file mode 100644 index 000000000..eb0d996b9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/proto/Types.proto @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +syntax = "proto2"; +package types; + +option java_package = "org.apache.arrow.gandiva.ipc"; +option java_outer_classname = "GandivaTypes"; +option optimize_for = SPEED; + +enum GandivaType { + NONE = 0; // arrow::Type::NA + BOOL = 1; // arrow::Type::BOOL + UINT8 = 2; // arrow::Type::UINT8 + INT8 = 3; // arrow::Type::INT8 + UINT16 = 4; // represents arrow::Type fields in src/arrow/type.h + INT16 = 5; + UINT32 = 6; + INT32 = 7; + UINT64 = 8; + INT64 = 9; + HALF_FLOAT = 10; + FLOAT = 11; + DOUBLE = 12; + UTF8 = 13; + BINARY = 14; + FIXED_SIZE_BINARY = 15; + DATE32 = 16; + DATE64 = 17; + TIMESTAMP = 18; + TIME32 = 19; + TIME64 = 20; + INTERVAL = 21; + DECIMAL = 22; + LIST = 23; + STRUCT = 24; + UNION = 25; + DICTIONARY = 26; + MAP = 27; +} + +enum DateUnit { + DAY = 0; + MILLI = 1; +} + +enum TimeUnit { + SEC = 0; + MILLISEC = 1; + MICROSEC = 2; + NANOSEC = 3; +} + +enum IntervalType { + YEAR_MONTH = 0; + DAY_TIME = 1; +} + +enum SelectionVectorType { + SV_NONE = 0; + SV_INT16 = 1; + SV_INT32 = 2; +} + +message ExtGandivaType { + optional GandivaType type = 1; + optional uint32 width = 2; // used by FIXED_SIZE_BINARY + optional int32 precision = 3; // used by DECIMAL + optional int32 scale = 4; // used by DECIMAL + optional DateUnit dateUnit = 5; // used by DATE32/DATE64 + optional TimeUnit timeUnit = 6; // used by TIME32/TIME64 + optional string timeZone = 7; // used by TIMESTAMP + optional IntervalType intervalType = 8; // used by INTERVAL +} + +message Field { + // name of the field + optional string name = 1; + optional ExtGandivaType type = 2; + optional bool nullable = 3; + // for complex data types like structs, unions + repeated Field children = 4; +} + +message FieldNode { + optional Field field = 1; +} + +message FunctionNode { + optional string functionName = 1; + repeated TreeNode inArgs = 2; + optional ExtGandivaType returnType = 3; +} + +message IfNode { + optional TreeNode cond = 1; + optional TreeNode thenNode = 2; + optional TreeNode elseNode = 3; + optional ExtGandivaType returnType = 4; +} + +message AndNode { + repeated TreeNode args = 1; +} + +message OrNode { + repeated TreeNode args = 1; +} + +message NullNode { + optional ExtGandivaType type = 1; +} + +message IntNode { + optional int32 value = 1; +} + +message FloatNode { + optional float value = 1; +} + +message DoubleNode { + optional double value = 1; +} + +message BooleanNode { + optional bool value = 1; +} + +message LongNode { + optional int64 value = 1; +} + +message StringNode { + optional bytes value = 1; +} + +message BinaryNode { + optional bytes value = 1; +} + +message DecimalNode { + optional string value = 1; + optional int32 precision = 2; + optional int32 scale = 3; +} + + +message TreeNode { + optional FieldNode fieldNode = 1; + optional FunctionNode fnNode = 2; + + // control expressions + optional IfNode ifNode = 6; + optional AndNode andNode = 7; + optional OrNode orNode = 8; + + // literals + optional NullNode nullNode = 11; + optional IntNode intNode = 12; + optional FloatNode floatNode = 13; + optional LongNode longNode = 14; + optional BooleanNode booleanNode = 15; + optional DoubleNode doubleNode = 16; + optional StringNode stringNode = 17; + optional BinaryNode binaryNode = 18; + optional DecimalNode decimalNode = 19; + + // in expr + optional InNode inNode = 21; +} + +message ExpressionRoot { + optional TreeNode root = 1; + optional Field resultType = 2; +} + +message ExpressionList { + repeated ExpressionRoot exprs = 2; +} + +message Condition { + optional TreeNode root = 1; +} + +message Schema { + repeated Field columns = 1; +} + +message GandivaDataTypes { + repeated ExtGandivaType dataType = 1; +} + +message GandivaFunctions { + repeated FunctionSignature function = 1; +} + +message FunctionSignature { + optional string name = 1; + optional ExtGandivaType returnType = 2; + repeated ExtGandivaType paramTypes = 3; +} + +message InNode { + optional TreeNode node = 1; + optional IntConstants intValues = 2; + optional LongConstants longValues = 3; + optional StringConstants stringValues = 4; + optional BinaryConstants binaryValues = 5; + optional DecimalConstants decimalValues = 6; + optional FloatConstants floatValues = 7; + optional DoubleConstants doubleValues = 8; +} + +message IntConstants { + repeated IntNode intValues = 1; +} + +message LongConstants { + repeated LongNode longValues = 1; +} + +message DecimalConstants { + repeated DecimalNode decimalValues = 1; +} + +message FloatConstants { + repeated FloatNode floatValues = 1; +} + +message DoubleConstants { + repeated DoubleNode doubleValues = 1; +} + +message StringConstants { + repeated StringNode stringValues = 1; +} + +message BinaryConstants { + repeated BinaryNode binaryValues = 1; +} diff --git a/src/arrow/cpp/src/gandiva/random_generator_holder.cc b/src/arrow/cpp/src/gandiva/random_generator_holder.cc new file mode 100644 index 000000000..3471c87d9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/random_generator_holder.cc @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/random_generator_holder.h" +#include "gandiva/node.h" + +namespace gandiva { +Status RandomGeneratorHolder::Make(const FunctionNode& node, + std::shared_ptr<RandomGeneratorHolder>* holder) { + ARROW_RETURN_IF(node.children().size() > 1, + Status::Invalid("'random' function requires at most one parameter")); + + if (node.children().size() == 0) { + *holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder()); + return Status::OK(); + } + + auto literal = dynamic_cast<LiteralNode*>(node.children().at(0).get()); + ARROW_RETURN_IF(literal == nullptr, + Status::Invalid("'random' function requires a literal as parameter")); + + auto literal_type = literal->return_type()->id(); + ARROW_RETURN_IF( + literal_type != arrow::Type::INT32, + Status::Invalid("'random' function requires an int32 literal as parameter")); + + *holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder( + literal->is_null() ? 0 : arrow::util::get<int32_t>(literal->holder()))); + return Status::OK(); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/random_generator_holder.h b/src/arrow/cpp/src/gandiva/random_generator_holder.h new file mode 100644 index 000000000..65b6607e8 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/random_generator_holder.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <random> + +#include "arrow/status.h" +#include "arrow/util/io_util.h" + +#include "gandiva/function_holder.h" +#include "gandiva/node.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Function Holder for 'random' +class GANDIVA_EXPORT RandomGeneratorHolder : public FunctionHolder { + public: + ~RandomGeneratorHolder() override = default; + + static Status Make(const FunctionNode& node, + std::shared_ptr<RandomGeneratorHolder>* holder); + + double operator()() { return distribution_(generator_); } + + private: + explicit RandomGeneratorHolder(int seed) : distribution_(0, 1) { + int64_t seed64 = static_cast<int64_t>(seed); + seed64 = (seed64 ^ 0x00000005DEECE66D) & 0x0000ffffffffffff; + generator_.seed(static_cast<uint64_t>(seed64)); + } + + RandomGeneratorHolder() : distribution_(0, 1) { + generator_.seed(::arrow::internal::GetRandomSeed()); + } + + std::mt19937_64 generator_; + std::uniform_real_distribution<> distribution_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/random_generator_holder_test.cc b/src/arrow/cpp/src/gandiva/random_generator_holder_test.cc new file mode 100644 index 000000000..4b16c1b7d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/random_generator_holder_test.cc @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/random_generator_holder.h" + +#include <memory> + +#include <gtest/gtest.h> + +namespace gandiva { + +class TestRandGenHolder : public ::testing::Test { + public: + FunctionNode BuildRandFunc() { return FunctionNode("random", {}, arrow::float64()); } + + FunctionNode BuildRandWithSeedFunc(int32_t seed, bool seed_is_null) { + auto seed_node = + std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(seed), seed_is_null); + return FunctionNode("rand", {seed_node}, arrow::float64()); + } +}; + +TEST_F(TestRandGenHolder, NoSeed) { + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder; + FunctionNode rand_func = BuildRandFunc(); + auto status = RandomGeneratorHolder::Make(rand_func, &rand_gen_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& random = *rand_gen_holder; + EXPECT_NE(random(), random()); +} + +TEST_F(TestRandGenHolder, WithValidEqualSeeds) { + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1; + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2; + FunctionNode rand_func_1 = BuildRandWithSeedFunc(12, false); + FunctionNode rand_func_2 = BuildRandWithSeedFunc(12, false); + auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1); + EXPECT_EQ(status.ok(), true) << status.message(); + status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& random_1 = *rand_gen_holder_1; + auto& random_2 = *rand_gen_holder_2; + EXPECT_EQ(random_1(), random_2()); + EXPECT_EQ(random_1(), random_2()); + EXPECT_GT(random_1(), 0); + EXPECT_NE(random_1(), random_2()); + EXPECT_LT(random_2(), 1); + EXPECT_EQ(random_1(), random_2()); +} + +TEST_F(TestRandGenHolder, WithValidSeeds) { + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1; + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2; + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_3; + FunctionNode rand_func_1 = BuildRandWithSeedFunc(11, false); + FunctionNode rand_func_2 = BuildRandWithSeedFunc(12, false); + FunctionNode rand_func_3 = BuildRandWithSeedFunc(-12, false); + auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1); + EXPECT_EQ(status.ok(), true) << status.message(); + status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2); + EXPECT_EQ(status.ok(), true) << status.message(); + status = RandomGeneratorHolder::Make(rand_func_3, &rand_gen_holder_3); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& random_1 = *rand_gen_holder_1; + auto& random_2 = *rand_gen_holder_2; + auto& random_3 = *rand_gen_holder_3; + EXPECT_NE(random_2(), random_3()); + EXPECT_NE(random_1(), random_2()); +} + +TEST_F(TestRandGenHolder, WithInValidSeed) { + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1; + std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2; + FunctionNode rand_func_1 = BuildRandWithSeedFunc(12, true); + FunctionNode rand_func_2 = BuildRandWithSeedFunc(0, false); + auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1); + EXPECT_EQ(status.ok(), true) << status.message(); + status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2); + EXPECT_EQ(status.ok(), true) << status.message(); + + auto& random_1 = *rand_gen_holder_1; + auto& random_2 = *rand_gen_holder_2; + EXPECT_EQ(random_1(), random_2()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/regex_util.cc b/src/arrow/cpp/src/gandiva/regex_util.cc new file mode 100644 index 000000000..abdd579d1 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/regex_util.cc @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/regex_util.h" + +namespace gandiva { + +const std::set<char> RegexUtil::pcre_regex_specials_ = { + '[', ']', '(', ')', '|', '^', '-', '+', '*', '?', '{', '}', '$', '\\', '.'}; + +Status RegexUtil::SqlLikePatternToPcre(const std::string& sql_pattern, char escape_char, + std::string& pcre_pattern) { + /// Characters that are considered special by pcre regex. These needs to be + /// escaped with '\\'. + pcre_pattern.clear(); + for (size_t idx = 0; idx < sql_pattern.size(); ++idx) { + auto cur = sql_pattern.at(idx); + + // Escape any char that is special for pcre regex + if (pcre_regex_specials_.find(cur) != pcre_regex_specials_.end()) { + pcre_pattern += "\\"; + } + + if (cur == escape_char) { + // escape char must be followed by '_', '%' or the escape char itself. + ++idx; + ARROW_RETURN_IF( + idx == sql_pattern.size(), + Status::Invalid("Unexpected escape char at the end of pattern ", sql_pattern)); + + cur = sql_pattern.at(idx); + if (cur == '_' || cur == '%' || cur == escape_char) { + pcre_pattern += cur; + } else { + return Status::Invalid("Invalid escape sequence in pattern ", sql_pattern, + " at offset ", idx); + } + } else if (cur == '_') { + pcre_pattern += '.'; + } else if (cur == '%') { + pcre_pattern += ".*"; + } else { + pcre_pattern += cur; + } + } + return Status::OK(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/regex_util.h b/src/arrow/cpp/src/gandiva/regex_util.h new file mode 100644 index 000000000..cf0002b8c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/regex_util.h @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <set> +#include <sstream> +#include <string> + +#include "gandiva/arrow.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Utility class for converting sql patterns to pcre patterns. +class GANDIVA_EXPORT RegexUtil { + public: + // Convert an sql pattern to a pcre pattern + static Status SqlLikePatternToPcre(const std::string& like_pattern, char escape_char, + std::string& pcre_pattern); + + static Status SqlLikePatternToPcre(const std::string& like_pattern, + std::string& pcre_pattern) { + return SqlLikePatternToPcre(like_pattern, 0 /*escape_char*/, pcre_pattern); + } + + private: + static const std::set<char> pcre_regex_specials_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/replace_holder.cc b/src/arrow/cpp/src/gandiva/replace_holder.cc new file mode 100644 index 000000000..8b42b585f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/replace_holder.cc @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/replace_holder.h" + +#include "gandiva/node.h" +#include "gandiva/regex_util.h" + +namespace gandiva { + +static bool IsArrowStringLiteral(arrow::Type::type type) { + return type == arrow::Type::STRING || type == arrow::Type::BINARY; +} + +Status ReplaceHolder::Make(const FunctionNode& node, + std::shared_ptr<ReplaceHolder>* holder) { + ARROW_RETURN_IF(node.children().size() != 3, + Status::Invalid("'replace' function requires three parameters")); + + auto literal = dynamic_cast<LiteralNode*>(node.children().at(1).get()); + ARROW_RETURN_IF( + literal == nullptr, + Status::Invalid("'replace' function requires a literal as the second parameter")); + + auto literal_type = literal->return_type()->id(); + ARROW_RETURN_IF( + !IsArrowStringLiteral(literal_type), + Status::Invalid( + "'replace' function requires a string literal as the second parameter")); + + return Make(arrow::util::get<std::string>(literal->holder()), holder); +} + +Status ReplaceHolder::Make(const std::string& sql_pattern, + std::shared_ptr<ReplaceHolder>* holder) { + auto lholder = std::shared_ptr<ReplaceHolder>(new ReplaceHolder(sql_pattern)); + ARROW_RETURN_IF(!lholder->regex_.ok(), + Status::Invalid("Building RE2 pattern '", sql_pattern, "' failed")); + + *holder = lholder; + return Status::OK(); +} + +void ReplaceHolder::return_error(ExecutionContext* context, std::string& data, + std::string& replace_string) { + std::string err_msg = "Error replacing '" + replace_string + "' on the given string '" + + data + "' for the given pattern: " + pattern_; + context->set_error_msg(err_msg.c_str()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/replace_holder.h b/src/arrow/cpp/src/gandiva/replace_holder.h new file mode 100644 index 000000000..79150d7aa --- /dev/null +++ b/src/arrow/cpp/src/gandiva/replace_holder.h @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <re2/re2.h> + +#include <memory> +#include <string> + +#include "arrow/status.h" +#include "gandiva/execution_context.h" +#include "gandiva/function_holder.h" +#include "gandiva/node.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Function Holder for 'replace' +class GANDIVA_EXPORT ReplaceHolder : public FunctionHolder { + public: + ~ReplaceHolder() override = default; + + static Status Make(const FunctionNode& node, std::shared_ptr<ReplaceHolder>* holder); + + static Status Make(const std::string& sql_pattern, + std::shared_ptr<ReplaceHolder>* holder); + + /// Return a new string with the pattern that matched the regex replaced for + /// the replace_input parameter. + const char* operator()(ExecutionContext* ctx, const char* user_input, + int32_t user_input_len, const char* replace_input, + int32_t replace_input_len, int32_t* out_length) { + std::string user_input_as_str(user_input, user_input_len); + std::string replace_input_as_str(replace_input, replace_input_len); + + int32_t total_replaces = + RE2::GlobalReplace(&user_input_as_str, regex_, replace_input_as_str); + + if (total_replaces < 0) { + return_error(ctx, user_input_as_str, replace_input_as_str); + *out_length = 0; + return ""; + } + + if (total_replaces == 0) { + *out_length = user_input_len; + return user_input; + } + + *out_length = static_cast<int32_t>(user_input_as_str.size()); + + // This condition treats the case where the whole string is replaced by an empty + // string + if (*out_length == 0) { + return ""; + } + + char* result_buffer = reinterpret_cast<char*>(ctx->arena()->Allocate(*out_length)); + + if (result_buffer == NULLPTR) { + ctx->set_error_msg("Could not allocate memory for result"); + *out_length = 0; + return ""; + } + + memcpy(result_buffer, user_input_as_str.data(), *out_length); + + return result_buffer; + } + + private: + explicit ReplaceHolder(const std::string& pattern) + : pattern_(pattern), regex_(pattern) {} + + void return_error(ExecutionContext* context, std::string& data, + std::string& replace_string); + + std::string pattern_; // posix pattern string, to help debugging + RE2 regex_; // compiled regex for the pattern +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/replace_holder_test.cc b/src/arrow/cpp/src/gandiva/replace_holder_test.cc new file mode 100644 index 000000000..b0830d4f0 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/replace_holder_test.cc @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/replace_holder.h" + +#include <gtest/gtest.h> + +#include <memory> +#include <vector> + +namespace gandiva { + +class TestReplaceHolder : public ::testing::Test { + protected: + ExecutionContext execution_context_; +}; + +TEST_F(TestReplaceHolder, TestMultipleReplace) { + std::shared_ptr<ReplaceHolder> replace_holder; + + auto status = ReplaceHolder::Make("ana", &replace_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + std::string input_string = "banana"; + std::string replace_string; + int32_t out_length = 0; + + auto& replace = *replace_holder; + const char* ret = + replace(&execution_context_, input_string.c_str(), + static_cast<int32_t>(input_string.length()), replace_string.c_str(), + static_cast<int32_t>(replace_string.length()), &out_length); + std::string ret_as_str(ret, out_length); + EXPECT_EQ(out_length, 3); + EXPECT_EQ(ret_as_str, "bna"); + + input_string = "bananaana"; + + ret = replace(&execution_context_, input_string.c_str(), + static_cast<int32_t>(input_string.length()), replace_string.c_str(), + static_cast<int32_t>(replace_string.length()), &out_length); + ret_as_str = std::string(ret, out_length); + EXPECT_EQ(out_length, 3); + EXPECT_EQ(ret_as_str, "bna"); + + input_string = "bananana"; + + ret = replace(&execution_context_, input_string.c_str(), + static_cast<int32_t>(input_string.length()), replace_string.c_str(), + static_cast<int32_t>(replace_string.length()), &out_length); + ret_as_str = std::string(ret, out_length); + EXPECT_EQ(out_length, 2); + EXPECT_EQ(ret_as_str, "bn"); + + input_string = "anaana"; + + ret = replace(&execution_context_, input_string.c_str(), + static_cast<int32_t>(input_string.length()), replace_string.c_str(), + static_cast<int32_t>(replace_string.length()), &out_length); + ret_as_str = std::string(ret, out_length); + EXPECT_EQ(out_length, 0); + EXPECT_FALSE(execution_context_.has_error()); + EXPECT_EQ(ret_as_str, ""); +} + +TEST_F(TestReplaceHolder, TestNoMatchPattern) { + std::shared_ptr<ReplaceHolder> replace_holder; + + auto status = ReplaceHolder::Make("ana", &replace_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + std::string input_string = "apple"; + std::string replace_string; + int32_t out_length = 0; + + auto& replace = *replace_holder; + const char* ret = + replace(&execution_context_, input_string.c_str(), + static_cast<int32_t>(input_string.length()), replace_string.c_str(), + static_cast<int32_t>(replace_string.length()), &out_length); + std::string ret_as_string(ret, out_length); + EXPECT_EQ(out_length, 5); + EXPECT_EQ(ret_as_string, "apple"); +} + +TEST_F(TestReplaceHolder, TestReplaceSameSize) { + std::shared_ptr<ReplaceHolder> replace_holder; + + auto status = ReplaceHolder::Make("a", &replace_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + + std::string input_string = "ananindeua"; + std::string replace_string = "b"; + int32_t out_length = 0; + + auto& replace = *replace_holder; + const char* ret = + replace(&execution_context_, input_string.c_str(), + static_cast<int32_t>(input_string.length()), replace_string.c_str(), + static_cast<int32_t>(replace_string.length()), &out_length); + std::string ret_as_string(ret, out_length); + EXPECT_EQ(out_length, 10); + EXPECT_EQ(ret_as_string, "bnbnindeub"); +} + +TEST_F(TestReplaceHolder, TestReplaceInvalidPattern) { + std::shared_ptr<ReplaceHolder> replace_holder; + + auto status = ReplaceHolder::Make("+", &replace_holder); + EXPECT_EQ(status.ok(), false) << status.message(); + + execution_context_.Reset(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/selection_vector.cc b/src/arrow/cpp/src/gandiva/selection_vector.cc new file mode 100644 index 000000000..a30bba686 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/selection_vector.cc @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/selection_vector.h" + +#include <memory> +#include <sstream> +#include <utility> +#include <vector> + +#include "arrow/util/bit_util.h" +#include "arrow/util/endian.h" + +#include "gandiva/selection_vector_impl.h" + +namespace gandiva { + +constexpr SelectionVector::Mode SelectionVector::kAllModes[kNumModes]; + +Status SelectionVector::PopulateFromBitMap(const uint8_t* bitmap, int64_t bitmap_size, + int64_t max_bitmap_index) { + const uint64_t max_idx = static_cast<uint64_t>(max_bitmap_index); + ARROW_RETURN_IF(bitmap_size % 8, Status::Invalid("Bitmap size ", bitmap_size, + " must be aligned to 64-bit size")); + ARROW_RETURN_IF(max_bitmap_index < 0, + Status::Invalid("Max bitmap index must be positive")); + ARROW_RETURN_IF( + max_idx > GetMaxSupportedValue(), + Status::Invalid("max_bitmap_index ", max_idx, " must be <= maxSupportedValue ", + GetMaxSupportedValue(), " in selection vector")); + + int64_t max_slots = GetMaxSlots(); + + // jump 8-bytes at a time, add the index corresponding to each valid bit to the + // the selection vector. + int64_t selection_idx = 0; + const uint64_t* bitmap_64 = reinterpret_cast<const uint64_t*>(bitmap); + for (int64_t bitmap_idx = 0; bitmap_idx < bitmap_size / 8; ++bitmap_idx) { + uint64_t current_word = arrow::BitUtil::ToLittleEndian(bitmap_64[bitmap_idx]); + + while (current_word != 0) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4146) +#endif + // MSVC warns about negating an unsigned type. We suppress it for now + uint64_t highest_only = current_word & -current_word; + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + int pos_in_word = arrow::BitUtil::CountTrailingZeros(highest_only); + + int64_t pos_in_bitmap = bitmap_idx * 64 + pos_in_word; + if (pos_in_bitmap > max_bitmap_index) { + // the bitmap may be slightly larger for alignment/padding. + break; + } + + ARROW_RETURN_IF(selection_idx >= max_slots, + Status::Invalid("selection vector has no remaining slots")); + + SetIndex(selection_idx, pos_in_bitmap); + ++selection_idx; + + current_word ^= highest_only; + } + } + + SetNumSlots(selection_idx); + return Status::OK(); +} + +Status SelectionVector::MakeInt16(int64_t max_slots, + std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector) { + ARROW_RETURN_NOT_OK(SelectionVectorInt16::ValidateBuffer(max_slots, buffer)); + *selection_vector = std::make_shared<SelectionVectorInt16>(max_slots, buffer); + return Status::OK(); +} + +Status SelectionVector::MakeInt16(int64_t max_slots, arrow::MemoryPool* pool, + std::shared_ptr<SelectionVector>* selection_vector) { + std::shared_ptr<arrow::Buffer> buffer; + ARROW_RETURN_NOT_OK(SelectionVectorInt16::AllocateBuffer(max_slots, pool, &buffer)); + *selection_vector = std::make_shared<SelectionVectorInt16>(max_slots, buffer); + return Status::OK(); +} + +Status SelectionVector::MakeImmutableInt16( + int64_t num_slots, std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector) { + *selection_vector = + std::make_shared<SelectionVectorInt16>(num_slots, num_slots, buffer); + return Status::OK(); +} + +Status SelectionVector::MakeInt32(int64_t max_slots, + std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector) { + ARROW_RETURN_NOT_OK(SelectionVectorInt32::ValidateBuffer(max_slots, buffer)); + *selection_vector = std::make_shared<SelectionVectorInt32>(max_slots, buffer); + + return Status::OK(); +} + +Status SelectionVector::MakeInt32(int64_t max_slots, arrow::MemoryPool* pool, + std::shared_ptr<SelectionVector>* selection_vector) { + std::shared_ptr<arrow::Buffer> buffer; + ARROW_RETURN_NOT_OK(SelectionVectorInt32::AllocateBuffer(max_slots, pool, &buffer)); + *selection_vector = std::make_shared<SelectionVectorInt32>(max_slots, buffer); + + return Status::OK(); +} + +Status SelectionVector::MakeImmutableInt32( + int64_t num_slots, std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector) { + *selection_vector = + std::make_shared<SelectionVectorInt32>(num_slots, num_slots, buffer); + return Status::OK(); +} + +Status SelectionVector::MakeInt64(int64_t max_slots, + std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector) { + ARROW_RETURN_NOT_OK(SelectionVectorInt64::ValidateBuffer(max_slots, buffer)); + *selection_vector = std::make_shared<SelectionVectorInt64>(max_slots, buffer); + + return Status::OK(); +} + +Status SelectionVector::MakeInt64(int64_t max_slots, arrow::MemoryPool* pool, + std::shared_ptr<SelectionVector>* selection_vector) { + std::shared_ptr<arrow::Buffer> buffer; + ARROW_RETURN_NOT_OK(SelectionVectorInt64::AllocateBuffer(max_slots, pool, &buffer)); + *selection_vector = std::make_shared<SelectionVectorInt64>(max_slots, buffer); + + return Status::OK(); +} + +template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode> +Status SelectionVectorImpl<C_TYPE, A_TYPE, mode>::AllocateBuffer( + int64_t max_slots, arrow::MemoryPool* pool, std::shared_ptr<arrow::Buffer>* buffer) { + auto buffer_len = max_slots * sizeof(C_TYPE); + ARROW_ASSIGN_OR_RAISE(*buffer, arrow::AllocateBuffer(buffer_len, pool)); + + return Status::OK(); +} + +template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode> +Status SelectionVectorImpl<C_TYPE, A_TYPE, mode>::ValidateBuffer( + int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer) { + ARROW_RETURN_IF(!buffer->is_mutable(), + Status::Invalid("buffer for selection vector must be mutable")); + + const int64_t min_len = max_slots * sizeof(C_TYPE); + ARROW_RETURN_IF(buffer->size() < min_len, + Status::Invalid("Buffer for selection vector is too small")); + + return Status::OK(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/selection_vector.h b/src/arrow/cpp/src/gandiva/selection_vector.h new file mode 100644 index 000000000..1c0fef1c5 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/selection_vector.h @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> + +#include "arrow/status.h" + +#include "arrow/util/logging.h" +#include "gandiva/arrow.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Selection Vector : vector of indices in a row-batch for a selection, +/// backed by an arrow-array. +class GANDIVA_EXPORT SelectionVector { + public: + virtual ~SelectionVector() = default; + + enum Mode : int { + MODE_NONE, + MODE_UINT16, + MODE_UINT32, + MODE_UINT64, + MODE_MAX = MODE_UINT64, // dummy + }; + static constexpr int kNumModes = static_cast<int>(MODE_MAX) + 1; + static constexpr Mode kAllModes[kNumModes] = {MODE_NONE, MODE_UINT16, MODE_UINT32, + MODE_UINT64}; + + /// Get the value at a given index. + virtual uint64_t GetIndex(int64_t index) const = 0; + + /// Set the value at a given index. + virtual void SetIndex(int64_t index, uint64_t value) = 0; + + // Get the max supported value in the selection vector. + virtual uint64_t GetMaxSupportedValue() const = 0; + + /// The maximum slots (capacity) of the selection vector. + virtual int64_t GetMaxSlots() const = 0; + + /// The number of slots (size) of the selection vector. + virtual int64_t GetNumSlots() const = 0; + + /// Set the number of slots in the selection vector. + virtual void SetNumSlots(int64_t num_slots) = 0; + + /// Convert to arrow-array. + virtual ArrayPtr ToArray() const = 0; + + /// Get the underlying arrow buffer. + virtual arrow::Buffer& GetBuffer() const = 0; + + /// Mode of SelectionVector + virtual Mode GetMode() const = 0; + + /// \brief populate selection vector for all the set bits in the bitmap. + /// + /// \param[in] bitmap the bitmap + /// \param[in] bitmap_size size of the bitmap in bytes + /// \param[in] max_bitmap_index max valid index in bitmap (can be lesser than + /// capacity in the bitmap, due to alignment/padding). + Status PopulateFromBitMap(const uint8_t* bitmap, int64_t bitmap_size, + int64_t max_bitmap_index); + + /// \brief make selection vector with int16 type records. + /// + /// \param[in] max_slots max number of slots + /// \param[in] buffer buffer sized to accommodate max_slots + /// \param[out] selection_vector selection vector backed by 'buffer' + static Status MakeInt16(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector); + + /// \param[in] max_slots max number of slots + /// \param[in] pool memory pool to allocate buffer + /// \param[out] selection_vector selection vector backed by a buffer allocated from the + /// pool. + static Status MakeInt16(int64_t max_slots, arrow::MemoryPool* pool, + std::shared_ptr<SelectionVector>* selection_vector); + + /// \brief creates a selection vector with pre populated buffer. + /// + /// \param[in] num_slots size of the selection vector + /// \param[in] buffer pre-populated buffer + /// \param[out] selection_vector selection vector backed by 'buffer' + static Status MakeImmutableInt16(int64_t num_slots, + std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector); + + /// \brief make selection vector with int32 type records. + /// + /// \param[in] max_slots max number of slots + /// \param[in] buffer buffer sized to accommodate max_slots + /// \param[out] selection_vector selection vector backed by 'buffer' + static Status MakeInt32(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector); + + /// \brief make selection vector with int32 type records. + /// + /// \param[in] max_slots max number of slots + /// \param[in] pool memory pool to allocate buffer + /// \param[out] selection_vector selection vector backed by a buffer allocated from the + /// pool. + static Status MakeInt32(int64_t max_slots, arrow::MemoryPool* pool, + std::shared_ptr<SelectionVector>* selection_vector); + + /// \brief creates a selection vector with pre populated buffer. + /// + /// \param[in] num_slots size of the selection vector + /// \param[in] buffer pre-populated buffer + /// \param[out] selection_vector selection vector backed by 'buffer' + static Status MakeImmutableInt32(int64_t num_slots, + std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector); + + /// \brief make selection vector with int64 type records. + /// + /// \param[in] max_slots max number of slots + /// \param[in] buffer buffer sized to accommodate max_slots + /// \param[out] selection_vector selection vector backed by 'buffer' + static Status MakeInt64(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer, + std::shared_ptr<SelectionVector>* selection_vector); + + /// \brief make selection vector with int64 type records. + /// + /// \param[in] max_slots max number of slots + /// \param[in] pool memory pool to allocate buffer + /// \param[out] selection_vector selection vector backed by a buffer allocated from the + /// pool. + static Status MakeInt64(int64_t max_slots, arrow::MemoryPool* pool, + std::shared_ptr<SelectionVector>* selection_vector); +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/selection_vector_impl.h b/src/arrow/cpp/src/gandiva/selection_vector_impl.h new file mode 100644 index 000000000..dc9724ca8 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/selection_vector_impl.h @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <limits> +#include <memory> + +#include "arrow/status.h" +#include "arrow/util/macros.h" + +#include "arrow/util/logging.h" +#include "gandiva/arrow.h" +#include "gandiva/selection_vector.h" + +namespace gandiva { + +/// \brief template implementation of selection vector with a specific ctype and arrow +/// type. +template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode> +class SelectionVectorImpl : public SelectionVector { + public: + SelectionVectorImpl(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer) + : max_slots_(max_slots), num_slots_(0), buffer_(buffer), mode_(mode) { + raw_data_ = reinterpret_cast<C_TYPE*>(buffer->mutable_data()); + } + + SelectionVectorImpl(int64_t max_slots, int64_t num_slots, + std::shared_ptr<arrow::Buffer> buffer) + : max_slots_(max_slots), num_slots_(num_slots), buffer_(buffer), mode_(mode) { + if (buffer) { + raw_data_ = const_cast<C_TYPE*>(reinterpret_cast<const C_TYPE*>(buffer->data())); + } + } + + uint64_t GetIndex(int64_t index) const override { return raw_data_[index]; } + + void SetIndex(int64_t index, uint64_t value) override { + raw_data_[index] = static_cast<C_TYPE>(value); + } + + ArrayPtr ToArray() const override; + + int64_t GetMaxSlots() const override { return max_slots_; } + + int64_t GetNumSlots() const override { return num_slots_; } + + void SetNumSlots(int64_t num_slots) override { + DCHECK_LE(num_slots, max_slots_); + num_slots_ = num_slots; + } + + uint64_t GetMaxSupportedValue() const override { + return std::numeric_limits<C_TYPE>::max(); + } + + Mode GetMode() const override { return mode_; } + + arrow::Buffer& GetBuffer() const override { return *buffer_; } + + static Status AllocateBuffer(int64_t max_slots, arrow::MemoryPool* pool, + std::shared_ptr<arrow::Buffer>* buffer); + + static Status ValidateBuffer(int64_t max_slots, std::shared_ptr<arrow::Buffer> buffer); + + protected: + /// maximum slots in the vector + int64_t max_slots_; + + /// number of slots in the vector + int64_t num_slots_; + + std::shared_ptr<arrow::Buffer> buffer_; + C_TYPE* raw_data_; + + /// SelectionVector mode + Mode mode_; +}; + +template <typename C_TYPE, typename A_TYPE, SelectionVector::Mode mode> +ArrayPtr SelectionVectorImpl<C_TYPE, A_TYPE, mode>::ToArray() const { + auto data_type = arrow::TypeTraits<A_TYPE>::type_singleton(); + auto array_data = arrow::ArrayData::Make(data_type, num_slots_, {NULLPTR, buffer_}); + return arrow::MakeArray(array_data); +} + +using SelectionVectorInt16 = + SelectionVectorImpl<uint16_t, arrow::UInt16Type, SelectionVector::MODE_UINT16>; +using SelectionVectorInt32 = + SelectionVectorImpl<uint32_t, arrow::UInt32Type, SelectionVector::MODE_UINT32>; +using SelectionVectorInt64 = + SelectionVectorImpl<uint64_t, arrow::UInt64Type, SelectionVector::MODE_UINT64>; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/selection_vector_test.cc b/src/arrow/cpp/src/gandiva/selection_vector_test.cc new file mode 100644 index 000000000..686892901 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/selection_vector_test.cc @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/selection_vector.h" + +#include <memory> +#include <utility> +#include <vector> + +#include <gtest/gtest.h> + +#include "arrow/testing/gtest_util.h" + +namespace gandiva { + +class TestSelectionVector : public ::testing::Test { + protected: + virtual void SetUp() { pool_ = arrow::default_memory_pool(); } + + arrow::MemoryPool* pool_; +}; + +static inline uint32_t RoundUpNumi64(uint32_t value) { return (value + 63) >> 6; } + +TEST_F(TestSelectionVector, TestInt16Make) { + int max_slots = 10; + + // Test with pool allocation + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + EXPECT_EQ(selection->GetMaxSlots(), max_slots); + EXPECT_EQ(selection->GetNumSlots(), 0); + + // Test with pre-alloced buffer + std::shared_ptr<SelectionVector> selection2; + auto buffer_len = max_slots * sizeof(int16_t); + ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len, pool_)); + + status = SelectionVector::MakeInt16(max_slots, std::move(buffer), &selection2); + EXPECT_EQ(status.ok(), true) << status.message(); + EXPECT_EQ(selection2->GetMaxSlots(), max_slots); + EXPECT_EQ(selection2->GetNumSlots(), 0); +} + +TEST_F(TestSelectionVector, TestInt16MakeNegative) { + int max_slots = 10; + + std::shared_ptr<SelectionVector> selection; + auto buffer_len = max_slots * sizeof(int16_t); + + // alloc a buffer that's insufficient. + ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len - 16, pool_)); + + auto status = SelectionVector::MakeInt16(max_slots, std::move(buffer), &selection); + EXPECT_EQ(status.IsInvalid(), true); +} + +TEST_F(TestSelectionVector, TestInt16Set) { + int max_slots = 10; + + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + selection->SetIndex(0, 100); + EXPECT_EQ(selection->GetIndex(0), 100); + + selection->SetIndex(1, 200); + EXPECT_EQ(selection->GetIndex(1), 200); + + selection->SetNumSlots(2); + EXPECT_EQ(selection->GetNumSlots(), 2); + + // TopArray() should return an array with 100,200 + auto array_raw = selection->ToArray(); + const auto& array = dynamic_cast<const arrow::UInt16Array&>(*array_raw); + EXPECT_EQ(array.length(), 2) << array_raw->ToString(); + EXPECT_EQ(array.Value(0), 100) << array_raw->ToString(); + EXPECT_EQ(array.Value(1), 200) << array_raw->ToString(); +} + +TEST_F(TestSelectionVector, TestInt16PopulateFromBitMap) { + int max_slots = 200; + + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + int bitmap_size = RoundUpNumi64(max_slots) * 8; + std::vector<uint8_t> bitmap(bitmap_size); + + arrow::BitUtil::SetBit(&bitmap[0], 0); + arrow::BitUtil::SetBit(&bitmap[0], 5); + arrow::BitUtil::SetBit(&bitmap[0], 121); + arrow::BitUtil::SetBit(&bitmap[0], 220); + + status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, max_slots - 1); + EXPECT_EQ(status.ok(), true) << status.message(); + + EXPECT_EQ(selection->GetNumSlots(), 3); + EXPECT_EQ(selection->GetIndex(0), 0); + EXPECT_EQ(selection->GetIndex(1), 5); + EXPECT_EQ(selection->GetIndex(2), 121); +} + +TEST_F(TestSelectionVector, TestInt16PopulateFromBitMapNegative) { + int max_slots = 2; + + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt16(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + int bitmap_size = 16; + std::vector<uint8_t> bitmap(bitmap_size); + + arrow::BitUtil::SetBit(&bitmap[0], 0); + arrow::BitUtil::SetBit(&bitmap[0], 1); + arrow::BitUtil::SetBit(&bitmap[0], 2); + + // The bitmap has three set bits, whereas the selection vector has capacity for only 2. + status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, 2); + EXPECT_EQ(status.IsInvalid(), true); +} + +TEST_F(TestSelectionVector, TestInt32Set) { + int max_slots = 10; + + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt32(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + selection->SetIndex(0, 100); + EXPECT_EQ(selection->GetIndex(0), 100); + + selection->SetIndex(1, 200); + EXPECT_EQ(selection->GetIndex(1), 200); + + selection->SetIndex(2, 100000); + EXPECT_EQ(selection->GetIndex(2), 100000); + + selection->SetNumSlots(3); + EXPECT_EQ(selection->GetNumSlots(), 3); + + // TopArray() should return an array with 100,200,100000 + auto array_raw = selection->ToArray(); + const auto& array = dynamic_cast<const arrow::UInt32Array&>(*array_raw); + EXPECT_EQ(array.length(), 3) << array_raw->ToString(); + EXPECT_EQ(array.Value(0), 100) << array_raw->ToString(); + EXPECT_EQ(array.Value(1), 200) << array_raw->ToString(); + EXPECT_EQ(array.Value(2), 100000) << array_raw->ToString(); +} + +TEST_F(TestSelectionVector, TestInt32PopulateFromBitMap) { + int max_slots = 200; + + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt32(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + int bitmap_size = RoundUpNumi64(max_slots) * 8; + std::vector<uint8_t> bitmap(bitmap_size); + + arrow::BitUtil::SetBit(&bitmap[0], 0); + arrow::BitUtil::SetBit(&bitmap[0], 5); + arrow::BitUtil::SetBit(&bitmap[0], 121); + arrow::BitUtil::SetBit(&bitmap[0], 220); + + status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, max_slots - 1); + EXPECT_EQ(status.ok(), true) << status.message(); + + EXPECT_EQ(selection->GetNumSlots(), 3); + EXPECT_EQ(selection->GetIndex(0), 0); + EXPECT_EQ(selection->GetIndex(1), 5); + EXPECT_EQ(selection->GetIndex(2), 121); +} + +TEST_F(TestSelectionVector, TestInt32MakeNegative) { + int max_slots = 10; + + std::shared_ptr<SelectionVector> selection; + auto buffer_len = max_slots * sizeof(int32_t); + + // alloc a buffer that's insufficient. + ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len - 1, pool_)); + + auto status = SelectionVector::MakeInt32(max_slots, std::move(buffer), &selection); + EXPECT_EQ(status.IsInvalid(), true); +} + +TEST_F(TestSelectionVector, TestInt64Set) { + int max_slots = 10; + + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt64(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + selection->SetIndex(0, 100); + EXPECT_EQ(selection->GetIndex(0), 100); + + selection->SetIndex(1, 200); + EXPECT_EQ(selection->GetIndex(1), 200); + + selection->SetIndex(2, 100000); + EXPECT_EQ(selection->GetIndex(2), 100000); + + selection->SetNumSlots(3); + EXPECT_EQ(selection->GetNumSlots(), 3); + + // TopArray() should return an array with 100,200,100000 + auto array_raw = selection->ToArray(); + const auto& array = dynamic_cast<const arrow::UInt64Array&>(*array_raw); + EXPECT_EQ(array.length(), 3) << array_raw->ToString(); + EXPECT_EQ(array.Value(0), 100) << array_raw->ToString(); + EXPECT_EQ(array.Value(1), 200) << array_raw->ToString(); + EXPECT_EQ(array.Value(2), 100000) << array_raw->ToString(); +} + +TEST_F(TestSelectionVector, TestInt64PopulateFromBitMap) { + int max_slots = 200; + + std::shared_ptr<SelectionVector> selection; + auto status = SelectionVector::MakeInt64(max_slots, pool_, &selection); + EXPECT_EQ(status.ok(), true) << status.message(); + + int bitmap_size = RoundUpNumi64(max_slots) * 8; + std::vector<uint8_t> bitmap(bitmap_size); + + arrow::BitUtil::SetBit(&bitmap[0], 0); + arrow::BitUtil::SetBit(&bitmap[0], 5); + arrow::BitUtil::SetBit(&bitmap[0], 121); + arrow::BitUtil::SetBit(&bitmap[0], 220); + + status = selection->PopulateFromBitMap(&bitmap[0], bitmap_size, max_slots - 1); + EXPECT_EQ(status.ok(), true) << status.message(); + + EXPECT_EQ(selection->GetNumSlots(), 3); + EXPECT_EQ(selection->GetIndex(0), 0); + EXPECT_EQ(selection->GetIndex(1), 5); + EXPECT_EQ(selection->GetIndex(2), 121); +} + +TEST_F(TestSelectionVector, TestInt64MakeNegative) { + int max_slots = 10; + + std::shared_ptr<SelectionVector> selection; + auto buffer_len = max_slots * sizeof(int64_t); + + // alloc a buffer that's insufficient. + ASSERT_OK_AND_ASSIGN(auto buffer, arrow::AllocateBuffer(buffer_len - 1, pool_)); + + auto status = SelectionVector::MakeInt64(max_slots, std::move(buffer), &selection); + EXPECT_EQ(status.IsInvalid(), true); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/simple_arena.h b/src/arrow/cpp/src/gandiva/simple_arena.h new file mode 100644 index 000000000..da00b3397 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/simple_arena.h @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <algorithm> +#include <memory> +#include <utility> +#include <vector> + +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "gandiva/arrow.h" + +namespace gandiva { + +/// \brief Simple arena allocator. +/// +/// Memory is allocated from system in units of chunk-size, and dished out in the +/// requested sizes. If the requested size > chunk-size, allocate directly from the +/// system. +/// +/// The allocated memory gets released only when the arena is destroyed, or on +/// Reset. +/// +/// This code is not multi-thread safe, and avoids all locking for efficiency. +/// +class SimpleArena { + public: + explicit SimpleArena(arrow::MemoryPool* pool, int64_t min_chunk_size = 4096); + + ~SimpleArena(); + + // Allocate buffer of requested size. + uint8_t* Allocate(int64_t size); + + // Reset arena state. + void Reset(); + + // total bytes allocated from system. + int64_t total_bytes() { return total_bytes_; } + + // total bytes available for allocations. + int64_t avail_bytes() { return avail_bytes_; } + + private: + struct Chunk { + Chunk(uint8_t* buf, int64_t size) : buf_(buf), size_(size) {} + + uint8_t* buf_; + int64_t size_; + }; + + // Allocate new chunk. + arrow::Status AllocateChunk(int64_t size); + + // release memory from buffers. + void ReleaseChunks(bool retain_first); + + // Memory pool used for allocs. + arrow::MemoryPool* pool_; + + // The chunk-size used for allocations from system. + int64_t min_chunk_size_; + + // Total bytes allocated from system. + int64_t total_bytes_; + + // Bytes available from allocated chunk. + int64_t avail_bytes_; + + // buffer from current chunk. + uint8_t* avail_buf_; + + // List of allocated chunks. + std::vector<Chunk> chunks_; +}; + +inline SimpleArena::SimpleArena(arrow::MemoryPool* pool, int64_t min_chunk_size) + : pool_(pool), + min_chunk_size_(min_chunk_size), + total_bytes_(0), + avail_bytes_(0), + avail_buf_(NULL) {} + +inline SimpleArena::~SimpleArena() { ReleaseChunks(false /*retain_first*/); } + +inline uint8_t* SimpleArena::Allocate(int64_t size) { + if (avail_bytes_ < size) { + auto status = AllocateChunk(std::max(size, min_chunk_size_)); + if (!status.ok()) { + return NULL; + } + } + + uint8_t* ret = avail_buf_; + avail_buf_ += size; + avail_bytes_ -= size; + return ret; +} + +inline arrow::Status SimpleArena::AllocateChunk(int64_t size) { + uint8_t* out; + + auto status = pool_->Allocate(size, &out); + ARROW_RETURN_NOT_OK(status); + + chunks_.emplace_back(out, size); + avail_buf_ = out; + avail_bytes_ = size; // left-over bytes in the previous chunk cannot be used anymore. + total_bytes_ += size; + return arrow::Status::OK(); +} + +// In the most common case, a chunk will be allocated when processing the first record. +// And, the same chunk can be used for processing the remaining records in the batch. +// By retaining the first chunk, the number of malloc calls are reduced to one per batch, +// instead of one per record. +inline void SimpleArena::Reset() { + if (chunks_.size() == 0) { + // if there are no chunks, nothing to do. + return; + } + + // Release all but the first chunk. + if (chunks_.size() > 1) { + ReleaseChunks(true); + chunks_.erase(chunks_.begin() + 1, chunks_.end()); + } + + avail_buf_ = chunks_.at(0).buf_; + avail_bytes_ = total_bytes_ = chunks_.at(0).size_; +} + +inline void SimpleArena::ReleaseChunks(bool retain_first) { + for (auto& chunk : chunks_) { + if (retain_first) { + // skip freeing first chunk. + retain_first = false; + continue; + } + pool_->Free(chunk.buf_, chunk.size_); + } +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/simple_arena_test.cc b/src/arrow/cpp/src/gandiva/simple_arena_test.cc new file mode 100644 index 000000000..60831280c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/simple_arena_test.cc @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/simple_arena.h" + +#include <gtest/gtest.h> + +#include "arrow/memory_pool.h" + +namespace gandiva { + +class TestSimpleArena : public ::testing::Test {}; + +TEST_F(TestSimpleArena, TestAlloc) { + int64_t chunk_size = 4096; + SimpleArena arena(arrow::default_memory_pool(), chunk_size); + + // Small allocations should come from the same chunk. + int64_t small_size = 100; + for (int64_t i = 0; i < 20; ++i) { + auto p = arena.Allocate(small_size); + EXPECT_NE(p, nullptr); + + EXPECT_EQ(arena.total_bytes(), chunk_size); + EXPECT_EQ(arena.avail_bytes(), chunk_size - (i + 1) * small_size); + } + + // large allocations require separate chunks + int64_t large_size = 100 * chunk_size; + auto p = arena.Allocate(large_size); + EXPECT_NE(p, nullptr); + EXPECT_EQ(arena.total_bytes(), chunk_size + large_size); + EXPECT_EQ(arena.avail_bytes(), 0); +} + +// small followed by big, then reset +TEST_F(TestSimpleArena, TestReset1) { + int64_t chunk_size = 4096; + SimpleArena arena(arrow::default_memory_pool(), chunk_size); + + int64_t small_size = 100; + auto p = arena.Allocate(small_size); + EXPECT_NE(p, nullptr); + + int64_t large_size = 100 * chunk_size; + p = arena.Allocate(large_size); + EXPECT_NE(p, nullptr); + + EXPECT_EQ(arena.total_bytes(), chunk_size + large_size); + EXPECT_EQ(arena.avail_bytes(), 0); + arena.Reset(); + EXPECT_EQ(arena.total_bytes(), chunk_size); + EXPECT_EQ(arena.avail_bytes(), chunk_size); + + // should re-use buffer after reset. + p = arena.Allocate(small_size); + EXPECT_NE(p, nullptr); + EXPECT_EQ(arena.total_bytes(), chunk_size); + EXPECT_EQ(arena.avail_bytes(), chunk_size - small_size); +} + +// big followed by small, then reset +TEST_F(TestSimpleArena, TestReset2) { + int64_t chunk_size = 4096; + SimpleArena arena(arrow::default_memory_pool(), chunk_size); + + int64_t large_size = 100 * chunk_size; + auto p = arena.Allocate(large_size); + EXPECT_NE(p, nullptr); + + int64_t small_size = 100; + p = arena.Allocate(small_size); + EXPECT_NE(p, nullptr); + + EXPECT_EQ(arena.total_bytes(), chunk_size + large_size); + EXPECT_EQ(arena.avail_bytes(), chunk_size - small_size); + arena.Reset(); + EXPECT_EQ(arena.total_bytes(), large_size); + EXPECT_EQ(arena.avail_bytes(), large_size); + + // should re-use buffer after reset. + p = arena.Allocate(small_size); + EXPECT_NE(p, nullptr); + EXPECT_EQ(arena.total_bytes(), large_size); + EXPECT_EQ(arena.avail_bytes(), large_size - small_size); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/symbols.map b/src/arrow/cpp/src/gandiva/symbols.map new file mode 100644 index 000000000..77f000106 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/symbols.map @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +{ + # Symbols marked as 'local' are not exported by the DSO and thus may not + # be used by client applications. + local: + # devtoolset / static-libstdc++ symbols + __cxa_*; + __once_proxy; + + extern "C++" { + # devtoolset or -static-libstdc++ - the Red Hat devtoolset statically + # links c++11 symbols into binaries so that the result may be executed on + # a system with an older libstdc++ which doesn't include the necessary + # c++11 symbols. + std::*; + *std::__once_call*; + }; +}; + diff --git a/src/arrow/cpp/src/gandiva/tests/CMakeLists.txt b/src/arrow/cpp/src/gandiva/tests/CMakeLists.txt new file mode 100644 index 000000000..5fa2da16c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/CMakeLists.txt @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_gandiva_test(filter_test) +add_gandiva_test(projector_test) +add_gandiva_test(projector_build_validation_test) +add_gandiva_test(if_expr_test) +add_gandiva_test(literal_test) +add_gandiva_test(boolean_expr_test) +add_gandiva_test(binary_test) +add_gandiva_test(date_time_test) +add_gandiva_test(to_string_test) +add_gandiva_test(utf8_test) +add_gandiva_test(hash_test) +add_gandiva_test(in_expr_test) +add_gandiva_test(null_validity_test) +add_gandiva_test(decimal_test) +add_gandiva_test(decimal_single_test) +add_gandiva_test(filter_project_test) + +if(ARROW_BUILD_STATIC) + add_gandiva_test(projector_test_static SOURCES projector_test.cc USE_STATIC_LINKING) + add_arrow_benchmark(micro_benchmarks + PREFIX + "gandiva" + EXTRA_LINK_LIBS + gandiva_static) +endif() diff --git a/src/arrow/cpp/src/gandiva/tests/binary_test.cc b/src/arrow/cpp/src/gandiva/tests/binary_test.cc new file mode 100644 index 000000000..591c5befc --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/binary_test.cc @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> + +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "gandiva/node.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::binary; +using arrow::boolean; +using arrow::int32; + +class TestBinary : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestBinary, TestSimple) { + // schema for input fields + auto field_a = field("a", binary()); + auto field_b = field("b", binary()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", int32()); + + // build expressions. + // a > b ? octet_length(a) : octet_length(b) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeField(field_b); + auto octet_len_a = TreeExprBuilder::MakeFunction("octet_length", {node_a}, int32()); + auto octet_len_b = TreeExprBuilder::MakeFunction("octet_length", {node_b}, int32()); + + auto is_greater = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto if_greater = + TreeExprBuilder::MakeIf(is_greater, octet_len_a, octet_len_b, int32()); + auto expr = TreeExprBuilder::MakeExpression(if_greater, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayBinary({"foo", "hello", "hi", "bye"}, {true, true, true, false}); + auto array_b = + MakeArrowArrayBinary({"fo", "hellos", "hi", "bye"}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({3, 6, 2, 3}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBinary, TestIfElse) { + // schema for input fields + auto field0 = field("f0", arrow::binary()); + auto field1 = field("f1", arrow::binary()); + + auto schema = arrow::schema({field0, field1}); + + auto f0 = TreeExprBuilder::MakeField(field0); + auto f1 = TreeExprBuilder::MakeField(field1); + + // output fields + auto field_result = field("out", arrow::binary()); + + // Build expression + auto cond = TreeExprBuilder::MakeFunction("isnotnull", {f0}, arrow::boolean()); + auto ifexpr = TreeExprBuilder::MakeIf(cond, f0, f1, arrow::binary()); + auto expr = TreeExprBuilder::MakeExpression(ifexpr, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_f0 = + MakeArrowArrayBinary({"foo", "hello", "hi", "bye"}, {true, true, true, false}); + auto array_f1 = + MakeArrowArrayBinary({"fe", "fi", "fo", "fum"}, {true, true, true, true}); + + // expected output + auto exp = + MakeArrowArrayBinary({"foo", "hello", "hi", "fum"}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_f0, array_f1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/boolean_expr_test.cc b/src/arrow/cpp/src/gandiva/tests/boolean_expr_test.cc new file mode 100644 index 000000000..9226f3571 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/boolean_expr_test.cc @@ -0,0 +1,388 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" + +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; + +class TestBooleanExpr : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestBooleanExpr, SimpleAnd) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) && (b > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + + auto node_and = TreeExprBuilder::MakeAnd({a_gt_0, b_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_and, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // FALSE_VALID && ? => FALSE_VALID + int num_records = 4; + auto arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {true, true, true, true}); + auto arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + auto exp = MakeArrowArrayBool({false, false, false, false}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // FALSE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, false, false}, {true, false, false, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // TRUE_VALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {true, true, true, true}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {true, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // TRUE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, false, false}, {true, false, false, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, SimpleOr) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) || (b > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + + auto node_or = TreeExprBuilder::MakeOr({a_gt_0, b_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_or, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // TRUE_VALID && ? => TRUE_VALID + int num_records = 4; + auto arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {true, true, true, true}); + auto arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + auto exp = MakeArrowArrayBool({true, true, true, true}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // TRUE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({2, 2, 2, 2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {false, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // FALSE_VALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {true, true, true, true}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {true, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); + + // FALSE_INVALID && ? + num_records = 4; + arraya = MakeArrowArrayInt32({-2, -2, -2, -2}, {false, false, false, false}); + arrayb = MakeArrowArrayInt32({-2, -2, 2, 2}, {true, false, true, false}); + exp = MakeArrowArrayBool({false, false, true, false}, {false, false, true, false}); + in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + outputs.clear(); + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, AndThree) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) && (b > 0) && (c > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + auto c_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_c, literal_0}, boolean()); + + auto node_and = TreeExprBuilder::MakeAnd({a_gt_0, b_gt_0, c_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_and, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 8; + std::vector<bool> validity({true, true, true, true, true, true, true, true}); + auto arraya = MakeArrowArrayInt32({2, 2, 2, 0, 2, 0, 0, 0}, validity); + auto arrayb = MakeArrowArrayInt32({2, 2, 0, 2, 0, 2, 0, 0}, validity); + auto arrayc = MakeArrowArrayInt32({2, 0, 2, 2, 0, 0, 2, 0}, validity); + auto exp = MakeArrowArrayBool({true, false, false, false, false, false, false, false}, + validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb, arrayc}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, OrThree) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (a > 0) || (b > 0) || (c > 0) + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto literal_0 = TreeExprBuilder::MakeLiteral((int32_t)0); + auto a_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_0}, boolean()); + auto b_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_0}, boolean()); + auto c_gt_0 = + TreeExprBuilder::MakeFunction("greater_than", {node_c, literal_0}, boolean()); + + auto node_or = TreeExprBuilder::MakeOr({a_gt_0, b_gt_0, c_gt_0}); + auto expr = TreeExprBuilder::MakeExpression(node_or, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 8; + std::vector<bool> validity({true, true, true, true, true, true, true, true}); + auto arraya = MakeArrowArrayInt32({2, 2, 2, 0, 2, 0, 0, 0}, validity); + auto arrayb = MakeArrowArrayInt32({2, 2, 0, 2, 0, 2, 0, 0}, validity); + auto arrayc = MakeArrowArrayInt32({2, 0, 2, 2, 0, 0, 2, 0}, validity); + auto exp = + MakeArrowArrayBool({true, true, true, true, true, true, true, false}, validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb, arrayc}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, BooleanAndInsideIf) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // if (a > 2 && b > 2) + // a > 3 && b > 3 + // else + // a > 1 && b > 1 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_1 = TreeExprBuilder::MakeLiteral((int32_t)1); + auto literal_2 = TreeExprBuilder::MakeLiteral((int32_t)2); + auto literal_3 = TreeExprBuilder::MakeLiteral((int32_t)3); + auto a_gt_1 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_1}, boolean()); + auto a_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_2}, boolean()); + auto a_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_3}, boolean()); + auto b_gt_1 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_1}, boolean()); + auto b_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_2}, boolean()); + auto b_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_3}, boolean()); + + auto and_1 = TreeExprBuilder::MakeAnd({a_gt_1, b_gt_1}); + auto and_2 = TreeExprBuilder::MakeAnd({a_gt_2, b_gt_2}); + auto and_3 = TreeExprBuilder::MakeAnd({a_gt_3, b_gt_3}); + + auto node_if = TreeExprBuilder::MakeIf(and_2, and_3, and_1, arrow::boolean()); + auto expr = TreeExprBuilder::MakeExpression(node_if, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 4; + std::vector<bool> validity({true, true, true, true}); + auto arraya = MakeArrowArrayInt32({4, 4, 2, 1}, validity); + auto arrayb = MakeArrowArrayInt32({5, 3, 3, 1}, validity); + auto exp = MakeArrowArrayBool({true, false, true, false}, validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestBooleanExpr, IfInsideBooleanAnd) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // (if (a > b) a > 3 else b > 3) && (if (a > b) a > 2 else b > 2) + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_2 = TreeExprBuilder::MakeLiteral((int32_t)2); + auto literal_3 = TreeExprBuilder::MakeLiteral((int32_t)3); + auto a_gt_b = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto a_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_2}, boolean()); + auto a_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_3}, boolean()); + auto b_gt_2 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_2}, boolean()); + auto b_gt_3 = + TreeExprBuilder::MakeFunction("greater_than", {node_b, literal_3}, boolean()); + + auto if_3 = TreeExprBuilder::MakeIf(a_gt_b, a_gt_3, b_gt_3, arrow::boolean()); + auto if_2 = TreeExprBuilder::MakeIf(a_gt_b, a_gt_2, b_gt_2, arrow::boolean()); + auto node_and = TreeExprBuilder::MakeAnd({if_3, if_2}); + auto expr = TreeExprBuilder::MakeExpression(node_and, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + int num_records = 4; + std::vector<bool> validity({true, true, true, true}); + auto arraya = MakeArrowArrayInt32({4, 3, 3, 2}, validity); + auto arrayb = MakeArrowArrayInt32({3, 4, 2, 3}, validity); + auto exp = MakeArrowArrayBool({true, true, false, false}, validity); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arraya, arrayb}); + + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/date_time_test.cc b/src/arrow/cpp/src/gandiva/tests/date_time_test.cc new file mode 100644 index 000000000..77139125f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/date_time_test.cc @@ -0,0 +1,602 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include <math.h> +#include <time.h> + +#include "arrow/memory_pool.h" +#include "gandiva/precompiled/time_constants.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::date32; +using arrow::date64; +using arrow::float32; +using arrow::int32; +using arrow::int64; +using arrow::timestamp; + +class TestProjector : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +time_t Epoch() { + // HACK: MSVC mktime() fails on UTC times before 1970-01-01 00:00:00. + // But it first converts its argument from local time to UTC time, + // so we ask for 1970-01-02 to avoid failing in timezones ahead of UTC. + struct tm y1970; + memset(&y1970, 0, sizeof(struct tm)); + y1970.tm_year = 70; + y1970.tm_mon = 0; + y1970.tm_mday = 2; + y1970.tm_hour = 0; + y1970.tm_min = 0; + y1970.tm_sec = 0; + time_t epoch = mktime(&y1970); + if (epoch == static_cast<time_t>(-1)) { + ARROW_LOG(FATAL) << "mktime() failed"; + } + // Adjust for the 24h offset above. + return epoch - 24 * 3600; +} + +int32_t MillisInDay(int32_t hh, int32_t mm, int32_t ss, int32_t millis) { + int32_t mins = hh * 60 + mm; + int32_t secs = mins * 60 + ss; + + return secs * 1000 + millis; +} + +int64_t MillisSince(time_t base_line, int32_t yy, int32_t mm, int32_t dd, int32_t hr, + int32_t min, int32_t sec, int32_t millis) { + struct tm given_ts; + memset(&given_ts, 0, sizeof(struct tm)); + given_ts.tm_year = (yy - 1900); + given_ts.tm_mon = (mm - 1); + given_ts.tm_mday = dd; + given_ts.tm_hour = hr; + given_ts.tm_min = min; + given_ts.tm_sec = sec; + + time_t ts = mktime(&given_ts); + if (ts == static_cast<time_t>(-1)) { + ARROW_LOG(FATAL) << "mktime() failed"; + } + // time_t is an arithmetic type on both POSIX and Windows, we can simply + // subtract to get a duration in seconds. + return static_cast<int64_t>(ts - base_line) * 1000 + millis; +} + +int32_t DaysSince(time_t base_line, int32_t yy, int32_t mm, int32_t dd, int32_t hr, + int32_t min, int32_t sec, int32_t millis) { + struct tm given_ts; + memset(&given_ts, 0, sizeof(struct tm)); + given_ts.tm_year = (yy - 1900); + given_ts.tm_mon = (mm - 1); + given_ts.tm_mday = dd; + given_ts.tm_hour = hr; + given_ts.tm_min = min; + given_ts.tm_sec = sec; + + time_t ts = mktime(&given_ts); + if (ts == static_cast<time_t>(-1)) { + ARROW_LOG(FATAL) << "mktime() failed"; + } + // time_t is an arithmetic type on both POSIX and Windows, we can simply + // subtract to get a duration in seconds. + return static_cast<int32_t>(((ts - base_line) * 1000 + millis) / MILLIS_IN_DAY); +} + +TEST_F(TestProjector, TestIsNull) { + auto d0 = field("d0", date64()); + auto t0 = field("t0", time32(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({d0, t0}); + + // output fields + auto b0 = field("isnull", boolean()); + + // isnull and isnotnull + auto isnull_expr = TreeExprBuilder::MakeExpression("isnull", {d0}, b0); + auto isnotnull_expr = TreeExprBuilder::MakeExpression("isnotnull", {t0}, b0); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {isnull_expr, isnotnull_expr}, + TestConfiguration(), &projector); + ASSERT_TRUE(status.ok()); + + int num_records = 4; + std::vector<int64_t> d0_data = {0, 100, 0, 1000}; + auto t0_data = {0, 100, 0, 1000}; + auto validity = {false, true, false, true}; + auto d0_array = + MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), d0_data, validity); + auto t0_array = MakeArrowTypeArray<arrow::Time32Type, int32_t>( + time32(arrow::TimeUnit::MILLI), t0_data, validity); + + // expected output + auto exp_isnull = + MakeArrowArrayBool({true, false, true, false}, {true, true, true, true}); + auto exp_isnotnull = MakeArrowArrayBool(validity, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {d0_array, t0_array}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_isnull, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_isnotnull, outputs.at(1)); +} + +TEST_F(TestProjector, TestDate32IsNull) { + auto d0 = field("d0", date32()); + auto schema = arrow::schema({d0}); + + // output fields + auto b0 = field("isnull", boolean()); + + // isnull and isnotnull + auto isnull_expr = TreeExprBuilder::MakeExpression("isnull", {d0}, b0); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {isnull_expr}, TestConfiguration(), &projector); + ASSERT_TRUE(status.ok()); + + int num_records = 4; + std::vector<int32_t> d0_data = {0, 100, 0, 1000}; + auto validity = {false, true, false, true}; + auto d0_array = + MakeArrowTypeArray<arrow::Date32Type, int32_t>(date32(), d0_data, validity); + + // expected output + auto exp_isnull = + MakeArrowArrayBool({true, false, true, false}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {d0_array}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_isnull, outputs.at(0)); +} + +TEST_F(TestProjector, TestDateTime) { + auto field0 = field("f0", date64()); + auto field1 = field("f1", date32()); + auto field2 = field("f2", timestamp(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({field0, field1, field2}); + + // output fields + auto field_year = field("yy", int64()); + auto field_month = field("mm", int64()); + auto field_day = field("dd", int64()); + auto field_hour = field("hh", int64()); + auto field_date64 = field("date64", date64()); + + // extract year and month from date + auto date2year_expr = + TreeExprBuilder::MakeExpression("extractYear", {field0}, field_year); + auto date2month_expr = + TreeExprBuilder::MakeExpression("extractMonth", {field0}, field_month); + + // extract year and month from date32, cast to date64 first + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto date32_to_date64_func = + TreeExprBuilder::MakeFunction("castDATE", {node_f1}, date64()); + + auto date64_2year_func = + TreeExprBuilder::MakeFunction("extractYear", {date32_to_date64_func}, int64()); + auto date64_2year_expr = TreeExprBuilder::MakeExpression(date64_2year_func, field_year); + + auto date64_2month_func = + TreeExprBuilder::MakeFunction("extractMonth", {date32_to_date64_func}, int64()); + auto date64_2month_expr = + TreeExprBuilder::MakeExpression(date64_2month_func, field_month); + + // extract month and day from timestamp + auto ts2month_expr = + TreeExprBuilder::MakeExpression("extractMonth", {field2}, field_month); + auto ts2day_expr = TreeExprBuilder::MakeExpression("extractDay", {field2}, field_day); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, + {date2year_expr, date2month_expr, date64_2year_expr, + date64_2month_expr, ts2month_expr, ts2day_expr}, + TestConfiguration(), &projector); + ASSERT_TRUE(status.ok()); + + // Create a row-batch with some sample data + time_t epoch = Epoch(); + int num_records = 4; + auto validity = {true, true, true, true}; + std::vector<int64_t> field0_data = {MillisSince(epoch, 2000, 1, 1, 5, 0, 0, 0), + MillisSince(epoch, 1999, 12, 31, 5, 0, 0, 0), + MillisSince(epoch, 2015, 6, 30, 20, 0, 0, 0), + MillisSince(epoch, 2015, 7, 1, 20, 0, 0, 0)}; + auto array0 = + MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), field0_data, validity); + + std::vector<int32_t> field1_data = {DaysSince(epoch, 2000, 1, 1, 5, 0, 0, 0), + DaysSince(epoch, 1999, 12, 31, 5, 0, 0, 0), + DaysSince(epoch, 2015, 6, 30, 20, 0, 0, 0), + DaysSince(epoch, 2015, 7, 1, 20, 0, 0, 0)}; + auto array1 = + MakeArrowTypeArray<arrow::Date32Type, int32_t>(date32(), field1_data, validity); + + std::vector<int64_t> field2_data = {MillisSince(epoch, 1999, 12, 31, 5, 0, 0, 0), + MillisSince(epoch, 2000, 1, 2, 5, 0, 0, 0), + MillisSince(epoch, 2015, 7, 1, 1, 0, 0, 0), + MillisSince(epoch, 2015, 6, 29, 23, 0, 0, 0)}; + + auto array2 = MakeArrowTypeArray<arrow::TimestampType, int64_t>( + arrow::timestamp(arrow::TimeUnit::MILLI), field2_data, validity); + + // expected output + // date 2 year and date 2 month for date64 + auto exp_yy_from_date64 = MakeArrowArrayInt64({2000, 1999, 2015, 2015}, validity); + auto exp_mm_from_date64 = MakeArrowArrayInt64({1, 12, 6, 7}, validity); + + // date 2 year and date 2 month for date32 + auto exp_yy_from_date32 = MakeArrowArrayInt64({2000, 1999, 2015, 2015}, validity); + auto exp_mm_from_date32 = MakeArrowArrayInt64({1, 12, 6, 7}, validity); + + // ts 2 month and ts 2 day + auto exp_mm_from_ts = MakeArrowArrayInt64({12, 1, 7, 6}, validity); + auto exp_dd_from_ts = MakeArrowArrayInt64({31, 2, 1, 29}, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_yy_from_date64, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_date64, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(exp_yy_from_date32, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_date32, outputs.at(3)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mm_from_ts, outputs.at(4)); + EXPECT_ARROW_ARRAY_EQUALS(exp_dd_from_ts, outputs.at(5)); +} + +TEST_F(TestProjector, TestTime) { + auto field0 = field("f0", time32(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({field0}); + + auto field_min = field("mm", int64()); + auto field_hour = field("hh", int64()); + + // extract day and hour from time32 + auto time2min_expr = + TreeExprBuilder::MakeExpression("extractMinute", {field0}, field_min); + auto time2hour_expr = + TreeExprBuilder::MakeExpression("extractHour", {field0}, field_hour); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {time2min_expr, time2hour_expr}, + TestConfiguration(), &projector); + ASSERT_TRUE(status.ok()); + + // create input data + int num_records = 4; + auto validity = {true, true, true, true}; + std::vector<int32_t> field_data = { + MillisInDay(5, 35, 25, 0), // 5:35:25 + MillisInDay(0, 59, 0, 0), // 0:59:12 + MillisInDay(12, 30, 0, 0), // 12:30:0 + MillisInDay(23, 0, 0, 0) // 23:0:0 + }; + auto array = MakeArrowTypeArray<arrow::Time32Type, int32_t>( + time32(arrow::TimeUnit::MILLI), field_data, validity); + + // expected output + auto exp_min = MakeArrowArrayInt64({35, 59, 30, 0}, validity); + auto exp_hour = MakeArrowArrayInt64({5, 0, 12, 23}, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_min, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_hour, outputs.at(1)); +} + +TEST_F(TestProjector, TestTimestampDiff) { + auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI)); + auto f1 = field("f1", timestamp(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({f0, f1}); + + // output fields + auto diff_seconds = field("ss", int32()); + + // get diff + auto diff_secs_expr = + TreeExprBuilder::MakeExpression("timestampdiffSecond", {f0, f1}, diff_seconds); + + auto diff_mins_expr = + TreeExprBuilder::MakeExpression("timestampdiffMinute", {f0, f1}, diff_seconds); + + auto diff_hours_expr = + TreeExprBuilder::MakeExpression("timestampdiffHour", {f0, f1}, diff_seconds); + + auto diff_days_expr = + TreeExprBuilder::MakeExpression("timestampdiffDay", {f0, f1}, diff_seconds); + + auto diff_days_expr_with_datediff_fn = + TreeExprBuilder::MakeExpression("datediff", {f0, f1}, diff_seconds); + + auto diff_weeks_expr = + TreeExprBuilder::MakeExpression("timestampdiffWeek", {f0, f1}, diff_seconds); + + auto diff_months_expr = + TreeExprBuilder::MakeExpression("timestampdiffMonth", {f0, f1}, diff_seconds); + + auto diff_quarters_expr = + TreeExprBuilder::MakeExpression("timestampdiffQuarter", {f0, f1}, diff_seconds); + + auto diff_years_expr = + TreeExprBuilder::MakeExpression("timestampdiffYear", {f0, f1}, diff_seconds); + + std::shared_ptr<Projector> projector; + auto exprs = {diff_secs_expr, + diff_mins_expr, + diff_hours_expr, + diff_days_expr, + diff_days_expr_with_datediff_fn, + diff_weeks_expr, + diff_months_expr, + diff_quarters_expr, + diff_years_expr}; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + ASSERT_TRUE(status.ok()); + + time_t epoch = Epoch(); + + // 2015-09-10T20:49:42.000 + auto start_millis = MillisSince(epoch, 2015, 9, 10, 20, 49, 42, 0); + // 2017-03-30T22:50:59.050 + auto end_millis = MillisSince(epoch, 2017, 3, 30, 22, 50, 59, 50); + std::vector<int64_t> f0_data = {start_millis, end_millis, + // 2015-09-10T20:49:42.999 + start_millis + 999, + // 2015-09-10T20:49:42.999 + MillisSince(epoch, 2015, 9, 10, 20, 49, 42, 999)}; + std::vector<int64_t> f1_data = {end_millis, start_millis, + // 2015-09-10T20:49:42.999 + start_millis + 999, + // 2015-09-9T21:49:42.999 (23 hours behind) + MillisSince(epoch, 2015, 9, 9, 21, 49, 42, 999)}; + + int64_t num_records = f0_data.size(); + std::vector<bool> validity(num_records, true); + auto array0 = MakeArrowTypeArray<arrow::TimestampType, int64_t>( + arrow::timestamp(arrow::TimeUnit::MILLI), f0_data, validity); + auto array1 = MakeArrowTypeArray<arrow::TimestampType, int64_t>( + arrow::timestamp(arrow::TimeUnit::MILLI), f1_data, validity); + + // expected output + std::vector<ArrayPtr> exp_output; + exp_output.push_back( + MakeArrowArrayInt32({48996077, -48996077, 0, -23 * 3600}, validity)); + exp_output.push_back(MakeArrowArrayInt32({816601, -816601, 0, -23 * 60}, validity)); + exp_output.push_back(MakeArrowArrayInt32({13610, -13610, 0, -23}, validity)); + exp_output.push_back(MakeArrowArrayInt32({567, -567, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({567, -567, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({81, -81, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({18, -18, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({6, -6, 0, 0}, validity)); + exp_output.push_back(MakeArrowArrayInt32({1, -1, 0, 0}, validity)); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + for (uint32_t i = 0; i < exp_output.size(); i++) { + EXPECT_ARROW_ARRAY_EQUALS(exp_output.at(i), outputs.at(i)); + } +} + +TEST_F(TestProjector, TestTimestampDiffMonth) { + auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI)); + auto f1 = field("f1", timestamp(arrow::TimeUnit::MILLI)); + auto schema = arrow::schema({f0, f1}); + + // output fields + auto diff_seconds = field("ss", int32()); + + auto diff_months_expr = + TreeExprBuilder::MakeExpression("timestampdiffMonth", {f0, f1}, diff_seconds); + + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {diff_months_expr}, TestConfiguration(), &projector); + std::cout << status.message(); + ASSERT_TRUE(status.ok()); + + time_t epoch = Epoch(); + + // Create a row-batch with some sample data + std::vector<int64_t> f0_data = {MillisSince(epoch, 2019, 1, 31, 0, 0, 0, 0), + MillisSince(epoch, 2020, 1, 31, 0, 0, 0, 0), + MillisSince(epoch, 2020, 1, 31, 0, 0, 0, 0), + MillisSince(epoch, 2019, 3, 31, 0, 0, 0, 0), + MillisSince(epoch, 2020, 3, 30, 0, 0, 0, 0), + MillisSince(epoch, 2020, 5, 31, 0, 0, 0, 0)}; + std::vector<int64_t> f1_data = {MillisSince(epoch, 2019, 2, 28, 0, 0, 0, 0), + MillisSince(epoch, 2020, 2, 28, 0, 0, 0, 0), + MillisSince(epoch, 2020, 2, 29, 0, 0, 0, 0), + MillisSince(epoch, 2019, 4, 30, 0, 0, 0, 0), + MillisSince(epoch, 2020, 2, 29, 0, 0, 0, 0), + MillisSince(epoch, 2020, 9, 30, 0, 0, 0, 0)}; + int64_t num_records = f0_data.size(); + std::vector<bool> validity(num_records, true); + + auto array0 = MakeArrowTypeArray<arrow::TimestampType, int64_t>( + arrow::timestamp(arrow::TimeUnit::MILLI), f0_data, validity); + auto array1 = MakeArrowTypeArray<arrow::TimestampType, int64_t>( + arrow::timestamp(arrow::TimeUnit::MILLI), f1_data, validity); + + // expected output + std::vector<ArrayPtr> exp_output; + exp_output.push_back(MakeArrowArrayInt32({1, 0, 1, 1, -1, 4}, validity)); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + for (uint32_t i = 0; i < exp_output.size(); i++) { + EXPECT_ARROW_ARRAY_EQUALS(exp_output.at(i), outputs.at(i)); + } +} + +TEST_F(TestProjector, TestMonthsBetween) { + auto f0 = field("f0", arrow::date64()); + auto f1 = field("f1", arrow::date64()); + auto schema = arrow::schema({f0, f1}); + + // output fields + auto output = field("out", arrow::float64()); + + auto months_between_expr = + TreeExprBuilder::MakeExpression("months_between", {f0, f1}, output); + + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {months_between_expr}, TestConfiguration(), &projector); + std::cout << status.message(); + ASSERT_TRUE(status.ok()); + + time_t epoch = Epoch(); + + // Create a row-batch with some sample data + int num_records = 4; + auto validity = {true, true, true, true}; + std::vector<int64_t> f0_data = {MillisSince(epoch, 1995, 3, 2, 0, 0, 0, 0), + MillisSince(epoch, 1995, 2, 2, 0, 0, 0, 0), + MillisSince(epoch, 1995, 3, 31, 0, 0, 0, 0), + MillisSince(epoch, 1996, 3, 31, 0, 0, 0, 0)}; + + auto array0 = + MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), f0_data, validity); + + std::vector<int64_t> f1_data = {MillisSince(epoch, 1995, 2, 2, 0, 0, 0, 0), + MillisSince(epoch, 1995, 3, 2, 0, 0, 0, 0), + MillisSince(epoch, 1995, 2, 28, 0, 0, 0, 0), + MillisSince(epoch, 1996, 2, 29, 0, 0, 0, 0)}; + + auto array1 = + MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), f1_data, validity); + + // expected output + auto exp_output = MakeArrowArrayFloat64({1.0, -1.0, 1.0, 1.0}, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0)); +} + +TEST_F(TestProjector, TestLastDay) { + auto f0 = field("f0", arrow::date64()); + auto schema = arrow::schema({f0}); + + // output fields + auto output = field("out", arrow::date64()); + + auto last_day_expr = TreeExprBuilder::MakeExpression("last_day", {f0}, output); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {last_day_expr}, TestConfiguration(), &projector); + std::cout << status.message(); + ASSERT_TRUE(status.ok()); + + time_t epoch = Epoch(); + + // Create a row-batch with some sample data + // Used a leap year as example. + int num_records = 5; + auto validity = {true, true, true, true, true}; + std::vector<int64_t> f0_data = {MillisSince(epoch, 2016, 2, 3, 8, 20, 10, 34), + MillisSince(epoch, 2016, 2, 29, 23, 59, 59, 59), + MillisSince(epoch, 2016, 1, 30, 1, 15, 20, 0), + MillisSince(epoch, 2017, 2, 3, 23, 15, 20, 0), + MillisSince(epoch, 2015, 12, 30, 22, 50, 11, 0)}; + + auto array0 = + MakeArrowTypeArray<arrow::Date64Type, int64_t>(date64(), f0_data, validity); + + std::vector<int64_t> f0_output_data = {MillisSince(epoch, 2016, 2, 29, 0, 0, 0, 0), + MillisSince(epoch, 2016, 2, 29, 0, 0, 0, 0), + MillisSince(epoch, 2016, 1, 31, 0, 0, 0, 0), + MillisSince(epoch, 2017, 2, 28, 0, 0, 0, 0), + MillisSince(epoch, 2015, 12, 31, 0, 0, 0, 0)}; + + // expected output + auto exp_output = MakeArrowArrayDate64(f0_output_data, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/decimal_single_test.cc b/src/arrow/cpp/src/gandiva/tests/decimal_single_test.cc new file mode 100644 index 000000000..666ee4a68 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/decimal_single_test.cc @@ -0,0 +1,305 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <sstream> + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" + +#include "gandiva/decimal_scalar.h" +#include "gandiva/decimal_type_util.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +using arrow::Decimal128; + +namespace gandiva { + +#define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual) \ + EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" << (y).ToString() \ + << ")" \ + << " expected : " << (expected).ToString() \ + << " actual : " << (actual).ToString(); + +DecimalScalar128 decimal_literal(const char* value, int precision, int scale) { + std::string value_string = std::string(value); + return DecimalScalar128(value_string, precision, scale); +} + +class TestDecimalOps : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + ArrayPtr MakeDecimalVector(const DecimalScalar128& in); + + void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x, + const DecimalScalar128& y, const DecimalScalar128& expected); + + void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected); + } + + void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected); + } + + void MultiplyAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + Verify(DecimalTypeUtil::kOpMultiply, "multiply", x, y, expected); + } + + void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected); + } + + void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + Verify(DecimalTypeUtil::kOpMod, "mod", x, y, expected); + } + + protected: + arrow::MemoryPool* pool_; +}; + +ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) { + std::vector<arrow::Decimal128> ret; + + Decimal128 decimal_value = in.value(); + + auto decimal_type = std::make_shared<arrow::Decimal128Type>(in.precision(), in.scale()); + return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true}); +} + +void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function, + const DecimalScalar128& x, const DecimalScalar128& y, + const DecimalScalar128& expected) { + auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale()); + auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale()); + auto field_x = field("x", x_type); + auto field_y = field("y", y_type); + auto schema = arrow::schema({field_x, field_y}); + + Decimal128TypePtr output_type; + auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type); + ARROW_EXPECT_OK(status); + + // output fields + auto res = field("res", output_type); + + // build expression : x op y + auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y}, res); + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + ARROW_EXPECT_OK(status); + + // Create a row-batch with some sample data + auto array_a = MakeDecimalVector(x); + auto array_b = MakeDecimalVector(y); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, 1 /*num_records*/, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ARROW_EXPECT_OK(status); + + // Validate results + auto out_array = dynamic_cast<arrow::Decimal128Array*>(outputs[0].get()); + const Decimal128 out_value(out_array->GetValue(0)); + + auto dtype = dynamic_cast<arrow::Decimal128Type*>(out_array->type().get()); + std::string value_string = out_value.ToString(0); + DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()}; + + EXPECT_DECIMAL_RESULT(function, x, y, expected, actual); +} + +TEST_F(TestDecimalOps, TestAdd) { + // fast-path + AddAndVerify(decimal_literal("201", 30, 3), // x + decimal_literal("301", 30, 3), // y + decimal_literal("502", 31, 3)); // expected + + AddAndVerify(decimal_literal("201", 30, 3), // x + decimal_literal("301", 30, 2), // y + decimal_literal("3211", 32, 3)); // expected + + AddAndVerify(decimal_literal("201", 30, 3), // x + decimal_literal("301", 30, 4), // y + decimal_literal("2311", 32, 4)); // expected + + // max precision, but no overflow + AddAndVerify(decimal_literal("201", 38, 3), // x + decimal_literal("301", 38, 3), // y + decimal_literal("502", 38, 3)); // expected + + AddAndVerify(decimal_literal("201", 38, 3), // x + decimal_literal("301", 38, 2), // y + decimal_literal("3211", 38, 3)); // expected + + AddAndVerify(decimal_literal("201", 38, 3), // x + decimal_literal("301", 38, 4), // y + decimal_literal("2311", 38, 4)); // expected + + AddAndVerify(decimal_literal("201", 38, 3), // x + decimal_literal("301", 38, 7), // y + decimal_literal("201030", 38, 6)); // expected + + AddAndVerify(decimal_literal("1201", 38, 3), // x + decimal_literal("1801", 38, 3), // y + decimal_literal("3002", 38, 3)); // carry-over from fractional + + // max precision + AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x + decimal_literal("100", 38, 7), // y + decimal_literal("99999999999999999999999999999990000010", 38, 6)); + + AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x + decimal_literal("100", 38, 7), // y + decimal_literal("-99999999999999999999999999999989999990", 38, 6)); + + AddAndVerify(decimal_literal("09999999999999999999999999999999000000", 38, 5), // x + decimal_literal("-100", 38, 7), // y + decimal_literal("99999999999999999999999999999989999990", 38, 6)); + + AddAndVerify(decimal_literal("-09999999999999999999999999999999000000", 38, 5), // x + decimal_literal("-100", 38, 7), // y + decimal_literal("-99999999999999999999999999999990000010", 38, 6)); + + AddAndVerify(decimal_literal("09999999999999999999999999999999999999", 38, 6), // x + decimal_literal("89999999999999999999999999999999999999", 38, 7), // y + decimal_literal("18999999999999999999999999999999999999", 38, 6)); + + // Both -ve + AddAndVerify(decimal_literal("-201", 30, 3), // x + decimal_literal("-301", 30, 2), // y + decimal_literal("-3211", 32, 3)); // expected + + AddAndVerify(decimal_literal("-201", 38, 3), // x + decimal_literal("-301", 38, 4), // y + decimal_literal("-2311", 38, 4)); // expected + + // Mix of +ve and -ve + AddAndVerify(decimal_literal("-201", 30, 3), // x + decimal_literal("301", 30, 2), // y + decimal_literal("2809", 32, 3)); // expected + + AddAndVerify(decimal_literal("-201", 38, 3), // x + decimal_literal("301", 38, 4), // y + decimal_literal("-1709", 38, 4)); // expected + + AddAndVerify(decimal_literal("201", 38, 3), // x + decimal_literal("-301", 38, 7), // y + decimal_literal("200970", 38, 6)); // expected + + AddAndVerify(decimal_literal("-1901", 38, 4), // x + decimal_literal("1801", 38, 4), // y + decimal_literal("-100", 38, 4)); // expected + + AddAndVerify(decimal_literal("1801", 38, 4), // x + decimal_literal("-1901", 38, 4), // y + decimal_literal("-100", 38, 4)); // expected + + // rounding +ve + AddAndVerify(decimal_literal("1000999", 38, 6), // x + decimal_literal("10000999", 38, 7), // y + decimal_literal("2001099", 38, 6)); + + AddAndVerify(decimal_literal("1000999", 38, 6), // x + decimal_literal("10000995", 38, 7), // y + decimal_literal("2001099", 38, 6)); + + AddAndVerify(decimal_literal("1000999", 38, 6), // x + decimal_literal("10000992", 38, 7), // y + decimal_literal("2001098", 38, 6)); + + // rounding -ve + AddAndVerify(decimal_literal("-1000999", 38, 6), // x + decimal_literal("-10000999", 38, 7), // y + decimal_literal("-2001099", 38, 6)); + + AddAndVerify(decimal_literal("-1000999", 38, 6), // x + decimal_literal("-10000995", 38, 7), // y + decimal_literal("-2001099", 38, 6)); + + AddAndVerify(decimal_literal("-1000999", 38, 6), // x + decimal_literal("-10000992", 38, 7), // y + decimal_literal("-2001098", 38, 6)); +} + +// subtract is a wrapper over add. so, minimal tests are sufficient. +TEST_F(TestDecimalOps, TestSubtract) { + // fast-path + SubtractAndVerify(decimal_literal("201", 30, 3), // x + decimal_literal("301", 30, 3), // y + decimal_literal("-100", 31, 3)); // expected + + // max precision + SubtractAndVerify( + decimal_literal("09999999999999999999999999999999000000", 38, 5), // x + decimal_literal("100", 38, 7), // y + decimal_literal("99999999999999999999999999999989999990", 38, 6)); + + // Mix of +ve and -ve + SubtractAndVerify(decimal_literal("-201", 30, 3), // x + decimal_literal("301", 30, 2), // y + decimal_literal("-3211", 32, 3)); // expected +} + +// Lots of unit tests for multiply/divide/mod in decimal_ops_test.cc. So, keeping these +// basic. +TEST_F(TestDecimalOps, TestMultiply) { + // fast-path + MultiplyAndVerify(decimal_literal("201", 10, 3), // x + decimal_literal("301", 10, 2), // y + decimal_literal("60501", 21, 5)); // expected + + // max precision + MultiplyAndVerify(DecimalScalar128(std::string(35, '9'), 38, 20), // x + DecimalScalar128(std::string(36, '9'), 38, 20), // x + DecimalScalar128("9999999999999999999999999999999999890", 38, 6)); +} + +TEST_F(TestDecimalOps, TestDivide) { + DivideAndVerify(decimal_literal("201", 10, 3), // x + decimal_literal("301", 10, 2), // y + decimal_literal("6677740863787", 23, 14)); // expected + + DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x + DecimalScalar128(std::string(35, '9'), 38, 20), // x + DecimalScalar128("1000000000", 38, 6)); +} + +TEST_F(TestDecimalOps, TestMod) { + ModAndVerify(decimal_literal("201", 20, 2), // x + decimal_literal("301", 20, 3), // y + decimal_literal("204", 20, 3)); // expected + + ModAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x + DecimalScalar128(std::string(35, '9'), 38, 21), // x + DecimalScalar128("9990", 38, 21)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/decimal_test.cc b/src/arrow/cpp/src/gandiva/tests/decimal_test.cc new file mode 100644 index 000000000..31f2dedf5 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/decimal_test.cc @@ -0,0 +1,1194 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <sstream> + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "arrow/util/decimal.h" + +#include "gandiva/decimal_type_util.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +using arrow::boolean; +using arrow::Decimal128; +using arrow::utf8; + +namespace gandiva { + +class TestDecimal : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + std::vector<Decimal128> MakeDecimalVector(std::vector<std::string> values, + int32_t scale); + + protected: + arrow::MemoryPool* pool_; +}; + +std::vector<Decimal128> TestDecimal::MakeDecimalVector(std::vector<std::string> values, + int32_t scale) { + std::vector<arrow::Decimal128> ret; + for (auto str : values) { + Decimal128 str_value; + int32_t str_precision; + int32_t str_scale; + + DCHECK_OK(Decimal128::FromString(str, &str_value, &str_precision, &str_scale)); + + Decimal128 scaled_value; + if (str_scale == scale) { + scaled_value = str_value; + } else { + scaled_value = str_value.Rescale(str_scale, scale).ValueOrDie(); + } + ret.push_back(scaled_value); + } + return ret; +} + +TEST_F(TestDecimal, TestSimple) { + // schema for input fields + constexpr int32_t precision = 36; + constexpr int32_t scale = 18; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_a = field("a", decimal_type); + auto field_b = field("b", decimal_type); + auto field_c = field("c", decimal_type); + auto schema = arrow::schema({field_a, field_b, field_c}); + + Decimal128TypePtr add2_type; + auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, + {decimal_type, decimal_type}, &add2_type); + + Decimal128TypePtr output_type; + status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, + {add2_type, decimal_type}, &output_type); + + // output fields + auto res = field("res0", output_type); + + // build expression : a + b + c + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeField(field_b); + auto node_c = TreeExprBuilder::MakeField(field_c); + auto add2 = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, add2_type); + auto add3 = TreeExprBuilder::MakeFunction("add", {add2, node_c}, output_type); + auto expr = TreeExprBuilder::MakeExpression(add3, res); + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale), + {false, true, true, true}); + auto array_b = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale), + {false, true, true, true}); + auto array_c = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"3", "4", "5", "6"}, scale), + {true, true, true, true}); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c}); + + auto expected = + MakeArrowArrayDecimal(output_type, MakeDecimalVector({"6", "9", "12", "15"}, scale), + {false, true, true, true}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(expected, outputs[0]); +} + +TEST_F(TestDecimal, TestLiteral) { + // schema for input fields + constexpr int32_t precision = 36; + constexpr int32_t scale = 18; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_a = field("a", decimal_type); + auto schema = arrow::schema({ + field_a, + }); + + Decimal128TypePtr add2_type; + auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, + {decimal_type, decimal_type}, &add2_type); + + // output fields + auto res = field("res0", add2_type); + + // build expression : a + b + c + auto node_a = TreeExprBuilder::MakeField(field_a); + static std::string decimal_point_six = "6"; + DecimalScalar128 literal(decimal_point_six, 2, 1); + auto node_b = TreeExprBuilder::MakeDecimalLiteral(literal); + auto add2 = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, add2_type); + auto expr = TreeExprBuilder::MakeExpression(add2, res); + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale), + {false, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + auto expected = MakeArrowArrayDecimal( + add2_type, MakeDecimalVector({"1.6", "2.6", "3.6", "4.6"}, scale), + {false, true, true, true}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(expected, outputs[0]); +} + +TEST_F(TestDecimal, TestIfElse) { + // schema for input fields + constexpr int32_t precision = 36; + constexpr int32_t scale = 18; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_a = field("a", decimal_type); + auto field_b = field("b", decimal_type); + auto field_c = field("c", arrow::boolean()); + auto schema = arrow::schema({field_a, field_b, field_c}); + + // output fields + auto field_result = field("res", decimal_type); + + // build expression. + // if (c) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeField(field_b); + auto node_c = TreeExprBuilder::MakeField(field_c); + auto if_node = TreeExprBuilder::MakeIf(node_c, node_a, node_b, decimal_type); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + Status status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "4"}, scale), + {false, true, true, true}); + auto array_b = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale), + {true, true, true, true}); + + auto array_c = MakeArrowArrayBool({true, false, true, false}, {true, true, true, true}); + + // expected output + auto exp = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"0", "3", "3", "5"}, scale), + {false, true, true, true}); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestDecimal, TestCompare) { + // schema for input fields + constexpr int32_t precision = 36; + constexpr int32_t scale = 18; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_a = field("a", decimal_type); + auto field_b = field("b", decimal_type); + auto schema = arrow::schema({field_a, field_b}); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("equal", {field_a, field_b}, + field("res_eq", boolean())), + TreeExprBuilder::MakeExpression("not_equal", {field_a, field_b}, + field("res_ne", boolean())), + TreeExprBuilder::MakeExpression("less_than", {field_a, field_b}, + field("res_lt", boolean())), + TreeExprBuilder::MakeExpression("less_than_or_equal_to", {field_a, field_b}, + field("res_le", boolean())), + TreeExprBuilder::MakeExpression("greater_than", {field_a, field_b}, + field("res_gt", boolean())), + TreeExprBuilder::MakeExpression("greater_than_or_equal_to", {field_a, field_b}, + field("res_ge", boolean())), + }; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "2", "3", "-4"}, scale), + {true, true, true, true}); + auto array_b = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"1", "3", "2", "-3"}, scale), + {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, false, false}), + outputs[0]); // equal + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, true, true}), + outputs[1]); // not_equal + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, false, true}), + outputs[2]); // less_than + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, true, false, true}), + outputs[3]); // less_than_or_equal_to + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, false, true, false}), + outputs[4]); // greater_than + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, true, false}), + outputs[5]); // greater_than_or_equal_to +} + +// ARROW-9092: This test is conditionally disabled when building with LLVM 9 +// because it hangs. +#if GANDIVA_LLVM_VERSION != 9 + +TEST_F(TestDecimal, TestRoundFunctions) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_a = field("a", decimal_type); + auto schema = arrow::schema({field_a}); + + auto scale_1 = TreeExprBuilder::MakeLiteral(1); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("abs", {field_a}, field("res_abs", decimal_type)), + TreeExprBuilder::MakeExpression("ceil", {field_a}, + field("res_ceil", arrow::decimal(precision, 0))), + TreeExprBuilder::MakeExpression("floor", {field_a}, + field("res_floor", arrow::decimal(precision, 0))), + TreeExprBuilder::MakeExpression("round", {field_a}, + field("res_round", arrow::decimal(precision, 0))), + TreeExprBuilder::MakeExpression( + "truncate", {field_a}, field("res_truncate", arrow::decimal(precision, 0))), + + TreeExprBuilder::MakeExpression( + TreeExprBuilder::MakeFunction("round", + {TreeExprBuilder::MakeField(field_a), scale_1}, + arrow::decimal(precision, 1)), + field("res_round_3", arrow::decimal(precision, 1))), + + TreeExprBuilder::MakeExpression( + TreeExprBuilder::MakeFunction("truncate", + {TreeExprBuilder::MakeField(field_a), scale_1}, + arrow::decimal(precision, 1)), + field("res_truncate_3", arrow::decimal(precision, 1))), + }; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto validity = {true, true, true, true}; + auto array_a = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.23", "1.58", "-1.23", "-1.58"}, scale), + validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + + // abs(x) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(decimal_type, + MakeDecimalVector({"1.23", "1.58", "1.23", "1.58"}, scale), + validity), + outputs[0]); + + // ceil(x) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision, 0), + MakeDecimalVector({"2", "2", "-1", "-1"}, 0), validity), + outputs[1]); + + // floor(x) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision, 0), + MakeDecimalVector({"1", "1", "-2", "-2"}, 0), validity), + outputs[2]); + + // round(x) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision, 0), + MakeDecimalVector({"1", "2", "-1", "-2"}, 0), validity), + outputs[3]); + + // truncate(x) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision, 0), + MakeDecimalVector({"1", "1", "-1", "-1"}, 0), validity), + outputs[4]); + + // round(x, 1) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision, 1), + MakeDecimalVector({"1.2", "1.6", "-1.2", "-1.6"}, 1), + validity), + outputs[5]); + + // truncate(x, 1) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision, 1), + MakeDecimalVector({"1.2", "1.5", "-1.2", "-1.5"}, 1), + validity), + outputs[6]); +} + +#endif // GANDIVA_LLVM_VERSION != 9 + +TEST_F(TestDecimal, TestCastFunctions) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto decimal_type_scale_1 = std::make_shared<arrow::Decimal128Type>(precision, 1); + auto field_int32 = field("int32", arrow::int32()); + auto field_int64 = field("int64", arrow::int64()); + auto field_float32 = field("float32", arrow::float32()); + auto field_float64 = field("float64", arrow::float64()); + auto field_dec = field("dec", decimal_type); + auto schema = + arrow::schema({field_int32, field_int64, field_float32, field_float64, field_dec}); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("castDECIMAL", {field_int32}, + field("int32_to_dec", decimal_type)), + TreeExprBuilder::MakeExpression("castDECIMAL", {field_int64}, + field("int64_to_dec", decimal_type)), + TreeExprBuilder::MakeExpression("castDECIMAL", {field_float32}, + field("float32_to_dec", decimal_type)), + TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64}, + field("float64_to_dec", decimal_type)), + TreeExprBuilder::MakeExpression("castDECIMAL", {field_dec}, + field("dec_to_dec", decimal_type_scale_1)), + TreeExprBuilder::MakeExpression("castBIGINT", {field_dec}, + field("dec_to_int64", arrow::int64())), + TreeExprBuilder::MakeExpression("castFLOAT8", {field_dec}, + field("dec_to_float64", arrow::float64())), + }; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto validity = {true, true, true, true}; + + auto array_int32 = MakeArrowArrayInt32({123, 158, -123, -158}); + auto array_int64 = MakeArrowArrayInt64({123, 158, -123, -158}); + auto array_float32 = MakeArrowArrayFloat32({1.23f, 1.58f, -1.23f, -1.58f}); + auto array_float64 = MakeArrowArrayFloat64({1.23, 1.58, -1.23, -1.58}); + auto array_dec = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.23", "1.58", "-1.23", "-1.58"}, scale), + validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make( + schema, num_records, + {array_int32, array_int64, array_float32, array_float64, array_dec}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + auto expected_int_dec = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"123", "158", "-123", "-158"}, scale), validity); + + // castDECIMAL(int32) + EXPECT_ARROW_ARRAY_EQUALS(expected_int_dec, outputs[0]); + + // castDECIMAL(int64) + EXPECT_ARROW_ARRAY_EQUALS(expected_int_dec, outputs[1]); + + // castDECIMAL(float32) + EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[2]); + + // castDECIMAL(float64) + EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[3]); + + // castDECIMAL(decimal) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision, 1), + MakeDecimalVector({"1.2", "1.6", "-1.2", "-1.6"}, 1), + validity), + outputs[4]); + + // castBIGINT(decimal) + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayInt64({1, 2, -1, -2}), outputs[5]); + + // castDOUBLE(decimal) + EXPECT_ARROW_ARRAY_EQUALS(array_float64, outputs[6]); +} + +// isnull, isnumeric +TEST_F(TestDecimal, TestIsNullNumericFunctions) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_dec = field("dec", decimal_type); + auto schema = arrow::schema({field_dec}); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("isnull", {field_dec}, + field("isnull", arrow::boolean())), + + TreeExprBuilder::MakeExpression("isnotnull", {field_dec}, + field("isnotnull", arrow::boolean())), + TreeExprBuilder::MakeExpression("isnumeric", {field_dec}, + field("isnumeric", arrow::boolean()))}; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 5; + auto validity = {false, true, true, true, false}; + + auto array_dec = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale), + validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + auto is_null = outputs.at(0); + auto is_not_null = outputs.at(1); + auto is_numeric = outputs.at(2); + + // isnull + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, false, false, true}), + outputs[0]); + + // isnotnull + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool(validity), outputs[1]); + + // isnumeric + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool(validity), outputs[2]); +} + +TEST_F(TestDecimal, TestIsDistinct) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale_1 = 2; + auto decimal_type_1 = std::make_shared<arrow::Decimal128Type>(precision, scale_1); + auto field_dec_1 = field("dec_1", decimal_type_1); + constexpr int32_t scale_2 = 1; + auto decimal_type_2 = std::make_shared<arrow::Decimal128Type>(precision, scale_2); + auto field_dec_2 = field("dec_2", decimal_type_2); + + auto schema = arrow::schema({field_dec_1, field_dec_2}); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("is_distinct_from", {field_dec_1, field_dec_2}, + field("isdistinct", arrow::boolean())), + + TreeExprBuilder::MakeExpression("is_not_distinct_from", {field_dec_1, field_dec_2}, + field("isnotdistinct", arrow::boolean()))}; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + + auto validity_1 = {true, false, true, true}; + auto array_dec_1 = MakeArrowArrayDecimal( + decimal_type_1, MakeDecimalVector({"1.51", "1.23", "1.20", "-1.20"}, scale_1), + validity_1); + + auto validity_2 = {true, false, false, true}; + auto array_dec_2 = MakeArrowArrayDecimal( + decimal_type_2, MakeDecimalVector({"1.5", "1.2", "1.2", "-1.2"}, scale_2), + validity_2); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_dec_1, array_dec_2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + auto is_distinct = std::dynamic_pointer_cast<arrow::BooleanArray>(outputs.at(0)); + auto is_not_distinct = std::dynamic_pointer_cast<arrow::BooleanArray>(outputs.at(1)); + + // isdistinct + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, false, true, false}), outputs[0]); + + // isnotdistinct + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, true, false, true}), outputs[1]); +} + +// decimal hashes without seed +TEST_F(TestDecimal, TestHashFunctions) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_dec = field("dec", decimal_type); + auto literal_seed32 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto literal_seed64 = TreeExprBuilder::MakeLiteral((int64_t)10); + auto schema = arrow::schema({field_dec}); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("hash", {field_dec}, + field("hash_of_dec", arrow::int32())), + + TreeExprBuilder::MakeExpression("hash64", {field_dec}, + field("hash64_of_dec", arrow::int64())), + + TreeExprBuilder::MakeExpression("hash32AsDouble", {field_dec}, + field("hash32_as_double", arrow::int32())), + + TreeExprBuilder::MakeExpression("hash64AsDouble", {field_dec}, + field("hash64_as_double", arrow::int64()))}; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 5; + auto validity = {false, true, true, true, true}; + + auto array_dec = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale), + validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0)); + EXPECT_EQ(int32_arr->null_count(), 0); + EXPECT_EQ(int32_arr->Value(0), 0); + EXPECT_EQ(int32_arr->Value(1), int32_arr->Value(2)); + EXPECT_NE(int32_arr->Value(2), int32_arr->Value(3)); + EXPECT_NE(int32_arr->Value(3), int32_arr->Value(4)); + + auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1)); + EXPECT_EQ(int64_arr->null_count(), 0); + EXPECT_EQ(int64_arr->Value(0), 0); + EXPECT_EQ(int64_arr->Value(1), int64_arr->Value(2)); + EXPECT_NE(int64_arr->Value(2), int64_arr->Value(3)); + EXPECT_NE(int64_arr->Value(3), int64_arr->Value(4)); +} + +TEST_F(TestDecimal, TestHash32WithSeed) { + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_dec_1 = field("dec1", decimal_type); + auto field_dec_2 = field("dec2", decimal_type); + auto schema = arrow::schema({field_dec_1, field_dec_2}); + + auto res = field("hash32_with_seed", arrow::int32()); + + auto field_1_nodePtr = TreeExprBuilder::MakeField(field_dec_1); + auto field_2_nodePtr = TreeExprBuilder::MakeField(field_dec_2); + + auto hash32 = + TreeExprBuilder::MakeFunction("hash32", {field_2_nodePtr}, arrow::int32()); + auto hash32_with_seed = + TreeExprBuilder::MakeFunction("hash32", {field_1_nodePtr, hash32}, arrow::int32()); + auto expr = TreeExprBuilder::MakeExpression(hash32, field("hash32", arrow::int32())); + auto exprWS = TreeExprBuilder::MakeExpression(hash32_with_seed, res); + + auto exprs = std::vector<ExpressionPtr>{expr, exprWS}; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 5; + auto validity_1 = {false, false, true, true, true}; + + auto array_dec_1 = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale), + validity_1); + + auto validity_2 = {false, true, false, true, true}; + + auto array_dec_2 = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale), + validity_2); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_dec_1, array_dec_2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0)); + auto int32_arr_WS = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(1)); + EXPECT_EQ(int32_arr->null_count(), 0); + // seed 0, null decimal + EXPECT_EQ(int32_arr_WS->Value(0), 0); + // null decimal => hash = seed + EXPECT_EQ(int32_arr_WS->Value(1), int32_arr->Value(1)); + // seed = 0 => hash = hash without seed + EXPECT_EQ(int32_arr_WS->Value(2), int32_arr->Value(1)); + // different inputs => different outputs + EXPECT_NE(int32_arr_WS->Value(3), int32_arr_WS->Value(4)); + // hash with, without seed are not equal + EXPECT_NE(int32_arr_WS->Value(4), int32_arr->Value(4)); +} + +TEST_F(TestDecimal, TestHash64WithSeed) { + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_dec_1 = field("dec1", decimal_type); + auto field_dec_2 = field("dec2", decimal_type); + auto schema = arrow::schema({field_dec_1, field_dec_2}); + + auto res = field("hash64_with_seed", arrow::int64()); + + auto field_1_nodePtr = TreeExprBuilder::MakeField(field_dec_1); + auto field_2_nodePtr = TreeExprBuilder::MakeField(field_dec_2); + + auto hash64 = + TreeExprBuilder::MakeFunction("hash64", {field_2_nodePtr}, arrow::int64()); + auto hash64_with_seed = + TreeExprBuilder::MakeFunction("hash64", {field_1_nodePtr, hash64}, arrow::int64()); + auto expr = TreeExprBuilder::MakeExpression(hash64, field("hash64", arrow::int64())); + auto exprWS = TreeExprBuilder::MakeExpression(hash64_with_seed, res); + + auto exprs = std::vector<ExpressionPtr>{expr, exprWS}; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 5; + auto validity_1 = {false, false, true, true, true}; + + auto array_dec_1 = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale), + validity_1); + + auto validity_2 = {false, true, false, true, true}; + + auto array_dec_2 = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"1.51", "1.23", "1.23", "-1.23", "-1.24"}, scale), + validity_2); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_dec_1, array_dec_2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(0)); + auto int64_arr_WS = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1)); + EXPECT_EQ(int64_arr->null_count(), 0); + // seed 0, null decimal + EXPECT_EQ(int64_arr_WS->Value(0), 0); + // null decimal => hash = seed + EXPECT_EQ(int64_arr_WS->Value(1), int64_arr->Value(1)); + // seed = 0 => hash = hash without seed + EXPECT_EQ(int64_arr_WS->Value(2), int64_arr->Value(1)); + // different inputs => different outputs + EXPECT_NE(int64_arr_WS->Value(3), int64_arr_WS->Value(4)); + // hash with, without seed are not equal + EXPECT_NE(int64_arr_WS->Value(4), int64_arr->Value(4)); +} + +TEST_F(TestDecimal, TestNullDecimalConstant) { + // schema for input fields + constexpr int32_t precision = 36; + constexpr int32_t scale = 18; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field_b = field("b", decimal_type); + auto field_c = field("c", arrow::boolean()); + auto schema = arrow::schema({field_b, field_c}); + + // output fields + auto field_result = field("res", decimal_type); + + // build expression. + // if (c) + // null + // else + // b + auto node_a = TreeExprBuilder::MakeNull(decimal_type); + auto node_b = TreeExprBuilder::MakeField(field_b); + auto node_c = TreeExprBuilder::MakeField(field_c); + auto if_node = TreeExprBuilder::MakeIf(node_c, node_a, node_b, decimal_type); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + Status status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + + auto array_b = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"2", "3", "4", "5"}, scale), + {true, true, true, true}); + + auto array_c = MakeArrowArrayBool({true, false, true, false}, {true, true, true, true}); + + // expected output + auto exp = + MakeArrowArrayDecimal(decimal_type, MakeDecimalVector({"0", "3", "3", "5"}, scale), + {false, true, false, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_b, array_c}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestDecimal, TestCastVarCharDecimal) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + + auto field_dec = field("dec", decimal_type); + auto field_res_str = field("res_str", utf8()); + auto field_res_str_1 = field("res_str_1", utf8()); + auto schema = arrow::schema({field_dec, field_res_str, field_res_str_1}); + + // output fields + auto res_str = field("res_str", utf8()); + auto equals_res_bool = field("equals_res", boolean()); + + // build expressions. + auto node_dec = TreeExprBuilder::MakeField(field_dec); + auto node_res_str = TreeExprBuilder::MakeField(field_res_str); + auto node_res_str_1 = TreeExprBuilder::MakeField(field_res_str_1); + // limits decimal string to input length + auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5)); + auto str_len_limit_1 = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(1)); + auto cast_varchar = + TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8()); + auto cast_varchar_1 = + TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit_1}, utf8()); + auto equals = + TreeExprBuilder::MakeFunction("equal", {cast_varchar, node_res_str}, boolean()); + auto equals_1 = + TreeExprBuilder::MakeFunction("equal", {cast_varchar_1, node_res_str_1}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(equals, equals_res_bool); + auto expr_1 = TreeExprBuilder::MakeExpression(equals_1, equals_res_bool); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + + auto status = Projector::Make(schema, {expr, expr_1}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_dec = MakeArrowArrayDecimal( + decimal_type, + MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale), + {true, false, true, true, true}); + auto array_str_res = MakeArrowArrayUtf8({"10.51", "-null-", "100.2", "-1000", "-0.10"}, + {true, false, true, true, true}); + auto array_str_res_1 = + MakeArrowArrayUtf8({"1", "-null-", "1", "-", "-"}, {true, false, true, true, true}); + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, + {array_dec, array_str_res, array_str_res_1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + auto exp = MakeArrowArrayBool({true, false, true, true, true}, + {true, false, true, true, true}); + auto exp_1 = MakeArrowArrayBool({true, false, true, true, true}, + {true, false, true, true, true}); + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[1]); +} + +TEST_F(TestDecimal, TestCastDecimalVarChar) { + // schema for input fields + constexpr int32_t precision = 4; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + + auto field_str = field("in_str", utf8()); + auto schema = arrow::schema({field_str}); + + // output fields + auto res_dec = field("res_dec", decimal_type); + + // build expressions. + auto node_str = TreeExprBuilder::MakeField(field_str); + auto cast_decimal = + TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type); + auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + + auto array_str = MakeArrowArrayUtf8({"10.5134", "-0.0", "-0.1", "10.516", "-1000"}, + {true, false, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_str}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + auto array_dec = MakeArrowArrayDecimal( + decimal_type, MakeDecimalVector({"10.51", "1.23", "-0.10", "10.52", "0.00"}, scale), + {true, false, true, true, true}); + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(array_dec, outputs[0]); +} + +TEST_F(TestDecimal, TestCastDecimalVarCharInvalidInput) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale = 0; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + + auto field_str = field("in_str", utf8()); + auto schema = arrow::schema({field_str}); + + // output fields + auto res_dec = field("res_dec", decimal_type); + + // build expressions. + auto node_str = TreeExprBuilder::MakeField(field_str); + auto cast_decimal = + TreeExprBuilder::MakeFunction("castDECIMAL", {node_str}, decimal_type); + auto expr = TreeExprBuilder::MakeExpression(cast_decimal, res_dec); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + + // invalid input + auto invalid_in = MakeArrowArrayUtf8({"a10.5134", "-0.0", "-0.1", "10.516", "-1000"}, + {true, false, true, true, true}); + + // prepare input record batch + auto in_batch_1 = arrow::RecordBatch::Make(schema, num_records, {invalid_in}); + + // Evaluate expression + arrow::ArrayVector outputs_1; + status = projector->Evaluate(*in_batch_1, pool_, &outputs_1); + EXPECT_FALSE(status.ok()) << status.message(); + EXPECT_NE(status.message().find("not a valid decimal128 number"), std::string::npos); +} + +TEST_F(TestDecimal, TestVarCharDecimalNestedCast) { + // schema for input fields + constexpr int32_t precision = 38; + constexpr int32_t scale = 2; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + + auto field_dec = field("dec", decimal_type); + auto schema = arrow::schema({field_dec}); + + // output fields + auto field_dec_res = field("dec_res", decimal_type); + + // build expressions. + auto node_dec = TreeExprBuilder::MakeField(field_dec); + + // limits decimal string to input length + auto str_len_limit = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(5)); + auto cast_varchar = + TreeExprBuilder::MakeFunction("castVARCHAR", {node_dec, str_len_limit}, utf8()); + auto cast_decimal = + TreeExprBuilder::MakeFunction("castDECIMAL", {cast_varchar}, decimal_type); + + auto expr = TreeExprBuilder::MakeExpression(cast_decimal, field_dec_res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_dec = MakeArrowArrayDecimal( + decimal_type, + MakeDecimalVector({"10.51", "1.23", "100.23", "-1000.23", "-0000.10"}, scale), + {true, false, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + auto array_dec_res = MakeArrowArrayDecimal( + decimal_type, + MakeDecimalVector({"10.51", "1.23", "100.20", "-1000.00", "-0.10"}, scale), + {true, false, true, true, true}); + EXPECT_ARROW_ARRAY_EQUALS(array_dec_res, outputs[0]); +} + +TEST_F(TestDecimal, TestCastDecimalOverflow) { + // schema for input fields + constexpr int32_t precision_in = 5; + constexpr int32_t scale_in = 2; + constexpr int32_t precision_out = 3; + constexpr int32_t scale_out = 1; + auto decimal_5_2 = std::make_shared<arrow::Decimal128Type>(precision_in, scale_in); + auto decimal_3_1 = std::make_shared<arrow::Decimal128Type>(precision_out, scale_out); + + auto field_dec = field("dec", decimal_5_2); + auto schema = arrow::schema({field_dec}); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("castDECIMAL", {field_dec}, + field("dec_to_dec", decimal_3_1)), + TreeExprBuilder::MakeExpression("castDECIMALNullOnOverflow", {field_dec}, + field("dec_to_dec_null_overflow", decimal_3_1)), + }; + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto validity = {true, true, true, true}; + + auto array_dec = MakeArrowArrayDecimal( + decimal_5_2, MakeDecimalVector({"1.23", "671.58", "-1.23", "-1.58"}, scale_in), + validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + DCHECK_OK(status); + + // Validate results + // castDECIMAL(decimal) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision_out, 1), + MakeDecimalVector({"1.2", "0.0", "-1.2", "-1.6"}, 1), + validity), + outputs[0]); + + // castDECIMALNullOnOverflow(decimal) + EXPECT_ARROW_ARRAY_EQUALS( + MakeArrowArrayDecimal(arrow::decimal(precision_out, 1), + MakeDecimalVector({"1.2", "1.6", "-1.2", "-1.6"}, 1), + {true, false, true, true}), + outputs[1]); +} + +TEST_F(TestDecimal, TestSha) { + // schema for input fields + const std::shared_ptr<arrow::DataType>& decimal_5_2 = arrow::decimal128(5, 2); + auto field_a = field("a", decimal_5_2); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_0 = field("res0", utf8()); + auto res_1 = field("res1", utf8()); + + // build expressions. + // hashSHA1(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto hashSha1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8()); + auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1, res_0); + + auto hashSha256 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8()); + auto expr_1 = TreeExprBuilder::MakeExpression(hashSha256, res_1); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {expr_0, expr_1}, TestConfiguration(), &projector); + ASSERT_OK(status) << status.message(); + + // Create a row-batch with some sample data + int num_records = 3; + auto validity_array = {false, true, true}; + + auto array_dec = MakeArrowArrayDecimal( + decimal_5_2, MakeDecimalVector({"3.45", "0", "0.01"}, 2), validity_array); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_dec}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + auto response = outputs.at(0); + EXPECT_EQ(response->null_count(), 0); + EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), ""); + + // Checks if the hash size in response is correct + const int sha1_hash_size = 40; + for (int i = 1; i < num_records; ++i) { + const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString(); + + EXPECT_EQ(value_at_position.size(), sha1_hash_size); + EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString()); + } + + response = outputs.at(1); + EXPECT_EQ(response->null_count(), 0); + EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), ""); + + // Checks if the hash size in response is correct + const int sha256_hash_size = 64; + for (int i = 1; i < num_records; ++i) { + const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString(); + + EXPECT_EQ(value_at_position.size(), sha256_hash_size); + EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString()); + } +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/filter_project_test.cc b/src/arrow/cpp/src/gandiva/tests/filter_project_test.cc new file mode 100644 index 000000000..0607feaef --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/filter_project_test.cc @@ -0,0 +1,276 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "gandiva/filter.h" +#include "gandiva/projector.h" +#include "gandiva/selection_vector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestFilterProject : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestFilterProject, TestSimple16) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto field2 = field("f2", int32()); + auto resultField = field("result", int32()); + auto schema = arrow::schema({field0, field1, field2}); + + // Build condition f0 < f1 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto node_f2 = TreeExprBuilder::MakeField(field2); + auto less_than_function = + TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_function); + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField); + + auto configuration = TestConfiguration(); + + std::shared_ptr<Filter> filter; + std::shared_ptr<Projector> projector; + + auto status = Filter::Make(schema, condition, configuration, &filter); + EXPECT_TRUE(status.ok()); + + status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT16, + configuration, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true}); + auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false}); + // expected output + auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false}); + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + arrow::ArrayVector outputs; + + status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0)); +} + +TEST_F(TestFilterProject, TestSimple32) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto field2 = field("f2", int32()); + auto resultField = field("result", int32()); + auto schema = arrow::schema({field0, field1, field2}); + + // Build condition f0 < f1 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto node_f2 = TreeExprBuilder::MakeField(field2); + auto less_than_function = + TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_function); + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField); + + auto configuration = TestConfiguration(); + + std::shared_ptr<Filter> filter; + std::shared_ptr<Projector> projector; + + auto status = Filter::Make(schema, condition, configuration, &filter); + EXPECT_TRUE(status.ok()); + + status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT32, + configuration, &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true}); + auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false}); + // expected output + auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false}); + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + arrow::ArrayVector outputs; + + status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); + ASSERT_OK(status); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0)); +} + +TEST_F(TestFilterProject, TestSimple64) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto field2 = field("f2", int32()); + auto resultField = field("result", int32()); + auto schema = arrow::schema({field0, field1, field2}); + + // Build condition f0 < f1 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto node_f2 = TreeExprBuilder::MakeField(field2); + auto less_than_function = + TreeExprBuilder::MakeFunction("less_than", {node_f0, node_f1}, arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_function); + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field1, field2}, resultField); + + auto configuration = TestConfiguration(); + + std::shared_ptr<Filter> filter; + std::shared_ptr<Projector> projector; + + auto status = Filter::Make(schema, condition, configuration, &filter); + EXPECT_TRUE(status.ok()); + + status = Projector::Make(schema, {sum_expr}, SelectionVector::MODE_UINT64, + configuration, &projector); + ASSERT_OK(status); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 3, 17, 6}, {true, true, true, true, true}); + auto array2 = MakeArrowArrayInt32({1, 2, 6, 40, 3}, {true, true, true, true, false}); + // expected output + auto result = MakeArrowArrayInt32({6, 11, 0}, {true, true, false}); + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt64(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + arrow::ArrayVector outputs; + + status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(result, outputs.at(0)); +} + +TEST_F(TestFilterProject, TestSimpleIf) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + + auto greater_than_function = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto filter_condition = TreeExprBuilder::MakeCondition(greater_than_function); + + auto project_condition = + TreeExprBuilder::MakeFunction("less_than", {node_b, node_c}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(project_condition, node_b, node_c, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + auto configuration = TestConfiguration(); + + // Build a filter for the expressions. + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, filter_condition, configuration, &filter); + EXPECT_TRUE(status.ok()); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + status = Projector::Make(schema, {expr}, SelectionVector::MODE_UINT32, configuration, + &projector); + ASSERT_OK(status); + + // Create a row-batch with some sample data + int num_records = 6; + auto array0 = + MakeArrowArrayInt32({10, 12, -20, 5, 21, 29}, {true, true, true, true, true, true}); + auto array1 = + MakeArrowArrayInt32({5, 15, 15, 17, 12, 3}, {true, true, true, true, true, true}); + auto array2 = MakeArrowArrayInt32({1, 25, 11, 30, -21, 30}, + {true, true, true, true, true, false}); + + // Create a selection vector + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // expected output + auto exp = MakeArrowArrayInt32({1, -21, 0}, {true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + // Evaluate filter + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate project + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, selection_vector.get(), pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/filter_test.cc b/src/arrow/cpp/src/gandiva/tests/filter_test.cc new file mode 100644 index 000000000..d4433f11e --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/filter_test.cc @@ -0,0 +1,340 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/filter.h" +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestFilter : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestFilter, TestFilterCache) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + auto configuration = TestConfiguration(); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, configuration, &filter); + EXPECT_TRUE(status.ok()); + + // same schema and condition, should return the same filter as above. + std::shared_ptr<Filter> cached_filter; + status = Filter::Make(schema, condition, configuration, &cached_filter); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_filter.get() == filter.get()); + + // schema is different should return a new filter. + auto field2 = field("f2", int32()); + auto different_schema = arrow::schema({field0, field1, field2}); + std::shared_ptr<Filter> should_be_new_filter; + status = + Filter::Make(different_schema, condition, configuration, &should_be_new_filter); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_filter.get() != should_be_new_filter.get()); + + // condition is different, should return a new filter. + auto greater_than_10 = TreeExprBuilder::MakeFunction( + "greater_than", {sum_func, literal_10}, arrow::boolean()); + auto new_condition = TreeExprBuilder::MakeCondition(greater_than_10); + std::shared_ptr<Filter> should_be_new_filter1; + status = Filter::Make(schema, new_condition, configuration, &should_be_new_filter1); + EXPECT_TRUE(status.ok()); + EXPECT_TRUE(cached_filter.get() != should_be_new_filter1.get()); +} + +TEST_F(TestFilter, TestSimple) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({0, 4}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestFilter, TestSimpleCustomConfig) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 != f1 + auto condition = TreeExprBuilder::MakeCondition("not_equal", {field0, field1}); + + ConfigurationBuilder config_builder; + std::shared_ptr<Configuration> config = config_builder.build(); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({11, 2, 3, 17}, {true, true, false, true}); + // expected output + auto exp = MakeArrowArrayUint16({0}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestFilter, TestZeroCopy) { + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // Build condition + auto condition = TreeExprBuilder::MakeCondition("isnotnull", {field0}); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayUint16({0, 1, 2}); + + // allocate selection buffers + int64_t data_sz = sizeof(int16_t) * num_records; + std::unique_ptr<uint8_t[]> data(new uint8_t[data_sz]); + std::shared_ptr<arrow::MutableBuffer> data_buf = + std::make_shared<arrow::MutableBuffer>(data.get(), data_sz); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, data_buf, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestFilter, TestZeroCopyNegative) { + ArrayPtr output; + + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // Build expression + auto condition = TreeExprBuilder::MakeCondition("isnotnull", {field0}); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayInt16({0, 1, 2}); + + // allocate output buffers + int64_t data_sz = sizeof(int16_t) * num_records; + std::unique_ptr<uint8_t[]> data(new uint8_t[data_sz]); + std::shared_ptr<arrow::MutableBuffer> data_buf = + std::make_shared<arrow::MutableBuffer>(data.get(), data_sz); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, data_buf, &selection_vector); + EXPECT_TRUE(status.ok()); + + // the batch can't be empty. + auto bad_batch = arrow::RecordBatch::Make(schema, 0 /*num_records*/, {array0}); + status = filter->Evaluate(*bad_batch, selection_vector); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the selection_vector can't be null. + std::shared_ptr<SelectionVector> null_selection; + status = filter->Evaluate(*in_batch, null_selection); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the selection vector must be suitably sized. + std::shared_ptr<SelectionVector> bad_selection; + status = SelectionVector::MakeInt16(num_records - 1, data_buf, &bad_selection); + EXPECT_TRUE(status.ok()); + + status = filter->Evaluate(*in_batch, bad_selection); + EXPECT_EQ(status.code(), StatusCode::Invalid); +} + +TEST_F(TestFilter, TestSimpleSVInt32) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint32({0, 4}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt32(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestFilter, TestOffset) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = + MakeArrowArrayInt32({0, 1, 2, 3, 4, 6}, {true, true, true, true, false, true}); + array0 = array0->Slice(1); + auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({3}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + in_batch = in_batch->Slice(1); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/generate_data.h b/src/arrow/cpp/src/gandiva/tests/generate_data.h new file mode 100644 index 000000000..9fb0e4eae --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/generate_data.h @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <stdlib.h> +#include <random> +#include <string> + +#include "arrow/util/decimal.h" +#include "arrow/util/io_util.h" + +#pragma once + +namespace gandiva { + +template <typename C_TYPE> +class DataGenerator { + public: + virtual ~DataGenerator() = default; + + virtual C_TYPE GenerateData() = 0; +}; + +class Random { + public: + Random() : gen_(::arrow::internal::GetRandomSeed()) {} + explicit Random(uint64_t seed) : gen_(seed) {} + + int32_t next() { return gen_(); } + + private: + std::default_random_engine gen_; +}; + +class Int32DataGenerator : public DataGenerator<int32_t> { + public: + Int32DataGenerator() {} + + int32_t GenerateData() { return random_.next(); } + + protected: + Random random_; +}; + +class BoundedInt32DataGenerator : public Int32DataGenerator { + public: + explicit BoundedInt32DataGenerator(uint32_t upperBound) + : Int32DataGenerator(), upperBound_(upperBound) {} + + int32_t GenerateData() { + int32_t value = (random_.next() % upperBound_); + return value; + } + + protected: + uint32_t upperBound_; +}; + +class Int64DataGenerator : public DataGenerator<int64_t> { + public: + Int64DataGenerator() {} + + int64_t GenerateData() { return random_.next(); } + + protected: + Random random_; +}; + +class Decimal128DataGenerator : public DataGenerator<arrow::Decimal128> { + public: + explicit Decimal128DataGenerator(bool large) : large_(large) {} + + arrow::Decimal128 GenerateData() { + uint64_t low = random_.next(); + int64_t high = random_.next(); + if (large_) { + high += (1ull << 62); + } + return arrow::Decimal128(high, low); + } + + protected: + bool large_; + Random random_; +}; + +class FastUtf8DataGenerator : public DataGenerator<std::string> { + public: + explicit FastUtf8DataGenerator(int max_len) : max_len_(max_len), cur_char_('a') {} + + std::string GenerateData() { + std::string generated_str; + + int slen = random_.next() % max_len_; + for (int i = 0; i < slen; ++i) { + generated_str += generate_next_char(); + } + return generated_str; + } + + private: + char generate_next_char() { + ++cur_char_; + if (cur_char_ > 'z') { + cur_char_ = 'a'; + } + return cur_char_; + } + + Random random_; + unsigned int max_len_; + char cur_char_; +}; + +class Utf8IntDataGenerator : public DataGenerator<std::string> { + public: + Utf8IntDataGenerator() {} + + std::string GenerateData() { return std::to_string(random_.next()); } + + private: + Random random_; +}; + +class Utf8FloatDataGenerator : public DataGenerator<std::string> { + public: + Utf8FloatDataGenerator() {} + + std::string GenerateData() { + return std::to_string( + static_cast<float>(random_.next()) / + static_cast<float>(RAND_MAX / 100)); // random float between 0.0 to 100.0 + } + + private: + Random random_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/hash_test.cc b/src/arrow/cpp/src/gandiva/tests/hash_test.cc new file mode 100644 index 000000000..40ebc50a2 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/hash_test.cc @@ -0,0 +1,615 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> + +#include <sstream> + +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::float64; +using arrow::int32; +using arrow::int64; +using arrow::utf8; + +class TestHash : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestHash, TestSimple) { + // schema for input fields + auto field_a = field("a", int32()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_0 = field("res0", int32()); + auto res_1 = field("res1", int64()); + + // build expression. + // hash32(a, 10) + // hash64(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto hash32 = TreeExprBuilder::MakeFunction("hash32", {node_a, literal_10}, int32()); + auto hash64 = TreeExprBuilder::MakeFunction("hash64", {node_a}, int64()); + auto expr_0 = TreeExprBuilder::MakeExpression(hash32, res_0); + auto expr_1 = TreeExprBuilder::MakeExpression(hash64, res_1); + + // Build a projector for the expression. + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {expr_0, expr_1}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayInt32({1, 2, 3, 4}, {false, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0)); + EXPECT_EQ(int32_arr->null_count(), 0); + EXPECT_EQ(int32_arr->Value(0), 10); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int32_arr->Value(i), int32_arr->Value(i - 1)); + } + + auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1)); + EXPECT_EQ(int64_arr->null_count(), 0); + EXPECT_EQ(int64_arr->Value(0), 0); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int64_arr->Value(i), int64_arr->Value(i - 1)); + } +} + +TEST_F(TestHash, TestBuf) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_0 = field("res0", int32()); + auto res_1 = field("res1", int64()); + + // build expressions. + // hash32(a) + // hash64(a, 10) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_10 = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(10)); + auto hash32 = TreeExprBuilder::MakeFunction("hash32", {node_a}, int32()); + auto hash64 = TreeExprBuilder::MakeFunction("hash64", {node_a, literal_10}, int64()); + auto expr_0 = TreeExprBuilder::MakeExpression(hash32, res_0); + auto expr_1 = TreeExprBuilder::MakeExpression(hash64, res_1); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {expr_0, expr_1}, TestConfiguration(), &projector); + ASSERT_OK(status) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {false, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + // Validate results + auto int32_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0)); + EXPECT_EQ(int32_arr->null_count(), 0); + EXPECT_EQ(int32_arr->Value(0), 0); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int32_arr->Value(i), int32_arr->Value(i - 1)); + } + + auto int64_arr = std::dynamic_pointer_cast<arrow::Int64Array>(outputs.at(1)); + EXPECT_EQ(int64_arr->null_count(), 0); + EXPECT_EQ(int64_arr->Value(0), 10); + for (int i = 1; i < num_records; ++i) { + EXPECT_NE(int64_arr->Value(i), int64_arr->Value(i - 1)); + } +} + +TEST_F(TestHash, TestSha256Simple) { + // schema for input fields + auto field_a = field("a", int32()); + auto field_b = field("b", int64()); + auto field_c = field("c", float32()); + auto field_d = field("d", float64()); + auto schema = arrow::schema({field_a, field_b, field_c, field_d}); + + // output fields + auto res_0 = field("res0", utf8()); + auto res_1 = field("res1", utf8()); + auto res_2 = field("res2", utf8()); + auto res_3 = field("res3", utf8()); + + // build expressions. + // hashSHA256(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto hashSha256_1 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8()); + auto expr_0 = TreeExprBuilder::MakeExpression(hashSha256_1, res_0); + + auto node_b = TreeExprBuilder::MakeField(field_b); + auto hashSha256_2 = TreeExprBuilder::MakeFunction("hashSHA256", {node_b}, utf8()); + auto expr_1 = TreeExprBuilder::MakeExpression(hashSha256_2, res_1); + + auto node_c = TreeExprBuilder::MakeField(field_c); + auto hashSha256_3 = TreeExprBuilder::MakeFunction("hashSHA256", {node_c}, utf8()); + auto expr_2 = TreeExprBuilder::MakeExpression(hashSha256_3, res_2); + + auto node_d = TreeExprBuilder::MakeField(field_d); + auto hashSha256_4 = TreeExprBuilder::MakeFunction("hashSHA256", {node_d}, utf8()); + auto expr_3 = TreeExprBuilder::MakeExpression(hashSha256_4, res_3); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr_0, expr_1, expr_2, expr_3}, + TestConfiguration(), &projector); + ASSERT_OK(status) << status.message(); + + // Create a row-batch with some sample data + int num_records = 2; + auto validity_array = {false, true}; + + auto array_int32 = MakeArrowArrayInt32({1, 0}, validity_array); + + auto array_int64 = MakeArrowArrayInt64({1, 0}, validity_array); + + auto array_float32 = MakeArrowArrayFloat32({1.0, 0.0}, validity_array); + + auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0}, validity_array); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make( + schema, num_records, {array_int32, array_int64, array_float32, array_float64}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + auto response_int32 = outputs.at(0); + auto response_int64 = outputs.at(1); + auto response_float32 = outputs.at(2); + auto response_float64 = outputs.at(3); + + // Checks if the null and zero representation for numeric values + // are consistent between the types + EXPECT_ARROW_ARRAY_EQUALS(response_int32, response_int64); + EXPECT_ARROW_ARRAY_EQUALS(response_int64, response_float32); + EXPECT_ARROW_ARRAY_EQUALS(response_float32, response_float64); + + const int sha256_hash_size = 64; + + // Checks if the hash size in response is correct + for (int i = 1; i < num_records; ++i) { + const auto& value_at_position = response_int32->GetScalar(i).ValueOrDie()->ToString(); + + EXPECT_EQ(value_at_position.size(), sha256_hash_size); + EXPECT_NE(value_at_position, + response_int32->GetScalar(i - 1).ValueOrDie()->ToString()); + } +} + +TEST_F(TestHash, TestSha256Varlen) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_0 = field("res0", utf8()); + + // build expressions. + // hashSHA256(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto hashSha256 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8()); + auto expr_0 = TreeExprBuilder::MakeExpression(hashSha256, res_0); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr_0}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 3; + + std::string first_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY " + "[ˈʏpsilɔn], Yen [jɛn], Yoga [ˈjoːgɑ]"; + std::string second_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY " + "[ˈʏpsilɔn], Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ"; + + auto array_a = + MakeArrowArrayUtf8({"foo", first_string, second_string}, {false, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + auto response = outputs.at(0); + const int sha256_hash_size = 64; + + EXPECT_EQ(response->null_count(), 0); + + // Checks that the null value was hashed + EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), ""); + EXPECT_EQ(response->GetScalar(0).ValueOrDie()->ToString().size(), sha256_hash_size); + + // Check that all generated hashes were different + for (int i = 1; i < num_records; ++i) { + const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString(); + + EXPECT_EQ(value_at_position.size(), sha256_hash_size); + EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString()); + } +} + +TEST_F(TestHash, TestSha1Simple) { + // schema for input fields + auto field_a = field("a", int32()); + auto field_b = field("b", int64()); + auto field_c = field("c", float32()); + auto field_d = field("d", float64()); + auto schema = arrow::schema({field_a, field_b, field_c, field_d}); + + // output fields + auto res_0 = field("res0", utf8()); + auto res_1 = field("res1", utf8()); + auto res_2 = field("res2", utf8()); + auto res_3 = field("res3", utf8()); + + // build expressions. + // hashSHA1(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto hashSha1_1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8()); + auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1_1, res_0); + + auto node_b = TreeExprBuilder::MakeField(field_b); + auto hashSha1_2 = TreeExprBuilder::MakeFunction("hashSHA1", {node_b}, utf8()); + auto expr_1 = TreeExprBuilder::MakeExpression(hashSha1_2, res_1); + + auto node_c = TreeExprBuilder::MakeField(field_c); + auto hashSha1_3 = TreeExprBuilder::MakeFunction("hashSHA1", {node_c}, utf8()); + auto expr_2 = TreeExprBuilder::MakeExpression(hashSha1_3, res_2); + + auto node_d = TreeExprBuilder::MakeField(field_d); + auto hashSha1_4 = TreeExprBuilder::MakeFunction("hashSHA1", {node_d}, utf8()); + auto expr_3 = TreeExprBuilder::MakeExpression(hashSha1_4, res_3); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr_0, expr_1, expr_2, expr_3}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 2; + auto validity_array = {false, true}; + + auto array_int32 = MakeArrowArrayInt32({1, 0}, validity_array); + + auto array_int64 = MakeArrowArrayInt64({1, 0}, validity_array); + + auto array_float32 = MakeArrowArrayFloat32({1.0, 0.0}, validity_array); + + auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0}, validity_array); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make( + schema, num_records, {array_int32, array_int64, array_float32, array_float64}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + auto response_int32 = outputs.at(0); + auto response_int64 = outputs.at(1); + auto response_float32 = outputs.at(2); + auto response_float64 = outputs.at(3); + + // Checks if the null and zero representation for numeric values + // are consistent between the types + EXPECT_ARROW_ARRAY_EQUALS(response_int32, response_int64); + EXPECT_ARROW_ARRAY_EQUALS(response_int64, response_float32); + EXPECT_ARROW_ARRAY_EQUALS(response_float32, response_float64); + + const int sha1_hash_size = 40; + + // Checks if the hash size in response is correct + for (int i = 1; i < num_records; ++i) { + const auto& value_at_position = response_int32->GetScalar(i).ValueOrDie()->ToString(); + + EXPECT_EQ(value_at_position.size(), sha1_hash_size); + EXPECT_NE(value_at_position, + response_int32->GetScalar(i - 1).ValueOrDie()->ToString()); + } +} + +TEST_F(TestHash, TestSha1Varlen) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_0 = field("res0", utf8()); + + // build expressions. + // hashSHA1(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto hashSha1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8()); + auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1, res_0); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr_0}, TestConfiguration(), &projector); + ASSERT_OK(status) << status.message(); + + // Create a row-batch with some sample data + int num_records = 3; + + std::string first_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ]"; + std::string second_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ"; + + auto array_a = + MakeArrowArrayUtf8({"", first_string, second_string}, {false, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + auto response = outputs.at(0); + const int sha1_hash_size = 40; + + EXPECT_EQ(response->null_count(), 0); + + // Checks that the null value was hashed + EXPECT_NE(response->GetScalar(0).ValueOrDie()->ToString(), ""); + EXPECT_EQ(response->GetScalar(0).ValueOrDie()->ToString().size(), sha1_hash_size); + + // Check that all generated hashes were different + for (int i = 1; i < num_records; ++i) { + const auto& value_at_position = response->GetScalar(i).ValueOrDie()->ToString(); + + EXPECT_EQ(value_at_position.size(), sha1_hash_size); + EXPECT_NE(value_at_position, response->GetScalar(i - 1).ValueOrDie()->ToString()); + } +} + +TEST_F(TestHash, TestSha1FunctionsAlias) { + // schema for input fields + auto field_a = field("a", utf8()); + auto field_b = field("c", int64()); + auto field_c = field("e", float64()); + auto schema = arrow::schema({field_a, field_b, field_c}); + + // output fields + auto res_0 = field("res0", utf8()); + auto res_0_sha1 = field("res0sha1", utf8()); + auto res_0_sha = field("res0sha", utf8()); + + auto res_1 = field("res1", utf8()); + auto res_1_sha1 = field("res1sha1", utf8()); + auto res_1_sha = field("res1sha", utf8()); + + auto res_2 = field("res2", utf8()); + auto res_2_sha1 = field("res2_sha1", utf8()); + auto res_2_sha = field("res2_sha", utf8()); + + // build expressions. + // hashSHA1(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto hashSha1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_a}, utf8()); + auto expr_0 = TreeExprBuilder::MakeExpression(hashSha1, res_0); + auto sha1 = TreeExprBuilder::MakeFunction("sha1", {node_a}, utf8()); + auto expr_0_sha1 = TreeExprBuilder::MakeExpression(sha1, res_0_sha1); + auto sha = TreeExprBuilder::MakeFunction("sha", {node_a}, utf8()); + auto expr_0_sha = TreeExprBuilder::MakeExpression(sha, res_0_sha); + + auto node_b = TreeExprBuilder::MakeField(field_b); + auto hashSha1_1 = TreeExprBuilder::MakeFunction("hashSHA1", {node_b}, utf8()); + auto expr_1 = TreeExprBuilder::MakeExpression(hashSha1_1, res_1); + auto sha1_1 = TreeExprBuilder::MakeFunction("sha1", {node_b}, utf8()); + auto expr_1_sha1 = TreeExprBuilder::MakeExpression(sha1_1, res_1_sha1); + auto sha_1 = TreeExprBuilder::MakeFunction("sha", {node_b}, utf8()); + auto expr_1_sha = TreeExprBuilder::MakeExpression(sha_1, res_1_sha); + + auto node_c = TreeExprBuilder::MakeField(field_c); + auto hashSha1_2 = TreeExprBuilder::MakeFunction("hashSHA1", {node_c}, utf8()); + auto expr_2 = TreeExprBuilder::MakeExpression(hashSha1_2, res_2); + auto sha1_2 = TreeExprBuilder::MakeFunction("sha1", {node_c}, utf8()); + auto expr_2_sha1 = TreeExprBuilder::MakeExpression(sha1_2, res_2_sha1); + auto sha_2 = TreeExprBuilder::MakeFunction("sha", {node_c}, utf8()); + auto expr_2_sha = TreeExprBuilder::MakeExpression(sha_2, res_2_sha); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, + {expr_0, expr_0_sha, expr_0_sha1, expr_1, expr_1_sha, + expr_1_sha1, expr_2, expr_2_sha, expr_2_sha1}, + TestConfiguration(), &projector); + ASSERT_OK(status) << status.message(); + + // Create a row-batch with some sample data + int32_t num_records = 3; + + std::string first_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ]"; + std::string second_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ"; + + auto array_utf8 = + MakeArrowArrayUtf8({"", first_string, second_string}, {false, true, true}); + + auto validity_array = {false, true, true}; + + auto array_int64 = MakeArrowArrayInt64({1, 0, 32423}, validity_array); + + auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0, 324893.3849}, validity_array); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, + {array_utf8, array_int64, array_float64}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + // Checks that the response for the hashSHA1, sha and sha1 are equals for the first + // field of utf8 type + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(0), outputs.at(1)); // hashSha1 and sha + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(1), outputs.at(2)); // sha and sha1 + + // Checks that the response for the hashSHA1, sha and sha1 are equals for the second + // field of int64 type + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(3), outputs.at(4)); // hashSha1 and sha + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(4), outputs.at(5)); // sha and sha1 + + // Checks that the response for the hashSHA1, sha and sha1 are equals for the first + // field of float64 type + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(6), outputs.at(7)); // hashSha1 and sha responses + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(7), outputs.at(8)); // sha and sha1 responses +} + +TEST_F(TestHash, TestSha256FunctionsAlias) { + // schema for input fields + auto field_a = field("a", utf8()); + auto field_b = field("c", int64()); + auto field_c = field("e", float64()); + auto schema = arrow::schema({field_a, field_b, field_c}); + + // output fields + auto res_0 = field("res0", utf8()); + auto res_0_sha256 = field("res0sha256", utf8()); + + auto res_1 = field("res1", utf8()); + auto res_1_sha256 = field("res1sha256", utf8()); + + auto res_2 = field("res2", utf8()); + auto res_2_sha256 = field("res2_sha256", utf8()); + + // build expressions. + // hashSHA1(a) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto hashSha2 = TreeExprBuilder::MakeFunction("hashSHA256", {node_a}, utf8()); + auto expr_0 = TreeExprBuilder::MakeExpression(hashSha2, res_0); + auto sha256 = TreeExprBuilder::MakeFunction("sha256", {node_a}, utf8()); + auto expr_0_sha256 = TreeExprBuilder::MakeExpression(sha256, res_0_sha256); + + auto node_b = TreeExprBuilder::MakeField(field_b); + auto hashSha2_1 = TreeExprBuilder::MakeFunction("hashSHA256", {node_b}, utf8()); + auto expr_1 = TreeExprBuilder::MakeExpression(hashSha2_1, res_1); + auto sha256_1 = TreeExprBuilder::MakeFunction("sha256", {node_b}, utf8()); + auto expr_1_sha256 = TreeExprBuilder::MakeExpression(sha256_1, res_1_sha256); + + auto node_c = TreeExprBuilder::MakeField(field_c); + auto hashSha2_2 = TreeExprBuilder::MakeFunction("hashSHA256", {node_c}, utf8()); + auto expr_2 = TreeExprBuilder::MakeExpression(hashSha2_2, res_2); + auto sha256_2 = TreeExprBuilder::MakeFunction("sha256", {node_c}, utf8()); + auto expr_2_sha256 = TreeExprBuilder::MakeExpression(sha256_2, res_2_sha256); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make( + schema, {expr_0, expr_0_sha256, expr_1, expr_1_sha256, expr_2, expr_2_sha256}, + TestConfiguration(), &projector); + ASSERT_OK(status) << status.message(); + + // Create a row-batch with some sample data + int32_t num_records = 3; + + std::string first_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeıʃn\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ]"; + std::string second_string = + "ði ıntəˈnæʃənəl fəˈnɛtık əsoʊsiˈeın\nY [ˈʏpsilɔn], " + "Yen [jɛn], Yoga [ˈjoːgɑ] コンニチハ"; + + auto array_utf8 = + MakeArrowArrayUtf8({"", first_string, second_string}, {false, true, true}); + + auto validity_array = {false, true, true}; + + auto array_int64 = MakeArrowArrayInt64({1, 0, 32423}, validity_array); + + auto array_float64 = MakeArrowArrayFloat64({1.0, 0.0, 324893.3849}, validity_array); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, + {array_utf8, array_int64, array_float64}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + ASSERT_OK(status); + + // Checks that the response for the hashSHA2, sha256 and sha2 are equals for the first + // field of utf8 type + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(0), outputs.at(1)); // hashSha2 and sha256 + + // Checks that the response for the hashSHA2, sha256 and sha2 are equals for the second + // field of int64 type + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(2), outputs.at(3)); // hashSha2 and sha256 + + // Checks that the response for the hashSHA2, sha256 and sha2 are equals for the first + // field of float64 type + EXPECT_ARROW_ARRAY_EQUALS(outputs.at(4), + outputs.at(5)); // hashSha2 and sha256 responses +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/huge_table_test.cc b/src/arrow/cpp/src/gandiva/tests/huge_table_test.cc new file mode 100644 index 000000000..46f814b47 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/huge_table_test.cc @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "gandiva/filter.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class LARGE_MEMORY_TEST(TestHugeProjector) : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +class LARGE_MEMORY_TEST(TestHugeFilter) : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(LARGE_MEMORY_TEST(TestHugeProjector), SimpleTestSumHuge) { + auto atype = arrow::TypeTraits<arrow::Int32Type>::type_singleton(); + + // schema for input fields + auto field0 = field("f0", atype); + auto field1 = field("f1", atype); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", atype); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + // Cause an overflow in int32_t + int64_t num_records = static_cast<int64_t>(INT32_MAX) + 3; + std::vector<int32_t> input0 = {2, 29, 5, 37, 11, 59, 17, 19}; + std::vector<int32_t> input1 = {23, 3, 31, 7, 41, 47, 13}; + std::vector<bool> validity; + + std::vector<int32_t> arr1; + std::vector<int32_t> arr2; + // expected output + std::vector<int32_t> sum1; + + for (int64_t i = 0; i < num_records; i++) { + arr1.push_back(input0[i % 8]); + arr2.push_back(input1[i % 7]); + sum1.push_back(input0[i % 8] + input1[i % 7]); + validity.push_back(true); + } + + auto exp_sum = MakeArrowArray<arrow::Int32Type, int32_t>(sum1, validity); + auto array0 = MakeArrowArray<arrow::Int32Type, int32_t>(arr1, validity); + auto array1 = MakeArrowArray<arrow::Int32Type, int32_t>(arr2, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); +} + +TEST_F(LARGE_MEMORY_TEST(TestHugeFilter), TestSimpleHugeFilter) { + // Create a row-batch with some sample data + // Cause an overflow in int32_t + int64_t num_records = static_cast<int64_t>(INT32_MAX) + 3; + std::vector<int32_t> input0 = {2, 29, 5, 37, 11, 59, 17, 19}; + std::vector<int32_t> input1 = {23, 3, 31, 7, 41, 47, 13}; + std::vector<bool> validity; + + std::vector<int32_t> arr1; + std::vector<int32_t> arr2; + // expected output + std::vector<uint64_t> sel; + + for (int64_t i = 0; i < num_records; i++) { + arr1.push_back(input0[i % 8]); + arr2.push_back(input1[i % 7]); + if (input0[i % 8] + input1[i % 7] > 50) { + sel.push_back(i); + } + validity.push_back(true); + } + + auto exp = MakeArrowArrayUint64(sel); + + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 50 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_50 = TreeExprBuilder::MakeLiteral((int32_t)50); + auto less_than_50 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_50}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_50); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arr1, arr2}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt64(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc b/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc new file mode 100644 index 000000000..54b6d43b4 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc @@ -0,0 +1,378 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" + +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestIfExpr : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestIfExpr, TestSimple) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestSimpleArithmetic) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + b + // else + // a - b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32()); + auto if_node = TreeExprBuilder::MakeIf(condition, sum, sub, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({15, -3, -35, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNested) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + b + // else if (a < b) + // a - b + // else + // a * b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition_gt = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto condition_lt = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32()); + auto mult = TreeExprBuilder::MakeFunction("multiply", {node_a, node_b}, int32()); + auto else_node = TreeExprBuilder::MakeIf(condition_lt, sub, mult, int32()); + auto if_node = TreeExprBuilder::MakeIf(condition_gt, sum, else_node, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, 15, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({15, -3, 225, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNestedInIf) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > 10) + // if (a < 20) + // a + b + // else + // b + c + // else + // a + c + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + + auto literal_10 = TreeExprBuilder::MakeLiteral(10); + auto literal_20 = TreeExprBuilder::MakeLiteral(20); + + auto gt_10 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_10}, boolean()); + auto lt_20 = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal_20}, boolean()); + auto sum_ab = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sum_bc = TreeExprBuilder::MakeFunction("add", {node_b, node_c}, int32()); + auto sum_ac = TreeExprBuilder::MakeFunction("add", {node_a, node_c}, int32()); + + auto if_lt_20 = TreeExprBuilder::MakeIf(lt_20, sum_ab, sum_bc, int32()); + auto if_gt_10 = TreeExprBuilder::MakeIf(gt_10, if_lt_20, sum_ac, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_gt_10, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 6; + auto array_a = + MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true}); + auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19}, + {true, true, true, false, false, false}); + auto array_c = MakeArrowArrayInt32({35, 45, 55, 35, 45, 55}, + {true, true, true, false, false, false}); + + // expected output + auto exp = + MakeArrowArrayInt32({55, 33, 60, 0, 0, 0}, {true, true, true, false, false, false}); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNestedInCondition) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (if (a > b) then true else if (a < b) false else null) + // 1 + // else if !(if (a > b) then true else if (a < b) false else null) + // 2 + // else + // 3 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_1 = TreeExprBuilder::MakeLiteral(1); + auto literal_2 = TreeExprBuilder::MakeLiteral(2); + auto literal_3 = TreeExprBuilder::MakeLiteral(3); + auto literal_true = TreeExprBuilder::MakeLiteral(true); + auto literal_false = TreeExprBuilder::MakeLiteral(false); + auto literal_null = TreeExprBuilder::MakeNull(boolean()); + + auto a_gt_b = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto a_lt_b = TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto cond_else = + TreeExprBuilder::MakeIf(a_lt_b, literal_false, literal_null, boolean()); + auto cond_if = TreeExprBuilder::MakeIf(a_gt_b, literal_true, cond_else, boolean()); + auto not_cond_if = TreeExprBuilder::MakeFunction("not", {cond_if}, boolean()); + + auto outer_else = TreeExprBuilder::MakeIf(not_cond_if, literal_2, literal_3, int32()); + auto outer_if = TreeExprBuilder::MakeIf(cond_if, literal_1, outer_else, int32()); + auto expr = TreeExprBuilder::MakeExpression(outer_if, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 6; + auto array_a = + MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true}); + auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19}, + {true, true, true, false, false, false}); + // expected output + auto exp = + MakeArrowArrayInt32({1, 2, 2, 3, 3, 3}, {true, true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestBigNested) { + // schema for input fields + auto fielda = field("a", int32()); + auto schema = arrow::schema({fielda}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a < 10) + // 10 + // else if (a < 20) + // 20 + // .. + // .. + // else if (a < 190) + // 190 + // else + // 200 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto top_node = TreeExprBuilder::MakeLiteral(200); + for (int thresh = 190; thresh > 0; thresh -= 10) { + auto literal = TreeExprBuilder::MakeLiteral(thresh); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32()); + top_node = if_node; + } + auto expr = TreeExprBuilder::MakeExpression(top_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 102, 158, 302}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({20, 110, 160, 200}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/in_expr_test.cc b/src/arrow/cpp/src/gandiva/tests/in_expr_test.cc new file mode 100644 index 000000000..fc1a8a71b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/in_expr_test.cc @@ -0,0 +1,278 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include <cmath> + +#include "arrow/memory_pool.h" +#include "gandiva/filter.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::float64; +using arrow::int32; + +class TestIn : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; +std::vector<Decimal128> MakeDecimalVector(std::vector<std::string> values) { + std::vector<arrow::Decimal128> ret; + for (auto str : values) { + Decimal128 decimal_value; + int32_t decimal_precision; + int32_t decimal_scale; + + DCHECK_OK( + Decimal128::FromString(str, &decimal_value, &decimal_precision, &decimal_scale)); + + ret.push_back(decimal_value); + } + return ret; +} + +TEST_F(TestIn, TestInSimple) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build In f0 + f1 in (6, 11) + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + std::unordered_set<int32_t> in_constants({6, 11}); + auto in_expr = TreeExprBuilder::MakeInExpressionInt32(sum_func, in_constants); + auto condition = TreeExprBuilder::MakeCondition(in_expr); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 6}, {true, true, true, false, true}); + auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 5}, {true, true, false, true, false}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({0, 1}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestIn, TestInFloat) { + // schema for input fields + auto field0 = field("f0", float32()); + auto schema = arrow::schema({field0}); + + // Build In f0 + f1 in (6, 11) + auto node_f0 = TreeExprBuilder::MakeField(field0); + + std::unordered_set<float> in_constants({6.5f, 12.0f, 11.5f}); + auto in_expr = TreeExprBuilder::MakeInExpressionFloat(node_f0, in_constants); + auto condition = TreeExprBuilder::MakeCondition(in_expr); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = + MakeArrowArrayFloat32({6.5f, 11.5f, 4, 3.15f, 6}, {true, true, false, true, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({0, 1}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestIn, TestInDouble) { + // schema for input fields + auto field0 = field("double0", float64()); + auto field1 = field("double1", float64()); + auto schema = arrow::schema({field0, field1}); + + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::float64()); + std::unordered_set<double> in_constants({3.14159265359, 15.5555555}); + auto in_expr = TreeExprBuilder::MakeInExpressionDouble(sum_func, in_constants); + auto condition = TreeExprBuilder::MakeCondition(in_expr); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayFloat64({1, 2, 3, 4, 11}, {true, true, true, false, false}); + auto array1 = MakeArrowArrayFloat64({5, 9, 0.14159265359, 17, 4.5555555}, + {true, true, true, true, true}); + + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({2}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestIn, TestInDecimal) { + int32_t precision = 38; + int32_t scale = 5; + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + + // schema for input fields + auto field0 = field("f0", arrow::decimal(precision, scale)); + auto schema = arrow::schema({field0}); + + // Build In f0 + f1 in (6, 11) + auto node_f0 = TreeExprBuilder::MakeField(field0); + + gandiva::DecimalScalar128 d0("6", precision, scale); + gandiva::DecimalScalar128 d1("12", precision, scale); + gandiva::DecimalScalar128 d2("11", precision, scale); + std::unordered_set<gandiva::DecimalScalar128> in_constants({d0, d1, d2}); + auto in_expr = TreeExprBuilder::MakeInExpressionDecimal(node_f0, in_constants); + auto condition = TreeExprBuilder::MakeCondition(in_expr); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto values0 = MakeDecimalVector({"1", "2", "0", "-6", "6"}); + auto array0 = + MakeArrowArrayDecimal(decimal_type, values0, {true, true, true, false, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({4}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestIn, TestInString) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto schema = arrow::schema({field0}); + + // Build f0 in ("test" ,"me") + auto node_f0 = TreeExprBuilder::MakeField(field0); + std::unordered_set<std::string> in_constants({"test", "me"}); + auto in_expr = TreeExprBuilder::MakeInExpressionString(node_f0, in_constants); + + auto condition = TreeExprBuilder::MakeCondition(in_expr); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_a = MakeArrowArrayUtf8({"test", "lol", "me", "arrow", "test"}, + {true, true, true, true, false}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({0, 2}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestIn, TestInStringValidationError) { + // schema for input fields + auto field0 = field("f0", arrow::int32()); + auto schema = arrow::schema({field0}); + + // Build f0 in ("test" ,"me") + auto node_f0 = TreeExprBuilder::MakeField(field0); + std::unordered_set<std::string> in_constants({"test", "me"}); + auto in_expr = TreeExprBuilder::MakeInExpressionString(node_f0, in_constants); + auto condition = TreeExprBuilder::MakeCondition(in_expr); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Evaluation expression for IN clause returns "; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/literal_test.cc b/src/arrow/cpp/src/gandiva/tests/literal_test.cc new file mode 100644 index 000000000..b5ffff031 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/literal_test.cc @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" + +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::float64; +using arrow::int32; +using arrow::int64; + +class TestLiteral : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestLiteral, TestSimpleArithmetic) { + // schema for input fields + auto field_a = field("a", boolean()); + auto field_b = field("b", int32()); + auto field_c = field("c", int64()); + auto field_d = field("d", float32()); + auto field_e = field("e", float64()); + auto schema = arrow::schema({field_a, field_b, field_c, field_d, field_e}); + + // output fields + auto res_a = field("a+1", boolean()); + auto res_b = field("b+1", int32()); + auto res_c = field("c+1", int64()); + auto res_d = field("d+1", float32()); + auto res_e = field("e+1", float64()); + + // build expressions. + // a == true + // b + 1 + // c + 1 + // d + 1 + // e + 1 + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_a = TreeExprBuilder::MakeLiteral(true); + auto func_a = TreeExprBuilder::MakeFunction("equal", {node_a, literal_a}, boolean()); + auto expr_a = TreeExprBuilder::MakeExpression(func_a, res_a); + + auto node_b = TreeExprBuilder::MakeField(field_b); + auto literal_b = TreeExprBuilder::MakeLiteral((int32_t)1); + auto func_b = TreeExprBuilder::MakeFunction("add", {node_b, literal_b}, int32()); + auto expr_b = TreeExprBuilder::MakeExpression(func_b, res_b); + + auto node_c = TreeExprBuilder::MakeField(field_c); + auto literal_c = TreeExprBuilder::MakeLiteral((int64_t)1); + auto func_c = TreeExprBuilder::MakeFunction("add", {node_c, literal_c}, int64()); + auto expr_c = TreeExprBuilder::MakeExpression(func_c, res_c); + + auto node_d = TreeExprBuilder::MakeField(field_d); + auto literal_d = TreeExprBuilder::MakeLiteral(static_cast<float>(1)); + auto func_d = TreeExprBuilder::MakeFunction("add", {node_d, literal_d}, float32()); + auto expr_d = TreeExprBuilder::MakeExpression(func_d, res_d); + + auto node_e = TreeExprBuilder::MakeField(field_e); + auto literal_e = TreeExprBuilder::MakeLiteral(static_cast<double>(1)); + auto func_e = TreeExprBuilder::MakeFunction("add", {node_e, literal_e}, float64()); + auto expr_e = TreeExprBuilder::MakeExpression(func_e, res_e); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr_a, expr_b, expr_c, expr_d, expr_e}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayBool({true, true, false, true}, {true, true, true, false}); + auto array_b = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false}); + auto array_c = MakeArrowArrayInt64({5, 15, -15, 17}, {true, true, true, false}); + auto array_d = MakeArrowArrayFloat32({5.2f, 15, -15.6f, 17}, {true, true, true, false}); + auto array_e = MakeArrowArrayFloat64({5.6f, 15, -15.9f, 17}, {true, true, true, false}); + + // expected output + auto exp_a = MakeArrowArrayBool({true, true, false, false}, {true, true, true, false}); + auto exp_b = MakeArrowArrayInt32({6, 16, -14, 0}, {true, true, true, false}); + auto exp_c = MakeArrowArrayInt64({6, 16, -14, 0}, {true, true, true, false}); + auto exp_d = MakeArrowArrayFloat32({6.2f, 16, -14.6f, 0}, {true, true, true, false}); + auto exp_e = MakeArrowArrayFloat64({6.6f, 16, -14.9f, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, + {array_a, array_b, array_c, array_d, array_e}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_a, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_b, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(exp_c, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(exp_d, outputs.at(3)); + EXPECT_ARROW_ARRAY_EQUALS(exp_e, outputs.at(4)); +} + +TEST_F(TestLiteral, TestLiteralHash) { + auto schema = arrow::schema({}); + // output fields + auto res = field("a", int32()); + auto int_literal = TreeExprBuilder::MakeLiteral((int32_t)2); + auto expr = TreeExprBuilder::MakeExpression(int_literal, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + auto res1 = field("a", int64()); + auto int_literal1 = TreeExprBuilder::MakeLiteral((int64_t)2); + auto expr1 = TreeExprBuilder::MakeExpression(int_literal1, res1); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector1; + status = Projector::Make(schema, {expr1}, TestConfiguration(), &projector1); + EXPECT_TRUE(status.ok()) << status.message(); + EXPECT_TRUE(projector.get() != projector1.get()); +} + +TEST_F(TestLiteral, TestNullLiteral) { + // schema for input fields + auto field_a = field("a", int32()); + auto field_b = field("b", int32()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("a+b+null", int32()); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeField(field_b); + auto literal_c = TreeExprBuilder::MakeNull(arrow::int32()); + auto add_a_b = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto add_a_b_c = TreeExprBuilder::MakeFunction("add", {add_a_b, literal_c}, int32()); + auto expr = TreeExprBuilder::MakeExpression(add_a_b_c, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false}); + auto array_b = MakeArrowArrayInt32({5, 15, -15, 17}, {true, true, true, false}); + + // expected output + auto exp = MakeArrowArrayInt32({0, 0, 0, 0}, {false, false, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestLiteral, TestNullLiteralInIf) { + // schema for input fields + auto field_a = field("a", float64()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", float64()); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_5 = TreeExprBuilder::MakeLiteral(5.0); + auto a_gt_5 = TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_5}, + arrow::boolean()); + auto literal_null = TreeExprBuilder::MakeNull(arrow::float64()); + auto if_node = + TreeExprBuilder::MakeIf(a_gt_5, literal_5, literal_null, arrow::float64()); + auto expr = TreeExprBuilder::MakeExpression(if_node, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayFloat64({6, 15, -15, 17}, {true, true, true, false}); + + // expected output + auto exp = MakeArrowArrayFloat64({5, 5, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/micro_benchmarks.cc b/src/arrow/cpp/src/gandiva/tests/micro_benchmarks.cc new file mode 100644 index 000000000..35c77e3dd --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/micro_benchmarks.cc @@ -0,0 +1,456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <stdlib.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "benchmark/benchmark.h" +#include "gandiva/decimal_type_util.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tests/timed_evaluate.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; +using arrow::int64; +using arrow::utf8; + +static void TimedTestAdd3(benchmark::State& state) { + // schema for input fields + auto field0 = field("f0", int64()); + auto field1 = field("f1", int64()); + auto field2 = field("f2", int64()); + auto schema = arrow::schema({field0, field1, field2}); + auto pool_ = arrow::default_memory_pool(); + + // output field + auto field_sum = field("add", int64()); + + // Build expression + auto part_sum = TreeExprBuilder::MakeFunction( + "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field2)}, + int64()); + auto sum = TreeExprBuilder::MakeFunction( + "add", {TreeExprBuilder::MakeField(field0), part_sum}, int64()); + + auto sum_expr = TreeExprBuilder::MakeExpression(sum, field_sum); + + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector)); + + Int64DataGenerator data_generator; + ProjectEvaluator evaluator(projector); + + Status status = TimedEvaluate<arrow::Int64Type, int64_t>( + schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state); + ASSERT_OK(status); +} + +static void TimedTestBigNested(benchmark::State& state) { + // schema for input fields + auto fielda = field("a", int32()); + auto schema = arrow::schema({fielda}); + auto pool_ = arrow::default_memory_pool(); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a < 10) + // 10 + // else if (a < 20) + // 20 + // .. + // .. + // else if (a < 190) + // 190 + // else + // 200 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto top_node = TreeExprBuilder::MakeLiteral(200); + for (int thresh = 190; thresh > 0; thresh -= 10) { + auto literal = TreeExprBuilder::MakeLiteral(thresh); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32()); + top_node = if_node; + } + auto expr = TreeExprBuilder::MakeExpression(top_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + BoundedInt32DataGenerator data_generator(250); + ProjectEvaluator evaluator(projector); + + Status status = TimedEvaluate<arrow::Int32Type, int32_t>( + schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state); + ASSERT_TRUE(status.ok()); +} + +static void TimedTestExtractYear(benchmark::State& state) { + // schema for input fields + auto field0 = field("f0", arrow::date64()); + auto schema = arrow::schema({field0}); + auto pool_ = arrow::default_memory_pool(); + + // output field + auto field_res = field("res", int64()); + + // Build expression + auto expr = TreeExprBuilder::MakeExpression("extractYear", {field0}, field_res); + + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + Int64DataGenerator data_generator; + ProjectEvaluator evaluator(projector); + + Status status = TimedEvaluate<arrow::Date64Type, int64_t>( + schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state); + ASSERT_TRUE(status.ok()); +} + +static void TimedTestFilterAdd2(benchmark::State& state) { + // schema for input fields + auto field0 = field("f0", int64()); + auto field1 = field("f1", int64()); + auto field2 = field("f2", int64()); + auto schema = arrow::schema({field0, field1, field2}); + auto pool_ = arrow::default_memory_pool(); + + // Build expression + auto sum = TreeExprBuilder::MakeFunction( + "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field0)}, + int64()); + auto less_than = TreeExprBuilder::MakeFunction( + "less_than", {sum, TreeExprBuilder::MakeField(field2)}, boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than); + + std::shared_ptr<Filter> filter; + ASSERT_OK(Filter::Make(schema, condition, TestConfiguration(), &filter)); + + Int64DataGenerator data_generator; + FilterEvaluator evaluator(filter); + + Status status = TimedEvaluate<arrow::Int64Type, int64_t>( + schema, evaluator, data_generator, pool_, MILLION, 16 * THOUSAND, state); + ASSERT_TRUE(status.ok()); +} + +static void TimedTestFilterLike(benchmark::State& state) { + // schema for input fields + auto fielda = field("a", utf8()); + auto schema = arrow::schema({fielda}); + auto pool_ = arrow::default_memory_pool(); + + // build expression. + auto node_a = TreeExprBuilder::MakeField(fielda); + auto pattern_node = TreeExprBuilder::MakeStringLiteral("%yellow%"); + auto like_yellow = + TreeExprBuilder::MakeFunction("like", {node_a, pattern_node}, arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(like_yellow); + + std::shared_ptr<Filter> filter; + ASSERT_OK(Filter::Make(schema, condition, TestConfiguration(), &filter)); + + FastUtf8DataGenerator data_generator(32); + FilterEvaluator evaluator(filter); + + Status status = TimedEvaluate<arrow::StringType, std::string>( + schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state); + ASSERT_TRUE(status.ok()); +} + +static void TimedTestCastFloatFromString(benchmark::State& state) { + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + auto pool = arrow::default_memory_pool(); + + auto field_result = field("res", arrow::float64()); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto fn = TreeExprBuilder::MakeFunction("castFLOAT8", {node_a}, arrow::float64()); + auto expr = TreeExprBuilder::MakeExpression(fn, field_result); + + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + Utf8FloatDataGenerator data_generator; + ProjectEvaluator evaluator(projector); + + Status status = TimedEvaluate<arrow::StringType, std::string>( + schema, evaluator, data_generator, pool, 1 * MILLION, 16 * THOUSAND, state); + ASSERT_TRUE(status.ok()); +} + +static void TimedTestCastIntFromString(benchmark::State& state) { + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + auto pool = arrow::default_memory_pool(); + + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto fn = TreeExprBuilder::MakeFunction("castINT", {node_a}, int32()); + auto expr = TreeExprBuilder::MakeExpression(fn, field_result); + + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + Utf8IntDataGenerator data_generator; + ProjectEvaluator evaluator(projector); + + Status status = TimedEvaluate<arrow::StringType, std::string>( + schema, evaluator, data_generator, pool, 1 * MILLION, 16 * THOUSAND, state); + ASSERT_TRUE(status.ok()); +} + +static void TimedTestAllocs(benchmark::State& state) { + // schema for input fields + auto field_a = field("a", arrow::utf8()); + auto schema = arrow::schema({field_a}); + auto pool_ = arrow::default_memory_pool(); + + // output field + auto field_res = field("res", int32()); + + // Build expression + auto node_a = TreeExprBuilder::MakeField(field_a); + auto upper = TreeExprBuilder::MakeFunction("upper", {node_a}, utf8()); + auto length = TreeExprBuilder::MakeFunction("octet_length", {upper}, int32()); + auto expr = TreeExprBuilder::MakeExpression(length, field_res); + + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + FastUtf8DataGenerator data_generator(64); + ProjectEvaluator evaluator(projector); + + Status status = TimedEvaluate<arrow::StringType, std::string>( + schema, evaluator, data_generator, pool_, 1 * MILLION, 16 * THOUSAND, state); + ASSERT_TRUE(status.ok()); +} +// following two tests are for benchmark optimization of +// in expr. will be used in follow-up PRs to optimize in expr. + +static void TimedTestMultiOr(benchmark::State& state) { + // schema for input fields + auto fielda = field("a", utf8()); + auto schema = arrow::schema({fielda}); + auto pool_ = arrow::default_memory_pool(); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // booleanOr(a = string1, a = string2, ..) + auto node_a = TreeExprBuilder::MakeField(fielda); + + NodeVector boolean_functions; + FastUtf8DataGenerator data_generator1(250); + for (int thresh = 1; thresh <= 32; thresh++) { + auto literal = TreeExprBuilder::MakeStringLiteral(data_generator1.GenerateData()); + auto condition = TreeExprBuilder::MakeFunction("equal", {node_a, literal}, boolean()); + boolean_functions.push_back(condition); + } + + auto boolean_or = TreeExprBuilder::MakeOr(boolean_functions); + auto expr = TreeExprBuilder::MakeExpression(boolean_or, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + FastUtf8DataGenerator data_generator(250); + ProjectEvaluator evaluator(projector); + Status status = TimedEvaluate<arrow::StringType, std::string>( + schema, evaluator, data_generator, pool_, 100 * THOUSAND, 16 * THOUSAND, state); + ASSERT_OK(status); +} + +static void TimedTestInExpr(benchmark::State& state) { + // schema for input fields + auto fielda = field("a", utf8()); + auto schema = arrow::schema({fielda}); + auto pool_ = arrow::default_memory_pool(); + + // output fields + auto field_result = field("res", boolean()); + + // build expression. + // a in (string1, string2, ..) + auto node_a = TreeExprBuilder::MakeField(fielda); + + std::unordered_set<std::string> values; + FastUtf8DataGenerator data_generator1(250); + for (int i = 1; i <= 32; i++) { + values.insert(data_generator1.GenerateData()); + } + auto boolean_or = TreeExprBuilder::MakeInExpressionString(node_a, values); + auto expr = TreeExprBuilder::MakeExpression(boolean_or, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + FastUtf8DataGenerator data_generator(250); + ProjectEvaluator evaluator(projector); + + Status status = TimedEvaluate<arrow::StringType, std::string>( + schema, evaluator, data_generator, pool_, 100 * THOUSAND, 16 * THOUSAND, state); + + ASSERT_OK(status); +} + +static void DoDecimalAdd3(benchmark::State& state, int32_t precision, int32_t scale, + bool large = false) { + // schema for input fields + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field0 = field("f0", decimal_type); + auto field1 = field("f1", decimal_type); + auto field2 = field("f2", decimal_type); + auto schema = arrow::schema({field0, field1, field2}); + + Decimal128TypePtr add2_type; + auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, + {decimal_type, decimal_type}, &add2_type); + + Decimal128TypePtr output_type; + status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, + {add2_type, decimal_type}, &output_type); + + // output field + auto field_sum = field("add", output_type); + + // Build expression + auto part_sum = TreeExprBuilder::MakeFunction( + "add", {TreeExprBuilder::MakeField(field1), TreeExprBuilder::MakeField(field2)}, + add2_type); + auto sum = TreeExprBuilder::MakeFunction( + "add", {TreeExprBuilder::MakeField(field0), part_sum}, output_type); + + auto sum_expr = TreeExprBuilder::MakeExpression(sum, field_sum); + + std::shared_ptr<Projector> projector; + status = Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + Decimal128DataGenerator data_generator(large); + ProjectEvaluator evaluator(projector); + + status = TimedEvaluate<arrow::Decimal128Type, arrow::Decimal128>( + schema, evaluator, data_generator, arrow::default_memory_pool(), 1 * MILLION, + 16 * THOUSAND, state); + ASSERT_OK(status); +} + +static void DoDecimalAdd2(benchmark::State& state, int32_t precision, int32_t scale, + bool large = false) { + // schema for input fields + auto decimal_type = std::make_shared<arrow::Decimal128Type>(precision, scale); + auto field0 = field("f0", decimal_type); + auto field1 = field("f1", decimal_type); + auto schema = arrow::schema({field0, field1}); + + Decimal128TypePtr output_type; + auto status = DecimalTypeUtil::GetResultType( + DecimalTypeUtil::kOpAdd, {decimal_type, decimal_type}, &output_type); + + // output field + auto field_sum = field("add", output_type); + + // Build expression + auto sum = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + + std::shared_ptr<Projector> projector; + status = Projector::Make(schema, {sum}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + Decimal128DataGenerator data_generator(large); + ProjectEvaluator evaluator(projector); + + status = TimedEvaluate<arrow::Decimal128Type, arrow::Decimal128>( + schema, evaluator, data_generator, arrow::default_memory_pool(), 1 * MILLION, + 16 * THOUSAND, state); + ASSERT_OK(status); +} + +static void DecimalAdd2Fast(benchmark::State& state) { + // use lesser precision to test the fast-path + DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision - 6, 18); +} + +static void DecimalAdd2LeadingZeroes(benchmark::State& state) { + // use max precision to test the large-integer-path + DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 6); +} + +static void DecimalAdd2LeadingZeroesWithDiv(benchmark::State& state) { + // use max precision to test the large-integer-path + DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 18); +} + +static void DecimalAdd2Large(benchmark::State& state) { + // use max precision to test the large-integer-path + DoDecimalAdd2(state, DecimalTypeUtil::kMaxPrecision, 18, true); +} + +static void DecimalAdd3Fast(benchmark::State& state) { + // use lesser precision to test the fast-path + DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision - 6, 18); +} + +static void DecimalAdd3LeadingZeroes(benchmark::State& state) { + // use max precision to test the large-integer-path + DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 6); +} + +static void DecimalAdd3LeadingZeroesWithDiv(benchmark::State& state) { + // use max precision to test the large-integer-path + DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 18); +} + +static void DecimalAdd3Large(benchmark::State& state) { + // use max precision to test the large-integer-path + DoDecimalAdd3(state, DecimalTypeUtil::kMaxPrecision, 18, true); +} + +BENCHMARK(TimedTestAdd3)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestBigNested)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestExtractYear)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestFilterAdd2)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestFilterLike)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestCastFloatFromString)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestCastIntFromString)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestAllocs)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestMultiOr)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(TimedTestInExpr)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd2Fast)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd2LeadingZeroes)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd2LeadingZeroesWithDiv)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd2Large)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd3Fast)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd3LeadingZeroes)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd3LeadingZeroesWithDiv)->MinTime(1.0)->Unit(benchmark::kMicrosecond); +BENCHMARK(DecimalAdd3Large)->MinTime(1.0)->Unit(benchmark::kMicrosecond); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/null_validity_test.cc b/src/arrow/cpp/src/gandiva/tests/null_validity_test.cc new file mode 100644 index 000000000..0374b68d4 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/null_validity_test.cc @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "gandiva/filter.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; +using arrow::utf8; + +class TestNullValidity : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +// Create an array without a validity buffer. +ArrayPtr MakeArrowArrayInt32WithNullValidity(std::vector<int32_t> in_data) { + auto array = MakeArrowArrayInt32(in_data); + return std::make_shared<arrow::Int32Array>(in_data.size(), array->data()->buffers[1], + nullptr, 0); +} + +TEST_F(TestNullValidity, TestFunc) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // Build condition f0 + f1 < 10 + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto node_f1 = TreeExprBuilder::MakeField(field1); + auto sum_func = + TreeExprBuilder::MakeFunction("add", {node_f0, node_f1}, arrow::int32()); + auto literal_10 = TreeExprBuilder::MakeLiteral((int32_t)10); + auto less_than_10 = TreeExprBuilder::MakeFunction("less_than", {sum_func, literal_10}, + arrow::boolean()); + auto condition = TreeExprBuilder::MakeCondition(less_than_10); + + std::shared_ptr<Filter> filter; + auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 5; + + // Create an array without a validity buffer. + auto array0 = MakeArrowArrayInt32WithNullValidity({1, 2, 3, 4, 6}); + auto array1 = MakeArrowArrayInt32({5, 9, 6, 17, 3}, {true, true, false, true, true}); + // expected output (indices for which condition matches) + auto exp = MakeArrowArrayUint16({0, 4}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + std::shared_ptr<SelectionVector> selection_vector; + status = SelectionVector::MakeInt16(num_records, pool_, &selection_vector); + EXPECT_TRUE(status.ok()); + + // Evaluate expression + status = filter->Evaluate(*in_batch, selection_vector); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, selection_vector->ToArray()); +} + +TEST_F(TestNullValidity, TestIfElse) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32WithNullValidity({10, 12, -20, 5}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}); + + // expected output + auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestNullValidity, TestUtf8) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res1", int32()); + + // build expressions. + // length(a) + auto expr = TreeExprBuilder::MakeExpression("length", {field_a}, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_v = MakeArrowArrayUtf8({"foo", "hello", "bye", "hi", "मदन"}); + auto array_a = std::make_shared<arrow::StringArray>( + num_records, array_v->data()->buffers[1], array_v->data()->buffers[2]); + + // expected output + auto exp = MakeArrowArrayInt32({3, 5, 3, 2, 3}, {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/projector_build_validation_test.cc b/src/arrow/cpp/src/gandiva/tests/projector_build_validation_test.cc new file mode 100644 index 000000000..5b86844f9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/projector_build_validation_test.cc @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestProjector : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestProjector, TestNonexistentFunction) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = TreeExprBuilder::MakeExpression("nonexistent_function", {field0, field1}, + field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Function bool nonexistent_function(float, float) not supported yet."; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestNotMatchingDataType) { + // schema for input fields + auto field0 = field("f0", float32()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Return type of root node float does not match that of expression bool"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestNotSupportedDataType) { + // schema for input fields + auto field0 = field("f0", list(int32())); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", list(int32())); + + // Build expression + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Field f0 has unsupported data type list"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIncorrectSchemaMissingField) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = + TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Field f2 not in schema"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIncorrectSchemaTypeNotMatching) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto field2 = field("f2", int32()); + auto schema = arrow::schema({field0, field2}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = + TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Field definition in schema f2: int32 different from field in expression f2: float"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIfNotSupportedFunction) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("nonexistent_function", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); +} + +TEST_F(TestProjector, TestIfNotMatchingReturnType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, boolean()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); +} + +TEST_F(TestProjector, TestElseNotMatchingReturnType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", boolean()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); +} + +TEST_F(TestProjector, TestElseNotSupportedType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", list(int32())); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + EXPECT_EQ(status.code(), StatusCode::ExpressionValidationError); +} + +TEST_F(TestProjector, TestAndMinChildren) { + // schema for input fields + auto fielda = field("a", boolean()); + auto schema = arrow::schema({fielda}); + + // output fields + auto field_result = field("res", boolean()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto and_node = TreeExprBuilder::MakeAnd({node_a}); + + auto expr = TreeExprBuilder::MakeExpression(and_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); +} + +TEST_F(TestProjector, TestAndBooleanArgType) { + // schema for input fields + auto fielda = field("a", boolean()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto and_node = TreeExprBuilder::MakeAnd({node_a, node_b}); + + auto expr = TreeExprBuilder::MakeExpression(and_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/projector_test.cc b/src/arrow/cpp/src/gandiva/tests/projector_test.cc new file mode 100644 index 000000000..120207773 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/projector_test.cc @@ -0,0 +1,1609 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +#include "gandiva/projector.h" + +#include <gtest/gtest.h> + +#include <cmath> + +#include "arrow/memory_pool.h" +#include "gandiva/literal_holder.h" +#include "gandiva/node.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestProjector : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestProjector, TestProjectCache) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f2", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", int32()); + auto field_sub = field("subtract", int32()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + auto sub_expr = + TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub); + + auto configuration = TestConfiguration(); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {sum_expr, sub_expr}, configuration, &projector); + ASSERT_OK(status); + + // everything is same, should return the same projector. + auto schema_same = arrow::schema({field0, field1}); + std::shared_ptr<Projector> cached_projector; + status = Projector::Make(schema_same, {sum_expr, sub_expr}, configuration, + &cached_projector); + ASSERT_OK(status); + EXPECT_EQ(cached_projector, projector); + + // schema is different should return a new projector. + auto field2 = field("f2", int32()); + auto different_schema = arrow::schema({field0, field1, field2}); + std::shared_ptr<Projector> should_be_new_projector; + status = Projector::Make(different_schema, {sum_expr, sub_expr}, configuration, + &should_be_new_projector); + ASSERT_OK(status); + EXPECT_NE(cached_projector, should_be_new_projector); + + // expression list is different should return a new projector. + std::shared_ptr<Projector> should_be_new_projector1; + status = Projector::Make(schema, {sum_expr}, configuration, &should_be_new_projector1); + ASSERT_OK(status); + EXPECT_NE(cached_projector, should_be_new_projector1); + + // another instance of the same configuration, should return the same projector. + status = Projector::Make(schema, {sum_expr, sub_expr}, TestConfiguration(), + &cached_projector); + ASSERT_OK(status); + EXPECT_EQ(cached_projector, projector); +} + +TEST_F(TestProjector, TestProjectCacheFieldNames) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto field2 = field("f2", int32()); + auto schema = arrow::schema({field0, field1, field2}); + + // output fields + auto sum_01 = field("sum_01", int32()); + auto sum_12 = field("sum_12", int32()); + + auto sum_expr_01 = TreeExprBuilder::MakeExpression("add", {field0, field1}, sum_01); + std::shared_ptr<Projector> projector_01; + auto status = + Projector::Make(schema, {sum_expr_01}, TestConfiguration(), &projector_01); + EXPECT_TRUE(status.ok()); + + auto sum_expr_12 = TreeExprBuilder::MakeExpression("add", {field1, field2}, sum_12); + std::shared_ptr<Projector> projector_12; + status = Projector::Make(schema, {sum_expr_12}, TestConfiguration(), &projector_12); + EXPECT_TRUE(status.ok()); + + // add(f0, f1) != add(f1, f2) + EXPECT_TRUE(projector_01.get() != projector_12.get()); +} + +TEST_F(TestProjector, TestProjectCacheDouble) { + auto schema = arrow::schema({}); + auto res = field("result", arrow::float64()); + + double d0 = 1.23456788912345677E18; + double d1 = 1.23456789012345677E18; + + auto literal0 = TreeExprBuilder::MakeLiteral(d0); + auto expr0 = TreeExprBuilder::MakeExpression(literal0, res); + auto configuration = TestConfiguration(); + + std::shared_ptr<Projector> projector0; + auto status = Projector::Make(schema, {expr0}, configuration, &projector0); + EXPECT_TRUE(status.ok()) << status.message(); + + auto literal1 = TreeExprBuilder::MakeLiteral(d1); + auto expr1 = TreeExprBuilder::MakeExpression(literal1, res); + std::shared_ptr<Projector> projector1; + status = Projector::Make(schema, {expr1}, configuration, &projector1); + EXPECT_TRUE(status.ok()) << status.message(); + + EXPECT_TRUE(projector0.get() != projector1.get()); +} + +TEST_F(TestProjector, TestProjectCacheFloat) { + auto schema = arrow::schema({}); + auto res = field("result", arrow::float32()); + + float f0 = static_cast<float>(12345678891.000000); + float f1 = f0 - 1000; + + auto literal0 = TreeExprBuilder::MakeLiteral(f0); + auto expr0 = TreeExprBuilder::MakeExpression(literal0, res); + std::shared_ptr<Projector> projector0; + auto status = Projector::Make(schema, {expr0}, TestConfiguration(), &projector0); + EXPECT_TRUE(status.ok()) << status.message(); + + auto literal1 = TreeExprBuilder::MakeLiteral(f1); + auto expr1 = TreeExprBuilder::MakeExpression(literal1, res); + std::shared_ptr<Projector> projector1; + status = Projector::Make(schema, {expr1}, TestConfiguration(), &projector1); + EXPECT_TRUE(status.ok()) << status.message(); + + EXPECT_TRUE(projector0.get() != projector1.get()); +} + +TEST_F(TestProjector, TestProjectCacheLiteral) { + auto schema = arrow::schema({}); + auto res = field("result", arrow::decimal(38, 5)); + + DecimalScalar128 d0("12345678", 38, 5); + DecimalScalar128 d1("98756432", 38, 5); + + auto literal0 = TreeExprBuilder::MakeDecimalLiteral(d0); + auto expr0 = TreeExprBuilder::MakeExpression(literal0, res); + std::shared_ptr<Projector> projector0; + ASSERT_OK(Projector::Make(schema, {expr0}, TestConfiguration(), &projector0)); + + auto literal1 = TreeExprBuilder::MakeDecimalLiteral(d1); + auto expr1 = TreeExprBuilder::MakeExpression(literal1, res); + std::shared_ptr<Projector> projector1; + ASSERT_OK(Projector::Make(schema, {expr1}, TestConfiguration(), &projector1)); + + EXPECT_NE(projector0.get(), projector1.get()); +} + +TEST_F(TestProjector, TestProjectCacheDecimalCast) { + auto field_float64 = field("float64", arrow::float64()); + auto schema = arrow::schema({field_float64}); + + auto res_31_13 = field("result", arrow::decimal(31, 13)); + auto expr0 = TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64}, res_31_13); + std::shared_ptr<Projector> projector0; + ASSERT_OK(Projector::Make(schema, {expr0}, TestConfiguration(), &projector0)); + + // if the output scale is different, the cache can't be used. + auto res_31_14 = field("result", arrow::decimal(31, 14)); + auto expr1 = TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64}, res_31_14); + std::shared_ptr<Projector> projector1; + ASSERT_OK(Projector::Make(schema, {expr1}, TestConfiguration(), &projector1)); + EXPECT_NE(projector0.get(), projector1.get()); + + // if the output scale/precision are same, should get a cache hit. + auto res_31_13_alt = field("result", arrow::decimal(31, 13)); + auto expr2 = + TreeExprBuilder::MakeExpression("castDECIMAL", {field_float64}, res_31_13_alt); + std::shared_ptr<Projector> projector2; + ASSERT_OK(Projector::Make(schema, {expr2}, TestConfiguration(), &projector2)); + EXPECT_EQ(projector0.get(), projector2.get()); +} + +TEST_F(TestProjector, TestIntSumSub) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f2", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", int32()); + auto field_sub = field("subtract", int32()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + auto sub_expr = + TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub); + + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {sum_expr, sub_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({11, 13, 15, 17}, {true, true, false, true}); + // expected output + auto exp_sum = MakeArrowArrayInt32({12, 15, 0, 0}, {true, true, false, false}); + auto exp_sub = MakeArrowArrayInt32({-10, -11, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_sub, outputs.at(1)); +} + +template <typename TYPE, typename C_TYPE> +static void TestArithmeticOpsForType(arrow::MemoryPool* pool) { + auto atype = arrow::TypeTraits<TYPE>::type_singleton(); + + // schema for input fields + auto field0 = field("f0", atype); + auto field1 = field("f1", atype); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("add", atype); + auto field_sub = field("subtract", atype); + auto field_mul = field("multiply", atype); + auto field_div = field("divide", atype); + auto field_eq = field("equal", arrow::boolean()); + auto field_lt = field("less_than", arrow::boolean()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + auto sub_expr = + TreeExprBuilder::MakeExpression("subtract", {field0, field1}, field_sub); + auto mul_expr = + TreeExprBuilder::MakeExpression("multiply", {field0, field1}, field_mul); + auto div_expr = TreeExprBuilder::MakeExpression("divide", {field0, field1}, field_div); + auto eq_expr = TreeExprBuilder::MakeExpression("equal", {field0, field1}, field_eq); + auto lt_expr = TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_lt); + + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {sum_expr, sub_expr, mul_expr, div_expr, eq_expr, lt_expr}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 12; + std::vector<C_TYPE> input0 = {1, 2, 53, 84, 5, 15, 0, 1, 52, 83, 4, 120}; + std::vector<C_TYPE> input1 = {10, 15, 23, 84, 4, 51, 68, 9, 16, 18, 19, 37}; + std::vector<bool> validity = {true, true, true, true, true, true, + true, true, true, true, true, true}; + + auto array0 = MakeArrowArray<TYPE, C_TYPE>(input0, validity); + auto array1 = MakeArrowArray<TYPE, C_TYPE>(input1, validity); + + // expected output + std::vector<C_TYPE> sum; + std::vector<C_TYPE> sub; + std::vector<C_TYPE> mul; + std::vector<C_TYPE> div; + std::vector<bool> eq; + std::vector<bool> lt; + for (int i = 0; i < num_records; i++) { + sum.push_back(static_cast<C_TYPE>(input0[i] + input1[i])); + sub.push_back(static_cast<C_TYPE>(input0[i] - input1[i])); + mul.push_back(static_cast<C_TYPE>(input0[i] * input1[i])); + div.push_back(static_cast<C_TYPE>(input0[i] / input1[i])); + eq.push_back(input0[i] == input1[i]); + lt.push_back(input0[i] < input1[i]); + } + auto exp_sum = MakeArrowArray<TYPE, C_TYPE>(sum, validity); + auto exp_sub = MakeArrowArray<TYPE, C_TYPE>(sub, validity); + auto exp_mul = MakeArrowArray<TYPE, C_TYPE>(mul, validity); + auto exp_div = MakeArrowArray<TYPE, C_TYPE>(div, validity); + auto exp_eq = MakeArrowArray<arrow::BooleanType, bool>(eq, validity); + auto exp_lt = MakeArrowArray<arrow::BooleanType, bool>(lt, validity); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_sub, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mul, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(exp_div, outputs.at(3)); + EXPECT_ARROW_ARRAY_EQUALS(exp_eq, outputs.at(4)); + EXPECT_ARROW_ARRAY_EQUALS(exp_lt, outputs.at(5)); +} + +TEST_F(TestProjector, TestAllIntTypes) { + TestArithmeticOpsForType<arrow::UInt8Type, uint8_t>(pool_); + TestArithmeticOpsForType<arrow::UInt16Type, uint16_t>(pool_); + TestArithmeticOpsForType<arrow::UInt32Type, uint32_t>(pool_); + TestArithmeticOpsForType<arrow::UInt64Type, uint64_t>(pool_); + TestArithmeticOpsForType<arrow::Int8Type, int8_t>(pool_); + TestArithmeticOpsForType<arrow::Int16Type, int16_t>(pool_); + TestArithmeticOpsForType<arrow::Int32Type, int32_t>(pool_); + TestArithmeticOpsForType<arrow::Int64Type, int64_t>(pool_); +} + +TEST_F(TestProjector, TestExtendedMath) { + // schema for input fields + auto field0 = arrow::field("f0", arrow::float64()); + auto field1 = arrow::field("f1", arrow::float64()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_cbrt = arrow::field("cbrt", arrow::float64()); + auto field_exp = arrow::field("exp", arrow::float64()); + auto field_log = arrow::field("log", arrow::float64()); + auto field_log10 = arrow::field("log10", arrow::float64()); + auto field_logb = arrow::field("logb", arrow::float64()); + auto field_power = arrow::field("power", arrow::float64()); + auto field_sin = arrow::field("sin", arrow::float64()); + auto field_cos = arrow::field("cos", arrow::float64()); + auto field_asin = arrow::field("asin", arrow::float64()); + auto field_acos = arrow::field("acos", arrow::float64()); + auto field_tan = arrow::field("tan", arrow::float64()); + auto field_atan = arrow::field("atan", arrow::float64()); + auto field_sinh = arrow::field("sinh", arrow::float64()); + auto field_cosh = arrow::field("cosh", arrow::float64()); + auto field_tanh = arrow::field("tanh", arrow::float64()); + auto field_atan2 = arrow::field("atan2", arrow::float64()); + auto field_cot = arrow::field("cot", arrow::float64()); + auto field_radians = arrow::field("radians", arrow::float64()); + auto field_degrees = arrow::field("degrees", arrow::float64()); + + // Build expression + auto cbrt_expr = TreeExprBuilder::MakeExpression("cbrt", {field0}, field_cbrt); + auto exp_expr = TreeExprBuilder::MakeExpression("exp", {field0}, field_exp); + auto log_expr = TreeExprBuilder::MakeExpression("log", {field0}, field_log); + auto log10_expr = TreeExprBuilder::MakeExpression("log10", {field0}, field_log10); + auto logb_expr = TreeExprBuilder::MakeExpression("log", {field0, field1}, field_logb); + auto power_expr = + TreeExprBuilder::MakeExpression("power", {field0, field1}, field_power); + auto sin_expr = TreeExprBuilder::MakeExpression("sin", {field0}, field_sin); + auto cos_expr = TreeExprBuilder::MakeExpression("cos", {field0}, field_cos); + auto asin_expr = TreeExprBuilder::MakeExpression("asin", {field0}, field_asin); + auto acos_expr = TreeExprBuilder::MakeExpression("acos", {field0}, field_acos); + auto tan_expr = TreeExprBuilder::MakeExpression("tan", {field0}, field_tan); + auto atan_expr = TreeExprBuilder::MakeExpression("atan", {field0}, field_atan); + auto sinh_expr = TreeExprBuilder::MakeExpression("sinh", {field0}, field_sinh); + auto cosh_expr = TreeExprBuilder::MakeExpression("cosh", {field0}, field_cosh); + auto tanh_expr = TreeExprBuilder::MakeExpression("tanh", {field0}, field_tanh); + auto atan2_expr = + TreeExprBuilder::MakeExpression("atan2", {field0, field1}, field_atan2); + auto cot_expr = TreeExprBuilder::MakeExpression("cot", {field0}, field_cot); + auto radians_expr = TreeExprBuilder::MakeExpression("radians", {field0}, field_radians); + auto degrees_expr = TreeExprBuilder::MakeExpression("degrees", {field0}, field_degrees); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make( + schema, + {cbrt_expr, exp_expr, log_expr, log10_expr, logb_expr, power_expr, sin_expr, + cos_expr, asin_expr, acos_expr, tan_expr, atan_expr, sinh_expr, cosh_expr, + tanh_expr, atan2_expr, cot_expr, radians_expr, degrees_expr}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + std::vector<double> input0 = {16, 10, -14, 8.3}; + std::vector<double> input1 = {2, 3, 5, 7}; + std::vector<bool> validity = {true, true, true, true}; + + auto array0 = MakeArrowArray<arrow::DoubleType, double>(input0, validity); + auto array1 = MakeArrowArray<arrow::DoubleType, double>(input1, validity); + + // expected output + std::vector<double> cbrt_vals; + std::vector<double> exp_vals; + std::vector<double> log_vals; + std::vector<double> log10_vals; + std::vector<double> logb_vals; + std::vector<double> power_vals; + std::vector<double> sin_vals; + std::vector<double> cos_vals; + std::vector<double> asin_vals; + std::vector<double> acos_vals; + std::vector<double> tan_vals; + std::vector<double> atan_vals; + std::vector<double> sinh_vals; + std::vector<double> cosh_vals; + std::vector<double> tanh_vals; + std::vector<double> atan2_vals; + std::vector<double> cot_vals; + std::vector<double> radians_vals; + std::vector<double> degrees_vals; + for (int i = 0; i < num_records; i++) { + cbrt_vals.push_back(static_cast<double>(cbrtl(input0[i]))); + exp_vals.push_back(static_cast<double>(expl(input0[i]))); + log_vals.push_back(static_cast<double>(logl(input0[i]))); + log10_vals.push_back(static_cast<double>(log10l(input0[i]))); + logb_vals.push_back(static_cast<double>(logl(input1[i]) / logl(input0[i]))); + power_vals.push_back(static_cast<double>(powl(input0[i], input1[i]))); + sin_vals.push_back(static_cast<double>(sin(input0[i]))); + cos_vals.push_back(static_cast<double>(cos(input0[i]))); + asin_vals.push_back(static_cast<double>(asin(input0[i]))); + acos_vals.push_back(static_cast<double>(acos(input0[i]))); + tan_vals.push_back(static_cast<double>(tan(input0[i]))); + atan_vals.push_back(static_cast<double>(atan(input0[i]))); + sinh_vals.push_back(static_cast<double>(sinh(input0[i]))); + cosh_vals.push_back(static_cast<double>(cosh(input0[i]))); + tanh_vals.push_back(static_cast<double>(tanh(input0[i]))); + atan2_vals.push_back(static_cast<double>(atan2(input0[i], input1[i]))); + cot_vals.push_back(static_cast<double>(tan(M_PI / 2 - input0[i]))); + radians_vals.push_back(static_cast<double>(input0[i] * M_PI / 180.0)); + degrees_vals.push_back(static_cast<double>(input0[i] * 180.0 / M_PI)); + } + auto expected_cbrt = MakeArrowArray<arrow::DoubleType, double>(cbrt_vals, validity); + auto expected_exp = MakeArrowArray<arrow::DoubleType, double>(exp_vals, validity); + auto expected_log = MakeArrowArray<arrow::DoubleType, double>(log_vals, validity); + auto expected_log10 = MakeArrowArray<arrow::DoubleType, double>(log10_vals, validity); + auto expected_logb = MakeArrowArray<arrow::DoubleType, double>(logb_vals, validity); + auto expected_power = MakeArrowArray<arrow::DoubleType, double>(power_vals, validity); + auto expected_sin = MakeArrowArray<arrow::DoubleType, double>(sin_vals, validity); + auto expected_cos = MakeArrowArray<arrow::DoubleType, double>(cos_vals, validity); + auto expected_asin = MakeArrowArray<arrow::DoubleType, double>(asin_vals, validity); + auto expected_acos = MakeArrowArray<arrow::DoubleType, double>(acos_vals, validity); + auto expected_tan = MakeArrowArray<arrow::DoubleType, double>(tan_vals, validity); + auto expected_atan = MakeArrowArray<arrow::DoubleType, double>(atan_vals, validity); + auto expected_sinh = MakeArrowArray<arrow::DoubleType, double>(sinh_vals, validity); + auto expected_cosh = MakeArrowArray<arrow::DoubleType, double>(cosh_vals, validity); + auto expected_tanh = MakeArrowArray<arrow::DoubleType, double>(tanh_vals, validity); + auto expected_atan2 = MakeArrowArray<arrow::DoubleType, double>(atan2_vals, validity); + auto expected_cot = MakeArrowArray<arrow::DoubleType, double>(cot_vals, validity); + auto expected_radians = + MakeArrowArray<arrow::DoubleType, double>(radians_vals, validity); + auto expected_degrees = + MakeArrowArray<arrow::DoubleType, double>(degrees_vals, validity); + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + double epsilon = 1E-13; + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cbrt, outputs.at(0), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_exp, outputs.at(1), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_log, outputs.at(2), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_log10, outputs.at(3), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_logb, outputs.at(4), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_power, outputs.at(5), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_sin, outputs.at(6), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cos, outputs.at(7), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_asin, outputs.at(8), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_acos, outputs.at(9), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_tan, outputs.at(10), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_atan, outputs.at(11), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_sinh, outputs.at(12), 1E-08); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cosh, outputs.at(13), 1E-08); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_tanh, outputs.at(14), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_atan2, outputs.at(15), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cot, outputs.at(16), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_radians, outputs.at(17), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_degrees, outputs.at(18), epsilon); +} + +TEST_F(TestProjector, TestFloatLessThan) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = + TreeExprBuilder::MakeExpression("less_than", {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 3; + auto array0 = MakeArrowArrayFloat32({1.0f, 8.9f, 3.0f}, {true, true, false}); + auto array1 = MakeArrowArrayFloat32({4.0f, 3.4f, 6.8f}, {true, true, true}); + // expected output + auto exp = MakeArrowArrayBool({true, false, false}, {true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestIsNotNull) { + // schema for input fields + auto field0 = field("f0", float32()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto myexpr = TreeExprBuilder::MakeExpression("isnotnull", {field0}, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {myexpr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 3; + auto array0 = MakeArrowArrayFloat32({1.0f, 8.9f, 3.0f}, {true, true, false}); + // expected output + auto exp = MakeArrowArrayBool({true, true, false}, {true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestZeroCopy) { + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // output fields + auto res = field("res", float32()); + + // Build expression + auto cast_expr = TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {cast_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false}); + + // allocate output buffers + int64_t bitmap_sz = arrow::BitUtil::BytesForBits(num_records); + int64_t bitmap_capacity = arrow::BitUtil::RoundUpToMultipleOf64(bitmap_sz); + std::vector<uint8_t> bitmap(bitmap_capacity); + std::shared_ptr<arrow::MutableBuffer> bitmap_buf = + std::make_shared<arrow::MutableBuffer>(&bitmap[0], bitmap_capacity); + + int64_t data_sz = sizeof(float) * num_records; + std::vector<uint8_t> data(bitmap_capacity); + std::shared_ptr<arrow::MutableBuffer> data_buf = + std::make_shared<arrow::MutableBuffer>(&data[0], data_sz); + + auto array_data = + arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, data_buf}); + + // Evaluate expression + status = projector->Evaluate(*in_batch, {array_data}); + EXPECT_TRUE(status.ok()); + + // Validate results + auto output = arrow::MakeArray(array_data); + EXPECT_ARROW_ARRAY_EQUALS(exp, output); +} + +TEST_F(TestProjector, TestZeroCopyNegative) { + // schema for input fields + auto field0 = field("f0", int32()); + auto schema = arrow::schema({field0}); + + // output fields + auto res = field("res", float32()); + + // Build expression + auto cast_expr = TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {cast_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // expected output + auto exp = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false}); + + // allocate output buffers + int64_t bitmap_sz = arrow::BitUtil::BytesForBits(num_records); + std::unique_ptr<uint8_t[]> bitmap(new uint8_t[bitmap_sz]); + std::shared_ptr<arrow::MutableBuffer> bitmap_buf = + std::make_shared<arrow::MutableBuffer>(bitmap.get(), bitmap_sz); + + int64_t data_sz = sizeof(float) * num_records; + std::unique_ptr<uint8_t[]> data(new uint8_t[data_sz]); + std::shared_ptr<arrow::MutableBuffer> data_buf = + std::make_shared<arrow::MutableBuffer>(data.get(), data_sz); + + auto array_data = + arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, data_buf}); + + // the batch can't be empty. + auto bad_batch = arrow::RecordBatch::Make(schema, 0 /*num_records*/, {array0}); + status = projector->Evaluate(*bad_batch, {array_data}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output array can't be null. + std::shared_ptr<arrow::ArrayData> null_array_data; + status = projector->Evaluate(*in_batch, {null_array_data}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output array must have at least two buffers. + auto bad_array_data = arrow::ArrayData::Make(float32(), num_records, {bitmap_buf}); + status = projector->Evaluate(*in_batch, {bad_array_data}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output buffers must have sufficiently sized data_buf. + std::shared_ptr<arrow::MutableBuffer> bad_data_buf = + std::make_shared<arrow::MutableBuffer>(data.get(), data_sz - 1); + auto bad_array_data2 = + arrow::ArrayData::Make(float32(), num_records, {bitmap_buf, bad_data_buf}); + status = projector->Evaluate(*in_batch, {bad_array_data2}); + EXPECT_EQ(status.code(), StatusCode::Invalid); + + // the output buffers must have sufficiently sized bitmap_buf. + std::shared_ptr<arrow::MutableBuffer> bad_bitmap_buf = + std::make_shared<arrow::MutableBuffer>(bitmap.get(), bitmap_sz - 1); + auto bad_array_data3 = + arrow::ArrayData::Make(float32(), num_records, {bad_bitmap_buf, data_buf}); + status = projector->Evaluate(*in_batch, {bad_array_data3}); + EXPECT_EQ(status.code(), StatusCode::Invalid); +} + +TEST_F(TestProjector, TestDivideZero) { + // schema for input fields + auto field0 = field("f0", int32()); + auto field1 = field("f2", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_div = field("divide", int32()); + + // Build expression + auto div_expr = TreeExprBuilder::MakeExpression("divide", {field0, field1}, field_div); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {div_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = MakeArrowArrayInt32({2, 3, 4, 5, 6}, {true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({1, 2, 2, 0, 0}, {true, true, false, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_EQ(status.code(), StatusCode::ExecutionError); + std::string expected_error = "divide by zero error"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); + + // Testing for second batch that has no error should succeed. + num_records = 5; + array0 = MakeArrowArrayInt32({2, 3, 4, 5, 6}, {true, true, true, true, true}); + array1 = MakeArrowArrayInt32({1, 2, 2, 1, 1}, {true, true, false, true, true}); + + // prepare input record batch + in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + // expected output + auto exp = MakeArrowArrayInt32({2, 1, 2, 5, 6}, {true, true, false, true, true}); + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestModZero) { + // schema for input fields + auto field0 = field("f0", arrow::int64()); + auto field1 = field("f2", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_div = field("mod", int32()); + + // Build expression + auto mod_expr = TreeExprBuilder::MakeExpression("mod", {field0, field1}, field_div); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {mod_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt64({2, 3, 4, 5}, {true, true, true, true}); + auto array1 = MakeArrowArrayInt32({1, 2, 2, 0}, {true, true, false, true}); + // expected output + auto exp_mod = MakeArrowArrayInt32({0, 1, 0, 5}, {true, true, false, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_mod, outputs.at(0)); +} + +TEST_F(TestProjector, TestConcat) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", arrow::utf8()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_concat = field("concat", arrow::utf8()); + + // Build expression + auto concat_expr = + TreeExprBuilder::MakeExpression("concat", {field0, field1}, field_concat); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {concat_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 6; + auto array0 = MakeArrowArrayUtf8({"ab", "", "ab", "invalid", "valid", "invalid"}, + {true, true, true, false, true, false}); + auto array1 = MakeArrowArrayUtf8({"cd", "cd", "", "valid", "invalid", "invalid"}, + {true, true, true, true, false, false}); + // expected output + auto exp_concat = MakeArrowArrayUtf8({"abcd", "cd", "ab", "valid", "valid", ""}, + {true, true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_concat, outputs.at(0)); +} + +TEST_F(TestProjector, TestBase64) { + // schema for input fields + auto field0 = field("f0", arrow::binary()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_base = field("base64", arrow::utf8()); + + // Build expression + auto base_expr = TreeExprBuilder::MakeExpression("base64", {field0}, field_base); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {base_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = + MakeArrowArrayBinary({"hello", "", "test", "hive"}, {true, true, true, true}); + // expected output + auto exp_base = MakeArrowArrayUtf8({"aGVsbG8=", "", "dGVzdA==", "aGl2ZQ=="}, + {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_base, outputs.at(0)); +} + +TEST_F(TestProjector, TestUnbase64) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_base = field("base64", arrow::binary()); + + // Build expression + auto base_expr = TreeExprBuilder::MakeExpression("unbase64", {field0}, field_base); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {base_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayUtf8({"aGVsbG8=", "", "dGVzdA==", "aGl2ZQ=="}, + {true, true, true, true}); + // expected output + auto exp_unbase = + MakeArrowArrayBinary({"hello", "", "test", "hive"}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_unbase, outputs.at(0)); +} + +TEST_F(TestProjector, TestLeftString) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", arrow::int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_concat = field("left", arrow::utf8()); + + // Build expression + auto concat_expr = + TreeExprBuilder::MakeExpression("left", {field0, field1}, field_concat); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {concat_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 6; + auto array0 = MakeArrowArrayUtf8({"ab", "", "ab", "invalid", "valid", "invalid"}, + {true, true, true, true, true, true}); + auto array1 = + MakeArrowArrayInt32({1, 500, 2, -5, 5, 0}, {true, true, true, true, true, true}); + // expected output + auto exp_left = MakeArrowArrayUtf8({"a", "", "ab", "in", "valid", ""}, + {true, true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_left, outputs.at(0)); +} + +TEST_F(TestProjector, TestRightString) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", arrow::int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_concat = field("right", arrow::utf8()); + + // Build expression + auto concat_expr = + TreeExprBuilder::MakeExpression("right", {field0, field1}, field_concat); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {concat_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 6; + auto array0 = MakeArrowArrayUtf8({"ab", "", "ab", "invalid", "valid", "invalid"}, + {true, true, true, true, true, true}); + auto array1 = + MakeArrowArrayInt32({1, 500, 2, -5, 5, 0}, {true, true, true, true, true, true}); + // expected output + auto exp_left = MakeArrowArrayUtf8({"b", "", "ab", "id", "valid", ""}, + {true, true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_left, outputs.at(0)); +} + +TEST_F(TestProjector, TestOffset) { + // schema for input fields + auto field0 = field("f0", arrow::int32()); + auto field1 = field("f1", arrow::int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_sum = field("sum", arrow::int32()); + + // Build expression + auto sum_expr = TreeExprBuilder::MakeExpression("add", {field0, field1}, field_sum); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {sum_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({1, 2, 3, 4, 5}, {true, true, true, true, false}); + array0 = array0->Slice(1); + auto array1 = MakeArrowArrayInt32({5, 6, 7, 8}, {true, false, true, true}); + // expected output + auto exp_sum = MakeArrowArrayInt32({9, 11, 13}, {false, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + in_batch = in_batch->Slice(1); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); +} + +TEST_F(TestProjector, TestByteSubString) { + // schema for input fields + auto field0 = field("f0", arrow::binary()); + auto field1 = field("f1", arrow::int32()); + auto field2 = field("f2", arrow::int32()); + auto schema = arrow::schema({field0, field1, field2}); + + // output fields + auto field_byte_substr = field("bytesubstring", arrow::binary()); + + // Build expression + auto byte_substr_expr = TreeExprBuilder::MakeExpression( + "bytesubstring", {field0, field1, field2}, field_byte_substr); + + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {byte_substr_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 6; + auto array0 = MakeArrowArrayBinary({"ab", "", "ab", "invalid", "valid", "invalid"}, + {true, true, true, true, true, true}); + auto array1 = + MakeArrowArrayInt32({0, 1, 1, 1, 3, 3}, {true, true, true, true, true, true}); + auto array2 = + MakeArrowArrayInt32({0, 1, 1, 2, 3, 3}, {true, true, true, true, true, true}); + // expected output + auto exp_byte_substr = MakeArrowArrayBinary({"", "", "a", "in", "lid", "val"}, + {true, true, true, true, true, true}); + + // prepare input record batch + auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_byte_substr, outputs.at(0)); +} + +// Test to ensure behaviour of cast functions when the validity is false for an input. The +// function should not run for that input. +TEST_F(TestProjector, TestCastFunction) { + auto field0 = field("f0", arrow::utf8()); + auto schema = arrow::schema({field0}); + + // output fields + auto res_float4 = field("res_float4", arrow::float32()); + auto res_float8 = field("res_float8", arrow::float64()); + auto res_int4 = field("castINT", arrow::int32()); + auto res_int8 = field("castBIGINT", arrow::int64()); + + // Build expression + auto cast_expr_float4 = + TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res_float4); + auto cast_expr_float8 = + TreeExprBuilder::MakeExpression("castFLOAT8", {field0}, res_float8); + auto cast_expr_int4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int4); + auto cast_expr_int8 = TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int8); + + std::shared_ptr<Projector> projector; + + // {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8} + auto status = Projector::Make( + schema, {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = MakeArrowArrayUtf8({"1", "2", "3", ""}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + auto out_float4 = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false}); + auto out_float8 = MakeArrowArrayFloat64({1, 2, 3, 0}, {true, true, true, false}); + auto out_int4 = MakeArrowArrayInt32({1, 2, 3, 0}, {true, true, true, false}); + auto out_int8 = MakeArrowArrayInt64({1, 2, 3, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(out_int4, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(3)); +} + +TEST_F(TestProjector, TestCastBitFunction) { + auto field0 = field("f0", arrow::utf8()); + auto schema = arrow::schema({field0}); + + // output fields + auto res_bit = field("res_bit", arrow::boolean()); + + // Build expression + auto cast_bit = TreeExprBuilder::MakeExpression("castBIT", {field0}, res_bit); + + std::shared_ptr<Projector> projector; + + auto status = Projector::Make(schema, {cast_bit}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto arr = MakeArrowArrayUtf8({"1", "true", "false", "0"}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arr}); + + auto out = MakeArrowArrayBool({true, true, false, false}, {true, true, true, true}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out, outputs.at(0)); +} + +// Test to ensure behaviour of cast functions when the validity is false for an input. The +// function should not run for that input. +TEST_F(TestProjector, TestCastVarbinaryFunction) { + auto field0 = field("f0", arrow::binary()); + auto schema = arrow::schema({field0}); + + // output fields + auto res_int4 = field("res_int4", arrow::int32()); + auto res_int8 = field("res_int8", arrow::int64()); + auto res_float4 = field("res_float4", arrow::float32()); + auto res_float8 = field("res_float8", arrow::float64()); + + // Build expression + auto cast_expr_int4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int4); + auto cast_expr_int8 = TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int8); + auto cast_expr_float4 = + TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res_float4); + auto cast_expr_float8 = + TreeExprBuilder::MakeExpression("castFLOAT8", {field0}, res_float8); + + std::shared_ptr<Projector> projector; + + // {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8} + auto status = Projector::Make( + schema, {cast_expr_int4, cast_expr_int8, cast_expr_float4, cast_expr_float8}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = + MakeArrowArrayBinary({"37", "-99999", "99999", "4"}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + auto out_int4 = MakeArrowArrayInt32({37, -99999, 99999, 0}, {true, true, true, false}); + auto out_int8 = MakeArrowArrayInt64({37, -99999, 99999, 0}, {true, true, true, false}); + auto out_float4 = + MakeArrowArrayFloat32({37, -99999, 99999, 0}, {true, true, true, false}); + auto out_float8 = + MakeArrowArrayFloat64({37, -99999, 99999, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_int4, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(3)); +} + +TEST_F(TestProjector, TestToDate) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field_node = std::make_shared<FieldNode>(field0); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", arrow::date64()); + + auto pattern_node = std::make_shared<LiteralNode>( + arrow::utf8(), LiteralHolder(std::string("YYYY-MM-DD")), false); + + // Build expression + auto fn_node = TreeExprBuilder::MakeFunction("to_date", {field_node, pattern_node}, + arrow::date64()); + auto expr = TreeExprBuilder::MakeExpression(fn_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 3; + auto array0 = + MakeArrowArrayUtf8({"1986-12-01", "2012-12-01", "invalid"}, {true, true, false}); + // expected output + auto exp = MakeArrowArrayDate64({533779200000, 1354320000000, 0}, {true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +// ARROW-11617 +TEST_F(TestProjector, TestIfElseOpt) { + // schema for input + auto field0 = field("f0", int32()); + auto field1 = field("f1", int32()); + auto field2 = field("f2", int32()); + auto schema = arrow::schema({field0, field1, field2}); + + auto f0 = std::make_shared<FieldNode>(field0); + auto f1 = std::make_shared<FieldNode>(field1); + auto f2 = std::make_shared<FieldNode>(field2); + + // output fields + auto field_result = field("out", int32()); + + // Expr - (f0, f1 - null; f2 non null) + // + // if (is not null(f0)) + // then f0 + // else add(( + // if (is not null (f1)) + // then f1 + // else f2 + // ), f1) + + auto cond_node_inner = TreeExprBuilder::MakeFunction("isnotnull", {f1}, boolean()); + auto if_node_inner = TreeExprBuilder::MakeIf(cond_node_inner, f1, f2, int32()); + + auto cond_node_outer = TreeExprBuilder::MakeFunction("isnotnull", {f0}, boolean()); + auto else_node_outer = + TreeExprBuilder::MakeFunction("add", {if_node_inner, f1}, int32()); + + auto if_node_outer = + TreeExprBuilder::MakeIf(cond_node_outer, f1, else_node_outer, int32()); + auto expr = TreeExprBuilder::MakeExpression(if_node_outer, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 1; + auto array0 = MakeArrowArrayInt32({0}, {false}); + auto array1 = MakeArrowArrayInt32({0}, {false}); + auto array2 = MakeArrowArrayInt32({99}, {true}); + // expected output + auto exp = MakeArrowArrayInt32({0}, {false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestRepeat) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", arrow::int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_repeat = field("repeat", arrow::utf8()); + + // Build expression + auto repeat_expr = + TreeExprBuilder::MakeExpression("repeat", {field0, field1}, field_repeat); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {repeat_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array0 = + MakeArrowArrayUtf8({"ab", "a", "car", "valid", ""}, {true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({2, 1, 3, 2, 10}, {true, true, true, true, true}); + // expected output + auto exp_repeat = MakeArrowArrayUtf8({"abab", "a", "carcarcar", "validvalid", ""}, + {true, true, true, true, true}); + + // prepare input record batch + auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_repeat, outputs.at(0)); +} + +TEST_F(TestProjector, TestLpad) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", arrow::int32()); + auto field2 = field("f2", arrow::utf8()); + auto schema = arrow::schema({field0, field1, field2}); + + // output fields + auto field_lpad = field("lpad", arrow::utf8()); + + // Build expression + auto lpad_expr = + TreeExprBuilder::MakeExpression("lpad", {field0, field1, field2}, field_lpad); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {lpad_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 7; + auto array0 = MakeArrowArrayUtf8({"ab", "a", "ab", "invalid", "valid", "invalid", ""}, + {true, true, true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({1, 5, 3, 12, 0, 2, 10}, + {true, true, true, true, true, true, true}); + auto array2 = MakeArrowArrayUtf8({"z", "z", "c", "valid", "invalid", "invalid", ""}, + {true, true, true, true, true, true, true}); + // expected output + auto exp_lpad = MakeArrowArrayUtf8({"a", "zzzza", "cab", "validinvalid", "", "in", ""}, + {true, true, true, true, true, true, true}); + + // prepare input record batch + auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_lpad, outputs.at(0)); +} + +TEST_F(TestProjector, TestRpad) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", arrow::int32()); + auto field2 = field("f2", arrow::utf8()); + auto schema = arrow::schema({field0, field1, field2}); + + // output fields + auto field_rpad = field("rpad", arrow::utf8()); + + // Build expression + auto rpad_expr = + TreeExprBuilder::MakeExpression("rpad", {field0, field1, field2}, field_rpad); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {rpad_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 7; + auto array0 = MakeArrowArrayUtf8({"ab", "a", "ab", "invalid", "valid", "invalid", ""}, + {true, true, true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({1, 5, 3, 12, 0, 2, 10}, + {true, true, true, true, true, true, true}); + auto array2 = MakeArrowArrayUtf8({"z", "z", "c", "valid", "invalid", "invalid", ""}, + {true, true, true, true, true, true, true}); + // expected output + auto exp_rpad = MakeArrowArrayUtf8({"a", "azzzz", "abc", "invalidvalid", "", "in", ""}, + {true, true, true, true, true, true, true}); + + // prepare input record batch + auto in = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_rpad, outputs.at(0)); +} + +TEST_F(TestProjector, TestBinRepresentation) { + // schema for input fields + auto field0 = field("f0", arrow::int64()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("bin", arrow::utf8()); + + // Build expression + auto myexpr = TreeExprBuilder::MakeExpression("bin", {field0}, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {myexpr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 3; + auto array0 = MakeArrowArrayInt64({7, -28550, 58117}, {true, true, true}); + // expected output + auto exp = MakeArrowArrayUtf8( + {"111", "1111111111111111111111111111111111111111111111111001000001111010", + "1110001100000101"}, + {true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestProjector, TestBigIntCastFunction) { + // input fields + auto field0 = field("f0", arrow::float32()); + auto field1 = field("f1", arrow::float64()); + auto field2 = field("f2", arrow::day_time_interval()); + auto field3 = field("f3", arrow::month_interval()); + auto schema = arrow::schema({field0, field1, field2, field3}); + + // output fields + auto res_int64 = field("res", arrow::int64()); + + // Build expression + auto cast_expr_float4 = + TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int64); + auto cast_expr_float8 = + TreeExprBuilder::MakeExpression("castBIGINT", {field1}, res_int64); + auto cast_expr_day_interval = + TreeExprBuilder::MakeExpression("castBIGINT", {field2}, res_int64); + auto cast_expr_year_interval = + TreeExprBuilder::MakeExpression("castBIGINT", {field3}, res_int64); + + std::shared_ptr<Projector> projector; + + // {cast_expr_float4, cast_expr_float8, cast_expr_day_interval, + // cast_expr_year_interval} + auto status = Projector::Make(schema, + {cast_expr_float4, cast_expr_float8, + cast_expr_day_interval, cast_expr_year_interval}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = + MakeArrowArrayFloat32({6.6f, -6.6f, 9.999999f, 0}, {true, true, true, false}); + auto array1 = + MakeArrowArrayFloat64({6.6, -6.6, 9.99999999999, 0}, {true, true, true, false}); + auto array2 = MakeArrowArrayInt64({100, 25, -0, 0}, {true, true, true, false}); + auto array3 = MakeArrowArrayInt32({25, -25, -0, 0}, {true, true, true, false}); + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2, array3}); + + auto out_float4 = MakeArrowArrayInt64({7, -7, 10, 0}, {true, true, true, false}); + auto out_float8 = MakeArrowArrayInt64({7, -7, 10, 0}, {true, true, true, false}); + auto out_days_interval = + MakeArrowArrayInt64({8640000000, 2160000000, 0, 0}, {true, true, true, false}); + auto out_year_interval = MakeArrowArrayInt64({2, -2, 0, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(out_days_interval, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(out_year_interval, outputs.at(3)); +} + +TEST_F(TestProjector, TestIntCastFunction) { + // input fields + auto field0 = field("f0", arrow::float32()); + auto field1 = field("f1", arrow::float64()); + auto field2 = field("f2", arrow::month_interval()); + auto schema = arrow::schema({field0, field1, field2}); + + // output fields + auto res_int32 = field("res", arrow::int32()); + + // Build expression + auto cast_expr_float4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int32); + auto cast_expr_float8 = TreeExprBuilder::MakeExpression("castINT", {field1}, res_int32); + auto cast_expr_year_interval = + TreeExprBuilder::MakeExpression("castINT", {field2}, res_int32); + + std::shared_ptr<Projector> projector; + + // {cast_expr_float4, cast_expr_float8, cast_expr_day_interval, + // cast_expr_year_interval} + auto status = Projector::Make( + schema, {cast_expr_float4, cast_expr_float8, cast_expr_year_interval}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = + MakeArrowArrayFloat32({6.6f, -6.6f, 9.999999f, 0}, {true, true, true, false}); + auto array1 = + MakeArrowArrayFloat64({6.6, -6.6, 9.99999999999, 0}, {true, true, true, false}); + auto array2 = MakeArrowArrayInt32({25, -25, -0, 0}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2}); + + auto out_float4 = MakeArrowArrayInt32({7, -7, 10, 0}, {true, true, true, false}); + auto out_float8 = MakeArrowArrayInt32({7, -7, 10, 0}, {true, true, true, false}); + auto out_year_interval = MakeArrowArrayInt32({2, -2, 0, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(out_year_interval, outputs.at(2)); +} + +TEST_F(TestProjector, TestCastNullableIntYearInterval) { + // input fields + auto field1 = field("f1", arrow::month_interval()); + auto schema = arrow::schema({field1}); + + // output fields + auto res_int32 = field("res", arrow::int32()); + auto res_int64 = field("res", arrow::int64()); + + // Build expression + auto cast_expr_int32 = + TreeExprBuilder::MakeExpression("castNULLABLEINT", {field1}, res_int32); + auto cast_expr_int64 = + TreeExprBuilder::MakeExpression("castNULLABLEBIGINT", {field1}, res_int64); + + std::shared_ptr<Projector> projector; + + // {cast_expr_int32, cast_expr_int64, cast_expr_day_interval, + // cast_expr_year_interval} + auto status = Projector::Make(schema, {cast_expr_int32, cast_expr_int64}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = MakeArrowArrayInt32({12, -24, -0, 0}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + auto out_int32 = MakeArrowArrayInt32({1, -2, -0, 0}, {true, true, true, false}); + auto out_int64 = MakeArrowArrayInt64({1, -2, -0, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_int32, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_int64, outputs.at(1)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/test_util.h b/src/arrow/cpp/src/gandiva/tests/test_util.h new file mode 100644 index 000000000..54270436c --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/test_util.h @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <chrono> +#include <memory> +#include <utility> +#include <vector> + +#include "arrow/testing/gtest_util.h" +#include "gandiva/arrow.h" +#include "gandiva/configuration.h" + +#pragma once + +namespace gandiva { + +// Helper function to create an arrow-array of type ARROWTYPE +// from primitive vectors of data & validity. +// +// arrow/testing/gtest_util.h has good utility classes for this purpose. +// Using those +template <typename TYPE, typename C_TYPE> +static inline ArrayPtr MakeArrowArray(std::vector<C_TYPE> values, + std::vector<bool> validity) { + ArrayPtr out; + arrow::ArrayFromVector<TYPE, C_TYPE>(validity, values, &out); + return out; +} + +template <typename TYPE, typename C_TYPE> +static inline ArrayPtr MakeArrowArray(std::vector<C_TYPE> values) { + ArrayPtr out; + arrow::ArrayFromVector<TYPE, C_TYPE>(values, &out); + return out; +} + +template <typename TYPE, typename C_TYPE> +static inline ArrayPtr MakeArrowArray(const std::shared_ptr<arrow::DataType>& type, + std::vector<C_TYPE> values, + std::vector<bool> validity) { + ArrayPtr out; + arrow::ArrayFromVector<TYPE, C_TYPE>(type, validity, values, &out); + return out; +} + +template <typename TYPE, typename C_TYPE> +static inline ArrayPtr MakeArrowTypeArray(const std::shared_ptr<arrow::DataType>& type, + const std::vector<C_TYPE>& values, + const std::vector<bool>& validity) { + ArrayPtr out; + arrow::ArrayFromVector<TYPE, C_TYPE>(type, validity, values, &out); + return out; +} + +#define MakeArrowArrayBool MakeArrowArray<arrow::BooleanType, bool> +#define MakeArrowArrayInt8 MakeArrowArray<arrow::Int8Type, int8_t> +#define MakeArrowArrayInt16 MakeArrowArray<arrow::Int16Type, int16_t> +#define MakeArrowArrayInt32 MakeArrowArray<arrow::Int32Type, int32_t> +#define MakeArrowArrayInt64 MakeArrowArray<arrow::Int64Type, int64_t> +#define MakeArrowArrayUint8 MakeArrowArray<arrow::UInt8Type, uint8_t> +#define MakeArrowArrayUint16 MakeArrowArray<arrow::UInt16Type, uint16_t> +#define MakeArrowArrayUint32 MakeArrowArray<arrow::UInt32Type, uint32_t> +#define MakeArrowArrayUint64 MakeArrowArray<arrow::UInt64Type, uint64_t> +#define MakeArrowArrayFloat32 MakeArrowArray<arrow::FloatType, float> +#define MakeArrowArrayFloat64 MakeArrowArray<arrow::DoubleType, double> +#define MakeArrowArrayDate64 MakeArrowArray<arrow::Date64Type, int64_t> +#define MakeArrowArrayUtf8 MakeArrowArray<arrow::StringType, std::string> +#define MakeArrowArrayBinary MakeArrowArray<arrow::BinaryType, std::string> +#define MakeArrowArrayDecimal MakeArrowArray<arrow::Decimal128Type, arrow::Decimal128> + +#define EXPECT_ARROW_ARRAY_EQUALS(a, b) \ + EXPECT_TRUE((a)->Equals(b, arrow::EqualOptions().nans_equal(true))) \ + << "expected array: " << (a)->ToString() << " actual array: " << (b)->ToString() + +#define EXPECT_ARROW_ARRAY_APPROX_EQUALS(a, b, epsilon) \ + EXPECT_TRUE( \ + (a)->ApproxEquals(b, arrow::EqualOptions().atol(epsilon).nans_equal(true))) \ + << "expected array: " << (a)->ToString() << " actual array: " << (b)->ToString() + +#define EXPECT_ARROW_TYPE_EQUALS(a, b) \ + EXPECT_TRUE((a)->Equals(b)) << "expected type: " << (a)->ToString() \ + << " actual type: " << (b)->ToString() + +static inline std::shared_ptr<Configuration> TestConfiguration() { + auto builder = ConfigurationBuilder(); + return builder.DefaultConfiguration(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/timed_evaluate.h b/src/arrow/cpp/src/gandiva/tests/timed_evaluate.h new file mode 100644 index 000000000..eba0f5eb9 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/timed_evaluate.h @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <memory> +#include <vector> +#include "benchmark/benchmark.h" +#include "gandiva/arrow.h" +#include "gandiva/filter.h" +#include "gandiva/projector.h" +#include "gandiva/tests/generate_data.h" + +#pragma once + +#define THOUSAND (1024) +#define MILLION (1024 * 1024) +#define NUM_BATCHES 16 + +namespace gandiva { + +template <typename C_TYPE> +std::vector<C_TYPE> GenerateData(int num_records, DataGenerator<C_TYPE>& data_generator) { + std::vector<C_TYPE> data; + + for (int i = 0; i < num_records; i++) { + data.push_back(data_generator.GenerateData()); + } + + return data; +} + +class BaseEvaluator { + public: + virtual ~BaseEvaluator() = default; + + virtual Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) = 0; +}; + +class ProjectEvaluator : public BaseEvaluator { + public: + explicit ProjectEvaluator(std::shared_ptr<Projector> projector) + : projector_(projector) {} + + Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override { + arrow::ArrayVector outputs; + return projector_->Evaluate(batch, pool, &outputs); + } + + private: + std::shared_ptr<Projector> projector_; +}; + +class FilterEvaluator : public BaseEvaluator { + public: + explicit FilterEvaluator(std::shared_ptr<Filter> filter) : filter_(filter) {} + + Status Evaluate(arrow::RecordBatch& batch, arrow::MemoryPool* pool) override { + if (selection_ == nullptr || selection_->GetMaxSlots() < batch.num_rows()) { + auto status = SelectionVector::MakeInt16(batch.num_rows(), pool, &selection_); + if (!status.ok()) { + return status; + } + } + return filter_->Evaluate(batch, selection_); + } + + private: + std::shared_ptr<Filter> filter_; + std::shared_ptr<SelectionVector> selection_; +}; + +template <typename TYPE, typename C_TYPE> +Status TimedEvaluate(SchemaPtr schema, BaseEvaluator& evaluator, + DataGenerator<C_TYPE>& data_generator, arrow::MemoryPool* pool, + int num_records, int batch_size, benchmark::State& state) { + int num_remaining = num_records; + int num_fields = schema->num_fields(); + int num_calls = 0; + Status status; + + // Generate batches of data + std::shared_ptr<arrow::RecordBatch> batches[NUM_BATCHES]; + for (int i = 0; i < NUM_BATCHES; i++) { + // generate data for all columns in the schema + std::vector<ArrayPtr> columns; + for (int col = 0; col < num_fields; col++) { + std::vector<C_TYPE> data = GenerateData<C_TYPE>(batch_size, data_generator); + std::vector<bool> validity(batch_size, true); + ArrayPtr col_data = + MakeArrowArray<TYPE, C_TYPE>(schema->field(col)->type(), data, validity); + + columns.push_back(col_data); + } + + // make the record batch + std::shared_ptr<arrow::RecordBatch> batch = + arrow::RecordBatch::Make(schema, batch_size, columns); + batches[i] = batch; + } + + for (auto _ : state) { + int num_in_batch = batch_size; + num_remaining = num_records; + while (num_remaining > 0) { + if (batch_size > num_remaining) { + num_in_batch = num_remaining; + } + + status = evaluator.Evaluate(*(batches[num_calls % NUM_BATCHES]), pool); + if (!status.ok()) { + state.SkipWithError("Evaluation of the batch failed"); + return status; + } + + num_calls++; + num_remaining -= num_in_batch; + } + } + + return Status::OK(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/to_string_test.cc b/src/arrow/cpp/src/gandiva/tests/to_string_test.cc new file mode 100644 index 000000000..55db6e92b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/to_string_test.cc @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include <math.h> +#include <time.h> +#include "arrow/memory_pool.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float64; +using arrow::int32; +using arrow::int64; + +class TestToString : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +#define CHECK_EXPR_TO_STRING(e, str) EXPECT_STREQ(e->ToString().c_str(), str) + +TEST_F(TestToString, TestAll) { + auto literal_node = TreeExprBuilder::MakeLiteral((uint64_t)100); + auto literal_expr = + TreeExprBuilder::MakeExpression(literal_node, arrow::field("r", int64())); + CHECK_EXPR_TO_STRING(literal_expr, "(const uint64) 100"); + + auto f0 = arrow::field("f0", float64()); + auto f0_node = TreeExprBuilder::MakeField(f0); + auto f0_expr = TreeExprBuilder::MakeExpression(f0_node, f0); + CHECK_EXPR_TO_STRING(f0_expr, "(double) f0"); + + auto f1 = arrow::field("f1", int64()); + auto f2 = arrow::field("f2", int64()); + auto f1_node = TreeExprBuilder::MakeField(f1); + auto f2_node = TreeExprBuilder::MakeField(f2); + auto add_node = TreeExprBuilder::MakeFunction("add", {f1_node, f2_node}, int64()); + auto add_expr = TreeExprBuilder::MakeExpression(add_node, f1); + CHECK_EXPR_TO_STRING(add_expr, "int64 add((int64) f1, (int64) f2)"); + + auto cond_node = TreeExprBuilder::MakeFunction( + "lesser_than", {f0_node, TreeExprBuilder::MakeLiteral(static_cast<float>(0))}, + boolean()); + auto then_node = TreeExprBuilder::MakeField(f1); + auto else_node = TreeExprBuilder::MakeField(f2); + + auto if_node = TreeExprBuilder::MakeIf(cond_node, then_node, else_node, int64()); + auto if_expr = TreeExprBuilder::MakeExpression(if_node, f1); + + CHECK_EXPR_TO_STRING(if_expr, + "if (bool lesser_than((double) f0, (const float) 0 raw(0))) { " + "(int64) f1 } else { (int64) f2 }"); + + auto f1_gt_100 = + TreeExprBuilder::MakeFunction("greater_than", {f1_node, literal_node}, boolean()); + auto f2_equals_100 = + TreeExprBuilder::MakeFunction("equals", {f2_node, literal_node}, boolean()); + auto and_node = TreeExprBuilder::MakeAnd({f1_gt_100, f2_equals_100}); + auto and_expr = + TreeExprBuilder::MakeExpression(and_node, arrow::field("f0", boolean())); + + CHECK_EXPR_TO_STRING(and_expr, + "bool greater_than((int64) f1, (const uint64) 100) && bool " + "equals((int64) f2, (const uint64) 100)"); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tests/utf8_test.cc b/src/arrow/cpp/src/gandiva/tests/utf8_test.cc new file mode 100644 index 000000000..e19d6712d --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/utf8_test.cc @@ -0,0 +1,751 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" + +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::date64; +using arrow::int32; +using arrow::int64; +using arrow::utf8; + +class TestUtf8 : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestUtf8, TestSimple) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_1 = field("res1", int32()); + auto res_2 = field("res2", boolean()); + auto res_3 = field("res3", int32()); + + // build expressions. + // octet_length(a) + // octet_length(a) == bit_length(a) / 8 + // length(a) + auto expr_a = TreeExprBuilder::MakeExpression("octet_length", {field_a}, res_1); + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto octet_length = TreeExprBuilder::MakeFunction("octet_length", {node_a}, int32()); + auto literal_8 = TreeExprBuilder::MakeLiteral((int32_t)8); + auto bit_length = TreeExprBuilder::MakeFunction("bit_length", {node_a}, int32()); + auto div_8 = TreeExprBuilder::MakeFunction("divide", {bit_length, literal_8}, int32()); + auto is_equal = + TreeExprBuilder::MakeFunction("equal", {octet_length, div_8}, boolean()); + auto expr_b = TreeExprBuilder::MakeExpression(is_equal, res_2); + auto expr_c = TreeExprBuilder::MakeExpression("length", {field_a}, res_3); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = + Projector::Make(schema, {expr_a, expr_b, expr_c}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_a = MakeArrowArrayUtf8({"foo", "hello", "bye", "hi", "मदन"}, + {true, true, false, true, true}); + + // expected output + auto exp_1 = MakeArrowArrayInt32({3, 5, 0, 2, 9}, {true, true, false, true, true}); + auto exp_2 = MakeArrowArrayBool({true, true, false, true, true}, + {true, true, false, true, true}); + auto exp_3 = MakeArrowArrayInt32({3, 5, 0, 2, 3}, {true, true, false, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_1, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(exp_3, outputs.at(2)); +} + +TEST_F(TestUtf8, TestLiteral) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // a == literal(s) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_s = TreeExprBuilder::MakeStringLiteral("hello"); + auto is_equal = TreeExprBuilder::MakeFunction("equal", {node_a, literal_s}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_equal, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {true, true, true, false}); + + // expected output + auto exp = MakeArrowArrayBool({false, true, false, false}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestNullLiteral) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // a == literal(null) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_null = TreeExprBuilder::MakeNull(arrow::utf8()); + auto is_equal = + TreeExprBuilder::MakeFunction("equal", {node_a, literal_null}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_equal, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"foo", "hello", "bye", "hi"}, {true, true, true, false}); + + // expected output + auto exp = + MakeArrowArrayBool({false, false, false, false}, {false, false, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestLike) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // like(literal(s), a) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_s = TreeExprBuilder::MakeStringLiteral("%spark%"); + auto is_like = TreeExprBuilder::MakeFunction("like", {node_a, literal_s}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_like, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8({"park", "sparkle", "bright spark and fire", "spark"}, + {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayBool({false, true, true, true}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestLikeWithEscape) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // like(literal(s), a, '\') + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_s = TreeExprBuilder::MakeStringLiteral("%pa\\%rk%"); + auto escape_char = TreeExprBuilder::MakeStringLiteral("\\"); + auto is_like = + TreeExprBuilder::MakeFunction("like", {node_a, literal_s, escape_char}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_like, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8( + {"park", "spa%rkle", "bright spa%rk and fire", "spark"}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayBool({false, true, true, false}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestBeginsEnds) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res1 = field("res1", boolean()); + auto res2 = field("res2", boolean()); + + // build expressions. + // like(literal("spark%"), a) + // like(literal("%spark"), a) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto literal_begin = TreeExprBuilder::MakeStringLiteral("spark%"); + auto is_like1 = + TreeExprBuilder::MakeFunction("like", {node_a, literal_begin}, boolean()); + auto expr1 = TreeExprBuilder::MakeExpression(is_like1, res1); + + auto literal_end = TreeExprBuilder::MakeStringLiteral("%spark"); + auto is_like2 = TreeExprBuilder::MakeFunction("like", {node_a, literal_end}, boolean()); + auto expr2 = TreeExprBuilder::MakeExpression(is_like2, res2); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr1, expr2}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = + MakeArrowArrayUtf8({"park", "sparkle", "bright spark and fire", "fiery spark"}, + {true, true, true, true}); + + // expected output + auto exp1 = MakeArrowArrayBool({false, true, false, false}, {true, true, true, true}); + auto exp2 = MakeArrowArrayBool({false, false, false, true}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp1, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp2, outputs.at(1)); +} + +TEST_F(TestUtf8, TestInternalAllocs) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + // like(upper(a), literal("%SPARK%")) + + auto node_a = TreeExprBuilder::MakeField(field_a); + auto upper_a = TreeExprBuilder::MakeFunction("upper", {node_a}, utf8()); + auto literal_spark = TreeExprBuilder::MakeStringLiteral("%SPARK%"); + auto is_like = + TreeExprBuilder::MakeFunction("like", {upper_a, literal_spark}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(is_like, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_a = MakeArrowArrayUtf8( + {"park", "Sparkle", "bright spark and fire", "fiery SPARK", "मदन"}, + {true, true, false, true, true}); + + // expected output + auto exp = MakeArrowArrayBool({false, true, false, true, false}, + {true, true, false, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestCastDate) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_1 = field("res1", int64()); + + // build expressions. + // extractYear(castDATE(a)) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto cast_function = TreeExprBuilder::MakeFunction("castDATE", {node_a}, date64()); + auto extract_year = + TreeExprBuilder::MakeFunction("extractYear", {cast_function}, int64()); + auto expr = TreeExprBuilder::MakeExpression(extract_year, res_1); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "incorrect", "67-45-11"}, + {true, true, false, true}); + + // expected output + auto exp_1 = MakeArrowArrayInt64({1967, 2067, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_EQ(status.code(), StatusCode::ExecutionError); + std::string expected_error = "Not a valid date value "; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); + + auto array_a_2 = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "67-1-1", "91-1-1"}, + {true, true, true, true}); + auto exp_2 = MakeArrowArrayInt64({1967, 2067, 2067, 1991}, {true, true, true, true}); + auto in_batch_2 = arrow::RecordBatch::Make(schema, num_records, {array_a_2}); + arrow::ArrayVector outputs2; + status = projector->Evaluate(*in_batch_2, pool_, &outputs2); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs2.at(0)); +} + +TEST_F(TestUtf8, TestToDateNoError) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_1 = field("res1", int64()); + + // build expressions. + // extractYear(castDATE(a)) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeStringLiteral("YYYY-MM-DD"); + auto node_c = TreeExprBuilder::MakeLiteral(1); + + auto cast_function = + TreeExprBuilder::MakeFunction("to_date", {node_a, node_b, node_c}, date64()); + auto extract_year = + TreeExprBuilder::MakeFunction("extractYear", {cast_function}, int64()); + auto expr = TreeExprBuilder::MakeExpression(extract_year, res_1); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "incorrect", "67-45-11"}, + {true, true, false, true}); + + // expected output + auto exp_1 = MakeArrowArrayInt64({1967, 67, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + EXPECT_ARROW_ARRAY_EQUALS(exp_1, outputs.at(0)); + + // Create a row-batch with some sample data + auto array_a_2 = MakeArrowArrayUtf8( + {"1967-12-1", "1967-12-01", "1967-11-11", "1991-11-11"}, {true, true, true, true}); + auto exp_2 = MakeArrowArrayInt64({1967, 1967, 1967, 1991}, {true, true, true, true}); + auto in_batch_2 = arrow::RecordBatch::Make(schema, num_records, {array_a_2}); + arrow::ArrayVector outputs2; + status = projector->Evaluate(*in_batch_2, pool_, &outputs2); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs2.at(0)); +} + +TEST_F(TestUtf8, TestToDateError) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // output fields + auto res_1 = field("res1", int64()); + + // build expressions. + // extractYear(castDATE(a)) + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_b = TreeExprBuilder::MakeStringLiteral("YYYY-MM-DD"); + auto node_c = TreeExprBuilder::MakeLiteral(0); + + auto cast_function = + TreeExprBuilder::MakeFunction("to_date", {node_a, node_b, node_c}, date64()); + auto extract_year = + TreeExprBuilder::MakeFunction("extractYear", {cast_function}, int64()); + auto expr = TreeExprBuilder::MakeExpression(extract_year, res_1); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8({"1967-12-1", "67-12-01", "incorrect", "67-45-11"}, + {true, true, false, true}); + + // expected output + auto exp_1 = MakeArrowArrayInt64({1967, 67, 0, 0}, {true, true, false, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_EQ(status.code(), StatusCode::ExecutionError); + std::string expected_error = "Error parsing value 67-45-11 for given format"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos) + << status.message(); +} + +TEST_F(TestUtf8, TestIsNull) { + // schema for input fields + auto field_a = field("a", utf8()); + auto schema = arrow::schema({field_a}); + + // build expressions + auto exprs = std::vector<ExpressionPtr>{ + TreeExprBuilder::MakeExpression("isnull", {field_a}, field("is_null", boolean())), + TreeExprBuilder::MakeExpression("isnotnull", {field_a}, + field("is_not_null", boolean())), + }; + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, exprs, TestConfiguration(), &projector); + DCHECK_OK(status); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_a = MakeArrowArrayUtf8({"hello", "world", "incorrect", "universe"}, + {true, true, false, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + + // validate results + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({false, false, true, false}), + outputs[0]); // isnull + EXPECT_ARROW_ARRAY_EQUALS(MakeArrowArrayBool({true, true, false, true}), + outputs[1]); // isnotnull +} + +TEST_F(TestUtf8, TestVarlenOutput) { + // schema for input fields + auto field_a = field("a", boolean()); + auto schema = arrow::schema({field_a}); + + // build expressions. + // if (a) literal_hi else literal_bye + auto if_node = TreeExprBuilder::MakeIf( + TreeExprBuilder::MakeField(field_a), TreeExprBuilder::MakeStringLiteral("hi"), + TreeExprBuilder::MakeStringLiteral("bye"), utf8()); + auto expr = TreeExprBuilder::MakeExpression(if_node, field("res", utf8())); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + + // assert that it fails gracefully. + ASSERT_OK(Projector::Make(schema, {expr}, TestConfiguration(), &projector)); + + // Create a row-batch with some sample data + int num_records = 4; + auto array_in = + MakeArrowArrayBool({true, false, false, false}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_in}); + + // Evaluate expression + arrow::ArrayVector outputs; + ASSERT_OK(projector->Evaluate(*in_batch, pool_, &outputs)); + + // expected output + auto exp = MakeArrowArrayUtf8({"hi", "bye", "bye", "bye"}, {true, true, true, true}); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestUtf8, TestConvertUtf8) { + // schema for input fields + auto field_a = field("a", arrow::binary()); + auto field_c = field("c", utf8()); + auto schema = arrow::schema({field_a, field_c}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_c = TreeExprBuilder::MakeField(field_c); + + // define char to replace + auto node_b = TreeExprBuilder::MakeStringLiteral("z"); + + auto convert_replace_utf8 = + TreeExprBuilder::MakeFunction("convert_replaceUTF8", {node_a, node_b}, utf8()); + auto equals = + TreeExprBuilder::MakeFunction("equal", {convert_replace_utf8, node_c}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(equals, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 3; + auto array_a = MakeArrowArrayUtf8({"ok-\xf8\x28" + "-a", + "all-valid", "ok-\xa0\xa1-valid"}, + {true, true, true}); + + auto array_b = + MakeArrowArrayUtf8({"ok-z(-a", "all-valid", "ok-zz-valid"}, {true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + auto exp = MakeArrowArrayBool({true, true, true}, {true, true, true}); + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]); +} + +TEST_F(TestUtf8, TestCastVarChar) { + // schema for input fields + auto field_a = field("a", utf8()); + auto field_c = field("c", utf8()); + auto schema = arrow::schema({field_a, field_c}); + + // output fields + auto res = field("res", boolean()); + + // build expressions. + auto node_a = TreeExprBuilder::MakeField(field_a); + auto node_c = TreeExprBuilder::MakeField(field_c); + // truncates the string to input length + auto node_b = TreeExprBuilder::MakeLiteral(static_cast<int64_t>(10)); + auto cast_varchar = + TreeExprBuilder::MakeFunction("castVARCHAR", {node_a, node_b}, utf8()); + auto equals = TreeExprBuilder::MakeFunction("equal", {cast_varchar, node_c}, boolean()); + auto expr = TreeExprBuilder::MakeExpression(equals, res); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 5; + auto array_a = MakeArrowArrayUtf8( + {"park", "Sparkle", "bright spark and fire", "fiery SPARK", "मदन"}, + {true, true, false, true, true}); + + auto array_b = + MakeArrowArrayUtf8({"park", "Sparkle", "bright spar", "fiery SPAR", "मदन"}, + {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + auto exp = MakeArrowArrayBool({true, true, false, true, true}, + {true, true, false, true, true}); + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs[0]); +} + +TEST_F(TestUtf8, TestAscii) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_asc = field("ascii", arrow::int32()); + + // Build expression + auto asc_expr = TreeExprBuilder::MakeExpression("ascii", {field0}, field_asc); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {asc_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 6; + auto array0 = MakeArrowArrayUtf8({"ABC", "", "abc", "Hello World", "123", "999"}, + {true, true, true, true, true, true}); + // expected output + auto exp_asc = + MakeArrowArrayInt32({65, 0, 97, 72, 49, 57}, {true, true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_asc, outputs.at(0)); +} + +TEST_F(TestUtf8, TestSpace) { + // schema for input fields + auto field0 = field("f0", arrow::int64()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_space = field("space", arrow::utf8()); + + // Build expression + auto space_expr = TreeExprBuilder::MakeExpression("space", {field0}, field_space); + + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {space_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt64({1, 0, -5, 2}, {true, true, true, true}); + // expected output + auto exp_space = MakeArrowArrayUtf8({" ", "", "", " "}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_space, outputs.at(0)); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/to_date_holder.cc b/src/arrow/cpp/src/gandiva/to_date_holder.cc new file mode 100644 index 000000000..1b7e2864f --- /dev/null +++ b/src/arrow/cpp/src/gandiva/to_date_holder.cc @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/to_date_holder.h" + +#include <algorithm> +#include <string> + +#include "arrow/util/value_parsing.h" +#include "arrow/vendored/datetime.h" +#include "gandiva/date_utils.h" +#include "gandiva/execution_context.h" +#include "gandiva/node.h" + +namespace gandiva { + +Status ToDateHolder::Make(const FunctionNode& node, + std::shared_ptr<ToDateHolder>* holder) { + if (node.children().size() != 2 && node.children().size() != 3) { + return Status::Invalid("'to_date' function requires two or three parameters"); + } + + auto literal_pattern = dynamic_cast<LiteralNode*>(node.children().at(1).get()); + if (literal_pattern == nullptr) { + return Status::Invalid( + "'to_date' function requires a literal as the second parameter"); + } + + auto literal_type = literal_pattern->return_type()->id(); + if (literal_type != arrow::Type::STRING && literal_type != arrow::Type::BINARY) { + return Status::Invalid( + "'to_date' function requires a string literal as the second parameter"); + } + auto pattern = arrow::util::get<std::string>(literal_pattern->holder()); + + int suppress_errors = 0; + if (node.children().size() == 3) { + auto literal_suppress_errors = + dynamic_cast<LiteralNode*>(node.children().at(2).get()); + if (literal_pattern == nullptr) { + return Status::Invalid( + "The (optional) third parameter to 'to_date' function needs to an integer " + "literal to indicate whether to suppress the error"); + } + + literal_type = literal_suppress_errors->return_type()->id(); + if (literal_type != arrow::Type::INT32) { + return Status::Invalid( + "The (optional) third parameter to 'to_date' function needs to an integer " + "literal to indicate whether to suppress the error"); + } + suppress_errors = arrow::util::get<int>(literal_suppress_errors->holder()); + } + + return Make(pattern, suppress_errors, holder); +} + +Status ToDateHolder::Make(const std::string& sql_pattern, int32_t suppress_errors, + std::shared_ptr<ToDateHolder>* holder) { + std::shared_ptr<std::string> transformed_pattern; + ARROW_RETURN_NOT_OK(DateUtils::ToInternalFormat(sql_pattern, &transformed_pattern)); + auto lholder = std::shared_ptr<ToDateHolder>( + new ToDateHolder(*(transformed_pattern.get()), suppress_errors)); + *holder = lholder; + return Status::OK(); +} + +int64_t ToDateHolder::operator()(ExecutionContext* context, const char* data, + int data_len, bool in_valid, bool* out_valid) { + *out_valid = false; + if (!in_valid) { + return 0; + } + + // Issues + // 1. processes date that do not match the format. + // 2. does not process time in format +08:00 (or) id. + int64_t seconds_since_epoch = 0; + if (!::arrow::internal::ParseTimestampStrptime( + data, data_len, pattern_.c_str(), + /*ignore_time_in_day=*/true, /*allow_trailing_chars=*/true, + ::arrow::TimeUnit::SECOND, &seconds_since_epoch)) { + return_error(context, data, data_len); + return 0; + } + + *out_valid = true; + return seconds_since_epoch * 1000; +} + +void ToDateHolder::return_error(ExecutionContext* context, const char* data, + int data_len) { + if (suppress_errors_ == 1) { + return; + } + + std::string err_msg = + "Error parsing value " + std::string(data, data_len) + " for given format."; + context->set_error_msg(err_msg.c_str()); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/to_date_holder.h b/src/arrow/cpp/src/gandiva/to_date_holder.h new file mode 100644 index 000000000..1211b6a30 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/to_date_holder.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> +#include <unordered_map> + +#include "arrow/status.h" + +#include "gandiva/execution_context.h" +#include "gandiva/function_holder.h" +#include "gandiva/node.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Function Holder for SQL 'to_date' +class GANDIVA_EXPORT ToDateHolder : public FunctionHolder { + public: + ~ToDateHolder() override = default; + + static Status Make(const FunctionNode& node, std::shared_ptr<ToDateHolder>* holder); + + static Status Make(const std::string& sql_pattern, int32_t suppress_errors, + std::shared_ptr<ToDateHolder>* holder); + + /// Return true if the data matches the pattern. + int64_t operator()(ExecutionContext* context, const char* data, int data_len, + bool in_valid, bool* out_valid); + + private: + ToDateHolder(const std::string& pattern, int32_t suppress_errors) + : pattern_(pattern), suppress_errors_(suppress_errors) {} + + void return_error(ExecutionContext* context, const char* data, int data_len); + + std::string pattern_; // date format string + + int32_t suppress_errors_; // should throw exception on runtime errors +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/to_date_holder_test.cc b/src/arrow/cpp/src/gandiva/to_date_holder_test.cc new file mode 100644 index 000000000..a420774bf --- /dev/null +++ b/src/arrow/cpp/src/gandiva/to_date_holder_test.cc @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <memory> +#include <vector> + +#include "arrow/testing/gtest_util.h" + +#include "gandiva/execution_context.h" +#include "gandiva/to_date_holder.h" +#include "precompiled/epoch_time_point.h" + +#include <gtest/gtest.h> + +namespace gandiva { + +class TestToDateHolder : public ::testing::Test { + public: + FunctionNode BuildToDate(std::string pattern) { + auto field = std::make_shared<FieldNode>(arrow::field("in", arrow::utf8())); + auto pattern_node = + std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(pattern), false); + auto suppress_error_node = + std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(0), false); + return FunctionNode("to_date_utf8_utf8_int32", + {field, pattern_node, suppress_error_node}, arrow::int64()); + } + + protected: + ExecutionContext execution_context_; +}; + +TEST_F(TestToDateHolder, TestSimpleDateTime) { + std::shared_ptr<ToDateHolder> to_date_holder; + ASSERT_OK(ToDateHolder::Make("YYYY-MM-DD HH:MI:SS", 1, &to_date_holder)); + + auto& to_date = *to_date_holder; + bool out_valid; + std::string s("1986-12-01 01:01:01"); + int64_t millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 533779200000); + + s = std::string("1986-12-01 01:01:01.11"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 533779200000); + + s = std::string("1986-12-01 01:01:01 +0800"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 533779200000); + +#if 0 + // TODO : this fails parsing with date::parse and strptime on linux + s = std::string("1886-12-01 00:00:00"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int) s.length(), true, &out_valid); + EXPECT_EQ(out_valid, true); + EXPECT_EQ(millis_since_epoch, -2621894400000); +#endif + + s = std::string("1886-12-01 01:01:01"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, -2621894400000); + + s = std::string("1986-12-11 01:30:00"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 534643200000); +} + +TEST_F(TestToDateHolder, TestSimpleDate) { + std::shared_ptr<ToDateHolder> to_date_holder; + ASSERT_OK(ToDateHolder::Make("YYYY-MM-DD", 1, &to_date_holder)); + + auto& to_date = *to_date_holder; + bool out_valid; + std::string s("1986-12-01"); + int64_t millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 533779200000); + + s = std::string("1986-12-01"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 533779200000); + + s = std::string("1886-12-1"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, -2621894400000); + + s = std::string("2012-12-1"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 1354320000000); + + // wrong month. should return 0 since we are suppressing errors. + s = std::string("1986-21-01 01:01:01 +0800"); + millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(millis_since_epoch, 0); +} + +TEST_F(TestToDateHolder, TestSimpleDateTimeError) { + std::shared_ptr<ToDateHolder> to_date_holder; + + auto status = ToDateHolder::Make("YYYY-MM-DD HH:MI:SS", 0, &to_date_holder); + EXPECT_EQ(status.ok(), true) << status.message(); + auto& to_date = *to_date_holder; + bool out_valid; + + std::string s("1986-01-40 01:01:01 +0800"); + int64_t millis_since_epoch = + to_date(&execution_context_, s.data(), (int)s.length(), true, &out_valid); + EXPECT_EQ(0, millis_since_epoch); + std::string expected_error = + "Error parsing value 1986-01-40 01:01:01 +0800 for given format"; + EXPECT_TRUE(execution_context_.get_error().find(expected_error) != std::string::npos) + << status.message(); + + // not valid should not return error + execution_context_.Reset(); + millis_since_epoch = to_date(&execution_context_, "nullptr", 7, false, &out_valid); + EXPECT_EQ(millis_since_epoch, 0); + EXPECT_TRUE(execution_context_.has_error() == false); +} + +TEST_F(TestToDateHolder, TestSimpleDateTimeMakeError) { + std::shared_ptr<ToDateHolder> to_date_holder; + // reject time stamps for now. + auto status = ToDateHolder::Make("YYYY-MM-DD HH:MI:SS tzo", 0, &to_date_holder); + EXPECT_EQ(status.IsInvalid(), true) << status.message(); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tree_expr_builder.cc b/src/arrow/cpp/src/gandiva/tree_expr_builder.cc new file mode 100644 index 000000000..de8e3445a --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tree_expr_builder.cc @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/tree_expr_builder.h" + +#include <iostream> +#include <utility> + +#include "gandiva/decimal_type_util.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/node.h" + +namespace gandiva { + +#define MAKE_LITERAL(atype, ctype) \ + NodePtr TreeExprBuilder::MakeLiteral(ctype value) { \ + return std::make_shared<LiteralNode>(atype, LiteralHolder(value), false); \ + } + +MAKE_LITERAL(arrow::boolean(), bool) +MAKE_LITERAL(arrow::int8(), int8_t) +MAKE_LITERAL(arrow::int16(), int16_t) +MAKE_LITERAL(arrow::int32(), int32_t) +MAKE_LITERAL(arrow::int64(), int64_t) +MAKE_LITERAL(arrow::uint8(), uint8_t) +MAKE_LITERAL(arrow::uint16(), uint16_t) +MAKE_LITERAL(arrow::uint32(), uint32_t) +MAKE_LITERAL(arrow::uint64(), uint64_t) +MAKE_LITERAL(arrow::float32(), float) +MAKE_LITERAL(arrow::float64(), double) + +NodePtr TreeExprBuilder::MakeStringLiteral(const std::string& value) { + return std::make_shared<LiteralNode>(arrow::utf8(), LiteralHolder(value), false); +} + +NodePtr TreeExprBuilder::MakeBinaryLiteral(const std::string& value) { + return std::make_shared<LiteralNode>(arrow::binary(), LiteralHolder(value), false); +} + +NodePtr TreeExprBuilder::MakeDecimalLiteral(const DecimalScalar128& value) { + return std::make_shared<LiteralNode>(arrow::decimal(value.precision(), value.scale()), + LiteralHolder(value), false); +} + +NodePtr TreeExprBuilder::MakeNull(DataTypePtr data_type) { + static const std::string empty; + + if (data_type == nullptr) { + return nullptr; + } + + switch (data_type->id()) { + case arrow::Type::BOOL: + return std::make_shared<LiteralNode>(data_type, LiteralHolder(false), true); + case arrow::Type::INT8: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((int8_t)0), true); + case arrow::Type::INT16: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((int16_t)0), true); + return std::make_shared<LiteralNode>(data_type, LiteralHolder((int32_t)0), true); + case arrow::Type::UINT8: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint8_t)0), true); + case arrow::Type::UINT16: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint16_t)0), true); + case arrow::Type::UINT32: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint32_t)0), true); + case arrow::Type::UINT64: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((uint64_t)0), true); + case arrow::Type::FLOAT: + return std::make_shared<LiteralNode>(data_type, + LiteralHolder(static_cast<float>(0)), true); + case arrow::Type::DOUBLE: + return std::make_shared<LiteralNode>(data_type, + LiteralHolder(static_cast<double>(0)), true); + case arrow::Type::STRING: + case arrow::Type::BINARY: + return std::make_shared<LiteralNode>(data_type, LiteralHolder(empty), true); + case arrow::Type::INT32: + case arrow::Type::DATE32: + case arrow::Type::TIME32: + case arrow::Type::INTERVAL_MONTHS: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((int32_t)0), true); + case arrow::Type::INT64: + case arrow::Type::DATE64: + case arrow::Type::TIME64: + case arrow::Type::TIMESTAMP: + case arrow::Type::INTERVAL_DAY_TIME: + return std::make_shared<LiteralNode>(data_type, LiteralHolder((int64_t)0), true); + case arrow::Type::DECIMAL: { + std::shared_ptr<arrow::DecimalType> decimal_type = + arrow::internal::checked_pointer_cast<arrow::DecimalType>(data_type); + DecimalScalar128 literal(decimal_type->precision(), decimal_type->scale()); + return std::make_shared<LiteralNode>(data_type, LiteralHolder(literal), true); + } + default: + return nullptr; + } +} + +NodePtr TreeExprBuilder::MakeField(FieldPtr field) { + return NodePtr(new FieldNode(field)); +} + +NodePtr TreeExprBuilder::MakeFunction(const std::string& name, const NodeVector& params, + DataTypePtr result_type) { + if (result_type == nullptr) { + return nullptr; + } + return std::make_shared<FunctionNode>(name, params, result_type); +} + +NodePtr TreeExprBuilder::MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node, + DataTypePtr result_type) { + if (condition == nullptr || then_node == nullptr || else_node == nullptr || + result_type == nullptr) { + return nullptr; + } + return std::make_shared<IfNode>(condition, then_node, else_node, result_type); +} + +NodePtr TreeExprBuilder::MakeAnd(const NodeVector& children) { + return std::make_shared<BooleanNode>(BooleanNode::AND, children); +} + +NodePtr TreeExprBuilder::MakeOr(const NodeVector& children) { + return std::make_shared<BooleanNode>(BooleanNode::OR, children); +} + +// set this to true to print expressions for debugging purposes +static bool print_expr = false; + +ExpressionPtr TreeExprBuilder::MakeExpression(NodePtr root_node, FieldPtr result_field) { + if (result_field == nullptr) { + return nullptr; + } + if (print_expr) { + std::cout << "Expression: " << root_node->ToString() << "\n"; + } + return ExpressionPtr(new Expression(root_node, result_field)); +} + +ExpressionPtr TreeExprBuilder::MakeExpression(const std::string& function, + const FieldVector& in_fields, + FieldPtr out_field) { + if (out_field == nullptr) { + return nullptr; + } + std::vector<NodePtr> field_nodes; + for (auto& field : in_fields) { + auto node = MakeField(field); + field_nodes.push_back(node); + } + auto func_node = MakeFunction(function, field_nodes, out_field->type()); + return MakeExpression(func_node, out_field); +} + +ConditionPtr TreeExprBuilder::MakeCondition(NodePtr root_node) { + if (root_node == nullptr) { + return nullptr; + } + if (print_expr) { + std::cout << "Condition: " << root_node->ToString() << "\n"; + } + + return ConditionPtr(new Condition(root_node)); +} + +ConditionPtr TreeExprBuilder::MakeCondition(const std::string& function, + const FieldVector& in_fields) { + std::vector<NodePtr> field_nodes; + for (auto& field : in_fields) { + auto node = MakeField(field); + field_nodes.push_back(node); + } + + auto func_node = MakeFunction(function, field_nodes, arrow::boolean()); + return ConditionPtr(new Condition(func_node)); +} + +NodePtr TreeExprBuilder::MakeInExpressionDecimal( + NodePtr node, std::unordered_set<gandiva::DecimalScalar128>& constants) { + int32_t precision = 0; + int32_t scale = 0; + if (!constants.empty()) { + precision = constants.begin()->precision(); + scale = constants.begin()->scale(); + } + return std::make_shared<InExpressionNode<gandiva::DecimalScalar128>>(node, constants, + precision, scale); +} + +#define MAKE_IN(NAME, ctype) \ + NodePtr TreeExprBuilder::MakeInExpression##NAME( \ + NodePtr node, const std::unordered_set<ctype>& values) { \ + return std::make_shared<InExpressionNode<ctype>>(node, values); \ + } + +MAKE_IN(Int32, int32_t); +MAKE_IN(Int64, int64_t); +MAKE_IN(Date32, int32_t); +MAKE_IN(Date64, int64_t); +MAKE_IN(TimeStamp, int64_t); +MAKE_IN(Time32, int32_t); +MAKE_IN(Time64, int64_t); +MAKE_IN(Float, float); +MAKE_IN(Double, double); +MAKE_IN(String, std::string); +MAKE_IN(Binary, std::string); + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tree_expr_builder.h b/src/arrow/cpp/src/gandiva/tree_expr_builder.h new file mode 100644 index 000000000..94a4a1793 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tree_expr_builder.h @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cmath> +#include <memory> +#include <string> +#include <unordered_set> +#include <vector> + +#include "arrow/type.h" +#include "gandiva/condition.h" +#include "gandiva/decimal_scalar.h" +#include "gandiva/expression.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// \brief Tree Builder for a nested expression. +class GANDIVA_EXPORT TreeExprBuilder { + public: + /// \brief create a node on a literal. + static NodePtr MakeLiteral(bool value); + static NodePtr MakeLiteral(uint8_t value); + static NodePtr MakeLiteral(uint16_t value); + static NodePtr MakeLiteral(uint32_t value); + static NodePtr MakeLiteral(uint64_t value); + static NodePtr MakeLiteral(int8_t value); + static NodePtr MakeLiteral(int16_t value); + static NodePtr MakeLiteral(int32_t value); + static NodePtr MakeLiteral(int64_t value); + static NodePtr MakeLiteral(float value); + static NodePtr MakeLiteral(double value); + static NodePtr MakeStringLiteral(const std::string& value); + static NodePtr MakeBinaryLiteral(const std::string& value); + static NodePtr MakeDecimalLiteral(const DecimalScalar128& value); + + /// \brief create a node on a null literal. + /// returns null if data_type is null or if it's not a supported datatype. + static NodePtr MakeNull(DataTypePtr data_type); + + /// \brief create a node on arrow field. + /// returns null if input is null. + static NodePtr MakeField(FieldPtr field); + + /// \brief create a node with a function. + /// returns null if return_type is null + static NodePtr MakeFunction(const std::string& name, const NodeVector& params, + DataTypePtr return_type); + + /// \brief create a node with an if-else expression. + /// returns null if any of the inputs is null. + static NodePtr MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node, + DataTypePtr result_type); + + /// \brief create a node with a boolean AND expression. + static NodePtr MakeAnd(const NodeVector& children); + + /// \brief create a node with a boolean OR expression. + static NodePtr MakeOr(const NodeVector& children); + + /// \brief create an expression with the specified root_node, and the + /// result written to result_field. + /// returns null if the result_field is null. + static ExpressionPtr MakeExpression(NodePtr root_node, FieldPtr result_field); + + /// \brief convenience function for simple function expressions. + /// returns null if the out_field is null. + static ExpressionPtr MakeExpression(const std::string& function, + const FieldVector& in_fields, FieldPtr out_field); + + /// \brief create a condition with the specified root_node + static ConditionPtr MakeCondition(NodePtr root_node); + + /// \brief convenience function for simple function conditions. + static ConditionPtr MakeCondition(const std::string& function, + const FieldVector& in_fields); + + /// \brief creates an in expression + static NodePtr MakeInExpressionInt32(NodePtr node, + const std::unordered_set<int32_t>& constants); + + static NodePtr MakeInExpressionInt64(NodePtr node, + const std::unordered_set<int64_t>& constants); + + static NodePtr MakeInExpressionDecimal( + NodePtr node, std::unordered_set<gandiva::DecimalScalar128>& constants); + + static NodePtr MakeInExpressionString(NodePtr node, + const std::unordered_set<std::string>& constants); + + static NodePtr MakeInExpressionBinary(NodePtr node, + const std::unordered_set<std::string>& constants); + + /// \brief creates an in expression for float + static NodePtr MakeInExpressionFloat(NodePtr node, + const std::unordered_set<float>& constants); + + /// \brief creates an in expression for double + static NodePtr MakeInExpressionDouble(NodePtr node, + const std::unordered_set<double>& constants); + + /// \brief Date as s/millis since epoch. + static NodePtr MakeInExpressionDate32(NodePtr node, + const std::unordered_set<int32_t>& constants); + + /// \brief Date as millis/us/ns since epoch. + static NodePtr MakeInExpressionDate64(NodePtr node, + const std::unordered_set<int64_t>& constants); + + /// \brief Time as s/millis of day + static NodePtr MakeInExpressionTime32(NodePtr node, + const std::unordered_set<int32_t>& constants); + + /// \brief Time as millis/us/ns of day + static NodePtr MakeInExpressionTime64(NodePtr node, + const std::unordered_set<int64_t>& constants); + + /// \brief Timestamp as millis since epoch. + static NodePtr MakeInExpressionTimeStamp(NodePtr node, + const std::unordered_set<int64_t>& constants); +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/tree_expr_test.cc b/src/arrow/cpp/src/gandiva/tree_expr_test.cc new file mode 100644 index 000000000..e70cf1289 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tree_expr_test.cc @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/tree_expr_builder.h" + +#include <gtest/gtest.h> +#include "gandiva/annotator.h" +#include "gandiva/dex.h" +#include "gandiva/expr_decomposer.h" +#include "gandiva/function_registry.h" +#include "gandiva/function_signature.h" +#include "gandiva/gandiva_aliases.h" +#include "gandiva/node.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::int32; + +class TestExprTree : public ::testing::Test { + public: + void SetUp() { + i0_ = field("i0", int32()); + i1_ = field("i1", int32()); + + b0_ = field("b0", boolean()); + } + + protected: + FieldPtr i0_; // int32 + FieldPtr i1_; // int32 + + FieldPtr b0_; // bool + FunctionRegistry registry_; +}; + +TEST_F(TestExprTree, TestField) { + Annotator annotator; + + auto n0 = TreeExprBuilder::MakeField(i0_); + EXPECT_EQ(n0->return_type(), int32()); + + auto n1 = TreeExprBuilder::MakeField(b0_); + EXPECT_EQ(n1->return_type(), boolean()); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n1, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto value_dex = std::dynamic_pointer_cast<VectorReadFixedLenValueDex>(value); + EXPECT_EQ(value_dex->FieldType(), boolean()); + + EXPECT_EQ(pair->validity_exprs().size(), 1); + auto validity = pair->validity_exprs().at(0); + auto validity_dex = std::dynamic_pointer_cast<VectorReadValidityDex>(validity); + EXPECT_NE(validity_dex->ValidityIdx(), value_dex->DataIdx()); +} + +TEST_F(TestExprTree, TestBinary) { + Annotator annotator; + + auto left = TreeExprBuilder::MakeField(i0_); + auto right = TreeExprBuilder::MakeField(i1_); + + auto n = TreeExprBuilder::MakeFunction("add", {left, right}, int32()); + auto add = std::dynamic_pointer_cast<FunctionNode>(n); + + auto func_desc = add->descriptor(); + FunctionSignature sign(func_desc->name(), func_desc->params(), + func_desc->return_type()); + + EXPECT_EQ(add->return_type(), int32()); + EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto null_if_null = std::dynamic_pointer_cast<NonNullableFuncDex>(value); + + FunctionSignature signature("add", {int32(), int32()}, int32()); + const NativeFunction* fn = registry_.LookupSignature(signature); + EXPECT_EQ(null_if_null->native_function(), fn); +} + +TEST_F(TestExprTree, TestUnary) { + Annotator annotator; + + auto arg = TreeExprBuilder::MakeField(i0_); + auto n = TreeExprBuilder::MakeFunction("isnumeric", {arg}, boolean()); + + auto unaryFn = std::dynamic_pointer_cast<FunctionNode>(n); + auto func_desc = unaryFn->descriptor(); + FunctionSignature sign(func_desc->name(), func_desc->params(), + func_desc->return_type()); + EXPECT_EQ(unaryFn->return_type(), boolean()); + EXPECT_TRUE(sign == FunctionSignature("isnumeric", {int32()}, boolean())); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*n, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto never_null = std::dynamic_pointer_cast<NullableNeverFuncDex>(value); + + FunctionSignature signature("isnumeric", {int32()}, boolean()); + const NativeFunction* fn = registry_.LookupSignature(signature); + EXPECT_EQ(never_null->native_function(), fn); +} + +TEST_F(TestExprTree, TestExpression) { + Annotator annotator; + auto left = TreeExprBuilder::MakeField(i0_); + auto right = TreeExprBuilder::MakeField(i1_); + + auto n = TreeExprBuilder::MakeFunction("add", {left, right}, int32()); + auto e = TreeExprBuilder::MakeExpression(n, field("r", int32())); + auto root_node = e->root(); + EXPECT_EQ(root_node->return_type(), int32()); + + auto add_node = std::dynamic_pointer_cast<FunctionNode>(root_node); + auto func_desc = add_node->descriptor(); + FunctionSignature sign(func_desc->name(), func_desc->params(), + func_desc->return_type()); + EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); + + ExprDecomposer decomposer(registry_, annotator); + ValueValidityPairPtr pair; + auto status = decomposer.Decompose(*root_node, &pair); + DCHECK_EQ(status.ok(), true) << status.message(); + + auto value = pair->value_expr(); + auto null_if_null = std::dynamic_pointer_cast<NonNullableFuncDex>(value); + + FunctionSignature signature("add", {int32(), int32()}, int32()); + const NativeFunction* fn = registry_.LookupSignature(signature); + EXPECT_EQ(null_if_null->native_function(), fn); +} + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/value_validity_pair.h b/src/arrow/cpp/src/gandiva/value_validity_pair.h new file mode 100644 index 000000000..e5943b230 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/value_validity_pair.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <vector> + +#include "gandiva/gandiva_aliases.h" +#include "gandiva/visibility.h" + +namespace gandiva { + +/// Pair of vector/validities generated after decomposing an expression tree/subtree. +class GANDIVA_EXPORT ValueValidityPair { + public: + ValueValidityPair(const DexVector& validity_exprs, DexPtr value_expr) + : validity_exprs_(validity_exprs), value_expr_(value_expr) {} + + ValueValidityPair(DexPtr validity_expr, DexPtr value_expr) : value_expr_(value_expr) { + validity_exprs_.push_back(validity_expr); + } + + explicit ValueValidityPair(DexPtr value_expr) : value_expr_(value_expr) {} + + const DexVector& validity_exprs() const { return validity_exprs_; } + + const DexPtr& value_expr() const { return value_expr_; } + + private: + DexVector validity_exprs_; + DexPtr value_expr_; +}; + +} // namespace gandiva diff --git a/src/arrow/cpp/src/gandiva/visibility.h b/src/arrow/cpp/src/gandiva/visibility.h new file mode 100644 index 000000000..450b3056b --- /dev/null +++ b/src/arrow/cpp/src/gandiva/visibility.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4251) +#else +#pragma GCC diagnostic ignored "-Wattributes" +#endif + +#ifdef GANDIVA_STATIC +#define GANDIVA_EXPORT +#elif defined(GANDIVA_EXPORTING) +#define GANDIVA_EXPORT __declspec(dllexport) +#else +#define GANDIVA_EXPORT __declspec(dllimport) +#endif + +#define GANDIVA_NO_EXPORT +#else // Not Windows +#ifndef GANDIVA_EXPORT +#define GANDIVA_EXPORT __attribute__((visibility("default"))) +#endif +#ifndef GANDIVA_NO_EXPORT +#define GANDIVA_NO_EXPORT __attribute__((visibility("hidden"))) +#endif +#endif // Non-Windows + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif |