summaryrefslogtreecommitdiffstats
path: root/src/arrow/java/gandiva
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/java/gandiva')
-rw-r--r--src/arrow/java/gandiva/CMakeLists.txt55
-rw-r--r--src/arrow/java/gandiva/README.md32
-rw-r--r--src/arrow/java/gandiva/pom.xml153
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java72
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java94
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java220
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java29
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java199
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java93
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java170
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java120
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java364
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVector.java87
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt16.java49
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt32.java48
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java69
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/EvaluatorClosedException.java25
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java35
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/UnsupportedTypeException.java27
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/AndNode.java47
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java350
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BinaryNode.java45
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BooleanNode.java43
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/Condition.java42
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java49
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DoubleNode.java43
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ExpressionTree.java46
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FieldNode.java43
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FloatNode.java43
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java54
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IfNode.java52
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java176
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IntNode.java43
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/LongNode.java43
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/NullNode.java41
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/OrNode.java47
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/StringNode.java48
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java230
-rw-r--r--src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeNode.java34
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java404
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java89
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryTest.java65
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterProjectTest.java102
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java315
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java151
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java797
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java2470
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/TestJniLoader.java53
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/ArrowTypeHelperTest.java105
-rw-r--r--src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/TreeBuilderTest.java350
-rw-r--r--src/arrow/java/gandiva/src/test/resources/logback.xml28
51 files changed, 8389 insertions, 0 deletions
diff --git a/src/arrow/java/gandiva/CMakeLists.txt b/src/arrow/java/gandiva/CMakeLists.txt
new file mode 100644
index 000000000..5010daf79
--- /dev/null
+++ b/src/arrow/java/gandiva/CMakeLists.txt
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+project(gandiva_java)
+
+# Find java/jni
+include(FindJava)
+include(UseJava)
+include(FindJNI)
+
+message("generating headers to ${JNI_HEADERS_DIR}/jni")
+
+# generate_native_headers is available only from java8
+# centos5 does not have java8 images, so supporting java 7 too.
+# unfortunately create_javah does not work in java8 correctly.
+if(ARROW_GANDIVA_JAVA7)
+ add_jar(gandiva_java
+ src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java
+ src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java
+ src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java
+ src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java)
+
+ create_javah(TARGET gandiva_jni_headers
+ CLASSES org.apache.arrow.gandiva.evaluator.ConfigurationBuilder
+ org.apache.arrow.gandiva.evaluator.JniWrapper
+ org.apache.arrow.gandiva.evaluator.ExpressionRegistryJniHelper
+ org.apache.arrow.gandiva.exceptions.GandivaException
+ DEPENDS gandiva_java
+ CLASSPATH gandiva_java
+ OUTPUT_DIR ${JNI_HEADERS_DIR}/jni)
+else()
+ add_jar(gandiva_java
+ src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java
+ src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java
+ src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java
+ src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java
+ GENERATE_NATIVE_HEADERS
+ gandiva_jni_headers
+ DESTINATION
+ ${JNI_HEADERS_DIR}/jni)
+endif()
diff --git a/src/arrow/java/gandiva/README.md b/src/arrow/java/gandiva/README.md
new file mode 100644
index 000000000..22a292eaf
--- /dev/null
+++ b/src/arrow/java/gandiva/README.md
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+# Gandiva Java
+
+## Setup Build Environment
+
+install:
+ - java 7 or later
+ - maven 3.3 or later
+
+## Building and running tests
+
+```
+cd java
+mvn install -Dgandiva.cpp.build.dir=<path_to_cpp_artifact_directory>
+```
diff --git a/src/arrow/java/gandiva/pom.xml b/src/arrow/java/gandiva/pom.xml
new file mode 100644
index 000000000..81caf12f5
--- /dev/null
+++ b/src/arrow/java/gandiva/pom.xml
@@ -0,0 +1,153 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
+ license agreements. See the NOTICE file distributed with this work for additional
+ information regarding copyright ownership. The ASF licenses this file to
+ You under the Apache License, Version 2.0 (the "License"); you may not use
+ this file except in compliance with the License. You may obtain a copy of
+ the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
+ by applicable law or agreed to in writing, software distributed under the
+ License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
+ OF ANY KIND, either express or implied. See the License for the specific
+ language governing permissions and limitations under the License. -->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-java-root</artifactId>
+ <version>6.0.1</version>
+ </parent>
+
+ <groupId>org.apache.arrow.gandiva</groupId>
+ <artifactId>arrow-gandiva</artifactId>
+ <packaging>jar</packaging>
+ <name>Arrow Gandiva</name>
+ <description>Java wrappers around the native Gandiva SQL expression compiler.</description>
+ <properties>
+ <maven.compiler.source>1.8</maven.compiler.source>
+ <maven.compiler.target>1.8</maven.compiler.target>
+ <protobuf.version>2.5.0</protobuf.version>
+ <checkstyle.failOnViolation>true</checkstyle.failOnViolation>
+ <arrow.cpp.build.dir>../../../cpp/release-build</arrow.cpp.build.dir>
+ </properties>
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-memory-core</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-memory-netty</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-vector</artifactId>
+ <version>${project.version}</version>
+ <classifier>${arrow.vector.classifier}</classifier>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+ </dependencies>
+ <profiles>
+ <profile>
+ <id>release</id>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.2.1</version>
+ <executions>
+ <execution>
+ <id>attach-sources</id>
+ <goals>
+ <goal>jar-no-fork</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <version>2.9.1</version>
+ <executions>
+ <execution>
+ <id>attach-javadocs</id>
+ <goals>
+ <goal>jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.5</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+ <build>
+ <resources>
+ <resource>
+ <directory>${arrow.cpp.build.dir}</directory>
+ <includes>
+ <include>**/gandiva_jni.*</include>
+ <include>**/libgandiva_jni.*</include>
+ </includes>
+ </resource>
+ </resources>
+
+ <extensions>
+ <extension>
+ <groupId>kr.motd.maven</groupId>
+ <artifactId>os-maven-plugin</artifactId>
+ <version>1.5.0.Final</version>
+ </extension>
+ </extensions>
+ <plugins>
+ <plugin>
+ <groupId>org.xolstice.maven.plugins</groupId>
+ <artifactId>protobuf-maven-plugin</artifactId>
+ <version>0.5.1</version>
+ <configuration>
+ <protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier}
+ </protocArtifact>
+ <protoSourceRoot>../../cpp/src/gandiva/proto</protoSourceRoot>
+ </configuration>
+ <executions>
+ <execution>
+ <goals>
+ <goal>compile</goal>
+ <goal>test-compile</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+
+</project>
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java
new file mode 100644
index 000000000..e903b4e87
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.util.Objects;
+
+/**
+ * Used to construct gandiva configuration objects.
+ */
+public class ConfigurationBuilder {
+
+ public long buildConfigInstance(ConfigOptions configOptions) {
+ return buildConfigInstance(configOptions.optimize, configOptions.targetCPU);
+ }
+
+ private native long buildConfigInstance(boolean optimize, boolean detectHostCPU);
+
+ public native void releaseConfigInstance(long configId);
+
+ /**
+ * ConfigOptions contains the configuration parameters to provide to gandiva.
+ */
+ public static class ConfigOptions {
+ private boolean optimize = true;
+ private boolean targetCPU = true;
+
+ public static ConfigOptions getDefault() {
+ return new ConfigOptions();
+ }
+
+ public ConfigOptions() {}
+
+ public ConfigOptions withOptimize(boolean optimize) {
+ this.optimize = optimize;
+ return this;
+ }
+
+ public ConfigOptions withTargetCPU(boolean targetCPU) {
+ this.targetCPU = targetCPU;
+ return this;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(optimize, targetCPU);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof ConfigOptions)) {
+ return false;
+ }
+ return this.optimize == ((ConfigOptions) obj).optimize &&
+ this.targetCPU == ((ConfigOptions) obj).targetCPU;
+ }
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java
new file mode 100644
index 000000000..e0c072cfb
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.vector.types.pojo.ArrowType.Decimal;
+
+/**
+ * Utility methods for working with {@link Decimal} values.
+ */
+public class DecimalTypeUtil {
+ private DecimalTypeUtil() {}
+
+ /**
+ * Enum for supported mathematical operations.
+ */
+ public enum OperationType {
+ ADD,
+ SUBTRACT,
+ MULTIPLY,
+ DIVIDE,
+ MOD
+ }
+
+ private static final int MIN_ADJUSTED_SCALE = 6;
+ /// The maximum precision representable by a 16-byte decimal
+ private static final int MAX_PRECISION = 38;
+
+ /**
+ * Determines the scale and precision of applying the given operation to the operands.
+ */
+ public static Decimal getResultTypeForOperation(OperationType operation, Decimal operand1, Decimal
+ operand2) {
+ int s1 = operand1.getScale();
+ int s2 = operand2.getScale();
+ int p1 = operand1.getPrecision();
+ int p2 = operand2.getPrecision();
+ int resultScale = 0;
+ int resultPrecision = 0;
+ switch (operation) {
+ case ADD:
+ case SUBTRACT:
+ resultScale = Math.max(operand1.getScale(), operand2.getScale());
+ resultPrecision = resultScale + Math.max(operand1.getPrecision() - operand1.getScale(),
+ operand2.getPrecision() - operand2.getScale()) + 1;
+ break;
+ case MULTIPLY:
+ resultScale = s1 + s2;
+ resultPrecision = p1 + p2 + 1;
+ break;
+ case DIVIDE:
+ resultScale =
+ Math.max(MIN_ADJUSTED_SCALE, operand1.getScale() + operand2.getPrecision() + 1);
+ resultPrecision =
+ operand1.getPrecision() - operand1.getScale() + operand2.getScale() + resultScale;
+ break;
+ case MOD:
+ resultScale = Math.max(operand1.getScale(), operand2.getScale());
+ resultPrecision = Math.min(operand1.getPrecision() - operand1.getScale(),
+ operand2.getPrecision() - operand2.getScale()) +
+ resultScale;
+ break;
+ default:
+ throw new RuntimeException("Needs support");
+ }
+ return adjustScaleIfNeeded(resultPrecision, resultScale);
+ }
+
+ private static Decimal adjustScaleIfNeeded(int precision, int scale) {
+ if (precision > MAX_PRECISION) {
+ int minScale = Math.min(scale, MIN_ADJUSTED_SCALE);
+ int delta = precision - MAX_PRECISION;
+ precision = MAX_PRECISION;
+ scale = Math.max(scale - delta, minScale);
+ }
+ return new Decimal(precision, scale, 128);
+ }
+
+}
+
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java
new file mode 100644
index 000000000..0155af082
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.util.List;
+import java.util.Set;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.gandiva.ipc.GandivaTypes.ExtGandivaType;
+import org.apache.arrow.gandiva.ipc.GandivaTypes.GandivaDataTypes;
+import org.apache.arrow.gandiva.ipc.GandivaTypes.GandivaFunctions;
+import org.apache.arrow.gandiva.ipc.GandivaTypes.GandivaType;
+import org.apache.arrow.vector.types.DateUnit;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.IntervalUnit;
+import org.apache.arrow.vector.types.TimeUnit;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import com.google.protobuf.InvalidProtocolBufferException;
+
+/**
+ * Used to get the functions and data types supported by
+ * Gandiva.
+ * All types are in Arrow namespace.
+ */
+public class ExpressionRegistry {
+
+ private static final int BIT_WIDTH8 = 8;
+ private static final int BIT_WIDTH_16 = 16;
+ private static final int BIT_WIDTH_32 = 32;
+ private static final int BIT_WIDTH_64 = 64;
+ private static final boolean IS_SIGNED_FALSE = false;
+ private static final boolean IS_SIGNED_TRUE = true;
+
+ private final Set<ArrowType> supportedTypes;
+ private final Set<FunctionSignature> functionSignatures;
+
+ private static volatile ExpressionRegistry INSTANCE;
+
+ private ExpressionRegistry(Set<ArrowType> supportedTypes,
+ Set<FunctionSignature> functionSignatures) {
+ this.supportedTypes = supportedTypes;
+ this.functionSignatures = functionSignatures;
+ }
+
+ /**
+ * Returns a singleton instance of the class.
+ * @return singleton instance
+ * @throws GandivaException if error in Gandiva Library integration.
+ */
+ public static ExpressionRegistry getInstance() throws GandivaException {
+ if (INSTANCE == null) {
+ synchronized (ExpressionRegistry.class) {
+ if (INSTANCE == null) {
+ // ensure library is setup.
+ JniLoader.getInstance();
+ Set<ArrowType> typesFromGandiva = getSupportedTypesFromGandiva();
+ Set<FunctionSignature> functionsFromGandiva = getSupportedFunctionsFromGandiva();
+ INSTANCE = new ExpressionRegistry(typesFromGandiva, functionsFromGandiva);
+ }
+ }
+ }
+ return INSTANCE;
+ }
+
+ public Set<FunctionSignature> getSupportedFunctions() {
+ return functionSignatures;
+ }
+
+ public Set<ArrowType> getSupportedTypes() {
+ return supportedTypes;
+ }
+
+ private static Set<ArrowType> getSupportedTypesFromGandiva() throws GandivaException {
+ Set<ArrowType> supportedTypes = Sets.newHashSet();
+ try {
+ byte[] gandivaSupportedDataTypes = new ExpressionRegistryJniHelper()
+ .getGandivaSupportedDataTypes();
+ GandivaDataTypes gandivaDataTypes = GandivaDataTypes.parseFrom(gandivaSupportedDataTypes);
+ for (ExtGandivaType type : gandivaDataTypes.getDataTypeList()) {
+ supportedTypes.add(getArrowType(type));
+ }
+ } catch (InvalidProtocolBufferException invalidProtException) {
+ throw new GandivaException("Could not get supported types.", invalidProtException);
+ }
+ return supportedTypes;
+ }
+
+ private static Set<FunctionSignature> getSupportedFunctionsFromGandiva() throws
+ GandivaException {
+ Set<FunctionSignature> supportedTypes = Sets.newHashSet();
+ try {
+ byte[] gandivaSupportedFunctions = new ExpressionRegistryJniHelper()
+ .getGandivaSupportedFunctions();
+ GandivaFunctions gandivaFunctions = GandivaFunctions.parseFrom(gandivaSupportedFunctions);
+ for (GandivaTypes.FunctionSignature protoFunctionSignature
+ : gandivaFunctions.getFunctionList()) {
+
+ String functionName = protoFunctionSignature.getName();
+ ArrowType returnType = getArrowType(protoFunctionSignature.getReturnType());
+ List<ArrowType> paramTypes = Lists.newArrayList();
+ for (ExtGandivaType type : protoFunctionSignature.getParamTypesList()) {
+ paramTypes.add(getArrowType(type));
+ }
+ FunctionSignature functionSignature = new FunctionSignature(functionName,
+ returnType, paramTypes);
+ supportedTypes.add(functionSignature);
+ }
+ } catch (InvalidProtocolBufferException invalidProtException) {
+ throw new GandivaException("Could not get supported functions.", invalidProtException);
+ }
+ return supportedTypes;
+ }
+
+ private static ArrowType getArrowType(ExtGandivaType type) {
+ switch (type.getType().getNumber()) {
+ case GandivaType.BOOL_VALUE:
+ return ArrowType.Bool.INSTANCE;
+ case GandivaType.UINT8_VALUE:
+ return new ArrowType.Int(BIT_WIDTH8, IS_SIGNED_FALSE);
+ case GandivaType.INT8_VALUE:
+ return new ArrowType.Int(BIT_WIDTH8, IS_SIGNED_TRUE);
+ case GandivaType.UINT16_VALUE:
+ return new ArrowType.Int(BIT_WIDTH_16, IS_SIGNED_FALSE);
+ case GandivaType.INT16_VALUE:
+ return new ArrowType.Int(BIT_WIDTH_16, IS_SIGNED_TRUE);
+ case GandivaType.UINT32_VALUE:
+ return new ArrowType.Int(BIT_WIDTH_32, IS_SIGNED_FALSE);
+ case GandivaType.INT32_VALUE:
+ return new ArrowType.Int(BIT_WIDTH_32, IS_SIGNED_TRUE);
+ case GandivaType.UINT64_VALUE:
+ return new ArrowType.Int(BIT_WIDTH_64, IS_SIGNED_FALSE);
+ case GandivaType.INT64_VALUE:
+ return new ArrowType.Int(BIT_WIDTH_64, IS_SIGNED_TRUE);
+ case GandivaType.HALF_FLOAT_VALUE:
+ return new ArrowType.FloatingPoint(FloatingPointPrecision.HALF);
+ case GandivaType.FLOAT_VALUE:
+ return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
+ case GandivaType.DOUBLE_VALUE:
+ return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
+ case GandivaType.UTF8_VALUE:
+ return new ArrowType.Utf8();
+ case GandivaType.BINARY_VALUE:
+ return new ArrowType.Binary();
+ case GandivaType.DATE32_VALUE:
+ return new ArrowType.Date(DateUnit.DAY);
+ case GandivaType.DATE64_VALUE:
+ return new ArrowType.Date(DateUnit.MILLISECOND);
+ case GandivaType.TIMESTAMP_VALUE:
+ return new ArrowType.Timestamp(mapArrowTimeUnit(type.getTimeUnit()), null);
+ case GandivaType.TIME32_VALUE:
+ return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()),
+ BIT_WIDTH_32);
+ case GandivaType.TIME64_VALUE:
+ return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()),
+ BIT_WIDTH_64);
+ case GandivaType.NONE_VALUE:
+ return new ArrowType.Null();
+ case GandivaType.DECIMAL_VALUE:
+ return new ArrowType.Decimal(0, 0, 128);
+ case GandivaType.INTERVAL_VALUE:
+ return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType()));
+ case GandivaType.FIXED_SIZE_BINARY_VALUE:
+ case GandivaType.MAP_VALUE:
+ case GandivaType.DICTIONARY_VALUE:
+ case GandivaType.LIST_VALUE:
+ case GandivaType.STRUCT_VALUE:
+ case GandivaType.UNION_VALUE:
+ default:
+ assert false;
+ }
+ return null;
+ }
+
+ private static TimeUnit mapArrowTimeUnit(GandivaTypes.TimeUnit timeUnit) {
+ switch (timeUnit.getNumber()) {
+ case GandivaTypes.TimeUnit.MICROSEC_VALUE:
+ return TimeUnit.MICROSECOND;
+ case GandivaTypes.TimeUnit.MILLISEC_VALUE:
+ return TimeUnit.MILLISECOND;
+ case GandivaTypes.TimeUnit.NANOSEC_VALUE:
+ return TimeUnit.NANOSECOND;
+ case GandivaTypes.TimeUnit.SEC_VALUE:
+ return TimeUnit.SECOND;
+ default:
+ return null;
+ }
+ }
+
+ private static IntervalUnit mapArrowIntervalUnit(GandivaTypes.IntervalType intervalType) {
+ switch (intervalType.getNumber()) {
+ case GandivaTypes.IntervalType.YEAR_MONTH_VALUE:
+ return IntervalUnit.YEAR_MONTH;
+ case GandivaTypes.IntervalType.DAY_TIME_VALUE:
+ return IntervalUnit.DAY_TIME;
+ default:
+ return null;
+ }
+ }
+
+}
+
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java
new file mode 100644
index 000000000..86c1eaaed
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+/**
+ * JNI Adapter used to get supported types and functions
+ * from Gandiva.
+ */
+class ExpressionRegistryJniHelper {
+
+ native byte[] getGandivaSupportedDataTypes();
+
+ native byte[] getGandivaSupportedFunctions();
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
new file mode 100644
index 000000000..010d644d1
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.EvaluatorClosedException;
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.ArrowTypeHelper;
+import org.apache.arrow.gandiva.expression.Condition;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.vector.ipc.message.ArrowBuffer;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This class provides a mechanism to filter a RecordBatch by evaluating a condition expression.
+ * Follow these steps to use this class: 1) Use the static method make() to create an instance of
+ * this class that evaluates a condition. 2) Invoke the method evaluate() to evaluate the filter
+ * against a RecordBatch 3) Invoke close() to release resources
+ */
+public class Filter {
+
+ private static final Logger logger = LoggerFactory.getLogger(Filter.class);
+
+ private final JniWrapper wrapper;
+ private final long moduleId;
+ private final Schema schema;
+ private boolean closed;
+
+ private Filter(JniWrapper wrapper, long moduleId, Schema schema) {
+ this.wrapper = wrapper;
+ this.moduleId = moduleId;
+ this.schema = schema;
+ this.closed = false;
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke
+ * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used to
+ * create the TreeNodes
+ * @param condition condition to be evaluated against data
+ * @return A native filter object that can be used to invoke on a RecordBatch
+ */
+ public static Filter make(Schema schema, Condition condition) throws GandivaException {
+ return make(schema, condition, JniLoader.getDefaultConfiguration());
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke
+ * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used to
+ * create the TreeNodes
+ * @param condition condition to be evaluated against data
+ * @param configOptions ConfigOptions parameter
+ * @return A native filter object that can be used to invoke on a RecordBatch
+ */
+ public static Filter make(Schema schema, Condition condition, ConfigurationBuilder.ConfigOptions configOptions)
+ throws GandivaException {
+ return make(schema, condition, JniLoader.getConfiguration(configOptions));
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke
+ * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used to
+ * create the TreeNodes
+ * @param condition condition to be evaluated against data
+ * @param optimize Flag to choose if the generated llvm code is to be optimized
+ * @return A native filter object that can be used to invoke on a RecordBatch
+ */
+ @Deprecated
+ public static Filter make(Schema schema, Condition condition, boolean optimize) throws GandivaException {
+ return make(schema, condition, JniLoader.getConfiguration((new ConfigurationBuilder.ConfigOptions())
+ .withOptimize(optimize)));
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke
+ * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used to
+ * create the TreeNodes
+ * @param condition condition to be evaluated against data
+ * @param configurationId Custom configuration created through config builder.
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ public static Filter make(Schema schema, Condition condition, long configurationId)
+ throws GandivaException {
+ // Invoke the JNI layer to create the LLVM module representing the filter.
+ GandivaTypes.Condition conditionBuf = condition.toProtobuf();
+ GandivaTypes.Schema schemaBuf = ArrowTypeHelper.arrowSchemaToProtobuf(schema);
+ JniWrapper wrapper = JniLoader.getInstance().getWrapper();
+ long moduleId = wrapper.buildFilter(schemaBuf.toByteArray(),
+ conditionBuf.toByteArray(), configurationId);
+ logger.debug("Created module for the filter with id {}", moduleId);
+ return new Filter(wrapper, moduleId, schema);
+ }
+
+ /**
+ * Invoke this function to evaluate a filter against a recordBatch.
+ *
+ * @param recordBatch Record batch including the data
+ * @param selectionVector Result of applying the filter on the data
+ */
+ public void evaluate(ArrowRecordBatch recordBatch, SelectionVector selectionVector)
+ throws GandivaException {
+ evaluate(recordBatch.getLength(), recordBatch.getBuffers(), recordBatch.getBuffersLayout(),
+ selectionVector);
+ }
+
+ /**
+ * Invoke this function to evaluate filter against a set of arrow buffers. (this is an optimised
+ * version that skips taking references).
+ *
+ * @param numRows number of rows.
+ * @param buffers List of input arrow buffers
+ * @param selectionVector Result of applying the filter on the data
+ */
+ public void evaluate(int numRows, List<ArrowBuf> buffers,
+ SelectionVector selectionVector) throws GandivaException {
+ List<ArrowBuffer> buffersLayout = new ArrayList<>();
+ long offset = 0;
+ for (ArrowBuf arrowBuf : buffers) {
+ long size = arrowBuf.readableBytes();
+ buffersLayout.add(new ArrowBuffer(offset, size));
+ offset += size;
+ }
+ evaluate(numRows, buffers, buffersLayout, selectionVector);
+ }
+
+ private void evaluate(int numRows, List<ArrowBuf> buffers, List<ArrowBuffer> buffersLayout,
+ SelectionVector selectionVector) throws GandivaException {
+ if (this.closed) {
+ throw new EvaluatorClosedException();
+ }
+ if (selectionVector.getMaxRecords() < numRows) {
+ logger.error("selectionVector has capacity for " + selectionVector.getMaxRecords() +
+ " rows, minimum required " + numRows);
+ throw new GandivaException("SelectionVector too small");
+ }
+
+ long[] bufAddrs = new long[buffers.size()];
+ long[] bufSizes = new long[buffers.size()];
+
+ int idx = 0;
+ for (ArrowBuf buf : buffers) {
+ bufAddrs[idx++] = buf.memoryAddress();
+ }
+
+ idx = 0;
+ for (ArrowBuffer bufLayout : buffersLayout) {
+ bufSizes[idx++] = bufLayout.getSize();
+ }
+
+ int numRecords = wrapper.evaluateFilter(this.moduleId, numRows,
+ bufAddrs, bufSizes,
+ selectionVector.getType().getNumber(),
+ selectionVector.getBuffer().memoryAddress(), selectionVector.getBuffer().capacity());
+ if (numRecords >= 0) {
+ selectionVector.setRecordCount(numRecords);
+ }
+ }
+
+ /**
+ * Closes the LLVM module representing this filter.
+ */
+ public void close() throws GandivaException {
+ if (this.closed) {
+ return;
+ }
+
+ wrapper.closeFilter(this.moduleId);
+ this.closed = true;
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java
new file mode 100644
index 000000000..d01881843
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.util.List;
+
+import org.apache.arrow.vector.types.pojo.ArrowType;
+
+import com.google.common.base.MoreObjects;
+import com.google.common.base.Objects;
+
+/**
+ * POJO to define a function signature.
+ */
+public class FunctionSignature {
+ private final String name;
+ private final ArrowType returnType;
+ private final List<ArrowType> paramTypes;
+
+ public ArrowType getReturnType() {
+ return returnType;
+ }
+
+ public List<ArrowType> getParamTypes() {
+ return paramTypes;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * Ctor.
+ * @param name - name of the function.
+ * @param returnType - data type of return
+ * @param paramTypes - data type of input args.
+ */
+ public FunctionSignature(String name, ArrowType returnType, List<ArrowType> paramTypes) {
+ this.name = name;
+ this.returnType = returnType;
+ this.paramTypes = paramTypes;
+ }
+
+ /**
+ * Override equals.
+ * @param signature - signature to compare
+ * @return true if equal and false if not.
+ */
+ public boolean equals(Object signature) {
+ if (signature == null) {
+ return false;
+ }
+ if (getClass() != signature.getClass()) {
+ return false;
+ }
+ final FunctionSignature other = (FunctionSignature) signature;
+ return this.name.equalsIgnoreCase(other.name) &&
+ Objects.equal(this.returnType, other.returnType) &&
+ Objects.equal(this.paramTypes, other.paramTypes);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(this.name.toLowerCase(), this.returnType, this.paramTypes);
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("name ", name)
+ .add("return type ", returnType)
+ .add("param types ", paramTypes)
+ .toString();
+
+ }
+
+
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java
new file mode 100644
index 000000000..676956a34
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import static java.util.UUID.randomUUID;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.StandardCopyOption;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+
+/**
+ * This class handles loading of the jni library, and acts as a bridge for the native functions.
+ */
+class JniLoader {
+ private static final String LIBRARY_NAME = "gandiva_jni";
+
+ private static volatile JniLoader INSTANCE;
+ private static volatile long defaultConfiguration = 0L;
+ private static final ConcurrentMap<ConfigurationBuilder.ConfigOptions, Long> configurationMap
+ = new ConcurrentHashMap<>();
+
+ private final JniWrapper wrapper;
+
+ private JniLoader() {
+ this.wrapper = new JniWrapper();
+ }
+
+ static JniLoader getInstance() throws GandivaException {
+ if (INSTANCE == null) {
+ synchronized (JniLoader.class) {
+ if (INSTANCE == null) {
+ INSTANCE = setupInstance();
+ }
+ }
+ }
+ return INSTANCE;
+ }
+
+ private static JniLoader setupInstance() throws GandivaException {
+ try {
+ String tempDir = System.getProperty("java.io.tmpdir");
+ loadGandivaLibraryFromJar(tempDir);
+ return new JniLoader();
+ } catch (IOException ioException) {
+ throw new GandivaException("unable to create native instance", ioException);
+ }
+ }
+
+ private static void loadGandivaLibraryFromJar(final String tmpDir)
+ throws IOException, GandivaException {
+ final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME);
+ final File libraryFile = moveFileFromJarToTemp(tmpDir, libraryToLoad);
+ System.load(libraryFile.getAbsolutePath());
+ }
+
+
+ private static File moveFileFromJarToTemp(final String tmpDir, String libraryToLoad)
+ throws IOException, GandivaException {
+ final File temp = setupFile(tmpDir, libraryToLoad);
+ try (final InputStream is = JniLoader.class.getClassLoader()
+ .getResourceAsStream(libraryToLoad)) {
+ if (is == null) {
+ throw new GandivaException(libraryToLoad + " was not found inside JAR.");
+ } else {
+ Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING);
+ }
+ }
+ return temp;
+ }
+
+ private static File setupFile(String tmpDir, String libraryToLoad)
+ throws IOException, GandivaException {
+ // accommodate multiple processes running with gandiva jar.
+ // length should be ok since uuid is only 36 characters.
+ final String randomizeFileName = libraryToLoad + randomUUID();
+ final File temp = new File(tmpDir, randomizeFileName);
+ if (temp.exists() && !temp.delete()) {
+ throw new GandivaException("File: " + temp.getAbsolutePath() +
+ " already exists and cannot be removed.");
+ }
+ if (!temp.createNewFile()) {
+ throw new GandivaException("File: " + temp.getAbsolutePath() +
+ " could not be created.");
+ }
+ temp.deleteOnExit();
+ return temp;
+ }
+
+ /**
+ * Returns the jni wrapper.
+ */
+ JniWrapper getWrapper() throws GandivaException {
+ return wrapper;
+ }
+
+ static long getConfiguration(ConfigurationBuilder.ConfigOptions configOptions) throws GandivaException {
+ if (!configurationMap.containsKey(configOptions)) {
+ synchronized (ConfigurationBuilder.class) {
+ if (!configurationMap.containsKey(configOptions)) {
+ JniLoader.getInstance(); // setup
+ long configInstance = new ConfigurationBuilder()
+ .buildConfigInstance(configOptions);
+ configurationMap.put(configOptions, configInstance);
+ if (ConfigurationBuilder.ConfigOptions.getDefault().equals(configOptions)) {
+ defaultConfiguration = configInstance;
+ }
+ return configInstance;
+ }
+ }
+ }
+ return configurationMap.get(configOptions);
+ }
+
+ /**
+ * Get the default configuration to invoke gandiva.
+ * @return default configuration
+ * @throws GandivaException if unable to get native builder instance.
+ */
+ static long getDefaultConfiguration() throws GandivaException {
+ if (defaultConfiguration == 0L) {
+ synchronized (ConfigurationBuilder.class) {
+ if (defaultConfiguration == 0L) {
+ JniLoader.getInstance(); // setup
+ ConfigurationBuilder.ConfigOptions defaultConfigOptons = ConfigurationBuilder.ConfigOptions.getDefault();
+ defaultConfiguration = new ConfigurationBuilder()
+ .buildConfigInstance(defaultConfigOptons);
+ configurationMap.put(defaultConfigOptons, defaultConfiguration);
+ }
+ }
+ }
+ return defaultConfiguration;
+ }
+
+ /**
+ * Remove the configuration.
+ */
+ static void removeConfiguration(ConfigurationBuilder.ConfigOptions configOptions) {
+ if (configurationMap.containsKey(configOptions)) {
+ synchronized (ConfigurationBuilder.class) {
+ if (configurationMap.containsKey(configOptions)) {
+ (new ConfigurationBuilder()).releaseConfigInstance(configurationMap.remove(configOptions));
+ if (configOptions.equals(ConfigurationBuilder.ConfigOptions.getDefault())) {
+ defaultConfiguration = 0;
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java
new file mode 100644
index 000000000..520ef5f44
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+
+/**
+ * This class is implemented in JNI. This provides the Java interface
+ * to invoke functions in JNI.
+ * This file is used to generated the .h files required for jni. Avoid all
+ * external dependencies in this file.
+ */
+public class JniWrapper {
+
+ /**
+ * Generates the projector module to evaluate the expressions with
+ * custom configuration.
+ *
+ * @param schemaBuf The schema serialized as a protobuf. See Types.proto
+ * to see the protobuf specification
+ * @param exprListBuf The serialized protobuf of the expression vector. Each
+ * expression is created using TreeBuilder::MakeExpression.
+ * @param selectionVectorType type of selection vector
+ * @param configId Configuration to gandiva.
+ * @return A moduleId that is passed to the evaluateProjector() and closeProjector() methods
+ *
+ */
+ native long buildProjector(byte[] schemaBuf, byte[] exprListBuf,
+ int selectionVectorType,
+ long configId) throws GandivaException;
+
+ /**
+ * Evaluate the expressions represented by the moduleId on a record batch
+ * and store the output in ValueVectors. Throws an exception in case of errors
+ *
+ * @param expander VectorExpander object. Used for callbacks from cpp.
+ * @param moduleId moduleId representing expressions. Created using a call to
+ * buildNativeCode
+ * @param numRows Number of rows in the record batch
+ * @param bufAddrs An array of memory addresses. Each memory address points to
+ * a validity vector or a data vector (will add support for offset
+ * vectors later).
+ * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs,
+ * the size of the buffer is present in bufSizes
+ * @param outAddrs An array of output buffers, including the validity and data
+ * addresses.
+ * @param outSizes The allocated size of the output buffers. On successful evaluation,
+ * the result is stored in the output buffers
+ */
+ native void evaluateProjector(Object expander, long moduleId, int numRows,
+ long[] bufAddrs, long[] bufSizes,
+ int selectionVectorType, int selectionVectorSize,
+ long selectionVectorBufferAddr, long selectionVectorBufferSize,
+ long[] outAddrs, long[] outSizes) throws GandivaException;
+
+ /**
+ * Closes the projector referenced by moduleId.
+ *
+ * @param moduleId moduleId that needs to be closed
+ */
+ native void closeProjector(long moduleId);
+
+ /**
+ * Generates the filter module to evaluate the condition expression with
+ * custom configuration.
+ *
+ * @param schemaBuf The schema serialized as a protobuf. See Types.proto
+ * to see the protobuf specification
+ * @param conditionBuf The serialized protobuf of the condition expression. Each
+ * expression is created using TreeBuilder::MakeCondition
+ * @param configId Configuration to gandiva.
+ * @return A moduleId that is passed to the evaluateFilter() and closeFilter() methods
+ *
+ */
+ native long buildFilter(byte[] schemaBuf, byte[] conditionBuf,
+ long configId) throws GandivaException;
+
+ /**
+ * Evaluate the filter represented by the moduleId on a record batch
+ * and store the output in buffer 'outAddr'. Throws an exception in case of errors
+ *
+ * @param moduleId moduleId representing expressions. Created using a call to
+ * buildNativeCode
+ * @param numRows Number of rows in the record batch
+ * @param bufAddrs An array of memory addresses. Each memory address points to
+ * a validity vector or a data vector (will add support for offset
+ * vectors later).
+ * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs,
+ * the size of the buffer is present in bufSizes
+ * @param selectionVectorType type of selection vector
+ * @param outAddr output buffer, whose type is represented by selectionVectorType
+ * @param outSize The allocated size of the output buffer. On successful evaluation,
+ * the result is stored in the output buffer
+ */
+ native int evaluateFilter(long moduleId, int numRows, long[] bufAddrs, long[] bufSizes,
+ int selectionVectorType,
+ long outAddr, long outSize) throws GandivaException;
+
+ /**
+ * Closes the filter referenced by moduleId.
+ *
+ * @param moduleId moduleId that needs to be closed
+ */
+ native void closeFilter(long moduleId);
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
new file mode 100644
index 000000000..471ddbced
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.EvaluatorClosedException;
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.exceptions.UnsupportedTypeException;
+import org.apache.arrow.gandiva.expression.ArrowTypeHelper;
+import org.apache.arrow.gandiva.expression.ExpressionTree;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.vector.BaseVariableWidthVector;
+import org.apache.arrow.vector.FixedWidthVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VariableWidthVector;
+import org.apache.arrow.vector.ipc.message.ArrowBuffer;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+/**
+ * This class provides a mechanism to evaluate a set of expressions against a RecordBatch.
+ * Follow these steps to use this class:
+ * 1) Use the static method make() to create an instance of this class that evaluates a
+ * set of expressions
+ * 2) Invoke the method evaluate() to evaluate these expressions against a RecordBatch
+ * 3) Invoke close() to release resources
+ */
+public class Projector {
+ private static final org.slf4j.Logger logger =
+ org.slf4j.LoggerFactory.getLogger(Projector.class);
+
+ private JniWrapper wrapper;
+ private final long moduleId;
+ private final Schema schema;
+ private final int numExprs;
+ private boolean closed;
+
+ private Projector(JniWrapper wrapper, long moduleId, Schema schema, int numExprs) {
+ this.wrapper = wrapper;
+ this.moduleId = moduleId;
+ this.schema = schema;
+ this.numExprs = numExprs;
+ this.closed = false;
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the list of project expressions.
+ * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch
+ * against these projections.
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used
+ * to create the TreeNodes
+ * @param exprs List of expressions to be evaluated against data
+ *
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ public static Projector make(Schema schema, List<ExpressionTree> exprs)
+ throws GandivaException {
+ return make(schema, exprs, SelectionVectorType.SV_NONE, JniLoader.getDefaultConfiguration());
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the list of project expressions.
+ * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch
+ * against these projections.
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used
+ * to create the TreeNodes
+ * @param exprs List of expressions to be evaluated against data
+ * @param configOptions ConfigOptions parameter
+ *
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ public static Projector make(Schema schema, List<ExpressionTree> exprs,
+ ConfigurationBuilder.ConfigOptions configOptions) throws GandivaException {
+ return make(schema, exprs, SelectionVectorType.SV_NONE, JniLoader.getConfiguration(configOptions));
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the list of project expressions.
+ * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch
+ * against these projections.
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used
+ * to create the TreeNodes
+ * @param exprs List of expressions to be evaluated against data
+ * @param optimize Flag to choose if the generated llvm code is to be optimized
+ *
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ @Deprecated
+ public static Projector make(Schema schema, List<ExpressionTree> exprs, boolean optimize)
+ throws GandivaException {
+ return make(schema, exprs, SelectionVectorType.SV_NONE,
+ JniLoader.getConfiguration((new ConfigurationBuilder.ConfigOptions()).withOptimize(optimize)));
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the list of project expressions.
+ * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch
+ * against these projections.
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used
+ * to create the TreeNodes
+ * @param exprs List of expressions to be evaluated against data
+ * @param selectionVectorType type of selection vector
+ *
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ public static Projector make(Schema schema, List<ExpressionTree> exprs,
+ SelectionVectorType selectionVectorType)
+ throws GandivaException {
+ return make(schema, exprs, selectionVectorType, JniLoader.getDefaultConfiguration());
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the list of project expressions.
+ * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch
+ * against these projections.
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used
+ * to create the TreeNodes
+ * @param exprs List of expressions to be evaluated against data
+ * @param selectionVectorType type of selection vector
+ * @param configOptions ConfigOptions parameter
+ *
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ public static Projector make(Schema schema, List<ExpressionTree> exprs, SelectionVectorType selectionVectorType,
+ ConfigurationBuilder.ConfigOptions configOptions) throws GandivaException {
+ return make(schema, exprs, selectionVectorType, JniLoader.getConfiguration(configOptions));
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the list of project expressions.
+ * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch
+ * against these projections.
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used
+ * to create the TreeNodes
+ * @param exprs List of expressions to be evaluated against data
+ * @param selectionVectorType type of selection vector
+ * @param optimize Flag to choose if the generated llvm code is to be optimized
+ *
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ @Deprecated
+ public static Projector make(Schema schema, List<ExpressionTree> exprs,
+ SelectionVectorType selectionVectorType, boolean optimize)
+ throws GandivaException {
+ return make(schema, exprs, selectionVectorType,
+ JniLoader.getConfiguration((new ConfigurationBuilder.ConfigOptions()).withOptimize(optimize)));
+ }
+
+ /**
+ * Invoke this function to generate LLVM code to evaluate the list of project expressions.
+ * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch
+ * against these projections.
+ *
+ * @param schema Table schema. The field names in the schema should match the fields used
+ * to create the TreeNodes
+ * @param exprs List of expressions to be evaluated against data
+ * @param selectionVectorType type of selection vector
+ * @param configurationId Custom configuration created through config builder.
+ *
+ * @return A native evaluator object that can be used to invoke these projections on a RecordBatch
+ */
+ public static Projector make(Schema schema, List<ExpressionTree> exprs,
+ SelectionVectorType selectionVectorType,
+ long configurationId) throws GandivaException {
+ // serialize the schema and the list of expressions as a protobuf
+ GandivaTypes.ExpressionList.Builder builder = GandivaTypes.ExpressionList.newBuilder();
+ for (ExpressionTree expr : exprs) {
+ builder.addExprs(expr.toProtobuf());
+ }
+
+ // Invoke the JNI layer to create the LLVM module representing the expressions
+ GandivaTypes.Schema schemaBuf = ArrowTypeHelper.arrowSchemaToProtobuf(schema);
+ JniWrapper wrapper = JniLoader.getInstance().getWrapper();
+ long moduleId = wrapper.buildProjector(schemaBuf.toByteArray(),
+ builder.build().toByteArray(), selectionVectorType.getNumber(), configurationId);
+ logger.debug("Created module for the projector with id {}", moduleId);
+ return new Projector(wrapper, moduleId, schema, exprs.size());
+ }
+
+ /**
+ * Invoke this function to evaluate a set of expressions against a recordBatch.
+ *
+ * @param recordBatch Record batch including the data
+ * @param outColumns Result of applying the project on the data
+ */
+ public void evaluate(ArrowRecordBatch recordBatch, List<ValueVector> outColumns)
+ throws GandivaException {
+ evaluate(recordBatch.getLength(), recordBatch.getBuffers(),
+ recordBatch.getBuffersLayout(),
+ SelectionVectorType.SV_NONE.getNumber(), recordBatch.getLength(),
+ 0, 0, outColumns);
+ }
+
+ /**
+ * Invoke this function to evaluate a set of expressions against a set of arrow buffers.
+ * (this is an optimised version that skips taking references).
+ *
+ * @param numRows number of rows.
+ * @param buffers List of input arrow buffers
+ * @param outColumns Result of applying the project on the data
+ */
+ public void evaluate(int numRows, List<ArrowBuf> buffers,
+ List<ValueVector> outColumns) throws GandivaException {
+ List<ArrowBuffer> buffersLayout = new ArrayList<>();
+ long offset = 0;
+ for (ArrowBuf arrowBuf : buffers) {
+ long size = arrowBuf.readableBytes();
+ buffersLayout.add(new ArrowBuffer(offset, size));
+ offset += size;
+ }
+ evaluate(numRows, buffers, buffersLayout,
+ SelectionVectorType.SV_NONE.getNumber(),
+ numRows, 0, 0, outColumns);
+ }
+
+ /**
+ * Invoke this function to evaluate a set of expressions against a {@link ArrowRecordBatch}.
+ *
+ * @param recordBatch The data to evaluate against.
+ * @param selectionVector Selection vector which stores the selected rows.
+ * @param outColumns Result of applying the project on the data
+ */
+ public void evaluate(ArrowRecordBatch recordBatch,
+ SelectionVector selectionVector, List<ValueVector> outColumns)
+ throws GandivaException {
+ evaluate(recordBatch.getLength(), recordBatch.getBuffers(),
+ recordBatch.getBuffersLayout(),
+ selectionVector.getType().getNumber(),
+ selectionVector.getRecordCount(),
+ selectionVector.getBuffer().memoryAddress(),
+ selectionVector.getBuffer().capacity(),
+ outColumns);
+ }
+
+ /**
+ * Invoke this function to evaluate a set of expressions against a set of arrow buffers
+ * on the selected positions.
+ * (this is an optimised version that skips taking references).
+ *
+ * @param numRows number of rows.
+ * @param buffers List of input arrow buffers
+ * @param selectionVector Selection vector which stores the selected rows.
+ * @param outColumns Result of applying the project on the data
+ */
+ public void evaluate(int numRows, List<ArrowBuf> buffers,
+ SelectionVector selectionVector,
+ List<ValueVector> outColumns) throws GandivaException {
+ List<ArrowBuffer> buffersLayout = new ArrayList<>();
+ long offset = 0;
+ for (ArrowBuf arrowBuf : buffers) {
+ long size = arrowBuf.readableBytes();
+ buffersLayout.add(new ArrowBuffer(offset, size));
+ offset += size;
+ }
+ evaluate(numRows, buffers, buffersLayout,
+ selectionVector.getType().getNumber(),
+ selectionVector.getRecordCount(),
+ selectionVector.getBuffer().memoryAddress(),
+ selectionVector.getBuffer().capacity(),
+ outColumns);
+ }
+
+ private void evaluate(int numRows, List<ArrowBuf> buffers, List<ArrowBuffer> buffersLayout,
+ int selectionVectorType, int selectionVectorRecordCount,
+ long selectionVectorAddr, long selectionVectorSize,
+ List<ValueVector> outColumns) throws GandivaException {
+ if (this.closed) {
+ throw new EvaluatorClosedException();
+ }
+
+ if (numExprs != outColumns.size()) {
+ logger.info("Expected " + numExprs + " columns, got " + outColumns.size());
+ throw new GandivaException("Incorrect number of columns for the output vector");
+ }
+
+ long[] bufAddrs = new long[buffers.size()];
+ long[] bufSizes = new long[buffers.size()];
+
+ int idx = 0;
+ for (ArrowBuf buf : buffers) {
+ bufAddrs[idx++] = buf.memoryAddress();
+ }
+
+ idx = 0;
+ for (ArrowBuffer bufLayout : buffersLayout) {
+ bufSizes[idx++] = bufLayout.getSize();
+ }
+
+ boolean hasVariableWidthColumns = false;
+ BaseVariableWidthVector[] resizableVectors = new BaseVariableWidthVector[outColumns.size()];
+ long[] outAddrs = new long[3 * outColumns.size()];
+ long[] outSizes = new long[3 * outColumns.size()];
+ idx = 0;
+ int outColumnIdx = 0;
+ for (ValueVector valueVector : outColumns) {
+ boolean isFixedWith = valueVector instanceof FixedWidthVector;
+ boolean isVarWidth = valueVector instanceof VariableWidthVector;
+ if (!isFixedWith && !isVarWidth) {
+ throw new UnsupportedTypeException(
+ "Unsupported value vector type " + valueVector.getField().getFieldType());
+ }
+
+ outAddrs[idx] = valueVector.getValidityBuffer().memoryAddress();
+ outSizes[idx++] = valueVector.getValidityBuffer().capacity();
+ if (isVarWidth) {
+ outAddrs[idx] = valueVector.getOffsetBuffer().memoryAddress();
+ outSizes[idx++] = valueVector.getOffsetBuffer().capacity();
+ hasVariableWidthColumns = true;
+
+ // save vector to allow for resizing.
+ resizableVectors[outColumnIdx] = (BaseVariableWidthVector) valueVector;
+ }
+ outAddrs[idx] = valueVector.getDataBuffer().memoryAddress();
+ outSizes[idx++] = valueVector.getDataBuffer().capacity();
+
+ valueVector.setValueCount(selectionVectorRecordCount);
+ outColumnIdx++;
+ }
+
+ wrapper.evaluateProjector(
+ hasVariableWidthColumns ? new VectorExpander(resizableVectors) : null,
+ this.moduleId, numRows, bufAddrs, bufSizes,
+ selectionVectorType, selectionVectorRecordCount,
+ selectionVectorAddr, selectionVectorSize,
+ outAddrs, outSizes);
+ }
+
+ /**
+ * Closes the LLVM module representing this evaluator.
+ */
+ public void close() throws GandivaException {
+ if (this.closed) {
+ return;
+ }
+
+ wrapper.closeProjector(this.moduleId);
+ this.closed = true;
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVector.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVector.java
new file mode 100644
index 000000000..2af88b526
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVector.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import static org.apache.arrow.memory.util.LargeMemoryUtil.capAtMaxInt;
+
+import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType;
+import org.apache.arrow.memory.ArrowBuf;
+
+/**
+ * A selection vector contains the indexes of "selected" records in a row batch. It is backed by an
+ * arrow buffer.
+ * Client manages the lifecycle of the arrow buffer - to release the reference.
+ */
+public abstract class SelectionVector {
+ private int recordCount;
+ private ArrowBuf buffer;
+
+ public SelectionVector(ArrowBuf buffer) {
+ this.buffer = buffer;
+ }
+
+ public final ArrowBuf getBuffer() {
+ return this.buffer;
+ }
+
+ /*
+ * The maximum number of records that the selection vector can hold.
+ */
+ public final int getMaxRecords() {
+ return capAtMaxInt(buffer.capacity() / getRecordSize());
+ }
+
+ /*
+ * The number of records held by the selection vector.
+ */
+ public final int getRecordCount() {
+ return this.recordCount;
+ }
+
+ /*
+ * Set the number of records in the selection vector.
+ */
+ final void setRecordCount(int recordCount) {
+ if (recordCount * getRecordSize() > buffer.capacity()) {
+ throw new IllegalArgumentException("recordCount " + recordCount +
+ " of size " + getRecordSize() +
+ " exceeds buffer capacity " + buffer.capacity());
+ }
+
+ this.recordCount = recordCount;
+ }
+
+ /*
+ * Get the value at specified index.
+ */
+ public abstract int getIndex(int index);
+
+ /*
+ * Get the record size of the selection vector itself.
+ */
+ abstract int getRecordSize();
+
+ abstract SelectionVectorType getType();
+
+ final void checkReadBounds(int index) {
+ if (index >= this.recordCount) {
+ throw new IllegalArgumentException("index " + index + " is >= recordCount " + recordCount);
+ }
+ }
+
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt16.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt16.java
new file mode 100644
index 000000000..84c795b67
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt16.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType;
+import org.apache.arrow.memory.ArrowBuf;
+
+/**
+ * Selection vector with records of arrow type INT16.
+ */
+public class SelectionVectorInt16 extends SelectionVector {
+
+ public SelectionVectorInt16(ArrowBuf buffer) {
+ super(buffer);
+ }
+
+ @Override
+ public int getRecordSize() {
+ return 2;
+ }
+
+ @Override
+ public SelectionVectorType getType() {
+ return SelectionVectorType.SV_INT16;
+ }
+
+ @Override
+ public int getIndex(int index) {
+ checkReadBounds(index);
+
+ char value = getBuffer().getChar(index * getRecordSize());
+ return (int) value;
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt32.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt32.java
new file mode 100644
index 000000000..c938f6691
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt32.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType;
+import org.apache.arrow.memory.ArrowBuf;
+
+/**
+ * Selection vector with records of arrow type INT32.
+ */
+public class SelectionVectorInt32 extends SelectionVector {
+
+ public SelectionVectorInt32(ArrowBuf buffer) {
+ super(buffer);
+ }
+
+ @Override
+ public int getRecordSize() {
+ return 4;
+ }
+
+ @Override
+ public SelectionVectorType getType() {
+ return SelectionVectorType.SV_INT32;
+ }
+
+ @Override
+ public int getIndex(int index) {
+ checkReadBounds(index);
+
+ return getBuffer().getInt(index * getRecordSize());
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java
new file mode 100644
index 000000000..f22ebbd37
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.vector.BaseVariableWidthVector;
+
+/**
+ * This class provides the functionality to expand output vectors using a callback mechanism from
+ * gandiva.
+ */
+public class VectorExpander {
+ private final BaseVariableWidthVector[] vectors;
+
+ public VectorExpander(BaseVariableWidthVector[] vectors) {
+ this.vectors = vectors;
+ }
+
+ /**
+ * Result of vector expansion.
+ */
+ public static class ExpandResult {
+ public long address;
+ public long capacity;
+
+ public ExpandResult(long address, long capacity) {
+ this.address = address;
+ this.capacity = capacity;
+ }
+ }
+
+ /**
+ * Expand vector at specified index. This is used as a back call from jni, and is only
+ * relevant for variable width vectors.
+ *
+ * @param index index of buffer in the list passed to jni.
+ * @param toCapacity the size to which the buffer should be expanded to.
+ *
+ * @return address and size of the buffer after expansion.
+ */
+ public ExpandResult expandOutputVectorAtIndex(int index, long toCapacity) {
+ if (index >= vectors.length || vectors[index] == null) {
+ throw new IllegalArgumentException("invalid index " + index);
+ }
+
+ BaseVariableWidthVector vector = vectors[index];
+ while (vector.getDataBuffer().capacity() < toCapacity) {
+ vector.reallocDataBuffer();
+ }
+ return new ExpandResult(
+ vector.getDataBuffer().memoryAddress(),
+ vector.getDataBuffer().capacity());
+ }
+
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/EvaluatorClosedException.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/EvaluatorClosedException.java
new file mode 100644
index 000000000..d3fb8b60d
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/EvaluatorClosedException.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.exceptions;
+
+/** Indicates an attempted call to methods on a closed evaluator. */
+public class EvaluatorClosedException extends GandivaException {
+ public EvaluatorClosedException() {
+ super("Cannot invoke methods on evaluator after closing it");
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java
new file mode 100644
index 000000000..e7fce58a3
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.exceptions;
+
+/** Base class for all specialized exceptions this package uses. */
+public class GandivaException extends Exception {
+
+ public GandivaException(String msg) {
+ super(msg);
+ }
+
+ public GandivaException(String msg, Exception cause) {
+ super(msg, cause);
+ }
+
+ @Override
+ public String toString() {
+ return getMessage();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/UnsupportedTypeException.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/UnsupportedTypeException.java
new file mode 100644
index 000000000..90e06e80e
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/UnsupportedTypeException.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.exceptions;
+
+/**
+ * Represents an exception thrown while dealing with unsupported types.
+ */
+public class UnsupportedTypeException extends GandivaException {
+ public UnsupportedTypeException(String msg) {
+ super(msg);
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/AndNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/AndNode.java
new file mode 100644
index 000000000..ecc577fa7
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/AndNode.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Node representing a logical And expression.
+ */
+class AndNode implements TreeNode {
+ private final List<TreeNode> children;
+
+ AndNode(List<TreeNode> children) {
+ this.children = children;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.AndNode.Builder andNode = GandivaTypes.AndNode.newBuilder();
+
+ for (TreeNode arg : children) {
+ andNode.addArgs(arg.toProtobuf());
+ }
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setAndNode(andNode.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java
new file mode 100644
index 000000000..90f8684b4
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.flatbuf.DateUnit;
+import org.apache.arrow.flatbuf.IntervalUnit;
+import org.apache.arrow.flatbuf.TimeUnit;
+import org.apache.arrow.flatbuf.Type;
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.exceptions.UnsupportedTypeException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+/**
+ * Utility methods to convert between Arrow and Gandiva types.
+ */
+public class ArrowTypeHelper {
+ private ArrowTypeHelper() {}
+
+ static final int WIDTH_8 = 8;
+ static final int WIDTH_16 = 16;
+ static final int WIDTH_32 = 32;
+ static final int WIDTH_64 = 64;
+
+ private static void initArrowTypeInt(
+ ArrowType.Int intType, GandivaTypes.ExtGandivaType.Builder builder) throws GandivaException {
+ int width = intType.getBitWidth();
+
+ if (intType.getIsSigned()) {
+ switch (width) {
+ case WIDTH_8: {
+ builder.setType(GandivaTypes.GandivaType.INT8);
+ return;
+ }
+ case WIDTH_16: {
+ builder.setType(GandivaTypes.GandivaType.INT16);
+ return;
+ }
+ case WIDTH_32: {
+ builder.setType(GandivaTypes.GandivaType.INT32);
+ return;
+ }
+ case WIDTH_64: {
+ builder.setType(GandivaTypes.GandivaType.INT64);
+ return;
+ }
+ default: {
+ throw new UnsupportedTypeException("Unsupported width for integer type");
+ }
+ }
+ }
+
+ // unsigned int
+ switch (width) {
+ case WIDTH_8: {
+ builder.setType(GandivaTypes.GandivaType.UINT8);
+ return;
+ }
+ case WIDTH_16: {
+ builder.setType(GandivaTypes.GandivaType.UINT16);
+ return;
+ }
+ case WIDTH_32: {
+ builder.setType(GandivaTypes.GandivaType.UINT32);
+ return;
+ }
+ case WIDTH_64: {
+ builder.setType(GandivaTypes.GandivaType.UINT64);
+ return;
+ }
+ default: {
+ throw new UnsupportedTypeException("Unsupported width for integer type");
+ }
+ }
+ }
+
+ private static void initArrowTypeFloat(
+ ArrowType.FloatingPoint floatType, GandivaTypes.ExtGandivaType.Builder builder)
+ throws GandivaException {
+ switch (floatType.getPrecision()) {
+ case HALF: {
+ builder.setType(GandivaTypes.GandivaType.HALF_FLOAT);
+ break;
+ }
+ case SINGLE: {
+ builder.setType(GandivaTypes.GandivaType.FLOAT);
+ break;
+ }
+ case DOUBLE: {
+ builder.setType(GandivaTypes.GandivaType.DOUBLE);
+ break;
+ }
+ default: {
+ throw new UnsupportedTypeException("Floating point type with unknown precision");
+ }
+ }
+ }
+
+ private static void initArrowTypeDecimal(ArrowType.Decimal decimalType,
+ GandivaTypes.ExtGandivaType.Builder builder) {
+ Preconditions.checkArgument(decimalType.getPrecision() > 0 &&
+ decimalType.getPrecision() <= 38, "Gandiva only supports decimals of upto 38 " +
+ "precision. Input precision : " + decimalType.getPrecision());
+ builder.setPrecision(decimalType.getPrecision());
+ builder.setScale(decimalType.getScale());
+ builder.setType(GandivaTypes.GandivaType.DECIMAL);
+ }
+
+ private static void initArrowTypeDate(ArrowType.Date dateType,
+ GandivaTypes.ExtGandivaType.Builder builder) {
+ short dateUnit = dateType.getUnit().getFlatbufID();
+ switch (dateUnit) {
+ case DateUnit.DAY: {
+ builder.setType(GandivaTypes.GandivaType.DATE32);
+ break;
+ }
+ case DateUnit.MILLISECOND: {
+ builder.setType(GandivaTypes.GandivaType.DATE64);
+ break;
+ }
+ default: {
+ // not supported
+ break;
+ }
+ }
+ }
+
+ private static void initArrowTypeTime(ArrowType.Time timeType,
+ GandivaTypes.ExtGandivaType.Builder builder) {
+ short timeUnit = timeType.getUnit().getFlatbufID();
+ switch (timeUnit) {
+ case TimeUnit.SECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIME32);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.SEC);
+ break;
+ }
+ case TimeUnit.MILLISECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIME32);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.MILLISEC);
+ break;
+ }
+ case TimeUnit.MICROSECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIME64);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.MICROSEC);
+ break;
+ }
+ case TimeUnit.NANOSECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIME64);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.NANOSEC);
+ break;
+ }
+ default: {
+ // not supported
+ }
+ }
+ }
+
+ private static void initArrowTypeTimestamp(ArrowType.Timestamp timestampType,
+ GandivaTypes.ExtGandivaType.Builder builder) {
+ short timeUnit = timestampType.getUnit().getFlatbufID();
+ switch (timeUnit) {
+ case TimeUnit.SECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIMESTAMP);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.SEC);
+ break;
+ }
+ case TimeUnit.MILLISECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIMESTAMP);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.MILLISEC);
+ break;
+ }
+ case TimeUnit.MICROSECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIMESTAMP);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.MICROSEC);
+ break;
+ }
+ case TimeUnit.NANOSECOND: {
+ builder.setType(GandivaTypes.GandivaType.TIMESTAMP);
+ builder.setTimeUnit(GandivaTypes.TimeUnit.NANOSEC);
+ break;
+ }
+ default: {
+ // not supported
+ }
+ }
+ }
+
+ private static void initArrowTypeInterval(ArrowType.Interval interval,
+ GandivaTypes.ExtGandivaType.Builder builder) {
+ short intervalUnit = interval.getUnit().getFlatbufID();
+ switch (intervalUnit) {
+ case IntervalUnit.YEAR_MONTH: {
+ builder.setType(GandivaTypes.GandivaType.INTERVAL);
+ builder.setIntervalType(GandivaTypes.IntervalType.YEAR_MONTH);
+ break;
+ }
+ case IntervalUnit.DAY_TIME: {
+ builder.setType(GandivaTypes.GandivaType.INTERVAL);
+ builder.setIntervalType(GandivaTypes.IntervalType.DAY_TIME);
+ break;
+ }
+ default: {
+ // not supported
+ }
+ }
+ }
+
+ /**
+ * Converts an arrow type into a protobuf.
+ *
+ * @param arrowType Arrow type to be converted
+ * @return Protobuf representing the arrow type
+ */
+ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType)
+ throws GandivaException {
+ GandivaTypes.ExtGandivaType.Builder builder = GandivaTypes.ExtGandivaType.newBuilder();
+
+ byte typeId = arrowType.getTypeID().getFlatbufID();
+ switch (typeId) {
+ case Type.NONE: { // 0
+ builder.setType(GandivaTypes.GandivaType.NONE);
+ break;
+ }
+ case Type.Null: { // 1
+ // TODO: Need to handle this later
+ break;
+ }
+ case Type.Int: { // 2
+ ArrowTypeHelper.initArrowTypeInt((ArrowType.Int) arrowType, builder);
+ break;
+ }
+ case Type.FloatingPoint: { // 3
+ ArrowTypeHelper.initArrowTypeFloat((ArrowType.FloatingPoint) arrowType, builder);
+ break;
+ }
+ case Type.Binary: { // 4
+ builder.setType(GandivaTypes.GandivaType.BINARY);
+ break;
+ }
+ case Type.Utf8: { // 5
+ builder.setType(GandivaTypes.GandivaType.UTF8);
+ break;
+ }
+ case Type.Bool: { // 6
+ builder.setType(GandivaTypes.GandivaType.BOOL);
+ break;
+ }
+ case Type.Decimal: { // 7
+ ArrowTypeHelper.initArrowTypeDecimal((ArrowType.Decimal) arrowType, builder);
+ break;
+ }
+ case Type.Date: { // 8
+ ArrowTypeHelper.initArrowTypeDate((ArrowType.Date) arrowType, builder);
+ break;
+ }
+ case Type.Time: { // 9
+ ArrowTypeHelper.initArrowTypeTime((ArrowType.Time) arrowType, builder);
+ break;
+ }
+ case Type.Timestamp: { // 10
+ ArrowTypeHelper.initArrowTypeTimestamp((ArrowType.Timestamp) arrowType, builder);
+ break;
+ }
+ case Type.Interval: { // 11
+ ArrowTypeHelper.initArrowTypeInterval((ArrowType.Interval) arrowType, builder);
+ break;
+ }
+ case Type.List: { // 12
+ break;
+ }
+ case Type.Struct_: { // 13
+ break;
+ }
+ case Type.Union: { // 14
+ break;
+ }
+ case Type.FixedSizeBinary: { // 15
+ break;
+ }
+ case Type.FixedSizeList: { // 16
+ break;
+ }
+ case Type.Map: { // 17
+ break;
+ }
+ default: {
+ break;
+ }
+ }
+
+ if (!builder.hasType()) {
+ // type has not been set
+ // throw an exception
+ throw new UnsupportedTypeException("Unsupported type " + arrowType.toString());
+ }
+
+ return builder.build();
+ }
+
+ /**
+ * Converts an arrow field object to a protobuf.
+ * @param field Arrow field to be converted
+ * @return Protobuf representing the arrow field
+ */
+ public static GandivaTypes.Field arrowFieldToProtobuf(Field field) throws GandivaException {
+ GandivaTypes.Field.Builder builder = GandivaTypes.Field.newBuilder();
+ builder.setName(field.getName());
+ builder.setType(ArrowTypeHelper.arrowTypeToProtobuf(field.getType()));
+ builder.setNullable(field.isNullable());
+
+ for (Field child : field.getChildren()) {
+ builder.addChildren(ArrowTypeHelper.arrowFieldToProtobuf(child));
+ }
+
+ return builder.build();
+ }
+
+ /**
+ * Converts a schema object to a protobuf.
+ * @param schema Schema object to be converted
+ * @return Protobuf representing a schema object
+ */
+ public static GandivaTypes.Schema arrowSchemaToProtobuf(Schema schema) throws GandivaException {
+ GandivaTypes.Schema.Builder builder = GandivaTypes.Schema.newBuilder();
+
+ for (Field field : schema.getFields()) {
+ builder.addColumns(ArrowTypeHelper.arrowFieldToProtobuf(field));
+ }
+
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BinaryNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BinaryNode.java
new file mode 100644
index 000000000..8455f29b2
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BinaryNode.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * Used to represent expression tree nodes representing binary constants.
+ */
+class BinaryNode implements TreeNode {
+ private final byte[] value;
+
+ public BinaryNode(byte[] value) {
+ this.value = value;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.BinaryNode binaryNode = GandivaTypes.BinaryNode.newBuilder()
+ .setValue(ByteString.copyFrom(value))
+ .build();
+
+ return GandivaTypes.TreeNode.newBuilder()
+ .setBinaryNode(binaryNode)
+ .build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BooleanNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BooleanNode.java
new file mode 100644
index 000000000..505f01919
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BooleanNode.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Used to represent expression tree nodes representing boolean constants.
+ * Used while creating expressions like if (!x).
+ */
+class BooleanNode implements TreeNode {
+ private final Boolean value;
+
+ BooleanNode(Boolean value) {
+ this.value = value;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.BooleanNode.Builder boolBuilder = GandivaTypes.BooleanNode.newBuilder();
+ boolBuilder.setValue(value.booleanValue());
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setBooleanNode(boolBuilder.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/Condition.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/Condition.java
new file mode 100644
index 000000000..1d584d673
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/Condition.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Opaque class representing a filter condition.
+ */
+public class Condition {
+ private final TreeNode root;
+
+ Condition(TreeNode root) {
+ this.root = root;
+ }
+
+ /**
+ * Converts an condition expression into a protobuf.
+ * @return A protobuf representing the condition expression tree
+ */
+ public GandivaTypes.Condition toProtobuf() throws GandivaException {
+ GandivaTypes.Condition.Builder builder = GandivaTypes.Condition.newBuilder();
+ builder.setRoot(root.toProtobuf());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java
new file mode 100644
index 000000000..bf17aa0aa
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Used to represent expression tree nodes representing decimal constants.
+ * Used in the expression (x + 5.0)
+ */
+class DecimalNode implements TreeNode {
+ private final String value;
+ private final int precision;
+ private final int scale;
+
+ DecimalNode(String value, int precision, int scale) {
+ this.value = value;
+ this.precision = precision;
+ this.scale = scale;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.DecimalNode.Builder decimalNode = GandivaTypes.DecimalNode.newBuilder();
+ decimalNode.setValue(value);
+ decimalNode.setPrecision(precision);
+ decimalNode.setScale(scale);
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setDecimalNode(decimalNode.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DoubleNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DoubleNode.java
new file mode 100644
index 000000000..f7a9436f1
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DoubleNode.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Used to represent expression tree nodes representing double constants.
+ * Used in the expression (x + 5.0)
+ */
+class DoubleNode implements TreeNode {
+ private final Double value;
+
+ DoubleNode(Double value) {
+ this.value = value;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.DoubleNode.Builder doubleBuilder = GandivaTypes.DoubleNode.newBuilder();
+ doubleBuilder.setValue(value.doubleValue());
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setDoubleNode(doubleBuilder.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ExpressionTree.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ExpressionTree.java
new file mode 100644
index 000000000..353c8d12b
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ExpressionTree.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.vector.types.pojo.Field;
+
+/**
+ * Opaque class representing an expression.
+ */
+public class ExpressionTree {
+ private final TreeNode root;
+ private final Field resultField;
+
+ ExpressionTree(TreeNode root, Field resultField) {
+ this.root = root;
+ this.resultField = resultField;
+ }
+
+ /**
+ * Converts an expression tree into a protobuf.
+ * @return A protobuf representing the expression tree
+ */
+ public GandivaTypes.ExpressionRoot toProtobuf() throws GandivaException {
+ GandivaTypes.ExpressionRoot.Builder builder = GandivaTypes.ExpressionRoot.newBuilder();
+ builder.setRoot(root.toProtobuf());
+ builder.setResultType(ArrowTypeHelper.arrowFieldToProtobuf(resultField));
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FieldNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FieldNode.java
new file mode 100644
index 000000000..893bf7191
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FieldNode.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.vector.types.pojo.Field;
+
+/**
+ * Opaque class that represents a tree node.
+ */
+class FieldNode implements TreeNode {
+ private final Field field;
+
+ FieldNode(Field field) {
+ this.field = field;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.FieldNode.Builder fieldNode = GandivaTypes.FieldNode.newBuilder();
+ fieldNode.setField(ArrowTypeHelper.arrowFieldToProtobuf(field));
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setFieldNode(fieldNode.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FloatNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FloatNode.java
new file mode 100644
index 000000000..6afe96bfe
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FloatNode.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Used to represent expression tree nodes representing float constants.
+ * Used in the expression (x + 5.0)
+ */
+class FloatNode implements TreeNode {
+ private final Float value;
+
+ public FloatNode(Float value) {
+ this.value = value;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.FloatNode.Builder floatBuilder = GandivaTypes.FloatNode.newBuilder();
+ floatBuilder.setValue(value.floatValue());
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setFloatNode(floatBuilder.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java
new file mode 100644
index 000000000..ead1e146d
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+
+/**
+ * Node representing an arbitrary function in an expression.
+ */
+class FunctionNode implements TreeNode {
+ private final String function;
+ private final List<TreeNode> children;
+ private final ArrowType retType;
+
+ FunctionNode(String function, List<TreeNode> children, ArrowType retType) {
+ this.function = function;
+ this.children = children;
+ this.retType = retType;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.FunctionNode.Builder fnNode = GandivaTypes.FunctionNode.newBuilder();
+ fnNode.setFunctionName(function);
+ fnNode.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType));
+
+ for (TreeNode arg : children) {
+ fnNode.addInArgs(arg.toProtobuf());
+ }
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setFnNode(fnNode.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IfNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IfNode.java
new file mode 100644
index 000000000..19f9095fb
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IfNode.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+
+/**
+ * Node representing a if-then-else block expression.
+ */
+class IfNode implements TreeNode {
+ private final TreeNode condition;
+ private final TreeNode thenNode;
+ private final TreeNode elseNode;
+ private final ArrowType retType;
+
+ IfNode(TreeNode condition, TreeNode thenNode, TreeNode elseNode, ArrowType retType) {
+ this.condition = condition;
+ this.thenNode = thenNode;
+ this.elseNode = elseNode;
+ this.retType = retType;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.IfNode.Builder ifNodeBuilder = GandivaTypes.IfNode.newBuilder();
+ ifNodeBuilder.setCond(condition.toProtobuf());
+ ifNodeBuilder.setThenNode(thenNode.toProtobuf());
+ ifNodeBuilder.setElseNode(elseNode.toProtobuf());
+ ifNodeBuilder.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType));
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setIfNode(ifNodeBuilder.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
new file mode 100644
index 000000000..0f8de9628
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import java.math.BigDecimal;
+import java.nio.charset.Charset;
+import java.util.Set;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * In Node representation in java.
+ */
+public class InNode implements TreeNode {
+ private static final Charset charset = Charset.forName("UTF-8");
+
+ private final Set<Integer> intValues;
+ private final Set<Long> longValues;
+ private final Set<Float> floatValues;
+ private final Set<Double> doubleValues;
+ private final Set<BigDecimal> decimalValues;
+ private final Set<String> stringValues;
+ private final Set<byte[]> binaryValues;
+ private final TreeNode input;
+
+ private final Integer precision;
+ private final Integer scale;
+
+ private InNode(Set<Integer> values, Set<Long> longValues, Set<String> stringValues, Set<byte[]>
+ binaryValues, Set<BigDecimal> decimalValues, Integer precision, Integer scale,
+ Set<Float> floatValues, Set<Double> doubleValues, TreeNode node) {
+ this.intValues = values;
+ this.longValues = longValues;
+ this.decimalValues = decimalValues;
+ this.precision = precision;
+ this.scale = scale;
+ this.stringValues = stringValues;
+ this.binaryValues = binaryValues;
+ this.floatValues = floatValues;
+ this.doubleValues = doubleValues;
+ this.input = node;
+ }
+
+ /**
+ * Makes an IN node for int values.
+ *
+ * @param node Node with the 'IN' clause.
+ * @param intValues Int values to build the IN node.
+ * @retur InNode referring to tree node.
+ */
+ public static InNode makeIntInExpr(TreeNode node, Set<Integer> intValues) {
+ return new InNode(intValues,
+ null, null, null, null, null, null, null,
+ null, node);
+ }
+
+ /**
+ * Makes an IN node for long values.
+ *
+ * @param node Node with the 'IN' clause.
+ * @param longValues Long values to build the IN node.
+ * @retur InNode referring to tree node.
+ */
+ public static InNode makeLongInExpr(TreeNode node, Set<Long> longValues) {
+ return new InNode(null, longValues,
+ null, null, null, null, null, null,
+ null, node);
+ }
+
+ /**
+ * Makes an IN node for float values.
+ *
+ * @param node Node with the 'IN' clause.
+ * @param floatValues Float values to build the IN node.
+ * @retur InNode referring to tree node.
+ */
+ public static InNode makeFloatInExpr(TreeNode node, Set<Float> floatValues) {
+ return new InNode(null, null, null, null, null, null,
+ null, floatValues, null, node);
+ }
+
+ /**
+ * Makes an IN node for double values.
+ *
+ * @param node Node with the 'IN' clause.
+ * @param doubleValues Double values to build the IN node.
+ * @retur InNode referring to tree node.
+ */
+ public static InNode makeDoubleInExpr(TreeNode node, Set<Double> doubleValues) {
+ return new InNode(null, null, null, null, null,
+ null, null, null, doubleValues, node);
+ }
+
+ public static InNode makeDecimalInExpr(TreeNode node, Set<BigDecimal> decimalValues,
+ Integer precision, Integer scale) {
+ return new InNode(null, null, null, null,
+ decimalValues, precision, scale, null, null, node);
+ }
+
+ public static InNode makeStringInExpr(TreeNode node, Set<String> stringValues) {
+ return new InNode(null, null, stringValues, null,
+ null, null, null, null, null, node);
+ }
+
+ public static InNode makeBinaryInExpr(TreeNode node, Set<byte[]> binaryValues) {
+ return new InNode(null, null, null, binaryValues,
+ null, null, null, null, null, node);
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.InNode.Builder inNode = GandivaTypes.InNode.newBuilder();
+
+ inNode.setNode(input.toProtobuf());
+
+ if (intValues != null) {
+ GandivaTypes.IntConstants.Builder intConstants = GandivaTypes.IntConstants.newBuilder();
+ intValues.stream().forEach(val -> intConstants.addIntValues(GandivaTypes.IntNode.newBuilder()
+ .setValue(val).build()));
+ inNode.setIntValues(intConstants.build());
+ } else if (longValues != null) {
+ GandivaTypes.LongConstants.Builder longConstants = GandivaTypes.LongConstants.newBuilder();
+ longValues.stream().forEach(val -> longConstants.addLongValues(GandivaTypes.LongNode.newBuilder()
+ .setValue(val).build()));
+ inNode.setLongValues(longConstants.build());
+ } else if (floatValues != null) {
+ GandivaTypes.FloatConstants.Builder floatConstants = GandivaTypes.FloatConstants.newBuilder();
+ floatValues.stream().forEach(val -> floatConstants.addFloatValues(GandivaTypes.FloatNode.newBuilder()
+ .setValue(val).build()));
+ inNode.setFloatValues(floatConstants.build());
+ } else if (doubleValues != null) {
+ GandivaTypes.DoubleConstants.Builder doubleConstants = GandivaTypes.DoubleConstants.newBuilder();
+ doubleValues.stream().forEach(val -> doubleConstants.addDoubleValues(GandivaTypes.DoubleNode.newBuilder()
+ .setValue(val).build()));
+ inNode.setDoubleValues(doubleConstants.build());
+ } else if (decimalValues != null) {
+ GandivaTypes.DecimalConstants.Builder decimalConstants = GandivaTypes.DecimalConstants.newBuilder();
+ decimalValues.stream().forEach(val -> decimalConstants.addDecimalValues(GandivaTypes.DecimalNode.newBuilder()
+ .setValue(val.toPlainString()).setPrecision(precision).setScale(scale).build()));
+ inNode.setDecimalValues(decimalConstants.build());
+ } else if (stringValues != null) {
+ GandivaTypes.StringConstants.Builder stringConstants = GandivaTypes.StringConstants
+ .newBuilder();
+ stringValues.stream().forEach(val -> stringConstants.addStringValues(GandivaTypes.StringNode
+ .newBuilder().setValue(ByteString.copyFrom(val.getBytes(charset))).build()));
+ inNode.setStringValues(stringConstants.build());
+ } else if (binaryValues != null) {
+ GandivaTypes.BinaryConstants.Builder binaryConstants = GandivaTypes.BinaryConstants
+ .newBuilder();
+ binaryValues.stream().forEach(val -> binaryConstants.addBinaryValues(GandivaTypes.BinaryNode
+ .newBuilder().setValue(ByteString.copyFrom(val)).build()));
+ inNode.setBinaryValues(binaryConstants.build());
+ }
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setInNode(inNode.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IntNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IntNode.java
new file mode 100644
index 000000000..c3858ef7e
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IntNode.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Used to represent expression tree nodes representing int constants.
+ * Used in the expression (x + 5)
+ */
+class IntNode implements TreeNode {
+ private final Integer value;
+
+ IntNode(Integer value) {
+ this.value = value;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.IntNode.Builder intBuilder = GandivaTypes.IntNode.newBuilder();
+ intBuilder.setValue(value.intValue());
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setIntNode(intBuilder.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/LongNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/LongNode.java
new file mode 100644
index 000000000..311c5d94d
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/LongNode.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Used to represent expression tree nodes representing long constants.
+ * Used in the expression (x + 5L)
+ */
+class LongNode implements TreeNode {
+ private final Long value;
+
+ LongNode(Long value) {
+ this.value = value;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.LongNode.Builder longBuilder = GandivaTypes.LongNode.newBuilder();
+ longBuilder.setValue(value.longValue());
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setLongNode(longBuilder.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/NullNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/NullNode.java
new file mode 100644
index 000000000..a8e7d6f82
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/NullNode.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+
+/** An expression indicating a null value. */
+class NullNode implements TreeNode {
+ private final ArrowType type;
+
+ NullNode(ArrowType type) {
+ this.type = type;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.NullNode.Builder nullNode = GandivaTypes.NullNode.newBuilder();
+ nullNode.setType(ArrowTypeHelper.arrowTypeToProtobuf(type));
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setNullNode(nullNode.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/OrNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/OrNode.java
new file mode 100644
index 000000000..2dbdfed7c
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/OrNode.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Represents a logical OR Node.
+ */
+class OrNode implements TreeNode {
+ private final List<TreeNode> children;
+
+ OrNode(List<TreeNode> children) {
+ this.children = children;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.OrNode.Builder orNode = GandivaTypes.OrNode.newBuilder();
+
+ for (TreeNode arg : children) {
+ orNode.addArgs(arg.toProtobuf());
+ }
+
+ GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder();
+ builder.setOrNode(orNode.build());
+ return builder.build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/StringNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/StringNode.java
new file mode 100644
index 000000000..a44329739
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/StringNode.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import java.nio.charset.Charset;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * Used to represent expression tree nodes representing utf8 constants.
+ */
+class StringNode implements TreeNode {
+ private static final Charset charset = Charset.forName("UTF-8");
+ private final String value;
+
+ public StringNode(String value) {
+ this.value = value;
+ }
+
+ @Override
+ public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
+ GandivaTypes.StringNode stringNode = GandivaTypes.StringNode.newBuilder()
+ .setValue(ByteString.copyFrom(value.getBytes(charset)))
+ .build();
+
+ return GandivaTypes.TreeNode.newBuilder()
+ .setStringNode(stringNode)
+ .build();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
new file mode 100644
index 000000000..8656e886a
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+
+/**
+ * Contains helper functions for constructing expression trees.
+ */
+public class TreeBuilder {
+ private TreeBuilder() {}
+
+ /**
+ * Helper functions to create literal constants.
+ */
+ public static TreeNode makeLiteral(Boolean booleanConstant) {
+ return new BooleanNode(booleanConstant);
+ }
+
+ public static TreeNode makeLiteral(Float floatConstant) {
+ return new FloatNode(floatConstant);
+ }
+
+ public static TreeNode makeLiteral(Double doubleConstant) {
+ return new DoubleNode(doubleConstant);
+ }
+
+ public static TreeNode makeLiteral(Integer integerConstant) {
+ return new IntNode(integerConstant);
+ }
+
+ public static TreeNode makeLiteral(Long longConstant) {
+ return new LongNode(longConstant);
+ }
+
+ public static TreeNode makeStringLiteral(String stringConstant) {
+ return new StringNode(stringConstant);
+ }
+
+ public static TreeNode makeBinaryLiteral(byte[] binaryConstant) {
+ return new BinaryNode(binaryConstant);
+ }
+
+ public static TreeNode makeDecimalLiteral(String decimalConstant, int precision, int scale) {
+ return new DecimalNode(decimalConstant, precision, scale);
+ }
+
+ /**
+ * create a null literal.
+ */
+ public static TreeNode makeNull(ArrowType type) {
+ return new NullNode(type);
+ }
+
+ /**
+ * Invoke this function to create a node representing a field, e.g. a column name.
+ *
+ * @param field represents the input argument - includes the name and type of the field
+ * @return Node representing a field
+ */
+ public static TreeNode makeField(Field field) {
+ return new FieldNode(field);
+ }
+
+ /**
+ * Invoke this function to create a node representing a function.
+ *
+ * @param function Name of the function, e.g. add
+ * @param children The arguments to the function
+ * @param retType The type of the return value of the operator
+ * @return Node representing a function
+ */
+ public static TreeNode makeFunction(String function,
+ List<TreeNode> children,
+ ArrowType retType) {
+ return new FunctionNode(function, children, retType);
+ }
+
+ /**
+ * Invoke this function to create a node representing an if-clause.
+ *
+ * @param condition Node representing the condition
+ * @param thenNode Node representing the if-block
+ * @param elseNode Node representing the else-block
+ * @param retType Return type of the node
+ * @return Node representing an if-clause
+ */
+ public static TreeNode makeIf(TreeNode condition,
+ TreeNode thenNode,
+ TreeNode elseNode,
+ ArrowType retType) {
+ return new IfNode(condition, thenNode, elseNode, retType);
+ }
+
+ /**
+ * Invoke this function to create a node representing an and-clause.
+ *
+ * @param nodes Nodes in the 'and' clause.
+ * @return Node representing an and-clause
+ */
+ public static TreeNode makeAnd(List<TreeNode> nodes) {
+ return new AndNode(nodes);
+ }
+
+ /**
+ * Invoke this function to create a node representing an or-clause.
+ *
+ * @param nodes Nodes in the 'or' clause.
+ * @return Node representing an or-clause
+ */
+ public static TreeNode makeOr(List<TreeNode> nodes) {
+ return new OrNode(nodes);
+ }
+
+ /**
+ * Invoke this function to create an expression tree.
+ *
+ * @param root is returned by a call to MakeField, MakeFunction, or MakeIf
+ * @param resultField represents the return value of the expression
+ * @return ExpressionTree referring to the root of an expression tree
+ */
+ public static ExpressionTree makeExpression(TreeNode root,
+ Field resultField) {
+ return new ExpressionTree(root, resultField);
+ }
+
+ /**
+ * Short cut to create an expression tree involving a single function, e.g. a+b+c.
+ *
+ * @param function Name of the function, e.g. add()
+ * @param inFields In arguments to the function
+ * @param resultField represents the return value of the expression
+ * @return ExpressionTree referring to the root of an expression tree
+ */
+ public static ExpressionTree makeExpression(String function,
+ List<Field> inFields,
+ Field resultField) {
+ List<TreeNode> children = new ArrayList<TreeNode>(inFields.size());
+ for (Field field : inFields) {
+ children.add(makeField(field));
+ }
+
+ TreeNode root = makeFunction(function, children, resultField.getType());
+ return makeExpression(root, resultField);
+ }
+
+ /**
+ * Invoke this function to create a condition.
+ *
+ * @param root is returned by a call to MakeField, MakeFunction, MakeIf, ..
+ * @return condition referring to the root of an expression tree
+ */
+ public static Condition makeCondition(TreeNode root) {
+ return new Condition(root);
+ }
+
+ /**
+ * Short cut to create an expression tree involving a single function, e.g. a+b+c.
+ *
+ * @param function Name of the function, e.g. add()
+ * @param inFields In arguments to the function
+ * @return condition referring to the root of an expression tree
+ */
+ public static Condition makeCondition(String function,
+ List<Field> inFields) {
+ List<TreeNode> children = new ArrayList<>(inFields.size());
+ for (Field field : inFields) {
+ children.add(makeField(field));
+ }
+
+ TreeNode root = makeFunction(function, children, new ArrowType.Bool());
+ return makeCondition(root);
+ }
+
+ public static TreeNode makeInExpressionInt32(TreeNode resultNode,
+ Set<Integer> intValues) {
+ return InNode.makeIntInExpr(resultNode, intValues);
+ }
+
+ public static TreeNode makeInExpressionBigInt(TreeNode resultNode,
+ Set<Long> longValues) {
+ return InNode.makeLongInExpr(resultNode, longValues);
+ }
+
+ public static TreeNode makeInExpressionDecimal(TreeNode resultNode,
+ Set<BigDecimal> decimalValues, Integer precision, Integer scale) {
+ return InNode.makeDecimalInExpr(resultNode, decimalValues, precision, scale);
+ }
+
+ public static TreeNode makeInExpressionFloat(TreeNode resultNode,
+ Set<Float> floatValues) {
+ return InNode.makeFloatInExpr(resultNode, floatValues);
+ }
+
+ public static TreeNode makeInExpressionDouble(TreeNode resultNode,
+ Set<Double> doubleValues) {
+ return InNode.makeDoubleInExpr(resultNode, doubleValues);
+ }
+
+ public static TreeNode makeInExpressionString(TreeNode resultNode,
+ Set<String> stringValues) {
+ return InNode.makeStringInExpr(resultNode, stringValues);
+ }
+
+ public static TreeNode makeInExpressionBinary(TreeNode resultNode,
+ Set<byte[]> binaryValues) {
+ return InNode.makeBinaryInExpr(resultNode, binaryValues);
+ }
+}
diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeNode.java
new file mode 100644
index 000000000..b8d70d6e7
--- /dev/null
+++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeNode.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+
+/**
+ * Defines an internal node in the expression tree.
+ */
+public interface TreeNode {
+ /**
+ * Converts a TreeNode into a protobuf.
+ *
+ * @return A treenode protobuf
+ * @throws GandivaException in case the TreeNode cannot be processed
+ */
+ GandivaTypes.TreeNode toProtobuf() throws GandivaException;
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
new file mode 100644
index 000000000..4a36c0405
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.math.BigDecimal;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.Condition;
+import org.apache.arrow.gandiva.expression.ExpressionTree;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.After;
+import org.junit.Before;
+
+class BaseEvaluatorTest {
+
+ interface BaseEvaluator {
+
+ void evaluate(ArrowRecordBatch recordBatch, BufferAllocator allocator) throws GandivaException;
+
+ long getElapsedMillis();
+ }
+
+ class ProjectEvaluator implements BaseEvaluator {
+
+ private Projector projector;
+ private DataAndVectorGenerator generator;
+ private int numExprs;
+ private int maxRowsInBatch;
+ private long elapsedTime = 0;
+ private List<ValueVector> outputVectors = new ArrayList<>();
+
+ public ProjectEvaluator(Projector projector,
+ DataAndVectorGenerator generator,
+ int numExprs,
+ int maxRowsInBatch) {
+ this.projector = projector;
+ this.generator = generator;
+ this.numExprs = numExprs;
+ this.maxRowsInBatch = maxRowsInBatch;
+ }
+
+ @Override
+ public void evaluate(ArrowRecordBatch recordBatch,
+ BufferAllocator allocator) throws GandivaException {
+ // set up output vectors
+ // for each expression, generate the output vector
+ for (int i = 0; i < numExprs; i++) {
+ ValueVector valueVector = generator.generateOutputVector(maxRowsInBatch);
+ outputVectors.add(valueVector);
+ }
+
+ try {
+ long start = System.nanoTime();
+ projector.evaluate(recordBatch, outputVectors);
+ long finish = System.nanoTime();
+ elapsedTime += (finish - start);
+ } finally {
+ for (ValueVector valueVector : outputVectors) {
+ valueVector.close();
+ }
+ }
+ outputVectors.clear();
+ }
+
+ @Override
+ public long getElapsedMillis() {
+ return TimeUnit.NANOSECONDS.toMillis(elapsedTime);
+ }
+ }
+
+ class FilterEvaluator implements BaseEvaluator {
+
+ private Filter filter;
+ private long elapsedTime = 0;
+
+ public FilterEvaluator(Filter filter) {
+ this.filter = filter;
+ }
+
+ @Override
+ public void evaluate(ArrowRecordBatch recordBatch,
+ BufferAllocator allocator) throws GandivaException {
+ ArrowBuf selectionBuffer = allocator.buffer(recordBatch.getLength() * 2);
+ SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+ try {
+ long start = System.nanoTime();
+ filter.evaluate(recordBatch, selectionVector);
+ long finish = System.nanoTime();
+ elapsedTime += (finish - start);
+ } finally {
+ selectionBuffer.close();
+ }
+ }
+
+ @Override
+ public long getElapsedMillis() {
+ return TimeUnit.NANOSECONDS.toMillis(elapsedTime);
+ }
+ }
+
+ interface DataAndVectorGenerator {
+
+ void writeData(ArrowBuf buffer);
+
+ ValueVector generateOutputVector(int numRowsInBatch);
+ }
+
+ class Int32DataAndVectorGenerator implements DataAndVectorGenerator {
+
+ protected final BufferAllocator allocator;
+ protected final Random rand;
+
+ Int32DataAndVectorGenerator(BufferAllocator allocator) {
+ this.allocator = allocator;
+ this.rand = new Random();
+ }
+
+ @Override
+ public void writeData(ArrowBuf buffer) {
+ buffer.writeInt(rand.nextInt());
+ }
+
+ @Override
+ public ValueVector generateOutputVector(int numRowsInBatch) {
+ IntVector intVector = new IntVector(BaseEvaluatorTest.EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRowsInBatch);
+ return intVector;
+ }
+ }
+
+ class BoundedInt32DataAndVectorGenerator extends Int32DataAndVectorGenerator {
+
+ private final int upperBound;
+
+ BoundedInt32DataAndVectorGenerator(BufferAllocator allocator, int upperBound) {
+ super(allocator);
+ this.upperBound = upperBound;
+ }
+
+ @Override
+ public void writeData(ArrowBuf buffer) {
+ buffer.writeInt(rand.nextInt(upperBound));
+ }
+ }
+
+ protected static final int THOUSAND = 1000;
+ protected static final int MILLION = THOUSAND * THOUSAND;
+
+ protected static final String EMPTY_SCHEMA_PATH = "";
+
+ protected BufferAllocator allocator;
+ protected ArrowType boolType;
+ protected ArrowType int8;
+ protected ArrowType int32;
+ protected ArrowType int64;
+ protected ArrowType float64;
+
+ @Before
+ public void init() {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ boolType = new ArrowType.Bool();
+ int8 = new ArrowType.Int(8, true);
+ int32 = new ArrowType.Int(32, true);
+ int64 = new ArrowType.Int(64, true);
+ float64 = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
+ }
+
+ @After
+ public void tearDown() {
+ allocator.close();
+ }
+
+ ArrowBuf buf(int length) {
+ ArrowBuf buffer = allocator.buffer(length);
+ return buffer;
+ }
+
+ ArrowBuf buf(byte[] bytes) {
+ ArrowBuf buffer = allocator.buffer(bytes.length);
+ buffer.writeBytes(bytes);
+ return buffer;
+ }
+
+ ArrowBuf arrowBufWithAllValid(int size) {
+ int bufLen = (size + 7) / 8;
+ ArrowBuf buffer = allocator.buffer(bufLen);
+ for (int i = 0; i < bufLen; i++) {
+ buffer.writeByte(255);
+ }
+
+ return buffer;
+ }
+
+ ArrowBuf intBuf(int[] ints) {
+ ArrowBuf buffer = allocator.buffer(ints.length * 4);
+ for (int i = 0; i < ints.length; i++) {
+ buffer.writeInt(ints[i]);
+ }
+ return buffer;
+ }
+
+ DecimalVector decimalVector(String[] values, int precision, int scale) {
+ DecimalVector vector = new DecimalVector("decimal" + Math.random(), allocator, precision, scale);
+ vector.allocateNew();
+ for (int i = 0; i < values.length; i++) {
+ BigDecimal decimal = new BigDecimal(values[i]).setScale(scale);
+ vector.setSafe(i, decimal);
+ }
+
+ vector.setValueCount(values.length);
+ return vector;
+ }
+
+ Set decimalSet(String[] values, Integer scale) {
+ Set<BigDecimal> decimalSet = new HashSet<>();
+ for (int i = 0; i < values.length; i++) {
+ decimalSet.add(new BigDecimal(values[i]).setScale(scale));
+ }
+
+ return decimalSet;
+ }
+
+ VarCharVector varcharVector(String[] values) {
+ VarCharVector vector = new VarCharVector("VarCharVector" + Math.random(), allocator);
+ vector.allocateNew();
+ for (int i = 0; i < values.length; i++) {
+ vector.setSafe(i, values[i].getBytes(), 0, values[i].length());
+ }
+
+ vector.setValueCount(values.length);
+ return vector;
+ }
+
+ ArrowBuf longBuf(long[] longs) {
+ ArrowBuf buffer = allocator.buffer(longs.length * 8);
+ for (int i = 0; i < longs.length; i++) {
+ buffer.writeLong(longs[i]);
+ }
+ return buffer;
+ }
+
+ ArrowBuf doubleBuf(double[] data) {
+ ArrowBuf buffer = allocator.buffer(data.length * 8);
+ for (int i = 0; i < data.length; i++) {
+ buffer.writeDouble(data[i]);
+ }
+
+ return buffer;
+ }
+
+ ArrowBuf stringToMillis(String[] dates) {
+ ArrowBuf buffer = allocator.buffer(dates.length * 8);
+ for (int i = 0; i < dates.length; i++) {
+ Instant instant = Instant.parse(dates[i]);
+ buffer.writeLong(instant.toEpochMilli());
+ }
+
+ return buffer;
+ }
+
+ ArrowBuf stringToDayInterval(String[] values) {
+ ArrowBuf buffer = allocator.buffer(values.length * 8);
+ for (int i = 0; i < values.length; i++) {
+ buffer.writeInt(Integer.parseInt(values[i].split(" ")[0])); // days
+ buffer.writeInt(Integer.parseInt(values[i].split(" ")[1])); // millis
+ }
+ return buffer;
+ }
+
+ void releaseRecordBatch(ArrowRecordBatch recordBatch) {
+ // There are 2 references to the buffers
+ // One in the recordBatch - release that by calling close()
+ // One in the allocator - release that explicitly
+ List<ArrowBuf> buffers = recordBatch.getBuffers();
+ recordBatch.close();
+ for (ArrowBuf buf : buffers) {
+ buf.getReferenceManager().release();
+ }
+ }
+
+ void releaseValueVectors(List<ValueVector> valueVectors) {
+ for (ValueVector valueVector : valueVectors) {
+ valueVector.close();
+ }
+ }
+
+ void generateData(DataAndVectorGenerator generator, int numRecords, ArrowBuf buffer) {
+ for (int i = 0; i < numRecords; i++) {
+ generator.writeData(buffer);
+ }
+ }
+
+ private void generateDataAndEvaluate(DataAndVectorGenerator generator,
+ BaseEvaluator evaluator,
+ int numFields,
+ int numRows, int maxRowsInBatch,
+ int inputFieldSize)
+ throws GandivaException, Exception {
+ int numRemaining = numRows;
+ List<ArrowBuf> inputData = new ArrayList<ArrowBuf>();
+ List<ArrowFieldNode> fieldNodes = new ArrayList<ArrowFieldNode>();
+
+ // set the bitmap
+ while (numRemaining > 0) {
+ int numRowsInBatch = maxRowsInBatch;
+ if (numRowsInBatch > numRemaining) {
+ numRowsInBatch = numRemaining;
+ }
+
+ // generate data
+ for (int i = 0; i < numFields; i++) {
+ ArrowBuf buf = allocator.buffer(numRowsInBatch * inputFieldSize);
+ ArrowBuf validity = arrowBufWithAllValid(maxRowsInBatch);
+ generateData(generator, numRowsInBatch, buf);
+
+ fieldNodes.add(new ArrowFieldNode(numRowsInBatch, 0));
+ inputData.add(validity);
+ inputData.add(buf);
+ }
+
+ // create record batch
+ ArrowRecordBatch recordBatch = new ArrowRecordBatch(numRowsInBatch, fieldNodes, inputData);
+
+ evaluator.evaluate(recordBatch, allocator);
+
+ // fix numRemaining
+ numRemaining -= numRowsInBatch;
+
+ // release refs
+ releaseRecordBatch(recordBatch);
+
+ inputData.clear();
+ fieldNodes.clear();
+ }
+ }
+
+ long timedProject(DataAndVectorGenerator generator,
+ Schema schema, List<ExpressionTree> exprs,
+ int numRows, int maxRowsInBatch,
+ int inputFieldSize)
+ throws GandivaException, Exception {
+ Projector projector = Projector.make(schema, exprs);
+ try {
+ ProjectEvaluator evaluator =
+ new ProjectEvaluator(projector, generator, exprs.size(), maxRowsInBatch);
+ generateDataAndEvaluate(generator, evaluator,
+ schema.getFields().size(), numRows, maxRowsInBatch, inputFieldSize);
+ return evaluator.getElapsedMillis();
+ } finally {
+ projector.close();
+ }
+ }
+
+ long timedFilter(DataAndVectorGenerator generator,
+ Schema schema, Condition condition,
+ int numRows, int maxRowsInBatch,
+ int inputFieldSize)
+ throws GandivaException, Exception {
+
+ Filter filter = Filter.make(schema, condition);
+ try {
+ FilterEvaluator evaluator = new FilterEvaluator(filter);
+ generateDataAndEvaluate(generator, evaluator,
+ schema.getFields().size(), numRows, maxRowsInBatch, inputFieldSize);
+ return evaluator.getElapsedMillis();
+ } finally {
+ filter.close();
+ }
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java
new file mode 100644
index 000000000..fe51c09e3
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DecimalTypeUtilTest {
+
+ @Test
+ public void testOutputTypesForAdd() {
+ ArrowType.Decimal operand1 = getDecimal(30, 10);
+ ArrowType.Decimal operand2 = getDecimal(30, 10);
+ ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(31, 10).equals(resultType));
+
+ operand1 = getDecimal(30, 6);
+ operand2 = getDecimal(30, 5);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(32, 6).equals(resultType));
+
+ operand1 = getDecimal(30, 10);
+ operand2 = getDecimal(38, 10);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 9).equals(resultType));
+
+ operand1 = getDecimal(38, 10);
+ operand2 = getDecimal(38, 38);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 9).equals(resultType));
+
+ operand1 = getDecimal(38, 10);
+ operand2 = getDecimal(38, 2);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 6).equals(resultType));
+
+ }
+
+ @Test
+ public void testOutputTypesForMultiply() {
+ ArrowType.Decimal operand1 = getDecimal(30, 10);
+ ArrowType.Decimal operand2 = getDecimal(30, 10);
+ ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.MULTIPLY, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 6).equals(resultType));
+
+ operand1 = getDecimal(38, 10);
+ operand2 = getDecimal(9, 2);
+ resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.MULTIPLY, operand1, operand2);
+ Assert.assertTrue(getDecimal(38, 6).equals(resultType));
+
+ }
+
+ @Test
+ public void testOutputTypesForMod() {
+ ArrowType.Decimal operand1 = getDecimal(30, 10);
+ ArrowType.Decimal operand2 = getDecimal(28, 7);
+ ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.MOD, operand1, operand2);
+ Assert.assertTrue(getDecimal(30, 10).equals(resultType));
+ }
+
+ private ArrowType.Decimal getDecimal(int precision, int scale) {
+ return new ArrowType.Decimal(precision, scale, 128);
+ }
+
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryTest.java
new file mode 100644
index 000000000..a51ac09ba
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryTest.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.util.Set;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.junit.Assert;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public class ExpressionRegistryTest {
+
+ @Test
+ public void testTypes() throws GandivaException {
+ Set<ArrowType> types = ExpressionRegistry.getInstance().getSupportedTypes();
+ ArrowType.Int uint8 = new ArrowType.Int(8, false);
+ Assert.assertTrue(types.contains(uint8));
+ }
+
+ @Test
+ public void testFunctions() throws GandivaException {
+ ArrowType.Int uint8 = new ArrowType.Int(8, false);
+ FunctionSignature signature =
+ new FunctionSignature("add", uint8, Lists.newArrayList(uint8, uint8));
+ Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions();
+ Assert.assertTrue(functions.contains(signature));
+ }
+
+ @Test
+ public void testFunctionAliases() throws GandivaException {
+ ArrowType.Int int64 = new ArrowType.Int(64, true);
+ FunctionSignature signature =
+ new FunctionSignature("modulo", int64, Lists.newArrayList(int64, int64));
+ Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions();
+ Assert.assertTrue(functions.contains(signature));
+ }
+
+ @Test
+ public void testCaseInsensitiveFunctionName() throws GandivaException {
+ ArrowType.Utf8 utf8 = new ArrowType.Utf8();
+ ArrowType.Int int64 = new ArrowType.Int(64, true);
+ FunctionSignature signature =
+ new FunctionSignature("castvarchar", utf8, Lists.newArrayList(utf8, int64));
+ Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions();
+ Assert.assertTrue(functions.contains(signature));
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterProjectTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterProjectTest.java
new file mode 100644
index 000000000..51fc1c291
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterProjectTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.Condition;
+import org.apache.arrow.gandiva.expression.ExpressionTree;
+import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public class FilterProjectTest extends BaseEvaluatorTest {
+
+ @Test
+ public void testSimpleSV16() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ Field c = Field.nullable("c", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Condition condition = TreeBuilder.makeCondition("less_than", args);
+
+ Schema schema = new Schema(args);
+ Filter filter = Filter.make(schema, condition);
+
+ ExpressionTree expression = TreeBuilder.makeExpression("add", Lists.newArrayList(a, b), c);
+ Projector projector = Projector.make(schema, Lists.newArrayList(expression), SelectionVectorType.SV_INT16);
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] bValues = new int[]{2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
+ int[] expected = {3, 7, 11, 15};
+
+ verifyTestCaseFor16(filter, projector, numRows, validity, aValues, bValues, expected);
+ }
+
+ private void verifyTestCaseFor16(Filter filter, Projector projector, int numRows, byte[] validity,
+ int[] aValues, int[] bValues, int[] expected) throws GandivaException {
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesb = intBuf(bValues);
+ ArrowRecordBatch batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(validitya, valuesa, validityb, valuesb));
+
+ ArrowBuf selectionBuffer = buf(numRows * 2);
+ SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+ filter.evaluate(batch, selectionVector);
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(selectionVector.getRecordCount());
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ projector.evaluate(batch, selectionVector, output);
+ for (int i = 0; i < selectionVector.getRecordCount(); i++) {
+ assertFalse(intVector.isNull(i));
+ assertEquals(expected[i], intVector.get(i));
+ }
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ selectionBuffer.close();
+ filter.close();
+ projector.close();
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
new file mode 100644
index 000000000..ed6e43cd6
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.nio.charset.Charset;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.IntStream;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.Condition;
+import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.expression.TreeNode;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+public class FilterTest extends BaseEvaluatorTest {
+
+ private int[] selectionVectorToArray(SelectionVector vector) {
+ int[] actual = new int[vector.getRecordCount()];
+ for (int i = 0; i < vector.getRecordCount(); ++i) {
+ actual[i] = vector.getIndex(i);
+ }
+ return actual;
+ }
+
+ private Charset utf8Charset = Charset.forName("UTF-8");
+ private Charset utf16Charset = Charset.forName("UTF-16");
+
+ List<ArrowBuf> varBufs(String[] strings, Charset charset) {
+ ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4);
+ ArrowBuf dataBuffer = allocator.buffer(strings.length * 8);
+
+ int startOffset = 0;
+ for (int i = 0; i < strings.length; i++) {
+ offsetsBuffer.writeInt(startOffset);
+
+ final byte[] bytes = strings[i].getBytes(charset);
+ dataBuffer = dataBuffer.reallocIfNeeded(dataBuffer.writerIndex() + bytes.length);
+ dataBuffer.setBytes(startOffset, bytes, 0, bytes.length);
+ startOffset += bytes.length;
+ }
+ offsetsBuffer.writeInt(startOffset); // offset for the last element
+
+ return Arrays.asList(offsetsBuffer, dataBuffer);
+ }
+
+ List<ArrowBuf> stringBufs(String[] strings) {
+ return varBufs(strings, utf8Charset);
+ }
+
+ @Test
+ public void testSimpleInString() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", new ArrowType.Utf8());
+ TreeNode l1 = TreeBuilder.makeLiteral(1L);
+ TreeNode l2 = TreeBuilder.makeLiteral(3L);
+
+ List<Field> argsSchema = Lists.newArrayList(c1);
+ List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
+ TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
+
+ Condition condition = TreeBuilder.makeCondition(inExpr);
+
+ Schema schema = new Schema(argsSchema);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 16;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ // second half is "undefined"
+ String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
+ "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
+ "sixteen"};
+ int[] expected = {0, 1, 2, 3};
+ ArrowBuf c1Validity = buf(validity);
+ ArrowBuf c2Validity = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(c1Values);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity));
+
+ ArrowBuf selectionBuffer = buf(numRows * 2);
+ SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+ filter.evaluate(batch, selectionVector);
+
+ int[] actual = selectionVectorToArray(selectionVector);
+ releaseRecordBatch(batch);
+ selectionBuffer.close();
+ filter.close();
+ Assert.assertArrayEquals(expected, actual);
+ }
+
+ @Test
+ public void testSimpleInInt() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", int32);
+
+ List<Field> argsSchema = Lists.newArrayList(c1);
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4));
+
+ Condition condition = TreeBuilder.makeCondition(inExpr);
+
+ Schema schema = new Schema(argsSchema);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 16;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] expected = {0, 1, 2, 3};
+
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(validitya, valuesa, validityb));
+
+ ArrowBuf selectionBuffer = buf(numRows * 2);
+ SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+ filter.evaluate(batch, selectionVector);
+
+ // free buffers
+ int[] actual = selectionVectorToArray(selectionVector);
+ releaseRecordBatch(batch);
+ selectionBuffer.close();
+ filter.close();
+ Assert.assertArrayEquals(expected, actual);
+ }
+
+ @Test
+ public void testSimpleSV16() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Condition condition = TreeBuilder.makeCondition("less_than", args);
+
+ Schema schema = new Schema(args);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 16;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
+ int[] expected = {0, 2, 4, 6};
+
+ verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
+ }
+
+ @Test
+ public void testSimpleSV16_AllMatched() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Condition condition = TreeBuilder.makeCondition("less_than", args);
+
+ Schema schema = new Schema(args);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 32;
+
+ byte[] validity = new byte[numRows / 8];
+
+ IntStream.range(0, numRows / 8).forEach(i -> validity[i] = (byte) 255);
+
+ int[] aValues = new int[numRows];
+ IntStream.range(0, numRows).forEach(i -> aValues[i] = i);
+
+ int[] bValues = new int[numRows];
+ IntStream.range(0, numRows).forEach(i -> bValues[i] = i + 1);
+
+ int[] expected = new int[numRows];
+ IntStream.range(0, numRows).forEach(i -> expected[i] = i);
+
+ verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
+ }
+
+ @Test
+ public void testSimpleSV16_GreaterThan64Recs() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Condition condition = TreeBuilder.makeCondition("greater_than", args);
+
+ Schema schema = new Schema(args);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 1000;
+
+ byte[] validity = new byte[numRows / 8];
+
+ IntStream.range(0, numRows / 8).forEach(i -> validity[i] = (byte) 255);
+
+ int[] aValues = new int[numRows];
+ IntStream.range(0, numRows).forEach(i -> aValues[i] = i);
+
+ int[] bValues = new int[numRows];
+ IntStream.range(0, numRows).forEach(i -> bValues[i] = i + 1);
+
+ aValues[0] = 5;
+ bValues[0] = 0;
+
+ int[] expected = {0};
+
+ verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
+ }
+
+ @Test
+ public void testSimpleSV32() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Condition condition = TreeBuilder.makeCondition("less_than", args);
+
+ Schema schema = new Schema(args);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 16;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
+ int[] expected = {0, 2, 4, 6};
+
+ verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
+ }
+
+ @Test
+ public void testSimpleFilterWithNoOptimisation() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Condition condition = TreeBuilder.makeCondition("less_than", args);
+
+ Schema schema = new Schema(args);
+ Filter filter = Filter.make(schema, condition, false);
+
+ int numRows = 16;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15};
+ int[] expected = {0, 2, 4, 6};
+
+ verifyTestCase(filter, numRows, validity, aValues, bValues, expected);
+ }
+
+ private void verifyTestCase(
+ Filter filter, int numRows, byte[] validity, int[] aValues, int[] bValues, int[] expected)
+ throws GandivaException {
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesb = intBuf(bValues);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(validitya, valuesa, validityb, valuesb));
+
+ ArrowBuf selectionBuffer = buf(numRows * 2);
+ SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+ filter.evaluate(batch, selectionVector);
+
+ // free buffers
+ int[] actual = selectionVectorToArray(selectionVector);
+ releaseRecordBatch(batch);
+ selectionBuffer.close();
+ filter.close();
+
+ Assert.assertArrayEquals(expected, actual);
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java
new file mode 100644
index 000000000..6934c3f9e
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import java.util.List;
+
+import org.apache.arrow.gandiva.expression.Condition;
+import org.apache.arrow.gandiva.expression.ExpressionTree;
+import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.expression.TreeNode;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+@Ignore
+public class MicroBenchmarkTest extends BaseEvaluatorTest {
+
+ private double toleranceRatio = 4.0;
+
+ @Test
+ public void testAdd3() throws Exception {
+ Field x = Field.nullable("x", int32);
+ Field n2x = Field.nullable("n2x", int32);
+ Field n3x = Field.nullable("n3x", int32);
+
+ // x + n2x + n3x
+ TreeNode add1 =
+ TreeBuilder.makeFunction(
+ "add", Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeField(n2x)), int32);
+ TreeNode add =
+ TreeBuilder.makeFunction(
+ "add", Lists.newArrayList(add1, TreeBuilder.makeField(n3x)), int32);
+ ExpressionTree expr = TreeBuilder.makeExpression(add, x);
+
+ List<Field> cols = Lists.newArrayList(x, n2x, n3x);
+ Schema schema = new Schema(cols);
+
+ long timeTaken = timedProject(new Int32DataAndVectorGenerator(allocator),
+ schema,
+ Lists.newArrayList(expr),
+ 1 * MILLION, 16 * THOUSAND,
+ 4);
+ System.out.println("Time taken for projecting 1m records of add3 is " + timeTaken + "ms");
+ Assert.assertTrue(timeTaken <= 13 * toleranceRatio);
+ }
+
+ @Test
+ public void testIf() throws Exception {
+ /*
+ * when x < 10 then 0
+ * when x < 20 then 1
+ * when x < 30 then 2
+ * when x < 40 then 3
+ * when x < 50 then 4
+ * when x < 60 then 5
+ * when x < 70 then 6
+ * when x < 80 then 7
+ * when x < 90 then 8
+ * when x < 100 then 9
+ * when x < 110 then 10
+ * when x < 120 then 11
+ * when x < 130 then 12
+ * when x < 140 then 13
+ * when x < 150 then 14
+ * when x < 160 then 15
+ * when x < 170 then 16
+ * when x < 180 then 17
+ * when x < 190 then 18
+ * when x < 200 then 19
+ * else 20
+ */
+ Field x = Field.nullable("x", int32);
+ TreeNode xNode = TreeBuilder.makeField(x);
+
+ // if (x < 100) then 9 else 10
+ int returnValue = 20;
+ TreeNode topNode = TreeBuilder.makeLiteral(returnValue);
+ int compareWith = 200;
+ while (compareWith >= 10) {
+ // cond (x < compareWith)
+ TreeNode condNode =
+ TreeBuilder.makeFunction(
+ "less_than",
+ Lists.newArrayList(xNode, TreeBuilder.makeLiteral(compareWith)),
+ boolType);
+ topNode =
+ TreeBuilder.makeIf(
+ condNode, // cond (x < compareWith)
+ TreeBuilder.makeLiteral(returnValue), // then returnValue
+ topNode, // else topNode
+ int32);
+ compareWith -= 10;
+ returnValue--;
+ }
+
+ ExpressionTree expr = TreeBuilder.makeExpression(topNode, x);
+ Schema schema = new Schema(Lists.newArrayList(x));
+
+ long timeTaken = timedProject(new BoundedInt32DataAndVectorGenerator(allocator, 250),
+ schema,
+ Lists.newArrayList(expr),
+ 1 * MILLION, 16 * THOUSAND,
+ 4);
+ System.out.println("Time taken for projecting 10m records of nestedIf is " + timeTaken + "ms");
+ Assert.assertTrue(timeTaken <= 15 * toleranceRatio);
+ }
+
+ @Test
+ public void testFilterAdd2() throws Exception {
+ Field x = Field.nullable("x", int32);
+ Field n2x = Field.nullable("n2x", int32);
+ Field n3x = Field.nullable("n3x", int32);
+
+ // x + n2x < n3x
+ TreeNode add = TreeBuilder.makeFunction("add",
+ Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeField(n2x)), int32);
+ TreeNode lessThan = TreeBuilder
+ .makeFunction("less_than", Lists.newArrayList(add, TreeBuilder.makeField(n3x)), boolType);
+ Condition condition = TreeBuilder.makeCondition(lessThan);
+
+ List<Field> cols = Lists.newArrayList(x, n2x, n3x);
+ Schema schema = new Schema(cols);
+
+ long timeTaken = timedFilter(new Int32DataAndVectorGenerator(allocator),
+ schema,
+ condition,
+ 1 * MILLION, 16 * THOUSAND,
+ 4);
+ System.out.println("Time taken for filtering 10m records of a+b<c is " + timeTaken + "ms");
+ Assert.assertTrue(timeTaken <= 12 * toleranceRatio);
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
new file mode 100644
index 000000000..28a57c9f8
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
@@ -0,0 +1,797 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.ExpressionTree;
+import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.expression.TreeNode;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.ArrowType.Decimal;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import com.google.common.collect.Lists;
+
+public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.BaseEvaluatorTest {
+ @Rule
+ public ExpectedException exception = ExpectedException.none();
+
+ @Test
+ public void test_add() throws GandivaException {
+ int precision = 38;
+ int scale = 8;
+ ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128);
+ Field a = Field.nullable("a", decimal);
+ Field b = Field.nullable("b", decimal);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, decimal, decimal);
+ Field retType = Field.nullable("c", outputType);
+ ExpressionTree root = TreeBuilder.makeExpression("add", args, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(root);
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 4;
+ byte[] validity = new byte[]{(byte) 255};
+ String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "4.12345678"};
+ String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "5.12345678"};
+
+ DecimalVector valuesa = decimalVector(aValues, precision, scale);
+ DecimalVector valuesb = decimalVector(bValues, precision, scale);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
+ valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
+
+ DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
+ outputType.getScale());
+ outVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+
+ // should have scaled down.
+ BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(3.2469136),
+ BigDecimal.valueOf(5.2469136),
+ BigDecimal.valueOf(7.2469136),
+ BigDecimal.valueOf(9.2469136)};
+
+ for (int i = 0; i < 4; i++) {
+ assertFalse(outVector.isNull(i));
+ assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i)
+ ) == 0);
+ }
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void test_add_literal() throws GandivaException {
+ int precision = 2;
+ int scale = 0;
+ ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128);
+ ArrowType.Decimal literalType = new ArrowType.Decimal(2, 1, 128);
+ Field a = Field.nullable("a", decimal);
+
+ ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.ADD, decimal, literalType);
+ Field retType = Field.nullable("c", outputType);
+ TreeNode field = TreeBuilder.makeField(a);
+ TreeNode literal = TreeBuilder.makeDecimalLiteral("6", 2, 1);
+ List<TreeNode> args = Lists.newArrayList(field, literal);
+ TreeNode root = TreeBuilder.makeFunction("add", args, outputType);
+ ExpressionTree tree = TreeBuilder.makeExpression(root, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(tree);
+
+ Schema schema = new Schema(Lists.newArrayList(a));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 4;
+ String[] aValues = new String[]{"1", "2", "3", "4"};
+
+ DecimalVector valuesa = decimalVector(aValues, precision, scale);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer()));
+
+ DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
+ outputType.getScale());
+ outVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+
+ BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(1.6), BigDecimal.valueOf(2.6),
+ BigDecimal.valueOf(3.6), BigDecimal.valueOf(4.6)};
+
+ for (int i = 0; i < 4; i++) {
+ assertFalse(outVector.isNull(i));
+ assertTrue(expOutput[i].compareTo(outVector.getObject(i)) == 0);
+ }
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void test_multiply() throws GandivaException {
+ int precision = 38;
+ int scale = 8;
+ ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128);
+ Field a = Field.nullable("a", decimal);
+ Field b = Field.nullable("b", decimal);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil
+ .OperationType.MULTIPLY, decimal, decimal);
+ Field retType = Field.nullable("c", outputType);
+ ExpressionTree root = TreeBuilder.makeExpression("multiply", args, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(root);
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 4;
+ byte[] validity = new byte[]{(byte) 255};
+ String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "999999999999.99999999"};
+ String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "999999999999.99999999"};
+
+ DecimalVector valuesa = decimalVector(aValues, precision, scale);
+ DecimalVector valuesb = decimalVector(bValues, precision, scale);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
+ valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
+
+ DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(),
+ outputType.getScale());
+ outVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+
+ // should have scaled down.
+ BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(2.385612),
+ BigDecimal.valueOf(6.632525),
+ BigDecimal.valueOf(12.879439),
+ new BigDecimal("999999999999999999980000.000000")};
+
+ for (int i = 0; i < 4; i++) {
+ assertFalse(outVector.isNull(i));
+ assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i)
+ ) == 0);
+ }
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testCompare() throws GandivaException {
+ Decimal aType = new Decimal(38, 3, 128);
+ Decimal bType = new Decimal(38, 2, 128);
+ Field a = Field.nullable("a", aType);
+ Field b = Field.nullable("b", bType);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ List<ExpressionTree> exprs = new ArrayList<>(
+ Arrays.asList(
+ TreeBuilder.makeExpression("equal", args, Field.nullable("eq", boolType)),
+ TreeBuilder.makeExpression("not_equal", args, Field.nullable("ne", boolType)),
+ TreeBuilder.makeExpression("less_than", args, Field.nullable("lt", boolType)),
+ TreeBuilder.makeExpression("less_than_or_equal_to", args, Field.nullable("le", boolType)),
+ TreeBuilder.makeExpression("greater_than", args, Field.nullable("gt", boolType)),
+ TreeBuilder.makeExpression("greater_than_or_equal_to", args, Field.nullable("ge", boolType))
+ )
+ );
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"7.620", "2.380", "3.860", "-18.160"};
+ String[] bValues = new String[]{"7.62", "3.50", "1.90", "-1.45"};
+
+ DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale());
+ DecimalVector valuesb = decimalVector(bValues, bType.getPrecision(), bType.getScale());
+ batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
+ valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
+
+ // expected results.
+ boolean[][] expected = {
+ {true, false, false, false}, // eq
+ {false, true, true, true}, // ne
+ {false, true, false, true}, // lt
+ {true, true, false, true}, // le
+ {false, false, true, false}, // gt
+ {true, false, true, false}, // ge
+ };
+
+ // Allocate output vectors.
+ output = new ArrayList<>(
+ Arrays.asList(
+ new BitVector("eq", allocator),
+ new BitVector("ne", allocator),
+ new BitVector("lt", allocator),
+ new BitVector("le", allocator),
+ new BitVector("gt", allocator),
+ new BitVector("ge", allocator)
+ )
+ );
+ for (ValueVector v : output) {
+ v.allocateNew();
+ }
+
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ for (int idx = 0; idx < output.size(); ++idx) {
+ boolean[] expectedArray = expected[idx];
+ BitVector resultVector = (BitVector) output.get(idx);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(resultVector.isNull(i));
+ assertEquals("mismatch in result for expr at idx " + idx + " for row " + i,
+ expectedArray[i], resultVector.getObject(i).booleanValue());
+ }
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testRound() throws GandivaException {
+ Decimal aType = new Decimal(38, 2, 128);
+ Decimal aWithScaleZero = new Decimal(38, 0, 128);
+ Decimal aWithScaleOne = new Decimal(38, 1, 128);
+ Field a = Field.nullable("a", aType);
+ List<Field> args = Lists.newArrayList(a);
+
+ List<ExpressionTree> exprs = new ArrayList<>(
+ Arrays.asList(
+ TreeBuilder.makeExpression("abs", args, Field.nullable("abs", aType)),
+ TreeBuilder.makeExpression("ceil", args, Field.nullable("ceil", aWithScaleZero)),
+ TreeBuilder.makeExpression("floor", args, Field.nullable("floor", aWithScaleZero)),
+ TreeBuilder.makeExpression("round", args, Field.nullable("round", aWithScaleZero)),
+ TreeBuilder.makeExpression("truncate", args, Field.nullable("truncate", aWithScaleZero)),
+ TreeBuilder.makeExpression(
+ TreeBuilder.makeFunction("round",
+ Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)),
+ aWithScaleOne),
+ Field.nullable("round_scale_1", aWithScaleOne)),
+ TreeBuilder.makeExpression(
+ TreeBuilder.makeFunction("truncate",
+ Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)),
+ aWithScaleOne),
+ Field.nullable("truncate_scale_1", aWithScaleOne))
+ )
+ );
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"};
+
+ DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale());
+ batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer()));
+
+ // expected results.
+ BigDecimal[][] expected = {
+ {BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58),
+ BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58)}, // abs
+ {BigDecimal.valueOf(2), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // ceil
+ {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-2), BigDecimal.valueOf(-2)}, // floor
+ {BigDecimal.valueOf(1), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-2)}, // round
+ {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // truncate
+ {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6),
+ BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)}, // round-to-scale-1
+ {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.5),
+ BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.5)}, // truncate-to-scale-1
+ };
+
+ // Allocate output vectors.
+ output = new ArrayList<>(
+ Arrays.asList(
+ new DecimalVector("abs", allocator, aType.getPrecision(), aType.getScale()),
+ new DecimalVector("ceil", allocator, aType.getPrecision(), 0),
+ new DecimalVector("floor", allocator, aType.getPrecision(), 0),
+ new DecimalVector("round", allocator, aType.getPrecision(), 0),
+ new DecimalVector("truncate", allocator, aType.getPrecision(), 0),
+ new DecimalVector("round_to_scale_1", allocator, aType.getPrecision(), 1),
+ new DecimalVector("truncate_to_scale_1", allocator, aType.getPrecision(), 1)
+ )
+ );
+ for (ValueVector v : output) {
+ v.allocateNew();
+ }
+
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ for (int idx = 0; idx < output.size(); ++idx) {
+ BigDecimal[] expectedArray = expected[idx];
+ DecimalVector resultVector = (DecimalVector) output.get(idx);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(resultVector.isNull(i));
+ assertTrue("mismatch in result for " +
+ "field " + resultVector.getField().getName() +
+ " for row " + i +
+ " expected " + expectedArray[i] +
+ ", got " + resultVector.getObject(i),
+ expectedArray[i].compareTo(resultVector.getObject(i)) == 0);
+ }
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testCastToDecimal() throws GandivaException {
+ Decimal decimalType = new Decimal(38, 2, 128);
+ Decimal decimalWithScaleOne = new Decimal(38, 1, 128);
+ Field dec = Field.nullable("dec", decimalType);
+ Field int64f = Field.nullable("int64", int64);
+ Field doublef = Field.nullable("float64", float64);
+
+ List<ExpressionTree> exprs = new ArrayList<>(
+ Arrays.asList(
+ TreeBuilder.makeExpression("castDECIMAL",
+ Lists.newArrayList(int64f),
+ Field.nullable("int64_to_dec", decimalType)),
+
+ TreeBuilder.makeExpression("castDECIMAL",
+ Lists.newArrayList(doublef),
+ Field.nullable("float64_to_dec", decimalType)),
+
+ TreeBuilder.makeExpression("castDECIMAL",
+ Lists.newArrayList(dec),
+ Field.nullable("dec_to_dec", decimalWithScaleOne))
+ )
+ );
+
+ Schema schema = new Schema(Lists.newArrayList(int64f, doublef, dec));
+ Projector eval = Projector.make(schema, exprs);
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"};
+ DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
+ batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(
+ new ArrowFieldNode(numRows, 0),
+ new ArrowFieldNode(numRows, 0),
+ new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(
+ arrowBufWithAllValid(4),
+ longBuf(new long[]{123, 158, -123, -158}),
+ arrowBufWithAllValid(4),
+ doubleBuf(new double[]{1.23, 1.58, -1.23, -1.58}),
+ valuesa.getValidityBuffer(),
+ valuesa.getDataBuffer())
+ );
+
+ // Allocate output vectors.
+ output = new ArrayList<>(
+ Arrays.asList(
+ new DecimalVector("int64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()),
+ new DecimalVector("float64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()),
+ new DecimalVector("dec_to_dec", allocator,
+ decimalWithScaleOne.getPrecision(), decimalWithScaleOne.getScale())
+ )
+ );
+ for (ValueVector v : output) {
+ v.allocateNew();
+ }
+
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ BigDecimal[][] expected = {
+ { BigDecimal.valueOf(123), BigDecimal.valueOf(158),
+ BigDecimal.valueOf(-123), BigDecimal.valueOf(-158)},
+ { BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58),
+ BigDecimal.valueOf(-1.23), BigDecimal.valueOf(-1.58)},
+ { BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6),
+ BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)}
+ };
+ for (int idx = 0; idx < output.size(); ++idx) {
+ BigDecimal[] expectedArray = expected[idx];
+ DecimalVector resultVector = (DecimalVector) output.get(idx);
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(resultVector.isNull(i));
+ assertTrue("mismatch in result for " +
+ "field " + resultVector.getField().getName() +
+ " for row " + i +
+ " expected " + expectedArray[i] +
+ ", got " + resultVector.getObject(i),
+ expectedArray[i].compareTo(resultVector.getObject(i)) == 0);
+ }
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testCastToLong() throws GandivaException {
+ Decimal decimalType = new Decimal(38, 2, 128);
+ Field dec = Field.nullable("dec", decimalType);
+
+ Schema schema = new Schema(Lists.newArrayList(dec));
+ Projector eval = Projector.make(schema,
+ Lists.newArrayList(
+ TreeBuilder.makeExpression("castBIGINT",
+ Lists.newArrayList(dec),
+ Field.nullable("dec_to_int64", int64)
+ )
+ )
+ );
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 5;
+ String[] aValues = new String[]{"1.23", "1.50", "98765.78", "-1.23", "-1.58"};
+ DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
+ batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(
+ new ArrowFieldNode(numRows, 0)
+ ),
+ Lists.newArrayList(
+ valuesa.getValidityBuffer(),
+ valuesa.getDataBuffer()
+ )
+ );
+
+ // Allocate output vectors.
+ BigIntVector resultVector = new BigIntVector("dec_to_int64", allocator);
+ resultVector.allocateNew();
+ output = new ArrayList<>(Arrays.asList(resultVector));
+
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ long[] expected = {1, 2, 98766, -1, -2};
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(resultVector.isNull(i));
+ assertEquals(expected[i], resultVector.get(i));
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testCastToDouble() throws GandivaException {
+ Decimal decimalType = new Decimal(38, 2, 128);
+ Field dec = Field.nullable("dec", decimalType);
+
+ Schema schema = new Schema(Lists.newArrayList(dec));
+ Projector eval = Projector.make(schema,
+ Lists.newArrayList(
+ TreeBuilder.makeExpression("castFLOAT8",
+ Lists.newArrayList(dec),
+ Field.nullable("dec_to_float64", float64)
+ )
+ )
+ );
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"};
+ DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
+ batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(
+ new ArrowFieldNode(numRows, 0)
+ ),
+ Lists.newArrayList(
+ valuesa.getValidityBuffer(),
+ valuesa.getDataBuffer()
+ )
+ );
+
+ // Allocate output vectors.
+ Float8Vector resultVector = new Float8Vector("dec_to_float64", allocator);
+ resultVector.allocateNew();
+ output = new ArrayList<>(Arrays.asList(resultVector));
+
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ double[] expected = {1.23, 1.58, -1.23, -1.58};
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(resultVector.isNull(i));
+ assertEquals(expected[i], resultVector.get(i), 0);
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testCastToString() throws GandivaException {
+ Decimal decimalType = new Decimal(38, 2, 128);
+ Field dec = Field.nullable("dec", decimalType);
+ Field str = Field.nullable("str", new ArrowType.Utf8());
+ TreeNode field = TreeBuilder.makeField(dec);
+ TreeNode literal = TreeBuilder.makeLiteral(5L);
+ List<TreeNode> args = Lists.newArrayList(field, literal);
+ TreeNode cast = TreeBuilder.makeFunction("castVARCHAR", args, new ArrowType.Utf8());
+ TreeNode root = TreeBuilder.makeFunction("equal",
+ Lists.newArrayList(cast, TreeBuilder.makeField(str)), new ArrowType.Bool());
+ ExpressionTree tree = TreeBuilder.makeExpression(root, Field.nullable("are_equal", new ArrowType.Bool()));
+
+ Schema schema = new Schema(Lists.newArrayList(dec, str));
+ Projector eval = Projector.make(schema, Lists.newArrayList(tree)
+ );
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"10.51", "100.23", "-1000.23", "-0000.10"};
+ String[] expected = {"10.51", "100.2", "-1000", "-0.10"};
+ DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale());
+ VarCharVector result = varcharVector(expected);
+ batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(
+ new ArrowFieldNode(numRows, 0)
+ ),
+ Lists.newArrayList(
+ valuesa.getValidityBuffer(),
+ valuesa.getDataBuffer(),
+ result.getValidityBuffer(),
+ result.getOffsetBuffer(),
+ result.getDataBuffer()
+ )
+ );
+
+ BitVector resultVector = new BitVector("res", allocator);
+ resultVector.allocateNew();
+ output = new ArrayList<>(Arrays.asList(resultVector));
+
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ for (int i = 0; i < numRows; i++) {
+ assertTrue(resultVector.getObject(i).booleanValue());
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testCastStringToDecimal() throws GandivaException {
+ Decimal decimalType = new Decimal(4, 2, 128);
+ Field dec = Field.nullable("dec", decimalType);
+
+ Field str = Field.nullable("str", new ArrowType.Utf8());
+ TreeNode field = TreeBuilder.makeField(str);
+ List<TreeNode> args = Lists.newArrayList(field);
+ TreeNode cast = TreeBuilder.makeFunction("castDECIMAL", args, decimalType);
+ ExpressionTree tree = TreeBuilder.makeExpression(cast, Field.nullable("dec_str", decimalType));
+
+ Schema schema = new Schema(Lists.newArrayList(str));
+ Projector eval = Projector.make(schema, Lists.newArrayList(tree)
+ );
+
+ List<ValueVector> output = null;
+ ArrowRecordBatch batch = null;
+ try {
+ int numRows = 4;
+ String[] aValues = new String[]{"10.5134", "-0.1", "10.516", "-1000"};
+ VarCharVector valuesa = varcharVector(aValues);
+ batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(
+ new ArrowFieldNode(numRows, 0)
+ ),
+ Lists.newArrayList(
+ valuesa.getValidityBuffer(),
+ valuesa.getOffsetBuffer(),
+ valuesa.getDataBuffer()
+ )
+ );
+
+ DecimalVector resultVector = new DecimalVector("res", allocator,
+ decimalType.getPrecision(), decimalType.getScale());
+ resultVector.allocateNew();
+ output = new ArrayList<>(Arrays.asList(resultVector));
+
+ BigDecimal[] expected = {BigDecimal.valueOf(10.51), BigDecimal.valueOf(-0.10),
+ BigDecimal.valueOf(10.52), BigDecimal.valueOf(0.00)};
+ // evaluate expressions.
+ eval.evaluate(batch, output);
+
+ // compare the outputs.
+ for (int i = 0; i < numRows; i++) {
+ assertTrue("mismatch in result for " +
+ "field " + resultVector.getField().getName() +
+ " for row " + i +
+ " expected " + expected[i] +
+ ", got " + resultVector.getObject(i), expected[i].compareTo(resultVector.getObject(i)) == 0);
+ }
+ } finally {
+ // free buffers
+ if (batch != null) {
+ releaseRecordBatch(batch);
+ }
+ if (output != null) {
+ releaseValueVectors(output);
+ }
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testInvalidDecimal() throws GandivaException {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" +
+ " : 0");
+ Decimal decimalType = new Decimal(0, 0, 128);
+ Field int64f = Field.nullable("int64", int64);
+
+ Schema schema = new Schema(Lists.newArrayList(int64f));
+ Projector eval = Projector.make(schema,
+ Lists.newArrayList(
+ TreeBuilder.makeExpression("castDECIMAL",
+ Lists.newArrayList(int64f),
+ Field.nullable("invalid_dec", decimalType)
+ )
+ )
+ );
+ }
+
+ @Test
+ public void testInvalidDecimalGt38() throws GandivaException {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" +
+ " : 42");
+ Decimal decimalType = new Decimal(42, 0, 128);
+ Field int64f = Field.nullable("int64", int64);
+
+ Schema schema = new Schema(Lists.newArrayList(int64f));
+ Projector eval = Projector.make(schema,
+ Lists.newArrayList(
+ TreeBuilder.makeExpression("castDECIMAL",
+ Lists.newArrayList(int64f),
+ Field.nullable("invalid_dec", decimalType)
+ )
+ )
+ );
+ }
+}
+
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
new file mode 100644
index 000000000..03c9377b0
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
@@ -0,0 +1,2470 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.math.BigDecimal;
+import java.nio.charset.Charset;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.IntStream;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.expression.ExpressionTree;
+import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.expression.TreeNode;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.DateUnit;
+import org.apache.arrow.vector.types.IntervalUnit;
+import org.apache.arrow.vector.types.TimeUnit;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+public class ProjectorTest extends BaseEvaluatorTest {
+
+ private Charset utf8Charset = Charset.forName("UTF-8");
+ private Charset utf16Charset = Charset.forName("UTF-16");
+
+ @Rule
+ public ExpectedException thrown = ExpectedException.none();
+
+ List<ArrowBuf> varBufs(String[] strings, Charset charset) {
+ ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4);
+
+ long dataBufferSize = 0L;
+ for (String string : strings) {
+ dataBufferSize += string.getBytes(charset).length;
+ }
+
+ ArrowBuf dataBuffer = allocator.buffer(dataBufferSize);
+
+ int startOffset = 0;
+ for (int i = 0; i < strings.length; i++) {
+ offsetsBuffer.writeInt(startOffset);
+
+ final byte[] bytes = strings[i].getBytes(charset);
+ dataBuffer = dataBuffer.reallocIfNeeded(dataBuffer.writerIndex() + bytes.length);
+ dataBuffer.setBytes(startOffset, bytes, 0, bytes.length);
+ startOffset += bytes.length;
+ }
+ offsetsBuffer.writeInt(startOffset); // offset for the last element
+
+ return Arrays.asList(offsetsBuffer, dataBuffer);
+ }
+
+ List<ArrowBuf> stringBufs(String[] strings) {
+ return varBufs(strings, utf8Charset);
+ }
+
+ List<ArrowBuf> binaryBufs(String[] strings) {
+ return varBufs(strings, utf16Charset);
+ }
+
+ private void testMakeProjectorParallel(ConfigurationBuilder.ConfigOptions configOptions) throws InterruptedException {
+ List<Schema> schemas = Lists.newArrayList();
+ Field a = Field.nullable("a", int64);
+ Field b = Field.nullable("b", int64);
+ IntStream.range(0, 1000)
+ .forEach(
+ i -> {
+ Field c = Field.nullable("" + i, int64);
+ List<Field> cols = Lists.newArrayList(a, b, c);
+ schemas.add(new Schema(cols));
+ });
+
+ TreeNode aNode = TreeBuilder.makeField(a);
+ TreeNode bNode = TreeBuilder.makeField(b);
+ List<TreeNode> args = Lists.newArrayList(aNode, bNode);
+
+ TreeNode cond = TreeBuilder.makeFunction("greater_than", args, boolType);
+ TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, int64);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifNode, Field.nullable("c", int64));
+ List<ExpressionTree> exprs = Lists.newArrayList(expr);
+
+ // build projectors in parallel choosing schema at random
+ // this should hit the same cache entry thus exposing
+ // any threading issues.
+ ExecutorService executors = Executors.newFixedThreadPool(16);
+
+ IntStream.range(0, 1000)
+ .forEach(
+ i -> {
+ executors.submit(
+ () -> {
+ try {
+ Projector evaluator = configOptions == null ?
+ Projector.make(schemas.get((int) (Math.random() * 100)), exprs) :
+ Projector.make(schemas.get((int) (Math.random() * 100)), exprs, configOptions);
+ evaluator.close();
+ } catch (GandivaException e) {
+ e.printStackTrace();
+ }
+ });
+ });
+ executors.shutdown();
+ executors.awaitTermination(100, java.util.concurrent.TimeUnit.SECONDS);
+ }
+
+ @Test
+ public void testMakeProjectorParallel() throws Exception {
+ testMakeProjectorParallel(null);
+ testMakeProjectorParallel(new ConfigurationBuilder.ConfigOptions().withTargetCPU(false));
+ testMakeProjectorParallel(new ConfigurationBuilder.ConfigOptions().withTargetCPU(false).withOptimize(false));
+ }
+
+ // Will be fixed by https://issues.apache.org/jira/browse/ARROW-4371
+ @Ignore
+ @Test
+ public void testMakeProjector() throws GandivaException {
+ Field a = Field.nullable("a", int64);
+ Field b = Field.nullable("b", int64);
+ TreeNode aNode = TreeBuilder.makeField(a);
+ TreeNode bNode = TreeBuilder.makeField(b);
+ List<TreeNode> args = Lists.newArrayList(aNode, bNode);
+
+ List<Field> cols = Lists.newArrayList(a, b);
+ Schema schema = new Schema(cols);
+
+ TreeNode cond = TreeBuilder.makeFunction("greater_than", args, boolType);
+ TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, int64);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifNode, Field.nullable("c", int64));
+ List<ExpressionTree> exprs = Lists.newArrayList(expr);
+
+ long startTime = System.currentTimeMillis();
+ Projector evaluator1 = Projector.make(schema, exprs);
+ System.out.println(
+ "Projector build: iteration 1 took " + (System.currentTimeMillis() - startTime) + " ms");
+ startTime = System.currentTimeMillis();
+ Projector evaluator2 = Projector.make(schema, exprs);
+ System.out.println(
+ "Projector build: iteration 2 took " + (System.currentTimeMillis() - startTime) + " ms");
+ startTime = System.currentTimeMillis();
+ Projector evaluator3 = Projector.make(schema, exprs);
+ long timeToMakeProjector = (System.currentTimeMillis() - startTime);
+ // should be getting the projector from the cache;
+ // giving 5ms for varying system load.
+ Assert.assertTrue(timeToMakeProjector < 5L);
+
+ evaluator1.close();
+ evaluator2.close();
+ evaluator3.close();
+ }
+
+ @Test
+ public void testMakeProjectorValidationError() throws InterruptedException {
+
+ Field a = Field.nullable("a", int64);
+ TreeNode aNode = TreeBuilder.makeField(a);
+ List<TreeNode> args = Lists.newArrayList(aNode);
+
+ List<Field> cols = Lists.newArrayList(a);
+ Schema schema = new Schema(cols);
+
+ TreeNode cond = TreeBuilder.makeFunction("non_existent_fn", args, boolType);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(cond, Field.nullable("c", int64));
+ List<ExpressionTree> exprs = Lists.newArrayList(expr);
+
+ boolean exceptionThrown = false;
+ try {
+ Projector evaluator1 = Projector.make(schema, exprs);
+ } catch (GandivaException e) {
+ exceptionThrown = true;
+ }
+
+ Assert.assertTrue(exceptionThrown);
+
+ // allow GC to collect any temp resources.
+ Thread.sleep(1000);
+
+ // try again to ensure no temporary resources.
+ exceptionThrown = false;
+ try {
+ Projector evaluator1 = Projector.make(schema, exprs);
+ } catch (GandivaException e) {
+ exceptionThrown = true;
+ }
+
+ Assert.assertTrue(exceptionThrown);
+ }
+
+ @Test
+ public void testEvaluate() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Field retType = Field.nullable("c", int32);
+ ExpressionTree root = TreeBuilder.makeExpression("add", args, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(root);
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] bValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
+
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesb = intBuf(bValues);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)),
+ Lists.newArrayList(validitya, valuesa, validityb, valuesb));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 8; i++) {
+ assertFalse(intVector.isNull(i));
+ assertEquals(17, intVector.get(i));
+ }
+ for (int i = 8; i < 16; i++) {
+ assertTrue(intVector.isNull(i));
+ }
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testEvaluateDivZero() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Field retType = Field.nullable("c", int32);
+ ExpressionTree root = TreeBuilder.makeExpression("divide", args, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(root);
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 2;
+ byte[] validity = new byte[]{(byte) 255};
+ // second half is "undefined"
+ int[] aValues = new int[]{2, 2};
+ int[] bValues = new int[]{1, 0};
+
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesb = intBuf(bValues);
+ ArrowRecordBatch batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(validitya, valuesa, validityb, valuesb));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ boolean exceptionThrown = false;
+ try {
+ eval.evaluate(batch, output);
+ } catch (GandivaException e) {
+ Assert.assertTrue(e.getMessage().contains("divide by zero"));
+ exceptionThrown = true;
+ }
+ Assert.assertTrue(exceptionThrown);
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testDivZeroParallel() throws GandivaException, InterruptedException {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ Field c = Field.nullable("c", int32);
+ List<Field> cols = Lists.newArrayList(a, b);
+ Schema s = new Schema(cols);
+
+ List<Field> args = Lists.newArrayList(a, b);
+
+ ExpressionTree expr = TreeBuilder.makeExpression("divide", args, c);
+ List<ExpressionTree> exprs = Lists.newArrayList(expr);
+
+ ExecutorService executors = Executors.newFixedThreadPool(16);
+
+ AtomicInteger errorCount = new AtomicInteger(0);
+ AtomicInteger errorCountExp = new AtomicInteger(0);
+ // pre-build the projector so that same projector is used for all executions.
+ Projector test = Projector.make(s, exprs);
+
+ IntStream.range(0, 1000).forEach(i -> {
+ executors.submit(() -> {
+ try {
+ Projector evaluator = Projector.make(s, exprs);
+ int numRows = 2;
+ byte[] validity = new byte[]{(byte) 255};
+ int[] aValues = new int[]{2, 2};
+ int[] bValues;
+ if (i % 2 == 0) {
+ errorCountExp.incrementAndGet();
+ bValues = new int[]{1, 0};
+ } else {
+ bValues = new int[]{1, 1};
+ }
+
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesb = intBuf(bValues);
+ ArrowRecordBatch batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows,
+ 0)),
+ Lists.newArrayList(validitya, valuesa, validityb, valuesb));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ try {
+ evaluator.evaluate(batch, output);
+ } catch (GandivaException e) {
+ errorCount.incrementAndGet();
+ }
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ evaluator.close();
+ } catch (GandivaException ignore) {
+ }
+ });
+ });
+ executors.shutdown();
+ executors.awaitTermination(100, java.util.concurrent.TimeUnit.SECONDS);
+ test.close();
+ Assert.assertEquals(errorCountExp.intValue(), errorCount.intValue());
+ }
+
+ @Test
+ public void testAdd3() throws GandivaException, Exception {
+ Field x = Field.nullable("x", int32);
+ Field n2x = Field.nullable("n2x", int32);
+ Field n3x = Field.nullable("n3x", int32);
+
+ List<TreeNode> args = new ArrayList<TreeNode>();
+
+ // x + n2x + n3x
+ TreeNode add1 =
+ TreeBuilder.makeFunction(
+ "add", Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeField(n2x)), int32);
+ TreeNode add =
+ TreeBuilder.makeFunction(
+ "add", Lists.newArrayList(add1, TreeBuilder.makeField(n3x)), int32);
+ ExpressionTree expr = TreeBuilder.makeExpression(add, x);
+
+ List<Field> cols = Lists.newArrayList(x, n2x, n3x);
+ Schema schema = new Schema(cols);
+
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ // second half is "undefined"
+ int[] xValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] n2xValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
+ int[] n3xValues = new int[]{1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8, 8, 7, 6, 5};
+
+ int[] expected = new int[]{18, 19, 20, 21, 21, 20, 19, 18, 18, 19, 20, 21, 21, 20, 19, 18};
+
+ ArrowBuf xValidity = buf(validity);
+ ArrowBuf xData = intBuf(xValues);
+ ArrowBuf n2xValidity = buf(validity);
+ ArrowBuf n2xData = intBuf(n2xValues);
+ ArrowBuf n3xValidity = buf(validity);
+ ArrowBuf n3xData = intBuf(n3xValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 8);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode, fieldNode),
+ Lists.newArrayList(xValidity, xData, n2xValidity, n2xData, n3xValidity, n3xData));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 8; i++) {
+ assertFalse(intVector.isNull(i));
+ assertEquals(expected[i], intVector.get(i));
+ }
+ for (int i = 8; i < 16; i++) {
+ assertTrue(intVector.isNull(i));
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testStringFields() throws GandivaException {
+ /*
+ * when x < "hello" then octet_length(x) + a
+ * else octet_length(x) + b
+ */
+
+ Field x = Field.nullable("x", new ArrowType.Utf8());
+ Field a = Field.nullable("a", new ArrowType.Int(32, true));
+ Field b = Field.nullable("b", new ArrowType.Int(32, true));
+
+ ArrowType retType = new ArrowType.Int(32, true);
+
+ TreeNode cond =
+ TreeBuilder.makeFunction(
+ "less_than",
+ Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeStringLiteral("hello")),
+ boolType);
+ TreeNode octetLenFuncNode =
+ TreeBuilder.makeFunction(
+ "octet_length", Lists.newArrayList(TreeBuilder.makeField(x)), retType);
+ TreeNode octetLenPlusANode =
+ TreeBuilder.makeFunction(
+ "add", Lists.newArrayList(TreeBuilder.makeField(a), octetLenFuncNode), retType);
+ TreeNode octetLenPlusBNode =
+ TreeBuilder.makeFunction(
+ "add", Lists.newArrayList(TreeBuilder.makeField(b), octetLenFuncNode), retType);
+
+ TreeNode ifHello = TreeBuilder.makeIf(cond, octetLenPlusANode, octetLenPlusBNode, retType);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifHello, Field.nullable("res", retType));
+ Schema schema = new Schema(Lists.newArrayList(a, x, b));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 5;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ // "A função" means "The function" in portugese
+ String[] valuesX = new String[]{"hell", "abc", "hellox", "ijk", "A função"};
+ int[] valuesA = new int[]{10, 20, 30, 40, 50};
+ int[] valuesB = new int[]{110, 120, 130, 140, 150};
+ int[] expected = new int[]{14, 23, 136, 143, 60};
+
+ ArrowBuf validityX = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(valuesX);
+ ArrowBuf validityA = buf(validity);
+ ArrowBuf dataA = intBuf(valuesA);
+ ArrowBuf validityB = buf(validity);
+ ArrowBuf dataB = intBuf(valuesB);
+
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(
+ validityA, dataA, validityX, dataBufsX.get(0), dataBufsX.get(1), validityB, dataB));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(intVector.isNull(i));
+ assertEquals(expected[i], intVector.get(i));
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testStringOutput() throws GandivaException {
+ /*
+ * if (x >= 0) "hi" else "bye"
+ */
+
+ Field x = Field.nullable("x", new ArrowType.Int(32, true));
+
+ ArrowType retType = new ArrowType.Utf8();
+
+ TreeNode ifHiBye = TreeBuilder.makeIf(
+ TreeBuilder.makeFunction(
+ "greater_than_or_equal_to",
+ Lists.newArrayList(
+ TreeBuilder.makeField(x),
+ TreeBuilder.makeLiteral(0)
+ ),
+ boolType),
+ TreeBuilder.makeStringLiteral("hi"),
+ TreeBuilder.makeStringLiteral("bye"),
+ retType);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifHiBye, Field.nullable("res", retType));
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ // fill up input record batch
+ int numRows = 4;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ int[] xValues = new int[]{10, -10, 20, -20};
+ String[] expected = new String[]{"hi", "bye", "hi", "bye"};
+ ArrowBuf validityX = buf(validity);
+ ArrowBuf dataX = intBuf(xValues);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList( validityX, dataX));
+
+ // allocate data for output vector.
+ VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+ outVector.allocateNew(64, numRows);
+
+
+ // evaluate expression
+ List<ValueVector> output = new ArrayList<>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+
+ // match expected output.
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(outVector.isNull(i));
+ assertEquals(expected[i], new String(outVector.get(i)));
+ }
+
+ // test with insufficient data buffer.
+ try {
+ outVector.allocateNew(4, numRows);
+ eval.evaluate(batch, output);
+ } finally {
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+ }
+
+ @Test
+ public void testRegex() throws GandivaException {
+ /*
+ * like "%map%"
+ */
+
+ Field x = Field.nullable("x", new ArrowType.Utf8());
+
+ TreeNode cond =
+ TreeBuilder.makeFunction(
+ "like",
+ Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeStringLiteral("%map%")),
+ boolType);
+ ExpressionTree expr = TreeBuilder.makeExpression(cond, Field.nullable("res", boolType));
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 5;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ String[] valuesX = new String[]{"mapD", "maps", "google maps", "map", "MapR"};
+ boolean[] expected = new boolean[]{true, true, true, true, false};
+
+ ArrowBuf validityX = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(valuesX);
+
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1)));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(bitVector.isNull(i));
+ assertEquals(expected[i], bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testRegexpReplace() throws GandivaException {
+
+ Field x = Field.nullable("x", new ArrowType.Utf8());
+ Field replaceString = Field.nullable("replaceString", new ArrowType.Utf8());
+
+ Field retType = Field.nullable("c", new ArrowType.Utf8());
+
+ TreeNode cond =
+ TreeBuilder.makeFunction(
+ "regexp_replace",
+ Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeStringLiteral("ana"),
+ TreeBuilder.makeField(replaceString)),
+ new ArrowType.Utf8());
+ ExpressionTree expr = TreeBuilder.makeExpression(cond, retType);
+ Schema schema = new Schema(Lists.newArrayList(x, replaceString));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 5;
+ byte[] validity = new byte[]{(byte) 15, 0};
+ String[] valuesX = new String[]{"banana", "bananaana", "bananana", "anaana", "anaana"};
+ String[] valuesReplace = new String[]{"ue", "", "", "c", ""};
+ String[] expected = new String[]{"buena", "bna", "bn", "cc", null};
+
+ ArrowBuf validityX = buf(validity);
+ ArrowBuf validityReplace = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(valuesX);
+ List<ArrowBuf> dataBufsReplace = stringBufs(valuesReplace);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1), validityReplace,
+ dataBufsReplace.get(0), dataBufsReplace.get(1)));
+
+ // allocate data for output vector.
+ VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+ outVector.allocateNew(numRows * 15, numRows);
+
+ // evaluate expression
+ List<ValueVector> output = new ArrayList<>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+ eval.close();
+
+ // match expected output.
+ for (int i = 0; i < numRows - 1; i++) {
+ assertFalse("Expect none value equals null", outVector.isNull(i));
+ assertEquals(expected[i], new String(outVector.get(i)));
+ }
+
+ assertTrue("Last value must be null", outVector.isNull(numRows - 1));
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testRand() throws GandivaException {
+
+ TreeNode randWithSeed =
+ TreeBuilder.makeFunction(
+ "rand",
+ Lists.newArrayList(TreeBuilder.makeLiteral(12)),
+ float64);
+ TreeNode rand =
+ TreeBuilder.makeFunction(
+ "rand",
+ Lists.newArrayList(),
+ float64);
+ ExpressionTree exprWithSeed = TreeBuilder.makeExpression(randWithSeed, Field.nullable("res", float64));
+ ExpressionTree expr = TreeBuilder.makeExpression(rand, Field.nullable("res2", float64));
+ Field x = Field.nullable("x", new ArrowType.Utf8());
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector evalWithSeed = Projector.make(schema, Lists.newArrayList(exprWithSeed));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ String[] valuesX = new String[] {"mapD", "maps", "google maps", "map", "MapR"};
+ double[] expected = new double[] {0.1597116001879662D, 0.7347813877263527D, 0.6069965050584282D,
+ 0.7240285696335824D, 0.09975540272957834D};
+
+ ArrowBuf validityX = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(valuesX);
+
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1)));
+
+ Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator);
+ float8Vector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(float8Vector);
+ evalWithSeed.evaluate(batch, output);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(float8Vector.isNull(i));
+ assertEquals(expected[i], float8Vector.getObject(i), 0.000000001);
+ }
+
+ eval.evaluate(batch, output); // without seed
+ assertNotEquals(float8Vector.getObject(0), float8Vector.getObject(1), 0.000000001);
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ evalWithSeed.close();
+ }
+
+ @Test
+ public void testBinaryFields() throws GandivaException {
+ Field a = Field.nullable("a", new ArrowType.Binary());
+ Field b = Field.nullable("b", new ArrowType.Binary());
+ List<Field> args = Lists.newArrayList(a, b);
+
+ ArrowType retType = new ArrowType.Bool();
+ ExpressionTree expr = TreeBuilder.makeExpression("equal", args, Field.nullable("res", retType));
+
+ Schema schema = new Schema(Lists.newArrayList(args));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 5;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ String[] valuesA = new String[]{"a", "aa", "aaa", "aaaa", "A função"};
+ String[] valuesB = new String[]{"a", "bb", "aaa", "bbbbb", "A função"};
+ boolean[] expected = new boolean[]{true, false, true, false, true};
+
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf validityb = buf(validity);
+ List<ArrowBuf> inBufsA = binaryBufs(valuesA);
+ List<ArrowBuf> inBufsB = binaryBufs(valuesB);
+
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)),
+ Lists.newArrayList(
+ validitya,
+ inBufsA.get(0),
+ inBufsA.get(1),
+ validityb,
+ inBufsB.get(0),
+ inBufsB.get(1)));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(bitVector.isNull(i));
+ assertEquals(expected[i], bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ private TreeNode makeLongLessThanCond(TreeNode arg, long value) {
+ return TreeBuilder.makeFunction(
+ "less_than", Lists.newArrayList(arg, TreeBuilder.makeLiteral(value)), boolType);
+ }
+
+ private TreeNode makeLongGreaterThanCond(TreeNode arg, long value) {
+ return TreeBuilder.makeFunction(
+ "greater_than", Lists.newArrayList(arg, TreeBuilder.makeLiteral(value)), boolType);
+ }
+
+ private TreeNode ifLongLessThanElse(
+ TreeNode arg, long value, long thenValue, TreeNode elseNode, ArrowType type) {
+ return TreeBuilder.makeIf(
+ makeLongLessThanCond(arg, value), TreeBuilder.makeLiteral(thenValue), elseNode, type);
+ }
+
+ @Test
+ public void testIf() throws GandivaException, Exception {
+ /*
+ * when x < 10 then 0
+ * when x < 20 then 1
+ * when x < 30 then 2
+ * when x < 40 then 3
+ * when x < 50 then 4
+ * when x < 60 then 5
+ * when x < 70 then 6
+ * when x < 80 then 7
+ * when x < 90 then 8
+ * when x < 100 then 9
+ * else 10
+ */
+ Field x = Field.nullable("x", int64);
+ TreeNode xNode = TreeBuilder.makeField(x);
+
+ // if (x < 100) then 9 else 10
+ TreeNode ifLess100 = ifLongLessThanElse(xNode, 100L, 9L, TreeBuilder.makeLiteral(10L), int64);
+ // if (x < 90) then 8 else ifLess100
+ TreeNode ifLess90 = ifLongLessThanElse(xNode, 90L, 8L, ifLess100, int64);
+ // if (x < 80) then 7 else ifLess90
+ TreeNode ifLess80 = ifLongLessThanElse(xNode, 80L, 7L, ifLess90, int64);
+ // if (x < 70) then 6 else ifLess80
+ TreeNode ifLess70 = ifLongLessThanElse(xNode, 70L, 6L, ifLess80, int64);
+ // if (x < 60) then 5 else ifLess70
+ TreeNode ifLess60 = ifLongLessThanElse(xNode, 60L, 5L, ifLess70, int64);
+ // if (x < 50) then 4 else ifLess60
+ TreeNode ifLess50 = ifLongLessThanElse(xNode, 50L, 4L, ifLess60, int64);
+ // if (x < 40) then 3 else ifLess50
+ TreeNode ifLess40 = ifLongLessThanElse(xNode, 40L, 3L, ifLess50, int64);
+ // if (x < 30) then 2 else ifLess40
+ TreeNode ifLess30 = ifLongLessThanElse(xNode, 30L, 2L, ifLess40, int64);
+ // if (x < 20) then 1 else ifLess30
+ TreeNode ifLess20 = ifLongLessThanElse(xNode, 20L, 1L, ifLess30, int64);
+ // if (x < 10) then 0 else ifLess20
+ TreeNode ifLess10 = ifLongLessThanElse(xNode, 10L, 0L, ifLess20, int64);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifLess10, x);
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, (byte) 255};
+ long[] xValues = new long[]{9, 15, 21, 32, 43, 54, 65, 76, 87, 98, 109, 200, -10, 60, 77, 80};
+ long[] expected = new long[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 0, 6, 7, 8};
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf xData = longBuf(xValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData));
+
+ BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
+ bigIntVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bigIntVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(bigIntVector.isNull(i));
+ assertEquals(expected[i], bigIntVector.get(i));
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testAnd() throws GandivaException, Exception {
+ /*
+ * x > 10 AND x < 20
+ */
+ ArrowType int64 = new ArrowType.Int(64, true);
+
+ Field x = Field.nullable("x", int64);
+ TreeNode xNode = TreeBuilder.makeField(x);
+ TreeNode gt10 = makeLongGreaterThanCond(xNode, 10);
+ TreeNode lt20 = makeLongLessThanCond(xNode, 20);
+ TreeNode and = TreeBuilder.makeAnd(Lists.newArrayList(gt10, lt20));
+
+ Field res = Field.nullable("res", boolType);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(and, res);
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 4;
+ byte[] validity = new byte[]{(byte) 255};
+ long[] xValues = new long[]{9, 15, 17, 25};
+ boolean[] expected = new boolean[]{false, true, true, false};
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf xData = longBuf(xValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(bitVector.isNull(i));
+ assertEquals(expected[i], bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testOr() throws GandivaException, Exception {
+ /*
+ * x > 10 OR x < 5
+ */
+ ArrowType int64 = new ArrowType.Int(64, true);
+
+ Field x = Field.nullable("x", int64);
+ TreeNode xNode = TreeBuilder.makeField(x);
+ TreeNode gt10 = makeLongGreaterThanCond(xNode, 10);
+ TreeNode lt5 = makeLongLessThanCond(xNode, 5);
+ TreeNode or = TreeBuilder.makeOr(Lists.newArrayList(gt10, lt5));
+
+ Field res = Field.nullable("res", boolType);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(or, res);
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 4;
+ byte[] validity = new byte[]{(byte) 255};
+ long[] xValues = new long[]{4, 9, 15, 17};
+ boolean[] expected = new boolean[]{true, false, true, true};
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf xData = longBuf(xValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < numRows; i++) {
+ assertFalse(bitVector.isNull(i));
+ assertEquals(expected[i], bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testNull() throws GandivaException, Exception {
+ /*
+ * when x < 10 then 1
+ * else null
+ */
+ ArrowType int64 = new ArrowType.Int(64, true);
+
+ Field x = Field.nullable("x", int64);
+ TreeNode xNode = TreeBuilder.makeField(x);
+
+ // if (x < 10) then 1 else null
+ TreeNode ifLess10 = ifLongLessThanElse(xNode, 10L, 1L, TreeBuilder.makeNull(int64), int64);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifLess10, x);
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 2;
+ byte[] validity = new byte[]{(byte) 255};
+ long[] xValues = new long[]{5, 32};
+ long[] expected = new long[]{1, 0};
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf xData = longBuf(xValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData));
+
+ BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
+ bigIntVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bigIntVector);
+ eval.evaluate(batch, output);
+
+ // first element should be 1
+ assertFalse(bigIntVector.isNull(0));
+ assertEquals(expected[0], bigIntVector.get(0));
+
+ // second element should be null
+ assertTrue(bigIntVector.isNull(1));
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testTimeNull() throws GandivaException, Exception {
+
+ ArrowType time64 = new ArrowType.Time(TimeUnit.MICROSECOND, 64);
+
+ Field x = Field.nullable("x", time64);
+ TreeNode xNode = TreeBuilder.makeNull(time64);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(xNode, x);
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 2;
+ byte[] validity = new byte[]{(byte) 255};
+ int[] xValues = new int[]{5, 32};
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf xData = intBuf(xValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData));
+
+ BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
+ bigIntVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bigIntVector);
+ eval.evaluate(batch, output);
+
+ assertTrue(bigIntVector.isNull(0));
+ assertTrue(bigIntVector.isNull(1));
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testTimeEquals() throws GandivaException, Exception { /*
+ * when isnotnull(x) then x
+ * else y
+ */
+ Field x = Field.nullable("x", new ArrowType.Time(TimeUnit.MILLISECOND, 32));
+ TreeNode xNode = TreeBuilder.makeField(x);
+
+ Field y = Field.nullable("y", new ArrowType.Time(TimeUnit.MILLISECOND, 32));
+ TreeNode yNode = TreeBuilder.makeField(y);
+
+ // if isnotnull(x) then x else y
+ TreeNode condition = TreeBuilder.makeFunction("isnotnull", Lists.newArrayList(xNode),
+ boolType);
+ TreeNode ifCoalesce = TreeBuilder.makeIf(
+ condition,
+ xNode,
+ yNode,
+ new ArrowType.Time(TimeUnit.MILLISECOND, 32));
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifCoalesce, x);
+ Schema schema = new Schema(Lists.newArrayList(x, y));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 2;
+ byte[] validity = new byte[]{(byte) 1};
+ byte[] yValidity = new byte[]{(byte) 3};
+ int[] xValues = new int[]{5, 1};
+ int[] yValues = new int[]{10, 2};
+ int[] expected = new int[]{5, 2};
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf xData = intBuf(xValues);
+
+ ArrowBuf yBufValidity = buf(yValidity);
+ ArrowBuf yData = intBuf(yValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch = new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, xData, yBufValidity, yData));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ eval.evaluate(batch, output);
+
+ // output should be 5 and 2
+ assertFalse(intVector.isNull(0));
+ assertEquals(expected[0], intVector.get(0));
+ assertEquals(expected[1], intVector.get(1));
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testIsNull() throws GandivaException, Exception {
+ Field x = Field.nullable("x", float64);
+
+ TreeNode xNode = TreeBuilder.makeField(x);
+ TreeNode isNull = TreeBuilder.makeFunction("isnull", Lists.newArrayList(xNode), boolType);
+ ExpressionTree expr = TreeBuilder.makeExpression(isNull, Field.nullable("result", boolType));
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ double[] xValues =
+ new double[]{
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf xData = doubleBuf(xValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 8; i++) {
+ assertFalse(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 8; i < numRows; i++) {
+ assertTrue(bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testEquals() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", int32);
+ Field c2 = Field.nullable("c2", int32);
+
+ TreeNode c1Node = TreeBuilder.makeField(c1);
+ TreeNode c2Node = TreeBuilder.makeField(c2);
+ TreeNode equals =
+ TreeBuilder.makeFunction("equal", Lists.newArrayList(c1Node, c2Node), boolType);
+ ExpressionTree expr = TreeBuilder.makeExpression(equals, Field.nullable("result", boolType));
+ Schema schema = new Schema(Lists.newArrayList(c1, c2));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ int[] c1Values = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] c2Values = new int[]{1, 2, 3, 4, 8, 7, 6, 5, 16, 15, 14, 13, 12, 11, 10, 9};
+
+ ArrowBuf c1Validity = buf(validity);
+ ArrowBuf c1Data = intBuf(c1Values);
+ ArrowBuf c2Validity = buf(validity);
+ ArrowBuf c2Data = intBuf(c2Values);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, c1Data, c2Validity, c2Data));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 4; i++) {
+ assertTrue(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 4; i < 8; i++) {
+ assertFalse(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 8; i < 16; i++) {
+ assertTrue(bitVector.isNull(i));
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testInExpr() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", int32);
+
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4, 5, 15, 16));
+ ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
+ Schema schema = new Schema(Lists.newArrayList(c1));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ int[] c1Values = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+
+ ArrowBuf c1Validity = buf(validity);
+ ArrowBuf c1Data = intBuf(c1Values);
+ ArrowBuf c2Validity = buf(validity);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, c1Data, c2Validity));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 5; i++) {
+ assertTrue(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 5; i < 16; i++) {
+ assertFalse(bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testInExprDecimal() throws GandivaException, Exception {
+ Integer precision = 26;
+ Integer scale = 5;
+ ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128);
+ Field c1 = Field.nullable("c1", decimal);
+
+ String[] values = new String[]{"1", "2", "3", "4"};
+ Set<BigDecimal> decimalSet = decimalSet(values, scale);
+ decimalSet.add(new BigDecimal(-0.0));
+ decimalSet.add(new BigDecimal(Long.MAX_VALUE));
+ decimalSet.add(new BigDecimal(Long.MIN_VALUE));
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionDecimal(TreeBuilder.makeField(c1),
+ decimalSet, precision, scale);
+ ExpressionTree expr = TreeBuilder.makeExpression(inExpr,
+ Field.nullable("result", boolType));
+ Schema schema = new Schema(Lists.newArrayList(c1));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ String[] c1Values =
+ new String[]{"1", "2", "3", "4", "-0.0", "6", "7", "8", "9", "10", "11", "12", "13", "14",
+ String.valueOf(Long.MAX_VALUE),
+ String.valueOf(Long.MIN_VALUE)};
+
+ DecimalVector c1Data = decimalVector(c1Values, precision, scale);
+ ArrowBuf c1Validity = buf(validity);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, c1Data.getDataBuffer(), c1Data.getValidityBuffer()));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 5; i++) {
+ assertTrue(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 5; i < 16; i++) {
+ assertFalse(bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testInExprDouble() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", float64);
+
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionDouble(TreeBuilder.makeField(c1),
+ Sets.newHashSet(1.0, -0.0, 3.0, 4.0, Double.NaN,
+ Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY));
+ ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
+ Schema schema = new Schema(Lists.newArrayList(c1));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ // Create a row-batch with some sample data to look for
+ int numRows = 16;
+ // Only the first 8 values will be valid.
+ byte[] validity = new byte[]{(byte) 255, 0};
+ double[] c1Values = new double[]{1, -0.0, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.NaN,
+ 6, 7, 8, 9, 10, 11, 12, 13, 14, 4, 3};
+
+ ArrowBuf c1Validity = buf(validity);
+ ArrowBuf c1Data = doubleBuf(c1Values);
+ ArrowBuf c2Validity = buf(validity);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, c1Data, c2Validity));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ // The first four values in the vector must match the expression, but not the other ones.
+ for (int i = 0; i < 4; i++) {
+ assertTrue(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 4; i < 16; i++) {
+ assertFalse(bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testInExprStrings() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", new ArrowType.Utf8());
+
+ TreeNode l1 = TreeBuilder.makeLiteral(1L);
+ TreeNode l2 = TreeBuilder.makeLiteral(3L);
+ List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
+ TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
+ ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
+ Schema schema = new Schema(Lists.newArrayList(c1));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
+ "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
+ "sixteen"};
+
+ ArrowBuf c1Validity = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(c1Values);
+ ArrowBuf c2Validity = buf(validity);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity));
+
+ BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator);
+ bitVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(bitVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 4; i++) {
+ assertTrue(bitVector.getObject(i).booleanValue());
+ }
+ for (int i = 5; i < 16; i++) {
+ assertFalse(bitVector.getObject(i).booleanValue());
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testSmallOutputVectors() throws GandivaException, Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Field retType = Field.nullable("c", int32);
+ ExpressionTree root = TreeBuilder.makeExpression("add", args, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(root);
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] bValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
+
+ ArrowBuf aValidity = buf(validity);
+ ArrowBuf aData = intBuf(aValues);
+ ArrowBuf bValidity = buf(validity);
+ ArrowBuf b2Validity = buf(validity);
+ ArrowBuf bData = intBuf(bValues);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)),
+ Lists.newArrayList(aValidity, aData, bValidity, bData, b2Validity));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ try {
+ eval.evaluate(batch, output);
+ } catch (Throwable t) {
+ intVector.allocateNew(numRows);
+ eval.evaluate(batch, output);
+ }
+
+ for (int i = 0; i < 8; i++) {
+ assertFalse(intVector.isNull(i));
+ assertEquals(17, intVector.get(i));
+ }
+ for (int i = 8; i < 16; i++) {
+ assertTrue(intVector.isNull(i));
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testDateTime() throws GandivaException, Exception {
+ ArrowType date64 = new ArrowType.Date(DateUnit.MILLISECOND);
+ // ArrowType time32 = new ArrowType.Time(TimeUnit.MILLISECOND, 32);
+ ArrowType timeStamp = new ArrowType.Timestamp(TimeUnit.MILLISECOND, "TZ");
+
+ Field dateField = Field.nullable("date", date64);
+ // Field timeField = Field.nullable("time", time32);
+ Field tsField = Field.nullable("timestamp", timeStamp);
+
+ TreeNode dateNode = TreeBuilder.makeField(dateField);
+ TreeNode tsNode = TreeBuilder.makeField(tsField);
+
+ List<TreeNode> dateArgs = Lists.newArrayList(dateNode);
+ TreeNode dateToYear = TreeBuilder.makeFunction("extractYear", dateArgs, int64);
+ TreeNode dateToMonth = TreeBuilder.makeFunction("extractMonth", dateArgs, int64);
+ TreeNode dateToDay = TreeBuilder.makeFunction("extractDay", dateArgs, int64);
+ TreeNode dateToHour = TreeBuilder.makeFunction("extractHour", dateArgs, int64);
+ TreeNode dateToMin = TreeBuilder.makeFunction("extractMinute", dateArgs, int64);
+
+ List<TreeNode> tsArgs = Lists.newArrayList(tsNode);
+ TreeNode tsToYear = TreeBuilder.makeFunction("extractYear", tsArgs, int64);
+ TreeNode tsToMonth = TreeBuilder.makeFunction("extractMonth", tsArgs, int64);
+ TreeNode tsToDay = TreeBuilder.makeFunction("extractDay", tsArgs, int64);
+ TreeNode tsToHour = TreeBuilder.makeFunction("extractHour", tsArgs, int64);
+ TreeNode tsToMin = TreeBuilder.makeFunction("extractMinute", tsArgs, int64);
+
+ Field resultField = Field.nullable("result", int64);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(dateToYear, resultField),
+ TreeBuilder.makeExpression(dateToMonth, resultField),
+ TreeBuilder.makeExpression(dateToDay, resultField),
+ TreeBuilder.makeExpression(dateToHour, resultField),
+ TreeBuilder.makeExpression(dateToMin, resultField),
+ TreeBuilder.makeExpression(tsToYear, resultField),
+ TreeBuilder.makeExpression(tsToMonth, resultField),
+ TreeBuilder.makeExpression(tsToDay, resultField),
+ TreeBuilder.makeExpression(tsToHour, resultField),
+ TreeBuilder.makeExpression(tsToMin, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(dateField, tsField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 8;
+ byte[] validity = new byte[]{(byte) 255};
+ String[] values =
+ new String[]{
+ "2007-01-01T01:00:00.00Z",
+ "2007-03-05T03:40:00.00Z",
+ "2008-05-31T13:55:00.00Z",
+ "2000-06-30T23:20:00.00Z",
+ "2000-07-10T20:30:00.00Z",
+ "2000-08-20T00:14:00.00Z",
+ "2000-09-30T02:29:00.00Z",
+ "2000-10-31T05:33:00.00Z"
+ };
+ long[] expYearFromDate = new long[]{2007, 2007, 2008, 2000, 2000, 2000, 2000, 2000};
+ long[] expMonthFromDate = new long[]{1, 3, 5, 6, 7, 8, 9, 10};
+ long[] expDayFromDate = new long[]{1, 5, 31, 30, 10, 20, 30, 31};
+ long[] expHourFromDate = new long[]{1, 3, 13, 23, 20, 0, 2, 5};
+ long[] expMinFromDate = new long[]{0, 40, 55, 20, 30, 14, 29, 33};
+
+ long[][] expValues =
+ new long[][]{
+ expYearFromDate, expMonthFromDate, expDayFromDate, expHourFromDate, expMinFromDate
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf millisData = stringToMillis(values);
+ ArrowBuf buf2Validity = buf(validity);
+ ArrowBuf millis2Data = stringToMillis(values);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(bufValidity, millisData, buf2Validity, millis2Data));
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ for (int i = 0; i < exprs.size(); i++) {
+ BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
+ bigIntVector.allocateNew(numRows);
+ output.add(bigIntVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (int i = 0; i < output.size(); i++) {
+ long[] expected = expValues[i % 5];
+ BigIntVector bigIntVector = (BigIntVector) output.get(i);
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(bigIntVector.isNull(j));
+ assertEquals(expected[j], bigIntVector.get(j));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testDateTrunc() throws Exception {
+ ArrowType date64 = new ArrowType.Date(DateUnit.MILLISECOND);
+ Field dateField = Field.nullable("date", date64);
+
+ TreeNode dateNode = TreeBuilder.makeField(dateField);
+
+ List<TreeNode> dateArgs = Lists.newArrayList(dateNode);
+ TreeNode dateToYear = TreeBuilder.makeFunction("date_trunc_Year", dateArgs, date64);
+ TreeNode dateToMonth = TreeBuilder.makeFunction("date_trunc_Month", dateArgs, date64);
+
+ Field resultField = Field.nullable("result", date64);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(dateToYear, resultField),
+ TreeBuilder.makeExpression(dateToMonth, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(dateField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 4;
+ byte[] validity = new byte[]{(byte) 255};
+ String[] values = new String[]{
+ "2007-01-01T01:00:00.00Z",
+ "2007-03-05T03:40:00.00Z",
+ "2008-05-31T13:55:00.00Z",
+ "2000-06-30T23:20:00.00Z",
+ };
+ String[] expYearFromDate = new String[]{
+ "2007-01-01T00:00:00.00Z",
+ "2007-01-01T00:00:00.00Z",
+ "2008-01-01T00:00:00.00Z",
+ "2000-01-01T00:00:00.00Z",
+ };
+ String[] expMonthFromDate = new String[]{
+ "2007-01-01T00:00:00.00Z",
+ "2007-03-01T00:00:00.00Z",
+ "2008-05-01T00:00:00.00Z",
+ "2000-06-01T00:00:00.00Z",
+ };
+
+ String[][] expValues = new String[][]{ expYearFromDate, expMonthFromDate};
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf millisData = stringToMillis(values);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, millisData));
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ for (int i = 0; i < exprs.size(); i++) {
+ BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
+ bigIntVector.allocateNew(numRows);
+ output.add(bigIntVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (int i = 0; i < output.size(); i++) {
+ String[] expected = expValues[i];
+ BigIntVector bigIntVector = (BigIntVector) output.get(i);
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(bigIntVector.isNull(j));
+ assertEquals(Instant.parse(expected[j]).toEpochMilli(), bigIntVector.get(j));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testUnknownFunction() {
+ Field c1 = Field.nullable("c1", int8);
+ Field c2 = Field.nullable("c2", int8);
+
+ TreeNode c1Node = TreeBuilder.makeField(c1);
+ TreeNode c2Node = TreeBuilder.makeField(c2);
+
+ TreeNode unknown =
+ TreeBuilder.makeFunction("xxx_yyy", Lists.newArrayList(c1Node, c2Node), int8);
+ ExpressionTree expr = TreeBuilder.makeExpression(unknown, Field.nullable("result", int8));
+ Schema schema = new Schema(Lists.newArrayList(c1, c2));
+ boolean caughtException = false;
+ try {
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+ } catch (GandivaException ge) {
+ caughtException = true;
+ }
+
+ assertTrue(caughtException);
+ }
+
+ @Test
+ public void testCastTimestampToString() throws Exception {
+ ArrowType timeStamp = new ArrowType.Timestamp(TimeUnit.MILLISECOND, "TZ");
+
+ Field tsField = Field.nullable("timestamp", timeStamp);
+ Field lenField = Field.nullable("outLength", int64);
+
+ TreeNode tsNode = TreeBuilder.makeField(tsField);
+ TreeNode lenNode = TreeBuilder.makeField(lenField);
+
+ TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(tsNode, lenNode),
+ new ArrowType.Utf8());
+
+ Field resultField = Field.nullable("result", new ArrowType.Utf8());
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(tsToString, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(tsField, lenField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ String[] values =
+ new String[] {
+ "0007-01-01T01:00:00Z",
+ "2007-03-05T03:40:00Z",
+ "2008-05-31T13:55:00Z",
+ "2000-06-30T23:20:00Z",
+ "2000-07-10T20:30:00Z",
+ };
+ long[] lenValues =
+ new long[] {
+ 23L, 24L, 22L, 0L, 4L
+ };
+
+ String[] expValues =
+ new String[] {
+ "0007-01-01 01:00:00.000",
+ "2007-03-05 03:40:00.000",
+ "2008-05-31 13:55:00.00",
+ "",
+ "2000",
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf millisData = stringToMillis(values);
+ ArrowBuf lenValidity = buf(validity);
+ ArrowBuf lenData = longBuf(lenValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(bufValidity, millisData, lenValidity, lenData));
+
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+
+ charVector.allocateNew(numRows * 23, numRows);
+ output.add(charVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (ValueVector valueVector : output) {
+ VarCharVector charVector = (VarCharVector) valueVector;
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(charVector.isNull(j));
+ assertEquals(expValues[j], new String(charVector.get(j)));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testCastDayIntervalToBigInt() throws Exception {
+ ArrowType dayIntervalType = new ArrowType.Interval(IntervalUnit.DAY_TIME);
+
+ Field dayIntervalField = Field.nullable("dayInterval", dayIntervalType);
+
+ TreeNode intervalNode = TreeBuilder.makeField(dayIntervalField);
+
+ TreeNode intervalToBigint = TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(intervalNode), int64);
+
+ Field resultField = Field.nullable("result", int64);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(intervalToBigint, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(dayIntervalField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 5;
+ byte[] validity = new byte[]{(byte) 255};
+ String[] values =
+ new String[]{
+ "1 0", // "days millis"
+ "2 0",
+ "1 1",
+ "10 5000",
+ "11 86400001",
+ };
+
+ Long[] expValues =
+ new Long[]{
+ 86400000L,
+ 2 * 86400000L,
+ 86400000L + 1L,
+ 10 * 86400000L + 5000L,
+ 11 * 86400000L + 86400001L
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf intervalsData = stringToDayInterval(values);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(bufValidity, intervalsData));
+
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
+ bigIntVector.allocateNew(numRows);
+ output.add(bigIntVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (ValueVector valueVector : output) {
+ BigIntVector bigintVector = (BigIntVector) valueVector;
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(bigintVector.isNull(j));
+ assertEquals(expValues[j], Long.valueOf(bigintVector.get(j)));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testCaseInsensitiveFunctions() throws Exception {
+ ArrowType timeStamp = new ArrowType.Timestamp(TimeUnit.MILLISECOND, "TZ");
+
+ Field tsField = Field.nullable("timestamp", timeStamp);
+
+ TreeNode tsNode = TreeBuilder.makeField(tsField);
+
+ TreeNode extractday = TreeBuilder.makeFunction("extractday", Lists.newArrayList(tsNode),
+ int64);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(extractday, Field.nullable("result", int64));
+ Schema schema = new Schema(Lists.newArrayList(tsField));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ String[] values =
+ new String[] {
+ "0007-01-01T01:00:00Z",
+ "2007-03-05T03:40:00Z",
+ "2008-05-31T13:55:00Z",
+ "2000-06-30T23:20:00Z",
+ "2000-07-10T20:30:00Z",
+ };
+
+ long[] expValues =
+ new long[] {
+ 1, 5, 31, 30, 10
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf millisData = stringToMillis(values);
+
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, millisData));
+
+ List<ValueVector> output = new ArrayList<>();
+ BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
+ bigIntVector.allocateNew(numRows);
+ output.add(bigIntVector);
+
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (ValueVector valueVector : output) {
+ BigIntVector vector = (BigIntVector) valueVector;
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(vector.isNull(j));
+ assertEquals(expValues[j], vector.get(j));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testCastInt() throws Exception {
+ Field inField = Field.nullable("input", new ArrowType.Utf8());
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode),
+ int32);
+ Field resultField = Field.nullable("result", int32);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(castINTFn, resultField));
+ Schema schema = new Schema(Lists.newArrayList(inField));
+ Projector eval = Projector.make(schema, exprs);
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ String[] values =
+ new String[] {
+ "0", "123", "-123", "-1", "1"
+ };
+ int[] expValues =
+ new int[] {
+ 0, 123, -123, -1, 1
+ };
+ ArrowBuf bufValidity = buf(validity);
+ List<ArrowBuf> bufData = stringBufs(values);
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1)));
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+ output.add(intVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+ for (ValueVector valueVector : output) {
+ IntVector intVector = (IntVector) valueVector;
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(intVector.isNull(j));
+ assertTrue(expValues[j] == intVector.get(j));
+ }
+ }
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test(expected = GandivaException.class)
+ public void testCastIntInvalidValue() throws Exception {
+ Field inField = Field.nullable("input", new ArrowType.Utf8());
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode),
+ int32);
+ Field resultField = Field.nullable("result", int32);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(castINTFn, resultField));
+ Schema schema = new Schema(Lists.newArrayList(inField));
+ Projector eval = Projector.make(schema, exprs);
+ int numRows = 1;
+ byte[] validity = new byte[] {(byte) 255};
+ String[] values =
+ new String[] {
+ "abc"
+ };
+ ArrowBuf bufValidity = buf(validity);
+ List<ArrowBuf> bufData = stringBufs(values);
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1)));
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+ output.add(intVector);
+ }
+ try {
+ eval.evaluate(batch, output);
+ } finally {
+ eval.close();
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+ }
+
+ @Test
+ public void testCastFloat() throws Exception {
+ Field inField = Field.nullable("input", new ArrowType.Utf8());
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode),
+ float64);
+ Field resultField = Field.nullable("result", float64);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(castFLOAT8Fn, resultField));
+ Schema schema = new Schema(Lists.newArrayList(inField));
+ Projector eval = Projector.make(schema, exprs);
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ String[] values =
+ new String[] {
+ "2.3",
+ "-11.11",
+ "0",
+ "111",
+ "12345.67"
+ };
+ double[] expValues =
+ new double[] {
+ 2.3, -11.11, 0, 111, 12345.67
+ };
+ ArrowBuf bufValidity = buf(validity);
+ List<ArrowBuf> bufData = stringBufs(values);
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1)));
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator);
+ float8Vector.allocateNew(numRows);
+ output.add(float8Vector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+ for (ValueVector valueVector : output) {
+ Float8Vector float8Vector = (Float8Vector) valueVector;
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(float8Vector.isNull(j));
+ assertTrue(expValues[j] == float8Vector.get(j));
+ }
+ }
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testCastFloatVarbinary() throws Exception {
+ Field inField = Field.nullable("input", new ArrowType.Binary());
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode),
+ float64);
+ Field resultField = Field.nullable("result", float64);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(castFLOAT8Fn, resultField));
+ Schema schema = new Schema(Lists.newArrayList(inField));
+ Projector eval = Projector.make(schema, exprs);
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ String[] values =
+ new String[] {
+ "2.3",
+ "-11.11",
+ "0",
+ "111",
+ "12345.67"
+ };
+ double[] expValues =
+ new double[] {
+ 2.3, -11.11, 0, 111, 12345.67
+ };
+ ArrowBuf bufValidity = buf(validity);
+ List<ArrowBuf> bufData = stringBufs(values);
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1)));
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator);
+ float8Vector.allocateNew(numRows);
+ output.add(float8Vector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+ for (ValueVector valueVector : output) {
+ Float8Vector float8Vector = (Float8Vector) valueVector;
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(float8Vector.isNull(j));
+ assertTrue(expValues[j] == float8Vector.get(j));
+ }
+ }
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test(expected = GandivaException.class)
+ public void testCastFloatInvalidValue() throws Exception {
+ Field inField = Field.nullable("input", new ArrowType.Utf8());
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode),
+ float64);
+ Field resultField = Field.nullable("result", float64);
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(castFLOAT8Fn, resultField));
+ Schema schema = new Schema(Lists.newArrayList(inField));
+ Projector eval = Projector.make(schema, exprs);
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ String[] values =
+ new String[] {
+ "2.3",
+ "-11.11",
+ "abc",
+ "111",
+ "12345.67"
+ };
+ ArrowBuf bufValidity = buf(validity);
+ List<ArrowBuf> bufData = stringBufs(values);
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1)));
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator);
+ float8Vector.allocateNew(numRows);
+ output.add(float8Vector);
+ }
+ try {
+ eval.evaluate(batch, output);
+ } finally {
+ eval.close();
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+ }
+
+ @Test
+ public void testEvaluateWithUnsetTargetHostCPU() throws Exception {
+ Field a = Field.nullable("a", int32);
+ Field b = Field.nullable("b", int32);
+ List<Field> args = Lists.newArrayList(a, b);
+
+ Field retType = Field.nullable("c", int32);
+ ExpressionTree root = TreeBuilder.makeExpression("add", args, retType);
+
+ List<ExpressionTree> exprs = Lists.newArrayList(root);
+
+ Schema schema = new Schema(args);
+ Projector eval = Projector.make(schema, exprs, new ConfigurationBuilder.ConfigOptions().withTargetCPU(false ));
+
+ int numRows = 16;
+ byte[] validity = new byte[]{(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] bValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
+
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesb = intBuf(bValues);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)),
+ Lists.newArrayList(validitya, valuesa, validityb, valuesb));
+
+ IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);
+ intVector.allocateNew(numRows);
+
+ List<ValueVector> output = new ArrayList<ValueVector>();
+ output.add(intVector);
+ eval.evaluate(batch, output);
+
+ for (int i = 0; i < 8; i++) {
+ assertFalse(intVector.isNull(i));
+ assertEquals(17, intVector.get(i));
+ }
+ for (int i = 8; i < 16; i++) {
+ assertTrue(intVector.isNull(i));
+ }
+
+ // free buffers
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ eval.close();
+ }
+
+ @Test
+ public void testCastVarcharFromInteger() throws Exception {
+ Field inField = Field.nullable("input", int32);
+ Field lenField = Field.nullable("outLength", int64);
+
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode lenNode = TreeBuilder.makeField(lenField);
+
+ TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(inNode, lenNode),
+ new ArrowType.Utf8());
+
+ Field resultField = Field.nullable("result", new ArrowType.Utf8());
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(tsToString, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(inField, lenField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ int[] values =
+ new int[] {
+ 2345,
+ 2345,
+ 2345,
+ 2345,
+ -2345,
+ };
+ long[] lenValues =
+ new long[] {
+ 0L, 4L, 2L, 6L, 5L
+ };
+
+ String[] expValues =
+ new String[] {
+ "",
+ Integer.toString(2345).substring(0, 4),
+ Integer.toString(2345).substring(0, 2),
+ Integer.toString(2345),
+ Integer.toString(-2345)
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf bufData = intBuf(values);
+ ArrowBuf lenValidity = buf(validity);
+ ArrowBuf lenData = longBuf(lenValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(bufValidity, bufData, lenValidity, lenData));
+
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+
+ charVector.allocateNew(numRows * 5, numRows);
+ output.add(charVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (ValueVector valueVector : output) {
+ VarCharVector charVector = (VarCharVector) valueVector;
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(charVector.isNull(j));
+ assertEquals(expValues[j], new String(charVector.get(j)));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testCastVarcharFromFloat() throws Exception {
+ Field inField = Field.nullable("input", float64);
+ Field lenField = Field.nullable("outLength", int64);
+
+ TreeNode inNode = TreeBuilder.makeField(inField);
+ TreeNode lenNode = TreeBuilder.makeField(lenField);
+
+ TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(inNode, lenNode),
+ new ArrowType.Utf8());
+
+ Field resultField = Field.nullable("result", new ArrowType.Utf8());
+ List<ExpressionTree> exprs =
+ Lists.newArrayList(
+ TreeBuilder.makeExpression(tsToString, resultField));
+
+ Schema schema = new Schema(Lists.newArrayList(inField, lenField));
+ Projector eval = Projector.make(schema, exprs);
+
+ int numRows = 5;
+ byte[] validity = new byte[] {(byte) 255};
+ double[] values =
+ new double[] {
+ 0.0,
+ -0.0,
+ 1.0,
+ 0.001,
+ 0.0009,
+ 0.00099893,
+ 999999.9999,
+ 10000000.0,
+ 23943410000000.343434,
+ Double.POSITIVE_INFINITY,
+ Double.NEGATIVE_INFINITY,
+ Double.NaN,
+ 23.45,
+ 23.45,
+ -23.45,
+ };
+ long[] lenValues =
+ new long[] {
+ 6L, 6L, 6L, 6L, 10L, 15L, 15L, 15L, 30L,
+ 15L, 15L, 15L, 0L, 6L, 6L
+ };
+
+ /* The Java real numbers are represented in two ways and Gandiva must
+ * follow the same rules:
+ * - If the number is greater or equals than 10^7 and less than 10^(-3)
+ * it will be represented using scientific notation, e.g:
+ * - 0.000012 -> 1.2E-5
+ * - 10000002.3 -> 1.00000023E7
+ * - If the numbers are between that interval above, they are showed as is.
+ *
+ * The test checks if the Gandiva function casts the number with the same notation of the
+ * Java.
+ * */
+ String[] expValues =
+ new String[] {
+ Double.toString(0.0), // must be cast to -> "0.0"
+ Double.toString(-0.0), // must be cast to -> "-0.0"
+ Double.toString(1.0), // must be cast to -> "1.0"
+ Double.toString(0.001), // must be cast to -> "0.001"
+ Double.toString(0.0009), // must be cast to -> "9E-4"
+ Double.toString(0.00099893), // must be cast to -> "9E-4"
+ Double.toString(999999.9999), // must be cast to -> "999999.9999"
+ Double.toString(10000000.0), // must be cast to 1E7
+ Double.toString(23943410000000.343434),
+ Double.toString(Double.POSITIVE_INFINITY),
+ Double.toString(Double.NEGATIVE_INFINITY),
+ Double.toString(Double.NaN),
+ "",
+ Double.toString(23.45),
+ Double.toString(-23.45)
+ };
+
+ ArrowBuf bufValidity = buf(validity);
+ ArrowBuf bufData = doubleBuf(values);
+ ArrowBuf lenValidity = buf(validity);
+ ArrowBuf lenData = longBuf(lenValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode, fieldNode),
+ Lists.newArrayList(bufValidity, bufData, lenValidity, lenData));
+
+ List<ValueVector> output = new ArrayList<>();
+ for (int i = 0; i < exprs.size(); i++) {
+ VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+
+ charVector.allocateNew(numRows * 5, numRows);
+ output.add(charVector);
+ }
+ eval.evaluate(batch, output);
+ eval.close();
+
+ for (ValueVector valueVector : output) {
+ VarCharVector charVector = (VarCharVector) valueVector;
+
+ for (int j = 0; j < numRows; j++) {
+ assertFalse(charVector.isNull(j));
+ assertEquals(expValues[j], new String(charVector.get(j)));
+ }
+ }
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+
+ @Test
+ public void testInitCap() throws Exception {
+
+ Field x = Field.nullable("x", new ArrowType.Utf8());
+
+ Field retType = Field.nullable("c", new ArrowType.Utf8());
+
+ TreeNode cond =
+ TreeBuilder.makeFunction(
+ "initcap",
+ Lists.newArrayList(TreeBuilder.makeField(x)),
+ new ArrowType.Utf8());
+ ExpressionTree expr = TreeBuilder.makeExpression(cond, retType);
+ Schema schema = new Schema(Lists.newArrayList(x));
+ Projector eval = Projector.make(schema, Lists.newArrayList(expr));
+
+ int numRows = 5;
+ byte[] validity = new byte[]{(byte) 15, 0};
+ String[] valuesX = new String[]{
+ " øhpqršvñ \n\n",
+ "möbelträger1füße \nmöbelträge'rfüße",
+ "ÂbĆDËFgh\néll",
+ "citroën CaR",
+ "kjk"
+ };
+
+ String[] expected = new String[]{
+ " Øhpqršvñ \n\n",
+ "Möbelträger1füße \nMöbelträge'Rfüße",
+ "Âbćdëfgh\nÉll",
+ "Citroën Car",
+ null
+ };
+
+ ArrowBuf validityX = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(valuesX);
+
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(new ArrowFieldNode(numRows, 0)),
+ Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1)));
+
+ // allocate data for output vector.
+ VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator);
+ outVector.allocateNew(numRows * 100, numRows);
+
+ // evaluate expression
+ List<ValueVector> output = new ArrayList<>();
+ output.add(outVector);
+ eval.evaluate(batch, output);
+ eval.close();
+
+ // match expected output.
+ for (int i = 0; i < numRows - 1; i++) {
+ assertFalse("Expect none value equals null", outVector.isNull(i));
+ assertEquals(expected[i], new String(outVector.get(i)));
+ }
+
+ assertTrue("Last value must be null", outVector.isNull(numRows - 1));
+
+ releaseRecordBatch(batch);
+ releaseValueVectors(output);
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/TestJniLoader.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/TestJniLoader.java
new file mode 100644
index 000000000..116f0dd9e
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/TestJniLoader.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.evaluator;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestJniLoader {
+
+ @Test
+ public void testDefaultConfiguration() throws Exception {
+ long configId = JniLoader.getConfiguration(ConfigurationBuilder.ConfigOptions.getDefault());
+ Assert.assertEquals(configId, JniLoader.getDefaultConfiguration());
+ Assert.assertEquals(configId, JniLoader.getConfiguration(ConfigurationBuilder.ConfigOptions.getDefault()));
+
+ long configId2 = JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions().withOptimize(false));
+ long configId3 = JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions().withTargetCPU(false));
+ long configId4 = JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions().withOptimize(false)
+ .withTargetCPU(false));
+
+ Assert.assertTrue(configId != configId2 && configId2 != configId3 && configId3 != configId4);
+
+ Assert.assertEquals(configId2, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions()
+ .withOptimize(false)));
+ Assert.assertEquals(configId3, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions()
+ .withTargetCPU(false)));
+ Assert.assertEquals(configId4, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions()
+ .withOptimize(false).withTargetCPU(false)));
+
+ JniLoader.removeConfiguration(new ConfigurationBuilder.ConfigOptions().withOptimize(false));
+ // configids are monotonically updated. after a config is removed, new one is assigned with higher id
+ Assert.assertNotEquals(configId2, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions()
+ .withOptimize(false)));
+
+ JniLoader.removeConfiguration(new ConfigurationBuilder.ConfigOptions());
+ Assert.assertNotEquals(configId, JniLoader.getConfiguration(ConfigurationBuilder.ConfigOptions.getDefault()));
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/ArrowTypeHelperTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/ArrowTypeHelperTest.java
new file mode 100644
index 000000000..7ddd602bf
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/ArrowTypeHelperTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import static org.junit.Assert.assertEquals;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+public class ArrowTypeHelperTest {
+
+ private void testInt(int width, boolean isSigned, int expected) throws GandivaException {
+ ArrowType arrowType = new ArrowType.Int(width, isSigned);
+ GandivaTypes.ExtGandivaType gandivaType = ArrowTypeHelper.arrowTypeToProtobuf(arrowType);
+ assertEquals(expected, gandivaType.getType().getNumber());
+ }
+
+ @Test
+ public void testAllInts() throws GandivaException {
+ testInt(8, false, GandivaTypes.GandivaType.UINT8_VALUE);
+ testInt(8, true, GandivaTypes.GandivaType.INT8_VALUE);
+ testInt(16, false, GandivaTypes.GandivaType.UINT16_VALUE);
+ testInt(16, true, GandivaTypes.GandivaType.INT16_VALUE);
+ testInt(32, false, GandivaTypes.GandivaType.UINT32_VALUE);
+ testInt(32, true, GandivaTypes.GandivaType.INT32_VALUE);
+ testInt(64, false, GandivaTypes.GandivaType.UINT64_VALUE);
+ testInt(64, true, GandivaTypes.GandivaType.INT64_VALUE);
+ }
+
+ private void testFloat(FloatingPointPrecision precision, int expected) throws GandivaException {
+ ArrowType arrowType = new ArrowType.FloatingPoint(precision);
+ GandivaTypes.ExtGandivaType gandivaType = ArrowTypeHelper.arrowTypeToProtobuf(arrowType);
+ assertEquals(expected, gandivaType.getType().getNumber());
+ }
+
+ @Test
+ public void testAllFloats() throws GandivaException {
+ testFloat(FloatingPointPrecision.HALF, GandivaTypes.GandivaType.HALF_FLOAT_VALUE);
+ testFloat(FloatingPointPrecision.SINGLE, GandivaTypes.GandivaType.FLOAT_VALUE);
+ testFloat(FloatingPointPrecision.DOUBLE, GandivaTypes.GandivaType.DOUBLE_VALUE);
+ }
+
+ private void testBasic(ArrowType arrowType, int expected) throws GandivaException {
+ GandivaTypes.ExtGandivaType gandivaType = ArrowTypeHelper.arrowTypeToProtobuf(arrowType);
+ assertEquals(expected, gandivaType.getType().getNumber());
+ }
+
+ @Test
+ public void testSimpleTypes() throws GandivaException {
+ testBasic(new ArrowType.Bool(), GandivaTypes.GandivaType.BOOL_VALUE);
+ testBasic(new ArrowType.Binary(), GandivaTypes.GandivaType.BINARY_VALUE);
+ testBasic(new ArrowType.Utf8(), GandivaTypes.GandivaType.UTF8_VALUE);
+ }
+
+ @Test
+ public void testField() throws GandivaException {
+ Field field = Field.nullable("col1", new ArrowType.Bool());
+ GandivaTypes.Field f = ArrowTypeHelper.arrowFieldToProtobuf(field);
+ assertEquals(field.getName(), f.getName());
+ assertEquals(true, f.getNullable());
+ assertEquals(GandivaTypes.GandivaType.BOOL_VALUE, f.getType().getType().getNumber());
+ }
+
+ @Test
+ public void testSchema() throws GandivaException {
+ Field a = Field.nullable("a", new ArrowType.Int(16, false));
+ Field b = Field.nullable("b", new ArrowType.Int(32, true));
+ Field c = Field.nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE));
+
+ List<Field> fields = new ArrayList<Field>();
+ fields.add(a);
+ fields.add(b);
+ fields.add(c);
+
+ GandivaTypes.Schema schema = ArrowTypeHelper.arrowSchemaToProtobuf(new Schema(fields));
+ int idx = 0;
+ for (GandivaTypes.Field f : schema.getColumnsList()) {
+ assertEquals(fields.get(idx).getName(), f.getName());
+ idx++;
+ }
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/TreeBuilderTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/TreeBuilderTest.java
new file mode 100644
index 000000000..90373cf79
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/TreeBuilderTest.java
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.gandiva.expression;
+
+import static org.junit.Assert.*;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.arrow.gandiva.exceptions.GandivaException;
+import org.apache.arrow.gandiva.ipc.GandivaTypes;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.junit.Test;
+
+public class TreeBuilderTest {
+
+ @Test
+ public void testMakeLiteral() throws GandivaException {
+ TreeNode n = TreeBuilder.makeLiteral(Boolean.TRUE);
+ GandivaTypes.TreeNode node = n.toProtobuf();
+
+ assertEquals(true, node.getBooleanNode().getValue());
+
+ n = TreeBuilder.makeLiteral(new Integer(10));
+ node = n.toProtobuf();
+ assertEquals(10, node.getIntNode().getValue());
+
+ n = TreeBuilder.makeLiteral(new Long(50));
+ node = n.toProtobuf();
+ assertEquals(50, node.getLongNode().getValue());
+
+ Float f = new Float(2.5);
+ n = TreeBuilder.makeLiteral(f);
+ node = n.toProtobuf();
+ assertEquals(f.floatValue(), node.getFloatNode().getValue(), 0.1);
+
+ Double d = new Double(3.3);
+ n = TreeBuilder.makeLiteral(d);
+ node = n.toProtobuf();
+ assertEquals(d.doubleValue(), node.getDoubleNode().getValue(), 0.1);
+
+ String s = new String("hello");
+ n = TreeBuilder.makeStringLiteral(s);
+ node = n.toProtobuf();
+ assertArrayEquals(s.getBytes(), node.getStringNode().getValue().toByteArray());
+
+ byte[] b = new String("hello").getBytes();
+ n = TreeBuilder.makeBinaryLiteral(b);
+ node = n.toProtobuf();
+ assertArrayEquals(b, node.getBinaryNode().getValue().toByteArray());
+ }
+
+ @Test
+ public void testMakeNull() throws GandivaException {
+ TreeNode n = TreeBuilder.makeNull(new ArrowType.Bool());
+ GandivaTypes.TreeNode node = n.toProtobuf();
+ assertEquals(
+ GandivaTypes.GandivaType.BOOL_VALUE, node.getNullNode().getType().getType().getNumber());
+
+ n = TreeBuilder.makeNull(new ArrowType.Int(32, true));
+ node = n.toProtobuf();
+ assertEquals(
+ GandivaTypes.GandivaType.INT32_VALUE, node.getNullNode().getType().getType().getNumber());
+
+ n = TreeBuilder.makeNull(new ArrowType.Int(64, false));
+ node = n.toProtobuf();
+ assertEquals(
+ GandivaTypes.GandivaType.UINT64_VALUE, node.getNullNode().getType().getType().getNumber());
+
+ n = TreeBuilder.makeNull(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE));
+ node = n.toProtobuf();
+ assertEquals(
+ GandivaTypes.GandivaType.FLOAT_VALUE, node.getNullNode().getType().getType().getNumber());
+ }
+
+ @Test
+ public void testMakeField() throws GandivaException {
+ TreeNode n = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Int(32, true)));
+ GandivaTypes.TreeNode node = n.toProtobuf();
+
+ assertEquals("a", node.getFieldNode().getField().getName());
+ assertEquals(
+ GandivaTypes.GandivaType.INT32_VALUE,
+ node.getFieldNode().getField().getType().getType().getNumber());
+ }
+
+ @Test
+ public void testMakeFunction() throws GandivaException {
+ TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Int(64, false)));
+ TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Int(64, false)));
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(a);
+ args.add(b);
+
+ TreeNode addNode = TreeBuilder.makeFunction("add", args, new ArrowType.Int(64, false));
+ GandivaTypes.TreeNode node = addNode.toProtobuf();
+
+ assertTrue(node.hasFnNode());
+ assertEquals("add", node.getFnNode().getFunctionName());
+ assertEquals("a", node.getFnNode().getInArgsList().get(0).getFieldNode().getField().getName());
+ assertEquals("b", node.getFnNode().getInArgsList().get(1).getFieldNode().getField().getName());
+ assertEquals(
+ GandivaTypes.GandivaType.UINT64_VALUE,
+ node.getFnNode().getReturnType().getType().getNumber());
+ }
+
+ @Test
+ public void testMakeIf() throws GandivaException {
+ Field a = Field.nullable("a", new ArrowType.Int(64, false));
+ Field b = Field.nullable("b", new ArrowType.Int(64, false));
+ TreeNode aNode = TreeBuilder.makeField(a);
+ TreeNode bNode = TreeBuilder.makeField(b);
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(aNode);
+ args.add(bNode);
+
+ ArrowType retType = new ArrowType.Bool();
+ TreeNode cond = TreeBuilder.makeFunction("greater_than", args, retType);
+ TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, retType);
+
+ GandivaTypes.TreeNode node = ifNode.toProtobuf();
+
+ assertTrue(node.hasIfNode());
+ assertEquals("greater_than", node.getIfNode().getCond().getFnNode().getFunctionName());
+ assertEquals(a.getName(), node.getIfNode().getThenNode().getFieldNode().getField().getName());
+ assertEquals(b.getName(), node.getIfNode().getElseNode().getFieldNode().getField().getName());
+ assertEquals(
+ GandivaTypes.GandivaType.BOOL_VALUE,
+ node.getIfNode().getReturnType().getType().getNumber());
+ }
+
+ @Test
+ public void testMakeAnd() throws GandivaException {
+ TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool()));
+ TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool()));
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(a);
+ args.add(b);
+
+ TreeNode andNode = TreeBuilder.makeAnd(args);
+ GandivaTypes.TreeNode node = andNode.toProtobuf();
+
+ assertTrue(node.hasAndNode());
+ assertEquals(2, node.getAndNode().getArgsList().size());
+ assertEquals("a", node.getAndNode().getArgsList().get(0).getFieldNode().getField().getName());
+ assertEquals("b", node.getAndNode().getArgsList().get(1).getFieldNode().getField().getName());
+ }
+
+ @Test
+ public void testMakeOr() throws GandivaException {
+ TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool()));
+ TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool()));
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(a);
+ args.add(b);
+
+ TreeNode orNode = TreeBuilder.makeOr(args);
+ GandivaTypes.TreeNode node = orNode.toProtobuf();
+
+ assertTrue(node.hasOrNode());
+ assertEquals(2, node.getOrNode().getArgsList().size());
+ assertEquals("a", node.getOrNode().getArgsList().get(0).getFieldNode().getField().getName());
+ assertEquals("b", node.getOrNode().getArgsList().get(1).getFieldNode().getField().getName());
+ }
+
+ @Test
+ public void testExpression() throws GandivaException {
+ Field a = Field.nullable("a", new ArrowType.Int(64, false));
+ Field b = Field.nullable("b", new ArrowType.Int(64, false));
+ TreeNode aNode = TreeBuilder.makeField(a);
+ TreeNode bNode = TreeBuilder.makeField(b);
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(aNode);
+ args.add(bNode);
+
+ ArrowType retType = new ArrowType.Bool();
+ TreeNode cond = TreeBuilder.makeFunction("greater_than", args, retType);
+ TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, retType);
+
+ ExpressionTree expr = TreeBuilder.makeExpression(ifNode, Field.nullable("c", retType));
+
+ GandivaTypes.ExpressionRoot root = expr.toProtobuf();
+
+ assertTrue(root.getRoot().hasIfNode());
+ assertEquals(
+ "greater_than", root.getRoot().getIfNode().getCond().getFnNode().getFunctionName());
+ assertEquals("c", root.getResultType().getName());
+ assertEquals(
+ GandivaTypes.GandivaType.BOOL_VALUE, root.getResultType().getType().getType().getNumber());
+ }
+
+ @Test
+ public void testExpression2() throws GandivaException {
+ Field a = Field.nullable("a", new ArrowType.Int(64, false));
+ Field b = Field.nullable("b", new ArrowType.Int(64, false));
+ List<Field> args = new ArrayList<Field>(2);
+ args.add(a);
+ args.add(b);
+
+ Field c = Field.nullable("c", new ArrowType.Int(64, false));
+ ExpressionTree expr = TreeBuilder.makeExpression("add", args, c);
+ GandivaTypes.ExpressionRoot root = expr.toProtobuf();
+
+ GandivaTypes.TreeNode node = root.getRoot();
+
+ assertEquals("c", root.getResultType().getName());
+ assertTrue(node.hasFnNode());
+ assertEquals("add", node.getFnNode().getFunctionName());
+ assertEquals("a", node.getFnNode().getInArgsList().get(0).getFieldNode().getField().getName());
+ assertEquals("b", node.getFnNode().getInArgsList().get(1).getFieldNode().getField().getName());
+ assertEquals(
+ GandivaTypes.GandivaType.UINT64_VALUE,
+ node.getFnNode().getReturnType().getType().getNumber());
+ }
+
+ @Test
+ public void testExpressionWithAnd() throws GandivaException {
+ TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool()));
+ TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool()));
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(a);
+ args.add(b);
+
+ TreeNode andNode = TreeBuilder.makeAnd(args);
+ ExpressionTree expr =
+ TreeBuilder.makeExpression(andNode, Field.nullable("c", new ArrowType.Bool()));
+ GandivaTypes.ExpressionRoot root = expr.toProtobuf();
+
+ assertTrue(root.getRoot().hasAndNode());
+ assertEquals(
+ "a", root.getRoot().getAndNode().getArgsList().get(0).getFieldNode().getField().getName());
+ assertEquals(
+ "b", root.getRoot().getAndNode().getArgsList().get(1).getFieldNode().getField().getName());
+ assertEquals("c", root.getResultType().getName());
+ assertEquals(
+ GandivaTypes.GandivaType.BOOL_VALUE, root.getResultType().getType().getType().getNumber());
+ }
+
+ @Test
+ public void testExpressionWithOr() throws GandivaException {
+ TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool()));
+ TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool()));
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(a);
+ args.add(b);
+
+ TreeNode orNode = TreeBuilder.makeOr(args);
+ ExpressionTree expr =
+ TreeBuilder.makeExpression(orNode, Field.nullable("c", new ArrowType.Bool()));
+ GandivaTypes.ExpressionRoot root = expr.toProtobuf();
+
+ assertTrue(root.getRoot().hasOrNode());
+ assertEquals(
+ "a", root.getRoot().getOrNode().getArgsList().get(0).getFieldNode().getField().getName());
+ assertEquals(
+ "b", root.getRoot().getOrNode().getArgsList().get(1).getFieldNode().getField().getName());
+ assertEquals("c", root.getResultType().getName());
+ assertEquals(
+ GandivaTypes.GandivaType.BOOL_VALUE, root.getResultType().getType().getType().getNumber());
+ }
+
+ @Test
+ public void testCondition() throws GandivaException {
+ Field a = Field.nullable("a", new ArrowType.Int(64, false));
+ Field b = Field.nullable("b", new ArrowType.Int(64, false));
+
+ TreeNode aNode = TreeBuilder.makeField(a);
+ TreeNode bNode = TreeBuilder.makeField(b);
+ List<TreeNode> args = new ArrayList<TreeNode>(2);
+ args.add(aNode);
+ args.add(bNode);
+
+ TreeNode root = TreeBuilder.makeFunction("greater_than", args, new ArrowType.Bool());
+ Condition condition = TreeBuilder.makeCondition(root);
+
+ GandivaTypes.Condition conditionProto = condition.toProtobuf();
+ assertTrue(conditionProto.getRoot().hasFnNode());
+ assertEquals("greater_than", conditionProto.getRoot().getFnNode().getFunctionName());
+ assertEquals(
+ "a",
+ conditionProto
+ .getRoot()
+ .getFnNode()
+ .getInArgsList()
+ .get(0)
+ .getFieldNode()
+ .getField()
+ .getName());
+ assertEquals(
+ "b",
+ conditionProto
+ .getRoot()
+ .getFnNode()
+ .getInArgsList()
+ .get(1)
+ .getFieldNode()
+ .getField()
+ .getName());
+ }
+
+ @Test
+ public void testCondition2() throws GandivaException {
+ Field a = Field.nullable("a", new ArrowType.Int(64, false));
+ Field b = Field.nullable("b", new ArrowType.Int(64, false));
+
+ Condition condition = TreeBuilder.makeCondition("greater_than", Arrays.asList(a, b));
+
+ GandivaTypes.Condition conditionProto = condition.toProtobuf();
+ assertTrue(conditionProto.getRoot().hasFnNode());
+ assertEquals("greater_than", conditionProto.getRoot().getFnNode().getFunctionName());
+ assertEquals(
+ "a",
+ conditionProto
+ .getRoot()
+ .getFnNode()
+ .getInArgsList()
+ .get(0)
+ .getFieldNode()
+ .getField()
+ .getName());
+ assertEquals(
+ "b",
+ conditionProto
+ .getRoot()
+ .getFnNode()
+ .getInArgsList()
+ .get(1)
+ .getFieldNode()
+ .getField()
+ .getName());
+ }
+}
diff --git a/src/arrow/java/gandiva/src/test/resources/logback.xml b/src/arrow/java/gandiva/src/test/resources/logback.xml
new file mode 100644
index 000000000..f9e449fa6
--- /dev/null
+++ b/src/arrow/java/gandiva/src/test/resources/logback.xml
@@ -0,0 +1,28 @@
+<?xml version="1.0" encoding="UTF-8" ?>
+<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
+ license agreements. See the NOTICE file distributed with this work for additional
+ information regarding copyright ownership. The ASF licenses this file to
+ You under the Apache License, Version 2.0 (the "License"); you may not use
+ this file except in compliance with the License. You may obtain a copy of
+ the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
+ by applicable law or agreed to in writing, software distributed under the
+ License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
+ OF ANY KIND, either express or implied. See the License for the specific
+ language governing permissions and limitations under the License. -->
+
+<configuration>
+ <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
+ <!-- encoders are assigned the type
+ ch.qos.logback.classic.encoder.PatternLayoutEncoder by default -->
+ <encoder>
+ <pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
+ </encoder>
+ </appender>
+
+ <statusListener class="ch.qos.logback.core.status.NopStatusListener"/>
+ <logger name="org.apache.arrow" additivity="false">
+ <level value="info" />
+ <appender-ref ref="STDOUT" />
+ </logger>
+
+</configuration>