diff options
Diffstat (limited to 'src/arrow/java/gandiva')
51 files changed, 8389 insertions, 0 deletions
diff --git a/src/arrow/java/gandiva/CMakeLists.txt b/src/arrow/java/gandiva/CMakeLists.txt new file mode 100644 index 000000000..5010daf79 --- /dev/null +++ b/src/arrow/java/gandiva/CMakeLists.txt @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +project(gandiva_java) + +# Find java/jni +include(FindJava) +include(UseJava) +include(FindJNI) + +message("generating headers to ${JNI_HEADERS_DIR}/jni") + +# generate_native_headers is available only from java8 +# centos5 does not have java8 images, so supporting java 7 too. +# unfortunately create_javah does not work in java8 correctly. +if(ARROW_GANDIVA_JAVA7) + add_jar(gandiva_java + src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java + src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java + src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java + src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java) + + create_javah(TARGET gandiva_jni_headers + CLASSES org.apache.arrow.gandiva.evaluator.ConfigurationBuilder + org.apache.arrow.gandiva.evaluator.JniWrapper + org.apache.arrow.gandiva.evaluator.ExpressionRegistryJniHelper + org.apache.arrow.gandiva.exceptions.GandivaException + DEPENDS gandiva_java + CLASSPATH gandiva_java + OUTPUT_DIR ${JNI_HEADERS_DIR}/jni) +else() + add_jar(gandiva_java + src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java + src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java + src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java + src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java + GENERATE_NATIVE_HEADERS + gandiva_jni_headers + DESTINATION + ${JNI_HEADERS_DIR}/jni) +endif() diff --git a/src/arrow/java/gandiva/README.md b/src/arrow/java/gandiva/README.md new file mode 100644 index 000000000..22a292eaf --- /dev/null +++ b/src/arrow/java/gandiva/README.md @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +# Gandiva Java + +## Setup Build Environment + +install: + - java 7 or later + - maven 3.3 or later + +## Building and running tests + +``` +cd java +mvn install -Dgandiva.cpp.build.dir=<path_to_cpp_artifact_directory> +``` diff --git a/src/arrow/java/gandiva/pom.xml b/src/arrow/java/gandiva/pom.xml new file mode 100644 index 000000000..81caf12f5 --- /dev/null +++ b/src/arrow/java/gandiva/pom.xml @@ -0,0 +1,153 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor + license agreements. See the NOTICE file distributed with this work for additional + information regarding copyright ownership. The ASF licenses this file to + You under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of + the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required + by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS + OF ANY KIND, either express or implied. See the License for the specific + language governing permissions and limitations under the License. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-java-root</artifactId> + <version>6.0.1</version> + </parent> + + <groupId>org.apache.arrow.gandiva</groupId> + <artifactId>arrow-gandiva</artifactId> + <packaging>jar</packaging> + <name>Arrow Gandiva</name> + <description>Java wrappers around the native Gandiva SQL expression compiler.</description> + <properties> + <maven.compiler.source>1.8</maven.compiler.source> + <maven.compiler.target>1.8</maven.compiler.target> + <protobuf.version>2.5.0</protobuf.version> + <checkstyle.failOnViolation>true</checkstyle.failOnViolation> + <arrow.cpp.build.dir>../../../cpp/release-build</arrow.cpp.build.dir> + </properties> + <dependencies> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-core</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-netty</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-vector</artifactId> + <version>${project.version}</version> + <classifier>${arrow.vector.classifier}</classifier> + </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </dependency> + </dependencies> + <profiles> + <profile> + <id>release</id> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-source-plugin</artifactId> + <version>2.2.1</version> + <executions> + <execution> + <id>attach-sources</id> + <goals> + <goal>jar-no-fork</goal> + </goals> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-javadoc-plugin</artifactId> + <version>2.9.1</version> + <executions> + <execution> + <id>attach-javadocs</id> + <goals> + <goal>jar</goal> + </goals> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + <version>1.5</version> + <executions> + <execution> + <id>sign-artifacts</id> + <phase>verify</phase> + <goals> + <goal>sign</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + </profiles> + <build> + <resources> + <resource> + <directory>${arrow.cpp.build.dir}</directory> + <includes> + <include>**/gandiva_jni.*</include> + <include>**/libgandiva_jni.*</include> + </includes> + </resource> + </resources> + + <extensions> + <extension> + <groupId>kr.motd.maven</groupId> + <artifactId>os-maven-plugin</artifactId> + <version>1.5.0.Final</version> + </extension> + </extensions> + <plugins> + <plugin> + <groupId>org.xolstice.maven.plugins</groupId> + <artifactId>protobuf-maven-plugin</artifactId> + <version>0.5.1</version> + <configuration> + <protocArtifact>com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + </protocArtifact> + <protoSourceRoot>../../cpp/src/gandiva/proto</protoSourceRoot> + </configuration> + <executions> + <execution> + <goals> + <goal>compile</goal> + <goal>test-compile</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + +</project> diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java new file mode 100644 index 000000000..e903b4e87 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.util.Objects; + +/** + * Used to construct gandiva configuration objects. + */ +public class ConfigurationBuilder { + + public long buildConfigInstance(ConfigOptions configOptions) { + return buildConfigInstance(configOptions.optimize, configOptions.targetCPU); + } + + private native long buildConfigInstance(boolean optimize, boolean detectHostCPU); + + public native void releaseConfigInstance(long configId); + + /** + * ConfigOptions contains the configuration parameters to provide to gandiva. + */ + public static class ConfigOptions { + private boolean optimize = true; + private boolean targetCPU = true; + + public static ConfigOptions getDefault() { + return new ConfigOptions(); + } + + public ConfigOptions() {} + + public ConfigOptions withOptimize(boolean optimize) { + this.optimize = optimize; + return this; + } + + public ConfigOptions withTargetCPU(boolean targetCPU) { + this.targetCPU = targetCPU; + return this; + } + + @Override + public int hashCode() { + return Objects.hash(optimize, targetCPU); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof ConfigOptions)) { + return false; + } + return this.optimize == ((ConfigOptions) obj).optimize && + this.targetCPU == ((ConfigOptions) obj).targetCPU; + } + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java new file mode 100644 index 000000000..e0c072cfb --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtil.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.vector.types.pojo.ArrowType.Decimal; + +/** + * Utility methods for working with {@link Decimal} values. + */ +public class DecimalTypeUtil { + private DecimalTypeUtil() {} + + /** + * Enum for supported mathematical operations. + */ + public enum OperationType { + ADD, + SUBTRACT, + MULTIPLY, + DIVIDE, + MOD + } + + private static final int MIN_ADJUSTED_SCALE = 6; + /// The maximum precision representable by a 16-byte decimal + private static final int MAX_PRECISION = 38; + + /** + * Determines the scale and precision of applying the given operation to the operands. + */ + public static Decimal getResultTypeForOperation(OperationType operation, Decimal operand1, Decimal + operand2) { + int s1 = operand1.getScale(); + int s2 = operand2.getScale(); + int p1 = operand1.getPrecision(); + int p2 = operand2.getPrecision(); + int resultScale = 0; + int resultPrecision = 0; + switch (operation) { + case ADD: + case SUBTRACT: + resultScale = Math.max(operand1.getScale(), operand2.getScale()); + resultPrecision = resultScale + Math.max(operand1.getPrecision() - operand1.getScale(), + operand2.getPrecision() - operand2.getScale()) + 1; + break; + case MULTIPLY: + resultScale = s1 + s2; + resultPrecision = p1 + p2 + 1; + break; + case DIVIDE: + resultScale = + Math.max(MIN_ADJUSTED_SCALE, operand1.getScale() + operand2.getPrecision() + 1); + resultPrecision = + operand1.getPrecision() - operand1.getScale() + operand2.getScale() + resultScale; + break; + case MOD: + resultScale = Math.max(operand1.getScale(), operand2.getScale()); + resultPrecision = Math.min(operand1.getPrecision() - operand1.getScale(), + operand2.getPrecision() - operand2.getScale()) + + resultScale; + break; + default: + throw new RuntimeException("Needs support"); + } + return adjustScaleIfNeeded(resultPrecision, resultScale); + } + + private static Decimal adjustScaleIfNeeded(int precision, int scale) { + if (precision > MAX_PRECISION) { + int minScale = Math.min(scale, MIN_ADJUSTED_SCALE); + int delta = precision - MAX_PRECISION; + precision = MAX_PRECISION; + scale = Math.max(scale - delta, minScale); + } + return new Decimal(precision, scale, 128); + } + +} + diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java new file mode 100644 index 000000000..0155af082 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.util.List; +import java.util.Set; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.gandiva.ipc.GandivaTypes.ExtGandivaType; +import org.apache.arrow.gandiva.ipc.GandivaTypes.GandivaDataTypes; +import org.apache.arrow.gandiva.ipc.GandivaTypes.GandivaFunctions; +import org.apache.arrow.gandiva.ipc.GandivaTypes.GandivaType; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.google.protobuf.InvalidProtocolBufferException; + +/** + * Used to get the functions and data types supported by + * Gandiva. + * All types are in Arrow namespace. + */ +public class ExpressionRegistry { + + private static final int BIT_WIDTH8 = 8; + private static final int BIT_WIDTH_16 = 16; + private static final int BIT_WIDTH_32 = 32; + private static final int BIT_WIDTH_64 = 64; + private static final boolean IS_SIGNED_FALSE = false; + private static final boolean IS_SIGNED_TRUE = true; + + private final Set<ArrowType> supportedTypes; + private final Set<FunctionSignature> functionSignatures; + + private static volatile ExpressionRegistry INSTANCE; + + private ExpressionRegistry(Set<ArrowType> supportedTypes, + Set<FunctionSignature> functionSignatures) { + this.supportedTypes = supportedTypes; + this.functionSignatures = functionSignatures; + } + + /** + * Returns a singleton instance of the class. + * @return singleton instance + * @throws GandivaException if error in Gandiva Library integration. + */ + public static ExpressionRegistry getInstance() throws GandivaException { + if (INSTANCE == null) { + synchronized (ExpressionRegistry.class) { + if (INSTANCE == null) { + // ensure library is setup. + JniLoader.getInstance(); + Set<ArrowType> typesFromGandiva = getSupportedTypesFromGandiva(); + Set<FunctionSignature> functionsFromGandiva = getSupportedFunctionsFromGandiva(); + INSTANCE = new ExpressionRegistry(typesFromGandiva, functionsFromGandiva); + } + } + } + return INSTANCE; + } + + public Set<FunctionSignature> getSupportedFunctions() { + return functionSignatures; + } + + public Set<ArrowType> getSupportedTypes() { + return supportedTypes; + } + + private static Set<ArrowType> getSupportedTypesFromGandiva() throws GandivaException { + Set<ArrowType> supportedTypes = Sets.newHashSet(); + try { + byte[] gandivaSupportedDataTypes = new ExpressionRegistryJniHelper() + .getGandivaSupportedDataTypes(); + GandivaDataTypes gandivaDataTypes = GandivaDataTypes.parseFrom(gandivaSupportedDataTypes); + for (ExtGandivaType type : gandivaDataTypes.getDataTypeList()) { + supportedTypes.add(getArrowType(type)); + } + } catch (InvalidProtocolBufferException invalidProtException) { + throw new GandivaException("Could not get supported types.", invalidProtException); + } + return supportedTypes; + } + + private static Set<FunctionSignature> getSupportedFunctionsFromGandiva() throws + GandivaException { + Set<FunctionSignature> supportedTypes = Sets.newHashSet(); + try { + byte[] gandivaSupportedFunctions = new ExpressionRegistryJniHelper() + .getGandivaSupportedFunctions(); + GandivaFunctions gandivaFunctions = GandivaFunctions.parseFrom(gandivaSupportedFunctions); + for (GandivaTypes.FunctionSignature protoFunctionSignature + : gandivaFunctions.getFunctionList()) { + + String functionName = protoFunctionSignature.getName(); + ArrowType returnType = getArrowType(protoFunctionSignature.getReturnType()); + List<ArrowType> paramTypes = Lists.newArrayList(); + for (ExtGandivaType type : protoFunctionSignature.getParamTypesList()) { + paramTypes.add(getArrowType(type)); + } + FunctionSignature functionSignature = new FunctionSignature(functionName, + returnType, paramTypes); + supportedTypes.add(functionSignature); + } + } catch (InvalidProtocolBufferException invalidProtException) { + throw new GandivaException("Could not get supported functions.", invalidProtException); + } + return supportedTypes; + } + + private static ArrowType getArrowType(ExtGandivaType type) { + switch (type.getType().getNumber()) { + case GandivaType.BOOL_VALUE: + return ArrowType.Bool.INSTANCE; + case GandivaType.UINT8_VALUE: + return new ArrowType.Int(BIT_WIDTH8, IS_SIGNED_FALSE); + case GandivaType.INT8_VALUE: + return new ArrowType.Int(BIT_WIDTH8, IS_SIGNED_TRUE); + case GandivaType.UINT16_VALUE: + return new ArrowType.Int(BIT_WIDTH_16, IS_SIGNED_FALSE); + case GandivaType.INT16_VALUE: + return new ArrowType.Int(BIT_WIDTH_16, IS_SIGNED_TRUE); + case GandivaType.UINT32_VALUE: + return new ArrowType.Int(BIT_WIDTH_32, IS_SIGNED_FALSE); + case GandivaType.INT32_VALUE: + return new ArrowType.Int(BIT_WIDTH_32, IS_SIGNED_TRUE); + case GandivaType.UINT64_VALUE: + return new ArrowType.Int(BIT_WIDTH_64, IS_SIGNED_FALSE); + case GandivaType.INT64_VALUE: + return new ArrowType.Int(BIT_WIDTH_64, IS_SIGNED_TRUE); + case GandivaType.HALF_FLOAT_VALUE: + return new ArrowType.FloatingPoint(FloatingPointPrecision.HALF); + case GandivaType.FLOAT_VALUE: + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case GandivaType.DOUBLE_VALUE: + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case GandivaType.UTF8_VALUE: + return new ArrowType.Utf8(); + case GandivaType.BINARY_VALUE: + return new ArrowType.Binary(); + case GandivaType.DATE32_VALUE: + return new ArrowType.Date(DateUnit.DAY); + case GandivaType.DATE64_VALUE: + return new ArrowType.Date(DateUnit.MILLISECOND); + case GandivaType.TIMESTAMP_VALUE: + return new ArrowType.Timestamp(mapArrowTimeUnit(type.getTimeUnit()), null); + case GandivaType.TIME32_VALUE: + return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), + BIT_WIDTH_32); + case GandivaType.TIME64_VALUE: + return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), + BIT_WIDTH_64); + case GandivaType.NONE_VALUE: + return new ArrowType.Null(); + case GandivaType.DECIMAL_VALUE: + return new ArrowType.Decimal(0, 0, 128); + case GandivaType.INTERVAL_VALUE: + return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType())); + case GandivaType.FIXED_SIZE_BINARY_VALUE: + case GandivaType.MAP_VALUE: + case GandivaType.DICTIONARY_VALUE: + case GandivaType.LIST_VALUE: + case GandivaType.STRUCT_VALUE: + case GandivaType.UNION_VALUE: + default: + assert false; + } + return null; + } + + private static TimeUnit mapArrowTimeUnit(GandivaTypes.TimeUnit timeUnit) { + switch (timeUnit.getNumber()) { + case GandivaTypes.TimeUnit.MICROSEC_VALUE: + return TimeUnit.MICROSECOND; + case GandivaTypes.TimeUnit.MILLISEC_VALUE: + return TimeUnit.MILLISECOND; + case GandivaTypes.TimeUnit.NANOSEC_VALUE: + return TimeUnit.NANOSECOND; + case GandivaTypes.TimeUnit.SEC_VALUE: + return TimeUnit.SECOND; + default: + return null; + } + } + + private static IntervalUnit mapArrowIntervalUnit(GandivaTypes.IntervalType intervalType) { + switch (intervalType.getNumber()) { + case GandivaTypes.IntervalType.YEAR_MONTH_VALUE: + return IntervalUnit.YEAR_MONTH; + case GandivaTypes.IntervalType.DAY_TIME_VALUE: + return IntervalUnit.DAY_TIME; + default: + return null; + } + } + +} + diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java new file mode 100644 index 000000000..86c1eaaed --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +/** + * JNI Adapter used to get supported types and functions + * from Gandiva. + */ +class ExpressionRegistryJniHelper { + + native byte[] getGandivaSupportedDataTypes(); + + native byte[] getGandivaSupportedFunctions(); +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java new file mode 100644 index 000000000..010d644d1 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Filter.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.EvaluatorClosedException; +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.expression.ArrowTypeHelper; +import org.apache.arrow.gandiva.expression.Condition; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.ipc.message.ArrowBuffer; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class provides a mechanism to filter a RecordBatch by evaluating a condition expression. + * Follow these steps to use this class: 1) Use the static method make() to create an instance of + * this class that evaluates a condition. 2) Invoke the method evaluate() to evaluate the filter + * against a RecordBatch 3) Invoke close() to release resources + */ +public class Filter { + + private static final Logger logger = LoggerFactory.getLogger(Filter.class); + + private final JniWrapper wrapper; + private final long moduleId; + private final Schema schema; + private boolean closed; + + private Filter(JniWrapper wrapper, long moduleId, Schema schema) { + this.wrapper = wrapper; + this.moduleId = moduleId; + this.schema = schema; + this.closed = false; + } + + /** + * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke + * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch + * + * @param schema Table schema. The field names in the schema should match the fields used to + * create the TreeNodes + * @param condition condition to be evaluated against data + * @return A native filter object that can be used to invoke on a RecordBatch + */ + public static Filter make(Schema schema, Condition condition) throws GandivaException { + return make(schema, condition, JniLoader.getDefaultConfiguration()); + } + + /** + * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke + * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch + * + * @param schema Table schema. The field names in the schema should match the fields used to + * create the TreeNodes + * @param condition condition to be evaluated against data + * @param configOptions ConfigOptions parameter + * @return A native filter object that can be used to invoke on a RecordBatch + */ + public static Filter make(Schema schema, Condition condition, ConfigurationBuilder.ConfigOptions configOptions) + throws GandivaException { + return make(schema, condition, JniLoader.getConfiguration(configOptions)); + } + + /** + * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke + * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch + * + * @param schema Table schema. The field names in the schema should match the fields used to + * create the TreeNodes + * @param condition condition to be evaluated against data + * @param optimize Flag to choose if the generated llvm code is to be optimized + * @return A native filter object that can be used to invoke on a RecordBatch + */ + @Deprecated + public static Filter make(Schema schema, Condition condition, boolean optimize) throws GandivaException { + return make(schema, condition, JniLoader.getConfiguration((new ConfigurationBuilder.ConfigOptions()) + .withOptimize(optimize))); + } + + /** + * Invoke this function to generate LLVM code to evaluate the condition expression. Invoke + * Filter::Evaluate() against a RecordBatch to evaluate the filter on this record batch + * + * @param schema Table schema. The field names in the schema should match the fields used to + * create the TreeNodes + * @param condition condition to be evaluated against data + * @param configurationId Custom configuration created through config builder. + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + public static Filter make(Schema schema, Condition condition, long configurationId) + throws GandivaException { + // Invoke the JNI layer to create the LLVM module representing the filter. + GandivaTypes.Condition conditionBuf = condition.toProtobuf(); + GandivaTypes.Schema schemaBuf = ArrowTypeHelper.arrowSchemaToProtobuf(schema); + JniWrapper wrapper = JniLoader.getInstance().getWrapper(); + long moduleId = wrapper.buildFilter(schemaBuf.toByteArray(), + conditionBuf.toByteArray(), configurationId); + logger.debug("Created module for the filter with id {}", moduleId); + return new Filter(wrapper, moduleId, schema); + } + + /** + * Invoke this function to evaluate a filter against a recordBatch. + * + * @param recordBatch Record batch including the data + * @param selectionVector Result of applying the filter on the data + */ + public void evaluate(ArrowRecordBatch recordBatch, SelectionVector selectionVector) + throws GandivaException { + evaluate(recordBatch.getLength(), recordBatch.getBuffers(), recordBatch.getBuffersLayout(), + selectionVector); + } + + /** + * Invoke this function to evaluate filter against a set of arrow buffers. (this is an optimised + * version that skips taking references). + * + * @param numRows number of rows. + * @param buffers List of input arrow buffers + * @param selectionVector Result of applying the filter on the data + */ + public void evaluate(int numRows, List<ArrowBuf> buffers, + SelectionVector selectionVector) throws GandivaException { + List<ArrowBuffer> buffersLayout = new ArrayList<>(); + long offset = 0; + for (ArrowBuf arrowBuf : buffers) { + long size = arrowBuf.readableBytes(); + buffersLayout.add(new ArrowBuffer(offset, size)); + offset += size; + } + evaluate(numRows, buffers, buffersLayout, selectionVector); + } + + private void evaluate(int numRows, List<ArrowBuf> buffers, List<ArrowBuffer> buffersLayout, + SelectionVector selectionVector) throws GandivaException { + if (this.closed) { + throw new EvaluatorClosedException(); + } + if (selectionVector.getMaxRecords() < numRows) { + logger.error("selectionVector has capacity for " + selectionVector.getMaxRecords() + + " rows, minimum required " + numRows); + throw new GandivaException("SelectionVector too small"); + } + + long[] bufAddrs = new long[buffers.size()]; + long[] bufSizes = new long[buffers.size()]; + + int idx = 0; + for (ArrowBuf buf : buffers) { + bufAddrs[idx++] = buf.memoryAddress(); + } + + idx = 0; + for (ArrowBuffer bufLayout : buffersLayout) { + bufSizes[idx++] = bufLayout.getSize(); + } + + int numRecords = wrapper.evaluateFilter(this.moduleId, numRows, + bufAddrs, bufSizes, + selectionVector.getType().getNumber(), + selectionVector.getBuffer().memoryAddress(), selectionVector.getBuffer().capacity()); + if (numRecords >= 0) { + selectionVector.setRecordCount(numRecords); + } + } + + /** + * Closes the LLVM module representing this filter. + */ + public void close() throws GandivaException { + if (this.closed) { + return; + } + + wrapper.closeFilter(this.moduleId); + this.closed = true; + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java new file mode 100644 index 000000000..d01881843 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.util.List; + +import org.apache.arrow.vector.types.pojo.ArrowType; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; + +/** + * POJO to define a function signature. + */ +public class FunctionSignature { + private final String name; + private final ArrowType returnType; + private final List<ArrowType> paramTypes; + + public ArrowType getReturnType() { + return returnType; + } + + public List<ArrowType> getParamTypes() { + return paramTypes; + } + + public String getName() { + return name; + } + + /** + * Ctor. + * @param name - name of the function. + * @param returnType - data type of return + * @param paramTypes - data type of input args. + */ + public FunctionSignature(String name, ArrowType returnType, List<ArrowType> paramTypes) { + this.name = name; + this.returnType = returnType; + this.paramTypes = paramTypes; + } + + /** + * Override equals. + * @param signature - signature to compare + * @return true if equal and false if not. + */ + public boolean equals(Object signature) { + if (signature == null) { + return false; + } + if (getClass() != signature.getClass()) { + return false; + } + final FunctionSignature other = (FunctionSignature) signature; + return this.name.equalsIgnoreCase(other.name) && + Objects.equal(this.returnType, other.returnType) && + Objects.equal(this.paramTypes, other.paramTypes); + } + + @Override + public int hashCode() { + return Objects.hashCode(this.name.toLowerCase(), this.returnType, this.paramTypes); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("name ", name) + .add("return type ", returnType) + .add("param types ", paramTypes) + .toString(); + + } + + +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java new file mode 100644 index 000000000..676956a34 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniLoader.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import static java.util.UUID.randomUUID; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.apache.arrow.gandiva.exceptions.GandivaException; + +/** + * This class handles loading of the jni library, and acts as a bridge for the native functions. + */ +class JniLoader { + private static final String LIBRARY_NAME = "gandiva_jni"; + + private static volatile JniLoader INSTANCE; + private static volatile long defaultConfiguration = 0L; + private static final ConcurrentMap<ConfigurationBuilder.ConfigOptions, Long> configurationMap + = new ConcurrentHashMap<>(); + + private final JniWrapper wrapper; + + private JniLoader() { + this.wrapper = new JniWrapper(); + } + + static JniLoader getInstance() throws GandivaException { + if (INSTANCE == null) { + synchronized (JniLoader.class) { + if (INSTANCE == null) { + INSTANCE = setupInstance(); + } + } + } + return INSTANCE; + } + + private static JniLoader setupInstance() throws GandivaException { + try { + String tempDir = System.getProperty("java.io.tmpdir"); + loadGandivaLibraryFromJar(tempDir); + return new JniLoader(); + } catch (IOException ioException) { + throw new GandivaException("unable to create native instance", ioException); + } + } + + private static void loadGandivaLibraryFromJar(final String tmpDir) + throws IOException, GandivaException { + final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME); + final File libraryFile = moveFileFromJarToTemp(tmpDir, libraryToLoad); + System.load(libraryFile.getAbsolutePath()); + } + + + private static File moveFileFromJarToTemp(final String tmpDir, String libraryToLoad) + throws IOException, GandivaException { + final File temp = setupFile(tmpDir, libraryToLoad); + try (final InputStream is = JniLoader.class.getClassLoader() + .getResourceAsStream(libraryToLoad)) { + if (is == null) { + throw new GandivaException(libraryToLoad + " was not found inside JAR."); + } else { + Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING); + } + } + return temp; + } + + private static File setupFile(String tmpDir, String libraryToLoad) + throws IOException, GandivaException { + // accommodate multiple processes running with gandiva jar. + // length should be ok since uuid is only 36 characters. + final String randomizeFileName = libraryToLoad + randomUUID(); + final File temp = new File(tmpDir, randomizeFileName); + if (temp.exists() && !temp.delete()) { + throw new GandivaException("File: " + temp.getAbsolutePath() + + " already exists and cannot be removed."); + } + if (!temp.createNewFile()) { + throw new GandivaException("File: " + temp.getAbsolutePath() + + " could not be created."); + } + temp.deleteOnExit(); + return temp; + } + + /** + * Returns the jni wrapper. + */ + JniWrapper getWrapper() throws GandivaException { + return wrapper; + } + + static long getConfiguration(ConfigurationBuilder.ConfigOptions configOptions) throws GandivaException { + if (!configurationMap.containsKey(configOptions)) { + synchronized (ConfigurationBuilder.class) { + if (!configurationMap.containsKey(configOptions)) { + JniLoader.getInstance(); // setup + long configInstance = new ConfigurationBuilder() + .buildConfigInstance(configOptions); + configurationMap.put(configOptions, configInstance); + if (ConfigurationBuilder.ConfigOptions.getDefault().equals(configOptions)) { + defaultConfiguration = configInstance; + } + return configInstance; + } + } + } + return configurationMap.get(configOptions); + } + + /** + * Get the default configuration to invoke gandiva. + * @return default configuration + * @throws GandivaException if unable to get native builder instance. + */ + static long getDefaultConfiguration() throws GandivaException { + if (defaultConfiguration == 0L) { + synchronized (ConfigurationBuilder.class) { + if (defaultConfiguration == 0L) { + JniLoader.getInstance(); // setup + ConfigurationBuilder.ConfigOptions defaultConfigOptons = ConfigurationBuilder.ConfigOptions.getDefault(); + defaultConfiguration = new ConfigurationBuilder() + .buildConfigInstance(defaultConfigOptons); + configurationMap.put(defaultConfigOptons, defaultConfiguration); + } + } + } + return defaultConfiguration; + } + + /** + * Remove the configuration. + */ + static void removeConfiguration(ConfigurationBuilder.ConfigOptions configOptions) { + if (configurationMap.containsKey(configOptions)) { + synchronized (ConfigurationBuilder.class) { + if (configurationMap.containsKey(configOptions)) { + (new ConfigurationBuilder()).releaseConfigInstance(configurationMap.remove(configOptions)); + if (configOptions.equals(ConfigurationBuilder.ConfigOptions.getDefault())) { + defaultConfiguration = 0; + } + } + } + } + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java new file mode 100644 index 000000000..520ef5f44 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.gandiva.exceptions.GandivaException; + +/** + * This class is implemented in JNI. This provides the Java interface + * to invoke functions in JNI. + * This file is used to generated the .h files required for jni. Avoid all + * external dependencies in this file. + */ +public class JniWrapper { + + /** + * Generates the projector module to evaluate the expressions with + * custom configuration. + * + * @param schemaBuf The schema serialized as a protobuf. See Types.proto + * to see the protobuf specification + * @param exprListBuf The serialized protobuf of the expression vector. Each + * expression is created using TreeBuilder::MakeExpression. + * @param selectionVectorType type of selection vector + * @param configId Configuration to gandiva. + * @return A moduleId that is passed to the evaluateProjector() and closeProjector() methods + * + */ + native long buildProjector(byte[] schemaBuf, byte[] exprListBuf, + int selectionVectorType, + long configId) throws GandivaException; + + /** + * Evaluate the expressions represented by the moduleId on a record batch + * and store the output in ValueVectors. Throws an exception in case of errors + * + * @param expander VectorExpander object. Used for callbacks from cpp. + * @param moduleId moduleId representing expressions. Created using a call to + * buildNativeCode + * @param numRows Number of rows in the record batch + * @param bufAddrs An array of memory addresses. Each memory address points to + * a validity vector or a data vector (will add support for offset + * vectors later). + * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs, + * the size of the buffer is present in bufSizes + * @param outAddrs An array of output buffers, including the validity and data + * addresses. + * @param outSizes The allocated size of the output buffers. On successful evaluation, + * the result is stored in the output buffers + */ + native void evaluateProjector(Object expander, long moduleId, int numRows, + long[] bufAddrs, long[] bufSizes, + int selectionVectorType, int selectionVectorSize, + long selectionVectorBufferAddr, long selectionVectorBufferSize, + long[] outAddrs, long[] outSizes) throws GandivaException; + + /** + * Closes the projector referenced by moduleId. + * + * @param moduleId moduleId that needs to be closed + */ + native void closeProjector(long moduleId); + + /** + * Generates the filter module to evaluate the condition expression with + * custom configuration. + * + * @param schemaBuf The schema serialized as a protobuf. See Types.proto + * to see the protobuf specification + * @param conditionBuf The serialized protobuf of the condition expression. Each + * expression is created using TreeBuilder::MakeCondition + * @param configId Configuration to gandiva. + * @return A moduleId that is passed to the evaluateFilter() and closeFilter() methods + * + */ + native long buildFilter(byte[] schemaBuf, byte[] conditionBuf, + long configId) throws GandivaException; + + /** + * Evaluate the filter represented by the moduleId on a record batch + * and store the output in buffer 'outAddr'. Throws an exception in case of errors + * + * @param moduleId moduleId representing expressions. Created using a call to + * buildNativeCode + * @param numRows Number of rows in the record batch + * @param bufAddrs An array of memory addresses. Each memory address points to + * a validity vector or a data vector (will add support for offset + * vectors later). + * @param bufSizes An array of buffer sizes. For each memory address in bufAddrs, + * the size of the buffer is present in bufSizes + * @param selectionVectorType type of selection vector + * @param outAddr output buffer, whose type is represented by selectionVectorType + * @param outSize The allocated size of the output buffer. On successful evaluation, + * the result is stored in the output buffer + */ + native int evaluateFilter(long moduleId, int numRows, long[] bufAddrs, long[] bufSizes, + int selectionVectorType, + long outAddr, long outSize) throws GandivaException; + + /** + * Closes the filter referenced by moduleId. + * + * @param moduleId moduleId that needs to be closed + */ + native void closeFilter(long moduleId); +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java new file mode 100644 index 000000000..471ddbced --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.EvaluatorClosedException; +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.exceptions.UnsupportedTypeException; +import org.apache.arrow.gandiva.expression.ArrowTypeHelper; +import org.apache.arrow.gandiva.expression.ExpressionTree; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.FixedWidthVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VariableWidthVector; +import org.apache.arrow.vector.ipc.message.ArrowBuffer; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * This class provides a mechanism to evaluate a set of expressions against a RecordBatch. + * Follow these steps to use this class: + * 1) Use the static method make() to create an instance of this class that evaluates a + * set of expressions + * 2) Invoke the method evaluate() to evaluate these expressions against a RecordBatch + * 3) Invoke close() to release resources + */ +public class Projector { + private static final org.slf4j.Logger logger = + org.slf4j.LoggerFactory.getLogger(Projector.class); + + private JniWrapper wrapper; + private final long moduleId; + private final Schema schema; + private final int numExprs; + private boolean closed; + + private Projector(JniWrapper wrapper, long moduleId, Schema schema, int numExprs) { + this.wrapper = wrapper; + this.moduleId = moduleId; + this.schema = schema; + this.numExprs = numExprs; + this.closed = false; + } + + /** + * Invoke this function to generate LLVM code to evaluate the list of project expressions. + * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch + * against these projections. + * + * @param schema Table schema. The field names in the schema should match the fields used + * to create the TreeNodes + * @param exprs List of expressions to be evaluated against data + * + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + public static Projector make(Schema schema, List<ExpressionTree> exprs) + throws GandivaException { + return make(schema, exprs, SelectionVectorType.SV_NONE, JniLoader.getDefaultConfiguration()); + } + + /** + * Invoke this function to generate LLVM code to evaluate the list of project expressions. + * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch + * against these projections. + * + * @param schema Table schema. The field names in the schema should match the fields used + * to create the TreeNodes + * @param exprs List of expressions to be evaluated against data + * @param configOptions ConfigOptions parameter + * + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + public static Projector make(Schema schema, List<ExpressionTree> exprs, + ConfigurationBuilder.ConfigOptions configOptions) throws GandivaException { + return make(schema, exprs, SelectionVectorType.SV_NONE, JniLoader.getConfiguration(configOptions)); + } + + /** + * Invoke this function to generate LLVM code to evaluate the list of project expressions. + * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch + * against these projections. + * + * @param schema Table schema. The field names in the schema should match the fields used + * to create the TreeNodes + * @param exprs List of expressions to be evaluated against data + * @param optimize Flag to choose if the generated llvm code is to be optimized + * + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + @Deprecated + public static Projector make(Schema schema, List<ExpressionTree> exprs, boolean optimize) + throws GandivaException { + return make(schema, exprs, SelectionVectorType.SV_NONE, + JniLoader.getConfiguration((new ConfigurationBuilder.ConfigOptions()).withOptimize(optimize))); + } + + /** + * Invoke this function to generate LLVM code to evaluate the list of project expressions. + * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch + * against these projections. + * + * @param schema Table schema. The field names in the schema should match the fields used + * to create the TreeNodes + * @param exprs List of expressions to be evaluated against data + * @param selectionVectorType type of selection vector + * + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + public static Projector make(Schema schema, List<ExpressionTree> exprs, + SelectionVectorType selectionVectorType) + throws GandivaException { + return make(schema, exprs, selectionVectorType, JniLoader.getDefaultConfiguration()); + } + + /** + * Invoke this function to generate LLVM code to evaluate the list of project expressions. + * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch + * against these projections. + * + * @param schema Table schema. The field names in the schema should match the fields used + * to create the TreeNodes + * @param exprs List of expressions to be evaluated against data + * @param selectionVectorType type of selection vector + * @param configOptions ConfigOptions parameter + * + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + public static Projector make(Schema schema, List<ExpressionTree> exprs, SelectionVectorType selectionVectorType, + ConfigurationBuilder.ConfigOptions configOptions) throws GandivaException { + return make(schema, exprs, selectionVectorType, JniLoader.getConfiguration(configOptions)); + } + + /** + * Invoke this function to generate LLVM code to evaluate the list of project expressions. + * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch + * against these projections. + * + * @param schema Table schema. The field names in the schema should match the fields used + * to create the TreeNodes + * @param exprs List of expressions to be evaluated against data + * @param selectionVectorType type of selection vector + * @param optimize Flag to choose if the generated llvm code is to be optimized + * + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + @Deprecated + public static Projector make(Schema schema, List<ExpressionTree> exprs, + SelectionVectorType selectionVectorType, boolean optimize) + throws GandivaException { + return make(schema, exprs, selectionVectorType, + JniLoader.getConfiguration((new ConfigurationBuilder.ConfigOptions()).withOptimize(optimize))); + } + + /** + * Invoke this function to generate LLVM code to evaluate the list of project expressions. + * Invoke Projector::Evaluate() against a RecordBatch to evaluate the record batch + * against these projections. + * + * @param schema Table schema. The field names in the schema should match the fields used + * to create the TreeNodes + * @param exprs List of expressions to be evaluated against data + * @param selectionVectorType type of selection vector + * @param configurationId Custom configuration created through config builder. + * + * @return A native evaluator object that can be used to invoke these projections on a RecordBatch + */ + public static Projector make(Schema schema, List<ExpressionTree> exprs, + SelectionVectorType selectionVectorType, + long configurationId) throws GandivaException { + // serialize the schema and the list of expressions as a protobuf + GandivaTypes.ExpressionList.Builder builder = GandivaTypes.ExpressionList.newBuilder(); + for (ExpressionTree expr : exprs) { + builder.addExprs(expr.toProtobuf()); + } + + // Invoke the JNI layer to create the LLVM module representing the expressions + GandivaTypes.Schema schemaBuf = ArrowTypeHelper.arrowSchemaToProtobuf(schema); + JniWrapper wrapper = JniLoader.getInstance().getWrapper(); + long moduleId = wrapper.buildProjector(schemaBuf.toByteArray(), + builder.build().toByteArray(), selectionVectorType.getNumber(), configurationId); + logger.debug("Created module for the projector with id {}", moduleId); + return new Projector(wrapper, moduleId, schema, exprs.size()); + } + + /** + * Invoke this function to evaluate a set of expressions against a recordBatch. + * + * @param recordBatch Record batch including the data + * @param outColumns Result of applying the project on the data + */ + public void evaluate(ArrowRecordBatch recordBatch, List<ValueVector> outColumns) + throws GandivaException { + evaluate(recordBatch.getLength(), recordBatch.getBuffers(), + recordBatch.getBuffersLayout(), + SelectionVectorType.SV_NONE.getNumber(), recordBatch.getLength(), + 0, 0, outColumns); + } + + /** + * Invoke this function to evaluate a set of expressions against a set of arrow buffers. + * (this is an optimised version that skips taking references). + * + * @param numRows number of rows. + * @param buffers List of input arrow buffers + * @param outColumns Result of applying the project on the data + */ + public void evaluate(int numRows, List<ArrowBuf> buffers, + List<ValueVector> outColumns) throws GandivaException { + List<ArrowBuffer> buffersLayout = new ArrayList<>(); + long offset = 0; + for (ArrowBuf arrowBuf : buffers) { + long size = arrowBuf.readableBytes(); + buffersLayout.add(new ArrowBuffer(offset, size)); + offset += size; + } + evaluate(numRows, buffers, buffersLayout, + SelectionVectorType.SV_NONE.getNumber(), + numRows, 0, 0, outColumns); + } + + /** + * Invoke this function to evaluate a set of expressions against a {@link ArrowRecordBatch}. + * + * @param recordBatch The data to evaluate against. + * @param selectionVector Selection vector which stores the selected rows. + * @param outColumns Result of applying the project on the data + */ + public void evaluate(ArrowRecordBatch recordBatch, + SelectionVector selectionVector, List<ValueVector> outColumns) + throws GandivaException { + evaluate(recordBatch.getLength(), recordBatch.getBuffers(), + recordBatch.getBuffersLayout(), + selectionVector.getType().getNumber(), + selectionVector.getRecordCount(), + selectionVector.getBuffer().memoryAddress(), + selectionVector.getBuffer().capacity(), + outColumns); + } + + /** + * Invoke this function to evaluate a set of expressions against a set of arrow buffers + * on the selected positions. + * (this is an optimised version that skips taking references). + * + * @param numRows number of rows. + * @param buffers List of input arrow buffers + * @param selectionVector Selection vector which stores the selected rows. + * @param outColumns Result of applying the project on the data + */ + public void evaluate(int numRows, List<ArrowBuf> buffers, + SelectionVector selectionVector, + List<ValueVector> outColumns) throws GandivaException { + List<ArrowBuffer> buffersLayout = new ArrayList<>(); + long offset = 0; + for (ArrowBuf arrowBuf : buffers) { + long size = arrowBuf.readableBytes(); + buffersLayout.add(new ArrowBuffer(offset, size)); + offset += size; + } + evaluate(numRows, buffers, buffersLayout, + selectionVector.getType().getNumber(), + selectionVector.getRecordCount(), + selectionVector.getBuffer().memoryAddress(), + selectionVector.getBuffer().capacity(), + outColumns); + } + + private void evaluate(int numRows, List<ArrowBuf> buffers, List<ArrowBuffer> buffersLayout, + int selectionVectorType, int selectionVectorRecordCount, + long selectionVectorAddr, long selectionVectorSize, + List<ValueVector> outColumns) throws GandivaException { + if (this.closed) { + throw new EvaluatorClosedException(); + } + + if (numExprs != outColumns.size()) { + logger.info("Expected " + numExprs + " columns, got " + outColumns.size()); + throw new GandivaException("Incorrect number of columns for the output vector"); + } + + long[] bufAddrs = new long[buffers.size()]; + long[] bufSizes = new long[buffers.size()]; + + int idx = 0; + for (ArrowBuf buf : buffers) { + bufAddrs[idx++] = buf.memoryAddress(); + } + + idx = 0; + for (ArrowBuffer bufLayout : buffersLayout) { + bufSizes[idx++] = bufLayout.getSize(); + } + + boolean hasVariableWidthColumns = false; + BaseVariableWidthVector[] resizableVectors = new BaseVariableWidthVector[outColumns.size()]; + long[] outAddrs = new long[3 * outColumns.size()]; + long[] outSizes = new long[3 * outColumns.size()]; + idx = 0; + int outColumnIdx = 0; + for (ValueVector valueVector : outColumns) { + boolean isFixedWith = valueVector instanceof FixedWidthVector; + boolean isVarWidth = valueVector instanceof VariableWidthVector; + if (!isFixedWith && !isVarWidth) { + throw new UnsupportedTypeException( + "Unsupported value vector type " + valueVector.getField().getFieldType()); + } + + outAddrs[idx] = valueVector.getValidityBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getValidityBuffer().capacity(); + if (isVarWidth) { + outAddrs[idx] = valueVector.getOffsetBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getOffsetBuffer().capacity(); + hasVariableWidthColumns = true; + + // save vector to allow for resizing. + resizableVectors[outColumnIdx] = (BaseVariableWidthVector) valueVector; + } + outAddrs[idx] = valueVector.getDataBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getDataBuffer().capacity(); + + valueVector.setValueCount(selectionVectorRecordCount); + outColumnIdx++; + } + + wrapper.evaluateProjector( + hasVariableWidthColumns ? new VectorExpander(resizableVectors) : null, + this.moduleId, numRows, bufAddrs, bufSizes, + selectionVectorType, selectionVectorRecordCount, + selectionVectorAddr, selectionVectorSize, + outAddrs, outSizes); + } + + /** + * Closes the LLVM module representing this evaluator. + */ + public void close() throws GandivaException { + if (this.closed) { + return; + } + + wrapper.closeProjector(this.moduleId); + this.closed = true; + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVector.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVector.java new file mode 100644 index 000000000..2af88b526 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVector.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import static org.apache.arrow.memory.util.LargeMemoryUtil.capAtMaxInt; + +import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; +import org.apache.arrow.memory.ArrowBuf; + +/** + * A selection vector contains the indexes of "selected" records in a row batch. It is backed by an + * arrow buffer. + * Client manages the lifecycle of the arrow buffer - to release the reference. + */ +public abstract class SelectionVector { + private int recordCount; + private ArrowBuf buffer; + + public SelectionVector(ArrowBuf buffer) { + this.buffer = buffer; + } + + public final ArrowBuf getBuffer() { + return this.buffer; + } + + /* + * The maximum number of records that the selection vector can hold. + */ + public final int getMaxRecords() { + return capAtMaxInt(buffer.capacity() / getRecordSize()); + } + + /* + * The number of records held by the selection vector. + */ + public final int getRecordCount() { + return this.recordCount; + } + + /* + * Set the number of records in the selection vector. + */ + final void setRecordCount(int recordCount) { + if (recordCount * getRecordSize() > buffer.capacity()) { + throw new IllegalArgumentException("recordCount " + recordCount + + " of size " + getRecordSize() + + " exceeds buffer capacity " + buffer.capacity()); + } + + this.recordCount = recordCount; + } + + /* + * Get the value at specified index. + */ + public abstract int getIndex(int index); + + /* + * Get the record size of the selection vector itself. + */ + abstract int getRecordSize(); + + abstract SelectionVectorType getType(); + + final void checkReadBounds(int index) { + if (index >= this.recordCount) { + throw new IllegalArgumentException("index " + index + " is >= recordCount " + recordCount); + } + } + +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt16.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt16.java new file mode 100644 index 000000000..84c795b67 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt16.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; +import org.apache.arrow.memory.ArrowBuf; + +/** + * Selection vector with records of arrow type INT16. + */ +public class SelectionVectorInt16 extends SelectionVector { + + public SelectionVectorInt16(ArrowBuf buffer) { + super(buffer); + } + + @Override + public int getRecordSize() { + return 2; + } + + @Override + public SelectionVectorType getType() { + return SelectionVectorType.SV_INT16; + } + + @Override + public int getIndex(int index) { + checkReadBounds(index); + + char value = getBuffer().getChar(index * getRecordSize()); + return (int) value; + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt32.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt32.java new file mode 100644 index 000000000..c938f6691 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/SelectionVectorInt32.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; +import org.apache.arrow.memory.ArrowBuf; + +/** + * Selection vector with records of arrow type INT32. + */ +public class SelectionVectorInt32 extends SelectionVector { + + public SelectionVectorInt32(ArrowBuf buffer) { + super(buffer); + } + + @Override + public int getRecordSize() { + return 4; + } + + @Override + public SelectionVectorType getType() { + return SelectionVectorType.SV_INT32; + } + + @Override + public int getIndex(int index) { + checkReadBounds(index); + + return getBuffer().getInt(index * getRecordSize()); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java new file mode 100644 index 000000000..f22ebbd37 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/VectorExpander.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.vector.BaseVariableWidthVector; + +/** + * This class provides the functionality to expand output vectors using a callback mechanism from + * gandiva. + */ +public class VectorExpander { + private final BaseVariableWidthVector[] vectors; + + public VectorExpander(BaseVariableWidthVector[] vectors) { + this.vectors = vectors; + } + + /** + * Result of vector expansion. + */ + public static class ExpandResult { + public long address; + public long capacity; + + public ExpandResult(long address, long capacity) { + this.address = address; + this.capacity = capacity; + } + } + + /** + * Expand vector at specified index. This is used as a back call from jni, and is only + * relevant for variable width vectors. + * + * @param index index of buffer in the list passed to jni. + * @param toCapacity the size to which the buffer should be expanded to. + * + * @return address and size of the buffer after expansion. + */ + public ExpandResult expandOutputVectorAtIndex(int index, long toCapacity) { + if (index >= vectors.length || vectors[index] == null) { + throw new IllegalArgumentException("invalid index " + index); + } + + BaseVariableWidthVector vector = vectors[index]; + while (vector.getDataBuffer().capacity() < toCapacity) { + vector.reallocDataBuffer(); + } + return new ExpandResult( + vector.getDataBuffer().memoryAddress(), + vector.getDataBuffer().capacity()); + } + +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/EvaluatorClosedException.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/EvaluatorClosedException.java new file mode 100644 index 000000000..d3fb8b60d --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/EvaluatorClosedException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.exceptions; + +/** Indicates an attempted call to methods on a closed evaluator. */ +public class EvaluatorClosedException extends GandivaException { + public EvaluatorClosedException() { + super("Cannot invoke methods on evaluator after closing it"); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java new file mode 100644 index 000000000..e7fce58a3 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.exceptions; + +/** Base class for all specialized exceptions this package uses. */ +public class GandivaException extends Exception { + + public GandivaException(String msg) { + super(msg); + } + + public GandivaException(String msg, Exception cause) { + super(msg, cause); + } + + @Override + public String toString() { + return getMessage(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/UnsupportedTypeException.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/UnsupportedTypeException.java new file mode 100644 index 000000000..90e06e80e --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/exceptions/UnsupportedTypeException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.exceptions; + +/** + * Represents an exception thrown while dealing with unsupported types. + */ +public class UnsupportedTypeException extends GandivaException { + public UnsupportedTypeException(String msg) { + super(msg); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/AndNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/AndNode.java new file mode 100644 index 000000000..ecc577fa7 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/AndNode.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Node representing a logical And expression. + */ +class AndNode implements TreeNode { + private final List<TreeNode> children; + + AndNode(List<TreeNode> children) { + this.children = children; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.AndNode.Builder andNode = GandivaTypes.AndNode.newBuilder(); + + for (TreeNode arg : children) { + andNode.addArgs(arg.toProtobuf()); + } + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setAndNode(andNode.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java new file mode 100644 index 000000000..90f8684b4 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.flatbuf.DateUnit; +import org.apache.arrow.flatbuf.IntervalUnit; +import org.apache.arrow.flatbuf.TimeUnit; +import org.apache.arrow.flatbuf.Type; +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.exceptions.UnsupportedTypeException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Utility methods to convert between Arrow and Gandiva types. + */ +public class ArrowTypeHelper { + private ArrowTypeHelper() {} + + static final int WIDTH_8 = 8; + static final int WIDTH_16 = 16; + static final int WIDTH_32 = 32; + static final int WIDTH_64 = 64; + + private static void initArrowTypeInt( + ArrowType.Int intType, GandivaTypes.ExtGandivaType.Builder builder) throws GandivaException { + int width = intType.getBitWidth(); + + if (intType.getIsSigned()) { + switch (width) { + case WIDTH_8: { + builder.setType(GandivaTypes.GandivaType.INT8); + return; + } + case WIDTH_16: { + builder.setType(GandivaTypes.GandivaType.INT16); + return; + } + case WIDTH_32: { + builder.setType(GandivaTypes.GandivaType.INT32); + return; + } + case WIDTH_64: { + builder.setType(GandivaTypes.GandivaType.INT64); + return; + } + default: { + throw new UnsupportedTypeException("Unsupported width for integer type"); + } + } + } + + // unsigned int + switch (width) { + case WIDTH_8: { + builder.setType(GandivaTypes.GandivaType.UINT8); + return; + } + case WIDTH_16: { + builder.setType(GandivaTypes.GandivaType.UINT16); + return; + } + case WIDTH_32: { + builder.setType(GandivaTypes.GandivaType.UINT32); + return; + } + case WIDTH_64: { + builder.setType(GandivaTypes.GandivaType.UINT64); + return; + } + default: { + throw new UnsupportedTypeException("Unsupported width for integer type"); + } + } + } + + private static void initArrowTypeFloat( + ArrowType.FloatingPoint floatType, GandivaTypes.ExtGandivaType.Builder builder) + throws GandivaException { + switch (floatType.getPrecision()) { + case HALF: { + builder.setType(GandivaTypes.GandivaType.HALF_FLOAT); + break; + } + case SINGLE: { + builder.setType(GandivaTypes.GandivaType.FLOAT); + break; + } + case DOUBLE: { + builder.setType(GandivaTypes.GandivaType.DOUBLE); + break; + } + default: { + throw new UnsupportedTypeException("Floating point type with unknown precision"); + } + } + } + + private static void initArrowTypeDecimal(ArrowType.Decimal decimalType, + GandivaTypes.ExtGandivaType.Builder builder) { + Preconditions.checkArgument(decimalType.getPrecision() > 0 && + decimalType.getPrecision() <= 38, "Gandiva only supports decimals of upto 38 " + + "precision. Input precision : " + decimalType.getPrecision()); + builder.setPrecision(decimalType.getPrecision()); + builder.setScale(decimalType.getScale()); + builder.setType(GandivaTypes.GandivaType.DECIMAL); + } + + private static void initArrowTypeDate(ArrowType.Date dateType, + GandivaTypes.ExtGandivaType.Builder builder) { + short dateUnit = dateType.getUnit().getFlatbufID(); + switch (dateUnit) { + case DateUnit.DAY: { + builder.setType(GandivaTypes.GandivaType.DATE32); + break; + } + case DateUnit.MILLISECOND: { + builder.setType(GandivaTypes.GandivaType.DATE64); + break; + } + default: { + // not supported + break; + } + } + } + + private static void initArrowTypeTime(ArrowType.Time timeType, + GandivaTypes.ExtGandivaType.Builder builder) { + short timeUnit = timeType.getUnit().getFlatbufID(); + switch (timeUnit) { + case TimeUnit.SECOND: { + builder.setType(GandivaTypes.GandivaType.TIME32); + builder.setTimeUnit(GandivaTypes.TimeUnit.SEC); + break; + } + case TimeUnit.MILLISECOND: { + builder.setType(GandivaTypes.GandivaType.TIME32); + builder.setTimeUnit(GandivaTypes.TimeUnit.MILLISEC); + break; + } + case TimeUnit.MICROSECOND: { + builder.setType(GandivaTypes.GandivaType.TIME64); + builder.setTimeUnit(GandivaTypes.TimeUnit.MICROSEC); + break; + } + case TimeUnit.NANOSECOND: { + builder.setType(GandivaTypes.GandivaType.TIME64); + builder.setTimeUnit(GandivaTypes.TimeUnit.NANOSEC); + break; + } + default: { + // not supported + } + } + } + + private static void initArrowTypeTimestamp(ArrowType.Timestamp timestampType, + GandivaTypes.ExtGandivaType.Builder builder) { + short timeUnit = timestampType.getUnit().getFlatbufID(); + switch (timeUnit) { + case TimeUnit.SECOND: { + builder.setType(GandivaTypes.GandivaType.TIMESTAMP); + builder.setTimeUnit(GandivaTypes.TimeUnit.SEC); + break; + } + case TimeUnit.MILLISECOND: { + builder.setType(GandivaTypes.GandivaType.TIMESTAMP); + builder.setTimeUnit(GandivaTypes.TimeUnit.MILLISEC); + break; + } + case TimeUnit.MICROSECOND: { + builder.setType(GandivaTypes.GandivaType.TIMESTAMP); + builder.setTimeUnit(GandivaTypes.TimeUnit.MICROSEC); + break; + } + case TimeUnit.NANOSECOND: { + builder.setType(GandivaTypes.GandivaType.TIMESTAMP); + builder.setTimeUnit(GandivaTypes.TimeUnit.NANOSEC); + break; + } + default: { + // not supported + } + } + } + + private static void initArrowTypeInterval(ArrowType.Interval interval, + GandivaTypes.ExtGandivaType.Builder builder) { + short intervalUnit = interval.getUnit().getFlatbufID(); + switch (intervalUnit) { + case IntervalUnit.YEAR_MONTH: { + builder.setType(GandivaTypes.GandivaType.INTERVAL); + builder.setIntervalType(GandivaTypes.IntervalType.YEAR_MONTH); + break; + } + case IntervalUnit.DAY_TIME: { + builder.setType(GandivaTypes.GandivaType.INTERVAL); + builder.setIntervalType(GandivaTypes.IntervalType.DAY_TIME); + break; + } + default: { + // not supported + } + } + } + + /** + * Converts an arrow type into a protobuf. + * + * @param arrowType Arrow type to be converted + * @return Protobuf representing the arrow type + */ + public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType) + throws GandivaException { + GandivaTypes.ExtGandivaType.Builder builder = GandivaTypes.ExtGandivaType.newBuilder(); + + byte typeId = arrowType.getTypeID().getFlatbufID(); + switch (typeId) { + case Type.NONE: { // 0 + builder.setType(GandivaTypes.GandivaType.NONE); + break; + } + case Type.Null: { // 1 + // TODO: Need to handle this later + break; + } + case Type.Int: { // 2 + ArrowTypeHelper.initArrowTypeInt((ArrowType.Int) arrowType, builder); + break; + } + case Type.FloatingPoint: { // 3 + ArrowTypeHelper.initArrowTypeFloat((ArrowType.FloatingPoint) arrowType, builder); + break; + } + case Type.Binary: { // 4 + builder.setType(GandivaTypes.GandivaType.BINARY); + break; + } + case Type.Utf8: { // 5 + builder.setType(GandivaTypes.GandivaType.UTF8); + break; + } + case Type.Bool: { // 6 + builder.setType(GandivaTypes.GandivaType.BOOL); + break; + } + case Type.Decimal: { // 7 + ArrowTypeHelper.initArrowTypeDecimal((ArrowType.Decimal) arrowType, builder); + break; + } + case Type.Date: { // 8 + ArrowTypeHelper.initArrowTypeDate((ArrowType.Date) arrowType, builder); + break; + } + case Type.Time: { // 9 + ArrowTypeHelper.initArrowTypeTime((ArrowType.Time) arrowType, builder); + break; + } + case Type.Timestamp: { // 10 + ArrowTypeHelper.initArrowTypeTimestamp((ArrowType.Timestamp) arrowType, builder); + break; + } + case Type.Interval: { // 11 + ArrowTypeHelper.initArrowTypeInterval((ArrowType.Interval) arrowType, builder); + break; + } + case Type.List: { // 12 + break; + } + case Type.Struct_: { // 13 + break; + } + case Type.Union: { // 14 + break; + } + case Type.FixedSizeBinary: { // 15 + break; + } + case Type.FixedSizeList: { // 16 + break; + } + case Type.Map: { // 17 + break; + } + default: { + break; + } + } + + if (!builder.hasType()) { + // type has not been set + // throw an exception + throw new UnsupportedTypeException("Unsupported type " + arrowType.toString()); + } + + return builder.build(); + } + + /** + * Converts an arrow field object to a protobuf. + * @param field Arrow field to be converted + * @return Protobuf representing the arrow field + */ + public static GandivaTypes.Field arrowFieldToProtobuf(Field field) throws GandivaException { + GandivaTypes.Field.Builder builder = GandivaTypes.Field.newBuilder(); + builder.setName(field.getName()); + builder.setType(ArrowTypeHelper.arrowTypeToProtobuf(field.getType())); + builder.setNullable(field.isNullable()); + + for (Field child : field.getChildren()) { + builder.addChildren(ArrowTypeHelper.arrowFieldToProtobuf(child)); + } + + return builder.build(); + } + + /** + * Converts a schema object to a protobuf. + * @param schema Schema object to be converted + * @return Protobuf representing a schema object + */ + public static GandivaTypes.Schema arrowSchemaToProtobuf(Schema schema) throws GandivaException { + GandivaTypes.Schema.Builder builder = GandivaTypes.Schema.newBuilder(); + + for (Field field : schema.getFields()) { + builder.addColumns(ArrowTypeHelper.arrowFieldToProtobuf(field)); + } + + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BinaryNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BinaryNode.java new file mode 100644 index 000000000..8455f29b2 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BinaryNode.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +import com.google.protobuf.ByteString; + +/** + * Used to represent expression tree nodes representing binary constants. + */ +class BinaryNode implements TreeNode { + private final byte[] value; + + public BinaryNode(byte[] value) { + this.value = value; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.BinaryNode binaryNode = GandivaTypes.BinaryNode.newBuilder() + .setValue(ByteString.copyFrom(value)) + .build(); + + return GandivaTypes.TreeNode.newBuilder() + .setBinaryNode(binaryNode) + .build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BooleanNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BooleanNode.java new file mode 100644 index 000000000..505f01919 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/BooleanNode.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Used to represent expression tree nodes representing boolean constants. + * Used while creating expressions like if (!x). + */ +class BooleanNode implements TreeNode { + private final Boolean value; + + BooleanNode(Boolean value) { + this.value = value; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.BooleanNode.Builder boolBuilder = GandivaTypes.BooleanNode.newBuilder(); + boolBuilder.setValue(value.booleanValue()); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setBooleanNode(boolBuilder.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/Condition.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/Condition.java new file mode 100644 index 000000000..1d584d673 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/Condition.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Opaque class representing a filter condition. + */ +public class Condition { + private final TreeNode root; + + Condition(TreeNode root) { + this.root = root; + } + + /** + * Converts an condition expression into a protobuf. + * @return A protobuf representing the condition expression tree + */ + public GandivaTypes.Condition toProtobuf() throws GandivaException { + GandivaTypes.Condition.Builder builder = GandivaTypes.Condition.newBuilder(); + builder.setRoot(root.toProtobuf()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java new file mode 100644 index 000000000..bf17aa0aa --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DecimalNode.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Used to represent expression tree nodes representing decimal constants. + * Used in the expression (x + 5.0) + */ +class DecimalNode implements TreeNode { + private final String value; + private final int precision; + private final int scale; + + DecimalNode(String value, int precision, int scale) { + this.value = value; + this.precision = precision; + this.scale = scale; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.DecimalNode.Builder decimalNode = GandivaTypes.DecimalNode.newBuilder(); + decimalNode.setValue(value); + decimalNode.setPrecision(precision); + decimalNode.setScale(scale); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setDecimalNode(decimalNode.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DoubleNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DoubleNode.java new file mode 100644 index 000000000..f7a9436f1 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/DoubleNode.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Used to represent expression tree nodes representing double constants. + * Used in the expression (x + 5.0) + */ +class DoubleNode implements TreeNode { + private final Double value; + + DoubleNode(Double value) { + this.value = value; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.DoubleNode.Builder doubleBuilder = GandivaTypes.DoubleNode.newBuilder(); + doubleBuilder.setValue(value.doubleValue()); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setDoubleNode(doubleBuilder.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ExpressionTree.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ExpressionTree.java new file mode 100644 index 000000000..353c8d12b --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ExpressionTree.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.pojo.Field; + +/** + * Opaque class representing an expression. + */ +public class ExpressionTree { + private final TreeNode root; + private final Field resultField; + + ExpressionTree(TreeNode root, Field resultField) { + this.root = root; + this.resultField = resultField; + } + + /** + * Converts an expression tree into a protobuf. + * @return A protobuf representing the expression tree + */ + public GandivaTypes.ExpressionRoot toProtobuf() throws GandivaException { + GandivaTypes.ExpressionRoot.Builder builder = GandivaTypes.ExpressionRoot.newBuilder(); + builder.setRoot(root.toProtobuf()); + builder.setResultType(ArrowTypeHelper.arrowFieldToProtobuf(resultField)); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FieldNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FieldNode.java new file mode 100644 index 000000000..893bf7191 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FieldNode.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.pojo.Field; + +/** + * Opaque class that represents a tree node. + */ +class FieldNode implements TreeNode { + private final Field field; + + FieldNode(Field field) { + this.field = field; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.FieldNode.Builder fieldNode = GandivaTypes.FieldNode.newBuilder(); + fieldNode.setField(ArrowTypeHelper.arrowFieldToProtobuf(field)); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setFieldNode(fieldNode.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FloatNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FloatNode.java new file mode 100644 index 000000000..6afe96bfe --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FloatNode.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Used to represent expression tree nodes representing float constants. + * Used in the expression (x + 5.0) + */ +class FloatNode implements TreeNode { + private final Float value; + + public FloatNode(Float value) { + this.value = value; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.FloatNode.Builder floatBuilder = GandivaTypes.FloatNode.newBuilder(); + floatBuilder.setValue(value.floatValue()); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setFloatNode(floatBuilder.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java new file mode 100644 index 000000000..ead1e146d --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.pojo.ArrowType; + +/** + * Node representing an arbitrary function in an expression. + */ +class FunctionNode implements TreeNode { + private final String function; + private final List<TreeNode> children; + private final ArrowType retType; + + FunctionNode(String function, List<TreeNode> children, ArrowType retType) { + this.function = function; + this.children = children; + this.retType = retType; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.FunctionNode.Builder fnNode = GandivaTypes.FunctionNode.newBuilder(); + fnNode.setFunctionName(function); + fnNode.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType)); + + for (TreeNode arg : children) { + fnNode.addInArgs(arg.toProtobuf()); + } + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setFnNode(fnNode.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IfNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IfNode.java new file mode 100644 index 000000000..19f9095fb --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IfNode.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.pojo.ArrowType; + +/** + * Node representing a if-then-else block expression. + */ +class IfNode implements TreeNode { + private final TreeNode condition; + private final TreeNode thenNode; + private final TreeNode elseNode; + private final ArrowType retType; + + IfNode(TreeNode condition, TreeNode thenNode, TreeNode elseNode, ArrowType retType) { + this.condition = condition; + this.thenNode = thenNode; + this.elseNode = elseNode; + this.retType = retType; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.IfNode.Builder ifNodeBuilder = GandivaTypes.IfNode.newBuilder(); + ifNodeBuilder.setCond(condition.toProtobuf()); + ifNodeBuilder.setThenNode(thenNode.toProtobuf()); + ifNodeBuilder.setElseNode(elseNode.toProtobuf()); + ifNodeBuilder.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType)); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setIfNode(ifNodeBuilder.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java new file mode 100644 index 000000000..0f8de9628 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import java.math.BigDecimal; +import java.nio.charset.Charset; +import java.util.Set; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +import com.google.protobuf.ByteString; + +/** + * In Node representation in java. + */ +public class InNode implements TreeNode { + private static final Charset charset = Charset.forName("UTF-8"); + + private final Set<Integer> intValues; + private final Set<Long> longValues; + private final Set<Float> floatValues; + private final Set<Double> doubleValues; + private final Set<BigDecimal> decimalValues; + private final Set<String> stringValues; + private final Set<byte[]> binaryValues; + private final TreeNode input; + + private final Integer precision; + private final Integer scale; + + private InNode(Set<Integer> values, Set<Long> longValues, Set<String> stringValues, Set<byte[]> + binaryValues, Set<BigDecimal> decimalValues, Integer precision, Integer scale, + Set<Float> floatValues, Set<Double> doubleValues, TreeNode node) { + this.intValues = values; + this.longValues = longValues; + this.decimalValues = decimalValues; + this.precision = precision; + this.scale = scale; + this.stringValues = stringValues; + this.binaryValues = binaryValues; + this.floatValues = floatValues; + this.doubleValues = doubleValues; + this.input = node; + } + + /** + * Makes an IN node for int values. + * + * @param node Node with the 'IN' clause. + * @param intValues Int values to build the IN node. + * @retur InNode referring to tree node. + */ + public static InNode makeIntInExpr(TreeNode node, Set<Integer> intValues) { + return new InNode(intValues, + null, null, null, null, null, null, null, + null, node); + } + + /** + * Makes an IN node for long values. + * + * @param node Node with the 'IN' clause. + * @param longValues Long values to build the IN node. + * @retur InNode referring to tree node. + */ + public static InNode makeLongInExpr(TreeNode node, Set<Long> longValues) { + return new InNode(null, longValues, + null, null, null, null, null, null, + null, node); + } + + /** + * Makes an IN node for float values. + * + * @param node Node with the 'IN' clause. + * @param floatValues Float values to build the IN node. + * @retur InNode referring to tree node. + */ + public static InNode makeFloatInExpr(TreeNode node, Set<Float> floatValues) { + return new InNode(null, null, null, null, null, null, + null, floatValues, null, node); + } + + /** + * Makes an IN node for double values. + * + * @param node Node with the 'IN' clause. + * @param doubleValues Double values to build the IN node. + * @retur InNode referring to tree node. + */ + public static InNode makeDoubleInExpr(TreeNode node, Set<Double> doubleValues) { + return new InNode(null, null, null, null, null, + null, null, null, doubleValues, node); + } + + public static InNode makeDecimalInExpr(TreeNode node, Set<BigDecimal> decimalValues, + Integer precision, Integer scale) { + return new InNode(null, null, null, null, + decimalValues, precision, scale, null, null, node); + } + + public static InNode makeStringInExpr(TreeNode node, Set<String> stringValues) { + return new InNode(null, null, stringValues, null, + null, null, null, null, null, node); + } + + public static InNode makeBinaryInExpr(TreeNode node, Set<byte[]> binaryValues) { + return new InNode(null, null, null, binaryValues, + null, null, null, null, null, node); + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.InNode.Builder inNode = GandivaTypes.InNode.newBuilder(); + + inNode.setNode(input.toProtobuf()); + + if (intValues != null) { + GandivaTypes.IntConstants.Builder intConstants = GandivaTypes.IntConstants.newBuilder(); + intValues.stream().forEach(val -> intConstants.addIntValues(GandivaTypes.IntNode.newBuilder() + .setValue(val).build())); + inNode.setIntValues(intConstants.build()); + } else if (longValues != null) { + GandivaTypes.LongConstants.Builder longConstants = GandivaTypes.LongConstants.newBuilder(); + longValues.stream().forEach(val -> longConstants.addLongValues(GandivaTypes.LongNode.newBuilder() + .setValue(val).build())); + inNode.setLongValues(longConstants.build()); + } else if (floatValues != null) { + GandivaTypes.FloatConstants.Builder floatConstants = GandivaTypes.FloatConstants.newBuilder(); + floatValues.stream().forEach(val -> floatConstants.addFloatValues(GandivaTypes.FloatNode.newBuilder() + .setValue(val).build())); + inNode.setFloatValues(floatConstants.build()); + } else if (doubleValues != null) { + GandivaTypes.DoubleConstants.Builder doubleConstants = GandivaTypes.DoubleConstants.newBuilder(); + doubleValues.stream().forEach(val -> doubleConstants.addDoubleValues(GandivaTypes.DoubleNode.newBuilder() + .setValue(val).build())); + inNode.setDoubleValues(doubleConstants.build()); + } else if (decimalValues != null) { + GandivaTypes.DecimalConstants.Builder decimalConstants = GandivaTypes.DecimalConstants.newBuilder(); + decimalValues.stream().forEach(val -> decimalConstants.addDecimalValues(GandivaTypes.DecimalNode.newBuilder() + .setValue(val.toPlainString()).setPrecision(precision).setScale(scale).build())); + inNode.setDecimalValues(decimalConstants.build()); + } else if (stringValues != null) { + GandivaTypes.StringConstants.Builder stringConstants = GandivaTypes.StringConstants + .newBuilder(); + stringValues.stream().forEach(val -> stringConstants.addStringValues(GandivaTypes.StringNode + .newBuilder().setValue(ByteString.copyFrom(val.getBytes(charset))).build())); + inNode.setStringValues(stringConstants.build()); + } else if (binaryValues != null) { + GandivaTypes.BinaryConstants.Builder binaryConstants = GandivaTypes.BinaryConstants + .newBuilder(); + binaryValues.stream().forEach(val -> binaryConstants.addBinaryValues(GandivaTypes.BinaryNode + .newBuilder().setValue(ByteString.copyFrom(val)).build())); + inNode.setBinaryValues(binaryConstants.build()); + } + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setInNode(inNode.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IntNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IntNode.java new file mode 100644 index 000000000..c3858ef7e --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/IntNode.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Used to represent expression tree nodes representing int constants. + * Used in the expression (x + 5) + */ +class IntNode implements TreeNode { + private final Integer value; + + IntNode(Integer value) { + this.value = value; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.IntNode.Builder intBuilder = GandivaTypes.IntNode.newBuilder(); + intBuilder.setValue(value.intValue()); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setIntNode(intBuilder.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/LongNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/LongNode.java new file mode 100644 index 000000000..311c5d94d --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/LongNode.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Used to represent expression tree nodes representing long constants. + * Used in the expression (x + 5L) + */ +class LongNode implements TreeNode { + private final Long value; + + LongNode(Long value) { + this.value = value; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.LongNode.Builder longBuilder = GandivaTypes.LongNode.newBuilder(); + longBuilder.setValue(value.longValue()); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setLongNode(longBuilder.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/NullNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/NullNode.java new file mode 100644 index 000000000..a8e7d6f82 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/NullNode.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.pojo.ArrowType; + +/** An expression indicating a null value. */ +class NullNode implements TreeNode { + private final ArrowType type; + + NullNode(ArrowType type) { + this.type = type; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.NullNode.Builder nullNode = GandivaTypes.NullNode.newBuilder(); + nullNode.setType(ArrowTypeHelper.arrowTypeToProtobuf(type)); + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setNullNode(nullNode.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/OrNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/OrNode.java new file mode 100644 index 000000000..2dbdfed7c --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/OrNode.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Represents a logical OR Node. + */ +class OrNode implements TreeNode { + private final List<TreeNode> children; + + OrNode(List<TreeNode> children) { + this.children = children; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.OrNode.Builder orNode = GandivaTypes.OrNode.newBuilder(); + + for (TreeNode arg : children) { + orNode.addArgs(arg.toProtobuf()); + } + + GandivaTypes.TreeNode.Builder builder = GandivaTypes.TreeNode.newBuilder(); + builder.setOrNode(orNode.build()); + return builder.build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/StringNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/StringNode.java new file mode 100644 index 000000000..a44329739 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/StringNode.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import java.nio.charset.Charset; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +import com.google.protobuf.ByteString; + +/** + * Used to represent expression tree nodes representing utf8 constants. + */ +class StringNode implements TreeNode { + private static final Charset charset = Charset.forName("UTF-8"); + private final String value; + + public StringNode(String value) { + this.value = value; + } + + @Override + public GandivaTypes.TreeNode toProtobuf() throws GandivaException { + GandivaTypes.StringNode stringNode = GandivaTypes.StringNode.newBuilder() + .setValue(ByteString.copyFrom(value.getBytes(charset))) + .build(); + + return GandivaTypes.TreeNode.newBuilder() + .setStringNode(stringNode) + .build(); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java new file mode 100644 index 000000000..8656e886a --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + +/** + * Contains helper functions for constructing expression trees. + */ +public class TreeBuilder { + private TreeBuilder() {} + + /** + * Helper functions to create literal constants. + */ + public static TreeNode makeLiteral(Boolean booleanConstant) { + return new BooleanNode(booleanConstant); + } + + public static TreeNode makeLiteral(Float floatConstant) { + return new FloatNode(floatConstant); + } + + public static TreeNode makeLiteral(Double doubleConstant) { + return new DoubleNode(doubleConstant); + } + + public static TreeNode makeLiteral(Integer integerConstant) { + return new IntNode(integerConstant); + } + + public static TreeNode makeLiteral(Long longConstant) { + return new LongNode(longConstant); + } + + public static TreeNode makeStringLiteral(String stringConstant) { + return new StringNode(stringConstant); + } + + public static TreeNode makeBinaryLiteral(byte[] binaryConstant) { + return new BinaryNode(binaryConstant); + } + + public static TreeNode makeDecimalLiteral(String decimalConstant, int precision, int scale) { + return new DecimalNode(decimalConstant, precision, scale); + } + + /** + * create a null literal. + */ + public static TreeNode makeNull(ArrowType type) { + return new NullNode(type); + } + + /** + * Invoke this function to create a node representing a field, e.g. a column name. + * + * @param field represents the input argument - includes the name and type of the field + * @return Node representing a field + */ + public static TreeNode makeField(Field field) { + return new FieldNode(field); + } + + /** + * Invoke this function to create a node representing a function. + * + * @param function Name of the function, e.g. add + * @param children The arguments to the function + * @param retType The type of the return value of the operator + * @return Node representing a function + */ + public static TreeNode makeFunction(String function, + List<TreeNode> children, + ArrowType retType) { + return new FunctionNode(function, children, retType); + } + + /** + * Invoke this function to create a node representing an if-clause. + * + * @param condition Node representing the condition + * @param thenNode Node representing the if-block + * @param elseNode Node representing the else-block + * @param retType Return type of the node + * @return Node representing an if-clause + */ + public static TreeNode makeIf(TreeNode condition, + TreeNode thenNode, + TreeNode elseNode, + ArrowType retType) { + return new IfNode(condition, thenNode, elseNode, retType); + } + + /** + * Invoke this function to create a node representing an and-clause. + * + * @param nodes Nodes in the 'and' clause. + * @return Node representing an and-clause + */ + public static TreeNode makeAnd(List<TreeNode> nodes) { + return new AndNode(nodes); + } + + /** + * Invoke this function to create a node representing an or-clause. + * + * @param nodes Nodes in the 'or' clause. + * @return Node representing an or-clause + */ + public static TreeNode makeOr(List<TreeNode> nodes) { + return new OrNode(nodes); + } + + /** + * Invoke this function to create an expression tree. + * + * @param root is returned by a call to MakeField, MakeFunction, or MakeIf + * @param resultField represents the return value of the expression + * @return ExpressionTree referring to the root of an expression tree + */ + public static ExpressionTree makeExpression(TreeNode root, + Field resultField) { + return new ExpressionTree(root, resultField); + } + + /** + * Short cut to create an expression tree involving a single function, e.g. a+b+c. + * + * @param function Name of the function, e.g. add() + * @param inFields In arguments to the function + * @param resultField represents the return value of the expression + * @return ExpressionTree referring to the root of an expression tree + */ + public static ExpressionTree makeExpression(String function, + List<Field> inFields, + Field resultField) { + List<TreeNode> children = new ArrayList<TreeNode>(inFields.size()); + for (Field field : inFields) { + children.add(makeField(field)); + } + + TreeNode root = makeFunction(function, children, resultField.getType()); + return makeExpression(root, resultField); + } + + /** + * Invoke this function to create a condition. + * + * @param root is returned by a call to MakeField, MakeFunction, MakeIf, .. + * @return condition referring to the root of an expression tree + */ + public static Condition makeCondition(TreeNode root) { + return new Condition(root); + } + + /** + * Short cut to create an expression tree involving a single function, e.g. a+b+c. + * + * @param function Name of the function, e.g. add() + * @param inFields In arguments to the function + * @return condition referring to the root of an expression tree + */ + public static Condition makeCondition(String function, + List<Field> inFields) { + List<TreeNode> children = new ArrayList<>(inFields.size()); + for (Field field : inFields) { + children.add(makeField(field)); + } + + TreeNode root = makeFunction(function, children, new ArrowType.Bool()); + return makeCondition(root); + } + + public static TreeNode makeInExpressionInt32(TreeNode resultNode, + Set<Integer> intValues) { + return InNode.makeIntInExpr(resultNode, intValues); + } + + public static TreeNode makeInExpressionBigInt(TreeNode resultNode, + Set<Long> longValues) { + return InNode.makeLongInExpr(resultNode, longValues); + } + + public static TreeNode makeInExpressionDecimal(TreeNode resultNode, + Set<BigDecimal> decimalValues, Integer precision, Integer scale) { + return InNode.makeDecimalInExpr(resultNode, decimalValues, precision, scale); + } + + public static TreeNode makeInExpressionFloat(TreeNode resultNode, + Set<Float> floatValues) { + return InNode.makeFloatInExpr(resultNode, floatValues); + } + + public static TreeNode makeInExpressionDouble(TreeNode resultNode, + Set<Double> doubleValues) { + return InNode.makeDoubleInExpr(resultNode, doubleValues); + } + + public static TreeNode makeInExpressionString(TreeNode resultNode, + Set<String> stringValues) { + return InNode.makeStringInExpr(resultNode, stringValues); + } + + public static TreeNode makeInExpressionBinary(TreeNode resultNode, + Set<byte[]> binaryValues) { + return InNode.makeBinaryInExpr(resultNode, binaryValues); + } +} diff --git a/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeNode.java b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeNode.java new file mode 100644 index 000000000..b8d70d6e7 --- /dev/null +++ b/src/arrow/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeNode.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; + +/** + * Defines an internal node in the expression tree. + */ +public interface TreeNode { + /** + * Converts a TreeNode into a protobuf. + * + * @return A treenode protobuf + * @throws GandivaException in case the TreeNode cannot be processed + */ + GandivaTypes.TreeNode toProtobuf() throws GandivaException; +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java new file mode 100644 index 000000000..4a36c0405 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/BaseEvaluatorTest.java @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.math.BigDecimal; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.expression.Condition; +import org.apache.arrow.gandiva.expression.ExpressionTree; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; + +class BaseEvaluatorTest { + + interface BaseEvaluator { + + void evaluate(ArrowRecordBatch recordBatch, BufferAllocator allocator) throws GandivaException; + + long getElapsedMillis(); + } + + class ProjectEvaluator implements BaseEvaluator { + + private Projector projector; + private DataAndVectorGenerator generator; + private int numExprs; + private int maxRowsInBatch; + private long elapsedTime = 0; + private List<ValueVector> outputVectors = new ArrayList<>(); + + public ProjectEvaluator(Projector projector, + DataAndVectorGenerator generator, + int numExprs, + int maxRowsInBatch) { + this.projector = projector; + this.generator = generator; + this.numExprs = numExprs; + this.maxRowsInBatch = maxRowsInBatch; + } + + @Override + public void evaluate(ArrowRecordBatch recordBatch, + BufferAllocator allocator) throws GandivaException { + // set up output vectors + // for each expression, generate the output vector + for (int i = 0; i < numExprs; i++) { + ValueVector valueVector = generator.generateOutputVector(maxRowsInBatch); + outputVectors.add(valueVector); + } + + try { + long start = System.nanoTime(); + projector.evaluate(recordBatch, outputVectors); + long finish = System.nanoTime(); + elapsedTime += (finish - start); + } finally { + for (ValueVector valueVector : outputVectors) { + valueVector.close(); + } + } + outputVectors.clear(); + } + + @Override + public long getElapsedMillis() { + return TimeUnit.NANOSECONDS.toMillis(elapsedTime); + } + } + + class FilterEvaluator implements BaseEvaluator { + + private Filter filter; + private long elapsedTime = 0; + + public FilterEvaluator(Filter filter) { + this.filter = filter; + } + + @Override + public void evaluate(ArrowRecordBatch recordBatch, + BufferAllocator allocator) throws GandivaException { + ArrowBuf selectionBuffer = allocator.buffer(recordBatch.getLength() * 2); + SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer); + + try { + long start = System.nanoTime(); + filter.evaluate(recordBatch, selectionVector); + long finish = System.nanoTime(); + elapsedTime += (finish - start); + } finally { + selectionBuffer.close(); + } + } + + @Override + public long getElapsedMillis() { + return TimeUnit.NANOSECONDS.toMillis(elapsedTime); + } + } + + interface DataAndVectorGenerator { + + void writeData(ArrowBuf buffer); + + ValueVector generateOutputVector(int numRowsInBatch); + } + + class Int32DataAndVectorGenerator implements DataAndVectorGenerator { + + protected final BufferAllocator allocator; + protected final Random rand; + + Int32DataAndVectorGenerator(BufferAllocator allocator) { + this.allocator = allocator; + this.rand = new Random(); + } + + @Override + public void writeData(ArrowBuf buffer) { + buffer.writeInt(rand.nextInt()); + } + + @Override + public ValueVector generateOutputVector(int numRowsInBatch) { + IntVector intVector = new IntVector(BaseEvaluatorTest.EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRowsInBatch); + return intVector; + } + } + + class BoundedInt32DataAndVectorGenerator extends Int32DataAndVectorGenerator { + + private final int upperBound; + + BoundedInt32DataAndVectorGenerator(BufferAllocator allocator, int upperBound) { + super(allocator); + this.upperBound = upperBound; + } + + @Override + public void writeData(ArrowBuf buffer) { + buffer.writeInt(rand.nextInt(upperBound)); + } + } + + protected static final int THOUSAND = 1000; + protected static final int MILLION = THOUSAND * THOUSAND; + + protected static final String EMPTY_SCHEMA_PATH = ""; + + protected BufferAllocator allocator; + protected ArrowType boolType; + protected ArrowType int8; + protected ArrowType int32; + protected ArrowType int64; + protected ArrowType float64; + + @Before + public void init() { + allocator = new RootAllocator(Long.MAX_VALUE); + boolType = new ArrowType.Bool(); + int8 = new ArrowType.Int(8, true); + int32 = new ArrowType.Int(32, true); + int64 = new ArrowType.Int(64, true); + float64 = new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + } + + @After + public void tearDown() { + allocator.close(); + } + + ArrowBuf buf(int length) { + ArrowBuf buffer = allocator.buffer(length); + return buffer; + } + + ArrowBuf buf(byte[] bytes) { + ArrowBuf buffer = allocator.buffer(bytes.length); + buffer.writeBytes(bytes); + return buffer; + } + + ArrowBuf arrowBufWithAllValid(int size) { + int bufLen = (size + 7) / 8; + ArrowBuf buffer = allocator.buffer(bufLen); + for (int i = 0; i < bufLen; i++) { + buffer.writeByte(255); + } + + return buffer; + } + + ArrowBuf intBuf(int[] ints) { + ArrowBuf buffer = allocator.buffer(ints.length * 4); + for (int i = 0; i < ints.length; i++) { + buffer.writeInt(ints[i]); + } + return buffer; + } + + DecimalVector decimalVector(String[] values, int precision, int scale) { + DecimalVector vector = new DecimalVector("decimal" + Math.random(), allocator, precision, scale); + vector.allocateNew(); + for (int i = 0; i < values.length; i++) { + BigDecimal decimal = new BigDecimal(values[i]).setScale(scale); + vector.setSafe(i, decimal); + } + + vector.setValueCount(values.length); + return vector; + } + + Set decimalSet(String[] values, Integer scale) { + Set<BigDecimal> decimalSet = new HashSet<>(); + for (int i = 0; i < values.length; i++) { + decimalSet.add(new BigDecimal(values[i]).setScale(scale)); + } + + return decimalSet; + } + + VarCharVector varcharVector(String[] values) { + VarCharVector vector = new VarCharVector("VarCharVector" + Math.random(), allocator); + vector.allocateNew(); + for (int i = 0; i < values.length; i++) { + vector.setSafe(i, values[i].getBytes(), 0, values[i].length()); + } + + vector.setValueCount(values.length); + return vector; + } + + ArrowBuf longBuf(long[] longs) { + ArrowBuf buffer = allocator.buffer(longs.length * 8); + for (int i = 0; i < longs.length; i++) { + buffer.writeLong(longs[i]); + } + return buffer; + } + + ArrowBuf doubleBuf(double[] data) { + ArrowBuf buffer = allocator.buffer(data.length * 8); + for (int i = 0; i < data.length; i++) { + buffer.writeDouble(data[i]); + } + + return buffer; + } + + ArrowBuf stringToMillis(String[] dates) { + ArrowBuf buffer = allocator.buffer(dates.length * 8); + for (int i = 0; i < dates.length; i++) { + Instant instant = Instant.parse(dates[i]); + buffer.writeLong(instant.toEpochMilli()); + } + + return buffer; + } + + ArrowBuf stringToDayInterval(String[] values) { + ArrowBuf buffer = allocator.buffer(values.length * 8); + for (int i = 0; i < values.length; i++) { + buffer.writeInt(Integer.parseInt(values[i].split(" ")[0])); // days + buffer.writeInt(Integer.parseInt(values[i].split(" ")[1])); // millis + } + return buffer; + } + + void releaseRecordBatch(ArrowRecordBatch recordBatch) { + // There are 2 references to the buffers + // One in the recordBatch - release that by calling close() + // One in the allocator - release that explicitly + List<ArrowBuf> buffers = recordBatch.getBuffers(); + recordBatch.close(); + for (ArrowBuf buf : buffers) { + buf.getReferenceManager().release(); + } + } + + void releaseValueVectors(List<ValueVector> valueVectors) { + for (ValueVector valueVector : valueVectors) { + valueVector.close(); + } + } + + void generateData(DataAndVectorGenerator generator, int numRecords, ArrowBuf buffer) { + for (int i = 0; i < numRecords; i++) { + generator.writeData(buffer); + } + } + + private void generateDataAndEvaluate(DataAndVectorGenerator generator, + BaseEvaluator evaluator, + int numFields, + int numRows, int maxRowsInBatch, + int inputFieldSize) + throws GandivaException, Exception { + int numRemaining = numRows; + List<ArrowBuf> inputData = new ArrayList<ArrowBuf>(); + List<ArrowFieldNode> fieldNodes = new ArrayList<ArrowFieldNode>(); + + // set the bitmap + while (numRemaining > 0) { + int numRowsInBatch = maxRowsInBatch; + if (numRowsInBatch > numRemaining) { + numRowsInBatch = numRemaining; + } + + // generate data + for (int i = 0; i < numFields; i++) { + ArrowBuf buf = allocator.buffer(numRowsInBatch * inputFieldSize); + ArrowBuf validity = arrowBufWithAllValid(maxRowsInBatch); + generateData(generator, numRowsInBatch, buf); + + fieldNodes.add(new ArrowFieldNode(numRowsInBatch, 0)); + inputData.add(validity); + inputData.add(buf); + } + + // create record batch + ArrowRecordBatch recordBatch = new ArrowRecordBatch(numRowsInBatch, fieldNodes, inputData); + + evaluator.evaluate(recordBatch, allocator); + + // fix numRemaining + numRemaining -= numRowsInBatch; + + // release refs + releaseRecordBatch(recordBatch); + + inputData.clear(); + fieldNodes.clear(); + } + } + + long timedProject(DataAndVectorGenerator generator, + Schema schema, List<ExpressionTree> exprs, + int numRows, int maxRowsInBatch, + int inputFieldSize) + throws GandivaException, Exception { + Projector projector = Projector.make(schema, exprs); + try { + ProjectEvaluator evaluator = + new ProjectEvaluator(projector, generator, exprs.size(), maxRowsInBatch); + generateDataAndEvaluate(generator, evaluator, + schema.getFields().size(), numRows, maxRowsInBatch, inputFieldSize); + return evaluator.getElapsedMillis(); + } finally { + projector.close(); + } + } + + long timedFilter(DataAndVectorGenerator generator, + Schema schema, Condition condition, + int numRows, int maxRowsInBatch, + int inputFieldSize) + throws GandivaException, Exception { + + Filter filter = Filter.make(schema, condition); + try { + FilterEvaluator evaluator = new FilterEvaluator(filter); + generateDataAndEvaluate(generator, evaluator, + schema.getFields().size(), numRows, maxRowsInBatch, inputFieldSize); + return evaluator.getElapsedMillis(); + } finally { + filter.close(); + } + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java new file mode 100644 index 000000000..fe51c09e3 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/DecimalTypeUtilTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.Assert; +import org.junit.Test; + +public class DecimalTypeUtilTest { + + @Test + public void testOutputTypesForAdd() { + ArrowType.Decimal operand1 = getDecimal(30, 10); + ArrowType.Decimal operand2 = getDecimal(30, 10); + ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.ADD, operand1, operand2); + Assert.assertTrue(getDecimal(31, 10).equals(resultType)); + + operand1 = getDecimal(30, 6); + operand2 = getDecimal(30, 5); + resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.ADD, operand1, operand2); + Assert.assertTrue(getDecimal(32, 6).equals(resultType)); + + operand1 = getDecimal(30, 10); + operand2 = getDecimal(38, 10); + resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.ADD, operand1, operand2); + Assert.assertTrue(getDecimal(38, 9).equals(resultType)); + + operand1 = getDecimal(38, 10); + operand2 = getDecimal(38, 38); + resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.ADD, operand1, operand2); + Assert.assertTrue(getDecimal(38, 9).equals(resultType)); + + operand1 = getDecimal(38, 10); + operand2 = getDecimal(38, 2); + resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.ADD, operand1, operand2); + Assert.assertTrue(getDecimal(38, 6).equals(resultType)); + + } + + @Test + public void testOutputTypesForMultiply() { + ArrowType.Decimal operand1 = getDecimal(30, 10); + ArrowType.Decimal operand2 = getDecimal(30, 10); + ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.MULTIPLY, operand1, operand2); + Assert.assertTrue(getDecimal(38, 6).equals(resultType)); + + operand1 = getDecimal(38, 10); + operand2 = getDecimal(9, 2); + resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.MULTIPLY, operand1, operand2); + Assert.assertTrue(getDecimal(38, 6).equals(resultType)); + + } + + @Test + public void testOutputTypesForMod() { + ArrowType.Decimal operand1 = getDecimal(30, 10); + ArrowType.Decimal operand2 = getDecimal(28, 7); + ArrowType.Decimal resultType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.MOD, operand1, operand2); + Assert.assertTrue(getDecimal(30, 10).equals(resultType)); + } + + private ArrowType.Decimal getDecimal(int precision, int scale) { + return new ArrowType.Decimal(precision, scale, 128); + } + +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryTest.java new file mode 100644 index 000000000..a51ac09ba --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.util.Set; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.Assert; +import org.junit.Test; + +import com.google.common.collect.Lists; + +public class ExpressionRegistryTest { + + @Test + public void testTypes() throws GandivaException { + Set<ArrowType> types = ExpressionRegistry.getInstance().getSupportedTypes(); + ArrowType.Int uint8 = new ArrowType.Int(8, false); + Assert.assertTrue(types.contains(uint8)); + } + + @Test + public void testFunctions() throws GandivaException { + ArrowType.Int uint8 = new ArrowType.Int(8, false); + FunctionSignature signature = + new FunctionSignature("add", uint8, Lists.newArrayList(uint8, uint8)); + Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions(); + Assert.assertTrue(functions.contains(signature)); + } + + @Test + public void testFunctionAliases() throws GandivaException { + ArrowType.Int int64 = new ArrowType.Int(64, true); + FunctionSignature signature = + new FunctionSignature("modulo", int64, Lists.newArrayList(int64, int64)); + Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions(); + Assert.assertTrue(functions.contains(signature)); + } + + @Test + public void testCaseInsensitiveFunctionName() throws GandivaException { + ArrowType.Utf8 utf8 = new ArrowType.Utf8(); + ArrowType.Int int64 = new ArrowType.Int(64, true); + FunctionSignature signature = + new FunctionSignature("castvarchar", utf8, Lists.newArrayList(utf8, int64)); + Set<FunctionSignature> functions = ExpressionRegistry.getInstance().getSupportedFunctions(); + Assert.assertTrue(functions.contains(signature)); + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterProjectTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterProjectTest.java new file mode 100644 index 000000000..51fc1c291 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterProjectTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.expression.Condition; +import org.apache.arrow.gandiva.expression.ExpressionTree; +import org.apache.arrow.gandiva.expression.TreeBuilder; +import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +import com.google.common.collect.Lists; + +public class FilterProjectTest extends BaseEvaluatorTest { + + @Test + public void testSimpleSV16() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + Field c = Field.nullable("c", int32); + List<Field> args = Lists.newArrayList(a, b); + + Condition condition = TreeBuilder.makeCondition("less_than", args); + + Schema schema = new Schema(args); + Filter filter = Filter.make(schema, condition); + + ExpressionTree expression = TreeBuilder.makeExpression("add", Lists.newArrayList(a, b), c); + Projector projector = Projector.make(schema, Lists.newArrayList(expression), SelectionVectorType.SV_INT16); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] bValues = new int[]{2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15}; + int[] expected = {3, 7, 11, 15}; + + verifyTestCaseFor16(filter, projector, numRows, validity, aValues, bValues, expected); + } + + private void verifyTestCaseFor16(Filter filter, Projector projector, int numRows, byte[] validity, + int[] aValues, int[] bValues, int[] expected) throws GandivaException { + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(aValues); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(bValues); + ArrowRecordBatch batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + ArrowBuf selectionBuffer = buf(numRows * 2); + SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer); + + filter.evaluate(batch, selectionVector); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(selectionVector.getRecordCount()); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + projector.evaluate(batch, selectionVector, output); + for (int i = 0; i < selectionVector.getRecordCount(); i++) { + assertFalse(intVector.isNull(i)); + assertEquals(expected[i], intVector.get(i)); + } + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + selectionBuffer.close(); + filter.close(); + projector.close(); + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java new file mode 100644 index 000000000..ed6e43cd6 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.List; +import java.util.stream.IntStream; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.expression.Condition; +import org.apache.arrow.gandiva.expression.TreeBuilder; +import org.apache.arrow.gandiva.expression.TreeNode; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Test; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +public class FilterTest extends BaseEvaluatorTest { + + private int[] selectionVectorToArray(SelectionVector vector) { + int[] actual = new int[vector.getRecordCount()]; + for (int i = 0; i < vector.getRecordCount(); ++i) { + actual[i] = vector.getIndex(i); + } + return actual; + } + + private Charset utf8Charset = Charset.forName("UTF-8"); + private Charset utf16Charset = Charset.forName("UTF-16"); + + List<ArrowBuf> varBufs(String[] strings, Charset charset) { + ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4); + ArrowBuf dataBuffer = allocator.buffer(strings.length * 8); + + int startOffset = 0; + for (int i = 0; i < strings.length; i++) { + offsetsBuffer.writeInt(startOffset); + + final byte[] bytes = strings[i].getBytes(charset); + dataBuffer = dataBuffer.reallocIfNeeded(dataBuffer.writerIndex() + bytes.length); + dataBuffer.setBytes(startOffset, bytes, 0, bytes.length); + startOffset += bytes.length; + } + offsetsBuffer.writeInt(startOffset); // offset for the last element + + return Arrays.asList(offsetsBuffer, dataBuffer); + } + + List<ArrowBuf> stringBufs(String[] strings) { + return varBufs(strings, utf8Charset); + } + + @Test + public void testSimpleInString() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", new ArrowType.Utf8()); + TreeNode l1 = TreeBuilder.makeLiteral(1L); + TreeNode l2 = TreeBuilder.makeLiteral(3L); + + List<Field> argsSchema = Lists.newArrayList(c1); + List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2); + TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8()); + TreeNode inExpr = + TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou")); + + Condition condition = TreeBuilder.makeCondition(inExpr); + + Schema schema = new Schema(argsSchema); + Filter filter = Filter.make(schema, condition); + + int numRows = 16; + byte[] validity = new byte[] {(byte) 255, 0}; + // second half is "undefined" + String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven", + "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", + "sixteen"}; + int[] expected = {0, 1, 2, 3}; + ArrowBuf c1Validity = buf(validity); + ArrowBuf c2Validity = buf(validity); + List<ArrowBuf> dataBufsX = stringBufs(c1Values); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity)); + + ArrowBuf selectionBuffer = buf(numRows * 2); + SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer); + + filter.evaluate(batch, selectionVector); + + int[] actual = selectionVectorToArray(selectionVector); + releaseRecordBatch(batch); + selectionBuffer.close(); + filter.close(); + Assert.assertArrayEquals(expected, actual); + } + + @Test + public void testSimpleInInt() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", int32); + + List<Field> argsSchema = Lists.newArrayList(c1); + TreeNode inExpr = + TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4)); + + Condition condition = TreeBuilder.makeCondition(inExpr); + + Schema schema = new Schema(argsSchema); + Filter filter = Filter.make(schema, condition); + + int numRows = 16; + byte[] validity = new byte[] {(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] expected = {0, 1, 2, 3}; + + ArrowBuf validitya = buf(validity); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesa = intBuf(aValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(validitya, valuesa, validityb)); + + ArrowBuf selectionBuffer = buf(numRows * 2); + SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer); + + filter.evaluate(batch, selectionVector); + + // free buffers + int[] actual = selectionVectorToArray(selectionVector); + releaseRecordBatch(batch); + selectionBuffer.close(); + filter.close(); + Assert.assertArrayEquals(expected, actual); + } + + @Test + public void testSimpleSV16() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Condition condition = TreeBuilder.makeCondition("less_than", args); + + Schema schema = new Schema(args); + Filter filter = Filter.make(schema, condition); + + int numRows = 16; + byte[] validity = new byte[] {(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15}; + int[] expected = {0, 2, 4, 6}; + + verifyTestCase(filter, numRows, validity, aValues, bValues, expected); + } + + @Test + public void testSimpleSV16_AllMatched() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Condition condition = TreeBuilder.makeCondition("less_than", args); + + Schema schema = new Schema(args); + Filter filter = Filter.make(schema, condition); + + int numRows = 32; + + byte[] validity = new byte[numRows / 8]; + + IntStream.range(0, numRows / 8).forEach(i -> validity[i] = (byte) 255); + + int[] aValues = new int[numRows]; + IntStream.range(0, numRows).forEach(i -> aValues[i] = i); + + int[] bValues = new int[numRows]; + IntStream.range(0, numRows).forEach(i -> bValues[i] = i + 1); + + int[] expected = new int[numRows]; + IntStream.range(0, numRows).forEach(i -> expected[i] = i); + + verifyTestCase(filter, numRows, validity, aValues, bValues, expected); + } + + @Test + public void testSimpleSV16_GreaterThan64Recs() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Condition condition = TreeBuilder.makeCondition("greater_than", args); + + Schema schema = new Schema(args); + Filter filter = Filter.make(schema, condition); + + int numRows = 1000; + + byte[] validity = new byte[numRows / 8]; + + IntStream.range(0, numRows / 8).forEach(i -> validity[i] = (byte) 255); + + int[] aValues = new int[numRows]; + IntStream.range(0, numRows).forEach(i -> aValues[i] = i); + + int[] bValues = new int[numRows]; + IntStream.range(0, numRows).forEach(i -> bValues[i] = i + 1); + + aValues[0] = 5; + bValues[0] = 0; + + int[] expected = {0}; + + verifyTestCase(filter, numRows, validity, aValues, bValues, expected); + } + + @Test + public void testSimpleSV32() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Condition condition = TreeBuilder.makeCondition("less_than", args); + + Schema schema = new Schema(args); + Filter filter = Filter.make(schema, condition); + + int numRows = 16; + byte[] validity = new byte[] {(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15}; + int[] expected = {0, 2, 4, 6}; + + verifyTestCase(filter, numRows, validity, aValues, bValues, expected); + } + + @Test + public void testSimpleFilterWithNoOptimisation() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Condition condition = TreeBuilder.makeCondition("less_than", args); + + Schema schema = new Schema(args); + Filter filter = Filter.make(schema, condition, false); + + int numRows = 16; + byte[] validity = new byte[] {(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] bValues = new int[] {2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 14, 15}; + int[] expected = {0, 2, 4, 6}; + + verifyTestCase(filter, numRows, validity, aValues, bValues, expected); + } + + private void verifyTestCase( + Filter filter, int numRows, byte[] validity, int[] aValues, int[] bValues, int[] expected) + throws GandivaException { + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(aValues); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(bValues); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + ArrowBuf selectionBuffer = buf(numRows * 2); + SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer); + + filter.evaluate(batch, selectionVector); + + // free buffers + int[] actual = selectionVectorToArray(selectionVector); + releaseRecordBatch(batch); + selectionBuffer.close(); + filter.close(); + + Assert.assertArrayEquals(expected, actual); + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java new file mode 100644 index 000000000..6934c3f9e --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/MicroBenchmarkTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import java.util.List; + +import org.apache.arrow.gandiva.expression.Condition; +import org.apache.arrow.gandiva.expression.ExpressionTree; +import org.apache.arrow.gandiva.expression.TreeBuilder; +import org.apache.arrow.gandiva.expression.TreeNode; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import com.google.common.collect.Lists; + +@Ignore +public class MicroBenchmarkTest extends BaseEvaluatorTest { + + private double toleranceRatio = 4.0; + + @Test + public void testAdd3() throws Exception { + Field x = Field.nullable("x", int32); + Field n2x = Field.nullable("n2x", int32); + Field n3x = Field.nullable("n3x", int32); + + // x + n2x + n3x + TreeNode add1 = + TreeBuilder.makeFunction( + "add", Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeField(n2x)), int32); + TreeNode add = + TreeBuilder.makeFunction( + "add", Lists.newArrayList(add1, TreeBuilder.makeField(n3x)), int32); + ExpressionTree expr = TreeBuilder.makeExpression(add, x); + + List<Field> cols = Lists.newArrayList(x, n2x, n3x); + Schema schema = new Schema(cols); + + long timeTaken = timedProject(new Int32DataAndVectorGenerator(allocator), + schema, + Lists.newArrayList(expr), + 1 * MILLION, 16 * THOUSAND, + 4); + System.out.println("Time taken for projecting 1m records of add3 is " + timeTaken + "ms"); + Assert.assertTrue(timeTaken <= 13 * toleranceRatio); + } + + @Test + public void testIf() throws Exception { + /* + * when x < 10 then 0 + * when x < 20 then 1 + * when x < 30 then 2 + * when x < 40 then 3 + * when x < 50 then 4 + * when x < 60 then 5 + * when x < 70 then 6 + * when x < 80 then 7 + * when x < 90 then 8 + * when x < 100 then 9 + * when x < 110 then 10 + * when x < 120 then 11 + * when x < 130 then 12 + * when x < 140 then 13 + * when x < 150 then 14 + * when x < 160 then 15 + * when x < 170 then 16 + * when x < 180 then 17 + * when x < 190 then 18 + * when x < 200 then 19 + * else 20 + */ + Field x = Field.nullable("x", int32); + TreeNode xNode = TreeBuilder.makeField(x); + + // if (x < 100) then 9 else 10 + int returnValue = 20; + TreeNode topNode = TreeBuilder.makeLiteral(returnValue); + int compareWith = 200; + while (compareWith >= 10) { + // cond (x < compareWith) + TreeNode condNode = + TreeBuilder.makeFunction( + "less_than", + Lists.newArrayList(xNode, TreeBuilder.makeLiteral(compareWith)), + boolType); + topNode = + TreeBuilder.makeIf( + condNode, // cond (x < compareWith) + TreeBuilder.makeLiteral(returnValue), // then returnValue + topNode, // else topNode + int32); + compareWith -= 10; + returnValue--; + } + + ExpressionTree expr = TreeBuilder.makeExpression(topNode, x); + Schema schema = new Schema(Lists.newArrayList(x)); + + long timeTaken = timedProject(new BoundedInt32DataAndVectorGenerator(allocator, 250), + schema, + Lists.newArrayList(expr), + 1 * MILLION, 16 * THOUSAND, + 4); + System.out.println("Time taken for projecting 10m records of nestedIf is " + timeTaken + "ms"); + Assert.assertTrue(timeTaken <= 15 * toleranceRatio); + } + + @Test + public void testFilterAdd2() throws Exception { + Field x = Field.nullable("x", int32); + Field n2x = Field.nullable("n2x", int32); + Field n3x = Field.nullable("n3x", int32); + + // x + n2x < n3x + TreeNode add = TreeBuilder.makeFunction("add", + Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeField(n2x)), int32); + TreeNode lessThan = TreeBuilder + .makeFunction("less_than", Lists.newArrayList(add, TreeBuilder.makeField(n3x)), boolType); + Condition condition = TreeBuilder.makeCondition(lessThan); + + List<Field> cols = Lists.newArrayList(x, n2x, n3x); + Schema schema = new Schema(cols); + + long timeTaken = timedFilter(new Int32DataAndVectorGenerator(allocator), + schema, + condition, + 1 * MILLION, 16 * THOUSAND, + 4); + System.out.println("Time taken for filtering 10m records of a+b<c is " + timeTaken + "ms"); + Assert.assertTrue(timeTaken <= 12 * toleranceRatio); + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java new file mode 100644 index 000000000..28a57c9f8 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java @@ -0,0 +1,797 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.expression.ExpressionTree; +import org.apache.arrow.gandiva.expression.TreeBuilder; +import org.apache.arrow.gandiva.expression.TreeNode; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.Decimal; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import com.google.common.collect.Lists; + +public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.BaseEvaluatorTest { + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Test + public void test_add() throws GandivaException { + int precision = 38; + int scale = 8; + ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); + Field a = Field.nullable("a", decimal); + Field b = Field.nullable("b", decimal); + List<Field> args = Lists.newArrayList(a, b); + + ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.ADD, decimal, decimal); + Field retType = Field.nullable("c", outputType); + ExpressionTree root = TreeBuilder.makeExpression("add", args, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(root); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + int numRows = 4; + byte[] validity = new byte[]{(byte) 255}; + String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "4.12345678"}; + String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "5.12345678"}; + + DecimalVector valuesa = decimalVector(aValues, precision, scale); + DecimalVector valuesb = decimalVector(bValues, precision, scale); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(), + valuesb.getValidityBuffer(), valuesb.getDataBuffer())); + + DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(), + outputType.getScale()); + outVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(outVector); + eval.evaluate(batch, output); + + // should have scaled down. + BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(3.2469136), + BigDecimal.valueOf(5.2469136), + BigDecimal.valueOf(7.2469136), + BigDecimal.valueOf(9.2469136)}; + + for (int i = 0; i < 4; i++) { + assertFalse(outVector.isNull(i)); + assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i) + ) == 0); + } + + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void test_add_literal() throws GandivaException { + int precision = 2; + int scale = 0; + ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); + ArrowType.Decimal literalType = new ArrowType.Decimal(2, 1, 128); + Field a = Field.nullable("a", decimal); + + ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.ADD, decimal, literalType); + Field retType = Field.nullable("c", outputType); + TreeNode field = TreeBuilder.makeField(a); + TreeNode literal = TreeBuilder.makeDecimalLiteral("6", 2, 1); + List<TreeNode> args = Lists.newArrayList(field, literal); + TreeNode root = TreeBuilder.makeFunction("add", args, outputType); + ExpressionTree tree = TreeBuilder.makeExpression(root, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(tree); + + Schema schema = new Schema(Lists.newArrayList(a)); + Projector eval = Projector.make(schema, exprs); + + int numRows = 4; + String[] aValues = new String[]{"1", "2", "3", "4"}; + + DecimalVector valuesa = decimalVector(aValues, precision, scale); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer())); + + DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(), + outputType.getScale()); + outVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(outVector); + eval.evaluate(batch, output); + + BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(1.6), BigDecimal.valueOf(2.6), + BigDecimal.valueOf(3.6), BigDecimal.valueOf(4.6)}; + + for (int i = 0; i < 4; i++) { + assertFalse(outVector.isNull(i)); + assertTrue(expOutput[i].compareTo(outVector.getObject(i)) == 0); + } + + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void test_multiply() throws GandivaException { + int precision = 38; + int scale = 8; + ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); + Field a = Field.nullable("a", decimal); + Field b = Field.nullable("b", decimal); + List<Field> args = Lists.newArrayList(a, b); + + ArrowType.Decimal outputType = DecimalTypeUtil.getResultTypeForOperation(DecimalTypeUtil + .OperationType.MULTIPLY, decimal, decimal); + Field retType = Field.nullable("c", outputType); + ExpressionTree root = TreeBuilder.makeExpression("multiply", args, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(root); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + int numRows = 4; + byte[] validity = new byte[]{(byte) 255}; + String[] aValues = new String[]{"1.12345678", "2.12345678", "3.12345678", "999999999999.99999999"}; + String[] bValues = new String[]{"2.12345678", "3.12345678", "4.12345678", "999999999999.99999999"}; + + DecimalVector valuesa = decimalVector(aValues, precision, scale); + DecimalVector valuesb = decimalVector(bValues, precision, scale); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(), + valuesb.getValidityBuffer(), valuesb.getDataBuffer())); + + DecimalVector outVector = new DecimalVector("decimal_output", allocator, outputType.getPrecision(), + outputType.getScale()); + outVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(outVector); + eval.evaluate(batch, output); + + // should have scaled down. + BigDecimal[] expOutput = new BigDecimal[]{BigDecimal.valueOf(2.385612), + BigDecimal.valueOf(6.632525), + BigDecimal.valueOf(12.879439), + new BigDecimal("999999999999999999980000.000000")}; + + for (int i = 0; i < 4; i++) { + assertFalse(outVector.isNull(i)); + assertTrue("index : " + i + " failed compare", expOutput[i].compareTo(outVector.getObject(i) + ) == 0); + } + + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testCompare() throws GandivaException { + Decimal aType = new Decimal(38, 3, 128); + Decimal bType = new Decimal(38, 2, 128); + Field a = Field.nullable("a", aType); + Field b = Field.nullable("b", bType); + List<Field> args = Lists.newArrayList(a, b); + + List<ExpressionTree> exprs = new ArrayList<>( + Arrays.asList( + TreeBuilder.makeExpression("equal", args, Field.nullable("eq", boolType)), + TreeBuilder.makeExpression("not_equal", args, Field.nullable("ne", boolType)), + TreeBuilder.makeExpression("less_than", args, Field.nullable("lt", boolType)), + TreeBuilder.makeExpression("less_than_or_equal_to", args, Field.nullable("le", boolType)), + TreeBuilder.makeExpression("greater_than", args, Field.nullable("gt", boolType)), + TreeBuilder.makeExpression("greater_than_or_equal_to", args, Field.nullable("ge", boolType)) + ) + ); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + List<ValueVector> output = null; + ArrowRecordBatch batch = null; + try { + int numRows = 4; + String[] aValues = new String[]{"7.620", "2.380", "3.860", "-18.160"}; + String[] bValues = new String[]{"7.62", "3.50", "1.90", "-1.45"}; + + DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale()); + DecimalVector valuesb = decimalVector(bValues, bType.getPrecision(), bType.getScale()); + batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(), + valuesb.getValidityBuffer(), valuesb.getDataBuffer())); + + // expected results. + boolean[][] expected = { + {true, false, false, false}, // eq + {false, true, true, true}, // ne + {false, true, false, true}, // lt + {true, true, false, true}, // le + {false, false, true, false}, // gt + {true, false, true, false}, // ge + }; + + // Allocate output vectors. + output = new ArrayList<>( + Arrays.asList( + new BitVector("eq", allocator), + new BitVector("ne", allocator), + new BitVector("lt", allocator), + new BitVector("le", allocator), + new BitVector("gt", allocator), + new BitVector("ge", allocator) + ) + ); + for (ValueVector v : output) { + v.allocateNew(); + } + + // evaluate expressions. + eval.evaluate(batch, output); + + // compare the outputs. + for (int idx = 0; idx < output.size(); ++idx) { + boolean[] expectedArray = expected[idx]; + BitVector resultVector = (BitVector) output.get(idx); + + for (int i = 0; i < numRows; i++) { + assertFalse(resultVector.isNull(i)); + assertEquals("mismatch in result for expr at idx " + idx + " for row " + i, + expectedArray[i], resultVector.getObject(i).booleanValue()); + } + } + } finally { + // free buffers + if (batch != null) { + releaseRecordBatch(batch); + } + if (output != null) { + releaseValueVectors(output); + } + eval.close(); + } + } + + @Test + public void testRound() throws GandivaException { + Decimal aType = new Decimal(38, 2, 128); + Decimal aWithScaleZero = new Decimal(38, 0, 128); + Decimal aWithScaleOne = new Decimal(38, 1, 128); + Field a = Field.nullable("a", aType); + List<Field> args = Lists.newArrayList(a); + + List<ExpressionTree> exprs = new ArrayList<>( + Arrays.asList( + TreeBuilder.makeExpression("abs", args, Field.nullable("abs", aType)), + TreeBuilder.makeExpression("ceil", args, Field.nullable("ceil", aWithScaleZero)), + TreeBuilder.makeExpression("floor", args, Field.nullable("floor", aWithScaleZero)), + TreeBuilder.makeExpression("round", args, Field.nullable("round", aWithScaleZero)), + TreeBuilder.makeExpression("truncate", args, Field.nullable("truncate", aWithScaleZero)), + TreeBuilder.makeExpression( + TreeBuilder.makeFunction("round", + Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)), + aWithScaleOne), + Field.nullable("round_scale_1", aWithScaleOne)), + TreeBuilder.makeExpression( + TreeBuilder.makeFunction("truncate", + Lists.newArrayList(TreeBuilder.makeField(a), TreeBuilder.makeLiteral(1)), + aWithScaleOne), + Field.nullable("truncate_scale_1", aWithScaleOne)) + ) + ); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + List<ValueVector> output = null; + ArrowRecordBatch batch = null; + try { + int numRows = 4; + String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"}; + + DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale()); + batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer())); + + // expected results. + BigDecimal[][] expected = { + {BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58), + BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58)}, // abs + {BigDecimal.valueOf(2), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // ceil + {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-2), BigDecimal.valueOf(-2)}, // floor + {BigDecimal.valueOf(1), BigDecimal.valueOf(2), BigDecimal.valueOf(-1), BigDecimal.valueOf(-2)}, // round + {BigDecimal.valueOf(1), BigDecimal.valueOf(1), BigDecimal.valueOf(-1), BigDecimal.valueOf(-1)}, // truncate + {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6), + BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)}, // round-to-scale-1 + {BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.5), + BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.5)}, // truncate-to-scale-1 + }; + + // Allocate output vectors. + output = new ArrayList<>( + Arrays.asList( + new DecimalVector("abs", allocator, aType.getPrecision(), aType.getScale()), + new DecimalVector("ceil", allocator, aType.getPrecision(), 0), + new DecimalVector("floor", allocator, aType.getPrecision(), 0), + new DecimalVector("round", allocator, aType.getPrecision(), 0), + new DecimalVector("truncate", allocator, aType.getPrecision(), 0), + new DecimalVector("round_to_scale_1", allocator, aType.getPrecision(), 1), + new DecimalVector("truncate_to_scale_1", allocator, aType.getPrecision(), 1) + ) + ); + for (ValueVector v : output) { + v.allocateNew(); + } + + // evaluate expressions. + eval.evaluate(batch, output); + + // compare the outputs. + for (int idx = 0; idx < output.size(); ++idx) { + BigDecimal[] expectedArray = expected[idx]; + DecimalVector resultVector = (DecimalVector) output.get(idx); + + for (int i = 0; i < numRows; i++) { + assertFalse(resultVector.isNull(i)); + assertTrue("mismatch in result for " + + "field " + resultVector.getField().getName() + + " for row " + i + + " expected " + expectedArray[i] + + ", got " + resultVector.getObject(i), + expectedArray[i].compareTo(resultVector.getObject(i)) == 0); + } + } + } finally { + // free buffers + if (batch != null) { + releaseRecordBatch(batch); + } + if (output != null) { + releaseValueVectors(output); + } + eval.close(); + } + } + + @Test + public void testCastToDecimal() throws GandivaException { + Decimal decimalType = new Decimal(38, 2, 128); + Decimal decimalWithScaleOne = new Decimal(38, 1, 128); + Field dec = Field.nullable("dec", decimalType); + Field int64f = Field.nullable("int64", int64); + Field doublef = Field.nullable("float64", float64); + + List<ExpressionTree> exprs = new ArrayList<>( + Arrays.asList( + TreeBuilder.makeExpression("castDECIMAL", + Lists.newArrayList(int64f), + Field.nullable("int64_to_dec", decimalType)), + + TreeBuilder.makeExpression("castDECIMAL", + Lists.newArrayList(doublef), + Field.nullable("float64_to_dec", decimalType)), + + TreeBuilder.makeExpression("castDECIMAL", + Lists.newArrayList(dec), + Field.nullable("dec_to_dec", decimalWithScaleOne)) + ) + ); + + Schema schema = new Schema(Lists.newArrayList(int64f, doublef, dec)); + Projector eval = Projector.make(schema, exprs); + + List<ValueVector> output = null; + ArrowRecordBatch batch = null; + try { + int numRows = 4; + String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"}; + DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); + batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList( + new ArrowFieldNode(numRows, 0), + new ArrowFieldNode(numRows, 0), + new ArrowFieldNode(numRows, 0)), + Lists.newArrayList( + arrowBufWithAllValid(4), + longBuf(new long[]{123, 158, -123, -158}), + arrowBufWithAllValid(4), + doubleBuf(new double[]{1.23, 1.58, -1.23, -1.58}), + valuesa.getValidityBuffer(), + valuesa.getDataBuffer()) + ); + + // Allocate output vectors. + output = new ArrayList<>( + Arrays.asList( + new DecimalVector("int64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()), + new DecimalVector("float64_to_dec", allocator, decimalType.getPrecision(), decimalType.getScale()), + new DecimalVector("dec_to_dec", allocator, + decimalWithScaleOne.getPrecision(), decimalWithScaleOne.getScale()) + ) + ); + for (ValueVector v : output) { + v.allocateNew(); + } + + // evaluate expressions. + eval.evaluate(batch, output); + + // compare the outputs. + BigDecimal[][] expected = { + { BigDecimal.valueOf(123), BigDecimal.valueOf(158), + BigDecimal.valueOf(-123), BigDecimal.valueOf(-158)}, + { BigDecimal.valueOf(1.23), BigDecimal.valueOf(1.58), + BigDecimal.valueOf(-1.23), BigDecimal.valueOf(-1.58)}, + { BigDecimal.valueOf(1.2), BigDecimal.valueOf(1.6), + BigDecimal.valueOf(-1.2), BigDecimal.valueOf(-1.6)} + }; + for (int idx = 0; idx < output.size(); ++idx) { + BigDecimal[] expectedArray = expected[idx]; + DecimalVector resultVector = (DecimalVector) output.get(idx); + for (int i = 0; i < numRows; i++) { + assertFalse(resultVector.isNull(i)); + assertTrue("mismatch in result for " + + "field " + resultVector.getField().getName() + + " for row " + i + + " expected " + expectedArray[i] + + ", got " + resultVector.getObject(i), + expectedArray[i].compareTo(resultVector.getObject(i)) == 0); + } + } + } finally { + // free buffers + if (batch != null) { + releaseRecordBatch(batch); + } + if (output != null) { + releaseValueVectors(output); + } + eval.close(); + } + } + + @Test + public void testCastToLong() throws GandivaException { + Decimal decimalType = new Decimal(38, 2, 128); + Field dec = Field.nullable("dec", decimalType); + + Schema schema = new Schema(Lists.newArrayList(dec)); + Projector eval = Projector.make(schema, + Lists.newArrayList( + TreeBuilder.makeExpression("castBIGINT", + Lists.newArrayList(dec), + Field.nullable("dec_to_int64", int64) + ) + ) + ); + + List<ValueVector> output = null; + ArrowRecordBatch batch = null; + try { + int numRows = 5; + String[] aValues = new String[]{"1.23", "1.50", "98765.78", "-1.23", "-1.58"}; + DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); + batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList( + new ArrowFieldNode(numRows, 0) + ), + Lists.newArrayList( + valuesa.getValidityBuffer(), + valuesa.getDataBuffer() + ) + ); + + // Allocate output vectors. + BigIntVector resultVector = new BigIntVector("dec_to_int64", allocator); + resultVector.allocateNew(); + output = new ArrayList<>(Arrays.asList(resultVector)); + + // evaluate expressions. + eval.evaluate(batch, output); + + // compare the outputs. + long[] expected = {1, 2, 98766, -1, -2}; + for (int i = 0; i < numRows; i++) { + assertFalse(resultVector.isNull(i)); + assertEquals(expected[i], resultVector.get(i)); + } + } finally { + // free buffers + if (batch != null) { + releaseRecordBatch(batch); + } + if (output != null) { + releaseValueVectors(output); + } + eval.close(); + } + } + + @Test + public void testCastToDouble() throws GandivaException { + Decimal decimalType = new Decimal(38, 2, 128); + Field dec = Field.nullable("dec", decimalType); + + Schema schema = new Schema(Lists.newArrayList(dec)); + Projector eval = Projector.make(schema, + Lists.newArrayList( + TreeBuilder.makeExpression("castFLOAT8", + Lists.newArrayList(dec), + Field.nullable("dec_to_float64", float64) + ) + ) + ); + + List<ValueVector> output = null; + ArrowRecordBatch batch = null; + try { + int numRows = 4; + String[] aValues = new String[]{"1.23", "1.58", "-1.23", "-1.58"}; + DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); + batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList( + new ArrowFieldNode(numRows, 0) + ), + Lists.newArrayList( + valuesa.getValidityBuffer(), + valuesa.getDataBuffer() + ) + ); + + // Allocate output vectors. + Float8Vector resultVector = new Float8Vector("dec_to_float64", allocator); + resultVector.allocateNew(); + output = new ArrayList<>(Arrays.asList(resultVector)); + + // evaluate expressions. + eval.evaluate(batch, output); + + // compare the outputs. + double[] expected = {1.23, 1.58, -1.23, -1.58}; + for (int i = 0; i < numRows; i++) { + assertFalse(resultVector.isNull(i)); + assertEquals(expected[i], resultVector.get(i), 0); + } + } finally { + // free buffers + if (batch != null) { + releaseRecordBatch(batch); + } + if (output != null) { + releaseValueVectors(output); + } + eval.close(); + } + } + + @Test + public void testCastToString() throws GandivaException { + Decimal decimalType = new Decimal(38, 2, 128); + Field dec = Field.nullable("dec", decimalType); + Field str = Field.nullable("str", new ArrowType.Utf8()); + TreeNode field = TreeBuilder.makeField(dec); + TreeNode literal = TreeBuilder.makeLiteral(5L); + List<TreeNode> args = Lists.newArrayList(field, literal); + TreeNode cast = TreeBuilder.makeFunction("castVARCHAR", args, new ArrowType.Utf8()); + TreeNode root = TreeBuilder.makeFunction("equal", + Lists.newArrayList(cast, TreeBuilder.makeField(str)), new ArrowType.Bool()); + ExpressionTree tree = TreeBuilder.makeExpression(root, Field.nullable("are_equal", new ArrowType.Bool())); + + Schema schema = new Schema(Lists.newArrayList(dec, str)); + Projector eval = Projector.make(schema, Lists.newArrayList(tree) + ); + + List<ValueVector> output = null; + ArrowRecordBatch batch = null; + try { + int numRows = 4; + String[] aValues = new String[]{"10.51", "100.23", "-1000.23", "-0000.10"}; + String[] expected = {"10.51", "100.2", "-1000", "-0.10"}; + DecimalVector valuesa = decimalVector(aValues, decimalType.getPrecision(), decimalType.getScale()); + VarCharVector result = varcharVector(expected); + batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList( + new ArrowFieldNode(numRows, 0) + ), + Lists.newArrayList( + valuesa.getValidityBuffer(), + valuesa.getDataBuffer(), + result.getValidityBuffer(), + result.getOffsetBuffer(), + result.getDataBuffer() + ) + ); + + BitVector resultVector = new BitVector("res", allocator); + resultVector.allocateNew(); + output = new ArrayList<>(Arrays.asList(resultVector)); + + // evaluate expressions. + eval.evaluate(batch, output); + + // compare the outputs. + for (int i = 0; i < numRows; i++) { + assertTrue(resultVector.getObject(i).booleanValue()); + } + } finally { + // free buffers + if (batch != null) { + releaseRecordBatch(batch); + } + if (output != null) { + releaseValueVectors(output); + } + eval.close(); + } + } + + @Test + public void testCastStringToDecimal() throws GandivaException { + Decimal decimalType = new Decimal(4, 2, 128); + Field dec = Field.nullable("dec", decimalType); + + Field str = Field.nullable("str", new ArrowType.Utf8()); + TreeNode field = TreeBuilder.makeField(str); + List<TreeNode> args = Lists.newArrayList(field); + TreeNode cast = TreeBuilder.makeFunction("castDECIMAL", args, decimalType); + ExpressionTree tree = TreeBuilder.makeExpression(cast, Field.nullable("dec_str", decimalType)); + + Schema schema = new Schema(Lists.newArrayList(str)); + Projector eval = Projector.make(schema, Lists.newArrayList(tree) + ); + + List<ValueVector> output = null; + ArrowRecordBatch batch = null; + try { + int numRows = 4; + String[] aValues = new String[]{"10.5134", "-0.1", "10.516", "-1000"}; + VarCharVector valuesa = varcharVector(aValues); + batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList( + new ArrowFieldNode(numRows, 0) + ), + Lists.newArrayList( + valuesa.getValidityBuffer(), + valuesa.getOffsetBuffer(), + valuesa.getDataBuffer() + ) + ); + + DecimalVector resultVector = new DecimalVector("res", allocator, + decimalType.getPrecision(), decimalType.getScale()); + resultVector.allocateNew(); + output = new ArrayList<>(Arrays.asList(resultVector)); + + BigDecimal[] expected = {BigDecimal.valueOf(10.51), BigDecimal.valueOf(-0.10), + BigDecimal.valueOf(10.52), BigDecimal.valueOf(0.00)}; + // evaluate expressions. + eval.evaluate(batch, output); + + // compare the outputs. + for (int i = 0; i < numRows; i++) { + assertTrue("mismatch in result for " + + "field " + resultVector.getField().getName() + + " for row " + i + + " expected " + expected[i] + + ", got " + resultVector.getObject(i), expected[i].compareTo(resultVector.getObject(i)) == 0); + } + } finally { + // free buffers + if (batch != null) { + releaseRecordBatch(batch); + } + if (output != null) { + releaseValueVectors(output); + } + eval.close(); + } + } + + @Test + public void testInvalidDecimal() throws GandivaException { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" + + " : 0"); + Decimal decimalType = new Decimal(0, 0, 128); + Field int64f = Field.nullable("int64", int64); + + Schema schema = new Schema(Lists.newArrayList(int64f)); + Projector eval = Projector.make(schema, + Lists.newArrayList( + TreeBuilder.makeExpression("castDECIMAL", + Lists.newArrayList(int64f), + Field.nullable("invalid_dec", decimalType) + ) + ) + ); + } + + @Test + public void testInvalidDecimalGt38() throws GandivaException { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Gandiva only supports decimals of upto 38 precision. Input precision" + + " : 42"); + Decimal decimalType = new Decimal(42, 0, 128); + Field int64f = Field.nullable("int64", int64); + + Schema schema = new Schema(Lists.newArrayList(int64f)); + Projector eval = Projector.make(schema, + Lists.newArrayList( + TreeBuilder.makeExpression("castDECIMAL", + Lists.newArrayList(int64f), + Field.nullable("invalid_dec", decimalType) + ) + ) + ); + } +} + diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java new file mode 100644 index 000000000..03c9377b0 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -0,0 +1,2470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import java.math.BigDecimal; +import java.nio.charset.Charset; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.expression.ExpressionTree; +import org.apache.arrow.gandiva.expression.TreeBuilder; +import org.apache.arrow.gandiva.expression.TreeNode; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.ipc.message.ArrowFieldNode; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +public class ProjectorTest extends BaseEvaluatorTest { + + private Charset utf8Charset = Charset.forName("UTF-8"); + private Charset utf16Charset = Charset.forName("UTF-16"); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + List<ArrowBuf> varBufs(String[] strings, Charset charset) { + ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4); + + long dataBufferSize = 0L; + for (String string : strings) { + dataBufferSize += string.getBytes(charset).length; + } + + ArrowBuf dataBuffer = allocator.buffer(dataBufferSize); + + int startOffset = 0; + for (int i = 0; i < strings.length; i++) { + offsetsBuffer.writeInt(startOffset); + + final byte[] bytes = strings[i].getBytes(charset); + dataBuffer = dataBuffer.reallocIfNeeded(dataBuffer.writerIndex() + bytes.length); + dataBuffer.setBytes(startOffset, bytes, 0, bytes.length); + startOffset += bytes.length; + } + offsetsBuffer.writeInt(startOffset); // offset for the last element + + return Arrays.asList(offsetsBuffer, dataBuffer); + } + + List<ArrowBuf> stringBufs(String[] strings) { + return varBufs(strings, utf8Charset); + } + + List<ArrowBuf> binaryBufs(String[] strings) { + return varBufs(strings, utf16Charset); + } + + private void testMakeProjectorParallel(ConfigurationBuilder.ConfigOptions configOptions) throws InterruptedException { + List<Schema> schemas = Lists.newArrayList(); + Field a = Field.nullable("a", int64); + Field b = Field.nullable("b", int64); + IntStream.range(0, 1000) + .forEach( + i -> { + Field c = Field.nullable("" + i, int64); + List<Field> cols = Lists.newArrayList(a, b, c); + schemas.add(new Schema(cols)); + }); + + TreeNode aNode = TreeBuilder.makeField(a); + TreeNode bNode = TreeBuilder.makeField(b); + List<TreeNode> args = Lists.newArrayList(aNode, bNode); + + TreeNode cond = TreeBuilder.makeFunction("greater_than", args, boolType); + TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, int64); + + ExpressionTree expr = TreeBuilder.makeExpression(ifNode, Field.nullable("c", int64)); + List<ExpressionTree> exprs = Lists.newArrayList(expr); + + // build projectors in parallel choosing schema at random + // this should hit the same cache entry thus exposing + // any threading issues. + ExecutorService executors = Executors.newFixedThreadPool(16); + + IntStream.range(0, 1000) + .forEach( + i -> { + executors.submit( + () -> { + try { + Projector evaluator = configOptions == null ? + Projector.make(schemas.get((int) (Math.random() * 100)), exprs) : + Projector.make(schemas.get((int) (Math.random() * 100)), exprs, configOptions); + evaluator.close(); + } catch (GandivaException e) { + e.printStackTrace(); + } + }); + }); + executors.shutdown(); + executors.awaitTermination(100, java.util.concurrent.TimeUnit.SECONDS); + } + + @Test + public void testMakeProjectorParallel() throws Exception { + testMakeProjectorParallel(null); + testMakeProjectorParallel(new ConfigurationBuilder.ConfigOptions().withTargetCPU(false)); + testMakeProjectorParallel(new ConfigurationBuilder.ConfigOptions().withTargetCPU(false).withOptimize(false)); + } + + // Will be fixed by https://issues.apache.org/jira/browse/ARROW-4371 + @Ignore + @Test + public void testMakeProjector() throws GandivaException { + Field a = Field.nullable("a", int64); + Field b = Field.nullable("b", int64); + TreeNode aNode = TreeBuilder.makeField(a); + TreeNode bNode = TreeBuilder.makeField(b); + List<TreeNode> args = Lists.newArrayList(aNode, bNode); + + List<Field> cols = Lists.newArrayList(a, b); + Schema schema = new Schema(cols); + + TreeNode cond = TreeBuilder.makeFunction("greater_than", args, boolType); + TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, int64); + + ExpressionTree expr = TreeBuilder.makeExpression(ifNode, Field.nullable("c", int64)); + List<ExpressionTree> exprs = Lists.newArrayList(expr); + + long startTime = System.currentTimeMillis(); + Projector evaluator1 = Projector.make(schema, exprs); + System.out.println( + "Projector build: iteration 1 took " + (System.currentTimeMillis() - startTime) + " ms"); + startTime = System.currentTimeMillis(); + Projector evaluator2 = Projector.make(schema, exprs); + System.out.println( + "Projector build: iteration 2 took " + (System.currentTimeMillis() - startTime) + " ms"); + startTime = System.currentTimeMillis(); + Projector evaluator3 = Projector.make(schema, exprs); + long timeToMakeProjector = (System.currentTimeMillis() - startTime); + // should be getting the projector from the cache; + // giving 5ms for varying system load. + Assert.assertTrue(timeToMakeProjector < 5L); + + evaluator1.close(); + evaluator2.close(); + evaluator3.close(); + } + + @Test + public void testMakeProjectorValidationError() throws InterruptedException { + + Field a = Field.nullable("a", int64); + TreeNode aNode = TreeBuilder.makeField(a); + List<TreeNode> args = Lists.newArrayList(aNode); + + List<Field> cols = Lists.newArrayList(a); + Schema schema = new Schema(cols); + + TreeNode cond = TreeBuilder.makeFunction("non_existent_fn", args, boolType); + + ExpressionTree expr = TreeBuilder.makeExpression(cond, Field.nullable("c", int64)); + List<ExpressionTree> exprs = Lists.newArrayList(expr); + + boolean exceptionThrown = false; + try { + Projector evaluator1 = Projector.make(schema, exprs); + } catch (GandivaException e) { + exceptionThrown = true; + } + + Assert.assertTrue(exceptionThrown); + + // allow GC to collect any temp resources. + Thread.sleep(1000); + + // try again to ensure no temporary resources. + exceptionThrown = false; + try { + Projector evaluator1 = Projector.make(schema, exprs); + } catch (GandivaException e) { + exceptionThrown = true; + } + + Assert.assertTrue(exceptionThrown); + } + + @Test + public void testEvaluate() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Field retType = Field.nullable("c", int32); + ExpressionTree root = TreeBuilder.makeExpression("add", args, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(root); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] bValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(aValues); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(bValues); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 8; i++) { + assertFalse(intVector.isNull(i)); + assertEquals(17, intVector.get(i)); + } + for (int i = 8; i < 16; i++) { + assertTrue(intVector.isNull(i)); + } + + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testEvaluateDivZero() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Field retType = Field.nullable("c", int32); + ExpressionTree root = TreeBuilder.makeExpression("divide", args, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(root); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + int numRows = 2; + byte[] validity = new byte[]{(byte) 255}; + // second half is "undefined" + int[] aValues = new int[]{2, 2}; + int[] bValues = new int[]{1, 0}; + + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(aValues); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(bValues); + ArrowRecordBatch batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + boolean exceptionThrown = false; + try { + eval.evaluate(batch, output); + } catch (GandivaException e) { + Assert.assertTrue(e.getMessage().contains("divide by zero")); + exceptionThrown = true; + } + Assert.assertTrue(exceptionThrown); + + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testDivZeroParallel() throws GandivaException, InterruptedException { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + Field c = Field.nullable("c", int32); + List<Field> cols = Lists.newArrayList(a, b); + Schema s = new Schema(cols); + + List<Field> args = Lists.newArrayList(a, b); + + ExpressionTree expr = TreeBuilder.makeExpression("divide", args, c); + List<ExpressionTree> exprs = Lists.newArrayList(expr); + + ExecutorService executors = Executors.newFixedThreadPool(16); + + AtomicInteger errorCount = new AtomicInteger(0); + AtomicInteger errorCountExp = new AtomicInteger(0); + // pre-build the projector so that same projector is used for all executions. + Projector test = Projector.make(s, exprs); + + IntStream.range(0, 1000).forEach(i -> { + executors.submit(() -> { + try { + Projector evaluator = Projector.make(s, exprs); + int numRows = 2; + byte[] validity = new byte[]{(byte) 255}; + int[] aValues = new int[]{2, 2}; + int[] bValues; + if (i % 2 == 0) { + errorCountExp.incrementAndGet(); + bValues = new int[]{1, 0}; + } else { + bValues = new int[]{1, 1}; + } + + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(aValues); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(bValues); + ArrowRecordBatch batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, + 0)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + try { + evaluator.evaluate(batch, output); + } catch (GandivaException e) { + errorCount.incrementAndGet(); + } + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + evaluator.close(); + } catch (GandivaException ignore) { + } + }); + }); + executors.shutdown(); + executors.awaitTermination(100, java.util.concurrent.TimeUnit.SECONDS); + test.close(); + Assert.assertEquals(errorCountExp.intValue(), errorCount.intValue()); + } + + @Test + public void testAdd3() throws GandivaException, Exception { + Field x = Field.nullable("x", int32); + Field n2x = Field.nullable("n2x", int32); + Field n3x = Field.nullable("n3x", int32); + + List<TreeNode> args = new ArrayList<TreeNode>(); + + // x + n2x + n3x + TreeNode add1 = + TreeBuilder.makeFunction( + "add", Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeField(n2x)), int32); + TreeNode add = + TreeBuilder.makeFunction( + "add", Lists.newArrayList(add1, TreeBuilder.makeField(n3x)), int32); + ExpressionTree expr = TreeBuilder.makeExpression(add, x); + + List<Field> cols = Lists.newArrayList(x, n2x, n3x); + Schema schema = new Schema(cols); + + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + // second half is "undefined" + int[] xValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] n2xValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + int[] n3xValues = new int[]{1, 2, 3, 4, 4, 3, 2, 1, 5, 6, 7, 8, 8, 7, 6, 5}; + + int[] expected = new int[]{18, 19, 20, 21, 21, 20, 19, 18, 18, 19, 20, 21, 21, 20, 19, 18}; + + ArrowBuf xValidity = buf(validity); + ArrowBuf xData = intBuf(xValues); + ArrowBuf n2xValidity = buf(validity); + ArrowBuf n2xData = intBuf(n2xValues); + ArrowBuf n3xValidity = buf(validity); + ArrowBuf n3xData = intBuf(n3xValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 8); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode, fieldNode), + Lists.newArrayList(xValidity, xData, n2xValidity, n2xData, n3xValidity, n3xData)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 8; i++) { + assertFalse(intVector.isNull(i)); + assertEquals(expected[i], intVector.get(i)); + } + for (int i = 8; i < 16; i++) { + assertTrue(intVector.isNull(i)); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testStringFields() throws GandivaException { + /* + * when x < "hello" then octet_length(x) + a + * else octet_length(x) + b + */ + + Field x = Field.nullable("x", new ArrowType.Utf8()); + Field a = Field.nullable("a", new ArrowType.Int(32, true)); + Field b = Field.nullable("b", new ArrowType.Int(32, true)); + + ArrowType retType = new ArrowType.Int(32, true); + + TreeNode cond = + TreeBuilder.makeFunction( + "less_than", + Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeStringLiteral("hello")), + boolType); + TreeNode octetLenFuncNode = + TreeBuilder.makeFunction( + "octet_length", Lists.newArrayList(TreeBuilder.makeField(x)), retType); + TreeNode octetLenPlusANode = + TreeBuilder.makeFunction( + "add", Lists.newArrayList(TreeBuilder.makeField(a), octetLenFuncNode), retType); + TreeNode octetLenPlusBNode = + TreeBuilder.makeFunction( + "add", Lists.newArrayList(TreeBuilder.makeField(b), octetLenFuncNode), retType); + + TreeNode ifHello = TreeBuilder.makeIf(cond, octetLenPlusANode, octetLenPlusBNode, retType); + + ExpressionTree expr = TreeBuilder.makeExpression(ifHello, Field.nullable("res", retType)); + Schema schema = new Schema(Lists.newArrayList(a, x, b)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[]{(byte) 255, 0}; + // "A função" means "The function" in portugese + String[] valuesX = new String[]{"hell", "abc", "hellox", "ijk", "A função"}; + int[] valuesA = new int[]{10, 20, 30, 40, 50}; + int[] valuesB = new int[]{110, 120, 130, 140, 150}; + int[] expected = new int[]{14, 23, 136, 143, 60}; + + ArrowBuf validityX = buf(validity); + List<ArrowBuf> dataBufsX = stringBufs(valuesX); + ArrowBuf validityA = buf(validity); + ArrowBuf dataA = intBuf(valuesA); + ArrowBuf validityB = buf(validity); + ArrowBuf dataB = intBuf(valuesB); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList( + validityA, dataA, validityX, dataBufsX.get(0), dataBufsX.get(1), validityB, dataB)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + eval.evaluate(batch, output); + + for (int i = 0; i < numRows; i++) { + assertFalse(intVector.isNull(i)); + assertEquals(expected[i], intVector.get(i)); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testStringOutput() throws GandivaException { + /* + * if (x >= 0) "hi" else "bye" + */ + + Field x = Field.nullable("x", new ArrowType.Int(32, true)); + + ArrowType retType = new ArrowType.Utf8(); + + TreeNode ifHiBye = TreeBuilder.makeIf( + TreeBuilder.makeFunction( + "greater_than_or_equal_to", + Lists.newArrayList( + TreeBuilder.makeField(x), + TreeBuilder.makeLiteral(0) + ), + boolType), + TreeBuilder.makeStringLiteral("hi"), + TreeBuilder.makeStringLiteral("bye"), + retType); + + ExpressionTree expr = TreeBuilder.makeExpression(ifHiBye, Field.nullable("res", retType)); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + // fill up input record batch + int numRows = 4; + byte[] validity = new byte[]{(byte) 255, 0}; + int[] xValues = new int[]{10, -10, 20, -20}; + String[] expected = new String[]{"hi", "bye", "hi", "bye"}; + ArrowBuf validityX = buf(validity); + ArrowBuf dataX = intBuf(xValues); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0)), + Lists.newArrayList( validityX, dataX)); + + // allocate data for output vector. + VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + outVector.allocateNew(64, numRows); + + + // evaluate expression + List<ValueVector> output = new ArrayList<>(); + output.add(outVector); + eval.evaluate(batch, output); + + // match expected output. + for (int i = 0; i < numRows; i++) { + assertFalse(outVector.isNull(i)); + assertEquals(expected[i], new String(outVector.get(i))); + } + + // test with insufficient data buffer. + try { + outVector.allocateNew(4, numRows); + eval.evaluate(batch, output); + } finally { + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + } + + @Test + public void testRegex() throws GandivaException { + /* + * like "%map%" + */ + + Field x = Field.nullable("x", new ArrowType.Utf8()); + + TreeNode cond = + TreeBuilder.makeFunction( + "like", + Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeStringLiteral("%map%")), + boolType); + ExpressionTree expr = TreeBuilder.makeExpression(cond, Field.nullable("res", boolType)); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[]{(byte) 255, 0}; + String[] valuesX = new String[]{"mapD", "maps", "google maps", "map", "MapR"}; + boolean[] expected = new boolean[]{true, true, true, true, false}; + + ArrowBuf validityX = buf(validity); + List<ArrowBuf> dataBufsX = stringBufs(valuesX); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1))); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < numRows; i++) { + assertFalse(bitVector.isNull(i)); + assertEquals(expected[i], bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testRegexpReplace() throws GandivaException { + + Field x = Field.nullable("x", new ArrowType.Utf8()); + Field replaceString = Field.nullable("replaceString", new ArrowType.Utf8()); + + Field retType = Field.nullable("c", new ArrowType.Utf8()); + + TreeNode cond = + TreeBuilder.makeFunction( + "regexp_replace", + Lists.newArrayList(TreeBuilder.makeField(x), TreeBuilder.makeStringLiteral("ana"), + TreeBuilder.makeField(replaceString)), + new ArrowType.Utf8()); + ExpressionTree expr = TreeBuilder.makeExpression(cond, retType); + Schema schema = new Schema(Lists.newArrayList(x, replaceString)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[]{(byte) 15, 0}; + String[] valuesX = new String[]{"banana", "bananaana", "bananana", "anaana", "anaana"}; + String[] valuesReplace = new String[]{"ue", "", "", "c", ""}; + String[] expected = new String[]{"buena", "bna", "bn", "cc", null}; + + ArrowBuf validityX = buf(validity); + ArrowBuf validityReplace = buf(validity); + List<ArrowBuf> dataBufsX = stringBufs(valuesX); + List<ArrowBuf> dataBufsReplace = stringBufs(valuesReplace); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1), validityReplace, + dataBufsReplace.get(0), dataBufsReplace.get(1))); + + // allocate data for output vector. + VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + outVector.allocateNew(numRows * 15, numRows); + + // evaluate expression + List<ValueVector> output = new ArrayList<>(); + output.add(outVector); + eval.evaluate(batch, output); + eval.close(); + + // match expected output. + for (int i = 0; i < numRows - 1; i++) { + assertFalse("Expect none value equals null", outVector.isNull(i)); + assertEquals(expected[i], new String(outVector.get(i))); + } + + assertTrue("Last value must be null", outVector.isNull(numRows - 1)); + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testRand() throws GandivaException { + + TreeNode randWithSeed = + TreeBuilder.makeFunction( + "rand", + Lists.newArrayList(TreeBuilder.makeLiteral(12)), + float64); + TreeNode rand = + TreeBuilder.makeFunction( + "rand", + Lists.newArrayList(), + float64); + ExpressionTree exprWithSeed = TreeBuilder.makeExpression(randWithSeed, Field.nullable("res", float64)); + ExpressionTree expr = TreeBuilder.makeExpression(rand, Field.nullable("res2", float64)); + Field x = Field.nullable("x", new ArrowType.Utf8()); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector evalWithSeed = Projector.make(schema, Lists.newArrayList(exprWithSeed)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[] {(byte) 255, 0}; + String[] valuesX = new String[] {"mapD", "maps", "google maps", "map", "MapR"}; + double[] expected = new double[] {0.1597116001879662D, 0.7347813877263527D, 0.6069965050584282D, + 0.7240285696335824D, 0.09975540272957834D}; + + ArrowBuf validityX = buf(validity); + List<ArrowBuf> dataBufsX = stringBufs(valuesX); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1))); + + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(float8Vector); + evalWithSeed.evaluate(batch, output); + + for (int i = 0; i < numRows; i++) { + assertFalse(float8Vector.isNull(i)); + assertEquals(expected[i], float8Vector.getObject(i), 0.000000001); + } + + eval.evaluate(batch, output); // without seed + assertNotEquals(float8Vector.getObject(0), float8Vector.getObject(1), 0.000000001); + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + evalWithSeed.close(); + } + + @Test + public void testBinaryFields() throws GandivaException { + Field a = Field.nullable("a", new ArrowType.Binary()); + Field b = Field.nullable("b", new ArrowType.Binary()); + List<Field> args = Lists.newArrayList(a, b); + + ArrowType retType = new ArrowType.Bool(); + ExpressionTree expr = TreeBuilder.makeExpression("equal", args, Field.nullable("res", retType)); + + Schema schema = new Schema(Lists.newArrayList(args)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[]{(byte) 255, 0}; + String[] valuesA = new String[]{"a", "aa", "aaa", "aaaa", "A função"}; + String[] valuesB = new String[]{"a", "bb", "aaa", "bbbbb", "A função"}; + boolean[] expected = new boolean[]{true, false, true, false, true}; + + ArrowBuf validitya = buf(validity); + ArrowBuf validityb = buf(validity); + List<ArrowBuf> inBufsA = binaryBufs(valuesA); + List<ArrowBuf> inBufsB = binaryBufs(valuesB); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)), + Lists.newArrayList( + validitya, + inBufsA.get(0), + inBufsA.get(1), + validityb, + inBufsB.get(0), + inBufsB.get(1))); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < numRows; i++) { + assertFalse(bitVector.isNull(i)); + assertEquals(expected[i], bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + private TreeNode makeLongLessThanCond(TreeNode arg, long value) { + return TreeBuilder.makeFunction( + "less_than", Lists.newArrayList(arg, TreeBuilder.makeLiteral(value)), boolType); + } + + private TreeNode makeLongGreaterThanCond(TreeNode arg, long value) { + return TreeBuilder.makeFunction( + "greater_than", Lists.newArrayList(arg, TreeBuilder.makeLiteral(value)), boolType); + } + + private TreeNode ifLongLessThanElse( + TreeNode arg, long value, long thenValue, TreeNode elseNode, ArrowType type) { + return TreeBuilder.makeIf( + makeLongLessThanCond(arg, value), TreeBuilder.makeLiteral(thenValue), elseNode, type); + } + + @Test + public void testIf() throws GandivaException, Exception { + /* + * when x < 10 then 0 + * when x < 20 then 1 + * when x < 30 then 2 + * when x < 40 then 3 + * when x < 50 then 4 + * when x < 60 then 5 + * when x < 70 then 6 + * when x < 80 then 7 + * when x < 90 then 8 + * when x < 100 then 9 + * else 10 + */ + Field x = Field.nullable("x", int64); + TreeNode xNode = TreeBuilder.makeField(x); + + // if (x < 100) then 9 else 10 + TreeNode ifLess100 = ifLongLessThanElse(xNode, 100L, 9L, TreeBuilder.makeLiteral(10L), int64); + // if (x < 90) then 8 else ifLess100 + TreeNode ifLess90 = ifLongLessThanElse(xNode, 90L, 8L, ifLess100, int64); + // if (x < 80) then 7 else ifLess90 + TreeNode ifLess80 = ifLongLessThanElse(xNode, 80L, 7L, ifLess90, int64); + // if (x < 70) then 6 else ifLess80 + TreeNode ifLess70 = ifLongLessThanElse(xNode, 70L, 6L, ifLess80, int64); + // if (x < 60) then 5 else ifLess70 + TreeNode ifLess60 = ifLongLessThanElse(xNode, 60L, 5L, ifLess70, int64); + // if (x < 50) then 4 else ifLess60 + TreeNode ifLess50 = ifLongLessThanElse(xNode, 50L, 4L, ifLess60, int64); + // if (x < 40) then 3 else ifLess50 + TreeNode ifLess40 = ifLongLessThanElse(xNode, 40L, 3L, ifLess50, int64); + // if (x < 30) then 2 else ifLess40 + TreeNode ifLess30 = ifLongLessThanElse(xNode, 30L, 2L, ifLess40, int64); + // if (x < 20) then 1 else ifLess30 + TreeNode ifLess20 = ifLongLessThanElse(xNode, 20L, 1L, ifLess30, int64); + // if (x < 10) then 0 else ifLess20 + TreeNode ifLess10 = ifLongLessThanElse(xNode, 10L, 0L, ifLess20, int64); + + ExpressionTree expr = TreeBuilder.makeExpression(ifLess10, x); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, (byte) 255}; + long[] xValues = new long[]{9, 15, 21, 32, 43, 54, 65, 76, 87, 98, 109, 200, -10, 60, 77, 80}; + long[] expected = new long[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 0, 6, 7, 8}; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf xData = longBuf(xValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData)); + + BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator); + bigIntVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bigIntVector); + eval.evaluate(batch, output); + + for (int i = 0; i < numRows; i++) { + assertFalse(bigIntVector.isNull(i)); + assertEquals(expected[i], bigIntVector.get(i)); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testAnd() throws GandivaException, Exception { + /* + * x > 10 AND x < 20 + */ + ArrowType int64 = new ArrowType.Int(64, true); + + Field x = Field.nullable("x", int64); + TreeNode xNode = TreeBuilder.makeField(x); + TreeNode gt10 = makeLongGreaterThanCond(xNode, 10); + TreeNode lt20 = makeLongLessThanCond(xNode, 20); + TreeNode and = TreeBuilder.makeAnd(Lists.newArrayList(gt10, lt20)); + + Field res = Field.nullable("res", boolType); + + ExpressionTree expr = TreeBuilder.makeExpression(and, res); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 4; + byte[] validity = new byte[]{(byte) 255}; + long[] xValues = new long[]{9, 15, 17, 25}; + boolean[] expected = new boolean[]{false, true, true, false}; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf xData = longBuf(xValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < numRows; i++) { + assertFalse(bitVector.isNull(i)); + assertEquals(expected[i], bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testOr() throws GandivaException, Exception { + /* + * x > 10 OR x < 5 + */ + ArrowType int64 = new ArrowType.Int(64, true); + + Field x = Field.nullable("x", int64); + TreeNode xNode = TreeBuilder.makeField(x); + TreeNode gt10 = makeLongGreaterThanCond(xNode, 10); + TreeNode lt5 = makeLongLessThanCond(xNode, 5); + TreeNode or = TreeBuilder.makeOr(Lists.newArrayList(gt10, lt5)); + + Field res = Field.nullable("res", boolType); + + ExpressionTree expr = TreeBuilder.makeExpression(or, res); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 4; + byte[] validity = new byte[]{(byte) 255}; + long[] xValues = new long[]{4, 9, 15, 17}; + boolean[] expected = new boolean[]{true, false, true, true}; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf xData = longBuf(xValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < numRows; i++) { + assertFalse(bitVector.isNull(i)); + assertEquals(expected[i], bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testNull() throws GandivaException, Exception { + /* + * when x < 10 then 1 + * else null + */ + ArrowType int64 = new ArrowType.Int(64, true); + + Field x = Field.nullable("x", int64); + TreeNode xNode = TreeBuilder.makeField(x); + + // if (x < 10) then 1 else null + TreeNode ifLess10 = ifLongLessThanElse(xNode, 10L, 1L, TreeBuilder.makeNull(int64), int64); + + ExpressionTree expr = TreeBuilder.makeExpression(ifLess10, x); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 2; + byte[] validity = new byte[]{(byte) 255}; + long[] xValues = new long[]{5, 32}; + long[] expected = new long[]{1, 0}; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf xData = longBuf(xValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData)); + + BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator); + bigIntVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bigIntVector); + eval.evaluate(batch, output); + + // first element should be 1 + assertFalse(bigIntVector.isNull(0)); + assertEquals(expected[0], bigIntVector.get(0)); + + // second element should be null + assertTrue(bigIntVector.isNull(1)); + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testTimeNull() throws GandivaException, Exception { + + ArrowType time64 = new ArrowType.Time(TimeUnit.MICROSECOND, 64); + + Field x = Field.nullable("x", time64); + TreeNode xNode = TreeBuilder.makeNull(time64); + + ExpressionTree expr = TreeBuilder.makeExpression(xNode, x); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 2; + byte[] validity = new byte[]{(byte) 255}; + int[] xValues = new int[]{5, 32}; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf xData = intBuf(xValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData)); + + BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator); + bigIntVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bigIntVector); + eval.evaluate(batch, output); + + assertTrue(bigIntVector.isNull(0)); + assertTrue(bigIntVector.isNull(1)); + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testTimeEquals() throws GandivaException, Exception { /* + * when isnotnull(x) then x + * else y + */ + Field x = Field.nullable("x", new ArrowType.Time(TimeUnit.MILLISECOND, 32)); + TreeNode xNode = TreeBuilder.makeField(x); + + Field y = Field.nullable("y", new ArrowType.Time(TimeUnit.MILLISECOND, 32)); + TreeNode yNode = TreeBuilder.makeField(y); + + // if isnotnull(x) then x else y + TreeNode condition = TreeBuilder.makeFunction("isnotnull", Lists.newArrayList(xNode), + boolType); + TreeNode ifCoalesce = TreeBuilder.makeIf( + condition, + xNode, + yNode, + new ArrowType.Time(TimeUnit.MILLISECOND, 32)); + + ExpressionTree expr = TreeBuilder.makeExpression(ifCoalesce, x); + Schema schema = new Schema(Lists.newArrayList(x, y)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 2; + byte[] validity = new byte[]{(byte) 1}; + byte[] yValidity = new byte[]{(byte) 3}; + int[] xValues = new int[]{5, 1}; + int[] yValues = new int[]{10, 2}; + int[] expected = new int[]{5, 2}; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf xData = intBuf(xValues); + + ArrowBuf yBufValidity = buf(yValidity); + ArrowBuf yData = intBuf(yValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, xData, yBufValidity, yData)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + eval.evaluate(batch, output); + + // output should be 5 and 2 + assertFalse(intVector.isNull(0)); + assertEquals(expected[0], intVector.get(0)); + assertEquals(expected[1], intVector.get(1)); + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testIsNull() throws GandivaException, Exception { + Field x = Field.nullable("x", float64); + + TreeNode xNode = TreeBuilder.makeField(x); + TreeNode isNull = TreeBuilder.makeFunction("isnull", Lists.newArrayList(xNode), boolType); + ExpressionTree expr = TreeBuilder.makeExpression(isNull, Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + double[] xValues = + new double[]{ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0 + }; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf xData = doubleBuf(xValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, Lists.newArrayList(fieldNode), Lists.newArrayList(bufValidity, xData)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 8; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + for (int i = 8; i < numRows; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testEquals() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", int32); + Field c2 = Field.nullable("c2", int32); + + TreeNode c1Node = TreeBuilder.makeField(c1); + TreeNode c2Node = TreeBuilder.makeField(c2); + TreeNode equals = + TreeBuilder.makeFunction("equal", Lists.newArrayList(c1Node, c2Node), boolType); + ExpressionTree expr = TreeBuilder.makeExpression(equals, Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(c1, c2)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + int[] c1Values = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] c2Values = new int[]{1, 2, 3, 4, 8, 7, 6, 5, 16, 15, 14, 13, 12, 11, 10, 9}; + + ArrowBuf c1Validity = buf(validity); + ArrowBuf c1Data = intBuf(c1Values); + ArrowBuf c2Validity = buf(validity); + ArrowBuf c2Data = intBuf(c2Values); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(c1Validity, c1Data, c2Validity, c2Data)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 4; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + for (int i = 4; i < 8; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + for (int i = 8; i < 16; i++) { + assertTrue(bitVector.isNull(i)); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testInExpr() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", int32); + + TreeNode inExpr = + TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4, 5, 15, 16)); + ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(c1)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + int[] c1Values = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + + ArrowBuf c1Validity = buf(validity); + ArrowBuf c1Data = intBuf(c1Values); + ArrowBuf c2Validity = buf(validity); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(c1Validity, c1Data, c2Validity)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 5; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + for (int i = 5; i < 16; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testInExprDecimal() throws GandivaException, Exception { + Integer precision = 26; + Integer scale = 5; + ArrowType.Decimal decimal = new ArrowType.Decimal(precision, scale, 128); + Field c1 = Field.nullable("c1", decimal); + + String[] values = new String[]{"1", "2", "3", "4"}; + Set<BigDecimal> decimalSet = decimalSet(values, scale); + decimalSet.add(new BigDecimal(-0.0)); + decimalSet.add(new BigDecimal(Long.MAX_VALUE)); + decimalSet.add(new BigDecimal(Long.MIN_VALUE)); + TreeNode inExpr = + TreeBuilder.makeInExpressionDecimal(TreeBuilder.makeField(c1), + decimalSet, precision, scale); + ExpressionTree expr = TreeBuilder.makeExpression(inExpr, + Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(c1)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + String[] c1Values = + new String[]{"1", "2", "3", "4", "-0.0", "6", "7", "8", "9", "10", "11", "12", "13", "14", + String.valueOf(Long.MAX_VALUE), + String.valueOf(Long.MIN_VALUE)}; + + DecimalVector c1Data = decimalVector(c1Values, precision, scale); + ArrowBuf c1Validity = buf(validity); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(c1Validity, c1Data.getDataBuffer(), c1Data.getValidityBuffer())); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 5; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + for (int i = 5; i < 16; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testInExprDouble() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", float64); + + TreeNode inExpr = + TreeBuilder.makeInExpressionDouble(TreeBuilder.makeField(c1), + Sets.newHashSet(1.0, -0.0, 3.0, 4.0, Double.NaN, + Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY)); + ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(c1)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + // Create a row-batch with some sample data to look for + int numRows = 16; + // Only the first 8 values will be valid. + byte[] validity = new byte[]{(byte) 255, 0}; + double[] c1Values = new double[]{1, -0.0, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.NaN, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 4, 3}; + + ArrowBuf c1Validity = buf(validity); + ArrowBuf c1Data = doubleBuf(c1Values); + ArrowBuf c2Validity = buf(validity); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(c1Validity, c1Data, c2Validity)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + // The first four values in the vector must match the expression, but not the other ones. + for (int i = 0; i < 4; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + for (int i = 4; i < 16; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testInExprStrings() throws GandivaException, Exception { + Field c1 = Field.nullable("c1", new ArrowType.Utf8()); + + TreeNode l1 = TreeBuilder.makeLiteral(1L); + TreeNode l2 = TreeBuilder.makeLiteral(3L); + List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2); + TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8()); + TreeNode inExpr = + TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou")); + ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType)); + Schema schema = new Schema(Lists.newArrayList(c1)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven", + "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", + "sixteen"}; + + ArrowBuf c1Validity = buf(validity); + List<ArrowBuf> dataBufsX = stringBufs(c1Values); + ArrowBuf c2Validity = buf(validity); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity)); + + BitVector bitVector = new BitVector(EMPTY_SCHEMA_PATH, allocator); + bitVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(bitVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 4; i++) { + assertTrue(bitVector.getObject(i).booleanValue()); + } + for (int i = 5; i < 16; i++) { + assertFalse(bitVector.getObject(i).booleanValue()); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testSmallOutputVectors() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Field retType = Field.nullable("c", int32); + ExpressionTree root = TreeBuilder.makeExpression("add", args, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(root); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] bValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + + ArrowBuf aValidity = buf(validity); + ArrowBuf aData = intBuf(aValues); + ArrowBuf bValidity = buf(validity); + ArrowBuf b2Validity = buf(validity); + ArrowBuf bData = intBuf(bValues); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)), + Lists.newArrayList(aValidity, aData, bValidity, bData, b2Validity)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + try { + eval.evaluate(batch, output); + } catch (Throwable t) { + intVector.allocateNew(numRows); + eval.evaluate(batch, output); + } + + for (int i = 0; i < 8; i++) { + assertFalse(intVector.isNull(i)); + assertEquals(17, intVector.get(i)); + } + for (int i = 8; i < 16; i++) { + assertTrue(intVector.isNull(i)); + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testDateTime() throws GandivaException, Exception { + ArrowType date64 = new ArrowType.Date(DateUnit.MILLISECOND); + // ArrowType time32 = new ArrowType.Time(TimeUnit.MILLISECOND, 32); + ArrowType timeStamp = new ArrowType.Timestamp(TimeUnit.MILLISECOND, "TZ"); + + Field dateField = Field.nullable("date", date64); + // Field timeField = Field.nullable("time", time32); + Field tsField = Field.nullable("timestamp", timeStamp); + + TreeNode dateNode = TreeBuilder.makeField(dateField); + TreeNode tsNode = TreeBuilder.makeField(tsField); + + List<TreeNode> dateArgs = Lists.newArrayList(dateNode); + TreeNode dateToYear = TreeBuilder.makeFunction("extractYear", dateArgs, int64); + TreeNode dateToMonth = TreeBuilder.makeFunction("extractMonth", dateArgs, int64); + TreeNode dateToDay = TreeBuilder.makeFunction("extractDay", dateArgs, int64); + TreeNode dateToHour = TreeBuilder.makeFunction("extractHour", dateArgs, int64); + TreeNode dateToMin = TreeBuilder.makeFunction("extractMinute", dateArgs, int64); + + List<TreeNode> tsArgs = Lists.newArrayList(tsNode); + TreeNode tsToYear = TreeBuilder.makeFunction("extractYear", tsArgs, int64); + TreeNode tsToMonth = TreeBuilder.makeFunction("extractMonth", tsArgs, int64); + TreeNode tsToDay = TreeBuilder.makeFunction("extractDay", tsArgs, int64); + TreeNode tsToHour = TreeBuilder.makeFunction("extractHour", tsArgs, int64); + TreeNode tsToMin = TreeBuilder.makeFunction("extractMinute", tsArgs, int64); + + Field resultField = Field.nullable("result", int64); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(dateToYear, resultField), + TreeBuilder.makeExpression(dateToMonth, resultField), + TreeBuilder.makeExpression(dateToDay, resultField), + TreeBuilder.makeExpression(dateToHour, resultField), + TreeBuilder.makeExpression(dateToMin, resultField), + TreeBuilder.makeExpression(tsToYear, resultField), + TreeBuilder.makeExpression(tsToMonth, resultField), + TreeBuilder.makeExpression(tsToDay, resultField), + TreeBuilder.makeExpression(tsToHour, resultField), + TreeBuilder.makeExpression(tsToMin, resultField)); + + Schema schema = new Schema(Lists.newArrayList(dateField, tsField)); + Projector eval = Projector.make(schema, exprs); + + int numRows = 8; + byte[] validity = new byte[]{(byte) 255}; + String[] values = + new String[]{ + "2007-01-01T01:00:00.00Z", + "2007-03-05T03:40:00.00Z", + "2008-05-31T13:55:00.00Z", + "2000-06-30T23:20:00.00Z", + "2000-07-10T20:30:00.00Z", + "2000-08-20T00:14:00.00Z", + "2000-09-30T02:29:00.00Z", + "2000-10-31T05:33:00.00Z" + }; + long[] expYearFromDate = new long[]{2007, 2007, 2008, 2000, 2000, 2000, 2000, 2000}; + long[] expMonthFromDate = new long[]{1, 3, 5, 6, 7, 8, 9, 10}; + long[] expDayFromDate = new long[]{1, 5, 31, 30, 10, 20, 30, 31}; + long[] expHourFromDate = new long[]{1, 3, 13, 23, 20, 0, 2, 5}; + long[] expMinFromDate = new long[]{0, 40, 55, 20, 30, 14, 29, 33}; + + long[][] expValues = + new long[][]{ + expYearFromDate, expMonthFromDate, expDayFromDate, expHourFromDate, expMinFromDate + }; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf millisData = stringToMillis(values); + ArrowBuf buf2Validity = buf(validity); + ArrowBuf millis2Data = stringToMillis(values); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(bufValidity, millisData, buf2Validity, millis2Data)); + + List<ValueVector> output = new ArrayList<ValueVector>(); + for (int i = 0; i < exprs.size(); i++) { + BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator); + bigIntVector.allocateNew(numRows); + output.add(bigIntVector); + } + eval.evaluate(batch, output); + eval.close(); + + for (int i = 0; i < output.size(); i++) { + long[] expected = expValues[i % 5]; + BigIntVector bigIntVector = (BigIntVector) output.get(i); + + for (int j = 0; j < numRows; j++) { + assertFalse(bigIntVector.isNull(j)); + assertEquals(expected[j], bigIntVector.get(j)); + } + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testDateTrunc() throws Exception { + ArrowType date64 = new ArrowType.Date(DateUnit.MILLISECOND); + Field dateField = Field.nullable("date", date64); + + TreeNode dateNode = TreeBuilder.makeField(dateField); + + List<TreeNode> dateArgs = Lists.newArrayList(dateNode); + TreeNode dateToYear = TreeBuilder.makeFunction("date_trunc_Year", dateArgs, date64); + TreeNode dateToMonth = TreeBuilder.makeFunction("date_trunc_Month", dateArgs, date64); + + Field resultField = Field.nullable("result", date64); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(dateToYear, resultField), + TreeBuilder.makeExpression(dateToMonth, resultField)); + + Schema schema = new Schema(Lists.newArrayList(dateField)); + Projector eval = Projector.make(schema, exprs); + + int numRows = 4; + byte[] validity = new byte[]{(byte) 255}; + String[] values = new String[]{ + "2007-01-01T01:00:00.00Z", + "2007-03-05T03:40:00.00Z", + "2008-05-31T13:55:00.00Z", + "2000-06-30T23:20:00.00Z", + }; + String[] expYearFromDate = new String[]{ + "2007-01-01T00:00:00.00Z", + "2007-01-01T00:00:00.00Z", + "2008-01-01T00:00:00.00Z", + "2000-01-01T00:00:00.00Z", + }; + String[] expMonthFromDate = new String[]{ + "2007-01-01T00:00:00.00Z", + "2007-03-01T00:00:00.00Z", + "2008-05-01T00:00:00.00Z", + "2000-06-01T00:00:00.00Z", + }; + + String[][] expValues = new String[][]{ expYearFromDate, expMonthFromDate}; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf millisData = stringToMillis(values); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, millisData)); + + List<ValueVector> output = new ArrayList<ValueVector>(); + for (int i = 0; i < exprs.size(); i++) { + BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator); + bigIntVector.allocateNew(numRows); + output.add(bigIntVector); + } + eval.evaluate(batch, output); + eval.close(); + + for (int i = 0; i < output.size(); i++) { + String[] expected = expValues[i]; + BigIntVector bigIntVector = (BigIntVector) output.get(i); + + for (int j = 0; j < numRows; j++) { + assertFalse(bigIntVector.isNull(j)); + assertEquals(Instant.parse(expected[j]).toEpochMilli(), bigIntVector.get(j)); + } + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testUnknownFunction() { + Field c1 = Field.nullable("c1", int8); + Field c2 = Field.nullable("c2", int8); + + TreeNode c1Node = TreeBuilder.makeField(c1); + TreeNode c2Node = TreeBuilder.makeField(c2); + + TreeNode unknown = + TreeBuilder.makeFunction("xxx_yyy", Lists.newArrayList(c1Node, c2Node), int8); + ExpressionTree expr = TreeBuilder.makeExpression(unknown, Field.nullable("result", int8)); + Schema schema = new Schema(Lists.newArrayList(c1, c2)); + boolean caughtException = false; + try { + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + } catch (GandivaException ge) { + caughtException = true; + } + + assertTrue(caughtException); + } + + @Test + public void testCastTimestampToString() throws Exception { + ArrowType timeStamp = new ArrowType.Timestamp(TimeUnit.MILLISECOND, "TZ"); + + Field tsField = Field.nullable("timestamp", timeStamp); + Field lenField = Field.nullable("outLength", int64); + + TreeNode tsNode = TreeBuilder.makeField(tsField); + TreeNode lenNode = TreeBuilder.makeField(lenField); + + TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(tsNode, lenNode), + new ArrowType.Utf8()); + + Field resultField = Field.nullable("result", new ArrowType.Utf8()); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(tsToString, resultField)); + + Schema schema = new Schema(Lists.newArrayList(tsField, lenField)); + Projector eval = Projector.make(schema, exprs); + + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "0007-01-01T01:00:00Z", + "2007-03-05T03:40:00Z", + "2008-05-31T13:55:00Z", + "2000-06-30T23:20:00Z", + "2000-07-10T20:30:00Z", + }; + long[] lenValues = + new long[] { + 23L, 24L, 22L, 0L, 4L + }; + + String[] expValues = + new String[] { + "0007-01-01 01:00:00.000", + "2007-03-05 03:40:00.000", + "2008-05-31 13:55:00.00", + "", + "2000", + }; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf millisData = stringToMillis(values); + ArrowBuf lenValidity = buf(validity); + ArrowBuf lenData = longBuf(lenValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(bufValidity, millisData, lenValidity, lenData)); + + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + + charVector.allocateNew(numRows * 23, numRows); + output.add(charVector); + } + eval.evaluate(batch, output); + eval.close(); + + for (ValueVector valueVector : output) { + VarCharVector charVector = (VarCharVector) valueVector; + + for (int j = 0; j < numRows; j++) { + assertFalse(charVector.isNull(j)); + assertEquals(expValues[j], new String(charVector.get(j))); + } + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testCastDayIntervalToBigInt() throws Exception { + ArrowType dayIntervalType = new ArrowType.Interval(IntervalUnit.DAY_TIME); + + Field dayIntervalField = Field.nullable("dayInterval", dayIntervalType); + + TreeNode intervalNode = TreeBuilder.makeField(dayIntervalField); + + TreeNode intervalToBigint = TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(intervalNode), int64); + + Field resultField = Field.nullable("result", int64); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(intervalToBigint, resultField)); + + Schema schema = new Schema(Lists.newArrayList(dayIntervalField)); + Projector eval = Projector.make(schema, exprs); + + int numRows = 5; + byte[] validity = new byte[]{(byte) 255}; + String[] values = + new String[]{ + "1 0", // "days millis" + "2 0", + "1 1", + "10 5000", + "11 86400001", + }; + + Long[] expValues = + new Long[]{ + 86400000L, + 2 * 86400000L, + 86400000L + 1L, + 10 * 86400000L + 5000L, + 11 * 86400000L + 86400001L + }; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf intervalsData = stringToDayInterval(values); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(bufValidity, intervalsData)); + + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator); + bigIntVector.allocateNew(numRows); + output.add(bigIntVector); + } + eval.evaluate(batch, output); + eval.close(); + + for (ValueVector valueVector : output) { + BigIntVector bigintVector = (BigIntVector) valueVector; + + for (int j = 0; j < numRows; j++) { + assertFalse(bigintVector.isNull(j)); + assertEquals(expValues[j], Long.valueOf(bigintVector.get(j))); + } + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testCaseInsensitiveFunctions() throws Exception { + ArrowType timeStamp = new ArrowType.Timestamp(TimeUnit.MILLISECOND, "TZ"); + + Field tsField = Field.nullable("timestamp", timeStamp); + + TreeNode tsNode = TreeBuilder.makeField(tsField); + + TreeNode extractday = TreeBuilder.makeFunction("extractday", Lists.newArrayList(tsNode), + int64); + + ExpressionTree expr = TreeBuilder.makeExpression(extractday, Field.nullable("result", int64)); + Schema schema = new Schema(Lists.newArrayList(tsField)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "0007-01-01T01:00:00Z", + "2007-03-05T03:40:00Z", + "2008-05-31T13:55:00Z", + "2000-06-30T23:20:00Z", + "2000-07-10T20:30:00Z", + }; + + long[] expValues = + new long[] { + 1, 5, 31, 30, 10 + }; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf millisData = stringToMillis(values); + + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, millisData)); + + List<ValueVector> output = new ArrayList<>(); + BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator); + bigIntVector.allocateNew(numRows); + output.add(bigIntVector); + + eval.evaluate(batch, output); + eval.close(); + + for (ValueVector valueVector : output) { + BigIntVector vector = (BigIntVector) valueVector; + + for (int j = 0; j < numRows; j++) { + assertFalse(vector.isNull(j)); + assertEquals(expValues[j], vector.get(j)); + } + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testCastInt() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode), + int32); + Field resultField = Field.nullable("result", int32); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castINTFn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "0", "123", "-123", "-1", "1" + }; + int[] expValues = + new int[] { + 0, 123, -123, -1, 1 + }; + ArrowBuf bufValidity = buf(validity); + List<ArrowBuf> bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + output.add(intVector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + IntVector intVector = (IntVector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(intVector.isNull(j)); + assertTrue(expValues[j] == intVector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test(expected = GandivaException.class) + public void testCastIntInvalidValue() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castINTFn = TreeBuilder.makeFunction("castINT", Lists.newArrayList(inNode), + int32); + Field resultField = Field.nullable("result", int32); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castINTFn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 1; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "abc" + }; + ArrowBuf bufValidity = buf(validity); + List<ArrowBuf> bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + output.add(intVector); + } + try { + eval.evaluate(batch, output); + } finally { + eval.close(); + releaseRecordBatch(batch); + releaseValueVectors(output); + } + } + + @Test + public void testCastFloat() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "0", + "111", + "12345.67" + }; + double[] expValues = + new double[] { + 2.3, -11.11, 0, 111, 12345.67 + }; + ArrowBuf bufValidity = buf(validity); + List<ArrowBuf> bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + Float8Vector float8Vector = (Float8Vector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(float8Vector.isNull(j)); + assertTrue(expValues[j] == float8Vector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testCastFloatVarbinary() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Binary()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "0", + "111", + "12345.67" + }; + double[] expValues = + new double[] { + 2.3, -11.11, 0, 111, 12345.67 + }; + ArrowBuf bufValidity = buf(validity); + List<ArrowBuf> bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + Float8Vector float8Vector = (Float8Vector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(float8Vector.isNull(j)); + assertTrue(expValues[j] == float8Vector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test(expected = GandivaException.class) + public void testCastFloatInvalidValue() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Utf8()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "abc", + "111", + "12345.67" + }; + ArrowBuf bufValidity = buf(validity); + List<ArrowBuf> bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + try { + eval.evaluate(batch, output); + } finally { + eval.close(); + releaseRecordBatch(batch); + releaseValueVectors(output); + } + } + + @Test + public void testEvaluateWithUnsetTargetHostCPU() throws Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Field retType = Field.nullable("c", int32); + ExpressionTree root = TreeBuilder.makeExpression("add", args, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(root); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs, new ConfigurationBuilder.ConfigOptions().withTargetCPU(false )); + + int numRows = 16; + byte[] validity = new byte[]{(byte) 255, 0}; + // second half is "undefined" + int[] aValues = new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int[] bValues = new int[]{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}; + + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(aValues); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(bValues); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + eval.evaluate(batch, output); + + for (int i = 0; i < 8; i++) { + assertFalse(intVector.isNull(i)); + assertEquals(17, intVector.get(i)); + } + for (int i = 8; i < 16; i++) { + assertTrue(intVector.isNull(i)); + } + + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testCastVarcharFromInteger() throws Exception { + Field inField = Field.nullable("input", int32); + Field lenField = Field.nullable("outLength", int64); + + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode lenNode = TreeBuilder.makeField(lenField); + + TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(inNode, lenNode), + new ArrowType.Utf8()); + + Field resultField = Field.nullable("result", new ArrowType.Utf8()); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(tsToString, resultField)); + + Schema schema = new Schema(Lists.newArrayList(inField, lenField)); + Projector eval = Projector.make(schema, exprs); + + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + int[] values = + new int[] { + 2345, + 2345, + 2345, + 2345, + -2345, + }; + long[] lenValues = + new long[] { + 0L, 4L, 2L, 6L, 5L + }; + + String[] expValues = + new String[] { + "", + Integer.toString(2345).substring(0, 4), + Integer.toString(2345).substring(0, 2), + Integer.toString(2345), + Integer.toString(-2345) + }; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf bufData = intBuf(values); + ArrowBuf lenValidity = buf(validity); + ArrowBuf lenData = longBuf(lenValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(bufValidity, bufData, lenValidity, lenData)); + + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + + charVector.allocateNew(numRows * 5, numRows); + output.add(charVector); + } + eval.evaluate(batch, output); + eval.close(); + + for (ValueVector valueVector : output) { + VarCharVector charVector = (VarCharVector) valueVector; + + for (int j = 0; j < numRows; j++) { + assertFalse(charVector.isNull(j)); + assertEquals(expValues[j], new String(charVector.get(j))); + } + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testCastVarcharFromFloat() throws Exception { + Field inField = Field.nullable("input", float64); + Field lenField = Field.nullable("outLength", int64); + + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode lenNode = TreeBuilder.makeField(lenField); + + TreeNode tsToString = TreeBuilder.makeFunction("castVARCHAR", Lists.newArrayList(inNode, lenNode), + new ArrowType.Utf8()); + + Field resultField = Field.nullable("result", new ArrowType.Utf8()); + List<ExpressionTree> exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(tsToString, resultField)); + + Schema schema = new Schema(Lists.newArrayList(inField, lenField)); + Projector eval = Projector.make(schema, exprs); + + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + double[] values = + new double[] { + 0.0, + -0.0, + 1.0, + 0.001, + 0.0009, + 0.00099893, + 999999.9999, + 10000000.0, + 23943410000000.343434, + Double.POSITIVE_INFINITY, + Double.NEGATIVE_INFINITY, + Double.NaN, + 23.45, + 23.45, + -23.45, + }; + long[] lenValues = + new long[] { + 6L, 6L, 6L, 6L, 10L, 15L, 15L, 15L, 30L, + 15L, 15L, 15L, 0L, 6L, 6L + }; + + /* The Java real numbers are represented in two ways and Gandiva must + * follow the same rules: + * - If the number is greater or equals than 10^7 and less than 10^(-3) + * it will be represented using scientific notation, e.g: + * - 0.000012 -> 1.2E-5 + * - 10000002.3 -> 1.00000023E7 + * - If the numbers are between that interval above, they are showed as is. + * + * The test checks if the Gandiva function casts the number with the same notation of the + * Java. + * */ + String[] expValues = + new String[] { + Double.toString(0.0), // must be cast to -> "0.0" + Double.toString(-0.0), // must be cast to -> "-0.0" + Double.toString(1.0), // must be cast to -> "1.0" + Double.toString(0.001), // must be cast to -> "0.001" + Double.toString(0.0009), // must be cast to -> "9E-4" + Double.toString(0.00099893), // must be cast to -> "9E-4" + Double.toString(999999.9999), // must be cast to -> "999999.9999" + Double.toString(10000000.0), // must be cast to 1E7 + Double.toString(23943410000000.343434), + Double.toString(Double.POSITIVE_INFINITY), + Double.toString(Double.NEGATIVE_INFINITY), + Double.toString(Double.NaN), + "", + Double.toString(23.45), + Double.toString(-23.45) + }; + + ArrowBuf bufValidity = buf(validity); + ArrowBuf bufData = doubleBuf(values); + ArrowBuf lenValidity = buf(validity); + ArrowBuf lenData = longBuf(lenValues); + + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode, fieldNode), + Lists.newArrayList(bufValidity, bufData, lenValidity, lenData)); + + List<ValueVector> output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + VarCharVector charVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + + charVector.allocateNew(numRows * 5, numRows); + output.add(charVector); + } + eval.evaluate(batch, output); + eval.close(); + + for (ValueVector valueVector : output) { + VarCharVector charVector = (VarCharVector) valueVector; + + for (int j = 0; j < numRows; j++) { + assertFalse(charVector.isNull(j)); + assertEquals(expValues[j], new String(charVector.get(j))); + } + } + + releaseRecordBatch(batch); + releaseValueVectors(output); + } + + @Test + public void testInitCap() throws Exception { + + Field x = Field.nullable("x", new ArrowType.Utf8()); + + Field retType = Field.nullable("c", new ArrowType.Utf8()); + + TreeNode cond = + TreeBuilder.makeFunction( + "initcap", + Lists.newArrayList(TreeBuilder.makeField(x)), + new ArrowType.Utf8()); + ExpressionTree expr = TreeBuilder.makeExpression(cond, retType); + Schema schema = new Schema(Lists.newArrayList(x)); + Projector eval = Projector.make(schema, Lists.newArrayList(expr)); + + int numRows = 5; + byte[] validity = new byte[]{(byte) 15, 0}; + String[] valuesX = new String[]{ + " øhpqršvñ \n\n", + "möbelträger1füße \nmöbelträge'rfüße", + "ÂbĆDËFgh\néll", + "citroën CaR", + "kjk" + }; + + String[] expected = new String[]{ + " Øhpqršvñ \n\n", + "Möbelträger1füße \nMöbelträge'Rfüße", + "Âbćdëfgh\nÉll", + "Citroën Car", + null + }; + + ArrowBuf validityX = buf(validity); + List<ArrowBuf> dataBufsX = stringBufs(valuesX); + + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(validityX, dataBufsX.get(0), dataBufsX.get(1))); + + // allocate data for output vector. + VarCharVector outVector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator); + outVector.allocateNew(numRows * 100, numRows); + + // evaluate expression + List<ValueVector> output = new ArrayList<>(); + output.add(outVector); + eval.evaluate(batch, output); + eval.close(); + + // match expected output. + for (int i = 0; i < numRows - 1; i++) { + assertFalse("Expect none value equals null", outVector.isNull(i)); + assertEquals(expected[i], new String(outVector.get(i))); + } + + assertTrue("Last value must be null", outVector.isNull(numRows - 1)); + + releaseRecordBatch(batch); + releaseValueVectors(output); + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/TestJniLoader.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/TestJniLoader.java new file mode 100644 index 000000000..116f0dd9e --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/TestJniLoader.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.junit.Assert; +import org.junit.Test; + +public class TestJniLoader { + + @Test + public void testDefaultConfiguration() throws Exception { + long configId = JniLoader.getConfiguration(ConfigurationBuilder.ConfigOptions.getDefault()); + Assert.assertEquals(configId, JniLoader.getDefaultConfiguration()); + Assert.assertEquals(configId, JniLoader.getConfiguration(ConfigurationBuilder.ConfigOptions.getDefault())); + + long configId2 = JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions().withOptimize(false)); + long configId3 = JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions().withTargetCPU(false)); + long configId4 = JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions().withOptimize(false) + .withTargetCPU(false)); + + Assert.assertTrue(configId != configId2 && configId2 != configId3 && configId3 != configId4); + + Assert.assertEquals(configId2, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions() + .withOptimize(false))); + Assert.assertEquals(configId3, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions() + .withTargetCPU(false))); + Assert.assertEquals(configId4, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions() + .withOptimize(false).withTargetCPU(false))); + + JniLoader.removeConfiguration(new ConfigurationBuilder.ConfigOptions().withOptimize(false)); + // configids are monotonically updated. after a config is removed, new one is assigned with higher id + Assert.assertNotEquals(configId2, JniLoader.getConfiguration(new ConfigurationBuilder.ConfigOptions() + .withOptimize(false))); + + JniLoader.removeConfiguration(new ConfigurationBuilder.ConfigOptions()); + Assert.assertNotEquals(configId, JniLoader.getConfiguration(ConfigurationBuilder.ConfigOptions.getDefault())); + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/ArrowTypeHelperTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/ArrowTypeHelperTest.java new file mode 100644 index 000000000..7ddd602bf --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/ArrowTypeHelperTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +public class ArrowTypeHelperTest { + + private void testInt(int width, boolean isSigned, int expected) throws GandivaException { + ArrowType arrowType = new ArrowType.Int(width, isSigned); + GandivaTypes.ExtGandivaType gandivaType = ArrowTypeHelper.arrowTypeToProtobuf(arrowType); + assertEquals(expected, gandivaType.getType().getNumber()); + } + + @Test + public void testAllInts() throws GandivaException { + testInt(8, false, GandivaTypes.GandivaType.UINT8_VALUE); + testInt(8, true, GandivaTypes.GandivaType.INT8_VALUE); + testInt(16, false, GandivaTypes.GandivaType.UINT16_VALUE); + testInt(16, true, GandivaTypes.GandivaType.INT16_VALUE); + testInt(32, false, GandivaTypes.GandivaType.UINT32_VALUE); + testInt(32, true, GandivaTypes.GandivaType.INT32_VALUE); + testInt(64, false, GandivaTypes.GandivaType.UINT64_VALUE); + testInt(64, true, GandivaTypes.GandivaType.INT64_VALUE); + } + + private void testFloat(FloatingPointPrecision precision, int expected) throws GandivaException { + ArrowType arrowType = new ArrowType.FloatingPoint(precision); + GandivaTypes.ExtGandivaType gandivaType = ArrowTypeHelper.arrowTypeToProtobuf(arrowType); + assertEquals(expected, gandivaType.getType().getNumber()); + } + + @Test + public void testAllFloats() throws GandivaException { + testFloat(FloatingPointPrecision.HALF, GandivaTypes.GandivaType.HALF_FLOAT_VALUE); + testFloat(FloatingPointPrecision.SINGLE, GandivaTypes.GandivaType.FLOAT_VALUE); + testFloat(FloatingPointPrecision.DOUBLE, GandivaTypes.GandivaType.DOUBLE_VALUE); + } + + private void testBasic(ArrowType arrowType, int expected) throws GandivaException { + GandivaTypes.ExtGandivaType gandivaType = ArrowTypeHelper.arrowTypeToProtobuf(arrowType); + assertEquals(expected, gandivaType.getType().getNumber()); + } + + @Test + public void testSimpleTypes() throws GandivaException { + testBasic(new ArrowType.Bool(), GandivaTypes.GandivaType.BOOL_VALUE); + testBasic(new ArrowType.Binary(), GandivaTypes.GandivaType.BINARY_VALUE); + testBasic(new ArrowType.Utf8(), GandivaTypes.GandivaType.UTF8_VALUE); + } + + @Test + public void testField() throws GandivaException { + Field field = Field.nullable("col1", new ArrowType.Bool()); + GandivaTypes.Field f = ArrowTypeHelper.arrowFieldToProtobuf(field); + assertEquals(field.getName(), f.getName()); + assertEquals(true, f.getNullable()); + assertEquals(GandivaTypes.GandivaType.BOOL_VALUE, f.getType().getType().getNumber()); + } + + @Test + public void testSchema() throws GandivaException { + Field a = Field.nullable("a", new ArrowType.Int(16, false)); + Field b = Field.nullable("b", new ArrowType.Int(32, true)); + Field c = Field.nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)); + + List<Field> fields = new ArrayList<Field>(); + fields.add(a); + fields.add(b); + fields.add(c); + + GandivaTypes.Schema schema = ArrowTypeHelper.arrowSchemaToProtobuf(new Schema(fields)); + int idx = 0; + for (GandivaTypes.Field f : schema.getColumnsList()) { + assertEquals(fields.get(idx).getName(), f.getName()); + idx++; + } + } +} diff --git a/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/TreeBuilderTest.java b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/TreeBuilderTest.java new file mode 100644 index 000000000..90373cf79 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/java/org/apache/arrow/gandiva/expression/TreeBuilderTest.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.expression; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.gandiva.exceptions.GandivaException; +import org.apache.arrow.gandiva.ipc.GandivaTypes; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.junit.Test; + +public class TreeBuilderTest { + + @Test + public void testMakeLiteral() throws GandivaException { + TreeNode n = TreeBuilder.makeLiteral(Boolean.TRUE); + GandivaTypes.TreeNode node = n.toProtobuf(); + + assertEquals(true, node.getBooleanNode().getValue()); + + n = TreeBuilder.makeLiteral(new Integer(10)); + node = n.toProtobuf(); + assertEquals(10, node.getIntNode().getValue()); + + n = TreeBuilder.makeLiteral(new Long(50)); + node = n.toProtobuf(); + assertEquals(50, node.getLongNode().getValue()); + + Float f = new Float(2.5); + n = TreeBuilder.makeLiteral(f); + node = n.toProtobuf(); + assertEquals(f.floatValue(), node.getFloatNode().getValue(), 0.1); + + Double d = new Double(3.3); + n = TreeBuilder.makeLiteral(d); + node = n.toProtobuf(); + assertEquals(d.doubleValue(), node.getDoubleNode().getValue(), 0.1); + + String s = new String("hello"); + n = TreeBuilder.makeStringLiteral(s); + node = n.toProtobuf(); + assertArrayEquals(s.getBytes(), node.getStringNode().getValue().toByteArray()); + + byte[] b = new String("hello").getBytes(); + n = TreeBuilder.makeBinaryLiteral(b); + node = n.toProtobuf(); + assertArrayEquals(b, node.getBinaryNode().getValue().toByteArray()); + } + + @Test + public void testMakeNull() throws GandivaException { + TreeNode n = TreeBuilder.makeNull(new ArrowType.Bool()); + GandivaTypes.TreeNode node = n.toProtobuf(); + assertEquals( + GandivaTypes.GandivaType.BOOL_VALUE, node.getNullNode().getType().getType().getNumber()); + + n = TreeBuilder.makeNull(new ArrowType.Int(32, true)); + node = n.toProtobuf(); + assertEquals( + GandivaTypes.GandivaType.INT32_VALUE, node.getNullNode().getType().getType().getNumber()); + + n = TreeBuilder.makeNull(new ArrowType.Int(64, false)); + node = n.toProtobuf(); + assertEquals( + GandivaTypes.GandivaType.UINT64_VALUE, node.getNullNode().getType().getType().getNumber()); + + n = TreeBuilder.makeNull(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)); + node = n.toProtobuf(); + assertEquals( + GandivaTypes.GandivaType.FLOAT_VALUE, node.getNullNode().getType().getType().getNumber()); + } + + @Test + public void testMakeField() throws GandivaException { + TreeNode n = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Int(32, true))); + GandivaTypes.TreeNode node = n.toProtobuf(); + + assertEquals("a", node.getFieldNode().getField().getName()); + assertEquals( + GandivaTypes.GandivaType.INT32_VALUE, + node.getFieldNode().getField().getType().getType().getNumber()); + } + + @Test + public void testMakeFunction() throws GandivaException { + TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Int(64, false))); + TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Int(64, false))); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(a); + args.add(b); + + TreeNode addNode = TreeBuilder.makeFunction("add", args, new ArrowType.Int(64, false)); + GandivaTypes.TreeNode node = addNode.toProtobuf(); + + assertTrue(node.hasFnNode()); + assertEquals("add", node.getFnNode().getFunctionName()); + assertEquals("a", node.getFnNode().getInArgsList().get(0).getFieldNode().getField().getName()); + assertEquals("b", node.getFnNode().getInArgsList().get(1).getFieldNode().getField().getName()); + assertEquals( + GandivaTypes.GandivaType.UINT64_VALUE, + node.getFnNode().getReturnType().getType().getNumber()); + } + + @Test + public void testMakeIf() throws GandivaException { + Field a = Field.nullable("a", new ArrowType.Int(64, false)); + Field b = Field.nullable("b", new ArrowType.Int(64, false)); + TreeNode aNode = TreeBuilder.makeField(a); + TreeNode bNode = TreeBuilder.makeField(b); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(aNode); + args.add(bNode); + + ArrowType retType = new ArrowType.Bool(); + TreeNode cond = TreeBuilder.makeFunction("greater_than", args, retType); + TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, retType); + + GandivaTypes.TreeNode node = ifNode.toProtobuf(); + + assertTrue(node.hasIfNode()); + assertEquals("greater_than", node.getIfNode().getCond().getFnNode().getFunctionName()); + assertEquals(a.getName(), node.getIfNode().getThenNode().getFieldNode().getField().getName()); + assertEquals(b.getName(), node.getIfNode().getElseNode().getFieldNode().getField().getName()); + assertEquals( + GandivaTypes.GandivaType.BOOL_VALUE, + node.getIfNode().getReturnType().getType().getNumber()); + } + + @Test + public void testMakeAnd() throws GandivaException { + TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool())); + TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool())); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(a); + args.add(b); + + TreeNode andNode = TreeBuilder.makeAnd(args); + GandivaTypes.TreeNode node = andNode.toProtobuf(); + + assertTrue(node.hasAndNode()); + assertEquals(2, node.getAndNode().getArgsList().size()); + assertEquals("a", node.getAndNode().getArgsList().get(0).getFieldNode().getField().getName()); + assertEquals("b", node.getAndNode().getArgsList().get(1).getFieldNode().getField().getName()); + } + + @Test + public void testMakeOr() throws GandivaException { + TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool())); + TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool())); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(a); + args.add(b); + + TreeNode orNode = TreeBuilder.makeOr(args); + GandivaTypes.TreeNode node = orNode.toProtobuf(); + + assertTrue(node.hasOrNode()); + assertEquals(2, node.getOrNode().getArgsList().size()); + assertEquals("a", node.getOrNode().getArgsList().get(0).getFieldNode().getField().getName()); + assertEquals("b", node.getOrNode().getArgsList().get(1).getFieldNode().getField().getName()); + } + + @Test + public void testExpression() throws GandivaException { + Field a = Field.nullable("a", new ArrowType.Int(64, false)); + Field b = Field.nullable("b", new ArrowType.Int(64, false)); + TreeNode aNode = TreeBuilder.makeField(a); + TreeNode bNode = TreeBuilder.makeField(b); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(aNode); + args.add(bNode); + + ArrowType retType = new ArrowType.Bool(); + TreeNode cond = TreeBuilder.makeFunction("greater_than", args, retType); + TreeNode ifNode = TreeBuilder.makeIf(cond, aNode, bNode, retType); + + ExpressionTree expr = TreeBuilder.makeExpression(ifNode, Field.nullable("c", retType)); + + GandivaTypes.ExpressionRoot root = expr.toProtobuf(); + + assertTrue(root.getRoot().hasIfNode()); + assertEquals( + "greater_than", root.getRoot().getIfNode().getCond().getFnNode().getFunctionName()); + assertEquals("c", root.getResultType().getName()); + assertEquals( + GandivaTypes.GandivaType.BOOL_VALUE, root.getResultType().getType().getType().getNumber()); + } + + @Test + public void testExpression2() throws GandivaException { + Field a = Field.nullable("a", new ArrowType.Int(64, false)); + Field b = Field.nullable("b", new ArrowType.Int(64, false)); + List<Field> args = new ArrayList<Field>(2); + args.add(a); + args.add(b); + + Field c = Field.nullable("c", new ArrowType.Int(64, false)); + ExpressionTree expr = TreeBuilder.makeExpression("add", args, c); + GandivaTypes.ExpressionRoot root = expr.toProtobuf(); + + GandivaTypes.TreeNode node = root.getRoot(); + + assertEquals("c", root.getResultType().getName()); + assertTrue(node.hasFnNode()); + assertEquals("add", node.getFnNode().getFunctionName()); + assertEquals("a", node.getFnNode().getInArgsList().get(0).getFieldNode().getField().getName()); + assertEquals("b", node.getFnNode().getInArgsList().get(1).getFieldNode().getField().getName()); + assertEquals( + GandivaTypes.GandivaType.UINT64_VALUE, + node.getFnNode().getReturnType().getType().getNumber()); + } + + @Test + public void testExpressionWithAnd() throws GandivaException { + TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool())); + TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool())); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(a); + args.add(b); + + TreeNode andNode = TreeBuilder.makeAnd(args); + ExpressionTree expr = + TreeBuilder.makeExpression(andNode, Field.nullable("c", new ArrowType.Bool())); + GandivaTypes.ExpressionRoot root = expr.toProtobuf(); + + assertTrue(root.getRoot().hasAndNode()); + assertEquals( + "a", root.getRoot().getAndNode().getArgsList().get(0).getFieldNode().getField().getName()); + assertEquals( + "b", root.getRoot().getAndNode().getArgsList().get(1).getFieldNode().getField().getName()); + assertEquals("c", root.getResultType().getName()); + assertEquals( + GandivaTypes.GandivaType.BOOL_VALUE, root.getResultType().getType().getType().getNumber()); + } + + @Test + public void testExpressionWithOr() throws GandivaException { + TreeNode a = TreeBuilder.makeField(Field.nullable("a", new ArrowType.Bool())); + TreeNode b = TreeBuilder.makeField(Field.nullable("b", new ArrowType.Bool())); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(a); + args.add(b); + + TreeNode orNode = TreeBuilder.makeOr(args); + ExpressionTree expr = + TreeBuilder.makeExpression(orNode, Field.nullable("c", new ArrowType.Bool())); + GandivaTypes.ExpressionRoot root = expr.toProtobuf(); + + assertTrue(root.getRoot().hasOrNode()); + assertEquals( + "a", root.getRoot().getOrNode().getArgsList().get(0).getFieldNode().getField().getName()); + assertEquals( + "b", root.getRoot().getOrNode().getArgsList().get(1).getFieldNode().getField().getName()); + assertEquals("c", root.getResultType().getName()); + assertEquals( + GandivaTypes.GandivaType.BOOL_VALUE, root.getResultType().getType().getType().getNumber()); + } + + @Test + public void testCondition() throws GandivaException { + Field a = Field.nullable("a", new ArrowType.Int(64, false)); + Field b = Field.nullable("b", new ArrowType.Int(64, false)); + + TreeNode aNode = TreeBuilder.makeField(a); + TreeNode bNode = TreeBuilder.makeField(b); + List<TreeNode> args = new ArrayList<TreeNode>(2); + args.add(aNode); + args.add(bNode); + + TreeNode root = TreeBuilder.makeFunction("greater_than", args, new ArrowType.Bool()); + Condition condition = TreeBuilder.makeCondition(root); + + GandivaTypes.Condition conditionProto = condition.toProtobuf(); + assertTrue(conditionProto.getRoot().hasFnNode()); + assertEquals("greater_than", conditionProto.getRoot().getFnNode().getFunctionName()); + assertEquals( + "a", + conditionProto + .getRoot() + .getFnNode() + .getInArgsList() + .get(0) + .getFieldNode() + .getField() + .getName()); + assertEquals( + "b", + conditionProto + .getRoot() + .getFnNode() + .getInArgsList() + .get(1) + .getFieldNode() + .getField() + .getName()); + } + + @Test + public void testCondition2() throws GandivaException { + Field a = Field.nullable("a", new ArrowType.Int(64, false)); + Field b = Field.nullable("b", new ArrowType.Int(64, false)); + + Condition condition = TreeBuilder.makeCondition("greater_than", Arrays.asList(a, b)); + + GandivaTypes.Condition conditionProto = condition.toProtobuf(); + assertTrue(conditionProto.getRoot().hasFnNode()); + assertEquals("greater_than", conditionProto.getRoot().getFnNode().getFunctionName()); + assertEquals( + "a", + conditionProto + .getRoot() + .getFnNode() + .getInArgsList() + .get(0) + .getFieldNode() + .getField() + .getName()); + assertEquals( + "b", + conditionProto + .getRoot() + .getFnNode() + .getInArgsList() + .get(1) + .getFieldNode() + .getField() + .getName()); + } +} diff --git a/src/arrow/java/gandiva/src/test/resources/logback.xml b/src/arrow/java/gandiva/src/test/resources/logback.xml new file mode 100644 index 000000000..f9e449fa6 --- /dev/null +++ b/src/arrow/java/gandiva/src/test/resources/logback.xml @@ -0,0 +1,28 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor + license agreements. See the NOTICE file distributed with this work for additional + information regarding copyright ownership. The ASF licenses this file to + You under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of + the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required + by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS + OF ANY KIND, either express or implied. See the License for the specific + language governing permissions and limitations under the License. --> + +<configuration> + <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> + <!-- encoders are assigned the type + ch.qos.logback.classic.encoder.PatternLayoutEncoder by default --> + <encoder> + <pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern> + </encoder> + </appender> + + <statusListener class="ch.qos.logback.core.status.NopStatusListener"/> + <logger name="org.apache.arrow" additivity="false"> + <level value="info" /> + <appender-ref ref="STDOUT" /> + </logger> + +</configuration> |