diff options
Diffstat (limited to 'src/arrow/cpp/src/gandiva/tests/if_expr_test.cc')
-rw-r--r-- | src/arrow/cpp/src/gandiva/tests/if_expr_test.cc | 378 |
1 files changed, 378 insertions, 0 deletions
diff --git a/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc b/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc new file mode 100644 index 000000000..54b6d43b4 --- /dev/null +++ b/src/arrow/cpp/src/gandiva/tests/if_expr_test.cc @@ -0,0 +1,378 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> +#include "arrow/memory_pool.h" +#include "arrow/status.h" + +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::int32; + +class TestIfExpr : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestIfExpr, TestSimple) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({10, 15, 15, 17}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestSimpleArithmetic) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + b + // else + // a - b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32()); + auto if_node = TreeExprBuilder::MakeIf(condition, sum, sub, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, -20, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({15, -3, -35, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNested) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > b) + // a + b + // else if (a < b) + // a - b + // else + // a * b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition_gt = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto condition_lt = + TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto sum = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sub = TreeExprBuilder::MakeFunction("subtract", {node_a, node_b}, int32()); + auto mult = TreeExprBuilder::MakeFunction("multiply", {node_a, node_b}, int32()); + auto else_node = TreeExprBuilder::MakeIf(condition_lt, sub, mult, int32()); + auto if_node = TreeExprBuilder::MakeIf(condition_gt, sum, else_node, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 12, 15, 5}, {true, true, true, false}); + auto array1 = MakeArrowArrayInt32({5, 15, 15, 17}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({15, -3, 225, 0}, {true, true, true, false}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNestedInIf) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", int32()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a > 10) + // if (a < 20) + // a + b + // else + // b + c + // else + // a + c + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + + auto literal_10 = TreeExprBuilder::MakeLiteral(10); + auto literal_20 = TreeExprBuilder::MakeLiteral(20); + + auto gt_10 = + TreeExprBuilder::MakeFunction("greater_than", {node_a, literal_10}, boolean()); + auto lt_20 = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal_20}, boolean()); + auto sum_ab = TreeExprBuilder::MakeFunction("add", {node_a, node_b}, int32()); + auto sum_bc = TreeExprBuilder::MakeFunction("add", {node_b, node_c}, int32()); + auto sum_ac = TreeExprBuilder::MakeFunction("add", {node_a, node_c}, int32()); + + auto if_lt_20 = TreeExprBuilder::MakeIf(lt_20, sum_ab, sum_bc, int32()); + auto if_gt_10 = TreeExprBuilder::MakeIf(gt_10, if_lt_20, sum_ac, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_gt_10, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 6; + auto array_a = + MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true}); + auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19}, + {true, true, true, false, false, false}); + auto array_c = MakeArrowArrayInt32({35, 45, 55, 35, 45, 55}, + {true, true, true, false, false, false}); + + // expected output + auto exp = + MakeArrowArrayInt32({55, 33, 60, 0, 0, 0}, {true, true, true, false, false, false}); + + // prepare input record batch + auto in_batch = + arrow::RecordBatch::Make(schema, num_records, {array_a, array_b, array_c}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestNestedInCondition) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (if (a > b) then true else if (a < b) false else null) + // 1 + // else if !(if (a > b) then true else if (a < b) false else null) + // 2 + // else + // 3 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto literal_1 = TreeExprBuilder::MakeLiteral(1); + auto literal_2 = TreeExprBuilder::MakeLiteral(2); + auto literal_3 = TreeExprBuilder::MakeLiteral(3); + auto literal_true = TreeExprBuilder::MakeLiteral(true); + auto literal_false = TreeExprBuilder::MakeLiteral(false); + auto literal_null = TreeExprBuilder::MakeNull(boolean()); + + auto a_gt_b = + TreeExprBuilder::MakeFunction("greater_than", {node_a, node_b}, boolean()); + auto a_lt_b = TreeExprBuilder::MakeFunction("less_than", {node_a, node_b}, boolean()); + auto cond_else = + TreeExprBuilder::MakeIf(a_lt_b, literal_false, literal_null, boolean()); + auto cond_if = TreeExprBuilder::MakeIf(a_gt_b, literal_true, cond_else, boolean()); + auto not_cond_if = TreeExprBuilder::MakeFunction("not", {cond_if}, boolean()); + + auto outer_else = TreeExprBuilder::MakeIf(not_cond_if, literal_2, literal_3, int32()); + auto outer_if = TreeExprBuilder::MakeIf(cond_if, literal_1, outer_else, int32()); + auto expr = TreeExprBuilder::MakeExpression(outer_if, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 6; + auto array_a = + MakeArrowArrayInt32({21, 15, 5, 22, 15, 5}, {true, true, true, true, true, true}); + auto array_b = MakeArrowArrayInt32({20, 18, 19, 20, 18, 19}, + {true, true, true, false, false, false}); + // expected output + auto exp = + MakeArrowArrayInt32({1, 2, 2, 3, 3, 3}, {true, true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestIfExpr, TestBigNested) { + // schema for input fields + auto fielda = field("a", int32()); + auto schema = arrow::schema({fielda}); + + // output fields + auto field_result = field("res", int32()); + + // build expression. + // if (a < 10) + // 10 + // else if (a < 20) + // 20 + // .. + // .. + // else if (a < 190) + // 190 + // else + // 200 + auto node_a = TreeExprBuilder::MakeField(fielda); + auto top_node = TreeExprBuilder::MakeLiteral(200); + for (int thresh = 190; thresh > 0; thresh -= 10) { + auto literal = TreeExprBuilder::MakeLiteral(thresh); + auto condition = + TreeExprBuilder::MakeFunction("less_than", {node_a, literal}, boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, literal, top_node, int32()); + top_node = if_node; + } + auto expr = TreeExprBuilder::MakeExpression(top_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr<Projector> projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayInt32({10, 102, 158, 302}, {true, true, true, true}); + + // expected output + auto exp = MakeArrowArrayInt32({20, 110, 160, 200}, {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +} // namespace gandiva |