summaryrefslogtreecommitdiffstats
path: root/src/arrow/python/pyarrow/tests/test_gandiva.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/python/pyarrow/tests/test_gandiva.py')
-rw-r--r--src/arrow/python/pyarrow/tests/test_gandiva.py391
1 files changed, 391 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/tests/test_gandiva.py b/src/arrow/python/pyarrow/tests/test_gandiva.py
new file mode 100644
index 000000000..6522c233a
--- /dev/null
+++ b/src/arrow/python/pyarrow/tests/test_gandiva.py
@@ -0,0 +1,391 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import pytest
+
+import pyarrow as pa
+
+
+@pytest.mark.gandiva
+def test_tree_exp_builder():
+ import pyarrow.gandiva as gandiva
+
+ builder = gandiva.TreeExprBuilder()
+
+ field_a = pa.field('a', pa.int32())
+ field_b = pa.field('b', pa.int32())
+
+ schema = pa.schema([field_a, field_b])
+
+ field_result = pa.field('res', pa.int32())
+
+ node_a = builder.make_field(field_a)
+ node_b = builder.make_field(field_b)
+
+ assert node_a.return_type() == field_a.type
+
+ condition = builder.make_function("greater_than", [node_a, node_b],
+ pa.bool_())
+ if_node = builder.make_if(condition, node_a, node_b, pa.int32())
+
+ expr = builder.make_expression(if_node, field_result)
+
+ assert expr.result().type == pa.int32()
+
+ projector = gandiva.make_projector(
+ schema, [expr], pa.default_memory_pool())
+
+ # Gandiva generates compute kernel function named `@expr_X`
+ assert projector.llvm_ir.find("@expr_") != -1
+
+ a = pa.array([10, 12, -20, 5], type=pa.int32())
+ b = pa.array([5, 15, 15, 17], type=pa.int32())
+ e = pa.array([10, 15, 15, 17], type=pa.int32())
+ input_batch = pa.RecordBatch.from_arrays([a, b], names=['a', 'b'])
+
+ r, = projector.evaluate(input_batch)
+ assert r.equals(e)
+
+
+@pytest.mark.gandiva
+def test_table():
+ import pyarrow.gandiva as gandiva
+
+ table = pa.Table.from_arrays([pa.array([1.0, 2.0]), pa.array([3.0, 4.0])],
+ ['a', 'b'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ node_b = builder.make_field(table.schema.field("b"))
+
+ sum = builder.make_function("add", [node_a, node_b], pa.float64())
+
+ field_result = pa.field("c", pa.float64())
+ expr = builder.make_expression(sum, field_result)
+
+ projector = gandiva.make_projector(
+ table.schema, [expr], pa.default_memory_pool())
+
+ # TODO: Add .evaluate function which can take Tables instead of
+ # RecordBatches
+ r, = projector.evaluate(table.to_batches()[0])
+
+ e = pa.array([4.0, 6.0])
+ assert r.equals(e)
+
+
+@pytest.mark.gandiva
+def test_filter():
+ import pyarrow.gandiva as gandiva
+
+ table = pa.Table.from_arrays([pa.array([1.0 * i for i in range(10000)])],
+ ['a'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ thousand = builder.make_literal(1000.0, pa.float64())
+ cond = builder.make_function("less_than", [node_a, thousand], pa.bool_())
+ condition = builder.make_condition(cond)
+
+ assert condition.result().type == pa.bool_()
+
+ filter = gandiva.make_filter(table.schema, condition)
+ # Gandiva generates compute kernel function named `@expr_X`
+ assert filter.llvm_ir.find("@expr_") != -1
+
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array(range(1000), type=pa.uint32()))
+
+
+@pytest.mark.gandiva
+def test_in_expr():
+ import pyarrow.gandiva as gandiva
+
+ arr = pa.array(["ga", "an", "nd", "di", "iv", "va"])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ # string
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, ["an", "nd"], pa.string())
+ condition = builder.make_condition(cond)
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
+
+ # int32
+ arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
+ table = pa.Table.from_arrays([arr.cast(pa.int32())], ["a"])
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [1, 5], pa.int32())
+ condition = builder.make_condition(cond)
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
+
+ # int64
+ arr = pa.array([3, 1, 4, 1, 5, 9, 2, 6, 5, 4])
+ table = pa.Table.from_arrays([arr], ["a"])
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [1, 5], pa.int64())
+ condition = builder.make_condition(cond)
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 3, 4, 8], type=pa.uint32()))
+
+
+@pytest.mark.skip(reason="Gandiva C++ did not have *real* binary, "
+ "time and date support.")
+def test_in_expr_todo():
+ import pyarrow.gandiva as gandiva
+ # TODO: Implement reasonable support for timestamp, time & date.
+ # Current exceptions:
+ # pyarrow.lib.ArrowException: ExpressionValidationError:
+ # Evaluation expression for IN clause returns XXXX values are of typeXXXX
+
+ # binary
+ arr = pa.array([b"ga", b"an", b"nd", b"di", b"iv", b"va"])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [b'an', b'nd'], pa.binary())
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([1, 2], type=pa.uint32()))
+
+ # timestamp
+ datetime_1 = datetime.datetime.utcfromtimestamp(1542238951.621877)
+ datetime_2 = datetime.datetime.utcfromtimestamp(1542238911.621877)
+ datetime_3 = datetime.datetime.utcfromtimestamp(1542238051.621877)
+
+ arr = pa.array([datetime_1, datetime_2, datetime_3])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [datetime_2], pa.timestamp('ms'))
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert list(result.to_array()) == [1]
+
+ # time
+ time_1 = datetime_1.time()
+ time_2 = datetime_2.time()
+ time_3 = datetime_3.time()
+
+ arr = pa.array([time_1, time_2, time_3])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [time_2], pa.time64('ms'))
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert list(result.to_array()) == [1]
+
+ # date
+ date_1 = datetime_1.date()
+ date_2 = datetime_2.date()
+ date_3 = datetime_3.date()
+
+ arr = pa.array([date_1, date_2, date_3])
+ table = pa.Table.from_arrays([arr], ["a"])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ cond = builder.make_in_expression(node_a, [date_2], pa.date32())
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert list(result.to_array()) == [1]
+
+
+@pytest.mark.gandiva
+def test_boolean():
+ import pyarrow.gandiva as gandiva
+
+ table = pa.Table.from_arrays([
+ pa.array([1., 31., 46., 3., 57., 44., 22.]),
+ pa.array([5., 45., 36., 73., 83., 23., 76.])],
+ ['a', 'b'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ node_b = builder.make_field(table.schema.field("b"))
+ fifty = builder.make_literal(50.0, pa.float64())
+ eleven = builder.make_literal(11.0, pa.float64())
+
+ cond_1 = builder.make_function("less_than", [node_a, fifty], pa.bool_())
+ cond_2 = builder.make_function("greater_than", [node_a, node_b],
+ pa.bool_())
+ cond_3 = builder.make_function("less_than", [node_b, eleven], pa.bool_())
+ cond = builder.make_or([builder.make_and([cond_1, cond_2]), cond_3])
+ condition = builder.make_condition(cond)
+
+ filter = gandiva.make_filter(table.schema, condition)
+ result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
+ assert result.to_array().equals(pa.array([0, 2, 5], type=pa.uint32()))
+
+
+@pytest.mark.gandiva
+def test_literals():
+ import pyarrow.gandiva as gandiva
+
+ builder = gandiva.TreeExprBuilder()
+
+ builder.make_literal(True, pa.bool_())
+ builder.make_literal(0, pa.uint8())
+ builder.make_literal(1, pa.uint16())
+ builder.make_literal(2, pa.uint32())
+ builder.make_literal(3, pa.uint64())
+ builder.make_literal(4, pa.int8())
+ builder.make_literal(5, pa.int16())
+ builder.make_literal(6, pa.int32())
+ builder.make_literal(7, pa.int64())
+ builder.make_literal(8.0, pa.float32())
+ builder.make_literal(9.0, pa.float64())
+ builder.make_literal("hello", pa.string())
+ builder.make_literal(b"world", pa.binary())
+
+ builder.make_literal(True, "bool")
+ builder.make_literal(0, "uint8")
+ builder.make_literal(1, "uint16")
+ builder.make_literal(2, "uint32")
+ builder.make_literal(3, "uint64")
+ builder.make_literal(4, "int8")
+ builder.make_literal(5, "int16")
+ builder.make_literal(6, "int32")
+ builder.make_literal(7, "int64")
+ builder.make_literal(8.0, "float32")
+ builder.make_literal(9.0, "float64")
+ builder.make_literal("hello", "string")
+ builder.make_literal(b"world", "binary")
+
+ with pytest.raises(TypeError):
+ builder.make_literal("hello", pa.int64())
+ with pytest.raises(TypeError):
+ builder.make_literal(True, None)
+
+
+@pytest.mark.gandiva
+def test_regex():
+ import pyarrow.gandiva as gandiva
+
+ elements = ["park", "sparkle", "bright spark and fire", "spark"]
+ data = pa.array(elements, type=pa.string())
+ table = pa.Table.from_arrays([data], names=['a'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ regex = builder.make_literal("%spark%", pa.string())
+ like = builder.make_function("like", [node_a, regex], pa.bool_())
+
+ field_result = pa.field("b", pa.bool_())
+ expr = builder.make_expression(like, field_result)
+
+ projector = gandiva.make_projector(
+ table.schema, [expr], pa.default_memory_pool())
+
+ r, = projector.evaluate(table.to_batches()[0])
+ b = pa.array([False, True, True, True], type=pa.bool_())
+ assert r.equals(b)
+
+
+@pytest.mark.gandiva
+def test_get_registered_function_signatures():
+ import pyarrow.gandiva as gandiva
+ signatures = gandiva.get_registered_function_signatures()
+
+ assert type(signatures[0].return_type()) is pa.DataType
+ assert type(signatures[0].param_types()) is list
+ assert hasattr(signatures[0], "name")
+
+
+@pytest.mark.gandiva
+def test_filter_project():
+ import pyarrow.gandiva as gandiva
+ mpool = pa.default_memory_pool()
+ # Create a table with some sample data
+ array0 = pa.array([10, 12, -20, 5, 21, 29], pa.int32())
+ array1 = pa.array([5, 15, 15, 17, 12, 3], pa.int32())
+ array2 = pa.array([1, 25, 11, 30, -21, None], pa.int32())
+
+ table = pa.Table.from_arrays([array0, array1, array2], ['a', 'b', 'c'])
+
+ field_result = pa.field("res", pa.int32())
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field("a"))
+ node_b = builder.make_field(table.schema.field("b"))
+ node_c = builder.make_field(table.schema.field("c"))
+
+ greater_than_function = builder.make_function("greater_than",
+ [node_a, node_b], pa.bool_())
+ filter_condition = builder.make_condition(
+ greater_than_function)
+
+ project_condition = builder.make_function("less_than",
+ [node_b, node_c], pa.bool_())
+ if_node = builder.make_if(project_condition,
+ node_b, node_c, pa.int32())
+ expr = builder.make_expression(if_node, field_result)
+
+ # Build a filter for the expressions.
+ filter = gandiva.make_filter(table.schema, filter_condition)
+
+ # Build a projector for the expressions.
+ projector = gandiva.make_projector(
+ table.schema, [expr], mpool, "UINT32")
+
+ # Evaluate filter
+ selection_vector = filter.evaluate(table.to_batches()[0], mpool)
+
+ # Evaluate project
+ r, = projector.evaluate(
+ table.to_batches()[0], selection_vector)
+
+ exp = pa.array([1, -21, None], pa.int32())
+ assert r.equals(exp)
+
+
+@pytest.mark.gandiva
+def test_to_string():
+ import pyarrow.gandiva as gandiva
+ builder = gandiva.TreeExprBuilder()
+
+ assert str(builder.make_literal(2.0, pa.float64())
+ ).startswith('(const double) 2 raw(')
+ assert str(builder.make_literal(2, pa.int64())) == '(const int64) 2'
+ assert str(builder.make_field(pa.field('x', pa.float64()))) == '(double) x'
+ assert str(builder.make_field(pa.field('y', pa.string()))) == '(string) y'
+
+ field_z = builder.make_field(pa.field('z', pa.bool_()))
+ func_node = builder.make_function('not', [field_z], pa.bool_())
+ assert str(func_node) == 'bool not((bool) z)'
+
+ field_y = builder.make_field(pa.field('y', pa.bool_()))
+ and_node = builder.make_and([func_node, field_y])
+ assert str(and_node) == 'bool not((bool) z) && (bool) y'